From 3113dcc0f76e4f2eaa30a010a3eb234c37a68826 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Wed, 24 Dec 2025 03:10:23 +0100 Subject: [PATCH 001/244] Add best-match tightening for local head bounds --- Main.lean | 1 + Nfp/Sound/HeadCert.lean | 4 +- Nfp/Sound/IO.lean | 106 +++++++++++++++++++++++++++----- Nfp/Untrusted/SoundCompute.lean | 4 +- SOUNDNESS_LIMITATIONS.md | 9 +-- 5 files changed, 104 insertions(+), 20 deletions(-) diff --git a/Main.lean b/Main.lean index 8806542..278b3f6 100644 --- a/Main.lean +++ b/Main.lean @@ -975,6 +975,7 @@ private def formatHeadBoundsLocal s!"ln1Bound={h.ln1Bound}, " ++ s!"wqOp={h.wqOpBound}, wkOp={h.wkOpBound}, " ++ s!"qk={h.qkFactorBound}, " ++ + s!"softmaxJacobianNormInfUB={h.softmaxJacobianNormInfUpperBound}, " ++ s!"wvOp={h.wvOpBound}, woOp={h.woOpBound}, " ++ s!"attn={h.attnJacBound}\n") "" header ++ body diff --git a/Nfp/Sound/HeadCert.lean b/Nfp/Sound/HeadCert.lean index c0b1c6c..9a042aa 100644 --- a/Nfp/Sound/HeadCert.lean +++ b/Nfp/Sound/HeadCert.lean @@ -60,6 +60,8 @@ structure HeadLocalContributionCert where wvOpBound : Rat woOpBound : Rat qkFactorBound : Rat + /-- Upper bound on the softmax Jacobian row-sum norm for this head. -/ + softmaxJacobianNormInfUpperBound : Rat /-- Upper bound on the per-head attention Jacobian contribution. -/ attnJacBound : Rat deriving Repr @@ -76,7 +78,7 @@ def Valid (eps : Rat) (c : HeadLocalContributionCert) : Prop := layerNormOpBoundConservative c.ln1MaxAbsGamma eps c.soundnessBits) ∧ c.qkFactorBound = c.wqOpBound * c.wkOpBound ∧ c.attnJacBound = - c.ln1Bound * softmaxJacobianNormInfWorst * c.wvOpBound * c.woOpBound + c.ln1Bound * c.softmaxJacobianNormInfUpperBound * c.wvOpBound * c.woOpBound instance (eps : Rat) (c : HeadLocalContributionCert) : Decidable (Valid eps c) := by unfold Valid diff --git a/Nfp/Sound/IO.lean b/Nfp/Sound/IO.lean index 088c003..40abc74 100644 --- a/Nfp/Sound/IO.lean +++ b/Nfp/Sound/IO.lean @@ -88,6 +88,19 @@ private def checkSoftmaxMarginZero (cert : ModelCert) : Except String Unit := theorem checkSoftmaxMarginZero_spec_io : checkSoftmaxMarginZero = checkSoftmaxMarginZero := rfl +private def checkSoftmaxProbIntervalWorst (cert : ModelCert) : Except String Unit := + Id.run do + for idx in [:cert.layers.size] do + let layer := cert.layers[idx]! + if layer.softmaxProbLo ≠ 0 then + return .error s!"softmaxProbLo is unverified (layer {idx})" + if layer.softmaxProbHi ≠ 1 then + return .error s!"softmaxProbHi is unverified (layer {idx})" + return .ok () + +theorem checkSoftmaxProbIntervalWorst_spec_io : + checkSoftmaxProbIntervalWorst = checkSoftmaxProbIntervalWorst := rfl + private def recomputeAttnWeightBoundsBinary (path : System.FilePath) : IO (Except String AttnWeightBounds) := do let h ← IO.FS.Handle.mk path IO.FS.Mode.read @@ -271,16 +284,19 @@ def certifyModelFileGlobal if cert.geluDerivTarget ≠ geluTarget then return .error "model header gelu_kind mismatch" if cert.check then - match checkSoftmaxMarginZero cert with + match checkSoftmaxProbIntervalWorst cert with | .error e => return .error e | .ok _ => - match ← recomputeAttnWeightBounds path with - | .error e => - return .error s!"attnWeightBounds verification failed: {e}" - | .ok bounds => - match checkAttnWeightBounds cert bounds with + match checkSoftmaxMarginZero cert with | .error e => return .error e - | .ok _ => return .ok cert + | .ok _ => + match ← recomputeAttnWeightBounds path with + | .error e => + return .error s!"attnWeightBounds verification failed: {e}" + | .ok bounds => + match checkAttnWeightBounds cert bounds with + | .error e => return .error e + | .ok _ => return .ok cert return .error "sound certificate failed internal consistency checks" /-- Entry point for sound certification (global or local). -/ @@ -308,16 +324,19 @@ def certifyModelFile if cert.geluDerivTarget ≠ geluTarget then return .error "model header gelu_kind mismatch" if cert.check then - match checkSoftmaxMarginZero cert with + match checkSoftmaxProbIntervalWorst cert with | .error e => return .error e | .ok _ => - match ← recomputeAttnWeightBounds path with - | .error e => - return .error s!"attnWeightBounds verification failed: {e}" - | .ok bounds => - match checkAttnWeightBounds cert bounds with + match checkSoftmaxMarginZero cert with | .error e => return .error e - | .ok _ => return .ok cert + | .ok _ => + match ← recomputeAttnWeightBounds path with + | .error e => + return .error s!"attnWeightBounds verification failed: {e}" + | .ok bounds => + match checkAttnWeightBounds cert bounds with + | .error e => return .error e + | .ok _ => return .ok cert return .error "sound certificate failed internal consistency checks" /-- Compute per-head contribution bounds (global). -/ @@ -443,6 +462,65 @@ def certifyHeadPatternBestMatchLocalSweep return .ok certs return .error "head best-match sweep certificate failed internal checks" +/-- Compute local per-head attention contribution bounds tightened by + best-match pattern evidence. -/ +def certifyHeadBoundsLocalBestMatch + (path : System.FilePath) + (layerIdx headIdx : Nat) + (queryPos? : Option Nat := none) + (inputPath? : Option System.FilePath := none) + (inputDelta : Rat := 0) + (soundnessBits : Nat) + (targetOffset : Int := -1) + (maxSeqLen : Nat := 256) + (tightPattern : Bool := false) + (tightPatternLayers : Nat := 1) + (perRowPatternLayers : Nat := 0) + (scalePow10 : Nat := 9) + (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : + IO (Except String HeadLocalContributionCert) := do + match ← readModelEps path with + | .error e => return .error e + | .ok eps => + match ← + certifyHeadBoundsLocal path + (inputPath? := inputPath?) (inputDelta := inputDelta) + (soundnessBits := soundnessBits) (scalePow10 := scalePow10) with + | .error e => return .error e + | .ok certs => + let base? := + certs.find? (fun c => c.layerIdx == layerIdx && c.headIdx == headIdx) + match base? with + | none => + return .error s!"no local head contribution cert for layer {layerIdx} head {headIdx}" + | some base => + match ← + certifyHeadPatternBestMatchLocal path layerIdx headIdx + (queryPos? := queryPos?) (inputPath? := inputPath?) + (inputDelta := inputDelta) (soundnessBits := soundnessBits) + (targetOffset := targetOffset) (maxSeqLen := maxSeqLen) + (tightPattern := tightPattern) (tightPatternLayers := tightPatternLayers) + (perRowPatternLayers := perRowPatternLayers) + (softmaxExpEffort := softmaxExpEffort) with + | .error e => return .error e + | .ok pattern => + if pattern.layerIdx ≠ layerIdx || pattern.headIdx ≠ headIdx then + return .error "best-match pattern cert layer/head mismatch" + if pattern.softmaxExpEffort ≠ softmaxExpEffort then + return .error "best-match pattern cert softmax effort mismatch" + let softmaxBound := pattern.softmaxJacobianNormInfUpperBound + if softmaxBound > base.softmaxJacobianNormInfUpperBound then + return .error "best-match softmax bound is worse than baseline" + let attnJacBound := + base.ln1Bound * softmaxBound * base.wvOpBound * base.woOpBound + let tightened := + { base with + softmaxJacobianNormInfUpperBound := softmaxBound + attnJacBound := attnJacBound } + if tightened.check eps then + return .ok tightened + return .error "tightened head contribution certificate failed internal checks" + /-- Compute local head output lower bounds. -/ def certifyHeadValueLowerBoundLocal (path : System.FilePath) diff --git a/Nfp/Untrusted/SoundCompute.lean b/Nfp/Untrusted/SoundCompute.lean index 5562929..9b5371f 100644 --- a/Nfp/Untrusted/SoundCompute.lean +++ b/Nfp/Untrusted/SoundCompute.lean @@ -2800,7 +2800,8 @@ private def certifyHeadBoundsLocalBinary ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h hdr.headDim hdr.modelDim vHidden scalePow10) attnUnion := addVecFixed attnUnion vOut - let attnW := ln1Bound * softmaxJacobianNormInfWorst * vCenteredOpBound * nWo + let softmaxJacobianBound := softmaxJacobianNormInfWorst + let attnW := ln1Bound * softmaxJacobianBound * vCenteredOpBound * nWo let cert : HeadLocalContributionCert := { layerIdx := l headIdx := hIdx @@ -2813,6 +2814,7 @@ private def certifyHeadBoundsLocalBinary wvOpBound := vCenteredOpBound woOpBound := nWo qkFactorBound := wqOp * wkOp + softmaxJacobianNormInfUpperBound := softmaxJacobianBound attnJacBound := attnW } if cert.check eps then diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index 260ad98..4997f2c 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -10,10 +10,11 @@ soundness upgrade. It is intentionally brief and human-readable. discharge those assumptions from model weights. - `partitionDepth > 0` is rejected with an explicit error (no partitioning logic yet). - Affine arithmetic is only a scaffold (`Nfp/Sound/Affine.lean`) and not wired into SOUND certification. -- Softmax Jacobian bounds typically use probability intervals defaulted to `[0,1]`, so they - reduce to the worst case. Margin-derived tightening is computed by the untrusted path, but - trusted IO currently **rejects nonzero** `softmaxMarginLowerBound` because margin evidence is - unverified. +- Softmax Jacobian bounds are enforced to use the worst-case probability interval `[0,1]` in + trusted IO. Margin-derived tightening is computed by the untrusted path, but trusted IO + currently **rejects nonzero** `softmaxMarginLowerBound` because margin evidence is unverified. +- Local per-head contribution bounds can now be tightened using a best-match pattern certificate, + but this tightening does **not** propagate to layer-level ModelCert bounds. - Best-match pattern certificates now use a margin-derived softmax Jacobian bound with an effort-indexed `expLB` (scaled Taylor + squaring). The lower-bound correctness of `expLB` is not yet formalized in Lean. From 40dcd3b016357ade3ff87fa29a18bc5d88118303 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Wed, 24 Dec 2025 03:22:55 +0100 Subject: [PATCH 002/244] Add layer best-match margin certs --- Nfp/Sound/Cert.lean | 33 +++++++++++++ Nfp/Sound/HeadCert.lean | 83 +++++++++++++++++++++++++++++++++ Nfp/Sound/IO.lean | 69 +++++++++++++++++++++++++++ Nfp/Untrusted/SoundCompute.lean | 63 +++++++++++++++++++++++++ SOUNDNESS_LIMITATIONS.md | 3 ++ 5 files changed, 251 insertions(+) diff --git a/Nfp/Sound/Cert.lean b/Nfp/Sound/Cert.lean index aff9ca9..c197bca 100644 --- a/Nfp/Sound/Cert.lean +++ b/Nfp/Sound/Cert.lean @@ -107,6 +107,24 @@ theorem softmaxJacobianNormInfPortfolioBound_def (seqLen : Nat) (l : LayerAmplif #[softmaxJacobianNormInfBoundFromMargin seqLen l.softmaxMarginLowerBound l.softmaxExpEffort] := rfl +/-- Update margin evidence and recompute dependent softmax + residual bounds. -/ +def withSoftmaxMargin (seqLen : Nat) (marginLowerBound : Rat) (softmaxExpEffort : Nat) + (l : LayerAmplificationCert) : LayerAmplificationCert := + let l' := + { l with + softmaxMarginLowerBound := marginLowerBound + softmaxExpEffort := softmaxExpEffort } + let softmaxBound := softmaxJacobianNormInfPortfolioBound seqLen l' + let attnJacBound := + l'.ln1Bound * + ((seqLen : Rat) * l'.attnValueCoeff + softmaxBound * l'.attnPatternCoeff) + let mlpJacBound := l'.mlpJacBound + let C := attnJacBound + mlpJacBound + attnJacBound * mlpJacBound + { l' with + softmaxJacobianNormInfUpperBound := softmaxBound + attnJacBound := attnJacBound + C := C } + /-- Internal consistency checks for per-layer bounds. -/ def Valid (eps : Rat) (sqrtPrecBits : Nat) (seqLen modelDim headDim : Nat) (l : LayerAmplificationCert) : Prop := @@ -268,6 +286,8 @@ theorem residual_bound_of_component_bounds_valid exact residual_bound_of_component_bounds (l := l) (A := A) (M := M) hC hA hM +theorem withSoftmaxMargin_spec : + withSoftmaxMargin = withSoftmaxMargin := rfl theorem Valid_spec : Valid = Valid := rfl theorem check_spec : check = check := rfl @@ -321,6 +341,18 @@ def check (c : ModelCert) : Bool := theorem check_iff (c : ModelCert) : c.check = true ↔ c.Valid := by simp [check, Valid] +/-- Replace a layer and recompute total amplification factor. -/ +def withUpdatedLayer (c : ModelCert) (layerIdx : Nat) (layer : LayerAmplificationCert) : + Option ModelCert := + if layer.layerIdx ≠ layerIdx then + none + else if layerIdx < c.layers.size then + let layers := c.layers.set! layerIdx layer + let total := layers.foldl (fun acc l => acc * (1 + l.C)) 1 + some { c with layers := layers, totalAmplificationFactor := total } + else + none + /-- Pretty printer. -/ def toString (c : ModelCert) : String := let header := @@ -360,6 +392,7 @@ instance : ToString ModelCert := ⟨toString⟩ theorem Valid_spec : Valid = Valid := rfl theorem check_spec : check = check := rfl +theorem withUpdatedLayer_spec : withUpdatedLayer = withUpdatedLayer := rfl theorem toString_spec : toString = toString := rfl end ModelCert diff --git a/Nfp/Sound/HeadCert.lean b/Nfp/Sound/HeadCert.lean index 9a042aa..975f6db 100644 --- a/Nfp/Sound/HeadCert.lean +++ b/Nfp/Sound/HeadCert.lean @@ -253,6 +253,79 @@ theorem check_iff (c : HeadBestMatchPatternCert) : c.check = true ↔ c.Valid := end HeadBestMatchPatternCert +/-! ## Layer-level best-match margin aggregation -/ + +/-- Index into a `(numHeads × seqLen)` margin array. -/ +def headQueryIndex (seqLen : Nat) (headIdx queryPos : Nat) : Nat := + headIdx * seqLen + queryPos + +/-- Populate a margin array from best-match certs; fails on duplicates or out-of-range indices. -/ +def marginsFromBestMatchCerts + (numHeads seqLen : Nat) (certs : Array HeadBestMatchPatternCert) : + Option (Array Rat) := + Id.run do + let size := numHeads * seqLen + let mut margins : Array Rat := Array.replicate size 0 + let mut seen : Array Bool := Array.replicate size false + for cert in certs do + if cert.headIdx < numHeads && cert.queryPos < seqLen then + let idx := headQueryIndex seqLen cert.headIdx cert.queryPos + if seen[idx]! then + return none + seen := seen.set! idx true + margins := margins.set! idx cert.marginLowerBound + else + return none + return some margins + +/-- Minimum margin over a nonempty array (defaults to `0` for empty input). -/ +def minMarginArray (margins : Array Rat) : Rat := + if margins.size = 0 then + 0 + else + margins.foldl (fun acc m => min acc m) margins[0]! + +/-- Layer-level best-match margin evidence aggregated across heads and query positions. -/ +structure LayerBestMatchMarginCert where + layerIdx : Nat + seqLen : Nat + numHeads : Nat + softmaxExpEffort : Nat + marginLowerBound : Rat + margins : Array Rat + headCerts : Array HeadBestMatchPatternCert + deriving Repr + +namespace LayerBestMatchMarginCert + +/-- Internal consistency checks for aggregated margins. -/ +def Valid (c : LayerBestMatchMarginCert) : Prop := + c.seqLen > 0 ∧ + c.numHeads > 0 ∧ + c.margins.size = c.numHeads * c.seqLen ∧ + c.headCerts.all (fun cert => + cert.check && + cert.layerIdx == c.layerIdx && + cert.seqLen == c.seqLen && + cert.softmaxExpEffort == c.softmaxExpEffort && + cert.headIdx < c.numHeads && + cert.queryPos < c.seqLen) = true ∧ + marginsFromBestMatchCerts c.numHeads c.seqLen c.headCerts = some c.margins ∧ + c.marginLowerBound = minMarginArray c.margins + +instance (c : LayerBestMatchMarginCert) : Decidable (Valid c) := by + unfold Valid + infer_instance + +/-- Boolean checker for `Valid`. -/ +def check (c : LayerBestMatchMarginCert) : Bool := + decide (Valid c) + +theorem check_iff (c : LayerBestMatchMarginCert) : c.check = true ↔ c.Valid := by + simp [check, Valid] + +end LayerBestMatchMarginCert + /-! ## Best-match value/logit bounds -/ /-- Local per-head output lower bound for a single coordinate (single query position). -/ @@ -456,6 +529,12 @@ theorem HeadPatternCert.Valid_spec : HeadPatternCert.Valid = HeadPatternCert.Valid := rfl theorem HeadPatternCert.check_spec : HeadPatternCert.check = HeadPatternCert.check := rfl +theorem headQueryIndex_spec : + headQueryIndex = headQueryIndex := rfl +theorem marginsFromBestMatchCerts_spec : + marginsFromBestMatchCerts = marginsFromBestMatchCerts := rfl +theorem minMarginArray_spec : + minMarginArray = minMarginArray := rfl theorem HeadPatternCert.toTokenMatchPattern_spec : HeadPatternCert.toTokenMatchPattern = HeadPatternCert.toTokenMatchPattern := rfl theorem HeadPatternCert.toInductionPatternWitness_spec : @@ -472,6 +551,10 @@ theorem HeadBestMatchPatternCert.Valid_spec : HeadBestMatchPatternCert.Valid = HeadBestMatchPatternCert.Valid := rfl theorem HeadBestMatchPatternCert.check_spec : HeadBestMatchPatternCert.check = HeadBestMatchPatternCert.check := rfl +theorem LayerBestMatchMarginCert.Valid_spec : + LayerBestMatchMarginCert.Valid = LayerBestMatchMarginCert.Valid := rfl +theorem LayerBestMatchMarginCert.check_spec : + LayerBestMatchMarginCert.check = LayerBestMatchMarginCert.check := rfl theorem HeadValueLowerBoundPosCert.Valid_spec : HeadValueLowerBoundPosCert.Valid = HeadValueLowerBoundPosCert.Valid := rfl theorem HeadValueLowerBoundPosCert.check_spec : diff --git a/Nfp/Sound/IO.lean b/Nfp/Sound/IO.lean index 40abc74..438fed3 100644 --- a/Nfp/Sound/IO.lean +++ b/Nfp/Sound/IO.lean @@ -101,6 +101,47 @@ private def checkSoftmaxProbIntervalWorst (cert : ModelCert) : Except String Uni theorem checkSoftmaxProbIntervalWorst_spec_io : checkSoftmaxProbIntervalWorst = checkSoftmaxProbIntervalWorst := rfl +private def tightenLayerSoftmaxFromBestMatch + (seqLen : Nat) (layer : LayerAmplificationCert) (cert : LayerBestMatchMarginCert) : + Except String LayerAmplificationCert := + Id.run do + if !cert.check then + return .error "layer best-match margin cert failed internal checks" + if cert.layerIdx ≠ layer.layerIdx then + return .error "layer margin cert does not match layer index" + if cert.seqLen ≠ seqLen then + return .error "layer margin cert seq_len mismatch" + let updated := + LayerAmplificationCert.withSoftmaxMargin seqLen cert.marginLowerBound + cert.softmaxExpEffort layer + if updated.softmaxJacobianNormInfUpperBound > layer.softmaxJacobianNormInfUpperBound then + return .error "best-match softmax bound is worse than baseline" + return .ok updated + +theorem tightenLayerSoftmaxFromBestMatch_spec_io : + tightenLayerSoftmaxFromBestMatch = tightenLayerSoftmaxFromBestMatch := rfl + +def tightenModelCertBestMatchMargins + (c : ModelCert) (certs : Array LayerBestMatchMarginCert) : + Except String ModelCert := + certs.foldl (fun acc cert => + match acc with + | .error e => .error e + | .ok cur => + if cert.layerIdx < cur.layers.size then + let layer := cur.layers[cert.layerIdx]! + match tightenLayerSoftmaxFromBestMatch cur.seqLen layer cert with + | .error e => .error e + | .ok updatedLayer => + match ModelCert.withUpdatedLayer cur cert.layerIdx updatedLayer with + | none => .error "failed to update model cert layer" + | some updated => .ok updated + else + .error s!"layer margin cert index {cert.layerIdx} out of range") (.ok c) + +theorem tightenModelCertBestMatchMargins_spec_io : + tightenModelCertBestMatchMargins = tightenModelCertBestMatchMargins := rfl + private def recomputeAttnWeightBoundsBinary (path : System.FilePath) : IO (Except String AttnWeightBounds) := do let h ← IO.FS.Handle.mk path IO.FS.Mode.read @@ -462,6 +503,34 @@ def certifyHeadPatternBestMatchLocalSweep return .ok certs return .error "head best-match sweep certificate failed internal checks" +/-- Compute layer-level best-match margin evidence (binary only). -/ +def certifyLayerBestMatchMarginLocal + (path : System.FilePath) + (layerIdx : Nat) + (inputPath? : Option System.FilePath := none) + (inputDelta : Rat := 0) + (soundnessBits : Nat) + (targetOffset : Int := -1) + (maxSeqLen : Nat := 256) + (tightPattern : Bool := false) + (tightPatternLayers : Nat := 1) + (perRowPatternLayers : Nat := 0) + (scalePow10 : Nat := 9) + (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : + IO (Except String LayerBestMatchMarginCert) := do + match ← readModelEps path with + | .error e => return .error e + | .ok eps => + match ← + Nfp.Untrusted.SoundCompute.certifyLayerBestMatchMarginLocal + path layerIdx eps soundnessBits inputPath? inputDelta targetOffset maxSeqLen + tightPattern tightPatternLayers perRowPatternLayers scalePow10 softmaxExpEffort with + | .error e => return .error e + | .ok cert => + if cert.check then + return .ok cert + return .error "layer best-match margin certificate failed internal checks" + /-- Compute local per-head attention contribution bounds tightened by best-match pattern evidence. -/ def certifyHeadBoundsLocalBestMatch diff --git a/Nfp/Untrusted/SoundCompute.lean b/Nfp/Untrusted/SoundCompute.lean index 9b5371f..f7f8734 100644 --- a/Nfp/Untrusted/SoundCompute.lean +++ b/Nfp/Untrusted/SoundCompute.lean @@ -4210,6 +4210,66 @@ def certifyHeadPatternBestMatchLocalSweep else return .error "head pattern bounds require NFP_BINARY_V1" +/-- Compute layer-level best-match margin evidence for a `.nfpt` layer (binary only). -/ +def certifyLayerBestMatchMarginLocal + (path : System.FilePath) + (layerIdx : Nat) + (eps : Rat) + (soundnessBits : Nat) + (inputPath? : Option System.FilePath := none) + (inputDelta : Rat := 0) + (targetOffset : Int := -1) + (maxSeqLen : Nat := 256) + (tightPattern : Bool := false) + (tightPatternLayers : Nat := 1) + (perRowPatternLayers : Nat := 0) + (scalePow10 : Nat := defaultBinaryScalePow10) + (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : + IO (Except String LayerBestMatchMarginCert) := do + if inputDelta < 0 then + return .error "delta must be nonnegative" + let h ← IO.FS.Handle.mk path IO.FS.Mode.read + let firstLine := (← h.getLine).trim + if firstLine = "NFP_BINARY_V1" then + let hdrE ← readBinaryHeader h + match hdrE with + | .error e => return .error e + | .ok hdr => + if layerIdx ≥ hdr.numLayers then + return .error s!"layer index {layerIdx} out of range" + if hdr.seqLen > maxSeqLen then + return .error s!"seq_len {hdr.seqLen} exceeds maxSeqLen {maxSeqLen}" + let inputPath := inputPath?.getD path + let mut headCerts : Array HeadBestMatchPatternCert := Array.mkEmpty 0 + for hIdx in [:hdr.numHeads] do + match ← + certifyHeadPatternBestMatchLocalBinarySweep + path layerIdx hIdx eps soundnessBits inputPath inputDelta targetOffset + maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 + softmaxExpEffort with + | .error e => return .error e + | .ok certs => + for cert in certs do + headCerts := headCerts.push cert + match marginsFromBestMatchCerts hdr.numHeads hdr.seqLen headCerts with + | none => return .error "best-match margin coverage failed" + | some margins => + let marginLowerBound := minMarginArray margins + let cert : LayerBestMatchMarginCert := { + layerIdx := layerIdx + seqLen := hdr.seqLen + numHeads := hdr.numHeads + softmaxExpEffort := softmaxExpEffort + marginLowerBound := marginLowerBound + margins := margins + headCerts := headCerts + } + if cert.check then + return .ok cert + return .error "layer best-match margin certificate failed internal checks" + else + return .error "layer best-match margins require NFP_BINARY_V1" + /-- Compute local head value lower bounds for a specific `.nfpt` head (binary only). -/ def certifyHeadValueLowerBoundLocal (path : System.FilePath) @@ -4652,6 +4712,9 @@ theorem certifyHeadPatternBestMatchLocal_spec_io : theorem certifyHeadPatternBestMatchLocalSweep_spec_io : certifyHeadPatternBestMatchLocalSweep = certifyHeadPatternBestMatchLocalSweep := rfl +theorem certifyLayerBestMatchMarginLocal_spec_io : + certifyLayerBestMatchMarginLocal = certifyLayerBestMatchMarginLocal := rfl + theorem certifyHeadValueLowerBoundLocal_spec_io : certifyHeadValueLowerBoundLocal = certifyHeadValueLowerBoundLocal := rfl diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index 4997f2c..0e311a3 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -15,6 +15,9 @@ soundness upgrade. It is intentionally brief and human-readable. currently **rejects nonzero** `softmaxMarginLowerBound` because margin evidence is unverified. - Local per-head contribution bounds can now be tightened using a best-match pattern certificate, but this tightening does **not** propagate to layer-level ModelCert bounds. +- Layer-level best-match margin certificates can be computed (binary only) and applied via + `tightenModelCertBestMatchMargins`, but this is not yet wired into the CLI and may not tighten + unless the best-match sweep covers all heads and query positions. - Best-match pattern certificates now use a margin-derived softmax Jacobian bound with an effort-indexed `expLB` (scaled Taylor + squaring). The lower-bound correctness of `expLB` is not yet formalized in Lean. From 006bd4d969e06dfd374ecfd17ba2da3e88d96b41 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 26 Dec 2025 02:17:02 +0100 Subject: [PATCH 003/244] Prove cache header round-trip --- Nfp/Sound/CachePure.lean | 340 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 339 insertions(+), 1 deletion(-) diff --git a/Nfp/Sound/CachePure.lean b/Nfp/Sound/CachePure.lean index e6e77a3..8b14b9d 100644 --- a/Nfp/Sound/CachePure.lean +++ b/Nfp/Sound/CachePure.lean @@ -618,7 +618,345 @@ theorem encodeHeader_magic_prefix (hdr : Header) : (encodeHeader hdr).extract 0 magic.size = magic := by simp [encodeHeader, ByteArray.append_assoc, ByteArray.extract_append_eq_left] --- TODO: Prove round-trip lemmas for `u32FromLE`/`u64FromLE` and `decodeHeader (encodeHeader _)`. +/-- `get!` agrees with `getElem` when the index is in bounds. -/ +theorem get!_eq_getElem {b : ByteArray} {i : Nat} (h : i < b.size) : b.get! i = b[i]'h := by + cases b with + | mk bs => + have h' : i < bs.size := by simpa using h + simpa [ByteArray.get!, ByteArray.get] using (getElem!_pos (c := bs) (i := i) h') + +/-- `get!` on an appended array reduces to the left part when the index is in bounds. -/ +theorem get!_append_left {a b : ByteArray} {i : Nat} + (hi : i < (a ++ b).size) (hlt : i < a.size) : (a ++ b).get! i = a.get! i := by + calc + (a ++ b).get! i = (a ++ b)[i]'hi := get!_eq_getElem hi + _ = a[i]'hlt := by + simpa using + (ByteArray.getElem_append_left (i := i) (a := a) (b := b) (h := hi) hlt) + _ = a.get! i := by + symm + exact get!_eq_getElem hlt + +/-- `get!` on an appended array reduces to the right part when the index is in bounds. -/ +theorem get!_append_right {a b : ByteArray} {i : Nat} + (hi : i < (a ++ b).size) (hle : a.size ≤ i) : + (a ++ b).get! i = b.get! (i - a.size) := by + have h' : i - a.size < b.size := by + have hi' : i < a.size + b.size := by + simpa [ByteArray.size_append] using hi + exact (Nat.sub_lt_iff_lt_add hle).2 (by simpa [Nat.add_comm] using hi') + calc + (a ++ b).get! i = (a ++ b)[i]'hi := get!_eq_getElem hi + _ = b[i - a.size]'h' := by + simpa using + (ByteArray.getElem_append_right (i := i) (a := a) (b := b) (h := hi) hle) + _ = b.get! (i - a.size) := by + symm + exact get!_eq_getElem h' + +/-- `u32FromLE` is a left inverse of `u32le` at offset `0`. -/ +theorem u32FromLE_u32le (x : UInt32) : u32FromLE (u32le x) 0 = x := by + apply (UInt32.toBitVec_inj).1 + have h255 : (255 : UInt8) = -1 := by decide + simp [u32FromLE, u32le, ByteArray.get!, h255] + bv_decide + +/-- `u64FromLE` is a left inverse of `u64le` at offset `0`. -/ +theorem u64FromLE_u64le (x : UInt64) : u64FromLE (u64le x) 0 = x := by + apply (UInt64.toBitVec_inj).1 + have h255 : (255 : UInt8) = -1 := by decide + simp [u64FromLE, u64le, ByteArray.get!, h255] + bv_decide + +/-- `u32FromLE` depends only on the left prefix when it has enough bytes. -/ +theorem u32FromLE_append_left (a b : ByteArray) (h : 3 < a.size) : + u32FromLE (a ++ b) 0 = u32FromLE a 0 := by + have h0 : 0 < a.size := by omega + have h1 : 1 < a.size := by omega + have h2 : 2 < a.size := by omega + have h3 : 3 < a.size := h + have hi0 : 0 < (a ++ b).size := by + have := Nat.lt_of_lt_of_le h0 (Nat.le_add_right a.size b.size) + simpa [ByteArray.size_append] using this + have hi1 : 1 < (a ++ b).size := by + have := Nat.lt_of_lt_of_le h1 (Nat.le_add_right a.size b.size) + simpa [ByteArray.size_append] using this + have hi2 : 2 < (a ++ b).size := by + have := Nat.lt_of_lt_of_le h2 (Nat.le_add_right a.size b.size) + simpa [ByteArray.size_append] using this + have hi3 : 3 < (a ++ b).size := by + have := Nat.lt_of_lt_of_le h3 (Nat.le_add_right a.size b.size) + simpa [ByteArray.size_append] using this + simp [u32FromLE, get!_append_left hi0 h0, get!_append_left hi1 h1, + get!_append_left hi2 h2, get!_append_left hi3 h3] + +/-- `u64FromLE` depends only on the left prefix when it has enough bytes. -/ +theorem u64FromLE_append_left (a b : ByteArray) (h : 7 < a.size) : + u64FromLE (a ++ b) 0 = u64FromLE a 0 := by + have h0 : 0 < a.size := by omega + have h1 : 1 < a.size := by omega + have h2 : 2 < a.size := by omega + have h3 : 3 < a.size := by omega + have h4 : 4 < a.size := by omega + have h5 : 5 < a.size := by omega + have h6 : 6 < a.size := by omega + have h7 : 7 < a.size := h + have hi0 : 0 < (a ++ b).size := by + have := Nat.lt_of_lt_of_le h0 (Nat.le_add_right a.size b.size) + simpa [ByteArray.size_append] using this + have hi1 : 1 < (a ++ b).size := by + have := Nat.lt_of_lt_of_le h1 (Nat.le_add_right a.size b.size) + simpa [ByteArray.size_append] using this + have hi2 : 2 < (a ++ b).size := by + have := Nat.lt_of_lt_of_le h2 (Nat.le_add_right a.size b.size) + simpa [ByteArray.size_append] using this + have hi3 : 3 < (a ++ b).size := by + have := Nat.lt_of_lt_of_le h3 (Nat.le_add_right a.size b.size) + simpa [ByteArray.size_append] using this + have hi4 : 4 < (a ++ b).size := by + have := Nat.lt_of_lt_of_le h4 (Nat.le_add_right a.size b.size) + simpa [ByteArray.size_append] using this + have hi5 : 5 < (a ++ b).size := by + have := Nat.lt_of_lt_of_le h5 (Nat.le_add_right a.size b.size) + simpa [ByteArray.size_append] using this + have hi6 : 6 < (a ++ b).size := by + have := Nat.lt_of_lt_of_le h6 (Nat.le_add_right a.size b.size) + simpa [ByteArray.size_append] using this + have hi7 : 7 < (a ++ b).size := by + have := Nat.lt_of_lt_of_le h7 (Nat.le_add_right a.size b.size) + simpa [ByteArray.size_append] using this + simp [u64FromLE, get!_append_left hi0 h0, get!_append_left hi1 h1, + get!_append_left hi2 h2, get!_append_left hi3 h3, get!_append_left hi4 h4, + get!_append_left hi5 h5, get!_append_left hi6 h6, get!_append_left hi7 h7] + +/-- `u32FromLE` ignores a left prefix when reading from the right. -/ +theorem u32FromLE_append_right (a b : ByteArray) (off : Nat) (h : off + 3 < b.size) : + u32FromLE (a ++ b) (a.size + off) = u32FromLE b off := by + have h0' : off < b.size := by omega + have h1' : off + 1 < b.size := by omega + have h2' : off + 2 < b.size := by omega + have h3' : off + 3 < b.size := h + have h0 : a.size + off < (a ++ b).size := by + have := Nat.add_lt_add_left h0' a.size + simpa [ByteArray.size_append, Nat.add_assoc] using this + have h1 : a.size + off + 1 < (a ++ b).size := by + have := Nat.add_lt_add_left h1' a.size + simpa [ByteArray.size_append, Nat.add_assoc] using this + have h2 : a.size + off + 2 < (a ++ b).size := by + have := Nat.add_lt_add_left h2' a.size + simpa [ByteArray.size_append, Nat.add_assoc] using this + have h3 : a.size + off + 3 < (a ++ b).size := by + have := Nat.add_lt_add_left h3' a.size + simpa [ByteArray.size_append, Nat.add_assoc] using this + have hle0 : a.size ≤ a.size + off := by omega + have hle1 : a.size ≤ a.size + off + 1 := by omega + have hle2 : a.size ≤ a.size + off + 2 := by omega + have hle3 : a.size ≤ a.size + off + 3 := by omega + unfold u32FromLE + simp [get!_append_right h0 hle0, get!_append_right h1 hle1, + get!_append_right h2 hle2, get!_append_right h3 hle3] + simp [Nat.add_assoc, Nat.add_sub_cancel_left] + +/-- `u64FromLE` ignores a left prefix when reading from the right. -/ +theorem u64FromLE_append_right (a b : ByteArray) (off : Nat) (h : off + 7 < b.size) : + u64FromLE (a ++ b) (a.size + off) = u64FromLE b off := by + have h0' : off < b.size := by omega + have h1' : off + 1 < b.size := by omega + have h2' : off + 2 < b.size := by omega + have h3' : off + 3 < b.size := by omega + have h4' : off + 4 < b.size := by omega + have h5' : off + 5 < b.size := by omega + have h6' : off + 6 < b.size := by omega + have h7' : off + 7 < b.size := h + have h0 : a.size + off < (a ++ b).size := by + have := Nat.add_lt_add_left h0' a.size + simpa [ByteArray.size_append, Nat.add_assoc] using this + have h1 : a.size + off + 1 < (a ++ b).size := by + have := Nat.add_lt_add_left h1' a.size + simpa [ByteArray.size_append, Nat.add_assoc] using this + have h2 : a.size + off + 2 < (a ++ b).size := by + have := Nat.add_lt_add_left h2' a.size + simpa [ByteArray.size_append, Nat.add_assoc] using this + have h3 : a.size + off + 3 < (a ++ b).size := by + have := Nat.add_lt_add_left h3' a.size + simpa [ByteArray.size_append, Nat.add_assoc] using this + have h4 : a.size + off + 4 < (a ++ b).size := by + have := Nat.add_lt_add_left h4' a.size + simpa [ByteArray.size_append, Nat.add_assoc] using this + have h5 : a.size + off + 5 < (a ++ b).size := by + have := Nat.add_lt_add_left h5' a.size + simpa [ByteArray.size_append, Nat.add_assoc] using this + have h6 : a.size + off + 6 < (a ++ b).size := by + have := Nat.add_lt_add_left h6' a.size + simpa [ByteArray.size_append, Nat.add_assoc] using this + have h7 : a.size + off + 7 < (a ++ b).size := by + have := Nat.add_lt_add_left h7' a.size + simpa [ByteArray.size_append, Nat.add_assoc] using this + have hle0 : a.size ≤ a.size + off := by omega + have hle1 : a.size ≤ a.size + off + 1 := by omega + have hle2 : a.size ≤ a.size + off + 2 := by omega + have hle3 : a.size ≤ a.size + off + 3 := by omega + have hle4 : a.size ≤ a.size + off + 4 := by omega + have hle5 : a.size ≤ a.size + off + 5 := by omega + have hle6 : a.size ≤ a.size + off + 6 := by omega + have hle7 : a.size ≤ a.size + off + 7 := by omega + unfold u64FromLE + simp [get!_append_right h0 hle0, get!_append_right h1 hle1, + get!_append_right h2 hle2, get!_append_right h3 hle3, get!_append_right h4 hle4, + get!_append_right h5 hle5, get!_append_right h6 hle6, get!_append_right h7 hle7] + simp [Nat.add_assoc, Nat.add_sub_cancel_left] + +/-- `u32FromLE` round-trips a `u32le` prefix. -/ +theorem u32FromLE_u32le_append (x : UInt32) (b : ByteArray) : + u32FromLE (u32le x ++ b) 0 = x := by + have h : 3 < (u32le x).size := by + simp [u32le_size] + calc + u32FromLE (u32le x ++ b) 0 = u32FromLE (u32le x) 0 := + u32FromLE_append_left (a := u32le x) (b := b) h + _ = x := u32FromLE_u32le x + +/-- `u64FromLE` round-trips a `u64le` prefix. -/ +theorem u64FromLE_u64le_append (x : UInt64) (b : ByteArray) : + u64FromLE (u64le x ++ b) 0 = x := by + have h : 7 < (u64le x).size := by + simp [u64le_size] + calc + u64FromLE (u64le x ++ b) 0 = u64FromLE (u64le x) 0 := + u64FromLE_append_left (a := u64le x) (b := b) h + _ = x := u64FromLE_u64le x + +/-- `u32FromLE` round-trips a `u32le` block after a prefix. -/ +theorem u32FromLE_append_u32le (a : ByteArray) (x : UInt32) (b : ByteArray) : + u32FromLE (a ++ u32le x ++ b) a.size = x := by + calc + u32FromLE (a ++ u32le x ++ b) a.size = u32FromLE (u32le x ++ b) 0 := by + have h : 0 + 3 < (u32le x ++ b).size := by + simp [ByteArray.size_append, u32le_size] + omega + simpa [ByteArray.append_assoc] using + (u32FromLE_append_right (a := a) (b := u32le x ++ b) (off := 0) h) + _ = x := u32FromLE_u32le_append x b + +/-- `u64FromLE` round-trips a `u64le` block after a prefix. -/ +theorem u64FromLE_append_u64le (a : ByteArray) (x : UInt64) (b : ByteArray) : + u64FromLE (a ++ u64le x ++ b) a.size = x := by + calc + u64FromLE (a ++ u64le x ++ b) a.size = u64FromLE (u64le x ++ b) 0 := by + have h : 0 + 7 < (u64le x ++ b).size := by + simp [ByteArray.size_append, u64le_size] + omega + simpa [ByteArray.append_assoc] using + (u64FromLE_append_right (a := a) (b := u64le x ++ b) (off := 0) h) + _ = x := u64FromLE_u64le_append x b + +/-- `decodeHeader` recovers any header encoded by `encodeHeader`. -/ +theorem decodeHeader_encodeHeader (hdr : Header) : + decodeHeader (encodeHeader hdr) = .ok hdr := by + have h1 : magic.size + 4 = (magic ++ u32le version).size := by + simp [ByteArray.size_append, u32le_size] + have h2 : magic.size + 4 + 8 = + (magic ++ u32le version ++ u64le hdr.modelHash).size := by + simp [ByteArray.size_append, u32le_size, u64le_size] + have h3 : magic.size + 4 + 8 + 8 = + (magic ++ u32le version ++ u64le hdr.modelHash ++ u64le hdr.modelSize).size := by + simp [ByteArray.size_append, u32le_size, u64le_size] + have h4 : magic.size + 4 + 8 + 8 + 4 = + (magic ++ u32le version ++ u64le hdr.modelHash ++ u64le hdr.modelSize ++ + u32le hdr.scalePow10).size := by + simp [ByteArray.size_append, u32le_size, u64le_size] + have h5 : magic.size + 4 + 8 + 8 + 4 + 4 = + (magic ++ u32le version ++ u64le hdr.modelHash ++ u64le hdr.modelSize ++ + u32le hdr.scalePow10 ++ u32le hdr.numLayers).size := by + simp [ByteArray.size_append, u32le_size, u64le_size] + have h6 : magic.size + 4 + 8 + 8 + 4 + 4 + 4 = + (magic ++ u32le version ++ u64le hdr.modelHash ++ u64le hdr.modelSize ++ + u32le hdr.scalePow10 ++ u32le hdr.numLayers ++ u32le hdr.numHeads).size := by + simp [ByteArray.size_append, u32le_size, u64le_size] + have h7 : magic.size + 4 + 8 + 8 + 4 + 4 + 4 + 4 = + (magic ++ u32le version ++ u64le hdr.modelHash ++ u64le hdr.modelSize ++ + u32le hdr.scalePow10 ++ u32le hdr.numLayers ++ u32le hdr.numHeads ++ + u32le hdr.modelDim).size := by + simp [ByteArray.size_append, u32le_size, u64le_size] + have h8 : magic.size + 4 + 8 + 8 + 4 + 4 + 4 + 4 + 4 = + (magic ++ u32le version ++ u64le hdr.modelHash ++ u64le hdr.modelSize ++ + u32le hdr.scalePow10 ++ u32le hdr.numLayers ++ u32le hdr.numHeads ++ + u32le hdr.modelDim ++ u32le hdr.headDim).size := by + simp [ByteArray.size_append, u32le_size, u64le_size] + have h_version : u32FromLE (encodeHeader hdr) magic.size = version := by + simpa [encodeHeader] using + (u32FromLE_append_u32le (a := magic) (x := version) + (b := u64le hdr.modelHash ++ u64le hdr.modelSize ++ u32le hdr.scalePow10 ++ + u32le hdr.numLayers ++ u32le hdr.numHeads ++ u32le hdr.modelDim ++ + u32le hdr.headDim ++ u32le hdr.hiddenDim)) + have h_modelHash : u64FromLE (encodeHeader hdr) (magic.size + 4) = hdr.modelHash := by + simpa [encodeHeader, h1] using + (u64FromLE_append_u64le (a := magic ++ u32le version) (x := hdr.modelHash) + (b := u64le hdr.modelSize ++ u32le hdr.scalePow10 ++ u32le hdr.numLayers ++ + u32le hdr.numHeads ++ u32le hdr.modelDim ++ u32le hdr.headDim ++ + u32le hdr.hiddenDim)) + have h_modelSize : u64FromLE (encodeHeader hdr) (magic.size + 4 + 8) = hdr.modelSize := by + simpa [encodeHeader, h2] using + (u64FromLE_append_u64le + (a := magic ++ u32le version ++ u64le hdr.modelHash) + (x := hdr.modelSize) + (b := u32le hdr.scalePow10 ++ u32le hdr.numLayers ++ u32le hdr.numHeads ++ + u32le hdr.modelDim ++ u32le hdr.headDim ++ u32le hdr.hiddenDim)) + have h_scalePow10 : + u32FromLE (encodeHeader hdr) (magic.size + 4 + 8 + 8) = hdr.scalePow10 := by + simpa [encodeHeader, h3] using + (u32FromLE_append_u32le + (a := magic ++ u32le version ++ u64le hdr.modelHash ++ u64le hdr.modelSize) + (x := hdr.scalePow10) + (b := u32le hdr.numLayers ++ u32le hdr.numHeads ++ u32le hdr.modelDim ++ + u32le hdr.headDim ++ u32le hdr.hiddenDim)) + have h_numLayers : + u32FromLE (encodeHeader hdr) (magic.size + 4 + 8 + 8 + 4) = hdr.numLayers := by + simpa [encodeHeader, h4] using + (u32FromLE_append_u32le + (a := magic ++ u32le version ++ u64le hdr.modelHash ++ u64le hdr.modelSize ++ + u32le hdr.scalePow10) + (x := hdr.numLayers) + (b := u32le hdr.numHeads ++ u32le hdr.modelDim ++ u32le hdr.headDim ++ + u32le hdr.hiddenDim)) + have h_numHeads : + u32FromLE (encodeHeader hdr) (magic.size + 4 + 8 + 8 + 4 + 4) = hdr.numHeads := by + simpa [encodeHeader, h5] using + (u32FromLE_append_u32le + (a := magic ++ u32le version ++ u64le hdr.modelHash ++ u64le hdr.modelSize ++ + u32le hdr.scalePow10 ++ u32le hdr.numLayers) + (x := hdr.numHeads) + (b := u32le hdr.modelDim ++ u32le hdr.headDim ++ u32le hdr.hiddenDim)) + have h_modelDim : + u32FromLE (encodeHeader hdr) (magic.size + 4 + 8 + 8 + 4 + 4 + 4) = hdr.modelDim := by + simpa [encodeHeader, h6] using + (u32FromLE_append_u32le + (a := magic ++ u32le version ++ u64le hdr.modelHash ++ u64le hdr.modelSize ++ + u32le hdr.scalePow10 ++ u32le hdr.numLayers ++ u32le hdr.numHeads) + (x := hdr.modelDim) + (b := u32le hdr.headDim ++ u32le hdr.hiddenDim)) + have h_headDim : + u32FromLE (encodeHeader hdr) (magic.size + 4 + 8 + 8 + 4 + 4 + 4 + 4) = hdr.headDim := by + simpa [encodeHeader, h7] using + (u32FromLE_append_u32le + (a := magic ++ u32le version ++ u64le hdr.modelHash ++ u64le hdr.modelSize ++ + u32le hdr.scalePow10 ++ u32le hdr.numLayers ++ u32le hdr.numHeads ++ + u32le hdr.modelDim) + (x := hdr.headDim) + (b := u32le hdr.hiddenDim)) + have h_hiddenDim : + u32FromLE (encodeHeader hdr) (magic.size + 4 + 8 + 8 + 4 + 4 + 4 + 4 + 4) = + hdr.hiddenDim := by + simpa [encodeHeader, h8] using + (u32FromLE_append_u32le + (a := magic ++ u32le version ++ u64le hdr.modelHash ++ u64le hdr.modelSize ++ + u32le hdr.scalePow10 ++ u32le hdr.numLayers ++ u32le hdr.numHeads ++ + u32le hdr.modelDim ++ u32le hdr.headDim) + (x := hdr.hiddenDim) + (b := ByteArray.empty)) + simp [decodeHeader, encodeHeader_size, encodeHeader_magic_prefix, h_version, h_modelHash, + h_modelSize, h_scalePow10, h_numLayers, h_numHeads, h_modelDim, h_headDim, h_hiddenDim] + cases hdr <;> rfl /-! ### Specs -/ From 32689eb058ebe932baa30d3699a38e756a102509 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 26 Dec 2025 02:30:40 +0100 Subject: [PATCH 004/244] Add specs for SOUND helpers --- Nfp/Sound/Bounds/LayerNorm.lean | 11 +++++++++++ Nfp/Sound/Interval.lean | 2 ++ Nfp/Sound/ModelHeader.lean | 2 ++ 3 files changed, 15 insertions(+) diff --git a/Nfp/Sound/Bounds/LayerNorm.lean b/Nfp/Sound/Bounds/LayerNorm.lean index c24fbe3..eb9d082 100644 --- a/Nfp/Sound/Bounds/LayerNorm.lean +++ b/Nfp/Sound/Bounds/LayerNorm.lean @@ -45,6 +45,11 @@ private def SqrtLowerDyadicCert.rat {x : Rat} {precBits : Nat} (c : SqrtLowerDyadicCert x precBits) : Rat := Rat.normalize (Int.ofNat c.k) (pow2 precBits) (den_nz := by simp [pow2]) +theorem SqrtLowerDyadicCert.rat_def {x : Rat} {precBits : Nat} + (c : SqrtLowerDyadicCert x precBits) : + SqrtLowerDyadicCert.rat c = + Rat.normalize (Int.ofNat c.k) (pow2 precBits) (den_nz := by simp [pow2]) := rfl + /-- Compute a dyadic floor certificate for `sqrt (max x 0)` using `Nat.sqrt` on the floor. -/ private def sqrtLowerDyadic (x : Rat) (precBits : Nat) : SqrtLowerDyadicCert x precBits := by let scale : Nat := pow2 precBits @@ -74,10 +79,16 @@ private def sqrtLowerDyadic (x : Rat) (precBits : Nat) : SqrtLowerDyadicCert x p exact_mod_cast hm_succ_le_nat exact lt_of_lt_of_le hy_lt hm_succ_le_rat +theorem sqrtLowerDyadic_spec (x : Rat) (precBits : Nat) : + sqrtLowerDyadic x precBits = sqrtLowerDyadic x precBits := rfl + /-- Dyadic lower bound on `sqrt (max x 0)` as a `Rat`. -/ private def sqrtLowerDyadicRat (x : Rat) (precBits : Nat) : Rat := (sqrtLowerDyadic x precBits).rat +theorem sqrtLowerDyadicRat_def (x : Rat) (precBits : Nat) : + sqrtLowerDyadicRat x precBits = (sqrtLowerDyadic x precBits).rat := rfl + /-- Conservative bound for the operator norm of a row-wise LayerNorm Jacobian. In exact real arithmetic one can show `‖J‖₂ ≤ max |γ| / σ` with `σ = sqrt(var + eps)`. diff --git a/Nfp/Sound/Interval.lean b/Nfp/Sound/Interval.lean index b4c76fa..4723390 100644 --- a/Nfp/Sound/Interval.lean +++ b/Nfp/Sound/Interval.lean @@ -284,6 +284,8 @@ theorem containsZero_iff (a : RatInterval) : RatInterval.containsZero a = true ↔ a.lo ≤ 0 ∧ 0 ≤ a.hi := by simp [containsZero] +theorem ratSq_def (x : Rat) : ratSq x = x * x := rfl + theorem squareLowerBound_def (a : RatInterval) : RatInterval.squareLowerBound a = if RatInterval.containsZero a then diff --git a/Nfp/Sound/ModelHeader.lean b/Nfp/Sound/ModelHeader.lean index 6864691..e2f2794 100644 --- a/Nfp/Sound/ModelHeader.lean +++ b/Nfp/Sound/ModelHeader.lean @@ -76,6 +76,8 @@ def parseTextHeaderEps (lines : Array String) : Except String Rat := do /-! ### Specs -/ theorem parseHeaderLine_spec : parseHeaderLine = parseHeaderLine := rfl +theorem parseGeluDerivTarget_spec (v : String) : + parseGeluDerivTarget v = parseGeluDerivTarget v := rfl theorem parseTextHeader_spec : parseTextHeader = parseTextHeader := rfl theorem parseTextHeaderEps_spec : parseTextHeaderEps = parseTextHeaderEps := rfl From 9b9cab1fb6c74ae01bd39dac2be0ff140818a97c Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 26 Dec 2025 02:36:35 +0100 Subject: [PATCH 005/244] Move SOUND checks to pure helpers --- Nfp/Sound/Cert.lean | 72 ++++++++++++++++++++++++++++++++ Nfp/Sound/IO.lean | 91 ----------------------------------------- Nfp/Sound/TextPure.lean | 25 +++++++++++ 3 files changed, 97 insertions(+), 91 deletions(-) diff --git a/Nfp/Sound/Cert.lean b/Nfp/Sound/Cert.lean index c197bca..8a6fa44 100644 --- a/Nfp/Sound/Cert.lean +++ b/Nfp/Sound/Cert.lean @@ -4,6 +4,7 @@ import Std import Mathlib.Data.Rat.Cast.Order import Nfp.SignedMixer import Nfp.Sound.Bounds +import Nfp.Sound.HeadCert namespace Nfp.Sound @@ -397,6 +398,77 @@ theorem toString_spec : toString = toString := rfl end ModelCert +/-! ### Certificate verification helpers -/ + +/-- Ensure all layers have zero softmax margin evidence. -/ +def checkSoftmaxMarginZero (cert : ModelCert) : Except String Unit := + Id.run do + for idx in [:cert.layers.size] do + let layer := cert.layers[idx]! + if layer.softmaxMarginLowerBound ≠ 0 then + return .error s!"softmaxMarginLowerBound is unverified (layer {idx})" + return .ok () + +theorem checkSoftmaxMarginZero_spec : + checkSoftmaxMarginZero = checkSoftmaxMarginZero := rfl + +/-- Ensure the softmax probability interval is the worst-case `[0,1]`. -/ +def checkSoftmaxProbIntervalWorst (cert : ModelCert) : Except String Unit := + Id.run do + for idx in [:cert.layers.size] do + let layer := cert.layers[idx]! + if layer.softmaxProbLo ≠ 0 then + return .error s!"softmaxProbLo is unverified (layer {idx})" + if layer.softmaxProbHi ≠ 1 then + return .error s!"softmaxProbHi is unverified (layer {idx})" + return .ok () + +theorem checkSoftmaxProbIntervalWorst_spec : + checkSoftmaxProbIntervalWorst = checkSoftmaxProbIntervalWorst := rfl + +/-- Update a layer certificate with best-match softmax evidence if it is valid and tighter. -/ +def tightenLayerSoftmaxFromBestMatch + (seqLen : Nat) (layer : LayerAmplificationCert) (cert : LayerBestMatchMarginCert) : + Except String LayerAmplificationCert := + Id.run do + if !cert.check then + return .error "layer best-match margin cert failed internal checks" + if cert.layerIdx ≠ layer.layerIdx then + return .error "layer margin cert does not match layer index" + if cert.seqLen ≠ seqLen then + return .error "layer margin cert seq_len mismatch" + let updated := + LayerAmplificationCert.withSoftmaxMargin seqLen cert.marginLowerBound + cert.softmaxExpEffort layer + if updated.softmaxJacobianNormInfUpperBound > layer.softmaxJacobianNormInfUpperBound then + return .error "best-match softmax bound is worse than baseline" + return .ok updated + +theorem tightenLayerSoftmaxFromBestMatch_spec : + tightenLayerSoftmaxFromBestMatch = tightenLayerSoftmaxFromBestMatch := rfl + +/-- Apply best-match margin updates to a whole model certificate. -/ +def tightenModelCertBestMatchMargins + (c : ModelCert) (certs : Array LayerBestMatchMarginCert) : + Except String ModelCert := + certs.foldl (fun acc cert => + match acc with + | .error e => .error e + | .ok cur => + if cert.layerIdx < cur.layers.size then + let layer := cur.layers[cert.layerIdx]! + match tightenLayerSoftmaxFromBestMatch cur.seqLen layer cert with + | .error e => .error e + | .ok updatedLayer => + match ModelCert.withUpdatedLayer cur cert.layerIdx updatedLayer with + | none => .error "failed to update model cert layer" + | some updated => .ok updated + else + .error s!"layer margin cert index {cert.layerIdx} out of range") (.ok c) + +theorem tightenModelCertBestMatchMargins_spec : + tightenModelCertBestMatchMargins = tightenModelCertBestMatchMargins := rfl + /-! ### Specs -/ end Nfp.Sound diff --git a/Nfp/Sound/IO.lean b/Nfp/Sound/IO.lean index 438fed3..fcd7447 100644 --- a/Nfp/Sound/IO.lean +++ b/Nfp/Sound/IO.lean @@ -51,97 +51,6 @@ private def readModelEps (path : System.FilePath) : IO (Except String Rat) := do | .error e => return .error e | .ok (eps, _) => return .ok eps -private def checkAttnWeightBounds - (cert : ModelCert) - (expected : AttnWeightBounds) : Except String Unit := - Id.run do - if expected.attnValueCoeff.size ≠ cert.layers.size then - return .error "attnValueCoeff layer count mismatch" - if expected.wqOpBoundMax.size ≠ cert.layers.size then - return .error "wqOpBoundMax layer count mismatch" - if expected.wkOpBoundMax.size ≠ cert.layers.size then - return .error "wkOpBoundMax layer count mismatch" - for idx in [:cert.layers.size] do - let expValue := expected.attnValueCoeff[idx]! - let expWq := expected.wqOpBoundMax[idx]! - let expWk := expected.wkOpBoundMax[idx]! - let layer := cert.layers[idx]! - if expValue ≠ layer.attnValueCoeff then - return .error s!"attnValueCoeff mismatch at layer {idx}" - if expWq ≠ layer.wqOpBoundMax then - return .error s!"wqOpBoundMax mismatch at layer {idx}" - if expWk ≠ layer.wkOpBoundMax then - return .error s!"wkOpBoundMax mismatch at layer {idx}" - return .ok () - -theorem checkAttnWeightBounds_spec_io : - checkAttnWeightBounds = checkAttnWeightBounds := rfl - -private def checkSoftmaxMarginZero (cert : ModelCert) : Except String Unit := - Id.run do - for idx in [:cert.layers.size] do - let layer := cert.layers[idx]! - if layer.softmaxMarginLowerBound ≠ 0 then - return .error s!"softmaxMarginLowerBound is unverified (layer {idx})" - return .ok () - -theorem checkSoftmaxMarginZero_spec_io : - checkSoftmaxMarginZero = checkSoftmaxMarginZero := rfl - -private def checkSoftmaxProbIntervalWorst (cert : ModelCert) : Except String Unit := - Id.run do - for idx in [:cert.layers.size] do - let layer := cert.layers[idx]! - if layer.softmaxProbLo ≠ 0 then - return .error s!"softmaxProbLo is unverified (layer {idx})" - if layer.softmaxProbHi ≠ 1 then - return .error s!"softmaxProbHi is unverified (layer {idx})" - return .ok () - -theorem checkSoftmaxProbIntervalWorst_spec_io : - checkSoftmaxProbIntervalWorst = checkSoftmaxProbIntervalWorst := rfl - -private def tightenLayerSoftmaxFromBestMatch - (seqLen : Nat) (layer : LayerAmplificationCert) (cert : LayerBestMatchMarginCert) : - Except String LayerAmplificationCert := - Id.run do - if !cert.check then - return .error "layer best-match margin cert failed internal checks" - if cert.layerIdx ≠ layer.layerIdx then - return .error "layer margin cert does not match layer index" - if cert.seqLen ≠ seqLen then - return .error "layer margin cert seq_len mismatch" - let updated := - LayerAmplificationCert.withSoftmaxMargin seqLen cert.marginLowerBound - cert.softmaxExpEffort layer - if updated.softmaxJacobianNormInfUpperBound > layer.softmaxJacobianNormInfUpperBound then - return .error "best-match softmax bound is worse than baseline" - return .ok updated - -theorem tightenLayerSoftmaxFromBestMatch_spec_io : - tightenLayerSoftmaxFromBestMatch = tightenLayerSoftmaxFromBestMatch := rfl - -def tightenModelCertBestMatchMargins - (c : ModelCert) (certs : Array LayerBestMatchMarginCert) : - Except String ModelCert := - certs.foldl (fun acc cert => - match acc with - | .error e => .error e - | .ok cur => - if cert.layerIdx < cur.layers.size then - let layer := cur.layers[cert.layerIdx]! - match tightenLayerSoftmaxFromBestMatch cur.seqLen layer cert with - | .error e => .error e - | .ok updatedLayer => - match ModelCert.withUpdatedLayer cur cert.layerIdx updatedLayer with - | none => .error "failed to update model cert layer" - | some updated => .ok updated - else - .error s!"layer margin cert index {cert.layerIdx} out of range") (.ok c) - -theorem tightenModelCertBestMatchMargins_spec_io : - tightenModelCertBestMatchMargins = tightenModelCertBestMatchMargins := rfl - private def recomputeAttnWeightBoundsBinary (path : System.FilePath) : IO (Except String AttnWeightBounds) := do let h ← IO.FS.Handle.mk path IO.FS.Mode.read diff --git a/Nfp/Sound/TextPure.lean b/Nfp/Sound/TextPure.lean index fe3895d..bc9d55c 100644 --- a/Nfp/Sound/TextPure.lean +++ b/Nfp/Sound/TextPure.lean @@ -2,6 +2,7 @@ import Std import Nfp.Sound.Bounds +import Nfp.Sound.Cert import Nfp.Sound.Decimal import Nfp.Sound.ModelHeader @@ -30,6 +31,28 @@ structure AttnWeightBounds where wkOpBoundMax : Array Rat deriving Repr +/-- Verify that attention-weight bounds match the certificate layer fields. -/ +def checkAttnWeightBounds (cert : ModelCert) (expected : AttnWeightBounds) : Except String Unit := + Id.run do + if expected.attnValueCoeff.size ≠ cert.layers.size then + return .error "attnValueCoeff layer count mismatch" + if expected.wqOpBoundMax.size ≠ cert.layers.size then + return .error "wqOpBoundMax layer count mismatch" + if expected.wkOpBoundMax.size ≠ cert.layers.size then + return .error "wkOpBoundMax layer count mismatch" + for idx in [:cert.layers.size] do + let expValue := expected.attnValueCoeff[idx]! + let expWq := expected.wqOpBoundMax[idx]! + let expWk := expected.wkOpBoundMax[idx]! + let layer := cert.layers[idx]! + if expValue ≠ layer.attnValueCoeff then + return .error s!"attnValueCoeff mismatch at layer {idx}" + if expWq ≠ layer.wqOpBoundMax then + return .error s!"wqOpBoundMax mismatch at layer {idx}" + if expWk ≠ layer.wkOpBoundMax then + return .error s!"wkOpBoundMax mismatch at layer {idx}" + return .ok () + def parseTextHeaderDims (lines : Array String) : Except String TextModelDims := Id.run do let mut i : Nat := 0 @@ -200,6 +223,8 @@ def attnValueCoeffFromTextLines (lines : Array String) : Except String (Array Ra theorem parseTextHeaderDims_spec : parseTextHeaderDims = parseTextHeaderDims := rfl theorem AttnWeightBounds_spec : AttnWeightBounds = AttnWeightBounds := rfl +theorem checkAttnWeightBounds_spec : + checkAttnWeightBounds = checkAttnWeightBounds := rfl theorem foldRatTokens_spec (α : Type) : @foldRatTokens α = @foldRatTokens α := rfl theorem consumeVector_spec : consumeVector = consumeVector := rfl From 576d40414bc088a98d1297c5f178483e16f560ca Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 26 Dec 2025 02:46:29 +0100 Subject: [PATCH 006/244] Refactor SOUND IO verification helpers --- Nfp/Sound/Cert.lean | 55 ++++++++++++++ Nfp/Sound/HeadCert.lean | 155 ++++++++++++++++++++++++++++++++++++++++ Nfp/Sound/IO.lean | 131 ++++++++------------------------- Nfp/Sound/TextPure.lean | 21 +----- 4 files changed, 240 insertions(+), 122 deletions(-) diff --git a/Nfp/Sound/Cert.lean b/Nfp/Sound/Cert.lean index 8a6fa44..ac8d2e4 100644 --- a/Nfp/Sound/Cert.lean +++ b/Nfp/Sound/Cert.lean @@ -400,6 +400,32 @@ end ModelCert /-! ### Certificate verification helpers -/ +/-- Verify attention-weight bounds from per-layer arrays. -/ +def checkAttnWeightBoundsArrays (cert : ModelCert) + (attnValueCoeff wqOpBoundMax wkOpBoundMax : Array Rat) : Except String Unit := + Id.run do + if attnValueCoeff.size ≠ cert.layers.size then + return .error "attnValueCoeff layer count mismatch" + if wqOpBoundMax.size ≠ cert.layers.size then + return .error "wqOpBoundMax layer count mismatch" + if wkOpBoundMax.size ≠ cert.layers.size then + return .error "wkOpBoundMax layer count mismatch" + for idx in [:cert.layers.size] do + let expValue := attnValueCoeff[idx]! + let expWq := wqOpBoundMax[idx]! + let expWk := wkOpBoundMax[idx]! + let layer := cert.layers[idx]! + if expValue ≠ layer.attnValueCoeff then + return .error s!"attnValueCoeff mismatch at layer {idx}" + if expWq ≠ layer.wqOpBoundMax then + return .error s!"wqOpBoundMax mismatch at layer {idx}" + if expWk ≠ layer.wkOpBoundMax then + return .error s!"wkOpBoundMax mismatch at layer {idx}" + return .ok () + +theorem checkAttnWeightBoundsArrays_spec : + checkAttnWeightBoundsArrays = checkAttnWeightBoundsArrays := rfl + /-- Ensure all layers have zero softmax margin evidence. -/ def checkSoftmaxMarginZero (cert : ModelCert) : Except String Unit := Id.run do @@ -469,6 +495,35 @@ def tightenModelCertBestMatchMargins theorem tightenModelCertBestMatchMargins_spec : tightenModelCertBestMatchMargins = tightenModelCertBestMatchMargins := rfl +/-- Verify a model certificate against header metadata and expected attention bounds. -/ +def verifyModelCert + (cert : ModelCert) + (eps : Rat) + (soundnessBits : Nat) + (geluDerivTarget : GeluDerivTarget) + (attnValueCoeff wqOpBoundMax wkOpBoundMax : Array Rat) : Except String ModelCert := + Id.run do + if cert.eps ≠ eps then + return .error "model header eps mismatch" + if cert.soundnessBits ≠ soundnessBits then + return .error "soundness bits mismatch" + if cert.geluDerivTarget ≠ geluDerivTarget then + return .error "model header gelu_kind mismatch" + if cert.check then + match checkSoftmaxProbIntervalWorst cert with + | .error e => return .error e + | .ok _ => + match checkSoftmaxMarginZero cert with + | .error e => return .error e + | .ok _ => + match checkAttnWeightBoundsArrays cert attnValueCoeff wqOpBoundMax wkOpBoundMax with + | .error e => return .error e + | .ok _ => return .ok cert + return .error "sound certificate failed internal consistency checks" + +theorem verifyModelCert_spec : + verifyModelCert = verifyModelCert := rfl + /-! ### Specs -/ end Nfp.Sound diff --git a/Nfp/Sound/HeadCert.lean b/Nfp/Sound/HeadCert.lean index 975f6db..a99f369 100644 --- a/Nfp/Sound/HeadCert.lean +++ b/Nfp/Sound/HeadCert.lean @@ -513,6 +513,135 @@ theorem check_iff (c : InductionHeadBestMatchSoundCert) : c.check = true ↔ c.V end InductionHeadBestMatchSoundCert +/-! ### Certificate verification helpers -/ + +/-- Validate a batch of head contribution certificates. -/ +def verifyHeadContributionCerts (certs : Array HeadContributionCert) : + Except String (Array HeadContributionCert) := + let ok := certs.foldl (fun acc c => acc && c.check) true + if ok then + .ok certs + else + .error "head contribution certificate failed internal checks" + +/-- Validate a batch of local head contribution certificates. -/ +def verifyHeadLocalContributionCerts (eps : Rat) (soundnessBits : Nat) + (certs : Array HeadLocalContributionCert) : + Except String (Array HeadLocalContributionCert) := + let ok := + certs.foldl (fun acc c => + acc && c.soundnessBits = soundnessBits && c.check eps) true + if ok then + .ok certs + else + .error "local head contribution certificate failed internal checks" + +/-- Validate a single local head contribution certificate. -/ +def verifyHeadLocalContributionCert (eps : Rat) (soundnessBits : Nat) + (cert : HeadLocalContributionCert) : Except String HeadLocalContributionCert := + if cert.soundnessBits = soundnessBits && cert.check eps then + .ok cert + else + .error "local head contribution certificate failed internal checks" + +/-- Validate a head pattern certificate. -/ +def verifyHeadPatternCert (cert : HeadPatternCert) : Except String HeadPatternCert := + if cert.check then + .ok cert + else + .error "head pattern certificate failed internal checks" + +/-- Validate a best-match head pattern certificate. -/ +def verifyHeadBestMatchPatternCert (cert : HeadBestMatchPatternCert) : + Except String HeadBestMatchPatternCert := + if cert.check then + .ok cert + else + .error "head best-match pattern certificate failed internal checks" + +/-- Validate a batch of best-match head pattern certificates. -/ +def verifyHeadBestMatchPatternCerts (certs : Array HeadBestMatchPatternCert) : + Except String (Array HeadBestMatchPatternCert) := + let ok := certs.foldl (fun acc c => acc && c.check) true + if ok then + .ok certs + else + .error "head best-match sweep certificate failed internal checks" + +/-- Validate a layer-level best-match margin certificate. -/ +def verifyLayerBestMatchMarginCert (cert : LayerBestMatchMarginCert) : + Except String LayerBestMatchMarginCert := + if cert.check then + .ok cert + else + .error "layer best-match margin certificate failed internal checks" + +/-- Validate a head output lower-bound certificate. -/ +def verifyHeadValueLowerBoundCert (cert : HeadValueLowerBoundCert) : + Except String HeadValueLowerBoundCert := + if cert.check then + .ok cert + else + .error "head value lower bound certificate failed internal checks" + +/-- Validate a head logit-difference lower-bound certificate. -/ +def verifyHeadLogitDiffLowerBoundCert (cert : HeadLogitDiffLowerBoundCert) : + Except String HeadLogitDiffLowerBoundCert := + if cert.check then + .ok cert + else + .error "head logit-diff lower bound certificate failed internal checks" + +/-- Validate an induction-head certificate. -/ +def verifyInductionHeadSoundCert (cert : InductionHeadSoundCert) : + Except String InductionHeadSoundCert := + if cert.check then + .ok cert + else + .error "induction head certificate failed internal checks" + +/-- Validate a best-match induction-head certificate. -/ +def verifyInductionHeadBestMatchSoundCert (cert : InductionHeadBestMatchSoundCert) : + Except String InductionHeadBestMatchSoundCert := + if cert.check then + .ok cert + else + .error "best-match induction head certificate failed internal checks" + +/-- Locate a local head contribution certificate for a specific layer/head. -/ +def findHeadLocalContribution (certs : Array HeadLocalContributionCert) + (layerIdx headIdx : Nat) : Except String HeadLocalContributionCert := + match certs.find? (fun c => c.layerIdx == layerIdx && c.headIdx == headIdx) with + | some c => .ok c + | none => .error s!"no local head contribution cert for layer {layerIdx} head {headIdx}" + +/-- Tighten a local head contribution certificate using best-match evidence. -/ +def tightenHeadLocalContributionBestMatch + (eps : Rat) + (soundnessBits : Nat) + (base : HeadLocalContributionCert) + (pattern : HeadBestMatchPatternCert) + (softmaxExpEffort : Nat) : Except String HeadLocalContributionCert := + Id.run do + let _ ← verifyHeadLocalContributionCert eps soundnessBits base + let _ ← verifyHeadBestMatchPatternCert pattern + if pattern.layerIdx ≠ base.layerIdx || pattern.headIdx ≠ base.headIdx then + return .error "best-match pattern cert layer/head mismatch" + if pattern.softmaxExpEffort ≠ softmaxExpEffort then + return .error "best-match pattern cert softmax effort mismatch" + let softmaxBound := pattern.softmaxJacobianNormInfUpperBound + if softmaxBound > base.softmaxJacobianNormInfUpperBound then + return .error "best-match softmax bound is worse than baseline" + let attnJacBound := + base.ln1Bound * softmaxBound * base.wvOpBound * base.woOpBound + let tightened := + { base with + softmaxJacobianNormInfUpperBound := softmaxBound + attnJacBound := attnJacBound } + if tightened.check eps then + return .ok tightened + return .error "tightened head contribution certificate failed internal checks" + /-! ### Specs -/ theorem HeadContributionCert.Valid_spec : @@ -571,5 +700,31 @@ theorem InductionHeadBestMatchSoundCert.Valid_spec : InductionHeadBestMatchSoundCert.Valid = InductionHeadBestMatchSoundCert.Valid := rfl theorem InductionHeadBestMatchSoundCert.check_spec : InductionHeadBestMatchSoundCert.check = InductionHeadBestMatchSoundCert.check := rfl +theorem verifyHeadContributionCerts_spec : + verifyHeadContributionCerts = verifyHeadContributionCerts := rfl +theorem verifyHeadLocalContributionCerts_spec : + verifyHeadLocalContributionCerts = verifyHeadLocalContributionCerts := rfl +theorem verifyHeadLocalContributionCert_spec : + verifyHeadLocalContributionCert = verifyHeadLocalContributionCert := rfl +theorem verifyHeadPatternCert_spec : + verifyHeadPatternCert = verifyHeadPatternCert := rfl +theorem verifyHeadBestMatchPatternCert_spec : + verifyHeadBestMatchPatternCert = verifyHeadBestMatchPatternCert := rfl +theorem verifyHeadBestMatchPatternCerts_spec : + verifyHeadBestMatchPatternCerts = verifyHeadBestMatchPatternCerts := rfl +theorem verifyLayerBestMatchMarginCert_spec : + verifyLayerBestMatchMarginCert = verifyLayerBestMatchMarginCert := rfl +theorem verifyHeadValueLowerBoundCert_spec : + verifyHeadValueLowerBoundCert = verifyHeadValueLowerBoundCert := rfl +theorem verifyHeadLogitDiffLowerBoundCert_spec : + verifyHeadLogitDiffLowerBoundCert = verifyHeadLogitDiffLowerBoundCert := rfl +theorem verifyInductionHeadSoundCert_spec : + verifyInductionHeadSoundCert = verifyInductionHeadSoundCert := rfl +theorem verifyInductionHeadBestMatchSoundCert_spec : + verifyInductionHeadBestMatchSoundCert = verifyInductionHeadBestMatchSoundCert := rfl +theorem findHeadLocalContribution_spec : + findHeadLocalContribution = findHeadLocalContribution := rfl +theorem tightenHeadLocalContributionBestMatch_spec : + tightenHeadLocalContributionBestMatch = tightenHeadLocalContributionBestMatch := rfl end Nfp.Sound diff --git a/Nfp/Sound/IO.lean b/Nfp/Sound/IO.lean index fcd7447..1f0808f 100644 --- a/Nfp/Sound/IO.lean +++ b/Nfp/Sound/IO.lean @@ -204,10 +204,7 @@ def certifyHeadBoundsBinary match ← Nfp.Untrusted.SoundCompute.certifyHeadBoundsBinary path scalePow10 with | .error e => return .error e | .ok certs => - let ok := certs.foldl (fun acc c => acc && c.check) true - if ok then - return .ok certs - return .error "head contribution certificate failed internal checks" + return verifyHeadContributionCerts certs /-- Soundly compute conservative per-layer residual amplification constants from a `.nfpt` file. -/ def certifyModelFileGlobal @@ -227,27 +224,12 @@ def certifyModelFileGlobal softmaxMarginLowerBound softmaxExpEffort with | .error e => return .error e | .ok cert => - if cert.eps ≠ eps then - return .error "model header eps mismatch" - if cert.soundnessBits ≠ soundnessBits then - return .error "soundness bits mismatch" - if cert.geluDerivTarget ≠ geluTarget then - return .error "model header gelu_kind mismatch" - if cert.check then - match checkSoftmaxProbIntervalWorst cert with - | .error e => return .error e - | .ok _ => - match checkSoftmaxMarginZero cert with - | .error e => return .error e - | .ok _ => - match ← recomputeAttnWeightBounds path with - | .error e => - return .error s!"attnWeightBounds verification failed: {e}" - | .ok bounds => - match checkAttnWeightBounds cert bounds with - | .error e => return .error e - | .ok _ => return .ok cert - return .error "sound certificate failed internal consistency checks" + match ← recomputeAttnWeightBounds path with + | .error e => + return .error s!"attnWeightBounds verification failed: {e}" + | .ok bounds => + return verifyModelCert cert eps soundnessBits geluTarget + bounds.attnValueCoeff bounds.wqOpBoundMax bounds.wkOpBoundMax /-- Entry point for sound certification (global or local). -/ def certifyModelFile @@ -267,27 +249,12 @@ def certifyModelFile softmaxMarginLowerBound softmaxExpEffort with | .error e => return .error e | .ok cert => - if cert.eps ≠ eps then - return .error "model header eps mismatch" - if cert.soundnessBits ≠ soundnessBits then - return .error "soundness bits mismatch" - if cert.geluDerivTarget ≠ geluTarget then - return .error "model header gelu_kind mismatch" - if cert.check then - match checkSoftmaxProbIntervalWorst cert with - | .error e => return .error e - | .ok _ => - match checkSoftmaxMarginZero cert with - | .error e => return .error e - | .ok _ => - match ← recomputeAttnWeightBounds path with - | .error e => - return .error s!"attnWeightBounds verification failed: {e}" - | .ok bounds => - match checkAttnWeightBounds cert bounds with - | .error e => return .error e - | .ok _ => return .ok cert - return .error "sound certificate failed internal consistency checks" + match ← recomputeAttnWeightBounds path with + | .error e => + return .error s!"attnWeightBounds verification failed: {e}" + | .ok bounds => + return verifyModelCert cert eps soundnessBits geluTarget + bounds.attnValueCoeff bounds.wqOpBoundMax bounds.wkOpBoundMax /-- Compute per-head contribution bounds (global). -/ def certifyHeadBounds @@ -297,10 +264,7 @@ def certifyHeadBounds match ← Nfp.Untrusted.SoundCompute.certifyHeadBounds path scalePow10 with | .error e => return .error e | .ok certs => - let ok := certs.foldl (fun acc c => acc && c.check) true - if ok then - return .ok certs - return .error "head contribution certificate failed internal checks" + return verifyHeadContributionCerts certs /-- Compute local per-head attention contribution bounds. -/ def certifyHeadBoundsLocal @@ -318,12 +282,7 @@ def certifyHeadBoundsLocal path eps inputPath? inputDelta soundnessBits scalePow10 with | .error e => return .error e | .ok certs => - let ok := - certs.foldl (fun acc c => - acc && c.soundnessBits = soundnessBits && c.check eps) true - if ok then - return .ok certs - return .error "local head contribution certificate failed internal checks" + return verifyHeadLocalContributionCerts eps soundnessBits certs /-- Compute local attention pattern bounds for a specific head. -/ def certifyHeadPatternLocal @@ -349,9 +308,7 @@ def certifyHeadPatternLocal tightPattern tightPatternLayers perRowPatternLayers scalePow10 softmaxExpEffort with | .error e => return .error e | .ok cert => - if cert.check then - return .ok cert - return .error "head pattern certificate failed internal checks" + return verifyHeadPatternCert cert /-- Compute local best-match pattern bounds for a specific head. -/ def certifyHeadPatternBestMatchLocal @@ -379,9 +336,7 @@ def certifyHeadPatternBestMatchLocal softmaxExpEffort with | .error e => return .error e | .ok cert => - if cert.check then - return .ok cert - return .error "head best-match pattern certificate failed internal checks" + return verifyHeadBestMatchPatternCert cert /-- Compute local best-match pattern bounds for a sweep of heads. -/ def certifyHeadPatternBestMatchLocalSweep @@ -407,10 +362,7 @@ def certifyHeadPatternBestMatchLocalSweep tightPattern tightPatternLayers perRowPatternLayers scalePow10 softmaxExpEffort with | .error e => return .error e | .ok certs => - let ok := certs.foldl (fun acc c => acc && c.check) true - if ok then - return .ok certs - return .error "head best-match sweep certificate failed internal checks" + return verifyHeadBestMatchPatternCerts certs /-- Compute layer-level best-match margin evidence (binary only). -/ def certifyLayerBestMatchMarginLocal @@ -436,9 +388,7 @@ def certifyLayerBestMatchMarginLocal tightPattern tightPatternLayers perRowPatternLayers scalePow10 softmaxExpEffort with | .error e => return .error e | .ok cert => - if cert.check then - return .ok cert - return .error "layer best-match margin certificate failed internal checks" + return verifyLayerBestMatchMarginCert cert /-- Compute local per-head attention contribution bounds tightened by best-match pattern evidence. -/ @@ -466,12 +416,9 @@ def certifyHeadBoundsLocalBestMatch (soundnessBits := soundnessBits) (scalePow10 := scalePow10) with | .error e => return .error e | .ok certs => - let base? := - certs.find? (fun c => c.layerIdx == layerIdx && c.headIdx == headIdx) - match base? with - | none => - return .error s!"no local head contribution cert for layer {layerIdx} head {headIdx}" - | some base => + match findHeadLocalContribution certs layerIdx headIdx with + | .error e => return .error e + | .ok base => match ← certifyHeadPatternBestMatchLocal path layerIdx headIdx (queryPos? := queryPos?) (inputPath? := inputPath?) @@ -482,22 +429,8 @@ def certifyHeadBoundsLocalBestMatch (softmaxExpEffort := softmaxExpEffort) with | .error e => return .error e | .ok pattern => - if pattern.layerIdx ≠ layerIdx || pattern.headIdx ≠ headIdx then - return .error "best-match pattern cert layer/head mismatch" - if pattern.softmaxExpEffort ≠ softmaxExpEffort then - return .error "best-match pattern cert softmax effort mismatch" - let softmaxBound := pattern.softmaxJacobianNormInfUpperBound - if softmaxBound > base.softmaxJacobianNormInfUpperBound then - return .error "best-match softmax bound is worse than baseline" - let attnJacBound := - base.ln1Bound * softmaxBound * base.wvOpBound * base.woOpBound - let tightened := - { base with - softmaxJacobianNormInfUpperBound := softmaxBound - attnJacBound := attnJacBound } - if tightened.check eps then - return .ok tightened - return .error "tightened head contribution certificate failed internal checks" + return tightenHeadLocalContributionBestMatch + eps soundnessBits base pattern softmaxExpEffort /-- Compute local head output lower bounds. -/ def certifyHeadValueLowerBoundLocal @@ -523,9 +456,7 @@ def certifyHeadValueLowerBoundLocal maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 with | .error e => return .error e | .ok cert => - if cert.check then - return .ok cert - return .error "head value lower bound certificate failed internal checks" + return verifyHeadValueLowerBoundCert cert /-- Compute local head logit-difference lower bounds. -/ def certifyHeadLogitDiffLowerBoundLocal @@ -552,9 +483,7 @@ def certifyHeadLogitDiffLowerBoundLocal perRowPatternLayers scalePow10 with | .error e => return .error e | .ok cert => - if cert.check then - return .ok cert - return .error "head logit-diff lower bound certificate failed internal checks" + return verifyHeadLogitDiffLowerBoundCert cert /-- Sound induction-head certification (local path). -/ def certifyInductionSound @@ -585,9 +514,7 @@ def certifyInductionSound perRowPatternLayers targetToken? negativeToken? softmaxExpEffort with | .error e => return .error e | .ok cert => - if cert.check then - return .ok cert - return .error "induction head certificate failed internal checks" + return verifyInductionHeadSoundCert cert /-- Sound best-match induction-head certification (local path). -/ def certifyInductionSoundBestMatch @@ -619,8 +546,6 @@ def certifyInductionSoundBestMatch tightPatternLayers perRowPatternLayers targetToken? negativeToken? softmaxExpEffort with | .error e => return .error e | .ok cert => - if cert.check then - return .ok cert - return .error "best-match induction head certificate failed internal checks" + return verifyInductionHeadBestMatchSoundCert cert end Nfp.Sound diff --git a/Nfp/Sound/TextPure.lean b/Nfp/Sound/TextPure.lean index bc9d55c..715b4f2 100644 --- a/Nfp/Sound/TextPure.lean +++ b/Nfp/Sound/TextPure.lean @@ -33,25 +33,8 @@ structure AttnWeightBounds where /-- Verify that attention-weight bounds match the certificate layer fields. -/ def checkAttnWeightBounds (cert : ModelCert) (expected : AttnWeightBounds) : Except String Unit := - Id.run do - if expected.attnValueCoeff.size ≠ cert.layers.size then - return .error "attnValueCoeff layer count mismatch" - if expected.wqOpBoundMax.size ≠ cert.layers.size then - return .error "wqOpBoundMax layer count mismatch" - if expected.wkOpBoundMax.size ≠ cert.layers.size then - return .error "wkOpBoundMax layer count mismatch" - for idx in [:cert.layers.size] do - let expValue := expected.attnValueCoeff[idx]! - let expWq := expected.wqOpBoundMax[idx]! - let expWk := expected.wkOpBoundMax[idx]! - let layer := cert.layers[idx]! - if expValue ≠ layer.attnValueCoeff then - return .error s!"attnValueCoeff mismatch at layer {idx}" - if expWq ≠ layer.wqOpBoundMax then - return .error s!"wqOpBoundMax mismatch at layer {idx}" - if expWk ≠ layer.wkOpBoundMax then - return .error s!"wkOpBoundMax mismatch at layer {idx}" - return .ok () + checkAttnWeightBoundsArrays cert expected.attnValueCoeff expected.wqOpBoundMax + expected.wkOpBoundMax def parseTextHeaderDims (lines : Array String) : Except String TextModelDims := Id.run do From 8e55999ed0666f045287f6581e5d5d4f195f3dcf Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 26 Dec 2025 02:56:25 +0100 Subject: [PATCH 007/244] Move attn weight bounds to pure helper --- Nfp/Sound/BinaryPure.lean | 21 +++++++++++++++++++++ Nfp/Sound/IO.lean | 25 ++++++++++++------------- 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/Nfp/Sound/BinaryPure.lean b/Nfp/Sound/BinaryPure.lean index 8b4dadd..df62f94 100644 --- a/Nfp/Sound/BinaryPure.lean +++ b/Nfp/Sound/BinaryPure.lean @@ -317,6 +317,25 @@ def attnQKMaxFromScaledPairs (scalePow10 : Nat) (pairs : Array (Int × Int)) : R max acc.2 (ratOfScaledInt scalePow10 p.2))) (0, 0) +/-- Compute per-layer attention-weight bound arrays from scaled-int pairs. -/ +def attnWeightBoundsArraysFromScaledPairs (scalePow10 : Nat) + (valuePairs qkPairs : Array (Array (Int × Int))) : + Except String (Array Rat × Array Rat × Array Rat) := + Id.run do + if valuePairs.size ≠ qkPairs.size then + return .error s!"attn weight bounds layer count mismatch: \ +value={valuePairs.size}, qk={qkPairs.size}" + let mut coeffs : Array Rat := Array.mkEmpty valuePairs.size + let mut wqMaxs : Array Rat := Array.mkEmpty valuePairs.size + let mut wkMaxs : Array Rat := Array.mkEmpty valuePairs.size + for idx in [:valuePairs.size] do + let coeff := attnValueCoeffFromScaledPairs scalePow10 valuePairs[idx]! + let (wqMax, wkMax) := attnQKMaxFromScaledPairs scalePow10 qkPairs[idx]! + coeffs := coeffs.push coeff + wqMaxs := wqMaxs.push wqMax + wkMaxs := wkMaxs.push wkMax + return .ok (coeffs, wqMaxs, wkMaxs) + /-! ### Derived properties -/ private theorem pure_eq_ok {ε α : Type} (x : α) : (pure x : Except ε α) = .ok x := rfl @@ -381,5 +400,7 @@ theorem attnValueCoeffFromScaledPairs_spec_binary_pure : attnValueCoeffFromScaledPairs = attnValueCoeffFromScaledPairs := rfl theorem attnQKMaxFromScaledPairs_spec_binary_pure : attnQKMaxFromScaledPairs = attnQKMaxFromScaledPairs := rfl +theorem attnWeightBoundsArraysFromScaledPairs_spec_binary_pure : + attnWeightBoundsArraysFromScaledPairs = attnWeightBoundsArraysFromScaledPairs := rfl end Nfp.Sound diff --git a/Nfp/Sound/IO.lean b/Nfp/Sound/IO.lean index 1f0808f..2c54376 100644 --- a/Nfp/Sound/IO.lean +++ b/Nfp/Sound/IO.lean @@ -64,9 +64,8 @@ private def recomputeAttnWeightBoundsBinary match ← Nfp.Untrusted.SoundBinary.skipF64Array h (hdr.seqLen * hdr.modelDim) with | .error e => return .error e | .ok _ => pure () - let mut coeffs : Array Rat := Array.mkEmpty hdr.numLayers - let mut wqMaxs : Array Rat := Array.mkEmpty hdr.numLayers - let mut wkMaxs : Array Rat := Array.mkEmpty hdr.numLayers + let mut valuePairsLayers : Array (Array (Int × Int)) := Array.mkEmpty hdr.numLayers + let mut qkPairsLayers : Array (Array (Int × Int)) := Array.mkEmpty hdr.numLayers for _l in [:hdr.numLayers] do let mut valuePairs : Array (Int × Int) := Array.mkEmpty hdr.numHeads let mut qkPairs : Array (Int × Int) := Array.mkEmpty hdr.numHeads @@ -161,11 +160,8 @@ private def recomputeAttnWeightBoundsBinary let _ := ln1GammaScaled let _ := ln1BetaScaled let _ := ln2GammaScaled - let coeff := attnValueCoeffFromScaledPairs scalePow10 valuePairs - let (wqMax, wkMax) := attnQKMaxFromScaledPairs scalePow10 qkPairs - coeffs := coeffs.push coeff - wqMaxs := wqMaxs.push wqMax - wkMaxs := wkMaxs.push wkMax + valuePairsLayers := valuePairsLayers.push valuePairs + qkPairsLayers := qkPairsLayers.push qkPairs match ← Nfp.Untrusted.SoundBinary.skipF64Array h hdr.modelDim with | .error e => return .error e | .ok _ => pure () @@ -175,11 +171,14 @@ private def recomputeAttnWeightBoundsBinary match ← Nfp.Untrusted.SoundBinary.skipF64Array h (hdr.modelDim * hdr.vocabSize) with | .error e => return .error e | .ok _ => pure () - return .ok { - attnValueCoeff := coeffs - wqOpBoundMax := wqMaxs - wkOpBoundMax := wkMaxs - } + match attnWeightBoundsArraysFromScaledPairs scalePow10 valuePairsLayers qkPairsLayers with + | .error e => return .error e + | .ok (coeffs, wqMaxs, wkMaxs) => + return .ok { + attnValueCoeff := coeffs + wqOpBoundMax := wqMaxs + wkOpBoundMax := wkMaxs + } private def recomputeAttnWeightBoundsText (path : System.FilePath) : IO (Except String AttnWeightBounds) := do From 1c8070b5f58ad76ab91c3e35b97a7b1a64e5c3b1 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 26 Dec 2025 07:16:42 +0100 Subject: [PATCH 008/244] Extract pure IO helpers --- AGENTS.md | 2 + Nfp/IO.lean | 405 +-------------------------------------------- Nfp/IO/Pure.lean | 415 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 418 insertions(+), 404 deletions(-) create mode 100644 Nfp/IO/Pure.lean diff --git a/AGENTS.md b/AGENTS.md index 1f5d40b..245b366 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -331,6 +331,8 @@ but you **must** update this list in the same commit. - `IO.lean` - Parsing/loading/tokenization/report formatting glue. - **IO-only principle:** no heavy proofs; keep it as a bridge to filesystem/CLI. +- `IO/Pure.lean` + - Pure parsing, construction, and tokenization helpers used by `IO.lean`. - `Main.lean` - CLI entrypoint and subcommand wiring. Keep it thin: - argument parsing + calling into `Nfp.IO` / `Discovery` / `Nfp.Sound.*` reporting helpers, diff --git a/Nfp/IO.lean b/Nfp/IO.lean index a0a338b..dcc5f86 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -1,6 +1,6 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Discovery +import Nfp.IO.Pure /-! # Model IO: Loading Pre-trained Weights @@ -55,344 +55,6 @@ namespace Nfp open IO -/-! ## Float Parsing Utilities -/ - -private def pow10PowTable : Array Float := Id.run do - -- Precompute `Float.pow 10.0 k` for k=0..308 so we avoid calling `Float.pow` per token. - let mut out : Array Float := Array.mkEmpty 309 - for k in [:309] do - out := out.push (Float.pow 10.0 k.toFloat) - out - -private def pow10Pow (n : Nat) : Float := - if n < pow10PowTable.size then - pow10PowTable[n]! - else - Float.pow 10.0 n.toFloat - -private def parseFloatRange (s : String) (start stop : String.Pos.Raw) : Option Float := Id.run do - -- This is a faster, allocation-free version of the previous `parseFloat`, but it preserves - -- the exact Float computation structure (Nat parsing + `Float.pow`) to keep results stable. - - let parseNatRange (s : String) (start stop : String.Pos.Raw) : Option Nat := Id.run do - let mut p := start - if p ≥ stop then - return none - let mut acc : Nat := 0 - let mut saw : Bool := false - while p < stop do - let c := p.get s - if ('0' ≤ c) && (c ≤ '9') then - acc := acc * 10 + (c.toNat - '0'.toNat) - saw := true - p := p.next s - else - return none - if saw then some acc else none - - let mut p := start - if p ≥ stop then - return none - - let mut negative := false - let c0 := p.get s - if c0 = '-' then - negative := true - p := p.next s - else if c0 = '+' then - p := p.next s - - if p ≥ stop then - return none - - -- Find exponent marker the same way as the old parser: accept exactly one `e` if present, - -- otherwise accept exactly one `E`. - let mut ePos : Option String.Pos.Raw := none - let mut eCount : Nat := 0 - let mut EPos : Option String.Pos.Raw := none - let mut ECount : Nat := 0 - let mut q := p - while q < stop do - let c := q.get s - if c = 'e' then - eCount := eCount + 1 - if eCount = 1 then ePos := some q - else if c = 'E' then - ECount := ECount + 1 - if ECount = 1 then EPos := some q - q := q.next s - - let expMarker? : Option String.Pos.Raw := - if eCount = 1 then ePos else if ECount = 1 then EPos else none - - let mantEnd : String.Pos.Raw := - match expMarker? with - | some ep => ep - | none => stop - - -- Find decimal point in mantissa (must be 0 or 1 occurrences). - let mut dotPos : Option String.Pos.Raw := none - let mut dotCount : Nat := 0 - let mut r := p - while r < mantEnd do - if r.get s = '.' then - dotCount := dotCount + 1 - if dotCount = 1 then dotPos := some r - r := r.next s - if dotCount > 1 then - return none - - let (intStart, intStop, fracStart?, fracStop) := - match dotPos with - | none => (p, mantEnd, none, mantEnd) - | some dp => (p, dp, some (dp.next s), mantEnd) - - let intN? : Option Nat := - if dotPos.isSome && intStart = intStop then - some 0 - else - parseNatRange s intStart intStop - - let fracN? : Option Nat := - match fracStart? with - | none => none - | some fs => - if fs = fracStop then some 0 else parseNatRange s fs fracStop - - let mantissa? : Option Float := - match dotPos, intN?, fracN? with - | none, some iN, _ => - some iN.toFloat - | some _, some iN, some fN => - let fracLen := (fracStop.byteIdx - (fracStart?.getD fracStop).byteIdx) - let divisor := pow10Pow fracLen - some (iN.toFloat + fN.toFloat / divisor) - | some _, _, none => - -- `.` present but no fractional parse (shouldn't happen), treat as invalid. - none - | _, none, _ => none - - let some mantissa := mantissa? | return none - - let value : Float := - match expMarker? with - | none => mantissa - | some ep => - let expStart := ep.next s - if expStart ≥ stop then - mantissa - else - -- Parse exponent, but if it is malformed, ignore it (old behavior). - let c := expStart.get s - let (expNeg, es) := - if c = '-' then (true, expStart.next s) - else if c = '+' then (false, expStart.next s) - else (false, expStart) - match parseNatRange s es stop with - | none => mantissa - | some eNat => - let p10 := pow10Pow eNat - if expNeg then mantissa / p10 else mantissa * p10 - - some (if negative then -value else value) - -/-- Parse a floating point number from a string. -/ -def parseFloat (s : String) : Option Float := Id.run do - let s := s.trim - if s.isEmpty then - none - else - parseFloatRange s 0 s.rawEndPos - -private def appendFloatsFromLine (line : String) (acc : Array Float) : Array Float := Id.run do - let mut out := acc - let s := line - let mut p : String.Pos.Raw := 0 - let endPos := s.rawEndPos - let isWs (c : Char) : Bool := - c = ' ' || c = '\t' || c = '\n' || c = '\r' - while p < endPos do - while p < endPos && isWs (p.get s) do - p := p.next s - let start := p - while p < endPos && !isWs (p.get s) do - p := p.next s - if start < p then - match parseFloatRange s start p with - | some x => out := out.push x - | none => pure () - out - -private def parseFloatsFromLines (lines : Array String) (cap : Nat := 0) : Array Float := - Id.run do - let mut out : Array Float := Array.mkEmpty cap - for line in lines do - out := appendFloatsFromLine line out - out - -private def spawnParseFloats (lines : Array String) (cap : Nat := 0) : Task (Array Float) := - Task.spawn (fun _ => parseFloatsFromLines lines cap) - -/-- Parse a line of space-separated floats. -/ -def parseFloatLine (line : String) : Array Float := - appendFloatsFromLine line #[] - -/-! ## Nat Parsing Utilities -/ - -/-- Parse a line of space-separated natural numbers. -/ -private def parseNatRange (s : String) (start stop : String.Pos.Raw) : Option Nat := Id.run do - let mut p := start - if p ≥ stop then - return none - let mut acc : Nat := 0 - let mut saw : Bool := false - while p < stop do - let c := p.get s - if ('0' ≤ c) && (c ≤ '9') then - acc := acc * 10 + (c.toNat - '0'.toNat) - saw := true - p := p.next s - else - return none - if saw then some acc else none - -private def appendNatsFromLine (line : String) (acc : Array Nat) : Array Nat := Id.run do - let mut out := acc - let s := line - let mut p : String.Pos.Raw := 0 - let endPos := s.rawEndPos - let isWs (c : Char) : Bool := - c = ' ' || c = '\t' || c = '\n' || c = '\r' - while p < endPos do - while p < endPos && isWs (p.get s) do - p := p.next s - let start := p - while p < endPos && !isWs (p.get s) do - p := p.next s - if start < p then - match parseNatRange s start p with - | some n => out := out.push n - | none => pure () - out - -def parseNatLine (line : String) : Array Nat := - appendNatsFromLine line #[] - -/-! ## Matrix Construction for IO - -For IO operations, we need to create matrices from runtime data. -We use a safe construction that always ensures size invariants hold -by padding/truncating the data. --/ - -/-- Build a ConcreteMatrix from float data, padding or truncating as needed. - This is safe because we ensure the data has exactly the right size. -/ -def buildMatrix (rows cols : Nat) (data : Array Float) : ConcreteMatrix := - let expectedSize := rows * cols - let normalizedData : Array Float := - if data.size < expectedSize then - data ++ (Array.replicate (expectedSize - data.size) 0.0) - else if data.size > expectedSize then - data.toSubarray 0 expectedSize |>.toArray - else - data - -- Use Array.ofFn to get the exact size we need with a proof - let finalData := Array.ofFn fun (i : Fin expectedSize) => - normalizedData.getD i.val 0.0 - { - numRows := rows - numCols := cols - data := finalData - size_eq := Array.size_ofFn - } - -/-- Result of loading a model. -/ -inductive LoadResult - | ok (model : ConcreteModel) - | error (msg : String) - -namespace LoadResult - -def isOk : LoadResult → Bool - | ok _ => true - | error _ => false - -def getModel : LoadResult → Option ConcreteModel - | ok m => some m - | error _ => none - -def getError : LoadResult → Option String - | ok _ => none - | error msg => some msg - -end LoadResult - -/-! ## Text Format Parsing -/ - -/-- NFP file header structure. -/ -structure NfpHeader where - numLayers : Nat - numHeads : Nat - modelDim : Nat - headDim : Nat - hiddenDim : Nat - vocabSize : Nat - seqLen : Nat - deriving Repr - -/-- Build a ConcreteAttentionLayer from weight matrices. - The dimension proofs are satisfied by construction (buildMatrix ensures correct sizes). -/ -def mkAttentionLayer - (modelDim headDim : Nat) - (wq wk wv wo bq bk bv : Array Float) : ConcreteAttentionLayer := - let wQ := buildMatrix modelDim headDim wq - let bQ := buildMatrix 1 headDim bq - let wK := buildMatrix modelDim headDim wk - let bK := buildMatrix 1 headDim bk - let wV := buildMatrix modelDim headDim wv - let bV := buildMatrix 1 headDim bv - let wO := buildMatrix headDim modelDim wo - { - modelDim := modelDim - headDim := headDim - W_Q := wQ - b_Q := bQ - W_K := wK - b_K := bK - W_V := wV - b_V := bV - W_O := wO - W_Q_dims := ⟨rfl, rfl⟩ - b_Q_dims := ⟨rfl, rfl⟩ - W_K_dims := ⟨rfl, rfl⟩ - b_K_dims := ⟨rfl, rfl⟩ - W_V_dims := ⟨rfl, rfl⟩ - b_V_dims := ⟨rfl, rfl⟩ - W_O_dims := ⟨rfl, rfl⟩ - } - -/-- Build a ConcreteMLPLayer from weight matrices. - The dimension proofs are satisfied by construction. -/ -def mkMLPLayer - (modelDim hiddenDim : Nat) - (win wout bin bout : Array Float) : ConcreteMLPLayer := - let wIn := buildMatrix modelDim hiddenDim win - let wOut := buildMatrix hiddenDim modelDim wout - let bIn := buildMatrix 1 hiddenDim bin - let bOut := buildMatrix 1 modelDim bout - { - modelDim := modelDim - hiddenDim := hiddenDim - W_in := wIn - W_out := wOut - b_in := bIn - b_out := bOut - W_in_dims := ⟨rfl, rfl⟩ - W_out_dims := ⟨rfl, rfl⟩ - b_in_dims := ⟨rfl, rfl⟩ - b_out_dims := ⟨rfl, rfl⟩ - } - /-- Load a model from NFP text format content. -/ def loadFromText (_content : String) : IO LoadResult := do return .error "NFP_TEXT format is deprecated; use NFP_BINARY_V1" @@ -607,71 +269,6 @@ def loadModel (path : System.FilePath) : IO LoadResult := do else return .error s!"Unsupported file format: {path.extension.getD "unknown"}" -/-! ## Tokenization Utilities -/ - -/-- Simple tokenizer with vocabulary mapping. -/ -structure Tokenizer where - /-- Token strings in order of ID -/ - tokens : Array String - /-- Unknown token ID -/ - unkId : Nat - /-- Padding token ID -/ - padId : Nat - /-- End of sequence token ID -/ - eosId : Nat - -namespace Tokenizer - -/-- Create a tokenizer from vocabulary list. -/ -def fromVocabList (tokens : Array String) - (unkId padId eosId : Nat := 0) : Tokenizer := - { tokens := tokens, unkId := unkId, padId := padId, eosId := eosId } - -/-- Find a token's ID in the vocabulary. -/ -def findToken (t : Tokenizer) (word : String) : Nat := - match t.tokens.findIdx? (· == word) with - | some idx => idx - | none => t.unkId - -/-- Tokenize a string using simple whitespace splitting. -/ -def tokenize (t : Tokenizer) (text : String) : Array Nat := Id.run do - let words := text.splitOn " " |>.filter (· ≠ "") - let mut ids : Array Nat := #[] - for word in words do - ids := ids.push (t.findToken word) - ids - -/-- Decode token IDs back to text. -/ -def decode (t : Tokenizer) (ids : Array Nat) : String := - let tokens := ids.filterMap fun id => - if id < t.tokens.size then some t.tokens[id]! - else none - " ".intercalate tokens.toList - -end Tokenizer - -/-- Look up embeddings for token IDs from the embedding matrix. -/ -def lookupEmbeddings (embeddings : ConcreteMatrix) (tokenIds : Array Nat) - (seqLen : Nat) (padId : Nat := 0) : ConcreteMatrix := Id.run do - let modelDim := embeddings.numCols - let mut data : Array Float := #[] - - for pos in [:seqLen] do - let tokenId := if pos < tokenIds.size then tokenIds[pos]! else padId - -- Copy embedding row for this token - for dim in [:modelDim] do - let val := embeddings.get tokenId dim - data := data.push val - - buildMatrix seqLen modelDim data - -/-- Set the input embeddings in a model for a given prompt (token IDs). -/ -def ConcreteModel.withInputTokens (model : ConcreteModel) - (embeddings : ConcreteMatrix) (tokenIds : Array Nat) - (padId : Nat := 0) : ConcreteModel := - let inputEmb := lookupEmbeddings embeddings tokenIds model.seqLen padId - { model with inputEmbeddings := inputEmb, inputTokens := some tokenIds } - /-! ## Analysis Report Generation -/ /-- Format for circuit analysis results. -/ diff --git a/Nfp/IO/Pure.lean b/Nfp/IO/Pure.lean new file mode 100644 index 0000000..e6c49ee --- /dev/null +++ b/Nfp/IO/Pure.lean @@ -0,0 +1,415 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Discovery + +/-! +# Pure helpers for model IO + +Pure parsing, construction, and tokenization utilities shared by the CLI-facing IO layer. +-/ + +namespace Nfp + +/-! ## Float Parsing Utilities -/ + +private def pow10PowTable : Array Float := Id.run do + -- Precompute `Float.pow 10.0 k` for k=0..308 so we avoid calling `Float.pow` per token. + let mut out : Array Float := Array.mkEmpty 309 + for k in [:309] do + out := out.push (Float.pow 10.0 k.toFloat) + out + +private def pow10Pow (n : Nat) : Float := + if n < pow10PowTable.size then + pow10PowTable[n]! + else + Float.pow 10.0 n.toFloat + +private def parseFloatRange (s : String) (start stop : String.Pos.Raw) : Option Float := Id.run do + -- This is a faster, allocation-free version of the previous `parseFloat`, but it preserves + -- the exact Float computation structure (Nat parsing + `Float.pow`) to keep results stable. + + let parseNatRange (s : String) (start stop : String.Pos.Raw) : Option Nat := Id.run do + let mut p := start + if p >= stop then + return none + let mut acc : Nat := 0 + let mut saw : Bool := false + while p < stop do + let c := p.get s + if ('0' <= c) && (c <= '9') then + acc := acc * 10 + (c.toNat - '0'.toNat) + saw := true + p := p.next s + else + return none + if saw then some acc else none + + let mut p := start + if p >= stop then + return none + + let mut negative := false + let c0 := p.get s + if c0 = '-' then + negative := true + p := p.next s + else if c0 = '+' then + p := p.next s + + if p >= stop then + return none + + -- Find exponent marker the same way as the old parser: accept exactly one `e` if present, + -- otherwise accept exactly one `E`. + let mut ePos : Option String.Pos.Raw := none + let mut eCount : Nat := 0 + let mut EPos : Option String.Pos.Raw := none + let mut ECount : Nat := 0 + let mut q := p + while q < stop do + let c := q.get s + if c = 'e' then + eCount := eCount + 1 + if eCount = 1 then ePos := some q + else if c = 'E' then + ECount := ECount + 1 + if ECount = 1 then EPos := some q + q := q.next s + + let expMarker? : Option String.Pos.Raw := + if eCount = 1 then ePos else if ECount = 1 then EPos else none + + let mantEnd : String.Pos.Raw := + match expMarker? with + | some ep => ep + | none => stop + + -- Find decimal point in mantissa (must be 0 or 1 occurrences). + let mut dotPos : Option String.Pos.Raw := none + let mut dotCount : Nat := 0 + let mut r := p + while r < mantEnd do + if r.get s = '.' then + dotCount := dotCount + 1 + if dotCount = 1 then dotPos := some r + r := r.next s + if dotCount > 1 then + return none + + let (intStart, intStop, fracStart?, fracStop) := + match dotPos with + | none => (p, mantEnd, none, mantEnd) + | some dp => (p, dp, some (dp.next s), mantEnd) + + let intN? : Option Nat := + if dotPos.isSome && intStart = intStop then + some 0 + else + parseNatRange s intStart intStop + + let fracN? : Option Nat := + match fracStart? with + | none => none + | some fs => + if fs = fracStop then some 0 else parseNatRange s fs fracStop + + let mantissa? : Option Float := + match dotPos, intN?, fracN? with + | none, some iN, _ => + some iN.toFloat + | some _, some iN, some fN => + let fracLen := (fracStop.byteIdx - (fracStart?.getD fracStop).byteIdx) + let divisor := pow10Pow fracLen + some (iN.toFloat + fN.toFloat / divisor) + | some _, _, none => + -- `.` present but no fractional parse (shouldn't happen), treat as invalid. + none + | _, none, _ => none + + let some mantissa := mantissa? | return none + + let value : Float := + match expMarker? with + | none => mantissa + | some ep => + let expStart := ep.next s + if expStart >= stop then + mantissa + else + -- Parse exponent, but if it is malformed, ignore it (old behavior). + let c := expStart.get s + let (expNeg, es) := + if c = '-' then (true, expStart.next s) + else if c = '+' then (false, expStart.next s) + else (false, expStart) + match parseNatRange s es stop with + | none => mantissa + | some eNat => + let p10 := pow10Pow eNat + if expNeg then mantissa / p10 else mantissa * p10 + + some (if negative then -value else value) + +/-- Parse a floating point number from a string. -/ +def parseFloat (s : String) : Option Float := Id.run do + let s := s.trim + if s.isEmpty then + none + else + parseFloatRange s 0 s.rawEndPos + +private def appendFloatsFromLine (line : String) (acc : Array Float) : Array Float := Id.run do + let mut out := acc + let s := line + let mut p : String.Pos.Raw := 0 + let endPos := s.rawEndPos + let isWs (c : Char) : Bool := + c = ' ' || c = '\t' || c = '\n' || c = '\r' + while p < endPos do + while p < endPos && isWs (p.get s) do + p := p.next s + let start := p + while p < endPos && !isWs (p.get s) do + p := p.next s + if start < p then + match parseFloatRange s start p with + | some x => out := out.push x + | none => pure () + out + +private def parseFloatsFromLines (lines : Array String) (cap : Nat := 0) : Array Float := + Id.run do + let mut out : Array Float := Array.mkEmpty cap + for line in lines do + out := appendFloatsFromLine line out + out + +private def spawnParseFloats (lines : Array String) (cap : Nat := 0) : Task (Array Float) := + Task.spawn (fun _ => parseFloatsFromLines lines cap) + +/-- Parse a line of space-separated floats. -/ +def parseFloatLine (line : String) : Array Float := + appendFloatsFromLine line #[] + +/-! ## Nat Parsing Utilities -/ + +/-- Parse a line of space-separated natural numbers. -/ +private def parseNatRange (s : String) (start stop : String.Pos.Raw) : Option Nat := Id.run do + let mut p := start + if p >= stop then + return none + let mut acc : Nat := 0 + let mut saw : Bool := false + while p < stop do + let c := p.get s + if ('0' <= c) && (c <= '9') then + acc := acc * 10 + (c.toNat - '0'.toNat) + saw := true + p := p.next s + else + return none + if saw then some acc else none + +private def appendNatsFromLine (line : String) (acc : Array Nat) : Array Nat := Id.run do + let mut out := acc + let s := line + let mut p : String.Pos.Raw := 0 + let endPos := s.rawEndPos + let isWs (c : Char) : Bool := + c = ' ' || c = '\t' || c = '\n' || c = '\r' + while p < endPos do + while p < endPos && isWs (p.get s) do + p := p.next s + let start := p + while p < endPos && !isWs (p.get s) do + p := p.next s + if start < p then + match parseNatRange s start p with + | some n => out := out.push n + | none => pure () + out + +def parseNatLine (line : String) : Array Nat := + appendNatsFromLine line #[] + +/-! ## Matrix Construction for IO -/ + +/- Build a ConcreteMatrix from float data, padding or truncating as needed. + This is safe because we ensure the data has exactly the right size. -/ +def buildMatrix (rows cols : Nat) (data : Array Float) : ConcreteMatrix := + let expectedSize := rows * cols + let normalizedData : Array Float := + if data.size < expectedSize then + data ++ (Array.replicate (expectedSize - data.size) 0.0) + else if data.size > expectedSize then + data.toSubarray 0 expectedSize |>.toArray + else + data + -- Use Array.ofFn to get the exact size we need with a proof. + let finalData := Array.ofFn fun (i : Fin expectedSize) => + normalizedData.getD i.val 0.0 + { + numRows := rows + numCols := cols + data := finalData + size_eq := Array.size_ofFn + } + +/-! ## Load Result Helpers -/ + +/-- Result of loading a model. -/ +inductive LoadResult + | ok (model : ConcreteModel) + | error (msg : String) + +namespace LoadResult + +def isOk : LoadResult -> Bool + | ok _ => true + | error _ => false + +def getModel : LoadResult -> Option ConcreteModel + | ok m => some m + | error _ => none + +def getError : LoadResult -> Option String + | ok _ => none + | error msg => some msg + +end LoadResult + +/-! ## Text Format Parsing -/ + +/-- NFP file header structure. -/ +structure NfpHeader where + numLayers : Nat + numHeads : Nat + modelDim : Nat + headDim : Nat + hiddenDim : Nat + vocabSize : Nat + seqLen : Nat + deriving Repr + +/- Build a ConcreteAttentionLayer from weight matrices. + The dimension proofs are satisfied by construction (buildMatrix ensures correct sizes). -/ +def mkAttentionLayer + (modelDim headDim : Nat) + (wq wk wv wo bq bk bv : Array Float) : ConcreteAttentionLayer := + let wQ := buildMatrix modelDim headDim wq + let bQ := buildMatrix 1 headDim bq + let wK := buildMatrix modelDim headDim wk + let bK := buildMatrix 1 headDim bk + let wV := buildMatrix modelDim headDim wv + let bV := buildMatrix 1 headDim bv + let wO := buildMatrix headDim modelDim wo + { + modelDim := modelDim + headDim := headDim + W_Q := wQ + b_Q := bQ + W_K := wK + b_K := bK + W_V := wV + b_V := bV + W_O := wO + W_Q_dims := And.intro rfl rfl + b_Q_dims := And.intro rfl rfl + W_K_dims := And.intro rfl rfl + b_K_dims := And.intro rfl rfl + W_V_dims := And.intro rfl rfl + b_V_dims := And.intro rfl rfl + W_O_dims := And.intro rfl rfl + } + +/- Build a ConcreteMLPLayer from weight matrices. + The dimension proofs are satisfied by construction. -/ +def mkMLPLayer + (modelDim hiddenDim : Nat) + (win wout bin bout : Array Float) : ConcreteMLPLayer := + let wIn := buildMatrix modelDim hiddenDim win + let wOut := buildMatrix hiddenDim modelDim wout + let bIn := buildMatrix 1 hiddenDim bin + let bOut := buildMatrix 1 modelDim bout + { + modelDim := modelDim + hiddenDim := hiddenDim + W_in := wIn + W_out := wOut + b_in := bIn + b_out := bOut + W_in_dims := And.intro rfl rfl + W_out_dims := And.intro rfl rfl + b_in_dims := And.intro rfl rfl + b_out_dims := And.intro rfl rfl + } + +/-! ## Tokenization Utilities -/ + +/-- Simple tokenizer with vocabulary mapping. -/ +structure Tokenizer where + /-- Token strings in order of ID. -/ + tokens : Array String + /-- Unknown token ID. -/ + unkId : Nat + /-- Padding token ID. -/ + padId : Nat + /-- End of sequence token ID. -/ + eosId : Nat + +namespace Tokenizer + +/-- Create a tokenizer from vocabulary list. -/ +def fromVocabList (tokens : Array String) + (unkId padId eosId : Nat := 0) : Tokenizer := + { tokens := tokens, unkId := unkId, padId := padId, eosId := eosId } + +/-- Find a token's ID in the vocabulary. -/ +def findToken (t : Tokenizer) (word : String) : Nat := + match t.tokens.findIdx? (fun tok => tok == word) with + | some idx => idx + | none => t.unkId + +/-- Tokenize a string using simple whitespace splitting. -/ +def tokenize (t : Tokenizer) (text : String) : Array Nat := Id.run do + let words := text.splitOn " " |>.filter (fun w => w != "") + let mut ids : Array Nat := #[] + for word in words do + ids := ids.push (t.findToken word) + ids + +/-- Decode token IDs back to text. -/ +def decode (t : Tokenizer) (ids : Array Nat) : String := + let tokens := ids.filterMap fun id => + if id < t.tokens.size then some t.tokens[id]! + else none + " ".intercalate tokens.toList + +end Tokenizer + +/-! ## Embedding Utilities -/ + +/-- Look up embeddings for token IDs from the embedding matrix. -/ +def lookupEmbeddings (embeddings : ConcreteMatrix) (tokenIds : Array Nat) + (seqLen : Nat) (padId : Nat := 0) : ConcreteMatrix := Id.run do + let modelDim := embeddings.numCols + let mut data : Array Float := #[] + + for pos in [:seqLen] do + let tokenId := if pos < tokenIds.size then tokenIds[pos]! else padId + -- Copy embedding row for this token. + for dim in [:modelDim] do + let val := embeddings.get tokenId dim + data := data.push val + + buildMatrix seqLen modelDim data + +/-- Set the input embeddings in a model for a given prompt (token IDs). -/ +def ConcreteModel.withInputTokens (model : ConcreteModel) + (embeddings : ConcreteMatrix) (tokenIds : Array Nat) + (padId : Nat := 0) : ConcreteModel := + let inputEmb := lookupEmbeddings embeddings tokenIds model.seqLen padId + { model with inputEmbeddings := inputEmb, inputTokens := some tokenIds } + +end Nfp From 1568583a420049f80ead7b28a0232b5db0535d26 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 26 Dec 2025 17:14:49 +0100 Subject: [PATCH 009/244] Add best-match margin tightening for certify --- Main.lean | 63 ++++++++++++++++++++++++++++++++++++++++----- Nfp/Sound/Cert.lean | 24 +++++++++++++++++ Nfp/Sound/IO.lean | 47 +++++++++++++++++++++++++++++++++ 3 files changed, 128 insertions(+), 6 deletions(-) diff --git a/Main.lean b/Main.lean index 278b3f6..55be692 100644 --- a/Main.lean +++ b/Main.lean @@ -873,6 +873,13 @@ private structure CertifyArgs where deltaStr : String softmaxMarginStr : String softmaxExpEffort : Nat + bestMatchMargins : Bool + targetOffset : Int + maxSeqLen : Nat + tightPattern : Bool + tightPatternLayers : Nat + perRowPatternLayers : Nat + scalePow10 : Nat outputPath? : Option System.FilePath private def parseCertifyArgs (p : Parsed) : CertifyArgs := @@ -885,6 +892,13 @@ private def parseCertifyArgs (p : Parsed) : CertifyArgs := let softmaxMarginStr := p.flag? "softmaxMargin" |>.map (·.as! String) |>.getD "0" let softmaxExpEffort := p.flag? "softmaxExpEffort" |>.map (·.as! Nat) |>.getD Nfp.Sound.defaultSoftmaxExpEffort + let bestMatchMargins := p.flag? "bestMatchMargins" |>.isSome + let targetOffset := p.flag? "targetOffset" |>.map (·.as! Int) |>.getD (-1) + let maxSeqLen := p.flag? "maxSeqLen" |>.map (·.as! Nat) |>.getD 0 + let tightPattern := p.flag? "tightPattern" |>.isSome + let tightPatternLayers := p.flag? "tightPatternLayers" |>.map (·.as! Nat) |>.getD 1 + let perRowPatternLayers := p.flag? "perRowPatternLayers" |>.map (·.as! Nat) |>.getD 0 + let scalePow10 := p.flag? "scalePow10" |>.map (·.as! Nat) |>.getD 9 let outputPath? := p.flag? "output" |>.map (·.as! String) |>.map (fun s => ⟨s⟩) { modelPath := ⟨modelPathStr⟩ modelPathStr := modelPathStr @@ -895,6 +909,13 @@ private def parseCertifyArgs (p : Parsed) : CertifyArgs := deltaStr := deltaStr softmaxMarginStr := softmaxMarginStr softmaxExpEffort := softmaxExpEffort + bestMatchMargins := bestMatchMargins + targetOffset := targetOffset + maxSeqLen := maxSeqLen + tightPattern := tightPattern + tightPatternLayers := tightPatternLayers + perRowPatternLayers := perRowPatternLayers + scalePow10 := scalePow10 outputPath? := outputPath? } private def runCertifyAction (args : CertifyArgs) : ExceptT String IO Nfp.Sound.ModelCert := do @@ -925,12 +946,35 @@ private def runCertifyAction (args : CertifyArgs) : ExceptT String IO Nfp.Sound. "local certification requested via --delta, but the model file has no \ EMBEDDINGS section before the first LAYER (legacy text format). Pass --input \ containing EMBEDDINGS or omit --delta for global certification." - let cert ← ExceptT.mk <| - Nfp.Sound.certifyModelFile args.modelPath args.soundnessBits - (inputPath? := inputPath?) (inputDelta := delta) (partitionDepth := args.partitionDepth) - (softmaxMarginLowerBound := softmaxMarginLowerBound) - (softmaxExpEffort := args.softmaxExpEffort) - return cert + let inputPath? ← + if args.bestMatchMargins && inputPath?.isNone then + let hasEmbeddings ← hasEmbeddingsBeforeLayers args.modelPath + if hasEmbeddings then + pure (some args.modelPath) + else + throw <| + "best-match margin tightening requires local input with EMBEDDINGS. \ +Pass --input or use a model file that embeds EMBEDDINGS." + else + pure inputPath? + if args.bestMatchMargins && softmaxMarginLowerBound != 0 then + throw "best-match margin tightening is incompatible with --softmaxMargin" + if args.bestMatchMargins then + let cert ← ExceptT.mk <| + Nfp.Sound.certifyModelFileBestMatchMargins args.modelPath args.soundnessBits + (inputPath? := inputPath?) (inputDelta := delta) (partitionDepth := args.partitionDepth) + (targetOffset := args.targetOffset) (maxSeqLen := args.maxSeqLen) + (tightPattern := args.tightPattern) (tightPatternLayers := args.tightPatternLayers) + (perRowPatternLayers := args.perRowPatternLayers) (scalePow10 := args.scalePow10) + (softmaxExpEffort := args.softmaxExpEffort) + return cert + else + let cert ← ExceptT.mk <| + Nfp.Sound.certifyModelFile args.modelPath args.soundnessBits + (inputPath? := inputPath?) (inputDelta := delta) (partitionDepth := args.partitionDepth) + (softmaxMarginLowerBound := softmaxMarginLowerBound) + (softmaxExpEffort := args.softmaxExpEffort) + return cert private structure HeadBoundsArgs where modelPath : System.FilePath @@ -1610,6 +1654,13 @@ for legacy text)" if --input is omitted, uses EMBEDDINGS in the model file when present)" softmaxMargin : String; "Lower bound on softmax logit margin (default: 0)" softmaxExpEffort : Nat; "Effort level for margin-based exp lower bounds (default: 1)" + bestMatchMargins; "Apply best-match margin tightening (binary + local only)" + targetOffset : Int; "Token-match offset for best-match margins (default: -1)" + maxSeqLen : Nat; "Max sequence length for best-match margins (default: 0 uses full seq_len)" + tightPattern; "Use tighter (slower) pattern bounds for best-match margins" + tightPatternLayers : Nat; "Number of layers using tight pattern bounds (default: 1)" + perRowPatternLayers : Nat; "Number of layers using per-row MLP propagation (default: 0)" + scalePow10 : Nat; "Fixed-point scale exponent for best-match margins (default: 9)" soundnessBits : Nat; "Dyadic sqrt precision bits for LayerNorm bounds (default: 20)" partitionDepth : Nat; "Partition depth for input splitting (default: 0; >0 scaffold only)" o, output : String; "Write report to file instead of stdout" diff --git a/Nfp/Sound/Cert.lean b/Nfp/Sound/Cert.lean index ac8d2e4..b04c76f 100644 --- a/Nfp/Sound/Cert.lean +++ b/Nfp/Sound/Cert.lean @@ -524,6 +524,30 @@ def verifyModelCert theorem verifyModelCert_spec : verifyModelCert = verifyModelCert := rfl +/-- Verify a model certificate and apply best-match margin tightening. -/ +def verifyModelCertBestMatchMargins + (cert : ModelCert) + (eps : Rat) + (soundnessBits : Nat) + (geluDerivTarget : GeluDerivTarget) + (attnValueCoeff wqOpBoundMax wkOpBoundMax : Array Rat) + (marginCerts : Array LayerBestMatchMarginCert) : Except String ModelCert := + Id.run do + match verifyModelCert cert eps soundnessBits geluDerivTarget + attnValueCoeff wqOpBoundMax wkOpBoundMax with + | .error e => return .error e + | .ok base => + match tightenModelCertBestMatchMargins base marginCerts with + | .error e => return .error e + | .ok tightened => + if tightened.check then + return .ok tightened + else + return .error "best-match margin tightening produced invalid cert" + +theorem verifyModelCertBestMatchMargins_spec : + verifyModelCertBestMatchMargins = verifyModelCertBestMatchMargins := rfl + /-! ### Specs -/ end Nfp.Sound diff --git a/Nfp/Sound/IO.lean b/Nfp/Sound/IO.lean index 2c54376..2209aef 100644 --- a/Nfp/Sound/IO.lean +++ b/Nfp/Sound/IO.lean @@ -389,6 +389,53 @@ def certifyLayerBestMatchMarginLocal | .ok cert => return verifyLayerBestMatchMarginCert cert +/-- Soundly compute conservative bounds and tighten them using best-match margin evidence. -/ +def certifyModelFileBestMatchMargins + (path : System.FilePath) + (soundnessBits : Nat) + (inputPath? : Option System.FilePath := none) + (inputDelta : Rat := 0) + (partitionDepth : Nat := 0) + (targetOffset : Int := -1) + (maxSeqLen : Nat := 0) + (tightPattern : Bool := false) + (tightPatternLayers : Nat := 1) + (perRowPatternLayers : Nat := 0) + (scalePow10 : Nat := 9) + (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : IO (Except String ModelCert) := do + match ← readBinaryModelHeader path with + | .error e => return .error e + | .ok hdr => + if inputPath?.isNone then + return .error "best-match margin tightening requires local input" + let maxSeqLen' := if maxSeqLen = 0 then hdr.seqLen else maxSeqLen + match ← + Nfp.Untrusted.SoundCompute.certifyModelFile + path hdr.eps hdr.geluDerivTarget soundnessBits inputPath? inputDelta partitionDepth + (softmaxMarginLowerBound := 0) (softmaxExpEffort := softmaxExpEffort) with + | .error e => return .error e + | .ok cert => + match ← recomputeAttnWeightBounds path with + | .error e => + return .error s!"attnWeightBounds verification failed: {e}" + | .ok bounds => + let mut marginCerts : Array LayerBestMatchMarginCert := Array.mkEmpty hdr.numLayers + for layerIdx in [:hdr.numLayers] do + match ← + certifyLayerBestMatchMarginLocal path layerIdx + (inputPath? := inputPath?) (inputDelta := inputDelta) + (soundnessBits := soundnessBits) + (targetOffset := targetOffset) (maxSeqLen := maxSeqLen') + (tightPattern := tightPattern) + (tightPatternLayers := tightPatternLayers) + (perRowPatternLayers := perRowPatternLayers) + (scalePow10 := scalePow10) + (softmaxExpEffort := softmaxExpEffort) with + | .error e => return .error e + | .ok cert => marginCerts := marginCerts.push cert + return verifyModelCertBestMatchMargins cert hdr.eps soundnessBits hdr.geluDerivTarget + bounds.attnValueCoeff bounds.wqOpBoundMax bounds.wkOpBoundMax marginCerts + /-- Compute local per-head attention contribution bounds tightened by best-match pattern evidence. -/ def certifyHeadBoundsLocalBestMatch From 1c511a1b80efc1b7d0f5d80fa1435153bbc0bc68 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 26 Dec 2025 17:17:08 +0100 Subject: [PATCH 010/244] Update soundness limitations for best-match margins --- SOUNDNESS_LIMITATIONS.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index 0e311a3..0a384f7 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -10,14 +10,13 @@ soundness upgrade. It is intentionally brief and human-readable. discharge those assumptions from model weights. - `partitionDepth > 0` is rejected with an explicit error (no partitioning logic yet). - Affine arithmetic is only a scaffold (`Nfp/Sound/Affine.lean`) and not wired into SOUND certification. -- Softmax Jacobian bounds are enforced to use the worst-case probability interval `[0,1]` in - trusted IO. Margin-derived tightening is computed by the untrusted path, but trusted IO - currently **rejects nonzero** `softmaxMarginLowerBound` because margin evidence is unverified. -- Local per-head contribution bounds can now be tightened using a best-match pattern certificate, - but this tightening does **not** propagate to layer-level ModelCert bounds. -- Layer-level best-match margin certificates can be computed (binary only) and applied via - `tightenModelCertBestMatchMargins`, but this is not yet wired into the CLI and may not tighten - unless the best-match sweep covers all heads and query positions. +- Softmax Jacobian bounds in the standard `certify` path still use the worst-case probability + interval `[0,1]`; direct `--softmaxMargin` is rejected because margin evidence is unverified. +- Best-match margin tightening is now available via `nfp certify --bestMatchMargins` (binary + local + inputs with EMBEDDINGS). It runs a full best-match sweep across heads and query positions, which + can be expensive and will fail if coverage is incomplete. +- Per-head best-match tightening (used by head-pattern/induction certs) is still separate from + model-level certification unless `--bestMatchMargins` is used. - Best-match pattern certificates now use a margin-derived softmax Jacobian bound with an effort-indexed `expLB` (scaled Taylor + squaring). The lower-bound correctness of `expLB` is not yet formalized in Lean. @@ -31,7 +30,8 @@ soundness upgrade. It is intentionally brief and human-readable. - Implement input-space partitioning in the SOUND local path and plumb it through the certify pipeline. - Replace or augment interval propagation with affine forms to preserve correlations. - Add sound probability interval extraction for softmax (requires sound exp/log-sum-exp bounds). -- Verify or compute margin evidence in the trusted path so margin-derived softmax tightening can be enabled. +- Verify or compute margin evidence in the trusted path so margin-derived softmax tightening can be + enabled without a best-match sweep and without rejecting `--softmaxMargin`. - Tighten GeLU derivative envelopes to the exact interval supremum if desired. - Discharge the bridge theorem’s component-norm assumptions from certificates/model weights, and connect the resulting statement to the `Linearization` Jacobian theorems. From 9a28cedf760c093ceafa19719a5534ba24d8ddae Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 26 Dec 2025 17:34:41 +0100 Subject: [PATCH 011/244] Add portfolio bounds for expLB and softmax --- Nfp/Sound/Bounds/Exp.lean | 43 ++++++++++++++++++++++----------- Nfp/Sound/Bounds/Portfolio.lean | 28 +++++++++++++++++++++ Nfp/Sound/Bounds/Softmax.lean | 37 ++++++++++++++++++++++++++++ 3 files changed, 94 insertions(+), 14 deletions(-) diff --git a/Nfp/Sound/Bounds/Exp.lean b/Nfp/Sound/Bounds/Exp.lean index 45a6256..70f3eca 100644 --- a/Nfp/Sound/Bounds/Exp.lean +++ b/Nfp/Sound/Bounds/Exp.lean @@ -4,6 +4,7 @@ import Mathlib.Algebra.Order.Ring.Unbundled.Rat import Mathlib.Data.Finset.Lattice.Fold import Mathlib.Algebra.BigOperators.Group.Finset.Basic import Mathlib.Data.Nat.Factorial.Basic +import Nfp.Sound.Bounds.Portfolio namespace Nfp.Sound @@ -61,27 +62,41 @@ def expLBPortfolio : Array (Nat × Nat) := theorem expLBPortfolio_def : expLBPortfolio = #[(2, 4), (3, 6), (4, 8)] := rfl -/-- Portfolio lower bound on `exp`, with a baseline `1 + x` candidate. -/ -def expLB (x : Rat) (effort : Nat) : Rat := - let base : Rat := max 0 ((1 : Rat) + x) - let limit := min effort expLBPortfolio.size +/-- Portfolio of `expLBScaledTaylor` candidates, truncated by effort. -/ +def expLBCandidates (x : Rat) (effort : Nat) : Array Rat := Id.run do - let mut best := base + let limit := min effort expLBPortfolio.size + let mut out : Array Rat := Array.mkEmpty limit for i in [:limit] do let cand := expLBScaledTaylor x (expLBPortfolio[i]!).2 (expLBPortfolio[i]!).1 - best := max best cand - return best + out := out.push cand + return out -theorem expLB_def (x : Rat) (effort : Nat) : - expLB x effort = - let base : Rat := max 0 ((1 : Rat) + x) - let limit := min effort expLBPortfolio.size +theorem expLBCandidates_def (x : Rat) (effort : Nat) : + expLBCandidates x effort = Id.run do - let mut best := base + let limit := min effort expLBPortfolio.size + let mut out : Array Rat := Array.mkEmpty limit for i in [:limit] do let cand := expLBScaledTaylor x (expLBPortfolio[i]!).2 (expLBPortfolio[i]!).1 - best := max best cand - return best := rfl + out := out.push cand + return out := rfl + +/-- Portfolio lower bound on `exp`, with a baseline `1 + x` candidate. -/ +def expLB (x : Rat) (effort : Nat) : Rat := + let base : Rat := max 0 ((1 : Rat) + x) + lbBest base (expLBCandidates x effort) + +theorem expLB_def (x : Rat) (effort : Nat) : + expLB x effort = + let base : Rat := max 0 ((1 : Rat) + x) + lbBest base (expLBCandidates x effort) := rfl + +/-- `expLB` never undercuts its baseline `1 + x` lower bound. -/ +theorem expLB_ge_base (x : Rat) (effort : Nat) : + max 0 ((1 : Rat) + x) ≤ expLB x effort := by + dsimp [expLB] + exact lbBest_ge_base (base := max 0 ((1 : Rat) + x)) (cands := expLBCandidates x effort) /-- Default effort used for margin-derived softmax bounds. -/ def defaultSoftmaxExpEffort : Nat := 1 diff --git a/Nfp/Sound/Bounds/Portfolio.lean b/Nfp/Sound/Bounds/Portfolio.lean index 0fe4279..fddef5e 100644 --- a/Nfp/Sound/Bounds/Portfolio.lean +++ b/Nfp/Sound/Bounds/Portfolio.lean @@ -17,6 +17,20 @@ def ubBest (base : Rat) (cands : Array Rat) : Rat := theorem ubBest_def (base : Rat) (cands : Array Rat) : ubBest base cands = cands.foldl min base := rfl +/-- `ubBest` never exceeds its baseline upper bound. -/ +theorem ubBest_le_base (base : Rat) (cands : Array Rat) : ubBest base cands ≤ base := by + classical + have hList : cands.toList.foldl min base ≤ base := by + induction cands.toList generalizing base with + | nil => simp + | cons x xs ih => + simp only [List.foldl] + have h := ih (base := min base x) + exact le_trans h (min_le_left _ _) + have hArray : cands.foldl min base ≤ base := by + simpa [Array.foldl_toList] using hList + simpa [ubBest] using hArray + /-- Best lower bound among candidates (never worse than `base`). -/ def lbBest (base : Rat) (cands : Array Rat) : Rat := cands.foldl max base @@ -24,4 +38,18 @@ def lbBest (base : Rat) (cands : Array Rat) : Rat := theorem lbBest_def (base : Rat) (cands : Array Rat) : lbBest base cands = cands.foldl max base := rfl +/-- `lbBest` never undercuts its baseline lower bound. -/ +theorem lbBest_ge_base (base : Rat) (cands : Array Rat) : base ≤ lbBest base cands := by + classical + have hList : base ≤ cands.toList.foldl max base := by + induction cands.toList generalizing base with + | nil => simp + | cons x xs ih => + simp only [List.foldl] + have h := ih (base := max base x) + exact le_trans (le_max_left _ _) h + have hArray : base ≤ cands.foldl max base := by + simpa [Array.foldl_toList] using hList + simpa [lbBest] using hArray + end Nfp.Sound diff --git a/Nfp/Sound/Bounds/Softmax.lean b/Nfp/Sound/Bounds/Softmax.lean index d4e8db9..d13196e 100644 --- a/Nfp/Sound/Bounds/Softmax.lean +++ b/Nfp/Sound/Bounds/Softmax.lean @@ -1,6 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Mathlib.Algebra.Order.Ring.Unbundled.Rat +import Mathlib.Tactic.Linarith import Nfp.Sound.Bounds.Exp import Nfp.Sound.Bounds.Portfolio @@ -29,6 +30,20 @@ private def clamp01 (x : Rat) : Rat := theorem clamp01_def (x : Rat) : clamp01 x = max 0 (min x 1) := rfl +private theorem clamp01_nonneg (x : Rat) : 0 ≤ clamp01 x := by + dsimp [clamp01] + exact le_max_left _ _ + +private theorem clamp01_le_one (x : Rat) : clamp01 x ≤ 1 := by + have h0 : (0 : Rat) ≤ 1 := by + decide + have hmin : min x 1 ≤ (1 : Rat) := by + exact min_le_right _ _ + have hmax : max 0 (min x 1) ≤ (1 : Rat) := by + exact max_le_iff.mpr ⟨h0, hmin⟩ + dsimp [clamp01] + exact hmax + /-- Local upper bound on the row-sum softmax Jacobian norm given `p ∈ [pLo, pHi]`. -/ def softmaxJacobianNormInfBound (pLo pHi : Rat) : Rat := let lo0 := min pLo pHi @@ -143,6 +158,19 @@ theorem softmaxJacobianNormInfBoundFromMaxProb_def (pLo : Rat) : else half := rfl +/-- Margin-derived Jacobian bounds never exceed the worst-case `1/2`. -/ +theorem softmaxJacobianNormInfBoundFromMaxProb_le_worst (pLo : Rat) : + softmaxJacobianNormInfBoundFromMaxProb pLo ≤ softmaxJacobianNormInfWorst := by + have hp0 : 0 ≤ clamp01 pLo := clamp01_nonneg pLo + have hp1 : clamp01 pLo ≤ 1 := clamp01_le_one pLo + by_cases h : (2 : Rat)⁻¹ < clamp01 pLo + · have hbound : + (2 : Rat) * clamp01 pLo * (1 - clamp01 pLo) ≤ (2 : Rat)⁻¹ := by + nlinarith [hp0, hp1] + simpa [softmaxJacobianNormInfBoundFromMaxProb, softmaxJacobianNormInfWorst_def, h] + using hbound + · simp [softmaxJacobianNormInfBoundFromMaxProb, softmaxJacobianNormInfWorst_def, h] + /-- Upper bound on the row-sum softmax Jacobian norm from a logit margin. -/ def softmaxJacobianNormInfBoundFromMargin (seqLen : Nat) (margin : Rat) (expEffort : Nat) : Rat := softmaxJacobianNormInfBoundFromMaxProb (softmaxMaxProbLowerBound seqLen margin expEffort) @@ -153,4 +181,13 @@ theorem softmaxJacobianNormInfBoundFromMargin_def (seqLen : Nat) (margin : Rat) softmaxJacobianNormInfBoundFromMaxProb (softmaxMaxProbLowerBound seqLen margin expEffort) := rfl +/-- Margin-derived Jacobian bound never exceeds the worst-case `1/2`. -/ +theorem softmaxJacobianNormInfBoundFromMargin_le_worst (seqLen : Nat) (margin : Rat) + (expEffort : Nat) : + softmaxJacobianNormInfBoundFromMargin seqLen margin expEffort ≤ + softmaxJacobianNormInfWorst := by + simpa [softmaxJacobianNormInfBoundFromMargin_def] using + softmaxJacobianNormInfBoundFromMaxProb_le_worst + (pLo := softmaxMaxProbLowerBound seqLen margin expEffort) + end Nfp.Sound From e6233df6857e4186974c3fd1231e085f5ff8922d Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 27 Dec 2025 00:16:59 +0100 Subject: [PATCH 012/244] Verify model weight-derived bounds --- Nfp/Sound/Cert.lean | 49 ++++++++++++++++---- Nfp/Sound/IO.lean | 64 +++++++++++++++++--------- Nfp/Sound/TextPure.lean | 99 ++++++++++++++++++++++++++++++++++------ SOUNDNESS_LIMITATIONS.md | 6 ++- 4 files changed, 172 insertions(+), 46 deletions(-) diff --git a/Nfp/Sound/Cert.lean b/Nfp/Sound/Cert.lean index b04c76f..899a296 100644 --- a/Nfp/Sound/Cert.lean +++ b/Nfp/Sound/Cert.lean @@ -400,9 +400,10 @@ end ModelCert /-! ### Certificate verification helpers -/ -/-- Verify attention-weight bounds from per-layer arrays. -/ -def checkAttnWeightBoundsArrays (cert : ModelCert) - (attnValueCoeff wqOpBoundMax wkOpBoundMax : Array Rat) : Except String Unit := +/-- Verify weight-derived bounds from per-layer arrays. -/ +def checkWeightBoundsArrays (cert : ModelCert) + (attnValueCoeff wqOpBoundMax wkOpBoundMax mlpWinBound mlpWoutBound + ln1MaxAbsGamma ln1MaxAbsBeta ln2MaxAbsGamma : Array Rat) : Except String Unit := Id.run do if attnValueCoeff.size ≠ cert.layers.size then return .error "attnValueCoeff layer count mismatch" @@ -410,10 +411,25 @@ def checkAttnWeightBoundsArrays (cert : ModelCert) return .error "wqOpBoundMax layer count mismatch" if wkOpBoundMax.size ≠ cert.layers.size then return .error "wkOpBoundMax layer count mismatch" + if mlpWinBound.size ≠ cert.layers.size then + return .error "mlpWinBound layer count mismatch" + if mlpWoutBound.size ≠ cert.layers.size then + return .error "mlpWoutBound layer count mismatch" + if ln1MaxAbsGamma.size ≠ cert.layers.size then + return .error "ln1MaxAbsGamma layer count mismatch" + if ln1MaxAbsBeta.size ≠ cert.layers.size then + return .error "ln1MaxAbsBeta layer count mismatch" + if ln2MaxAbsGamma.size ≠ cert.layers.size then + return .error "ln2MaxAbsGamma layer count mismatch" for idx in [:cert.layers.size] do let expValue := attnValueCoeff[idx]! let expWq := wqOpBoundMax[idx]! let expWk := wkOpBoundMax[idx]! + let expMlpWin := mlpWinBound[idx]! + let expMlpWout := mlpWoutBound[idx]! + let expLn1Gamma := ln1MaxAbsGamma[idx]! + let expLn1Beta := ln1MaxAbsBeta[idx]! + let expLn2Gamma := ln2MaxAbsGamma[idx]! let layer := cert.layers[idx]! if expValue ≠ layer.attnValueCoeff then return .error s!"attnValueCoeff mismatch at layer {idx}" @@ -421,10 +437,20 @@ def checkAttnWeightBoundsArrays (cert : ModelCert) return .error s!"wqOpBoundMax mismatch at layer {idx}" if expWk ≠ layer.wkOpBoundMax then return .error s!"wkOpBoundMax mismatch at layer {idx}" + if expMlpWin ≠ layer.mlpWinBound then + return .error s!"mlpWinBound mismatch at layer {idx}" + if expMlpWout ≠ layer.mlpWoutBound then + return .error s!"mlpWoutBound mismatch at layer {idx}" + if expLn1Gamma ≠ layer.ln1MaxAbsGamma then + return .error s!"ln1MaxAbsGamma mismatch at layer {idx}" + if expLn1Beta ≠ layer.ln1MaxAbsBeta then + return .error s!"ln1MaxAbsBeta mismatch at layer {idx}" + if expLn2Gamma ≠ layer.ln2MaxAbsGamma then + return .error s!"ln2MaxAbsGamma mismatch at layer {idx}" return .ok () -theorem checkAttnWeightBoundsArrays_spec : - checkAttnWeightBoundsArrays = checkAttnWeightBoundsArrays := rfl +theorem checkWeightBoundsArrays_spec : + checkWeightBoundsArrays = checkWeightBoundsArrays := rfl /-- Ensure all layers have zero softmax margin evidence. -/ def checkSoftmaxMarginZero (cert : ModelCert) : Except String Unit := @@ -501,7 +527,9 @@ def verifyModelCert (eps : Rat) (soundnessBits : Nat) (geluDerivTarget : GeluDerivTarget) - (attnValueCoeff wqOpBoundMax wkOpBoundMax : Array Rat) : Except String ModelCert := + (attnValueCoeff wqOpBoundMax wkOpBoundMax mlpWinBound mlpWoutBound + ln1MaxAbsGamma ln1MaxAbsBeta ln2MaxAbsGamma : Array Rat) : + Except String ModelCert := Id.run do if cert.eps ≠ eps then return .error "model header eps mismatch" @@ -516,7 +544,8 @@ def verifyModelCert match checkSoftmaxMarginZero cert with | .error e => return .error e | .ok _ => - match checkAttnWeightBoundsArrays cert attnValueCoeff wqOpBoundMax wkOpBoundMax with + match checkWeightBoundsArrays cert attnValueCoeff wqOpBoundMax wkOpBoundMax + mlpWinBound mlpWoutBound ln1MaxAbsGamma ln1MaxAbsBeta ln2MaxAbsGamma with | .error e => return .error e | .ok _ => return .ok cert return .error "sound certificate failed internal consistency checks" @@ -530,11 +559,13 @@ def verifyModelCertBestMatchMargins (eps : Rat) (soundnessBits : Nat) (geluDerivTarget : GeluDerivTarget) - (attnValueCoeff wqOpBoundMax wkOpBoundMax : Array Rat) + (attnValueCoeff wqOpBoundMax wkOpBoundMax mlpWinBound mlpWoutBound + ln1MaxAbsGamma ln1MaxAbsBeta ln2MaxAbsGamma : Array Rat) (marginCerts : Array LayerBestMatchMarginCert) : Except String ModelCert := Id.run do match verifyModelCert cert eps soundnessBits geluDerivTarget - attnValueCoeff wqOpBoundMax wkOpBoundMax with + attnValueCoeff wqOpBoundMax wkOpBoundMax + mlpWinBound mlpWoutBound ln1MaxAbsGamma ln1MaxAbsBeta ln2MaxAbsGamma with | .error e => return .error e | .ok base => match tightenModelCertBestMatchMargins base marginCerts with diff --git a/Nfp/Sound/IO.lean b/Nfp/Sound/IO.lean index 2209aef..83cc1c6 100644 --- a/Nfp/Sound/IO.lean +++ b/Nfp/Sound/IO.lean @@ -51,8 +51,8 @@ private def readModelEps (path : System.FilePath) : IO (Except String Rat) := do | .error e => return .error e | .ok (eps, _) => return .ok eps -private def recomputeAttnWeightBoundsBinary - (path : System.FilePath) : IO (Except String AttnWeightBounds) := do +private def recomputeModelWeightBoundsBinary + (path : System.FilePath) : IO (Except String ModelWeightBounds) := do let h ← IO.FS.Handle.mk path IO.FS.Mode.read match ← Nfp.Untrusted.SoundBinary.readBinaryHeader h with | .error e => return .error e @@ -66,6 +66,11 @@ private def recomputeAttnWeightBoundsBinary | .ok _ => pure () let mut valuePairsLayers : Array (Array (Int × Int)) := Array.mkEmpty hdr.numLayers let mut qkPairsLayers : Array (Array (Int × Int)) := Array.mkEmpty hdr.numLayers + let mut mlpWinBound : Array Rat := Array.mkEmpty hdr.numLayers + let mut mlpWoutBound : Array Rat := Array.mkEmpty hdr.numLayers + let mut ln1MaxAbsGamma : Array Rat := Array.mkEmpty hdr.numLayers + let mut ln1MaxAbsBeta : Array Rat := Array.mkEmpty hdr.numLayers + let mut ln2MaxAbsGamma : Array Rat := Array.mkEmpty hdr.numLayers for _l in [:hdr.numLayers] do let mut valuePairs : Array (Int × Int) := Array.mkEmpty hdr.numHeads let mut qkPairs : Array (Int × Int) := Array.mkEmpty hdr.numHeads @@ -155,11 +160,16 @@ private def recomputeAttnWeightBoundsBinary match ln2BetaScaledE with | .error e => return .error e | .ok _ => pure () - let _ := nWinScaled - let _ := nWoutScaled - let _ := ln1GammaScaled - let _ := ln1BetaScaled - let _ := ln2GammaScaled + let nWin := ratOfScaledInt scalePow10 nWinScaled + let nWout := ratOfScaledInt scalePow10 nWoutScaled + let ln1Gamma := ratOfScaledInt scalePow10 ln1GammaScaled + let ln1Beta := ratOfScaledInt scalePow10 ln1BetaScaled + let ln2Gamma := ratOfScaledInt scalePow10 ln2GammaScaled + mlpWinBound := mlpWinBound.push nWin + mlpWoutBound := mlpWoutBound.push nWout + ln1MaxAbsGamma := ln1MaxAbsGamma.push ln1Gamma + ln1MaxAbsBeta := ln1MaxAbsBeta.push ln1Beta + ln2MaxAbsGamma := ln2MaxAbsGamma.push ln2Gamma valuePairsLayers := valuePairsLayers.push valuePairs qkPairsLayers := qkPairsLayers.push qkPairs match ← Nfp.Untrusted.SoundBinary.skipF64Array h hdr.modelDim with @@ -178,22 +188,27 @@ private def recomputeAttnWeightBoundsBinary attnValueCoeff := coeffs wqOpBoundMax := wqMaxs wkOpBoundMax := wkMaxs + mlpWinBound := mlpWinBound + mlpWoutBound := mlpWoutBound + ln1MaxAbsGamma := ln1MaxAbsGamma + ln1MaxAbsBeta := ln1MaxAbsBeta + ln2MaxAbsGamma := ln2MaxAbsGamma } -private def recomputeAttnWeightBoundsText - (path : System.FilePath) : IO (Except String AttnWeightBounds) := do +private def recomputeModelWeightBoundsText + (path : System.FilePath) : IO (Except String ModelWeightBounds) := do let contents ← IO.FS.readFile path let lines : Array String := (contents.splitOn "\n").toArray - return attnWeightBoundsFromTextLines lines + return modelWeightBoundsFromTextLines lines -private def recomputeAttnWeightBounds - (path : System.FilePath) : IO (Except String AttnWeightBounds) := do +private def recomputeModelWeightBounds + (path : System.FilePath) : IO (Except String ModelWeightBounds) := do let firstLine ← IO.FS.withFile path IO.FS.Mode.read fun h => h.getLine if firstLine.trim = "NFP_BINARY_V1" then - recomputeAttnWeightBoundsBinary path + recomputeModelWeightBoundsBinary path else - recomputeAttnWeightBoundsText path + recomputeModelWeightBoundsText path /-- Compute weight-only per-head contribution bounds from a binary `.nfpt`. -/ def certifyHeadBoundsBinary @@ -223,12 +238,14 @@ def certifyModelFileGlobal softmaxMarginLowerBound softmaxExpEffort with | .error e => return .error e | .ok cert => - match ← recomputeAttnWeightBounds path with + match ← recomputeModelWeightBounds path with | .error e => - return .error s!"attnWeightBounds verification failed: {e}" + return .error s!"model weight bounds verification failed: {e}" | .ok bounds => return verifyModelCert cert eps soundnessBits geluTarget bounds.attnValueCoeff bounds.wqOpBoundMax bounds.wkOpBoundMax + bounds.mlpWinBound bounds.mlpWoutBound + bounds.ln1MaxAbsGamma bounds.ln1MaxAbsBeta bounds.ln2MaxAbsGamma /-- Entry point for sound certification (global or local). -/ def certifyModelFile @@ -248,12 +265,14 @@ def certifyModelFile softmaxMarginLowerBound softmaxExpEffort with | .error e => return .error e | .ok cert => - match ← recomputeAttnWeightBounds path with + match ← recomputeModelWeightBounds path with | .error e => - return .error s!"attnWeightBounds verification failed: {e}" + return .error s!"model weight bounds verification failed: {e}" | .ok bounds => return verifyModelCert cert eps soundnessBits geluTarget bounds.attnValueCoeff bounds.wqOpBoundMax bounds.wkOpBoundMax + bounds.mlpWinBound bounds.mlpWoutBound + bounds.ln1MaxAbsGamma bounds.ln1MaxAbsBeta bounds.ln2MaxAbsGamma /-- Compute per-head contribution bounds (global). -/ def certifyHeadBounds @@ -415,9 +434,9 @@ def certifyModelFileBestMatchMargins (softmaxMarginLowerBound := 0) (softmaxExpEffort := softmaxExpEffort) with | .error e => return .error e | .ok cert => - match ← recomputeAttnWeightBounds path with + match ← recomputeModelWeightBounds path with | .error e => - return .error s!"attnWeightBounds verification failed: {e}" + return .error s!"model weight bounds verification failed: {e}" | .ok bounds => let mut marginCerts : Array LayerBestMatchMarginCert := Array.mkEmpty hdr.numLayers for layerIdx in [:hdr.numLayers] do @@ -434,7 +453,10 @@ def certifyModelFileBestMatchMargins | .error e => return .error e | .ok cert => marginCerts := marginCerts.push cert return verifyModelCertBestMatchMargins cert hdr.eps soundnessBits hdr.geluDerivTarget - bounds.attnValueCoeff bounds.wqOpBoundMax bounds.wkOpBoundMax marginCerts + bounds.attnValueCoeff bounds.wqOpBoundMax bounds.wkOpBoundMax + bounds.mlpWinBound bounds.mlpWoutBound + bounds.ln1MaxAbsGamma bounds.ln1MaxAbsBeta bounds.ln2MaxAbsGamma + marginCerts /-- Compute local per-head attention contribution bounds tightened by best-match pattern evidence. -/ diff --git a/Nfp/Sound/TextPure.lean b/Nfp/Sound/TextPure.lean index 715b4f2..be8a929 100644 --- a/Nfp/Sound/TextPure.lean +++ b/Nfp/Sound/TextPure.lean @@ -24,17 +24,24 @@ structure TextModelDims where start : Nat deriving Repr -/-- Per-layer attention weight bounds extracted from a text model. -/ -structure AttnWeightBounds where +/-- Per-layer weight-derived bounds extracted from a text model. -/ +structure ModelWeightBounds where attnValueCoeff : Array Rat wqOpBoundMax : Array Rat wkOpBoundMax : Array Rat + mlpWinBound : Array Rat + mlpWoutBound : Array Rat + ln1MaxAbsGamma : Array Rat + ln1MaxAbsBeta : Array Rat + ln2MaxAbsGamma : Array Rat deriving Repr -/-- Verify that attention-weight bounds match the certificate layer fields. -/ -def checkAttnWeightBounds (cert : ModelCert) (expected : AttnWeightBounds) : Except String Unit := - checkAttnWeightBoundsArrays cert expected.attnValueCoeff expected.wqOpBoundMax - expected.wkOpBoundMax +/-- Verify that weight-derived bounds match the certificate layer fields. -/ +def checkModelWeightBounds (cert : ModelCert) (expected : ModelWeightBounds) : + Except String Unit := + checkWeightBoundsArrays cert expected.attnValueCoeff expected.wqOpBoundMax + expected.wkOpBoundMax expected.mlpWinBound expected.mlpWoutBound + expected.ln1MaxAbsGamma expected.ln1MaxAbsBeta expected.ln2MaxAbsGamma def parseTextHeaderDims (lines : Array String) : Except String TextModelDims := Id.run do @@ -126,6 +133,14 @@ def consumeVector let step := fun (acc : Array Rat) (x : Rat) => acc.push x foldRatTokens lines start n (Array.mkEmpty n) step +/-- Consume a vector of length `n` and return its max absolute entry. -/ +def consumeVectorMaxAbs + (lines : Array String) + (start : Nat) + (n : Nat) : Except String (Rat × Nat) := + let step := fun (acc : Rat) (x : Rat) => max acc (ratAbs x) + foldRatTokens lines start n 0 step + /-- Consume a matrix and return its row-sum norm. -/ def consumeMatrixNormInf (lines : Array String) @@ -136,8 +151,8 @@ def consumeMatrixNormInf | .error e => .error e | .ok (xs, next) => .ok (matrixNormInfOfRowMajor rows cols xs, next) -/-- Compute per-layer attention value and `W_Q/W_K` bounds from text model lines. -/ -def attnWeightBoundsFromTextLines (lines : Array String) : Except String AttnWeightBounds := +/-- Compute per-layer weight bounds from text model lines. -/ +def modelWeightBoundsFromTextLines (lines : Array String) : Except String ModelWeightBounds := Id.run do let infoE := parseTextHeaderDims lines let info ← @@ -149,6 +164,11 @@ def attnWeightBoundsFromTextLines (lines : Array String) : Except String AttnWei let mut attnValueCoeff : Array Rat := Array.replicate info.numLayers 0 let mut wqMax : Array Rat := Array.replicate info.numLayers 0 let mut wkMax : Array Rat := Array.replicate info.numLayers 0 + let mut mlpWinBound : Array Rat := Array.replicate info.numLayers 0 + let mut mlpWoutBound : Array Rat := Array.replicate info.numLayers 0 + let mut ln1MaxAbsGamma : Array Rat := Array.replicate info.numLayers 0 + let mut ln1MaxAbsBeta : Array Rat := Array.replicate info.numLayers 0 + let mut ln2MaxAbsGamma : Array Rat := Array.replicate info.numLayers 0 while i < lines.size do let line := lines[i]!.trim if line.startsWith "LAYER" then @@ -189,31 +209,82 @@ def attnWeightBoundsFromTextLines (lines : Array String) : Except String AttnWei attnValueCoeff := attnValueCoeff.set! r (attnValueCoeff[r]! + (nv * no)) i := next2 + else if line = "W_in" then + let r := curLayer + match consumeMatrixNormInf lines (i + 1) info.modelDim info.hiddenDim with + | .error e => return .error e + | .ok (nwin, next) => + if r < mlpWinBound.size then + mlpWinBound := mlpWinBound.set! r nwin + i := next + else if line = "W_out" then + let r := curLayer + match consumeMatrixNormInf lines (i + 1) info.hiddenDim info.modelDim with + | .error e => return .error e + | .ok (nwout, next) => + if r < mlpWoutBound.size then + mlpWoutBound := mlpWoutBound.set! r nwout + i := next + else if line = "LN1_GAMMA" then + let r := curLayer + match consumeVectorMaxAbs lines (i + 1) info.modelDim with + | .error e => return .error e + | .ok (g, next) => + if r < ln1MaxAbsGamma.size then + ln1MaxAbsGamma := ln1MaxAbsGamma.set! r g + i := next + else if line = "LN1_BETA" then + let r := curLayer + match consumeVectorMaxAbs lines (i + 1) info.modelDim with + | .error e => return .error e + | .ok (b, next) => + if r < ln1MaxAbsBeta.size then + ln1MaxAbsBeta := ln1MaxAbsBeta.set! r b + i := next + else if line = "LN2_GAMMA" then + let r := curLayer + match consumeVectorMaxAbs lines (i + 1) info.modelDim with + | .error e => return .error e + | .ok (g, next) => + if r < ln2MaxAbsGamma.size then + ln2MaxAbsGamma := ln2MaxAbsGamma.set! r g + i := next + else if line = "LN2_BETA" then + match consumeVectorMaxAbs lines (i + 1) info.modelDim with + | .error e => return .error e + | .ok (_, next) => + i := next else i := i + 1 return .ok { attnValueCoeff := attnValueCoeff wqOpBoundMax := wqMax wkOpBoundMax := wkMax + mlpWinBound := mlpWinBound + mlpWoutBound := mlpWoutBound + ln1MaxAbsGamma := ln1MaxAbsGamma + ln1MaxAbsBeta := ln1MaxAbsBeta + ln2MaxAbsGamma := ln2MaxAbsGamma } /-- Compute per-layer `attnValueCoeff` from text model lines. -/ def attnValueCoeffFromTextLines (lines : Array String) : Except String (Array Rat) := do - let bounds ← attnWeightBoundsFromTextLines lines + let bounds ← modelWeightBoundsFromTextLines lines return bounds.attnValueCoeff /-! ### Specs -/ theorem parseTextHeaderDims_spec : parseTextHeaderDims = parseTextHeaderDims := rfl -theorem AttnWeightBounds_spec : AttnWeightBounds = AttnWeightBounds := rfl -theorem checkAttnWeightBounds_spec : - checkAttnWeightBounds = checkAttnWeightBounds := rfl +theorem ModelWeightBounds_spec : ModelWeightBounds = ModelWeightBounds := rfl +theorem checkModelWeightBounds_spec : + checkModelWeightBounds = checkModelWeightBounds := rfl theorem foldRatTokens_spec (α : Type) : @foldRatTokens α = @foldRatTokens α := rfl theorem consumeVector_spec : consumeVector = consumeVector := rfl +theorem consumeVectorMaxAbs_spec : consumeVectorMaxAbs = consumeVectorMaxAbs := rfl theorem consumeMatrixNormInf_spec : consumeMatrixNormInf = consumeMatrixNormInf := rfl -theorem attnWeightBoundsFromTextLines_spec : - attnWeightBoundsFromTextLines = attnWeightBoundsFromTextLines := rfl +theorem modelWeightBoundsFromTextLines_spec : + modelWeightBoundsFromTextLines = modelWeightBoundsFromTextLines := rfl theorem attnValueCoeffFromTextLines_spec : attnValueCoeffFromTextLines = attnValueCoeffFromTextLines := rfl diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index 0a384f7..0bc839a 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -6,8 +6,10 @@ soundness upgrade. It is intentionally brief and human-readable. ### Current limitations - The bridge theorem in `Nfp/Sound/Bridge.lean` links `LayerAmplificationCert` bounds to `DeepLinearization` residual Jacobians, but it requires external operator-norm assumptions - (LN Jacobians, attention full Jacobian, and MLP factors). The trusted checker does not yet - discharge those assumptions from model weights. + (LN Jacobians, attention full Jacobian, and MLP factors). The trusted checker now recomputes + weight-derived bounds (W_Q/W_K/W_V/W_O, MLP W_in/W_out, LN1 gamma/beta, LN2 gamma) from model files, + but it still treats softmax probability or margin evidence as external and does not derive those + bounds from logits. - `partitionDepth > 0` is rejected with an explicit error (no partitioning logic yet). - Affine arithmetic is only a scaffold (`Nfp/Sound/Affine.lean`) and not wired into SOUND certification. - Softmax Jacobian bounds in the standard `certify` path still use the worst-case probability From 2487ab3be907c1a98d2cdadfdc379209e6ac9719 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 27 Dec 2025 01:13:21 +0100 Subject: [PATCH 013/244] Derive softmax prob bounds from score magnitude --- Nfp/Sound/Bounds/Attention.lean | 16 +++++++ Nfp/Sound/Bounds/Exp.lean | 62 +++++++++++++++++++++++++++ Nfp/Sound/Bounds/Softmax.lean | 38 +++++++++++++++++ Nfp/Sound/Bridge.lean | 3 +- Nfp/Sound/Cert.lean | 74 ++++++++++++++++++++++----------- Nfp/Untrusted/SoundCompute.lean | 31 +++++++++----- SOUNDNESS_LIMITATIONS.md | 6 ++- 7 files changed, 193 insertions(+), 37 deletions(-) diff --git a/Nfp/Sound/Bounds/Attention.lean b/Nfp/Sound/Bounds/Attention.lean index 881ad3c..72d94e2 100644 --- a/Nfp/Sound/Bounds/Attention.lean +++ b/Nfp/Sound/Bounds/Attention.lean @@ -64,4 +64,20 @@ theorem attnPatternCoeffBound_def (seqLen modelDim headDim : Nat) attnScoreGradBound seqLen modelDim headDim ln1OutMaxAbs wqBound wkBound * (inputL1 * valueCoeff) := rfl +/-- Conservative bound on `|q·k|/sqrt(d_head)` using max-abs LN1 output and W_Q/W_K norms. -/ +def attnScoreAbsBound (modelDim headDim : Nat) + (ln1OutMaxAbs wqBound wkBound : Rat) : Rat := + let dRat : Rat := (modelDim : Nat) + let qMax := dRat * ln1OutMaxAbs * wqBound + let kMax := dRat * ln1OutMaxAbs * wkBound + sqrtUpperRat headDim * qMax * kMax + +theorem attnScoreAbsBound_def (modelDim headDim : Nat) + (ln1OutMaxAbs wqBound wkBound : Rat) : + attnScoreAbsBound modelDim headDim ln1OutMaxAbs wqBound wkBound = + let dRat : Rat := (modelDim : Nat) + let qMax := dRat * ln1OutMaxAbs * wqBound + let kMax := dRat * ln1OutMaxAbs * wkBound + sqrtUpperRat headDim * qMax * kMax := rfl + end Nfp.Sound diff --git a/Nfp/Sound/Bounds/Exp.lean b/Nfp/Sound/Bounds/Exp.lean index 70f3eca..202945e 100644 --- a/Nfp/Sound/Bounds/Exp.lean +++ b/Nfp/Sound/Bounds/Exp.lean @@ -98,6 +98,68 @@ theorem expLB_ge_base (x : Rat) (effort : Nat) : dsimp [expLB] exact lbBest_ge_base (base := max 0 ((1 : Rat) + x)) (cands := expLBCandidates x effort) +/-- Scaling exponent so `x / 2^s ≤ 1/2` for `x ≥ 0`. -/ +private def expUBScalePow (x : Rat) : Nat := + let half : Rat := (1 : Rat) / 2 + if x ≤ half then + 0 + else + Id.run do + let mut s : Nat := 0 + let mut y : Rat := x + while y > half do + s := s + 1 + y := y / (2 : Rat) + return s + +theorem expUBScalePow_def (x : Rat) : + expUBScalePow x = + let half : Rat := (1 : Rat) / 2 + if x ≤ half then + 0 + else + Id.run do + let mut s : Nat := 0 + let mut y : Rat := x + while y > half do + s := s + 1 + y := y / (2 : Rat) + return s := rfl + +/-! +### Exp upper bounds (geometric series + squaring) +-/ + +/-- Upper bound on `exp(x)` for `x ≥ 0` using `exp(z) ≤ 1/(1-z)` with scaling. -/ +def expUBScaledGeom (x : Rat) : Rat := + if x ≤ 0 then + 1 + else + let scalePow := expUBScalePow x + let scale : Rat := (Nat.pow 2 scalePow : Nat) + let z := x / scale + let denom := (1 : Rat) - z + if denom ≤ 0 then + 0 + else + let base := (1 : Rat) / denom + ratPow base (Nat.pow 2 scalePow) + +theorem expUBScaledGeom_def (x : Rat) : + expUBScaledGeom x = + if x ≤ 0 then + 1 + else + let scalePow := expUBScalePow x + let scale : Rat := (Nat.pow 2 scalePow : Nat) + let z := x / scale + let denom := (1 : Rat) - z + if denom ≤ 0 then + 0 + else + let base := (1 : Rat) / denom + ratPow base (Nat.pow 2 scalePow) := rfl + /-- Default effort used for margin-derived softmax bounds. -/ def defaultSoftmaxExpEffort : Nat := 1 diff --git a/Nfp/Sound/Bounds/Softmax.lean b/Nfp/Sound/Bounds/Softmax.lean index d13196e..214ae86 100644 --- a/Nfp/Sound/Bounds/Softmax.lean +++ b/Nfp/Sound/Bounds/Softmax.lean @@ -78,6 +78,44 @@ theorem softmaxJacobianNormInfBound_def (pLo pHi : Rat) : /-! ### Margin-derived softmax bounds -/ +/-- Probability interval from a uniform score bound `|s| ≤ B`. -/ +def softmaxProbIntervalFromScoreAbsBound (seqLen : Nat) (scoreAbsBound : Rat) + (expEffort : Nat) : Rat × Rat := + if seqLen = 0 then + (0, 1) + else if seqLen = 1 then + (1, 1) + else + let b := max 0 scoreAbsBound + let nRat : Rat := (seqLen : Nat) + let ePosUb := expUBScaledGeom b + let eNegLb := expLB (-b) expEffort + if eNegLb = 0 then + (0, 1) + else + let denomLo := eNegLb + (nRat - 1) * ePosUb + let denomHi := ePosUb + (nRat - 1) * eNegLb + (eNegLb / denomLo, ePosUb / denomHi) + +theorem softmaxProbIntervalFromScoreAbsBound_def (seqLen : Nat) (scoreAbsBound : Rat) + (expEffort : Nat) : + softmaxProbIntervalFromScoreAbsBound seqLen scoreAbsBound expEffort = + if seqLen = 0 then + (0, 1) + else if seqLen = 1 then + (1, 1) + else + let b := max 0 scoreAbsBound + let nRat : Rat := (seqLen : Nat) + let ePosUb := expUBScaledGeom b + let eNegLb := expLB (-b) expEffort + if eNegLb = 0 then + (0, 1) + else + let denomLo := eNegLb + (nRat - 1) * ePosUb + let denomHi := ePosUb + (nRat - 1) * eNegLb + (eNegLb / denomLo, ePosUb / denomHi) := rfl + /-- Lower bound on the maximum softmax probability from a logit margin. Uses a portfolio `expLB` to lower bound `exp(m)` and maps it to diff --git a/Nfp/Sound/Bridge.lean b/Nfp/Sound/Bridge.lean index 461cc0f..b1a8731 100644 --- a/Nfp/Sound/Bridge.lean +++ b/Nfp/Sound/Bridge.lean @@ -270,7 +270,8 @@ theorem attn_pattern_bound_of_cert (l.attnPatternCoeff : ℝ) = (attnPatternCoeffBound seqLenNat modelDimNat headDimNat l.ln1OutMaxAbsBound l.wqOpBoundMax l.wkOpBoundMax l.attnValueCoeff : ℝ) := by - rcases hValid with ⟨_hln1, _hln2, _hln1Out, _hsoftmax, hpat, _hattn, _hmlpCoeff, _hmlp, _hC⟩ + rcases hValid with ⟨_hln1, _hln2, _hln1Out, _hProbLo, _hProbHi, + _hsoftmax, hpat, _hattn, _hmlpCoeff, _hmlp, _hC⟩ exact congrArg (fun x : Rat => (x : ℝ)) hpat have hCoeff_eq : (seqLenNat : ℝ) * S_bound * V = diff --git a/Nfp/Sound/Cert.lean b/Nfp/Sound/Cert.lean index 899a296..8761714 100644 --- a/Nfp/Sound/Cert.lean +++ b/Nfp/Sound/Cert.lean @@ -109,19 +109,25 @@ theorem softmaxJacobianNormInfPortfolioBound_def (seqLen : Nat) (l : LayerAmplif l.softmaxExpEffort] := rfl /-- Update margin evidence and recompute dependent softmax + residual bounds. -/ -def withSoftmaxMargin (seqLen : Nat) (marginLowerBound : Rat) (softmaxExpEffort : Nat) - (l : LayerAmplificationCert) : LayerAmplificationCert := +def withSoftmaxMargin (seqLen modelDim headDim : Nat) + (marginLowerBound : Rat) (softmaxExpEffort : Nat) (l : LayerAmplificationCert) : + LayerAmplificationCert := let l' := { l with softmaxMarginLowerBound := marginLowerBound softmaxExpEffort := softmaxExpEffort } - let softmaxBound := softmaxJacobianNormInfPortfolioBound seqLen l' + let scoreAbsBound := + attnScoreAbsBound modelDim headDim l'.ln1OutMaxAbsBound l'.wqOpBoundMax l'.wkOpBoundMax + let (softmaxProbLo, softmaxProbHi) := + softmaxProbIntervalFromScoreAbsBound seqLen scoreAbsBound l'.softmaxExpEffort + let l'' := { l' with softmaxProbLo := softmaxProbLo, softmaxProbHi := softmaxProbHi } + let softmaxBound := softmaxJacobianNormInfPortfolioBound seqLen l'' let attnJacBound := - l'.ln1Bound * - ((seqLen : Rat) * l'.attnValueCoeff + softmaxBound * l'.attnPatternCoeff) - let mlpJacBound := l'.mlpJacBound + l''.ln1Bound * + ((seqLen : Rat) * l''.attnValueCoeff + softmaxBound * l''.attnPatternCoeff) + let mlpJacBound := l''.mlpJacBound let C := attnJacBound + mlpJacBound + attnJacBound * mlpJacBound - { l' with + { l'' with softmaxJacobianNormInfUpperBound := softmaxBound attnJacBound := attnJacBound C := C } @@ -129,6 +135,10 @@ def withSoftmaxMargin (seqLen : Nat) (marginLowerBound : Rat) (softmaxExpEffort /-- Internal consistency checks for per-layer bounds. -/ def Valid (eps : Rat) (sqrtPrecBits : Nat) (seqLen modelDim headDim : Nat) (l : LayerAmplificationCert) : Prop := + let scoreAbsBound := + attnScoreAbsBound modelDim headDim l.ln1OutMaxAbsBound l.wqOpBoundMax l.wkOpBoundMax + let probInterval := + softmaxProbIntervalFromScoreAbsBound seqLen scoreAbsBound l.softmaxExpEffort l.ln1Bound = (match l.ln1VarianceLowerBound? with | some v => layerNormOpBoundLocal l.ln1MaxAbsGamma v eps sqrtPrecBits @@ -139,6 +149,8 @@ def Valid (eps : Rat) (sqrtPrecBits : Nat) (seqLen modelDim headDim : Nat) | none => layerNormOpBoundConservative l.ln2MaxAbsGamma eps sqrtPrecBits) ∧ l.ln1OutMaxAbsBound = layerNormOutputMaxAbsBound modelDim l.ln1MaxAbsGamma l.ln1MaxAbsBeta ∧ + l.softmaxProbLo = probInterval.1 ∧ + l.softmaxProbHi = probInterval.2 ∧ l.softmaxJacobianNormInfUpperBound = softmaxJacobianNormInfPortfolioBound seqLen l ∧ l.attnPatternCoeff = @@ -178,7 +190,8 @@ theorem c_eq_of_valid (eps : Rat) (sqrtPrecBits : Nat) (seqLen modelDim headDim l.C = l.attnJacBound + l.mlpJacBound + l.attnJacBound * l.mlpJacBound := by rcases h with - ⟨_hln1, _hln2, _hln1Out, _hsoftmax, _hpat, _hattn, _hmlpCoeff, _hmlp, hC⟩ + ⟨_hln1, _hln2, _hln1Out, _hProbLo, _hProbHi, + _hsoftmax, _hpat, _hattn, _hmlpCoeff, _hmlp, hC⟩ exact hC /-- Extract the attention contribution identity from `Valid`. -/ @@ -190,7 +203,8 @@ theorem attnJacBound_eq_of_valid (eps : Rat) (sqrtPrecBits : Nat) ((seqLen : Rat) * l.attnValueCoeff + l.softmaxJacobianNormInfUpperBound * l.attnPatternCoeff) := by rcases h with - ⟨_hln1, _hln2, _hln1Out, _hsoftmax, _hpat, hattn, _hmlpCoeff, _hmlp, _hC⟩ + ⟨_hln1, _hln2, _hln1Out, _hProbLo, _hProbHi, + _hsoftmax, _hpat, hattn, _hmlpCoeff, _hmlp, _hC⟩ exact hattn /-- Extract the MLP coefficient identity from `Valid`. -/ @@ -198,7 +212,8 @@ theorem mlpCoeff_eq_of_valid (eps : Rat) (sqrtPrecBits : Nat) (seqLen modelDim h (l : LayerAmplificationCert) (h : l.Valid eps sqrtPrecBits seqLen modelDim headDim) : l.mlpCoeff = l.mlpWinBound * l.mlpWoutBound := by rcases h with - ⟨_hln1, _hln2, _hln1Out, _hsoftmax, _hpat, _hattn, hCoeff, _hmlp, _hC⟩ + ⟨_hln1, _hln2, _hln1Out, _hProbLo, _hProbHi, + _hsoftmax, _hpat, _hattn, hCoeff, _hmlp, _hC⟩ exact hCoeff /-- Extract the MLP contribution identity from `Valid`. -/ @@ -207,7 +222,8 @@ theorem mlpJacBound_eq_of_valid (eps : Rat) (sqrtPrecBits : Nat) (h : l.Valid eps sqrtPrecBits seqLen modelDim headDim) : l.mlpJacBound = l.ln2Bound * (l.mlpCoeff * l.mlpActDerivBound) := by rcases h with - ⟨_hln1, _hln2, _hln1Out, _hsoftmax, _hpat, _hattn, _hmlpCoeff, hmlp, _hC⟩ + ⟨_hln1, _hln2, _hln1Out, _hProbLo, _hProbHi, + _hsoftmax, _hpat, _hattn, _hmlpCoeff, hmlp, _hC⟩ exact hmlp /-- Cast the `C` identity to `ℝ` using `Valid`. -/ @@ -464,23 +480,32 @@ def checkSoftmaxMarginZero (cert : ModelCert) : Except String Unit := theorem checkSoftmaxMarginZero_spec : checkSoftmaxMarginZero = checkSoftmaxMarginZero := rfl -/-- Ensure the softmax probability interval is the worst-case `[0,1]`. -/ -def checkSoftmaxProbIntervalWorst (cert : ModelCert) : Except String Unit := +/-! ### Softmax probability interval checks -/ + +/-- Ensure the softmax probability interval matches the derived score bound. -/ +def checkSoftmaxProbIntervalDerived (cert : ModelCert) : Except String Unit := Id.run do for idx in [:cert.layers.size] do let layer := cert.layers[idx]! - if layer.softmaxProbLo ≠ 0 then - return .error s!"softmaxProbLo is unverified (layer {idx})" - if layer.softmaxProbHi ≠ 1 then - return .error s!"softmaxProbHi is unverified (layer {idx})" + let scoreAbsBound := + attnScoreAbsBound cert.modelDim cert.headDim layer.ln1OutMaxAbsBound + layer.wqOpBoundMax layer.wkOpBoundMax + let (probLo, probHi) := + softmaxProbIntervalFromScoreAbsBound cert.seqLen scoreAbsBound + layer.softmaxExpEffort + if layer.softmaxProbLo ≠ probLo then + return .error s!"softmaxProbLo mismatch at layer {idx}" + if layer.softmaxProbHi ≠ probHi then + return .error s!"softmaxProbHi mismatch at layer {idx}" return .ok () -theorem checkSoftmaxProbIntervalWorst_spec : - checkSoftmaxProbIntervalWorst = checkSoftmaxProbIntervalWorst := rfl +theorem checkSoftmaxProbIntervalDerived_spec : + checkSoftmaxProbIntervalDerived = checkSoftmaxProbIntervalDerived := rfl /-- Update a layer certificate with best-match softmax evidence if it is valid and tighter. -/ def tightenLayerSoftmaxFromBestMatch - (seqLen : Nat) (layer : LayerAmplificationCert) (cert : LayerBestMatchMarginCert) : + (seqLen modelDim headDim : Nat) (layer : LayerAmplificationCert) + (cert : LayerBestMatchMarginCert) : Except String LayerAmplificationCert := Id.run do if !cert.check then @@ -490,8 +515,8 @@ def tightenLayerSoftmaxFromBestMatch if cert.seqLen ≠ seqLen then return .error "layer margin cert seq_len mismatch" let updated := - LayerAmplificationCert.withSoftmaxMargin seqLen cert.marginLowerBound - cert.softmaxExpEffort layer + LayerAmplificationCert.withSoftmaxMargin seqLen modelDim headDim + cert.marginLowerBound cert.softmaxExpEffort layer if updated.softmaxJacobianNormInfUpperBound > layer.softmaxJacobianNormInfUpperBound then return .error "best-match softmax bound is worse than baseline" return .ok updated @@ -509,7 +534,8 @@ def tightenModelCertBestMatchMargins | .ok cur => if cert.layerIdx < cur.layers.size then let layer := cur.layers[cert.layerIdx]! - match tightenLayerSoftmaxFromBestMatch cur.seqLen layer cert with + match tightenLayerSoftmaxFromBestMatch cur.seqLen cur.modelDim cur.headDim + layer cert with | .error e => .error e | .ok updatedLayer => match ModelCert.withUpdatedLayer cur cert.layerIdx updatedLayer with @@ -538,7 +564,7 @@ def verifyModelCert if cert.geluDerivTarget ≠ geluDerivTarget then return .error "model header gelu_kind mismatch" if cert.check then - match checkSoftmaxProbIntervalWorst cert with + match checkSoftmaxProbIntervalDerived cert with | .error e => return .error e | .ok _ => match checkSoftmaxMarginZero cert with diff --git a/Nfp/Untrusted/SoundCompute.lean b/Nfp/Untrusted/SoundCompute.lean index f7f8734..7f7017d 100644 --- a/Nfp/Untrusted/SoundCompute.lean +++ b/Nfp/Untrusted/SoundCompute.lean @@ -346,8 +346,10 @@ private def certifyModelFileGlobalBinary wqMax wkMax attnValueCoeff let mlpCoeff := nWin * nWout let mlpActDerivBound := actDerivBound - let softmaxProbLo : Rat := 0 - let softmaxProbHi : Rat := 1 + let scoreAbsBound := + attnScoreAbsBound hdr.modelDim hdr.headDim ln1OutMaxAbsBound wqMax wkMax + let (softmaxProbLo, softmaxProbHi) := + softmaxProbIntervalFromScoreAbsBound hdr.seqLen scoreAbsBound softmaxExpEffort let softmaxIntervalBound := softmaxJacobianNormInfBound softmaxProbLo softmaxProbHi let softmaxMarginBound := softmaxJacobianNormInfBoundFromMargin hdr.seqLen softmaxMarginLowerBound softmaxExpEffort @@ -813,8 +815,10 @@ def certifyModelFileGlobal attnValueCoeffLayer let mlpCoeff := mlpWin[l]! * mlpWout[l]! let mlpActDerivBound := actDerivBound - let softmaxProbLo : Rat := 0 - let softmaxProbHi : Rat := 1 + let scoreAbsBound := + attnScoreAbsBound d dh ln1OutMaxAbsBound (wqMax[l]!) (wkMax[l]!) + let (softmaxProbLo, softmaxProbHi) := + softmaxProbIntervalFromScoreAbsBound n scoreAbsBound softmaxExpEffort let softmaxIntervalBound := softmaxJacobianNormInfBound softmaxProbLo softmaxProbHi let softmaxMarginBound := softmaxJacobianNormInfBoundFromMargin n softmaxMarginLowerBound softmaxExpEffort @@ -2320,8 +2324,10 @@ private def certifyModelFileLocalText pos := nextBout let mlpOut := addConstVec mlpOut0 bout residualUnion := addVecIntervals residualUnion mlpOut - let softmaxProbLo : Rat := 0 - let softmaxProbHi : Rat := 1 + let scoreAbsBound := + attnScoreAbsBound d dh ln1OutMaxAbsBound wqMax wkMax + let (softmaxProbLo, softmaxProbHi) := + softmaxProbIntervalFromScoreAbsBound n scoreAbsBound softmaxExpEffort let softmaxIntervalBound := softmaxJacobianNormInfBound softmaxProbLo softmaxProbHi let softmaxMarginBound := @@ -2503,8 +2509,11 @@ private def certifyModelFileLocal rr := rrBout let mlpOut := addVecFixed mlpOut0 bOut residualUnion := addVecFixed residualUnion mlpOut - let softmaxProbLo : Rat := 0 - let softmaxProbHi : Rat := 1 + let scoreAbsBound := + attnScoreAbsBound modelDim headDim ln1OutMaxAbsBound (wqMaxArr[l]!) + (wkMaxArr[l]!) + let (softmaxProbLo, softmaxProbHi) := + softmaxProbIntervalFromScoreAbsBound inputSeqLen scoreAbsBound softmaxExpEffort let softmaxIntervalBound := softmaxJacobianNormInfBound softmaxProbLo softmaxProbHi let softmaxMarginBound := softmaxJacobianNormInfBoundFromMargin inputSeqLen softmaxMarginLowerBound @@ -2671,8 +2680,10 @@ private def certifyModelFileLocalBinary let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) let mlpOut := addVecFixed mlpOut0 bOut residualUnion := addVecFixed residualUnion mlpOut - let softmaxProbLo : Rat := 0 - let softmaxProbHi : Rat := 1 + let scoreAbsBound := + attnScoreAbsBound hdr.modelDim hdr.headDim ln1OutMaxAbsBound wqMax wkMax + let (softmaxProbLo, softmaxProbHi) := + softmaxProbIntervalFromScoreAbsBound hdr.seqLen scoreAbsBound softmaxExpEffort let softmaxIntervalBound := softmaxJacobianNormInfBound softmaxProbLo softmaxProbHi let softmaxMarginBound := softmaxJacobianNormInfBoundFromMargin hdr.seqLen softmaxMarginLowerBound softmaxExpEffort diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index 0bc839a..1aebd23 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -12,8 +12,10 @@ soundness upgrade. It is intentionally brief and human-readable. bounds from logits. - `partitionDepth > 0` is rejected with an explicit error (no partitioning logic yet). - Affine arithmetic is only a scaffold (`Nfp/Sound/Affine.lean`) and not wired into SOUND certification. -- Softmax Jacobian bounds in the standard `certify` path still use the worst-case probability - interval `[0,1]`; direct `--softmaxMargin` is rejected because margin evidence is unverified. +- Softmax Jacobian bounds in the standard `certify` path now derive a probability interval from a + global attention-score magnitude bound (LN1 max-abs + W_Q/W_K norms), but it is typically very + loose and often collapses to `[0,1]`. Direct `--softmaxMargin` is still rejected because margin + evidence is unverified. - Best-match margin tightening is now available via `nfp certify --bestMatchMargins` (binary + local inputs with EMBEDDINGS). It runs a full best-match sweep across heads and query positions, which can be expensive and will fail if coverage is incomplete. From 91531f43d3573761fcb1a0ee155b8586c7cf1c76 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 30 Dec 2025 03:49:01 +0100 Subject: [PATCH 014/244] Optimize induction best-match pass --- Nfp/Untrusted/SoundCompute.lean | 2762 +++++++++++++++++++++++++------ 1 file changed, 2241 insertions(+), 521 deletions(-) diff --git a/Nfp/Untrusted/SoundCompute.lean b/Nfp/Untrusted/SoundCompute.lean index 7f7017d..5cac86e 100644 --- a/Nfp/Untrusted/SoundCompute.lean +++ b/Nfp/Untrusted/SoundCompute.lean @@ -1243,6 +1243,9 @@ private def addRowsFixed out := out.push (addVecFixed rows[i]! adds[i]!) return out +private def takePrefix {α : Type} (xs : Array α) (n : Nat) : Array α := + if xs.size ≤ n then xs else xs.extract 0 n + private def mlpRowFromScaled (cfg : Fixed10Cfg) (slack : Int) @@ -1306,6 +1309,20 @@ private def unionRowsFixed out := out.set! j { lo := min cur.lo r.lo, hi := max cur.hi r.hi } return out +private def prefixUnionRowsFixed + (rows : Array (Array Fixed10Interval)) : Array (Array Fixed10Interval) := + if rows.isEmpty then + #[] + else + Id.run do + let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size + let mut acc := rows[0]! + out := out.push acc + for i in [1:rows.size] do + acc := Fixed10Interval.unionVec acc rows[i]! + out := out.push acc + return out + private def consumeMatrixMulAndNormInfFixed (cfg : Fixed10Cfg) (slack : Int) @@ -1534,6 +1551,65 @@ private def loadTokensBinary | .error e => return .error e | .ok toks => return .ok (hdr, toks) +/-- Shared binary inputs for repeated local bound checks. -/ +private structure SharedBinaryInputs where + hdr : BinaryHeader + ln1Params : Array LayerNormParamsFixed + ln2Params : Array LayerNormParamsFixed + tokens : Array Int + residuals0 : Array (Array Fixed10Interval) + inputDelta : Rat + scalePow10 : Nat + +/-- Cached prefix views for a fixed query position. -/ +private structure SharedBinaryPrefix where + seqLenEff : Nat + residuals : Thunk (Array (Array Fixed10Interval)) + tokens : Thunk (Array Int) + +/-- Load shared model/input data once for reuse across best-match configs. -/ +private def loadSharedBinaryInputs + (path : System.FilePath) + (inputPath : System.FilePath) + (inputDelta : Rat) + (scalePow10 : Nat) : + IO (Except String SharedBinaryInputs) := do + let slack : Int := fixedUlpSlack + let action : ExceptT String IO SharedBinaryInputs := do + let (hdr, ln1Params, ln2Params) ← + ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) + let residuals0 ← + ExceptT.mk + (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) + let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) + if hdrTok.seqLen ≠ hdr.seqLen then + throw "token/embedding seq_len mismatch" + return { + hdr := hdr + ln1Params := ln1Params + ln2Params := ln2Params + tokens := tokens + residuals0 := residuals0 + inputDelta := inputDelta + scalePow10 := scalePow10 + } + action.run + +/-- Build cached prefix arrays for a fixed query position. -/ +private def mkSharedBinaryPrefix + (shared : SharedBinaryInputs) + (queryPos : Nat) + (causalPattern : Bool) : + SharedBinaryPrefix := + let seqLenEff : Nat := if causalPattern then queryPos + 1 else shared.hdr.seqLen + { + seqLenEff := seqLenEff + residuals := Thunk.mk (fun () => + if causalPattern then takePrefix shared.residuals0 seqLenEff else shared.residuals0) + tokens := Thunk.mk (fun () => + if causalPattern then takePrefix shared.tokens seqLenEff else shared.tokens) + } + private def skipToUnembeddingBinary (h : IO.FS.Handle) (hdr : BinaryHeader) : IO (Except String Unit) := do let action : ExceptT String IO Unit := do @@ -1572,13 +1648,35 @@ private def certifyHeadValueLowerBoundLocalBinaryAt (targetOffset : Int) (matchWeightLowerBound : Rat) (maxSeqLen : Nat) - (scalePow10 : Nat := defaultBinaryScalePow10) : + (scalePow10 : Nat := defaultBinaryScalePow10) + (tightPattern : Bool := false) + (tightPatternLayers : Nat := 1) + (perRowPatternLayers : Nat := 0) + (causalPattern : Bool := true) + (shared? : Option SharedBinaryInputs := none) + (prefix? : Option SharedBinaryPrefix := none) : IO (Except String HeadValueLowerBoundPosCert) := do let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 let slack : Int := fixedUlpSlack let action : ExceptT String IO HeadValueLowerBoundPosCert := do - let (hdr, ln1Params, ln2Params) ← - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) + let (hdr, ln1Params, ln2Params, residualsBase, tokensBase) ← + match shared? with + | some shared => + if shared.scalePow10 ≠ scalePow10 then + throw "shared scalePow10 mismatch" + if shared.inputDelta ≠ inputDelta then + throw "shared inputDelta mismatch" + pure (shared.hdr, shared.ln1Params, shared.ln2Params, shared.residuals0, shared.tokens) + | none => + let (hdr, ln1Params, ln2Params) ← + ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) + let residuals0 ← + ExceptT.mk + (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) + let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) + if hdrTok.seqLen ≠ hdr.seqLen then + throw "token/embedding seq_len mismatch" + pure (hdr, ln1Params, ln2Params, residuals0, tokens) if layerIdx ≥ hdr.numLayers then throw s!"layer index {layerIdx} out of range" if headIdx ≥ hdr.numHeads then @@ -1589,12 +1687,18 @@ private def certifyHeadValueLowerBoundLocalBinaryAt throw s!"seq_len {hdr.seqLen} exceeds maxSeqLen {maxSeqLen}" if queryPos ≥ hdr.seqLen then throw s!"queryPos {queryPos} out of range" - let residuals0 ← - ExceptT.mk - (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) - let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) - if hdrTok.seqLen ≠ hdr.seqLen then - throw "token/embedding seq_len mismatch" + let seqLenEff : Nat := if causalPattern then queryPos + 1 else hdr.seqLen + let (residuals0, tokens) ← + match prefix? with + | some pref => + if pref.seqLenEff ≠ seqLenEff then + throw "prefix seq_len mismatch" + pure (pref.residuals.get, pref.tokens.get) + | none => + let residuals0 := + if causalPattern then takePrefix residualsBase seqLenEff else residualsBase + let tokens := if causalPattern then takePrefix tokensBase seqLenEff else tokensBase + pure (residuals0, tokens) let h ← IO.FS.Handle.mk path IO.FS.Mode.read let _ ← ExceptT.mk (readBinaryHeader h) let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) @@ -1650,7 +1754,7 @@ private def certifyHeadValueLowerBoundLocalBinaryAt | none => throw "missing W_O for requested head" | some xs => pure xs let bVIntervals := intervalsFromScaled bV slack - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen + let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty seqLenEff for row in ln1Rows do let vHidden0 := matMulIntervalsFromScaled cfg slack hdr.modelDim hdr.headDim wv row @@ -1659,25 +1763,28 @@ private def certifyHeadValueLowerBoundLocalBinaryAt hdr.headDim hdr.modelDim wo vHidden vOutRows := vOutRows.push vOut let ti : Int := (Int.ofNat queryPos) + targetOffset - if ti < 0 || ti ≥ (Int.ofNat hdr.seqLen) then + if ti < 0 || ti ≥ (Int.ofNat seqLenEff) then throw "query position has no valid target offset" let tIdx : Nat := Int.toNat ti let targetTok := tokens[tIdx]! let mut matchLo? : Option Int := none let mut nonmatchLo? : Option Int := none - for j in [:hdr.seqLen] do - let row := vOutRows[j]! - let vCoord := row[coord]!.lo - if tokens[j]! = targetTok then - matchLo? := - match matchLo? with - | none => some vCoord - | some m => some (min m vCoord) + for j in [:seqLenEff] do + if !causalPattern || j ≤ queryPos then + let row := vOutRows[j]! + let vCoord := row[coord]!.lo + if tokens[j]! = targetTok then + matchLo? := + match matchLo? with + | none => some vCoord + | some m => some (min m vCoord) + else + nonmatchLo? := + match nonmatchLo? with + | none => some vCoord + | some m => some (min m vCoord) else - nonmatchLo? := - match nonmatchLo? with - | none => some vCoord - | some m => some (min m vCoord) + pure () let matchLo ← match matchLo? with | none => throw "no matching keys for the requested offset" @@ -1704,45 +1811,116 @@ private def certifyHeadValueLowerBoundLocalBinaryAt return cert throw "head value lower bound (pos) failed internal consistency checks" else - let ln1Union := unionRowsFixed ln1Rows - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let (vHidden0, _nWv) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.headDim ln1Union scalePow10) - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let vHidden := addVecFixed vHidden0 bV - let (vOut, _nWo) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.headDim hdr.modelDim vHidden scalePow10) - attnUnion := addVecFixed attnUnion vOut - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion + let tightLayers : Nat := + if tightPattern then Nat.max 1 tightPatternLayers else 0 + if tightLayers > 0 && layerIdx ≤ l + tightLayers then + if causalPattern then + let zeroRow : Array Fixed10Interval := + Array.replicate hdr.modelDim { lo := 0, hi := 0 } + let mut attnRows : Array (Array Fixed10Interval) := + Array.replicate hdr.seqLen zeroRow + for _h in [:hdr.numHeads] do + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let wv ← ExceptT.mk <| + readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 + let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) + let wo ← ExceptT.mk <| + readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 + let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen + for row in ln1Rows do + let vHidden0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wv row + let vHidden := addVecFixed vHidden0 bV + let vOut := matMulIntervalsFromScaled cfg slack + hdr.headDim hdr.modelDim wo vHidden + vOutRows := vOutRows.push vOut + let headRows := prefixUnionRowsFixed vOutRows + attnRows := addRowsFixed attnRows headRows + let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + attnRows := addVecFixedRows attnRows attnBias + residuals := addRowsFixed residuals attnRows + else + let mut attnUnion : Array Fixed10Interval := + Array.replicate hdr.modelDim { lo := 0, hi := 0 } + let groupRows := groupUnionRowsByToken ln1Rows tokens + for _h in [:hdr.numHeads] do + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let wv ← ExceptT.mk <| + readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 + let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) + let wo ← ExceptT.mk <| + readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 + let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty groupRows.size + for row in groupRows do + let vHidden0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wv row + let vHidden := addVecFixed vHidden0 bV + let vOut := matMulIntervalsFromScaled cfg slack + hdr.headDim hdr.modelDim wo vHidden + vOutRows := vOutRows.push vOut + let vUnion := unionRowsFixed vOutRows + attnUnion := addVecFixed attnUnion vUnion + let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + attnUnion := addVecFixed attnUnion attnBias + residuals := addVecFixedRows residuals attnUnion + else + let ln1Union := unionRowsFixed ln1Rows + let mut attnUnion : Array Fixed10Interval := + Array.replicate hdr.modelDim { lo := 0, hi := 0 } + for _h in [:hdr.numHeads] do + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let (vHidden0, _nWv) ← + ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h + hdr.modelDim hdr.headDim ln1Union scalePow10) + let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) + let vHidden := addVecFixed vHidden0 bV + let (vOut, _nWo) ← + ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h + hdr.headDim hdr.modelDim vHidden scalePow10) + attnUnion := addVecFixed attnUnion vOut + let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + attnUnion := addVecFixed attnUnion attnBias + residuals := addVecFixedRows residuals attnUnion let p2 := ln2Params.getD l defP - let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen + let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty seqLenEff for row in residuals do let (ln2Out, _ln2VarLB) := fixedLayerNormRowApprox cfg row p2.gamma p2.beta eps soundnessBits ln2Rows := ln2Rows.push ln2Out - let ln2Union := unionRowsFixed ln2Rows - let (hidden0, _nWin) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Union scalePow10) - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let hiddenB := addVecFixed hidden0 bIn - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox - let (mlpOut0, _nWout) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.hiddenDim hdr.modelDim actHidden scalePow10) - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpOut := addVecFixed mlpOut0 bOut - residuals := addVecFixedRows residuals mlpOut + let perRowLayers : Nat := perRowPatternLayers + if perRowLayers > 0 && layerIdx ≤ l + perRowLayers then + let wIn ← + ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.hiddenDim) scalePow10 + let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) + let wOut ← + ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 + let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + let mlpRows := + mlpRowsFromScaled cfg slack hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows + residuals := addRowsFixed residuals mlpRows + else + let ln2Union := unionRowsFixed ln2Rows + let (hidden0, _nWin) ← + ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h + hdr.modelDim hdr.hiddenDim ln2Union scalePow10) + let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) + let hiddenB := addVecFixed hidden0 bIn + let actHidden := hiddenB.map Fixed10Interval.geluOverapprox + let (mlpOut0, _nWout) ← + ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h + hdr.hiddenDim hdr.modelDim actHidden scalePow10) + let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + let mlpOut := addVecFixed mlpOut0 bOut + residuals := addVecFixedRows residuals mlpOut let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) @@ -1750,64 +1928,15 @@ private def certifyHeadValueLowerBoundLocalBinaryAt throw "target layer not reached" action.run +/-- Combined value + optional logit certs for a single query position (binary only). -/ +private structure HeadValueLogitCert where + value : HeadValueLowerBoundPosCert + logit? : Option HeadLogitDiffLowerBoundPosCert -private def readUnembeddingColumnsBinary - (path : System.FilePath) - (tokenA tokenB : Nat) - (scalePow10 : Nat) : - IO (Except String (BinaryHeader × Array Int × Array Int)) := do - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let action : ExceptT String IO (BinaryHeader × Array Int × Array Int) := do - let hdr ← ExceptT.mk (readBinaryHeader h) - if tokenA ≥ hdr.vocabSize || tokenB ≥ hdr.vocabSize then - throw "token index out of range for unembedding" - if tokenA = tokenB then - throw "target and negative tokens must differ" - let _ ← ExceptT.mk (skipToUnembeddingBinary h hdr) - let loTok := min tokenA tokenB - let hiTok := max tokenA tokenB - let swapped : Bool := tokenA > tokenB - let mut colA : Array Int := Array.mkEmpty hdr.modelDim - let mut colB : Array Int := Array.mkEmpty hdr.modelDim - for _r in [:hdr.modelDim] do - let _ ← ExceptT.mk (skipF64Array h loTok) - let vLo ← ExceptT.mk (readScaledFloat h scalePow10) - let _ ← ExceptT.mk (skipF64Array h (hiTok - loTok - 1)) - let vHi ← ExceptT.mk (readScaledFloat h scalePow10) - let _ ← ExceptT.mk (skipF64Array h (hdr.vocabSize - hiTok - 1)) - if swapped then - colA := colA.push vHi - colB := colB.push vLo - else - colA := colA.push vLo - colB := colB.push vHi - return (hdr, colA, colB) - action.run - -private def readLogitDiffDirectionBinary - (path : System.FilePath) - (targetToken negativeToken : Nat) - (scalePow10 : Nat) - (slack : Int) : - IO (Except String (BinaryHeader × Array Fixed10Interval)) := do - let action : ExceptT String IO (BinaryHeader × Array Fixed10Interval) := do - let (hdr, colTarget, colNeg) ← - ExceptT.mk (readUnembeddingColumnsBinary path targetToken negativeToken scalePow10) - if colTarget.size ≠ hdr.modelDim || colNeg.size ≠ hdr.modelDim then - throw "unembedding column size mismatch" - let targetIntervals := intervalsFromScaled colTarget slack - let negIntervals := intervalsFromScaled colNeg slack - let mut dir : Array Fixed10Interval := Array.mkEmpty hdr.modelDim - for i in [:hdr.modelDim] do - dir := dir.push (Fixed10Interval.sub targetIntervals[i]! negIntervals[i]!) - return (hdr, dir) - action.run - -/-- Compute local head logit-difference lower bounds at a specific query position (binary only). -/ -private def certifyHeadLogitDiffLowerBoundLocalBinaryAt +/-- Compute value and optional logit bounds for a head at a query position (binary only). -/ +private def certifyHeadValueLogitLowerBoundLocalBinaryAt (path : System.FilePath) - (layerIdx headIdx queryPos : Nat) - (targetToken negativeToken : Nat) + (layerIdx headIdx queryPos coord : Nat) (eps : Rat) (soundnessBits : Nat) (inputPath : System.FilePath) @@ -1815,33 +1944,60 @@ private def certifyHeadLogitDiffLowerBoundLocalBinaryAt (targetOffset : Int) (matchWeightLowerBound : Rat) (maxSeqLen : Nat) - (scalePow10 : Nat := defaultBinaryScalePow10) : - IO (Except String HeadLogitDiffLowerBoundPosCert) := do + (scalePow10 : Nat := defaultBinaryScalePow10) + (tightPattern : Bool := false) + (tightPatternLayers : Nat := 1) + (perRowPatternLayers : Nat := 0) + (causalPattern : Bool := true) + (shared? : Option SharedBinaryInputs := none) + (prefix? : Option SharedBinaryPrefix := none) + (targetToken? : Option Nat := none) + (negativeToken? : Option Nat := none) + (direction? : Option (Thunk (Array Fixed10Interval)) := none) : + IO (Except String HeadValueLogitCert) := do let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 let slack : Int := fixedUlpSlack - let action : ExceptT String IO HeadLogitDiffLowerBoundPosCert := do - let (hdrDir, direction) ← - ExceptT.mk (readLogitDiffDirectionBinary path targetToken negativeToken scalePow10 slack) - let (hdr, ln1Params, ln2Params) ← - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) - if hdr.modelDim ≠ hdrDir.modelDim then - throw "unembedding model_dim mismatch" + let action : ExceptT String IO HeadValueLogitCert := do + let (hdr, ln1Params, ln2Params, residualsBase, tokensBase) ← + match shared? with + | some shared => + if shared.scalePow10 ≠ scalePow10 then + throw "shared scalePow10 mismatch" + if shared.inputDelta ≠ inputDelta then + throw "shared inputDelta mismatch" + pure (shared.hdr, shared.ln1Params, shared.ln2Params, shared.residuals0, shared.tokens) + | none => + let (hdr, ln1Params, ln2Params) ← + ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) + let residuals0 ← + ExceptT.mk + (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) + let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) + if hdrTok.seqLen ≠ hdr.seqLen then + throw "token/embedding seq_len mismatch" + pure (hdr, ln1Params, ln2Params, residuals0, tokens) if layerIdx ≥ hdr.numLayers then throw s!"layer index {layerIdx} out of range" if headIdx ≥ hdr.numHeads then throw s!"head index {headIdx} out of range" - if queryPos ≥ hdr.seqLen then - throw s!"queryPos {queryPos} out of range" + if coord ≥ hdr.modelDim then + throw s!"coord index {coord} out of range" if hdr.seqLen > maxSeqLen then throw s!"seq_len {hdr.seqLen} exceeds maxSeqLen {maxSeqLen}" - if direction.size ≠ hdr.modelDim then - throw "logit direction size mismatch" - let residuals0 ← - ExceptT.mk - (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) - let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) - if hdrTok.seqLen ≠ hdr.seqLen then - throw "token/embedding seq_len mismatch" + if queryPos ≥ hdr.seqLen then + throw s!"queryPos {queryPos} out of range" + let seqLenEff : Nat := if causalPattern then queryPos + 1 else hdr.seqLen + let (residuals0, tokens) ← + match prefix? with + | some pref => + if pref.seqLenEff ≠ seqLenEff then + throw "prefix seq_len mismatch" + pure (pref.residuals.get, pref.tokens.get) + | none => + let residuals0 := + if causalPattern then takePrefix residualsBase seqLenEff else residualsBase + let tokens := if causalPattern then takePrefix tokensBase seqLenEff else tokensBase + pure (residuals0, tokens) let h ← IO.FS.Handle.mk path IO.FS.Mode.read let _ ← ExceptT.mk (readBinaryHeader h) let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) @@ -1897,7 +2053,7 @@ private def certifyHeadLogitDiffLowerBoundLocalBinaryAt | none => throw "missing W_O for requested head" | some xs => pure xs let bVIntervals := intervalsFromScaled bV slack - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen + let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty seqLenEff for row in ln1Rows do let vHidden0 := matMulIntervalsFromScaled cfg slack hdr.modelDim hdr.headDim wv row @@ -1905,28 +2061,29 @@ private def certifyHeadLogitDiffLowerBoundLocalBinaryAt let vOut := matMulIntervalsFromScaled cfg slack hdr.headDim hdr.modelDim wo vHidden vOutRows := vOutRows.push vOut - let mut vDotRows : Array Fixed10Interval := Array.mkEmpty hdr.seqLen - for row in vOutRows do - vDotRows := vDotRows.push (fixedDotInterval cfg row direction) let ti : Int := (Int.ofNat queryPos) + targetOffset - if ti < 0 || ti ≥ (Int.ofNat hdr.seqLen) then + if ti < 0 || ti ≥ (Int.ofNat seqLenEff) then throw "query position has no valid target offset" let tIdx : Nat := Int.toNat ti let targetTok := tokens[tIdx]! let mut matchLo? : Option Int := none let mut nonmatchLo? : Option Int := none - for j in [:hdr.seqLen] do - let vLo := (vDotRows[j]!).lo - if tokens[j]! = targetTok then - matchLo? := - match matchLo? with - | none => some vLo - | some m => some (min m vLo) + for j in [:seqLenEff] do + if !causalPattern || j ≤ queryPos then + let row := vOutRows[j]! + let vCoord := row[coord]!.lo + if tokens[j]! = targetTok then + matchLo? := + match matchLo? with + | none => some vCoord + | some m => some (min m vCoord) + else + nonmatchLo? := + match nonmatchLo? with + | none => some vCoord + | some m => some (min m vCoord) else - nonmatchLo? := - match nonmatchLo? with - | none => some vLo - | some m => some (min m vLo) + pure () let matchLo ← match matchLo? with | none => throw "no matching keys for the requested offset" @@ -1937,62 +2094,550 @@ private def certifyHeadLogitDiffLowerBoundLocalBinaryAt | some v => v let matchLoRat := ratOfScaledInt scalePow10 matchLo let nonmatchLoRat := ratOfScaledInt scalePow10 nonmatchLo - let weightLB := matchWeightLowerBound - let outputLB := mixLowerBound weightLB matchLoRat nonmatchLoRat - let cert : HeadLogitDiffLowerBoundPosCert := { + let outputLB := mixLowerBound matchWeightLowerBound matchLoRat nonmatchLoRat + let value : HeadValueLowerBoundPosCert := { layerIdx := layerIdx headIdx := headIdx queryPos := queryPos - targetToken := targetToken - negativeToken := negativeToken - matchWeightLowerBound := weightLB - matchLogitLowerBound := matchLoRat - nonmatchLogitLowerBound := nonmatchLoRat - logitDiffLowerBound := outputLB + coord := coord + matchWeightLowerBound := matchWeightLowerBound + matchCoordLowerBound := matchLoRat + nonmatchCoordLowerBound := nonmatchLoRat + outputCoordLowerBound := outputLB } - if cert.check then - return cert - throw "head logit lower bound (pos) failed internal consistency checks" + if !value.check then + throw "head value certificate failed internal consistency checks" + let logit? ← + match targetToken?, negativeToken?, direction? with + | none, none, none => pure none + | some targetToken, some negativeToken, some direction => do + let dir := direction.get + if dir.size ≠ hdr.modelDim then + throw "logit direction size mismatch" + let mut vDotRows : Array Fixed10Interval := Array.mkEmpty seqLenEff + for row in vOutRows do + vDotRows := vDotRows.push (fixedDotInterval cfg row dir) + let mut matchLoLogit? : Option Int := none + let mut nonmatchLoLogit? : Option Int := none + for j in [:seqLenEff] do + if !causalPattern || j ≤ queryPos then + let vLo := (vDotRows[j]!).lo + if tokens[j]! = targetTok then + matchLoLogit? := + match matchLoLogit? with + | none => some vLo + | some m => some (min m vLo) + else + nonmatchLoLogit? := + match nonmatchLoLogit? with + | none => some vLo + | some m => some (min m vLo) + else + pure () + let matchLoLogit ← + match matchLoLogit? with + | none => throw "no matching keys for the requested offset" + | some v => pure v + let nonmatchLoLogit := + match nonmatchLoLogit? with + | none => matchLoLogit + | some v => v + let matchLoRat := ratOfScaledInt scalePow10 matchLoLogit + let nonmatchLoRat := ratOfScaledInt scalePow10 nonmatchLoLogit + let logitLB := mixLowerBound matchWeightLowerBound matchLoRat nonmatchLoRat + let logitCert : HeadLogitDiffLowerBoundPosCert := { + layerIdx := layerIdx + headIdx := headIdx + queryPos := queryPos + targetToken := targetToken + negativeToken := negativeToken + matchWeightLowerBound := matchWeightLowerBound + matchLogitLowerBound := matchLoRat + nonmatchLogitLowerBound := nonmatchLoRat + logitDiffLowerBound := logitLB + } + if logitCert.check then + pure (some logitCert) + else + throw "head logit certificate failed internal consistency checks" + | _, _, _ => + throw "use both target and negative tokens (or neither)" + return { value := value, logit? := logit? } else - let ln1Union := unionRowsFixed ln1Rows - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let (vHidden0, _nWv) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.headDim ln1Union scalePow10) - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let vHidden := addVecFixed vHidden0 bV - let (vOut, _nWo) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.headDim hdr.modelDim vHidden scalePow10) - attnUnion := addVecFixed attnUnion vOut - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion + let tightLayers : Nat := + if tightPattern then Nat.max 1 tightPatternLayers else 0 + if tightLayers > 0 && layerIdx ≤ l + tightLayers then + if causalPattern then + let zeroRow : Array Fixed10Interval := + Array.replicate hdr.modelDim { lo := 0, hi := 0 } + let mut attnRows : Array (Array Fixed10Interval) := + Array.replicate hdr.seqLen zeroRow + for _h in [:hdr.numHeads] do + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let wv ← ExceptT.mk <| + readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 + let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) + let wo ← ExceptT.mk <| + readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 + let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen + for row in ln1Rows do + let vHidden0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wv row + let vHidden := addVecFixed vHidden0 bV + let vOut := matMulIntervalsFromScaled cfg slack + hdr.headDim hdr.modelDim wo vHidden + vOutRows := vOutRows.push vOut + let headRows := prefixUnionRowsFixed vOutRows + attnRows := addRowsFixed attnRows headRows + let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + attnRows := addVecFixedRows attnRows attnBias + residuals := addRowsFixed residuals attnRows + else + let mut attnUnion : Array Fixed10Interval := + Array.replicate hdr.modelDim { lo := 0, hi := 0 } + let groupRows := groupUnionRowsByToken ln1Rows tokens + for _h in [:hdr.numHeads] do + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let wv ← ExceptT.mk <| + readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 + let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) + let wo ← ExceptT.mk <| + readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 + let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty groupRows.size + for row in groupRows do + let vHidden0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wv row + let vHidden := addVecFixed vHidden0 bV + let vOut := matMulIntervalsFromScaled cfg slack + hdr.headDim hdr.modelDim wo vHidden + vOutRows := vOutRows.push vOut + let vUnion := unionRowsFixed vOutRows + attnUnion := addVecFixed attnUnion vUnion + let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + attnUnion := addVecFixed attnUnion attnBias + residuals := addVecFixedRows residuals attnUnion + else + let ln1Union := unionRowsFixed ln1Rows + let mut attnUnion : Array Fixed10Interval := + Array.replicate hdr.modelDim { lo := 0, hi := 0 } + for _h in [:hdr.numHeads] do + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let (vHidden0, _nWv) ← + ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h + hdr.modelDim hdr.headDim ln1Union scalePow10) + let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) + let vHidden := addVecFixed vHidden0 bV + let (vOut, _nWo) ← + ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h + hdr.headDim hdr.modelDim vHidden scalePow10) + attnUnion := addVecFixed attnUnion vOut + let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + attnUnion := addVecFixed attnUnion attnBias + residuals := addVecFixedRows residuals attnUnion + let p2 := ln2Params.getD l defP + let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty seqLenEff + for row in residuals do + let (ln2Out, _ln2VarLB) := + fixedLayerNormRowApprox cfg row p2.gamma p2.beta eps soundnessBits + ln2Rows := ln2Rows.push ln2Out + let perRowLayers : Nat := perRowPatternLayers + if perRowLayers > 0 && layerIdx ≤ l + perRowLayers then + let wIn ← + ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.hiddenDim) scalePow10 + let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) + let wOut ← + ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 + let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + let mlpRows := + mlpRowsFromScaled cfg slack hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows + residuals := addRowsFixed residuals mlpRows + else + let ln2Union := unionRowsFixed ln2Rows + let (hidden0, _nWin) ← + ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h + hdr.modelDim hdr.hiddenDim ln2Union scalePow10) + let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) + let hiddenB := addVecFixed hidden0 bIn + let actHidden := hiddenB.map Fixed10Interval.geluOverapprox + let (mlpOut0, _nWout) ← + ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h + hdr.hiddenDim hdr.modelDim actHidden scalePow10) + let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + let mlpOut := addVecFixed mlpOut0 bOut + residuals := addVecFixedRows residuals mlpOut + let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) + let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) + let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) + let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) + throw "target layer not reached" + action.run + + +private def readUnembeddingColumnsBinary + (path : System.FilePath) + (tokenA tokenB : Nat) + (scalePow10 : Nat) : + IO (Except String (BinaryHeader × Array Int × Array Int)) := do + let h ← IO.FS.Handle.mk path IO.FS.Mode.read + let action : ExceptT String IO (BinaryHeader × Array Int × Array Int) := do + let hdr ← ExceptT.mk (readBinaryHeader h) + if tokenA ≥ hdr.vocabSize || tokenB ≥ hdr.vocabSize then + throw "token index out of range for unembedding" + if tokenA = tokenB then + throw "target and negative tokens must differ" + let _ ← ExceptT.mk (skipToUnembeddingBinary h hdr) + let loTok := min tokenA tokenB + let hiTok := max tokenA tokenB + let swapped : Bool := tokenA > tokenB + let mut colA : Array Int := Array.mkEmpty hdr.modelDim + let mut colB : Array Int := Array.mkEmpty hdr.modelDim + for _r in [:hdr.modelDim] do + let _ ← ExceptT.mk (skipF64Array h loTok) + let vLo ← ExceptT.mk (readScaledFloat h scalePow10) + let _ ← ExceptT.mk (skipF64Array h (hiTok - loTok - 1)) + let vHi ← ExceptT.mk (readScaledFloat h scalePow10) + let _ ← ExceptT.mk (skipF64Array h (hdr.vocabSize - hiTok - 1)) + if swapped then + colA := colA.push vHi + colB := colB.push vLo + else + colA := colA.push vLo + colB := colB.push vHi + return (hdr, colA, colB) + action.run + +private def readLogitDiffDirectionBinary + (path : System.FilePath) + (targetToken negativeToken : Nat) + (scalePow10 : Nat) + (slack : Int) : + IO (Except String (BinaryHeader × Array Fixed10Interval)) := do + let action : ExceptT String IO (BinaryHeader × Array Fixed10Interval) := do + let (hdr, colTarget, colNeg) ← + ExceptT.mk (readUnembeddingColumnsBinary path targetToken negativeToken scalePow10) + if colTarget.size ≠ hdr.modelDim || colNeg.size ≠ hdr.modelDim then + throw "unembedding column size mismatch" + let targetIntervals := intervalsFromScaled colTarget slack + let negIntervals := intervalsFromScaled colNeg slack + let mut dir : Array Fixed10Interval := Array.mkEmpty hdr.modelDim + for i in [:hdr.modelDim] do + dir := dir.push (Fixed10Interval.sub targetIntervals[i]! negIntervals[i]!) + return (hdr, dir) + action.run + +/-- Compute local head logit-difference lower bounds at a specific query position (binary only). -/ +private def certifyHeadLogitDiffLowerBoundLocalBinaryAt + (path : System.FilePath) + (layerIdx headIdx queryPos : Nat) + (targetToken negativeToken : Nat) + (eps : Rat) + (soundnessBits : Nat) + (inputPath : System.FilePath) + (inputDelta : Rat) + (targetOffset : Int) + (matchWeightLowerBound : Rat) + (maxSeqLen : Nat) + (scalePow10 : Nat := defaultBinaryScalePow10) + (tightPattern : Bool := false) + (tightPatternLayers : Nat := 1) + (perRowPatternLayers : Nat := 0) + (causalPattern : Bool := true) + (shared? : Option SharedBinaryInputs := none) + (prefix? : Option SharedBinaryPrefix := none) + (direction? : Option (Thunk (Array Fixed10Interval)) := none) : + IO (Except String HeadLogitDiffLowerBoundPosCert) := do + let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 + let slack : Int := fixedUlpSlack + let action : ExceptT String IO HeadLogitDiffLowerBoundPosCert := do + let (direction, hdrDir?) ← + match direction? with + | some thunk => pure (thunk.get, none) + | none => + let (hdrDir, dir) ← + ExceptT.mk <| + readLogitDiffDirectionBinary path targetToken negativeToken scalePow10 slack + pure (dir, some hdrDir) + let (hdr, ln1Params, ln2Params, residualsBase, tokensBase) ← + match shared? with + | some shared => + if shared.scalePow10 ≠ scalePow10 then + throw "shared scalePow10 mismatch" + if shared.inputDelta ≠ inputDelta then + throw "shared inputDelta mismatch" + pure (shared.hdr, shared.ln1Params, shared.ln2Params, shared.residuals0, shared.tokens) + | none => + let (hdr, ln1Params, ln2Params) ← + ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) + let residuals0 ← + ExceptT.mk + (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) + let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) + if hdrTok.seqLen ≠ hdr.seqLen then + throw "token/embedding seq_len mismatch" + pure (hdr, ln1Params, ln2Params, residuals0, tokens) + match hdrDir? with + | some hdrDir => + if hdr.modelDim ≠ hdrDir.modelDim then + throw "unembedding model_dim mismatch" + | none => + if direction.size ≠ hdr.modelDim then + throw "logit direction size mismatch" + if layerIdx ≥ hdr.numLayers then + throw s!"layer index {layerIdx} out of range" + if headIdx ≥ hdr.numHeads then + throw s!"head index {headIdx} out of range" + if queryPos ≥ hdr.seqLen then + throw s!"queryPos {queryPos} out of range" + if hdr.seqLen > maxSeqLen then + throw s!"seq_len {hdr.seqLen} exceeds maxSeqLen {maxSeqLen}" + let seqLenEff : Nat := if causalPattern then queryPos + 1 else hdr.seqLen + if direction.size ≠ hdr.modelDim then + throw "logit direction size mismatch" + let (residuals0, tokens) ← + match prefix? with + | some pref => + if pref.seqLenEff ≠ seqLenEff then + throw "prefix seq_len mismatch" + pure (pref.residuals.get, pref.tokens.get) + | none => + let residuals0 := + if causalPattern then takePrefix residualsBase seqLenEff else residualsBase + let tokens := if causalPattern then takePrefix tokensBase seqLenEff else tokensBase + pure (residuals0, tokens) + let h ← IO.FS.Handle.mk path IO.FS.Mode.read + let _ ← ExceptT.mk (readBinaryHeader h) + let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) + let _ ← ExceptT.mk (skipF64Array h (hdr.seqLen * hdr.modelDim)) + let defP : LayerNormParamsFixed := { + gamma := Array.replicate hdr.modelDim { lo := 0, hi := 0 } + beta := Array.replicate hdr.modelDim { lo := 0, hi := 0 } + } + let mut residuals := residuals0 + for l in [:hdr.numLayers] do + let p1 := ln1Params.getD l defP + let mut ln1Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen + for row in residuals do + let (ln1Out, _ln1VarLB) := + fixedLayerNormRowApprox cfg row p1.gamma p1.beta eps soundnessBits + ln1Rows := ln1Rows.push ln1Out + if l = layerIdx then + let mut wv? : Option (Array Int) := none + let mut bv? : Option (Array Int) := none + let mut wo? : Option (Array Int) := none + for hIdx in [:hdr.numHeads] do + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + if hIdx = headIdx then + let wv ← + ExceptT.mk <| + readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 + wv? := some wv + let bV ← + ExceptT.mk <| + readScaledFloatArray h hdr.headDim scalePow10 + bv? := some bV + let wo ← + ExceptT.mk <| + readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 + wo? := some wo + else + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let _ ← ExceptT.mk (skipF64Array h (hdr.headDim * hdr.modelDim)) + let wv ← + match wv? with + | none => throw "missing W_V for requested head" + | some xs => pure xs + let bV ← + match bv? with + | none => throw "missing b_V for requested head" + | some xs => pure xs + let wo ← + match wo? with + | none => throw "missing W_O for requested head" + | some xs => pure xs + let bVIntervals := intervalsFromScaled bV slack + let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty seqLenEff + for row in ln1Rows do + let vHidden0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wv row + let vHidden := addVecFixed vHidden0 bVIntervals + let vOut := matMulIntervalsFromScaled cfg slack + hdr.headDim hdr.modelDim wo vHidden + vOutRows := vOutRows.push vOut + let mut vDotRows : Array Fixed10Interval := Array.mkEmpty seqLenEff + for row in vOutRows do + vDotRows := vDotRows.push (fixedDotInterval cfg row direction) + let ti : Int := (Int.ofNat queryPos) + targetOffset + if ti < 0 || ti ≥ (Int.ofNat seqLenEff) then + throw "query position has no valid target offset" + let tIdx : Nat := Int.toNat ti + let targetTok := tokens[tIdx]! + let mut matchLo? : Option Int := none + let mut nonmatchLo? : Option Int := none + for j in [:seqLenEff] do + if !causalPattern || j ≤ queryPos then + let vLo := (vDotRows[j]!).lo + if tokens[j]! = targetTok then + matchLo? := + match matchLo? with + | none => some vLo + | some m => some (min m vLo) + else + nonmatchLo? := + match nonmatchLo? with + | none => some vLo + | some m => some (min m vLo) + else + pure () + let matchLo ← + match matchLo? with + | none => throw "no matching keys for the requested offset" + | some v => pure v + let nonmatchLo := + match nonmatchLo? with + | none => matchLo + | some v => v + let matchLoRat := ratOfScaledInt scalePow10 matchLo + let nonmatchLoRat := ratOfScaledInt scalePow10 nonmatchLo + let weightLB := matchWeightLowerBound + let outputLB := mixLowerBound weightLB matchLoRat nonmatchLoRat + let cert : HeadLogitDiffLowerBoundPosCert := { + layerIdx := layerIdx + headIdx := headIdx + queryPos := queryPos + targetToken := targetToken + negativeToken := negativeToken + matchWeightLowerBound := weightLB + matchLogitLowerBound := matchLoRat + nonmatchLogitLowerBound := nonmatchLoRat + logitDiffLowerBound := outputLB + } + if cert.check then + return cert + throw "head logit lower bound (pos) failed internal consistency checks" + else + let tightLayers : Nat := + if tightPattern then Nat.max 1 tightPatternLayers else 0 + if tightLayers > 0 && layerIdx ≤ l + tightLayers then + if causalPattern then + let zeroRow : Array Fixed10Interval := + Array.replicate hdr.modelDim { lo := 0, hi := 0 } + let mut attnRows : Array (Array Fixed10Interval) := + Array.replicate hdr.seqLen zeroRow + for _h in [:hdr.numHeads] do + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let wv ← ExceptT.mk <| + readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 + let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) + let wo ← ExceptT.mk <| + readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 + let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen + for row in ln1Rows do + let vHidden0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wv row + let vHidden := addVecFixed vHidden0 bV + let vOut := matMulIntervalsFromScaled cfg slack + hdr.headDim hdr.modelDim wo vHidden + vOutRows := vOutRows.push vOut + let headRows := prefixUnionRowsFixed vOutRows + attnRows := addRowsFixed attnRows headRows + let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + attnRows := addVecFixedRows attnRows attnBias + residuals := addRowsFixed residuals attnRows + else + let mut attnUnion : Array Fixed10Interval := + Array.replicate hdr.modelDim { lo := 0, hi := 0 } + let groupRows := groupUnionRowsByToken ln1Rows tokens + for _h in [:hdr.numHeads] do + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let wv ← ExceptT.mk <| + readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 + let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) + let wo ← ExceptT.mk <| + readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 + let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty groupRows.size + for row in groupRows do + let vHidden0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wv row + let vHidden := addVecFixed vHidden0 bV + let vOut := matMulIntervalsFromScaled cfg slack + hdr.headDim hdr.modelDim wo vHidden + vOutRows := vOutRows.push vOut + let vUnion := unionRowsFixed vOutRows + attnUnion := addVecFixed attnUnion vUnion + let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + attnUnion := addVecFixed attnUnion attnBias + residuals := addVecFixedRows residuals attnUnion + else + let ln1Union := unionRowsFixed ln1Rows + let mut attnUnion : Array Fixed10Interval := + Array.replicate hdr.modelDim { lo := 0, hi := 0 } + for _h in [:hdr.numHeads] do + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let (vHidden0, _nWv) ← + ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h + hdr.modelDim hdr.headDim ln1Union scalePow10) + let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) + let vHidden := addVecFixed vHidden0 bV + let (vOut, _nWo) ← + ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h + hdr.headDim hdr.modelDim vHidden scalePow10) + attnUnion := addVecFixed attnUnion vOut + let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + attnUnion := addVecFixed attnUnion attnBias + residuals := addVecFixedRows residuals attnUnion let p2 := ln2Params.getD l defP - let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen + let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty seqLenEff for row in residuals do let (ln2Out, _ln2VarLB) := fixedLayerNormRowApprox cfg row p2.gamma p2.beta eps soundnessBits ln2Rows := ln2Rows.push ln2Out - let ln2Union := unionRowsFixed ln2Rows - let (hidden0, _nWin) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Union scalePow10) - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let hiddenB := addVecFixed hidden0 bIn - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox - let (mlpOut0, _nWout) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.hiddenDim hdr.modelDim actHidden scalePow10) - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpOut := addVecFixed mlpOut0 bOut - residuals := addVecFixedRows residuals mlpOut + let perRowLayers : Nat := perRowPatternLayers + if perRowLayers > 0 && layerIdx ≤ l + perRowLayers then + let wIn ← + ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.hiddenDim) scalePow10 + let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) + let wOut ← + ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 + let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + let mlpRows := + mlpRowsFromScaled cfg slack hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows + residuals := addRowsFixed residuals mlpRows + else + let ln2Union := unionRowsFixed ln2Rows + let (hidden0, _nWin) ← + ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h + hdr.modelDim hdr.hiddenDim ln2Union scalePow10) + let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) + let hiddenB := addVecFixed hidden0 bIn + let actHidden := hiddenB.map Fixed10Interval.geluOverapprox + let (mlpOut0, _nWout) ← + ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h + hdr.hiddenDim hdr.modelDim actHidden scalePow10) + let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + let mlpOut := addVecFixed mlpOut0 bOut + residuals := addVecFixedRows residuals mlpOut let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) @@ -2871,7 +3516,8 @@ private def certifyHeadPatternLocalBinary (tightPatternLayers : Nat) (perRowPatternLayers : Nat) (scalePow10 : Nat := defaultBinaryScalePow10) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : + (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) + (causalPattern : Bool := true) : IO (Except String HeadPatternCert) := do let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 let slack : Int := fixedUlpSlack @@ -2908,26 +3554,36 @@ private def certifyHeadPatternLocalBinary ln1Rows := ln1Rows.push ln1Out if l = layerIdx then let mut wq? : Option (Array Int) := none + let mut bq? : Option (Array Int) := none let mut wk? : Option (Array Int) := none + let mut bk? : Option (Array Int) := none for hIdx in [:hdr.numHeads] do if hIdx = headIdx then let wq ← ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 wq? := some wq + let bQ ← + ExceptT.mk <| + readScaledFloatArray h hdr.headDim scalePow10 + bq? := some bQ else let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) pure () - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) if hIdx = headIdx then let wk ← ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 wk? := some wk + let bK ← + ExceptT.mk <| + readScaledFloatArray h hdr.headDim scalePow10 + bk? := some bK else let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) pure () - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) let _ ← ExceptT.mk (skipF64Array h hdr.headDim) let _ ← ExceptT.mk (skipF64Array h (hdr.headDim * hdr.modelDim)) @@ -2935,17 +3591,29 @@ private def certifyHeadPatternLocalBinary match wq? with | none => throw "missing W_Q for requested head" | some xs => pure xs + let bQ ← + match bq? with + | none => throw "missing b_Q for requested head" + | some xs => pure xs let wk ← match wk? with | none => throw "missing W_K for requested head" | some xs => pure xs + let bK ← + match bk? with + | none => throw "missing b_K for requested head" + | some xs => pure xs + let bQIntervals := intervalsFromScaled bQ slack + let bKIntervals := intervalsFromScaled bK slack let mut qRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen let mut kRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen for row in ln1Rows do - qRows := qRows.push (matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wq row) - kRows := kRows.push (matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wk row) + let qRow0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wq row + let kRow0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wk row + qRows := qRows.push (addVecFixed qRow0 bQIntervals) + kRows := kRows.push (addVecFixed kRow0 bKIntervals) let mut minTargetLower? : Option Int := none let mut maxOtherUpper? : Option Int := none let mut minTargetCount? : Option Nat := none @@ -2962,23 +3630,26 @@ private def certifyHeadPatternLocalBinary let mut maxOtherUpperRow? : Option Int := none let mut targetCount : Nat := 0 for j in [:hdr.seqLen] do - let dot := fixedDotInterval cfg qRow (kRows[j]!) - if tokens[j]! = targetTok then - targetCount := targetCount + 1 - targetLower? := - match targetLower? with - | none => some dot.lo - | some m => some (min m dot.lo) - targetMaxLower? := - match targetMaxLower? with - | none => some dot.lo - | some m => some (max m dot.lo) + if !causalPattern || j ≤ i then + let dot := fixedDotInterval cfg qRow (kRows[j]!) + if tokens[j]! = targetTok then + targetCount := targetCount + 1 + targetLower? := + match targetLower? with + | none => some dot.lo + | some m => some (min m dot.lo) + targetMaxLower? := + match targetMaxLower? with + | none => some dot.lo + | some m => some (max m dot.lo) + else + let cur := dot.hi + maxOtherUpperRow? := + match maxOtherUpperRow? with + | none => some cur + | some m => some (max m cur) else - let cur := dot.hi - maxOtherUpperRow? := - match maxOtherUpperRow? with - | none => some cur - | some m => some (max m cur) + pure () let targetLowerRow? := if tightPattern then targetMaxLower? else targetLower? match targetLowerRow? with @@ -3039,32 +3710,61 @@ private def certifyHeadPatternLocalBinary let tightLayers : Nat := if tightPattern then Nat.max 1 tightPatternLayers else 0 if tightLayers > 0 && layerIdx ≤ l + tightLayers then - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - let groupRows := groupUnionRowsByToken ln1Rows tokens - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wv ← ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let wo ← ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty groupRows.size - for row in groupRows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let vUnion := unionRowsFixed vOutRows - attnUnion := addVecFixed attnUnion vUnion - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion + if causalPattern then + let zeroRow : Array Fixed10Interval := + Array.replicate hdr.modelDim { lo := 0, hi := 0 } + let mut attnRows : Array (Array Fixed10Interval) := + Array.replicate hdr.seqLen zeroRow + for _h in [:hdr.numHeads] do + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let wv ← ExceptT.mk <| + readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 + let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) + let wo ← ExceptT.mk <| + readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 + let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen + for row in ln1Rows do + let vHidden0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wv row + let vHidden := addVecFixed vHidden0 bV + let vOut := matMulIntervalsFromScaled cfg slack + hdr.headDim hdr.modelDim wo vHidden + vOutRows := vOutRows.push vOut + let headRows := prefixUnionRowsFixed vOutRows + attnRows := addRowsFixed attnRows headRows + let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + attnRows := addVecFixedRows attnRows attnBias + residuals := addRowsFixed residuals attnRows + else + let mut attnUnion : Array Fixed10Interval := + Array.replicate hdr.modelDim { lo := 0, hi := 0 } + let groupRows := groupUnionRowsByToken ln1Rows tokens + for _h in [:hdr.numHeads] do + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let wv ← ExceptT.mk <| + readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 + let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) + let wo ← ExceptT.mk <| + readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 + let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty groupRows.size + for row in groupRows do + let vHidden0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wv row + let vHidden := addVecFixed vHidden0 bV + let vOut := matMulIntervalsFromScaled cfg slack + hdr.headDim hdr.modelDim wo vHidden + vOutRows := vOutRows.push vOut + let vUnion := unionRowsFixed vOutRows + attnUnion := addVecFixed attnUnion vUnion + let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + attnUnion := addVecFixed attnUnion attnBias + residuals := addVecFixedRows residuals attnUnion let p2 := ln2Params.getD l defP let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen for row in residuals do @@ -3147,6 +3847,40 @@ private def certifyHeadPatternLocalBinary throw "target layer not reached" action.run +/-- Minimum relative improvement required to keep increasing softmax exp effort. -/ +private def defaultSoftmaxEffortMinRelImprove : Rat := (1 : Rat) / 100 + +/-- Choose a softmax exp effort by iterating until improvements are negligible. -/ +private def chooseSoftmaxExpEffort + (seqLen : Nat) (margin : Rat) (maxEffort : Nat) : + Nat × Rat × Rat := + let startEffort : Nat := if maxEffort = 0 then 0 else 1 + let weight0 : Rat := softmaxMaxProbLowerBound seqLen margin startEffort + let jac0 : Rat := softmaxJacobianNormInfBoundFromMargin seqLen margin startEffort + if startEffort ≥ maxEffort then + (startEffort, weight0, jac0) + else + Id.run do + let mut bestEff : Nat := startEffort + let mut bestWeight : Rat := weight0 + let mut bestJac : Rat := jac0 + let mut eff : Nat := startEffort + while eff < maxEffort do + eff := eff + 1 + let weight := softmaxMaxProbLowerBound seqLen margin eff + let jac := softmaxJacobianNormInfBoundFromMargin seqLen margin eff + if jac < bestJac then + let relImprove := + if bestJac = 0 then 0 else (bestJac - jac) / bestJac + bestEff := eff + bestWeight := weight + bestJac := jac + if relImprove < defaultSoftmaxEffortMinRelImprove then + eff := maxEffort + else + eff := maxEffort + return (bestEff, bestWeight, bestJac) + /-- Compute local head best-match pattern bounds for a specific `.nfpt` head (binary only). -/ private def certifyHeadPatternBestMatchLocalBinary (path : System.FilePath) @@ -3162,13 +3896,32 @@ private def certifyHeadPatternBestMatchLocalBinary (tightPatternLayers : Nat) (perRowPatternLayers : Nat) (scalePow10 : Nat := defaultBinaryScalePow10) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : + (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) + (causalPattern : Bool := true) + (shared? : Option SharedBinaryInputs := none) + (prefix? : Option SharedBinaryPrefix := none) : IO (Except String HeadBestMatchPatternCert) := do let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 let slack : Int := fixedUlpSlack let action : ExceptT String IO HeadBestMatchPatternCert := do - let (hdr, ln1Params, ln2Params) ← - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) + let (hdr, ln1Params, ln2Params, residualsBase, tokensBase) ← + match shared? with + | some shared => + if shared.scalePow10 ≠ scalePow10 then + throw "shared scalePow10 mismatch" + if shared.inputDelta ≠ inputDelta then + throw "shared inputDelta mismatch" + pure (shared.hdr, shared.ln1Params, shared.ln2Params, shared.residuals0, shared.tokens) + | none => + let (hdr, ln1Params, ln2Params) ← + ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) + let residuals0 ← + ExceptT.mk + (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) + let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) + if hdrTok.seqLen ≠ hdr.seqLen then + throw "token/embedding seq_len mismatch" + pure (hdr, ln1Params, ln2Params, residuals0, tokens) if layerIdx ≥ hdr.numLayers then throw s!"layer index {layerIdx} out of range" if headIdx ≥ hdr.numHeads then @@ -3182,12 +3935,18 @@ private def certifyHeadPatternBestMatchLocalBinary if hdr.seqLen = 0 then 0 else hdr.seqLen - 1 if queryPos ≥ hdr.seqLen then throw s!"queryPos {queryPos} out of range" - let residuals0 ← - ExceptT.mk - (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) - let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) - if hdrTok.seqLen ≠ hdr.seqLen then - throw "token/embedding seq_len mismatch" + let seqLenEff : Nat := if causalPattern then queryPos + 1 else hdr.seqLen + let (residuals0, tokens) ← + match prefix? with + | some pref => + if pref.seqLenEff ≠ seqLenEff then + throw "prefix seq_len mismatch" + pure (pref.residuals.get, pref.tokens.get) + | none => + let residuals0 := + if causalPattern then takePrefix residualsBase seqLenEff else residualsBase + let tokens := if causalPattern then takePrefix tokensBase seqLenEff else tokensBase + pure (residuals0, tokens) let h ← IO.FS.Handle.mk path IO.FS.Mode.read let _ ← ExceptT.mk (readBinaryHeader h) let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) @@ -3206,19 +3965,27 @@ private def certifyHeadPatternBestMatchLocalBinary ln1Rows := ln1Rows.push ln1Out if l = layerIdx then let mut wq? : Option (Array Int) := none + let mut bq? : Option (Array Int) := none let mut wk? : Option (Array Int) := none + let mut bk? : Option (Array Int) := none for hIdx in [:hdr.numHeads] do if hIdx = headIdx then let wq ← ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 wq? := some wq - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let bQ ← + ExceptT.mk <| + readScaledFloatArray h hdr.headDim scalePow10 + bq? := some bQ let wk ← ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 wk? := some wk - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let bK ← + ExceptT.mk <| + readScaledFloatArray h hdr.headDim scalePow10 + bk? := some bK else let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) let _ ← ExceptT.mk (skipF64Array h hdr.headDim) @@ -3231,35 +3998,50 @@ private def certifyHeadPatternBestMatchLocalBinary match wq? with | none => throw "missing W_Q for requested head" | some xs => pure xs + let bQ ← + match bq? with + | none => throw "missing b_Q for requested head" + | some xs => pure xs let wk ← match wk? with | none => throw "missing W_K for requested head" | some xs => pure xs + let bK ← + match bk? with + | none => throw "missing b_K for requested head" + | some xs => pure xs + let bQIntervals := intervalsFromScaled bQ slack + let bKIntervals := intervalsFromScaled bK slack let qRow := matMulIntervalsFromScaled cfg slack hdr.modelDim hdr.headDim wq (ln1Rows[queryPos]!) - let mut kRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen + let qRow := addVecFixed qRow bQIntervals + let mut kRows : Array (Array Fixed10Interval) := Array.mkEmpty seqLenEff for row in ln1Rows do - kRows := kRows.push (matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wk row) + let kRow0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wk row + kRows := kRows.push (addVecFixed kRow0 bKIntervals) let ti : Int := (Int.ofNat queryPos) + targetOffset - if ti < 0 || ti ≥ (Int.ofNat hdr.seqLen) then + if ti < 0 || ti ≥ (Int.ofNat seqLenEff) then throw "query position has no valid target offset" let tIdx : Nat := Int.toNat ti let targetTok := tokens[tIdx]! let mut bestMatchLower? : Option Int := none let mut bestNonmatchUpper? : Option Int := none - for j in [:hdr.seqLen] do - let dot := fixedDotInterval cfg qRow (kRows[j]!) - if tokens[j]! = targetTok then - bestMatchLower? := - match bestMatchLower? with - | none => some dot.lo - | some m => some (max m dot.lo) + for j in [:seqLenEff] do + if !causalPattern || j ≤ queryPos then + let dot := fixedDotInterval cfg qRow (kRows[j]!) + if tokens[j]! = targetTok then + bestMatchLower? := + match bestMatchLower? with + | none => some dot.lo + | some m => some (max m dot.lo) + else + bestNonmatchUpper? := + match bestNonmatchUpper? with + | none => some dot.hi + | some m => some (max m dot.hi) else - bestNonmatchUpper? := - match bestNonmatchUpper? with - | none => some dot.hi - | some m => some (max m dot.hi) + pure () let bestMatchLower ← match bestMatchLower? with | none => throw "no matching keys for the requested offset" @@ -3272,10 +4054,8 @@ private def certifyHeadPatternBestMatchLocalBinary let bestMatchLowerRat := ratOfScaledInt scalePow10 bestMatchLower let bestNonmatchUpperRat := ratOfScaledInt scalePow10 bestNonmatchUpper let margin := ratOfScaledInt scalePow10 marginInt - let weightLB : Rat := - softmaxMaxProbLowerBound hdr.seqLen margin softmaxExpEffort - let softmaxJacobianUB : Rat := - softmaxJacobianNormInfBoundFromMargin hdr.seqLen margin softmaxExpEffort + let (effortUsed, weightLB, softmaxJacobianUB) := + chooseSoftmaxExpEffort hdr.seqLen margin softmaxExpEffort let cert : HeadBestMatchPatternCert := { layerIdx := layerIdx headIdx := headIdx @@ -3286,7 +4066,7 @@ private def certifyHeadPatternBestMatchLocalBinary bestMatchLogitLowerBound := bestMatchLowerRat bestNonmatchLogitUpperBound := bestNonmatchUpperRat marginLowerBound := margin - softmaxExpEffort := softmaxExpEffort + softmaxExpEffort := effortUsed bestMatchWeightLowerBound := weightLB softmaxJacobianNormInfUpperBound := softmaxJacobianUB } @@ -3297,32 +4077,61 @@ private def certifyHeadPatternBestMatchLocalBinary let tightLayers : Nat := if tightPattern then Nat.max 1 tightPatternLayers else 0 if tightLayers > 0 && layerIdx ≤ l + tightLayers then - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - let groupRows := groupUnionRowsByToken ln1Rows tokens - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wv ← ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let wo ← ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty groupRows.size - for row in groupRows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let vUnion := unionRowsFixed vOutRows - attnUnion := addVecFixed attnUnion vUnion - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion + if causalPattern then + let zeroRow : Array Fixed10Interval := + Array.replicate hdr.modelDim { lo := 0, hi := 0 } + let mut attnRows : Array (Array Fixed10Interval) := + Array.replicate seqLenEff zeroRow + for _h in [:hdr.numHeads] do + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let wv ← ExceptT.mk <| + readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 + let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) + let wo ← ExceptT.mk <| + readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 + let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty seqLenEff + for row in ln1Rows do + let vHidden0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wv row + let vHidden := addVecFixed vHidden0 bV + let vOut := matMulIntervalsFromScaled cfg slack + hdr.headDim hdr.modelDim wo vHidden + vOutRows := vOutRows.push vOut + let headRows := prefixUnionRowsFixed vOutRows + attnRows := addRowsFixed attnRows headRows + let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + attnRows := addVecFixedRows attnRows attnBias + residuals := addRowsFixed residuals attnRows + else + let mut attnUnion : Array Fixed10Interval := + Array.replicate hdr.modelDim { lo := 0, hi := 0 } + let groupRows := groupUnionRowsByToken ln1Rows tokens + for _h in [:hdr.numHeads] do + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let wv ← ExceptT.mk <| + readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 + let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) + let wo ← ExceptT.mk <| + readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 + let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty groupRows.size + for row in groupRows do + let vHidden0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wv row + let vHidden := addVecFixed vHidden0 bV + let vOut := matMulIntervalsFromScaled cfg slack + hdr.headDim hdr.modelDim wo vHidden + vOutRows := vOutRows.push vOut + let vUnion := unionRowsFixed vOutRows + attnUnion := addVecFixed attnUnion vUnion + let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + attnUnion := addVecFixed attnUnion attnBias + residuals := addVecFixedRows residuals attnUnion else let ln1Union := unionRowsFixed ln1Rows let mut attnUnion : Array Fixed10Interval := @@ -3345,7 +4154,7 @@ private def certifyHeadPatternBestMatchLocalBinary attnUnion := addVecFixed attnUnion attnBias residuals := addVecFixedRows residuals attnUnion let p2 := ln2Params.getD l defP - let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen + let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty seqLenEff for row in residuals do let (ln2Out, _ln2VarLB) := fixedLayerNormRowApprox cfg row p2.gamma p2.beta eps soundnessBits @@ -3396,25 +4205,37 @@ private def certifyHeadPatternBestMatchLocalBinarySweep (tightPatternLayers : Nat) (perRowPatternLayers : Nat) (scalePow10 : Nat := defaultBinaryScalePow10) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : + (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) + (causalPattern : Bool := true) + (shared? : Option SharedBinaryInputs := none) : IO (Except String (Array HeadBestMatchPatternCert)) := do let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 let slack : Int := fixedUlpSlack let action : ExceptT String IO (Array HeadBestMatchPatternCert) := do - let (hdr, ln1Params, ln2Params) ← - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) + let (hdr, ln1Params, ln2Params, residuals0, tokens) ← + match shared? with + | some shared => + if shared.scalePow10 ≠ scalePow10 then + throw "shared scalePow10 mismatch" + if shared.inputDelta ≠ inputDelta then + throw "shared inputDelta mismatch" + pure (shared.hdr, shared.ln1Params, shared.ln2Params, shared.residuals0, shared.tokens) + | none => + let (hdr, ln1Params, ln2Params) ← + ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) + let residuals0 ← + ExceptT.mk + (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) + let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) + if hdrTok.seqLen ≠ hdr.seqLen then + throw "token/embedding seq_len mismatch" + pure (hdr, ln1Params, ln2Params, residuals0, tokens) if layerIdx ≥ hdr.numLayers then throw s!"layer index {layerIdx} out of range" if headIdx ≥ hdr.numHeads then throw s!"head index {headIdx} out of range" if hdr.seqLen > maxSeqLen then throw s!"seq_len {hdr.seqLen} exceeds maxSeqLen {maxSeqLen}" - let residuals0 ← - ExceptT.mk - (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) - let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) - if hdrTok.seqLen ≠ hdr.seqLen then - throw "token/embedding seq_len mismatch" let h ← IO.FS.Handle.mk path IO.FS.Mode.read let _ ← ExceptT.mk (readBinaryHeader h) let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) @@ -3433,19 +4254,27 @@ private def certifyHeadPatternBestMatchLocalBinarySweep ln1Rows := ln1Rows.push ln1Out if l = layerIdx then let mut wq? : Option (Array Int) := none + let mut bq? : Option (Array Int) := none let mut wk? : Option (Array Int) := none + let mut bk? : Option (Array Int) := none for hIdx in [:hdr.numHeads] do if hIdx = headIdx then let wq ← ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 wq? := some wq - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let bQ ← + ExceptT.mk <| + readScaledFloatArray h hdr.headDim scalePow10 + bq? := some bQ let wk ← ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 wk? := some wk - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let bK ← + ExceptT.mk <| + readScaledFloatArray h hdr.headDim scalePow10 + bk? := some bK else let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) let _ ← ExceptT.mk (skipF64Array h hdr.headDim) @@ -3458,17 +4287,29 @@ private def certifyHeadPatternBestMatchLocalBinarySweep match wq? with | none => throw "missing W_Q for requested head" | some xs => pure xs + let bQ ← + match bq? with + | none => throw "missing b_Q for requested head" + | some xs => pure xs let wk ← match wk? with | none => throw "missing W_K for requested head" | some xs => pure xs + let bK ← + match bk? with + | none => throw "missing b_K for requested head" + | some xs => pure xs + let bQIntervals := intervalsFromScaled bQ slack + let bKIntervals := intervalsFromScaled bK slack let mut qRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen let mut kRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen for row in ln1Rows do - qRows := qRows.push (matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wq row) - kRows := kRows.push (matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wk row) + let qRow0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wq row + let kRow0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wk row + qRows := qRows.push (addVecFixed qRow0 bQIntervals) + kRows := kRows.push (addVecFixed kRow0 bKIntervals) let validPositions : Array Nat := Id.run do let mut out : Array Nat := Array.mkEmpty hdr.seqLen for i in [:hdr.seqLen] do @@ -3490,17 +4331,20 @@ private def certifyHeadPatternBestMatchLocalBinarySweep let mut bestMatchLower? : Option Int := none let mut bestNonmatchUpper? : Option Int := none for j in [:hdr.seqLen] do - let dot := fixedDotInterval cfg qRow (kRows[j]!) - if tokens[j]! = targetTok then - bestMatchLower? := - match bestMatchLower? with - | none => some dot.lo - | some m => some (max m dot.lo) + if !causalPattern || j ≤ queryPos then + let dot := fixedDotInterval cfg qRow (kRows[j]!) + if tokens[j]! = targetTok then + bestMatchLower? := + match bestMatchLower? with + | none => some dot.lo + | some m => some (max m dot.lo) + else + bestNonmatchUpper? := + match bestNonmatchUpper? with + | none => some dot.hi + | some m => some (max m dot.hi) else - bestNonmatchUpper? := - match bestNonmatchUpper? with - | none => some dot.hi - | some m => some (max m dot.hi) + pure () let bestMatchLower ← match bestMatchLower? with | none => throw "no matching keys for the requested offset" @@ -3513,10 +4357,8 @@ private def certifyHeadPatternBestMatchLocalBinarySweep let bestMatchLowerRat := ratOfScaledInt scalePow10 bestMatchLower let bestNonmatchUpperRat := ratOfScaledInt scalePow10 bestNonmatchUpper let margin := ratOfScaledInt scalePow10 marginInt - let weightLB : Rat := - softmaxMaxProbLowerBound hdr.seqLen margin softmaxExpEffort - let softmaxJacobianUB : Rat := - softmaxJacobianNormInfBoundFromMargin hdr.seqLen margin softmaxExpEffort + let (effortUsed, weightLB, softmaxJacobianUB) := + chooseSoftmaxExpEffort hdr.seqLen margin softmaxExpEffort let cert : HeadBestMatchPatternCert := { layerIdx := layerIdx headIdx := headIdx @@ -3527,7 +4369,7 @@ private def certifyHeadPatternBestMatchLocalBinarySweep bestMatchLogitLowerBound := bestMatchLowerRat bestNonmatchLogitUpperBound := bestNonmatchUpperRat marginLowerBound := margin - softmaxExpEffort := softmaxExpEffort + softmaxExpEffort := effortUsed bestMatchWeightLowerBound := weightLB softmaxJacobianNormInfUpperBound := softmaxJacobianUB } @@ -3553,32 +4395,61 @@ private def certifyHeadPatternBestMatchLocalBinarySweep let tightLayers : Nat := if tightPattern then Nat.max 1 tightPatternLayers else 0 if tightLayers > 0 && layerIdx ≤ l + tightLayers then - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - let groupRows := groupUnionRowsByToken ln1Rows tokens - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wv ← ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let wo ← ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty groupRows.size - for row in groupRows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let vUnion := unionRowsFixed vOutRows - attnUnion := addVecFixed attnUnion vUnion - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion + if causalPattern then + let zeroRow : Array Fixed10Interval := + Array.replicate hdr.modelDim { lo := 0, hi := 0 } + let mut attnRows : Array (Array Fixed10Interval) := + Array.replicate hdr.seqLen zeroRow + for _h in [:hdr.numHeads] do + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let wv ← ExceptT.mk <| + readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 + let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) + let wo ← ExceptT.mk <| + readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 + let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen + for row in ln1Rows do + let vHidden0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wv row + let vHidden := addVecFixed vHidden0 bV + let vOut := matMulIntervalsFromScaled cfg slack + hdr.headDim hdr.modelDim wo vHidden + vOutRows := vOutRows.push vOut + let headRows := prefixUnionRowsFixed vOutRows + attnRows := addRowsFixed attnRows headRows + let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + attnRows := addVecFixedRows attnRows attnBias + residuals := addRowsFixed residuals attnRows + else + let mut attnUnion : Array Fixed10Interval := + Array.replicate hdr.modelDim { lo := 0, hi := 0 } + let groupRows := groupUnionRowsByToken ln1Rows tokens + for _h in [:hdr.numHeads] do + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let wv ← ExceptT.mk <| + readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 + let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) + let wo ← ExceptT.mk <| + readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 + let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty groupRows.size + for row in groupRows do + let vHidden0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wv row + let vHidden := addVecFixed vHidden0 bV + let vOut := matMulIntervalsFromScaled cfg slack + hdr.headDim hdr.modelDim wo vHidden + vOutRows := vOutRows.push vOut + let vUnion := unionRowsFixed vOutRows + attnUnion := addVecFixed attnUnion vUnion + let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + attnUnion := addVecFixed attnUnion attnBias + residuals := addVecFixedRows residuals attnUnion else let ln1Union := unionRowsFixed ln1Rows let mut attnUnion : Array Fixed10Interval := @@ -3648,7 +4519,11 @@ private def certifyHeadValueLowerBoundLocalBinary (inputPath : System.FilePath) (inputDelta : Rat) (maxSeqLen : Nat) - (scalePow10 : Nat := defaultBinaryScalePow10) : + (scalePow10 : Nat := defaultBinaryScalePow10) + (tightPattern : Bool := false) + (tightPatternLayers : Nat := 1) + (perRowPatternLayers : Nat := 0) + (causalPattern : Bool := true) : IO (Except String HeadValueLowerBoundCert) := do let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 let slack : Int := fixedUlpSlack @@ -3746,18 +4621,21 @@ private def certifyHeadValueLowerBoundLocalBinary let mut matchLo? : Option Int := none let mut nonmatchLo? : Option Int := none for j in [:hdr.seqLen] do - let row := vOutRows[j]! - let vCoord := row[coord]!.lo - if tokens[j]! = targetTok then - matchLo? := - match matchLo? with - | none => some vCoord - | some m => some (min m vCoord) + if !causalPattern || j ≤ i then + let row := vOutRows[j]! + let vCoord := row[coord]!.lo + if tokens[j]! = targetTok then + matchLo? := + match matchLo? with + | none => some vCoord + | some m => some (min m vCoord) + else + nonmatchLo? := + match nonmatchLo? with + | none => some vCoord + | some m => some (min m vCoord) else - nonmatchLo? := - match nonmatchLo? with - | none => some vCoord - | some m => some (min m vCoord) + pure () let matchLo := match matchLo? with | none => 0 @@ -3799,45 +4677,116 @@ private def certifyHeadValueLowerBoundLocalBinary return cert throw "head value lower bound failed internal consistency checks" else - let ln1Union := unionRowsFixed ln1Rows - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let (vHidden0, _nWv) ← + let tightLayers : Nat := + if tightPattern then Nat.max 1 tightPatternLayers else 0 + if tightLayers > 0 && pattern.layerIdx ≤ l + tightLayers then + if causalPattern then + let zeroRow : Array Fixed10Interval := + Array.replicate hdr.modelDim { lo := 0, hi := 0 } + let mut attnRows : Array (Array Fixed10Interval) := + Array.replicate hdr.seqLen zeroRow + for _h in [:hdr.numHeads] do + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let wv ← ExceptT.mk <| + readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 + let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) + let wo ← ExceptT.mk <| + readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 + let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen + for row in ln1Rows do + let vHidden0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wv row + let vHidden := addVecFixed vHidden0 bV + let vOut := matMulIntervalsFromScaled cfg slack + hdr.headDim hdr.modelDim wo vHidden + vOutRows := vOutRows.push vOut + let headRows := prefixUnionRowsFixed vOutRows + attnRows := addRowsFixed attnRows headRows + let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + attnRows := addVecFixedRows attnRows attnBias + residuals := addRowsFixed residuals attnRows + else + let mut attnUnion : Array Fixed10Interval := + Array.replicate hdr.modelDim { lo := 0, hi := 0 } + let groupRows := groupUnionRowsByToken ln1Rows tokens + for _h in [:hdr.numHeads] do + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let wv ← ExceptT.mk <| + readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 + let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) + let wo ← ExceptT.mk <| + readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 + let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty groupRows.size + for row in groupRows do + let vHidden0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wv row + let vHidden := addVecFixed vHidden0 bV + let vOut := matMulIntervalsFromScaled cfg slack + hdr.headDim hdr.modelDim wo vHidden + vOutRows := vOutRows.push vOut + let vUnion := unionRowsFixed vOutRows + attnUnion := addVecFixed attnUnion vUnion + let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + attnUnion := addVecFixed attnUnion attnBias + residuals := addVecFixedRows residuals attnUnion + else + let ln1Union := unionRowsFixed ln1Rows + let mut attnUnion : Array Fixed10Interval := + Array.replicate hdr.modelDim { lo := 0, hi := 0 } + for _h in [:hdr.numHeads] do + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let (vHidden0, _nWv) ← + ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h + hdr.modelDim hdr.headDim ln1Union scalePow10) + let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) + let vHidden := addVecFixed vHidden0 bV + let (vOut, _nWo) ← + ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h + hdr.headDim hdr.modelDim vHidden scalePow10) + attnUnion := addVecFixed attnUnion vOut + let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + attnUnion := addVecFixed attnUnion attnBias + residuals := addVecFixedRows residuals attnUnion + let p2 := ln2Params.getD l defP + let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen + for row in residuals do + let (ln2Out, _ln2VarLB) := + fixedLayerNormRowApprox cfg row p2.gamma p2.beta eps soundnessBits + ln2Rows := ln2Rows.push ln2Out + let perRowLayers : Nat := perRowPatternLayers + if perRowLayers > 0 && pattern.layerIdx ≤ l + perRowLayers then + let wIn ← + ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.hiddenDim) scalePow10 + let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) + let wOut ← + ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 + let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + let mlpRows := + mlpRowsFromScaled cfg slack hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows + residuals := addRowsFixed residuals mlpRows + else + let ln2Union := unionRowsFixed ln2Rows + let (hidden0, _nWin) ← ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.headDim ln1Union scalePow10) - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let vHidden := addVecFixed vHidden0 bV - let (vOut, _nWo) ← + hdr.modelDim hdr.hiddenDim ln2Union scalePow10) + let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) + let hiddenB := addVecFixed hidden0 bIn + let actHidden := hiddenB.map Fixed10Interval.geluOverapprox + let (mlpOut0, _nWout) ← ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.headDim hdr.modelDim vHidden scalePow10) - attnUnion := addVecFixed attnUnion vOut - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - let p2 := ln2Params.getD l defP - let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln2Out, _ln2VarLB) := - fixedLayerNormRowApprox cfg row p2.gamma p2.beta eps soundnessBits - ln2Rows := ln2Rows.push ln2Out - let ln2Union := unionRowsFixed ln2Rows - let (hidden0, _nWin) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Union scalePow10) - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let hiddenB := addVecFixed hidden0 bIn - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox - let (mlpOut0, _nWout) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.hiddenDim hdr.modelDim actHidden scalePow10) - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpOut := addVecFixed mlpOut0 bOut - residuals := addVecFixedRows residuals mlpOut + hdr.hiddenDim hdr.modelDim actHidden scalePow10) + let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + let mlpOut := addVecFixed mlpOut0 bOut + residuals := addVecFixedRows residuals mlpOut let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) @@ -3855,7 +4804,11 @@ private def certifyHeadLogitDiffLowerBoundLocalBinary (inputPath : System.FilePath) (inputDelta : Rat) (maxSeqLen : Nat) - (scalePow10 : Nat := defaultBinaryScalePow10) : + (scalePow10 : Nat := defaultBinaryScalePow10) + (tightPattern : Bool := false) + (tightPatternLayers : Nat := 1) + (perRowPatternLayers : Nat := 0) + (causalPattern : Bool := true) : IO (Except String HeadLogitDiffLowerBoundCert) := do let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 let slack : Int := fixedUlpSlack @@ -3960,17 +4913,20 @@ private def certifyHeadLogitDiffLowerBoundLocalBinary let mut matchLo? : Option Int := none let mut nonmatchLo? : Option Int := none for j in [:hdr.seqLen] do - let vLo := (vDotRows[j]!).lo - if tokens[j]! = targetTok then - matchLo? := - match matchLo? with - | none => some vLo - | some m => some (min m vLo) + if !causalPattern || j ≤ i then + let vLo := (vDotRows[j]!).lo + if tokens[j]! = targetTok then + matchLo? := + match matchLo? with + | none => some vLo + | some m => some (min m vLo) + else + nonmatchLo? := + match nonmatchLo? with + | none => some vLo + | some m => some (min m vLo) else - nonmatchLo? := - match nonmatchLo? with - | none => some vLo - | some m => some (min m vLo) + pure () let matchLo := match matchLo? with | none => 0 @@ -4013,45 +4969,116 @@ private def certifyHeadLogitDiffLowerBoundLocalBinary return cert throw "head logit lower bound failed internal consistency checks" else - let ln1Union := unionRowsFixed ln1Rows - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let (vHidden0, _nWv) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.headDim ln1Union scalePow10) - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let vHidden := addVecFixed vHidden0 bV - let (vOut, _nWo) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.headDim hdr.modelDim vHidden scalePow10) - attnUnion := addVecFixed attnUnion vOut - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion + let tightLayers : Nat := + if tightPattern then Nat.max 1 tightPatternLayers else 0 + if tightLayers > 0 && pattern.layerIdx ≤ l + tightLayers then + if causalPattern then + let zeroRow : Array Fixed10Interval := + Array.replicate hdr.modelDim { lo := 0, hi := 0 } + let mut attnRows : Array (Array Fixed10Interval) := + Array.replicate hdr.seqLen zeroRow + for _h in [:hdr.numHeads] do + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let wv ← ExceptT.mk <| + readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 + let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) + let wo ← ExceptT.mk <| + readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 + let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen + for row in ln1Rows do + let vHidden0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wv row + let vHidden := addVecFixed vHidden0 bV + let vOut := matMulIntervalsFromScaled cfg slack + hdr.headDim hdr.modelDim wo vHidden + vOutRows := vOutRows.push vOut + let headRows := prefixUnionRowsFixed vOutRows + attnRows := addRowsFixed attnRows headRows + let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + attnRows := addVecFixedRows attnRows attnBias + residuals := addRowsFixed residuals attnRows + else + let mut attnUnion : Array Fixed10Interval := + Array.replicate hdr.modelDim { lo := 0, hi := 0 } + let groupRows := groupUnionRowsByToken ln1Rows tokens + for _h in [:hdr.numHeads] do + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let wv ← ExceptT.mk <| + readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 + let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) + let wo ← ExceptT.mk <| + readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 + let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty groupRows.size + for row in groupRows do + let vHidden0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wv row + let vHidden := addVecFixed vHidden0 bV + let vOut := matMulIntervalsFromScaled cfg slack + hdr.headDim hdr.modelDim wo vHidden + vOutRows := vOutRows.push vOut + let vUnion := unionRowsFixed vOutRows + attnUnion := addVecFixed attnUnion vUnion + let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + attnUnion := addVecFixed attnUnion attnBias + residuals := addVecFixedRows residuals attnUnion + else + let ln1Union := unionRowsFixed ln1Rows + let mut attnUnion : Array Fixed10Interval := + Array.replicate hdr.modelDim { lo := 0, hi := 0 } + for _h in [:hdr.numHeads] do + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let (vHidden0, _nWv) ← + ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h + hdr.modelDim hdr.headDim ln1Union scalePow10) + let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) + let vHidden := addVecFixed vHidden0 bV + let (vOut, _nWo) ← + ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h + hdr.headDim hdr.modelDim vHidden scalePow10) + attnUnion := addVecFixed attnUnion vOut + let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + attnUnion := addVecFixed attnUnion attnBias + residuals := addVecFixedRows residuals attnUnion let p2 := ln2Params.getD l defP let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen for row in residuals do let (ln2Out, _ln2VarLB) := fixedLayerNormRowApprox cfg row p2.gamma p2.beta eps soundnessBits ln2Rows := ln2Rows.push ln2Out - let ln2Union := unionRowsFixed ln2Rows - let (hidden0, _nWin) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Union scalePow10) - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let hiddenB := addVecFixed hidden0 bIn - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox - let (mlpOut0, _nWout) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.hiddenDim hdr.modelDim actHidden scalePow10) - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpOut := addVecFixed mlpOut0 bOut - residuals := addVecFixedRows residuals mlpOut + let perRowLayers : Nat := perRowPatternLayers + if perRowLayers > 0 && pattern.layerIdx ≤ l + perRowLayers then + let wIn ← + ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.hiddenDim) scalePow10 + let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) + let wOut ← + ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 + let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + let mlpRows := + mlpRowsFromScaled cfg slack hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows + residuals := addRowsFixed residuals mlpRows + else + let ln2Union := unionRowsFixed ln2Rows + let (hidden0, _nWin) ← + ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h + hdr.modelDim hdr.hiddenDim ln2Union scalePow10) + let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) + let hiddenB := addVecFixed hidden0 bIn + let actHidden := hiddenB.map Fixed10Interval.geluOverapprox + let (mlpOut0, _nWout) ← + ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h + hdr.hiddenDim hdr.modelDim actHidden scalePow10) + let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + let mlpOut := addVecFixed mlpOut0 bOut + residuals := addVecFixedRows residuals mlpOut let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) @@ -4149,7 +5176,8 @@ def certifyHeadPatternLocal (tightPatternLayers : Nat := 1) (perRowPatternLayers : Nat := 0) (scalePow10 : Nat := defaultBinaryScalePow10) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : + (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) + (causalPattern : Bool := true) : IO (Except String HeadPatternCert) := do if inputDelta < 0 then return .error "delta must be nonnegative" @@ -4159,7 +5187,7 @@ def certifyHeadPatternLocal let inputPath := inputPath?.getD path certifyHeadPatternLocalBinary path layerIdx headIdx eps soundnessBits inputPath inputDelta targetOffset maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 - softmaxExpEffort + softmaxExpEffort causalPattern else return .error "head pattern bounds require NFP_BINARY_V1" @@ -4178,7 +5206,8 @@ def certifyHeadPatternBestMatchLocal (tightPatternLayers : Nat := 1) (perRowPatternLayers : Nat := 0) (scalePow10 : Nat := defaultBinaryScalePow10) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : + (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) + (causalPattern : Bool := true) : IO (Except String HeadBestMatchPatternCert) := do if inputDelta < 0 then return .error "delta must be nonnegative" @@ -4189,7 +5218,7 @@ def certifyHeadPatternBestMatchLocal certifyHeadPatternBestMatchLocalBinary path layerIdx headIdx queryPos? eps soundnessBits inputPath inputDelta targetOffset maxSeqLen tightPattern tightPatternLayers - perRowPatternLayers scalePow10 softmaxExpEffort + perRowPatternLayers scalePow10 softmaxExpEffort causalPattern else return .error "head pattern bounds require NFP_BINARY_V1" @@ -4207,7 +5236,8 @@ def certifyHeadPatternBestMatchLocalSweep (tightPatternLayers : Nat := 1) (perRowPatternLayers : Nat := 0) (scalePow10 : Nat := defaultBinaryScalePow10) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : + (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) + (causalPattern : Bool := true) : IO (Except String (Array HeadBestMatchPatternCert)) := do if inputDelta < 0 then return .error "delta must be nonnegative" @@ -4217,7 +5247,7 @@ def certifyHeadPatternBestMatchLocalSweep let inputPath := inputPath?.getD path certifyHeadPatternBestMatchLocalBinarySweep path layerIdx headIdx eps soundnessBits inputPath inputDelta targetOffset maxSeqLen tightPattern tightPatternLayers perRowPatternLayers - scalePow10 softmaxExpEffort + scalePow10 softmaxExpEffort causalPattern else return .error "head pattern bounds require NFP_BINARY_V1" @@ -4235,7 +5265,8 @@ def certifyLayerBestMatchMarginLocal (tightPatternLayers : Nat := 1) (perRowPatternLayers : Nat := 0) (scalePow10 : Nat := defaultBinaryScalePow10) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : + (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) + (causalPattern : Bool := true) : IO (Except String LayerBestMatchMarginCert) := do if inputDelta < 0 then return .error "delta must be nonnegative" @@ -4257,7 +5288,7 @@ def certifyLayerBestMatchMarginLocal certifyHeadPatternBestMatchLocalBinarySweep path layerIdx hIdx eps soundnessBits inputPath inputDelta targetOffset maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 - softmaxExpEffort with + softmaxExpEffort causalPattern with | .error e => return .error e | .ok certs => for cert in certs do @@ -4294,7 +5325,8 @@ def certifyHeadValueLowerBoundLocal (tightPattern : Bool := false) (tightPatternLayers : Nat := 1) (perRowPatternLayers : Nat := 0) - (scalePow10 : Nat := defaultBinaryScalePow10) : + (scalePow10 : Nat := defaultBinaryScalePow10) + (causalPattern : Bool := true) : IO (Except String HeadValueLowerBoundCert) := do if inputDelta < 0 then return .error "delta must be nonnegative" @@ -4305,11 +5337,13 @@ def certifyHeadValueLowerBoundLocal let patternE ← certifyHeadPatternLocalBinary path layerIdx headIdx eps soundnessBits inputPath inputDelta targetOffset maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 + defaultSoftmaxExpEffort causalPattern match patternE with | .error e => return .error e | .ok pattern => - certifyHeadValueLowerBoundLocalBinary path pattern coord eps soundnessBits inputPath - inputDelta maxSeqLen scalePow10 + certifyHeadValueLowerBoundLocalBinary path pattern coord eps soundnessBits inputPath + inputDelta maxSeqLen scalePow10 tightPattern tightPatternLayers perRowPatternLayers + causalPattern else return .error "head value bounds require NFP_BINARY_V1" @@ -4326,7 +5360,8 @@ def certifyHeadLogitDiffLowerBoundLocal (tightPattern : Bool := false) (tightPatternLayers : Nat := 1) (perRowPatternLayers : Nat := 0) - (scalePow10 : Nat := defaultBinaryScalePow10) : + (scalePow10 : Nat := defaultBinaryScalePow10) + (causalPattern : Bool := true) : IO (Except String HeadLogitDiffLowerBoundCert) := do if inputDelta < 0 then return .error "delta must be nonnegative" @@ -4337,11 +5372,13 @@ def certifyHeadLogitDiffLowerBoundLocal let patternE ← certifyHeadPatternLocalBinary path layerIdx headIdx eps soundnessBits inputPath inputDelta targetOffset maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 + defaultSoftmaxExpEffort causalPattern match patternE with | .error e => return .error e | .ok pattern => - certifyHeadLogitDiffLowerBoundLocalBinary path pattern targetToken negativeToken - eps soundnessBits inputPath inputDelta maxSeqLen scalePow10 + certifyHeadLogitDiffLowerBoundLocalBinary path pattern targetToken negativeToken + eps soundnessBits inputPath inputDelta maxSeqLen scalePow10 tightPattern + tightPatternLayers perRowPatternLayers causalPattern else return .error "head logit bounds require NFP_BINARY_V1" @@ -4362,7 +5399,8 @@ def certifyInductionSound (perRowPatternLayers : Nat := 0) (targetToken? : Option Nat := none) (negativeToken? : Option Nat := none) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : + (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) + (causalPattern : Bool := true) : IO (Except String InductionHeadSoundCert) := do if inputDelta < 0 then return .error "delta must be nonnegative" @@ -4373,20 +5411,21 @@ def certifyInductionSound let p1E ← certifyHeadPatternLocalBinary path layer1 head1 eps soundnessBits inputPath inputDelta offset1 maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 - softmaxExpEffort + softmaxExpEffort causalPattern match p1E with | .error e => return .error e | .ok p1 => let p2E ← certifyHeadPatternLocalBinary path layer2 head2 eps soundnessBits inputPath inputDelta offset2 maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 - softmaxExpEffort + softmaxExpEffort causalPattern match p2E with | .error e => return .error e | .ok p2 => let vE ← certifyHeadValueLowerBoundLocalBinary path p2 coord eps soundnessBits inputPath - inputDelta maxSeqLen scalePow10 + inputDelta maxSeqLen scalePow10 tightPattern tightPatternLayers perRowPatternLayers + causalPattern match vE with | .error e => return .error e | .ok v => @@ -4396,7 +5435,8 @@ def certifyInductionSound | some targetToken, some negativeToken => do let logitE ← certifyHeadLogitDiffLowerBoundLocalBinary path p2 targetToken negativeToken eps soundnessBits inputPath inputDelta - maxSeqLen scalePow10 + maxSeqLen scalePow10 tightPattern tightPatternLayers perRowPatternLayers + causalPattern pure (logitE.map some) | _, _ => pure (.error "use both target and negative tokens (or neither)") @@ -4416,6 +5456,636 @@ def certifyInductionSound else return .error "induction sound cert requires NFP_BINARY_V1" +/-- Compute a best-match induction-head certificate in a single binary pass. -/ +private def certifyInductionSoundBestMatchLocalBinaryPair + (path : System.FilePath) + (layer1 head1 layer2 head2 coord queryPos : Nat) + (eps : Rat) + (soundnessBits : Nat) + (inputPath : System.FilePath) + (inputDelta : Rat) + (offset1 : Int) + (offset2 : Int) + (maxSeqLen : Nat) + (scalePow10 : Nat := defaultBinaryScalePow10) + (tightPattern : Bool) + (tightPatternLayers : Nat) + (perRowPatternLayers : Nat) + (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) + (causalPattern : Bool := true) + (shared? : Option SharedBinaryInputs := none) + (prefix? : Option SharedBinaryPrefix := none) + (targetToken? : Option Nat := none) + (negativeToken? : Option Nat := none) + (direction? : Option (Thunk (Array Fixed10Interval)) := none) : + IO (Except String InductionHeadBestMatchSoundCert) := do + let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 + let slack : Int := fixedUlpSlack + let action : ExceptT String IO InductionHeadBestMatchSoundCert := do + let (hdr, ln1Params, ln2Params, residualsBase, tokensBase) ← + match shared? with + | some shared => + if shared.scalePow10 ≠ scalePow10 then + throw "shared scalePow10 mismatch" + if shared.inputDelta ≠ inputDelta then + throw "shared inputDelta mismatch" + pure (shared.hdr, shared.ln1Params, shared.ln2Params, shared.residuals0, shared.tokens) + | none => + let (hdr, ln1Params, ln2Params) ← + ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) + let residuals0 ← + ExceptT.mk + (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) + let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) + if hdrTok.seqLen ≠ hdr.seqLen then + throw "token/embedding seq_len mismatch" + pure (hdr, ln1Params, ln2Params, residuals0, tokens) + if layer1 ≥ hdr.numLayers then + throw s!"layer1 index {layer1} out of range" + if layer2 ≥ hdr.numLayers then + throw s!"layer2 index {layer2} out of range" + if head1 ≥ hdr.numHeads then + throw s!"head1 index {head1} out of range" + if head2 ≥ hdr.numHeads then + throw s!"head2 index {head2} out of range" + if coord ≥ hdr.modelDim then + throw s!"coord index {coord} out of range" + if hdr.seqLen > maxSeqLen then + throw s!"seq_len {hdr.seqLen} exceeds maxSeqLen {maxSeqLen}" + if queryPos ≥ hdr.seqLen then + throw s!"queryPos {queryPos} out of range" + let seqLenEff : Nat := if causalPattern then queryPos + 1 else hdr.seqLen + let (residuals0, tokens) ← + match prefix? with + | some pref => + if pref.seqLenEff ≠ seqLenEff then + throw "prefix seq_len mismatch" + pure (pref.residuals.get, pref.tokens.get) + | none => + let residuals0 := + if causalPattern then takePrefix residualsBase seqLenEff else residualsBase + let tokens := if causalPattern then takePrefix tokensBase seqLenEff else tokensBase + pure (residuals0, tokens) + let useLogit ← + match targetToken?, negativeToken?, direction? with + | none, none, none => pure false + | some _, some _, some _ => pure true + | _, _, _ => throw "use both target and negative tokens (or neither)" + let calcLnRows + (rows : Array (Array Fixed10Interval)) + (p : LayerNormParamsFixed) : + Array (Array Fixed10Interval) := + Id.run do + let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size + for row in rows do + let (lnOut, _varLB) := + fixedLayerNormRowApprox cfg row p.gamma p.beta eps soundnessBits + out := out.push lnOut + return out + let calcVOutRows + (rows : Array (Array Fixed10Interval)) + (wv wo : Array Int) + (bV : Array Fixed10Interval) : + Array (Array Fixed10Interval) := + Id.run do + let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size + for row in rows do + let vHidden0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wv row + let vHidden := addVecFixed vHidden0 bV + let vOut := matMulIntervalsFromScaled cfg slack + hdr.headDim hdr.modelDim wo vHidden + out := out.push vOut + return out + let calcVOut + (row : Array Fixed10Interval) + (wv wo : Array Int) + (bV : Array Fixed10Interval) : + Array Fixed10Interval := + let vHidden0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wv row + let vHidden := addVecFixed vHidden0 bV + matMulIntervalsFromScaled cfg slack hdr.headDim hdr.modelDim wo vHidden + let bestMatchPattern + (layerIdx headIdx : Nat) + (ln1Rows : Array (Array Fixed10Interval)) + (wq wk : Array Int) + (bQ bK : Array Fixed10Interval) + (targetOffset : Int) : + ExceptT String IO HeadBestMatchPatternCert := do + let qRow0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wq (ln1Rows[queryPos]!) + let qRow := addVecFixed qRow0 bQ + let mut kRows : Array (Array Fixed10Interval) := Array.mkEmpty seqLenEff + for row in ln1Rows do + let kRow0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wk row + kRows := kRows.push (addVecFixed kRow0 bK) + let ti : Int := (Int.ofNat queryPos) + targetOffset + if ti < 0 || ti ≥ (Int.ofNat seqLenEff) then + throw "query position has no valid target offset" + let tIdx : Nat := Int.toNat ti + let targetTok := tokens[tIdx]! + let mut bestMatchLower? : Option Int := none + let mut bestNonmatchUpper? : Option Int := none + for j in [:seqLenEff] do + if !causalPattern || j ≤ queryPos then + let dot := fixedDotInterval cfg qRow (kRows[j]!) + if tokens[j]! = targetTok then + bestMatchLower? := + match bestMatchLower? with + | none => some dot.lo + | some m => some (max m dot.lo) + else + bestNonmatchUpper? := + match bestNonmatchUpper? with + | none => some dot.hi + | some m => some (max m dot.hi) + else + pure () + let bestMatchLower ← + match bestMatchLower? with + | none => throw "no matching keys for the requested offset" + | some v => pure v + let bestNonmatchUpper := + match bestNonmatchUpper? with + | none => bestMatchLower + | some v => v + let marginInt : Int := bestMatchLower - bestNonmatchUpper + let bestMatchLowerRat := ratOfScaledInt scalePow10 bestMatchLower + let bestNonmatchUpperRat := ratOfScaledInt scalePow10 bestNonmatchUpper + let margin := ratOfScaledInt scalePow10 marginInt + let (effortUsed, weightLB, softmaxJacobianUB) := + chooseSoftmaxExpEffort hdr.seqLen margin softmaxExpEffort + let cert : HeadBestMatchPatternCert := { + layerIdx := layerIdx + headIdx := headIdx + seqLen := hdr.seqLen + queryPos := queryPos + targetOffset := targetOffset + targetToken := targetTok + bestMatchLogitLowerBound := bestMatchLowerRat + bestNonmatchLogitUpperBound := bestNonmatchUpperRat + marginLowerBound := margin + softmaxExpEffort := effortUsed + bestMatchWeightLowerBound := weightLB + softmaxJacobianNormInfUpperBound := softmaxJacobianUB + } + if cert.check then + return cert + throw "best-match head pattern certificate failed internal consistency checks" + let valueLogit + (ln1Rows : Array (Array Fixed10Interval)) + (matchWeightLowerBound : Rat) + (wv wo : Array Int) + (bV : Array Fixed10Interval) + (targetOffset : Int) : + ExceptT String IO HeadValueLogitCert := do + let vOutRows := calcVOutRows ln1Rows wv wo bV + let ti : Int := (Int.ofNat queryPos) + targetOffset + if ti < 0 || ti ≥ (Int.ofNat seqLenEff) then + throw "query position has no valid target offset" + let tIdx : Nat := Int.toNat ti + let targetTok := tokens[tIdx]! + let mut matchLo? : Option Int := none + let mut nonmatchLo? : Option Int := none + for j in [:seqLenEff] do + if !causalPattern || j ≤ queryPos then + let row := vOutRows[j]! + let vCoord := row[coord]!.lo + if tokens[j]! = targetTok then + matchLo? := + match matchLo? with + | none => some vCoord + | some m => some (min m vCoord) + else + nonmatchLo? := + match nonmatchLo? with + | none => some vCoord + | some m => some (min m vCoord) + else + pure () + let matchLo ← + match matchLo? with + | none => throw "no matching keys for the requested offset" + | some v => pure v + let nonmatchLo := + match nonmatchLo? with + | none => matchLo + | some v => v + let matchLoRat := ratOfScaledInt scalePow10 matchLo + let nonmatchLoRat := ratOfScaledInt scalePow10 nonmatchLo + let outputLB := mixLowerBound matchWeightLowerBound matchLoRat nonmatchLoRat + let value : HeadValueLowerBoundPosCert := { + layerIdx := layer2 + headIdx := head2 + queryPos := queryPos + coord := coord + matchWeightLowerBound := matchWeightLowerBound + matchCoordLowerBound := matchLoRat + nonmatchCoordLowerBound := nonmatchLoRat + outputCoordLowerBound := outputLB + } + if !value.check then + throw "head value certificate failed internal consistency checks" + let logit? ← + if !useLogit then + pure none + else + match targetToken?, negativeToken?, direction? with + | some targetToken, some negativeToken, some direction => do + let dir := direction.get + if dir.size ≠ hdr.modelDim then + throw "logit direction size mismatch" + let mut vDotRows : Array Fixed10Interval := Array.mkEmpty seqLenEff + for row in vOutRows do + vDotRows := vDotRows.push (fixedDotInterval cfg row dir) + let mut matchLoLogit? : Option Int := none + let mut nonmatchLoLogit? : Option Int := none + for j in [:seqLenEff] do + if !causalPattern || j ≤ queryPos then + let vLo := (vDotRows[j]!).lo + if tokens[j]! = targetTok then + matchLoLogit? := + match matchLoLogit? with + | none => some vLo + | some m => some (min m vLo) + else + nonmatchLoLogit? := + match nonmatchLoLogit? with + | none => some vLo + | some m => some (min m vLo) + else + pure () + let matchLoLogit ← + match matchLoLogit? with + | none => throw "no matching keys for the requested offset" + | some v => pure v + let nonmatchLoLogit := + match nonmatchLoLogit? with + | none => matchLoLogit + | some v => v + let matchLoRat := ratOfScaledInt scalePow10 matchLoLogit + let nonmatchLoRat := ratOfScaledInt scalePow10 nonmatchLoLogit + let logitLB := mixLowerBound matchWeightLowerBound matchLoRat nonmatchLoRat + let logitCert : HeadLogitDiffLowerBoundPosCert := { + layerIdx := layer2 + headIdx := head2 + queryPos := queryPos + targetToken := targetToken + negativeToken := negativeToken + matchWeightLowerBound := matchWeightLowerBound + matchLogitLowerBound := matchLoRat + nonmatchLogitLowerBound := nonmatchLoRat + logitDiffLowerBound := logitLB + } + if logitCert.check then + pure (some logitCert) + else + throw "head logit certificate failed internal consistency checks" + | _, _, _ => + throw "use both target and negative tokens (or neither)" + return { value := value, logit? := logit? } + let addAttn + (useTight : Bool) + (ln1Rows : Array (Array Fixed10Interval)) + (ln1Union? : Option (Array Fixed10Interval)) + (groupRows? : Option (Array (Array Fixed10Interval))) + (attnRows? : Option (Array (Array Fixed10Interval))) + (attnUnion? : Option (Array Fixed10Interval)) + (wv wo : Array Int) + (bV : Array Fixed10Interval) : + ExceptT String IO + (Option (Array (Array Fixed10Interval)) × Option (Array Fixed10Interval)) := do + if useTight then + if causalPattern then + let vOutRows := calcVOutRows ln1Rows wv wo bV + let headRows := prefixUnionRowsFixed vOutRows + match attnRows? with + | some rows => return (some (addRowsFixed rows headRows), attnUnion?) + | none => throw "missing attnRows" + else + let groupRows ← + match groupRows? with + | some rows => pure rows + | none => throw "missing group rows" + let vOutRows := calcVOutRows groupRows wv wo bV + let vUnion := unionRowsFixed vOutRows + match attnUnion? with + | some u => return (attnRows?, some (addVecFixed u vUnion)) + | none => throw "missing attnUnion" + else + let ln1Union ← + match ln1Union? with + | some row => pure row + | none => throw "missing ln1Union" + let vOut := calcVOut ln1Union wv wo bV + match attnUnion? with + | some u => return (attnRows?, some (addVecFixed u vOut)) + | none => throw "missing attnUnion" + let applyAttn + (rows : Array (Array Fixed10Interval)) + (useTight : Bool) + (attnRows? : Option (Array (Array Fixed10Interval))) + (attnUnion? : Option (Array Fixed10Interval)) + (attnBias : Array Fixed10Interval) : + ExceptT String IO (Array (Array Fixed10Interval)) := do + if useTight && causalPattern then + match attnRows? with + | some attnRows => + let attnRows := addVecFixedRows attnRows attnBias + return addRowsFixed rows attnRows + | none => throw "missing attnRows" + else + match attnUnion? with + | some attnUnion => + let attnUnion := addVecFixed attnUnion attnBias + return addVecFixedRows rows attnUnion + | none => throw "missing attnUnion" + let applyMlp + (rows : Array (Array Fixed10Interval)) + (usePerRow : Bool) + (p : LayerNormParamsFixed) + (wIn wOut : Array Int) + (bIn bOut : Array Fixed10Interval) : + Array (Array Fixed10Interval) := + let ln2Rows := calcLnRows rows p + if usePerRow then + let mlpRows := mlpRowsFromScaled cfg slack + hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows + addRowsFixed rows mlpRows + else + let ln2Union := unionRowsFixed ln2Rows + let mlpOut := mlpRowFromScaled cfg slack + hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Union + addVecFixedRows rows mlpOut + let h ← IO.FS.Handle.mk path IO.FS.Mode.read + let _ ← ExceptT.mk (readBinaryHeader h) + let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) + let _ ← ExceptT.mk (skipF64Array h (hdr.seqLen * hdr.modelDim)) + let defP : LayerNormParamsFixed := { + gamma := Array.replicate hdr.modelDim { lo := 0, hi := 0 } + beta := Array.replicate hdr.modelDim { lo := 0, hi := 0 } + } + let mut residuals1 := residuals0 + let mut residuals2 := residuals0 + let mut residualsSame : Bool := true + let mut residualsV := residuals0 + let mut residualsSameV : Bool := true + let mut p1? : Option HeadBestMatchPatternCert := none + let mut p2? : Option HeadBestMatchPatternCert := none + let mut vlogit? : Option HeadValueLogitCert := none + for l in [:hdr.numLayers] do + let at1 := l = layer1 && p1?.isNone + let at2 := l = layer2 && p2?.isNone + let needUpdate1 := l < layer1 && p1?.isNone + let needUpdate2 := l < layer2 && p2?.isNone + let needUpdateV := needUpdate2 + let needRows1 := at1 || needUpdate1 + let needRows2 := at2 || needUpdate2 + let needRowsV := needRows2 + let ln1P := ln1Params.getD l defP + let mut ln1RowsShared? : Option (Array (Array Fixed10Interval)) := none + if residualsSame && (needRows1 || needRows2) then + ln1RowsShared? := some (calcLnRows residuals1 ln1P) + let mut ln1Rows1? : Option (Array (Array Fixed10Interval)) := none + let mut ln1Rows2? : Option (Array (Array Fixed10Interval)) := none + if needRows1 then + ln1Rows1? := + some (ln1RowsShared?.getD (calcLnRows residuals1 ln1P)) + if needRows2 then + ln1Rows2? := + some (ln1RowsShared?.getD (calcLnRows residuals2 ln1P)) + let mut ln1RowsV? : Option (Array (Array Fixed10Interval)) := none + if needRowsV then + if residualsSameV then + ln1RowsV? := ln1Rows2? + else + ln1RowsV? := some (calcLnRows residualsV ln1P) + let tightLayers : Nat := + if tightPattern then Nat.max 1 tightPatternLayers else 0 + let useTight1 := needUpdate1 && tightLayers > 0 && layer1 ≤ l + tightLayers + let useTight2 := needUpdate2 && tightLayers > 0 && layer2 ≤ l + tightLayers + let usePerRow1 := + needUpdate1 && perRowPatternLayers > 0 && layer1 ≤ l + perRowPatternLayers + let usePerRow2 := + needUpdate2 && perRowPatternLayers > 0 && layer2 ≤ l + perRowPatternLayers + let useTightV := useTight2 + let usePerRowV := usePerRow2 + let skipAttnV := useTightV && causalPattern && seqLenEff < hdr.seqLen + let shareUpdateV := residualsSameV && needUpdateV && !skipAttnV + let shareUpdate := + residualsSame && needUpdate1 && needUpdate2 && + useTight1 = useTight2 && usePerRow1 = usePerRow2 + let zeroRow : Array Fixed10Interval := + Array.replicate hdr.modelDim { lo := 0, hi := 0 } + let mut ln1Union1? : Option (Array Fixed10Interval) := none + let mut ln1Union2? : Option (Array Fixed10Interval) := none + let mut groupRows1? : Option (Array (Array Fixed10Interval)) := none + let mut groupRows2? : Option (Array (Array Fixed10Interval)) := none + let mut attnRows1? : Option (Array (Array Fixed10Interval)) := none + let mut attnRows2? : Option (Array (Array Fixed10Interval)) := none + let mut attnUnion1? : Option (Array Fixed10Interval) := none + let mut attnUnion2? : Option (Array Fixed10Interval) := none + let mut ln1UnionV? : Option (Array Fixed10Interval) := none + let mut groupRowsV? : Option (Array (Array Fixed10Interval)) := none + let mut attnRowsV? : Option (Array (Array Fixed10Interval)) := none + let mut attnUnionV? : Option (Array Fixed10Interval) := none + let mut ln1UnionShared? : Option (Array Fixed10Interval) := none + let mut groupRowsShared? : Option (Array (Array Fixed10Interval)) := none + let mut attnRowsShared? : Option (Array (Array Fixed10Interval)) := none + let mut attnUnionShared? : Option (Array Fixed10Interval) := none + let ln1Rows1 := ln1Rows1?.getD #[] + let ln1Rows2 := ln1Rows2?.getD #[] + let ln1RowsV := ln1RowsV?.getD #[] + let ln1RowsShared := ln1RowsShared?.getD #[] + if shareUpdate then + if useTight1 then + if causalPattern then + attnRowsShared? := some (Array.replicate seqLenEff zeroRow) + else + groupRowsShared? := some (groupUnionRowsByToken ln1RowsShared tokens) + attnUnionShared? := some (Array.replicate hdr.modelDim { lo := 0, hi := 0 }) + else + ln1UnionShared? := some (unionRowsFixed ln1RowsShared) + attnUnionShared? := some (Array.replicate hdr.modelDim { lo := 0, hi := 0 }) + else + if needUpdate1 then + if useTight1 then + if causalPattern then + attnRows1? := some (Array.replicate seqLenEff zeroRow) + else + groupRows1? := some (groupUnionRowsByToken ln1Rows1 tokens) + attnUnion1? := some (Array.replicate hdr.modelDim { lo := 0, hi := 0 }) + else + ln1Union1? := some (unionRowsFixed ln1Rows1) + attnUnion1? := some (Array.replicate hdr.modelDim { lo := 0, hi := 0 }) + if needUpdate2 then + if useTight2 then + if causalPattern then + attnRows2? := some (Array.replicate seqLenEff zeroRow) + else + groupRows2? := some (groupUnionRowsByToken ln1Rows2 tokens) + attnUnion2? := some (Array.replicate hdr.modelDim { lo := 0, hi := 0 }) + else + ln1Union2? := some (unionRowsFixed ln1Rows2) + attnUnion2? := some (Array.replicate hdr.modelDim { lo := 0, hi := 0 }) + if needUpdateV && !shareUpdateV && !skipAttnV then + if useTightV then + if causalPattern then + attnRowsV? := some (Array.replicate seqLenEff zeroRow) + else + groupRowsV? := some (groupUnionRowsByToken ln1RowsV tokens) + attnUnionV? := some (Array.replicate hdr.modelDim { lo := 0, hi := 0 }) + else + ln1UnionV? := some (unionRowsFixed ln1RowsV) + attnUnionV? := some (Array.replicate hdr.modelDim { lo := 0, hi := 0 }) + for hIdx in [:hdr.numHeads] do + let needQK := (at1 && hIdx = head1) || (at2 && hIdx = head2) + if needQK then + let wq ← + ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 + let bQ ← ExceptT.mk <| readScaledFloatArray h hdr.headDim scalePow10 + let wk ← + ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 + let bK ← ExceptT.mk <| readScaledFloatArray h hdr.headDim scalePow10 + let bQIntervals := intervalsFromScaled bQ slack + let bKIntervals := intervalsFromScaled bK slack + if at1 && hIdx = head1 then + let p1 ← bestMatchPattern layer1 head1 ln1Rows1 wq wk bQIntervals bKIntervals offset1 + p1? := some p1 + if at2 && hIdx = head2 then + let p2 ← bestMatchPattern layer2 head2 ln1Rows2 wq wk bQIntervals bKIntervals offset2 + p2? := some p2 + else + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let needUpdate := needUpdate1 || needUpdate2 + let needValue := at2 && hIdx = head2 + let needV := needUpdate || needValue + if needV then + let wv ← + ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 + let bV ← ExceptT.mk <| readScaledFloatArray h hdr.headDim scalePow10 + let wo ← + ExceptT.mk <| readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 + let bVIntervals := intervalsFromScaled bV slack + if needUpdate then + if shareUpdate then + let (attnRows', attnUnion') ← + addAttn useTight1 ln1RowsShared ln1UnionShared? groupRowsShared? + attnRowsShared? attnUnionShared? wv wo bVIntervals + attnRowsShared? := attnRows' + attnUnionShared? := attnUnion' + else + if needUpdate1 then + let (attnRows', attnUnion') ← + addAttn useTight1 ln1Rows1 ln1Union1? groupRows1? + attnRows1? attnUnion1? wv wo bVIntervals + attnRows1? := attnRows' + attnUnion1? := attnUnion' + if needUpdate2 then + let (attnRows', attnUnion') ← + addAttn useTight2 ln1Rows2 ln1Union2? groupRows2? + attnRows2? attnUnion2? wv wo bVIntervals + attnRows2? := attnRows' + attnUnion2? := attnUnion' + if needUpdateV && !shareUpdateV && !skipAttnV then + let (attnRows', attnUnion') ← + addAttn useTightV ln1RowsV ln1UnionV? groupRowsV? + attnRowsV? attnUnionV? wv wo bVIntervals + attnRowsV? := attnRows' + attnUnionV? := attnUnion' + if needValue then + let p2 ← + match p2? with + | some cert => pure cert + | none => throw "missing best-match pattern cert for value bound" + let vlogit ← + valueLogit ln1RowsV p2.bestMatchWeightLowerBound wv wo bVIntervals offset2 + vlogit? := some vlogit + else + let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) + let _ ← ExceptT.mk (skipF64Array h hdr.headDim) + let _ ← ExceptT.mk (skipF64Array h (hdr.headDim * hdr.modelDim)) + if p1?.isSome && p2?.isSome && vlogit?.isSome && !(needUpdate1 || needUpdate2) then + match p1?, p2?, vlogit? with + | some p1, some p2, some vlogit => + let cert : InductionHeadBestMatchSoundCert := { + layer1Pattern := p1 + layer2Pattern := p2 + layer2Value := vlogit.value + layer2Logit? := vlogit.logit? + deltaLowerBound := vlogit.value.outputCoordLowerBound + } + if cert.check then + return cert + throw "induction head certificate failed internal consistency checks" + | _, _, _ => throw "induction head certificate failed internal consistency checks" + if needUpdate1 || needUpdate2 then + let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + if shareUpdate then + residuals1 ← applyAttn residuals1 useTight1 attnRowsShared? attnUnionShared? attnBias + residuals2 := residuals1 + else + if needUpdate1 then + residuals1 ← applyAttn residuals1 useTight1 attnRows1? attnUnion1? attnBias + if needUpdate2 then + residuals2 ← applyAttn residuals2 useTight2 attnRows2? attnUnion2? attnBias + if needUpdateV && !shareUpdateV && !skipAttnV then + residualsV ← applyAttn residualsV useTightV attnRowsV? attnUnionV? attnBias + let wIn ← + ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.hiddenDim) scalePow10 + let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) + let wOut ← + ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 + let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) + let ln2P := ln2Params.getD l defP + if shareUpdate then + residuals1 := applyMlp residuals1 usePerRow1 ln2P wIn wOut bIn bOut + residuals2 := residuals1 + else + if needUpdate1 then + residuals1 := applyMlp residuals1 usePerRow1 ln2P wIn wOut bIn bOut + if needUpdate2 then + residuals2 := applyMlp residuals2 usePerRow2 ln2P wIn wOut bIn bOut + if needUpdateV then + if shareUpdateV then + residualsV := residuals2 + else + residualsV := applyMlp residualsV usePerRowV ln2P wIn wOut bIn bOut + if shareUpdate then + residualsSame := true + else if needUpdate1 && needUpdate2 then + residualsSame := false + if needUpdateV then + if shareUpdateV then + residualsSameV := true + else + residualsSameV := false + let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) + let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) + let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) + let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) + match p1?, p2?, vlogit? with + | some p1, some p2, some vlogit => + let cert : InductionHeadBestMatchSoundCert := { + layer1Pattern := p1 + layer2Pattern := p2 + layer2Value := vlogit.value + layer2Logit? := vlogit.logit? + deltaLowerBound := vlogit.value.outputCoordLowerBound + } + if cert.check then + return cert + throw "induction head certificate failed internal consistency checks" + | _, _, _ => + throw "target layer not reached" + action.run + + /-- Compute a combined sound certificate for an induction-style head pair (best-match, binary only). -/ def certifyInductionSoundBestMatch @@ -4433,63 +6103,113 @@ def certifyInductionSoundBestMatch (tightPattern : Bool := false) (tightPatternLayers : Nat := 1) (perRowPatternLayers : Nat := 0) + (iterTighten : Bool := false) (targetToken? : Option Nat := none) (negativeToken? : Option Nat := none) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : + (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) + (causalPattern : Bool := true) : IO (Except String InductionHeadBestMatchSoundCert) := do if inputDelta < 0 then return .error "delta must be nonnegative" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let firstLine := (← h.getLine).trim - if firstLine = "NFP_BINARY_V1" then - let inputPath := inputPath?.getD path - let p1E ← - certifyHeadPatternBestMatchLocalBinary path layer1 head1 queryPos? eps soundnessBits - inputPath inputDelta offset1 maxSeqLen tightPattern tightPatternLayers - perRowPatternLayers scalePow10 softmaxExpEffort - match p1E with - | .error e => return .error e - | .ok p1 => - let p2E ← - certifyHeadPatternBestMatchLocalBinary path layer2 head2 queryPos? eps soundnessBits - inputPath inputDelta offset2 maxSeqLen tightPattern tightPatternLayers - perRowPatternLayers scalePow10 softmaxExpEffort - match p2E with - | .error e => return .error e - | .ok p2 => - let vE ← - certifyHeadValueLowerBoundLocalBinaryAt path layer2 head2 p2.queryPos coord - eps soundnessBits inputPath inputDelta offset2 p2.bestMatchWeightLowerBound - maxSeqLen scalePow10 - match vE with - | .error e => return .error e - | .ok v => - let logitE ← - match targetToken?, negativeToken? with - | none, none => pure (.ok none) - | some targetToken, some negativeToken => do - let logitE ← - certifyHeadLogitDiffLowerBoundLocalBinaryAt path layer2 head2 p2.queryPos - targetToken negativeToken eps soundnessBits inputPath inputDelta offset2 - p2.bestMatchWeightLowerBound maxSeqLen scalePow10 - pure (logitE.map some) - | _, _ => - pure (.error "use both target and negative tokens (or neither)") - match logitE with - | .error e => return .error e - | .ok logit? => - let cert : InductionHeadBestMatchSoundCert := { - layer1Pattern := p1 - layer2Pattern := p2 - layer2Value := v - layer2Logit? := logit? - deltaLowerBound := v.outputCoordLowerBound - } - if cert.check then - return .ok cert - return .error "induction head certificate failed internal consistency checks" - else - return .error "induction sound cert requires NFP_BINARY_V1" + let action : ExceptT String IO InductionHeadBestMatchSoundCert := do + let h ← IO.FS.Handle.mk path IO.FS.Mode.read + let firstLine := (← h.getLine).trim + if firstLine = "NFP_BINARY_V1" then + let inputPath := inputPath?.getD path + let timingEnabled ← ExceptT.lift <| IO.getEnv "NFP_TIMING" + let timing : Bool := timingEnabled.isSome + let timeIt {α : Type} (label : String) (action : ExceptT String IO α) : + ExceptT String IO α := do + if !timing then + action + else + let t0 ← ExceptT.lift IO.monoNanosNow + let r ← action + let t1 ← ExceptT.lift IO.monoNanosNow + let dtNs := t1 - t0 + let dtMs := dtNs / 1000000 + ExceptT.lift <| IO.eprintln s!"timing:{label} {dtMs}ms" + return r + let shared ← ExceptT.mk (loadSharedBinaryInputs path inputPath inputDelta scalePow10) + let queryPos : Nat := + match queryPos? with + | some q => q + | none => + if shared.hdr.seqLen = 0 then 0 else shared.hdr.seqLen - 1 + if queryPos ≥ shared.hdr.seqLen then + throw s!"queryPos {queryPos} out of range" + let prefixCache := mkSharedBinaryPrefix shared queryPos causalPattern + let direction? : Option (Thunk (Array Fixed10Interval)) ← + match targetToken?, negativeToken? with + | none, none => pure none + | some targetToken, some negativeToken => + let (hdrDir, dir) ← + ExceptT.mk (readLogitDiffDirectionBinary + path targetToken negativeToken scalePow10 fixedUlpSlack) + if hdrDir.modelDim ≠ shared.hdr.modelDim then + throw "unembedding model_dim mismatch" + pure (some (Thunk.mk (fun () => dir))) + | _, _ => + throw "use both target and negative tokens (or neither)" + let computeCert (useTight : Bool) (tightLayers perRowLayers : Nat) : + ExceptT String IO InductionHeadBestMatchSoundCert := do + let label := + s!"tight={useTight} tl={tightLayers} pr={perRowLayers}" + let cert ← + timeIt (s!"{label}:pair") <| + ExceptT.mk <| + certifyInductionSoundBestMatchLocalBinaryPair + path layer1 head1 layer2 head2 coord queryPos eps soundnessBits inputPath + inputDelta offset1 offset2 maxSeqLen scalePow10 useTight tightLayers + perRowLayers softmaxExpEffort causalPattern + (shared? := some shared) (prefix? := some prefixCache) + (targetToken? := targetToken?) (negativeToken? := negativeToken?) + (direction? := direction?) + return cert + if !iterTighten then + let cert ← computeCert tightPattern tightPatternLayers perRowPatternLayers + return cert + else + let maxLayer := Nat.max layer1 layer2 + let tightFull := Nat.max 1 maxLayer + let perRowFull := maxLayer + let mut configs : Array (Bool × Nat × Nat) := + #[(tightPattern, tightPatternLayers, perRowPatternLayers)] + let needTightFull := (!tightPattern) || tightPatternLayers < tightFull + if needTightFull then + configs := configs.push (true, tightFull, perRowPatternLayers) + if perRowPatternLayers < perRowFull then + configs := configs.push (true, tightFull, perRowFull) + let tasks ← + ExceptT.lift <| + configs.mapM fun (useTight, tightLayers, perRowLayers) => + IO.asTask (computeCert useTight tightLayers perRowLayers).run + let results := tasks.map (fun t => t.get) + let mut best : Option (Rat × InductionHeadBestMatchSoundCert) := none + for i in [:configs.size] do + let res := results[i]! + match res with + | .error e => throw (toString e) + | .ok (.error msg) => throw msg + | .ok (.ok cert) => + let metric := + match cert.layer2Logit? with + | some logit => logit.logitDiffLowerBound + | none => cert.deltaLowerBound + best := + match best with + | none => some (metric, cert) + | some (bestMetric, bestCert) => + if metric > bestMetric then + some (metric, cert) + else + some (bestMetric, bestCert) + match best with + | none => throw "no induction certs computed" + | some (_, cert) => return cert + else + throw "induction sound cert requires NFP_BINARY_V1" + action.run /-! ### Specs -/ From 65c736bac300b72ac3ae6cde3de7b90c7d17baa4 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 30 Dec 2025 03:58:34 +0100 Subject: [PATCH 015/244] Docs: note gitignored models dir --- AGENTS.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/AGENTS.md b/AGENTS.md index 245b366..0567f52 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -37,6 +37,9 @@ Before you finish any change: - `lake build -q --wfail` - `lake build nfp -q --wfail` +Note: `models/` is gitignored, so `rg` will skip it unless you pass `--no-ignore` +or `-uuu` (or equivalent) when searching. + --- ## 1. Non-Negotiables (Hard Rules) From e61f639975e48e6fcfa76d67d4688cae09a7262f Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 30 Dec 2025 03:58:48 +0100 Subject: [PATCH 016/244] Sound: fix attention score scale bound --- Nfp/Sound/Bounds/Attention.lean | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Nfp/Sound/Bounds/Attention.lean b/Nfp/Sound/Bounds/Attention.lean index 72d94e2..3da0736 100644 --- a/Nfp/Sound/Bounds/Attention.lean +++ b/Nfp/Sound/Bounds/Attention.lean @@ -70,7 +70,7 @@ def attnScoreAbsBound (modelDim headDim : Nat) let dRat : Rat := (modelDim : Nat) let qMax := dRat * ln1OutMaxAbs * wqBound let kMax := dRat * ln1OutMaxAbs * wkBound - sqrtUpperRat headDim * qMax * kMax + invSqrtUpperBound headDim * qMax * kMax theorem attnScoreAbsBound_def (modelDim headDim : Nat) (ln1OutMaxAbs wqBound wkBound : Rat) : @@ -78,6 +78,6 @@ theorem attnScoreAbsBound_def (modelDim headDim : Nat) let dRat : Rat := (modelDim : Nat) let qMax := dRat * ln1OutMaxAbs * wqBound let kMax := dRat * ln1OutMaxAbs * wkBound - sqrtUpperRat headDim * qMax * kMax := rfl + invSqrtUpperBound headDim * qMax * kMax := rfl end Nfp.Sound From 0332d9cc1280e1d7b44e9c08414d98a4df522117 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 30 Dec 2025 04:11:45 +0100 Subject: [PATCH 017/244] Add causal pattern flags and logit-diff helper --- Main.lean | 182 ++++++++++++++++++++++++++++++++++++++- Nfp/IO.lean | 50 +++++++++++ Nfp/Sound/HeadCert.lean | 5 +- Nfp/Sound/IO.lean | 61 ++++++++----- README.md | 1 + SOUNDNESS_LIMITATIONS.md | 10 ++- 6 files changed, 281 insertions(+), 28 deletions(-) diff --git a/Main.lean b/Main.lean index 55be692..65e46f9 100644 --- a/Main.lean +++ b/Main.lean @@ -879,6 +879,7 @@ private structure CertifyArgs where tightPattern : Bool tightPatternLayers : Nat perRowPatternLayers : Nat + causalPattern : Bool scalePow10 : Nat outputPath? : Option System.FilePath @@ -898,6 +899,7 @@ private def parseCertifyArgs (p : Parsed) : CertifyArgs := let tightPattern := p.flag? "tightPattern" |>.isSome let tightPatternLayers := p.flag? "tightPatternLayers" |>.map (·.as! Nat) |>.getD 1 let perRowPatternLayers := p.flag? "perRowPatternLayers" |>.map (·.as! Nat) |>.getD 0 + let causalPattern := !p.hasFlag "noncausalPattern" let scalePow10 := p.flag? "scalePow10" |>.map (·.as! Nat) |>.getD 9 let outputPath? := p.flag? "output" |>.map (·.as! String) |>.map (fun s => ⟨s⟩) { modelPath := ⟨modelPathStr⟩ @@ -915,6 +917,7 @@ private def parseCertifyArgs (p : Parsed) : CertifyArgs := tightPattern := tightPattern tightPatternLayers := tightPatternLayers perRowPatternLayers := perRowPatternLayers + causalPattern := causalPattern scalePow10 := scalePow10 outputPath? := outputPath? } @@ -966,7 +969,7 @@ Pass --input or use a model file that embeds EMBEDDINGS." (targetOffset := args.targetOffset) (maxSeqLen := args.maxSeqLen) (tightPattern := args.tightPattern) (tightPatternLayers := args.tightPatternLayers) (perRowPatternLayers := args.perRowPatternLayers) (scalePow10 := args.scalePow10) - (softmaxExpEffort := args.softmaxExpEffort) + (softmaxExpEffort := args.softmaxExpEffort) (causalPattern := args.causalPattern) return cert else let cert ← ExceptT.mk <| @@ -1077,6 +1080,7 @@ private structure HeadPatternArgs where tightPatternLayers : Nat tightPattern : Bool perRowPatternLayers : Nat + causalPattern : Bool bestMatch : Bool sweep : Bool queryPos? : Option Nat @@ -1098,6 +1102,7 @@ private def parseHeadPatternArgs (p : Parsed) : HeadPatternArgs := let tightPatternLayers := tightPatternLayers?.getD 1 let tightPattern := p.hasFlag "tightPattern" || tightPatternLayers?.isSome let perRowPatternLayers := p.flag? "perRowPatternLayers" |>.map (·.as! Nat) |>.getD 0 + let causalPattern := !p.hasFlag "noncausalPattern" let bestMatch := p.hasFlag "bestMatch" let sweep := p.hasFlag "sweep" let queryPos? := p.flag? "queryPos" |>.map (·.as! Nat) @@ -1114,6 +1119,7 @@ private def parseHeadPatternArgs (p : Parsed) : HeadPatternArgs := tightPatternLayers := tightPatternLayers tightPattern := tightPattern perRowPatternLayers := perRowPatternLayers + causalPattern := causalPattern bestMatch := bestMatch sweep := sweep queryPos? := queryPos? @@ -1188,6 +1194,7 @@ private def runHeadPatternAction (args : HeadPatternArgs) : ExceptT String IO St (tightPatternLayers := args.tightPatternLayers) (perRowPatternLayers := args.perRowPatternLayers) (softmaxExpEffort := args.softmaxExpEffort) + (causalPattern := args.causalPattern) return formatHeadPatternBestMatchSweep args.layerIdx args.headIdx args.offset certs else let cert ← ExceptT.mk <| @@ -1198,6 +1205,7 @@ private def runHeadPatternAction (args : HeadPatternArgs) : ExceptT String IO St (tightPattern := args.tightPattern) (tightPatternLayers := args.tightPatternLayers) (perRowPatternLayers := args.perRowPatternLayers) (softmaxExpEffort := args.softmaxExpEffort) + (causalPattern := args.causalPattern) return formatHeadPatternBestMatch cert else if args.sweep then @@ -1210,6 +1218,7 @@ private def runHeadPatternAction (args : HeadPatternArgs) : ExceptT String IO St (tightPattern := args.tightPattern) (tightPatternLayers := args.tightPatternLayers) (perRowPatternLayers := args.perRowPatternLayers) (softmaxExpEffort := args.softmaxExpEffort) + (causalPattern := args.causalPattern) return formatHeadPatternLocal cert /-- Run the certify command - compute conservative, exact bounds in sound mode. -/ @@ -1263,6 +1272,8 @@ private structure InductionCertArgs where tightPatternLayers : Nat tightPattern : Bool perRowPatternLayers : Nat + iterTighten : Bool + causalPattern : Bool bestMatch : Bool queryPos? : Option Nat inputPath? : Option System.FilePath @@ -1291,6 +1302,8 @@ private def parseInductionCertArgs (p : Parsed) : ExceptT String IO InductionCer let tightPatternLayers := tightPatternLayers?.getD 1 let tightPattern := p.hasFlag "tightPattern" || tightPatternLayers?.isSome let perRowPatternLayers := p.flag? "perRowPatternLayers" |>.map (·.as! Nat) |>.getD 0 + let iterTighten := p.hasFlag "iterTighten" + let causalPattern := !p.hasFlag "noncausalPattern" let bestMatch := p.hasFlag "bestMatch" let queryPos := p.flag? "queryPos" |>.map (·.as! Nat) let inputPath := p.flag? "input" |>.map (·.as! String) @@ -1330,6 +1343,8 @@ private def parseInductionCertArgs (p : Parsed) : ExceptT String IO InductionCer tightPatternLayers := tightPatternLayers tightPattern := tightPattern perRowPatternLayers := perRowPatternLayers + iterTighten := iterTighten + causalPattern := causalPattern bestMatch := bestMatch queryPos? := queryPos inputPath? := inputPath? @@ -1416,8 +1431,10 @@ private def runInductionCertAction (args : InductionCertArgs) : ExceptT String I (tightPattern := args.tightPattern) (tightPatternLayers := args.tightPatternLayers) (perRowPatternLayers := args.perRowPatternLayers) + (iterTighten := args.iterTighten) (targetToken? := args.targetToken?) (negativeToken? := args.negativeToken?) (softmaxExpEffort := args.softmaxExpEffort) + (causalPattern := args.causalPattern) return formatInductionBestMatch cert else let cert ← ExceptT.mk <| @@ -1430,6 +1447,7 @@ private def runInductionCertAction (args : InductionCertArgs) : ExceptT String I (perRowPatternLayers := args.perRowPatternLayers) (targetToken? := args.targetToken?) (negativeToken? := args.negativeToken?) (softmaxExpEffort := args.softmaxExpEffort) + (causalPattern := args.causalPattern) return formatInductionLocal cert /-- Run the induction-cert command - compute a sound induction head certificate. -/ @@ -1607,6 +1625,147 @@ def runDump (p : Parsed) : IO UInt32 := do let args := parseDumpArgs p runDumpWithArgs args +/-! ## Logit-difference helpers -/ + +private def logitAt (residual : ConcreteMatrix) (pos : Nat) + (W_U : ConcreteMatrix) (token : Nat) : Except String Float := + if residual.numCols ≠ W_U.numRows then + .error "dimension mismatch: residual.numCols != W_U.numRows" + else if pos ≥ residual.numRows then + .error "position out of range" + else if token ≥ W_U.numCols then + .error "token out of range" + else + .ok <| Id.run do + let d := residual.numCols + let vocab := W_U.numCols + let rowBase := pos * d + let mut acc : Float := 0.0 + for k in [:d] do + acc := acc + residual.data[rowBase + k]! * W_U.data[k * vocab + token]! + return acc + +private def topNonTargetToken (residual : ConcreteMatrix) (pos : Nat) + (W_U : ConcreteMatrix) (targetToken : Nat) : Except String (Nat × Float) := + if residual.numCols ≠ W_U.numRows then + .error "dimension mismatch: residual.numCols != W_U.numRows" + else if pos ≥ residual.numRows then + .error "position out of range" + else if targetToken ≥ W_U.numCols then + .error "target token out of range" + else if W_U.numCols < 2 then + .error "vocab size too small to select non-target token" + else + .ok <| Id.run do + let d := residual.numCols + let vocab := W_U.numCols + let rowBase := pos * d + let mut bestTok : Nat := 0 + let mut bestLogit : Float := (-Float.inf) + let mut found : Bool := false + for tok in [:vocab] do + if tok ≠ targetToken then + found := true + let mut acc : Float := 0.0 + for k in [:d] do + acc := acc + residual.data[rowBase + k]! * W_U.data[k * vocab + tok]! + if acc > bestLogit then + bestTok := tok + bestLogit := acc + if found then + return (bestTok, bestLogit) + else + return (0, bestLogit) + +private structure LogitDiffArgs where + modelPath : System.FilePath + modelPathStr : String + target : Nat + negative : Nat + pos? : Option Nat + inputPath? : Option System.FilePath + autoNegative : Bool + +private def parseLogitDiffArgs (p : Parsed) : LogitDiffArgs := + let modelPathStr := p.positionalArg! "model" |>.as! String + let target := p.positionalArg! "target" |>.as! Nat + let negative := p.positionalArg! "negative" |>.as! Nat + let pos? := p.flag? "pos" |>.map (·.as! Nat) + let inputPath? := p.flag? "input" |>.map (System.FilePath.mk ∘ (·.as! String)) + let autoNegative := p.hasFlag "autoNegative" + { modelPath := ⟨modelPathStr⟩ + modelPathStr := modelPathStr + target := target + negative := negative + pos? := pos? + inputPath? := inputPath? + autoNegative := autoNegative } + +private def runLogitDiff (p : Parsed) : IO UInt32 := do + let args := parseLogitDiffArgs p + setStdoutLogNameFromModelPath args.modelPathStr + let loadResult ← loadModel args.modelPath + match loadResult with + | .error msg => + IO.eprintln s!"Error loading model: {msg}" + return 1 + | .ok model0 => + let model ← + match args.inputPath? with + | none => pure model0 + | some inputPath => + match ← loadInputBinary inputPath with + | .error msg => + IO.eprintln s!"Error loading input: {msg}" + return 1 + | .ok input => + if input.modelDim ≠ model0.modelDim then + IO.eprintln s!"Input model_dim mismatch ({input.modelDim} != {model0.modelDim})" + return 1 + pure { + model0 with + seqLen := input.seqLen + inputTokens := some input.tokens + inputEmbeddings := input.embeddings + } + match model.unembedding with + | none => + IO.eprintln "Error: Model is missing unembedding matrix (needed for logits)." + return 1 + | some W_U => + if model.seqLen = 0 then + IO.eprintln "Error: seq_len = 0; cannot compute logits." + return 1 + let pos := args.pos?.getD (model.seqLen - 1) + if pos ≥ model.seqLen then + IO.eprintln s!"Error: pos={pos} out of bounds (seq_len={model.seqLen})" + return 1 + let fwd := model.runForward true + let residual := fwd.finalOutput + let negResult := + if args.autoNegative then + topNonTargetToken residual pos W_U args.target + else + match logitAt residual pos W_U args.negative with + | .ok logit => .ok (args.negative, logit) + | .error msg => .error msg + match logitAt residual pos W_U args.target, negResult with + | .ok targetLogit, .ok (negTok, negLogit) => + let diff := targetLogit - negLogit + IO.println s!"pos={pos} target={args.target} negative={negTok}" + if args.autoNegative then + IO.println "negativeSource=topNonTarget" + IO.println s!"logit(target)={targetLogit}" + IO.println s!"logit(negative)={negLogit}" + IO.println s!"logitDiff={diff}" + return 0 + | .error msg, _ => + IO.eprintln s!"Error computing target logit: {msg}" + return 1 + | _, .error msg => + IO.eprintln s!"Error computing negative logit: {msg}" + return 1 + /-- The analyze subcommand. -/ def analyzeCmd : Cmd := `[Cli| analyze VIA runAnalyze; @@ -1660,6 +1819,7 @@ if --input is omitted, uses EMBEDDINGS in the model file when present)" tightPattern; "Use tighter (slower) pattern bounds for best-match margins" tightPatternLayers : Nat; "Number of layers using tight pattern bounds (default: 1)" perRowPatternLayers : Nat; "Number of layers using per-row MLP propagation (default: 0)" + noncausalPattern; "Disable causal-prefix restriction for pattern/value bounds" scalePow10 : Nat; "Fixed-point scale exponent for best-match margins (default: 9)" soundnessBits : Nat; "Dyadic sqrt precision bits for LayerNorm bounds (default: 20)" partitionDepth : Nat; "Partition depth for input splitting (default: 0; >0 scaffold only)" @@ -1697,6 +1857,7 @@ LayerNorm epsilon is read from the model header." tightPattern; "Use tighter (slower) pattern bounds near the target layer" tightPatternLayers : Nat; "Number of layers using tight pattern bounds (default: 1)" perRowPatternLayers : Nat; "Number of layers using per-row MLP propagation (default: 0)" + noncausalPattern; "Disable causal-prefix restriction for pattern/value bounds" bestMatch; "Use best-match (single-query) pattern bounds" sweep; "Sweep best-match bounds across all valid query positions" queryPos : Nat; "Query position for best-match bounds (default: last position)" @@ -1729,6 +1890,8 @@ LayerNorm epsilon is read from the model header." tightPattern; "Use tighter (slower) pattern bounds near the target layer" tightPatternLayers : Nat; "Number of layers using tight pattern bounds (default: 1)" perRowPatternLayers : Nat; "Number of layers using per-row MLP propagation (default: 0)" + iterTighten; "Iteratively tighten best-match bounds (escalates tight/per-row layers to full)" + noncausalPattern; "Disable causal-prefix restriction for pattern/value bounds" bestMatch; "Use best-match (single-query) pattern bounds" queryPos : Nat; "Query position for best-match bounds (default: last position)" input : String; "Optional input .nfpt file for local bounds (must contain EMBEDDINGS \ @@ -1775,6 +1938,20 @@ def dumpCmd : Cmd := `[Cli| model : String; "Path to the model weights file (.nfpt)" ] +/-- The logit-diff subcommand. -/ +def logitDiffCmd : Cmd := `[Cli| + logit_diff VIA runLogitDiff; + "Compute empirical logit-difference for target vs. negative token." + FLAGS: + pos : Nat; "Token position (default: last position)" + input : String; "Optional input .nfpt file with TOKENS + EMBEDDINGS" + autoNegative; "Use top non-target logit as negative token (ignores provided negative)" + ARGS: + model : String; "Path to the model weights file (.nfpt)" + target : Nat; "Target token ID" + negative : Nat; "Negative token ID" +] + /-- The main CLI command. -/ def nfpCmd : Cmd := `[Cli| nfp NOOP; @@ -1788,7 +1965,8 @@ def nfpCmd : Cmd := `[Cli| inductionCertCmd; soundCacheCheckCmd; ropeCmd; - dumpCmd + dumpCmd; + logitDiffCmd ] /-- Main entry point. -/ diff --git a/Nfp/IO.lean b/Nfp/IO.lean index dcc5f86..7c8b2d7 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -259,6 +259,56 @@ def loadBinary (h : IO.FS.Handle) : IO LoadResult := do return .ok model catch e => return .error s!"Binary load failed: {e}" + +/-- Input tokens + embeddings loaded from a binary `.nfpt` file. -/ +structure InputBinary where + /-- Sequence length parsed from the input header. -/ + seqLen : Nat + /-- Model dimension parsed from the input header. -/ + modelDim : Nat + /-- Token IDs parsed from the input file. -/ + tokens : Array Nat + /-- Input embeddings (seqLen × modelDim). -/ + embeddings : ConcreteMatrix + +/-- Load input tokens + embeddings from a binary `.nfpt` file. -/ +def loadInputBinary (path : System.FilePath) : IO (Except String InputBinary) := do + try + let h ← IO.FS.Handle.mk path IO.FS.Mode.read + let some magicLine ← readLine? h + | return .error "Empty input file" + let magic := magicLine.trim + if magic != "NFP_BINARY_V1" then + return .error "Invalid input magic: expected NFP_BINARY_V1" + let mut seqLen? : Option Nat := none + let mut modelDim? : Option Nat := none + let mut line? ← readLine? h + while true do + match line? with + | none => return .error "Unexpected EOF while reading input header" + | some line => + let t := line.trim + if t = "BINARY_START" then + break + if t.startsWith "seq_len=" then + match (t.drop 8).toNat? with + | some n => seqLen? := some n + | none => return .error "Invalid seq_len in input header" + else if t.startsWith "model_dim=" then + match (t.drop 10).toNat? with + | some n => modelDim? := some n + | none => return .error "Invalid model_dim in input header" + line? ← readLine? h + let some seqLen := seqLen? + | return .error "Missing seq_len in input header" + let some modelDim := modelDim? + | return .error "Missing model_dim in input header" + let tokens ← readI32Array h seqLen + let embFloats ← readFloatArray h (seqLen * modelDim) + let embeddings := buildMatrix seqLen modelDim embFloats.data + return .ok { seqLen := seqLen, modelDim := modelDim, tokens := tokens, embeddings := embeddings } + catch e => + return .error s!"Binary input load failed: {e}" /-! ## File IO Operations -/ /-- Load a model from a file path. Supports .nfpt (binary) format. -/ diff --git a/Nfp/Sound/HeadCert.lean b/Nfp/Sound/HeadCert.lean index a99f369..33eae3b 100644 --- a/Nfp/Sound/HeadCert.lean +++ b/Nfp/Sound/HeadCert.lean @@ -290,6 +290,7 @@ structure LayerBestMatchMarginCert where layerIdx : Nat seqLen : Nat numHeads : Nat + /-- Max softmax exp effort allowed for per-head best-match certificates. -/ softmaxExpEffort : Nat marginLowerBound : Rat margins : Array Rat @@ -300,14 +301,14 @@ namespace LayerBestMatchMarginCert /-- Internal consistency checks for aggregated margins. -/ def Valid (c : LayerBestMatchMarginCert) : Prop := - c.seqLen > 0 ∧ + c.seqLen > 0 ∧ c.numHeads > 0 ∧ c.margins.size = c.numHeads * c.seqLen ∧ c.headCerts.all (fun cert => cert.check && cert.layerIdx == c.layerIdx && cert.seqLen == c.seqLen && - cert.softmaxExpEffort == c.softmaxExpEffort && + decide (cert.softmaxExpEffort ≤ c.softmaxExpEffort) && cert.headIdx < c.numHeads && cert.queryPos < c.seqLen) = true ∧ marginsFromBestMatchCerts c.numHeads c.seqLen c.headCerts = some c.margins ∧ diff --git a/Nfp/Sound/IO.lean b/Nfp/Sound/IO.lean index 83cc1c6..40423de 100644 --- a/Nfp/Sound/IO.lean +++ b/Nfp/Sound/IO.lean @@ -315,7 +315,8 @@ def certifyHeadPatternLocal (tightPatternLayers : Nat := 1) (perRowPatternLayers : Nat := 0) (scalePow10 : Nat := 9) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : + (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) + (causalPattern : Bool := true) : IO (Except String HeadPatternCert) := do match ← readModelEps path with | .error e => return .error e @@ -323,7 +324,8 @@ def certifyHeadPatternLocal match ← Nfp.Untrusted.SoundCompute.certifyHeadPatternLocal path layerIdx headIdx eps soundnessBits inputPath? inputDelta targetOffset maxSeqLen - tightPattern tightPatternLayers perRowPatternLayers scalePow10 softmaxExpEffort with + tightPattern tightPatternLayers perRowPatternLayers scalePow10 softmaxExpEffort + causalPattern with | .error e => return .error e | .ok cert => return verifyHeadPatternCert cert @@ -342,7 +344,8 @@ def certifyHeadPatternBestMatchLocal (tightPatternLayers : Nat := 1) (perRowPatternLayers : Nat := 0) (scalePow10 : Nat := 9) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : + (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) + (causalPattern : Bool := true) : IO (Except String HeadBestMatchPatternCert) := do match ← readModelEps path with | .error e => return .error e @@ -351,7 +354,7 @@ def certifyHeadPatternBestMatchLocal Nfp.Untrusted.SoundCompute.certifyHeadPatternBestMatchLocal path layerIdx headIdx queryPos? eps soundnessBits inputPath? inputDelta targetOffset maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 - softmaxExpEffort with + softmaxExpEffort causalPattern with | .error e => return .error e | .ok cert => return verifyHeadBestMatchPatternCert cert @@ -369,7 +372,8 @@ def certifyHeadPatternBestMatchLocalSweep (tightPatternLayers : Nat := 1) (perRowPatternLayers : Nat := 0) (scalePow10 : Nat := 9) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : + (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) + (causalPattern : Bool := true) : IO (Except String (Array HeadBestMatchPatternCert)) := do match ← readModelEps path with | .error e => return .error e @@ -377,7 +381,8 @@ def certifyHeadPatternBestMatchLocalSweep match ← Nfp.Untrusted.SoundCompute.certifyHeadPatternBestMatchLocalSweep path layerIdx headIdx eps soundnessBits inputPath? inputDelta targetOffset maxSeqLen - tightPattern tightPatternLayers perRowPatternLayers scalePow10 softmaxExpEffort with + tightPattern tightPatternLayers perRowPatternLayers scalePow10 softmaxExpEffort + causalPattern with | .error e => return .error e | .ok certs => return verifyHeadBestMatchPatternCerts certs @@ -395,7 +400,8 @@ def certifyLayerBestMatchMarginLocal (tightPatternLayers : Nat := 1) (perRowPatternLayers : Nat := 0) (scalePow10 : Nat := 9) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : + (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) + (causalPattern : Bool := true) : IO (Except String LayerBestMatchMarginCert) := do match ← readModelEps path with | .error e => return .error e @@ -403,7 +409,8 @@ def certifyLayerBestMatchMarginLocal match ← Nfp.Untrusted.SoundCompute.certifyLayerBestMatchMarginLocal path layerIdx eps soundnessBits inputPath? inputDelta targetOffset maxSeqLen - tightPattern tightPatternLayers perRowPatternLayers scalePow10 softmaxExpEffort with + tightPattern tightPatternLayers perRowPatternLayers scalePow10 softmaxExpEffort + causalPattern with | .error e => return .error e | .ok cert => return verifyLayerBestMatchMarginCert cert @@ -421,7 +428,8 @@ def certifyModelFileBestMatchMargins (tightPatternLayers : Nat := 1) (perRowPatternLayers : Nat := 0) (scalePow10 : Nat := 9) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : IO (Except String ModelCert) := do + (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) + (causalPattern : Bool := true) : IO (Except String ModelCert) := do match ← readBinaryModelHeader path with | .error e => return .error e | .ok hdr => @@ -449,7 +457,8 @@ def certifyModelFileBestMatchMargins (tightPatternLayers := tightPatternLayers) (perRowPatternLayers := perRowPatternLayers) (scalePow10 := scalePow10) - (softmaxExpEffort := softmaxExpEffort) with + (softmaxExpEffort := softmaxExpEffort) + (causalPattern := causalPattern) with | .error e => return .error e | .ok cert => marginCerts := marginCerts.push cert return verifyModelCertBestMatchMargins cert hdr.eps soundnessBits hdr.geluDerivTarget @@ -473,7 +482,8 @@ def certifyHeadBoundsLocalBestMatch (tightPatternLayers : Nat := 1) (perRowPatternLayers : Nat := 0) (scalePow10 : Nat := 9) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : + (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) + (causalPattern : Bool := true) : IO (Except String HeadLocalContributionCert) := do match ← readModelEps path with | .error e => return .error e @@ -494,11 +504,12 @@ def certifyHeadBoundsLocalBestMatch (targetOffset := targetOffset) (maxSeqLen := maxSeqLen) (tightPattern := tightPattern) (tightPatternLayers := tightPatternLayers) (perRowPatternLayers := perRowPatternLayers) - (softmaxExpEffort := softmaxExpEffort) with + (softmaxExpEffort := softmaxExpEffort) + (causalPattern := causalPattern) with | .error e => return .error e | .ok pattern => return tightenHeadLocalContributionBestMatch - eps soundnessBits base pattern softmaxExpEffort + eps soundnessBits base pattern pattern.softmaxExpEffort /-- Compute local head output lower bounds. -/ def certifyHeadValueLowerBoundLocal @@ -513,7 +524,8 @@ def certifyHeadValueLowerBoundLocal (tightPattern : Bool := false) (tightPatternLayers : Nat := 1) (perRowPatternLayers : Nat := 0) - (scalePow10 : Nat := 9) : + (scalePow10 : Nat := 9) + (causalPattern : Bool := true) : IO (Except String HeadValueLowerBoundCert) := do match ← readModelEps path with | .error e => return .error e @@ -521,7 +533,8 @@ def certifyHeadValueLowerBoundLocal match ← Nfp.Untrusted.SoundCompute.certifyHeadValueLowerBoundLocal path layerIdx headIdx coord eps soundnessBits inputPath? inputDelta targetOffset - maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 with + maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 + causalPattern with | .error e => return .error e | .ok cert => return verifyHeadValueLowerBoundCert cert @@ -539,7 +552,8 @@ def certifyHeadLogitDiffLowerBoundLocal (tightPattern : Bool := false) (tightPatternLayers : Nat := 1) (perRowPatternLayers : Nat := 0) - (scalePow10 : Nat := 9) : + (scalePow10 : Nat := 9) + (causalPattern : Bool := true) : IO (Except String HeadLogitDiffLowerBoundCert) := do match ← readModelEps path with | .error e => return .error e @@ -548,7 +562,7 @@ def certifyHeadLogitDiffLowerBoundLocal Nfp.Untrusted.SoundCompute.certifyHeadLogitDiffLowerBoundLocal path layerIdx headIdx targetToken negativeToken eps soundnessBits inputPath? inputDelta targetOffset maxSeqLen tightPattern tightPatternLayers - perRowPatternLayers scalePow10 with + perRowPatternLayers scalePow10 causalPattern with | .error e => return .error e | .ok cert => return verifyHeadLogitDiffLowerBoundCert cert @@ -570,7 +584,8 @@ def certifyInductionSound (perRowPatternLayers : Nat := 0) (targetToken? : Option Nat := none) (negativeToken? : Option Nat := none) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : + (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) + (causalPattern : Bool := true) : IO (Except String InductionHeadSoundCert) := do match ← readModelEps path with | .error e => return .error e @@ -579,7 +594,7 @@ def certifyInductionSound Nfp.Untrusted.SoundCompute.certifyInductionSound path layer1 head1 layer2 head2 coord eps soundnessBits inputPath? inputDelta offset1 offset2 maxSeqLen scalePow10 tightPattern tightPatternLayers - perRowPatternLayers targetToken? negativeToken? softmaxExpEffort with + perRowPatternLayers targetToken? negativeToken? softmaxExpEffort causalPattern with | .error e => return .error e | .ok cert => return verifyInductionHeadSoundCert cert @@ -600,9 +615,11 @@ def certifyInductionSoundBestMatch (tightPattern : Bool := false) (tightPatternLayers : Nat := 1) (perRowPatternLayers : Nat := 0) + (iterTighten : Bool := false) (targetToken? : Option Nat := none) (negativeToken? : Option Nat := none) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : + (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) + (causalPattern : Bool := true) : IO (Except String InductionHeadBestMatchSoundCert) := do match ← readModelEps path with | .error e => return .error e @@ -611,7 +628,9 @@ def certifyInductionSoundBestMatch Nfp.Untrusted.SoundCompute.certifyInductionSoundBestMatch path layer1 head1 layer2 head2 coord queryPos? eps soundnessBits inputPath? inputDelta offset1 offset2 maxSeqLen scalePow10 tightPattern - tightPatternLayers perRowPatternLayers targetToken? negativeToken? softmaxExpEffort with + tightPatternLayers perRowPatternLayers iterTighten targetToken? negativeToken? + softmaxExpEffort + causalPattern with | .error e => return .error e | .ok cert => return verifyInductionHeadBestMatchSoundCert cert diff --git a/README.md b/README.md index 6dfcb60..2cd2368 100644 --- a/README.md +++ b/README.md @@ -343,6 +343,7 @@ lake exe nfp induction_cert models/gpt2_rigorous.nfpt \ - `--perRowPatternLayers` sets how many layers use per-row MLP propagation (default: `0`). - `--bestMatch` switches to single-query best-match bounds (default query: last position). - `--queryPos` chooses the query position for best-match bounds (default: last position). +- `--iterTighten` iteratively tightens best-match bounds, escalating to full tight/per-row layers. ### `rope` diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index 1aebd23..3700e24 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -19,9 +19,13 @@ soundness upgrade. It is intentionally brief and human-readable. - Best-match margin tightening is now available via `nfp certify --bestMatchMargins` (binary + local inputs with EMBEDDINGS). It runs a full best-match sweep across heads and query positions, which can be expensive and will fail if coverage is incomplete. -- Per-head best-match tightening (used by head-pattern/induction certs) is still separate from - model-level certification unless `--bestMatchMargins` is used. -- Best-match pattern certificates now use a margin-derived softmax Jacobian bound with an +- Local pattern/value/logit bounds now assume **causal attention** by default (prefix-only keys). + Use `--noncausalPattern` for non-causal models; otherwise these bounds are not sound. +- Per-head best-match tightening (used by head-pattern/induction certs) now records the **actual** + `softmaxExpEffort` chosen by iterative exp-portfolio tightening (early stop on low relative + improvement). The verifier accepts any per-head effort ≤ the requested cap, but model-level + certification still requires `--bestMatchMargins`. +- Best-match pattern certificates use a margin-derived softmax Jacobian bound with an effort-indexed `expLB` (scaled Taylor + squaring). The lower-bound correctness of `expLB` is not yet formalized in Lean. - GeLU derivative bounds are conservative envelopes; the exact interval supremum is not computed yet. From 491a009402b669245b5407b791ca512132f70b4e Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 30 Dec 2025 15:49:14 +0100 Subject: [PATCH 018/244] Tighten SOUND GeLU bounds and LN variance --- Main.lean | 7 + Nfp/Sound/Fixed.lean | 11 +- Nfp/Sound/Interval.lean | 127 ++++++++++- Nfp/Untrusted/SoundCompute.lean | 386 ++++++++++++++++++++++++-------- README.md | 2 +- 5 files changed, 425 insertions(+), 108 deletions(-) diff --git a/Main.lean b/Main.lean index 65e46f9..d2cc00c 100644 --- a/Main.lean +++ b/Main.lean @@ -1279,6 +1279,7 @@ private structure InductionCertArgs where inputPath? : Option System.FilePath delta : Rat maxSeqLen : Nat + scalePow10 : Nat outputPath? : Option System.FilePath /-- Parse and validate `induction-cert` arguments. -/ @@ -1309,6 +1310,7 @@ private def parseInductionCertArgs (p : Parsed) : ExceptT String IO InductionCer let inputPath := p.flag? "input" |>.map (·.as! String) let deltaStr := p.flag? "delta" |>.map (·.as! String) |>.getD "0" let maxSeqLen := p.flag? "maxSeqLen" |>.map (·.as! Nat) |>.getD 256 + let scalePow10 := p.flag? "scalePow10" |>.map (·.as! Nat) |>.getD 9 let outputPath := p.flag? "output" |>.map (·.as! String) let delta ← match Nfp.Sound.parseRat deltaStr with @@ -1350,6 +1352,7 @@ private def parseInductionCertArgs (p : Parsed) : ExceptT String IO InductionCer inputPath? := inputPath? delta := delta maxSeqLen := maxSeqLen + scalePow10 := scalePow10 outputPath? := outputPath? } @@ -1428,6 +1431,7 @@ private def runInductionCertAction (args : InductionCertArgs) : ExceptT String I (queryPos? := args.queryPos?) (inputPath? := args.inputPath?) (inputDelta := args.delta) (soundnessBits := args.soundnessBits) (offset1 := args.offset1) (offset2 := args.offset2) (maxSeqLen := args.maxSeqLen) + (scalePow10 := args.scalePow10) (tightPattern := args.tightPattern) (tightPatternLayers := args.tightPatternLayers) (perRowPatternLayers := args.perRowPatternLayers) @@ -1443,6 +1447,7 @@ private def runInductionCertAction (args : InductionCertArgs) : ExceptT String I (inputPath? := args.inputPath?) (inputDelta := args.delta) (soundnessBits := args.soundnessBits) (offset1 := args.offset1) (offset2 := args.offset2) (maxSeqLen := args.maxSeqLen) + (scalePow10 := args.scalePow10) (tightPattern := args.tightPattern) (tightPatternLayers := args.tightPatternLayers) (perRowPatternLayers := args.perRowPatternLayers) (targetToken? := args.targetToken?) (negativeToken? := args.negativeToken?) @@ -1867,6 +1872,7 @@ for legacy text)" soundnessBits : Nat; "Dyadic sqrt precision bits for LayerNorm bounds (default: 20)" softmaxExpEffort : Nat; "Effort level for margin-based exp lower bounds (default: 1)" maxSeqLen : Nat; "Maximum sequence length to analyze (default: 256)" + scalePow10 : Nat; "Fixed-point scale exponent for best-match bounds (default: 9)" o, output : String; "Write report to file instead of stdout" ARGS: model : String; "Path to the model weights file (.nfpt)" @@ -1900,6 +1906,7 @@ for legacy text)" soundnessBits : Nat; "Dyadic sqrt precision bits for LayerNorm bounds (default: 20)" softmaxExpEffort : Nat; "Effort level for margin-based exp lower bounds (default: 1)" maxSeqLen : Nat; "Maximum sequence length to analyze (default: 256)" + scalePow10 : Nat; "Fixed-point scale exponent for best-match bounds (default: 9)" o, output : String; "Write report to file instead of stdout" ARGS: model : String; "Path to the model weights file (.nfpt)" diff --git a/Nfp/Sound/Fixed.lean b/Nfp/Sound/Fixed.lean index 88c31e1..3c4208d 100644 --- a/Nfp/Sound/Fixed.lean +++ b/Nfp/Sound/Fixed.lean @@ -58,9 +58,13 @@ def sub (a b : Fixed10Interval) : Fixed10Interval := def relu (a : Fixed10Interval) : Fixed10Interval := { lo := max 0 a.lo, hi := max 0 a.hi } -/-- Conservative GeLU hull: `GeLU(x) ∈ [min(x,0), max(x,0)]`. -/ +/-- Conservative GeLU hull using a linear lower bound `GeLU(x) ≥ x/2`. + +For both exact and tanh GeLU, `GeLU(x) = x * g(x)` with `g(x) ∈ [0, 1]` and +`g(x) ≥ 1/2` when `x ≥ 0`, so `x/2` is a global lower bound. +-/ def geluOverapprox (a : Fixed10Interval) : Fixed10Interval := - { lo := min a.lo 0, hi := max a.hi 0 } + { lo := a.lo.ediv (Int.ofNat 2), hi := max a.hi 0 } private def absInt (x : Int) : Int := if x < 0 then -x else x @@ -161,7 +165,8 @@ theorem relu_def (a : Fixed10Interval) : Fixed10Interval.relu a = { lo := max 0 a.lo, hi := max 0 a.hi } := rfl theorem geluOverapprox_def (a : Fixed10Interval) : - Fixed10Interval.geluOverapprox a = { lo := min a.lo 0, hi := max a.hi 0 } := rfl + Fixed10Interval.geluOverapprox a = + { lo := a.lo.ediv (Int.ofNat 2), hi := max a.hi 0 } := rfl theorem absInt_spec (x : Int) : absInt x = absInt x := rfl diff --git a/Nfp/Sound/Interval.lean b/Nfp/Sound/Interval.lean index 4723390..9d11de6 100644 --- a/Nfp/Sound/Interval.lean +++ b/Nfp/Sound/Interval.lean @@ -36,6 +36,14 @@ def add (a b : RatInterval) : RatInterval := def sub (a b : RatInterval) : RatInterval := { lo := a.lo - b.hi, hi := a.hi - b.lo } +/-- Interval multiplication via endpoint products. -/ +def mul (a b : RatInterval) : RatInterval := + let p1 := a.lo * b.lo + let p2 := a.lo * b.hi + let p3 := a.hi * b.lo + let p4 := a.hi * b.hi + { lo := min (min p1 p2) (min p3 p4), hi := max (max p1 p2) (max p3 p4) } + /-- Scale an interval by a rational `c`, handling sign. -/ def scale (c : Rat) (a : RatInterval) : RatInterval := if c ≥ 0 then @@ -208,13 +216,72 @@ def varianceLowerBound (xs : Array RatInterval) : Rat := let exactLB := bestG / nRat return exactLB -/-- Over-approximate GeLU on an interval without any transcendental facts. +/-- Over-approximate GeLU on an interval without transcendental evaluation. -For all real `x`, `GeLU(x) = x·Φ(x)` lies between `x` and `0`. -Therefore `GeLU([lo,hi]) ⊆ [min(lo,0), max(hi,0)]`. +For both exact and tanh GeLU, `GeLU(x) = x * g(x)` with `g(x) ∈ [0, 1]` and +`g(x) ≥ 1/2` when `x ≥ 0`, so `GeLU(x) ≥ x/2` for all `x`. +We keep the standard upper bound `GeLU(x) ≤ max(x, 0)`. -/ def geluOverapprox (a : RatInterval) : RatInterval := - { lo := min a.lo 0, hi := max a.hi 0 } + { lo := a.lo / (2 : Rat), hi := max a.hi 0 } + +/-- Exp lower bound for all signs, using reciprocal of `expUB` for `x < 0`. -/ +private def expLBAll (x : Rat) (effort : Nat) : Rat := + if x ≥ 0 then + expLB x effort + else + let ub := expUBScaledGeom (-x) + if ub = 0 then 0 else (1 : Rat) / ub + +/-- Exp upper bound for all signs, using reciprocal of `expLB` for `x < 0`. -/ +private def expUBAll (x : Rat) (effort : Nat) : Rat := + if x ≥ 0 then + expUBScaledGeom x + else + let lb := expLB (-x) effort + if lb = 0 then 1 else (1 : Rat) / lb + +/-- Tanh over-approximation using exp bounds on the endpoints. -/ +def tanhOverapprox (a : RatInterval) (expEffort : Nat) : RatInterval := + let lo := min a.lo a.hi + let hi := max a.lo a.hi + let eLo := expLBAll ((2 : Rat) * lo) expEffort + let eHi := expUBAll ((2 : Rat) * hi) expEffort + let f : Rat → Rat := fun e => (e - 1) / (e + 1) + { lo := f eLo, hi := f eHi } + +/-- Tanh-based GeLU over-approximation using exp bounds. -/ +def geluOverapproxTanh (a : RatInterval) (expEffort : Nat := defaultSoftmaxExpEffort) : + RatInterval := + let x := { lo := min a.lo a.hi, hi := max a.lo a.hi } + let c : Rat := (44715 : Rat) / 1000000 + let kLo : Rat := (7978845608 : Rat) / 10000000000 + let kHi : Rat := (7978845609 : Rat) / 10000000000 + let kI : RatInterval := { lo := kLo, hi := kHi } + let x2 := mul x x + let x3 := mul x2 x + let sPoly := add x (scale c x3) + let s := mul kI sPoly + let t := tanhOverapprox s expEffort + let half : Rat := (1 : Rat) / 2 + let onePlus := add (const 1) t + let g := scale half onePlus + mul x g + +/-- Split-based tightening for tanh GeLU over-approximation. -/ +def geluOverapproxTanhSplit (a : RatInterval) (expEffort : Nat := defaultSoftmaxExpEffort) + (splitDepth : Nat := 0) : RatInterval := + if splitDepth = 0 then + geluOverapproxTanh a expEffort + else + let lo := min a.lo a.hi + let hi := max a.lo a.hi + let mid := (lo + hi) / (2 : Rat) + let left : RatInterval := { lo := lo, hi := mid } + let right : RatInterval := { lo := mid, hi := hi } + RatInterval.union + (geluOverapproxTanhSplit left expEffort (splitDepth - 1)) + (geluOverapproxTanhSplit right expEffort (splitDepth - 1)) /-- Upper bound on `max |gelu'(x)|` over a rational interval. -/ def geluDerivBound (target : GeluDerivTarget) (a : RatInterval) : Rat := @@ -249,6 +316,14 @@ theorem add_def (a b : RatInterval) : theorem sub_def (a b : RatInterval) : RatInterval.sub a b = { lo := a.lo - b.hi, hi := a.hi - b.lo } := rfl +theorem mul_def (a b : RatInterval) : + RatInterval.mul a b = + let p1 := a.lo * b.lo + let p2 := a.lo * b.hi + let p3 := a.hi * b.lo + let p4 := a.hi * b.hi + { lo := min (min p1 p2) (min p3 p4), hi := max (max p1 p2) (max p3 p4) } := rfl + theorem scale_def (c : Rat) (a : RatInterval) : RatInterval.scale c a = if c ≥ 0 then @@ -311,7 +386,49 @@ theorem varianceLowerBound_spec (xs : Array RatInterval) : RatInterval.varianceLowerBound xs = RatInterval.varianceLowerBound xs := rfl theorem geluOverapprox_def (a : RatInterval) : - RatInterval.geluOverapprox a = { lo := min a.lo 0, hi := max a.hi 0 } := rfl + RatInterval.geluOverapprox a = { lo := a.lo / (2 : Rat), hi := max a.hi 0 } := rfl + +theorem tanhOverapprox_def (a : RatInterval) (expEffort : Nat) : + RatInterval.tanhOverapprox a expEffort = + let lo := min a.lo a.hi + let hi := max a.lo a.hi + let eLo := + (fun x => + if x ≥ 0 then + expLB x expEffort + else + let ub := expUBScaledGeom (-x) + if ub = 0 then 0 else (1 : Rat) / ub) ((2 : Rat) * lo) + let eHi := + (fun x => + if x ≥ 0 then + expUBScaledGeom x + else + let lb := expLB (-x) expEffort + if lb = 0 then 1 else (1 : Rat) / lb) ((2 : Rat) * hi) + let f : Rat → Rat := fun e => (e - 1) / (e + 1) + { lo := f eLo, hi := f eHi } := rfl + +theorem geluOverapproxTanh_def (a : RatInterval) (expEffort : Nat) : + RatInterval.geluOverapproxTanh a expEffort = + let x := { lo := min a.lo a.hi, hi := max a.lo a.hi } + let c : Rat := (44715 : Rat) / 1000000 + let kLo : Rat := (7978845608 : Rat) / 10000000000 + let kHi : Rat := (7978845609 : Rat) / 10000000000 + let kI : RatInterval := { lo := kLo, hi := kHi } + let x2 := RatInterval.mul x x + let x3 := RatInterval.mul x2 x + let sPoly := RatInterval.add x (RatInterval.scale c x3) + let s := RatInterval.mul kI sPoly + let t := RatInterval.tanhOverapprox s expEffort + let half : Rat := (1 : Rat) / 2 + let onePlus := RatInterval.add (RatInterval.const 1) t + let g := RatInterval.scale half onePlus + RatInterval.mul x g := rfl + +theorem geluOverapproxTanhSplit_spec (a : RatInterval) (expEffort : Nat) (splitDepth : Nat) : + RatInterval.geluOverapproxTanhSplit a expEffort splitDepth = + RatInterval.geluOverapproxTanhSplit a expEffort splitDepth := rfl end RatInterval diff --git a/Nfp/Untrusted/SoundCompute.lean b/Nfp/Untrusted/SoundCompute.lean index 5cac86e..e116898 100644 --- a/Nfp/Untrusted/SoundCompute.lean +++ b/Nfp/Untrusted/SoundCompute.lean @@ -1089,6 +1089,12 @@ private def ratCeilMulNat (x : Rat) (k : Nat) : Int := let r := numK.emod (Int.ofNat den) if r = 0 then q else q + 1 +private def ratFloorMulNat (x : Rat) (k : Nat) : Int := + let num : Int := x.num + let den : Nat := x.den + let numK : Int := num * (Int.ofNat k) + numK.ediv (Int.ofNat den) + private def fixedMeanInterval (xs : Array Fixed10Interval) : Fixed10Interval := if xs.isEmpty then { lo := 0, hi := 0 } @@ -1129,6 +1135,55 @@ private def fixedVarianceLowerBoundRange (cfg : Fixed10Cfg) (xs : Array Fixed10I let δSq : Rat := δRat * δRat return δSq / ((2 : Rat) * nRat) +private def absInt (x : Int) : Int := if x < 0 then -x else x + +/-- Lower bound on variance using midpoint + radius deviation. -/ +private def fixedVarianceLowerBoundMidpoint (cfg : Fixed10Cfg) (xs : Array Fixed10Interval) : + Rat := + if xs.size < 2 then + 0 + else + Id.run do + let n : Nat := xs.size + let nInt : Int := Int.ofNat n + let d : Nat := 2 * cfg.scaleNat + let mut sumM : Int := 0 + let mut sumR : Int := 0 + for x in xs do + sumM := sumM + (x.lo + x.hi) + sumR := sumR + (x.hi - x.lo) + let mut varNum : Int := 0 + let mut errNum : Int := 0 + for x in xs do + let mInt := x.lo + x.hi + let rInt := x.hi - x.lo + let aNum := nInt * mInt - sumM + let rNum := nInt * rInt + sumR + varNum := varNum + aNum * aNum + errNum := errNum + (absInt aNum) * rNum + let num := varNum - 2 * errNum + if num <= 0 then + return 0 + let denNat : Nat := d * d * n * n * n + return (num : Rat) / (denNat : Rat) + +/-- Exact variance lower bound by converting to `RatInterval` and using the exact routine. -/ +private def fixedVarianceLowerBoundExact (cfg : Fixed10Cfg) (xs : Array Fixed10Interval) : Rat := + if xs.size < 2 then + 0 + else + let ratXs := + xs.map (fun x => { lo := ratOfScaledInt cfg.scalePow10 x.lo, + hi := ratOfScaledInt cfg.scalePow10 x.hi }) + RatInterval.varianceLowerBound ratXs + +/-- Best available variance lower bound from range + midpoint deviation. -/ +private def fixedVarianceLowerBound (cfg : Fixed10Cfg) (xs : Array Fixed10Interval) : Rat := + let rangeLB := fixedVarianceLowerBoundRange cfg xs + let midLB := fixedVarianceLowerBoundMidpoint cfg xs + let exactLB := fixedVarianceLowerBoundExact cfg xs + max rangeLB (max midLB exactLB) + private def fixedLayerNormRowApprox (cfg : Fixed10Cfg) (row : Array Fixed10Interval) @@ -1141,7 +1196,7 @@ private def fixedLayerNormRowApprox else Id.run do let μ := fixedMeanInterval row - let varLB := fixedVarianceLowerBoundRange cfg row + let varLB := fixedVarianceLowerBound cfg row let invσUpper : Rat := if varLB ≤ 0 then layerNormOpBoundConservative 1 eps soundnessBits @@ -1157,6 +1212,21 @@ private def fixedLayerNormRowApprox out := out.push (Fixed10Interval.add scaled beta[i]!) return (out, varLB) +private def fixedLayerNormRowsApprox + (cfg : Fixed10Cfg) + (rows : Array (Array Fixed10Interval)) + (p : LayerNormParamsFixed) + (eps : Rat) + (soundnessBits : Nat) : + Array (Array Fixed10Interval) := + let useTasks := rows.size > 32 + if useTasks then + let tasks := rows.map (fun row => + Task.spawn (fun _ => fixedLayerNormRowApprox cfg row p.gamma p.beta eps soundnessBits)) + tasks.map (fun t => (t.get).1) + else + rows.map (fun row => (fixedLayerNormRowApprox cfg row p.gamma p.beta eps soundnessBits).1) + private def readVecIntervals (r : SoundCache.I32Reader) (n : Nat) (slack : Int) : IO (Array Fixed10Interval × SoundCache.I32Reader) := do @@ -1218,6 +1288,34 @@ private def centeredAbsSumFixed (cfg : Fixed10Cfg) (xs : Array Fixed10Interval) have h10pos : (0 : Nat) < 10 := by decide exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) +private def ratIntervalOfFixed (cfg : Fixed10Cfg) (a : Fixed10Interval) : RatInterval := + { lo := ratOfScaledInt cfg.scalePow10 a.lo, hi := ratOfScaledInt cfg.scalePow10 a.hi } + +private def fixedIntervalOfRat (cfg : Fixed10Cfg) (a : RatInterval) : Fixed10Interval := + { lo := ratFloorMulNat a.lo cfg.scaleNat, hi := ratCeilMulNat a.hi cfg.scaleNat } + +private def defaultGeluExpEffort : Nat := 2 +private def defaultGeluSplitDepth : Nat := 1 + +private def geluOverapproxRat (target : GeluDerivTarget) (a : RatInterval) : RatInterval := + match target with + | .tanh => RatInterval.geluOverapproxTanhSplit a defaultGeluExpEffort defaultGeluSplitDepth + | .exact => RatInterval.geluOverapprox a + +private def geluOverapproxFixed (cfg : Fixed10Cfg) (target : GeluDerivTarget) + (a : Fixed10Interval) : Fixed10Interval := + match target with + | .tanh => + let r := ratIntervalOfFixed cfg a + fixedIntervalOfRat cfg + (RatInterval.geluOverapproxTanhSplit r defaultGeluExpEffort defaultGeluSplitDepth) + | .exact => + Fixed10Interval.geluOverapprox a + +private def geluOverapproxFixedVec (cfg : Fixed10Cfg) (target : GeluDerivTarget) + (xs : Array Fixed10Interval) : Array Fixed10Interval := + xs.map (geluOverapproxFixed cfg target) + private def addVecFixed (a b : Array Fixed10Interval) : Array Fixed10Interval := Id.run do if a.size ≠ b.size then @@ -1248,6 +1346,7 @@ private def takePrefix {α : Type} (xs : Array α) (n : Nat) : Array α := private def mlpRowFromScaled (cfg : Fixed10Cfg) + (geluDerivTarget : GeluDerivTarget) (slack : Int) (modelDim hiddenDim : Nat) (wIn wOut : Array Int) @@ -1255,12 +1354,13 @@ private def mlpRowFromScaled (row : Array Fixed10Interval) : Array Fixed10Interval := let hidden0 := matMulIntervalsFromScaled cfg slack modelDim hiddenDim wIn row let hiddenB := addVecFixed hidden0 bIn - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox + let actHidden := geluOverapproxFixedVec cfg geluDerivTarget hiddenB let mlpOut0 := matMulIntervalsFromScaled cfg slack hiddenDim modelDim wOut actHidden addVecFixed mlpOut0 bOut private def mlpRowsFromScaled (cfg : Fixed10Cfg) + (geluDerivTarget : GeluDerivTarget) (slack : Int) (modelDim hiddenDim : Nat) (wIn wOut : Array Int) @@ -1269,10 +1369,11 @@ private def mlpRowsFromScaled let useTasks := rows.size > 32 if useTasks then let tasks := rows.map (fun row => - Task.spawn (fun _ => mlpRowFromScaled cfg slack modelDim hiddenDim wIn wOut bIn bOut row)) + Task.spawn (fun _ => + mlpRowFromScaled cfg geluDerivTarget slack modelDim hiddenDim wIn wOut bIn bOut row)) tasks.map (fun t => t.get) else - rows.map (mlpRowFromScaled cfg slack modelDim hiddenDim wIn wOut bIn bOut) + rows.map (mlpRowFromScaled cfg geluDerivTarget slack modelDim hiddenDim wIn wOut bIn bOut) private def groupUnionRowsByToken (rows : Array (Array Fixed10Interval)) @@ -1710,11 +1811,7 @@ private def certifyHeadValueLowerBoundLocalBinaryAt let mut residuals := residuals0 for l in [:hdr.numLayers] do let p1 := ln1Params.getD l defP - let mut ln1Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln1Out, _ln1VarLB) := - fixedLayerNormRowApprox cfg row p1.gamma p1.beta eps soundnessBits - ln1Rows := ln1Rows.push ln1Out + let ln1Rows := fixedLayerNormRowsApprox cfg residuals p1 eps soundnessBits if l = layerIdx then let mut wv? : Option (Array Int) := none let mut bv? : Option (Array Int) := none @@ -1891,11 +1988,7 @@ private def certifyHeadValueLowerBoundLocalBinaryAt attnUnion := addVecFixed attnUnion attnBias residuals := addVecFixedRows residuals attnUnion let p2 := ln2Params.getD l defP - let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty seqLenEff - for row in residuals do - let (ln2Out, _ln2VarLB) := - fixedLayerNormRowApprox cfg row p2.gamma p2.beta eps soundnessBits - ln2Rows := ln2Rows.push ln2Out + let ln2Rows := fixedLayerNormRowsApprox cfg residuals p2 eps soundnessBits let perRowLayers : Nat := perRowPatternLayers if perRowLayers > 0 && layerIdx ≤ l + perRowLayers then let wIn ← @@ -1905,7 +1998,8 @@ private def certifyHeadValueLowerBoundLocalBinaryAt ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) let mlpRows := - mlpRowsFromScaled cfg slack hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows + mlpRowsFromScaled cfg hdr.geluDerivTarget slack + hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows residuals := addRowsFixed residuals mlpRows else let ln2Union := unionRowsFixed ln2Rows @@ -1914,7 +2008,7 @@ private def certifyHeadValueLowerBoundLocalBinaryAt hdr.modelDim hdr.hiddenDim ln2Union scalePow10) let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) let hiddenB := addVecFixed hidden0 bIn - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox + let actHidden := geluOverapproxFixedVec cfg hdr.geluDerivTarget hiddenB let (mlpOut0, _nWout) ← ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h hdr.hiddenDim hdr.modelDim actHidden scalePow10) @@ -2258,7 +2352,8 @@ private def certifyHeadValueLogitLowerBoundLocalBinaryAt ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) let mlpRows := - mlpRowsFromScaled cfg slack hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows + mlpRowsFromScaled cfg hdr.geluDerivTarget slack + hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows residuals := addRowsFixed residuals mlpRows else let ln2Union := unionRowsFixed ln2Rows @@ -2267,7 +2362,7 @@ private def certifyHeadValueLogitLowerBoundLocalBinaryAt hdr.modelDim hdr.hiddenDim ln2Union scalePow10) let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) let hiddenB := addVecFixed hidden0 bIn - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox + let actHidden := geluOverapproxFixedVec cfg hdr.geluDerivTarget hiddenB let (mlpOut0, _nWout) ← ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h hdr.hiddenDim hdr.modelDim actHidden scalePow10) @@ -2622,7 +2717,8 @@ private def certifyHeadLogitDiffLowerBoundLocalBinaryAt ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) let mlpRows := - mlpRowsFromScaled cfg slack hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows + mlpRowsFromScaled cfg hdr.geluDerivTarget slack + hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows residuals := addRowsFixed residuals mlpRows else let ln2Union := unionRowsFixed ln2Rows @@ -2631,7 +2727,7 @@ private def certifyHeadLogitDiffLowerBoundLocalBinaryAt hdr.modelDim hdr.hiddenDim ln2Union scalePow10) let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) let hiddenB := addVecFixed hidden0 bIn - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox + let actHidden := geluOverapproxFixedVec cfg hdr.geluDerivTarget hiddenB let (mlpOut0, _nWout) ← ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h hdr.hiddenDim hdr.modelDim actHidden scalePow10) @@ -2952,7 +3048,7 @@ private def certifyModelFileLocalText pos := nextBin let hiddenB := addConstVec hidden bin let mlpActDerivBound := maxGeluDerivBound geluDerivTarget hiddenB - let actHidden := hiddenB.map RatInterval.geluOverapprox + let actHidden := hiddenB.map (geluOverapproxRat geluDerivTarget) pos := skipBlankLines lines pos if !(pos < lines.size && lines[pos]!.trim = "W_out") then return .error "missing W_out" @@ -3146,7 +3242,7 @@ private def certifyModelFileLocal rr := rrBin let hiddenB := addVecFixed hidden0 bIn let mlpActDerivBound := maxGeluDerivBoundFixed cfg geluDerivTarget hiddenB - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox + let actHidden := geluOverapproxFixedVec cfg geluDerivTarget hiddenB let (mlpOut0, nWout, rrWout) ← consumeMatrixMulAndNormInfFixed cfg slack rr hiddenDim modelDim actHidden rr := rrWout @@ -3318,7 +3414,7 @@ private def certifyModelFileLocalBinary let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) let hiddenB := addVecFixed hidden0 bIn let mlpActDerivBound := maxGeluDerivBoundFixed cfg geluDerivTarget hiddenB - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox + let actHidden := geluOverapproxFixedVec cfg hdr.geluDerivTarget hiddenB let (mlpOut0, nWout) ← ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h hdr.hiddenDim hdr.modelDim actHidden scalePow10) @@ -3488,7 +3584,7 @@ private def certifyHeadBoundsLocalBinary hdr.modelDim hdr.hiddenDim ln2Out scalePow10) let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) let hiddenB := addVecFixed hidden0 bIn - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox + let actHidden := geluOverapproxFixedVec cfg hdr.geluDerivTarget hiddenB let (mlpOut0, _nWout) ← ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h hdr.hiddenDim hdr.modelDim actHidden scalePow10) @@ -3780,7 +3876,8 @@ private def certifyHeadPatternLocalBinary ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) let mlpRows := - mlpRowsFromScaled cfg slack hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows + mlpRowsFromScaled cfg hdr.geluDerivTarget slack + hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows residuals := addRowsFixed residuals mlpRows else let ln2Union := unionRowsFixed ln2Rows @@ -3789,7 +3886,7 @@ private def certifyHeadPatternLocalBinary hdr.modelDim hdr.hiddenDim ln2Union scalePow10) let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) let hiddenB := addVecFixed hidden0 bIn - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox + let actHidden := geluOverapproxFixedVec cfg hdr.geluDerivTarget hiddenB let (mlpOut0, _nWout) ← ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h hdr.hiddenDim hdr.modelDim actHidden scalePow10) @@ -3833,7 +3930,7 @@ private def certifyHeadPatternLocalBinary hdr.modelDim hdr.hiddenDim ln2Union scalePow10) let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) let hiddenB := addVecFixed hidden0 bIn - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox + let actHidden := geluOverapproxFixedVec cfg hdr.geluDerivTarget hiddenB let (mlpOut0, _nWout) ← ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h hdr.hiddenDim hdr.modelDim actHidden scalePow10) @@ -4168,7 +4265,8 @@ private def certifyHeadPatternBestMatchLocalBinary ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) let mlpRows := - mlpRowsFromScaled cfg slack hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows + mlpRowsFromScaled cfg hdr.geluDerivTarget slack + hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows residuals := addRowsFixed residuals mlpRows else let ln2Union := unionRowsFixed ln2Rows @@ -4177,7 +4275,7 @@ private def certifyHeadPatternBestMatchLocalBinary hdr.modelDim hdr.hiddenDim ln2Union scalePow10) let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) let hiddenB := addVecFixed hidden0 bIn - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox + let actHidden := geluOverapproxFixedVec cfg hdr.geluDerivTarget hiddenB let (mlpOut0, _nWout) ← ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h hdr.hiddenDim hdr.modelDim actHidden scalePow10) @@ -4486,7 +4584,8 @@ private def certifyHeadPatternBestMatchLocalBinarySweep ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) let mlpRows := - mlpRowsFromScaled cfg slack hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows + mlpRowsFromScaled cfg hdr.geluDerivTarget slack + hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows residuals := addRowsFixed residuals mlpRows else let ln2Union := unionRowsFixed ln2Rows @@ -4495,7 +4594,7 @@ private def certifyHeadPatternBestMatchLocalBinarySweep hdr.modelDim hdr.hiddenDim ln2Union scalePow10) let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) let hiddenB := addVecFixed hidden0 bIn - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox + let actHidden := geluOverapproxFixedVec cfg hdr.geluDerivTarget hiddenB let (mlpOut0, _nWout) ← ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h hdr.hiddenDim hdr.modelDim actHidden scalePow10) @@ -4771,7 +4870,8 @@ private def certifyHeadValueLowerBoundLocalBinary ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) let mlpRows := - mlpRowsFromScaled cfg slack hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows + mlpRowsFromScaled cfg hdr.geluDerivTarget slack + hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows residuals := addRowsFixed residuals mlpRows else let ln2Union := unionRowsFixed ln2Rows @@ -4780,7 +4880,7 @@ private def certifyHeadValueLowerBoundLocalBinary hdr.modelDim hdr.hiddenDim ln2Union scalePow10) let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) let hiddenB := addVecFixed hidden0 bIn - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox + let actHidden := geluOverapproxFixedVec cfg hdr.geluDerivTarget hiddenB let (mlpOut0, _nWout) ← ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h hdr.hiddenDim hdr.modelDim actHidden scalePow10) @@ -5063,7 +5163,8 @@ private def certifyHeadLogitDiffLowerBoundLocalBinary ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) let mlpRows := - mlpRowsFromScaled cfg slack hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows + mlpRowsFromScaled cfg hdr.geluDerivTarget slack + hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows residuals := addRowsFixed residuals mlpRows else let ln2Union := unionRowsFixed ln2Rows @@ -5072,7 +5173,7 @@ private def certifyHeadLogitDiffLowerBoundLocalBinary hdr.modelDim hdr.hiddenDim ln2Union scalePow10) let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) let hiddenB := addVecFixed hidden0 bIn - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox + let actHidden := geluOverapproxFixedVec cfg hdr.geluDerivTarget hiddenB let (mlpOut0, _nWout) ← ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h hdr.hiddenDim hdr.modelDim actHidden scalePow10) @@ -5535,13 +5636,7 @@ private def certifyInductionSoundBestMatchLocalBinaryPair (rows : Array (Array Fixed10Interval)) (p : LayerNormParamsFixed) : Array (Array Fixed10Interval) := - Id.run do - let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size - for row in rows do - let (lnOut, _varLB) := - fixedLayerNormRowApprox cfg row p.gamma p.beta eps soundnessBits - out := out.push lnOut - return out + fixedLayerNormRowsApprox cfg rows p eps soundnessBits let calcVOutRows (rows : Array (Array Fixed10Interval)) (wv wo : Array Int) @@ -5746,6 +5841,61 @@ private def certifyInductionSoundBestMatchLocalBinaryPair | _, _, _ => throw "use both target and negative tokens (or neither)" return { value := value, logit? := logit? } + let tightenQueryRowLower + (baseRow : Array Fixed10Interval) + (vOutRows : Array (Array Fixed10Interval)) + (matchWeightLowerBound : Rat) + (targetOffset : Int) : + ExceptT String IO (Array Fixed10Interval) := do + let ti : Int := (Int.ofNat queryPos) + targetOffset + if ti < 0 || ti ≥ (Int.ofNat seqLenEff) then + throw "query position has no valid target offset" + let tIdx : Nat := Int.toNat ti + let targetTok := tokens[tIdx]! + let mut matchLo? : Option (Array Int) := none + let mut nonmatchLo? : Option (Array Int) := none + for j in [:seqLenEff] do + if !causalPattern || j ≤ queryPos then + let row := vOutRows[j]! + let rowLo : Array Int := row.map (fun x => x.lo) + if tokens[j]! = targetTok then + matchLo? := + match matchLo? with + | none => some rowLo + | some cur => + some <| Id.run do + let mut out : Array Int := Array.mkEmpty hdr.modelDim + for i in [:hdr.modelDim] do + out := out.push (min cur[i]! rowLo[i]!) + out + else + nonmatchLo? := + match nonmatchLo? with + | none => some rowLo + | some cur => + some <| Id.run do + let mut out : Array Int := Array.mkEmpty hdr.modelDim + for i in [:hdr.modelDim] do + out := out.push (min cur[i]! rowLo[i]!) + out + let matchLo ← + match matchLo? with + | none => throw "no matching keys for the requested offset" + | some v => pure v + let nonmatchLo := + match nonmatchLo? with + | none => matchLo + | some v => v + let mut tightened : Array Fixed10Interval := Array.mkEmpty hdr.modelDim + for i in [:hdr.modelDim] do + let matchLoRat := ratOfScaledInt scalePow10 matchLo[i]! + let nonmatchLoRat := ratOfScaledInt scalePow10 nonmatchLo[i]! + let outLB := mixLowerBound matchWeightLowerBound matchLoRat nonmatchLoRat + let outLBInt := ratFloorMulNat outLB cfg.scaleNat + let base := baseRow[i]! + let newLo := max base.lo outLBInt + tightened := tightened.push { lo := newLo, hi := base.hi } + return tightened let addAttn (useTight : Bool) (ln1Rows : Array (Array Fixed10Interval)) @@ -5811,12 +5961,12 @@ private def certifyInductionSoundBestMatchLocalBinaryPair Array (Array Fixed10Interval) := let ln2Rows := calcLnRows rows p if usePerRow then - let mlpRows := mlpRowsFromScaled cfg slack + let mlpRows := mlpRowsFromScaled cfg hdr.geluDerivTarget slack hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows addRowsFixed rows mlpRows else let ln2Union := unionRowsFixed ln2Rows - let mlpOut := mlpRowFromScaled cfg slack + let mlpOut := mlpRowFromScaled cfg hdr.geluDerivTarget slack hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Union addVecFixedRows rows mlpOut let h ← IO.FS.Handle.mk path IO.FS.Mode.read @@ -5987,11 +6137,26 @@ private def certifyInductionSoundBestMatchLocalBinaryPair attnRows1? := attnRows' attnUnion1? := attnUnion' if needUpdate2 then - let (attnRows', attnUnion') ← - addAttn useTight2 ln1Rows2 ln1Union2? groupRows2? - attnRows2? attnUnion2? wv wo bVIntervals - attnRows2? := attnRows' - attnUnion2? := attnUnion' + if l == layer1 && hIdx == head1 && useTight2 && causalPattern then + let p1 ← + match p1? with + | some cert => pure cert + | none => throw "missing best-match pattern cert for tightening" + let vOutRows := calcVOutRows ln1Rows2 wv wo bVIntervals + let mut headRows := prefixUnionRowsFixed vOutRows + let baseRow := headRows[queryPos]! + let tightRow ← + tightenQueryRowLower baseRow vOutRows p1.bestMatchWeightLowerBound offset1 + headRows := headRows.set! queryPos tightRow + match attnRows2? with + | some rows => attnRows2? := some (addRowsFixed rows headRows) + | none => throw "missing attnRows" + else + let (attnRows', attnUnion') ← + addAttn useTight2 ln1Rows2 ln1Union2? groupRows2? + attnRows2? attnUnion2? wv wo bVIntervals + attnRows2? := attnRows' + attnUnion2? := attnUnion' if needUpdateV && !shareUpdateV && !skipAttnV then let (attnRows', attnUnion') ← addAttn useTightV ln1RowsV ln1UnionV? groupRowsV? @@ -6130,56 +6295,45 @@ def certifyInductionSoundBestMatch let dtMs := dtNs / 1000000 ExceptT.lift <| IO.eprintln s!"timing:{label} {dtMs}ms" return r - let shared ← ExceptT.mk (loadSharedBinaryInputs path inputPath inputDelta scalePow10) - let queryPos : Nat := - match queryPos? with - | some q => q - | none => - if shared.hdr.seqLen = 0 then 0 else shared.hdr.seqLen - 1 - if queryPos ≥ shared.hdr.seqLen then - throw s!"queryPos {queryPos} out of range" - let prefixCache := mkSharedBinaryPrefix shared queryPos causalPattern - let direction? : Option (Thunk (Array Fixed10Interval)) ← - match targetToken?, negativeToken? with - | none, none => pure none - | some targetToken, some negativeToken => - let (hdrDir, dir) ← - ExceptT.mk (readLogitDiffDirectionBinary - path targetToken negativeToken scalePow10 fixedUlpSlack) - if hdrDir.modelDim ≠ shared.hdr.modelDim then - throw "unembedding model_dim mismatch" - pure (some (Thunk.mk (fun () => dir))) - | _, _ => - throw "use both target and negative tokens (or neither)" - let computeCert (useTight : Bool) (tightLayers perRowLayers : Nat) : - ExceptT String IO InductionHeadBestMatchSoundCert := do - let label := - s!"tight={useTight} tl={tightLayers} pr={perRowLayers}" - let cert ← - timeIt (s!"{label}:pair") <| - ExceptT.mk <| - certifyInductionSoundBestMatchLocalBinaryPair - path layer1 head1 layer2 head2 coord queryPos eps soundnessBits inputPath - inputDelta offset1 offset2 maxSeqLen scalePow10 useTight tightLayers - perRowLayers softmaxExpEffort causalPattern - (shared? := some shared) (prefix? := some prefixCache) - (targetToken? := targetToken?) (negativeToken? := negativeToken?) - (direction? := direction?) - return cert - if !iterTighten then - let cert ← computeCert tightPattern tightPatternLayers perRowPatternLayers - return cert - else - let maxLayer := Nat.max layer1 layer2 - let tightFull := Nat.max 1 maxLayer - let perRowFull := maxLayer - let mut configs : Array (Bool × Nat × Nat) := - #[(tightPattern, tightPatternLayers, perRowPatternLayers)] - let needTightFull := (!tightPattern) || tightPatternLayers < tightFull - if needTightFull then - configs := configs.push (true, tightFull, perRowPatternLayers) - if perRowPatternLayers < perRowFull then - configs := configs.push (true, tightFull, perRowFull) + let computeBestAtScale (scalePow10 : Nat) + (configs : Array (Bool × Nat × Nat)) : + ExceptT String IO (Rat × InductionHeadBestMatchSoundCert) := do + let shared ← ExceptT.mk (loadSharedBinaryInputs path inputPath inputDelta scalePow10) + let queryPos : Nat := + match queryPos? with + | some q => q + | none => + if shared.hdr.seqLen = 0 then 0 else shared.hdr.seqLen - 1 + if queryPos ≥ shared.hdr.seqLen then + throw s!"queryPos {queryPos} out of range" + let prefixCache := mkSharedBinaryPrefix shared queryPos causalPattern + let direction? : Option (Thunk (Array Fixed10Interval)) ← + match targetToken?, negativeToken? with + | none, none => pure none + | some targetToken, some negativeToken => + let (hdrDir, dir) ← + ExceptT.mk (readLogitDiffDirectionBinary + path targetToken negativeToken scalePow10 fixedUlpSlack) + if hdrDir.modelDim ≠ shared.hdr.modelDim then + throw "unembedding model_dim mismatch" + pure (some (Thunk.mk (fun () => dir))) + | _, _ => + throw "use both target and negative tokens (or neither)" + let computeCert (useTight : Bool) (tightLayers perRowLayers : Nat) : + ExceptT String IO InductionHeadBestMatchSoundCert := do + let label := + s!"scale={scalePow10} tight={useTight} tl={tightLayers} pr={perRowLayers}" + let cert ← + timeIt (s!"{label}:pair") <| + ExceptT.mk <| + certifyInductionSoundBestMatchLocalBinaryPair + path layer1 head1 layer2 head2 coord queryPos eps soundnessBits inputPath + inputDelta offset1 offset2 maxSeqLen scalePow10 useTight tightLayers + perRowLayers softmaxExpEffort causalPattern + (shared? := some shared) (prefix? := some prefixCache) + (targetToken? := targetToken?) (negativeToken? := negativeToken?) + (direction? := direction?) + return cert let tasks ← ExceptT.lift <| configs.mapM fun (useTight, tightLayers, perRowLayers) => @@ -6206,6 +6360,40 @@ def certifyInductionSoundBestMatch some (bestMetric, bestCert) match best with | none => throw "no induction certs computed" + | some bestPair => return bestPair + let maxLayer := Nat.max layer1 layer2 + let tightFull := Nat.max 1 maxLayer + let perRowFull := maxLayer + let mut configs : Array (Bool × Nat × Nat) := + #[(tightPattern, tightPatternLayers, perRowPatternLayers)] + let needTightFull := (!tightPattern) || tightPatternLayers < tightFull + if needTightFull then + configs := configs.push (true, tightFull, perRowPatternLayers) + if perRowPatternLayers < perRowFull then + configs := configs.push (true, tightFull, perRowFull) + if !iterTighten then + let (_, cert) ← computeBestAtScale scalePow10 configs + return cert + else + let mut bestOverall : Option (Rat × InductionHeadBestMatchSoundCert) := none + let mut scale := scalePow10 + let maxScale := scalePow10 + 2 + while scale ≤ maxScale do + let (metric, cert) ← computeBestAtScale scale configs + bestOverall := + match bestOverall with + | none => some (metric, cert) + | some (bestMetric, bestCert) => + if metric > bestMetric then + some (metric, cert) + else + some (bestMetric, bestCert) + if metric > 0 then + scale := maxScale + 1 + else + scale := scale + 1 + match bestOverall with + | none => throw "no induction certs computed" | some (_, cert) => return cert else throw "induction sound cert requires NFP_BINARY_V1" diff --git a/README.md b/README.md index 2cd2368..cec1a3f 100644 --- a/README.md +++ b/README.md @@ -343,7 +343,7 @@ lake exe nfp induction_cert models/gpt2_rigorous.nfpt \ - `--perRowPatternLayers` sets how many layers use per-row MLP propagation (default: `0`). - `--bestMatch` switches to single-query best-match bounds (default query: last position). - `--queryPos` chooses the query position for best-match bounds (default: last position). -- `--iterTighten` iteratively tightens best-match bounds, escalating to full tight/per-row layers. +- `--iterTighten` iteratively tightens best-match bounds (tight/per-row layers and scale precision). ### `rope` From 39dff3aa63eca0aae39db1bc49e02f0fc4a0167c Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 08:27:02 +0100 Subject: [PATCH 019/244] Optimize cache build and add benchmark --- Main.lean | 81 +++++++++++++++ Nfp/IO.lean | 15 ++- Nfp/Sound/CachePure.lean | 70 ++++++++----- Nfp/Untrusted/SoundBinary.lean | 15 ++- Nfp/Untrusted/SoundCacheIO.lean | 169 ++++++++++++++++++++++++++++++-- 5 files changed, 310 insertions(+), 40 deletions(-) diff --git a/Main.lean b/Main.lean index d2cc00c..35ee2c2 100644 --- a/Main.lean +++ b/Main.lean @@ -1528,6 +1528,75 @@ def runSoundCacheCheck (p : Parsed) : IO UInt32 := do let args := parseSoundCacheCheckArgs p runSoundCacheCheckWithArgs args +/-! ## Sound cache benchmark helpers -/ + +private structure SoundCacheBenchArgs where + modelPath : System.FilePath + scalePow10 : Nat + runs : Nat + +private def parseSoundCacheBenchArgs (p : Parsed) : SoundCacheBenchArgs := + let modelPath := p.positionalArg! "model" |>.as! String + let scalePow10 := p.flag? "scalePow10" |>.map (·.as! Nat) |>.getD 9 + let runs := p.flag? "runs" |>.map (·.as! Nat) |>.getD 1 + { modelPath := ⟨modelPath⟩, scalePow10 := scalePow10, runs := runs } + +private def runSoundCacheBenchWithArgs (args : SoundCacheBenchArgs) : IO UInt32 := do + if args.runs = 0 then + IO.eprintln "Error: --runs must be > 0" + return 1 + let modelHash ← Nfp.Untrusted.SoundCacheIO.fnv1a64File args.modelPath + let mdata ← args.modelPath.metadata + let modelSize : UInt64 := mdata.byteSize + let isBinaryE ← Nfp.Untrusted.SoundCacheIO.isBinaryModelFile args.modelPath + let isBinary ← + match isBinaryE with + | .error e => + IO.eprintln s!"Error: {e}" + return 1 + | .ok b => pure b + let formatStr := if isBinary then "binary" else "text" + let mut times : Array Nat := Array.mkEmpty args.runs + let mut lastBytes : Nat := 0 + for i in [:args.runs] do + let t0 ← IO.monoNanosNow + let bytesE ← + if isBinary then + Nfp.Untrusted.SoundCacheIO.buildCacheBytesBinary + args.modelPath args.scalePow10 modelHash modelSize + else + Nfp.Untrusted.SoundCacheIO.buildCacheBytesText + args.modelPath args.scalePow10 modelHash modelSize + let t1 ← IO.monoNanosNow + match bytesE with + | .error e => + IO.eprintln s!"Error: {e}" + return 1 + | .ok bytes => + let dtMs := (t1 - t0) / 1000000 + times := times.push dtMs + lastBytes := bytes.size + if args.runs > 1 then + IO.println s!"run {i + 1}: {dtMs}ms" + let t0 := times[0]! + let mut minT := t0 + let mut maxT := t0 + let mut sumT : Nat := 0 + for t in times do + if t < minT then + minT := t + if t > maxT then + maxT := t + sumT := sumT + t + let avgT := sumT / times.size + IO.println s!"cacheBuild format={formatStr} scalePow10={args.scalePow10} bytes={lastBytes}" + IO.println s!"cacheBuild runs={args.runs} min={minT}ms avg={avgT}ms max={maxT}ms" + return 0 + +def runSoundCacheBench (p : Parsed) : IO UInt32 := do + let args := parseSoundCacheBenchArgs p + runSoundCacheBenchWithArgs args + /-- Run the rope command - print a proof-backed RoPE operator norm certificate. -/ def runRoPE (p : Parsed) : IO UInt32 := do let seqLen := p.flag? "seqLen" |>.map (·.as! Nat) |>.getD 4 @@ -1923,6 +1992,17 @@ def soundCacheCheckCmd : Cmd := `[Cli| model : String; "Path to the model weights file (.nfpt)" ] +/-- The sound-cache-bench subcommand. -/ +def soundCacheBenchCmd : Cmd := `[Cli| + sound_cache_bench VIA runSoundCacheBench; + "Benchmark SOUND fixed-point cache build (text or binary)." + FLAGS: + scalePow10 : Nat; "Fixed-point scale exponent p in S=10^p (default: 9)" + runs : Nat; "Number of benchmark runs (default: 1)" + ARGS: + model : String; "Path to the model weights file (.nfpt)" +] + /-- The rope subcommand. -/ def ropeCmd : Cmd := `[Cli| rope VIA runRoPE; @@ -1971,6 +2051,7 @@ def nfpCmd : Cmd := `[Cli| headPatternCmd; inductionCertCmd; soundCacheCheckCmd; + soundCacheBenchCmd; ropeCmd; dumpCmd; logitDiffCmd diff --git a/Nfp/IO.lean b/Nfp/IO.lean index 7c8b2d7..7e85c04 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -69,13 +69,18 @@ private def readLine? (h : IO.FS.Handle) : IO (Option String) := do return some s private def readExactly (h : IO.FS.Handle) (n : Nat) : IO ByteArray := do - let mut out := ByteArray.empty - while out.size < n do - let chunk ← h.read (USize.ofNat (n - out.size)) + if n = 0 then + return ByteArray.empty + let mut remaining := n + let mut out : Array UInt8 := Array.mkEmpty n + while remaining > 0 do + let chunk ← h.read (USize.ofNat remaining) if chunk.isEmpty then throw (IO.userError "unexpected EOF") - out := out ++ chunk - return out + for b in chunk.data do + out := out.push b + remaining := remaining - chunk.size + return ByteArray.mk out private def u32FromLE (b : ByteArray) (off : Nat) : UInt32 := let b0 := (b[off]!).toUInt32 diff --git a/Nfp/Sound/CachePure.lean b/Nfp/Sound/CachePure.lean index 8b14b9d..8b491ef 100644 --- a/Nfp/Sound/CachePure.lean +++ b/Nfp/Sound/CachePure.lean @@ -53,6 +53,16 @@ private def i32le (x : Int) : ByteArray := let ux : UInt32 := UInt32.ofInt x u32le ux +private def appendI32LE (buf : Array UInt8) (x : Int) : Array UInt8 := + Id.run do + let ux : UInt32 := UInt32.ofInt x + let mut out := buf + out := out.push (ux &&& 0xFF).toUInt8 + out := out.push ((ux >>> 8) &&& 0xFF).toUInt8 + out := out.push ((ux >>> 16) &&& 0xFF).toUInt8 + out := out.push ((ux >>> 24) &&& 0xFF).toUInt8 + return out + private def u32FromLE (b : ByteArray) (off : Nat) : UInt32 := let b0 := (b.get! off).toUInt32 let b1 := (b.get! (off + 1)).toUInt32 @@ -233,7 +243,7 @@ private def consumeFixedBytes Id.run do let mut iLine := start let mut remaining := count - let mut buf : ByteArray := ByteArray.empty + let mut buf : Array UInt8 := Array.mkEmpty (count * 4) while remaining > 0 do if iLine ≥ lines.size then return .error "unexpected end of file while reading fixed tokens" @@ -256,9 +266,9 @@ private def consumeFixedBytes match parseFixed10Rounded scalePow10 bytes tokStart tokStop with | .error e => return .error e | .ok x => - buf := buf ++ i32le x + buf := appendI32LE buf x remaining := remaining - 1 - return .ok (buf, iLine) + return .ok (ByteArray.mk buf, iLine) private def readHeaderFromLines (lines : Array String) : Except String (Header × Nat) := Id.run do @@ -371,14 +381,21 @@ private def collectLayerNormParamsFixed return .ok (ln1, ln2) private def encodeIntArray (xs : Array Int) : ByteArray := - xs.foldl (fun acc x => acc ++ i32le x) ByteArray.empty + Id.run do + let mut out : Array UInt8 := Array.mkEmpty (xs.size * 4) + for x in xs do + out := appendI32LE out x + return ByteArray.mk out private def repeatBytes (b : ByteArray) (n : Nat) : ByteArray := Id.run do - let mut out := ByteArray.empty + if n = 0 || b.size = 0 then + return ByteArray.empty + let mut out : Array UInt8 := Array.mkEmpty (n * b.size) for _ in [:n] do - out := out ++ b - return out + for byte in b.data do + out := out.push byte + return ByteArray.mk out def buildCacheBytes (lines : Array String) @@ -408,17 +425,25 @@ def buildCacheBytes modelSize := modelSize scalePow10 := UInt32.ofNat scalePow10 } - let mut out : ByteArray := encodeHeader hdr + let totalBytes : Nat := headerBytes + expectedI32Count hdr * 4 + let appendBytes := fun (out : Array UInt8) (bytes : ByteArray) => Id.run do + let mut out := out + for b in bytes.data do + out := out.push b + return out + + let mut out : Array UInt8 := Array.mkEmpty totalBytes + out := appendBytes out (encodeHeader hdr) let mut pos : Nat := skipUntil lines 0 (fun s => s.startsWith "LAYER") let zeroBytes := i32le 0 for l in [:L] do let p1 := ln1.getD l { gamma := Array.replicate d (0 : Int), beta := Array.replicate d 0 } let p2 := ln2.getD l { gamma := Array.replicate d (0 : Int), beta := Array.replicate d 0 } - out := out ++ encodeIntArray p1.gamma - out := out ++ encodeIntArray p1.beta - out := out ++ encodeIntArray p2.gamma - out := out ++ encodeIntArray p2.beta + out := appendBytes out (encodeIntArray p1.gamma) + out := appendBytes out (encodeIntArray p1.beta) + out := appendBytes out (encodeIntArray p2.gamma) + out := appendBytes out (encodeIntArray p2.beta) pos := skipUntil lines pos (fun s => s.startsWith "LAYER") if pos ≥ lines.size then @@ -463,7 +488,7 @@ def buildCacheBytes match consumeFixedBytes scalePow10 lines (pos + 1) (d * dh) with | .error e => return .error e | .ok (bytes, next) => - out := out ++ bytes + out := appendBytes out bytes pos := next pos := skipBlankLines lines pos @@ -471,10 +496,10 @@ def buildCacheBytes match consumeFixedBytes scalePow10 lines (pos + 1) dh with | .error e => return .error e | .ok (bytes, next) => - out := out ++ bytes + out := appendBytes out bytes pos := next else - out := out ++ repeatBytes zeroBytes dh + out := appendBytes out (repeatBytes zeroBytes dh) pos := skipBlankLines lines pos if !(pos < lines.size && lines[pos]!.trim = "W_O") then @@ -482,7 +507,7 @@ def buildCacheBytes match consumeFixedBytes scalePow10 lines (pos + 1) (dh * d) with | .error e => return .error e | .ok (bytes, next) => - out := out ++ bytes + out := appendBytes out bytes pos := next pos := skipBlankLines lines pos @@ -491,7 +516,7 @@ def buildCacheBytes match consumeFixedBytes scalePow10 lines (pos + 1) d with | .error e => return .error e | .ok (bytes, next) => - out := out ++ bytes + out := appendBytes out bytes pos := next pos := skipBlankLines lines pos @@ -505,7 +530,7 @@ def buildCacheBytes match consumeFixedBytes scalePow10 lines (pos + 1) (d * dhid) with | .error e => return .error e | .ok (bytes, next) => - out := out ++ bytes + out := appendBytes out bytes pos := next pos := skipBlankLines lines pos @@ -514,7 +539,7 @@ def buildCacheBytes match consumeFixedBytes scalePow10 lines (pos + 1) dhid with | .error e => return .error e | .ok (bytes, next) => - out := out ++ bytes + out := appendBytes out bytes pos := next pos := skipBlankLines lines pos @@ -523,7 +548,7 @@ def buildCacheBytes match consumeFixedBytes scalePow10 lines (pos + 1) (dhid * d) with | .error e => return .error e | .ok (bytes, next) => - out := out ++ bytes + out := appendBytes out bytes pos := next pos := skipBlankLines lines pos @@ -532,12 +557,12 @@ def buildCacheBytes match consumeFixedBytes scalePow10 lines (pos + 1) d with | .error e => return .error e | .ok (bytes, next) => - out := out ++ bytes + out := appendBytes out bytes pos := next pos := skipUntil lines pos (fun s => s.startsWith "LAYER") - return .ok out + return .ok (ByteArray.mk out) private def isMaybeNumberStart (b : UInt8) : Bool := b = 45 || b = 43 || b = 46 || (48 ≤ b && b ≤ 57) @@ -966,6 +991,7 @@ theorem Header_spec_cache_pure : Header = Header := rfl theorem u32le_spec_cache_pure : u32le = u32le := rfl theorem u64le_spec_cache_pure : u64le = u64le := rfl theorem i32le_spec_cache_pure : i32le = i32le := rfl +theorem appendI32LE_spec_cache_pure : appendI32LE = appendI32LE := rfl theorem u32FromLE_spec_cache_pure : u32FromLE = u32FromLE := rfl theorem u64FromLE_spec_cache_pure : u64FromLE = u64FromLE := rfl theorem i32FromLE_spec_cache_pure : i32FromLE = i32FromLE := rfl diff --git a/Nfp/Untrusted/SoundBinary.lean b/Nfp/Untrusted/SoundBinary.lean index afb4dad..9f10e30 100644 --- a/Nfp/Untrusted/SoundBinary.lean +++ b/Nfp/Untrusted/SoundBinary.lean @@ -35,13 +35,18 @@ def readBinaryHeader (h : IO.FS.Handle) : IO (Except String Nfp.Sound.BinaryHead return Nfp.Sound.parseBinaryHeaderLines magicLine lines private def readExactly (h : IO.FS.Handle) (n : Nat) : IO ByteArray := do - let mut out := ByteArray.empty - while out.size < n do - let chunk ← h.read (USize.ofNat (n - out.size)) + if n = 0 then + return ByteArray.empty + let mut remaining := n + let mut out : Array UInt8 := Array.mkEmpty n + while remaining > 0 do + let chunk ← h.read (USize.ofNat remaining) if chunk.isEmpty then throw (IO.userError "unexpected EOF") - out := out ++ chunk - return out + for b in chunk.data do + out := out.push b + remaining := remaining - chunk.size + return ByteArray.mk out def skipBytes (h : IO.FS.Handle) (n : Nat) : IO (Except String Unit) := do let mut remaining := n diff --git a/Nfp/Untrusted/SoundCacheIO.lean b/Nfp/Untrusted/SoundCacheIO.lean index 5805ad6..78fb99c 100644 --- a/Nfp/Untrusted/SoundCacheIO.lean +++ b/Nfp/Untrusted/SoundCacheIO.lean @@ -3,6 +3,7 @@ import Std import Init.System.IO import Nfp.Sound.CachePure +import Nfp.Untrusted.SoundBinary namespace Nfp.Untrusted.SoundCacheIO @@ -12,13 +13,50 @@ namespace Nfp.Untrusted.SoundCacheIO IO wrappers for the SOUND cache format. Pure parsing/encoding lives in `Nfp.Sound.CachePure`. -/ private def readExactly (h : IO.FS.Handle) (n : Nat) : IO ByteArray := do - let mut out := ByteArray.empty - while out.size < n do - let chunk ← h.read (USize.ofNat (n - out.size)) + if n = 0 then + return ByteArray.empty + let mut remaining := n + let mut out : Array UInt8 := Array.mkEmpty n + while remaining > 0 do + let chunk ← h.read (USize.ofNat remaining) if chunk.isEmpty then throw (IO.userError "unexpected EOF") - out := out ++ chunk - return out + for b in chunk.data do + out := out.push b + remaining := remaining - chunk.size + return ByteArray.mk out + +private def appendI32LE (buf : Array UInt8) (x : Int) : Array UInt8 := + Id.run do + let ux : UInt32 := UInt32.ofInt x + let mut out := buf + out := out.push (ux &&& 0xFF).toUInt8 + out := out.push ((ux >>> 8) &&& 0xFF).toUInt8 + out := out.push ((ux >>> 16) &&& 0xFF).toUInt8 + out := out.push ((ux >>> 24) &&& 0xFF).toUInt8 + return out + +private def appendI32Array (buf : Array UInt8) (xs : Array Int) : Array UInt8 := + Id.run do + let mut out := buf + for x in xs do + out := appendI32LE out x + return out + +private def appendBytes (buf : Array UInt8) (bytes : ByteArray) : Array UInt8 := + Id.run do + let mut out := buf + for b in bytes.data do + out := out.push b + return out + +def isBinaryModelFile (path : System.FilePath) : IO (Except String Bool) := do + let h ← IO.FS.Handle.mk path IO.FS.Mode.read + let line ← h.getLine + if line.isEmpty then + return .error "empty model file" + let magic := line.trim + return .ok (magic = "NFP_BINARY_V1") def writeHeader (h : IO.FS.Handle) (hdr : Nfp.Sound.SoundCache.Header) : IO Unit := do h.write (Nfp.Sound.SoundCache.encodeHeader hdr) @@ -46,17 +84,132 @@ def fnv1a64File (path : System.FilePath) : IO UInt64 := do def ensureCacheDir : IO Unit := do IO.FS.createDirAll Nfp.Sound.SoundCache.cacheDir +def buildCacheBytesText + (modelPath : System.FilePath) + (scalePow10 : Nat) + (modelHash modelSize : UInt64) : IO (Except String ByteArray) := do + let contents ← IO.FS.readFile modelPath + let lines : Array String := (contents.splitOn "\n").toArray + return Nfp.Sound.SoundCache.buildCacheBytes lines scalePow10 modelHash modelSize + +def buildCacheBytesBinary + (modelPath : System.FilePath) + (scalePow10 : Nat) + (modelHash modelSize : UInt64) : IO (Except String ByteArray) := do + let action : ExceptT String IO ByteArray := do + let liftExcept {α : Type} (act : IO (Except String α)) : ExceptT String IO α := + ExceptT.mk act + + let h1 ← ExceptT.lift <| IO.FS.Handle.mk modelPath IO.FS.Mode.read + let hdr1 ← liftExcept <| Nfp.Untrusted.SoundBinary.readBinaryHeader h1 + let d := hdr1.modelDim + let dh := hdr1.headDim + let dhid := hdr1.hiddenDim + let L := hdr1.numLayers + let H := hdr1.numHeads + + let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipI32Array h1 hdr1.seqLen + let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 (hdr1.seqLen * d) + + let mut ln1Gamma : Array (Array Int) := Array.mkEmpty L + let mut ln1Beta : Array (Array Int) := Array.mkEmpty L + let mut ln2Gamma : Array (Array Int) := Array.mkEmpty L + let mut ln2Beta : Array (Array Int) := Array.mkEmpty L + + for _l in [:L] do + for _h in [:H] do + let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 (d * dh) + let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 dh + let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 (d * dh) + let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 dh + let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 (d * dh) + let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 dh + let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 (dh * d) + let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 d + let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 (d * dhid) + let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 dhid + let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 (dhid * d) + let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 d + let ln1G ← liftExcept <| Nfp.Untrusted.SoundBinary.readScaledFloatArray h1 d scalePow10 + let ln1B ← liftExcept <| Nfp.Untrusted.SoundBinary.readScaledFloatArray h1 d scalePow10 + let ln2G ← liftExcept <| Nfp.Untrusted.SoundBinary.readScaledFloatArray h1 d scalePow10 + let ln2B ← liftExcept <| Nfp.Untrusted.SoundBinary.readScaledFloatArray h1 d scalePow10 + ln1Gamma := ln1Gamma.push ln1G + ln1Beta := ln1Beta.push ln1B + ln2Gamma := ln2Gamma.push ln2G + ln2Beta := ln2Beta.push ln2B + + let hdrCache : Nfp.Sound.SoundCache.Header := { + modelHash := modelHash + modelSize := modelSize + scalePow10 := UInt32.ofNat scalePow10 + numLayers := UInt32.ofNat L + numHeads := UInt32.ofNat H + modelDim := UInt32.ofNat d + headDim := UInt32.ofNat dh + hiddenDim := UInt32.ofNat dhid + } + + let totalBytes : Nat := + Nfp.Sound.SoundCache.headerBytes + + Nfp.Sound.SoundCache.expectedI32Count hdrCache * 4 + let mut out : Array UInt8 := Array.mkEmpty totalBytes + out := appendBytes out (Nfp.Sound.SoundCache.encodeHeader hdrCache) + + let h2 ← ExceptT.lift <| IO.FS.Handle.mk modelPath IO.FS.Mode.read + let hdr2 ← liftExcept <| Nfp.Untrusted.SoundBinary.readBinaryHeader h2 + if hdr2.numLayers ≠ L || hdr2.numHeads ≠ H || hdr2.modelDim ≠ d || + hdr2.headDim ≠ dh || hdr2.hiddenDim ≠ dhid then + throw "binary header mismatch between passes" + + let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipI32Array h2 hdr2.seqLen + let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h2 (hdr2.seqLen * d) + + for l in [:L] do + out := appendI32Array out (ln1Gamma[l]!) + out := appendI32Array out (ln1Beta[l]!) + out := appendI32Array out (ln2Gamma[l]!) + out := appendI32Array out (ln2Beta[l]!) + for _h in [:H] do + let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h2 (d * dh) + let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h2 dh + let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h2 (d * dh) + let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h2 dh + let wV ← liftExcept <| Nfp.Untrusted.SoundBinary.readScaledFloatArray h2 (d * dh) scalePow10 + let bV ← liftExcept <| Nfp.Untrusted.SoundBinary.readScaledFloatArray h2 dh scalePow10 + let wO ← liftExcept <| Nfp.Untrusted.SoundBinary.readScaledFloatArray h2 (dh * d) scalePow10 + out := appendI32Array out wV + out := appendI32Array out bV + out := appendI32Array out wO + let attnBias ← liftExcept <| Nfp.Untrusted.SoundBinary.readScaledFloatArray h2 d scalePow10 + let wIn ← liftExcept <| Nfp.Untrusted.SoundBinary.readScaledFloatArray h2 (d * dhid) scalePow10 + let bIn ← liftExcept <| Nfp.Untrusted.SoundBinary.readScaledFloatArray h2 dhid scalePow10 + let wOut ← liftExcept <| Nfp.Untrusted.SoundBinary.readScaledFloatArray h2 (dhid * d) scalePow10 + let bOut ← liftExcept <| Nfp.Untrusted.SoundBinary.readScaledFloatArray h2 d scalePow10 + out := appendI32Array out attnBias + out := appendI32Array out wIn + out := appendI32Array out bIn + out := appendI32Array out wOut + out := appendI32Array out bOut + let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h2 d + let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h2 d + let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h2 d + let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h2 d + + if out.size ≠ totalBytes then + throw s!"cache size mismatch: expected {totalBytes}, got {out.size}" + return ByteArray.mk out + action.run + /-- Build (or overwrite) a SOUND fixed-point cache file. -/ def buildCacheFile (modelPath cachePath : System.FilePath) (scalePow10 : Nat := 9) : IO (Except String Unit) := do ensureCacheDir - let contents ← IO.FS.readFile modelPath - let lines : Array String := (contents.splitOn "\n").toArray let modelHash ← fnv1a64File modelPath let mdata ← modelPath.metadata let modelSize : UInt64 := mdata.byteSize - match Nfp.Sound.SoundCache.buildCacheBytes lines scalePow10 modelHash modelSize with + match ← buildCacheBytesText modelPath scalePow10 modelHash modelSize with | .error e => return .error e | .ok bytes => let tmpPath := cachePath.withExtension "tmp" From d4a23bb65cef1368aa9491113082e84bb47c9bb0 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 08:49:38 +0100 Subject: [PATCH 020/244] Optimize sound Rat parsing --- Nfp/Sound/Decimal.lean | 245 +++- Nfp/Sound/TextPure.lean | 22 +- Nfp/Untrusted/SoundCompute.lean | 2232 +++++++++++++++++++++++++++---- 3 files changed, 2137 insertions(+), 362 deletions(-) diff --git a/Nfp/Sound/Decimal.lean b/Nfp/Sound/Decimal.lean index 4ed69fd..c9a36a4 100644 --- a/Nfp/Sound/Decimal.lean +++ b/Nfp/Sound/Decimal.lean @@ -42,6 +42,166 @@ def parseNat10OrZero (s : String) : Except String Nat := | none => .error s!"invalid natural '{s}'" | some n => .ok n +/-- Parse a decimal/scientific numeral from a substring into an exact `Rat`. -/ +def parseRatRange (s : String) (start stop : String.Pos.Raw) : Except String Rat := Id.run do + if start >= stop then + return .error "empty numeral" + + let token := fun () => String.Pos.Raw.extract s start stop + + -- sign + let mut p := start + let mut neg := false + let c0 := p.get s + if c0 = '-' then + neg := true + p := p.next s + else if c0 = '+' then + p := p.next s + + -- optional exponent (exactly one `e`, otherwise exactly one `E`). + let mut ePos : Option String.Pos.Raw := none + let mut eCount : Nat := 0 + let mut EPos : Option String.Pos.Raw := none + let mut ECount : Nat := 0 + let mut q := p + while q < stop do + let c := q.get s + if c = 'e' then + eCount := eCount + 1 + if eCount = 1 then ePos := some q + else if c = 'E' then + ECount := ECount + 1 + if ECount = 1 then EPos := some q + q := q.next s + + let expMarker? : Option String.Pos.Raw := + if eCount = 1 then ePos else if ECount = 1 then EPos else none + + let mantEnd : String.Pos.Raw := + match expMarker? with + | some ep => ep + | none => stop + + -- mantissa: intPart.fracPart + let mut dotPos : Option String.Pos.Raw := none + let mut dotCount : Nat := 0 + let mut r := p + while r < mantEnd do + if r.get s = '.' then + dotCount := dotCount + 1 + if dotCount = 1 then dotPos := some r + r := r.next s + if dotCount > 1 then + return .error s!"invalid numeral '{token ()}'" + + let intStart := p + let intStop : String.Pos.Raw := + match dotPos with + | some dp => dp + | none => mantEnd + let fracStart? : Option String.Pos.Raw := + match dotPos with + | some dp => some (dp.next s) + | none => none + let fracStop := mantEnd + + let parseNatRangeOrZero (start stop : String.Pos.Raw) : Except String (Nat × Nat) := Id.run do + if start >= stop then + return .ok (0, 0) + let mut p := start + let mut acc : Nat := 0 + let mut len : Nat := 0 + while p < stop do + let c := p.get s + if ('0' <= c) && (c <= '9') then + acc := acc * 10 + (c.toNat - '0'.toNat) + len := len + 1 + p := p.next s + else + let tok := String.Pos.Raw.extract s start stop + return .error s!"invalid natural '{tok}'" + return .ok (acc, len) + + let parseIntRange (start stop : String.Pos.Raw) : Except String Int := Id.run do + if start >= stop then + return .error "empty integer" + let tok := String.Pos.Raw.extract s start stop + let mut p := start + let mut neg := false + let c0 := p.get s + if c0 = '-' then + neg := true + p := p.next s + else if c0 = '+' then + p := p.next s + if p >= stop then + return .error s!"invalid integer '{tok}'" + let mut acc : Nat := 0 + while p < stop do + let c := p.get s + if ('0' <= c) && (c <= '9') then + acc := acc * 10 + (c.toNat - '0'.toNat) + p := p.next s + else + return .error s!"invalid integer '{tok}'" + let i : Int := Int.ofNat acc + return .ok (if neg then -i else i) + + let buildResult (iNat fNat fracLen : Nat) (expInt : Int) : Except String Rat := + -- Construct `Rat` in a single normalization step (avoids repeated gcd normalization). + let denomBase : Nat := Nat.pow 10 fracLen + let mantissaNat : Nat := iNat * denomBase + fNat + let num0 : Int := if neg then -(Int.ofNat mantissaNat) else (Int.ofNat mantissaNat) + let expAbs : Nat := Int.natAbs expInt + let pow10Nat : Nat := Nat.pow 10 expAbs + + let den : Nat := + if expInt < 0 then denomBase * pow10Nat else denomBase + let num : Int := + if expInt > 0 then num0 * (Int.ofNat pow10Nat) else num0 + + have den_nz : den ≠ 0 := by + have h10pos : (0 : Nat) < 10 := by decide + have hpow1 : denomBase ≠ 0 := by + exact Nat.ne_of_gt (Nat.pow_pos (n := fracLen) h10pos) + have hpow2 : pow10Nat ≠ 0 := by + exact Nat.ne_of_gt (Nat.pow_pos (n := expAbs) h10pos) + by_cases hneg : expInt < 0 + · -- `den = denomBase * pow10Nat` + simpa [den, hneg] using Nat.mul_ne_zero hpow1 hpow2 + · -- `den = denomBase` + simpa [den, hneg] using hpow1 + + .ok (Rat.normalize num den (den_nz := den_nz)) + + let result : Except String Rat := + match parseNatRangeOrZero intStart intStop with + | .error e => .error e + | .ok (iNat, _) => + match fracStart? with + | none => + match expMarker? with + | none => buildResult iNat 0 0 0 + | some ep => + let expStart := ep.next s + match parseIntRange expStart stop with + | .error e => .error e + | .ok expInt => buildResult iNat 0 0 expInt + | some fs => + match parseNatRangeOrZero fs fracStop with + | .error e => .error e + | .ok (fNat, fracLen) => + match expMarker? with + | none => buildResult iNat fNat fracLen 0 + | some ep => + let expStart := ep.next s + match parseIntRange expStart stop with + | .error e => .error e + | .ok expInt => buildResult iNat fNat fracLen expInt + + return result + /-- Parse a decimal/scientific numeral into an exact `Rat`. Supported forms: @@ -55,73 +215,27 @@ def parseRat (s : String) : Except String Rat := do let s := s.trim if s.isEmpty then throw "empty numeral" - - -- sign - let (neg, rest) := - if s.startsWith "-" then (true, s.drop 1) - else if s.startsWith "+" then (false, s.drop 1) - else (false, s) - - -- optional exponent - let (mantissaStr, expStr?) : String × Option String := - match rest.splitOn "e" with - | [m, e] => (m, some e) - | _ => - match rest.splitOn "E" with - | [m, e] => (m, some e) - | _ => (rest, none) - - -- mantissa: intPart.fracPart - let parts := mantissaStr.splitOn "." - let (intPart, fracPart) ← - match parts with - | [i] => pure (i, "") - | [i, f] => pure (i, f) - | _ => throw s!"invalid numeral '{s}'" - - let iNat ← parseNat10OrZero intPart - let fNat ← parseNat10OrZero fracPart - let fracLen := fracPart.trim.length - - let expInt : Int ← - match expStr? with - | none => pure 0 - | some e => parseInt10 e - - -- Construct `Rat` in a single normalization step (avoids repeated gcd normalization). - let denomBase : Nat := Nat.pow 10 fracLen - let mantissaNat : Nat := iNat * denomBase + fNat - let num0 : Int := if neg then -(Int.ofNat mantissaNat) else (Int.ofNat mantissaNat) - let expAbs : Nat := Int.natAbs expInt - let pow10Nat : Nat := Nat.pow 10 expAbs - - let den : Nat := - if expInt < 0 then denomBase * pow10Nat else denomBase - let num : Int := - if expInt > 0 then num0 * (Int.ofNat pow10Nat) else num0 - - have den_nz : den ≠ 0 := by - have h10pos : (0 : Nat) < 10 := by decide - have hpow1 : denomBase ≠ 0 := by - exact Nat.ne_of_gt (Nat.pow_pos (n := fracLen) h10pos) - have hpow2 : pow10Nat ≠ 0 := by - exact Nat.ne_of_gt (Nat.pow_pos (n := expAbs) h10pos) - by_cases hneg : expInt < 0 - · -- `den = denomBase * pow10Nat` - simpa [den, hneg] using Nat.mul_ne_zero hpow1 hpow2 - · -- `den = denomBase` - simpa [den, hneg] using hpow1 - - return Rat.normalize num den (den_nz := den_nz) + parseRatRange s 0 s.rawEndPos /-- Parse a line of space-separated rationals, failing on the first invalid token. -/ -def parseRatLine (line : String) : Except String (Array Rat) := do - let parts := line.splitOn " " |>.filter (· ≠ "") +def parseRatLine (line : String) : Except String (Array Rat) := Id.run do + let isWs (c : Char) : Bool := + c = ' ' || c = '\t' || c = '\n' || c = '\r' let mut out : Array Rat := #[] - for p in parts do - let r ← parseRat p - out := out.push r - return out + let s := line + let mut p : String.Pos.Raw := 0 + let stop := s.rawEndPos + while p < stop do + while p < stop && isWs (p.get s) do + p := p.next s + let tokStart := p + while p < stop && !isWs (p.get s) do + p := p.next s + if tokStart < p then + match parseRatRange s tokStart p with + | .error e => return .error e + | .ok r => out := out.push r + return .ok out /-! ### Specs -/ @@ -129,6 +243,9 @@ theorem parseInt10_spec (s : String) : parseInt10 s = parseInt10 s := rfl theorem parseNat10OrZero_spec (s : String) : parseNat10OrZero s = parseNat10OrZero s := rfl +theorem parseRatRange_spec (s : String) (start stop : String.Pos.Raw) : + parseRatRange s start stop = parseRatRange s start stop := rfl + theorem parseRat_spec (s : String) : parseRat s = parseRat s := rfl theorem parseRatLine_spec (line : String) : parseRatLine line = parseRatLine line := rfl diff --git a/Nfp/Sound/TextPure.lean b/Nfp/Sound/TextPure.lean index be8a929..a068eeb 100644 --- a/Nfp/Sound/TextPure.lean +++ b/Nfp/Sound/TextPure.lean @@ -102,21 +102,25 @@ def foldRatTokens {α : Type} (state : α) (step : α → Rat → α) : Except String (α × Nat) := Id.run do + let isWs (c : Char) : Bool := + c = ' ' || c = '\t' || c = '\n' || c = '\r' let mut i := start let mut remaining := count let mut st := state while remaining > 0 do if i < lines.size then - let line := lines[i]!.trim + let line := lines[i]! i := i + 1 - if line.isEmpty then - pure () - else - let toks := line.splitOn " " |>.filter (· ≠ "") - for t in toks do - if remaining = 0 then - break - match parseRat t with + let mut p : String.Pos.Raw := 0 + let stop := line.rawEndPos + while p < stop && remaining > 0 do + while p < stop && isWs (p.get line) do + p := p.next line + let tokStart := p + while p < stop && !isWs (p.get line) do + p := p.next line + if tokStart < p then + match parseRatRange line tokStart p with | .error e => return .error e | .ok r => st := step st r diff --git a/Nfp/Untrusted/SoundCompute.lean b/Nfp/Untrusted/SoundCompute.lean index e116898..95b6522 100644 --- a/Nfp/Untrusted/SoundCompute.lean +++ b/Nfp/Untrusted/SoundCompute.lean @@ -5,6 +5,7 @@ import Nfp.Sound.Cert import Nfp.Sound.HeadCert import Nfp.Untrusted.SoundBinary import Nfp.Sound.Interval +import Nfp.Sound.Affine import Nfp.Untrusted.SoundCacheIO import Nfp.Sound.Fixed @@ -58,21 +59,25 @@ def foldRatTokens {α : Type} (state : α) (step : α → Rat → α) : Except String (α × Nat) := Id.run do + let isWs (c : Char) : Bool := + c = ' ' || c = '\t' || c = '\n' || c = '\r' let mut i := start let mut remaining := count let mut st := state while remaining > 0 do if i < lines.size then - let line := lines[i]!.trim + let line := lines[i]! i := i + 1 - if line.isEmpty then - pure () - else - let toks := line.splitOn " " |>.filter (· ≠ "") - for t in toks do - if remaining = 0 then - break - match parseRat t with + let mut p : String.Pos.Raw := 0 + let stop := line.rawEndPos + while p < stop && remaining > 0 do + while p < stop && isWs (p.get line) do + p := p.next line + let tokStart := p + while p < stop && !isWs (p.get line) do + p := p.next line + if tokStart < p then + match parseRatRange line tokStart p with | .error e => return .error e | .ok r => st := step st r @@ -1181,8 +1186,12 @@ private def fixedVarianceLowerBoundExact (cfg : Fixed10Cfg) (xs : Array Fixed10I private def fixedVarianceLowerBound (cfg : Fixed10Cfg) (xs : Array Fixed10Interval) : Rat := let rangeLB := fixedVarianceLowerBoundRange cfg xs let midLB := fixedVarianceLowerBoundMidpoint cfg xs - let exactLB := fixedVarianceLowerBoundExact cfg xs - max rangeLB (max midLB exactLB) + -- Avoid the exact Rat-based bound on large rows (expensive and stack-heavy). + if xs.size > 256 then + max rangeLB midLB + else + let exactLB := fixedVarianceLowerBoundExact cfg xs + max rangeLB (max midLB exactLB) private def fixedLayerNormRowApprox (cfg : Fixed10Cfg) @@ -1212,6 +1221,33 @@ private def fixedLayerNormRowApprox out := out.push (Fixed10Interval.add scaled beta[i]!) return (out, varLB) +private def fixedLayerNormRowApproxExact + (cfg : Fixed10Cfg) + (row : Array Fixed10Interval) + (gamma beta : Array Fixed10Interval) + (eps : Rat) + (soundnessBits : Nat) : Array Fixed10Interval := + if row.size = 0 || gamma.size ≠ row.size || beta.size ≠ row.size then + row + else + Id.run do + let μ := fixedMeanInterval row + let varLB := fixedVarianceLowerBoundExact cfg row + let invσUpper : Rat := + if varLB ≤ 0 then + layerNormOpBoundConservative 1 eps soundnessBits + else + layerNormOpBoundLocal 1 varLB eps soundnessBits + let invσUpperInt : Int := ratCeilMulNat invσUpper cfg.scaleNat + let invσFix : Fixed10Interval := { lo := invσUpperInt, hi := invσUpperInt } + let mut out : Array Fixed10Interval := Array.mkEmpty row.size + for i in [:row.size] do + let centered := Fixed10Interval.sub row[i]! μ + let coeff := Fixed10Interval.mul cfg gamma[i]! invσFix + let scaled := Fixed10Interval.mul cfg coeff centered + out := out.push (Fixed10Interval.add scaled beta[i]!) + return out + private def fixedLayerNormRowsApprox (cfg : Fixed10Cfg) (rows : Array (Array Fixed10Interval)) @@ -1221,12 +1257,69 @@ private def fixedLayerNormRowsApprox Array (Array Fixed10Interval) := let useTasks := rows.size > 32 if useTasks then - let tasks := rows.map (fun row => - Task.spawn (fun _ => fixedLayerNormRowApprox cfg row p.gamma p.beta eps soundnessBits)) - tasks.map (fun t => (t.get).1) + Id.run do + let chunkSize : Nat := 16 + let numChunks : Nat := (rows.size + chunkSize - 1) / chunkSize + let mut tasks : Array (Task (Array (Array Fixed10Interval))) := Array.mkEmpty numChunks + let mut chunkIdx : Nat := 0 + while chunkIdx < numChunks do + let start := chunkIdx * chunkSize + let stop := min rows.size (start + chunkSize) + tasks := tasks.push <| + Task.spawn (fun _ => + Id.run do + let mut outChunk : Array (Array Fixed10Interval) := Array.mkEmpty (stop - start) + let mut i := start + while i < stop do + outChunk := outChunk.push + (fixedLayerNormRowApprox cfg rows[i]! p.gamma p.beta eps soundnessBits).1 + i := i + 1 + return outChunk) + chunkIdx := chunkIdx + 1 + let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size + for t in tasks do + for row in t.get do + out := out.push row + return out else rows.map (fun row => (fixedLayerNormRowApprox cfg row p.gamma p.beta eps soundnessBits).1) +private def fixedLayerNormRowsApproxExact + (cfg : Fixed10Cfg) + (rows : Array (Array Fixed10Interval)) + (p : LayerNormParamsFixed) + (eps : Rat) + (soundnessBits : Nat) : + Array (Array Fixed10Interval) := + let useTasks := rows.size > 32 + if useTasks then + Id.run do + let chunkSize : Nat := 16 + let numChunks : Nat := (rows.size + chunkSize - 1) / chunkSize + let mut tasks : Array (Task (Array (Array Fixed10Interval))) := Array.mkEmpty numChunks + let mut chunkIdx : Nat := 0 + while chunkIdx < numChunks do + let start := chunkIdx * chunkSize + let stop := min rows.size (start + chunkSize) + tasks := tasks.push <| + Task.spawn (fun _ => + Id.run do + let mut outChunk : Array (Array Fixed10Interval) := Array.mkEmpty (stop - start) + let mut i := start + while i < stop do + outChunk := outChunk.push + (fixedLayerNormRowApproxExact cfg rows[i]! p.gamma p.beta eps soundnessBits) + i := i + 1 + return outChunk) + chunkIdx := chunkIdx + 1 + let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size + for t in tasks do + for row in t.get do + out := out.push row + return out + else + rows.map (fun row => fixedLayerNormRowApproxExact cfg row p.gamma p.beta eps soundnessBits) + private def readVecIntervals (r : SoundCache.I32Reader) (n : Nat) (slack : Int) : IO (Array Fixed10Interval × SoundCache.I32Reader) := do @@ -1245,26 +1338,153 @@ private def readVecIntervalsBinary | .error e => return .error e | .ok xs => return .ok (intervalsFromScaled xs slack) -private def matMulIntervalsFromScaled +private def matMulIntervalsFromScaledCore (cfg : Fixed10Cfg) (slack : Int) (rows cols : Nat) (weights : Array Int) (input : Array Fixed10Interval) : Array Fixed10Interval := Id.run do - if input.size ≠ rows || weights.size ≠ rows * cols then - return Array.replicate cols { lo := 0, hi := 0 } let mut out : Array Fixed10Interval := Array.replicate cols { lo := 0, hi := 0 } - for rowIdx in [:rows] do + let mut rowIdx : Nat := 0 + while rowIdx < rows do let xi := input[rowIdx]! - for colIdx in [:cols] do + let mut colIdx : Nat := 0 + while colIdx < cols do let idx := rowIdx * cols + colIdx let w := weights[idx]! let wI : Fixed10Interval := { lo := w - slack, hi := w + slack } let term := Fixed10Interval.mul cfg wI xi out := out.set! colIdx (Fixed10Interval.add (out[colIdx]!) term) + colIdx := colIdx + 1 + rowIdx := rowIdx + 1 + return out + +private def matMulIntervalsFromScaledNoTask + (cfg : Fixed10Cfg) + (slack : Int) + (rows cols : Nat) + (weights : Array Int) + (input : Array Fixed10Interval) : Array Fixed10Interval := + if input.size ≠ rows || weights.size ≠ rows * cols then + Array.replicate cols { lo := 0, hi := 0 } + else + matMulIntervalsFromScaledCore cfg slack rows cols weights input + +private def matMulIntervalsFromIntervalsNoTask + (cfg : Fixed10Cfg) + (rows cols : Nat) + (weights : Array Fixed10Interval) + (input : Array Fixed10Interval) : Array Fixed10Interval := + Id.run do + if input.size ≠ rows || weights.size ≠ rows * cols then + return Array.replicate cols { lo := 0, hi := 0 } + let mut out : Array Fixed10Interval := Array.replicate cols { lo := 0, hi := 0 } + let mut rowIdx : Nat := 0 + while rowIdx < rows do + let xi := input[rowIdx]! + let mut colIdx : Nat := 0 + while colIdx < cols do + let idx := rowIdx * cols + colIdx + let wI := weights[idx]! + let term := Fixed10Interval.mul cfg wI xi + out := out.set! colIdx (Fixed10Interval.add (out[colIdx]!) term) + colIdx := colIdx + 1 + rowIdx := rowIdx + 1 return out +private def matMulIntervalsFromIntervals + (cfg : Fixed10Cfg) + (rows cols : Nat) + (weights : Array Fixed10Interval) + (input : Array Fixed10Interval) : Array Fixed10Interval := + Id.run do + if input.size ≠ rows || weights.size ≠ rows * cols then + return Array.replicate cols { lo := 0, hi := 0 } + let useTasks := rows * cols > 16384 && cols > 1 + if useTasks then + let chunkSize : Nat := 32 + let numChunks : Nat := (cols + chunkSize - 1) / chunkSize + let mut tasks : Array (Task (Array Fixed10Interval)) := Array.mkEmpty numChunks + let mut chunkIdx : Nat := 0 + while chunkIdx < numChunks do + let start := chunkIdx * chunkSize + let stop := min cols (start + chunkSize) + tasks := tasks.push <| + Task.spawn (fun _ => + Id.run do + let mut outChunk : Array Fixed10Interval := Array.mkEmpty (stop - start) + let mut colIdx : Nat := start + while colIdx < stop do + let mut acc : Fixed10Interval := { lo := 0, hi := 0 } + let mut rowIdx : Nat := 0 + while rowIdx < rows do + let xi := input[rowIdx]! + let idx := rowIdx * cols + colIdx + let wI := weights[idx]! + let term := Fixed10Interval.mul cfg wI xi + acc := Fixed10Interval.add acc term + rowIdx := rowIdx + 1 + outChunk := outChunk.push acc + colIdx := colIdx + 1 + return outChunk) + chunkIdx := chunkIdx + 1 + let mut out : Array Fixed10Interval := Array.mkEmpty cols + for t in tasks do + let chunk := t.get + for v in chunk do + out := out.push v + return out + else + return matMulIntervalsFromIntervalsNoTask cfg rows cols weights input + +private def matMulIntervalsFromScaled + (cfg : Fixed10Cfg) + (slack : Int) + (rows cols : Nat) + (weights : Array Int) + (input : Array Fixed10Interval) : Array Fixed10Interval := + Id.run do + if input.size ≠ rows || weights.size ≠ rows * cols then + return Array.replicate cols { lo := 0, hi := 0 } + let useTasks := rows * cols > 16384 && cols > 1 + if useTasks then + let chunkSize : Nat := 32 + let numChunks : Nat := (cols + chunkSize - 1) / chunkSize + let mut tasks : Array (Task (Array Fixed10Interval)) := Array.mkEmpty numChunks + let mut chunkIdx : Nat := 0 + while chunkIdx < numChunks do + let start := chunkIdx * chunkSize + let stop := min cols (start + chunkSize) + tasks := tasks.push <| + Task.spawn (fun _ => + Id.run do + let mut outChunk : Array Fixed10Interval := Array.mkEmpty (stop - start) + let mut colIdx : Nat := start + while colIdx < stop do + let mut acc : Fixed10Interval := { lo := 0, hi := 0 } + let mut rowIdx : Nat := 0 + while rowIdx < rows do + let xi := input[rowIdx]! + let idx := rowIdx * cols + colIdx + let w := weights[idx]! + let wI : Fixed10Interval := { lo := w - slack, hi := w + slack } + let term := Fixed10Interval.mul cfg wI xi + acc := Fixed10Interval.add acc term + rowIdx := rowIdx + 1 + outChunk := outChunk.push acc + colIdx := colIdx + 1 + return outChunk) + chunkIdx := chunkIdx + 1 + let mut out : Array Fixed10Interval := Array.mkEmpty cols + for t in tasks do + let chunk := t.get + for v in chunk do + out := out.push v + return out + else + return matMulIntervalsFromScaledCore cfg slack rows cols weights input + private def fixedDotInterval (cfg : Fixed10Cfg) (a b : Array Fixed10Interval) : Fixed10Interval := @@ -1278,6 +1498,426 @@ private def fixedDotInterval acc := Fixed10Interval.add acc term return acc +private def centerRadiusOfFixed + (cfg : Fixed10Cfg) (a : Fixed10Interval) : Rat × Rat := + let lo := ratOfScaledInt cfg.scalePow10 a.lo + let hi := ratOfScaledInt cfg.scalePow10 a.hi + let center := (lo + hi) / (2 : Rat) + let radius := (hi - lo) / (2 : Rat) + (center, radius) + +private def rowCentersRadiiAbs + (cfg : Fixed10Cfg) + (row : Array Fixed10Interval) : Array Rat × Array Rat × Rat := + Id.run do + let mut centers : Array Rat := Array.mkEmpty row.size + let mut radii : Array Rat := Array.mkEmpty row.size + let mut absSum : Rat := 0 + for x in row do + let lo := ratOfScaledInt cfg.scalePow10 x.lo + let hi := ratOfScaledInt cfg.scalePow10 x.hi + let center := (lo + hi) / (2 : Rat) + let radius := (hi - lo) / (2 : Rat) + centers := centers.push center + radii := radii.push radius + absSum := absSum + max (ratAbs lo) (ratAbs hi) + return (centers, radii, absSum) + +private def weightsRatFromScaled (cfg : Fixed10Cfg) (weights : Array Int) : Array Rat := + weights.map (ratOfScaledInt cfg.scalePow10) + +private def affineMatMulRowExact + (rows cols : Nat) + (weights : Array Rat) + (centers radii : Array Rat) : Array AffineForm := + Id.run do + if centers.size ≠ rows || radii.size ≠ rows || weights.size ≠ rows * cols then + return Array.replicate cols (AffineForm.const 0) + let mut out : Array AffineForm := Array.mkEmpty cols + for colIdx in [:cols] do + let mut center : Rat := 0 + let mut coeffs : Array Rat := Array.mkEmpty rows + for rowIdx in [:rows] do + let idx := rowIdx * cols + colIdx + let w := weights[idx]! + center := center + w * centers[rowIdx]! + coeffs := coeffs.push (w * radii[rowIdx]!) + out := out.push { center := center, coeffs := coeffs } + return out + +private def affineAddBiasCenters + (biasCenters : Array Rat) + (row : Array AffineForm) : Array AffineForm := + Id.run do + if biasCenters.size ≠ row.size then + return row + let mut out : Array AffineForm := Array.mkEmpty row.size + for i in [:row.size] do + let a := row.getD i (AffineForm.const 0) + let bias := biasCenters.getD i 0 + out := out.push { a with center := a.center + bias } + return out + +private def affineAbsSum (row : Array AffineForm) : Rat := + row.foldl (fun acc a => acc + ratAbs a.center + AffineForm.radius a) 0 + +private def affineDotDisjoint + (a b : Array AffineForm) : AffineForm := + if a.size = 0 || a.size ≠ b.size then + AffineForm.const 0 + else + Id.run do + let mut acc := AffineForm.const 0 + for i in [:a.size] do + let ai := a.getD i (AffineForm.const 0) + let bi := b.getD i (AffineForm.const 0) + let term := AffineForm.mulDisjoint ai bi + acc := AffineForm.add acc term + return acc + +private def sumRat (xs : Array Rat) : Rat := + Id.run do + let mut acc : Rat := 0 + let mut i := 0 + while i < xs.size do + acc := acc + xs[i]! + i := i + 1 + return acc + +private def sumAbsRat (xs : Array Rat) : Rat := + Id.run do + let mut acc : Rat := 0 + let mut i := 0 + while i < xs.size do + acc := acc + ratAbs xs[i]! + i := i + 1 + return acc + +private def addVecRat (a b : Array Rat) : Array Rat := + Id.run do + if a.size ≠ b.size then + return a + let mut out : Array Rat := Array.mkEmpty a.size + let mut i := 0 + while i < a.size do + out := out.push (a[i]! + b[i]!) + i := i + 1 + return out + +private def dotRat (a b : Array Rat) : Rat := + if a.size = 0 || a.size ≠ b.size then + 0 + else + Id.run do + let mut acc : Rat := 0 + let mut i := 0 + while i < a.size do + acc := acc + a[i]! * b[i]! + i := i + 1 + return acc + +private def matMulCentersRadii + (rows cols : Nat) + (weights : Array Rat) + (centers radii : Array Rat) : Array Rat × Array Rat := + Id.run do + if centers.size ≠ rows || radii.size ≠ rows || weights.size ≠ rows * cols then + return (Array.replicate cols 0, Array.replicate cols 0) + let mut outCenters : Array Rat := Array.mkEmpty cols + let mut outRadii : Array Rat := Array.mkEmpty cols + let mut colIdx := 0 + while colIdx < cols do + let mut center : Rat := 0 + let mut radius : Rat := 0 + let mut rowIdx := 0 + while rowIdx < rows do + let idx := rowIdx * cols + colIdx + let w := weights.getD idx 0 + let c := centers.getD rowIdx 0 + let r := radii.getD rowIdx 0 + center := center + w * c + radius := radius + ratAbs w * r + rowIdx := rowIdx + 1 + outCenters := outCenters.push center + outRadii := outRadii.push radius + colIdx := colIdx + 1 + return (outCenters, outRadii) + +private def coeffSumFromCenters + (rows cols : Nat) + (weights : Array Rat) + (inputRadii : Array Rat) + (otherCenters : Array Rat) : Rat := + if inputRadii.size ≠ rows || otherCenters.size ≠ cols || weights.size ≠ rows * cols then + 0 + else + Id.run do + let mut acc : Rat := 0 + let mut rowIdx := 0 + while rowIdx < rows do + let mut sum : Rat := 0 + let mut colIdx := 0 + while colIdx < cols do + let idx := rowIdx * cols + colIdx + sum := sum + weights.getD idx 0 * otherCenters.getD colIdx 0 + colIdx := colIdx + 1 + let coeff := inputRadii.getD rowIdx 0 * sum + acc := acc + ratAbs coeff + rowIdx := rowIdx + 1 + return acc + +private def sumInt (xs : Array Int) : Int := + Id.run do + let mut acc : Int := 0 + let mut i := 0 + while i < xs.size do + acc := acc + xs[i]! + i := i + 1 + return acc + +private def sumAbsInt (xs : Array Int) : Int := + Id.run do + let mut acc : Int := 0 + let mut i := 0 + while i < xs.size do + acc := acc + absInt xs[i]! + i := i + 1 + return acc + +private def addVecScaledInt (a : Array Int) (b : Array Int) (scale : Int) : Array Int := + Id.run do + if a.size ≠ b.size then + return a + let mut out : Array Int := Array.mkEmpty a.size + let mut i := 0 + while i < a.size do + out := out.push (a[i]! + b[i]! * scale) + i := i + 1 + return out + +private def dotInt (a b : Array Int) : Int := + if a.size = 0 || a.size ≠ b.size then + 0 + else + Id.run do + let mut acc : Int := 0 + let mut i := 0 + while i < a.size do + acc := acc + a[i]! * b[i]! + i := i + 1 + return acc + +private def rowCentersRadiiAbsInt + (row : Array Fixed10Interval) : Array Int × Array Int × Int := + Id.run do + let mut centers : Array Int := Array.mkEmpty row.size + let mut radii : Array Int := Array.mkEmpty row.size + let mut absSum : Int := 0 + for x in row do + let sum := x.lo + x.hi + let width := x.hi - x.lo + let center := sum.ediv (Int.ofNat 2) + let half := width.ediv (Int.ofNat 2) + let radius := if width.emod (Int.ofNat 2) = 0 then half else half + 1 + centers := centers.push center + radii := radii.push radius + absSum := absSum + Fixed10Interval.absUpper x + return (centers, radii, absSum) + +private def matMulCentersRadiiInt + (cfg : Fixed10Cfg) + (rows cols : Nat) + (weights : Array Int) + (centers radii : Array Int) : Array Int × Array Int := + Id.run do + if centers.size ≠ rows || radii.size ≠ rows || weights.size ≠ rows * cols then + return (Array.replicate cols 0, Array.replicate cols 0) + let mut outCenters : Array Int := Array.mkEmpty cols + let mut outRadii : Array Int := Array.mkEmpty cols + let mut colIdx := 0 + while colIdx < cols do + let mut centerI : Fixed10Interval := { lo := 0, hi := 0 } + let mut radiusAcc : Int := 0 + let mut rowIdx := 0 + while rowIdx < rows do + let idx := rowIdx * cols + colIdx + let w := weights.getD idx 0 + let c := centers.getD rowIdx 0 + let r := radii.getD rowIdx 0 + let term := Fixed10Interval.mul cfg { lo := w, hi := w } { lo := c, hi := c } + centerI := Fixed10Interval.add centerI term + if r ≠ 0 && w ≠ 0 then + let wAbs := absInt w + let termR := Fixed10Interval.mul cfg { lo := wAbs, hi := wAbs } { lo := r, hi := r } + radiusAcc := radiusAcc + termR.hi + rowIdx := rowIdx + 1 + let width := centerI.hi - centerI.lo + let center := (centerI.lo + centerI.hi).ediv (Int.ofNat 2) + let half := width.ediv (Int.ofNat 2) + let radiusMid := if width.emod (Int.ofNat 2) = 0 then half else half + 1 + let radius := radiusMid + radiusAcc + outCenters := outCenters.push center + outRadii := outRadii.push radius + colIdx := colIdx + 1 + return (outCenters, outRadii) + +private def intervalRadiusInt (x : Fixed10Interval) : Int := + let width := x.hi - x.lo + let half := width.ediv (Int.ofNat 2) + if width.emod (Int.ofNat 2) = 0 then half else half + 1 + +private def matMulCentersRadiiIntSlack + (cfg : Fixed10Cfg) + (slack : Int) + (rows cols : Nat) + (weights : Array Int) + (centers radii : Array Int) : Array Int × Array Int := + Id.run do + if centers.size ≠ rows || radii.size ≠ rows || weights.size ≠ rows * cols then + return (Array.replicate cols 0, Array.replicate cols 0) + let mut outCenters : Array Int := Array.mkEmpty cols + let mut outRadii : Array Int := Array.mkEmpty cols + let mut colIdx := 0 + while colIdx < cols do + let mut centerI : Fixed10Interval := { lo := 0, hi := 0 } + let mut radiusAcc : Int := 0 + let mut rowIdx := 0 + while rowIdx < rows do + let idx := rowIdx * cols + colIdx + let w := weights.getD idx 0 + let c := centers.getD rowIdx 0 + let r := radii.getD rowIdx 0 + let term := Fixed10Interval.mul cfg { lo := w, hi := w } { lo := c, hi := c } + centerI := Fixed10Interval.add centerI term + if r ≠ 0 || slack ≠ 0 then + let wAbs := absInt w + let cAbs := absInt c + let term1 := Fixed10Interval.mul cfg { lo := wAbs, hi := wAbs } { lo := r, hi := r } + let term2 := + if slack = 0 then 0 + else + (Fixed10Interval.mul cfg { lo := slack, hi := slack } + { lo := cAbs, hi := cAbs }).hi + let term3 := + if slack = 0 then 0 + else + (Fixed10Interval.mul cfg { lo := slack, hi := slack } { lo := r, hi := r }).hi + radiusAcc := radiusAcc + term1.hi + term2 + term3 + rowIdx := rowIdx + 1 + let width := centerI.hi - centerI.lo + let center := (centerI.lo + centerI.hi).ediv (Int.ofNat 2) + let half := width.ediv (Int.ofNat 2) + let radiusMid := if width.emod (Int.ofNat 2) = 0 then half else half + 1 + let radius := radiusMid + radiusAcc + outCenters := outCenters.push center + outRadii := outRadii.push radius + colIdx := colIdx + 1 + return (outCenters, outRadii) + +private def coeffSumFromCentersInt + (cfg : Fixed10Cfg) + (rows cols : Nat) + (weights : Array Int) + (inputRadii : Array Int) + (otherCenters : Array Int) : Int := + if inputRadii.size ≠ rows || otherCenters.size ≠ cols || weights.size ≠ rows * cols then + 0 + else + Id.run do + let mut acc : Int := 0 + let mut rowIdx := 0 + while rowIdx < rows do + let mut sum : Fixed10Interval := { lo := 0, hi := 0 } + let mut colIdx := 0 + while colIdx < cols do + let idx := rowIdx * cols + colIdx + let w := weights.getD idx 0 + let c := otherCenters.getD colIdx 0 + let term := Fixed10Interval.mul cfg { lo := w, hi := w } { lo := c, hi := c } + sum := Fixed10Interval.add sum term + colIdx := colIdx + 1 + let r := inputRadii.getD rowIdx 0 + let coeff := Fixed10Interval.mul cfg sum { lo := r, hi := r } + acc := acc + Fixed10Interval.absUpper coeff + rowIdx := rowIdx + 1 + return acc + +private def dotIntervalFromCentersInt + (cfg : Fixed10Cfg) + (a b : Array Int) : Fixed10Interval := + if a.size = 0 || a.size ≠ b.size then + { lo := 0, hi := 0 } + else + Id.run do + let mut acc : Fixed10Interval := { lo := 0, hi := 0 } + let mut i := 0 + while i < a.size do + let term := Fixed10Interval.mul cfg + { lo := a[i]!, hi := a[i]! } + { lo := b[i]!, hi := b[i]! } + acc := Fixed10Interval.add acc term + i := i + 1 + return acc + +private def dotIntervalFromCentersRadiiInt + (cfg : Fixed10Cfg) + (aCenters aRadii bCenters bRadii : Array Int) : Fixed10Interval := + if aCenters.size = 0 || aCenters.size ≠ bCenters.size || + aCenters.size ≠ aRadii.size || bCenters.size ≠ bRadii.size then + { lo := 0, hi := 0 } + else + Id.run do + let mut centerI : Fixed10Interval := { lo := 0, hi := 0 } + let mut radiusAcc : Int := 0 + let mut i := 0 + while i < aCenters.size do + let ac := aCenters[i]! + let ar := aRadii[i]! + let bc := bCenters[i]! + let br := bRadii[i]! + let term := Fixed10Interval.mul cfg { lo := ac, hi := ac } { lo := bc, hi := bc } + centerI := Fixed10Interval.add centerI term + if ar ≠ 0 || br ≠ 0 then + let acAbs := absInt ac + let bcAbs := absInt bc + let term1 := Fixed10Interval.mul cfg { lo := acAbs, hi := acAbs } { lo := br, hi := br } + let term2 := Fixed10Interval.mul cfg { lo := bcAbs, hi := bcAbs } { lo := ar, hi := ar } + let term3 := Fixed10Interval.mul cfg { lo := ar, hi := ar } { lo := br, hi := br } + radiusAcc := radiusAcc + term1.hi + term2.hi + term3.hi + i := i + 1 + let width := centerI.hi - centerI.lo + let center := (centerI.lo + centerI.hi).ediv (Int.ofNat 2) + let half := width.ediv (Int.ofNat 2) + let radiusMid := if width.emod (Int.ofNat 2) = 0 then half else half + 1 + let radius := radiusMid + radiusAcc + return { lo := center - radius, hi := center + radius } + +private def sumMulUpperInt + (cfg : Fixed10Cfg) + (a b : Array Int) : Int := + if a.size = 0 || a.size ≠ b.size then + 0 + else + Id.run do + let mut acc : Int := 0 + let mut i := 0 + while i < a.size do + let term := Fixed10Interval.mul cfg + { lo := a[i]!, hi := a[i]! } + { lo := b[i]!, hi := b[i]! } + acc := acc + term.hi + i := i + 1 + return acc + +private def floorDivNat (a : Int) (d : Nat) : Int := + a.ediv (Int.ofNat d) + +private def ceilDivNat (a : Int) (d : Nat) : Int := + let di : Int := Int.ofNat d + let q := a.ediv di + let r := a.emod di + if r = 0 then q else q + 1 + private def maxAbsVecFixed (xs : Array Fixed10Interval) : Int := xs.foldl (fun acc x => max acc (Fixed10Interval.absUpper x)) 0 @@ -1314,21 +1954,45 @@ private def geluOverapproxFixed (cfg : Fixed10Cfg) (target : GeluDerivTarget) private def geluOverapproxFixedVec (cfg : Fixed10Cfg) (target : GeluDerivTarget) (xs : Array Fixed10Interval) : Array Fixed10Interval := - xs.map (geluOverapproxFixed cfg target) + Id.run do + let mut out : Array Fixed10Interval := Array.mkEmpty xs.size + let mut i : Nat := 0 + while i < xs.size do + out := out.push (geluOverapproxFixed cfg target xs[i]!) + i := i + 1 + return out + +private def geluOverapproxFixedVecLinear + (xs : Array Fixed10Interval) : Array Fixed10Interval := + Id.run do + let mut out : Array Fixed10Interval := Array.mkEmpty xs.size + let mut i : Nat := 0 + while i < xs.size do + out := out.push (Fixed10Interval.geluOverapprox xs[i]!) + i := i + 1 + return out private def addVecFixed (a b : Array Fixed10Interval) : Array Fixed10Interval := Id.run do if a.size ≠ b.size then return a - let mut out := Array.mkEmpty a.size - for i in [:a.size] do + let mut out : Array Fixed10Interval := Array.mkEmpty a.size + let mut i : Nat := 0 + while i < a.size do out := out.push (Fixed10Interval.add a[i]! b[i]!) + i := i + 1 return out private def addVecFixedRows (rows : Array (Array Fixed10Interval)) (v : Array Fixed10Interval) : Array (Array Fixed10Interval) := - rows.map (fun row => addVecFixed row v) + Id.run do + let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size + let mut i : Nat := 0 + while i < rows.size do + out := out.push (addVecFixed rows[i]! v) + i := i + 1 + return out private def addRowsFixed (rows : Array (Array Fixed10Interval)) @@ -1337,8 +2001,10 @@ private def addRowsFixed if rows.size ≠ adds.size then return rows let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size - for i in [:rows.size] do + let mut i : Nat := 0 + while i < rows.size do out := out.push (addVecFixed rows[i]! adds[i]!) + i := i + 1 return out private def takePrefix {α : Type} (xs : Array α) (n : Nat) : Array α := @@ -1358,6 +2024,110 @@ private def mlpRowFromScaled let mlpOut0 := matMulIntervalsFromScaled cfg slack hiddenDim modelDim wOut actHidden addVecFixed mlpOut0 bOut +private def mlpRowFromScaledNoTask + (cfg : Fixed10Cfg) + (geluDerivTarget : GeluDerivTarget) + (slack : Int) + (modelDim hiddenDim : Nat) + (wIn wOut : Array Int) + (bIn bOut : Array Fixed10Interval) + (row : Array Fixed10Interval) : Array Fixed10Interval := + let hidden0 := matMulIntervalsFromScaledNoTask cfg slack modelDim hiddenDim wIn row + let hiddenB := addVecFixed hidden0 bIn + let actHidden := geluOverapproxFixedVec cfg geluDerivTarget hiddenB + let mlpOut0 := matMulIntervalsFromScaledNoTask cfg slack hiddenDim modelDim wOut actHidden + addVecFixed mlpOut0 bOut + +/-- Linear GeLU-hull MLP row used to avoid the tanh/exp path in hot loops. -/ +private def mlpRowFromScaledLinear + (cfg : Fixed10Cfg) + (slack : Int) + (modelDim hiddenDim : Nat) + (wIn wOut : Array Int) + (bIn bOut : Array Fixed10Interval) + (row : Array Fixed10Interval) : Array Fixed10Interval := + let hidden0 := matMulIntervalsFromScaled cfg slack modelDim hiddenDim wIn row + let hiddenB := addVecFixed hidden0 bIn + let actHidden := geluOverapproxFixedVecLinear hiddenB + let mlpOut0 := matMulIntervalsFromScaled cfg slack hiddenDim modelDim wOut actHidden + addVecFixed mlpOut0 bOut + +private def mlpRowFromScaledLinearNoTask + (cfg : Fixed10Cfg) + (slack : Int) + (modelDim hiddenDim : Nat) + (wIn wOut : Array Int) + (bIn bOut : Array Fixed10Interval) + (row : Array Fixed10Interval) : Array Fixed10Interval := + let hidden0 := matMulIntervalsFromScaledNoTask cfg slack modelDim hiddenDim wIn row + let hiddenB := addVecFixed hidden0 bIn + let actHidden := geluOverapproxFixedVecLinear hiddenB + let mlpOut0 := matMulIntervalsFromScaledNoTask cfg slack hiddenDim modelDim wOut actHidden + addVecFixed mlpOut0 bOut + +private def mlpRowFromIntervalsNoTask + (cfg : Fixed10Cfg) + (geluDerivTarget : GeluDerivTarget) + (modelDim hiddenDim : Nat) + (wIn wOut : Array Fixed10Interval) + (bIn bOut : Array Fixed10Interval) + (row : Array Fixed10Interval) : Array Fixed10Interval := + let hidden0 := matMulIntervalsFromIntervalsNoTask cfg modelDim hiddenDim wIn row + let hiddenB := addVecFixed hidden0 bIn + let actHidden := geluOverapproxFixedVec cfg geluDerivTarget hiddenB + let mlpOut0 := matMulIntervalsFromIntervalsNoTask cfg hiddenDim modelDim wOut actHidden + addVecFixed mlpOut0 bOut + +private def mlpRowFromIntervals + (cfg : Fixed10Cfg) + (geluDerivTarget : GeluDerivTarget) + (modelDim hiddenDim : Nat) + (wIn wOut : Array Fixed10Interval) + (bIn bOut : Array Fixed10Interval) + (row : Array Fixed10Interval) : Array Fixed10Interval := + let hidden0 := matMulIntervalsFromIntervals cfg modelDim hiddenDim wIn row + let hiddenB := addVecFixed hidden0 bIn + let actHidden := geluOverapproxFixedVec cfg geluDerivTarget hiddenB + let mlpOut0 := matMulIntervalsFromIntervals cfg hiddenDim modelDim wOut actHidden + addVecFixed mlpOut0 bOut + +private def mlpRowsFromIntervals + (cfg : Fixed10Cfg) + (geluDerivTarget : GeluDerivTarget) + (modelDim hiddenDim : Nat) + (wIn wOut : Array Fixed10Interval) + (bIn bOut : Array Fixed10Interval) + (rows : Array (Array Fixed10Interval)) : Array (Array Fixed10Interval) := + let useTasks := rows.size > 32 + if useTasks then + Id.run do + let chunkSize : Nat := 16 + let numChunks : Nat := (rows.size + chunkSize - 1) / chunkSize + let mut tasks : Array (Task (Array (Array Fixed10Interval))) := Array.mkEmpty numChunks + let mut chunkIdx : Nat := 0 + while chunkIdx < numChunks do + let start := chunkIdx * chunkSize + let stop := min rows.size (start + chunkSize) + tasks := tasks.push <| + Task.spawn (fun _ => + Id.run do + let mut outChunk : Array (Array Fixed10Interval) := Array.mkEmpty (stop - start) + let mut i := start + while i < stop do + outChunk := outChunk.push + (mlpRowFromIntervalsNoTask cfg geluDerivTarget modelDim hiddenDim wIn wOut bIn + bOut rows[i]!) + i := i + 1 + return outChunk) + chunkIdx := chunkIdx + 1 + let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size + for t in tasks do + for row in t.get do + out := out.push row + return out + else + rows.map (mlpRowFromIntervalsNoTask cfg geluDerivTarget modelDim hiddenDim wIn wOut bIn bOut) + private def mlpRowsFromScaled (cfg : Fixed10Cfg) (geluDerivTarget : GeluDerivTarget) @@ -1368,12 +2138,77 @@ private def mlpRowsFromScaled (rows : Array (Array Fixed10Interval)) : Array (Array Fixed10Interval) := let useTasks := rows.size > 32 if useTasks then - let tasks := rows.map (fun row => - Task.spawn (fun _ => - mlpRowFromScaled cfg geluDerivTarget slack modelDim hiddenDim wIn wOut bIn bOut row)) - tasks.map (fun t => t.get) + Id.run do + let chunkSize : Nat := 16 + let numChunks : Nat := (rows.size + chunkSize - 1) / chunkSize + let mut tasks : Array (Task (Array (Array Fixed10Interval))) := Array.mkEmpty numChunks + let mut chunkIdx : Nat := 0 + while chunkIdx < numChunks do + let start := chunkIdx * chunkSize + let stop := min rows.size (start + chunkSize) + tasks := tasks.push <| + Task.spawn (fun _ => + Id.run do + let mut outChunk : Array (Array Fixed10Interval) := Array.mkEmpty (stop - start) + let mut i := start + while i < stop do + outChunk := outChunk.push + (mlpRowFromScaledNoTask cfg geluDerivTarget slack modelDim hiddenDim wIn wOut bIn + bOut rows[i]!) + i := i + 1 + return outChunk) + chunkIdx := chunkIdx + 1 + let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size + for t in tasks do + for row in t.get do + out := out.push row + return out + else + rows.map (mlpRowFromScaledNoTask cfg geluDerivTarget slack modelDim hiddenDim wIn wOut bIn bOut) + +/-- Linear GeLU-hull per-row MLP for best-match induction hot paths. -/ +private def mlpRowsFromScaledLinear + (cfg : Fixed10Cfg) + (slack : Int) + (modelDim hiddenDim : Nat) + (wIn wOut : Array Int) + (bIn bOut : Array Fixed10Interval) + (rows : Array (Array Fixed10Interval)) : Array (Array Fixed10Interval) := + let wInIntervals := intervalsFromScaled wIn slack + let wOutIntervals := intervalsFromScaled wOut slack + let mlpRowFromIntervals (row : Array Fixed10Interval) : Array Fixed10Interval := + let hidden0 := matMulIntervalsFromIntervalsNoTask cfg modelDim hiddenDim wInIntervals row + let hiddenB := addVecFixed hidden0 bIn + let actHidden := geluOverapproxFixedVecLinear hiddenB + let mlpOut0 := matMulIntervalsFromIntervalsNoTask cfg hiddenDim modelDim wOutIntervals actHidden + addVecFixed mlpOut0 bOut + let useTasks := rows.size > 32 + if useTasks then + Id.run do + let chunkSize : Nat := 16 + let numChunks : Nat := (rows.size + chunkSize - 1) / chunkSize + let mut tasks : Array (Task (Array (Array Fixed10Interval))) := Array.mkEmpty numChunks + let mut chunkIdx : Nat := 0 + while chunkIdx < numChunks do + let start := chunkIdx * chunkSize + let stop := min rows.size (start + chunkSize) + tasks := tasks.push <| + Task.spawn (fun _ => + Id.run do + let mut outChunk : Array (Array Fixed10Interval) := Array.mkEmpty (stop - start) + let mut i := start + while i < stop do + outChunk := outChunk.push (mlpRowFromIntervals rows[i]!) + i := i + 1 + return outChunk) + chunkIdx := chunkIdx + 1 + let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size + for t in tasks do + for row in t.get do + out := out.push row + return out else - rows.map (mlpRowFromScaled cfg geluDerivTarget slack modelDim hiddenDim wIn wOut bIn bOut) + rows.map mlpRowFromIntervals private def groupUnionRowsByToken (rows : Array (Array Fixed10Interval)) @@ -1383,7 +2218,8 @@ private def groupUnionRowsByToken return rows let mut uniqTokens : Array Int := #[] let mut uniqRows : Array (Array Fixed10Interval) := #[] - for i in [:rows.size] do + let mut i : Nat := 0 + while i < rows.size do let tok := tokens[i]! match uniqTokens.findIdx? (· == tok) with | some idx => @@ -1392,6 +2228,7 @@ private def groupUnionRowsByToken | none => uniqTokens := uniqTokens.push tok uniqRows := uniqRows.push rows[i]! + i := i + 1 return uniqRows private def unionRowsFixed @@ -1401,13 +2238,17 @@ private def unionRowsFixed else Id.run do let mut out := rows[0]! - for i in [1:rows.size] do + let mut i : Nat := 1 + while i < rows.size do let row := rows[i]! if row.size = out.size then - for j in [:out.size] do + let mut j : Nat := 0 + while j < out.size do let cur := out[j]! let r := row[j]! out := out.set! j { lo := min cur.lo r.lo, hi := max cur.hi r.hi } + j := j + 1 + i := i + 1 return out private def prefixUnionRowsFixed @@ -1419,9 +2260,11 @@ private def prefixUnionRowsFixed let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size let mut acc := rows[0]! out := out.push acc - for i in [1:rows.size] do + let mut i : Nat := 1 + while i < rows.size do acc := Fixed10Interval.unionVec acc rows[i]! out := out.push acc + i := i + 1 return out private def consumeMatrixMulAndNormInfFixed @@ -1437,9 +2280,11 @@ private def consumeMatrixMulAndNormInfFixed let mut out : Array Fixed10Interval := Array.replicate cols { lo := 0, hi := 0 } let mut curRowAbs : Int := 0 let mut maxRowAbs : Int := 0 - for rowIdx in [:rows] do + let mut rowIdx : Nat := 0 + while rowIdx < rows do let xi := input[rowIdx]! - for colIdx in [:cols] do + let mut colIdx : Nat := 0 + while colIdx < cols do let (w, rr2) ← Nfp.Untrusted.SoundCacheIO.I32Reader.readI32 rr rr := rr2 let wAbsBound : Int := (if w < 0 then -w else w) + slack @@ -1447,8 +2292,10 @@ private def consumeMatrixMulAndNormInfFixed let wI : Fixed10Interval := { lo := w - slack, hi := w + slack } let term := Fixed10Interval.mul cfg wI xi out := out.set! colIdx (Fixed10Interval.add (out[colIdx]!) term) + colIdx := colIdx + 1 maxRowAbs := max maxRowAbs curRowAbs curRowAbs := 0 + rowIdx := rowIdx + 1 let normInf : Rat := Rat.normalize maxRowAbs cfg.scaleNat (den_nz := by have h10pos : (0 : Nat) < 10 := by decide @@ -1473,9 +2320,11 @@ private def consumeMatrixMulAndNormInfFixedBinary let mut out : Array Fixed10Interval := Array.replicate cols { lo := 0, hi := 0 } let mut curRowAbs : Int := 0 let mut maxRowAbs : Int := 0 - for rowIdx in [:rows] do + let mut rowIdx : Nat := 0 + while rowIdx < rows do let xi := input[rowIdx]! - for colIdx in [:cols] do + let mut colIdx : Nat := 0 + while colIdx < cols do let idx := rowIdx * cols + colIdx let w := vals[idx]! let wAbsBound : Int := (if w < 0 then -w else w) + slack @@ -1483,14 +2332,95 @@ private def consumeMatrixMulAndNormInfFixedBinary let wI : Fixed10Interval := { lo := w - slack, hi := w + slack } let term := Fixed10Interval.mul cfg wI xi out := out.set! colIdx (Fixed10Interval.add (out[colIdx]!) term) + colIdx := colIdx + 1 maxRowAbs := max maxRowAbs curRowAbs curRowAbs := 0 + rowIdx := rowIdx + 1 let normInf : Rat := Rat.normalize maxRowAbs cfg.scaleNat (den_nz := by have h10pos : (0 : Nat) < 10 := by decide exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) return .ok (out, normInf) +private def consumeMatrixMulFixedBinaryStreaming + (cfg : Fixed10Cfg) + (slack : Int) + (h : IO.FS.Handle) + (rows cols : Nat) + (input : Array Fixed10Interval) + (scalePow10 : Nat) : + IO (Except String (Array Fixed10Interval)) := do + if input.size ≠ rows then + match ← skipF64Array h (rows * cols) with + | .error e => return .error e + | .ok _ => return .ok (Array.replicate cols { lo := 0, hi := 0 }) + let mut out : Array Fixed10Interval := Array.replicate cols { lo := 0, hi := 0 } + let mut rowIdx : Nat := 0 + while rowIdx < rows do + let rowWeightsE ← readScaledFloatArray h cols scalePow10 + match rowWeightsE with + | .error e => return .error e + | .ok rowWeights => + let xi := input[rowIdx]! + let mut colIdx : Nat := 0 + while colIdx < cols do + let w := rowWeights[colIdx]! + let wI : Fixed10Interval := { lo := w - slack, hi := w + slack } + let term := Fixed10Interval.mul cfg wI xi + out := out.set! colIdx (Fixed10Interval.add (out[colIdx]!) term) + colIdx := colIdx + 1 + rowIdx := rowIdx + 1 + return .ok out + +/-- Apply union-MLP propagation for binary bounds using streaming matmul. -/ +private def mlpUnionStepBinary + (cfg : Fixed10Cfg) + (slack : Int) + (h : IO.FS.Handle) + (modelDim hiddenDim : Nat) + (ln2Rows : Array (Array Fixed10Interval)) + (residuals : Array (Array Fixed10Interval)) + (scalePow10 : Nat) : + IO (Except String (Array (Array Fixed10Interval))) := do + let ln2Union := unionRowsFixed ln2Rows + let hidden0E ← + consumeMatrixMulFixedBinaryStreaming cfg slack h modelDim hiddenDim ln2Union scalePow10 + match hidden0E with + | .error e => return .error e + | .ok hidden0 => + let bInE ← readVecIntervalsBinary h hiddenDim slack scalePow10 + match bInE with + | .error e => return .error e + | .ok bIn => + let hiddenB := addVecFixed hidden0 bIn + -- Linear GeLU hull keeps the union path fast and avoids heavy tanh bounds. + let actHidden := geluOverapproxFixedVecLinear hiddenB + let mut mlpOut0 : Array Fixed10Interval := + Array.replicate modelDim { lo := 0, hi := 0 } + let mut rowIdx : Nat := 0 + while rowIdx < hiddenDim do + let rowWeightsE ← readScaledFloatArray h modelDim scalePow10 + match rowWeightsE with + | .error e => return .error e + | .ok rowWeights => + let xi := actHidden[rowIdx]! + let mut colIdx : Nat := 0 + while colIdx < modelDim do + let w := rowWeights[colIdx]! + let wI : Fixed10Interval := { lo := w - slack, hi := w + slack } + let term := Fixed10Interval.mul cfg wI xi + mlpOut0 := mlpOut0.set! colIdx + (Fixed10Interval.add (mlpOut0[colIdx]!) term) + colIdx := colIdx + 1 + rowIdx := rowIdx + 1 + let bOutE ← readVecIntervalsBinary h modelDim slack scalePow10 + match bOutE with + | .error e => return .error e + | .ok bOut => + let mlpOut := addVecFixed mlpOut0 bOut + let residuals' := addVecFixedRows residuals mlpOut + return .ok residuals' + private def loadEmbeddingsUnionFixed (cfg : Fixed10Cfg) (path : System.FilePath) @@ -1632,15 +2562,47 @@ private def loadEmbeddingsIntervalsBinary | .ok scaled => if total = 0 then return .ok #[] - let mut rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for rowIdx in [:hdr.seqLen] do - let mut row : Array Fixed10Interval := Array.mkEmpty hdr.modelDim - for colIdx in [:hdr.modelDim] do - let idx := rowIdx * hdr.modelDim + colIdx - let v := scaled[idx]! - row := row.push { lo := v - deltaScaled, hi := v + deltaScaled } - rows := rows.push row - return .ok rows + let useTasks := hdr.seqLen > 32 + if useTasks then + let chunkSize : Nat := 16 + let numChunks : Nat := (hdr.seqLen + chunkSize - 1) / chunkSize + let mut tasks : + Array (Task (Array (Array Fixed10Interval))) := Array.mkEmpty numChunks + let mut chunkIdx : Nat := 0 + while chunkIdx < numChunks do + let start := chunkIdx * chunkSize + let stop := min hdr.seqLen (start + chunkSize) + tasks := tasks.push <| + Task.spawn (fun _ => + Id.run do + let mut rowsChunk : + Array (Array Fixed10Interval) := Array.mkEmpty (stop - start) + let mut rowIdx := start + while rowIdx < stop do + let mut row : Array Fixed10Interval := Array.mkEmpty hdr.modelDim + for colIdx in [:hdr.modelDim] do + let idx := rowIdx * hdr.modelDim + colIdx + let v := scaled[idx]! + row := row.push { lo := v - deltaScaled, hi := v + deltaScaled } + rowsChunk := rowsChunk.push row + rowIdx := rowIdx + 1 + return rowsChunk) + chunkIdx := chunkIdx + 1 + let mut rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen + for t in tasks do + for row in t.get do + rows := rows.push row + return .ok rows + else + let mut rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen + for rowIdx in [:hdr.seqLen] do + let mut row : Array Fixed10Interval := Array.mkEmpty hdr.modelDim + for colIdx in [:hdr.modelDim] do + let idx := rowIdx * hdr.modelDim + colIdx + let v := scaled[idx]! + row := row.push { lo := v - deltaScaled, hi := v + deltaScaled } + rows := rows.push row + return .ok rows private def loadTokensBinary (path : System.FilePath) : IO (Except String (BinaryHeader × Array Int)) := do @@ -1677,12 +2639,23 @@ private def loadSharedBinaryInputs IO (Except String SharedBinaryInputs) := do let slack : Int := fixedUlpSlack let action : ExceptT String IO SharedBinaryInputs := do + let paramsTask ← + ExceptT.lift <| IO.asTask (collectLayerNormParamsBinary path scalePow10 slack) + let tokensTask ← + ExceptT.lift <| IO.asTask (loadTokensBinary inputPath) let (hdr, ln1Params, ln2Params) ← - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) + match paramsTask.get with + | .error e => throw (toString e) + | .ok (.error msg) => throw msg + | .ok (.ok v) => pure v + let (hdrTok, tokens) ← + match tokensTask.get with + | .error e => throw (toString e) + | .ok (.error msg) => throw msg + | .ok (.ok v) => pure v let residuals0 ← ExceptT.mk (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) - let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) if hdrTok.seqLen ≠ hdr.seqLen then throw "token/embedding seq_len mismatch" return { @@ -1747,6 +2720,7 @@ private def certifyHeadValueLowerBoundLocalBinaryAt (inputPath : System.FilePath) (inputDelta : Rat) (targetOffset : Int) + (keyOffset : Int) (matchWeightLowerBound : Rat) (maxSeqLen : Nat) (scalePow10 : Nat := defaultBinaryScalePow10) @@ -1800,6 +2774,9 @@ private def certifyHeadValueLowerBoundLocalBinaryAt if causalPattern then takePrefix residualsBase seqLenEff else residualsBase let tokens := if causalPattern then takePrefix tokensBase seqLenEff else tokensBase pure (residuals0, tokens) + let keyOffsetNat? : Option Nat := + if keyOffset ≥ 0 then some (Int.toNat keyOffset) else none + let keyOffsetNeg : Nat := Int.toNat (-keyOffset) let h ← IO.FS.Handle.mk path IO.FS.Mode.read let _ ← ExceptT.mk (readBinaryHeader h) let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) @@ -1870,7 +2847,17 @@ private def certifyHeadValueLowerBoundLocalBinaryAt if !causalPattern || j ≤ queryPos then let row := vOutRows[j]! let vCoord := row[coord]!.lo - if tokens[j]! = targetTok then + let isMatch : Bool := + match keyOffsetNat? with + | some k => + let idx := j + k + idx < seqLenEff && tokens[idx]! = targetTok + | none => + if j < keyOffsetNeg then + false + else + tokens[j - keyOffsetNeg]! = targetTok + if isMatch then matchLo? := match matchLo? with | none => some vCoord @@ -1927,13 +2914,16 @@ private def certifyHeadValueLowerBoundLocalBinaryAt let wo ← ExceptT.mk <| readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in ln1Rows do + let mut rowIdx : Nat := 0 + while rowIdx < ln1Rows.size do + let row := ln1Rows[rowIdx]! let vHidden0 := matMulIntervalsFromScaled cfg slack hdr.modelDim hdr.headDim wv row let vHidden := addVecFixed vHidden0 bV let vOut := matMulIntervalsFromScaled cfg slack hdr.headDim hdr.modelDim wo vHidden vOutRows := vOutRows.push vOut + rowIdx := rowIdx + 1 let headRows := prefixUnionRowsFixed vOutRows attnRows := addRowsFixed attnRows headRows let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) @@ -1954,13 +2944,16 @@ private def certifyHeadValueLowerBoundLocalBinaryAt let wo ← ExceptT.mk <| readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty groupRows.size - for row in groupRows do + let mut rowIdx : Nat := 0 + while rowIdx < groupRows.size do + let row := groupRows[rowIdx]! let vHidden0 := matMulIntervalsFromScaled cfg slack hdr.modelDim hdr.headDim wv row let vHidden := addVecFixed vHidden0 bV let vOut := matMulIntervalsFromScaled cfg slack hdr.headDim hdr.modelDim wo vHidden vOutRows := vOutRows.push vOut + rowIdx := rowIdx + 1 let vUnion := unionRowsFixed vOutRows attnUnion := addVecFixed attnUnion vUnion let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) @@ -1975,14 +2968,16 @@ private def certifyHeadValueLowerBoundLocalBinaryAt let _ ← ExceptT.mk (skipF64Array h hdr.headDim) let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let (vHidden0, _nWv) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.headDim ln1Union scalePow10) + let wv ← + ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 + let vHidden0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wv ln1Union let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) let vHidden := addVecFixed vHidden0 bV - let (vOut, _nWo) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.headDim hdr.modelDim vHidden scalePow10) + let wo ← + ExceptT.mk <| readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 + let vOut := matMulIntervalsFromScaled cfg slack + hdr.headDim hdr.modelDim wo vHidden attnUnion := addVecFixed attnUnion vOut let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) attnUnion := addVecFixed attnUnion attnBias @@ -2036,6 +3031,7 @@ private def certifyHeadValueLogitLowerBoundLocalBinaryAt (inputPath : System.FilePath) (inputDelta : Rat) (targetOffset : Int) + (keyOffset : Int) (matchWeightLowerBound : Rat) (maxSeqLen : Nat) (scalePow10 : Nat := defaultBinaryScalePow10) @@ -2092,6 +3088,9 @@ private def certifyHeadValueLogitLowerBoundLocalBinaryAt if causalPattern then takePrefix residualsBase seqLenEff else residualsBase let tokens := if causalPattern then takePrefix tokensBase seqLenEff else tokensBase pure (residuals0, tokens) + let keyOffsetNat? : Option Nat := + if keyOffset ≥ 0 then some (Int.toNat keyOffset) else none + let keyOffsetNeg : Nat := Int.toNat (-keyOffset) let h ← IO.FS.Handle.mk path IO.FS.Mode.read let _ ← ExceptT.mk (readBinaryHeader h) let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) @@ -2103,11 +3102,7 @@ private def certifyHeadValueLogitLowerBoundLocalBinaryAt let mut residuals := residuals0 for l in [:hdr.numLayers] do let p1 := ln1Params.getD l defP - let mut ln1Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln1Out, _ln1VarLB) := - fixedLayerNormRowApprox cfg row p1.gamma p1.beta eps soundnessBits - ln1Rows := ln1Rows.push ln1Out + let ln1Rows := fixedLayerNormRowsApprox cfg residuals p1 eps soundnessBits if l = layerIdx then let mut wv? : Option (Array Int) := none let mut bv? : Option (Array Int) := none @@ -2166,7 +3161,17 @@ private def certifyHeadValueLogitLowerBoundLocalBinaryAt if !causalPattern || j ≤ queryPos then let row := vOutRows[j]! let vCoord := row[coord]!.lo - if tokens[j]! = targetTok then + let isMatch : Bool := + match keyOffsetNat? with + | some k => + let idx := j + k + idx < seqLenEff && tokens[idx]! = targetTok + | none => + if j < keyOffsetNeg then + false + else + tokens[j - keyOffsetNeg]! = targetTok + if isMatch then matchLo? := match matchLo? with | none => some vCoord @@ -2208,15 +3213,34 @@ private def certifyHeadValueLogitLowerBoundLocalBinaryAt let dir := direction.get if dir.size ≠ hdr.modelDim then throw "logit direction size mismatch" - let mut vDotRows : Array Fixed10Interval := Array.mkEmpty seqLenEff - for row in vOutRows do - vDotRows := vDotRows.push (fixedDotInterval cfg row dir) + let vDotRows := + let useTasks := vOutRows.size > 32 + if useTasks then + let tasks := vOutRows.map (fun row => + Task.spawn (fun _ => fixedDotInterval cfg row dir)) + tasks.map (fun t => t.get) + else + Id.run do + let mut out : Array Fixed10Interval := Array.mkEmpty seqLenEff + for row in vOutRows do + out := out.push (fixedDotInterval cfg row dir) + return out let mut matchLoLogit? : Option Int := none let mut nonmatchLoLogit? : Option Int := none for j in [:seqLenEff] do if !causalPattern || j ≤ queryPos then let vLo := (vDotRows[j]!).lo - if tokens[j]! = targetTok then + let isMatch : Bool := + match keyOffsetNat? with + | some k => + let idx := j + k + idx < seqLenEff && tokens[idx]! = targetTok + | none => + if j < keyOffsetNeg then + false + else + tokens[j - keyOffsetNeg]! = targetTok + if isMatch then matchLoLogit? := match matchLoLogit? with | none => some vLo @@ -2325,14 +3349,16 @@ private def certifyHeadValueLogitLowerBoundLocalBinaryAt let _ ← ExceptT.mk (skipF64Array h hdr.headDim) let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let (vHidden0, _nWv) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.headDim ln1Union scalePow10) + let wv ← + ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 + let vHidden0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wv ln1Union let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) let vHidden := addVecFixed vHidden0 bV - let (vOut, _nWo) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.headDim hdr.modelDim vHidden scalePow10) + let wo ← + ExceptT.mk <| readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 + let vOut := matMulIntervalsFromScaled cfg slack + hdr.headDim hdr.modelDim wo vHidden attnUnion := addVecFixed attnUnion vOut let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) attnUnion := addVecFixed attnUnion attnBias @@ -2439,6 +3465,7 @@ private def certifyHeadLogitDiffLowerBoundLocalBinaryAt (inputPath : System.FilePath) (inputDelta : Rat) (targetOffset : Int) + (keyOffset : Int) (matchWeightLowerBound : Rat) (maxSeqLen : Nat) (scalePow10 : Nat := defaultBinaryScalePow10) @@ -2508,6 +3535,9 @@ private def certifyHeadLogitDiffLowerBoundLocalBinaryAt if causalPattern then takePrefix residualsBase seqLenEff else residualsBase let tokens := if causalPattern then takePrefix tokensBase seqLenEff else tokensBase pure (residuals0, tokens) + let keyOffsetNat? : Option Nat := + if keyOffset ≥ 0 then some (Int.toNat keyOffset) else none + let keyOffsetNeg : Nat := Int.toNat (-keyOffset) let h ← IO.FS.Handle.mk path IO.FS.Mode.read let _ ← ExceptT.mk (readBinaryHeader h) let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) @@ -2584,7 +3614,17 @@ private def certifyHeadLogitDiffLowerBoundLocalBinaryAt for j in [:seqLenEff] do if !causalPattern || j ≤ queryPos then let vLo := (vDotRows[j]!).lo - if tokens[j]! = targetTok then + let isMatch : Bool := + match keyOffsetNat? with + | some k => + let idx := j + k + idx < seqLenEff && tokens[idx]! = targetTok + | none => + if j < keyOffsetNeg then + false + else + tokens[j - keyOffsetNeg]! = targetTok + if isMatch then matchLo? := match matchLo? with | none => some vLo @@ -3607,6 +4647,7 @@ private def certifyHeadPatternLocalBinary (inputPath : System.FilePath) (inputDelta : Rat) (targetOffset : Int) + (keyOffset : Int) (maxSeqLen : Nat) (tightPattern : Bool) (tightPatternLayers : Nat) @@ -3632,6 +4673,9 @@ private def certifyHeadPatternLocalBinary let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) if hdrTok.seqLen ≠ hdr.seqLen then throw "token/embedding seq_len mismatch" + let keyOffsetNat? : Option Nat := + if keyOffset ≥ 0 then some (Int.toNat keyOffset) else none + let keyOffsetNeg : Nat := Int.toNat (-keyOffset) let h ← IO.FS.Handle.mk path IO.FS.Mode.read let _ ← ExceptT.mk (readBinaryHeader h) let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) @@ -3643,11 +4687,7 @@ private def certifyHeadPatternLocalBinary let mut residuals := residuals0 for l in [:hdr.numLayers] do let p1 := ln1Params.getD l defP - let mut ln1Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln1Out, _ln1VarLB) := - fixedLayerNormRowApprox cfg row p1.gamma p1.beta eps soundnessBits - ln1Rows := ln1Rows.push ln1Out + let ln1Rows := fixedLayerNormRowsApprox cfg residuals p1 eps soundnessBits if l = layerIdx then let mut wq? : Option (Array Int) := none let mut bq? : Option (Array Int) := none @@ -3703,17 +4743,21 @@ private def certifyHeadPatternLocalBinary let bKIntervals := intervalsFromScaled bK slack let mut qRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen let mut kRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in ln1Rows do + let mut rowIdx : Nat := 0 + while rowIdx < ln1Rows.size do + let row := ln1Rows[rowIdx]! let qRow0 := matMulIntervalsFromScaled cfg slack hdr.modelDim hdr.headDim wq row let kRow0 := matMulIntervalsFromScaled cfg slack hdr.modelDim hdr.headDim wk row qRows := qRows.push (addVecFixed qRow0 bQIntervals) kRows := kRows.push (addVecFixed kRow0 bKIntervals) + rowIdx := rowIdx + 1 let mut minTargetLower? : Option Int := none let mut maxOtherUpper? : Option Int := none let mut minTargetCount? : Option Nat := none - for i in [:hdr.seqLen] do + let mut i : Nat := 0 + while i < hdr.seqLen do let ti : Int := (Int.ofNat i) + targetOffset if ti < 0 || ti ≥ (Int.ofNat hdr.seqLen) then pure () @@ -3725,10 +4769,21 @@ private def certifyHeadPatternLocalBinary let mut targetMaxLower? : Option Int := none let mut maxOtherUpperRow? : Option Int := none let mut targetCount : Nat := 0 - for j in [:hdr.seqLen] do + let mut j : Nat := 0 + while j < hdr.seqLen do if !causalPattern || j ≤ i then let dot := fixedDotInterval cfg qRow (kRows[j]!) - if tokens[j]! = targetTok then + let isMatch : Bool := + match keyOffsetNat? with + | some k => + let idx := j + k + idx < hdr.seqLen && tokens[idx]! = targetTok + | none => + if j < keyOffsetNeg then + false + else + tokens[j - keyOffsetNeg]! = targetTok + if isMatch then targetCount := targetCount + 1 targetLower? := match targetLower? with @@ -3746,6 +4801,7 @@ private def certifyHeadPatternLocalBinary | some m => some (max m cur) else pure () + j := j + 1 let targetLowerRow? := if tightPattern then targetMaxLower? else targetLower? match targetLowerRow? with @@ -3767,6 +4823,7 @@ private def certifyHeadPatternLocalBinary match minTargetCount? with | none => some targetCount | some m => some (min m targetCount) + i := i + 1 let minTargetLower ← match minTargetLower? with | none => throw "no valid target positions for the requested offset" @@ -3792,6 +4849,7 @@ private def certifyHeadPatternLocalBinary headIdx := headIdx seqLen := hdr.seqLen targetOffset := targetOffset + keyOffset := keyOffset targetCountLowerBound := targetCountLB targetLogitLowerBound := targetLower otherLogitUpperBound := otherUpper @@ -3822,13 +4880,16 @@ private def certifyHeadPatternLocalBinary let wo ← ExceptT.mk <| readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in ln1Rows do + let mut rowIdx : Nat := 0 + while rowIdx < ln1Rows.size do + let row := ln1Rows[rowIdx]! let vHidden0 := matMulIntervalsFromScaled cfg slack hdr.modelDim hdr.headDim wv row let vHidden := addVecFixed vHidden0 bV let vOut := matMulIntervalsFromScaled cfg slack hdr.headDim hdr.modelDim wo vHidden vOutRows := vOutRows.push vOut + rowIdx := rowIdx + 1 let headRows := prefixUnionRowsFixed vOutRows attnRows := addRowsFixed attnRows headRows let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) @@ -3849,54 +4910,21 @@ private def certifyHeadPatternLocalBinary let wo ← ExceptT.mk <| readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty groupRows.size - for row in groupRows do + let mut rowIdx : Nat := 0 + while rowIdx < groupRows.size do + let row := groupRows[rowIdx]! let vHidden0 := matMulIntervalsFromScaled cfg slack hdr.modelDim hdr.headDim wv row let vHidden := addVecFixed vHidden0 bV let vOut := matMulIntervalsFromScaled cfg slack hdr.headDim hdr.modelDim wo vHidden vOutRows := vOutRows.push vOut + rowIdx := rowIdx + 1 let vUnion := unionRowsFixed vOutRows attnUnion := addVecFixed attnUnion vUnion let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) attnUnion := addVecFixed attnUnion attnBias residuals := addVecFixedRows residuals attnUnion - let p2 := ln2Params.getD l defP - let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln2Out, _ln2VarLB) := - fixedLayerNormRowApprox cfg row p2.gamma p2.beta eps soundnessBits - ln2Rows := ln2Rows.push ln2Out - let perRowLayers : Nat := perRowPatternLayers - if perRowLayers > 0 && layerIdx ≤ l + perRowLayers then - let wIn ← - ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.hiddenDim) scalePow10 - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let wOut ← - ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpRows := - mlpRowsFromScaled cfg hdr.geluDerivTarget slack - hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows - residuals := addRowsFixed residuals mlpRows - else - let ln2Union := unionRowsFixed ln2Rows - let (hidden0, _nWin) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Union scalePow10) - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let hiddenB := addVecFixed hidden0 bIn - let actHidden := geluOverapproxFixedVec cfg hdr.geluDerivTarget hiddenB - let (mlpOut0, _nWout) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.hiddenDim hdr.modelDim actHidden scalePow10) - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpOut := addVecFixed mlpOut0 bOut - residuals := addVecFixedRows residuals mlpOut - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) else let ln1Union := unionRowsFixed ln1Rows let mut attnUnion : Array Fixed10Interval := @@ -3918,29 +4946,29 @@ private def certifyHeadPatternLocalBinary let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) attnUnion := addVecFixed attnUnion attnBias residuals := addVecFixedRows residuals attnUnion - let p2 := ln2Params.getD l defP - let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln2Out, _ln2VarLB) := - fixedLayerNormRowApprox cfg row p2.gamma p2.beta eps soundnessBits - ln2Rows := ln2Rows.push ln2Out - let ln2Union := unionRowsFixed ln2Rows - let (hidden0, _nWin) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Union scalePow10) + let p2 := ln2Params.getD l defP + let ln2Rows := fixedLayerNormRowsApprox cfg residuals p2 eps soundnessBits + let perRowLayers : Nat := perRowPatternLayers + if perRowLayers > 0 && layerIdx ≤ l + perRowLayers then + let wIn ← + ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.hiddenDim) scalePow10 let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let hiddenB := addVecFixed hidden0 bIn - let actHidden := geluOverapproxFixedVec cfg hdr.geluDerivTarget hiddenB - let (mlpOut0, _nWout) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.hiddenDim hdr.modelDim actHidden scalePow10) + let wOut ← + ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpOut := addVecFixed mlpOut0 bOut - residuals := addVecFixedRows residuals mlpOut - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) + let mlpRows := + mlpRowsFromScaled cfg hdr.geluDerivTarget slack + hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows + residuals := addRowsFixed residuals mlpRows + else + let residuals' ← + ExceptT.mk (mlpUnionStepBinary cfg slack h + hdr.modelDim hdr.hiddenDim ln2Rows residuals scalePow10) + residuals := residuals' + let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) + let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) + let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) + let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) throw "target layer not reached" action.run @@ -3988,10 +5016,12 @@ private def certifyHeadPatternBestMatchLocalBinary (inputPath : System.FilePath) (inputDelta : Rat) (targetOffset : Int) + (keyOffset : Int) (maxSeqLen : Nat) (tightPattern : Bool) (tightPatternLayers : Nat) (perRowPatternLayers : Nat) + (useAffine : Bool) (scalePow10 : Nat := defaultBinaryScalePow10) (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) (causalPattern : Bool := true) @@ -4044,6 +5074,9 @@ private def certifyHeadPatternBestMatchLocalBinary if causalPattern then takePrefix residualsBase seqLenEff else residualsBase let tokens := if causalPattern then takePrefix tokensBase seqLenEff else tokensBase pure (residuals0, tokens) + let keyOffsetNat? : Option Nat := + if keyOffset ≥ 0 then some (Int.toNat keyOffset) else none + let keyOffsetNeg : Nat := Int.toNat (-keyOffset) let h ← IO.FS.Handle.mk path IO.FS.Mode.read let _ ← ExceptT.mk (readBinaryHeader h) let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) @@ -4109,14 +5142,6 @@ private def certifyHeadPatternBestMatchLocalBinary | some xs => pure xs let bQIntervals := intervalsFromScaled bQ slack let bKIntervals := intervalsFromScaled bK slack - let qRow := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wq (ln1Rows[queryPos]!) - let qRow := addVecFixed qRow bQIntervals - let mut kRows : Array (Array Fixed10Interval) := Array.mkEmpty seqLenEff - for row in ln1Rows do - let kRow0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wk row - kRows := kRows.push (addVecFixed kRow0 bKIntervals) let ti : Int := (Int.ofNat queryPos) + targetOffset if ti < 0 || ti ≥ (Int.ofNat seqLenEff) then throw "query position has no valid target offset" @@ -4124,21 +5149,150 @@ private def certifyHeadPatternBestMatchLocalBinary let targetTok := tokens[tIdx]! let mut bestMatchLower? : Option Int := none let mut bestNonmatchUpper? : Option Int := none - for j in [:seqLenEff] do - if !causalPattern || j ≤ queryPos then - let dot := fixedDotInterval cfg qRow (kRows[j]!) - if tokens[j]! = targetTok then - bestMatchLower? := - match bestMatchLower? with - | none => some dot.lo - | some m => some (max m dot.lo) - else - bestNonmatchUpper? := - match bestNonmatchUpper? with - | none => some dot.hi - | some m => some (max m dot.hi) + if useAffine then + let (qInputCenters, qInputRadii, _qAbsInput) := + rowCentersRadiiAbsInt (ln1Rows[queryPos]!) + let (qCenters0, qRadii0) := + matMulCentersRadiiIntSlack cfg slack + hdr.modelDim hdr.headDim wq qInputCenters qInputRadii + let bQCenters := bQ + let bKCenters := bK + let bQRadii := bQIntervals.map intervalRadiusInt + let bKRadii := bKIntervals.map intervalRadiusInt + let qCenters := addVecScaledInt qCenters0 bQCenters 1 + let qRadii := addVecScaledInt qRadii0 bQRadii 1 + let useTasks := seqLenEff > 32 + if useTasks then + let chunkSize : Nat := 16 + let numChunks : Nat := (seqLenEff + chunkSize - 1) / chunkSize + let mut tasks : Array (Task (Option Int × Option Int)) := Array.mkEmpty numChunks + let mut chunkIdx : Nat := 0 + while chunkIdx < numChunks do + let start := chunkIdx * chunkSize + let stop := min seqLenEff (start + chunkSize) + tasks := tasks.push <| Task.spawn (fun _ => + Id.run do + let mut bestMatchLower? : Option Int := none + let mut bestNonmatchUpper? : Option Int := none + let mut j := start + while j < stop do + if !causalPattern || j ≤ queryPos then + let (kInputCenters, kInputRadii, _kAbsInput) := + rowCentersRadiiAbsInt (ln1Rows[j]!) + let (kCenters0, kRadii0) := + matMulCentersRadiiIntSlack cfg slack + hdr.modelDim hdr.headDim wk kInputCenters kInputRadii + let kCenters := addVecScaledInt kCenters0 bKCenters 1 + let kRadii := addVecScaledInt kRadii0 bKRadii 1 + let dot := + dotIntervalFromCentersRadiiInt cfg qCenters qRadii kCenters kRadii + let isMatch : Bool := + match keyOffsetNat? with + | some k => + let idx := j + k + idx < seqLenEff && tokens[idx]! = targetTok + | none => + if j < keyOffsetNeg then + false + else + tokens[j - keyOffsetNeg]! = targetTok + if isMatch then + bestMatchLower? := + match bestMatchLower? with + | none => some dot.lo + | some m => some (max m dot.lo) + else + bestNonmatchUpper? := + match bestNonmatchUpper? with + | none => some dot.hi + | some m => some (max m dot.hi) + j := j + 1 + return (bestMatchLower?, bestNonmatchUpper?)) + chunkIdx := chunkIdx + 1 + for t in tasks do + let (matchChunk?, nonmatchChunk?) := t.get + if matchChunk?.isSome then + bestMatchLower? := + match bestMatchLower?, matchChunk? with + | none, some v => some v + | some cur, some v => some (max cur v) + | some cur, none => some cur + | none, none => none + if nonmatchChunk?.isSome then + bestNonmatchUpper? := + match bestNonmatchUpper?, nonmatchChunk? with + | none, some v => some v + | some cur, some v => some (max cur v) + | some cur, none => some cur + | none, none => none else - pure () + for j in [:seqLenEff] do + if !causalPattern || j ≤ queryPos then + let (kInputCenters, kInputRadii, _kAbsInput) := + rowCentersRadiiAbsInt (ln1Rows[j]!) + let (kCenters0, kRadii0) := + matMulCentersRadiiIntSlack cfg slack + hdr.modelDim hdr.headDim wk kInputCenters kInputRadii + let kCenters := addVecScaledInt kCenters0 bKCenters 1 + let kRadii := addVecScaledInt kRadii0 bKRadii 1 + let dot := + dotIntervalFromCentersRadiiInt cfg qCenters qRadii kCenters kRadii + let isMatch : Bool := + match keyOffsetNat? with + | some k => + let idx := j + k + idx < seqLenEff && tokens[idx]! = targetTok + | none => + if j < keyOffsetNeg then + false + else + tokens[j - keyOffsetNeg]! = targetTok + if isMatch then + bestMatchLower? := + match bestMatchLower? with + | none => some dot.lo + | some m => some (max m dot.lo) + else + bestNonmatchUpper? := + match bestNonmatchUpper? with + | none => some dot.hi + | some m => some (max m dot.hi) + else + pure () + else + let qRow := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wq (ln1Rows[queryPos]!) + let qRow := addVecFixed qRow bQIntervals + let mut kRows : Array (Array Fixed10Interval) := Array.mkEmpty seqLenEff + for row in ln1Rows do + let kRow0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wk row + kRows := kRows.push (addVecFixed kRow0 bKIntervals) + for j in [:seqLenEff] do + if !causalPattern || j ≤ queryPos then + let dot := fixedDotInterval cfg qRow (kRows[j]!) + let isMatch : Bool := + match keyOffsetNat? with + | some k => + let idx := j + k + idx < seqLenEff && tokens[idx]! = targetTok + | none => + if j < keyOffsetNeg then + false + else + tokens[j - keyOffsetNeg]! = targetTok + if isMatch then + bestMatchLower? := + match bestMatchLower? with + | none => some dot.lo + | some m => some (max m dot.lo) + else + bestNonmatchUpper? := + match bestNonmatchUpper? with + | none => some dot.hi + | some m => some (max m dot.hi) + else + pure () let bestMatchLower ← match bestMatchLower? with | none => throw "no matching keys for the requested offset" @@ -4159,6 +5313,7 @@ private def certifyHeadPatternBestMatchLocalBinary seqLen := hdr.seqLen queryPos := queryPos targetOffset := targetOffset + keyOffset := keyOffset targetToken := targetTok bestMatchLogitLowerBound := bestMatchLowerRat bestNonmatchLogitUpperBound := bestNonmatchUpperRat @@ -4298,10 +5453,12 @@ private def certifyHeadPatternBestMatchLocalBinarySweep (inputPath : System.FilePath) (inputDelta : Rat) (targetOffset : Int) + (keyOffset : Int) (maxSeqLen : Nat) (tightPattern : Bool) (tightPatternLayers : Nat) (perRowPatternLayers : Nat) + (useAffine : Bool) (scalePow10 : Nat := defaultBinaryScalePow10) (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) (causalPattern : Bool := true) @@ -4334,6 +5491,8 @@ private def certifyHeadPatternBestMatchLocalBinarySweep throw s!"head index {headIdx} out of range" if hdr.seqLen > maxSeqLen then throw s!"seq_len {hdr.seqLen} exceeds maxSeqLen {maxSeqLen}" + if useAffine then + throw "affine sweep is unsupported; use --bestMatch without --sweep" let h ← IO.FS.Handle.mk path IO.FS.Mode.read let _ ← ExceptT.mk (readBinaryHeader h) let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) @@ -4419,6 +5578,9 @@ private def certifyHeadPatternBestMatchLocalBinarySweep out if validPositions.isEmpty then throw "no valid query positions for the requested offset" + let keyOffsetNat? : Option Nat := + if keyOffset ≥ 0 then some (Int.toNat keyOffset) else none + let keyOffsetNeg : Nat := Int.toNat (-keyOffset) let computeCert : Nat → Except String HeadBestMatchPatternCert := fun queryPos => do let ti : Int := (Int.ofNat queryPos) + targetOffset if ti < 0 || ti ≥ (Int.ofNat hdr.seqLen) then @@ -4431,7 +5593,17 @@ private def certifyHeadPatternBestMatchLocalBinarySweep for j in [:hdr.seqLen] do if !causalPattern || j ≤ queryPos then let dot := fixedDotInterval cfg qRow (kRows[j]!) - if tokens[j]! = targetTok then + let isMatch : Bool := + match keyOffsetNat? with + | some k => + let idx := j + k + idx < hdr.seqLen && tokens[idx]! = targetTok + | none => + if j < keyOffsetNeg then + false + else + tokens[j - keyOffsetNeg]! = targetTok + if isMatch then bestMatchLower? := match bestMatchLower? with | none => some dot.lo @@ -4463,6 +5635,7 @@ private def certifyHeadPatternBestMatchLocalBinarySweep seqLen := hdr.seqLen queryPos := queryPos targetOffset := targetOffset + keyOffset := keyOffset targetToken := targetTok bestMatchLogitLowerBound := bestMatchLowerRat bestNonmatchLogitUpperBound := bestNonmatchUpperRat @@ -4645,6 +5818,9 @@ private def certifyHeadValueLowerBoundLocalBinary throw "token/embedding seq_len mismatch" if pattern.seqLen ≠ hdr.seqLen then throw "pattern seq_len mismatch" + let keyOffsetNat? : Option Nat := + if pattern.keyOffset ≥ 0 then some (Int.toNat pattern.keyOffset) else none + let keyOffsetNeg : Nat := Int.toNat (-pattern.keyOffset) let h ← IO.FS.Handle.mk path IO.FS.Mode.read let _ ← ExceptT.mk (readBinaryHeader h) let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) @@ -4723,7 +5899,17 @@ private def certifyHeadValueLowerBoundLocalBinary if !causalPattern || j ≤ i then let row := vOutRows[j]! let vCoord := row[coord]!.lo - if tokens[j]! = targetTok then + let isMatch : Bool := + match keyOffsetNat? with + | some k => + let idx := j + k + idx < hdr.seqLen && tokens[idx]! = targetTok + | none => + if j < keyOffsetNeg then + false + else + tokens[j - keyOffsetNeg]! = targetTok + if isMatch then matchLo? := match matchLo? with | none => some vCoord @@ -4935,6 +6121,9 @@ private def certifyHeadLogitDiffLowerBoundLocalBinary throw "token/embedding seq_len mismatch" if pattern.seqLen ≠ hdr.seqLen then throw "pattern seq_len mismatch" + let keyOffsetNat? : Option Nat := + if pattern.keyOffset ≥ 0 then some (Int.toNat pattern.keyOffset) else none + let keyOffsetNeg : Nat := Int.toNat (-pattern.keyOffset) let h ← IO.FS.Handle.mk path IO.FS.Mode.read let _ ← ExceptT.mk (readBinaryHeader h) let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) @@ -5015,7 +6204,17 @@ private def certifyHeadLogitDiffLowerBoundLocalBinary for j in [:hdr.seqLen] do if !causalPattern || j ≤ i then let vLo := (vDotRows[j]!).lo - if tokens[j]! = targetTok then + let isMatch : Bool := + match keyOffsetNat? with + | some k => + let idx := j + k + idx < hdr.seqLen && tokens[idx]! = targetTok + | none => + if j < keyOffsetNeg then + false + else + tokens[j - keyOffsetNeg]! = targetTok + if isMatch then matchLo? := match matchLo? with | none => some vLo @@ -5272,6 +6471,7 @@ def certifyHeadPatternLocal (inputPath? : Option System.FilePath := none) (inputDelta : Rat := 0) (targetOffset : Int := -1) + (keyOffset : Int := 0) (maxSeqLen : Nat := 256) (tightPattern : Bool := false) (tightPatternLayers : Nat := 1) @@ -5287,7 +6487,8 @@ def certifyHeadPatternLocal if firstLine = "NFP_BINARY_V1" then let inputPath := inputPath?.getD path certifyHeadPatternLocalBinary path layerIdx headIdx eps soundnessBits inputPath inputDelta - targetOffset maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 + targetOffset keyOffset maxSeqLen tightPattern tightPatternLayers perRowPatternLayers + scalePow10 softmaxExpEffort causalPattern else return .error "head pattern bounds require NFP_BINARY_V1" @@ -5302,10 +6503,12 @@ def certifyHeadPatternBestMatchLocal (inputPath? : Option System.FilePath := none) (inputDelta : Rat := 0) (targetOffset : Int := -1) + (keyOffset : Int := 0) (maxSeqLen : Nat := 256) (tightPattern : Bool := false) (tightPatternLayers : Nat := 1) (perRowPatternLayers : Nat := 0) + (useAffine : Bool := false) (scalePow10 : Nat := defaultBinaryScalePow10) (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) (causalPattern : Bool := true) : @@ -5318,8 +6521,8 @@ def certifyHeadPatternBestMatchLocal let inputPath := inputPath?.getD path certifyHeadPatternBestMatchLocalBinary path layerIdx headIdx queryPos? eps soundnessBits inputPath - inputDelta targetOffset maxSeqLen tightPattern tightPatternLayers - perRowPatternLayers scalePow10 softmaxExpEffort causalPattern + inputDelta targetOffset keyOffset maxSeqLen tightPattern tightPatternLayers + perRowPatternLayers useAffine scalePow10 softmaxExpEffort causalPattern else return .error "head pattern bounds require NFP_BINARY_V1" @@ -5332,10 +6535,12 @@ def certifyHeadPatternBestMatchLocalSweep (inputPath? : Option System.FilePath := none) (inputDelta : Rat := 0) (targetOffset : Int := -1) + (keyOffset : Int := 0) (maxSeqLen : Nat := 256) (tightPattern : Bool := false) (tightPatternLayers : Nat := 1) (perRowPatternLayers : Nat := 0) + (useAffine : Bool := false) (scalePow10 : Nat := defaultBinaryScalePow10) (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) (causalPattern : Bool := true) : @@ -5347,8 +6552,9 @@ def certifyHeadPatternBestMatchLocalSweep if firstLine = "NFP_BINARY_V1" then let inputPath := inputPath?.getD path certifyHeadPatternBestMatchLocalBinarySweep path layerIdx headIdx eps soundnessBits inputPath - inputDelta targetOffset maxSeqLen tightPattern tightPatternLayers perRowPatternLayers - scalePow10 softmaxExpEffort causalPattern + inputDelta targetOffset keyOffset maxSeqLen tightPattern tightPatternLayers + perRowPatternLayers + useAffine scalePow10 softmaxExpEffort causalPattern else return .error "head pattern bounds require NFP_BINARY_V1" @@ -5361,6 +6567,7 @@ def certifyLayerBestMatchMarginLocal (inputPath? : Option System.FilePath := none) (inputDelta : Rat := 0) (targetOffset : Int := -1) + (keyOffset : Int := 0) (maxSeqLen : Nat := 256) (tightPattern : Bool := false) (tightPatternLayers : Nat := 1) @@ -5387,8 +6594,8 @@ def certifyLayerBestMatchMarginLocal for hIdx in [:hdr.numHeads] do match ← certifyHeadPatternBestMatchLocalBinarySweep - path layerIdx hIdx eps soundnessBits inputPath inputDelta targetOffset - maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 + path layerIdx hIdx eps soundnessBits inputPath inputDelta targetOffset keyOffset + maxSeqLen tightPattern tightPatternLayers perRowPatternLayers false scalePow10 softmaxExpEffort causalPattern with | .error e => return .error e | .ok certs => @@ -5422,6 +6629,7 @@ def certifyHeadValueLowerBoundLocal (inputPath? : Option System.FilePath := none) (inputDelta : Rat := 0) (targetOffset : Int := -1) + (keyOffset : Int := 0) (maxSeqLen : Nat := 256) (tightPattern : Bool := false) (tightPatternLayers : Nat := 1) @@ -5437,7 +6645,8 @@ def certifyHeadValueLowerBoundLocal let inputPath := inputPath?.getD path let patternE ← certifyHeadPatternLocalBinary path layerIdx headIdx eps soundnessBits inputPath inputDelta - targetOffset maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 + targetOffset keyOffset maxSeqLen tightPattern tightPatternLayers perRowPatternLayers + scalePow10 defaultSoftmaxExpEffort causalPattern match patternE with | .error e => return .error e @@ -5457,6 +6666,7 @@ def certifyHeadLogitDiffLowerBoundLocal (inputPath? : Option System.FilePath := none) (inputDelta : Rat := 0) (targetOffset : Int := -1) + (keyOffset : Int := 0) (maxSeqLen : Nat := 256) (tightPattern : Bool := false) (tightPatternLayers : Nat := 1) @@ -5472,7 +6682,8 @@ def certifyHeadLogitDiffLowerBoundLocal let inputPath := inputPath?.getD path let patternE ← certifyHeadPatternLocalBinary path layerIdx headIdx eps soundnessBits inputPath inputDelta - targetOffset maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 + targetOffset keyOffset maxSeqLen tightPattern tightPatternLayers perRowPatternLayers + scalePow10 defaultSoftmaxExpEffort causalPattern match patternE with | .error e => return .error e @@ -5493,6 +6704,8 @@ def certifyInductionSound (inputDelta : Rat := 0) (offset1 : Int := -1) (offset2 : Int := -1) + (keyOffset1 : Int := 0) + (keyOffset2 : Int := 0) (maxSeqLen : Nat := 256) (scalePow10 : Nat := defaultBinaryScalePow10) (tightPattern : Bool := false) @@ -5511,14 +6724,15 @@ def certifyInductionSound let inputPath := inputPath?.getD path let p1E ← certifyHeadPatternLocalBinary path layer1 head1 eps soundnessBits inputPath inputDelta - offset1 maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 + offset1 keyOffset1 maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 softmaxExpEffort causalPattern match p1E with | .error e => return .error e | .ok p1 => let p2E ← certifyHeadPatternLocalBinary path layer2 head2 eps soundnessBits inputPath inputDelta - offset2 maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 + offset2 keyOffset2 maxSeqLen tightPattern tightPatternLayers perRowPatternLayers + scalePow10 softmaxExpEffort causalPattern match p2E with | .error e => return .error e @@ -5567,11 +6781,14 @@ private def certifyInductionSoundBestMatchLocalBinaryPair (inputDelta : Rat) (offset1 : Int) (offset2 : Int) + (keyOffset1 : Int) + (keyOffset2 : Int) (maxSeqLen : Nat) (scalePow10 : Nat := defaultBinaryScalePow10) (tightPattern : Bool) (tightPatternLayers : Nat) (perRowPatternLayers : Nat) + (useAffine : Bool) (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) (causalPattern : Bool := true) (shared? : Option SharedBinaryInputs := none) @@ -5627,6 +6844,53 @@ private def certifyInductionSoundBestMatchLocalBinaryPair if causalPattern then takePrefix residualsBase seqLenEff else residualsBase let tokens := if causalPattern then takePrefix tokensBase seqLenEff else tokensBase pure (residuals0, tokens) + let matchRows + (targetOffset : Int) + (keyOffset : Int) : Array Nat := + Id.run do + let ti : Int := (Int.ofNat queryPos) + targetOffset + if ti < 0 || ti ≥ (Int.ofNat seqLenEff) then + return #[] + let tIdx : Nat := Int.toNat ti + let targetTok := tokens[tIdx]! + let keyOffsetNat? : Option Nat := + if keyOffset ≥ 0 then some (Int.toNat keyOffset) else none + let keyOffsetNeg : Nat := Int.toNat (-keyOffset) + let mut rows : Array Nat := Array.mkEmpty seqLenEff + for j in [:seqLenEff] do + if !causalPattern || j ≤ queryPos then + let isMatch : Bool := + match keyOffsetNat? with + | some k => + let idx := j + k + idx < seqLenEff && tokens[idx]! = targetTok + | none => + if j < keyOffsetNeg then + false + else + tokens[j - keyOffsetNeg]! = targetTok + if isMatch then + rows := rows.push j + else + pure () + return rows + let selectedRows : Array Nat := + let r1 := matchRows offset1 keyOffset1 + let r2 := matchRows offset2 keyOffset2 + if r1.isEmpty && r2.isEmpty then + #[] + else + Id.run do + let mut acc : Array Nat := Array.mkEmpty (r1.size + r2.size) + for v in r1 do + if !acc.contains v then + acc := acc.push v + for v in r2 do + if !acc.contains v then + acc := acc.push v + acc + let selectedRows? : Option (Array Nat) := + if selectedRows.isEmpty then none else some selectedRows let useLogit ← match targetToken?, negativeToken?, direction? with | none, none, none => pure false @@ -5637,21 +6901,59 @@ private def certifyInductionSoundBestMatchLocalBinaryPair (p : LayerNormParamsFixed) : Array (Array Fixed10Interval) := fixedLayerNormRowsApprox cfg rows p eps soundnessBits + let calcLnRowsExact + (rows : Array (Array Fixed10Interval)) + (p : LayerNormParamsFixed) : + Array (Array Fixed10Interval) := + fixedLayerNormRowsApproxExact cfg rows p eps soundnessBits let calcVOutRows (rows : Array (Array Fixed10Interval)) (wv wo : Array Int) (bV : Array Fixed10Interval) : Array (Array Fixed10Interval) := - Id.run do - let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size - for row in rows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - out := out.push vOut - return out + let wvIntervals := intervalsFromScaled wv slack + let woIntervals := intervalsFromScaled wo slack + let useTasks := rows.size > 32 + if useTasks then + Id.run do + let chunkSize : Nat := 16 + let numChunks : Nat := (rows.size + chunkSize - 1) / chunkSize + let mut tasks : Array (Task (Array (Array Fixed10Interval))) := Array.mkEmpty numChunks + let mut chunkIdx : Nat := 0 + while chunkIdx < numChunks do + let start := chunkIdx * chunkSize + let stop := min rows.size (start + chunkSize) + tasks := tasks.push <| + Task.spawn (fun _ => + Id.run do + let mut outChunk : Array (Array Fixed10Interval) := Array.mkEmpty (stop - start) + let mut i := start + while i < stop do + let vHidden0 := matMulIntervalsFromIntervalsNoTask cfg + hdr.modelDim hdr.headDim wvIntervals (rows[i]!) + let vHidden := addVecFixed vHidden0 bV + let vOut := matMulIntervalsFromIntervalsNoTask cfg + hdr.headDim hdr.modelDim woIntervals vHidden + outChunk := outChunk.push vOut + i := i + 1 + return outChunk) + chunkIdx := chunkIdx + 1 + let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size + for t in tasks do + for row in t.get do + out := out.push row + return out + else + Id.run do + let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size + for row in rows do + let vHidden0 := matMulIntervalsFromIntervalsNoTask cfg + hdr.modelDim hdr.headDim wvIntervals row + let vHidden := addVecFixed vHidden0 bV + let vOut := matMulIntervalsFromIntervalsNoTask cfg + hdr.headDim hdr.modelDim woIntervals vHidden + out := out.push vOut + return out let calcVOut (row : Array Fixed10Interval) (wv wo : Array Int) @@ -5666,38 +6968,176 @@ private def certifyInductionSoundBestMatchLocalBinaryPair (ln1Rows : Array (Array Fixed10Interval)) (wq wk : Array Int) (bQ bK : Array Fixed10Interval) - (targetOffset : Int) : + (targetOffset : Int) + (keyOffset : Int) + (useTasks : Bool := true) : ExceptT String IO HeadBestMatchPatternCert := do - let qRow0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wq (ln1Rows[queryPos]!) - let qRow := addVecFixed qRow0 bQ - let mut kRows : Array (Array Fixed10Interval) := Array.mkEmpty seqLenEff - for row in ln1Rows do - let kRow0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wk row - kRows := kRows.push (addVecFixed kRow0 bK) let ti : Int := (Int.ofNat queryPos) + targetOffset if ti < 0 || ti ≥ (Int.ofNat seqLenEff) then throw "query position has no valid target offset" let tIdx : Nat := Int.toNat ti let targetTok := tokens[tIdx]! + let keyOffsetNat? : Option Nat := + if keyOffset ≥ 0 then some (Int.toNat keyOffset) else none + let keyOffsetNeg : Nat := Int.toNat (-keyOffset) let mut bestMatchLower? : Option Int := none let mut bestNonmatchUpper? : Option Int := none - for j in [:seqLenEff] do - if !causalPattern || j ≤ queryPos then - let dot := fixedDotInterval cfg qRow (kRows[j]!) - if tokens[j]! = targetTok then - bestMatchLower? := - match bestMatchLower? with - | none => some dot.lo - | some m => some (max m dot.lo) - else - bestNonmatchUpper? := - match bestNonmatchUpper? with - | none => some dot.hi - | some m => some (max m dot.hi) + if useAffine then + let bQCenters := bQ.map (fun x => (x.lo + x.hi).ediv (Int.ofNat 2)) + let bKCenters := bK.map (fun x => (x.lo + x.hi).ediv (Int.ofNat 2)) + let bQRadii := bQ.map intervalRadiusInt + let bKRadii := bK.map intervalRadiusInt + let (qInputCenters, qInputRadii, _qAbsInput) := + rowCentersRadiiAbsInt (ln1Rows[queryPos]!) + let (qCenters0, qRadii0) := + matMulCentersRadiiIntSlack cfg slack + hdr.modelDim hdr.headDim wq qInputCenters qInputRadii + let qCenters := addVecScaledInt qCenters0 bQCenters 1 + let qRadii := addVecScaledInt qRadii0 bQRadii 1 + let useTasksHere := useTasks && seqLenEff > 32 + if useTasksHere then + let chunkSize : Nat := 16 + let numChunks : Nat := (seqLenEff + chunkSize - 1) / chunkSize + let mut tasks : Array (Task (Option Int × Option Int)) := Array.mkEmpty numChunks + let mut chunkIdx : Nat := 0 + while chunkIdx < numChunks do + let start := chunkIdx * chunkSize + let stop := min seqLenEff (start + chunkSize) + tasks := tasks.push <| Task.spawn (fun _ => + Id.run do + let mut bestMatchLower? : Option Int := none + let mut bestNonmatchUpper? : Option Int := none + let mut j := start + while j < stop do + if !causalPattern || j ≤ queryPos then + let (kInputCenters, kInputRadii, _kAbsInput) := + rowCentersRadiiAbsInt (ln1Rows[j]!) + let (kCenters0, kRadii0) := + matMulCentersRadiiIntSlack cfg slack + hdr.modelDim hdr.headDim wk kInputCenters kInputRadii + let kCenters := addVecScaledInt kCenters0 bKCenters 1 + let kRadii := addVecScaledInt kRadii0 bKRadii 1 + let dot := + dotIntervalFromCentersRadiiInt cfg qCenters qRadii kCenters kRadii + let isMatch : Bool := + match keyOffsetNat? with + | some k => + let idx := j + k + idx < seqLenEff && tokens[idx]! = targetTok + | none => + if j < keyOffsetNeg then + false + else + tokens[j - keyOffsetNeg]! = targetTok + if isMatch then + bestMatchLower? := + match bestMatchLower? with + | none => some dot.lo + | some m => some (max m dot.lo) + else + bestNonmatchUpper? := + match bestNonmatchUpper? with + | none => some dot.hi + | some m => some (max m dot.hi) + j := j + 1 + return (bestMatchLower?, bestNonmatchUpper?)) + chunkIdx := chunkIdx + 1 + for t in tasks do + let (matchChunk?, nonmatchChunk?) := t.get + if matchChunk?.isSome then + bestMatchLower? := + match bestMatchLower?, matchChunk? with + | none, some v => some v + | some cur, some v => some (max cur v) + | some cur, none => some cur + | none, none => none + if nonmatchChunk?.isSome then + bestNonmatchUpper? := + match bestNonmatchUpper?, nonmatchChunk? with + | none, some v => some v + | some cur, some v => some (max cur v) + | some cur, none => some cur + | none, none => none else - pure () + for j in [:seqLenEff] do + if !causalPattern || j ≤ queryPos then + let (kInputCenters, kInputRadii, _kAbsInput) := + rowCentersRadiiAbsInt (ln1Rows[j]!) + let (kCenters0, kRadii0) := + matMulCentersRadiiIntSlack cfg slack + hdr.modelDim hdr.headDim wk kInputCenters kInputRadii + let kCenters := addVecScaledInt kCenters0 bKCenters 1 + let kRadii := addVecScaledInt kRadii0 bKRadii 1 + let dot := + dotIntervalFromCentersRadiiInt cfg qCenters qRadii kCenters kRadii + let isMatch : Bool := + match keyOffsetNat? with + | some k => + let idx := j + k + idx < seqLenEff && tokens[idx]! = targetTok + | none => + if j < keyOffsetNeg then + false + else + tokens[j - keyOffsetNeg]! = targetTok + if isMatch then + bestMatchLower? := + match bestMatchLower? with + | none => some dot.lo + | some m => some (max m dot.lo) + else + bestNonmatchUpper? := + match bestNonmatchUpper? with + | none => some dot.hi + | some m => some (max m dot.hi) + else + pure () + else + let qRow0 := matMulIntervalsFromScaled cfg slack + hdr.modelDim hdr.headDim wq (ln1Rows[queryPos]!) + let qRow := addVecFixed qRow0 bQ + let kRows := + let useTasksHere := useTasks && ln1Rows.size > 32 + if useTasksHere then + let tasks := ln1Rows.map (fun row => + Task.spawn (fun _ => + let kRow0 := matMulIntervalsFromScaledNoTask cfg slack + hdr.modelDim hdr.headDim wk row + addVecFixed kRow0 bK)) + tasks.map (fun t => t.get) + else + Id.run do + let mut out : Array (Array Fixed10Interval) := Array.mkEmpty seqLenEff + for row in ln1Rows do + let kRow0 := matMulIntervalsFromScaledNoTask cfg slack + hdr.modelDim hdr.headDim wk row + out := out.push (addVecFixed kRow0 bK) + return out + for j in [:seqLenEff] do + if !causalPattern || j ≤ queryPos then + let dot := fixedDotInterval cfg qRow (kRows[j]!) + let isMatch : Bool := + match keyOffsetNat? with + | some k => + let idx := j + k + idx < seqLenEff && tokens[idx]! = targetTok + | none => + if j < keyOffsetNeg then + false + else + tokens[j - keyOffsetNeg]! = targetTok + if isMatch then + bestMatchLower? := + match bestMatchLower? with + | none => some dot.lo + | some m => some (max m dot.lo) + else + bestNonmatchUpper? := + match bestNonmatchUpper? with + | none => some dot.hi + | some m => some (max m dot.hi) + else + pure () let bestMatchLower ← match bestMatchLower? with | none => throw "no matching keys for the requested offset" @@ -5718,6 +7158,7 @@ private def certifyInductionSoundBestMatchLocalBinaryPair seqLen := hdr.seqLen queryPos := queryPos targetOffset := targetOffset + keyOffset := keyOffset targetToken := targetTok bestMatchLogitLowerBound := bestMatchLowerRat bestNonmatchLogitUpperBound := bestNonmatchUpperRat @@ -5734,7 +7175,8 @@ private def certifyInductionSoundBestMatchLocalBinaryPair (matchWeightLowerBound : Rat) (wv wo : Array Int) (bV : Array Fixed10Interval) - (targetOffset : Int) : + (targetOffset : Int) + (keyOffset : Int) : ExceptT String IO HeadValueLogitCert := do let vOutRows := calcVOutRows ln1Rows wv wo bV let ti : Int := (Int.ofNat queryPos) + targetOffset @@ -5742,13 +7184,26 @@ private def certifyInductionSoundBestMatchLocalBinaryPair throw "query position has no valid target offset" let tIdx : Nat := Int.toNat ti let targetTok := tokens[tIdx]! + let keyOffsetNat? : Option Nat := + if keyOffset ≥ 0 then some (Int.toNat keyOffset) else none + let keyOffsetNeg : Nat := Int.toNat (-keyOffset) let mut matchLo? : Option Int := none let mut nonmatchLo? : Option Int := none for j in [:seqLenEff] do if !causalPattern || j ≤ queryPos then let row := vOutRows[j]! let vCoord := row[coord]!.lo - if tokens[j]! = targetTok then + let isMatch : Bool := + match keyOffsetNat? with + | some k => + let idx := j + k + idx < seqLenEff && tokens[idx]! = targetTok + | none => + if j < keyOffsetNeg then + false + else + tokens[j - keyOffsetNeg]! = targetTok + if isMatch then matchLo? := match matchLo? with | none => some vCoord @@ -5800,7 +7255,17 @@ private def certifyInductionSoundBestMatchLocalBinaryPair for j in [:seqLenEff] do if !causalPattern || j ≤ queryPos then let vLo := (vDotRows[j]!).lo - if tokens[j]! = targetTok then + let isMatch : Bool := + match keyOffsetNat? with + | some k => + let idx := j + k + idx < seqLenEff && tokens[idx]! = targetTok + | none => + if j < keyOffsetNeg then + false + else + tokens[j - keyOffsetNeg]! = targetTok + if isMatch then matchLoLogit? := match matchLoLogit? with | none => some vLo @@ -5845,20 +7310,34 @@ private def certifyInductionSoundBestMatchLocalBinaryPair (baseRow : Array Fixed10Interval) (vOutRows : Array (Array Fixed10Interval)) (matchWeightLowerBound : Rat) - (targetOffset : Int) : + (targetOffset : Int) + (keyOffset : Int) : ExceptT String IO (Array Fixed10Interval) := do let ti : Int := (Int.ofNat queryPos) + targetOffset if ti < 0 || ti ≥ (Int.ofNat seqLenEff) then throw "query position has no valid target offset" let tIdx : Nat := Int.toNat ti let targetTok := tokens[tIdx]! + let keyOffsetNat? : Option Nat := + if keyOffset ≥ 0 then some (Int.toNat keyOffset) else none + let keyOffsetNeg : Nat := Int.toNat (-keyOffset) let mut matchLo? : Option (Array Int) := none let mut nonmatchLo? : Option (Array Int) := none for j in [:seqLenEff] do if !causalPattern || j ≤ queryPos then let row := vOutRows[j]! let rowLo : Array Int := row.map (fun x => x.lo) - if tokens[j]! = targetTok then + let isMatch : Bool := + match keyOffsetNat? with + | some k => + let idx := j + k + idx < seqLenEff && tokens[idx]! = targetTok + | none => + if j < keyOffsetNeg then + false + else + tokens[j - keyOffsetNeg]! = targetTok + if isMatch then matchLo? := match matchLo? with | none => some rowLo @@ -5960,15 +7439,48 @@ private def certifyInductionSoundBestMatchLocalBinaryPair (bIn bOut : Array Fixed10Interval) : Array (Array Fixed10Interval) := let ln2Rows := calcLnRows rows p + let geluTargetUnion : GeluDerivTarget := + if hdr.geluDerivTarget = .tanh then .exact else hdr.geluDerivTarget if usePerRow then - let mlpRows := mlpRowsFromScaled cfg hdr.geluDerivTarget slack - hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows - addRowsFixed rows mlpRows + match selectedRows? with + | none => + let wInIntervals := intervalsFromScaled wIn slack + let wOutIntervals := intervalsFromScaled wOut slack + let mlpRows := mlpRowsFromIntervals cfg geluTargetUnion + hdr.modelDim hdr.hiddenDim wInIntervals wOutIntervals bIn bOut ln2Rows + addRowsFixed rows mlpRows + | some idxs => + Id.run do + let ln2Union := unionRowsFixed ln2Rows + let mlpUnion := mlpRowFromScaled cfg geluTargetUnion slack + hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Union + let mut out := addVecFixedRows rows mlpUnion + for idx in idxs do + if idx < ln2Rows.size then + let mlpRow := mlpRowFromScaledNoTask cfg hdr.geluDerivTarget slack + hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut (ln2Rows[idx]!) + out := out.set! idx (addVecFixed (rows[idx]!) mlpRow) + return out else let ln2Union := unionRowsFixed ln2Rows - let mlpOut := mlpRowFromScaled cfg hdr.geluDerivTarget slack + let mlpOut := mlpRowFromScaled cfg geluTargetUnion slack hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Union addVecFixedRows rows mlpOut + let awaitPattern + (pattern? : Option HeadBestMatchPatternCert) + (task? : Option (Task (Except IO.Error (Except String HeadBestMatchPatternCert)))) + (label : String) : + ExceptT String IO HeadBestMatchPatternCert := do + match pattern? with + | some cert => pure cert + | none => + match task? with + | none => throw label + | some task => + match task.get with + | .error e => throw (toString e) + | .ok (.error msg) => throw msg + | .ok (.ok cert) => pure cert let h ← IO.FS.Handle.mk path IO.FS.Mode.read let _ ← ExceptT.mk (readBinaryHeader h) let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) @@ -5984,6 +7496,10 @@ private def certifyInductionSoundBestMatchLocalBinaryPair let mut residualsSameV : Bool := true let mut p1? : Option HeadBestMatchPatternCert := none let mut p2? : Option HeadBestMatchPatternCert := none + let mut p1Task? : + Option (Task (Except IO.Error (Except String HeadBestMatchPatternCert))) := none + let mut p2Task? : + Option (Task (Except IO.Error (Except String HeadBestMatchPatternCert))) := none let mut vlogit? : Option HeadValueLogitCert := none for l in [:hdr.numLayers] do let at1 := l = layer1 && p1?.isNone @@ -6006,6 +7522,12 @@ private def certifyInductionSoundBestMatchLocalBinaryPair if needRows2 then ln1Rows2? := some (ln1RowsShared?.getD (calcLnRows residuals2 ln1P)) + let mut ln1Rows1Exact? : Option (Array (Array Fixed10Interval)) := none + let mut ln1Rows2Exact? : Option (Array (Array Fixed10Interval)) := none + if at1 then + ln1Rows1Exact? := some (calcLnRowsExact residuals1 ln1P) + if at2 then + ln1Rows2Exact? := some (calcLnRowsExact residuals2 ln1P) let mut ln1RowsV? : Option (Array (Array Fixed10Interval)) := none if needRowsV then if residualsSameV then @@ -6090,7 +7612,10 @@ private def certifyInductionSoundBestMatchLocalBinaryPair else ln1UnionV? := some (unionRowsFixed ln1RowsV) attnUnionV? := some (Array.replicate hdr.modelDim { lo := 0, hi := 0 }) + let needUpdate := needUpdate1 || needUpdate2 for hIdx in [:hdr.numHeads] do + let needValue := at2 && hIdx = head2 + let needV := needUpdate || needValue let needQK := (at1 && hIdx = head1) || (at2 && hIdx = head2) if needQK then let wq ← @@ -6102,19 +7627,40 @@ private def certifyInductionSoundBestMatchLocalBinaryPair let bQIntervals := intervalsFromScaled bQ slack let bKIntervals := intervalsFromScaled bK slack if at1 && hIdx = head1 then - let p1 ← bestMatchPattern layer1 head1 ln1Rows1 wq wk bQIntervals bKIntervals offset1 - p1? := some p1 + let ln1Rows1Exact := ln1Rows1Exact?.getD ln1Rows1 + if needV then + let task ← + ExceptT.lift <| + IO.asTask + (bestMatchPattern + layer1 head1 ln1Rows1Exact wq wk bQIntervals bKIntervals offset1 keyOffset1 + (useTasks := false)).run + p1Task? := some task + else + let p1 ← + bestMatchPattern + layer1 head1 ln1Rows1Exact wq wk bQIntervals bKIntervals offset1 keyOffset1 + p1? := some p1 if at2 && hIdx = head2 then - let p2 ← bestMatchPattern layer2 head2 ln1Rows2 wq wk bQIntervals bKIntervals offset2 - p2? := some p2 + let ln1Rows2Exact := ln1Rows2Exact?.getD ln1Rows2 + if needV then + let task ← + ExceptT.lift <| + IO.asTask + (bestMatchPattern + layer2 head2 ln1Rows2Exact wq wk bQIntervals bKIntervals offset2 keyOffset2 + (useTasks := false)).run + p2Task? := some task + else + let p2 ← + bestMatchPattern + layer2 head2 ln1Rows2Exact wq wk bQIntervals bKIntervals offset2 keyOffset2 + p2? := some p2 else let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) let _ ← ExceptT.mk (skipF64Array h hdr.headDim) let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let needUpdate := needUpdate1 || needUpdate2 - let needValue := at2 && hIdx = head2 - let needV := needUpdate || needValue if needV then let wv ← ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 @@ -6139,14 +7685,14 @@ private def certifyInductionSoundBestMatchLocalBinaryPair if needUpdate2 then if l == layer1 && hIdx == head1 && useTight2 && causalPattern then let p1 ← - match p1? with - | some cert => pure cert - | none => throw "missing best-match pattern cert for tightening" + awaitPattern p1? p1Task? "missing best-match pattern cert for tightening" + p1? := some p1 let vOutRows := calcVOutRows ln1Rows2 wv wo bVIntervals let mut headRows := prefixUnionRowsFixed vOutRows let baseRow := headRows[queryPos]! let tightRow ← tightenQueryRowLower baseRow vOutRows p1.bestMatchWeightLowerBound offset1 + keyOffset1 headRows := headRows.set! queryPos tightRow match attnRows2? with | some rows => attnRows2? := some (addRowsFixed rows headRows) @@ -6165,11 +7711,11 @@ private def certifyInductionSoundBestMatchLocalBinaryPair attnUnionV? := attnUnion' if needValue then let p2 ← - match p2? with - | some cert => pure cert - | none => throw "missing best-match pattern cert for value bound" + awaitPattern p2? p2Task? "missing best-match pattern cert for value bound" + p2? := some p2 let vlogit ← valueLogit ln1RowsV p2.bestMatchWeightLowerBound wv wo bVIntervals offset2 + keyOffset2 vlogit? := some vlogit else let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) @@ -6234,6 +7780,14 @@ private def certifyInductionSoundBestMatchLocalBinaryPair let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) + if l == layer1 && p1?.isNone then + let p1 ← + awaitPattern p1? p1Task? "missing best-match pattern cert for layer1" + p1? := some p1 + if l == layer2 && p2?.isNone then + let p2 ← + awaitPattern p2? p2Task? "missing best-match pattern cert for layer2" + p2? := some p2 match p1?, p2?, vlogit? with | some p1, some p2, some vlogit => let cert : InductionHeadBestMatchSoundCert := { @@ -6263,11 +7817,14 @@ def certifyInductionSoundBestMatch (inputDelta : Rat := 0) (offset1 : Int := -1) (offset2 : Int := -1) + (keyOffset1 : Int := 0) + (keyOffset2 : Int := 0) (maxSeqLen : Nat := 256) (scalePow10 : Nat := defaultBinaryScalePow10) (tightPattern : Bool := false) (tightPatternLayers : Nat := 1) (perRowPatternLayers : Nat := 0) + (useAffine : Bool := false) (iterTighten : Bool := false) (targetToken? : Option Nat := none) (negativeToken? : Option Nat := none) @@ -6295,10 +7852,116 @@ def certifyInductionSoundBestMatch let dtMs := dtNs / 1000000 ExceptT.lift <| IO.eprintln s!"timing:{label} {dtMs}ms" return r + let loadSharedAndDirection (scalePow10 : Nat) : + ExceptT String IO (SharedBinaryInputs × Option (Thunk (Array Fixed10Interval))) := do + let sharedTask ← + ExceptT.lift <| IO.asTask (loadSharedBinaryInputs path inputPath inputDelta scalePow10) + let directionTask? ← + match targetToken?, negativeToken? with + | none, none => pure none + | some targetToken, some negativeToken => + let task ← + ExceptT.lift <| + IO.asTask + (readLogitDiffDirectionBinary + path targetToken negativeToken scalePow10 fixedUlpSlack) + pure (some task) + | _, _ => + throw "use both target and negative tokens (or neither)" + let shared ← + match sharedTask.get with + | .error e => throw (toString e) + | .ok (.error msg) => throw msg + | .ok (.ok v) => pure v + let direction? ← + match directionTask? with + | none => pure none + | some task => + match task.get with + | .error e => throw (toString e) + | .ok (.error msg) => throw msg + | .ok (.ok (hdrDir, dir)) => + if hdrDir.modelDim ≠ shared.hdr.modelDim then + throw "unembedding model_dim mismatch" + pure (some (Thunk.mk (fun () => dir))) + return (shared, direction?) let computeBestAtScale (scalePow10 : Nat) (configs : Array (Bool × Nat × Nat)) : ExceptT String IO (Rat × InductionHeadBestMatchSoundCert) := do - let shared ← ExceptT.mk (loadSharedBinaryInputs path inputPath inputDelta scalePow10) + let (shared, direction?) ← loadSharedAndDirection scalePow10 + let queryPos : Nat := + match queryPos? with + | some q => q + | none => + if shared.hdr.seqLen = 0 then 0 else shared.hdr.seqLen - 1 + if queryPos ≥ shared.hdr.seqLen then + throw s!"queryPos {queryPos} out of range" + let prefixCache := mkSharedBinaryPrefix shared queryPos causalPattern + let computeCert (useTight : Bool) (tightLayers perRowLayers : Nat) : + ExceptT String IO InductionHeadBestMatchSoundCert := do + let label := + s!"scale={scalePow10} tight={useTight} tl={tightLayers} pr={perRowLayers}" + let cert ← + timeIt (s!"{label}:pair") <| + ExceptT.mk <| + certifyInductionSoundBestMatchLocalBinaryPair + path layer1 head1 layer2 head2 coord queryPos eps soundnessBits inputPath + inputDelta offset1 offset2 keyOffset1 keyOffset2 maxSeqLen scalePow10 useTight + tightLayers perRowLayers useAffine softmaxExpEffort causalPattern + (shared? := some shared) (prefix? := some prefixCache) + (targetToken? := targetToken?) (negativeToken? := negativeToken?) + (direction? := direction?) + return cert + let metricOf (cert : InductionHeadBestMatchSoundCert) : Rat := + match cert.layer2Logit? with + | some logit => logit.logitDiffLowerBound + | none => cert.deltaLowerBound + -- Avoid nested task pools when per-row MLP already spawns tasks. + let parallelConfigs : Bool := + configs.size > 1 && configs.all (fun (_, _, perRowLayers) => perRowLayers = 0) + let mut best : Option (Rat × InductionHeadBestMatchSoundCert) := none + if parallelConfigs then + let tasks ← + ExceptT.lift <| + configs.mapM fun (useTight, tightLayers, perRowLayers) => + IO.asTask (computeCert useTight tightLayers perRowLayers).run + let results := tasks.map (fun t => t.get) + for i in [:configs.size] do + let res := results[i]! + match res with + | .error e => throw (toString e) + | .ok (.error msg) => throw msg + | .ok (.ok cert) => + let metric := metricOf cert + best := + match best with + | none => some (metric, cert) + | some (bestMetric, bestCert) => + if metric > bestMetric then + some (metric, cert) + else + some (bestMetric, bestCert) + else + for i in [:configs.size] do + let (useTight, tightLayers, perRowLayers) := configs[i]! + let cert ← computeCert useTight tightLayers perRowLayers + let metric := metricOf cert + best := + match best with + | none => some (metric, cert) + | some (bestMetric, bestCert) => + if metric > bestMetric then + some (metric, cert) + else + some (bestMetric, bestCert) + match best with + | none => throw "no induction certs computed" + | some bestPair => return bestPair + let computeBestAtScaleOrdered (scalePow10 : Nat) + (configs : Array (Bool × Nat × Nat)) + (stopAtPositive : Bool) : + ExceptT String IO (Rat × InductionHeadBestMatchSoundCert) := do + let (shared, direction?) ← loadSharedAndDirection scalePow10 let queryPos : Nat := match queryPos? with | some q => q @@ -6307,18 +7970,6 @@ def certifyInductionSoundBestMatch if queryPos ≥ shared.hdr.seqLen then throw s!"queryPos {queryPos} out of range" let prefixCache := mkSharedBinaryPrefix shared queryPos causalPattern - let direction? : Option (Thunk (Array Fixed10Interval)) ← - match targetToken?, negativeToken? with - | none, none => pure none - | some targetToken, some negativeToken => - let (hdrDir, dir) ← - ExceptT.mk (readLogitDiffDirectionBinary - path targetToken negativeToken scalePow10 fixedUlpSlack) - if hdrDir.modelDim ≠ shared.hdr.modelDim then - throw "unembedding model_dim mismatch" - pure (some (Thunk.mk (fun () => dir))) - | _, _ => - throw "use both target and negative tokens (or neither)" let computeCert (useTight : Bool) (tightLayers perRowLayers : Nat) : ExceptT String IO InductionHeadBestMatchSoundCert := do let label := @@ -6327,59 +7978,64 @@ def certifyInductionSoundBestMatch timeIt (s!"{label}:pair") <| ExceptT.mk <| certifyInductionSoundBestMatchLocalBinaryPair - path layer1 head1 layer2 head2 coord queryPos eps soundnessBits inputPath - inputDelta offset1 offset2 maxSeqLen scalePow10 useTight tightLayers - perRowLayers softmaxExpEffort causalPattern - (shared? := some shared) (prefix? := some prefixCache) - (targetToken? := targetToken?) (negativeToken? := negativeToken?) - (direction? := direction?) + path layer1 head1 layer2 head2 coord queryPos eps soundnessBits inputPath + inputDelta offset1 offset2 keyOffset1 keyOffset2 maxSeqLen scalePow10 useTight + tightLayers perRowLayers useAffine softmaxExpEffort causalPattern + (shared? := some shared) (prefix? := some prefixCache) + (targetToken? := targetToken?) (negativeToken? := negativeToken?) + (direction? := direction?) return cert - let tasks ← - ExceptT.lift <| - configs.mapM fun (useTight, tightLayers, perRowLayers) => - IO.asTask (computeCert useTight tightLayers perRowLayers).run - let results := tasks.map (fun t => t.get) + let metricOf (cert : InductionHeadBestMatchSoundCert) : Rat := + match cert.layer2Logit? with + | some logit => logit.logitDiffLowerBound + | none => cert.deltaLowerBound let mut best : Option (Rat × InductionHeadBestMatchSoundCert) := none for i in [:configs.size] do - let res := results[i]! - match res with - | .error e => throw (toString e) - | .ok (.error msg) => throw msg - | .ok (.ok cert) => - let metric := - match cert.layer2Logit? with - | some logit => logit.logitDiffLowerBound - | none => cert.deltaLowerBound - best := - match best with - | none => some (metric, cert) - | some (bestMetric, bestCert) => - if metric > bestMetric then - some (metric, cert) - else - some (bestMetric, bestCert) + let (useTight, tightLayers, perRowLayers) := configs[i]! + let cert ← computeCert useTight tightLayers perRowLayers + let metric := metricOf cert + if stopAtPositive && metric > 0 then + return (metric, cert) + best := + match best with + | none => some (metric, cert) + | some (bestMetric, bestCert) => + if metric > bestMetric then + some (metric, cert) + else + some (bestMetric, bestCert) match best with | none => throw "no induction certs computed" | some bestPair => return bestPair let maxLayer := Nat.max layer1 layer2 let tightFull := Nat.max 1 maxLayer let perRowFull := maxLayer - let mut configs : Array (Bool × Nat × Nat) := - #[(tightPattern, tightPatternLayers, perRowPatternLayers)] - let needTightFull := (!tightPattern) || tightPatternLayers < tightFull - if needTightFull then - configs := configs.push (true, tightFull, perRowPatternLayers) - if perRowPatternLayers < perRowFull then - configs := configs.push (true, tightFull, perRowFull) + let normalizeConfig (useTight : Bool) (tightLayers perRowLayers : Nat) : + Bool × Nat × Nat := + if useTight then + (true, Nat.max 1 tightLayers, perRowLayers) + else + (false, 0, perRowLayers) + let pushUnique (configs : Array (Bool × Nat × Nat)) (cfg : Bool × Nat × Nat) : + Array (Bool × Nat × Nat) := + if configs.any (fun c => c == cfg) then configs else configs.push cfg + let baseCfg : Bool × Nat × Nat := + normalizeConfig tightPattern tightPatternLayers perRowPatternLayers if !iterTighten then - let (_, cert) ← computeBestAtScale scalePow10 configs + let (_, cert) ← computeBestAtScale scalePow10 #[baseCfg] return cert else + let mut configs : Array (Bool × Nat × Nat) := #[baseCfg] + let needTightFull := (!tightPattern) || tightPatternLayers < tightFull + if needTightFull then + configs := pushUnique configs (normalizeConfig true tightFull perRowPatternLayers) + if perRowPatternLayers < perRowFull then + configs := pushUnique configs (normalizeConfig true tightFull perRowFull) + let scales : List Nat := [scalePow10, scalePow10 + 1, scalePow10 + 2] let mut bestOverall : Option (Rat × InductionHeadBestMatchSoundCert) := none - let mut scale := scalePow10 - let maxScale := scalePow10 + 2 - while scale ≤ maxScale do - let (metric, cert) ← computeBestAtScale scale configs + for scale in scales do + let (metric, cert) ← + computeBestAtScaleOrdered scale configs (stopAtPositive := true) bestOverall := match bestOverall with | none => some (metric, cert) @@ -6389,9 +8045,7 @@ def certifyInductionSoundBestMatch else some (bestMetric, bestCert) if metric > 0 then - scale := maxScale + 1 - else - scale := scale + 1 + return cert match bestOverall with | none => throw "no induction certs computed" | some (_, cert) => return cert From aebc8f4a92559e90b8d3634ca36bd9ad4c1dea24 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 08:51:03 +0100 Subject: [PATCH 021/244] Extend SOUND head patterns and bounds --- AGENTS.md | 1 - Main.lean | 62 +++++++++++++++++++++++---- Nfp/Induction.lean | 33 ++++++++++++++- Nfp/Sound/Affine.lean | 36 ++++++++++++++-- Nfp/Sound/BinaryPure.lean | 63 +++++++++++++++++++++++----- Nfp/Sound/Bounds/Exp.lean | 51 ++++++++++++++++++---- Nfp/Sound/Fixed.lean | 6 ++- Nfp/Sound/HeadCert.lean | 21 +++++++++- Nfp/Sound/IO.lean | 50 ++++++++++++++-------- Nfp/Sound/Interval.lean | 34 ++++++++++----- README.md | 8 +++- scripts/scan_gpt2_induction_sound.py | 8 +++- 12 files changed, 305 insertions(+), 68 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 0567f52..4c84131 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -25,7 +25,6 @@ but keep the core invariants and the “no fake proofs” ethos. ### Run the CLI (preferred integration path) One of these typically works (depending on your Lake setup): - `lake exe nfp --help` -- `./.lake/build/bin/nfp --help` If you add or change CLI behavior, validate at least: - `nfp --help` diff --git a/Main.lean b/Main.lean index 35ee2c2..5c552b5 100644 --- a/Main.lean +++ b/Main.lean @@ -48,7 +48,8 @@ lake exe nfp head_pattern model.nfpt --layer 0 --head 0 --delta 1/100 --offset - # Sound induction head certificate (binary only) lake exe nfp induction_cert model.nfpt --layer1 0 --head1 0 --layer2 1 --head2 0 \ - --coord 0 --delta 1/100 --offset1 -1 --offset2 -1 --target 42 --negative 17 + --coord 0 --delta 1/100 --offset1 -1 --offset2 0 --keyOffset2 -1 \ + --target 42 --negative 17 # Instantiate RoPE bounds for a specific shape lake exe nfp rope --seqLen 4 --pairs 8 @@ -1075,6 +1076,7 @@ private structure HeadPatternArgs where layerIdx : Nat headIdx : Nat offset : Int + keyOffset : Int soundnessBits : Nat softmaxExpEffort : Nat tightPatternLayers : Nat @@ -1082,6 +1084,7 @@ private structure HeadPatternArgs where perRowPatternLayers : Nat causalPattern : Bool bestMatch : Bool + useAffine : Bool sweep : Bool queryPos? : Option Nat inputPath? : Option System.FilePath @@ -1094,6 +1097,7 @@ private def parseHeadPatternArgs (p : Parsed) : HeadPatternArgs := let layerIdx := p.flag? "layer" |>.map (·.as! Nat) |>.getD 0 let headIdx := p.flag? "head" |>.map (·.as! Nat) |>.getD 0 let offset := p.flag? "offset" |>.map (·.as! Int) |>.getD (-1) + let keyOffset := p.flag? "keyOffset" |>.map (·.as! Int) |>.getD 0 let soundnessBits := p.flag? "soundnessBits" |>.map (·.as! Nat) |>.getD 20 let softmaxExpEffort := p.flag? "softmaxExpEffort" |>.map (·.as! Nat) @@ -1104,6 +1108,7 @@ private def parseHeadPatternArgs (p : Parsed) : HeadPatternArgs := let perRowPatternLayers := p.flag? "perRowPatternLayers" |>.map (·.as! Nat) |>.getD 0 let causalPattern := !p.hasFlag "noncausalPattern" let bestMatch := p.hasFlag "bestMatch" + let useAffine := p.hasFlag "affine" let sweep := p.hasFlag "sweep" let queryPos? := p.flag? "queryPos" |>.map (·.as! Nat) let inputPath? := p.flag? "input" |>.map (·.as! String) |>.map (fun s => ⟨s⟩) @@ -1114,6 +1119,7 @@ private def parseHeadPatternArgs (p : Parsed) : HeadPatternArgs := layerIdx := layerIdx headIdx := headIdx offset := offset + keyOffset := keyOffset soundnessBits := soundnessBits softmaxExpEffort := softmaxExpEffort tightPatternLayers := tightPatternLayers @@ -1121,6 +1127,7 @@ private def parseHeadPatternArgs (p : Parsed) : HeadPatternArgs := perRowPatternLayers := perRowPatternLayers causalPattern := causalPattern bestMatch := bestMatch + useAffine := useAffine sweep := sweep queryPos? := queryPos? inputPath? := inputPath? @@ -1131,10 +1138,11 @@ private def parseHeadPatternArgs (p : Parsed) : HeadPatternArgs := private def formatHeadPatternBestMatchSweep (layerIdx headIdx : Nat) (offset : Int) + (keyOffset : Int) (certs : Array Nfp.Sound.HeadBestMatchPatternCert) : String := let header := "SOUND head pattern sweep (best-match): " ++ - s!"layer={layerIdx}, head={headIdx}, offset={offset}\n" + s!"layer={layerIdx}, head={headIdx}, offset={offset}, keyOffset={keyOffset}\n" let body := certs.foldl (fun acc cert => acc ++ @@ -1147,7 +1155,8 @@ private def formatHeadPatternBestMatch (cert : Nfp.Sound.HeadBestMatchPatternCert) : String := "SOUND head pattern (best-match): " ++ s!"layer={cert.layerIdx}, head={cert.headIdx}, " ++ - s!"offset={cert.targetOffset}, queryPos={cert.queryPos}\n" ++ + s!"offset={cert.targetOffset}, keyOffset={cert.keyOffset}, " ++ + s!"queryPos={cert.queryPos}\n" ++ s!"seqLen={cert.seqLen}, targetTok={cert.targetToken}, " ++ s!"bestMatchLogitLB={cert.bestMatchLogitLowerBound}, " ++ s!"bestNonmatchLogitUB={cert.bestNonmatchLogitUpperBound}\n" ++ @@ -1158,7 +1167,8 @@ private def formatHeadPatternBestMatch private def formatHeadPatternLocal (cert : Nfp.Sound.HeadPatternCert) : String := "SOUND head pattern (local): " ++ - s!"layer={cert.layerIdx}, head={cert.headIdx}, offset={cert.targetOffset}\n" ++ + s!"layer={cert.layerIdx}, head={cert.headIdx}, " ++ + s!"offset={cert.targetOffset}, keyOffset={cert.keyOffset}\n" ++ s!"seqLen={cert.seqLen}, " ++ s!"targetCountLB={cert.targetCountLowerBound}, " ++ s!"targetLogitLB={cert.targetLogitLowerBound}, " ++ @@ -1183,27 +1193,36 @@ private def runHeadPatternAction (args : HeadPatternArgs) : ExceptT String IO St else throw <| "head pattern bounds require EMBEDDINGS; pass --input for legacy text models." + if args.useAffine && !args.bestMatch then + throw "affine bounds are only supported with --bestMatch" + if args.useAffine && args.sweep then + throw "affine sweep is unsupported; use --bestMatch without --sweep" if args.bestMatch then if args.sweep then let certs ← ExceptT.mk <| Nfp.Sound.certifyHeadPatternBestMatchLocalSweep args.modelPath args.layerIdx args.headIdx (inputPath? := inputPath?) (inputDelta := delta) (soundnessBits := args.soundnessBits) (targetOffset := args.offset) + (keyOffset := args.keyOffset) (maxSeqLen := args.maxSeqLen) (tightPattern := args.tightPattern) (tightPatternLayers := args.tightPatternLayers) (perRowPatternLayers := args.perRowPatternLayers) + (useAffine := args.useAffine) (softmaxExpEffort := args.softmaxExpEffort) (causalPattern := args.causalPattern) - return formatHeadPatternBestMatchSweep args.layerIdx args.headIdx args.offset certs + return formatHeadPatternBestMatchSweep args.layerIdx args.headIdx args.offset + args.keyOffset certs else let cert ← ExceptT.mk <| Nfp.Sound.certifyHeadPatternBestMatchLocal args.modelPath args.layerIdx args.headIdx (queryPos? := args.queryPos?) (inputPath? := inputPath?) (inputDelta := delta) (soundnessBits := args.soundnessBits) - (targetOffset := args.offset) (maxSeqLen := args.maxSeqLen) + (targetOffset := args.offset) (keyOffset := args.keyOffset) + (maxSeqLen := args.maxSeqLen) (tightPattern := args.tightPattern) (tightPatternLayers := args.tightPatternLayers) (perRowPatternLayers := args.perRowPatternLayers) + (useAffine := args.useAffine) (softmaxExpEffort := args.softmaxExpEffort) (causalPattern := args.causalPattern) return formatHeadPatternBestMatch cert @@ -1214,6 +1233,7 @@ private def runHeadPatternAction (args : HeadPatternArgs) : ExceptT String IO St Nfp.Sound.certifyHeadPatternLocal args.modelPath args.layerIdx args.headIdx (inputPath? := inputPath?) (inputDelta := delta) (soundnessBits := args.soundnessBits) (targetOffset := args.offset) + (keyOffset := args.keyOffset) (maxSeqLen := args.maxSeqLen) (tightPattern := args.tightPattern) (tightPatternLayers := args.tightPatternLayers) (perRowPatternLayers := args.perRowPatternLayers) @@ -1265,6 +1285,8 @@ private structure InductionCertArgs where coord : Nat offset1 : Int offset2 : Int + keyOffset1 : Int + keyOffset2 : Int targetToken? : Option Nat negativeToken? : Option Nat soundnessBits : Nat @@ -1275,6 +1297,7 @@ private structure InductionCertArgs where iterTighten : Bool causalPattern : Bool bestMatch : Bool + useAffine : Bool queryPos? : Option Nat inputPath? : Option System.FilePath delta : Rat @@ -1293,6 +1316,8 @@ private def parseInductionCertArgs (p : Parsed) : ExceptT String IO InductionCer let coord := p.flag? "coord" |>.map (·.as! Nat) |>.getD 0 let offset1 := p.flag? "offset1" |>.map (·.as! Int) |>.getD (-1) let offset2 := p.flag? "offset2" |>.map (·.as! Int) |>.getD (-1) + let keyOffset1 := p.flag? "keyOffset1" |>.map (·.as! Int) |>.getD 0 + let keyOffset2 := p.flag? "keyOffset2" |>.map (·.as! Int) |>.getD 0 let targetToken := p.flag? "target" |>.map (·.as! Nat) let negativeToken := p.flag? "negative" |>.map (·.as! Nat) let soundnessBits := p.flag? "soundnessBits" |>.map (·.as! Nat) |>.getD 20 @@ -1306,6 +1331,7 @@ private def parseInductionCertArgs (p : Parsed) : ExceptT String IO InductionCer let iterTighten := p.hasFlag "iterTighten" let causalPattern := !p.hasFlag "noncausalPattern" let bestMatch := p.hasFlag "bestMatch" + let useAffine := p.hasFlag "affine" let queryPos := p.flag? "queryPos" |>.map (·.as! Nat) let inputPath := p.flag? "input" |>.map (·.as! String) let deltaStr := p.flag? "delta" |>.map (·.as! String) |>.getD "0" @@ -1338,6 +1364,8 @@ private def parseInductionCertArgs (p : Parsed) : ExceptT String IO InductionCer coord := coord offset1 := offset1 offset2 := offset2 + keyOffset1 := keyOffset1 + keyOffset2 := keyOffset2 targetToken? := targetToken negativeToken? := negativeToken soundnessBits := soundnessBits @@ -1348,6 +1376,7 @@ private def parseInductionCertArgs (p : Parsed) : ExceptT String IO InductionCer iterTighten := iterTighten causalPattern := causalPattern bestMatch := bestMatch + useAffine := useAffine queryPos? := queryPos inputPath? := inputPath? delta := delta @@ -1388,11 +1417,13 @@ private def formatInductionBestMatch "SOUND induction cert (best-match):\n" ++ s!"queryPos={p2.queryPos}\n" ++ s!"layer1=L{p1.layerIdx} H{p1.headIdx} offset={p1.targetOffset} " ++ + s!"keyOffset={p1.keyOffset} " ++ s!"targetTok={p1.targetToken} " ++ s!"marginLB={p1.marginLowerBound} " ++ s!"weightLB={p1.bestMatchWeightLowerBound} " ++ s!"softmaxExpEffort={p1.softmaxExpEffort}\n" ++ s!"layer2=L{p2.layerIdx} H{p2.headIdx} offset={p2.targetOffset} " ++ + s!"keyOffset={p2.keyOffset} " ++ s!"targetTok={p2.targetToken} " ++ s!"marginLB={p2.marginLowerBound} " ++ s!"weightLB={p2.bestMatchWeightLowerBound} " ++ @@ -1411,9 +1442,11 @@ private def formatInductionLocal let logitLine := formatInductionLogitLine cert.layer2Logit? "SOUND induction cert:\n" ++ s!"layer1=L{p1.layerIdx} H{p1.headIdx} offset={p1.targetOffset} " ++ + s!"keyOffset={p1.keyOffset} " ++ s!"marginLB={p1.marginLowerBound} weightLB={p1.targetWeightLowerBound} " ++ s!"softmaxExpEffort={p1.softmaxExpEffort}\n" ++ s!"layer2=L{p2.layerIdx} H{p2.headIdx} offset={p2.targetOffset} " ++ + s!"keyOffset={p2.keyOffset} " ++ s!"marginLB={p2.marginLowerBound} weightLB={p2.targetWeightLowerBound} " ++ s!"softmaxExpEffort={p2.softmaxExpEffort}\n" ++ s!"coord={v.coord} matchCountLB={p2.targetCountLowerBound} " ++ @@ -1424,17 +1457,22 @@ private def formatInductionLocal /-- Run the induction-cert action and return the report string. -/ private def runInductionCertAction (args : InductionCertArgs) : ExceptT String IO String := do + if args.useAffine && !args.bestMatch then + throw "affine bounds are only supported with --bestMatch" if args.bestMatch then let cert ← ExceptT.mk <| Nfp.Sound.certifyInductionSoundBestMatch args.modelPath args.layer1 args.head1 args.layer2 args.head2 args.coord (queryPos? := args.queryPos?) (inputPath? := args.inputPath?) (inputDelta := args.delta) (soundnessBits := args.soundnessBits) - (offset1 := args.offset1) (offset2 := args.offset2) (maxSeqLen := args.maxSeqLen) + (offset1 := args.offset1) (offset2 := args.offset2) + (keyOffset1 := args.keyOffset1) (keyOffset2 := args.keyOffset2) + (maxSeqLen := args.maxSeqLen) (scalePow10 := args.scalePow10) (tightPattern := args.tightPattern) (tightPatternLayers := args.tightPatternLayers) (perRowPatternLayers := args.perRowPatternLayers) + (useAffine := args.useAffine) (iterTighten := args.iterTighten) (targetToken? := args.targetToken?) (negativeToken? := args.negativeToken?) (softmaxExpEffort := args.softmaxExpEffort) @@ -1446,7 +1484,9 @@ private def runInductionCertAction (args : InductionCertArgs) : ExceptT String I args.layer1 args.head1 args.layer2 args.head2 args.coord (inputPath? := args.inputPath?) (inputDelta := args.delta) (soundnessBits := args.soundnessBits) - (offset1 := args.offset1) (offset2 := args.offset2) (maxSeqLen := args.maxSeqLen) + (offset1 := args.offset1) (offset2 := args.offset2) + (keyOffset1 := args.keyOffset1) (keyOffset2 := args.keyOffset2) + (maxSeqLen := args.maxSeqLen) (scalePow10 := args.scalePow10) (tightPattern := args.tightPattern) (tightPatternLayers := args.tightPatternLayers) (perRowPatternLayers := args.perRowPatternLayers) @@ -1928,11 +1968,13 @@ LayerNorm epsilon is read from the model header." layer : Nat; "Layer index (default: 0)" head : Nat; "Head index (default: 0)" offset : Int; "Token-match offset (default: -1 for previous token, 0 for self)" + keyOffset : Int; "Key-position token offset (default: 0; use -1 with offset=0 for copy-next)" tightPattern; "Use tighter (slower) pattern bounds near the target layer" tightPatternLayers : Nat; "Number of layers using tight pattern bounds (default: 1)" perRowPatternLayers : Nat; "Number of layers using per-row MLP propagation (default: 0)" noncausalPattern; "Disable causal-prefix restriction for pattern/value bounds" bestMatch; "Use best-match (single-query) pattern bounds" + affine; "Use affine Q/K dot bounds for best-match (single-query only)" sweep; "Sweep best-match bounds across all valid query positions" queryPos : Nat; "Query position for best-match bounds (default: last position)" input : String; "Optional input .nfpt file for local bounds (must contain EMBEDDINGS \ @@ -1960,6 +2002,9 @@ LayerNorm epsilon is read from the model header." coord : Nat; "Output coordinate for the value bound (default: 0)" offset1 : Int; "Token-match offset for layer1 (default: -1)" offset2 : Int; "Token-match offset for layer2 (default: -1)" + keyOffset1 : Int; "Key-position token offset for layer1 (default: 0)" + keyOffset2 : Int; "Key-position token offset for layer2 (default: 0; use -1 with \ +offset2=0 for copy-next)" target : Nat; "Target token ID for logit-diff bound (optional; requires --negative)" negative : Nat; "Negative token ID for logit-diff bound (optional; requires --target)" tightPattern; "Use tighter (slower) pattern bounds near the target layer" @@ -1968,6 +2013,7 @@ LayerNorm epsilon is read from the model header." iterTighten; "Iteratively tighten best-match bounds (escalates tight/per-row layers to full)" noncausalPattern; "Disable causal-prefix restriction for pattern/value bounds" bestMatch; "Use best-match (single-query) pattern bounds" + affine; "Use affine Q/K dot bounds for best-match" queryPos : Nat; "Query position for best-match bounds (default: last position)" input : String; "Optional input .nfpt file for local bounds (must contain EMBEDDINGS \ for legacy text)" diff --git a/Nfp/Induction.lean b/Nfp/Induction.lean index 38956fc..2f78871 100644 --- a/Nfp/Induction.lean +++ b/Nfp/Induction.lean @@ -97,6 +97,8 @@ structure TokenMatchPattern where seqLen : Nat /-- Target offset (e.g. `-1` for previous token). -/ targetOffset : Int + /-- Key-position offset used when matching tokens against the query's target token. -/ + keyOffset : Int /-- Lower bound on the number of matching-token keys. -/ targetCountLowerBound : Nat /-- Effort level for the `expLB` portfolio used in margin-to-weight bounds. -/ @@ -205,6 +207,20 @@ structure InductionPatternWitness where tokenMatch : TokenMatchPattern /-- The pattern targets the previous-token offset. -/ prevOffset : tokenMatch.targetOffset = -1 + /-- The key-token comparison uses no key offset. -/ + keyOffsetZero : tokenMatch.keyOffset = 0 + /-- Certified nontrivial attention mass on matching tokens. -/ + positiveMass : 0 < tokenMatch.targetWeightLowerBound + deriving Repr + +/-- A minimal sound witness for a copy-next induction-style attention pattern. -/ +structure CopyNextPatternWitness where + /-- Token-match pattern data (sound certificate output). -/ + tokenMatch : TokenMatchPattern + /-- The pattern uses the current query token as the target. -/ + targetOffsetZero : tokenMatch.targetOffset = 0 + /-- Keys are matched against the previous-token stream (copy-next). -/ + keyOffsetPrev : tokenMatch.keyOffset = -1 /-- Certified nontrivial attention mass on matching tokens. -/ positiveMass : 0 < tokenMatch.targetWeightLowerBound deriving Repr @@ -214,11 +230,26 @@ namespace TokenMatchPattern /-- Build an induction-style witness from a valid token-match pattern plus explicit assumptions. -/ def toInductionPatternWitness (p : TokenMatchPattern) (h : p.Valid) (hm : p.marginLowerBound > 0) - (hcount : 0 < p.targetCountLowerBound) (hoff : p.targetOffset = -1) : + (hcount : 0 < p.targetCountLowerBound) (hoff : p.targetOffset = -1) + (hkey : p.keyOffset = 0) : InductionPatternWitness := { tokenMatch := p prevOffset := hoff + keyOffsetZero := hkey + positiveMass := weight_lower_bound_pos_of_margin_pos p h hm hcount + } + +/-- Build a copy-next witness from a valid token-match pattern plus explicit assumptions. -/ +def toCopyNextPatternWitness + (p : TokenMatchPattern) (h : p.Valid) (hm : p.marginLowerBound > 0) + (hcount : 0 < p.targetCountLowerBound) (hoff : p.targetOffset = 0) + (hkey : p.keyOffset = -1) : + CopyNextPatternWitness := + { + tokenMatch := p + targetOffsetZero := hoff + keyOffsetPrev := hkey positiveMass := weight_lower_bound_pos_of_margin_pos p h hm hcount } diff --git a/Nfp/Sound/Affine.lean b/Nfp/Sound/Affine.lean index 866065f..7090f8c 100644 --- a/Nfp/Sound/Affine.lean +++ b/Nfp/Sound/Affine.lean @@ -8,8 +8,8 @@ namespace Nfp.Sound /-! # Affine arithmetic scaffolding (SOUND) -This module provides a minimal affine-form representation for future local -certification improvements. It is not yet integrated into the SOUND pipeline. +This module provides a minimal affine-form representation for local certification +improvements, used by optional affine Q/K bounds in SOUND best-match paths. -/ /-- Affine form `x = center + sum coeffs[i] * eps_i` with `eps_i in [-1, 1]`. -/ @@ -28,8 +28,8 @@ private def combineCoeffs (a b : Array Rat) (f : Rat → Rat → Rat) : Array Ra let n := max a.size b.size let mut out := Array.mkEmpty n for i in [:n] do - let ai := a.get? i |>.getD 0 - let bi := b.get? i |>.getD 0 + let ai := a.getD i 0 + let bi := b.getD i 0 out := out.push (f ai bi) return out @@ -48,6 +48,13 @@ def scale (c : Rat) (a : AffineForm) : AffineForm := { center := c * a.center coeffs := a.coeffs.map (fun k => c * k) } +/-- Append a fresh independent noise coefficient (skipped if zero). -/ +def appendNoise (a : AffineForm) (coeff : Rat) : AffineForm := + if coeff = 0 then + a + else + { center := a.center, coeffs := a.coeffs.push coeff } + /-- Sum of absolute noise coefficients (radius of the interval hull). -/ def radius (a : AffineForm) : Rat := a.coeffs.foldl (fun acc c => acc + ratAbs c) 0 @@ -57,6 +64,24 @@ def toInterval (a : AffineForm) : RatInterval := let r := radius a { lo := a.center - r, hi := a.center + r } +/-- Affine multiplication with aligned noise terms and a single remainder noise. -/ +def mul (a b : AffineForm) : AffineForm := + let coeffs := combineCoeffs a.coeffs b.coeffs + (fun ai bi => b.center * ai + a.center * bi) + let rem := radius a * radius b + appendNoise { center := a.center * b.center, coeffs := coeffs } rem + +/-- Affine multiplication treating noise terms as disjoint. -/ +def mulDisjoint (a b : AffineForm) : AffineForm := + Id.run do + let mut coeffs : Array Rat := Array.mkEmpty (a.coeffs.size + b.coeffs.size) + for ai in a.coeffs do + coeffs := coeffs.push (b.center * ai) + for bi in b.coeffs do + coeffs := coeffs.push (a.center * bi) + let rem := radius a * radius b + return appendNoise { center := a.center * b.center, coeffs := coeffs } rem + /-! ### Specs -/ theorem AffineForm_spec : AffineForm = AffineForm := rfl @@ -65,8 +90,11 @@ theorem combineCoeffs_spec : combineCoeffs = combineCoeffs := rfl theorem add_spec : add = add := rfl theorem sub_spec : sub = sub := rfl theorem scale_spec : scale = scale := rfl +theorem appendNoise_spec : appendNoise = appendNoise := rfl theorem radius_spec : radius = radius := rfl theorem toInterval_spec : toInterval = toInterval := rfl +theorem mul_spec : mul = mul := rfl +theorem mulDisjoint_spec : mulDisjoint = mulDisjoint := rfl end AffineForm diff --git a/Nfp/Sound/BinaryPure.lean b/Nfp/Sound/BinaryPure.lean index df62f94..6b0b7f5 100644 --- a/Nfp/Sound/BinaryPure.lean +++ b/Nfp/Sound/BinaryPure.lean @@ -194,13 +194,15 @@ def vectorMaxAbsScaledFromBytes (bytes : ByteArray) (n scalePow10 : Nat) : if bytes.size < n * 8 then throw "unexpected EOF" let mut maxAbs : Int := 0 - for i in [:n] do + let mut i : Nat := 0 + while i < n do let bits := u64FromLE bytes (i * 8) match floatAbsCeilScaled scalePow10 bits with | .error e => throw e | .ok absScaled => if absScaled > maxAbs then maxAbs := absScaled + i := i + 1 return maxAbs def matrixNormInfScaledFromBytes (bytes : ByteArray) (rows cols scalePow10 : Nat) : @@ -212,7 +214,8 @@ def matrixNormInfScaledFromBytes (bytes : ByteArray) (rows cols scalePow10 : Nat throw "unexpected EOF" let mut maxRowSum : Int := 0 let mut curRowSum : Int := 0 - for i in [:count] do + let mut i : Nat := 0 + while i < count do let bits := u64FromLE bytes (i * 8) match floatAbsCeilScaled scalePow10 bits with | .error e => throw e @@ -222,6 +225,7 @@ def matrixNormInfScaledFromBytes (bytes : ByteArray) (rows cols scalePow10 : Nat if curRowSum > maxRowSum then maxRowSum := curRowSum curRowSum := 0 + i := i + 1 return maxRowSum def scaledFloatArrayFromBytes (bytes : ByteArray) (count scalePow10 : Nat) : @@ -230,13 +234,46 @@ def scaledFloatArrayFromBytes (bytes : ByteArray) (count scalePow10 : Nat) : return #[] if bytes.size < count * 8 then throw "unexpected EOF" - let mut out : Array Int := Array.mkEmpty count - for i in [:count] do - let bits := u64FromLE bytes (i * 8) - match floatScaledCeilSigned scalePow10 bits with - | .error e => throw e - | .ok v => out := out.push v - return out + let useTasks := count > 16384 + if useTasks then + let chunkSize : Nat := 8192 + let numChunks : Nat := (count + chunkSize - 1) / chunkSize + let mut tasks : Array (Task (Except String (Array Int))) := Array.mkEmpty numChunks + let mut chunkIdx : Nat := 0 + while chunkIdx < numChunks do + let start := chunkIdx * chunkSize + let stop := min count (start + chunkSize) + tasks := tasks.push <| + Task.spawn (fun _ => + Id.run do + let mut outChunk : Array Int := Array.mkEmpty (stop - start) + let mut i := start + while i < stop do + let bits := u64FromLE bytes (i * 8) + match floatScaledCeilSigned scalePow10 bits with + | .error e => return .error e + | .ok v => outChunk := outChunk.push v + i := i + 1 + return .ok outChunk) + chunkIdx := chunkIdx + 1 + let mut out : Array Int := Array.mkEmpty count + for t in tasks do + match t.get with + | .error e => throw e + | .ok chunk => + for v in chunk do + out := out.push v + return out + else + let mut out : Array Int := Array.mkEmpty count + let mut i : Nat := 0 + while i < count do + let bits := u64FromLE bytes (i * 8) + match floatScaledCeilSigned scalePow10 bits with + | .error e => throw e + | .ok v => out := out.push v + i := i + 1 + return out def scaledFloatFromBytes (bytes : ByteArray) (scalePow10 : Nat) : Except String Int := do @@ -254,9 +291,11 @@ def i32ArrayFromBytes (bytes : ByteArray) (count : Nat) : if bytes.size < count * 4 then throw "unexpected EOF" let mut out : Array Int := Array.mkEmpty count - for i in [:count] do + let mut i : Nat := 0 + while i < count do let v := i32FromLE bytes (i * 4) out := out.push v + i := i + 1 return out def matrixNormOneInfScaledFromBytes (bytes : ByteArray) (rows cols scalePow10 : Nat) : @@ -269,7 +308,8 @@ def matrixNormOneInfScaledFromBytes (bytes : ByteArray) (rows cols scalePow10 : let mut maxRowSum : Nat := 0 let mut curRowSum : Nat := 0 let mut colSums : Array Nat := Array.replicate cols 0 - for i in [:count] do + let mut i : Nat := 0 + while i < count do let bits := u64FromLE bytes (i * 8) match floatAbsCeilScaled scalePow10 bits with | .error e => throw e @@ -282,6 +322,7 @@ def matrixNormOneInfScaledFromBytes (bytes : ByteArray) (rows cols scalePow10 : if curRowSum > maxRowSum then maxRowSum := curRowSum curRowSum := 0 + i := i + 1 let mut maxColSum : Nat := 0 for c in colSums do if c > maxColSum then diff --git a/Nfp/Sound/Bounds/Exp.lean b/Nfp/Sound/Bounds/Exp.lean index 202945e..3d861f7 100644 --- a/Nfp/Sound/Bounds/Exp.lean +++ b/Nfp/Sound/Bounds/Exp.lean @@ -14,14 +14,31 @@ open scoped BigOperators # Exp lower bounds (scaled Taylor + squaring) -/ -/-- Power function on `Rat` for natural exponents. -/ -private def ratPow (x : Rat) : Nat → Rat - | 0 => 1 - | n + 1 => ratPow x n * x +/-- Power function on `Rat` for natural exponents (iterative to avoid deep recursion). -/ +private def ratPow (x : Rat) (n : Nat) : Rat := + Id.run do + let mut acc : Rat := 1 + let mut base : Rat := x + let mut exp : Nat := n + while exp > 0 do + if exp % 2 = 1 then + acc := acc * base + base := base * base + exp := exp / 2 + return acc theorem ratPow_def (x : Rat) (n : Nat) : - ratPow x n = match n with | 0 => 1 | n + 1 => ratPow x n * x := by - cases n <;> rfl + ratPow x n = + Id.run do + let mut acc : Rat := 1 + let mut base : Rat := x + let mut exp : Nat := n + while exp > 0 do + if exp % 2 = 1 then + acc := acc * base + base := base * base + exp := exp / 2 + return acc := rfl /-- Factorial as a rational. -/ private def ratFactorial (n : Nat) : Rat := (Nat.factorial n : Nat) @@ -30,11 +47,29 @@ theorem ratFactorial_def (n : Nat) : ratFactorial n = (Nat.factorial n : Nat) := /-- Taylor partial sum for `exp` (all terms are nonnegative when `x ≥ 0`). -/ private def expTaylorLowerBound (x : Rat) (deg : Nat) : Rat := - Finset.sum (Finset.range (deg + 1)) fun k => ratPow x k / ratFactorial k + Id.run do + let mut term : Rat := 1 + let mut sum : Rat := 1 + let mut k : Nat := 1 + while k ≤ deg do + let kRat : Rat := (k : Nat) + term := term * x / kRat + sum := sum + term + k := k + 1 + return sum theorem expTaylorLowerBound_def (x : Rat) (deg : Nat) : expTaylorLowerBound x deg = - Finset.sum (Finset.range (deg + 1)) fun k => ratPow x k / ratFactorial k := rfl + Id.run do + let mut term : Rat := 1 + let mut sum : Rat := 1 + let mut k : Nat := 1 + while k ≤ deg do + let kRat : Rat := (k : Nat) + term := term * x / kRat + sum := sum + term + k := k + 1 + return sum := rfl /-- Lower bound on `exp` via scaled Taylor partial sums and repeated squaring. -/ def expLBScaledTaylor (x : Rat) (deg scalePow : Nat) : Rat := diff --git a/Nfp/Sound/Fixed.lean b/Nfp/Sound/Fixed.lean index 3c4208d..a89d86b 100644 --- a/Nfp/Sound/Fixed.lean +++ b/Nfp/Sound/Fixed.lean @@ -141,9 +141,11 @@ def unionVec (a b : Array Fixed10Interval) : Array Fixed10Interval := Id.run do if a.size ≠ b.size then return a - let mut out := Array.mkEmpty a.size - for i in [:a.size] do + let mut out : Array Fixed10Interval := Array.mkEmpty a.size + let mut i : Nat := 0 + while i < a.size do out := out.push (union a[i]! b[i]!) + i := i + 1 return out /-! ### Specs -/ diff --git a/Nfp/Sound/HeadCert.lean b/Nfp/Sound/HeadCert.lean index 33eae3b..012876e 100644 --- a/Nfp/Sound/HeadCert.lean +++ b/Nfp/Sound/HeadCert.lean @@ -100,6 +100,8 @@ structure HeadPatternCert where headIdx : Nat seqLen : Nat targetOffset : Int + /-- Key-position offset used for token matching. -/ + keyOffset : Int targetCountLowerBound : Nat targetLogitLowerBound : Rat otherLogitUpperBound : Rat @@ -217,6 +219,8 @@ structure HeadBestMatchPatternCert where seqLen : Nat queryPos : Nat targetOffset : Int + /-- Key-position offset used for token matching. -/ + keyOffset : Int targetToken : Int bestMatchLogitLowerBound : Rat bestNonmatchLogitUpperBound : Rat @@ -404,6 +408,7 @@ namespace HeadPatternCert def toTokenMatchPattern (c : HeadPatternCert) : Nfp.TokenMatchPattern := { seqLen := c.seqLen targetOffset := c.targetOffset + keyOffset := c.keyOffset targetCountLowerBound := c.targetCountLowerBound softmaxExpEffort := c.softmaxExpEffort targetWeightLowerBound := c.targetWeightLowerBound @@ -417,10 +422,20 @@ theorem toTokenMatchPattern_valid (c : HeadPatternCert) (h : c.Valid) : def toInductionPatternWitness (c : HeadPatternCert) (h : c.Valid) (hm : c.marginLowerBound > 0) - (hcount : 0 < c.targetCountLowerBound) (hoff : c.targetOffset = -1) : + (hcount : 0 < c.targetCountLowerBound) (hoff : c.targetOffset = -1) + (hkey : c.keyOffset = 0) : Nfp.InductionPatternWitness := Nfp.TokenMatchPattern.toInductionPatternWitness - (toTokenMatchPattern c) (toTokenMatchPattern_valid c h) hm hcount hoff + (toTokenMatchPattern c) (toTokenMatchPattern_valid c h) hm hcount hoff hkey + +/-- Build a copy-next witness from a head pattern certificate. -/ +def toCopyNextPatternWitness + (c : HeadPatternCert) (h : c.Valid) (hm : c.marginLowerBound > 0) + (hcount : 0 < c.targetCountLowerBound) (hoff : c.targetOffset = 0) + (hkey : c.keyOffset = -1) : + Nfp.CopyNextPatternWitness := + Nfp.TokenMatchPattern.toCopyNextPatternWitness + (toTokenMatchPattern c) (toTokenMatchPattern_valid c h) hm hcount hoff hkey end HeadPatternCert @@ -669,6 +684,8 @@ theorem HeadPatternCert.toTokenMatchPattern_spec : HeadPatternCert.toTokenMatchPattern = HeadPatternCert.toTokenMatchPattern := rfl theorem HeadPatternCert.toInductionPatternWitness_spec : HeadPatternCert.toInductionPatternWitness = HeadPatternCert.toInductionPatternWitness := rfl +theorem HeadPatternCert.toCopyNextPatternWitness_spec : + HeadPatternCert.toCopyNextPatternWitness = HeadPatternCert.toCopyNextPatternWitness := rfl theorem HeadValueLowerBoundCert.Valid_spec : HeadValueLowerBoundCert.Valid = HeadValueLowerBoundCert.Valid := rfl theorem HeadValueLowerBoundCert.check_spec : diff --git a/Nfp/Sound/IO.lean b/Nfp/Sound/IO.lean index 40423de..a7f385e 100644 --- a/Nfp/Sound/IO.lean +++ b/Nfp/Sound/IO.lean @@ -310,6 +310,7 @@ def certifyHeadPatternLocal (inputDelta : Rat := 0) (soundnessBits : Nat) (targetOffset : Int := -1) + (keyOffset : Int := 0) (maxSeqLen : Nat := 256) (tightPattern : Bool := false) (tightPatternLayers : Nat := 1) @@ -323,9 +324,9 @@ def certifyHeadPatternLocal | .ok eps => match ← Nfp.Untrusted.SoundCompute.certifyHeadPatternLocal - path layerIdx headIdx eps soundnessBits inputPath? inputDelta targetOffset maxSeqLen - tightPattern tightPatternLayers perRowPatternLayers scalePow10 softmaxExpEffort - causalPattern with + path layerIdx headIdx eps soundnessBits inputPath? inputDelta targetOffset keyOffset + maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 + softmaxExpEffort causalPattern with | .error e => return .error e | .ok cert => return verifyHeadPatternCert cert @@ -339,10 +340,12 @@ def certifyHeadPatternBestMatchLocal (inputDelta : Rat := 0) (soundnessBits : Nat) (targetOffset : Int := -1) + (keyOffset : Int := 0) (maxSeqLen : Nat := 256) (tightPattern : Bool := false) (tightPatternLayers : Nat := 1) (perRowPatternLayers : Nat := 0) + (useAffine : Bool := false) (scalePow10 : Nat := 9) (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) (causalPattern : Bool := true) : @@ -353,8 +356,8 @@ def certifyHeadPatternBestMatchLocal match ← Nfp.Untrusted.SoundCompute.certifyHeadPatternBestMatchLocal path layerIdx headIdx queryPos? eps soundnessBits inputPath? inputDelta targetOffset - maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 - softmaxExpEffort causalPattern with + keyOffset maxSeqLen tightPattern tightPatternLayers perRowPatternLayers useAffine + scalePow10 softmaxExpEffort causalPattern with | .error e => return .error e | .ok cert => return verifyHeadBestMatchPatternCert cert @@ -367,10 +370,12 @@ def certifyHeadPatternBestMatchLocalSweep (inputDelta : Rat := 0) (soundnessBits : Nat) (targetOffset : Int := -1) + (keyOffset : Int := 0) (maxSeqLen : Nat := 256) (tightPattern : Bool := false) (tightPatternLayers : Nat := 1) (perRowPatternLayers : Nat := 0) + (useAffine : Bool := false) (scalePow10 : Nat := 9) (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) (causalPattern : Bool := true) : @@ -380,9 +385,9 @@ def certifyHeadPatternBestMatchLocalSweep | .ok eps => match ← Nfp.Untrusted.SoundCompute.certifyHeadPatternBestMatchLocalSweep - path layerIdx headIdx eps soundnessBits inputPath? inputDelta targetOffset maxSeqLen - tightPattern tightPatternLayers perRowPatternLayers scalePow10 softmaxExpEffort - causalPattern with + path layerIdx headIdx eps soundnessBits inputPath? inputDelta targetOffset keyOffset + maxSeqLen tightPattern tightPatternLayers perRowPatternLayers useAffine scalePow10 + softmaxExpEffort causalPattern with | .error e => return .error e | .ok certs => return verifyHeadBestMatchPatternCerts certs @@ -395,6 +400,7 @@ def certifyLayerBestMatchMarginLocal (inputDelta : Rat := 0) (soundnessBits : Nat) (targetOffset : Int := -1) + (keyOffset : Int := 0) (maxSeqLen : Nat := 256) (tightPattern : Bool := false) (tightPatternLayers : Nat := 1) @@ -408,9 +414,9 @@ def certifyLayerBestMatchMarginLocal | .ok eps => match ← Nfp.Untrusted.SoundCompute.certifyLayerBestMatchMarginLocal - path layerIdx eps soundnessBits inputPath? inputDelta targetOffset maxSeqLen - tightPattern tightPatternLayers perRowPatternLayers scalePow10 softmaxExpEffort - causalPattern with + path layerIdx eps soundnessBits inputPath? inputDelta targetOffset keyOffset + maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 + softmaxExpEffort causalPattern with | .error e => return .error e | .ok cert => return verifyLayerBestMatchMarginCert cert @@ -477,6 +483,7 @@ def certifyHeadBoundsLocalBestMatch (inputDelta : Rat := 0) (soundnessBits : Nat) (targetOffset : Int := -1) + (keyOffset : Int := 0) (maxSeqLen : Nat := 256) (tightPattern : Bool := false) (tightPatternLayers : Nat := 1) @@ -501,7 +508,8 @@ def certifyHeadBoundsLocalBestMatch certifyHeadPatternBestMatchLocal path layerIdx headIdx (queryPos? := queryPos?) (inputPath? := inputPath?) (inputDelta := inputDelta) (soundnessBits := soundnessBits) - (targetOffset := targetOffset) (maxSeqLen := maxSeqLen) + (targetOffset := targetOffset) (keyOffset := keyOffset) + (maxSeqLen := maxSeqLen) (tightPattern := tightPattern) (tightPatternLayers := tightPatternLayers) (perRowPatternLayers := perRowPatternLayers) (softmaxExpEffort := softmaxExpEffort) @@ -520,6 +528,7 @@ def certifyHeadValueLowerBoundLocal (inputDelta : Rat := 0) (soundnessBits : Nat) (targetOffset : Int := -1) + (keyOffset : Int := 0) (maxSeqLen : Nat := 256) (tightPattern : Bool := false) (tightPatternLayers : Nat := 1) @@ -533,7 +542,7 @@ def certifyHeadValueLowerBoundLocal match ← Nfp.Untrusted.SoundCompute.certifyHeadValueLowerBoundLocal path layerIdx headIdx coord eps soundnessBits inputPath? inputDelta targetOffset - maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 + keyOffset maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 causalPattern with | .error e => return .error e | .ok cert => @@ -548,6 +557,7 @@ def certifyHeadLogitDiffLowerBoundLocal (inputDelta : Rat := 0) (soundnessBits : Nat) (targetOffset : Int := -1) + (keyOffset : Int := 0) (maxSeqLen : Nat := 256) (tightPattern : Bool := false) (tightPatternLayers : Nat := 1) @@ -561,7 +571,7 @@ def certifyHeadLogitDiffLowerBoundLocal match ← Nfp.Untrusted.SoundCompute.certifyHeadLogitDiffLowerBoundLocal path layerIdx headIdx targetToken negativeToken eps soundnessBits inputPath? inputDelta - targetOffset maxSeqLen tightPattern tightPatternLayers + targetOffset keyOffset maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 causalPattern with | .error e => return .error e | .ok cert => @@ -577,6 +587,8 @@ def certifyInductionSound (soundnessBits : Nat) (offset1 : Int := -1) (offset2 : Int := -1) + (keyOffset1 : Int := 0) + (keyOffset2 : Int := 0) (maxSeqLen : Nat := 256) (scalePow10 : Nat := 9) (tightPattern : Bool := false) @@ -593,7 +605,8 @@ def certifyInductionSound match ← Nfp.Untrusted.SoundCompute.certifyInductionSound path layer1 head1 layer2 head2 coord eps soundnessBits inputPath? inputDelta - offset1 offset2 maxSeqLen scalePow10 tightPattern tightPatternLayers + offset1 offset2 keyOffset1 keyOffset2 maxSeqLen scalePow10 tightPattern + tightPatternLayers perRowPatternLayers targetToken? negativeToken? softmaxExpEffort causalPattern with | .error e => return .error e | .ok cert => @@ -610,11 +623,14 @@ def certifyInductionSoundBestMatch (soundnessBits : Nat) (offset1 : Int := -1) (offset2 : Int := -1) + (keyOffset1 : Int := 0) + (keyOffset2 : Int := 0) (maxSeqLen : Nat := 256) (scalePow10 : Nat := 9) (tightPattern : Bool := false) (tightPatternLayers : Nat := 1) (perRowPatternLayers : Nat := 0) + (useAffine : Bool := false) (iterTighten : Bool := false) (targetToken? : Option Nat := none) (negativeToken? : Option Nat := none) @@ -627,8 +643,8 @@ def certifyInductionSoundBestMatch match ← Nfp.Untrusted.SoundCompute.certifyInductionSoundBestMatch path layer1 head1 layer2 head2 coord queryPos? eps soundnessBits inputPath? inputDelta - offset1 offset2 maxSeqLen scalePow10 tightPattern - tightPatternLayers perRowPatternLayers iterTighten targetToken? negativeToken? + offset1 offset2 keyOffset1 keyOffset2 maxSeqLen scalePow10 tightPattern + tightPatternLayers perRowPatternLayers useAffine iterTighten targetToken? negativeToken? softmaxExpEffort causalPattern with | .error e => return .error e diff --git a/Nfp/Sound/Interval.lean b/Nfp/Sound/Interval.lean index 9d11de6..6a82fa2 100644 --- a/Nfp/Sound/Interval.lean +++ b/Nfp/Sound/Interval.lean @@ -271,17 +271,29 @@ def geluOverapproxTanh (a : RatInterval) (expEffort : Nat := defaultSoftmaxExpEf /-- Split-based tightening for tanh GeLU over-approximation. -/ def geluOverapproxTanhSplit (a : RatInterval) (expEffort : Nat := defaultSoftmaxExpEffort) (splitDepth : Nat := 0) : RatInterval := - if splitDepth = 0 then - geluOverapproxTanh a expEffort - else - let lo := min a.lo a.hi - let hi := max a.lo a.hi - let mid := (lo + hi) / (2 : Rat) - let left : RatInterval := { lo := lo, hi := mid } - let right : RatInterval := { lo := mid, hi := hi } - RatInterval.union - (geluOverapproxTanhSplit left expEffort (splitDepth - 1)) - (geluOverapproxTanhSplit right expEffort (splitDepth - 1)) + Id.run do + let mut stack : Array (RatInterval × Nat) := #[(a, splitDepth)] + let mut acc? : Option RatInterval := none + while stack.size > 0 do + let idx := stack.size - 1 + let (cur, depth) := stack[idx]! + stack := stack.pop + if depth = 0 then + let leaf := geluOverapproxTanh cur expEffort + acc? := + match acc? with + | none => some leaf + | some acc => some (union acc leaf) + else + let lo := min cur.lo cur.hi + let hi := max cur.lo cur.hi + let mid := (lo + hi) / (2 : Rat) + let left : RatInterval := { lo := lo, hi := mid } + let right : RatInterval := { lo := mid, hi := hi } + let depth' := depth - 1 + stack := stack.push (left, depth') + stack := stack.push (right, depth') + return acc?.getD (geluOverapproxTanh a expEffort) /-- Upper bound on `max |gelu'(x)|` over a rational interval. -/ def geluDerivBound (target : GeluDerivTarget) (a : RatInterval) : Rat := diff --git a/README.md b/README.md index cec1a3f..2944cb7 100644 --- a/README.md +++ b/README.md @@ -305,14 +305,16 @@ lake exe nfp head_bounds models/gpt2_rigorous.nfpt --delta 0.01 Computes a sound local attention pattern bound for a single head (binary only), propagating per-position intervals up to the target layer (bounded by `maxSeqLen`). -The pattern compares logits for keys whose token matches the query’s offset token -(e.g., `--offset -1` matches the previous token). +The pattern compares logits for keys whose **shifted-key token** matches the +query’s **offset token** (e.g., `--offset -1` matches the previous token, and +`--offset 0 --keyOffset -1` matches the copy-next pattern). ```bash lake exe nfp head_pattern models/gpt2_rigorous.nfpt --layer 0 --head 0 --delta 0.01 --offset -1 ``` - `--offset` selects the target key position relative to the query (default: `-1` for previous token). +- `--keyOffset` selects which key-position token is matched (default: `0` for the key token itself). - `--maxSeqLen` caps the sequence length analyzed for pattern bounds (default: `256`). - `--delta` sets the local input radius; LayerNorm ε is read from the model header (`layer_norm_eps`). - `--tightPattern` enables a slower but tighter pattern bound near the target layer. @@ -337,6 +339,8 @@ lake exe nfp induction_cert models/gpt2_rigorous.nfpt \ token-match head. - `--coord` chooses the output coordinate used for the value lower bound. - `--offset1/--offset2` adjust the token-match offsets (default: `-1`). +- `--keyOffset1/--keyOffset2` adjust the key-token offsets (default: `0`; + use `--offset2 0 --keyOffset2 -1` for copy-next induction). - `--target/--negative` optionally add a logit-diff lower bound using unembedding columns. - `--tightPattern` enables a slower but tighter pattern bound near the target layer. - `--tightPatternLayers` sets how many layers use tight bounds (default: `1`; implies `--tightPattern`). diff --git a/scripts/scan_gpt2_induction_sound.py b/scripts/scan_gpt2_induction_sound.py index b767d59..0d33efe 100755 --- a/scripts/scan_gpt2_induction_sound.py +++ b/scripts/scan_gpt2_induction_sound.py @@ -132,7 +132,9 @@ def main() -> int: parser.add_argument("--delta", default="0.01") parser.add_argument("--coord", type=int, default=0) parser.add_argument("--offset1", type=int, default=-1) - parser.add_argument("--offset2", type=int, default=-1) + parser.add_argument("--offset2", type=int, default=0) + parser.add_argument("--keyOffset1", type=int, default=0) + parser.add_argument("--keyOffset2", type=int, default=-1) parser.add_argument("--maxSeqLen", type=int, default=256) parser.add_argument("--jobs", type=int, default=1) parser.add_argument("--fast", action="store_true") @@ -202,6 +204,10 @@ def run_cert(pair: tuple[int, int, int, int]) -> tuple[tuple[int, int, int, int] str(args.offset1), "--offset2", str(args.offset2), + "--keyOffset1", + str(args.keyOffset1), + "--keyOffset2", + str(args.keyOffset2), "--delta", args.delta, "--maxSeqLen", From 7f05888e3a090bcfa55116b119f01ee26b41629c Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 09:45:12 +0100 Subject: [PATCH 022/244] Optimize tokenizer and text parsing --- Nfp/IO/Pure.lean | 59 ++++++++++++++++++++++++++++++----------- Nfp/Sound/TextPure.lean | 15 ++++++++--- 2 files changed, 55 insertions(+), 19 deletions(-) diff --git a/Nfp/IO/Pure.lean b/Nfp/IO/Pure.lean index e6c49ee..1416623 100644 --- a/Nfp/IO/Pure.lean +++ b/Nfp/IO/Pure.lean @@ -1,5 +1,6 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later +import Std import Nfp.Discovery /-! @@ -351,6 +352,8 @@ def mkMLPLayer structure Tokenizer where /-- Token strings in order of ID. -/ tokens : Array String + /-- Map from token string to its first ID. -/ + tokMap : Std.HashMap String Nat /-- Unknown token ID. -/ unkId : Nat /-- Padding token ID. -/ @@ -363,28 +366,46 @@ namespace Tokenizer /-- Create a tokenizer from vocabulary list. -/ def fromVocabList (tokens : Array String) (unkId padId eosId : Nat := 0) : Tokenizer := - { tokens := tokens, unkId := unkId, padId := padId, eosId := eosId } + let tokMap := + Id.run do + let mut out : Std.HashMap String Nat := Std.HashMap.emptyWithCapacity tokens.size + let mut i := tokens.size + while i > 0 do + i := i - 1 + out := out.insert tokens[i]! i + return out + { tokens := tokens, tokMap := tokMap, unkId := unkId, padId := padId, eosId := eosId } /-- Find a token's ID in the vocabulary. -/ def findToken (t : Tokenizer) (word : String) : Nat := - match t.tokens.findIdx? (fun tok => tok == word) with - | some idx => idx - | none => t.unkId + t.tokMap.getD word t.unkId /-- Tokenize a string using simple whitespace splitting. -/ def tokenize (t : Tokenizer) (text : String) : Array Nat := Id.run do - let words := text.splitOn " " |>.filter (fun w => w != "") let mut ids : Array Nat := #[] - for word in words do - ids := ids.push (t.findToken word) + let mut p : String.Pos.Raw := 0 + let stop := text.rawEndPos + while p < stop do + while p < stop && p.get text = ' ' do + p := p.next text + let start := p + while p < stop && p.get text ≠ ' ' do + p := p.next text + if start < p then + let word := String.Pos.Raw.extract text start p + ids := ids.push (t.findToken word) ids /-- Decode token IDs back to text. -/ def decode (t : Tokenizer) (ids : Array Nat) : String := - let tokens := ids.filterMap fun id => - if id < t.tokens.size then some t.tokens[id]! - else none - " ".intercalate tokens.toList + let tokens := ids.foldr + (fun id acc => + if id < t.tokens.size then + t.tokens[id]! :: acc + else + acc) + [] + " ".intercalate tokens end Tokenizer @@ -394,14 +415,20 @@ end Tokenizer def lookupEmbeddings (embeddings : ConcreteMatrix) (tokenIds : Array Nat) (seqLen : Nat) (padId : Nat := 0) : ConcreteMatrix := Id.run do let modelDim := embeddings.numCols - let mut data : Array Float := #[] + let rowCount := embeddings.numRows + let tokenIdsSize := tokenIds.size + let mut data : Array Float := Array.mkEmpty (seqLen * modelDim) for pos in [:seqLen] do - let tokenId := if pos < tokenIds.size then tokenIds[pos]! else padId + let tokenId := if pos < tokenIdsSize then tokenIds[pos]! else padId -- Copy embedding row for this token. - for dim in [:modelDim] do - let val := embeddings.get tokenId dim - data := data.push val + if tokenId < rowCount then + let rowBase := tokenId * modelDim + for dim in [:modelDim] do + data := data.push embeddings.data[rowBase + dim]! + else + for _ in [:modelDim] do + data := data.push 0.0 buildMatrix seqLen modelDim data diff --git a/Nfp/Sound/TextPure.lean b/Nfp/Sound/TextPure.lean index a068eeb..2466c89 100644 --- a/Nfp/Sound/TextPure.lean +++ b/Nfp/Sound/TextPure.lean @@ -176,9 +176,18 @@ def modelWeightBoundsFromTextLines (lines : Array String) : Except String ModelW while i < lines.size do let line := lines[i]!.trim if line.startsWith "LAYER" then - let parts := line.splitOn " " |>.filter (· ≠ "") - if parts.length >= 2 then - curLayer := (parts[1]!).toNat? |>.getD 0 + let mut p : String.Pos.Raw := 0 + let stop := line.rawEndPos + while p < stop && p.get line ≠ ' ' do + p := p.next line + while p < stop && p.get line = ' ' do + p := p.next line + if p < stop then + let start := p + while p < stop && p.get line ≠ ' ' do + p := p.next line + let tok := String.Pos.Raw.extract line start p + curLayer := tok.toNat? |>.getD 0 i := i + 1 else if line = "W_Q" then let r := curLayer From aa4f11cca70fd12b58f55bc8aa5ae64c462f5f3d Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 09:47:45 +0100 Subject: [PATCH 023/244] Optimize power iteration and binary parsing --- Nfp/Discovery.lean | 44 +++++++++++++++--------------- Nfp/Sound/BinaryPure.lean | 57 ++++++++++++++++++++++++++++----------- Nfp/Sound/CachePure.lean | 15 +++++------ 3 files changed, 71 insertions(+), 45 deletions(-) diff --git a/Nfp/Discovery.lean b/Nfp/Discovery.lean index f5f005c..a3590c9 100644 --- a/Nfp/Discovery.lean +++ b/Nfp/Discovery.lean @@ -314,7 +314,7 @@ We approximate it using power iteration on M^T M. This is a fast **heuristic estimate** of how much `M` can stretch a vector. PERFORMANCE: Power iteration is O(iterations × n²) but heavily optimized: -- Pre-allocated vectors with `Array.ofFn` (no array copying) +- Pre-allocated vectors reused across iterations (`Array.replicate` + `set!`) - Direct loops instead of `List.range.foldl` (10× faster) - Bounds-checked access `v[j]!` and `Mv[i]!` (compiler optimizes in loops) -/ @@ -323,49 +323,49 @@ def operatorNormHeuristicPI (M : ConcreteMatrix) (numIterations : Nat := 20) : F let numCols := M.numCols if numRows = 0 || numCols = 0 then return 0.0 - -- Initialize with a vector of ones - let mut v : Array Float := .ofFn fun _ : Fin numCols => 1.0 - - -- Normalize initial vector - let initNorm := Float.sqrt (v.foldl (fun acc x => acc + x * x) 0.0) - if initNorm > 0.0 then - v := v.map (· / initNorm) + -- Initialize with a normalized vector of ones (avoids a fold + map). + let initScale := 1.0 / Float.sqrt numCols.toFloat + let mut v : Array Float := Array.replicate numCols initScale + let mut Mv : Array Float := Array.replicate numRows 0.0 + let mut MTMv : Array Float := Array.replicate numCols 0.0 -- Power iteration: v ← (M^T M) v / ‖(M^T M) v‖ let mut sigma : Float := 0.0 for _ in [:numIterations] do -- Compute M v - let mut Mv : Array Float := .ofFn fun i : Fin numRows => Id.run do + let mut mvNormSq : Float := 0.0 + for i in [:numRows] do let mut acc : Float := 0.0 - let rowBase := i.val * numCols + let rowBase := i * numCols for j in [:numCols] do - -- SAFETY: v has size M.numCols, guaranteed by Array.ofFn + -- SAFETY: v has size M.numCols, guaranteed by Array.replicate. acc := acc + M.data[rowBase + j]! * v[j]! - return acc + Mv := Mv.set! i acc + mvNormSq := mvNormSq + acc * acc -- Compute M^T (M v) = (M^T M) v - let mut MTMv : Array Float := .ofFn fun j : Fin numCols => Id.run do + let mut mtmvNormSq : Float := 0.0 + for j in [:numCols] do let mut acc : Float := 0.0 - let col := j.val for i in [:numRows] do - -- SAFETY: Mv has size M.numRows, guaranteed by Array.ofFn above - acc := acc + M.data[i * numCols + col]! * Mv[i]! - return acc + -- SAFETY: Mv has size M.numRows, guaranteed by Array.replicate. + acc := acc + M.data[i * numCols + j]! * Mv[i]! + MTMv := MTMv.set! j acc + mtmvNormSq := mtmvNormSq + acc * acc -- Compute norm of MTMv (this is σ² times ‖v‖, and ‖v‖ ≈ 1) - let normSq := MTMv.foldl (fun acc x => acc + x * x) 0.0 - let norm := Float.sqrt normSq + let norm := Float.sqrt mtmvNormSq if norm < 1e-15 then return 0.0 -- σ² ≈ ‖MTMv‖ / ‖v‖ ≈ ‖MTMv‖ -- So σ ≈ ‖Mv‖ - let MvNorm := Float.sqrt (Mv.foldl (fun acc x => acc + x * x) 0.0) - sigma := MvNorm + sigma := Float.sqrt mvNormSq -- Normalize for next iteration - v := MTMv.map (· / norm) + for j in [:numCols] do + v := v.set! j (MTMv[j]! / norm) -- Heuristic safety margin for numerical errors sigma * 1.01 diff --git a/Nfp/Sound/BinaryPure.lean b/Nfp/Sound/BinaryPure.lean index 6b0b7f5..51a5f5e 100644 --- a/Nfp/Sound/BinaryPure.lean +++ b/Nfp/Sound/BinaryPure.lean @@ -148,7 +148,10 @@ private def ceilDivNat (a : Int) (d : Nat) : Int := let r := a.emod di if r = 0 then q else q + 1 -private def floatAbsCeilScaled (scalePow10 : Nat) (bits : UInt64) : Except String Int := +private def scaleIntOfPow10 (scalePow10 : Nat) : Int := + Int.ofNat (Nat.pow 10 scalePow10) + +private def floatAbsCeilScaledCore (scaleInt : Int) (bits : UInt64) : Except String Int := let expBits : UInt64 := (bits >>> 52) &&& 0x7ff let mantBits : UInt64 := bits &&& 0x000f_ffff_ffff_ffff if expBits = 0x7ff then @@ -156,7 +159,6 @@ private def floatAbsCeilScaled (scalePow10 : Nat) (bits : UInt64) : Except Strin else if expBits = 0 && mantBits = 0 then .ok 0 else - let scale : Nat := Nat.pow 10 scalePow10 let mant : Nat := if expBits = 0 then mantBits.toNat @@ -170,15 +172,18 @@ private def floatAbsCeilScaled (scalePow10 : Nat) (bits : UInt64) : Except Strin let mInt : Int := Int.ofNat mant if expVal ≥ 0 then let pow2 := pow2Nat expVal.toNat - let num := mInt * Int.ofNat scale + let num := mInt * scaleInt .ok (num * Int.ofNat pow2) else let denPow := pow2Nat (-expVal).toNat - let num := mInt * Int.ofNat scale + let num := mInt * scaleInt .ok (ceilDivNat num denPow) -private def floatScaledCeilSigned (scalePow10 : Nat) (bits : UInt64) : Except String Int := - match floatAbsCeilScaled scalePow10 bits with +private def floatAbsCeilScaled (scalePow10 : Nat) (bits : UInt64) : Except String Int := + floatAbsCeilScaledCore (scaleIntOfPow10 scalePow10) bits + +private def floatScaledCeilSignedCore (scaleInt : Int) (bits : UInt64) : Except String Int := + match floatAbsCeilScaledCore scaleInt bits with | .error e => .error e | .ok absScaled => let signNeg : Bool := (bits >>> 63) = (1 : UInt64) @@ -187,17 +192,21 @@ private def floatScaledCeilSigned (scalePow10 : Nat) (bits : UInt64) : Except St else .ok absScaled +private def floatScaledCeilSigned (scalePow10 : Nat) (bits : UInt64) : Except String Int := + floatScaledCeilSignedCore (scaleIntOfPow10 scalePow10) bits + def vectorMaxAbsScaledFromBytes (bytes : ByteArray) (n scalePow10 : Nat) : Except String Int := do if n = 0 then return 0 if bytes.size < n * 8 then throw "unexpected EOF" + let scaleInt := scaleIntOfPow10 scalePow10 let mut maxAbs : Int := 0 let mut i : Nat := 0 while i < n do let bits := u64FromLE bytes (i * 8) - match floatAbsCeilScaled scalePow10 bits with + match floatAbsCeilScaledCore scaleInt bits with | .error e => throw e | .ok absScaled => if absScaled > maxAbs then @@ -212,12 +221,13 @@ def matrixNormInfScaledFromBytes (bytes : ByteArray) (rows cols scalePow10 : Nat let count := rows * cols if bytes.size < count * 8 then throw "unexpected EOF" + let scaleInt := scaleIntOfPow10 scalePow10 let mut maxRowSum : Int := 0 let mut curRowSum : Int := 0 let mut i : Nat := 0 while i < count do let bits := u64FromLE bytes (i * 8) - match floatAbsCeilScaled scalePow10 bits with + match floatAbsCeilScaledCore scaleInt bits with | .error e => throw e | .ok absScaled => curRowSum := curRowSum + absScaled @@ -235,6 +245,7 @@ def scaledFloatArrayFromBytes (bytes : ByteArray) (count scalePow10 : Nat) : if bytes.size < count * 8 then throw "unexpected EOF" let useTasks := count > 16384 + let scaleInt := scaleIntOfPow10 scalePow10 if useTasks then let chunkSize : Nat := 8192 let numChunks : Nat := (count + chunkSize - 1) / chunkSize @@ -250,7 +261,7 @@ def scaledFloatArrayFromBytes (bytes : ByteArray) (count scalePow10 : Nat) : let mut i := start while i < stop do let bits := u64FromLE bytes (i * 8) - match floatScaledCeilSigned scalePow10 bits with + match floatScaledCeilSignedCore scaleInt bits with | .error e => return .error e | .ok v => outChunk := outChunk.push v i := i + 1 @@ -269,7 +280,7 @@ def scaledFloatArrayFromBytes (bytes : ByteArray) (count scalePow10 : Nat) : let mut i : Nat := 0 while i < count do let bits := u64FromLE bytes (i * 8) - match floatScaledCeilSigned scalePow10 bits with + match floatScaledCeilSignedCore scaleInt bits with | .error e => throw e | .ok v => out := out.push v i := i + 1 @@ -280,7 +291,7 @@ def scaledFloatFromBytes (bytes : ByteArray) (scalePow10 : Nat) : if bytes.size < 8 then throw "unexpected EOF" let bits := u64FromLE bytes 0 - match floatScaledCeilSigned scalePow10 bits with + match floatScaledCeilSignedCore (scaleIntOfPow10 scalePow10) bits with | .error e => throw e | .ok v => return v @@ -305,13 +316,14 @@ def matrixNormOneInfScaledFromBytes (bytes : ByteArray) (rows cols scalePow10 : let count := rows * cols if bytes.size < count * 8 then throw "unexpected EOF" + let scaleInt := scaleIntOfPow10 scalePow10 let mut maxRowSum : Nat := 0 let mut curRowSum : Nat := 0 let mut colSums : Array Nat := Array.replicate cols 0 let mut i : Nat := 0 while i < count do let bits := u64FromLE bytes (i * 8) - match floatAbsCeilScaled scalePow10 bits with + match floatAbsCeilScaledCore scaleInt bits with | .error e => throw e | .ok absScaled => let absNat := Int.toNat absScaled @@ -346,16 +358,26 @@ def defaultBinaryScalePow10 : Nat := 9 /-- Sum of per-head value-output norm products in scaled-int form. -/ def attnValueCoeffFromScaledPairs (scalePow10 : Nat) (pairs : Array (Int × Int)) : Rat := + let den : Nat := Nat.pow 10 scalePow10 + have den_nz : den ≠ 0 := by + have h10pos : (0 : Nat) < 10 := by decide + exact Nat.ne_of_gt (Nat.pow_pos (n := scalePow10) h10pos) + let ratOfScaledIntLocal := fun (x : Int) => Rat.normalize x den (den_nz := den_nz) pairs.foldl (fun acc p => - acc + ratOfScaledInt scalePow10 p.1 * ratOfScaledInt scalePow10 p.2) 0 + acc + ratOfScaledIntLocal p.1 * ratOfScaledIntLocal p.2) 0 /-- Max per-head W_Q/W_K bounds in scaled-int form. -/ def attnQKMaxFromScaledPairs (scalePow10 : Nat) (pairs : Array (Int × Int)) : Rat × Rat := + let den : Nat := Nat.pow 10 scalePow10 + have den_nz : den ≠ 0 := by + have h10pos : (0 : Nat) < 10 := by decide + exact Nat.ne_of_gt (Nat.pow_pos (n := scalePow10) h10pos) + let ratOfScaledIntLocal := fun (x : Int) => Rat.normalize x den (den_nz := den_nz) pairs.foldl (fun acc p => - (max acc.1 (ratOfScaledInt scalePow10 p.1), - max acc.2 (ratOfScaledInt scalePow10 p.2))) + (max acc.1 (ratOfScaledIntLocal p.1), + max acc.2 (ratOfScaledIntLocal p.2))) (0, 0) /-- Compute per-layer attention-weight bound arrays from scaled-int pairs. -/ @@ -416,7 +438,12 @@ theorem u32FromLE_spec_binary_pure : u32FromLE = u32FromLE := rfl theorem i32FromLE_spec_binary_pure : i32FromLE = i32FromLE := rfl theorem pow2Nat_spec_binary_pure : pow2Nat = pow2Nat := rfl theorem ceilDivNat_spec_binary_pure : ceilDivNat = ceilDivNat := rfl +theorem scaleIntOfPow10_spec_binary_pure : scaleIntOfPow10 = scaleIntOfPow10 := rfl +theorem floatAbsCeilScaledCore_spec_binary_pure : + floatAbsCeilScaledCore = floatAbsCeilScaledCore := rfl theorem floatAbsCeilScaled_spec_binary_pure : floatAbsCeilScaled = floatAbsCeilScaled := rfl +theorem floatScaledCeilSignedCore_spec_binary_pure : + floatScaledCeilSignedCore = floatScaledCeilSignedCore := rfl theorem floatScaledCeilSigned_spec_binary_pure : floatScaledCeilSigned = floatScaledCeilSigned := rfl theorem vectorMaxAbsScaledFromBytes_spec_binary_pure : diff --git a/Nfp/Sound/CachePure.lean b/Nfp/Sound/CachePure.lean index 8b491ef..9cffd12 100644 --- a/Nfp/Sound/CachePure.lean +++ b/Nfp/Sound/CachePure.lean @@ -206,7 +206,7 @@ private def countWsTokens (s : String) : Nat := let mut cnt : Nat := 0 while i < bytes.size do let b := bytes[i]! - let isWs : Bool := b = 32 || b = 9 + let isWs : Bool := b = 32 || b = 9 || b = 10 || b = 13 if isWs then inTok := false else if !inTok then @@ -223,16 +223,15 @@ private def skipTokensFast (lines : Array String) (start : Nat) (numTokens : Nat while remaining > 0 do if iLine ≥ lines.size then return .error "unexpected end of file while skipping tokens" - let line := lines[iLine]!.trim + let line := lines[iLine]! iLine := iLine + 1 - if line.isEmpty then + let c := countWsTokens line + if c = 0 then pure () + else if c ≥ remaining then + remaining := 0 else - let c := countWsTokens line - if c ≥ remaining then - remaining := 0 - else - remaining := remaining - c + remaining := remaining - c return .ok iLine private def consumeFixedBytes From b26b0dd3a4e18f5e01cbec077c689da65cdbd421 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 10:11:04 +0100 Subject: [PATCH 024/244] Optimize parsing and attention weights --- Nfp/Discovery.lean | 10 +++-- Nfp/Untrusted/SoundCompute.lean | 73 ++++++++++++++++++--------------- 2 files changed, 46 insertions(+), 37 deletions(-) diff --git a/Nfp/Discovery.lean b/Nfp/Discovery.lean index a3590c9..1e9ab84 100644 --- a/Nfp/Discovery.lean +++ b/Nfp/Discovery.lean @@ -2761,6 +2761,7 @@ def toMatrix (A : ConcreteAttentionWeights) : ConcreteMatrix where simpa using A.size_eq /-- Compute softmax for a row of logits. -/ +@[inline] def softmaxRow (logits : Array Float) : Array Float := Id.run do -- PERFORMANCE: keep arrays linear to enable in-place updates @@ -2790,9 +2791,12 @@ def compute (queries keys : ConcreteMatrix) (scale : Float) let n := seqLen * seqLen let mut weights : { w : Array Float // w.size = n } := ⟨Array.replicate n 0.0, by simp [n]⟩ + -- Reuse a single row buffer to avoid per-row allocations. + let mut rowScores : Array Float := Array.replicate seqLen (-1e30) for q in [:seqLen] do -- Initialize to -∞ and only fill the causal prefix when `causal = true`. - let mut rowScores : Array Float := Array.replicate seqLen (-1e30) + for i in [:seqLen] do + rowScores := rowScores.set! i (-1e30) let stop := if causal then min (q + 1) seqLen else seqLen let qBase := q * queries.numCols for j in [:stop] do @@ -2804,10 +2808,10 @@ def compute (queries keys : ConcreteMatrix) (scale : Float) -- and `d < cols ≤ queries.numCols/keys.numCols`. dotProd := dotProd + queries.data[qBase + d]! * keys.data[jBase + d]! rowScores := rowScores.set! j (dotProd / scale) - let row := softmaxRow rowScores + rowScores := softmaxRow rowScores let rowBase := q * seqLen for k in [:stop] do - let weights' := weights.1.set! (rowBase + k) (row[k]!) + let weights' := weights.1.set! (rowBase + k) (rowScores[k]!) have weights'SizeEq : weights'.size = n := by have hsize : weights'.size = weights.1.size := by -- `set!` is `setIfInBounds`, which preserves size. diff --git a/Nfp/Untrusted/SoundCompute.lean b/Nfp/Untrusted/SoundCompute.lean index 95b6522..ec7f966 100644 --- a/Nfp/Untrusted/SoundCompute.lean +++ b/Nfp/Untrusted/SoundCompute.lean @@ -594,6 +594,14 @@ private def consumeVectorSkipFast (n : Nat) : Except String Nat := consumeTokensSkipFast lines start n +/-- Accumulator for streaming matrix multiplication with row-abs tracking. -/ +private structure MulAndNormAcc where + out : Array RatInterval + row : Nat + col : Nat + curRowAbs : Rat + maxRowAbs : Rat + /-! Streaming multiplication for row-major stored matrices. @@ -612,41 +620,38 @@ private def consumeMatrixMulAndNormInf Id.run do if input.size ≠ rows then return .error "input interval dimension mismatch" - let mut out : Array RatInterval := zeroIntervals cols - let mut iLine := start - let mut remaining := rows * cols - let mut idx : Nat := 0 - let mut curRowAbs : Rat := 0 - let mut maxRowAbs : Rat := 0 - while remaining > 0 do - if iLine ≥ lines.size then - return .error "unexpected end of file while reading matrix" - let line := lines[iLine]!.trim - iLine := iLine + 1 - if line.isEmpty then - pure () + let init : MulAndNormAcc := { + out := zeroIntervals cols + row := 0 + col := 0 + curRowAbs := 0 + maxRowAbs := 0 + } + let step := fun (st : MulAndNormAcc) (w : Rat) => + let r := st.row + let c := st.col + let curRowAbs := st.curRowAbs + ratAbs w + -- out[c] += w * input[r] + let term := RatInterval.scale w (input[r]!) + let out := st.out.set! c (RatInterval.add (st.out[c]!) term) + if c + 1 = cols then + { out := out + row := r + 1 + col := 0 + curRowAbs := 0 + maxRowAbs := max st.maxRowAbs curRowAbs } else - let toks := line.splitOn " " |>.filter (· ≠ "") - for t in toks do - if remaining = 0 then - break - match parseRat t with - | .error e => return .error e - | .ok w => - let r := idx / cols - let c := idx % cols - curRowAbs := curRowAbs + ratAbs w - -- out[c] += w * input[r] - let term := RatInterval.scale w (input[r]!) - out := out.set! c (RatInterval.add (out[c]!) term) - idx := idx + 1 - remaining := remaining - 1 - if c + 1 = cols then - maxRowAbs := max maxRowAbs curRowAbs - curRowAbs := 0 - -- Account for a partial last row (should not happen if rows*cols consumed). - maxRowAbs := max maxRowAbs curRowAbs - return .ok (out, maxRowAbs, iLine) + { out := out + row := r + col := c + 1 + curRowAbs := curRowAbs + maxRowAbs := st.maxRowAbs } + match foldRatTokens lines start (rows * cols) init step with + | .error e => return .error e + | .ok (st, next) => + -- Account for a partial last row (should not happen if rows*cols consumed). + let maxRowAbs := max st.maxRowAbs st.curRowAbs + return .ok (st.out, maxRowAbs, next) /-- Soundly compute conservative per-layer residual amplification constants from a `.nfpt` file. -/ def certifyModelFileGlobal From 76eb22f59bc1ac3006abed0a483ea75111d04459 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 10:21:18 +0100 Subject: [PATCH 025/244] Optimize ConcreteMatrix array operations --- Nfp/Discovery.lean | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/Nfp/Discovery.lean b/Nfp/Discovery.lean index 1e9ab84..55a1900 100644 --- a/Nfp/Discovery.lean +++ b/Nfp/Discovery.lean @@ -1,6 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Batteries.Lean.Float +import Init.Data.Array.Extract /-! # Executable Circuit Discovery for Induction Heads @@ -122,15 +123,19 @@ def rowMaxAbs (M : ConcreteMatrix) (r : Nat) : Float := /-- Take the first `n` rows of a matrix (keeping all columns). -/ def takeRows (M : ConcreteMatrix) (n : Nat) : ConcreteMatrix := - if n ≥ M.numRows then + if h : n ≥ M.numRows then M else + let rowCount := n * M.numCols { numRows := n numCols := M.numCols - data := .ofFn fun idx : Fin (n * M.numCols) => - -- Since `n < M.numRows`, `idx.val < n*numCols ≤ numRows*numCols = data.size`. - M.data.getD idx.val 0.0 - size_eq := Array.size_ofFn } + data := M.data.extract 0 rowCount + size_eq := by + have hrows : n ≤ M.numRows := Nat.le_of_lt (Nat.lt_of_not_ge h) + have hsize : rowCount ≤ M.data.size := by + simpa [rowCount, M.size_eq] using Nat.mul_le_mul_right M.numCols hrows + simpa [rowCount] using + (Array.size_extract_of_le (as := M.data) (i := 0) (j := rowCount) hsize) } /-- Create a zero matrix of given dimensions. -/ def zeros (rows cols : Nat) : ConcreteMatrix where @@ -1287,7 +1292,7 @@ def add (A B : ConcreteMatrix) : ConcreteMatrix := numRows := A.numRows numCols := A.numCols data := .ofFn fun idx : Fin (A.numRows * A.numCols) => - A.data.getD idx.val 0.0 + B.data.getD idx.val 0.0 + A.data[idx.val]! + B.data[idx.val]! size_eq := Array.size_ofFn } else zeros 0 0 @@ -7850,7 +7855,7 @@ def ConcreteMatrix.sub (A B : ConcreteMatrix) : ConcreteMatrix := numRows := A.numRows numCols := A.numCols data := .ofFn fun idx : Fin (A.numRows * A.numCols) => - A.data.getD idx.val 0.0 - B.data.getD idx.val 0.0 + A.data[idx.val]! - B.data[idx.val]! size_eq := Array.size_ofFn } else ConcreteMatrix.zeros 0 0 From ba167237ca6b77d9952bed45e5afd268453f4126 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 10:31:42 +0100 Subject: [PATCH 026/244] Optimize parsing loops and matrix construction --- Nfp/IO/Pure.lean | 11 ++--------- Nfp/Sound/BinaryPure.lean | 13 ++++++++++--- Nfp/Sound/TextPure.lean | 17 ++++++++++++++--- 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/Nfp/IO/Pure.lean b/Nfp/IO/Pure.lean index 1416623..0a982ce 100644 --- a/Nfp/IO/Pure.lean +++ b/Nfp/IO/Pure.lean @@ -240,16 +240,9 @@ def parseNatLine (line : String) : Array Nat := This is safe because we ensure the data has exactly the right size. -/ def buildMatrix (rows cols : Nat) (data : Array Float) : ConcreteMatrix := let expectedSize := rows * cols - let normalizedData : Array Float := - if data.size < expectedSize then - data ++ (Array.replicate (expectedSize - data.size) 0.0) - else if data.size > expectedSize then - data.toSubarray 0 expectedSize |>.toArray - else - data - -- Use Array.ofFn to get the exact size we need with a proof. + -- Use Array.ofFn to get the exact size we need while padding/truncating via getD. let finalData := Array.ofFn fun (i : Fin expectedSize) => - normalizedData.getD i.val 0.0 + data.getD i.val 0.0 { numRows := rows numCols := cols diff --git a/Nfp/Sound/BinaryPure.lean b/Nfp/Sound/BinaryPure.lean index 51a5f5e..1e217e6 100644 --- a/Nfp/Sound/BinaryPure.lean +++ b/Nfp/Sound/BinaryPure.lean @@ -224,6 +224,7 @@ def matrixNormInfScaledFromBytes (bytes : ByteArray) (rows cols scalePow10 : Nat let scaleInt := scaleIntOfPow10 scalePow10 let mut maxRowSum : Int := 0 let mut curRowSum : Int := 0 + let mut colIdx : Nat := 0 let mut i : Nat := 0 while i < count do let bits := u64FromLE bytes (i * 8) @@ -231,10 +232,13 @@ def matrixNormInfScaledFromBytes (bytes : ByteArray) (rows cols scalePow10 : Nat | .error e => throw e | .ok absScaled => curRowSum := curRowSum + absScaled - if (i + 1) % cols = 0 then + if colIdx + 1 = cols then if curRowSum > maxRowSum then maxRowSum := curRowSum curRowSum := 0 + colIdx := 0 + else + colIdx := colIdx + 1 i := i + 1 return maxRowSum @@ -320,6 +324,7 @@ def matrixNormOneInfScaledFromBytes (bytes : ByteArray) (rows cols scalePow10 : let mut maxRowSum : Nat := 0 let mut curRowSum : Nat := 0 let mut colSums : Array Nat := Array.replicate cols 0 + let mut colIdx : Nat := 0 let mut i : Nat := 0 while i < count do let bits := u64FromLE bytes (i * 8) @@ -328,12 +333,14 @@ def matrixNormOneInfScaledFromBytes (bytes : ByteArray) (rows cols scalePow10 : | .ok absScaled => let absNat := Int.toNat absScaled curRowSum := curRowSum + absNat - let colIdx := i % cols colSums := colSums.set! colIdx (colSums[colIdx]! + absNat) - if (i + 1) % cols = 0 then + if colIdx + 1 = cols then if curRowSum > maxRowSum then maxRowSum := curRowSum curRowSum := 0 + colIdx := 0 + else + colIdx := colIdx + 1 i := i + 1 let mut maxColSum : Nat := 0 for c in colSums do diff --git a/Nfp/Sound/TextPure.lean b/Nfp/Sound/TextPure.lean index 2466c89..c82f904 100644 --- a/Nfp/Sound/TextPure.lean +++ b/Nfp/Sound/TextPure.lean @@ -151,9 +151,20 @@ def consumeMatrixNormInf (start : Nat) (rows cols : Nat) : Except String (Rat × Nat) := let count := rows * cols - match consumeVector lines start count with - | .error e => .error e - | .ok (xs, next) => .ok (matrixNormInfOfRowMajor rows cols xs, next) + if count = 0 then + .ok (0, start) + else + let step := fun (acc : Rat × Rat × Nat) (x : Rat) => + let (curRowSum, maxRowSum, colIdx) := acc + let curRowSum := curRowSum + ratAbs x + let colIdx := colIdx + 1 + if colIdx = cols then + (0, max maxRowSum curRowSum, 0) + else + (curRowSum, maxRowSum, colIdx) + match foldRatTokens lines start count (0, 0, 0) step with + | .error e => .error e + | .ok ((_, maxRowSum, _), next) => .ok (maxRowSum, next) /-- Compute per-layer weight bounds from text model lines. -/ def modelWeightBoundsFromTextLines (lines : Array String) : Except String ModelWeightBounds := From 4f8e8de4972f5dea677f5eadb5b260d5dee2ae14 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 10:43:28 +0100 Subject: [PATCH 027/244] Refactor parsing helpers --- Nfp/IO/Pure.lean | 122 +++++++++++++++----------------------- Nfp/Sound/BinaryPure.lean | 9 +-- Nfp/Sound/CachePure.lean | 9 +-- 3 files changed, 50 insertions(+), 90 deletions(-) diff --git a/Nfp/IO/Pure.lean b/Nfp/IO/Pure.lean index 0a982ce..d36231d 100644 --- a/Nfp/IO/Pure.lean +++ b/Nfp/IO/Pure.lean @@ -13,12 +13,9 @@ namespace Nfp /-! ## Float Parsing Utilities -/ -private def pow10PowTable : Array Float := Id.run do +private def pow10PowTable : Array Float := -- Precompute `Float.pow 10.0 k` for k=0..308 so we avoid calling `Float.pow` per token. - let mut out : Array Float := Array.mkEmpty 309 - for k in [:309] do - out := out.push (Float.pow 10.0 k.toFloat) - out + Array.ofFn fun k : Fin 309 => Float.pow 10.0 k.val.toFloat private def pow10Pow (n : Nat) : Float := if n < pow10PowTable.size then @@ -26,26 +23,26 @@ private def pow10Pow (n : Nat) : Float := else Float.pow 10.0 n.toFloat +private def parseNatRange (s : String) (start stop : String.Pos.Raw) : Option Nat := Id.run do + let mut p := start + if p >= stop then + return none + let mut acc : Nat := 0 + let mut saw : Bool := false + while p < stop do + let c := p.get s + if ('0' <= c) && (c <= '9') then + acc := acc * 10 + (c.toNat - '0'.toNat) + saw := true + p := p.next s + else + return none + if saw then some acc else none + private def parseFloatRange (s : String) (start stop : String.Pos.Raw) : Option Float := Id.run do -- This is a faster, allocation-free version of the previous `parseFloat`, but it preserves -- the exact Float computation structure (Nat parsing + `Float.pow`) to keep results stable. - let parseNatRange (s : String) (start stop : String.Pos.Raw) : Option Nat := Id.run do - let mut p := start - if p >= stop then - return none - let mut acc : Nat := 0 - let mut saw : Bool := false - while p < stop do - let c := p.get s - if ('0' <= c) && (c <= '9') then - acc := acc * 10 + (c.toNat - '0'.toNat) - saw := true - p := p.next s - else - return none - if saw then some acc else none - let mut p := start if p >= stop then return none @@ -160,24 +157,31 @@ def parseFloat (s : String) : Option Float := Id.run do else parseFloatRange s 0 s.rawEndPos -private def appendFloatsFromLine (line : String) (acc : Array Float) : Array Float := Id.run do - let mut out := acc - let s := line - let mut p : String.Pos.Raw := 0 - let endPos := s.rawEndPos - let isWs (c : Char) : Bool := - c = ' ' || c = '\t' || c = '\n' || c = '\r' - while p < endPos do - while p < endPos && isWs (p.get s) do - p := p.next s - let start := p - while p < endPos && !isWs (p.get s) do - p := p.next s - if start < p then - match parseFloatRange s start p with - | some x => out := out.push x - | none => pure () - out +@[inline] private def isWs (c : Char) : Bool := + c = ' ' || c = '\t' || c = '\n' || c = '\r' + +@[inline] private def foldTokensFromLine {α : Type} + (line : String) (init : α) + (step : α → String.Pos.Raw → String.Pos.Raw → α) : α := + Id.run do + let mut out := init + let mut p : String.Pos.Raw := 0 + let stop := line.rawEndPos + while p < stop do + while p < stop && isWs (p.get line) do + p := p.next line + let start := p + while p < stop && !isWs (p.get line) do + p := p.next line + if start < p then + out := step out start p + out + +private def appendFloatsFromLine (line : String) (acc : Array Float) : Array Float := + foldTokensFromLine line acc fun out start stop => + match parseFloatRange line start stop with + | some x => out.push x + | none => out private def parseFloatsFromLines (lines : Array String) (cap : Nat := 0) : Array Float := Id.run do @@ -195,41 +199,11 @@ def parseFloatLine (line : String) : Array Float := /-! ## Nat Parsing Utilities -/ -/-- Parse a line of space-separated natural numbers. -/ -private def parseNatRange (s : String) (start stop : String.Pos.Raw) : Option Nat := Id.run do - let mut p := start - if p >= stop then - return none - let mut acc : Nat := 0 - let mut saw : Bool := false - while p < stop do - let c := p.get s - if ('0' <= c) && (c <= '9') then - acc := acc * 10 + (c.toNat - '0'.toNat) - saw := true - p := p.next s - else - return none - if saw then some acc else none - -private def appendNatsFromLine (line : String) (acc : Array Nat) : Array Nat := Id.run do - let mut out := acc - let s := line - let mut p : String.Pos.Raw := 0 - let endPos := s.rawEndPos - let isWs (c : Char) : Bool := - c = ' ' || c = '\t' || c = '\n' || c = '\r' - while p < endPos do - while p < endPos && isWs (p.get s) do - p := p.next s - let start := p - while p < endPos && !isWs (p.get s) do - p := p.next s - if start < p then - match parseNatRange s start p with - | some n => out := out.push n - | none => pure () - out +private def appendNatsFromLine (line : String) (acc : Array Nat) : Array Nat := + foldTokensFromLine line acc fun out start stop => + match parseNatRange line start stop with + | some n => out.push n + | none => out def parseNatLine (line : String) : Array Nat := appendNatsFromLine line #[] diff --git a/Nfp/Sound/BinaryPure.lean b/Nfp/Sound/BinaryPure.lean index 1e217e6..fe11a37 100644 --- a/Nfp/Sound/BinaryPure.lean +++ b/Nfp/Sound/BinaryPure.lean @@ -3,6 +3,7 @@ import Std import Nfp.Sound.Activation import Nfp.Sound.Decimal +import Nfp.Sound.ModelHeader namespace Nfp.Sound @@ -25,14 +26,6 @@ structure BinaryHeader where geluDerivTarget : GeluDerivTarget deriving Repr -private def parseHeaderLine (line : String) : Option (String × String) := - let line := line.trim - if line.isEmpty then none - else - match line.splitOn "=" with - | [k, v] => some (k.trim, v.trim) - | _ => none - private def readHeaderNat (k v : String) : Option Nat := if k = "num_layers" || k = "num_heads" || k = "model_dim" || k = "head_dim" || k = "hidden_dim" || k = "vocab_size" || k = "seq_len" then diff --git a/Nfp/Sound/CachePure.lean b/Nfp/Sound/CachePure.lean index 9cffd12..8d6befd 100644 --- a/Nfp/Sound/CachePure.lean +++ b/Nfp/Sound/CachePure.lean @@ -5,6 +5,7 @@ import Init.System.IO import Init.Data.ByteArray.Lemmas import Nfp.Sound.Decimal import Nfp.Sound.Fixed +import Nfp.Sound.ModelHeader namespace Nfp.Sound @@ -168,14 +169,6 @@ def fnv1a64Update (hash : UInt64) (chunk : ByteArray) : UInt64 := def fnv1a64 (bytes : ByteArray) : UInt64 := fnv1a64Update fnv1a64Init bytes -private def parseHeaderLine (line : String) : Option (String × String) := - let line := line.trim - if line.isEmpty then none - else - match line.splitOn "=" with - | [k, v] => some (k.trim, v.trim) - | _ => none - private def findLineIdxFrom (lines : Array String) (start : Nat) (p : String → Bool) : Option Nat := Id.run do From a6fc01d87203f338dec46aa478d11f7d0c74ae35 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 10:54:58 +0100 Subject: [PATCH 028/244] Optimize binary parsing loops --- Nfp/IO.lean | 27 +++++++++++++------- Nfp/Sound/BinaryPure.lean | 52 +++++++++++++++++++++++++-------------- 2 files changed, 52 insertions(+), 27 deletions(-) diff --git a/Nfp/IO.lean b/Nfp/IO.lean index 7e85c04..409cadf 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -82,14 +82,14 @@ private def readExactly (h : IO.FS.Handle) (n : Nat) : IO ByteArray := do remaining := remaining - chunk.size return ByteArray.mk out -private def u32FromLE (b : ByteArray) (off : Nat) : UInt32 := +@[inline] private def u32FromLE (b : ByteArray) (off : Nat) : UInt32 := let b0 := (b[off]!).toUInt32 let b1 := (b[off + 1]!).toUInt32 let b2 := (b[off + 2]!).toUInt32 let b3 := (b[off + 3]!).toUInt32 b0 ||| (b1 <<< 8) ||| (b2 <<< 16) ||| (b3 <<< 24) -private def u64FromLE (b : ByteArray) (off : Nat) : UInt64 := +@[inline] private def u64FromLE (b : ByteArray) (off : Nat) : UInt64 := let b0 := (b[off]!).toUInt64 let b1 := (b[off + 1]!).toUInt64 let b2 := (b[off + 2]!).toUInt64 @@ -110,27 +110,36 @@ private def i32FromLE (b : ByteArray) (off : Nat) : Int := let two32 : Int := Int.ofNat (Nat.pow 2 32) (Int.ofNat u.toNat) - two32 -private def floatFromLE (b : ByteArray) (off : Nat) : Float := +@[inline] private def floatFromLE (b : ByteArray) (off : Nat) : Float := Float.ofBits (u64FromLE b off) private def readFloatArray (h : IO.FS.Handle) (count : Nat) : IO FloatArray := do if count = 0 then return FloatArray.empty let bytes ← readExactly h (count * 8) - let data := Array.ofFn (fun i : Fin count => - floatFromLE bytes (i.val * 8)) + let mut data : Array Float := Array.replicate count 0.0 + let mut i : Nat := 0 + let mut off : Nat := 0 + while i < count do + data := data.set! i (floatFromLE bytes off) + off := off + 8 + i := i + 1 return .mk data private def readI32Array (h : IO.FS.Handle) (count : Nat) : IO (Array Nat) := do if count = 0 then return #[] let bytes ← readExactly h (count * 4) - let mut out : Array Nat := Array.mkEmpty count - for i in [:count] do - let v := i32FromLE bytes (i * 4) + let mut out : Array Nat := Array.replicate count 0 + let mut i : Nat := 0 + let mut off : Nat := 0 + while i < count do + let v := i32FromLE bytes off if v < 0 then throw (IO.userError s!"Negative token id at index {i}") - out := out.push v.toNat + out := out.set! i v.toNat + off := off + 4 + i := i + 1 return out /-- Load a model from the `.nfpt` binary format (NFP_BINARY_V1). -/ diff --git a/Nfp/Sound/BinaryPure.lean b/Nfp/Sound/BinaryPure.lean index fe11a37..a770e17 100644 --- a/Nfp/Sound/BinaryPure.lean +++ b/Nfp/Sound/BinaryPure.lean @@ -107,7 +107,7 @@ def parseBinaryHeaderLines (magicLine : String) (lines : Array String) : geluDerivTarget := geluVal } -private def u64FromLE (b : ByteArray) (off : Nat) : UInt64 := +@[inline] private def u64FromLE (b : ByteArray) (off : Nat) : UInt64 := let b0 := (b.get! off).toUInt64 let b1 := (b.get! (off + 1)).toUInt64 let b2 := (b.get! (off + 2)).toUInt64 @@ -119,14 +119,14 @@ private def u64FromLE (b : ByteArray) (off : Nat) : UInt64 := b0 ||| (b1 <<< 8) ||| (b2 <<< 16) ||| (b3 <<< 24) ||| (b4 <<< 32) ||| (b5 <<< 40) ||| (b6 <<< 48) ||| (b7 <<< 56) -private def u32FromLE (b : ByteArray) (off : Nat) : UInt32 := +@[inline] private def u32FromLE (b : ByteArray) (off : Nat) : UInt32 := let b0 := (b.get! off).toUInt32 let b1 := (b.get! (off + 1)).toUInt32 let b2 := (b.get! (off + 2)).toUInt32 let b3 := (b.get! (off + 3)).toUInt32 b0 ||| (b1 <<< 8) ||| (b2 <<< 16) ||| (b3 <<< 24) -private def i32FromLE (b : ByteArray) (off : Nat) : Int := +@[inline] private def i32FromLE (b : ByteArray) (off : Nat) : Int := let u := u32FromLE b off if u ≤ 0x7fffffff then Int.ofNat u.toNat @@ -139,7 +139,7 @@ private def ceilDivNat (a : Int) (d : Nat) : Int := let di : Int := Int.ofNat d let q := a.ediv di let r := a.emod di - if r = 0 then q else q + 1 + if r = 0 then q else q + 1 private def scaleIntOfPow10 (scalePow10 : Nat) : Int := Int.ofNat (Nat.pow 10 scalePow10) @@ -197,13 +197,15 @@ def vectorMaxAbsScaledFromBytes (bytes : ByteArray) (n scalePow10 : Nat) : let scaleInt := scaleIntOfPow10 scalePow10 let mut maxAbs : Int := 0 let mut i : Nat := 0 + let mut off : Nat := 0 while i < n do - let bits := u64FromLE bytes (i * 8) + let bits := u64FromLE bytes off match floatAbsCeilScaledCore scaleInt bits with | .error e => throw e | .ok absScaled => if absScaled > maxAbs then maxAbs := absScaled + off := off + 8 i := i + 1 return maxAbs @@ -219,8 +221,9 @@ def matrixNormInfScaledFromBytes (bytes : ByteArray) (rows cols scalePow10 : Nat let mut curRowSum : Int := 0 let mut colIdx : Nat := 0 let mut i : Nat := 0 + let mut off : Nat := 0 while i < count do - let bits := u64FromLE bytes (i * 8) + let bits := u64FromLE bytes off match floatAbsCeilScaledCore scaleInt bits with | .error e => throw e | .ok absScaled => @@ -232,6 +235,7 @@ def matrixNormInfScaledFromBytes (bytes : ByteArray) (rows cols scalePow10 : Nat colIdx := 0 else colIdx := colIdx + 1 + off := off + 8 i := i + 1 return maxRowSum @@ -254,32 +258,40 @@ def scaledFloatArrayFromBytes (bytes : ByteArray) (count scalePow10 : Nat) : tasks := tasks.push <| Task.spawn (fun _ => Id.run do - let mut outChunk : Array Int := Array.mkEmpty (stop - start) + let mut outChunk : Array Int := Array.replicate (stop - start) 0 let mut i := start + let mut off := start * 8 + let mut outIdx : Nat := 0 while i < stop do - let bits := u64FromLE bytes (i * 8) + let bits := u64FromLE bytes off match floatScaledCeilSignedCore scaleInt bits with | .error e => return .error e - | .ok v => outChunk := outChunk.push v + | .ok v => outChunk := outChunk.set! outIdx v + off := off + 8 i := i + 1 + outIdx := outIdx + 1 return .ok outChunk) chunkIdx := chunkIdx + 1 - let mut out : Array Int := Array.mkEmpty count + let mut out : Array Int := Array.replicate count 0 + let mut outIdx : Nat := 0 for t in tasks do match t.get with | .error e => throw e | .ok chunk => for v in chunk do - out := out.push v + out := out.set! outIdx v + outIdx := outIdx + 1 return out else - let mut out : Array Int := Array.mkEmpty count + let mut out : Array Int := Array.replicate count 0 let mut i : Nat := 0 + let mut off : Nat := 0 while i < count do - let bits := u64FromLE bytes (i * 8) + let bits := u64FromLE bytes off match floatScaledCeilSignedCore scaleInt bits with | .error e => throw e - | .ok v => out := out.push v + | .ok v => out := out.set! i v + off := off + 8 i := i + 1 return out @@ -298,11 +310,13 @@ def i32ArrayFromBytes (bytes : ByteArray) (count : Nat) : return #[] if bytes.size < count * 4 then throw "unexpected EOF" - let mut out : Array Int := Array.mkEmpty count + let mut out : Array Int := Array.replicate count 0 let mut i : Nat := 0 + let mut off : Nat := 0 while i < count do - let v := i32FromLE bytes (i * 4) - out := out.push v + let v := i32FromLE bytes off + out := out.set! i v + off := off + 4 i := i + 1 return out @@ -319,8 +333,9 @@ def matrixNormOneInfScaledFromBytes (bytes : ByteArray) (rows cols scalePow10 : let mut colSums : Array Nat := Array.replicate cols 0 let mut colIdx : Nat := 0 let mut i : Nat := 0 + let mut off : Nat := 0 while i < count do - let bits := u64FromLE bytes (i * 8) + let bits := u64FromLE bytes off match floatAbsCeilScaledCore scaleInt bits with | .error e => throw e | .ok absScaled => @@ -334,6 +349,7 @@ def matrixNormOneInfScaledFromBytes (bytes : ByteArray) (rows cols scalePow10 : colIdx := 0 else colIdx := colIdx + 1 + off := off + 8 i := i + 1 let mut maxColSum : Nat := 0 for c in colSums do From 12fb130f8f304cb45f270695c7dcfb68c57f571a Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 11:18:54 +0100 Subject: [PATCH 029/244] Refactor IO parsing and staging --- Nfp/IO.lean | 12 ++++----- Nfp/Sound/CachePure.lean | 12 ++++----- Nfp/Untrusted/SoundBinary.lean | 46 +++++++++++---------------------- Nfp/Untrusted/SoundCacheIO.lean | 16 +----------- Nfp/Untrusted/SoundCompute.lean | 25 +++++++++--------- 5 files changed, 40 insertions(+), 71 deletions(-) diff --git a/Nfp/IO.lean b/Nfp/IO.lean index 409cadf..0512343 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -198,15 +198,15 @@ def loadBinary (h : IO.FS.Handle) : IO LoadResult := do IO.println s!"[3/5] Loading {numLayers} layers with {numHeads} heads each..." - let mut layers : Array (Array ConcreteAttentionLayer) := #[] - let mut attnProjBias : Array ConcreteMatrix := #[] - let mut mlps : Array ConcreteMLPLayer := #[] - let mut ln1 : Array ConcreteLayerNormParams := #[] - let mut ln2 : Array ConcreteLayerNormParams := #[] + let mut layers : Array (Array ConcreteAttentionLayer) := Array.mkEmpty numLayers + let mut attnProjBias : Array ConcreteMatrix := Array.mkEmpty numLayers + let mut mlps : Array ConcreteMLPLayer := Array.mkEmpty numLayers + let mut ln1 : Array ConcreteLayerNormParams := Array.mkEmpty numLayers + let mut ln2 : Array ConcreteLayerNormParams := Array.mkEmpty numLayers for l in [:numLayers] do IO.println s!" Loading layer {l}/{numLayers}..." - let mut layerHeads : Array ConcreteAttentionLayer := #[] + let mut layerHeads : Array ConcreteAttentionLayer := Array.mkEmpty numHeads for _h in [:numHeads] do let wq ← readFloatArray h (modelDim * headDim) let bq ← readFloatArray h headDim diff --git a/Nfp/Sound/CachePure.lean b/Nfp/Sound/CachePure.lean index 8d6befd..dd3531d 100644 --- a/Nfp/Sound/CachePure.lean +++ b/Nfp/Sound/CachePure.lean @@ -193,19 +193,19 @@ private def skipBlankLines (lines : Array String) (start : Nat) : Nat := private def countWsTokens (s : String) : Nat := Id.run do - let bytes := s.toUTF8 - let mut i : Nat := 0 + let mut p : String.Pos.Raw := 0 + let stop := s.rawEndPos let mut inTok : Bool := false let mut cnt : Nat := 0 - while i < bytes.size do - let b := bytes[i]! - let isWs : Bool := b = 32 || b = 9 || b = 10 || b = 13 + while p < stop do + let c := p.get s + let isWs : Bool := c = ' ' || c = '\t' || c = '\n' || c = '\r' if isWs then inTok := false else if !inTok then inTok := true cnt := cnt + 1 - i := i + 1 + p := p.next s return cnt private def skipTokensFast (lines : Array String) (start : Nat) (numTokens : Nat) : diff --git a/Nfp/Untrusted/SoundBinary.lean b/Nfp/Untrusted/SoundBinary.lean index 9f10e30..8733624 100644 --- a/Nfp/Untrusted/SoundBinary.lean +++ b/Nfp/Untrusted/SoundBinary.lean @@ -34,7 +34,8 @@ def readBinaryHeader (h : IO.FS.Handle) : IO (Except String Nfp.Sound.BinaryHead line? ← readLine? h return Nfp.Sound.parseBinaryHeaderLines magicLine lines -private def readExactly (h : IO.FS.Handle) (n : Nat) : IO ByteArray := do +/-- Read exactly `n` bytes or throw on EOF. -/ +def readExactly (h : IO.FS.Handle) (n : Nat) : IO ByteArray := do if n = 0 then return ByteArray.empty let mut remaining := n @@ -48,6 +49,13 @@ private def readExactly (h : IO.FS.Handle) (n : Nat) : IO ByteArray := do remaining := remaining - chunk.size return ByteArray.mk out +@[inline] private def readExactlyExcept (h : IO.FS.Handle) (n : Nat) : + IO (Except String ByteArray) := do + try + return .ok (← readExactly h n) + catch + | _ => return .error "unexpected EOF" + def skipBytes (h : IO.FS.Handle) (n : Nat) : IO (Except String Unit) := do let mut remaining := n while remaining > 0 do @@ -68,11 +76,7 @@ def readVectorMaxAbsScaled (h : IO.FS.Handle) (n scalePow10 : Nat) : IO (Except String Int) := do if n = 0 then return .ok 0 - let bytesE : Except String ByteArray ← - try - pure (Except.ok (← readExactly h (n * 8))) - catch - | _ => pure (Except.error "unexpected EOF") + let bytesE ← readExactlyExcept h (n * 8) match bytesE with | .error e => return .error e | .ok bytes => @@ -83,11 +87,7 @@ def readMatrixNormInfScaled (h : IO.FS.Handle) (rows cols scalePow10 : Nat) : if rows = 0 || cols = 0 then return .ok 0 let count := rows * cols - let bytesE : Except String ByteArray ← - try - pure (Except.ok (← readExactly h (count * 8))) - catch - | _ => pure (Except.error "unexpected EOF") + let bytesE ← readExactlyExcept h (count * 8) match bytesE with | .error e => return .error e | .ok bytes => @@ -97,22 +97,14 @@ def readScaledFloatArray (h : IO.FS.Handle) (count scalePow10 : Nat) : IO (Except String (Array Int)) := do if count = 0 then return .ok #[] - let bytesE : Except String ByteArray ← - try - pure (Except.ok (← readExactly h (count * 8))) - catch - | _ => pure (Except.error "unexpected EOF") + let bytesE ← readExactlyExcept h (count * 8) match bytesE with | .error e => return .error e | .ok bytes => return Nfp.Sound.scaledFloatArrayFromBytes bytes count scalePow10 def readScaledFloat (h : IO.FS.Handle) (scalePow10 : Nat) : IO (Except String Int) := do - let bytesE : Except String ByteArray ← - try - pure (Except.ok (← readExactly h 8)) - catch - | _ => pure (Except.error "unexpected EOF") + let bytesE ← readExactlyExcept h 8 match bytesE with | .error e => return .error e | .ok bytes => @@ -122,11 +114,7 @@ def readI32Array (h : IO.FS.Handle) (count : Nat) : IO (Except String (Array Int)) := do if count = 0 then return .ok #[] - let bytesE : Except String ByteArray ← - try - pure (Except.ok (← readExactly h (count * 4))) - catch - | _ => pure (Except.error "unexpected EOF") + let bytesE ← readExactlyExcept h (count * 4) match bytesE with | .error e => return .error e | .ok bytes => @@ -137,11 +125,7 @@ def readMatrixNormOneInfScaled (h : IO.FS.Handle) (rows cols scalePow10 : Nat) : if rows = 0 || cols = 0 then return .ok (0, 0) let count := rows * cols - let bytesE : Except String ByteArray ← - try - pure (Except.ok (← readExactly h (count * 8))) - catch - | _ => pure (Except.error "unexpected EOF") + let bytesE ← readExactlyExcept h (count * 8) match bytesE with | .error e => return .error e | .ok bytes => diff --git a/Nfp/Untrusted/SoundCacheIO.lean b/Nfp/Untrusted/SoundCacheIO.lean index 78fb99c..df99b72 100644 --- a/Nfp/Untrusted/SoundCacheIO.lean +++ b/Nfp/Untrusted/SoundCacheIO.lean @@ -12,20 +12,6 @@ namespace Nfp.Untrusted.SoundCacheIO IO wrappers for the SOUND cache format. Pure parsing/encoding lives in `Nfp.Sound.CachePure`. -/ -private def readExactly (h : IO.FS.Handle) (n : Nat) : IO ByteArray := do - if n = 0 then - return ByteArray.empty - let mut remaining := n - let mut out : Array UInt8 := Array.mkEmpty n - while remaining > 0 do - let chunk ← h.read (USize.ofNat remaining) - if chunk.isEmpty then - throw (IO.userError "unexpected EOF") - for b in chunk.data do - out := out.push b - remaining := remaining - chunk.size - return ByteArray.mk out - private def appendI32LE (buf : Array UInt8) (x : Int) : Array UInt8 := Id.run do let ux : UInt32 := UInt32.ofInt x @@ -62,7 +48,7 @@ def writeHeader (h : IO.FS.Handle) (hdr : Nfp.Sound.SoundCache.Header) : IO Unit h.write (Nfp.Sound.SoundCache.encodeHeader hdr) def readHeader (h : IO.FS.Handle) : IO Nfp.Sound.SoundCache.Header := do - let bytes ← readExactly h Nfp.Sound.SoundCache.headerBytes + let bytes ← Nfp.Untrusted.SoundBinary.readExactly h Nfp.Sound.SoundCache.headerBytes match Nfp.Sound.SoundCache.decodeHeader bytes with | .ok hdr => return hdr | .error e => throw (IO.userError e) diff --git a/Nfp/Untrusted/SoundCompute.lean b/Nfp/Untrusted/SoundCompute.lean index ec7f966..b059b8a 100644 --- a/Nfp/Untrusted/SoundCompute.lean +++ b/Nfp/Untrusted/SoundCompute.lean @@ -539,19 +539,19 @@ instead of calling `parseRat`. private def countWsTokens (s : String) : Nat := Id.run do - let bytes := s.toUTF8 - let mut i : Nat := 0 + let mut p : String.Pos.Raw := 0 + let stop := s.rawEndPos let mut inTok : Bool := false let mut cnt : Nat := 0 - while i < bytes.size do - let b := bytes[i]! - let isWs : Bool := b = 32 || b = 9 -- ' ' or '\t' + while p < stop do + let c := p.get s + let isWs : Bool := c = ' ' || c = '\t' || c = '\n' || c = '\r' if isWs then inTok := false else if !inTok then inTok := true cnt := cnt + 1 - i := i + 1 + p := p.next s return cnt private def consumeTokensSkipFast @@ -562,16 +562,15 @@ private def consumeTokensSkipFast while remaining > 0 do if iLine ≥ lines.size then return .error "unexpected end of file while skipping tokens" - let line := lines[iLine]!.trim + let line := lines[iLine]! iLine := iLine + 1 - if line.isEmpty then + let c := countWsTokens line + if c = 0 then pure () + else if c ≥ remaining then + remaining := 0 else - let c := countWsTokens line - if c ≥ remaining then - remaining := 0 - else - remaining := remaining - c + remaining := remaining - c return .ok iLine private def consumeMatrixSkip From 81dfff99dfec5214d6f4cb7c1a6719598562adf0 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 11:38:07 +0100 Subject: [PATCH 030/244] Refine tokenization and log sanitization --- Main.lean | 5 ++--- Nfp/IO/Pure.lean | 18 ++++-------------- 2 files changed, 6 insertions(+), 17 deletions(-) diff --git a/Main.lean b/Main.lean index 5c552b5..d022f64 100644 --- a/Main.lean +++ b/Main.lean @@ -75,9 +75,8 @@ private structure StdoutLogCtx where initialize stdoutLogCtxRef : IO.Ref (Option StdoutLogCtx) ← IO.mkRef none private def sanitizeFileComponent (s : String) : String := - String.ofList <| - s.toList.map fun c => - if c.isAlphanum || c = '_' || c = '-' || c = '.' then c else '_' + s.map fun c => + if c.isAlphanum || c = '_' || c = '-' || c = '.' then c else '_' private def timestampNowForLog : IO String := do let dt ← Std.Time.ZonedDateTime.now diff --git a/Nfp/IO/Pure.lean b/Nfp/IO/Pure.lean index d36231d..5c3599b 100644 --- a/Nfp/IO/Pure.lean +++ b/Nfp/IO/Pure.lean @@ -348,20 +348,10 @@ def findToken (t : Tokenizer) (word : String) : Nat := t.tokMap.getD word t.unkId /-- Tokenize a string using simple whitespace splitting. -/ -def tokenize (t : Tokenizer) (text : String) : Array Nat := Id.run do - let mut ids : Array Nat := #[] - let mut p : String.Pos.Raw := 0 - let stop := text.rawEndPos - while p < stop do - while p < stop && p.get text = ' ' do - p := p.next text - let start := p - while p < stop && p.get text ≠ ' ' do - p := p.next text - if start < p then - let word := String.Pos.Raw.extract text start p - ids := ids.push (t.findToken word) - ids +def tokenize (t : Tokenizer) (text : String) : Array Nat := + foldTokensFromLine text #[] fun out start stop => + let word := String.Pos.Raw.extract text start stop + out.push (t.findToken word) /-- Decode token IDs back to text. -/ def decode (t : Tokenizer) (ids : Array Nat) : String := From adadc47a1e4f2da1a57e2646e23c90d661b9d18a Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 11:52:25 +0100 Subject: [PATCH 031/244] Unify text line splitting in SOUND paths --- Nfp/Sound/IO.lean | 4 ++-- Nfp/Sound/ModelHeader.lean | 18 ++++++++++++++++++ Nfp/Untrusted/SoundCacheIO.lean | 5 +++-- Nfp/Untrusted/SoundCompute.lean | 7 ++++--- 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/Nfp/Sound/IO.lean b/Nfp/Sound/IO.lean index a7f385e..da9babf 100644 --- a/Nfp/Sound/IO.lean +++ b/Nfp/Sound/IO.lean @@ -23,7 +23,7 @@ This module is intentionally thin: it delegates witness generation to private def readTextModelHeader (path : System.FilePath) : IO (Except String TextHeader) := do let contents ← IO.FS.readFile path - let lines : Array String := (contents.splitOn "\n").toArray + let lines : Array String := splitLines contents return Nfp.Sound.parseTextHeader lines private def readBinaryModelHeader (path : System.FilePath) : @@ -198,7 +198,7 @@ private def recomputeModelWeightBoundsBinary private def recomputeModelWeightBoundsText (path : System.FilePath) : IO (Except String ModelWeightBounds) := do let contents ← IO.FS.readFile path - let lines : Array String := (contents.splitOn "\n").toArray + let lines : Array String := splitLines contents return modelWeightBoundsFromTextLines lines private def recomputeModelWeightBounds diff --git a/Nfp/Sound/ModelHeader.lean b/Nfp/Sound/ModelHeader.lean index e2f2794..b3ed573 100644 --- a/Nfp/Sound/ModelHeader.lean +++ b/Nfp/Sound/ModelHeader.lean @@ -21,6 +21,23 @@ def parseHeaderLine (line : String) : Option (String × String) := | [k, v] => some (k.trim, v.trim) | _ => none +/-- Split a string on `\n`, preserving empty lines. -/ +def splitLines (s : String) : Array String := + Id.run do + let mut out : Array String := #[] + let mut start : String.Pos.Raw := 0 + let mut p : String.Pos.Raw := 0 + let stop := s.rawEndPos + while p < stop do + if p.get s = '\n' then + out := out.push (String.Pos.Raw.extract s start p) + p := p.next s + start := p + else + p := p.next s + out := out.push (String.Pos.Raw.extract s start stop) + return out + /-- Minimal parsed header data for sound certification. -/ structure TextHeader where eps : Rat @@ -76,6 +93,7 @@ def parseTextHeaderEps (lines : Array String) : Except String Rat := do /-! ### Specs -/ theorem parseHeaderLine_spec : parseHeaderLine = parseHeaderLine := rfl +theorem splitLines_spec : splitLines = splitLines := rfl theorem parseGeluDerivTarget_spec (v : String) : parseGeluDerivTarget v = parseGeluDerivTarget v := rfl theorem parseTextHeader_spec : parseTextHeader = parseTextHeader := rfl diff --git a/Nfp/Untrusted/SoundCacheIO.lean b/Nfp/Untrusted/SoundCacheIO.lean index df99b72..cccee27 100644 --- a/Nfp/Untrusted/SoundCacheIO.lean +++ b/Nfp/Untrusted/SoundCacheIO.lean @@ -3,6 +3,7 @@ import Std import Init.System.IO import Nfp.Sound.CachePure +import Nfp.Sound.ModelHeader import Nfp.Untrusted.SoundBinary namespace Nfp.Untrusted.SoundCacheIO @@ -75,7 +76,7 @@ def buildCacheBytesText (scalePow10 : Nat) (modelHash modelSize : UInt64) : IO (Except String ByteArray) := do let contents ← IO.FS.readFile modelPath - let lines : Array String := (contents.splitOn "\n").toArray + let lines : Array String := Nfp.Sound.splitLines contents return Nfp.Sound.SoundCache.buildCacheBytes lines scalePow10 modelHash modelSize def buildCacheBytesBinary @@ -218,7 +219,7 @@ def checkTextTokenEnvelope (scalePow10 : Nat := 9) (maxTokens : Nat := 0) : IO (Except String Unit) := do let contents ← IO.FS.readFile modelPath - let lines : Array String := (contents.splitOn "\n").toArray + let lines : Array String := Nfp.Sound.splitLines contents return Nfp.Sound.SoundCache.checkTextTokenEnvelopeLines lines scalePow10 maxTokens /-- Check that the cache file size matches the expected tensor stream length. -/ diff --git a/Nfp/Untrusted/SoundCompute.lean b/Nfp/Untrusted/SoundCompute.lean index b059b8a..807b72a 100644 --- a/Nfp/Untrusted/SoundCompute.lean +++ b/Nfp/Untrusted/SoundCompute.lean @@ -3,6 +3,7 @@ import Std import Nfp.Sound.Cert import Nfp.Sound.HeadCert +import Nfp.Sound.ModelHeader import Nfp.Untrusted.SoundBinary import Nfp.Sound.Interval import Nfp.Sound.Affine @@ -667,7 +668,7 @@ def certifyModelFileGlobal return .error "partitionDepth > 0 not yet implemented" let actDerivBound := geluDerivBoundGlobal geluDerivTarget let contents ← IO.FS.readFile path - let lines : Array String := (contents.splitOn "\n").toArray + let lines : Array String := Nfp.Sound.splitLines contents -- Header let mut i : Nat := 0 while i < lines.size && lines[i]!.trim.isEmpty do @@ -890,7 +891,7 @@ private def loadEmbeddingsIntervals (path : System.FilePath) (seqLen modelDim : Nat) (delta : Rat) : IO (Except String (Array (Array RatInterval))) := do let contents ← IO.FS.readFile path - let lines : Array String := (contents.splitOn "\n").toArray + let lines : Array String := Nfp.Sound.splitLines contents let mut i : Nat := 0 while i < lines.size && lines[i]!.trim.isEmpty do i := i + 1 @@ -3903,7 +3904,7 @@ private def certifyModelFileLocalText if partitionDepth ≠ 0 then return .error "partitionDepth > 0 not yet implemented" let contents ← IO.FS.readFile path - let lines : Array String := (contents.splitOn "\n").toArray + let lines : Array String := Nfp.Sound.splitLines contents -- Header let mut i : Nat := 0 while i < lines.size && lines[i]!.trim.isEmpty do From daa29943ec2ad89dc54fddda8f99afedb145a3d1 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 11:56:56 +0100 Subject: [PATCH 032/244] Reuse SOUND text parsers in untrusted compute --- Nfp/Untrusted/SoundCompute.lean | 110 ++------------------------------ 1 file changed, 4 insertions(+), 106 deletions(-) diff --git a/Nfp/Untrusted/SoundCompute.lean b/Nfp/Untrusted/SoundCompute.lean index 807b72a..299b56e 100644 --- a/Nfp/Untrusted/SoundCompute.lean +++ b/Nfp/Untrusted/SoundCompute.lean @@ -4,6 +4,7 @@ import Std import Nfp.Sound.Cert import Nfp.Sound.HeadCert import Nfp.Sound.ModelHeader +import Nfp.Sound.TextPure import Nfp.Untrusted.SoundBinary import Nfp.Sound.Interval import Nfp.Sound.Affine @@ -38,96 +39,8 @@ Trusted base: No `Float` arithmetic is used as an input to certification. -/ -/-- Parse `key=value` header lines. -/ -def parseHeaderLine (line : String) : Option (String × String) := - let line := line.trim - if line.isEmpty then none - else - match line.splitOn "=" with - | [k, v] => some (k.trim, v.trim) - | _ => none - private def defaultBinaryScalePow10 : Nat := 9 -/-- Read `count` rationals from lines starting at `start`, folding into `state`. - -Returns `(state, nextLineIndex)`. --/ -def foldRatTokens {α : Type} - (lines : Array String) - (start : Nat) - (count : Nat) - (state : α) - (step : α → Rat → α) : Except String (α × Nat) := - Id.run do - let isWs (c : Char) : Bool := - c = ' ' || c = '\t' || c = '\n' || c = '\r' - let mut i := start - let mut remaining := count - let mut st := state - while remaining > 0 do - if i < lines.size then - let line := lines[i]! - i := i + 1 - let mut p : String.Pos.Raw := 0 - let stop := line.rawEndPos - while p < stop && remaining > 0 do - while p < stop && isWs (p.get line) do - p := p.next line - let tokStart := p - while p < stop && !isWs (p.get line) do - p := p.next line - if tokStart < p then - match parseRatRange line tokStart p with - | .error e => return .error e - | .ok r => - st := step st r - remaining := remaining - 1 - else - return .error "unexpected end of file while reading numbers" - return .ok (st, i) - -/-- Consume a vector of length `n` into an array. Returns `(xs, nextLineIndex)`. -/ -def consumeVector - (lines : Array String) - (start : Nat) - (n : Nat) : Except String (Array Rat × Nat) := - let step := fun (acc : Array Rat) (x : Rat) => acc.push x - foldRatTokens lines start n (Array.mkEmpty n) step - -/-- Consume a matrix in row-major order and return its exact row-sum norm -(`‖·‖∞` in column-vector convention, `‖·‖₁` in row-vector convention). - -Returns `(normInf, nextLineIndex)`. --/ -def consumeMatrixNormInf - (lines : Array String) - (start : Nat) - (rows cols : Nat) : Except String (Rat × Nat) := - let count := rows * cols - match consumeVector lines start count with - | .error e => .error e - | .ok (xs, next) => .ok (matrixNormInfOfRowMajor rows cols xs, next) - -/-- Parsed matrix norm is definitionally the spec-level bound on the parsed data. -/ -theorem consumeMatrixNormInf_spec - (lines : Array String) (start rows cols : Nat) (xs : Array Rat) (next : Nat) - (h : consumeVector lines start (rows * cols) = .ok (xs, next)) : - consumeMatrixNormInf lines start rows cols = - .ok (matrixNormInfOfRowMajor rows cols xs, next) := by - simp [consumeMatrixNormInf, h] - -/-- Consume a vector of length `n` and return `max |xᵢ|`. - -Returns `(maxAbs, nextLineIndex)`. --/ -def consumeMaxAbs - (lines : Array String) - (start : Nat) - (n : Nat) : Except String (Rat × Nat) := - let step := fun (m : Rat) (x : Rat) => max m (ratAbs x) - foldRatTokens lines start n 0 step - private def maxAbsOfVector (xs : Array Rat) : Rat := xs.foldl (fun acc x => max acc (ratAbs x)) 0 @@ -777,21 +690,21 @@ def certifyModelFileGlobal mlpWout := mlpWout.set! curLayer n i := next else if line = "LN1_GAMMA" then - match consumeMaxAbs lines (i + 1) d with + match consumeVectorMaxAbs lines (i + 1) d with | .error e => return .error e | .ok (m, next) => if curLayer < ln1GammaMax.size then ln1GammaMax := ln1GammaMax.set! curLayer m i := next else if line = "LN1_BETA" then - match consumeMaxAbs lines (i + 1) d with + match consumeVectorMaxAbs lines (i + 1) d with | .error e => return .error e | .ok (m, next) => if curLayer < ln1BetaMax.size then ln1BetaMax := ln1BetaMax.set! curLayer m i := next else if line = "LN2_GAMMA" then - match consumeMaxAbs lines (i + 1) d with + match consumeVectorMaxAbs lines (i + 1) d with | .error e => return .error e | .ok (m, next) => if curLayer < ln2GammaMax.size then @@ -8060,24 +7973,9 @@ def certifyInductionSoundBestMatch /-! ### Specs -/ -theorem parseHeaderLine_spec_io : - parseHeaderLine = parseHeaderLine := rfl - theorem defaultBinaryScalePow10_spec_io : defaultBinaryScalePow10 = defaultBinaryScalePow10 := rfl -theorem foldRatTokens_spec_io {α : Type} : - foldRatTokens (α := α) = foldRatTokens (α := α) := rfl - -theorem consumeVector_spec_io : - consumeVector = consumeVector := rfl - -theorem consumeMatrixNormInf_spec_io : - consumeMatrixNormInf = consumeMatrixNormInf := rfl - -theorem consumeMaxAbs_spec_io : - consumeMaxAbs = consumeMaxAbs := rfl - theorem maxAbsOfVector_spec_io : maxAbsOfVector = maxAbsOfVector := rfl From 06a0469e829374f97583a3371950ff3a42d1c5ce Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 12:09:01 +0100 Subject: [PATCH 033/244] Share SOUND text scanning utilities --- Nfp/Sound/CachePure.lean | 36 ++++-------------------- Nfp/Sound/ModelHeader.lean | 50 +++++++++++++++++++++++++++++++++ Nfp/Sound/TextPure.lean | 6 ++-- Nfp/Untrusted/SoundCompute.lean | 36 ++++-------------------- 4 files changed, 62 insertions(+), 66 deletions(-) diff --git a/Nfp/Sound/CachePure.lean b/Nfp/Sound/CachePure.lean index dd3531d..dd57b4e 100644 --- a/Nfp/Sound/CachePure.lean +++ b/Nfp/Sound/CachePure.lean @@ -171,42 +171,16 @@ def fnv1a64 (bytes : ByteArray) : UInt64 := private def findLineIdxFrom (lines : Array String) (start : Nat) (p : String → Bool) : Option Nat := - Id.run do - let mut i := start - while i < lines.size do - if p (lines[i]!.trim) then - return some i - i := i + 1 - return none + Nfp.Sound.findLineIdxFrom lines start p private def skipUntil (lines : Array String) (start : Nat) (p : String → Bool) : Nat := - match findLineIdxFrom lines start p with - | some i => i - | none => lines.size + Nfp.Sound.skipUntil lines start p private def skipBlankLines (lines : Array String) (start : Nat) : Nat := - Id.run do - let mut i := start - while i < lines.size && lines[i]!.trim.isEmpty do - i := i + 1 - return i + Nfp.Sound.skipBlankLines lines start -private def countWsTokens (s : String) : Nat := - Id.run do - let mut p : String.Pos.Raw := 0 - let stop := s.rawEndPos - let mut inTok : Bool := false - let mut cnt : Nat := 0 - while p < stop do - let c := p.get s - let isWs : Bool := c = ' ' || c = '\t' || c = '\n' || c = '\r' - if isWs then - inTok := false - else if !inTok then - inTok := true - cnt := cnt + 1 - p := p.next s - return cnt +@[inline] private def countWsTokens (s : String) : Nat := + Nfp.Sound.countWsTokens s private def skipTokensFast (lines : Array String) (start : Nat) (numTokens : Nat) : Except String Nat := diff --git a/Nfp/Sound/ModelHeader.lean b/Nfp/Sound/ModelHeader.lean index b3ed573..41d2aa4 100644 --- a/Nfp/Sound/ModelHeader.lean +++ b/Nfp/Sound/ModelHeader.lean @@ -38,6 +38,51 @@ def splitLines (s : String) : Array String := out := out.push (String.Pos.Raw.extract s start stop) return out +/-- Whitespace predicate used in token scanners. -/ +@[inline] def isWsChar (c : Char) : Bool := + c = ' ' || c = '\t' || c = '\n' || c = '\r' + +/-- Count whitespace-separated tokens in a line. -/ +def countWsTokens (s : String) : Nat := + Id.run do + let mut p : String.Pos.Raw := 0 + let stop := s.rawEndPos + let mut inTok : Bool := false + let mut cnt : Nat := 0 + while p < stop do + let c := p.get s + if isWsChar c then + inTok := false + else if !inTok then + inTok := true + cnt := cnt + 1 + p := p.next s + return cnt + +/-- Find the first line index at or after `start` that satisfies `p`. -/ +def findLineIdxFrom (lines : Array String) (start : Nat) (p : String → Bool) : Option Nat := + Id.run do + let mut i := start + while i < lines.size do + if p (lines[i]!.trim) then + return some i + i := i + 1 + return none + +/-- Skip to the next line satisfying `p`, or return `lines.size` if none. -/ +def skipUntil (lines : Array String) (start : Nat) (p : String → Bool) : Nat := + match findLineIdxFrom lines start p with + | some i => i + | none => lines.size + +/-- Skip blank (whitespace-only) lines starting at `start`. -/ +def skipBlankLines (lines : Array String) (start : Nat) : Nat := + Id.run do + let mut i := start + while i < lines.size && lines[i]!.trim.isEmpty do + i := i + 1 + return i + /-- Minimal parsed header data for sound certification. -/ structure TextHeader where eps : Rat @@ -94,6 +139,11 @@ def parseTextHeaderEps (lines : Array String) : Except String Rat := do theorem parseHeaderLine_spec : parseHeaderLine = parseHeaderLine := rfl theorem splitLines_spec : splitLines = splitLines := rfl +theorem isWsChar_spec : isWsChar = isWsChar := rfl +theorem countWsTokens_spec : countWsTokens = countWsTokens := rfl +theorem findLineIdxFrom_spec : findLineIdxFrom = findLineIdxFrom := rfl +theorem skipUntil_spec : skipUntil = skipUntil := rfl +theorem skipBlankLines_spec : skipBlankLines = skipBlankLines := rfl theorem parseGeluDerivTarget_spec (v : String) : parseGeluDerivTarget v = parseGeluDerivTarget v := rfl theorem parseTextHeader_spec : parseTextHeader = parseTextHeader := rfl diff --git a/Nfp/Sound/TextPure.lean b/Nfp/Sound/TextPure.lean index c82f904..1ce8125 100644 --- a/Nfp/Sound/TextPure.lean +++ b/Nfp/Sound/TextPure.lean @@ -102,8 +102,6 @@ def foldRatTokens {α : Type} (state : α) (step : α → Rat → α) : Except String (α × Nat) := Id.run do - let isWs (c : Char) : Bool := - c = ' ' || c = '\t' || c = '\n' || c = '\r' let mut i := start let mut remaining := count let mut st := state @@ -114,10 +112,10 @@ def foldRatTokens {α : Type} let mut p : String.Pos.Raw := 0 let stop := line.rawEndPos while p < stop && remaining > 0 do - while p < stop && isWs (p.get line) do + while p < stop && isWsChar (p.get line) do p := p.next line let tokStart := p - while p < stop && !isWs (p.get line) do + while p < stop && !isWsChar (p.get line) do p := p.next line if tokStart < p then match parseRatRange line tokStart p with diff --git a/Nfp/Untrusted/SoundCompute.lean b/Nfp/Untrusted/SoundCompute.lean index 299b56e..15f502c 100644 --- a/Nfp/Untrusted/SoundCompute.lean +++ b/Nfp/Untrusted/SoundCompute.lean @@ -421,25 +421,13 @@ private def minVarAcrossRows (rows : Array (Array RatInterval)) : Rat := private def findLineIdxFrom (lines : Array String) (start : Nat) (p : String → Bool) : Option Nat := - Id.run do - let mut i := start - while i < lines.size do - if p (lines[i]!.trim) then - return some i - i := i + 1 - return none + Nfp.Sound.findLineIdxFrom lines start p private def skipUntil (lines : Array String) (start : Nat) (p : String → Bool) : Nat := - match findLineIdxFrom lines start p with - | some i => i - | none => lines.size + Nfp.Sound.skipUntil lines start p private def skipBlankLines (lines : Array String) (start : Nat) : Nat := - Id.run do - let mut i := start - while i < lines.size && lines[i]!.trim.isEmpty do - i := i + 1 - return i + Nfp.Sound.skipBlankLines lines start /-! ### Fast skipping without parsing @@ -451,22 +439,8 @@ Parsing decimals into `Rat` is expensive, so we skip these sections by **countin instead of calling `parseRat`. -/ -private def countWsTokens (s : String) : Nat := - Id.run do - let mut p : String.Pos.Raw := 0 - let stop := s.rawEndPos - let mut inTok : Bool := false - let mut cnt : Nat := 0 - while p < stop do - let c := p.get s - let isWs : Bool := c = ' ' || c = '\t' || c = '\n' || c = '\r' - if isWs then - inTok := false - else if !inTok then - inTok := true - cnt := cnt + 1 - p := p.next s - return cnt +@[inline] private def countWsTokens (s : String) : Nat := + Nfp.Sound.countWsTokens s private def consumeTokensSkipFast (lines : Array String) (start : Nat) (numTokens : Nat) : Except String Nat := From 2c525bba76dc97423230fe42092450b44e7305a4 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 12:19:04 +0100 Subject: [PATCH 034/244] Streamline Discovery aggregation loops --- Nfp/Discovery.lean | 107 +++++++++++++++++++++++++++++++-------------- 1 file changed, 73 insertions(+), 34 deletions(-) diff --git a/Nfp/Discovery.lean b/Nfp/Discovery.lean index 55a1900..b00462f 100644 --- a/Nfp/Discovery.lean +++ b/Nfp/Discovery.lean @@ -61,6 +61,50 @@ of minutes for full network analysis). namespace Nfp +@[inline] private def sumSquares (xs : Array Float) : Float := Id.run do + let mut acc : Float := 0.0 + for x in xs do + acc := acc + x * x + return acc + +@[inline] private def sumFloatArray (xs : Array Float) : Float := Id.run do + let mut acc : Float := 0.0 + for x in xs do + acc := acc + x + return acc + +@[inline] private def sumNatArray (xs : Array Nat) : Nat := Id.run do + let mut acc : Nat := 0 + for x in xs do + acc := acc + x + return acc + +@[inline] private def sumSizes {α : Type} (chunks : Array (Array α)) : Nat := Id.run do + let mut acc : Nat := 0 + for cs in chunks do + acc := acc + cs.size + return acc + +@[inline] private def maxArray (xs : Array Float) : Float := Id.run do + let mut m : Float := 0.0 + for x in xs do + if x > m then + m := x + return m + +@[inline] private def countTrue (xs : Array Bool) : Nat := Id.run do + let mut acc : Nat := 0 + for b in xs do + if b then + acc := acc + 1 + return acc + +@[inline] private def countTrueNested (xs : Array (Array Bool)) : Nat := Id.run do + let mut acc : Nat := 0 + for row in xs do + acc := acc + countTrue row + return acc + /-! ## Concrete Weight Representations -/ /-- A concrete weight matrix stored as nested Arrays. @@ -271,7 +315,7 @@ def matmul (A B : ConcreteMatrix) : ConcreteMatrix := /-- Compute Frobenius norm squared: Σᵢⱼ M[i,j]². -/ def frobeniusNormSq (M : ConcreteMatrix) : Float := - M.data.foldl (fun acc x => acc + x * x) 0.0 + sumSquares M.data /-- Compute Frobenius norm: √(Σᵢⱼ M[i,j]²). -/ def frobeniusNorm (M : ConcreteMatrix) : Float := @@ -1658,7 +1702,7 @@ def dot (v1 v2 : ConcreteMatrix) : Float := /-- Compute L2 norm of a vector (stored as n×1 matrix). -/ def vecNorm (v : ConcreteMatrix) : Float := if v.numCols = 1 then - Float.sqrt (v.data.foldl (fun acc x => acc + x * x) 0.0) + Float.sqrt (sumSquares v.data) else 0.0 /-- Vector subtraction for n×1 matrices. -/ @@ -2124,12 +2168,12 @@ def neuronOutputWeights (layer : ConcreteMLPLayer) (neuronIdx : Nat) : Array Flo /-- Compute the L2 norm of input weights for a neuron. -/ def neuronInputNorm (layer : ConcreteMLPLayer) (neuronIdx : Nat) : Float := let weights := layer.neuronInputWeights neuronIdx - Float.sqrt (weights.foldl (fun acc w => acc + w * w) 0.0) + Float.sqrt (sumSquares weights) /-- Compute the L2 norm of output weights for a neuron. -/ def neuronOutputNorm (layer : ConcreteMLPLayer) (neuronIdx : Nat) : Float := let weights := layer.neuronOutputWeights neuronIdx - Float.sqrt (weights.foldl (fun acc w => acc + w * w) 0.0) + Float.sqrt (sumSquares weights) /-- Compute the "influence magnitude" of a neuron: ‖W_in[:,i]‖ · ‖W_out[i,:]‖ @@ -2542,12 +2586,12 @@ def decoderWeights (sae : ConcreteSAE) (featureIdx : Nat) : Array Float := /-- Compute the L2 norm of encoder weights for feature k. -/ def encoderNorm (sae : ConcreteSAE) (featureIdx : Nat) : Float := let weights := sae.encoderWeights featureIdx - Float.sqrt (weights.foldl (fun acc w => acc + w * w) 0.0) + Float.sqrt (sumSquares weights) /-- Compute the L2 norm of decoder weights for feature k. -/ def decoderNorm (sae : ConcreteSAE) (featureIdx : Nat) : Float := let weights := sae.decoderWeights featureIdx - Float.sqrt (weights.foldl (fun acc w => acc + w * w) 0.0) + Float.sqrt (sumSquares weights) /-- Compute the "influence magnitude" of feature k: ‖W_enc[:,k]‖ · ‖W_dec[k,:]‖ @@ -2856,7 +2900,7 @@ This bounds the Frobenius norm of the softmax Jacobian for that row. - Uniform over n → sqrt((n-1)/n) ≈ 1 for large n -/ def softmaxRowJacobianNorm (row : Array Float) : Float := - let sumSq := row.foldl (fun acc p => acc + p * p) 0.0 + let sumSq := sumSquares row Float.sqrt (max 0.0 (1.0 - sumSq)) /-- Compute the average softmax Jacobian norm across all rows of attention weights. @@ -3184,7 +3228,7 @@ def computeMLPLayerOpNormFromGeluDerivWithOpBounds if a > out[k]! then out := out.set! k a out - let globalDmax : Float := dMax.foldl (fun m x => max m x) 0.0 + let globalDmax : Float := maxArray dMax if globalDmax ≤ 0.0 || Float.isNaN globalDmax || Float.isInf globalDmax then -- If derivative information is degenerate, we can still use the global GeLU' upper bound (≈1.7). return legacy @@ -3351,7 +3395,7 @@ def computeMLPOpAbsSchurDiag (layer : ConcreteMLPLayer) (geluDeriv : ConcreteMat if a > out[k]! then out := out.set! k a out - let globalDmax : Float := dMaxVec.foldl (fun m x => max m x) 0.0 + let globalDmax : Float := maxArray dMaxVec if globalDmax ≤ 0.0 || Float.isNaN globalDmax || Float.isInf globalDmax then return { dMax := 0.0, boundInf := 0.0, boundOne := 0.0, absSchur := 0.0 } @@ -3444,7 +3488,7 @@ def ConcreteMLPLayer.precomputeJacobianBoundCore (layer : ConcreteMLPLayer) if a > out[k]! then out := out.set! k a out - let globalDmax : Float := dMax.foldl (fun m x => max m x) 0.0 + let globalDmax : Float := maxArray dMax if globalDmax ≤ 0.0 || Float.isNaN globalDmax || Float.isInf globalDmax then return none @@ -3631,7 +3675,7 @@ This avoids computing the full (N·D)² matrix! /-- Compute ‖valueTerm‖_F efficiently via factorization. -/ def computeValueTermNorm (attn : ConcreteAttentionWeights) (valueOutputProjFrobNormSq : Float) : Float := - let attnNormSq := attn.weights.foldl (fun acc x => acc + x * x) 0.0 + let attnNormSq := sumSquares attn.weights Float.sqrt (attnNormSq * valueOutputProjFrobNormSq) /-- Information needed to bound the pattern term. -/ @@ -5076,7 +5120,7 @@ This precomputes all attention patterns, projections, and norms once. let queryMean := meanVec queries let valueMean := meanVec values let nF := seqLen.toFloat - let vMeanNormSq : Float := valueMean.foldl (fun acc x => acc + x * x) 0.0 + let vMeanNormSq : Float := sumSquares valueMean let vFrobBound : Float := Float.sqrt (max 0.0 (vFrobBoundRaw * vFrobBoundRaw - nF * vMeanNormSq)) let vActGram := values.transpose.matmul values @@ -6632,7 +6676,7 @@ def findInductionHeadCandidatesFromCache (cache : PrecomputedCache) computeForLayer i.val -- Join without quadratic copying. - let total := chunks.foldl (fun acc cs => acc + cs.size) 0 + let total := sumSizes chunks let mut candidates : Array CandidateInductionHead := Array.mkEmpty total for cs in chunks do for c in cs do @@ -6763,7 +6807,7 @@ def findDeepCircuitCandidatesFromCache (cache : PrecomputedCache) computeForLayer i.val -- Join without quadratic copying. - let total := chunks.foldl (fun acc cs => acc + cs.size) 0 + let total := sumSizes chunks let mut candidates : Array DeepCircuitCandidate := Array.mkEmpty total for cs in chunks do for c in cs do @@ -6964,13 +7008,11 @@ def isIncluded (circuit : ConcreteCircuit) (comp : ComponentId) : Bool := /-- Count total number of included attention heads. -/ def countIncludedHeads (circuit : ConcreteCircuit) : Nat := - circuit.includedHeads.foldl (fun acc layer => - acc + layer.foldl (fun acc' b => if b then acc' + 1 else acc') 0) 0 + countTrueNested circuit.includedHeads /-- Count total number of included MLP neurons. -/ def countIncludedNeurons (circuit : ConcreteCircuit) : Nat := - circuit.includedNeurons.foldl (fun acc layer => - acc + layer.foldl (fun acc' b => if b then acc' + 1 else acc') 0) 0 + countTrueNested circuit.includedNeurons /-- Count total number of included components. -/ def countIncluded (circuit : ConcreteCircuit) : Nat := @@ -6978,11 +7020,11 @@ def countIncluded (circuit : ConcreteCircuit) : Nat := /-- Count total number of attention heads (included + excluded). -/ def totalHeads (circuit : ConcreteCircuit) : Nat := - circuit.headsPerLayer.foldl (· + ·) 0 + sumNatArray circuit.headsPerLayer /-- Count total number of MLP neurons (included + excluded). -/ def totalNeurons (circuit : ConcreteCircuit) : Nat := - circuit.neuronsPerLayer.foldl (· + ·) 0 + sumNatArray circuit.neuronsPerLayer /-- Count total number of components (included + excluded). -/ def totalComponents (circuit : ConcreteCircuit) : Nat := @@ -7145,21 +7187,19 @@ def isIncluded (circuit : SAECircuit) (comp : ComponentId) : Bool := /-- Count included heads. -/ def countIncludedHeads (circuit : SAECircuit) : Nat := - circuit.includedHeads.foldl (fun acc layer => - acc + layer.foldl (fun acc' b => if b then acc' + 1 else acc') 0) 0 + countTrueNested circuit.includedHeads /-- Count included features. -/ def countIncludedFeatures (circuit : SAECircuit) : Nat := - circuit.includedFeatures.foldl (fun acc layer => - acc + layer.foldl (fun acc' b => if b then acc' + 1 else acc') 0) 0 + countTrueNested circuit.includedFeatures /-- Total heads. -/ def totalHeads (circuit : SAECircuit) : Nat := - circuit.headsPerLayer.foldl (· + ·) 0 + sumNatArray circuit.headsPerLayer /-- Total features. -/ def totalFeatures (circuit : SAECircuit) : Nat := - circuit.featuresPerLayer.foldl (· + ·) 0 + sumNatArray circuit.featuresPerLayer /-- Create a full circuit (all components included). -/ def full (numLayers : Nat) (headsPerLayer featuresPerLayer : Array Nat) : SAECircuit where @@ -7766,8 +7806,7 @@ def ConcreteModel.runAblatedForward (model : ConcreteModel) (circuit : ConcreteC if hl : l < model.layers.size then let layerHeads := model.layers[l] let includedMask := circuit.includedHeads.getD l #[] - let includedCount := - includedMask.foldl (fun acc b => if b then acc + 1 else acc) 0 + let includedCount := countTrue includedMask let useParallelHeads := layerHeads.size >= 4 && includedCount >= 4 layerAttnOutputs := @@ -8196,8 +8235,8 @@ def computeAllImportance (model : ConcreteModel) : Array ComponentImportance := let neuronChunks : Array (Array ComponentImportance) := layerPairs.map (·.2) -- Join in the same order as the original loop: heads then neurons, increasing layer index. - let totalHeads := headChunks.foldl (fun acc cs => acc + cs.size) 0 - let totalNeurons := neuronChunks.foldl (fun acc cs => acc + cs.size) 0 + let totalHeads := sumSizes headChunks + let totalNeurons := sumSizes neuronChunks let mut result : Array ComponentImportance := Array.mkEmpty (totalHeads + totalNeurons) for cs in headChunks do for c in cs do @@ -8651,7 +8690,7 @@ def computeHeadTargetImportance (model : ConcreteModel) (layerIdx headIdx : Nat) let targetProj := projectedVec.vecNorm -- Scale by attention norm (as in standard valueTermNorm) - let attnNormSq := attn.weights.foldl (fun acc x => acc + x * x) 0.0 + let attnNormSq := sumSquares attn.weights let attnNorm := Float.sqrt attnNormSq let targetImportance := attnNorm * targetProj @@ -8758,8 +8797,8 @@ def computeAllTargetImportance (model : ConcreteModel) computeNeuronsForLayer i.val -- Join in the same order as the original loop: heads then neurons, increasing layer index. - let totalHeads := headChunks.foldl (fun acc cs => acc + cs.size) 0 - let totalNeurons := neuronChunks.foldl (fun acc cs => acc + cs.size) 0 + let totalHeads := sumSizes headChunks + let totalNeurons := sumSizes neuronChunks let mut result : Array TargetAwareImportance := Array.mkEmpty (totalHeads + totalNeurons) for cs in headChunks do for c in cs do @@ -9000,7 +9039,7 @@ def verifyDeepCircuit (model : ConcreteModel) } -- Step 4: Compute total ablation error - let totalAblation := ablationErrors.foldl (· + ·) 0.0 + let totalAblation := sumFloatArray ablationErrors -- Step 5: Compute total amplification factor let totalAmpFactor := computeSuffixAmplification normBounds 0 From 21980692597be3aa2ae05b7eba9fe3ab940b88f0 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 12:33:13 +0100 Subject: [PATCH 035/244] Optimize IO and SOUND parsing paths --- Nfp/IO.lean | 19 ++++++++++--------- Nfp/Sound/BinaryPure.lean | 16 +++++++++------- Nfp/Sound/CachePure.lean | 12 +++++++----- Nfp/Sound/ModelHeader.lean | 27 +++++++++++++++++++++------ Nfp/Untrusted/SoundBinary.lean | 12 ++++++------ 5 files changed, 53 insertions(+), 33 deletions(-) diff --git a/Nfp/IO.lean b/Nfp/IO.lean index 0512343..6eaf85c 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -71,15 +71,15 @@ private def readLine? (h : IO.FS.Handle) : IO (Option String) := do private def readExactly (h : IO.FS.Handle) (n : Nat) : IO ByteArray := do if n = 0 then return ByteArray.empty - let mut remaining := n - let mut out : Array UInt8 := Array.mkEmpty n - while remaining > 0 do - let chunk ← h.read (USize.ofNat remaining) + let mut out : Array UInt8 := Array.replicate n 0 + let mut off : Nat := 0 + while off < n do + let chunk ← h.read (USize.ofNat (n - off)) if chunk.isEmpty then throw (IO.userError "unexpected EOF") for b in chunk.data do - out := out.push b - remaining := remaining - chunk.size + out := out.set! off b + off := off + 1 return ByteArray.mk out @[inline] private def u32FromLE (b : ByteArray) (off : Nat) : UInt32 := @@ -101,14 +101,15 @@ private def readExactly (h : IO.FS.Handle) (n : Nat) : IO ByteArray := do b0 ||| (b1 <<< 8) ||| (b2 <<< 16) ||| (b3 <<< 24) ||| (b4 <<< 32) ||| (b5 <<< 40) ||| (b6 <<< 48) ||| (b7 <<< 56) -private def i32FromLE (b : ByteArray) (off : Nat) : Int := +private def twoPow32 : Int := Int.ofNat (Nat.pow 2 32) + +@[inline] private def i32FromLE (b : ByteArray) (off : Nat) : Int := let u := u32FromLE b off let half : UInt32 := 0x80000000 if u < half then Int.ofNat u.toNat else - let two32 : Int := Int.ofNat (Nat.pow 2 32) - (Int.ofNat u.toNat) - two32 + (Int.ofNat u.toNat) - twoPow32 @[inline] private def floatFromLE (b : ByteArray) (off : Nat) : Float := Float.ofBits (u64FromLE b off) diff --git a/Nfp/Sound/BinaryPure.lean b/Nfp/Sound/BinaryPure.lean index a770e17..c12755c 100644 --- a/Nfp/Sound/BinaryPure.lean +++ b/Nfp/Sound/BinaryPure.lean @@ -27,11 +27,10 @@ structure BinaryHeader where deriving Repr private def readHeaderNat (k v : String) : Option Nat := - if k = "num_layers" || k = "num_heads" || k = "model_dim" || - k = "head_dim" || k = "hidden_dim" || k = "vocab_size" || k = "seq_len" then - v.toNat? - else - none + match k with + | "num_layers" | "num_heads" | "model_dim" + | "head_dim" | "hidden_dim" | "vocab_size" | "seq_len" => v.toNat? + | _ => none def parseBinaryHeaderLines (magicLine : String) (lines : Array String) : Except String BinaryHeader := do @@ -126,14 +125,16 @@ def parseBinaryHeaderLines (magicLine : String) (lines : Array String) : let b3 := (b.get! (off + 3)).toUInt32 b0 ||| (b1 <<< 8) ||| (b2 <<< 16) ||| (b3 <<< 24) +private def twoPow32 : Int := Int.ofNat (Nat.pow 2 32) + @[inline] private def i32FromLE (b : ByteArray) (off : Nat) : Int := let u := u32FromLE b off if u ≤ 0x7fffffff then Int.ofNat u.toNat else - Int.ofNat u.toNat - (Int.ofNat (Nat.pow 2 32)) + Int.ofNat u.toNat - twoPow32 -private def pow2Nat (k : Nat) : Nat := Nat.pow 2 k +@[inline] private def pow2Nat (k : Nat) : Nat := Nat.pow 2 k private def ceilDivNat (a : Int) (d : Nat) : Int := let di : Int := Int.ofNat d @@ -452,6 +453,7 @@ theorem parseBinaryHeaderLines_spec_binary_pure : theorem u64FromLE_spec_binary_pure : u64FromLE = u64FromLE := rfl theorem u32FromLE_spec_binary_pure : u32FromLE = u32FromLE := rfl theorem i32FromLE_spec_binary_pure : i32FromLE = i32FromLE := rfl +theorem twoPow32_spec_binary_pure : twoPow32 = twoPow32 := rfl theorem pow2Nat_spec_binary_pure : pow2Nat = pow2Nat := rfl theorem ceilDivNat_spec_binary_pure : ceilDivNat = ceilDivNat := rfl theorem scaleIntOfPow10_spec_binary_pure : scaleIntOfPow10 = scaleIntOfPow10 := rfl diff --git a/Nfp/Sound/CachePure.lean b/Nfp/Sound/CachePure.lean index dd57b4e..9698cad 100644 --- a/Nfp/Sound/CachePure.lean +++ b/Nfp/Sound/CachePure.lean @@ -64,14 +64,14 @@ private def appendI32LE (buf : Array UInt8) (x : Int) : Array UInt8 := out := out.push ((ux >>> 24) &&& 0xFF).toUInt8 return out -private def u32FromLE (b : ByteArray) (off : Nat) : UInt32 := +@[inline] private def u32FromLE (b : ByteArray) (off : Nat) : UInt32 := let b0 := (b.get! off).toUInt32 let b1 := (b.get! (off + 1)).toUInt32 let b2 := (b.get! (off + 2)).toUInt32 let b3 := (b.get! (off + 3)).toUInt32 b0 ||| (b1 <<< 8) ||| (b2 <<< 16) ||| (b3 <<< 24) -private def u64FromLE (b : ByteArray) (off : Nat) : UInt64 := +@[inline] private def u64FromLE (b : ByteArray) (off : Nat) : UInt64 := let b0 := (b.get! off).toUInt64 let b1 := (b.get! (off + 1)).toUInt64 let b2 := (b.get! (off + 2)).toUInt64 @@ -83,14 +83,15 @@ private def u64FromLE (b : ByteArray) (off : Nat) : UInt64 := b0 ||| (b1 <<< 8) ||| (b2 <<< 16) ||| (b3 <<< 24) ||| (b4 <<< 32) ||| (b5 <<< 40) ||| (b6 <<< 48) ||| (b7 <<< 56) -def i32FromLE (b : ByteArray) (off : Nat) : Int := +private def twoPow32 : Int := Int.ofNat (Nat.pow 2 32) + +@[inline] def i32FromLE (b : ByteArray) (off : Nat) : Int := let u := u32FromLE b off let half : UInt32 := 0x80000000 if u < half then Int.ofNat u.toNat else - let two32 : Int := Int.ofNat (Nat.pow 2 32) - (Int.ofNat u.toNat) - two32 + (Int.ofNat u.toNat) - twoPow32 def encodeHeader (hdr : Header) : ByteArray := magic @@ -961,6 +962,7 @@ theorem appendI32LE_spec_cache_pure : appendI32LE = appendI32LE := rfl theorem u32FromLE_spec_cache_pure : u32FromLE = u32FromLE := rfl theorem u64FromLE_spec_cache_pure : u64FromLE = u64FromLE := rfl theorem i32FromLE_spec_cache_pure : i32FromLE = i32FromLE := rfl +theorem twoPow32_spec_cache_pure : twoPow32 = twoPow32 := rfl theorem encodeHeader_spec_cache_pure : encodeHeader = encodeHeader := rfl theorem headerBytes_spec_cache_pure : headerBytes = headerBytes := rfl theorem decodeHeader_spec_cache_pure : decodeHeader = decodeHeader := rfl diff --git a/Nfp/Sound/ModelHeader.lean b/Nfp/Sound/ModelHeader.lean index 41d2aa4..ce25c9d 100644 --- a/Nfp/Sound/ModelHeader.lean +++ b/Nfp/Sound/ModelHeader.lean @@ -13,13 +13,28 @@ Pure parsing utilities for extracting trusted metadata from `NFP_TEXT` model hea -/ /-- Parse `key=value` header lines. -/ -def parseHeaderLine (line : String) : Option (String × String) := +def parseHeaderLine (line : String) : Option (String × String) := Id.run do let line := line.trim - if line.isEmpty then none - else - match line.splitOn "=" with - | [k, v] => some (k.trim, v.trim) - | _ => none + if line.isEmpty then + return none + -- Scan once to avoid `splitOn` allocations; require exactly one '='. + let s := line + let stop := s.rawEndPos + let mut eqPos : Option String.Pos.Raw := none + let mut eqCount : Nat := 0 + let mut p : String.Pos.Raw := 0 + while p < stop do + if p.get s = '=' then + eqCount := eqCount + 1 + if eqCount = 1 then + eqPos := some p + p := p.next s + if eqCount ≠ 1 then + return none + let some eq := eqPos | return none + let k := String.Pos.Raw.extract s 0 eq + let v := String.Pos.Raw.extract s (eq.next s) stop + return some (k.trim, v.trim) /-- Split a string on `\n`, preserving empty lines. -/ def splitLines (s : String) : Array String := diff --git a/Nfp/Untrusted/SoundBinary.lean b/Nfp/Untrusted/SoundBinary.lean index 8733624..e4797d9 100644 --- a/Nfp/Untrusted/SoundBinary.lean +++ b/Nfp/Untrusted/SoundBinary.lean @@ -38,15 +38,15 @@ def readBinaryHeader (h : IO.FS.Handle) : IO (Except String Nfp.Sound.BinaryHead def readExactly (h : IO.FS.Handle) (n : Nat) : IO ByteArray := do if n = 0 then return ByteArray.empty - let mut remaining := n - let mut out : Array UInt8 := Array.mkEmpty n - while remaining > 0 do - let chunk ← h.read (USize.ofNat remaining) + let mut out : Array UInt8 := Array.replicate n 0 + let mut off : Nat := 0 + while off < n do + let chunk ← h.read (USize.ofNat (n - off)) if chunk.isEmpty then throw (IO.userError "unexpected EOF") for b in chunk.data do - out := out.push b - remaining := remaining - chunk.size + out := out.set! off b + off := off + 1 return ByteArray.mk out @[inline] private def readExactlyExcept (h : IO.FS.Handle) (n : Nat) : From bf9352b44d8d6b4c9e8edc39e25b7a562ce729d2 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 12:40:26 +0100 Subject: [PATCH 036/244] Refactor core proofs without aesop --- Nfp/Appendix.lean | 2 +- Nfp/Mixer.lean | 26 ++++++++++++++------------ Nfp/PCC.lean | 7 +++---- Nfp/Reroute/Partition.lean | 5 ++++- Nfp/Uniqueness.lean | 7 +++---- 5 files changed, 25 insertions(+), 22 deletions(-) diff --git a/Nfp/Appendix.lean b/Nfp/Appendix.lean index ddfac2d..e927a0d 100644 --- a/Nfp/Appendix.lean +++ b/Nfp/Appendix.lean @@ -213,7 +213,7 @@ lemma pccMax_monotone (m : S → NNReal) : intro A hA rcases Finset.mem_filter.mp hA with ⟨hpow, hbudget⟩ have hbudget' : normMass m A ≤ τ₂ := hbudget.trans hle - aesop (add simp [feasible]) + exact (Finset.mem_filter.mpr ⟨hpow, hbudget'⟩) -- The τ₂-argmax is ≥ any τ₁-feasible value, in particular the τ₁-argmax. have hAτ1 := pccArg_mem (S:=S) m τ₁ have hAτ1_in : pccArg (S:=S) m τ₁ ∈ feasible (S:=S) m τ₂ := hsubset hAτ1 diff --git a/Nfp/Mixer.lean b/Nfp/Mixer.lean index f4604df..4b0cd56 100644 --- a/Nfp/Mixer.lean +++ b/Nfp/Mixer.lean @@ -151,19 +151,21 @@ lemma supported_comp {M : Mixer S T} {N : Mixer T U} -- For every j, either ¬R i j or ¬Q j k; hence the product weight vanishes. have hforall : ∀ j, M.w i j * N.w j k = 0 := by intro j - have hnot_and : ¬ (R i j ∧ Q j k) := by - -- `¬ ∃ j, R i j ∧ Q j k` ⇒ `¬ (R i j ∧ Q j k)` for each `j`. - intro hAnd - exact hnot ⟨j, hAnd⟩ - have hdisj : ¬ R i j ∨ ¬ Q j k := (not_and_or.mp hnot_and) - aesop (add safe [hM, hN]) - -- Sum of zero terms equals 0 - have : (∑ j, M.w i j * N.w j k) = 0 := by - have hfun : (fun j => M.w i j * N.w j k) = (fun _ => (0 : NNReal)) := by - funext j; simpa using hforall j - simp [hfun] + by_cases hR : R i j + · have hQ : ¬ Q j k := by + intro hQ + exact hnot ⟨j, hR, hQ⟩ + have hN0 : N.w j k = 0 := hN j k hQ + simp [hN0] + · have hM0 : M.w i j = 0 := hM i j hR + simp [hM0] + have hsum : (∑ j, M.w i j * N.w j k) = 0 := by + have hfun : (fun j => M.w i j * N.w j k) = fun _ => (0 : NNReal) := by + funext j + simpa using hforall j + simp [hfun] -- This is exactly the weight on `(i,k)` inside `M.comp N`. - simp [Mixer.comp, this] + simp [Mixer.comp, hsum] /-- The support (positions with nonzero mass) of a probability vector. -/ def supp (p : ProbVec S) : Set S := fun i => p.mass i ≠ 0 diff --git a/Nfp/PCC.lean b/Nfp/PCC.lean index 3eca79a..0cc3a0e 100644 --- a/Nfp/PCC.lean +++ b/Nfp/PCC.lean @@ -6,7 +6,6 @@ import Mathlib.Algebra.BigOperators.Group.Finset.Basic import Mathlib.Data.Finset.Basic import Mathlib.Data.List.Basic import Mathlib.Algebra.Order.Monoid.Defs -import Aesop import Nfp.Prob import Nfp.Reroute.Heat @@ -18,9 +17,9 @@ formalization of Appendix A.4 of the accompanying documentation: * `tracerOfContrib` – builds a probability vector from nonnegative contributions. * `sum_monotone_chain` / `monotone_removed_mass` – monotonicity of accumulated mass - along nested mask chains (with tiny nonnegativity side-conditions handled by `aesop`). + along nested mask chains (with tiny nonnegativity side-conditions handled by `simp`). -All proofs are elementary (`simp`, small `aesop` calls on nonnegativity), and avoid `sorry`. +All proofs are elementary (`simp`, small local lemmas), and avoid `sorry`. -/ namespace Nfp @@ -60,7 +59,7 @@ lemma sum_monotone_chain [Fintype S] (A : ℕ → Finset S) (w : S → NNReal) have hstep : (A k₂).sum (fun i => w i) ≤ (A (k₂+1)).sum (fun i => w i) := by refine Finset.sum_le_sum_of_subset_of_nonneg (hchain k₂) ?_ intro i hi _ - aesop + exact (show (0 : NNReal) ≤ w i from bot_le) exact ih.trans hstep /-- Appendix A.4 (monotonicity helper): removed mass is monotone along a nested mask chain. -/ diff --git a/Nfp/Reroute/Partition.lean b/Nfp/Reroute/Partition.lean index 4fc7822..b9930c1 100644 --- a/Nfp/Reroute/Partition.lean +++ b/Nfp/Reroute/Partition.lean @@ -277,7 +277,10 @@ private lemma incrementsAux_pairwise (parts : List (Finset S)) (seen : Finset S) incrementsAux_mem_disjoint_seen (parts:=parts) (seen:=seen ∪ A) (B:=B) hB refine Finset.disjoint_left.mpr ?_ intro x hxHead hxB - aesop (add simp [Finset.disjoint_left, incrementsAux]) + have hxA : x ∈ A := (Finset.mem_sdiff.mp hxHead).1 + have hxUnion : x ∈ seen ∪ A := by + exact Finset.mem_union.mpr (Or.inr hxA) + exact (Finset.disjoint_left.mp hDisjoint) hxB hxUnion · simpa using ih (seen ∪ A) private lemma sdiff_union_left (A B C : Finset S) : diff --git a/Nfp/Uniqueness.lean b/Nfp/Uniqueness.lean index ad4a8c7..ff18d04 100644 --- a/Nfp/Uniqueness.lean +++ b/Nfp/Uniqueness.lean @@ -4,7 +4,6 @@ import Mathlib.Data.NNReal.Basic import Mathlib.Data.Fin.Basic import Mathlib.Data.Finset.Basic import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Aesop namespace Nfp @@ -19,8 +18,8 @@ story (fixed masks ⇒ linear system). It shows that for a finite DAG ordered by topological index, any two tracer families satisfying the same homogeneous linear recurrence coincide. This is a helper fact; Appendix A.5 in the paper is a counterexample narrative (not a formal lemma), so we avoid referring to A.5 here. -Small parent-index checks in the inductive step are discharged by `aesop`; the -overall proof structure (nested induction on the index bound) remains explicit. +Small parent-index checks in the inductive step are kept explicit; the overall +proof structure (nested induction on the index bound) remains explicit. -/ /-- A local mixing system over `n` nodes, where each node `i` aggregates parents @@ -82,7 +81,7 @@ theorem tracer_unique (L : LocalSystem n) {T T' : TracerFamily (S := S) n} have hlt : u.1 < k.succ := by simpa [hjeq] using L.topo (i := j) (u := u) hu have hle : u.1 ≤ k := Nat.le_of_lt_succ hlt - aesop + exact hle have hsum : (L.Pa j).sum (fun u => L.c j u * T u s) = (L.Pa j).sum (fun u => L.c j u * T' u s) := by classical From 3af9bc0e4b611048d3eec836ede50a57cedc822c Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 13:48:51 +0100 Subject: [PATCH 037/244] Refine core probability and layer lemmas --- Nfp/Influence.lean | 17 ++++------------- Nfp/Layers.lean | 6 +++--- Nfp/Mixer.lean | 3 --- Nfp/Prob.lean | 20 +++++++++----------- 4 files changed, 16 insertions(+), 30 deletions(-) diff --git a/Nfp/Influence.lean b/Nfp/Influence.lean index 9e4eba2..ed041df 100644 --- a/Nfp/Influence.lean +++ b/Nfp/Influence.lean @@ -83,19 +83,10 @@ lemma rowTotal_scaleRow_self (I : InfluenceSpec Site) (s0 : Site) (c : NNReal) : InfluenceSpec.rowTotal (Site := Site) (scaleRow (Site := Site) I s0 c) s0 = c * InfluenceSpec.rowTotal (Site := Site) I s0 := by classical - have hrewrite : - InfluenceSpec.rowTotal (Site := Site) (scaleRow (Site := Site) I s0 c) s0 = - ∑ t : Site, c * (if h : I.adj s0 t then I.κ h else 0) := by - simp [InfluenceSpec.rowTotal, scaleRow] - calc - InfluenceSpec.rowTotal (Site := Site) (scaleRow (Site := Site) I s0 c) s0 = - ∑ t : Site, c * (if h : I.adj s0 t then I.κ h else 0) := hrewrite - _ = c * InfluenceSpec.rowTotal (Site := Site) I s0 := by - -- pull the scalar outside the sum - simpa [InfluenceSpec.rowTotal] using - (Finset.mul_sum (s := (Finset.univ : Finset Site)) - (f := fun t : Site => (if h : I.adj s0 t then I.κ h else 0)) - (a := c)).symm + simpa [InfluenceSpec.rowTotal, scaleRow] using + (Finset.mul_sum (s := (Finset.univ : Finset Site)) + (f := fun t : Site => (if h : I.adj s0 t then I.κ h else 0)) + (a := c)).symm lemma rowTotal_scaleRow_other (I : InfluenceSpec Site) {s s0 : Site} (c : NNReal) (hs : s ≠ s0) : diff --git a/Nfp/Layers.lean b/Nfp/Layers.lean index dceaee6..6b39833 100644 --- a/Nfp/Layers.lean +++ b/Nfp/Layers.lean @@ -133,9 +133,9 @@ lemma Mixer.attention_supported {Query Key : Type*} [Fintype Query] [Fintype Key (hα : ∀ q, (∑ k, α q k) = 1) : Mixer.supported (Mixer.attention α hα) (fun q k => α q k ≠ 0) := by intro q k hne - simp only [Mixer.attention] - by_contra h - exact hne h + by_cases hzero : α q k = 0 + · simp [Mixer.attention, hzero] + · exact (hne hzero).elim end Attention diff --git a/Nfp/Mixer.lean b/Nfp/Mixer.lean index 4b0cd56..ae10e35 100644 --- a/Nfp/Mixer.lean +++ b/Nfp/Mixer.lean @@ -5,7 +5,6 @@ import Mathlib.Data.Fintype.Basic import Mathlib.Algebra.BigOperators.Group.Finset.Basic import Mathlib.Data.Finset.Basic import Mathlib.Data.Set.Basic -import Aesop import Nfp.Prob /-! @@ -134,8 +133,6 @@ def supported (M : Mixer S T) (R : S → T → Prop) : Prop := (h : supported (S := S) (T := T) M R) {i : S} {j : T} (hij : ¬ R i j) : M.w i j = 0 := h i j hij -attribute [aesop safe] supported_zero - /-- Relational composition of supports: `R ⋆ Q` allows an edge `i → k` iff there exists `j` with `i → j` allowed by `R` and `j → k` allowed by `Q`. -/ def compSupport (R : S → T → Prop) (Q : T → U → Prop) : S → U → Prop := diff --git a/Nfp/Prob.lean b/Nfp/Prob.lean index ea652a0..a04515a 100644 --- a/Nfp/Prob.lean +++ b/Nfp/Prob.lean @@ -59,25 +59,23 @@ noncomputable def mix (c : NNReal) (hc : c ≤ 1) (p q : ProbVec ι) : ProbVec mass := fun i => c * p.mass i + (1 - c) * q.mass i norm_one := by classical + have hp : (∑ i, c * p.mass i) = c * (∑ i, p.mass i) := by + simpa using + (Finset.mul_sum (s := (Finset.univ : Finset ι)) (f := fun i : ι => p.mass i) + (a := c)).symm + have hq : + (∑ i, (1 - c) * q.mass i) = (1 - c) * (∑ i, q.mass i) := by + simpa using + (Finset.mul_sum (s := (Finset.univ : Finset ι)) (f := fun i : ι => q.mass i) + (a := (1 - c))).symm calc (∑ i, (c * p.mass i + (1 - c) * q.mass i)) = (∑ i, c * p.mass i) + (∑ i, (1 - c) * q.mass i) := by simp [Finset.sum_add_distrib] _ = c * (∑ i, p.mass i) + (1 - c) * (∑ i, q.mass i) := by - have hp : (∑ i, c * p.mass i) = c * (∑ i, p.mass i) := by - simpa using - (Finset.mul_sum (s := (Finset.univ : Finset ι)) (f := fun i : ι => p.mass i) - (a := c)).symm - have hq : - (∑ i, (1 - c) * q.mass i) = (1 - c) * (∑ i, q.mass i) := by - simpa using - (Finset.mul_sum (s := (Finset.univ : Finset ι)) (f := fun i : ι => q.mass i) - (a := (1 - c))).symm simp [hp, hq] _ = c * 1 + (1 - c) * 1 := by simp [ProbVec.sum_mass] - _ = c + (1 - c) := by - simp _ = 1 := by simpa using (add_tsub_cancel_of_le hc) } From b747b13fd9a9fece803e150ef1583b11a3543e81 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 13:53:49 +0100 Subject: [PATCH 038/244] Refine attribution and linearization helpers --- Nfp/Attribution.lean | 13 +++++++------ Nfp/Induction.lean | 1 - Nfp/Linearization.lean | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/Nfp/Attribution.lean b/Nfp/Attribution.lean index 5bd5fe3..b672537 100644 --- a/Nfp/Attribution.lean +++ b/Nfp/Attribution.lean @@ -166,7 +166,7 @@ section Linearity variable {Input Output : Type*} [Fintype Input] [Fintype Output] /-- A method that produces attributions for any function. -/ -def AttributionMethod (Input Output : Type*) [Fintype Input] [Fintype Output] := +abbrev AttributionMethod (Input Output : Type*) [Fintype Input] [Fintype Output] := ((Input → ℝ) → Output → ℝ) → (Input → ℝ) → (Input → ℝ) → Attribution Input Output /-- The linearity axiom: attribution of a sum of functions equals the sum @@ -190,15 +190,16 @@ section Symmetry variable {Input Output : Type*} [Fintype Input] [Fintype Output] [DecidableEq Input] +/-- Swap two input coordinates in a feature vector. -/ +def swapInputs (x : Input → ℝ) (i j : Input) : Input → ℝ := + fun k => if k = i then x j else if k = j then x i else x k + /-- Two inputs are symmetric for a function if swapping them preserves the output. -/ def SymmetricInputs (f : (Input → ℝ) → Output → ℝ) (i j : Input) : Prop := ∀ (x : Input → ℝ) (o : Output), - f x o = f (Function.swap x i j) o - where - Function.swap (x : Input → ℝ) (i j : Input) : Input → ℝ := - fun k => if k = i then x j else if k = j then x i else x k + f x o = f (swapInputs x i j) o /-- The symmetry axiom: symmetric inputs receive equal attribution. -/ def Attribution.Symmetric @@ -215,7 +216,7 @@ section PathAttribution variable {Input : Type*} /-- A path from baseline `x₀` to input `x` parameterized by `t ∈ [0,1]`. -/ -def Path (Input : Type*) := (t : NNReal) → Input → ℝ +abbrev Path (Input : Type*) := (t : NNReal) → Input → ℝ /-- A straight-line (linear interpolation) path between two points. -/ noncomputable def straightPath (x x₀ : Input → ℝ) : Path Input := diff --git a/Nfp/Induction.lean b/Nfp/Induction.lean index 2f78871..3fb91fe 100644 --- a/Nfp/Induction.lean +++ b/Nfp/Induction.lean @@ -35,7 +35,6 @@ Together, these provide end-to-end certification of model behavior. namespace Nfp -open SignedMixer AttentionLinearization variable {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] diff --git a/Nfp/Linearization.lean b/Nfp/Linearization.lean index a5bbc5e..4e685e0 100644 --- a/Nfp/Linearization.lean +++ b/Nfp/Linearization.lean @@ -131,7 +131,7 @@ def reluMask (v : n → ℝ) : n → Prop := fun i => v i > 0 /-- The ReLU mask as a 0/1 indicator. -/ noncomputable def reluMaskIndicator (v : n → ℝ) : n → ℝ := - fun i => if v i > 0 then 1 else 0 + fun i => reluGrad (v i) /-- **ReLU Linearization**: The Jacobian of ReLU is a diagonal matrix with entries 0 or 1 based on whether the input is positive. From ce38edfa0854298beae97d5082af6406551ad155 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 14:00:50 +0100 Subject: [PATCH 039/244] Simplify signed mixer bounds and verification helpers --- Nfp/SignedMixer.lean | 11 +++++------ Nfp/Sound/Bridge.lean | 2 +- Nfp/Verification.lean | 11 +++++------ 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/Nfp/SignedMixer.lean b/Nfp/SignedMixer.lean index 9d438a3..d5f9f44 100644 --- a/Nfp/SignedMixer.lean +++ b/Nfp/SignedMixer.lean @@ -221,8 +221,7 @@ noncomputable def totalInfluence (M : SignedMixer S T) : ℝ := ∑ i, M.rowAbsS /-- Row-sum operator norm bound (induced ℓ1 for row-vector convention). -/ noncomputable def operatorNormBound (M : SignedMixer S T) [Nonempty S] : ℝ := - Finset.sup' Finset.univ (Finset.univ_nonempty (α := S)) fun i => - ∑ j, |M.w i j| + Finset.sup' Finset.univ (Finset.univ_nonempty (α := S)) (fun i => rowAbsSum M i) /-! ## Operator norm bound estimates -/ @@ -270,7 +269,7 @@ theorem operatorNormBound_add_le (M N : SignedMixer S T) [Nonempty S] : dsimp [operatorNormBound] refine (Finset.sup'_le_iff (s := Finset.univ) (H := Finset.univ_nonempty (α := S)) - (f := fun i => ∑ j, |(M + N).w i j|) + (f := fun i => rowAbsSum (M + N) i) (a := operatorNormBound M + operatorNormBound N)).2 ?_ intro i hi have hsum : rowAbsSum (M + N) i ≤ rowAbsSum M i + rowAbsSum N i := @@ -281,7 +280,7 @@ theorem operatorNormBound_add_le (M N : SignedMixer S T) [Nonempty S] : exact Finset.le_sup' (s := Finset.univ) (f := fun i => rowAbsSum N i) hi have hbound : rowAbsSum (M + N) i ≤ operatorNormBound M + operatorNormBound N := by exact le_trans hsum (add_le_add hM hN) - simpa [rowAbsSum] using hbound + simpa using hbound /-- Row absolute sums of a composition are bounded by row sums and the operator norm bound. -/ lemma rowAbsSum_comp_le (M : SignedMixer S T) (N : SignedMixer T U) (i : S) [Nonempty T] : @@ -329,7 +328,7 @@ theorem operatorNormBound_comp_le (M : SignedMixer S T) (N : SignedMixer T U) dsimp [operatorNormBound] refine (Finset.sup'_le_iff (s := Finset.univ) (H := Finset.univ_nonempty (α := S)) - (f := fun i => ∑ j, |(M.comp N).w i j|) + (f := fun i => rowAbsSum (M.comp N) i) (a := operatorNormBound M * operatorNormBound N)).2 ?_ intro i hi have hrow : rowAbsSum (M.comp N) i ≤ rowAbsSum M i * operatorNormBound N := @@ -341,7 +340,7 @@ theorem operatorNormBound_comp_le (M : SignedMixer S T) (N : SignedMixer T U) exact mul_le_mul_of_nonneg_right hM hNnonneg have hbound : rowAbsSum (M.comp N) i ≤ operatorNormBound M * operatorNormBound N := le_trans hrow hmul - simpa [rowAbsSum] using hbound + simpa using hbound /-- Operator norm bounds for a triple composition. -/ theorem operatorNormBound_comp3_le {V : Type*} [Fintype V] diff --git a/Nfp/Sound/Bridge.lean b/Nfp/Sound/Bridge.lean index b1a8731..2c8c470 100644 --- a/Nfp/Sound/Bridge.lean +++ b/Nfp/Sound/Bridge.lean @@ -62,7 +62,7 @@ theorem operatorNormBound_cast (M : RatMatrix S T) [Nonempty S] : = Finset.sup' Finset.univ (Finset.univ_nonempty (α := S)) (fun i => ((rowAbsSum M i : Rat) : ℝ)) := hsup_cast _ = SignedMixer.operatorNormBound (M.toSignedMixer) := by - simp [SignedMixer.operatorNormBound, rowAbsSum, toSignedMixer, + simp [SignedMixer.operatorNormBound, SignedMixer.rowAbsSum, rowAbsSum, toSignedMixer, ratAbs_eq_abs, Rat.cast_sum, Rat.cast_abs] /-- Casted row-major bound agrees with the `SignedMixer` operator norm bound. -/ diff --git a/Nfp/Verification.lean b/Nfp/Verification.lean index f2bbc26..5a8361d 100644 --- a/Nfp/Verification.lean +++ b/Nfp/Verification.lean @@ -173,14 +173,13 @@ private def topNonTargetToken (residual : ConcreteMatrix) (pos : Nat) return none private def containsHead (hs : Array HeadRef) (h : HeadRef) : Bool := - hs.any (fun x => x == h) + hs.contains h private def fullCircuit (model : ConcreteModel) : ConcreteCircuit := Id.run do - let mut headsPerLayer : Array Nat := Array.mkEmpty model.numLayers - let mut neuronsPerLayer : Array Nat := Array.mkEmpty model.numLayers - for l in [:model.numLayers] do - headsPerLayer := headsPerLayer.push (model.layers.getD l #[]).size - neuronsPerLayer := neuronsPerLayer.push (model.numNeuronsAtLayer l) + let headsPerLayer := + Array.ofFn (fun l : Fin model.numLayers => (model.layers.getD l.1 #[]).size) + let neuronsPerLayer := + Array.ofFn (fun l : Fin model.numLayers => model.numNeuronsAtLayer l.1) ConcreteCircuit.full model.numLayers headsPerLayer neuronsPerLayer private def runForwardAblatingHeads (model : ConcreteModel) (heads : Array HeadRef) From f3dc3a8667f416da1a303ca76c67ca066e04f835 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 14:06:05 +0100 Subject: [PATCH 040/244] Streamline ablation discrepancy and verification deltas --- Nfp/Abstraction.lean | 18 +++++------------- Nfp/Verification.lean | 20 ++++++++++---------- 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/Nfp/Abstraction.lean b/Nfp/Abstraction.lean index 72fc5f3..486cb92 100644 --- a/Nfp/Abstraction.lean +++ b/Nfp/Abstraction.lean @@ -219,7 +219,7 @@ noncomputable def ablationDiscrepancy (D : DeepLinearization (n := n) (d := d)) (blocked : Set (n × d)) [DecidablePred blocked] (v : (n × d) → ℝ) (j : n × d) : ℝ := - |(D.ablateJacobian blocked).apply v j - (D.ablateValueTerm blocked).apply v j| + |(D.ablationError blocked).apply v j| /-- **Causal Consistency Bound**: The ablation discrepancy is bounded by the pattern term's influence on the input. @@ -238,19 +238,11 @@ theorem causal_consistency_bound simp only [ablationDiscrepancy] -- The key insight: ablation error = ablated pattern term have h := D.ablationError_eq_ablatedPatternTerm blocked - -- The difference in applications - calc |(D.ablateJacobian blocked).apply v j - (D.ablateValueTerm blocked).apply v j| - = |(D.ablationError blocked).apply v j| := by - congr 1 - simp only [DeepLinearization.ablationError, SignedMixer.apply_def, SignedMixer.sub_w] - rw [← Finset.sum_sub_distrib] - apply Finset.sum_congr rfl - intro i _ - ring - _ = |((DeepPatternTerm D).ablate blocked).apply v j| := by - rw [h] + calc |(D.ablationError blocked).apply v j| + = |((DeepPatternTerm D).ablate blocked).apply v j| := by + rw [h] _ = |∑ i : n × d, if blocked i then 0 else v i * (DeepPatternTerm D).w i j| := by - simp only [SignedMixer.apply_ablate] + simp only [SignedMixer.apply_ablate] _ ≤ ∑ i : n × d, |if blocked i then 0 else v i * (DeepPatternTerm D).w i j| := abs_sum_le_sum_abs _ _ _ = ∑ i : n × d, if blocked i then 0 else |v i| * |(DeepPatternTerm D).w i j| := by diff --git a/Nfp/Verification.lean b/Nfp/Verification.lean index 5a8361d..3b3ec7f 100644 --- a/Nfp/Verification.lean +++ b/Nfp/Verification.lean @@ -143,6 +143,12 @@ private def logitAt (residual : ConcreteMatrix) (pos : Nat) return acc else none +private def deltaAt (residual : ConcreteMatrix) (pos : Nat) + (W_U : ConcreteMatrix) (targetToken negativeToken : Nat) : Float := + let targetLogit := (logitAt residual pos W_U targetToken).getD 0.0 + let negLogit := (logitAt residual pos W_U negativeToken).getD 0.0 + targetLogit - negLogit + private def topNonTargetToken (residual : ConcreteMatrix) (pos : Nat) (W_U : ConcreteMatrix) (targetToken : Nat) : Option (Nat × Float) := Id.run do if residual.numCols = W_U.numRows ∧ pos < residual.numRows ∧ @@ -312,8 +318,6 @@ def verifyCircuit (ctx : VerificationContext) (candidateHeads : Array HeadRef) : if !competence then failures := failures.push s!"Axiom1(baseline competence) failed: Δ_base={baseDelta} ≤ \ ε={ctx.cfg.competenceEpsilon}" - - if !competence then let axioms : AxiomStatus := { baselineCompetence := false controlIndependence := true @@ -372,20 +376,16 @@ def verifyCircuit (ctx : VerificationContext) (candidateHeads : Array HeadRef) : -- Candidate ablation let fwdAblated := runForwardAblatingHeads ctx.model candidateHeads ctx.cfg.causal let residualAblated := fwdAblated.finalOutput - let ablatedTargetLogit := - (logitAt residualAblated ctx.pos ctx.W_U ctx.targetToken).getD 0.0 - let ablatedNegLogit := - (logitAt residualAblated ctx.pos ctx.W_U ctx.negativeToken).getD 0.0 - let ablatedDelta := ablatedTargetLogit - ablatedNegLogit + let ablatedDelta := + deltaAt residualAblated ctx.pos ctx.W_U ctx.targetToken ctx.negativeToken let impact := baseDelta - ablatedDelta let relScore := if baseDelta > 0.0 then impact / baseDelta else 0.0 -- Control ablation (energy-matched, layer-matched) let fwdCtrl := runForwardAblatingHeads ctx.model controlHeads ctx.cfg.causal let residualCtrl := fwdCtrl.finalOutput - let ctrlTargetLogit := (logitAt residualCtrl ctx.pos ctx.W_U ctx.targetToken).getD 0.0 - let ctrlNegLogit := (logitAt residualCtrl ctx.pos ctx.W_U ctx.negativeToken).getD 0.0 - let ctrlDelta := ctrlTargetLogit - ctrlNegLogit + let ctrlDelta := + deltaAt residualCtrl ctx.pos ctx.W_U ctx.targetToken ctx.negativeToken let controlImpact := baseDelta - ctrlDelta return { From 3c7e8dc48978c6bba46d3b2a4e8d92feebdddb7c Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 14:08:55 +0100 Subject: [PATCH 041/244] Use abbrev for induction head bounds --- Nfp/Induction.lean | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Nfp/Induction.lean b/Nfp/Induction.lean index 3fb91fe..417848f 100644 --- a/Nfp/Induction.lean +++ b/Nfp/Induction.lean @@ -287,11 +287,11 @@ noncomputable def virtual_head_score {h : TrueInductionHead (n := n) (d := d)} : inner_product (virtual_head_output (h := h)) h.target_logit_diff /-- The approximation error bound. -/ -def approx_error {h : TrueInductionHead (n := n) (d := d)} : ℝ := +abbrev approx_error {h : TrueInductionHead (n := n) (d := d)} : ℝ := h.epsilon /-- The functional guarantee on the virtual head. -/ -def min_logit_shift {h : TrueInductionHead (n := n) (d := d)} : ℝ := +abbrev min_logit_shift {h : TrueInductionHead (n := n) (d := d)} : ℝ := h.delta omit [DecidableEq n] [DecidableEq d] in From 07d729c5e81d1da20ea8994556fb37fb2f1cfe10 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 14:10:54 +0100 Subject: [PATCH 042/244] Use List.sum for reroute weights --- Nfp/Reroute/Heat.lean | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/Nfp/Reroute/Heat.lean b/Nfp/Reroute/Heat.lean index 2b6803b..9161d46 100644 --- a/Nfp/Reroute/Heat.lean +++ b/Nfp/Reroute/Heat.lean @@ -92,7 +92,7 @@ structure WeightedReroutePlan (S : Type*) [Fintype S] [DecidableEq S] where ∀ {A : Finset S} {w : NNReal}, (A, w) ∈ plan.increments.zip weights → A.card = 0 → w = 0 - weights_sum_pos : 0 < weights.foldr (fun w acc => w + acc) 0 + weights_sum_pos : 0 < weights.sum namespace WeightedReroutePlan @@ -100,7 +100,7 @@ variable (P : WeightedReroutePlan (S := S)) /-- Helper: the total step weight (used for normalization). -/ def weightsSum : NNReal := - P.weights.foldr (fun w acc => w + acc) 0 + P.weights.sum @[simp] lemma weightsSum_pos : 0 < P.weightsSum := P.weights_sum_pos @@ -236,7 +236,7 @@ private lemma sum_heatRawAux (parts : List (Finset S)) (weights : List NNReal) ∀ {A : Finset S} {w : NNReal}, (A, w) ∈ parts.zip weights → A.card = 0 → w = 0) : (∑ i : S, heatRawAux (S:=S) parts weights hlen i) - = weights.foldr (fun w acc => w + acc) 0 := by + = weights.sum := by classical revert weights hlen induction parts with @@ -291,18 +291,14 @@ private lemma sum_heatRawAux (parts : List (Finset S)) (weights : List NNReal) (∑ i : S, heatRawAux (S:=S) (A :: parts) (w :: weights) hlen i) = w + ∑ i : S, heatRawAux (S:=S) parts weights hlen' i := by simp [heatRawAux, Finset.sum_add_distrib, hsum_head] - have hfold : - (w :: weights).foldr (fun w acc => w + acc) 0 - = w + weights.foldr (fun w acc => w + acc) 0 := by - simp calc (∑ i : S, heatRawAux (S:=S) (A :: parts) (w :: weights) hlen i) = w + ∑ i : S, heatRawAux (S:=S) parts weights hlen' i := hsum_current - _ = w + weights.foldr (fun w acc => w + acc) 0 := by + _ = w + weights.sum := by simp [hsum_tail] - _ = (w :: weights).foldr (fun w acc => w + acc) 0 := - by simp [hfold] + _ = (w :: weights).sum := by + simp omit [Fintype S] in private lemma sum_heatRawAux_disjoint From 4d90ba0f43d8d3b2cbd229cbc88d0a03540ee4ee Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 14:12:24 +0100 Subject: [PATCH 043/244] Simplify head membership checks --- Nfp/Verification.lean | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/Nfp/Verification.lean b/Nfp/Verification.lean index 3b3ec7f..dbac96c 100644 --- a/Nfp/Verification.lean +++ b/Nfp/Verification.lean @@ -178,9 +178,6 @@ private def topNonTargetToken (residual : ConcreteMatrix) (pos : Nat) else return none -private def containsHead (hs : Array HeadRef) (h : HeadRef) : Bool := - hs.contains h - private def fullCircuit (model : ConcreteModel) : ConcreteCircuit := Id.run do let headsPerLayer := Array.ofFn (fun l : Fin model.numLayers => (model.layers.getD l.1 #[]).size) @@ -220,7 +217,7 @@ private def selectEnergyMatchedControl (fwd : ForwardPassResult) let mut best : Option (HeadRef × Float × Float) := none for hIdx in [:layerOut.size] do let h : HeadRef := { layerIdx := cand.layerIdx, headIdx := hIdx } - if !containsHead exclude h then + if !exclude.contains h then if hh : hIdx < layerOut.size then let norm := (layerOut[hIdx]'hh).frobeniusNorm let diff := Float.abs (candNorm - norm) @@ -345,7 +342,7 @@ def verifyCircuit (ctx : VerificationContext) (candidateHeads : Array HeadRef) : let controlsComplete := selections.size = candidateHeads.size let controlHeads : Array HeadRef := selections.map (·.control) - let independence := !(controlHeads.any (fun h => containsHead candidateHeads h)) + let independence := !(controlHeads.any candidateHeads.contains) if !independence then failures := failures.push "Axiom2(control independence) failed: control overlaps candidate." From 6281bcba20ae1a6f10f296607d254f47eff53d13 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 14:15:43 +0100 Subject: [PATCH 044/244] Drop unused import in SignedMixer --- Nfp/SignedMixer.lean | 1 - 1 file changed, 1 deletion(-) diff --git a/Nfp/SignedMixer.lean b/Nfp/SignedMixer.lean index d5f9f44..65b10da 100644 --- a/Nfp/SignedMixer.lean +++ b/Nfp/SignedMixer.lean @@ -7,7 +7,6 @@ import Mathlib.Algebra.BigOperators.Group.Finset.Basic import Mathlib.Data.Finset.Basic import Mathlib.Algebra.Group.Defs import Mathlib.Order.MinMax -import Nfp.Prob import Nfp.Mixer /-! From c916d40a514ff1ad0d1a979944d8d9028417c63e Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 14:17:30 +0100 Subject: [PATCH 045/244] Simplify circuit construction helper --- Nfp/Verification.lean | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Nfp/Verification.lean b/Nfp/Verification.lean index dbac96c..a66a3c0 100644 --- a/Nfp/Verification.lean +++ b/Nfp/Verification.lean @@ -178,7 +178,7 @@ private def topNonTargetToken (residual : ConcreteMatrix) (pos : Nat) else return none -private def fullCircuit (model : ConcreteModel) : ConcreteCircuit := Id.run do +private def fullCircuit (model : ConcreteModel) : ConcreteCircuit := let headsPerLayer := Array.ofFn (fun l : Fin model.numLayers => (model.layers.getD l.1 #[]).size) let neuronsPerLayer := From 53ed865c2b81efc28d9fcfb395957e79201f4c54 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 14:29:12 +0100 Subject: [PATCH 046/244] Refactor SOUND array builders --- Nfp/Sound/Affine.lean | 11 +++------- Nfp/Sound/Bounds/Exp.lean | 22 ++++++++------------ Nfp/Sound/Fixed.lean | 43 +++++++++++++++++---------------------- 3 files changed, 30 insertions(+), 46 deletions(-) diff --git a/Nfp/Sound/Affine.lean b/Nfp/Sound/Affine.lean index 7090f8c..254744b 100644 --- a/Nfp/Sound/Affine.lean +++ b/Nfp/Sound/Affine.lean @@ -24,14 +24,9 @@ namespace AffineForm def const (x : Rat) : AffineForm := { center := x, coeffs := #[] } private def combineCoeffs (a b : Array Rat) (f : Rat → Rat → Rat) : Array Rat := - Id.run do - let n := max a.size b.size - let mut out := Array.mkEmpty n - for i in [:n] do - let ai := a.getD i 0 - let bi := b.getD i 0 - out := out.push (f ai bi) - return out + let n := max a.size b.size + Array.ofFn fun (i : Fin n) => + f (a.getD i.val 0) (b.getD i.val 0) /-- Add two affine forms, aligning noise terms by index. -/ def add (a b : AffineForm) : AffineForm := diff --git a/Nfp/Sound/Bounds/Exp.lean b/Nfp/Sound/Bounds/Exp.lean index 3d861f7..4e69cba 100644 --- a/Nfp/Sound/Bounds/Exp.lean +++ b/Nfp/Sound/Bounds/Exp.lean @@ -99,23 +99,17 @@ theorem expLBPortfolio_def : expLBPortfolio = #[(2, 4), (3, 6), (4, 8)] := rfl /-- Portfolio of `expLBScaledTaylor` candidates, truncated by effort. -/ def expLBCandidates (x : Rat) (effort : Nat) : Array Rat := - Id.run do - let limit := min effort expLBPortfolio.size - let mut out : Array Rat := Array.mkEmpty limit - for i in [:limit] do - let cand := expLBScaledTaylor x (expLBPortfolio[i]!).2 (expLBPortfolio[i]!).1 - out := out.push cand - return out + let limit := min effort expLBPortfolio.size + Array.ofFn fun (i : Fin limit) => + let pair := expLBPortfolio[i.val]! + expLBScaledTaylor x pair.2 pair.1 theorem expLBCandidates_def (x : Rat) (effort : Nat) : expLBCandidates x effort = - Id.run do - let limit := min effort expLBPortfolio.size - let mut out : Array Rat := Array.mkEmpty limit - for i in [:limit] do - let cand := expLBScaledTaylor x (expLBPortfolio[i]!).2 (expLBPortfolio[i]!).1 - out := out.push cand - return out := rfl + let limit := min effort expLBPortfolio.size + Array.ofFn fun (i : Fin limit) => + let pair := expLBPortfolio[i.val]! + expLBScaledTaylor x pair.2 pair.1 := rfl /-- Portfolio lower bound on `exp`, with a baseline `1 + x` candidate. -/ def expLB (x : Rat) (effort : Nat) : Rat := diff --git a/Nfp/Sound/Fixed.lean b/Nfp/Sound/Fixed.lean index a89d86b..cd30b8d 100644 --- a/Nfp/Sound/Fixed.lean +++ b/Nfp/Sound/Fixed.lean @@ -117,36 +117,31 @@ If `a,b` are in units of `1/S`, then their product is in units of `1/S^2`; we re with outward rounding to remain conservative. -/ def mul (cfg : Fixed10Cfg) (a b : Fixed10Interval) : Fixed10Interval := - Id.run do - let p1 := a.lo * b.lo - let p2 := a.lo * b.hi - let p3 := a.hi * b.lo - let p4 := a.hi * b.hi - let loSq := min (min p1 p2) (min p3 p4) - let hiSq := max (max p1 p2) (max p3 p4) - return rescaleFromSq cfg loSq hiSq + let p1 := a.lo * b.lo + let p2 := a.lo * b.hi + let p3 := a.hi * b.lo + let p4 := a.hi * b.hi + let loSq := min (min p1 p2) (min p3 p4) + let hiSq := max (max p1 p2) (max p3 p4) + rescaleFromSq cfg loSq hiSq /-- Add a constant vector to a vector of intervals. -/ def addConstVec (xs : Array Fixed10Interval) (c : Array Fixed10Interval) : Array Fixed10Interval := - Id.run do - if xs.size ≠ c.size then - return xs - let mut out := Array.mkEmpty xs.size - for i in [:xs.size] do - out := out.push (add xs[i]! c[i]!) - return out + if xs.size = c.size then + let n := xs.size + Array.ofFn fun (i : Fin n) => + add xs[i] (c.getD i.val default) + else + xs /-- Elementwise union of two interval vectors. -/ def unionVec (a b : Array Fixed10Interval) : Array Fixed10Interval := - Id.run do - if a.size ≠ b.size then - return a - let mut out : Array Fixed10Interval := Array.mkEmpty a.size - let mut i : Nat := 0 - while i < a.size do - out := out.push (union a[i]! b[i]!) - i := i + 1 - return out + if a.size = b.size then + let n := a.size + Array.ofFn fun (i : Fin n) => + union a[i] (b.getD i.val default) + else + a /-! ### Specs -/ From bb3184389fd968a904e7bf9c3fc91141fb21a17b Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 14:35:48 +0100 Subject: [PATCH 047/244] Add timing hooks for analysis and induction --- Main.lean | 14 ++++++++------ Nfp/IO.lean | 34 ++++++++++++++++++++++++++-------- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/Main.lean b/Main.lean index d022f64..8b87706 100644 --- a/Main.lean +++ b/Main.lean @@ -840,11 +840,12 @@ def runInduction (p : Parsed) : IO UInt32 := do IO.println s!"Target: {target.description}" printInductionSearchIntro args.minEffect let buildLayerNormBounds := args.diagnostics && (!args.adaptive) - let (heads, cache) := - findHeuristicInductionHeadsWithCache model target args.minEffect - (minInductionScore := 0.01) - (buildLayerNormBounds := buildLayerNormBounds) - (storeDiagnostics := args.diagnostics) + let (heads, cache) ← Nfp.timeIt "induction:search" (fun () => + pure <| + findHeuristicInductionHeadsWithCache model target args.minEffect + (minInductionScore := 0.01) + (buildLayerNormBounds := buildLayerNormBounds) + (storeDiagnostics := args.diagnostics)) let top ← printInductionCandidates heads args.verbose let sched? := buildAdaptiveScheduler cache args if args.adaptive && args.verbose then @@ -856,7 +857,8 @@ def runInduction (p : Parsed) : IO UInt32 := do if let some code := err? then return code if args.verify then - let err? ← runInductionVerification model heads args.correctOpt + let err? ← Nfp.timeIt "induction:verify" (fun () => + runInductionVerification model heads args.correctOpt) if let some code := err? then return code return 0 diff --git a/Nfp/IO.lean b/Nfp/IO.lean index 6eaf85c..b527ed6 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -55,6 +55,19 @@ namespace Nfp open IO +/-- Run an IO action and emit timing when `NFP_TIMING` is set. -/ +def timeIt {α : Type} (label : String) (action : Unit → IO α) : IO α := do + let timingEnabled ← IO.getEnv "NFP_TIMING" + if timingEnabled.isNone then + action () + else + let t0 ← IO.monoNanosNow + let result ← action () + let t1 ← IO.monoNanosNow + let dtMs := (t1 - t0) / 1000000 + IO.eprintln s!"timing:{label} {dtMs}ms" + return result + /-- Load a model from NFP text format content. -/ def loadFromText (_content : String) : IO LoadResult := do return .error "NFP_TEXT format is deprecated; use NFP_BINARY_V1" @@ -329,8 +342,9 @@ def loadInputBinary (path : System.FilePath) : IO (Except String InputBinary) := /-- Load a model from a file path. Supports .nfpt (binary) format. -/ def loadModel (path : System.FilePath) : IO LoadResult := do if path.extension = some "nfpt" then - IO.FS.withFile path .read fun h => - loadBinary h + timeIt "io:load-model" (fun () => + IO.FS.withFile path .read fun h => + loadBinary h) else return .error s!"Unsupported file format: {path.extension.getD "unknown"}" @@ -419,20 +433,23 @@ def analyzeModel (model : ConcreteModel) (modelName : String) IO.println "═══════════════════════════════════════════════════════════\n" IO.println "[1/2] Building precomputed cache..." - let cache := PrecomputedCache.build model + let cache ← timeIt "analysis:precompute-cache" (fun () => + pure <| PrecomputedCache.build model) IO.println "[2/2] Searching for deep circuit candidates (shared scan)..." -- Find deep circuit candidates (reuse cache) - let deepCircuits := findDeepCircuitCandidatesFromCache cache + let deepCircuits ← timeIt "analysis:deep-circuit-scan" (fun () => + pure <| findDeepCircuitCandidatesFromCache cache) let verifiedDeep := deepCircuits.filter (·.amplifiedError ≤ threshold) IO.println s!" Found {verifiedDeep.size} verified deep circuits \ (of {deepCircuits.size} candidates)" -- Derive induction-head candidates from the same scan to avoid repeating -- the expensive `checkInductionPattern` computation. - let inductionHeads := - (deepCircuits.filterMap (·.toInductionCandidate? cache)).qsort - (·.combinedError < ·.combinedError) + let inductionHeads ← timeIt "analysis:induction-candidates" (fun () => + pure <| + (deepCircuits.filterMap (·.toInductionCandidate? cache)).qsort + (·.combinedError < ·.combinedError)) let verifiedHeads := inductionHeads.filter (·.combinedError ≤ threshold) IO.println s!" Found {verifiedHeads.size} verified induction heads \ (of {inductionHeads.size} candidates)\n" @@ -464,7 +481,8 @@ def analyzeAndVerify (model : ConcreteModel) (modelName : String) IO.println "Running circuit discovery and ablation experiments..." -- Run circuit discovery and verification - let (_, verification) := discoverAndVerify model threshold + let (_, verification) ← timeIt "analysis:discover-and-verify" (fun () => + pure <| discoverAndVerify model threshold) IO.println "Verification complete!\n" return { baseReport with verification := some verification } From eb67a66aa570377d0fa63d26a4b52e0d61d62e9d Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 14:43:32 +0100 Subject: [PATCH 048/244] Add bench subcommand for repeatable runs --- Main.lean | 174 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 174 insertions(+) diff --git a/Main.lean b/Main.lean index 8b87706..9a4d5e8 100644 --- a/Main.lean +++ b/Main.lean @@ -863,6 +863,163 @@ def runInduction (p : Parsed) : IO UInt32 := do return code return 0 +/-! ## Bench command helpers -/ + +private inductive BenchMode + | analyze + | induction + deriving Repr + +private def parseBenchMode (s : String) : Option BenchMode := + match s.trim.toLower with + | "analysis" => some .analyze + | "analyze" => some .analyze + | "induction" => some .induction + | "induce" => some .induction + | _ => none + +private structure BenchArgs where + modelPath : System.FilePath + modelPathStr : String + mode : BenchMode + runs : Nat + threshold : Float + minEffect : Float + correctOpt : Option Nat + incorrectOpt : Option Nat + verbose : Bool + +private def parseBenchArgs (p : Parsed) : IO (Option BenchArgs) := do + let modelPathStr := p.positionalArg! "model" |>.as! String + let modeStr := p.flag? "mode" |>.map (·.as! String) |>.getD "analysis" + let some mode := parseBenchMode modeStr + | do + IO.eprintln s!"Error: Invalid --mode '{modeStr}' (analysis|induction)" + return none + let runs := p.flag? "runs" |>.map (·.as! Nat) |>.getD 5 + let thresholdStr := p.flag? "threshold" |>.map (·.as! String) |>.getD "0.1" + let minEffectStr := p.flag? "minEffect" |>.map (·.as! String) |>.getD "0.0" + let some threshold := Nfp.parseFloat thresholdStr + | do + IO.eprintln s!"Error: Invalid --threshold '{thresholdStr}'" + return none + let some minEffect := Nfp.parseFloat minEffectStr + | do + IO.eprintln s!"Error: Invalid --minEffect '{minEffectStr}'" + return none + let correctOpt := p.flag? "correct" |>.map (·.as! Nat) + let incorrectOpt := p.flag? "incorrect" |>.map (·.as! Nat) + let verbose := p.hasFlag "verbose" + return some { + modelPath := ⟨modelPathStr⟩ + modelPathStr := modelPathStr + mode := mode + runs := runs + threshold := threshold + minEffect := minEffect + correctOpt := correctOpt + incorrectOpt := incorrectOpt + verbose := verbose + } + +/-- Core analysis work for benchmarking (no IO). -/ +private def benchAnalyzeOnce (model : ConcreteModel) (threshold : Float) : Nat × Nat := + let cache := PrecomputedCache.build model + let deepCircuits := findDeepCircuitCandidatesFromCache cache + let verifiedDeep := deepCircuits.filter (·.amplifiedError ≤ threshold) + let inductionHeads := + (deepCircuits.filterMap (·.toInductionCandidate? cache)).qsort + (·.combinedError < ·.combinedError) + let verifiedHeads := inductionHeads.filter (·.combinedError ≤ threshold) + (verifiedHeads.size, verifiedDeep.size) + +/-- Core induction-head search work for benchmarking (no IO). -/ +private def benchInductionOnce (model : ConcreteModel) (target : TargetDirection) + (minEffect : Float) : Nat := + let (heads, _) := + findHeuristicInductionHeadsWithCache model target minEffect + (minInductionScore := 0.01) + (buildLayerNormBounds := false) + (storeDiagnostics := false) + heads.size + +private def summarizeBenchTimes (label : String) (times : Array Nat) : IO Unit := do + let t0 := times[0]! + let mut minT := t0 + let mut maxT := t0 + let mut sumT : Nat := 0 + for t in times do + if t < minT then + minT := t + if t > maxT then + maxT := t + sumT := sumT + t + let avgT := sumT / times.size + IO.println s!"{label} runs={times.size} min={minT}ms avg={avgT}ms max={maxT}ms" + +private def runBenchWithArgs (args : BenchArgs) : IO UInt32 := do + if args.runs = 0 then + IO.eprintln "Error: --runs must be > 0" + return 1 + setStdoutLogNameFromModelPath args.modelPathStr + let loadResult ← loadModel args.modelPath + let model ← + match loadResult with + | .error msg => + IO.eprintln s!"Error loading model: {msg}" + return 1 + | .ok model0 => pure (model0.trimTrailingZeroEmbeddings) + match args.mode with + | .analyze => + let mut times : Array Nat := Array.mkEmpty args.runs + let mut lastHeads : Nat := 0 + let mut lastCircuits : Nat := 0 + for i in [:args.runs] do + let t0 ← IO.monoNanosNow + let (heads, circuits) := benchAnalyzeOnce model args.threshold + let t1 ← IO.monoNanosNow + let dtMs := (t1 - t0) / 1000000 + times := times.push dtMs + lastHeads := heads + lastCircuits := circuits + if args.verbose then + IO.println s!"run {i + 1}: {dtMs}ms heads={heads} circuits={circuits}" + summarizeBenchTimes "bench:analysis" times + IO.println <| + s!"bench:analysis threshold={args.threshold} heads={lastHeads} " ++ + s!"circuits={lastCircuits}" + return 0 + | .induction => + let some W_U := model.unembedding + | do + IO.eprintln "Error: Model is missing unembedding matrix (needed for target direction)." + return 1 + let target? := deriveInductionTarget model W_U args.correctOpt args.incorrectOpt + let some target := target? + | do + IO.eprintln "Error: Use both --correct and --incorrect (or ensure TOKENS are present)." + return 1 + let mut times : Array Nat := Array.mkEmpty args.runs + let mut lastHeads : Nat := 0 + for i in [:args.runs] do + let t0 ← IO.monoNanosNow + let heads := benchInductionOnce model target args.minEffect + let t1 ← IO.monoNanosNow + let dtMs := (t1 - t0) / 1000000 + times := times.push dtMs + lastHeads := heads + if args.verbose then + IO.println s!"run {i + 1}: {dtMs}ms heads={heads}" + summarizeBenchTimes "bench:induction" times + IO.println s!"bench:induction minEffect={args.minEffect} heads={lastHeads}" + return 0 + +/-- Run the bench command for repeatable performance measurements. -/ +def runBench (p : Parsed) : IO UInt32 := do + let some args ← parseBenchArgs p + | return 1 + runBenchWithArgs args + /-! ## SOUND command helpers -/ private structure CertifyArgs where @@ -1916,6 +2073,22 @@ def inductionCmd : Cmd := `[Cli| model : String; "Path to the model weights file (.nfpt)" ] +/-- The bench subcommand. -/ +def benchCmd : Cmd := `[Cli| + bench VIA runBench; + "Run repeatable microbenchmarks on analysis or induction search." + FLAGS: + mode : String; "analysis|induction (default: analysis)" + runs : Nat; "Number of timed runs (default: 5)" + t, threshold : String; "Analyze threshold (default: 0.1)" + minEffect : String; "Induction minEffect (default: 0.0)" + c, correct : Nat; "Correct token ID (requires --incorrect)" + i, incorrect : Nat; "Incorrect token ID (requires --correct)" + v, verbose; "Print per-run timing details" + ARGS: + model : String; "Path to the model weights file (.nfpt)" +] + /-- The certify subcommand. -/ def certifyCmd : Cmd := `[Cli| certify VIA runCertify; @@ -2093,6 +2266,7 @@ def nfpCmd : Cmd := `[Cli| SUBCOMMANDS: analyzeCmd; inductionCmd; + benchCmd; certifyCmd; headBoundsCmd; headPatternCmd; From a1b2d9ebef7a226776543d26dfb27a5c58fa8be8 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 14:50:22 +0100 Subject: [PATCH 049/244] Extend bench with repeats and breakdown --- Main.lean | 84 +++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 73 insertions(+), 11 deletions(-) diff --git a/Main.lean b/Main.lean index 9a4d5e8..b9184a2 100644 --- a/Main.lean +++ b/Main.lean @@ -883,11 +883,13 @@ private structure BenchArgs where modelPathStr : String mode : BenchMode runs : Nat + repeatCount : Nat threshold : Float minEffect : Float correctOpt : Option Nat incorrectOpt : Option Nat verbose : Bool + breakdown : Bool private def parseBenchArgs (p : Parsed) : IO (Option BenchArgs) := do let modelPathStr := p.positionalArg! "model" |>.as! String @@ -897,6 +899,7 @@ private def parseBenchArgs (p : Parsed) : IO (Option BenchArgs) := do IO.eprintln s!"Error: Invalid --mode '{modeStr}' (analysis|induction)" return none let runs := p.flag? "runs" |>.map (·.as! Nat) |>.getD 5 + let repeatCount := p.flag? "repeats" |>.map (·.as! Nat) |>.getD 1 let thresholdStr := p.flag? "threshold" |>.map (·.as! String) |>.getD "0.1" let minEffectStr := p.flag? "minEffect" |>.map (·.as! String) |>.getD "0.0" let some threshold := Nfp.parseFloat thresholdStr @@ -910,16 +913,19 @@ private def parseBenchArgs (p : Parsed) : IO (Option BenchArgs) := do let correctOpt := p.flag? "correct" |>.map (·.as! Nat) let incorrectOpt := p.flag? "incorrect" |>.map (·.as! Nat) let verbose := p.hasFlag "verbose" + let breakdown := p.hasFlag "breakdown" return some { modelPath := ⟨modelPathStr⟩ modelPathStr := modelPathStr mode := mode runs := runs + repeatCount := repeatCount threshold := threshold minEffect := minEffect correctOpt := correctOpt incorrectOpt := incorrectOpt verbose := verbose + breakdown := breakdown } /-- Core analysis work for benchmarking (no IO). -/ @@ -943,7 +949,8 @@ private def benchInductionOnce (model : ConcreteModel) (target : TargetDirection (storeDiagnostics := false) heads.size -private def summarizeBenchTimes (label : String) (times : Array Nat) : IO Unit := do +private def summarizeBenchTimes (label : String) (times : Array Nat) (repeatCount : Nat) : + IO Unit := do let t0 := times[0]! let mut minT := t0 let mut maxT := t0 @@ -955,12 +962,24 @@ private def summarizeBenchTimes (label : String) (times : Array Nat) : IO Unit : maxT := t sumT := sumT + t let avgT := sumT / times.size - IO.println s!"{label} runs={times.size} min={minT}ms avg={avgT}ms max={maxT}ms" + IO.println <| + s!"{label} runs={times.size} repeat={repeatCount} " ++ + s!"min={minT}ms avg={avgT}ms max={maxT}ms" + +private def timeNs {α : Type} (action : Unit → IO α) : IO (α × Nat) := do + let t0 ← IO.monoNanosNow + let result ← action () + let t1 ← IO.monoNanosNow + let dtNs := t1 - t0 + return (result, dtNs) private def runBenchWithArgs (args : BenchArgs) : IO UInt32 := do if args.runs = 0 then IO.eprintln "Error: --runs must be > 0" return 1 + if args.repeatCount = 0 then + IO.eprintln "Error: --repeats must be > 0" + return 1 setStdoutLogNameFromModelPath args.modelPathStr let loadResult ← loadModel args.modelPath let model ← @@ -974,17 +993,57 @@ private def runBenchWithArgs (args : BenchArgs) : IO UInt32 := do let mut times : Array Nat := Array.mkEmpty args.runs let mut lastHeads : Nat := 0 let mut lastCircuits : Nat := 0 + let mut cacheNsTotal : Nat := 0 + let mut deepNsTotal : Nat := 0 + let mut candNsTotal : Nat := 0 for i in [:args.runs] do let t0 ← IO.monoNanosNow - let (heads, circuits) := benchAnalyzeOnce model args.threshold + if args.breakdown then + let mut localCacheNs : Nat := 0 + let mut localDeepNs : Nat := 0 + let mut localCandNs : Nat := 0 + for _ in [:args.repeatCount] do + let (cache, cacheNs) ← timeNs (fun () => + pure <| PrecomputedCache.build model) + let (deepCircuits, deepNs) ← timeNs (fun () => + pure <| findDeepCircuitCandidatesFromCache cache) + let verifiedDeep := deepCircuits.filter (·.amplifiedError ≤ args.threshold) + let (inductionHeads, candNs) ← timeNs (fun () => + pure <| + (deepCircuits.filterMap (·.toInductionCandidate? cache)).qsort + (·.combinedError < ·.combinedError)) + let verifiedHeads := inductionHeads.filter (·.combinedError ≤ args.threshold) + localCacheNs := localCacheNs + cacheNs + localDeepNs := localDeepNs + deepNs + localCandNs := localCandNs + candNs + lastHeads := verifiedHeads.size + lastCircuits := verifiedDeep.size + cacheNsTotal := cacheNsTotal + localCacheNs + deepNsTotal := deepNsTotal + localDeepNs + candNsTotal := candNsTotal + localCandNs + else + for _ in [:args.repeatCount] do + let (heads, circuits) := benchAnalyzeOnce model args.threshold + lastHeads := heads + lastCircuits := circuits let t1 ← IO.monoNanosNow let dtMs := (t1 - t0) / 1000000 times := times.push dtMs - lastHeads := heads - lastCircuits := circuits if args.verbose then - IO.println s!"run {i + 1}: {dtMs}ms heads={heads} circuits={circuits}" - summarizeBenchTimes "bench:analysis" times + IO.println s!"run {i + 1}: {dtMs}ms heads={lastHeads} circuits={lastCircuits}" + summarizeBenchTimes "bench:analysis" times args.repeatCount + if args.breakdown then + let runs := args.runs + let repeatCount := args.repeatCount + let cacheAvgNs := cacheNsTotal / (runs * repeatCount) + let deepAvgNs := deepNsTotal / (runs * repeatCount) + let candAvgNs := candNsTotal / (runs * repeatCount) + let cacheAvgUs := cacheAvgNs / 1000 + let deepAvgUs := deepAvgNs / 1000 + let candAvgUs := candAvgNs / 1000 + IO.println <| + s!"bench:analysis cacheAvg={cacheAvgUs}us " ++ + s!"scanAvg={deepAvgUs}us candAvg={candAvgUs}us" IO.println <| s!"bench:analysis threshold={args.threshold} heads={lastHeads} " ++ s!"circuits={lastCircuits}" @@ -1003,14 +1062,15 @@ private def runBenchWithArgs (args : BenchArgs) : IO UInt32 := do let mut lastHeads : Nat := 0 for i in [:args.runs] do let t0 ← IO.monoNanosNow - let heads := benchInductionOnce model target args.minEffect + for _ in [:args.repeatCount] do + let heads := benchInductionOnce model target args.minEffect + lastHeads := heads let t1 ← IO.monoNanosNow let dtMs := (t1 - t0) / 1000000 times := times.push dtMs - lastHeads := heads if args.verbose then - IO.println s!"run {i + 1}: {dtMs}ms heads={heads}" - summarizeBenchTimes "bench:induction" times + IO.println s!"run {i + 1}: {dtMs}ms heads={lastHeads}" + summarizeBenchTimes "bench:induction" times args.repeatCount IO.println s!"bench:induction minEffect={args.minEffect} heads={lastHeads}" return 0 @@ -2080,11 +2140,13 @@ def benchCmd : Cmd := `[Cli| FLAGS: mode : String; "analysis|induction (default: analysis)" runs : Nat; "Number of timed runs (default: 5)" + repeats : Nat; "Repeat inner workload per run (default: 1)" t, threshold : String; "Analyze threshold (default: 0.1)" minEffect : String; "Induction minEffect (default: 0.0)" c, correct : Nat; "Correct token ID (requires --incorrect)" i, incorrect : Nat; "Incorrect token ID (requires --correct)" v, verbose; "Print per-run timing details" + breakdown; "Emit per-phase averages (analysis only)" ARGS: model : String; "Path to the model weights file (.nfpt)" ] From 8d39d7ac26a49cfe37a89a4b6f23debc953a6671 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 15:41:09 +0100 Subject: [PATCH 050/244] Streamline induction candidates and cache benchmarks --- Main.lean | 95 +++++++++++++----- Nfp/Discovery.lean | 238 ++++++++++++++++++++------------------------- Nfp/IO.lean | 21 ++-- 3 files changed, 195 insertions(+), 159 deletions(-) diff --git a/Main.lean b/Main.lean index b9184a2..9d5f60e 100644 --- a/Main.lean +++ b/Main.lean @@ -930,14 +930,21 @@ private def parseBenchArgs (p : Parsed) : IO (Option BenchArgs) := do /-- Core analysis work for benchmarking (no IO). -/ private def benchAnalyzeOnce (model : ConcreteModel) (threshold : Float) : Nat × Nat := - let cache := PrecomputedCache.build model - let deepCircuits := findDeepCircuitCandidatesFromCache cache - let verifiedDeep := deepCircuits.filter (·.amplifiedError ≤ threshold) - let inductionHeads := - (deepCircuits.filterMap (·.toInductionCandidate? cache)).qsort - (·.combinedError < ·.combinedError) - let verifiedHeads := inductionHeads.filter (·.combinedError ≤ threshold) - (verifiedHeads.size, verifiedDeep.size) + Id.run do + let cache := PrecomputedCache.build model + let deepCircuits := findDeepCircuitCandidatesFromCache cache + let verifiedDeep := deepCircuits.filter (·.amplifiedError ≤ threshold) + let mut verifiedHeads : Array CandidateInductionHead := Array.mkEmpty 0 + for circuit in deepCircuits do + match circuit.toInductionCandidateCore? cache with + | none => pure () + | some core => + if core.combinedError ≤ threshold then + match core.toInductionCandidate? cache with + | some cand => verifiedHeads := verifiedHeads.push cand + | none => pure () + let verifiedSorted := verifiedHeads.qsort (·.combinedError < ·.combinedError) + return (verifiedSorted.size, verifiedDeep.size) /-- Core induction-head search work for benchmarking (no IO). -/ private def benchInductionOnce (model : ConcreteModel) (target : TargetDirection) @@ -993,32 +1000,69 @@ private def runBenchWithArgs (args : BenchArgs) : IO UInt32 := do let mut times : Array Nat := Array.mkEmpty args.runs let mut lastHeads : Nat := 0 let mut lastCircuits : Nat := 0 - let mut cacheNsTotal : Nat := 0 + let mut fwdNsTotal : Nat := 0 + let mut headNsTotal : Nat := 0 + let mut normNsTotal : Nat := 0 let mut deepNsTotal : Nat := 0 let mut candNsTotal : Nat := 0 for i in [:args.runs] do let t0 ← IO.monoNanosNow if args.breakdown then - let mut localCacheNs : Nat := 0 + let mut localFwdNs : Nat := 0 + let mut localHeadNs : Nat := 0 + let mut localNormNs : Nat := 0 let mut localDeepNs : Nat := 0 let mut localCandNs : Nat := 0 for _ in [:args.repeatCount] do - let (cache, cacheNs) ← timeNs (fun () => - pure <| PrecomputedCache.build model) + let (fwdResult, fwdNs) ← timeNs (fun () => + pure <| model.runForward true) + let ((headData, ln1Inputs), headNs) ← timeNs (fun () => + pure <| + PrecomputedCache.buildHeadData model fwdResult true + ConcreteMatrix.BoundEffort.tier1 false) + let baseBounds := Array.replicate model.numLayers 0.0 + let baseCache : PrecomputedCache := { + model := model + forwardResult := fwdResult + ln1Inputs := ln1Inputs + headData := headData + layerNormBounds := baseBounds + layerNormBoundsComputed := false + } + let (layerNormBounds, normNs) ← timeNs (fun () => + pure <| + PrecomputedCache.computeLayerNormBounds baseCache + ConcreteMatrix.BoundEffort.tier1) + let cache : PrecomputedCache := { + baseCache with + layerNormBounds := layerNormBounds + layerNormBoundsComputed := true + } let (deepCircuits, deepNs) ← timeNs (fun () => pure <| findDeepCircuitCandidatesFromCache cache) let verifiedDeep := deepCircuits.filter (·.amplifiedError ≤ args.threshold) - let (inductionHeads, candNs) ← timeNs (fun () => - pure <| - (deepCircuits.filterMap (·.toInductionCandidate? cache)).qsort - (·.combinedError < ·.combinedError)) - let verifiedHeads := inductionHeads.filter (·.combinedError ≤ args.threshold) - localCacheNs := localCacheNs + cacheNs + let (verifiedHeads, candNs) ← timeNs (fun () => do + let mut verified : Array CandidateInductionHead := Array.mkEmpty 0 + for circuit in deepCircuits do + match circuit.toInductionCandidateCore? cache with + | none => pure () + | some core => + if core.combinedError ≤ args.threshold then + match core.toInductionCandidate? cache with + | some cand => verified := verified.push cand + | none => pure () + let verifiedSorted := verified.qsort (·.combinedError < ·.combinedError) + return verifiedSorted) + localFwdNs := localFwdNs + fwdNs + localHeadNs := localHeadNs + headNs + localNormNs := localNormNs + normNs localDeepNs := localDeepNs + deepNs localCandNs := localCandNs + candNs lastHeads := verifiedHeads.size lastCircuits := verifiedDeep.size - cacheNsTotal := cacheNsTotal + localCacheNs + fwdNsTotal := fwdNsTotal + localFwdNs + headNsTotal := headNsTotal + localHeadNs + normNsTotal := normNsTotal + localNormNs deepNsTotal := deepNsTotal + localDeepNs candNsTotal := candNsTotal + localCandNs else @@ -1035,15 +1079,20 @@ private def runBenchWithArgs (args : BenchArgs) : IO UInt32 := do if args.breakdown then let runs := args.runs let repeatCount := args.repeatCount - let cacheAvgNs := cacheNsTotal / (runs * repeatCount) + let fwdAvgNs := fwdNsTotal / (runs * repeatCount) + let headAvgNs := headNsTotal / (runs * repeatCount) + let normAvgNs := normNsTotal / (runs * repeatCount) let deepAvgNs := deepNsTotal / (runs * repeatCount) let candAvgNs := candNsTotal / (runs * repeatCount) - let cacheAvgUs := cacheAvgNs / 1000 + let fwdAvgUs := fwdAvgNs / 1000 + let headAvgUs := headAvgNs / 1000 + let normAvgUs := normAvgNs / 1000 let deepAvgUs := deepAvgNs / 1000 let candAvgUs := candAvgNs / 1000 IO.println <| - s!"bench:analysis cacheAvg={cacheAvgUs}us " ++ - s!"scanAvg={deepAvgUs}us candAvg={candAvgUs}us" + s!"bench:analysis fwdAvg={fwdAvgUs}us headAvg={headAvgUs}us " ++ + s!"normAvg={normAvgUs}us scanAvg={deepAvgUs}us " ++ + s!"candAvg={candAvgUs}us" IO.println <| s!"bench:analysis threshold={args.threshold} heads={lastHeads} " ++ s!"circuits={lastCircuits}" diff --git a/Nfp/Discovery.lean b/Nfp/Discovery.lean index b00462f..f53f936 100644 --- a/Nfp/Discovery.lean +++ b/Nfp/Discovery.lean @@ -4932,31 +4932,7 @@ def layerNormBoundAt (cache : PrecomputedCache) (layerIdx : Nat) let attnFrob : Float := Float.sqrt (max 0.0 d.attentionFrobeniusNormSq) let attnOpUb : Float := min attnFrob d.attentionOneInfBound let valueTermUb : Float := attnOpUb * d.valueOutputProjSchurNorm - let inputs : PatternTermBoundInputs := { - attention := d.attention - inputNorm := d.inputNorm - inputOpBound := d.inputOpBound - qFrobBound := d.qFrobBound - kFrobBound := d.kFrobBound - vFrobBound := d.vFrobBound - qOpBoundAct := d.qOpBoundAct - kOpBoundAct := d.kOpBoundAct - vOpBoundAct := d.vOpBoundAct - qkActFrobBound := d.qkActFrobBound - kqActFrobBound := d.kqActFrobBound - qkActOpBound := d.qkActOpBound - kqActOpBound := d.kqActOpBound - scaleFactor := d.scaleFactor - wqOpBound := d.wqOpGram - wkOpBound := d.wkOpGram - wvOpBound := d.wvOpGram - woOpBound := d.woOpGram - voOpBound := d.valueOutputProjSchurNorm - bqFrob := d.bqFrob - bkFrob := d.bkFrob - bvFrob := d.bvFrob - } - let patternTermUb : Float := computePatternTermBound inputs + let patternTermUb : Float := d.patternTermBound attnPart := attnPart + d.ln1OpBound * (valueTermUb + patternTermUb) if hm : layerIdx < model.mlps.size then @@ -4994,19 +4970,15 @@ def computeLayerNormBounds (cache : PrecomputedCache) out := out.push (cache.layerNormBoundAt l effort) out -/-- Build a complete precomputed cache for a model. - -This precomputes all attention patterns, projections, and norms once. --/ - def build (model : ConcreteModel) (causal : Bool := true) - (computeLayerNormBounds : Bool := true) - (layerNormEffort : ConcreteMatrix.BoundEffort := ConcreteMatrix.BoundEffort.tier1) - (storeDiagnostics : Bool := false) : - PrecomputedCache := Id.run do - let fwdResult := model.runForward causal +/-- Build cached head data and pre-LN attention inputs for all layers. -/ +def buildHeadData (model : ConcreteModel) (fwdResult : ForwardPassResult) + (causal : Bool := true) + (layerNormEffort : ConcreteMatrix.BoundEffort := ConcreteMatrix.BoundEffort.tier1) + (storeDiagnostics : Bool := false) : + Array (Array PrecomputedHeadData) × Array ConcreteMatrix := Id.run do -- Parallelizing both layers and heads can lead to too many tasks; prefer layer-level parallelism. let useParallelLayers := model.numLayers >= 4 - let computeLayer (l : Nat) : (Array PrecomputedHeadData × (Float × ConcreteMatrix)) := Id.run do + let computeLayer (l : Nat) : (Array PrecomputedHeadData × ConcreteMatrix) := Id.run do let layerInput := fwdResult.getLayerInput l let attnInput := model.applyLn1 l layerInput let inputNorm := computeInputNorm attnInput @@ -5206,7 +5178,7 @@ This precomputes all attention patterns, projections, and norms once. let kqActOpDense1 : Float := kqActGram1.opNormUpperBoundDenseBrauer let kqActOpDense2 : Float := kqActGram2.opNormUpperBoundDenseBrauer let kqActOpBoundDense : Float := - Float.sqrt (max 0.0 (min kqActOpDense1 kqActOpDense2)) + Float.sqrt (max 0.0 (min kqActOpDense1 qkActOpDense2)) let kqActF2_1 : Float := ConcreteMatrix.traceMul kqActGram1 kqActGram1 let kqActF2_2 : Float := ConcreteMatrix.traceMul kqActGram2 kqActGram2 let kqActMoment1 : Float := @@ -5384,73 +5356,14 @@ This precomputes all attention patterns, projections, and norms once. else #[] - let norm : Float := - if computeLayerNormBounds then - -- OPTIMIZATION: compute per-layer residual Jacobian upper bounds from cached head data, - -- avoiding recomputation of attention weights / projections. - Id.run do - let y := fwdResult.getPostAttnResidual l - let ln2Bound := model.ln2OpBound l y - let mut attnPart : Float := 0.0 - for d in layerHeadData do - let attnFrob : Float := Float.sqrt (max 0.0 d.attentionFrobeniusNormSq) - let attnOpUb : Float := min attnFrob d.attentionOneInfBound - let valueTermUb : Float := attnOpUb * d.valueOutputProjSchurNorm - let inputs : PatternTermBoundInputs := { - attention := d.attention - inputNorm := d.inputNorm - inputOpBound := d.inputOpBound - qFrobBound := d.qFrobBound - kFrobBound := d.kFrobBound - vFrobBound := d.vFrobBound - qOpBoundAct := d.qOpBoundAct - kOpBoundAct := d.kOpBoundAct - vOpBoundAct := d.vOpBoundAct - qkActFrobBound := d.qkActFrobBound - kqActFrobBound := d.kqActFrobBound - qkActOpBound := d.qkActOpBound - kqActOpBound := d.kqActOpBound - scaleFactor := d.scaleFactor - wqOpBound := d.wqOpGram - wkOpBound := d.wkOpGram - wvOpBound := d.wvOpGram - woOpBound := d.woOpGram - voOpBound := d.valueOutputProjSchurNorm - bqFrob := d.bqFrob - bkFrob := d.bkFrob - bvFrob := d.bvFrob - } - let patternTermUb : Float := computePatternTermBound inputs - attnPart := attnPart + d.ln1OpBound * (valueTermUb + patternTermUb) - - if hm : l < model.mlps.size then - let mlp := model.mlps[l]'hm - let winNormUb := mlp.W_in.opNormUpperBoundRectGramEffort layerNormEffort - let woutNormUb := mlp.W_out.opNormUpperBoundRectGramEffort layerNormEffort - let geluDeriv := fwdResult.getMlpGeluDeriv l - let mlpUb := - if geluDeriv.numCols = mlp.hiddenDim ∧ geluDeriv.numRows = y.numRows then - computeMLPLayerOpNormFromGeluDerivWithOpBounds mlp geluDeriv winNormUb woutNormUb - else - let mlpInput := model.applyLn2 l y - let hidden := (mlpInput.matmul mlp.W_in).addBias mlp.b_in - let dAct := hidden.map geluDerivFloat - computeMLPLayerOpNormFromGeluDerivWithOpBounds mlp dAct winNormUb woutNormUb - let mlpPart := ln2Bound * mlpUb - return attnPart + (1.0 + attnPart) * mlpPart - else - return attnPart - else - 0.0 - - return (layerHeadData, (norm, attnInput)) + return (layerHeadData, attnInput) -- Pure parallelism via tasks: layer cache construction is independent once the -- forward pass has produced all layer inputs. let useParallel := useParallelLayers - let layerResults : Array (Array PrecomputedHeadData × (Float × ConcreteMatrix)) := + let layerResults : Array (Array PrecomputedHeadData × ConcreteMatrix) := if useParallel then - let tasks : Array (Task (Array PrecomputedHeadData × (Float × ConcreteMatrix))) := + let tasks : Array (Task (Array PrecomputedHeadData × ConcreteMatrix)) := .ofFn fun i : Fin model.numLayers => Task.spawn (fun _ => computeLayer i.val) tasks.map Task.get @@ -5459,19 +5372,41 @@ This precomputes all attention patterns, projections, and norms once. computeLayer i.val let mut headData : Array (Array PrecomputedHeadData) := Array.mkEmpty model.numLayers - let mut layerNormBounds : Array Float := Array.mkEmpty model.numLayers let mut ln1Inputs : Array ConcreteMatrix := Array.mkEmpty model.numLayers - for (layerHeadData, (norm, attnInput)) in layerResults do + for (layerHeadData, attnInput) in layerResults do headData := headData.push layerHeadData - layerNormBounds := layerNormBounds.push norm ln1Inputs := ln1Inputs.push attnInput - { model := model + return (headData, ln1Inputs) + +/-- Build a complete precomputed cache for a model. + +This precomputes all attention patterns, projections, and norms once. +-/ +def build (model : ConcreteModel) (causal : Bool := true) + (computeLayerNormBounds : Bool := true) + (layerNormEffort : ConcreteMatrix.BoundEffort := ConcreteMatrix.BoundEffort.tier1) + (storeDiagnostics : Bool := false) : + PrecomputedCache := Id.run do + let fwdResult := model.runForward causal + let (headData, ln1Inputs) := + buildHeadData model fwdResult causal layerNormEffort storeDiagnostics + let baseBounds := Array.replicate model.numLayers 0.0 + let baseCache : PrecomputedCache := { + model := model forwardResult := fwdResult ln1Inputs := ln1Inputs headData := headData - layerNormBounds := layerNormBounds - layerNormBoundsComputed := computeLayerNormBounds } + layerNormBounds := baseBounds + layerNormBoundsComputed := false + } + if computeLayerNormBounds then + let layerNormBounds := PrecomputedCache.computeLayerNormBounds baseCache layerNormEffort + return { baseCache with + layerNormBounds := layerNormBounds + layerNormBoundsComputed := true } + else + return baseCache /-- Retrieve cached data for a specific head. -/ def getHeadData (cache : PrecomputedCache) (layerIdx headIdx : Nat) : @@ -6557,24 +6492,26 @@ def computeKCompositionScore else 0.0 -/-- Convert a 2-layer deep circuit candidate into an induction-head candidate. +/-- Core induction candidate data used for fast thresholding. -/ +structure InductionCandidateCore where + layer1Idx : Nat + layer2Idx : Nat + head1Idx : Nat + head2Idx : Nat + patternBound1 : Float + patternBound2 : Float + combinedError : Float + prevTokenStrength : Float + description : String -This is used to avoid re-running the expensive `checkInductionPattern` scan when both -induction heads and deep circuits are requested from the same cache. --/ -def DeepCircuitCandidate.toInductionCandidate? - (c : DeepCircuitCandidate) (cache : PrecomputedCache) : +namespace InductionCandidateCore + +/-- Finalize an induction candidate by computing expensive scores. -/ +def toInductionCandidate? (core : InductionCandidateCore) (cache : PrecomputedCache) : Option CandidateInductionHead := - if c.layerIndices.size = 2 && c.headIndices.size = 2 then - let l1 := c.layerIndices[0]! - let l2 := c.layerIndices[1]! - let h1 := c.headIndices[0]! - let h2 := c.headIndices[1]! - match cache.getHeadData l1 h1, cache.getHeadData l2 h2 with - | some d1, some d2 => - let ε1 := d1.faithfulnessRatio - let ε2 := d2.faithfulnessRatio - let combinedError := ε1 + ε2 + ε1 * ε2 + match cache.getHeadData core.layer1Idx core.head1Idx, + cache.getHeadData core.layer2Idx core.head2Idx with + | some d1, some d2 => let inductionScore : Float := match cache.model.inputTokens with | some tokens => @@ -6582,22 +6519,63 @@ def DeepCircuitCandidate.toInductionCandidate? | none => 1.0 let kComp := computeKCompositionScore cache.model d1 d2 some { - layer1Idx := l1 - layer2Idx := l2 - head1Idx := h1 - head2Idx := h2 - patternBound1 := ε1 - patternBound2 := ε2 - combinedError := combinedError - prevTokenStrength := d1.prevTokenStrength + layer1Idx := core.layer1Idx + layer2Idx := core.layer2Idx + head1Idx := core.head1Idx + head2Idx := core.head2Idx + patternBound1 := core.patternBound1 + patternBound2 := core.patternBound2 + combinedError := core.combinedError + prevTokenStrength := core.prevTokenStrength inductionScore := inductionScore kComp := kComp - description := s!"L{l1}H{h1}->L{l2}H{h2} (deep)" + description := core.description } + | _, _ => none + +end InductionCandidateCore + +/-- Convert a 2-layer deep circuit candidate into cheap induction-core data. -/ +def DeepCircuitCandidate.toInductionCandidateCore? + (c : DeepCircuitCandidate) (cache : PrecomputedCache) : + Option InductionCandidateCore := + if c.layerIndices.size = 2 && c.headIndices.size = 2 then + let l1 := c.layerIndices[0]! + let l2 := c.layerIndices[1]! + let h1 := c.headIndices[0]! + let h2 := c.headIndices[1]! + match cache.getHeadData l1 h1, cache.getHeadData l2 h2 with + | some d1, some d2 => + let ε1 := d1.faithfulnessRatio + let ε2 := d2.faithfulnessRatio + let combinedError := ε1 + ε2 + ε1 * ε2 + some { + layer1Idx := l1 + layer2Idx := l2 + head1Idx := h1 + head2Idx := h2 + patternBound1 := ε1 + patternBound2 := ε2 + combinedError := combinedError + prevTokenStrength := d1.prevTokenStrength + description := s!"L{l1}H{h1}->L{l2}H{h2} (deep)" + } | _, _ => none else none +/-- Convert a 2-layer deep circuit candidate into an induction-head candidate. + +This is used to avoid re-running the expensive `checkInductionPattern` scan when both +induction heads and deep circuits are requested from the same cache. +-/ +def DeepCircuitCandidate.toInductionCandidate? + (c : DeepCircuitCandidate) (cache : PrecomputedCache) : + Option CandidateInductionHead := + match c.toInductionCandidateCore? cache with + | none => none + | some core => core.toInductionCandidate? cache + /-- Find candidate (L1, L2) induction-head pairs from a `PrecomputedCache`. This searches for the classic two-head induction circuit: diff --git a/Nfp/IO.lean b/Nfp/IO.lean index b527ed6..0a39306 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -446,13 +446,22 @@ def analyzeModel (model : ConcreteModel) (modelName : String) -- Derive induction-head candidates from the same scan to avoid repeating -- the expensive `checkInductionPattern` computation. - let inductionHeads ← timeIt "analysis:induction-candidates" (fun () => - pure <| - (deepCircuits.filterMap (·.toInductionCandidate? cache)).qsort - (·.combinedError < ·.combinedError)) - let verifiedHeads := inductionHeads.filter (·.combinedError ≤ threshold) + let (totalInduction, verifiedHeads) ← timeIt "analysis:induction-candidates" (fun () => do + let mut total : Nat := 0 + let mut verified : Array CandidateInductionHead := Array.mkEmpty 0 + for circuit in deepCircuits do + match circuit.toInductionCandidateCore? cache with + | none => pure () + | some core => + total := total + 1 + if core.combinedError ≤ threshold then + match core.toInductionCandidate? cache with + | some cand => verified := verified.push cand + | none => pure () + let verifiedSorted := verified.qsort (·.combinedError < ·.combinedError) + return (total, verifiedSorted)) IO.println s!" Found {verifiedHeads.size} verified induction heads \ - (of {inductionHeads.size} candidates)\n" + (of {totalInduction} candidates)\n" IO.println "Analysis complete!\n" From 3c24c325eba6ece50af9840a231a88f537dcffe3 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 16:08:15 +0100 Subject: [PATCH 051/244] Tune head task parallelism in cache build --- Nfp/Discovery.lean | 47 ++++++++++++++++++++++++++++++---------------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/Nfp/Discovery.lean b/Nfp/Discovery.lean index f53f936..6560948 100644 --- a/Nfp/Discovery.lean +++ b/Nfp/Discovery.lean @@ -4976,7 +4976,7 @@ def buildHeadData (model : ConcreteModel) (fwdResult : ForwardPassResult) (layerNormEffort : ConcreteMatrix.BoundEffort := ConcreteMatrix.BoundEffort.tier1) (storeDiagnostics : Bool := false) : Array (Array PrecomputedHeadData) × Array ConcreteMatrix := Id.run do - -- Parallelizing both layers and heads can lead to too many tasks; prefer layer-level parallelism. + -- Prefer layer-level parallelism, but allow bounded head chunking to use spare cores. let useParallelLayers := model.numLayers >= 4 let computeLayer (l : Nat) : (Array PrecomputedHeadData × ConcreteMatrix) := Id.run do let layerInput := fwdResult.getLayerInput l @@ -5091,10 +5091,6 @@ def buildHeadData (model : ConcreteModel) (fwdResult : ForwardPassResult) let keyMean := meanVec keys let queryMean := meanVec queries let valueMean := meanVec values - let nF := seqLen.toFloat - let vMeanNormSq : Float := sumSquares valueMean - let vFrobBound : Float := - Float.sqrt (max 0.0 (vFrobBoundRaw * vFrobBoundRaw - nF * vMeanNormSq)) let vActGram := values.transpose.matmul values let vActGramCentered := centerGram vActGram valueMean seqLen let vActTrace : Float := gramTrace vActGramCentered @@ -5178,7 +5174,7 @@ def buildHeadData (model : ConcreteModel) (fwdResult : ForwardPassResult) let kqActOpDense1 : Float := kqActGram1.opNormUpperBoundDenseBrauer let kqActOpDense2 : Float := kqActGram2.opNormUpperBoundDenseBrauer let kqActOpBoundDense : Float := - Float.sqrt (max 0.0 (min kqActOpDense1 qkActOpDense2)) + Float.sqrt (max 0.0 (min kqActOpDense1 kqActOpDense2)) let kqActF2_1 : Float := ConcreteMatrix.traceMul kqActGram1 kqActGram1 let kqActF2_2 : Float := ConcreteMatrix.traceMul kqActGram2 kqActGram2 let kqActMoment1 : Float := @@ -5340,19 +5336,38 @@ def buildHeadData (model : ConcreteModel) (fwdResult : ForwardPassResult) voFactorGram := bnds.voFactorGram } - let useParallelHeads := (!useParallelLayers) && heads.size >= 4 - if useParallelHeads then - let tasks : Array (Task PrecomputedHeadData) := - .ofFn fun i : Fin heads.size => - Task.spawn (fun _ => computeHead i.val heads[i]) - tasks.map Task.get - else + let headTaskCount : Nat := + if heads.size < 4 then + 1 + else if useParallelLayers then + let maxHeadTasks : Nat := 48 + let budget := maxHeadTasks / model.numLayers + let target := Nat.max 1 budget + Nat.min heads.size target + else + heads.size + let computeHeadChunk (start stop : Nat) : Array PrecomputedHeadData := Id.run do + let mut out : Array PrecomputedHeadData := Array.mkEmpty (stop - start) + for h_idx in [start:stop] do + if hh : h_idx < heads.size then + out := out.push (computeHead h_idx (heads[h_idx]'hh)) + return out + if headTaskCount > 1 then Id.run do + let chunkSize := (heads.size + headTaskCount - 1) / headTaskCount + let chunkCount := (heads.size + chunkSize - 1) / chunkSize + let tasks : Array (Task (Array PrecomputedHeadData)) := + .ofFn fun i : Fin chunkCount => + let start := i.val * chunkSize + let stop := min heads.size (start + chunkSize) + Task.spawn (fun _ => computeHeadChunk start stop) let mut out : Array PrecomputedHeadData := Array.mkEmpty heads.size - for h_idx in [:heads.size] do - if hh : h_idx < heads.size then - out := out.push (computeHead h_idx (heads[h_idx]'hh)) + for chunk in tasks.map Task.get do + for item in chunk do + out := out.push item return out + else + computeHeadChunk 0 heads.size else #[] From 3dc9a7a6a0ea09c94286a5c291798ba45bf3f207 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 16:29:45 +0100 Subject: [PATCH 052/244] Refactor parsing loops and LayerNorm stats --- Nfp/Discovery.lean | 12 +++---- Nfp/Sound/BinaryPure.lean | 75 ++++++++++++++++----------------------- Nfp/Sound/TextPure.lean | 34 +++++++++--------- 3 files changed, 54 insertions(+), 67 deletions(-) diff --git a/Nfp/Discovery.lean b/Nfp/Discovery.lean index 6560948..8a21860 100644 --- a/Nfp/Discovery.lean +++ b/Nfp/Discovery.lean @@ -1491,8 +1491,8 @@ def layerNormRowwise (X γ β : ConcreteMatrix) (eps : Float := 1e-5) : Concrete let betaData := β.data -- Per-row mean and inverse stddev (compute once for speed). - let mut means : Array Float := Array.mkEmpty rows - let mut invStds : Array Float := Array.mkEmpty rows + let mut means : Array Float := Array.replicate rows 0.0 + let mut invStds : Array Float := Array.replicate rows 0.0 for r in [:rows] do let mut sum : Float := 0.0 let rowBase := r * cols @@ -1506,8 +1506,8 @@ def layerNormRowwise (X γ β : ConcreteMatrix) (eps : Float := 1e-5) : Concrete -- In exact arithmetic, `var ≥ 0`. Clamp to avoid NaN from tiny negative float noise. let var := max 0.0 (varSum / colsF) let invσ := 1.0 / Float.sqrt (var + eps) - means := means.push μ - invStds := invStds.push invσ + means := means.set! r μ + invStds := invStds.set! r invσ return { numRows := rows @@ -1515,8 +1515,8 @@ def layerNormRowwise (X γ β : ConcreteMatrix) (eps : Float := 1e-5) : Concrete data := .ofFn fun idx : Fin (rows * cols) => let r := idx.val / cols let c := idx.val % cols - let μ := means.getD r 0.0 - let invσ := invStds.getD r 0.0 + let μ := means[r]! + let invσ := invStds[r]! let normalized := (X.data[r * cols + c]! - μ) * invσ (gammaData[c]!) * normalized + (betaData[c]!) size_eq := Array.size_ofFn diff --git a/Nfp/Sound/BinaryPure.lean b/Nfp/Sound/BinaryPure.lean index c12755c..f17fdfc 100644 --- a/Nfp/Sound/BinaryPure.lean +++ b/Nfp/Sound/BinaryPure.lean @@ -176,15 +176,10 @@ private def floatAbsCeilScaledCore (scaleInt : Int) (bits : UInt64) : Except Str private def floatAbsCeilScaled (scalePow10 : Nat) (bits : UInt64) : Except String Int := floatAbsCeilScaledCore (scaleIntOfPow10 scalePow10) bits -private def floatScaledCeilSignedCore (scaleInt : Int) (bits : UInt64) : Except String Int := - match floatAbsCeilScaledCore scaleInt bits with - | .error e => .error e - | .ok absScaled => - let signNeg : Bool := (bits >>> 63) = (1 : UInt64) - if signNeg then - .ok (-absScaled) - else - .ok absScaled +private def floatScaledCeilSignedCore (scaleInt : Int) (bits : UInt64) : Except String Int := do + let absScaled ← floatAbsCeilScaledCore scaleInt bits + let signNeg : Bool := (bits >>> 63) = (1 : UInt64) + return if signNeg then -absScaled else absScaled private def floatScaledCeilSigned (scalePow10 : Nat) (bits : UInt64) : Except String Int := floatScaledCeilSignedCore (scaleIntOfPow10 scalePow10) bits @@ -201,11 +196,9 @@ def vectorMaxAbsScaledFromBytes (bytes : ByteArray) (n scalePow10 : Nat) : let mut off : Nat := 0 while i < n do let bits := u64FromLE bytes off - match floatAbsCeilScaledCore scaleInt bits with - | .error e => throw e - | .ok absScaled => - if absScaled > maxAbs then - maxAbs := absScaled + let absScaled ← floatAbsCeilScaledCore scaleInt bits + if absScaled > maxAbs then + maxAbs := absScaled off := off + 8 i := i + 1 return maxAbs @@ -225,17 +218,15 @@ def matrixNormInfScaledFromBytes (bytes : ByteArray) (rows cols scalePow10 : Nat let mut off : Nat := 0 while i < count do let bits := u64FromLE bytes off - match floatAbsCeilScaledCore scaleInt bits with - | .error e => throw e - | .ok absScaled => - curRowSum := curRowSum + absScaled - if colIdx + 1 = cols then - if curRowSum > maxRowSum then - maxRowSum := curRowSum - curRowSum := 0 - colIdx := 0 - else - colIdx := colIdx + 1 + let absScaled ← floatAbsCeilScaledCore scaleInt bits + curRowSum := curRowSum + absScaled + if colIdx + 1 = cols then + if curRowSum > maxRowSum then + maxRowSum := curRowSum + curRowSum := 0 + colIdx := 0 + else + colIdx := colIdx + 1 off := off + 8 i := i + 1 return maxRowSum @@ -289,9 +280,8 @@ def scaledFloatArrayFromBytes (bytes : ByteArray) (count scalePow10 : Nat) : let mut off : Nat := 0 while i < count do let bits := u64FromLE bytes off - match floatScaledCeilSignedCore scaleInt bits with - | .error e => throw e - | .ok v => out := out.set! i v + let v ← floatScaledCeilSignedCore scaleInt bits + out := out.set! i v off := off + 8 i := i + 1 return out @@ -301,9 +291,8 @@ def scaledFloatFromBytes (bytes : ByteArray) (scalePow10 : Nat) : if bytes.size < 8 then throw "unexpected EOF" let bits := u64FromLE bytes 0 - match floatScaledCeilSignedCore (scaleIntOfPow10 scalePow10) bits with - | .error e => throw e - | .ok v => return v + let v ← floatScaledCeilSignedCore (scaleIntOfPow10 scalePow10) bits + return v def i32ArrayFromBytes (bytes : ByteArray) (count : Nat) : Except String (Array Int) := do @@ -337,19 +326,17 @@ def matrixNormOneInfScaledFromBytes (bytes : ByteArray) (rows cols scalePow10 : let mut off : Nat := 0 while i < count do let bits := u64FromLE bytes off - match floatAbsCeilScaledCore scaleInt bits with - | .error e => throw e - | .ok absScaled => - let absNat := Int.toNat absScaled - curRowSum := curRowSum + absNat - colSums := colSums.set! colIdx (colSums[colIdx]! + absNat) - if colIdx + 1 = cols then - if curRowSum > maxRowSum then - maxRowSum := curRowSum - curRowSum := 0 - colIdx := 0 - else - colIdx := colIdx + 1 + let absScaled ← floatAbsCeilScaledCore scaleInt bits + let absNat := Int.toNat absScaled + curRowSum := curRowSum + absNat + colSums := colSums.set! colIdx (colSums[colIdx]! + absNat) + if colIdx + 1 = cols then + if curRowSum > maxRowSum then + maxRowSum := curRowSum + curRowSum := 0 + colIdx := 0 + else + colIdx := colIdx + 1 off := off + 8 i := i + 1 let mut maxColSum : Nat := 0 diff --git a/Nfp/Sound/TextPure.lean b/Nfp/Sound/TextPure.lean index 1ce8125..8ab5aa3 100644 --- a/Nfp/Sound/TextPure.lean +++ b/Nfp/Sound/TextPure.lean @@ -182,6 +182,15 @@ def modelWeightBoundsFromTextLines (lines : Array String) : Except String ModelW let mut ln1MaxAbsGamma : Array Rat := Array.replicate info.numLayers 0 let mut ln1MaxAbsBeta : Array Rat := Array.replicate info.numLayers 0 let mut ln2MaxAbsGamma : Array Rat := Array.replicate info.numLayers 0 + let updateAt := fun (arr : Array Rat) (idx : Nat) (f : Rat → Rat) => + if idx < arr.size then + arr.set! idx (f arr[idx]!) + else + arr + let setAt := fun (arr : Array Rat) (idx : Nat) (val : Rat) => + updateAt arr idx (fun _ => val) + let setMaxAt := fun (arr : Array Rat) (idx : Nat) (val : Rat) => + updateAt arr idx (fun cur => max cur val) while i < lines.size do let line := lines[i]!.trim if line.startsWith "LAYER" then @@ -203,16 +212,14 @@ def modelWeightBoundsFromTextLines (lines : Array String) : Except String ModelW match consumeMatrixNormInf lines (i + 1) info.modelDim info.headDim with | .error e => return .error e | .ok (nq, next) => - if r < wqMax.size then - wqMax := wqMax.set! r (max wqMax[r]! nq) + wqMax := setMaxAt wqMax r nq i := next else if line = "W_K" then let r := curLayer match consumeMatrixNormInf lines (i + 1) info.modelDim info.headDim with | .error e => return .error e | .ok (nk, next) => - if r < wkMax.size then - wkMax := wkMax.set! r (max wkMax[r]! nk) + wkMax := setMaxAt wkMax r nk i := next else if line = "W_V" then let r := curLayer @@ -227,49 +234,42 @@ def modelWeightBoundsFromTextLines (lines : Array String) : Except String ModelW match consumeMatrixNormInf lines (i + 1) info.headDim info.modelDim with | .error e => return .error e | .ok (no, next2) => - if r < attnValueCoeff.size then - attnValueCoeff := - attnValueCoeff.set! r (attnValueCoeff[r]! + (nv * no)) + attnValueCoeff := updateAt attnValueCoeff r (fun cur => cur + (nv * no)) i := next2 else if line = "W_in" then let r := curLayer match consumeMatrixNormInf lines (i + 1) info.modelDim info.hiddenDim with | .error e => return .error e | .ok (nwin, next) => - if r < mlpWinBound.size then - mlpWinBound := mlpWinBound.set! r nwin + mlpWinBound := setAt mlpWinBound r nwin i := next else if line = "W_out" then let r := curLayer match consumeMatrixNormInf lines (i + 1) info.hiddenDim info.modelDim with | .error e => return .error e | .ok (nwout, next) => - if r < mlpWoutBound.size then - mlpWoutBound := mlpWoutBound.set! r nwout + mlpWoutBound := setAt mlpWoutBound r nwout i := next else if line = "LN1_GAMMA" then let r := curLayer match consumeVectorMaxAbs lines (i + 1) info.modelDim with | .error e => return .error e | .ok (g, next) => - if r < ln1MaxAbsGamma.size then - ln1MaxAbsGamma := ln1MaxAbsGamma.set! r g + ln1MaxAbsGamma := setAt ln1MaxAbsGamma r g i := next else if line = "LN1_BETA" then let r := curLayer match consumeVectorMaxAbs lines (i + 1) info.modelDim with | .error e => return .error e | .ok (b, next) => - if r < ln1MaxAbsBeta.size then - ln1MaxAbsBeta := ln1MaxAbsBeta.set! r b + ln1MaxAbsBeta := setAt ln1MaxAbsBeta r b i := next else if line = "LN2_GAMMA" then let r := curLayer match consumeVectorMaxAbs lines (i + 1) info.modelDim with | .error e => return .error e | .ok (g, next) => - if r < ln2MaxAbsGamma.size then - ln2MaxAbsGamma := ln2MaxAbsGamma.set! r g + ln2MaxAbsGamma := setAt ln2MaxAbsGamma r g i := next else if line = "LN2_BETA" then match consumeVectorMaxAbs lines (i + 1) info.modelDim with From 9ab5b100399ebfaf61ee30c2995ac57b9afbca41 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 16:42:38 +0100 Subject: [PATCH 053/244] Optimize IO buffering and sound cache parsing --- Nfp/Discovery.lean | 19 ++++++++++--------- Nfp/IO.lean | 9 ++++----- Nfp/Sound/CachePure.lean | 12 +++++++++--- 3 files changed, 23 insertions(+), 17 deletions(-) diff --git a/Nfp/Discovery.lean b/Nfp/Discovery.lean index 8a21860..89f914e 100644 --- a/Nfp/Discovery.lean +++ b/Nfp/Discovery.lean @@ -5468,26 +5468,27 @@ def mkLayerNormJacobianCtx (X γ : ConcreteMatrix) (eps : Float := 1e-5) : Layer if !(γ.numRows = 1 ∧ γ.numCols = cols) then return { numRows := rows, numCols := cols, gamma := γ, invStds := #[], v := X } - let mut means : Array Float := Array.mkEmpty rows - let mut invStds : Array Float := Array.mkEmpty rows + let mut means : Array Float := Array.replicate rows 0.0 + let mut invStds : Array Float := Array.replicate rows 0.0 + let colsF := cols.toFloat for r in [:rows] do let mut sum : Float := 0.0 let rowBase := r * cols for c in [:cols] do sum := sum + X.data[rowBase + c]! - let μ := sum / cols.toFloat + let μ := sum / colsF let mut varSum : Float := 0.0 for c in [:cols] do let d := X.data[rowBase + c]! - μ varSum := varSum + d * d - let varRaw := varSum / cols.toFloat + let varRaw := varSum / colsF -- Clamp for numerical stability (avoid NaN from tiny negative float noise). let var := if Float.isNaN varRaw || Float.isInf varRaw then 0.0 else max 0.0 varRaw let invσ := 1.0 / Float.sqrt (var + eps) - means := means.push μ - invStds := invStds.push invσ + means := means.set! r μ + invStds := invStds.set! r invσ let v : ConcreteMatrix := { numRows := rows @@ -5495,8 +5496,8 @@ def mkLayerNormJacobianCtx (X γ : ConcreteMatrix) (eps : Float := 1e-5) : Layer data := .ofFn fun idx : Fin (rows * cols) => let r := idx.val / cols let c := idx.val % cols - let μ := means.getD r 0.0 - let invσ := invStds.getD r 0.0 + let μ := means[r]! + let invσ := invStds[r]! (X.data[r * cols + c]! - μ) * invσ size_eq := Array.size_ofFn } @@ -5527,7 +5528,7 @@ def LayerNormJacobianCtx.apply (ctx : LayerNormJacobianCtx) (dX : ConcreteMatrix let dx := dX.data[rowBase + c]! sumDx := sumDx + dx sumVDx := sumVDx + (ctx.v.data[rowBase + c]! * dx) - meanDx := meanDx.set! r (sumDx / cols.toFloat) + meanDx := meanDx.set! r (sumDx / colsF) vDotDx := vDotDx.set! r sumVDx { numRows := rows diff --git a/Nfp/IO.lean b/Nfp/IO.lean index 0a39306..625ac46 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -84,16 +84,15 @@ private def readLine? (h : IO.FS.Handle) : IO (Option String) := do private def readExactly (h : IO.FS.Handle) (n : Nat) : IO ByteArray := do if n = 0 then return ByteArray.empty - let mut out : Array UInt8 := Array.replicate n 0 + let mut out : ByteArray := ByteArray.mk (Array.replicate n 0) let mut off : Nat := 0 while off < n do let chunk ← h.read (USize.ofNat (n - off)) if chunk.isEmpty then throw (IO.userError "unexpected EOF") - for b in chunk.data do - out := out.set! off b - off := off + 1 - return ByteArray.mk out + out := chunk.copySlice 0 out off chunk.size + off := off + chunk.size + return out @[inline] private def u32FromLE (b : ByteArray) (off : Nat) : UInt32 := let b0 := (b[off]!).toUInt32 diff --git a/Nfp/Sound/CachePure.lean b/Nfp/Sound/CachePure.lean index 9698cad..b701231 100644 --- a/Nfp/Sound/CachePure.lean +++ b/Nfp/Sound/CachePure.lean @@ -210,7 +210,8 @@ private def consumeFixedBytes Id.run do let mut iLine := start let mut remaining := count - let mut buf : Array UInt8 := Array.mkEmpty (count * 4) + let mut buf : ByteArray := ByteArray.mk (Array.replicate (count * 4) 0) + let mut offBytes : Nat := 0 while remaining > 0 do if iLine ≥ lines.size then return .error "unexpected end of file while reading fixed tokens" @@ -233,9 +234,14 @@ private def consumeFixedBytes match parseFixed10Rounded scalePow10 bytes tokStart tokStop with | .error e => return .error e | .ok x => - buf := appendI32LE buf x + let ux : UInt32 := UInt32.ofInt x + buf := buf.set! offBytes (ux &&& 0xFF).toUInt8 + buf := buf.set! (offBytes + 1) ((ux >>> 8) &&& 0xFF).toUInt8 + buf := buf.set! (offBytes + 2) ((ux >>> 16) &&& 0xFF).toUInt8 + buf := buf.set! (offBytes + 3) ((ux >>> 24) &&& 0xFF).toUInt8 + offBytes := offBytes + 4 remaining := remaining - 1 - return .ok (ByteArray.mk buf, iLine) + return .ok (buf, iLine) private def readHeaderFromLines (lines : Array String) : Except String (Header × Nat) := Id.run do From 70d2cbed3386016a25f7d6365e7063d553f86dbc Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 16:49:23 +0100 Subject: [PATCH 054/244] Streamline verification scan and range folds --- Nfp/Linearization.lean | 10 ++++++---- Nfp/Verification.lean | 14 +++++++------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/Nfp/Linearization.lean b/Nfp/Linearization.lean index 4e685e0..c66da88 100644 --- a/Nfp/Linearization.lean +++ b/Nfp/Linearization.lean @@ -2102,16 +2102,19 @@ theorem DeepLinearization.layerJacobian_residual_bound (A := attnJac) (M := mlpJac) (a := A) (b := M) hA' hM' simpa [DeepLinearization.layerJacobian, attnJac, mlpJac] using hres +/-- Left-fold over `[0, count)` without allocating a list. -/ +private def foldRange {α : Type*} (count : Nat) (init : α) (f : α → Nat → α) : α := + Nat.rec (motive := fun _ => α) init (fun i acc => f acc i) count + /-- The composed Jacobian from layer `start` to layer `stop` (exclusive). -/ noncomputable def DeepLinearization.rangeJacobian (D : DeepLinearization (n := n) (d := d)) (start stop : ℕ) : SignedMixer (n × d) (n × d) := if _h : start < stop ∧ stop ≤ D.numLayers then - (List.range (stop - start)).foldl + foldRange (stop - start) SignedMixer.identity (fun acc i => if hi : start + i < D.numLayers then acc.comp (D.layerJacobian ⟨start + i, hi⟩) else acc) - SignedMixer.identity else SignedMixer.identity /-! ### Virtual Attention Heads -/ @@ -2170,7 +2173,7 @@ noncomputable def DeepValueTerm (D : DeepLinearization (n := n) (d := d)) : SignedMixer (n × d) (n × d) := let core := if _h : 0 < D.numLayers then - (List.range D.numLayers).foldl + foldRange D.numLayers SignedMixer.identity (fun acc i => if hi : i < D.numLayers then let L := D.layers ⟨i, hi⟩ @@ -2178,7 +2181,6 @@ noncomputable def DeepValueTerm (D : DeepLinearization (n := n) (d := d)) : -- Pre-LN: absorb ln_1 linearization into the value path. acc.comp (SignedMixer.identity + ln.comp (valueTerm L)) else acc) - SignedMixer.identity else SignedMixer.identity -- Final normalization is applied after all blocks. core.comp D.lnFJacobian diff --git a/Nfp/Verification.lean b/Nfp/Verification.lean index a66a3c0..d4d56af 100644 --- a/Nfp/Verification.lean +++ b/Nfp/Verification.lean @@ -55,11 +55,11 @@ def inductionTargetTokenFromHistory (model : ConcreteModel) : Option Nat := do let lastIdx := tokens.size - 1 let tCurr := tokens[lastIdx]! let mut foundIdx : Option Nat := none - for offset in [:lastIdx] do - if foundIdx.isNone then - let idx := lastIdx - 1 - offset - if tokens[idx]! = tCurr then - foundIdx := some idx + let mut idx := lastIdx + while idx > 0 && foundIdx.isNone do + idx := idx - 1 + if tokens[idx]! = tCurr then + foundIdx := some idx let k ← foundIdx some (tokens[k + 1]!) @@ -311,7 +311,7 @@ def verifyCircuit (ctx : VerificationContext) (candidateHeads : Array HeadRef) : CircuitVerificationRow := Id.run do let baseDelta := ctx.baseDelta let competence := baseDelta > ctx.cfg.competenceEpsilon - let mut failures : Array String := #[] + let mut failures : Array String := Array.mkEmpty (candidateHeads.size + 3) if !competence then failures := failures.push s!"Axiom1(baseline competence) failed: Δ_base={baseDelta} ≤ \ ε={ctx.cfg.competenceEpsilon}" @@ -333,7 +333,7 @@ def verifyCircuit (ctx : VerificationContext) (candidateHeads : Array HeadRef) : } -- Choose one control head per candidate head (same layer, closest output norm). - let mut selections : Array ControlSelection := #[] + let mut selections : Array ControlSelection := Array.mkEmpty candidateHeads.size for cand in candidateHeads do match selectEnergyMatchedControl ctx.baselineForward cand candidateHeads with | some sel => selections := selections.push sel From de6222817d332752fff20dfe8940e33280e9d212 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 17:03:55 +0100 Subject: [PATCH 055/244] Reduce list allocations in Linearization bounds --- Nfp/Discovery.lean | 8 +-- Nfp/Linearization.lean | 102 +++++++++++++++++++-------------------- Nfp/Sound/CachePure.lean | 9 ++-- Nfp/Sound/IO.lean | 40 +++++++-------- 4 files changed, 80 insertions(+), 79 deletions(-) diff --git a/Nfp/Discovery.lean b/Nfp/Discovery.lean index 89f914e..c4ec017 100644 --- a/Nfp/Discovery.lean +++ b/Nfp/Discovery.lean @@ -4965,9 +4965,9 @@ def computeLayerNormBounds (cache : PrecomputedCache) Task.spawn (fun _ => cache.layerNormBoundAt i.val effort) tasks.map Task.get else - let mut out : Array Float := Array.mkEmpty n + let mut out : Array Float := Array.replicate n 0.0 for l in [:n] do - out := out.push (cache.layerNormBoundAt l effort) + out := out.set! l (cache.layerNormBoundAt l effort) out /-- Build cached head data and pre-LN attention inputs for all layers. -/ @@ -5062,9 +5062,9 @@ def buildHeadData (model : ConcreteModel) (fwdResult : ForwardPassResult) for c in [:cols] do sums := sums.set! c (sums[c]! + M.data[rowBase + c]!) let invN := 1.0 / M.numRows.toFloat - let mut out : Array Float := Array.mkEmpty cols + let mut out : Array Float := Array.replicate cols 0.0 for c in [:cols] do - out := out.push (sums[c]! * invN) + out := out.set! c (sums[c]! * invN) return out let centerGram := fun (gram : ConcreteMatrix) (mean : Array Float) (n : Nat) => if gram.numRows = 0 || gram.numCols = 0 || n = 0 then diff --git a/Nfp/Linearization.lean b/Nfp/Linearization.lean index c66da88..01a913a 100644 --- a/Nfp/Linearization.lean +++ b/Nfp/Linearization.lean @@ -2353,12 +2353,11 @@ This product ignores `lnFJacobian`; if it is nontrivial, multiply by `operatorNormBound D.lnFJacobian` to bound end-to-end amplification. -/ noncomputable def amplificationFactor (D : DeepLinearization (n := n) (d := d)) : ℝ := -- Product of (1 + ‖layerJacobian - I‖) for all layers - (List.range D.numLayers).foldl + foldRange D.numLayers 1 (fun acc i => if hi : i < D.numLayers then acc * (1 + operatorNormBound (D.layerJacobian ⟨i, hi⟩ - SignedMixer.identity)) else acc) - 1 /-- **Two-layer composition theorem**: Explicit bound for 2-layer case. @@ -2414,25 +2413,21 @@ gets amplified by subsequent layers. When start = numLayers, this equals 1 (no amplification). -/ noncomputable def suffixAmplification (D : DeepLinearization (n := n) (d := d)) (start : ℕ) : ℝ := - (List.range (D.numLayers - start)).foldl + foldRange (D.numLayers - start) 1 (fun acc i => if hi : start + i < D.numLayers then acc * (1 + layerNormBounds D ⟨start + i, hi⟩) else acc) - 1 /-- Base case: suffix amplification starting at numLayers is 1. -/ theorem suffixAmplification_base (D : DeepLinearization (n := n) (d := d)) : suffixAmplification D D.numLayers = 1 := by - simp only [suffixAmplification, Nat.sub_self, List.range_zero, List.foldl_nil] + simp [suffixAmplification, foldRange] /-- The amplificationFactor equals suffixAmplification starting from 0. -/ theorem amplificationFactor_eq_suffix (D : DeepLinearization (n := n) (d := d)) : amplificationFactor D = suffixAmplification D 0 := by - simp only [amplificationFactor, suffixAmplification, layerNormBounds, Nat.sub_zero] - congr 1 - ext acc i - simp only [zero_add] + simp [amplificationFactor, suffixAmplification, layerNormBounds] /-- **Recursive total error formula**: Total error with amplification. @@ -2447,24 +2442,27 @@ theorem suffixAmplification_nonneg (D : DeepLinearization (n := n) (d := d)) (start : ℕ) (hNorm : ∀ i : Fin D.numLayers, 0 ≤ layerNormBounds D i) : 0 ≤ suffixAmplification D start := by unfold suffixAmplification - -- We prove a stronger statement: for any init ≥ 0, the foldl result is ≥ 0 - suffices h : ∀ init : ℝ, 0 ≤ init → - 0 ≤ (List.range (D.numLayers - start)).foldl - (fun acc i => if hi : start + i < D.numLayers then - acc * (1 + layerNormBounds D ⟨start + i, hi⟩) else acc) - init by - exact h 1 (by norm_num : (0 : ℝ) ≤ 1) - intro init hinit - generalize (List.range (D.numLayers - start)) = xs - induction xs generalizing init with - | nil => simp [hinit] - | cons x xs ih => - simp only [List.foldl_cons] - split_ifs with hi - · apply ih - apply mul_nonneg hinit - linarith [hNorm ⟨start + x, hi⟩] - · exact ih init hinit + -- We prove a stronger statement: for any init ≥ 0, the fold result is ≥ 0 + let f := fun acc i => + if hi : start + i < D.numLayers then + acc * (1 + layerNormBounds D ⟨start + i, hi⟩) + else acc + suffices h : ∀ count : Nat, ∀ init : ℝ, 0 ≤ init → 0 ≤ foldRange count init f by + exact h (D.numLayers - start) 1 (by norm_num : (0 : ℝ) ≤ 1) + intro count + induction count with + | zero => + intro init hinit + simpa [foldRange] using hinit + | succ count ih => + intro init hinit + have hacc : 0 ≤ foldRange count init f := ih init hinit + by_cases hi : start + count < D.numLayers + · have hbound : 0 ≤ layerNormBounds D ⟨start + count, hi⟩ := hNorm _ + have hmul : 0 ≤ foldRange count init f * (1 + layerNormBounds D ⟨start + count, hi⟩) := + mul_nonneg hacc (by linarith [hbound]) + simpa [foldRange, f, hi] using hmul + · simpa [foldRange, f, hi] using hacc /- These lemmas don't need the `[Nonempty _]` section variables (they are in scope @@ -2518,42 +2516,44 @@ theorem n_layer_faithfulness_composition ∃ (ε_total : ℝ), 0 ≤ ε_total ∧ ε_total ≤ ∑ i : Fin D.numLayers, - εs i * (List.range (D.numLayers - (i.val + 1))).foldl + εs i * foldRange (D.numLayers - (i.val + 1)) 1 (fun acc j => if hj : i.val + 1 + j < D.numLayers then acc * (1 + Cs ⟨i.val + 1 + j, hj⟩) - else acc) - 1 := by + else acc) := by -- The witness is exactly the bound formula let suffix_bound : Fin D.numLayers → ℝ := fun i => - (List.range (D.numLayers - (i.val + 1))).foldl + foldRange (D.numLayers - (i.val + 1)) 1 (fun acc j => if hj : i.val + 1 + j < D.numLayers then acc * (1 + Cs ⟨i.val + 1 + j, hj⟩) else acc) - 1 -- Helper: suffix_bound is nonnegative have hsuffix_nonneg : ∀ i : Fin D.numLayers, 0 ≤ suffix_bound i := by intro i - simp only [suffix_bound] - -- We prove: for any init ≥ 0, foldl result is ≥ 0 - suffices h : ∀ init : ℝ, 0 ≤ init → - 0 ≤ (List.range (D.numLayers - (i.val + 1))).foldl - (fun acc j => if hj : i.val + 1 + j < D.numLayers then - acc * (1 + Cs ⟨i.val + 1 + j, hj⟩) else acc) - init by - exact h 1 (by norm_num : (0 : ℝ) ≤ 1) - intro init hinit - generalize (List.range (D.numLayers - (i.val + 1))) = xs - induction xs generalizing init with - | nil => simp [hinit] - | cons x xs ih => - simp only [List.foldl_cons] - split_ifs with hj - · apply ih - apply mul_nonneg hinit - linarith [hC_pos ⟨i.val + 1 + x, hj⟩] - · exact ih init hinit + let f := fun acc j => + if hj : i.val + 1 + j < D.numLayers then + acc * (1 + Cs ⟨i.val + 1 + j, hj⟩) + else acc + -- We prove: for any init ≥ 0, foldRange result is ≥ 0 + suffices h : ∀ count : Nat, ∀ init : ℝ, 0 ≤ init → 0 ≤ foldRange count init f by + have h1 : 0 ≤ foldRange (D.numLayers - (i.val + 1)) 1 f := + h (D.numLayers - (i.val + 1)) 1 (by norm_num : (0 : ℝ) ≤ 1) + simpa [suffix_bound, f] using h1 + intro count + induction count with + | zero => + intro init hinit + simpa [foldRange] using hinit + | succ count ih => + intro init hinit + have hacc : 0 ≤ foldRange count init f := ih init hinit + by_cases hj : i.val + 1 + count < D.numLayers + · have hbound : 0 ≤ Cs ⟨i.val + 1 + count, hj⟩ := hC_pos _ + have hmul : 0 ≤ foldRange count init f * (1 + Cs ⟨i.val + 1 + count, hj⟩) := + mul_nonneg hacc (by linarith [hbound]) + simpa [foldRange, f, hj] using hmul + · simpa [foldRange, f, hj] using hacc use ∑ i : Fin D.numLayers, εs i * suffix_bound i constructor · -- Nonnegativity diff --git a/Nfp/Sound/CachePure.lean b/Nfp/Sound/CachePure.lean index b701231..dd5b097 100644 --- a/Nfp/Sound/CachePure.lean +++ b/Nfp/Sound/CachePure.lean @@ -364,11 +364,12 @@ private def repeatBytes (b : ByteArray) (n : Nat) : ByteArray := Id.run do if n = 0 || b.size = 0 then return ByteArray.empty - let mut out : Array UInt8 := Array.mkEmpty (n * b.size) + let mut out : ByteArray := ByteArray.mk (Array.replicate (n * b.size) 0) + let mut off : Nat := 0 for _ in [:n] do - for byte in b.data do - out := out.push byte - return ByteArray.mk out + out := b.copySlice 0 out off b.size + off := off + b.size + return out def buildCacheBytes (lines : Array String) diff --git a/Nfp/Sound/IO.lean b/Nfp/Sound/IO.lean index da9babf..7b8d4a3 100644 --- a/Nfp/Sound/IO.lean +++ b/Nfp/Sound/IO.lean @@ -64,17 +64,17 @@ private def recomputeModelWeightBoundsBinary match ← Nfp.Untrusted.SoundBinary.skipF64Array h (hdr.seqLen * hdr.modelDim) with | .error e => return .error e | .ok _ => pure () - let mut valuePairsLayers : Array (Array (Int × Int)) := Array.mkEmpty hdr.numLayers - let mut qkPairsLayers : Array (Array (Int × Int)) := Array.mkEmpty hdr.numLayers - let mut mlpWinBound : Array Rat := Array.mkEmpty hdr.numLayers - let mut mlpWoutBound : Array Rat := Array.mkEmpty hdr.numLayers - let mut ln1MaxAbsGamma : Array Rat := Array.mkEmpty hdr.numLayers - let mut ln1MaxAbsBeta : Array Rat := Array.mkEmpty hdr.numLayers - let mut ln2MaxAbsGamma : Array Rat := Array.mkEmpty hdr.numLayers - for _l in [:hdr.numLayers] do - let mut valuePairs : Array (Int × Int) := Array.mkEmpty hdr.numHeads - let mut qkPairs : Array (Int × Int) := Array.mkEmpty hdr.numHeads - for _h in [:hdr.numHeads] do + let mut valuePairsLayers : Array (Array (Int × Int)) := Array.replicate hdr.numLayers #[] + let mut qkPairsLayers : Array (Array (Int × Int)) := Array.replicate hdr.numLayers #[] + let mut mlpWinBound : Array Rat := Array.replicate hdr.numLayers 0 + let mut mlpWoutBound : Array Rat := Array.replicate hdr.numLayers 0 + let mut ln1MaxAbsGamma : Array Rat := Array.replicate hdr.numLayers 0 + let mut ln1MaxAbsBeta : Array Rat := Array.replicate hdr.numLayers 0 + let mut ln2MaxAbsGamma : Array Rat := Array.replicate hdr.numLayers 0 + for l in [:hdr.numLayers] do + let mut valuePairs : Array (Int × Int) := Array.replicate hdr.numHeads (0, 0) + let mut qkPairs : Array (Int × Int) := Array.replicate hdr.numHeads (0, 0) + for hIdx in [:hdr.numHeads] do let wqScaledE ← Nfp.Untrusted.SoundBinary.readMatrixNormInfScaled h hdr.modelDim hdr.headDim scalePow10 @@ -112,8 +112,8 @@ private def recomputeModelWeightBoundsBinary match noScaledE with | .error e => return .error e | .ok v => pure v - qkPairs := qkPairs.push (wqScaled, wkScaled) - valuePairs := valuePairs.push (nvScaled, noScaled) + qkPairs := qkPairs.set! hIdx (wqScaled, wkScaled) + valuePairs := valuePairs.set! hIdx (nvScaled, noScaled) match ← Nfp.Untrusted.SoundBinary.skipF64Array h hdr.modelDim with | .error e => return .error e | .ok _ => pure () @@ -165,13 +165,13 @@ private def recomputeModelWeightBoundsBinary let ln1Gamma := ratOfScaledInt scalePow10 ln1GammaScaled let ln1Beta := ratOfScaledInt scalePow10 ln1BetaScaled let ln2Gamma := ratOfScaledInt scalePow10 ln2GammaScaled - mlpWinBound := mlpWinBound.push nWin - mlpWoutBound := mlpWoutBound.push nWout - ln1MaxAbsGamma := ln1MaxAbsGamma.push ln1Gamma - ln1MaxAbsBeta := ln1MaxAbsBeta.push ln1Beta - ln2MaxAbsGamma := ln2MaxAbsGamma.push ln2Gamma - valuePairsLayers := valuePairsLayers.push valuePairs - qkPairsLayers := qkPairsLayers.push qkPairs + mlpWinBound := mlpWinBound.set! l nWin + mlpWoutBound := mlpWoutBound.set! l nWout + ln1MaxAbsGamma := ln1MaxAbsGamma.set! l ln1Gamma + ln1MaxAbsBeta := ln1MaxAbsBeta.set! l ln1Beta + ln2MaxAbsGamma := ln2MaxAbsGamma.set! l ln2Gamma + valuePairsLayers := valuePairsLayers.set! l valuePairs + qkPairsLayers := qkPairsLayers.set! l qkPairs match ← Nfp.Untrusted.SoundBinary.skipF64Array h hdr.modelDim with | .error e => return .error e | .ok _ => pure () From 46260ad99eb7cd83fb8997c7da3994a54a8dd0a9 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 17:10:15 +0100 Subject: [PATCH 056/244] Preallocate sound and discovery arrays --- Nfp/Discovery.lean | 8 ++------ Nfp/Sound/BinaryPure.lean | 12 ++++++------ Nfp/Sound/CachePure.lean | 12 +++++++++--- Nfp/Sound/Interval.lean | 11 ++++++----- 4 files changed, 23 insertions(+), 20 deletions(-) diff --git a/Nfp/Discovery.lean b/Nfp/Discovery.lean index c4ec017..0576941 100644 --- a/Nfp/Discovery.lean +++ b/Nfp/Discovery.lean @@ -5386,12 +5386,8 @@ def buildHeadData (model : ConcreteModel) (fwdResult : ForwardPassResult) .ofFn fun i : Fin model.numLayers => computeLayer i.val - let mut headData : Array (Array PrecomputedHeadData) := Array.mkEmpty model.numLayers - let mut ln1Inputs : Array ConcreteMatrix := Array.mkEmpty model.numLayers - for (layerHeadData, attnInput) in layerResults do - headData := headData.push layerHeadData - ln1Inputs := ln1Inputs.push attnInput - + let headData := layerResults.map (·.1) + let ln1Inputs := layerResults.map (·.2) return (headData, ln1Inputs) /-- Build a complete precomputed cache for a model. diff --git a/Nfp/Sound/BinaryPure.lean b/Nfp/Sound/BinaryPure.lean index f17fdfc..92fbab1 100644 --- a/Nfp/Sound/BinaryPure.lean +++ b/Nfp/Sound/BinaryPure.lean @@ -392,15 +392,15 @@ def attnWeightBoundsArraysFromScaledPairs (scalePow10 : Nat) if valuePairs.size ≠ qkPairs.size then return .error s!"attn weight bounds layer count mismatch: \ value={valuePairs.size}, qk={qkPairs.size}" - let mut coeffs : Array Rat := Array.mkEmpty valuePairs.size - let mut wqMaxs : Array Rat := Array.mkEmpty valuePairs.size - let mut wkMaxs : Array Rat := Array.mkEmpty valuePairs.size + let mut coeffs : Array Rat := Array.replicate valuePairs.size 0 + let mut wqMaxs : Array Rat := Array.replicate valuePairs.size 0 + let mut wkMaxs : Array Rat := Array.replicate valuePairs.size 0 for idx in [:valuePairs.size] do let coeff := attnValueCoeffFromScaledPairs scalePow10 valuePairs[idx]! let (wqMax, wkMax) := attnQKMaxFromScaledPairs scalePow10 qkPairs[idx]! - coeffs := coeffs.push coeff - wqMaxs := wqMaxs.push wqMax - wkMaxs := wkMaxs.push wkMax + coeffs := coeffs.set! idx coeff + wqMaxs := wqMaxs.set! idx wqMax + wkMaxs := wkMaxs.set! idx wkMax return .ok (coeffs, wqMaxs, wkMaxs) /-! ### Derived properties -/ diff --git a/Nfp/Sound/CachePure.lean b/Nfp/Sound/CachePure.lean index dd5b097..bfb260d 100644 --- a/Nfp/Sound/CachePure.lean +++ b/Nfp/Sound/CachePure.lean @@ -355,10 +355,16 @@ private def collectLayerNormParamsFixed private def encodeIntArray (xs : Array Int) : ByteArray := Id.run do - let mut out : Array UInt8 := Array.mkEmpty (xs.size * 4) + let mut out : ByteArray := ByteArray.mk (Array.replicate (xs.size * 4) 0) + let mut off : Nat := 0 for x in xs do - out := appendI32LE out x - return ByteArray.mk out + let ux : UInt32 := UInt32.ofInt x + out := out.set! off (ux &&& 0xFF).toUInt8 + out := out.set! (off + 1) ((ux >>> 8) &&& 0xFF).toUInt8 + out := out.set! (off + 2) ((ux >>> 16) &&& 0xFF).toUInt8 + out := out.set! (off + 3) ((ux >>> 24) &&& 0xFF).toUInt8 + off := off + 4 + return out private def repeatBytes (b : ByteArray) (n : Nat) : ByteArray := Id.run do diff --git a/Nfp/Sound/Interval.lean b/Nfp/Sound/Interval.lean index 6a82fa2..60428fb 100644 --- a/Nfp/Sound/Interval.lean +++ b/Nfp/Sound/Interval.lean @@ -129,15 +129,16 @@ def varianceLowerBound (xs : Array RatInterval) : Rat := if n < 2 then return 0 -- Build sorted breakpoint lists for `lo` and `hi` with squared endpoints for O(1) evaluation. - let mut enters : Array (Rat × Rat) := Array.mkEmpty n - let mut leaves : Array (Rat × Rat) := Array.mkEmpty n + let mut enters : Array (Rat × Rat) := Array.replicate n (0, 0) + let mut leaves : Array (Rat × Rat) := Array.replicate n (0, 0) let mut sumLeft : Rat := 0 let mut sumLeftSq : Rat := 0 - for x in normed do + for i in [:n] do + let x := normed[i]! let lo := x.lo let hi := x.hi - enters := enters.push (lo, ratSq lo) - leaves := leaves.push (hi, ratSq hi) + enters := enters.set! i (lo, ratSq lo) + leaves := leaves.set! i (hi, ratSq hi) sumLeft := sumLeft + lo sumLeftSq := sumLeftSq + ratSq lo -- Exact minimization over the breakpoints (O(n log n)). From d82ecc16f5317b0abebe8084610c362bee494d4c Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 17:21:25 +0100 Subject: [PATCH 057/244] Tidy cert validity checks and logging --- Nfp/Discovery.lean | 14 +++++++++++++- Nfp/Sound/Bounds/Portfolio.lean | 27 +++++++++++---------------- Nfp/Sound/Cert.lean | 9 ++++----- 3 files changed, 28 insertions(+), 22 deletions(-) diff --git a/Nfp/Discovery.lean b/Nfp/Discovery.lean index 0576941..c49bb56 100644 --- a/Nfp/Discovery.lean +++ b/Nfp/Discovery.lean @@ -4000,7 +4000,19 @@ namespace DeepCircuitCandidate def toString (c : DeepCircuitCandidate) : String := let heads := c.layerIndices.zip c.headIndices |>.map fun (l, h) => s!"L{l}H{h}" - s!"{c.patternType}: {heads.toList} | " ++ + let headStr := + Id.run do + let mut out := "[" + let mut first := true + for h in heads do + if first then + first := false + else + out := out ++ ", " + out := out ++ h + out := out ++ "]" + return out + s!"{c.patternType}: {headStr} | " ++ s!"ε_simple={c.simpleErrorSum}, ε_amplified={c.amplifiedError}, amp={c.amplificationFactor}" instance : ToString DeepCircuitCandidate := ⟨toString⟩ diff --git a/Nfp/Sound/Bounds/Portfolio.lean b/Nfp/Sound/Bounds/Portfolio.lean index fddef5e..cbd3f4a 100644 --- a/Nfp/Sound/Bounds/Portfolio.lean +++ b/Nfp/Sound/Bounds/Portfolio.lean @@ -1,6 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Mathlib.Algebra.Order.Ring.Unbundled.Rat +import Init.Data.Array.Lemmas namespace Nfp.Sound @@ -20,15 +21,12 @@ theorem ubBest_def (base : Rat) (cands : Array Rat) : /-- `ubBest` never exceeds its baseline upper bound. -/ theorem ubBest_le_base (base : Rat) (cands : Array Rat) : ubBest base cands ≤ base := by classical - have hList : cands.toList.foldl min base ≤ base := by - induction cands.toList generalizing base with - | nil => simp - | cons x xs ih => - simp only [List.foldl] - have h := ih (base := min base x) - exact le_trans h (min_le_left _ _) have hArray : cands.foldl min base ≤ base := by - simpa [Array.foldl_toList] using hList + refine Array.foldl_induction (as := cands) + (motive := fun _ acc => acc ≤ base) (init := base) (f := fun acc x => min acc x) ?h0 ?hf + · exact le_rfl + · intro i acc hacc + exact le_trans (min_le_left _ _) hacc simpa [ubBest] using hArray /-- Best lower bound among candidates (never worse than `base`). -/ @@ -41,15 +39,12 @@ theorem lbBest_def (base : Rat) (cands : Array Rat) : /-- `lbBest` never undercuts its baseline lower bound. -/ theorem lbBest_ge_base (base : Rat) (cands : Array Rat) : base ≤ lbBest base cands := by classical - have hList : base ≤ cands.toList.foldl max base := by - induction cands.toList generalizing base with - | nil => simp - | cons x xs ih => - simp only [List.foldl] - have h := ih (base := max base x) - exact le_trans (le_max_left _ _) h have hArray : base ≤ cands.foldl max base := by - simpa [Array.foldl_toList] using hList + refine Array.foldl_induction (as := cands) + (motive := fun _ acc => base ≤ acc) (init := base) (f := fun acc x => max acc x) ?h0 ?hf + · exact le_rfl + · intro i acc hacc + exact le_trans hacc (le_max_left _ _) simpa [lbBest] using hArray end Nfp.Sound diff --git a/Nfp/Sound/Cert.lean b/Nfp/Sound/Cert.lean index 8761714..65c7ee9 100644 --- a/Nfp/Sound/Cert.lean +++ b/Nfp/Sound/Cert.lean @@ -339,11 +339,10 @@ def Valid (c : ModelCert) : Prop := 0 < c.eps ∧ c.softmaxJacobianNormInfWorst = Nfp.Sound.softmaxJacobianNormInfWorst ∧ c.actDerivBound = c.layers.foldl (fun acc l => max acc l.mlpActDerivBound) 0 ∧ - List.Forall₂ - (fun i l => - l.layerIdx = i ∧ - LayerAmplificationCert.Valid c.eps c.soundnessBits c.seqLen c.modelDim c.headDim l) - (List.range c.layers.size) c.layers.toList ∧ + (∀ i : Fin c.layers.size, + let l := c.layers[i] + l.layerIdx = i.val ∧ + LayerAmplificationCert.Valid c.eps c.soundnessBits c.seqLen c.modelDim c.headDim l) ∧ c.totalAmplificationFactor = c.layers.foldl (fun acc l => acc * (1 + l.C)) 1 From 83b33eda97b99e7f7e878c677599d02fb7f4c702 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 22:03:06 +0100 Subject: [PATCH 058/244] Optimize induction cert hot path --- Nfp/Untrusted/SoundCompute.lean | 363 ++++++++++++++++++++++++++------ 1 file changed, 299 insertions(+), 64 deletions(-) diff --git a/Nfp/Untrusted/SoundCompute.lean b/Nfp/Untrusted/SoundCompute.lean index 15f502c..a8634fa 100644 --- a/Nfp/Untrusted/SoundCompute.lean +++ b/Nfp/Untrusted/SoundCompute.lean @@ -18,25 +18,26 @@ open Nfp.Sound open Nfp.Untrusted.SoundBinary /-! -# Untrusted SOUND `.nfpt` loader (exact Rat parsing, legacy text format) +# Untrusted SOUND computation helpers -This is a minimal, *sound* loader intended for certification on the legacy text format. +This module performs **IO-heavy witness generation** for SOUND certification. It parses `.nfpt` +models (binary, plus legacy text for some paths) and computes candidate certificates for: +- model-level residual amplification bounds, +- per-head contribution bounds, +- local head-pattern / best-match / induction certificates. It does **not** construct the full `ConcreteModel` (Float-based). Instead it parses only the weights needed for conservative residual amplification constants `Cᵢ` (bounds ‖layerJacobian - I‖), -using exact `Rat` arithmetic. +using exact `Rat` arithmetic or fixed-point interval arithmetic. -It can optionally consume an input `.nfpt` file (for `EMBEDDINGS`) to enable **local** -LayerNorm certification on a bounded region around that input. - -Global certification supports `NFP_BINARY_V1` via a sound fixed-point rounding bridge. -Local (input-dependent) certification supports `NFP_BINARY_V1` using a union-box fixed-point path. +All certificates produced here are **untrusted** and must be validated by the trusted checker +in `Nfp.Sound.IO`. Trusted base: - Parsing from text to `Rat` via `Nfp.Sound.parseRat`. - Exact accumulation of row-sum norms and max-abs values. -No `Float` arithmetic is used as an input to certification. +No `Float` arithmetic is *trusted* as an input to certification. -/ private def defaultBinaryScalePow10 : Nat := 9 @@ -1078,12 +1079,17 @@ private def fixedVarianceLowerBoundExact (cfg : Fixed10Cfg) (xs : Array Fixed10I private def fixedVarianceLowerBound (cfg : Fixed10Cfg) (xs : Array Fixed10Interval) : Rat := let rangeLB := fixedVarianceLowerBoundRange cfg xs let midLB := fixedVarianceLowerBoundMidpoint cfg xs - -- Avoid the exact Rat-based bound on large rows (expensive and stack-heavy). - if xs.size > 256 then - max rangeLB midLB - else + let approxLB := max rangeLB midLB + -- Avoid the exact Rat-based bound on large rows (expensive and stack-heavy), + -- but recover it when the fast bounds collapse to zero for medium sizes. + if xs.size ≤ 256 then + let exactLB := fixedVarianceLowerBoundExact cfg xs + max approxLB exactLB + else if approxLB = 0 && xs.size ≤ 1024 then let exactLB := fixedVarianceLowerBoundExact cfg xs - max rangeLB (max midLB exactLB) + max approxLB exactLB + else + approxLB private def fixedLayerNormRowApprox (cfg : Fixed10Cfg) @@ -2669,7 +2675,7 @@ private def certifyHeadValueLowerBoundLocalBinaryAt let keyOffsetNat? : Option Nat := if keyOffset ≥ 0 then some (Int.toNat keyOffset) else none let keyOffsetNeg : Nat := Int.toNat (-keyOffset) - let h ← IO.FS.Handle.mk path IO.FS.Mode.read + let h ← ExceptT.lift <| IO.FS.Handle.mk path IO.FS.Mode.read let _ ← ExceptT.mk (readBinaryHeader h) let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) let _ ← ExceptT.mk (skipF64Array h (hdr.seqLen * hdr.modelDim)) @@ -4923,21 +4929,37 @@ private def certifyHeadPatternBestMatchLocalBinary let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 let slack : Int := fixedUlpSlack let action : ExceptT String IO HeadBestMatchPatternCert := do + let timingEnabled ← ExceptT.lift <| IO.getEnv "NFP_TIMING" + let timing : Bool := timingEnabled.isSome + let timeIt {α : Type} (label : String) (work : ExceptT String IO α) : + ExceptT String IO α := do + if !timing then + work + else + let t0 ← ExceptT.lift IO.monoNanosNow + let r ← work + let t1 ← ExceptT.lift IO.monoNanosNow + let dtMs := (t1 - t0) / 1000000 + ExceptT.lift <| IO.eprintln s!"timing:{label} {dtMs}ms" + return r let (hdr, ln1Params, ln2Params, residualsBase, tokensBase) ← - match shared? with - | some shared => + timeIt "load_shared" <| match shared? with + | some shared => do if shared.scalePow10 ≠ scalePow10 then throw "shared scalePow10 mismatch" if shared.inputDelta ≠ inputDelta then throw "shared inputDelta mismatch" pure (shared.hdr, shared.ln1Params, shared.ln2Params, shared.residuals0, shared.tokens) - | none => + | none => do let (hdr, ln1Params, ln2Params) ← - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) + timeIt "load_ln_params" <| + ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) let residuals0 ← - ExceptT.mk - (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) - let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) + timeIt "load_embeddings" <| + ExceptT.mk + (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) + let (hdrTok, tokens) ← + timeIt "load_tokens" <| ExceptT.mk (loadTokensBinary inputPath) if hdrTok.seqLen ≠ hdr.seqLen then throw "token/embedding seq_len mismatch" pure (hdr, ln1Params, ln2Params, residuals0, tokens) @@ -4986,6 +5008,12 @@ private def certifyHeadPatternBestMatchLocalBinary fixedLayerNormRowApprox cfg row p1.gamma p1.beta eps soundnessBits ln1Rows := ln1Rows.push ln1Out if l = layerIdx then + let tPattern0? ← + if timing then + let t0 ← ExceptT.lift IO.monoNanosNow + pure (some t0) + else + pure none let mut wq? : Option (Array Int) := none let mut bq? : Option (Array Int) := none let mut wk? : Option (Array Int) := none @@ -5215,6 +5243,10 @@ private def certifyHeadPatternBestMatchLocalBinary softmaxJacobianNormInfUpperBound := softmaxJacobianUB } if cert.check then + if let some t0 := tPattern0? then + let t1 ← ExceptT.lift IO.monoNanosNow + let dtMs := (t1 - t0) / 1000000 + ExceptT.lift <| IO.eprintln s!"timing:layer{l}:pattern {dtMs}ms" return cert throw "best-match head pattern certificate failed internal consistency checks" else @@ -6692,21 +6724,37 @@ private def certifyInductionSoundBestMatchLocalBinaryPair let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 let slack : Int := fixedUlpSlack let action : ExceptT String IO InductionHeadBestMatchSoundCert := do + let timingEnabled ← ExceptT.lift <| IO.getEnv "NFP_TIMING" + let timing : Bool := timingEnabled.isSome + let timeIt {α : Type} (label : String) (work : ExceptT String IO α) : + ExceptT String IO α := do + if !timing then + work + else + let t0 ← ExceptT.lift IO.monoNanosNow + let r ← work + let t1 ← ExceptT.lift IO.monoNanosNow + let dtMs := (t1 - t0) / 1000000 + ExceptT.lift <| IO.eprintln s!"timing:{label} {dtMs}ms" + return r let (hdr, ln1Params, ln2Params, residualsBase, tokensBase) ← - match shared? with - | some shared => + timeIt "load_shared" <| match shared? with + | some shared => do if shared.scalePow10 ≠ scalePow10 then throw "shared scalePow10 mismatch" if shared.inputDelta ≠ inputDelta then throw "shared inputDelta mismatch" pure (shared.hdr, shared.ln1Params, shared.ln2Params, shared.residuals0, shared.tokens) - | none => + | none => do let (hdr, ln1Params, ln2Params) ← - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) + timeIt "load_ln_params" <| + ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) let residuals0 ← - ExceptT.mk - (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) - let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) + timeIt "load_embeddings" <| + ExceptT.mk + (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) + let (hdrTok, tokens) ← + timeIt "load_tokens" <| ExceptT.mk (loadTokensBinary inputPath) if hdrTok.seqLen ≠ hdr.seqLen then throw "token/embedding seq_len mismatch" pure (hdr, ln1Params, ln2Params, residuals0, tokens) @@ -6798,13 +6846,11 @@ private def certifyInductionSoundBestMatchLocalBinaryPair (p : LayerNormParamsFixed) : Array (Array Fixed10Interval) := fixedLayerNormRowsApproxExact cfg rows p eps soundnessBits - let calcVOutRows + let calcVOutRowsIntervals (rows : Array (Array Fixed10Interval)) - (wv wo : Array Int) + (wvIntervals woIntervals : Array Fixed10Interval) (bV : Array Fixed10Interval) : Array (Array Fixed10Interval) := - let wvIntervals := intervalsFromScaled wv slack - let woIntervals := intervalsFromScaled wo slack let useTasks := rows.size > 32 if useTasks then Id.run do @@ -6846,15 +6892,15 @@ private def certifyInductionSoundBestMatchLocalBinaryPair hdr.headDim hdr.modelDim woIntervals vHidden out := out.push vOut return out - let calcVOut + let calcVOutIntervals (row : Array Fixed10Interval) - (wv wo : Array Int) + (wvIntervals woIntervals : Array Fixed10Interval) (bV : Array Fixed10Interval) : Array Fixed10Interval := - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row + let vHidden0 := matMulIntervalsFromIntervalsNoTask cfg + hdr.modelDim hdr.headDim wvIntervals row let vHidden := addVecFixed vHidden0 bV - matMulIntervalsFromScaled cfg slack hdr.headDim hdr.modelDim wo vHidden + matMulIntervalsFromIntervalsNoTask cfg hdr.headDim hdr.modelDim woIntervals vHidden let bestMatchPattern (layerIdx headIdx : Nat) (ln1Rows : Array (Array Fixed10Interval)) @@ -7065,12 +7111,12 @@ private def certifyInductionSoundBestMatchLocalBinaryPair let valueLogit (ln1Rows : Array (Array Fixed10Interval)) (matchWeightLowerBound : Rat) - (wv wo : Array Int) + (wvIntervals woIntervals : Array Fixed10Interval) (bV : Array Fixed10Interval) (targetOffset : Int) (keyOffset : Int) : ExceptT String IO HeadValueLogitCert := do - let vOutRows := calcVOutRows ln1Rows wv wo bV + let vOutRows := calcVOutRowsIntervals ln1Rows wvIntervals woIntervals bV let ti : Int := (Int.ofNat queryPos) + targetOffset if ti < 0 || ti ≥ (Int.ofNat seqLenEff) then throw "query position has no valid target offset" @@ -7274,23 +7320,35 @@ private def certifyInductionSoundBestMatchLocalBinaryPair (groupRows? : Option (Array (Array Fixed10Interval))) (attnRows? : Option (Array (Array Fixed10Interval))) (attnUnion? : Option (Array Fixed10Interval)) - (wv wo : Array Int) + (wvIntervals woIntervals : Array Fixed10Interval) (bV : Array Fixed10Interval) : ExceptT String IO (Option (Array (Array Fixed10Interval)) × Option (Array Fixed10Interval)) := do if useTight then if causalPattern then - let vOutRows := calcVOutRows ln1Rows wv wo bV - let headRows := prefixUnionRowsFixed vOutRows + let vOutRows := calcVOutRowsIntervals ln1Rows wvIntervals woIntervals bV match attnRows? with - | some rows => return (some (addRowsFixed rows headRows), attnUnion?) + | some rows => + if rows.size ≠ vOutRows.size then + return (some rows, attnUnion?) + if vOutRows.isEmpty then + return (some rows, attnUnion?) + let mut outRows := rows + let mut acc := vOutRows[0]! + outRows := outRows.set! 0 (addVecFixed rows[0]! acc) + let mut i : Nat := 1 + while i < vOutRows.size do + acc := Fixed10Interval.unionVec acc vOutRows[i]! + outRows := outRows.set! i (addVecFixed rows[i]! acc) + i := i + 1 + return (some outRows, attnUnion?) | none => throw "missing attnRows" else let groupRows ← match groupRows? with | some rows => pure rows | none => throw "missing group rows" - let vOutRows := calcVOutRows groupRows wv wo bV + let vOutRows := calcVOutRowsIntervals groupRows wvIntervals woIntervals bV let vUnion := unionRowsFixed vOutRows match attnUnion? with | some u => return (attnRows?, some (addVecFixed u vUnion)) @@ -7300,7 +7358,7 @@ private def certifyInductionSoundBestMatchLocalBinaryPair match ln1Union? with | some row => pure row | none => throw "missing ln1Union" - let vOut := calcVOut ln1Union wv wo bV + let vOut := calcVOutIntervals ln1Union wvIntervals woIntervals bV match attnUnion? with | some u => return (attnRows?, some (addVecFixed u vOut)) | none => throw "missing attnUnion" @@ -7403,6 +7461,12 @@ private def certifyInductionSoundBestMatchLocalBinaryPair let needRows2 := at2 || needUpdate2 let needRowsV := needRows2 let ln1P := ln1Params.getD l defP + let tLn10? ← + if timing then + let t0 ← ExceptT.lift IO.monoNanosNow + pure (some t0) + else + pure none let mut ln1RowsShared? : Option (Array (Array Fixed10Interval)) := none if residualsSame && (needRows1 || needRows2) then ln1RowsShared? := some (calcLnRows residuals1 ln1P) @@ -7426,6 +7490,10 @@ private def certifyInductionSoundBestMatchLocalBinaryPair ln1RowsV? := ln1Rows2? else ln1RowsV? := some (calcLnRows residualsV ln1P) + if let some t0 := tLn10? then + let t1 ← ExceptT.lift IO.monoNanosNow + let dtMs := (t1 - t0) / 1000000 + ExceptT.lift <| IO.eprintln s!"timing:layer{l}:ln1 {dtMs}ms" let tightLayers : Nat := if tightPattern then Nat.max 1 tightPatternLayers else 0 let useTight1 := needUpdate1 && tightLayers > 0 && layer1 ≤ l + tightLayers @@ -7436,6 +7504,7 @@ private def certifyInductionSoundBestMatchLocalBinaryPair needUpdate2 && perRowPatternLayers > 0 && layer2 ≤ l + perRowPatternLayers let useTightV := useTight2 let usePerRowV := usePerRow2 + let needTightenNow : Bool := l == layer1 && useTight2 && causalPattern let skipAttnV := useTightV && causalPattern && seqLenEff < hdr.seqLen let shareUpdateV := residualsSameV && needUpdateV && !skipAttnV let shareUpdate := @@ -7505,33 +7574,60 @@ private def certifyInductionSoundBestMatchLocalBinaryPair ln1UnionV? := some (unionRowsFixed ln1RowsV) attnUnionV? := some (Array.replicate hdr.modelDim { lo := 0, hi := 0 }) let needUpdate := needUpdate1 || needUpdate2 + let tHeads0? ← + if timing then + let t0 ← ExceptT.lift IO.monoNanosNow + pure (some t0) + else + pure none + let mut qkReadMs : Nat := 0 + let mut vReadMs : Nat := 0 + let mut addAttnMs : Nat := 0 + let mut tightenMs : Nat := 0 + let mut tightenVOutMs : Nat := 0 + let mut tightenPrefixMs : Nat := 0 + let mut tightenRowMs : Nat := 0 + let mut tightenWaitMs : Nat := 0 for hIdx in [:hdr.numHeads] do let needValue := at2 && hIdx = head2 let needV := needUpdate || needValue let needQK := (at1 && hIdx = head1) || (at2 && hIdx = head2) if needQK then + let tQK0? ← + if timing then + let t0 ← ExceptT.lift IO.monoNanosNow + pure (some t0) + else + pure none let wq ← ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 let bQ ← ExceptT.mk <| readScaledFloatArray h hdr.headDim scalePow10 let wk ← ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 let bK ← ExceptT.mk <| readScaledFloatArray h hdr.headDim scalePow10 + if let some t0 := tQK0? then + let t1 ← ExceptT.lift IO.monoNanosNow + let dtMs := (t1 - t0) / 1000000 + qkReadMs := qkReadMs + dtMs let bQIntervals := intervalsFromScaled bQ slack let bKIntervals := intervalsFromScaled bK slack if at1 && hIdx = head1 then let ln1Rows1Exact := ln1Rows1Exact?.getD ln1Rows1 - if needV then + if needV && !needTightenNow then let task ← ExceptT.lift <| IO.asTask - (bestMatchPattern - layer1 head1 ln1Rows1Exact wq wk bQIntervals bKIntervals offset1 keyOffset1 - (useTasks := false)).run + (timeIt s!"layer{layer1}:pattern" <| + bestMatchPattern + layer1 head1 ln1Rows1Exact wq wk bQIntervals bKIntervals offset1 + keyOffset1 + (useTasks := false)).run p1Task? := some task else let p1 ← - bestMatchPattern - layer1 head1 ln1Rows1Exact wq wk bQIntervals bKIntervals offset1 keyOffset1 + timeIt s!"layer{layer1}:pattern" <| + bestMatchPattern + layer1 head1 ln1Rows1Exact wq wk bQIntervals bKIntervals offset1 keyOffset1 p1? := some p1 if at2 && hIdx = head2 then let ln1Rows2Exact := ln1Rows2Exact?.getD ln1Rows2 @@ -7539,14 +7635,17 @@ private def certifyInductionSoundBestMatchLocalBinaryPair let task ← ExceptT.lift <| IO.asTask - (bestMatchPattern - layer2 head2 ln1Rows2Exact wq wk bQIntervals bKIntervals offset2 keyOffset2 - (useTasks := false)).run + (timeIt s!"layer{layer2}:pattern" <| + bestMatchPattern + layer2 head2 ln1Rows2Exact wq wk bQIntervals bKIntervals offset2 + keyOffset2 + (useTasks := false)).run p2Task? := some task else let p2 ← - bestMatchPattern - layer2 head2 ln1Rows2Exact wq wk bQIntervals bKIntervals offset2 keyOffset2 + timeIt s!"layer{layer2}:pattern" <| + bestMatchPattern + layer2 head2 ln1Rows2Exact wq wk bQIntervals bKIntervals offset2 keyOffset2 p2? := some p2 else let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) @@ -7554,51 +7653,153 @@ private def certifyInductionSoundBestMatchLocalBinaryPair let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) let _ ← ExceptT.mk (skipF64Array h hdr.headDim) if needV then + let tV0? ← + if timing then + let t0 ← ExceptT.lift IO.monoNanosNow + pure (some t0) + else + pure none let wv ← ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 let bV ← ExceptT.mk <| readScaledFloatArray h hdr.headDim scalePow10 let wo ← ExceptT.mk <| readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 + if let some t0 := tV0? then + let t1 ← ExceptT.lift IO.monoNanosNow + let dtMs := (t1 - t0) / 1000000 + vReadMs := vReadMs + dtMs let bVIntervals := intervalsFromScaled bV slack + let wvIntervals := intervalsFromScaled wv slack + let woIntervals := intervalsFromScaled wo slack if needUpdate then if shareUpdate then + let tAdd0? ← + if timing then + let t0 ← ExceptT.lift IO.monoNanosNow + pure (some t0) + else + pure none let (attnRows', attnUnion') ← addAttn useTight1 ln1RowsShared ln1UnionShared? groupRowsShared? - attnRowsShared? attnUnionShared? wv wo bVIntervals + attnRowsShared? attnUnionShared? wvIntervals woIntervals bVIntervals + if let some t0 := tAdd0? then + let t1 ← ExceptT.lift IO.monoNanosNow + let dtMs := (t1 - t0) / 1000000 + addAttnMs := addAttnMs + dtMs attnRowsShared? := attnRows' attnUnionShared? := attnUnion' else if needUpdate1 then + let tAdd0? ← + if timing then + let t0 ← ExceptT.lift IO.monoNanosNow + pure (some t0) + else + pure none let (attnRows', attnUnion') ← addAttn useTight1 ln1Rows1 ln1Union1? groupRows1? - attnRows1? attnUnion1? wv wo bVIntervals + attnRows1? attnUnion1? wvIntervals woIntervals bVIntervals + if let some t0 := tAdd0? then + let t1 ← ExceptT.lift IO.monoNanosNow + let dtMs := (t1 - t0) / 1000000 + addAttnMs := addAttnMs + dtMs attnRows1? := attnRows' attnUnion1? := attnUnion' if needUpdate2 then if l == layer1 && hIdx == head1 && useTight2 && causalPattern then + let tTight0? ← + if timing then + let t0 ← ExceptT.lift IO.monoNanosNow + pure (some t0) + else + pure none + let tWait0? ← + if timing then + let t0 ← ExceptT.lift IO.monoNanosNow + pure (some t0) + else + pure none let p1 ← awaitPattern p1? p1Task? "missing best-match pattern cert for tightening" + if let some t0 := tWait0? then + let t1 ← ExceptT.lift IO.monoNanosNow + let dtMs := (t1 - t0) / 1000000 + tightenWaitMs := tightenWaitMs + dtMs p1? := some p1 - let vOutRows := calcVOutRows ln1Rows2 wv wo bVIntervals + let tVOut0? ← + if timing then + let t0 ← ExceptT.lift IO.monoNanosNow + pure (some t0) + else + pure none + let vOutRows := calcVOutRowsIntervals ln1Rows2 wvIntervals woIntervals bVIntervals + if let some t0 := tVOut0? then + let t1 ← ExceptT.lift IO.monoNanosNow + let dtMs := (t1 - t0) / 1000000 + tightenVOutMs := tightenVOutMs + dtMs + let tPrefix0? ← + if timing then + let t0 ← ExceptT.lift IO.monoNanosNow + pure (some t0) + else + pure none let mut headRows := prefixUnionRowsFixed vOutRows + if let some t0 := tPrefix0? then + let t1 ← ExceptT.lift IO.monoNanosNow + let dtMs := (t1 - t0) / 1000000 + tightenPrefixMs := tightenPrefixMs + dtMs let baseRow := headRows[queryPos]! + let tRow0? ← + if timing then + let t0 ← ExceptT.lift IO.monoNanosNow + pure (some t0) + else + pure none let tightRow ← tightenQueryRowLower baseRow vOutRows p1.bestMatchWeightLowerBound offset1 keyOffset1 + if let some t0 := tRow0? then + let t1 ← ExceptT.lift IO.monoNanosNow + let dtMs := (t1 - t0) / 1000000 + tightenRowMs := tightenRowMs + dtMs headRows := headRows.set! queryPos tightRow + if let some t0 := tTight0? then + let t1 ← ExceptT.lift IO.monoNanosNow + let dtMs := (t1 - t0) / 1000000 + tightenMs := tightenMs + dtMs match attnRows2? with | some rows => attnRows2? := some (addRowsFixed rows headRows) | none => throw "missing attnRows" else + let tAdd0? ← + if timing then + let t0 ← ExceptT.lift IO.monoNanosNow + pure (some t0) + else + pure none let (attnRows', attnUnion') ← addAttn useTight2 ln1Rows2 ln1Union2? groupRows2? - attnRows2? attnUnion2? wv wo bVIntervals + attnRows2? attnUnion2? wvIntervals woIntervals bVIntervals + if let some t0 := tAdd0? then + let t1 ← ExceptT.lift IO.monoNanosNow + let dtMs := (t1 - t0) / 1000000 + addAttnMs := addAttnMs + dtMs attnRows2? := attnRows' attnUnion2? := attnUnion' if needUpdateV && !shareUpdateV && !skipAttnV then + let tAdd0? ← + if timing then + let t0 ← ExceptT.lift IO.monoNanosNow + pure (some t0) + else + pure none let (attnRows', attnUnion') ← addAttn useTightV ln1RowsV ln1UnionV? groupRowsV? - attnRowsV? attnUnionV? wv wo bVIntervals + attnRowsV? attnUnionV? wvIntervals woIntervals bVIntervals + if let some t0 := tAdd0? then + let t1 ← ExceptT.lift IO.monoNanosNow + let dtMs := (t1 - t0) / 1000000 + addAttnMs := addAttnMs + dtMs attnRowsV? := attnRows' attnUnionV? := attnUnion' if needValue then @@ -7606,13 +7807,27 @@ private def certifyInductionSoundBestMatchLocalBinaryPair awaitPattern p2? p2Task? "missing best-match pattern cert for value bound" p2? := some p2 let vlogit ← - valueLogit ln1RowsV p2.bestMatchWeightLowerBound wv wo bVIntervals offset2 - keyOffset2 + timeIt s!"layer{layer2}:value_logit" <| + valueLogit ln1RowsV p2.bestMatchWeightLowerBound wvIntervals woIntervals + bVIntervals offset2 keyOffset2 vlogit? := some vlogit else let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) let _ ← ExceptT.mk (skipF64Array h hdr.headDim) let _ ← ExceptT.mk (skipF64Array h (hdr.headDim * hdr.modelDim)) + if let some t0 := tHeads0? then + let t1 ← ExceptT.lift IO.monoNanosNow + let dtMs := (t1 - t0) / 1000000 + ExceptT.lift <| IO.eprintln s!"timing:layer{l}:heads {dtMs}ms" + if timing then + ExceptT.lift <| IO.eprintln s!"timing:layer{l}:qk_read {qkReadMs}ms" + ExceptT.lift <| IO.eprintln s!"timing:layer{l}:v_read {vReadMs}ms" + ExceptT.lift <| IO.eprintln s!"timing:layer{l}:add_attn {addAttnMs}ms" + ExceptT.lift <| IO.eprintln s!"timing:layer{l}:tighten {tightenMs}ms" + ExceptT.lift <| IO.eprintln s!"timing:layer{l}:tighten_vout {tightenVOutMs}ms" + ExceptT.lift <| IO.eprintln s!"timing:layer{l}:tighten_prefix {tightenPrefixMs}ms" + ExceptT.lift <| IO.eprintln s!"timing:layer{l}:tighten_row {tightenRowMs}ms" + ExceptT.lift <| IO.eprintln s!"timing:layer{l}:tighten_wait {tightenWaitMs}ms" if p1?.isSome && p2?.isSome && vlogit?.isSome && !(needUpdate1 || needUpdate2) then match p1?, p2?, vlogit? with | some p1, some p2, some vlogit => @@ -7628,6 +7843,12 @@ private def certifyInductionSoundBestMatchLocalBinaryPair throw "induction head certificate failed internal consistency checks" | _, _, _ => throw "induction head certificate failed internal consistency checks" if needUpdate1 || needUpdate2 then + let tAttn0? ← + if timing then + let t0 ← ExceptT.lift IO.monoNanosNow + pure (some t0) + else + pure none let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) if shareUpdate then residuals1 ← applyAttn residuals1 useTight1 attnRowsShared? attnUnionShared? attnBias @@ -7639,6 +7860,16 @@ private def certifyInductionSoundBestMatchLocalBinaryPair residuals2 ← applyAttn residuals2 useTight2 attnRows2? attnUnion2? attnBias if needUpdateV && !shareUpdateV && !skipAttnV then residualsV ← applyAttn residualsV useTightV attnRowsV? attnUnionV? attnBias + if let some t0 := tAttn0? then + let t1 ← ExceptT.lift IO.monoNanosNow + let dtMs := (t1 - t0) / 1000000 + ExceptT.lift <| IO.eprintln s!"timing:layer{l}:attn_update {dtMs}ms" + let tMlp0? ← + if timing then + let t0 ← ExceptT.lift IO.monoNanosNow + pure (some t0) + else + pure none let wIn ← ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.hiddenDim) scalePow10 let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) @@ -7668,6 +7899,10 @@ private def certifyInductionSoundBestMatchLocalBinaryPair residualsSameV := true else residualsSameV := false + if let some t0 := tMlp0? then + let t1 ← ExceptT.lift IO.monoNanosNow + let dtMs := (t1 - t0) / 1000000 + ExceptT.lift <| IO.eprintln s!"timing:layer{l}:mlp_update {dtMs}ms" let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) From 4193e190fdcefe99d293e8a96879aa872cb4e3d8 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 22:04:36 +0100 Subject: [PATCH 059/244] Update documentation and CLI docstrings --- CLAIMS.md | 7 ++- Main.lean | 15 ++++++ Nfp/Discovery.lean | 12 +++-- Nfp/IO.lean | 5 ++ README.md | 113 +++++++++++++++++++++++++++++++++++++-- SOUNDNESS_LIMITATIONS.md | 12 +++-- 6 files changed, 150 insertions(+), 14 deletions(-) diff --git a/CLAIMS.md b/CLAIMS.md index bef758a..613ab25 100644 --- a/CLAIMS.md +++ b/CLAIMS.md @@ -6,7 +6,10 @@ what is heuristic, and what is not yet proven. | Claim | Status | Where | | --- | --- | --- | | Definitions of mixers/signed mixers and linearizations (ReLU, GeLU, LayerNorm, softmax) with basic lemmas (composition, diagonality, etc.) | Proven in Lean | `Nfp/SignedMixer.lean`, `Nfp/Linearization.lean` | -| The sound certificate checker validates internal arithmetic consistency for layer bounds and the total amplification factor | Soundly checked (Lean) | `Nfp/Sound/Cert.lean` | -| Sound bound formulas use exact `Rat` arithmetic (LayerNorm/softmax/GeLU envelopes); witness values are produced in untrusted code and then checked | Soundly checked formulas; untrusted witnesses | `Nfp/Sound/Bounds.lean`, `Nfp/Untrusted/SoundCompute.lean` | +| Model-level SOUND certificate checker validates internal arithmetic consistency and recomputes weight-derived bounds from model files | Soundly checked (Lean) | `Nfp/Sound/Cert.lean`, `Nfp/Sound/IO.lean`, `Nfp/Sound/BinaryPure.lean`, `Nfp/Sound/TextPure.lean` | +| Per-head contribution, head-pattern, and induction-head certificates (including best-match variants) have internal consistency checks | Soundly checked (Lean) | `Nfp/Sound/HeadCert.lean`, `Nfp/Sound/IO.lean` | +| Sound bound formulas use exact `Rat` arithmetic (LayerNorm/softmax/GeLU envelopes); witness values are produced in untrusted code and then checked | Soundly checked formulas; untrusted witnesses | `Nfp/Sound/Bounds.lean`, `Nfp/Untrusted/SoundCompute.lean`, `Nfp/Untrusted/SoundBinary.lean` | +| Best-match margin tightening uses untrusted logit bounds; verification checks only internal margin/softmax consistency | Partially checked (internal consistency only) | `Nfp/Sound/HeadCert.lean`, `Nfp/Sound/IO.lean`, `Nfp/Untrusted/SoundCompute.lean` | | Heuristic discovery and ranking of induction-style candidates | Heuristic | `Nfp/Discovery.lean`, CLI `induction` | +| Empirical causal verification via head ablation (competence/control/energy checks) | Heuristic | `Nfp/Verification.lean`, CLI `analyze --verify` / `induction --verify` | | End-to-end statement that certificate validity implies `||layerJacobian - I|| <= C` for Lean-defined Jacobians | Not yet proven | See `SOUNDNESS_LIMITATIONS.md` | diff --git a/Main.lean b/Main.lean index 9d5f60e..9fdf738 100644 --- a/Main.lean +++ b/Main.lean @@ -33,6 +33,9 @@ lake exe nfp analyze model.nfpt --threshold 0.1 --output report.txt # Search for induction heads with diagnostics enabled lake exe nfp induction model.nfpt --diagnostics --diagTop 5 --adaptive +# Microbenchmarks for analysis or induction +lake exe nfp bench model.nfpt --mode analysis --runs 5 + # Generate a sound-mode certificate report lake exe nfp certify model.nfpt @@ -54,6 +57,18 @@ lake exe nfp induction_cert model.nfpt --layer1 0 --head1 0 --layer2 1 --head2 0 # Instantiate RoPE bounds for a specific shape lake exe nfp rope --seqLen 4 --pairs 8 +# Check SOUND cache soundness (CI/fixtures) +lake exe nfp sound_cache_check model.nfpt + +# Benchmark SOUND cache build +lake exe nfp sound_cache_bench model.nfpt --runs 3 + +# Dump a small forward-pass slice +lake exe nfp dump model.nfpt --layer 0 --pos 0 --kind afterLayer + +# Empirical logit-diff check +lake exe nfp logit_diff model.nfpt 42 17 --autoNegative + # Show version lake exe nfp --version ``` diff --git a/Nfp/Discovery.lean b/Nfp/Discovery.lean index c49bb56..5d48e25 100644 --- a/Nfp/Discovery.lean +++ b/Nfp/Discovery.lean @@ -6,9 +6,12 @@ import Init.Data.Array.Extract /-! # Executable Circuit Discovery for Induction Heads -This module provides executable functions for discovering **certified induction heads** +This module provides executable functions for discovering **candidate induction heads** from concrete model weights. It bridges the theoretical framework (Frobenius norms, -pattern terms, faithfulness bounds) with practical verification of real neural networks. +pattern terms, faithfulness bounds) with practical, Float-based analysis of real networks. + +Important: these routines are **heuristic** and are not kernel-sound. Sound certification +lives in `Nfp.Sound.*`. ## Key Components @@ -19,7 +22,7 @@ pattern terms, faithfulness bounds) with practical verification of real neural n bounds without materializing the full (N·D)² Jacobian matrix. 3. **Discovery Functions**: Search algorithms that iterate over layer pairs to find - certified virtual heads (e.g., induction heads). + candidate virtual heads (e.g., induction heads). ## Mathematical Background @@ -29,7 +32,8 @@ decomposes as: `fullJacobian = valueTerm + patternTerm` where: - `patternTerm` captures how A shifts when input changes (the error term) The **faithfulness bound** states: if ‖patternTerm‖_F ≤ ε, then the simple -attention-based interpretation is ε-accurate. +attention-based interpretation is ε-accurate. In this module those bounds are +computed with `Float`s for speed and should be treated as diagnostics. ## Performance Optimizations diff --git a/Nfp/IO.lean b/Nfp/IO.lean index 625ac46..209329d 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -30,9 +30,14 @@ head_dim=64 hidden_dim=3072 vocab_size=50257 seq_len=1024 +layer_norm_eps=1e-5 +gelu_kind=tanh BINARY_START ``` +`layer_norm_eps` (or legacy `eps`) and `gelu_kind` (or legacy `gelu_deriv`) are required by the +SOUND certification path but are otherwise ignored by this loader. + Binary payload (little-endian, row-major, no markers): 1. TOKENS: `seq_len` × Int32 2. EMBEDDINGS: `seq_len` × `model_dim` × Float64 diff --git a/README.md b/README.md index 2944cb7..26676ba 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,8 @@ The Lean library defines the core math objects (finite probability, mixers, line At present, the checker does **not** include a bridge theorem that connects certificate validity to the Lean-defined Jacobian bounds (for example, a theorem of the form `||layerJacobian - I|| <= C`). Treat sound certificates as **internally consistent bound reports**, not as a fully formal end-to-end verification of transformer Jacobians. +Margin-based softmax tightening exists, but only **best-match margin evidence** is accepted today. Direct `--softmaxMargin` is rejected by the checker, and best-match logit bounds are generated in untrusted code and only checked for internal consistency. + For known gaps and ongoing upgrades, see `SOUNDNESS_LIMITATIONS.md`. ## North Star @@ -119,6 +121,8 @@ lake exe nfp --help ## Models The CLI expects a model file in **`.nfpt`** format (NFP_BINARY_V1). +Most commands (analysis/induction/diagnostics) require `NFP_BINARY_V1`; legacy `NFP_TEXT_V1/V2` +is supported only for local SOUND certification. - Create a local `models/` directory and place your `.nfpt` files there (the repo does not version model files; the author’s setup may have used local symlinks). - You can export GPT-2 weights from Hugging Face using the scripts in `scripts/`. @@ -134,13 +138,18 @@ head_dim=... hidden_dim=... vocab_size=... seq_len=... +layer_norm_eps=... +gelu_kind=... BINARY_START ``` The payload is raw little-endian bytes in a fixed order (tokens, embeddings, then weights). -Note: global sound certification supports `NFP_BINARY_V1`. Local sound certification -supports `NFP_BINARY_V1` (fixed-point union-box) and legacy `NFP_TEXT_V1/V2`. +Notes: +- `layer_norm_eps` (or legacy `eps`) and `gelu_kind` (or legacy `gelu_deriv`) are required for + SOUND certification. +- Global sound certification supports `NFP_BINARY_V1`. Local sound certification supports + `NFP_BINARY_V1` (fixed-point union-box) and legacy `NFP_TEXT_V1/V2`. ### Exporting GPT-2 to `.nfpt` @@ -278,8 +287,21 @@ If you want to override the embedded input, pass a separate input `.nfpt`: - LayerNorm ε is read from the model header (`layer_norm_eps`). - `gelu_kind` in the model header selects the GeLU derivative target (`tanh` or `exact`). - `--delta` sets the local ℓ∞ radius `δ` (default: `0`). Providing `--delta` enables local certification. +- `--input` optionally provides an input `.nfpt` file used for local certification; if omitted and the + model file embeds `EMBEDDINGS`, `certify` reuses the model file as its input source. +- `--softmaxMargin` provides a logit-margin lower bound, but it is currently **rejected** by the + verifier (use `--bestMatchMargins` instead). +- `--softmaxExpEffort` controls exp lower-bound effort used for margin-based softmax tightening (default: `1`). +- `--bestMatchMargins` runs a full best-match sweep (binary + local only) and tightens layer + softmax bounds using verified margin evidence. It is incompatible with `--softmaxMargin`. +- `--targetOffset` selects the target-token offset for best-match margins (default: `-1`). +- `--maxSeqLen` caps the sequence length used in best-match margin sweeps (default: `0` = full `seq_len`). +- `--tightPattern`, `--tightPatternLayers`, and `--perRowPatternLayers` control pattern tightening + during best-match sweeps. +- `--scalePow10` sets fixed-point scaling for best-match sweeps (default: `9`). +- `--noncausalPattern` disables the causal-prefix restriction (required for non-causal models). +- `--soundnessBits` sets dyadic sqrt precision for LayerNorm bounds (default: `20`). - `--partitionDepth` requests input partitioning depth (default: `0`; scaffold only, must remain `0` for now). -- `--input` optionally provides an input `.nfpt` file used for local certification. - `--output` (`-o`) writes the report to a file (otherwise it prints to stdout). ### `head_bounds` @@ -298,6 +320,7 @@ lake exe nfp head_bounds models/gpt2_rigorous.nfpt --delta 0.01 - `--delta` enables local head bounds; `--input` can override the embedded input. - LayerNorm ε is read from the model header (`layer_norm_eps`). +- `--soundnessBits` sets dyadic sqrt precision for LayerNorm bounds (default: `20`). - `--scalePow10` controls fixed-point scaling for global bounds (default: `9`). - `--output` (`-o`) writes the report to a file (otherwise it prints to stdout). @@ -316,13 +339,20 @@ lake exe nfp head_pattern models/gpt2_rigorous.nfpt --layer 0 --head 0 --delta 0 - `--offset` selects the target key position relative to the query (default: `-1` for previous token). - `--keyOffset` selects which key-position token is matched (default: `0` for the key token itself). - `--maxSeqLen` caps the sequence length analyzed for pattern bounds (default: `256`). +- `--input` optionally provides an input `.nfpt` file; required for legacy text models. - `--delta` sets the local input radius; LayerNorm ε is read from the model header (`layer_norm_eps`). +- `--soundnessBits` sets dyadic sqrt precision for LayerNorm bounds (default: `20`). - `--tightPattern` enables a slower but tighter pattern bound near the target layer. - `--tightPatternLayers` sets how many layers use tight bounds (default: `1`; implies `--tightPattern`). - `--perRowPatternLayers` sets how many layers use per-row MLP propagation (default: `0`). +- `--softmaxExpEffort` sets the exp lower-bound effort for margin-derived softmax bounds (default: `1`). +- `--scalePow10` sets fixed-point scaling for best-match bounds (default: `9`). +- `--noncausalPattern` disables the causal-prefix restriction (required for non-causal models). - `--bestMatch` switches to a single-query best-match bound (default query: last position). +- `--affine` uses affine Q/K dot bounds in best-match mode. - `--sweep` prints best-match bounds for all valid query positions (requires `--bestMatch`). - `--queryPos` chooses the query position for best-match bounds (default: last position). +- `--output` (`-o`) writes the report to a file (otherwise it prints to stdout). ### `induction_cert` @@ -342,12 +372,21 @@ lake exe nfp induction_cert models/gpt2_rigorous.nfpt \ - `--keyOffset1/--keyOffset2` adjust the key-token offsets (default: `0`; use `--offset2 0 --keyOffset2 -1` for copy-next induction). - `--target/--negative` optionally add a logit-diff lower bound using unembedding columns. +- `--input` optionally provides an input `.nfpt` file; required for legacy text models. +- `--delta` sets the local input radius (default: `0`). +- `--soundnessBits` sets dyadic sqrt precision for LayerNorm bounds (default: `20`). - `--tightPattern` enables a slower but tighter pattern bound near the target layer. - `--tightPatternLayers` sets how many layers use tight bounds (default: `1`; implies `--tightPattern`). - `--perRowPatternLayers` sets how many layers use per-row MLP propagation (default: `0`). +- `--softmaxExpEffort` sets the exp lower-bound effort for margin-derived softmax bounds (default: `1`). +- `--maxSeqLen` caps the sequence length analyzed for best-match bounds (default: `256`). +- `--scalePow10` sets fixed-point scaling for best-match bounds (default: `9`). +- `--noncausalPattern` disables the causal-prefix restriction (required for non-causal models). - `--bestMatch` switches to single-query best-match bounds (default query: last position). +- `--affine` uses affine Q/K dot bounds in best-match mode. - `--queryPos` chooses the query position for best-match bounds (default: last position). - `--iterTighten` iteratively tightens best-match bounds (tight/per-row layers and scale precision). +- `--output` (`-o`) writes the report to a file (otherwise it prints to stdout). ### `rope` @@ -360,6 +399,74 @@ lake exe nfp rope --seqLen 4 --pairs 8 - `--seqLen` instantiates the bound at the given sequence length (default: `4`). - `--pairs` sets the number of RoPE pairs; the dimension is `2 * pairs` (default: `8`). +### `bench` + +Runs repeatable microbenchmarks for analysis or induction search. + +```bash +lake exe nfp bench models/gpt2_rigorous.nfpt --mode analysis --runs 5 --repeats 1 +``` + +- `--mode` selects `analysis` or `induction` (default: `analysis`). +- `--runs` sets the number of timed runs (default: `5`). +- `--repeats` repeats the inner workload per run (default: `1`). +- `--threshold` sets the analyze threshold (default: `0.1`). +- `--minEffect` sets the induction minEffect (default: `0.0`). +- `--correct/--incorrect` override induction target tokens. +- `--verbose` prints per-run timing details. +- `--breakdown` emits per-phase averages (analysis only). + +### `sound_cache_check` + +Checks SOUND fixed-point cache soundness (CI / small fixtures). + +```bash +lake exe nfp sound_cache_check tests/fixtures/tiny_sound_binary.nfpt +``` + +- `--scalePow10` sets the fixed-point scale exponent (default: `9`). +- `--maxTokens` checks at most this many numeric tokens (default: `0` = all). + +### `sound_cache_bench` + +Benchmarks SOUND fixed-point cache build (text or binary). + +```bash +lake exe nfp sound_cache_bench models/gpt2_rigorous.nfpt --runs 3 +``` + +- `--scalePow10` sets the fixed-point scale exponent (default: `9`). +- `--runs` sets the number of benchmark runs (default: `1`). + +### `dump` + +Dumps a small forward-pass slice for PyTorch sanity checking. + +```bash +lake exe nfp dump models/gpt2_rigorous.nfpt --layer 0 --pos 0 --kind afterLayer +``` + +- `--layer` selects the layer index (default: `0`). +- `--pos` selects the token position / row index (default: `0`). +- `--take` limits columns from the start (default: `16`). +- `--kind` chooses `embeddings | layerInput | postAttn | afterLayer` (default: `afterLayer`). + +### `logit_diff` + +Computes an empirical logit-difference for a target vs. negative token. + +```bash +lake exe nfp logit_diff models/gpt2_rigorous.nfpt 42 17 --autoNegative +``` + +- `--pos` selects the token position (default: last position). +- `--input` provides an input `.nfpt` with TOKENS + EMBEDDINGS. +- `--autoNegative` uses the top non-target logit as the negative token. + +### `--version` + +Prints the CLI version string. + ## What “rigorous” means here At a high level, the “rigorous” path avoids heuristic operator-norm estimation and instead uses **upper bounds** derived from standard inequalities (examples you may see in logs): diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index 3700e24..46840a1 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -11,8 +11,10 @@ soundness upgrade. It is intentionally brief and human-readable. but it still treats softmax probability or margin evidence as external and does not derive those bounds from logits. - `partitionDepth > 0` is rejected with an explicit error (no partitioning logic yet). -- Affine arithmetic is only a scaffold (`Nfp/Sound/Affine.lean`) and not wired into SOUND certification. -- Softmax Jacobian bounds in the standard `certify` path now derive a probability interval from a +- Affine arithmetic is available via `--affine` for best-match Q/K dot bounds, but those dot-bound + computations are untrusted witness generation; the checker only validates the downstream + margin-to-probability derivations. +- Softmax Jacobian bounds in the standard `certify` path derive a probability interval from a global attention-score magnitude bound (LN1 max-abs + W_Q/W_K norms), but it is typically very loose and often collapses to `[0,1]`. Direct `--softmaxMargin` is still rejected because margin evidence is unverified. @@ -25,9 +27,9 @@ soundness upgrade. It is intentionally brief and human-readable. `softmaxExpEffort` chosen by iterative exp-portfolio tightening (early stop on low relative improvement). The verifier accepts any per-head effort ≤ the requested cap, but model-level certification still requires `--bestMatchMargins`. -- Best-match pattern certificates use a margin-derived softmax Jacobian bound with an - effort-indexed `expLB` (scaled Taylor + squaring). The lower-bound correctness of `expLB` - is not yet formalized in Lean. +- Best-match pattern certificates rely on untrusted interval/affine logit bounds to produce a + margin, and then use a margin-derived softmax Jacobian bound with an effort-indexed `expLB` + (scaled Taylor + squaring). The lower-bound correctness of `expLB` is not yet formalized in Lean. - GeLU derivative bounds are conservative envelopes; the exact interval supremum is not computed yet. - Attention Jacobian bounds now include an explicit pattern-term coefficient using max `W_Q/W_K` row-sum norms and a conservative LayerNorm output magnitude bound (`max|gamma|*sqrt(d)+max|beta|`), From 2da87eb4894db8b0b5e4cfe7acbe89770127920f Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 22:11:14 +0100 Subject: [PATCH 060/244] Clarify untrusted sound compute docstring --- Nfp/Untrusted/SoundCompute.lean | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Nfp/Untrusted/SoundCompute.lean b/Nfp/Untrusted/SoundCompute.lean index a8634fa..643c680 100644 --- a/Nfp/Untrusted/SoundCompute.lean +++ b/Nfp/Untrusted/SoundCompute.lean @@ -6851,7 +6851,7 @@ private def certifyInductionSoundBestMatchLocalBinaryPair (wvIntervals woIntervals : Array Fixed10Interval) (bV : Array Fixed10Interval) : Array (Array Fixed10Interval) := - let useTasks := rows.size > 32 + let useTasks := rows.size > 256 if useTasks then Id.run do let chunkSize : Nat := 16 From 1cafbfc6b85b3267443c22c6cc72b942c977f9ca Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 1 Jan 2026 22:30:18 +0100 Subject: [PATCH 061/244] Speed up pattern tasks and matmul pipeline --- Nfp/Untrusted/SoundCompute.lean | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/Nfp/Untrusted/SoundCompute.lean b/Nfp/Untrusted/SoundCompute.lean index 643c680..13c7fc2 100644 --- a/Nfp/Untrusted/SoundCompute.lean +++ b/Nfp/Untrusted/SoundCompute.lean @@ -6851,7 +6851,7 @@ private def certifyInductionSoundBestMatchLocalBinaryPair (wvIntervals woIntervals : Array Fixed10Interval) (bV : Array Fixed10Interval) : Array (Array Fixed10Interval) := - let useTasks := rows.size > 256 + let useTasks := rows.size > 32 if useTasks then Id.run do let chunkSize : Nat := 16 @@ -6934,7 +6934,7 @@ private def certifyInductionSoundBestMatchLocalBinaryPair let qRadii := addVecScaledInt qRadii0 bQRadii 1 let useTasksHere := useTasks && seqLenEff > 32 if useTasksHere then - let chunkSize : Nat := 16 + let chunkSize : Nat := 32 let numChunks : Nat := (seqLenEff + chunkSize - 1) / chunkSize let mut tasks : Array (Task (Option Int × Option Int)) := Array.mkEmpty numChunks let mut chunkIdx : Nat := 0 @@ -7504,6 +7504,7 @@ private def certifyInductionSoundBestMatchLocalBinaryPair needUpdate2 && perRowPatternLayers > 0 && layer2 ≤ l + perRowPatternLayers let useTightV := useTight2 let usePerRowV := usePerRow2 + let usePatternTasks : Bool := perRowPatternLayers = 0 let needTightenNow : Bool := l == layer1 && useTight2 && causalPattern let skipAttnV := useTightV && causalPattern && seqLenEff < hdr.seqLen let shareUpdateV := residualsSameV && needUpdateV && !skipAttnV @@ -7621,7 +7622,7 @@ private def certifyInductionSoundBestMatchLocalBinaryPair bestMatchPattern layer1 head1 ln1Rows1Exact wq wk bQIntervals bKIntervals offset1 keyOffset1 - (useTasks := false)).run + (useTasks := usePatternTasks)).run p1Task? := some task else let p1 ← @@ -7639,7 +7640,7 @@ private def certifyInductionSoundBestMatchLocalBinaryPair bestMatchPattern layer2 head2 ln1Rows2Exact wq wk bQIntervals bKIntervals offset2 keyOffset2 - (useTasks := false)).run + (useTasks := usePatternTasks)).run p2Task? := some task else let p2 ← From b023dc26623423827123855216e7928a72ac7ef3 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 11:49:15 +0100 Subject: [PATCH 062/244] Add margin debug logging to sound compute --- Nfp/Untrusted/SoundCompute.lean | 280 ++++++++++++++++++++++++++------ 1 file changed, 226 insertions(+), 54 deletions(-) diff --git a/Nfp/Untrusted/SoundCompute.lean b/Nfp/Untrusted/SoundCompute.lean index 13c7fc2..1bcbdb8 100644 --- a/Nfp/Untrusted/SoundCompute.lean +++ b/Nfp/Untrusted/SoundCompute.lean @@ -6726,6 +6726,8 @@ private def certifyInductionSoundBestMatchLocalBinaryPair let action : ExceptT String IO InductionHeadBestMatchSoundCert := do let timingEnabled ← ExceptT.lift <| IO.getEnv "NFP_TIMING" let timing : Bool := timingEnabled.isSome + let debugMarginEnabled ← ExceptT.lift <| IO.getEnv "NFP_MARGIN_DEBUG" + let debugMargin : Bool := debugMarginEnabled.isSome let timeIt {α : Type} (label : String) (work : ExceptT String IO α) : ExceptT String IO α := do if !timing then @@ -6846,6 +6848,45 @@ private def certifyInductionSoundBestMatchLocalBinaryPair (p : LayerNormParamsFixed) : Array (Array Fixed10Interval) := fixedLayerNormRowsApproxExact cfg rows p eps soundnessBits + let logRowsWidth (tag label : String) (rows : Array (Array Fixed10Interval)) : + ExceptT String IO Unit := do + if debugMargin then + if rows.isEmpty then + ExceptT.lift <| IO.eprintln s!"{tag}:{label} empty" + else + let mut maxW : Rat := 0 + for row in rows do + let w := centeredAbsSumFixed cfg row + if w > maxW then + maxW := w + let qW := + if queryPos < rows.size then centeredAbsSumFixed cfg rows[queryPos]! else 0 + ExceptT.lift <| + IO.eprintln s!"{tag}:{label} rows={rows.size} queryWidth={qW} maxWidth={maxW}" + let logVecWidth (tag label : String) (row : Array Fixed10Interval) : + ExceptT String IO Unit := do + if debugMargin then + let w := centeredAbsSumFixed cfg row + ExceptT.lift <| IO.eprintln s!"{tag}:{label} width={w}" + let calcVOutRowsIntervalsNoTask + (cfg : Fixed10Cfg) + (modelDim headDim : Nat) + (wvIntervals woIntervals : Array Fixed10Interval) + (bV : Array Fixed10Interval) + (rows : Array (Array Fixed10Interval)) + (start stop : Nat) : + Array (Array Fixed10Interval) := + Id.run do + let mut out : Array (Array Fixed10Interval) := Array.mkEmpty (stop - start) + let mut i := start + while i < stop do + let vHidden0 := matMulIntervalsFromIntervalsNoTask cfg modelDim headDim wvIntervals + (rows[i]!) + let vHidden := addVecFixed vHidden0 bV + let vOut := matMulIntervalsFromIntervalsNoTask cfg headDim modelDim woIntervals vHidden + out := out.push vOut + i := i + 1 + return out let calcVOutRowsIntervals (rows : Array (Array Fixed10Interval)) (wvIntervals woIntervals : Array Fixed10Interval) @@ -6863,18 +6904,8 @@ private def certifyInductionSoundBestMatchLocalBinaryPair let stop := min rows.size (start + chunkSize) tasks := tasks.push <| Task.spawn (fun _ => - Id.run do - let mut outChunk : Array (Array Fixed10Interval) := Array.mkEmpty (stop - start) - let mut i := start - while i < stop do - let vHidden0 := matMulIntervalsFromIntervalsNoTask cfg - hdr.modelDim hdr.headDim wvIntervals (rows[i]!) - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromIntervalsNoTask cfg - hdr.headDim hdr.modelDim woIntervals vHidden - outChunk := outChunk.push vOut - i := i + 1 - return outChunk) + calcVOutRowsIntervalsNoTask cfg hdr.modelDim hdr.headDim wvIntervals + woIntervals bV rows start stop) chunkIdx := chunkIdx + 1 let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size for t in tasks do @@ -6882,16 +6913,8 @@ private def certifyInductionSoundBestMatchLocalBinaryPair out := out.push row return out else - Id.run do - let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size - for row in rows do - let vHidden0 := matMulIntervalsFromIntervalsNoTask cfg - hdr.modelDim hdr.headDim wvIntervals row - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromIntervalsNoTask cfg - hdr.headDim hdr.modelDim woIntervals vHidden - out := out.push vOut - return out + calcVOutRowsIntervalsNoTask cfg hdr.modelDim hdr.headDim wvIntervals woIntervals bV rows 0 + rows.size let calcVOutIntervals (row : Array Fixed10Interval) (wvIntervals woIntervals : Array Fixed10Interval) @@ -6920,6 +6943,14 @@ private def certifyInductionSoundBestMatchLocalBinaryPair let keyOffsetNeg : Nat := Int.toNat (-keyOffset) let mut bestMatchLower? : Option Int := none let mut bestNonmatchUpper? : Option Int := none + let mut matchCount : Nat := 0 + let mut nonmatchCount : Nat := 0 + let mut matchLowerMin? : Option Int := none + let mut matchUpperMax? : Option Int := none + let mut nonmatchLowerMax? : Option Int := none + let mut nonmatchUpperMax? : Option Int := none + let mut matchWidthMax : Int := 0 + let mut nonmatchWidthMax : Int := 0 if useAffine then let bQCenters := bQ.map (fun x => (x.lo + x.hi).ediv (Int.ofNat 2)) let bKCenters := bK.map (fun x => (x.lo + x.hi).ediv (Int.ofNat 2)) @@ -6932,7 +6963,7 @@ private def certifyInductionSoundBestMatchLocalBinaryPair hdr.modelDim hdr.headDim wq qInputCenters qInputRadii let qCenters := addVecScaledInt qCenters0 bQCenters 1 let qRadii := addVecScaledInt qRadii0 bQRadii 1 - let useTasksHere := useTasks && seqLenEff > 32 + let useTasksHere := useTasks && !debugMargin && seqLenEff > 32 if useTasksHere then let chunkSize : Nat := 32 let numChunks : Nat := (seqLenEff + chunkSize - 1) / chunkSize @@ -7019,11 +7050,37 @@ private def certifyInductionSoundBestMatchLocalBinaryPair else tokens[j - keyOffsetNeg]! = targetTok if isMatch then + if debugMargin then + matchCount := matchCount + 1 + matchLowerMin? := + match matchLowerMin? with + | none => some dot.lo + | some v => some (min v dot.lo) + matchUpperMax? := + match matchUpperMax? with + | none => some dot.hi + | some v => some (max v dot.hi) + let width := dot.hi - dot.lo + if width > matchWidthMax then + matchWidthMax := width bestMatchLower? := match bestMatchLower? with | none => some dot.lo | some m => some (max m dot.lo) else + if debugMargin then + nonmatchCount := nonmatchCount + 1 + nonmatchLowerMax? := + match nonmatchLowerMax? with + | none => some dot.lo + | some v => some (max v dot.lo) + nonmatchUpperMax? := + match nonmatchUpperMax? with + | none => some dot.hi + | some v => some (max v dot.hi) + let width := dot.hi - dot.lo + if width > nonmatchWidthMax then + nonmatchWidthMax := width bestNonmatchUpper? := match bestNonmatchUpper? with | none => some dot.hi @@ -7035,7 +7092,7 @@ private def certifyInductionSoundBestMatchLocalBinaryPair hdr.modelDim hdr.headDim wq (ln1Rows[queryPos]!) let qRow := addVecFixed qRow0 bQ let kRows := - let useTasksHere := useTasks && ln1Rows.size > 32 + let useTasksHere := useTasks && !debugMargin && ln1Rows.size > 32 if useTasksHere then let tasks := ln1Rows.map (fun row => Task.spawn (fun _ => @@ -7065,11 +7122,37 @@ private def certifyInductionSoundBestMatchLocalBinaryPair else tokens[j - keyOffsetNeg]! = targetTok if isMatch then + if debugMargin then + matchCount := matchCount + 1 + matchLowerMin? := + match matchLowerMin? with + | none => some dot.lo + | some v => some (min v dot.lo) + matchUpperMax? := + match matchUpperMax? with + | none => some dot.hi + | some v => some (max v dot.hi) + let width := dot.hi - dot.lo + if width > matchWidthMax then + matchWidthMax := width bestMatchLower? := match bestMatchLower? with | none => some dot.lo | some m => some (max m dot.lo) else + if debugMargin then + nonmatchCount := nonmatchCount + 1 + nonmatchLowerMax? := + match nonmatchLowerMax? with + | none => some dot.lo + | some v => some (max v dot.lo) + nonmatchUpperMax? := + match nonmatchUpperMax? with + | none => some dot.hi + | some v => some (max v dot.hi) + let width := dot.hi - dot.lo + if width > nonmatchWidthMax then + nonmatchWidthMax := width bestNonmatchUpper? := match bestNonmatchUpper? with | none => some dot.hi @@ -7085,6 +7168,21 @@ private def certifyInductionSoundBestMatchLocalBinaryPair | none => bestMatchLower | some v => v let marginInt : Int := bestMatchLower - bestNonmatchUpper + if debugMargin then + let matchLowerMin := matchLowerMin?.getD bestMatchLower + let matchUpperMax := matchUpperMax?.getD bestMatchLower + let nonmatchLowerMax := nonmatchLowerMax?.getD bestNonmatchUpper + let nonmatchUpperMax := nonmatchUpperMax?.getD bestNonmatchUpper + let msg := + s!"pattern_debug:layer{layerIdx}:head{headIdx} " ++ + s!"targetTok={targetTok} queryPos={queryPos} offset={targetOffset} " ++ + s!"keyOffset={keyOffset} scalePow10={scalePow10} " ++ + s!"matches={matchCount} nonmatches={nonmatchCount} " ++ + s!"matchLoMaxInt={bestMatchLower} matchLoMinInt={matchLowerMin} " ++ + s!"matchHiMaxInt={matchUpperMax} nonmatchLoMaxInt={nonmatchLowerMax} " ++ + s!"nonmatchHiMaxInt={nonmatchUpperMax} marginInt={marginInt} " ++ + s!"matchWidthMaxInt={matchWidthMax} nonmatchWidthMaxInt={nonmatchWidthMax}" + ExceptT.lift <| IO.eprintln msg let bestMatchLowerRat := ratOfScaledInt scalePow10 bestMatchLower let bestNonmatchUpperRat := ratOfScaledInt scalePow10 bestNonmatchUpper let margin := ratOfScaledInt scalePow10 marginInt @@ -7363,6 +7461,7 @@ private def certifyInductionSoundBestMatchLocalBinaryPair | some u => return (attnRows?, some (addVecFixed u vOut)) | none => throw "missing attnUnion" let applyAttn + (label : String) (rows : Array (Array Fixed10Interval)) (useTight : Bool) (attnRows? : Option (Array (Array Fixed10Interval))) @@ -7373,22 +7472,30 @@ private def certifyInductionSoundBestMatchLocalBinaryPair match attnRows? with | some attnRows => let attnRows := addVecFixedRows attnRows attnBias - return addRowsFixed rows attnRows + logRowsWidth "attn_debug" s!"{label}:attn_rows" attnRows + let out := addRowsFixed rows attnRows + logRowsWidth "attn_debug" s!"{label}:out" out + return out | none => throw "missing attnRows" else match attnUnion? with | some attnUnion => let attnUnion := addVecFixed attnUnion attnBias - return addVecFixedRows rows attnUnion + logVecWidth "attn_debug" s!"{label}:attn_union" attnUnion + let out := addVecFixedRows rows attnUnion + logRowsWidth "attn_debug" s!"{label}:out" out + return out | none => throw "missing attnUnion" let applyMlp + (label : String) (rows : Array (Array Fixed10Interval)) (usePerRow : Bool) (p : LayerNormParamsFixed) (wIn wOut : Array Int) (bIn bOut : Array Fixed10Interval) : - Array (Array Fixed10Interval) := + ExceptT String IO (Array (Array Fixed10Interval)) := do let ln2Rows := calcLnRows rows p + logRowsWidth "ln2_debug" s!"{label}:ln2" ln2Rows let geluTargetUnion : GeluDerivTarget := if hdr.geluDerivTarget = .tanh then .exact else hdr.geluDerivTarget if usePerRow then @@ -7398,24 +7505,65 @@ private def certifyInductionSoundBestMatchLocalBinaryPair let wOutIntervals := intervalsFromScaled wOut slack let mlpRows := mlpRowsFromIntervals cfg geluTargetUnion hdr.modelDim hdr.hiddenDim wInIntervals wOutIntervals bIn bOut ln2Rows - addRowsFixed rows mlpRows + logRowsWidth "mlp_debug" s!"{label}:mlp_rows" mlpRows + let out := addRowsFixed rows mlpRows + logRowsWidth "mlp_debug" s!"{label}:out" out + return out | some idxs => - Id.run do - let ln2Union := unionRowsFixed ln2Rows - let mlpUnion := mlpRowFromScaled cfg geluTargetUnion slack - hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Union - let mut out := addVecFixedRows rows mlpUnion - for idx in idxs do - if idx < ln2Rows.size then - let mlpRow := mlpRowFromScaledNoTask cfg hdr.geluDerivTarget slack - hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut (ln2Rows[idx]!) - out := out.set! idx (addVecFixed (rows[idx]!) mlpRow) - return out + let ln2Union := unionRowsFixed ln2Rows + let mlpUnion := mlpRowFromScaled cfg geluTargetUnion slack + hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Union + logVecWidth "mlp_debug" s!"{label}:mlp_union" mlpUnion + let idxsValid := idxs.filter (fun idx => idx < ln2Rows.size) + let useTasksHere := idxsValid.size > 4 + let out := + if useTasksHere then + Id.run do + let chunkSize : Nat := 8 + let numChunks : Nat := (idxsValid.size + chunkSize - 1) / chunkSize + let mut tasks : Array (Task (Array (Nat × Array Fixed10Interval))) := + Array.mkEmpty numChunks + let mut chunkIdx : Nat := 0 + while chunkIdx < numChunks do + let start := chunkIdx * chunkSize + let stop := min idxsValid.size (start + chunkSize) + tasks := tasks.push <| + Task.spawn (fun _ => + Id.run do + let mut outChunk : Array (Nat × Array Fixed10Interval) := + Array.mkEmpty (stop - start) + let mut i := start + while i < stop do + let idx := idxsValid[i]! + let mlpRow := mlpRowFromScaledNoTask cfg hdr.geluDerivTarget slack + hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut (ln2Rows[idx]!) + outChunk := outChunk.push (idx, mlpRow) + i := i + 1 + return outChunk) + chunkIdx := chunkIdx + 1 + let mut out := addVecFixedRows rows mlpUnion + for t in tasks do + for (idx, mlpRow) in t.get do + out := out.set! idx (addVecFixed (rows[idx]!) mlpRow) + return out + else + Id.run do + let mut out := addVecFixedRows rows mlpUnion + for idx in idxsValid do + let mlpRow := mlpRowFromScaledNoTask cfg hdr.geluDerivTarget slack + hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut (ln2Rows[idx]!) + out := out.set! idx (addVecFixed (rows[idx]!) mlpRow) + return out + logRowsWidth "mlp_debug" s!"{label}:out" out + return out else let ln2Union := unionRowsFixed ln2Rows let mlpOut := mlpRowFromScaled cfg geluTargetUnion slack hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Union - addVecFixedRows rows mlpOut + logVecWidth "mlp_debug" s!"{label}:mlp_union" mlpOut + let out := addVecFixedRows rows mlpOut + logRowsWidth "mlp_debug" s!"{label}:out" out + return out let awaitPattern (pattern? : Option HeadBestMatchPatternCert) (task? : Option (Task (Except IO.Error (Except String HeadBestMatchPatternCert)))) @@ -7469,15 +7617,25 @@ private def certifyInductionSoundBestMatchLocalBinaryPair pure none let mut ln1RowsShared? : Option (Array (Array Fixed10Interval)) := none if residualsSame && (needRows1 || needRows2) then - ln1RowsShared? := some (calcLnRows residuals1 ln1P) + let rows := calcLnRows residuals1 ln1P + ln1RowsShared? := some rows + logRowsWidth "ln1_debug" s!"layer{l}:ln1_shared" rows let mut ln1Rows1? : Option (Array (Array Fixed10Interval)) := none let mut ln1Rows2? : Option (Array (Array Fixed10Interval)) := none if needRows1 then - ln1Rows1? := - some (ln1RowsShared?.getD (calcLnRows residuals1 ln1P)) + match ln1RowsShared? with + | some rows => ln1Rows1? := some rows + | none => + let rows := calcLnRows residuals1 ln1P + ln1Rows1? := some rows + logRowsWidth "ln1_debug" s!"layer{l}:ln1_1" rows if needRows2 then - ln1Rows2? := - some (ln1RowsShared?.getD (calcLnRows residuals2 ln1P)) + match ln1RowsShared? with + | some rows => ln1Rows2? := some rows + | none => + let rows := calcLnRows residuals2 ln1P + ln1Rows2? := some rows + logRowsWidth "ln1_debug" s!"layer{l}:ln1_2" rows let mut ln1Rows1Exact? : Option (Array (Array Fixed10Interval)) := none let mut ln1Rows2Exact? : Option (Array (Array Fixed10Interval)) := none if at1 then @@ -7489,7 +7647,9 @@ private def certifyInductionSoundBestMatchLocalBinaryPair if residualsSameV then ln1RowsV? := ln1Rows2? else - ln1RowsV? := some (calcLnRows residualsV ln1P) + let rows := calcLnRows residualsV ln1P + ln1RowsV? := some rows + logRowsWidth "ln1_debug" s!"layer{l}:ln1_v" rows if let some t0 := tLn10? then let t1 ← ExceptT.lift IO.monoNanosNow let dtMs := (t1 - t0) / 1000000 @@ -7852,15 +8012,23 @@ private def certifyInductionSoundBestMatchLocalBinaryPair pure none let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) if shareUpdate then - residuals1 ← applyAttn residuals1 useTight1 attnRowsShared? attnUnionShared? attnBias + residuals1 ← + applyAttn s!"layer{l}:attn_shared" residuals1 useTight1 attnRowsShared? + attnUnionShared? attnBias residuals2 := residuals1 else if needUpdate1 then - residuals1 ← applyAttn residuals1 useTight1 attnRows1? attnUnion1? attnBias + residuals1 ← + applyAttn s!"layer{l}:attn_1" residuals1 useTight1 attnRows1? attnUnion1? + attnBias if needUpdate2 then - residuals2 ← applyAttn residuals2 useTight2 attnRows2? attnUnion2? attnBias + residuals2 ← + applyAttn s!"layer{l}:attn_2" residuals2 useTight2 attnRows2? attnUnion2? + attnBias if needUpdateV && !shareUpdateV && !skipAttnV then - residualsV ← applyAttn residualsV useTightV attnRowsV? attnUnionV? attnBias + residualsV ← + applyAttn s!"layer{l}:attn_v" residualsV useTightV attnRowsV? attnUnionV? + attnBias if let some t0 := tAttn0? then let t1 ← ExceptT.lift IO.monoNanosNow let dtMs := (t1 - t0) / 1000000 @@ -7879,18 +8047,22 @@ private def certifyInductionSoundBestMatchLocalBinaryPair let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) let ln2P := ln2Params.getD l defP if shareUpdate then - residuals1 := applyMlp residuals1 usePerRow1 ln2P wIn wOut bIn bOut + residuals1 ← + applyMlp s!"layer{l}:mlp_shared" residuals1 usePerRow1 ln2P wIn wOut bIn bOut residuals2 := residuals1 else if needUpdate1 then - residuals1 := applyMlp residuals1 usePerRow1 ln2P wIn wOut bIn bOut + residuals1 ← + applyMlp s!"layer{l}:mlp_1" residuals1 usePerRow1 ln2P wIn wOut bIn bOut if needUpdate2 then - residuals2 := applyMlp residuals2 usePerRow2 ln2P wIn wOut bIn bOut + residuals2 ← + applyMlp s!"layer{l}:mlp_2" residuals2 usePerRow2 ln2P wIn wOut bIn bOut if needUpdateV then if shareUpdateV then residualsV := residuals2 else - residualsV := applyMlp residualsV usePerRowV ln2P wIn wOut bIn bOut + residualsV ← + applyMlp s!"layer{l}:mlp_v" residualsV usePerRowV ln2P wIn wOut bIn bOut if shareUpdate then residualsSame := true else if needUpdate1 && needUpdate2 then From 94b9d22f27b3ba4ef7d560b3aee62d288aa24eae Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 16:53:30 +0100 Subject: [PATCH 063/244] tabula rasa core: DAG local system scaffold --- AGENTS.md | 137 +- {Nfp => Legacy/Nfp}/Abstraction.lean | 0 {Nfp => Legacy/Nfp}/Appendix.lean | 0 {Nfp => Legacy/Nfp}/Attribution.lean | 0 {Nfp => Legacy/Nfp}/Discovery.lean | 0 {Nfp => Legacy/Nfp}/IO.lean | 0 {Nfp => Legacy/Nfp}/IO/Pure.lean | 0 {Nfp => Legacy/Nfp}/Induction.lean | 0 {Nfp => Legacy/Nfp}/Influence.lean | 0 {Nfp => Legacy/Nfp}/Layers.lean | 0 {Nfp => Legacy/Nfp}/Linearization.lean | 0 {Nfp => Legacy/Nfp}/MixerLocalSystem.lean | 0 {Nfp => Legacy/Nfp}/PCC.lean | 0 {Nfp => Legacy/Nfp}/Reroute/Heat.lean | 0 {Nfp => Legacy/Nfp}/Reroute/Partition.lean | 0 {Nfp => Legacy/Nfp}/SignedMixer.lean | 0 {Nfp => Legacy/Nfp}/Sound/Activation.lean | 0 {Nfp => Legacy/Nfp}/Sound/Affine.lean | 0 {Nfp => Legacy/Nfp}/Sound/BinaryPure.lean | 0 {Nfp => Legacy/Nfp}/Sound/Bounds.lean | 0 .../Nfp}/Sound/Bounds/Attention.lean | 0 {Nfp => Legacy/Nfp}/Sound/Bounds/Basic.lean | 0 {Nfp => Legacy/Nfp}/Sound/Bounds/Effort.lean | 0 {Nfp => Legacy/Nfp}/Sound/Bounds/Exp.lean | 0 {Nfp => Legacy/Nfp}/Sound/Bounds/Gelu.lean | 0 .../Nfp}/Sound/Bounds/LayerNorm.lean | 0 .../Nfp}/Sound/Bounds/MatrixNorm.lean | 0 .../Nfp}/Sound/Bounds/Portfolio.lean | 0 {Nfp => Legacy/Nfp}/Sound/Bounds/Softmax.lean | 0 {Nfp => Legacy/Nfp}/Sound/Bridge.lean | 0 {Nfp => Legacy/Nfp}/Sound/CachePure.lean | 0 {Nfp => Legacy/Nfp}/Sound/Cert.lean | 0 {Nfp => Legacy/Nfp}/Sound/Decimal.lean | 0 {Nfp => Legacy/Nfp}/Sound/Demo.lean | 0 {Nfp => Legacy/Nfp}/Sound/Fixed.lean | 0 {Nfp => Legacy/Nfp}/Sound/HeadCert.lean | 0 {Nfp => Legacy/Nfp}/Sound/IO.lean | 0 {Nfp => Legacy/Nfp}/Sound/Interval.lean | 0 {Nfp => Legacy/Nfp}/Sound/ModelHeader.lean | 0 {Nfp => Legacy/Nfp}/Sound/TextPure.lean | 0 {Nfp => Legacy/Nfp}/Uniqueness.lean | 0 .../Nfp}/Untrusted/SoundBinary.lean | 0 .../Nfp}/Untrusted/SoundCacheIO.lean | 0 .../Nfp}/Untrusted/SoundCompute.lean | 0 {Nfp => Legacy/Nfp}/Verification.lean | 0 Main.lean | 2426 +---------------- Nfp.lean | 282 +- Nfp/Cli.lean | 42 + Nfp/Core.lean | 7 + Nfp/Core/Basic.lean | 14 + Nfp/Mixer.lean | 195 +- Nfp/Mixer/Basic.lean | 37 + Nfp/Mixer/Operations.lean | 71 + Nfp/Prob.lean | 88 +- Nfp/Prob/Basic.lean | 33 + Nfp/Prob/Operations.lean | 52 + Nfp/System.lean | 8 + Nfp/System/Dag.lean | 45 + Nfp/System/LocalSystem.lean | 47 + lakefile.toml | 1 + 60 files changed, 416 insertions(+), 3069 deletions(-) rename {Nfp => Legacy/Nfp}/Abstraction.lean (100%) rename {Nfp => Legacy/Nfp}/Appendix.lean (100%) rename {Nfp => Legacy/Nfp}/Attribution.lean (100%) rename {Nfp => Legacy/Nfp}/Discovery.lean (100%) rename {Nfp => Legacy/Nfp}/IO.lean (100%) rename {Nfp => Legacy/Nfp}/IO/Pure.lean (100%) rename {Nfp => Legacy/Nfp}/Induction.lean (100%) rename {Nfp => Legacy/Nfp}/Influence.lean (100%) rename {Nfp => Legacy/Nfp}/Layers.lean (100%) rename {Nfp => Legacy/Nfp}/Linearization.lean (100%) rename {Nfp => Legacy/Nfp}/MixerLocalSystem.lean (100%) rename {Nfp => Legacy/Nfp}/PCC.lean (100%) rename {Nfp => Legacy/Nfp}/Reroute/Heat.lean (100%) rename {Nfp => Legacy/Nfp}/Reroute/Partition.lean (100%) rename {Nfp => Legacy/Nfp}/SignedMixer.lean (100%) rename {Nfp => Legacy/Nfp}/Sound/Activation.lean (100%) rename {Nfp => Legacy/Nfp}/Sound/Affine.lean (100%) rename {Nfp => Legacy/Nfp}/Sound/BinaryPure.lean (100%) rename {Nfp => Legacy/Nfp}/Sound/Bounds.lean (100%) rename {Nfp => Legacy/Nfp}/Sound/Bounds/Attention.lean (100%) rename {Nfp => Legacy/Nfp}/Sound/Bounds/Basic.lean (100%) rename {Nfp => Legacy/Nfp}/Sound/Bounds/Effort.lean (100%) rename {Nfp => Legacy/Nfp}/Sound/Bounds/Exp.lean (100%) rename {Nfp => Legacy/Nfp}/Sound/Bounds/Gelu.lean (100%) rename {Nfp => Legacy/Nfp}/Sound/Bounds/LayerNorm.lean (100%) rename {Nfp => Legacy/Nfp}/Sound/Bounds/MatrixNorm.lean (100%) rename {Nfp => Legacy/Nfp}/Sound/Bounds/Portfolio.lean (100%) rename {Nfp => Legacy/Nfp}/Sound/Bounds/Softmax.lean (100%) rename {Nfp => Legacy/Nfp}/Sound/Bridge.lean (100%) rename {Nfp => Legacy/Nfp}/Sound/CachePure.lean (100%) rename {Nfp => Legacy/Nfp}/Sound/Cert.lean (100%) rename {Nfp => Legacy/Nfp}/Sound/Decimal.lean (100%) rename {Nfp => Legacy/Nfp}/Sound/Demo.lean (100%) rename {Nfp => Legacy/Nfp}/Sound/Fixed.lean (100%) rename {Nfp => Legacy/Nfp}/Sound/HeadCert.lean (100%) rename {Nfp => Legacy/Nfp}/Sound/IO.lean (100%) rename {Nfp => Legacy/Nfp}/Sound/Interval.lean (100%) rename {Nfp => Legacy/Nfp}/Sound/ModelHeader.lean (100%) rename {Nfp => Legacy/Nfp}/Sound/TextPure.lean (100%) rename {Nfp => Legacy/Nfp}/Uniqueness.lean (100%) rename {Nfp => Legacy/Nfp}/Untrusted/SoundBinary.lean (100%) rename {Nfp => Legacy/Nfp}/Untrusted/SoundCacheIO.lean (100%) rename {Nfp => Legacy/Nfp}/Untrusted/SoundCompute.lean (100%) rename {Nfp => Legacy/Nfp}/Verification.lean (100%) create mode 100644 Nfp/Cli.lean create mode 100644 Nfp/Core.lean create mode 100644 Nfp/Core/Basic.lean create mode 100644 Nfp/Mixer/Basic.lean create mode 100644 Nfp/Mixer/Operations.lean create mode 100644 Nfp/Prob/Basic.lean create mode 100644 Nfp/Prob/Operations.lean create mode 100644 Nfp/System.lean create mode 100644 Nfp/System/Dag.lean create mode 100644 Nfp/System/LocalSystem.lean diff --git a/AGENTS.md b/AGENTS.md index 4c84131..dda03ed 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -240,107 +240,46 @@ Note: Recent Lean versions changed the story around well-founded recursion trans This is a *map*, not a prison. You may reshuffle if a better design emerges, but you **must** update this list in the same commit. -### 5.1 Core probability + mixing -- `Prob.lean` - - Probability vectors (`ProbVec`) on finite types; normalization + basic lemmas. -- `Mixer.lean` - - Row-stochastic operators (“mixers”), composition/pushforward/support tools. -- `Influence.lean` - - Influence specifications/families, capacities, scaling, and conversion into mixers. -- `Reroute/Partition.lean` - - Finite partitions + reroute planning structures. -- `Reroute/Heat.lean` - - Weighted reroute plans and induced “heat” distributions. -- `PCC.lean` - - Tracer/contribution utilities; discrete AUC / interval machinery. -- `Uniqueness.lean` - - `LocalSystem` for finite DAG mixing systems; uniqueness theorem(s) for tracers. -- `MixerLocalSystem.lean` - - Bridges mixers-on-DAGs to `LocalSystem` (interpreters using a topo order). -- `Appendix.lean` - - Supplemental lemmas and wrappers that don’t belong elsewhere. - -### 5.2 Interpretability / NN-oriented layers (mathematical, mostly proofs) -- `Layers.lean` - - Neural-network layer operations modeled as mixers; attribution/ablation/reachability laws. -- `Attribution.lean` - - Interpretability axioms and bridges from tracer-based notions. -- `Induction.lean` - - True induction head definitions and certification theorems (pattern + faithfulness + functional effect). -- `SignedMixer.lean` - - Signed/real-weight generalization (negative weights, affine maps, etc.). -- `Linearization.lean` - - Jacobian-based linearizations, decomposition results, deep composition/error theorems. -- `Abstraction.lean` - - Causal-consistency / intervention correspondence between “real” networks and abstract DAG views. - -### 5.3 Executable analysis & CLI surface -- `Discovery.lean` - - Executable discovery + bound computations and verification pipeline. - - May be performance-sensitive; keep proofs minimal and move them to proof modules when possible. -- `Sound/Decimal.lean` - - Exact parsing of decimal/scientific numerals into `Rat` for sound mode. -- `Sound/Activation.lean` - - Activation-derivative metadata + header parsing helpers for SOUND mode. -- `Sound/ModelHeader.lean` - - Pure header parsing helpers for SOUND metadata (e.g., LayerNorm epsilon). -- `Sound/BinaryPure.lean` - - Pure binary parsing/decoding helpers (IO-free, used by untrusted IO wrappers). -- `Sound/TextPure.lean` - - Pure text parsing helpers for model-weight norms used in sound verification. -- `Sound/CachePure.lean` - - Pure cache parsing/encoding helpers used by untrusted IO wrappers. -- `Untrusted/SoundBinary.lean` - - IO wrappers for the SOUND binary path (untrusted). -- `Untrusted/SoundCacheIO.lean` - - IO wrappers for the SOUND fixed-point cache (untrusted). -- `Sound/Bounds.lean` - - Umbrella import for SOUND bound utilities (exact `Rat` arithmetic, no Float). -- `Sound/Bounds/Basic.lean` - - Basic `Rat` helpers used across bounds modules. -- `Sound/Bounds/MatrixNorm.lean` - - Rat matrices, row-sum norms, and multiplicativity bounds. -- `Sound/Bounds/Gelu.lean` - - GeLU derivative envelopes (global). -- `Sound/Bounds/Exp.lean` - - exp lower bounds (scaled Taylor + squaring). -- `Sound/Bounds/Softmax.lean` - - Softmax Jacobian bounds and margin-derived weight helpers. -- `Sound/Bounds/Attention.lean` - - Attention pattern-term coefficient helpers. -- `Sound/Bounds/LayerNorm.lean` - - LayerNorm operator bounds (global/local). -- `Sound/Bounds/Portfolio.lean` - - Placeholder for portfolio combinators. -- `Sound/Bounds/Effort.lean` - - Placeholder for effort-tier records. -- `Sound/Affine.lean` - - Affine-form scaffolding for future local soundness improvements. -- `Sound/HeadCert.lean` - - Sound per-head contribution certificate structures. -- `Sound/Bridge.lean` - - Lemmas connecting `Rat`-level bounds to `SignedMixer` operator-norm bounds. -- `Sound/Cert.lean` - - Certificate/report structures and pretty-printing for SOUND-mode output. -- `Sound/IO.lean` - - Trusted IO wrappers: read inputs, call untrusted computation, and verify certificates. -- `Untrusted/SoundCompute.lean` - - IO-heavy witness generation for sound certificates (untrusted; verified by `Sound/IO`). -- `Sound/Demo.lean` - - Tiny end-to-end lemma demo bridging to `Linearization.operatorNormBound`. -- `Verification.lean` - - Executable **causal verification** via head ablation + runtime axiom checks (competence, control independence, energy matching). -- `IO.lean` - - Parsing/loading/tokenization/report formatting glue. - - **IO-only principle:** no heavy proofs; keep it as a bridge to filesystem/CLI. -- `IO/Pure.lean` - - Pure parsing, construction, and tokenization helpers used by `IO.lean`. +### 5.1 Core types +- `Nfp/Core/Basic.lean` + - `Mass` alias for nonnegative weights used throughout the rewrite. +- `Nfp/Core.lean` + - Aggregator for core shared definitions. + +### 5.2 Probability vectors +- `Nfp/Prob/Basic.lean` + - `ProbVec` definition + invariants. +- `Nfp/Prob/Operations.lean` + - `pure`, `mix`, and basic lemmas. +- `Nfp/Prob.lean` + - Aggregator for probability modules. + +### 5.3 Mixers +- `Nfp/Mixer/Basic.lean` + - `Mixer` structure and row-stochastic invariant. +- `Nfp/Mixer/Operations.lean` + - `push`, `comp`, and `id` mixers. +- `Nfp/Mixer.lean` + - Aggregator for mixer modules. + +### 5.4 Systems (DAG + local mixing) +- `Nfp/System/Dag.lean` + - DAG relation + parent/child sets. +- `Nfp/System/LocalSystem.lean` + - `LocalSystem` with edge support and row-sum invariants. +- `Nfp/System.lean` + - Aggregator for system modules. + +### 5.5 CLI surface +- `Nfp/Cli.lean` + - CLI commands and `main` implementation. - `Main.lean` - - CLI entrypoint and subcommand wiring. Keep it thin: - - argument parsing + calling into `Nfp.IO` / `Discovery` / `Nfp.Sound.*` reporting helpers, - - minimal logic, minimal proof content. + - Thin entrypoint delegating to `Nfp.Cli.main`. - `Nfp.lean` - - Top-level reexports and an axioms check (`#print axioms` / trust dashboard). + - Top-level reexports and axioms dashboard (`#print axioms`). + +### 5.6 Legacy (tabula rasa transition) +- Legacy modules live under `Legacy/Nfp/` as reference only and are not built by default. If you introduce a new conceptual layer: - either extend the closest existing file, diff --git a/Nfp/Abstraction.lean b/Legacy/Nfp/Abstraction.lean similarity index 100% rename from Nfp/Abstraction.lean rename to Legacy/Nfp/Abstraction.lean diff --git a/Nfp/Appendix.lean b/Legacy/Nfp/Appendix.lean similarity index 100% rename from Nfp/Appendix.lean rename to Legacy/Nfp/Appendix.lean diff --git a/Nfp/Attribution.lean b/Legacy/Nfp/Attribution.lean similarity index 100% rename from Nfp/Attribution.lean rename to Legacy/Nfp/Attribution.lean diff --git a/Nfp/Discovery.lean b/Legacy/Nfp/Discovery.lean similarity index 100% rename from Nfp/Discovery.lean rename to Legacy/Nfp/Discovery.lean diff --git a/Nfp/IO.lean b/Legacy/Nfp/IO.lean similarity index 100% rename from Nfp/IO.lean rename to Legacy/Nfp/IO.lean diff --git a/Nfp/IO/Pure.lean b/Legacy/Nfp/IO/Pure.lean similarity index 100% rename from Nfp/IO/Pure.lean rename to Legacy/Nfp/IO/Pure.lean diff --git a/Nfp/Induction.lean b/Legacy/Nfp/Induction.lean similarity index 100% rename from Nfp/Induction.lean rename to Legacy/Nfp/Induction.lean diff --git a/Nfp/Influence.lean b/Legacy/Nfp/Influence.lean similarity index 100% rename from Nfp/Influence.lean rename to Legacy/Nfp/Influence.lean diff --git a/Nfp/Layers.lean b/Legacy/Nfp/Layers.lean similarity index 100% rename from Nfp/Layers.lean rename to Legacy/Nfp/Layers.lean diff --git a/Nfp/Linearization.lean b/Legacy/Nfp/Linearization.lean similarity index 100% rename from Nfp/Linearization.lean rename to Legacy/Nfp/Linearization.lean diff --git a/Nfp/MixerLocalSystem.lean b/Legacy/Nfp/MixerLocalSystem.lean similarity index 100% rename from Nfp/MixerLocalSystem.lean rename to Legacy/Nfp/MixerLocalSystem.lean diff --git a/Nfp/PCC.lean b/Legacy/Nfp/PCC.lean similarity index 100% rename from Nfp/PCC.lean rename to Legacy/Nfp/PCC.lean diff --git a/Nfp/Reroute/Heat.lean b/Legacy/Nfp/Reroute/Heat.lean similarity index 100% rename from Nfp/Reroute/Heat.lean rename to Legacy/Nfp/Reroute/Heat.lean diff --git a/Nfp/Reroute/Partition.lean b/Legacy/Nfp/Reroute/Partition.lean similarity index 100% rename from Nfp/Reroute/Partition.lean rename to Legacy/Nfp/Reroute/Partition.lean diff --git a/Nfp/SignedMixer.lean b/Legacy/Nfp/SignedMixer.lean similarity index 100% rename from Nfp/SignedMixer.lean rename to Legacy/Nfp/SignedMixer.lean diff --git a/Nfp/Sound/Activation.lean b/Legacy/Nfp/Sound/Activation.lean similarity index 100% rename from Nfp/Sound/Activation.lean rename to Legacy/Nfp/Sound/Activation.lean diff --git a/Nfp/Sound/Affine.lean b/Legacy/Nfp/Sound/Affine.lean similarity index 100% rename from Nfp/Sound/Affine.lean rename to Legacy/Nfp/Sound/Affine.lean diff --git a/Nfp/Sound/BinaryPure.lean b/Legacy/Nfp/Sound/BinaryPure.lean similarity index 100% rename from Nfp/Sound/BinaryPure.lean rename to Legacy/Nfp/Sound/BinaryPure.lean diff --git a/Nfp/Sound/Bounds.lean b/Legacy/Nfp/Sound/Bounds.lean similarity index 100% rename from Nfp/Sound/Bounds.lean rename to Legacy/Nfp/Sound/Bounds.lean diff --git a/Nfp/Sound/Bounds/Attention.lean b/Legacy/Nfp/Sound/Bounds/Attention.lean similarity index 100% rename from Nfp/Sound/Bounds/Attention.lean rename to Legacy/Nfp/Sound/Bounds/Attention.lean diff --git a/Nfp/Sound/Bounds/Basic.lean b/Legacy/Nfp/Sound/Bounds/Basic.lean similarity index 100% rename from Nfp/Sound/Bounds/Basic.lean rename to Legacy/Nfp/Sound/Bounds/Basic.lean diff --git a/Nfp/Sound/Bounds/Effort.lean b/Legacy/Nfp/Sound/Bounds/Effort.lean similarity index 100% rename from Nfp/Sound/Bounds/Effort.lean rename to Legacy/Nfp/Sound/Bounds/Effort.lean diff --git a/Nfp/Sound/Bounds/Exp.lean b/Legacy/Nfp/Sound/Bounds/Exp.lean similarity index 100% rename from Nfp/Sound/Bounds/Exp.lean rename to Legacy/Nfp/Sound/Bounds/Exp.lean diff --git a/Nfp/Sound/Bounds/Gelu.lean b/Legacy/Nfp/Sound/Bounds/Gelu.lean similarity index 100% rename from Nfp/Sound/Bounds/Gelu.lean rename to Legacy/Nfp/Sound/Bounds/Gelu.lean diff --git a/Nfp/Sound/Bounds/LayerNorm.lean b/Legacy/Nfp/Sound/Bounds/LayerNorm.lean similarity index 100% rename from Nfp/Sound/Bounds/LayerNorm.lean rename to Legacy/Nfp/Sound/Bounds/LayerNorm.lean diff --git a/Nfp/Sound/Bounds/MatrixNorm.lean b/Legacy/Nfp/Sound/Bounds/MatrixNorm.lean similarity index 100% rename from Nfp/Sound/Bounds/MatrixNorm.lean rename to Legacy/Nfp/Sound/Bounds/MatrixNorm.lean diff --git a/Nfp/Sound/Bounds/Portfolio.lean b/Legacy/Nfp/Sound/Bounds/Portfolio.lean similarity index 100% rename from Nfp/Sound/Bounds/Portfolio.lean rename to Legacy/Nfp/Sound/Bounds/Portfolio.lean diff --git a/Nfp/Sound/Bounds/Softmax.lean b/Legacy/Nfp/Sound/Bounds/Softmax.lean similarity index 100% rename from Nfp/Sound/Bounds/Softmax.lean rename to Legacy/Nfp/Sound/Bounds/Softmax.lean diff --git a/Nfp/Sound/Bridge.lean b/Legacy/Nfp/Sound/Bridge.lean similarity index 100% rename from Nfp/Sound/Bridge.lean rename to Legacy/Nfp/Sound/Bridge.lean diff --git a/Nfp/Sound/CachePure.lean b/Legacy/Nfp/Sound/CachePure.lean similarity index 100% rename from Nfp/Sound/CachePure.lean rename to Legacy/Nfp/Sound/CachePure.lean diff --git a/Nfp/Sound/Cert.lean b/Legacy/Nfp/Sound/Cert.lean similarity index 100% rename from Nfp/Sound/Cert.lean rename to Legacy/Nfp/Sound/Cert.lean diff --git a/Nfp/Sound/Decimal.lean b/Legacy/Nfp/Sound/Decimal.lean similarity index 100% rename from Nfp/Sound/Decimal.lean rename to Legacy/Nfp/Sound/Decimal.lean diff --git a/Nfp/Sound/Demo.lean b/Legacy/Nfp/Sound/Demo.lean similarity index 100% rename from Nfp/Sound/Demo.lean rename to Legacy/Nfp/Sound/Demo.lean diff --git a/Nfp/Sound/Fixed.lean b/Legacy/Nfp/Sound/Fixed.lean similarity index 100% rename from Nfp/Sound/Fixed.lean rename to Legacy/Nfp/Sound/Fixed.lean diff --git a/Nfp/Sound/HeadCert.lean b/Legacy/Nfp/Sound/HeadCert.lean similarity index 100% rename from Nfp/Sound/HeadCert.lean rename to Legacy/Nfp/Sound/HeadCert.lean diff --git a/Nfp/Sound/IO.lean b/Legacy/Nfp/Sound/IO.lean similarity index 100% rename from Nfp/Sound/IO.lean rename to Legacy/Nfp/Sound/IO.lean diff --git a/Nfp/Sound/Interval.lean b/Legacy/Nfp/Sound/Interval.lean similarity index 100% rename from Nfp/Sound/Interval.lean rename to Legacy/Nfp/Sound/Interval.lean diff --git a/Nfp/Sound/ModelHeader.lean b/Legacy/Nfp/Sound/ModelHeader.lean similarity index 100% rename from Nfp/Sound/ModelHeader.lean rename to Legacy/Nfp/Sound/ModelHeader.lean diff --git a/Nfp/Sound/TextPure.lean b/Legacy/Nfp/Sound/TextPure.lean similarity index 100% rename from Nfp/Sound/TextPure.lean rename to Legacy/Nfp/Sound/TextPure.lean diff --git a/Nfp/Uniqueness.lean b/Legacy/Nfp/Uniqueness.lean similarity index 100% rename from Nfp/Uniqueness.lean rename to Legacy/Nfp/Uniqueness.lean diff --git a/Nfp/Untrusted/SoundBinary.lean b/Legacy/Nfp/Untrusted/SoundBinary.lean similarity index 100% rename from Nfp/Untrusted/SoundBinary.lean rename to Legacy/Nfp/Untrusted/SoundBinary.lean diff --git a/Nfp/Untrusted/SoundCacheIO.lean b/Legacy/Nfp/Untrusted/SoundCacheIO.lean similarity index 100% rename from Nfp/Untrusted/SoundCacheIO.lean rename to Legacy/Nfp/Untrusted/SoundCacheIO.lean diff --git a/Nfp/Untrusted/SoundCompute.lean b/Legacy/Nfp/Untrusted/SoundCompute.lean similarity index 100% rename from Nfp/Untrusted/SoundCompute.lean rename to Legacy/Nfp/Untrusted/SoundCompute.lean diff --git a/Nfp/Verification.lean b/Legacy/Nfp/Verification.lean similarity index 100% rename from Nfp/Verification.lean rename to Legacy/Nfp/Verification.lean diff --git a/Main.lean b/Main.lean index 9fdf738..6f5e59a 100644 --- a/Main.lean +++ b/Main.lean @@ -1,2425 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Cli -import Nfp.IO -import Nfp.Linearization -import Nfp.Untrusted.SoundCacheIO -import Nfp.Verification -import Nfp.Sound.IO -import Std.Time.Format +import Nfp.Cli -/-! -# NFP CLI: Circuit Verification Command-Line Tool - -This is the main entry point for the NFP circuit verification tool. - -## Usage - -Build the executable: -```bash -lake build nfp -``` - -List the available subcommands: -```bash -lake exe nfp --help -``` - -Example invocations (see README for full flag descriptions): -```bash -# Analyze a model and write a report to a file -lake exe nfp analyze model.nfpt --threshold 0.1 --output report.txt - -# Search for induction heads with diagnostics enabled -lake exe nfp induction model.nfpt --diagnostics --diagTop 5 --adaptive - -# Microbenchmarks for analysis or induction -lake exe nfp bench model.nfpt --mode analysis --runs 5 - -# Generate a sound-mode certificate report -lake exe nfp certify model.nfpt - -# Local (input-dependent) sound-mode certificate report -# (NFP_BINARY_V1 always embeds inputs; legacy text requires an EMBEDDINGS section.) -lake exe nfp certify model.nfpt --delta 1/100 - -# Sound per-head contribution bounds (global or local with --delta) -lake exe nfp head_bounds model.nfpt --delta 1/100 - -# Sound attention pattern bounds for a single head (binary only) -lake exe nfp head_pattern model.nfpt --layer 0 --head 0 --delta 1/100 --offset -1 - -# Sound induction head certificate (binary only) -lake exe nfp induction_cert model.nfpt --layer1 0 --head1 0 --layer2 1 --head2 0 \ - --coord 0 --delta 1/100 --offset1 -1 --offset2 0 --keyOffset2 -1 \ - --target 42 --negative 17 - -# Instantiate RoPE bounds for a specific shape -lake exe nfp rope --seqLen 4 --pairs 8 - -# Check SOUND cache soundness (CI/fixtures) -lake exe nfp sound_cache_check model.nfpt - -# Benchmark SOUND cache build -lake exe nfp sound_cache_bench model.nfpt --runs 3 - -# Dump a small forward-pass slice -lake exe nfp dump model.nfpt --layer 0 --pos 0 --kind afterLayer - -# Empirical logit-diff check -lake exe nfp logit_diff model.nfpt 42 17 --autoNegative - -# Show version -lake exe nfp --version -``` --/ - -open Cli Nfp - -private def fmtFloat (x : Float) : String := - toString x - -/-! ## Stdout logging -/ - -private structure StdoutLogCtx where - handle : IO.FS.Handle - pathRef : IO.Ref System.FilePath - pendingRef : IO.Ref Bool - timestamp : String - -initialize stdoutLogCtxRef : IO.Ref (Option StdoutLogCtx) ← IO.mkRef none - -private def sanitizeFileComponent (s : String) : String := - s.map fun c => - if c.isAlphanum || c = '_' || c = '-' || c = '.' then c else '_' - -private def timestampNowForLog : IO String := do - let dt ← Std.Time.ZonedDateTime.now - let dateStr := s!"{dt.toPlainDate}" - let timeRaw := s!"{dt.toPlainTime}" - let timeNoFrac := (timeRaw.splitOn ".").getD 0 timeRaw - let timeStr := timeNoFrac.replace ":" "-" - return s!"{dateStr}-{timeStr}" - -private def openPendingStdoutLog : IO StdoutLogCtx := do - let logsDir : System.FilePath := "logs" - IO.FS.createDirAll logsDir - let ts ← timestampNowForLog - let path : System.FilePath := logsDir / s!"{ts}_pending.log" - let h ← IO.FS.Handle.mk path .write - let pathRef ← IO.mkRef path - let pendingRef ← IO.mkRef true - return { handle := h, pathRef := pathRef, pendingRef := pendingRef, timestamp := ts } - -private def mkTeeStream (out log : IO.FS.Stream) : IO.FS.Stream := - { flush := do out.flush; log.flush - read := fun n => out.read n - write := fun b => do out.write b; log.write b - getLine := out.getLine - putStr := fun s => do out.putStr s; log.putStr s - isTty := out.isTty } - -private def setStdoutLogName (name : String) : IO Unit := do - let some ctx ← stdoutLogCtxRef.get | return () - let pending ← ctx.pendingRef.get - if !pending then - return () - let oldPath ← ctx.pathRef.get - let logsDir : System.FilePath := "logs" - let safeName := sanitizeFileComponent name - let newPath : System.FilePath := logsDir / s!"{ctx.timestamp}_{safeName}.log" - try - IO.FS.rename oldPath newPath - ctx.pathRef.set newPath - ctx.pendingRef.set false - catch - | _ => - -- If rename fails, keep the pending filename but continue. - pure () - -private def setStdoutLogNameFromModelPath (modelPath : String) : IO Unit := do - let p : System.FilePath := modelPath - let stem := p.fileStem.getD (p.fileName.getD "model") - setStdoutLogName stem - -/-- Write a report to stdout or to a file if an output path is provided. -/ -private def writeReport (outputPath? : Option System.FilePath) (report : String) : IO Unit := do - match outputPath? with - | some path => - IO.FS.writeFile path report - IO.println s!"Report written to {path}" - | none => - IO.println report - -/-- Check whether a `.nfpt` file contains an `EMBEDDINGS` section before the first `LAYER`. - -For `NFP_BINARY_V1`, embeddings are always present, so this returns true. This is used to decide -whether `nfp certify --delta ...` can default to using the model file as its own input source -(so users don't have to pass `--input model.nfpt`). -/ -private def hasEmbeddingsBeforeLayers (path : System.FilePath) : IO Bool := do - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let line ← h.getLine - if line.isEmpty then - return false - let s := line.trim - if s = "NFP_BINARY_V1" then - return true - -- Header: read until blank line (text format only). - let mut seenHeader : Bool := false - if s.startsWith "NFP_TEXT" then - seenHeader := true - while true do - let line ← h.getLine - if line.isEmpty then - return false - let s := line.trim - if !seenHeader then - if s.startsWith "NFP_TEXT" then - seenHeader := true - continue - if s.isEmpty then - break - -- After the header, `EMBEDDINGS` (if present) must appear before any layer payload. - while true do - let line ← h.getLine - if line.isEmpty then - return false - let s := line.trim - if s = "EMBEDDINGS" then - return true - if s.startsWith "LAYER" then - return false - return false - -private def printHeadDiagnostics (label : String) (data : PrecomputedHeadData) : IO Unit := do - let attnFrob : Float := Float.sqrt data.attentionFrobeniusNormSq - let patternRecon : Float := - (data.softmaxJacobianOpEst / data.scaleFactor) * - data.inputNorm * data.queryKeyAlignSchurNorm * data.valueOutputProjSchurNorm - let valueRecon : Float := attnFrob * data.valueOutputProjNorm - let epsRecon : Float := - if valueRecon < 1e-10 then Float.inf else patternRecon / valueRecon - let patternCached := data.patternTermBoundCached - let valueCached := data.valueTermNormCached - let epsCached := data.faithfulnessRatioCached - IO.println s!" {label} L{data.layerIdx}H{data.headIdx}:" - IO.println s!" softmaxOpBound = {fmtFloat data.softmaxJacobianOpEst}" - IO.println s!" softmaxParts = rowMaxP={fmtFloat data.softmaxRowMaxP}" - IO.println s!" = rowTrace={fmtFloat data.softmaxRowTraceBound}" - IO.println s!" = rowMoment={fmtFloat data.softmaxRowMomentBound}" - IO.println s!" = rowGersh={fmtFloat data.softmaxRowGershBound}" - IO.println s!" = rowUsed={fmtFloat data.softmaxRowBoundUsed}" - IO.println s!" = fallbackRows={data.softmaxRowsFallback}" - IO.println s!" scaleFactor = {fmtFloat data.scaleFactor}" - IO.println s!" inputNorm = {fmtFloat data.inputNorm}" - IO.println s!" inputOpBound = {fmtFloat data.inputOpBound}" - IO.println s!" qkOpBoundUsed = {fmtFloat data.queryKeyAlignSchurNorm}" - IO.println s!" qkActFrob(c) = {fmtFloat data.qkActFrobBound}" - IO.println s!" kqActFrob(c) = {fmtFloat data.kqActFrobBound}" - IO.println s!" qkActOp(c) = {fmtFloat data.qkActOpBound}" - IO.println s!" kqActOp(c) = {fmtFloat data.kqActOpBound}" - let qkActOpUbStr : String := - match data.patternBoundParts? with - | some parts => fmtFloat parts.qkActOpUb - | none => "n/a" - let kqActOpUbStr : String := - match data.patternBoundParts? with - | some parts => fmtFloat parts.kqActOpUb - | none => "n/a" - IO.println s!" qkActOpUbUsed = {qkActOpUbStr}" - IO.println s!" kqActOpUbUsed = {kqActOpUbStr}" - IO.println s!" qkActOpSource = {data.qkActOpBoundSource}" - IO.println s!" kqActOpSource = {data.kqActOpBoundSource}" - let vOpUbStr : String := - match data.patternBoundParts? with - | some parts => fmtFloat parts.vOpUb - | none => "n/a" - let vOpUbWOStr : String := - match data.patternBoundParts? with - | some parts => fmtFloat parts.vOpUbWO - | none => "n/a" - IO.println s!" vOpUbUsed = {vOpUbStr}" - IO.println s!" vOpUbWOUsed = {vOpUbWOStr}" - IO.println s!" qOpBoundAct = {fmtFloat data.qOpBoundAct}" - IO.println s!" kOpBoundAct = {fmtFloat data.kOpBoundAct}" - IO.println s!" vOpBoundAct = {fmtFloat data.vOpBoundAct}" - IO.println s!" qkFrob = {fmtFloat data.queryKeyAlignNorm}" - IO.println s!" wqOpGram = {fmtFloat data.wqOpGram}" - IO.println s!" wkOpGram = {fmtFloat data.wkOpGram}" - IO.println s!" qkFactorGram = {fmtFloat data.qkFactorGram}" - let qkDenseSchurStr : String := - match data.diag? with - | some diag => fmtFloat diag.qkDenseSchur.get - | none => "n/a" - IO.println s!" qkCandidates = denseSchur={qkDenseSchurStr}" - IO.println s!" = denseFrob={fmtFloat data.qkDenseFrob}" - IO.println s!" = denseGram={fmtFloat data.qkDenseGram}" - IO.println s!" = denseBrauer={fmtFloat data.qkDenseBrauer}" - let qkBrauerOk : String := - if data.qkDenseBrauer ≤ data.qkDenseGram then "true" else "false" - IO.println s!" = denseBrauer≤denseGram={qkBrauerOk}" - IO.println s!" = factorSchur={fmtFloat data.qkFactorSchur}" - IO.println s!" = factorFrob={fmtFloat data.qkFactorFrob}" - IO.println s!" = factorGram={fmtFloat data.qkFactorGram}" - IO.println s!" voOpBoundUsed = {fmtFloat data.valueOutputProjSchurNorm}" - IO.println s!" voFrob = {fmtFloat data.valueOutputProjNorm}" - IO.println s!" wvOpGram = {fmtFloat data.wvOpGram}" - IO.println s!" woOpGram = {fmtFloat data.woOpGram}" - IO.println s!" voFactorGram = {fmtFloat data.voFactorGram}" - let voDenseSchurStr : String := - match data.diag? with - | some diag => fmtFloat diag.voDenseSchur.get - | none => "n/a" - IO.println s!" voCandidates = denseSchur={voDenseSchurStr}" - IO.println s!" = denseFrob={fmtFloat data.voDenseFrob}" - IO.println s!" = denseGram={fmtFloat data.voDenseGram}" - IO.println s!" = denseBrauer={fmtFloat data.voDenseBrauer}" - let voBrauerOk : String := - if data.voDenseBrauer ≤ data.voDenseGram then "true" else "false" - IO.println s!" = denseBrauer≤denseGram={voBrauerOk}" - IO.println s!" = factorSchur={fmtFloat data.voFactorSchur}" - IO.println s!" = factorFrob={fmtFloat data.voFactorFrob}" - IO.println s!" = factorGram={fmtFloat data.voFactorGram}" - IO.println s!" attnFrob = {fmtFloat attnFrob}" - IO.println s!" patternCached = {fmtFloat patternCached}" - IO.println s!" valueCached = {fmtFloat valueCached}" - IO.println s!" εCached = {fmtFloat epsCached}" - let candFrobStr : String := - match data.patternBoundParts? with - | some parts => fmtFloat parts.candFrob - | none => "n/a" - let candOpStr : String := - match data.patternBoundParts? with - | some parts => fmtFloat parts.candOp - | none => "n/a" - let candOpWOStr : String := - match data.patternBoundParts? with - | some parts => fmtFloat parts.candOpWO - | none => "n/a" - IO.println s!" candFrob = {candFrobStr}" - IO.println s!" candOp = {candOpStr}" - IO.println s!" candOpWO = {candOpWOStr}" - IO.println s!" patternRecon = {fmtFloat patternRecon}" - IO.println s!" valueRecon = {fmtFloat valueRecon}" - IO.println s!" εRecon = {fmtFloat epsRecon}" - let qkOk : String := - if data.queryKeyAlignSchurNorm ≤ data.queryKeyAlignNorm then "true" else "false" - let voOk : String := - if data.valueOutputProjSchurNorm ≤ data.valueOutputProjNorm then "true" else "false" - IO.println s!" checks = qkUsed≤qkFrob={qkOk}, voUsed≤voFrob={voOk}" - IO.println s!" reconDiff = Δpattern={fmtFloat (patternRecon - patternCached)}" - IO.println s!" Δvalue={fmtFloat (valueRecon - valueCached)}" - IO.println s!" Δε={fmtFloat (epsRecon - epsCached)}" - -/-! ## Analyze command helpers -/ - -private structure AnalyzeArgs where - modelPath : System.FilePath - modelPathStr : String - threshold : Float - outputPath? : Option System.FilePath - verify : Bool - verbose : Bool - -private def parseAnalyzeArgs (p : Parsed) : IO (Option AnalyzeArgs) := do - let modelPathStr := p.positionalArg! "model" |>.as! String - let thresholdStr := p.flag? "threshold" |>.map (·.as! String) |>.getD "0.1" - let some threshold := Nfp.parseFloat thresholdStr - | do - IO.eprintln s!"Error: Invalid threshold value '{thresholdStr}'" - return none - let outputPath? := p.flag? "output" |>.map (·.as! String) |>.map (fun s => ⟨s⟩) - let verify := p.hasFlag "verify" - let verbose := p.hasFlag "verbose" - return some { - modelPath := ⟨modelPathStr⟩ - modelPathStr := modelPathStr - threshold := threshold - outputPath? := outputPath? - verify := verify - verbose := verbose - } - -private def runAnalyzeWithArgs (args : AnalyzeArgs) : IO UInt32 := do - setStdoutLogNameFromModelPath args.modelPathStr - if args.verbose then - IO.println s!"Loading model from {args.modelPathStr}..." - IO.println s!"Threshold: {args.threshold}" - if args.verify then - IO.println "Mode: Verification (with empirical validation)" - else - IO.println "Mode: Analysis (static bounds only)" - let loadResult ← loadModel args.modelPath - match loadResult with - | .error msg => - IO.eprintln s!"Error loading model: {msg}" - return 1 - | .ok model0 => - let model := model0.trimTrailingZeroEmbeddings - if args.verbose && model.seqLen ≠ model0.seqLen then - IO.println s!" Trimmed trailing zero embeddings: seqLen {model0.seqLen} -> {model.seqLen}" - if args.verbose then - IO.println s!"✓ Model loaded successfully" - IO.println s!" Layers: {model.numLayers}" - IO.println s!" Sequence Length: {model.seqLen}" - let vocabSize := - match model.unembedding with - | some u => u.numCols - | none => 0 - IO.println s!" Embedding Vocabulary: {vocabSize}" - IO.println s!" Model Dimension: {model.modelDim}" - IO.println "" - IO.println "Running analysis..." - let report ← if args.verify then - analyzeAndVerify model args.modelPathStr args.threshold none - else - analyzeModel model args.modelPathStr args.threshold - writeReport args.outputPath? (toString report) - return 0 - -/-- Run the analyze command - perform circuit analysis. -/ -def runAnalyze (p : Parsed) : IO UInt32 := do - let some args ← parseAnalyzeArgs p - | return 1 - runAnalyzeWithArgs args - -/-! ## Induction command helpers -/ - -private structure InductionArgs where - modelPath : System.FilePath - modelPathStr : String - correctOpt : Option Nat - incorrectOpt : Option Nat - minEffect : Float - verify : Bool - verbose : Bool - diagnostics : Bool - adaptive : Bool - targetSlack : Float - maxUpgrades : Nat - minRelImprove : Float - krylovSteps : Nat - adaptiveScope : Nfp.AdaptiveScope - adaptiveScopeStr : String - diagTop : Nat - -private def parseInductionArgs (p : Parsed) : IO (Option InductionArgs) := do - let modelPathStr := p.positionalArg! "model" |>.as! String - let correctOpt := p.flag? "correct" |>.map (·.as! Nat) - let incorrectOpt := p.flag? "incorrect" |>.map (·.as! Nat) - let thresholdStr := p.flag? "threshold" |>.map (·.as! String) |>.getD "0.0" - let verify := p.hasFlag "verify" - let verbose := p.hasFlag "verbose" - let diagnostics := p.hasFlag "diagnostics" - let adaptive := p.hasFlag "adaptive" - let targetSlackStr := p.flag? "targetSlack" |>.map (·.as! String) |>.getD "8.0" - let maxUpgrades := p.flag? "maxUpgrades" |>.map (·.as! Nat) |>.getD 120 - let minRelImproveStr := p.flag? "minRelImprove" |>.map (·.as! String) |>.getD "0.01" - let krylovSteps := p.flag? "krylovSteps" |>.map (·.as! Nat) |>.getD 2 - let adaptiveScopeStr := p.flag? "adaptiveScope" |>.map (·.as! String) |>.getD "layernorm" - let diagTop := p.flag? "diagTop" |>.map (·.as! Nat) |>.getD 5 - let some minEffect := Nfp.parseFloat thresholdStr - | do - IO.eprintln s!"Error: Invalid threshold value '{thresholdStr}'" - return none - let (targetSlack, minRelImprove, adaptiveScope) ← - if adaptive then - let some targetSlack := Nfp.parseFloat targetSlackStr - | do - IO.eprintln s!"Error: Invalid --targetSlack '{targetSlackStr}'" - return none - let some minRelImprove := Nfp.parseFloat minRelImproveStr - | do - IO.eprintln s!"Error: Invalid --minRelImprove '{minRelImproveStr}'" - return none - let adaptiveScope? : Option Nfp.AdaptiveScope := - match adaptiveScopeStr.trim.toLower with - | "layernorm" => some .layernorm - | "all" => some .all - | _ => none - let some adaptiveScope := adaptiveScope? - | do - IO.eprintln <| - s!"Error: Invalid --adaptiveScope '{adaptiveScopeStr}' " ++ - "(expected layernorm|all)" - return none - pure (targetSlack, minRelImprove, adaptiveScope) - else - pure (8.0, 0.01, Nfp.AdaptiveScope.layernorm) - return some { - modelPath := ⟨modelPathStr⟩ - modelPathStr := modelPathStr - correctOpt := correctOpt - incorrectOpt := incorrectOpt - minEffect := minEffect - verify := verify - verbose := verbose - diagnostics := diagnostics - adaptive := adaptive - targetSlack := targetSlack - maxUpgrades := maxUpgrades - minRelImprove := minRelImprove - krylovSteps := krylovSteps - adaptiveScope := adaptiveScope - adaptiveScopeStr := adaptiveScopeStr - diagTop := diagTop - } - -private def deriveInductionTarget - (model : ConcreteModel) - (W_U : ConcreteMatrix) - (correctOpt incorrectOpt : Option Nat) : Option TargetDirection := - match correctOpt, incorrectOpt with - | some correct, some incorrect => - some (TargetDirection.fromLogitDiff W_U correct incorrect) - | some _, none | none, some _ => - none - | none, none => - match model.inputTokens with - | some _ => TargetDirection.fromInductionHistory model - | none => some (TargetDirection.fromLogitDiff W_U 0 1) - -private def printInductionSearchIntro (minEffect : Float) : IO Unit := do - IO.println s!"Searching for heads with Effect > {minEffect}... (heuristic)" - IO.println "Ranking: highest mechScore (= kComp·indScore·prevTok) first" - IO.println " Tie-break: Effect, kComp, δ, then lowest Error" - IO.println " circuitScore = Effect · mechScore" - IO.println " Effect = δ / (‖ln₁(X₂)‖_F · ‖u‖₂)" - IO.println " kComp_raw = ‖W_QK² · W_OV¹‖_F / (‖W_QK²‖_F · ‖W_OV¹‖_F)" - IO.println " kComp = kComp_raw - 1/√modelDim" - -private def printInductionCandidate (verbose : Bool) (h : HeuristicInductionHead) : IO Unit := do - let c := h.candidate - let mechScore := c.kComp * c.inductionScore * c.prevTokenStrength - let circuitScore := h.effect * mechScore - if verbose then - IO.println <| - s!"L{c.layer1Idx}H{c.head1Idx} -> L{c.layer2Idx}H{c.head2Idx} | " ++ - s!"Mech: {mechScore} | Circuit: {circuitScore} | " ++ - s!"Effect: {h.effect} | kComp: {c.kComp} | " ++ - s!"indScore: {c.inductionScore} | prevTok: {c.prevTokenStrength} | " ++ - s!"δ: {h.delta} | " ++ - s!"Error: {c.combinedError} | " ++ - s!"‖X₂‖_F: {h.layer2InputNorm} | " ++ - s!"‖ln₁(X₂)‖_F: {h.layer2Ln1InputNorm} " ++ - s!"(ε₁={c.patternBound1}, ε₂={c.patternBound2})" - else - IO.println <| - s!"L{c.layer1Idx}H{c.head1Idx} -> L{c.layer2Idx}H{c.head2Idx} | " ++ - s!"Mech: {mechScore} | Effect: {h.effect} | " ++ - s!"kComp: {c.kComp} | " ++ - s!"indScore: {c.inductionScore} | prevTok: {c.prevTokenStrength} | " ++ - s!"Error: {c.combinedError} | " ++ - s!"‖X₂‖_F: {h.layer2InputNorm}" - -private def printInductionCandidates (heads : Array HeuristicInductionHead) (verbose : Bool) : - IO (Array HeuristicInductionHead) := do - let top := heads.take 50 - IO.println s!"Top Induction Head Pairs by mechScore (top {top.size} of {heads.size})" - for h in top do - printInductionCandidate verbose h - return top - -private def buildAdaptiveScheduler - (cache : PrecomputedCache) - (args : InductionArgs) : Option AdaptiveSchedulerResult := - if args.adaptive && (args.verbose || args.diagnostics) then - let cfg : Nfp.AdaptiveSchedulerConfig := - { targetSlack := args.targetSlack - maxUpgrades := args.maxUpgrades - minRelImprove := args.minRelImprove - krylovSteps := args.krylovSteps - scope := args.adaptiveScope - debugMonotone := args.diagnostics } - some (Nfp.runAdaptiveScheduler cache cfg none) - else - none - -private def printAdaptiveSchedulerSteps - (sched : AdaptiveSchedulerResult) - (args : InductionArgs) : IO Unit := do - IO.println "" - IO.println "ADAPTIVE SCHEDULER" - IO.println <| - s!" targetSlack={fmtFloat args.targetSlack} maxUpgrades={args.maxUpgrades} " ++ - s!"minRelImprove={fmtFloat args.minRelImprove} krylovSteps={args.krylovSteps} " ++ - s!"scope={args.adaptiveScopeStr}" - for s in sched.steps do - IO.println <| - match s.kind with - | .ubTier => - s!" it={s.iter} L{s.layerIdx}: tier {s.tierFrom}->{s.tierTo} " ++ - s!"ub {fmtFloat s.ubBefore}->{fmtFloat s.ubAfter} " ++ - s!"lb≈{fmtFloat s.lb} (k={s.kTo}) " ++ - s!"slack {fmtFloat s.slackBefore}->{fmtFloat s.slackAfter}" - | .lbSteps => - s!" it={s.iter} L{s.layerIdx}: lb-steps {s.kFrom}->{s.kTo} " ++ - s!"ub {fmtFloat s.ubBefore} " ++ - s!"lb≈{fmtFloat s.lb} slack {fmtFloat s.slackBefore}->{fmtFloat s.slackAfter}" - -private def printInductionDiagnostics - (cache : PrecomputedCache) - (top : Array HeuristicInductionHead) - (args : InductionArgs) - (sched? : Option AdaptiveSchedulerResult) : IO (Option UInt32) := do - IO.println "" - let diagN := min args.diagTop top.size - if args.adaptive then - let some sched := sched? - | do - IO.eprintln "Error: internal scheduler state missing." - return some 1 - IO.println "" - IO.println "LAYER NORM DIAGNOSTICS (LOWER vs UPPER)" - IO.println " (lb is rigorous lower bound via Krylov steps; ub is rigorous upper bound)" - for l in [:sched.ub.size] do - let ub := sched.ub[l]! - let lb := sched.lb[l]! - let ratio : Float := if lb > 1e-12 then ub / lb else Float.inf - let tier := sched.tier[l]! - let k := sched.lbK.getD l 0 - IO.println <| - s!" L{l}: lb≈{fmtFloat lb} ub={fmtFloat ub} " ++ - s!"ub/lb={fmtFloat ratio} tier={tier} k={k}" - let x := cache.forwardResult.getLayerInput l - let y := cache.forwardResult.getPostAttnResidual l - let ln1p := cache.model.ln1Params l - let ln2p := cache.model.ln2Params l - let ln1 := ConcreteMatrix.layerNormRowwiseOpDiag x ln1p.gamma - let ln2 := ConcreteMatrix.layerNormRowwiseOpDiag y ln2p.gamma - IO.println <| - s!" ln1: op≈{fmtFloat (ln1.gammaMaxAbs * ln1.maxInvStd)} " ++ - s!"(γmax≈{fmtFloat ln1.gammaMaxAbs}, invStdMax≈{fmtFloat ln1.maxInvStd} " ++ - s!"@r={ln1.maxInvStdRow}, varMin≈{fmtFloat ln1.minVar} @r={ln1.minVarRow})" - IO.println <| - s!" ln2: op≈{fmtFloat (ln2.gammaMaxAbs * ln2.maxInvStd)} " ++ - s!"(γmax≈{fmtFloat ln2.gammaMaxAbs}, invStdMax≈{fmtFloat ln2.maxInvStd} " ++ - s!"@r={ln2.maxInvStdRow}, varMin≈{fmtFloat ln2.minVar} @r={ln2.minVarRow})" - if hm : l < cache.model.mlps.size then - let mlp := cache.model.mlps[l]'hm - let y := cache.forwardResult.getPostAttnResidual l - let ln2Bound := cache.model.ln2OpBound l y - let (attnPart, maxHeadIdx, maxHeadContrib, maxHeadValue, maxHeadPattern) : - (Float × Nat × Float × Float × Float) := Id.run do - let layerData := cache.headData.getD l #[] - let mut a : Float := 0.0 - let mut bestIdx : Nat := 0 - let mut best : Float := 0.0 - let mut bestValue : Float := 0.0 - let mut bestPattern : Float := 0.0 - let mut idx : Nat := 0 - for d in layerData do - let attnFrob : Float := Float.sqrt (max 0.0 d.attentionFrobeniusNormSq) - let attnOpUb : Float := min attnFrob d.attentionOneInfBound - let valueTermUb : Float := d.ln1OpBound * (attnOpUb * d.valueOutputProjSchurNorm) - let inputs : Nfp.PatternTermBoundInputs := { - attention := d.attention - inputNorm := d.inputNorm - inputOpBound := d.inputOpBound - qFrobBound := d.qFrobBound - kFrobBound := d.kFrobBound - vFrobBound := d.vFrobBound - qOpBoundAct := d.qOpBoundAct - kOpBoundAct := d.kOpBoundAct - vOpBoundAct := d.vOpBoundAct - qkActFrobBound := d.qkActFrobBound - kqActFrobBound := d.kqActFrobBound - qkActOpBound := d.qkActOpBound - kqActOpBound := d.kqActOpBound - scaleFactor := d.scaleFactor - wqOpBound := d.wqOpGram - wkOpBound := d.wkOpGram - wvOpBound := d.wvOpGram - woOpBound := d.woOpGram - voOpBound := d.valueOutputProjSchurNorm - bqFrob := d.bqFrob - bkFrob := d.bkFrob - bvFrob := d.bvFrob - } - let patternTermUb : Float := d.ln1OpBound * (Nfp.computePatternTermBound inputs) - let contrib := valueTermUb + patternTermUb - a := a + contrib - if contrib > best then - bestIdx := idx - best := contrib - bestValue := valueTermUb - bestPattern := patternTermUb - idx := idx + 1 - (a, bestIdx, best, bestValue, bestPattern) - let mlpTotal : Float := max 0.0 (ub - attnPart) - let mlpOnly : Float := - if 1.0 + attnPart > 1e-12 then - mlpTotal / (1.0 + attnPart) - else - mlpTotal - let cross : Float := max 0.0 (mlpTotal - mlpOnly) - IO.println <| - s!" contrib: attn≈{fmtFloat attnPart} mlpOnly≈{fmtFloat mlpOnly} " ++ - s!"cross≈{fmtFloat cross} mlpTotal≈{fmtFloat mlpTotal} " ++ - s!"(maxHead=H{maxHeadIdx}≈{fmtFloat maxHeadContrib}, " ++ - s!"value≈{fmtFloat maxHeadValue}, pattern≈{fmtFloat maxHeadPattern})" - let mlpJacLegacy : Float := - let denom := ln2Bound * (1.0 + attnPart) - if denom > 1e-12 then - mlpTotal / denom - else - Float.nan - let geluDeriv := cache.forwardResult.getMlpGeluDeriv l - let diag := computeMLPOpAbsSchurDiag mlp geluDeriv - let chosen : Float := min diag.absSchur mlpJacLegacy - IO.println <| - s!" mlpDiag(L{l}): absSchur={fmtFloat diag.absSchur} " ++ - s!"legacy≈{fmtFloat mlpJacLegacy} chosen≈{fmtFloat chosen} " ++ - s!"dMax≈{fmtFloat diag.dMax}" - else - IO.println "LAYER NORM DIAGNOSTICS (PI vs rigorous)" - IO.println " (PI is diagnostics-only; rigorous is used for bounds)" - for l in [:cache.layerNormBounds.size] do - let ub := cache.layerNormBounds[l]! - let pi := estimateAttentionLayerNormHeuristicPI cache.model cache.forwardResult l true - let ratio : Float := if pi > 1e-12 then ub / pi else Float.inf - IO.println s!" L{l}: PI≈{fmtFloat pi} ub={fmtFloat ub} ub/PI={fmtFloat ratio}" - -- Rectangular Gram diagnostics for MLP weights (layer 0 only). - if h0 : 0 < cache.model.mlps.size then - let mlp0 : ConcreteMLPLayer := cache.model.mlps[0]'h0 - let dIn := mlp0.W_in.opNormUpperBoundRectGramDiag - let dOut := mlp0.W_out.opNormUpperBoundRectGramDiag - let chosenMsg (d : _) : String := - if d.usedGram then "chosen=signedGram" - else if d.usedAbsGram then "chosen=absGram" - else "chosen=cheap" - let signedGramMsg (d : _) : String := - if !d.signedGramEnabled then "signedGram=disabled" - else if d.gramDim > d.maxGramDimCap then "signedGram=skipped(maxGramDim cap)" - else if d.skippedGram then "signedGram=skipped(cost guard)" - else if d.computedGram then "signedGram=computed" - else "signedGram=not-attempted" - let absGramMsg (d : _) : String := - if !d.computedAbsGram then "absGram=disabled" - else if d.usedAbsGram then "absGram=chosen" - else "absGram=computed" - IO.println "" - IO.println "RECT-GRAM DIAGNOSTICS (MLP layer 0 weights)" - IO.println <| - s!" W_in: usedGram={dIn.usedGram} usedAbsGram={dIn.usedAbsGram} " ++ - s!"computedGram={dIn.computedGram} computedAbsGram={dIn.computedAbsGram} " ++ - s!"skippedGram={dIn.skippedGram}" - IO.println <| - s!" gramDim={dIn.gramDim} maxGramDimCap={dIn.maxGramDimCap} " ++ - s!"signedGramEnabled={dIn.signedGramEnabled}" - IO.println <| - s!" gramCost={dIn.gramCost} gramCostLimit={dIn.gramCostLimit} " ++ - s!"{chosenMsg dIn} {signedGramMsg dIn} {absGramMsg dIn}" - IO.println <| - s!" frob={fmtFloat dIn.frobBound} oneInf={fmtFloat dIn.oneInfBound} " ++ - s!"opBound={fmtFloat dIn.opBound}" - IO.println <| - s!" λ_abs_gersh={fmtFloat dIn.lambdaAbsGersh} " ++ - s!"λ_abs_brauer={fmtFloat dIn.lambdaAbsBrauer}" - IO.println s!" λ_gersh={fmtFloat dIn.lambdaGersh}" - IO.println s!" λ_brauer={fmtFloat dIn.lambdaBrauer}" - IO.println s!" λ_moment={fmtFloat dIn.lambdaMoment}" - IO.println s!" λ_used={fmtFloat dIn.lambdaUsed}" - IO.println <| - s!" W_out: usedGram={dOut.usedGram} usedAbsGram={dOut.usedAbsGram} " ++ - s!"computedGram={dOut.computedGram} computedAbsGram={dOut.computedAbsGram} " ++ - s!"skippedGram={dOut.skippedGram}" - IO.println <| - s!" gramDim={dOut.gramDim} maxGramDimCap={dOut.maxGramDimCap} " ++ - s!"signedGramEnabled={dOut.signedGramEnabled}" - IO.println <| - s!" gramCost={dOut.gramCost} gramCostLimit={dOut.gramCostLimit} " ++ - s!"{chosenMsg dOut} {signedGramMsg dOut} {absGramMsg dOut}" - IO.println <| - s!" frob={fmtFloat dOut.frobBound} oneInf={fmtFloat dOut.oneInfBound} " ++ - s!"opBound={fmtFloat dOut.opBound}" - IO.println <| - s!" λ_abs_gersh={fmtFloat dOut.lambdaAbsGersh} " ++ - s!"λ_abs_brauer={fmtFloat dOut.lambdaAbsBrauer}" - IO.println s!" λ_gersh={fmtFloat dOut.lambdaGersh}" - IO.println s!" λ_brauer={fmtFloat dOut.lambdaBrauer}" - IO.println s!" λ_moment={fmtFloat dOut.lambdaMoment}" - IO.println s!" λ_used={fmtFloat dOut.lambdaUsed}" - IO.println "" - IO.println s!"DIAGNOSTICS (ε decomposition) for top-{diagN} candidates" - for h in (top.take diagN) do - let c := h.candidate - IO.println s!"Candidate L{c.layer1Idx}H{c.head1Idx} -> L{c.layer2Idx}H{c.head2Idx}" - match cache.getHeadData c.layer1Idx c.head1Idx, - cache.getHeadData c.layer2Idx c.head2Idx with - | some d1, some d2 => - printHeadDiagnostics "Head1" d1 - printHeadDiagnostics "Head2" d2 - let ε1 := d1.faithfulnessRatioCached - let ε2 := d2.faithfulnessRatioCached - let combinedRecon := ε1 + ε2 + ε1 * ε2 - IO.println " Combined:" - IO.println s!" ε1 = {fmtFloat ε1} ε2 = {fmtFloat ε2}" - IO.println s!" combinedError = (1+ε1)(1+ε2)-1 = {fmtFloat combinedRecon}" - IO.println s!" combinedErrorCached = {fmtFloat c.combinedError}" - IO.println s!" reconDiff = {fmtFloat (combinedRecon - c.combinedError)}" - | _, _ => - IO.println " (diagnostics unavailable: missing cached head data)" - return none - -private def runInductionVerification - (model : ConcreteModel) - (heads : Array HeuristicInductionHead) - (correctOpt : Option Nat) : IO (Option UInt32) := do - IO.println "" - IO.println "Causal Verification (Head Ablation)" - IO.println "Metric: Δ = logit(target) - logit(top non-target) at last position" - IO.println "Impact = Δ_base - Δ_ablated" - IO.println "" - let targetToken? : Option Nat := - match correctOpt with - | some t => some t - | none => inductionTargetTokenFromHistory model - let some targetToken := targetToken? - | do - IO.eprintln "Error: Cannot infer induction target token (need TOKENS or --correct)." - return some 2 - match VerificationContext.build model targetToken {} with - | .error msg => - IO.eprintln s!"Error: {msg}" - return some 2 - | .ok ctx => - let verifyTop := heads.take 10 - IO.println s!"Top-{verifyTop.size} ablation checks (ranked by mechScore):" - IO.println s!" target={ctx.targetToken} | neg={ctx.negativeToken}" - IO.println s!" Δ_base={ctx.baseDelta}" - IO.println "" - IO.println "Rank | Candidate | Base Δ | Ablated Δ | Impact (Logits) | RelScore | \ - Control Impact | Axioms Verified?" - IO.println "-----|-----------|--------|----------|----------------|----------|\ - --------------|----------------" - let fmtOpt (x : Option Float) : String := - match x with - | some v => toString v - | none => "undef" - let mut rank : Nat := 1 - for h in verifyTop do - let c := h.candidate - let candHeads : Array HeadRef := #[ - { layerIdx := c.layer1Idx, headIdx := c.head1Idx }, - { layerIdx := c.layer2Idx, headIdx := c.head2Idx } - ] - let row := verifyCircuit ctx candHeads - let axiomsStr := - if row.axioms.verified then - "yes" - else - let reasons := String.intercalate "; " row.axioms.failures.toList - if reasons.isEmpty then "no" else s!"no ({reasons})" - IO.println s!"{rank} | {row.candidateLabel} | {row.baseDelta} | \ - {fmtOpt row.ablatedDelta} | {fmtOpt row.impact} | {fmtOpt row.relScore} | \ - {fmtOpt row.controlImpact} | {axiomsStr}" - rank := rank + 1 - return none - -/-- Run the induction command - discover induction heads ranked by effectiveness. -/ -def runInduction (p : Parsed) : IO UInt32 := do - let some args ← parseInductionArgs p - | return 1 - setStdoutLogNameFromModelPath args.modelPathStr - IO.println "Loading model..." - let loadResult ← loadModel args.modelPath - match loadResult with - | .error msg => - IO.eprintln s!"Error loading model: {msg}" - return 1 - | .ok model0 => - let model := model0.trimTrailingZeroEmbeddings - if args.verbose && model.seqLen ≠ model0.seqLen then - IO.println s!" Trimmed trailing zero embeddings: seqLen {model0.seqLen} -> {model.seqLen}" - match model.unembedding with - | none => - IO.eprintln "Error: Model is missing unembedding matrix (needed for logit directions)." - return 1 - | some W_U => - let target? := deriveInductionTarget model W_U args.correctOpt args.incorrectOpt - let some target := target? - | do - if args.correctOpt.isSome ∨ args.incorrectOpt.isSome then - IO.eprintln "Error: Use both --correct and --incorrect (or neither to auto-detect)." - return 1 - else - IO.eprintln "No valid induction target could be derived from TOKENS \ - (no prior repetition of last token)." - IO.eprintln "Hint: pass --correct/--incorrect to override, or export a prompt \ - where the last token repeats." - return 2 - if args.correctOpt.isNone ∧ args.incorrectOpt.isNone ∧ model.inputTokens.isNone then - IO.println "Warning: No TOKENS section found; using default target logit_diff(0-1)." - IO.println "Hint: export with TOKENS or pass --correct/--incorrect." - IO.println s!"Target: {target.description}" - printInductionSearchIntro args.minEffect - let buildLayerNormBounds := args.diagnostics && (!args.adaptive) - let (heads, cache) ← Nfp.timeIt "induction:search" (fun () => - pure <| - findHeuristicInductionHeadsWithCache model target args.minEffect - (minInductionScore := 0.01) - (buildLayerNormBounds := buildLayerNormBounds) - (storeDiagnostics := args.diagnostics)) - let top ← printInductionCandidates heads args.verbose - let sched? := buildAdaptiveScheduler cache args - if args.adaptive && args.verbose then - match sched? with - | some sched => printAdaptiveSchedulerSteps sched args - | none => pure () - if args.diagnostics then - let err? ← printInductionDiagnostics cache top args sched? - if let some code := err? then - return code - if args.verify then - let err? ← Nfp.timeIt "induction:verify" (fun () => - runInductionVerification model heads args.correctOpt) - if let some code := err? then - return code - return 0 - -/-! ## Bench command helpers -/ - -private inductive BenchMode - | analyze - | induction - deriving Repr - -private def parseBenchMode (s : String) : Option BenchMode := - match s.trim.toLower with - | "analysis" => some .analyze - | "analyze" => some .analyze - | "induction" => some .induction - | "induce" => some .induction - | _ => none - -private structure BenchArgs where - modelPath : System.FilePath - modelPathStr : String - mode : BenchMode - runs : Nat - repeatCount : Nat - threshold : Float - minEffect : Float - correctOpt : Option Nat - incorrectOpt : Option Nat - verbose : Bool - breakdown : Bool - -private def parseBenchArgs (p : Parsed) : IO (Option BenchArgs) := do - let modelPathStr := p.positionalArg! "model" |>.as! String - let modeStr := p.flag? "mode" |>.map (·.as! String) |>.getD "analysis" - let some mode := parseBenchMode modeStr - | do - IO.eprintln s!"Error: Invalid --mode '{modeStr}' (analysis|induction)" - return none - let runs := p.flag? "runs" |>.map (·.as! Nat) |>.getD 5 - let repeatCount := p.flag? "repeats" |>.map (·.as! Nat) |>.getD 1 - let thresholdStr := p.flag? "threshold" |>.map (·.as! String) |>.getD "0.1" - let minEffectStr := p.flag? "minEffect" |>.map (·.as! String) |>.getD "0.0" - let some threshold := Nfp.parseFloat thresholdStr - | do - IO.eprintln s!"Error: Invalid --threshold '{thresholdStr}'" - return none - let some minEffect := Nfp.parseFloat minEffectStr - | do - IO.eprintln s!"Error: Invalid --minEffect '{minEffectStr}'" - return none - let correctOpt := p.flag? "correct" |>.map (·.as! Nat) - let incorrectOpt := p.flag? "incorrect" |>.map (·.as! Nat) - let verbose := p.hasFlag "verbose" - let breakdown := p.hasFlag "breakdown" - return some { - modelPath := ⟨modelPathStr⟩ - modelPathStr := modelPathStr - mode := mode - runs := runs - repeatCount := repeatCount - threshold := threshold - minEffect := minEffect - correctOpt := correctOpt - incorrectOpt := incorrectOpt - verbose := verbose - breakdown := breakdown - } - -/-- Core analysis work for benchmarking (no IO). -/ -private def benchAnalyzeOnce (model : ConcreteModel) (threshold : Float) : Nat × Nat := - Id.run do - let cache := PrecomputedCache.build model - let deepCircuits := findDeepCircuitCandidatesFromCache cache - let verifiedDeep := deepCircuits.filter (·.amplifiedError ≤ threshold) - let mut verifiedHeads : Array CandidateInductionHead := Array.mkEmpty 0 - for circuit in deepCircuits do - match circuit.toInductionCandidateCore? cache with - | none => pure () - | some core => - if core.combinedError ≤ threshold then - match core.toInductionCandidate? cache with - | some cand => verifiedHeads := verifiedHeads.push cand - | none => pure () - let verifiedSorted := verifiedHeads.qsort (·.combinedError < ·.combinedError) - return (verifiedSorted.size, verifiedDeep.size) - -/-- Core induction-head search work for benchmarking (no IO). -/ -private def benchInductionOnce (model : ConcreteModel) (target : TargetDirection) - (minEffect : Float) : Nat := - let (heads, _) := - findHeuristicInductionHeadsWithCache model target minEffect - (minInductionScore := 0.01) - (buildLayerNormBounds := false) - (storeDiagnostics := false) - heads.size - -private def summarizeBenchTimes (label : String) (times : Array Nat) (repeatCount : Nat) : - IO Unit := do - let t0 := times[0]! - let mut minT := t0 - let mut maxT := t0 - let mut sumT : Nat := 0 - for t in times do - if t < minT then - minT := t - if t > maxT then - maxT := t - sumT := sumT + t - let avgT := sumT / times.size - IO.println <| - s!"{label} runs={times.size} repeat={repeatCount} " ++ - s!"min={minT}ms avg={avgT}ms max={maxT}ms" - -private def timeNs {α : Type} (action : Unit → IO α) : IO (α × Nat) := do - let t0 ← IO.monoNanosNow - let result ← action () - let t1 ← IO.monoNanosNow - let dtNs := t1 - t0 - return (result, dtNs) - -private def runBenchWithArgs (args : BenchArgs) : IO UInt32 := do - if args.runs = 0 then - IO.eprintln "Error: --runs must be > 0" - return 1 - if args.repeatCount = 0 then - IO.eprintln "Error: --repeats must be > 0" - return 1 - setStdoutLogNameFromModelPath args.modelPathStr - let loadResult ← loadModel args.modelPath - let model ← - match loadResult with - | .error msg => - IO.eprintln s!"Error loading model: {msg}" - return 1 - | .ok model0 => pure (model0.trimTrailingZeroEmbeddings) - match args.mode with - | .analyze => - let mut times : Array Nat := Array.mkEmpty args.runs - let mut lastHeads : Nat := 0 - let mut lastCircuits : Nat := 0 - let mut fwdNsTotal : Nat := 0 - let mut headNsTotal : Nat := 0 - let mut normNsTotal : Nat := 0 - let mut deepNsTotal : Nat := 0 - let mut candNsTotal : Nat := 0 - for i in [:args.runs] do - let t0 ← IO.monoNanosNow - if args.breakdown then - let mut localFwdNs : Nat := 0 - let mut localHeadNs : Nat := 0 - let mut localNormNs : Nat := 0 - let mut localDeepNs : Nat := 0 - let mut localCandNs : Nat := 0 - for _ in [:args.repeatCount] do - let (fwdResult, fwdNs) ← timeNs (fun () => - pure <| model.runForward true) - let ((headData, ln1Inputs), headNs) ← timeNs (fun () => - pure <| - PrecomputedCache.buildHeadData model fwdResult true - ConcreteMatrix.BoundEffort.tier1 false) - let baseBounds := Array.replicate model.numLayers 0.0 - let baseCache : PrecomputedCache := { - model := model - forwardResult := fwdResult - ln1Inputs := ln1Inputs - headData := headData - layerNormBounds := baseBounds - layerNormBoundsComputed := false - } - let (layerNormBounds, normNs) ← timeNs (fun () => - pure <| - PrecomputedCache.computeLayerNormBounds baseCache - ConcreteMatrix.BoundEffort.tier1) - let cache : PrecomputedCache := { - baseCache with - layerNormBounds := layerNormBounds - layerNormBoundsComputed := true - } - let (deepCircuits, deepNs) ← timeNs (fun () => - pure <| findDeepCircuitCandidatesFromCache cache) - let verifiedDeep := deepCircuits.filter (·.amplifiedError ≤ args.threshold) - let (verifiedHeads, candNs) ← timeNs (fun () => do - let mut verified : Array CandidateInductionHead := Array.mkEmpty 0 - for circuit in deepCircuits do - match circuit.toInductionCandidateCore? cache with - | none => pure () - | some core => - if core.combinedError ≤ args.threshold then - match core.toInductionCandidate? cache with - | some cand => verified := verified.push cand - | none => pure () - let verifiedSorted := verified.qsort (·.combinedError < ·.combinedError) - return verifiedSorted) - localFwdNs := localFwdNs + fwdNs - localHeadNs := localHeadNs + headNs - localNormNs := localNormNs + normNs - localDeepNs := localDeepNs + deepNs - localCandNs := localCandNs + candNs - lastHeads := verifiedHeads.size - lastCircuits := verifiedDeep.size - fwdNsTotal := fwdNsTotal + localFwdNs - headNsTotal := headNsTotal + localHeadNs - normNsTotal := normNsTotal + localNormNs - deepNsTotal := deepNsTotal + localDeepNs - candNsTotal := candNsTotal + localCandNs - else - for _ in [:args.repeatCount] do - let (heads, circuits) := benchAnalyzeOnce model args.threshold - lastHeads := heads - lastCircuits := circuits - let t1 ← IO.monoNanosNow - let dtMs := (t1 - t0) / 1000000 - times := times.push dtMs - if args.verbose then - IO.println s!"run {i + 1}: {dtMs}ms heads={lastHeads} circuits={lastCircuits}" - summarizeBenchTimes "bench:analysis" times args.repeatCount - if args.breakdown then - let runs := args.runs - let repeatCount := args.repeatCount - let fwdAvgNs := fwdNsTotal / (runs * repeatCount) - let headAvgNs := headNsTotal / (runs * repeatCount) - let normAvgNs := normNsTotal / (runs * repeatCount) - let deepAvgNs := deepNsTotal / (runs * repeatCount) - let candAvgNs := candNsTotal / (runs * repeatCount) - let fwdAvgUs := fwdAvgNs / 1000 - let headAvgUs := headAvgNs / 1000 - let normAvgUs := normAvgNs / 1000 - let deepAvgUs := deepAvgNs / 1000 - let candAvgUs := candAvgNs / 1000 - IO.println <| - s!"bench:analysis fwdAvg={fwdAvgUs}us headAvg={headAvgUs}us " ++ - s!"normAvg={normAvgUs}us scanAvg={deepAvgUs}us " ++ - s!"candAvg={candAvgUs}us" - IO.println <| - s!"bench:analysis threshold={args.threshold} heads={lastHeads} " ++ - s!"circuits={lastCircuits}" - return 0 - | .induction => - let some W_U := model.unembedding - | do - IO.eprintln "Error: Model is missing unembedding matrix (needed for target direction)." - return 1 - let target? := deriveInductionTarget model W_U args.correctOpt args.incorrectOpt - let some target := target? - | do - IO.eprintln "Error: Use both --correct and --incorrect (or ensure TOKENS are present)." - return 1 - let mut times : Array Nat := Array.mkEmpty args.runs - let mut lastHeads : Nat := 0 - for i in [:args.runs] do - let t0 ← IO.monoNanosNow - for _ in [:args.repeatCount] do - let heads := benchInductionOnce model target args.minEffect - lastHeads := heads - let t1 ← IO.monoNanosNow - let dtMs := (t1 - t0) / 1000000 - times := times.push dtMs - if args.verbose then - IO.println s!"run {i + 1}: {dtMs}ms heads={lastHeads}" - summarizeBenchTimes "bench:induction" times args.repeatCount - IO.println s!"bench:induction minEffect={args.minEffect} heads={lastHeads}" - return 0 - -/-- Run the bench command for repeatable performance measurements. -/ -def runBench (p : Parsed) : IO UInt32 := do - let some args ← parseBenchArgs p - | return 1 - runBenchWithArgs args - -/-! ## SOUND command helpers -/ - -private structure CertifyArgs where - modelPath : System.FilePath - modelPathStr : String - inputPath? : Option System.FilePath - soundnessBits : Nat - partitionDepth : Nat - deltaFlag? : Option String - deltaStr : String - softmaxMarginStr : String - softmaxExpEffort : Nat - bestMatchMargins : Bool - targetOffset : Int - maxSeqLen : Nat - tightPattern : Bool - tightPatternLayers : Nat - perRowPatternLayers : Nat - causalPattern : Bool - scalePow10 : Nat - outputPath? : Option System.FilePath - -private def parseCertifyArgs (p : Parsed) : CertifyArgs := - let modelPathStr := p.positionalArg! "model" |>.as! String - let inputPath? := p.flag? "input" |>.map (·.as! String) |>.map (fun s => ⟨s⟩) - let soundnessBits := p.flag? "soundnessBits" |>.map (·.as! Nat) |>.getD 20 - let partitionDepth := p.flag? "partitionDepth" |>.map (·.as! Nat) |>.getD 0 - let deltaFlag? := p.flag? "delta" |>.map (·.as! String) - let deltaStr := deltaFlag?.getD "0" - let softmaxMarginStr := p.flag? "softmaxMargin" |>.map (·.as! String) |>.getD "0" - let softmaxExpEffort := - p.flag? "softmaxExpEffort" |>.map (·.as! Nat) |>.getD Nfp.Sound.defaultSoftmaxExpEffort - let bestMatchMargins := p.flag? "bestMatchMargins" |>.isSome - let targetOffset := p.flag? "targetOffset" |>.map (·.as! Int) |>.getD (-1) - let maxSeqLen := p.flag? "maxSeqLen" |>.map (·.as! Nat) |>.getD 0 - let tightPattern := p.flag? "tightPattern" |>.isSome - let tightPatternLayers := p.flag? "tightPatternLayers" |>.map (·.as! Nat) |>.getD 1 - let perRowPatternLayers := p.flag? "perRowPatternLayers" |>.map (·.as! Nat) |>.getD 0 - let causalPattern := !p.hasFlag "noncausalPattern" - let scalePow10 := p.flag? "scalePow10" |>.map (·.as! Nat) |>.getD 9 - let outputPath? := p.flag? "output" |>.map (·.as! String) |>.map (fun s => ⟨s⟩) - { modelPath := ⟨modelPathStr⟩ - modelPathStr := modelPathStr - inputPath? := inputPath? - soundnessBits := soundnessBits - partitionDepth := partitionDepth - deltaFlag? := deltaFlag? - deltaStr := deltaStr - softmaxMarginStr := softmaxMarginStr - softmaxExpEffort := softmaxExpEffort - bestMatchMargins := bestMatchMargins - targetOffset := targetOffset - maxSeqLen := maxSeqLen - tightPattern := tightPattern - tightPatternLayers := tightPatternLayers - perRowPatternLayers := perRowPatternLayers - causalPattern := causalPattern - scalePow10 := scalePow10 - outputPath? := outputPath? } - -private def runCertifyAction (args : CertifyArgs) : ExceptT String IO Nfp.Sound.ModelCert := do - let delta ← - match Nfp.Sound.parseRat args.deltaStr with - | .ok r => pure r - | .error e => throw s!"invalid --delta '{args.deltaStr}': {e}" - let softmaxMarginLowerBound ← - match Nfp.Sound.parseRat args.softmaxMarginStr with - | .ok r => pure r - | .error e => throw s!"invalid --softmaxMargin '{args.softmaxMarginStr}': {e}" - /- If `--input` is omitted but `--delta` is provided, try to use `modelPath` as the input file - (for `.nfpt` exports that embed `EMBEDDINGS` in the same file). This keeps `nfp certify` - ergonomic without changing the default behavior when `--delta` is absent. -/ - let inputPath? : Option System.FilePath ← - match args.inputPath? with - | some path => pure (some path) - | none => - match args.deltaFlag? with - | none => pure none - | some _ => - let hasEmbeddings ← - hasEmbeddingsBeforeLayers args.modelPath - if hasEmbeddings then - pure (some args.modelPath) - else - throw <| - "local certification requested via --delta, but the model file has no \ -EMBEDDINGS section before the first LAYER (legacy text format). Pass --input \ -containing EMBEDDINGS or omit --delta for global certification." - let inputPath? ← - if args.bestMatchMargins && inputPath?.isNone then - let hasEmbeddings ← hasEmbeddingsBeforeLayers args.modelPath - if hasEmbeddings then - pure (some args.modelPath) - else - throw <| - "best-match margin tightening requires local input with EMBEDDINGS. \ -Pass --input or use a model file that embeds EMBEDDINGS." - else - pure inputPath? - if args.bestMatchMargins && softmaxMarginLowerBound != 0 then - throw "best-match margin tightening is incompatible with --softmaxMargin" - if args.bestMatchMargins then - let cert ← ExceptT.mk <| - Nfp.Sound.certifyModelFileBestMatchMargins args.modelPath args.soundnessBits - (inputPath? := inputPath?) (inputDelta := delta) (partitionDepth := args.partitionDepth) - (targetOffset := args.targetOffset) (maxSeqLen := args.maxSeqLen) - (tightPattern := args.tightPattern) (tightPatternLayers := args.tightPatternLayers) - (perRowPatternLayers := args.perRowPatternLayers) (scalePow10 := args.scalePow10) - (softmaxExpEffort := args.softmaxExpEffort) (causalPattern := args.causalPattern) - return cert - else - let cert ← ExceptT.mk <| - Nfp.Sound.certifyModelFile args.modelPath args.soundnessBits - (inputPath? := inputPath?) (inputDelta := delta) (partitionDepth := args.partitionDepth) - (softmaxMarginLowerBound := softmaxMarginLowerBound) - (softmaxExpEffort := args.softmaxExpEffort) - return cert - -private structure HeadBoundsArgs where - modelPath : System.FilePath - inputPath? : Option System.FilePath - deltaFlag? : Option String - deltaStr : String - soundnessBits : Nat - scalePow10 : Nat - outputPath? : Option System.FilePath - -private def parseHeadBoundsArgs (p : Parsed) : HeadBoundsArgs := - let modelPathStr := p.positionalArg! "model" |>.as! String - let inputPath? := p.flag? "input" |>.map (·.as! String) |>.map (fun s => ⟨s⟩) - let deltaFlag? := p.flag? "delta" |>.map (·.as! String) - let deltaStr := deltaFlag?.getD "0" - let soundnessBits := p.flag? "soundnessBits" |>.map (·.as! Nat) |>.getD 20 - let scalePow10 := p.flag? "scalePow10" |>.map (·.as! Nat) |>.getD 9 - let outputPath? := p.flag? "output" |>.map (·.as! String) |>.map (fun s => ⟨s⟩) - { modelPath := ⟨modelPathStr⟩ - inputPath? := inputPath? - deltaFlag? := deltaFlag? - deltaStr := deltaStr - soundnessBits := soundnessBits - scalePow10 := scalePow10 - outputPath? := outputPath? } - -private def formatHeadBoundsLocal - (heads : Array Nfp.Sound.HeadLocalContributionCert) - (delta : Rat) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath) - (modelPath : System.FilePath) : String := - let header := - "SOUND per-head bounds (local): " ++ - s!"delta={delta}, soundnessBits={soundnessBits}, input={inputPath?.getD modelPath}\n" - let body := - heads.foldl (fun acc h => - acc ++ - s!"Layer {h.layerIdx} Head {h.headIdx}: " ++ - s!"ln1MaxAbsGamma={h.ln1MaxAbsGamma}, " ++ - s!"ln1VarLB={h.ln1VarianceLowerBound}, " ++ - s!"ln1Bound={h.ln1Bound}, " ++ - s!"wqOp={h.wqOpBound}, wkOp={h.wkOpBound}, " ++ - s!"qk={h.qkFactorBound}, " ++ - s!"softmaxJacobianNormInfUB={h.softmaxJacobianNormInfUpperBound}, " ++ - s!"wvOp={h.wvOpBound}, woOp={h.woOpBound}, " ++ - s!"attn={h.attnJacBound}\n") "" - header ++ body - -private def formatHeadBoundsGlobal - (heads : Array Nfp.Sound.HeadContributionCert) - (scalePow10 : Nat) : String := - let header := - s!"SOUND per-head bounds (weight-only): scalePow10={scalePow10}\n" - let body := - heads.foldl (fun acc h => - acc ++ - s!"Layer {h.layerIdx} Head {h.headIdx}: " ++ - s!"wqOp={h.wqOpBound}, wkOp={h.wkOpBound}, " ++ - s!"wvOp={h.wvOpBound}, woOp={h.woOpBound}, " ++ - s!"qk={h.qkFactorBound}, vo={h.voFactorBound}\n") "" - header ++ body - -private def runHeadBoundsAction (args : HeadBoundsArgs) : ExceptT String IO String := do - let delta ← - match Nfp.Sound.parseRat args.deltaStr with - | .ok r => pure r - | .error e => throw s!"invalid --delta '{args.deltaStr}': {e}" - let useLocal := (args.inputPath?.isSome || args.deltaFlag?.isSome) - if useLocal then - let inputPath? : Option System.FilePath ← - match args.inputPath? with - | some path => pure (some path) - | none => - let hasEmbeddings ← - hasEmbeddingsBeforeLayers args.modelPath - if hasEmbeddings then - pure (some args.modelPath) - else - throw <| - "local head bounds requested via --delta, but the model file has no \ -EMBEDDINGS section before the first LAYER (legacy text format). Pass --input \ -containing EMBEDDINGS or omit --delta for global head bounds." - let heads ← ExceptT.mk <| - Nfp.Sound.certifyHeadBoundsLocal args.modelPath - (inputPath? := inputPath?) (inputDelta := delta) (soundnessBits := args.soundnessBits) - return formatHeadBoundsLocal heads delta args.soundnessBits inputPath? args.modelPath - else - let heads ← ExceptT.mk <| - Nfp.Sound.certifyHeadBounds args.modelPath (scalePow10 := args.scalePow10) - return formatHeadBoundsGlobal heads args.scalePow10 - -private structure HeadPatternArgs where - modelPath : System.FilePath - layerIdx : Nat - headIdx : Nat - offset : Int - keyOffset : Int - soundnessBits : Nat - softmaxExpEffort : Nat - tightPatternLayers : Nat - tightPattern : Bool - perRowPatternLayers : Nat - causalPattern : Bool - bestMatch : Bool - useAffine : Bool - sweep : Bool - queryPos? : Option Nat - inputPath? : Option System.FilePath - deltaStr : String - maxSeqLen : Nat - outputPath? : Option System.FilePath - -private def parseHeadPatternArgs (p : Parsed) : HeadPatternArgs := - let modelPathStr := p.positionalArg! "model" |>.as! String - let layerIdx := p.flag? "layer" |>.map (·.as! Nat) |>.getD 0 - let headIdx := p.flag? "head" |>.map (·.as! Nat) |>.getD 0 - let offset := p.flag? "offset" |>.map (·.as! Int) |>.getD (-1) - let keyOffset := p.flag? "keyOffset" |>.map (·.as! Int) |>.getD 0 - let soundnessBits := p.flag? "soundnessBits" |>.map (·.as! Nat) |>.getD 20 - let softmaxExpEffort := - p.flag? "softmaxExpEffort" |>.map (·.as! Nat) - |>.getD Nfp.Sound.defaultSoftmaxExpEffort - let tightPatternLayers? := p.flag? "tightPatternLayers" |>.map (·.as! Nat) - let tightPatternLayers := tightPatternLayers?.getD 1 - let tightPattern := p.hasFlag "tightPattern" || tightPatternLayers?.isSome - let perRowPatternLayers := p.flag? "perRowPatternLayers" |>.map (·.as! Nat) |>.getD 0 - let causalPattern := !p.hasFlag "noncausalPattern" - let bestMatch := p.hasFlag "bestMatch" - let useAffine := p.hasFlag "affine" - let sweep := p.hasFlag "sweep" - let queryPos? := p.flag? "queryPos" |>.map (·.as! Nat) - let inputPath? := p.flag? "input" |>.map (·.as! String) |>.map (fun s => ⟨s⟩) - let deltaStr := p.flag? "delta" |>.map (·.as! String) |>.getD "0" - let maxSeqLen := p.flag? "maxSeqLen" |>.map (·.as! Nat) |>.getD 256 - let outputPath? := p.flag? "output" |>.map (·.as! String) |>.map (fun s => ⟨s⟩) - { modelPath := ⟨modelPathStr⟩ - layerIdx := layerIdx - headIdx := headIdx - offset := offset - keyOffset := keyOffset - soundnessBits := soundnessBits - softmaxExpEffort := softmaxExpEffort - tightPatternLayers := tightPatternLayers - tightPattern := tightPattern - perRowPatternLayers := perRowPatternLayers - causalPattern := causalPattern - bestMatch := bestMatch - useAffine := useAffine - sweep := sweep - queryPos? := queryPos? - inputPath? := inputPath? - deltaStr := deltaStr - maxSeqLen := maxSeqLen - outputPath? := outputPath? } - -private def formatHeadPatternBestMatchSweep - (layerIdx headIdx : Nat) - (offset : Int) - (keyOffset : Int) - (certs : Array Nfp.Sound.HeadBestMatchPatternCert) : String := - let header := - "SOUND head pattern sweep (best-match): " ++ - s!"layer={layerIdx}, head={headIdx}, offset={offset}, keyOffset={keyOffset}\n" - let body := - certs.foldl (fun acc cert => - acc ++ - s!"queryPos={cert.queryPos} targetTok={cert.targetToken} " ++ - s!"marginLB={cert.marginLowerBound} " ++ - s!"weightLB={cert.bestMatchWeightLowerBound}\n") "" - header ++ body - -private def formatHeadPatternBestMatch - (cert : Nfp.Sound.HeadBestMatchPatternCert) : String := - "SOUND head pattern (best-match): " ++ - s!"layer={cert.layerIdx}, head={cert.headIdx}, " ++ - s!"offset={cert.targetOffset}, keyOffset={cert.keyOffset}, " ++ - s!"queryPos={cert.queryPos}\n" ++ - s!"seqLen={cert.seqLen}, targetTok={cert.targetToken}, " ++ - s!"bestMatchLogitLB={cert.bestMatchLogitLowerBound}, " ++ - s!"bestNonmatchLogitUB={cert.bestNonmatchLogitUpperBound}\n" ++ - s!"marginLB={cert.marginLowerBound}, " ++ - s!"bestMatchWeightLB={cert.bestMatchWeightLowerBound}, " ++ - s!"softmaxExpEffort={cert.softmaxExpEffort}\n" - -private def formatHeadPatternLocal - (cert : Nfp.Sound.HeadPatternCert) : String := - "SOUND head pattern (local): " ++ - s!"layer={cert.layerIdx}, head={cert.headIdx}, " ++ - s!"offset={cert.targetOffset}, keyOffset={cert.keyOffset}\n" ++ - s!"seqLen={cert.seqLen}, " ++ - s!"targetCountLB={cert.targetCountLowerBound}, " ++ - s!"targetLogitLB={cert.targetLogitLowerBound}, " ++ - s!"otherLogitUB={cert.otherLogitUpperBound}\n" ++ - s!"marginLB={cert.marginLowerBound}, " ++ - s!"targetWeightLB={cert.targetWeightLowerBound}, " ++ - s!"softmaxExpEffort={cert.softmaxExpEffort}\n" - -private def runHeadPatternAction (args : HeadPatternArgs) : ExceptT String IO String := do - let delta ← - match Nfp.Sound.parseRat args.deltaStr with - | .ok r => pure r - | .error e => throw s!"invalid --delta '{args.deltaStr}': {e}" - let inputPath? : Option System.FilePath ← - match args.inputPath? with - | some path => pure (some path) - | none => - let hasEmbeddings ← - hasEmbeddingsBeforeLayers args.modelPath - if hasEmbeddings then - pure (some args.modelPath) - else - throw <| - "head pattern bounds require EMBEDDINGS; pass --input for legacy text models." - if args.useAffine && !args.bestMatch then - throw "affine bounds are only supported with --bestMatch" - if args.useAffine && args.sweep then - throw "affine sweep is unsupported; use --bestMatch without --sweep" - if args.bestMatch then - if args.sweep then - let certs ← ExceptT.mk <| - Nfp.Sound.certifyHeadPatternBestMatchLocalSweep args.modelPath args.layerIdx args.headIdx - (inputPath? := inputPath?) (inputDelta := delta) - (soundnessBits := args.soundnessBits) (targetOffset := args.offset) - (keyOffset := args.keyOffset) - (maxSeqLen := args.maxSeqLen) - (tightPattern := args.tightPattern) - (tightPatternLayers := args.tightPatternLayers) - (perRowPatternLayers := args.perRowPatternLayers) - (useAffine := args.useAffine) - (softmaxExpEffort := args.softmaxExpEffort) - (causalPattern := args.causalPattern) - return formatHeadPatternBestMatchSweep args.layerIdx args.headIdx args.offset - args.keyOffset certs - else - let cert ← ExceptT.mk <| - Nfp.Sound.certifyHeadPatternBestMatchLocal args.modelPath args.layerIdx args.headIdx - (queryPos? := args.queryPos?) (inputPath? := inputPath?) - (inputDelta := delta) (soundnessBits := args.soundnessBits) - (targetOffset := args.offset) (keyOffset := args.keyOffset) - (maxSeqLen := args.maxSeqLen) - (tightPattern := args.tightPattern) (tightPatternLayers := args.tightPatternLayers) - (perRowPatternLayers := args.perRowPatternLayers) - (useAffine := args.useAffine) - (softmaxExpEffort := args.softmaxExpEffort) - (causalPattern := args.causalPattern) - return formatHeadPatternBestMatch cert - else - if args.sweep then - throw "use --sweep with --bestMatch" - let cert ← ExceptT.mk <| - Nfp.Sound.certifyHeadPatternLocal args.modelPath args.layerIdx args.headIdx - (inputPath? := inputPath?) (inputDelta := delta) - (soundnessBits := args.soundnessBits) (targetOffset := args.offset) - (keyOffset := args.keyOffset) - (maxSeqLen := args.maxSeqLen) - (tightPattern := args.tightPattern) (tightPatternLayers := args.tightPatternLayers) - (perRowPatternLayers := args.perRowPatternLayers) - (softmaxExpEffort := args.softmaxExpEffort) - (causalPattern := args.causalPattern) - return formatHeadPatternLocal cert - -/-- Run the certify command - compute conservative, exact bounds in sound mode. -/ -def runCertify (p : Parsed) : IO UInt32 := do - let args := parseCertifyArgs p - setStdoutLogNameFromModelPath args.modelPathStr - match ← (runCertifyAction args).run with - | .error msg => - IO.eprintln s!"Error: {msg}" - return 1 - | .ok cert => - writeReport args.outputPath? (toString cert) - return 0 - -/-- Run the head-bounds command - compute per-head contribution bounds in sound mode. -/ -def runHeadBounds (p : Parsed) : IO UInt32 := do - let args := parseHeadBoundsArgs p - match ← (runHeadBoundsAction args).run with - | .error msg => - IO.eprintln s!"Error: {msg}" - return 1 - | .ok s => - writeReport args.outputPath? s - return 0 - -/-- Run the head-pattern command - compute per-head attention pattern bounds in sound mode. -/ -def runHeadPattern (p : Parsed) : IO UInt32 := do - let args := parseHeadPatternArgs p - match ← (runHeadPatternAction args).run with - | .error msg => - IO.eprintln s!"Error: {msg}" - return 1 - | .ok s => - writeReport args.outputPath? s - return 0 - -/-- Parsed arguments for `induction-cert`, with resolved input path and delta. -/ -private structure InductionCertArgs where - modelPath : System.FilePath - layer1 : Nat - head1 : Nat - layer2 : Nat - head2 : Nat - coord : Nat - offset1 : Int - offset2 : Int - keyOffset1 : Int - keyOffset2 : Int - targetToken? : Option Nat - negativeToken? : Option Nat - soundnessBits : Nat - softmaxExpEffort : Nat - tightPatternLayers : Nat - tightPattern : Bool - perRowPatternLayers : Nat - iterTighten : Bool - causalPattern : Bool - bestMatch : Bool - useAffine : Bool - queryPos? : Option Nat - inputPath? : Option System.FilePath - delta : Rat - maxSeqLen : Nat - scalePow10 : Nat - outputPath? : Option System.FilePath - -/-- Parse and validate `induction-cert` arguments. -/ -private def parseInductionCertArgs (p : Parsed) : ExceptT String IO InductionCertArgs := do - let modelPathStr := p.positionalArg! "model" |>.as! String - let modelPath : System.FilePath := ⟨modelPathStr⟩ - let layer1 := p.flag? "layer1" |>.map (·.as! Nat) |>.getD 0 - let head1 := p.flag? "head1" |>.map (·.as! Nat) |>.getD 0 - let layer2 := p.flag? "layer2" |>.map (·.as! Nat) |>.getD 1 - let head2 := p.flag? "head2" |>.map (·.as! Nat) |>.getD 0 - let coord := p.flag? "coord" |>.map (·.as! Nat) |>.getD 0 - let offset1 := p.flag? "offset1" |>.map (·.as! Int) |>.getD (-1) - let offset2 := p.flag? "offset2" |>.map (·.as! Int) |>.getD (-1) - let keyOffset1 := p.flag? "keyOffset1" |>.map (·.as! Int) |>.getD 0 - let keyOffset2 := p.flag? "keyOffset2" |>.map (·.as! Int) |>.getD 0 - let targetToken := p.flag? "target" |>.map (·.as! Nat) - let negativeToken := p.flag? "negative" |>.map (·.as! Nat) - let soundnessBits := p.flag? "soundnessBits" |>.map (·.as! Nat) |>.getD 20 - let softmaxExpEffort := - p.flag? "softmaxExpEffort" |>.map (·.as! Nat) - |>.getD Nfp.Sound.defaultSoftmaxExpEffort - let tightPatternLayers? := p.flag? "tightPatternLayers" |>.map (·.as! Nat) - let tightPatternLayers := tightPatternLayers?.getD 1 - let tightPattern := p.hasFlag "tightPattern" || tightPatternLayers?.isSome - let perRowPatternLayers := p.flag? "perRowPatternLayers" |>.map (·.as! Nat) |>.getD 0 - let iterTighten := p.hasFlag "iterTighten" - let causalPattern := !p.hasFlag "noncausalPattern" - let bestMatch := p.hasFlag "bestMatch" - let useAffine := p.hasFlag "affine" - let queryPos := p.flag? "queryPos" |>.map (·.as! Nat) - let inputPath := p.flag? "input" |>.map (·.as! String) - let deltaStr := p.flag? "delta" |>.map (·.as! String) |>.getD "0" - let maxSeqLen := p.flag? "maxSeqLen" |>.map (·.as! Nat) |>.getD 256 - let scalePow10 := p.flag? "scalePow10" |>.map (·.as! Nat) |>.getD 9 - let outputPath := p.flag? "output" |>.map (·.as! String) - let delta ← - match Nfp.Sound.parseRat deltaStr with - | .ok r => pure r - | .error e => throw s!"invalid --delta '{deltaStr}': {e}" - let inputPath? : Option System.FilePath ← - match inputPath with - | some s => pure (some ⟨s⟩) - | none => - let hasEmbeddings ← - hasEmbeddingsBeforeLayers modelPath - if hasEmbeddings then - pure (some modelPath) - else - throw <| - "induction cert requires EMBEDDINGS; pass --input for legacy text models." - let outputPath? : Option System.FilePath := - outputPath.map (fun s => ⟨s⟩) - return { - modelPath := modelPath - layer1 := layer1 - head1 := head1 - layer2 := layer2 - head2 := head2 - coord := coord - offset1 := offset1 - offset2 := offset2 - keyOffset1 := keyOffset1 - keyOffset2 := keyOffset2 - targetToken? := targetToken - negativeToken? := negativeToken - soundnessBits := soundnessBits - softmaxExpEffort := softmaxExpEffort - tightPatternLayers := tightPatternLayers - tightPattern := tightPattern - perRowPatternLayers := perRowPatternLayers - iterTighten := iterTighten - causalPattern := causalPattern - bestMatch := bestMatch - useAffine := useAffine - queryPos? := queryPos - inputPath? := inputPath? - delta := delta - maxSeqLen := maxSeqLen - scalePow10 := scalePow10 - outputPath? := outputPath? - } - -/-- Format the optional logit-diff line for local induction certs. -/ -private def formatInductionLogitLine - (logit? : Option Nfp.Sound.HeadLogitDiffLowerBoundCert) : String := - match logit? with - | none => "" - | some logit => - s!"logitDiffLB={logit.logitDiffLowerBound} " ++ - s!"targetTok={logit.targetToken} negTok={logit.negativeToken}\n" ++ - s!"logitMatchLB={logit.matchLogitLowerBound} " ++ - s!"logitNonmatchLB={logit.nonmatchLogitLowerBound}\n" - -/-- Format the optional logit-diff line for best-match induction certs. -/ -private def formatInductionLogitLinePos - (logit? : Option Nfp.Sound.HeadLogitDiffLowerBoundPosCert) : String := - match logit? with - | none => "" - | some logit => - s!"logitDiffLB={logit.logitDiffLowerBound} " ++ - s!"targetTok={logit.targetToken} negTok={logit.negativeToken}\n" ++ - s!"logitMatchLB={logit.matchLogitLowerBound} " ++ - s!"logitNonmatchLB={logit.nonmatchLogitLowerBound}\n" - -/-- Render a best-match induction certificate report. -/ -private def formatInductionBestMatch - (cert : Nfp.Sound.InductionHeadBestMatchSoundCert) : String := - let p1 := cert.layer1Pattern - let p2 := cert.layer2Pattern - let v := cert.layer2Value - let logitLine := formatInductionLogitLinePos cert.layer2Logit? - "SOUND induction cert (best-match):\n" ++ - s!"queryPos={p2.queryPos}\n" ++ - s!"layer1=L{p1.layerIdx} H{p1.headIdx} offset={p1.targetOffset} " ++ - s!"keyOffset={p1.keyOffset} " ++ - s!"targetTok={p1.targetToken} " ++ - s!"marginLB={p1.marginLowerBound} " ++ - s!"weightLB={p1.bestMatchWeightLowerBound} " ++ - s!"softmaxExpEffort={p1.softmaxExpEffort}\n" ++ - s!"layer2=L{p2.layerIdx} H{p2.headIdx} offset={p2.targetOffset} " ++ - s!"keyOffset={p2.keyOffset} " ++ - s!"targetTok={p2.targetToken} " ++ - s!"marginLB={p2.marginLowerBound} " ++ - s!"weightLB={p2.bestMatchWeightLowerBound} " ++ - s!"softmaxExpEffort={p2.softmaxExpEffort}\n" ++ - s!"coord={v.coord} matchCoordLB={v.matchCoordLowerBound} " ++ - s!"nonmatchCoordLB={v.nonmatchCoordLowerBound}\n" ++ - s!"deltaLB={cert.deltaLowerBound}\n" ++ - logitLine - -/-- Render a local induction certificate report. -/ -private def formatInductionLocal - (cert : Nfp.Sound.InductionHeadSoundCert) : String := - let p1 := cert.layer1Pattern - let p2 := cert.layer2Pattern - let v := cert.layer2Value - let logitLine := formatInductionLogitLine cert.layer2Logit? - "SOUND induction cert:\n" ++ - s!"layer1=L{p1.layerIdx} H{p1.headIdx} offset={p1.targetOffset} " ++ - s!"keyOffset={p1.keyOffset} " ++ - s!"marginLB={p1.marginLowerBound} weightLB={p1.targetWeightLowerBound} " ++ - s!"softmaxExpEffort={p1.softmaxExpEffort}\n" ++ - s!"layer2=L{p2.layerIdx} H{p2.headIdx} offset={p2.targetOffset} " ++ - s!"keyOffset={p2.keyOffset} " ++ - s!"marginLB={p2.marginLowerBound} weightLB={p2.targetWeightLowerBound} " ++ - s!"softmaxExpEffort={p2.softmaxExpEffort}\n" ++ - s!"coord={v.coord} matchCountLB={p2.targetCountLowerBound} " ++ - s!"matchCoordLB={v.matchCoordLowerBound} " ++ - s!"nonmatchCoordLB={v.nonmatchCoordLowerBound}\n" ++ - s!"deltaLB={cert.deltaLowerBound}\n" ++ - logitLine - -/-- Run the induction-cert action and return the report string. -/ -private def runInductionCertAction (args : InductionCertArgs) : ExceptT String IO String := do - if args.useAffine && !args.bestMatch then - throw "affine bounds are only supported with --bestMatch" - if args.bestMatch then - let cert ← ExceptT.mk <| - Nfp.Sound.certifyInductionSoundBestMatch args.modelPath - args.layer1 args.head1 args.layer2 args.head2 args.coord - (queryPos? := args.queryPos?) (inputPath? := args.inputPath?) - (inputDelta := args.delta) (soundnessBits := args.soundnessBits) - (offset1 := args.offset1) (offset2 := args.offset2) - (keyOffset1 := args.keyOffset1) (keyOffset2 := args.keyOffset2) - (maxSeqLen := args.maxSeqLen) - (scalePow10 := args.scalePow10) - (tightPattern := args.tightPattern) - (tightPatternLayers := args.tightPatternLayers) - (perRowPatternLayers := args.perRowPatternLayers) - (useAffine := args.useAffine) - (iterTighten := args.iterTighten) - (targetToken? := args.targetToken?) (negativeToken? := args.negativeToken?) - (softmaxExpEffort := args.softmaxExpEffort) - (causalPattern := args.causalPattern) - return formatInductionBestMatch cert - else - let cert ← ExceptT.mk <| - Nfp.Sound.certifyInductionSound args.modelPath - args.layer1 args.head1 args.layer2 args.head2 args.coord - (inputPath? := args.inputPath?) (inputDelta := args.delta) - (soundnessBits := args.soundnessBits) - (offset1 := args.offset1) (offset2 := args.offset2) - (keyOffset1 := args.keyOffset1) (keyOffset2 := args.keyOffset2) - (maxSeqLen := args.maxSeqLen) - (scalePow10 := args.scalePow10) - (tightPattern := args.tightPattern) (tightPatternLayers := args.tightPatternLayers) - (perRowPatternLayers := args.perRowPatternLayers) - (targetToken? := args.targetToken?) (negativeToken? := args.negativeToken?) - (softmaxExpEffort := args.softmaxExpEffort) - (causalPattern := args.causalPattern) - return formatInductionLocal cert - -/-- Run the induction-cert command - compute a sound induction head certificate. -/ -def runInductionCert (p : Parsed) : IO UInt32 := do - let action : ExceptT String IO (String × Option System.FilePath) := do - let args ← parseInductionCertArgs p - let report ← runInductionCertAction args - return (report, args.outputPath?) - match ← action.run with - | .error msg => - IO.eprintln s!"Error: {msg}" - return 1 - | .ok (s, outputPath?) => - match outputPath? with - | some path => - IO.FS.writeFile path s - IO.println s!"Report written to {path}" - | none => - IO.println s - return 0 - -/-! ## Sound cache check helpers -/ - -private structure SoundCacheCheckArgs where - modelPath : System.FilePath - scalePow10 : Nat - maxTokens : Nat - -private def parseSoundCacheCheckArgs (p : Parsed) : SoundCacheCheckArgs := - let modelPath := p.positionalArg! "model" |>.as! String - let scalePow10 := p.flag? "scalePow10" |>.map (·.as! Nat) |>.getD 9 - let maxTokens := p.flag? "maxTokens" |>.map (·.as! Nat) |>.getD 0 - { modelPath := ⟨modelPath⟩, scalePow10 := scalePow10, maxTokens := maxTokens } - -private def runSoundCacheCheckWithArgs (args : SoundCacheCheckArgs) : IO UInt32 := do - let modelHash ← Nfp.Untrusted.SoundCacheIO.fnv1a64File args.modelPath - let cacheFp := Nfp.Sound.SoundCache.cachePath args.modelPath modelHash args.scalePow10 - match (← Nfp.Untrusted.SoundCacheIO.buildCacheFile args.modelPath cacheFp args.scalePow10) with - | .error e => - IO.eprintln s!"Error: cache build failed: {e}" - return 1 - | .ok _ => - let ch ← IO.FS.Handle.mk cacheFp IO.FS.Mode.read - let hdr ← Nfp.Untrusted.SoundCacheIO.readHeader ch - if hdr.modelHash ≠ modelHash then - IO.eprintln "Error: cache hash mismatch" - return 1 - match (← Nfp.Untrusted.SoundCacheIO.checkCacheFileSize cacheFp hdr) with - | .error e => - IO.eprintln s!"Error: {e}" - return 1 - | .ok _ => - match - (← Nfp.Untrusted.SoundCacheIO.checkTextTokenEnvelope - args.modelPath args.scalePow10 args.maxTokens) with - | .error e => - IO.eprintln s!"Error: {e}" - return 1 - | .ok _ => - IO.println <| - "OK: sound cache validated " ++ - s!"(scalePow10={args.scalePow10}, maxTokens={args.maxTokens})" - return 0 - -/-- Regression test for SOUND fixed-point cache soundness and consistency. - -This is intended for CI and small fixtures. It: -- builds a `.nfpc` cache (if needed), -- checks cache size matches the expected tensor stream length, -- checks the `±1`-ulp rounding envelope on up to `maxTokens` numeric tokens in the text file. --/ -def runSoundCacheCheck (p : Parsed) : IO UInt32 := do - let args := parseSoundCacheCheckArgs p - runSoundCacheCheckWithArgs args - -/-! ## Sound cache benchmark helpers -/ - -private structure SoundCacheBenchArgs where - modelPath : System.FilePath - scalePow10 : Nat - runs : Nat - -private def parseSoundCacheBenchArgs (p : Parsed) : SoundCacheBenchArgs := - let modelPath := p.positionalArg! "model" |>.as! String - let scalePow10 := p.flag? "scalePow10" |>.map (·.as! Nat) |>.getD 9 - let runs := p.flag? "runs" |>.map (·.as! Nat) |>.getD 1 - { modelPath := ⟨modelPath⟩, scalePow10 := scalePow10, runs := runs } - -private def runSoundCacheBenchWithArgs (args : SoundCacheBenchArgs) : IO UInt32 := do - if args.runs = 0 then - IO.eprintln "Error: --runs must be > 0" - return 1 - let modelHash ← Nfp.Untrusted.SoundCacheIO.fnv1a64File args.modelPath - let mdata ← args.modelPath.metadata - let modelSize : UInt64 := mdata.byteSize - let isBinaryE ← Nfp.Untrusted.SoundCacheIO.isBinaryModelFile args.modelPath - let isBinary ← - match isBinaryE with - | .error e => - IO.eprintln s!"Error: {e}" - return 1 - | .ok b => pure b - let formatStr := if isBinary then "binary" else "text" - let mut times : Array Nat := Array.mkEmpty args.runs - let mut lastBytes : Nat := 0 - for i in [:args.runs] do - let t0 ← IO.monoNanosNow - let bytesE ← - if isBinary then - Nfp.Untrusted.SoundCacheIO.buildCacheBytesBinary - args.modelPath args.scalePow10 modelHash modelSize - else - Nfp.Untrusted.SoundCacheIO.buildCacheBytesText - args.modelPath args.scalePow10 modelHash modelSize - let t1 ← IO.monoNanosNow - match bytesE with - | .error e => - IO.eprintln s!"Error: {e}" - return 1 - | .ok bytes => - let dtMs := (t1 - t0) / 1000000 - times := times.push dtMs - lastBytes := bytes.size - if args.runs > 1 then - IO.println s!"run {i + 1}: {dtMs}ms" - let t0 := times[0]! - let mut minT := t0 - let mut maxT := t0 - let mut sumT : Nat := 0 - for t in times do - if t < minT then - minT := t - if t > maxT then - maxT := t - sumT := sumT + t - let avgT := sumT / times.size - IO.println s!"cacheBuild format={formatStr} scalePow10={args.scalePow10} bytes={lastBytes}" - IO.println s!"cacheBuild runs={args.runs} min={minT}ms avg={avgT}ms max={maxT}ms" - return 0 - -def runSoundCacheBench (p : Parsed) : IO UInt32 := do - let args := parseSoundCacheBenchArgs p - runSoundCacheBenchWithArgs args - -/-- Run the rope command - print a proof-backed RoPE operator norm certificate. -/ -def runRoPE (p : Parsed) : IO UInt32 := do - let seqLen := p.flag? "seqLen" |>.map (·.as! Nat) |>.getD 4 - let pairs := p.flag? "pairs" |>.map (·.as! Nat) |>.getD 8 - match seqLen, pairs with - | Nat.succ n, Nat.succ m => - -- Instantiate the theorem at concrete sizes to ensure the report is proof-backed. - let θ : Fin (Nat.succ n) → Fin (Nat.succ m) → ℝ := fun _ _ => 0 - have _ : - Nfp.operatorNormBound - (n := Fin (Nat.succ n)) (d := Nfp.RoPEDim (Fin (Nat.succ m))) - (Nfp.ropeJacobian (pos := Fin (Nat.succ n)) (pair := Fin (Nat.succ m)) θ) - ≤ (2 : ℝ) := by - simpa using - (Nfp.rope_operatorNormBound_le_two - (pos := Fin (Nat.succ n)) (pair := Fin (Nat.succ m)) θ) - IO.println "RoPE certificate (static):" - IO.println s!" seqLen={seqLen}, pairs={pairs}, dim={2 * pairs}" - IO.println " Bound: operatorNormBound(ropeJacobian θ) ≤ 2" - IO.println " Meaning: max row-sum of absolute weights (ℓ1 induced for row-vectors)." - IO.println " Proof: Nfp.rope_operatorNormBound_le_two (uses |sin|,|cos| ≤ 1 from mathlib)." - return 0 - | _, _ => - IO.eprintln "Error: --seqLen and --pairs must be positive." - return 1 - -/-! ## Dump command helpers -/ - -private structure DumpArgs where - modelPath : System.FilePath - modelPathStr : String - layer : Nat - pos : Nat - take : Nat - kind : String - -private def parseDumpArgs (p : Parsed) : DumpArgs := - let modelPathStr := p.positionalArg! "model" |>.as! String - let layer := p.flag? "layer" |>.map (·.as! Nat) |>.getD 0 - let pos := p.flag? "pos" |>.map (·.as! Nat) |>.getD 0 - let take := p.flag? "take" |>.map (·.as! Nat) |>.getD 16 - let kind := p.flag? "kind" |>.map (·.as! String) |>.getD "afterLayer" - { modelPath := ⟨modelPathStr⟩ - modelPathStr := modelPathStr - layer := layer - pos := pos - take := take - kind := kind } - -private def selectDumpMatrix - (kind : String) (layer : Nat) (model : ConcreteModel) (fwd : ForwardPassResult) : - ConcreteMatrix := - match kind with - | "embeddings" => model.inputEmbeddings - | "layerInput" => fwd.getLayerInput layer - | "postAttn" => fwd.getPostAttnResidual layer - | "afterLayer" => fwd.getLayerInput (layer + 1) - | _ => fwd.getLayerInput (layer + 1) - -private def collectDumpRow (X : ConcreteMatrix) (pos n : Nat) : - (Array Float × Float × Float) := Id.run do - let mut xs : Array Float := Array.mkEmpty n - let mut sum : Float := 0.0 - let mut sumSq : Float := 0.0 - for j in [:n] do - let v := X.get pos j - xs := xs.push v - sum := sum + v - sumSq := sumSq + v * v - return (xs, sum, sumSq) - -private def runDumpWithArgs (args : DumpArgs) : IO UInt32 := do - setStdoutLogNameFromModelPath args.modelPathStr - let loadResult ← loadModel args.modelPath - match loadResult with - | .error msg => - IO.eprintln s!"Error loading model: {msg}" - return 1 - | .ok model0 => - let model := model0.trimTrailingZeroEmbeddings - let fwd := model.runForward true - let X := selectDumpMatrix args.kind args.layer model fwd - if X.numRows = 0 || X.numCols = 0 then - IO.eprintln s!"Error: empty matrix for kind={args.kind}" - return 1 - if args.pos ≥ X.numRows then - IO.eprintln s!"Error: pos={args.pos} out of bounds (rows={X.numRows})" - return 1 - let n := min args.take X.numCols - let (xs, sum, sumSq) := collectDumpRow X args.pos n - IO.println <| - s!"DUMP kind={args.kind} layer={args.layer} pos={args.pos} take={n} " ++ - s!"rows={X.numRows} cols={X.numCols}" - IO.println s!"sum={sum} sumSq={sumSq}" - IO.println (String.intercalate " " (xs.toList.map (fun x => s!"{x}"))) - return 0 - -/-- Dump a small slice of a forward pass for cross-implementation sanity checks. -/ -def runDump (p : Parsed) : IO UInt32 := do - let args := parseDumpArgs p - runDumpWithArgs args - -/-! ## Logit-difference helpers -/ - -private def logitAt (residual : ConcreteMatrix) (pos : Nat) - (W_U : ConcreteMatrix) (token : Nat) : Except String Float := - if residual.numCols ≠ W_U.numRows then - .error "dimension mismatch: residual.numCols != W_U.numRows" - else if pos ≥ residual.numRows then - .error "position out of range" - else if token ≥ W_U.numCols then - .error "token out of range" - else - .ok <| Id.run do - let d := residual.numCols - let vocab := W_U.numCols - let rowBase := pos * d - let mut acc : Float := 0.0 - for k in [:d] do - acc := acc + residual.data[rowBase + k]! * W_U.data[k * vocab + token]! - return acc - -private def topNonTargetToken (residual : ConcreteMatrix) (pos : Nat) - (W_U : ConcreteMatrix) (targetToken : Nat) : Except String (Nat × Float) := - if residual.numCols ≠ W_U.numRows then - .error "dimension mismatch: residual.numCols != W_U.numRows" - else if pos ≥ residual.numRows then - .error "position out of range" - else if targetToken ≥ W_U.numCols then - .error "target token out of range" - else if W_U.numCols < 2 then - .error "vocab size too small to select non-target token" - else - .ok <| Id.run do - let d := residual.numCols - let vocab := W_U.numCols - let rowBase := pos * d - let mut bestTok : Nat := 0 - let mut bestLogit : Float := (-Float.inf) - let mut found : Bool := false - for tok in [:vocab] do - if tok ≠ targetToken then - found := true - let mut acc : Float := 0.0 - for k in [:d] do - acc := acc + residual.data[rowBase + k]! * W_U.data[k * vocab + tok]! - if acc > bestLogit then - bestTok := tok - bestLogit := acc - if found then - return (bestTok, bestLogit) - else - return (0, bestLogit) - -private structure LogitDiffArgs where - modelPath : System.FilePath - modelPathStr : String - target : Nat - negative : Nat - pos? : Option Nat - inputPath? : Option System.FilePath - autoNegative : Bool - -private def parseLogitDiffArgs (p : Parsed) : LogitDiffArgs := - let modelPathStr := p.positionalArg! "model" |>.as! String - let target := p.positionalArg! "target" |>.as! Nat - let negative := p.positionalArg! "negative" |>.as! Nat - let pos? := p.flag? "pos" |>.map (·.as! Nat) - let inputPath? := p.flag? "input" |>.map (System.FilePath.mk ∘ (·.as! String)) - let autoNegative := p.hasFlag "autoNegative" - { modelPath := ⟨modelPathStr⟩ - modelPathStr := modelPathStr - target := target - negative := negative - pos? := pos? - inputPath? := inputPath? - autoNegative := autoNegative } - -private def runLogitDiff (p : Parsed) : IO UInt32 := do - let args := parseLogitDiffArgs p - setStdoutLogNameFromModelPath args.modelPathStr - let loadResult ← loadModel args.modelPath - match loadResult with - | .error msg => - IO.eprintln s!"Error loading model: {msg}" - return 1 - | .ok model0 => - let model ← - match args.inputPath? with - | none => pure model0 - | some inputPath => - match ← loadInputBinary inputPath with - | .error msg => - IO.eprintln s!"Error loading input: {msg}" - return 1 - | .ok input => - if input.modelDim ≠ model0.modelDim then - IO.eprintln s!"Input model_dim mismatch ({input.modelDim} != {model0.modelDim})" - return 1 - pure { - model0 with - seqLen := input.seqLen - inputTokens := some input.tokens - inputEmbeddings := input.embeddings - } - match model.unembedding with - | none => - IO.eprintln "Error: Model is missing unembedding matrix (needed for logits)." - return 1 - | some W_U => - if model.seqLen = 0 then - IO.eprintln "Error: seq_len = 0; cannot compute logits." - return 1 - let pos := args.pos?.getD (model.seqLen - 1) - if pos ≥ model.seqLen then - IO.eprintln s!"Error: pos={pos} out of bounds (seq_len={model.seqLen})" - return 1 - let fwd := model.runForward true - let residual := fwd.finalOutput - let negResult := - if args.autoNegative then - topNonTargetToken residual pos W_U args.target - else - match logitAt residual pos W_U args.negative with - | .ok logit => .ok (args.negative, logit) - | .error msg => .error msg - match logitAt residual pos W_U args.target, negResult with - | .ok targetLogit, .ok (negTok, negLogit) => - let diff := targetLogit - negLogit - IO.println s!"pos={pos} target={args.target} negative={negTok}" - if args.autoNegative then - IO.println "negativeSource=topNonTarget" - IO.println s!"logit(target)={targetLogit}" - IO.println s!"logit(negative)={negLogit}" - IO.println s!"logitDiff={diff}" - return 0 - | .error msg, _ => - IO.eprintln s!"Error computing target logit: {msg}" - return 1 - | _, .error msg => - IO.eprintln s!"Error computing negative logit: {msg}" - return 1 - -/-- The analyze subcommand. -/ -def analyzeCmd : Cmd := `[Cli| - analyze VIA runAnalyze; - "Analyze a neural network model for circuit discovery and verification." - FLAGS: - t, threshold : String; "Error threshold for verification (default: 0.1)" - o, output : String; "Write report to file instead of stdout" - verify; "Run empirical verification (requires input in model)" - v, verbose; "Enable verbose output" - ARGS: - model : String; "Path to the model weights file (.nfpt)" -] - -/-- The induction subcommand. -/ -def inductionCmd : Cmd := `[Cli| - induction VIA runInduction; - "Discover induction head pairs ranked by mechScore (kComp·indScore·prevTok)." - FLAGS: - c, correct : Nat; "Correct token ID (manual override; requires --incorrect)" - i, incorrect : Nat; "Incorrect token ID (manual override; requires --correct)" - t, threshold : String; "Minimum normalized Effect threshold (default: 0.0)" - verify; "Run causal verification via head ablation on the top-10 candidates" - v, verbose; "Enable verbose output" - d, diagnostics; "Print diagnostic breakdown of ε bounds (pattern/value decomposition)" - diagTop : Nat; "How many top candidates get diagnostics (default: 5)" - adaptive; "Enable adaptive bound scheduler (rigorous; deterministic)" - targetSlack : String; "Stop when ub/lb ≤ targetSlack (default: 8.0)" - maxUpgrades : Nat; "Maximum adaptive upgrades (default: 120)" - minRelImprove : String; "Stop upgrading a layer if improvement < this fraction (default: 0.01)" - krylovSteps : Nat; "Krylov steps for LOWER bounds only (default: 2)" - adaptiveScope : String; "Adaptive scope: layernorm | all (default: layernorm)" - ARGS: - model : String; "Path to the model weights file (.nfpt)" -] - -/-- The bench subcommand. -/ -def benchCmd : Cmd := `[Cli| - bench VIA runBench; - "Run repeatable microbenchmarks on analysis or induction search." - FLAGS: - mode : String; "analysis|induction (default: analysis)" - runs : Nat; "Number of timed runs (default: 5)" - repeats : Nat; "Repeat inner workload per run (default: 1)" - t, threshold : String; "Analyze threshold (default: 0.1)" - minEffect : String; "Induction minEffect (default: 0.0)" - c, correct : Nat; "Correct token ID (requires --incorrect)" - i, incorrect : Nat; "Incorrect token ID (requires --correct)" - v, verbose; "Print per-run timing details" - breakdown; "Emit per-phase averages (analysis only)" - ARGS: - model : String; "Path to the model weights file (.nfpt)" -] - -/-- The certify subcommand. -/ -def certifyCmd : Cmd := `[Cli| - certify VIA runCertify; - "SOUND mode: compute conservative bounds using exact Rat arithmetic (no Float trust). \ -LayerNorm epsilon and GeLU kind are read from the model header." - FLAGS: - input : String; "Optional input .nfpt file for local certification (must contain EMBEDDINGS \ -for legacy text)" - delta : String; "Input ℓ∞ radius δ for local certification (default: 0; \ -if --input is omitted, uses EMBEDDINGS in the model file when present)" - softmaxMargin : String; "Lower bound on softmax logit margin (default: 0)" - softmaxExpEffort : Nat; "Effort level for margin-based exp lower bounds (default: 1)" - bestMatchMargins; "Apply best-match margin tightening (binary + local only)" - targetOffset : Int; "Token-match offset for best-match margins (default: -1)" - maxSeqLen : Nat; "Max sequence length for best-match margins (default: 0 uses full seq_len)" - tightPattern; "Use tighter (slower) pattern bounds for best-match margins" - tightPatternLayers : Nat; "Number of layers using tight pattern bounds (default: 1)" - perRowPatternLayers : Nat; "Number of layers using per-row MLP propagation (default: 0)" - noncausalPattern; "Disable causal-prefix restriction for pattern/value bounds" - scalePow10 : Nat; "Fixed-point scale exponent for best-match margins (default: 9)" - soundnessBits : Nat; "Dyadic sqrt precision bits for LayerNorm bounds (default: 20)" - partitionDepth : Nat; "Partition depth for input splitting (default: 0; >0 scaffold only)" - o, output : String; "Write report to file instead of stdout" - ARGS: - model : String; "Path to the model weights file (.nfpt)" -] - -/-- The head-bounds subcommand. -/ -def headBoundsCmd : Cmd := `[Cli| - head_bounds VIA runHeadBounds; - "SOUND mode: compute per-head contribution bounds. \ -LayerNorm epsilon is read from the model header." - FLAGS: - input : String; "Optional input .nfpt file for local bounds (must contain EMBEDDINGS \ -for legacy text)" - delta : String; "Input ℓ∞ radius δ for local bounds (default: 0; if --input is omitted, \ -uses EMBEDDINGS in the model file when present)" - soundnessBits : Nat; "Dyadic sqrt precision bits for LayerNorm bounds (default: 20)" - scalePow10 : Nat; "Fixed-point scale exponent p in S=10^p for global bounds (default: 9)" - o, output : String; "Write report to file instead of stdout" - ARGS: - model : String; "Path to the model weights file (.nfpt)" -] - -/-- The head-pattern subcommand. -/ -def headPatternCmd : Cmd := `[Cli| - head_pattern VIA runHeadPattern; - "SOUND mode: compute per-head attention pattern bounds (binary only). \ -LayerNorm epsilon is read from the model header." - FLAGS: - layer : Nat; "Layer index (default: 0)" - head : Nat; "Head index (default: 0)" - offset : Int; "Token-match offset (default: -1 for previous token, 0 for self)" - keyOffset : Int; "Key-position token offset (default: 0; use -1 with offset=0 for copy-next)" - tightPattern; "Use tighter (slower) pattern bounds near the target layer" - tightPatternLayers : Nat; "Number of layers using tight pattern bounds (default: 1)" - perRowPatternLayers : Nat; "Number of layers using per-row MLP propagation (default: 0)" - noncausalPattern; "Disable causal-prefix restriction for pattern/value bounds" - bestMatch; "Use best-match (single-query) pattern bounds" - affine; "Use affine Q/K dot bounds for best-match (single-query only)" - sweep; "Sweep best-match bounds across all valid query positions" - queryPos : Nat; "Query position for best-match bounds (default: last position)" - input : String; "Optional input .nfpt file for local bounds (must contain EMBEDDINGS \ -for legacy text)" - delta : String; "Input ℓ∞ radius δ for local bounds (default: 0)" - soundnessBits : Nat; "Dyadic sqrt precision bits for LayerNorm bounds (default: 20)" - softmaxExpEffort : Nat; "Effort level for margin-based exp lower bounds (default: 1)" - maxSeqLen : Nat; "Maximum sequence length to analyze (default: 256)" - scalePow10 : Nat; "Fixed-point scale exponent for best-match bounds (default: 9)" - o, output : String; "Write report to file instead of stdout" - ARGS: - model : String; "Path to the model weights file (.nfpt)" -] - -/-- The induction-cert subcommand. -/ -def inductionCertCmd : Cmd := `[Cli| - induction_cert VIA runInductionCert; - "SOUND mode: compute a minimal induction head certificate (binary only). \ -LayerNorm epsilon is read from the model header." - FLAGS: - layer1 : Nat; "Layer index for the previous-token head (default: 0)" - head1 : Nat; "Head index for the previous-token head (default: 0)" - layer2 : Nat; "Layer index for the token-match head (default: 1)" - head2 : Nat; "Head index for the token-match head (default: 0)" - coord : Nat; "Output coordinate for the value bound (default: 0)" - offset1 : Int; "Token-match offset for layer1 (default: -1)" - offset2 : Int; "Token-match offset for layer2 (default: -1)" - keyOffset1 : Int; "Key-position token offset for layer1 (default: 0)" - keyOffset2 : Int; "Key-position token offset for layer2 (default: 0; use -1 with \ -offset2=0 for copy-next)" - target : Nat; "Target token ID for logit-diff bound (optional; requires --negative)" - negative : Nat; "Negative token ID for logit-diff bound (optional; requires --target)" - tightPattern; "Use tighter (slower) pattern bounds near the target layer" - tightPatternLayers : Nat; "Number of layers using tight pattern bounds (default: 1)" - perRowPatternLayers : Nat; "Number of layers using per-row MLP propagation (default: 0)" - iterTighten; "Iteratively tighten best-match bounds (escalates tight/per-row layers to full)" - noncausalPattern; "Disable causal-prefix restriction for pattern/value bounds" - bestMatch; "Use best-match (single-query) pattern bounds" - affine; "Use affine Q/K dot bounds for best-match" - queryPos : Nat; "Query position for best-match bounds (default: last position)" - input : String; "Optional input .nfpt file for local bounds (must contain EMBEDDINGS \ -for legacy text)" - delta : String; "Input ℓ∞ radius δ for local bounds (default: 0)" - soundnessBits : Nat; "Dyadic sqrt precision bits for LayerNorm bounds (default: 20)" - softmaxExpEffort : Nat; "Effort level for margin-based exp lower bounds (default: 1)" - maxSeqLen : Nat; "Maximum sequence length to analyze (default: 256)" - scalePow10 : Nat; "Fixed-point scale exponent for best-match bounds (default: 9)" - o, output : String; "Write report to file instead of stdout" - ARGS: - model : String; "Path to the model weights file (.nfpt)" -] - -/-- The sound-cache-check subcommand (CI regression test). -/ -def soundCacheCheckCmd : Cmd := `[Cli| - sound_cache_check VIA runSoundCacheCheck; - "Check SOUND fixed-point cache soundness (CI / small fixtures)." - FLAGS: - scalePow10 : Nat; "Fixed-point scale exponent p in S=10^p (default: 9)" - maxTokens : Nat; "Check at most this many numeric tokens (0=all; default: 0)" - ARGS: - model : String; "Path to the model weights file (.nfpt)" -] - -/-- The sound-cache-bench subcommand. -/ -def soundCacheBenchCmd : Cmd := `[Cli| - sound_cache_bench VIA runSoundCacheBench; - "Benchmark SOUND fixed-point cache build (text or binary)." - FLAGS: - scalePow10 : Nat; "Fixed-point scale exponent p in S=10^p (default: 9)" - runs : Nat; "Number of benchmark runs (default: 1)" - ARGS: - model : String; "Path to the model weights file (.nfpt)" -] - -/-- The rope subcommand. -/ -def ropeCmd : Cmd := `[Cli| - rope VIA runRoPE; - "Static certificate for RoPE (rotary position embedding) linearization bounds." - FLAGS: - seqLen : Nat; "Sequence length (>0) used for instantiation (default: 4)" - pairs : Nat; "Number of RoPE pairs (>0); dimension is 2*pairs (default: 8)" -] - -/-- The dump subcommand. -/ -def dumpCmd : Cmd := `[Cli| - dump VIA runDump; - "Dump a small forward-pass slice (for PyTorch sanity checking)." - FLAGS: - layer : Nat; "Layer index (default: 0)" - pos : Nat; "Token position / row index (default: 0)" - take : Nat; "How many columns to dump from the start (default: 16)" - kind : String; "embeddings | layerInput | postAttn | afterLayer (default: afterLayer)" - ARGS: - model : String; "Path to the model weights file (.nfpt)" -] - -/-- The logit-diff subcommand. -/ -def logitDiffCmd : Cmd := `[Cli| - logit_diff VIA runLogitDiff; - "Compute empirical logit-difference for target vs. negative token." - FLAGS: - pos : Nat; "Token position (default: last position)" - input : String; "Optional input .nfpt file with TOKENS + EMBEDDINGS" - autoNegative; "Use top non-target logit as negative token (ignores provided negative)" - ARGS: - model : String; "Path to the model weights file (.nfpt)" - target : Nat; "Target token ID" - negative : Nat; "Negative token ID" -] - -/-- The main CLI command. -/ -def nfpCmd : Cmd := `[Cli| - nfp NOOP; - "NFP: Neural Formal Pathways verification toolkit" - SUBCOMMANDS: - analyzeCmd; - inductionCmd; - benchCmd; - certifyCmd; - headBoundsCmd; - headPatternCmd; - inductionCertCmd; - soundCacheCheckCmd; - soundCacheBenchCmd; - ropeCmd; - dumpCmd; - logitDiffCmd -] - -/-- Main entry point. -/ -def main (args : List String) : IO UInt32 := do - let ctx ← openPendingStdoutLog - stdoutLogCtxRef.set (some ctx) - let out ← IO.getStdout - let log := IO.FS.Stream.ofHandle ctx.handle - let tee := mkTeeStream out log - IO.withStdout tee <| do - try - if args.contains "--version" then - setStdoutLogName "version" - IO.println "nfp version 0.1.0" - return (0 : UInt32) - nfpCmd.validate args - finally - let pending ← ctx.pendingRef.get - if pending then - setStdoutLogName "no_model" - stdoutLogCtxRef.set none +/-- CLI entry point. -/ +def main (args : List String) : IO UInt32 := + Nfp.main args diff --git a/Nfp.lean b/Nfp.lean index 1967251..d04adfa 100644 --- a/Nfp.lean +++ b/Nfp.lean @@ -1,287 +1,25 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later +import Nfp.Core import Nfp.Prob -import Nfp.PCC -import Nfp.Uniqueness -import Nfp.Appendix import Nfp.Mixer -import Nfp.Reroute.Heat -import Nfp.Reroute.Partition -import Nfp.Influence -import Nfp.MixerLocalSystem -import Nfp.Attribution -import Nfp.Layers -import Nfp.SignedMixer -import Nfp.Linearization -import Nfp.Abstraction -import Nfp.Induction -import Nfp.Discovery -import Nfp.Verification -import Nfp.IO -import Nfp.Sound.IO -import Nfp.Sound.Bridge -import Nfp.Sound.Demo +import Nfp.System /-! -Axioms used by key theorems/definitions +Top-level reexports and trust dashboard for the NFP rewrite. +-/ + +/-! +Axioms used by key definitions/lemmas. These `#print axioms` lines help ensure we only depend on a small set of axioms (ideally a subset of: `propext`, `Classical.choice`, `Quot.sound`). -/ --- From Nfp.Prob #print axioms Nfp.ProbVec.sum_mass #print axioms Nfp.ProbVec.pure #print axioms Nfp.ProbVec.mix - --- From Nfp.PCC -#print axioms Nfp.tracerOfContrib -#print axioms Nfp.tracerOfContrib_mass -#print axioms Nfp.sum_monotone_chain -#print axioms Nfp.monotone_removed_mass --- From Nfp.Uniqueness -#print axioms Nfp.LocalSystem.tracer_unique --- From Nfp.Appendix (PCC-focused lemmas) -#print axioms Nfp.drop_eq_mass -#print axioms Nfp.pcc_upper_bound_for_any_mask -#print axioms Nfp.normalized_sum_monotone --- Residuals -#print axioms Nfp.lambdaEC -#print axioms Nfp.lambdaEC_sum_one -#print axioms Nfp.residual_lambda_from_norm --- Greedy optimality -#print axioms Nfp.greedy_topk_optimal --- Budgeted PCC sup -#print axioms Nfp.normMass -#print axioms Nfp.normMass_union -#print axioms Nfp.normMass_partitionUnion -#print axioms Nfp.normMass_partition_eq_one -#print axioms Nfp.ReroutePlan.masks -#print axioms Nfp.ReroutePlan.normMass_sum_one -#print axioms Nfp.unionParts -#print axioms Nfp.disjoint_unionParts -#print axioms Nfp.feasible -#print axioms Nfp.pccArg -#print axioms Nfp.pccMax -#print axioms Nfp.pccMax_le_tau -#print axioms Nfp.pccMax_monotone -#print axioms Nfp.pccMax_dominates -#print axioms Nfp.greedy_topmass_optimal - --- PCC(t) alias and properties -#print axioms Nfp.PCC_monotone -#print axioms Nfp.PCC_upper_bounds_masks - --- Normalization utilities and uniqueness -#print axioms Nfp.normalizeOn_outside -#print axioms Nfp.normalizeOn_inside -#print axioms Nfp.normalizeOn_sum_one -#print axioms Nfp.proportional_row_unique --- Residual characterization (Appendix A.3) -#print axioms Nfp.lambdaEC_scale_invariant_global - --- Appendix A.1 consolidated wrappers and packaged theorem -#print axioms Nfp.A1_residual_unique -#print axioms Nfp.A1_normalize_unique -#print axioms Nfp.A1_global_tracer_unique -#print axioms Nfp.A1 - --- Forward mixers (Appendix A.1) #print axioms Nfp.Mixer.push -#print axioms Nfp.Mixer.row -#print axioms Nfp.Mixer.push_pure -#print axioms Nfp.Mixer.push_mix #print axioms Nfp.Mixer.comp - --- Appendix A.2 wrappers and packaged theorem -#print axioms Nfp.A2_graph_faithful_comp -#print axioms Nfp.A2_residual_energy_consistent -#print axioms Nfp.A2_residual_unique -#print axioms Nfp.A2_normalize_row_sum_one -#print axioms Nfp.A2 - --- From Nfp.Mixer (support restriction) -#print axioms Nfp.Mixer.supported_comp -#print axioms Nfp.Mixer.supp_push_subset_image - --- From Nfp.Influence -#print axioms Nfp.Mixer.ofInfluenceSpec -#print axioms Nfp.Mixer.ofInfluenceSpec_supported - --- From Nfp.MixerLocalSystem -#print axioms Nfp.LocalSystem.ofMixerIdx -#print axioms Nfp.LocalSystem.ofMixer - --- From Nfp.Reroute.Partition -#print axioms Nfp.ReroutePlan.increments -#print axioms Nfp.ReroutePlan.increments_pairwise -#print axioms Nfp.sum_unionParts_eq - --- From Nfp.Reroute.Heat -#print axioms Nfp.WeightedReroutePlan.rerouteHeat -#print axioms Nfp.WeightedReroutePlan.heatRaw_sum_increment -#print axioms Nfp.WeightedReroutePlan.rerouteHeat_sum_increment - --- From Nfp.Attribution (attribution axioms for NN interpretation) -#print axioms Nfp.Attribution.Complete -#print axioms Nfp.Attribution.CompleteScalar -#print axioms Nfp.Attribution.Sensitive -#print axioms Nfp.Attribution.Dummy -#print axioms Nfp.Attribution.Symmetric -#print axioms Nfp.attributionOfProbVec -#print axioms Nfp.tracer_attribution_complete - --- From Nfp.Layers (NN layer mixers) -#print axioms Nfp.Mixer.identity -#print axioms Nfp.Mixer.identity_comp -#print axioms Nfp.Mixer.comp_identity -#print axioms Nfp.Mixer.comp_assoc -#print axioms Nfp.Mixer.attention -#print axioms Nfp.Mixer.selfAttention -#print axioms Nfp.Mixer.residual -#print axioms Nfp.Mixer.residual_one -#print axioms Nfp.Mixer.residual_zero -#print axioms Nfp.Mixer.comp3 -#print axioms Nfp.Mixer.comp3_path_decomposition -#print axioms Nfp.Mixer.transformerBlock -#print axioms Nfp.attentionFlow -#print axioms Nfp.attentionFlow_singleton -#print axioms Nfp.attentionFlow_nil -#print axioms Nfp.effectiveAttention_normalized -#print axioms Nfp.pathContrib_sum -#print axioms Nfp.Mixer.push_preserves_total_mass -#print axioms Nfp.Mixer.push_comp -#print axioms Nfp.transformer_attribution_unique - --- Residual stream analysis -#print axioms Nfp.Mixer.residual_decomposition -#print axioms Nfp.Mixer.residual_skip_dominance -#print axioms Nfp.Mixer.residual_off_diag_bound -#print axioms Nfp.Mixer.residual_off_diag_scaling - --- Attention concentration -#print axioms Nfp.Mixer.maxWeight -#print axioms Nfp.Mixer.weight_le_maxWeight -#print axioms Nfp.Mixer.push_concentration_bound - --- Ablation analysis -#print axioms Nfp.Mixer.maskFn -#print axioms Nfp.Mixer.maskFn_blocked -#print axioms Nfp.Mixer.maskFn_unblocked -#print axioms Nfp.blockedContribution -#print axioms Nfp.unblockedContribution -#print axioms Nfp.Mixer.ablation_decomposition - --- Composition depth and reachability -#print axioms Nfp.Mixer.reachable -#print axioms Nfp.Mixer.reach_comp -#print axioms Nfp.Mixer.path_contrib_le_comp -#print axioms Nfp.Mixer.comp_reachable_of_path - --- Information bounds -#print axioms Nfp.Mixer.supportSize -#print axioms Nfp.Mixer.exists_nonzero -#print axioms Nfp.Mixer.supportSize_pos - --- Gradient correspondence (chain rule) -#print axioms Nfp.Mixer.chain_rule_analog -#print axioms Nfp.Mixer.chain_rule_three - --- Multi-head attention -#print axioms Nfp.Mixer.multiHead -#print axioms Nfp.Mixer.multiHead_head_contrib_bound -#print axioms Nfp.Mixer.multiHead_zero_weight -#print axioms Nfp.Mixer.multiHead_single_head -#print axioms Nfp.Mixer.multiHead_convex - --- Causal (autoregressive) attention -#print axioms Nfp.isCausal -#print axioms Nfp.causal_reachable_dir -#print axioms Nfp.causal_comp -#print axioms Nfp.causal_future_attention_zero -#print axioms Nfp.causal_past_attention_one -#print axioms Nfp.causal_first_token_self - --- Attention head analysis -#print axioms Nfp.Mixer.attentionConcentration -#print axioms Nfp.Mixer.attentionConcentration_upper_bound -#print axioms Nfp.Mixer.isSparseAt -#print axioms Nfp.Mixer.isDiffuseAt -#print axioms Nfp.Mixer.attentionConcentration_one_hot - --- Residual dominance analysis -#print axioms Nfp.residual_diagonal_lower -#print axioms Nfp.residual_offdiag_scale -#print axioms Nfp.residual_offdiag_sum_bound - --- Deep composition analysis -#print axioms Nfp.comp_weight_le_one -#print axioms Nfp.comp_term_bound - --- Cross-attention for encoder-decoder models -#print axioms Nfp.Mixer.crossAttention -#print axioms Nfp.Mixer.crossAttention_normalized - --- Layer-wise attribution analysis -#print axioms Nfp.layerWiseAttribution -#print axioms Nfp.layerWiseAttribution_nil -#print axioms Nfp.layerWiseAttribution_singleton -#print axioms Nfp.layerWiseAttribution_le_one -#print axioms Nfp.layerWiseAttribution_sum_one - --- From Nfp.SignedMixer (generalized signed weight mixers) -#print axioms Nfp.SignedMixer -#print axioms Nfp.SignedMixer.comp -#print axioms Nfp.SignedMixer.comp_assoc -#print axioms Nfp.SignedMixer.identity -#print axioms Nfp.SignedMixer.positivePart -#print axioms Nfp.SignedMixer.negativePart -#print axioms Nfp.SignedMixer.decompose -#print axioms Nfp.SignedMixer.rowSum -#print axioms Nfp.SignedMixer.IsRowStochastic -#print axioms Nfp.SignedMixer.toMixer -#print axioms Nfp.SignedMixer.ofMixer -#print axioms Nfp.SignedMixer.ofMixer_isRowStochastic -#print axioms Nfp.SignedMixer.influence -#print axioms Nfp.SignedMixer.totalInfluenceFrom -#print axioms Nfp.SignedMixer.totalInfluenceOn -#print axioms Nfp.SignedMixer.apply -#print axioms Nfp.SignedMixer.apply_comp -#print axioms Nfp.SignedMixer.jacobianEntry_eq_weight - --- AffineMixer (linear + bias) -#print axioms Nfp.AffineMixer -#print axioms Nfp.AffineMixer.linear -#print axioms Nfp.AffineMixer.bias -#print axioms Nfp.AffineMixer.apply -#print axioms Nfp.AffineMixer.comp - --- From Nfp.Linearization (linearizing non-linear ops) -#print axioms Nfp.Linearization -#print axioms Nfp.Linearization.comp -#print axioms Nfp.relu -#print axioms Nfp.reluGrad -#print axioms Nfp.reluLinearization -#print axioms Nfp.reluLinearization_diagonal -#print axioms Nfp.reluLinearization_diag_binary -#print axioms Nfp.gelu -#print axioms Nfp.geluGrad -#print axioms Nfp.geluLinearization -#print axioms Nfp.layerNorm -#print axioms Nfp.layerNormJacobian -#print axioms Nfp.layerNorm_translation_invariant -#print axioms Nfp.softmax -#print axioms Nfp.softmaxJacobian -#print axioms Nfp.softmax_nonneg -#print axioms Nfp.softmax_sum_one -#print axioms Nfp.softmaxJacobian_diag_pos -#print axioms Nfp.softmaxJacobian_off_diag_neg -#print axioms Nfp.softmax_translation_invariant -#print axioms Nfp.ropeJacobian -#print axioms Nfp.rope -#print axioms Nfp.ropeJacobian_cross_pos -#print axioms Nfp.rope_operatorNormBound_le_two -#print axioms Nfp.gradientTimesInput -#print axioms Nfp.gradientTimesInput_complete -#print axioms Nfp.composed_attribution -#print axioms Nfp.integratedGradientsLinear -#print axioms Nfp.integratedGradients_linear_complete +#print axioms Nfp.Mixer.id +#print axioms Nfp.Dag.parents +#print axioms Nfp.LocalSystem.toMixer diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean new file mode 100644 index 0000000..3c109f2 --- /dev/null +++ b/Nfp/Cli.lean @@ -0,0 +1,42 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Cli + +/-! +Minimal CLI surface for the NFP rewrite. +-/ + +open Cli + +namespace Nfp + +/-- Human-readable version string for the CLI. -/ +def versionString : String := "0.1.0-tabula" + +/-- Print the CLI version. -/ +def runVersion (_p : Parsed) : IO UInt32 := do + IO.println s!"nfp version {versionString}" + return 0 + +/-- The version subcommand. -/ +def versionCmd : Cmd := `[Cli| + version VIA runVersion; + "Print the NFP version." +] + +/-- The root CLI command. -/ +def nfpCmd : Cmd := `[Cli| + nfp NOOP; + "NFP: Neural Formal Pathways (rewrite in progress)." + SUBCOMMANDS: + versionCmd +] + +/-- Main entry point for the CLI. -/ +def main (args : List String) : IO UInt32 := do + if args.contains "--version" then + IO.println s!"nfp version {versionString}" + return 0 + nfpCmd.validate args + +end Nfp diff --git a/Nfp/Core.lean b/Nfp/Core.lean new file mode 100644 index 0000000..e9e6bc6 --- /dev/null +++ b/Nfp/Core.lean @@ -0,0 +1,7 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Core.Basic + +/-! +Core shared definitions for the NFP rewrite. +-/ diff --git a/Nfp/Core/Basic.lean b/Nfp/Core/Basic.lean new file mode 100644 index 0000000..13dffe2 --- /dev/null +++ b/Nfp/Core/Basic.lean @@ -0,0 +1,14 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Data.NNReal.Basic + +/-! +Basic shared definitions for the NFP rewrite. +-/ + +namespace Nfp + +/-- Nonnegative mass used for probabilities and weights. -/ +abbrev Mass := NNReal + +end Nfp diff --git a/Nfp/Mixer.lean b/Nfp/Mixer.lean index ae10e35..747fc6f 100644 --- a/Nfp/Mixer.lean +++ b/Nfp/Mixer.lean @@ -1,197 +1,8 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Data.NNReal.Basic -import Mathlib.Data.Fintype.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Data.Finset.Basic -import Mathlib.Data.Set.Basic -import Nfp.Prob +import Nfp.Mixer.Basic +import Nfp.Mixer.Operations /-! -Forward mixers (Appendix A.1): abstract, finite, row-stochastic operators and -their basic closure properties. This formalizes the “forward mixer” notion from -the Appendix and shows that they preserve probability mass and are closed under -composition. +Row-stochastic mixers. -/ - -namespace Nfp - -open scoped BigOperators - -variable {S T U : Type*} [Fintype S] [Fintype T] [Fintype U] - -/-- A row-stochastic nonnegative mixer from `S` to `T` (Appendix A.1). -/ -structure Mixer (S T : Type*) [Fintype S] [Fintype T] where - w : S → T → NNReal - row_sum_one : ∀ i : S, (∑ j, w i j) = 1 - -namespace Mixer - -variable (M : Mixer S T) - -@[ext] -theorem ext {M N : Mixer S T} (h : ∀ i j, M.w i j = N.w i j) : M = N := by - cases M; cases N; simp only [mk.injEq]; funext i j; exact h i j - -/-- Apply a mixer to a probability row `p` to produce a probability row on `T`. -This corresponds to pushing tracer mass through a forward mixer (Appendix A.1). -/ -noncomputable def push (p : ProbVec S) : ProbVec T := - { - mass := fun j => ∑ i, p.mass i * M.w i j, - norm_one := by - classical - -- Show ∑_j (∑_i p_i * w_{i,j}) = 1 by swapping sums and using row sums = 1 - have : (∑ j, (∑ i, p.mass i * M.w i j)) = 1 := by - calc - (∑ j, (∑ i, p.mass i * M.w i j)) - = (∑ i, (∑ j, p.mass i * M.w i j)) := by - simpa [mul_comm, mul_left_comm, mul_assoc] using Finset.sum_comm - _ = (∑ i, p.mass i * (∑ j, M.w i j)) := by - simp [Finset.mul_sum] - _ = (∑ i, p.mass i * 1) := by - simp [M.row_sum_one] - _ = (∑ i, p.mass i) := by - simp - _ = 1 := by - simp [ProbVec.sum_mass (ι:=S) p] - simp [this] } - -/-- The `i`-th row of a mixer as a probability vector on the target type. -/ -noncomputable def row (i : S) : ProbVec T := - { - mass := fun j => M.w i j - norm_one := by - classical - simpa using M.row_sum_one i - } - -@[simp] lemma row_mass (i : S) (j : T) : (M.row i).mass j = M.w i j := rfl - -/-- Pushing a point-mass probability vector through a mixer selects the corresponding row. -/ -theorem push_pure (i : S) : M.push (ProbVec.pure (ι := S) i) = M.row i := by - ext j - classical - simp [Mixer.push, Mixer.row, ProbVec.pure, Finset.sum_ite_eq', Finset.mem_univ] - -/-- `push` is affine: it commutes with convex mixtures of probability vectors. -/ -theorem push_mix (c : NNReal) (hc : c ≤ 1) (p q : ProbVec S) : - M.push (ProbVec.mix (ι := S) c hc p q) = - ProbVec.mix (ι := T) c hc (M.push p) (M.push q) := by - ext j - classical - simp [Mixer.push, ProbVec.mix, Finset.sum_add_distrib, add_mul, mul_assoc, Finset.mul_sum] - -/-- Composition of mixers is again a mixer (closure under composition). -/ -noncomputable def comp (N : Mixer T U) : Mixer S U := - { - w := fun i k => ∑ j, M.w i j * N.w j k, - row_sum_one := by - classical - intro i - -- ∑_k ∑_j M i j * N j k = ∑_j M i j * (∑_k N j k) = ∑_j M i j * 1 = 1 - calc - (∑ k, (∑ j, M.w i j * N.w j k)) = 1 := by - -- Same calculation as above but stated in one calc for direct closure - calc - (∑ k, (∑ j, M.w i j * N.w j k)) - = (∑ j, (∑ k, M.w i j * N.w j k)) := by - simpa [mul_comm, mul_left_comm, mul_assoc] using Finset.sum_comm - _ = (∑ j, M.w i j * (∑ k, N.w j k)) := by - simp [Finset.mul_sum] - _ = (∑ j, M.w i j * 1) := by - simp [N.row_sum_one] - _ = (∑ j, M.w i j) := by - simp - _ = 1 := by - simpa using M.row_sum_one i - } - -end Mixer - -end Nfp - -/-! -Support-restricted mixers (Appendix A.1, support restriction): compact helpers -to state and propagate that a mixer has zero weights outside a given binary -relation `R : S → T → Prop`. These do not change the existing proofs; they -encode the “only route along executed edges” constraint as a property. --/ - -namespace Nfp - -open scoped BigOperators - -variable {S T U : Type*} [Fintype S] [Fintype T] [Fintype U] - -namespace Mixer - -/-- A mixer `M : Mixer S T` is `supported` by `R` if every weight outside `R` is zero. -/ -def supported (M : Mixer S T) (R : S → T → Prop) : Prop := - ∀ i j, ¬ R i j → M.w i j = 0 - -@[simp] lemma supported_zero {M : Mixer S T} {R : S → T → Prop} - (h : supported (S := S) (T := T) M R) {i : S} {j : T} (hij : ¬ R i j) : M.w i j = 0 := - h i j hij - -/-- Relational composition of supports: `R ⋆ Q` allows an edge `i → k` iff there -exists `j` with `i → j` allowed by `R` and `j → k` allowed by `Q`. -/ -def compSupport (R : S → T → Prop) (Q : T → U → Prop) : S → U → Prop := - fun i k => ∃ j, R i j ∧ Q j k - -/-- If `M` is supported by `R` and `N` by `Q`, then `M.comp N` is supported by `R ⋆ Q`. -/ -lemma supported_comp {M : Mixer S T} {N : Mixer T U} - {R : S → T → Prop} {Q : T → U → Prop} - (hM : supported (S := S) (T := T) M R) (hN : supported (S := T) (T := U) N Q) : - supported (S := S) (T := U) (M.comp N) (compSupport R Q) := by - classical - intro i k hnot - -- For every j, either ¬R i j or ¬Q j k; hence the product weight vanishes. - have hforall : ∀ j, M.w i j * N.w j k = 0 := by - intro j - by_cases hR : R i j - · have hQ : ¬ Q j k := by - intro hQ - exact hnot ⟨j, hR, hQ⟩ - have hN0 : N.w j k = 0 := hN j k hQ - simp [hN0] - · have hM0 : M.w i j = 0 := hM i j hR - simp [hM0] - have hsum : (∑ j, M.w i j * N.w j k) = 0 := by - have hfun : (fun j => M.w i j * N.w j k) = fun _ => (0 : NNReal) := by - funext j - simpa using hforall j - simp [hfun] - -- This is exactly the weight on `(i,k)` inside `M.comp N`. - simp [Mixer.comp, hsum] - -/-- The support (positions with nonzero mass) of a probability vector. -/ -def supp (p : ProbVec S) : Set S := fun i => p.mass i ≠ 0 - -/-- Image of a set of sources along a binary support relation. -/ -def image (R : S → T → Prop) (A : Set S) : Set T := fun k => ∃ i, A i ∧ R i k - -/-- Support propagation: if `M` is supported by `R`, then any nonzero pushed mass at `k` -originates from some source `i` with nonzero mass and an allowed edge `R i k`. -/ -lemma supp_push_subset_image {M : Mixer S T} {R : S → T → Prop} - (hM : supported (S := S) (T := T) M R) (p : ProbVec S) : - supp (S := T) (M.push p) ⊆ image (S := S) (T := T) R (supp (S := S) p) := by - classical - intro k hk - have hk_mass : (M.push p).mass k ≠ 0 := hk - have hk' : (∑ i, p.mass i * M.w i k) ≠ 0 := by - simpa [Mixer.push] using hk_mass - obtain ⟨i, _hi, hne⟩ := Finset.exists_ne_zero_of_sum_ne_zero hk' - have hpi : p.mass i ≠ 0 := by - intro h0 - exact hne (by simp [h0]) - have hwi : M.w i k ≠ 0 := by - intro h0 - exact hne (by simp [h0]) - have hR : R i k := by - by_contra hR - exact hwi (hM i k hR) - exact ⟨i, ⟨hpi, hR⟩⟩ - -end Mixer - -end Nfp diff --git a/Nfp/Mixer/Basic.lean b/Nfp/Mixer/Basic.lean new file mode 100644 index 0000000..44c0317 --- /dev/null +++ b/Nfp/Mixer/Basic.lean @@ -0,0 +1,37 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Prob.Basic + +/-! +Row-stochastic mixers. +-/ + +open scoped BigOperators + +namespace Nfp + +universe u + +/-- A row-stochastic mixer from `ι` to `κ`. -/ +structure Mixer (ι κ : Type u) [Fintype ι] [Fintype κ] where + /-- Nonnegative weights for each source/target pair. -/ + weight : ι → κ → Mass + /-- Each row is a probability vector. -/ + row_sum : ∀ i, (∑ k, weight i k) = 1 + +attribute [simp] Mixer.row_sum + +namespace Mixer + +variable {ι κ : Type u} [Fintype ι] [Fintype κ] + +instance : CoeFun (Mixer ι κ) (fun _ => ι → κ → Mass) := ⟨Mixer.weight⟩ + +/-- The row of a mixer as a probability vector. -/ +def row (M : Mixer ι κ) (i : ι) : ProbVec κ := + { mass := fun k => M.weight i k + sum_mass := M.row_sum i } + +end Mixer + +end Nfp diff --git a/Nfp/Mixer/Operations.lean b/Nfp/Mixer/Operations.lean new file mode 100644 index 0000000..ee66010 --- /dev/null +++ b/Nfp/Mixer/Operations.lean @@ -0,0 +1,71 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Mixer.Basic +import Nfp.Prob.Operations +import Mathlib.Algebra.BigOperators.Ring.Finset + +/-! +Mixer operations (pushforward, composition, identity). +-/ + +open scoped BigOperators + +namespace Nfp +namespace Mixer + +universe u + +variable {ι κ α : Type u} [Fintype ι] [Fintype κ] [Fintype α] + +/-- Push a probability vector forward along a mixer. -/ +def push (M : Mixer ι κ) (p : ProbVec ι) : ProbVec κ := + { mass := fun k => ∑ i, p.mass i * M.weight i k + sum_mass := by + classical + calc + ∑ k, ∑ i, p.mass i * M.weight i k + = ∑ i, ∑ k, p.mass i * M.weight i k := by + simpa using + (Finset.sum_comm : + (∑ k : κ, ∑ i : ι, p.mass i * M.weight i k) = + ∑ i : ι, ∑ k : κ, p.mass i * M.weight i k) + _ = ∑ i, p.mass i * ∑ k, M.weight i k := by + refine Finset.sum_congr rfl ?_ + intro i _ + simpa using + (Finset.mul_sum (a := p.mass i) (s := (Finset.univ : Finset κ)) + (f := fun k => M.weight i k)).symm + _ = ∑ i, p.mass i * 1 := by simp + _ = 1 := by simp } + +/-- Composition of two mixers. -/ +def comp (M : Mixer ι κ) (N : Mixer κ α) : Mixer ι α := + { weight := fun i a => ∑ k, M.weight i k * N.weight k a + row_sum := by + classical + intro i + calc + ∑ a, ∑ k, M.weight i k * N.weight k a + = ∑ k, ∑ a, M.weight i k * N.weight k a := by + simpa using + (Finset.sum_comm : + (∑ a : α, ∑ k : κ, M.weight i k * N.weight k a) = + ∑ k : κ, ∑ a : α, M.weight i k * N.weight k a) + _ = ∑ k, M.weight i k * ∑ a, N.weight k a := by + refine Finset.sum_congr rfl ?_ + intro k _ + simpa using + (Finset.mul_sum (a := M.weight i k) (s := (Finset.univ : Finset α)) + (f := fun a => N.weight k a)).symm + _ = ∑ k, M.weight i k * 1 := by simp + _ = 1 := by simp } + +/-- Identity mixer. -/ +def id (ι : Type u) [Fintype ι] [DecidableEq ι] : Mixer ι ι := + { weight := fun i j => (ProbVec.pure i).mass j + row_sum := by + intro i + exact (ProbVec.pure i).sum_mass } + +end Mixer +end Nfp diff --git a/Nfp/Prob.lean b/Nfp/Prob.lean index a04515a..292da09 100644 --- a/Nfp/Prob.lean +++ b/Nfp/Prob.lean @@ -1,88 +1,8 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Data.NNReal.Basic -import Mathlib.Data.Fintype.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Data.Finset.Basic +import Nfp.Prob.Basic +import Nfp.Prob.Operations -/- -Basic probability-friendly definitions used across the NFP development. -We work with finite types and nonnegative reals `NNReal` from mathlib. +/-! +Probability vectors. -/ - -namespace Nfp - -open scoped BigOperators -open Finset - -/-- A probability vector on a finite type `ι` is a nonnegative function summing to 1. -/ -structure ProbVec (ι : Type*) [Fintype ι] where - mass : ι → NNReal - norm_one : (∑ i, mass i) = (1 : NNReal) - -namespace ProbVec - -variable {ι : Type*} [Fintype ι] - -@[simp] theorem sum_mass (p : ProbVec ι) : (∑ i, p.mass i) = 1 := p.norm_one - -@[ext] -theorem ext {p q : ProbVec ι} (h : ∀ i, p.mass i = q.mass i) : p = q := by - cases p; cases q; simp only [mk.injEq]; funext i; exact h i - -theorem mass_le_one (p : ProbVec ι) (i : ι) : p.mass i ≤ 1 := by - have h := p.sum_mass - calc p.mass i ≤ ∑ j, p.mass j := Finset.single_le_sum (by simp) (Finset.mem_univ i) - _ = 1 := h - -/-- The Dirac/point-mass probability vector at `i0`. -/ -noncomputable def pure (i0 : ι) : ProbVec ι := by - classical - refine - { - mass := fun i => if i = i0 then 1 else 0 - norm_one := ?_ - } - simp - -@[simp] lemma pure_mass_self (i0 : ι) : (pure (ι := ι) i0).mass i0 = 1 := by - classical - simp [pure] - -@[simp] lemma pure_mass_ne_self {i0 i : ι} (h : i ≠ i0) : (pure (ι := ι) i0).mass i = 0 := by - classical - simp [pure, h] - -/-- Convex mixture of probability vectors using coefficient `c ∈ [0,1]`. -/ -noncomputable def mix (c : NNReal) (hc : c ≤ 1) (p q : ProbVec ι) : ProbVec ι := - { - mass := fun i => c * p.mass i + (1 - c) * q.mass i - norm_one := by - classical - have hp : (∑ i, c * p.mass i) = c * (∑ i, p.mass i) := by - simpa using - (Finset.mul_sum (s := (Finset.univ : Finset ι)) (f := fun i : ι => p.mass i) - (a := c)).symm - have hq : - (∑ i, (1 - c) * q.mass i) = (1 - c) * (∑ i, q.mass i) := by - simpa using - (Finset.mul_sum (s := (Finset.univ : Finset ι)) (f := fun i : ι => q.mass i) - (a := (1 - c))).symm - calc - (∑ i, (c * p.mass i + (1 - c) * q.mass i)) - = (∑ i, c * p.mass i) + (∑ i, (1 - c) * q.mass i) := by - simp [Finset.sum_add_distrib] - _ = c * (∑ i, p.mass i) + (1 - c) * (∑ i, q.mass i) := by - simp [hp, hq] - _ = c * 1 + (1 - c) * 1 := by - simp [ProbVec.sum_mass] - _ = 1 := by - simpa using (add_tsub_cancel_of_le hc) - } - -@[simp] lemma mix_mass (c : NNReal) (hc : c ≤ 1) (p q : ProbVec ι) (i : ι) : - (mix (ι := ι) c hc p q).mass i = c * p.mass i + (1 - c) * q.mass i := rfl - -end ProbVec - -end Nfp diff --git a/Nfp/Prob/Basic.lean b/Nfp/Prob/Basic.lean new file mode 100644 index 0000000..92a3bb4 --- /dev/null +++ b/Nfp/Prob/Basic.lean @@ -0,0 +1,33 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Core +import Mathlib.Data.Fintype.BigOperators + +/-! +Probability vectors on finite types. +-/ + +open scoped BigOperators + +namespace Nfp + +universe u + +/-- A probability vector on a finite type. -/ +structure ProbVec (ι : Type u) [Fintype ι] where + /-- Mass assigned to each point. -/ + mass : ι → Mass + /-- Total mass is exactly one. -/ + sum_mass : (∑ i, mass i) = 1 + +attribute [simp] ProbVec.sum_mass + +namespace ProbVec + +variable {ι : Type u} [Fintype ι] + +instance : CoeFun (ProbVec ι) (fun _ => ι → Mass) := ⟨ProbVec.mass⟩ + +end ProbVec + +end Nfp diff --git a/Nfp/Prob/Operations.lean b/Nfp/Prob/Operations.lean new file mode 100644 index 0000000..39bb283 --- /dev/null +++ b/Nfp/Prob/Operations.lean @@ -0,0 +1,52 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Prob.Basic +import Mathlib.Algebra.BigOperators.Ring.Finset + +/-! +Basic constructions on probability vectors. +-/ + +open scoped BigOperators + +namespace Nfp +namespace ProbVec + +universe u + +variable {ι : Type u} [Fintype ι] + +/-- The pure distribution at a single point. -/ +def pure (i0 : ι) [DecidableEq ι] : ProbVec ι := by + refine + { mass := fun i => if i = i0 then 1 else 0 + sum_mass := ?_ } + exact (Fintype.sum_ite_eq' (ι := ι) (i := i0) (f := fun _ => (1 : Mass))) + +@[simp] theorem mass_pure (i0 i : ι) [DecidableEq ι] : + (pure i0).mass i = if i = i0 then 1 else 0 := rfl + +/-- Convex combination of two probability vectors with weights that sum to one. -/ +def mix (a b : Mass) (h : a + b = 1) (p q : ProbVec ι) : ProbVec ι := + { mass := fun i => a * p.mass i + b * q.mass i + sum_mass := by + classical + calc + ∑ i, (a * p.mass i + b * q.mass i) + = (∑ i, a * p.mass i) + (∑ i, b * q.mass i) := by + simp [Finset.sum_add_distrib] + _ = a * ∑ i, p.mass i + b * ∑ i, q.mass i := by + have ha : (∑ i, a * p.mass i) = a * ∑ i, p.mass i := by + simpa using + (Finset.mul_sum (a := a) (s := (Finset.univ : Finset ι)) + (f := fun i => p.mass i)).symm + have hb : (∑ i, b * q.mass i) = b * ∑ i, q.mass i := by + simpa using + (Finset.mul_sum (a := b) (s := (Finset.univ : Finset ι)) + (f := fun i => q.mass i)).symm + simp [ha, hb] + _ = a * 1 + b * 1 := by simp + _ = 1 := by simp [h] } + +end ProbVec +end Nfp diff --git a/Nfp/System.lean b/Nfp/System.lean new file mode 100644 index 0000000..ab8e7ad --- /dev/null +++ b/Nfp/System.lean @@ -0,0 +1,8 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.System.Dag +import Nfp.System.LocalSystem + +/-! +DAG-based system foundations. +-/ diff --git a/Nfp/System/Dag.lean b/Nfp/System/Dag.lean new file mode 100644 index 0000000..155b358 --- /dev/null +++ b/Nfp/System/Dag.lean @@ -0,0 +1,45 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Data.Finset.Basic +import Mathlib.Data.Fintype.Basic + +/-! +Directed acyclic graph foundations. +-/ + +namespace Nfp + +universe u + +/-- A finite directed acyclic graph with edge relation `rel`. +`rel u v` means there is an edge from `u` to `v`. -/ +structure Dag (ι : Type u) [Fintype ι] where + rel : ι → ι → Prop + decRel : DecidableRel rel + wf : WellFounded rel + +attribute [instance] Dag.decRel + +namespace Dag + +variable {ι : Type u} [Fintype ι] + +/-- Parents (incoming neighbors) of a node. -/ +def parents (G : Dag ι) (i : ι) [DecidableEq ι] : Finset ι := + Finset.filter (fun j => G.rel j i) Finset.univ + +/-- Children (outgoing neighbors) of a node. -/ +def children (G : Dag ι) (i : ι) [DecidableEq ι] : Finset ι := + Finset.filter (fun j => G.rel i j) Finset.univ + +@[simp] theorem mem_parents {G : Dag ι} [DecidableEq ι] {i j : ι} : + j ∈ G.parents i ↔ G.rel j i := by + simp [Dag.parents] + +@[simp] theorem mem_children {G : Dag ι} [DecidableEq ι] {i j : ι} : + j ∈ G.children i ↔ G.rel i j := by + simp [Dag.children] + +end Dag + +end Nfp diff --git a/Nfp/System/LocalSystem.lean b/Nfp/System/LocalSystem.lean new file mode 100644 index 0000000..d5a447d --- /dev/null +++ b/Nfp/System/LocalSystem.lean @@ -0,0 +1,47 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Data.Fintype.BigOperators +import Nfp.Mixer +import Nfp.System.Dag + +/-! +Local mixing systems on finite DAGs. +-/ + +open scoped BigOperators + +namespace Nfp + +universe u + +/-- A local mixing system on a DAG. +`weight i j` is the contribution from `j` into `i`. -/ +structure LocalSystem (ι : Type u) [Fintype ι] where + /-- The underlying DAG describing allowed dependencies. -/ + dag : Dag ι + /-- Mixing weights for each target/source pair. -/ + weight : ι → ι → Mass + /-- Weights vanish off the edge relation. -/ + support : ∀ i j, ¬ dag.rel j i → weight i j = 0 + /-- Each row is a probability vector. -/ + row_sum : ∀ i, (∑ j, weight i j) = 1 + +attribute [simp] LocalSystem.row_sum + +namespace LocalSystem + +variable {ι : Type u} [Fintype ι] + +/-- View a local system as a global mixer. -/ +def toMixer (L : LocalSystem ι) : Mixer ι ι := + { weight := L.weight + row_sum := L.row_sum } + +/-- Off-edge weights are zero. -/ +theorem weight_eq_zero_of_not_parent (L : LocalSystem ι) {i j : ι} (h : ¬ L.dag.rel j i) : + L.weight i j = 0 := + L.support i j h + +end LocalSystem + +end Nfp diff --git a/lakefile.toml b/lakefile.toml index 9973d13..b51a17e 100644 --- a/lakefile.toml +++ b/lakefile.toml @@ -19,6 +19,7 @@ rev = "stable" [[lean_lib]] name = "Nfp" +roots = ["Nfp"] [[lean_exe]] name = "nfp" From dfedecc7e7db32672918b67b7fd26eba03f60df5 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 17:06:49 +0100 Subject: [PATCH 064/244] refactor dag to wrap mathlib Digraph --- Nfp/System/Dag.lean | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/Nfp/System/Dag.lean b/Nfp/System/Dag.lean index 155b358..59ae739 100644 --- a/Nfp/System/Dag.lean +++ b/Nfp/System/Dag.lean @@ -1,7 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later +import Mathlib.Combinatorics.Digraph.Basic import Mathlib.Data.Finset.Basic -import Mathlib.Data.Fintype.Basic /-! Directed acyclic graph foundations. @@ -11,32 +11,39 @@ namespace Nfp universe u -/-- A finite directed acyclic graph with edge relation `rel`. -`rel u v` means there is an edge from `u` to `v`. -/ +/-- A finite directed acyclic graph, built on top of `Digraph`. -/ structure Dag (ι : Type u) [Fintype ι] where - rel : ι → ι → Prop - decRel : DecidableRel rel - wf : WellFounded rel + /-- The underlying directed graph. -/ + graph : Digraph ι + /-- Decidable adjacency for `graph.Adj`. -/ + decAdj : DecidableRel graph.Adj + /-- The adjacency relation is well-founded. -/ + wf : WellFounded graph.Adj -attribute [instance] Dag.decRel +attribute [instance] Dag.decAdj namespace Dag variable {ι : Type u} [Fintype ι] +/-- The edge relation of a DAG. -/ +def rel (G : Dag ι) : ι → ι → Prop := G.graph.Adj + /-- Parents (incoming neighbors) of a node. -/ -def parents (G : Dag ι) (i : ι) [DecidableEq ι] : Finset ι := - Finset.filter (fun j => G.rel j i) Finset.univ +def parents (G : Dag ι) (i : ι) : Finset ι := by + let _ : DecidablePred (fun j => G.rel j i) := fun j => G.decAdj j i + exact Finset.filter (fun j => G.rel j i) Finset.univ /-- Children (outgoing neighbors) of a node. -/ -def children (G : Dag ι) (i : ι) [DecidableEq ι] : Finset ι := - Finset.filter (fun j => G.rel i j) Finset.univ +def children (G : Dag ι) (i : ι) : Finset ι := by + let _ : DecidablePred (fun j => G.rel i j) := fun j => G.decAdj i j + exact Finset.filter (fun j => G.rel i j) Finset.univ -@[simp] theorem mem_parents {G : Dag ι} [DecidableEq ι] {i j : ι} : +@[simp] theorem mem_parents {G : Dag ι} {i j : ι} : j ∈ G.parents i ↔ G.rel j i := by simp [Dag.parents] -@[simp] theorem mem_children {G : Dag ι} [DecidableEq ι] {i j : ι} : +@[simp] theorem mem_children {G : Dag ι} {i j : ι} : j ∈ G.children i ↔ G.rel i j := by simp [Dag.children] From 023e4a8b83fa269e56daf0ec0fa035310a59a58f Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 17:16:38 +0100 Subject: [PATCH 065/244] lean core reuse: pi.single pure, narrower imports --- Nfp/Prob/Operations.lean | 10 +++++++--- Nfp/System/LocalSystem.lean | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/Nfp/Prob/Operations.lean b/Nfp/Prob/Operations.lean index 39bb283..ef0d1f2 100644 --- a/Nfp/Prob/Operations.lean +++ b/Nfp/Prob/Operations.lean @@ -19,12 +19,16 @@ variable {ι : Type u} [Fintype ι] /-- The pure distribution at a single point. -/ def pure (i0 : ι) [DecidableEq ι] : ProbVec ι := by refine - { mass := fun i => if i = i0 then 1 else 0 + { mass := Pi.single i0 (1 : Mass) sum_mass := ?_ } - exact (Fintype.sum_ite_eq' (ι := ι) (i := i0) (f := fun _ => (1 : Mass))) + exact (Fintype.sum_pi_single' (ι := ι) (i := i0) (a := (1 : Mass))) @[simp] theorem mass_pure (i0 i : ι) [DecidableEq ι] : - (pure i0).mass i = if i = i0 then 1 else 0 := rfl + (pure i0).mass i = if i = i0 then 1 else 0 := by + by_cases h : i = i0 + · subst h + simp [pure, Pi.single] + · simp [pure, Pi.single, h] /-- Convex combination of two probability vectors with weights that sum to one. -/ def mix (a b : Mass) (h : a + b = 1) (p q : ProbVec ι) : ProbVec ι := diff --git a/Nfp/System/LocalSystem.lean b/Nfp/System/LocalSystem.lean index d5a447d..59ea55a 100644 --- a/Nfp/System/LocalSystem.lean +++ b/Nfp/System/LocalSystem.lean @@ -1,7 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Mathlib.Data.Fintype.BigOperators -import Nfp.Mixer +import Nfp.Mixer.Basic import Nfp.System.Dag /-! From 9891adeced77ab7ef147b3c25ce8004283a08503 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 17:17:42 +0100 Subject: [PATCH 066/244] narrow build roots to new Nfp spine --- lakefile.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lakefile.toml b/lakefile.toml index b51a17e..a058f0a 100644 --- a/lakefile.toml +++ b/lakefile.toml @@ -11,6 +11,9 @@ weak.linter.mathlibStandardSet = true maxSynthPendingDepth = 3 linter.unusedVariables = true weak.linter.unreachableTactic = true +weak.linter.missingDocs = true +weak.linter.unusedTactic = true +weak.linter.omit = true [[require]] name = "mathlib" From 62eed2b17e8ad0ba9e6c40901674ce8ee2ffb181 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 17:28:09 +0100 Subject: [PATCH 067/244] local system eval semantics and row-stochastic predicate --- AGENTS.md | 2 +- Nfp.lean | 2 ++ Nfp/System/Dag.lean | 2 +- Nfp/System/LocalSystem.lean | 33 +++++++++++++++++++++++++++------ 4 files changed, 31 insertions(+), 8 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index dda03ed..7450662 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -266,7 +266,7 @@ but you **must** update this list in the same commit. - `Nfp/System/Dag.lean` - DAG relation + parent/child sets. - `Nfp/System/LocalSystem.lean` - - `LocalSystem` with edge support and row-sum invariants. + - `LocalSystem` with edge support, row-stochastic predicate, and evaluation semantics. - `Nfp/System.lean` - Aggregator for system modules. diff --git a/Nfp.lean b/Nfp.lean index d04adfa..b59c05b 100644 --- a/Nfp.lean +++ b/Nfp.lean @@ -23,3 +23,5 @@ These `#print axioms` lines help ensure we only depend on a small set of axioms #print axioms Nfp.Mixer.id #print axioms Nfp.Dag.parents #print axioms Nfp.LocalSystem.toMixer +#print axioms Nfp.LocalSystem.eval +#print axioms Nfp.LocalSystem.eval_eq diff --git a/Nfp/System/Dag.lean b/Nfp/System/Dag.lean index 59ae739..e5caa68 100644 --- a/Nfp/System/Dag.lean +++ b/Nfp/System/Dag.lean @@ -27,7 +27,7 @@ namespace Dag variable {ι : Type u} [Fintype ι] /-- The edge relation of a DAG. -/ -def rel (G : Dag ι) : ι → ι → Prop := G.graph.Adj +abbrev rel (G : Dag ι) : ι → ι → Prop := G.graph.Adj /-- Parents (incoming neighbors) of a node. -/ def parents (G : Dag ι) (i : ι) : Finset ι := by diff --git a/Nfp/System/LocalSystem.lean b/Nfp/System/LocalSystem.lean index 59ea55a..249d337 100644 --- a/Nfp/System/LocalSystem.lean +++ b/Nfp/System/LocalSystem.lean @@ -23,25 +23,46 @@ structure LocalSystem (ι : Type u) [Fintype ι] where weight : ι → ι → Mass /-- Weights vanish off the edge relation. -/ support : ∀ i j, ¬ dag.rel j i → weight i j = 0 - /-- Each row is a probability vector. -/ - row_sum : ∀ i, (∑ j, weight i j) = 1 - -attribute [simp] LocalSystem.row_sum namespace LocalSystem variable {ι : Type u} [Fintype ι] +/-- Row-stochasticity for a local system. -/ +def IsRowStochastic (L : LocalSystem ι) : Prop := + ∀ i, (∑ j, L.weight i j) = 1 + /-- View a local system as a global mixer. -/ -def toMixer (L : LocalSystem ι) : Mixer ι ι := +def toMixer (L : LocalSystem ι) (h : IsRowStochastic L) : Mixer ι ι := { weight := L.weight - row_sum := L.row_sum } + row_sum := h } /-- Off-edge weights are zero. -/ theorem weight_eq_zero_of_not_parent (L : LocalSystem ι) {i j : ι} (h : ¬ L.dag.rel j i) : L.weight i j = 0 := L.support i j h +/-- One-step evaluation functional used by `eval`. -/ +def evalStep (L : LocalSystem ι) (input : ι → Mass) + (i : ι) (rec : ∀ j, L.dag.rel j i → Mass) : Mass := + input i + + ∑ j, (if h : L.dag.rel j i then L.weight i j * rec j h else 0) + +/-- Evaluate a local system with external input at each node. -/ +def eval (L : LocalSystem ι) (input : ι → Mass) : ι → Mass := + L.dag.wf.fix (fun i rec => evalStep L input i rec) + +/-- Unfolding equation for `eval`. -/ +theorem eval_eq (L : LocalSystem ι) (input : ι → Mass) (i : ι) : + eval L input i = + input i + + ∑ j, (if _ : L.dag.rel j i then L.weight i j * eval L input j else 0) := by + set F : ∀ i, (∀ j, L.dag.rel j i → Mass) → Mass := fun i rec => evalStep L input i rec + change L.dag.wf.fix F i = + input i + ∑ j, (if _ : L.dag.rel j i then L.weight i j * L.dag.wf.fix F j else 0) + rw [WellFounded.fix_eq] + dsimp [F, evalStep] + end LocalSystem end Nfp From ca6fa2be5a539033c9be328e4306967ae4202f24 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 17:33:13 +0100 Subject: [PATCH 068/244] Add Dag coercion to Digraph --- Nfp/System/Dag.lean | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Nfp/System/Dag.lean b/Nfp/System/Dag.lean index e5caa68..25632a3 100644 --- a/Nfp/System/Dag.lean +++ b/Nfp/System/Dag.lean @@ -26,6 +26,9 @@ namespace Dag variable {ι : Type u} [Fintype ι] +/-- Coerce a DAG to its underlying digraph. -/ +instance : Coe (Dag ι) (Digraph ι) := ⟨Dag.graph⟩ + /-- The edge relation of a DAG. -/ abbrev rel (G : Dag ι) : ι → ι → Prop := G.graph.Adj From 6fee534b8de7b4eba2ef9359a23c7294ea72cc67 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 17:54:12 +0100 Subject: [PATCH 069/244] Add circuit core definitions and equivalence checker --- AGENTS.md | 14 +++++++++-- Nfp.lean | 3 +++ Nfp/Circuit.lean | 9 +++++++ Nfp/Circuit/Basic.lean | 24 +++++++++++++++++++ Nfp/Circuit/Cert.lean | 40 +++++++++++++++++++++++++++++++ Nfp/Circuit/Semantics.lean | 49 ++++++++++++++++++++++++++++++++++++++ 6 files changed, 137 insertions(+), 2 deletions(-) create mode 100644 Nfp/Circuit.lean create mode 100644 Nfp/Circuit/Basic.lean create mode 100644 Nfp/Circuit/Cert.lean create mode 100644 Nfp/Circuit/Semantics.lean diff --git a/AGENTS.md b/AGENTS.md index 7450662..5c36282 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -270,7 +270,17 @@ but you **must** update this list in the same commit. - `Nfp/System.lean` - Aggregator for system modules. -### 5.5 CLI surface +### 5.5 Circuits (certification core) +- `Nfp/Circuit/Basic.lean` + - DAG-based circuit structure with inputs/outputs and gate semantics. +- `Nfp/Circuit/Semantics.lean` + - Well-founded evaluation semantics for circuits. +- `Nfp/Circuit/Cert.lean` + - Equivalence definition and finite checker. +- `Nfp/Circuit.lean` + - Aggregator for circuit modules. + +### 5.6 CLI surface - `Nfp/Cli.lean` - CLI commands and `main` implementation. - `Main.lean` @@ -278,7 +288,7 @@ but you **must** update this list in the same commit. - `Nfp.lean` - Top-level reexports and axioms dashboard (`#print axioms`). -### 5.6 Legacy (tabula rasa transition) +### 5.7 Legacy (tabula rasa transition) - Legacy modules live under `Legacy/Nfp/` as reference only and are not built by default. If you introduce a new conceptual layer: diff --git a/Nfp.lean b/Nfp.lean index b59c05b..073abcd 100644 --- a/Nfp.lean +++ b/Nfp.lean @@ -4,6 +4,7 @@ import Nfp.Core import Nfp.Prob import Nfp.Mixer import Nfp.System +import Nfp.Circuit /-! Top-level reexports and trust dashboard for the NFP rewrite. @@ -25,3 +26,5 @@ These `#print axioms` lines help ensure we only depend on a small set of axioms #print axioms Nfp.LocalSystem.toMixer #print axioms Nfp.LocalSystem.eval #print axioms Nfp.LocalSystem.eval_eq +#print axioms Nfp.Circuit.eval +#print axioms Nfp.Circuit.checkEquiv diff --git a/Nfp/Circuit.lean b/Nfp/Circuit.lean new file mode 100644 index 0000000..d39c489 --- /dev/null +++ b/Nfp/Circuit.lean @@ -0,0 +1,9 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Circuit.Basic +import Nfp.Circuit.Semantics +import Nfp.Circuit.Cert + +/-! +Circuit definitions, semantics, and equivalence checking. +-/ diff --git a/Nfp/Circuit/Basic.lean b/Nfp/Circuit/Basic.lean new file mode 100644 index 0000000..944ccff --- /dev/null +++ b/Nfp/Circuit/Basic.lean @@ -0,0 +1,24 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.System.Dag + +/-! +Circuit foundations: a DAG with designated inputs/outputs and gate semantics. +-/ + +namespace Nfp + +universe u v + +/-- A finite circuit on a DAG with designated inputs/outputs and per-node gate semantics. -/ +structure Circuit (ι : Type u) [Fintype ι] (α : Type v) where + /-- The underlying DAG that orders dependencies. -/ + dag : Dag ι + /-- Input nodes read from the external assignment. -/ + inputs : Finset ι + /-- Output nodes observed after evaluation. -/ + outputs : Finset ι + /-- Gate semantics at each node, given values of its parents. -/ + gate : ∀ i, (∀ j, dag.rel j i → α) → α + +end Nfp diff --git a/Nfp/Circuit/Cert.lean b/Nfp/Circuit/Cert.lean new file mode 100644 index 0000000..70b769c --- /dev/null +++ b/Nfp/Circuit/Cert.lean @@ -0,0 +1,40 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Circuit.Semantics + +/-! +Circuit equivalence and a finite checker. +-/ + +namespace Nfp + +universe u v + +namespace Circuit + +variable {ι : Type u} [Fintype ι] [DecidableEq ι] +variable {α : Type v} + +/-- Circuits share the same input/output interface. -/ +def SameInterface (C₁ C₂ : Circuit ι α) : Prop := + C₁.inputs = C₂.inputs ∧ C₁.outputs = C₂.outputs + +/-- Circuits are equivalent if they agree on outputs for all inputs on the same interface. -/ +def Equiv (C₁ C₂ : Circuit ι α) : Prop := + SameInterface C₁ C₂ ∧ + ∀ input, ∀ i ∈ C₁.outputs, eval C₁ input i = eval C₂ input i + +/-- Decide equivalence (classically); computational checkers can refine this. -/ +noncomputable def checkEquiv (C₁ C₂ : Circuit ι α) : Bool := by + classical + exact decide (Equiv C₁ C₂) + +/-- `checkEquiv` is sound and complete for `Equiv`. -/ +theorem checkEquiv_eq_true_iff (C₁ C₂ : Circuit ι α) : + checkEquiv C₁ C₂ = true ↔ Equiv C₁ C₂ := by + classical + simp [checkEquiv] + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Semantics.lean b/Nfp/Circuit/Semantics.lean new file mode 100644 index 0000000..e7fcd66 --- /dev/null +++ b/Nfp/Circuit/Semantics.lean @@ -0,0 +1,49 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Circuit.Basic + +/-! +Evaluation semantics for finite circuits. +-/ + +namespace Nfp + +universe u v + +namespace Circuit + +variable {ι : Type u} [Fintype ι] [DecidableEq ι] +variable {α : Type v} + +/-- One-step evaluation functional used by `eval`. -/ +def evalStep (C : Circuit ι α) (input : ι → α) + (i : ι) (rec : ∀ j, C.dag.rel j i → α) : α := + if _ : i ∈ C.inputs then input i else C.gate i rec + +/-- Evaluate a circuit with a given input assignment. -/ +def eval (C : Circuit ι α) (input : ι → α) : ι → α := + C.dag.wf.fix (fun i rec => evalStep C input i rec) + +/-- Unfolding equation for `eval`. -/ +theorem eval_eq (C : Circuit ι α) (input : ι → α) (i : ι) : + eval C input i = + if _ : i ∈ C.inputs then input i else C.gate i (fun j _ => eval C input j) := by + set F : ∀ i, (∀ j, C.dag.rel j i → α) → α := fun i rec => evalStep C input i rec + change C.dag.wf.fix F i = + if _ : i ∈ C.inputs then input i else C.gate i (fun j _ => C.dag.wf.fix F j) + rw [WellFounded.fix_eq] + dsimp [F, evalStep] + +/-- Input nodes evaluate to their assigned input value. -/ +theorem eval_eq_input (C : Circuit ι α) (input : ι → α) {i : ι} (h : i ∈ C.inputs) : + eval C input i = input i := by + simpa [h] using (eval_eq C input i) + +/-- Non-input nodes evaluate via their gate semantics. -/ +theorem eval_eq_gate (C : Circuit ι α) (input : ι → α) {i : ι} (h : i ∉ C.inputs) : + eval C input i = C.gate i (fun j _ => eval C input j) := by + simpa [h] using (eval_eq C input i) + +end Circuit + +end Nfp From ace7880412b8f3e830bb371a4a205f226a7857b2 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 18:09:53 +0100 Subject: [PATCH 070/244] Make circuit equivalence checker computable --- Nfp/Circuit/Cert.lean | 59 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 54 insertions(+), 5 deletions(-) diff --git a/Nfp/Circuit/Cert.lean b/Nfp/Circuit/Cert.lean index 70b769c..d501e6c 100644 --- a/Nfp/Circuit/Cert.lean +++ b/Nfp/Circuit/Cert.lean @@ -1,5 +1,8 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later +import Mathlib.Data.Finset.Fold +import Mathlib.Data.Finset.Insert +import Mathlib.Data.Fintype.Pi import Nfp.Circuit.Semantics /-! @@ -24,16 +27,62 @@ def Equiv (C₁ C₂ : Circuit ι α) : Prop := SameInterface C₁ C₂ ∧ ∀ input, ∀ i ∈ C₁.outputs, eval C₁ input i = eval C₂ input i -/-- Decide equivalence (classically); computational checkers can refine this. -/ -noncomputable def checkEquiv (C₁ C₂ : Circuit ι α) : Bool := by +section + +local instance : Std.Commutative (α := Bool) (· && ·) := ⟨Bool.and_comm⟩ +local instance : Std.Associative (α := Bool) (· && ·) := ⟨Bool.and_assoc⟩ + +/-- Boolean `all` over a finset. -/ +def finsetAll {β : Type v} (s : Finset β) (p : β → Bool) : Bool := + s.fold (· && ·) true p + +theorem finsetAll_eq_true_iff {β : Type v} {s : Finset β} {p : β → Bool} : + finsetAll s p = true ↔ ∀ a ∈ s, p a = true := by classical - exact decide (Equiv C₁ C₂) + induction s using Finset.induction_on with + | empty => + simp [finsetAll] + | @insert a s ha ih => + have hfold : finsetAll (insert a s) p = true ↔ p a = true ∧ finsetAll s p = true := by + simp [finsetAll, ha, Bool.and_eq_true] + calc + finsetAll (insert a s) p = true + ↔ p a = true ∧ finsetAll s p = true := hfold + _ ↔ p a = true ∧ ∀ a ∈ s, p a = true := by simp [ih] + _ ↔ ∀ x ∈ insert a s, p x = true := by + constructor + · intro h x hx + rcases h with ⟨ha', hs⟩ + by_cases hx' : x = a + · simpa [hx'] using ha' + · exact hs x (Finset.mem_of_mem_insert_of_ne hx hx') + · intro h + refine ⟨?_, ?_⟩ + · exact h a (Finset.mem_insert_self a s) + · intro x hx + exact h x (Finset.mem_insert_of_mem hx) + +/-- Boolean check for interface equality. -/ +def sameInterface (C₁ C₂ : Circuit ι α) : Bool := + decide (C₁.inputs = C₂.inputs) && decide (C₁.outputs = C₂.outputs) + +theorem sameInterface_eq_true_iff (C₁ C₂ : Circuit ι α) : + sameInterface C₁ C₂ = true ↔ SameInterface C₁ C₂ := by + simp [sameInterface, SameInterface, Bool.and_eq_true] + +/-- Decide equivalence by enumerating all inputs on a finite value type. -/ +def checkEquiv (C₁ C₂ : Circuit ι α) [Fintype α] [DecidableEq α] : Bool := + sameInterface C₁ C₂ && + finsetAll (Finset.univ : Finset (ι → α)) (fun input => + finsetAll C₁.outputs (fun i => decide (eval C₁ input i = eval C₂ input i))) /-- `checkEquiv` is sound and complete for `Equiv`. -/ -theorem checkEquiv_eq_true_iff (C₁ C₂ : Circuit ι α) : +theorem checkEquiv_eq_true_iff (C₁ C₂ : Circuit ι α) [Fintype α] [DecidableEq α] : checkEquiv C₁ C₂ = true ↔ Equiv C₁ C₂ := by classical - simp [checkEquiv] + simp [checkEquiv, Equiv, sameInterface_eq_true_iff, finsetAll_eq_true_iff, Bool.and_eq_true] + +end end Circuit From e2f25495e060edd367ec9ec0ac629fb881b3eb4d Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 18:14:59 +0100 Subject: [PATCH 071/244] Enable additional style linters --- lakefile.toml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/lakefile.toml b/lakefile.toml index a058f0a..3ecdd1e 100644 --- a/lakefile.toml +++ b/lakefile.toml @@ -14,6 +14,15 @@ weak.linter.unreachableTactic = true weak.linter.missingDocs = true weak.linter.unusedTactic = true weak.linter.omit = true +weak.linter.style.longFile = 1500 +weak.linter.style.setOption = true +weak.linter.style.missingEnd = true +weak.linter.style.openClassical = true +weak.linter.style.show = true +weak.linter.style.lambdaSyntax = true +weak.linter.style.dollarSyntax = true +weak.linter.style.cdot = true +weak.linter.style.longLine = true [[require]] name = "mathlib" From f1cc4dc1cd7141cd210338d50871a56912a8a59a Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 18:30:33 +0100 Subject: [PATCH 072/244] Add input-only circuit assignments and evaluation --- Nfp.lean | 1 + Nfp/Circuit/Basic.lean | 17 +++++++++++++++++ Nfp/Circuit/Cert.lean | 38 ++++++++++++++++++++++++++++++-------- Nfp/Circuit/Semantics.lean | 31 +++++++++++++++++++++++++++++++ 4 files changed, 79 insertions(+), 8 deletions(-) diff --git a/Nfp.lean b/Nfp.lean index 073abcd..c915fba 100644 --- a/Nfp.lean +++ b/Nfp.lean @@ -27,4 +27,5 @@ These `#print axioms` lines help ensure we only depend on a small set of axioms #print axioms Nfp.LocalSystem.eval #print axioms Nfp.LocalSystem.eval_eq #print axioms Nfp.Circuit.eval +#print axioms Nfp.Circuit.evalInput #print axioms Nfp.Circuit.checkEquiv diff --git a/Nfp/Circuit/Basic.lean b/Nfp/Circuit/Basic.lean index 944ccff..2570bc8 100644 --- a/Nfp/Circuit/Basic.lean +++ b/Nfp/Circuit/Basic.lean @@ -21,4 +21,21 @@ structure Circuit (ι : Type u) [Fintype ι] (α : Type v) where /-- Gate semantics at each node, given values of its parents. -/ gate : ∀ i, (∀ j, dag.rel j i → α) → α +namespace Circuit + +variable {ι : Type u} [Fintype ι] {α : Type v} + +/-- External input assignment on the circuit's input nodes. -/ +abbrev InputAssignment (C : Circuit ι α) : Type (max u v) := + { i // i ∈ C.inputs } → α + +/-- Reinterpret input assignments along an equality of input sets. -/ +def InputAssignment.cast {C₁ C₂ : Circuit ι α} (h : C₁.inputs = C₂.inputs) : + InputAssignment C₁ → InputAssignment C₂ := by + intro input i + refine input ⟨i.1, ?_⟩ + simp [h] + +end Circuit + end Nfp diff --git a/Nfp/Circuit/Cert.lean b/Nfp/Circuit/Cert.lean index d501e6c..428d950 100644 --- a/Nfp/Circuit/Cert.lean +++ b/Nfp/Circuit/Cert.lean @@ -22,10 +22,19 @@ variable {α : Type v} def SameInterface (C₁ C₂ : Circuit ι α) : Prop := C₁.inputs = C₂.inputs ∧ C₁.outputs = C₂.outputs -/-- Circuits are equivalent if they agree on outputs for all inputs on the same interface. -/ +/-- `SameInterface` is decidable. -/ +instance (C₁ C₂ : Circuit ι α) : Decidable (SameInterface C₁ C₂) := by + dsimp [SameInterface] + infer_instance + +/-- Circuits agree on outputs for all input assignments on a fixed interface. -/ +def EquivOn (C₁ C₂ : Circuit ι α) (h : SameInterface C₁ C₂) : Prop := + ∀ input : C₁.InputAssignment, ∀ i ∈ C₁.outputs, + evalInput C₁ input i = evalInput C₂ (InputAssignment.cast h.1 input) i + +/-- Circuits are equivalent if they share an interface and agree on all inputs. -/ def Equiv (C₁ C₂ : Circuit ι α) : Prop := - SameInterface C₁ C₂ ∧ - ∀ input, ∀ i ∈ C₁.outputs, eval C₁ input i = eval C₂ input i + ∃ h : SameInterface C₁ C₂, EquivOn C₁ C₂ h section @@ -70,17 +79,30 @@ theorem sameInterface_eq_true_iff (C₁ C₂ : Circuit ι α) : sameInterface C₁ C₂ = true ↔ SameInterface C₁ C₂ := by simp [sameInterface, SameInterface, Bool.and_eq_true] -/-- Decide equivalence by enumerating all inputs on a finite value type. -/ +/-- Decide equivalence by enumerating all input assignments on a finite value type. -/ def checkEquiv (C₁ C₂ : Circuit ι α) [Fintype α] [DecidableEq α] : Bool := - sameInterface C₁ C₂ && - finsetAll (Finset.univ : Finset (ι → α)) (fun input => - finsetAll C₁.outputs (fun i => decide (eval C₁ input i = eval C₂ input i))) + if h : SameInterface C₁ C₂ then + finsetAll (Finset.univ : Finset C₁.InputAssignment) (fun input => + finsetAll C₁.outputs (fun i => + decide (evalInput C₁ input i = evalInput C₂ (InputAssignment.cast h.1 input) i))) + else + false /-- `checkEquiv` is sound and complete for `Equiv`. -/ theorem checkEquiv_eq_true_iff (C₁ C₂ : Circuit ι α) [Fintype α] [DecidableEq α] : checkEquiv C₁ C₂ = true ↔ Equiv C₁ C₂ := by classical - simp [checkEquiv, Equiv, sameInterface_eq_true_iff, finsetAll_eq_true_iff, Bool.and_eq_true] + by_cases h : SameInterface C₁ C₂ + · have hcheck : checkEquiv C₁ C₂ = true ↔ EquivOn C₁ C₂ h := by + simp [checkEquiv, h, EquivOn, finsetAll_eq_true_iff] + constructor + · intro hc + exact ⟨h, hcheck.mp hc⟩ + · intro hEquiv + rcases hEquiv with ⟨h', hEq⟩ + have hh : h' = h := Subsingleton.elim _ _ + exact hcheck.mpr (by simpa [hh] using hEq) + · simp [checkEquiv, h, Equiv] end diff --git a/Nfp/Circuit/Semantics.lean b/Nfp/Circuit/Semantics.lean index e7fcd66..2105fbf 100644 --- a/Nfp/Circuit/Semantics.lean +++ b/Nfp/Circuit/Semantics.lean @@ -44,6 +44,37 @@ theorem eval_eq_gate (C : Circuit ι α) (input : ι → α) {i : ι} (h : i ∉ eval C input i = C.gate i (fun j _ => eval C input j) := by simpa [h] using (eval_eq C input i) +/-- One-step evaluation functional used by `evalInput`. -/ +def evalInputStep (C : Circuit ι α) (input : C.InputAssignment) + (i : ι) (rec : ∀ j, C.dag.rel j i → α) : α := + if h : i ∈ C.inputs then input ⟨i, h⟩ else C.gate i rec + +/-- Evaluate a circuit with an input assignment defined on input nodes. -/ +def evalInput (C : Circuit ι α) (input : C.InputAssignment) : ι → α := + C.dag.wf.fix (fun i rec => evalInputStep C input i rec) + +/-- Unfolding equation for `evalInput`. -/ +theorem evalInput_eq (C : Circuit ι α) (input : C.InputAssignment) (i : ι) : + evalInput C input i = + if h : i ∈ C.inputs then input ⟨i, h⟩ else C.gate i (fun j _ => evalInput C input j) := by + set F : ∀ i, (∀ j, C.dag.rel j i → α) → α := fun i rec => evalInputStep C input i rec + change C.dag.wf.fix F i = + if h : i ∈ C.inputs then input ⟨i, h⟩ else C.gate i (fun j _ => C.dag.wf.fix F j) + rw [WellFounded.fix_eq] + dsimp [F, evalInputStep] + +/-- Input nodes evaluate to their assigned input value (input-only form). -/ +theorem evalInput_eq_input (C : Circuit ι α) (input : C.InputAssignment) {i : ι} + (h : i ∈ C.inputs) : + evalInput C input i = input ⟨i, h⟩ := by + simpa [h] using (evalInput_eq C input i) + +/-- Non-input nodes evaluate via their gate semantics (input-only form). -/ +theorem evalInput_eq_gate (C : Circuit ι α) (input : C.InputAssignment) {i : ι} + (h : i ∉ C.inputs) : + evalInput C input i = C.gate i (fun j _ => evalInput C input j) := by + simpa [h] using (evalInput_eq C input i) + end Circuit end Nfp From d67e6df51097cefa7775b7fcb563ecec58822b8b Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 18:39:25 +0100 Subject: [PATCH 073/244] Add typed circuit interfaces and well-formedness --- AGENTS.md | 4 ++++ Nfp.lean | 2 ++ Nfp/Circuit.lean | 2 ++ Nfp/Circuit/Cert.lean | 41 ++++++++++++++++++++++++++++++++- Nfp/Circuit/Interface.lean | 46 +++++++++++++++++++++++++++++++++++++ Nfp/Circuit/WellFormed.lean | 35 ++++++++++++++++++++++++++++ 6 files changed, 129 insertions(+), 1 deletion(-) create mode 100644 Nfp/Circuit/Interface.lean create mode 100644 Nfp/Circuit/WellFormed.lean diff --git a/AGENTS.md b/AGENTS.md index 5c36282..492cdb7 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -273,8 +273,12 @@ but you **must** update this list in the same commit. ### 5.5 Circuits (certification core) - `Nfp/Circuit/Basic.lean` - DAG-based circuit structure with inputs/outputs and gate semantics. +- `Nfp/Circuit/Interface.lean` + - Typed input/output interfaces and interface-based evaluation. - `Nfp/Circuit/Semantics.lean` - Well-founded evaluation semantics for circuits. +- `Nfp/Circuit/WellFormed.lean` + - Basic well-formedness conditions for circuit inputs. - `Nfp/Circuit/Cert.lean` - Equivalence definition and finite checker. - `Nfp/Circuit.lean` diff --git a/Nfp.lean b/Nfp.lean index c915fba..5b4254f 100644 --- a/Nfp.lean +++ b/Nfp.lean @@ -28,4 +28,6 @@ These `#print axioms` lines help ensure we only depend on a small set of axioms #print axioms Nfp.LocalSystem.eval_eq #print axioms Nfp.Circuit.eval #print axioms Nfp.Circuit.evalInput +#print axioms Nfp.Circuit.Interface.eval #print axioms Nfp.Circuit.checkEquiv +#print axioms Nfp.Circuit.checkEquivOnInterface diff --git a/Nfp/Circuit.lean b/Nfp/Circuit.lean index d39c489..43cc369 100644 --- a/Nfp/Circuit.lean +++ b/Nfp/Circuit.lean @@ -1,7 +1,9 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Nfp.Circuit.Basic +import Nfp.Circuit.Interface import Nfp.Circuit.Semantics +import Nfp.Circuit.WellFormed import Nfp.Circuit.Cert /-! diff --git a/Nfp/Circuit/Cert.lean b/Nfp/Circuit/Cert.lean index 428d950..24f69cc 100644 --- a/Nfp/Circuit/Cert.lean +++ b/Nfp/Circuit/Cert.lean @@ -3,6 +3,7 @@ import Mathlib.Data.Finset.Fold import Mathlib.Data.Finset.Insert import Mathlib.Data.Fintype.Pi +import Nfp.Circuit.Interface import Nfp.Circuit.Semantics /-! @@ -11,7 +12,7 @@ Circuit equivalence and a finite checker. namespace Nfp -universe u v +universe u v u' u_in u_out namespace Circuit @@ -36,6 +37,19 @@ def EquivOn (C₁ C₂ : Circuit ι α) (h : SameInterface C₁ C₂) : Prop := def Equiv (C₁ C₂ : Circuit ι α) : Prop := ∃ h : SameInterface C₁ C₂, EquivOn C₁ C₂ h +section Interface + +variable {ι₁ : Type u} [Fintype ι₁] [DecidableEq ι₁] +variable {ι₂ : Type u'} [Fintype ι₂] [DecidableEq ι₂] +variable {ι_in : Type u_in} {ι_out : Type u_out} + +/-- Circuits agree on outputs for all typed inputs on a shared interface. -/ +def EquivOnInterface (C₁ : Circuit ι₁ α) (C₂ : Circuit ι₂ α) + (I₁ : Interface C₁ ι_in ι_out) (I₂ : Interface C₂ ι_in ι_out) : Prop := + ∀ input : ι_in → α, ∀ o : ι_out, I₁.eval input o = I₂.eval input o + +end Interface + section local instance : Std.Commutative (α := Bool) (· && ·) := ⟨Bool.and_comm⟩ @@ -106,6 +120,31 @@ theorem checkEquiv_eq_true_iff (C₁ C₂ : Circuit ι α) [Fintype α] [Decidab end +section InterfaceCheck + +variable {ι₁ : Type u} [Fintype ι₁] [DecidableEq ι₁] +variable {ι₂ : Type u'} [Fintype ι₂] [DecidableEq ι₂] +variable {ι_in : Type u_in} [Fintype ι_in] [DecidableEq ι_in] +variable {ι_out : Type u_out} [Fintype ι_out] + +/-- Decide interface-based equivalence by enumerating typed inputs. -/ +def checkEquivOnInterface (C₁ : Circuit ι₁ α) (C₂ : Circuit ι₂ α) + (I₁ : Interface C₁ ι_in ι_out) (I₂ : Interface C₂ ι_in ι_out) + [Fintype α] [DecidableEq α] : Bool := + finsetAll (Finset.univ : Finset (ι_in → α)) (fun input => + finsetAll (Finset.univ : Finset ι_out) (fun o => + decide (I₁.eval input o = I₂.eval input o))) + +/-- `checkEquivOnInterface` is sound and complete for `EquivOnInterface`. -/ +theorem checkEquivOnInterface_eq_true_iff (C₁ : Circuit ι₁ α) (C₂ : Circuit ι₂ α) + (I₁ : Interface C₁ ι_in ι_out) (I₂ : Interface C₂ ι_in ι_out) + [Fintype α] [DecidableEq α] : + checkEquivOnInterface C₁ C₂ I₁ I₂ = true ↔ EquivOnInterface C₁ C₂ I₁ I₂ := by + classical + simp [checkEquivOnInterface, EquivOnInterface, finsetAll_eq_true_iff] + +end InterfaceCheck + end Circuit end Nfp diff --git a/Nfp/Circuit/Interface.lean b/Nfp/Circuit/Interface.lean new file mode 100644 index 0000000..4ec091a --- /dev/null +++ b/Nfp/Circuit/Interface.lean @@ -0,0 +1,46 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Circuit.Semantics + +/-! +Typed input/output interfaces for circuits. +-/ + +namespace Nfp + +universe u v u_in u_out + +namespace Circuit + +variable {ι : Type u} [Fintype ι] [DecidableEq ι] +variable {α : Type v} + +/-- A typed input/output interface for a circuit. -/ +structure Interface (C : Circuit ι α) (ι_in : Type u_in) (ι_out : Type u_out) where + /-- Input labels correspond exactly to the circuit's input nodes. -/ + inputs : ι_in ≃ { i // i ∈ C.inputs } + /-- Output labels correspond exactly to the circuit's output nodes. -/ + outputs : ι_out ≃ { i // i ∈ C.outputs } + +namespace Interface + +variable {C : Circuit ι α} {ι_in : Type u_in} {ι_out : Type u_out} + +/-- Convert a typed input assignment into an input-node assignment. -/ +def toInputAssignment (I : Interface C ι_in ι_out) (input : ι_in → α) : C.InputAssignment := + fun i => input (I.inputs.symm i) + +/-- Evaluate a circuit on a typed interface. -/ +def eval (I : Interface C ι_in ι_out) (input : ι_in → α) : ι_out → α := + fun o => evalInput C (I.toInputAssignment input) (I.outputs o).1 + +/-- Unfolding equation for `Interface.eval`. -/ +theorem eval_eq (I : Interface C ι_in ι_out) (input : ι_in → α) (o : ι_out) : + I.eval input o = evalInput C (I.toInputAssignment input) (I.outputs o).1 := + rfl + +end Interface + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/WellFormed.lean b/Nfp/Circuit/WellFormed.lean new file mode 100644 index 0000000..070f565 --- /dev/null +++ b/Nfp/Circuit/WellFormed.lean @@ -0,0 +1,35 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Circuit.Basic + +/-! +Well-formedness conditions for circuits. +-/ + +namespace Nfp + +universe u v + +namespace Circuit + +variable {ι : Type u} [Fintype ι] +variable {α : Type v} + +/-- A circuit is well-formed if every input node has no incoming edges. -/ +def WellFormed (C : Circuit ι α) : Prop := + ∀ i ∈ C.inputs, ∀ j, ¬ C.dag.rel j i + +/-- Inputs have no parents in a well-formed circuit. -/ +theorem wellFormed_no_parent {C : Circuit ι α} (h : WellFormed C) {i j : ι} (hi : i ∈ C.inputs) : + ¬ C.dag.rel j i := + h i hi j + +/-- Input nodes have empty parent sets in a well-formed circuit. -/ +theorem wellFormed_parents_empty {C : Circuit ι α} (h : WellFormed C) {i : ι} (hi : i ∈ C.inputs) : + C.dag.parents i = ∅ := by + ext j + simp [Dag.mem_parents, h i hi j] + +end Circuit + +end Nfp From 57fe3b39752e16c0ef538111085e565dee923aad Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 18:43:10 +0100 Subject: [PATCH 074/244] Document tabula rasa rewrite status --- README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/README.md b/README.md index 26676ba..ccf14a8 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,17 @@ This repo contains: This is research tooling. Interfaces may change; please treat results as experimental unless they are backed by a certificate/check you trust. +## Tabula Rasa Rewrite (current state) + +The `tabula-rasa` branch is a fresh, minimal Lean 4 core focused on circuit certification. The legacy system remains in `Legacy/Nfp/` and is not built by default. + +Current core modules (new): +- `Nfp/Core`, `Nfp/Prob`, `Nfp/Mixer`, `Nfp/System` define basic mass/probability, mixers, and DAG-backed local systems. +- `Nfp/Circuit` defines DAG-based circuits with input-only evaluation, typed interfaces, well-formedness, and equivalence checkers. +- `Nfp/Cli` and `Main.lean` are thin placeholders (no full transformer pipeline yet). + +Module map and invariants are tracked in `AGENTS.md`. + ## Soundness statement (what is proven vs checked) The Lean library defines the core math objects (finite probability, mixers, linearizations, and operator-norm-style bounds) and proves a number of lemmas about them. The CLI sound path produces certificates using exact `Rat` arithmetic and a trusted checker that verifies internal arithmetic relationships between certificate fields. From c6f97134b79f7d0dd929a5cb0e2e7e88ee2b10dd Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 19:04:45 +0100 Subject: [PATCH 075/244] Add circuit combinators and gates --- AGENTS.md | 8 +++++ Nfp/Circuit.lean | 3 ++ Nfp/Circuit/Combinators.lean | 60 ++++++++++++++++++++++++++++++++++++ Nfp/Circuit/Gates.lean | 7 +++++ Nfp/Circuit/Gates/Basic.lean | 40 ++++++++++++++++++++++++ Nfp/Circuit/Typed.lean | 60 ++++++++++++++++++++++++++++++++++++ Nfp/System/Dag.lean | 17 +++++++++- 7 files changed, 194 insertions(+), 1 deletion(-) create mode 100644 Nfp/Circuit/Combinators.lean create mode 100644 Nfp/Circuit/Gates.lean create mode 100644 Nfp/Circuit/Gates/Basic.lean create mode 100644 Nfp/Circuit/Typed.lean diff --git a/AGENTS.md b/AGENTS.md index 492cdb7..b008034 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -273,6 +273,8 @@ but you **must** update this list in the same commit. ### 5.5 Circuits (certification core) - `Nfp/Circuit/Basic.lean` - DAG-based circuit structure with inputs/outputs and gate semantics. +- `Nfp/Circuit/Combinators.lean` + - Core circuit combinators (relabeling, interface transport). - `Nfp/Circuit/Interface.lean` - Typed input/output interfaces and interface-based evaluation. - `Nfp/Circuit/Semantics.lean` @@ -281,6 +283,12 @@ but you **must** update this list in the same commit. - Basic well-formedness conditions for circuit inputs. - `Nfp/Circuit/Cert.lean` - Equivalence definition and finite checker. +- `Nfp/Circuit/Typed.lean` + - Typed circuit wrapper and interface-level equivalence checker. +- `Nfp/Circuit/Gates/Basic.lean` + - Basic gate combinators for aggregating parent values. +- `Nfp/Circuit/Gates.lean` + - Aggregator for gate combinator modules. - `Nfp/Circuit.lean` - Aggregator for circuit modules. diff --git a/Nfp/Circuit.lean b/Nfp/Circuit.lean index 43cc369..19c77db 100644 --- a/Nfp/Circuit.lean +++ b/Nfp/Circuit.lean @@ -1,10 +1,13 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Nfp.Circuit.Basic +import Nfp.Circuit.Combinators import Nfp.Circuit.Interface import Nfp.Circuit.Semantics import Nfp.Circuit.WellFormed import Nfp.Circuit.Cert +import Nfp.Circuit.Typed +import Nfp.Circuit.Gates /-! Circuit definitions, semantics, and equivalence checking. diff --git a/Nfp/Circuit/Combinators.lean b/Nfp/Circuit/Combinators.lean new file mode 100644 index 0000000..b12b88a --- /dev/null +++ b/Nfp/Circuit/Combinators.lean @@ -0,0 +1,60 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Data.Finset.Image +import Mathlib.Logic.Equiv.Basic +import Nfp.Circuit.Interface + +/-! +Circuit combinators such as relabeling. +-/ + +namespace Nfp + +universe u v u' u_in u_out + +namespace Circuit + +variable {ι : Type u} [Fintype ι] +variable {ι' : Type u'} [Fintype ι'] +variable {α : Type v} + +/-- Relabel the nodes of a circuit along an equivalence. -/ +def relabel (C : Circuit ι α) (e : ι ≃ ι') : Circuit ι' α := by + refine + { dag := C.dag.relabel e + inputs := C.inputs.map e.toEmbedding + outputs := C.outputs.map e.toEmbedding + gate := ?_ } + intro i rec + refine C.gate (e.symm i) ?_ + intro j h + refine rec (e j) ?_ + change C.dag.rel (e.symm (e j)) (e.symm i) + simpa using h + +namespace Interface + +variable {ι : Type u} [Fintype ι] [DecidableEq ι] +variable {ι' : Type u'} [Fintype ι'] [DecidableEq ι'] +variable {α : Type v} +variable {ι_in : Type u_in} {ι_out : Type u_out} +variable {C : Circuit ι α} + +/-- Relabel a circuit interface along an equivalence of nodes. -/ +def relabel (I : Interface C ι_in ι_out) (e : ι ≃ ι') : + Interface (C.relabel e) ι_in ι_out := by + refine { inputs := ?_, outputs := ?_ } + · refine I.inputs.trans ?_ + refine (e.subtypeEquiv ?_) + intro a + simp [Circuit.relabel] + · refine I.outputs.trans ?_ + refine (e.subtypeEquiv ?_) + intro a + simp [Circuit.relabel] + +end Interface + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Gates.lean b/Nfp/Circuit/Gates.lean new file mode 100644 index 0000000..adc08df --- /dev/null +++ b/Nfp/Circuit/Gates.lean @@ -0,0 +1,7 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Circuit.Gates.Basic + +/-! +Gate combinators for circuit semantics. +-/ diff --git a/Nfp/Circuit/Gates/Basic.lean b/Nfp/Circuit/Gates/Basic.lean new file mode 100644 index 0000000..b4ffcae --- /dev/null +++ b/Nfp/Circuit/Gates/Basic.lean @@ -0,0 +1,40 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.Ring.Basic +import Mathlib.Data.Finset.Attach +import Mathlib.Data.Fintype.BigOperators + +/-! +Basic gate combinators for aggregating parent values. +-/ + +namespace Nfp + +namespace Circuit + +namespace Gates + +universe u v + +variable {ι : Type u} {α : Type v} + +/-- Sum of parent values. -/ +def sumParents (parents : Finset ι) (rec : ∀ j, j ∈ parents → α) + [AddCommMonoid α] : α := + parents.attach.sum fun j => rec j.1 j.2 + +/-- Weighted sum of parent values using weights `w`. -/ +def weightedSumParents (parents : Finset ι) (w : ι → α) + (rec : ∀ j, j ∈ parents → α) [Semiring α] : α := + parents.attach.sum fun j => w j.1 * rec j.1 j.2 + +/-- Affine combination of parent values with weights `w` and bias `b`. -/ +def affineParents (parents : Finset ι) (w : ι → α) (b : α) + (rec : ∀ j, j ∈ parents → α) [Semiring α] : α := + weightedSumParents parents w rec + b + +end Gates + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Typed.lean b/Nfp/Circuit/Typed.lean new file mode 100644 index 0000000..fe76905 --- /dev/null +++ b/Nfp/Circuit/Typed.lean @@ -0,0 +1,60 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Circuit.Combinators +import Nfp.Circuit.Cert + +/-! +Typed circuit wrappers and typed equivalence checking. +-/ + +namespace Nfp + +universe u v u' u_in u_out + +namespace Circuit + +/-- A circuit bundled with a typed input/output interface. -/ +structure TypedCircuit (ι : Type u) [Fintype ι] [DecidableEq ι] (α : Type v) + (ι_in : Type u_in) (ι_out : Type u_out) where + /-- The underlying circuit. -/ + circuit : Circuit ι α + /-- Typed input/output interface for `circuit`. -/ + interface : Interface circuit ι_in ι_out + +namespace TypedCircuit + +variable {ι : Type u} [Fintype ι] [DecidableEq ι] +variable {α : Type v} {ι_in : Type u_in} {ι_out : Type u_out} + +/-- Evaluate a typed circuit on a typed input. -/ +def eval (T : TypedCircuit ι α ι_in ι_out) (input : ι_in → α) : ι_out → α := + T.interface.eval input + +/-- Decide equivalence by enumerating typed inputs. -/ +def checkEquiv (T₁ T₂ : TypedCircuit ι α ι_in ι_out) + [Fintype ι_in] [DecidableEq ι_in] [Fintype ι_out] + [Fintype α] [DecidableEq α] : Bool := + Circuit.checkEquivOnInterface T₁.circuit T₂.circuit T₁.interface T₂.interface + +/-- `checkEquiv` is sound and complete for `EquivOnInterface`. -/ +theorem checkEquiv_eq_true_iff (T₁ T₂ : TypedCircuit ι α ι_in ι_out) + [Fintype ι_in] [DecidableEq ι_in] [Fintype ι_out] + [Fintype α] [DecidableEq α] : + checkEquiv T₁ T₂ = true ↔ + EquivOnInterface T₁.circuit T₂.circuit T₁.interface T₂.interface := by + simpa [checkEquiv] using + (checkEquivOnInterface_eq_true_iff T₁.circuit T₂.circuit T₁.interface T₂.interface) + +variable {ι' : Type u'} [Fintype ι'] [DecidableEq ι'] + +/-- Relabel the nodes of a typed circuit. -/ +def relabel (T : TypedCircuit ι α ι_in ι_out) (e : ι ≃ ι') : + TypedCircuit ι' α ι_in ι_out := + { circuit := T.circuit.relabel e + interface := T.interface.relabel e } + +end TypedCircuit + +end Circuit + +end Nfp diff --git a/Nfp/System/Dag.lean b/Nfp/System/Dag.lean index 25632a3..1049fe7 100644 --- a/Nfp/System/Dag.lean +++ b/Nfp/System/Dag.lean @@ -9,7 +9,7 @@ Directed acyclic graph foundations. namespace Nfp -universe u +universe u u' /-- A finite directed acyclic graph, built on top of `Digraph`. -/ structure Dag (ι : Type u) [Fintype ι] where @@ -50,6 +50,21 @@ def children (G : Dag ι) (i : ι) : Finset ι := by j ∈ G.children i ↔ G.rel i j := by simp [Dag.children] +section Relabel + +variable {ι' : Type u'} [Fintype ι'] + +/-- Relabel a DAG along an equivalence of vertex types. -/ +def relabel (G : Dag ι) (e : ι ≃ ι') : Dag ι' := + { graph := { Adj := fun a b => G.rel (e.symm a) (e.symm b) } + decAdj := by + intro a b + exact G.decAdj (e.symm a) (e.symm b) + wf := by + simpa using (InvImage.wf (f := e.symm) (h := G.wf)) } + +end Relabel + end Dag end Nfp From 95e77231a57e3ddb739689aab08158f7f964fdcd Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 19:09:21 +0100 Subject: [PATCH 076/244] Use ASCII identifiers in circuit modules --- Nfp/Circuit/Combinators.lean | 22 ++++++++++---------- Nfp/Circuit/Gates/Basic.lean | 14 ++++++------- Nfp/Circuit/Typed.lean | 40 ++++++++++++++++++------------------ 3 files changed, 38 insertions(+), 38 deletions(-) diff --git a/Nfp/Circuit/Combinators.lean b/Nfp/Circuit/Combinators.lean index b12b88a..02bd981 100644 --- a/Nfp/Circuit/Combinators.lean +++ b/Nfp/Circuit/Combinators.lean @@ -14,12 +14,12 @@ universe u v u' u_in u_out namespace Circuit -variable {ι : Type u} [Fintype ι] -variable {ι' : Type u'} [Fintype ι'] -variable {α : Type v} +variable {Node : Type u} [Fintype Node] +variable {Node' : Type u'} [Fintype Node'] +variable {Val : Type v} /-- Relabel the nodes of a circuit along an equivalence. -/ -def relabel (C : Circuit ι α) (e : ι ≃ ι') : Circuit ι' α := by +def relabel (C : Circuit Node Val) (e : _root_.Equiv Node Node') : Circuit Node' Val := by refine { dag := C.dag.relabel e inputs := C.inputs.map e.toEmbedding @@ -34,15 +34,15 @@ def relabel (C : Circuit ι α) (e : ι ≃ ι') : Circuit ι' α := by namespace Interface -variable {ι : Type u} [Fintype ι] [DecidableEq ι] -variable {ι' : Type u'} [Fintype ι'] [DecidableEq ι'] -variable {α : Type v} -variable {ι_in : Type u_in} {ι_out : Type u_out} -variable {C : Circuit ι α} +variable {Node : Type u} [Fintype Node] [DecidableEq Node] +variable {Node' : Type u'} [Fintype Node'] [DecidableEq Node'] +variable {Val : Type v} +variable {Input : Type u_in} {Output : Type u_out} +variable {C : Circuit Node Val} /-- Relabel a circuit interface along an equivalence of nodes. -/ -def relabel (I : Interface C ι_in ι_out) (e : ι ≃ ι') : - Interface (C.relabel e) ι_in ι_out := by +def relabel (I : Interface C Input Output) (e : _root_.Equiv Node Node') : + Interface (C.relabel e) Input Output := by refine { inputs := ?_, outputs := ?_ } · refine I.inputs.trans ?_ refine (e.subtypeEquiv ?_) diff --git a/Nfp/Circuit/Gates/Basic.lean b/Nfp/Circuit/Gates/Basic.lean index b4ffcae..0cb28d4 100644 --- a/Nfp/Circuit/Gates/Basic.lean +++ b/Nfp/Circuit/Gates/Basic.lean @@ -16,21 +16,21 @@ namespace Gates universe u v -variable {ι : Type u} {α : Type v} +variable {Node : Type u} {Val : Type v} /-- Sum of parent values. -/ -def sumParents (parents : Finset ι) (rec : ∀ j, j ∈ parents → α) - [AddCommMonoid α] : α := +def sumParents (parents : Finset Node) (rec : ∀ j, j ∈ parents → Val) + [AddCommMonoid Val] : Val := parents.attach.sum fun j => rec j.1 j.2 /-- Weighted sum of parent values using weights `w`. -/ -def weightedSumParents (parents : Finset ι) (w : ι → α) - (rec : ∀ j, j ∈ parents → α) [Semiring α] : α := +def weightedSumParents (parents : Finset Node) (w : Node → Val) + (rec : ∀ j, j ∈ parents → Val) [Semiring Val] : Val := parents.attach.sum fun j => w j.1 * rec j.1 j.2 /-- Affine combination of parent values with weights `w` and bias `b`. -/ -def affineParents (parents : Finset ι) (w : ι → α) (b : α) - (rec : ∀ j, j ∈ parents → α) [Semiring α] : α := +def affineParents (parents : Finset Node) (w : Node → Val) (b : Val) + (rec : ∀ j, j ∈ parents → Val) [Semiring Val] : Val := weightedSumParents parents w rec + b end Gates diff --git a/Nfp/Circuit/Typed.lean b/Nfp/Circuit/Typed.lean index fe76905..590ee4c 100644 --- a/Nfp/Circuit/Typed.lean +++ b/Nfp/Circuit/Typed.lean @@ -14,42 +14,42 @@ universe u v u' u_in u_out namespace Circuit /-- A circuit bundled with a typed input/output interface. -/ -structure TypedCircuit (ι : Type u) [Fintype ι] [DecidableEq ι] (α : Type v) - (ι_in : Type u_in) (ι_out : Type u_out) where +structure TypedCircuit (Node : Type u) [Fintype Node] [DecidableEq Node] (Val : Type v) + (Input : Type u_in) (Output : Type u_out) where /-- The underlying circuit. -/ - circuit : Circuit ι α + circuit : Circuit Node Val /-- Typed input/output interface for `circuit`. -/ - interface : Interface circuit ι_in ι_out + interface : Interface circuit Input Output namespace TypedCircuit -variable {ι : Type u} [Fintype ι] [DecidableEq ι] -variable {α : Type v} {ι_in : Type u_in} {ι_out : Type u_out} +variable {Node : Type u} [Fintype Node] [DecidableEq Node] +variable {Val : Type v} {Input : Type u_in} {Output : Type u_out} /-- Evaluate a typed circuit on a typed input. -/ -def eval (T : TypedCircuit ι α ι_in ι_out) (input : ι_in → α) : ι_out → α := +def eval (T : TypedCircuit Node Val Input Output) (input : Input → Val) : Output → Val := T.interface.eval input /-- Decide equivalence by enumerating typed inputs. -/ -def checkEquiv (T₁ T₂ : TypedCircuit ι α ι_in ι_out) - [Fintype ι_in] [DecidableEq ι_in] [Fintype ι_out] - [Fintype α] [DecidableEq α] : Bool := - Circuit.checkEquivOnInterface T₁.circuit T₂.circuit T₁.interface T₂.interface +def checkEquiv (T1 T2 : TypedCircuit Node Val Input Output) + [Fintype Input] [DecidableEq Input] [Fintype Output] + [Fintype Val] [DecidableEq Val] : Bool := + Circuit.checkEquivOnInterface T1.circuit T2.circuit T1.interface T2.interface /-- `checkEquiv` is sound and complete for `EquivOnInterface`. -/ -theorem checkEquiv_eq_true_iff (T₁ T₂ : TypedCircuit ι α ι_in ι_out) - [Fintype ι_in] [DecidableEq ι_in] [Fintype ι_out] - [Fintype α] [DecidableEq α] : - checkEquiv T₁ T₂ = true ↔ - EquivOnInterface T₁.circuit T₂.circuit T₁.interface T₂.interface := by +theorem checkEquiv_eq_true_iff (T1 T2 : TypedCircuit Node Val Input Output) + [Fintype Input] [DecidableEq Input] [Fintype Output] + [Fintype Val] [DecidableEq Val] : + checkEquiv T1 T2 = true ↔ + EquivOnInterface T1.circuit T2.circuit T1.interface T2.interface := by simpa [checkEquiv] using - (checkEquivOnInterface_eq_true_iff T₁.circuit T₂.circuit T₁.interface T₂.interface) + (checkEquivOnInterface_eq_true_iff T1.circuit T2.circuit T1.interface T2.interface) -variable {ι' : Type u'} [Fintype ι'] [DecidableEq ι'] +variable {Node' : Type u'} [Fintype Node'] [DecidableEq Node'] /-- Relabel the nodes of a typed circuit. -/ -def relabel (T : TypedCircuit ι α ι_in ι_out) (e : ι ≃ ι') : - TypedCircuit ι' α ι_in ι_out := +def relabel (T : TypedCircuit Node Val Input Output) (e : _root_.Equiv Node Node') : + TypedCircuit Node' Val Input Output := { circuit := T.circuit.relabel e interface := T.interface.relabel e } From 6658bb354e79729ab9086e3ee034a4ecbf77470d Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 19:17:41 +0100 Subject: [PATCH 077/244] Add tensor aliases and linear gates --- AGENTS.md | 4 ++++ Nfp/Circuit.lean | 1 + Nfp/Circuit/Gates.lean | 1 + Nfp/Circuit/Gates/Linear.lean | 33 +++++++++++++++++++++++++ Nfp/Circuit/Tensor.lean | 45 +++++++++++++++++++++++++++++++++++ 5 files changed, 84 insertions(+) create mode 100644 Nfp/Circuit/Gates/Linear.lean create mode 100644 Nfp/Circuit/Tensor.lean diff --git a/AGENTS.md b/AGENTS.md index b008034..9b0a115 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -287,8 +287,12 @@ but you **must** update this list in the same commit. - Typed circuit wrapper and interface-level equivalence checker. - `Nfp/Circuit/Gates/Basic.lean` - Basic gate combinators for aggregating parent values. +- `Nfp/Circuit/Gates/Linear.lean` + - Linear and affine gate combinators built from `Matrix.mulVec`. - `Nfp/Circuit/Gates.lean` - Aggregator for gate combinator modules. +- `Nfp/Circuit/Tensor.lean` + - Typed tensor indices and tensor aliases. - `Nfp/Circuit.lean` - Aggregator for circuit modules. diff --git a/Nfp/Circuit.lean b/Nfp/Circuit.lean index 19c77db..b0eb449 100644 --- a/Nfp/Circuit.lean +++ b/Nfp/Circuit.lean @@ -8,6 +8,7 @@ import Nfp.Circuit.WellFormed import Nfp.Circuit.Cert import Nfp.Circuit.Typed import Nfp.Circuit.Gates +import Nfp.Circuit.Tensor /-! Circuit definitions, semantics, and equivalence checking. diff --git a/Nfp/Circuit/Gates.lean b/Nfp/Circuit/Gates.lean index adc08df..2e96c14 100644 --- a/Nfp/Circuit/Gates.lean +++ b/Nfp/Circuit/Gates.lean @@ -1,6 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Nfp.Circuit.Gates.Basic +import Nfp.Circuit.Gates.Linear /-! Gate combinators for circuit semantics. diff --git a/Nfp/Circuit/Gates/Linear.lean b/Nfp/Circuit/Gates/Linear.lean new file mode 100644 index 0000000..4f42f7c --- /dev/null +++ b/Nfp/Circuit/Gates/Linear.lean @@ -0,0 +1,33 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Data.Matrix.Mul + +/-! +Linear and affine gate combinators built from `Matrix.mulVec`. +-/ + +namespace Nfp + +namespace Circuit + +namespace Gates + +universe u v + +variable {Row : Type u} {Col : Type u} {Val : Type v} + +/-- Linear map on vectors defined by a matrix. -/ +def linear [Fintype Row] [Fintype Col] [NonUnitalNonAssocSemiring Val] + (W : Matrix Row Col Val) (x : Col → Val) : Row → Val := + Matrix.mulVec W x + +/-- Affine map on vectors defined by a matrix and bias. -/ +def affine [Fintype Row] [Fintype Col] [NonUnitalNonAssocSemiring Val] + (W : Matrix Row Col Val) (b : Row → Val) (x : Col → Val) : Row → Val := + linear W x + b + +end Gates + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Tensor.lean b/Nfp/Circuit/Tensor.lean new file mode 100644 index 0000000..165cb26 --- /dev/null +++ b/Nfp/Circuit/Tensor.lean @@ -0,0 +1,45 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Data.Matrix.Basic + +/-! +Typed tensor indices and tensor aliases. +-/ + +namespace Nfp + +namespace Circuit + +namespace Tensor + +universe u v + +/-- Index type for a length-`n` vector. -/ +abbrev VecIndex (n : Nat) : Type := Fin n + +/-- Index type for an `m × n` matrix. -/ +abbrev MatIndex (m n : Nat) : Type := Fin m × Fin n + +/-- Index type for a 3D tensor. -/ +abbrev Tensor3Index (a b c : Nat) : Type := Fin a × Fin b × Fin c + +/-- Index type for a 4D tensor. -/ +abbrev Tensor4Index (a b c d : Nat) : Type := Fin a × Fin b × Fin c × Fin d + +/-- A length-`n` vector of values. -/ +abbrev Vec (n : Nat) (Val : Type v) : Type v := VecIndex n → Val + +/-- An `m × n` matrix of values. -/ +abbrev Mat (m n : Nat) (Val : Type v) : Type v := Matrix (VecIndex m) (VecIndex n) Val + +/-- A 3D tensor of values. -/ +abbrev Tensor3 (a b c : Nat) (Val : Type v) : Type v := Tensor3Index a b c → Val + +/-- A 4D tensor of values. -/ +abbrev Tensor4 (a b c d : Nat) (Val : Type v) : Type v := Tensor4Index a b c d → Val + +end Tensor + +end Circuit + +end Nfp From 91e7aa525f60e45cd8f7699be3bcef420424e48d Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 19:34:53 +0100 Subject: [PATCH 078/244] Add linear layer circuit combinators --- AGENTS.md | 4 + Nfp/Circuit.lean | 1 + Nfp/Circuit/Layers.lean | 7 + Nfp/Circuit/Layers/Linear.lean | 236 +++++++++++++++++++++++++++++++++ 4 files changed, 248 insertions(+) create mode 100644 Nfp/Circuit/Layers.lean create mode 100644 Nfp/Circuit/Layers/Linear.lean diff --git a/AGENTS.md b/AGENTS.md index 9b0a115..27f328c 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -293,6 +293,10 @@ but you **must** update this list in the same commit. - Aggregator for gate combinator modules. - `Nfp/Circuit/Tensor.lean` - Typed tensor indices and tensor aliases. +- `Nfp/Circuit/Layers/Linear.lean` + - Linear/affine layer circuits with typed interfaces. +- `Nfp/Circuit/Layers.lean` + - Aggregator for circuit layer modules. - `Nfp/Circuit.lean` - Aggregator for circuit modules. diff --git a/Nfp/Circuit.lean b/Nfp/Circuit.lean index b0eb449..9c71f7b 100644 --- a/Nfp/Circuit.lean +++ b/Nfp/Circuit.lean @@ -9,6 +9,7 @@ import Nfp.Circuit.Cert import Nfp.Circuit.Typed import Nfp.Circuit.Gates import Nfp.Circuit.Tensor +import Nfp.Circuit.Layers /-! Circuit definitions, semantics, and equivalence checking. diff --git a/Nfp/Circuit/Layers.lean b/Nfp/Circuit/Layers.lean new file mode 100644 index 0000000..9017695 --- /dev/null +++ b/Nfp/Circuit/Layers.lean @@ -0,0 +1,7 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Circuit.Layers.Linear + +/-! +Circuit layer combinators. +-/ diff --git a/Nfp/Circuit/Layers/Linear.lean b/Nfp/Circuit/Layers/Linear.lean new file mode 100644 index 0000000..62c28e5 --- /dev/null +++ b/Nfp/Circuit/Layers/Linear.lean @@ -0,0 +1,236 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Data.Finset.Image +import Mathlib.Logic.Embedding.Basic +import Nfp.Circuit.Basic +import Nfp.Circuit.Gates.Linear +import Nfp.Circuit.Typed + +/-! +Linear and affine layer circuits. +-/ + +namespace Nfp + +namespace Circuit + +namespace Layers + +open Function + +universe u v + +variable {Row Col : Type u} + +/-- Node type for a linear/affine layer from `Col` inputs to `Row` outputs. -/ +abbrev LinearNode (Row Col : Type u) : Type u := Sum Col Row + +/-- Rank function used to orient layer edges from inputs to outputs. -/ +def linearRank : LinearNode Row Col → Nat + | Sum.inl _ => 0 + | Sum.inr _ => 1 + +section Dag + +variable [Fintype Row] [Fintype Col] + +/-- DAG for a single linear/affine layer. -/ +def linearDag : Dag (LinearNode Row Col) := + { graph := { Adj := fun j i => linearRank (Row := Row) (Col := Col) j < + linearRank (Row := Row) (Col := Col) i } + decAdj := by + intro j i + infer_instance + wf := by + simpa using (InvImage.wf (f := linearRank (Row := Row) (Col := Col)) + (h := Nat.lt_wfRel.wf)) } + +/-- Every input node has an edge to every output node. -/ +theorem linearDag_rel_inl_inr (c : Col) (r : Row) : + (linearDag (Row := Row) (Col := Col)).rel (Sum.inl c) (Sum.inr r) := by + dsimp [linearDag, linearRank] + exact Nat.zero_lt_one + +end Dag + +section Inputs + +variable [Fintype Col] + +/-- Input nodes for a linear/affine layer circuit. -/ +def linearInputs : Finset (LinearNode Row Col) := + (Finset.univ : Finset Col).map Embedding.inl + +/-- Membership in the input nodes corresponds to being a left injection. -/ +theorem mem_linearInputs_iff {s : LinearNode Row Col} : + s ∈ linearInputs (Row := Row) (Col := Col) ↔ ∃ c, s = Sum.inl c := by + constructor + · intro hs + rcases (Finset.mem_map.1 hs) with ⟨c, _hc, hcs⟩ + exact ⟨c, hcs.symm⟩ + · rintro ⟨c, rfl⟩ + refine Finset.mem_map.2 ?_ + exact ⟨c, by simp, rfl⟩ + +/-- Right injections are not input nodes. -/ +theorem not_mem_linearInputs_inr (r : Row) : + Sum.inr r ∉ linearInputs (Row := Row) (Col := Col) := by + intro h + rcases (mem_linearInputs_iff (Row := Row) (Col := Col)).1 h with ⟨c, hcs⟩ + cases hcs + +/-- Input labels correspond to input nodes in a linear/affine layer. -/ +def linearInputEquiv : Col ≃ { i // i ∈ linearInputs (Row := Row) (Col := Col) } := + { toFun := fun c => + ⟨Sum.inl c, (mem_linearInputs_iff (Row := Row) (Col := Col)).2 ⟨c, rfl⟩⟩ + invFun := fun i => + match i with + | ⟨Sum.inl c, _⟩ => c + | ⟨Sum.inr r, h⟩ => False.elim + (not_mem_linearInputs_inr (Row := Row) (Col := Col) r h) + left_inv := by + intro c + rfl + right_inv := by + intro i + cases i with + | mk s hs => + cases s with + | inl c => rfl + | inr r => + cases (not_mem_linearInputs_inr (Row := Row) (Col := Col) r hs) } + +end Inputs + +section Outputs + +variable [Fintype Row] + +/-- Output nodes for a linear/affine layer circuit. -/ +def linearOutputs : Finset (LinearNode Row Col) := + (Finset.univ : Finset Row).map Embedding.inr + +/-- Membership in the output nodes corresponds to being a right injection. -/ +theorem mem_linearOutputs_iff {s : LinearNode Row Col} : + s ∈ linearOutputs (Row := Row) (Col := Col) ↔ ∃ r, s = Sum.inr r := by + constructor + · intro hs + rcases (Finset.mem_map.1 hs) with ⟨r, _hr, hrs⟩ + exact ⟨r, hrs.symm⟩ + · rintro ⟨r, rfl⟩ + refine Finset.mem_map.2 ?_ + exact ⟨r, by simp, rfl⟩ + +/-- Left injections are not output nodes. -/ +theorem not_mem_linearOutputs_inl (c : Col) : + Sum.inl c ∉ linearOutputs (Row := Row) (Col := Col) := by + intro h + rcases (mem_linearOutputs_iff (Row := Row) (Col := Col)).1 h with ⟨r, hrs⟩ + cases hrs + +/-- Output labels correspond to output nodes in a linear/affine layer. -/ +def linearOutputEquiv : Row ≃ { i // i ∈ linearOutputs (Row := Row) (Col := Col) } := + { toFun := fun r => + ⟨Sum.inr r, (mem_linearOutputs_iff (Row := Row) (Col := Col)).2 ⟨r, rfl⟩⟩ + invFun := fun i => + match i with + | ⟨Sum.inr r, _⟩ => r + | ⟨Sum.inl c, h⟩ => False.elim + (not_mem_linearOutputs_inl (Row := Row) (Col := Col) c h) + left_inv := by + intro r + rfl + right_inv := by + intro i + cases i with + | mk s hs => + cases s with + | inr r => rfl + | inl c => + cases (not_mem_linearOutputs_inl (Row := Row) (Col := Col) c hs) } + +end Outputs + +section Circuits + +variable [Fintype Row] [Fintype Col] +variable {Val : Type v} [NonUnitalNonAssocSemiring Val] + +/-- Gate semantics for a linear layer circuit. -/ +def linearGate (W : Matrix Row Col Val) : + ∀ i, (∀ j, (linearDag (Row := Row) (Col := Col)).rel j i → Val) → Val := by + intro i rec + cases i with + | inl _ => + exact 0 + | inr r => + let x : Col → Val := fun c => + rec (Sum.inl c) (linearDag_rel_inl_inr (Row := Row) (Col := Col) c r) + exact Gates.linear W x r + +/-- Gate semantics for an affine layer circuit. -/ +def affineGate (W : Matrix Row Col Val) (b : Row → Val) : + ∀ i, (∀ j, (linearDag (Row := Row) (Col := Col)).rel j i → Val) → Val := by + intro i rec + cases i with + | inl _ => + exact 0 + | inr r => + let x : Col → Val := fun c => + rec (Sum.inl c) (linearDag_rel_inl_inr (Row := Row) (Col := Col) c r) + exact Gates.affine W b x r + +/-- Circuit for a linear layer. -/ +def linearCircuit (W : Matrix Row Col Val) : Circuit (LinearNode Row Col) Val := + { dag := linearDag (Row := Row) (Col := Col) + inputs := linearInputs (Row := Row) (Col := Col) + outputs := linearOutputs (Row := Row) (Col := Col) + gate := linearGate (Row := Row) (Col := Col) W } + +/-- Circuit for an affine layer. -/ +def affineCircuit (W : Matrix Row Col Val) (b : Row → Val) : + Circuit (LinearNode Row Col) Val := + { dag := linearDag (Row := Row) (Col := Col) + inputs := linearInputs (Row := Row) (Col := Col) + outputs := linearOutputs (Row := Row) (Col := Col) + gate := affineGate (Row := Row) (Col := Col) W b } + +/-- Typed interface for a linear layer circuit. -/ +def linearInterface (W : Matrix Row Col Val) : + Interface (linearCircuit (Row := Row) (Col := Col) W) Col Row := + { inputs := linearInputEquiv (Row := Row) (Col := Col) + outputs := linearOutputEquiv (Row := Row) (Col := Col) } + +/-- Typed interface for an affine layer circuit. -/ +def affineInterface (W : Matrix Row Col Val) (b : Row → Val) : + Interface (affineCircuit (Row := Row) (Col := Col) W b) Col Row := + { inputs := linearInputEquiv (Row := Row) (Col := Col) + outputs := linearOutputEquiv (Row := Row) (Col := Col) } + +end Circuits + +section Typed + +variable [Fintype Row] [Fintype Col] +variable [DecidableEq Row] [DecidableEq Col] +variable {Val : Type v} [NonUnitalNonAssocSemiring Val] + +/-- Typed linear layer circuit. -/ +def linearTyped (W : Matrix Row Col Val) : + TypedCircuit (LinearNode Row Col) Val Col Row := + { circuit := linearCircuit (Row := Row) (Col := Col) W + interface := linearInterface (Row := Row) (Col := Col) W } + +/-- Typed affine layer circuit. -/ +def affineTyped (W : Matrix Row Col Val) (b : Row → Val) : + TypedCircuit (LinearNode Row Col) Val Col Row := + { circuit := affineCircuit (Row := Row) (Col := Col) W b + interface := affineInterface (Row := Row) (Col := Col) W b } + +end Typed + +end Layers + +end Circuit + +end Nfp From a00e4bfe385dfc3f01034f31d5d75645fc82c36b Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 19:43:17 +0100 Subject: [PATCH 079/244] Add batched linear layer combinators --- AGENTS.md | 2 + Nfp/Circuit/Layers.lean | 1 + Nfp/Circuit/Layers/Tensor.lean | 174 +++++++++++++++++++++++++++++++++ 3 files changed, 177 insertions(+) create mode 100644 Nfp/Circuit/Layers/Tensor.lean diff --git a/AGENTS.md b/AGENTS.md index 27f328c..ec85914 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -295,6 +295,8 @@ but you **must** update this list in the same commit. - Typed tensor indices and tensor aliases. - `Nfp/Circuit/Layers/Linear.lean` - Linear/affine layer circuits with typed interfaces. +- `Nfp/Circuit/Layers/Tensor.lean` + - Batched linear/affine layer circuits for tensor-shaped data. - `Nfp/Circuit/Layers.lean` - Aggregator for circuit layer modules. - `Nfp/Circuit.lean` diff --git a/Nfp/Circuit/Layers.lean b/Nfp/Circuit/Layers.lean index 9017695..8867d41 100644 --- a/Nfp/Circuit/Layers.lean +++ b/Nfp/Circuit/Layers.lean @@ -1,6 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Nfp.Circuit.Layers.Linear +import Nfp.Circuit.Layers.Tensor /-! Circuit layer combinators. diff --git a/Nfp/Circuit/Layers/Tensor.lean b/Nfp/Circuit/Layers/Tensor.lean new file mode 100644 index 0000000..7d5a21d --- /dev/null +++ b/Nfp/Circuit/Layers/Tensor.lean @@ -0,0 +1,174 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Circuit.Layers.Linear + +/-! +Tensor-shaped layer builders (batched linear and affine layers). +-/ + +namespace Nfp + +namespace Circuit + +namespace Layers + +universe u v + +variable {Batch Row Col : Type u} + +/-- Node type for a batched linear/affine layer. -/ +abbrev BatchedLinearNode (Batch Row Col : Type u) : Type u := + LinearNode (Batch × Row) (Batch × Col) + +/-- Adjacency for batched linear layers: inputs connect only to outputs in the same batch. -/ +def batchedLinearAdj (Batch Row Col : Type u) : + BatchedLinearNode Batch Row Col → BatchedLinearNode Batch Row Col → Prop + | Sum.inl (b, _), Sum.inr (b', _) => b = b' + | _, _ => False + +section Dag + +variable [Fintype Batch] [Fintype Row] [Fintype Col] +variable [DecidableEq Batch] + +/-- DAG for a batched linear/affine layer. -/ +def batchedLinearDag : Dag (BatchedLinearNode Batch Row Col) := + { graph := { Adj := batchedLinearAdj Batch Row Col } + decAdj := by + intro j i + cases j with + | inl bc => + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr br => + exact (inferInstance : Decidable (bc.1 = br.1)) + | inr _ => + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr _ => + exact isFalse (by intro h; cases h) + wf := by + have hsub : Subrelation (batchedLinearAdj Batch Row Col) + (fun j i => + linearRank (Row := Batch × Row) (Col := Batch × Col) j < + linearRank (Row := Batch × Row) (Col := Batch × Col) i) := by + intro j i h + cases j <;> cases i <;> simp [batchedLinearAdj, linearRank] at h ⊢ + have hwf : WellFounded (fun j i => + linearRank (Row := Batch × Row) (Col := Batch × Col) j < + linearRank (Row := Batch × Row) (Col := Batch × Col) i) := by + simpa using (InvImage.wf + (f := linearRank (Row := Batch × Row) (Col := Batch × Col)) + (h := Nat.lt_wfRel.wf)) + exact Subrelation.wf hsub hwf } + +/-- Edges connect each batch's inputs to its outputs. -/ +theorem batchedLinearDag_rel_inl_inr (b : Batch) (c : Col) (r : Row) : + (batchedLinearDag (Batch := Batch) (Row := Row) (Col := Col)).rel + (Sum.inl (b, c)) (Sum.inr (b, r)) := by + change batchedLinearAdj Batch Row Col (Sum.inl (b, c)) (Sum.inr (b, r)) + simp [batchedLinearAdj] + +end Dag + +section Circuits + +variable [Fintype Batch] [Fintype Row] [Fintype Col] +variable [DecidableEq Batch] +variable {Val : Type v} [NonUnitalNonAssocSemiring Val] + +/-- Gate semantics for a batched linear layer circuit. -/ +def batchedLinearGate (W : Matrix Row Col Val) : + ∀ i, + (∀ j, + (batchedLinearDag (Batch := Batch) (Row := Row) (Col := Col)).rel j i → Val) → + Val := by + intro i rec + cases i with + | inl _ => + exact 0 + | inr br => + cases br with + | mk b r => + let x : Col → Val := fun c => + rec (Sum.inl (b, c)) + (batchedLinearDag_rel_inl_inr (Batch := Batch) (Row := Row) (Col := Col) b c r) + exact Gates.linear W x r + +/-- Gate semantics for a batched affine layer circuit. -/ +def batchedAffineGate (W : Matrix Row Col Val) (bias : Row → Val) : + ∀ i, + (∀ j, + (batchedLinearDag (Batch := Batch) (Row := Row) (Col := Col)).rel j i → Val) → + Val := by + intro i rec + cases i with + | inl _ => + exact 0 + | inr br => + cases br with + | mk b r => + let x : Col → Val := fun c => + rec (Sum.inl (b, c)) + (batchedLinearDag_rel_inl_inr (Batch := Batch) (Row := Row) (Col := Col) b c r) + exact Gates.affine W bias x r + +/-- Circuit for a batched linear layer. -/ +def batchedLinearCircuit (W : Matrix Row Col Val) : + Circuit (BatchedLinearNode Batch Row Col) Val := + { dag := batchedLinearDag (Batch := Batch) (Row := Row) (Col := Col) + inputs := linearInputs (Row := Batch × Row) (Col := Batch × Col) + outputs := linearOutputs (Row := Batch × Row) (Col := Batch × Col) + gate := batchedLinearGate (Batch := Batch) (Row := Row) (Col := Col) W } + +/-- Circuit for a batched affine layer. -/ +def batchedAffineCircuit (W : Matrix Row Col Val) (bias : Row → Val) : + Circuit (BatchedLinearNode Batch Row Col) Val := + { dag := batchedLinearDag (Batch := Batch) (Row := Row) (Col := Col) + inputs := linearInputs (Row := Batch × Row) (Col := Batch × Col) + outputs := linearOutputs (Row := Batch × Row) (Col := Batch × Col) + gate := batchedAffineGate (Batch := Batch) (Row := Row) (Col := Col) W bias } + +/-- Typed interface for a batched linear layer circuit. -/ +def batchedLinearInterface (W : Matrix Row Col Val) : + Interface (batchedLinearCircuit (Batch := Batch) (Row := Row) (Col := Col) W) + (Batch × Col) (Batch × Row) := + { inputs := linearInputEquiv (Row := Batch × Row) (Col := Batch × Col) + outputs := linearOutputEquiv (Row := Batch × Row) (Col := Batch × Col) } + +/-- Typed interface for a batched affine layer circuit. -/ +def batchedAffineInterface (W : Matrix Row Col Val) (bias : Row → Val) : + Interface (batchedAffineCircuit (Batch := Batch) (Row := Row) (Col := Col) W bias) + (Batch × Col) (Batch × Row) := + { inputs := linearInputEquiv (Row := Batch × Row) (Col := Batch × Col) + outputs := linearOutputEquiv (Row := Batch × Row) (Col := Batch × Col) } + +end Circuits + +section Typed + +variable [Fintype Batch] [Fintype Row] [Fintype Col] +variable [DecidableEq Batch] [DecidableEq Row] [DecidableEq Col] +variable {Val : Type v} [NonUnitalNonAssocSemiring Val] + +/-- Typed batched linear layer circuit. -/ +def batchedLinearTyped (W : Matrix Row Col Val) : + TypedCircuit (BatchedLinearNode Batch Row Col) Val (Batch × Col) (Batch × Row) := + { circuit := batchedLinearCircuit (Batch := Batch) (Row := Row) (Col := Col) W + interface := batchedLinearInterface (Batch := Batch) (Row := Row) (Col := Col) W } + +/-- Typed batched affine layer circuit. -/ +def batchedAffineTyped (W : Matrix Row Col Val) (bias : Row → Val) : + TypedCircuit (BatchedLinearNode Batch Row Col) Val (Batch × Col) (Batch × Row) := + { circuit := batchedAffineCircuit (Batch := Batch) (Row := Row) (Col := Col) W bias + interface := batchedAffineInterface (Batch := Batch) (Row := Row) (Col := Col) W bias } + +end Typed + +end Layers + +end Circuit + +end Nfp From 9fd30b1ab1c9252035a87aa539bf9c38449f46f6 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 19:46:54 +0100 Subject: [PATCH 080/244] Add reshape combinators for typed interfaces --- AGENTS.md | 2 ++ Nfp/Circuit/Layers.lean | 1 + Nfp/Circuit/Layers/Reshape.lean | 57 +++++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+) create mode 100644 Nfp/Circuit/Layers/Reshape.lean diff --git a/AGENTS.md b/AGENTS.md index ec85914..09d6156 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -297,6 +297,8 @@ but you **must** update this list in the same commit. - Linear/affine layer circuits with typed interfaces. - `Nfp/Circuit/Layers/Tensor.lean` - Batched linear/affine layer circuits for tensor-shaped data. +- `Nfp/Circuit/Layers/Reshape.lean` + - Reshape combinators for product-typed circuit interfaces. - `Nfp/Circuit/Layers.lean` - Aggregator for circuit layer modules. - `Nfp/Circuit.lean` diff --git a/Nfp/Circuit/Layers.lean b/Nfp/Circuit/Layers.lean index 8867d41..9feceac 100644 --- a/Nfp/Circuit/Layers.lean +++ b/Nfp/Circuit/Layers.lean @@ -2,6 +2,7 @@ import Nfp.Circuit.Layers.Linear import Nfp.Circuit.Layers.Tensor +import Nfp.Circuit.Layers.Reshape /-! Circuit layer combinators. diff --git a/Nfp/Circuit/Layers/Reshape.lean b/Nfp/Circuit/Layers/Reshape.lean new file mode 100644 index 0000000..0737007 --- /dev/null +++ b/Nfp/Circuit/Layers/Reshape.lean @@ -0,0 +1,57 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Logic.Equiv.Prod +import Nfp.Circuit.Typed + +/-! +Reshape combinators for product-typed circuit interfaces. +-/ + +namespace Nfp + +namespace Circuit + +namespace Layers + +universe u v u_in u_out u_in' u_out' + +variable {Node : Type u} [Fintype Node] [DecidableEq Node] +variable {Val : Type v} +variable {α β γ : Type u_in} +variable {δ ε ζ : Type u_out} +variable {Input : Type u_in} {Output : Type u_out} +variable {Input' : Type u_in'} {Output' : Type u_out'} + +/-- Reassociate the input/output product structure of a typed circuit. -/ +def reassoc3 + (T : TypedCircuit Node Val ((α × β) × γ) ((δ × ε) × ζ)) : + TypedCircuit Node Val (α × β × γ) (δ × ε × ζ) := + { circuit := T.circuit + interface := + { inputs := (_root_.Equiv.prodAssoc α β γ).symm.trans T.interface.inputs + outputs := (_root_.Equiv.prodAssoc δ ε ζ).symm.trans T.interface.outputs } } + +/-- Swap the two factors of the input/output product structure. -/ +def swap12 + (T : TypedCircuit Node Val (α × β) (δ × ε)) : + TypedCircuit Node Val (β × α) (ε × δ) := + { circuit := T.circuit + interface := + { inputs := (_root_.Equiv.prodComm α β).symm.trans T.interface.inputs + outputs := (_root_.Equiv.prodComm δ ε).symm.trans T.interface.outputs } } + +/-- Apply equivalences to the input/output labels of a typed circuit. -/ +def mapInterface + (T : TypedCircuit Node Val Input Output) + (eIn : _root_.Equiv Input Input') (eOut : _root_.Equiv Output Output') : + TypedCircuit Node Val Input' Output' := + { circuit := T.circuit + interface := + { inputs := eIn.symm.trans T.interface.inputs + outputs := eOut.symm.trans T.interface.outputs } } + +end Layers + +end Circuit + +end Nfp From f2b76def11b29e8b66b4593f7174229a00c9a5c7 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 19:50:10 +0100 Subject: [PATCH 081/244] Add head split and merge combinators --- AGENTS.md | 2 + Nfp/Circuit/Layers.lean | 1 + Nfp/Circuit/Layers/Heads.lean | 83 +++++++++++++++++++++++++++++++++++ 3 files changed, 86 insertions(+) create mode 100644 Nfp/Circuit/Layers/Heads.lean diff --git a/AGENTS.md b/AGENTS.md index 09d6156..795f8d8 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -299,6 +299,8 @@ but you **must** update this list in the same commit. - Batched linear/affine layer circuits for tensor-shaped data. - `Nfp/Circuit/Layers/Reshape.lean` - Reshape combinators for product-typed circuit interfaces. +- `Nfp/Circuit/Layers/Heads.lean` + - Head split/merge combinators for transformer-shaped indices. - `Nfp/Circuit/Layers.lean` - Aggregator for circuit layer modules. - `Nfp/Circuit.lean` diff --git a/Nfp/Circuit/Layers.lean b/Nfp/Circuit/Layers.lean index 9feceac..64c80c2 100644 --- a/Nfp/Circuit/Layers.lean +++ b/Nfp/Circuit/Layers.lean @@ -3,6 +3,7 @@ import Nfp.Circuit.Layers.Linear import Nfp.Circuit.Layers.Tensor import Nfp.Circuit.Layers.Reshape +import Nfp.Circuit.Layers.Heads /-! Circuit layer combinators. diff --git a/Nfp/Circuit/Layers/Heads.lean b/Nfp/Circuit/Layers/Heads.lean new file mode 100644 index 0000000..84a7f4c --- /dev/null +++ b/Nfp/Circuit/Layers/Heads.lean @@ -0,0 +1,83 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Logic.Equiv.Fin.Basic +import Mathlib.Logic.Equiv.Prod +import Nfp.Circuit.Layers.Reshape + +/-! +Head split/merge combinators for transformer-style shapes. +-/ + +namespace Nfp + +namespace Circuit + +namespace Layers + +universe u v u_in u_out + +variable {Node : Type u} [Fintype Node] [DecidableEq Node] +variable {Val : Type v} +variable {Batch : Type u} + +/-- Split a hidden index into `(heads, headDim)` using `Fin` product equivalence. -/ +def headSplitEquiv (heads dim : Nat) : + Fin (heads * dim) ≃ Fin heads × Fin dim := + (finProdFinEquiv (m := heads) (n := dim)).symm + +/-- Merge `(heads, headDim)` back into a hidden index. -/ +def headMergeEquiv (heads dim : Nat) : + Fin heads × Fin dim ≃ Fin (heads * dim) := + finProdFinEquiv (m := heads) (n := dim) + +/-- Split the hidden dimension inside a batched index. -/ +def batchHeadSplitEquiv (Batch : Type u) (heads dim : Nat) : + Batch × Fin (heads * dim) ≃ Batch × Fin heads × Fin dim := + _root_.Equiv.prodCongr (_root_.Equiv.refl Batch) (headSplitEquiv heads dim) + +/-- Merge the head and head-dimension inside a batched index. -/ +def batchHeadMergeEquiv (Batch : Type u) (heads dim : Nat) : + Batch × Fin heads × Fin dim ≃ Batch × Fin (heads * dim) := + (batchHeadSplitEquiv Batch heads dim).symm + +/-- Split heads on the input labels of a typed circuit. -/ +def splitHeadsInput {Output : Type u_out} (heads dim : Nat) + (T : TypedCircuit Node Val (Batch × Fin (heads * dim)) Output) : + TypedCircuit Node Val (Batch × Fin heads × Fin dim) Output := + mapInterface T (batchHeadSplitEquiv Batch heads dim) (_root_.Equiv.refl Output) + +/-- Split heads on the output labels of a typed circuit. -/ +def splitHeadsOutput {Input : Type u_in} (heads dim : Nat) + (T : TypedCircuit Node Val Input (Batch × Fin (heads * dim))) : + TypedCircuit Node Val Input (Batch × Fin heads × Fin dim) := + mapInterface T (_root_.Equiv.refl Input) (batchHeadSplitEquiv Batch heads dim) + +/-- Split heads on both input and output labels. -/ +def splitHeads (heads dim : Nat) + (T : TypedCircuit Node Val (Batch × Fin (heads * dim)) (Batch × Fin (heads * dim))) : + TypedCircuit Node Val (Batch × Fin heads × Fin dim) (Batch × Fin heads × Fin dim) := + mapInterface T (batchHeadSplitEquiv Batch heads dim) (batchHeadSplitEquiv Batch heads dim) + +/-- Merge heads on the input labels of a typed circuit. -/ +def mergeHeadsInput {Output : Type u_out} (heads dim : Nat) + (T : TypedCircuit Node Val (Batch × Fin heads × Fin dim) Output) : + TypedCircuit Node Val (Batch × Fin (heads * dim)) Output := + mapInterface T (batchHeadMergeEquiv Batch heads dim) (_root_.Equiv.refl Output) + +/-- Merge heads on the output labels of a typed circuit. -/ +def mergeHeadsOutput {Input : Type u_in} (heads dim : Nat) + (T : TypedCircuit Node Val Input (Batch × Fin heads × Fin dim)) : + TypedCircuit Node Val Input (Batch × Fin (heads * dim)) := + mapInterface T (_root_.Equiv.refl Input) (batchHeadMergeEquiv Batch heads dim) + +/-- Merge heads on both input and output labels. -/ +def mergeHeads (heads dim : Nat) + (T : TypedCircuit Node Val (Batch × Fin heads × Fin dim) (Batch × Fin heads × Fin dim)) : + TypedCircuit Node Val (Batch × Fin (heads * dim)) (Batch × Fin (heads * dim)) := + mapInterface T (batchHeadMergeEquiv Batch heads dim) (batchHeadMergeEquiv Batch heads dim) + +end Layers + +end Circuit + +end Nfp From 1ad0adafc7568028333d4eb4e31538fd0a0e6483 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 19:59:22 +0100 Subject: [PATCH 082/244] Add attention projection wiring --- AGENTS.md | 2 + Nfp/Circuit/Layers.lean | 1 + Nfp/Circuit/Layers/Attention.lean | 64 +++++++++++++++++++++++++++++++ 3 files changed, 67 insertions(+) create mode 100644 Nfp/Circuit/Layers/Attention.lean diff --git a/AGENTS.md b/AGENTS.md index 795f8d8..a5faf5d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -301,6 +301,8 @@ but you **must** update this list in the same commit. - Reshape combinators for product-typed circuit interfaces. - `Nfp/Circuit/Layers/Heads.lean` - Head split/merge combinators for transformer-shaped indices. +- `Nfp/Circuit/Layers/Attention.lean` + - Q/K/V and output projection wiring for attention layers. - `Nfp/Circuit/Layers.lean` - Aggregator for circuit layer modules. - `Nfp/Circuit.lean` diff --git a/Nfp/Circuit/Layers.lean b/Nfp/Circuit/Layers.lean index 64c80c2..b09773e 100644 --- a/Nfp/Circuit/Layers.lean +++ b/Nfp/Circuit/Layers.lean @@ -4,6 +4,7 @@ import Nfp.Circuit.Layers.Linear import Nfp.Circuit.Layers.Tensor import Nfp.Circuit.Layers.Reshape import Nfp.Circuit.Layers.Heads +import Nfp.Circuit.Layers.Attention /-! Circuit layer combinators. diff --git a/Nfp/Circuit/Layers/Attention.lean b/Nfp/Circuit/Layers/Attention.lean new file mode 100644 index 0000000..e161eb4 --- /dev/null +++ b/Nfp/Circuit/Layers/Attention.lean @@ -0,0 +1,64 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Circuit.Layers.Heads +import Nfp.Circuit.Layers.Tensor + +/-! +QKV and output projection wiring for attention layers. +-/ + +namespace Nfp + +namespace Circuit + +namespace Layers + +universe v + +variable {Batch : Type} [Fintype Batch] [DecidableEq Batch] +variable {Val : Type v} [NonUnitalNonAssocSemiring Val] + +/-- Hidden dimension type for a `heads × dim` factorization. -/ +abbrev Hidden (heads dim : Nat) : Type := Fin (heads * dim) + +/-- Node type for Q/K/V and output projection layers. -/ +abbrev QkvNode (Batch : Type) (heads dim : Nat) : Type := + BatchedLinearNode Batch (Hidden heads dim) (Hidden heads dim) + +/-- Q projection with head-split output. -/ +def qProj (heads dim : Nat) (Wq : Matrix (Hidden heads dim) (Hidden heads dim) Val) : + TypedCircuit (QkvNode Batch heads dim) Val (Batch × Hidden heads dim) + (Batch × Fin heads × Fin dim) := + splitHeadsOutput (Batch := Batch) heads dim + (batchedLinearTyped (Batch := Batch) + (Row := Hidden heads dim) (Col := Hidden heads dim) Wq) + +/-- K projection with head-split output. -/ +def kProj (heads dim : Nat) (Wk : Matrix (Hidden heads dim) (Hidden heads dim) Val) : + TypedCircuit (QkvNode Batch heads dim) Val (Batch × Hidden heads dim) + (Batch × Fin heads × Fin dim) := + splitHeadsOutput (Batch := Batch) heads dim + (batchedLinearTyped (Batch := Batch) + (Row := Hidden heads dim) (Col := Hidden heads dim) Wk) + +/-- V projection with head-split output. -/ +def vProj (heads dim : Nat) (Wv : Matrix (Hidden heads dim) (Hidden heads dim) Val) : + TypedCircuit (QkvNode Batch heads dim) Val (Batch × Hidden heads dim) + (Batch × Fin heads × Fin dim) := + splitHeadsOutput (Batch := Batch) heads dim + (batchedLinearTyped (Batch := Batch) + (Row := Hidden heads dim) (Col := Hidden heads dim) Wv) + +/-- Output projection with head-merged input. -/ +def outProj (heads dim : Nat) (Wo : Matrix (Hidden heads dim) (Hidden heads dim) Val) : + TypedCircuit (QkvNode Batch heads dim) Val (Batch × Fin heads × Fin dim) + (Batch × Hidden heads dim) := + splitHeadsInput (Batch := Batch) heads dim + (batchedLinearTyped (Batch := Batch) + (Row := Hidden heads dim) (Col := Hidden heads dim) Wo) + +end Layers + +end Circuit + +end Nfp From e4f535290093eef0b11d585a0505240111710607 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 22:36:59 +0100 Subject: [PATCH 083/244] Add circuit composition and induction-head spec --- AGENTS.md | 8 +- Nfp/Circuit.lean | 1 + Nfp/Circuit/Compose.lean | 390 +++++++++++++++++++++++ Nfp/Circuit/Layers.lean | 2 + Nfp/Circuit/Layers/Induction.lean | 111 +++++++ Nfp/Circuit/Layers/TransformerBlock.lean | 85 +++++ 6 files changed, 596 insertions(+), 1 deletion(-) create mode 100644 Nfp/Circuit/Compose.lean create mode 100644 Nfp/Circuit/Layers/Induction.lean create mode 100644 Nfp/Circuit/Layers/TransformerBlock.lean diff --git a/AGENTS.md b/AGENTS.md index a5faf5d..c8ff86e 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -285,6 +285,8 @@ but you **must** update this list in the same commit. - Equivalence definition and finite checker. - `Nfp/Circuit/Typed.lean` - Typed circuit wrapper and interface-level equivalence checker. +- `Nfp/Circuit/Compose.lean` + - Sequential composition and residual wiring for typed circuits. - `Nfp/Circuit/Gates/Basic.lean` - Basic gate combinators for aggregating parent values. - `Nfp/Circuit/Gates/Linear.lean` @@ -302,7 +304,11 @@ but you **must** update this list in the same commit. - `Nfp/Circuit/Layers/Heads.lean` - Head split/merge combinators for transformer-shaped indices. - `Nfp/Circuit/Layers/Attention.lean` - - Q/K/V and output projection wiring for attention layers. + - Q/K/V, output projection wiring, and attention score/mixing core. +- `Nfp/Circuit/Layers/Induction.lean` + - Induction-head weight specs and attention-core output lemmas. +- `Nfp/Circuit/Layers/TransformerBlock.lean` + - GPT-style transformer block wiring from LN/attention/MLP circuits. - `Nfp/Circuit/Layers.lean` - Aggregator for circuit layer modules. - `Nfp/Circuit.lean` diff --git a/Nfp/Circuit.lean b/Nfp/Circuit.lean index 9c71f7b..51917ea 100644 --- a/Nfp/Circuit.lean +++ b/Nfp/Circuit.lean @@ -7,6 +7,7 @@ import Nfp.Circuit.Semantics import Nfp.Circuit.WellFormed import Nfp.Circuit.Cert import Nfp.Circuit.Typed +import Nfp.Circuit.Compose import Nfp.Circuit.Gates import Nfp.Circuit.Tensor import Nfp.Circuit.Layers diff --git a/Nfp/Circuit/Compose.lean b/Nfp/Circuit/Compose.lean new file mode 100644 index 0000000..88e654b --- /dev/null +++ b/Nfp/Circuit/Compose.lean @@ -0,0 +1,390 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Data.Finset.Disjoint +import Mathlib.Data.Fintype.Sum +import Mathlib.Data.Sum.Order +import Mathlib.Logic.Embedding.Basic +import Nfp.Circuit.Typed + +/-! +Combinators for composing typed circuits (sequential and residual wiring). +-/ + +namespace Nfp + +universe u v u' u_in u_mid u_out + +namespace Circuit + +open Function + +section SumEquiv + +variable {Left : Type u} {Right : Type u'} + +/-- Embed a finset subtype into the left injection of a sum. -/ +def inlSubtypeEquiv (s : Finset Left) : + { i // i ∈ s } ≃ { i // i ∈ s.map (Embedding.inl : Left ↪ Left ⊕ Right) } := + { toFun := fun i => + ⟨Sum.inl i.1, by + refine Finset.mem_map.2 ?_ + exact ⟨i.1, i.2, rfl⟩⟩ + invFun := fun i => + match i with + | ⟨Sum.inl a, ha⟩ => + let h' : a ∈ s := by + rcases (Finset.mem_map.1 ha) with ⟨a', ha', h⟩ + cases h + exact ha' + ⟨a, h'⟩ + | ⟨Sum.inr b, hb⟩ => + False.elim <| by + rcases (Finset.mem_map.1 hb) with ⟨a, _ha, h⟩ + cases h + left_inv := by + intro i + rfl + right_inv := by + intro i + cases i with + | mk s hs => + cases s with + | inl a => + rfl + | inr b => + have : False := by + rcases (Finset.mem_map.1 hs) with ⟨a, _ha, h⟩ + cases h + cases this } + +/-- Embed a finset subtype into the right injection of a sum. -/ +def inrSubtypeEquiv (s : Finset Right) : + { i // i ∈ s } ≃ { i // i ∈ s.map (Embedding.inr : Right ↪ Left ⊕ Right) } := + { toFun := fun i => + ⟨Sum.inr i.1, by + refine Finset.mem_map.2 ?_ + exact ⟨i.1, i.2, rfl⟩⟩ + invFun := fun i => + match i with + | ⟨Sum.inr a, ha⟩ => + let h' : a ∈ s := by + rcases (Finset.mem_map.1 ha) with ⟨a', ha', h⟩ + cases h + exact ha' + ⟨a, h'⟩ + | ⟨Sum.inl b, hb⟩ => + False.elim <| by + rcases (Finset.mem_map.1 hb) with ⟨a, _ha, h⟩ + cases h + left_inv := by + intro i + rfl + right_inv := by + intro i + cases i with + | mk s hs => + cases s with + | inr a => + rfl + | inl b => + have : False := by + rcases (Finset.mem_map.1 hs) with ⟨a, _ha, h⟩ + cases h + cases this } + +end SumEquiv + +section Seq + +variable {Node₁ : Type u} [Fintype Node₁] +variable {Node₂ : Type u'} [Fintype Node₂] +variable {Val : Type v} +variable {Input : Type u_in} {Mid : Type u_mid} {Output : Type u_out} +variable {C1 : Circuit Node₁ Val} {C2 : Circuit Node₂ Val} +variable {I1 : Interface C1 Input Mid} {I2 : Interface C2 Mid Output} + +/-- Bridge edges from the outputs of `C1` to the inputs of `C2`. -/ +def seqBridge (j : Node₁) (i : Node₂) : Prop := + ∃ h : i ∈ C2.inputs, + j = (I1.outputs (I2.inputs.symm ⟨i, h⟩)).1 + +/-- Fixing an input witness reduces `seqBridge` to an equality. -/ +theorem seqBridge_iff_eq {j : Node₁} {i : Node₂} (hmem : i ∈ C2.inputs) : + seqBridge (C1 := C1) (C2 := C2) (I1 := I1) (I2 := I2) j i ↔ + j = (I1.outputs (I2.inputs.symm ⟨i, hmem⟩)).1 := by + constructor + · rintro ⟨h, hEq⟩ + have hSubtype : + (⟨i, h⟩ : { i // i ∈ C2.inputs }) = ⟨i, hmem⟩ := by + apply Subtype.ext + rfl + have hMid : + I2.inputs.symm ⟨i, h⟩ = I2.inputs.symm ⟨i, hmem⟩ := by + exact congrArg I2.inputs.symm hSubtype + have hOut : + (I1.outputs (I2.inputs.symm ⟨i, h⟩)).1 = + (I1.outputs (I2.inputs.symm ⟨i, hmem⟩)).1 := by + exact congrArg Subtype.val (congrArg I1.outputs hMid) + exact hEq.trans hOut + · intro hEq + exact ⟨hmem, hEq⟩ + +/-- Adjacency for sequentially composed circuits. -/ +def seqAdj : Node₁ ⊕ Node₂ → Node₁ ⊕ Node₂ → Prop + | Sum.inl j, Sum.inl i => C1.dag.rel j i + | Sum.inl j, Sum.inr i => + seqBridge (C1 := C1) (C2 := C2) (I1 := I1) (I2 := I2) j i + | Sum.inr j, Sum.inr i => C2.dag.rel j i + | _, _ => False + +variable [DecidableEq Node₁] [DecidableEq Node₂] + +/-- DAG for sequentially composed circuits. -/ +def seqDag : Dag (Node₁ ⊕ Node₂) := + { graph := { Adj := seqAdj (C1 := C1) (C2 := C2) (I1 := I1) (I2 := I2) } + decAdj := by + intro j i + cases j with + | inl j => + cases i with + | inl i => + exact (inferInstance : Decidable (C1.dag.rel j i)) + | inr i => + by_cases hmem : i ∈ C2.inputs + · by_cases hEq : + j = (I1.outputs (I2.inputs.symm ⟨i, hmem⟩)).1 + · exact isTrue ⟨hmem, hEq⟩ + · exact isFalse (by + intro h + have hEq' : + j = (I1.outputs (I2.inputs.symm ⟨i, hmem⟩)).1 := + (seqBridge_iff_eq (C1 := C1) (C2 := C2) (I1 := I1) (I2 := I2) + (j := j) (i := i) hmem).1 h + exact hEq hEq') + · exact isFalse (by + intro h + exact hmem h.1) + | inr j => + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr i => + exact (inferInstance : Decidable (C2.dag.rel j i)) + wf := by + have hsub : + Subrelation + (seqAdj (C1 := C1) (C2 := C2) (I1 := I1) (I2 := I2)) + (Sum.Lex C1.dag.rel C2.dag.rel) := by + intro j i h + cases j with + | inl j => + cases i with + | inl i => + exact Sum.Lex.inl h + | inr i => + exact Sum.Lex.sep _ _ + | inr j => + cases i with + | inl _ => + exact False.elim h + | inr i => + exact Sum.Lex.inr h + have hwf : WellFounded (Sum.Lex C1.dag.rel C2.dag.rel) := + Sum.lex_wf C1.dag.wf C2.dag.wf + exact Subrelation.wf hsub hwf } + +/-- Sequential composition of circuits at the node level. -/ +def seqCircuit : Circuit (Node₁ ⊕ Node₂) Val := + { dag := seqDag (C1 := C1) (C2 := C2) (I1 := I1) (I2 := I2) + inputs := C1.inputs.map (Embedding.inl : Node₁ ↪ Node₁ ⊕ Node₂) + outputs := C2.outputs.map (Embedding.inr : Node₂ ↪ Node₁ ⊕ Node₂) + gate := by + intro i rec + cases i with + | inl i => + exact C1.gate i (fun j h => + rec (Sum.inl j) (by + change C1.dag.rel j i + exact h)) + | inr i => + by_cases hinput : i ∈ C2.inputs + · let mid : Mid := I2.inputs.symm ⟨i, hinput⟩ + let out : Node₁ := (I1.outputs mid).1 + exact rec (Sum.inl out) (by + refine ⟨hinput, rfl⟩) + · exact C2.gate i (fun j h => + rec (Sum.inr j) (by + change C2.dag.rel j i + exact h)) } + +/-- Interface for sequentially composed circuits. -/ +def seqInterface : + Interface + (seqCircuit (C1 := C1) (C2 := C2) (I1 := I1) (I2 := I2)) Input Output := + { inputs := + I1.inputs.trans (inlSubtypeEquiv (s := C1.inputs)) + outputs := + I2.outputs.trans (inrSubtypeEquiv (s := C2.outputs)) } + +end Seq + +section Residual + +variable {Node : Type u} +variable {Input : Type u_in} [Fintype Input] + +/-- Output nodes for residual wiring. -/ +def residualOutputs : Finset (Node ⊕ Input) := + (Finset.univ : Finset Input).map (Embedding.inr : Input ↪ Node ⊕ Input) + +/-- Output equivalence for residual wiring. -/ +def residualOutputEquiv : + Input ≃ { i // i ∈ residualOutputs (Node := Node) (Input := Input) } := + { toFun := fun o => + ⟨Sum.inr o, by + refine Finset.mem_map.2 ?_ + exact ⟨o, by simp, rfl⟩⟩ + invFun := fun i => + match i with + | ⟨Sum.inr o, _⟩ => o + | ⟨Sum.inl _, h⟩ => + False.elim <| by + rcases (Finset.mem_map.1 h) with ⟨o, _ho, ho⟩ + cases ho + left_inv := by + intro o + rfl + right_inv := by + intro i + cases i with + | mk s hs => + cases s with + | inr o => + rfl + | inl s => + have : False := by + rcases (Finset.mem_map.1 hs) with ⟨o, _ho, ho⟩ + cases ho + cases this } + +variable [Fintype Node] +variable {Val : Type v} [Add Val] +variable {C : Circuit Node Val} +variable {I : Interface C Input Input} + +/-- Adjacency for residual wiring on a typed circuit. -/ +def residualAdj : Node ⊕ Input → Node ⊕ Input → Prop + | Sum.inl j, Sum.inl i => C.dag.rel j i + | Sum.inl j, Sum.inr o => j = (I.inputs o).1 ∨ j = (I.outputs o).1 + | _, _ => False + +variable [DecidableEq Node] + +/-- DAG for residual wiring on a typed circuit. -/ +def residualDag : Dag (Node ⊕ Input) := + { graph := { Adj := residualAdj (C := C) (I := I) } + decAdj := by + intro j i + cases j with + | inl j => + cases i with + | inl i => + exact (inferInstance : Decidable (C.dag.rel j i)) + | inr o => + exact (inferInstance : + Decidable (j = (I.inputs o).1 ∨ j = (I.outputs o).1)) + | inr _ => + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr _ => + exact isFalse (by intro h; cases h) + wf := by + have hsub : + Subrelation (residualAdj (C := C) (I := I)) + (Sum.Lex C.dag.rel (fun _ _ : Input => False)) := by + intro j i h + cases j with + | inl j => + cases i with + | inl i => + exact Sum.Lex.inl h + | inr _ => + exact Sum.Lex.sep _ _ + | inr _ => + cases i with + | inl _ => + exact False.elim h + | inr _ => + exact False.elim h + have hfalse : WellFounded (fun _ _ : Input => False) := by + refine ⟨?_⟩ + intro a + refine Acc.intro a ?_ + intro b h + cases h + have hwf : WellFounded (Sum.Lex C.dag.rel (fun _ _ : Input => False)) := + Sum.lex_wf C.dag.wf hfalse + exact Subrelation.wf hsub hwf } + +/-- Circuit that adds a residual connection to a typed circuit. -/ +def residualCircuit : Circuit (Node ⊕ Input) Val := + { dag := residualDag (C := C) (I := I) + inputs := C.inputs.map (Embedding.inl : Node ↪ Node ⊕ Input) + outputs := residualOutputs (Node := Node) (Input := Input) + gate := by + intro i rec + cases i with + | inl i => + exact C.gate i (fun j h => + rec (Sum.inl j) (by simpa [residualAdj] using h)) + | inr o => + let inNode : Node := (I.inputs o).1 + let outNode : Node := (I.outputs o).1 + let inVal := rec (Sum.inl inNode) (by + change inNode = (I.inputs o).1 ∨ inNode = (I.outputs o).1 + exact Or.inl rfl) + let outVal := rec (Sum.inl outNode) (by + change outNode = (I.inputs o).1 ∨ outNode = (I.outputs o).1 + exact Or.inr rfl) + exact inVal + outVal } + +/-- Interface for residual wiring. -/ +def residualInterface : + Interface (residualCircuit (C := C) (I := I)) Input Input := + { inputs := I.inputs.trans (inlSubtypeEquiv (s := C.inputs)) + outputs := residualOutputEquiv (Node := Node) (Input := Input) } + +end Residual + +namespace TypedCircuit + +variable {Node₁ : Type u} [Fintype Node₁] [DecidableEq Node₁] +variable {Node₂ : Type u'} [Fintype Node₂] [DecidableEq Node₂] +variable {Val : Type v} +variable {Input : Type u_in} {Mid : Type u_mid} {Output : Type u_out} + +/-- Sequential composition of typed circuits. -/ +def seq (T₁ : TypedCircuit Node₁ Val Input Mid) + (T₂ : TypedCircuit Node₂ Val Mid Output) : + TypedCircuit (Node₁ ⊕ Node₂) Val Input Output := + { circuit := seqCircuit (C1 := T₁.circuit) (C2 := T₂.circuit) + (I1 := T₁.interface) (I2 := T₂.interface) + interface := seqInterface (C1 := T₁.circuit) (C2 := T₂.circuit) + (I1 := T₁.interface) (I2 := T₂.interface) } + +variable [Add Val] +variable [Fintype Input] +variable [DecidableEq Input] + +/-- Add a residual connection to a typed circuit. -/ +def residual (T : TypedCircuit Node₁ Val Input Input) : + TypedCircuit (Node₁ ⊕ Input) Val Input Input := + { circuit := residualCircuit (C := T.circuit) (I := T.interface) + interface := residualInterface (C := T.circuit) (I := T.interface) } + +end TypedCircuit + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Layers.lean b/Nfp/Circuit/Layers.lean index b09773e..dfbe548 100644 --- a/Nfp/Circuit/Layers.lean +++ b/Nfp/Circuit/Layers.lean @@ -5,6 +5,8 @@ import Nfp.Circuit.Layers.Tensor import Nfp.Circuit.Layers.Reshape import Nfp.Circuit.Layers.Heads import Nfp.Circuit.Layers.Attention +import Nfp.Circuit.Layers.Induction +import Nfp.Circuit.Layers.TransformerBlock /-! Circuit layer combinators. diff --git a/Nfp/Circuit/Layers/Induction.lean b/Nfp/Circuit/Layers/Induction.lean new file mode 100644 index 0000000..ab5b13f --- /dev/null +++ b/Nfp/Circuit/Layers/Induction.lean @@ -0,0 +1,111 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Circuit.Layers.Attention + +/-! +Induction-head specifications for attention cores. +-/ + +namespace Nfp + +namespace Circuit + +namespace Layers + +universe v + +section Weights + +variable {Val : Type v} [NonAssocSemiring Val] +variable {seq : Nat} + +/-- Induction weights are one-hot at `prev` for each query position. -/ +def InductionWeights (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) : Prop := + ∀ q, weights q = Pi.single (prev q) 1 + +/-- A one-hot weight vector selects the corresponding value in a dot product. -/ +theorem dotProduct_eq_of_oneHot (k : Fin seq) (vals : Fin seq → Val) : + dotProduct (Pi.single k 1) vals = vals k := by + simp + +/-- Induction weights select the `prev` value in each dot product. -/ +theorem dotProduct_eq_prev (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) (vals : Fin seq → Fin seq → Val) + (hweights : InductionWeights (Val := Val) prev weights) (q : Fin seq) : + dotProduct (weights q) (vals q) = vals q (prev q) := by + have hq : weights q = Pi.single (prev q) 1 := hweights q + simp [hq] + +end Weights + +section Attention + +variable {Batch : Type} [Fintype Batch] [DecidableEq Batch] +variable {seq heads dim : Nat} +variable {Val : Type v} [NonAssocSemiring Val] + +/-- Weight function feeding an attention output node. -/ +def attentionOutWeights (b : Batch) (h : Fin heads) (q : Fin seq) (d : Fin dim) + (rec : + ∀ j, + (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel j + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) → + Val) : + Fin seq → Val := + fun k => + rec (attnWeight (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, h, q, k)) + (attentionDag_rel_weight_out (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q k d) + +/-- Value function feeding an attention output node. -/ +def attentionOutValues (b : Batch) (h : Fin heads) (q : Fin seq) (d : Fin dim) + (rec : + ∀ j, + (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel j + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) → + Val) : + Fin seq → Val := + fun k => + rec (attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, k, h, d)) + (attentionDag_rel_v_out (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b k h d q) + +/-- One-hot attention weights force the output to copy the selected value. -/ +theorem attentionGate_out_eq_of_oneHot (scale : Val) + (softmax : (Fin seq → Val) → Fin seq → Val) (prev : Fin seq → Fin seq) + (b : Batch) (h : Fin heads) (q : Fin seq) (d : Fin dim) + (rec : + ∀ j, + (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel j + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) → + Val) + (hweights : + attentionOutWeights (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d rec = + Pi.single (prev q) 1) : + attentionGate (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) rec = + attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d rec (prev q) := by + simp only [attentionGate] + change + dotProduct + (attentionOutWeights (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d rec) + (attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d rec) = + attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d rec (prev q) + rw [hweights] + exact dotProduct_eq_of_oneHot (Val := Val) (seq := seq) (k := prev q) + (vals := attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d rec) + +end Attention + +end Layers + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Layers/TransformerBlock.lean b/Nfp/Circuit/Layers/TransformerBlock.lean new file mode 100644 index 0000000..657cd59 --- /dev/null +++ b/Nfp/Circuit/Layers/TransformerBlock.lean @@ -0,0 +1,85 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Circuit.Compose +import Nfp.Circuit.Layers.Attention + +/-! +Transformer block wiring built from sequential composition and residual links. +-/ + +namespace Nfp + +namespace Circuit + +namespace Layers + +universe u v u_in + +variable {Val : Type v} [Add Val] +variable {Input : Type u_in} [Fintype Input] [DecidableEq Input] + +variable {NodeLn1 : Type u} [Fintype NodeLn1] [DecidableEq NodeLn1] +variable {NodeAttn : Type u} [Fintype NodeAttn] [DecidableEq NodeAttn] +variable {NodeLn2 : Type u} [Fintype NodeLn2] [DecidableEq NodeLn2] +variable {NodeMlp : Type u} [Fintype NodeMlp] [DecidableEq NodeMlp] + +/-- Node type for the attention subpath (LN1 ∘ attention). -/ +abbrev AttnPathNode (NodeLn1 NodeAttn : Type u) : Type u := + NodeLn1 ⊕ NodeAttn + +/-- Node type for the attention residual wrapper. -/ +abbrev AttnResidualNode (NodeLn1 NodeAttn : Type u) (Input : Type u_in) : + Type (max u u_in) := + (AttnPathNode NodeLn1 NodeAttn) ⊕ Input + +/-- Node type for the MLP subpath (LN2 ∘ MLP). -/ +abbrev MlpPathNode (NodeLn2 NodeMlp : Type u) : Type u := + NodeLn2 ⊕ NodeMlp + +/-- Node type for the MLP residual wrapper. -/ +abbrev MlpResidualNode (NodeLn2 NodeMlp : Type u) (Input : Type u_in) : + Type (max u u_in) := + (MlpPathNode NodeLn2 NodeMlp) ⊕ Input + +/-- Node type for a full transformer block. -/ +abbrev TransformerBlockNode (NodeLn1 NodeAttn NodeLn2 NodeMlp : Type u) (Input : Type u_in) : + Type (max u u_in) := + (AttnResidualNode NodeLn1 NodeAttn Input) ⊕ (MlpResidualNode NodeLn2 NodeMlp Input) + +/-- Compose a GPT-style transformer block from LN/attention/MLP circuits. -/ +def transformerBlock + (ln1 : TypedCircuit NodeLn1 Val Input Input) + (attn : TypedCircuit NodeAttn Val Input Input) + (ln2 : TypedCircuit NodeLn2 Val Input Input) + (mlp : TypedCircuit NodeMlp Val Input Input) : + TypedCircuit (TransformerBlockNode NodeLn1 NodeAttn NodeLn2 NodeMlp Input) Val Input Input := + let attnPath := TypedCircuit.seq ln1 attn + let attnRes := TypedCircuit.residual attnPath + let mlpPath := TypedCircuit.seq ln2 mlp + let mlpRes := TypedCircuit.residual mlpPath + TypedCircuit.seq attnRes mlpRes + +/-- Token hidden state type for GPT-2 style blocks. -/ +abbrev BlockInput (Batch : Type) (seq heads dim : Nat) : Type := + Batch × Fin seq × Hidden heads dim + +/-- Transformer block specialized to GPT-style hidden states. -/ +def gpt2Block {Batch : Type} [Fintype Batch] [DecidableEq Batch] {seq heads dim : Nat} + (ln1 : TypedCircuit NodeLn1 Val (BlockInput Batch seq heads dim) + (BlockInput Batch seq heads dim)) + (attn : TypedCircuit NodeAttn Val (BlockInput Batch seq heads dim) + (BlockInput Batch seq heads dim)) + (ln2 : TypedCircuit NodeLn2 Val (BlockInput Batch seq heads dim) + (BlockInput Batch seq heads dim)) + (mlp : TypedCircuit NodeMlp Val (BlockInput Batch seq heads dim) + (BlockInput Batch seq heads dim)) : + TypedCircuit (TransformerBlockNode NodeLn1 NodeAttn NodeLn2 NodeMlp + (BlockInput Batch seq heads dim)) Val (BlockInput Batch seq heads dim) + (BlockInput Batch seq heads dim) := + transformerBlock ln1 attn ln2 mlp + +end Layers + +end Circuit + +end Nfp From 962534856006f2fc7acab28304163f10f9b57200 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 22:38:17 +0100 Subject: [PATCH 084/244] Extend attention core wiring --- Nfp/Circuit/Layers/Attention.lean | 683 +++++++++++++++++++++++++++++- 1 file changed, 681 insertions(+), 2 deletions(-) diff --git a/Nfp/Circuit/Layers/Attention.lean b/Nfp/Circuit/Layers/Attention.lean index e161eb4..2b4e478 100644 --- a/Nfp/Circuit/Layers/Attention.lean +++ b/Nfp/Circuit/Layers/Attention.lean @@ -1,10 +1,13 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later +import Mathlib.Data.Finset.Image +import Mathlib.Data.Matrix.Mul +import Mathlib.Logic.Embedding.Basic import Nfp.Circuit.Layers.Heads import Nfp.Circuit.Layers.Tensor /-! -QKV and output projection wiring for attention layers. +QKV and output projection wiring for attention layers, plus attention score/mixing core. -/ namespace Nfp @@ -13,9 +16,11 @@ namespace Circuit namespace Layers +open Function + universe v -variable {Batch : Type} [Fintype Batch] [DecidableEq Batch] +variable {Batch : Type} [Fintype Batch] variable {Val : Type v} [NonUnitalNonAssocSemiring Val] /-- Hidden dimension type for a `heads × dim` factorization. -/ @@ -25,6 +30,10 @@ abbrev Hidden (heads dim : Nat) : Type := Fin (heads * dim) abbrev QkvNode (Batch : Type) (heads dim : Nat) : Type := BatchedLinearNode Batch (Hidden heads dim) (Hidden heads dim) +section Projections + +variable [DecidableEq Batch] + /-- Q projection with head-split output. -/ def qProj (heads dim : Nat) (Wq : Matrix (Hidden heads dim) (Hidden heads dim) Val) : TypedCircuit (QkvNode Batch heads dim) Val (Batch × Hidden heads dim) @@ -57,6 +66,676 @@ def outProj (heads dim : Nat) (Wo : Matrix (Hidden heads dim) (Hidden heads dim) (batchedLinearTyped (Batch := Batch) (Row := Hidden heads dim) (Col := Hidden heads dim) Wo) +end Projections + +/- +Attention score/mixing core. +-/ + +variable {seq heads dim : Nat} + +/-- Index for per-token head-split vectors. -/ +abbrev QkvIndex (Batch : Type) (seq heads dim : Nat) : Type := + Batch × Fin seq × Fin heads × Fin dim + +/-- Index for attention score/weight entries. -/ +abbrev ScoreIndex (Batch : Type) (seq heads : Nat) : Type := + Batch × Fin heads × Fin seq × Fin seq + +/-- Index for attention weight entries (same shape as scores). -/ +abbrev WeightIndex (Batch : Type) (seq heads : Nat) : Type := + ScoreIndex Batch seq heads + +/-- Input-node labels for attention core circuits. -/ +abbrev AttentionInputNode (Batch : Type) (seq heads dim : Nat) : Type := + Sum (QkvIndex Batch seq heads dim) + (Sum (QkvIndex Batch seq heads dim) (QkvIndex Batch seq heads dim)) + +/-- Typed input labels for attention core circuits (Q/K/V). -/ +abbrev AttentionInput (Batch : Type) (seq heads dim : Nat) : Type := + AttentionInputNode Batch seq heads dim + +/-- Typed output labels for attention core circuits. -/ +abbrev AttentionOutput (Batch : Type) (seq heads dim : Nat) : Type := + QkvIndex Batch seq heads dim + +/-- Node type for the attention score/mixing core. -/ +abbrev AttentionNode (Batch : Type) (seq heads dim : Nat) : Type := + Sum (AttentionInputNode Batch seq heads dim) + (Sum (ScoreIndex Batch seq heads) (Sum (WeightIndex Batch seq heads) + (AttentionOutput Batch seq heads dim))) + +section Decidable + +variable [DecidableEq Batch] + +/-- Decidable equality for attention-core input nodes. -/ +instance instDecidableEqAttentionInputNode : + DecidableEq (AttentionInputNode Batch seq heads dim) := by + intro x y + cases x with + | inl qx => + cases y with + | inl qy => + simpa using (inferInstance : Decidable (qx = qy)) + | inr _ => + exact isFalse (by intro h; cases h) + | inr x => + cases x with + | inl kx => + cases y with + | inl _ => + exact isFalse (by intro h; cases h) + | inr y => + cases y with + | inl ky => + simpa using (inferInstance : Decidable (kx = ky)) + | inr _ => + exact isFalse (by intro h; cases h) + | inr vx => + cases y with + | inl _ => + exact isFalse (by intro h; cases h) + | inr y => + cases y with + | inl _ => + exact isFalse (by intro h; cases h) + | inr vy => + simpa using (inferInstance : Decidable (vx = vy)) + +/-- Decidable equality for attention-core nodes. -/ +instance instDecidableEqAttentionNode : + DecidableEq (AttentionNode Batch seq heads dim) := by + intro x y + cases x with + | inl x => + cases y with + | inl y => + simpa using (inferInstance : Decidable (x = y)) + | inr _ => + exact isFalse (by intro h; cases h) + | inr x => + cases x with + | inl sx => + cases y with + | inl _ => + exact isFalse (by intro h; cases h) + | inr y => + cases y with + | inl sy => + simpa using (inferInstance : Decidable (sx = sy)) + | inr _ => + exact isFalse (by intro h; cases h) + | inr x => + cases x with + | inl wx => + cases y with + | inl _ => + exact isFalse (by intro h; cases h) + | inr y => + cases y with + | inl _ => + exact isFalse (by intro h; cases h) + | inr y => + cases y with + | inl wy => + simpa using (inferInstance : Decidable (wx = wy)) + | inr _ => + exact isFalse (by intro h; cases h) + | inr ox => + cases y with + | inl _ => + exact isFalse (by intro h; cases h) + | inr y => + cases y with + | inl _ => + exact isFalse (by intro h; cases h) + | inr y => + cases y with + | inl _ => + exact isFalse (by intro h; cases h) + | inr oy => + simpa using (inferInstance : Decidable (ox = oy)) + +end Decidable + +/-- Inject a Q input node into the attention core node type. -/ +abbrev attnQ (q : QkvIndex Batch seq heads dim) : AttentionNode Batch seq heads dim := + Sum.inl (Sum.inl q) + +/-- Inject a K input node into the attention core node type. -/ +abbrev attnK (k : QkvIndex Batch seq heads dim) : AttentionNode Batch seq heads dim := + Sum.inl (Sum.inr (Sum.inl k)) + +/-- Inject a V input node into the attention core node type. -/ +abbrev attnV (v : QkvIndex Batch seq heads dim) : AttentionNode Batch seq heads dim := + Sum.inl (Sum.inr (Sum.inr v)) + +/-- Inject a score node into the attention core node type. -/ +abbrev attnScore (s : ScoreIndex Batch seq heads) : AttentionNode Batch seq heads dim := + Sum.inr (Sum.inl s) + +/-- Inject a weight node into the attention core node type. -/ +abbrev attnWeight (w : WeightIndex Batch seq heads) : AttentionNode Batch seq heads dim := + Sum.inr (Sum.inr (Sum.inl w)) + +/-- Inject an output node into the attention core node type. -/ +abbrev attnOut (o : AttentionOutput Batch seq heads dim) : AttentionNode Batch seq heads dim := + Sum.inr (Sum.inr (Sum.inr o)) + +/-- Rank function used to orient attention-core edges from inputs to outputs. -/ +def attentionRank : AttentionNode Batch seq heads dim → Nat + | Sum.inl _ => 0 + | Sum.inr (Sum.inl _) => 1 + | Sum.inr (Sum.inr (Sum.inl _)) => 2 + | Sum.inr (Sum.inr (Sum.inr _)) => 3 + +/-- Adjacency relation for attention-core wiring. -/ +def attentionAdj : AttentionNode Batch seq heads dim → AttentionNode Batch seq heads dim → Prop + | Sum.inl (Sum.inl (b, q, h, _)), + Sum.inr (Sum.inl (b', h', q', _)) => + b = b' ∧ h = h' ∧ q = q' + | Sum.inl (Sum.inr (Sum.inl (b, k, h, _))), + Sum.inr (Sum.inl (b', h', _, k')) => + b = b' ∧ h = h' ∧ k = k' + | Sum.inr (Sum.inl (b, h, q, _)), + Sum.inr (Sum.inr (Sum.inl (b', h', q', _))) => + b = b' ∧ h = h' ∧ q = q' + | Sum.inr (Sum.inr (Sum.inl (b, h, q, _))), + Sum.inr (Sum.inr (Sum.inr (b', q', h', _))) => + b = b' ∧ h = h' ∧ q = q' + | Sum.inl (Sum.inr (Sum.inr (b, _, h, d))), + Sum.inr (Sum.inr (Sum.inr (b', _, h', d'))) => + b = b' ∧ h = h' ∧ d = d' + | _, _ => False + +section Dag + +variable [DecidableEq Batch] + +/-- DAG for the attention score/mixing core. -/ +def attentionDag : Dag (AttentionNode Batch seq heads dim) := + { graph := { Adj := attentionAdj (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) } + decAdj := by + intro j i + cases j with + | inl j => + cases j with + | inl q => + rcases q with ⟨b, q, h, _d⟩ + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr i => + cases i with + | inl s => + rcases s with ⟨b', h', q', _k⟩ + exact (inferInstance : Decidable (b = b' ∧ h = h' ∧ q = q')) + | inr _ => + exact isFalse (by intro h; cases h) + | inr j => + cases j with + | inl k => + rcases k with ⟨b, k, h, _d⟩ + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr i => + cases i with + | inl s => + rcases s with ⟨b', h', _q, k'⟩ + exact (inferInstance : Decidable (b = b' ∧ h = h' ∧ k = k')) + | inr _ => + exact isFalse (by intro h; cases h) + | inr v => + rcases v with ⟨b, _k, h, d⟩ + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr i => + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr i => + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr o => + rcases o with ⟨b', _q, h', d'⟩ + exact (inferInstance : Decidable (b = b' ∧ h = h' ∧ d = d')) + | inr j => + cases j with + | inl s => + rcases s with ⟨b, h, q, _k⟩ + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr i => + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr i => + cases i with + | inl w => + rcases w with ⟨b', h', q', _k'⟩ + exact (inferInstance : Decidable (b = b' ∧ h = h' ∧ q = q')) + | inr _ => + exact isFalse (by intro h; cases h) + | inr j => + cases j with + | inl w => + rcases w with ⟨b, h, q, _k⟩ + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr i => + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr i => + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr o => + rcases o with ⟨b', q', h', _d⟩ + exact (inferInstance : Decidable (b = b' ∧ h = h' ∧ q = q')) + | inr _ => + exact isFalse (by intro h; cases h) + wf := by + have hsub : + Subrelation (attentionAdj (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)) + (fun j i => + attentionRank (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) j < + attentionRank (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) i) := by + intro j i h + cases j with + | inl j => + cases j with + | inl _ => + cases i with + | inl _ => + cases h + | inr i => + cases i with + | inl _ => + simp [attentionRank] + | inr i => + cases i with + | inl _ => + cases h + | inr _ => + cases h + | inr j => + cases j with + | inl _ => + cases i with + | inl _ => + cases h + | inr i => + cases i with + | inl _ => + simp [attentionRank] + | inr i => + cases i with + | inl _ => + cases h + | inr _ => + cases h + | inr _ => + cases i with + | inl _ => + cases h + | inr i => + cases i with + | inl _ => + cases h + | inr i => + cases i with + | inl _ => + cases h + | inr _ => + simp [attentionRank] + | inr j => + cases j with + | inl _ => + cases i with + | inl _ => + cases h + | inr i => + cases i with + | inl _ => + cases h + | inr i => + cases i with + | inl _ => + simp [attentionRank] + | inr _ => + cases h + | inr j => + cases j with + | inl _ => + cases i with + | inl _ => + cases h + | inr i => + cases i with + | inl _ => + cases h + | inr i => + cases i with + | inl _ => + cases h + | inr _ => + simp [attentionRank] + | inr _ => + cases h + have hwf : WellFounded (fun j i => + attentionRank (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) j < + attentionRank (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) i) := by + simpa using (InvImage.wf + (f := attentionRank (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)) + (h := Nat.lt_wfRel.wf)) + exact Subrelation.wf hsub hwf } + +/-- Q nodes feed score nodes for matching batch/head/query. -/ +theorem attentionDag_rel_q_score (b : Batch) (q : Fin seq) (h : Fin heads) (d : Fin dim) + (k : Fin seq) : + (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel + (attnQ (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) + (attnScore (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, h, q, k)) := by + simp [Dag.rel, attentionDag, attentionAdj] + +/-- K nodes feed score nodes for matching batch/head/key. -/ +theorem attentionDag_rel_k_score (b : Batch) (k : Fin seq) (h : Fin heads) (d : Fin dim) + (q : Fin seq) : + (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel + (attnK (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, k, h, d)) + (attnScore (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, h, q, k)) := by + simp [Dag.rel, attentionDag, attentionAdj] + +/-- Score nodes feed weight nodes for matching batch/head/query. -/ +theorem attentionDag_rel_score_weight (b : Batch) (h : Fin heads) (q : Fin seq) (k k' : Fin seq) : + (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel + (attnScore (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, h, q, k')) + (attnWeight (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, h, q, k)) := by + simp [Dag.rel, attentionDag, attentionAdj] + +/-- Weight nodes feed output nodes for matching batch/head/query. -/ +theorem attentionDag_rel_weight_out (b : Batch) (h : Fin heads) (q : Fin seq) + (k : Fin seq) (d : Fin dim) : + (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel + (attnWeight (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, h, q, k)) + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) := by + simp [Dag.rel, attentionDag, attentionAdj] + +/-- V nodes feed output nodes for matching batch/head/dimension. -/ +theorem attentionDag_rel_v_out (b : Batch) (k : Fin seq) (h : Fin heads) (d : Fin dim) + (q : Fin seq) : + (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel + (attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, k, h, d)) + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) := by + simp [Dag.rel, attentionDag, attentionAdj] + +end Dag + +section Inputs + +/-- Input nodes for the attention core. -/ +def attentionInputs : Finset (AttentionNode Batch seq heads dim) := + (Finset.univ : Finset (AttentionInput Batch seq heads dim)).map Embedding.inl + +open scoped Classical in +/-- Membership in attention inputs corresponds to being a left injection. -/ +theorem mem_attentionInputs_iff {s : AttentionNode Batch seq heads dim} : + s ∈ attentionInputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) ↔ + ∃ a, s = Sum.inl a := by + constructor + · intro hs + rcases (Finset.mem_map.1 hs) with ⟨a, _ha, hsa⟩ + exact ⟨a, hsa.symm⟩ + · rintro ⟨a, rfl⟩ + refine Finset.mem_map.2 ?_ + exact ⟨a, by simp, rfl⟩ + +open scoped Classical in +/-- Right injections are not attention input nodes. -/ +theorem not_mem_attentionInputs_inr (s : Sum (ScoreIndex Batch seq heads) + (Sum (WeightIndex Batch seq heads) (AttentionOutput Batch seq heads dim))) : + Sum.inr s ∉ attentionInputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) := by + intro h + rcases (mem_attentionInputs_iff (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).1 h + with ⟨a, ha⟩ + cases ha + +open scoped Classical in +/-- Input labels correspond to input nodes in the attention core. -/ +def attentionInputEquiv : + AttentionInput Batch seq heads dim ≃ + { i // i ∈ attentionInputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) } := + { toFun := fun a => + ⟨Sum.inl a, + (mem_attentionInputs_iff (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).2 + ⟨a, rfl⟩⟩ + invFun := fun i => + match i with + | ⟨Sum.inl a, _⟩ => a + | ⟨Sum.inr s, h⟩ => + False.elim + (not_mem_attentionInputs_inr (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) s h) + left_inv := by + intro a + rfl + right_inv := by + intro i + cases i with + | mk s hs => + cases s with + | inl a => rfl + | inr s => + cases (not_mem_attentionInputs_inr (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) s hs) } + +end Inputs + +section Outputs + +/-- Output nodes for the attention core. -/ +def attentionOutputs : Finset (AttentionNode Batch seq heads dim) := + (Finset.univ : Finset (AttentionOutput Batch seq heads dim)).map + { toFun := attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + inj' := by + intro a b h + cases h + rfl } + +open scoped Classical in +/-- Membership in attention outputs corresponds to being an output injection. -/ +theorem mem_attentionOutputs_iff {s : AttentionNode Batch seq heads dim} : + s ∈ attentionOutputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) ↔ + ∃ o, s = attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) o := by + constructor + · intro hs + rcases (Finset.mem_map.1 hs) with ⟨o, _ho, hso⟩ + exact ⟨o, hso.symm⟩ + · rintro ⟨o, rfl⟩ + refine Finset.mem_map.2 ?_ + exact ⟨o, by simp, rfl⟩ + +open scoped Classical in +/-- Left injections are not attention output nodes. -/ +theorem not_mem_attentionOutputs_inl (s : AttentionInputNode Batch seq heads dim) : + Sum.inl s ∉ attentionOutputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) := by + intro h + rcases (Finset.mem_map.1 h) with ⟨o, _ho, ho⟩ + cases ho + +open scoped Classical in +/-- Score nodes are not attention output nodes. -/ +theorem not_mem_attentionOutputs_score (s : ScoreIndex Batch seq heads) : + Sum.inr (Sum.inl s) ∉ attentionOutputs (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) := by + intro h + rcases (Finset.mem_map.1 h) with ⟨o, _ho, ho⟩ + cases ho + +open scoped Classical in +/-- Weight nodes are not attention output nodes. -/ +theorem not_mem_attentionOutputs_weight (w : WeightIndex Batch seq heads) : + Sum.inr (Sum.inr (Sum.inl w)) ∉ + attentionOutputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) := by + intro h + rcases (Finset.mem_map.1 h) with ⟨o, _ho, ho⟩ + cases ho + +open scoped Classical in +/-- Output labels correspond to output nodes in the attention core. -/ +def attentionOutputEquiv : + AttentionOutput Batch seq heads dim ≃ + { i // i ∈ attentionOutputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) } := + { toFun := fun o => + ⟨attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) o, + (mem_attentionOutputs_iff (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).2 + ⟨o, rfl⟩⟩ + invFun := fun i => + match i with + | ⟨Sum.inr (Sum.inr (Sum.inr o)), _⟩ => o + | ⟨Sum.inl s, h⟩ => + False.elim + (not_mem_attentionOutputs_inl (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) s h) + | ⟨Sum.inr (Sum.inl s), h⟩ => + False.elim + (not_mem_attentionOutputs_score (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) s h) + | ⟨Sum.inr (Sum.inr (Sum.inl w)), h⟩ => + False.elim + (not_mem_attentionOutputs_weight (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) w h) + left_inv := by + intro o + rfl + right_inv := by + intro i + cases i with + | mk s hs => + cases s with + | inl s => + cases (not_mem_attentionOutputs_inl (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) s hs) + | inr s => + cases s with + | inl s => + cases (not_mem_attentionOutputs_score (Batch := Batch) (seq := seq) + (heads := heads) (dim := dim) s hs) + | inr s => + cases s with + | inl w => + cases (not_mem_attentionOutputs_weight (Batch := Batch) (seq := seq) + (heads := heads) (dim := dim) w hs) + | inr _ => + rfl } + +end Outputs + +section Circuits + +variable [DecidableEq Batch] + +/-- Gate semantics for attention score/mixing circuits. -/ +def attentionGate (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : + ∀ i, + (∀ j, + (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel j i → + Val) → + Val := by + intro i rec + cases i with + | inl _ => + exact 0 + | inr i => + cases i with + | inl s => + rcases s with ⟨b, h, q, k⟩ + let qNode : Fin dim → AttentionNode Batch seq heads dim := fun d => + attnQ (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, q, h, d) + let kNode : Fin dim → AttentionNode Batch seq heads dim := fun d => + attnK (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, k, h, d) + let qVals : Fin dim → Val := fun d => + rec (qNode d) + (attentionDag_rel_q_score (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b q h d k) + let kVals : Fin dim → Val := fun d => + rec (kNode d) + (attentionDag_rel_k_score (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b k h d q) + exact scale * dotProduct qVals kVals + | inr i => + cases i with + | inl w => + rcases w with ⟨b, h, q, k⟩ + let scoreNode : Fin seq → AttentionNode Batch seq heads dim := fun k' => + attnScore (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, h, q, k') + let scores : Fin seq → Val := fun k' => + rec (scoreNode k') + (attentionDag_rel_score_weight (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) b h q k k') + exact softmax scores k + | inr o => + rcases o with ⟨b, q, h, d⟩ + let weightNode : Fin seq → AttentionNode Batch seq heads dim := fun k => + attnWeight (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, h, q, k) + let valueNode : Fin seq → AttentionNode Batch seq heads dim := fun k => + attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, k, h, d) + let weights : Fin seq → Val := fun k => + rec (weightNode k) + (attentionDag_rel_weight_out (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) b h q k d) + let vals : Fin seq → Val := fun k => + rec (valueNode k) + (attentionDag_rel_v_out (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) b k h d q) + exact dotProduct weights vals + +/-- Circuit for attention score/mixing. -/ +def attentionCircuit (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : + Circuit (AttentionNode Batch seq heads dim) Val := + { dag := attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + inputs := attentionInputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + outputs := attentionOutputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + gate := + attentionGate (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax } + +/-- Typed interface for attention score/mixing circuits. -/ +def attentionInterface (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : + Interface + (attentionCircuit (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax) + (AttentionInput Batch seq heads dim) (AttentionOutput Batch seq heads dim) := + { inputs := attentionInputEquiv (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + outputs := attentionOutputEquiv (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) } + +end Circuits + +section Typed + +variable [DecidableEq Batch] + +/-- Typed attention score/mixing circuit. -/ +def attentionTyped (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : + TypedCircuit (AttentionNode Batch seq heads dim) Val + (AttentionInput Batch seq heads dim) (AttentionOutput Batch seq heads dim) := + { circuit := attentionCircuit (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + scale softmax + interface := attentionInterface (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + scale softmax } + +end Typed + end Layers end Circuit From 9d2307ae3ce4f96c5d634af3b9ce117ec42ee345 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 22:41:24 +0100 Subject: [PATCH 085/244] Update tabula rasa status --- README.md | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ccf14a8..5ae409d 100644 --- a/README.md +++ b/README.md @@ -21,8 +21,12 @@ The `tabula-rasa` branch is a fresh, minimal Lean 4 core focused on circuit cert Current core modules (new): - `Nfp/Core`, `Nfp/Prob`, `Nfp/Mixer`, `Nfp/System` define basic mass/probability, mixers, and DAG-backed local systems. -- `Nfp/Circuit` defines DAG-based circuits with input-only evaluation, typed interfaces, well-formedness, and equivalence checkers. -- `Nfp/Cli` and `Main.lean` are thin placeholders (no full transformer pipeline yet). +- `Nfp/Circuit` defines DAG-based circuits with typed interfaces, well-formedness, and equivalence checkers. +- `Nfp/Circuit/Compose` adds sequential and residual wiring combinators for typed circuits. +- `Nfp/Circuit/Layers/Attention` contains Q/K/V projection wiring plus an attention score/mixing core. +- `Nfp/Circuit/Layers/Induction` provides induction-head specs and the core attention one-hot lemma. +- `Nfp/Circuit/Layers/TransformerBlock` wires LN/attention/MLP into a GPT-style block skeleton. +- `Nfp/Cli` and `Main.lean` remain thin placeholders (no full transformer pipeline yet). Module map and invariants are tracked in `AGENTS.md`. From dd56c5be34f72f86271ff49e8c4d2e2a09101a11 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 22:52:15 +0100 Subject: [PATCH 086/244] Add induction spec and typed eval lemma --- Nfp/Circuit/Layers/Induction.lean | 122 ++++++++++++++++++++++++++++++ 1 file changed, 122 insertions(+) diff --git a/Nfp/Circuit/Layers/Induction.lean b/Nfp/Circuit/Layers/Induction.lean index ab5b13f..aa7ede7 100644 --- a/Nfp/Circuit/Layers/Induction.lean +++ b/Nfp/Circuit/Layers/Induction.lean @@ -39,12 +39,35 @@ theorem dotProduct_eq_prev (prev : Fin seq → Fin seq) end Weights +section Spec + +variable {Val : Type v} +variable {n : Nat} + +/-- Induction-head spec: for nonzero queries, outputs copy `prev` values. -/ +def InductionSpec (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (out vals : Fin (Nat.succ n) → Val) : Prop := + ∀ q, q ≠ 0 → out q = vals (prev q) + +/-- Concrete `prev` map on `Fin (n + 1)` (with `0 ↦ 0`). -/ +def prevIndex : Fin (Nat.succ n) → Fin (Nat.succ n) + | ⟨0, _⟩ => 0 + | ⟨Nat.succ k, hk⟩ => + ⟨k, Nat.lt_trans (Nat.lt_of_succ_lt_succ hk) (Nat.lt_succ_self n)⟩ + +end Spec + section Attention variable {Batch : Type} [Fintype Batch] [DecidableEq Batch] variable {seq heads dim : Nat} variable {Val : Type v} [NonAssocSemiring Val] +/-- Typed V-input label for attention cores. -/ +abbrev attnInputV (v : QkvIndex Batch seq heads dim) : + AttentionInput Batch seq heads dim := + Sum.inr (Sum.inr v) + /-- Weight function feeding an attention output node. -/ def attentionOutWeights (b : Batch) (h : Fin heads) (q : Fin seq) (d : Fin dim) (rec : @@ -102,6 +125,105 @@ theorem attentionGate_out_eq_of_oneHot (scale : Val) (vals := attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) b h q d rec) +section Typed + +variable (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) + +/-- Attention output equals the selected V input when weights are one-hot. -/ +theorem attentionTyped_eval_out_eq_of_oneHot (prev : Fin seq → Fin seq) + (input : AttentionInput Batch seq heads dim → Val) + (b : Batch) (h : Fin heads) (q : Fin seq) (d : Fin dim) + (hweights : + attentionOutWeights (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d + (fun j _ => + Circuit.evalInput + (attentionCircuit (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + scale softmax) + ((attentionInterface (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + scale softmax).toInputAssignment input) j) = + Pi.single (prev q) 1) : + (attentionTyped (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax).eval + input (b, q, h, d) = + input + (attnInputV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, prev q, h, d)) := by + let C := + attentionCircuit (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax + let I := + attentionInterface (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax + let inputAssign := I.toInputAssignment input + have hnot : + attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d) ∉ + attentionInputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) := by + simpa using + (not_mem_attentionInputs_inr (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (s := Sum.inr (Sum.inr (b, q, h, d)))) + have hgate : + Circuit.evalInput C inputAssign + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) = + attentionGate (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) + (fun j _ => Circuit.evalInput C inputAssign j) := by + exact Circuit.evalInput_eq_gate (C := C) (input := inputAssign) + (i := attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) + hnot + have hcopy : + Circuit.evalInput C inputAssign + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) = + attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d (fun j _ => Circuit.evalInput C inputAssign j) (prev q) := by + have hgate' : + attentionGate (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) + (fun j _ => Circuit.evalInput C inputAssign j) = + attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d (fun j _ => Circuit.evalInput C inputAssign j) (prev q) := + attentionGate_out_eq_of_oneHot (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (scale := scale) (softmax := softmax) (prev := prev) (b := b) (h := h) (q := q) (d := d) + (rec := fun j _ => Circuit.evalInput C inputAssign j) hweights + exact hgate.trans hgate' + have hmem : + attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, prev q, h, d) ∈ + attentionInputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) := by + refine (mem_attentionInputs_iff (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim)).2 ?_ + exact ⟨attnInputV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, prev q, h, d), rfl⟩ + have hinput : + Circuit.evalInput C inputAssign + (attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, prev q, h, d)) = + input (attnInputV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, prev q, h, d)) := by + have h := + Circuit.evalInput_eq_input (C := C) (input := inputAssign) + (i := attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, prev q, h, d)) hmem + simpa [inputAssign, I, attentionInterface, attnInputV] using h + have hvals : + attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d (fun j _ => Circuit.evalInput C inputAssign j) (prev q) = + Circuit.evalInput C inputAssign + (attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, prev q, h, d)) := rfl + calc + (attentionTyped (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax).eval + input (b, q, h, d) = + Circuit.evalInput C inputAssign (I.outputs (b, q, h, d)).1 := by + simp [TypedCircuit.eval, Interface.eval, C, I, inputAssign, attentionTyped] + _ = Circuit.evalInput C inputAssign + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) := by + rfl + _ = attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d (fun j _ => Circuit.evalInput C inputAssign j) (prev q) := hcopy + _ = Circuit.evalInput C inputAssign + (attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, prev q, h, d)) := hvals + _ = input (attnInputV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, prev q, h, d)) := hinput + +end Typed + end Attention end Layers From 969a8b314eaf404d6df9505982150067219a6819 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 23:22:05 +0100 Subject: [PATCH 087/244] Remove legacy tree --- AGENTS.md | 3 - Legacy/Nfp/Abstraction.lean | 424 - Legacy/Nfp/Appendix.lean | 956 --- Legacy/Nfp/Attribution.lean | 237 - Legacy/Nfp/Discovery.lean | 9929 ----------------------- Legacy/Nfp/IO.lean | 503 -- Legacy/Nfp/IO/Pure.lean | 399 - Legacy/Nfp/Induction.lean | 498 -- Legacy/Nfp/Influence.lean | 342 - Legacy/Nfp/Layers.lean | 1046 --- Legacy/Nfp/Linearization.lean | 2780 ------- Legacy/Nfp/MixerLocalSystem.lean | 68 - Legacy/Nfp/PCC.lean | 227 - Legacy/Nfp/Reroute/Heat.lean | 524 -- Legacy/Nfp/Reroute/Partition.lean | 413 - Legacy/Nfp/SignedMixer.lean | 636 -- Legacy/Nfp/Sound/Activation.lean | 43 - Legacy/Nfp/Sound/Affine.lean | 96 - Legacy/Nfp/Sound/BinaryPure.lean | 479 -- Legacy/Nfp/Sound/Bounds.lean | 19 - Legacy/Nfp/Sound/Bounds/Attention.lean | 83 - Legacy/Nfp/Sound/Bounds/Basic.lean | 19 - Legacy/Nfp/Sound/Bounds/Effort.lean | 11 - Legacy/Nfp/Sound/Bounds/Exp.lean | 197 - Legacy/Nfp/Sound/Bounds/Gelu.lean | 19 - Legacy/Nfp/Sound/Bounds/LayerNorm.lean | 164 - Legacy/Nfp/Sound/Bounds/MatrixNorm.lean | 130 - Legacy/Nfp/Sound/Bounds/Portfolio.lean | 50 - Legacy/Nfp/Sound/Bounds/Softmax.lean | 231 - Legacy/Nfp/Sound/Bridge.lean | 759 -- Legacy/Nfp/Sound/CachePure.lean | 1011 --- Legacy/Nfp/Sound/Cert.lean | 609 -- Legacy/Nfp/Sound/Decimal.lean | 253 - Legacy/Nfp/Sound/Demo.lean | 103 - Legacy/Nfp/Sound/Fixed.lean | 400 - Legacy/Nfp/Sound/HeadCert.lean | 748 -- Legacy/Nfp/Sound/IO.lean | 654 -- Legacy/Nfp/Sound/Interval.lean | 448 - Legacy/Nfp/Sound/ModelHeader.lean | 167 - Legacy/Nfp/Sound/TextPure.lean | 313 - Legacy/Nfp/Uniqueness.lean | 97 - Legacy/Nfp/Untrusted/SoundBinary.lean | 141 - Legacy/Nfp/Untrusted/SoundCacheIO.lean | 256 - Legacy/Nfp/Untrusted/SoundCompute.lean | 8588 -------------------- Legacy/Nfp/Verification.lean | 399 - README.md | 2 +- 46 files changed, 1 insertion(+), 35473 deletions(-) delete mode 100644 Legacy/Nfp/Abstraction.lean delete mode 100644 Legacy/Nfp/Appendix.lean delete mode 100644 Legacy/Nfp/Attribution.lean delete mode 100644 Legacy/Nfp/Discovery.lean delete mode 100644 Legacy/Nfp/IO.lean delete mode 100644 Legacy/Nfp/IO/Pure.lean delete mode 100644 Legacy/Nfp/Induction.lean delete mode 100644 Legacy/Nfp/Influence.lean delete mode 100644 Legacy/Nfp/Layers.lean delete mode 100644 Legacy/Nfp/Linearization.lean delete mode 100644 Legacy/Nfp/MixerLocalSystem.lean delete mode 100644 Legacy/Nfp/PCC.lean delete mode 100644 Legacy/Nfp/Reroute/Heat.lean delete mode 100644 Legacy/Nfp/Reroute/Partition.lean delete mode 100644 Legacy/Nfp/SignedMixer.lean delete mode 100644 Legacy/Nfp/Sound/Activation.lean delete mode 100644 Legacy/Nfp/Sound/Affine.lean delete mode 100644 Legacy/Nfp/Sound/BinaryPure.lean delete mode 100644 Legacy/Nfp/Sound/Bounds.lean delete mode 100644 Legacy/Nfp/Sound/Bounds/Attention.lean delete mode 100644 Legacy/Nfp/Sound/Bounds/Basic.lean delete mode 100644 Legacy/Nfp/Sound/Bounds/Effort.lean delete mode 100644 Legacy/Nfp/Sound/Bounds/Exp.lean delete mode 100644 Legacy/Nfp/Sound/Bounds/Gelu.lean delete mode 100644 Legacy/Nfp/Sound/Bounds/LayerNorm.lean delete mode 100644 Legacy/Nfp/Sound/Bounds/MatrixNorm.lean delete mode 100644 Legacy/Nfp/Sound/Bounds/Portfolio.lean delete mode 100644 Legacy/Nfp/Sound/Bounds/Softmax.lean delete mode 100644 Legacy/Nfp/Sound/Bridge.lean delete mode 100644 Legacy/Nfp/Sound/CachePure.lean delete mode 100644 Legacy/Nfp/Sound/Cert.lean delete mode 100644 Legacy/Nfp/Sound/Decimal.lean delete mode 100644 Legacy/Nfp/Sound/Demo.lean delete mode 100644 Legacy/Nfp/Sound/Fixed.lean delete mode 100644 Legacy/Nfp/Sound/HeadCert.lean delete mode 100644 Legacy/Nfp/Sound/IO.lean delete mode 100644 Legacy/Nfp/Sound/Interval.lean delete mode 100644 Legacy/Nfp/Sound/ModelHeader.lean delete mode 100644 Legacy/Nfp/Sound/TextPure.lean delete mode 100644 Legacy/Nfp/Uniqueness.lean delete mode 100644 Legacy/Nfp/Untrusted/SoundBinary.lean delete mode 100644 Legacy/Nfp/Untrusted/SoundCacheIO.lean delete mode 100644 Legacy/Nfp/Untrusted/SoundCompute.lean delete mode 100644 Legacy/Nfp/Verification.lean diff --git a/AGENTS.md b/AGENTS.md index c8ff86e..921e6cc 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -322,9 +322,6 @@ but you **must** update this list in the same commit. - `Nfp.lean` - Top-level reexports and axioms dashboard (`#print axioms`). -### 5.7 Legacy (tabula rasa transition) -- Legacy modules live under `Legacy/Nfp/` as reference only and are not built by default. - If you introduce a new conceptual layer: - either extend the closest existing file, - or add a new module with a clear name + top docstring, diff --git a/Legacy/Nfp/Abstraction.lean b/Legacy/Nfp/Abstraction.lean deleted file mode 100644 index 486cb92..0000000 --- a/Legacy/Nfp/Abstraction.lean +++ /dev/null @@ -1,424 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.Real.Basic -import Mathlib.Data.NNReal.Basic -import Mathlib.Data.Fintype.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Data.Set.Basic -import Mathlib.Logic.Equiv.Defs -import Mathlib.Data.Fin.Basic -import Nfp.Linearization -import Nfp.Uniqueness - -/-! -# Causal Consistency of Circuit Abstractions - -This module bridges the gap between **real neural network computations** -(`DeepLinearization`, containing weights and Jacobians) and **abstract causal graphs** -(`LocalSystem`, containing topology and mixing coefficients). - -## Main Results - -1. **Projection**: `DeepLinearization.toLocalSystem` extracts a causal DAG from a - network's `DeepValueTerm` (the "Attention Rollout" component). - -2. **Interventions**: `SignedMixer.ablate` and `DeepLinearization.ablate` formalize - node removal / path zeroing interventions. - -3. **Causal Consistency Theorem**: If the `DeepPatternTerm` (linearization error) - is bounded by ε, then the real network's output under ablation matches the - `LocalSystem`'s prediction within O(ε). - -## Significance - -This transforms the library from a descriptive tool into a **verification engine**. -Practitioners can input real model weights and receive a mathematical certificate -that a discovered "induction head" or "circuit" is a **genuine mechanism** and not -an interpretability illusion. - -The key insight is: -- `LocalSystem` computes via: T(i) = Σ_{u ∈ Pa(i)} c(i,u) · T(u) -- `DeepValueTerm` approximates the Jacobian via attention flow -- When `DeepPatternTerm` is small, interventions on the abstract graph - accurately predict interventions on the real network. --/ - -namespace Nfp - -open scoped BigOperators -open Finset - -/-! ## Signed Mixer Ablation -/ - -section SignedMixerAblation - -variable {S T : Type*} [Fintype S] [Fintype T] - -/-- Ablate (zero out) specific source positions in a SignedMixer. - -This models the intervention "what if we remove position i from contributing?" -Used for causal intervention analysis. -/ -noncomputable def SignedMixer.ablate (M : SignedMixer S T) (blocked : Set S) - [DecidablePred blocked] : SignedMixer S T where - w := fun i j => if blocked i then 0 else M.w i j - -@[simp] lemma SignedMixer.ablate_blocked (M : SignedMixer S T) (blocked : Set S) - [DecidablePred blocked] {i : S} (hi : blocked i) (j : T) : - (M.ablate blocked).w i j = 0 := by - simp [SignedMixer.ablate, hi] - -@[simp] lemma SignedMixer.ablate_unblocked (M : SignedMixer S T) (blocked : Set S) - [DecidablePred blocked] {i : S} (hi : ¬blocked i) (j : T) : - (M.ablate blocked).w i j = M.w i j := by - simp [SignedMixer.ablate, hi] - -/-- The effect of ablation on application to a vector. -/ -theorem SignedMixer.apply_ablate (M : SignedMixer S T) (blocked : Set S) - [DecidablePred blocked] (v : S → ℝ) (j : T) : - (M.ablate blocked).apply v j = ∑ i : S, if blocked i then 0 else v i * M.w i j := by - simp only [SignedMixer.apply_def, SignedMixer.ablate] - apply Finset.sum_congr rfl - intro i _ - split_ifs <;> ring - -/-- Ablation decomposes application into blocked and unblocked contributions. -/ -theorem SignedMixer.apply_ablate_decomposition (M : SignedMixer S T) - (blocked : Set S) [DecidablePred blocked] (v : S → ℝ) (j : T) : - M.apply v j = (M.ablate blocked).apply v j + - ∑ i : S, if blocked i then v i * M.w i j else 0 := by - simp only [SignedMixer.apply_def, SignedMixer.ablate] - rw [← Finset.sum_add_distrib] - apply Finset.sum_congr rfl - intro i _ - split_ifs <;> ring - -end SignedMixerAblation - -/-! ## Deep Linearization Ablation -/ - -section DeepLinearizationAblation - -variable {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - -/-- Ablate the `DeepValueTerm` by zeroing out contributions from blocked positions. -/ -noncomputable def DeepLinearization.ablateValueTerm - (D : DeepLinearization (n := n) (d := d)) - (blocked : Set (n × d)) [DecidablePred blocked] : - SignedMixer (n × d) (n × d) := - (DeepValueTerm D).ablate blocked - -/-- Ablate the full `composedJacobian` by zeroing out contributions from blocked positions. -/ -noncomputable def DeepLinearization.ablateJacobian - (D : DeepLinearization (n := n) (d := d)) - (blocked : Set (n × d)) [DecidablePred blocked] : - SignedMixer (n × d) (n × d) := - D.composedJacobian.ablate blocked - -/-- The difference between ablating the full Jacobian vs the value term. -/ -noncomputable def DeepLinearization.ablationError - (D : DeepLinearization (n := n) (d := d)) - (blocked : Set (n × d)) [DecidablePred blocked] : - SignedMixer (n × d) (n × d) := - D.ablateJacobian blocked - D.ablateValueTerm blocked - -/-- Ablation error equals ablated pattern term. -/ -theorem DeepLinearization.ablationError_eq_ablatedPatternTerm - (D : DeepLinearization (n := n) (d := d)) - (blocked : Set (n × d)) [DecidablePred blocked] : - D.ablationError blocked = (DeepPatternTerm D).ablate blocked := by - ext i j - simp only [ablationError, ablateJacobian, ablateValueTerm, SignedMixer.sub_w, - SignedMixer.ablate, DeepPatternTerm] - split_ifs with h - · simp - · rfl - -end DeepLinearizationAblation - -/-! ## Projection to LocalSystem -/ - -section Projection - -variable {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - -/-- Extract an NNReal coefficient from the absolute value of a real weight. - -This is used when projecting a `SignedMixer` (which has real weights) to a -`LocalSystem` (which uses NNReal coefficients for mixing). The absolute value -captures the "influence magnitude" regardless of sign. -/ -noncomputable def absToNNReal (x : ℝ) : NNReal := ⟨|x|, abs_nonneg x⟩ - -/-- A position-level signed mixer extracted from collapsing the dimension axes. -/ -noncomputable def positionMixer (M : SignedMixer (n × d) (n × d)) : - SignedMixer n n where - w := fun i j => ∑ di : d, ∑ dj : d, M.w (i, di) (j, dj) - -/-- The magnitude of position-to-position flow (for LocalSystem coefficients). -/ -noncomputable def positionFlowMagnitude (M : SignedMixer (n × d) (n × d)) - (i j : n) : NNReal := - absToNNReal ((positionMixer M).w i j) - -/-- Extract a `LocalSystem` from a `DeepLinearization` using its `DeepValueTerm`. - -The resulting graph has: -- Nodes: positions (type `n`) numbered by `e : n ≃ Fin k` -- Parents: positions with nonzero attention flow -- Coefficients: magnitude of position-to-position value term flow - -This represents the "attention rollout" approximation as a causal DAG. -/ -noncomputable def DeepLinearization.toLocalSystem - (D : DeepLinearization (n := n) (d := d)) - {k : ℕ} (e : n ≃ Fin k) - (acyclic : ∀ i j : n, (positionMixer (DeepValueTerm D)).w i j ≠ 0 → e i < e j) : - LocalSystem k := by - classical - let posOf : Fin k → n := e.symm - exact { - Pa := fun idx => - Finset.univ.filter fun u : Fin k => - (positionMixer (DeepValueTerm D)).w (posOf u) (posOf idx) ≠ 0 - c := fun idx u => - positionFlowMagnitude (DeepValueTerm D) (posOf u) (posOf idx) - topo := by - intro idx u hu - have hmem := Finset.mem_filter.mp hu - have hweight : (positionMixer (DeepValueTerm D)).w (posOf u) (posOf idx) ≠ 0 := hmem.2 - have htopo : e (posOf u) < e (posOf idx) := acyclic _ _ hweight - simpa [posOf] using htopo - } - -/-- The parent set of position `i` in the extracted LocalSystem. -/ -theorem DeepLinearization.toLocalSystem_Pa - (D : DeepLinearization (n := n) (d := d)) - {k : ℕ} (e : n ≃ Fin k) - (acyclic : ∀ i j : n, (positionMixer (DeepValueTerm D)).w i j ≠ 0 → e i < e j) - (idx : Fin k) : - (D.toLocalSystem e acyclic).Pa idx = - Finset.univ.filter fun u : Fin k => - (positionMixer (DeepValueTerm D)).w (e.symm u) (e.symm idx) ≠ 0 := rfl - -/-- The coefficient for parent `u` of position `idx` in the extracted LocalSystem. -/ -theorem DeepLinearization.toLocalSystem_c - (D : DeepLinearization (n := n) (d := d)) - {k : ℕ} (e : n ≃ Fin k) - (acyclic : ∀ i j : n, (positionMixer (DeepValueTerm D)).w i j ≠ 0 → e i < e j) - (idx u : Fin k) : - (D.toLocalSystem e acyclic).c idx u = - positionFlowMagnitude (DeepValueTerm D) (e.symm u) (e.symm idx) := rfl - -end Projection - -/-! ## Causal Consistency -/ - -section CausalConsistency - -variable {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - -/-- The error norm when comparing ablated computations. -/ -noncomputable def ablationDiscrepancy - (D : DeepLinearization (n := n) (d := d)) - (blocked : Set (n × d)) [DecidablePred blocked] - (v : (n × d) → ℝ) (j : n × d) : ℝ := - |(D.ablationError blocked).apply v j| - -/-- **Causal Consistency Bound**: The ablation discrepancy is bounded by the -pattern term's influence on the input. - -If position `i` is blocked, the discrepancy at output `j` is bounded by: - |ablated_real - ablated_abstract| ≤ Σ_{i ∉ blocked} |v_i| · |PatternTerm_{i,j}| - -This shows that when the pattern term is small, interventions on the abstract -`LocalSystem` accurately predict interventions on the real network. -/ -theorem causal_consistency_bound - (D : DeepLinearization (n := n) (d := d)) - (blocked : Set (n × d)) [DecidablePred blocked] - (v : (n × d) → ℝ) (j : n × d) : - ablationDiscrepancy D blocked v j ≤ - ∑ i : n × d, if blocked i then 0 else |v i| * |(DeepPatternTerm D).w i j| := by - simp only [ablationDiscrepancy] - -- The key insight: ablation error = ablated pattern term - have h := D.ablationError_eq_ablatedPatternTerm blocked - calc |(D.ablationError blocked).apply v j| - = |((DeepPatternTerm D).ablate blocked).apply v j| := by - rw [h] - _ = |∑ i : n × d, if blocked i then 0 else v i * (DeepPatternTerm D).w i j| := by - simp only [SignedMixer.apply_ablate] - _ ≤ ∑ i : n × d, |if blocked i then 0 else v i * (DeepPatternTerm D).w i j| := - abs_sum_le_sum_abs _ _ - _ = ∑ i : n × d, if blocked i then 0 else |v i| * |(DeepPatternTerm D).w i j| := by - apply Finset.sum_congr rfl - intro i _ - by_cases hb : blocked i - · simp [hb] - · simp [hb, abs_mul] - -/-- **Simplified Frobenius-style bound**: - -The total squared ablation discrepancy is bounded by the product of -input energy and pattern term Frobenius norm squared. - -This is a cleaner statement that captures the key O(ε) relationship. -/ -theorem causal_consistency_frobenius_simple - (D : DeepLinearization (n := n) (d := d)) - (blocked : Set (n × d)) [DecidablePred blocked] - (v : (n × d) → ℝ) : - ∑ j : n × d, (ablationDiscrepancy D blocked v j) ^ 2 ≤ - (∑ i : n × d, (v i) ^ 2) * (∑ i : n × d, ∑ j : n × d, ((DeepPatternTerm D).w i j) ^ 2) := by - -- We use a direct bound via the pointwise bounds - calc ∑ j : n × d, (ablationDiscrepancy D blocked v j) ^ 2 - ≤ ∑ j : n × d, (∑ i : n × d, |v i| * |(DeepPatternTerm D).w i j|) ^ 2 := by - apply Finset.sum_le_sum - intro j _ - apply sq_le_sq' - · have hnn := Finset.sum_nonneg - (fun i (_ : i ∈ Finset.univ) => mul_nonneg (abs_nonneg (v i)) - (abs_nonneg ((DeepPatternTerm D).w i j))) - calc -(∑ i : n × d, |v i| * |(DeepPatternTerm D).w i j|) - ≤ 0 := neg_nonpos.mpr hnn - _ ≤ ablationDiscrepancy D blocked v j := abs_nonneg _ - · have hbound := causal_consistency_bound D blocked v j - refine le_trans hbound ?_ - apply Finset.sum_le_sum - intro i _ - by_cases hb : blocked i - · have hnonneg : - 0 ≤ |v i| * |(DeepPatternTerm D).w i j| := - mul_nonneg (abs_nonneg _) (abs_nonneg _) - simpa [hb] using hnonneg - · simp [hb] - _ ≤ ∑ j : n × d, (∑ i : n × d, (v i) ^ 2) * (∑ i : n × d, ((DeepPatternTerm D).w i j) ^ 2) := by - apply Finset.sum_le_sum - intro j _ - -- Cauchy-Schwarz: (Σ ab)² ≤ (Σ a²)(Σ b²) - have cs : (∑ i : n × d, |v i| * |(DeepPatternTerm D).w i j|) ^ 2 ≤ - (∑ i : n × d, (|v i|) ^ 2) * (∑ i : n × d, (|(DeepPatternTerm D).w i j|) ^ 2) := - Finset.sum_mul_sq_le_sq_mul_sq Finset.univ (fun i => |v i|) - (fun i => |(DeepPatternTerm D).w i j|) - simp only [sq_abs] at cs - exact cs - _ = (∑ i : n × d, (v i) ^ 2) * (∑ j : n × d, ∑ i : n × d, ((DeepPatternTerm D).w i j) ^ 2) := by - rw [Finset.mul_sum] - _ = (∑ i : n × d, (v i) ^ 2) * (∑ i : n × d, ∑ j : n × d, ((DeepPatternTerm D).w i j) ^ 2) := by - congr 1 - rw [Finset.sum_comm] - -/-- **Main Causal Consistency Theorem**: - -If a network's `DeepPatternTerm` has Frobenius norm bounded by ε, then -interventions (ablations) on the extracted `LocalSystem` predict the -real network's behavior within O(ε) error. - -Specifically: Σⱼ (discrepancy_j)² ≤ ε² · Σᵢ (vᵢ)² - -This is the key result that turns the library into a verification engine: -- Small pattern term → attention rollout is faithful -- Faithful rollout → discovered circuits are genuine mechanisms -- Genuine mechanisms → interventions have predictable effects -/ -theorem causal_consistency_frobenius - (D : DeepLinearization (n := n) (d := d)) - (blocked : Set (n × d)) [DecidablePred blocked] - (v : (n × d) → ℝ) - (ε : ℝ) (hε : frobeniusNorm (DeepPatternTerm D) ≤ ε) : - ∑ j : n × d, (ablationDiscrepancy D blocked v j) ^ 2 ≤ - ε ^ 2 * (∑ i : n × d, (v i) ^ 2) := by - have h1 := causal_consistency_frobenius_simple D blocked v - have h2 : ∑ i : n × d, ∑ j : n × d, ((DeepPatternTerm D).w i j) ^ 2 ≤ ε ^ 2 := by - calc ∑ i : n × d, ∑ j : n × d, ((DeepPatternTerm D).w i j) ^ 2 - = (frobeniusNorm (DeepPatternTerm D)) ^ 2 := by - simp only [frobeniusNorm] - rw [Real.sq_sqrt] - exact Finset.sum_nonneg (fun i _ => Finset.sum_nonneg (fun j _ => sq_nonneg _)) - _ ≤ ε ^ 2 := by - apply sq_le_sq' - · calc -ε ≤ 0 := by - by_contra hne - push_neg at hne - have : frobeniusNorm (DeepPatternTerm D) ≤ ε := hε - have hpos : 0 ≤ frobeniusNorm (DeepPatternTerm D) := by - simp only [frobeniusNorm] - exact Real.sqrt_nonneg _ - linarith - _ ≤ frobeniusNorm (DeepPatternTerm D) := by - simp only [frobeniusNorm] - exact Real.sqrt_nonneg _ - · exact hε - calc ∑ j : n × d, (ablationDiscrepancy D blocked v j) ^ 2 - ≤ (∑ i : n × d, (v i) ^ 2) * (∑ i : n × d, ∑ j : n × d, ((DeepPatternTerm D).w i j) ^ 2) := h1 - _ ≤ (∑ i : n × d, (v i) ^ 2) * ε ^ 2 := by - apply mul_le_mul_of_nonneg_left h2 - exact Finset.sum_nonneg (fun i _ => sq_nonneg _) - _ = ε ^ 2 * (∑ i : n × d, (v i) ^ 2) := by ring - -end CausalConsistency - -/-! ## Mechanism Certification -/ - -section MechanismCertification - -variable {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - -/-- A circuit is **causally certified** if its pattern term error is below threshold. - -This means interventions on the abstract `LocalSystem` derived from the circuit -will accurately predict the real network's behavior. -/ -def isCausallyCertified (D : DeepLinearization (n := n) (d := d)) (threshold : ℝ) : Prop := - frobeniusNorm (DeepPatternTerm D) ≤ threshold - -/-- A certified circuit's ablations are faithful within the error bound. -/ -theorem certified_ablation_faithful - (D : DeepLinearization (n := n) (d := d)) - (threshold : ℝ) - (hcert : isCausallyCertified D threshold) - (blocked : Set (n × d)) [DecidablePred blocked] - (v : (n × d) → ℝ) : - ∑ j : n × d, (ablationDiscrepancy D blocked v j) ^ 2 ≤ - threshold ^ 2 * (∑ i : n × d, (v i) ^ 2) := - causal_consistency_frobenius D blocked v threshold hcert - -/-- A mechanism discovered via interpretability is **verified** if: -1. The extracted `LocalSystem` has the expected structure (e.g., induction head pattern) -2. The circuit is causally certified (small pattern term) - -When both hold, the mechanism is a genuine causal explanation of the network's behavior, -not an interpretability illusion. -/ -structure VerifiedMechanism (D : DeepLinearization (n := n) (d := d)) where - /-- The certification threshold -/ - threshold : ℝ - /-- The threshold is positive -/ - threshold_pos : 0 < threshold - /-- The circuit meets the certification bound -/ - certified : isCausallyCertified D threshold - /-- Description of the discovered mechanism (for documentation) -/ - description : String - -/-- Any verified mechanism satisfies causal consistency. -/ -theorem VerifiedMechanism.causal_consistency - {D : DeepLinearization (n := n) (d := d)} - (M : VerifiedMechanism D) - (blocked : Set (n × d)) [DecidablePred blocked] - (v : (n × d) → ℝ) : - ∑ j : n × d, (ablationDiscrepancy D blocked v j) ^ 2 ≤ - M.threshold ^ 2 * (∑ i : n × d, (v i) ^ 2) := - certified_ablation_faithful D M.threshold M.certified blocked v - -/-- The RMS discrepancy is bounded by threshold times RMS input. -/ -theorem VerifiedMechanism.rms_bound - {D : DeepLinearization (n := n) (d := d)} - (M : VerifiedMechanism D) - (blocked : Set (n × d)) [DecidablePred blocked] - (v : (n × d) → ℝ) : - Real.sqrt (∑ j : n × d, (ablationDiscrepancy D blocked v j) ^ 2) ≤ - M.threshold * Real.sqrt (∑ i : n × d, (v i) ^ 2) := by - have h := M.causal_consistency blocked v - have hpos : 0 ≤ ∑ j : n × d, (ablationDiscrepancy D blocked v j) ^ 2 := - Finset.sum_nonneg (fun j _ => sq_nonneg _) - calc Real.sqrt (∑ j : n × d, (ablationDiscrepancy D blocked v j) ^ 2) - ≤ Real.sqrt (M.threshold ^ 2 * (∑ i : n × d, (v i) ^ 2)) := Real.sqrt_le_sqrt h - _ = |M.threshold| * Real.sqrt (∑ i : n × d, (v i) ^ 2) := by - rw [Real.sqrt_mul (sq_nonneg _), Real.sqrt_sq_eq_abs] - _ = M.threshold * Real.sqrt (∑ i : n × d, (v i) ^ 2) := by - rw [abs_of_pos M.threshold_pos] - -end MechanismCertification - -end Nfp diff --git a/Legacy/Nfp/Appendix.lean b/Legacy/Nfp/Appendix.lean deleted file mode 100644 index e927a0d..0000000 --- a/Legacy/Nfp/Appendix.lean +++ /dev/null @@ -1,956 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.NNReal.Basic -import Mathlib.Data.Fintype.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Indicator -import Mathlib.Data.Finset.Basic -import Mathlib.Data.List.Basic -import Mathlib.Order.MinMax -import Mathlib.Data.Finset.Lattice.Basic -import Nfp.Prob -import Nfp.PCC -import Nfp.Uniqueness -import Nfp.Mixer -import Nfp.Influence -import Nfp.Reroute.Partition -import Nfp.Reroute.Heat - -namespace Nfp - -open scoped BigOperators -open Finset - -namespace List - -lemma take_succ_drop_append {α : Type*} : - ∀ {xs : List α} {k : ℕ}, - xs.take (k + 1) = xs.take k ++ (xs.drop k).take 1 - | [], 0 => by simp - | [], k+1 => by simp - | _ :: _, 0 => by simp - | x :: xs, k + 1 => by - have ih := - take_succ_drop_append (xs:=xs) (k:=k) - simp [ih] - -end List - -/-! -# Appendix A mapping (Lean formalization) - -This file collects the statements used to mirror the Appendix A results: - -- Appendix A.4 (PCC curve): `normMass`, `feasible`, `pccArg`, `pccMax`, - `PCC`, monotonicity (`PCC_monotone`), and the dominance/upper-bound lemmas - (`pccMax_dominates`, `PCC_upper_bounds_masks`), plus the existence/greedy view - `greedy_topmass_optimal`. -- Appendix A.3 (Residual rule): `lambdaEC`, `lambdaEC_sum_one`, - `residual_lambda_from_norm`, and the global characterization - `lambdaEC_scale_invariant_global`. -- Appendix A.2 (Normalization on a subset): `normalizeOn` and its properties - (`normalizeOn_outside`, `normalizeOn_inside`, `normalizeOn_sum_one`), and the - uniqueness theorem `proportional_row_unique`. - -All proofs avoid `sorry`. Some results use classical reasoning (`classical`). --/ - -section PCC - -variable {S : Type*} [Fintype S] - -variable (m : Nfp.Contrib S) - -lemma drop_eq_mass (A : Finset S) : - (A.sum m) / (∑ i, m i) = (A.sum (fun i => m i / (∑ j, m j))) := by - classical - -- Division distributes over finite sums in `NNReal`. - simp [Finset.sum_div] - -/-- Appendix A.4 (upper bound): any mask with tracer budget ≤ τ removes at most τ -of normalized logit. -/ -lemma pcc_upper_bound_for_any_mask (A : Finset S) (τ : NNReal) - (hA : (A.sum (fun i => m i / (∑ j, m j))) ≤ τ) : - (A.sum m) / (∑ j, m j) ≤ τ := by - simpa [drop_eq_mass (m:=m) (A:=A)] using hA - -/-- Appendix A.4 (monotonicity helper): normalized removed mass is monotone along -any nested mask chain. -/ -lemma normalized_sum_monotone (A : ℕ → Finset S) - (hchain : ∀ k, A k ⊆ A (k + 1)) : - Monotone (fun k => (A k).sum (fun i => m i / (∑ j, m j))) := by - classical - -- Apply monotonicity of raw sums to the pointwise scaled weights. - have := Nfp.sum_monotone_chain (A:=A) (w:=fun i => m i / (∑ j, m j)) hchain - simpa using this - -/-- Appendix A.4: normalized mass of a mask `A`. -/ -noncomputable def normMass (m : S → NNReal) (A : Finset S) : NNReal := - (A.sum m) / (∑ j, m j) - -@[simp] lemma normMass_empty (m : S → NNReal) : normMass m (∅ : Finset S) = 0 := by - simp [normMass] - -lemma normMass_union [DecidableEq S] (m : S → NNReal) (A B : Finset S) - (hdisj : Disjoint A B) : - normMass m (A ∪ B) = normMass m A + normMass m B := by - classical - have hsum := Finset.sum_union hdisj (f := m) - simp [normMass, hsum, add_div] - -lemma normMass_partitionUnion [DecidableEq S] (m : S → NNReal) - (parts : List (Finset S)) - (hpair : parts.Pairwise (fun A B => Disjoint A B)) : - normMass m (unionParts parts) = - parts.foldr (fun A acc => normMass m A + acc) 0 := by - classical - induction parts with - | nil => - simp [unionParts] - | cons A parts ih => - rcases List.pairwise_cons.mp hpair with ⟨hA, htail⟩ - have hdisj : Disjoint A (unionParts parts) := - disjoint_unionParts A parts (by - intro B hB - exact hA _ (by simpa using hB)) - have hAdd := - normMass_union (m:=m) A (unionParts parts) hdisj - simp [unionParts, hAdd, ih htail, List.foldr] - -lemma normMass_partition_eq_one [DecidableEq S] (m : S → NNReal) - (parts : List (Finset S)) - (hpair : parts.Pairwise (fun A B => Disjoint A B)) - (hcover : unionParts parts = (Finset.univ : Finset S)) - (htot : (∑ j, m j) ≠ 0) : - parts.foldr (fun A acc => normMass m A + acc) 0 = 1 := by - classical - have hsum : - parts.foldr (fun A acc => normMass m A + acc) 0 = normMass m (unionParts parts) := by - simpa using (normMass_partitionUnion (m:=m) parts hpair).symm - simpa [hcover, normMass, htot] using hsum - -namespace ReroutePlan - -section - -variable {S : Type*} [Fintype S] [DecidableEq S] - -lemma normMass_sum_one (m : S → NNReal) (P : ReroutePlan (S := S)) - (htot : (∑ j, m j) ≠ 0) : - P.masks.foldr (fun A acc => normMass m A + acc) 0 = 1 := by - simpa [ReroutePlan.masks] using - normMass_partition_eq_one (S := S) m P.masks P.pairwise_disjoint P.covers_univ htot - -end - -end ReroutePlan - -lemma normMass_rerouteHeat_increment {S : Type*} [Fintype S] [DecidableEq S] - (P : WeightedReroutePlan (S := S)) {A : Finset S} {w : NNReal} - (hmem : (A, w) ∈ P.plan.increments.zip P.weights) : - normMass (fun i => (P.rerouteHeat).mass i) A = w / P.weightsSum := by - classical - have hsum := WeightedReroutePlan.rerouteHeat_sum_increment (P:=P) hmem - have hdenom : (∑ i, (P.rerouteHeat).mass i) = 1 := ProbVec.sum_mass _ - have hsum' : - (∑ i ∈ A, P.heatRaw i / P.weightsSum) = w / P.weightsSum := by - simpa [WeightedReroutePlan.rerouteHeat_mass] using hsum - have hdenom' : - (∑ i, P.heatRaw i / P.weightsSum) = 1 := by - simpa [WeightedReroutePlan.rerouteHeat_mass] using hdenom - simp [normMass, hsum', hdenom'] - - -/-- Appendix A.4: feasible masks under budget `τ` (by normalized mass). -/ -noncomputable def feasible (m : S → NNReal) (τ : NNReal) : Finset (Finset S) := - (Finset.univ : Finset S).powerset.filter (fun A => normMass m A ≤ τ) - -lemma feasible_nonempty (m : S → NNReal) (τ : NNReal) : - (feasible (S:=S) m τ).Nonempty := by - classical - refine ⟨∅, ?_⟩ - simp [feasible, normMass] - --- Appendix A.4: argmax mask under budget τ w.r.t. normalized mass. -noncomputable def pccArg (m : S → NNReal) (τ : NNReal) : Finset S := by - classical - exact Classical.choose - (Finset.exists_max_image (feasible (S:=S) m τ) (fun A => normMass m A) - (feasible_nonempty (S:=S) m τ)) - -private lemma pccArg_mem (m : S → NNReal) (τ : NNReal) : - pccArg (S:=S) m τ ∈ feasible (S:=S) m τ := by - classical - have h := Classical.choose_spec - (Finset.exists_max_image (feasible (S:=S) m τ) (fun A => normMass m A) - (feasible_nonempty (S:=S) m τ)) - exact h.left - -private lemma pccArg_is_max (m : S → NNReal) (τ : NNReal) : - ∀ B ∈ feasible (S:=S) m τ, normMass m B ≤ normMass m (pccArg (S:=S) m τ) := by - classical - have h := Classical.choose_spec - (Finset.exists_max_image (feasible (S:=S) m τ) (fun A => normMass m A) - (feasible_nonempty (S:=S) m τ)) - exact h.right - -/-- Appendix A.4: maximal normalized drop (the PCC value at budget `τ`). -/ -noncomputable def pccMax (m : S → NNReal) (τ : NNReal) : NNReal := - normMass m (pccArg (S:=S) m τ) - -lemma pccMax_le_tau (m : S → NNReal) (τ : NNReal) : - pccMax (S:=S) m τ ≤ τ := by - classical - have hmem := pccArg_mem (S:=S) m τ - simpa [pccMax] using (Finset.mem_filter.mp hmem).2 - -lemma pccMax_monotone (m : S → NNReal) : - Monotone (fun τ : NNReal => pccMax (S:=S) m τ) := by - classical - intro τ₁ τ₂ hle - -- Any τ₁-feasible set is τ₂-feasible when τ₁ ≤ τ₂. - have hsubset : feasible (S:=S) m τ₁ ⊆ feasible (S:=S) m τ₂ := by - intro A hA - rcases Finset.mem_filter.mp hA with ⟨hpow, hbudget⟩ - have hbudget' : normMass m A ≤ τ₂ := hbudget.trans hle - exact (Finset.mem_filter.mpr ⟨hpow, hbudget'⟩) - -- The τ₂-argmax is ≥ any τ₁-feasible value, in particular the τ₁-argmax. - have hAτ1 := pccArg_mem (S:=S) m τ₁ - have hAτ1_in : pccArg (S:=S) m τ₁ ∈ feasible (S:=S) m τ₂ := hsubset hAτ1 - have hmaxτ2 := pccArg_is_max (S:=S) m τ₂ (pccArg (S:=S) m τ₁) hAτ1_in - simpa [pccMax] using hmaxτ2 - - -/-- Appendix A.4 (dominance/upper bound): for any mask `B` that removes at most -`τ` of tracer mass (normalized), its normalized drop is upper-bounded by `pccMax m τ`. -Equivalently, `pccMax` is the supremum over all feasible masks at budget `τ`. -/ -lemma pccMax_dominates (m : S → NNReal) (τ : NNReal) (B : Finset S) - (hB : normMass m B ≤ τ) : - normMass m B ≤ pccMax (S:=S) m τ := by - classical - -- Feasibility of B at budget τ - have hB_mem : B ∈ feasible (S:=S) m τ := by - have hpow : B ∈ (Finset.univ : Finset S).powerset := by - simp - exact (Finset.mem_filter.mpr ⟨hpow, hB⟩) - -- Maximality of the arg at τ - have hmax := pccArg_is_max (S:=S) m τ B hB_mem - simpa [pccMax] using hmax - -/-- Appendix A.4 (PCC curve): normalized logit drop achievable at tracer-mass -budget `t`. This matches the budgeted maximum `pccMax` over feasible masks. -/ -noncomputable abbrev PCC (m : S → NNReal) (t : NNReal) : NNReal := pccMax (S:=S) m t - -@[simp] lemma PCC_def (m : S → NNReal) (t : NNReal) : PCC (S:=S) m t = pccMax (S:=S) m t := rfl - -/-- Appendix A.4 (monotonicity): the PCC curve is monotone in the budget `t`. -/ -lemma PCC_monotone (m : S → NNReal) : Monotone (fun t : NNReal => PCC (S:=S) m t) := by - simpa [PCC_def] using pccMax_monotone (S:=S) m - -/-- Appendix A.4 (upper bound, mask phrasing): for any mask `B` that -preserves at least `1 - t` tracer mass (i.e., removes ≤ `t`), the normalized -logit drop of `B` is at most `PCC m t`. -/ -lemma PCC_upper_bounds_masks (m : S → NNReal) (t : NNReal) (B : Finset S) - (hBudget : normMass m B ≤ t) : - normMass m B ≤ PCC (S:=S) m t := by - simpa [PCC_def] using pccMax_dominates (S:=S) m t B hBudget - -lemma rerouteHeat_pcc_drop {S : Type*} [Fintype S] [DecidableEq S] - (P : WeightedReroutePlan (S := S)) - {A : Finset S} {w : NNReal} - (hmem : (A, w) ∈ P.plan.increments.zip P.weights) : - PCC (S:=S) (fun i => (P.rerouteHeat).mass i) (w / P.weightsSum) = w / P.weightsSum := by - classical - set m := fun i => (P.rerouteHeat).mass i - have hnorm' : normMass m A = w / P.weightsSum := by - simpa [m] using normMass_rerouteHeat_increment (P:=P) hmem - have hfeas : normMass m A ≤ w / P.weightsSum := by - simp [hnorm'] - have hdom := pccMax_dominates (S:=S) m (w / P.weightsSum) A hfeas - have hupper : PCC (S:=S) m (w / P.weightsSum) ≤ w / P.weightsSum := by - simpa [PCC_def] using pccMax_le_tau (S:=S) m (w / P.weightsSum) - have hdom' : w / P.weightsSum ≤ pccMax m (w / P.weightsSum) := by - simpa [hnorm'] using hdom - have hlower : w / P.weightsSum ≤ PCC (S:=S) m (w / P.weightsSum) := by - simpa [PCC_def] using hdom' - exact le_antisymm hupper hlower - -/-- Convenience corollary: the `k`-th reroute increment (mask/weight pair) -from a weighted plan achieves equality with the PCC budget that matches its -normalized weight. This version avoids spelunking through `zip` membership -and is tailored for the exported segment lists used in the Python notebooks. -/ -lemma rerouteHeat_pcc_drop_index {S : Type*} [Fintype S] [DecidableEq S] - (P : WeightedReroutePlan (S := S)) {k : Nat} - (hk : k < P.plan.increments.length) : - PCC (S:=S) (fun i => (P.rerouteHeat).mass i) - (P.weights.get ⟨k, by simpa [P.length_eq_increments] using hk⟩ / P.weightsSum) - = P.weights.get ⟨k, by simpa [P.length_eq_increments] using hk⟩ / P.weightsSum := by - classical - have hlen : P.plan.increments.length = P.weights.length := (P.length_eq_increments).symm - have hw : k < P.weights.length := hlen ▸ hk - have hmem : - (P.plan.increments.get ⟨k, hk⟩, - P.weights.get ⟨k, hw⟩) ∈ P.plan.increments.zip P.weights := by - simpa [hw, hlen] using - List.get_mem_zip (xs:=P.plan.increments) (ys:=P.weights) - (k:=k) hk hw - have h := rerouteHeat_pcc_drop (P:=P) (hmem := hmem) - exact h - -namespace WeightedReroutePlan - -variable {S : Type*} [Fintype S] [DecidableEq S] - -/-- Stage-4 prefix objective: restrict the AUC evaluation to the first `k` -intervals induced by `weights`. -/ -noncomputable def rerouteHeatObjectivePrefix - (P : WeightedReroutePlan (S := S)) (k : ℕ) : NNReal := - PCC.evalFromWeights - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - P.weightsSum (P.weights.take k) - -/-- Objective used in Stage 4: the PCC AUC achieved by the reroute plan’s -`rerouteHeat` distribution, evaluated via the generic weights-to-interval map. -/ -noncomputable def rerouteHeatObjective (P : WeightedReroutePlan (S := S)) : NNReal := - PCC.evalFromWeights - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - P.weightsSum P.weights - -lemma rerouteHeatObjectivePrefix_as_auc (P : WeightedReroutePlan (S := S)) (k : ℕ) : - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - ((P.aucIntervals).take k) - = P.rerouteHeatObjectivePrefix k := by - classical - simpa [rerouteHeatObjectivePrefix, aucIntervals] using - PCC.AUC_intervalsFromWeights_take - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - P.weightsSum P.weights k - -lemma rerouteHeatObjective_as_auc (P : WeightedReroutePlan (S := S)) : - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - (P.aucIntervals) - = P.rerouteHeatObjective := by - simp [rerouteHeatObjective, auc_eval] - -lemma rerouteHeatObjectivePrefix_le (P : WeightedReroutePlan (S := S)) (k : ℕ) : - P.rerouteHeatObjectivePrefix k ≤ P.rerouteHeatObjective := by - classical - have hmono := - PCC.AUC_monotone_append - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - (P.aucIntervals.take k) (P.aucIntervals.drop k) - have hsplit := List.take_append_drop k P.aucIntervals - have hle : - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - (P.aucIntervals.take k) - ≤ - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - (P.aucIntervals) := by - simpa [hsplit] using hmono - have hprefix := (rerouteHeatObjectivePrefix_as_auc (P:=P) (k:=k)).symm - have htotal := (rerouteHeatObjective_as_auc (P:=P)).symm - simpa [hprefix, htotal] using hle - -lemma rerouteHeatObjective_split (P : WeightedReroutePlan (S := S)) (k : ℕ) : - P.rerouteHeatObjective = - P.rerouteHeatObjectivePrefix k + - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - (P.aucIntervals.drop k) := by - classical - have h := - PCC.AUC_append - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - (P.aucIntervals.take k) (P.aucIntervals.drop k) - have hsplit := List.take_append_drop k P.aucIntervals - have hsum : - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - (P.aucIntervals) - = - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - (P.aucIntervals.take k) + - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - (P.aucIntervals.drop k) := by - simpa [hsplit] using h - have hprefix := (rerouteHeatObjectivePrefix_as_auc (P:=P) (k:=k)).symm - have htotal := (rerouteHeatObjective_as_auc (P:=P)).symm - simpa [hprefix, htotal] using hsum - -/-- Prefix addition lemma: the `(k+1)`-st prefix equals the `k`-prefix plus the -contribution coming from the `(k+1)`-st interval (if present). -/ -lemma rerouteHeatObjectivePrefix_succ (P : WeightedReroutePlan (S := S)) (k : ℕ) : - P.rerouteHeatObjectivePrefix (k + 1) = - P.rerouteHeatObjectivePrefix k + - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - ((P.aucIntervals.drop k).take 1) := by - classical - have hsum : - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - (P.aucIntervals.take (k + 1)) - = - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - (P.aucIntervals.take k) + - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - ((P.aucIntervals.drop k).take 1) := by - simpa [List.take_succ_drop_append (xs:=P.aucIntervals) (k:=k)] using - PCC.AUC_append - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - (P.aucIntervals.take k) - ((P.aucIntervals.drop k).take 1) - have hprefix := rerouteHeatObjectivePrefix_as_auc (P:=P) (k:=k) - have hprefix_succ := rerouteHeatObjectivePrefix_as_auc (P:=P) (k:=k+1) - have hsum' := hsum - rw [hprefix_succ] at hsum' - have hsum'' := hsum' - rw [hprefix] at hsum'' - exact hsum'' - -lemma rerouteHeatObjectivePrefix_succ_le (P : WeightedReroutePlan (S := S)) (k : ℕ) : - P.rerouteHeatObjectivePrefix k ≤ P.rerouteHeatObjectivePrefix (k + 1) := by - classical - have h := rerouteHeatObjectivePrefix_succ (P:=P) (k:=k) - have hnonneg : - 0 ≤ - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - ((P.aucIntervals.drop k).take 1) := - PCC.AUC_nonneg - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - ((P.aucIntervals.drop k).take 1) - have hle : - P.rerouteHeatObjectivePrefix k ≤ - P.rerouteHeatObjectivePrefix k + - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - ((P.aucIntervals.drop k).take 1) := - le_add_of_nonneg_right hnonneg - calc - P.rerouteHeatObjectivePrefix k - ≤ P.rerouteHeatObjectivePrefix k + - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - ((P.aucIntervals.drop k).take 1) := hle - _ = P.rerouteHeatObjectivePrefix (k + 1) := by - simp [h] - -lemma rerouteHeatObjectivePrefix_mono (P : WeightedReroutePlan (S := S)) - {k₁ k₂ : ℕ} (hle : k₁ ≤ k₂) : - P.rerouteHeatObjectivePrefix k₁ ≤ P.rerouteHeatObjectivePrefix k₂ := by - classical - have hmono : - ∀ d k, P.rerouteHeatObjectivePrefix k ≤ P.rerouteHeatObjectivePrefix (k + d) := by - intro d - induction d with - | zero => - intro k - simp - | succ d ih => - intro k - have hsucc := - rerouteHeatObjectivePrefix_succ_le - (P:=P) (k:=k + d) - have := ih k - have := this.trans hsucc - simpa [Nat.add_comm, Nat.add_left_comm, Nat.add_assoc] using this - obtain ⟨d, rfl⟩ := Nat.exists_eq_add_of_le hle - simpa using hmono d k₁ - -/-- Feasibility predicate for Stage 4: each incremental mask carries the exact -normalized mass dictated by its weight ratio. -/ -def rerouteFeasible (m : S → NNReal) (P : WeightedReroutePlan (S := S)) : Prop := - ∀ {A : Finset S} {w : NNReal}, - (A, w) ∈ P.plan.increments.zip P.weights → - normMass m A = w / P.weightsSum - -lemma rerouteHeat_feasible (P : WeightedReroutePlan (S := S)) : - rerouteFeasible (S:=S) (fun i => (P.rerouteHeat).mass i) P := by - intro A w hmem - simpa using normMass_rerouteHeat_increment (P:=P) (hmem:=hmem) - -/-- Stage-4 goal statement (stub): a “delta-weighted” optimal plan maximizes -`rerouteHeatObjective` among feasible plans. Proof to be supplied once the -majorization lemmas are in place. -/ -def optimal_delta_weighting_statement (m : S → NNReal) : Prop := - ∃ P : WeightedReroutePlan (S := S), - rerouteFeasible (S:=S) m P ∧ - ∀ Q : WeightedReroutePlan (S := S), - rerouteFeasible (S:=S) m Q → - Q.rerouteHeatObjective ≤ P.rerouteHeatObjective - -end WeightedReroutePlan - - -/-- Appendix A.4 (greedy optimality for mass budgets): there exists a -subset `A` (a feasible mask under budget `τ`) that maximizes normalized -removed mass among all feasible masks. Concretely we can take `A = pccArg m τ`. -This ties the constructive “pick a best feasible mask” view to our -`pccArg/pccMax` definitions; ties are handled by the argmax choice. -/ -lemma greedy_topmass_optimal (m : S → NNReal) (τ : NNReal) : - ∃ A : Finset S, - A ∈ feasible (S:=S) m τ ∧ - (∀ B : Finset S, B ∈ feasible (S:=S) m τ → - normMass m B ≤ normMass m A) := by - classical - refine ⟨pccArg (S:=S) m τ, ?_, ?_⟩ - · exact pccArg_mem (S:=S) m τ - · intro B hB - exact pccArg_is_max (S:=S) m τ B hB - - -/-- Greedy top-k optimality (helper for A.4): among all subsets with cardinality ≤ k, -there exists -an optimal set A maximizing `A.sum m` such that no outside element has larger weight -than an inside element (swap-optimality). This captures the compact majorization-style -“top-k is optimal” property without constructing an explicit sort. -/ -lemma greedy_topk_optimal (m : S → NNReal) (k : ℕ) : - ∃ A : Finset S, - A.card ≤ k ∧ - (∀ B : Finset S, B.card ≤ k → B.sum m ≤ A.sum m) ∧ - (∀ {i j}, i ∈ A → j ∉ A → m i ≥ m j) := by - classical - -- Candidates: all subsets of `univ` with card ≤ k. - let C : Finset (Finset S) := (Finset.univ : Finset S).powerset.filter (fun A => A.card ≤ k) - have hC_nonempty : C.Nonempty := by - refine ⟨∅, ?_⟩ - simp [C] - -- Pick A ∈ C maximizing the sum. - obtain ⟨A, hA_mem, hAmax⟩ := Finset.exists_max_image C (fun A : Finset S => A.sum m) hC_nonempty - -- Basic facts about A - have hA_card_le : A.card ≤ k := by - rcases Finset.mem_filter.mp hA_mem with ⟨hA_ps, hk⟩ - exact hk - -- Show optimality among all B with card ≤ k. - have hA_opt : ∀ B : Finset S, B.card ≤ k → B.sum m ≤ A.sum m := by - intro B hBk - have hB_mem : B ∈ C := by - -- Any B ⊆ univ is in the powerset, and `B.card ≤ k` ensures it passes the filter. - simp [C, hBk] - exact hAmax B hB_mem - -- Swap optimality: no beneficial swap exists. - have hswap : ∀ {i j}, i ∈ A → j ∉ A → m i ≥ m j := by - intro i j hiA hjA - by_contra hij - have hij' : m i < m j := lt_of_not_ge hij - -- Construct the swapped set B = insert j (A.erase i) - let B := insert j (A.erase i) - have hj_not : j ∉ A.erase i := by - simp [Finset.mem_erase, hjA] - have hcard_eq : B.card = A.card := by - have hIns : insert i (A.erase i) = A := by simpa using Finset.insert_erase hiA - have hi_not : i ∉ A.erase i := by simp - have hA_card' : (A.erase i).card + 1 = A.card := by - -- From card(insert i (erase i A)) = card(erase i A) + 1 - calc - (A.erase i).card + 1 = (insert i (A.erase i)).card := by - simp [Finset.card_insert_of_notMem hi_not] - _ = A.card := by simp [hIns] - have hB_card : B.card = (A.erase i).card + 1 := by - simp [B, Finset.card_insert_of_notMem hj_not] - simpa [hB_card] using hA_card' - have hB_card_le : B.card ≤ k := by simpa [hcard_eq] using hA_card_le - -- Compare sums via erase/insert decomposition - have hsumA : A.sum m = m i + (A.erase i).sum m := by - have hi_not : i ∉ A.erase i := by simp - have hIns : insert i (A.erase i) = A := by simpa using Finset.insert_erase hiA - have sumIns := Finset.sum_insert (s:=A.erase i) (a:=i) (f:=m) hi_not - simpa [hIns, add_comm] using sumIns - have hsumB : B.sum m = m j + (A.erase i).sum m := by - have sumInsB := Finset.sum_insert (s:=A.erase i) (a:=j) (f:=m) hj_not - simpa [B, add_comm] using sumInsB - have : A.sum m < B.sum m := by - have h := add_lt_add_right hij' ((A.erase i).sum m) - simpa [hsumA, hsumB, add_comm, add_left_comm, add_assoc] using h - have hB_le := hA_opt B hB_card_le - exact (not_le_of_gt this) hB_le - exact ⟨A, hA_card_le, hA_opt, hswap⟩ - -end PCC - -/-! Residual additions: energy-consistent coefficients -/ - -section Residual - -/-- Energy-consistent residual coefficient (Appendix A.3). -/ -noncomputable def lambdaEC (x a : NNReal) : NNReal := x / (x + a) - -lemma lambdaEC_sum_one (x a : NNReal) (hx : x + a ≠ 0) : - lambdaEC x a + a / (x + a) = 1 := by - unfold lambdaEC - have : x / (x + a) + a / (x + a) = (x + a) / (x + a) := by - simp [add_div] - simp [this, div_self hx] - -/-- If a residual coefficient `λ` satisfies `λ + a/(x+a) = 1` (with `x+a ≠ 0`), -then necessarily `λ = x/(x+a)` in `NNReal`. -/ -lemma residual_lambda_from_norm - (lcoeff : NNReal → NNReal → NNReal) - (x a : NNReal) (hx : x + a ≠ 0) - (hnorm : lcoeff x a + a / (x + a) = 1) : - lcoeff x a = x / (x + a) := by - -- Work in ℝ via coercions to use field operations. - have hR : (lcoeff x a : ℝ) + (a : ℝ) / (x + a) = 1 := by - simpa using congrArg (fun t : NNReal => (t : ℝ)) hnorm - have hxneR : (x + a : ℝ) ≠ 0 := by exact_mod_cast hx - have hval : (lcoeff x a : ℝ) = x / (x + a) := by - -- From hR: l + a/(x+a) = 1 ⇒ l = 1 - a/(x+a) = x/(x+a) - have : (lcoeff x a : ℝ) = 1 - (a : ℝ) / (x + a) := by - exact (eq_sub_iff_add_eq).mpr hR - have hunit : (1 : ℝ) = (x + a) / (x + a) := by simp [div_self hxneR] - have hdiff : (x + a : ℝ) / (x + a) - (a : ℝ) / (x + a) = ((x + a : ℝ) - a) / (x + a) := by - simpa [sub_eq_add_neg] using (sub_div ((x + a : ℝ)) (a : ℝ) (x + a)).symm - have : (lcoeff x a : ℝ) = ((x + a : ℝ) - a) / (x + a) := by simp [this, hunit, hdiff] - simpa [sub_eq_add_neg, add_comm] - -- Conclude equality in NNReal by extensionality on coercions to ℝ. - apply Subtype.ext - simpa using hval - -/-- Global characterization: any residual rule which is pointwise row-normalized -(`lcoeff x a + a/(x+a) = 1` whenever `x+a ≠ 0`) must be `x/(x+a)` (Appendix A.3). -The Appendix mentions scale-invariance for motivation, but the normalization -equation alone determines the coefficients, so we keep the statement minimal. -/ -lemma lambdaEC_scale_invariant_global - (lcoeff : NNReal → NNReal → NNReal) - (hnorm : ∀ (x a : NNReal), x + a ≠ 0 → lcoeff x a + a / (x + a) = 1) : - ∀ (x a : NNReal), x + a ≠ 0 → lcoeff x a = x / (x + a) := by - intro x a hx - exact residual_lambda_from_norm lcoeff x a hx (hnorm x a hx) - -end Residual - -/-! Normalization uniqueness for local mixers -/ - -section Normalize - -variable {S : Type*} [DecidableEq S] - -/-- Appendix A.2: normalize nonnegative weights on a finite subset `A` to a probability row. -/ -noncomputable def normalizeOn (A : Finset S) (w : S → NNReal) : S → NNReal := - fun i => if i ∈ A then w i / (A.sum w) else 0 - -/-- Appendix A.2: outside the support `A`, `normalizeOn A w` is zero. -/ -@[simp] lemma normalizeOn_outside (A : Finset S) (w : S → NNReal) {i : S} (hi : i ∉ A) : - normalizeOn (S:=S) A w i = 0 := by - classical - simp [normalizeOn, hi] - -/-- Appendix A.2: inside the support `A`, `normalizeOn A w` is proportional to `w`. -/ -@[simp] lemma normalizeOn_inside (A : Finset S) (w : S → NNReal) {i : S} (hi : i ∈ A) : - normalizeOn (S:=S) A w i = w i / (A.sum w) := by - classical - simp [normalizeOn, hi] - -/-- Appendix A.2: `normalizeOn A w` is a probability row (sums to 1). -/ -lemma normalizeOn_sum_one [Fintype S] (A : Finset S) (w : S → NNReal) (h : A.sum w ≠ 0) : - (∑ i, normalizeOn (S:=S) A w i) = 1 := by - classical - -- Sum over `univ` of an indicator-style function equals the sum over `A`. - have hsumA := - (Finset.sum_indicator_subset (s:=A) (t:=(Finset.univ : Finset S)) - (f:=fun i => w i / (A.sum w)) (Finset.subset_univ A)) - -- simplify the indicator - have hsumA' : (∑ i, normalizeOn (S:=S) A w i) = A.sum (fun i => w i / (A.sum w)) := by - convert hsumA using 1 - simp [normalizeOn, Set.indicator] - have hsumA'' : A.sum (fun i => w i / (A.sum w)) = (A.sum w) / (A.sum w) := by - simp [Finset.sum_div] - have hsumA1 : A.sum (fun i => w i / (A.sum w)) = 1 := by - simpa [div_self h] using hsumA'' - simpa [hsumA'] using hsumA1 - --- (To be extended) A local uniqueness lemma will connect proportional rows supported on `A` --- and the `normalizeOn` construction via the budget/sum constraint. - -/-- Appendix A.2 (uniqueness): if `p` is supported on `A`, -and proportional to `w` on `A` with total mass 1, then `p` is exactly `normalizeOn A w`. -We also derive the proportionality constant as `1 / A.sum w` using the mass constraint. -/ -lemma proportional_row_unique [Fintype S] - (A : Finset S) (w p : S → NNReal) (k : NNReal) - (hAw : A.sum w ≠ 0) - (hout : ∀ i ∉ A, p i = 0) - (hin : ∀ i ∈ A, p i = k * w i) - (hsum : (∑ i, p i) = 1) : - p = normalizeOn (S:=S) A w := by - classical - -- Represent `p` as an indicator-style function using `hin`/`hout`. - have hrepr : (fun i => if i ∈ A then k * w i else 0) = p := by - funext i - by_cases hi : i ∈ A - · simp [hi, hin i hi] - · simp [hi, hout i hi] - -- The total mass constraint identifies `k`. - -- Sum of indicator over `univ` reduces to a sum over `A`. - have hsumA0 := - (Finset.sum_indicator_subset (s:=A) (t:=(Finset.univ : Finset S)) - (f:=fun i => k * w i) (Finset.subset_univ A)) - have hsumA0' : - (∑ i, ((↑A : Set S).indicator (fun i => k * w i) i : NNReal)) - = A.sum (fun i => k * w i) := by - simpa using hsumA0 - -- Identify the indicator function with the `if-then-else` representation and then with `p`. - have hind : (fun i => ((↑A : Set S).indicator (fun i => k * w i) i : NNReal)) - = (fun i => if i ∈ A then k * w i else 0) := by - funext i; simp [Set.indicator] - have hfun : (fun i => ((↑A : Set S).indicator (fun i => k * w i) i : NNReal)) = p := by - simpa [hrepr] using hind - have hsumA : (∑ i, p i) = A.sum (fun i => k * w i) := by - have : (∑ i, p i) = (∑ i, ((↑A : Set S).indicator (fun i => k * w i) i : NNReal)) := by - simp [hfun] - exact this.trans hsumA0' - have hsumA1 : A.sum (fun i => k * w i) = 1 := by simpa [hsumA] using hsum - -- Move to ℝ to use field-style algebra and cancel the nonzero sum. - have hsumA1R : (A.sum (fun i => ((k : ℝ) * (w i : ℝ)))) = 1 := by - simpa using congrArg (fun t : NNReal => (t : ℝ)) hsumA1 - have hR : (k : ℝ) * (((A.sum w : NNReal) : ℝ)) = 1 := by - -- factor out the constant from the sum and identify the coerced sum - simpa [Finset.mul_sum] using hsumA1R - have hAwR : (((A.sum w : NNReal) : ℝ)) ≠ 0 := by exact_mod_cast hAw - have k_eq : (k : ℝ) = 1 / (((A.sum w : NNReal) : ℝ)) := by - -- from k * S = 1 - exact (eq_div_iff_mul_eq hAwR).2 (by simpa [mul_comm] using hR) - -- Conclude pointwise equality with normalizeOn. - apply funext; intro i; by_cases hiA : i ∈ A - · -- inside A: p i = k * w i = (1 / sum) * w i = w i / sum - have : p i = k * w i := hin i hiA - have : (p i : ℝ) = (k : ℝ) * (w i : ℝ) := by - simpa using congrArg (fun t : NNReal => (t : ℝ)) this - have : (p i : ℝ) = (1 / (((A.sum w : NNReal) : ℝ))) * (w i : ℝ) := by - simpa [k_eq, mul_comm, mul_left_comm, mul_assoc] - using this - -- convert back to NNReal - apply Subtype.ext - -- use commutativity to write as division - have : (p i : ℝ) = (w i : ℝ) / (((A.sum w : NNReal) : ℝ)) := by - simpa [div_eq_mul_inv, mul_comm] using this - -- now identify with normalizeOn - have hnorm : (normalizeOn (S:=S) A w i : ℝ) = (w i : ℝ) / (((A.sum w : NNReal) : ℝ)) := by - -- simplify normalizeOn since i ∈ A - have := normalizeOn_inside (S:=S) A w (i:=i) hiA - -- coerce and rewrite - simp [this] - simpa [hnorm] - · -- outside A: both are 0 - have : p i = 0 := hout i hiA - simp [normalizeOn_outside (S:=S) A w hiA, this] - -end Normalize - -end Nfp - - -/-! -## Consolidated Theorem A.1 (wrapper statements) - -We collect three core ingredients used in Appendix A.1’s uniqueness story as -named wrappers. These are restatements (aliases) of lemmas proved above or in -`Nfp.Uniqueness` so downstream developments can refer to a single place. - -- Residual coefficient uniqueness from row-normalization. -- Normalization uniqueness for proportional rows supported on `A`. -- Global uniqueness of tracer families for a linear local system on a finite DAG. - -These wrappers introduce no new proof obligations; they simply expose the -relevant facts under Appendix A.1’s umbrella. --/ - -namespace Nfp - -/-! Residual coefficient uniqueness (Appendix A.3 inside A.1) -/ - -theorem A1_residual_unique - (lcoeff : NNReal → NNReal → NNReal) - (x a : NNReal) (hx : x + a ≠ 0) - (hnorm : lcoeff x a + a / (x + a) = 1) : - lcoeff x a = x / (x + a) := - residual_lambda_from_norm lcoeff x a hx hnorm - -/-! Normalization uniqueness on a support (Appendix A.2 inside A.1) -/ - -theorem A1_normalize_unique - {S : Type*} [Fintype S] [DecidableEq S] - (A : Finset S) (w p : S → NNReal) (k : NNReal) - (hAw : A.sum w ≠ 0) - (hout : ∀ i ∉ A, p i = 0) - (hin : ∀ i ∈ A, p i = k * w i) - (hsum : (∑ i, p i) = 1) : - p = normalizeOn (S:=S) A w := - proportional_row_unique (S:=S) A w p k hAw hout hin hsum - -/-! Global linear uniqueness on a DAG (Appendix A.1 via `LocalSystem`) -/ - -theorem A1_global_tracer_unique - {S : Type*} {n : ℕ} - (L : LocalSystem n) - {T T' : LocalSystem.TracerFamily (S := S) n} - (hT : LocalSystem.Satisfies (S := S) L T) - (hT' : LocalSystem.Satisfies (S := S) L T') : - T = T' := - LocalSystem.tracer_unique (S:=S) L hT hT' - -/-! -## Packaged Appendix A.1 theorem - -We bundle the three core uniqueness components of Appendix A.1 into a single -exported statement `A1` returning a triple of universally quantified facts: - -- residual coefficient uniqueness from row-normalization (A.3 used in A.1), -- normalization uniqueness on a given support (A.2 used in A.1), -- global tracer uniqueness for a linear local system on a finite DAG (A.1). --/ - -/-- Appendix A.1 (packaged): a conjunction of the three core uniqueness results -used in the Appendix narrative. Each component is universally quantified over -its parameters and follows directly from the dedicated wrapper theorems above. -/ -theorem A1 : - (∀ (lcoeff : NNReal → NNReal → NNReal) (x a : NNReal), - x + a ≠ 0 → lcoeff x a + a / (x + a) = 1 → - lcoeff x a = x / (x + a)) ∧ - (∀ {S : Type*} [Fintype S] [DecidableEq S] - (A : Finset S) (w p : S → NNReal) (k : NNReal), - A.sum w ≠ 0 → - (∀ i ∉ A, p i = 0) → - (∀ i ∈ A, p i = k * w i) → - (∑ i, p i) = 1 → - p = normalizeOn (S:=S) A w) ∧ - (∀ {S : Type*} {n : ℕ} - (L : LocalSystem n) - {T T' : LocalSystem.TracerFamily (S := S) n}, - LocalSystem.Satisfies (S := S) L T → - LocalSystem.Satisfies (S := S) L T' → - T = T') := by - refine ⟨?residual, ?normalize, ?global⟩ - · intro lcoeff x a hx hnorm; exact A1_residual_unique lcoeff x a hx hnorm - · intro S _ _ A w p k hAw hout hin hsum; - exact A1_normalize_unique (S:=S) A w p k hAw hout hin hsum - · intro S n L T T' hT hT'; exact A1_global_tracer_unique (S:=S) L hT hT' - -end Nfp - -/-! -## Appendix A.2 (axioms-as-theorems wrappers) - -We expose the A.2 "axioms" as formal properties derived from our basic -constructions. Rather than assuming them as axioms, we state them as theorems -that hold for our generic primitives: - -- Graph faithfulness: composition preserves support restrictions. -- Residual additivity: the energy-consistent coefficients sum to 1 and are - uniquely determined by row-normalization. -- Row-normalization: `normalizeOn` produces rows that sum to 1 over a finite - support. - -These are lightweight restatements of previously proven lemmas so the Appendix -can reference stable names. --/ - -namespace Nfp - -/-! Graph faithfulness under composition (Axiom 1) -/ - -/- -Appendix A.2 (Graph-faithfulness under composition). -If mixers `M : S → T` and `N : T → U` are supported by relations `R` and `Q` -respectively, then their composition is supported by the composed relation. -This is a wrapper around `Mixer.supported_comp` so Appendix A can refer to a -stable name. See docs/APPENDIX.md §A.2. --/ -theorem A2_graph_faithful_comp - {S T U : Type*} [Fintype S] [Fintype T] [Fintype U] - (M : Mixer S T) (N : Mixer T U) - (R : S → T → Prop) (Q : T → U → Prop) - (hM : Mixer.supported (S := S) (T := T) M R) - (hN : Mixer.supported (S := T) (T := U) N Q) : - Mixer.supported (S := S) (T := U) (M.comp N) (Mixer.compSupport R Q) := - Mixer.supported_comp (S := S) (T := T) (U := U) (M := M) (N := N) hM hN - -/-! Residual additivity and energy-consistency (Axiom 2) -/ - -/- -Appendix A.2 (Residual additivity / energy consistency). -For residual coefficient `λ := lambdaEC x a = x/(x+a)`, we have -`λ + a/(x+a) = 1`. See docs/APPENDIX.md §A.2. --/ -theorem A2_residual_energy_consistent (x a : NNReal) (hx : x + a ≠ 0) : - lambdaEC x a + a / (x + a) = 1 := - lambdaEC_sum_one x a hx - -/- -Appendix A.2 (Residual uniqueness from row-normalization). -Any residual rule that satisfies the row-normalization equation is uniquely -determined as `x/(x+a)`. Wrapper of `residual_lambda_from_norm`. -See docs/APPENDIX.md §A.2. --/ -theorem A2_residual_unique - (lcoeff : NNReal → NNReal → NNReal) - (x a : NNReal) (hx : x + a ≠ 0) - (hnorm : lcoeff x a + a / (x + a) = 1) : - lcoeff x a = x / (x + a) := - residual_lambda_from_norm lcoeff x a hx hnorm - -/-! Row-normalization on a finite support (part of Axioms 1/3/6) -/ - -/- -Appendix A.2 (Row-normalization on a finite support). -The `normalizeOn` construction sums to 1 provided the total mass on `A` is -nonzero. Wrapper of `normalizeOn_sum_one`. See docs/APPENDIX.md §A.2. --/ -theorem A2_normalize_row_sum_one {S : Type*} [Fintype S] [DecidableEq S] - (A : Finset S) (w : S → NNReal) (h : A.sum w ≠ 0) : - (∑ i, normalizeOn (S := S) A w i) = 1 := - normalizeOn_sum_one (S := S) A w h - -/-! -### Packaged Appendix A.2 statement - -We bundle three generic A.2 facets into a single theorem `A2`: - -1. Residual uniqueness from row-normalization at a residual node. -2. Row-normalization of proportional weights on a finite support via `normalizeOn`. -3. Preservation of support/faithfulness under composition of mixers. - -This mirrors the intention that the "axioms" are consequences of our formal -definitions rather than assumptions. --/ - -/- -Appendix A.2 (packaged statement). -Conjunction of: residual uniqueness, row-normalization via `normalizeOn`, and -graph-faithfulness of mixer composition. See docs/APPENDIX.md §A.2. --/ -theorem A2 : - (∀ (lcoeff : NNReal → NNReal → NNReal) (x a : NNReal), - x + a ≠ 0 → lcoeff x a + a / (x + a) = 1 → - lcoeff x a = x / (x + a)) ∧ - (∀ {S : Type*} [Fintype S] [DecidableEq S] - (A : Finset S) (w : S → NNReal), - A.sum w ≠ 0 → (∑ i, normalizeOn (S := S) A w i) = 1) ∧ - (∀ {S T U : Type*} [Fintype S] [Fintype T] [Fintype U] - (M : Mixer S T) (N : Mixer T U) - (R : S → T → Prop) (Q : T → U → Prop), - Mixer.supported (S := S) (T := T) M R → - Mixer.supported (S := T) (T := U) N Q → - Mixer.supported (S := S) (T := U) (M.comp N) (Mixer.compSupport R Q)) := by - refine ⟨?resid, ?row, ?faith⟩ - · intro lcoeff x a hx hnorm; exact A2_residual_unique lcoeff x a hx hnorm - · intro S _ _ A w h; exact A2_normalize_row_sum_one (S := S) A w h - · intro S T U _ _ _ M N R Q hM hN; - exact A2_graph_faithful_comp (M:=M) (N:=N) R Q hM hN - -/-- Appendix A.2 compatibility: an `InfluenceSpec` with nonzero rows yields a -graph-faithful mixer via `ofInfluenceSpec`, ready to feed into `A2` results. -/ -lemma A2_for_ofInfluenceSpec {Site : Type*} [Fintype Site] [DecidableEq Site] - (I : InfluenceSpec Site) (hZ : ∀ s, InfluenceSpec.rowTotal (Site := Site) I s ≠ 0) : - Mixer.supported (S := Site) (T := Site) (Mixer.ofInfluenceSpec (Site := Site) I) I.adj := - Mixer.ofInfluenceSpec_supported (Site := Site) (I := I) hZ - -end Nfp diff --git a/Legacy/Nfp/Attribution.lean b/Legacy/Nfp/Attribution.lean deleted file mode 100644 index b672537..0000000 --- a/Legacy/Nfp/Attribution.lean +++ /dev/null @@ -1,237 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.Real.Basic -import Mathlib.Data.NNReal.Basic -import Mathlib.Data.Fintype.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Data.Finset.Basic -import Nfp.Prob -import Nfp.Mixer - -/-! -# Attribution Axioms for Neural Network Interpretation - -This module formalizes key properties (axioms) that attribution methods for -neural networks should satisfy. These axioms are used to characterize and -compare different interpretation methods such as Integrated Gradients, LRP, -Shapley values, and path-based attribution. - -## Main definitions - -* `Attribution` – an attribution assigns a contribution score to each input feature -* `Completeness` – contributions sum to the output difference from baseline -* `Sensitivity` – if an input affects output, it receives nonzero attribution -* `Implementation Invariance` – attributions depend only on input-output behavior -* `Linearity` – attribution is linear in the function being explained - -## Key theorems - -* `completeness_of_conservation` – conservation mixers induce complete attributions -* `tracer_attribution_complete` – tracer-based attributions satisfy completeness - -## References - -* Sundararajan, Taly, Yan: "Axiomatic Attribution for Deep Networks" (ICML 2017) -* Ancona et al.: "Towards better understanding of gradient-based attribution" (2018) -* Lundstrom et al.: "A Rigorous Study of Integrated Gradients" (2022) --/ - -namespace Nfp - -open scoped BigOperators -open Finset - -/-! ## Attribution structure -/ - -section Attribution - -variable {Input Output : Type*} [Fintype Input] [Fintype Output] - -/-- An attribution method assigns a contribution score to each input feature -for a given output. Scores are signed reals to allow negative contributions. -/ -structure Attribution (Input Output : Type*) [Fintype Input] [Fintype Output] where - /-- The contribution of input `i` to output `o`. -/ - contrib : Input → Output → ℝ - -namespace Attribution - -variable (A : Attribution Input Output) - -/-- Total contribution to a specific output. -/ -noncomputable def totalContrib (o : Output) : ℝ := - ∑ i, A.contrib i o - -/-- An attribution is nonnegative if all contributions are nonnegative. -/ -def Nonneg : Prop := - ∀ i o, 0 ≤ A.contrib i o - -end Attribution - -end Attribution - -/-! ## Completeness axiom -/ - -section Completeness - -variable {Input Output : Type*} [Fintype Input] [Fintype Output] - -/-- The completeness axiom: contributions sum to the difference between the -output at input `x` and the output at baseline `x₀`. This is the core axiom -from Integrated Gradients and Shapley value attribution. - -For neural networks: `f(x) - f(x₀) = ∑ᵢ attribution(i)` -/ -def Attribution.Complete - (A : Attribution Input Output) - (f : (Input → ℝ) → Output → ℝ) - (x x₀ : Input → ℝ) - (o : Output) : Prop := - A.totalContrib o = f x o - f x₀ o - -/-- Completeness for a single-output function (common case). -/ -def Attribution.CompleteScalar - (A : Attribution Input Unit) - (f : (Input → ℝ) → ℝ) - (x x₀ : Input → ℝ) : Prop := - A.totalContrib () = f x - f x₀ - -end Completeness - -/-! ## Sensitivity axiom -/ - -section Sensitivity - -variable {Input Output : Type*} [Fintype Input] [Fintype Output] - -/-- An input feature `i` influences output `o` for function `f` at point `x` -relative to baseline `x₀` if changing just that feature changes the output. -/ -def Influences - (f : (Input → ℝ) → Output → ℝ) - (x x₀ : Input → ℝ) - (i : Input) - (o : Output) : Prop := - ∃ (x' : Input → ℝ), - (∀ j, j ≠ i → x' j = x₀ j) ∧ - x' i = x i ∧ - f x' o ≠ f x₀ o - -/-- The sensitivity axiom: if a feature influences the output, it receives -nonzero attribution. -/ -def Attribution.Sensitive - (A : Attribution Input Output) - (f : (Input → ℝ) → Output → ℝ) - (x x₀ : Input → ℝ) : Prop := - ∀ i o, Influences f x x₀ i o → A.contrib i o ≠ 0 - -/-- Dummy axiom: features that don't influence output get zero attribution. -/ -def Attribution.Dummy - (A : Attribution Input Output) - (f : (Input → ℝ) → Output → ℝ) - (x x₀ : Input → ℝ) : Prop := - ∀ i o, ¬ Influences f x x₀ i o → A.contrib i o = 0 - -end Sensitivity - -/-! ## Conservation and tracer-based attribution -/ - -section Conservation - -variable {S : Type*} [Fintype S] - -/-- A mixer `M` is mass-conserving if the total outgoing mass equals -the total incoming mass. For row-stochastic mixers, this is automatic. -/ -lemma Mixer.conserves_mass (M : Mixer S S) (p : ProbVec S) : - (∑ i, (M.push p).mass i) = ∑ i, p.mass i := by - simp only [ProbVec.sum_mass] - -/-- Attribution derived from a probability distribution (tracer mass). -/ -noncomputable def attributionOfProbVec [Fintype S] - (p : ProbVec S) : Attribution S Unit where - contrib := fun i _ => (p.mass i : ℝ) - -/-- Tracer-based attributions automatically satisfy completeness with respect -to total mass (which is 1). -/ -theorem tracer_attribution_complete (p : ProbVec S) : - (attributionOfProbVec p).totalContrib () = 1 := by - simp only [Attribution.totalContrib, attributionOfProbVec] - have h := p.norm_one - simp only [← NNReal.coe_sum] at h ⊢ - exact_mod_cast h - -end Conservation - -/-! ## Linearity axiom -/ - -section Linearity - -variable {Input Output : Type*} [Fintype Input] [Fintype Output] - -/-- A method that produces attributions for any function. -/ -abbrev AttributionMethod (Input Output : Type*) [Fintype Input] [Fintype Output] := - ((Input → ℝ) → Output → ℝ) → (Input → ℝ) → (Input → ℝ) → Attribution Input Output - -/-- The linearity axiom: attribution of a sum of functions equals the sum -of attributions. -/ -def AttributionMethod.Linear (method : AttributionMethod Input Output) : Prop := - ∀ (f g : (Input → ℝ) → Output → ℝ) (x x₀ : Input → ℝ) (i : Input) (o : Output), - (method (fun inp => fun o => f inp o + g inp o) x x₀).contrib i o = - (method f x x₀).contrib i o + (method g x x₀).contrib i o - -/-- Scale invariance: scaling the function scales the attribution. -/ -def AttributionMethod.ScaleInvariant (method : AttributionMethod Input Output) : Prop := - ∀ (f : (Input → ℝ) → Output → ℝ) (c : ℝ) (x x₀ : Input → ℝ) (i : Input) (o : Output), - (method (fun inp => fun o => c * f inp o) x x₀).contrib i o = - c * (method f x x₀).contrib i o - -end Linearity - -/-! ## Symmetry and efficiency -/ - -section Symmetry - -variable {Input Output : Type*} [Fintype Input] [Fintype Output] [DecidableEq Input] - -/-- Swap two input coordinates in a feature vector. -/ -def swapInputs (x : Input → ℝ) (i j : Input) : Input → ℝ := - fun k => if k = i then x j else if k = j then x i else x k - -/-- Two inputs are symmetric for a function if swapping them preserves the output. -/ -def SymmetricInputs - (f : (Input → ℝ) → Output → ℝ) - (i j : Input) : Prop := - ∀ (x : Input → ℝ) (o : Output), - f x o = f (swapInputs x i j) o - -/-- The symmetry axiom: symmetric inputs receive equal attribution. -/ -def Attribution.Symmetric - (A : Attribution Input Output) - (f : (Input → ℝ) → Output → ℝ) : Prop := - ∀ i j o, SymmetricInputs f i j → A.contrib i o = A.contrib j o - -end Symmetry - -/-! ## Path-based attribution -/ - -section PathAttribution - -variable {Input : Type*} - -/-- A path from baseline `x₀` to input `x` parameterized by `t ∈ [0,1]`. -/ -abbrev Path (Input : Type*) := (t : NNReal) → Input → ℝ - -/-- A straight-line (linear interpolation) path between two points. -/ -noncomputable def straightPath (x x₀ : Input → ℝ) : Path Input := - fun t i => (1 - (t : ℝ)) * x₀ i + (t : ℝ) * x i - -/-- A valid path starts at baseline and ends at input. -/ -structure Path.Valid (γ : Path Input) (x x₀ : Input → ℝ) : Prop where - at_zero : ∀ i, γ 0 i = x₀ i - at_one : ∀ i, γ 1 i = x i - -lemma straightPath_valid (x x₀ : Input → ℝ) : (straightPath x x₀).Valid x x₀ := by - constructor - · intro i; simp [straightPath] - · intro i; simp [straightPath] - -end PathAttribution - -end Nfp diff --git a/Legacy/Nfp/Discovery.lean b/Legacy/Nfp/Discovery.lean deleted file mode 100644 index 5d48e25..0000000 --- a/Legacy/Nfp/Discovery.lean +++ /dev/null @@ -1,9929 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Batteries.Lean.Float -import Init.Data.Array.Extract - -/-! -# Executable Circuit Discovery for Induction Heads - -This module provides executable functions for discovering **candidate induction heads** -from concrete model weights. It bridges the theoretical framework (Frobenius norms, -pattern terms, faithfulness bounds) with practical, Float-based analysis of real networks. - -Important: these routines are **heuristic** and are not kernel-sound. Sound certification -lives in `Nfp.Sound.*`. - -## Key Components - -1. **Concrete Model Structures**: Computable representations of attention layers using - Arrays and Floats instead of abstract types. Suitable for export from PyTorch/JAX. - -2. **Efficient Bound Calculation**: Algorithms to compute `patternTerm` and `valueTerm` - bounds without materializing the full (N·D)² Jacobian matrix. - -3. **Discovery Functions**: Search algorithms that iterate over layer pairs to find - candidate virtual heads (e.g., induction heads). - -## Mathematical Background - -For an attention layer with weights (W_Q, W_K, W_V, W_O), the Jacobian at input x -decomposes as: `fullJacobian = valueTerm + patternTerm` where: -- `valueTerm` depends only on attention weights A and projections W_V·W_O -- `patternTerm` captures how A shifts when input changes (the error term) - -The **faithfulness bound** states: if ‖patternTerm‖_F ≤ ε, then the simple -attention-based interpretation is ε-accurate. In this module those bounds are -computed with `Float`s for speed and should be treated as diagnostics. - -## Performance Optimizations - -This module contains critical hot paths for circuit discovery. Key optimizations: - -1. **Array pre-allocation**: Use `Array.ofFn` instead of repeated `push` operations - - Avoids O(n²) copying from array reallocations - - Memory usage: O(n) instead of O(n²) during construction - -2. **Direct loops over List.range.foldl**: Replace `(List.range n).foldl f acc` with - `for i in [:n] do ...` - eliminates intermediate list construction (10-100× faster) - -3. **Bounds-checked array access**: Use `array[i]!` which panics on out-of-bounds - instead of `getD` which silently returns default values - - Makes bugs explicit rather than silent - - Compiler can optimize bounds checks in loops - -4. **Matrix operations**: Pre-allocated `Array.ofFn` with `Id.run do` blocks for - complex computations (matmul, matVecMul, power iteration) - -**Benchmark impact** (GPT-2 Small, 12 layers × 12 heads): -- Matrix operations: 10-50× faster (direct loops vs List.range.foldl) -- Array construction: 2-5× faster (pre-allocation vs repeated push) -- Memory: 50% reduction (no intermediate copies) - -These optimizations make circuit discovery practical on real models (seconds instead -of minutes for full network analysis). --/ - -namespace Nfp - -@[inline] private def sumSquares (xs : Array Float) : Float := Id.run do - let mut acc : Float := 0.0 - for x in xs do - acc := acc + x * x - return acc - -@[inline] private def sumFloatArray (xs : Array Float) : Float := Id.run do - let mut acc : Float := 0.0 - for x in xs do - acc := acc + x - return acc - -@[inline] private def sumNatArray (xs : Array Nat) : Nat := Id.run do - let mut acc : Nat := 0 - for x in xs do - acc := acc + x - return acc - -@[inline] private def sumSizes {α : Type} (chunks : Array (Array α)) : Nat := Id.run do - let mut acc : Nat := 0 - for cs in chunks do - acc := acc + cs.size - return acc - -@[inline] private def maxArray (xs : Array Float) : Float := Id.run do - let mut m : Float := 0.0 - for x in xs do - if x > m then - m := x - return m - -@[inline] private def countTrue (xs : Array Bool) : Nat := Id.run do - let mut acc : Nat := 0 - for b in xs do - if b then - acc := acc + 1 - return acc - -@[inline] private def countTrueNested (xs : Array (Array Bool)) : Nat := Id.run do - let mut acc : Nat := 0 - for row in xs do - acc := acc + countTrue row - return acc - -/-! ## Concrete Weight Representations -/ - -/-- A concrete weight matrix stored as nested Arrays. -This is the computable representation for export from PyTorch/JAX. -/ -structure ConcreteMatrix where - /-- Number of rows -/ - numRows : Nat - /-- Number of columns -/ - numCols : Nat - /-- Row-major data storage. data[i * numCols + j] = entry (i, j) -/ - data : Array Float - /-- Data has the correct size -/ - size_eq : data.size = numRows * numCols - -namespace ConcreteMatrix - -/-- Access element (i, j) of the matrix. Returns 0 if out of bounds. - -PERFORMANCE: This is the guarded accessor; hot loops should prefer `getUnsafe` -once bounds are established to avoid the per-access branch. --/ -def get (M : ConcreteMatrix) (i j : Nat) : Float := - if i < M.numRows ∧ j < M.numCols then - -- Index is in-bounds by `size_eq` and the guard above. - M.data[i * M.numCols + j]! - else 0.0 - -/-- Fast access to element `(i, j)` assuming `i < numRows` and `j < numCols`. -/ -@[inline] def getUnsafe (M : ConcreteMatrix) (i j : Nat) : Float := - M.data[i * M.numCols + j]! - -/-- Set element `(i,j)` of the matrix. If out of bounds, returns the original matrix. - -PERFORMANCE: This is intended for small matrices (e.g. `headDim×headDim` Grams) where copying the -underlying array is cheap. Prefer bulk construction for large matrices. --/ -def set (M : ConcreteMatrix) (i j : Nat) (val : Float) : ConcreteMatrix := - if h : i < M.numRows ∧ j < M.numCols then - let idx := i * M.numCols + j - { numRows := M.numRows - numCols := M.numCols - data := M.data.set! idx val - size_eq := by simpa [idx] using M.size_eq } - else - M - -/-- Maximum absolute entry in a given row. Returns 0 if the row is out of bounds. -/ -def rowMaxAbs (M : ConcreteMatrix) (r : Nat) : Float := - if r < M.numRows then - Id.run do - let mut m : Float := 0.0 - let rowBase := r * M.numCols - for c in [:M.numCols] do - let a := Float.abs (M.data[rowBase + c]!) - if a > m then - m := a - return m - else - 0.0 - -/-- Take the first `n` rows of a matrix (keeping all columns). -/ -def takeRows (M : ConcreteMatrix) (n : Nat) : ConcreteMatrix := - if h : n ≥ M.numRows then - M - else - let rowCount := n * M.numCols - { numRows := n - numCols := M.numCols - data := M.data.extract 0 rowCount - size_eq := by - have hrows : n ≤ M.numRows := Nat.le_of_lt (Nat.lt_of_not_ge h) - have hsize : rowCount ≤ M.data.size := by - simpa [rowCount, M.size_eq] using Nat.mul_le_mul_right M.numCols hrows - simpa [rowCount] using - (Array.size_extract_of_le (as := M.data) (i := 0) (j := rowCount) hsize) } - -/-- Create a zero matrix of given dimensions. -/ -def zeros (rows cols : Nat) : ConcreteMatrix where - numRows := rows - numCols := cols - data := .ofFn fun _ : Fin (rows * cols) => (0.0 : Float) - size_eq := Array.size_ofFn - -/-- Create an all-ones matrix of given dimensions. -/ -def ones (rows cols : Nat) : ConcreteMatrix where - numRows := rows - numCols := cols - data := .ofFn fun _ : Fin (rows * cols) => (1.0 : Float) - size_eq := Array.size_ofFn - -/-- Create an identity matrix. -/ -def identity (n : Nat) : ConcreteMatrix where - numRows := n - numCols := n - data := .ofFn fun idx : Fin (n * n) => - let i := idx.val / n - let j := idx.val % n - if i = j then 1.0 else 0.0 - size_eq := Array.size_ofFn - -/-- Matrix multiplication. - -PERFORMANCE CRITICAL: This is the hottest path in circuit discovery. -- Pre-allocates result array with `Array.ofFn` (no intermediate copies) -- Direct `for k in [:A.numCols]` instead of `List.range.foldl` (10-50× faster) -- Uses `Id.run do` to enable mutable accumulator in pure context -- Uses deterministic task-parallelism for very large products (preserving evaluation order - *within* each dot product). --/ -private def matmulSeqCore (A B : ConcreteMatrix) : ConcreteMatrix := - { - numRows := A.numRows - numCols := B.numCols - data := .ofFn fun idx : Fin (A.numRows * B.numCols) => Id.run do - let i := idx.val / B.numCols - let j := idx.val % B.numCols - let mut acc : Float := 0.0 - let aRowBase := i * A.numCols - for k in [:A.numCols] do - -- SAFETY: within this branch `i < A.numRows` and `k < A.numCols`, - -- and `A.size_eq` implies `aRowBase + k < A.data.size`. - let a := A.data[aRowBase + k]! - -- SAFETY: within this branch `k < B.numRows` and `j < B.numCols`, - -- and `B.size_eq` implies `k * B.numCols + j < B.data.size`. - let b := B.data[k * B.numCols + j]! - acc := acc + a * b - return acc - size_eq := Array.size_ofFn - } - -private def matmulParFlopThreshold : Nat := 10_000_000 -private def matmulParMaxTasks : Nat := 16 -private def matmulParMinInnerDim : Nat := 256 -private def matmulParMinOutCols : Nat := 256 -private def matmulParMinRows : Nat := 2 - -private def shouldUseMatmulPar (A B : ConcreteMatrix) : Bool := - let flops := A.numRows * B.numCols * A.numCols - flops ≥ matmulParFlopThreshold && - A.numCols ≥ matmulParMinInnerDim && - B.numCols ≥ matmulParMinOutCols && - A.numRows ≥ matmulParMinRows - -private def matmulPar (A B : ConcreteMatrix) : ConcreteMatrix := - if A.numRows = 0 || B.numCols = 0 then - matmulSeqCore A B - else - let numTasks := min matmulParMaxTasks A.numRows - let q := A.numRows / numTasks - let r := A.numRows % numTasks - - let tasks : Array (Task (Array Float)) := - .ofFn fun t : Fin numTasks => - Task.spawn (fun _ => - let tid := t.val - let extra := if tid < r then 1 else 0 - let rowsHere := q + extra - let startRow := tid * q + min tid r - let chunkSize := rowsHere * B.numCols - .ofFn fun idx : Fin chunkSize => Id.run do - let localRow := idx.val / B.numCols - let j := idx.val % B.numCols - let i := startRow + localRow - let mut acc : Float := 0.0 - let aRowBase := i * A.numCols - for k in [:A.numCols] do - -- SAFETY: `i < A.numRows` by chunk construction and `k < A.numCols` by loop bound. - let a := A.data[aRowBase + k]! - -- SAFETY: `k < B.numRows` and `j < B.numCols` by loop bounds and chunk indexing. - let b := B.data[k * B.numCols + j]! - acc := acc + a * b - return acc) - - -- Join in increasing task index order (deterministic). - let chunks := tasks.map Task.get - let cutoff := (q + 1) * r - - { - numRows := A.numRows - numCols := B.numCols - data := .ofFn fun idx : Fin (A.numRows * B.numCols) => - let row := idx.val / B.numCols - let col := idx.val % B.numCols - let taskIdx := - if row < cutoff then - row / (q + 1) - else - r + (row - cutoff) / q - let localRow := - if row < cutoff then - row % (q + 1) - else - (row - cutoff) % q - -- SAFETY: `taskIdx < numTasks` by construction; chunks are in task order. - let chunk := chunks[taskIdx]! - -- SAFETY: `localRow < rowsHere` for this chunk and `col < B.numCols`. - chunk[localRow * B.numCols + col]! - size_eq := Array.size_ofFn - } - -def matmul (A B : ConcreteMatrix) : ConcreteMatrix := - if A.numCols = B.numRows then - if shouldUseMatmulPar A B then - matmulPar A B - else - matmulSeqCore A B - else zeros 0 0 - -/-- Compute Frobenius norm squared: Σᵢⱼ M[i,j]². -/ -def frobeniusNormSq (M : ConcreteMatrix) : Float := - sumSquares M.data - -/-- Compute Frobenius norm: √(Σᵢⱼ M[i,j]²). -/ -def frobeniusNorm (M : ConcreteMatrix) : Float := - Float.sqrt M.frobeniusNormSq - -/-- Compute `trace(A · B)` without allocating the product. - -This uses `trace(A·B) = ∑_{i,j} A[i,j] · B[j,i]`. -Returns 0.0 if the dimensions do not line up. --/ -def traceMul (A B : ConcreteMatrix) : Float := Id.run do - if A.numCols ≠ B.numRows then return 0.0 - if A.numRows ≠ B.numCols then return 0.0 - if A.numRows ≠ A.numCols then return 0.0 - let n := A.numRows - let mut acc : Float := 0.0 - for i in [:n] do - let aRowBase := i * A.numCols - for j in [:n] do - -- SAFETY: `i,j < n` and both matrices are `n×n`. - acc := acc + A.data[aRowBase + j]! * B.data[j * B.numCols + i]! - return acc - -/-! ### Float numerics (heuristics) - -The definitions in this section use `Float` arithmetic for speed. - -Important: these are **not** kernel-sound upper bounds in general. -They are best-effort numerical estimates (rounding may under- or over-estimate). -Sound certification lives in `Nfp.Sound.*`. --/ - -/-! #### Non-certified estimates - -The functions in this subsection are **not** mathematically certified upper bounds. -They may be useful for diagnostics, but must not feed into the repo's “rigorous” -error / ε pipelines. --/ - -/-- Heuristic operator-norm estimate via power iteration. - -The operator norm ‖M‖₂ = max‖x‖=1 ‖x·M‖ (row-vector convention) is the largest singular value. -We approximate it using power iteration on M^T M. - -This is a fast **heuristic estimate** of how much `M` can stretch a vector. - -PERFORMANCE: Power iteration is O(iterations × n²) but heavily optimized: -- Pre-allocated vectors reused across iterations (`Array.replicate` + `set!`) -- Direct loops instead of `List.range.foldl` (10× faster) -- Bounds-checked access `v[j]!` and `Mv[i]!` (compiler optimizes in loops) --/ -def operatorNormHeuristicPI (M : ConcreteMatrix) (numIterations : Nat := 20) : Float := Id.run do - let numRows := M.numRows - let numCols := M.numCols - if numRows = 0 || numCols = 0 then return 0.0 - - -- Initialize with a normalized vector of ones (avoids a fold + map). - let initScale := 1.0 / Float.sqrt numCols.toFloat - let mut v : Array Float := Array.replicate numCols initScale - let mut Mv : Array Float := Array.replicate numRows 0.0 - let mut MTMv : Array Float := Array.replicate numCols 0.0 - - -- Power iteration: v ← (M^T M) v / ‖(M^T M) v‖ - let mut sigma : Float := 0.0 - for _ in [:numIterations] do - -- Compute M v - let mut mvNormSq : Float := 0.0 - for i in [:numRows] do - let mut acc : Float := 0.0 - let rowBase := i * numCols - for j in [:numCols] do - -- SAFETY: v has size M.numCols, guaranteed by Array.replicate. - acc := acc + M.data[rowBase + j]! * v[j]! - Mv := Mv.set! i acc - mvNormSq := mvNormSq + acc * acc - - -- Compute M^T (M v) = (M^T M) v - let mut mtmvNormSq : Float := 0.0 - for j in [:numCols] do - let mut acc : Float := 0.0 - for i in [:numRows] do - -- SAFETY: Mv has size M.numRows, guaranteed by Array.replicate. - acc := acc + M.data[i * numCols + j]! * Mv[i]! - MTMv := MTMv.set! j acc - mtmvNormSq := mtmvNormSq + acc * acc - - -- Compute norm of MTMv (this is σ² times ‖v‖, and ‖v‖ ≈ 1) - let norm := Float.sqrt mtmvNormSq - - if norm < 1e-15 then - return 0.0 - - -- σ² ≈ ‖MTMv‖ / ‖v‖ ≈ ‖MTMv‖ - -- So σ ≈ ‖Mv‖ - sigma := Float.sqrt mvNormSq - - -- Normalize for next iteration - for j in [:numCols] do - v := v.set! j (MTMv[j]! / norm) - - -- Heuristic safety margin for numerical errors - sigma * 1.01 - - -/-! ### Provable eigenvalue upper bounds (PSD moments) - -The helper below is used to tighten bounds of the form `λ_max(G)` where `G` is -symmetric positive semidefinite (PSD), using only the first two spectral moments. - -Let `λ₁ ≥ λ₂ ≥ ... ≥ λₙ ≥ 0` be the eigenvalues of `G`. -Write -- `tr = Σᵢ λᵢ = trace(G)` -- `f2 = Σᵢ λᵢ² = ‖G‖_F²`. - -Among all nonnegative spectra with fixed `tr` and `f2`, the maximum possible `λ₁` -is achieved when the remaining `n-1` eigenvalues are equal. Solving that extremal -case yields the closed-form upper bound: - - λ₁ ≤ (tr + sqrt((n-1) * (n*f2 - tr^2))) / n. - -We defensively clamp the radicand to `≥ 0` to avoid negative values caused by -Float roundoff. --/ - -/-- PSD moment bound on the maximum eigenvalue from trace and Frobenius-squared. - -Inputs: -- `n`: matrix dimension -- `tr = trace(G)` -- `f2 = ‖G‖_F²` - -Output: a deterministic `Float` expression corresponding to the real bound above. --/ -def psdLambdaMaxUpperMoment (n : Nat) (tr f2 : Float) : Float := - if n = 0 then - 0.0 - else if n = 1 then - -- For 1×1 PSD matrices, λ_max = tr. - max 0.0 tr - else - let nF : Float := n.toFloat - let rad := max 0.0 ((n - 1).toFloat * (nF * f2 - tr * tr)) - let root := Float.sqrt rad - max 0.0 ((tr + root) / nF) - - -/-- Estimate maximum absolute row sum (induced ℓ1 for row-vector convention, -induced ℓ∞ for column-vector convention). - -Mathematically, the real-valued quantity `maxᵢ Σⱼ |M[i,j]|` is an induced matrix norm. -This function computes a `Float` approximation and is therefore a heuristic estimate. --/ -def maxAbsRowSumEst (M : ConcreteMatrix) : Float := Id.run do - let mut maxSum : Float := 0.0 - for i in [:M.numRows] do - let mut rowSum : Float := 0.0 - let rowBase := i * M.numCols - for j in [:M.numCols] do - rowSum := rowSum + Float.abs (M.data[rowBase + j]!) - maxSum := max maxSum rowSum - maxSum - - -/-- Estimate maximum absolute column sum (induced ℓ∞ for row-vector convention, -induced ℓ1 for column-vector convention). - -Mathematically, the real-valued quantity `maxⱼ Σᵢ |M[i,j]|` is an induced matrix norm. -This function computes a `Float` approximation and is therefore a heuristic estimate. --/ -def maxAbsColSumEst (M : ConcreteMatrix) : Float := Id.run do - let mut maxSum : Float := 0.0 - for j in [:M.numCols] do - let mut colSum : Float := 0.0 - for i in [:M.numRows] do - let rowBase := i * M.numCols - colSum := colSum + Float.abs (M.data[rowBase + j]!) - maxSum := max maxSum colSum - maxSum - - -/-- Rigorous (inequality-based) one/inf upper bound on `‖M‖₂`. - -In exact real arithmetic: -`‖M‖₂ ≤ sqrt(‖M‖₁ · ‖M‖∞)`. - -We compute max row/column sums from the stored Float entries; `‖·‖₁`/`‖·‖∞` -swap under transpose, so the bound is convention-agnostic. --/ -def opNormUpperBoundOneInf (M : ConcreteMatrix) : Float := - Float.sqrt (M.maxAbsRowSumEst * M.maxAbsColSumEst) - -/-- Heuristic Schur-type estimate: `sqrt(‖M‖₁ · ‖M‖∞)` computed in `Float`. - -In exact real arithmetic, `sqrt(‖M‖₁ · ‖M‖∞)` upper-bounds the spectral norm. -This implementation uses `Float`, so it should be treated as an estimate. --/ -def schurNormEst (M : ConcreteMatrix) : Float := - M.opNormUpperBoundOneInf - -/-- Cheap operator-norm upper bound formula for a concrete real matrix. - -In exact real arithmetic, we have the standard inequalities: -- `‖M‖₂ ≤ ‖M‖_F` -- `‖M‖₂ ≤ sqrt(‖M‖₁ · ‖M‖∞)` - -In exact real arithmetic, taking `min` can only tighten a valid upper bound. -Here we compute both in `Float`, so treat the result as an estimate. --/ -def opNormUpperBoundCheap (M : ConcreteMatrix) : Float := - let frob := M.frobeniusNorm - let schur := M.schurNormEst - min frob schur - -/-- Dense Frobenius upper bound on `‖M‖₂`: `‖M‖₂ ≤ ‖M‖_F`. - -This is cheap and always valid in exact real arithmetic. --/ -def opNormUpperBoundDenseFrob (M : ConcreteMatrix) : Float := - M.frobeniusNorm - -/-- Dense Schur upper bound on `‖M‖₂`: `‖M‖₂ ≤ sqrt(‖M‖₁‖M‖∞)`. - -This is cheap and always valid in exact real arithmetic. --/ -def opNormUpperBoundDenseSchur (M : ConcreteMatrix) : Float := - M.opNormUpperBoundOneInf - -/-- Induced `∞` norm with absolute values: `max_i Σ_j |M[i,j]|`. - -This is the standard induced matrix norm `‖M‖_∞`. -We compute it deterministically in `Float` and interpret the result as a real-number -expression over the concrete Float entries. --/ -def infNormAbs (M : ConcreteMatrix) : Float := Id.run do - let mut maxSum : Float := 0.0 - for i in [:M.numRows] do - let mut rowSum : Float := 0.0 - let rowBase := i * M.numCols - for j in [:M.numCols] do - rowSum := rowSum + Float.abs (M.data[rowBase + j]!) - maxSum := max maxSum rowSum - maxSum - -/-- Upper bound on the operator norm via the Gram matrix and the induced `∞` norm. - -For a real matrix `W` we have: -- `‖W‖₂² = λ_max(WᵀW)` -- `‖G‖₂ ≤ ‖G‖∞` for any real matrix `G` - -Therefore `‖W‖₂ ≤ sqrt(‖WᵀW‖∞)`. - -This computes the quantity `sqrt(max_i Σ_j |(WᵀW)[i,j]|)` **without allocating** `WᵀW`. - -Note: This is computed using `Float` arithmetic; we use it as a deterministic bound for the -matrix obtained by interpreting the Float entries as real numbers. - -PERFORMANCE: O(numRows * numCols^2). Intended for factor matrices with small `numCols` -(e.g. `modelDim×headDim`). --/ -def opNormUpperBoundViaGramInf (W : ConcreteMatrix) : Float := Id.run do - if W.numCols = 0 then - return 0.0 - - let mut maxRowSum : Float := 0.0 - for i in [:W.numCols] do - let mut rowSum : Float := 0.0 - for j in [:W.numCols] do - let mut gij : Float := 0.0 - for k in [:W.numRows] do - let rowBase := k * W.numCols - -- SAFETY: `k < numRows` and `i,j < numCols`, so indices are within `data.size`. - let wi := W.data[rowBase + i]! - let wj := W.data[rowBase + j]! - gij := gij + wi * wj - rowSum := rowSum + Float.abs gij - maxRowSum := max maxRowSum rowSum - - -- Guard against negative zero / NaNs propagating into sqrt. - Float.sqrt (max 0.0 maxRowSum) - -/-- Transpose a matrix. -/ -def transpose (M : ConcreteMatrix) : ConcreteMatrix where - numRows := M.numCols - numCols := M.numRows - data := .ofFn fun idx : Fin (M.numCols * M.numRows) => - let i := idx.val / M.numRows - let j := idx.val % M.numRows - M.data[j * M.numCols + i]! - size_eq := Array.size_ofFn - - -/-- Diagnostics for the rectangular Gram-based operator-norm bound. - -If `usedGram=true`, then we formed the smaller Gram matrix `G` and bounded -`λ_max(G)` by several PSD inequalities, returning `sqrt(λ_upper)`. -Otherwise (size-capped), we fall back to cheap dense bounds. --/ -structure RectGramDiag where - usedGram : Bool - /-- Used the absolute-Gram fallback (no materialized Gram). -/ - usedAbsGram : Bool - /-- True if we computed the absolute-Gram bounds (even if they weren't chosen). -/ - computedAbsGram : Bool - /-- True if we computed a signed Gram matrix candidate (even if it wasn't chosen). -/ - computedGram : Bool - /-- True if we would have formed a signed Gram (by `maxGramDim`) but skipped it via a deterministic cost guard. -/ - skippedGram : Bool - /-- The `maxGramDim` cap from the `BoundEffort` used. -/ - maxGramDimCap : Nat - /-- Whether signed-Gram candidates were enabled in the `BoundEffort` used. -/ - signedGramEnabled : Bool - /-- Estimated cost for materializing Gram (`k*k*max(m,n)`). -/ - gramCost : Nat - /-- Deterministic Gram-cost limit used for scalability guard. -/ - gramCostLimit : Nat - gramDim : Nat - lambdaBrauer : Float - lambdaMoment : Float - lambdaGersh : Float - /-- Gershgorin upper bound computed without forming the Gram matrix. -/ - lambdaAbsGersh : Float - /-- Brauer/Cassini upper bound computed without forming the Gram matrix. -/ - lambdaAbsBrauer : Float - lambdaUsed : Float - opBound : Float - frobBound : Float - oneInfBound : Float - deriving Repr - -/-- How much work to spend computing rigorous upper bounds. - -Higher-effort modes must be **monotone**: they may add additional rigorous candidates and then -take the minimum, but they must never remove candidates. This guarantees that "upgrading" -cannot increase the returned upper bound (up to tiny Float noise). --/ -structure BoundEffort where - /-- Maximum Gram dimension allowed for materializing a (signed) Gram matrix. -/ - maxGramDim : Nat := 256 - /-- Deterministic cost guard for materializing a `k×k` signed Gram matrix, measured as - `k*k*max(m,n)` where `m×n` is the matrix size and `k = min(m,n)`. -/ - gramCostLimit : Nat := 200000000 - /-- Allow the absolute-Gram fallback (no materialized Gram). -/ - enableAbsGram : Bool := true - /-- Allow materializing a signed Gram matrix when `gramDim ≤ maxGramDim`. -/ - enableSignedGram : Bool := true - /-- Allow PSD moment bounds (when available). -/ - enableMoment : Bool := true - deriving Repr, Inhabited - -namespace BoundEffort - -/-- Tier-0: cheap bounds only. -/ -def tier0 : BoundEffort := - { maxGramDim := 0 - gramCostLimit := 0 - enableAbsGram := false - enableSignedGram := false - enableMoment := false } - -/-- Tier-1: enable abs-Gram fallback; keep default signed-Gram cap. -/ -def tier1 : BoundEffort := - { maxGramDim := 256 - gramCostLimit := 200000000 - enableAbsGram := true - enableSignedGram := true - enableMoment := true } - -/-- Tier-2: allow larger signed-Gram (may be expensive). -/ -def tier2 : BoundEffort := - { maxGramDim := 768 - gramCostLimit := 2000000000 - enableAbsGram := true - enableSignedGram := true - enableMoment := true } - -/-- Tier-3: placeholder for future rigorous tightenings (currently same as tier2). -/ -def tier3 : BoundEffort := tier2 - -/-- Ordered list of effort tiers (monotone increasing compute). -/ -def tiers : Array BoundEffort := #[tier0, tier1, tier2, tier3] - -end BoundEffort - -/-- Brauer/Cassini upper bound on `λ_max(G)` for an explicit symmetric matrix `G`. - -This is the same Cassini-ovals formula used for Gram matrices, but computed from -the materialized matrix `G` (intended for small `k×k`). - -Guardrails: -- Clamp discriminants to `≥ 0`. -- If NaN/Inf appears, fall back to the induced-∞ bound `‖G‖_∞`. --/ -def symmLambdaMaxUpperBrauer (G : ConcreteMatrix) : Float := Id.run do - let n := G.numRows - if n = 0 || G.numCols ≠ n then - return 0.0 - - let mut maxDiag : Float := 0.0 - let mut infBound : Float := 0.0 - let mut bad : Bool := false - let mut ds : Array Float := Array.mkEmpty n - let mut rs : Array Float := Array.mkEmpty n - - for i in [:n] do - let di := G.data[i * n + i]! - ds := ds.push di - maxDiag := max maxDiag di - - for i in [:n] do - let mut ri : Float := 0.0 - let rowBase := i * n - for j in [:n] do - if j = i then - continue - ri := ri + Float.abs (G.data[rowBase + j]!) - rs := rs.push ri - let di := ds[i]! - if Float.isNaN di || Float.isInf di || Float.isNaN ri || Float.isInf ri then - bad := true - infBound := max infBound (Float.abs di + ri) - - if bad then - return infBound - if n < 2 then - return maxDiag - - let mut maxPair : Float := 0.0 - for i in [:n] do - let di := ds[i]! - let ri := rs[i]! - for j in [i + 1:n] do - let dj := ds[j]! - let rj := rs[j]! - let delta := di - dj - let disc := max 0.0 (delta * delta + 4.0 * ri * rj) - let root := Float.sqrt disc - let bij := (di + dj + root) / 2.0 - if Float.isNaN bij || Float.isInf bij then - bad := true - maxPair := max maxPair bij - - if bad then - return infBound - else - return max maxDiag maxPair - -/-- Gershgorin and Brauer/Cassini upper bounds for the reduced Gram matrix, without forming it. - -Let `A` be `m×n`. We reduce to `G = A Aᵀ` if `m ≤ n` and `G = Aᵀ A` otherwise, so -`‖A‖₂² = λ_max(G)` (exact real arithmetic). - -We avoid forming `G` by bounding absolute row sums via: -- If `m > n` (`G = AᵀA`): use `s_row[k] = Σ_j |A[k,j]|` and - `Σ_j |G[i,j]| ≤ Σ_k |A[k,i]| * s_row[k]`. -- If `m ≤ n` (`G = AAᵀ`): use `s_col[t] = Σ_i |A[i,t]|` and - `Σ_j |G[i,j]| ≤ Σ_t |A[i,t]| * s_col[t]`. - -The first component is the Gershgorin/∞ bound `max_i rowSumUpper[i]`. -The second is the Brauer/Cassini bound computed from `diag[i]` and -`offSumUpper[i] = max(0, rowSumUpper[i] - diag[i])`. - -If NaN/Inf appears in the Brauer calculation, we conservatively fall back to the -Gershgorin bound. --/ -def rectAbsGramLambdaUpperGershBrauer (A : ConcreteMatrix) : Float × Float := Id.run do - let m := A.numRows - let n := A.numCols - if m = 0 || n = 0 then - return (0.0, 0.0) - - if m > n then - -- `G = AᵀA` (size `n×n`), using row absolute sums of `A`. - let mut sRow : Array Float := Array.replicate m 0.0 - for k in [:m] do - let mut acc : Float := 0.0 - let rowBase := k * n - for j in [:n] do - acc := acc + Float.abs (A.data[rowBase + j]!) - sRow := sRow.set! k acc - - let mut diag : Array Float := Array.replicate n 0.0 - let mut rowSumUpper : Array Float := Array.replicate n 0.0 - for k in [:m] do - let s := sRow[k]! - let rowBase := k * n - for i in [:n] do - let a := A.data[rowBase + i]! - diag := diag.set! i (diag[i]! + a * a) - rowSumUpper := rowSumUpper.set! i (rowSumUpper[i]! + Float.abs a * s) - - let mut lambdaAbsGersh : Float := 0.0 - for i in [:n] do - lambdaAbsGersh := max lambdaAbsGersh rowSumUpper[i]! - - if n < 2 then - return (lambdaAbsGersh, diag[0]!) - - let mut maxPair : Float := 0.0 - let mut bad : Bool := false - for i in [:n] do - let di := diag[i]! - let ri := max 0.0 (rowSumUpper[i]! - di) - if Float.isNaN di || Float.isInf di || Float.isNaN ri || Float.isInf ri then - bad := true - for j in [i + 1:n] do - let dj := diag[j]! - let rj := max 0.0 (rowSumUpper[j]! - dj) - if Float.isNaN dj || Float.isInf dj || Float.isNaN rj || Float.isInf rj then - bad := true - let delta := di - dj - let disc := max 0.0 (delta * delta + 4.0 * ri * rj) - let root := Float.sqrt disc - let bij := (di + dj + root) / 2.0 - if Float.isNaN bij || Float.isInf bij then - bad := true - maxPair := max maxPair bij - - let lambdaAbsBrauer := if bad then lambdaAbsGersh else maxPair - return (lambdaAbsGersh, lambdaAbsBrauer) - else - -- `G = AAᵀ` (size `m×m`), using column absolute sums of `A`. - let mut sCol : Array Float := Array.replicate n 0.0 - for i in [:m] do - let rowBase := i * n - for t in [:n] do - let a := A.data[rowBase + t]! - sCol := sCol.set! t (sCol[t]! + Float.abs a) - - let mut diag : Array Float := Array.replicate m 0.0 - let mut rowSumUpper : Array Float := Array.replicate m 0.0 - for i in [:m] do - let mut di : Float := 0.0 - let mut ru : Float := 0.0 - let rowBase := i * n - for t in [:n] do - let a := A.data[rowBase + t]! - di := di + a * a - ru := ru + Float.abs a * sCol[t]! - diag := diag.set! i di - rowSumUpper := rowSumUpper.set! i ru - - let mut lambdaAbsGersh : Float := 0.0 - for i in [:m] do - lambdaAbsGersh := max lambdaAbsGersh rowSumUpper[i]! - - if m < 2 then - return (lambdaAbsGersh, diag[0]!) - - let mut maxPair : Float := 0.0 - let mut bad : Bool := false - for i in [:m] do - let di := diag[i]! - let ri := max 0.0 (rowSumUpper[i]! - di) - if Float.isNaN di || Float.isInf di || Float.isNaN ri || Float.isInf ri then - bad := true - for j in [i + 1:m] do - let dj := diag[j]! - let rj := max 0.0 (rowSumUpper[j]! - dj) - if Float.isNaN dj || Float.isInf dj || Float.isNaN rj || Float.isInf rj then - bad := true - let delta := di - dj - let disc := max 0.0 (delta * delta + 4.0 * ri * rj) - let root := Float.sqrt disc - let bij := (di + dj + root) / 2.0 - if Float.isNaN bij || Float.isInf bij then - bad := true - maxPair := max maxPair bij - - let lambdaAbsBrauer := if bad then lambdaAbsGersh else maxPair - return (lambdaAbsGersh, lambdaAbsBrauer) - -/-- Effort-configurable variant of `opNormUpperBoundRectGramDiag`. - -This is the entrypoint used by the adaptive scheduler: it can add expensive-but-rigorous -candidates while remaining monotone (the final `opBound` is always the minimum of enabled -rigorous candidates). --/ -def opNormUpperBoundRectGramDiagEffort (A : ConcreteMatrix) (effort : BoundEffort) : RectGramDiag := - let frob := A.frobeniusNorm - let oneInf := A.opNormUpperBoundOneInf - let cheap := min frob oneInf - let m := A.numRows - let n := A.numCols - let k := min m n - let costLimit : Nat := effort.gramCostLimit - let cost : Nat := k * k * (max m n) - - if k = 0 then - { usedGram := true - usedAbsGram := false - computedAbsGram := false - computedGram := false - skippedGram := false - maxGramDimCap := effort.maxGramDim - signedGramEnabled := effort.enableSignedGram - gramCost := 0 - gramCostLimit := costLimit - gramDim := 0 - lambdaBrauer := 0.0 - lambdaMoment := 0.0 - lambdaGersh := 0.0 - lambdaAbsGersh := 0.0 - lambdaAbsBrauer := 0.0 - lambdaUsed := 0.0 - opBound := 0.0 - frobBound := frob - oneInfBound := oneInf } - else Id.run do - let mut bestOp : Float := cheap - let mut usedAbs : Bool := false - let mut usedGram : Bool := false - let mut computedAbsGram : Bool := false - let mut computedGram : Bool := false - let mut skippedGram : Bool := false - let mut lambdaAbsGersh : Float := 0.0 - let mut lambdaAbsBrauer : Float := 0.0 - let mut lambdaGersh : Float := 0.0 - let mut lambdaBrauer : Float := 0.0 - let mut lambdaMoment : Float := 0.0 - let mut lambdaUsed : Float := max 0.0 (cheap * cheap) - - if effort.enableAbsGram then - computedAbsGram := true - let (lG, lB) := rectAbsGramLambdaUpperGershBrauer A - lambdaAbsGersh := lG - lambdaAbsBrauer := lB - let lUpper := min lG lB - let opAbsRaw := Float.sqrt (max 0.0 lUpper) - let opAbs : Float := - if Float.isNaN opAbsRaw || Float.isInf opAbsRaw then - Float.inf - else - opAbsRaw - if opAbs < bestOp then - bestOp := opAbs - lambdaUsed := max 0.0 lUpper - usedAbs := true - usedGram := false - - if effort.enableSignedGram && k ≤ effort.maxGramDim then - -- Deterministic scalability guard: forming `k×k` Gram matrices can be prohibitively expensive - -- for large rectangular matrices (e.g. 768×3072). We therefore skip this candidate when the - -- estimated matmul cost is too high. - -- - -- Correctness is preserved: skipping a candidate can only loosen the final bound. - let allowGram : Bool := cost ≤ costLimit - if allowGram then - computedGram := true - let G : ConcreteMatrix := - if m ≤ n then - A.matmul A.transpose - else - A.transpose.matmul A - lambdaGersh := G.infNormAbs - lambdaBrauer := symmLambdaMaxUpperBrauer G - let mut lUpper : Float := min lambdaGersh lambdaBrauer - if effort.enableMoment then - -- For Gram `G`, `trace(G) = ‖A‖_F²`. - let tr : Float := A.frobeniusNormSq - let f2 : Float := G.frobeniusNormSq - lambdaMoment := psdLambdaMaxUpperMoment k tr f2 - lUpper := min lUpper lambdaMoment - else - lambdaMoment := 0.0 - let op := Float.sqrt (max 0.0 lUpper) - if op < bestOp then - bestOp := op - lambdaUsed := max 0.0 lUpper - usedAbs := false - usedGram := true - else - skippedGram := true - - { usedGram := usedGram - usedAbsGram := usedAbs - computedAbsGram := computedAbsGram - computedGram := computedGram - skippedGram := skippedGram - maxGramDimCap := effort.maxGramDim - signedGramEnabled := effort.enableSignedGram - gramCost := cost - gramCostLimit := costLimit - gramDim := k - lambdaBrauer := lambdaBrauer - lambdaMoment := lambdaMoment - lambdaGersh := lambdaGersh - lambdaAbsGersh := lambdaAbsGersh - lambdaAbsBrauer := lambdaAbsBrauer - lambdaUsed := lambdaUsed - opBound := min cheap bestOp - frobBound := frob - oneInfBound := oneInf } - -/-- Backwards-compatible wrapper around `opNormUpperBoundRectGramDiagEffort`. -/ -def opNormUpperBoundRectGramDiag (A : ConcreteMatrix) (maxGramDim : Nat := 256) : RectGramDiag := - opNormUpperBoundRectGramDiagEffort A { maxGramDim := maxGramDim } - -/-- Rectangular Gram-based operator-norm bound (with a size-cap fallback). -/ -def opNormUpperBoundRectGram (A : ConcreteMatrix) (maxGramDim : Nat := 256) : Float := - (A.opNormUpperBoundRectGramDiag maxGramDim).opBound - -/-- Effort-configurable variant of `opNormUpperBoundRectGram`. -/ -def opNormUpperBoundRectGramEffort (A : ConcreteMatrix) (effort : BoundEffort) : Float := - (A.opNormUpperBoundRectGramDiagEffort effort).opBound - -/-- Gram-matrix induced-∞ upper bound on the spectral norm. - -In exact real arithmetic: -`‖M‖₂² = λ_max(MᵀM) ≤ ‖MᵀM‖_∞`, hence `‖M‖₂ ≤ sqrt(‖MᵀM‖_∞)`. - -This allocates `MᵀM`, so it is intended for small matrices. -If the computation produces NaN/Inf, we conservatively fall back to `‖M‖_F`. --/ -def gramInfOpBound (M : ConcreteMatrix) : Float := - if M.numRows = 0 || M.numCols = 0 then - 0.0 - else - let g := M.transpose.matmul M - let v := Float.sqrt (max 0.0 g.infNormAbs) - if Float.isNaN v || Float.isInf v then - M.frobeniusNorm - else - v - -/-- Compute an entry of the Gram matrix `MᵀM` without materializing it. - -`gramMatrixEntry M i j = (MᵀM)[i,j] = Σ_k M[k,i] * M[k,j]`. - -This is intended for small `numCols` (e.g. `headDim=64`). --/ -def gramMatrixEntry (M : ConcreteMatrix) (i j : Nat) : Float := Id.run do - if M.numRows = 0 || M.numCols = 0 then - return 0.0 - if i ≥ M.numCols || j ≥ M.numCols then - return 0.0 - let mut acc : Float := 0.0 - for k in [:M.numRows] do - let rowBase := k * M.numCols - let mi := M.data[rowBase + i]! - let mj := M.data[rowBase + j]! - acc := acc + mi * mj - return acc - -/-- Gram diagonal entry `d_i = (MᵀM)[i,i] = Σ_k M[k,i]^2`. - -For real `M`, these are nonnegative in exact arithmetic. --/ -def gramDiag (M : ConcreteMatrix) (i : Nat) : Float := Id.run do - if M.numRows = 0 || M.numCols = 0 then - return 0.0 - if i ≥ M.numCols then - return 0.0 - let mut acc : Float := 0.0 - for k in [:M.numRows] do - let rowBase := k * M.numCols - let mi := M.data[rowBase + i]! - acc := acc + mi * mi - return acc - -/-- Off-diagonal absolute row sum of the Gram matrix. - -Let `G = MᵀM`. This computes -`R_i = Σ_{j ≠ i} |G[i,j]|`. --/ -def gramRowAbsSumExclDiag (M : ConcreteMatrix) (i : Nat) : Float := Id.run do - if M.numCols = 0 || i ≥ M.numCols then - return 0.0 - let mut acc : Float := 0.0 - for j in [:M.numCols] do - if j = i then - continue - acc := acc + Float.abs (M.gramMatrixEntry i j) - return acc - -/-- Frobenius norm squared of the Gram matrix `G = MᵀM`, computed without allocating `G`. - -This returns `‖G‖_F² = Σ_{i,j} G[i,j]^2`. -Intended for small `numCols` (e.g. `headDim=64`). --/ -def gramFrobeniusNormSq (M : ConcreteMatrix) : Float := Id.run do - let n := M.numCols - if n = 0 then - return 0.0 - let mut acc : Float := 0.0 - for i in [:n] do - for j in [:n] do - let gij := M.gramMatrixEntry i j - acc := acc + gij * gij - return acc - -/-- Brauer/Cassini upper bound on `λ_max(G)` for a Gram matrix `G = MᵀM`. - -Mathematical facts (exact real arithmetic): - -Let `G` be real symmetric (Gram matrices are symmetric PSD). Define: -- `d_i = G[i,i]` -- `R_i = Σ_{j≠i} |G[i,j]|` - -Brauer bound (Cassini ovals): -`λ_max(G) ≤ max_{i≠j} b_ij`, where -`b_ij = (d_i + d_j + sqrt((d_i - d_j)^2 + 4*R_i*R_j)) / 2`. - -We also have the induced-∞ / Gershgorin bound: -`λ_max(G) ≤ max_i (d_i + R_i) = ‖G‖_∞`. - -Guardrails: -- Clamp the discriminant inside `sqrt` to `≥ 0`. -- If any NaN/Inf appears, conservatively fall back to `‖G‖_∞`. -- If `n < 2`, return `max_i d_i`. --/ -def gramLambdaMaxUpperBrauer (M : ConcreteMatrix) : Float := Id.run do - let n := M.numCols - if n = 0 then - return 0.0 - - let ds : Array Float := .ofFn fun i : Fin n => M.gramDiag i.val - let rs : Array Float := .ofFn fun i : Fin n => M.gramRowAbsSumExclDiag i.val - - let mut maxDiag : Float := 0.0 - let mut infBound : Float := 0.0 - let mut bad : Bool := false - for i in [:n] do - let di := ds[i]! - let ri := rs[i]! - if Float.isNaN di || Float.isInf di || Float.isNaN ri || Float.isInf ri then - bad := true - maxDiag := max maxDiag di - infBound := max infBound (Float.abs di + ri) - - if bad then - return infBound - - if n < 2 then - return maxDiag - - let mut maxPair : Float := 0.0 - for i in [:n] do - let di := ds[i]! - let ri := rs[i]! - for j in [i + 1:n] do - let dj := ds[j]! - let rj := rs[j]! - let delta := di - dj - let disc := max 0.0 (delta * delta + 4.0 * ri * rj) - let root := Float.sqrt disc - let bij := (di + dj + root) / 2.0 - if Float.isNaN bij || Float.isInf bij then - bad := true - maxPair := max maxPair bij - - if bad then - return infBound - else - -- Brauer/Cassini bound candidate - let lambdaBrauerUpper := max maxDiag maxPair - -- Moment bound candidate for PSD `G = MᵀM`. - -- Here `trace(G) = ‖M‖_F²`. - let tr : Float := M.frobeniusNormSq - let f2 : Float := M.gramFrobeniusNormSq - let lambdaMomentUpper := psdLambdaMaxUpperMoment n tr f2 - -- Combine cheap valid upper bounds by taking `min`. - return min infBound (min lambdaBrauerUpper lambdaMomentUpper) - -/-- Brauer/Cassini upper bound on the spectral radius for a general square matrix. - -For any eigenvalue `λ` there exist `i ≠ j` with -`|λ - a_ii| · |λ - a_jj| ≤ R_i R_j`, where `R_i` is the off-diagonal row sum. -By the reverse triangle inequality, this yields a real upper bound on `|λ|`. --/ -def lambdaMaxUpperBrauer (M : ConcreteMatrix) : Float := Id.run do - let n := M.numRows - if n = 0 || M.numCols ≠ n then - return 0.0 - - let mut maxDiag : Float := 0.0 - let mut infBound : Float := 0.0 - let mut bad : Bool := false - let mut ds : Array Float := Array.mkEmpty n - let mut rs : Array Float := Array.mkEmpty n - - for i in [:n] do - let di := Float.abs (M.data[i * n + i]!) - ds := ds.push di - maxDiag := max maxDiag di - - for i in [:n] do - let mut ri : Float := 0.0 - let rowBase := i * n - for j in [:n] do - if j = i then - continue - ri := ri + Float.abs (M.data[rowBase + j]!) - rs := rs.push ri - let di := ds[i]! - if Float.isNaN di || Float.isInf di || Float.isNaN ri || Float.isInf ri then - bad := true - infBound := max infBound (di + ri) - - if bad then - return infBound - if n < 2 then - return maxDiag - - let mut maxPair : Float := 0.0 - for i in [:n] do - let di := ds[i]! - let ri := rs[i]! - for j in [i + 1:n] do - let dj := ds[j]! - let rj := rs[j]! - let delta := di - dj - let disc := max 0.0 (delta * delta + 4.0 * ri * rj) - let root := Float.sqrt disc - let bij := (di + dj + root) / 2.0 - if Float.isNaN bij || Float.isInf bij then - bad := true - maxPair := max maxPair bij - - if bad then - return infBound - else - return max maxDiag maxPair - -/-- Gershgorin upper bound on `λ_max` for matrices with nonnegative real eigenvalues. - -Uses `λ ≤ max_i (a_ii + Σ_{j≠i} |a_ij|)` (no absolute on the diagonal). --/ -def lambdaMaxUpperGershNonneg (M : ConcreteMatrix) : Float := Id.run do - let n := M.numRows - if n = 0 || M.numCols ≠ n then - return 0.0 - let mut maxBound : Float := 0.0 - for i in [:n] do - let rowBase := i * n - let di := M.data[rowBase + i]! - let mut ri : Float := 0.0 - for j in [:n] do - if j = i then - continue - ri := ri + Float.abs (M.data[rowBase + j]!) - let bound := max 0.0 (di + ri) - if Float.isNaN bound || Float.isInf bound then - return M.infNormAbs - if bound > maxBound then - maxBound := bound - return maxBound - -/-- Brauer/Cassini upper bound on `λ_max` for matrices with nonnegative real eigenvalues. - -This mirrors `lambdaMaxUpperBrauer` but keeps signed diagonals. Since the eigenvalues are -assumed nonnegative, using `a_ii` (not `|a_ii|`) can tighten the bound. --/ -def lambdaMaxUpperBrauerNonneg (M : ConcreteMatrix) : Float := Id.run do - let n := M.numRows - if n = 0 || M.numCols ≠ n then - return 0.0 - - let mut maxDiag : Float := 0.0 - let mut gersh : Float := 0.0 - let mut bad : Bool := false - let mut ds : Array Float := Array.mkEmpty n - let mut rs : Array Float := Array.mkEmpty n - - for i in [:n] do - let di := M.data[i * n + i]! - ds := ds.push di - maxDiag := max maxDiag di - - for i in [:n] do - let mut ri : Float := 0.0 - let rowBase := i * n - for j in [:n] do - if j = i then - continue - ri := ri + Float.abs (M.data[rowBase + j]!) - rs := rs.push ri - let di := ds[i]! - if Float.isNaN di || Float.isInf di || Float.isNaN ri || Float.isInf ri then - bad := true - gersh := max gersh (max 0.0 (di + ri)) - - if bad then - return gersh - if n < 2 then - return max gersh (max 0.0 maxDiag) - - let mut maxPair : Float := 0.0 - for i in [:n] do - let di := ds[i]! - let ri := rs[i]! - for j in [i + 1:n] do - let dj := ds[j]! - let rj := rs[j]! - let delta := di - dj - let disc := max 0.0 (delta * delta + 4.0 * ri * rj) - let root := Float.sqrt disc - let bij := (di + dj + root) / 2.0 - if Float.isNaN bij || Float.isInf bij then - bad := true - maxPair := max maxPair bij - - if bad then - return gersh - else - return max gersh (max 0.0 maxPair) - -/-- Dense (small) spectral-norm upper bound using the Brauer/Cassini Gram bound. - -`‖M‖₂² = λ_max(MᵀM) ≤ gramLambdaMaxUpperBrauer(M)`. --/ -def opNormUpperBoundDenseBrauer (M : ConcreteMatrix) : Float := - Float.sqrt (max 0.0 (M.gramLambdaMaxUpperBrauer)) - -/-- Matrix addition. Returns zero matrix if dimensions don't match. -/ -def add (A B : ConcreteMatrix) : ConcreteMatrix := - if A.numRows = B.numRows ∧ A.numCols = B.numCols then - { - numRows := A.numRows - numCols := A.numCols - data := .ofFn fun idx : Fin (A.numRows * A.numCols) => - A.data[idx.val]! + B.data[idx.val]! - size_eq := Array.size_ofFn - } - else zeros 0 0 - -/-- Sum a list of same-shape matrices in a single pass. - -This avoids repeated `add` allocations (which would traverse the matrix multiple times). -When `allowParallel=true` and the matrices are sufficiently large, it uses row-chunk parallelism. --/ -def sumMatrices (ms : Array ConcreteMatrix) (allowParallel : Bool := true) : ConcreteMatrix := - if ms.isEmpty then - zeros 0 0 - else - letI : Inhabited ConcreteMatrix := ⟨zeros 0 0⟩ - let base := ms[0]! - let rows := base.numRows - let cols := base.numCols - let okDims := - ms.all fun M => M.numRows = rows ∧ M.numCols = cols - if !okDims then - zeros 0 0 - else - let entries := rows * cols - let shouldPar : Bool := - allowParallel && - entries ≥ 200_000 && - rows ≥ 2 && - cols ≥ 256 && - ms.size ≥ 4 - if !shouldPar then - { - numRows := rows - numCols := cols - data := .ofFn fun idx : Fin entries => Id.run do - let mut acc : Float := 0.0 - for M in ms do - acc := acc + M.data[idx.val]! - return acc - size_eq := Array.size_ofFn - } - else - let numTasks : Nat := min 16 rows - let q := rows / numTasks - let r := rows % numTasks - let tasks : Array (Task (Array Float)) := - .ofFn fun t : Fin numTasks => - Task.spawn (fun _ => - let tid := t.val - let extra := if tid < r then 1 else 0 - let rowsHere := q + extra - let startRow := tid * q + min tid r - let chunkSize := rowsHere * cols - .ofFn fun idx : Fin chunkSize => Id.run do - let localRow := idx.val / cols - let j := idx.val % cols - let i := startRow + localRow - let globalIdx := i * cols + j - let mut acc : Float := 0.0 - for M in ms do - acc := acc + M.data[globalIdx]! - return acc) - -- Join in increasing task index order (deterministic). - let chunks := tasks.map Task.get - let cutoff := (q + 1) * r - { - numRows := rows - numCols := cols - data := .ofFn fun idx : Fin entries => - let row := idx.val / cols - let col := idx.val % cols - let taskIdx := - if row < cutoff then - row / (q + 1) - else - r + (row - cutoff) / q - let localRow := - if row < cutoff then - row % (q + 1) - else - (row - cutoff) % q - let chunk := chunks[taskIdx]! - chunk[localRow * cols + col]! - size_eq := Array.size_ofFn - } - -/-- Scalar multiplication. -/ -def scale (c : Float) (M : ConcreteMatrix) : ConcreteMatrix where - numRows := M.numRows - numCols := M.numCols - data := M.data.map (c * ·) - size_eq := by simp [M.size_eq] - -/-- Get row i as a 1×numCols matrix. -/ -def getRow (M : ConcreteMatrix) (i : Nat) : ConcreteMatrix := - if i < M.numRows then - { - numRows := 1 - numCols := M.numCols - data := .ofFn fun j : Fin M.numCols => M.getUnsafe i j.val - size_eq := by simp - } - else zeros 1 M.numCols - -/-- Set row i from a 1×numCols matrix. Returns original if dimensions wrong. -/ -def setRow (M : ConcreteMatrix) (i : Nat) (row : ConcreteMatrix) : ConcreteMatrix := - if i < M.numRows ∧ row.numRows = 1 ∧ row.numCols = M.numCols then - { - numRows := M.numRows - numCols := M.numCols - data := .ofFn fun idx : Fin (M.numRows * M.numCols) => - let r := idx.val / M.numCols - let c := idx.val % M.numCols - if r = i then row.getUnsafe 0 c else M.getUnsafe r c - size_eq := Array.size_ofFn - } - else M - -/-- Element-wise application of a function. -/ -def map (f : Float → Float) (M : ConcreteMatrix) : ConcreteMatrix where - numRows := M.numRows - numCols := M.numCols - data := M.data.map f - size_eq := by simp [M.size_eq] - -/-- Broadcast add: add a 1×numCols bias to each row. -/ -def addBias (M : ConcreteMatrix) (bias : ConcreteMatrix) : ConcreteMatrix := - if bias.numRows = 1 ∧ bias.numCols = M.numCols then - { - numRows := M.numRows - numCols := M.numCols - data := .ofFn fun idx : Fin (M.numRows * M.numCols) => - let c := idx.val % M.numCols - M.data[idx.val]! + bias.data[c]! - size_eq := Array.size_ofFn - } - else M - -/-- Row-wise LayerNorm with learnable scale γ and bias β (both 1×numCols). - -This implements the Pre-LN transformer normalization convention: each token (row) -is normalized across model dimension (columns), then scaled and shifted. --/ -def layerNormRowwise (X γ β : ConcreteMatrix) (eps : Float := 1e-5) : ConcreteMatrix := Id.run do - let rows := X.numRows - let cols := X.numCols - if rows = 0 || cols = 0 then - return ConcreteMatrix.zeros rows cols - if !(γ.numRows = 1 ∧ γ.numCols = cols ∧ β.numRows = 1 ∧ β.numCols = cols) then - return X - let colsF := cols.toFloat - let gammaData := γ.data - let betaData := β.data - - -- Per-row mean and inverse stddev (compute once for speed). - let mut means : Array Float := Array.replicate rows 0.0 - let mut invStds : Array Float := Array.replicate rows 0.0 - for r in [:rows] do - let mut sum : Float := 0.0 - let rowBase := r * cols - for c in [:cols] do - sum := sum + X.data[rowBase + c]! - let μ := sum / colsF - let mut varSum : Float := 0.0 - for c in [:cols] do - let d := X.data[rowBase + c]! - μ - varSum := varSum + d * d - -- In exact arithmetic, `var ≥ 0`. Clamp to avoid NaN from tiny negative float noise. - let var := max 0.0 (varSum / colsF) - let invσ := 1.0 / Float.sqrt (var + eps) - means := means.set! r μ - invStds := invStds.set! r invσ - - return { - numRows := rows - numCols := cols - data := .ofFn fun idx : Fin (rows * cols) => - let r := idx.val / cols - let c := idx.val % cols - let μ := means[r]! - let invσ := invStds[r]! - let normalized := (X.data[r * cols + c]! - μ) * invσ - (gammaData[c]!) * normalized + (betaData[c]!) - size_eq := Array.size_ofFn - } - -/-- Heuristic estimate for the operator norm of the Jacobian of row-wise LayerNorm. - -We bound the **global** operator norm on the block-diagonal Jacobian (one block per row) -by the maximum over rows of a **tight spectral-norm bound**. - -For a single row `x : ℝ^d` (ignoring β), LayerNorm is: -`LN(x) = γ ⊙ ((x - μ) / σ)` with `σ = sqrt(var + eps)`. -Its Jacobian has the closed form: - -`J = diag(γ) * (1/σ) * (I - (1/d)11ᵀ - (1/d)vvᵀ)` - -where `v` is the centered vector scaled by `1/σ`. The symmetric matrix in parentheses -has eigenvalues `{0, 1, eps/(var+eps)}` so its spectral norm is exactly `1`. -Therefore `‖J‖₂ ≤ max |γ| / σ` in exact real arithmetic. - -This avoids the previous row-sum bound which could overestimate by orders of magnitude -and made downstream certification thresholds unusable. --/ -def layerNormRowwiseOpEst (X γ : ConcreteMatrix) (eps : Float := 1e-5) : Float := Id.run do - let rows := X.numRows - let cols := X.numCols - if rows = 0 || cols = 0 then return 0.0 - if !(γ.numRows = 1 ∧ γ.numCols = cols) then return 0.0 - let colsF := cols.toFloat - let gammaData := γ.data - - -- max |γ| - let mut gammaMaxAbs : Float := 0.0 - for c in [:cols] do - let g := Float.abs (gammaData[c]!) - if Float.isNaN g || Float.isInf g then - gammaMaxAbs := Float.inf - else if g > gammaMaxAbs then - gammaMaxAbs := g - - -- max_r (1/σ_r) - let mut maxInvStd : Float := 0.0 - for r in [:rows] do - -- Mean - let mut sum : Float := 0.0 - let rowBase := r * cols - for c in [:cols] do - sum := sum + X.data[rowBase + c]! - let μ := sum / colsF - - -- Variance - let mut varSum : Float := 0.0 - for c in [:cols] do - let centered := X.data[rowBase + c]! - μ - varSum := varSum + centered * centered - let varRaw := varSum / colsF - -- In exact arithmetic, `var ≥ 0`. If we see NaN/Inf, conservatively treat as 0 - -- so that `1/σ ≤ 1/sqrt(eps)`. - let var := - if Float.isNaN varRaw || Float.isInf varRaw then 0.0 - else max 0.0 varRaw - let σ := Float.sqrt (var + eps) - if σ > 0.0 && !(Float.isNaN σ || Float.isInf σ) then - let invσ := 1.0 / σ - if invσ > maxInvStd then maxInvStd := invσ - - if maxInvStd ≤ 0.0 || Float.isNaN maxInvStd || Float.isInf maxInvStd then - maxInvStd := 1.0 / Float.sqrt eps - else - let invStdMax := 1.0 / Float.sqrt eps - if maxInvStd > invStdMax then - maxInvStd := invStdMax - - return gammaMaxAbs * maxInvStd - -structure LayerNormOpDiag where - gammaMaxAbs : Float - maxInvStd : Float - maxInvStdRow : Nat - minVar : Float - minVarRow : Nat - deriving Repr - -/-- Diagnostics for `layerNormRowwiseOpEst`: reports `max |γ|`, the worst-case row inverse-std, -and the minimum variance row (useful for spotting padding/degenerate rows). -/ -def layerNormRowwiseOpDiag (X γ : ConcreteMatrix) (eps : Float := 1e-5) : LayerNormOpDiag := Id.run do - let rows := X.numRows - let cols := X.numCols - if rows = 0 || cols = 0 then - return { gammaMaxAbs := 0.0, maxInvStd := 0.0, maxInvStdRow := 0, minVar := 0.0, minVarRow := 0 } - if !(γ.numRows = 1 ∧ γ.numCols = cols) then - return { gammaMaxAbs := 0.0, maxInvStd := 0.0, maxInvStdRow := 0, minVar := 0.0, minVarRow := 0 } - let colsF := cols.toFloat - let gammaData := γ.data - - let mut gammaMaxAbs : Float := 0.0 - for c in [:cols] do - let g := Float.abs (gammaData[c]!) - if Float.isNaN g || Float.isInf g then - gammaMaxAbs := Float.inf - else if g > gammaMaxAbs then - gammaMaxAbs := g - - let mut maxInvStd : Float := 0.0 - let mut maxInvStdRow : Nat := 0 - let mut minVar : Float := Float.inf - let mut minVarRow : Nat := 0 - - for r in [:rows] do - let mut sum : Float := 0.0 - let rowBase := r * cols - for c in [:cols] do - sum := sum + X.data[rowBase + c]! - let μ := sum / colsF - let mut varSum : Float := 0.0 - for c in [:cols] do - let centered := X.data[rowBase + c]! - μ - varSum := varSum + centered * centered - let varRaw := varSum / colsF - let var := - if Float.isNaN varRaw || Float.isInf varRaw then 0.0 - else max 0.0 varRaw - if var < minVar then - minVar := var - minVarRow := r - let σ := Float.sqrt (var + eps) - if σ > 0.0 then - let invσ := 1.0 / σ - if invσ > maxInvStd then - maxInvStd := invσ - maxInvStdRow := r - - if Float.isInf minVar || Float.isNaN minVar then - minVar := 0.0 - if maxInvStd ≤ 0.0 || Float.isNaN maxInvStd || Float.isInf maxInvStd then - maxInvStd := 1.0 / Float.sqrt eps - - return { - gammaMaxAbs := gammaMaxAbs - maxInvStd := maxInvStd - maxInvStdRow := maxInvStdRow - minVar := minVar - minVarRow := minVarRow - } - -/-- Get column j as a vector (stored as numRows×1 matrix). -/ -def getCol (M : ConcreteMatrix) (j : Nat) : ConcreteMatrix := - if j < M.numCols then - { - numRows := M.numRows - numCols := 1 - data := .ofFn fun i : Fin M.numRows => M.getUnsafe i.val j - size_eq := by simp - } - else zeros M.numRows 1 - -/-- Compute matrix-vector product M * v where v is stored as numCols×1 matrix. - Returns a numRows×1 matrix. -/ -def matVecMul (M : ConcreteMatrix) (v : ConcreteMatrix) : ConcreteMatrix := - if M.numCols = v.numRows ∧ v.numCols = 1 then - { - numRows := M.numRows - numCols := 1 - data := .ofFn fun i : Fin M.numRows => Id.run do - let mut acc : Float := 0.0 - let rowBase := i.val * M.numCols - for k in [:M.numCols] do - acc := acc + M.data[rowBase + k]! * v.data[k]! - return acc - size_eq := by simp - } - else zeros M.numRows 1 - -/-- Compute dot product of two vectors (stored as n×1 matrices). -/ -def dot (v1 v2 : ConcreteMatrix) : Float := - if v1.numRows = v2.numRows ∧ v1.numCols = 1 ∧ v2.numCols = 1 then Id.run do - let mut acc : Float := 0.0 - for i in [:v1.numRows] do - acc := acc + v1.data[i]! * v2.data[i]! - return acc - else 0.0 - -/-- Compute L2 norm of a vector (stored as n×1 matrix). -/ -def vecNorm (v : ConcreteMatrix) : Float := - if v.numCols = 1 then - Float.sqrt (sumSquares v.data) - else 0.0 - -/-- Vector subtraction for n×1 matrices. -/ -def vecSub (v1 v2 : ConcreteMatrix) : ConcreteMatrix := - if v1.numRows = v2.numRows ∧ v1.numCols = 1 ∧ v2.numCols = 1 then - { - numRows := v1.numRows - numCols := 1 - data := .ofFn fun i : Fin v1.numRows => v1.data[i.val]! - v2.data[i.val]! - size_eq := by simp - } - else zeros v1.numRows 1 - -end ConcreteMatrix - -/-! ## Concrete LayerNorm Parameters -/ - -/-- Concrete LayerNorm parameters for Pre-LN transformers (scale γ and bias β). -/ -structure ConcreteLayerNormParams where - /-- Scale γ (1×modelDim) -/ - gamma : ConcreteMatrix - /-- Bias β (1×modelDim) -/ - beta : ConcreteMatrix - -namespace ConcreteLayerNormParams - -/-- Identity LayerNorm affine parameters: γ=1, β=0. -/ -def identity (modelDim : Nat) : ConcreteLayerNormParams := - { gamma := ConcreteMatrix.ones 1 modelDim, beta := ConcreteMatrix.zeros 1 modelDim } - -end ConcreteLayerNormParams - -/-! ## Concrete Attention Layer -/ - -/-- A concrete attention layer with exported weights. - -This structure holds the four projection matrices that define a single attention head: -- W_Q: Query projection (d × d_head) -- W_K: Key projection (d × d_head) -- W_V: Value projection (d × d_head) -- W_O: Output projection (d_head × d) --/ -structure ConcreteAttentionLayer where - /-- Model dimension (embedding size) -/ - modelDim : Nat - /-- Head dimension -/ - headDim : Nat - /-- Query projection matrix (modelDim × headDim) -/ - W_Q : ConcreteMatrix - /-- Query bias (1×headDim). -/ - b_Q : ConcreteMatrix - /-- Key projection matrix (modelDim × headDim) -/ - W_K : ConcreteMatrix - /-- Key bias (1×headDim). -/ - b_K : ConcreteMatrix - /-- Value projection matrix (modelDim × headDim) -/ - W_V : ConcreteMatrix - /-- Value bias (1×headDim). -/ - b_V : ConcreteMatrix - /-- Output projection matrix (headDim × modelDim) -/ - W_O : ConcreteMatrix - /-- Dimension consistency for W_Q -/ - W_Q_dims : W_Q.numRows = modelDim ∧ W_Q.numCols = headDim - /-- Dimension consistency for b_Q -/ - b_Q_dims : b_Q.numRows = 1 ∧ b_Q.numCols = headDim - /-- Dimension consistency for W_K -/ - W_K_dims : W_K.numRows = modelDim ∧ W_K.numCols = headDim - /-- Dimension consistency for b_K -/ - b_K_dims : b_K.numRows = 1 ∧ b_K.numCols = headDim - /-- Dimension consistency for W_V -/ - W_V_dims : W_V.numRows = modelDim ∧ W_V.numCols = headDim - /-- Dimension consistency for b_V -/ - b_V_dims : b_V.numRows = 1 ∧ b_V.numCols = headDim - /-- Dimension consistency for W_O -/ - W_O_dims : W_O.numRows = headDim ∧ W_O.numCols = modelDim - -namespace ConcreteAttentionLayer - -/-- Compute the value-output projection `W_V · W_O` (dense `modelDim×modelDim`). - -This is **diagnostics-only** on the main discovery/verification paths. Prefer using -small `headDim×headDim` Gram matrices and trace identities to compute scalar bounds. --/ -def valueOutputProjection (layer : ConcreteAttentionLayer) : ConcreteMatrix := - layer.W_V.matmul layer.W_O - -/-- Compute the query-key alignment `W_Q · W_Kᵀ` (dense `modelDim×modelDim`). - -This is **diagnostics-only** on the main discovery/verification paths. Prefer using -small `headDim×headDim` Gram matrices and trace identities to compute scalar bounds. --/ -def queryKeyAlignment (layer : ConcreteAttentionLayer) : ConcreteMatrix := - layer.W_Q.matmul layer.W_K.transpose - -/-- Compute a tighter rigorous bound on `‖W_Q W_Kᵀ‖₂` by centering `W_K` across its rows. - -Key observation: attention keys are computed from the Pre-LN activation `u = ln₁(x)`, and each row -of `u` has (approximately) zero mean across the `modelDim` features. Therefore, replacing -`W_K` by `W_K - 1·μᵀ` where `μ = (1/modelDim)·(1ᵀ W_K)` does not change `u·W_K`, hence does not -change the attention logits or their Jacobian, but can reduce the operator norm used in bounds. - -We avoid materializing `W_K - 1·μᵀ` by applying the corresponding rank-1 correction to the Gram: -`(W_K - 1·μᵀ)ᵀ (W_K - 1·μᵀ) = W_KᵀW_K - (1/modelDim)·uᵀu` where `u = 1ᵀ W_K`. --/ -def centeredQKOpBound (layer : ConcreteAttentionLayer) : Float := Id.run do - let kRows := layer.modelDim - let kCols := layer.headDim - if kRows = 0 || kCols = 0 then - return 0.0 - - -- 1) u = 1ᵀ W_K (column sums of W_K). - let mut colSums : Array Float := Array.replicate kCols 0.0 - for r in [:kRows] do - let rowBase := r * kCols - for c in [:kCols] do - colSums := colSums.set! c (colSums[c]! + layer.W_K.data[rowBase + c]!) - - -- 2) Centered Gram: G' = W_KᵀW_K - (1/kRows)·uᵀu. - let invDim := 1.0 / kRows.toFloat - let wkGram := layer.W_K.transpose.matmul layer.W_K - let wkGramCentered : ConcreteMatrix := - { numRows := kCols - numCols := kCols - data := .ofFn fun idx : Fin (kCols * kCols) => - let i := idx.val / kCols - let j := idx.val % kCols - let ui := colSums.getD i 0.0 - let uj := colSums.getD j 0.0 - wkGram.getUnsafe i j - (invDim * ui * uj) - size_eq := Array.size_ofFn } - - let wqGram := layer.W_Q.transpose.matmul layer.W_Q - - -- Candidate A: factorized Gram bound using `‖·‖∞` on Grams. - let wqOp := Float.sqrt (max 0.0 (wqGram.infNormAbs)) - let wkOpCentered := Float.sqrt (max 0.0 (wkGramCentered.infNormAbs)) - let boundFactor := wqOp * wkOpCentered - - -- Candidate B: Frobenius candidate via trace identity. - let frobSq := ConcreteMatrix.traceMul wqGram wkGramCentered - let boundFrob := Float.sqrt (max 0.0 frobSq) - - return min boundFactor boundFrob - -/-! ### No-dense scalar bounds for `W_Q·W_Kᵀ` and `W_V·W_O` - -We avoid materializing `modelDim×modelDim` products by working with `headDim×headDim` Grams. - -Key trace identities (for real matrices): -- For `W_V : d×h` and `W_O : h×d`, with `vo = W_V·W_O`: - `‖vo‖_F² = trace((W_VᵀW_V) · (W_O W_Oᵀ))`. -- For `W_Q : d×h` and `W_K : d×h`, with `qk = W_Q·W_Kᵀ`: - `‖qk‖_F² = trace((W_QᵀW_Q) · (W_KᵀW_K))`. - -These let us compute exact (in ℝ, given the Float-evaluated entries) Frobenius candidates -using only `headDim×headDim` matrices. --/ - -structure NoDenseProductBounds where - /-- `W_QᵀW_Q` (headDim×headDim). -/ - wqGram : ConcreteMatrix - /-- `W_KᵀW_K` (headDim×headDim). -/ - wkGram : ConcreteMatrix - /-- `W_VᵀW_V` (headDim×headDim). -/ - wvGram : ConcreteMatrix - /-- `W_O W_Oᵀ` (headDim×headDim). -/ - woGram : ConcreteMatrix - - /-- Frobenius-squared of `W_Q·W_Kᵀ`, computed via a `headDim×headDim` trace identity. -/ - qkFrobNormSq : Float - /-- Frobenius bound candidate `‖W_Q·W_Kᵀ‖_F`. -/ - qkDenseFrob : Float - /-- Tight Gram-product candidate derived from `headDim×headDim` Grams. -/ - qkDenseGram : Float - /-- Brauer/Cassini candidate computed on a `headDim×headDim` product with matching singular values. -/ - qkDenseBrauer : Float - qkFactorSchur : Float - qkFactorFrob : Float - wqOpGram : Float - wkOpGram : Float - qkFactorGram : Float - /-- Final chosen upper bound (min of rigorous candidates). -/ - qkOpBound : Float - - /-- Frobenius-squared of `W_V·W_O`, computed via a `headDim×headDim` trace identity. -/ - voFrobNormSq : Float - /-- Frobenius bound candidate `‖W_V·W_O‖_F`. -/ - voDenseFrob : Float - /-- Tight Gram-product candidate derived from `headDim×headDim` Grams. -/ - voDenseGram : Float - /-- Brauer/Cassini candidate computed on a `headDim×headDim` product with matching singular values. -/ - voDenseBrauer : Float - voFactorSchur : Float - voFactorFrob : Float - wvOpGram : Float - woOpGram : Float - voFactorGram : Float - /-- Final chosen upper bound (min of rigorous candidates). -/ - voOpBound : Float - -private def safeSqrt (x : Float) : Float := - Float.sqrt (max 0.0 x) - -/-- Tight Gram-based operator bound from an explicit PSD Gram matrix. -/ -private def opBoundFromGram (gram : ConcreteMatrix) (trace : Float) : Float := - if gram.numRows = 0 || gram.numCols = 0 then - 0.0 - else - let lambdaGersh := gram.infNormAbs - let lambdaBrauer := ConcreteMatrix.symmLambdaMaxUpperBrauer gram - let lambdaMoment := - ConcreteMatrix.psdLambdaMaxUpperMoment gram.numRows trace gram.frobeniusNormSq - let lambdaUpper := min lambdaGersh (min lambdaBrauer lambdaMoment) - let op := safeSqrt lambdaUpper - if Float.isNaN op || Float.isInf op then - safeSqrt lambdaGersh - else - op - -/-- Compute rigorous (in exact ℝ) scalar candidates for `‖W_Q·W_Kᵀ‖₂` and `‖W_V·W_O‖₂` -without forming any dense `modelDim×modelDim` products. -/ -def noDenseProductBounds (layer : ConcreteAttentionLayer) : NoDenseProductBounds := Id.run do - let wqGram := layer.W_Q.transpose.matmul layer.W_Q - let wkGram := layer.W_K.transpose.matmul layer.W_K - let wvGram := layer.W_V.transpose.matmul layer.W_V - let woGram := layer.W_O.matmul layer.W_O.transpose - - -- Frobenius candidates via trace identities (see docstring above). - let qkFrobSq := ConcreteMatrix.traceMul wqGram wkGram - let qkDenseFrob := safeSqrt qkFrobSq - let voFrobSq := ConcreteMatrix.traceMul wvGram woGram - let voDenseFrob := safeSqrt voFrobSq - - -- Gram-product spectral candidates (headDim×headDim): - -- ‖W_Q W_Kᵀ‖₂² = λ_max((W_QᵀW_Q)(W_KᵀW_K)) ≤ ‖(W_QᵀW_Q)(W_KᵀW_K)‖_∞. - let qkDenseGram := safeSqrt ((wqGram.matmul wkGram).infNormAbs) - -- ‖W_V W_O‖₂² = λ_max((W_VᵀW_V)(W_O W_Oᵀ)) ≤ ‖(W_VᵀW_V)(W_O W_Oᵀ)‖_∞. - let voDenseGram := safeSqrt ((wvGram.matmul woGram).infNormAbs) - - -- Brauer/Cassini candidates on `headDim×headDim` products with matching singular values. - let qkSmall := layer.W_K.transpose.matmul layer.W_Q - let qkDenseBrauer := qkSmall.opNormUpperBoundDenseBrauer - let voSmall := layer.W_O.matmul layer.W_V - let voDenseBrauer := voSmall.opNormUpperBoundDenseBrauer - - -- Factor bounds from submultiplicativity. - let qkFactorSchur := layer.W_Q.schurNormEst * layer.W_K.schurNormEst - let qkFactorFrob := layer.W_Q.frobeniusNorm * layer.W_K.frobeniusNorm - let wqOpGram0 := layer.W_Q.opNormUpperBoundViaGramInf - let wkOpGram0 := layer.W_K.opNormUpperBoundViaGramInf - let wqOpGram := min wqOpGram0 (opBoundFromGram wqGram layer.W_Q.frobeniusNormSq) - let wkOpGram := min wkOpGram0 (opBoundFromGram wkGram layer.W_K.frobeniusNormSq) - let qkFactorGram := wqOpGram * wkOpGram - - let voFactorSchur := layer.W_V.schurNormEst * layer.W_O.schurNormEst - let voFactorFrob := layer.W_V.frobeniusNorm * layer.W_O.frobeniusNorm - let wvOpGram0 := layer.W_V.opNormUpperBoundViaGramInf - -- PERFORMANCE: `W_O` is typically wide (`headDim×modelDim`), so compute on `W_Oᵀ` instead. - let woOpGram0 := layer.W_O.transpose.opNormUpperBoundViaGramInf - let wvOpGram := min wvOpGram0 (opBoundFromGram wvGram layer.W_V.frobeniusNormSq) - let woOpGram := min woOpGram0 (opBoundFromGram woGram layer.W_O.frobeniusNormSq) - let voFactorGram := wvOpGram * woOpGram - - let qkOpBoundCentered := layer.centeredQKOpBound - let qkOpBound := - min qkDenseFrob <| - min qkDenseGram <| - min qkDenseBrauer <| - min qkOpBoundCentered <| - min qkFactorSchur <| - min qkFactorFrob qkFactorGram - - let voOpBound := - min voDenseFrob <| - min voDenseGram <| - min voDenseBrauer <| - min voFactorSchur <| - min voFactorFrob voFactorGram - - return { - wqGram := wqGram - wkGram := wkGram - wvGram := wvGram - woGram := woGram - qkFrobNormSq := qkFrobSq - qkDenseFrob := qkDenseFrob - qkDenseGram := qkDenseGram - qkDenseBrauer := qkDenseBrauer - qkFactorSchur := qkFactorSchur - qkFactorFrob := qkFactorFrob - wqOpGram := wqOpGram - wkOpGram := wkOpGram - qkFactorGram := qkFactorGram - qkOpBound := qkOpBound - voFrobNormSq := voFrobSq - voDenseFrob := voDenseFrob - voDenseGram := voDenseGram - voDenseBrauer := voDenseBrauer - voFactorSchur := voFactorSchur - voFactorFrob := voFactorFrob - wvOpGram := wvOpGram - woOpGram := woOpGram - voFactorGram := voFactorGram - voOpBound := voOpBound - } - -private def opBoundMinOfMany (dense : ConcreteMatrix) - (leftFactor rightFactor : ConcreteMatrix) : Float := - let denseSchur := dense.schurNormEst - let denseFrob := dense.frobeniusNorm - let factorSchur := leftFactor.schurNormEst * rightFactor.schurNormEst - let factorFrob := leftFactor.frobeniusNorm * rightFactor.frobeniusNorm - min denseFrob (min denseSchur (min factorSchur factorFrob)) - -/-- A tighter (still sound-in-ℝ) Float upper bound on ‖W_Q · W_Kᵀ‖₂. - -We take the minimum of several valid upper bounds: - ‖M‖₂ ≤ schurNormEst(M) - ‖M‖₂ ≤ ‖M‖_F - ‖W_Q W_Kᵀ‖₂ ≤ ‖W_Q‖₂‖W_K‖₂ ≤ schur(W_Q)·schur(W_K) - ‖W_Q W_Kᵀ‖₂ ≤ ‖W_Q‖_F‖W_K‖_F - -This is computed in `Float` and is therefore a deterministic heuristic estimate. -In exact real arithmetic, each candidate is an upper bound, so taking `min` -can only tighten the bound. --/ -def queryKeyAlignmentOpBoundFrom (layer : ConcreteAttentionLayer) (qk : ConcreteMatrix) : Float := - let base := opBoundMinOfMany qk layer.W_Q layer.W_K - let wqOpGram := layer.W_Q.opNormUpperBoundViaGramInf - let wkOpGram := layer.W_K.opNormUpperBoundViaGramInf - let qkFactorGram := wqOpGram * wkOpGram - -- Low-rank Gram-product tightening (64×64): - -- ‖W_Q W_Kᵀ‖₂² = λ_max((W_QᵀW_Q)(W_KᵀW_K)) ≤ ‖(W_QᵀW_Q)(W_KᵀW_K)‖_∞. - let wqGram := layer.W_Q.transpose.matmul layer.W_Q - let wkGram := layer.W_K.transpose.matmul layer.W_K - let qkDenseGram := Float.sqrt (max 0.0 ((wqGram.matmul wkGram).infNormAbs)) - -- Brauer/Cassini Gram bound on a 64×64 product with the same singular values: - -- For `A = W_Q` and `B = W_K`, `‖A Bᵀ‖₂ = ‖Bᵀ A‖₂`. - let qkSmall := layer.W_K.transpose.matmul layer.W_Q - let qkDenseBrauer := qkSmall.opNormUpperBoundDenseBrauer - min (min base qkFactorGram) (min (min qkDenseGram qkDenseBrauer) qk.frobeniusNorm) - -/-- A tighter (still sound-in-ℝ) Float upper bound on ‖W_Q · W_Kᵀ‖₂. - -Convenience wrapper that materializes `W_Q·W_Kᵀ`. -Prefer `queryKeyAlignmentOpBoundFrom` when `qk` is already available. --/ -def queryKeyAlignmentOpBound (layer : ConcreteAttentionLayer) : Float := - -- Prefer the no-dense path by default to avoid allocating a `modelDim×modelDim` matrix. - (layer.noDenseProductBounds).qkOpBound - -/-- A tighter (still sound-in-ℝ) Float upper bound on ‖W_V · W_O‖₂. - -This is the minimum of Schur / Frobenius / factorized Schur / factorized Frobenius -upper bounds, analogous to `queryKeyAlignmentOpBoundFrom`. --/ -def valueOutputProjectionOpBoundFrom - (layer : ConcreteAttentionLayer) (vo : ConcreteMatrix) : Float := - let base := opBoundMinOfMany vo layer.W_V layer.W_O - let wvOpGram := layer.W_V.opNormUpperBoundViaGramInf - -- PERFORMANCE: `W_O` is typically wide (`headDim×modelDim`), so we compute the same - -- bound on `W_Oᵀ` instead (‖W_O‖₂ = ‖W_Oᵀ‖₂) to avoid an O(modelDim²) loop. - let woOpGram := layer.W_O.transpose.opNormUpperBoundViaGramInf - let voFactorGram := wvOpGram * woOpGram - -- Low-rank Gram-product tightening (64×64): - -- For `M = W_V W_O`, ‖M‖₂² = λ_max((W_VᵀW_V)(W_O W_Oᵀ)) up to reordering. - let wvGram := layer.W_V.transpose.matmul layer.W_V - let woGram := layer.W_O.matmul layer.W_O.transpose - let voDenseGram := Float.sqrt (max 0.0 ((wvGram.matmul woGram).infNormAbs)) - -- Brauer/Cassini Gram bound on a 64×64 product with the same singular values: - -- For `A = W_V` and `B = W_Oᵀ`, `‖A B‖₂ = ‖B A‖₂`. - let voSmall := layer.W_O.matmul layer.W_V - let voDenseBrauer := voSmall.opNormUpperBoundDenseBrauer - min (min base voFactorGram) (min (min voDenseGram voDenseBrauer) vo.frobeniusNorm) - - -/-- A tighter (still sound-in-ℝ) Float upper bound on ‖W_V · W_O‖₂. - -Convenience wrapper that materializes `W_V·W_O`. -Prefer `valueOutputProjectionOpBoundFrom` when `vo` is already available. --/ -def valueOutputProjectionOpBound (layer : ConcreteAttentionLayer) : Float := - -- Prefer the no-dense path by default to avoid allocating a `modelDim×modelDim` matrix. - (layer.noDenseProductBounds).voOpBound - -end ConcreteAttentionLayer - -/-! ### Local Gram-based operator bounds -/ - -/-- Tight Gram-based operator bound from an explicit PSD Gram matrix. - -This mirrors the logic used for attention-layer weight bounds but is scoped for -data-dependent activation Grams computed in discovery. --/ -private def opBoundFromGramLocal (gram : ConcreteMatrix) (trace : Float) : Float := - if gram.numRows = 0 || gram.numCols = 0 then - 0.0 - else - let lambdaGersh := gram.infNormAbs - let lambdaBrauer := ConcreteMatrix.symmLambdaMaxUpperBrauer gram - let lambdaMoment := - ConcreteMatrix.psdLambdaMaxUpperMoment gram.numRows trace gram.frobeniusNormSq - let lambdaPow4 : Float := - if gram.numRows ≤ 128 then - let gramSq := gram.matmul gram - let trace4 := ConcreteMatrix.traceMul gramSq gramSq - Float.sqrt (Float.sqrt (max 0.0 trace4)) - else - Float.inf - let lambdaUpper := min lambdaGersh (min lambdaBrauer (min lambdaMoment lambdaPow4)) - let op := Float.sqrt (max 0.0 lambdaUpper) - if Float.isNaN op || Float.isInf op then - Float.sqrt (max 0.0 lambdaGersh) - else - op - -/-! ## Concrete MLP Layer -/ - -/-- A concrete MLP (Feed-Forward) layer with exported weights. - -Standard transformer MLP: `output = W_out · activation(W_in · x + b_in) + b_out` - -For interpretability, we analyze individual **neurons** (columns of W_in / rows of W_out). -Each neuron i computes: `activation(W_in[:,i]·x + b_in[i])` and writes `W_out[i,:]` to output. --/ -structure ConcreteMLPLayer where - /-- Model dimension (embedding size) -/ - modelDim : Nat - /-- Hidden dimension (number of neurons) -/ - hiddenDim : Nat - /-- Input projection matrix (modelDim × hiddenDim): maps input to hidden activations -/ - W_in : ConcreteMatrix - /-- Output projection matrix (hiddenDim × modelDim): maps hidden to output -/ - W_out : ConcreteMatrix - /-- Input bias (hiddenDim) stored as 1×hiddenDim matrix for uniformity -/ - b_in : ConcreteMatrix - /-- Output bias (modelDim) stored as 1×modelDim matrix for uniformity -/ - b_out : ConcreteMatrix - /-- Dimension consistency for W_in -/ - W_in_dims : W_in.numRows = modelDim ∧ W_in.numCols = hiddenDim - /-- Dimension consistency for W_out -/ - W_out_dims : W_out.numRows = hiddenDim ∧ W_out.numCols = modelDim - /-- Dimension consistency for b_in -/ - b_in_dims : b_in.numRows = 1 ∧ b_in.numCols = hiddenDim - /-- Dimension consistency for b_out -/ - b_out_dims : b_out.numRows = 1 ∧ b_out.numCols = modelDim - -namespace ConcreteMLPLayer - -/-- Get the input weight vector for neuron i (column i of W_in). -/ -def neuronInputWeights (layer : ConcreteMLPLayer) (neuronIdx : Nat) : Array Float := - if neuronIdx < layer.hiddenDim then - .ofFn fun row : Fin layer.modelDim => layer.W_in.getUnsafe row.val neuronIdx - else #[] - -/-- Get the output weight vector for neuron i (row i of W_out). -/ -def neuronOutputWeights (layer : ConcreteMLPLayer) (neuronIdx : Nat) : Array Float := - if neuronIdx < layer.hiddenDim then - .ofFn fun col : Fin layer.modelDim => - layer.W_out.data[neuronIdx * layer.modelDim + col.val]! - else #[] - -/-- Compute the L2 norm of input weights for a neuron. -/ -def neuronInputNorm (layer : ConcreteMLPLayer) (neuronIdx : Nat) : Float := - let weights := layer.neuronInputWeights neuronIdx - Float.sqrt (sumSquares weights) - -/-- Compute the L2 norm of output weights for a neuron. -/ -def neuronOutputNorm (layer : ConcreteMLPLayer) (neuronIdx : Nat) : Float := - let weights := layer.neuronOutputWeights neuronIdx - Float.sqrt (sumSquares weights) - -/-- Compute the "influence magnitude" of a neuron: ‖W_in[:,i]‖ · ‖W_out[i,:]‖ - -This measures how much information can flow through neuron i. -For ReLU networks, this bounds the neuron's contribution to the output. --/ -def neuronInfluence (layer : ConcreteMLPLayer) (neuronIdx : Nat) : Float := - layer.neuronInputNorm neuronIdx * layer.neuronOutputNorm neuronIdx - -/-- Get the bias for neuron i. -/ -def getBias (layer : ConcreteMLPLayer) (neuronIdx : Nat) : Float := - if neuronIdx < layer.hiddenDim then - layer.b_in.data[neuronIdx]! - else 0.0 - -end ConcreteMLPLayer - -/-! ## Interval Bound Propagation (IBP) for MLP Activation Stability - -When analyzing circuit faithfulness, we need to know if ablating upstream components -can cause MLP neurons to "flip" their activation states. A neuron is **stable** if -its pre-activation stays positive (always active) or negative (always inactive) under -all perturbations bounded by ε. - -**Mathematical Setup:** -- Pre-activation: z = W_in^T · x + b -- For perturbation δx with ‖δx‖₂ ≤ ε: - - z' = W_in^T · (x + δx) + b = z + W_in^T · δx - - By Cauchy-Schwarz: |W_in^T · δx| ≤ ‖W_in‖₂ · ‖δx‖₂ ≤ ‖W_in‖₂ · ε -- Therefore: z - ε·‖W_in‖ ≤ z' ≤ z + ε·‖W_in‖ - -**Stability Criterion:** -- Neuron is "stably ON" if z - ε·‖W_in‖ > 0 -- Neuron is "stably OFF" if z + ε·‖W_in‖ < 0 -- Otherwise, neuron is "unstable" (may flip) - -**Pattern Term for Unstable Neurons:** -When a ReLU neuron flips, the linearization error is bounded by the magnitude -of the output weight times the activation change: - ‖ΔOutput‖ ≤ ‖W_out[i,:]‖ · |activation_change| - ≤ ‖W_out[i,:]‖ · max(|z + ε·‖W_in‖|, |z - ε·‖W_in‖|) - -For GeLU, the bound is tighter but we use ReLU-style conservative bounds. --/ - -/-- Result of interval bound propagation for a single neuron. -/ -structure NeuronIntervalBound where - /-- Neuron index within the layer -/ - neuronIdx : Nat - /-- Lower bound on pre-activation (z - ε·‖W_in‖) -/ - preActLower : Float - /-- Upper bound on pre-activation (z + ε·‖W_in‖) -/ - preActUpper : Float - /-- Nominal pre-activation (z = W_in^T · x + b) -/ - preActNominal : Float - /-- Input weight norm ‖W_in[:,i]‖ -/ - inputNorm : Float - /-- Output weight norm ‖W_out[i,:]‖ -/ - outputNorm : Float - deriving Repr - -namespace NeuronIntervalBound - -/-- Is this neuron stably active (always ON) under the perturbation bound? -/ -def isStablyActive (b : NeuronIntervalBound) : Bool := - b.preActLower > 0.0 - -/-- Is this neuron stably inactive (always OFF) under the perturbation bound? -/ -def isStablyInactive (b : NeuronIntervalBound) : Bool := - b.preActUpper < 0.0 - -/-- Is this neuron stable (won't flip activation state)? -/ -def isStable (b : NeuronIntervalBound) : Bool := - b.isStablyActive || b.isStablyInactive - -/-- Is this neuron unstable (may flip activation state)? -/ -def isUnstable (b : NeuronIntervalBound) : Bool := - ¬b.isStable - -/-- The "flip margin" - how close the pre-activation interval is to zero. - -For stable neurons this is positive (distance from zero). -For unstable neurons this is negative (interval crosses zero). --/ -def flipMargin (b : NeuronIntervalBound) : Float := - min b.preActLower (-b.preActUpper) - -/-- Bound on the activation change if the neuron flips (ReLU). - -For ReLU, if the neuron flips from ON to OFF, the activation changes from z to 0. -If it flips from OFF to ON, the activation changes from 0 to z. -The maximum change magnitude is bounded by max(|z_lower|, |z_upper|). --/ -def maxActivationChange (b : NeuronIntervalBound) : Float := - if b.isStable then 0.0 - else max (Float.abs b.preActLower) (Float.abs b.preActUpper) - -/-- Bound on the output error due to potential activation flip. - -This is the pattern term contribution for an unstable neuron: - ‖ΔOutput‖ ≤ ‖W_out[i,:]‖ · max_activation_change --/ -def patternTermBound (b : NeuronIntervalBound) : Float := - b.outputNorm * b.maxActivationChange - -end NeuronIntervalBound - -/-- Result of IBP analysis for an entire MLP layer. -/ -structure MLPIntervalAnalysis where - /-- Layer index -/ - layerIdx : Nat - /-- Per-neuron interval bounds -/ - neuronBounds : Array NeuronIntervalBound - /-- Input perturbation norm (ε) used for analysis -/ - perturbationNorm : Float - /-- Number of stable neurons -/ - numStable : Nat - /-- Number of unstable neurons -/ - numUnstable : Nat - /-- Total pattern term bound for unstable neurons -/ - totalPatternBound : Float - deriving Repr - -namespace MLPIntervalAnalysis - -/-- Fraction of neurons that are stable. -/ -def stabilityRatio (a : MLPIntervalAnalysis) : Float := - if a.neuronBounds.size = 0 then 1.0 - else a.numStable.toFloat / a.neuronBounds.size.toFloat - -/-- Is the layer "fully stable" (all neurons stable)? -/ -def isFullyStable (a : MLPIntervalAnalysis) : Bool := - a.numUnstable = 0 - -/-- Get bounds for a specific neuron. -/ -def getNeuronBound (a : MLPIntervalAnalysis) (idx : Nat) : Option NeuronIntervalBound := - if h : idx < a.neuronBounds.size then some a.neuronBounds[idx] else none - -end MLPIntervalAnalysis - -namespace ConcreteMLPLayer - -/-- Compute pre-activations for all neurons given an input vector (single position). - -Returns array of pre-activations: z[i] = W_in[:,i]^T · x + b[i] --/ -def computePreActivations (layer : ConcreteMLPLayer) (input : Array Float) : - Array Float := - .ofFn fun i : Fin layer.hiddenDim => Id.run do - let mut z : Float := layer.b_in.data[i.val]! - for j in [:layer.modelDim] do - -- SAFETY: input should have size modelDim, but getD provides safe fallback - let x_j := input.getD j 0.0 - let w_ji := layer.W_in.data[j * layer.hiddenDim + i.val]! - z := z + w_ji * x_j - return z - -/-- Compute interval bounds for all neurons given input and perturbation bound. - -**Algorithm:** -For each neuron i: -1. Compute nominal pre-activation: z = W_in[:,i]^T · x + b[i] -2. Compute Δz = ε · ‖W_in[:,i]‖₂ (maximum change due to perturbation) -3. Set bounds: [z - Δz, z + Δz] -4. Determine stability based on whether interval crosses zero - -**Parameters:** -- `input`: Nominal input vector (modelDim elements) -- `perturbationNorm`: L2 norm bound on input perturbation (ε) --/ -def computeIntervalBounds (layer : ConcreteMLPLayer) - (input : Array Float) (perturbationNorm : Float) : Array NeuronIntervalBound := - let preActs := layer.computePreActivations input - .ofFn fun i : Fin layer.hiddenDim => - -- SAFETY: preActs has size hiddenDim by construction from computePreActivations - let z := preActs[i.val]! - let inputNorm := layer.neuronInputNorm i.val - let outputNorm := layer.neuronOutputNorm i.val - let delta := perturbationNorm * inputNorm - { - neuronIdx := i.val - preActLower := z - delta - preActUpper := z + delta - preActNominal := z - inputNorm := inputNorm - outputNorm := outputNorm - } - -/-- Run full IBP analysis on an MLP layer. - -Returns comprehensive analysis including stability counts and total pattern bound. --/ -def analyzeIntervalBounds (layer : ConcreteMLPLayer) (layerIdx : Nat) - (input : Array Float) (perturbationNorm : Float) : MLPIntervalAnalysis := Id.run do - let bounds := layer.computeIntervalBounds input perturbationNorm - let mut numStable : Nat := 0 - let mut numUnstable : Nat := 0 - let mut totalPattern : Float := 0.0 - - for b in bounds do - if b.isStable then - numStable := numStable + 1 - else - numUnstable := numUnstable + 1 - totalPattern := totalPattern + b.patternTermBound - - { - layerIdx := layerIdx - neuronBounds := bounds - perturbationNorm := perturbationNorm - numStable := numStable - numUnstable := numUnstable - totalPatternBound := totalPattern - } - -/-- Compute the pattern term bound for a single neuron using IBP. - -This is a convenient wrapper for single-neuron queries, useful for -`computeNeuronImportance`. - -**Parameters:** -- `neuronIdx`: Index of the neuron to analyze -- `input`: Nominal input vector (from forward pass) -- `perturbationNorm`: L2 norm bound on input perturbation - -**Returns:** Pattern term bound (0 if stable, output_norm * max_change if unstable) --/ -def neuronPatternTermBoundIBP (layer : ConcreteMLPLayer) (neuronIdx : Nat) - (input : Array Float) (perturbationNorm : Float) : Float := - if neuronIdx ≥ layer.hiddenDim then 0.0 - else - let z := Id.run do - let mut acc : Float := layer.b_in.data[neuronIdx]! - for j in [:layer.modelDim] do - let x_j := input.getD j 0.0 - let w_ji := layer.W_in.data[j * layer.hiddenDim + neuronIdx]! - acc := acc + w_ji * x_j - acc - let inputNorm := layer.neuronInputNorm neuronIdx - let outputNorm := layer.neuronOutputNorm neuronIdx - let delta := perturbationNorm * inputNorm - - -- Check stability - let lower := z - delta - let upper := z + delta - if lower > 0.0 || upper < 0.0 then - -- Stable: no pattern term error - 0.0 - else - -- Unstable: bound by output weight times max activation change - outputNorm * max (Float.abs lower) (Float.abs upper) - -end ConcreteMLPLayer - -/-- Create a ConcreteMLPLayer from raw Float arrays. -/ -def mkConcreteMLPLayer - (modelDim hiddenDim : Nat) - (w_in w_out b_in b_out : Array Float) - (hw_in : w_in.size = modelDim * hiddenDim) - (hw_out : w_out.size = hiddenDim * modelDim) - (hb_in : b_in.size = hiddenDim) - (hb_out : b_out.size = modelDim) : ConcreteMLPLayer where - modelDim := modelDim - hiddenDim := hiddenDim - W_in := { numRows := modelDim, numCols := hiddenDim, data := w_in, size_eq := hw_in } - W_out := { numRows := hiddenDim, numCols := modelDim, data := w_out, size_eq := hw_out } - b_in := { numRows := 1, numCols := hiddenDim, data := b_in, - size_eq := by simp [hb_in] } - b_out := { numRows := 1, numCols := modelDim, data := b_out, - size_eq := by simp [hb_out] } - W_in_dims := ⟨rfl, rfl⟩ - W_out_dims := ⟨rfl, rfl⟩ - b_in_dims := ⟨rfl, rfl⟩ - b_out_dims := ⟨rfl, rfl⟩ - -/-! ## Sparse Autoencoders (SAEs) for Feature-Level Analysis - -Sparse Autoencoders decompose MLP activations into interpretable **features**. -Instead of analyzing raw neurons (which are often polysemantic), we analyze -sparse linear combinations that correspond to semantic concepts. - -**Architecture:** -- Encoder: `f = ReLU(W_enc · x + b_enc)` maps residual stream to sparse features -- Decoder: `x' = W_dec · f + b_dec` reconstructs the residual stream -- Sparsity: Only a small number of features are active (f[k] > 0) for any input - -**Key Insight for Circuit Discovery:** -The Jacobian of an MLP approximated through an SAE becomes a sum of rank-1 matrices: - `J ≈ Σ_{k ∈ active} W_dec[:,k] ⊗ W_enc[k,:]` - -This allows us to: -1. Identify which **features** (not neurons) are responsible for behavior -2. Discover cleaner circuits for complex behaviors -3. Handle polysemantic neurons by decomposing them into monosemantic features - -**Reconstruction Error:** -The SAE approximation introduces error: `‖MLP(x) - SAE(x)‖_F` -This must be accounted for in the total faithfulness bound. --/ - -/-- A Sparse Autoencoder for analyzing MLP features. - -Trained to reconstruct residual stream activations with sparse latent codes. -Typically has many more features than the residual stream dimension (e.g., 16x). --/ -structure ConcreteSAE where - /-- Input/output dimension (residual stream size) -/ - inputDim : Nat - /-- Number of features (typically >> inputDim for overcomplete SAEs) -/ - numFeatures : Nat - /-- Encoder weights (inputDim × numFeatures): W_enc[i,k] = weight from input i to feature k -/ - W_enc : ConcreteMatrix - /-- Decoder weights (numFeatures × inputDim): W_dec[k,j] = weight from feature k to output j -/ - W_dec : ConcreteMatrix - /-- Encoder bias (numFeatures): b_enc[k] = bias for feature k -/ - b_enc : ConcreteMatrix - /-- Decoder bias (inputDim): b_dec[j] = bias for output j -/ - b_dec : ConcreteMatrix - /-- Dimension constraints -/ - W_enc_dims : W_enc.numRows = inputDim ∧ W_enc.numCols = numFeatures - W_dec_dims : W_dec.numRows = numFeatures ∧ W_dec.numCols = inputDim - b_enc_dims : b_enc.numRows = 1 ∧ b_enc.numCols = numFeatures - b_dec_dims : b_dec.numRows = 1 ∧ b_dec.numCols = inputDim - -namespace ConcreteSAE - -/-- ReLU activation for SAE encoder. -/ -private def relu (x : Float) : Float := if x > 0.0 then x else 0.0 - -/-- Encode input to sparse feature activations: f = ReLU(W_enc^T · x + b_enc) - -Note: W_enc is stored as (inputDim × numFeatures), so we compute x^T · W_enc. - -PERFORMANCE: Pre-allocates output array (critical for SAE-based circuit discovery). --/ -def encode (sae : ConcreteSAE) (input : Array Float) : Array Float := - .ofFn fun k : Fin sae.numFeatures => Id.run do - let mut z : Float := sae.b_enc.data[k.val]! - for i in [:sae.inputDim] do - let x_i := input.getD i 0.0 - let w_ik := sae.W_enc.data[i * sae.numFeatures + k.val]! - z := z + x_i * w_ik - return relu z - -/-- Decode sparse features back to residual stream: x' = W_dec^T · f + b_dec - -Note: W_dec is stored as (numFeatures × inputDim), so we compute f^T · W_dec. - -PERFORMANCE: Pre-allocates output array (critical for SAE-based circuit discovery). --/ -def decode (sae : ConcreteSAE) (features : Array Float) : Array Float := - .ofFn fun j : Fin sae.inputDim => Id.run do - let mut y : Float := sae.b_dec.data[j.val]! - for k in [:sae.numFeatures] do - let f_k := features.getD k 0.0 - let w_kj := sae.W_dec.data[k * sae.inputDim + j.val]! - y := y + f_k * w_kj - return y - -/-- Full forward pass: encode then decode. -/ -def forward (sae : ConcreteSAE) (input : Array Float) : Array Float := - sae.decode (sae.encode input) - -/-- Compute reconstruction error ‖x - SAE(x)‖₂ for a single input. -/ -def reconstructionError (sae : ConcreteSAE) (input : Array Float) : Float := Id.run do - let reconstructed := sae.forward input - let mut errSq : Float := 0.0 - for i in [:sae.inputDim] do - let diff := input.getD i 0.0 - reconstructed.getD i 0.0 - errSq := errSq + diff * diff - Float.sqrt errSq - -/-- Compute reconstruction error for a matrix (multiple positions). -/ -def reconstructionErrorMatrix (sae : ConcreteSAE) (input : ConcreteMatrix) : Float := Id.run do - let mut totalErrSq : Float := 0.0 - for pos in [:input.numRows] do - let inputVec : Array Float := .ofFn fun d : Fin input.numCols => input.getUnsafe pos d.val - let err := sae.reconstructionError inputVec - totalErrSq := totalErrSq + err * err - Float.sqrt totalErrSq - -/-- Get indices of active features (f[k] > threshold). -/ -def activeFeatures (sae : ConcreteSAE) (input : Array Float) - (threshold : Float := 0.0) : Array Nat := Id.run do - let features := sae.encode input - let mut active : Array Nat := #[] - for k in [:sae.numFeatures] do - if features.getD k 0.0 > threshold then - active := active.push k - active - -/-- Count active features for an input. -/ -def numActiveFeatures (sae : ConcreteSAE) (input : Array Float) - (threshold : Float := 0.0) : Nat := - (sae.activeFeatures input threshold).size - -/-- Get the encoder weight vector for feature k (column k of W_enc). -/ -def encoderWeights (sae : ConcreteSAE) (featureIdx : Nat) : Array Float := - if featureIdx < sae.numFeatures then - .ofFn fun i : Fin sae.inputDim => sae.W_enc.getUnsafe i.val featureIdx - else #[] - -/-- Get the decoder weight vector for feature k (row k of W_dec). -/ -def decoderWeights (sae : ConcreteSAE) (featureIdx : Nat) : Array Float := - if featureIdx < sae.numFeatures then - .ofFn fun j : Fin sae.inputDim => - sae.W_dec.data[featureIdx * sae.inputDim + j.val]! - else #[] - -/-- Compute the L2 norm of encoder weights for feature k. -/ -def encoderNorm (sae : ConcreteSAE) (featureIdx : Nat) : Float := - let weights := sae.encoderWeights featureIdx - Float.sqrt (sumSquares weights) - -/-- Compute the L2 norm of decoder weights for feature k. -/ -def decoderNorm (sae : ConcreteSAE) (featureIdx : Nat) : Float := - let weights := sae.decoderWeights featureIdx - Float.sqrt (sumSquares weights) - -/-- Compute the "influence magnitude" of feature k: ‖W_enc[:,k]‖ · ‖W_dec[k,:]‖ - -This bounds how much information can flow through the feature. -Analogous to `neuronInfluence` for MLP neurons. --/ -def featureInfluence (sae : ConcreteSAE) (featureIdx : Nat) : Float := - sae.encoderNorm featureIdx * sae.decoderNorm featureIdx - -/-- Get encoder bias for feature k. -/ -def encoderBias (sae : ConcreteSAE) (featureIdx : Nat) : Float := - if featureIdx < sae.numFeatures then sae.b_enc.data[featureIdx]! else 0.0 - -/-- Compute the pre-activation for feature k given input. -/ -def featurePreActivation (sae : ConcreteSAE) (featureIdx : Nat) - (input : Array Float) : Float := Id.run do - if featureIdx ≥ sae.numFeatures then return 0.0 - let mut z : Float := sae.b_enc.data[featureIdx]! - for i in [:sae.inputDim] do - let x_i := input.getD i 0.0 - let w_ik := sae.W_enc.data[i * sae.numFeatures + featureIdx]! - z := z + x_i * w_ik - z - -/-- Check if feature k is active (pre-activation > 0) for given input. -/ -def isFeatureActive (sae : ConcreteSAE) (featureIdx : Nat) (input : Array Float) : Bool := - sae.featurePreActivation featureIdx input > 0.0 - -end ConcreteSAE - -/-- Create a ConcreteSAE from raw Float arrays. -/ -def mkConcreteSAE - (inputDim numFeatures : Nat) - (w_enc w_dec b_enc b_dec : Array Float) - (hw_enc : w_enc.size = inputDim * numFeatures) - (hw_dec : w_dec.size = numFeatures * inputDim) - (hb_enc : b_enc.size = numFeatures) - (hb_dec : b_dec.size = inputDim) : ConcreteSAE where - inputDim := inputDim - numFeatures := numFeatures - W_enc := { numRows := inputDim, numCols := numFeatures, data := w_enc, size_eq := hw_enc } - W_dec := { numRows := numFeatures, numCols := inputDim, data := w_dec, size_eq := hw_dec } - b_enc := { numRows := 1, numCols := numFeatures, data := b_enc, - size_eq := by simp [hb_enc] } - b_dec := { numRows := 1, numCols := inputDim, data := b_dec, - size_eq := by simp [hb_dec] } - W_enc_dims := ⟨rfl, rfl⟩ - W_dec_dims := ⟨rfl, rfl⟩ - b_enc_dims := ⟨rfl, rfl⟩ - b_dec_dims := ⟨rfl, rfl⟩ - -/-! ### SAE Interval Bound Propagation - -Like MLP neurons, SAE features can flip activation states under perturbation. -We extend IBP to track feature stability. --/ - -/-- Result of IBP analysis for a single SAE feature. -/ -structure SAEFeatureIntervalBound where - /-- Feature index -/ - featureIdx : Nat - /-- Lower bound on pre-activation -/ - preActLower : Float - /-- Upper bound on pre-activation -/ - preActUpper : Float - /-- Nominal pre-activation -/ - preActNominal : Float - /-- Encoder weight norm ‖W_enc[:,k]‖ -/ - encoderNorm : Float - /-- Decoder weight norm ‖W_dec[k,:]‖ -/ - decoderNorm : Float - deriving Repr - -namespace SAEFeatureIntervalBound - -/-- Is this feature stably active? -/ -def isStablyActive (b : SAEFeatureIntervalBound) : Bool := b.preActLower > 0.0 - -/-- Is this feature stably inactive? -/ -def isStablyInactive (b : SAEFeatureIntervalBound) : Bool := b.preActUpper < 0.0 - -/-- Is this feature stable (won't flip)? -/ -def isStable (b : SAEFeatureIntervalBound) : Bool := b.isStablyActive || b.isStablyInactive - -/-- Pattern term bound if this feature flips. -/ -def patternTermBound (b : SAEFeatureIntervalBound) : Float := - if b.isStable then 0.0 - else b.decoderNorm * max (Float.abs b.preActLower) (Float.abs b.preActUpper) - -end SAEFeatureIntervalBound - -/-- Result of IBP analysis for an entire SAE. -/ -structure SAEIntervalAnalysis where - /-- Per-feature bounds -/ - featureBounds : Array SAEFeatureIntervalBound - /-- Perturbation norm used -/ - perturbationNorm : Float - /-- Number of stable features -/ - numStable : Nat - /-- Number of unstable features -/ - numUnstable : Nat - /-- Total pattern term bound from unstable features -/ - totalPatternBound : Float - /-- Reconstruction error (SAE approximation) -/ - reconstructionError : Float - deriving Repr - -namespace SAEIntervalAnalysis - -/-- Stability ratio. -/ -def stabilityRatio (a : SAEIntervalAnalysis) : Float := - if a.featureBounds.size = 0 then 1.0 - else a.numStable.toFloat / a.featureBounds.size.toFloat - -/-- Total error bound (pattern + reconstruction). -/ -def totalErrorBound (a : SAEIntervalAnalysis) : Float := - a.totalPatternBound + a.reconstructionError - -end SAEIntervalAnalysis - -namespace ConcreteSAE - -/-- Compute interval bounds for all features given input and perturbation. - -PERFORMANCE: Pre-allocates result array with `Array.ofFn` to avoid O(n) reallocations. --/ -def computeFeatureIntervalBounds (sae : ConcreteSAE) - (input : Array Float) (perturbationNorm : Float) : Array SAEFeatureIntervalBound := - .ofFn fun k : Fin sae.numFeatures => - let z := sae.featurePreActivation k.val input - let encNorm := sae.encoderNorm k.val - let decNorm := sae.decoderNorm k.val - let delta := perturbationNorm * encNorm - { - featureIdx := k.val - preActLower := z - delta - preActUpper := z + delta - preActNominal := z - encoderNorm := encNorm - decoderNorm := decNorm - } - -/-- Run full IBP analysis on an SAE. -/ -def analyzeIntervalBounds (sae : ConcreteSAE) (input : Array Float) - (perturbationNorm : Float) : SAEIntervalAnalysis := Id.run do - let bounds := sae.computeFeatureIntervalBounds input perturbationNorm - let mut numStable : Nat := 0 - let mut numUnstable : Nat := 0 - let mut totalPattern : Float := 0.0 - - for b in bounds do - if b.isStable then - numStable := numStable + 1 - else - numUnstable := numUnstable + 1 - totalPattern := totalPattern + b.patternTermBound - - let reconErr := sae.reconstructionError input - - { - featureBounds := bounds - perturbationNorm := perturbationNorm - numStable := numStable - numUnstable := numUnstable - totalPatternBound := totalPattern - reconstructionError := reconErr - } - -/-- Pattern term bound for a single feature using IBP. -/ -def featurePatternTermBoundIBP (sae : ConcreteSAE) (featureIdx : Nat) - (input : Array Float) (perturbationNorm : Float) : Float := - if featureIdx ≥ sae.numFeatures then 0.0 - else - let z := sae.featurePreActivation featureIdx input - let encNorm := sae.encoderNorm featureIdx - let decNorm := sae.decoderNorm featureIdx - let delta := perturbationNorm * encNorm - let lower := z - delta - let upper := z + delta - if lower > 0.0 || upper < 0.0 then 0.0 - else decNorm * max (Float.abs lower) (Float.abs upper) - -end ConcreteSAE - -/-! ## Attention Weights Computation -/ - -/-- Concrete attention weights for a sequence. -A[q][k] = attention weight from query position q to key position k. -/ -structure ConcreteAttentionWeights where - /-- Sequence length -/ - seqLen : Nat - /-- Attention weights stored row-major: weights[q * seqLen + k] = A[q,k] -/ - weights : Array Float - /-- Size check -/ - size_eq : weights.size = seqLen * seqLen - -namespace ConcreteAttentionWeights - -/-- Access A[q, k]. -/ -def get (A : ConcreteAttentionWeights) (q k : Nat) : Float := - if q < A.seqLen ∧ k < A.seqLen then - -- Index is in-bounds by `size_eq` and the guard above. - A.weights[q * A.seqLen + k]! - else 0.0 - -/-- Fast access to `A[q, k]` assuming `q < seqLen` and `k < seqLen`. -/ -@[inline] def getUnsafe (A : ConcreteAttentionWeights) (q k : Nat) : Float := - A.weights[q * A.seqLen + k]! - -/-- Convert attention weights to a `ConcreteMatrix` for use with matrix multiplication. -/ -def toMatrix (A : ConcreteAttentionWeights) : ConcreteMatrix where - numRows := A.seqLen - numCols := A.seqLen - data := A.weights - size_eq := by - simpa using A.size_eq - -/-- Compute softmax for a row of logits. -/ -@[inline] -def softmaxRow (logits : Array Float) : Array Float := - Id.run do - -- PERFORMANCE: keep arrays linear to enable in-place updates - -- (see Lean Reference Manual: runtime reference counting + array performance). - let mut expVals : Array Float := logits - let mut maxVal : Float := -1e30 - for i in [:expVals.size] do - let v := expVals[i]! - if v > maxVal then maxVal := v - let mut sumExp : Float := 0.0 - for i in [:expVals.size] do - let v := Float.exp (expVals[i]! - maxVal) - expVals := expVals.set! i v - sumExp := sumExp + v - if sumExp > 0.0 then - for i in [:expVals.size] do - expVals := expVals.set! i (expVals[i]! / sumExp) - return expVals - -/-- Compute attention weights given queries, keys, and scaling. -/ -def compute (queries keys : ConcreteMatrix) (scale : Float) - (seqLen : Nat) - (causal : Bool := true) : ConcreteAttentionWeights := Id.run do - -- PERFORMANCE: avoid allocating an `Array (Array Float)` of rows; write the flattened - -- weights row-by-row into a single mutable array. - let cols := min queries.numCols keys.numCols - let n := seqLen * seqLen - let mut weights : { w : Array Float // w.size = n } := - ⟨Array.replicate n 0.0, by simp [n]⟩ - -- Reuse a single row buffer to avoid per-row allocations. - let mut rowScores : Array Float := Array.replicate seqLen (-1e30) - for q in [:seqLen] do - -- Initialize to -∞ and only fill the causal prefix when `causal = true`. - for i in [:seqLen] do - rowScores := rowScores.set! i (-1e30) - let stop := if causal then min (q + 1) seqLen else seqLen - let qBase := q * queries.numCols - for j in [:stop] do - if q < queries.numRows ∧ j < keys.numRows then - let jBase := j * keys.numCols - let mut dotProd : Float := 0.0 - for d in [:cols] do - -- SAFETY: within this branch `q < queries.numRows` and `j < keys.numRows`, - -- and `d < cols ≤ queries.numCols/keys.numCols`. - dotProd := dotProd + queries.data[qBase + d]! * keys.data[jBase + d]! - rowScores := rowScores.set! j (dotProd / scale) - rowScores := softmaxRow rowScores - let rowBase := q * seqLen - for k in [:stop] do - let weights' := weights.1.set! (rowBase + k) (rowScores[k]!) - have weights'SizeEq : weights'.size = n := by - have hsize : weights'.size = weights.1.size := by - -- `set!` is `setIfInBounds`, which preserves size. - simp [weights'] - exact hsize.trans weights.2 - weights := ⟨weights', weights'SizeEq⟩ - return { - seqLen := seqLen - weights := weights.1 - size_eq := by simpa [n] using weights.2 - } - -end ConcreteAttentionWeights - -/-! ## Softmax Jacobian Sparsity Analysis - -For a probability vector p, the softmax Jacobian J has entries J_ij = p_i(δ_ij - p_j). -The Frobenius norm squared of J for a single row is: - - ‖J‖²_F = Σᵢⱼ p_i²(δ_ij - p_j)² = Σᵢ p_i²(1-p_i)² + Σᵢ≠ⱼ p_i²p_j² - = Σᵢ p_i² - 2Σᵢ p_i³ + (Σᵢ p_i²)² - -A simpler upper bound is `Σᵢ p_i(1-p_i) = 1 - Σᵢ p_i²` (the Gini impurity). - -For sparse (one-hot) distributions: Σ p_i² ≈ 1, so the bound is ≈ 0 -For uniform distributions: Σ p_i² = 1/n, so the bound is √((n-1)/n) - -This allows us to compute much tighter pattern term bounds for sharp attention heads. --/ - -/-- Compute the "effective softmax derivative norm" for a single probability row. - -For probability vector p, this computes `sqrt(Σᵢ p_i(1-p_i)) = sqrt(1 - Σᵢ p_i²)`. -This bounds the Frobenius norm of the softmax Jacobian for that row. - -- One-hot distribution → 0 (no gradient flow through softmax) -- Uniform over n → sqrt((n-1)/n) ≈ 1 for large n --/ -def softmaxRowJacobianNorm (row : Array Float) : Float := - let sumSq := sumSquares row - Float.sqrt (max 0.0 (1.0 - sumSq)) - -/-- Compute the average softmax Jacobian norm across all rows of attention weights. - -This provides a data-dependent bound on the softmax Jacobian that is much tighter -than a coarse constant bound (e.g. 1.0), especially for sharp attention patterns. --/ -def ConcreteAttentionWeights.avgSoftmaxJacobianNorm (A : ConcreteAttentionWeights) : Float := - if A.seqLen = 0 then 0.0 - else Id.run do - let mut totalNormSq : Float := 0.0 - for q in [:A.seqLen] do - -- Extract row q - let mut sumSq : Float := 0.0 - let rowBase := q * A.seqLen - for k in [:A.seqLen] do - -- SAFETY: `q < seqLen` and `k < seqLen` by loop bounds. - let p := A.weights[rowBase + k]! - sumSq := sumSq + p * p - -- Jacobian norm squared for this row is bounded by 1 - sumSq - totalNormSq := totalNormSq + max 0.0 (1.0 - sumSq) - -- Return RMS (root mean square) of per-row norms - Float.sqrt (totalNormSq / A.seqLen.toFloat) - -/-- Compute the maximum softmax Jacobian norm across all rows. - -More conservative than avg, but still much tighter than a coarse constant bound for sparse attention. --/ -def ConcreteAttentionWeights.maxSoftmaxJacobianNorm (A : ConcreteAttentionWeights) : Float := - if A.seqLen = 0 then 0.0 - else Id.run do - let mut maxNorm : Float := 0.0 - for q in [:A.seqLen] do - let mut sumSq : Float := 0.0 - let rowBase := q * A.seqLen - for k in [:A.seqLen] do - -- SAFETY: `q < seqLen` and `k < seqLen` by loop bounds. - let p := A.weights[rowBase + k]! - sumSq := sumSq + p * p - let rowNorm := Float.sqrt (max 0.0 (1.0 - sumSq)) - maxNorm := max maxNorm rowNorm - maxNorm - -/-! ### Attention Jacobian heuristics - -The following quantities are computed from `Float` attention weights produced by a `Float` -softmax. They must be treated as **heuristic estimates**, not sound certificates. --/ - -/-- Diagnostics for the softmax-Jacobian operator norm bound. - -The full softmax Jacobian over an attention matrix is block-diagonal over rows, -so the overall operator norm is the maximum of the per-row operator norms. - -We record statistics for the row that attains this maximum. --/ -structure SoftmaxJacobianOpDiag where - /-- Overall (block-diagonal) operator norm upper bound, `max_q rowBound(q)`. -/ - opBound : Float - /-- `max_i p_i` for the maximizing row. -/ - maxRowMaxP : Float - /-- `tr(J) = 1 - ∑ p_i^2` for the maximizing row. -/ - maxRowTraceBound : Float - /-- PSD moment bound for the maximizing row, derived from `tr(J)` and `‖J‖_F²`. -/ - maxRowMomentBound : Float - /-- Gershgorin / induced-∞ bound `max_i 2 p_i (1 - p_i)` for the maximizing row. -/ - maxRowGersh : Float - /-- The final per-row bound used for the maximizing row. -/ - maxRowBoundUsed : Float - /-- Number of rows that triggered a conservative fallback (NaN/Inf/zero-sum). -/ - numRowsFallback : Nat - deriving Repr - -/-- Heuristic estimate of the softmax-Jacobian operator norm per row (then maxed). - -For a probability row `p`, the softmax Jacobian is: -`J = diag(p) - p pᵀ`. -It is positive semidefinite and satisfies: -- `J ≤ diag(p)` so `‖J‖₂ ≤ maxᵢ pᵢ` -- `‖J‖₂ ≤ tr(J) = 1 - Σᵢ pᵢ²` - -We take the tighter bound -`min(maxᵢ pᵢ, 1 - Σᵢ pᵢ²)` for each row, -and then take the maximum over rows. -This is especially sharp for nearly one-hot or nearly uniform rows. --/ -def ConcreteAttentionWeights.softmaxJacobianOpDiag - (A : ConcreteAttentionWeights) : SoftmaxJacobianOpDiag := - if A.seqLen = 0 then - { opBound := 0.0 - maxRowMaxP := 0.0 - maxRowTraceBound := 0.0 - maxRowMomentBound := 0.0 - maxRowGersh := 0.0 - maxRowBoundUsed := 0.0 - numRowsFallback := 0 } - else Id.run do - let mut maxBound : Float := 0.0 - let mut bestMaxP : Float := 0.0 - let mut bestTrace : Float := 0.0 - let mut bestMoment : Float := 0.0 - let mut bestGersh : Float := 0.0 - let mut bestUsed : Float := 0.0 - let mut fallbackCount : Nat := 0 - - for q in [:A.seqLen] do - let rowBase := q * A.seqLen - - -- Pass 1: clamp negatives to 0 and compute sum. - let mut sumP : Float := 0.0 - for k in [:A.seqLen] do - let p0 := A.weights[rowBase + k]! - let p := if p0 < 0.0 then 0.0 else p0 - sumP := sumP + p - - let mut rowBound : Float := 0.5 - let mut rowMaxP : Float := 0.0 - let mut rowTrace : Float := 0.0 - let mut rowMoment : Float := 0.0 - let mut rowGersh : Float := 0.0 - - if sumP ≤ 0.0 || Float.isNaN sumP || Float.isInf sumP then - -- Conservative fallback: global bound for any probability row. - fallbackCount := fallbackCount + 1 - rowBound := 0.5 - else - -- Pass 2: renormalize and compute per-row bounds. - let invSum := 1.0 / sumP - let mut sumSq : Float := 0.0 - let mut sumCube : Float := 0.0 - let mut maxP : Float := 0.0 - let mut gersh : Float := 0.0 - for k in [:A.seqLen] do - let p0 := A.weights[rowBase + k]! - let pClamped := if p0 < 0.0 then 0.0 else p0 - let p := pClamped * invSum - sumSq := sumSq + p * p - sumCube := sumCube + p * p * p - if p > maxP then maxP := p - let g := 2.0 * p * (1.0 - p) - if g > gersh then gersh := g - let traceBound := max 0.0 (1.0 - sumSq) - -- Moment bound: J is PSD with - -- tr(J) = 1 - Σ p_i² - -- ‖J‖_F² = Σ (p_i - p_i²)² + Σ_{i≠j} (p_i p_j)² - -- = (Σ p_i²) - 2(Σ p_i³) + (Σ p_i²)². - let frob2 := max 0.0 (sumSq - 2.0 * sumCube + sumSq * sumSq) - let momentBound := ConcreteMatrix.psdLambdaMaxUpperMoment A.seqLen traceBound frob2 - -- Rigorous (for probability rows): - -- λ_max(J) ≤ maxP - -- λ_max(J) ≤ tr(J) - -- λ_max(J) ≤ ‖J‖_∞ = max_i 2 p_i (1-p_i) - -- λ_max(J) ≤ 1/2 - let bound0 := min maxP (min traceBound (min gersh momentBound)) - let bound1 := min bound0 0.5 - let bound2 := max 0.0 (min 0.5 bound1) - let bound := if Float.isNaN bound2 || Float.isInf bound2 then 0.5 else bound2 - rowBound := bound - rowMaxP := maxP - rowTrace := traceBound - rowMoment := momentBound - rowGersh := gersh - - if rowBound > maxBound then - maxBound := rowBound - bestMaxP := rowMaxP - bestTrace := rowTrace - bestMoment := rowMoment - bestGersh := rowGersh - bestUsed := rowBound - - { opBound := maxBound - maxRowMaxP := bestMaxP - maxRowTraceBound := bestTrace - maxRowMomentBound := bestMoment - maxRowGersh := bestGersh - maxRowBoundUsed := bestUsed - numRowsFallback := fallbackCount } - -/-- Backwards-compatible accessor: the bound value. -/ -def ConcreteAttentionWeights.softmaxJacobianOpEst (A : ConcreteAttentionWeights) : Float := - (A.softmaxJacobianOpDiag).opBound - -/-- Compute the overall Frobenius norm of the softmax Jacobian across all rows. - -For a probability row `p`, the Jacobian is `J = diag(p) - p pᵀ` and -`‖J‖_F² = (Σ pᵢ²) - 2(Σ pᵢ³) + (Σ pᵢ²)²`. -We sum this per-row Frobenius norm squared and take a final square root. --/ -def ConcreteAttentionWeights.softmaxJacobianFrobeniusNorm - (A : ConcreteAttentionWeights) : Float := - if A.seqLen = 0 then 0.0 - else Id.run do - let mut totalFrobSq : Float := 0.0 - for q in [:A.seqLen] do - let mut sumSq : Float := 0.0 - let mut sumCube : Float := 0.0 - let rowBase := q * A.seqLen - for k in [:A.seqLen] do - -- SAFETY: `q < seqLen` and `k < seqLen` by loop bounds. - let p := A.weights[rowBase + k]! - sumSq := sumSq + p * p - sumCube := sumCube + p * p * p - let rowFrobSq := max 0.0 (sumSq - 2.0 * sumCube + sumSq * sumSq) - totalFrobSq := totalFrobSq + rowFrobSq - Float.sqrt totalFrobSq - -/-! ## Forward Pass Implementations - -These methods compute the actual forward pass through attention and MLP layers, -accumulating the residual stream as in real transformers. --/ - -/-- Compute attention weights for a layer given an input matrix. -/ -def ConcreteAttentionLayer.computeAttentionWeights (layer : ConcreteAttentionLayer) - (input : ConcreteMatrix) (causal : Bool := true) : ConcreteAttentionWeights := - let queries := (input.matmul layer.W_Q).addBias layer.b_Q - let keys := (input.matmul layer.W_K).addBias layer.b_K - let scale := Float.sqrt layer.headDim.toFloat - ConcreteAttentionWeights.compute queries keys scale input.numRows causal - -/-- Forward pass for a single attention head. - -Input: X (seqLen × modelDim) -Output: Y (seqLen × modelDim) where Y = A·V·W_O - -This computes the attention output (before residual connection): -1. Q = X·W_Q, K = X·W_K, V = X·W_V -2. A = softmax(Q·K^T / √d_head) -3. Y = A·V·W_O --/ -def ConcreteAttentionLayer.forward (layer : ConcreteAttentionLayer) (input : ConcreteMatrix) - (causal : Bool := true) : ConcreteMatrix := - let attn := layer.computeAttentionWeights input causal - let values := (input.matmul layer.W_V).addBias layer.b_V -- seqLen × headDim - -- Compute A · V using direct construction - let attnValues : ConcreteMatrix := { - numRows := input.numRows - numCols := layer.headDim - data := .ofFn fun idx : Fin (input.numRows * layer.headDim) => Id.run do - let q := idx.val / layer.headDim - let d := idx.val % layer.headDim - let mut acc : Float := 0.0 - let n := input.numRows - let attnRowBase := q * n - for k in [:n] do - -- SAFETY: `q < n` and `k < n` by loop bounds, and `attn.seqLen = n`. - let a := attn.weights[attnRowBase + k]! - -- SAFETY: `k < values.numRows = n` and `d < values.numCols = headDim`. - let v := values.data[k * layer.headDim + d]! - acc := acc + a * v - return acc - size_eq := Array.size_ofFn - } - -- Project back to model dimension: (A·V) · W_O - attnValues.matmul layer.W_O - -/-- ReLU activation function for Float. -/ -def reluFloat (x : Float) : Float := if x > 0.0 then x else 0.0 - -/-- GeLU activation function (approximate) for Float. -/ -def geluFloat (x : Float) : Float := - let pi : Float := 3.14159265358979323846 - -- GPT-2 uses the "gelu_new" tanh approximation: - -- 0.5*x*(1 + tanh( sqrt(2/pi) * (x + 0.044715*x^3) )) - let a := Float.sqrt (2.0 / pi) - 0.5 * x * (1.0 + Float.tanh (a * (x + 0.044715 * x * x * x))) - -/-- Derivative of `geluFloat` with respect to `x`. - -This matches the tanh-based approximate GeLU used by `geluFloat`. --/ -def geluDerivFloat (x : Float) : Float := - let pi : Float := 3.14159265358979323846 - let a := Float.sqrt (2.0 / pi) - let b : Float := 0.044715 - let t := a * (x + b * x * x * x) - let tanhT := Float.tanh t - let sech2 := 1.0 - tanhT * tanhT - let t' := a * (1.0 + 3.0 * b * x * x) - 0.5 * (1.0 + tanhT) + 0.5 * x * sech2 * t' - -/-- Data-dependent upper bound on the MLP Jacobian operator norm over a batch of tokens. - -For a token with GeLU derivative vector `d` (from the pre-activations `z`), -the MLP Jacobian is `J = W_in · diag(d) · W_out`. - -Computing `J` explicitly is intractable (`modelDim×modelDim`). Instead we upper-bound `‖J‖₂` -via the standard inequality `‖J‖₂ ≤ sqrt(‖J‖₁ · ‖J‖∞)`, where `‖·‖₁`/`‖·‖∞` are induced norms. - -For a fixed token, row/column absolute sums can be bounded without forming `J`. -In row-vector convention, `‖J‖₁` is the max row sum and `‖J‖∞` is the max column sum: - -* max row sum ≤ `max_r ∑_k |W_in[r,k]| · |d_k| · (∑_c |W_out[k,c]|)` -* max column sum ≤ `max_c ∑_k |W_out[k,c]| · |d_k| · (∑_r |W_in[r,k]|)` - -To avoid a per-token `O(modelDim·hiddenDim)` loop, we take `dMax[k] = max_token |d_token[k]|` and -use it in place of `|d_k|`, which is sound for both `‖·‖₁` and `‖·‖∞`. --/ -def computeMLPLayerOpNormFromGeluDerivWithOpBounds - (layer : ConcreteMLPLayer) (geluDeriv : ConcreteMatrix) (winUb woutUb : Float) : Float := Id.run do - let d := layer.modelDim - let h := layer.hiddenDim - let legacy : Float := winUb * 1.7 * woutUb - if d = 0 || h = 0 || geluDeriv.numRows = 0 then - -- No tokens or empty layer: treat as zero contribution. - return 0.0 - if geluDeriv.numCols ≠ h then - -- Dimension mismatch indicates an inconsistent forward-pass cache; fall back conservatively. - return legacy - if layer.W_in.numRows ≠ d || layer.W_in.numCols ≠ h then - return legacy - if layer.W_out.numRows ≠ h || layer.W_out.numCols ≠ d then - return legacy - - let rows := geluDeriv.numRows - - -- dMax[k] = max_token |gelu'(z)[token,k]|. - let dMax : Array Float := Id.run do - let mut out : Array Float := Array.replicate h 0.0 - for i in [:rows] do - let base := i * h - for k in [:h] do - let a := Float.abs (geluDeriv.data[base + k]!) - if a > out[k]! then - out := out.set! k a - out - let globalDmax : Float := maxArray dMax - if globalDmax ≤ 0.0 || Float.isNaN globalDmax || Float.isInf globalDmax then - -- If derivative information is degenerate, we can still use the global GeLU' upper bound (≈1.7). - return legacy - - -- sOut[k] = ∑_c |W_out[k,c]| (row sums of |W_out|). - let (sOut, woutRowSqSum) : (Array Float × Array Float) := Id.run do - let mut out : Array Float := Array.replicate h 0.0 - let mut sq : Array Float := Array.replicate h 0.0 - for k in [:h] do - let rowBase := k * d - let mut s : Float := 0.0 - let mut ss : Float := 0.0 - for c in [:d] do - let w := layer.W_out.data[rowBase + c]! - s := s + Float.abs w - ss := ss + w * w - out := out.set! k s - sq := sq.set! k ss - (out, sq) - - -- sIn[k] = ∑_r |W_in[r,k]| (column sums of |W_in|). - let (sIn, winColSqSum) : (Array Float × Array Float) := Id.run do - let mut out : Array Float := Array.replicate h 0.0 - let mut sq : Array Float := Array.replicate h 0.0 - for r in [:d] do - let rowBase := r * h - for k in [:h] do - let w := layer.W_in.data[rowBase + k]! - out := out.set! k (out[k]! + Float.abs w) - sq := sq.set! k (sq[k]! + w * w) - (out, sq) - - /- - Token-wise Frobenius scaling can be strictly tighter than using the coordinatewise maxima `dMax`: - the vector `dMax` may not be realized by any single token, but the MLP Jacobian is block-diagonal - across tokens, so `‖J‖₂ = max_token ‖J_token‖₂`. - -/ - let (winScaledFrobTokMax, woutScaledFrobTokMax) : (Float × Float) := Id.run do - let mut winMaxSq : Float := 0.0 - let mut woutMaxSq : Float := 0.0 - for i in [:rows] do - let base := i * h - let mut winSq : Float := 0.0 - let mut woutSq : Float := 0.0 - for k in [:h] do - let a := Float.abs (geluDeriv.data[base + k]!) - let aa := a * a - winSq := winSq + aa * winColSqSum[k]! - woutSq := woutSq + aa * woutRowSqSum[k]! - winMaxSq := max winMaxSq winSq - woutMaxSq := max woutMaxSq woutSq - (Float.sqrt (max 0.0 winMaxSq), Float.sqrt (max 0.0 woutMaxSq)) - - -- Fast candidates that exploit the full per-unit maxima `dMax[k]` (not just `max_k dMax[k]`): - -- - -- ‖W_in·diag(d)·W_out‖₂ ≤ ‖W_in·diag(dMax)‖₂ · ‖W_out‖₂ - -- ‖W_in·diag(d)·W_out‖₂ ≤ ‖W_in‖₂ · ‖diag(dMax)·W_out‖₂ - -- - -- where `dMax[k] = max_token |gelu'(z)[token,k]|`. Both are rigorous because induced - -- norms / Frobenius norms are monotone under entrywise scaling by `dMax`. - let maxWinScaledCol : Float := Id.run do - let mut m : Float := 0.0 - for k in [:h] do - m := max m (dMax[k]! * sIn[k]!) - m - let maxWoutScaledRow : Float := Id.run do - let mut m : Float := 0.0 - for k in [:h] do - m := max m (dMax[k]! * sOut[k]!) - m - let winScaledFrob : Float := winScaledFrobTokMax - let woutScaledFrob : Float := woutScaledFrobTokMax - - -- Max row-sum bound (‖J‖₁ in row-vector convention). - let (boundInf, maxWinRowScaled) : (Float × Float) := Id.run do - let mut maxRow : Float := 0.0 - let mut maxScaled : Float := 0.0 - for r in [:d] do - let rowBase := r * h - let mut s : Float := 0.0 - let mut sScaled : Float := 0.0 - for k in [:h] do - let coeff := dMax[k]! * sOut[k]! - let a := Float.abs (layer.W_in.data[rowBase + k]!) - s := s + a * coeff - sScaled := sScaled + a * dMax[k]! - if s > maxRow then - maxRow := s - if sScaled > maxScaled then - maxScaled := sScaled - (maxRow, maxScaled) - - -- Max column-sum bound (‖J‖∞ in row-vector convention). - let (boundOne, maxWoutColScaled) : (Float × Float) := Id.run do - let mut maxCol : Float := 0.0 - let mut maxScaled : Float := 0.0 - for c in [:d] do - let mut s : Float := 0.0 - let mut sScaled : Float := 0.0 - for k in [:h] do - let coeff := dMax[k]! * sIn[k]! - let a := Float.abs (layer.W_out.data[k * d + c]!) - s := s + a * coeff - sScaled := sScaled + a * dMax[k]! - if s > maxCol then - maxCol := s - if sScaled > maxScaled then - maxScaled := sScaled - (maxCol, maxScaled) - - let absSchur := Float.sqrt (max 0.0 (boundInf * boundOne)) - let legacy := winUb * globalDmax * woutUb - let winScaledOneInf := Float.sqrt (max 0.0 (maxWinScaledCol * maxWinRowScaled)) - let woutScaledOneInf := Float.sqrt (max 0.0 (maxWoutScaledRow * maxWoutColScaled)) - let winScaledUb := min winScaledFrob winScaledOneInf - let woutScaledUb := min woutScaledFrob woutScaledOneInf - let scaledViaWin := winScaledUb * woutUb - let scaledViaWout := winUb * woutScaledUb - let scaled0 := - if scaledViaWin ≤ 0.0 || Float.isNaN scaledViaWin || Float.isInf scaledViaWin then - scaledViaWout - else if scaledViaWout ≤ 0.0 || Float.isNaN scaledViaWout || Float.isInf scaledViaWout then - scaledViaWin - else - min scaledViaWin scaledViaWout - let scaled := - if scaled0 ≤ 0.0 || Float.isNaN scaled0 || Float.isInf scaled0 then legacy else scaled0 - let out := - if absSchur ≤ 0.0 || Float.isNaN absSchur || Float.isInf absSchur then - min legacy scaled - else - min absSchur (min legacy scaled) - return out - -/-- Diagnostics for the MLP absolute-Schur candidate bound `sqrt(‖J‖₁‖J‖∞)`. -/ -structure MLPOpAbsSchurDiag where - dMax : Float - boundInf : Float - boundOne : Float - absSchur : Float - -/-- Compute the absolute-Schur candidate for `‖W_in·diag(gelu'(z))·W_out‖₂` without using -any operator-norm bounds on the weights (diagnostics helper). - -This matches the `absSchur` computation inside `computeMLPLayerOpNormFromGeluDerivWithOpBounds`. --/ -def computeMLPOpAbsSchurDiag (layer : ConcreteMLPLayer) (geluDeriv : ConcreteMatrix) : MLPOpAbsSchurDiag := Id.run do - let d := layer.modelDim - let h := layer.hiddenDim - if d = 0 || h = 0 || geluDeriv.numRows = 0 || geluDeriv.numCols ≠ h then - return { dMax := 0.0, boundInf := 0.0, boundOne := 0.0, absSchur := 0.0 } - if layer.W_in.numRows ≠ d || layer.W_in.numCols ≠ h then - return { dMax := 0.0, boundInf := 0.0, boundOne := 0.0, absSchur := 0.0 } - if layer.W_out.numRows ≠ h || layer.W_out.numCols ≠ d then - return { dMax := 0.0, boundInf := 0.0, boundOne := 0.0, absSchur := 0.0 } - - let rows := geluDeriv.numRows - let dMaxVec : Array Float := Id.run do - let mut out : Array Float := Array.replicate h 0.0 - for i in [:rows] do - let base := i * h - for k in [:h] do - let a := Float.abs (geluDeriv.data[base + k]!) - if a > out[k]! then - out := out.set! k a - out - let globalDmax : Float := maxArray dMaxVec - if globalDmax ≤ 0.0 || Float.isNaN globalDmax || Float.isInf globalDmax then - return { dMax := 0.0, boundInf := 0.0, boundOne := 0.0, absSchur := 0.0 } - - let sOut : Array Float := Id.run do - let mut out : Array Float := Array.replicate h 0.0 - for k in [:h] do - let rowBase := k * d - let mut s : Float := 0.0 - for c in [:d] do - s := s + Float.abs (layer.W_out.data[rowBase + c]!) - out := out.set! k s - out - let sIn : Array Float := Id.run do - let mut out : Array Float := Array.replicate h 0.0 - for r in [:d] do - let rowBase := r * h - for k in [:h] do - out := out.set! k (out[k]! + Float.abs (layer.W_in.data[rowBase + k]!)) - out - - let boundInf : Float := Id.run do - let mut maxRow : Float := 0.0 - for r in [:d] do - let rowBase := r * h - let mut s : Float := 0.0 - for k in [:h] do - let coeff := dMaxVec[k]! * sOut[k]! - s := s + Float.abs (layer.W_in.data[rowBase + k]!) * coeff - if s > maxRow then - maxRow := s - maxRow - - let boundOne : Float := Id.run do - let mut maxCol : Float := 0.0 - for c in [:d] do - let mut s : Float := 0.0 - for k in [:h] do - let coeff := dMaxVec[k]! * sIn[k]! - s := s + Float.abs (layer.W_out.data[k * d + c]!) * coeff - if s > maxCol then - maxCol := s - maxCol - - let absSchur := Float.sqrt (max 0.0 (boundInf * boundOne)) - return { dMax := globalDmax, boundInf := boundInf, boundOne := boundOne, absSchur := absSchur } - -/-- Precomputed, token-batch–dependent factors for bounding `‖W_in · diag(gelu'(z)) · W_out‖₂`. - -This is used by the adaptive scheduler to avoid recomputing weight-dependent summaries -(`absSchur`, scaled Frobenius/Schur bounds) across effort-tier upgrades. - -Correctness: the final upper bound is still computed as a minimum of rigorous candidates; -this structure merely memoizes pieces of those candidates. --/ -structure MLPJacobianBoundCore where - /-- `max_k dMax[k]` where `dMax[k] = max_token |gelu'(z)[token,k]|`. -/ - globalDmax : Float - /-- Candidate `sqrt(‖J‖₁‖J‖∞)` computed from absolute row/col sums (independent of `winUb/woutUb`). -/ - absSchur : Float - /-- Upper bound on `‖W_in · diag(dMax)‖₂` (min of scaled Frobenius and scaled Schur/one-inf). -/ - winScaledUb : Float - /-- Upper bound on `‖diag(dMax) · W_out‖₂` (min of scaled Frobenius and scaled Schur/one-inf). -/ - woutScaledUb : Float - deriving Repr - -/-- Precompute `MLPJacobianBoundCore` for a layer and a given `geluDeriv` matrix. - -Returns `none` if dimensions don't match; callers should conservatively fall back. --/ -def ConcreteMLPLayer.precomputeJacobianBoundCore (layer : ConcreteMLPLayer) - (geluDeriv : ConcreteMatrix) : Option MLPJacobianBoundCore := Id.run do - let d := layer.modelDim - let h := layer.hiddenDim - if d = 0 || h = 0 || geluDeriv.numRows = 0 || geluDeriv.numCols ≠ h then - return none - if layer.W_in.numRows ≠ d || layer.W_in.numCols ≠ h then - return none - if layer.W_out.numRows ≠ h || layer.W_out.numCols ≠ d then - return none - - let rows := geluDeriv.numRows - - -- dMax[k] = max_token |gelu'(z)[token,k]|. - let dMax : Array Float := Id.run do - let mut out : Array Float := Array.replicate h 0.0 - for i in [:rows] do - let base := i * h - for k in [:h] do - let a := Float.abs (geluDeriv.data[base + k]!) - if a > out[k]! then - out := out.set! k a - out - let globalDmax : Float := maxArray dMax - if globalDmax ≤ 0.0 || Float.isNaN globalDmax || Float.isInf globalDmax then - return none - - -- sOut[k] = ∑_c |W_out[k,c]| (row sums of |W_out|). - let (sOut, woutRowSqSum) : (Array Float × Array Float) := Id.run do - let mut out : Array Float := Array.replicate h 0.0 - let mut sq : Array Float := Array.replicate h 0.0 - for k in [:h] do - let rowBase := k * d - let mut s : Float := 0.0 - let mut ss : Float := 0.0 - for c in [:d] do - let w := layer.W_out.data[rowBase + c]! - s := s + Float.abs w - ss := ss + w * w - out := out.set! k s - sq := sq.set! k ss - (out, sq) - - -- sIn[k] = ∑_r |W_in[r,k]| (column sums of |W_in|). - let (sIn, winColSqSum) : (Array Float × Array Float) := Id.run do - let mut out : Array Float := Array.replicate h 0.0 - let mut sq : Array Float := Array.replicate h 0.0 - for r in [:d] do - let rowBase := r * h - for k in [:h] do - let w := layer.W_in.data[rowBase + k]! - out := out.set! k (out[k]! + Float.abs w) - sq := sq.set! k (sq[k]! + w * w) - (out, sq) - - let (winScaledFrobTokMax, woutScaledFrobTokMax) : (Float × Float) := Id.run do - let mut winMaxSq : Float := 0.0 - let mut woutMaxSq : Float := 0.0 - for i in [:rows] do - let base := i * h - let mut winSq : Float := 0.0 - let mut woutSq : Float := 0.0 - for k in [:h] do - let a := Float.abs (geluDeriv.data[base + k]!) - let aa := a * a - winSq := winSq + aa * winColSqSum[k]! - woutSq := woutSq + aa * woutRowSqSum[k]! - winMaxSq := max winMaxSq winSq - woutMaxSq := max woutMaxSq woutSq - (Float.sqrt (max 0.0 winMaxSq), Float.sqrt (max 0.0 woutMaxSq)) - - -- Max row-sum bound (‖J‖₁ in row-vector convention). - let boundInf : Float := Id.run do - let mut maxRow : Float := 0.0 - for r in [:d] do - let rowBase := r * h - let mut s : Float := 0.0 - for k in [:h] do - let coeff := dMax[k]! * sOut[k]! - let a := Float.abs (layer.W_in.data[rowBase + k]!) - s := s + a * coeff - if s > maxRow then - maxRow := s - maxRow - - -- Max column-sum bound (‖J‖∞ in row-vector convention). - let boundOne : Float := Id.run do - let mut maxCol : Float := 0.0 - for c in [:d] do - let mut s : Float := 0.0 - for k in [:h] do - let coeff := dMax[k]! * sIn[k]! - let a := Float.abs (layer.W_out.data[k * d + c]!) - s := s + a * coeff - if s > maxCol then - maxCol := s - maxCol - - let absSchur : Float := Float.sqrt (max 0.0 (boundInf * boundOne)) - - let winScaledFrob : Float := winScaledFrobTokMax - let woutScaledFrob : Float := woutScaledFrobTokMax - - let maxWinScaledCol : Float := Id.run do - let mut m : Float := 0.0 - for k in [:h] do - m := max m (dMax[k]! * sIn[k]!) - m - let maxWoutScaledRow : Float := Id.run do - let mut m : Float := 0.0 - for k in [:h] do - m := max m (dMax[k]! * sOut[k]!) - m - let maxWinRowScaled : Float := Id.run do - let mut m : Float := 0.0 - for r in [:d] do - let rowBase := r * h - let mut s : Float := 0.0 - for k in [:h] do - s := s + Float.abs (layer.W_in.data[rowBase + k]!) * dMax[k]! - m := max m s - m - let maxWoutColScaled : Float := Id.run do - let mut m : Float := 0.0 - for c in [:d] do - let mut s : Float := 0.0 - for k in [:h] do - s := s + Float.abs (layer.W_out.data[k * d + c]!) * dMax[k]! - m := max m s - m - - let winScaledOneInf := Float.sqrt (max 0.0 (maxWinScaledCol * maxWinRowScaled)) - let woutScaledOneInf := Float.sqrt (max 0.0 (maxWoutScaledRow * maxWoutColScaled)) - let winScaledUb := min winScaledFrob winScaledOneInf - let woutScaledUb := min woutScaledFrob woutScaledOneInf - - if absSchur ≤ 0.0 || Float.isNaN absSchur || Float.isInf absSchur then - return none - if winScaledUb ≤ 0.0 || Float.isNaN winScaledUb || Float.isInf winScaledUb then - return none - if woutScaledUb ≤ 0.0 || Float.isNaN woutScaledUb || Float.isInf woutScaledUb then - return none - - return some { globalDmax := globalDmax, absSchur := absSchur, - winScaledUb := winScaledUb, woutScaledUb := woutScaledUb } - -/-- Fused MLP Jacobian upper bound, given the cached GeLU derivative matrix `gelu'(z)`. -/ -def computeMLPLayerOpNormFromGeluDeriv (layer : ConcreteMLPLayer) (geluDeriv : ConcreteMatrix) : Float := Id.run do - let winUb := layer.W_in.opNormUpperBoundRectGram - let woutUb := layer.W_out.opNormUpperBoundRectGram - computeMLPLayerOpNormFromGeluDerivWithOpBounds layer geluDeriv winUb woutUb - -/-- Compute the MLP Jacobian bound by re-evaluating `gelu'(z)` from a concrete input. -/ -def computeMLPLayerOpNorm (layer : ConcreteMLPLayer) (input : ConcreteMatrix) : Float := Id.run do - let d := layer.modelDim - let h := layer.hiddenDim - if d = 0 || h = 0 || input.numRows = 0 then - return 0.0 - if layer.W_in.numRows ≠ d || layer.W_in.numCols ≠ h then - return 0.0 - if layer.W_out.numRows ≠ h || layer.W_out.numCols ≠ d then - return 0.0 - let hidden := (input.matmul layer.W_in).addBias layer.b_in - let geluDeriv := hidden.map geluDerivFloat - let winUb := layer.W_in.opNormUpperBoundRectGram - let woutUb := layer.W_out.opNormUpperBoundRectGram - computeMLPLayerOpNormFromGeluDerivWithOpBounds layer geluDeriv winUb woutUb - -/-- Forward pass for an MLP layer with GeLU activation. - -Input: X (seqLen × modelDim) -Output: Y (seqLen × modelDim) where Y = W_out · GeLU(W_in · X + b_in) + b_out - -For each position, this computes the standard transformer FFN. --/ -def ConcreteMLPLayer.forward (layer : ConcreteMLPLayer) (input : ConcreteMatrix) : ConcreteMatrix := - -- hidden = input · W_in + b_in (seqLen × hiddenDim) - let hidden := (input.matmul layer.W_in).addBias layer.b_in - -- Apply GeLU activation - let activated := hidden.map geluFloat - -- output = activated · W_out + b_out (seqLen × modelDim) - (activated.matmul layer.W_out).addBias layer.b_out - -/-- Forward pass plus a data-dependent bound on `max |gelu'(z)|` over this layer's preactivations. - -The returned derivative maximum is exact for the computed `hidden` matrix entries (interpreting -Float arithmetic as defining a concrete real expression, consistent with this file's conventions). -If NaN/Inf is encountered, callers should conservatively fall back to a global bound. --/ -def ConcreteMLPLayer.forwardWithGeluDerivMax (layer : ConcreteMLPLayer) - (input : ConcreteMatrix) : (ConcreteMatrix × Float) := Id.run do - let hidden := (input.matmul layer.W_in).addBias layer.b_in - let mut maxDeriv : Float := 0.0 - for z in hidden.data do - let d := Float.abs (geluDerivFloat z) - if d > maxDeriv then - maxDeriv := d - let activated := hidden.map geluFloat - let out := (activated.matmul layer.W_out).addBias layer.b_out - return (out, maxDeriv) - -/-! ## Efficient Pattern Term Bound Calculation - -The core insight: the valueTerm Frobenius norm factors as `‖A‖_F · ‖W_V·W_O‖_F`. -This avoids computing the full (N·D)² matrix! --/ - -/-- Compute ‖valueTerm‖_F efficiently via factorization. -/ -def computeValueTermNorm (attn : ConcreteAttentionWeights) - (valueOutputProjFrobNormSq : Float) : Float := - let attnNormSq := sumSquares attn.weights - Float.sqrt (attnNormSq * valueOutputProjFrobNormSq) - -/-- Information needed to bound the pattern term. -/ -structure PatternTermBoundInputs where - /-- Attention weights -/ - attention : ConcreteAttentionWeights - /-- Input embedding norm (‖X‖_F) -/ - inputNorm : Float - /-- Input embedding operator-norm upper bound (‖X‖₂), used for tighter pattern-term bounds. -/ - inputOpBound : Float - /-- Optional direct Frobenius norm bound for Q = X·W_Q + b_Q, if available. -/ - qFrobBound : Float := 0.0 - /-- Optional direct Frobenius norm bound for K = X·W_K + b_K, if available. -/ - kFrobBound : Float := 0.0 - /-- Optional Frobenius norm bound for centered V = X·W_V + b_V, if available. -/ - vFrobBound : Float := 0.0 - /-- Optional direct operator-norm bound for Q = X·W_Q + b_Q, if available. -/ - qOpBoundAct : Float := 0.0 - /-- Optional direct operator-norm bound for K = X·W_K + b_K, if available. -/ - kOpBoundAct : Float := 0.0 - /-- Optional operator-norm bound for centered V = X·W_V + b_V, if available. -/ - vOpBoundAct : Float := 0.0 - /-- Optional Frobenius bound for `‖W_Q·Kᵀ‖_F` after centering keys, if available. -/ - qkActFrobBound : Float := 0.0 - /-- Optional Frobenius bound for `‖W_K·Qᵀ‖_F` after centering queries, if available. -/ - kqActFrobBound : Float := 0.0 - /-- Optional operator-norm bound for `‖W_Q·Kᵀ‖₂` after centering keys, if available. -/ - qkActOpBound : Float := 0.0 - /-- Optional operator-norm bound for `‖W_K·Qᵀ‖₂` after centering queries, if available. -/ - kqActOpBound : Float := 0.0 - /-- Scaling factor (√d_head) -/ - scaleFactor : Float - /-- Deterministic Float upper bound on ‖W_Q‖₂. -/ - wqOpBound : Float - /-- Deterministic Float upper bound on ‖W_K‖₂. -/ - wkOpBound : Float - /-- Deterministic Float upper bound on ‖W_V‖₂. -/ - wvOpBound : Float - /-- Deterministic Float upper bound on ‖W_O‖₂. -/ - woOpBound : Float - /-- Deterministic Float upper bound on ‖W_V·W_O‖₂, if available. -/ - voOpBound : Float := 0.0 - /-- Query bias Frobenius norm (for the 1×headDim bias row). -/ - bqFrob : Float - /-- Key bias Frobenius norm (for the 1×headDim bias row). -/ - bkFrob : Float - /-- Value bias Frobenius norm (for the 1×headDim bias row). -/ - bvFrob : Float - -/-- Selected intermediate bounds for the pattern-term calculation. -/ -structure PatternTermBoundParts where - /-- Selected operator-norm bound for `‖W_Q·Kᵀ‖₂` after all min-combination. -/ - qkActOpUb : Float - /-- Selected operator-norm bound for `‖W_K·Qᵀ‖₂` after all min-combination. -/ - kqActOpUb : Float - /-- Selected operator-norm bound for centered `V`. -/ - vOpUb : Float - /-- Selected operator-norm bound for centered `V·W_O`. -/ - vOpUbWO : Float - /-- Frobenius-style candidate bound. -/ - candFrob : Float - /-- Operator-style candidate bound (using `vOpUb`). -/ - candOp : Float - /-- Operator-style candidate bound (using `vOpUbWO`). -/ - candOpWO : Float - /-- Final chosen pattern-term bound. -/ - patternBound : Float - -/-- Compute pattern-term bound and expose the selected intermediate bounds. -/ -def computePatternTermBoundParts (inputs : PatternTermBoundInputs) : PatternTermBoundParts := - -- Data-dependent bound on the softmax Jacobian operator norm. - -- Provable global bound: for any probability vector p, J = diag(p) - p pᵀ has ‖J‖₂ ≤ 1/2. - -- Clamp defensively so callers cannot accidentally exceed this. - let softmaxBound := min inputs.attention.softmaxJacobianOpEst 0.5 - let n := inputs.attention.seqLen - let sqrtN := Float.sqrt n.toFloat - -- Bias-aware Frobenius upper bounds on the concrete Q/K/V activations. - -- If the caller provides direct (data-dependent) bounds on ‖Q‖_F or ‖K‖_F, we - -- take the minimum (still a rigorous upper bound in exact ℝ arithmetic). - let qFrobUb0 := inputs.inputNorm * inputs.wqOpBound + sqrtN * inputs.bqFrob - let kFrobUb0 := inputs.inputNorm * inputs.wkOpBound + sqrtN * inputs.bkFrob - let qFrobUb := - if inputs.qFrobBound ≤ 0.0 || Float.isNaN inputs.qFrobBound || Float.isInf inputs.qFrobBound then - qFrobUb0 - else - min qFrobUb0 inputs.qFrobBound - let kFrobUb := - if inputs.kFrobBound ≤ 0.0 || Float.isNaN inputs.kFrobBound || Float.isInf inputs.kFrobBound then - kFrobUb0 - else - min kFrobUb0 inputs.kFrobBound - let vFrobUb0 := inputs.inputNorm * inputs.wvOpBound + sqrtN * inputs.bvFrob - let vFrobUb := - if inputs.vFrobBound ≤ 0.0 || Float.isNaN inputs.vFrobBound || Float.isInf inputs.vFrobBound then - vFrobUb0 - else - min vFrobUb0 inputs.vFrobBound - -- Bias-aware operator-norm upper bounds on Q/K/V: - -- ‖X·W_Q + 1·b_Q‖₂ ≤ ‖X‖₂‖W_Q‖₂ + √n·‖b_Q‖₂. - let qOpUb0 := inputs.inputOpBound * inputs.wqOpBound + sqrtN * inputs.bqFrob - let kOpUb0 := inputs.inputOpBound * inputs.wkOpBound + sqrtN * inputs.bkFrob - let qOpUb := - if inputs.qOpBoundAct ≤ 0.0 || Float.isNaN inputs.qOpBoundAct || Float.isInf inputs.qOpBoundAct then - qOpUb0 - else - min qOpUb0 inputs.qOpBoundAct - let kOpUb := - if inputs.kOpBoundAct ≤ 0.0 || Float.isNaN inputs.kOpBoundAct || Float.isInf inputs.kOpBoundAct then - kOpUb0 - else - min kOpUb0 inputs.kOpBoundAct - let qkActOpUb0 := inputs.wqOpBound * kOpUb - let qkActOpUb1 := - if inputs.qkActOpBound ≤ 0.0 || Float.isNaN inputs.qkActOpBound || - Float.isInf inputs.qkActOpBound then - qkActOpUb0 - else - min qkActOpUb0 inputs.qkActOpBound - let qkActOpUb := - if inputs.qkActFrobBound ≤ 0.0 || Float.isNaN inputs.qkActFrobBound || - Float.isInf inputs.qkActFrobBound then - qkActOpUb1 - else - min qkActOpUb1 inputs.qkActFrobBound - let kqActOpUb0 := inputs.wkOpBound * qOpUb - let kqActOpUb1 := - if inputs.kqActOpBound ≤ 0.0 || Float.isNaN inputs.kqActOpBound || - Float.isInf inputs.kqActOpBound then - kqActOpUb0 - else - min kqActOpUb0 inputs.kqActOpBound - let kqActOpUb := - if inputs.kqActFrobBound ≤ 0.0 || Float.isNaN inputs.kqActFrobBound || - Float.isInf inputs.kqActFrobBound then - kqActOpUb1 - else - min kqActOpUb1 inputs.kqActFrobBound - let vOpUb0 := inputs.inputOpBound * inputs.wvOpBound + sqrtN * inputs.bvFrob - let vOpUb := - if inputs.vOpBoundAct ≤ 0.0 || Float.isNaN inputs.vOpBoundAct || Float.isInf inputs.vOpBoundAct then - vOpUb0 - else - min vOpUb0 inputs.vOpBoundAct - let vOpUbWO0 := inputs.inputOpBound * inputs.voOpBound + - sqrtN * inputs.bvFrob * inputs.woOpBound - let vOpUbWO := - if inputs.voOpBound ≤ 0.0 || Float.isNaN inputs.voOpBound || Float.isInf inputs.voOpBound then - vOpUb * inputs.woOpBound - else - min (vOpUb * inputs.woOpBound) vOpUbWO0 - -- Logits sensitivity: - -- dS = (dQ)Kᵀ + Q(dK)ᵀ, with ‖dQ‖_F ≤ ‖dX‖_F‖W_Q‖₂ and ‖dK‖_F ≤ ‖dX‖_F‖W_K‖₂. - let sCoeffFrob := (inputs.wqOpBound * kFrobUb + inputs.wkOpBound * qFrobUb) / inputs.scaleFactor - let sCoeffOp := (qkActOpUb + kqActOpUb) / inputs.scaleFactor - -- Two rigorous candidates: - -- (A) Frobenius-style activation bounds: - -- ‖dA·V·W_O‖_F ≤ ‖dA‖_F‖V‖_F‖W_O‖₂. - let candFrob := (softmaxBound * sCoeffFrob) * vFrobUb * inputs.woOpBound - -- (B) Operator-style activation bounds (often much tighter): - -- ‖dS‖_F ≤ ‖dX‖_F · (‖W_Q‖₂‖K‖₂ + ‖W_K‖₂‖Q‖₂)/scale, - -- ‖dA·V‖_F ≤ ‖dA‖_F‖V‖₂, - -- so ‖dA·V·W_O‖_F ≤ ‖J_softmax‖₂‖dS‖_F‖V‖₂‖W_O‖₂. - let candOp := (softmaxBound * sCoeffOp) * vOpUb * inputs.woOpBound - let candOpWO := (softmaxBound * sCoeffOp) * vOpUbWO - let patternBound := min candFrob (min candOp candOpWO) - { qkActOpUb := qkActOpUb - kqActOpUb := kqActOpUb - vOpUb := vOpUb - vOpUbWO := vOpUbWO - candFrob := candFrob - candOp := candOp - candOpWO := candOpWO - patternBound := patternBound } - -/-- Bound ‖patternTerm‖_F without expanding the full Jacobian. - -The pattern term arises from how attention weights A change when input X changes: - patternTerm = (∂A/∂X) ⊗ (V·W_O) - -The key insight is that ∂A/∂X involves the softmax Jacobian, which is bounded by -the "softness" of the attention distribution. For sparse (one-hot) attention, -the softmax Jacobian is nearly zero, giving much tighter bounds. - -**Sparsity-aware bound**: - ‖patternTerm‖_F ≤ (‖J_softmax‖₂ / scale) · ‖X‖_F · ‖W_Q·W_K^T‖₂ · ‖W_V·W_O‖₂ - -We use a tight, data-dependent bound on `‖J_softmax‖₂` per row of attention: -for a probability row `p`, `J = diag(p) - p pᵀ` and -`‖J‖₂ ≤ min(maxᵢ pᵢ, 1 - Σᵢ pᵢ²)`. - -- Perfectly one-hot rows give `‖J‖₂ = 0`. -- Uniform rows give `‖J‖₂ = 1/n`. -- Worst-case (all n): `‖J‖₂ ≤ 0.5`. --/ -def computePatternTermBound (inputs : PatternTermBoundInputs) : Float := - (computePatternTermBoundParts inputs).patternBound - -/-- Bound ‖patternTerm‖_F using the old pessimistic constant bound. - -This uses the worst-case softmax Jacobian spectral-norm bound of 0.5, which is valid but loose. -Prefer `computePatternTermBound` for tighter data-dependent bounds. --/ -def computePatternTermBoundPessimistic (inputs : PatternTermBoundInputs) : Float := - let softmaxBound : Float := 0.5 -- Worst-case softmax Jacobian spectral norm - let n := inputs.attention.seqLen - let sqrtN := Float.sqrt n.toFloat - let qFrobUb := inputs.inputNorm * inputs.wqOpBound + sqrtN * inputs.bqFrob - let kFrobUb := inputs.inputNorm * inputs.wkOpBound + sqrtN * inputs.bkFrob - let vFrobUb0 := inputs.inputNorm * inputs.wvOpBound + sqrtN * inputs.bvFrob - let vFrobUb := - if inputs.vFrobBound ≤ 0.0 || Float.isNaN inputs.vFrobBound || Float.isInf inputs.vFrobBound then - vFrobUb0 - else - min vFrobUb0 inputs.vFrobBound - let sCoeff := (inputs.wqOpBound * kFrobUb + inputs.wkOpBound * qFrobUb) / inputs.scaleFactor - (softmaxBound * sCoeff) * vFrobUb * inputs.woOpBound - -/-- Compute faithfulness ratio: ‖patternTerm‖_F / ‖valueTerm‖_F. -/ -def computeFaithfulnessRatio (inputs : PatternTermBoundInputs) (valueOutputProjFrobNormSq : Float) : Float := - let patternBound := computePatternTermBound inputs - let valueNorm := computeValueTermNorm inputs.attention valueOutputProjFrobNormSq - if valueNorm < 1e-10 then Float.inf else patternBound / valueNorm - -/-! ## Discovery Structures -/ - -/-- Result of discovering a potential induction head pair. -/ -structure CandidateInductionHead where - /-- Layer index of the "previous token" head (L1) -/ - layer1Idx : Nat - /-- Layer index of the "induction" head (L2) -/ - layer2Idx : Nat - /-- Head index within layer 1 -/ - head1Idx : Nat - /-- Head index within layer 2 -/ - head2Idx : Nat - /-- Faithfulness ratio ε₁ for L1: ‖PatternTerm‖_F / ‖ValueTerm‖_F -/ - patternBound1 : Float - /-- Faithfulness ratio ε₂ for L2: ‖PatternTerm‖_F / ‖ValueTerm‖_F -/ - patternBound2 : Float - /-- Combined relative error: (1+ε₁)(1+ε₂) - 1 = ε₁ + ε₂ + ε₁·ε₂ -/ - combinedError : Float - /-- Previous-token strength: avg A₁[i, i-1] -/ - prevTokenStrength : Float - /-- Induction "copy-next" pattern score for head 2 (prompt-dependent). -/ - inductionScore : Float - /-- K-composition score between head 1 and head 2, as in the circuits framework paper: - - `kComp_raw = ‖W_QK² · W_OV¹‖_F / (‖W_QK²‖_F · ‖W_OV¹‖_F)`, - - then we subtract the random-baseline `1/√modelDim`: - - `kComp = kComp_raw - 1/√modelDim`. - - This measures how strongly head 1 can feed information into head 2's QK circuit, - i.e. whether head 1 plausibly acts as a **pattern enabler** for head 2. - -/ - kComp : Float - /-- Description of the discovered pattern -/ - description : String - -/-- A verified induction head that meets the certification threshold. -/ -structure VerifiedInductionHead where - /-- The candidate that was verified -/ - candidate : CandidateInductionHead - /-- The certification threshold used -/ - threshold : Float - /-- Combined error is below threshold (runtime-checked) -/ - errorChecked : Bool - -/-- An induction head candidate with an explicit effectiveness score `δ` on a target direction. - -This is produced by the Float-based discovery pipeline and should be interpreted as a -**heuristic ranking**, not a proof-grade certification. --/ -structure HeuristicInductionHead where - /-- The discovered candidate pair (pattern-checked, heuristically) -/ - candidate : CandidateInductionHead - /-- Raw effectiveness score `δ` on the target direction (Float). -/ - delta : Float - /-- Scale-invariant effectiveness score (Float): - - `effect = δ / (‖ln₁(X₂)‖_F · ‖u‖₂)`, - - where `X₂` is the layer-2 input residual stream and `u` is the target direction. - This isolates the **mechanism** (virtual-head computation) from residual-stream energy. - -/ - effect : Float - /-- Frobenius norm of the layer-2 input residual stream `‖X₂‖_F` (Float). -/ - layer2InputNorm : Float - /-- Frobenius norm of the Pre-LN attention input `‖ln₁(X₂)‖_F` (Float). -/ - layer2Ln1InputNorm : Float - -/-- Result of discovering a multi-layer circuit with N-layer error bounds. - -This extends `CandidateInductionHead` with rigorous N-layer amplification bounds -from the theorem `n_layer_faithfulness_composition`: - - ε_total = Σᵢ εᵢ · ∏_{j>i} (1 + Cⱼ) --/ -structure DeepCircuitCandidate where - /-- Layer indices involved in the circuit (sorted) -/ - layerIndices : Array Nat - /-- Head indices at each layer -/ - headIndices : Array Nat - /-- Per-layer pattern term bounds (εᵢ) -/ - patternBounds : Array Float - /-- Per-layer residual operator norm upper bounds (Cᵢ) -/ - operatorNormUbs : Array Float - /-- Simple error sum: Σᵢ εᵢ (no amplification) -/ - simpleErrorSum : Float - /-- N-layer amplified error: Σᵢ εᵢ · ∏_{j>i}(1+Cⱼ) -/ - amplifiedError : Float - /-- Suffix amplification factor from the earliest layer in the circuit: - `∏_{j ≥ min layer}(1 + Cⱼ)`. -/ - amplificationFactor : Float - /-- Pattern type description (e.g., "induction", "composition") -/ - patternType : String - /-- Human-readable description -/ - description : String - -namespace DeepCircuitCandidate - -def toString (c : DeepCircuitCandidate) : String := - let heads := c.layerIndices.zip c.headIndices |>.map fun (l, h) => s!"L{l}H{h}" - let headStr := - Id.run do - let mut out := "[" - let mut first := true - for h in heads do - if first then - first := false - else - out := out ++ ", " - out := out ++ h - out := out ++ "]" - return out - s!"{c.patternType}: {headStr} | " ++ - s!"ε_simple={c.simpleErrorSum}, ε_amplified={c.amplifiedError}, amp={c.amplificationFactor}" - -instance : ToString DeepCircuitCandidate := ⟨toString⟩ - -end DeepCircuitCandidate - -/-! ## Discovery Algorithm -/ - -/-- Check if a layer exhibits "previous-token" attention pattern. -/ -def checkPrevTokenPattern (attn : ConcreteAttentionWeights) - (minStrength : Float := 0.3) : Option Float := - if attn.seqLen < 2 then none - else - let sum : Float := Id.run do - let n := attn.seqLen - let w := attn.weights - let mut acc : Float := 0.0 - for i in [:n - 1] do - -- SAFETY: `i < n-1` implies `i+1 < n`, so `(i+1,i)` is in-bounds. - acc := acc + w[(i + 1) * n + i]! - return acc - let avgStrength := sum / (attn.seqLen - 1).toFloat - if avgStrength ≥ minStrength then some avgStrength else none - -/-- Check if a head exhibits a **content-addressable** attention pattern. - -We say a head is content-addressable on a prompt when, for many query positions `q`, -it places substantial attention mass on *previous occurrences of the same token*: - -`score(q) = ∑_{k < q, tokens[k] = tokens[q]} A[q, k]`. - -The returned score is the average of `score(q)` over query positions that have at least -one previous occurrence. This is variable-lag by construction (no fixed positional lag). --/ -def checkContentAddressablePattern (tokens : Array Nat) (attn : ConcreteAttentionWeights) - (minScore : Float := 0.1) : Option Float := - if tokens.size ≠ attn.seqLen then none - else if attn.seqLen < 2 then none - else - let n := attn.seqLen - let w := attn.weights - let (sumScore, count) : (Float × Nat) := Id.run do - let mut sumScore : Float := 0.0 - let mut count : Nat := 0 - for q in [1:n] do - let tq := tokens[q]! - let rowBase := q * n - let mut hasPrev : Bool := false - let mut rowScore : Float := 0.0 - for k in [:q] do - if tokens[k]! = tq then - hasPrev := true - -- SAFETY: `q < n` and `k < q ≤ n` by loop bounds. - rowScore := rowScore + w[rowBase + k]! - if hasPrev then - sumScore := sumScore + rowScore - count := count + 1 - return (sumScore, count) - - if count = 0 then none - else - let avgScore := sumScore / count.toFloat - if avgScore ≥ minScore then some avgScore else none - -/-- Check if a head exhibits an **induction** ("copy-next") attention pattern. - -We say a head is induction-like on a prompt when, for many query positions `q`, it places -substantial attention mass on tokens *immediately after* previous occurrences of the same token: - -`score(q) = ∑_{k+1 < q, tokens[k] = tokens[q]} A[q, k+1]`. - -This is the token-level signature of the induction mechanism described in the transformer-circuits -literature: when the current token repeats, attend to the successor of the previous occurrence so -the head can **copy** that successor forward. - -The returned score is the average of `score(q)` over query positions that have at least one -previous occurrence with an in-bounds successor (`k+1 < q`). This is variable-lag by construction. --/ -def checkInductionCopyNextPattern (tokens : Array Nat) (attn : ConcreteAttentionWeights) - (minScore : Float := 0.1) : Option Float := - if tokens.size ≠ attn.seqLen then none - else if attn.seqLen < 3 then none - else - let n := attn.seqLen - let w := attn.weights - let (sumScore, count) : (Float × Nat) := Id.run do - let mut sumScore : Float := 0.0 - let mut count : Nat := 0 - -- Need `q ≥ 2` so there is room for a predecessor `k` with successor `k+1 < q`. - for q in [2:n] do - let tq := tokens[q]! - let rowBase := q * n - let mut hasPrevSucc : Bool := false - let mut rowScore : Float := 0.0 - -- Scan all earlier positions `k` whose successor is still < q. - for k in [:q - 1] do - if tokens[k]! = tq then - hasPrevSucc := true - -- Attend to the *successor* position `k+1`. - rowScore := rowScore + w[rowBase + (k + 1)]! - if hasPrevSucc then - sumScore := sumScore + rowScore - count := count + 1 - return (sumScore, count) - - if count = 0 then none - else - let avgScore := sumScore / count.toFloat - if avgScore ≥ minScore then some avgScore else none - -/-- Compute composed attention score for induction pattern detection. - -Generalizes over all possible repetition lags from 2 to n/2, computing the -"induction score" (average attention mass transferred from q to q+lag via -the two-layer circuit). Returns the maximum score across all lags. - -This enables detection of induction heads with arbitrary repetition periods, -not just lag-2 patterns. --/ -def checkInductionPattern (attn1 attn2 : ConcreteAttentionWeights) - (minScore : Float := 0.1) : Option Float := - if attn1.seqLen ≠ attn2.seqLen then none - else if attn1.seqLen < 3 then none - else - let n := attn1.seqLen - let maxLag := n / 2 - - -- Try all possible lags and find the maximum induction score. - -- - -- PERFORMANCE: this is called in an O(L²H²) search, so avoid `List.range.foldl`. - let maxScore : Float := Id.run do - let w1 := attn1.weights - let w2 := attn2.weights - let mut currentMax : Float := 0.0 - for lagIdx in [:maxLag - 1] do - let lag := lagIdx + 2 -- Start from lag=2 - if lag < n then - -- Compute average induction score for this lag - let validPositions := n - lag - let mut composedSum : Float := 0.0 - for q in [:validPositions] do - let q' := q + lag - let row2Base := q' * n - let mut composedToQ : Float := 0.0 - -- Column access into `w1` is strided, so avoid repeated multiplications. - let mut col1Idx := q - for j in [:n] do - -- SAFETY: `q' < n` and `j < n` by loop bounds. - let a2 := w2[row2Base + j]! - -- SAFETY: `j < n` and `q < n` by loop bounds. - let a1 := w1[col1Idx]! - composedToQ := composedToQ + a2 * a1 - col1Idx := col1Idx + n - composedSum := composedSum + composedToQ - let avgScore := composedSum / validPositions.toFloat - if avgScore > currentMax then - currentMax := avgScore - return currentMax - - if maxScore ≥ minScore then some maxScore else none - -/-- Multi-layer model with concrete weights. -/ -structure ConcreteModel where - /-- Number of layers -/ - numLayers : Nat - /-- Attention layers with their heads: layers[l] is array of heads in layer l -/ - layers : Array (Array ConcreteAttentionLayer) - /-- Attention output projection bias (c_proj.bias), one per layer (1×modelDim). - - In GPT-2, this bias is added **once per layer** after combining all heads, i.e.: - `attn_out = c_proj(concat(heads)) + bias`. - -/ - attnProjBias : Array ConcreteMatrix := #[] - /-- MLP layers: mlps[l] is the MLP in layer l (one per layer) -/ - mlps : Array ConcreteMLPLayer - /-- Pre-LN LayerNorm parameters before attention (ln_1), one per layer. -/ - ln1 : Array ConcreteLayerNormParams := #[] - /-- Pre-LN LayerNorm parameters before MLP (ln_2), one per layer. -/ - ln2 : Array ConcreteLayerNormParams := #[] - /-- Final LayerNorm parameters (ln_f) before unembedding. -/ - lnf : ConcreteLayerNormParams := ConcreteLayerNormParams.identity 0 - /-- Sequence length for analysis -/ - seqLen : Nat - /-- Optional ground-truth input token IDs for the prompt being analyzed. - - When present, this enables **self-supervised induction targeting** by choosing the - correct next-token prediction target from sequence history (see - `TargetDirection.fromInductionHistory`). - -/ - inputTokens : Option (Array Nat) := none - /-- Input embeddings (seqLen × modelDim) -/ - inputEmbeddings : ConcreteMatrix - /-- Unembedding (decoder) matrix (modelDim × vocabSize) for logit computation. - Maps final residual stream to vocabulary logits: logits = residual · W_U - Optional: if not provided, target-aware analysis is unavailable. -/ - unembedding : Option ConcreteMatrix := none - -namespace ConcreteModel - -/-- Model dimension (d), inferred from input embeddings. -/ -def modelDim (model : ConcreteModel) : Nat := - model.inputEmbeddings.numCols - -/-- Attention output bias for a layer, defaulting to zero. -/ -def attnProjBiasAt (model : ConcreteModel) (layerIdx : Nat) : ConcreteMatrix := - model.attnProjBias.getD layerIdx (ConcreteMatrix.zeros 1 model.modelDim) - -/-- Trim trailing all-zero embedding rows (common when `.nfpt` uses a fixed `seqLen` with padding). - -This is a **semantic no-op** for causal analysis of the *prefix* of real tokens: padded positions -occur after the prompt and are never attended to by earlier queries, but they can badly inflate -LayerNorm Jacobian bounds because zero rows have variance `0` (hence `1/sqrt(eps)` scaling). - -We only trim when we detect a sufficiently long suffix of (near-)zero rows to avoid accidental -truncation on legitimate data. --/ -def trimTrailingZeroEmbeddings (model : ConcreteModel) - (minTrailing : Nat := 8) (eps : Float := 1e-12) : ConcreteModel := Id.run do - let X := model.inputEmbeddings - if X.numRows = 0 || X.numCols = 0 then - return model - let mut k : Nat := 0 - let mut r : Nat := X.numRows - while r > 0 do - let r' := r - 1 - let m := ConcreteMatrix.rowMaxAbs X r' - if m ≤ eps then - k := k + 1 - r := r' - else - break - if k < minTrailing then - return model - let newLen := X.numRows - k - let toks? := model.inputTokens.map (fun ts => ts.take newLen) - return { model with - seqLen := newLen - inputEmbeddings := X.takeRows newLen - inputTokens := toks? } - -/-- Get ln_1 parameters for a layer, defaulting to identity. -/ -def ln1Params (model : ConcreteModel) (layerIdx : Nat) : ConcreteLayerNormParams := - model.ln1.getD layerIdx (ConcreteLayerNormParams.identity model.modelDim) - -/-- Get ln_2 parameters for a layer, defaulting to identity. -/ -def ln2Params (model : ConcreteModel) (layerIdx : Nat) : ConcreteLayerNormParams := - model.ln2.getD layerIdx (ConcreteLayerNormParams.identity model.modelDim) - -/-- Apply ln_1 to a residual stream (row-wise, per token). -/ -def applyLn1 (model : ConcreteModel) (layerIdx : Nat) (X : ConcreteMatrix) : ConcreteMatrix := - let p := model.ln1Params layerIdx - ConcreteMatrix.layerNormRowwise X p.gamma p.beta - -/-- Apply ln_2 to a residual stream (row-wise, per token). -/ -def applyLn2 (model : ConcreteModel) (layerIdx : Nat) (X : ConcreteMatrix) : ConcreteMatrix := - let p := model.ln2Params layerIdx - ConcreteMatrix.layerNormRowwise X p.gamma p.beta - -/-- Apply final ln_f to a residual stream (row-wise, per token). -/ -def applyLnf (model : ConcreteModel) (X : ConcreteMatrix) : ConcreteMatrix := - ConcreteMatrix.layerNormRowwise X model.lnf.gamma model.lnf.beta - -/-- Heuristic estimate for ln_1 Jacobian operator norm at a specific activation. -/ -def ln1OpBound (model : ConcreteModel) (layerIdx : Nat) (X : ConcreteMatrix) : Float := - let p := model.ln1Params layerIdx - ConcreteMatrix.layerNormRowwiseOpEst X p.gamma - -/-- Heuristic estimate for ln_2 Jacobian operator norm at a specific activation. -/ -def ln2OpBound (model : ConcreteModel) (layerIdx : Nat) (X : ConcreteMatrix) : Float := - let p := model.ln2Params layerIdx - ConcreteMatrix.layerNormRowwiseOpEst X p.gamma - -end ConcreteModel - -/-- Get the number of neurons in the MLP at a given layer. -/ -def ConcreteModel.numNeuronsAtLayer (model : ConcreteModel) (layerIdx : Nat) : Nat := - if h : layerIdx < model.mlps.size then - model.mlps[layerIdx].hiddenDim - else 0 - -/-- Result of running a forward pass: the residual stream after each layer. - -`layerInputs[l]` is the input to layer l (the accumulated residual stream). -`layerInputs[0]` = inputEmbeddings (initial token embeddings) -`layerInputs[l+1]` = x_{l+1} in the Pre-LN recurrence. --/ -structure ForwardPassResult where - /-- Input to each layer. layerInputs[l] is what layer l receives. -/ - layerInputs : Array ConcreteMatrix - /-- Post-attention residual for each layer: `y_l = x_l + attn_out` (including attention output bias). -/ - postAttnResiduals : Array ConcreteMatrix - /-- Attention outputs per layer per head: attnOutputs[l][h] = output of head h at layer l -/ - attnOutputs : Array (Array ConcreteMatrix) - /-- MLP outputs per layer: mlpOutputs[l] = output of MLP at layer l -/ - mlpOutputs : Array ConcreteMatrix - /-- Per-layer GeLU derivative matrices over MLP preactivations. - - If a layer has an MLP, this is the matrix `gelu'(z)` where `z = ln₂(y_l)·W_in + b_in`. - Otherwise this entry is a `0×0` placeholder. - -/ - mlpActDeriv : Array ConcreteMatrix - /-- Per-layer maximum absolute GeLU derivative over MLP preactivations. - - Length is `numLayers`. For layers without an MLP, the entry is `0.0`. - -/ - mlpActDerivMax : Array Float - /-- Final output after all layers (after ln_f, i.e. what goes into unembedding). -/ - finalOutput : ConcreteMatrix - -/-- Run a full forward pass through the model, computing the residual stream at each layer. - -This is the key function that enables deep circuit analysis: layer N sees the accumulated -output of layers 0..N-1, not just the raw embeddings. - - For each layer l (Pre-LN, GPT-2 style): - 1. u = ln_1(x_l) - 2. y_l = x_l + Σₕ AttentionHead[l,h].forward(u) - 3. v = ln_2(y_l) - 4. x_{l+1} = y_l + MLP[l].forward(v) - - After all layers: output = ln_f(x_L) --/ -def ConcreteModel.runForward (model : ConcreteModel) - (causal : Bool := true) : ForwardPassResult := Id.run do - let mut layerInputs : Array ConcreteMatrix := Array.mkEmpty (model.numLayers + 1) - let mut postAttnResiduals : Array ConcreteMatrix := Array.mkEmpty model.numLayers - let mut attnOutputs : Array (Array ConcreteMatrix) := Array.mkEmpty model.numLayers - let mut mlpOutputs : Array ConcreteMatrix := Array.mkEmpty model.numLayers - let mut mlpActDeriv : Array ConcreteMatrix := Array.mkEmpty model.numLayers - let mut mlpActDerivMax : Array Float := Array.mkEmpty model.numLayers - let mut residual := model.inputEmbeddings - layerInputs := layerInputs.push residual - - for l in [:model.numLayers] do - -- Pre-LN: attention sees ln_1(residual) - let attnInput := model.applyLn1 l residual - -- Compute attention outputs for all heads in this layer - let mut layerAttnOutputs : Array ConcreteMatrix := #[] - let rows := residual.numRows - let cols := residual.numCols - - if hl : l < model.layers.size then - let layerHeads := model.layers[l] - let useParallelHeads := layerHeads.size >= 4 - layerAttnOutputs := - if useParallelHeads then - let tasks : Array (Task ConcreteMatrix) := - .ofFn fun i : Fin layerHeads.size => - Task.spawn (fun _ => (layerHeads[i]).forward attnInput causal) - tasks.map Task.get - else - Id.run do - let mut outs : Array ConcreteMatrix := Array.mkEmpty layerHeads.size - for head in layerHeads do - outs := outs.push (head.forward attnInput causal) - return outs - - attnOutputs := attnOutputs.push layerAttnOutputs - - -- Add attention residual - let residualAfterAttn := - if layerAttnOutputs.isEmpty then - residual - else - let attnSum := ConcreteMatrix.sumMatrices layerAttnOutputs - let attnBias := model.attnProjBiasAt l - residual.add ((attnSum.addBias attnBias)) - postAttnResiduals := postAttnResiduals.push residualAfterAttn - - -- Compute MLP output - -- Pre-LN: MLP sees ln_2(residualAfterAttn) - let mlpInput := model.applyLn2 l residualAfterAttn - let mut mlpOut : ConcreteMatrix := ConcreteMatrix.zeros residual.numRows residual.numCols - if hm : l < model.mlps.size then - let mlp := model.mlps[l] - let hidden := (mlpInput.matmul mlp.W_in).addBias mlp.b_in - let geluDeriv := hidden.map geluDerivFloat - let mut dmax : Float := 0.0 - for z in geluDeriv.data do - let d := Float.abs z - if d > dmax then - dmax := d - let activated := hidden.map geluFloat - let out := (activated.matmul mlp.W_out).addBias mlp.b_out - let dmax' : Float := - if Float.isNaN dmax || Float.isInf dmax then 0.0 else dmax - mlpActDeriv := mlpActDeriv.push geluDeriv - mlpActDerivMax := mlpActDerivMax.push dmax' - mlpOut := out - else - -- No MLP at this layer. - mlpActDeriv := mlpActDeriv.push (ConcreteMatrix.zeros 0 0) - mlpActDerivMax := mlpActDerivMax.push 0.0 - mlpOut := ConcreteMatrix.zeros residual.numRows residual.numCols - - mlpOutputs := mlpOutputs.push mlpOut - - -- Add MLP residual - residual := residualAfterAttn.add mlpOut - - -- Store input for next layer - layerInputs := layerInputs.push residual - - let finalOutput := model.applyLnf residual - { - layerInputs := layerInputs - postAttnResiduals := postAttnResiduals - attnOutputs := attnOutputs - mlpOutputs := mlpOutputs - mlpActDeriv := mlpActDeriv - mlpActDerivMax := mlpActDerivMax - finalOutput := finalOutput - } - -/-- Get the input to a specific layer from a forward pass result. -/ -def ForwardPassResult.getLayerInput (result : ForwardPassResult) - (layerIdx : Nat) : ConcreteMatrix := - if h : layerIdx < result.layerInputs.size then - result.layerInputs[layerIdx] - else ConcreteMatrix.zeros 0 0 - -/-- Get the cached MLP GeLU-derivative matrix for a layer, if present. -/ -def ForwardPassResult.getMlpGeluDeriv (result : ForwardPassResult) (layerIdx : Nat) : ConcreteMatrix := - if h : layerIdx < result.mlpActDeriv.size then - result.mlpActDeriv[layerIdx] - else - ConcreteMatrix.zeros 0 0 - -/-- Get the post-attention residual `y_l = x_l + attn_sum` for a layer. - -This is the input to `ln_2` in a Pre-LN transformer block. --/ -def ForwardPassResult.getPostAttnResidual (result : ForwardPassResult) - (layerIdx : Nat) : ConcreteMatrix := Id.run do - if h : layerIdx < result.postAttnResiduals.size then - result.postAttnResiduals[layerIdx] - else - let x := result.getLayerInput layerIdx - if h2 : layerIdx < result.attnOutputs.size then - let heads := result.attnOutputs[layerIdx] - if heads.isEmpty then x else x.add (ConcreteMatrix.sumMatrices heads) - else - x - -/-! ## N-Layer Error Amplification Computation - -These functions implement the N-layer faithfulness composition formula from -`Linearization.lean`. They must be defined early so they can be used by -discovery functions like `findDeepCircuitCandidates`. --/ - -/-- Compute suffix amplification factor: ∏_{j≥start} (1 + C_j) - -This is how much error from layer `start` gets amplified by subsequent layers. -When start ≥ normBounds.size, returns 1 (no amplification). --/ -def computeSuffixAmplification (normBounds : Array Float) (start : Nat) : Float := Id.run do - let mut product : Float := 1.0 - for j in [start:normBounds.size] do - if hj : j < normBounds.size then - product := product * (1.0 + normBounds[j]) - product - -/-- Compute total amplified error: Σᵢ εᵢ · suffixAmplification(i+1) - -This implements the N-layer faithfulness composition formula from -`Linearization.lean` theorem `n_layer_faithfulness_composition`. --/ -def computeTotalAmplifiedError (patternBounds normBounds : Array Float) : Float := Id.run do - if patternBounds.size = 0 then return 0.0 - let mut total : Float := 0.0 - for i in [:patternBounds.size] do - if hi : i < patternBounds.size then - let epsilon_i := patternBounds[i] - -- Suffix amplification from layer i+1 onwards - let suffix := computeSuffixAmplification normBounds (i + 1) - total := total + epsilon_i * suffix - total - -/-- Estimate the operator norm bound for a single attention layer. - -For an attention layer, the Jacobian includes both the attention pattern term -and the value projection. We estimate: - ‖J‖ ≤ ‖A‖_F · ‖W_V·W_O‖_op + ‖∂A/∂x‖ · ‖V·W_O‖, -so ‖I + J‖ ≤ 1 + ‖J‖. - -For simplicity, we use Frobenius norms as upper bounds. --/ -def estimateAttentionLayerNorm (model : ConcreteModel) (fwdResult : ForwardPassResult) - (layerIdx : Nat) (causal : Bool := true) : Float := Id.run do - if h : layerIdx < model.layers.size then - let heads := model.layers[layerIdx] - let mut totalNorm : Float := 0.0 - - -- Pre-LN: attention and MLP see normalized activations; account for LN Jacobian scaling. - let x := fwdResult.getLayerInput layerIdx - let y := fwdResult.getPostAttnResidual layerIdx - let ln1Bound := model.ln1OpBound layerIdx x - let ln2Bound := model.ln2OpBound layerIdx y - - -- Attention pattern/value bounds are computed at the Pre-LN attention input. - let attnInput := model.applyLn1 layerIdx x - let inputNorm := attnInput.frobeniusNorm - let inputOpBound := attnInput.opNormUpperBoundOneInf - - -- Sum contributions from all heads in this layer - for hidx in [:heads.size] do - if hh : hidx < heads.size then - let head := heads[hidx] - - -- QK operator norm bound via a 64×64 companion product. - -- `‖W_Q W_Kᵀ‖₂ = ‖W_Kᵀ W_Q‖₂` because `AB` and `BA` have the same nonzero - -- singular values. - let qkSmall : ConcreteMatrix := head.W_K.transpose.matmul head.W_Q - let qkNorm : Float := - min (qkSmall.opNormUpperBoundDenseBrauer) - (min (qkSmall.opNormUpperBoundDenseSchur) (qkSmall.opNormUpperBoundDenseFrob)) - - -- VO operator norm bound via a 64×64 companion product. - -- `‖W_V W_O‖₂ = ‖W_O W_V‖₂` because `AB` and `BA` have the same nonzero - -- singular values. - let voSmall : ConcreteMatrix := head.W_O.matmul head.W_V - let valueOutputProjNorm : Float := - min (voSmall.opNormUpperBoundDenseBrauer) - (min (voSmall.opNormUpperBoundDenseSchur) (voSmall.opNormUpperBoundDenseFrob)) - - -- Data-dependent softmax Jacobian operator bound (per-row max, clamped to 1/2). - let attn := head.computeAttentionWeights attnInput causal - let softmaxOpBound := min attn.softmaxJacobianOpDiag.opBound 0.5 - let scaleFactor := Float.sqrt head.headDim.toFloat - - -- Value-term operator bound: - -- ‖X ↦ A·X·(W_VW_O)‖ ≤ ‖A‖₂ · ‖W_VW_O‖₂. - -- For the attention matrix `A` (nonnegative row-stochastic), the max row sum is 1 - -- (`‖A‖₁` in row-vector convention, `‖A‖∞` in column-vector convention), so - -- `sqrt(‖A‖₁‖A‖∞)` is typically far tighter than `‖A‖_F`. - let attnFrob : Float := Id.run do - let mut s : Float := 0.0 - for w in attn.weights do - s := s + w * w - Float.sqrt (max 0.0 s) - let attnMaxRowSum : Float := Id.run do - let n := attn.seqLen - let mut maxSum : Float := 0.0 - for q in [:n] do - let mut s : Float := 0.0 - let rowBase := q * n - for k in [:n] do - s := s + Float.abs (attn.weights[rowBase + k]!) - maxSum := max maxSum s - maxSum - let attnMaxColSum : Float := Id.run do - let n := attn.seqLen - let mut maxSum : Float := 0.0 - for k in [:n] do - let mut s : Float := 0.0 - for q in [:n] do - s := s + Float.abs (attn.weights[q * n + k]!) - maxSum := max maxSum s - maxSum - let attnOneInf : Float := Float.sqrt (attnMaxRowSum * attnMaxColSum) - let attnOpUb : Float := min attnFrob attnOneInf - let valueTermUb := attnOpUb * valueOutputProjNorm - - let bnds := head.noDenseProductBounds - let inputs : PatternTermBoundInputs := { - attention := attn - inputNorm := inputNorm - inputOpBound := inputOpBound - scaleFactor := scaleFactor - wqOpBound := bnds.wqOpGram - wkOpBound := bnds.wkOpGram - wvOpBound := bnds.wvOpGram - woOpBound := bnds.woOpGram - voOpBound := bnds.voOpBound - bqFrob := head.b_Q.frobeniusNorm - bkFrob := head.b_K.frobeniusNorm - bvFrob := head.b_V.frobeniusNorm - } - let patternTermUb := computePatternTermBound inputs - - totalNorm := totalNorm + ln1Bound * (valueTermUb + patternTermUb) - - -- Add MLP contribution if present - if hm : layerIdx < model.mlps.size then - let mlp := model.mlps[layerIdx] - -- Use the fused Jacobian bound (data-dependent on `gelu'(z)`), but keep the legacy - -- product-of-norms as a (looser) fallback. - let winNormUb := mlp.W_in.opNormUpperBoundRectGram - let woutNormUb := mlp.W_out.opNormUpperBoundRectGram - let geluDeriv := fwdResult.getMlpGeluDeriv layerIdx - let mlpUb := - if geluDeriv.numCols = mlp.hiddenDim ∧ geluDeriv.numRows = y.numRows then - computeMLPLayerOpNormFromGeluDerivWithOpBounds mlp geluDeriv winNormUb woutNormUb - else - let mlpInput := model.applyLn2 layerIdx y - let hidden := (mlpInput.matmul mlp.W_in).addBias mlp.b_in - let dAct := hidden.map geluDerivFloat - computeMLPLayerOpNormFromGeluDerivWithOpBounds mlp dAct winNormUb woutNormUb - totalNorm := totalNorm + ln2Bound * mlpUb - - totalNorm - else - return 0.0 - -/-- Diagnostics-only variant of `estimateAttentionLayerNorm`. - -This uses power iteration (`operatorNormHeuristicPI`) for the key operator norms, -to provide an "old PI vs new rigorous" comparison under `-d`. - -WARNING: This is **not** a certified upper bound and must never be used in -bound/certification codepaths. --/ -def estimateAttentionLayerNormHeuristicPI (model : ConcreteModel) (fwdResult : ForwardPassResult) - (layerIdx : Nat) (causal : Bool := true) : Float := Id.run do - if h : layerIdx < model.layers.size then - let heads := model.layers[layerIdx] - let mut totalNorm : Float := 0.0 - - let x := fwdResult.getLayerInput layerIdx - let y := fwdResult.getPostAttnResidual layerIdx - let ln1Bound := model.ln1OpBound layerIdx x - let ln2Bound := model.ln2OpBound layerIdx y - let attnInput := model.applyLn1 layerIdx x - let inputNorm := attnInput.frobeniusNorm - let inputOpBound := attnInput.opNormUpperBoundOneInf - - for hidx in [:heads.size] do - if hh : hidx < heads.size then - let head := heads[hidx] - let qkSmall : ConcreteMatrix := head.W_K.transpose.matmul head.W_Q - let qkPi := qkSmall.operatorNormHeuristicPI 20 - let voSmall : ConcreteMatrix := head.W_O.matmul head.W_V - let voPi := voSmall.operatorNormHeuristicPI 20 - - let attn := head.computeAttentionWeights attnInput causal - let softmaxOpBound := min attn.softmaxJacobianOpDiag.opBound 0.5 - let scaleFactor := Float.sqrt head.headDim.toFloat - let attnFrob : Float := Id.run do - let mut s : Float := 0.0 - for w in attn.weights do - s := s + w * w - Float.sqrt (max 0.0 s) - let attnMaxRowSum : Float := Id.run do - let n := attn.seqLen - let mut maxSum : Float := 0.0 - for q in [:n] do - let mut s : Float := 0.0 - let rowBase := q * n - for k in [:n] do - s := s + Float.abs (attn.weights[rowBase + k]!) - maxSum := max maxSum s - maxSum - let attnMaxColSum : Float := Id.run do - let n := attn.seqLen - let mut maxSum : Float := 0.0 - for k in [:n] do - let mut s : Float := 0.0 - for q in [:n] do - s := s + Float.abs (attn.weights[q * n + k]!) - maxSum := max maxSum s - maxSum - let attnOneInf : Float := Float.sqrt (attnMaxRowSum * attnMaxColSum) - let attnOpUb : Float := min attnFrob attnOneInf - let valueTermEst := attnOpUb * voPi - let patternTermEst := (softmaxOpBound / scaleFactor) * inputNorm * qkPi * voPi - totalNorm := totalNorm + ln1Bound * (valueTermEst + patternTermEst) - - if hm : layerIdx < model.mlps.size then - let mlp := model.mlps[layerIdx] - -- Keep PI iterations low-ish here to avoid expensive diagnostics runs. - let winPi := mlp.W_in.operatorNormHeuristicPI 5 - let woutPi := mlp.W_out.operatorNormHeuristicPI 5 - let geluDerivBound : Float := - let d := fwdResult.mlpActDerivMax.getD layerIdx 1.7 - if d ≤ 0.0 || Float.isNaN d || Float.isInf d then 1.7 else d - totalNorm := totalNorm + ln2Bound * (winPi * geluDerivBound * woutPi) - - totalNorm - else - return 0.0 - -/-- Compute attention weights for a given layer and head using the correct layer input. - -This is the corrected version that uses the accumulated residual stream. --/ -def ConcreteModel.computeAttentionWithInput (model : ConcreteModel) - (layerIdx headIdx : Nat) (input : ConcreteMatrix) : Option ConcreteAttentionWeights := - if h1 : layerIdx < model.layers.size then - let layerHeads := model.layers[layerIdx] - if h2 : headIdx < layerHeads.size then - let head := layerHeads[headIdx]'h2 - -- Pre-LN: attention weights are computed from ln_1(input). - let attnInput := model.applyLn1 layerIdx input - some (head.computeAttentionWeights attnInput) - else none - else none - -/-- Compute attention weights for a given layer and head. - -This is a legacy helper that only uses the model's `inputEmbeddings`. -Prefer `computeAttentionWithInput` for Pre-LN-correct layer inputs. --/ -def ConcreteModel.computeAttention (model : ConcreteModel) - (layerIdx headIdx : Nat) : Option ConcreteAttentionWeights := - if h1 : layerIdx < model.layers.size then - let layerHeads := model.layers[layerIdx] - if h2 : headIdx < layerHeads.size then - let head := layerHeads[headIdx] - let attnInput := model.applyLn1 layerIdx model.inputEmbeddings - some (head.computeAttentionWeights attnInput) - else none - else none - -/-- Compute input norm for bound calculations. -/ -def computeInputNorm (embeddings : ConcreteMatrix) : Float := - embeddings.frobeniusNorm - -/-! ## Precomputation Cache Structures - -To optimize the O(L²H²) nested loop in deep circuit discovery, we precompute and cache: -1. Attention patterns (A = softmax(QK^T/√d)) for each layer-head -2. Value-output projections (V·W_O) for each head -3. Query-key alignments (Q·K^T) for each head -4. Operator norm bounds and Frobenius norms - -This reduces redundant computation from O(L²H²) to O(LH), a massive improvement -for models like GPT-2 Small (12 layers × 12 heads = 144 heads → 20,736 → 144 calls). --/ - -/-- Precomputed data for a single attention head at a specific layer input. -/ -structure HeadDiagnosticsLazy where - /-- Dense `W_V·W_O` (diagnostics-only). -/ - voProj : Thunk ConcreteMatrix - /-- Dense `W_Q·W_Kᵀ` (diagnostics-only). -/ - qkAlign : Thunk ConcreteMatrix - /-- Dense Schur candidate `sqrt(‖M‖₁‖M‖∞)` for `W_V·W_O` (diagnostics-only). -/ - voDenseSchur : Thunk Float - /-- Dense Schur candidate `sqrt(‖M‖₁‖M‖∞)` for `W_Q·W_Kᵀ` (diagnostics-only). -/ - qkDenseSchur : Thunk Float - -structure PrecomputedHeadData where - /-- Layer index -/ - layerIdx : Nat - /-- Head index within layer -/ - headIdx : Nat - /-- Attention weights (A = softmax(QK^T/√d)) -/ - attention : ConcreteAttentionWeights - /-- Average previous-token attention strength: `(1/(n-1)) * Σᵢ A[i+1, i]`. -/ - prevTokenStrength : Float - /-- Cached softmax Jacobian operator-norm estimate for this head's attention rows. -/ - softmaxJacobianOpEst : Float - /-- Softmax Jacobian diagnostics: `max_i p_i` for the maximizing row. -/ - softmaxRowMaxP : Float - /-- Softmax Jacobian diagnostics: `tr(J) = 1 - ∑ p_i^2` for the maximizing row. -/ - softmaxRowTraceBound : Float - /-- Softmax Jacobian diagnostics: PSD moment bound for the maximizing row. -/ - softmaxRowMomentBound : Float - /-- Softmax Jacobian diagnostics: Gershgorin / `‖J‖_∞ = max_i 2 p_i (1-p_i)` - for the maximizing row. -/ - softmaxRowGershBound : Float - /-- Softmax Jacobian diagnostics: final per-row bound used for the maximizing row. -/ - softmaxRowBoundUsed : Float - /-- Number of rows that triggered a conservative fallback (NaN/Inf/zero-sum). -/ - softmaxRowsFallback : Nat - /-- Cached Frobenius norm squared of the attention matrix: `‖A‖_F²`. -/ - attentionFrobeniusNormSq : Float - /-- Cached `sqrt(‖A‖₁‖A‖∞)` upper bound on `‖A‖₂` (computed from max row/col sums). -/ - attentionOneInfBound : Float - /-- Cached pattern-term bound `‖PatternTerm‖_F` for this head at the cached input. -/ - patternTermBoundCached : Float - /-- Cached value-term Frobenius norm `‖ValueTerm‖_F` for this head. -/ - valueTermNormCached : Float - /-- Cached dimensionless faithfulness ratio `‖PatternTerm‖_F / ‖ValueTerm‖_F`. -/ - faithfulnessRatioCached : Float - /-- Optional detailed pattern-term bound diagnostics. -/ - patternBoundParts? : Option PatternTermBoundParts := none - /-- Optional lazy diagnostics. Never forced on the main path. -/ - diag? : Option HeadDiagnosticsLazy - /-- Cached Gram matrix `W_Qᵀ · W_Q` (headDim × headDim). -/ - wqGram : ConcreteMatrix - /-- Cached Gram matrix `W_Vᵀ · W_V` (headDim × headDim). -/ - wvGram : ConcreteMatrix - /-- Input norm ‖X‖_F for this layer -/ - inputNorm : Float - /-- Operator-norm upper bound ‖X‖₂ for this layer input (computed via 1/∞). -/ - inputOpBound : Float - /-- Direct (data-dependent) Frobenius norm of `Q = X·W_Q + b_Q` for this head. -/ - qFrobBound : Float - /-- Direct (data-dependent) Frobenius norm of `K = X·W_K + b_K` for this head. -/ - kFrobBound : Float - /-- Direct (data-dependent) Frobenius norm of centered `V = X·W_V + b_V` for this head. -/ - vFrobBound : Float - /-- Direct (data-dependent) operator-norm upper bound on `Q` (computed via a small Gram). -/ - qOpBoundAct : Float - /-- Direct (data-dependent) operator-norm upper bound on `K` (computed via a small Gram). -/ - kOpBoundAct : Float - /-- Direct (data-dependent) operator-norm upper bound on centered `V` (computed via a small Gram). -/ - vOpBoundAct : Float - /-- Frobenius bound for `‖W_Q·Kᵀ‖_F` from centered activations. -/ - qkActFrobBound : Float - /-- Frobenius bound for `‖W_K·Qᵀ‖_F` from centered activations. -/ - kqActFrobBound : Float - /-- Operator-norm bound for `‖W_Q·Kᵀ‖₂` from centered activation Grams. -/ - qkActOpBound : Float - /-- Operator-norm bound for `‖W_K·Qᵀ‖₂` from centered activation Grams. -/ - kqActOpBound : Float - /-- Selected candidate for `‖W_Q·Kᵀ‖₂` activation bound (diagnostics-only). -/ - qkActOpBoundSource : String := "unknown" - /-- Selected candidate for `‖W_K·Qᵀ‖₂` activation bound (diagnostics-only). -/ - kqActOpBoundSource : String := "unknown" - /-- Operator-norm bound for ln_1 Jacobian at this layer input. -/ - ln1OpBound : Float - /-- Scaling factor √d_head -/ - scaleFactor : Float - /-- Cached Frobenius norm of V·W_O -/ - valueOutputProjNorm : Float - /-- Cached Frobenius norm of Q·K^T -/ - queryKeyAlignNorm : Float - /-- Cached deterministic Float upper bound on ‖V·W_O‖₂. - - This is computed as the minimum of several valid upper bounds - (Schur / Frobenius / factor bounds). - -/ - valueOutputProjSchurNorm : Float - /-- Cached deterministic Float upper bound on ‖Q·K^T‖₂. - - This is computed as the minimum of several valid upper bounds - (Schur / Frobenius / factor bounds). - -/ - queryKeyAlignSchurNorm : Float - /-- Candidate bounds for `‖Q·Kᵀ‖₂` used in diagnostics. -/ - qkDenseFrob : Float - /-- Tight Gram-product candidate derived from 64×64 Gram matrices. -/ - qkDenseGram : Float - /-- Brauer/Cassini Gram candidate computed on a 64×64 matrix with matching singular values. -/ - qkDenseBrauer : Float - qkFactorSchur : Float - qkFactorFrob : Float - - /-- Gram-based operator bound for `W_Q` (min of Gersh/Brauer/moment candidates). -/ - wqOpGram : Float - /-- Gram-based operator bound for `W_K` (min of Gersh/Brauer/moment candidates). -/ - wkOpGram : Float - /-- Frobenius norm of the 1×headDim query bias row `b_Q`. -/ - bqFrob : Float - /-- Frobenius norm of the 1×headDim key bias row `b_K`. -/ - bkFrob : Float - /-- Frobenius norm of the 1×headDim value bias row `b_V`. -/ - bvFrob : Float - /-- Factorized Gram bound for `‖W_Q·W_Kᵀ‖₂`: `wqOpGram * wkOpGram`. -/ - qkFactorGram : Float - - /-- Candidate bounds for `‖W_V·W_O‖₂` used in diagnostics. -/ - voDenseFrob : Float - /-- Tight Gram-product candidate derived from 64×64 Gram matrices. -/ - voDenseGram : Float - /-- Brauer/Cassini Gram candidate computed on a 64×64 matrix with matching singular values. -/ - voDenseBrauer : Float - voFactorSchur : Float - voFactorFrob : Float - /-- Gram-based operator bound for `W_V` (min of Gersh/Brauer/moment candidates). -/ - wvOpGram : Float - /-- Gram-based operator bound for `W_O` (min of Gersh/Brauer/moment candidates). -/ - woOpGram : Float - /-- Factorized Gram bound for `‖W_V·W_O‖₂`: `wvOpGram * woOpGram`. -/ - voFactorGram : Float - -namespace PrecomputedHeadData - -/-- Precomputed pattern term bound for a head (cached computation). -/ -def patternTermBound (data : PrecomputedHeadData) : Float := - data.patternTermBoundCached - -/-- Frobenius norm of the Value Term of this head's Jacobian. - -For the attention linearization, the Value Term factorizes as `A ⊗ (W_V·W_O)`, -so `‖ValueTerm‖_F = ‖A‖_F · ‖W_V·W_O‖_F`. --/ -def valueTermNorm (data : PrecomputedHeadData) : Float := - data.valueTermNormCached - -/-- Dimensionless faithfulness ratio: `‖PatternTerm‖_F / ‖ValueTerm‖_F`. - -This matches `relativeApproximationError` from `Nfp.Linearization` and is the -quantity that should be compared to user-facing thresholds like `0.1`. --/ -def faithfulnessRatio (data : PrecomputedHeadData) : Float := - data.faithfulnessRatioCached - -end PrecomputedHeadData - -/-- Cache for all precomputed head data across all layers. - -Structure: `cache[layerIdx][headIdx]` gives the PrecomputedHeadData for that head. --/ -structure PrecomputedCache where - /-- Model this cache was built for -/ - model : ConcreteModel - /-- Forward pass result with layer inputs -/ - forwardResult : ForwardPassResult - /-- Cached Pre-LN attention inputs `ln_1(x_l)` for each layer `l`. -/ - ln1Inputs : Array ConcreteMatrix - /-- Precomputed data: cache[layerIdx][headIdx] -/ - headData : Array (Array PrecomputedHeadData) - /-- Precomputed operator norm bounds for each layer (for N-layer error amplification) -/ - layerNormBounds : Array Float - /-- Whether `layerNormBounds` were computed (otherwise the array is a placeholder). -/ - layerNormBoundsComputed : Bool - -namespace PrecomputedCache - -/-! ### Layer Jacobian-norm upper bounds (cached-head fast path) -/ - -/-- Compute the per-layer residual Jacobian norm upper bound `C_l` using cached head data. - -This matches `estimateAttentionLayerNorm` but avoids recomputing attention weights and -small-factor norms. The only tier-dependent cost is the MLP `W_in/W_out` operator bounds. --/ -def layerNormBoundAt (cache : PrecomputedCache) (layerIdx : Nat) - (effort : ConcreteMatrix.BoundEffort := ConcreteMatrix.BoundEffort.tier1) : Float := Id.run do - let model := cache.model - let fwd := cache.forwardResult - let y := fwd.getPostAttnResidual layerIdx - let ln2Bound := model.ln2OpBound layerIdx y - let layerData := cache.headData.getD layerIdx #[] - let mut attnPart : Float := 0.0 - for d in layerData do - let attnFrob : Float := Float.sqrt (max 0.0 d.attentionFrobeniusNormSq) - let attnOpUb : Float := min attnFrob d.attentionOneInfBound - let valueTermUb : Float := attnOpUb * d.valueOutputProjSchurNorm - let patternTermUb : Float := d.patternTermBound - attnPart := attnPart + d.ln1OpBound * (valueTermUb + patternTermUb) - - if hm : layerIdx < model.mlps.size then - let mlp := model.mlps[layerIdx]'hm - let winUb := mlp.W_in.opNormUpperBoundRectGramEffort effort - let woutUb := mlp.W_out.opNormUpperBoundRectGramEffort effort - let geluDeriv := fwd.getMlpGeluDeriv layerIdx - let mlpUb := - if geluDeriv.numCols = mlp.hiddenDim ∧ geluDeriv.numRows = y.numRows then - computeMLPLayerOpNormFromGeluDerivWithOpBounds mlp geluDeriv winUb woutUb - else - let mlpInput := model.applyLn2 layerIdx y - let hidden := (mlpInput.matmul mlp.W_in).addBias mlp.b_in - let dAct := hidden.map geluDerivFloat - computeMLPLayerOpNormFromGeluDerivWithOpBounds mlp dAct winUb woutUb - let mlpPart := ln2Bound * mlpUb - return attnPart + (1.0 + attnPart) * mlpPart - else - return attnPart - -/-- Compute all per-layer residual Jacobian upper bounds `C_l` under a uniform effort tier. -/ -def computeLayerNormBounds (cache : PrecomputedCache) - (effort : ConcreteMatrix.BoundEffort := ConcreteMatrix.BoundEffort.tier1) : - Array Float := Id.run do - let n := cache.model.numLayers - let useParallel := n >= 4 - if useParallel then - let tasks : Array (Task Float) := - .ofFn fun i : Fin n => - Task.spawn (fun _ => cache.layerNormBoundAt i.val effort) - tasks.map Task.get - else - let mut out : Array Float := Array.replicate n 0.0 - for l in [:n] do - out := out.set! l (cache.layerNormBoundAt l effort) - out - -/-- Build cached head data and pre-LN attention inputs for all layers. -/ -def buildHeadData (model : ConcreteModel) (fwdResult : ForwardPassResult) - (causal : Bool := true) - (layerNormEffort : ConcreteMatrix.BoundEffort := ConcreteMatrix.BoundEffort.tier1) - (storeDiagnostics : Bool := false) : - Array (Array PrecomputedHeadData) × Array ConcreteMatrix := Id.run do - -- Prefer layer-level parallelism, but allow bounded head chunking to use spare cores. - let useParallelLayers := model.numLayers >= 4 - let computeLayer (l : Nat) : (Array PrecomputedHeadData × ConcreteMatrix) := Id.run do - let layerInput := fwdResult.getLayerInput l - let attnInput := model.applyLn1 l layerInput - let inputNorm := computeInputNorm attnInput - -- Use a Gram-based bound here (tier-controlled) because the cheap 1/∞ estimate can be - -- extremely loose on dense `seqLen×modelDim` matrices and will not meaningfully tighten - -- the attention pattern-term bounds. - let inputOpBound := attnInput.opNormUpperBoundRectGramEffort layerNormEffort - let ln1Bound := model.ln1OpBound l layerInput - - let layerHeadData : Array PrecomputedHeadData := - if h : l < model.layers.size then - let heads := model.layers[l]'h - let computeHead (hIdx : Nat) (head : ConcreteAttentionLayer) : PrecomputedHeadData := Id.run do - -- Compute Q/K once (needed for attention weights), and also cache tight, data-dependent - -- bounds on their norms for use in pattern-term bounds. - let queries := (attnInput.matmul head.W_Q).addBias head.b_Q - let keys := (attnInput.matmul head.W_K).addBias head.b_K - let scaleFactor := Float.sqrt head.headDim.toFloat - let attn := ConcreteAttentionWeights.compute queries keys scaleFactor attnInput.numRows causal - let values := (attnInput.matmul head.W_V).addBias head.b_V - let qFrobBound : Float := queries.frobeniusNorm - let kFrobBound : Float := keys.frobeniusNorm - let vFrobBoundRaw : Float := values.frobeniusNorm - let smallEffort : ConcreteMatrix.BoundEffort := ConcreteMatrix.BoundEffort.tier1 - let qOpBoundAct : Float := queries.opNormUpperBoundRectGramEffort smallEffort - let kOpBoundAct : Float := keys.opNormUpperBoundRectGramEffort smallEffort - let vOpBoundActRaw : Float := values.opNormUpperBoundRectGramEffort smallEffort - let softmaxDiag := attn.softmaxJacobianOpDiag - let softmaxOpBound := min softmaxDiag.opBound 0.5 - let n := attn.seqLen - let mut attnFrobNormSq : Float := 0.0 - let mut maxRowAbsSum : Float := 0.0 - let mut colAbsSums : Array Float := Array.replicate n 0.0 - for q in [:n] do - let rowBase := q * n - let mut rowAbsSum : Float := 0.0 - for k in [:n] do - let w := attn.weights[rowBase + k]! - attnFrobNormSq := attnFrobNormSq + w * w - let a := Float.abs w - rowAbsSum := rowAbsSum + a - colAbsSums := colAbsSums.set! k (colAbsSums[k]! + a) - if rowAbsSum > maxRowAbsSum then - maxRowAbsSum := rowAbsSum - let mut maxColAbsSum : Float := 0.0 - for k in [:n] do - let s := colAbsSums[k]! - if s > maxColAbsSum then - maxColAbsSum := s - let attnOneInf : Float := Float.sqrt (max 0.0 (maxRowAbsSum * maxColAbsSum)) - - let prevTokenStrength : Float := - if attn.seqLen < 2 then - 0.0 - else - Id.run do - let n := attn.seqLen - let w := attn.weights - let mut sum : Float := 0.0 - for i in [:n - 1] do - sum := sum + w[(i + 1) * n + i]! - return sum / (n - 1).toFloat - - -- No-dense scalar bounds: avoid allocating per-head `modelDim×modelDim` products. - let bnds := head.noDenseProductBounds - let wqGram := bnds.wqGram - let wvGram := bnds.wvGram - let qkNorm := bnds.qkDenseFrob - let voNorm := bnds.voDenseFrob - let qkOpBound := bnds.qkOpBound - let voOpBound := bnds.voOpBound - let qActGram := queries.transpose.matmul queries - let kActGram := keys.transpose.matmul keys - let meanVec := fun (M : ConcreteMatrix) => Id.run do - let cols := M.numCols - if M.numRows = 0 || cols = 0 then - return Array.replicate cols 0.0 - let mut sums : Array Float := Array.replicate cols 0.0 - for r in [:M.numRows] do - let rowBase := r * cols - for c in [:cols] do - sums := sums.set! c (sums[c]! + M.data[rowBase + c]!) - let invN := 1.0 / M.numRows.toFloat - let mut out : Array Float := Array.replicate cols 0.0 - for c in [:cols] do - out := out.set! c (sums[c]! * invN) - return out - let centerGram := fun (gram : ConcreteMatrix) (mean : Array Float) (n : Nat) => - if gram.numRows = 0 || gram.numCols = 0 || n = 0 then - gram - else - let nF := n.toFloat - { numRows := gram.numRows - numCols := gram.numCols - data := .ofFn fun idx : Fin (gram.numRows * gram.numCols) => Id.run do - let i := idx.val / gram.numCols - let j := idx.val % gram.numCols - gram.data[idx.val]! - nF * mean[i]! * mean[j]! - size_eq := Array.size_ofFn - } - let gramTrace := fun (M : ConcreteMatrix) => Id.run do - if M.numRows = 0 || M.numCols = 0 then - return 0.0 - let mut acc : Float := 0.0 - for i in [:M.numRows] do - acc := acc + M.data[i * M.numCols + i]! - return acc - -- Center activations to drop row-constant logit shifts (softmax-invariant). - let seqLen := keys.numRows - let keyMean := meanVec keys - let queryMean := meanVec queries - let valueMean := meanVec values - let vActGram := values.transpose.matmul values - let vActGramCentered := centerGram vActGram valueMean seqLen - let vActTrace : Float := gramTrace vActGramCentered - let vActTracePos := max 0.0 vActTrace - let vFrobBoundCentered : Float := Float.sqrt vActTracePos - let vFrobBound : Float := - if vFrobBoundCentered ≤ 0.0 || Float.isNaN vFrobBoundCentered || - Float.isInf vFrobBoundCentered then - vFrobBoundRaw - else - min vFrobBoundRaw vFrobBoundCentered - let vOpBoundActCentered : Float := opBoundFromGramLocal vActGramCentered vActTracePos - let vOpBoundAct : Float := - if vOpBoundActRaw ≤ 0.0 || Float.isNaN vOpBoundActRaw || Float.isInf vOpBoundActRaw then - vOpBoundActCentered - else - min vOpBoundActRaw vOpBoundActCentered - let kActGramCentered := centerGram kActGram keyMean seqLen - let qActGramCentered := centerGram qActGram queryMean seqLen - let kActTracePos := max 0.0 (gramTrace kActGramCentered) - let qActTracePos := max 0.0 (gramTrace qActGramCentered) - let kCenteredOpBound : Float := opBoundFromGramLocal kActGramCentered kActTracePos - let qCenteredOpBound : Float := opBoundFromGramLocal qActGramCentered qActTracePos - let chooseMinLabel := fun (a b : (String × Float)) => - if a.2 ≤ b.2 then a else b - let qkActTrace : Float := ConcreteMatrix.traceMul wqGram kActGramCentered - let qkActGram1 := kActGramCentered.matmul wqGram - let qkActGram2 := wqGram.matmul kActGramCentered - let qkActOpDense1 : Float := qkActGram1.opNormUpperBoundDenseBrauer - let qkActOpDense2 : Float := qkActGram2.opNormUpperBoundDenseBrauer - let qkActOpBoundDense : Float := - Float.sqrt (max 0.0 (min qkActOpDense1 qkActOpDense2)) - let qkActF2_1 : Float := ConcreteMatrix.traceMul qkActGram1 qkActGram1 - let qkActF2_2 : Float := ConcreteMatrix.traceMul qkActGram2 qkActGram2 - -- `KᵀK·W_QᵀW_Q` has nonnegative real eigenvalues; use trace/trace-square moment bound. - let qkActMoment1 : Float := - ConcreteMatrix.psdLambdaMaxUpperMoment qkActGram1.numRows qkActTrace (max 0.0 qkActF2_1) - let qkActMoment2 : Float := - ConcreteMatrix.psdLambdaMaxUpperMoment qkActGram2.numRows qkActTrace (max 0.0 qkActF2_2) - let qkActOpBoundMoment : Float := - Float.sqrt (max 0.0 (min qkActMoment1 qkActMoment2)) - let qkActGram1Sq := qkActGram1.matmul qkActGram1 - let qkActGram2Sq := qkActGram2.matmul qkActGram2 - let qkActTrace4_1 : Float := ConcreteMatrix.traceMul qkActGram1Sq qkActGram1Sq - let qkActTrace4_2 : Float := ConcreteMatrix.traceMul qkActGram2Sq qkActGram2Sq - let qkActOpBoundPow4 : Float := - Float.sqrt (Float.sqrt (max 0.0 (min qkActTrace4_1 qkActTrace4_2))) - let qkActLambdaBrauer1 : Float := ConcreteMatrix.lambdaMaxUpperBrauer qkActGram1 - let qkActLambdaBrauer2 : Float := ConcreteMatrix.lambdaMaxUpperBrauer qkActGram2 - let qkActOpBoundBrauer : Float := - Float.sqrt (max 0.0 (min qkActLambdaBrauer1 qkActLambdaBrauer2)) - let qkActLambdaGershNN1 : Float := ConcreteMatrix.lambdaMaxUpperGershNonneg qkActGram1 - let qkActLambdaGershNN2 : Float := ConcreteMatrix.lambdaMaxUpperGershNonneg qkActGram2 - let qkActOpBoundGershNN : Float := - Float.sqrt (max 0.0 (min qkActLambdaGershNN1 qkActLambdaGershNN2)) - let qkActLambdaBrauerNN1 : Float := ConcreteMatrix.lambdaMaxUpperBrauerNonneg qkActGram1 - let qkActLambdaBrauerNN2 : Float := ConcreteMatrix.lambdaMaxUpperBrauerNonneg qkActGram2 - let qkActOpBoundBrauerNN : Float := - Float.sqrt (max 0.0 (min qkActLambdaBrauerNN1 qkActLambdaBrauerNN2)) - let qkActOpBound0 : Float := - Float.sqrt (max 0.0 (min qkActGram1.infNormAbs qkActGram2.infNormAbs)) - let qkActCanFactor : Bool := - !(kCenteredOpBound ≤ 0.0 || Float.isNaN kCenteredOpBound || Float.isInf kCenteredOpBound) - let qkActBase0 : (String × Float) := ("gramInf", qkActOpBound0) - let qkActBase1 := chooseMinLabel qkActBase0 ("gramGershNN", qkActOpBoundGershNN) - let qkActBase2 := chooseMinLabel qkActBase1 ("gramBrauerNN", qkActOpBoundBrauerNN) - let qkActBase3 := chooseMinLabel qkActBase2 ("gramBrauer", qkActOpBoundBrauer) - let qkActBase4 := chooseMinLabel qkActBase3 ("denseBrauer", qkActOpBoundDense) - let qkActBase5 := chooseMinLabel qkActBase4 ("gramMoment", qkActOpBoundMoment) - let qkActBase6 := chooseMinLabel qkActBase5 ("gramPow4", qkActOpBoundPow4) - let qkActBest := - if qkActCanFactor then - chooseMinLabel qkActBase6 ("factor", bnds.wqOpGram * kCenteredOpBound) - else - qkActBase6 - let qkActOpBound : Float := qkActBest.2 - let qkActOpBoundSource : String := qkActBest.1 - let kqActTrace : Float := ConcreteMatrix.traceMul bnds.wkGram qActGramCentered - let kqActGram1 := qActGramCentered.matmul bnds.wkGram - let kqActGram2 := bnds.wkGram.matmul qActGramCentered - let kqActOpDense1 : Float := kqActGram1.opNormUpperBoundDenseBrauer - let kqActOpDense2 : Float := kqActGram2.opNormUpperBoundDenseBrauer - let kqActOpBoundDense : Float := - Float.sqrt (max 0.0 (min kqActOpDense1 kqActOpDense2)) - let kqActF2_1 : Float := ConcreteMatrix.traceMul kqActGram1 kqActGram1 - let kqActF2_2 : Float := ConcreteMatrix.traceMul kqActGram2 kqActGram2 - let kqActMoment1 : Float := - ConcreteMatrix.psdLambdaMaxUpperMoment kqActGram1.numRows kqActTrace (max 0.0 kqActF2_1) - let kqActMoment2 : Float := - ConcreteMatrix.psdLambdaMaxUpperMoment kqActGram2.numRows kqActTrace (max 0.0 kqActF2_2) - let kqActOpBoundMoment : Float := - Float.sqrt (max 0.0 (min kqActMoment1 kqActMoment2)) - let kqActGram1Sq := kqActGram1.matmul kqActGram1 - let kqActGram2Sq := kqActGram2.matmul kqActGram2 - let kqActTrace4_1 : Float := ConcreteMatrix.traceMul kqActGram1Sq kqActGram1Sq - let kqActTrace4_2 : Float := ConcreteMatrix.traceMul kqActGram2Sq kqActGram2Sq - let kqActOpBoundPow4 : Float := - Float.sqrt (Float.sqrt (max 0.0 (min kqActTrace4_1 kqActTrace4_2))) - let kqActLambdaBrauer1 : Float := ConcreteMatrix.lambdaMaxUpperBrauer kqActGram1 - let kqActLambdaBrauer2 : Float := ConcreteMatrix.lambdaMaxUpperBrauer kqActGram2 - let kqActOpBoundBrauer : Float := - Float.sqrt (max 0.0 (min kqActLambdaBrauer1 kqActLambdaBrauer2)) - let kqActLambdaGershNN1 : Float := ConcreteMatrix.lambdaMaxUpperGershNonneg kqActGram1 - let kqActLambdaGershNN2 : Float := ConcreteMatrix.lambdaMaxUpperGershNonneg kqActGram2 - let kqActOpBoundGershNN : Float := - Float.sqrt (max 0.0 (min kqActLambdaGershNN1 kqActLambdaGershNN2)) - let kqActLambdaBrauerNN1 : Float := ConcreteMatrix.lambdaMaxUpperBrauerNonneg kqActGram1 - let kqActLambdaBrauerNN2 : Float := ConcreteMatrix.lambdaMaxUpperBrauerNonneg kqActGram2 - let kqActOpBoundBrauerNN : Float := - Float.sqrt (max 0.0 (min kqActLambdaBrauerNN1 kqActLambdaBrauerNN2)) - let kqActOpBound0 : Float := - Float.sqrt (max 0.0 (min kqActGram1.infNormAbs kqActGram2.infNormAbs)) - let kqActCanFactor : Bool := - !(qCenteredOpBound ≤ 0.0 || Float.isNaN qCenteredOpBound || Float.isInf qCenteredOpBound) - let kqActBase0 : (String × Float) := ("gramInf", kqActOpBound0) - let kqActBase1 := chooseMinLabel kqActBase0 ("gramGershNN", kqActOpBoundGershNN) - let kqActBase2 := chooseMinLabel kqActBase1 ("gramBrauerNN", kqActOpBoundBrauerNN) - let kqActBase3 := chooseMinLabel kqActBase2 ("gramBrauer", kqActOpBoundBrauer) - let kqActBase4 := chooseMinLabel kqActBase3 ("denseBrauer", kqActOpBoundDense) - let kqActBase5 := chooseMinLabel kqActBase4 ("gramMoment", kqActOpBoundMoment) - let kqActBase6 := chooseMinLabel kqActBase5 ("gramPow4", kqActOpBoundPow4) - let kqActBest := - if kqActCanFactor then - chooseMinLabel kqActBase6 ("factor", bnds.wkOpGram * qCenteredOpBound) - else - kqActBase6 - let kqActOpBound : Float := kqActBest.2 - let kqActOpBoundSource : String := kqActBest.1 - let qkActFrobBound : Float := Float.sqrt (max 0.0 qkActTrace) - let kqActFrobBound : Float := Float.sqrt (max 0.0 kqActTrace) - - let diag? : Option HeadDiagnosticsLazy := - if storeDiagnostics then - let voProjT : Thunk ConcreteMatrix := Thunk.mk (fun _ => head.valueOutputProjection) - let qkAlignT : Thunk ConcreteMatrix := Thunk.mk (fun _ => head.queryKeyAlignment) - let voDenseSchurT : Thunk Float := Thunk.mk (fun _ => (voProjT.get).schurNormEst) - let qkDenseSchurT : Thunk Float := Thunk.mk (fun _ => (qkAlignT.get).schurNormEst) - some { voProj := voProjT, qkAlign := qkAlignT - voDenseSchur := voDenseSchurT, qkDenseSchur := qkDenseSchurT } - else - none - - let inputs : PatternTermBoundInputs := { - attention := attn - inputNorm := inputNorm - inputOpBound := inputOpBound - qFrobBound := qFrobBound - kFrobBound := kFrobBound - vFrobBound := vFrobBound - qOpBoundAct := qOpBoundAct - kOpBoundAct := kOpBoundAct - vOpBoundAct := vOpBoundAct - qkActFrobBound := qkActFrobBound - kqActFrobBound := kqActFrobBound - qkActOpBound := qkActOpBound - kqActOpBound := kqActOpBound - scaleFactor := scaleFactor - wqOpBound := bnds.wqOpGram - wkOpBound := bnds.wkOpGram - wvOpBound := bnds.wvOpGram - woOpBound := bnds.woOpGram - voOpBound := voOpBound - bqFrob := head.b_Q.frobeniusNorm - bkFrob := head.b_K.frobeniusNorm - bvFrob := head.b_V.frobeniusNorm - } - let patternParts? : Option PatternTermBoundParts := - if storeDiagnostics then - some (computePatternTermBoundParts inputs) - else - none - let patternBound : Float := - match patternParts? with - | some parts => parts.patternBound - | none => computePatternTermBound inputs - let valueNorm : Float := - Float.sqrt attnFrobNormSq * voNorm - let ratio : Float := - if valueNorm < 1e-10 then Float.inf else patternBound / valueNorm - - return { - layerIdx := l - headIdx := hIdx - attention := attn - prevTokenStrength := prevTokenStrength - softmaxJacobianOpEst := softmaxOpBound - softmaxRowMaxP := softmaxDiag.maxRowMaxP - softmaxRowTraceBound := softmaxDiag.maxRowTraceBound - softmaxRowMomentBound := softmaxDiag.maxRowMomentBound - softmaxRowGershBound := softmaxDiag.maxRowGersh - softmaxRowBoundUsed := softmaxDiag.maxRowBoundUsed - softmaxRowsFallback := softmaxDiag.numRowsFallback - attentionFrobeniusNormSq := attnFrobNormSq - attentionOneInfBound := attnOneInf - patternTermBoundCached := patternBound - valueTermNormCached := valueNorm - faithfulnessRatioCached := ratio - patternBoundParts? := patternParts? - diag? := diag? - wqGram := wqGram - wvGram := wvGram - inputNorm := inputNorm - inputOpBound := inputOpBound - qFrobBound := qFrobBound - kFrobBound := kFrobBound - vFrobBound := vFrobBound - qOpBoundAct := qOpBoundAct - kOpBoundAct := kOpBoundAct - vOpBoundAct := vOpBoundAct - qkActFrobBound := qkActFrobBound - kqActFrobBound := kqActFrobBound - qkActOpBound := qkActOpBound - kqActOpBound := kqActOpBound - qkActOpBoundSource := qkActOpBoundSource - kqActOpBoundSource := kqActOpBoundSource - ln1OpBound := ln1Bound - scaleFactor := scaleFactor - valueOutputProjNorm := voNorm - queryKeyAlignNorm := qkNorm - valueOutputProjSchurNorm := voOpBound - queryKeyAlignSchurNorm := qkOpBound - - qkDenseFrob := bnds.qkDenseFrob - qkDenseGram := bnds.qkDenseGram - qkDenseBrauer := bnds.qkDenseBrauer - qkFactorSchur := bnds.qkFactorSchur - qkFactorFrob := bnds.qkFactorFrob - - wqOpGram := bnds.wqOpGram - wkOpGram := bnds.wkOpGram - bqFrob := head.b_Q.frobeniusNorm - bkFrob := head.b_K.frobeniusNorm - bvFrob := head.b_V.frobeniusNorm - qkFactorGram := bnds.qkFactorGram - - voDenseFrob := bnds.voDenseFrob - voDenseGram := bnds.voDenseGram - voDenseBrauer := bnds.voDenseBrauer - voFactorSchur := bnds.voFactorSchur - voFactorFrob := bnds.voFactorFrob - wvOpGram := bnds.wvOpGram - woOpGram := bnds.woOpGram - voFactorGram := bnds.voFactorGram - } - - let headTaskCount : Nat := - if heads.size < 4 then - 1 - else if useParallelLayers then - let maxHeadTasks : Nat := 48 - let budget := maxHeadTasks / model.numLayers - let target := Nat.max 1 budget - Nat.min heads.size target - else - heads.size - let computeHeadChunk (start stop : Nat) : Array PrecomputedHeadData := Id.run do - let mut out : Array PrecomputedHeadData := Array.mkEmpty (stop - start) - for h_idx in [start:stop] do - if hh : h_idx < heads.size then - out := out.push (computeHead h_idx (heads[h_idx]'hh)) - return out - if headTaskCount > 1 then - Id.run do - let chunkSize := (heads.size + headTaskCount - 1) / headTaskCount - let chunkCount := (heads.size + chunkSize - 1) / chunkSize - let tasks : Array (Task (Array PrecomputedHeadData)) := - .ofFn fun i : Fin chunkCount => - let start := i.val * chunkSize - let stop := min heads.size (start + chunkSize) - Task.spawn (fun _ => computeHeadChunk start stop) - let mut out : Array PrecomputedHeadData := Array.mkEmpty heads.size - for chunk in tasks.map Task.get do - for item in chunk do - out := out.push item - return out - else - computeHeadChunk 0 heads.size - else - #[] - - return (layerHeadData, attnInput) - - -- Pure parallelism via tasks: layer cache construction is independent once the - -- forward pass has produced all layer inputs. - let useParallel := useParallelLayers - let layerResults : Array (Array PrecomputedHeadData × ConcreteMatrix) := - if useParallel then - let tasks : Array (Task (Array PrecomputedHeadData × ConcreteMatrix)) := - .ofFn fun i : Fin model.numLayers => - Task.spawn (fun _ => computeLayer i.val) - tasks.map Task.get - else - .ofFn fun i : Fin model.numLayers => - computeLayer i.val - - let headData := layerResults.map (·.1) - let ln1Inputs := layerResults.map (·.2) - return (headData, ln1Inputs) - -/-- Build a complete precomputed cache for a model. - -This precomputes all attention patterns, projections, and norms once. --/ -def build (model : ConcreteModel) (causal : Bool := true) - (computeLayerNormBounds : Bool := true) - (layerNormEffort : ConcreteMatrix.BoundEffort := ConcreteMatrix.BoundEffort.tier1) - (storeDiagnostics : Bool := false) : - PrecomputedCache := Id.run do - let fwdResult := model.runForward causal - let (headData, ln1Inputs) := - buildHeadData model fwdResult causal layerNormEffort storeDiagnostics - let baseBounds := Array.replicate model.numLayers 0.0 - let baseCache : PrecomputedCache := { - model := model - forwardResult := fwdResult - ln1Inputs := ln1Inputs - headData := headData - layerNormBounds := baseBounds - layerNormBoundsComputed := false - } - if computeLayerNormBounds then - let layerNormBounds := PrecomputedCache.computeLayerNormBounds baseCache layerNormEffort - return { baseCache with - layerNormBounds := layerNormBounds - layerNormBoundsComputed := true } - else - return baseCache - -/-- Retrieve cached data for a specific head. -/ -def getHeadData (cache : PrecomputedCache) (layerIdx headIdx : Nat) : - Option PrecomputedHeadData := - if h1 : layerIdx < cache.headData.size then - let layerCache := cache.headData[layerIdx] - if h2 : headIdx < layerCache.size then - some layerCache[headIdx] - else none - else none - -/-- Retrieve cached Pre-LN attention input `ln_1(x_l)` for a specific layer. -/ -def getLn1Input (cache : PrecomputedCache) (layerIdx : Nat) : ConcreteMatrix := - if h : layerIdx < cache.ln1Inputs.size then - cache.ln1Inputs[layerIdx] - else - ConcreteMatrix.zeros 0 0 - -end PrecomputedCache - -/-! ## Krylov-style rigorous lower bounds for layer residual Jacobians -/ - -namespace ConcreteMatrix - -/-- A precomputed context for applying the Jacobian of row-wise LayerNorm to perturbations. -/ -structure LayerNormJacobianCtx where - numRows : Nat - numCols : Nat - gamma : ConcreteMatrix - invStds : Array Float - v : ConcreteMatrix - -/-- Build a `LayerNormJacobianCtx` for `layerNormRowwise X γ β`. - -We cache the per-row `invStd` and the centered+scaled `v = (x-μ)/σ` used in the Jacobian -formula. (`β` does not affect the Jacobian.) --/ -def mkLayerNormJacobianCtx (X γ : ConcreteMatrix) (eps : Float := 1e-5) : LayerNormJacobianCtx := - Id.run do - let rows := X.numRows - let cols := X.numCols - if rows = 0 || cols = 0 then - return { numRows := rows, numCols := cols, gamma := γ, invStds := #[], v := X } - if !(γ.numRows = 1 ∧ γ.numCols = cols) then - return { numRows := rows, numCols := cols, gamma := γ, invStds := #[], v := X } - - let mut means : Array Float := Array.replicate rows 0.0 - let mut invStds : Array Float := Array.replicate rows 0.0 - let colsF := cols.toFloat - for r in [:rows] do - let mut sum : Float := 0.0 - let rowBase := r * cols - for c in [:cols] do - sum := sum + X.data[rowBase + c]! - let μ := sum / colsF - let mut varSum : Float := 0.0 - for c in [:cols] do - let d := X.data[rowBase + c]! - μ - varSum := varSum + d * d - let varRaw := varSum / colsF - -- Clamp for numerical stability (avoid NaN from tiny negative float noise). - let var := - if Float.isNaN varRaw || Float.isInf varRaw then 0.0 - else max 0.0 varRaw - let invσ := 1.0 / Float.sqrt (var + eps) - means := means.set! r μ - invStds := invStds.set! r invσ - - let v : ConcreteMatrix := - { numRows := rows - numCols := cols - data := .ofFn fun idx : Fin (rows * cols) => - let r := idx.val / cols - let c := idx.val % cols - let μ := means[r]! - let invσ := invStds[r]! - (X.data[r * cols + c]! - μ) * invσ - size_eq := Array.size_ofFn } - - return { numRows := rows, numCols := cols, gamma := γ, invStds := invStds, v := v } - -/-- Apply the LayerNorm Jacobian `J` at the cached row statistics: `δy = δx · J` (row-wise). -/ -def LayerNormJacobianCtx.apply (ctx : LayerNormJacobianCtx) (dX : ConcreteMatrix) : ConcreteMatrix := - Id.run do - if dX.numRows ≠ ctx.numRows || dX.numCols ≠ ctx.numCols then - return ConcreteMatrix.zeros ctx.numRows ctx.numCols - let rows := ctx.numRows - let cols := ctx.numCols - let colsF := cols.toFloat - let gammaData := ctx.gamma.data - if rows = 0 || cols = 0 then - return ConcreteMatrix.zeros rows cols - if ctx.invStds.size ≠ rows then - return ConcreteMatrix.zeros rows cols - - -- Precompute per-row mean(dX) and v⋅dX. - let mut meanDx : Array Float := Array.replicate rows 0.0 - let mut vDotDx : Array Float := Array.replicate rows 0.0 - for r in [:rows] do - let rowBase := r * cols - let mut sumDx : Float := 0.0 - let mut sumVDx : Float := 0.0 - for c in [:cols] do - let dx := dX.data[rowBase + c]! - sumDx := sumDx + dx - sumVDx := sumVDx + (ctx.v.data[rowBase + c]! * dx) - meanDx := meanDx.set! r (sumDx / colsF) - vDotDx := vDotDx.set! r sumVDx - - { numRows := rows - numCols := cols - data := .ofFn fun idx : Fin (rows * cols) => - let r := idx.val / cols - let c := idx.val % cols - let rowBase := r * cols - let invσ := ctx.invStds[r]! - let g := gammaData[c]! - let dx := dX.data[rowBase + c]! - let mdx := meanDx[r]! - let vdx := vDotDx[r]! - let vrc := ctx.v.data[rowBase + c]! - g * invσ * (dx - mdx - (vrc * (vdx / colsF))) - size_eq := Array.size_ofFn } - -/-- Elementwise (Hadamard) product of two matrices with the same shape. -/ -def hadamard (A B : ConcreteMatrix) : ConcreteMatrix := - if A.numRows = B.numRows ∧ A.numCols = B.numCols then - { numRows := A.numRows - numCols := A.numCols - data := .ofFn fun idx : Fin (A.numRows * A.numCols) => - A.data[idx.val]! * B.data[idx.val]! - size_eq := Array.size_ofFn } - else - ConcreteMatrix.zeros 0 0 - -end ConcreteMatrix - -namespace ConcreteAttentionWeights - -/-- Multiply an attention matrix `A` (as weights) by a matrix `V` (n×d), returning `A·V`. -/ -def mul (A : ConcreteAttentionWeights) (V : ConcreteMatrix) : ConcreteMatrix := - if A.seqLen = 0 then - ConcreteMatrix.zeros 0 0 - else if V.numRows ≠ A.seqLen then - ConcreteMatrix.zeros A.seqLen V.numCols - else - let n := A.seqLen - { numRows := n - numCols := V.numCols - data := .ofFn fun idx : Fin (n * V.numCols) => Id.run do - let q := idx.val / V.numCols - let d := idx.val % V.numCols - let mut acc : Float := 0.0 - let rowBase := q * n - for k in [:n] do - let a := A.weights[rowBase + k]! - let v := V.data[k * V.numCols + d]! - acc := acc + a * v - return acc - size_eq := Array.size_ofFn } - -end ConcreteAttentionWeights - -/-- Cached data for applying a single-head attention Jacobian to perturbations. -/ -structure HeadJacobianCtx where - head : ConcreteAttentionLayer - attn : ConcreteAttentionWeights - input : ConcreteMatrix - Q : ConcreteMatrix - K : ConcreteMatrix - V : ConcreteMatrix - KT : ConcreteMatrix - AV : ConcreteMatrix - invScale : Float - -namespace HeadJacobianCtx - -def build (head : ConcreteAttentionLayer) (input : ConcreteMatrix) (attn : ConcreteAttentionWeights) : - HeadJacobianCtx := - let Q := (input.matmul head.W_Q).addBias head.b_Q - let K := (input.matmul head.W_K).addBias head.b_K - let V := (input.matmul head.W_V).addBias head.b_V - let KT := K.transpose - let AV := attn.mul V - let invScale := 1.0 / Float.sqrt head.headDim.toFloat - { head := head, attn := attn, input := input, Q := Q, K := K, V := V, KT := KT, AV := AV, - invScale := invScale } - -/-- Apply the attention-head Jacobian at the cached `(input, attn)` to a perturbation `dInput`. -/ -def apply (ctx : HeadJacobianCtx) (dInput : ConcreteMatrix) : ConcreteMatrix := Id.run do - let n := ctx.attn.seqLen - if n = 0 then - return ConcreteMatrix.zeros 0 0 - if dInput.numRows ≠ n || dInput.numCols ≠ ctx.head.modelDim then - return ConcreteMatrix.zeros n ctx.head.modelDim - - let dQ := dInput.matmul ctx.head.W_Q - let dK := dInput.matmul ctx.head.W_K - let dV := dInput.matmul ctx.head.W_V - - -- Value term: A · dV - let dAV_value := ctx.attn.mul dV - - -- Pattern term: (dA) · V where dA = J_softmax(S) dS. - let dS1 := dQ.matmul ctx.KT - let dS2 := ctx.Q.matmul dK.transpose - let dS := (ConcreteMatrix.scale ctx.invScale (dS1.add dS2)) - - -- cRow[q] = ⟨p_q, dS_q⟩ - let mut cRow : Array Float := Array.replicate n 0.0 - for q in [:n] do - let rowBase := q * n - let mut c : Float := 0.0 - for k in [:n] do - let p := ctx.attn.weights[rowBase + k]! - let ds := dS.data[rowBase + k]! - c := c + p * ds - cRow := cRow.set! q c - - let dAV_pattern : ConcreteMatrix := - { numRows := n - numCols := ctx.head.headDim - data := .ofFn fun idx : Fin (n * ctx.head.headDim) => Id.run do - let q := idx.val / ctx.head.headDim - let d := idx.val % ctx.head.headDim - let rowBase := q * n - let c := cRow[q]! - let mut acc : Float := 0.0 - for k in [:n] do - let p := ctx.attn.weights[rowBase + k]! - let ds := dS.data[rowBase + k]! - let alpha := p * (ds - c) - let v := ctx.V.data[k * ctx.head.headDim + d]! - acc := acc + alpha * v - return acc - size_eq := Array.size_ofFn } - - let dAV := dAV_value.add dAV_pattern - -- Project back: (dAV) · W_O - return dAV.matmul ctx.head.W_O - -end HeadJacobianCtx - -/-- Cached data for applying an MLP Jacobian to perturbations. -/ -structure MLPJacobianCtx where - layer : ConcreteMLPLayer - input : ConcreteMatrix - geluDeriv : ConcreteMatrix - -namespace MLPJacobianCtx - -def build (layer : ConcreteMLPLayer) (input : ConcreteMatrix) : MLPJacobianCtx := - let hidden := (input.matmul layer.W_in).addBias layer.b_in - let geluDeriv := hidden.map geluDerivFloat - { layer := layer, input := input, geluDeriv := geluDeriv } - -def apply (ctx : MLPJacobianCtx) (dInput : ConcreteMatrix) : ConcreteMatrix := - let dHidden := dInput.matmul ctx.layer.W_in - let dHiddenAct := ConcreteMatrix.hadamard dHidden ctx.geluDeriv - dHiddenAct.matmul ctx.layer.W_out - -end MLPJacobianCtx - -/-- Cached context for applying the residual Jacobian of one transformer block (excluding identity). -/ -structure LayerResidualJacobianCtx where - ln1 : ConcreteMatrix.LayerNormJacobianCtx - ln2 : ConcreteMatrix.LayerNormJacobianCtx - heads : Array HeadJacobianCtx - mlp? : Option MLPJacobianCtx - -namespace LayerResidualJacobianCtx - -def build (cache : PrecomputedCache) (layerIdx : Nat) : LayerResidualJacobianCtx := Id.run do - let model := cache.model - let fwd := cache.forwardResult - let x := fwd.getLayerInput layerIdx - let y := fwd.getPostAttnResidual layerIdx - - let ln1p := model.ln1.getD layerIdx (ConcreteLayerNormParams.identity model.modelDim) - let ln2p := model.ln2.getD layerIdx (ConcreteLayerNormParams.identity model.modelDim) - - let ln1Ctx := ConcreteMatrix.mkLayerNormJacobianCtx x ln1p.gamma - let ln2Ctx := ConcreteMatrix.mkLayerNormJacobianCtx y ln2p.gamma - - let attnInput := model.applyLn1 layerIdx x - let layerHeads := model.layers.getD layerIdx #[] - let layerHeadData := cache.headData.getD layerIdx #[] - let mut headCtxs : Array HeadJacobianCtx := Array.mkEmpty layerHeads.size - for h in [:layerHeads.size] do - if hh : h < layerHeads.size then - let head := layerHeads[h]'hh - let attn := - if hd : h < layerHeadData.size then - (layerHeadData[h]'hd).attention - else - head.computeAttentionWeights attnInput true - headCtxs := headCtxs.push (HeadJacobianCtx.build head attnInput attn) - - let mlp? : Option MLPJacobianCtx := - if hm : layerIdx < model.mlps.size then - let mlp := model.mlps[layerIdx]'hm - let mlpInput := model.applyLn2 layerIdx y - some (MLPJacobianCtx.build mlp mlpInput) - else - none - - return { ln1 := ln1Ctx, ln2 := ln2Ctx, heads := headCtxs, mlp? := mlp? } - - /-- Build a `LayerResidualJacobianCtx` restricted to the first `prefixLen` token positions. - - For causal attention, the outputs at positions `< prefixLen` depend only on inputs at - positions `< prefixLen`. Therefore, computing a Krylov lower bound on the restricted - operator (and measuring norms only on this prefix) yields a valid lower bound for the - full-sequence Jacobian operator norm. - -/ - def buildPrefix (cache : PrecomputedCache) (layerIdx : Nat) (prefixLen : Nat) : LayerResidualJacobianCtx := Id.run do - let model := cache.model - let fwd := cache.forwardResult - let xFull := fwd.getLayerInput layerIdx - let yFull := fwd.getPostAttnResidual layerIdx - let p := min prefixLen xFull.numRows - let x := xFull.takeRows p - let y := yFull.takeRows p - - let ln1p := model.ln1.getD layerIdx (ConcreteLayerNormParams.identity model.modelDim) - let ln2p := model.ln2.getD layerIdx (ConcreteLayerNormParams.identity model.modelDim) - let ln1Ctx := ConcreteMatrix.mkLayerNormJacobianCtx x ln1p.gamma - let ln2Ctx := ConcreteMatrix.mkLayerNormJacobianCtx y ln2p.gamma - - let attnInput := model.applyLn1 layerIdx x - let layerHeads := model.layers.getD layerIdx #[] - let mut headCtxs : Array HeadJacobianCtx := Array.mkEmpty layerHeads.size - for h in [:layerHeads.size] do - if hh : h < layerHeads.size then - let head := layerHeads[h]'hh - let attn := head.computeAttentionWeights attnInput true - headCtxs := headCtxs.push (HeadJacobianCtx.build head attnInput attn) - - let mlp? : Option MLPJacobianCtx := - if hm : layerIdx < model.mlps.size then - let mlp := model.mlps[layerIdx]'hm - let mlpInput := model.applyLn2 layerIdx y - some (MLPJacobianCtx.build mlp mlpInput) - else - none - - return { ln1 := ln1Ctx, ln2 := ln2Ctx, heads := headCtxs, mlp? := mlp? } - - /-- Apply the residual Jacobian `J_resid` (excluding identity): `δres = δx · J_resid`. - - If `parallelHeads=true`, per-head Jacobian-vector products are computed in parallel and then - summed in **head-index order** to preserve deterministic Float semantics. - -/ - def apply (ctx : LayerResidualJacobianCtx) (dX : ConcreteMatrix) - (parallelHeads : Bool := false) : ConcreteMatrix := Id.run do - let dU := ctx.ln1.apply dX - let dAttnSum : ConcreteMatrix := - if parallelHeads && ctx.heads.size >= 4 then - let tasks : Array (Task ConcreteMatrix) := - .ofFn fun i : Fin ctx.heads.size => - Task.spawn (fun _ => (ctx.heads[i]).apply dU) - let outs := tasks.map Task.get - ConcreteMatrix.sumMatrices outs (allowParallel := parallelHeads) - else - Id.run do - let mut outs : Array ConcreteMatrix := Array.mkEmpty ctx.heads.size - for hctx in ctx.heads do - outs := outs.push (hctx.apply dU) - return ConcreteMatrix.sumMatrices outs (allowParallel := false) - let dY := dX.add dAttnSum - let dV := ctx.ln2.apply dY - let dMlp := - match ctx.mlp? with - | some m => m.apply dV - | none => ConcreteMatrix.zeros dX.numRows dX.numCols - return dAttnSum.add dMlp - -end LayerResidualJacobianCtx - -private def deterministicInitVec (rows cols layerIdx : Nat) : ConcreteMatrix := - if rows = 0 || cols = 0 then - ConcreteMatrix.zeros rows cols - else - let n := rows * cols - let idx := ((layerIdx + 1) * 2654435761) % n - { numRows := rows - numCols := cols - data := .ofFn fun i : Fin n => if i.val = idx then 1.0 else 0.0 - size_eq := Array.size_ofFn } - -private def normalizeFrob (X : ConcreteMatrix) : ConcreteMatrix := - let n := X.frobeniusNorm - if n ≤ 0.0 || Float.isNaN n || Float.isInf n then - X - else - ConcreteMatrix.scale (1.0 / n) X - -private def maxAbsIndex (X : ConcreteMatrix) : Nat := Id.run do - let mut bestIdx : Nat := 0 - let mut best : Float := 0.0 - for i in [:X.data.size] do - let a := Float.abs (X.data[i]!) - if a > best then - best := a - bestIdx := i - bestIdx - -private def basisAt (rows cols idx : Nat) : ConcreteMatrix := - if rows = 0 || cols = 0 then - ConcreteMatrix.zeros rows cols - else - let n := rows * cols - if idx < n then - { numRows := rows - numCols := cols - data := .ofFn fun i : Fin n => if i.val = idx then 1.0 else 0.0 - size_eq := Array.size_ofFn } - else - ConcreteMatrix.zeros rows cols - -private def signLike (X : ConcreteMatrix) : ConcreteMatrix := - { numRows := X.numRows - numCols := X.numCols - data := .ofFn fun i : Fin (X.numRows * X.numCols) => - let x := X.data[i.val]! - if x > 0.0 then 1.0 else if x < 0.0 then (-1.0) else 0.0 - size_eq := Array.size_ofFn } - -/-- Improve the lower bound by trying a small set of deterministic initial vectors. - -This is still rigorous: for any nonzero `v`, `‖J v‖/‖v‖ ≤ ‖J‖`. We simply take the max -over a few starts to avoid a "bad" basis vector giving a vacuous bound. --/ -private def krylovLowerBoundFromInit (ctx : LayerResidualJacobianCtx) (v0 : ConcreteMatrix) - (k : Nat) (parallelHeads : Bool) : Float := Id.run do - let mut v := v0 - let mut best : Float := 0.0 - for _ in [:k] do - let vNorm := v.frobeniusNorm - if vNorm ≤ 0.0 || Float.isNaN vNorm || Float.isInf vNorm then - break - let w := ctx.apply v (parallelHeads := parallelHeads) - let wNorm := w.frobeniusNorm - let r : Float := wNorm / vNorm - if r > best then - best := r - if wNorm ≤ 0.0 || Float.isNaN wNorm || Float.isInf wNorm then - break - v := normalizeFrob w - return best - -/-- Rigorous (in exact real arithmetic) lower bound on `‖J_resid‖₂` from `k` Krylov steps. -/ -def layerResidualJacobianLowerBound (cache : PrecomputedCache) (layerIdx : Nat) (k : Nat := 4) - (parallelHeads : Bool := false) : - Float := Id.run do - let x := cache.forwardResult.getLayerInput layerIdx - let rows := x.numRows - let cols := x.numCols - if rows = 0 || cols = 0 then - return 0.0 - - -- PERFORMANCE: the full Jacobian-apply cost scales like `O(seqLen^2)` per head. - -- For causal attention we can restrict to a small prefix and still obtain a rigorous - -- lower bound by projecting the output norm to that prefix. - -- Keep this small: the Jacobian-apply cost is roughly quadratic in `prefixLen` due to softmax terms. - let prefixLen : Nat := min rows 16 - let xP := x.takeRows prefixLen - let ctx := - if prefixLen = rows then - LayerResidualJacobianCtx.build cache layerIdx - else - LayerResidualJacobianCtx.buildPrefix cache layerIdx prefixLen - let vHash := deterministicInitVec prefixLen cols layerIdx - let vData := normalizeFrob xP - let vMax := basisAt prefixLen cols (maxAbsIndex xP) - let vSign := normalizeFrob (signLike xP) - let useParallelInits : Bool := parallelHeads && (prefixLen * cols ≥ 50000) && k > 0 - if useParallelInits then - let t1 := Task.spawn (fun _ => krylovLowerBoundFromInit ctx vHash k parallelHeads) - let t2 := Task.spawn (fun _ => krylovLowerBoundFromInit ctx vData k parallelHeads) - let t3 := Task.spawn (fun _ => krylovLowerBoundFromInit ctx vMax k parallelHeads) - let t4 := Task.spawn (fun _ => krylovLowerBoundFromInit ctx vSign k parallelHeads) - let b1 := t1.get - let b2 := t2.get - let b3 := t3.get - let b4 := t4.get - return max (max b1 b2) (max b3 b4) - else - let b1 := krylovLowerBoundFromInit ctx vHash k parallelHeads - let b2 := krylovLowerBoundFromInit ctx vData k parallelHeads - let b3 := krylovLowerBoundFromInit ctx vMax k parallelHeads - let b4 := krylovLowerBoundFromInit ctx vSign k parallelHeads - return max (max b1 b2) (max b3 b4) - -/-! ## Adaptive bound scheduler (rigorous upper bounds, deterministic) -/ - -inductive AdaptiveScope where - | layernorm - | all - deriving Repr, BEq, Inhabited - -structure AdaptiveSchedulerConfig where - targetSlack : Float := 8.0 - maxUpgrades : Nat := 200 - minRelImprove : Float := 0.01 - krylovSteps : Nat := 4 - scope : AdaptiveScope := .layernorm - debugMonotone : Bool := false - deriving Repr - -inductive AdaptiveUpgradeKind where - | ubTier - | lbSteps - deriving Repr, BEq, Inhabited - -structure AdaptiveSchedulerStep where - iter : Nat - layerIdx : Nat - kind : AdaptiveUpgradeKind := .ubTier - tierFrom : Nat - tierTo : Nat - kFrom : Nat := 0 - kTo : Nat := 0 - ubBefore : Float - ubAfter : Float - lb : Float - slackBefore : Float - slackAfter : Float - deriving Repr - -structure AdaptiveSchedulerResult where - /-- Final per-layer rigorous upper bounds. -/ - ub : Array Float - /-- Per-layer rigorous lower bounds from Krylov steps. -/ - lb : Array Float - /-- Krylov steps used for each layer lower bound. -/ - lbK : Array Nat - /-- Final effort tier index per layer. -/ - tier : Array Nat - /-- Upgrade log (one entry per accepted upgrade). -/ - steps : Array AdaptiveSchedulerStep - deriving Repr - -private structure AdaptiveUbCaches where - winUb : Array (Array (Option Float)) - woutUb : Array (Array (Option Float)) - /-- Cached operator-norm upper bounds `‖ln₁(X_l)‖₂` per layer and effort tier. -/ - ln1InputOpUb : Array (Array (Option Float)) - /-- Cached attention-only residual Jacobian upper bounds per layer and effort tier. -/ - attnPartUb : Array (Array (Option Float)) - ubAt : Array (Array (Option Float)) - mlpCore : Array (Option MLPJacobianBoundCore) - deriving Inhabited - -private def safeSlack (ub lb : Float) : Float := - if lb ≤ 0.0 || Float.isNaN lb || Float.isInf lb then - Float.inf - else - ub / lb - -/-- Run the adaptive scheduler for per-layer norm bounds. - -Correctness: every `ub[l]` returned is a minimum of **rigorous** upper bound candidates. -The adaptive logic only changes which candidates are computed; it never relaxes the final bound. --/ -def runAdaptiveScheduler (cache : PrecomputedCache) (cfg : AdaptiveSchedulerConfig) - (activeLayers? : Option (Array Bool) := none) : - AdaptiveSchedulerResult := Id.run do - let model := cache.model - let tiers := ConcreteMatrix.BoundEffort.tiers - let startTier : Nat := 1 -- current default behavior corresponds to tier1 - let lbStep : Nat := if cfg.krylovSteps = 0 then 0 else max 1 cfg.krylovSteps - let maxKrylov : Nat := - if cfg.krylovSteps = 0 then 0 - else max cfg.krylovSteps (min 32 (max 8 (cfg.krylovSteps * 4))) - let active : Array Bool := - match activeLayers? with - | some layers => - if layers.size = model.numLayers then layers - else Array.replicate model.numLayers true - | none => Array.replicate model.numLayers true - - -- Precompute tier-1 attention part and per-layer ln₂ Jacobian bound. - -- - -- We cache tier-1 attention bounds directly from `cache.headData` to preserve the - -- current default behavior and avoid recomputing `‖ln₁(X_l)‖₂` up front. - let nLayers := model.numLayers - let useParallelInit := nLayers >= 4 - let computeLayerInit (l : Nat) : (Float × Float × Float) := Id.run do - if !active[l]! then - return (0.0, 0.0, 0.0) - let layerData := cache.headData.getD l #[] - let mut a : Float := 0.0 - let inputOpBoundTier1 : Float := - if h : 0 < layerData.size then - (layerData[0]'h).inputOpBound - else - let emptyMat : ConcreteMatrix := { numRows := 0, numCols := 0, data := #[], size_eq := by simp } - let x := cache.ln1Inputs.getD l emptyMat - x.opNormUpperBoundRectGramEffort (tiers.getD startTier ConcreteMatrix.BoundEffort.tier1) - for d in layerData do - let attnFrob : Float := Float.sqrt (max 0.0 d.attentionFrobeniusNormSq) - let attnOpUb : Float := min attnFrob d.attentionOneInfBound - let valueTermUb : Float := attnOpUb * d.valueOutputProjSchurNorm - let inputs : PatternTermBoundInputs := { - attention := d.attention - inputNorm := d.inputNorm - inputOpBound := inputOpBoundTier1 - qFrobBound := d.qFrobBound - kFrobBound := d.kFrobBound - vFrobBound := d.vFrobBound - qOpBoundAct := d.qOpBoundAct - kOpBoundAct := d.kOpBoundAct - vOpBoundAct := d.vOpBoundAct - qkActFrobBound := d.qkActFrobBound - kqActFrobBound := d.kqActFrobBound - qkActOpBound := d.qkActOpBound - kqActOpBound := d.kqActOpBound - scaleFactor := d.scaleFactor - wqOpBound := d.wqOpGram - wkOpBound := d.wkOpGram - wvOpBound := d.wvOpGram - woOpBound := d.woOpGram - voOpBound := d.valueOutputProjSchurNorm - bqFrob := d.bqFrob - bkFrob := d.bkFrob - bvFrob := d.bvFrob - } - let patternTermUb : Float := computePatternTermBound inputs - a := a + d.ln1OpBound * (valueTermUb + patternTermUb) - - let y := cache.forwardResult.getPostAttnResidual l - let ln2Bound := model.ln2OpBound l y - return (a, ln2Bound, inputOpBoundTier1) - - let init : Array (Float × Float × Float) := - if useParallelInit then - let tasks : Array (Task (Float × Float × Float)) := - .ofFn fun i : Fin nLayers => Task.spawn (fun _ => computeLayerInit i.val) - tasks.map Task.get - else - .ofFn fun i : Fin nLayers => computeLayerInit i.val - - let attnPartTier1 : Array Float := init.map (·.1) - let ln2Bound : Array Float := init.map (·.2.1) - let ln1InputOpTier1 : Array Float := init.map (·.2.2) - - let mut lbK : Array Nat := Array.replicate model.numLayers 0 - for l in [:model.numLayers] do - if active[l]! then - lbK := lbK.set! l cfg.krylovSteps - let mut lb : Array Float := Array.replicate model.numLayers 0.0 - let useParallelLb := model.numLayers >= 4 && cfg.krylovSteps > 0 - if useParallelLb then - let mut tasks : Array (Option (Task Float)) := Array.replicate model.numLayers none - for l in [:model.numLayers] do - if active[l]! then - tasks := tasks.set! l (some (Task.spawn (fun _ => - layerResidualJacobianLowerBound cache l cfg.krylovSteps (parallelHeads := true)))) - let mut out : Array Float := Array.replicate model.numLayers 0.0 - for l in [:model.numLayers] do - match tasks[l]! with - | some t => out := out.set! l t.get - | none => pure () - lb := out - else - for l in [:model.numLayers] do - if active[l]! then - lb := lb.set! l (layerResidualJacobianLowerBound cache l cfg.krylovSteps (parallelHeads := true)) - - let mlpUbFromCore (core : MLPJacobianBoundCore) (winUb woutUb : Float) : Float := - let legacy := winUb * core.globalDmax * woutUb - let scaledViaWin := core.winScaledUb * woutUb - let scaledViaWout := winUb * core.woutScaledUb - let scaled0 := - if scaledViaWin ≤ 0.0 || Float.isNaN scaledViaWin || Float.isInf scaledViaWin then - scaledViaWout - else if scaledViaWout ≤ 0.0 || Float.isNaN scaledViaWout || Float.isInf scaledViaWout then - scaledViaWin - else - min scaledViaWin scaledViaWout - let scaled := - if scaled0 ≤ 0.0 || Float.isNaN scaled0 || Float.isInf scaled0 then legacy else scaled0 - let absSchur := core.absSchur - if absSchur ≤ 0.0 || Float.isNaN absSchur || Float.isInf absSchur then - min legacy scaled - else - min absSchur (min legacy scaled) - - let initRow : Array (Option Float) := Array.replicate tiers.size none - let initCaches : AdaptiveUbCaches := - { winUb := Array.replicate model.numLayers (Array.replicate tiers.size none) - woutUb := Array.replicate model.numLayers (Array.replicate tiers.size none) - ln1InputOpUb := Array.replicate model.numLayers initRow - attnPartUb := Array.replicate model.numLayers initRow - ubAt := Array.replicate model.numLayers (Array.replicate tiers.size none) - mlpCore := Array.replicate model.numLayers none } - let mut caches : AdaptiveUbCaches := initCaches - - -- Seed tier-1 caches from precomputation (fast path). - for l in [:model.numLayers] do - if active[l]! && startTier < tiers.size then - let rowIn := (caches.ln1InputOpUb[l]!).set! startTier (some (ln1InputOpTier1.getD l 0.0)) - let rowA := (caches.attnPartUb[l]!).set! startTier (some (attnPartTier1.getD l 0.0)) - caches := - { caches with - ln1InputOpUb := caches.ln1InputOpUb.set! l rowIn - attnPartUb := caches.attnPartUb.set! l rowA } - - let getWinWoutM (layerIdx tierIdx : Nat) : StateM AdaptiveUbCaches (Float × Float) := do - let st ← get - if !(layerIdx < model.numLayers ∧ tierIdx < tiers.size) then - return (0.0, 0.0) - if hm : layerIdx < model.mlps.size then - let win? := st.winUb[layerIdx]![tierIdx]! - let wout? := st.woutUb[layerIdx]![tierIdx]! - match win?, wout? with - | some win, some wout => return (win, wout) - | _, _ => - let eff := tiers[tierIdx]! - let mlp := model.mlps[layerIdx]'hm - let big := mlp.W_in.numRows * mlp.W_in.numCols + mlp.W_out.numRows * mlp.W_out.numCols - let useParallel := big >= 200000 && tierIdx > 0 - let (win, wout) := - if useParallel then - let t1 := Task.spawn (fun _ => mlp.W_in.opNormUpperBoundRectGramEffort eff) - let t2 := Task.spawn (fun _ => mlp.W_out.opNormUpperBoundRectGramEffort eff) - (t1.get, t2.get) - else - (mlp.W_in.opNormUpperBoundRectGramEffort eff, mlp.W_out.opNormUpperBoundRectGramEffort eff) - let rowWin := (st.winUb[layerIdx]!).set! tierIdx (some win) - let rowWout := (st.woutUb[layerIdx]!).set! tierIdx (some wout) - set { st with winUb := st.winUb.set! layerIdx rowWin, woutUb := st.woutUb.set! layerIdx rowWout } - return (win, wout) - else - return (0.0, 0.0) - - let getMlpCoreM (layerIdx : Nat) : StateM AdaptiveUbCaches (Option MLPJacobianBoundCore) := do - let st ← get - if !(layerIdx < model.numLayers) then - return none - match st.mlpCore[layerIdx]! with - | some core => return some core - | none => - if hm : layerIdx < model.mlps.size then - let mlp := model.mlps[layerIdx]'hm - let y := cache.forwardResult.getPostAttnResidual layerIdx - let geluDeriv := cache.forwardResult.getMlpGeluDeriv layerIdx - let core? := - if geluDeriv.numCols = mlp.hiddenDim ∧ geluDeriv.numRows = y.numRows then - mlp.precomputeJacobianBoundCore geluDeriv - else - let mlpInput := model.applyLn2 layerIdx y - let hidden := (mlpInput.matmul mlp.W_in).addBias mlp.b_in - let dAct := hidden.map geluDerivFloat - mlp.precomputeJacobianBoundCore dAct - set { st with mlpCore := st.mlpCore.set! layerIdx core? } - return core? - else - return none - - let rec getLn1InputOpUbM (layerIdx tierIdx : Nat) : StateM AdaptiveUbCaches Float := do - let st ← get - if !(layerIdx < model.numLayers ∧ tierIdx < tiers.size) then - return 0.0 - match st.ln1InputOpUb[layerIdx]![tierIdx]! with - | some v => return v - | none => - -- Enforce monotonicity across tiers by always including the previous-tier bound. - let prev : Float ← - if tierIdx = 0 then - pure Float.inf - else - getLn1InputOpUbM layerIdx (tierIdx - 1) - let emptyMat : ConcreteMatrix := { numRows := 0, numCols := 0, data := #[], size_eq := by simp } - let x := cache.ln1Inputs.getD layerIdx emptyMat - let eff := tiers[tierIdx]! - let raw := x.opNormUpperBoundRectGramEffort eff - let v := min prev raw - let row := (st.ln1InputOpUb[layerIdx]!).set! tierIdx (some v) - set { st with ln1InputOpUb := st.ln1InputOpUb.set! layerIdx row } - return v - - let rec getAttnPartUbM (layerIdx tierIdx : Nat) : StateM AdaptiveUbCaches Float := do - let st ← get - if !(layerIdx < model.numLayers ∧ tierIdx < tiers.size) then - return 0.0 - match st.attnPartUb[layerIdx]![tierIdx]! with - | some v => return v - | none => - let prev : Float ← - if tierIdx = 0 then - pure Float.inf - else - getAttnPartUbM layerIdx (tierIdx - 1) - let inputOpBound ← getLn1InputOpUbM layerIdx tierIdx - let layerData := cache.headData.getD layerIdx #[] - let mut a : Float := 0.0 - for d in layerData do - let attnFrob : Float := Float.sqrt (max 0.0 d.attentionFrobeniusNormSq) - let attnOpUb : Float := min attnFrob d.attentionOneInfBound - let valueTermUb : Float := attnOpUb * d.valueOutputProjSchurNorm - let inputs : PatternTermBoundInputs := { - attention := d.attention - inputNorm := d.inputNorm - inputOpBound := inputOpBound - qFrobBound := d.qFrobBound - kFrobBound := d.kFrobBound - vFrobBound := d.vFrobBound - qOpBoundAct := d.qOpBoundAct - kOpBoundAct := d.kOpBoundAct - vOpBoundAct := d.vOpBoundAct - qkActFrobBound := d.qkActFrobBound - kqActFrobBound := d.kqActFrobBound - qkActOpBound := d.qkActOpBound - kqActOpBound := d.kqActOpBound - scaleFactor := d.scaleFactor - wqOpBound := d.wqOpGram - wkOpBound := d.wkOpGram - wvOpBound := d.wvOpGram - woOpBound := d.woOpGram - voOpBound := d.valueOutputProjSchurNorm - bqFrob := d.bqFrob - bkFrob := d.bkFrob - bvFrob := d.bvFrob - } - let patternTermUb : Float := computePatternTermBound inputs - a := a + d.ln1OpBound * (valueTermUb + patternTermUb) - let v := min prev a - let row := (st.attnPartUb[layerIdx]!).set! tierIdx (some v) - set { st with attnPartUb := st.attnPartUb.set! layerIdx row } - return v - - let computeUbAtM (layerIdx tierIdx : Nat) : StateM AdaptiveUbCaches Float := do - let st ← get - if !(layerIdx < model.numLayers ∧ tierIdx < tiers.size) then - return 0.0 - match st.ubAt[layerIdx]![tierIdx]! with - | some v => return v - | none => - let base ← getAttnPartUbM layerIdx tierIdx - let v : Float ← - if hm : layerIdx < model.mlps.size then - let mlp := model.mlps[layerIdx]'hm - let (win, wout) ← getWinWoutM layerIdx tierIdx - let l2 := ln2Bound.getD layerIdx 0.0 - let mlpUb : Float ← - if tierIdx ≤ startTier then - -- Tier0/tier1 are only evaluated once per layer; avoid extra passes just to build core. - let y := cache.forwardResult.getPostAttnResidual layerIdx - let geluDeriv := cache.forwardResult.getMlpGeluDeriv layerIdx - if geluDeriv.numCols = mlp.hiddenDim ∧ geluDeriv.numRows = y.numRows then - pure (computeMLPLayerOpNormFromGeluDerivWithOpBounds mlp geluDeriv win wout) - else - let mlpInput := model.applyLn2 layerIdx y - let hidden := (mlpInput.matmul mlp.W_in).addBias mlp.b_in - let dAct := hidden.map geluDerivFloat - pure (computeMLPLayerOpNormFromGeluDerivWithOpBounds mlp dAct win wout) - else - match (← getMlpCoreM layerIdx) with - | some core => - pure (mlpUbFromCore core win wout) - | none => - let y := cache.forwardResult.getPostAttnResidual layerIdx - let geluDeriv := cache.forwardResult.getMlpGeluDeriv layerIdx - if geluDeriv.numCols = mlp.hiddenDim ∧ geluDeriv.numRows = y.numRows then - pure (computeMLPLayerOpNormFromGeluDerivWithOpBounds mlp geluDeriv win wout) - else - let mlpInput := model.applyLn2 layerIdx y - let hidden := (mlpInput.matmul mlp.W_in).addBias mlp.b_in - let dAct := hidden.map geluDerivFloat - pure (computeMLPLayerOpNormFromGeluDerivWithOpBounds mlp dAct win wout) - let mlpPart := l2 * mlpUb - pure (base + (1.0 + base) * mlpPart) - else - pure base - let row := (st.ubAt[layerIdx]!).set! tierIdx (some v) - set { st with ubAt := st.ubAt.set! layerIdx row } - return v - - -- Initialize to current default (tier1) to ensure adaptive never worsens bounds. - let mut tier : Array Nat := Array.replicate model.numLayers startTier - let mut ub : Array Float := Array.mkEmpty model.numLayers - for l in [:model.numLayers] do - if active[l]! then - let (v, st) := (computeUbAtM l startTier).run caches - caches := st - let baseUb : Float := - if cache.layerNormBoundsComputed then - min (cache.layerNormBounds.getD l v) v - else - v - ub := ub.push baseUb - else - ub := ub.push 0.0 - - let mut steps : Array AdaptiveSchedulerStep := #[] - let mut stalledUb : Array Bool := Array.ofFn fun i : Fin model.numLayers => !active[i.val]! - let mut stalledLb : Array Bool := Array.ofFn fun i : Fin model.numLayers => !active[i.val]! - let mut upgradesUsed : Nat := 0 - let epsNoise : Float := 1e-6 - - while upgradesUsed < cfg.maxUpgrades do - -- Find worst slack among layers that can still be upgraded (ub tier or lb steps). - let mut bestLayer : Nat := 0 - let mut bestSlack : Float := 0.0 - let mut bestUb : Float := 0.0 - let mut found : Bool := false - for l in [:model.numLayers] do - if !active[l]! then - continue - let t := tier[l]! - let canUb := (!stalledUb[l]!) && (t + 1 < tiers.size) - let canLb := (!stalledLb[l]!) && (lbStep > 0) && (lbK[l]! < maxKrylov) - if (!canUb) && (!canLb) then - continue - let s := safeSlack (ub[l]!) (lb[l]!) - if (!found) || s > bestSlack || (s == bestSlack && ub[l]! > bestUb) then - found := true - bestLayer := l - bestSlack := s - bestUb := ub[l]! - - if !found then - break - if bestSlack ≤ cfg.targetSlack then - break - - let l := bestLayer - let oldUb := ub[l]! - let oldLb := lb[l]! - let oldSlack := safeSlack oldUb oldLb - let t := tier[l]! - let canUb := (!stalledUb[l]!) && (t + 1 < tiers.size) - let canLb := (!stalledLb[l]!) && (lbStep > 0) && (lbK[l]! < maxKrylov) - - -- Heuristic choice when both upgrades are possible: - -- if slack is dominated by a tiny lower bound, prioritize strengthening `lb` first. - if canLb && canUb && (oldLb ≤ 0.0 || oldSlack > cfg.targetSlack * 4.0) then - let fromK := lbK[l]! - let toK := min (fromK + lbStep) maxKrylov - let newLb := layerResidualJacobianLowerBound cache l toK (parallelHeads := true) - let relImprove : Float := - if oldLb ≤ 0.0 then (if newLb > 0.0 then 1.0 else 0.0) else (newLb - oldLb) / oldLb - if relImprove < cfg.minRelImprove then - stalledLb := stalledLb.set! l true - else - lbK := lbK.set! l toK - lb := lb.set! l newLb - steps := steps.push - { iter := upgradesUsed - layerIdx := l - kind := .lbSteps - tierFrom := tier[l]! - tierTo := tier[l]! - kFrom := fromK - kTo := toK - ubBefore := oldUb - ubAfter := oldUb - lb := lb[l]! - slackBefore := oldSlack - slackAfter := safeSlack oldUb (lb[l]!) } - upgradesUsed := upgradesUsed + 1 - continue - - if canUb then - let fromTier := t - let toTier := fromTier + 1 - let (newUb, st) := (computeUbAtM l toTier).run caches - caches := st - if cfg.debugMonotone && newUb > oldUb + epsNoise then - panic! s!"Adaptive monotonicity violated at layer {l}: old={oldUb} new={newUb}" - let relImprove : Float := - if oldUb ≤ 0.0 then 0.0 else (oldUb - newUb) / oldUb - if relImprove < cfg.minRelImprove then - stalledUb := stalledUb.set! l true - else - tier := tier.set! l toTier - ub := ub.set! l (min oldUb newUb) - steps := steps.push - { iter := upgradesUsed - layerIdx := l - kind := .ubTier - tierFrom := fromTier - tierTo := toTier - kFrom := lbK[l]! - kTo := lbK[l]! - ubBefore := oldUb - ubAfter := ub[l]! - lb := lb[l]! - slackBefore := oldSlack - slackAfter := safeSlack (ub[l]!) (lb[l]!) } - upgradesUsed := upgradesUsed + 1 - continue - - if canLb then - let fromK := lbK[l]! - let toK := min (fromK + lbStep) maxKrylov - let newLb := layerResidualJacobianLowerBound cache l toK (parallelHeads := true) - let relImprove : Float := - if oldLb ≤ 0.0 then (if newLb > 0.0 then 1.0 else 0.0) else (newLb - oldLb) / oldLb - if relImprove < cfg.minRelImprove then - stalledLb := stalledLb.set! l true - else - lbK := lbK.set! l toK - lb := lb.set! l newLb - steps := steps.push - { iter := upgradesUsed - layerIdx := l - kind := .lbSteps - tierFrom := tier[l]! - tierTo := tier[l]! - kFrom := fromK - kTo := toK - ubBefore := oldUb - ubAfter := oldUb - lb := lb[l]! - slackBefore := oldSlack - slackAfter := safeSlack (ub[l]!) (lb[l]!) } - upgradesUsed := upgradesUsed + 1 - continue - - stalledUb := stalledUb.set! l true - stalledLb := stalledLb.set! l true - upgradesUsed := upgradesUsed + 1 - - return { ub := ub, lb := lb, lbK := lbK, tier := tier, steps := steps } - - -/-! ## Head Composition Metrics -/ - -/-- Compute the **K-composition** score between two attention heads. - -This follows the definition in *"A Mathematical Framework for Transformer Circuits"* -(see the "composition diagram caption"): - -`kComp(h₁→h₂) = ‖W_QK² · W_OV¹‖_F / (‖W_QK²‖_F · ‖W_OV¹‖_F)`. - -In this codebase's row-vector convention, we store `W_QK = W_Q · W_Kᵀ` and -`W_OV,row = W_V · W_O`. The paper's `W_OV` corresponds to `(W_OV,row)ᵀ`, but -`‖W_OV‖_F = ‖W_OV,row‖_F` and we compute the numerator via low-rank factors without -materializing any `modelDim×modelDim` products. - -By default (as in the paper), we subtract the expected composition of random matrices -of shape `modelDim×modelDim`, which is approximately `1/√modelDim`. - -We return 0.0 on dimension mismatch or missing heads. --/ -def computeKCompositionScore - (model : ConcreteModel) - (data1 data2 : PrecomputedHeadData) : Float := - if h1 : data1.layerIdx < model.layers.size then - let heads1 := model.layers[data1.layerIdx]'h1 - if h2 : data2.layerIdx < model.layers.size then - let heads2 := model.layers[data2.layerIdx]'h2 - if hh1 : data1.headIdx < heads1.size then - let head1 := heads1[data1.headIdx]'hh1 - if hh2 : data2.headIdx < heads2.size then - let head2 := heads2[data2.headIdx]'hh2 - let denom := data2.queryKeyAlignNorm * data1.valueOutputProjNorm - if denom < 1e-10 then 0.0 - else - -- K-composition numerator: - -- ‖W_QK² · W_OV¹‖_F where W_QK² = W_Q²W_K²ᵀ and W_OV¹ = (W_V¹W_O¹)ᵀ. - -- Using low-rank factorization: - -- W_Q²W_K²ᵀ (W_V¹W_O¹)ᵀ = W_Q² (W_K²ᵀ W_O¹ᵀ) W_V¹ᵀ - -- and M := W_O¹ W_K² so W_K²ᵀ W_O¹ᵀ = Mᵀ. - let M := head1.W_O.matmul head2.W_K -- (d_head × d_head) - let T := data1.wvGram.matmul M -- (d_head × d_head) - let S := M.transpose.matmul T -- Mᵀ · (W_V¹ᵀW_V¹) · M - let numeratorSq := ConcreteMatrix.traceMul S data2.wqGram - let numerator := Float.sqrt (max numeratorSq 0.0) - let raw := numerator / denom - let baseline : Float := - if head1.modelDim = 0 then 0.0 - else 1.0 / Float.sqrt head1.modelDim.toFloat - raw - baseline - else - 0.0 - else - 0.0 - else - 0.0 - else - 0.0 - -/-- Core induction candidate data used for fast thresholding. -/ -structure InductionCandidateCore where - layer1Idx : Nat - layer2Idx : Nat - head1Idx : Nat - head2Idx : Nat - patternBound1 : Float - patternBound2 : Float - combinedError : Float - prevTokenStrength : Float - description : String - -namespace InductionCandidateCore - -/-- Finalize an induction candidate by computing expensive scores. -/ -def toInductionCandidate? (core : InductionCandidateCore) (cache : PrecomputedCache) : - Option CandidateInductionHead := - match cache.getHeadData core.layer1Idx core.head1Idx, - cache.getHeadData core.layer2Idx core.head2Idx with - | some d1, some d2 => - let inductionScore : Float := - match cache.model.inputTokens with - | some tokens => - (checkInductionCopyNextPattern tokens d2.attention (minScore := 0.0)).getD 0.0 - | none => 1.0 - let kComp := computeKCompositionScore cache.model d1 d2 - some { - layer1Idx := core.layer1Idx - layer2Idx := core.layer2Idx - head1Idx := core.head1Idx - head2Idx := core.head2Idx - patternBound1 := core.patternBound1 - patternBound2 := core.patternBound2 - combinedError := core.combinedError - prevTokenStrength := core.prevTokenStrength - inductionScore := inductionScore - kComp := kComp - description := core.description - } - | _, _ => none - -end InductionCandidateCore - -/-- Convert a 2-layer deep circuit candidate into cheap induction-core data. -/ -def DeepCircuitCandidate.toInductionCandidateCore? - (c : DeepCircuitCandidate) (cache : PrecomputedCache) : - Option InductionCandidateCore := - if c.layerIndices.size = 2 && c.headIndices.size = 2 then - let l1 := c.layerIndices[0]! - let l2 := c.layerIndices[1]! - let h1 := c.headIndices[0]! - let h2 := c.headIndices[1]! - match cache.getHeadData l1 h1, cache.getHeadData l2 h2 with - | some d1, some d2 => - let ε1 := d1.faithfulnessRatio - let ε2 := d2.faithfulnessRatio - let combinedError := ε1 + ε2 + ε1 * ε2 - some { - layer1Idx := l1 - layer2Idx := l2 - head1Idx := h1 - head2Idx := h2 - patternBound1 := ε1 - patternBound2 := ε2 - combinedError := combinedError - prevTokenStrength := d1.prevTokenStrength - description := s!"L{l1}H{h1}->L{l2}H{h2} (deep)" - } - | _, _ => none - else - none - -/-- Convert a 2-layer deep circuit candidate into an induction-head candidate. - -This is used to avoid re-running the expensive `checkInductionPattern` scan when both -induction heads and deep circuits are requested from the same cache. --/ -def DeepCircuitCandidate.toInductionCandidate? - (c : DeepCircuitCandidate) (cache : PrecomputedCache) : - Option CandidateInductionHead := - match c.toInductionCandidateCore? cache with - | none => none - | some core => core.toInductionCandidate? cache - -/-- Find candidate (L1, L2) induction-head pairs from a `PrecomputedCache`. - -This searches for the classic two-head induction circuit: -- Layer 1 (L1): a strong **previous-token** head, -- Layer 2 (L2): an **induction** head (token-level "copy-next" attention). --/ -def findInductionHeadCandidatesFromCache (cache : PrecomputedCache) - (minPrevTokenStrength : Float := 0.1) - (minInductionScore : Float := 0.05) : Array CandidateInductionHead := Id.run do - let model := cache.model - let tokens? := model.inputTokens - -- Precompute induction scores per head (layer2-only), since these are reused across head1 loops. - let headInductionScores : Array (Array (Option Float)) := - match tokens? with - | some tokens => - cache.headData.map (fun layer => - layer.map (fun data2 => - checkInductionCopyNextPattern tokens data2.attention minInductionScore)) - | none => - cache.headData.map (fun layer => layer.map (fun _ => some 1.0)) - - let computeForLayer (l1 : Nat) : Array CandidateInductionHead := Id.run do - let layer1Cache := cache.headData.getD l1 #[] - let layer1Candidates := - if minPrevTokenStrength <= 0.0 then - layer1Cache - else - layer1Cache.filter (fun data1 => data1.prevTokenStrength ≥ minPrevTokenStrength) - let mut layerCandidates : Array CandidateInductionHead := #[] - - -- Preserve the original traversal order: l1, l2, head1, head2. - for l2 in [l1 + 1:model.numLayers] do - let layer2Cache := cache.headData.getD l2 #[] - let layer2Scores := headInductionScores.getD l2 #[] - let layer2Candidates : Array (PrecomputedHeadData × Float) := Id.run do - let mut out : Array (PrecomputedHeadData × Float) := #[] - for (data2, score?) in layer2Cache.zip layer2Scores do - match score? with - | some inductionScore => out := out.push (data2, inductionScore) - | none => pure () - return out - for data1 in layer1Candidates do - for (data2, inductionScore) in layer2Candidates do - -- Use dimensionless faithfulness ratios (relative approximation errors). - let ε1 := data1.faithfulnessRatio - let ε2 := data2.faithfulnessRatio - let combinedError := ε1 + ε2 + ε1 * ε2 - let kComp := computeKCompositionScore model data1 data2 - - layerCandidates := layerCandidates.push { - layer1Idx := l1 - layer2Idx := l2 - head1Idx := data1.headIdx - head2Idx := data2.headIdx - patternBound1 := ε1 - patternBound2 := ε2 - combinedError := combinedError - prevTokenStrength := data1.prevTokenStrength - inductionScore := inductionScore - kComp := kComp - description := s!"L{l1}H{data1.headIdx}->L{l2}H{data2.headIdx}" - } - - return layerCandidates - - -- Pure parallelism via tasks: layer-1 index computations are independent. - let useParallel := model.numLayers >= 4 - let chunks : Array (Array CandidateInductionHead) := - if useParallel then - let tasks : Array (Task (Array CandidateInductionHead)) := - .ofFn fun i : Fin model.numLayers => - Task.spawn (fun _ => computeForLayer i.val) - tasks.map Task.get - else - .ofFn fun i : Fin model.numLayers => - computeForLayer i.val - - -- Join without quadratic copying. - let total := sumSizes chunks - let mut candidates : Array CandidateInductionHead := Array.mkEmpty total - for cs in chunks do - for c in cs do - candidates := candidates.push c - - candidates.qsort (·.combinedError < ·.combinedError) - -/-- Search for induction heads using proper layer-wise residual stream computation. - -This method performs multi-layer analysis with correct forward pass: -- Layer 1 attention is computed on the residual stream *after* layer 0 -- Layer 2 attention is computed on the residual stream *after* layers 0-1 - -This enables detection of true induction heads where layer 2 attends to -information created by layer 1's "previous-token" head. - -**Performance Note**: Uses `PrecomputedCache` to avoid O(L²H²) redundant attention -computations, reducing to O(LH) for typical models. --/ -def findInductionHeadCandidates (model : ConcreteModel) - (_threshold : Float) - (minPrevTokenStrength : Float := 0.1) - (minInductionScore : Float := 0.05) : Array CandidateInductionHead := Id.run do - let cache := PrecomputedCache.build model - findInductionHeadCandidatesFromCache cache minPrevTokenStrength minInductionScore - -/-- Filter candidates to only those meeting the threshold. - -Uses proper layer-wise residual stream computation. --/ -def findVerifiedInductionHeads (model : ConcreteModel) - (threshold : Float) : Array VerifiedInductionHead := Id.run do - let candidates := findInductionHeadCandidates model threshold - let mut verified : Array VerifiedInductionHead := #[] - - for candidate in candidates do - if candidate.combinedError ≤ threshold then - verified := verified.push { - candidate := candidate - threshold := threshold - errorChecked := true - } - - verified - -/-- Find induction head candidates with rigorous N-layer amplification bounds. - -This replaces the ad-hoc `ε₁ + ε₂ + ε₁·ε₂` formula with the correct N-layer -composition theorem: - - ε_total = Σᵢ εᵢ · ∏_{j>i} (1 + Cⱼ) - -For a 2-layer induction head with layers l1 < l2: -- Layer l1 contributes: ε₁ · (1 + C_l2) · (1 + C_{l2+1}) · ... -- Layer l2 contributes: ε₂ · (1 + C_{l2+1}) · ... - -The amplification factors Cⱼ are estimated from residual Jacobian norms -(`layerJacobian - I`). --/ -def findDeepCircuitCandidatesFromCache (cache : PrecomputedCache) - (minPrevTokenStrength : Float := 0.1) - (minInductionScore : Float := 0.05) : Array DeepCircuitCandidate := Id.run do - let model := cache.model - - -- OPTIMIZATION: Use cached operator norm bounds (already computed in cache.build) - let allNormBounds := cache.layerNormBounds - - -- OPTIMIZATION: precompute suffix amplification products: - -- `suffixAmp[i] = ∏_{j≥i} (1 + C_j)` and `suffixAmp[size] = 1`. - let suffixAmp : Array Float := Id.run do - let n := allNormBounds.size - let mut out : Array Float := Array.replicate (n + 1) 1.0 - let mut prod : Float := 1.0 - for offset in [:n] do - let i := n - 1 - offset - prod := prod * (1.0 + allNormBounds[i]!) - out := out.set! i prod - return out - - let computeForLayer (l1 : Nat) : Array DeepCircuitCandidate := Id.run do - let layer1Cache := cache.headData.getD l1 #[] - let suffix1 := suffixAmp.getD (l1 + 1) 1.0 - let totalAmpFactor := suffixAmp.getD l1 1.0 - let mut layerCandidates : Array DeepCircuitCandidate := #[] - - -- Preserve the original traversal order: l1, l2, head1, head2. - for l2 in [l1 + 1:model.numLayers] do - let layer2Cache := cache.headData.getD l2 #[] - let suffix2 := suffixAmp.getD (l2 + 1) 1.0 - let relevantNormBounds := allNormBounds.extract l1 (l2 + 1) - - for data1 in layer1Cache do - if data1.prevTokenStrength ≥ minPrevTokenStrength then - for data2 in layer2Cache do - match checkInductionPattern data1.attention data2.attention minInductionScore with - | some _ => - let bound1 := data1.patternTermBound - let bound2 := data2.patternTermBound - let amplifiedError := bound1 * suffix1 + bound2 * suffix2 - - layerCandidates := layerCandidates.push { - layerIndices := #[l1, l2] - headIndices := #[data1.headIdx, data2.headIdx] - patternBounds := #[bound1, bound2] - operatorNormUbs := relevantNormBounds - simpleErrorSum := bound1 + bound2 - amplifiedError := amplifiedError - amplificationFactor := totalAmpFactor - patternType := "induction" - description := s!"L{l1}H{data1.headIdx}->L{l2}H{data2.headIdx}" - } - | none => pure () - else - pure () - - return layerCandidates - - -- Pure parallelism via tasks: layer-1 index computations are independent. - let useParallel := model.numLayers >= 4 - let chunks : Array (Array DeepCircuitCandidate) := - if useParallel then - let tasks : Array (Task (Array DeepCircuitCandidate)) := - .ofFn fun i : Fin model.numLayers => - Task.spawn (fun _ => computeForLayer i.val) - tasks.map Task.get - else - .ofFn fun i : Fin model.numLayers => - computeForLayer i.val - - -- Join without quadratic copying. - let total := sumSizes chunks - let mut candidates : Array DeepCircuitCandidate := Array.mkEmpty total - for cs in chunks do - for c in cs do - candidates := candidates.push c - - candidates.qsort (·.amplifiedError < ·.amplifiedError) - -/-- Find deep circuit candidates with rigorous N-layer amplification bounds. - -This is a wrapper around `findDeepCircuitCandidatesFromCache` that builds the cache. --/ -def findDeepCircuitCandidates (model : ConcreteModel) - (_threshold : Float) - (minPrevTokenStrength : Float := 0.1) - (minInductionScore : Float := 0.05) : Array DeepCircuitCandidate := Id.run do - let cache := PrecomputedCache.build model - findDeepCircuitCandidatesFromCache cache minPrevTokenStrength minInductionScore - -/-- Filter deep circuit candidates by N-layer error threshold. -/ -def findVerifiedDeepCircuits (model : ConcreteModel) - (threshold : Float) : Array DeepCircuitCandidate := Id.run do - let candidates := findDeepCircuitCandidates model threshold - let mut verified : Array DeepCircuitCandidate := #[] - - for candidate in candidates do - if candidate.amplifiedError ≤ threshold then - verified := verified.push candidate - - verified - -/-! ## Pretty Printing -/ - -/-- Pretty print a candidate induction head. -/ -def CandidateInductionHead.toString (c : CandidateInductionHead) : String := - s!"InductionHead L{c.layer1Idx}H{c.head1Idx}->L{c.layer2Idx}H{c.head2Idx} " ++ - s!"(error={c.combinedError}, prev-token={c.prevTokenStrength}, " ++ - s!"induction={c.inductionScore}, kComp={c.kComp})" - -instance : ToString CandidateInductionHead := ⟨CandidateInductionHead.toString⟩ - -/-- Pretty print a verified induction head. -/ -def VerifiedInductionHead.toString (vh : VerifiedInductionHead) : String := - s!"Verified: {vh.candidate} [threshold={vh.threshold}]" - -instance : ToString VerifiedInductionHead := ⟨VerifiedInductionHead.toString⟩ - -/-! ## Summary Statistics -/ - -/-- Summary statistics for discovery results. -/ -structure DiscoverySummary where - /-- Number of candidates found -/ - candidateCount : Nat - /-- Number meeting the threshold -/ - verifiedCount : Nat - /-- Best (lowest) error bound found -/ - bestError : Float - -/-- Compute summary statistics for discovery results. -/ -def computeDiscoverySummary (model : ConcreteModel) (threshold : Float) : - DiscoverySummary := Id.run do - let candidates := findInductionHeadCandidates model threshold - let mut verifiedCount : Nat := 0 - for c in candidates do - if c.combinedError ≤ threshold then - verifiedCount := verifiedCount + 1 - - let bestErr := if candidates.isEmpty then Float.inf - else candidates.foldl (fun acc c => min acc c.combinedError) Float.inf - - { - candidateCount := candidates.size - verifiedCount := verifiedCount - bestError := bestErr - } - -/-! ## Convenience Functions -/ - -/-- Create a ConcreteAttentionLayer from raw Float arrays. -/ -def mkConcreteAttentionLayer - (modelDim headDim : Nat) - (wq wk wv wo : Array Float) - (hq : wq.size = modelDim * headDim) - (hk : wk.size = modelDim * headDim) - (hv : wv.size = modelDim * headDim) - (ho : wo.size = headDim * modelDim) : ConcreteAttentionLayer where - modelDim := modelDim - headDim := headDim - W_Q := { numRows := modelDim, numCols := headDim, data := wq, size_eq := hq } - b_Q := { numRows := 1, numCols := headDim, data := Array.replicate headDim 0.0, size_eq := by simp } - W_K := { numRows := modelDim, numCols := headDim, data := wk, size_eq := hk } - b_K := { numRows := 1, numCols := headDim, data := Array.replicate headDim 0.0, size_eq := by simp } - W_V := { numRows := modelDim, numCols := headDim, data := wv, size_eq := hv } - b_V := { numRows := 1, numCols := headDim, data := Array.replicate headDim 0.0, size_eq := by simp } - W_O := { numRows := headDim, numCols := modelDim, data := wo, size_eq := ho } - W_Q_dims := ⟨rfl, rfl⟩ - b_Q_dims := ⟨rfl, rfl⟩ - W_K_dims := ⟨rfl, rfl⟩ - b_K_dims := ⟨rfl, rfl⟩ - W_V_dims := ⟨rfl, rfl⟩ - b_V_dims := ⟨rfl, rfl⟩ - W_O_dims := ⟨rfl, rfl⟩ - -/-! ## Generic Circuit Discovery - -This section provides a framework for discovering arbitrary circuits (sparse subgraphs) -that are certifiably responsible for model behavior on a given input. - -The key insight is that we can bound the error introduced by pruning a component -without running forward passes—using only weight matrices and attention patterns. --/ - -/-! ### Circuit Mask Representation -/ - -/-- A component identifier: attention head, MLP neuron, or SAE feature. -/ -inductive ComponentId where - /-- Attention head at (layer, head index) -/ - | head : (layerIdx : Nat) → (headIdx : Nat) → ComponentId - /-- MLP neuron at (layer, neuron index within that layer's MLP) -/ - | mlpNeuron : (layerIdx : Nat) → (neuronIdx : Nat) → ComponentId - /-- SAE feature at (layer, feature index within that layer's SAE) -/ - | saeFeature : (layerIdx : Nat) → (featureIdx : Nat) → ComponentId - deriving DecidableEq, Repr - -namespace ComponentId - -/-- Pretty print a component ID. -/ -def toString : ComponentId → String - | head l h => s!"L{l}H{h}" - | mlpNeuron l n => s!"L{l}N{n}" - | saeFeature l f => s!"L{l}F{f}" - -/-- Check if this is an attention head. -/ -def isHead : ComponentId → Bool - | head _ _ => true - | _ => false - -/-- Check if this is an MLP neuron. -/ -def isNeuron : ComponentId → Bool - | mlpNeuron _ _ => true - | _ => false - -/-- Check if this is an SAE feature. -/ -def isSAEFeature : ComponentId → Bool - | saeFeature _ _ => true - | _ => false - -/-- Get the layer index. -/ -def layerIdx : ComponentId → Nat - | head l _ => l - | mlpNeuron l _ => l - | saeFeature l _ => l - -instance : ToString ComponentId := ⟨ComponentId.toString⟩ - -end ComponentId - -/-- A circuit is a sparse subgraph mask over model components. - -The mask indicates which components are **included** in the circuit. -Components not in the mask are considered "ablated" (zeroed out). - -This is a simplified structure without proof obligations - validity is checked at runtime. --/ -structure ConcreteCircuit where - /-- Number of layers in the model -/ - numLayers : Nat - /-- Number of heads per layer -/ - headsPerLayer : Array Nat - /-- Number of MLP neurons per layer -/ - neuronsPerLayer : Array Nat - /-- Included attention heads: includedHeads[l][h] = true iff head h in layer l is active -/ - includedHeads : Array (Array Bool) - /-- Included MLP neurons: includedNeurons[l][n] = true iff neuron n in layer l is active -/ - includedNeurons : Array (Array Bool) - -namespace ConcreteCircuit - -/-- Check if a specific head is included in the circuit. -/ -def isHeadIncluded (circuit : ConcreteCircuit) (layerIdx headIdx : Nat) : Bool := - if layerIdx < circuit.includedHeads.size then - let layerMask := circuit.includedHeads.getD layerIdx #[] - layerMask.getD headIdx false - else false - -/-- Check if a specific MLP neuron is included in the circuit. -/ -def isNeuronIncluded (circuit : ConcreteCircuit) (layerIdx neuronIdx : Nat) : Bool := - if layerIdx < circuit.includedNeurons.size then - let layerMask := circuit.includedNeurons.getD layerIdx #[] - layerMask.getD neuronIdx false - else false - -/-- Check if any component is included (dispatches on ComponentId type). -/ -def isIncluded (circuit : ConcreteCircuit) (comp : ComponentId) : Bool := - match comp with - | ComponentId.head l h => circuit.isHeadIncluded l h - | ComponentId.mlpNeuron l n => circuit.isNeuronIncluded l n - | ComponentId.saeFeature _ _ => false -- SAE features handled by SAECircuit - -/-- Count total number of included attention heads. -/ -def countIncludedHeads (circuit : ConcreteCircuit) : Nat := - countTrueNested circuit.includedHeads - -/-- Count total number of included MLP neurons. -/ -def countIncludedNeurons (circuit : ConcreteCircuit) : Nat := - countTrueNested circuit.includedNeurons - -/-- Count total number of included components. -/ -def countIncluded (circuit : ConcreteCircuit) : Nat := - circuit.countIncludedHeads + circuit.countIncludedNeurons - -/-- Count total number of attention heads (included + excluded). -/ -def totalHeads (circuit : ConcreteCircuit) : Nat := - sumNatArray circuit.headsPerLayer - -/-- Count total number of MLP neurons (included + excluded). -/ -def totalNeurons (circuit : ConcreteCircuit) : Nat := - sumNatArray circuit.neuronsPerLayer - -/-- Count total number of components (included + excluded). -/ -def totalComponents (circuit : ConcreteCircuit) : Nat := - circuit.totalHeads + circuit.totalNeurons - -/-- List all included component IDs. -/ -def includedComponents (circuit : ConcreteCircuit) : Array ComponentId := Id.run do - let mut result : Array ComponentId := #[] - -- Attention heads - for l in [:circuit.numLayers] do - let layerMask := circuit.includedHeads.getD l #[] - for h_idx in [:layerMask.size] do - if layerMask.getD h_idx false then - result := result.push (ComponentId.head l h_idx) - -- MLP neurons - for l in [:circuit.numLayers] do - let layerMask := circuit.includedNeurons.getD l #[] - for n_idx in [:layerMask.size] do - if layerMask.getD n_idx false then - result := result.push (ComponentId.mlpNeuron l n_idx) - result - -/-- List all excluded component IDs. -/ -def excludedComponents (circuit : ConcreteCircuit) : Array ComponentId := Id.run do - let mut result : Array ComponentId := #[] - -- Attention heads - for l in [:circuit.numLayers] do - let layerMask := circuit.includedHeads.getD l #[] - for h_idx in [:layerMask.size] do - if !layerMask.getD h_idx false then - result := result.push (ComponentId.head l h_idx) - -- MLP neurons - for l in [:circuit.numLayers] do - let layerMask := circuit.includedNeurons.getD l #[] - for n_idx in [:layerMask.size] do - if !layerMask.getD n_idx false then - result := result.push (ComponentId.mlpNeuron l n_idx) - result - -/-- Create a full circuit (all components included). -/ -def full (numLayers : Nat) (headsPerLayer neuronsPerLayer : Array Nat) : ConcreteCircuit where - numLayers := numLayers - headsPerLayer := headsPerLayer - neuronsPerLayer := neuronsPerLayer - includedHeads := headsPerLayer.map fun numHeads => - .ofFn fun _ : Fin numHeads => true - includedNeurons := neuronsPerLayer.map fun numNeurons => - .ofFn fun _ : Fin numNeurons => true - -/-- Create an empty circuit (no components included). -/ -def empty (numLayers : Nat) (headsPerLayer neuronsPerLayer : Array Nat) : ConcreteCircuit where - numLayers := numLayers - headsPerLayer := headsPerLayer - neuronsPerLayer := neuronsPerLayer - includedHeads := headsPerLayer.map fun numHeads => - .ofFn fun _ : Fin numHeads => false - includedNeurons := neuronsPerLayer.map fun numNeurons => - .ofFn fun _ : Fin numNeurons => false - -/-- Remove a single component from the circuit (returns new circuit). -/ -def removeComponent (circuit : ConcreteCircuit) (comp : ComponentId) : ConcreteCircuit := - match comp with - | ComponentId.head layerIdx headIdx => - if layerIdx < circuit.includedHeads.size then - let newIncluded := circuit.includedHeads.modify layerIdx fun layerMask => - if headIdx < layerMask.size then - layerMask.modify headIdx fun _ => false - else layerMask - { circuit with includedHeads := newIncluded } - else circuit - | ComponentId.mlpNeuron layerIdx neuronIdx => - if layerIdx < circuit.includedNeurons.size then - let newIncluded := circuit.includedNeurons.modify layerIdx fun layerMask => - if neuronIdx < layerMask.size then - layerMask.modify neuronIdx fun _ => false - else layerMask - { circuit with includedNeurons := newIncluded } - else circuit - | ComponentId.saeFeature _ _ => circuit -- SAE features handled by SAECircuit - -/-- Add a single component to the circuit (returns new circuit). -/ -def addComponent (circuit : ConcreteCircuit) (comp : ComponentId) : ConcreteCircuit := - match comp with - | ComponentId.head layerIdx headIdx => - if layerIdx < circuit.includedHeads.size then - let newIncluded := circuit.includedHeads.modify layerIdx fun layerMask => - if headIdx < layerMask.size then - layerMask.modify headIdx fun _ => true - else layerMask - { circuit with includedHeads := newIncluded } - else circuit - | ComponentId.mlpNeuron layerIdx neuronIdx => - if layerIdx < circuit.includedNeurons.size then - let newIncluded := circuit.includedNeurons.modify layerIdx fun layerMask => - if neuronIdx < layerMask.size then - layerMask.modify neuronIdx fun _ => true - else layerMask - { circuit with includedNeurons := newIncluded } - else circuit - | ComponentId.saeFeature _ _ => circuit -- SAE features handled by SAECircuit - -/-- Pretty print the circuit. -/ -def toString (circuit : ConcreteCircuit) : String := - let heads := circuit.countIncludedHeads - let neurons := circuit.countIncludedNeurons - let totalH := circuit.totalHeads - let totalN := circuit.totalNeurons - s!"Circuit(heads={heads}/{totalH}, neurons={neurons}/{totalN})" - -instance : ToString ConcreteCircuit := ⟨ConcreteCircuit.toString⟩ - -end ConcreteCircuit - -/-! ## SAE-Enhanced Circuit Discovery - -When using Sparse Autoencoders, we replace MLP neuron masks with SAE feature masks. -This enables discovering circuits in terms of interpretable features rather than -polysemantic neurons. --/ - -/-- A circuit mask that operates on SAE features instead of MLP neurons. - -This extends ConcreteCircuit by replacing MLP neuron masks with SAE feature masks. -The attention head masks remain the same. --/ -structure SAECircuit where - /-- Number of layers in the model -/ - numLayers : Nat - /-- Number of heads per layer -/ - headsPerLayer : Array Nat - /-- Number of SAE features per layer -/ - featuresPerLayer : Array Nat - /-- Included attention heads -/ - includedHeads : Array (Array Bool) - /-- Included SAE features: includedFeatures[l][f] = true iff feature f in layer l is active -/ - includedFeatures : Array (Array Bool) - -namespace SAECircuit - -/-- Check if a specific head is included. -/ -def isHeadIncluded (circuit : SAECircuit) (layerIdx headIdx : Nat) : Bool := - if layerIdx < circuit.includedHeads.size then - let layerMask := circuit.includedHeads.getD layerIdx #[] - layerMask.getD headIdx false - else false - -/-- Check if a specific SAE feature is included. -/ -def isFeatureIncluded (circuit : SAECircuit) (layerIdx featureIdx : Nat) : Bool := - if layerIdx < circuit.includedFeatures.size then - let layerMask := circuit.includedFeatures.getD layerIdx #[] - layerMask.getD featureIdx false - else false - -/-- Check if any component is included. -/ -def isIncluded (circuit : SAECircuit) (comp : ComponentId) : Bool := - match comp with - | ComponentId.head l h => circuit.isHeadIncluded l h - | ComponentId.mlpNeuron _ _ => false -- SAE circuits don't track neurons - | ComponentId.saeFeature l f => circuit.isFeatureIncluded l f - -/-- Count included heads. -/ -def countIncludedHeads (circuit : SAECircuit) : Nat := - countTrueNested circuit.includedHeads - -/-- Count included features. -/ -def countIncludedFeatures (circuit : SAECircuit) : Nat := - countTrueNested circuit.includedFeatures - -/-- Total heads. -/ -def totalHeads (circuit : SAECircuit) : Nat := - sumNatArray circuit.headsPerLayer - -/-- Total features. -/ -def totalFeatures (circuit : SAECircuit) : Nat := - sumNatArray circuit.featuresPerLayer - -/-- Create a full circuit (all components included). -/ -def full (numLayers : Nat) (headsPerLayer featuresPerLayer : Array Nat) : SAECircuit where - numLayers := numLayers - headsPerLayer := headsPerLayer - featuresPerLayer := featuresPerLayer - includedHeads := headsPerLayer.map fun numHeads => - .ofFn fun _ : Fin numHeads => true - includedFeatures := featuresPerLayer.map fun numFeats => - .ofFn fun _ : Fin numFeats => true - -/-- Create an empty circuit. -/ -def empty (numLayers : Nat) (headsPerLayer featuresPerLayer : Array Nat) : SAECircuit where - numLayers := numLayers - headsPerLayer := headsPerLayer - featuresPerLayer := featuresPerLayer - includedHeads := headsPerLayer.map fun numHeads => - .ofFn fun _ : Fin numHeads => false - includedFeatures := featuresPerLayer.map fun numFeats => - .ofFn fun _ : Fin numFeats => false - -/-- Remove a component. -/ -def removeComponent (circuit : SAECircuit) (comp : ComponentId) : SAECircuit := - match comp with - | ComponentId.head layerIdx headIdx => - if layerIdx < circuit.includedHeads.size then - let newIncluded := circuit.includedHeads.modify layerIdx fun layerMask => - if headIdx < layerMask.size then - layerMask.modify headIdx fun _ => false - else layerMask - { circuit with includedHeads := newIncluded } - else circuit - | ComponentId.saeFeature layerIdx featureIdx => - if layerIdx < circuit.includedFeatures.size then - let newIncluded := circuit.includedFeatures.modify layerIdx fun layerMask => - if featureIdx < layerMask.size then - layerMask.modify featureIdx fun _ => false - else layerMask - { circuit with includedFeatures := newIncluded } - else circuit - | ComponentId.mlpNeuron _ _ => circuit -- Not supported in SAE circuits - -/-- Add a component. -/ -def addComponent (circuit : SAECircuit) (comp : ComponentId) : SAECircuit := - match comp with - | ComponentId.head layerIdx headIdx => - if layerIdx < circuit.includedHeads.size then - let newIncluded := circuit.includedHeads.modify layerIdx fun layerMask => - if headIdx < layerMask.size then - layerMask.modify headIdx fun _ => true - else layerMask - { circuit with includedHeads := newIncluded } - else circuit - | ComponentId.saeFeature layerIdx featureIdx => - if layerIdx < circuit.includedFeatures.size then - let newIncluded := circuit.includedFeatures.modify layerIdx fun layerMask => - if featureIdx < layerMask.size then - layerMask.modify featureIdx fun _ => true - else layerMask - { circuit with includedFeatures := newIncluded } - else circuit - | ComponentId.mlpNeuron _ _ => circuit - -def toString (circuit : SAECircuit) : String := - let heads := circuit.countIncludedHeads - let features := circuit.countIncludedFeatures - let totalH := circuit.totalHeads - let totalF := circuit.totalFeatures - s!"SAECircuit(heads={heads}/{totalH}, features={features}/{totalF})" - -instance : ToString SAECircuit := ⟨SAECircuit.toString⟩ - -end SAECircuit - -/-- Model with SAEs attached at each layer's MLP. - -Replaces `ConcreteModel.mlps` with SAEs for feature-level analysis. --/ -structure SAEEnhancedModel where - /-- Number of layers -/ - numLayers : Nat - /-- Attention layers with their heads -/ - layers : Array (Array ConcreteAttentionLayer) - /-- Pre-LN LayerNorm parameters before attention (ln_1), one per layer. -/ - ln1 : Array ConcreteLayerNormParams := #[] - /-- Pre-LN LayerNorm parameters before SAE/MLP (ln_2), one per layer. -/ - ln2 : Array ConcreteLayerNormParams := #[] - /-- Final LayerNorm parameters (ln_f) before unembedding. -/ - lnf : ConcreteLayerNormParams := ConcreteLayerNormParams.identity 0 - /-- SAEs for MLP analysis: saes[l] is the SAE for layer l's MLP -/ - saes : Array ConcreteSAE - /-- Sequence length -/ - seqLen : Nat - /-- Input embeddings -/ - inputEmbeddings : ConcreteMatrix - /-- Unembedding matrix -/ - unembedding : Option ConcreteMatrix := none - -namespace SAEEnhancedModel - -/-- Model dimension (d), inferred from input embeddings. -/ -def modelDim (model : SAEEnhancedModel) : Nat := - model.inputEmbeddings.numCols - -/-- Get ln_1 parameters for a layer, defaulting to identity. -/ -def ln1Params (model : SAEEnhancedModel) (layerIdx : Nat) : ConcreteLayerNormParams := - model.ln1.getD layerIdx (ConcreteLayerNormParams.identity model.modelDim) - -/-- Get ln_2 parameters for a layer, defaulting to identity. -/ -def ln2Params (model : SAEEnhancedModel) (layerIdx : Nat) : ConcreteLayerNormParams := - model.ln2.getD layerIdx (ConcreteLayerNormParams.identity model.modelDim) - -/-- Apply ln_1 to a residual stream (row-wise, per token). -/ -def applyLn1 (model : SAEEnhancedModel) (layerIdx : Nat) (X : ConcreteMatrix) : ConcreteMatrix := - let p := model.ln1Params layerIdx - ConcreteMatrix.layerNormRowwise X p.gamma p.beta - -/-- Apply ln_2 to a residual stream (row-wise, per token). -/ -def applyLn2 (model : SAEEnhancedModel) (layerIdx : Nat) (X : ConcreteMatrix) : ConcreteMatrix := - let p := model.ln2Params layerIdx - ConcreteMatrix.layerNormRowwise X p.gamma p.beta - -/-- Conservative operator-norm bound for ln_1 Jacobian at a specific activation. -/ -def ln1OpBound (model : SAEEnhancedModel) (layerIdx : Nat) (X : ConcreteMatrix) : Float := - let p := model.ln1Params layerIdx - ConcreteMatrix.layerNormRowwiseOpEst X p.gamma - -/-- Create from ConcreteModel with externally trained SAEs. -/ -def fromModel (model : ConcreteModel) (saes : Array ConcreteSAE) : Option SAEEnhancedModel := - if saes.size = model.numLayers then - some { - numLayers := model.numLayers - layers := model.layers - ln1 := model.ln1 - ln2 := model.ln2 - lnf := model.lnf - saes := saes - seqLen := model.seqLen - inputEmbeddings := model.inputEmbeddings - unembedding := model.unembedding - } - else none - -/-- Total reconstruction error across all SAEs for given forward pass. -/ -def totalReconstructionError (model : SAEEnhancedModel) - (fwd : ForwardPassResult) : Float := Id.run do - let mut totalErr : Float := 0.0 - for l in [:model.numLayers] do - if hl : l < model.saes.size then - let sae := model.saes[l] - let layerInput := fwd.getLayerInput l - let err := sae.reconstructionErrorMatrix layerInput - totalErr := totalErr + err * err - Float.sqrt totalErr - -end SAEEnhancedModel - -/-- Importance metrics for SAE features. -/ -structure SAEFeatureImportance where - /-- Component identifier -/ - component : ComponentId - /-- Value term norm (feature influence) -/ - valueTermNorm : Float - /-- Pattern term bound (activation instability) -/ - patternTermBound : Float - /-- Faithfulness ratio -/ - faithfulnessRatio : Float - /-- Reconstruction error contribution (SAE approximation) -/ - reconstructionError : Float - -namespace SAEFeatureImportance - -def toString (imp : SAEFeatureImportance) : String := - s!"{imp.component}: value={imp.valueTermNorm}, pattern={imp.patternTermBound}, " ++ - s!"recon={imp.reconstructionError}, ratio={imp.faithfulnessRatio}" - -instance : ToString SAEFeatureImportance := ⟨toString⟩ - -end SAEFeatureImportance - -/-- Compute importance metrics for a single SAE feature. -/ -def computeSAEFeatureImportance (sae : ConcreteSAE) (layerIdx featureIdx : Nat) - (layerInput : ConcreteMatrix) (perturbationNorm : Float) : Option SAEFeatureImportance := - if featureIdx < sae.numFeatures then - let influence := sae.featureInfluence featureIdx - - -- Compute IBP-based pattern term bound across positions - let patternBound := Id.run do - let mut totalPatternSq : Float := 0.0 - for pos in [:layerInput.numRows] do - let inputVec : Array Float := .ofFn fun d : Fin layerInput.numCols => - layerInput.getUnsafe pos d.val - let posBound := sae.featurePatternTermBoundIBP featureIdx inputVec perturbationNorm - totalPatternSq := totalPatternSq + posBound * posBound - Float.sqrt (totalPatternSq / (max 1 layerInput.numRows).toFloat) - - -- Estimate per-feature contribution to reconstruction error - -- Approximation: uniform distribution across features (could be refined) - let perFeatureRecon := Id.run do - let mut totalRecon : Float := 0.0 - for pos in [:layerInput.numRows] do - let inputVec : Array Float := .ofFn fun d : Fin layerInput.numCols => - layerInput.getUnsafe pos d.val - totalRecon := totalRecon + sae.reconstructionError inputVec - totalRecon / (max 1 layerInput.numRows).toFloat / sae.numFeatures.toFloat - - let ratio := if influence < 1e-10 then Float.inf else patternBound / influence - - some { - component := ComponentId.saeFeature layerIdx featureIdx - valueTermNorm := influence - patternTermBound := patternBound - faithfulnessRatio := ratio - reconstructionError := perFeatureRecon - } - else none - -/-- Error breakdown for SAE-based circuits. -/ -structure SAECircuitError where - /-- Pattern term error from included components -/ - patternTermError : Float - /-- Ablation error from excluded components -/ - ablationError : Float - /-- SAE reconstruction error (approximation of MLP) -/ - reconstructionError : Float - /-- Total error bound -/ - totalError : Float - /-- Number of included components -/ - includedCount : Nat - /-- Number of excluded components -/ - excludedCount : Nat - /-- Number of unstable features -/ - unstableFeatureCount : Nat - deriving Repr - -namespace SAECircuitError - -def toString (e : SAECircuitError) : String := - s!"SAECircuitError(total={e.totalError}, pattern={e.patternTermError}, " ++ - s!"ablation={e.ablationError}, recon={e.reconstructionError}, " ++ - s!"unstable={e.unstableFeatureCount})" - -instance : ToString SAECircuitError := ⟨toString⟩ - -end SAECircuitError - -/-- Helper to compute head importance for SAE analysis. -Inline version used before computeHeadImportance is defined. -Works with both ConcreteModel and SAEEnhancedModel (via their shared fields). -/ -private def computeHeadMetricsForSAE - (model : SAEEnhancedModel) - (layerIdx headIdx : Nat) (layerInput : ConcreteMatrix) : Option (Float × Float) := - if h1 : layerIdx < model.layers.size then - let layerHeads := model.layers[layerIdx] - if h2 : headIdx < layerHeads.size then - let head := layerHeads[headIdx] - let attnInput := model.applyLn1 layerIdx layerInput - let attn := head.computeAttentionWeights attnInput false - let inputNorm := computeInputNorm attnInput - let inputOpBound := attnInput.opNormUpperBoundOneInf - let ln1Bound := model.ln1OpBound layerIdx layerInput - let bnds := head.noDenseProductBounds - - let valueNorm := ln1Bound * computeValueTermNorm attn bnds.voFrobNormSq - let inputs : PatternTermBoundInputs := { - attention := attn - inputNorm := inputNorm - inputOpBound := inputOpBound - scaleFactor := Float.sqrt head.headDim.toFloat - wqOpBound := bnds.wqOpGram - wkOpBound := bnds.wkOpGram - wvOpBound := bnds.wvOpGram - woOpBound := bnds.woOpGram - voOpBound := bnds.voOpBound - bqFrob := head.b_Q.frobeniusNorm - bkFrob := head.b_K.frobeniusNorm - bvFrob := head.b_V.frobeniusNorm - } - let patternBound := ln1Bound * computePatternTermBound inputs - some (valueNorm, patternBound) - else none - else none - -/-- Estimate faithfulness error for an SAE-based circuit. - -Extends the standard circuit error model to include: -1. Pattern term error for included attention heads -2. Pattern term error for included SAE features (with IBP) -3. Ablation error for excluded components -4. SAE reconstruction error (the approximation of using SAE instead of MLP) - -Total Error = Σ(included) patternBound + Σ(excluded) valueNorm + reconstructionError --/ -def estimateSAECircuitFaithfulness (model : SAEEnhancedModel) - (circuit : SAECircuit) (_causal : Bool := true) : SAECircuitError := Id.run do - -- Simplified forward pass (just attention) - let mut residual := model.inputEmbeddings - let mut layerInputs : Array ConcreteMatrix := #[model.inputEmbeddings] - - for l in [:model.numLayers] do - let attnInput := model.applyLn1 l residual - let mut attnSum := ConcreteMatrix.zeros residual.numRows residual.numCols - if hl : l < model.layers.size then - let layerHeads := model.layers[l] - for h in [:layerHeads.size] do - if hh : h < layerHeads.size then - let head := layerHeads[h]'hh - let headOutput := head.forward attnInput true - attnSum := attnSum.add headOutput - residual := residual.add attnSum - layerInputs := layerInputs.push residual - - let mut patternError : Float := 0.0 - let mut ablationError : Float := 0.0 - let mut totalRecon : Float := 0.0 - let mut includedCount : Nat := 0 - let mut excludedCount : Nat := 0 - let mut unstableCount : Nat := 0 - let mut cumulativeAblation : Float := 0.0 - - -- Process attention heads - for l in [:model.numLayers] do - if hl : l < model.layers.size then - let layerHeads := model.layers[l] - for h_idx in [:layerHeads.size] do - let included := circuit.isHeadIncluded l h_idx - let layerInput := - if hl2 : l < layerInputs.size then layerInputs[l]'hl2 else model.inputEmbeddings - match computeHeadMetricsForSAE model l h_idx layerInput with - | some (valueNorm, patternBound) => - if included then - patternError := patternError + patternBound - includedCount := includedCount + 1 - else - ablationError := ablationError + valueNorm - cumulativeAblation := cumulativeAblation + valueNorm - excludedCount := excludedCount + 1 - | none => pure () - - -- Process SAE features - for l in [:model.numLayers] do - if hl : l < model.saes.size then - let sae := model.saes[l] - -- Pre-LN: SAE/MLP sees ln_2(y_l) where y_l is post-attention residual. - let postAttn := - if hpost : l + 1 < layerInputs.size then layerInputs[l + 1]'hpost else residual - let layerInput := model.applyLn2 l postAttn - - -- Add reconstruction error for this layer - let layerRecon := sae.reconstructionErrorMatrix layerInput - totalRecon := totalRecon + layerRecon - - for f_idx in [:sae.numFeatures] do - let included := circuit.isFeatureIncluded l f_idx - match computeSAEFeatureImportance sae l f_idx layerInput cumulativeAblation with - | some imp => - if included then - patternError := patternError + imp.patternTermBound - if imp.patternTermBound > 0.0 then - unstableCount := unstableCount + 1 - includedCount := includedCount + 1 - else - ablationError := ablationError + imp.valueTermNorm - excludedCount := excludedCount + 1 - | none => pure () - - { - patternTermError := patternError - ablationError := ablationError - reconstructionError := totalRecon - totalError := patternError + ablationError + totalRecon - includedCount := includedCount - excludedCount := excludedCount - unstableFeatureCount := unstableCount - } - -/-! ### SAE Circuit Discovery - -Greedy pruning algorithm for SAE-enhanced circuits: -1. Start with all components included -2. Compute importance for each component (heads + SAE features) -3. Remove the component with smallest valueNorm (least information loss when ablated) -4. Repeat until error threshold would be exceeded - -Note: SAE reconstruction error is an additive constant for a given set of SAEs, -so it doesn't affect the pruning order. --/ - -/-- Ranked component importance for SAE circuits. -/ -structure SAERankedComponent where - /-- Component identifier -/ - component : ComponentId - /-- Value term norm (importance for ablation) -/ - valueTermNorm : Float - /-- Pattern term bound (error when included) -/ - patternTermBound : Float - -/-- Compute all component importances for an SAE-enhanced model. -/ -def computeAllSAEImportance (model : SAEEnhancedModel) : Array SAERankedComponent := Id.run do - let mut result : Array SAERankedComponent := #[] - - -- Simplified forward pass - let mut residual := model.inputEmbeddings - let mut layerInputs : Array ConcreteMatrix := #[model.inputEmbeddings] - - for l in [:model.numLayers] do - let attnInput := model.applyLn1 l residual - let mut attnSum := ConcreteMatrix.zeros residual.numRows residual.numCols - if hl : l < model.layers.size then - let layerHeads := model.layers[l] - for h in [:layerHeads.size] do - if hh : h < layerHeads.size then - let head := layerHeads[h]'hh - let headOutput := head.forward attnInput true - attnSum := attnSum.add headOutput - residual := residual.add attnSum - layerInputs := layerInputs.push residual - - -- Compute head importances - for l in [:model.numLayers] do - if hl : l < model.layers.size then - let layerHeads := model.layers[l] - for h_idx in [:layerHeads.size] do - let layerInput := - if hl2 : l < layerInputs.size then layerInputs[l]'hl2 else model.inputEmbeddings - match computeHeadMetricsForSAE model l h_idx layerInput with - | some (valueNorm, patternBound) => - result := result.push { - component := ComponentId.head l h_idx - valueTermNorm := valueNorm - patternTermBound := patternBound - } - | none => pure () - - -- Compute SAE feature importances - for l in [:model.numLayers] do - if hl : l < model.saes.size then - let sae := model.saes[l] - let postAttn := - if hpost : l + 1 < layerInputs.size then layerInputs[l + 1]'hpost else residual - let layerInput := model.applyLn2 l postAttn - for f_idx in [:sae.numFeatures] do - match computeSAEFeatureImportance sae l f_idx layerInput 0.0 with - | some imp => - result := result.push { - component := imp.component - valueTermNorm := imp.valueTermNorm - patternTermBound := imp.patternTermBound - } - | none => pure () - - result - -/-- Discover minimal SAE circuit using greedy pruning. - -Algorithm: -1. Start with full circuit (all heads, all features) -2. Compute base reconstruction error (constant for given SAEs) -3. Sort all components by valueTermNorm (ascending) -4. Iteratively remove smallest-impact components until error exceeds threshold - -The threshold should account for SAE reconstruction error as a baseline. --/ -def discoverSAECircuit (model : SAEEnhancedModel) (threshold : Float) : SAECircuit := Id.run do - -- Initialize full circuit - let headsPerLayer := model.layers.map (·.size) - let featuresPerLayer := model.saes.map (·.numFeatures) - let mut circuit := SAECircuit.full model.numLayers headsPerLayer featuresPerLayer - - -- Compute base reconstruction error - let baseError := estimateSAECircuitFaithfulness model circuit - let reconError := baseError.reconstructionError - - -- If base error already exceeds threshold, return empty - if reconError > threshold then - return SAECircuit.empty model.numLayers headsPerLayer featuresPerLayer - - let adjustedThreshold := threshold - reconError - - -- Get all components sorted by valueTermNorm (ascending) - let allImportance := computeAllSAEImportance model - let sorted := allImportance.qsort fun a b => a.valueTermNorm < b.valueTermNorm - - let mut cumulativePatternError : Float := baseError.patternTermError - let mut cumulativeAblationError : Float := 0.0 - - -- Greedily remove components - for comp in sorted do - -- Removing this component: - -- - Removes its patternTermBound from included error - -- - Adds its valueTermNorm to ablation error - let newPatternError := cumulativePatternError - comp.patternTermBound - let newAblationError := cumulativeAblationError + comp.valueTermNorm - let newTotalError := newPatternError + newAblationError - - if newTotalError > adjustedThreshold then - -- Stop: removing this component would exceed threshold - break - - -- Remove the component - circuit := circuit.removeComponent comp.component - cumulativePatternError := newPatternError - cumulativeAblationError := newAblationError - - circuit - -/-- Result of SAE circuit discovery. -/ -structure SAEDiscoveryResult where - /-- The discovered circuit -/ - circuit : SAECircuit - /-- Error breakdown -/ - error : SAECircuitError - /-- Compression ratio (components kept / total) -/ - compressionRatio : Float - -/-- Discover SAE circuit with full result details. -/ -def discoverSAECircuitWithResult (model : SAEEnhancedModel) - (threshold : Float) : SAEDiscoveryResult := Id.run do - let circuit := discoverSAECircuit model threshold - let error := estimateSAECircuitFaithfulness model circuit - let totalComponents := circuit.totalHeads + circuit.totalFeatures - let includedComponents := circuit.countIncludedHeads + circuit.countIncludedFeatures - let compression := if totalComponents > 0 then - includedComponents.toFloat / totalComponents.toFloat - else 1.0 - - { circuit, error, compressionRatio := compression } - -/-! ## Ablated Forward Pass (Causal Intervention) - -These functions implement executable causal interventions: running a forward pass -where specific attention heads or MLP neurons are masked out (ablated) based on -a `ConcreteCircuit` mask. - -This bridges the theoretical `Abstraction.lean` bounds with empirical verification: -- `runAblatedForward`: Execute the circuit (masked forward pass) -- `computeAblationDiscrepancy`: Measure actual difference from full model -- `verifyCircuitFaithfulness`: Assert empirical ≤ theoretical bound --/ - -/-- Forward pass for an MLP layer with neuron-level ablation. - -Same as `ConcreteMLPLayer.forward` but with a mask specifying which neurons are active. -Inactive neurons have their contributions zeroed out. - -For neuron i: -- If included: contribute GeLU(W_in[:,i] · x + b_in[i]) * W_out[i,:] -- If excluded: contribute 0 --/ -def ConcreteMLPLayer.forwardAblated (layer : ConcreteMLPLayer) (input : ConcreteMatrix) - (neuronMask : Array Bool) : ConcreteMatrix := - -- hidden = input · W_in + b_in (seqLen × hiddenDim) - let hidden := (input.matmul layer.W_in).addBias layer.b_in - -- Apply GeLU activation with masking using .ofFn for proper size proof - let activated : ConcreteMatrix := { - numRows := hidden.numRows - numCols := hidden.numCols - data := .ofFn fun idx : Fin (hidden.numRows * hidden.numCols) => - let j := idx.val % hidden.numCols - let val := hidden.data[idx.val]! - let act := geluFloat val - -- Zero out if neuron j is not included - if neuronMask.getD j true then act else 0.0 - size_eq := Array.size_ofFn - } - -- output = activated · W_out + b_out (seqLen × modelDim) - (activated.matmul layer.W_out).addBias layer.b_out - -/-- Run an ablated forward pass through the model. - -Like `runForward`, but with a `ConcreteCircuit` mask that specifies which attention -heads and MLP neurons are active. Excluded components have their contributions -zeroed out, implementing a causal intervention. - -This enables **empirical validation** of theoretical circuit bounds: -1. Discover a circuit via `discoverCircuit` or `discoverTargetedCircuit` -2. Run `runAblatedForward` with that circuit -3. Compare to `runForward` to measure actual discrepancy -4. Verify that empirical discrepancy ≤ theoretical bound - -**Ablation semantics:** -- Excluded attention head: its output is zero (does not contribute to residual) -- Excluded MLP neuron: its activation is zero (does not contribute to FFN output) --/ -def ConcreteModel.runAblatedForward (model : ConcreteModel) (circuit : ConcreteCircuit) - (causal : Bool := true) : ForwardPassResult := Id.run do - let mut layerInputs : Array ConcreteMatrix := Array.mkEmpty (model.numLayers + 1) - let mut postAttnResiduals : Array ConcreteMatrix := Array.mkEmpty model.numLayers - let mut attnOutputs : Array (Array ConcreteMatrix) := Array.mkEmpty model.numLayers - let mut mlpOutputs : Array ConcreteMatrix := Array.mkEmpty model.numLayers - let mut residual := model.inputEmbeddings - layerInputs := layerInputs.push residual - - for l in [:model.numLayers] do - -- Pre-LN: attention sees ln_1(residual) - let attnInput := model.applyLn1 l residual - -- Compute attention outputs for included heads only - let mut layerAttnOutputs : Array ConcreteMatrix := #[] - let mut includedHeadOutputs : Array ConcreteMatrix := #[] - let rows := residual.numRows - let cols := residual.numCols - let zeroOutput := ConcreteMatrix.zeros rows cols - - if hl : l < model.layers.size then - let layerHeads := model.layers[l] - let includedMask := circuit.includedHeads.getD l #[] - let includedCount := countTrue includedMask - let useParallelHeads := - layerHeads.size >= 4 && includedCount >= 4 - layerAttnOutputs := - if useParallelHeads then - let tasks : Array (Task ConcreteMatrix) := - .ofFn fun i : Fin layerHeads.size => - Task.spawn (fun _ => - if circuit.isHeadIncluded l i.val then - (layerHeads[i]).forward attnInput causal - else - zeroOutput) - tasks.map Task.get - else - Id.run do - let mut outs : Array ConcreteMatrix := Array.mkEmpty layerHeads.size - for h in [:layerHeads.size] do - if hh : h < layerHeads.size then - let head := layerHeads[h]'hh - if circuit.isHeadIncluded l h then - outs := outs.push (head.forward attnInput causal) - else - outs := outs.push zeroOutput - return outs - -- Preserve the original summation order: increasing head index. - includedHeadOutputs := - Id.run do - let mut outs : Array ConcreteMatrix := Array.mkEmpty includedCount - for h in [:layerAttnOutputs.size] do - if circuit.isHeadIncluded l h then - if hh : h < layerAttnOutputs.size then - outs := outs.push (layerAttnOutputs[h]'hh) - return outs - - attnOutputs := attnOutputs.push layerAttnOutputs - - -- Add attention residual - let attnBias := model.attnProjBiasAt l - let attnSum := - if includedHeadOutputs.isEmpty then - ConcreteMatrix.zeros rows cols - else - ConcreteMatrix.sumMatrices includedHeadOutputs - let residualAfterAttn := residual.add (attnSum.addBias attnBias) - postAttnResiduals := postAttnResiduals.push residualAfterAttn - - -- Compute MLP output with neuron-level ablation - -- Pre-LN: MLP sees ln_2(residualAfterAttn) - let mlpInput := model.applyLn2 l residualAfterAttn - let mlpOut := - if hm : l < model.mlps.size then - -- Get neuron mask for this layer - let neuronMask := circuit.includedNeurons.getD l #[] - model.mlps[l].forwardAblated mlpInput neuronMask - else ConcreteMatrix.zeros residual.numRows residual.numCols - - mlpOutputs := mlpOutputs.push mlpOut - - -- Add MLP residual - residual := residualAfterAttn.add mlpOut - - -- Store input for next layer - layerInputs := layerInputs.push residual - - let finalOutput := model.applyLnf residual - { - layerInputs := layerInputs - postAttnResiduals := postAttnResiduals - attnOutputs := attnOutputs - mlpOutputs := mlpOutputs - mlpActDeriv := Array.replicate model.numLayers (ConcreteMatrix.zeros 0 0) - mlpActDerivMax := Array.replicate model.numLayers 0.0 - finalOutput := finalOutput - } - -/-! ### Empirical Discrepancy and Verification - -These functions compute the actual difference between full and ablated model outputs, -enabling empirical validation of theoretical circuit bounds. --/ - -/-- Compute the element-wise difference between two matrices. -/ -def ConcreteMatrix.sub (A B : ConcreteMatrix) : ConcreteMatrix := - if A.numRows = B.numRows ∧ A.numCols = B.numCols then - { - numRows := A.numRows - numCols := A.numCols - data := .ofFn fun idx : Fin (A.numRows * A.numCols) => - A.data[idx.val]! - B.data[idx.val]! - size_eq := Array.size_ofFn - } - else ConcreteMatrix.zeros 0 0 - -/-- Result of comparing full model output to ablated circuit output. - -This captures the empirical discrepancy between running the full model -and running only the discovered circuit. --/ -structure AblationResult where - /-- Full model output (residual stream after all layers) -/ - fullOutput : ConcreteMatrix - /-- Ablated circuit output -/ - ablatedOutput : ConcreteMatrix - /-- Difference: fullOutput - ablatedOutput -/ - difference : ConcreteMatrix - /-- Frobenius norm of the difference: ‖full - ablated‖_F -/ - empiricalError : Float - /-- Relative error: ‖full - ablated‖_F / ‖full‖_F -/ - relativeError : Float - /-- Number of components in the circuit -/ - circuitSize : Nat - /-- Total number of components in the model -/ - totalComponents : Nat - -namespace AblationResult - -/-- Compute compression ratio: what fraction of components are included. -/ -def compressionRatio (r : AblationResult) : Float := - if r.totalComponents > 0 then - r.circuitSize.toFloat / r.totalComponents.toFloat - else 1.0 - -/-- Pretty print the ablation result. -/ -def toString (r : AblationResult) : String := - s!"AblationResult:\n" ++ - s!" Empirical Error (‖Δ‖_F): {r.empiricalError}\n" ++ - s!" Relative Error: {r.relativeError * 100.0}%\n" ++ - s!" Circuit Size: {r.circuitSize}/{r.totalComponents} " ++ - s!"({r.compressionRatio * 100.0}%)" - -instance : ToString AblationResult := ⟨AblationResult.toString⟩ - -end AblationResult - -/-- Compute the empirical discrepancy between full model and ablated circuit. - -This is the core function for empirical validation: -1. Runs full forward pass -2. Runs ablated forward pass with the circuit mask -3. Computes the difference and its Frobenius norm - -The empirical error should be bounded by the theoretical error estimate from -`estimateCircuitFaithfulness`. --/ -def computeAblationDiscrepancy (model : ConcreteModel) (circuit : ConcreteCircuit) - (causal : Bool := true) : AblationResult := - let fullResult := model.runForward causal - let ablatedResult := model.runAblatedForward circuit causal - let diff := fullResult.finalOutput.sub ablatedResult.finalOutput - let empiricalErr := diff.frobeniusNorm - let fullNorm := fullResult.finalOutput.frobeniusNorm - let relErr := if fullNorm > 1e-10 then empiricalErr / fullNorm else 0.0 - { - fullOutput := fullResult.finalOutput - ablatedOutput := ablatedResult.finalOutput - difference := diff - empiricalError := empiricalErr - relativeError := relErr - circuitSize := circuit.countIncluded - totalComponents := circuit.totalComponents - } - -/-- Result of comparing empirical discrepancy to theoretical bound. - -This is the verification bridge between `Abstraction.lean` (theory) and -`Discovery.lean` (practice). --/ -structure VerificationResult where - /-- Ablation result with empirical measurements -/ - ablation : AblationResult - /-- Theoretical error bound from circuit analysis -/ - theoreticalBound : Float - /-- Whether empirical ≤ theoretical (verification passed) -/ - verified : Bool - /-- Slack: theoretical - empirical (how much margin we have) -/ - slack : Float - /-- Tightness ratio: empirical / theoretical -/ - tightness : Float - -namespace VerificationResult - -/-- Pretty print the verification result. -/ -def toString (r : VerificationResult) : String := - let status := if r.verified then "✓ VERIFIED" else "✗ FAILED" - s!"VerificationResult [{status}]\n" ++ - s!" Empirical Error: {r.ablation.empiricalError}\n" ++ - s!" Theoretical Bound: {r.theoreticalBound}\n" ++ - s!" Slack: {r.slack}\n" ++ - s!" Tightness: {r.tightness * 100.0}%\n" ++ - s!" Circuit: {r.ablation.circuitSize}/{r.ablation.totalComponents} components" - -instance : ToString VerificationResult := ⟨VerificationResult.toString⟩ - -end VerificationResult - -/-- Verify that a discovered circuit's empirical error is within theoretical bounds. - -This is the key function that **closes the loop** between theory and practice: - -**Input:** -- A model -- A discovered circuit (from `discoverCircuit` or `discoverTargetedCircuit`) -- The theoretical error bound (from `CircuitError.totalError`) - -**Output:** -- Verification result showing whether empirical ≤ theoretical - -**Usage:** -``` -let result := discoverCircuit model 0.1 -let verification := verifyCircuitFaithfulness model result.circuit result.error.totalError -if verification.verified then - IO.println "Circuit is empirically faithful!" -``` - -**Interpretation:** -- `verified = true`: The circuit recapitulates model behavior within bounds -- `tightness ≈ 1`: Theoretical bound is tight (good analysis) -- `tightness << 1`: Theoretical bound is loose (conservative) --/ -def verifyCircuitFaithfulness (model : ConcreteModel) (circuit : ConcreteCircuit) - (theoreticalBound : Float) (causal : Bool := true) : VerificationResult := - let ablation := computeAblationDiscrepancy model circuit causal - let verified := ablation.empiricalError ≤ theoreticalBound - let slack := theoreticalBound - ablation.empiricalError - let tightness := if theoreticalBound > 1e-10 - then ablation.empiricalError / theoreticalBound - else 1.0 - { - ablation := ablation - theoreticalBound := theoreticalBound - verified := verified - slack := slack - tightness := tightness - } - -/-! ### Component Importance Metrics -/ - -/-- Importance metrics for a single component. - -These metrics allow ranking components by their contribution to model behavior, -enabling principled circuit pruning. --/ -structure ComponentImportance where - /-- Component identifier -/ - component : ComponentId - /-- Value term norm: ‖A‖_F · ‖W_V·W_O‖_F (how much information flows through) -/ - valueTermNorm : Float - /-- Pattern term bound (approximation error if we trust attention patterns) -/ - patternTermBound : Float - /-- Faithfulness ratio: patternBound / valueNorm -/ - faithfulnessRatio : Float - -namespace ComponentImportance - -/-- Pretty print component importance. -/ -def toString (imp : ComponentImportance) : String := - s!"{imp.component}: value={imp.valueTermNorm}, pattern={imp.patternTermBound}, " ++ - s!"ratio={imp.faithfulnessRatio}" - -instance : ToString ComponentImportance := ⟨ComponentImportance.toString⟩ - -end ComponentImportance - -/-- Compute importance metrics for a single attention head. -/ -def computeHeadImportance (model : ConcreteModel) (layerIdx headIdx : Nat) - (layerInput : ConcreteMatrix) : Option ComponentImportance := do - if h1 : layerIdx < model.layers.size then - let layerHeads := model.layers[layerIdx] - if h2 : headIdx < layerHeads.size then - let head := layerHeads[headIdx] - let attnInput := model.applyLn1 layerIdx layerInput - let attn := head.computeAttentionWeights attnInput - let inputNorm := attnInput.frobeniusNorm - let inputOpBound := attnInput.opNormUpperBoundOneInf - let ln1Bound := model.ln1OpBound layerIdx layerInput - let bnds := head.noDenseProductBounds - - -- Pre-LN: effective value path includes the LayerNorm Jacobian. - let valueNorm := ln1Bound * computeValueTermNorm attn bnds.voFrobNormSq - let inputs : PatternTermBoundInputs := { - attention := attn - inputNorm := inputNorm - inputOpBound := inputOpBound - scaleFactor := Float.sqrt head.headDim.toFloat - wqOpBound := bnds.wqOpGram - wkOpBound := bnds.wkOpGram - wvOpBound := bnds.wvOpGram - woOpBound := bnds.woOpGram - voOpBound := bnds.voOpBound - bqFrob := head.b_Q.frobeniusNorm - bkFrob := head.b_K.frobeniusNorm - bvFrob := head.b_V.frobeniusNorm - } - let patternBound := ln1Bound * computePatternTermBound inputs - let ratio := if valueNorm < 1e-10 then Float.inf else patternBound / valueNorm - - return { - component := ComponentId.head layerIdx headIdx - valueTermNorm := valueNorm - patternTermBound := patternBound - faithfulnessRatio := ratio - } - else none - else none - -/-- Compute importance metrics for a single MLP neuron. - -**Simple Version (no forward pass):** -For ReLU/GeLU MLPs, this uses weight-based bounds only. -The influence magnitude = ‖W_in[:,i]‖ · ‖W_out[i,:]‖ bounds information flow. - -Pattern term is set to 0 (assumes locally linear), which is **unsound** if -ablations cause activation flips. Use `computeNeuronImportanceIBP` with -forward pass data for rigorous bounds. --/ -def computeNeuronImportance (model : ConcreteModel) (layerIdx neuronIdx : Nat) - (_inputNorm : Float) : Option ComponentImportance := - if h : layerIdx < model.mlps.size then - let mlp := model.mlps[layerIdx] - if neuronIdx < mlp.hiddenDim then - let influence := mlp.neuronInfluence neuronIdx - -- Conservative assumption: locally linear (pattern term = 0) - -- WARNING: This is unsound if activation flips occur! - let patternBound : Float := 0.0 - let ratio := if influence < 1e-10 then Float.inf else patternBound / influence - - some { - component := ComponentId.mlpNeuron layerIdx neuronIdx - valueTermNorm := influence - patternTermBound := patternBound - faithfulnessRatio := ratio - } - else none - else none - -/-- Compute importance metrics for a single MLP neuron using IBP. - -**Sound Version (requires forward pass):** -Uses Interval Bound Propagation to detect neurons that may flip activation -states under input perturbations. Provides mathematically rigorous pattern -term bounds. - -**Parameters:** -- `layerInput`: Input to this layer (from forward pass), used to compute - nominal pre-activations -- `perturbationNorm`: L2 bound on how much ablations can change the input - (typically computed from the ablated components' value terms) - -**Returns:** ComponentImportance with rigorous pattern term bound that -accounts for potential activation flips. --/ -def computeNeuronImportanceIBP (model : ConcreteModel) (layerIdx neuronIdx : Nat) - (layerInput : ConcreteMatrix) (perturbationNorm : Float) : Option ComponentImportance := - if h : layerIdx < model.mlps.size then - let mlp := model.mlps[layerIdx] - if neuronIdx < mlp.hiddenDim then - let influence := mlp.neuronInfluence neuronIdx - - -- Average the IBP bound across sequence positions - let patternBound := Id.run do - let mut totalPatternBound : Float := 0.0 - let numPositions := layerInput.numRows - - for pos in [:numPositions] do - -- Extract input vector at this position - let inputVec : Array Float := .ofFn fun d : Fin layerInput.numCols => - layerInput.getUnsafe pos d.val - -- Compute IBP-based pattern term bound - let posBound := mlp.neuronPatternTermBoundIBP neuronIdx inputVec perturbationNorm - totalPatternBound := totalPatternBound + posBound * posBound - - -- RMS of per-position bounds - Float.sqrt (totalPatternBound / (max 1 numPositions).toFloat) - - let ratio := if influence < 1e-10 then Float.inf else patternBound / influence - - some { - component := ComponentId.mlpNeuron layerIdx neuronIdx - valueTermNorm := influence - patternTermBound := patternBound - faithfulnessRatio := ratio - } - else none - else none - -/-- Compute importance metrics for all components in a model. -/ -def computeAllImportance (model : ConcreteModel) : Array ComponentImportance := Id.run do - let inputNorm := computeInputNorm model.inputEmbeddings - - let computeHeadsForLayer (l : Nat) : Array ComponentImportance := Id.run do - if h : l < model.layers.size then - let layerHeads := model.layers[l] - let mut outs : Array ComponentImportance := Array.mkEmpty layerHeads.size - for h_idx in [:layerHeads.size] do - match computeHeadImportance model l h_idx model.inputEmbeddings with - | some imp => outs := outs.push imp - | none => pure () - return outs - else - return #[] - - let computeNeuronsForLayer (l : Nat) : Array ComponentImportance := Id.run do - if h : l < model.mlps.size then - let mlp := model.mlps[l] - let mut outs : Array ComponentImportance := Array.mkEmpty mlp.hiddenDim - for n_idx in [:mlp.hiddenDim] do - match computeNeuronImportance model l n_idx inputNorm with - | some imp => outs := outs.push imp - | none => pure () - return outs - else - return #[] - - let useParallel := model.numLayers >= 4 - let layerPairs : Array (Array ComponentImportance × Array ComponentImportance) := - if useParallel then - let tasks : Array (Task (Array ComponentImportance × Array ComponentImportance)) := - .ofFn fun i : Fin model.numLayers => - Task.spawn (fun _ => (computeHeadsForLayer i.val, computeNeuronsForLayer i.val)) - tasks.map Task.get - else - .ofFn fun i : Fin model.numLayers => - (computeHeadsForLayer i.val, computeNeuronsForLayer i.val) - - let headChunks : Array (Array ComponentImportance) := layerPairs.map (·.1) - let neuronChunks : Array (Array ComponentImportance) := layerPairs.map (·.2) - - -- Join in the same order as the original loop: heads then neurons, increasing layer index. - let totalHeads := sumSizes headChunks - let totalNeurons := sumSizes neuronChunks - let mut result : Array ComponentImportance := Array.mkEmpty (totalHeads + totalNeurons) - for cs in headChunks do - for c in cs do - result := result.push c - for cs in neuronChunks do - for c in cs do - result := result.push c - result - -/-! ### Target-Aware Circuit Discovery (Logit Lens) - -Standard circuit discovery finds components that contribute to the *entire* residual stream, -yielding large generic circuits. **Target-aware discovery** instead finds minimal circuits -for *specific predictions* by projecting component outputs onto a target direction. - -**Key Insight:** For a specific prediction (e.g., "the model predicts 'cat' not 'dog'"), -we define a target direction `u = W_U[cat] - W_U[dog]` in the residual stream space. -A component's importance is then `‖output · u‖` rather than `‖output‖_F`. - -This enables: -- Finding circuits responsible for specific token predictions -- Isolating mechanisms for behavioral differences (e.g., IOI task) -- Producing smaller, more interpretable circuits - -**Mathematical Formulation:** -For attention head with value-output projection `W_V · W_O`: -- Standard importance: `‖A‖_F · ‖W_V · W_O‖_F` (generic) -- Target-aware: `‖A‖_F · ‖(W_V · W_O) · u‖` (specific) - -For MLP neuron with output weights `W_out[i,:]`: -- Standard importance: `‖W_in[:,i]‖ · ‖W_out[i,:]‖` (generic) -- Target-aware: `‖W_in[:,i]‖ · |W_out[i,:] · u|` (specific) --/ - -/-- A target direction for focused circuit discovery. - -Specifies a direction in residual stream space to project component outputs onto. -Typically constructed as `u = W_U[correct_token] - W_U[incorrect_token]`. --/ -structure TargetDirection where - /-- The target vector in model dimension space (modelDim × 1 matrix) -/ - direction : ConcreteMatrix - /-- Human-readable description of what this direction represents -/ - description : String := "target" - -namespace TargetDirection - -/-- Create a target direction from unembedding columns for two tokens. - -`u = W_U[:, correctToken] - W_U[:, incorrectToken]` - -This direction points from the incorrect prediction toward the correct one, -so components with positive projection increase P(correct) / P(incorrect). --/ -def fromLogitDiff (unembedding : ConcreteMatrix) - (correctToken incorrectToken : Nat) : TargetDirection := - let correctCol := unembedding.getCol correctToken - let incorrectCol := unembedding.getCol incorrectToken - let direction := correctCol.vecSub incorrectCol - { - direction := direction - description := s!"logit_diff({correctToken}-{incorrectToken})" - } - -/-- Create a target direction from a single token's unembedding. - -Useful when you want to understand what promotes a specific token. --/ -def fromSingleToken (unembedding : ConcreteMatrix) (token : Nat) : TargetDirection := - { - direction := unembedding.getCol token - description := s!"logit({token})" - } - -/-- Normalize the target direction to unit length. -/ -def normalize (t : TargetDirection) : TargetDirection := - let norm := t.direction.vecNorm - if norm > 1e-10 then - { t with direction := t.direction.scale (1.0 / norm) } - else t - -/-- Construct a next-token logit-difference direction from the model's input token history. - -This is the **self-supervised induction target**: -let `T` be the ground-truth token sequence, and let `t_curr = T[last]`. -If `t_curr` appeared before at index `k`, the "induction target" is `t_next = T[k+1]`. - -Returns `none` if: -- the model has no `inputTokens`, -- the sequence has no previous occurrence of `t_curr`, -- or the model is missing an `unembedding` matrix. --/ -def fromInductionHistory (model : ConcreteModel) : Option TargetDirection := do - let tokens ← model.inputTokens - if tokens.size = 0 then none else - let lastIdx := tokens.size - 1 - let tCurr := tokens[lastIdx]! - let mut foundIdx : Option Nat := none - for offset in [:lastIdx] do - if foundIdx.isNone then - let idx := lastIdx - 1 - offset - if tokens[idx]! = tCurr then - foundIdx := some idx - - let k ← foundIdx - let tNext := tokens[k + 1]! - - let W_U ← model.unembedding - let vocabSize := W_U.numCols - if vocabSize < 2 then none - else if tNext ≥ vocabSize then none - else - let incorrect : Nat := - if tCurr < vocabSize ∧ tCurr ≠ tNext then tCurr - else - let cand1 := (tNext + 1) % vocabSize - if cand1 ≠ tNext then cand1 else (tNext + 2) % vocabSize - if incorrect = tNext then none - else - let base := TargetDirection.fromLogitDiff W_U tNext incorrect - some { base with - description := s!"induction_history(curr={tCurr}, prev={k}, \ - next={tNext}, neg={incorrect})" - } - -end TargetDirection - -/-! ## Virtual-Head Effectiveness Verification -/ - -/-- Extremely generous cutoff to reject only egregious **interpretability illusions**. - -In theory, a mechanism is "genuine" when its relative approximation error is < 1.0 -(`isGenuineMechanism` / `mechanism_trichotomy` in `Nfp.Linearization`). In practice, the -executable Frobenius-norm bounds can be loose in high dimensions, making the strict < 1.0 -test numerically vacuous. Empirically, however, massive errors indicate clear illusions. - -We therefore filter only astronomically large `combinedError` values while ranking by -faithfulness (smallest `combinedError` first). --/ -def egregiousIllusionThreshold : Float := 1.0e30 - -/-- Egregious-illusion filter (currently disabled). - -We intentionally keep *all* candidates (even those with extremely loose bounds), since -the Frobenius-norm estimates can scale poorly with depth and dimension. --/ -def passesEgregiousIllusionFilter (_candidate : CandidateInductionHead) : Bool := - true - -/-- Compute the raw **direct** effectiveness score `δ` for an induction-head candidate. - -Induction heads are primarily a **pattern** story (Q/K-composition): a "previous token" head -enables the *induction head* to attend to the **successor** of a previous matching token. -Once the induction head is attending to the right source positions, its functional effect is -driven by the head's **OV circuit** applied to the residual stream at those source tokens. - -Accordingly, we treat head 1 purely as a *pattern enabler* (enforced by the pattern filters) -and score only the **direct value path** of head 2: - -For a **Pre-LayerNorm** transformer (GPT-2 style), attention reads from `ln₁(X₂)` (not `X₂`). -We compute: - -`Y = A₂ · ln₁(X₂) · W₂`, - -then score the last position against `target.direction = u`: - -`δ = ⟪Y[last], u⟫`. --/ -def computeInductionEffectiveness - (candidate : CandidateInductionHead) - (cache : PrecomputedCache) - (layer2Ln1Input : ConcreteMatrix) - (target : TargetDirection) : Float := - match cache.getHeadData candidate.layer2Idx candidate.head2Idx with - | some data2 => - let u := target.direction - - -- Direct head score: δ = ⟪(A₂ · ln₁(X₂) · W₂)[last], u⟫. - -- - -- PERFORMANCE: compute this scalar without materializing any dense `d×d` products. - -- 1) v = (W_V·W_O) · u = W_V · (W_O · u) - -- 2) xdot[k] = ⟪ln₁(X₂)[k], v⟫ - -- 3) δ = ⟪A₂[last], xdot⟫ - if u.numCols ≠ 1 then 0.0 - else if data2.attention.seqLen = 0 then 0.0 - else if layer2Ln1Input.numRows ≠ data2.attention.seqLen then 0.0 - else - if hl : candidate.layer2Idx < cache.model.layers.size then - let layerHeads := cache.model.layers[candidate.layer2Idx]'hl - if hh : candidate.head2Idx < layerHeads.size then - let head2 := layerHeads[candidate.head2Idx]'hh - if layer2Ln1Input.numCols ≠ u.numRows then 0.0 - else - let n := data2.attention.seqLen - let lastPos := n - 1 - - -- v = W_V · (W_O · u) - let t := head2.W_O.matVecMul u - let v := head2.W_V.matVecMul t - if v.numRows = 0 then 0.0 - else - -- xdot[k] = ⟪ln₁(X₂)[k], v⟫ - let xdot : Array Float := .ofFn fun k : Fin n => Id.run do - let xBase := k.val * layer2Ln1Input.numCols - let mut acc : Float := 0.0 - for c in [:layer2Ln1Input.numCols] do - -- SAFETY: `k < n = layer2Ln1Input.numRows` by construction and guard above. - let x := layer2Ln1Input.data[xBase + c]! - -- SAFETY: `v` is `modelDim×1` so `c < v.data.size`. - let vc := v.data[c]! - acc := acc + x * vc - return acc - - -- δ = ⟪A2[last], xdot⟫ - Id.run do - let w2 := data2.attention.weights - let row2Base := lastPos * n - let mut score : Float := 0.0 - for k in [:n] do - -- SAFETY: `w2` has size `n*n` and `row2Base + k < n*n`. - score := score + w2[row2Base + k]! * xdot[k]! - return score - else - 0.0 - else - 0.0 - | none => 0.0 - -/-- Compute the **certified lower bound** from `true_induction_head_predicts_logits`. - -`LowerBound = δ - (ε · ‖X‖_F · ‖u‖₂)` where: -- `δ` is the virtual effectiveness score, -- `ε` is `candidate.combinedError`, -- `X` is the layer-2 Pre-LN attention input matrix `ln₁(X₂)`, -- `u` is the target direction. --/ -def computeCertifiedLowerBound - (delta : Float) - (candidate : CandidateInductionHead) - (layer2Ln1Input : ConcreteMatrix) - (target : TargetDirection) : Float := - delta - (candidate.combinedError * layer2Ln1Input.frobeniusNorm * target.direction.vecNorm) - -/-- Rank induction-head candidates by a **mechanism-first** score. - -We compute the raw direct-effect score `δ`, normalize it to a scale-invariant `effect`, -and also compute a prompt/weight-based mechanism score: - -`mechScore = kComp · inductionScore · prevTokenStrength`, - -as well as the combined score: - -`circuitScore = effect · mechScore`. - -Here: -- `effect = δ / (‖ln₁(X₂)‖_F · ‖u‖₂)` removes the Pre-LN depth confounder where the residual - stream norm grows with depth, and -- `kComp` (from the circuits framework paper) measures how strongly head 1 can feed into - head 2's **QK circuit**, i.e. whether head 1 plausibly acts as a *pattern enabler* for head 2. -- `inductionScore` is the prompt-dependent "copy-next" attention score for head 2, and -- `prevTokenStrength` is the prompt-dependent previous-token attention score for head 1. - -We rank primarily by `mechScore` (to identify the canonical induction mechanism) and use -`effect` only as a secondary key. - -We still compute and report `combinedError` for inspection, but avoid using it as a primary -ranking key since Frobenius-norm bounds can be systematically looser in high dimensions. - -Uses a `PrecomputedCache` so attention patterns/projections and layer inputs are computed once. --/ -def findHeuristicInductionHeadsWithCache (model : ConcreteModel) - (target : TargetDirection) - (minEffect : Float := 0.0) - (minPrevTokenStrength : Float := 0.1) - (minInductionScore : Float := 0.05) - (buildLayerNormBounds : Bool := true) - (layerNormEffort : ConcreteMatrix.BoundEffort := ConcreteMatrix.BoundEffort.tier1) - (storeDiagnostics : Bool := false) : - (Array HeuristicInductionHead × PrecomputedCache) := - Id.run do - let cache := PrecomputedCache.build model (computeLayerNormBounds := buildLayerNormBounds) - (layerNormEffort := layerNormEffort) - (storeDiagnostics := storeDiagnostics) - let targetNorm := target.direction.vecNorm - -- Precompute layer input norms once (per-layer, not per-candidate). - let layerInputNorms : Array Float := - .ofFn fun i : Fin model.numLayers => - (cache.forwardResult.getLayerInput i.val).frobeniusNorm - let ln1InputNorms : Array Float := - .ofFn fun i : Fin model.numLayers => - (cache.getLn1Input i.val).frobeniusNorm - let candidates := - findInductionHeadCandidatesFromCache cache minPrevTokenStrength minInductionScore - let mut certified : Array HeuristicInductionHead := #[] - - for candidate in candidates do - if passesEgregiousIllusionFilter candidate then - let layer2InputNorm := layerInputNorms.getD candidate.layer2Idx 0.0 - let layer2Ln1Input := cache.getLn1Input candidate.layer2Idx - let layer2Ln1InputNorm := ln1InputNorms.getD candidate.layer2Idx 0.0 - let delta := computeInductionEffectiveness candidate cache layer2Ln1Input target - let denom := layer2Ln1InputNorm * targetNorm - let effect := - if denom > 1e-10 then delta / denom else 0.0 - if effect > minEffect then - certified := certified.push { - candidate := candidate - delta := delta - effect := effect - layer2InputNorm := layer2InputNorm - layer2Ln1InputNorm := layer2Ln1InputNorm - } - - let certifiedSorted := - certified.qsort (fun a b => - -- Primary key: higher **mechanism score** first. - -- - -- Induction heads are primarily defined by attention-pattern structure (copy-next) - -- plus K-composition with a previous-token head. Target-direction Effect is useful, - -- but prompt/target-dependent; we therefore use it only as a secondary key. - let sa := - if Float.isNaN a.effect ∨ Float.isNaN a.candidate.kComp ∨ - Float.isNaN a.candidate.inductionScore ∨ - Float.isNaN a.candidate.prevTokenStrength then - (-Float.inf) - else - a.candidate.kComp * a.candidate.inductionScore * a.candidate.prevTokenStrength - let sb := - if Float.isNaN b.effect ∨ Float.isNaN b.candidate.kComp ∨ - Float.isNaN b.candidate.inductionScore ∨ - Float.isNaN b.candidate.prevTokenStrength then - (-Float.inf) - else - b.candidate.kComp * b.candidate.inductionScore * b.candidate.prevTokenStrength - if sb < sa then true - else if sa < sb then false - else - -- Secondary key: higher normalized Effect first. - let ea := if Float.isNaN a.effect then (-Float.inf) else a.effect - let eb := if Float.isNaN b.effect then (-Float.inf) else b.effect - if eb < ea then true - else if ea < eb then false - else - -- Tertiary key: higher K-composition first. - let ka := if Float.isNaN a.candidate.kComp then (-Float.inf) else a.candidate.kComp - let kb := if Float.isNaN b.candidate.kComp then (-Float.inf) else b.candidate.kComp - if kb < ka then true - else if ka < kb then false - else - -- Next key: higher raw δ first. - let δa := if Float.isNaN a.delta then (-Float.inf) else a.delta - let δb := if Float.isNaN b.delta then (-Float.inf) else b.delta - if δb < δa then true - else if δa < δb then false - else - -- Final key: smaller relative-error bound first. - let εa := - if Float.isNaN a.candidate.combinedError then - Float.inf - else - a.candidate.combinedError - let εb := - if Float.isNaN b.candidate.combinedError then - Float.inf - else - b.candidate.combinedError - εa < εb) - - return (certifiedSorted, cache) - -/-- Rank induction-head candidates by a **mechanism-first** score. - -This is the same as `findHeuristicInductionHeadsWithCache`, but discards the cache. --/ -def findHeuristicInductionHeads (model : ConcreteModel) - (target : TargetDirection) - (minEffect : Float := 0.0) - (minPrevTokenStrength : Float := 0.1) - (minInductionScore : Float := 0.05) : Array HeuristicInductionHead := Id.run do - (findHeuristicInductionHeadsWithCache model target minEffect - (minPrevTokenStrength := minPrevTokenStrength) - (minInductionScore := minInductionScore)).1 - -/-- Target-aware importance metrics for a component. - -Like `ComponentImportance` but with an additional field measuring -projection onto the target direction. --/ -structure TargetAwareImportance where - /-- Component identifier -/ - component : ComponentId - /-- Standard value term norm (generic importance) -/ - valueTermNorm : Float - /-- Pattern term bound -/ - patternTermBound : Float - /-- **Target projection**: how much this component contributes to the target direction. - For heads: `‖(W_V · W_O) · u‖` - For neurons: `|W_out[i,:] · u|` -/ - targetProjection : Float - /-- Faithfulness ratio for target: patternBound / targetProjection -/ - targetFaithfulnessRatio : Float - -namespace TargetAwareImportance - -/-- Pretty print target-aware importance. -/ -def toString (imp : TargetAwareImportance) : String := - s!"{imp.component}: target={imp.targetProjection}, generic={imp.valueTermNorm}, " ++ - s!"pattern={imp.patternTermBound}, ratio={imp.targetFaithfulnessRatio}" - -instance : ToString TargetAwareImportance := ⟨TargetAwareImportance.toString⟩ - -end TargetAwareImportance - -/-- Compute target-aware importance for a single attention head. - -The target projection is computed as `‖(W_V · W_O) · u‖` where u is the target direction. -This measures how much the head's output aligns with the target direction. --/ -def computeHeadTargetImportance (model : ConcreteModel) (layerIdx headIdx : Nat) - (inputNorm : Float) (target : TargetDirection) : Option TargetAwareImportance := do - let attn ← model.computeAttention layerIdx headIdx - if h1 : layerIdx < model.layers.size then - let layerHeads := model.layers[layerIdx] - if h2 : headIdx < layerHeads.size then - let head := layerHeads[headIdx] - let bnds := head.noDenseProductBounds - - -- Standard metrics - let valueNorm := computeValueTermNorm attn bnds.voFrobNormSq - let inputs : PatternTermBoundInputs := { - attention := attn - inputNorm := inputNorm - -- Conservative: `‖X‖₂ ≤ ‖X‖_F`. - inputOpBound := inputNorm - scaleFactor := Float.sqrt head.headDim.toFloat - wqOpBound := bnds.wqOpGram - wkOpBound := bnds.wkOpGram - wvOpBound := bnds.wvOpGram - woOpBound := bnds.woOpGram - voOpBound := bnds.voOpBound - bqFrob := head.b_Q.frobeniusNorm - bkFrob := head.b_K.frobeniusNorm - bvFrob := head.b_V.frobeniusNorm - } - let patternBound := computePatternTermBound inputs - - -- Target-aware: compute ‖(W_V · W_O) · u‖ via low-rank matvecs: - -- (W_V·W_O)·u = W_V·(W_O·u). - let t := head.W_O.matVecMul target.direction - let projectedVec := head.W_V.matVecMul t - let targetProj := projectedVec.vecNorm - - -- Scale by attention norm (as in standard valueTermNorm) - let attnNormSq := sumSquares attn.weights - let attnNorm := Float.sqrt attnNormSq - let targetImportance := attnNorm * targetProj - - let ratio := if targetImportance < 1e-10 then Float.inf else patternBound / targetImportance - - return { - component := ComponentId.head layerIdx headIdx - valueTermNorm := valueNorm - patternTermBound := patternBound - targetProjection := targetImportance - targetFaithfulnessRatio := ratio - } - else none - else none - -/-- Compute target-aware importance for a single MLP neuron. - -The target projection is `|W_out[i,:] · u|` - the absolute dot product of the -neuron's output weights with the target direction. --/ -def computeNeuronTargetImportance (model : ConcreteModel) (layerIdx neuronIdx : Nat) - (_inputNorm : Float) (target : TargetDirection) : Option TargetAwareImportance := - if h : layerIdx < model.mlps.size then - let mlp := model.mlps[layerIdx] - if neuronIdx < mlp.hiddenDim then - -- Standard influence - let inputNormVal := mlp.neuronInputNorm neuronIdx - let outputNormVal := mlp.neuronOutputNorm neuronIdx - let influence := inputNormVal * outputNormVal - - -- Target-aware: compute W_out[i,:] · u - -- Get row i of W_out as a column vector for dot product - let outputWeights : ConcreteMatrix := { - numRows := mlp.modelDim - numCols := 1 - data := .ofFn fun j : Fin mlp.modelDim => - mlp.W_out.data[neuronIdx * mlp.modelDim + j.val]! - size_eq := by simp - } - let dotProd := outputWeights.dot target.direction - let targetProj := inputNormVal * Float.abs dotProd - - let patternBound : Float := 0.0 -- ReLU is locally linear - let ratio := if targetProj < 1e-10 then Float.inf else patternBound / targetProj - - some { - component := ComponentId.mlpNeuron layerIdx neuronIdx - valueTermNorm := influence - patternTermBound := patternBound - targetProjection := targetProj - targetFaithfulnessRatio := ratio - } - else none - else none - -/-- Compute target-aware importance for all components in a model. -/ -def computeAllTargetImportance (model : ConcreteModel) - (target : TargetDirection) : Array TargetAwareImportance := Id.run do - let inputNorm := computeInputNorm model.inputEmbeddings - - let computeHeadsForLayer (l : Nat) : Array TargetAwareImportance := Id.run do - if h : l < model.layers.size then - let layerHeads := model.layers[l] - let mut outs : Array TargetAwareImportance := Array.mkEmpty layerHeads.size - for h_idx in [:layerHeads.size] do - match computeHeadTargetImportance model l h_idx inputNorm target with - | some imp => outs := outs.push imp - | none => pure () - return outs - else - return #[] - - let computeNeuronsForLayer (l : Nat) : Array TargetAwareImportance := Id.run do - if h : l < model.mlps.size then - let mlp := model.mlps[l] - let mut outs : Array TargetAwareImportance := Array.mkEmpty mlp.hiddenDim - for n_idx in [:mlp.hiddenDim] do - match computeNeuronTargetImportance model l n_idx inputNorm target with - | some imp => outs := outs.push imp - | none => pure () - return outs - else - return #[] - - let useParallel := model.numLayers >= 4 - let headChunks : Array (Array TargetAwareImportance) := - if useParallel then - let tasks : Array (Task (Array TargetAwareImportance)) := - .ofFn fun i : Fin model.numLayers => - Task.spawn (fun _ => computeHeadsForLayer i.val) - tasks.map Task.get - else - .ofFn fun i : Fin model.numLayers => - computeHeadsForLayer i.val - - let neuronChunks : Array (Array TargetAwareImportance) := - if useParallel then - let tasks : Array (Task (Array TargetAwareImportance)) := - .ofFn fun i : Fin model.numLayers => - Task.spawn (fun _ => computeNeuronsForLayer i.val) - tasks.map Task.get - else - .ofFn fun i : Fin model.numLayers => - computeNeuronsForLayer i.val - - -- Join in the same order as the original loop: heads then neurons, increasing layer index. - let totalHeads := sumSizes headChunks - let totalNeurons := sumSizes neuronChunks - let mut result : Array TargetAwareImportance := Array.mkEmpty (totalHeads + totalNeurons) - for cs in headChunks do - for c in cs do - result := result.push c - for cs in neuronChunks do - for c in cs do - result := result.push c - result - -/-! ### Circuit Faithfulness Estimation -/ - -/-- Error breakdown for a circuit. - -Total error has two components: -1. **Pattern Term Error**: Approximation error from trusting attention patterns -2. **Ablation Error**: Information loss from pruned components - -Total Error ≤ PatternTermError + AblationError --/ -structure CircuitError where - /-- Sum of pattern term bounds for included components -/ - patternTermError : Float - /-- Sum of value term norms for excluded (ablated) components -/ - ablationError : Float - /-- Combined error bound -/ - totalError : Float - /-- Number of included components -/ - includedCount : Nat - /-- Number of excluded components -/ - excludedCount : Nat - deriving Repr - -namespace CircuitError - -/-- Pretty print error breakdown. -/ -def toString (err : CircuitError) : String := - s!"CircuitError(total={err.totalError}, pattern={err.patternTermError}, " ++ - s!"ablation={err.ablationError}, included={err.includedCount}, excluded={err.excludedCount})" - -instance : ToString CircuitError := ⟨CircuitError.toString⟩ - -end CircuitError - -/-! ### N-Layer Faithfulness Verification - -This section implements the N-layer error amplification formula from `Linearization.lean` -for concrete Float matrices. The key insight is that errors in early layers get -amplified as they propagate through subsequent layers: - - ε_total = Σᵢ εᵢ · ∏_{j>i} (1 + Cⱼ) - -where: -- εᵢ = pattern term bound for layer i (interpretation error) -- Cⱼ = operator norm bound for layer j's residual Jacobian (layerJacobian - I) --/ - -/-- Per-layer error metrics for deep circuit analysis. -/ -structure LayerErrorMetrics where - /-- Layer index -/ - layerIdx : Nat - /-- Pattern term bound εᵢ (faithfulness error before amplification) -/ - patternTermBound : Float - /-- Operator norm upper bound Cᵢ for layerJacobian - I (residual part). -/ - operatorNormUb : Float - /-- Suffix amplification: ∏_{j>i} (1 + Cⱼ) -/ - suffixAmplification : Float - /-- Amplified error contribution: εᵢ · suffixAmplification(i+1) -/ - amplifiedError : Float - -namespace LayerErrorMetrics - -def toString (m : LayerErrorMetrics) : String := - s!"Layer {m.layerIdx}: ε={m.patternTermBound}, C_ub={m.operatorNormUb}, " ++ - s!"amp={m.suffixAmplification}, contrib={m.amplifiedError}" - -instance : ToString LayerErrorMetrics := ⟨toString⟩ - -end LayerErrorMetrics - -/-- Deep circuit verification result with rigorous N-layer error bounds. -/ -structure DeepCircuitVerification where - /-- Per-layer error breakdown -/ - layerMetrics : Array LayerErrorMetrics - /-- Total error bound: Σᵢ εᵢ · suffixAmplification(i+1) -/ - totalAmplifiedError : Float - /-- Simple sum error (no amplification): Σᵢ εᵢ -/ - simpleErrorSum : Float - /-- Total amplification factor: ∏ᵢ (1 + Cᵢ) -/ - totalAmplificationFactor : Float - /-- Ablation error from excluded components -/ - ablationError : Float - /-- Combined error bound (amplified + ablation) -/ - combinedError : Float - /-- Number of layers analyzed -/ - numLayers : Nat - -namespace DeepCircuitVerification - -def toString (v : DeepCircuitVerification) : String := - s!"DeepCircuitVerification:\n" ++ - s!" Layers: {v.numLayers}\n" ++ - s!" Total Amplified Error: {v.totalAmplifiedError}\n" ++ - s!" Simple Error Sum: {v.simpleErrorSum}\n" ++ - s!" Amplification Factor: {v.totalAmplificationFactor}\n" ++ - s!" Ablation Error: {v.ablationError}\n" ++ - s!" Combined Error: {v.combinedError}" - -instance : ToString DeepCircuitVerification := ⟨toString⟩ - -end DeepCircuitVerification - -/-- Estimate pattern term bound for a single layer (all heads combined). - -Aggregates pattern term bounds across all attention heads in the layer. --/ -def estimateLayerPatternBound (model : ConcreteModel) (fwdResult : ForwardPassResult) - (layerIdx : Nat) (circuit : ConcreteCircuit) : Float := Id.run do - if h : layerIdx < model.layers.size then - let heads := model.layers[layerIdx] - let layerInput := fwdResult.getLayerInput layerIdx - let attnInput := model.applyLn1 layerIdx layerInput - let inputNorm := attnInput.frobeniusNorm - let inputOpBound := attnInput.opNormUpperBoundOneInf - let ln1Bound := model.ln1OpBound layerIdx layerInput - let mut totalBound : Float := 0.0 - - for hidx in [:heads.size] do - if hh : hidx < heads.size then - -- Only count included heads - if circuit.isHeadIncluded layerIdx hidx then - let head := heads[hidx] - let attn := head.computeAttentionWeights attnInput - let bnds := head.noDenseProductBounds - let inputs : PatternTermBoundInputs := { - attention := attn - inputNorm := inputNorm - inputOpBound := inputOpBound - scaleFactor := Float.sqrt head.headDim.toFloat - wqOpBound := bnds.wqOpGram - wkOpBound := bnds.wkOpGram - wvOpBound := bnds.wvOpGram - woOpBound := bnds.woOpGram - voOpBound := bnds.voOpBound - bqFrob := head.b_Q.frobeniusNorm - bkFrob := head.b_K.frobeniusNorm - bvFrob := head.b_V.frobeniusNorm - } - -- Pre-LN: pattern sensitivity is scaled by the LayerNorm Jacobian. - totalBound := totalBound + ln1Bound * computePatternTermBound inputs - - totalBound - else - return 0.0 - -/-- Estimate ablation error for excluded components at a single layer. -/ -def estimateLayerAblationError (model : ConcreteModel) (fwdResult : ForwardPassResult) - (layerIdx : Nat) (circuit : ConcreteCircuit) : Float := Id.run do - let layerInput := fwdResult.getLayerInput layerIdx - let mut totalError : Float := 0.0 - - -- Ablation error from excluded attention heads - if h : layerIdx < model.layers.size then - let heads := model.layers[layerIdx] - for hidx in [:heads.size] do - if hh : hidx < heads.size then - if !circuit.isHeadIncluded layerIdx hidx then - match computeHeadImportance model layerIdx hidx layerInput with - | some imp => totalError := totalError + imp.valueTermNorm - | none => pure () - - -- Ablation error from excluded neurons - if hm : layerIdx < model.mlps.size then - let mlp := model.mlps[layerIdx] - for nidx in [:mlp.hiddenDim] do - if !circuit.isNeuronIncluded layerIdx nidx then - match computeNeuronImportance model layerIdx nidx (layerInput.frobeniusNorm) with - | some imp => totalError := totalError + imp.valueTermNorm - | none => pure () - - totalError - -/-- Verify a deep circuit using rigorous N-layer error amplification bounds. - -This is the main function that bridges theoretical bounds to practical verification. -It computes: -1. Per-layer pattern term bounds (εᵢ) -2. Per-layer residual operator norm bounds (Cᵢ) -3. Suffix amplification factors -4. Total amplified error using the N-layer composition formula - -**The N-Layer Formula:** - ε_total = Σᵢ εᵢ · ∏_{j>i} (1 + Cⱼ) - -This captures that early layer errors get amplified more because they pass -through more subsequent layers. --/ -def verifyDeepCircuit (model : ConcreteModel) - (circuit : ConcreteCircuit) : DeepCircuitVerification := Id.run do - -- Run forward pass to get layer-wise inputs - let fwdResult := model.runForward - - -- Step 1: Compute per-layer residual operator norm bounds - let mut normBounds : Array Float := #[] - for l in [:model.numLayers] do - let norm := estimateAttentionLayerNorm model fwdResult l true - normBounds := normBounds.push norm - - -- Step 2: Compute per-layer pattern term bounds and ablation errors - let mut patternBounds : Array Float := #[] - let mut ablationErrors : Array Float := #[] - for l in [:model.numLayers] do - let pattern := estimateLayerPatternBound model fwdResult l circuit - let ablation := estimateLayerAblationError model fwdResult l circuit - patternBounds := patternBounds.push pattern - ablationErrors := ablationErrors.push ablation - - -- Step 3: Compute suffix amplification and amplified errors per layer - let mut layerMetrics : Array LayerErrorMetrics := #[] - let mut simpleSum : Float := 0.0 - let mut totalAmplified : Float := 0.0 - - for l in [:model.numLayers] do - if hl : l < patternBounds.size then - let epsilon := patternBounds[l] - let normBound := if hn : l < normBounds.size then normBounds[l] else 0.0 - let suffix := computeSuffixAmplification normBounds (l + 1) - let amplified := epsilon * suffix - - simpleSum := simpleSum + epsilon - totalAmplified := totalAmplified + amplified - - layerMetrics := layerMetrics.push { - layerIdx := l - patternTermBound := epsilon - operatorNormUb := normBound - suffixAmplification := suffix - amplifiedError := amplified - } - - -- Step 4: Compute total ablation error - let totalAblation := sumFloatArray ablationErrors - - -- Step 5: Compute total amplification factor - let totalAmpFactor := computeSuffixAmplification normBounds 0 - - { - layerMetrics := layerMetrics - totalAmplifiedError := totalAmplified - simpleErrorSum := simpleSum - totalAmplificationFactor := totalAmpFactor - ablationError := totalAblation - combinedError := totalAmplified + totalAblation - numLayers := model.numLayers - } - -/-- Check if a deep circuit meets a certification threshold. -/ -def isDeepCircuitCertified (verification : DeepCircuitVerification) - (threshold : Float) : Bool := - verification.combinedError ≤ threshold - -/-- Structure for a verified deep circuit with certification. -/ -structure VerifiedDeepCircuit where - /-- The circuit that was verified -/ - circuit : ConcreteCircuit - /-- Full verification details -/ - verification : DeepCircuitVerification - /-- The threshold used -/ - threshold : Float - /-- Whether it passed certification -/ - certified : Bool - -namespace VerifiedDeepCircuit - -def toString (v : VerifiedDeepCircuit) : String := - let status := if v.certified then "✓ CERTIFIED" else "✗ NOT CERTIFIED" - s!"{status} (threshold={v.threshold})\n{v.verification}" - -instance : ToString VerifiedDeepCircuit := ⟨toString⟩ - -end VerifiedDeepCircuit - -/-- Verify and certify a deep circuit against a threshold. -/ -def certifyDeepCircuit (model : ConcreteModel) (circuit : ConcreteCircuit) - (threshold : Float) : VerifiedDeepCircuit := - let verification := verifyDeepCircuit model circuit - { - circuit := circuit - verification := verification - threshold := threshold - certified := isDeepCircuitCertified verification threshold - } - -/-- Estimate the faithfulness error for a given circuit mask. - -This is the core function that enables circuit discovery without forward passes. -It uses only weight matrices and attention patterns to bound the error. - -**Error Model:** -- For **included** components: we incur the pattern term error (approximation) -- For **excluded** components: we incur the value term error (ablation) - -Total Error = Σ(included) patternBound + Σ(excluded) valueNorm --/ -def estimateCircuitFaithfulness (model : ConcreteModel) - (circuit : ConcreteCircuit) : CircuitError := Id.run do - let inputNorm := computeInputNorm model.inputEmbeddings - let mut patternError : Float := 0.0 - let mut ablationError : Float := 0.0 - let mut includedCount : Nat := 0 - let mut excludedCount : Nat := 0 - - -- Process attention heads - for l in [:model.numLayers] do - if h : l < model.layers.size then - let layerHeads := model.layers[l] - for h_idx in [:layerHeads.size] do - let included := circuit.isHeadIncluded l h_idx - match computeHeadImportance model l h_idx model.inputEmbeddings with - | some imp => - if included then - patternError := patternError + imp.patternTermBound - includedCount := includedCount + 1 - else - ablationError := ablationError + imp.valueTermNorm - excludedCount := excludedCount + 1 - | none => pure () - - -- Process MLP neurons - for l in [:model.numLayers] do - if h : l < model.mlps.size then - let mlp := model.mlps[l] - for n_idx in [:mlp.hiddenDim] do - let included := circuit.isNeuronIncluded l n_idx - match computeNeuronImportance model l n_idx inputNorm with - | some imp => - if included then - patternError := patternError + imp.patternTermBound - includedCount := includedCount + 1 - else - ablationError := ablationError + imp.valueTermNorm - excludedCount := excludedCount + 1 - | none => pure () - - { - patternTermError := patternError - ablationError := ablationError - totalError := patternError + ablationError - includedCount := includedCount - excludedCount := excludedCount - } - -/-- Extended circuit error with IBP analysis details. -/ -structure CircuitErrorIBP extends CircuitError where - /-- Total number of unstable neurons detected -/ - unstableNeuronCount : Nat - /-- Pattern error contribution from unstable MLP neurons -/ - mlpInstabilityError : Float - /-- Per-layer MLP stability ratios -/ - layerStabilityRatios : Array Float - deriving Repr - -namespace CircuitErrorIBP - -def toString (e : CircuitErrorIBP) : String := - s!"CircuitErrorIBP: pattern={e.patternTermError}, ablation={e.ablationError}, " ++ - s!"total={e.totalError}, unstable_neurons={e.unstableNeuronCount}, " ++ - s!"mlp_instability={e.mlpInstabilityError}" - -instance : ToString CircuitErrorIBP := ⟨toString⟩ - -end CircuitErrorIBP - -/-- Estimate circuit faithfulness with Interval Bound Propagation for MLPs. - -This is the **sound** version of circuit faithfulness estimation that properly -accounts for MLP activation instability. It runs a forward pass to get layer -inputs, then uses IBP to bound the pattern term error for neurons that may -flip activation states. - -**Key Improvement over `estimateCircuitFaithfulness`:** -- Standard version assumes MLP pattern term = 0 (unsound) -- IBP version detects unstable neurons and computes rigorous error bounds - -**Algorithm:** -1. Run forward pass to get layer inputs -2. For each layer, compute the ablation perturbation norm (sum of excluded - component value terms up to this layer) -3. For each MLP neuron, use IBP to determine if it's stable under this - perturbation -4. Unstable neurons contribute their IBP pattern term bound to total error - -**Parameters:** -- `causal`: Whether to use causal attention masking (default true) - -**Returns:** Extended error struct with stability analysis details --/ -def estimateCircuitFaithfulnessIBP (model : ConcreteModel) - (circuit : ConcreteCircuit) (causal : Bool := true) : CircuitErrorIBP := Id.run do - let inputNorm := computeInputNorm model.inputEmbeddings - let fwd := model.runForward causal - - let mut patternError : Float := 0.0 - let mut ablationError : Float := 0.0 - let mut mlpInstabilityError : Float := 0.0 - let mut unstableCount : Nat := 0 - let mut includedCount : Nat := 0 - let mut excludedCount : Nat := 0 - let mut layerStability : Array Float := #[] - - -- Track cumulative ablation perturbation up to each layer - -- This is the norm of the change to the residual stream from ablated components - let mut cumulativeAblation : Float := 0.0 - - -- Process layer by layer - for l in [:model.numLayers] do - let mut layerAblation : Float := 0.0 - - -- Process attention heads in this layer - if h : l < model.layers.size then - let layerHeads := model.layers[l] - let layerInput := fwd.getLayerInput l - for h_idx in [:layerHeads.size] do - let included := circuit.isHeadIncluded l h_idx - match computeHeadImportance model l h_idx layerInput with - | some imp => - if included then - patternError := patternError + imp.patternTermBound - includedCount := includedCount + 1 - else - ablationError := ablationError + imp.valueTermNorm - layerAblation := layerAblation + imp.valueTermNorm - excludedCount := excludedCount + 1 - | none => pure () - - -- Update cumulative ablation (this affects MLP inputs) - cumulativeAblation := cumulativeAblation + layerAblation - - -- Process MLP neurons with IBP - if hm : l < model.mlps.size then - let mlp := model.mlps[l] - let layerInput := fwd.getLayerInput l - - let mut layerUnstable : Nat := 0 - let mut layerMlpPattern : Float := 0.0 - - for n_idx in [:mlp.hiddenDim] do - let included := circuit.isNeuronIncluded l n_idx - if included then - -- Use IBP with cumulative perturbation norm - match computeNeuronImportanceIBP model l n_idx layerInput cumulativeAblation with - | some imp => - patternError := patternError + imp.patternTermBound - if imp.patternTermBound > 0.0 then - layerUnstable := layerUnstable + 1 - layerMlpPattern := layerMlpPattern + imp.patternTermBound - includedCount := includedCount + 1 - | none => pure () - else - -- Excluded neurons contribute value term (ablation error) - let influence := mlp.neuronInfluence n_idx - ablationError := ablationError + influence - excludedCount := excludedCount + 1 - - unstableCount := unstableCount + layerUnstable - mlpInstabilityError := mlpInstabilityError + layerMlpPattern - - -- Record stability ratio for this layer - let stabilityRatio := if mlp.hiddenDim = 0 then 1.0 - else (mlp.hiddenDim - layerUnstable).toFloat / mlp.hiddenDim.toFloat - layerStability := layerStability.push stabilityRatio - - { - patternTermError := patternError - ablationError := ablationError - totalError := patternError + ablationError - includedCount := includedCount - excludedCount := excludedCount - unstableNeuronCount := unstableCount - mlpInstabilityError := mlpInstabilityError - layerStabilityRatios := layerStability - } - -/-- Summary of MLP stability across a model for a given perturbation. -/ -structure MLPStabilitySummary where - /-- Per-layer analysis results -/ - layerAnalyses : Array MLPIntervalAnalysis - /-- Total stable neurons across all layers -/ - totalStable : Nat - /-- Total unstable neurons across all layers -/ - totalUnstable : Nat - /-- Overall stability ratio -/ - overallStabilityRatio : Float - /-- Total pattern term bound from all unstable neurons -/ - totalPatternBound : Float - deriving Repr - -/-- Analyze MLP stability across the entire model. - -Runs IBP analysis on all MLP layers to identify which neurons are stable -under a given perturbation bound. --/ -def analyzeModelMLPStability (model : ConcreteModel) - (perturbationNorm : Float) (causal : Bool := true) : MLPStabilitySummary := Id.run do - let fwd := model.runForward causal - let mut analyses : Array MLPIntervalAnalysis := #[] - let mut totalStable : Nat := 0 - let mut totalUnstable : Nat := 0 - let mut totalPattern : Float := 0.0 - - for l in [:model.numLayers] do - if hm : l < model.mlps.size then - let mlp := model.mlps[l] - let layerInput := fwd.getLayerInput l - - -- Analyze each position and aggregate - let mut layerAnalysis : MLPIntervalAnalysis := { - layerIdx := l - neuronBounds := #[] - perturbationNorm := perturbationNorm - numStable := 0 - numUnstable := 0 - totalPatternBound := 0.0 - } - - -- Use position 0 as representative (could average over positions) - if layerInput.numRows > 0 then - let inputVec : Array Float := .ofFn fun d : Fin layerInput.numCols => - layerInput.getUnsafe 0 d.val - layerAnalysis := mlp.analyzeIntervalBounds l inputVec perturbationNorm - - analyses := analyses.push layerAnalysis - totalStable := totalStable + layerAnalysis.numStable - totalUnstable := totalUnstable + layerAnalysis.numUnstable - totalPattern := totalPattern + layerAnalysis.totalPatternBound - - let totalNeurons := totalStable + totalUnstable - let ratio := if totalNeurons = 0 then 1.0 - else totalStable.toFloat / totalNeurons.toFloat - - { - layerAnalyses := analyses - totalStable := totalStable - totalUnstable := totalUnstable - overallStabilityRatio := ratio - totalPatternBound := totalPattern - } - -/-! ### Greedy Circuit Pruning -/ - -/-- Result of the greedy pruning algorithm. -/ -structure PruningResult where - /-- The discovered circuit -/ - circuit : ConcreteCircuit - /-- Error estimate for the circuit -/ - error : CircuitError - /-- History of pruning steps (component removed, error after removal) -/ - pruningHistory : Array (ComponentId × Float) - /-- The error threshold that was used -/ - threshold : Float - -namespace PruningResult - -/-- Pretty print pruning result. -/ -def toString (pr : PruningResult) : String := - s!"PruningResult: {pr.circuit}\n Error: {pr.error}\n " ++ - s!"Steps: {pr.pruningHistory.size}, Threshold: {pr.threshold}" - -instance : ToString PruningResult := ⟨PruningResult.toString⟩ - -end PruningResult - -/-- Find the component with smallest value term (least important for information flow). - -Returns the component ID and its value term norm, considering only currently included -components. --/ -def findLeastImportantComponent (circuit : ConcreteCircuit) - (importance : Array ComponentImportance) : Option (ComponentId × Float) := Id.run do - let mut best : Option (ComponentId × Float) := none - - for imp in importance do - let included := circuit.isIncluded imp.component - if included then - match best with - | none => best := some (imp.component, imp.valueTermNorm) - | some (_, bestValue) => - if imp.valueTermNorm < bestValue then - best := some (imp.component, imp.valueTermNorm) - - best - -/-- Find the component with smallest target projection (least important for target behavior). - -Returns the component ID and its target projection, considering only currently included -components. This is the target-aware version of `findLeastImportantComponent`. --/ -def findLeastImportantTargetComponent (circuit : ConcreteCircuit) - (importance : Array TargetAwareImportance) : Option (ComponentId × Float) := Id.run do - let mut best : Option (ComponentId × Float) := none - - for imp in importance do - let included := circuit.isIncluded imp.component - if included then - match best with - | none => best := some (imp.component, imp.targetProjection) - | some (_, bestValue) => - if imp.targetProjection < bestValue then - best := some (imp.component, imp.targetProjection) - - best - -/-- Estimate circuit faithfulness for target-aware pruning. - -For target-aware circuits, the error model is: -- **Included components**: Contribute approximation error (pattern term) -- **Excluded components**: Contribute information loss measured by target projection - -Unlike generic discovery where we use `‖W_V·W_O‖_F` for ablation error, -here we use `targetProjection` - the component's contribution to the target direction. --/ -def estimateTargetCircuitError (_model : ConcreteModel) (circuit : ConcreteCircuit) - (importance : Array TargetAwareImportance) : CircuitError := Id.run do - let mut patternTermError : Float := 0.0 - let mut ablationError : Float := 0.0 - let mut includedCount : Nat := 0 - let mut excludedCount : Nat := 0 - - for imp in importance do - if circuit.isIncluded imp.component then - patternTermError := patternTermError + imp.patternTermBound - includedCount := includedCount + 1 - else - ablationError := ablationError + imp.targetProjection - excludedCount := excludedCount + 1 - - { - patternTermError := patternTermError - ablationError := ablationError - totalError := patternTermError + ablationError - includedCount := includedCount - excludedCount := excludedCount - } - -/-- Greedy circuit pruning algorithm. - -Starting from the full model, iteratively removes the component with the smallest -value term contribution until the total error would exceed the threshold. - -**Algorithm:** -1. Start with all components included -2. Compute importance metrics for all components -3. Repeat: - a. Find component with smallest valueTermNorm among included - b. Tentatively remove it - c. Estimate new total error - d. If error ≤ threshold, commit removal; else restore and stop -4. Return the pruned circuit - -**Complexity:** O(n²) where n = number of components (n iterations, each scanning n components) --/ -def discoverCircuit (model : ConcreteModel) (threshold : Float) : PruningResult := Id.run do - -- Build heads per layer array - let mut headsPerLayer : Array Nat := #[] - for l in [:model.numLayers] do - if h : l < model.layers.size then - headsPerLayer := headsPerLayer.push model.layers[l].size - else - headsPerLayer := headsPerLayer.push 0 - - -- Build neurons per layer array - let mut neuronsPerLayer : Array Nat := #[] - for l in [:model.numLayers] do - if h : l < model.mlps.size then - neuronsPerLayer := neuronsPerLayer.push model.mlps[l].hiddenDim - else - neuronsPerLayer := neuronsPerLayer.push 0 - - -- Start with full circuit - let mut circuit := ConcreteCircuit.full model.numLayers headsPerLayer neuronsPerLayer - let mut history : Array (ComponentId × Float) := #[] - - -- Precompute all importance metrics - let importance := computeAllImportance model - - -- Initial error - let mut currentError := estimateCircuitFaithfulness model circuit - - -- Greedy pruning loop - let maxIters := circuit.totalComponents - for _ in [:maxIters] do - -- Find least important included component - match findLeastImportantComponent circuit importance with - | none => break -- No more components to prune - | some (comp, _) => - -- Tentatively remove component - let tentativeCircuit := circuit.removeComponent comp - let tentativeError := estimateCircuitFaithfulness model tentativeCircuit - - -- Check if we can afford to remove it - if tentativeError.totalError ≤ threshold then - circuit := tentativeCircuit - currentError := tentativeError - history := history.push (comp, tentativeError.totalError) - else - break -- Would exceed threshold, stop pruning - - { - circuit := circuit - error := currentError - pruningHistory := history - threshold := threshold - } - -/-- Discover circuit with verbose output of each step. -/ -def discoverCircuitVerbose (model : ConcreteModel) (threshold : Float) : - PruningResult × Array String := Id.run do - let mut logs : Array String := #[] - - -- Build heads per layer array - let mut headsPerLayer : Array Nat := #[] - for l in [:model.numLayers] do - if h : l < model.layers.size then - headsPerLayer := headsPerLayer.push model.layers[l].size - else - headsPerLayer := headsPerLayer.push 0 - - -- Build neurons per layer array - let mut neuronsPerLayer : Array Nat := #[] - for l in [:model.numLayers] do - if h : l < model.mlps.size then - neuronsPerLayer := neuronsPerLayer.push model.mlps[l].hiddenDim - else - neuronsPerLayer := neuronsPerLayer.push 0 - - let mut circuit := ConcreteCircuit.full model.numLayers headsPerLayer neuronsPerLayer - let mut history : Array (ComponentId × Float) := #[] - let importance := computeAllImportance model - - logs := logs.push s!"Starting with full circuit: {circuit.countIncluded} components" - - let mut currentError := estimateCircuitFaithfulness model circuit - logs := logs.push s!"Initial error: {currentError.totalError}" - - let maxIters := circuit.totalComponents - for step in [:maxIters] do - match findLeastImportantComponent circuit importance with - | none => - logs := logs.push s!"Step {step}: No more components to prune" - break - | some (comp, valueNorm) => - let tentativeCircuit := circuit.removeComponent comp - let tentativeError := estimateCircuitFaithfulness model tentativeCircuit - - if tentativeError.totalError ≤ threshold then - circuit := tentativeCircuit - currentError := tentativeError - history := history.push (comp, tentativeError.totalError) - let msg := s!"Step {step}: Removed {comp}, new error={tentativeError.totalError}" - logs := logs.push msg - else - let msg := s!"Step {step}: Cannot remove {comp}, exceeds threshold" - logs := logs.push msg - break - - logs := logs.push s!"Final circuit: {circuit}" - - ({ - circuit := circuit - error := currentError - pruningHistory := history - threshold := threshold - }, logs) - -/-! ### Circuit Verification -/ - -/-- A verified circuit that meets the certification threshold. -/ -structure VerifiedCircuit where - /-- The pruned circuit -/ - circuit : ConcreteCircuit - /-- Error estimate -/ - error : CircuitError - /-- Certification threshold -/ - threshold : Float - /-- Human-readable description -/ - description : String - -namespace VerifiedCircuit - -/-- Pretty print verified circuit. -/ -def toString (vc : VerifiedCircuit) : String := - s!"VerifiedCircuit [{vc.description}]\n {vc.circuit}\n {vc.error}\n " ++ - s!"Threshold: {vc.threshold}" - -instance : ToString VerifiedCircuit := ⟨VerifiedCircuit.toString⟩ - -end VerifiedCircuit - -/-- Discover and verify a circuit, returning None if threshold cannot be met. -/ -def discoverVerifiedCircuit (model : ConcreteModel) (threshold : Float) - (description : String := "auto-discovered") : Option VerifiedCircuit := do - let result := discoverCircuit model threshold - if result.error.totalError ≤ threshold then - some { - circuit := result.circuit - error := result.error - threshold := threshold - description := description - } - else - none - -/-! ### Target-Aware Circuit Discovery - -These functions discover circuits optimized for specific predictions rather than -general model behavior. Given a target direction (e.g., logit difference between -correct and incorrect tokens), they find the minimal circuit that explains that -specific prediction. --/ - -/-- Target-aware greedy circuit pruning algorithm. - -Like `discoverCircuit`, but uses target projection instead of generic value term -for importance ranking. This finds circuits that explain *specific behaviors* -rather than everything the model does. - -**Algorithm:** -1. Start with all components included -2. Compute target-aware importance for all components -3. Repeat: - a. Find component with smallest targetProjection among included - b. Tentatively remove it - c. Estimate new total error (using target projections for ablation) - d. If error ≤ threshold, commit removal; else stop -4. Return the pruned circuit - -**Key Difference from Generic Discovery:** -- Generic: prunes by `‖W_V·W_O‖_F` (generic information flow) -- Target-aware: prunes by `‖(W_V·W_O)·u‖` (contribution to target direction) - -This typically produces much smaller circuits when you care about specific outputs. --/ -def discoverTargetedCircuit (model : ConcreteModel) (threshold : Float) - (target : TargetDirection) : PruningResult := Id.run do - -- Build heads per layer array - let mut headsPerLayer : Array Nat := #[] - for l in [:model.numLayers] do - if h : l < model.layers.size then - headsPerLayer := headsPerLayer.push model.layers[l].size - else - headsPerLayer := headsPerLayer.push 0 - - -- Build neurons per layer array - let mut neuronsPerLayer : Array Nat := #[] - for l in [:model.numLayers] do - if h : l < model.mlps.size then - neuronsPerLayer := neuronsPerLayer.push model.mlps[l].hiddenDim - else - neuronsPerLayer := neuronsPerLayer.push 0 - - -- Start with full circuit - let mut circuit := ConcreteCircuit.full model.numLayers headsPerLayer neuronsPerLayer - let mut history : Array (ComponentId × Float) := #[] - - -- Precompute target-aware importance metrics - let importance := computeAllTargetImportance model target - - -- Initial error - let mut currentError := estimateTargetCircuitError model circuit importance - - -- Greedy pruning loop - let maxIters := circuit.totalComponents - for _ in [:maxIters] do - -- Find least important included component (by target projection) - match findLeastImportantTargetComponent circuit importance with - | none => break - | some (comp, _) => - let tentativeCircuit := circuit.removeComponent comp - let tentativeError := estimateTargetCircuitError model tentativeCircuit importance - - if tentativeError.totalError ≤ threshold then - circuit := tentativeCircuit - currentError := tentativeError - history := history.push (comp, tentativeError.totalError) - else - break - - { - circuit := circuit - error := currentError - pruningHistory := history - threshold := threshold - } - -/-- Target-aware circuit discovery with verbose logging. -/ -def discoverTargetedCircuitVerbose (model : ConcreteModel) (threshold : Float) - (target : TargetDirection) : PruningResult × Array String := Id.run do - let mut logs : Array String := #[] - logs := logs.push s!"Target-aware discovery for: {target.description}" - - -- Build layer arrays - let mut headsPerLayer : Array Nat := #[] - for l in [:model.numLayers] do - if h : l < model.layers.size then - headsPerLayer := headsPerLayer.push model.layers[l].size - else - headsPerLayer := headsPerLayer.push 0 - - let mut neuronsPerLayer : Array Nat := #[] - for l in [:model.numLayers] do - if h : l < model.mlps.size then - neuronsPerLayer := neuronsPerLayer.push model.mlps[l].hiddenDim - else - neuronsPerLayer := neuronsPerLayer.push 0 - - let mut circuit := ConcreteCircuit.full model.numLayers headsPerLayer neuronsPerLayer - let mut history : Array (ComponentId × Float) := #[] - let importance := computeAllTargetImportance model target - - logs := logs.push s!"Starting with full circuit: {circuit.countIncluded} components" - - let mut currentError := estimateTargetCircuitError model circuit importance - logs := logs.push s!"Initial error: {currentError.totalError}" - - let maxIters := circuit.totalComponents - for step in [:maxIters] do - match findLeastImportantTargetComponent circuit importance with - | none => - logs := logs.push s!"Step {step}: No more components to prune" - break - | some (comp, targetProj) => - let tentativeCircuit := circuit.removeComponent comp - let tentativeError := estimateTargetCircuitError model tentativeCircuit importance - - if tentativeError.totalError ≤ threshold then - circuit := tentativeCircuit - currentError := tentativeError - history := history.push (comp, tentativeError.totalError) - let msg := s!"Step {step}: Removed {comp} (target={targetProj}), " ++ - s!"new error={tentativeError.totalError}" - logs := logs.push msg - else - let msg := s!"Step {step}: Cannot remove {comp}, exceeds threshold" - logs := logs.push msg - break - - logs := logs.push s!"Final circuit: {circuit}" - logs := logs.push s!"Compression: {circuit.countIncluded}/{circuit.totalComponents} components" - - ({ - circuit := circuit - error := currentError - pruningHistory := history - threshold := threshold - }, logs) - -/-- Discover and verify a target-aware circuit. -/ -def discoverVerifiedTargetedCircuit (model : ConcreteModel) (threshold : Float) - (target : TargetDirection) : Option VerifiedCircuit := do - let result := discoverTargetedCircuit model threshold target - if result.error.totalError ≤ threshold then - some { - circuit := result.circuit - error := result.error - threshold := threshold - description := s!"target-aware: {target.description}" - } - else - none - -/-- Convenience function to discover circuit for logit difference. - -Given correct and incorrect token IDs, creates the target direction -`u = W_U[:, correct] - W_U[:, incorrect]` and discovers the minimal -circuit that explains why the model predicts correct over incorrect. --/ -def discoverLogitDiffCircuit (model : ConcreteModel) (threshold : Float) - (correctToken incorrectToken : Nat) : Option (PruningResult × TargetDirection) := do - let W_U ← model.unembedding - let target := TargetDirection.fromLogitDiff W_U correctToken incorrectToken - let result := discoverTargetedCircuit model threshold target - some (result, target) - -/-- Rank components by their target projection (descending). - -Useful for identifying which components most strongly promote the target behavior. --/ -def rankComponentsByTargetImportance (model : ConcreteModel) - (target : TargetDirection) : Array TargetAwareImportance := - let importance := computeAllTargetImportance model target - importance.qsort (·.targetProjection > ·.targetProjection) - -/-- Get the top-k components most important for a target direction. -/ -def topKTargetComponents (model : ConcreteModel) (target : TargetDirection) - (k : Nat) : Array TargetAwareImportance := - let ranked := rankComponentsByTargetImportance model target - ranked.extract 0 (min k ranked.size) - -/-! ### End-to-End Discovery and Verification - -These functions combine circuit discovery with empirical verification, -providing a complete workflow from model analysis to validated circuits. --/ - -/-- Discover a circuit and immediately verify it empirically. - -This is the end-to-end function that: -1. Discovers a minimal circuit using greedy pruning -2. Computes theoretical error bounds -3. Verifies empirically that the circuit is faithful - -Returns both the pruning result and verification result. --/ -def discoverAndVerify (model : ConcreteModel) (threshold : Float) - (causal : Bool := true) : PruningResult × VerificationResult := - let pruning := discoverCircuit model threshold - let verification := verifyCircuitFaithfulness model pruning.circuit - pruning.error.totalError causal - (pruning, verification) - -/-- Discover a target-aware circuit and verify it empirically. - -Like `discoverAndVerify` but optimizes for a specific prediction target -(e.g., logit difference between correct and incorrect tokens). --/ -def discoverTargetedAndVerify (model : ConcreteModel) (threshold : Float) - (target : TargetDirection) (causal : Bool := true) : - PruningResult × VerificationResult := - let pruning := discoverTargetedCircuit model threshold target - let verification := verifyCircuitFaithfulness model pruning.circuit - pruning.error.totalError causal - (pruning, verification) - -/-- Complete analysis: discover, verify, and return detailed comparison. - -This is the most comprehensive function for circuit analysis. It: -1. Discovers the minimal circuit meeting the error threshold -2. Runs both full and ablated forward passes -3. Computes empirical vs theoretical error comparison -4. Returns everything needed for detailed analysis - -**Example output interpretation:** -- `verification.verified = true`: Circuit is empirically faithful -- `verification.tightness = 0.8`: Theoretical bound is 80% utilized (20% slack) -- `ablation.relativeError = 0.05`: Circuit output differs by 5% from full model --/ -def analyzeCircuitFaithfulness (model : ConcreteModel) (threshold : Float) - (causal : Bool := true) : PruningResult × VerificationResult × AblationResult := - let pruning := discoverCircuit model threshold - let ablation := computeAblationDiscrepancy model pruning.circuit causal - let verification := verifyCircuitFaithfulness model pruning.circuit - pruning.error.totalError causal - (pruning, verification, ablation) - -/-- Analyze a target-aware circuit with full verification details. -/ -def analyzeTargetedCircuitFaithfulness (model : ConcreteModel) (threshold : Float) - (target : TargetDirection) (causal : Bool := true) : - PruningResult × VerificationResult × AblationResult := - let pruning := discoverTargetedCircuit model threshold target - let ablation := computeAblationDiscrepancy model pruning.circuit causal - let verification := verifyCircuitFaithfulness model pruning.circuit - pruning.error.totalError causal - (pruning, verification, ablation) - -/-! ### Analysis Utilities -/ - -/-- Rank all components by their value term contribution (descending). -/ -def rankComponentsByImportance (model : ConcreteModel) : Array ComponentImportance := - let importance := computeAllImportance model - importance.qsort (·.valueTermNorm > ·.valueTermNorm) - -/-- Get the top-k most important components. -/ -def topKComponents (model : ConcreteModel) (k : Nat) : Array ComponentImportance := - let ranked := rankComponentsByImportance model - ranked.extract 0 (min k ranked.size) - -/-- Get components with faithfulness ratio below threshold (most reliable). -/ -def reliableComponents (model : ConcreteModel) (maxRatio : Float) : Array ComponentImportance := - let importance := computeAllImportance model - importance.filter (·.faithfulnessRatio ≤ maxRatio) - -/-- Summary of circuit discovery analysis. -/ -structure CircuitAnalysis where - /-- Total number of components in model -/ - totalComponents : Nat - /-- Number of components in discovered circuit -/ - circuitSize : Nat - /-- Compression ratio: circuitSize / totalComponents -/ - compressionRatio : Float - /-- Total error bound -/ - totalError : Float - /-- Pattern term contribution to error -/ - patternContribution : Float - /-- Ablation contribution to error -/ - ablationContribution : Float - /-- Most important component (by value term) -/ - topComponent : Option ComponentImportance - /-- Most reliable component (by faithfulness ratio) -/ - mostReliable : Option ComponentImportance - -/-- Perform comprehensive circuit analysis. -/ -def analyzeCircuit (model : ConcreteModel) (threshold : Float) : CircuitAnalysis := Id.run do - let result := discoverCircuit model threshold - let importance := computeAllImportance model - let ranked := importance.qsort (·.valueTermNorm > ·.valueTermNorm) - let reliable := importance.qsort (·.faithfulnessRatio < ·.faithfulnessRatio) - - let total := result.circuit.totalComponents - let included := result.circuit.countIncluded - let ratio := if total > 0 then included.toFloat / total.toFloat else 1.0 - - { - totalComponents := total - circuitSize := included - compressionRatio := ratio - totalError := result.error.totalError - patternContribution := result.error.patternTermError - ablationContribution := result.error.ablationError - topComponent := if h : 0 < ranked.size then some ranked[0] else none - mostReliable := if h : 0 < reliable.size then some reliable[0] else none - } - -end Nfp diff --git a/Legacy/Nfp/IO.lean b/Legacy/Nfp/IO.lean deleted file mode 100644 index 209329d..0000000 --- a/Legacy/Nfp/IO.lean +++ /dev/null @@ -1,503 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Nfp.IO.Pure - -/-! -# Model IO: Loading Pre-trained Weights - -This module provides functionality to load pre-trained transformer model weights -from external files (exported from PyTorch/JAX) into the `ConcreteModel` structure. - -## Performance & Safety Design - -This module prioritizes **safety and clear error reporting** over raw performance. -- All array accesses use bounds-checked indexing (`array[i]!` panics on OOB) -- Comprehensive validation of file format and dimensions with helpful error messages -- File format parsing is I/O-bound, so optimizing array operations has minimal impact - -For high-performance computation, see `Discovery.lean` where hot paths are optimized. - -## File Format: `.nfpt` (NFP_BINARY_V1) - -Hybrid text header + binary body: - -``` -NFP_BINARY_V1 -num_layers=12 -num_heads=12 -model_dim=768 -head_dim=64 -hidden_dim=3072 -vocab_size=50257 -seq_len=1024 -layer_norm_eps=1e-5 -gelu_kind=tanh -BINARY_START -``` - -`layer_norm_eps` (or legacy `eps`) and `gelu_kind` (or legacy `gelu_deriv`) are required by the -SOUND certification path but are otherwise ignored by this loader. - -Binary payload (little-endian, row-major, no markers): -1. TOKENS: `seq_len` × Int32 -2. EMBEDDINGS: `seq_len` × `model_dim` × Float64 -3. For each layer (0..num_layers-1), for each head (0..num_heads-1): - - W_Q (`model_dim`×`head_dim`), b_Q (`head_dim`) - - W_K (`model_dim`×`head_dim`), b_K (`head_dim`) - - W_V (`model_dim`×`head_dim`), b_V (`head_dim`) - - W_O (`head_dim`×`model_dim`) -4. ATTN_BIAS: `model_dim` -5. MLP: W_in (`model_dim`×`hidden_dim`), b_in (`hidden_dim`), - W_out (`hidden_dim`×`model_dim`), b_out (`model_dim`) -6. LN1 gamma/beta (`model_dim` each) -7. LN2 gamma/beta (`model_dim` each) -8. LN_F gamma/beta (`model_dim` each) -9. UNEMBEDDING: `model_dim`×`vocab_size` -``` --/ - -namespace Nfp - -open IO - -/-- Run an IO action and emit timing when `NFP_TIMING` is set. -/ -def timeIt {α : Type} (label : String) (action : Unit → IO α) : IO α := do - let timingEnabled ← IO.getEnv "NFP_TIMING" - if timingEnabled.isNone then - action () - else - let t0 ← IO.monoNanosNow - let result ← action () - let t1 ← IO.monoNanosNow - let dtMs := (t1 - t0) / 1000000 - IO.eprintln s!"timing:{label} {dtMs}ms" - return result - -/-- Load a model from NFP text format content. -/ -def loadFromText (_content : String) : IO LoadResult := do - return .error "NFP_TEXT format is deprecated; use NFP_BINARY_V1" - -/-! ## Binary `.nfpt` loading (NFP_BINARY_V1) -/ - -private def readLine? (h : IO.FS.Handle) : IO (Option String) := do - let s ← h.getLine - if s.isEmpty then - return none - else - return some s - -private def readExactly (h : IO.FS.Handle) (n : Nat) : IO ByteArray := do - if n = 0 then - return ByteArray.empty - let mut out : ByteArray := ByteArray.mk (Array.replicate n 0) - let mut off : Nat := 0 - while off < n do - let chunk ← h.read (USize.ofNat (n - off)) - if chunk.isEmpty then - throw (IO.userError "unexpected EOF") - out := chunk.copySlice 0 out off chunk.size - off := off + chunk.size - return out - -@[inline] private def u32FromLE (b : ByteArray) (off : Nat) : UInt32 := - let b0 := (b[off]!).toUInt32 - let b1 := (b[off + 1]!).toUInt32 - let b2 := (b[off + 2]!).toUInt32 - let b3 := (b[off + 3]!).toUInt32 - b0 ||| (b1 <<< 8) ||| (b2 <<< 16) ||| (b3 <<< 24) - -@[inline] private def u64FromLE (b : ByteArray) (off : Nat) : UInt64 := - let b0 := (b[off]!).toUInt64 - let b1 := (b[off + 1]!).toUInt64 - let b2 := (b[off + 2]!).toUInt64 - let b3 := (b[off + 3]!).toUInt64 - let b4 := (b[off + 4]!).toUInt64 - let b5 := (b[off + 5]!).toUInt64 - let b6 := (b[off + 6]!).toUInt64 - let b7 := (b[off + 7]!).toUInt64 - b0 ||| (b1 <<< 8) ||| (b2 <<< 16) ||| (b3 <<< 24) ||| - (b4 <<< 32) ||| (b5 <<< 40) ||| (b6 <<< 48) ||| (b7 <<< 56) - -private def twoPow32 : Int := Int.ofNat (Nat.pow 2 32) - -@[inline] private def i32FromLE (b : ByteArray) (off : Nat) : Int := - let u := u32FromLE b off - let half : UInt32 := 0x80000000 - if u < half then - Int.ofNat u.toNat - else - (Int.ofNat u.toNat) - twoPow32 - -@[inline] private def floatFromLE (b : ByteArray) (off : Nat) : Float := - Float.ofBits (u64FromLE b off) - -private def readFloatArray (h : IO.FS.Handle) (count : Nat) : IO FloatArray := do - if count = 0 then - return FloatArray.empty - let bytes ← readExactly h (count * 8) - let mut data : Array Float := Array.replicate count 0.0 - let mut i : Nat := 0 - let mut off : Nat := 0 - while i < count do - data := data.set! i (floatFromLE bytes off) - off := off + 8 - i := i + 1 - return .mk data - -private def readI32Array (h : IO.FS.Handle) (count : Nat) : IO (Array Nat) := do - if count = 0 then - return #[] - let bytes ← readExactly h (count * 4) - let mut out : Array Nat := Array.replicate count 0 - let mut i : Nat := 0 - let mut off : Nat := 0 - while i < count do - let v := i32FromLE bytes off - if v < 0 then - throw (IO.userError s!"Negative token id at index {i}") - out := out.set! i v.toNat - off := off + 4 - i := i + 1 - return out - -/-- Load a model from the `.nfpt` binary format (NFP_BINARY_V1). -/ -def loadBinary (h : IO.FS.Handle) : IO LoadResult := do - try - let some magicLine ← readLine? h - | return .error "Empty file" - let magic := magicLine.trim - if magic != "NFP_BINARY_V1" then - return .error "Invalid magic: expected NFP_BINARY_V1" - - IO.println "[1/5] Parsing header..." - - let mut numLayers : Nat := 0 - let mut numHeads : Nat := 0 - let mut modelDim : Nat := 0 - let mut headDim : Nat := 0 - let mut hiddenDim : Nat := 0 - let mut vocabSize : Nat := 0 - let mut seqLen : Nat := 0 - - let mut line? ← readLine? h - while true do - match line? with - | none => return .error "Unexpected EOF while reading header" - | some line => - let t := line.trim - if t = "BINARY_START" then - break - if t.startsWith "num_layers=" then - numLayers := (t.drop 11).toNat! - else if t.startsWith "num_heads=" then - numHeads := (t.drop 10).toNat! - else if t.startsWith "model_dim=" then - modelDim := (t.drop 10).toNat! - else if t.startsWith "head_dim=" then - headDim := (t.drop 9).toNat! - else if t.startsWith "hidden_dim=" then - hiddenDim := (t.drop 11).toNat! - else if t.startsWith "vocab_size=" then - vocabSize := (t.drop 11).toNat! - else if t.startsWith "seq_len=" then - seqLen := (t.drop 8).toNat! - line? ← readLine? h - - if modelDim = 0 || numLayers = 0 || numHeads = 0 then - return .error s!"Invalid header: modelDim={modelDim}, numLayers={numLayers}, numHeads={numHeads} (all must be > 0)" - if headDim = 0 || hiddenDim = 0 || vocabSize = 0 || seqLen = 0 then - return .error "Invalid header: headDim/hiddenDim/vocabSize/seqLen must be > 0" - - IO.println s!"[2/5] Loading input tokens + embeddings (seq_len={seqLen}, model_dim={modelDim})..." - - let inputTokens : Array Nat ← readI32Array h seqLen - let embFloats ← readFloatArray h (seqLen * modelDim) - let inputEmbeddings := buildMatrix seqLen modelDim embFloats.data - - IO.println s!"[3/5] Loading {numLayers} layers with {numHeads} heads each..." - - let mut layers : Array (Array ConcreteAttentionLayer) := Array.mkEmpty numLayers - let mut attnProjBias : Array ConcreteMatrix := Array.mkEmpty numLayers - let mut mlps : Array ConcreteMLPLayer := Array.mkEmpty numLayers - let mut ln1 : Array ConcreteLayerNormParams := Array.mkEmpty numLayers - let mut ln2 : Array ConcreteLayerNormParams := Array.mkEmpty numLayers - - for l in [:numLayers] do - IO.println s!" Loading layer {l}/{numLayers}..." - let mut layerHeads : Array ConcreteAttentionLayer := Array.mkEmpty numHeads - for _h in [:numHeads] do - let wq ← readFloatArray h (modelDim * headDim) - let bq ← readFloatArray h headDim - let wk ← readFloatArray h (modelDim * headDim) - let bk ← readFloatArray h headDim - let wv ← readFloatArray h (modelDim * headDim) - let bv ← readFloatArray h headDim - let wo ← readFloatArray h (headDim * modelDim) - let head := mkAttentionLayer modelDim headDim wq.data wk.data wv.data wo.data bq.data bk.data bv.data - layerHeads := layerHeads.push head - layers := layers.push layerHeads - - let bias ← readFloatArray h modelDim - attnProjBias := attnProjBias.push (buildMatrix 1 modelDim bias.data) - - let win ← readFloatArray h (modelDim * hiddenDim) - let bin ← readFloatArray h hiddenDim - let wout ← readFloatArray h (hiddenDim * modelDim) - let bout ← readFloatArray h modelDim - mlps := mlps.push (mkMLPLayer modelDim hiddenDim win.data wout.data bin.data bout.data) - - let ln1Gamma ← readFloatArray h modelDim - let ln1Beta ← readFloatArray h modelDim - ln1 := ln1.push { - gamma := buildMatrix 1 modelDim ln1Gamma.data - beta := buildMatrix 1 modelDim ln1Beta.data - } - - let ln2Gamma ← readFloatArray h modelDim - let ln2Beta ← readFloatArray h modelDim - ln2 := ln2.push { - gamma := buildMatrix 1 modelDim ln2Gamma.data - beta := buildMatrix 1 modelDim ln2Beta.data - } - - IO.println "[4/5] Loading final layernorm + unembedding..." - - let lnfGamma ← readFloatArray h modelDim - let lnfBeta ← readFloatArray h modelDim - let lnf := { - gamma := buildMatrix 1 modelDim lnfGamma.data - beta := buildMatrix 1 modelDim lnfBeta.data - } - - let unembFloats ← readFloatArray h (modelDim * vocabSize) - let unembedding := buildMatrix modelDim vocabSize unembFloats.data - - let model : ConcreteModel := { - numLayers := numLayers - layers := layers - attnProjBias := attnProjBias - mlps := mlps - ln1 := ln1 - ln2 := ln2 - lnf := lnf - seqLen := seqLen - inputTokens := some inputTokens - inputEmbeddings := inputEmbeddings - unembedding := some unembedding - } - - IO.println "[5/5] Model loaded successfully! -" - return .ok model - catch e => - return .error s!"Binary load failed: {e}" - -/-- Input tokens + embeddings loaded from a binary `.nfpt` file. -/ -structure InputBinary where - /-- Sequence length parsed from the input header. -/ - seqLen : Nat - /-- Model dimension parsed from the input header. -/ - modelDim : Nat - /-- Token IDs parsed from the input file. -/ - tokens : Array Nat - /-- Input embeddings (seqLen × modelDim). -/ - embeddings : ConcreteMatrix - -/-- Load input tokens + embeddings from a binary `.nfpt` file. -/ -def loadInputBinary (path : System.FilePath) : IO (Except String InputBinary) := do - try - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let some magicLine ← readLine? h - | return .error "Empty input file" - let magic := magicLine.trim - if magic != "NFP_BINARY_V1" then - return .error "Invalid input magic: expected NFP_BINARY_V1" - let mut seqLen? : Option Nat := none - let mut modelDim? : Option Nat := none - let mut line? ← readLine? h - while true do - match line? with - | none => return .error "Unexpected EOF while reading input header" - | some line => - let t := line.trim - if t = "BINARY_START" then - break - if t.startsWith "seq_len=" then - match (t.drop 8).toNat? with - | some n => seqLen? := some n - | none => return .error "Invalid seq_len in input header" - else if t.startsWith "model_dim=" then - match (t.drop 10).toNat? with - | some n => modelDim? := some n - | none => return .error "Invalid model_dim in input header" - line? ← readLine? h - let some seqLen := seqLen? - | return .error "Missing seq_len in input header" - let some modelDim := modelDim? - | return .error "Missing model_dim in input header" - let tokens ← readI32Array h seqLen - let embFloats ← readFloatArray h (seqLen * modelDim) - let embeddings := buildMatrix seqLen modelDim embFloats.data - return .ok { seqLen := seqLen, modelDim := modelDim, tokens := tokens, embeddings := embeddings } - catch e => - return .error s!"Binary input load failed: {e}" -/-! ## File IO Operations -/ - -/-- Load a model from a file path. Supports .nfpt (binary) format. -/ -def loadModel (path : System.FilePath) : IO LoadResult := do - if path.extension = some "nfpt" then - timeIt "io:load-model" (fun () => - IO.FS.withFile path .read fun h => - loadBinary h) - else - return .error s!"Unsupported file format: {path.extension.getD "unknown"}" - -/-! ## Analysis Report Generation -/ - -/-- Format for circuit analysis results. -/ -structure AnalysisReport where - /-- Model name/path -/ - modelName : String - /-- Input prompt (if available) -/ - prompt : Option String - /-- Number of layers analyzed -/ - numLayers : Nat - /-- Total heads in model -/ - totalHeads : Nat - /-- Verified induction head candidates -/ - inductionHeads : Array CandidateInductionHead - /-- Deep circuit candidates with N-layer bounds -/ - deepCircuits : Array DeepCircuitCandidate - /-- Verification result (if run) -/ - verification : Option VerificationResult - -namespace AnalysisReport - -/-- Generate a human-readable report. -/ -def toString (r : AnalysisReport) : String := Id.run do - let mut s := s!"═══════════════════════════════════════════════════════════\n" - s := s ++ s!"NFP Circuit Analysis Report\n" - s := s ++ s!"Model: {r.modelName}\n" - match r.prompt with - | some p => s := s ++ s!"Prompt: \"{p}\"\n" - | none => pure () - s := s ++ s!"Layers: {r.numLayers}, Heads: {r.totalHeads}\n" - s := s ++ s!"═══════════════════════════════════════════════════════════\n\n" - - if r.inductionHeads.size > 0 then - s := s ++ s!"VERIFIED INDUCTION HEADS ({r.inductionHeads.size} found):\n" - s := s ++ s!"───────────────────────────────────────────────────────────\n" - for head in r.inductionHeads do - s := s ++ s!" L{head.layer1Idx}H{head.head1Idx} → L{head.layer2Idx}H{head.head2Idx}\n" - s := s ++ s!" Combined Error: {head.combinedError}\n" - s := s ++ s!" Prev-Token Strength: {head.prevTokenStrength}\n" - s := s ++ s!" Induction Score: {head.inductionScore}\n" - s := s ++ s!" K-Composition: {head.kComp}\n" - s := s ++ s!" Faithfulness Ratios: ε₁={head.patternBound1}, ε₂={head.patternBound2}\n\n" - else - s := s ++ s!"No induction heads found above threshold.\n\n" - - if r.deepCircuits.size > 0 then - s := s ++ s!"DEEP CIRCUIT CANDIDATES ({r.deepCircuits.size} found):\n" - s := s ++ s!"───────────────────────────────────────────────────────────\n" - for circuit in r.deepCircuits do - s := s ++ s!" {circuit.description}\n" - s := s ++ s!" Pattern Type: {circuit.patternType}\n" - s := s ++ s!" Simple Error Sum: {circuit.simpleErrorSum}\n" - s := s ++ s!" Amplified Error: {circuit.amplifiedError}\n" - s := s ++ s!" Amplification Factor: {circuit.amplificationFactor}\n\n" - - match r.verification with - | some v => - s := s ++ s!"EMPIRICAL VERIFICATION:\n" - s := s ++ s!"───────────────────────────────────────────────────────────\n" - let status := if v.verified then "✓ PASSED" else "✗ FAILED" - s := s ++ s!" Status: {status}\n" - s := s ++ s!" Empirical Error: {v.ablation.empiricalError}\n" - s := s ++ s!" Theoretical Bound: {v.theoreticalBound}\n" - s := s ++ s!" Tightness: {v.tightness * 100.0}%\n" - s := s ++ s!" Circuit Size: {v.ablation.circuitSize}/{v.ablation.totalComponents}\n\n" - | none => pure () - - s := s ++ s!"═══════════════════════════════════════════════════════════\n" - s - -instance : ToString AnalysisReport := ⟨AnalysisReport.toString⟩ - -end AnalysisReport - -/-- Run full analysis on a model and generate a report. -/ -def analyzeModel (model : ConcreteModel) (modelName : String) - (threshold : Float := 0.1) - (prompt : Option String := none) : IO AnalysisReport := do - IO.println "\n═══════════════════════════════════════════════════════════" - IO.println "Starting Circuit Analysis" - IO.println s!"Model: {modelName}" - IO.println s!"Threshold: {threshold}" - IO.println "═══════════════════════════════════════════════════════════\n" - - IO.println "[1/2] Building precomputed cache..." - let cache ← timeIt "analysis:precompute-cache" (fun () => - pure <| PrecomputedCache.build model) - - IO.println "[2/2] Searching for deep circuit candidates (shared scan)..." - -- Find deep circuit candidates (reuse cache) - let deepCircuits ← timeIt "analysis:deep-circuit-scan" (fun () => - pure <| findDeepCircuitCandidatesFromCache cache) - let verifiedDeep := deepCircuits.filter (·.amplifiedError ≤ threshold) - IO.println s!" Found {verifiedDeep.size} verified deep circuits \ - (of {deepCircuits.size} candidates)" - - -- Derive induction-head candidates from the same scan to avoid repeating - -- the expensive `checkInductionPattern` computation. - let (totalInduction, verifiedHeads) ← timeIt "analysis:induction-candidates" (fun () => do - let mut total : Nat := 0 - let mut verified : Array CandidateInductionHead := Array.mkEmpty 0 - for circuit in deepCircuits do - match circuit.toInductionCandidateCore? cache with - | none => pure () - | some core => - total := total + 1 - if core.combinedError ≤ threshold then - match core.toInductionCandidate? cache with - | some cand => verified := verified.push cand - | none => pure () - let verifiedSorted := verified.qsort (·.combinedError < ·.combinedError) - return (total, verifiedSorted)) - IO.println s!" Found {verifiedHeads.size} verified induction heads \ - (of {totalInduction} candidates)\n" - - IO.println "Analysis complete!\n" - - -- Count total heads - let totalHeads := model.layers.foldl (fun acc layer => acc + layer.size) 0 - - return { - modelName := modelName - prompt := prompt - numLayers := model.numLayers - totalHeads := totalHeads - inductionHeads := verifiedHeads - deepCircuits := verifiedDeep - verification := none - } - -/-- Run analysis with empirical verification. -/ -def analyzeAndVerify (model : ConcreteModel) (modelName : String) - (threshold : Float := 0.1) - (prompt : Option String := none) : IO AnalysisReport := do - let baseReport ← analyzeModel model modelName threshold prompt - - IO.println "═══════════════════════════════════════════════════════════" - IO.println "Starting Empirical Verification" - IO.println "═══════════════════════════════════════════════════════════\n" - - IO.println "Running circuit discovery and ablation experiments..." - -- Run circuit discovery and verification - let (_, verification) ← timeIt "analysis:discover-and-verify" (fun () => - pure <| discoverAndVerify model threshold) - IO.println "Verification complete!\n" - - return { baseReport with verification := some verification } - -end Nfp diff --git a/Legacy/Nfp/IO/Pure.lean b/Legacy/Nfp/IO/Pure.lean deleted file mode 100644 index 5c3599b..0000000 --- a/Legacy/Nfp/IO/Pure.lean +++ /dev/null @@ -1,399 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Nfp.Discovery - -/-! -# Pure helpers for model IO - -Pure parsing, construction, and tokenization utilities shared by the CLI-facing IO layer. --/ - -namespace Nfp - -/-! ## Float Parsing Utilities -/ - -private def pow10PowTable : Array Float := - -- Precompute `Float.pow 10.0 k` for k=0..308 so we avoid calling `Float.pow` per token. - Array.ofFn fun k : Fin 309 => Float.pow 10.0 k.val.toFloat - -private def pow10Pow (n : Nat) : Float := - if n < pow10PowTable.size then - pow10PowTable[n]! - else - Float.pow 10.0 n.toFloat - -private def parseNatRange (s : String) (start stop : String.Pos.Raw) : Option Nat := Id.run do - let mut p := start - if p >= stop then - return none - let mut acc : Nat := 0 - let mut saw : Bool := false - while p < stop do - let c := p.get s - if ('0' <= c) && (c <= '9') then - acc := acc * 10 + (c.toNat - '0'.toNat) - saw := true - p := p.next s - else - return none - if saw then some acc else none - -private def parseFloatRange (s : String) (start stop : String.Pos.Raw) : Option Float := Id.run do - -- This is a faster, allocation-free version of the previous `parseFloat`, but it preserves - -- the exact Float computation structure (Nat parsing + `Float.pow`) to keep results stable. - - let mut p := start - if p >= stop then - return none - - let mut negative := false - let c0 := p.get s - if c0 = '-' then - negative := true - p := p.next s - else if c0 = '+' then - p := p.next s - - if p >= stop then - return none - - -- Find exponent marker the same way as the old parser: accept exactly one `e` if present, - -- otherwise accept exactly one `E`. - let mut ePos : Option String.Pos.Raw := none - let mut eCount : Nat := 0 - let mut EPos : Option String.Pos.Raw := none - let mut ECount : Nat := 0 - let mut q := p - while q < stop do - let c := q.get s - if c = 'e' then - eCount := eCount + 1 - if eCount = 1 then ePos := some q - else if c = 'E' then - ECount := ECount + 1 - if ECount = 1 then EPos := some q - q := q.next s - - let expMarker? : Option String.Pos.Raw := - if eCount = 1 then ePos else if ECount = 1 then EPos else none - - let mantEnd : String.Pos.Raw := - match expMarker? with - | some ep => ep - | none => stop - - -- Find decimal point in mantissa (must be 0 or 1 occurrences). - let mut dotPos : Option String.Pos.Raw := none - let mut dotCount : Nat := 0 - let mut r := p - while r < mantEnd do - if r.get s = '.' then - dotCount := dotCount + 1 - if dotCount = 1 then dotPos := some r - r := r.next s - if dotCount > 1 then - return none - - let (intStart, intStop, fracStart?, fracStop) := - match dotPos with - | none => (p, mantEnd, none, mantEnd) - | some dp => (p, dp, some (dp.next s), mantEnd) - - let intN? : Option Nat := - if dotPos.isSome && intStart = intStop then - some 0 - else - parseNatRange s intStart intStop - - let fracN? : Option Nat := - match fracStart? with - | none => none - | some fs => - if fs = fracStop then some 0 else parseNatRange s fs fracStop - - let mantissa? : Option Float := - match dotPos, intN?, fracN? with - | none, some iN, _ => - some iN.toFloat - | some _, some iN, some fN => - let fracLen := (fracStop.byteIdx - (fracStart?.getD fracStop).byteIdx) - let divisor := pow10Pow fracLen - some (iN.toFloat + fN.toFloat / divisor) - | some _, _, none => - -- `.` present but no fractional parse (shouldn't happen), treat as invalid. - none - | _, none, _ => none - - let some mantissa := mantissa? | return none - - let value : Float := - match expMarker? with - | none => mantissa - | some ep => - let expStart := ep.next s - if expStart >= stop then - mantissa - else - -- Parse exponent, but if it is malformed, ignore it (old behavior). - let c := expStart.get s - let (expNeg, es) := - if c = '-' then (true, expStart.next s) - else if c = '+' then (false, expStart.next s) - else (false, expStart) - match parseNatRange s es stop with - | none => mantissa - | some eNat => - let p10 := pow10Pow eNat - if expNeg then mantissa / p10 else mantissa * p10 - - some (if negative then -value else value) - -/-- Parse a floating point number from a string. -/ -def parseFloat (s : String) : Option Float := Id.run do - let s := s.trim - if s.isEmpty then - none - else - parseFloatRange s 0 s.rawEndPos - -@[inline] private def isWs (c : Char) : Bool := - c = ' ' || c = '\t' || c = '\n' || c = '\r' - -@[inline] private def foldTokensFromLine {α : Type} - (line : String) (init : α) - (step : α → String.Pos.Raw → String.Pos.Raw → α) : α := - Id.run do - let mut out := init - let mut p : String.Pos.Raw := 0 - let stop := line.rawEndPos - while p < stop do - while p < stop && isWs (p.get line) do - p := p.next line - let start := p - while p < stop && !isWs (p.get line) do - p := p.next line - if start < p then - out := step out start p - out - -private def appendFloatsFromLine (line : String) (acc : Array Float) : Array Float := - foldTokensFromLine line acc fun out start stop => - match parseFloatRange line start stop with - | some x => out.push x - | none => out - -private def parseFloatsFromLines (lines : Array String) (cap : Nat := 0) : Array Float := - Id.run do - let mut out : Array Float := Array.mkEmpty cap - for line in lines do - out := appendFloatsFromLine line out - out - -private def spawnParseFloats (lines : Array String) (cap : Nat := 0) : Task (Array Float) := - Task.spawn (fun _ => parseFloatsFromLines lines cap) - -/-- Parse a line of space-separated floats. -/ -def parseFloatLine (line : String) : Array Float := - appendFloatsFromLine line #[] - -/-! ## Nat Parsing Utilities -/ - -private def appendNatsFromLine (line : String) (acc : Array Nat) : Array Nat := - foldTokensFromLine line acc fun out start stop => - match parseNatRange line start stop with - | some n => out.push n - | none => out - -def parseNatLine (line : String) : Array Nat := - appendNatsFromLine line #[] - -/-! ## Matrix Construction for IO -/ - -/- Build a ConcreteMatrix from float data, padding or truncating as needed. - This is safe because we ensure the data has exactly the right size. -/ -def buildMatrix (rows cols : Nat) (data : Array Float) : ConcreteMatrix := - let expectedSize := rows * cols - -- Use Array.ofFn to get the exact size we need while padding/truncating via getD. - let finalData := Array.ofFn fun (i : Fin expectedSize) => - data.getD i.val 0.0 - { - numRows := rows - numCols := cols - data := finalData - size_eq := Array.size_ofFn - } - -/-! ## Load Result Helpers -/ - -/-- Result of loading a model. -/ -inductive LoadResult - | ok (model : ConcreteModel) - | error (msg : String) - -namespace LoadResult - -def isOk : LoadResult -> Bool - | ok _ => true - | error _ => false - -def getModel : LoadResult -> Option ConcreteModel - | ok m => some m - | error _ => none - -def getError : LoadResult -> Option String - | ok _ => none - | error msg => some msg - -end LoadResult - -/-! ## Text Format Parsing -/ - -/-- NFP file header structure. -/ -structure NfpHeader where - numLayers : Nat - numHeads : Nat - modelDim : Nat - headDim : Nat - hiddenDim : Nat - vocabSize : Nat - seqLen : Nat - deriving Repr - -/- Build a ConcreteAttentionLayer from weight matrices. - The dimension proofs are satisfied by construction (buildMatrix ensures correct sizes). -/ -def mkAttentionLayer - (modelDim headDim : Nat) - (wq wk wv wo bq bk bv : Array Float) : ConcreteAttentionLayer := - let wQ := buildMatrix modelDim headDim wq - let bQ := buildMatrix 1 headDim bq - let wK := buildMatrix modelDim headDim wk - let bK := buildMatrix 1 headDim bk - let wV := buildMatrix modelDim headDim wv - let bV := buildMatrix 1 headDim bv - let wO := buildMatrix headDim modelDim wo - { - modelDim := modelDim - headDim := headDim - W_Q := wQ - b_Q := bQ - W_K := wK - b_K := bK - W_V := wV - b_V := bV - W_O := wO - W_Q_dims := And.intro rfl rfl - b_Q_dims := And.intro rfl rfl - W_K_dims := And.intro rfl rfl - b_K_dims := And.intro rfl rfl - W_V_dims := And.intro rfl rfl - b_V_dims := And.intro rfl rfl - W_O_dims := And.intro rfl rfl - } - -/- Build a ConcreteMLPLayer from weight matrices. - The dimension proofs are satisfied by construction. -/ -def mkMLPLayer - (modelDim hiddenDim : Nat) - (win wout bin bout : Array Float) : ConcreteMLPLayer := - let wIn := buildMatrix modelDim hiddenDim win - let wOut := buildMatrix hiddenDim modelDim wout - let bIn := buildMatrix 1 hiddenDim bin - let bOut := buildMatrix 1 modelDim bout - { - modelDim := modelDim - hiddenDim := hiddenDim - W_in := wIn - W_out := wOut - b_in := bIn - b_out := bOut - W_in_dims := And.intro rfl rfl - W_out_dims := And.intro rfl rfl - b_in_dims := And.intro rfl rfl - b_out_dims := And.intro rfl rfl - } - -/-! ## Tokenization Utilities -/ - -/-- Simple tokenizer with vocabulary mapping. -/ -structure Tokenizer where - /-- Token strings in order of ID. -/ - tokens : Array String - /-- Map from token string to its first ID. -/ - tokMap : Std.HashMap String Nat - /-- Unknown token ID. -/ - unkId : Nat - /-- Padding token ID. -/ - padId : Nat - /-- End of sequence token ID. -/ - eosId : Nat - -namespace Tokenizer - -/-- Create a tokenizer from vocabulary list. -/ -def fromVocabList (tokens : Array String) - (unkId padId eosId : Nat := 0) : Tokenizer := - let tokMap := - Id.run do - let mut out : Std.HashMap String Nat := Std.HashMap.emptyWithCapacity tokens.size - let mut i := tokens.size - while i > 0 do - i := i - 1 - out := out.insert tokens[i]! i - return out - { tokens := tokens, tokMap := tokMap, unkId := unkId, padId := padId, eosId := eosId } - -/-- Find a token's ID in the vocabulary. -/ -def findToken (t : Tokenizer) (word : String) : Nat := - t.tokMap.getD word t.unkId - -/-- Tokenize a string using simple whitespace splitting. -/ -def tokenize (t : Tokenizer) (text : String) : Array Nat := - foldTokensFromLine text #[] fun out start stop => - let word := String.Pos.Raw.extract text start stop - out.push (t.findToken word) - -/-- Decode token IDs back to text. -/ -def decode (t : Tokenizer) (ids : Array Nat) : String := - let tokens := ids.foldr - (fun id acc => - if id < t.tokens.size then - t.tokens[id]! :: acc - else - acc) - [] - " ".intercalate tokens - -end Tokenizer - -/-! ## Embedding Utilities -/ - -/-- Look up embeddings for token IDs from the embedding matrix. -/ -def lookupEmbeddings (embeddings : ConcreteMatrix) (tokenIds : Array Nat) - (seqLen : Nat) (padId : Nat := 0) : ConcreteMatrix := Id.run do - let modelDim := embeddings.numCols - let rowCount := embeddings.numRows - let tokenIdsSize := tokenIds.size - let mut data : Array Float := Array.mkEmpty (seqLen * modelDim) - - for pos in [:seqLen] do - let tokenId := if pos < tokenIdsSize then tokenIds[pos]! else padId - -- Copy embedding row for this token. - if tokenId < rowCount then - let rowBase := tokenId * modelDim - for dim in [:modelDim] do - data := data.push embeddings.data[rowBase + dim]! - else - for _ in [:modelDim] do - data := data.push 0.0 - - buildMatrix seqLen modelDim data - -/-- Set the input embeddings in a model for a given prompt (token IDs). -/ -def ConcreteModel.withInputTokens (model : ConcreteModel) - (embeddings : ConcreteMatrix) (tokenIds : Array Nat) - (padId : Nat := 0) : ConcreteModel := - let inputEmb := lookupEmbeddings embeddings tokenIds model.seqLen padId - { model with inputEmbeddings := inputEmb, inputTokens := some tokenIds } - -end Nfp diff --git a/Legacy/Nfp/Induction.lean b/Legacy/Nfp/Induction.lean deleted file mode 100644 index 417848f..0000000 --- a/Legacy/Nfp/Induction.lean +++ /dev/null @@ -1,498 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.Real.Basic -import Mathlib.Data.Real.Sqrt -import Mathlib.Analysis.InnerProductSpace.Basic -import Mathlib.Algebra.BigOperators.Ring.Finset -import Mathlib.Algebra.Order.BigOperators.Ring.Finset -import Nfp.Linearization -import Nfp.SignedMixer -import Nfp.Sound.Bounds - -/-! -# True Induction Head Formalization - -A **True Induction Head** is a rigorously certified mechanism that combines three components: - -1. **Structure**: The attention patterns match an induction head (previous-token + induction), - with the previous-token leg modeled by a self-attention placeholder due to abstract indexing. -2. **Faithfulness**: The virtual head approximation (attention rollout) is ε-certified -3. **Function**: The mechanism effectively increases logit scores for the correct token by ≥ δ - -This module formalizes the definition and proves that true induction heads provide -verifiable guarantees about model behavior. - -## Key Insight - -Most interpretability claims are heuristic. A true induction head is different: it combines -pattern detection with causal certification and functional verification, proving that: - - The discovered mechanism is mathematically sound - - The simplification (attention rollout) is approximately correct - - The mechanism actually causes the predicted output - -Together, these provide end-to-end certification of model behavior. --/ - -namespace Nfp - - -variable {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - -/-! ## True Induction Head Definition -/ - -/-- Compute L² norm of a function over a finite type. -/ -noncomputable def l2_norm (v : (n × d) → ℝ) : ℝ := - Real.sqrt (∑ pos : n × d, (v pos) ^ 2) - -/-- Inner product of two functions over a finite type. -/ -noncomputable def inner_product (u v : (n × d) → ℝ) : ℝ := - ∑ pos : n × d, u pos * v pos - -/-- **True Induction Head**: A rigorously certified induction mechanism for a specific input. - -An induction head is "true" if it simultaneously satisfies three conditions: - -1. **Structural Pattern**: The attention weights exhibit the induction head - structure (Layer 1 uses a previous-token pattern, modeled as self-attention; Layer 2 - uses a nonnegativity placeholder for token-matching). - This is captured by an `InductionHeadPattern`. - -2. **Faithful Approximation**: The virtual head (composition of value terms, - aka "attention rollout") is ε-certified—it approximates the true composed Jacobian - within Frobenius norm ε. - -3. **Functional Effectiveness**: On the **specific input**, the virtual head's output, - when projected onto the target logit difference direction, produces at least δ increase - in score. This binds the abstract mechanism to the concrete model behavior. --/ -structure TrueInductionHead where - /-- The model input (residual stream at sequence positions) -/ - input : (n × d) → ℝ - /-- Certified induction head pattern (has layer1 and layer2 with attention properties) -/ - pattern : InductionHeadPattern (n := n) (d := d) - /-- The composed true Jacobian from input to output -/ - composed_jacobian : SignedMixer (n × d) (n × d) - /-- Target direction in residual stream space (how positions/dimensions contribute to target) -/ - target_logit_diff : (n × d) → ℝ - /-- Faithfulness bound: how close virtual head is to composed Jacobian -/ - epsilon : ℝ - /-- Functional effectiveness bound: minimum logit increase from this mechanism -/ - delta : ℝ - /-- Faithfulness: Virtual head approximates composed Jacobian within ε -/ - faithful : isCertifiedVirtualHead pattern.layer2 pattern.layer1 composed_jacobian epsilon - /-- Effectiveness: Virtual head applied to this input produces ≥ delta on target direction -/ - effective : inner_product (VirtualHead pattern.layer2 pattern.layer1 |>.apply input) - target_logit_diff ≥ delta - /-- Bounds are valid -/ - epsilon_nonneg : 0 ≤ epsilon - /-- Delta is nonnegative (can't guarantee negative output) -/ - delta_nonneg : 0 ≤ delta - -/-! ## Sound pattern witnesses -/ - -/-- Minimal token-match pattern witness used by the sound certification path. -/ -structure TokenMatchPattern where - /-- Sequence length for the certificate. -/ - seqLen : Nat - /-- Target offset (e.g. `-1` for previous token). -/ - targetOffset : Int - /-- Key-position offset used when matching tokens against the query's target token. -/ - keyOffset : Int - /-- Lower bound on the number of matching-token keys. -/ - targetCountLowerBound : Nat - /-- Effort level for the `expLB` portfolio used in margin-to-weight bounds. -/ - softmaxExpEffort : Nat - /-- Lower bound on total attention weight assigned to matching tokens. -/ - targetWeightLowerBound : Rat - /-- Lower bound on logit margin between matching vs non-matching keys. -/ - marginLowerBound : Rat - deriving Repr - -namespace TokenMatchPattern - -/-- Soundness invariant for token-match pattern witnesses. -/ -def Valid (p : TokenMatchPattern) : Prop := - p.seqLen > 0 ∧ - p.targetCountLowerBound ≤ p.seqLen ∧ - p.targetWeightLowerBound = - Sound.softmaxTargetWeightLowerBound p.seqLen p.targetCountLowerBound - p.marginLowerBound p.softmaxExpEffort - -instance (p : TokenMatchPattern) : Decidable (Valid p) := by - unfold Valid - infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (p : TokenMatchPattern) : Bool := - decide (Valid p) - -theorem check_iff (p : TokenMatchPattern) : p.check = true ↔ p.Valid := by - simp [check, Valid] - -/-- If the margin and target count are positive, the weight lower bound matches -the portfolio bound derived from `expLB`. -/ -theorem weight_lower_bound_of_margin_pos - (p : TokenMatchPattern) (h : p.Valid) (hm : p.marginLowerBound > 0) - (hcount : 0 < p.targetCountLowerBound) : - p.targetWeightLowerBound = - let nRat : Rat := (p.seqLen : Nat) - let tRat : Rat := (p.targetCountLowerBound : Nat) - let base := tRat / nRat - let e := Sound.expLB p.marginLowerBound p.softmaxExpEffort - let cand := (tRat * e) / (tRat * e + (nRat - tRat)) - max base cand := by - rcases h with ⟨hseq, _hcount, hweight⟩ - have hseq0 : p.seqLen ≠ 0 := Nat.ne_of_gt hseq - have hcount0 : p.targetCountLowerBound ≠ 0 := Nat.ne_of_gt hcount - simpa [Sound.softmaxTargetWeightLowerBound_def, hseq0, hcount0, hm] using hweight - -/-- If the margin is nonpositive, the weight lower bound is zero. -/ -theorem weight_lower_bound_of_margin_nonpos - (p : TokenMatchPattern) (h : p.Valid) (hm : p.marginLowerBound ≤ 0) : - p.targetWeightLowerBound = 0 := by - rcases h with ⟨_hseq, _hcount, hweight⟩ - have hm' : ¬ p.marginLowerBound > 0 := by - exact not_lt.mpr hm - by_cases hzero : p.seqLen = 0 || p.targetCountLowerBound = 0 - · simpa [Sound.softmaxTargetWeightLowerBound_def, hzero] using hweight - · simpa [Sound.softmaxTargetWeightLowerBound_def, hzero, hm'] using hweight - -/-- Positive margin and a positive target count imply positive attention mass. -/ -theorem weight_lower_bound_pos_of_margin_pos - (p : TokenMatchPattern) (h : p.Valid) (hm : p.marginLowerBound > 0) - (hcount : 0 < p.targetCountLowerBound) : - 0 < p.targetWeightLowerBound := by - let nRat : Rat := (p.seqLen : Nat) - let tRat : Rat := (p.targetCountLowerBound : Nat) - let base : Rat := tRat / nRat - let e := Sound.expLB p.marginLowerBound p.softmaxExpEffort - let cand : Rat := (tRat * e) / (tRat * e + (nRat - tRat)) - have hweight := weight_lower_bound_of_margin_pos p h hm hcount - rcases h with ⟨hseq, _hcount, _hweight⟩ - have hseq' : (0 : Rat) < nRat := by - have hseq'' : (0 : Rat) < (p.seqLen : Rat) := by - exact_mod_cast hseq - simpa [nRat] using hseq'' - have hcount' : (0 : Rat) < tRat := by - have hcount'' : (0 : Rat) < (p.targetCountLowerBound : Rat) := by - exact_mod_cast hcount - simpa [tRat] using hcount'' - have hbase : (0 : Rat) < base := by - exact div_pos hcount' hseq' - have hmax : base ≤ max base cand := by - exact le_max_left _ _ - have hpos : (0 : Rat) < max base cand := by - exact lt_of_lt_of_le hbase hmax - simpa [nRat, tRat, base, e, cand, hweight] using hpos - -/-- Either the margin is nonpositive (so the bound is zero), -or the bound is positive when the match count is positive. -/ -theorem weight_lower_bound_dichotomy - (p : TokenMatchPattern) (h : p.Valid) (hcount : 0 < p.targetCountLowerBound) : - p.marginLowerBound ≤ 0 ∨ 0 < p.targetWeightLowerBound := by - by_cases hm : p.marginLowerBound > 0 - · right - exact weight_lower_bound_pos_of_margin_pos p h hm hcount - · left - exact not_lt.mp hm - -end TokenMatchPattern - -/-! ## Sound induction witnesses -/ - -/-- A minimal sound witness for an induction-style attention pattern. -/ -structure InductionPatternWitness where - /-- Token-match pattern data (sound certificate output). -/ - tokenMatch : TokenMatchPattern - /-- The pattern targets the previous-token offset. -/ - prevOffset : tokenMatch.targetOffset = -1 - /-- The key-token comparison uses no key offset. -/ - keyOffsetZero : tokenMatch.keyOffset = 0 - /-- Certified nontrivial attention mass on matching tokens. -/ - positiveMass : 0 < tokenMatch.targetWeightLowerBound - deriving Repr - -/-- A minimal sound witness for a copy-next induction-style attention pattern. -/ -structure CopyNextPatternWitness where - /-- Token-match pattern data (sound certificate output). -/ - tokenMatch : TokenMatchPattern - /-- The pattern uses the current query token as the target. -/ - targetOffsetZero : tokenMatch.targetOffset = 0 - /-- Keys are matched against the previous-token stream (copy-next). -/ - keyOffsetPrev : tokenMatch.keyOffset = -1 - /-- Certified nontrivial attention mass on matching tokens. -/ - positiveMass : 0 < tokenMatch.targetWeightLowerBound - deriving Repr - -namespace TokenMatchPattern - -/-- Build an induction-style witness from a valid token-match pattern plus explicit assumptions. -/ -def toInductionPatternWitness - (p : TokenMatchPattern) (h : p.Valid) (hm : p.marginLowerBound > 0) - (hcount : 0 < p.targetCountLowerBound) (hoff : p.targetOffset = -1) - (hkey : p.keyOffset = 0) : - InductionPatternWitness := - { - tokenMatch := p - prevOffset := hoff - keyOffsetZero := hkey - positiveMass := weight_lower_bound_pos_of_margin_pos p h hm hcount - } - -/-- Build a copy-next witness from a valid token-match pattern plus explicit assumptions. -/ -def toCopyNextPatternWitness - (p : TokenMatchPattern) (h : p.Valid) (hm : p.marginLowerBound > 0) - (hcount : 0 < p.targetCountLowerBound) (hoff : p.targetOffset = 0) - (hkey : p.keyOffset = -1) : - CopyNextPatternWitness := - { - tokenMatch := p - targetOffsetZero := hoff - keyOffsetPrev := hkey - positiveMass := weight_lower_bound_pos_of_margin_pos p h hm hcount - } - -end TokenMatchPattern - -/-! ## Verification Theorems -/ - -omit [DecidableEq n] [DecidableEq d] in -/-- **Main Theorem**: True Induction Head Bounds - -Any true induction head has nonnegative epsilon and delta bounds by definition. -/ -theorem true_induction_head_bounds_nonneg {h : TrueInductionHead (n := n) (d := d)} : - (h.epsilon ≥ 0) ∧ (h.delta ≥ 0) := - ⟨h.epsilon_nonneg, h.delta_nonneg⟩ - -omit [DecidableEq n] [DecidableEq d] in -/-- **Key Property**: Virtual head achieves the stated delta bound. - -By definition of `TrueInductionHead`, the virtual head applied to the input -achieves at least delta on the target direction. --/ -lemma virtual_head_achieves_delta {h : TrueInductionHead (n := n) (d := d)} : - inner_product ((VirtualHead h.pattern.layer2 h.pattern.layer1).apply h.input) - h.target_logit_diff ≥ h.delta := - h.effective - -/-! ## Properties of True Induction Heads -/ - -/-- The virtual head output on the certified input. -/ -noncomputable def virtual_head_output {h : TrueInductionHead (n := n) (d := d)} : - (n × d) → ℝ := - (VirtualHead h.pattern.layer2 h.pattern.layer1).apply h.input - -/-- The virtual head's score on the target direction. -/ -noncomputable def virtual_head_score {h : TrueInductionHead (n := n) (d := d)} : ℝ := - inner_product (virtual_head_output (h := h)) h.target_logit_diff - -/-- The approximation error bound. -/ -abbrev approx_error {h : TrueInductionHead (n := n) (d := d)} : ℝ := - h.epsilon - -/-- The functional guarantee on the virtual head. -/ -abbrev min_logit_shift {h : TrueInductionHead (n := n) (d := d)} : ℝ := - h.delta - -omit [DecidableEq n] [DecidableEq d] in -/-- **Composition of mechanisms**: Composed error bound. - -If two true induction heads have errors ε₁ and ε₂ respectively, their -composition has bounded error from the rule: ε_total ≤ ε₁ + ε₂ + ε₁·ε₂. --/ -theorem true_induction_head_composition - (h₁ h₂ : TrueInductionHead (n := n) (d := d)) - (ε : ℝ) - (hε_bound : ε ≤ h₁.epsilon + h₂.epsilon + h₁.epsilon * h₂.epsilon) : - ε ≤ h₁.epsilon + h₂.epsilon + h₁.epsilon * h₂.epsilon := hε_bound - -omit [DecidableEq n] [DecidableEq d] in -/-- **Interpretability Guarantee**: True induction heads are real mechanisms. -/ -theorem true_induction_head_is_genuine - (h : TrueInductionHead (n := n) (d := d)) : - (∃ L₁ L₂, h.pattern.layer1 = L₁ ∧ h.pattern.layer2 = L₂) ∧ - (isCertifiedVirtualHead h.pattern.layer2 h.pattern.layer1 h.composed_jacobian h.epsilon) ∧ - (inner_product ((VirtualHead h.pattern.layer2 h.pattern.layer1).apply h.input) - h.target_logit_diff ≥ h.delta) := by - exact ⟨⟨h.pattern.layer1, h.pattern.layer2, rfl, rfl⟩, h.faithful, h.effective⟩ - -/-! ## Helper inequality: Frobenius norm bounds application -/ - -omit [DecidableEq n] [DecidableEq d] in -/-- For any signed mixer `M` and vector `v`, the output L² norm is bounded by the Frobenius -norm of `M` times the input L² norm. -/ -lemma norm_apply_le (M : SignedMixer (n × d) (n × d)) (v : (n × d) → ℝ) : - l2_norm (M.apply v) ≤ frobeniusNorm (n := n) (d := d) M * l2_norm v := by - classical - set A : ℝ := ∑ i : n × d, (v i) ^ 2 - set C : ℝ := ∑ i : n × d, ∑ j : n × d, (M.w i j) ^ 2 - have hA : 0 ≤ A := by - simpa [A] using (Finset.sum_nonneg (fun i _hi => sq_nonneg (v i))) - have hC : 0 ≤ C := by - -- two nested sums of squares - have : 0 ≤ ∑ i : n × d, ∑ j : n × d, (M.w i j) ^ 2 := by - refine Finset.sum_nonneg ?_ - intro i _hi - refine Finset.sum_nonneg ?_ - intro j _hj - exact sq_nonneg (M.w i j) - simpa [C] using this - have hpoint : - ∀ j : n × d, (M.apply v j) ^ 2 ≤ A * (∑ i : n × d, (M.w i j) ^ 2) := by - intro j - -- Cauchy–Schwarz (squared form) on the dot product defining `(M.apply v) j`. - simpa [SignedMixer.apply_def, A] using - (Finset.sum_mul_sq_le_sq_mul_sq (s := (Finset.univ : Finset (n × d))) - (f := v) (g := fun i : n × d => M.w i j)) - have hsum : - (∑ j : n × d, (M.apply v j) ^ 2) ≤ ∑ j : n × d, A * (∑ i : n × d, (M.w i j) ^ 2) := by - refine Finset.sum_le_sum ?_ - intro j _hj - exact hpoint j - have hsum' : - (∑ j : n × d, (M.apply v j) ^ 2) ≤ A * (∑ j : n × d, ∑ i : n × d, (M.w i j) ^ 2) := by - have hfac : - A * (∑ j : n × d, ∑ i : n × d, (M.w i j) ^ 2) = - ∑ j : n × d, A * (∑ i : n × d, (M.w i j) ^ 2) := by - simpa using (Finset.mul_sum (s := (Finset.univ : Finset (n × d))) (a := A) - (f := fun j : n × d => ∑ i : n × d, (M.w i j) ^ 2)) - calc - (∑ j : n × d, (M.apply v j) ^ 2) - ≤ ∑ j : n × d, A * (∑ i : n × d, (M.w i j) ^ 2) := hsum - _ = A * (∑ j : n × d, ∑ i : n × d, (M.w i j) ^ 2) := by - simp [hfac] - have hsum'' : - (∑ j : n × d, (M.apply v j) ^ 2) ≤ A * C := by - have hswap : - (∑ j : n × d, ∑ i : n × d, (M.w i j) ^ 2) = - ∑ i : n × d, ∑ j : n × d, (M.w i j) ^ 2 := by - simpa using (Finset.sum_comm : - (∑ j : n × d, ∑ i : n × d, (M.w i j) ^ 2) = - ∑ i : n × d, ∑ j : n × d, (M.w i j) ^ 2) - calc - (∑ j : n × d, (M.apply v j) ^ 2) - ≤ A * (∑ j : n × d, ∑ i : n × d, (M.w i j) ^ 2) := hsum' - _ = A * C := by - simp [C, hswap] - -- take square roots and unfold definitions - calc - l2_norm (M.apply v) - = Real.sqrt (∑ j : n × d, (M.apply v j) ^ 2) := rfl - _ ≤ Real.sqrt (A * C) := by - exact Real.sqrt_le_sqrt hsum'' - _ = Real.sqrt A * Real.sqrt C := by - simpa using (Real.sqrt_mul hA C) - _ = frobeniusNorm (n := n) (d := d) M * l2_norm v := by - simp [frobeniusNorm, l2_norm, C, A, mul_comm] - -omit [DecidableEq n] [DecidableEq d] in -/-- Finite-dimensional Cauchy–Schwarz for the `inner_product`/`l2_norm` defined in this file. -/ -lemma abs_inner_product_le_l2 (u v : (n × d) → ℝ) : - |inner_product u v| ≤ l2_norm u * l2_norm v := by - classical - have hcs : - (inner_product u v) ^ 2 ≤ (∑ i : n × d, (u i) ^ 2) * (∑ i : n × d, (v i) ^ 2) := by - simpa [inner_product] using - (Finset.sum_mul_sq_le_sq_mul_sq (s := (Finset.univ : Finset (n × d))) (f := u) (g := v)) - have hu : 0 ≤ ∑ i : n × d, (u i) ^ 2 := by - simpa using (Finset.sum_nonneg (fun i _hi => sq_nonneg (u i))) - have hv : 0 ≤ ∑ i : n × d, (v i) ^ 2 := by - simpa using (Finset.sum_nonneg (fun i _hi => sq_nonneg (v i))) - calc - |inner_product u v| - = Real.sqrt ((inner_product u v) ^ 2) := by - simpa using (Real.sqrt_sq_eq_abs (inner_product u v)).symm - _ ≤ Real.sqrt ((∑ i : n × d, (u i) ^ 2) * (∑ i : n × d, (v i) ^ 2)) := by - exact Real.sqrt_le_sqrt hcs - _ = Real.sqrt (∑ i : n × d, (u i) ^ 2) * Real.sqrt (∑ i : n × d, (v i) ^ 2) := by - simpa using (Real.sqrt_mul hu (∑ i : n × d, (v i) ^ 2)) - _ = l2_norm u * l2_norm v := by - rfl - -omit [DecidableEq n] [DecidableEq d] in -/-- **Main verification theorem**: a `TrueInductionHead` lower-bounds the real model score -on the target direction by `δ` minus the certified approximation error. -/ -theorem true_induction_head_predicts_logits - (h : TrueInductionHead (n := n) (d := d)) : - inner_product (h.composed_jacobian.apply h.input) h.target_logit_diff ≥ - h.delta - (h.epsilon * l2_norm h.input * l2_norm h.target_logit_diff) := by - classical - let V : SignedMixer (n × d) (n × d) := VirtualHead h.pattern.layer2 h.pattern.layer1 - let E : SignedMixer (n × d) (n × d) := h.composed_jacobian - V - have hE : frobeniusNorm (n := n) (d := d) E ≤ h.epsilon := by - simpa [E, V, isCertifiedVirtualHead] using h.faithful - have hV : h.delta ≤ inner_product (V.apply h.input) h.target_logit_diff := by - simpa [V] using h.effective - have happly_add : (V + E).apply h.input = V.apply h.input + E.apply h.input := by - ext j - simp [SignedMixer.apply_def, Finset.sum_add_distrib, mul_add] - have hJ_eq : h.composed_jacobian = V + E := by - ext i j - simp [E, V] - have hdecomp : - inner_product (h.composed_jacobian.apply h.input) h.target_logit_diff = - inner_product (V.apply h.input) h.target_logit_diff + - inner_product (E.apply h.input) h.target_logit_diff := by - have happly : - h.composed_jacobian.apply h.input = V.apply h.input + E.apply h.input := by - simpa [hJ_eq] using happly_add - have hinner_add (a b u : (n × d) → ℝ) : - inner_product (a + b) u = inner_product a u + inner_product b u := by - simp [inner_product, Finset.sum_add_distrib, add_mul] - calc - inner_product (h.composed_jacobian.apply h.input) h.target_logit_diff - = inner_product (V.apply h.input + E.apply h.input) h.target_logit_diff := by - simp [happly] - _ = inner_product (V.apply h.input) h.target_logit_diff + - inner_product (E.apply h.input) h.target_logit_diff := by - simpa using hinner_add (a := V.apply h.input) (b := E.apply h.input) - (u := h.target_logit_diff) - set bound : ℝ := h.epsilon * l2_norm h.input * l2_norm h.target_logit_diff - have hbound_nonneg : 0 ≤ bound := by - have hx : 0 ≤ l2_norm h.input := by simp [l2_norm] - have hu : 0 ≤ l2_norm h.target_logit_diff := by simp [l2_norm] - have : 0 ≤ h.epsilon * l2_norm h.input := mul_nonneg h.epsilon_nonneg hx - simpa [bound, mul_assoc] using mul_nonneg this hu - have herr_abs : - |inner_product (E.apply h.input) h.target_logit_diff| ≤ bound := by - have habs : - |inner_product (E.apply h.input) h.target_logit_diff| ≤ - l2_norm (E.apply h.input) * l2_norm h.target_logit_diff := by - simpa using (abs_inner_product_le_l2 (n := n) (d := d) (u := E.apply h.input) - (v := h.target_logit_diff)) - have hnormEx : - l2_norm (E.apply h.input) ≤ frobeniusNorm (n := n) (d := d) E * l2_norm h.input := by - simpa using (norm_apply_le (n := n) (d := d) E h.input) - have hu : 0 ≤ l2_norm h.target_logit_diff := by simp [l2_norm] - have hx : 0 ≤ l2_norm h.input := by simp [l2_norm] - have hstep1 : - l2_norm (E.apply h.input) * l2_norm h.target_logit_diff ≤ - (frobeniusNorm (n := n) (d := d) E * l2_norm h.input) * l2_norm h.target_logit_diff := - mul_le_mul_of_nonneg_right hnormEx hu - have hstep2 : - (frobeniusNorm (n := n) (d := d) E * l2_norm h.input) * l2_norm h.target_logit_diff ≤ - (h.epsilon * l2_norm h.input) * l2_norm h.target_logit_diff := by - have : frobeniusNorm (n := n) (d := d) E * l2_norm h.input ≤ h.epsilon * l2_norm h.input := - mul_le_mul_of_nonneg_right hE hx - exact mul_le_mul_of_nonneg_right this hu - have hchain := le_trans hstep1 hstep2 - have hchain' : - l2_norm (E.apply h.input) * l2_norm h.target_logit_diff ≤ bound := by - simpa [bound, mul_assoc, mul_left_comm, mul_comm] using hchain - exact le_trans habs hchain' - have herr_lower : -bound ≤ inner_product (E.apply h.input) h.target_logit_diff := by - exact (abs_le.mp herr_abs).1 - -- Combine: = + ≥ δ + (-bound) = δ - bound - have hsum_le : - h.delta + (-bound) ≤ - inner_product (V.apply h.input) h.target_logit_diff + - inner_product (E.apply h.input) h.target_logit_diff := by - exact add_le_add hV herr_lower - -- rewrite the goal via the decomposition - have : - h.delta - bound ≤ - inner_product (h.composed_jacobian.apply h.input) h.target_logit_diff := by - simpa [sub_eq_add_neg, hdecomp] using hsum_le - exact this - -end Nfp diff --git a/Legacy/Nfp/Influence.lean b/Legacy/Nfp/Influence.lean deleted file mode 100644 index ed041df..0000000 --- a/Legacy/Nfp/Influence.lean +++ /dev/null @@ -1,342 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.NNReal.Basic -import Mathlib.Data.Fintype.Basic -import Mathlib.Data.Fintype.Prod -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Data.Finset.Basic -import Mathlib.Algebra.Field.Basic -import Mathlib.Logic.Equiv.Fin.Basic -import Mathlib.Tactic.DeriveFintype -import Nfp.Prob -import Nfp.Mixer - -/- -Core influence specifications and helpers that sit below mixers. Influence -specs carry raw, non-normalized edge capacities; `Mixer.ofInfluenceSpec` turns -them into row-stochastic mixers with a small fallback for empty rows. This -file also introduces lightweight sign and hierarchy indices (`Chan`, `HSite`, -`BigSite`) together with convenience projections for tracers living on the -combined site. --/ - -namespace Nfp - -open scoped BigOperators -open Finset - -/-! ## Influence specs -/ - -/-- A raw influence description: adjacency plus nonnegative edge capacities. -/ -structure InfluenceSpec (Site : Type*) where - adj : Site → Site → Prop - κ : ∀ {s t}, adj s t → NNReal - -/-- A family of influence specs indexed by different “views”. -/ -structure InfluenceSpecFamily (Site : Type*) where - View : Type* - spec : View → InfluenceSpec Site - -namespace InfluenceSpecFamily - -variable {Site : Type*} - -/-- Convexly combine multiple influence specs using nonnegative weights `α`. -Specs with zero weight do not contribute to the merged adjacency. -/ -noncomputable def combine (F : InfluenceSpecFamily Site) [Fintype F.View] - (α : F.View → NNReal) (_hα : ∑ v, α v = 1) : InfluenceSpec Site := - { - adj := fun s t => ∃ v, α v ≠ 0 ∧ (F.spec v).adj s t - κ := by - intro s t _ - classical - exact ∑ v, α v * (if h : (F.spec v).adj s t then (F.spec v).κ h else 0) - } - -end InfluenceSpecFamily - -namespace InfluenceSpec - -variable {Site : Type*} [Fintype Site] [DecidableEq Site] - -/-- Sum of outgoing raw capacities from a site. -/ -noncomputable def rowTotal (I : InfluenceSpec Site) (s : Site) : NNReal := by - classical - exact ∑ t, (if h : I.adj s t then I.κ h else 0) - -/-! ### Row scaling helpers -/ - -/-- Scale all raw capacities in a single row of an influence spec. -/ -noncomputable def scaleRow (I : InfluenceSpec Site) (s0 : Site) (c : NNReal) : - InfluenceSpec Site := - { - adj := I.adj - κ := by - intro s t h - by_cases hs : s = s0 - · subst hs - exact c * I.κ h - · exact I.κ h - } - -lemma rowTotal_scaleRow_self (I : InfluenceSpec Site) (s0 : Site) (c : NNReal) : - InfluenceSpec.rowTotal (Site := Site) (scaleRow (Site := Site) I s0 c) s0 = - c * InfluenceSpec.rowTotal (Site := Site) I s0 := by - classical - simpa [InfluenceSpec.rowTotal, scaleRow] using - (Finset.mul_sum (s := (Finset.univ : Finset Site)) - (f := fun t : Site => (if h : I.adj s0 t then I.κ h else 0)) - (a := c)).symm - -lemma rowTotal_scaleRow_other (I : InfluenceSpec Site) {s s0 : Site} (c : NNReal) - (hs : s ≠ s0) : - InfluenceSpec.rowTotal (Site := Site) (scaleRow (Site := Site) I s0 c) s = - InfluenceSpec.rowTotal (Site := Site) I s := by - classical - simp [InfluenceSpec.rowTotal, scaleRow, hs] - -end InfluenceSpec - -namespace InfluenceSpecFamily - -variable {Site : Type*} - -lemma combine_adj_iff (F : InfluenceSpecFamily Site) [Fintype F.View] - (α : F.View → NNReal) (hα : ∑ v, α v = 1) (s t : Site) : - (combine (F:=F) α hα).adj s t ↔ ∃ v, α v ≠ 0 ∧ (F.spec v).adj s t := - Iff.rfl - -end InfluenceSpecFamily - -/-! ## Mixers derived from influence specs -/ - -namespace Mixer - -variable {Site : Type*} [Fintype Site] [DecidableEq Site] - -/-- Canonical mixer obtained by normalizing each row of an influence spec. -If a row has zero total capacity, we fall back to an identity row on that site, -interpreting the site as an absorbing state. -/ -noncomputable def ofInfluenceSpec (I : InfluenceSpec Site) : Mixer Site Site := by - classical - let Z : Site → NNReal := InfluenceSpec.rowTotal (Site := Site) I - refine - { - w := fun s t => - if hZ : Z s = 0 then - if hst : s = t then 1 else 0 - else - if hAdj : I.adj s t then I.κ hAdj / Z s else 0 - row_sum_one := by - intro s - by_cases hZ : Z s = 0 - · have hdiag : - (∑ t : Site, (if s = t then (1 : NNReal) else 0)) = 1 := by - classical - simp - simp [Z, hZ, hdiag] - · have hZne : Z s ≠ 0 := hZ - have hnormalize : - (∑ t : Site, (if h : I.adj s t then I.κ h / Z s else 0)) = - (∑ t : Site, (if h : I.adj s t then I.κ h else 0)) * (1 / Z s) := by - classical - have hrewrite : - (∑ t : Site, (if h : I.adj s t then I.κ h / Z s else 0)) = - (∑ t : Site, (if h : I.adj s t then I.κ h else 0) * (1 / Z s)) := by - refine Finset.sum_congr rfl ?_ - intro t _ht - by_cases hAdj : I.adj s t <;> simp [hAdj, div_eq_mul_inv, mul_comm] - have hfactor : - (∑ t : Site, (if h : I.adj s t then I.κ h else 0) * (1 / Z s)) = - (∑ t : Site, (if h : I.adj s t then I.κ h else 0)) * (1 / Z s) := by - simpa using - (Finset.sum_mul (s := (Finset.univ : Finset Site)) - (f := fun t : Site => if h : I.adj s t then I.κ h else 0) - (a := (1 / Z s))).symm - exact hrewrite.trans hfactor - have hrow : (∑ t : Site, (if h : I.adj s t then I.κ h else 0)) = Z s := rfl - have hnormalized : - (∑ t : Site, (if h : I.adj s t then I.κ h / Z s else 0)) = 1 := by - have hdiv : - (∑ t : Site, (if h : I.adj s t then I.κ h else 0)) * (1 / Z s) = - (1 : NNReal) := by - simp [hrow, div_eq_mul_inv, hZne] - exact hnormalize.trans hdiv - simpa [Z, hZ] using hnormalized - } - -lemma ofInfluenceSpec_zero_of_not_adj (I : InfluenceSpec Site) - {s t : Site} (hZ : InfluenceSpec.rowTotal (Site := Site) I s ≠ 0) - (hAdj : ¬ I.adj s t) : - (ofInfluenceSpec (Site := Site) I).w s t = 0 := by - classical - simp [ofInfluenceSpec, hZ, hAdj] - -lemma ofInfluenceSpec_adj_weight (I : InfluenceSpec Site) - {s t : Site} (hZ : InfluenceSpec.rowTotal (Site := Site) I s ≠ 0) - (hAdj : I.adj s t) : - (ofInfluenceSpec (Site := Site) I).w s t = - I.κ hAdj / InfluenceSpec.rowTotal (Site := Site) I s := by - classical - simp [ofInfluenceSpec, hZ, hAdj] - -lemma ofInfluenceSpec_supported (I : InfluenceSpec Site) - (hZ : ∀ s, InfluenceSpec.rowTotal (Site := Site) I s ≠ 0) : - Mixer.supported (S := Site) (T := Site) (ofInfluenceSpec (Site := Site) I) I.adj := by - intro s t hAdj - exact ofInfluenceSpec_zero_of_not_adj (Site := Site) (I:=I) (hZ s) hAdj - -lemma ofInfluenceSpec_row_scaling (I : InfluenceSpec Site) (s0 : Site) {c : NNReal} - (hc : c ≠ 0) (t : Site) : - (ofInfluenceSpec (Site := Site) (InfluenceSpec.scaleRow (Site := Site) I s0 c)).w s0 t = - (ofInfluenceSpec (Site := Site) I).w s0 t := by - classical - by_cases hZ : InfluenceSpec.rowTotal (Site := Site) I s0 = 0 - · have hZ' : - InfluenceSpec.rowTotal (Site := Site) - (InfluenceSpec.scaleRow (Site := Site) I s0 c) s0 = 0 := by - simp [InfluenceSpec.rowTotal_scaleRow_self (I:=I) (s0:=s0) (c:=c), hZ] - simp [ofInfluenceSpec, hZ, hZ'] - · have hrow : - InfluenceSpec.rowTotal (Site := Site) (InfluenceSpec.scaleRow (Site := Site) I s0 c) - s0 = c * InfluenceSpec.rowTotal (Site := Site) I s0 := - InfluenceSpec.rowTotal_scaleRow_self (I:=I) (s0:=s0) (c:=c) - have hZ' : - InfluenceSpec.rowTotal (Site := Site) - (InfluenceSpec.scaleRow (Site := Site) I s0 c) s0 ≠ 0 := by - have hne : c * InfluenceSpec.rowTotal (Site := Site) I s0 ≠ 0 := - mul_ne_zero hc hZ - simpa [hrow] using hne - by_cases hAdj : I.adj s0 t - · have hκ : - (InfluenceSpec.scaleRow (Site := Site) I s0 c).κ hAdj = c * I.κ hAdj := by - simp [InfluenceSpec.scaleRow] - have hleft : - (ofInfluenceSpec (Site := Site) (InfluenceSpec.scaleRow (Site := Site) I s0 c)).w s0 t = - (InfluenceSpec.scaleRow (Site := Site) I s0 c).κ hAdj / - InfluenceSpec.rowTotal (Site := Site) - (InfluenceSpec.scaleRow (Site := Site) I s0 c) s0 := - ofInfluenceSpec_adj_weight (I := InfluenceSpec.scaleRow (Site := Site) I s0 c) hZ' hAdj - have hright : - (ofInfluenceSpec (Site := Site) I).w s0 t = - I.κ hAdj / InfluenceSpec.rowTotal (Site := Site) I s0 := - ofInfluenceSpec_adj_weight (I := I) hZ hAdj - have hcancel : - (InfluenceSpec.scaleRow (Site := Site) I s0 c).κ hAdj / - InfluenceSpec.rowTotal - (Site := Site) (InfluenceSpec.scaleRow (Site := Site) I s0 c) s0 = - c * I.κ hAdj / (c * InfluenceSpec.rowTotal (Site := Site) I s0) := by - simp [hκ, hrow] - calc - (ofInfluenceSpec (Site := Site) (InfluenceSpec.scaleRow (Site := Site) I s0 c)).w s0 t = - c * I.κ hAdj / (c * InfluenceSpec.rowTotal (Site := Site) I s0) := by - simp [hcancel] at hleft - simpa using hleft - _ = I.κ hAdj / InfluenceSpec.rowTotal (Site := Site) I s0 := by - have hc' : (c : NNReal) ≠ 0 := hc - simpa [mul_comm, mul_left_comm, mul_assoc] using - (mul_div_mul_left (a := I.κ hAdj) - (b := InfluenceSpec.rowTotal (Site := Site) I s0) - (c := c) (hc := hc')) - _ = (ofInfluenceSpec (Site := Site) I).w s0 t := hright.symm - · have hleft : - (ofInfluenceSpec (Site := Site) (InfluenceSpec.scaleRow (Site := Site) I s0 c)).w s0 t = - 0 := - ofInfluenceSpec_zero_of_not_adj - (I := InfluenceSpec.scaleRow (Site := Site) I s0 c) hZ' hAdj - have hright : - (ofInfluenceSpec (Site := Site) I).w s0 t = 0 := - ofInfluenceSpec_zero_of_not_adj (I := I) hZ hAdj - simp [hleft, hright] - -end Mixer - -/-! ## Sign channels and hierarchical sites -/ - -/-- Explicit sign channels; sign is tracked structurally, not by negative mass. -/ -inductive Chan - | pos - | neg -deriving DecidableEq, Fintype - -/-- Hierarchical site: either a base node or a group-level aggregate. -/ -inductive HSite (Base Group : Type*) : Type _ - | base : Base → HSite Base Group - | group : Group → HSite Base Group -deriving DecidableEq - -namespace HSite - -variable {Base Group : Type*} - -/-- An equivalence to a sum type, used to reuse existing finite instances. -/ -def equivSum : HSite Base Group ≃ Sum Base Group := - { - toFun := fun - | base b => Sum.inl b - | group g => Sum.inr g - invFun := fun - | Sum.inl b => base b - | Sum.inr g => group g - left_inv := by - intro x - cases x <;> rfl - right_inv := by - intro x - cases x <;> rfl - } - -instance [Fintype Base] [Fintype Group] : Fintype (HSite Base Group) := - Fintype.ofEquiv (Sum Base Group) (equivSum (Base:=Base) (Group:=Group)).symm - -end HSite - -/-- Combined site carrying an objective, a base node with sign, or a group. -/ -abbrev BigSite (Obj Node Group : Type*) := - Obj × HSite (Node × Chan) Group - -/-! ## Tracer view helpers on `BigSite` -/ - -namespace TracerViews - -section - -variable {Obj Node Group : Type*} [Fintype Obj] [Fintype Node] [Fintype Group] -variable (p : ProbVec (BigSite Obj Node Group)) - -/-- Positive-channel tracer mass for a specific objective/node pair. -/ -noncomputable def posTracer (o : Obj) (n : Node) : NNReal := - p.mass (o, HSite.base (n, Chan.pos)) - -/-- Negative-channel tracer mass for a specific objective/node pair. -/ -noncomputable def negTracer (o : Obj) (n : Node) : NNReal := - p.mass (o, HSite.base (n, Chan.neg)) - -/-- Net tracer mass (`pos - neg`) for a specific objective/node pair. -/ -noncomputable def netTracer (o : Obj) (n : Node) : ℝ := - (posTracer (p:=p) o n : ℝ) - (negTracer (p:=p) o n : ℝ) - -/-- Group-level tracer mass for an objective and group. -/ -noncomputable def groupTracer (o : Obj) (g : Group) : NNReal := - p.mass (o, HSite.group g) - -@[simp] lemma posTracer_eval (o : Obj) (n : Node) : - posTracer (p:=p) o n = p.mass (o, HSite.base (n, Chan.pos)) := by - rfl - -@[simp] lemma negTracer_eval (o : Obj) (n : Node) : - negTracer (p:=p) o n = p.mass (o, HSite.base (n, Chan.neg)) := by - rfl - -@[simp] lemma groupTracer_eval (o : Obj) (g : Group) : - groupTracer (p:=p) o g = p.mass (o, HSite.group g) := by - rfl - -end - -end TracerViews - -open TracerViews (posTracer negTracer netTracer groupTracer) -export TracerViews (posTracer negTracer netTracer groupTracer) - -end Nfp diff --git a/Legacy/Nfp/Layers.lean b/Legacy/Nfp/Layers.lean deleted file mode 100644 index 6b39833..0000000 --- a/Legacy/Nfp/Layers.lean +++ /dev/null @@ -1,1046 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.NNReal.Basic -import Mathlib.Data.Fintype.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Data.Finset.Basic -import Mathlib.Data.List.Basic -import Nfp.Prob -import Nfp.Mixer -import Nfp.Uniqueness - -/-! -# Neural Network Layer Mixers - -This module formalizes common neural network layer operations as mixers, -establishing the connection between abstract row-stochastic operators and -concrete NN architectures. This enables applying the tracer uniqueness and -attribution theorems to real neural network interpretation. - -## Main definitions - -* `Mixer.identity` – identity/skip connection -* `Mixer.attention` – attention mechanism as a mixer -* `Mixer.selfAttention` – self-attention variant -* `Mixer.residual` – residual connection combining identity with transform - -## Key theorems - -* `Mixer.identity_comp` – identity is a left/right unit for composition -* `Mixer.comp_assoc` – composition is associative -* `effectiveAttention_normalized` – attention rollout forms valid distributions -* `Mixer.push_comp` – pushing through composition equals sequential pushing - -## Neural network interpretation - -The key insight is that many NN operations can be viewed as row-stochastic -operators when considering how "importance" or "relevance" flows: - -- Attention: importance flows according to attention weights -- Skip connections: importance passes through unchanged -- Residual: weighted combination of skip and transform - -## References - -* Abnar & Zuidema: "Quantifying Attention Flow in Transformers" (2020) --/ - -namespace Nfp - -open scoped BigOperators -open Finset - -/-! ## Identity mixer -/ - -section Identity - -variable {S : Type*} [Fintype S] [DecidableEq S] - -/-- The identity mixer: each source routes entirely to itself. -/ -noncomputable def Mixer.identity : Mixer S S where - w := fun i j => if i = j then 1 else 0 - row_sum_one := by - classical - intro i - simp only [Finset.sum_ite_eq, Finset.mem_univ, ↓reduceIte] - -@[simp] lemma Mixer.identity_diag (i : S) : Mixer.identity.w i i = 1 := by - simp [Mixer.identity] - -@[simp] lemma Mixer.identity_off_diag {i j : S} (h : i ≠ j) : - Mixer.identity.w i j = 0 := by - simp [Mixer.identity, h] - -/-- Identity is a left unit for mixer composition. -/ -@[simp] theorem Mixer.identity_comp (M : Mixer S S) : - Mixer.identity.comp M = M := by - ext i k - simp only [Mixer.comp, Mixer.identity] - classical - simp only [ite_mul, one_mul, zero_mul, Finset.sum_ite_eq, Finset.mem_univ, ↓reduceIte] - -/-- Identity is a right unit for mixer composition. -/ -@[simp] theorem Mixer.comp_identity (M : Mixer S S) : - M.comp Mixer.identity = M := by - ext i k - simp only [Mixer.comp, Mixer.identity] - classical - simp only [mul_ite, mul_one, mul_zero, Finset.sum_ite_eq', Finset.mem_univ, ↓reduceIte] - -end Identity - -/-! ## Mixer composition is associative -/ - -section Associativity - -variable {S T U V : Type*} - [Fintype S] [Fintype T] [Fintype U] [Fintype V] - -/-- Mixer composition is associative. -/ -theorem Mixer.comp_assoc (M : Mixer S T) (N : Mixer T U) (P : Mixer U V) : - (M.comp N).comp P = M.comp (N.comp P) := by - ext i l - simp only [Mixer.comp, Finset.sum_mul, Finset.mul_sum] - rw [Finset.sum_comm] - simp_rw [mul_assoc] - -end Associativity - -/-! ## Attention as a mixer -/ - -section Attention - -variable {Query Key : Type*} [Fintype Query] [Fintype Key] - -/-- Attention weights derived from query-key scores. -Given attention scores `α : Query → Key → NNReal` that are row-normalized -(each query's weights over keys sum to 1), this produces a mixer. -/ -noncomputable def Mixer.attention - (α : Query → Key → NNReal) - (hα : ∀ q, (∑ k, α q k) = 1) : Mixer Query Key where - w := α - row_sum_one := hα - -/-- Self-attention: queries and keys are the same set of positions. -/ -noncomputable def Mixer.selfAttention {Pos : Type*} [Fintype Pos] - (α : Pos → Pos → NNReal) - (hα : ∀ p, (∑ p', α p p') = 1) : Mixer Pos Pos := - Mixer.attention α hα - -/-- Attention is supported on positions with nonzero attention weight. -/ -lemma Mixer.attention_supported {Query Key : Type*} [Fintype Query] [Fintype Key] - (α : Query → Key → NNReal) - (hα : ∀ q, (∑ k, α q k) = 1) : - Mixer.supported (Mixer.attention α hα) (fun q k => α q k ≠ 0) := by - intro q k hne - by_cases hzero : α q k = 0 - · simp [Mixer.attention, hzero] - · exact (hne hzero).elim - -end Attention - -/-! ## Skip/Residual connections -/ - -section Residual - -variable {S : Type*} [Fintype S] [DecidableEq S] - -/-- Residual mixer: mixes identity with another mixer using coefficient `c ∈ [0,1]`. -This models skip connections: `output = c * identity + (1-c) * transform`. -/ -noncomputable def Mixer.residual (M : Mixer S S) (c : NNReal) (hc : c ≤ 1) : Mixer S S where - w := fun i j => c * (if i = j then 1 else 0) + (1 - c) * M.w i j - row_sum_one := by - classical - intro i - simp only [Finset.sum_add_distrib, ← Finset.mul_sum] - simp only [M.row_sum_one, Finset.sum_ite_eq, Finset.mem_univ, ↓reduceIte, mul_one] - rw [add_comm] - exact tsub_add_cancel_of_le hc - -/-- A pure skip connection is the identity mixer. -/ -lemma Mixer.residual_one (M : Mixer S S) : - Mixer.residual M 1 le_rfl = Mixer.identity := by - ext i j - simp [Mixer.residual, Mixer.identity] - -/-- No skip connection passes through the transform entirely. -/ -lemma Mixer.residual_zero (M : Mixer S S) : - Mixer.residual M 0 (zero_le _) = M := by - ext i j - simp [Mixer.residual] - -/-- Residual weight decomposition: the weight splits into identity and transform parts. -/ -lemma Mixer.residual_w (M : Mixer S S) (c : NNReal) (hc : c ≤ 1) (i j : S) : - (Mixer.residual M c hc).w i j = - c * Mixer.identity.w i j + (1 - c) * M.w i j := by - simp only [Mixer.residual, Mixer.identity] - -end Residual - -/-! ## Attention flow composition theorems -/ - -section AttentionFlow - -variable {Pos : Type*} [Fintype Pos] [DecidableEq Pos] - -/-- Attention flow: composition of attention matrices across layers. -This captures how attention "flows" across layers in list order (row-vector convention). - -This is the formal version of "attention rollout" from Abnar & Zuidema (2020). -/ -noncomputable def attentionFlow (layers : List (Mixer Pos Pos)) : Mixer Pos Pos := - layers.foldl Mixer.comp Mixer.identity - -/-- Single layer attention flow is identity composed with the layer. -/ -lemma attentionFlow_singleton (M : Mixer Pos Pos) : - attentionFlow [M] = Mixer.identity.comp M := rfl - -/-- Empty attention flow is identity. -/ -@[simp] lemma attentionFlow_nil : - attentionFlow (Pos := Pos) [] = Mixer.identity := rfl - -/-- Attention flow of a single layer simplifies to just the layer. -/ -@[simp] lemma attentionFlow_singleton' (M : Mixer Pos Pos) : - attentionFlow [M] = M := by - simp [attentionFlow_singleton] - -/-- Two-layer attention flow is just composition. -/ -lemma attentionFlow_two (M₁ M₂ : Mixer Pos Pos) : - attentionFlow [M₁, M₂] = M₁.comp M₂ := by - simp [attentionFlow] - -/-- Three-layer attention flow is associative composition. -/ -lemma attentionFlow_three (M₁ M₂ M₃ : Mixer Pos Pos) : - attentionFlow [M₁, M₂, M₃] = (M₁.comp M₂).comp M₃ := by - simp [attentionFlow] - -end AttentionFlow - -/-! ## Layer composition helpers -/ - -section LayerComp - -variable {S T U V : Type*} - [Fintype S] [Fintype T] [Fintype U] [Fintype V] - -/-- Compose three layers (common pattern: embed → transform → project). -/ -noncomputable def Mixer.comp3 - (M₁ : Mixer S T) (M₂ : Mixer T U) (M₃ : Mixer U V) : Mixer S V := - (M₁.comp M₂).comp M₃ - -/-- comp3 is equivalent to right-associated composition. -/ -lemma Mixer.comp3_eq_comp_comp (M₁ : Mixer S T) (M₂ : Mixer T U) (M₃ : Mixer U V) : - M₁.comp3 M₂ M₃ = M₁.comp (M₂.comp M₃) := by - simp [Mixer.comp3, Mixer.comp_assoc] - -end LayerComp - -/-! ## Transformer blocks -/ - -section TransformerBlock - -variable {Pos : Type*} [Fintype Pos] [DecidableEq Pos] - -/-- A full transformer block conceptually: attention + feedforward with residuals. - -In a Pre-LN transformer (e.g. GPT-2): `y = x + Attention(LayerNorm(x))` followed by -`output = y + FFN(LayerNorm(y))`. We model this as composition of residual mixers. - -The coefficients `c_attn` and `c_ff` control how much of the skip connection -vs the transformed value flows through. -/ -noncomputable def Mixer.transformerBlock - (attn : Mixer Pos Pos) - (ff : Mixer Pos Pos) - (c_attn c_ff : NNReal) - (h_attn : c_attn ≤ 1) (h_ff : c_ff ≤ 1) : Mixer Pos Pos := - (Mixer.residual attn c_attn h_attn).comp (Mixer.residual ff c_ff h_ff) - -/-- A transformer block with no skip connections is just attention then FFN. -/ -lemma Mixer.transformerBlock_no_skip (attn ff : Mixer Pos Pos) : - Mixer.transformerBlock attn ff 0 0 (zero_le _) (zero_le _) = attn.comp ff := by - simp [Mixer.transformerBlock, Mixer.residual_zero] - -/-- A transformer block with full skip connections is identity. -/ -lemma Mixer.transformerBlock_full_skip (attn ff : Mixer Pos Pos) : - Mixer.transformerBlock attn ff 1 1 le_rfl le_rfl = Mixer.identity := by - simp [Mixer.transformerBlock, Mixer.residual_one] - -end TransformerBlock - -/-! ## Stacking transformer layers -/ - -section TransformerStack - -variable {Pos : Type*} [Fintype Pos] [DecidableEq Pos] - -/-- A stack of transformer blocks. -/ -noncomputable def transformerStack (blocks : List (Mixer Pos Pos)) : Mixer Pos Pos := - attentionFlow blocks - -/-- The effective attention from position `i` to position `j` through a stack -of `n` transformer layers is given by the composed mixer weight. -/ -noncomputable def effectiveAttention - (blocks : List (Mixer Pos Pos)) (i j : Pos) : NNReal := - (transformerStack blocks).w i j - -/-- Effective attention forms a probability distribution over target positions -for each source position. This is a key property for interpretation: -it tells us "how much" each source position contributes to each target. -/ -theorem effectiveAttention_normalized (blocks : List (Mixer Pos Pos)) (i : Pos) : - (∑ j, effectiveAttention blocks i j) = 1 := by - simp only [effectiveAttention, transformerStack] - exact (attentionFlow blocks).row_sum_one i - -end TransformerStack - -/-! ## Path-based decomposition - -This section provides the key insight connecting mixer composition to -path-based attribution. The weight `(M.comp N).w i k` can be decomposed -as a sum over intermediate positions, corresponding to paths through -the computation graph. --/ - -section PathDecomposition - -variable {S T U V : Type*} [Fintype S] [Fintype T] [Fintype U] [Fintype V] - -/-- The composition weight decomposes as a sum over paths through the intermediate layer. -This is the foundation for path-integrated attribution methods. -/ -theorem Mixer.comp_path_decomposition (M : Mixer S T) (N : Mixer T U) (i : S) (k : U) : - (M.comp N).w i k = ∑ j, M.w i j * N.w j k := rfl - -/-- The contribution of path `i → j → k` to the total weight `i → k`. -/ -noncomputable def pathContrib (M : Mixer S T) (N : Mixer T U) (i : S) (j : T) (k : U) : NNReal := - M.w i j * N.w j k - -/-- Path contributions sum to the total weight. -/ -theorem pathContrib_sum (M : Mixer S T) (N : Mixer T U) (i : S) (k : U) : - (∑ j, pathContrib M N i j k) = (M.comp N).w i k := by - simp only [pathContrib, Mixer.comp] - -/-- For three-layer composition, paths go through two intermediate positions. -/ -theorem Mixer.comp3_path_decomposition - (M₁ : Mixer S T) (M₂ : Mixer T U) (M₃ : Mixer U V) (i : S) (l : V) : - (M₁.comp3 M₂ M₃).w i l = ∑ j, ∑ k, M₁.w i j * M₂.w j k * M₃.w k l := by - simp only [Mixer.comp3, Mixer.comp, Finset.sum_mul] - rw [Finset.sum_comm] - -end PathDecomposition - -/-! ## Conservation theorems - -Key insight: mixer operations preserve total probability mass. -This connects to the completeness axiom in attribution theory. --/ - -section Conservation - -variable {S T U : Type*} [Fintype S] [Fintype T] [Fintype U] - -/-- Pushing a probability vector through a mixer preserves total mass. -This is the probabilistic interpretation of "conservation". -/ -theorem Mixer.push_preserves_total_mass (M : Mixer S T) (p : ProbVec S) : - (∑ j, (M.push p).mass j) = ∑ i, p.mass i := by - simp only [ProbVec.sum_mass] - -/-- The pushed mass at position `j` is the weighted sum of source masses. -/ -lemma Mixer.push_mass_eq (M : Mixer S T) (p : ProbVec S) (j : T) : - (M.push p).mass j = ∑ i, p.mass i * M.w i j := rfl - -/-- Conservation for composition: pushing through composed mixers -equals pushing through each sequentially. -/ -theorem Mixer.push_comp (M : Mixer S T) (N : Mixer T U) (p : ProbVec S) : - (M.comp N).push p = N.push (M.push p) := by - ext k - simp only [Mixer.push, Mixer.comp] - simp only [Finset.mul_sum, Finset.sum_mul] - rw [Finset.sum_comm] - simp_rw [mul_assoc] - -end Conservation - -/-! ## Residual stream decomposition - -A key insight for transformer interpretation: residual connections create -multiple "paths" through the network. The effective contribution from source -to target can be decomposed into: -1. Direct path: information flows unchanged through the skip connection -2. Indirect path: information is transformed by the attention/FFN layer - -This decomposition is crucial for understanding how much a layer actually -contributes vs how much just passes through unchanged. --/ - -section ResidualDecomposition - -variable {S : Type*} [Fintype S] [DecidableEq S] - -/-- **Residual decomposition theorem**: The residual mixer weight at (i,j) equals -the sum of direct (skip) and indirect (transform) path contributions. - -This is the formal version of: "How much information flows directly vs through the layer?" -/ -theorem Mixer.residual_decomposition (M : Mixer S S) (c : NNReal) (hc : c ≤ 1) (i j : S) : - (Mixer.residual M c hc).w i j = - c * (if i = j then 1 else 0) + (1 - c) * M.w i j := by - simp only [Mixer.residual] - -/-- **Skip connection dominance**: When the skip coefficient c is large, -the residual is close to identity. Specifically, the diagonal entries -are at least c. -/ -theorem Mixer.residual_skip_dominance (M : Mixer S S) (c : NNReal) (hc : c ≤ 1) (i : S) : - (Mixer.residual M c hc).w i i ≥ c := by - simp only [Mixer.residual, ↓reduceIte, mul_one, le_add_iff_nonneg_right, zero_le] - -/-- Off-diagonal entries are bounded by the indirect path contribution. -/ -theorem Mixer.residual_off_diag_bound (M : Mixer S S) (c : NNReal) (hc : c ≤ 1) - {i j : S} (hij : i ≠ j) : - (Mixer.residual M c hc).w i j = (1 - c) * M.w i j := by - simp only [Mixer.residual, hij, ↓reduceIte, mul_zero, zero_add] - -/-- **Interpretation insight**: Off-diagonal entries are scaled down by (1-c). -If c = 0.9, off-diagonal influence is reduced to 10% of original. -This quantifies how residual connections "protect" self-information. -/ -theorem Mixer.residual_off_diag_scaling (M : Mixer S S) (c : NNReal) (hc : c ≤ 1) - {i j : S} (hij : i ≠ j) : - (Mixer.residual M c hc).w i j ≤ M.w i j := by - rw [Mixer.residual_off_diag_bound M c hc hij] - have h1 : (1 - c) ≤ 1 := tsub_le_self - calc (1 - c) * M.w i j ≤ 1 * M.w i j := mul_le_mul_of_nonneg_right h1 (zero_le _) - _ = M.w i j := one_mul _ - -end ResidualDecomposition - -/-! ## Attention concentration and information flow - -When attention is concentrated on few positions, this limits how information -can spread through the network. These theorems formalize bounds on information flow. --/ - -section AttentionConcentration - -variable {S : Type*} [Fintype S] - -/-- The maximum attention weight from position i determines an upper bound -on how much relevance can flow to any single source position. -/ -noncomputable def Mixer.maxWeight (M : Mixer S S) (i : S) : NNReal := - Finset.sup' Finset.univ ⟨i, Finset.mem_univ i⟩ (fun j => M.w i j) - -/-- Any weight is at most the maxWeight. -/ -lemma Mixer.weight_le_maxWeight (M : Mixer S S) (i j : S) : - M.w i j ≤ M.maxWeight i := by - simp only [Mixer.maxWeight] - exact Finset.le_sup' (fun k => M.w i k) (Finset.mem_univ j) - -/-- **Attention bottleneck**: The pushed mass at any position j is bounded by -the sum over i of (mass at i) × (max attention weight from i). -In other words, if attention is spread thin, mass can't concentrate. -/ -theorem Mixer.push_concentration_bound (M : Mixer S S) (p : ProbVec S) (j : S) : - (M.push p).mass j ≤ ∑ i, p.mass i * M.maxWeight i := by - simp only [Mixer.push] - apply Finset.sum_le_sum - intro i _ - exact mul_le_mul_of_nonneg_left (M.weight_le_maxWeight i j) (zero_le _) - -end AttentionConcentration - -/-! ## Ablation and masking analysis - -When we "ablate" or "mask" certain positions, we effectively zero out their -contribution. This section formalizes the effect of such interventions. --/ - -section Ablation - -variable {S : Type*} [Fintype S] - -/-- A masked weight function: positions in the mask set have their outgoing weights zeroed. -This models "what if we remove these positions from consideration?" - -Note: This is sub-stochastic (rows of blocked positions don't sum to 1). -/ -noncomputable def Mixer.maskFn (M : Mixer S S) (blocked : Set S) [DecidablePred blocked] : - S → S → NNReal := - fun i j => if blocked i then 0 else M.w i j - -/-- Masking a position removes its contribution entirely. -/ -lemma Mixer.maskFn_blocked (M : Mixer S S) (blocked : Set S) [DecidablePred blocked] - {i : S} (hi : blocked i) (j : S) : - M.maskFn blocked i j = 0 := by - simp [Mixer.maskFn, hi] - -/-- Unblocked positions keep their original weights. -/ -lemma Mixer.maskFn_unblocked (M : Mixer S S) (blocked : Set S) [DecidablePred blocked] - {i : S} (hi : ¬blocked i) (j : S) : - M.maskFn blocked i j = M.w i j := by - simp [Mixer.maskFn, hi] - -/-- The contribution from blocked positions to position j. -/ -noncomputable def blockedContribution (M : Mixer S S) (p : ProbVec S) - (blocked : Set S) [DecidablePred blocked] (j : S) : NNReal := - ∑ i : S, if blocked i then p.mass i * M.w i j else 0 - -/-- The contribution from unblocked positions to position j. -/ -noncomputable def unblockedContribution (M : Mixer S S) (p : ProbVec S) - (blocked : Set S) [DecidablePred blocked] (j : S) : NNReal := - ∑ i : S, if blocked i then 0 else p.mass i * M.w i j - -/-- **Ablation decomposition**: The pushed mass equals blocked plus unblocked contributions. -/ -theorem Mixer.ablation_decomposition (M : Mixer S S) (p : ProbVec S) - (blocked : Set S) [DecidablePred blocked] (j : S) : - (M.push p).mass j = - unblockedContribution M p blocked j + blockedContribution M p blocked j := by - simp only [Mixer.push, unblockedContribution, blockedContribution] - rw [← Finset.sum_add_distrib] - apply Finset.sum_congr rfl - intro i _ - split_ifs <;> simp - -end Ablation - -/-! ## Composition depth and information spread - -As we compose more layers, information can spread to more positions. -These theorems characterize how "reach" grows with depth. --/ - -section CompositionDepth - -variable {S : Type*} [Fintype S] - -/-- A position j is "reachable" from i through mixer M if M.w i j > 0. -/ -def Mixer.reachable (M : Mixer S S) (i j : S) : Prop := M.w i j ≠ 0 - -/-- **Reach expansion**: If j is reachable from i through M, and k is reachable -from j through N, then the composition has at least the product weight. -This shows how influence compounds through layers. -/ -theorem Mixer.reach_comp (M N : Mixer S S) {i j k : S} - (_ : M.reachable i j) (_ : N.reachable j k) : - (M.comp N).w i k ≥ M.w i j * N.w j k := by - simp only [Mixer.comp, ge_iff_le] - exact Finset.single_le_sum (f := fun x => M.w i x * N.w x k) - (fun x _ => zero_le _) (Finset.mem_univ j) - -/-- **Path contribution bound**: The contribution through any single intermediate j -is at most the composed weight. -/ -theorem Mixer.path_contrib_le_comp (M N : Mixer S S) (i j k : S) : - M.w i j * N.w j k ≤ (M.comp N).w i k := by - simp only [Mixer.comp] - exact Finset.single_le_sum (f := fun x => M.w i x * N.w x k) - (fun x _ => zero_le _) (Finset.mem_univ j) - -/-- Composing mixers preserves reachability through any nonzero path. -/ -theorem Mixer.comp_reachable_of_path (M N : Mixer S S) {i j k : S} - (hij : M.w i j ≠ 0) (hjk : N.w j k ≠ 0) : - (M.comp N).reachable i k := by - simp only [Mixer.reachable, Mixer.comp, ne_eq] - intro h - have hterm : M.w i j * N.w j k = 0 := by - have hle : M.w i j * N.w j k ≤ ∑ x, M.w i x * N.w x k := - Finset.single_le_sum (f := fun x => M.w i x * N.w x k) - (fun x _ => zero_le _) (Finset.mem_univ j) - rw [h] at hle - exact le_antisymm hle (zero_le _) - cases mul_eq_zero.mp hterm with - | inl h => exact hij h - | inr h => exact hjk h - -end CompositionDepth - -/-! ## Information-theoretic bounds - -These theorems connect mixer properties to information-theoretic concepts, -providing bounds on how much "information" can flow through attention layers. --/ - -section InformationBounds - -variable {S : Type*} [Fintype S] - -/-- The "effective support size" from position i: how many positions receive -nonzero attention. Smaller support = more concentrated attention. -/ -noncomputable def Mixer.supportSize (M : Mixer S S) (i : S) : ℕ := - (Finset.univ.filter (fun j => M.w i j ≠ 0)).card - -/-- Row-stochasticity means at least one entry is nonzero (assuming S nonempty). -/ -lemma Mixer.exists_nonzero [Nonempty S] (M : Mixer S S) (i : S) : ∃ j, M.w i j ≠ 0 := by - by_contra h - push_neg at h - have hsum : ∑ j, M.w i j = 0 := Finset.sum_eq_zero (fun j _ => h j) - rw [M.row_sum_one i] at hsum - exact one_ne_zero hsum - -/-- **Support size bound**: In a nonempty type, every row has positive support. -/ -theorem Mixer.supportSize_pos [Nonempty S] (M : Mixer S S) (i : S) : - M.supportSize i ≥ 1 := by - simp only [Mixer.supportSize] - obtain ⟨j, hj⟩ := M.exists_nonzero i - exact Finset.one_le_card.mpr ⟨j, by simp [hj]⟩ - -end InformationBounds - -/-! ## Gradient-attribution correspondence - -A key insight: for linear layers, mixer-based attribution corresponds exactly -to gradient-based attribution. This section establishes this correspondence. --/ - -section GradientCorrespondence - -variable {S T U : Type*} [Fintype S] [Fintype T] [Fintype U] - -/-- **Gradient-attribution alignment**: For composed linear layers, the composed -attribution equals the product of individual attributions, summed over paths. -This is analogous to the chain rule for gradients. -/ -theorem Mixer.chain_rule_analog (M : Mixer S T) (N : Mixer T U) (i : S) (k : U) : - (M.comp N).w i k = ∑ j, M.w i j * N.w j k := rfl - -/-- **Three-layer chain rule**: The attribution through three layers -decomposes into a double sum over intermediate positions. -/ -theorem Mixer.chain_rule_three (M₁ : Mixer S T) (M₂ : Mixer T U) (M₃ : Mixer U S) (i l : S) : - (M₁.comp (M₂.comp M₃)).w i l = ∑ j, ∑ k, M₁.w i j * M₂.w j k * M₃.w k l := by - simp only [Mixer.comp, Finset.mul_sum, mul_assoc] - -end GradientCorrespondence - -/-! ## Multi-head attention - -Real transformers use multiple attention heads that are combined. This section -formalizes multi-head attention and proves key properties about how heads combine. - -In practice, each head has its own attention pattern, and the outputs are -concatenated and projected. For relevance/attribution purposes, this is equivalent -to a weighted combination of the individual head attention patterns. --/ - -section MultiHead - -variable {Pos : Type*} [Fintype Pos] -variable {numHeads : ℕ} - -/-- Multi-head attention combines multiple attention heads with weights. -The weights typically come from the output projection and represent how much -each head contributes to the final output. - -For interpretation: if we want to know "how much does position j contribute -to position i", we sum over heads weighted by their importance. -/ -noncomputable def Mixer.multiHead - (heads : Fin numHeads → Mixer Pos Pos) - (headWeights : Fin numHeads → NNReal) - (hsum : ∑ h, headWeights h = 1) : Mixer Pos Pos := - { w := fun i j => ∑ h, headWeights h * (heads h).w i j, - row_sum_one := by - intro i - rw [Finset.sum_comm] - simp_rw [← Finset.mul_sum, Mixer.row_sum_one, mul_one, hsum] } - -/-- Each head's contribution to the multi-head attention is bounded by its weight. -/ -theorem Mixer.multiHead_head_contrib_bound - (heads : Fin numHeads → Mixer Pos Pos) - (headWeights : Fin numHeads → NNReal) - (hsum : ∑ h, headWeights h = 1) - (h : Fin numHeads) (i j : Pos) : - headWeights h * (heads h).w i j ≤ (Mixer.multiHead heads headWeights hsum).w i j := by - simp only [Mixer.multiHead] - exact Finset.single_le_sum (f := fun k => headWeights k * (heads k).w i j) - (fun _ _ => zero_le _) (Finset.mem_univ h) - -/-- **Head importance theorem**: A head with zero weight contributes nothing. -This formalizes the intuition that "unimportant" heads can be pruned. -/ -theorem Mixer.multiHead_zero_weight - (heads : Fin numHeads → Mixer Pos Pos) - (headWeights : Fin numHeads → NNReal) - (h : Fin numHeads) (hw : headWeights h = 0) (i j : Pos) : - headWeights h * (heads h).w i j = 0 := by - simp [hw] - -/-- **Single head dominance**: If one head has weight 1 (others have weight 0), -multi-head attention reduces to that single head's attention. -/ -theorem Mixer.multiHead_single_head - (heads : Fin numHeads → Mixer Pos Pos) - (headWeights : Fin numHeads → NNReal) - (hsum : ∑ h, headWeights h = 1) - (h₀ : Fin numHeads) (hdom : headWeights h₀ = 1) - (hzero : ∀ h, h ≠ h₀ → headWeights h = 0) (i j : Pos) : - (Mixer.multiHead heads headWeights hsum).w i j = (heads h₀).w i j := by - simp only [Mixer.multiHead] - have hsplit : ∑ h, headWeights h * (heads h).w i j = - headWeights h₀ * (heads h₀).w i j + - ∑ h ∈ Finset.univ.erase h₀, headWeights h * (heads h).w i j := by - rw [← Finset.add_sum_erase _ _ (Finset.mem_univ h₀)] - rw [hsplit, hdom, one_mul, add_eq_left] - apply Finset.sum_eq_zero - intro h hh - simp only [Finset.mem_erase, Finset.mem_univ, ne_eq] at hh - simp [hzero h hh.1] - -/-- The multi-head attention weight is a convex combination of individual head weights. -/ -theorem Mixer.multiHead_convex - (heads : Fin numHeads → Mixer Pos Pos) - (headWeights : Fin numHeads → NNReal) - (hsum : ∑ h, headWeights h = 1) - (i j : Pos) : - (Mixer.multiHead heads headWeights hsum).w i j ≤ 1 := by - simp only [Mixer.multiHead] - calc ∑ h, headWeights h * (heads h).w i j - ≤ ∑ h, headWeights h * 1 := by - apply Finset.sum_le_sum - intro h _ - apply mul_le_mul_of_nonneg_left _ (zero_le _) - have hrow := (heads h).row_sum_one i - have hle : (heads h).w i j ≤ ∑ k, (heads h).w i k := - Finset.single_le_sum (f := fun k => (heads h).w i k) - (fun _ _ => zero_le _) (Finset.mem_univ j) - rw [hrow] at hle - exact hle - _ = ∑ h, headWeights h := by simp - _ = 1 := hsum - -end MultiHead - -/-! ## Causal masking - -Autoregressive models (GPT-style) use causal masking: position i can only attend -to positions j ≤ i. This creates a triangular attention pattern with important -consequences for information flow and attribution. --/ - -section CausalMask - -variable {n : ℕ} - -/-- A causal attention mask: position i can attend to j only if j ≤ i. -This models autoregressive/decoder-only transformers like GPT. -/ -def isCausal (M : Mixer (Fin n) (Fin n)) : Prop := - ∀ i j : Fin n, j.val > i.val → M.w i j = 0 - -/-- **Causal reachability**: In a causal mixer, information can only flow -from later to earlier positions (or stay in place). -/ -theorem causal_reachable_dir (M : Mixer (Fin n) (Fin n)) (hcaus : isCausal M) - {i j : Fin n} (hreach : M.reachable i j) : j.val ≤ i.val := by - by_contra h - push_neg at h - have := hcaus i j h - exact hreach this - -/-- Composition of causal mixers is causal. This means stacking causal attention -layers preserves the causal property. -/ -theorem causal_comp (M N : Mixer (Fin n) (Fin n)) - (hM : isCausal M) (hN : isCausal N) : isCausal (M.comp N) := by - intro i j hij - simp only [Mixer.comp] - apply Finset.sum_eq_zero - intro k _ - by_cases hk : k.val > i.val - · simp [hM i k hk] - · push_neg at hk - by_cases hkj : j.val > k.val - · simp [hN k j hkj] - · push_neg at hkj - -- k ≤ i and j ≤ k, so j ≤ i, contradicting hij - omega - -/-- **Causal information bound**: In a causal model, the total attention from -position i to future positions is zero. All attention goes to past/current. -/ -theorem causal_future_attention_zero (M : Mixer (Fin n) (Fin n)) (hcaus : isCausal M) - (i : Fin n) : ∑ j ∈ Finset.univ.filter (fun j => j.val > i.val), M.w i j = 0 := by - apply Finset.sum_eq_zero - intro j hj - simp only [Finset.mem_filter, Finset.mem_univ, true_and] at hj - exact hcaus i j hj - -/-- In a causal mixer, all attention mass goes to positions ≤ i. -/ -theorem causal_past_attention_one (M : Mixer (Fin n) (Fin n)) (hcaus : isCausal M) - (i : Fin n) : ∑ j ∈ Finset.univ.filter (fun j => j.val ≤ i.val), M.w i j = 1 := by - have htotal := M.row_sum_one i - have hfuture := causal_future_attention_zero M hcaus i - -- Show the two filters partition univ - have hpart : Finset.univ.filter (fun j : Fin n => j.val ≤ i.val) ∪ - Finset.univ.filter (fun j => i.val < j.val) = Finset.univ := by - ext x - simp only [Finset.mem_union, Finset.mem_filter, Finset.mem_univ, true_and, iff_true] - exact le_or_gt x.val i.val - have hdisj : Disjoint (Finset.univ.filter (fun j : Fin n => j.val ≤ i.val)) - (Finset.univ.filter (fun j => i.val < j.val)) := by - rw [Finset.disjoint_filter] - intro x _ hle hlt - omega - have key : ∑ j, M.w i j = ∑ j ∈ Finset.univ.filter (fun j => j.val ≤ i.val), M.w i j + - ∑ j ∈ Finset.univ.filter (fun j => i.val < j.val), M.w i j := by - conv_lhs => rw [← hpart] - rw [Finset.sum_union hdisj] - simp only [hfuture, add_zero, htotal] at key - exact key.symm - -/-- **First token theorem**: In a causal model, the first position (index 0) -can only attend to itself, so its self-attention weight is 1. -/ -theorem causal_first_token_self (M : Mixer (Fin (n + 1)) (Fin (n + 1))) - (hcaus : isCausal M) : M.w 0 0 = 1 := by - have h := causal_past_attention_one M hcaus 0 - simp only [Fin.val_zero] at h - have hfilt : Finset.univ.filter (fun j : Fin (n + 1) => j.val ≤ 0) = {0} := by - ext x - simp only [Finset.mem_filter, Finset.mem_univ, true_and, Finset.mem_singleton] - constructor - · intro hx - exact Fin.ext (Nat.le_zero.mp hx) - · intro hx - simp [hx] - rw [hfilt] at h - simpa using h - -end CausalMask - -/-! ## Attention head analysis - -Tools for analyzing individual attention heads and their roles. --/ - -section HeadAnalysis - -variable {Pos : Type*} [Fintype Pos] - -/-- The "concentration" of attention from position i: sum of squared weights. -Higher value = more concentrated on few positions. Lower = more spread out. -This is a measure related to the inverse of entropy. -/ -noncomputable def Mixer.attentionConcentration (M : Mixer Pos Pos) (i : Pos) : NNReal := - ∑ j, (M.w i j) ^ 2 - -/-- Concentration is at most 1 (achieved when all attention on one position). -/ -theorem Mixer.attentionConcentration_upper_bound (M : Mixer Pos Pos) (i : Pos) : - M.attentionConcentration i ≤ 1 := by - simp only [Mixer.attentionConcentration] - have hsum := M.row_sum_one i - calc ∑ j, (M.w i j) ^ 2 - ≤ ∑ j, M.w i j := by - apply Finset.sum_le_sum - intro j _ - rw [sq] - have hle : M.w i j ≤ ∑ k, M.w i k := - Finset.single_le_sum (f := fun k => M.w i k) (fun _ _ => zero_le _) (Finset.mem_univ j) - rw [hsum] at hle - calc M.w i j * M.w i j ≤ M.w i j * 1 := mul_le_mul_of_nonneg_left hle (zero_le _) - _ = M.w i j := mul_one _ - _ = 1 := hsum - -/-- **Sparsity indicator**: An attention head is "sparse" at position i if its -concentration is high (close to 1). This indicates it focuses on few positions. -/ -def Mixer.isSparseAt (M : Mixer Pos Pos) (i : Pos) (threshold : NNReal) : Prop := - M.attentionConcentration i ≥ threshold - -/-- **Uniform attention indicator**: An attention head is "diffuse" at position i -if its concentration is low. This indicates it spreads attention broadly. -/ -def Mixer.isDiffuseAt (M : Mixer Pos Pos) (i : Pos) (threshold : NNReal) : Prop := - M.attentionConcentration i ≤ threshold - -/-- If all attention is on one position, concentration is 1. -/ -theorem Mixer.attentionConcentration_one_hot - (M : Mixer Pos Pos) (i j₀ : Pos) - (h : M.w i j₀ = 1) (hz : ∀ j, j ≠ j₀ → M.w i j = 0) : - M.attentionConcentration i = 1 := by - classical - simp only [Mixer.attentionConcentration] - have hsplit : ∑ j, (M.w i j) ^ 2 = - (M.w i j₀) ^ 2 + ∑ j ∈ Finset.univ.erase j₀, (M.w i j) ^ 2 := by - rw [← Finset.add_sum_erase _ _ (Finset.mem_univ j₀)] - rw [hsplit, h, one_pow, add_eq_left] - apply Finset.sum_eq_zero - intro j hj - simp only [Finset.mem_erase, Finset.mem_univ, ne_eq] at hj - simp [hz j hj.1] - -end HeadAnalysis - -/-! ## Residual dominance analysis - -For interpreting transformers, we often want to know: "How dominant is the -skip connection vs the attention?" This section provides tools for this. --/ - -section ResidualDominance - -variable {S : Type*} [Fintype S] - -/-- A residual layer with coefficient c gives diagonal elements at least c. -This is a key property for interpretation: high c means the skip connection -dominates, preserving information from earlier layers. -/ -theorem residual_diagonal_lower [DecidableEq S] - (M : Mixer S S) (c : NNReal) (hc : c ≤ 1) (i : S) : - (Mixer.residual M c hc).w i i ≥ c := by - simp only [Mixer.residual] - have hnonneg : 0 ≤ (1 - c) * M.w i i := by - have h1 : 0 ≤ (1 - c) := by - exact zero_le _ - exact mul_nonneg h1 (by simp) - calc c * 1 + (1 - c) * M.w i i ≥ c * 1 := by - exact le_add_of_nonneg_right hnonneg - _ = c * 1 + 0 := by simp - _ = c := by ring - -/-- Off-diagonal elements of a residual are scaled down by (1-c). -This quantifies how much the attention contribution is suppressed. -/ -theorem residual_offdiag_scale [DecidableEq S] - (M : Mixer S S) (c : NNReal) (hc : c ≤ 1) - (i j : S) (hij : i ≠ j) : - (Mixer.residual M c hc).w i j = (1 - c) * M.w i j := by - simp only [Mixer.residual] - simp [hij] - -/-- The sum of off-diagonal elements in a residual row is at most (1-c). -This bounds how much attention "leaks" to other positions. -/ -theorem residual_offdiag_sum_bound [DecidableEq S] - (M : Mixer S S) (c : NNReal) (hc : c ≤ 1) (i : S) : - ∑ j ∈ Finset.univ.filter (· ≠ i), (Mixer.residual M c hc).w i j ≤ 1 - c := by - calc ∑ j ∈ Finset.univ.filter (· ≠ i), (Mixer.residual M c hc).w i j - = ∑ j ∈ Finset.univ.filter (· ≠ i), (1 - c) * M.w i j := by - apply Finset.sum_congr rfl - intro j hj - simp only [Finset.mem_filter, Finset.mem_univ, true_and] at hj - exact residual_offdiag_scale M c hc i j (Ne.symm hj) - _ = (1 - c) * ∑ j ∈ Finset.univ.filter (· ≠ i), M.w i j := by - rw [Finset.mul_sum] - _ ≤ (1 - c) * 1 := by - apply mul_le_mul_of_nonneg_left _ (by simp) - calc ∑ j ∈ Finset.univ.filter (· ≠ i), M.w i j - ≤ ∑ j, M.w i j := Finset.sum_le_sum_of_subset (fun x hx => Finset.mem_univ x) - _ = 1 := M.row_sum_one i - _ = 1 - c := by ring - -end ResidualDominance - -/-! ## Deep composition bounds - -A key question for transformer interpretation: after L layers, how spread out -is the attribution? This section provides quantitative bounds. --/ - -section DeepComposition - -variable {Pos : Type*} [Fintype Pos] - -/-- Composition weight is at most 1 (row-stochastic property preserved). -/ -theorem comp_weight_le_one (M N : Mixer Pos Pos) (i k : Pos) : - (M.comp N).w i k ≤ 1 := by - have h := (M.comp N).row_sum_one i - calc (M.comp N).w i k - ≤ ∑ k', (M.comp N).w i k' := - Finset.single_le_sum (fun _ _ => zero_le _) (Finset.mem_univ k) - _ = 1 := h - -/-- Each term in a composition sum is bounded by the corresponding M weight. -/ -theorem comp_term_bound (M N : Mixer Pos Pos) (i j k : Pos) : - M.w i j * N.w j k ≤ M.w i j := by - calc M.w i j * N.w j k ≤ M.w i j * 1 := by - apply mul_le_mul_of_nonneg_left _ (zero_le _) - calc N.w j k ≤ ∑ k', N.w j k' := - Finset.single_le_sum (fun _ _ => zero_le _) (Finset.mem_univ k) - _ = 1 := N.row_sum_one j - _ = M.w i j := by ring - -end DeepComposition - -/-! ## Cross-attention for encoder-decoder models - -Real seq2seq models use cross-attention where queries come from the decoder -and keys/values come from the encoder. This is a mixer from decoder positions -to encoder positions (tracking where decoder attends in encoder). --/ - -section CrossAttention - -variable {EncPos DecPos : Type*} [Fintype EncPos] [Fintype DecPos] - -/-- Cross-attention mixer: tracks where each decoder position attends in encoder. -This is the fundamental building block for encoder-decoder attribution. -/ -noncomputable def Mixer.crossAttention - (w : DecPos → EncPos → NNReal) - (hw : ∀ d, ∑ e, w d e = 1) : Mixer DecPos EncPos where - w := w - row_sum_one := hw - -/-- Cross-attention preserves the row-stochastic property. -/ -theorem Mixer.crossAttention_normalized - (w : DecPos → EncPos → NNReal) - (hw : ∀ d, ∑ e, w d e = 1) (d : DecPos) : - ∑ e, (Mixer.crossAttention w hw).w d e = 1 := - hw d - -end CrossAttention - -/-! ## Layer-wise attribution analysis - -For understanding transformer behavior, it's crucial to decompose attribution -layer by layer. This section provides tools for such analysis. --/ - -section LayerAttribution - -variable {Pos : Type*} [Fintype Pos] [DecidableEq Pos] - -/-- The attribution from position i to position j through a sequence of layers. -/ -noncomputable def layerWiseAttribution - (layers : List (Mixer Pos Pos)) (i j : Pos) : NNReal := - (layers.foldl Mixer.comp Mixer.identity).w i j - -/-- Empty layer list gives identity attribution. -/ -@[simp] -theorem layerWiseAttribution_nil (i j : Pos) : - layerWiseAttribution (Pos := Pos) [] i j = Mixer.identity.w i j := rfl - -/-- Single layer attribution equals the layer's weight. -/ -theorem layerWiseAttribution_singleton (M : Mixer Pos Pos) (i j : Pos) : - layerWiseAttribution [M] i j = M.w i j := by - simp only [layerWiseAttribution, List.foldl_cons, List.foldl_nil, Mixer.identity_comp] - -/-- Attribution through layers is bounded by 1 (probability). -/ -theorem layerWiseAttribution_le_one (layers : List (Mixer Pos Pos)) (i j : Pos) : - layerWiseAttribution layers i j ≤ 1 := by - simp only [layerWiseAttribution] - have h := (layers.foldl Mixer.comp Mixer.identity).row_sum_one i - calc (layers.foldl Mixer.comp Mixer.identity).w i j - ≤ ∑ k, (layers.foldl Mixer.comp Mixer.identity).w i k := - Finset.single_le_sum (fun _ _ => zero_le _) (Finset.mem_univ j) - _ = 1 := h - -/-- **Total attribution conservation**: Sum over all targets equals 1. -This is the formal statement that "attribution mass is conserved". -/ -theorem layerWiseAttribution_sum_one (layers : List (Mixer Pos Pos)) (i : Pos) : - ∑ j, layerWiseAttribution layers i j = 1 := by - simp only [layerWiseAttribution] - exact (layers.foldl Mixer.comp Mixer.identity).row_sum_one i - -end LayerAttribution - -/-! ## Tracer uniqueness for transformer interpretation - -The key insight connecting this formalization to neural network interpretation: -the tracer uniqueness theorem (from `Uniqueness.lean`) implies that attribution -methods based on the mixer framework are uniquely determined by boundary conditions. - -This provides formal justification for attention-based interpretation methods: -if two attribution methods both propagate mass according to the same mixers -(attention patterns) and agree on boundary conditions (e.g., start with -probability 1 at the output token), then they must produce identical attributions. --/ - -section TracerInterpretation - -variable {S : Type*} [Fintype S] [DecidableEq S] - -end TracerInterpretation - -/-- If two tracer propagation methods satisfy the same mixing recurrence -on a transformer's computation graph, they must coincide. This is the -formal statement that "attention-based attribution is unique given -the attention patterns and boundary conditions." -/ -theorem transformer_attribution_unique - {S : Type*} - (n : ℕ) - (parents : Fin n → Finset (Fin n)) - (htopo : ∀ k u, u ∈ parents k → u.val < k.val) - (c : Fin n → Fin n → NNReal) - (L : LocalSystem n := ⟨parents, c, fun {i u} h => htopo i u h⟩) - (T T' : LocalSystem.TracerFamily (S := S) n) - (hT : L.Satisfies T) - (hT' : L.Satisfies T') : T = T' := - LocalSystem.tracer_unique L hT hT' - -end Nfp diff --git a/Legacy/Nfp/Linearization.lean b/Legacy/Nfp/Linearization.lean deleted file mode 100644 index 01a913a..0000000 --- a/Legacy/Nfp/Linearization.lean +++ /dev/null @@ -1,2780 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.Real.Basic -import Mathlib.Data.Real.Sqrt -import Mathlib.Analysis.Real.Pi.Bounds -import Mathlib.Analysis.SpecialFunctions.Trigonometric.Basic -import Mathlib.Algebra.BigOperators.Ring.Finset -import Mathlib.Algebra.BigOperators.Field -import Mathlib.Data.Fintype.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Nfp.SignedMixer - -/-! -# Linearization of Non-Linear Operations - -Real neural networks are not linear—they use activation functions (ReLU, GeLU), -normalization layers (LayerNorm, BatchNorm), and attention softmax. To apply -our mixer-based attribution framework to real networks, we must *linearize* -these operations at a specific input. - -## Key Insight - -For any differentiable function f : ℝⁿ → ℝᵐ, at a specific input x₀, we use a -row-vector convention: - f(x) ≈ f(x₀) + (x - x₀) · J_f(x₀) - -where J_f(x₀) is the Jacobian matrix. This Jacobian is exactly a `SignedMixer`! - -For piecewise-linear functions like ReLU, the Jacobian is well-defined almost -everywhere and consists of 0s and 1s. - -## Main Definitions - -* `Linearization`: A record of (input, output, Jacobian as SignedMixer) -* `reluLinearization`: ReLU's Jacobian is diagonal with 0/1 entries -* `geluLinearization`: GeLU's Jacobian based on the GeLU derivative -* `layerNormJacobian`: Full Jacobian of LayerNorm (non-trivial!) -* `softmaxJacobian`: Jacobian of softmax for attention - -## Why This Matters - -Given a concrete forward pass through a transformer: -1. At each layer, record the activations -2. Compute the linearization (Jacobian) at those activations -3. Compose the resulting SignedMixers -4. The composition gives end-to-end attribution via the chain rule - -This connects to: -- Gradient × Input attribution -- Integrated Gradients (as a path integral of linearizations) -- Attention rollout (composition of attention Jacobians) - -## References - -- Sundararajan et al., "Axiomatic Attribution for Deep Networks" (Integrated Gradients) -- Abnar & Zuidema, "Quantifying Attention Flow in Transformers" --/ - -namespace Nfp - -open scoped BigOperators -open Finset - -/-! ## Linearization Structure -/ - -/-- A linearization captures the local linear approximation of a function at a point. - -Given f : ℝⁿ → ℝᵐ and input x₀, a linearization consists of: -- `input`: The point x₀ where we linearize -- `output`: f(x₀) -- `jacobian`: The Jacobian ∂f/∂x evaluated at x₀, as a SignedMixer - -The approximation is: f(x) ≈ output + jacobian.apply(x - input) -/ -structure Linearization (n m : Type*) [Fintype n] [Fintype m] where - /-- The input point at which we linearized. -/ - input : n → ℝ - /-- The output f(input). -/ - output : m → ℝ - /-- The Jacobian matrix as a SignedMixer. jacobian.w i j = ∂f_j/∂x_i -/ - jacobian : SignedMixer n m - -namespace Linearization - -variable {n m p : Type*} [Fintype n] [Fintype m] [Fintype p] - -/-- Compose two linearizations (chain rule). -If f : ℝⁿ → ℝᵐ is linearized at x with Jacobian J_f, and - g : ℝᵐ → ℝᵖ is linearized at f(x) with Jacobian J_g, then - g ∘ f has Jacobian J_f · J_g at x (row-vector convention). -/ -noncomputable def comp - (L₁ : Linearization n m) (L₂ : Linearization m p) - (_h : L₂.input = L₁.output) : Linearization n p where - input := L₁.input - output := L₂.output - jacobian := L₁.jacobian.comp L₂.jacobian - -/-- Chain rule for composed linearizations (row-vector convention). -/ -theorem comp_apply - (L₁ : Linearization n m) (L₂ : Linearization m p) (h : L₂.input = L₁.output) - (v : n → ℝ) : - (L₁.comp L₂ h).jacobian.apply v = L₂.jacobian.apply (L₁.jacobian.apply v) := by - simpa using - (SignedMixer.apply_comp (M := L₁.jacobian) (N := L₂.jacobian) (v := v)) - -/-- The identity linearization (identity function). -/ -noncomputable def id [DecidableEq n] : Linearization n n where - input := fun _ => 0 - output := fun _ => 0 - jacobian := SignedMixer.identity - -end Linearization - -/-! ## ReLU Linearization -/ - -section ReLU - -variable {n : Type*} [Fintype n] [DecidableEq n] - -/-- The ReLU activation function: max(x, 0). -/ -noncomputable def relu (x : ℝ) : ℝ := max x 0 - -/-- The ReLU derivative: 1 if x > 0, 0 otherwise. -At x = 0, we use the subgradient convention: derivative is 0. -/ -noncomputable def reluGrad (x : ℝ) : ℝ := if x > 0 then 1 else 0 - -/-- ReLU applied elementwise to a vector. -/ -noncomputable def reluVec (v : n → ℝ) : n → ℝ := fun i => relu (v i) - -/-- The ReLU mask: which coordinates are "on" (positive). -/ -def reluMask (v : n → ℝ) : n → Prop := fun i => v i > 0 - -/-- The ReLU mask as a 0/1 indicator. -/ -noncomputable def reluMaskIndicator (v : n → ℝ) : n → ℝ := - fun i => reluGrad (v i) - -/-- **ReLU Linearization**: The Jacobian of ReLU is a diagonal matrix -with entries 0 or 1 based on whether the input is positive. - -This is the key insight: ReLU is piecewise linear, so its local linearization -is exact (not an approximation) within each linear region. -/ -noncomputable def reluLinearization (x : n → ℝ) : Linearization n n where - input := x - output := reluVec x - jacobian := { - w := fun i j => if i = j then reluGrad (x i) else 0 - } - -/-- The ReLU Jacobian is diagonal. -/ -theorem reluLinearization_diagonal (x : n → ℝ) (i j : n) (h : i ≠ j) : - (reluLinearization x).jacobian.w i j = 0 := by - simp [reluLinearization, h] - -/-- The ReLU Jacobian diagonal entry is 0 or 1. -/ -theorem reluLinearization_diag_binary (x : n → ℝ) (i : n) : - (reluLinearization x).jacobian.w i i = 0 ∨ - (reluLinearization x).jacobian.w i i = 1 := by - simp only [reluLinearization, reluGrad] - by_cases h : x i > 0 <;> simp [h] - -/-- ReLU preserves positive inputs exactly. -/ -theorem relu_pos {x : ℝ} (h : x > 0) : relu x = x := by - simp [relu, max_eq_left (le_of_lt h)] - -/-- ReLU kills negative inputs. -/ -theorem relu_neg {x : ℝ} (h : x ≤ 0) : relu x = 0 := by - simp [relu, max_eq_right h] - -end ReLU - -/-! ## GeLU Linearization -/ - -section GeLU - -variable {n : Type*} [Fintype n] [DecidableEq n] - -/-- The GeLU (Gaussian Error Linear Unit) activation. -GeLU(x) = x · Φ(x) where Φ is the standard normal CDF. - -We use the approximation: GeLU(x) ≈ 0.5 · x · (1 + tanh(√(2/π) · (x + 0.044715 · x³))) -This is what most implementations use. -/ -noncomputable def gelu (x : ℝ) : ℝ := - 0.5 * x * (1 + Real.tanh (Real.sqrt (2 / Real.pi) * (x + 0.044715 * x^3))) - -/-- The GeLU derivative. -d/dx[GeLU(x)] = Φ(x) + x · φ(x) -where φ is the standard normal PDF. - -For the tanh approximation, the derivative is more complex but well-defined. -/ -noncomputable def geluGrad (x : ℝ) : ℝ := - let s := Real.sqrt (2 / Real.pi) - let inner := s * (x + 0.044715 * x^3) - let tanh_inner := Real.tanh inner - let sech2_inner := 1 - tanh_inner^2 -- sech² = 1 - tanh² - let inner_deriv := s * (1 + 3 * 0.044715 * x^2) - 0.5 * (1 + tanh_inner) + 0.5 * x * sech2_inner * inner_deriv - -/-- GeLU applied elementwise. -/ -noncomputable def geluVec (v : n → ℝ) : n → ℝ := fun i => gelu (v i) - -/-- **GeLU Linearization**: The Jacobian is diagonal with entries geluGrad(x_i). -/ -noncomputable def geluLinearization (x : n → ℝ) : Linearization n n where - input := x - output := geluVec x - jacobian := { - w := fun i j => if i = j then geluGrad (x i) else 0 - } - -/-- GeLU Jacobian is diagonal. -/ -theorem geluLinearization_diagonal (x : n → ℝ) (i j : n) (h : i ≠ j) : - (geluLinearization x).jacobian.w i j = 0 := by - simp [geluLinearization, h] - -end GeLU - -/-! ## LayerNorm Linearization -/ - -section LayerNorm - -variable {n : Type*} [Fintype n] [DecidableEq n] - -/-- Mean of a vector. -/ -noncomputable def mean (v : n → ℝ) : ℝ := - (∑ i, v i) / Fintype.card n - -/-- Variance of a vector. -/ -noncomputable def variance (v : n → ℝ) : ℝ := - let μ := mean v - (∑ i, (v i - μ)^2) / Fintype.card n - -/-- Standard deviation with epsilon for numerical stability. -/ -noncomputable def stddev (v : n → ℝ) : ℝ := Real.sqrt (variance v + 1e-5) - -/-- LayerNorm without learnable parameters (just normalization). -/ -noncomputable def layerNorm (v : n → ℝ) : n → ℝ := - let μ := mean v - let σ := stddev v - fun i => (v i - μ) / σ - -/-- LayerNorm with scale γ and bias β (per-coordinate). -/ -noncomputable def layerNormFull (γ β : n → ℝ) (v : n → ℝ) : n → ℝ := - let normalized := layerNorm v - fun i => γ i * normalized i + β i - -/-- **LayerNorm Jacobian**: This is the key non-trivial result. - -∂(LayerNorm(x))_j / ∂x_i = (1/σ) · [δ_{ij} - 1/n - (x_j - μ)(x_i - μ)/(n·σ²)] - -where δ_{ij} is 1 if i=j, 0 otherwise. - -This shows LayerNorm creates *dense* dependencies: every output depends on every input! -This is fundamentally different from ReLU/GeLU which are diagonal. -/ -noncomputable def layerNormJacobian (x : n → ℝ) : SignedMixer n n where - w := fun i j => - let μ := mean x - let σ := stddev x - let n_inv := (1 : ℝ) / Fintype.card n - let centered_i := x i - μ - let centered_j := x j - μ - let diagonal := if i = j then 1 else 0 - (1 / σ) * (diagonal - n_inv - centered_j * centered_i / (Fintype.card n * σ^2)) - -/-- Diagonal linear map as a `SignedMixer`: x · diag d (row-vector convention). -/ -noncomputable def diagMixer (d : n → ℝ) : SignedMixer n n where - w := fun i j => if i = j then d j else 0 - -/-- Operator norm bound for a diagonal mixer from a uniform entry bound. -/ -theorem operatorNormBound_diagMixer_le [Nonempty n] (d : n → ℝ) (b : ℝ) - (h : ∀ i, |d i| ≤ b) : - SignedMixer.operatorNormBound (diagMixer d) ≤ b := by - classical - dsimp [SignedMixer.operatorNormBound] - refine (Finset.sup'_le_iff (s := Finset.univ) - (H := Finset.univ_nonempty (α := n)) - (f := fun i => ∑ j, |(diagMixer d).w i j|) - (a := b)).2 ?_ - intro i hi - have hsum : (∑ j, |(diagMixer d).w i j|) = |d i| := by - have hsum' : (∑ j, |(diagMixer d).w i j|) = |(diagMixer d).w i i| := by - refine Fintype.sum_eq_single i ?_ - intro j hne - have hne' : i ≠ j := by - simpa [ne_comm] using hne - simp [diagMixer, hne'] - simpa [diagMixer] using hsum' - simpa [hsum] using h i - -/-- Operator norm bound for `A ∘ diag(d) ∘ B` from component bounds. -/ -theorem operatorNormBound_comp_diagMixer_comp_le - {S T U : Type*} [Fintype S] [Fintype T] [Fintype U] [DecidableEq T] - [Nonempty S] [Nonempty T] - (A : SignedMixer S T) (B : SignedMixer T U) (d : T → ℝ) - (a c b : ℝ) - (hA : SignedMixer.operatorNormBound A ≤ a) - (hB : SignedMixer.operatorNormBound B ≤ b) - (hD : ∀ i, |d i| ≤ c) : - SignedMixer.operatorNormBound ((A.comp (diagMixer d)).comp B) ≤ a * c * b := by - classical - have hD' : SignedMixer.operatorNormBound (diagMixer d) ≤ c := - operatorNormBound_diagMixer_le (d := d) (b := c) hD - have hA_nonneg : 0 ≤ a := - le_trans (SignedMixer.operatorNormBound_nonneg (M := A)) hA - have hB_nonneg : 0 ≤ b := - le_trans (SignedMixer.operatorNormBound_nonneg (M := B)) hB - have hC_nonneg : 0 ≤ c := - le_trans (SignedMixer.operatorNormBound_nonneg (M := diagMixer d)) hD' - have hcomp : - SignedMixer.operatorNormBound ((A.comp (diagMixer d)).comp B) ≤ - SignedMixer.operatorNormBound A * - SignedMixer.operatorNormBound (diagMixer d) * - SignedMixer.operatorNormBound B := by - simpa using - (SignedMixer.operatorNormBound_comp3_le - (A := A) (B := diagMixer d) (C := B)) - have hmul1 : - SignedMixer.operatorNormBound A * - SignedMixer.operatorNormBound (diagMixer d) ≤ a * c := by - have h1 : - SignedMixer.operatorNormBound A * - SignedMixer.operatorNormBound (diagMixer d) - ≤ a * SignedMixer.operatorNormBound (diagMixer d) := by - exact mul_le_mul_of_nonneg_right hA - (SignedMixer.operatorNormBound_nonneg (M := diagMixer d)) - have h2 : - a * SignedMixer.operatorNormBound (diagMixer d) ≤ a * c := by - exact mul_le_mul_of_nonneg_left hD' hA_nonneg - exact le_trans h1 h2 - have hmul2 : - (SignedMixer.operatorNormBound A * - SignedMixer.operatorNormBound (diagMixer d)) * - SignedMixer.operatorNormBound B ≤ (a * c) * b := by - have h1 : - (SignedMixer.operatorNormBound A * - SignedMixer.operatorNormBound (diagMixer d)) * - SignedMixer.operatorNormBound B ≤ - (a * c) * SignedMixer.operatorNormBound B := by - exact mul_le_mul_of_nonneg_right hmul1 - (SignedMixer.operatorNormBound_nonneg (M := B)) - have h2 : (a * c) * SignedMixer.operatorNormBound B ≤ (a * c) * b := by - exact mul_le_mul_of_nonneg_left hB (mul_nonneg hA_nonneg hC_nonneg) - exact le_trans h1 h2 - have hmul2' : - SignedMixer.operatorNormBound A * - SignedMixer.operatorNormBound (diagMixer d) * - SignedMixer.operatorNormBound B ≤ a * c * b := by - simpa [mul_assoc] using hmul2 - exact le_trans hcomp hmul2' - -/-- Jacobian of LayerNorm with learnable scale γ (bias β has no effect on Jacobian). -/ -noncomputable def layerNormFullJacobian (γ : n → ℝ) (x : n → ℝ) : SignedMixer n n := - (layerNormJacobian x).comp (diagMixer γ) - -/-- LayerNorm-with-affine linearization at a specific input. -/ -noncomputable def layerNormFullLinearization (γ β : n → ℝ) (x : n → ℝ) : Linearization n n where - input := x - output := layerNormFull γ β x - jacobian := layerNormFullJacobian γ x - -/-- LayerNorm linearization at a specific input. -/ -noncomputable def layerNormLinearization (x : n → ℝ) : Linearization n n where - input := x - output := layerNorm x - jacobian := layerNormJacobian x - -omit [DecidableEq n] in -/-- **Key insight**: LayerNorm is translation-invariant: LayerNorm(x + c·1) = LayerNorm(x). -In row-vector convention, this corresponds to the Jacobian columns summing to 0. -/ -theorem layerNorm_translation_invariant [Nonempty n] (x : n → ℝ) (c : ℝ) : - layerNorm (fun i => x i + c) = layerNorm x := by - ext i - simp only [layerNorm, mean, variance, stddev] - -- First show: mean(x + c) = mean(x) + c - have hmean : (∑ j, (x j + c)) / Fintype.card n = (∑ j, x j) / Fintype.card n + c := by - rw [Finset.sum_add_distrib] - simp only [Finset.sum_const, Finset.card_univ, nsmul_eq_mul] - field_simp - -- The centered value (x i + c) - mean(x + c) = x i - mean(x) - have hcentered : ∀ j, (x j + c) - (∑ k, (x k + c)) / Fintype.card n = - x j - (∑ k, x k) / Fintype.card n := by - intro j - rw [hmean] - ring - -- Therefore variance is unchanged - have hvar : (∑ j, ((x j + c) - (∑ k, (x k + c)) / Fintype.card n)^2) / Fintype.card n = - (∑ j, (x j - (∑ k, x k) / Fintype.card n)^2) / Fintype.card n := by - congr 1 - apply Finset.sum_congr rfl - intro j _ - rw [hcentered] - -- So stddev is unchanged, and the final result follows - simp only [hcentered] - -end LayerNorm - -/-! ## Token-wise LayerNorm on the Residual Stream -/ - -section TokenwiseLayerNorm - -variable {pos d : Type*} [Fintype pos] [DecidableEq pos] [Fintype d] [DecidableEq d] - -/-- Lift a per-token LayerNorm Jacobian to the full residual stream `(pos × d)`. - -This is block-diagonal across positions: coordinates at different positions do not mix. --/ -noncomputable def tokenwiseLayerNormFullJacobian (γ : d → ℝ) (x : pos × d → ℝ) : - SignedMixer (pos × d) (pos × d) where - w := fun i j => - if i.1 = j.1 then - let p : pos := i.1 - (layerNormFullJacobian (n := d) γ (fun k => x (p, k))).w i.2 j.2 - else 0 - -end TokenwiseLayerNorm - -/-! ## Rotary Position Embeddings (RoPE) -/ - -section RoPE - -variable {pos pair : Type*} - [Fintype pos] [DecidableEq pos] - [Fintype pair] [DecidableEq pair] - -/-- RoPE uses 2D rotations on each `(pairIdx, Bool)` coordinate pair. -/ -abbrev RoPEDim (pair : Type*) := pair × Bool - -/-- The RoPE linear map as a `SignedMixer` on the residual stream `(pos × (pair × Bool))`. - -For each position `p` and pair index `k`, this applies the 2×2 rotation with angle `θ p k`: -`(x₀, x₁) ↦ (cos θ · x₀ - sin θ · x₁, sin θ · x₀ + cos θ · x₁)`. - -This is tokenwise (block-diagonal across `pos`): different positions never mix. -/ -noncomputable def ropeJacobian (θ : pos → pair → ℝ) : - SignedMixer (pos × RoPEDim pair) (pos × RoPEDim pair) where - w := fun i j => - if i.1 = j.1 then - if i.2.1 = j.2.1 then - let p : pos := j.1 - let k : pair := j.2.1 - let ang := θ p k - match i.2.2, j.2.2 with - | false, false => Real.cos ang - | true, false => -Real.sin ang - | false, true => Real.sin ang - | true, true => Real.cos ang - else 0 - else 0 - -/-- RoPE forward map: apply the RoPE Jacobian as a linear operator. -/ -noncomputable def rope (θ : pos → pair → ℝ) (x : pos × RoPEDim pair → ℝ) : - pos × RoPEDim pair → ℝ := - (ropeJacobian (pos := pos) (pair := pair) θ).apply x - -@[simp] lemma ropeJacobian_cross_pos (θ : pos → pair → ℝ) - {i j : pos × RoPEDim pair} (h : i.1 ≠ j.1) : - (ropeJacobian (pos := pos) (pair := pair) θ).w i j = 0 := by - simp [ropeJacobian, h] - -end RoPE - -/-! ## Softmax Linearization -/ - -section Softmax - -variable {n : Type*} [Fintype n] - -/-- Softmax function: softmax(x)_j = exp(x_j) / Σ_k exp(x_k) -/ -noncomputable def softmax (v : n → ℝ) : n → ℝ := - let expSum := ∑ k, Real.exp (v k) - fun j => Real.exp (v j) / expSum - -variable [DecidableEq n] - -/-- **Softmax Jacobian**: ∂softmax(x)_j / ∂x_i = softmax(x)_j · (δ_{ij} - softmax(x)_i) - -This is a classic result. The Jacobian depends on the softmax *output*, not input! -/ -noncomputable def softmaxJacobian (x : n → ℝ) : SignedMixer n n where - w := fun i j => - let p := softmax x - p j * ((if i = j then 1 else 0) - p i) - -/-- Softmax linearization. -/ -noncomputable def softmaxLinearization (x : n → ℝ) : Linearization n n where - input := x - output := softmax x - jacobian := softmaxJacobian x - -omit [DecidableEq n] in -/-- Softmax outputs are nonnegative. -/ -theorem softmax_nonneg (x : n → ℝ) (j : n) : softmax x j ≥ 0 := by - simp only [softmax] - apply div_nonneg (Real.exp_nonneg _) - exact Finset.sum_nonneg (fun _ _ => Real.exp_nonneg _) - -omit [DecidableEq n] in -/-- Softmax outputs sum to 1. -/ -theorem softmax_sum_one [Nonempty n] (x : n → ℝ) : ∑ j, softmax x j = 1 := by - simp only [softmax] - rw [← Finset.sum_div] - apply div_self - apply ne_of_gt - apply Finset.sum_pos (fun _ _ => Real.exp_pos _) Finset.univ_nonempty - -/-- Softmax Jacobian diagonal entries are positive (when p_j < 1). -/ -theorem softmaxJacobian_diag_pos [Nonempty n] (x : n → ℝ) (j : n) - (h : softmax x j < 1) : (softmaxJacobian x).w j j > 0 := by - simp only [softmaxJacobian, ite_true] - -- p_j · (1 - p_j) > 0 when 0 < p_j < 1 - have hp : softmax x j > 0 := by - simp only [softmax] - apply div_pos (Real.exp_pos _) - apply Finset.sum_pos (fun _ _ => Real.exp_pos _) Finset.univ_nonempty - apply mul_pos hp - linarith - -/-- Softmax Jacobian off-diagonal entries are negative. -/ -theorem softmaxJacobian_off_diag_neg [Nonempty n] (x : n → ℝ) (i j : n) (h : i ≠ j) : - (softmaxJacobian x).w i j < 0 := by - simp only [softmaxJacobian, if_neg h] - -- p_j · (0 - p_i) = -p_j · p_i < 0 - have hpj : softmax x j > 0 := by - simp only [softmax] - apply div_pos (Real.exp_pos _) - apply Finset.sum_pos (fun _ _ => Real.exp_pos _) Finset.univ_nonempty - have hpi : softmax x i > 0 := by - simp only [softmax] - apply div_pos (Real.exp_pos _) - apply Finset.sum_pos (fun _ _ => Real.exp_pos _) Finset.univ_nonempty - linarith [mul_pos hpj hpi] - -omit [DecidableEq n] in -/-- Softmax is translation-invariant: softmax(x + c·1) = softmax(x). -/ -theorem softmax_translation_invariant (x : n → ℝ) (c : ℝ) : - softmax (fun i => x i + c) = softmax x := by - ext j - simp only [softmax, Real.exp_add] - -- exp(x_j + c) / Σ exp(x_k + c) = exp(x_j) · exp(c) / (exp(c) · Σ exp(x_k)) - -- = exp(x_j) / Σ exp(x_k) - have h : ∑ x_1 : n, Real.exp (x x_1) * Real.exp c = - Real.exp c * ∑ k : n, Real.exp (x k) := by - rw [Finset.mul_sum] - congr 1 - ext k - ring - rw [h] - field_simp - -end Softmax - -/-! ## Attribution via Linearization -/ - -section Attribution - -variable {n m : Type*} [Fintype n] [Fintype m] - -/-- Given a full forward pass linearization, compute feature attributions. - -The attribution of input feature i to output feature j is: - attr(i, j) = input_i × ∂output_j/∂input_i - -This is "Gradient × Input" attribution. -/ -noncomputable def gradientTimesInput (L : Linearization n m) (i : n) (j : m) : ℝ := - L.input i * L.jacobian.w i j - -/-- Sum of gradient×input attributions equals output (for linear function). -This is the completeness axiom from our Attribution module! -/ -theorem gradientTimesInput_complete (L : Linearization n n) - (hLinear : L.output = L.jacobian.apply L.input) (j : n) : - ∑ i, gradientTimesInput L i j = L.output j := by - simp only [gradientTimesInput, hLinear, SignedMixer.apply_def] - -/-- For composed linearizations, the chain rule gives: - ∂output/∂input = J_first · J_{next} · ... · J_last - -This is exactly `SignedMixer.comp` under the row-vector convention. -/ -theorem composed_attribution {p : Type*} [Fintype p] - (L₁ : Linearization n m) (L₂ : Linearization m p) - (h : L₂.input = L₁.output) : - (L₁.comp L₂ h).jacobian = L₁.jacobian.comp L₂.jacobian := rfl - -end Attribution - -/-! ## Full Attention Jacobian Decomposition -/ - -section AttentionJacobian - -/-! -### The Key Insight - -In a self-attention layer, the output for query position q is (before W_O): - attnOut_q = Σ_k A_{qk} · V_k = Σ_k A_{qk} · (x_k · W_V) - -where A_{qk} = softmax(Q_q · K_k^T / √d)_k - -The Jacobian ∂output/∂input has two fundamentally different contributions: -1. **Value term**: ∂output/∂V · ∂V/∂x = A · (W_V · W_O) - (attention weights × value/output projections) -2. **Pattern term**: ∂output/∂A · ∂A/∂x (how changing x shifts the attention pattern) - -The **Value term** is what "Attention Rollout" uses—it treats attention weights A as fixed. -The **Pattern term** captures how attention patterns themselves shift with input changes. - -**Key result**: We can bound the Pattern term, and when it's small relative to the -Value term, Attention Rollout is a "faithful" explanation. --/ - -variable {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - -/-- The dimensionality of the model (as a real number for scaling). -/ -noncomputable def modelDim (d : Type*) [Fintype d] : ℝ := Fintype.card d - -/-- Full self-attention layer with concrete projection matrices. - -This captures the complete attention mechanism: - Q = x · W_Q, K = x · W_K, V = x · W_V - scores = Q · K^T / √d - A = softmax(scores) - output = A · V · W_O - -We index by: -- `n`: sequence positions (tokens) -- `d`: hidden dimension -/ -structure FullAttentionLayer (n d : Type*) [Fintype n] [Fintype d] where - /-- Query projection W_Q : d → d -/ - W_Q : SignedMixer d d - /-- Key projection W_K : d → d -/ - W_K : SignedMixer d d - /-- Value projection W_V : d → d -/ - W_V : SignedMixer d d - /-- Output projection W_O : d → d -/ - W_O : SignedMixer d d - -/-- Attention forward pass state: captures all intermediate values at a specific input. - -This is what we need to compute the Jacobian—we linearize around these specific values. -/ -structure AttentionForwardState (n d : Type*) [Fintype n] [Fintype d] where - /-- Input hidden states: x[position, hidden_dim] -/ - input : n → d → ℝ - /-- Queries after projection: Q = x · W_Q -/ - queries : n → d → ℝ - /-- Keys after projection: K = x · W_K -/ - keys : n → d → ℝ - /-- Values after projection: V = x · W_V -/ - values : n → d → ℝ - /-- Raw attention scores (before softmax): scores_{qk} = Q_q · K_k^T / √d -/ - scores : n → n → ℝ - /-- Attention weights (after softmax): A_{qk} = softmax(scores_q)_k -/ - attentionWeights : n → n → ℝ - /-- Output before W_O: Σ_k A_{qk} · V_k -/ - attentionOutput : n → d → ℝ - /-- Final output: attentionOutput · W_O -/ - output : n → d → ℝ - -/-- Compute the forward pass for a full attention layer. -/ -noncomputable def attentionForward - (layer : FullAttentionLayer n d) (x : n → d → ℝ) : AttentionForwardState n d where - input := x - queries := fun pos dim => ∑ d', x pos d' * layer.W_Q.w d' dim - keys := fun pos dim => ∑ d', x pos d' * layer.W_K.w d' dim - values := fun pos dim => ∑ d', x pos d' * layer.W_V.w d' dim - scores := fun q k => - let Q_q := fun dim => ∑ d', x q d' * layer.W_Q.w d' dim - let K_k := fun dim => ∑ d', x k d' * layer.W_K.w d' dim - (∑ dim, Q_q dim * K_k dim) / Real.sqrt (modelDim d) - attentionWeights := fun q k => - let rawScores := fun k' => - let Q_q := fun dim => ∑ d', x q d' * layer.W_Q.w d' dim - let K_k' := fun dim => ∑ d', x k' d' * layer.W_K.w d' dim - (∑ dim, Q_q dim * K_k' dim) / Real.sqrt (modelDim d) - softmax rawScores k - attentionOutput := fun q dim => - let A := fun k => - let rawScores := fun k' => - let Q_q := fun dim' => ∑ d', x q d' * layer.W_Q.w d' dim' - let K_k' := fun dim' => ∑ d', x k' d' * layer.W_K.w d' dim' - (∑ dim', Q_q dim' * K_k' dim') / Real.sqrt (modelDim d) - softmax rawScores k - ∑ k, A k * (∑ d', x k d' * layer.W_V.w d' dim) - output := fun q dim => - let attnOut := fun dim' => - let A := fun k => - let rawScores := fun k' => - let Q_q := fun d'' => ∑ d', x q d' * layer.W_Q.w d' d'' - let K_k' := fun d'' => ∑ d', x k' d' * layer.W_K.w d' d'' - (∑ d'', Q_q d'' * K_k' d'') / Real.sqrt (modelDim d) - softmax rawScores k - ∑ k, A k * (∑ d', x k d' * layer.W_V.w d' dim') - ∑ dim', attnOut dim' * layer.W_O.w dim' dim - -/-- **Extended Attention Linearization** with full projection matrices and intermediates. - -This captures everything needed to decompose the Jacobian into Value and Pattern terms. -/ -structure AttentionLinearization (n d : Type*) [Fintype n] [Fintype d] where - /-- The attention layer definition -/ - layer : FullAttentionLayer n d - /-- The forward state at a specific input -/ - state : AttentionForwardState n d - /-- The full Jacobian of the attention layer at this input. - Maps (position × dim) → (position × dim). -/ - fullJacobian : SignedMixer (n × d) (n × d) - -/-! ### The Value Term -/ - -/-- **Value Term** of the attention Jacobian. - -This is the Jacobian when we treat attention weights A as fixed constants. -It corresponds to "Attention Rollout" interpretability. - -For output position (q, dim_out), input position (k, dim_in): - ValueTerm_{(q,dim_out), (k,dim_in)} = A_{qk} · (W_V · W_O)_{dim_in, dim_out} - -This measures: "How much does input at position k flow to output at position q, -weighted by the attention A_{qk} and projected through value/output matrices?" -/ -noncomputable def valueTerm (L : AttentionLinearization n d) : SignedMixer (n × d) (n × d) where - w := fun ⟨k, dim_in⟩ ⟨q, dim_out⟩ => - L.state.attentionWeights q k * (L.layer.W_V.comp L.layer.W_O).w dim_in dim_out - -omit [DecidableEq n] [DecidableEq d] in -/-- The Value Term is a tensor product: A ⊗ (W_V · W_O). -This structure is why attention weights alone (Attention Rollout) make sense: -position mixing is captured by A, dimension mixing by W_V · W_O. -/ -theorem valueTerm_factorizes (L : AttentionLinearization n d) (q k : n) (d_in d_out : d) : - (valueTerm L).w (k, d_in) (q, d_out) = - L.state.attentionWeights q k * (L.layer.W_V.comp L.layer.W_O).w d_in d_out := rfl - -omit [DecidableEq n] [DecidableEq d] in -/-- Row absolute sum for the Value Term splits into attention column mass and value row mass. -/ -theorem valueTerm_rowAbsSum (L : AttentionLinearization n d) (k : n) (d_in : d) : - SignedMixer.rowAbsSum (valueTerm L) (k, d_in) = - (∑ q, |L.state.attentionWeights q k|) * - SignedMixer.rowAbsSum (L.layer.W_V.comp L.layer.W_O) d_in := by - classical - let voEntry : d → d → ℝ := - fun d_in d_out => ∑ j, L.layer.W_V.w d_in j * L.layer.W_O.w j d_out - have hprod : - (∑ q, |L.state.attentionWeights q k|) * - ∑ d_out, |voEntry d_in d_out| = - ∑ q, ∑ d_out, - |L.state.attentionWeights q k| * |voEntry d_in d_out| := by - simpa using - (Fintype.sum_mul_sum - (f := fun q => |L.state.attentionWeights q k|) - (g := fun d_out => |voEntry d_in d_out|)) - have hrow : - SignedMixer.rowAbsSum (valueTerm L) (k, d_in) = - ∑ x : n × d, - |L.state.attentionWeights x.1 k| * |voEntry d_in x.2| := by - simp [SignedMixer.rowAbsSum, valueTerm, abs_mul, voEntry, SignedMixer.comp_w] - have hrow' : - ∑ x : n × d, - |L.state.attentionWeights x.1 k| * |voEntry d_in x.2| = - ∑ q, ∑ d_out, - |L.state.attentionWeights q k| * |voEntry d_in d_out| := by - simpa using - (Fintype.sum_prod_type' - (f := fun q d_out => - |L.state.attentionWeights q k| * |voEntry d_in d_out|)) - calc - SignedMixer.rowAbsSum (valueTerm L) (k, d_in) - = ∑ q, ∑ d_out, - |L.state.attentionWeights q k| * - |(L.layer.W_V.comp L.layer.W_O).w d_in d_out| := by - simpa [hrow'] using hrow - _ = (∑ q, |L.state.attentionWeights q k|) * - SignedMixer.rowAbsSum (L.layer.W_V.comp L.layer.W_O) d_in := by - simpa [SignedMixer.rowAbsSum, voEntry, SignedMixer.comp_w] using hprod.symm - -omit [DecidableEq n] [DecidableEq d] in -/-- Value-term operator-norm bound from attention column mass and value projection bound. -/ -theorem valueTerm_operatorNormBound_le [Nonempty n] [Nonempty d] - (L : AttentionLinearization n d) (A B : ℝ) - (hAttn : ∀ k, ∑ q, |L.state.attentionWeights q k| ≤ A) - (hVO : SignedMixer.operatorNormBound (L.layer.W_V.comp L.layer.W_O) ≤ B) : - SignedMixer.operatorNormBound (valueTerm L) ≤ A * B := by - classical - refine (Finset.sup'_le_iff (s := Finset.univ) - (H := Finset.univ_nonempty (α := n × d)) - (f := fun i => SignedMixer.rowAbsSum (valueTerm L) i) - (a := A * B)).2 ?_ - intro kd hkd - rcases kd with ⟨k, d_in⟩ - have hRow : - SignedMixer.rowAbsSum (valueTerm L) (k, d_in) = - (∑ q, |L.state.attentionWeights q k|) * - SignedMixer.rowAbsSum (L.layer.W_V.comp L.layer.W_O) d_in := - valueTerm_rowAbsSum (L := L) k d_in - have hAttn_nonneg : 0 ≤ A := by - rcases (inferInstance : Nonempty n) with ⟨k0⟩ - have hsum_nonneg : - 0 ≤ ∑ q, |L.state.attentionWeights q k0| := by - exact Finset.sum_nonneg (fun _ _ => abs_nonneg _) - exact le_trans hsum_nonneg (hAttn k0) - have hVOnonneg : - 0 ≤ SignedMixer.rowAbsSum (L.layer.W_V.comp L.layer.W_O) d_in := - SignedMixer.rowAbsSum_nonneg (M := L.layer.W_V.comp L.layer.W_O) d_in - have hVOrow : - SignedMixer.rowAbsSum (L.layer.W_V.comp L.layer.W_O) d_in ≤ B := by - have hsup : - SignedMixer.rowAbsSum (L.layer.W_V.comp L.layer.W_O) d_in ≤ - SignedMixer.operatorNormBound (L.layer.W_V.comp L.layer.W_O) := by - exact Finset.le_sup' (s := Finset.univ) - (f := fun i => SignedMixer.rowAbsSum (L.layer.W_V.comp L.layer.W_O) i) (by simp) - exact le_trans hsup hVO - have hMul1 : - (∑ q, |L.state.attentionWeights q k|) * - SignedMixer.rowAbsSum (L.layer.W_V.comp L.layer.W_O) d_in ≤ - A * SignedMixer.rowAbsSum (L.layer.W_V.comp L.layer.W_O) d_in := by - exact mul_le_mul_of_nonneg_right (hAttn k) hVOnonneg - have hMul2 : - A * SignedMixer.rowAbsSum (L.layer.W_V.comp L.layer.W_O) d_in ≤ A * B := by - exact mul_le_mul_of_nonneg_left hVOrow hAttn_nonneg - have hMul : - (∑ q, |L.state.attentionWeights q k|) * - SignedMixer.rowAbsSum (L.layer.W_V.comp L.layer.W_O) d_in ≤ A * B := by - exact le_trans hMul1 hMul2 - simpa [hRow] using hMul - -/-! ### The Pattern Term -/ - -/-- **Pattern Term** of the attention Jacobian. - -This captures how the attention pattern A itself changes as input changes. -It involves the softmax Jacobian and the query/key gradients. - -∂A_{qk}/∂x_{i,d} = Σ_k' ∂A_{qk}/∂scores_{qk'} · ∂scores_{qk'}/∂x_{i,d} - -where: -- ∂A/∂scores is the softmax Jacobian -- ∂scores/∂x involves W_Q and W_K - -The Pattern Term is the contribution of this to the overall Jacobian. -/ -noncomputable def patternTerm (L : AttentionLinearization n d) : SignedMixer (n × d) (n × d) where - w := fun ⟨i, d_in⟩ ⟨q, d_out⟩ => - -- This is the complex term: how changing x_{i,d_in} shifts attention, - -- and how that shifted attention affects output_{q,d_out} - -- - -- Full formula: - -- Σ_k Σ_k' (∂output_{q,d_out}/∂A_{qk}) · (∂A_{qk}/∂scores_{qk'}) · (∂scores_{qk'}/∂x_{i,d_in}) - -- - -- = Σ_k Σ_k' V_{k,d_out'} · W_O_{d_out',d_out} · softmaxJac_{qkk'} · scoreGrad_{qk',i,d_in} - - -- For now, we define it implicitly as fullJacobian - valueTerm - L.fullJacobian.w (i, d_in) (q, d_out) - (valueTerm L).w (i, d_in) (q, d_out) - -omit [DecidableEq n] [DecidableEq d] in -/-- **The Fundamental Decomposition**: The full Jacobian equals Value Term + Pattern Term. - -This is the core insight: attention Jacobian = how values flow + how attention shifts. -When the Pattern Term is small, attention weights alone explain the network's behavior. -/ -theorem attention_jacobian_decomposition (L : AttentionLinearization n d) : - L.fullJacobian = valueTerm L + patternTerm L := by - ext ⟨i, d_in⟩ ⟨q, d_out⟩ - simp only [SignedMixer.add_w, valueTerm, patternTerm] - ring - -omit [DecidableEq n] [DecidableEq d] in -/-- Operator-norm bound for the full attention Jacobian from Value/Pattern term bounds. -/ -theorem attention_fullJacobian_bound_of_terms [Nonempty n] [Nonempty d] - (L : AttentionLinearization n d) {A B : ℝ} - (hValue : SignedMixer.operatorNormBound (valueTerm L) ≤ A) - (hPattern : SignedMixer.operatorNormBound (patternTerm L) ≤ B) : - SignedMixer.operatorNormBound L.fullJacobian ≤ A + B := by - have hdecomp := attention_jacobian_decomposition (L := L) - calc - SignedMixer.operatorNormBound L.fullJacobian = - SignedMixer.operatorNormBound (valueTerm L + patternTerm L) := by - simp [hdecomp] - _ ≤ SignedMixer.operatorNormBound (valueTerm L) + - SignedMixer.operatorNormBound (patternTerm L) := by - simpa using - (SignedMixer.operatorNormBound_add_le - (M := valueTerm L) (N := patternTerm L)) - _ ≤ A + B := add_le_add hValue hPattern - -/-! ### Score Gradient -/ - -/-- Gradient of attention scores with respect to input. - -∂scores_{qk}/∂x_{i,d} = (1/√d) · [δ_{qi} · Σ_d' W_Q_{d,d'} · K_k[d'] - + δ_{ki} · Σ_d' Q_q[d'] · W_K_{d,d'}] - -The score gradient is nonzero only when i = q (query position) or i = k (key position). -/ -noncomputable def scoreGradient (L : AttentionLinearization n d) - (q k i : n) (d_in : d) : ℝ := - let scale := 1 / Real.sqrt (modelDim d) - let queryContrib := if q = i then - ∑ d', L.layer.W_Q.w d_in d' * L.state.keys k d' - else 0 - let keyContrib := if k = i then - ∑ d', L.state.queries q d' * L.layer.W_K.w d_in d' - else 0 - scale * (queryContrib + keyContrib) - -omit [DecidableEq d] in -/-- Score gradient is local: only the query and key positions contribute. -/ -theorem scoreGradient_local (L : AttentionLinearization n d) - (q k i : n) (d_in : d) (hq : q ≠ i) (hk : k ≠ i) : - scoreGradient L q k i d_in = 0 := by - simp [scoreGradient, hq, hk] - -/-! ### Attention Pattern Gradient -/ - -/-- Gradient of attention weights with respect to input, using the softmax Jacobian. - -∂A_{qk}/∂x_{i,d} = Σ_k' softmaxJac(scores_q)_{k,k'} · ∂scores_{qk'}/∂x_{i,d} - = Σ_k' A_{qk}·(δ_{kk'} - A_{qk'}) · scoreGrad_{qk',i,d} - -Note: This involves the full softmax Jacobian evaluated at the scores. -/ -noncomputable def attentionGradient (L : AttentionLinearization n d) - (q k i : n) (d_in : d) : ℝ := - let A_q := L.state.attentionWeights q -- attention distribution for query q - ∑ k', A_q k * ((if k = k' then 1 else 0) - A_q k') * scoreGradient L q k' i d_in - -omit [DecidableEq d] in -/-- The attention gradient relates to the softmax Jacobian. - -Note: This requires the consistency property that -`L.state.attentionWeights q = softmax (L.state.scores q)`, -which we state as a separate condition. -/ -theorem attentionGradient_via_softmax (L : AttentionLinearization n d) (q k i : n) (d_in : d) - (hConsistent : L.state.attentionWeights q = softmax (L.state.scores q)) : - attentionGradient L q k i d_in = - ∑ k', (softmaxJacobian (L.state.scores q)).w k' k * scoreGradient L q k' i d_in := by - simp only [attentionGradient, softmaxJacobian, hConsistent] - congr 1 - ext k' - by_cases h : k = k' - · simp [h] - · have hne : k' ≠ k := fun h' => h h'.symm - simp [h, hne] - -omit [DecidableEq n] [DecidableEq d] in -/-- Sum of absolute values after applying a signed mixer is controlled by the operator norm. -/ -theorem sum_abs_apply_le {S T : Type*} [Fintype S] [Fintype T] [Nonempty S] - (M : SignedMixer S T) (v : S → ℝ) : - ∑ j, |M.apply v j| ≤ (∑ i, |v i|) * SignedMixer.operatorNormBound M := by - classical - have hterm : ∀ j, |M.apply v j| ≤ ∑ i, |v i| * |M.w i j| := by - intro j - have h := - (abs_sum_le_sum_abs (f := fun i => v i * M.w i j) (s := Finset.univ)) - simpa [SignedMixer.apply_def, abs_mul] using h - have hsum : - ∑ j, |M.apply v j| ≤ ∑ j, ∑ i, |v i| * |M.w i j| := by - refine Finset.sum_le_sum ?_ - intro j _hj - exact hterm j - have hswap : - (∑ j, ∑ i, |v i| * |M.w i j|) = - ∑ i, |v i| * (∑ j, |M.w i j|) := by - calc - (∑ j, ∑ i, |v i| * |M.w i j|) - = ∑ i, ∑ j, |v i| * |M.w i j| := by - simpa using - (Finset.sum_comm (s := Finset.univ) (t := Finset.univ) - (f := fun j i => |v i| * |M.w i j|)) - _ = ∑ i, |v i| * (∑ j, |M.w i j|) := by - refine Finset.sum_congr rfl ?_ - intro i _hi - simp [Finset.mul_sum] - have hrow : - ∀ i, (∑ j, |M.w i j|) ≤ SignedMixer.operatorNormBound M := by - intro i - exact Finset.le_sup' (s := Finset.univ) - (f := fun i => SignedMixer.rowAbsSum M i) (by simp) - have hfinal : - ∑ i, |v i| * (∑ j, |M.w i j|) ≤ - ∑ i, |v i| * SignedMixer.operatorNormBound M := by - refine Finset.sum_le_sum ?_ - intro i _hi - have hnonneg : 0 ≤ |v i| := abs_nonneg _ - exact mul_le_mul_of_nonneg_left (hrow i) hnonneg - have hmul : - ∑ i, |v i| * SignedMixer.operatorNormBound M = - (∑ i, |v i|) * SignedMixer.operatorNormBound M := by - simp [Finset.sum_mul] - calc - ∑ j, |M.apply v j| ≤ ∑ j, ∑ i, |v i| * |M.w i j| := hsum - _ = ∑ i, |v i| * (∑ j, |M.w i j|) := hswap - _ ≤ ∑ i, |v i| * SignedMixer.operatorNormBound M := hfinal - _ = (∑ i, |v i|) * SignedMixer.operatorNormBound M := hmul - -omit [DecidableEq n] [DecidableEq d] in -/-- L1 bound on keys from inputs and the W_K operator norm. -/ -theorem keys_sum_abs_le_of_input [Nonempty d] - (L : AttentionLinearization n d) (k : n) - (hKeys : - ∀ d', L.state.keys k d' = - ∑ d_in, L.state.input k d_in * L.layer.W_K.w d_in d') : - ∑ d', |L.state.keys k d'| ≤ - (∑ d_in, |L.state.input k d_in|) * - SignedMixer.operatorNormBound L.layer.W_K := by - classical - let v : d → ℝ := fun d_in => L.state.input k d_in - have hKeys' : - ∀ d', ∑ d_in, L.state.input k d_in * L.layer.W_K.w d_in d' = - L.state.keys k d' := by - intro d' - exact (hKeys d').symm - have hSum := sum_abs_apply_le (M := L.layer.W_K) (v := v) - simpa [v, SignedMixer.apply_def, hKeys'] using hSum - -omit [DecidableEq n] [DecidableEq d] in -/-- L1 bound on queries from inputs and the W_Q operator norm. -/ -theorem queries_sum_abs_le_of_input [Nonempty d] - (L : AttentionLinearization n d) (q : n) - (hQueries : - ∀ d', L.state.queries q d' = - ∑ d_in, L.state.input q d_in * L.layer.W_Q.w d_in d') : - ∑ d', |L.state.queries q d'| ≤ - (∑ d_in, |L.state.input q d_in|) * - SignedMixer.operatorNormBound L.layer.W_Q := by - classical - let v : d → ℝ := fun d_in => L.state.input q d_in - have hQueries' : - ∀ d', ∑ d_in, L.state.input q d_in * L.layer.W_Q.w d_in d' = - L.state.queries q d' := by - intro d' - exact (hQueries d').symm - have hSum := sum_abs_apply_le (M := L.layer.W_Q) (v := v) - simpa [v, SignedMixer.apply_def, hQueries'] using hSum - -omit [DecidableEq n] [DecidableEq d] in -/-- L1 bound on values from inputs and the W_V operator norm. -/ -theorem values_sum_abs_le_of_input [Nonempty d] - (L : AttentionLinearization n d) (k : n) - (hValues : - ∀ d', L.state.values k d' = - ∑ d_in, L.state.input k d_in * L.layer.W_V.w d_in d') : - ∑ d', |L.state.values k d'| ≤ - (∑ d_in, |L.state.input k d_in|) * - SignedMixer.operatorNormBound L.layer.W_V := by - classical - let v : d → ℝ := fun d_in => L.state.input k d_in - have hValues' : - ∀ d', ∑ d_in, L.state.input k d_in * L.layer.W_V.w d_in d' = - L.state.values k d' := by - intro d' - exact (hValues d').symm - have hSum := sum_abs_apply_le (M := L.layer.W_V) (v := v) - simpa [v, SignedMixer.apply_def, hValues'] using hSum - -omit [DecidableEq d] in -/-- Score-gradient entry bound from input L1 bounds and W_Q/W_K operator norms. -/ -theorem scoreGradient_abs_le_of_input [Nonempty n] [Nonempty d] - (L : AttentionLinearization n d) (q k i : n) (d_in : d) - (B wq wk : ℝ) - (hInput : ∀ pos, ∑ d', |L.state.input pos d'| ≤ B) - (hKeys : - ∀ pos d', L.state.keys pos d' = - ∑ d_in, L.state.input pos d_in * L.layer.W_K.w d_in d') - (hQueries : - ∀ pos d', L.state.queries pos d' = - ∑ d_in, L.state.input pos d_in * L.layer.W_Q.w d_in d') - (hWQ : SignedMixer.operatorNormBound L.layer.W_Q ≤ wq) - (hWK : SignedMixer.operatorNormBound L.layer.W_K ≤ wk) : - |scoreGradient L q k i d_in| ≤ - (1 / Real.sqrt (modelDim d)) * (2 * B * wq * wk) := by - classical - let scale : ℝ := 1 / Real.sqrt (modelDim d) - let queryContrib : ℝ := - ∑ d', L.layer.W_Q.w d_in d' * L.state.keys k d' - let keyContrib : ℝ := - ∑ d', L.state.queries q d' * L.layer.W_K.w d_in d' - have hB_nonneg : 0 ≤ B := by - rcases (inferInstance : Nonempty n) with ⟨pos⟩ - have hsum_nonneg : - 0 ≤ ∑ d', |L.state.input pos d'| := by - exact Finset.sum_nonneg (fun _ _ => abs_nonneg _) - exact le_trans hsum_nonneg (hInput pos) - have hWQ_nonneg : 0 ≤ wq := by - exact le_trans (SignedMixer.operatorNormBound_nonneg (M := L.layer.W_Q)) hWQ - have hWK_nonneg : 0 ≤ wk := by - exact le_trans (SignedMixer.operatorNormBound_nonneg (M := L.layer.W_K)) hWK - have hScale_nonneg : 0 ≤ scale := by - have hcard : 0 < (Fintype.card d : ℝ) := by - exact_mod_cast Fintype.card_pos_iff.mpr (inferInstance : Nonempty d) - have hsqrt : 0 < Real.sqrt (modelDim d) := by - simpa [modelDim] using (Real.sqrt_pos.2 hcard) - exact div_nonneg (show (0 : ℝ) ≤ 1 by exact zero_le_one) (le_of_lt hsqrt) - have hQuery_abs : - |queryContrib| ≤ B * wq * wk := by - have hsum : - |queryContrib| ≤ ∑ d', |L.layer.W_Q.w d_in d'| * |L.state.keys k d'| := by - simpa [queryContrib, abs_mul] using - (abs_sum_le_sum_abs - (f := fun d' => L.layer.W_Q.w d_in d' * L.state.keys k d') - (s := Finset.univ)) - have hle : - ∀ d', |L.layer.W_Q.w d_in d'| * |L.state.keys k d'| ≤ - SignedMixer.rowAbsSum L.layer.W_Q d_in * |L.state.keys k d'| := by - intro d' - have hrow : - |L.layer.W_Q.w d_in d'| ≤ SignedMixer.rowAbsSum L.layer.W_Q d_in := by - have hnonneg : - ∀ j ∈ (Finset.univ : Finset d), 0 ≤ |L.layer.W_Q.w d_in j| := by - intro j _hj - exact abs_nonneg _ - simpa [SignedMixer.rowAbsSum] using - (single_le_sum (s := (Finset.univ : Finset d)) - (f := fun j => |L.layer.W_Q.w d_in j|) hnonneg (by simp)) - exact mul_le_mul_of_nonneg_right hrow (abs_nonneg _) - have hsum' : - ∑ d', |L.layer.W_Q.w d_in d'| * |L.state.keys k d'| ≤ - ∑ d', SignedMixer.rowAbsSum L.layer.W_Q d_in * |L.state.keys k d'| := by - refine Finset.sum_le_sum ?_ - intro d' _hd - exact hle d' - have hsum'' : - ∑ d', SignedMixer.rowAbsSum L.layer.W_Q d_in * |L.state.keys k d'| = - SignedMixer.rowAbsSum L.layer.W_Q d_in * - (∑ d', |L.state.keys k d'|) := by - simp [Finset.mul_sum] - have hkeys := - keys_sum_abs_le_of_input (L := L) (k := k) (hKeys := hKeys k) - have hkeys' : - ∑ d', |L.state.keys k d'| ≤ B * wk := by - have hmul1 : - (∑ d', |L.state.input k d'|) * SignedMixer.operatorNormBound L.layer.W_K ≤ - B * SignedMixer.operatorNormBound L.layer.W_K := by - exact mul_le_mul_of_nonneg_right (hInput k) - (SignedMixer.operatorNormBound_nonneg (M := L.layer.W_K)) - have hmul2 : - B * SignedMixer.operatorNormBound L.layer.W_K ≤ B * wk := by - exact mul_le_mul_of_nonneg_left hWK hB_nonneg - exact le_trans hkeys (le_trans hmul1 hmul2) - have hrow_le : - SignedMixer.rowAbsSum L.layer.W_Q d_in ≤ wq := by - have hsup : - SignedMixer.rowAbsSum L.layer.W_Q d_in ≤ - SignedMixer.operatorNormBound L.layer.W_Q := by - exact Finset.le_sup' (s := Finset.univ) - (f := fun j => SignedMixer.rowAbsSum L.layer.W_Q j) (by simp) - exact le_trans hsup hWQ - have hmul1 : - SignedMixer.rowAbsSum L.layer.W_Q d_in * - (∑ d', |L.state.keys k d'|) ≤ - wq * (∑ d', |L.state.keys k d'|) := by - exact mul_le_mul_of_nonneg_right hrow_le - (Finset.sum_nonneg (fun _ _ => abs_nonneg _)) - have hmul2 : - wq * (∑ d', |L.state.keys k d'|) ≤ - wq * (B * wk) := by - exact mul_le_mul_of_nonneg_left hkeys' hWQ_nonneg - have hmul3 : - wq * (B * wk) = B * wq * wk := by ring - calc - |queryContrib| ≤ ∑ d', |L.layer.W_Q.w d_in d'| * |L.state.keys k d'| := hsum - _ ≤ ∑ d', SignedMixer.rowAbsSum L.layer.W_Q d_in * |L.state.keys k d'| := hsum' - _ = SignedMixer.rowAbsSum L.layer.W_Q d_in * (∑ d', |L.state.keys k d'|) := hsum'' - _ ≤ wq * (∑ d', |L.state.keys k d'|) := hmul1 - _ ≤ wq * (B * wk) := hmul2 - _ = B * wq * wk := hmul3 - have hKey_abs : - |keyContrib| ≤ B * wq * wk := by - have hsum : - |keyContrib| ≤ ∑ d', |L.state.queries q d'| * |L.layer.W_K.w d_in d'| := by - simpa [keyContrib, abs_mul, mul_comm] using - (abs_sum_le_sum_abs - (f := fun d' => L.state.queries q d' * L.layer.W_K.w d_in d') - (s := Finset.univ)) - have hle : - ∀ d', |L.state.queries q d'| * |L.layer.W_K.w d_in d'| ≤ - |L.state.queries q d'| * SignedMixer.rowAbsSum L.layer.W_K d_in := by - intro d' - have hrow : - |L.layer.W_K.w d_in d'| ≤ SignedMixer.rowAbsSum L.layer.W_K d_in := by - have hnonneg : - ∀ j ∈ (Finset.univ : Finset d), 0 ≤ |L.layer.W_K.w d_in j| := by - intro j _hj - exact abs_nonneg _ - simpa [SignedMixer.rowAbsSum] using - (single_le_sum (s := (Finset.univ : Finset d)) - (f := fun j => |L.layer.W_K.w d_in j|) hnonneg (by simp)) - exact mul_le_mul_of_nonneg_left hrow (abs_nonneg _) - have hsum' : - ∑ d', |L.state.queries q d'| * |L.layer.W_K.w d_in d'| ≤ - ∑ d', |L.state.queries q d'| * SignedMixer.rowAbsSum L.layer.W_K d_in := by - refine Finset.sum_le_sum ?_ - intro d' _hd - exact hle d' - have hsum'' : - ∑ d', |L.state.queries q d'| * SignedMixer.rowAbsSum L.layer.W_K d_in = - (∑ d', |L.state.queries q d'|) * - SignedMixer.rowAbsSum L.layer.W_K d_in := by - simp [Finset.sum_mul] - have hqueries := - queries_sum_abs_le_of_input (L := L) (q := q) (hQueries := hQueries q) - have hqueries' : - ∑ d', |L.state.queries q d'| ≤ B * wq := by - have hmul1 : - (∑ d', |L.state.input q d'|) * SignedMixer.operatorNormBound L.layer.W_Q ≤ - B * SignedMixer.operatorNormBound L.layer.W_Q := by - exact mul_le_mul_of_nonneg_right (hInput q) - (SignedMixer.operatorNormBound_nonneg (M := L.layer.W_Q)) - have hmul2 : - B * SignedMixer.operatorNormBound L.layer.W_Q ≤ B * wq := by - exact mul_le_mul_of_nonneg_left hWQ hB_nonneg - exact le_trans hqueries (le_trans hmul1 hmul2) - have hrow_le : - SignedMixer.rowAbsSum L.layer.W_K d_in ≤ wk := by - have hsup : - SignedMixer.rowAbsSum L.layer.W_K d_in ≤ - SignedMixer.operatorNormBound L.layer.W_K := by - exact Finset.le_sup' (s := Finset.univ) - (f := fun j => SignedMixer.rowAbsSum L.layer.W_K j) (by simp) - exact le_trans hsup hWK - have hmul1 : - (∑ d', |L.state.queries q d'|) * - SignedMixer.rowAbsSum L.layer.W_K d_in ≤ - (∑ d', |L.state.queries q d'|) * wk := by - exact mul_le_mul_of_nonneg_left hrow_le - (Finset.sum_nonneg (fun _ _ => abs_nonneg _)) - have hmul2 : - (∑ d', |L.state.queries q d'|) * wk ≤ - (B * wq) * wk := by - exact mul_le_mul_of_nonneg_right hqueries' hWK_nonneg - have hmul3 : - (B * wq) * wk = B * wq * wk := by ring - calc - |keyContrib| ≤ ∑ d', |L.state.queries q d'| * |L.layer.W_K.w d_in d'| := hsum - _ ≤ ∑ d', |L.state.queries q d'| * SignedMixer.rowAbsSum L.layer.W_K d_in := hsum' - _ = (∑ d', |L.state.queries q d'|) * SignedMixer.rowAbsSum L.layer.W_K d_in := hsum'' - _ ≤ (∑ d', |L.state.queries q d'|) * wk := hmul1 - _ ≤ (B * wq) * wk := hmul2 - _ = B * wq * wk := hmul3 - have hQuery_term : - |if q = i then queryContrib else 0| ≤ B * wq * wk := by - by_cases hqi : q = i - · simp [hqi, hQuery_abs] - · have hnonneg : 0 ≤ B * wq * wk := - mul_nonneg (mul_nonneg hB_nonneg hWQ_nonneg) hWK_nonneg - simpa [hqi] using hnonneg - have hKey_term : - |if k = i then keyContrib else 0| ≤ B * wq * wk := by - by_cases hki : k = i - · simp [hki, hKey_abs] - · have hnonneg : 0 ≤ B * wq * wk := - mul_nonneg (mul_nonneg hB_nonneg hWQ_nonneg) hWK_nonneg - simpa [hki] using hnonneg - have hsum : - |(if q = i then queryContrib else 0) + (if k = i then keyContrib else 0)| ≤ - B * wq * wk + B * wq * wk := by - exact le_trans (abs_add_le _ _) (add_le_add hQuery_term hKey_term) - have hsum_eq : B * wq * wk + B * wq * wk = 2 * B * wq * wk := by - ring - have hmul : - scale * - (B * wq * wk + B * wq * wk) = - scale * (2 * B * wq * wk) := by - simp [hsum_eq] - calc - |scoreGradient L q k i d_in| - = |scale| * - |(if q = i then queryContrib else 0) + (if k = i then keyContrib else 0)| := by - simp [scoreGradient, scale, queryContrib, keyContrib, abs_mul] - _ = scale * - |(if q = i then queryContrib else 0) + (if k = i then keyContrib else 0)| := by - simp [abs_of_nonneg hScale_nonneg] - _ ≤ scale * (B * wq * wk + B * wq * wk) := by - exact mul_le_mul_of_nonneg_left hsum hScale_nonneg - _ = scale * (2 * B * wq * wk) := hmul - _ = (1 / Real.sqrt (modelDim d)) * (2 * B * wq * wk) := by simp [scale] - -omit [DecidableEq d] in -/-- L1 bound on score-gradient rows from input L1 bounds and W_Q/W_K operator norms. -/ -theorem scoreGradient_sum_le_of_input [Nonempty n] [Nonempty d] - (L : AttentionLinearization n d) (q i : n) (d_in : d) - (B wq wk : ℝ) - (hInput : ∀ pos, ∑ d', |L.state.input pos d'| ≤ B) - (hKeys : - ∀ pos d', L.state.keys pos d' = - ∑ d_in, L.state.input pos d_in * L.layer.W_K.w d_in d') - (hQueries : - ∀ pos d', L.state.queries pos d' = - ∑ d_in, L.state.input pos d_in * L.layer.W_Q.w d_in d') - (hWQ : SignedMixer.operatorNormBound L.layer.W_Q ≤ wq) - (hWK : SignedMixer.operatorNormBound L.layer.W_K ≤ wk) : - ∑ k, |scoreGradient L q k i d_in| ≤ - (Fintype.card n : ℝ) * - ((1 / Real.sqrt (modelDim d)) * (2 * B * wq * wk)) := by - classical - have hEach : - ∀ k, |scoreGradient L q k i d_in| ≤ - (1 / Real.sqrt (modelDim d)) * (2 * B * wq * wk) := by - intro k - exact scoreGradient_abs_le_of_input (L := L) (q := q) (k := k) (i := i) (d_in := d_in) - (B := B) (wq := wq) (wk := wk) hInput hKeys hQueries hWQ hWK - have hSum : - ∑ k, |scoreGradient L q k i d_in| ≤ - ∑ k : n, (1 / Real.sqrt (modelDim d)) * (2 * B * wq * wk) := by - refine Finset.sum_le_sum ?_ - intro k _hk - exact hEach k - exact le_trans hSum (by simp) - -omit [DecidableEq d] in -/-- Row-sum bound for attention gradients from softmax Jacobian and score-gradient bounds. -/ -theorem attentionGradient_rowAbsSum_le_of_softmax [Nonempty n] - (L : AttentionLinearization n d) (q i : n) (d_in : d) (J S : ℝ) - (hConsistent : L.state.attentionWeights q = softmax (L.state.scores q)) - (hSoftmax : - SignedMixer.operatorNormBound (softmaxJacobian (L.state.scores q)) ≤ J) - (hScore : ∑ k, |scoreGradient L q k i d_in| ≤ S) : - ∑ k, |attentionGradient L q k i d_in| ≤ J * S := by - classical - let M := softmaxJacobian (L.state.scores q) - let v : n → ℝ := fun k => scoreGradient L q k i d_in - have hApply : ∀ k, attentionGradient L q k i d_in = M.apply v k := by - intro k - have h := - attentionGradient_via_softmax (L := L) (q := q) (k := k) (i := i) (d_in := d_in) - hConsistent - simpa [SignedMixer.apply_def, M, v, mul_comm] using h - have hSum : - ∑ k, |attentionGradient L q k i d_in| = ∑ k, |M.apply v k| := by - refine Finset.sum_congr rfl ?_ - intro k _hk - simp [hApply] - have hBound := sum_abs_apply_le (M := M) (v := v) - have hSum_nonneg : 0 ≤ ∑ k, |v k| := - Finset.sum_nonneg (fun _ _ => abs_nonneg _) - have hJ_nonneg : 0 ≤ J := - le_trans (SignedMixer.operatorNormBound_nonneg (M := M)) hSoftmax - have hMul1 : - (∑ k, |v k|) * SignedMixer.operatorNormBound M ≤ (∑ k, |v k|) * J := by - exact mul_le_mul_of_nonneg_left hSoftmax hSum_nonneg - have hMul2 : - (∑ k, |v k|) * J ≤ S * J := by - exact mul_le_mul_of_nonneg_right hScore hJ_nonneg - calc - ∑ k, |attentionGradient L q k i d_in| = ∑ k, |M.apply v k| := hSum - _ ≤ (∑ k, |v k|) * SignedMixer.operatorNormBound M := hBound - _ ≤ (∑ k, |v k|) * J := hMul1 - _ ≤ S * J := hMul2 - _ = J * S := by ring - -omit [DecidableEq n] [DecidableEq d] in -/-- Attention weights are nonnegative when consistent with softmax. -/ -theorem attentionWeights_nonneg (L : AttentionLinearization n d) - (hConsistent : ∀ q, L.state.attentionWeights q = softmax (L.state.scores q)) - (q k : n) : 0 ≤ L.state.attentionWeights q k := by - simpa [hConsistent q] using - (softmax_nonneg (x := L.state.scores q) (j := k)) - -omit [DecidableEq n] [DecidableEq d] in -/-- Attention weights for a query sum to one when consistent with softmax. -/ -theorem attentionWeights_row_sum_one [Nonempty n] (L : AttentionLinearization n d) - (hConsistent : ∀ q, L.state.attentionWeights q = softmax (L.state.scores q)) - (q : n) : ∑ k, L.state.attentionWeights q k = 1 := by - simpa [hConsistent q] using (softmax_sum_one (x := L.state.scores q)) - -omit [DecidableEq n] [DecidableEq d] in -/-- Attention weights are at most one when consistent with softmax. -/ -theorem attentionWeights_le_one [Nonempty n] (L : AttentionLinearization n d) - (hConsistent : ∀ q, L.state.attentionWeights q = softmax (L.state.scores q)) - (q k : n) : L.state.attentionWeights q k ≤ 1 := by - classical - have hnonneg : ∀ j, 0 ≤ L.state.attentionWeights q j := by - intro j - exact attentionWeights_nonneg (L := L) hConsistent q j - have hle : - L.state.attentionWeights q k ≤ ∑ j, L.state.attentionWeights q j := by - simpa using - (single_le_sum (s := Finset.univ) (f := fun j => L.state.attentionWeights q j) - (by - intro j _hj - exact hnonneg j) (by simp)) - have hsum := attentionWeights_row_sum_one (L := L) hConsistent q - simpa [hsum] using hle - -omit [DecidableEq n] [DecidableEq d] in -/-- Column mass of attention weights is bounded by the sequence length under softmax consistency. -/ -theorem attentionWeights_column_sum_le_card [Nonempty n] (L : AttentionLinearization n d) - (hConsistent : ∀ q, L.state.attentionWeights q = softmax (L.state.scores q)) - (k : n) : - ∑ q, |L.state.attentionWeights q k| ≤ (Fintype.card n : ℝ) := by - classical - have hle1 : ∀ q, L.state.attentionWeights q k ≤ 1 := by - intro q - exact attentionWeights_le_one (L := L) hConsistent q k - have hnonneg : ∀ q, 0 ≤ L.state.attentionWeights q k := by - intro q - exact attentionWeights_nonneg (L := L) hConsistent q k - have hsum_abs : - (∑ q, |L.state.attentionWeights q k|) = - ∑ q, L.state.attentionWeights q k := by - refine Finset.sum_congr rfl ?_ - intro q _hq - exact abs_of_nonneg (hnonneg q) - have hsum_le : - (∑ q : n, L.state.attentionWeights q k) ≤ ∑ q : n, (1 : ℝ) := by - refine Finset.sum_le_sum ?_ - intro q _hq - exact hle1 q - have hsum_one : (∑ q : n, (1 : ℝ)) = (Fintype.card n : ℝ) := by - simp - calc - ∑ q, |L.state.attentionWeights q k| - = ∑ q, L.state.attentionWeights q k := hsum_abs - _ ≤ ∑ q, (1 : ℝ) := hsum_le - _ = (Fintype.card n : ℝ) := hsum_one - -omit [DecidableEq n] [DecidableEq d] in -/-- Value-term operator-norm bound using softmax column-mass control. -/ -theorem valueTerm_operatorNormBound_le_card [Nonempty n] [Nonempty d] - (L : AttentionLinearization n d) (B : ℝ) - (hConsistent : ∀ q, L.state.attentionWeights q = softmax (L.state.scores q)) - (hVO : SignedMixer.operatorNormBound (L.layer.W_V.comp L.layer.W_O) ≤ B) : - SignedMixer.operatorNormBound (valueTerm L) ≤ (Fintype.card n : ℝ) * B := by - have hAttn := attentionWeights_column_sum_le_card (L := L) hConsistent - simpa using - (valueTerm_operatorNormBound_le (L := L) (A := (Fintype.card n : ℝ)) (B := B) - hAttn hVO) - -/-! ### Explicit Pattern Term Formula -/ - -omit [DecidableEq n] [DecidableEq d] in -/-- **Explicit formula for the Pattern Term**. - -PatternTerm_{(i,d_in), (q,d_out)} = - Σ_k ∂A_{qk}/∂x_{i,d_in} · (Σ_{d'} V_k[d'] · W_O[d',d_out]) - = Σ_k attentionGradient(q,k,i,d_in) · valueContrib(k,d_out) - -This shows exactly how shifting attention patterns affects the output. -/ -noncomputable def patternTermExplicit (L : AttentionLinearization n d) : - SignedMixer (n × d) (n × d) where - w := fun ⟨i, d_in⟩ ⟨q, d_out⟩ => - ∑ k, attentionGradient L q k i d_in * - (∑ d', L.state.values k d' * L.layer.W_O.w d' d_out) - -omit [DecidableEq d] in -/-- Pattern term equals the explicit formula when the full Jacobian matches the explicit split. -/ -theorem patternTerm_eq_explicit_of_fullJacobian_eq (L : AttentionLinearization n d) - (hEq : L.fullJacobian = valueTerm L + patternTermExplicit L) : - patternTerm L = patternTermExplicit L := by - have hDecomp := attention_jacobian_decomposition (L := L) - have hMixers : valueTerm L + patternTerm L = valueTerm L + patternTermExplicit L := by - exact hDecomp.symm.trans hEq - ext i j - have hEq' := congrArg (fun M => M.w i j) hMixers - have hEq'' : - (valueTerm L).w i j + (patternTerm L).w i j = - (valueTerm L).w i j + (patternTermExplicit L).w i j := by - simpa [SignedMixer.add_w] using hEq' - exact add_left_cancel hEq'' - -/-! ### Pattern Term Bounds -/ - -/-- Output mixer using cached values and the output projection. -/ -noncomputable def valueOutputMixer (L : AttentionLinearization n d) : SignedMixer n d := - ⟨fun k d_out => ∑ d', L.state.values k d' * L.layer.W_O.w d' d_out⟩ - -omit [DecidableEq n] [DecidableEq d] in -/-- Value-output mixer bound from input L1 bounds and a `W_V·W_O` operator-norm bound. -/ -theorem valueOutputMixer_operatorNormBound_le_of_input [Nonempty n] [Nonempty d] - (L : AttentionLinearization n d) (B V : ℝ) - (hInput : ∀ pos, ∑ d', |L.state.input pos d'| ≤ B) - (hValues : - ∀ pos d', L.state.values pos d' = - ∑ d_in, L.state.input pos d_in * L.layer.W_V.w d_in d') - (hVO : SignedMixer.operatorNormBound (L.layer.W_V.comp L.layer.W_O) ≤ V) : - SignedMixer.operatorNormBound (valueOutputMixer L) ≤ B * V := by - classical - have hB_nonneg : 0 ≤ B := by - rcases (inferInstance : Nonempty n) with ⟨pos⟩ - have hsum_nonneg : - 0 ≤ ∑ d', |L.state.input pos d'| := by - exact Finset.sum_nonneg (fun _ _ => abs_nonneg _) - exact le_trans hsum_nonneg (hInput pos) - refine (Finset.sup'_le_iff (s := Finset.univ) - (H := Finset.univ_nonempty (α := n)) - (f := fun k => SignedMixer.rowAbsSum (valueOutputMixer L) k) - (a := B * V)).2 ?_ - intro k _hk - have hInputSum := hInput k - have hW : - ∀ d_out, (valueOutputMixer L).w k d_out = - ∑ d_in, L.state.input k d_in * (L.layer.W_V.comp L.layer.W_O).w d_in d_out := by - intro d_out - have hValues' := hValues k - calc - (valueOutputMixer L).w k d_out - = ∑ d', L.state.values k d' * L.layer.W_O.w d' d_out := by - simp [valueOutputMixer] - _ = ∑ d', (∑ d_in, L.state.input k d_in * L.layer.W_V.w d_in d') * - L.layer.W_O.w d' d_out := by - simp [hValues'] - _ = ∑ d_in, L.state.input k d_in * - (∑ d', L.layer.W_V.w d_in d' * L.layer.W_O.w d' d_out) := by - calc - ∑ d', (∑ d_in, L.state.input k d_in * L.layer.W_V.w d_in d') * - L.layer.W_O.w d' d_out - = ∑ d', L.layer.W_O.w d' d_out * - (∑ d_in, L.state.input k d_in * L.layer.W_V.w d_in d') := by - simp [mul_comm] - _ = ∑ d', ∑ d_in, - L.layer.W_O.w d' d_out * - (L.state.input k d_in * L.layer.W_V.w d_in d') := by - simp [Finset.mul_sum] - _ = ∑ d_in, ∑ d', - L.layer.W_O.w d' d_out * - (L.state.input k d_in * L.layer.W_V.w d_in d') := by - simpa using - (Finset.sum_comm (s := Finset.univ) (t := Finset.univ) - (f := fun d' d_in => - L.layer.W_O.w d' d_out * - (L.state.input k d_in * L.layer.W_V.w d_in d'))) - _ = ∑ d_in, L.state.input k d_in * - (∑ d', L.layer.W_V.w d_in d' * L.layer.W_O.w d' d_out) := by - refine Finset.sum_congr rfl ?_ - intro d_in _hd - simp [Finset.mul_sum, mul_comm, mul_assoc] - _ = ∑ d_in, L.state.input k d_in * (L.layer.W_V.comp L.layer.W_O).w d_in d_out := by - simp [SignedMixer.comp_w] - have hRow : - SignedMixer.rowAbsSum (valueOutputMixer L) k = - ∑ d_out, |∑ d_in, L.state.input k d_in * - (L.layer.W_V.comp L.layer.W_O).w d_in d_out| := by - simp [SignedMixer.rowAbsSum, hW] - have hSum := - sum_abs_apply_le (M := L.layer.W_V.comp L.layer.W_O) - (v := fun d_in => L.state.input k d_in) - have hSum' : - ∑ d_out, |(L.layer.W_V.comp L.layer.W_O).apply (fun d_in => - L.state.input k d_in) d_out| ≤ - (∑ d_in, |L.state.input k d_in|) * - SignedMixer.operatorNormBound (L.layer.W_V.comp L.layer.W_O) := by - simpa using hSum - have hMul : - (∑ d_in, |L.state.input k d_in|) * - SignedMixer.operatorNormBound (L.layer.W_V.comp L.layer.W_O) ≤ B * V := by - have hmul1 : - (∑ d_in, |L.state.input k d_in|) * - SignedMixer.operatorNormBound (L.layer.W_V.comp L.layer.W_O) ≤ - B * SignedMixer.operatorNormBound (L.layer.W_V.comp L.layer.W_O) := by - exact mul_le_mul_of_nonneg_right hInputSum - (SignedMixer.operatorNormBound_nonneg (M := L.layer.W_V.comp L.layer.W_O)) - have hmul2 : - B * SignedMixer.operatorNormBound (L.layer.W_V.comp L.layer.W_O) ≤ B * V := by - exact mul_le_mul_of_nonneg_left hVO hB_nonneg - exact le_trans hmul1 hmul2 - calc - SignedMixer.rowAbsSum (valueOutputMixer L) k - = ∑ d_out, |(L.layer.W_V.comp L.layer.W_O).apply (fun d_in => - L.state.input k d_in) d_out| := by - simpa [SignedMixer.apply_def] using hRow - _ ≤ (∑ d_in, |L.state.input k d_in|) * - SignedMixer.operatorNormBound (L.layer.W_V.comp L.layer.W_O) := hSum' - _ ≤ B * V := hMul - -/-- Mixer capturing attention gradients for a fixed input coordinate. -/ -noncomputable def attentionGradientMixer (L : AttentionLinearization n d) (i : n) (d_in : d) : - SignedMixer n n := - ⟨fun q k => attentionGradient L q k i d_in⟩ - -omit [DecidableEq d] in -/-- Pattern term entries as a gradient mixer composed with value output. -/ -theorem patternTermExplicit_w_eq (L : AttentionLinearization n d) - (i : n) (d_in : d) (q : n) (d_out : d) : - (patternTermExplicit L).w (i, d_in) (q, d_out) = - ((attentionGradientMixer L i d_in).comp (valueOutputMixer L)).w q d_out := by - simp [patternTermExplicit, attentionGradientMixer, valueOutputMixer, SignedMixer.comp_w] - -omit [DecidableEq d] in -/-- Row-absolute-sum bound for the explicit pattern term. -/ -theorem patternTermExplicit_rowAbsSum_le [Nonempty n] [Nonempty d] - (L : AttentionLinearization n d) (i : n) (d_in : d) (G V : ℝ) - (hGrad : ∀ q, ∑ k, |attentionGradient L q k i d_in| ≤ G) - (hValue : SignedMixer.operatorNormBound (valueOutputMixer L) ≤ V) : - SignedMixer.rowAbsSum (patternTermExplicit L) (i, d_in) ≤ - (Fintype.card n : ℝ) * G * V := by - classical - let A := attentionGradientMixer L i d_in - let B := valueOutputMixer L - have hRow : - SignedMixer.rowAbsSum (patternTermExplicit L) (i, d_in) = - ∑ q, SignedMixer.rowAbsSum (A.comp B) q := by - have hRow1 : - SignedMixer.rowAbsSum (patternTermExplicit L) (i, d_in) = - ∑ q, ∑ d_out, - |∑ k, attentionGradient L q k i d_in * - (∑ d', L.state.values k d' * L.layer.W_O.w d' d_out)| := by - simpa [SignedMixer.rowAbsSum, patternTermExplicit] using - (Fintype.sum_prod_type' - (f := fun q d_out => - |∑ k, attentionGradient L q k i d_in * - (∑ d', L.state.values k d' * L.layer.W_O.w d' d_out)|)) - have hRow2 : - ∑ q, SignedMixer.rowAbsSum (A.comp B) q = - ∑ q, ∑ d_out, - |∑ k, attentionGradient L q k i d_in * - (∑ d', L.state.values k d' * L.layer.W_O.w d' d_out)| := by - simp [SignedMixer.rowAbsSum, A, B, attentionGradientMixer, valueOutputMixer, - SignedMixer.comp_w] - exact hRow1.trans hRow2.symm - have hRow_q : - ∀ q, SignedMixer.rowAbsSum (A.comp B) q ≤ G * SignedMixer.operatorNormBound B := by - intro q - have hA : - SignedMixer.rowAbsSum A q ≤ G := by - simpa [A, SignedMixer.rowAbsSum] using hGrad q - have hB_nonneg : 0 ≤ SignedMixer.operatorNormBound B := - SignedMixer.operatorNormBound_nonneg (M := B) - have hcomp : - SignedMixer.rowAbsSum (A.comp B) q ≤ - SignedMixer.rowAbsSum A q * SignedMixer.operatorNormBound B := - SignedMixer.rowAbsSum_comp_le (M := A) (N := B) (i := q) - have hmul : - SignedMixer.rowAbsSum A q * SignedMixer.operatorNormBound B ≤ - G * SignedMixer.operatorNormBound B := - mul_le_mul_of_nonneg_right hA hB_nonneg - exact le_trans hcomp hmul - have hSum : - (∑ q : n, SignedMixer.rowAbsSum (A.comp B) q) ≤ - ∑ q : n, G * SignedMixer.operatorNormBound B := by - refine Finset.sum_le_sum ?_ - intro q _hq - exact hRow_q q - have hCard : - (∑ q : n, G * SignedMixer.operatorNormBound B) = - (Fintype.card n : ℝ) * (G * SignedMixer.operatorNormBound B) := by - simp - have hCard_nonneg : 0 ≤ (Fintype.card n : ℝ) := by - exact_mod_cast Nat.zero_le _ - have hG_nonneg : 0 ≤ G := by - rcases (inferInstance : Nonempty n) with ⟨q⟩ - have h := hGrad q - exact le_trans (Finset.sum_nonneg (fun _ _ => abs_nonneg _)) h - have hGV : G * SignedMixer.operatorNormBound B ≤ G * V := by - exact mul_le_mul_of_nonneg_left hValue hG_nonneg - calc - SignedMixer.rowAbsSum (patternTermExplicit L) (i, d_in) - = ∑ q, SignedMixer.rowAbsSum (A.comp B) q := hRow - _ ≤ ∑ q, G * SignedMixer.operatorNormBound B := hSum - _ = (Fintype.card n : ℝ) * (G * SignedMixer.operatorNormBound B) := hCard - _ ≤ (Fintype.card n : ℝ) * (G * V) := by - exact mul_le_mul_of_nonneg_left hGV hCard_nonneg - _ = (Fintype.card n : ℝ) * G * V := by ring - -omit [DecidableEq d] in -/-- Operator-norm bound for the explicit pattern term. -/ -theorem patternTermExplicit_operatorNormBound_le [Nonempty n] [Nonempty d] - (L : AttentionLinearization n d) (G V : ℝ) - (hGrad : ∀ i d_in q, ∑ k, |attentionGradient L q k i d_in| ≤ G) - (hValue : SignedMixer.operatorNormBound (valueOutputMixer L) ≤ V) : - SignedMixer.operatorNormBound (patternTermExplicit L) ≤ - (Fintype.card n : ℝ) * G * V := by - classical - refine (Finset.sup'_le_iff (s := Finset.univ) - (H := Finset.univ_nonempty (α := n × d)) - (f := fun i => SignedMixer.rowAbsSum (patternTermExplicit L) i) - (a := (Fintype.card n : ℝ) * G * V)).2 ?_ - intro id _hid - rcases id with ⟨i, d_in⟩ - have hRow := - patternTermExplicit_rowAbsSum_le (L := L) (i := i) (d_in := d_in) (G := G) (V := V) - (hGrad := hGrad i d_in) hValue - simpa using hRow - -omit [DecidableEq d] in -/-- Pattern-term operator-norm bound from equality with the explicit formula. -/ -theorem patternTerm_operatorNormBound_le_of_eq_explicit [Nonempty n] [Nonempty d] - (L : AttentionLinearization n d) (G V : ℝ) - (hGrad : ∀ i d_in q, ∑ k, |attentionGradient L q k i d_in| ≤ G) - (hValue : SignedMixer.operatorNormBound (valueOutputMixer L) ≤ V) - (hEq : patternTerm L = patternTermExplicit L) : - SignedMixer.operatorNormBound (patternTerm L) ≤ - (Fintype.card n : ℝ) * G * V := by - simpa [hEq] using - (patternTermExplicit_operatorNormBound_le (L := L) (G := G) (V := V) hGrad hValue) - -omit [DecidableEq d] in -/-- Pattern-term operator-norm bound via softmax consistency and score-gradient bounds. -/ -theorem patternTerm_operatorNormBound_le_of_softmax [Nonempty n] [Nonempty d] - (L : AttentionLinearization n d) (J S V : ℝ) - (hConsistent : ∀ q, L.state.attentionWeights q = softmax (L.state.scores q)) - (hSoftmax : - ∀ q, SignedMixer.operatorNormBound (softmaxJacobian (L.state.scores q)) ≤ J) - (hScore : ∀ i d_in q, ∑ k, |scoreGradient L q k i d_in| ≤ S) - (hValue : SignedMixer.operatorNormBound (valueOutputMixer L) ≤ V) - (hEq : patternTerm L = patternTermExplicit L) : - SignedMixer.operatorNormBound (patternTerm L) ≤ - (Fintype.card n : ℝ) * J * S * V := by - have hGrad : ∀ i d_in q, ∑ k, |attentionGradient L q k i d_in| ≤ J * S := by - intro i d_in q - simpa using - (attentionGradient_rowAbsSum_le_of_softmax (L := L) (q := q) (i := i) (d_in := d_in) - (J := J) (S := S) (hConsistent := hConsistent q) - (hSoftmax := hSoftmax q) (hScore := hScore i d_in q)) - have hBound := - patternTerm_operatorNormBound_le_of_eq_explicit - (L := L) (G := J * S) (V := V) hGrad hValue hEq - calc - SignedMixer.operatorNormBound (patternTerm L) ≤ - (Fintype.card n : ℝ) * (J * S) * V := hBound - _ = (Fintype.card n : ℝ) * J * S * V := by ring - -/-! ### Attention Rollout Approximation Error -/ - -/-- **Attention Approximation Error**: The Frobenius norm of the Pattern Term. - -When this is small relative to the Value Term, Attention Rollout (using just A) -is a faithful explanation of the network's input-output relationship. - -This gives a rigorous, quantitative answer to "When is visualizing attention weights valid?" -/ -noncomputable def attentionApproximationError (L : AttentionLinearization n d) : ℝ := - Real.sqrt (∑ input : n × d, ∑ output : n × d, - ((patternTerm L).w input output) ^ 2) - -/-- The Frobenius norm of the Value Term for normalization. -/ -noncomputable def valueTermNorm (L : AttentionLinearization n d) : ℝ := - Real.sqrt (∑ input : n × d, ∑ output : n × d, - ((valueTerm L).w input output) ^ 2) - -/-- **Relative Approximation Error**: Pattern Term / Value Term. - -When this ratio is small (e.g., < 0.1), attention weights are a good explanation. -When large, the attention pattern is shifting significantly with input changes, -and attention visualization may be misleading. -/ -noncomputable def relativeApproximationError (L : AttentionLinearization n d) - (_hV : valueTermNorm L ≠ 0) : ℝ := - attentionApproximationError L / valueTermNorm L - -/-- **Attention Rollout Faithfulness Criterion**: The approximation is "ε-faithful" -if the relative error is at most ε. - -This gives a rigorous definition of when attention visualization is valid! -/ -def isAttentionRolloutFaithful (L : AttentionLinearization n d) (ε : ℝ) - (hV : valueTermNorm L ≠ 0) : Prop := - relativeApproximationError L hV ≤ ε - -/-! ### Bounds on the Pattern Term -/ - -variable [Nonempty n] [Nonempty d] - -/-- Maximum entry in the value projection. -/ -noncomputable def maxValueWeight (L : AttentionLinearization n d) : ℝ := - Finset.sup' Finset.univ Finset.univ_nonempty fun (p : d × d) => - |(L.layer.W_V.comp L.layer.W_O).w p.1 p.2| - -/-- Maximum entry in the score gradient (bounded by QK projection norms). -/ -noncomputable def maxScoreGradient (L : AttentionLinearization n d) : ℝ := - let maxQ := Finset.sup' Finset.univ Finset.univ_nonempty fun (p : d × d) => - |L.layer.W_Q.w p.1 p.2| - let maxK := Finset.sup' Finset.univ Finset.univ_nonempty fun (p : d × d) => - |L.layer.W_K.w p.1 p.2| - let maxKey := Finset.sup' Finset.univ Finset.univ_nonempty fun (p : n × d) => - |L.state.keys p.1 p.2| - let maxQuery := Finset.sup' Finset.univ Finset.univ_nonempty fun (p : n × d) => - |L.state.queries p.1 p.2| - (1 / Real.sqrt (modelDim d)) * (maxQ * maxKey + maxQuery * maxK) * Fintype.card d - -/-- **Bound on Pattern Term via softmax sensitivity**. - -The Pattern Term is bounded by: - |PatternTerm| ≤ maxAttnGradBound · maxValueContrib · (sequence length) - -where maxAttnGradBound depends on the softmax Jacobian (bounded by 0.25 per entry) -and the score gradient. - -This is a structural statement about the existence of such a bound. -The exact bound depends on architectural details. -/ -noncomputable def patternTermBound (L : AttentionLinearization n d) : ℝ := - let maxValue := Finset.sup' Finset.univ Finset.univ_nonempty fun (p : n × d) => - |L.state.values p.1 p.2| - Fintype.card n * (0.25 * maxScoreGradient L) * - (Fintype.card d * maxValueWeight L * maxValue) - -/-! ### When is Attention Rollout Valid? -/ - -/-- **Sufficient condition for attention rollout validity**: small score gradients. - -If the score gradients are small (attention patterns are stable), then the -Pattern Term is small and Attention Rollout is faithful. - -Intuitively: when Q·K^T has small gradients with respect to x, the attention -pattern doesn't shift much, so treating A as constant is valid. - -This definition captures when we expect rollout to be faithful. -/ -def hasSmallScoreGradient (L : AttentionLinearization n d) (ε : ℝ) : Prop := - maxScoreGradient L ≤ ε - -/-- **Attention rollout validity criterion**: When score gradients are bounded, -the relative error is bounded by a function of the score gradient bound. - -This is the key structural insight: the faithfulness of attention rollout -depends on how stable the attention pattern is under input perturbations. -/ -noncomputable def rolloutErrorBound (L : AttentionLinearization n d) : ℝ := - patternTermBound L / (valueTermNorm L + 1) - -omit [DecidableEq n] [DecidableEq d] [Nonempty n] [Nonempty d] in -/-- **Attention rollout becomes exact when QK projections are zero**. - -If W_Q = 0, the query contribution to score gradients vanishes. -This is unrealistic but shows the theoretical structure. -/ -theorem scoreGradient_queryContrib_zero_when_Q_zero (L : AttentionLinearization n d) - (hQ : L.layer.W_Q = 0) (k : n) (d_in : d) : - ∑ d', L.layer.W_Q.w d_in d' * L.state.keys k d' = 0 := by - have hQ' : ∀ a b, L.layer.W_Q.w a b = 0 := fun a b => by simp [hQ, SignedMixer.zero_w] - simp [hQ'] - -/-! ### Position-wise vs Full Jacobian -/ - -/-- **Position-collapsed attention Jacobian**: Sum over hidden dimensions. - -This gives a (position × position) matrix that shows how much each input position -affects each output position, averaging over dimensions. - -This is closer to what "attention visualization" typically shows. -/ -noncomputable def positionJacobian (L : AttentionLinearization n d) : SignedMixer n n where - w := fun i q => ∑ d_in : d, ∑ d_out : d, L.fullJacobian.w (i, d_in) (q, d_out) - -/-- Position-collapsed Value Term. -/ -noncomputable def positionValueTerm (L : AttentionLinearization n d) : SignedMixer n n where - w := fun k q => - -- Σ_{d_in, d_out} A_{qk} · (W_V · W_O)_{d_in, d_out} - let voProd := ∑ d_in : d, ∑ d_out : d, (L.layer.W_V.comp L.layer.W_O).w d_in d_out - L.state.attentionWeights q k * voProd - -omit [DecidableEq n] [DecidableEq d] [Nonempty n] [Nonempty d] in -/-- **Key insight**: The position-collapsed Value Term is proportional to attention weights! - -positionValueTerm(k→q) = A_{qk} · Σ_{d_in,d_out} (W_V · W_O)_{d_in,d_out} - -So if the total sum of entries of W_V · W_O is treated as a constant, attention weights -directly give the position flow. This is the mathematical justification for -"attention rollout". -/ -theorem positionValueTerm_proportional_to_attention (L : AttentionLinearization n d) (k q : n) : - (positionValueTerm L).w k q = - L.state.attentionWeights q k * - ∑ d_in : d, ∑ d_out : d, (L.layer.W_V.comp L.layer.W_O).w d_in d_out := rfl - -/-- The total sum of entries of W_V · W_O (the proportionality constant). -/ -noncomputable def valueOutputTrace (L : AttentionLinearization n d) : ℝ := - ∑ d_in : d, ∑ d_out : d, (L.layer.W_V.comp L.layer.W_O).w d_in d_out - -omit [DecidableEq n] [DecidableEq d] [Nonempty n] [Nonempty d] in -/-- Position Value Term is attention weights scaled by the total sum of entries. -/ -theorem positionValueTerm_eq_scaled_attention (L : AttentionLinearization n d) (k q : n) : - (positionValueTerm L).w k q = L.state.attentionWeights q k * valueOutputTrace L := rfl - -end AttentionJacobian - -/-! ## Full Transformer Layer Linearization -/ - -section TransformerLayers - -variable {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - -/-- A full transformer layer linearization. -Given activations at a specific forward pass, this captures the local -linear approximation of the entire layer. -/ -structure TransformerLayerLinearization where - /-- Input hidden states -/ - input : n → d → ℝ - /-- Output hidden states -/ - output : n → d → ℝ - /-- Attention linearization -/ - attention : AttentionLinearization n d - /-- MLP Jacobian -/ - mlpJacobian : SignedMixer (n × d) (n × d) - /-- Combined Jacobian (attention + residual) · (MLP + residual) -/ - combinedJacobian : SignedMixer (n × d) (n × d) - -/-- The combined Jacobian of a transformer layer is the composition -of sub-layer Jacobians, accounting for residual connections. - -For a residual block f(x) = x + sublayer(x): - ∂f/∂x = I + ∂sublayer/∂x - -So the full layer with two residual blocks: - ∂/∂x [(x + attention(x)) + MLP(x + attention(x))] - = (I + J_attn) · (I + J_mlp) -/ -theorem transformer_layer_jacobian_structure - (L : TransformerLayerLinearization (n := n) (d := d)) - (h : L.combinedJacobian = - (SignedMixer.identity + L.attention.fullJacobian).comp - (SignedMixer.identity + L.mlpJacobian)) : - L.combinedJacobian = - SignedMixer.identity + - L.attention.fullJacobian + - L.mlpJacobian + - L.attention.fullJacobian.comp L.mlpJacobian := by - rw [h] - ext ⟨i, d_i⟩ ⟨o, d_o⟩ - simp only [SignedMixer.comp_w, SignedMixer.add_w, SignedMixer.identity] - -- Expand (I + A)(I + M) = I + A + M + AM by computing each indicator sum separately - -- Use Finset.sum_eq_single to evaluate sums with single nonzero term - -- First sum: Σ_x δ_{ix}δ_{xo} = δ_{io} - have sum_ii : ∑ x : n × d, - (if (i, d_i) = x then (1 : ℝ) else 0) * (if x = (o, d_o) then (1 : ℝ) else 0) = - if (i, d_i) = (o, d_o) then (1 : ℝ) else 0 := by - by_cases heq : (i, d_i) = (o, d_o) - · simp only [heq, ite_true] - rw [Finset.sum_eq_single (o, d_o)] - · simp - · intro j _ hj; simp [hj, hj.symm] - · intro h; exact absurd (Finset.mem_univ _) h - · simp only [heq, ite_false] - apply Finset.sum_eq_zero - intro j _ - by_cases h1 : (i, d_i) = j <;> by_cases h2 : j = (o, d_o) - · exact absurd (h1.trans h2) heq - · simp [h2] - · simp [h1] - · simp [h1] - -- Second sum: Σ_x δ_{ix}M_{xo} = M_{io} - have sum_im : ∑ x : n × d, (if (i, d_i) = x then 1 else 0) * L.mlpJacobian.w x (o, d_o) = - L.mlpJacobian.w (i, d_i) (o, d_o) := by - rw [Finset.sum_eq_single (i, d_i)] - · simp - · intro j _ hj; simp [hj.symm] - · intro h; exact absurd (Finset.mem_univ _) h - -- Third sum: Σ_x A_{ix}δ_{xo} = A_{io} - have sum_ai : ∑ x : n × d, - L.attention.fullJacobian.w (i, d_i) x * (if x = (o, d_o) then 1 else 0) = - L.attention.fullJacobian.w (i, d_i) (o, d_o) := by - rw [Finset.sum_eq_single (o, d_o)] - · simp - · intro j _ hj; simp [hj] - · intro h; exact absurd (Finset.mem_univ _) h - -- Expand the product, distribute the sum, then simplify - have expand_prod : ∀ x, - ((if (i, d_i) = x then 1 else 0) + L.attention.fullJacobian.w (i, d_i) x) * - ((if x = (o, d_o) then 1 else 0) + L.mlpJacobian.w x (o, d_o)) = - (if (i, d_i) = x then 1 else 0) * (if x = (o, d_o) then 1 else 0) + - (if (i, d_i) = x then 1 else 0) * L.mlpJacobian.w x (o, d_o) + - L.attention.fullJacobian.w (i, d_i) x * (if x = (o, d_o) then 1 else 0) + - L.attention.fullJacobian.w (i, d_i) x * L.mlpJacobian.w x (o, d_o) := by intro x; ring - conv_lhs => arg 2; ext x; rw [expand_prod] - rw [Finset.sum_add_distrib, Finset.sum_add_distrib, Finset.sum_add_distrib] - conv_lhs => arg 1; arg 1; arg 1; rw [sum_ii] - simp only [sum_im, sum_ai] - ring - -/-- **Transformer attribution has four components**: -1. Direct (identity): input flows directly through residual -2. Attention: input → attention mechanism → output -3. MLP: input → residual → MLP → output -4. Cross-term: input → attention → MLP → output (interaction) - -Each can be analyzed separately for interpretability. -/ -theorem transformer_attribution_components - (L : TransformerLayerLinearization (n := n) (d := d)) - (h : L.combinedJacobian = - (SignedMixer.identity + L.attention.fullJacobian).comp - (SignedMixer.identity + L.mlpJacobian)) : - ∃ (direct attention mlp cross : SignedMixer (n × d) (n × d)), - L.combinedJacobian = direct + attention + mlp + cross ∧ - direct = SignedMixer.identity ∧ - attention = L.attention.fullJacobian ∧ - mlp = L.mlpJacobian ∧ - cross = L.attention.fullJacobian.comp L.mlpJacobian := by - refine ⟨SignedMixer.identity, L.attention.fullJacobian, L.mlpJacobian, - L.attention.fullJacobian.comp L.mlpJacobian, ?_, rfl, rfl, rfl, rfl⟩ - exact transformer_layer_jacobian_structure L h - -end TransformerLayers - -/-! ## Integrated Gradients Connection -/ - -section IntegratedGradients - -variable {n m : Type*} [Fintype n] [Fintype m] - -/-- Integrated Gradients attribution from baseline x₀ to input x. - -IG_i(x, x₀) = (x_i - x₀_i) · ∫₀¹ ∂f/∂x_i(x₀ + t(x - x₀)) dt - -For a linear function f(x) = x · M (row-vector convention), this simplifies to: - IG_i = (x_i - x₀_i) · M_{i,j} (gradient × input difference for output j) - -The key insight: IG is a path integral of linearizations along -the straight line from baseline to input. -/ -noncomputable def integratedGradientsLinear - (M : SignedMixer n m) (x₀ x : n → ℝ) (i : n) (j : m) : ℝ := - (x i - x₀ i) * M.w i j - -/-- For linear functions, IG equals output difference (completeness). -/ -theorem integratedGradients_linear_complete - (M : SignedMixer n n) (x₀ x : n → ℝ) (j : n) : - ∑ i, integratedGradientsLinear M x₀ x i j = - M.apply x j - M.apply x₀ j := by - simp only [integratedGradientsLinear, SignedMixer.apply_def] - rw [← Finset.sum_sub_distrib] - congr 1 - ext i - ring - -/-- Placeholder: the full piecewise-linear IG statement is not yet formalized. -/ -theorem integratedGradients_piecewise_linear_placeholder - (_regions : List (Linearization n n)) - (_weights : List ℝ) - (_hWeightSum : _weights.sum = 1) : - True := trivial - -end IntegratedGradients - -/-! ## Deep Linearization: Multi-Layer Transformer Analysis - -This section formalizes how attention patterns and their Jacobian decompositions -compose through multiple transformer layers. The key insight is that when we -compose layer Jacobians, we can track how much of the composition comes from -"value terms" (fixed attention flow) versus "pattern terms" (attention shifts). - -This provides a mathematical foundation for: -1. **Attention Rollout** validity across multiple layers -2. **Virtual Heads** (e.g., induction heads where L2 attention flows through L1) -3. **Circuit Analysis** with certified error bounds --/ - -section DeepLinearization - -variable {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - -/-! ### Deep Linearization Structure -/ - -/-- Factorization of an MLP Jacobian into input/output weights and activation derivatives. -/ -structure MLPFactorization (n d : Type*) [Fintype n] [Fintype d] where - /-- Hidden dimension for the MLP layer. -/ - hidden : Type* - /-- Finiteness for the hidden dimension. -/ - instFintype : Fintype hidden - /-- Decidable equality for the hidden dimension (for diagonal mixers). -/ - instDecEq : DecidableEq hidden - /-- Nonempty hidden dimension (for operator-norm bounds). -/ - instNonempty : Nonempty hidden - /-- Input weights: residual stream → hidden. -/ - win : SignedMixer (n × d) hidden - /-- Output weights: hidden → residual stream. -/ - wout : SignedMixer hidden (n × d) - /-- Activation derivative (diagonal) at the linearization point. -/ - deriv : hidden → ℝ - -attribute [instance] MLPFactorization.instFintype -attribute [instance] MLPFactorization.instDecEq -attribute [instance] MLPFactorization.instNonempty - -/-- The Jacobian represented by an `MLPFactorization`. -/ -noncomputable def MLPFactorization.jacobian - (F : MLPFactorization (n := n) (d := d)) : SignedMixer (n × d) (n × d) := - (F.win.comp (diagMixer F.deriv)).comp F.wout - -/-- A deep linearization captures the Jacobian decomposition of a multi-layer network. - -For a transformer with L layers, this tracks: -- The per-layer attention Jacobians and their V/P decompositions -- The MLP Jacobians (via an explicit factorization) -- The composed end-to-end Jacobian - -The key insight: composing (I + A₁)(I + M₁)(I + A₂)(I + M₂)... creates -cross-layer terms where attention from layer L flows through layer L-1. -These "virtual heads" are what make mechanisms like induction heads work. -/ -structure DeepLinearization where - /-- Number of layers (as a finite type index) -/ - numLayers : ℕ - /-- Per-layer attention linearizations -/ - layers : Fin numLayers → AttentionLinearization n d - /-- Per-layer LayerNorm Jacobians before attention (ln_1). -/ - ln1Jacobians : Fin numLayers → SignedMixer (n × d) (n × d) - /-- Per-layer MLP factorization data. -/ - mlpFactors : Fin numLayers → MLPFactorization (n := n) (d := d) - /-- Per-layer LayerNorm Jacobians before MLP (ln_2). -/ - ln2Jacobians : Fin numLayers → SignedMixer (n × d) (n × d) - /-- Final LayerNorm Jacobian (ln_f) applied after the last layer. -/ - lnFJacobian : SignedMixer (n × d) (n × d) := SignedMixer.identity - /-- The composed end-to-end Jacobian -/ - composedJacobian : SignedMixer (n × d) (n × d) - -/-- Per-layer MLP Jacobians derived from the factorization data. -/ -noncomputable def DeepLinearization.mlpJacobians - (D : DeepLinearization (n := n) (d := d)) : - Fin D.numLayers → SignedMixer (n × d) (n × d) := - fun i => (D.mlpFactors i).jacobian - -/-- Get the full Jacobian of a specific layer (including residual). -/ -noncomputable def DeepLinearization.layerJacobian (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) : SignedMixer (n × d) (n × d) := - let attnJac := (D.ln1Jacobians i).comp (D.layers i).fullJacobian - let mlpJac := (D.ln2Jacobians i).comp (D.mlpJacobians i) - (SignedMixer.identity + attnJac).comp (SignedMixer.identity + mlpJac) - -/-- Residual bound for a layer Jacobian from bounds on attention/MLP Jacobians. -/ -theorem DeepLinearization.layerJacobian_residual_bound - [Nonempty n] [Nonempty d] - (D : DeepLinearization (n := n) (d := d)) (i : Fin D.numLayers) - (A M : ℝ) - (hA : - SignedMixer.operatorNormBound - ((D.ln1Jacobians i).comp (D.layers i).fullJacobian) ≤ A) - (hM : - SignedMixer.operatorNormBound - ((D.ln2Jacobians i).comp (D.mlpJacobians i)) ≤ M) : - SignedMixer.operatorNormBound - (D.layerJacobian i - SignedMixer.identity) ≤ A + M + A * M := by - classical - set attnJac := (D.ln1Jacobians i).comp (D.layers i).fullJacobian - set mlpJac := (D.ln2Jacobians i).comp (D.mlpJacobians i) - have hA' : SignedMixer.operatorNormBound attnJac ≤ A := by simpa [attnJac] using hA - have hM' : SignedMixer.operatorNormBound mlpJac ≤ M := by simpa [mlpJac] using hM - have hres := - SignedMixer.operatorNormBound_residual_comp_le_of_bounds - (A := attnJac) (M := mlpJac) (a := A) (b := M) hA' hM' - simpa [DeepLinearization.layerJacobian, attnJac, mlpJac] using hres - -/-- Left-fold over `[0, count)` without allocating a list. -/ -private def foldRange {α : Type*} (count : Nat) (init : α) (f : α → Nat → α) : α := - Nat.rec (motive := fun _ => α) init (fun i acc => f acc i) count - -/-- The composed Jacobian from layer `start` to layer `stop` (exclusive). -/ -noncomputable def DeepLinearization.rangeJacobian (D : DeepLinearization (n := n) (d := d)) - (start stop : ℕ) : SignedMixer (n × d) (n × d) := - if _h : start < stop ∧ stop ≤ D.numLayers then - foldRange (stop - start) SignedMixer.identity - (fun acc i => - if hi : start + i < D.numLayers then - acc.comp (D.layerJacobian ⟨start + i, hi⟩) - else acc) - else SignedMixer.identity - -/-! ### Virtual Attention Heads -/ - -/-- **Virtual Head**: The composition of value terms from two layers. - -When Layer L₂ attends to position k, and Layer L₁ at position k attends to position j, -the composed flow from j to the final output creates a "virtual head" with pattern: - VirtualHead_{L₂,L₁}(i→q) = Σ_k A₂_{qk} · A₁_{ki} · (projections) - -This is the formal definition of "attention composition" used in: -- Attention Rollout (approximating with just attention weights) -- Induction head analysis (L2 attends to L1's output) -- Copy suppression analysis --/ -noncomputable def VirtualHead - (L₂ L₁ : AttentionLinearization n d) : SignedMixer (n × d) (n × d) := - (valueTerm L₁).comp (valueTerm L₂) - -omit [DecidableEq n] [DecidableEq d] in -/-- Virtual head is the composition of two value terms. -/ -theorem VirtualHead_is_comp (L₂ L₁ : AttentionLinearization n d) : - VirtualHead L₂ L₁ = (valueTerm L₁).comp (valueTerm L₂) := rfl - -/-- Position-collapsed virtual head: shows position-to-position flow. -/ -noncomputable def PositionVirtualHead - (L₂ L₁ : AttentionLinearization n d) : SignedMixer n n where - w := fun i q => - -- Sum over all intermediate positions k and dimensions - ∑ k : n, - L₁.state.attentionWeights k i * - L₂.state.attentionWeights q k * - (valueOutputTrace L₁) * (valueOutputTrace L₂) - -omit [DecidableEq n] [DecidableEq d] in -/-- Position virtual head is attention composition scaled by value-entry sums. -/ -theorem PositionVirtualHead_eq_attention_comp - (L₂ L₁ : AttentionLinearization n d) (i q : n) : - (PositionVirtualHead L₂ L₁).w i q = - (∑ k : n, L₂.state.attentionWeights q k * L₁.state.attentionWeights k i) * - (valueOutputTrace L₁ * valueOutputTrace L₂) := by - simp only [PositionVirtualHead, valueOutputTrace] - rw [Finset.sum_mul] - apply Finset.sum_congr rfl - intro k _ - ring - -/-! ### Deep Value Term -/ - -/-- **Deep Value Term**: The composition of all value terms through a deep network. - -This is what "Attention Rollout" computes—treating attention weights as fixed -and composing them through layers. It's the first-order approximation that -ignores how attention patterns shift. -/ -noncomputable def DeepValueTerm (D : DeepLinearization (n := n) (d := d)) : - SignedMixer (n × d) (n × d) := - let core := - if _h : 0 < D.numLayers then - foldRange D.numLayers SignedMixer.identity - (fun acc i => - if hi : i < D.numLayers then - let L := D.layers ⟨i, hi⟩ - let ln := D.ln1Jacobians ⟨i, hi⟩ - -- Pre-LN: absorb ln_1 linearization into the value path. - acc.comp (SignedMixer.identity + ln.comp (valueTerm L)) - else acc) - else SignedMixer.identity - -- Final normalization is applied after all blocks. - core.comp D.lnFJacobian - -/-! ### Deep Pattern Term (Error) -/ - -/-- **Deep Pattern Term**: The error from approximating full Jacobian by value terms. - -DeepPatternTerm = composedJacobian - DeepValueTerm - -This measures how much the actual network behavior differs from what -"Attention Rollout" would predict. When this is small, attention visualization -is faithful to the network's actual computation. -/ -noncomputable def DeepPatternTerm (D : DeepLinearization (n := n) (d := d)) : - SignedMixer (n × d) (n × d) := - D.composedJacobian - DeepValueTerm D - -/-- Deep decomposition: composedJacobian = DeepValueTerm + DeepPatternTerm. -/ -theorem deep_jacobian_decomposition (D : DeepLinearization (n := n) (d := d)) : - D.composedJacobian = DeepValueTerm D + DeepPatternTerm D := by - simp only [DeepPatternTerm] - ext i j - simp [add_sub_cancel] - -/-! ### Error Norms and Bounds -/ - -/-- Frobenius norm of a SignedMixer. -/ -noncomputable def frobeniusNorm (M : SignedMixer (n × d) (n × d)) : ℝ := - Real.sqrt (∑ i : n × d, ∑ j : n × d, (M.w i j) ^ 2) - -/-- **Main structural insight**: Deep error is bounded (by definition, since matrices are finite). - -This is the foundational existence statement: every deep pattern term has a finite -Frobenius norm bound. More refined bounds relating this to layer-wise errors require -additional assumptions about network structure. -/ -theorem deep_error_bounded_by_layer_errors (D : DeepLinearization (n := n) (d := d)) : - ∃ (bound : ℝ), frobeniusNorm (DeepPatternTerm D) ≤ bound := - ⟨frobeniusNorm (DeepPatternTerm D), le_refl _⟩ - -/-- Operator norm bound (submultiplicativity approximation). -/ -noncomputable def operatorNormBound [Nonempty n] [Nonempty d] - (M : SignedMixer (n × d) (n × d)) : ℝ := - SignedMixer.operatorNormBound M - -/-! ### RoPE bounds -/ - -section RoPEBounds - -variable {pos pair : Type*} - [Fintype pos] [DecidableEq pos] [Nonempty pos] - [Fintype pair] [DecidableEq pair] [Nonempty pair] - -/-- **Certification lemma (row-sum bound)**: RoPE has a universal `operatorNormBound` ≤ 2. - -Each RoPE row has at most two nonzero entries, `cos` and `±sin`, whose absolute values are ≤ 1. -/ - theorem rope_operatorNormBound_le_two (θ : pos → pair → ℝ) : - operatorNormBound (n := pos) (d := RoPEDim pair) - (ropeJacobian (pos := pos) (pair := pair) θ) ≤ (2 : ℝ) := by - classical - -- Reduce `sup' ≤ 2` to a per-row absolute row-sum bound. - dsimp [operatorNormBound, SignedMixer.operatorNormBound] - refine (Finset.sup'_le_iff (s := (Finset.univ : Finset (pos × RoPEDim pair))) - (f := fun i : pos × RoPEDim pair => - ∑ j : pos × RoPEDim pair, - |(ropeJacobian (pos := pos) (pair := pair) θ).w i j|) - (H := Finset.univ_nonempty)).2 ?_ - intro i _hi - rcases i with ⟨p, ⟨k, b⟩⟩ - have hrow : - (∑ j : pos × RoPEDim pair, - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) j|) ≤ (2 : ℝ) := by - -- Expand the row-sum over `pos × pair × Bool` and collapse the `pos`/`pair` sums using - -- `Fintype.sum_eq_single` (all other terms are zero by definition of `ropeJacobian`). - have hpos : - (∑ j : pos, - ∑ j' : RoPEDim pair, - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) (j, j')|) - = - ∑ j' : RoPEDim pair, - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) (p, j')| := by - -- `Fintype.sum_eq_single` in mathlib now has a single side-condition. - have hzero : - ∀ x : pos, - x ≠ p → - (∑ j' : RoPEDim pair, - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) - (x, j')|) = 0 := by - intro x hx - have hpx : p ≠ x := by - simpa [eq_comm] using hx - simp [ropeJacobian, hpx] - simpa using - (Fintype.sum_eq_single (f := fun x : pos => - ∑ j' : RoPEDim pair, - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) (x, j')|) p hzero) - have hpair : - (∑ j' : RoPEDim pair, - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) (p, j')|) - = - ∑ bb : Bool, - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) (p, (k, bb))| := by - simp only [RoPEDim, Fintype.sum_prod_type] - have hzero : - ∀ x : pair, - x ≠ k → - (|(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) - (p, (x, true))|) - + - (|(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) - (p, (x, false))|) = 0 := by - intro x hx - have hkx : k ≠ x := by - simpa [eq_comm] using hx - simp [ropeJacobian, hkx] - -- Repackage into `Fintype.sum_eq_single` over `pair`. - simpa [Fintype.univ_bool] using - (Fintype.sum_eq_single (f := fun x : pair => - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) (p, (x, true))| + - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) (p, (x, false))|) - k hzero) - have hbool : - (∑ bb : Bool, - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) (p, (k, bb))|) - = |Real.cos (θ p k)| + |Real.sin (θ p k)| := by - cases b <;> - simp [ropeJacobian, RoPEDim, Fintype.univ_bool, abs_neg, add_comm] - calc - (∑ j : pos × RoPEDim pair, - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) j|) - = - (∑ j : pos, - ∑ j' : RoPEDim pair, - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) (j, j')|) := by - simp [Fintype.sum_prod_type] - _ = ∑ j' : RoPEDim pair, - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) (p, j')| := hpos - _ = ∑ bb : Bool, - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) (p, (k, bb))| := hpair - _ = |Real.cos (θ p k)| + |Real.sin (θ p k)| := hbool - _ ≤ 1 + 1 := by - exact add_le_add (Real.abs_cos_le_one _) (Real.abs_sin_le_one _) - _ = (2 : ℝ) := by norm_num - exact hrow - -end RoPEBounds - -variable [Nonempty n] [Nonempty d] - -/-- **Per-layer error contribution**: The pattern term norm of each layer. -/ -noncomputable def layerError (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) : ℝ := - frobeniusNorm (patternTerm (D.layers i)) - -/-- **Total layer error sum**: Σᵢ ‖patternTerm(layer i)‖. -/ -noncomputable def totalLayerError (D : DeepLinearization (n := n) (d := d)) : ℝ := - ∑ i : Fin D.numLayers, layerError D i - -/-! ### The Composition of Faithfulness Theorem -/ - -/-- **Layer faithfulness**: A layer is ε-faithful if its pattern term has norm ≤ ε. -/ -def isLayerFaithful (L : AttentionLinearization n d) (ε : ℝ) : Prop := - frobeniusNorm (patternTerm L) ≤ ε - -/-- **Deep faithfulness**: A deep network is ε-faithful if its deep pattern term has norm ≤ ε. -/ -def isDeepFaithful (D : DeepLinearization (n := n) (d := d)) (ε : ℝ) : Prop := - frobeniusNorm (DeepPatternTerm D) ≤ ε - -/-- The key bound constant: amplification from residual Jacobian norms. -This product ignores `lnFJacobian`; if it is nontrivial, multiply by -`operatorNormBound D.lnFJacobian` to bound end-to-end amplification. -/ -noncomputable def amplificationFactor (D : DeepLinearization (n := n) (d := d)) : ℝ := - -- Product of (1 + ‖layerJacobian - I‖) for all layers - foldRange D.numLayers 1 - (fun acc i => - if hi : i < D.numLayers then - acc * (1 + operatorNormBound (D.layerJacobian ⟨i, hi⟩ - SignedMixer.identity)) - else acc) - -/-- **Two-layer composition theorem**: Explicit bound for 2-layer case. - -If Layer 1 is ε₁-faithful and Layer 2 is ε₂-faithful, and both residual layer -maps `(I + fullJacobian)` have operator norm bounded by C, then the composition is approximately -(ε₁ · C + ε₂ · C + ε₁ · ε₂)-faithful. - -The ε₁ · ε₂ term is second-order and often negligible when ε₁, ε₂ are small. -/ -theorem two_layer_faithfulness_composition - (L₁ L₂ : AttentionLinearization n d) - (ε₁ ε₂ C : ℝ) - (_hC₁ : operatorNormBound (SignedMixer.identity + L₁.fullJacobian) ≤ C) - (_hC₂ : operatorNormBound (SignedMixer.identity + L₂.fullJacobian) ≤ C) - (_hε₁ : isLayerFaithful L₁ ε₁) - (_hε₂ : isLayerFaithful L₂ ε₂) - (_hε₁_pos : 0 ≤ ε₁) (_hε₂_pos : 0 ≤ ε₂) (_hC_pos : 0 ≤ C) : - -- The composed error is bounded - ∃ (ε_composed : ℝ), - ε_composed ≤ C * ε₁ + C * ε₂ + ε₁ * ε₂ := by - exact ⟨C * ε₁ + C * ε₂ + ε₁ * ε₂, le_refl _⟩ - -/-! ### N-Layer Faithfulness Composition Theorem - -The key insight for deep networks is that errors compound multiplicatively: -- Each layer's pattern term contributes error εᵢ -- But that error is amplified by all subsequent layers -- The amplification factor for layer i is ∏_{j>i} (1 + Cⱼ) where - Cⱼ bounds ‖layerJacobianⱼ - I‖ - -The total error bound is: - ε_total = Σᵢ εᵢ · ∏_{j>i} (1 + Cⱼ) - -This formula is crucial because it shows: -1. Errors in early layers matter more (they get amplified more) -2. Keeping layer norms small (Cⱼ close to 0) keeps amplification low -3. The bound is tight: achieved when all errors compound constructively --/ - -/-- Per-layer residual norm bounds: Cᵢ bounds ‖layerJacobianᵢ - I‖. -/ -noncomputable def layerNormBounds (D : DeepLinearization (n := n) (d := d)) : - Fin D.numLayers → ℝ := - fun i => operatorNormBound (D.layerJacobian i - SignedMixer.identity) - -/-- Per-layer faithfulness: εᵢ bounds ‖patternTermᵢ‖. -/ -noncomputable def layerFaithfulness (D : DeepLinearization (n := n) (d := d)) : - Fin D.numLayers → ℝ := - fun i => frobeniusNorm (patternTerm (D.layers i)) - -/-- Suffix amplification factor: ∏_{j≥start} (1 + Cⱼ), -where Cⱼ bounds ‖layerJacobianⱼ - I‖. This is how much error from layer `start` -gets amplified by subsequent layers. - -When start = numLayers, this equals 1 (no amplification). -/ -noncomputable def suffixAmplification (D : DeepLinearization (n := n) (d := d)) - (start : ℕ) : ℝ := - foldRange (D.numLayers - start) 1 - (fun acc i => - if hi : start + i < D.numLayers then - acc * (1 + layerNormBounds D ⟨start + i, hi⟩) - else acc) - -/-- Base case: suffix amplification starting at numLayers is 1. -/ -theorem suffixAmplification_base (D : DeepLinearization (n := n) (d := d)) : - suffixAmplification D D.numLayers = 1 := by - simp [suffixAmplification, foldRange] - -/-- The amplificationFactor equals suffixAmplification starting from 0. -/ -theorem amplificationFactor_eq_suffix (D : DeepLinearization (n := n) (d := d)) : - amplificationFactor D = suffixAmplification D 0 := by - simp [amplificationFactor, suffixAmplification, layerNormBounds] - -/-- **Recursive total error formula**: Total error with amplification. - -ε_total = Σᵢ εᵢ · suffixAmplification(i+1) - -Each layer's error is amplified by all subsequent layers. -/ -noncomputable def totalAmplifiedError (D : DeepLinearization (n := n) (d := d)) : ℝ := - ∑ i : Fin D.numLayers, layerFaithfulness D i * suffixAmplification D (i.val + 1) - -/-- Suffix amplification is nonnegative. -/ -theorem suffixAmplification_nonneg (D : DeepLinearization (n := n) (d := d)) - (start : ℕ) (hNorm : ∀ i : Fin D.numLayers, 0 ≤ layerNormBounds D i) : - 0 ≤ suffixAmplification D start := by - unfold suffixAmplification - -- We prove a stronger statement: for any init ≥ 0, the fold result is ≥ 0 - let f := fun acc i => - if hi : start + i < D.numLayers then - acc * (1 + layerNormBounds D ⟨start + i, hi⟩) - else acc - suffices h : ∀ count : Nat, ∀ init : ℝ, 0 ≤ init → 0 ≤ foldRange count init f by - exact h (D.numLayers - start) 1 (by norm_num : (0 : ℝ) ≤ 1) - intro count - induction count with - | zero => - intro init hinit - simpa [foldRange] using hinit - | succ count ih => - intro init hinit - have hacc : 0 ≤ foldRange count init f := ih init hinit - by_cases hi : start + count < D.numLayers - · have hbound : 0 ≤ layerNormBounds D ⟨start + count, hi⟩ := hNorm _ - have hmul : 0 ≤ foldRange count init f * (1 + layerNormBounds D ⟨start + count, hi⟩) := - mul_nonneg hacc (by linarith [hbound]) - simpa [foldRange, f, hi] using hmul - · simpa [foldRange, f, hi] using hacc - -/- -These lemmas don't need the `[Nonempty _]` section variables (they are in scope -for other theorems in this section), so we explicitly omit them to satisfy the -unused-section-vars linter. --/ -omit [Nonempty n] [Nonempty d] in -/-- Layer faithfulness is nonnegative (Frobenius norm is nonneg). -/ -theorem layerFaithfulness_nonneg (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) : 0 ≤ layerFaithfulness D i := by - simp only [layerFaithfulness] - apply Real.sqrt_nonneg - -/-- Total amplified error is nonnegative. -/ -theorem totalAmplifiedError_nonneg (D : DeepLinearization (n := n) (d := d)) - (hNorm : ∀ i : Fin D.numLayers, 0 ≤ layerNormBounds D i) : - 0 ≤ totalAmplifiedError D := by - apply Finset.sum_nonneg - intro i _ - apply mul_nonneg - · exact layerFaithfulness_nonneg D i - · exact suffixAmplification_nonneg D (i.val + 1) hNorm - -/-- **N-Layer Faithfulness Composition Theorem**. - -If each layer i is εᵢ-faithful (‖patternTermᵢ‖ ≤ εᵢ) and has operator norm -bounded by Cᵢ (‖layerJacobianᵢ - I‖ ≤ Cᵢ, hence ‖layerJacobianᵢ‖ ≤ 1 + Cᵢ), -then the deep network is -ε_total-faithful where: - - ε_total = Σᵢ εᵢ · ∏_{j>i} (1 + Cⱼ) - -This is the central theorem enabling layer-by-layer verification: -instead of analyzing the full deep network at once, we can: -1. Check each layer's faithfulness (small pattern term) -2. Bound each layer's operator norm -3. Compose the bounds using this theorem - -**Key insight**: Early layer errors compound more because they pass through -more subsequent layers. This explains why attention patterns in early layers -are often harder to interpret—their errors get amplified more. -/ -theorem n_layer_faithfulness_composition - (D : DeepLinearization (n := n) (d := d)) - (εs : Fin D.numLayers → ℝ) - (Cs : Fin D.numLayers → ℝ) - (_hLayerFaithful : ∀ i, isLayerFaithful (D.layers i) (εs i)) - (_hLayerNorm : ∀ i, operatorNormBound (D.layerJacobian i - SignedMixer.identity) ≤ Cs i) - (hε_pos : ∀ i, 0 ≤ εs i) - (hC_pos : ∀ i, 0 ≤ Cs i) : - -- The deep network faithfulness is bounded by the amplified sum - ∃ (ε_total : ℝ), - 0 ≤ ε_total ∧ - ε_total ≤ ∑ i : Fin D.numLayers, - εs i * foldRange (D.numLayers - (i.val + 1)) 1 - (fun acc j => - if hj : i.val + 1 + j < D.numLayers then - acc * (1 + Cs ⟨i.val + 1 + j, hj⟩) - else acc) := by - -- The witness is exactly the bound formula - let suffix_bound : Fin D.numLayers → ℝ := fun i => - foldRange (D.numLayers - (i.val + 1)) 1 - (fun acc j => - if hj : i.val + 1 + j < D.numLayers then - acc * (1 + Cs ⟨i.val + 1 + j, hj⟩) - else acc) - -- Helper: suffix_bound is nonnegative - have hsuffix_nonneg : ∀ i : Fin D.numLayers, 0 ≤ suffix_bound i := by - intro i - let f := fun acc j => - if hj : i.val + 1 + j < D.numLayers then - acc * (1 + Cs ⟨i.val + 1 + j, hj⟩) - else acc - -- We prove: for any init ≥ 0, foldRange result is ≥ 0 - suffices h : ∀ count : Nat, ∀ init : ℝ, 0 ≤ init → 0 ≤ foldRange count init f by - have h1 : 0 ≤ foldRange (D.numLayers - (i.val + 1)) 1 f := - h (D.numLayers - (i.val + 1)) 1 (by norm_num : (0 : ℝ) ≤ 1) - simpa [suffix_bound, f] using h1 - intro count - induction count with - | zero => - intro init hinit - simpa [foldRange] using hinit - | succ count ih => - intro init hinit - have hacc : 0 ≤ foldRange count init f := ih init hinit - by_cases hj : i.val + 1 + count < D.numLayers - · have hbound : 0 ≤ Cs ⟨i.val + 1 + count, hj⟩ := hC_pos _ - have hmul : 0 ≤ foldRange count init f * (1 + Cs ⟨i.val + 1 + count, hj⟩) := - mul_nonneg hacc (by linarith [hbound]) - simpa [foldRange, f, hj] using hmul - · simpa [foldRange, f, hj] using hacc - use ∑ i : Fin D.numLayers, εs i * suffix_bound i - constructor - · -- Nonnegativity - apply Finset.sum_nonneg - intro i _ - apply mul_nonneg (hε_pos i) (hsuffix_nonneg i) - · -- The bound is satisfied (trivially, since we chose exactly this bound) - exact le_refl _ - -/-- Simplified N-layer bound with uniform constants. - -If all layers have ‖patternTerm‖ ≤ ε and ‖layerJacobian - I‖ ≤ C, then: - ε_total ≤ ε · L · (1 + C)^{L-1} - -where L is the number of layers. This shows exponential growth in depth -when C > 0, but constant growth when C = 0 (pure attention without MLP scaling). -/ -theorem n_layer_uniform_bound - (D : DeepLinearization (n := n) (d := d)) - (ε C : ℝ) - (_hL : 0 < D.numLayers) - (_hLayerFaithful : ∀ i, isLayerFaithful (D.layers i) ε) - (_hLayerNorm : ∀ i, operatorNormBound (D.layerJacobian i - SignedMixer.identity) ≤ C) - (hε_pos : 0 ≤ ε) - (hC_pos : 0 ≤ C) : - -- Simplified bound with uniform constants - ∃ (ε_total : ℝ), - 0 ≤ ε_total ∧ - ε_total ≤ ε * D.numLayers * (1 + C) ^ (D.numLayers - 1) := by - use ε * D.numLayers * (1 + C) ^ (D.numLayers - 1) - constructor - · apply mul_nonneg - · apply mul_nonneg hε_pos - exact Nat.cast_nonneg D.numLayers - · apply pow_nonneg; linarith - · exact le_refl _ - -/-- Geometric series interpretation of the N-layer bound. - -When all Cs are equal to C, the suffix amplification forms a geometric series: -suffixAmplification(i) = (1 + C)^{L-i} - -The total error becomes: -ε_total = Σᵢ εᵢ · (1 + C)^{L-1-i} - -For uniform εᵢ = ε: -ε_total = ε · Σᵢ (1 + C)^{L-1-i} = ε · ((1+C)^L - 1) / C when C ≠ 0 - = ε · L when C = 0 - -This shows that for "attention-only" networks (C ≈ 0), error grows linearly -with depth, while for networks with significant MLP scaling (C > 0), error -grows exponentially. -/ -theorem n_layer_geometric_bound - (D : DeepLinearization (n := n) (d := d)) - (ε C : ℝ) - (_hL : 0 < D.numLayers) - (_hLayerFaithful : ∀ i, isLayerFaithful (D.layers i) ε) - (_hLayerNorm : ∀ i, operatorNormBound (D.layerJacobian i - SignedMixer.identity) ≤ C) - (hε_pos : 0 ≤ ε) - (hC_pos : 0 < C) : - -- The geometric series bound - ∃ (ε_total : ℝ), - 0 ≤ ε_total ∧ - ε_total ≤ ε * ((1 + C) ^ D.numLayers - 1) / C := by - use ε * ((1 + C) ^ D.numLayers - 1) / C - constructor - · apply div_nonneg - · apply mul_nonneg hε_pos - have h1C : 1 ≤ 1 + C := by linarith - have hpow : 1 ≤ (1 + C) ^ D.numLayers := one_le_pow₀ h1C - linarith - · linarith - · exact le_refl _ - -/-- Zero-norm case: when all residual Jacobians have zero operator norm, error adds linearly. - -This is the best-case scenario for interpretability: each layer's error -contributes independently without amplification. -/ -theorem n_layer_zero_norm_bound - (D : DeepLinearization (n := n) (d := d)) - (ε : ℝ) - (_hLayerFaithful : ∀ i, isLayerFaithful (D.layers i) ε) - (_hLayerNorm : ∀ i, operatorNormBound (D.layerJacobian i - SignedMixer.identity) ≤ 0) - (hε_pos : 0 ≤ ε) : - -- Linear bound when amplification is 1 - ∃ (ε_total : ℝ), - 0 ≤ ε_total ∧ - ε_total ≤ ε * D.numLayers := by - use ε * D.numLayers - constructor - · apply mul_nonneg hε_pos - exact Nat.cast_nonneg D.numLayers - · exact le_refl _ - -/-- The connection to totalLayerError: when amplification is 1. - -Without amplification (all residual layer norms ≤ 0), the N-layer bound reduces to -the simple sum of layer errors, matching totalLayerError. -/ -theorem totalLayerError_eq_n_layer_no_amplification - (D : DeepLinearization (n := n) (d := d)) - (_hLayerNorm : ∀ i, operatorNormBound (D.layerJacobian i - SignedMixer.identity) ≤ 0) : - totalLayerError D ≤ ∑ i : Fin D.numLayers, layerFaithfulness D i := by - simp only [totalLayerError, layerError, layerFaithfulness] - exact le_refl _ - -/-! ### Certified Virtual Attention -/ - -/-- **Certified Virtual Head**: A virtual head is ε-certified if the composition -of value terms approximates the true composed Jacobian within ε. - -This is the key definition for "interpretability certification": -when we claim "this is an induction head," we can certify that the -attention-based explanation is within ε of the true mechanism. -/ -def isCertifiedVirtualHead - (L₂ L₁ : AttentionLinearization n d) - (composedJacobian : SignedMixer (n × d) (n × d)) - (ε : ℝ) : Prop := - frobeniusNorm (composedJacobian - VirtualHead L₂ L₁) ≤ ε - -omit [DecidableEq n] [DecidableEq d] [Nonempty n] [Nonempty d] in -/-- **Virtual head error budget from layer faithfulness**. - -This packages a combined ε bound (ε ≤ ε₁ + ε₂ + ε₁·ε₂); it does not -assert `isCertifiedVirtualHead` for a specific composed Jacobian. -/ -theorem virtual_head_certification - (L₂ L₁ : AttentionLinearization n d) - (ε₁ ε₂ : ℝ) - (_hε₁ : isLayerFaithful L₁ ε₁) - (_hε₂ : isLayerFaithful L₂ ε₂) : - -- The virtual head approximation has bounded error - ∃ (ε : ℝ), ε ≤ ε₁ + ε₂ + ε₁ * ε₂ := by - exact ⟨ε₁ + ε₂ + ε₁ * ε₂, le_refl _⟩ - -/-! ### Induction Head Formalization -/ - -/-- **Induction Head Pattern**: Layer 2 follows the attention structure created by Layer 1. - -An induction head occurs when: -- Layer 1 (previous-token head, simplified): A₁[i, i] is high (self-attention stand-in for i-1) -- Layer 2 (induction head, simplified): attention weights are nonnegative (softmax), - with token-matching handled by external witnesses. - -The composed pattern A₂ · A₁ creates "in-context learning" behavior. -/ -structure InductionHeadPattern where - /-- Layer 1: the previous-token attention head -/ - layer1 : AttentionLinearization n d - /-- Layer 2: the induction attention head -/ - layer2 : AttentionLinearization n d - /-- Layer 1 strongly attends to previous position -/ - prevTokenStrong : ∀ i : n, 0.5 ≤ layer1.state.attentionWeights i i - -- In practice, this would be i attending to i-1, but we simplify - /-- Layer 2 has nonnegative attention weights (softmax); token matching is handled elsewhere. -/ - inductionStrong : ∀ q k : n, layer2.state.attentionWeights q k ≥ 0 - -/-- The effective "induction pattern" created by composing the heads. -/ -noncomputable def inductionPattern (H : InductionHeadPattern (n := n) (d := d)) : - SignedMixer n n := - PositionVirtualHead H.layer2 H.layer1 - -omit [DecidableEq n] [DecidableEq d] [Nonempty n] [Nonempty d] in -/-- **Induction head error budget**: combine per-layer bounds into ε. - -This provides a concrete ε bound (ε ≤ ε₁ + ε₂ + ε₁·ε₂); it does not, by itself, -certify a specific composed Jacobian. -/ -theorem induction_head_certified (H : InductionHeadPattern (n := n) (d := d)) - (ε₁ ε₂ : ℝ) - (_hε₁ : isLayerFaithful H.layer1 ε₁) - (_hε₂ : isLayerFaithful H.layer2 ε₂) - (hε₁_pos : 0 ≤ ε₁) (hε₂_pos : 0 ≤ ε₂) : - -- The virtual head computation is approximately correct - ∃ (ε : ℝ), 0 ≤ ε ∧ ε ≤ ε₁ + ε₂ + ε₁ * ε₂ := by - refine ⟨ε₁ + ε₂ + ε₁ * ε₂, ?_, le_refl _⟩ - nlinarith - -/-! ### Interpretability Illusion Detection -/ - -/-- **Interpretability Illusion**: When pattern terms dominate value terms. - -A discovered "circuit" might be an illusion if the pattern term is large -relative to the value term—meaning the attention patterns are unstable -and the simple attention-based explanation is misleading. -/ -def isInterpretabilityIllusion (L : AttentionLinearization n d) (threshold : ℝ) : Prop := - frobeniusNorm (patternTerm L) > threshold * frobeniusNorm (valueTerm L) - -/-- **Genuine Mechanism**: When value terms dominate pattern terms. - -A mechanism is "genuine" (not an illusion) when the value term captures -most of the Jacobian and the pattern term is relatively small. -/ -def isGenuineMechanism (L : AttentionLinearization n d) (threshold : ℝ) : Prop := - frobeniusNorm (patternTerm L) ≤ threshold * frobeniusNorm (valueTerm L) - -omit [DecidableEq n] [DecidableEq d] [Nonempty n] [Nonempty d] in -/-- Mechanisms are either illusions or genuine (assuming reasonable threshold). -/ -theorem mechanism_trichotomy (L : AttentionLinearization n d) (threshold : ℝ) - (_hpos : 0 < threshold) : - isGenuineMechanism L threshold ∨ isInterpretabilityIllusion L threshold := by - by_cases h : frobeniusNorm (patternTerm L) ≤ threshold * frobeniusNorm (valueTerm L) - · left; exact h - · right - push_neg at h - exact h - -/-- **Deep mechanism certification**: A multi-layer mechanism is genuine if -all constituent layers have small pattern terms. -/ -def isDeepGenuineMechanism (D : DeepLinearization (n := n) (d := d)) (threshold : ℝ) : Prop := - ∀ i : Fin D.numLayers, isGenuineMechanism (D.layers i) threshold - -omit [Nonempty n] [Nonempty d] in -/-- If all layers are genuine, the deep pattern term is bounded. -/ -theorem deep_genuine_implies_bounded (D : DeepLinearization (n := n) (d := d)) - (threshold : ℝ) - (hGenuine : isDeepGenuineMechanism D threshold) - (_hthreshold_pos : 0 ≤ threshold) : - -- Deep pattern term is bounded by layer value terms and threshold - totalLayerError D ≤ threshold * - (∑ i : Fin D.numLayers, frobeniusNorm (valueTerm (D.layers i))) := by - simp only [totalLayerError, layerError, isDeepGenuineMechanism, isGenuineMechanism] at * - rw [Finset.mul_sum] - apply Finset.sum_le_sum - intro i _ - exact hGenuine i - -end DeepLinearization - -end Nfp diff --git a/Legacy/Nfp/MixerLocalSystem.lean b/Legacy/Nfp/MixerLocalSystem.lean deleted file mode 100644 index e3f2486..0000000 --- a/Legacy/Nfp/MixerLocalSystem.lean +++ /dev/null @@ -1,68 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.Fintype.Basic -import Mathlib.Data.Finset.Basic -import Mathlib.Logic.Equiv.Defs -import Mathlib.Data.Fintype.Card -import Nfp.Mixer -import Nfp.Uniqueness - -/- -Bridge from row-stochastic mixers on finite DAGs to `LocalSystem`. This is -useful for reusing the `tracer_unique` theorem on mixers that come equipped -with a topological ordering. --/ - -namespace Nfp - -open Finset - -variable {Site : Type*} [Fintype Site] [DecidableEq Site] - -namespace LocalSystem - -/-- Interpret a mixer as a `LocalSystem` using an explicit numbering of sites. -/ -noncomputable def ofMixerIdx {n : ℕ} (M : Mixer Site Site) (e : Site ≃ Fin n) - (acyclic : ∀ s t, M.w s t ≠ 0 → e s < e t) : - LocalSystem n := by - classical - let siteOf : Fin n → Site := e.symm - refine - { - Pa := fun i => - (Finset.univ.filter - (fun u : Fin n => - M.w (siteOf u) (siteOf i) ≠ 0)) - c := fun i u => M.w (siteOf u) (siteOf i) - topo := by - intro i u hu - have hmem := Finset.mem_filter.mp hu - have hweight : M.w (siteOf u) (siteOf i) ≠ 0 := hmem.2 - have htopo : e (siteOf u) < e (siteOf i) := acyclic _ _ hweight - simpa [siteOf] using htopo - } - -/-- Interpret a mixer as a `LocalSystem`, given a topological index `topo` and -a compatibility witness `respect` showing that the canonical `Fintype` ordering -aligns with `topo`. The `acyclic` assumption enforces the DAG constraint -`topo s < topo t` whenever `M.w s t` is nonzero. -/ -noncomputable def ofMixer (M : Mixer Site Site) (topo : Site → ℕ) - (acyclic : ∀ s t, M.w s t ≠ 0 → topo s < topo t) - (respect : - ∀ {s t}, topo s < topo t → - (Fintype.equivFin Site s).1 < (Fintype.equivFin Site t).1) : - LocalSystem (Fintype.card Site) := by - classical - let e : Site ≃ Fin (Fintype.card Site) := Fintype.equivFin Site - have hindex : - ∀ s t, M.w s t ≠ 0 → e s < e t := by - intro s t hwt - have htopo := acyclic s t hwt - have horder : (Fintype.equivFin Site s).1 < (Fintype.equivFin Site t).1 := - respect htopo - simpa [e] using horder - exact ofMixerIdx (Site := Site) (M := M) (e := e) hindex - -end LocalSystem - -end Nfp diff --git a/Legacy/Nfp/PCC.lean b/Legacy/Nfp/PCC.lean deleted file mode 100644 index 0cc3a0e..0000000 --- a/Legacy/Nfp/PCC.lean +++ /dev/null @@ -1,227 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.NNReal.Basic -import Mathlib.Data.Fintype.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Data.Finset.Basic -import Mathlib.Data.List.Basic -import Mathlib.Algebra.Order.Monoid.Defs -import Nfp.Prob -import Nfp.Reroute.Heat - -/-! -# PCC Helpers (Appendix A.4) - -This module provides small probability/contribution utilities used in the -formalization of Appendix A.4 of the accompanying documentation: - -* `tracerOfContrib` – builds a probability vector from nonnegative contributions. -* `sum_monotone_chain` / `monotone_removed_mass` – monotonicity of accumulated mass - along nested mask chains (with tiny nonnegativity side-conditions handled by `simp`). - -All proofs are elementary (`simp`, small local lemmas), and avoid `sorry`. --/ - -namespace Nfp - -open scoped BigOperators -open Finset - -variable {S : Type*} - -/-- Nonnegative contributions over `S`. -/ -abbrev Contrib (S : Type*) [Fintype S] := S → NNReal - -/-- Lemma A.4: The tracer distribution equals normalized absolute contributions. -/ -noncomputable def tracerOfContrib [Fintype S] (m : Contrib S) (h : (∑ i, m i) ≠ 0) : ProbVec S := - { - mass := fun i => m i / (∑ j, m j), - norm_one := by - classical - have hsum : (∑ j, m j) ≠ 0 := h - have : (∑ i, m i / (∑ j, m j)) = (∑ i, m i) / (∑ j, m j) := by - simp [Finset.sum_div] - simp [this, div_self hsum] - } - -@[simp] lemma tracerOfContrib_mass [Fintype S] (m : Contrib S) (h : (∑ i, m i) ≠ 0) : - (tracerOfContrib (S:=S) m h).mass = fun i => m i / (∑ j, m j) := rfl - --- PCC monotonicity surrogate: if masks grow, removed normalized mass grows. -/-- Appendix A.4 (monotonicity helper): if masks grow, the removed mass sum grows. -/ -lemma sum_monotone_chain [Fintype S] (A : ℕ → Finset S) (w : S → NNReal) - (hchain : ∀ k, A k ⊆ A (k + 1)) : Monotone (fun k => (A k).sum (fun i => w i)) := by - classical - intro k₁ k₂ hk - refine Nat.le_induction ?base ?step _ hk - · exact le_rfl - · intro k₂ _ ih - have hstep : (A k₂).sum (fun i => w i) ≤ (A (k₂+1)).sum (fun i => w i) := by - refine Finset.sum_le_sum_of_subset_of_nonneg (hchain k₂) ?_ - intro i hi _ - exact (show (0 : NNReal) ≤ w i from bot_le) - exact ih.trans hstep - -/-- Appendix A.4 (monotonicity helper): removed mass is monotone along a nested mask chain. -/ -lemma monotone_removed_mass [Fintype S] (A : ℕ → Finset S) (m : Contrib S) - (hchain : ∀ k, A k ⊆ A (k + 1)) : Monotone (fun k => (A k).sum m) := by - simpa using (sum_monotone_chain (A:=A) (w:=m) hchain) - -namespace PCC - -/-- Finite interval `[lower, upper]` used to accumulate PCC area. We only require -`lower ≤ upper` so that the width is realizable in `NNReal`. -/ -structure AInterval where - lower : NNReal - upper : NNReal - hle : lower ≤ upper - -namespace AInterval - -@[simp] lemma width_nonneg (I : AInterval) : 0 ≤ I.upper - I.lower := - (I.upper - I.lower).property - -@[simp] lemma coe_width (I : AInterval) : (I.upper - I.lower : NNReal) = I.upper - I.lower := rfl - -end AInterval - -noncomputable def intervalsFromWeightsAux : NNReal → NNReal → List NNReal → List AInterval - | _, _, [] => [] - | total, acc, w :: ws => - let width := w / total - have hle : acc ≤ acc + width := - le_add_of_nonneg_right (show 0 ≤ width from bot_le) - { lower := acc, upper := acc + width, hle := hle } :: - intervalsFromWeightsAux total (acc + width) ws - -/-- Build the discrete PCC intervals from a list of widths and a total scale. -/ -noncomputable def intervalsFromWeights (total : NNReal) (weights : List NNReal) : List AInterval := - intervalsFromWeightsAux total 0 weights - -variable (f : NNReal → NNReal) - -/-- Evaluate the discrete AUC directly from widths (without constructing intervals). -/ -noncomputable def evalFromWeightsAux : - NNReal → NNReal → List NNReal → NNReal - | _, _, [] => 0 - | total, acc, w :: ws => - let width := w / total - width * f acc + evalFromWeightsAux total (acc + width) ws - -noncomputable def evalFromWeights (total : NNReal) (weights : List NNReal) : NNReal := - evalFromWeightsAux (f:=f) total 0 weights - -/-- Discrete AUC over a finite list of intervals for an arbitrary nonnegative -function `f`. Each interval contributes `(upper - lower) * f lower`. -/ -def AUC (L : List AInterval) : NNReal := - L.foldr (fun I acc => (I.upper - I.lower) * f I.lower + acc) 0 - -@[simp] lemma AUC_nil : AUC (f:=f) [] = 0 := rfl - -@[simp] lemma AUC_cons (I : AInterval) (L : List AInterval) : - AUC (f:=f) (I :: L) = (I.upper - I.lower) * f I.lower + AUC (f:=f) L := rfl - -lemma AUC_append (L₁ L₂ : List AInterval) : - AUC (f:=f) (L₁ ++ L₂) = AUC (f:=f) L₁ + AUC (f:=f) L₂ := by - induction L₁ with - | nil => simp [AUC] - | cons I L₁ ih => - calc - AUC (f:=f) ((I :: L₁) ++ L₂) - = (I.upper - I.lower) * f I.lower + AUC (f:=f) (L₁ ++ L₂) := by - simp [List.cons_append] - _ = (I.upper - I.lower) * f I.lower + (AUC (f:=f) L₁ + AUC (f:=f) L₂) := by - simp [ih] - _ = ((I.upper - I.lower) * f I.lower + AUC (f:=f) L₁) + AUC (f:=f) L₂ := by - ac_rfl - _ = AUC (f:=f) (I :: L₁) + AUC (f:=f) L₂ := by - simp - -lemma AUC_nonneg (L : List AInterval) : 0 ≤ AUC (f:=f) L := by - induction L with - | nil => simp - | cons I L ih => - have hterm : 0 ≤ (I.upper - I.lower) * f I.lower := by - exact mul_nonneg I.width_nonneg (show 0 ≤ f I.lower from bot_le) - have hsum : 0 ≤ (I.upper - I.lower) * f I.lower + AUC (f:=f) L := - add_nonneg hterm ih - calc - 0 ≤ (I.upper - I.lower) * f I.lower + AUC (f:=f) L := hsum - _ = AUC (f:=f) (I :: L) := (AUC_cons (f:=f) I L).symm - -lemma AUC_monotone_append (L₁ L₂ : List AInterval) : - AUC (f:=f) L₁ ≤ AUC (f:=f) (L₁ ++ L₂) := by - have hnonneg := AUC_nonneg (f:=f) L₂ - have hle : AUC (f:=f) L₁ ≤ AUC (f:=f) L₁ + AUC (f:=f) L₂ := by - exact le_add_of_nonneg_right hnonneg - calc - AUC (f:=f) L₁ - ≤ AUC (f:=f) L₁ + AUC (f:=f) L₂ := hle - _ = AUC (f:=f) (L₁ ++ L₂) := (AUC_append (f:=f) L₁ L₂).symm - -lemma AUC_add (L₁ L₂ : List AInterval) : - AUC (f:=f) (L₁ ++ L₂) = AUC (f:=f) L₁ + AUC (f:=f) L₂ := - AUC_append (f:=f) L₁ L₂ - -lemma AUC_intervalsFromWeightsAux (total acc : NNReal) (weights : List NNReal) : - AUC (f:=f) (intervalsFromWeightsAux total acc weights) - = evalFromWeightsAux (f:=f) total acc weights := by - induction weights generalizing acc with - | nil => - simp [intervalsFromWeightsAux, evalFromWeightsAux, AUC] - | cons w ws ih => - have htail := ih (acc + w / total) - simp [intervalsFromWeightsAux, evalFromWeightsAux, htail] - -lemma AUC_intervalsFromWeights (total : NNReal) (weights : List NNReal) : - AUC (f:=f) (intervalsFromWeights total weights) - = evalFromWeights (f:=f) total weights := by - simpa [intervalsFromWeights, evalFromWeights] using - AUC_intervalsFromWeightsAux (f:=f) total 0 weights - -lemma intervalsFromWeightsAux_take (total acc : NNReal) (weights : List NNReal) : - ∀ k, - intervalsFromWeightsAux total acc (weights.take k) - = (intervalsFromWeightsAux total acc weights).take k - | 0 => by simp [intervalsFromWeightsAux] - | Nat.succ k => - by - cases weights with - | nil => simp [intervalsFromWeightsAux] - | cons w ws => - have ih := - intervalsFromWeightsAux_take (total:=total) (acc:=acc + w / total) (weights:=ws) k - simp [intervalsFromWeightsAux, ih, Nat.succ_eq_add_one] - -lemma AUC_intervalsFromWeights_take (total : NNReal) (weights : List NNReal) (k : ℕ) : - AUC (f:=f) ((intervalsFromWeights total weights).take k) - = evalFromWeights (f:=f) total (weights.take k) := by - have h := - AUC_intervalsFromWeightsAux (f:=f) total 0 (weights.take k) - simpa [intervalsFromWeights, evalFromWeights, - intervalsFromWeightsAux_take (total:=total) (acc:=0) (weights:=weights)] - using h - -variable {f} - -end PCC - -namespace WeightedReroutePlan - -open PCC - -variable {S : Type*} [Fintype S] [DecidableEq S] - -/-- Normalized PCC intervals for a weighted reroute plan (widths sum to one). -/ -noncomputable def aucIntervals (P : WeightedReroutePlan (S := S)) : List PCC.AInterval := - intervalsFromWeights (total:=P.weightsSum) P.weights - -lemma auc_eval (P : WeightedReroutePlan (S := S)) (f : NNReal → NNReal) : - PCC.AUC (f:=f) (P.aucIntervals) - = PCC.evalFromWeights (f:=f) P.weightsSum P.weights := by - simpa [aucIntervals] - using PCC.AUC_intervalsFromWeights (f:=f) P.weightsSum P.weights - -end WeightedReroutePlan - -end Nfp diff --git a/Legacy/Nfp/Reroute/Heat.lean b/Legacy/Nfp/Reroute/Heat.lean deleted file mode 100644 index 9161d46..0000000 --- a/Legacy/Nfp/Reroute/Heat.lean +++ /dev/null @@ -1,524 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.NNReal.Basic -import Mathlib.Data.Fintype.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Algebra.BigOperators.Ring.Finset -import Mathlib.Algebra.BigOperators.Field -import Mathlib.Data.Finset.Basic -import Mathlib.Data.List.Basic -import Mathlib.Algebra.Order.Ring.Defs -import Mathlib.Algebra.Field.Defs -import Nfp.Prob -import Nfp.Reroute.Partition - -namespace Nfp - -namespace List - -lemma mem_left_of_zip {α β : Type*} : - ∀ {xs : List α} {ys : List β} {a : α} {b : β}, - (a, b) ∈ xs.zip ys → a ∈ xs - | [], _, _, _, h => by cases h - | _ :: _, [], _, _, h => by cases h - | x :: xs, y :: ys, a, b, h => by - have h' : (a, b) = (x, y) ∨ (a, b) ∈ xs.zip ys := by - simpa [List.zip_cons_cons] using h - rcases h' with h | h - · rcases h with ⟨rfl, _⟩ - simp - · exact List.mem_cons.mpr (Or.inr (mem_left_of_zip h)) - -lemma get_mem_zip {α β : Type*} : - ∀ {xs : List α} {ys : List β} {k : Nat} - (hk : k < xs.length) (hy : k < ys.length), - (xs.get ⟨k, hk⟩, ys.get ⟨k, hy⟩) ∈ xs.zip ys - | [], _, k, hk, _ => by - cases hk - | _ :: _, [], k, _, hy => by - cases hy - | x :: xs, y :: ys, k, hk, hy => by - cases k with - | zero => - simp [List.zip_cons_cons] - | succ k => - have hk' : k < xs.length := Nat.lt_of_succ_lt_succ hk - have hy' : k < ys.length := Nat.lt_of_succ_lt_succ hy - have htail := get_mem_zip (xs:=xs) (ys:=ys) hk' hy' - have hpair_eq : - ((x :: xs).get ⟨Nat.succ k, by simpa [List.length_cons] using hk⟩, - (y :: ys).get ⟨Nat.succ k, by simpa [List.length_cons] using hy⟩) - = (xs.get ⟨k, hk'⟩, ys.get ⟨k, hy'⟩) := by - simp - have htail' : - ((x :: xs).get ⟨Nat.succ k, by simpa [List.length_cons] using hk⟩, - (y :: ys).get ⟨Nat.succ k, by simpa [List.length_cons] using hy⟩) - ∈ xs.zip ys := by - simpa [hpair_eq] using htail - have hzip : (x :: xs).zip (y :: ys) = (x, y) :: xs.zip ys := by - simp [List.zip_cons_cons] - have hmem' : - ((x :: xs).get ⟨Nat.succ k, by simpa [List.length_cons] using hk⟩, - (y :: ys).get ⟨Nat.succ k, by simpa [List.length_cons] using hy⟩) - ∈ (x, y) :: xs.zip ys := by - exact List.mem_cons.mpr (Or.inr htail') - simpa [hzip] using hmem' - -end List - -/- -Weighted reroute plans and the induced “heat” probability vector (Stage 2). -The structure couples a reroute plan with per-step weights and enforces the -alignment invariants required to distribute each weight over the incremental -masks. `rerouteHeat` sums the per-block shares, normalizes by the total drop, -and produces a `ProbVec` on `S`. --/ - -open scoped BigOperators -open Finset - -noncomputable section - -section Weighted - -variable {S : Type*} [Fintype S] [DecidableEq S] - -/-- A reroute plan together with per-step nonnegative weights (logit drops). -/ -structure WeightedReroutePlan (S : Type*) [Fintype S] [DecidableEq S] where - plan : ReroutePlan S - weights : List NNReal - length_eq : weights.length = plan.steps.length - zero_weight_of_empty : - ∀ {A : Finset S} {w : NNReal}, - (A, w) ∈ plan.increments.zip weights → - A.card = 0 → w = 0 - weights_sum_pos : 0 < weights.sum - -namespace WeightedReroutePlan - -variable (P : WeightedReroutePlan (S := S)) - -/-- Helper: the total step weight (used for normalization). -/ -def weightsSum : NNReal := - P.weights.sum - -@[simp] lemma weightsSum_pos : 0 < P.weightsSum := P.weights_sum_pos - -lemma weightsSum_ne_zero : P.weightsSum ≠ 0 := - ne_of_gt P.weightsSum_pos - -lemma length_eq_increments : - P.weights.length = P.plan.increments.length := by - simpa [P.plan.increments_length] using P.length_eq - -private def share (A : Finset S) (w : NNReal) (i : S) : NNReal := - if A.card = 0 then 0 else if i ∈ A then w / (A.card : NNReal) else 0 - -lemma sum_share (A : Finset S) (w : NNReal) : - (∑ i : S, share (S:=S) A w i) = if A.card = 0 then 0 else w := by - classical - by_cases hA : A.card = 0 - · simp [share, hA] - · have hcard : (A.card : NNReal) ≠ 0 := by - exact_mod_cast hA - have hshare_zero : ∀ i ∉ A, share (S:=S) A w i = 0 := by - intro i hi - simp [share, hA, hi] - have hsplit : - (∑ i : S, share (S:=S) A w i) - = ∑ i ∈ A, share (S:=S) A w i := by - classical - let U : Finset S := Finset.univ - have hdisj : Disjoint A (U \ A) := by - refine Finset.disjoint_left.mpr ?_ - intro x hxA hxDiff - exact (Finset.mem_sdiff.mp hxDiff).2 hxA - have hcover : A ∪ (U \ A) = U := by - ext x - by_cases hx : x ∈ A <;> simp [U, hx] - calc - (∑ i : S, share (S:=S) A w i) - = ∑ i ∈ U, share (S:=S) A w i := by simp [U] - _ = ∑ i ∈ A ∪ (U \ A), share (S:=S) A w i := by - simp [U, hcover] - _ = (∑ i ∈ A, share (S:=S) A w i) + - ∑ i ∈ U \ A, share (S:=S) A w i := by - simpa using - (Finset.sum_union hdisj - (f := fun i => share (S:=S) A w i)) - _ = (∑ i ∈ A, share (S:=S) A w i) + 0 := by - refine congrArg (fun t => (∑ i ∈ A, share (S:=S) A w i) + t) ?_ - classical - refine Finset.sum_eq_zero ?_ - intro i hi - have hiA : i ∉ A := (Finset.mem_sdiff.mp hi).2 - simpa [U] using hshare_zero i hiA - _ = ∑ i ∈ A, share (S:=S) A w i := by simp - have hsum' : - (∑ i : S, share (S:=S) A w i) - = (A.card : NNReal) * (w / (A.card : NNReal)) := by - have hconst : - (∑ i ∈ A, share (S:=S) A w i) - = (A.card : NNReal) * (w / (A.card : NNReal)) := by - classical - simp [share, hA] - exact hsplit.trans hconst - have hratio : - (A.card : NNReal) * (w / (A.card : NNReal)) = w := - mul_div_cancel₀ _ hcard - have := hsum'.trans hratio - simpa [hA] using this - -lemma sum_share_self (A : Finset S) (w : NNReal) : - (∑ i ∈ A, share (S:=S) A w i) = if A.card = 0 then 0 else w := by - classical - have hshare := sum_share (S:=S) A w - by_cases hA : A.card = 0 - · have hA' : A = ∅ := Finset.card_eq_zero.mp hA - simp [share, hA'] - · have hzero : - (∑ i ∈ ((Finset.univ : Finset S) \ A), share (S:=S) A w i) = 0 := by - refine Finset.sum_eq_zero ?_ - intro i hi - have hiA : i ∉ A := (Finset.mem_sdiff.mp hi).2 - simp [share, hiA] - have hdisj : - Disjoint A ((Finset.univ : Finset S) \ A) := by - refine Finset.disjoint_left.mpr ?_ - intro i hiA hiDiff - exact (Finset.mem_sdiff.mp hiDiff).2 hiA - have hcover : - A ∪ ((Finset.univ : Finset S) \ A) = (Finset.univ : Finset S) := by - ext i - by_cases hi : i ∈ A <;> simp [hi] - have hsplit := - Finset.sum_union hdisj (f := fun i => share (S:=S) A w i) - have hx_univ : - (∑ i : S, share (S:=S) A w i) - = ∑ i ∈ (Finset.univ : Finset S), share (S:=S) A w i := by - simp - have hxA : - (∑ i ∈ (Finset.univ : Finset S), share (S:=S) A w i) - = (∑ i ∈ A, share (S:=S) A w i) := by - simpa [hcover, hzero] using hsplit - have hx := hx_univ.trans hxA - have hx' := hx.symm - simpa using hx'.trans hshare - -omit [Fintype S] in -lemma sum_share_of_disjoint {A B : Finset S} (w : NNReal) - (hdisj : Disjoint A B) : - (∑ i ∈ A, share (S:=S) B w i) = 0 := by - classical - refine Finset.sum_eq_zero ?_ - intro i hi - have hiB : i ∉ B := by - intro hmem - exact (Finset.disjoint_left.mp hdisj) hi hmem - simp [share, hiB] - -private def heatRawAux : - ∀ (parts : List (Finset S)) (weights : List NNReal), - weights.length = parts.length → S → NNReal - | [], [], _, _ => 0 - | A :: parts, w :: weights, hlen, i => - let htail : weights.length = parts.length := by - have hlen' : - Nat.succ weights.length = Nat.succ parts.length := by - simpa [List.length_cons] using hlen - exact Nat.succ.inj hlen' - share (S:=S) A w i + heatRawAux parts weights htail i - | _, _, _, _ => 0 - -private lemma sum_heatRawAux (parts : List (Finset S)) (weights : List NNReal) - (hlen : weights.length = parts.length) - (hzero : - ∀ {A : Finset S} {w : NNReal}, - (A, w) ∈ parts.zip weights → A.card = 0 → w = 0) : - (∑ i : S, heatRawAux (S:=S) parts weights hlen i) - = weights.sum := by - classical - revert weights hlen - induction parts with - | nil => - intro weights hlen hzero - cases weights with - | nil => - simp [heatRawAux] - | cons w weights => - cases hlen - | cons A parts ih => - intro weights hlen hzero - cases weights with - | nil => - cases hlen - | cons w weights => - have hlen' : - weights.length = parts.length := by - have hlen'' : - Nat.succ weights.length = Nat.succ parts.length := by - simpa [List.length_cons] using hlen - exact Nat.succ.inj hlen'' - have hzero_head : - A.card = 0 → w = 0 := by - intro hcard - have hpair : - (A, w) ∈ (A :: parts).zip (w :: weights) := by - simp [List.zip_cons_cons] - exact hzero hpair hcard - have hzero_tail : - ∀ {B : Finset S} {w' : NNReal}, - (B, w') ∈ parts.zip weights → B.card = 0 → w' = 0 := by - intro B w' hpair hcard - have hpair' : - (B, w') ∈ (A :: parts).zip (w :: weights) := by - have : (B, w') ∈ (A, w) :: parts.zip weights := - List.mem_cons.mpr (Or.inr hpair) - simpa [List.zip_cons_cons] using this - exact hzero hpair' hcard - have hsum_tail := - ih weights hlen' hzero_tail - have hsum_head : - (∑ i : S, share (S:=S) A w i) = w := by - have hshare := sum_share (S:=S) A w - by_cases hcard : A.card = 0 - · have hw0 : w = 0 := hzero_head hcard - have : (∑ i : S, share (S:=S) A w i) = 0 := by - simp [share, hcard] - simpa [hshare, hcard, hw0, this] - · simp [hshare, hcard] - have hsum_current : - (∑ i : S, heatRawAux (S:=S) (A :: parts) (w :: weights) hlen i) - = w + ∑ i : S, heatRawAux (S:=S) parts weights hlen' i := by - simp [heatRawAux, Finset.sum_add_distrib, hsum_head] - calc - (∑ i : S, heatRawAux (S:=S) (A :: parts) (w :: weights) hlen i) - = w + ∑ i : S, heatRawAux (S:=S) parts weights hlen' i := - hsum_current - _ = w + weights.sum := by - simp [hsum_tail] - _ = (w :: weights).sum := by - simp - -omit [Fintype S] in -private lemma sum_heatRawAux_disjoint - (parts : List (Finset S)) (weights : List NNReal) - (hlen : weights.length = parts.length) - (hpair : parts.Pairwise (fun A B => Disjoint A B)) - {A : Finset S} - (hdisj : ∀ B ∈ parts, Disjoint A B) : - A.sum (fun i => heatRawAux (S:=S) parts weights hlen i) = 0 := by - classical - revert weights hlen hpair hdisj - induction parts generalizing A with - | nil => - intro weights hlen _ _ - cases weights with - | nil => simp [heatRawAux] - | cons _ _ => cases hlen - | cons B parts ih => - intro weights hlen hpair hdisj - cases weights with - | nil => cases hlen - | cons w weights => - have hlen' : weights.length = parts.length := by - have hlen'' : Nat.succ weights.length = Nat.succ parts.length := by - simpa [List.length_cons] using hlen - exact Nat.succ.inj hlen'' - rcases List.pairwise_cons.mp hpair with ⟨hB, htail⟩ - have hdisj_tail : ∀ C ∈ parts, Disjoint A C := fun C hC => hdisj C (by simp [hC]) - have hdisj_head : Disjoint A B := hdisj B (by simp) - have hshare_zero : - A.sum (fun i => share (S:=S) B w i) = 0 := - sum_share_of_disjoint (S:=S) (A:=A) (B:=B) (w:=w) hdisj_head - have htail_zero := ih weights hlen' htail hdisj_tail - simp [heatRawAux, Finset.sum_add_distrib, hshare_zero, htail_zero] - -private lemma sum_heatRawAux_mem_zip - (parts : List (Finset S)) (weights : List NNReal) - (hlen : weights.length = parts.length) - (hpair : parts.Pairwise (fun A B => Disjoint A B)) - (hzero : - ∀ {B : Finset S} {w : NNReal}, - (B, w) ∈ parts.zip weights → B.card = 0 → w = 0) - {A : Finset S} {w : NNReal} - (hmem : (A, w) ∈ parts.zip weights) : - A.sum (fun i => heatRawAux (S:=S) parts weights hlen i) = w := by - classical - revert weights hlen hpair hzero hmem - induction parts generalizing A w with - | nil => - intro weights hlen _ _ hmem - cases weights with - | nil => cases hmem - | cons _ _ => cases hlen - | cons B parts ih => - intro weights hlen hpair hzero hmem - cases weights with - | nil => cases hlen - | cons z weights => - have hlen' : weights.length = parts.length := by - have hlen'' : Nat.succ weights.length = Nat.succ parts.length := by - simpa [List.length_cons] using hlen - exact Nat.succ.inj hlen'' - rcases List.pairwise_cons.mp hpair with ⟨hB, htail⟩ - have hzero_head : B.card = 0 → z = 0 := by - intro hcard - have : (B, z) ∈ (B :: parts).zip (z :: weights) := by - simp [List.zip_cons_cons] - exact hzero this hcard - have hzero_tail : - ∀ {C : Finset S} {w' : NNReal}, - (C, w') ∈ parts.zip weights → C.card = 0 → w' = 0 := by - intro C w' hC hcard - have : (C, w') ∈ (B :: parts).zip (z :: weights) := by - have : (C, w') ∈ (B, z) :: parts.zip weights := - List.mem_cons.mpr (Or.inr hC) - simpa [List.zip_cons_cons] using this - exact hzero this hcard - have hmem_cons : (A, w) ∈ (B, z) :: parts.zip weights := by - simpa [List.zip_cons_cons] using hmem - rcases List.mem_cons.mp hmem_cons with hhead | htail_mem - · cases hhead - have hshare := sum_share_self (S:=S) (A:=B) (w:=w) - have htail_zero : - B.sum (fun i => heatRawAux (S:=S) parts weights hlen' i) = 0 := by - have hdisjB : ∀ C ∈ parts, Disjoint B C := fun C hC => hB C hC - exact sum_heatRawAux_disjoint (S:=S) - parts weights hlen' htail hdisjB - set rest := fun i => heatRawAux (S:=S) parts weights hlen' i - have hsum_split : - B.sum (fun i => heatRawAux (S:=S) (B :: parts) (w :: weights) hlen i) - = B.sum (fun i => share (S:=S) B w i) + B.sum rest := by - have := Finset.sum_add_distrib - (s:=B) (f:=fun i => share (S:=S) B w i) (g:=rest) - simpa [rest, heatRawAux] - using this - by_cases hcard : B.card = 0 - · have hw : w = 0 := hzero_head hcard - have hshare_zero : - B.sum (fun i => share (S:=S) B w i) = 0 := by - simpa [hcard] using hshare - have hBempty : B = (∅ : Finset S) := Finset.card_eq_zero.mp hcard - have hx : - B.sum (fun i => share (S:=S) B w i) + B.sum rest = 0 := by - have hxshare := hshare_zero - have hxrest := htail_zero - rw [hxshare, hxrest] - simp - have hs : - B.sum (fun i => heatRawAux (S:=S) (B :: parts) (w :: weights) hlen i) - = 0 := by - simpa [hsum_split] using hx - simpa [hw] using hs - · have hshare_eq : - B.sum (fun i => share (S:=S) B w i) = w := by - simpa [hcard] using hshare - have hx : - B.sum (fun i => share (S:=S) B w i) + B.sum rest = w := by - have hxrest := htail_zero - rw [hshare_eq, hxrest] - simp - have hs : - B.sum (fun i => heatRawAux (S:=S) (B :: parts) (w :: weights) hlen i) - = w := by - simpa [hsum_split] using hx - exact hs - · have htail_result := - ih weights hlen' htail hzero_tail htail_mem - have hA_mem : A ∈ parts := List.mem_left_of_zip htail_mem - have hdisjBA : Disjoint B A := hB A hA_mem - have hdisjAB : Disjoint A B := by - refine Finset.disjoint_left.mpr ?_ - intro i hiA hiB - exact (Finset.disjoint_left.mp hdisjBA) hiB hiA - have hshare_zero : - A.sum (fun i => share (S:=S) B z i) = 0 := - sum_share_of_disjoint (S:=S) (A:=A) (B:=B) (w:=z) hdisjAB - set rest := fun i => heatRawAux (S:=S) parts weights hlen' i - have hsum_split : - A.sum (fun i => heatRawAux (S:=S) (B :: parts) (z :: weights) hlen i) - = A.sum (fun i => share (S:=S) B z i) + A.sum rest := by - have := Finset.sum_add_distrib - (s:=A) (f:=fun i => share (S:=S) B z i) (g:=rest) - simpa [rest, heatRawAux] - using this - have hx : - A.sum (fun i => share (S:=S) B z i) + A.sum rest = w := by - rw [hshare_zero, htail_result] - simp - have hs : - A.sum (fun i => heatRawAux (S:=S) (B :: parts) (z :: weights) hlen i) - = w := by - simpa [hsum_split] using hx - exact hs - -def heatRaw (i : S) : NNReal := - heatRawAux (S:=S) P.plan.increments P.weights - P.length_eq_increments i - -lemma sum_heatRaw : - (∑ i : S, P.heatRaw i) = P.weightsSum := by - classical - have hzero : - ∀ {A : Finset S} {w : NNReal}, - (A, w) ∈ P.plan.increments.zip P.weights → - A.card = 0 → w = 0 := - fun {A} {w} => P.zero_weight_of_empty - exact - sum_heatRawAux (S:=S) P.plan.increments P.weights - P.length_eq_increments hzero - -noncomputable def rerouteHeat : ProbVec S := - { - mass := fun i => P.heatRaw i / P.weightsSum, - norm_one := by - classical - have hdiv : - (∑ i : S, P.heatRaw i / P.weightsSum) - = (∑ i : S, P.heatRaw i) / P.weightsSum := - (Finset.sum_div (s:=Finset.univ) (f:=fun i => P.heatRaw i) - (a:=P.weightsSum)).symm - have hsum := P.sum_heatRaw - have hne := P.weightsSum_ne_zero - change (∑ i : S, P.heatRaw i / P.weightsSum) = 1 - simp [hdiv, hsum, hne] - } - -@[simp] lemma rerouteHeat_mass (i : S) : - (P.rerouteHeat).mass i = P.heatRaw i / P.weightsSum := rfl - -lemma heatRaw_sum_increment {A : Finset S} {w : NNReal} - (hmem : (A, w) ∈ P.plan.increments.zip P.weights) : - A.sum (fun i => P.heatRaw i) = w := by - classical - have hpair := ReroutePlan.increments_pairwise (P:=P.plan) - have hzero : - ∀ {B : Finset S} {w' : NNReal}, - (B, w') ∈ P.plan.increments.zip P.weights → B.card = 0 → w' = 0 := - fun {B} {w'} => P.zero_weight_of_empty - simpa [heatRaw] using - sum_heatRawAux_mem_zip (S:=S) - P.plan.increments P.weights P.length_eq_increments - hpair hzero hmem - -lemma rerouteHeat_sum_increment {A : Finset S} {w : NNReal} - (hmem : (A, w) ∈ P.plan.increments.zip P.weights) : - A.sum (fun i => (P.rerouteHeat).mass i) = w / P.weightsSum := by - classical - have hsum := heatRaw_sum_increment (P:=P) hmem - have hdiv : - A.sum (fun i => P.heatRaw i / P.weightsSum) - = (A.sum fun i => P.heatRaw i) / P.weightsSum := by - simpa using - (Finset.sum_div (s:=A) (f:=fun i => P.heatRaw i) - (a:=P.weightsSum)).symm - simp [rerouteHeat_mass, hdiv, hsum] - -end WeightedReroutePlan - -end Weighted - -end - -end Nfp diff --git a/Legacy/Nfp/Reroute/Partition.lean b/Legacy/Nfp/Reroute/Partition.lean deleted file mode 100644 index b9930c1..0000000 --- a/Legacy/Nfp/Reroute/Partition.lean +++ /dev/null @@ -1,413 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.Finset.Basic -import Mathlib.Data.List.Basic -import Mathlib.Data.NNReal.Basic -import Mathlib.Data.Fintype.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Tactic.FinCases - -/-! -Partitions and reroute plans. This file hosts the element-level lemmas -needed by the reroute weighting development (Stage 1). It provides: - -* a recursive definition of `unionParts` with rewrite-friendly lemmas -* helpers showing membership existence/uniqueness under disjointness -* an intrinsic notion of `ReroutePlan` -* incremental masks (`increments`) together with coverage and disjointness -* per-element sum lemmas that decompose along disjoint parts --/ - -namespace Nfp - -open scoped BigOperators -open Finset - -section Partitions - -variable {S : Type*} [DecidableEq S] - -/-- Union a list of disjoint components. The recursion is `simp`-friendly so -Lean can peel the head element while reasoning about membership. -/ -@[simp] def unionParts : List (Finset S) → Finset S - | [] => ∅ - | A :: parts => A ∪ unionParts parts - -@[simp] lemma unionParts_nil : unionParts ([] : List (Finset S)) = (∅ : Finset S) := rfl - -@[simp] lemma unionParts_cons (A : Finset S) (parts : List (Finset S)) : - unionParts (A :: parts) = A ∪ unionParts parts := rfl - -@[simp] lemma unionParts_singleton (A : Finset S) : - unionParts [A] = A := by simp - -lemma unionParts_cons_mem {A : Finset S} {parts : List (Finset S)} {i : S} : - i ∈ unionParts (A :: parts) ↔ i ∈ A ∨ i ∈ unionParts parts := by - classical - simp [unionParts, Finset.mem_union] - -lemma mem_unionParts_iff_exists_mem (parts : List (Finset S)) (i : S) : - i ∈ unionParts parts ↔ ∃ A ∈ parts, i ∈ A := by - classical - induction parts with - | nil => - simp - | cons A parts ih => - simp [unionParts, ih, List.mem_cons] - -lemma mem_unionParts_of_mem {parts : List (Finset S)} {A : Finset S} (hA : A ∈ parts) - {i : S} (hi : i ∈ A) : i ∈ unionParts parts := by - classical - exact (mem_unionParts_iff_exists_mem (parts:=parts) (i:=i)).2 ⟨A, hA, hi⟩ - -lemma disjoint_unionParts (A : Finset S) (parts : List (Finset S)) - (h : ∀ B ∈ parts, Disjoint A B) : - Disjoint A (unionParts parts) := by - classical - induction parts with - | nil => - simp - | cons B parts ih => - have hAB : Disjoint A B := h B (by simp) - have htail : ∀ C ∈ parts, Disjoint A C := by - intro C hC - exact h C (by simp [hC]) - have hArest := ih htail - refine Finset.disjoint_left.mpr ?_ - intro x hxA hxUnion - have : x ∈ B ∨ x ∈ unionParts parts := by - simpa [unionParts] using hxUnion - cases this with - | inl hxB => - exact (Finset.disjoint_left.mp hAB) hxA hxB - | inr hxRest => - exact (Finset.disjoint_left.mp hArest) hxA hxRest - -example (A B : Finset S) (i : S) : - i ∈ unionParts [A, B] ↔ i ∈ A ∨ i ∈ B := by - classical - simp [unionParts] - -end Partitions - -lemma pairwise_disjoint_cons_elim {S : Type*} {A : Finset S} {parts : List (Finset S)} - (hpair : (A :: parts).Pairwise (fun B C => Disjoint B C)) : - (∀ B ∈ parts, Disjoint A B) ∧ - parts.Pairwise (fun B C => Disjoint B C) := - List.pairwise_cons.mp hpair - -lemma mem_unionParts_unique {S : Type*} (parts : List (Finset S)) - (hpair : parts.Pairwise (fun A B => Disjoint A B)) - {A B : Finset S} {i : S} - (hA : A ∈ parts) (hB : B ∈ parts) - (hiA : i ∈ A) (hiB : i ∈ B) : - A = B := by - classical - revert A B - induction parts with - | nil => - intro A B hA hB _ _ - cases hA - | cons C parts ih => - intro A B hA hB hiA hiB - rcases List.pairwise_cons.mp hpair with ⟨hC, htail⟩ - rcases List.mem_cons.mp hA with hAC | hA' - · rcases List.mem_cons.mp hB with hBC | hB' - · cases hAC; cases hBC; exact rfl - · have hiC : i ∈ C := by simpa [hAC] using hiA - have hCB : Disjoint C B := hC B hB' - exact ((Finset.disjoint_left.mp hCB) hiC hiB).elim - · rcases List.mem_cons.mp hB with hBC | hB' - · have hiC : i ∈ C := by simpa [hBC] using hiB - have hCA : Disjoint C A := hC A hA' - exact - ((Finset.disjoint_left.mp hCA) hiC - (by simpa [hA'] using hiA)).elim - · exact ih htail hA' hB' hiA hiB - -section PartitionsFinite - -variable {S : Type*} [Fintype S] [DecidableEq S] - -lemma mem_unionParts_of_univ (parts : List (Finset S)) - (hcover : unionParts parts = (Finset.univ : Finset S)) - (i : S) : i ∈ unionParts parts := by - classical - simp [hcover] - -end PartitionsFinite - -section SumLemmas - -variable {S : Type*} [DecidableEq S] - -lemma sum_unionParts_eq (m : S → NNReal) (parts : List (Finset S)) - (hpair : parts.Pairwise (fun A B => Disjoint A B)) : - (unionParts parts).sum m = - parts.foldr (fun A acc => A.sum m + acc) 0 := by - classical - induction parts with - | nil => - simp [unionParts] - | cons A parts ih => - rcases List.pairwise_cons.mp hpair with ⟨hA, htail⟩ - have hdisj : Disjoint A (unionParts parts) := - disjoint_unionParts (A:=A) (parts:=parts) (fun B hB => hA _ (by simpa using hB)) - have hsum := Finset.sum_union (s₁:=A) (s₂:=unionParts parts) (f:=m) hdisj - have hfold : - (A :: parts).foldr (fun B acc => B.sum m + acc) 0 - = A.sum m + parts.foldr (fun B acc => B.sum m + acc) 0 := by - simp [List.foldr] - calc - (unionParts (A :: parts)).sum m - = (A ∪ unionParts parts).sum m := rfl - _ = A.sum m + (unionParts parts).sum m := hsum - _ = A.sum m + parts.foldr (fun B acc => B.sum m + acc) 0 := by - simp [ih htail] - _ = (A :: parts).foldr (fun B acc => B.sum m + acc) 0 := by - simp [hfold] - -end SumLemmas - -section ReroutePlanStruct - -variable {S : Type*} [Fintype S] [DecidableEq S] - -/-- A reroute plan is a list of disjoint masks that cover the whole space. -/ -structure ReroutePlan (S : Type*) [Fintype S] [DecidableEq S] where - steps : List (Finset S) - pairwise_disjoint : steps.Pairwise (fun A B => Disjoint A B) - covers_univ : unionParts steps = (Finset.univ : Finset S) - -namespace ReroutePlan - -variable (S) - -/-- Re-export the list of masks. Useful when inferring implicit arguments. -/ -def masks (P : ReroutePlan (S := S)) : List (Finset S) := P.steps - -variable {S} - -lemma mem_cover (P : ReroutePlan (S := S)) (i : S) : - i ∈ unionParts P.steps := - mem_unionParts_of_univ (parts:=P.steps) P.covers_univ i - -lemma exists_block (P : ReroutePlan (S := S)) (i : S) : - ∃ A ∈ P.steps, i ∈ A := - (mem_unionParts_iff_exists_mem (parts:=P.steps) (i:=i)).1 (P.mem_cover i) - -lemma unique_block (P : ReroutePlan (S := S)) (i : S) : - ∃! A, A ∈ P.steps ∧ i ∈ A := by - classical - obtain ⟨A, hA, hiA⟩ := P.exists_block i - refine ⟨A, ⟨hA, hiA⟩, ?_⟩ - intro B hB - apply mem_unionParts_unique (parts:=P.steps) P.pairwise_disjoint <;> - tauto - -end ReroutePlan - -end ReroutePlanStruct - -section Increments - -namespace ReroutePlan - -section Aux - -variable {S : Type*} [DecidableEq S] - -private def incrementsAux : List (Finset S) → Finset S → List (Finset S) - | [], _ => [] - | A :: parts, seen => (A \ seen) :: incrementsAux parts (seen ∪ A) - -@[simp] lemma incrementsAux_nil (seen : Finset S) : - incrementsAux ([] : List (Finset S)) seen = [] := rfl - -@[simp] lemma incrementsAux_cons (A : Finset S) (parts : List (Finset S)) - (seen : Finset S) : - incrementsAux (A :: parts) seen = - (A \ seen) :: incrementsAux parts (seen ∪ A) := rfl - -lemma incrementsAux_length (parts : List (Finset S)) (seen : Finset S) : - (incrementsAux parts seen).length = parts.length := by - classical - induction parts generalizing seen with - | nil => - simp [incrementsAux] - | cons A parts ih => - simp [incrementsAux, ih] - -private lemma incrementsAux_mem_disjoint_seen : - ∀ {parts : List (Finset S)} {seen B : Finset S}, - B ∈ incrementsAux parts seen → Disjoint B seen := by - classical - intro parts - induction parts with - | nil => - intro seen B hB - cases hB - | cons A parts ih => - intro seen B hB - dsimp [incrementsAux] at hB - rcases List.mem_cons.mp hB with hHead | hTail - · subst hHead - refine Finset.disjoint_left.mpr ?_ - intro x hxB hxSeen - exact (Finset.mem_sdiff.mp hxB).2 hxSeen - · have h := ih (seen := seen ∪ A) (B := B) hTail - refine Finset.disjoint_left.mpr ?_ - intro x hxB hxSeen - have hxUnion : x ∈ seen ∪ A := by - have : x ∈ seen := hxSeen - exact (Finset.mem_union.mpr (Or.inl this)) - exact (Finset.disjoint_left.mp h) hxB hxUnion - -private lemma incrementsAux_pairwise (parts : List (Finset S)) (seen : Finset S) : - (incrementsAux parts seen).Pairwise (fun A B => Disjoint A B) := by - classical - induction parts generalizing seen with - | nil => - simp [incrementsAux] - | cons A parts ih => - refine List.pairwise_cons.2 ?_ - constructor - · intro B hB - have hDisjoint := - incrementsAux_mem_disjoint_seen (parts:=parts) (seen:=seen ∪ A) (B:=B) hB - refine Finset.disjoint_left.mpr ?_ - intro x hxHead hxB - have hxA : x ∈ A := (Finset.mem_sdiff.mp hxHead).1 - have hxUnion : x ∈ seen ∪ A := by - exact Finset.mem_union.mpr (Or.inr hxA) - exact (Finset.disjoint_left.mp hDisjoint) hxB hxUnion - · simpa using ih (seen ∪ A) - -private lemma sdiff_union_left (A B C : Finset S) : - (A \ C) ∪ (B \ (C ∪ A)) = (A ∪ B) \ C := by - classical - ext i - constructor - · intro hx - rcases Finset.mem_union.mp hx with hxA | hxB - · rcases Finset.mem_sdiff.mp hxA with ⟨hiA, hiC⟩ - exact Finset.mem_sdiff.mpr ⟨Finset.mem_union.mpr (Or.inl hiA), hiC⟩ - · rcases Finset.mem_sdiff.mp hxB with ⟨hiB, hiCompl⟩ - have hiC : i ∉ C := by - exact fun hCi => hiCompl (Finset.mem_union.mpr (Or.inl hCi)) - exact Finset.mem_sdiff.mpr ⟨Finset.mem_union.mpr (Or.inr hiB), hiC⟩ - · intro hx - rcases Finset.mem_sdiff.mp hx with ⟨hiUnion, hiC⟩ - by_cases hiA : i ∈ A - · exact Finset.mem_union.mpr (Or.inl (Finset.mem_sdiff.mpr ⟨hiA, hiC⟩)) - · have hiB : i ∈ B := by - have : i ∈ A ∨ i ∈ B := (Finset.mem_union.mp hiUnion) - exact this.resolve_left hiA - have hiNot : i ∉ C ∪ A := by - intro hmem - rcases Finset.mem_union.mp hmem with hC | hA - · exact hiC hC - · exact hiA hA - exact Finset.mem_union.mpr (Or.inr (Finset.mem_sdiff.mpr ⟨hiB, hiNot⟩)) - -private lemma unionParts_incrementsAux (parts : List (Finset S)) (seen : Finset S) : - unionParts (incrementsAux parts seen) = unionParts parts \ seen := by - classical - induction parts generalizing seen with - | nil => - simp [incrementsAux] - | cons A parts ih => - simp [incrementsAux, unionParts, ih, sdiff_union_left] - -@[simp] lemma unionParts_incrementsAux_empty (parts : List (Finset S)) : - unionParts (incrementsAux parts (∅ : Finset S)) = unionParts parts := by - classical - have h := - unionParts_incrementsAux (parts:=parts) (seen:=(∅ : Finset S)) - have hzero : unionParts parts \ (∅ : Finset S) = unionParts parts := by simp - exact h.trans hzero - -end Aux - -section WithPlan - -variable {S : Type*} [Fintype S] [DecidableEq S] - -/-- Incremental “delta” masks: each block removes only the new elements from the -corresponding reroute step. -/ -def increments (P : ReroutePlan (S := S)) : List (Finset S) := - incrementsAux P.steps ∅ - -/-- The incremental masks list has the same length as `steps`. -/ -lemma increments_length (P : ReroutePlan (S := S)) : - P.increments.length = P.steps.length := by - classical - simp [increments, incrementsAux_length] - -lemma increments_pairwise (P : ReroutePlan (S := S)) : - P.increments.Pairwise (fun A B => Disjoint A B) := by - classical - unfold increments - exact - incrementsAux_pairwise (parts:=P.steps) (seen:=(∅ : Finset S)) - -@[simp] lemma unionParts_increments (P : ReroutePlan (S := S)) : - unionParts P.increments = unionParts P.steps := by - classical - have h := - unionParts_incrementsAux (parts:=P.steps) (seen:=(∅ : Finset S)) - have hzero : unionParts P.steps \ (∅ : Finset S) = unionParts P.steps := by simp - unfold increments - exact h.trans hzero - -lemma sum_over_increments (P : ReroutePlan (S := S)) (m : S → NNReal) : - (unionParts P.steps).sum m = - P.increments.foldr (fun A acc => A.sum m + acc) 0 := by - classical - have hpair := increments_pairwise (P:=P) - have hsum := - sum_unionParts_eq (m:=m) (parts:=P.increments) hpair - simpa [unionParts_increments (P:=P)] using hsum - -end WithPlan - -end ReroutePlan - -end Increments - -section Examples - -open ReroutePlan - -@[simp] def fin2Plan : ReroutePlan (Fin 2) := -by - classical - refine { - steps := [{0}, {1}], - pairwise_disjoint := ?_, - covers_univ := ?_ } - · refine List.pairwise_cons.2 ?_ - refine ⟨?_, ?_⟩ - · intro B hB - have : B = ({1} : Finset (Fin 2)) := by simpa using hB - subst this - simp - · simp - · ext i - fin_cases i <;> simp [unionParts] - -example : - (unionParts (fin2Plan.steps)).sum - (fun i => if i = 0 then (2 : NNReal) else 5) = - fin2Plan.increments.foldr - (fun A acc => - A.sum (fun i => if i = 0 then (2 : NNReal) else 5) + acc) - 0 := by - classical - exact - (ReroutePlan.sum_over_increments - (P:=fin2Plan) (m:=fun i => if i = 0 then (2 : NNReal) else 5)) - -end Examples - -end Nfp diff --git a/Legacy/Nfp/SignedMixer.lean b/Legacy/Nfp/SignedMixer.lean deleted file mode 100644 index 65b10da..0000000 --- a/Legacy/Nfp/SignedMixer.lean +++ /dev/null @@ -1,636 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.Real.Basic -import Mathlib.Data.Real.Sign -import Mathlib.Data.Fintype.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Data.Finset.Basic -import Mathlib.Algebra.Group.Defs -import Mathlib.Order.MinMax -import Nfp.Mixer - -/-! -# Signed Mixers and Affine Transformations - -This module extends the mixer framework to support real neural network operations -that involve negative weights and biases. While the original `Mixer` type captures -attention (row-stochastic, nonnegative), real networks also use: - -1. **Signed linear maps**: Value projections, MLPs with negative weights -2. **Affine maps**: Operations with bias terms -3. **Decompositions**: Splitting signed maps into positive/negative parts - -## Key insight for interpretation - -For attribution, we care about *how much* each input contributes to each output. -With signed weights, a negative contribution means "increasing the input decreases -the output." The framework here tracks both positive and negative contributions -separately, enabling precise attribution analysis. - -## Main definitions - -* `SignedMixer`: Linear map with real (possibly negative) weights -* `SignedMixer.positivePart`, `negativePart`: Decomposition into nonnegative parts -* `AffineMixer`: Signed mixer plus bias term -* `SignedMixer.toInfluence`: Convert to influence matrix for attribution - -## References - -This connects to: -- Integrated Gradients (uses signed gradients for attribution) -- SHAP values (can be positive or negative) -- Attention with negative weights (some transformer variants) --/ - -namespace Nfp - -open scoped BigOperators -open Finset - -variable {S T U : Type*} [Fintype S] [Fintype T] [Fintype U] - -/-! ## Signed Mixer -/ - -/-- A signed mixer: a linear map between finite types with real weights. -Unlike `Mixer`, weights can be negative and rows need not sum to 1. - -This captures operations like: -- Value projections in attention -- MLP layers -- Any linear layer in a neural network -/ -structure SignedMixer (S T : Type*) [Fintype S] [Fintype T] where - /-- The weight matrix. `w i j` is the weight from input `i` to output `j`. -/ - w : S → T → ℝ - -namespace SignedMixer - -variable {S T U : Type*} [Fintype S] [Fintype T] [Fintype U] - -/-- Extensionality for signed mixers. -/ -@[ext] -theorem ext {M N : SignedMixer S T} (h : ∀ i j, M.w i j = N.w i j) : M = N := by - cases M; cases N; simp only [mk.injEq]; funext i j; exact h i j - -/-- The zero signed mixer. -/ -instance : Zero (SignedMixer S T) where - zero := ⟨fun _ _ => 0⟩ - -/-- Addition of signed mixers (pointwise). -/ -instance : Add (SignedMixer S T) where - add M N := ⟨fun i j => M.w i j + N.w i j⟩ - -/-- Scalar multiplication for signed mixers. -/ -instance : SMul ℝ (SignedMixer S T) where - smul c M := ⟨fun i j => c * M.w i j⟩ - -/-- Negation of signed mixers. -/ -instance : Neg (SignedMixer S T) where - neg M := ⟨fun i j => -M.w i j⟩ - -/-- Subtraction of signed mixers. -/ -instance : Sub (SignedMixer S T) where - sub M N := ⟨fun i j => M.w i j - N.w i j⟩ - -@[simp] lemma zero_w (i : S) (j : T) : (0 : SignedMixer S T).w i j = 0 := rfl -@[simp] lemma add_w (M N : SignedMixer S T) (i : S) (j : T) : - (M + N).w i j = M.w i j + N.w i j := rfl -@[simp] lemma smul_w (c : ℝ) (M : SignedMixer S T) (i : S) (j : T) : - (c • M).w i j = c * M.w i j := rfl -@[simp] lemma neg_w (M : SignedMixer S T) (i : S) (j : T) : (-M).w i j = -M.w i j := rfl -@[simp] lemma sub_w (M N : SignedMixer S T) (i : S) (j : T) : - (M - N).w i j = M.w i j - N.w i j := rfl - -/-- The identity signed mixer. -/ -noncomputable def identity [DecidableEq S] : SignedMixer S S where - w := fun i j => if i = j then 1 else 0 - -@[simp] lemma identity_diag [DecidableEq S] (i : S) : identity.w i i = 1 := by simp [identity] - -@[simp] lemma identity_off_diag [DecidableEq S] {i j : S} (h : i ≠ j) : - identity.w i j = 0 := by simp [identity, h] - -/-- Composition of signed mixers (matrix multiplication). -/ -noncomputable def comp (M : SignedMixer S T) (N : SignedMixer T U) : SignedMixer S U where - w := fun i k => ∑ j, M.w i j * N.w j k - -@[simp] lemma comp_w (M : SignedMixer S T) (N : SignedMixer T U) (i : S) (k : U) : - (M.comp N).w i k = ∑ j, M.w i j * N.w j k := rfl - -/-- Identity is a left unit for composition. -/ -@[simp] theorem identity_comp [DecidableEq S] (M : SignedMixer S T) : - identity.comp M = M := by - ext i j - simp only [comp_w, identity] - simp [Finset.sum_ite_eq, Finset.mem_univ] - -/-- Identity is a right unit for composition. -/ -@[simp] theorem comp_identity [DecidableEq T] (M : SignedMixer S T) : - M.comp identity = M := by - ext i j - simp only [comp_w, identity] - simp [Finset.sum_ite_eq', Finset.mem_univ] - -/-- Composition is associative. -/ -theorem comp_assoc {V : Type*} [Fintype V] - (M : SignedMixer S T) (N : SignedMixer T U) (P : SignedMixer U V) : - (M.comp N).comp P = M.comp (N.comp P) := by - ext i l - simp only [comp_w] - -- LHS: ∑_k (∑_j M_ij * N_jk) * P_kl - -- RHS: ∑_j M_ij * (∑_k N_jk * P_kl) - conv_lhs => - arg 2 - ext k - rw [Finset.sum_mul] - conv_rhs => - arg 2 - ext j - rw [Finset.mul_sum] - rw [Finset.sum_comm] - congr 1 - ext j - congr 1 - ext k - ring - -/-- Composition distributes over addition on the left. -/ -theorem comp_add_left (M₁ M₂ : SignedMixer S T) (N : SignedMixer T U) : - (M₁ + M₂).comp N = M₁.comp N + M₂.comp N := by - ext i k - simp [comp_w, add_w, add_mul, Finset.sum_add_distrib] - -/-- Composition distributes over addition on the right. -/ -theorem comp_add_right (M : SignedMixer S T) (N₁ N₂ : SignedMixer T U) : - M.comp (N₁ + N₂) = M.comp N₁ + M.comp N₂ := by - ext i k - simp [comp_w, add_w, mul_add, Finset.sum_add_distrib] - -/-! ## Decomposition into positive and negative parts -/ - -/-- The positive part of a signed mixer: max(w, 0) for each weight. -/ -noncomputable def positivePart (M : SignedMixer S T) : SignedMixer S T where - w := fun i j => max (M.w i j) 0 - -/-- The negative part of a signed mixer: max(-w, 0) for each weight. -Note: This is nonnegative; it represents the magnitude of negative weights. -/ -noncomputable def negativePart (M : SignedMixer S T) : SignedMixer S T where - w := fun i j => max (-M.w i j) 0 - -@[simp] lemma positivePart_w (M : SignedMixer S T) (i : S) (j : T) : - M.positivePart.w i j = max (M.w i j) 0 := rfl - -@[simp] lemma negativePart_w (M : SignedMixer S T) (i : S) (j : T) : - M.negativePart.w i j = max (-M.w i j) 0 := rfl - -/-- A signed mixer decomposes as positivePart - negativePart. -/ -theorem decompose (M : SignedMixer S T) : - M = M.positivePart - M.negativePart := by - ext i j - simp only [positivePart_w, negativePart_w, sub_w] - -- max(x, 0) - max(-x, 0) = x - by_cases h : M.w i j ≥ 0 - · simp [max_eq_left h, max_eq_right (neg_nonpos.mpr h)] - · push_neg at h - simp [max_eq_right (le_of_lt h), max_eq_left (neg_nonneg.mpr (le_of_lt h))] - -/-- The positive part is nonnegative. -/ -lemma positivePart_nonneg (M : SignedMixer S T) (i : S) (j : T) : - M.positivePart.w i j ≥ 0 := le_max_right _ _ - -/-- The negative part is nonnegative. -/ -lemma negativePart_nonneg (M : SignedMixer S T) (i : S) (j : T) : - M.negativePart.w i j ≥ 0 := le_max_right _ _ - -/-! ## Row sums and normalization -/ - -/-- The sum of weights in row i. -/ -noncomputable def rowSum (M : SignedMixer S T) (i : S) : ℝ := ∑ j, M.w i j - -/-- A signed mixer is row-stochastic if all rows sum to 1. -/ -def IsRowStochastic (M : SignedMixer S T) : Prop := ∀ i, M.rowSum i = 1 - -/-- A signed mixer is row-normalized if all rows sum to the same value. -/ -def IsRowNormalized (M : SignedMixer S T) (c : ℝ) : Prop := ∀ i, M.rowSum i = c - -/-- The sum of absolute values in row i. -/ -noncomputable def rowAbsSum (M : SignedMixer S T) (i : S) : ℝ := ∑ j, |M.w i j| - -/-- Total influence magnitude: sum of all absolute weights. -/ -noncomputable def totalInfluence (M : SignedMixer S T) : ℝ := ∑ i, M.rowAbsSum i - -/-- Row-sum operator norm bound (induced ℓ1 for row-vector convention). -/ -noncomputable def operatorNormBound (M : SignedMixer S T) [Nonempty S] : ℝ := - Finset.sup' Finset.univ (Finset.univ_nonempty (α := S)) (fun i => rowAbsSum M i) - -/-! ## Operator norm bound estimates -/ - -/-- Row absolute sum is nonnegative. -/ -lemma rowAbsSum_nonneg (M : SignedMixer S T) (i : S) : 0 ≤ M.rowAbsSum i := by - classical - unfold rowAbsSum - refine Finset.sum_nonneg ?_ - intro j _hj - exact abs_nonneg _ - -/-- Operator norm bounds are nonnegative. -/ -theorem operatorNormBound_nonneg (M : SignedMixer S T) [Nonempty S] : - 0 ≤ operatorNormBound M := by - classical - rcases (Finset.univ_nonempty (α := S)) with ⟨i, hi⟩ - have hrow : 0 ≤ rowAbsSum M i := rowAbsSum_nonneg (M := M) i - have hle : rowAbsSum M i ≤ operatorNormBound M := by - exact Finset.le_sup' (s := Finset.univ) (f := fun i => rowAbsSum M i) hi - exact le_trans hrow hle - -/-- Row absolute sums are subadditive. -/ -lemma rowAbsSum_add_le (M N : SignedMixer S T) (i : S) : - rowAbsSum (M + N) i ≤ rowAbsSum M i + rowAbsSum N i := by - classical - have hterm : ∀ j : T, |M.w i j + N.w i j| ≤ |M.w i j| + |N.w i j| := by - intro j - exact abs_add_le _ _ - have hsum : - ∑ j, |M.w i j + N.w i j| ≤ ∑ j, (|M.w i j| + |N.w i j|) := by - refine Finset.sum_le_sum ?_ - intro j _hj - exact hterm j - calc - rowAbsSum (M + N) i = ∑ j, |M.w i j + N.w i j| := by - simp [rowAbsSum, add_w] - _ ≤ ∑ j, (|M.w i j| + |N.w i j|) := hsum - _ = rowAbsSum M i + rowAbsSum N i := by - simp [rowAbsSum, Finset.sum_add_distrib] - -/-- Operator norm bounds are subadditive. -/ -theorem operatorNormBound_add_le (M N : SignedMixer S T) [Nonempty S] : - operatorNormBound (M + N) ≤ operatorNormBound M + operatorNormBound N := by - classical - dsimp [operatorNormBound] - refine (Finset.sup'_le_iff (s := Finset.univ) - (H := Finset.univ_nonempty (α := S)) - (f := fun i => rowAbsSum (M + N) i) - (a := operatorNormBound M + operatorNormBound N)).2 ?_ - intro i hi - have hsum : rowAbsSum (M + N) i ≤ rowAbsSum M i + rowAbsSum N i := - rowAbsSum_add_le (M := M) (N := N) i - have hM : rowAbsSum M i ≤ operatorNormBound M := by - exact Finset.le_sup' (s := Finset.univ) (f := fun i => rowAbsSum M i) hi - have hN : rowAbsSum N i ≤ operatorNormBound N := by - exact Finset.le_sup' (s := Finset.univ) (f := fun i => rowAbsSum N i) hi - have hbound : rowAbsSum (M + N) i ≤ operatorNormBound M + operatorNormBound N := by - exact le_trans hsum (add_le_add hM hN) - simpa using hbound - -/-- Row absolute sums of a composition are bounded by row sums and the operator norm bound. -/ -lemma rowAbsSum_comp_le (M : SignedMixer S T) (N : SignedMixer T U) (i : S) [Nonempty T] : - rowAbsSum (M.comp N) i ≤ rowAbsSum M i * operatorNormBound N := by - classical - have hterm : ∀ k : U, |∑ j, M.w i j * N.w j k| ≤ ∑ j, |M.w i j| * |N.w j k| := by - intro k - simpa [abs_mul] using - (abs_sum_le_sum_abs (f := fun j => M.w i j * N.w j k) (s := Finset.univ)) - calc - rowAbsSum (M.comp N) i = ∑ k, |∑ j, M.w i j * N.w j k| := by - simp [rowAbsSum, comp_w] - _ ≤ ∑ k, ∑ j, |M.w i j| * |N.w j k| := by - refine Finset.sum_le_sum ?_ - intro k _hk - exact hterm k - _ = ∑ j, |M.w i j| * (∑ k, |N.w j k|) := by - calc - (∑ k, ∑ j, |M.w i j| * |N.w j k|) - = ∑ j, ∑ k, |M.w i j| * |N.w j k| := by - simpa using - (Finset.sum_comm (s := Finset.univ) (t := Finset.univ) - (f := fun k j => |M.w i j| * |N.w j k|)) - _ = ∑ j, |M.w i j| * (∑ k, |N.w j k|) := by - refine Finset.sum_congr rfl ?_ - intro j _hj - simp [Finset.mul_sum] - _ ≤ ∑ j, |M.w i j| * operatorNormBound N := by - refine Finset.sum_le_sum ?_ - intro j _hj - have hN : rowAbsSum N j ≤ operatorNormBound N := by - exact Finset.le_sup' (s := Finset.univ) (f := fun j => rowAbsSum N j) (by simp) - have hN' : (∑ k, |N.w j k|) ≤ operatorNormBound N := by - simpa [rowAbsSum] using hN - have hMnonneg : 0 ≤ |M.w i j| := abs_nonneg _ - exact mul_le_mul_of_nonneg_left hN' hMnonneg - _ = rowAbsSum M i * operatorNormBound N := by - simp [rowAbsSum, Finset.sum_mul] - -/-- Operator norm bounds are submultiplicative. -/ -theorem operatorNormBound_comp_le (M : SignedMixer S T) (N : SignedMixer T U) - [Nonempty S] [Nonempty T] : - operatorNormBound (M.comp N) ≤ operatorNormBound M * operatorNormBound N := by - classical - dsimp [operatorNormBound] - refine (Finset.sup'_le_iff (s := Finset.univ) - (H := Finset.univ_nonempty (α := S)) - (f := fun i => rowAbsSum (M.comp N) i) - (a := operatorNormBound M * operatorNormBound N)).2 ?_ - intro i hi - have hrow : rowAbsSum (M.comp N) i ≤ rowAbsSum M i * operatorNormBound N := - rowAbsSum_comp_le (M := M) (N := N) i - have hM : rowAbsSum M i ≤ operatorNormBound M := by - exact Finset.le_sup' (s := Finset.univ) (f := fun i => rowAbsSum M i) hi - have hNnonneg : 0 ≤ operatorNormBound N := operatorNormBound_nonneg (M := N) - have hmul : rowAbsSum M i * operatorNormBound N ≤ operatorNormBound M * operatorNormBound N := by - exact mul_le_mul_of_nonneg_right hM hNnonneg - have hbound : rowAbsSum (M.comp N) i ≤ operatorNormBound M * operatorNormBound N := - le_trans hrow hmul - simpa using hbound - -/-- Operator norm bounds for a triple composition. -/ -theorem operatorNormBound_comp3_le {V : Type*} [Fintype V] - (A : SignedMixer S T) (B : SignedMixer T U) (C : SignedMixer U V) - [Nonempty S] [Nonempty T] [Nonempty U] : - operatorNormBound ((A.comp B).comp C) ≤ - operatorNormBound A * operatorNormBound B * operatorNormBound C := by - have h1 : - operatorNormBound ((A.comp B).comp C) ≤ - operatorNormBound (A.comp B) * operatorNormBound C := - operatorNormBound_comp_le (M := A.comp B) (N := C) - have h2 : - operatorNormBound (A.comp B) ≤ operatorNormBound A * operatorNormBound B := - operatorNormBound_comp_le (M := A) (N := B) - have hC_nonneg : 0 ≤ operatorNormBound C := operatorNormBound_nonneg (M := C) - have hmul : - operatorNormBound (A.comp B) * operatorNormBound C ≤ - (operatorNormBound A * operatorNormBound B) * operatorNormBound C := by - exact mul_le_mul_of_nonneg_right h2 hC_nonneg - calc - operatorNormBound ((A.comp B).comp C) ≤ - operatorNormBound (A.comp B) * operatorNormBound C := h1 - _ ≤ (operatorNormBound A * operatorNormBound B) * operatorNormBound C := hmul - _ = operatorNormBound A * operatorNormBound B * operatorNormBound C := by - ring - -/-- Bound for `A + M + A.comp M` in terms of operator norms. -/ -theorem operatorNormBound_add_comp_le (A M : SignedMixer S S) [Nonempty S] : - operatorNormBound (A + M + A.comp M) ≤ - operatorNormBound A + operatorNormBound M + - operatorNormBound A * operatorNormBound M := by - have hsum : operatorNormBound (A + M + A.comp M) ≤ - operatorNormBound (A + M) + operatorNormBound (A.comp M) := - operatorNormBound_add_le (M := A + M) (N := A.comp M) - have hsum' : operatorNormBound (A + M) ≤ operatorNormBound A + operatorNormBound M := - operatorNormBound_add_le (M := A) (N := M) - have hcomp : operatorNormBound (A.comp M) ≤ operatorNormBound A * operatorNormBound M := - operatorNormBound_comp_le (M := A) (N := M) - calc - operatorNormBound (A + M + A.comp M) - ≤ operatorNormBound (A + M) + operatorNormBound (A.comp M) := hsum - _ ≤ (operatorNormBound A + operatorNormBound M) + - (operatorNormBound A * operatorNormBound M) := by - exact add_le_add hsum' hcomp - _ = operatorNormBound A + operatorNormBound M + - operatorNormBound A * operatorNormBound M := by - ring - -/-- Expand the residual composition `(I + A) ∘ (I + M) - I` into `A + M + A ∘ M`. -/ -theorem residual_comp_eq [DecidableEq S] (A M : SignedMixer S S) : - (SignedMixer.identity + A).comp (SignedMixer.identity + M) - SignedMixer.identity = - A + M + A.comp M := by - classical - have h1 : - (SignedMixer.identity + A).comp (SignedMixer.identity + M) = - SignedMixer.identity.comp (SignedMixer.identity + M) + - A.comp (SignedMixer.identity + M) := by - exact comp_add_left (M₁ := SignedMixer.identity) (M₂ := A) (N := SignedMixer.identity + M) - have h2 : - SignedMixer.identity.comp (SignedMixer.identity + M) = - SignedMixer.identity.comp SignedMixer.identity + SignedMixer.identity.comp M := by - exact comp_add_right (M := SignedMixer.identity) (N₁ := SignedMixer.identity) (N₂ := M) - have h3 : - A.comp (SignedMixer.identity + M) = - A.comp SignedMixer.identity + A.comp M := by - exact comp_add_right (M := A) (N₁ := SignedMixer.identity) (N₂ := M) - ext i j - simp [h1, h2, h3, add_w, sub_w, identity_comp, comp_identity] - ring - -/-- Residual composition bound with the `A + M + A*M` cross term. -/ -theorem operatorNormBound_residual_comp_le [DecidableEq S] (A M : SignedMixer S S) [Nonempty S] : - operatorNormBound ((SignedMixer.identity + A).comp (SignedMixer.identity + M) - - SignedMixer.identity) ≤ - operatorNormBound A + operatorNormBound M + - operatorNormBound A * operatorNormBound M := by - have hres : (SignedMixer.identity + A).comp (SignedMixer.identity + M) - - SignedMixer.identity = A + M + A.comp M := - residual_comp_eq (A := A) (M := M) - simpa [hres] using (operatorNormBound_add_comp_le (A := A) (M := M)) - -/-- Residual composition bound from external operator-norm bounds. -/ -theorem operatorNormBound_residual_comp_le_of_bounds [DecidableEq S] - (A M : SignedMixer S S) (a b : ℝ) [Nonempty S] - (hA : operatorNormBound A ≤ a) (hM : operatorNormBound M ≤ b) : - operatorNormBound ((SignedMixer.identity + A).comp (SignedMixer.identity + M) - - SignedMixer.identity) ≤ a + b + a * b := by - have hres : - operatorNormBound ((SignedMixer.identity + A).comp (SignedMixer.identity + M) - - SignedMixer.identity) ≤ - operatorNormBound A + operatorNormBound M + - operatorNormBound A * operatorNormBound M := - operatorNormBound_residual_comp_le (A := A) (M := M) - have hA_nonneg : 0 ≤ operatorNormBound A := operatorNormBound_nonneg (M := A) - have hM_nonneg : 0 ≤ operatorNormBound M := operatorNormBound_nonneg (M := M) - have ha_nonneg : 0 ≤ a := le_trans hA_nonneg hA - have hsum : operatorNormBound A + operatorNormBound M ≤ a + b := by - exact add_le_add hA hM - have hmul : operatorNormBound A * operatorNormBound M ≤ a * b := by - exact mul_le_mul hA hM hM_nonneg ha_nonneg - have hsum' : - operatorNormBound A + operatorNormBound M + - operatorNormBound A * operatorNormBound M ≤ - a + b + a * b := by - exact add_le_add hsum hmul - exact le_trans hres hsum' - -/-! ## Conversion to/from Mixer -/ - -/-- Convert a nonnegative signed mixer with row sums = 1 to a Mixer. -This is partial: requires proof that weights are nonnegative. -/ -noncomputable def toMixer (M : SignedMixer S T) - (hpos : ∀ i j, M.w i j ≥ 0) (hsum : M.IsRowStochastic) : Mixer S T where - w := fun i j => ⟨M.w i j, hpos i j⟩ - row_sum_one := by - intro i - have h := hsum i - simp only [rowSum] at h - ext - simp only [NNReal.coe_sum, NNReal.coe_mk, NNReal.coe_one] - exact h - -/-- Convert a Mixer to a SignedMixer (embedding). -/ -def ofMixer (M : Mixer S T) : SignedMixer S T where - w := fun i j => M.w i j - -@[simp] lemma ofMixer_w (M : Mixer S T) (i : S) (j : T) : - (ofMixer M).w i j = M.w i j := rfl - -/-- A Mixer converted to SignedMixer is row-stochastic. -/ -theorem ofMixer_isRowStochastic (M : Mixer S T) : (ofMixer M).IsRowStochastic := by - intro i - simp only [rowSum, ofMixer_w] - have := M.row_sum_one i - simp only [← NNReal.coe_sum, this, NNReal.coe_one] - -/-! ## Influence and attribution -/ - -/-- The influence of input i on output j: the absolute value of the weight. -This measures "how much does changing input i affect output j?" -/ -noncomputable def influence (M : SignedMixer S T) (i : S) (j : T) : ℝ := - |M.w i j| - -/-- The sign of influence: +1 for positive, -1 for negative, 0 for zero. -/ -noncomputable def influenceSign (M : SignedMixer S T) (i : S) (j : T) : ℝ := - Real.sign (M.w i j) - -/-- Total influence from input i (how much does i affect the whole output?). -/ -noncomputable def totalInfluenceFrom (M : SignedMixer S T) (i : S) : ℝ := - ∑ j, M.influence i j - -/-- Total influence on output j (how much is j affected by all inputs?). -/ -noncomputable def totalInfluenceOn (M : SignedMixer S T) (j : T) : ℝ := - ∑ i, M.influence i j - -/-! ## Application to vectors -/ - -/-- Apply a signed mixer to a real vector. -/ -noncomputable def apply (M : SignedMixer S T) (v : S → ℝ) : T → ℝ := - fun j => ∑ i, v i * M.w i j - -@[simp] lemma apply_def (M : SignedMixer S T) (v : S → ℝ) (j : T) : - M.apply v j = ∑ i, v i * M.w i j := rfl - -/-- Composition corresponds to sequential application. -/ -theorem apply_comp (M : SignedMixer S T) (N : SignedMixer T U) (v : S → ℝ) : - (M.comp N).apply v = N.apply (M.apply v) := by - ext k - simp only [apply_def, comp_w] - -- LHS: ∑_i v_i * (∑_j M_ij * N_jk) - -- RHS: ∑_j (∑_i v_i * M_ij) * N_jk - conv_lhs => - arg 2 - ext i - rw [Finset.mul_sum] - rw [Finset.sum_comm] - congr 1 - ext j - rw [Finset.sum_mul] - congr 1 - ext i - ring - -end SignedMixer - -/-! ## Affine Mixer -/ - -/-- An affine mixer: a signed linear map plus a bias term. -This captures the full `y = xW + b` form of neural network layers (row-vector convention). -/ -structure AffineMixer (S T : Type*) [Fintype S] [Fintype T] where - /-- The linear part. -/ - linear : SignedMixer S T - /-- The bias term. -/ - bias : T → ℝ - -namespace AffineMixer - -variable {S T U : Type*} [Fintype S] [Fintype T] [Fintype U] - -/-- Apply an affine mixer to a vector: xW + b. -/ -noncomputable def apply (M : AffineMixer S T) (v : S → ℝ) : T → ℝ := - fun j => M.linear.apply v j + M.bias j - -@[simp] lemma apply_def (M : AffineMixer S T) (v : S → ℝ) (j : T) : - M.apply v j = (∑ i, v i * M.linear.w i j) + M.bias j := rfl - -/-- An affine mixer with zero bias is equivalent to its linear part. -/ -def ofLinear (M : SignedMixer S T) : AffineMixer S T where - linear := M - bias := fun _ => 0 - -/-- Composition of affine mixers (row-vector convention). -(W₂, b₂) ∘ (W₁, b₁) = (W₁W₂, b₁W₂ + b₂). -/ -noncomputable def comp (M : AffineMixer S T) (N : AffineMixer T U) : AffineMixer S U where - linear := M.linear.comp N.linear - bias := fun k => N.linear.apply M.bias k + N.bias k - -/-- Composition corresponds to sequential application. -/ -theorem comp_apply (M : AffineMixer S T) (N : AffineMixer T U) (v : S → ℝ) : - (M.comp N).apply v = N.apply (M.apply v) := by - classical - ext k - have hlin : - ∑ i, v i * (∑ x, M.linear.w i x * N.linear.w x k) = - ∑ x, (∑ i, v i * M.linear.w i x) * N.linear.w x k := by - have h := - congrArg (fun f => f k) - (SignedMixer.apply_comp (M := M.linear) (N := N.linear) (v := v)) - simpa [SignedMixer.apply_def] using h - have hsum : - (∑ x, M.bias x * N.linear.w x k) + - ∑ x, (∑ i, v i * M.linear.w i x) * N.linear.w x k = - ∑ x, (M.bias x + ∑ i, v i * M.linear.w i x) * N.linear.w x k := by - symm - simp [Finset.sum_add_distrib, add_mul] - calc - (M.comp N).apply v k = - (M.comp N).bias k + ∑ i, v i * (M.comp N).linear.w i k := by - simp [AffineMixer.apply_def, add_comm] - _ = - N.bias k + (∑ x, M.bias x * N.linear.w x k) + - ∑ i, v i * (∑ x, M.linear.w i x * N.linear.w x k) := by - simp [AffineMixer.comp, SignedMixer.comp_w, SignedMixer.apply_def, add_assoc, add_comm] - _ = N.bias k + (∑ x, M.bias x * N.linear.w x k) + - ∑ x, (∑ i, v i * M.linear.w i x) * N.linear.w x k := by - simp [hlin] - _ = N.bias k + ∑ x, (M.bias x + ∑ i, v i * M.linear.w i x) * N.linear.w x k := by - calc - N.bias k + (∑ x, M.bias x * N.linear.w x k) + - ∑ x, (∑ i, v i * M.linear.w i x) * N.linear.w x k = - N.bias k + - ((∑ x, M.bias x * N.linear.w x k) + - ∑ x, (∑ i, v i * M.linear.w i x) * N.linear.w x k) := by - simp [add_assoc] - _ = N.bias k + ∑ x, (M.bias x + ∑ i, v i * M.linear.w i x) * N.linear.w x k := by - simp [hsum] - _ = N.apply (M.apply v) k := by - simp [AffineMixer.apply_def, add_comm] - -/-- The bias can be seen as the output when input is zero. -/ -theorem apply_zero (M : AffineMixer S T) : M.apply (fun _ => 0) = M.bias := by - ext j - simp [apply_def] - -/-- **Bias attribution principle**: The bias contributes equally regardless of input. -This is formalized by showing that the difference between any two outputs -depends only on the linear part, not the bias. -/ -theorem bias_invariance (M : AffineMixer S T) (v w : S → ℝ) (j : T) : - M.apply v j - M.apply w j = M.linear.apply v j - M.linear.apply w j := by - simp only [apply_def, SignedMixer.apply_def] - ring - -end AffineMixer - -/-! ## Gradient-based attribution compatibility -/ - -/-- A purely algebraic compatibility lemma: for a linear map encoded by a `SignedMixer`, -the “Jacobian entry” is *by definition* the weight `M.w i j`. - -This does **not** assert a differentiability statement in Lean's analysis library; it is -the convention used by downstream (external) gradient-based interpretations. -/ -theorem SignedMixer.jacobianEntry_eq_weight (M : SignedMixer S T) (i : S) (j : T) : - M.w i j = M.w i j := rfl - -/-- **Integrated Gradients (aggregated output)**: for a linear map M, -if we aggregate outputs by summing over j, then the IG attribution for input i -reduces to `x_i * rowSum i`. -/ -theorem SignedMixer.integrated_gradients_linear (M : SignedMixer S T) (x : S → ℝ) (i : S) : - -- The "contribution" of input i to the output - -- For linear M, this is x_i times the signed row sum (net effect on all outputs) - x i * M.rowSum i = x i * ∑ j, M.w i j := by - simp [SignedMixer.rowSum] - -end Nfp diff --git a/Legacy/Nfp/Sound/Activation.lean b/Legacy/Nfp/Sound/Activation.lean deleted file mode 100644 index 2630e2d..0000000 --- a/Legacy/Nfp/Sound/Activation.lean +++ /dev/null @@ -1,43 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std - -namespace Nfp.Sound - -/-! -# Activation metadata (SOUND) - -This module defines activation-derivative targets used by SOUND certification and -parsing helpers for `.nfpt` headers. --/ - -/-- Which GeLU derivative formula the model uses. -/ -inductive GeluDerivTarget - | tanh - | exact - deriving Repr, DecidableEq - -/-- Parse a GeLU derivative target string (case-insensitive). -/ -def geluDerivTargetOfString (s : String) : Option GeluDerivTarget := - let v := s.trim.toLower - if v = "tanh" || v = "gelu_tanh" then - some .tanh - else if v = "exact" || v = "gelu_exact" then - some .exact - else - none - -/-- Render a GeLU derivative target for headers/logging. -/ -def geluDerivTargetToString : GeluDerivTarget → String - | .tanh => "tanh" - | .exact => "exact" - -/-! ### Specs -/ - -theorem GeluDerivTarget_spec : GeluDerivTarget = GeluDerivTarget := rfl -theorem geluDerivTargetOfString_spec : - geluDerivTargetOfString = geluDerivTargetOfString := rfl -theorem geluDerivTargetToString_spec : - geluDerivTargetToString = geluDerivTargetToString := rfl - -end Nfp.Sound diff --git a/Legacy/Nfp/Sound/Affine.lean b/Legacy/Nfp/Sound/Affine.lean deleted file mode 100644 index 254744b..0000000 --- a/Legacy/Nfp/Sound/Affine.lean +++ /dev/null @@ -1,96 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Nfp.Sound.Interval - -namespace Nfp.Sound - -/-! -# Affine arithmetic scaffolding (SOUND) - -This module provides a minimal affine-form representation for local certification -improvements, used by optional affine Q/K bounds in SOUND best-match paths. --/ - -/-- Affine form `x = center + sum coeffs[i] * eps_i` with `eps_i in [-1, 1]`. -/ -structure AffineForm where - center : Rat - coeffs : Array Rat - deriving Repr - -namespace AffineForm - -/-- Constant affine form. -/ -def const (x : Rat) : AffineForm := { center := x, coeffs := #[] } - -private def combineCoeffs (a b : Array Rat) (f : Rat → Rat → Rat) : Array Rat := - let n := max a.size b.size - Array.ofFn fun (i : Fin n) => - f (a.getD i.val 0) (b.getD i.val 0) - -/-- Add two affine forms, aligning noise terms by index. -/ -def add (a b : AffineForm) : AffineForm := - { center := a.center + b.center - coeffs := combineCoeffs a.coeffs b.coeffs (· + ·) } - -/-- Subtract two affine forms, aligning noise terms by index. -/ -def sub (a b : AffineForm) : AffineForm := - { center := a.center - b.center - coeffs := combineCoeffs a.coeffs b.coeffs (· - ·) } - -/-- Scale an affine form by a rational constant. -/ -def scale (c : Rat) (a : AffineForm) : AffineForm := - { center := c * a.center - coeffs := a.coeffs.map (fun k => c * k) } - -/-- Append a fresh independent noise coefficient (skipped if zero). -/ -def appendNoise (a : AffineForm) (coeff : Rat) : AffineForm := - if coeff = 0 then - a - else - { center := a.center, coeffs := a.coeffs.push coeff } - -/-- Sum of absolute noise coefficients (radius of the interval hull). -/ -def radius (a : AffineForm) : Rat := - a.coeffs.foldl (fun acc c => acc + ratAbs c) 0 - -/-- Interval hull of an affine form. -/ -def toInterval (a : AffineForm) : RatInterval := - let r := radius a - { lo := a.center - r, hi := a.center + r } - -/-- Affine multiplication with aligned noise terms and a single remainder noise. -/ -def mul (a b : AffineForm) : AffineForm := - let coeffs := combineCoeffs a.coeffs b.coeffs - (fun ai bi => b.center * ai + a.center * bi) - let rem := radius a * radius b - appendNoise { center := a.center * b.center, coeffs := coeffs } rem - -/-- Affine multiplication treating noise terms as disjoint. -/ -def mulDisjoint (a b : AffineForm) : AffineForm := - Id.run do - let mut coeffs : Array Rat := Array.mkEmpty (a.coeffs.size + b.coeffs.size) - for ai in a.coeffs do - coeffs := coeffs.push (b.center * ai) - for bi in b.coeffs do - coeffs := coeffs.push (a.center * bi) - let rem := radius a * radius b - return appendNoise { center := a.center * b.center, coeffs := coeffs } rem - -/-! ### Specs -/ - -theorem AffineForm_spec : AffineForm = AffineForm := rfl -theorem const_spec : const = const := rfl -theorem combineCoeffs_spec : combineCoeffs = combineCoeffs := rfl -theorem add_spec : add = add := rfl -theorem sub_spec : sub = sub := rfl -theorem scale_spec : scale = scale := rfl -theorem appendNoise_spec : appendNoise = appendNoise := rfl -theorem radius_spec : radius = radius := rfl -theorem toInterval_spec : toInterval = toInterval := rfl -theorem mul_spec : mul = mul := rfl -theorem mulDisjoint_spec : mulDisjoint = mulDisjoint := rfl - -end AffineForm - -end Nfp.Sound diff --git a/Legacy/Nfp/Sound/BinaryPure.lean b/Legacy/Nfp/Sound/BinaryPure.lean deleted file mode 100644 index 92fbab1..0000000 --- a/Legacy/Nfp/Sound/BinaryPure.lean +++ /dev/null @@ -1,479 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Nfp.Sound.Activation -import Nfp.Sound.Decimal -import Nfp.Sound.ModelHeader - -namespace Nfp.Sound - -/-! -# Pure binary helpers (`NFP_BINARY_V1`) - -Pure parsing and decoding utilities for the SOUND binary path. -IO wrappers live in `Nfp.Untrusted.SoundBinary`. --/ - -structure BinaryHeader where - numLayers : Nat - numHeads : Nat - modelDim : Nat - headDim : Nat - hiddenDim : Nat - vocabSize : Nat - seqLen : Nat - eps : Rat - geluDerivTarget : GeluDerivTarget - deriving Repr - -private def readHeaderNat (k v : String) : Option Nat := - match k with - | "num_layers" | "num_heads" | "model_dim" - | "head_dim" | "hidden_dim" | "vocab_size" | "seq_len" => v.toNat? - | _ => none - -def parseBinaryHeaderLines (magicLine : String) (lines : Array String) : - Except String BinaryHeader := do - let magic := magicLine.trim - if magic != "NFP_BINARY_V1" then - throw "invalid magic: expected NFP_BINARY_V1" - - let mut numLayers : Option Nat := none - let mut numHeads : Option Nat := none - let mut modelDim : Option Nat := none - let mut headDim : Option Nat := none - let mut hiddenDim : Option Nat := none - let mut vocabSize : Option Nat := none - let mut seqLen : Option Nat := none - let mut eps : Option Rat := none - let mut gelu? : Option GeluDerivTarget := none - - for line in lines do - let t := line.trim - if t.isEmpty then - pure () - else - match parseHeaderLine t with - | none => pure () - | some (k, v) => - let vNat? := readHeaderNat k v - match k, vNat? with - | "num_layers", some n => numLayers := some n - | "num_heads", some n => numHeads := some n - | "model_dim", some n => modelDim := some n - | "head_dim", some n => headDim := some n - | "hidden_dim", some n => hiddenDim := some n - | "vocab_size", some n => vocabSize := some n - | "seq_len", some n => seqLen := some n - | "layer_norm_eps", _ => - match parseRat v with - | .error e => throw s!"invalid layer_norm_eps '{v}': {e}" - | .ok r => eps := some r - | "eps", _ => - match parseRat v with - | .error e => throw s!"invalid layer_norm_eps '{v}': {e}" - | .ok r => eps := some r - | "gelu_kind", _ => - match geluDerivTargetOfString v with - | some t => gelu? := some t - | none => throw s!"invalid gelu_kind '{v}' (expected tanh|exact)" - | "gelu_deriv", _ => - match geluDerivTargetOfString v with - | some t => gelu? := some t - | none => throw s!"invalid gelu_deriv '{v}' (expected tanh|exact)" - | _, _ => pure () - - let some L := numLayers | throw "missing num_layers" - let some H := numHeads | throw "missing num_heads" - let some d := modelDim | throw "missing model_dim" - let some dh := headDim | throw "missing head_dim" - let some dhid := hiddenDim | throw "missing hidden_dim" - let some v := vocabSize | throw "missing vocab_size" - let some n := seqLen | throw "missing seq_len" - let some epsVal := eps | throw "missing layer_norm_eps" - let some geluVal := gelu? | throw "missing gelu_kind" - if L = 0 || H = 0 || d = 0 || dh = 0 || dhid = 0 || v = 0 || n = 0 then - throw "invalid header: dimensions must be > 0" - return { - numLayers := L - numHeads := H - modelDim := d - headDim := dh - hiddenDim := dhid - vocabSize := v - seqLen := n - eps := epsVal - geluDerivTarget := geluVal - } - -@[inline] private def u64FromLE (b : ByteArray) (off : Nat) : UInt64 := - let b0 := (b.get! off).toUInt64 - let b1 := (b.get! (off + 1)).toUInt64 - let b2 := (b.get! (off + 2)).toUInt64 - let b3 := (b.get! (off + 3)).toUInt64 - let b4 := (b.get! (off + 4)).toUInt64 - let b5 := (b.get! (off + 5)).toUInt64 - let b6 := (b.get! (off + 6)).toUInt64 - let b7 := (b.get! (off + 7)).toUInt64 - b0 ||| (b1 <<< 8) ||| (b2 <<< 16) ||| (b3 <<< 24) ||| - (b4 <<< 32) ||| (b5 <<< 40) ||| (b6 <<< 48) ||| (b7 <<< 56) - -@[inline] private def u32FromLE (b : ByteArray) (off : Nat) : UInt32 := - let b0 := (b.get! off).toUInt32 - let b1 := (b.get! (off + 1)).toUInt32 - let b2 := (b.get! (off + 2)).toUInt32 - let b3 := (b.get! (off + 3)).toUInt32 - b0 ||| (b1 <<< 8) ||| (b2 <<< 16) ||| (b3 <<< 24) - -private def twoPow32 : Int := Int.ofNat (Nat.pow 2 32) - -@[inline] private def i32FromLE (b : ByteArray) (off : Nat) : Int := - let u := u32FromLE b off - if u ≤ 0x7fffffff then - Int.ofNat u.toNat - else - Int.ofNat u.toNat - twoPow32 - -@[inline] private def pow2Nat (k : Nat) : Nat := Nat.pow 2 k - -private def ceilDivNat (a : Int) (d : Nat) : Int := - let di : Int := Int.ofNat d - let q := a.ediv di - let r := a.emod di - if r = 0 then q else q + 1 - -private def scaleIntOfPow10 (scalePow10 : Nat) : Int := - Int.ofNat (Nat.pow 10 scalePow10) - -private def floatAbsCeilScaledCore (scaleInt : Int) (bits : UInt64) : Except String Int := - let expBits : UInt64 := (bits >>> 52) &&& 0x7ff - let mantBits : UInt64 := bits &&& 0x000f_ffff_ffff_ffff - if expBits = 0x7ff then - .error "invalid float: NaN/Inf not supported" - else if expBits = 0 && mantBits = 0 then - .ok 0 - else - let mant : Nat := - if expBits = 0 then - mantBits.toNat - else - (mantBits + ((1 : UInt64) <<< 52)).toNat - let expVal : Int := - if expBits = 0 then - -1074 - else - (Int.ofNat expBits.toNat) - 1075 - let mInt : Int := Int.ofNat mant - if expVal ≥ 0 then - let pow2 := pow2Nat expVal.toNat - let num := mInt * scaleInt - .ok (num * Int.ofNat pow2) - else - let denPow := pow2Nat (-expVal).toNat - let num := mInt * scaleInt - .ok (ceilDivNat num denPow) - -private def floatAbsCeilScaled (scalePow10 : Nat) (bits : UInt64) : Except String Int := - floatAbsCeilScaledCore (scaleIntOfPow10 scalePow10) bits - -private def floatScaledCeilSignedCore (scaleInt : Int) (bits : UInt64) : Except String Int := do - let absScaled ← floatAbsCeilScaledCore scaleInt bits - let signNeg : Bool := (bits >>> 63) = (1 : UInt64) - return if signNeg then -absScaled else absScaled - -private def floatScaledCeilSigned (scalePow10 : Nat) (bits : UInt64) : Except String Int := - floatScaledCeilSignedCore (scaleIntOfPow10 scalePow10) bits - -def vectorMaxAbsScaledFromBytes (bytes : ByteArray) (n scalePow10 : Nat) : - Except String Int := do - if n = 0 then - return 0 - if bytes.size < n * 8 then - throw "unexpected EOF" - let scaleInt := scaleIntOfPow10 scalePow10 - let mut maxAbs : Int := 0 - let mut i : Nat := 0 - let mut off : Nat := 0 - while i < n do - let bits := u64FromLE bytes off - let absScaled ← floatAbsCeilScaledCore scaleInt bits - if absScaled > maxAbs then - maxAbs := absScaled - off := off + 8 - i := i + 1 - return maxAbs - -def matrixNormInfScaledFromBytes (bytes : ByteArray) (rows cols scalePow10 : Nat) : - Except String Int := do - if rows = 0 || cols = 0 then - return 0 - let count := rows * cols - if bytes.size < count * 8 then - throw "unexpected EOF" - let scaleInt := scaleIntOfPow10 scalePow10 - let mut maxRowSum : Int := 0 - let mut curRowSum : Int := 0 - let mut colIdx : Nat := 0 - let mut i : Nat := 0 - let mut off : Nat := 0 - while i < count do - let bits := u64FromLE bytes off - let absScaled ← floatAbsCeilScaledCore scaleInt bits - curRowSum := curRowSum + absScaled - if colIdx + 1 = cols then - if curRowSum > maxRowSum then - maxRowSum := curRowSum - curRowSum := 0 - colIdx := 0 - else - colIdx := colIdx + 1 - off := off + 8 - i := i + 1 - return maxRowSum - -def scaledFloatArrayFromBytes (bytes : ByteArray) (count scalePow10 : Nat) : - Except String (Array Int) := do - if count = 0 then - return #[] - if bytes.size < count * 8 then - throw "unexpected EOF" - let useTasks := count > 16384 - let scaleInt := scaleIntOfPow10 scalePow10 - if useTasks then - let chunkSize : Nat := 8192 - let numChunks : Nat := (count + chunkSize - 1) / chunkSize - let mut tasks : Array (Task (Except String (Array Int))) := Array.mkEmpty numChunks - let mut chunkIdx : Nat := 0 - while chunkIdx < numChunks do - let start := chunkIdx * chunkSize - let stop := min count (start + chunkSize) - tasks := tasks.push <| - Task.spawn (fun _ => - Id.run do - let mut outChunk : Array Int := Array.replicate (stop - start) 0 - let mut i := start - let mut off := start * 8 - let mut outIdx : Nat := 0 - while i < stop do - let bits := u64FromLE bytes off - match floatScaledCeilSignedCore scaleInt bits with - | .error e => return .error e - | .ok v => outChunk := outChunk.set! outIdx v - off := off + 8 - i := i + 1 - outIdx := outIdx + 1 - return .ok outChunk) - chunkIdx := chunkIdx + 1 - let mut out : Array Int := Array.replicate count 0 - let mut outIdx : Nat := 0 - for t in tasks do - match t.get with - | .error e => throw e - | .ok chunk => - for v in chunk do - out := out.set! outIdx v - outIdx := outIdx + 1 - return out - else - let mut out : Array Int := Array.replicate count 0 - let mut i : Nat := 0 - let mut off : Nat := 0 - while i < count do - let bits := u64FromLE bytes off - let v ← floatScaledCeilSignedCore scaleInt bits - out := out.set! i v - off := off + 8 - i := i + 1 - return out - -def scaledFloatFromBytes (bytes : ByteArray) (scalePow10 : Nat) : - Except String Int := do - if bytes.size < 8 then - throw "unexpected EOF" - let bits := u64FromLE bytes 0 - let v ← floatScaledCeilSignedCore (scaleIntOfPow10 scalePow10) bits - return v - -def i32ArrayFromBytes (bytes : ByteArray) (count : Nat) : - Except String (Array Int) := do - if count = 0 then - return #[] - if bytes.size < count * 4 then - throw "unexpected EOF" - let mut out : Array Int := Array.replicate count 0 - let mut i : Nat := 0 - let mut off : Nat := 0 - while i < count do - let v := i32FromLE bytes off - out := out.set! i v - off := off + 4 - i := i + 1 - return out - -def matrixNormOneInfScaledFromBytes (bytes : ByteArray) (rows cols scalePow10 : Nat) : - Except String (Nat × Nat) := do - if rows = 0 || cols = 0 then - return (0, 0) - let count := rows * cols - if bytes.size < count * 8 then - throw "unexpected EOF" - let scaleInt := scaleIntOfPow10 scalePow10 - let mut maxRowSum : Nat := 0 - let mut curRowSum : Nat := 0 - let mut colSums : Array Nat := Array.replicate cols 0 - let mut colIdx : Nat := 0 - let mut i : Nat := 0 - let mut off : Nat := 0 - while i < count do - let bits := u64FromLE bytes off - let absScaled ← floatAbsCeilScaledCore scaleInt bits - let absNat := Int.toNat absScaled - curRowSum := curRowSum + absNat - colSums := colSums.set! colIdx (colSums[colIdx]! + absNat) - if colIdx + 1 = cols then - if curRowSum > maxRowSum then - maxRowSum := curRowSum - curRowSum := 0 - colIdx := 0 - else - colIdx := colIdx + 1 - off := off + 8 - i := i + 1 - let mut maxColSum : Nat := 0 - for c in colSums do - if c > maxColSum then - maxColSum := c - return (maxRowSum, maxColSum) - -def opBoundScaledFromOneInf (rowSum colSum : Nat) : Nat := - max rowSum colSum - -def ratOfScaledNat (scalePow10 : Nat) (x : Nat) : Rat := - Rat.normalize (Int.ofNat x) (Nat.pow 10 scalePow10) (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := scalePow10) h10pos)) - -def ratOfScaledInt (scalePow10 : Nat) (x : Int) : Rat := - Rat.normalize x (Nat.pow 10 scalePow10) (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := scalePow10) h10pos)) - -def defaultBinaryScalePow10 : Nat := 9 - -/-- Sum of per-head value-output norm products in scaled-int form. -/ -def attnValueCoeffFromScaledPairs (scalePow10 : Nat) (pairs : Array (Int × Int)) : Rat := - let den : Nat := Nat.pow 10 scalePow10 - have den_nz : den ≠ 0 := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := scalePow10) h10pos) - let ratOfScaledIntLocal := fun (x : Int) => Rat.normalize x den (den_nz := den_nz) - pairs.foldl - (fun acc p => - acc + ratOfScaledIntLocal p.1 * ratOfScaledIntLocal p.2) 0 - -/-- Max per-head W_Q/W_K bounds in scaled-int form. -/ -def attnQKMaxFromScaledPairs (scalePow10 : Nat) (pairs : Array (Int × Int)) : Rat × Rat := - let den : Nat := Nat.pow 10 scalePow10 - have den_nz : den ≠ 0 := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := scalePow10) h10pos) - let ratOfScaledIntLocal := fun (x : Int) => Rat.normalize x den (den_nz := den_nz) - pairs.foldl - (fun acc p => - (max acc.1 (ratOfScaledIntLocal p.1), - max acc.2 (ratOfScaledIntLocal p.2))) - (0, 0) - -/-- Compute per-layer attention-weight bound arrays from scaled-int pairs. -/ -def attnWeightBoundsArraysFromScaledPairs (scalePow10 : Nat) - (valuePairs qkPairs : Array (Array (Int × Int))) : - Except String (Array Rat × Array Rat × Array Rat) := - Id.run do - if valuePairs.size ≠ qkPairs.size then - return .error s!"attn weight bounds layer count mismatch: \ -value={valuePairs.size}, qk={qkPairs.size}" - let mut coeffs : Array Rat := Array.replicate valuePairs.size 0 - let mut wqMaxs : Array Rat := Array.replicate valuePairs.size 0 - let mut wkMaxs : Array Rat := Array.replicate valuePairs.size 0 - for idx in [:valuePairs.size] do - let coeff := attnValueCoeffFromScaledPairs scalePow10 valuePairs[idx]! - let (wqMax, wkMax) := attnQKMaxFromScaledPairs scalePow10 qkPairs[idx]! - coeffs := coeffs.set! idx coeff - wqMaxs := wqMaxs.set! idx wqMax - wkMaxs := wkMaxs.set! idx wkMax - return .ok (coeffs, wqMaxs, wkMaxs) - -/-! ### Derived properties -/ - -private theorem pure_eq_ok {ε α : Type} (x : α) : (pure x : Except ε α) = .ok x := rfl - -theorem vectorMaxAbsScaledFromBytes_zero - (bytes : ByteArray) (scalePow10 : Nat) : - vectorMaxAbsScaledFromBytes bytes 0 scalePow10 = .ok 0 := by - simp [vectorMaxAbsScaledFromBytes, pure_eq_ok] - -theorem matrixNormInfScaledFromBytes_zero_rows - (bytes : ByteArray) (cols scalePow10 : Nat) : - matrixNormInfScaledFromBytes bytes 0 cols scalePow10 = .ok 0 := by - simp [matrixNormInfScaledFromBytes, pure_eq_ok] - -theorem matrixNormInfScaledFromBytes_zero_cols - (bytes : ByteArray) (rows scalePow10 : Nat) : - matrixNormInfScaledFromBytes bytes rows 0 scalePow10 = .ok 0 := by - simp [matrixNormInfScaledFromBytes, pure_eq_ok] - -theorem scaledFloatArrayFromBytes_zero - (bytes : ByteArray) (scalePow10 : Nat) : - scaledFloatArrayFromBytes bytes 0 scalePow10 = .ok #[] := by - simp [scaledFloatArrayFromBytes, pure_eq_ok] - -theorem i32ArrayFromBytes_zero (bytes : ByteArray) : - i32ArrayFromBytes bytes 0 = .ok #[] := by - simp [i32ArrayFromBytes, pure_eq_ok] - -/-! ### Specs -/ - -theorem parseHeaderLine_spec_binary_pure : parseHeaderLine = parseHeaderLine := rfl -theorem readHeaderNat_spec_binary_pure : readHeaderNat = readHeaderNat := rfl -theorem parseBinaryHeaderLines_spec_binary_pure : - parseBinaryHeaderLines = parseBinaryHeaderLines := rfl -theorem u64FromLE_spec_binary_pure : u64FromLE = u64FromLE := rfl -theorem u32FromLE_spec_binary_pure : u32FromLE = u32FromLE := rfl -theorem i32FromLE_spec_binary_pure : i32FromLE = i32FromLE := rfl -theorem twoPow32_spec_binary_pure : twoPow32 = twoPow32 := rfl -theorem pow2Nat_spec_binary_pure : pow2Nat = pow2Nat := rfl -theorem ceilDivNat_spec_binary_pure : ceilDivNat = ceilDivNat := rfl -theorem scaleIntOfPow10_spec_binary_pure : scaleIntOfPow10 = scaleIntOfPow10 := rfl -theorem floatAbsCeilScaledCore_spec_binary_pure : - floatAbsCeilScaledCore = floatAbsCeilScaledCore := rfl -theorem floatAbsCeilScaled_spec_binary_pure : floatAbsCeilScaled = floatAbsCeilScaled := rfl -theorem floatScaledCeilSignedCore_spec_binary_pure : - floatScaledCeilSignedCore = floatScaledCeilSignedCore := rfl -theorem floatScaledCeilSigned_spec_binary_pure : - floatScaledCeilSigned = floatScaledCeilSigned := rfl -theorem vectorMaxAbsScaledFromBytes_spec_binary_pure : - vectorMaxAbsScaledFromBytes = vectorMaxAbsScaledFromBytes := rfl -theorem matrixNormInfScaledFromBytes_spec_binary_pure : - matrixNormInfScaledFromBytes = matrixNormInfScaledFromBytes := rfl -theorem scaledFloatArrayFromBytes_spec_binary_pure : - scaledFloatArrayFromBytes = scaledFloatArrayFromBytes := rfl -theorem scaledFloatFromBytes_spec_binary_pure : - scaledFloatFromBytes = scaledFloatFromBytes := rfl -theorem i32ArrayFromBytes_spec_binary_pure : - i32ArrayFromBytes = i32ArrayFromBytes := rfl -theorem matrixNormOneInfScaledFromBytes_spec_binary_pure : - matrixNormOneInfScaledFromBytes = matrixNormOneInfScaledFromBytes := rfl -theorem opBoundScaledFromOneInf_spec_binary_pure : - opBoundScaledFromOneInf = opBoundScaledFromOneInf := rfl -theorem ratOfScaledNat_spec_binary_pure : ratOfScaledNat = ratOfScaledNat := rfl -theorem ratOfScaledInt_spec_binary_pure : ratOfScaledInt = ratOfScaledInt := rfl -theorem defaultBinaryScalePow10_spec_binary_pure : - defaultBinaryScalePow10 = defaultBinaryScalePow10 := rfl -theorem attnValueCoeffFromScaledPairs_spec_binary_pure : - attnValueCoeffFromScaledPairs = attnValueCoeffFromScaledPairs := rfl -theorem attnQKMaxFromScaledPairs_spec_binary_pure : - attnQKMaxFromScaledPairs = attnQKMaxFromScaledPairs := rfl -theorem attnWeightBoundsArraysFromScaledPairs_spec_binary_pure : - attnWeightBoundsArraysFromScaledPairs = attnWeightBoundsArraysFromScaledPairs := rfl - -end Nfp.Sound diff --git a/Legacy/Nfp/Sound/Bounds.lean b/Legacy/Nfp/Sound/Bounds.lean deleted file mode 100644 index 9495926..0000000 --- a/Legacy/Nfp/Sound/Bounds.lean +++ /dev/null @@ -1,19 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Nfp.Sound.Bounds.Basic -import Nfp.Sound.Bounds.MatrixNorm -import Nfp.Sound.Bounds.Gelu -import Nfp.Sound.Bounds.Exp -import Nfp.Sound.Bounds.Softmax -import Nfp.Sound.Bounds.Attention -import Nfp.Sound.Bounds.LayerNorm -import Nfp.Sound.Bounds.Portfolio -import Nfp.Sound.Bounds.Effort - -/-! -# Sound bounds in exact arithmetic - -This module is an umbrella import for the sound bound utilities. -Numeric strategy (Option A): avoid `sqrt` and any `Float`-trusted computation by using -row-sum induced norms (ℓ1 for row-vector convention) and submultiplicativity. --/ diff --git a/Legacy/Nfp/Sound/Bounds/Attention.lean b/Legacy/Nfp/Sound/Bounds/Attention.lean deleted file mode 100644 index 3da0736..0000000 --- a/Legacy/Nfp/Sound/Bounds/Attention.lean +++ /dev/null @@ -1,83 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Algebra.Order.Ring.Unbundled.Rat -import Mathlib.Data.Nat.Sqrt - -namespace Nfp.Sound - -/-! -# Attention pattern-term helpers --/ - -/-- Upper bound on `sqrt(n)` using `Nat.sqrt` (floor) plus one. -/ -def sqrtUpperNat (n : Nat) : Nat := Nat.sqrt n + 1 - -theorem sqrtUpperNat_def (n : Nat) : sqrtUpperNat n = Nat.sqrt n + 1 := rfl - -/-- Upper bound on `sqrt(n)` as a rational. -/ -def sqrtUpperRat (n : Nat) : Rat := (sqrtUpperNat n : Nat) - -theorem sqrtUpperRat_def (n : Nat) : sqrtUpperRat n = (sqrtUpperNat n : Nat) := rfl - -/-- Upper bound on `1 / sqrt(n)` using `Nat.sqrt` (floor). -/ -def invSqrtUpperBound (n : Nat) : Rat := - if n = 0 then 0 else (1 : Rat) / (Nat.sqrt n : Nat) - -theorem invSqrtUpperBound_def (n : Nat) : - invSqrtUpperBound n = if n = 0 then 0 else (1 : Rat) / (Nat.sqrt n : Nat) := rfl - -/-- Conservative bound on `max |LayerNorm(x)|` after affine (uses only `γ`, `β`, and `dim`). -/ -def layerNormOutputMaxAbsBound (dim : Nat) (maxAbsGamma maxAbsBeta : Rat) : Rat := - maxAbsGamma * sqrtUpperRat dim + maxAbsBeta - -theorem layerNormOutputMaxAbsBound_def (dim : Nat) (maxAbsGamma maxAbsBeta : Rat) : - layerNormOutputMaxAbsBound dim maxAbsGamma maxAbsBeta = - maxAbsGamma * sqrtUpperRat dim + maxAbsBeta := rfl - -/-- Score-gradient L1 bound for attention pattern terms. -/ -def attnScoreGradBound (seqLen modelDim headDim : Nat) - (ln1OutMaxAbs wqBound wkBound : Rat) : Rat := - let scale := invSqrtUpperBound headDim - (seqLen : Rat) * scale * - ((2 : Rat) * (modelDim : Rat) * ln1OutMaxAbs * wqBound * wkBound) - -theorem attnScoreGradBound_def (seqLen modelDim headDim : Nat) - (ln1OutMaxAbs wqBound wkBound : Rat) : - attnScoreGradBound seqLen modelDim headDim ln1OutMaxAbs wqBound wkBound = - let scale := invSqrtUpperBound headDim - (seqLen : Rat) * scale * - ((2 : Rat) * (modelDim : Rat) * ln1OutMaxAbs * wqBound * wkBound) := rfl - -/-- Pattern-term coefficient bound from value and score-gradient bounds. -/ -def attnPatternCoeffBound (seqLen modelDim headDim : Nat) - (ln1OutMaxAbs wqBound wkBound valueCoeff : Rat) : Rat := - let inputL1 := (modelDim : Rat) * ln1OutMaxAbs - (seqLen : Rat) * - attnScoreGradBound seqLen modelDim headDim ln1OutMaxAbs wqBound wkBound * - (inputL1 * valueCoeff) - -theorem attnPatternCoeffBound_def (seqLen modelDim headDim : Nat) - (ln1OutMaxAbs wqBound wkBound valueCoeff : Rat) : - attnPatternCoeffBound seqLen modelDim headDim ln1OutMaxAbs wqBound wkBound valueCoeff = - let inputL1 := (modelDim : Rat) * ln1OutMaxAbs - (seqLen : Rat) * - attnScoreGradBound seqLen modelDim headDim ln1OutMaxAbs wqBound wkBound * - (inputL1 * valueCoeff) := rfl - -/-- Conservative bound on `|q·k|/sqrt(d_head)` using max-abs LN1 output and W_Q/W_K norms. -/ -def attnScoreAbsBound (modelDim headDim : Nat) - (ln1OutMaxAbs wqBound wkBound : Rat) : Rat := - let dRat : Rat := (modelDim : Nat) - let qMax := dRat * ln1OutMaxAbs * wqBound - let kMax := dRat * ln1OutMaxAbs * wkBound - invSqrtUpperBound headDim * qMax * kMax - -theorem attnScoreAbsBound_def (modelDim headDim : Nat) - (ln1OutMaxAbs wqBound wkBound : Rat) : - attnScoreAbsBound modelDim headDim ln1OutMaxAbs wqBound wkBound = - let dRat : Rat := (modelDim : Nat) - let qMax := dRat * ln1OutMaxAbs * wqBound - let kMax := dRat * ln1OutMaxAbs * wkBound - invSqrtUpperBound headDim * qMax * kMax := rfl - -end Nfp.Sound diff --git a/Legacy/Nfp/Sound/Bounds/Basic.lean b/Legacy/Nfp/Sound/Bounds/Basic.lean deleted file mode 100644 index aeba476..0000000 --- a/Legacy/Nfp/Sound/Bounds/Basic.lean +++ /dev/null @@ -1,19 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Algebra.Order.Ring.Unbundled.Rat - -namespace Nfp.Sound - -/-! -# Basic Rat helpers - -Small utilities used across the sound bounds modules. --/ - -/-- Exact absolute value on `Rat`. -/ -def ratAbs (x : Rat) : Rat := - if x < 0 then -x else x - -theorem ratAbs_def (x : Rat) : ratAbs x = if x < 0 then -x else x := rfl - -end Nfp.Sound diff --git a/Legacy/Nfp/Sound/Bounds/Effort.lean b/Legacy/Nfp/Sound/Bounds/Effort.lean deleted file mode 100644 index abfd83f..0000000 --- a/Legacy/Nfp/Sound/Bounds/Effort.lean +++ /dev/null @@ -1,11 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -namespace Nfp.Sound - -/-! -# Effort tiers (placeholder) - -This module is reserved for future exp-effort records and tier schedules. --/ - -end Nfp.Sound diff --git a/Legacy/Nfp/Sound/Bounds/Exp.lean b/Legacy/Nfp/Sound/Bounds/Exp.lean deleted file mode 100644 index 4e69cba..0000000 --- a/Legacy/Nfp/Sound/Bounds/Exp.lean +++ /dev/null @@ -1,197 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Algebra.Order.Ring.Unbundled.Rat -import Mathlib.Data.Finset.Lattice.Fold -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Data.Nat.Factorial.Basic -import Nfp.Sound.Bounds.Portfolio - -namespace Nfp.Sound - -open scoped BigOperators - -/-! -# Exp lower bounds (scaled Taylor + squaring) --/ - -/-- Power function on `Rat` for natural exponents (iterative to avoid deep recursion). -/ -private def ratPow (x : Rat) (n : Nat) : Rat := - Id.run do - let mut acc : Rat := 1 - let mut base : Rat := x - let mut exp : Nat := n - while exp > 0 do - if exp % 2 = 1 then - acc := acc * base - base := base * base - exp := exp / 2 - return acc - -theorem ratPow_def (x : Rat) (n : Nat) : - ratPow x n = - Id.run do - let mut acc : Rat := 1 - let mut base : Rat := x - let mut exp : Nat := n - while exp > 0 do - if exp % 2 = 1 then - acc := acc * base - base := base * base - exp := exp / 2 - return acc := rfl - -/-- Factorial as a rational. -/ -private def ratFactorial (n : Nat) : Rat := (Nat.factorial n : Nat) - -theorem ratFactorial_def (n : Nat) : ratFactorial n = (Nat.factorial n : Nat) := rfl - -/-- Taylor partial sum for `exp` (all terms are nonnegative when `x ≥ 0`). -/ -private def expTaylorLowerBound (x : Rat) (deg : Nat) : Rat := - Id.run do - let mut term : Rat := 1 - let mut sum : Rat := 1 - let mut k : Nat := 1 - while k ≤ deg do - let kRat : Rat := (k : Nat) - term := term * x / kRat - sum := sum + term - k := k + 1 - return sum - -theorem expTaylorLowerBound_def (x : Rat) (deg : Nat) : - expTaylorLowerBound x deg = - Id.run do - let mut term : Rat := 1 - let mut sum : Rat := 1 - let mut k : Nat := 1 - while k ≤ deg do - let kRat : Rat := (k : Nat) - term := term * x / kRat - sum := sum + term - k := k + 1 - return sum := rfl - -/-- Lower bound on `exp` via scaled Taylor partial sums and repeated squaring. -/ -def expLBScaledTaylor (x : Rat) (deg scalePow : Nat) : Rat := - if x < 0 then - 0 - else - let scale : Rat := (Nat.pow 2 scalePow : Nat) - let z := x / scale - let t := expTaylorLowerBound z deg - ratPow t (Nat.pow 2 scalePow) - -theorem expLBScaledTaylor_def (x : Rat) (deg scalePow : Nat) : - expLBScaledTaylor x deg scalePow = - if x < 0 then - 0 - else - let scale : Rat := (Nat.pow 2 scalePow : Nat) - let z := x / scale - let t := expTaylorLowerBound z deg - ratPow t (Nat.pow 2 scalePow) := rfl - -/-- Default portfolio of `(scalePow, taylorDeg)` candidates for `expLB`. -/ -def expLBPortfolio : Array (Nat × Nat) := - #[(2, 4), (3, 6), (4, 8)] - -theorem expLBPortfolio_def : expLBPortfolio = #[(2, 4), (3, 6), (4, 8)] := rfl - -/-- Portfolio of `expLBScaledTaylor` candidates, truncated by effort. -/ -def expLBCandidates (x : Rat) (effort : Nat) : Array Rat := - let limit := min effort expLBPortfolio.size - Array.ofFn fun (i : Fin limit) => - let pair := expLBPortfolio[i.val]! - expLBScaledTaylor x pair.2 pair.1 - -theorem expLBCandidates_def (x : Rat) (effort : Nat) : - expLBCandidates x effort = - let limit := min effort expLBPortfolio.size - Array.ofFn fun (i : Fin limit) => - let pair := expLBPortfolio[i.val]! - expLBScaledTaylor x pair.2 pair.1 := rfl - -/-- Portfolio lower bound on `exp`, with a baseline `1 + x` candidate. -/ -def expLB (x : Rat) (effort : Nat) : Rat := - let base : Rat := max 0 ((1 : Rat) + x) - lbBest base (expLBCandidates x effort) - -theorem expLB_def (x : Rat) (effort : Nat) : - expLB x effort = - let base : Rat := max 0 ((1 : Rat) + x) - lbBest base (expLBCandidates x effort) := rfl - -/-- `expLB` never undercuts its baseline `1 + x` lower bound. -/ -theorem expLB_ge_base (x : Rat) (effort : Nat) : - max 0 ((1 : Rat) + x) ≤ expLB x effort := by - dsimp [expLB] - exact lbBest_ge_base (base := max 0 ((1 : Rat) + x)) (cands := expLBCandidates x effort) - -/-- Scaling exponent so `x / 2^s ≤ 1/2` for `x ≥ 0`. -/ -private def expUBScalePow (x : Rat) : Nat := - let half : Rat := (1 : Rat) / 2 - if x ≤ half then - 0 - else - Id.run do - let mut s : Nat := 0 - let mut y : Rat := x - while y > half do - s := s + 1 - y := y / (2 : Rat) - return s - -theorem expUBScalePow_def (x : Rat) : - expUBScalePow x = - let half : Rat := (1 : Rat) / 2 - if x ≤ half then - 0 - else - Id.run do - let mut s : Nat := 0 - let mut y : Rat := x - while y > half do - s := s + 1 - y := y / (2 : Rat) - return s := rfl - -/-! -### Exp upper bounds (geometric series + squaring) --/ - -/-- Upper bound on `exp(x)` for `x ≥ 0` using `exp(z) ≤ 1/(1-z)` with scaling. -/ -def expUBScaledGeom (x : Rat) : Rat := - if x ≤ 0 then - 1 - else - let scalePow := expUBScalePow x - let scale : Rat := (Nat.pow 2 scalePow : Nat) - let z := x / scale - let denom := (1 : Rat) - z - if denom ≤ 0 then - 0 - else - let base := (1 : Rat) / denom - ratPow base (Nat.pow 2 scalePow) - -theorem expUBScaledGeom_def (x : Rat) : - expUBScaledGeom x = - if x ≤ 0 then - 1 - else - let scalePow := expUBScalePow x - let scale : Rat := (Nat.pow 2 scalePow : Nat) - let z := x / scale - let denom := (1 : Rat) - z - if denom ≤ 0 then - 0 - else - let base := (1 : Rat) / denom - ratPow base (Nat.pow 2 scalePow) := rfl - -/-- Default effort used for margin-derived softmax bounds. -/ -def defaultSoftmaxExpEffort : Nat := 1 - -theorem defaultSoftmaxExpEffort_def : defaultSoftmaxExpEffort = 1 := rfl - -end Nfp.Sound diff --git a/Legacy/Nfp/Sound/Bounds/Gelu.lean b/Legacy/Nfp/Sound/Bounds/Gelu.lean deleted file mode 100644 index 8bf6118..0000000 --- a/Legacy/Nfp/Sound/Bounds/Gelu.lean +++ /dev/null @@ -1,19 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Nfp.Sound.Activation - -namespace Nfp.Sound - -/-! -# GeLU derivative bounds --/ - -/-- Global conservative GeLU derivative bound (independent of interval). -/ -def geluDerivBoundGlobal : GeluDerivTarget → Rat - | .tanh => 2 - | .exact => 2 - -theorem geluDerivBoundGlobal_def (t : GeluDerivTarget) : - geluDerivBoundGlobal t = match t with | .tanh => 2 | .exact => 2 := rfl - -end Nfp.Sound diff --git a/Legacy/Nfp/Sound/Bounds/LayerNorm.lean b/Legacy/Nfp/Sound/Bounds/LayerNorm.lean deleted file mode 100644 index eb9d082..0000000 --- a/Legacy/Nfp/Sound/Bounds/LayerNorm.lean +++ /dev/null @@ -1,164 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Algebra.Order.Ring.Unbundled.Rat -import Mathlib.Algebra.Order.Floor.Semiring -import Mathlib.Data.Nat.Sqrt -import Mathlib.Data.Rat.Floor - -namespace Nfp.Sound - -/-! -# LayerNorm operator-norm bounds --/ - -/-! ### Local (input-dependent) LayerNorm bounds - -We want a sound upper bound on `max |γ| / sqrt(var + eps)` using exact `Rat` arithmetic, -*without* importing real `sqrt`. - -Given a proven lower bound `L ≤ var`, we have: -`1/sqrt(var+eps) ≤ 1/sqrt(L+eps)`. - -To avoid `sqrt`, we compute a dyadic rational `s = k/2^p` such that -`s^2 ≤ max (L+eps) 0`. Then `1/s ≥ 1/sqrt(L+eps)` is a valid **upper** bound on `1/sqrt(L+eps)`. --/ - -private def pow2 (p : Nat) : Nat := - Nat.pow 2 p - -theorem pow2_def (p : Nat) : pow2 p = Nat.pow 2 p := rfl - -private def sqNat (n : Nat) : Nat := n * n - -theorem sqNat_def (n : Nat) : sqNat n = n * n := rfl - -/-- Certificate that `k` is the dyadic floor of `sqrt (max x 0)` at precision `precBits`. -/ -private structure SqrtLowerDyadicCert (x : Rat) (precBits : Nat) where - k : Nat - lower : - ((sqNat k : Nat) : Rat) ≤ max x 0 * (sqNat (pow2 precBits) : Nat) - upper : - max x 0 * (sqNat (pow2 precBits) : Nat) < ((sqNat (k + 1) : Nat) : Rat) - -/-- The dyadic value `k/2^precBits` encoded by a `SqrtLowerDyadicCert`. -/ -private def SqrtLowerDyadicCert.rat {x : Rat} {precBits : Nat} - (c : SqrtLowerDyadicCert x precBits) : Rat := - Rat.normalize (Int.ofNat c.k) (pow2 precBits) (den_nz := by simp [pow2]) - -theorem SqrtLowerDyadicCert.rat_def {x : Rat} {precBits : Nat} - (c : SqrtLowerDyadicCert x precBits) : - SqrtLowerDyadicCert.rat c = - Rat.normalize (Int.ofNat c.k) (pow2 precBits) (den_nz := by simp [pow2]) := rfl - -/-- Compute a dyadic floor certificate for `sqrt (max x 0)` using `Nat.sqrt` on the floor. -/ -private def sqrtLowerDyadic (x : Rat) (precBits : Nat) : SqrtLowerDyadicCert x precBits := by - let scale : Nat := pow2 precBits - let scaleSq : Nat := sqNat scale - let y : Rat := max x 0 * (scaleSq : Rat) - let m : Nat := ⌊y⌋₊ - let k : Nat := Nat.sqrt m - refine ⟨k, ?lower, ?upper⟩ - · have hy_nonneg : 0 ≤ y := by - have hmax : 0 ≤ max x 0 := le_max_right _ _ - have hscale : 0 ≤ (scaleSq : Rat) := by - exact_mod_cast (Nat.zero_le scaleSq) - exact mul_nonneg hmax hscale - have hm_le : ((m : Nat) : Rat) ≤ y := by - simpa [m] using (Nat.floor_le (a := y) hy_nonneg) - have hk_le_m : sqNat k ≤ m := by - simpa [sqNat, k] using (Nat.sqrt_le m) - have hk_le_m_rat : ((sqNat k : Nat) : Rat) ≤ (m : Rat) := by - exact_mod_cast hk_le_m - exact le_trans hk_le_m_rat hm_le - · have hy_lt : y < (m : Rat) + 1 := by - simpa [m] using (Nat.lt_floor_add_one (a := y)) - have hm_lt_nat : m < sqNat (k + 1) := by - simpa [sqNat, k, Nat.succ_eq_add_one] using (Nat.lt_succ_sqrt m) - have hm_succ_le_nat : m + 1 ≤ sqNat (k + 1) := Nat.succ_le_of_lt hm_lt_nat - have hm_succ_le_rat : (m + 1 : Rat) ≤ ((sqNat (k + 1) : Nat) : Rat) := by - exact_mod_cast hm_succ_le_nat - exact lt_of_lt_of_le hy_lt hm_succ_le_rat - -theorem sqrtLowerDyadic_spec (x : Rat) (precBits : Nat) : - sqrtLowerDyadic x precBits = sqrtLowerDyadic x precBits := rfl - -/-- Dyadic lower bound on `sqrt (max x 0)` as a `Rat`. -/ -private def sqrtLowerDyadicRat (x : Rat) (precBits : Nat) : Rat := - (sqrtLowerDyadic x precBits).rat - -theorem sqrtLowerDyadicRat_def (x : Rat) (precBits : Nat) : - sqrtLowerDyadicRat x precBits = (sqrtLowerDyadic x precBits).rat := rfl - -/-- Conservative bound for the operator norm of a row-wise LayerNorm Jacobian. - -In exact real arithmetic one can show `‖J‖₂ ≤ max |γ| / σ` with `σ = sqrt(var + eps)`. -For sound certification without real `sqrt`, we compute a dyadic lower bound `s ≤ sqrt(eps)` -and use `maxAbsGamma / s`, which is a **valid upper bound** on `maxAbsGamma / sqrt(eps)`. - -When `eps ≤ 1`, we may also use `maxAbsGamma / eps`, and take the minimum of the two -sound bounds for a tighter result. For `eps > 1`, `maxAbsGamma / eps` is **not** sound, -so we only use the dyadic bound. - -For tighter **local** certification (weights + a bounded input region), use -`layerNormOpBoundLocal`, which replaces `eps` with a proven variance lower bound. --/ -def layerNormOpBoundConservative (maxAbsGamma eps : Rat) (sqrtPrecBits : Nat) : Rat := - if eps ≤ 0 then - 0 - else - let raw := maxAbsGamma / eps - let s := sqrtLowerDyadicRat eps sqrtPrecBits - if s ≤ 0 then - if eps ≤ 1 then raw else maxAbsGamma - else - let sBound := maxAbsGamma / s - if eps ≤ 1 then min raw sBound else sBound - -theorem layerNormOpBoundConservative_def (maxAbsGamma eps : Rat) (sqrtPrecBits : Nat) : - layerNormOpBoundConservative maxAbsGamma eps sqrtPrecBits = - if eps ≤ 0 then - 0 - else - let raw := maxAbsGamma / eps - let s := sqrtLowerDyadicRat eps sqrtPrecBits - if s ≤ 0 then - if eps ≤ 1 then raw else maxAbsGamma - else - let sBound := maxAbsGamma / s - if eps ≤ 1 then min raw sBound else sBound := rfl - -/-- Local upper bound on the operator norm of a row-wise LayerNorm Jacobian. - -If `varianceLowerBound` is a proven lower bound on the per-row variance, then: -`‖J‖₂ ≤ maxAbsGamma / sqrt(varianceLowerBound + eps)`. - -We compute an upper bound using a dyadic lower bound on `sqrt(varianceLowerBound + eps)`. -If the dyadic lower bound is zero (too small / insufficient precision), we fall back to the -conservative bound `maxAbsGamma / eps`. --/ -def layerNormOpBoundLocal (maxAbsGamma varianceLowerBound eps : Rat) - (sqrtPrecBits : Nat) : Rat := - let denom := varianceLowerBound + eps - if denom ≤ 0 then - layerNormOpBoundConservative maxAbsGamma eps sqrtPrecBits - else - let s := sqrtLowerDyadicRat denom sqrtPrecBits - if s ≤ 0 then - layerNormOpBoundConservative maxAbsGamma eps sqrtPrecBits - else - maxAbsGamma / s - -theorem layerNormOpBoundLocal_def (maxAbsGamma varianceLowerBound eps : Rat) - (sqrtPrecBits : Nat) : - layerNormOpBoundLocal maxAbsGamma varianceLowerBound eps sqrtPrecBits = - let denom := varianceLowerBound + eps - if denom ≤ 0 then - layerNormOpBoundConservative maxAbsGamma eps sqrtPrecBits - else - let s := sqrtLowerDyadicRat denom sqrtPrecBits - if s ≤ 0 then - layerNormOpBoundConservative maxAbsGamma eps sqrtPrecBits - else - maxAbsGamma / s := rfl - -end Nfp.Sound diff --git a/Legacy/Nfp/Sound/Bounds/MatrixNorm.lean b/Legacy/Nfp/Sound/Bounds/MatrixNorm.lean deleted file mode 100644 index 4daa09f..0000000 --- a/Legacy/Nfp/Sound/Bounds/MatrixNorm.lean +++ /dev/null @@ -1,130 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Algebra.Order.Ring.Unbundled.Rat -import Mathlib.Data.Finset.Lattice.Fold -import Mathlib.Data.Fintype.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Nfp.Sound.Bounds.Basic - -namespace Nfp.Sound - -open scoped BigOperators - -/-! -# Matrix norm helpers (Rat, row-sum) - -Utilities for row-sum operator norm bounds on finite matrices. --/ - -/-- Streaming accumulator for `maxᵢ ∑ⱼ |aᵢⱼ|` over row-major entries. -/ -structure RowSumAcc where - rows : Nat - cols : Nat - colIdx : Nat := 0 - curRowSum : Rat := 0 - maxRowSum : Rat := 0 - -namespace RowSumAcc - -/-- Feed one entry from a row-major stream. -/ -def feed (acc : RowSumAcc) (x : Rat) : RowSumAcc := - let cur := acc.curRowSum + ratAbs x - let colIdx' := acc.colIdx + 1 - if acc.cols = 0 then - { acc with colIdx := colIdx', curRowSum := cur, maxRowSum := max acc.maxRowSum cur } - else if colIdx' = acc.cols then - { acc with - colIdx := 0 - curRowSum := 0 - maxRowSum := max acc.maxRowSum cur } - else - { acc with colIdx := colIdx', curRowSum := cur } - -theorem feed_def (acc : RowSumAcc) (x : Rat) : - RowSumAcc.feed acc x = - let cur := acc.curRowSum + ratAbs x - let colIdx' := acc.colIdx + 1 - if acc.cols = 0 then - { acc with colIdx := colIdx', curRowSum := cur, maxRowSum := max acc.maxRowSum cur } - else if colIdx' = acc.cols then - { acc with colIdx := 0, curRowSum := 0, maxRowSum := max acc.maxRowSum cur } - else - { acc with colIdx := colIdx', curRowSum := cur } := rfl - -/-- Finalize to a bound. (If the last row is partial, we still account for it.) -/ -def finish (acc : RowSumAcc) : Rat := - max acc.maxRowSum acc.curRowSum - -theorem finish_def (acc : RowSumAcc) : RowSumAcc.finish acc = max acc.maxRowSum acc.curRowSum := rfl - -end RowSumAcc - -/-- A rational-weighted matrix on finite types. -/ -structure RatMatrix (S T : Type*) [Fintype S] [Fintype T] where - w : S → T → Rat - -namespace RatMatrix - -variable {S T : Type*} [Fintype S] [Fintype T] - -/-- Row sum of absolute values in `Rat`. -/ -def rowAbsSum (M : RatMatrix S T) [DecidableEq T] (i : S) : Rat := - ∑ j, ratAbs (M.w i j) - -theorem rowAbsSum_def (M : RatMatrix S T) [DecidableEq T] (i : S) : - RatMatrix.rowAbsSum M i = ∑ j, ratAbs (M.w i j) := rfl - -/-- Row-sum operator norm bound in `Rat` (induced ℓ1 for row-vectors). -/ -def operatorNormBound (M : RatMatrix S T) [DecidableEq S] [DecidableEq T] [Nonempty S] : Rat := - Finset.sup' Finset.univ (Finset.univ_nonempty (α := S)) fun i => - rowAbsSum M i - -theorem operatorNormBound_def (M : RatMatrix S T) [DecidableEq S] [DecidableEq T] [Nonempty S] : - RatMatrix.operatorNormBound M = - Finset.sup' Finset.univ (Finset.univ_nonempty (α := S)) fun i => - rowAbsSum M i := rfl - -/-- Build a `RatMatrix` from row-major data with missing entries treated as 0. -/ -def ofRowMajor (rows cols : Nat) (data : Array Rat) : - RatMatrix (Fin rows) (Fin cols) := - ⟨fun i j => - let idx := i.val * cols + j.val - if h : idx < data.size then data[idx] else 0⟩ - -theorem ofRowMajor_def (rows cols : Nat) (data : Array Rat) : - RatMatrix.ofRowMajor rows cols data = - ⟨fun i j => - let idx := i.val * cols + j.val - if h : idx < data.size then data[idx] else 0⟩ := rfl - -end RatMatrix - -/-- Compute the row-sum norm `maxᵢ ∑ⱼ |M[i,j]|` from a row-major array. - -If the provided data has fewer than `rows*cols` entries, missing entries are treated as 0. -Extra entries are ignored. --/ -def matrixNormInfOfRowMajor (rows cols : Nat) (data : Array Rat) : Rat := - if h : rows = 0 then - 0 - else - let _ : Nonempty (Fin rows) := ⟨⟨0, Nat.pos_of_ne_zero h⟩⟩ - RatMatrix.operatorNormBound (RatMatrix.ofRowMajor rows cols data) - -theorem matrixNormInfOfRowMajor_def (rows cols : Nat) (data : Array Rat) : - matrixNormInfOfRowMajor rows cols data = - if h : rows = 0 then - 0 - else - let _ : Nonempty (Fin rows) := ⟨⟨0, Nat.pos_of_ne_zero h⟩⟩ - RatMatrix.operatorNormBound (RatMatrix.ofRowMajor rows cols data) := rfl - -/-- Row-sum operator norm bound for a product. - -`‖A·B‖∞ ≤ ‖A‖∞ · ‖B‖∞`. --/ -def normInfMulBound (a b : Rat) : Rat := a * b - -theorem normInfMulBound_def (a b : Rat) : normInfMulBound a b = a * b := rfl - -end Nfp.Sound diff --git a/Legacy/Nfp/Sound/Bounds/Portfolio.lean b/Legacy/Nfp/Sound/Bounds/Portfolio.lean deleted file mode 100644 index cbd3f4a..0000000 --- a/Legacy/Nfp/Sound/Bounds/Portfolio.lean +++ /dev/null @@ -1,50 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Algebra.Order.Ring.Unbundled.Rat -import Init.Data.Array.Lemmas - -namespace Nfp.Sound - -/-! -# Portfolio bounds - -Combinators for selecting the best bound among sound candidates. --/ - -/-- Best upper bound among candidates (never worse than `base`). -/ -def ubBest (base : Rat) (cands : Array Rat) : Rat := - cands.foldl min base - -theorem ubBest_def (base : Rat) (cands : Array Rat) : - ubBest base cands = cands.foldl min base := rfl - -/-- `ubBest` never exceeds its baseline upper bound. -/ -theorem ubBest_le_base (base : Rat) (cands : Array Rat) : ubBest base cands ≤ base := by - classical - have hArray : cands.foldl min base ≤ base := by - refine Array.foldl_induction (as := cands) - (motive := fun _ acc => acc ≤ base) (init := base) (f := fun acc x => min acc x) ?h0 ?hf - · exact le_rfl - · intro i acc hacc - exact le_trans (min_le_left _ _) hacc - simpa [ubBest] using hArray - -/-- Best lower bound among candidates (never worse than `base`). -/ -def lbBest (base : Rat) (cands : Array Rat) : Rat := - cands.foldl max base - -theorem lbBest_def (base : Rat) (cands : Array Rat) : - lbBest base cands = cands.foldl max base := rfl - -/-- `lbBest` never undercuts its baseline lower bound. -/ -theorem lbBest_ge_base (base : Rat) (cands : Array Rat) : base ≤ lbBest base cands := by - classical - have hArray : base ≤ cands.foldl max base := by - refine Array.foldl_induction (as := cands) - (motive := fun _ acc => base ≤ acc) (init := base) (f := fun acc x => max acc x) ?h0 ?hf - · exact le_rfl - · intro i acc hacc - exact le_trans hacc (le_max_left _ _) - simpa [lbBest] using hArray - -end Nfp.Sound diff --git a/Legacy/Nfp/Sound/Bounds/Softmax.lean b/Legacy/Nfp/Sound/Bounds/Softmax.lean deleted file mode 100644 index 214ae86..0000000 --- a/Legacy/Nfp/Sound/Bounds/Softmax.lean +++ /dev/null @@ -1,231 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Algebra.Order.Ring.Unbundled.Rat -import Mathlib.Tactic.Linarith -import Nfp.Sound.Bounds.Exp -import Nfp.Sound.Bounds.Portfolio - -namespace Nfp.Sound - -/-! -# Softmax Jacobian bounds --/ - -/-- Worst-case bound on the row-sum operator norm of a softmax Jacobian row. - -For a probability row `p`, the softmax Jacobian is `J = diag(p) - p pᵀ`. -For row `i`, the absolute row-sum is: - -`∑ⱼ |Jᵢⱼ| = pᵢ(1-pᵢ) + ∑_{j≠i} pᵢ pⱼ = 2 pᵢ (1-pᵢ) ≤ 1/2`. - -This bound is universal (independent of sequence length). --/ -def softmaxJacobianNormInfWorst : Rat := (1 : Rat) / 2 - -theorem softmaxJacobianNormInfWorst_def : softmaxJacobianNormInfWorst = (1 : Rat) / 2 := rfl - -/-- Clamp a rational to the unit interval `[0,1]`. -/ -private def clamp01 (x : Rat) : Rat := - max 0 (min x 1) - -theorem clamp01_def (x : Rat) : clamp01 x = max 0 (min x 1) := rfl - -private theorem clamp01_nonneg (x : Rat) : 0 ≤ clamp01 x := by - dsimp [clamp01] - exact le_max_left _ _ - -private theorem clamp01_le_one (x : Rat) : clamp01 x ≤ 1 := by - have h0 : (0 : Rat) ≤ 1 := by - decide - have hmin : min x 1 ≤ (1 : Rat) := by - exact min_le_right _ _ - have hmax : max 0 (min x 1) ≤ (1 : Rat) := by - exact max_le_iff.mpr ⟨h0, hmin⟩ - dsimp [clamp01] - exact hmax - -/-- Local upper bound on the row-sum softmax Jacobian norm given `p ∈ [pLo, pHi]`. -/ -def softmaxJacobianNormInfBound (pLo pHi : Rat) : Rat := - let lo0 := min pLo pHi - let hi0 := max pLo pHi - let lo := clamp01 lo0 - let hi := clamp01 hi0 - if hi < lo then - 0 - else - let half : Rat := (1 : Rat) / 2 - let f : Rat → Rat := fun p => (2 : Rat) * p * (1 - p) - if lo ≤ half ∧ half ≤ hi then - half - else - max (f lo) (f hi) - -theorem softmaxJacobianNormInfBound_def (pLo pHi : Rat) : - softmaxJacobianNormInfBound pLo pHi = - let lo0 := min pLo pHi - let hi0 := max pLo pHi - let lo := clamp01 lo0 - let hi := clamp01 hi0 - if hi < lo then - 0 - else - let half : Rat := (1 : Rat) / 2 - let f : Rat → Rat := fun p => (2 : Rat) * p * (1 - p) - if lo ≤ half ∧ half ≤ hi then - half - else - max (f lo) (f hi) := rfl - -/-! ### Margin-derived softmax bounds -/ - -/-- Probability interval from a uniform score bound `|s| ≤ B`. -/ -def softmaxProbIntervalFromScoreAbsBound (seqLen : Nat) (scoreAbsBound : Rat) - (expEffort : Nat) : Rat × Rat := - if seqLen = 0 then - (0, 1) - else if seqLen = 1 then - (1, 1) - else - let b := max 0 scoreAbsBound - let nRat : Rat := (seqLen : Nat) - let ePosUb := expUBScaledGeom b - let eNegLb := expLB (-b) expEffort - if eNegLb = 0 then - (0, 1) - else - let denomLo := eNegLb + (nRat - 1) * ePosUb - let denomHi := ePosUb + (nRat - 1) * eNegLb - (eNegLb / denomLo, ePosUb / denomHi) - -theorem softmaxProbIntervalFromScoreAbsBound_def (seqLen : Nat) (scoreAbsBound : Rat) - (expEffort : Nat) : - softmaxProbIntervalFromScoreAbsBound seqLen scoreAbsBound expEffort = - if seqLen = 0 then - (0, 1) - else if seqLen = 1 then - (1, 1) - else - let b := max 0 scoreAbsBound - let nRat : Rat := (seqLen : Nat) - let ePosUb := expUBScaledGeom b - let eNegLb := expLB (-b) expEffort - if eNegLb = 0 then - (0, 1) - else - let denomLo := eNegLb + (nRat - 1) * ePosUb - let denomHi := ePosUb + (nRat - 1) * eNegLb - (eNegLb / denomLo, ePosUb / denomHi) := rfl - -/-- Lower bound on the maximum softmax probability from a logit margin. - -Uses a portfolio `expLB` to lower bound `exp(m)` and maps it to -`p_max ≥ exp(m) / (exp(m) + (n-1))` for `m > 0`, with `n = seqLen`. -/ -def softmaxMaxProbLowerBound (seqLen : Nat) (margin : Rat) (expEffort : Nat) : Rat := - if seqLen = 0 then - 0 - else if margin > 0 then - let nRat : Rat := (seqLen : Nat) - let e := expLB margin expEffort - e / (e + (nRat - 1)) - else - 0 - -theorem softmaxMaxProbLowerBound_def (seqLen : Nat) (margin : Rat) (expEffort : Nat) : - softmaxMaxProbLowerBound seqLen margin expEffort = - if seqLen = 0 then - 0 - else if margin > 0 then - let nRat : Rat := (seqLen : Nat) - let e := expLB margin expEffort - e / (e + (nRat - 1)) - else - 0 := rfl - -/-- Lower bound on total target softmax weight from a logit margin. - -If at least `targetCount` logits exceed the rest by `margin`, then the total -target weight is at least `t*exp(m)/(t*exp(m)+(n-t))`. --/ -def softmaxTargetWeightLowerBound (seqLen targetCount : Nat) (margin : Rat) - (expEffort : Nat) : Rat := - if seqLen = 0 || targetCount = 0 then - 0 - else if margin > 0 then - let nRat : Rat := (seqLen : Nat) - let tRat : Rat := (targetCount : Nat) - let base := tRat / nRat - let e := expLB margin expEffort - let cand := (tRat * e) / (tRat * e + (nRat - tRat)) - lbBest base #[cand] - else - 0 - -theorem softmaxTargetWeightLowerBound_def (seqLen targetCount : Nat) (margin : Rat) - (expEffort : Nat) : - softmaxTargetWeightLowerBound seqLen targetCount margin expEffort = - if seqLen = 0 || targetCount = 0 then - 0 - else if margin > 0 then - let nRat : Rat := (seqLen : Nat) - let tRat : Rat := (targetCount : Nat) - let base := tRat / nRat - let e := expLB margin expEffort - let cand := (tRat * e) / (tRat * e + (nRat - tRat)) - lbBest base #[cand] - else - 0 := rfl - -/-- Upper bound on the row-sum softmax Jacobian norm from a max-probability lower bound. - -If the maximum probability is at least `pLo` and `pLo > 1/2`, then every row -satisfies `2 p (1-p) ≤ 2 pLo (1-pLo)`; otherwise the universal `1/2` bound applies. -/ -def softmaxJacobianNormInfBoundFromMaxProb (pLo : Rat) : Rat := - let half : Rat := (1 : Rat) / 2 - let p := clamp01 pLo - if p > half then - (2 : Rat) * p * (1 - p) - else - half - -theorem softmaxJacobianNormInfBoundFromMaxProb_def (pLo : Rat) : - softmaxJacobianNormInfBoundFromMaxProb pLo = - let half : Rat := (1 : Rat) / 2 - let p := clamp01 pLo - if p > half then - (2 : Rat) * p * (1 - p) - else - half := rfl - -/-- Margin-derived Jacobian bounds never exceed the worst-case `1/2`. -/ -theorem softmaxJacobianNormInfBoundFromMaxProb_le_worst (pLo : Rat) : - softmaxJacobianNormInfBoundFromMaxProb pLo ≤ softmaxJacobianNormInfWorst := by - have hp0 : 0 ≤ clamp01 pLo := clamp01_nonneg pLo - have hp1 : clamp01 pLo ≤ 1 := clamp01_le_one pLo - by_cases h : (2 : Rat)⁻¹ < clamp01 pLo - · have hbound : - (2 : Rat) * clamp01 pLo * (1 - clamp01 pLo) ≤ (2 : Rat)⁻¹ := by - nlinarith [hp0, hp1] - simpa [softmaxJacobianNormInfBoundFromMaxProb, softmaxJacobianNormInfWorst_def, h] - using hbound - · simp [softmaxJacobianNormInfBoundFromMaxProb, softmaxJacobianNormInfWorst_def, h] - -/-- Upper bound on the row-sum softmax Jacobian norm from a logit margin. -/ -def softmaxJacobianNormInfBoundFromMargin (seqLen : Nat) (margin : Rat) (expEffort : Nat) : Rat := - softmaxJacobianNormInfBoundFromMaxProb (softmaxMaxProbLowerBound seqLen margin expEffort) - -theorem softmaxJacobianNormInfBoundFromMargin_def (seqLen : Nat) (margin : Rat) - (expEffort : Nat) : - softmaxJacobianNormInfBoundFromMargin seqLen margin expEffort = - softmaxJacobianNormInfBoundFromMaxProb - (softmaxMaxProbLowerBound seqLen margin expEffort) := rfl - -/-- Margin-derived Jacobian bound never exceeds the worst-case `1/2`. -/ -theorem softmaxJacobianNormInfBoundFromMargin_le_worst (seqLen : Nat) (margin : Rat) - (expEffort : Nat) : - softmaxJacobianNormInfBoundFromMargin seqLen margin expEffort ≤ - softmaxJacobianNormInfWorst := by - simpa [softmaxJacobianNormInfBoundFromMargin_def] using - softmaxJacobianNormInfBoundFromMaxProb_le_worst - (pLo := softmaxMaxProbLowerBound seqLen margin expEffort) - -end Nfp.Sound diff --git a/Legacy/Nfp/Sound/Bridge.lean b/Legacy/Nfp/Sound/Bridge.lean deleted file mode 100644 index 2c8c470..0000000 --- a/Legacy/Nfp/Sound/Bridge.lean +++ /dev/null @@ -1,759 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.Rat.BigOperators -import Mathlib.Data.Rat.Cast.Order -import Mathlib.Data.Finset.Lattice.Fold -import Nfp.Linearization -import Nfp.SignedMixer -import Nfp.Sound.Bounds -import Nfp.Sound.Cert - -namespace Nfp.Sound - -open scoped BigOperators - -namespace RatMatrix - -variable {S T : Type*} [Fintype S] [Fintype T] [DecidableEq S] [DecidableEq T] - -/-- Cast a rational matrix to a real SignedMixer. -/ -noncomputable def toSignedMixer (M : RatMatrix S T) : SignedMixer S T := - ⟨fun i j => (M.w i j : ℝ)⟩ - -omit [DecidableEq S] [DecidableEq T] in -theorem toSignedMixer_spec (M : RatMatrix S T) : toSignedMixer M = toSignedMixer M := rfl - -lemma ratAbs_eq_abs (x : Rat) : ratAbs x = |x| := by - by_cases h : x < 0 - · simp [ratAbs, h, abs_of_neg h] - · have h' : 0 ≤ x := le_of_not_gt h - simp [ratAbs, h, abs_of_nonneg h'] - -/-- Casting the rational bound matches the real row-sum operator norm bound. -/ -theorem operatorNormBound_cast (M : RatMatrix S T) [Nonempty S] : - (operatorNormBound M : ℝ) = SignedMixer.operatorNormBound (M.toSignedMixer) := by - classical - have hsup_cast : - (operatorNormBound M : ℝ) = - Finset.sup' Finset.univ (Finset.univ_nonempty (α := S)) - (fun i => ((rowAbsSum M i : Rat) : ℝ)) := by - apply le_antisymm - · rcases Finset.exists_mem_eq_sup' - (s := Finset.univ) (H := Finset.univ_nonempty (α := S)) - (f := fun i => rowAbsSum M i) with ⟨i, hi, hsup⟩ - have hcast : (operatorNormBound M : ℝ) = (rowAbsSum M i : ℝ) := by - simp [operatorNormBound, hsup] - calc - (operatorNormBound M : ℝ) = (rowAbsSum M i : ℝ) := hcast - _ ≤ Finset.sup' Finset.univ (Finset.univ_nonempty (α := S)) - (fun j => ((rowAbsSum M j : Rat) : ℝ)) := by - exact Finset.le_sup' (s := Finset.univ) - (f := fun j => ((rowAbsSum M j : Rat) : ℝ)) hi - · refine (Finset.sup'_le_iff (s := Finset.univ) - (H := Finset.univ_nonempty (α := S)) - (f := fun i => ((rowAbsSum M i : Rat) : ℝ)) - (a := (operatorNormBound M : ℝ))).2 ?_ - intro i hi - have hle : rowAbsSum M i ≤ operatorNormBound M := by - exact Finset.le_sup' (s := Finset.univ) (f := fun j => rowAbsSum M j) hi - exact (Rat.cast_le (K := ℝ)).2 hle - calc - (operatorNormBound M : ℝ) - = Finset.sup' Finset.univ (Finset.univ_nonempty (α := S)) - (fun i => ((rowAbsSum M i : Rat) : ℝ)) := hsup_cast - _ = SignedMixer.operatorNormBound (M.toSignedMixer) := by - simp [SignedMixer.operatorNormBound, SignedMixer.rowAbsSum, rowAbsSum, toSignedMixer, - ratAbs_eq_abs, Rat.cast_sum, Rat.cast_abs] - -/-- Casted row-major bound agrees with the `SignedMixer` operator norm bound. -/ -theorem matrixNormInfOfRowMajor_cast (rows cols : Nat) (data : Array Rat) - [Nonempty (Fin rows)] (h : rows ≠ 0) : - (matrixNormInfOfRowMajor rows cols data : ℝ) = - SignedMixer.operatorNormBound (RatMatrix.ofRowMajor rows cols data).toSignedMixer := by - classical - simpa [matrixNormInfOfRowMajor, h] using - (operatorNormBound_cast (M := RatMatrix.ofRowMajor rows cols data)) - -end RatMatrix - -/-! ## Certificate-to-Jacobian bridge -/ - -/-- Assumptions needed to link certificate fields to component operator-norm bounds. -/ -structure LayerComponentNormAssumptions - {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - [Nonempty n] [Nonempty d] - (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) - (l : LayerAmplificationCert) : Prop where - ln1Bound : - SignedMixer.operatorNormBound (D.ln1Jacobians i) ≤ (l.ln1Bound : ℝ) - ln1Bound_nonneg : 0 ≤ (l.ln1Bound : ℝ) - attnValueBound : - SignedMixer.operatorNormBound (valueTerm (D.layers i)) ≤ - (Fintype.card n : ℝ) * (l.attnValueCoeff : ℝ) - attnPatternBound : - SignedMixer.operatorNormBound (patternTerm (D.layers i)) ≤ - (l.softmaxJacobianNormInfUpperBound : ℝ) * (l.attnPatternCoeff : ℝ) - ln2Bound : - SignedMixer.operatorNormBound (D.ln2Jacobians i) ≤ (l.ln2Bound : ℝ) - ln2Bound_nonneg : 0 ≤ (l.ln2Bound : ℝ) - mlpWinBound : - SignedMixer.operatorNormBound (D.mlpFactors i).win ≤ (l.mlpWinBound : ℝ) - mlpWinBound_nonneg : 0 ≤ (l.mlpWinBound : ℝ) - mlpWoutBound : - SignedMixer.operatorNormBound (D.mlpFactors i).wout ≤ (l.mlpWoutBound : ℝ) - mlpWoutBound_nonneg : 0 ≤ (l.mlpWoutBound : ℝ) - mlpDerivBound : - ∀ j, |(D.mlpFactors i).deriv j| ≤ (l.mlpActDerivBound : ℝ) - -/-- Pattern-term bound from the explicit formula and row-sum gradient/value bounds. -/ -theorem attn_pattern_bound_of_explicit - {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - [Nonempty n] [Nonempty d] - (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) - (G V : ℝ) - (hGrad : ∀ (pos : n) (d_in : d) (q : n), - ∑ k, |attentionGradient (D.layers i) q k pos d_in| ≤ G) - (hValue : SignedMixer.operatorNormBound (valueOutputMixer (D.layers i)) ≤ V) - (hEq : patternTerm (D.layers i) = patternTermExplicit (D.layers i)) : - SignedMixer.operatorNormBound (patternTerm (D.layers i)) ≤ - (Fintype.card n : ℝ) * G * V := by - simpa using - (patternTerm_operatorNormBound_le_of_eq_explicit - (L := D.layers i) (G := G) (V := V) hGrad hValue hEq) - -/-- Pattern-term bound from certificate validity and attention-state assumptions. -/ -theorem attn_pattern_bound_of_cert - {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - [Nonempty n] [Nonempty d] - (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) - (l : LayerAmplificationCert) (eps : Rat) (sqrtPrecBits : Nat) - (seqLenNat modelDimNat headDimNat : Nat) - (hValid : l.Valid eps sqrtPrecBits seqLenNat modelDimNat headDimNat) - (hSeqLen : seqLenNat = Fintype.card n) - (hModelDim : modelDimNat = Fintype.card d) - (hScale : (1 / Real.sqrt (modelDim d)) ≤ (invSqrtUpperBound headDimNat : ℝ)) - (hInputBound : - ∀ pos d_in, |(D.layers i).state.input pos d_in| ≤ (l.ln1OutMaxAbsBound : ℝ)) - (hKeys : - ∀ pos d', (D.layers i).state.keys pos d' = - ∑ d_in, (D.layers i).state.input pos d_in * - (D.layers i).layer.W_K.w d_in d') - (hQueries : - ∀ pos d', (D.layers i).state.queries pos d' = - ∑ d_in, (D.layers i).state.input pos d_in * - (D.layers i).layer.W_Q.w d_in d') - (hValues : - ∀ pos d', (D.layers i).state.values pos d' = - ∑ d_in, (D.layers i).state.input pos d_in * - (D.layers i).layer.W_V.w d_in d') - (hWQ : - SignedMixer.operatorNormBound (D.layers i).layer.W_Q ≤ (l.wqOpBoundMax : ℝ)) - (hWK : - SignedMixer.operatorNormBound (D.layers i).layer.W_K ≤ (l.wkOpBoundMax : ℝ)) - (hVO : - SignedMixer.operatorNormBound - ((D.layers i).layer.W_V.comp (D.layers i).layer.W_O) ≤ - (l.attnValueCoeff : ℝ)) - (hConsistent : - ∀ q, (D.layers i).state.attentionWeights q = - softmax ((D.layers i).state.scores q)) - (hSoftmax : - ∀ q, SignedMixer.operatorNormBound - (softmaxJacobian ((D.layers i).state.scores q)) ≤ - (l.softmaxJacobianNormInfUpperBound : ℝ)) - (hEq : patternTerm (D.layers i) = patternTermExplicit (D.layers i)) : - SignedMixer.operatorNormBound (patternTerm (D.layers i)) ≤ - (l.softmaxJacobianNormInfUpperBound : ℝ) * (l.attnPatternCoeff : ℝ) := by - classical - let L := D.layers i - let B : ℝ := (modelDimNat : ℝ) * (l.ln1OutMaxAbsBound : ℝ) - let wq : ℝ := (l.wqOpBoundMax : ℝ) - let wk : ℝ := (l.wkOpBoundMax : ℝ) - let J : ℝ := (l.softmaxJacobianNormInfUpperBound : ℝ) - let S : ℝ := - (Fintype.card n : ℝ) * - ((1 / Real.sqrt (modelDim d)) * (2 * B * wq * wk)) - let V : ℝ := B * (l.attnValueCoeff : ℝ) - have hSeqLen' : (Fintype.card n : ℝ) = (seqLenNat : ℝ) := by - exact_mod_cast hSeqLen.symm - have hModelDim' : (Fintype.card d : ℝ) = (modelDimNat : ℝ) := by - exact_mod_cast hModelDim.symm - have hB_nonneg : 0 ≤ B := by - have hln1_nonneg : 0 ≤ (l.ln1OutMaxAbsBound : ℝ) := by - rcases (inferInstance : Nonempty n) with ⟨pos⟩ - rcases (inferInstance : Nonempty d) with ⟨d_in⟩ - have h := hInputBound pos d_in - exact le_trans (abs_nonneg _) h - have hdim_nonneg : 0 ≤ (modelDimNat : ℝ) := by - exact_mod_cast (Nat.zero_le _) - exact mul_nonneg hdim_nonneg hln1_nonneg - have hWQ_nonneg : 0 ≤ wq := by - exact le_trans - (SignedMixer.operatorNormBound_nonneg (M := L.layer.W_Q)) hWQ - have hWK_nonneg : 0 ≤ wk := by - exact le_trans - (SignedMixer.operatorNormBound_nonneg (M := L.layer.W_K)) hWK - have hInput : - ∀ pos, ∑ d', |L.state.input pos d'| ≤ B := by - intro pos - have hsum : - ∑ d', |L.state.input pos d'| ≤ - ∑ d' : d, (l.ln1OutMaxAbsBound : ℝ) := by - refine Finset.sum_le_sum ?_ - intro d' _hd - exact hInputBound pos d' - have hconst : - (∑ d' : d, (l.ln1OutMaxAbsBound : ℝ)) = B := by - simp [B, hModelDim'] - exact le_trans hsum (by simp [hconst]) - have hScore : - ∀ i' d_in q, ∑ k, |scoreGradient L q k i' d_in| ≤ S := by - intro i' d_in q - have h := - scoreGradient_sum_le_of_input (L := L) (q := q) (i := i') (d_in := d_in) - (B := B) (wq := wq) (wk := wk) - hInput (hKeys := hKeys) (hQueries := hQueries) hWQ hWK - simpa [S, wq, wk] using h - have hValue : - SignedMixer.operatorNormBound (valueOutputMixer L) ≤ V := by - have h := - valueOutputMixer_operatorNormBound_le_of_input (L := L) (B := B) - (V := (l.attnValueCoeff : ℝ)) hInput (hValues := hValues) hVO - simpa [V] using h - have hPattern := - patternTerm_operatorNormBound_le_of_softmax (L := L) (J := J) (S := S) (V := V) - (hConsistent := hConsistent) (hSoftmax := hSoftmax) - (hScore := hScore) (hValue := hValue) hEq - have hJ_nonneg : 0 ≤ J := by - rcases (inferInstance : Nonempty n) with ⟨q⟩ - exact le_trans - (SignedMixer.operatorNormBound_nonneg - (M := softmaxJacobian (L.state.scores q))) (hSoftmax q) - have hS_nonneg : 0 ≤ 2 * B * wq * wk := by - have h2 : 0 ≤ (2 : ℝ) := by norm_num - exact mul_nonneg (mul_nonneg (mul_nonneg h2 hB_nonneg) hWQ_nonneg) hWK_nonneg - let S_bound : ℝ := - (seqLenNat : ℝ) * ((invSqrtUpperBound headDimNat : ℝ) * (2 * B * wq * wk)) - have hS_le : S ≤ S_bound := by - have hSeq_nonneg : 0 ≤ (seqLenNat : ℝ) := by - exact_mod_cast (Nat.zero_le _) - calc - S = (seqLenNat : ℝ) * - ((1 / Real.sqrt (modelDim d)) * (2 * B * wq * wk)) := by - simp [S, hSeqLen'] - _ ≤ (seqLenNat : ℝ) * - ((invSqrtUpperBound headDimNat : ℝ) * (2 * B * wq * wk)) := by - exact mul_le_mul_of_nonneg_left - (mul_le_mul_of_nonneg_right hScale hS_nonneg) hSeq_nonneg - _ = S_bound := rfl - have hValue_nonneg : 0 ≤ V := by - have hVal_nonneg : 0 ≤ (l.attnValueCoeff : ℝ) := by - exact le_trans - (SignedMixer.operatorNormBound_nonneg - (M := L.layer.W_V.comp L.layer.W_O)) hVO - exact mul_nonneg hB_nonneg hVal_nonneg - have hSV_le : - (Fintype.card n : ℝ) * S * V ≤ (seqLenNat : ℝ) * S_bound * V := by - have hSeq_nonneg : 0 ≤ (seqLenNat : ℝ) := by - exact_mod_cast (Nat.zero_le _) - have hS_le' : (seqLenNat : ℝ) * S ≤ (seqLenNat : ℝ) * S_bound := by - exact mul_le_mul_of_nonneg_left hS_le hSeq_nonneg - calc - (Fintype.card n : ℝ) * S * V = (seqLenNat : ℝ) * S * V := by - simp [hSeqLen'] - _ ≤ (seqLenNat : ℝ) * S_bound * V := by - exact mul_le_mul_of_nonneg_right hS_le' hValue_nonneg - have hPat : - (l.attnPatternCoeff : ℝ) = - (attnPatternCoeffBound seqLenNat modelDimNat headDimNat - l.ln1OutMaxAbsBound l.wqOpBoundMax l.wkOpBoundMax l.attnValueCoeff : ℝ) := by - rcases hValid with ⟨_hln1, _hln2, _hln1Out, _hProbLo, _hProbHi, - _hsoftmax, hpat, _hattn, _hmlpCoeff, _hmlp, _hC⟩ - exact congrArg (fun x : Rat => (x : ℝ)) hpat - have hCoeff_eq : - (seqLenNat : ℝ) * S_bound * V = - (attnPatternCoeffBound seqLenNat modelDimNat headDimNat - l.ln1OutMaxAbsBound l.wqOpBoundMax l.wkOpBoundMax l.attnValueCoeff : ℝ) := by - simp [S_bound, V, B, wq, wk, attnPatternCoeffBound_def, attnScoreGradBound_def, - Rat.cast_mul, Rat.cast_natCast, Rat.cast_ofNat, mul_assoc, mul_left_comm, - mul_comm, -mul_eq_mul_left_iff, -mul_eq_mul_right_iff] - have hCoeff : - (Fintype.card n : ℝ) * S * V ≤ (l.attnPatternCoeff : ℝ) := by - calc - (Fintype.card n : ℝ) * S * V ≤ (seqLenNat : ℝ) * S_bound * V := hSV_le - _ = (attnPatternCoeffBound seqLenNat modelDimNat headDimNat - l.ln1OutMaxAbsBound l.wqOpBoundMax l.wkOpBoundMax l.attnValueCoeff : ℝ) := hCoeff_eq - _ = (l.attnPatternCoeff : ℝ) := by simp [hPat] - have hCoeffMul' : - J * ((Fintype.card n : ℝ) * S * V) ≤ J * (l.attnPatternCoeff : ℝ) := - mul_le_mul_of_nonneg_left hCoeff hJ_nonneg - have hCoeffMul : - (Fintype.card n : ℝ) * J * S * V ≤ J * (l.attnPatternCoeff : ℝ) := by - have hRearrange : - (Fintype.card n : ℝ) * J * S * V = J * ((Fintype.card n : ℝ) * S * V) := by - ring - simpa [hRearrange] using hCoeffMul' - exact le_trans hPattern hCoeffMul - -/-- Full attention Jacobian bound from certificate validity and attention-state assumptions. -/ -theorem attn_fullJacobian_bound_of_cert - {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - [Nonempty n] [Nonempty d] - (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) - (l : LayerAmplificationCert) (eps : Rat) (sqrtPrecBits : Nat) - (seqLenNat modelDimNat headDimNat : Nat) - (hValid : l.Valid eps sqrtPrecBits seqLenNat modelDimNat headDimNat) - (hSeqLen : seqLenNat = Fintype.card n) - (hModelDim : modelDimNat = Fintype.card d) - (hScale : (1 / Real.sqrt (modelDim d)) ≤ (invSqrtUpperBound headDimNat : ℝ)) - (hInputBound : - ∀ pos d_in, |(D.layers i).state.input pos d_in| ≤ (l.ln1OutMaxAbsBound : ℝ)) - (hKeys : - ∀ pos d', (D.layers i).state.keys pos d' = - ∑ d_in, (D.layers i).state.input pos d_in * - (D.layers i).layer.W_K.w d_in d') - (hQueries : - ∀ pos d', (D.layers i).state.queries pos d' = - ∑ d_in, (D.layers i).state.input pos d_in * - (D.layers i).layer.W_Q.w d_in d') - (hValues : - ∀ pos d', (D.layers i).state.values pos d' = - ∑ d_in, (D.layers i).state.input pos d_in * - (D.layers i).layer.W_V.w d_in d') - (hWQ : - SignedMixer.operatorNormBound (D.layers i).layer.W_Q ≤ (l.wqOpBoundMax : ℝ)) - (hWK : - SignedMixer.operatorNormBound (D.layers i).layer.W_K ≤ (l.wkOpBoundMax : ℝ)) - (hVO : - SignedMixer.operatorNormBound - ((D.layers i).layer.W_V.comp (D.layers i).layer.W_O) ≤ - (l.attnValueCoeff : ℝ)) - (hConsistent : - ∀ q, (D.layers i).state.attentionWeights q = - softmax ((D.layers i).state.scores q)) - (hSoftmax : - ∀ q, SignedMixer.operatorNormBound - (softmaxJacobian ((D.layers i).state.scores q)) ≤ - (l.softmaxJacobianNormInfUpperBound : ℝ)) - (hEq : patternTerm (D.layers i) = patternTermExplicit (D.layers i)) : - SignedMixer.operatorNormBound (D.layers i).fullJacobian ≤ - (seqLenNat : ℝ) * (l.attnValueCoeff : ℝ) + - (l.softmaxJacobianNormInfUpperBound : ℝ) * (l.attnPatternCoeff : ℝ) := by - classical - let L := D.layers i - have hPattern := - attn_pattern_bound_of_cert - (D := D) (i := i) (l := l) (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLenNat := seqLenNat) (modelDimNat := modelDimNat) (headDimNat := headDimNat) - hValid hSeqLen hModelDim hScale hInputBound hKeys hQueries hValues hWQ hWK hVO - hConsistent hSoftmax hEq - have hValueBase : - SignedMixer.operatorNormBound (valueTerm L) ≤ - (Fintype.card n : ℝ) * (l.attnValueCoeff : ℝ) := by - simpa using - (valueTerm_operatorNormBound_le_card (L := L) (B := (l.attnValueCoeff : ℝ)) - (hConsistent := hConsistent) (hVO := hVO)) - have hSeqLen' : (Fintype.card n : ℝ) = (seqLenNat : ℝ) := by - exact_mod_cast hSeqLen.symm - have hValue : - SignedMixer.operatorNormBound (valueTerm L) ≤ - (seqLenNat : ℝ) * (l.attnValueCoeff : ℝ) := by - simpa [hSeqLen'] using hValueBase - simpa using - (attention_fullJacobian_bound_of_terms (L := L) (hValue := hValue) (hPattern := hPattern)) - -/-- MLP coefficient bound from certificate validity and component bounds. -/ -theorem mlp_coeff_bound_of_valid - {n d : Type*} [Fintype n] [Fintype d] [Nonempty n] [Nonempty d] - (F : MLPFactorization (n := n) (d := d)) - (l : LayerAmplificationCert) (eps : Rat) (sqrtPrecBits : Nat) - (seqLen modelDim headDim : Nat) - (hValid : l.Valid eps sqrtPrecBits seqLen modelDim headDim) - (hWin : SignedMixer.operatorNormBound F.win ≤ (l.mlpWinBound : ℝ)) - (hWin_nonneg : 0 ≤ (l.mlpWinBound : ℝ)) - (hWout : SignedMixer.operatorNormBound F.wout ≤ (l.mlpWoutBound : ℝ)) : - SignedMixer.operatorNormBound F.win * - SignedMixer.operatorNormBound F.wout ≤ (l.mlpCoeff : ℝ) := by - have hWout_nonneg : - 0 ≤ SignedMixer.operatorNormBound F.wout := - SignedMixer.operatorNormBound_nonneg (M := F.wout) - have hMul1 : - SignedMixer.operatorNormBound F.win * SignedMixer.operatorNormBound F.wout ≤ - (l.mlpWinBound : ℝ) * SignedMixer.operatorNormBound F.wout := by - exact mul_le_mul_of_nonneg_right hWin hWout_nonneg - have hMul2 : - (l.mlpWinBound : ℝ) * SignedMixer.operatorNormBound F.wout ≤ - (l.mlpWinBound : ℝ) * (l.mlpWoutBound : ℝ) := by - exact mul_le_mul_of_nonneg_left hWout hWin_nonneg - have hMul : - SignedMixer.operatorNormBound F.win * SignedMixer.operatorNormBound F.wout ≤ - (l.mlpWinBound : ℝ) * (l.mlpWoutBound : ℝ) := by - exact le_trans hMul1 hMul2 - have hCoeff := LayerAmplificationCert.mlpCoeff_eq_cast_of_valid - (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLen := seqLen) (modelDim := modelDim) (headDim := headDim) (l := l) hValid - simpa [hCoeff] using hMul - -/-- MLP operator-norm bound from a factored Jacobian plus coefficient bounds. -/ -theorem mlp_bound_of_factorization - {n d : Type*} [Fintype n] [Fintype d] [Nonempty n] [Nonempty d] - (F : MLPFactorization (n := n) (d := d)) - (l : LayerAmplificationCert) - (hDeriv : ∀ j, |F.deriv j| ≤ (l.mlpActDerivBound : ℝ)) - (hCoeff : - SignedMixer.operatorNormBound F.win * - SignedMixer.operatorNormBound F.wout ≤ (l.mlpCoeff : ℝ)) : - SignedMixer.operatorNormBound F.jacobian ≤ - (l.mlpCoeff : ℝ) * (l.mlpActDerivBound : ℝ) := by - classical - let a := SignedMixer.operatorNormBound F.win - let b := SignedMixer.operatorNormBound F.wout - have hBound : - SignedMixer.operatorNormBound F.jacobian ≤ - a * (l.mlpActDerivBound : ℝ) * b := by - simpa using - (operatorNormBound_comp_diagMixer_comp_le - (A := F.win) (B := F.wout) (d := F.deriv) - (a := a) (c := (l.mlpActDerivBound : ℝ)) (b := b) - (hA := by simp [a]) (hB := by simp [b]) hDeriv) - have hAct_nonneg : 0 ≤ (l.mlpActDerivBound : ℝ) := by - rcases (inferInstance : Nonempty F.hidden) with ⟨j⟩ - have h := hDeriv j - exact le_trans (abs_nonneg _) h - have hMul : - a * (l.mlpActDerivBound : ℝ) * b ≤ - (l.mlpCoeff : ℝ) * (l.mlpActDerivBound : ℝ) := by - have hMul' : - (a * b) * (l.mlpActDerivBound : ℝ) ≤ - (l.mlpCoeff : ℝ) * (l.mlpActDerivBound : ℝ) := by - exact mul_le_mul_of_nonneg_right hCoeff hAct_nonneg - have hRearrange : - a * (l.mlpActDerivBound : ℝ) * b = - (a * b) * (l.mlpActDerivBound : ℝ) := by - calc - a * (l.mlpActDerivBound : ℝ) * b = - a * b * (l.mlpActDerivBound : ℝ) := by - simpa [mul_assoc] using (mul_right_comm a (l.mlpActDerivBound : ℝ) b) - _ = (a * b) * (l.mlpActDerivBound : ℝ) := by simp [mul_assoc] - simpa [hRearrange] using hMul' - exact le_trans hBound hMul - -/-- Attention-component bound from certificate identities and component bounds. -/ -theorem attn_component_bound_of_cert - {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - [Nonempty n] [Nonempty d] - (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) - (l : LayerAmplificationCert) (eps : Rat) (sqrtPrecBits : Nat) - (seqLen modelDim headDim : Nat) - (hValid : l.Valid eps sqrtPrecBits seqLen modelDim headDim) - (hLn1 : - SignedMixer.operatorNormBound (D.ln1Jacobians i) ≤ (l.ln1Bound : ℝ)) - (hLn1_nonneg : 0 ≤ (l.ln1Bound : ℝ)) - (hFull : - SignedMixer.operatorNormBound (D.layers i).fullJacobian ≤ - (seqLen : ℝ) * (l.attnValueCoeff : ℝ) + - (l.softmaxJacobianNormInfUpperBound : ℝ) * (l.attnPatternCoeff : ℝ)) : - SignedMixer.operatorNormBound - ((D.ln1Jacobians i).comp (D.layers i).fullJacobian) - ≤ (l.attnJacBound : ℝ) := by - classical - let attnTotal : ℝ := - (seqLen : ℝ) * (l.attnValueCoeff : ℝ) + - (l.softmaxJacobianNormInfUpperBound : ℝ) * (l.attnPatternCoeff : ℝ) - have hcomp : - SignedMixer.operatorNormBound - ((D.ln1Jacobians i).comp (D.layers i).fullJacobian) ≤ - SignedMixer.operatorNormBound (D.ln1Jacobians i) * - SignedMixer.operatorNormBound (D.layers i).fullJacobian := by - simpa using - (SignedMixer.operatorNormBound_comp_le - (M := D.ln1Jacobians i) (N := (D.layers i).fullJacobian)) - have hFull_nonneg : - 0 ≤ SignedMixer.operatorNormBound (D.layers i).fullJacobian := - SignedMixer.operatorNormBound_nonneg (M := (D.layers i).fullJacobian) - have hmul1 : - SignedMixer.operatorNormBound (D.ln1Jacobians i) * - SignedMixer.operatorNormBound (D.layers i).fullJacobian ≤ - (l.ln1Bound : ℝ) * SignedMixer.operatorNormBound (D.layers i).fullJacobian := by - exact mul_le_mul_of_nonneg_right hLn1 hFull_nonneg - have hmul2 : - (l.ln1Bound : ℝ) * SignedMixer.operatorNormBound (D.layers i).fullJacobian ≤ - (l.ln1Bound : ℝ) * attnTotal := by - simpa [attnTotal] using - (mul_le_mul_of_nonneg_left hFull hLn1_nonneg) - have hmul : - SignedMixer.operatorNormBound (D.ln1Jacobians i) * - SignedMixer.operatorNormBound (D.layers i).fullJacobian ≤ - (l.ln1Bound : ℝ) * attnTotal := by - exact le_trans hmul1 hmul2 - have hAttn : - (l.ln1Bound : ℝ) * attnTotal = - (l.attnJacBound : ℝ) := by - have hAttn := - LayerAmplificationCert.attnJacBound_eq_cast_of_valid - (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLen := seqLen) (modelDim := modelDim) (headDim := headDim) (l := l) hValid - simpa [attnTotal, mul_assoc, add_assoc] using hAttn.symm - calc - SignedMixer.operatorNormBound - ((D.ln1Jacobians i).comp (D.layers i).fullJacobian) - ≤ SignedMixer.operatorNormBound (D.ln1Jacobians i) * - SignedMixer.operatorNormBound (D.layers i).fullJacobian := hcomp - _ ≤ (l.ln1Bound : ℝ) * attnTotal := hmul - _ = (l.attnJacBound : ℝ) := hAttn - -/-- MLP-component bound from certificate identities and component bounds. -/ -theorem mlp_component_bound_of_cert - {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - [Nonempty n] [Nonempty d] - (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) - (l : LayerAmplificationCert) (eps : Rat) (sqrtPrecBits : Nat) - (seqLen modelDim headDim : Nat) - (hValid : l.Valid eps sqrtPrecBits seqLen modelDim headDim) - (hLn2 : - SignedMixer.operatorNormBound (D.ln2Jacobians i) ≤ (l.ln2Bound : ℝ)) - (hLn2_nonneg : 0 ≤ (l.ln2Bound : ℝ)) - (hMlp : - SignedMixer.operatorNormBound (D.mlpJacobians i) ≤ - (l.mlpCoeff : ℝ) * (l.mlpActDerivBound : ℝ)) : - SignedMixer.operatorNormBound - ((D.ln2Jacobians i).comp (D.mlpJacobians i)) - ≤ (l.mlpJacBound : ℝ) := by - classical - let mlpBound : ℝ := (l.mlpCoeff : ℝ) * (l.mlpActDerivBound : ℝ) - have hcomp : - SignedMixer.operatorNormBound - ((D.ln2Jacobians i).comp (D.mlpJacobians i)) ≤ - SignedMixer.operatorNormBound (D.ln2Jacobians i) * - SignedMixer.operatorNormBound (D.mlpJacobians i) := by - simpa using - (SignedMixer.operatorNormBound_comp_le - (M := D.ln2Jacobians i) (N := D.mlpJacobians i)) - have hMlp_nonneg : 0 ≤ SignedMixer.operatorNormBound (D.mlpJacobians i) := - SignedMixer.operatorNormBound_nonneg (M := D.mlpJacobians i) - have hmul1 : - SignedMixer.operatorNormBound (D.ln2Jacobians i) * - SignedMixer.operatorNormBound (D.mlpJacobians i) ≤ - (l.ln2Bound : ℝ) * SignedMixer.operatorNormBound (D.mlpJacobians i) := by - exact mul_le_mul_of_nonneg_right hLn2 hMlp_nonneg - have hmul2 : - (l.ln2Bound : ℝ) * SignedMixer.operatorNormBound (D.mlpJacobians i) ≤ - (l.ln2Bound : ℝ) * mlpBound := by - simpa [mlpBound] using - (mul_le_mul_of_nonneg_left hMlp hLn2_nonneg) - have hmul : - SignedMixer.operatorNormBound (D.ln2Jacobians i) * - SignedMixer.operatorNormBound (D.mlpJacobians i) ≤ - (l.ln2Bound : ℝ) * mlpBound := by - exact le_trans hmul1 hmul2 - have hMlpEq : - (l.ln2Bound : ℝ) * mlpBound = (l.mlpJacBound : ℝ) := by - have hMlpEq := - LayerAmplificationCert.mlpJacBound_eq_cast_of_valid - (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLen := seqLen) (modelDim := modelDim) (headDim := headDim) (l := l) hValid - simpa [mlpBound, mul_assoc] using hMlpEq.symm - calc - SignedMixer.operatorNormBound - ((D.ln2Jacobians i).comp (D.mlpJacobians i)) - ≤ SignedMixer.operatorNormBound (D.ln2Jacobians i) * - SignedMixer.operatorNormBound (D.mlpJacobians i) := hcomp - _ ≤ (l.ln2Bound : ℝ) * mlpBound := hmul - _ = (l.mlpJacBound : ℝ) := hMlpEq - -/-- Layer Jacobian residual bound from a layer amplification certificate. -/ -theorem layerJacobian_residual_bound_of_cert - {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - [Nonempty n] [Nonempty d] - (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) - (l : LayerAmplificationCert) (eps : Rat) (sqrtPrecBits : Nat) - (seqLen modelDim headDim : Nat) - (hValid : l.Valid eps sqrtPrecBits seqLen modelDim headDim) - (hA : - SignedMixer.operatorNormBound - ((D.ln1Jacobians i).comp (D.layers i).fullJacobian) - ≤ (l.attnJacBound : ℝ)) - (hM : - SignedMixer.operatorNormBound - ((D.ln2Jacobians i).comp (D.mlpJacobians i)) - ≤ (l.mlpJacBound : ℝ)) : - SignedMixer.operatorNormBound - (D.layerJacobian i - SignedMixer.identity) ≤ (l.C : ℝ) := by - have hC := LayerAmplificationCert.c_eq_cast_of_valid - (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLen := seqLen) (modelDim := modelDim) (headDim := headDim) (l := l) hValid - have hres := - DeepLinearization.layerJacobian_residual_bound - (D := D) (i := i) - (A := (l.attnJacBound : ℝ)) - (M := (l.mlpJacBound : ℝ)) hA hM - simpa [hC] using hres - -/-- Layer Jacobian residual bound from a certificate, with component-norm assumptions. -/ -theorem layerJacobian_residual_bound_of_cert_components - {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - [Nonempty n] [Nonempty d] - (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) - (l : LayerAmplificationCert) (eps : Rat) (sqrtPrecBits : Nat) - (seqLen modelDim headDim : Nat) - (hValid : l.Valid eps sqrtPrecBits seqLen modelDim headDim) - (hLn1 : - SignedMixer.operatorNormBound (D.ln1Jacobians i) ≤ (l.ln1Bound : ℝ)) - (hLn1_nonneg : 0 ≤ (l.ln1Bound : ℝ)) - (hFull : - SignedMixer.operatorNormBound (D.layers i).fullJacobian ≤ - (seqLen : ℝ) * (l.attnValueCoeff : ℝ) + - (l.softmaxJacobianNormInfUpperBound : ℝ) * (l.attnPatternCoeff : ℝ)) - (hLn2 : - SignedMixer.operatorNormBound (D.ln2Jacobians i) ≤ (l.ln2Bound : ℝ)) - (hLn2_nonneg : 0 ≤ (l.ln2Bound : ℝ)) - (hMlp : - SignedMixer.operatorNormBound (D.mlpJacobians i) ≤ - (l.mlpCoeff : ℝ) * (l.mlpActDerivBound : ℝ)) : - SignedMixer.operatorNormBound - (D.layerJacobian i - SignedMixer.identity) ≤ (l.C : ℝ) := by - have hA := - attn_component_bound_of_cert - (D := D) (i := i) (l := l) (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLen := seqLen) (modelDim := modelDim) (headDim := headDim) - hValid hLn1 hLn1_nonneg hFull - have hM := - mlp_component_bound_of_cert - (D := D) (i := i) (l := l) (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLen := seqLen) (modelDim := modelDim) (headDim := headDim) - hValid hLn2 hLn2_nonneg hMlp - exact layerJacobian_residual_bound_of_cert - (D := D) (i := i) (l := l) (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLen := seqLen) (modelDim := modelDim) (headDim := headDim) - hValid hA hM - -/-- Layer Jacobian residual bound from certificate validity and attention-state assumptions. -/ -theorem layerJacobian_residual_bound_of_cert_from_state - {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - [Nonempty n] [Nonempty d] - (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) - (l : LayerAmplificationCert) (eps : Rat) (sqrtPrecBits : Nat) - (seqLenNat modelDimNat headDimNat : Nat) - (hValid : l.Valid eps sqrtPrecBits seqLenNat modelDimNat headDimNat) - (hSeqLen : seqLenNat = Fintype.card n) - (hModelDim : modelDimNat = Fintype.card d) - (hScale : (1 / Real.sqrt (modelDim d)) ≤ (invSqrtUpperBound headDimNat : ℝ)) - (hInputBound : - ∀ pos d_in, |(D.layers i).state.input pos d_in| ≤ (l.ln1OutMaxAbsBound : ℝ)) - (hKeys : - ∀ pos d', (D.layers i).state.keys pos d' = - ∑ d_in, (D.layers i).state.input pos d_in * - (D.layers i).layer.W_K.w d_in d') - (hQueries : - ∀ pos d', (D.layers i).state.queries pos d' = - ∑ d_in, (D.layers i).state.input pos d_in * - (D.layers i).layer.W_Q.w d_in d') - (hValues : - ∀ pos d', (D.layers i).state.values pos d' = - ∑ d_in, (D.layers i).state.input pos d_in * - (D.layers i).layer.W_V.w d_in d') - (hWQ : - SignedMixer.operatorNormBound (D.layers i).layer.W_Q ≤ (l.wqOpBoundMax : ℝ)) - (hWK : - SignedMixer.operatorNormBound (D.layers i).layer.W_K ≤ (l.wkOpBoundMax : ℝ)) - (hVO : - SignedMixer.operatorNormBound - ((D.layers i).layer.W_V.comp (D.layers i).layer.W_O) ≤ - (l.attnValueCoeff : ℝ)) - (hConsistent : - ∀ q, (D.layers i).state.attentionWeights q = - softmax ((D.layers i).state.scores q)) - (hSoftmax : - ∀ q, SignedMixer.operatorNormBound - (softmaxJacobian ((D.layers i).state.scores q)) ≤ - (l.softmaxJacobianNormInfUpperBound : ℝ)) - (hEq : patternTerm (D.layers i) = patternTermExplicit (D.layers i)) - (hLn1 : - SignedMixer.operatorNormBound (D.ln1Jacobians i) ≤ (l.ln1Bound : ℝ)) - (hLn1_nonneg : 0 ≤ (l.ln1Bound : ℝ)) - (hLn2 : - SignedMixer.operatorNormBound (D.ln2Jacobians i) ≤ (l.ln2Bound : ℝ)) - (hLn2_nonneg : 0 ≤ (l.ln2Bound : ℝ)) - (hMlp : - SignedMixer.operatorNormBound (D.mlpJacobians i) ≤ - (l.mlpCoeff : ℝ) * (l.mlpActDerivBound : ℝ)) : - SignedMixer.operatorNormBound - (D.layerJacobian i - SignedMixer.identity) ≤ (l.C : ℝ) := by - have hFull := - attn_fullJacobian_bound_of_cert - (D := D) (i := i) (l := l) (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLenNat := seqLenNat) (modelDimNat := modelDimNat) (headDimNat := headDimNat) - hValid hSeqLen hModelDim hScale hInputBound hKeys hQueries hValues hWQ hWK hVO - hConsistent hSoftmax hEq - have hA := - attn_component_bound_of_cert - (D := D) (i := i) (l := l) (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLen := seqLenNat) (modelDim := modelDimNat) (headDim := headDimNat) - hValid hLn1 hLn1_nonneg hFull - have hM := - mlp_component_bound_of_cert - (D := D) (i := i) (l := l) (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLen := seqLenNat) (modelDim := modelDimNat) (headDim := headDimNat) - hValid hLn2 hLn2_nonneg hMlp - exact layerJacobian_residual_bound_of_cert - (D := D) (i := i) (l := l) (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLen := seqLenNat) (modelDim := modelDimNat) (headDim := headDimNat) - hValid hA hM - -/-- Layer Jacobian residual bound from a certificate plus component-norm assumptions. -/ -theorem layerJacobian_residual_bound_of_cert_assuming - {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - [Nonempty n] [Nonempty d] - (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) - (l : LayerAmplificationCert) (eps : Rat) (sqrtPrecBits : Nat) - (seqLen modelDim headDim : Nat) - (hValid : l.Valid eps sqrtPrecBits seqLen modelDim headDim) - (hSeqLen : seqLen = Fintype.card n) - (h : LayerComponentNormAssumptions (D := D) (i := i) (l := l)) : - SignedMixer.operatorNormBound - (D.layerJacobian i - SignedMixer.identity) ≤ (l.C : ℝ) := by - have hCoeff : - SignedMixer.operatorNormBound (D.mlpFactors i).win * - SignedMixer.operatorNormBound (D.mlpFactors i).wout ≤ (l.mlpCoeff : ℝ) := - mlp_coeff_bound_of_valid - (F := D.mlpFactors i) (l := l) - (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLen := seqLen) (modelDim := modelDim) (headDim := headDim) hValid - h.mlpWinBound h.mlpWinBound_nonneg h.mlpWoutBound - have hMlp : - SignedMixer.operatorNormBound (D.mlpJacobians i) ≤ - (l.mlpCoeff : ℝ) * (l.mlpActDerivBound : ℝ) := by - simpa using - (mlp_bound_of_factorization (F := D.mlpFactors i) (l := l) - h.mlpDerivBound hCoeff) - have hFull : - SignedMixer.operatorNormBound (D.layers i).fullJacobian ≤ - (seqLen : ℝ) * (l.attnValueCoeff : ℝ) + - (l.softmaxJacobianNormInfUpperBound : ℝ) * (l.attnPatternCoeff : ℝ) := by - have hFullBase : - SignedMixer.operatorNormBound (D.layers i).fullJacobian ≤ - (Fintype.card n : ℝ) * (l.attnValueCoeff : ℝ) + - (l.softmaxJacobianNormInfUpperBound : ℝ) * (l.attnPatternCoeff : ℝ) := by - simpa using - (attention_fullJacobian_bound_of_terms (L := D.layers i) - (hValue := h.attnValueBound) (hPattern := h.attnPatternBound)) - have hSeqLen' : (Fintype.card n : ℝ) = (seqLen : ℝ) := by - exact_mod_cast hSeqLen.symm - simpa [hSeqLen'] using hFullBase - exact layerJacobian_residual_bound_of_cert_components - (D := D) (i := i) (l := l) (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLen := seqLen) (modelDim := modelDim) (headDim := headDim) - hValid h.ln1Bound h.ln1Bound_nonneg hFull - h.ln2Bound h.ln2Bound_nonneg hMlp - -end Nfp.Sound diff --git a/Legacy/Nfp/Sound/CachePure.lean b/Legacy/Nfp/Sound/CachePure.lean deleted file mode 100644 index bfb260d..0000000 --- a/Legacy/Nfp/Sound/CachePure.lean +++ /dev/null @@ -1,1011 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Init.System.IO -import Init.Data.ByteArray.Lemmas -import Nfp.Sound.Decimal -import Nfp.Sound.Fixed -import Nfp.Sound.ModelHeader - -namespace Nfp.Sound - -/-! -# SOUND fixed-point cache (pure helpers) - -Pure parsing and encoding utilities for the SOUND cache format. -IO wrappers live in `Nfp.Untrusted.SoundCacheIO`. --/ - -namespace SoundCache - -def version : UInt32 := 1 -def magic : ByteArray := "NFP_SND_CACHE_V1\n".toUTF8 - -structure Header where - modelHash : UInt64 - modelSize : UInt64 - scalePow10 : UInt32 - numLayers : UInt32 - numHeads : UInt32 - modelDim : UInt32 - headDim : UInt32 - hiddenDim : UInt32 - deriving Repr - -private def u32le (x : UInt32) : ByteArray := - let b0 := (x &&& 0xFF).toUInt8 - let b1 := ((x >>> 8) &&& 0xFF).toUInt8 - let b2 := ((x >>> 16) &&& 0xFF).toUInt8 - let b3 := ((x >>> 24) &&& 0xFF).toUInt8 - ByteArray.mk #[b0, b1, b2, b3] - -private def u64le (x : UInt64) : ByteArray := - let b0 := (x &&& 0xFF).toUInt8 - let b1 := ((x >>> 8) &&& 0xFF).toUInt8 - let b2 := ((x >>> 16) &&& 0xFF).toUInt8 - let b3 := ((x >>> 24) &&& 0xFF).toUInt8 - let b4 := ((x >>> 32) &&& 0xFF).toUInt8 - let b5 := ((x >>> 40) &&& 0xFF).toUInt8 - let b6 := ((x >>> 48) &&& 0xFF).toUInt8 - let b7 := ((x >>> 56) &&& 0xFF).toUInt8 - ByteArray.mk #[b0, b1, b2, b3, b4, b5, b6, b7] - -private def i32le (x : Int) : ByteArray := - let ux : UInt32 := UInt32.ofInt x - u32le ux - -private def appendI32LE (buf : Array UInt8) (x : Int) : Array UInt8 := - Id.run do - let ux : UInt32 := UInt32.ofInt x - let mut out := buf - out := out.push (ux &&& 0xFF).toUInt8 - out := out.push ((ux >>> 8) &&& 0xFF).toUInt8 - out := out.push ((ux >>> 16) &&& 0xFF).toUInt8 - out := out.push ((ux >>> 24) &&& 0xFF).toUInt8 - return out - -@[inline] private def u32FromLE (b : ByteArray) (off : Nat) : UInt32 := - let b0 := (b.get! off).toUInt32 - let b1 := (b.get! (off + 1)).toUInt32 - let b2 := (b.get! (off + 2)).toUInt32 - let b3 := (b.get! (off + 3)).toUInt32 - b0 ||| (b1 <<< 8) ||| (b2 <<< 16) ||| (b3 <<< 24) - -@[inline] private def u64FromLE (b : ByteArray) (off : Nat) : UInt64 := - let b0 := (b.get! off).toUInt64 - let b1 := (b.get! (off + 1)).toUInt64 - let b2 := (b.get! (off + 2)).toUInt64 - let b3 := (b.get! (off + 3)).toUInt64 - let b4 := (b.get! (off + 4)).toUInt64 - let b5 := (b.get! (off + 5)).toUInt64 - let b6 := (b.get! (off + 6)).toUInt64 - let b7 := (b.get! (off + 7)).toUInt64 - b0 ||| (b1 <<< 8) ||| (b2 <<< 16) ||| (b3 <<< 24) ||| - (b4 <<< 32) ||| (b5 <<< 40) ||| (b6 <<< 48) ||| (b7 <<< 56) - -private def twoPow32 : Int := Int.ofNat (Nat.pow 2 32) - -@[inline] def i32FromLE (b : ByteArray) (off : Nat) : Int := - let u := u32FromLE b off - let half : UInt32 := 0x80000000 - if u < half then - Int.ofNat u.toNat - else - (Int.ofNat u.toNat) - twoPow32 - -def encodeHeader (hdr : Header) : ByteArray := - magic - ++ u32le version - ++ u64le hdr.modelHash - ++ u64le hdr.modelSize - ++ u32le hdr.scalePow10 - ++ u32le hdr.numLayers - ++ u32le hdr.numHeads - ++ u32le hdr.modelDim - ++ u32le hdr.headDim - ++ u32le hdr.hiddenDim - -def headerBytes : Nat := - magic.size + 4 + 8 + 8 + 4 + 4 + 4 + 4 + 4 + 4 - -def decodeHeader (bytes : ByteArray) : Except String Header := do - if bytes.size < headerBytes then - throw "unexpected EOF while reading cache header" - let m := bytes.extract 0 magic.size - if m ≠ magic then - throw "invalid cache magic" - let off0 := magic.size - let v := u32FromLE bytes off0 - if v ≠ version then - throw s!"unsupported cache version {v}" - let off1 := off0 + 4 - let modelHash := u64FromLE bytes off1 - let off2 := off1 + 8 - let modelSize := u64FromLE bytes off2 - let off3 := off2 + 8 - let scalePow10 := u32FromLE bytes off3 - let off4 := off3 + 4 - let numLayers := u32FromLE bytes off4 - let off5 := off4 + 4 - let numHeads := u32FromLE bytes off5 - let off6 := off5 + 4 - let modelDim := u32FromLE bytes off6 - let off7 := off6 + 4 - let headDim := u32FromLE bytes off7 - let off8 := off7 + 4 - let hiddenDim := u32FromLE bytes off8 - return { modelHash, modelSize, scalePow10, numLayers, numHeads, modelDim, headDim, hiddenDim } - -def cacheDir : System.FilePath := "sound_cache" - -def cachePath (modelPath : System.FilePath) (modelHash : UInt64) (scalePow10 : Nat) : - System.FilePath := - let stem := modelPath.fileStem - cacheDir / s!"{stem}_{modelHash.toNat}_p{scalePow10}.nfpc" - -def expectedI32Count (hdr : Header) : Nat := - let L := hdr.numLayers.toNat - let H := hdr.numHeads.toNat - let d := hdr.modelDim.toNat - let dh := hdr.headDim.toNat - let dhid := hdr.hiddenDim.toNat - let perHead := d * dh + dh + dh * d - let perLayer := - (4 * d) + (H * perHead) + d + (d * dhid) + dhid + (dhid * d) + d - L * perLayer - -def expectedCacheBytes (hdr : Header) : UInt64 := - UInt64.ofNat headerBytes + (UInt64.ofNat (expectedI32Count hdr) * (4 : UInt64)) - -def fnv1a64Init : UInt64 := 14695981039346656037 - -def fnv1a64Update (hash : UInt64) (chunk : ByteArray) : UInt64 := - Id.run do - let prime : UInt64 := 1099511628211 - let mut h := hash - for b in chunk.data do - h := (h ^^^ (UInt64.ofNat b.toNat)) * prime - return h - -def fnv1a64 (bytes : ByteArray) : UInt64 := - fnv1a64Update fnv1a64Init bytes - -private def findLineIdxFrom (lines : Array String) (start : Nat) (p : String → Bool) : - Option Nat := - Nfp.Sound.findLineIdxFrom lines start p - -private def skipUntil (lines : Array String) (start : Nat) (p : String → Bool) : Nat := - Nfp.Sound.skipUntil lines start p - -private def skipBlankLines (lines : Array String) (start : Nat) : Nat := - Nfp.Sound.skipBlankLines lines start - -@[inline] private def countWsTokens (s : String) : Nat := - Nfp.Sound.countWsTokens s - -private def skipTokensFast (lines : Array String) (start : Nat) (numTokens : Nat) : - Except String Nat := - Id.run do - let mut iLine := start - let mut remaining := numTokens - while remaining > 0 do - if iLine ≥ lines.size then - return .error "unexpected end of file while skipping tokens" - let line := lines[iLine]! - iLine := iLine + 1 - let c := countWsTokens line - if c = 0 then - pure () - else if c ≥ remaining then - remaining := 0 - else - remaining := remaining - c - return .ok iLine - -private def consumeFixedBytes - (scalePow10 : Nat) - (lines : Array String) - (start : Nat) - (count : Nat) : Except String (ByteArray × Nat) := - Id.run do - let mut iLine := start - let mut remaining := count - let mut buf : ByteArray := ByteArray.mk (Array.replicate (count * 4) 0) - let mut offBytes : Nat := 0 - while remaining > 0 do - if iLine ≥ lines.size then - return .error "unexpected end of file while reading fixed tokens" - let line := lines[iLine]!.trim - iLine := iLine + 1 - if line.isEmpty then - pure () - else - let bytes := line.toUTF8 - let mut j : Nat := 0 - while j < bytes.size && remaining > 0 do - while j < bytes.size && (bytes[j]! = 32 || bytes[j]! = 9) do - j := j + 1 - if j ≥ bytes.size then - break - let tokStart := j - while j < bytes.size && (bytes[j]! ≠ 32 && bytes[j]! ≠ 9) do - j := j + 1 - let tokStop := j - match parseFixed10Rounded scalePow10 bytes tokStart tokStop with - | .error e => return .error e - | .ok x => - let ux : UInt32 := UInt32.ofInt x - buf := buf.set! offBytes (ux &&& 0xFF).toUInt8 - buf := buf.set! (offBytes + 1) ((ux >>> 8) &&& 0xFF).toUInt8 - buf := buf.set! (offBytes + 2) ((ux >>> 16) &&& 0xFF).toUInt8 - buf := buf.set! (offBytes + 3) ((ux >>> 24) &&& 0xFF).toUInt8 - offBytes := offBytes + 4 - remaining := remaining - 1 - return .ok (buf, iLine) - -private def readHeaderFromLines (lines : Array String) : Except String (Header × Nat) := - Id.run do - let mut i : Nat := 0 - while i < lines.size && lines[i]!.trim.isEmpty do - i := i + 1 - if i ≥ lines.size then - return .error "empty model file" - let headerTag := lines[i]!.trim - if !headerTag.startsWith "NFP_TEXT" then - return .error s!"unexpected header '{headerTag}'" - i := i + 1 - - let mut numLayers : Option UInt32 := none - let mut numHeads : Option UInt32 := none - let mut modelDim : Option UInt32 := none - let mut headDim : Option UInt32 := none - let mut hiddenDim : Option UInt32 := none - - while i < lines.size do - let line := lines[i]!.trim - if line.isEmpty then - i := i + 1 - break - match parseHeaderLine line with - | none => - i := i + 1 - | some (k, v) => - match k with - | "num_layers" => numLayers := (v.toNat?.map UInt32.ofNat) - | "num_heads" => numHeads := (v.toNat?.map UInt32.ofNat) - | "model_dim" => modelDim := (v.toNat?.map UInt32.ofNat) - | "head_dim" => headDim := (v.toNat?.map UInt32.ofNat) - | "hidden_dim" => hiddenDim := (v.toNat?.map UInt32.ofNat) - | _ => pure () - i := i + 1 - - let some L := numLayers | return .error "missing num_layers" - let some H := numHeads | return .error "missing num_heads" - let some d := modelDim | return .error "missing model_dim" - let some dh := headDim | return .error "missing head_dim" - let some dhid := hiddenDim | return .error "missing hidden_dim" - let hdr : Header := - { modelHash := 0, modelSize := 0, scalePow10 := 0, numLayers := L, numHeads := H - modelDim := d, headDim := dh, hiddenDim := dhid } - return .ok (hdr, i) - -private structure LNParamsFixed where - gamma : Array Int - beta : Array Int - -private instance : Inhabited LNParamsFixed := - ⟨{ gamma := #[], beta := #[] }⟩ - -private def collectLayerNormParamsFixed - (scalePow10 : Nat) (lines : Array String) (numLayers modelDim : Nat) : - Except String (Array LNParamsFixed × Array LNParamsFixed) := - Id.run do - let defP : LNParamsFixed := - { gamma := Array.replicate modelDim (0 : Int), beta := Array.replicate modelDim (0 : Int) } - let mut ln1 : Array LNParamsFixed := Array.replicate numLayers defP - let mut ln2 : Array LNParamsFixed := Array.replicate numLayers defP - let mut curLayer : Nat := 0 - let mut i : Nat := 0 - while i < lines.size do - let line := lines[i]!.trim - if line.startsWith "LAYER" then - let parts := line.splitOn " " |>.filter (· ≠ "") - if parts.length >= 2 then - curLayer := (parts[1]!).toNat? |>.getD curLayer - i := i + 1 - else if line = "LN1_GAMMA" then - match foldFixed10Tokens scalePow10 lines (i + 1) modelDim (Array.mkEmpty modelDim) - (fun a x => a.push x) with - | .error e => return .error e - | .ok (xs, next) => - if curLayer < ln1.size then - let old := ln1[curLayer]! - ln1 := ln1.set! curLayer { old with gamma := xs } - i := next - else if line = "LN1_BETA" then - match foldFixed10Tokens scalePow10 lines (i + 1) modelDim (Array.mkEmpty modelDim) - (fun a x => a.push x) with - | .error e => return .error e - | .ok (xs, next) => - if curLayer < ln1.size then - let old := ln1[curLayer]! - ln1 := ln1.set! curLayer { old with beta := xs } - i := next - else if line = "LN2_GAMMA" then - match foldFixed10Tokens scalePow10 lines (i + 1) modelDim (Array.mkEmpty modelDim) - (fun a x => a.push x) with - | .error e => return .error e - | .ok (xs, next) => - if curLayer < ln2.size then - let old := ln2[curLayer]! - ln2 := ln2.set! curLayer { old with gamma := xs } - i := next - else if line = "LN2_BETA" then - match foldFixed10Tokens scalePow10 lines (i + 1) modelDim (Array.mkEmpty modelDim) - (fun a x => a.push x) with - | .error e => return .error e - | .ok (xs, next) => - if curLayer < ln2.size then - let old := ln2[curLayer]! - ln2 := ln2.set! curLayer { old with beta := xs } - i := next - else - i := i + 1 - return .ok (ln1, ln2) - -private def encodeIntArray (xs : Array Int) : ByteArray := - Id.run do - let mut out : ByteArray := ByteArray.mk (Array.replicate (xs.size * 4) 0) - let mut off : Nat := 0 - for x in xs do - let ux : UInt32 := UInt32.ofInt x - out := out.set! off (ux &&& 0xFF).toUInt8 - out := out.set! (off + 1) ((ux >>> 8) &&& 0xFF).toUInt8 - out := out.set! (off + 2) ((ux >>> 16) &&& 0xFF).toUInt8 - out := out.set! (off + 3) ((ux >>> 24) &&& 0xFF).toUInt8 - off := off + 4 - return out - -private def repeatBytes (b : ByteArray) (n : Nat) : ByteArray := - Id.run do - if n = 0 || b.size = 0 then - return ByteArray.empty - let mut out : ByteArray := ByteArray.mk (Array.replicate (n * b.size) 0) - let mut off : Nat := 0 - for _ in [:n] do - out := b.copySlice 0 out off b.size - off := off + b.size - return out - -def buildCacheBytes - (lines : Array String) - (scalePow10 : Nat) - (modelHash modelSize : UInt64) : Except String ByteArray := - Id.run do - let hdr0E := readHeaderFromLines lines - let (hdr0, _afterHdr) ← - match hdr0E with - | .error e => return .error e - | .ok x => pure x - - let L : Nat := hdr0.numLayers.toNat - let H : Nat := hdr0.numHeads.toNat - let d : Nat := hdr0.modelDim.toNat - let dh : Nat := hdr0.headDim.toNat - let dhid : Nat := hdr0.hiddenDim.toNat - - let (ln1, ln2) ← - match collectLayerNormParamsFixed scalePow10 lines L d with - | .error e => return .error e - | .ok x => pure x - - let hdr : Header := - { hdr0 with - modelHash := modelHash - modelSize := modelSize - scalePow10 := UInt32.ofNat scalePow10 } - - let totalBytes : Nat := headerBytes + expectedI32Count hdr * 4 - let appendBytes := fun (out : Array UInt8) (bytes : ByteArray) => Id.run do - let mut out := out - for b in bytes.data do - out := out.push b - return out - - let mut out : Array UInt8 := Array.mkEmpty totalBytes - out := appendBytes out (encodeHeader hdr) - let mut pos : Nat := skipUntil lines 0 (fun s => s.startsWith "LAYER") - let zeroBytes := i32le 0 - - for l in [:L] do - let p1 := ln1.getD l { gamma := Array.replicate d (0 : Int), beta := Array.replicate d 0 } - let p2 := ln2.getD l { gamma := Array.replicate d (0 : Int), beta := Array.replicate d 0 } - out := appendBytes out (encodeIntArray p1.gamma) - out := appendBytes out (encodeIntArray p1.beta) - out := appendBytes out (encodeIntArray p2.gamma) - out := appendBytes out (encodeIntArray p2.beta) - - pos := skipUntil lines pos (fun s => s.startsWith "LAYER") - if pos ≥ lines.size then - return .error s!"unexpected EOF while scanning layer {l}" - pos := pos + 1 - - for _h in [:H] do - pos := skipBlankLines lines pos - if !(pos < lines.size && (lines[pos]!.trim.startsWith "HEAD")) then - return .error "expected HEAD" - pos := pos + 1 - - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "W_Q") then - return .error "missing W_Q" - match skipTokensFast lines (pos + 1) (d * dh) with - | .error e => return .error e - | .ok next => pos := next - - pos := skipBlankLines lines pos - if pos < lines.size && lines[pos]!.trim = "b_Q" then - match skipTokensFast lines (pos + 1) dh with - | .error e => return .error e - | .ok next => pos := next - - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "W_K") then - return .error "missing W_K" - match skipTokensFast lines (pos + 1) (d * dh) with - | .error e => return .error e - | .ok next => pos := next - - pos := skipBlankLines lines pos - if pos < lines.size && lines[pos]!.trim = "b_K" then - match skipTokensFast lines (pos + 1) dh with - | .error e => return .error e - | .ok next => pos := next - - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "W_V") then - return .error "missing W_V" - match consumeFixedBytes scalePow10 lines (pos + 1) (d * dh) with - | .error e => return .error e - | .ok (bytes, next) => - out := appendBytes out bytes - pos := next - - pos := skipBlankLines lines pos - if pos < lines.size && lines[pos]!.trim = "b_V" then - match consumeFixedBytes scalePow10 lines (pos + 1) dh with - | .error e => return .error e - | .ok (bytes, next) => - out := appendBytes out bytes - pos := next - else - out := appendBytes out (repeatBytes zeroBytes dh) - - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "W_O") then - return .error "missing W_O" - match consumeFixedBytes scalePow10 lines (pos + 1) (dh * d) with - | .error e => return .error e - | .ok (bytes, next) => - out := appendBytes out bytes - pos := next - - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "ATTN_BIAS") then - return .error "missing ATTN_BIAS" - match consumeFixedBytes scalePow10 lines (pos + 1) d with - | .error e => return .error e - | .ok (bytes, next) => - out := appendBytes out bytes - pos := next - - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "MLP") then - return .error "missing MLP" - pos := pos + 1 - - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "W_in") then - return .error "missing W_in" - match consumeFixedBytes scalePow10 lines (pos + 1) (d * dhid) with - | .error e => return .error e - | .ok (bytes, next) => - out := appendBytes out bytes - pos := next - - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "b_in") then - return .error "missing b_in" - match consumeFixedBytes scalePow10 lines (pos + 1) dhid with - | .error e => return .error e - | .ok (bytes, next) => - out := appendBytes out bytes - pos := next - - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "W_out") then - return .error "missing W_out" - match consumeFixedBytes scalePow10 lines (pos + 1) (dhid * d) with - | .error e => return .error e - | .ok (bytes, next) => - out := appendBytes out bytes - pos := next - - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "b_out") then - return .error "missing b_out" - match consumeFixedBytes scalePow10 lines (pos + 1) d with - | .error e => return .error e - | .ok (bytes, next) => - out := appendBytes out bytes - pos := next - - pos := skipUntil lines pos (fun s => s.startsWith "LAYER") - - return .ok (ByteArray.mk out) - -private def isMaybeNumberStart (b : UInt8) : Bool := - b = 45 || b = 43 || b = 46 || (48 ≤ b && b ≤ 57) - -def checkTextTokenEnvelopeLines - (lines : Array String) - (scalePow10 : Nat := 9) - (maxTokens : Nat := 0) : Except String Unit := - Id.run do - let cfg : Fixed10Cfg := { scalePow10 := scalePow10 } - let S : Nat := cfg.scaleNat - let mut checked : Nat := 0 - let mut done : Bool := false - for line in lines do - if done then - break - let s := line.trim - if s.isEmpty then - pure () - else - let bytes := s.toUTF8 - let mut i : Nat := 0 - while i < bytes.size do - while i < bytes.size && (bytes[i]! = 32 || bytes[i]! = 9) do - i := i + 1 - if i ≥ bytes.size then - i := bytes.size - let tokStart := i - while i < bytes.size && (bytes[i]! ≠ 32 && bytes[i]! ≠ 9) do - i := i + 1 - let tokStop := i - if tokStart < tokStop && isMaybeNumberStart (bytes[tokStart]!) then - let tok := String.Pos.Raw.extract s ⟨tokStart⟩ ⟨tokStop⟩ - match parseRat tok with - | .error _ => - pure () - | .ok r => - match parseFixed10Rounded scalePow10 bytes tokStart tokStop with - | .error e => return .error e - | .ok w => - let lo : Rat := Rat.normalize (w - 1) S (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := scalePow10) h10pos)) - let hi : Rat := Rat.normalize (w + 1) S (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := scalePow10) h10pos)) - if lo ≤ r ∧ r ≤ hi then - checked := checked + 1 - else - return .error s!"token '{tok}' out of envelope: {lo} ≤ {r} ≤ {hi} failed" - if maxTokens ≠ 0 && checked ≥ maxTokens then - done := true - i := bytes.size - return .ok () - -structure I32Reader where - h : IO.FS.Handle - buf : ByteArray - pos : Nat - -def i32FromBuffer (buf : ByteArray) (pos : Nat) : Int := - i32FromLE buf pos - -/-! ### Derived properties -/ - -theorem u32le_size (x : UInt32) : (u32le x).size = 4 := by - rfl - -theorem u64le_size (x : UInt64) : (u64le x).size = 8 := by - rfl - -/-- `encodeHeader` has the exact byte length advertised by `headerBytes`. -/ -theorem encodeHeader_size (hdr : Header) : (encodeHeader hdr).size = headerBytes := by - simp [encodeHeader, headerBytes, ByteArray.size_append, u32le_size, u64le_size] - -/-- `encodeHeader` always begins with the cache magic prefix. -/ -theorem encodeHeader_magic_prefix (hdr : Header) : - (encodeHeader hdr).extract 0 magic.size = magic := by - simp [encodeHeader, ByteArray.append_assoc, ByteArray.extract_append_eq_left] - -/-- `get!` agrees with `getElem` when the index is in bounds. -/ -theorem get!_eq_getElem {b : ByteArray} {i : Nat} (h : i < b.size) : b.get! i = b[i]'h := by - cases b with - | mk bs => - have h' : i < bs.size := by simpa using h - simpa [ByteArray.get!, ByteArray.get] using (getElem!_pos (c := bs) (i := i) h') - -/-- `get!` on an appended array reduces to the left part when the index is in bounds. -/ -theorem get!_append_left {a b : ByteArray} {i : Nat} - (hi : i < (a ++ b).size) (hlt : i < a.size) : (a ++ b).get! i = a.get! i := by - calc - (a ++ b).get! i = (a ++ b)[i]'hi := get!_eq_getElem hi - _ = a[i]'hlt := by - simpa using - (ByteArray.getElem_append_left (i := i) (a := a) (b := b) (h := hi) hlt) - _ = a.get! i := by - symm - exact get!_eq_getElem hlt - -/-- `get!` on an appended array reduces to the right part when the index is in bounds. -/ -theorem get!_append_right {a b : ByteArray} {i : Nat} - (hi : i < (a ++ b).size) (hle : a.size ≤ i) : - (a ++ b).get! i = b.get! (i - a.size) := by - have h' : i - a.size < b.size := by - have hi' : i < a.size + b.size := by - simpa [ByteArray.size_append] using hi - exact (Nat.sub_lt_iff_lt_add hle).2 (by simpa [Nat.add_comm] using hi') - calc - (a ++ b).get! i = (a ++ b)[i]'hi := get!_eq_getElem hi - _ = b[i - a.size]'h' := by - simpa using - (ByteArray.getElem_append_right (i := i) (a := a) (b := b) (h := hi) hle) - _ = b.get! (i - a.size) := by - symm - exact get!_eq_getElem h' - -/-- `u32FromLE` is a left inverse of `u32le` at offset `0`. -/ -theorem u32FromLE_u32le (x : UInt32) : u32FromLE (u32le x) 0 = x := by - apply (UInt32.toBitVec_inj).1 - have h255 : (255 : UInt8) = -1 := by decide - simp [u32FromLE, u32le, ByteArray.get!, h255] - bv_decide - -/-- `u64FromLE` is a left inverse of `u64le` at offset `0`. -/ -theorem u64FromLE_u64le (x : UInt64) : u64FromLE (u64le x) 0 = x := by - apply (UInt64.toBitVec_inj).1 - have h255 : (255 : UInt8) = -1 := by decide - simp [u64FromLE, u64le, ByteArray.get!, h255] - bv_decide - -/-- `u32FromLE` depends only on the left prefix when it has enough bytes. -/ -theorem u32FromLE_append_left (a b : ByteArray) (h : 3 < a.size) : - u32FromLE (a ++ b) 0 = u32FromLE a 0 := by - have h0 : 0 < a.size := by omega - have h1 : 1 < a.size := by omega - have h2 : 2 < a.size := by omega - have h3 : 3 < a.size := h - have hi0 : 0 < (a ++ b).size := by - have := Nat.lt_of_lt_of_le h0 (Nat.le_add_right a.size b.size) - simpa [ByteArray.size_append] using this - have hi1 : 1 < (a ++ b).size := by - have := Nat.lt_of_lt_of_le h1 (Nat.le_add_right a.size b.size) - simpa [ByteArray.size_append] using this - have hi2 : 2 < (a ++ b).size := by - have := Nat.lt_of_lt_of_le h2 (Nat.le_add_right a.size b.size) - simpa [ByteArray.size_append] using this - have hi3 : 3 < (a ++ b).size := by - have := Nat.lt_of_lt_of_le h3 (Nat.le_add_right a.size b.size) - simpa [ByteArray.size_append] using this - simp [u32FromLE, get!_append_left hi0 h0, get!_append_left hi1 h1, - get!_append_left hi2 h2, get!_append_left hi3 h3] - -/-- `u64FromLE` depends only on the left prefix when it has enough bytes. -/ -theorem u64FromLE_append_left (a b : ByteArray) (h : 7 < a.size) : - u64FromLE (a ++ b) 0 = u64FromLE a 0 := by - have h0 : 0 < a.size := by omega - have h1 : 1 < a.size := by omega - have h2 : 2 < a.size := by omega - have h3 : 3 < a.size := by omega - have h4 : 4 < a.size := by omega - have h5 : 5 < a.size := by omega - have h6 : 6 < a.size := by omega - have h7 : 7 < a.size := h - have hi0 : 0 < (a ++ b).size := by - have := Nat.lt_of_lt_of_le h0 (Nat.le_add_right a.size b.size) - simpa [ByteArray.size_append] using this - have hi1 : 1 < (a ++ b).size := by - have := Nat.lt_of_lt_of_le h1 (Nat.le_add_right a.size b.size) - simpa [ByteArray.size_append] using this - have hi2 : 2 < (a ++ b).size := by - have := Nat.lt_of_lt_of_le h2 (Nat.le_add_right a.size b.size) - simpa [ByteArray.size_append] using this - have hi3 : 3 < (a ++ b).size := by - have := Nat.lt_of_lt_of_le h3 (Nat.le_add_right a.size b.size) - simpa [ByteArray.size_append] using this - have hi4 : 4 < (a ++ b).size := by - have := Nat.lt_of_lt_of_le h4 (Nat.le_add_right a.size b.size) - simpa [ByteArray.size_append] using this - have hi5 : 5 < (a ++ b).size := by - have := Nat.lt_of_lt_of_le h5 (Nat.le_add_right a.size b.size) - simpa [ByteArray.size_append] using this - have hi6 : 6 < (a ++ b).size := by - have := Nat.lt_of_lt_of_le h6 (Nat.le_add_right a.size b.size) - simpa [ByteArray.size_append] using this - have hi7 : 7 < (a ++ b).size := by - have := Nat.lt_of_lt_of_le h7 (Nat.le_add_right a.size b.size) - simpa [ByteArray.size_append] using this - simp [u64FromLE, get!_append_left hi0 h0, get!_append_left hi1 h1, - get!_append_left hi2 h2, get!_append_left hi3 h3, get!_append_left hi4 h4, - get!_append_left hi5 h5, get!_append_left hi6 h6, get!_append_left hi7 h7] - -/-- `u32FromLE` ignores a left prefix when reading from the right. -/ -theorem u32FromLE_append_right (a b : ByteArray) (off : Nat) (h : off + 3 < b.size) : - u32FromLE (a ++ b) (a.size + off) = u32FromLE b off := by - have h0' : off < b.size := by omega - have h1' : off + 1 < b.size := by omega - have h2' : off + 2 < b.size := by omega - have h3' : off + 3 < b.size := h - have h0 : a.size + off < (a ++ b).size := by - have := Nat.add_lt_add_left h0' a.size - simpa [ByteArray.size_append, Nat.add_assoc] using this - have h1 : a.size + off + 1 < (a ++ b).size := by - have := Nat.add_lt_add_left h1' a.size - simpa [ByteArray.size_append, Nat.add_assoc] using this - have h2 : a.size + off + 2 < (a ++ b).size := by - have := Nat.add_lt_add_left h2' a.size - simpa [ByteArray.size_append, Nat.add_assoc] using this - have h3 : a.size + off + 3 < (a ++ b).size := by - have := Nat.add_lt_add_left h3' a.size - simpa [ByteArray.size_append, Nat.add_assoc] using this - have hle0 : a.size ≤ a.size + off := by omega - have hle1 : a.size ≤ a.size + off + 1 := by omega - have hle2 : a.size ≤ a.size + off + 2 := by omega - have hle3 : a.size ≤ a.size + off + 3 := by omega - unfold u32FromLE - simp [get!_append_right h0 hle0, get!_append_right h1 hle1, - get!_append_right h2 hle2, get!_append_right h3 hle3] - simp [Nat.add_assoc, Nat.add_sub_cancel_left] - -/-- `u64FromLE` ignores a left prefix when reading from the right. -/ -theorem u64FromLE_append_right (a b : ByteArray) (off : Nat) (h : off + 7 < b.size) : - u64FromLE (a ++ b) (a.size + off) = u64FromLE b off := by - have h0' : off < b.size := by omega - have h1' : off + 1 < b.size := by omega - have h2' : off + 2 < b.size := by omega - have h3' : off + 3 < b.size := by omega - have h4' : off + 4 < b.size := by omega - have h5' : off + 5 < b.size := by omega - have h6' : off + 6 < b.size := by omega - have h7' : off + 7 < b.size := h - have h0 : a.size + off < (a ++ b).size := by - have := Nat.add_lt_add_left h0' a.size - simpa [ByteArray.size_append, Nat.add_assoc] using this - have h1 : a.size + off + 1 < (a ++ b).size := by - have := Nat.add_lt_add_left h1' a.size - simpa [ByteArray.size_append, Nat.add_assoc] using this - have h2 : a.size + off + 2 < (a ++ b).size := by - have := Nat.add_lt_add_left h2' a.size - simpa [ByteArray.size_append, Nat.add_assoc] using this - have h3 : a.size + off + 3 < (a ++ b).size := by - have := Nat.add_lt_add_left h3' a.size - simpa [ByteArray.size_append, Nat.add_assoc] using this - have h4 : a.size + off + 4 < (a ++ b).size := by - have := Nat.add_lt_add_left h4' a.size - simpa [ByteArray.size_append, Nat.add_assoc] using this - have h5 : a.size + off + 5 < (a ++ b).size := by - have := Nat.add_lt_add_left h5' a.size - simpa [ByteArray.size_append, Nat.add_assoc] using this - have h6 : a.size + off + 6 < (a ++ b).size := by - have := Nat.add_lt_add_left h6' a.size - simpa [ByteArray.size_append, Nat.add_assoc] using this - have h7 : a.size + off + 7 < (a ++ b).size := by - have := Nat.add_lt_add_left h7' a.size - simpa [ByteArray.size_append, Nat.add_assoc] using this - have hle0 : a.size ≤ a.size + off := by omega - have hle1 : a.size ≤ a.size + off + 1 := by omega - have hle2 : a.size ≤ a.size + off + 2 := by omega - have hle3 : a.size ≤ a.size + off + 3 := by omega - have hle4 : a.size ≤ a.size + off + 4 := by omega - have hle5 : a.size ≤ a.size + off + 5 := by omega - have hle6 : a.size ≤ a.size + off + 6 := by omega - have hle7 : a.size ≤ a.size + off + 7 := by omega - unfold u64FromLE - simp [get!_append_right h0 hle0, get!_append_right h1 hle1, - get!_append_right h2 hle2, get!_append_right h3 hle3, get!_append_right h4 hle4, - get!_append_right h5 hle5, get!_append_right h6 hle6, get!_append_right h7 hle7] - simp [Nat.add_assoc, Nat.add_sub_cancel_left] - -/-- `u32FromLE` round-trips a `u32le` prefix. -/ -theorem u32FromLE_u32le_append (x : UInt32) (b : ByteArray) : - u32FromLE (u32le x ++ b) 0 = x := by - have h : 3 < (u32le x).size := by - simp [u32le_size] - calc - u32FromLE (u32le x ++ b) 0 = u32FromLE (u32le x) 0 := - u32FromLE_append_left (a := u32le x) (b := b) h - _ = x := u32FromLE_u32le x - -/-- `u64FromLE` round-trips a `u64le` prefix. -/ -theorem u64FromLE_u64le_append (x : UInt64) (b : ByteArray) : - u64FromLE (u64le x ++ b) 0 = x := by - have h : 7 < (u64le x).size := by - simp [u64le_size] - calc - u64FromLE (u64le x ++ b) 0 = u64FromLE (u64le x) 0 := - u64FromLE_append_left (a := u64le x) (b := b) h - _ = x := u64FromLE_u64le x - -/-- `u32FromLE` round-trips a `u32le` block after a prefix. -/ -theorem u32FromLE_append_u32le (a : ByteArray) (x : UInt32) (b : ByteArray) : - u32FromLE (a ++ u32le x ++ b) a.size = x := by - calc - u32FromLE (a ++ u32le x ++ b) a.size = u32FromLE (u32le x ++ b) 0 := by - have h : 0 + 3 < (u32le x ++ b).size := by - simp [ByteArray.size_append, u32le_size] - omega - simpa [ByteArray.append_assoc] using - (u32FromLE_append_right (a := a) (b := u32le x ++ b) (off := 0) h) - _ = x := u32FromLE_u32le_append x b - -/-- `u64FromLE` round-trips a `u64le` block after a prefix. -/ -theorem u64FromLE_append_u64le (a : ByteArray) (x : UInt64) (b : ByteArray) : - u64FromLE (a ++ u64le x ++ b) a.size = x := by - calc - u64FromLE (a ++ u64le x ++ b) a.size = u64FromLE (u64le x ++ b) 0 := by - have h : 0 + 7 < (u64le x ++ b).size := by - simp [ByteArray.size_append, u64le_size] - omega - simpa [ByteArray.append_assoc] using - (u64FromLE_append_right (a := a) (b := u64le x ++ b) (off := 0) h) - _ = x := u64FromLE_u64le_append x b - -/-- `decodeHeader` recovers any header encoded by `encodeHeader`. -/ -theorem decodeHeader_encodeHeader (hdr : Header) : - decodeHeader (encodeHeader hdr) = .ok hdr := by - have h1 : magic.size + 4 = (magic ++ u32le version).size := by - simp [ByteArray.size_append, u32le_size] - have h2 : magic.size + 4 + 8 = - (magic ++ u32le version ++ u64le hdr.modelHash).size := by - simp [ByteArray.size_append, u32le_size, u64le_size] - have h3 : magic.size + 4 + 8 + 8 = - (magic ++ u32le version ++ u64le hdr.modelHash ++ u64le hdr.modelSize).size := by - simp [ByteArray.size_append, u32le_size, u64le_size] - have h4 : magic.size + 4 + 8 + 8 + 4 = - (magic ++ u32le version ++ u64le hdr.modelHash ++ u64le hdr.modelSize ++ - u32le hdr.scalePow10).size := by - simp [ByteArray.size_append, u32le_size, u64le_size] - have h5 : magic.size + 4 + 8 + 8 + 4 + 4 = - (magic ++ u32le version ++ u64le hdr.modelHash ++ u64le hdr.modelSize ++ - u32le hdr.scalePow10 ++ u32le hdr.numLayers).size := by - simp [ByteArray.size_append, u32le_size, u64le_size] - have h6 : magic.size + 4 + 8 + 8 + 4 + 4 + 4 = - (magic ++ u32le version ++ u64le hdr.modelHash ++ u64le hdr.modelSize ++ - u32le hdr.scalePow10 ++ u32le hdr.numLayers ++ u32le hdr.numHeads).size := by - simp [ByteArray.size_append, u32le_size, u64le_size] - have h7 : magic.size + 4 + 8 + 8 + 4 + 4 + 4 + 4 = - (magic ++ u32le version ++ u64le hdr.modelHash ++ u64le hdr.modelSize ++ - u32le hdr.scalePow10 ++ u32le hdr.numLayers ++ u32le hdr.numHeads ++ - u32le hdr.modelDim).size := by - simp [ByteArray.size_append, u32le_size, u64le_size] - have h8 : magic.size + 4 + 8 + 8 + 4 + 4 + 4 + 4 + 4 = - (magic ++ u32le version ++ u64le hdr.modelHash ++ u64le hdr.modelSize ++ - u32le hdr.scalePow10 ++ u32le hdr.numLayers ++ u32le hdr.numHeads ++ - u32le hdr.modelDim ++ u32le hdr.headDim).size := by - simp [ByteArray.size_append, u32le_size, u64le_size] - have h_version : u32FromLE (encodeHeader hdr) magic.size = version := by - simpa [encodeHeader] using - (u32FromLE_append_u32le (a := magic) (x := version) - (b := u64le hdr.modelHash ++ u64le hdr.modelSize ++ u32le hdr.scalePow10 ++ - u32le hdr.numLayers ++ u32le hdr.numHeads ++ u32le hdr.modelDim ++ - u32le hdr.headDim ++ u32le hdr.hiddenDim)) - have h_modelHash : u64FromLE (encodeHeader hdr) (magic.size + 4) = hdr.modelHash := by - simpa [encodeHeader, h1] using - (u64FromLE_append_u64le (a := magic ++ u32le version) (x := hdr.modelHash) - (b := u64le hdr.modelSize ++ u32le hdr.scalePow10 ++ u32le hdr.numLayers ++ - u32le hdr.numHeads ++ u32le hdr.modelDim ++ u32le hdr.headDim ++ - u32le hdr.hiddenDim)) - have h_modelSize : u64FromLE (encodeHeader hdr) (magic.size + 4 + 8) = hdr.modelSize := by - simpa [encodeHeader, h2] using - (u64FromLE_append_u64le - (a := magic ++ u32le version ++ u64le hdr.modelHash) - (x := hdr.modelSize) - (b := u32le hdr.scalePow10 ++ u32le hdr.numLayers ++ u32le hdr.numHeads ++ - u32le hdr.modelDim ++ u32le hdr.headDim ++ u32le hdr.hiddenDim)) - have h_scalePow10 : - u32FromLE (encodeHeader hdr) (magic.size + 4 + 8 + 8) = hdr.scalePow10 := by - simpa [encodeHeader, h3] using - (u32FromLE_append_u32le - (a := magic ++ u32le version ++ u64le hdr.modelHash ++ u64le hdr.modelSize) - (x := hdr.scalePow10) - (b := u32le hdr.numLayers ++ u32le hdr.numHeads ++ u32le hdr.modelDim ++ - u32le hdr.headDim ++ u32le hdr.hiddenDim)) - have h_numLayers : - u32FromLE (encodeHeader hdr) (magic.size + 4 + 8 + 8 + 4) = hdr.numLayers := by - simpa [encodeHeader, h4] using - (u32FromLE_append_u32le - (a := magic ++ u32le version ++ u64le hdr.modelHash ++ u64le hdr.modelSize ++ - u32le hdr.scalePow10) - (x := hdr.numLayers) - (b := u32le hdr.numHeads ++ u32le hdr.modelDim ++ u32le hdr.headDim ++ - u32le hdr.hiddenDim)) - have h_numHeads : - u32FromLE (encodeHeader hdr) (magic.size + 4 + 8 + 8 + 4 + 4) = hdr.numHeads := by - simpa [encodeHeader, h5] using - (u32FromLE_append_u32le - (a := magic ++ u32le version ++ u64le hdr.modelHash ++ u64le hdr.modelSize ++ - u32le hdr.scalePow10 ++ u32le hdr.numLayers) - (x := hdr.numHeads) - (b := u32le hdr.modelDim ++ u32le hdr.headDim ++ u32le hdr.hiddenDim)) - have h_modelDim : - u32FromLE (encodeHeader hdr) (magic.size + 4 + 8 + 8 + 4 + 4 + 4) = hdr.modelDim := by - simpa [encodeHeader, h6] using - (u32FromLE_append_u32le - (a := magic ++ u32le version ++ u64le hdr.modelHash ++ u64le hdr.modelSize ++ - u32le hdr.scalePow10 ++ u32le hdr.numLayers ++ u32le hdr.numHeads) - (x := hdr.modelDim) - (b := u32le hdr.headDim ++ u32le hdr.hiddenDim)) - have h_headDim : - u32FromLE (encodeHeader hdr) (magic.size + 4 + 8 + 8 + 4 + 4 + 4 + 4) = hdr.headDim := by - simpa [encodeHeader, h7] using - (u32FromLE_append_u32le - (a := magic ++ u32le version ++ u64le hdr.modelHash ++ u64le hdr.modelSize ++ - u32le hdr.scalePow10 ++ u32le hdr.numLayers ++ u32le hdr.numHeads ++ - u32le hdr.modelDim) - (x := hdr.headDim) - (b := u32le hdr.hiddenDim)) - have h_hiddenDim : - u32FromLE (encodeHeader hdr) (magic.size + 4 + 8 + 8 + 4 + 4 + 4 + 4 + 4) = - hdr.hiddenDim := by - simpa [encodeHeader, h8] using - (u32FromLE_append_u32le - (a := magic ++ u32le version ++ u64le hdr.modelHash ++ u64le hdr.modelSize ++ - u32le hdr.scalePow10 ++ u32le hdr.numLayers ++ u32le hdr.numHeads ++ - u32le hdr.modelDim ++ u32le hdr.headDim) - (x := hdr.hiddenDim) - (b := ByteArray.empty)) - simp [decodeHeader, encodeHeader_size, encodeHeader_magic_prefix, h_version, h_modelHash, - h_modelSize, h_scalePow10, h_numLayers, h_numHeads, h_modelDim, h_headDim, h_hiddenDim] - cases hdr <;> rfl - -/-! ### Specs -/ - -theorem version_spec_cache_pure : version = version := rfl -theorem magic_spec_cache_pure : magic = magic := rfl -theorem Header_spec_cache_pure : Header = Header := rfl -theorem u32le_spec_cache_pure : u32le = u32le := rfl -theorem u64le_spec_cache_pure : u64le = u64le := rfl -theorem i32le_spec_cache_pure : i32le = i32le := rfl -theorem appendI32LE_spec_cache_pure : appendI32LE = appendI32LE := rfl -theorem u32FromLE_spec_cache_pure : u32FromLE = u32FromLE := rfl -theorem u64FromLE_spec_cache_pure : u64FromLE = u64FromLE := rfl -theorem i32FromLE_spec_cache_pure : i32FromLE = i32FromLE := rfl -theorem twoPow32_spec_cache_pure : twoPow32 = twoPow32 := rfl -theorem encodeHeader_spec_cache_pure : encodeHeader = encodeHeader := rfl -theorem headerBytes_spec_cache_pure : headerBytes = headerBytes := rfl -theorem decodeHeader_spec_cache_pure : decodeHeader = decodeHeader := rfl -theorem cacheDir_spec_cache_pure : cacheDir = cacheDir := rfl -theorem cachePath_spec_cache_pure : cachePath = cachePath := rfl -theorem expectedI32Count_spec_cache_pure : expectedI32Count = expectedI32Count := rfl -theorem expectedCacheBytes_spec_cache_pure : expectedCacheBytes = expectedCacheBytes := rfl -theorem fnv1a64Init_spec_cache_pure : fnv1a64Init = fnv1a64Init := rfl -theorem fnv1a64Update_spec_cache_pure : fnv1a64Update = fnv1a64Update := rfl -theorem fnv1a64_spec_cache_pure : fnv1a64 = fnv1a64 := rfl -theorem parseHeaderLine_spec_cache_pure : parseHeaderLine = parseHeaderLine := rfl -theorem findLineIdxFrom_spec_cache_pure : findLineIdxFrom = findLineIdxFrom := rfl -theorem skipUntil_spec_cache_pure : skipUntil = skipUntil := rfl -theorem skipBlankLines_spec_cache_pure : skipBlankLines = skipBlankLines := rfl -theorem countWsTokens_spec_cache_pure : countWsTokens = countWsTokens := rfl -theorem skipTokensFast_spec_cache_pure : skipTokensFast = skipTokensFast := rfl -theorem consumeFixedBytes_spec_cache_pure : consumeFixedBytes = consumeFixedBytes := rfl -theorem readHeaderFromLines_spec_cache_pure : readHeaderFromLines = readHeaderFromLines := rfl -theorem LNParamsFixed_spec_cache_pure : LNParamsFixed = LNParamsFixed := rfl -theorem collectLayerNormParamsFixed_spec_cache_pure : - collectLayerNormParamsFixed = collectLayerNormParamsFixed := rfl -theorem encodeIntArray_spec_cache_pure : encodeIntArray = encodeIntArray := rfl -theorem repeatBytes_spec_cache_pure : repeatBytes = repeatBytes := rfl -theorem buildCacheBytes_spec_cache_pure : buildCacheBytes = buildCacheBytes := rfl -theorem isMaybeNumberStart_spec_cache_pure : isMaybeNumberStart = isMaybeNumberStart := rfl -theorem checkTextTokenEnvelopeLines_spec_cache_pure : - checkTextTokenEnvelopeLines = checkTextTokenEnvelopeLines := rfl -theorem I32Reader_spec_cache_pure : I32Reader = I32Reader := rfl -theorem i32FromBuffer_spec_cache_pure : i32FromBuffer = i32FromBuffer := rfl - -end SoundCache - -end Nfp.Sound diff --git a/Legacy/Nfp/Sound/Cert.lean b/Legacy/Nfp/Sound/Cert.lean deleted file mode 100644 index 65c7ee9..0000000 --- a/Legacy/Nfp/Sound/Cert.lean +++ /dev/null @@ -1,609 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Mathlib.Data.Rat.Cast.Order -import Nfp.SignedMixer -import Nfp.Sound.Bounds -import Nfp.Sound.HeadCert - -namespace Nfp.Sound - -/-! -# Sound certification report structures - -The sound path computes conservative bounds using exact `Rat` arithmetic. -These values are meant to be *trusted* (up to Lean's kernel/runtime), but may be -much looser than the Float-based heuristic analysis. --/ - - -/-- Per-layer conservative residual amplification constant `Cᵢ` -(bounds ‖layerJacobian - I‖) and its components. -`C` uses the safe algebraic form `attn + mlp + attn*mlp`. -/ -structure LayerAmplificationCert where - layerIdx : Nat - ln1MaxAbsGamma : Rat - ln1MaxAbsBeta : Rat - ln2MaxAbsGamma : Rat - /-- Optional local variance lower bound used for LN1 (if available). -/ - ln1VarianceLowerBound? : Option Rat - /-- Optional local variance lower bound used for LN2 (if available). -/ - ln2VarianceLowerBound? : Option Rat - ln1Bound : Rat - ln2Bound : Rat - /-- Upper bound on `max |LN1 output|` after affine. -/ - ln1OutMaxAbsBound : Rat - /-- Lower bound on the softmax probability interval used for Jacobian bounds. -/ - softmaxProbLo : Rat - /-- Upper bound on the softmax probability interval used for Jacobian bounds. -/ - softmaxProbHi : Rat - /-- Lower bound on the softmax logit margin (0 means no margin evidence). -/ - softmaxMarginLowerBound : Rat - /-- Effort level for the exp lower bound used in margin-derived softmax bounds. -/ - softmaxExpEffort : Nat - /-- Upper bound on the softmax Jacobian row-sum norm (portfolio bound). -/ - softmaxJacobianNormInfUpperBound : Rat - /-- Maximum operator-norm bound on W_Q across heads (row-sum). -/ - wqOpBoundMax : Rat - /-- Maximum operator-norm bound on W_K across heads (row-sum). -/ - wkOpBoundMax : Rat - /-- Value-term coefficient (sum of per-head `W_V`/`W_O` bounds). -/ - attnValueCoeff : Rat - /-- Pattern-term coefficient (score-gradient L1 × input L1 × value coefficient). -/ - attnPatternCoeff : Rat - mlpCoeff : Rat - /-- Upper bound on the operator norm of the MLP input weights. -/ - mlpWinBound : Rat - /-- Upper bound on the operator norm of the MLP output weights. -/ - mlpWoutBound : Rat - /-- Upper bound on the max GeLU derivative over this layer's preactivations. -/ - mlpActDerivBound : Rat - /-- Upper bound on the attention residual Jacobian contribution. -/ - attnJacBound : Rat - /-- Upper bound on the MLP residual Jacobian contribution. -/ - mlpJacBound : Rat - /-- Combined residual amplification bound: `attn + mlp + attn*mlp`. -/ - C : Rat - deriving Repr - -namespace LayerAmplificationCert - -instance : Inhabited LayerAmplificationCert := - ⟨{ - layerIdx := 0 - ln1MaxAbsGamma := 0 - ln1MaxAbsBeta := 0 - ln2MaxAbsGamma := 0 - ln1VarianceLowerBound? := none - ln2VarianceLowerBound? := none - ln1Bound := 0 - ln2Bound := 0 - ln1OutMaxAbsBound := 0 - softmaxProbLo := 0 - softmaxProbHi := 1 - softmaxMarginLowerBound := 0 - softmaxExpEffort := defaultSoftmaxExpEffort - softmaxJacobianNormInfUpperBound := softmaxJacobianNormInfWorst - wqOpBoundMax := 0 - wkOpBoundMax := 0 - attnValueCoeff := 0 - attnPatternCoeff := 0 - mlpCoeff := 0 - mlpWinBound := 0 - mlpWoutBound := 0 - mlpActDerivBound := 0 - attnJacBound := 0 - mlpJacBound := 0 - C := 0 - }⟩ - -/-- Portfolio softmax Jacobian bound from interval and margin candidates. -/ -def softmaxJacobianNormInfPortfolioBound (seqLen : Nat) (l : LayerAmplificationCert) : Rat := - ubBest (softmaxJacobianNormInfBound l.softmaxProbLo l.softmaxProbHi) - #[softmaxJacobianNormInfBoundFromMargin seqLen l.softmaxMarginLowerBound l.softmaxExpEffort] - -theorem softmaxJacobianNormInfPortfolioBound_def (seqLen : Nat) (l : LayerAmplificationCert) : - softmaxJacobianNormInfPortfolioBound seqLen l = - ubBest (softmaxJacobianNormInfBound l.softmaxProbLo l.softmaxProbHi) - #[softmaxJacobianNormInfBoundFromMargin seqLen l.softmaxMarginLowerBound - l.softmaxExpEffort] := rfl - -/-- Update margin evidence and recompute dependent softmax + residual bounds. -/ -def withSoftmaxMargin (seqLen modelDim headDim : Nat) - (marginLowerBound : Rat) (softmaxExpEffort : Nat) (l : LayerAmplificationCert) : - LayerAmplificationCert := - let l' := - { l with - softmaxMarginLowerBound := marginLowerBound - softmaxExpEffort := softmaxExpEffort } - let scoreAbsBound := - attnScoreAbsBound modelDim headDim l'.ln1OutMaxAbsBound l'.wqOpBoundMax l'.wkOpBoundMax - let (softmaxProbLo, softmaxProbHi) := - softmaxProbIntervalFromScoreAbsBound seqLen scoreAbsBound l'.softmaxExpEffort - let l'' := { l' with softmaxProbLo := softmaxProbLo, softmaxProbHi := softmaxProbHi } - let softmaxBound := softmaxJacobianNormInfPortfolioBound seqLen l'' - let attnJacBound := - l''.ln1Bound * - ((seqLen : Rat) * l''.attnValueCoeff + softmaxBound * l''.attnPatternCoeff) - let mlpJacBound := l''.mlpJacBound - let C := attnJacBound + mlpJacBound + attnJacBound * mlpJacBound - { l'' with - softmaxJacobianNormInfUpperBound := softmaxBound - attnJacBound := attnJacBound - C := C } - -/-- Internal consistency checks for per-layer bounds. -/ -def Valid (eps : Rat) (sqrtPrecBits : Nat) (seqLen modelDim headDim : Nat) - (l : LayerAmplificationCert) : Prop := - let scoreAbsBound := - attnScoreAbsBound modelDim headDim l.ln1OutMaxAbsBound l.wqOpBoundMax l.wkOpBoundMax - let probInterval := - softmaxProbIntervalFromScoreAbsBound seqLen scoreAbsBound l.softmaxExpEffort - l.ln1Bound = - (match l.ln1VarianceLowerBound? with - | some v => layerNormOpBoundLocal l.ln1MaxAbsGamma v eps sqrtPrecBits - | none => layerNormOpBoundConservative l.ln1MaxAbsGamma eps sqrtPrecBits) ∧ - l.ln2Bound = - (match l.ln2VarianceLowerBound? with - | some v => layerNormOpBoundLocal l.ln2MaxAbsGamma v eps sqrtPrecBits - | none => layerNormOpBoundConservative l.ln2MaxAbsGamma eps sqrtPrecBits) ∧ - l.ln1OutMaxAbsBound = - layerNormOutputMaxAbsBound modelDim l.ln1MaxAbsGamma l.ln1MaxAbsBeta ∧ - l.softmaxProbLo = probInterval.1 ∧ - l.softmaxProbHi = probInterval.2 ∧ - l.softmaxJacobianNormInfUpperBound = - softmaxJacobianNormInfPortfolioBound seqLen l ∧ - l.attnPatternCoeff = - attnPatternCoeffBound seqLen modelDim headDim l.ln1OutMaxAbsBound l.wqOpBoundMax - l.wkOpBoundMax l.attnValueCoeff ∧ - l.attnJacBound = - l.ln1Bound * - ((seqLen : Rat) * l.attnValueCoeff + - l.softmaxJacobianNormInfUpperBound * l.attnPatternCoeff) ∧ - l.mlpCoeff = l.mlpWinBound * l.mlpWoutBound ∧ - l.mlpJacBound = - l.ln2Bound * (l.mlpCoeff * l.mlpActDerivBound) ∧ - l.C = - l.attnJacBound + l.mlpJacBound + - l.attnJacBound * l.mlpJacBound - -instance (eps : Rat) (sqrtPrecBits : Nat) (seqLen modelDim headDim : Nat) - (l : LayerAmplificationCert) : - Decidable (Valid eps sqrtPrecBits seqLen modelDim headDim l) := by - unfold Valid - infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (eps : Rat) (sqrtPrecBits : Nat) (seqLen modelDim headDim : Nat) - (l : LayerAmplificationCert) : Bool := - decide (Valid eps sqrtPrecBits seqLen modelDim headDim l) - -theorem check_iff (eps : Rat) (sqrtPrecBits : Nat) (seqLen modelDim headDim : Nat) - (l : LayerAmplificationCert) : - l.check eps sqrtPrecBits seqLen modelDim headDim = true ↔ - l.Valid eps sqrtPrecBits seqLen modelDim headDim := by - simp [check] - -/-- Extract the `C` identity from `Valid`. -/ -theorem c_eq_of_valid (eps : Rat) (sqrtPrecBits : Nat) (seqLen modelDim headDim : Nat) - (l : LayerAmplificationCert) (h : l.Valid eps sqrtPrecBits seqLen modelDim headDim) : - l.C = l.attnJacBound + l.mlpJacBound + - l.attnJacBound * l.mlpJacBound := by - rcases h with - ⟨_hln1, _hln2, _hln1Out, _hProbLo, _hProbHi, - _hsoftmax, _hpat, _hattn, _hmlpCoeff, _hmlp, hC⟩ - exact hC - -/-- Extract the attention contribution identity from `Valid`. -/ -theorem attnJacBound_eq_of_valid (eps : Rat) (sqrtPrecBits : Nat) - (seqLen modelDim headDim : Nat) (l : LayerAmplificationCert) - (h : l.Valid eps sqrtPrecBits seqLen modelDim headDim) : - l.attnJacBound = - l.ln1Bound * - ((seqLen : Rat) * l.attnValueCoeff + - l.softmaxJacobianNormInfUpperBound * l.attnPatternCoeff) := by - rcases h with - ⟨_hln1, _hln2, _hln1Out, _hProbLo, _hProbHi, - _hsoftmax, _hpat, hattn, _hmlpCoeff, _hmlp, _hC⟩ - exact hattn - -/-- Extract the MLP coefficient identity from `Valid`. -/ -theorem mlpCoeff_eq_of_valid (eps : Rat) (sqrtPrecBits : Nat) (seqLen modelDim headDim : Nat) - (l : LayerAmplificationCert) (h : l.Valid eps sqrtPrecBits seqLen modelDim headDim) : - l.mlpCoeff = l.mlpWinBound * l.mlpWoutBound := by - rcases h with - ⟨_hln1, _hln2, _hln1Out, _hProbLo, _hProbHi, - _hsoftmax, _hpat, _hattn, hCoeff, _hmlp, _hC⟩ - exact hCoeff - -/-- Extract the MLP contribution identity from `Valid`. -/ -theorem mlpJacBound_eq_of_valid (eps : Rat) (sqrtPrecBits : Nat) - (seqLen modelDim headDim : Nat) (l : LayerAmplificationCert) - (h : l.Valid eps sqrtPrecBits seqLen modelDim headDim) : - l.mlpJacBound = l.ln2Bound * (l.mlpCoeff * l.mlpActDerivBound) := by - rcases h with - ⟨_hln1, _hln2, _hln1Out, _hProbLo, _hProbHi, - _hsoftmax, _hpat, _hattn, _hmlpCoeff, hmlp, _hC⟩ - exact hmlp - -/-- Cast the `C` identity to `ℝ` using `Valid`. -/ -theorem c_eq_cast_of_valid (eps : Rat) (sqrtPrecBits : Nat) (seqLen modelDim headDim : Nat) - (l : LayerAmplificationCert) (h : l.Valid eps sqrtPrecBits seqLen modelDim headDim) : - (l.C : ℝ) = - (l.attnJacBound : ℝ) + - (l.mlpJacBound : ℝ) + - (l.attnJacBound : ℝ) * (l.mlpJacBound : ℝ) := by - have hC := c_eq_of_valid eps sqrtPrecBits seqLen modelDim headDim l h - have hC' := congrArg (fun (x : Rat) => (x : ℝ)) hC - simpa [Rat.cast_add, Rat.cast_mul, add_assoc] using hC' - -/-- Cast the attention contribution identity to `ℝ` using `Valid`. -/ -theorem attnJacBound_eq_cast_of_valid (eps : Rat) (sqrtPrecBits : Nat) - (seqLen modelDim headDim : Nat) (l : LayerAmplificationCert) - (h : l.Valid eps sqrtPrecBits seqLen modelDim headDim) : - (l.attnJacBound : ℝ) = - (l.ln1Bound : ℝ) * - ((seqLen : ℝ) * (l.attnValueCoeff : ℝ) + - (l.softmaxJacobianNormInfUpperBound : ℝ) * (l.attnPatternCoeff : ℝ)) := by - have hAttn := attnJacBound_eq_of_valid eps sqrtPrecBits seqLen modelDim headDim l h - have hAttn' := congrArg (fun (x : Rat) => (x : ℝ)) hAttn - simpa [Rat.cast_mul, Rat.cast_add, Rat.cast_natCast, mul_assoc, add_assoc] using hAttn' - -/-- Cast the MLP coefficient identity to `ℝ` using `Valid`. -/ -theorem mlpCoeff_eq_cast_of_valid (eps : Rat) (sqrtPrecBits : Nat) (seqLen modelDim headDim : Nat) - (l : LayerAmplificationCert) (h : l.Valid eps sqrtPrecBits seqLen modelDim headDim) : - (l.mlpCoeff : ℝ) = (l.mlpWinBound : ℝ) * (l.mlpWoutBound : ℝ) := by - have hCoeff := mlpCoeff_eq_of_valid eps sqrtPrecBits seqLen modelDim headDim l h - have hCoeff' := congrArg (fun (x : Rat) => (x : ℝ)) hCoeff - simpa [Rat.cast_mul] using hCoeff' - -/-- Cast the MLP contribution identity to `ℝ` using `Valid`. -/ -theorem mlpJacBound_eq_cast_of_valid (eps : Rat) (sqrtPrecBits : Nat) - (seqLen modelDim headDim : Nat) (l : LayerAmplificationCert) - (h : l.Valid eps sqrtPrecBits seqLen modelDim headDim) : - (l.mlpJacBound : ℝ) = - (l.ln2Bound : ℝ) * ((l.mlpCoeff : ℝ) * (l.mlpActDerivBound : ℝ)) := by - have hMlp := mlpJacBound_eq_of_valid eps sqrtPrecBits seqLen modelDim headDim l h - have hMlp' := congrArg (fun (x : Rat) => (x : ℝ)) hMlp - simpa [Rat.cast_mul, mul_assoc] using hMlp' - -/-- Residual composition bound from component bounds and a cast `C` identity. -/ -theorem residual_bound_of_component_bounds - {S : Type*} [Fintype S] [DecidableEq S] [Nonempty S] - (l : LayerAmplificationCert) - (A M : SignedMixer S S) - (hC : (l.C : ℝ) = - (l.attnJacBound : ℝ) + - (l.mlpJacBound : ℝ) + - (l.attnJacBound : ℝ) * (l.mlpJacBound : ℝ)) - (hA : SignedMixer.operatorNormBound A ≤ (l.attnJacBound : ℝ)) - (hM : SignedMixer.operatorNormBound M ≤ (l.mlpJacBound : ℝ)) : - SignedMixer.operatorNormBound - ((SignedMixer.identity + A).comp (SignedMixer.identity + M) - SignedMixer.identity) - ≤ (l.C : ℝ) := by - have hres := - SignedMixer.operatorNormBound_residual_comp_le_of_bounds - (A := A) (M := M) - (a := (l.attnJacBound : ℝ)) - (b := (l.mlpJacBound : ℝ)) hA hM - simpa [hC] using hres - -/-- Residual composition bound from `Valid` plus component bounds. -/ -theorem residual_bound_of_component_bounds_valid - {S : Type*} [Fintype S] [DecidableEq S] [Nonempty S] - (l : LayerAmplificationCert) (eps : Rat) (sqrtPrecBits : Nat) - (seqLen modelDim headDim : Nat) (A M : SignedMixer S S) - (hValid : l.Valid eps sqrtPrecBits seqLen modelDim headDim) - (hA : SignedMixer.operatorNormBound A ≤ (l.attnJacBound : ℝ)) - (hM : SignedMixer.operatorNormBound M ≤ (l.mlpJacBound : ℝ)) : - SignedMixer.operatorNormBound - ((SignedMixer.identity + A).comp (SignedMixer.identity + M) - SignedMixer.identity) - ≤ (l.C : ℝ) := by - have hC := c_eq_cast_of_valid eps sqrtPrecBits seqLen modelDim headDim l hValid - exact residual_bound_of_component_bounds (l := l) (A := A) (M := M) hC hA hM - - -theorem withSoftmaxMargin_spec : - withSoftmaxMargin = withSoftmaxMargin := rfl -theorem Valid_spec : Valid = Valid := rfl -theorem check_spec : check = check := rfl - -end LayerAmplificationCert - -/-- Model-level certification report. -/ -structure ModelCert where - modelPath : String - inputPath? : Option String - inputDelta : Rat - eps : Rat - /-- Sequence length from the model header. -/ - seqLen : Nat - /-- Model dimension from the model header. -/ - modelDim : Nat - /-- Head dimension from the model header. -/ - headDim : Nat - /-- Precision in dyadic bits for local LayerNorm bounds. -/ - soundnessBits : Nat - /-- Which GeLU derivative target the model uses. -/ - geluDerivTarget : GeluDerivTarget - actDerivBound : Rat - softmaxJacobianNormInfWorst : Rat - layers : Array LayerAmplificationCert - totalAmplificationFactor : Rat - deriving Repr - -namespace ModelCert - -/-- Internal consistency checks for a reported sound certificate. -/ -def Valid (c : ModelCert) : Prop := - 0 < c.eps ∧ - c.softmaxJacobianNormInfWorst = Nfp.Sound.softmaxJacobianNormInfWorst ∧ - c.actDerivBound = c.layers.foldl (fun acc l => max acc l.mlpActDerivBound) 0 ∧ - (∀ i : Fin c.layers.size, - let l := c.layers[i] - l.layerIdx = i.val ∧ - LayerAmplificationCert.Valid c.eps c.soundnessBits c.seqLen c.modelDim c.headDim l) ∧ - c.totalAmplificationFactor = - c.layers.foldl (fun acc l => acc * (1 + l.C)) 1 - -instance (c : ModelCert) : Decidable (Valid c) := by - unfold Valid - infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (c : ModelCert) : Bool := - decide (Valid c) - -theorem check_iff (c : ModelCert) : c.check = true ↔ c.Valid := by - simp [check, Valid] - -/-- Replace a layer and recompute total amplification factor. -/ -def withUpdatedLayer (c : ModelCert) (layerIdx : Nat) (layer : LayerAmplificationCert) : - Option ModelCert := - if layer.layerIdx ≠ layerIdx then - none - else if layerIdx < c.layers.size then - let layers := c.layers.set! layerIdx layer - let total := layers.foldl (fun acc l => acc * (1 + l.C)) 1 - some { c with layers := layers, totalAmplificationFactor := total } - else - none - -/-- Pretty printer. -/ -def toString (c : ModelCert) : String := - let header := - s!"SOUND mode: conservative bounds; may be much looser than heuristic analysis.\n" ++ - (match c.inputPath? with - | some p => s!"input={p}, delta={c.inputDelta}\n" - | none => "") ++ - s!"eps={c.eps}, seqLen={c.seqLen}, modelDim={c.modelDim}, headDim={c.headDim}, " ++ - s!"soundnessBits={c.soundnessBits}, " ++ - s!"geluDerivTarget={geluDerivTargetToString c.geluDerivTarget}, " ++ - s!"actDerivBound={c.actDerivBound}, " ++ - s!"softmaxJacobianNormInfWorst={c.softmaxJacobianNormInfWorst}\n" ++ - s!"totalAmplificationFactor={c.totalAmplificationFactor}\n" - let body := - c.layers.foldl (fun acc l => - acc ++ - s!"Layer {l.layerIdx}: C={l.C} (attn={l.attnJacBound}, \ -mlp={l.mlpJacBound}, cross={l.attnJacBound * l.mlpJacBound}, \ -attnValueCoeff={l.attnValueCoeff}, attnPatternCoeff={l.attnPatternCoeff}, \ -wqOpBoundMax={l.wqOpBoundMax}, wkOpBoundMax={l.wkOpBoundMax}, \ -ln1OutMaxAbsBound={l.ln1OutMaxAbsBound}, \ -mlpWinBound={l.mlpWinBound}, mlpWoutBound={l.mlpWoutBound}, \ -mlpActDerivBound={l.mlpActDerivBound}, \ -softmaxJacobianNormInfUpperBound={l.softmaxJacobianNormInfUpperBound}, \ -softmaxMarginLowerBound={l.softmaxMarginLowerBound}, softmaxExpEffort={l.softmaxExpEffort}, \ -ln1Bound={l.ln1Bound}, ln2Bound={l.ln2Bound}" ++ - (match l.ln1VarianceLowerBound? with - | some v => s!", ln1Var≥{v}" - | none => "") ++ - (match l.ln2VarianceLowerBound? with - | some v => s!", ln2Var≥{v}" - | none => "") ++ - ")\n") "" - header ++ body - -instance : ToString ModelCert := ⟨toString⟩ - -theorem Valid_spec : Valid = Valid := rfl -theorem check_spec : check = check := rfl -theorem withUpdatedLayer_spec : withUpdatedLayer = withUpdatedLayer := rfl -theorem toString_spec : toString = toString := rfl - -end ModelCert - -/-! ### Certificate verification helpers -/ - -/-- Verify weight-derived bounds from per-layer arrays. -/ -def checkWeightBoundsArrays (cert : ModelCert) - (attnValueCoeff wqOpBoundMax wkOpBoundMax mlpWinBound mlpWoutBound - ln1MaxAbsGamma ln1MaxAbsBeta ln2MaxAbsGamma : Array Rat) : Except String Unit := - Id.run do - if attnValueCoeff.size ≠ cert.layers.size then - return .error "attnValueCoeff layer count mismatch" - if wqOpBoundMax.size ≠ cert.layers.size then - return .error "wqOpBoundMax layer count mismatch" - if wkOpBoundMax.size ≠ cert.layers.size then - return .error "wkOpBoundMax layer count mismatch" - if mlpWinBound.size ≠ cert.layers.size then - return .error "mlpWinBound layer count mismatch" - if mlpWoutBound.size ≠ cert.layers.size then - return .error "mlpWoutBound layer count mismatch" - if ln1MaxAbsGamma.size ≠ cert.layers.size then - return .error "ln1MaxAbsGamma layer count mismatch" - if ln1MaxAbsBeta.size ≠ cert.layers.size then - return .error "ln1MaxAbsBeta layer count mismatch" - if ln2MaxAbsGamma.size ≠ cert.layers.size then - return .error "ln2MaxAbsGamma layer count mismatch" - for idx in [:cert.layers.size] do - let expValue := attnValueCoeff[idx]! - let expWq := wqOpBoundMax[idx]! - let expWk := wkOpBoundMax[idx]! - let expMlpWin := mlpWinBound[idx]! - let expMlpWout := mlpWoutBound[idx]! - let expLn1Gamma := ln1MaxAbsGamma[idx]! - let expLn1Beta := ln1MaxAbsBeta[idx]! - let expLn2Gamma := ln2MaxAbsGamma[idx]! - let layer := cert.layers[idx]! - if expValue ≠ layer.attnValueCoeff then - return .error s!"attnValueCoeff mismatch at layer {idx}" - if expWq ≠ layer.wqOpBoundMax then - return .error s!"wqOpBoundMax mismatch at layer {idx}" - if expWk ≠ layer.wkOpBoundMax then - return .error s!"wkOpBoundMax mismatch at layer {idx}" - if expMlpWin ≠ layer.mlpWinBound then - return .error s!"mlpWinBound mismatch at layer {idx}" - if expMlpWout ≠ layer.mlpWoutBound then - return .error s!"mlpWoutBound mismatch at layer {idx}" - if expLn1Gamma ≠ layer.ln1MaxAbsGamma then - return .error s!"ln1MaxAbsGamma mismatch at layer {idx}" - if expLn1Beta ≠ layer.ln1MaxAbsBeta then - return .error s!"ln1MaxAbsBeta mismatch at layer {idx}" - if expLn2Gamma ≠ layer.ln2MaxAbsGamma then - return .error s!"ln2MaxAbsGamma mismatch at layer {idx}" - return .ok () - -theorem checkWeightBoundsArrays_spec : - checkWeightBoundsArrays = checkWeightBoundsArrays := rfl - -/-- Ensure all layers have zero softmax margin evidence. -/ -def checkSoftmaxMarginZero (cert : ModelCert) : Except String Unit := - Id.run do - for idx in [:cert.layers.size] do - let layer := cert.layers[idx]! - if layer.softmaxMarginLowerBound ≠ 0 then - return .error s!"softmaxMarginLowerBound is unverified (layer {idx})" - return .ok () - -theorem checkSoftmaxMarginZero_spec : - checkSoftmaxMarginZero = checkSoftmaxMarginZero := rfl - -/-! ### Softmax probability interval checks -/ - -/-- Ensure the softmax probability interval matches the derived score bound. -/ -def checkSoftmaxProbIntervalDerived (cert : ModelCert) : Except String Unit := - Id.run do - for idx in [:cert.layers.size] do - let layer := cert.layers[idx]! - let scoreAbsBound := - attnScoreAbsBound cert.modelDim cert.headDim layer.ln1OutMaxAbsBound - layer.wqOpBoundMax layer.wkOpBoundMax - let (probLo, probHi) := - softmaxProbIntervalFromScoreAbsBound cert.seqLen scoreAbsBound - layer.softmaxExpEffort - if layer.softmaxProbLo ≠ probLo then - return .error s!"softmaxProbLo mismatch at layer {idx}" - if layer.softmaxProbHi ≠ probHi then - return .error s!"softmaxProbHi mismatch at layer {idx}" - return .ok () - -theorem checkSoftmaxProbIntervalDerived_spec : - checkSoftmaxProbIntervalDerived = checkSoftmaxProbIntervalDerived := rfl - -/-- Update a layer certificate with best-match softmax evidence if it is valid and tighter. -/ -def tightenLayerSoftmaxFromBestMatch - (seqLen modelDim headDim : Nat) (layer : LayerAmplificationCert) - (cert : LayerBestMatchMarginCert) : - Except String LayerAmplificationCert := - Id.run do - if !cert.check then - return .error "layer best-match margin cert failed internal checks" - if cert.layerIdx ≠ layer.layerIdx then - return .error "layer margin cert does not match layer index" - if cert.seqLen ≠ seqLen then - return .error "layer margin cert seq_len mismatch" - let updated := - LayerAmplificationCert.withSoftmaxMargin seqLen modelDim headDim - cert.marginLowerBound cert.softmaxExpEffort layer - if updated.softmaxJacobianNormInfUpperBound > layer.softmaxJacobianNormInfUpperBound then - return .error "best-match softmax bound is worse than baseline" - return .ok updated - -theorem tightenLayerSoftmaxFromBestMatch_spec : - tightenLayerSoftmaxFromBestMatch = tightenLayerSoftmaxFromBestMatch := rfl - -/-- Apply best-match margin updates to a whole model certificate. -/ -def tightenModelCertBestMatchMargins - (c : ModelCert) (certs : Array LayerBestMatchMarginCert) : - Except String ModelCert := - certs.foldl (fun acc cert => - match acc with - | .error e => .error e - | .ok cur => - if cert.layerIdx < cur.layers.size then - let layer := cur.layers[cert.layerIdx]! - match tightenLayerSoftmaxFromBestMatch cur.seqLen cur.modelDim cur.headDim - layer cert with - | .error e => .error e - | .ok updatedLayer => - match ModelCert.withUpdatedLayer cur cert.layerIdx updatedLayer with - | none => .error "failed to update model cert layer" - | some updated => .ok updated - else - .error s!"layer margin cert index {cert.layerIdx} out of range") (.ok c) - -theorem tightenModelCertBestMatchMargins_spec : - tightenModelCertBestMatchMargins = tightenModelCertBestMatchMargins := rfl - -/-- Verify a model certificate against header metadata and expected attention bounds. -/ -def verifyModelCert - (cert : ModelCert) - (eps : Rat) - (soundnessBits : Nat) - (geluDerivTarget : GeluDerivTarget) - (attnValueCoeff wqOpBoundMax wkOpBoundMax mlpWinBound mlpWoutBound - ln1MaxAbsGamma ln1MaxAbsBeta ln2MaxAbsGamma : Array Rat) : - Except String ModelCert := - Id.run do - if cert.eps ≠ eps then - return .error "model header eps mismatch" - if cert.soundnessBits ≠ soundnessBits then - return .error "soundness bits mismatch" - if cert.geluDerivTarget ≠ geluDerivTarget then - return .error "model header gelu_kind mismatch" - if cert.check then - match checkSoftmaxProbIntervalDerived cert with - | .error e => return .error e - | .ok _ => - match checkSoftmaxMarginZero cert with - | .error e => return .error e - | .ok _ => - match checkWeightBoundsArrays cert attnValueCoeff wqOpBoundMax wkOpBoundMax - mlpWinBound mlpWoutBound ln1MaxAbsGamma ln1MaxAbsBeta ln2MaxAbsGamma with - | .error e => return .error e - | .ok _ => return .ok cert - return .error "sound certificate failed internal consistency checks" - -theorem verifyModelCert_spec : - verifyModelCert = verifyModelCert := rfl - -/-- Verify a model certificate and apply best-match margin tightening. -/ -def verifyModelCertBestMatchMargins - (cert : ModelCert) - (eps : Rat) - (soundnessBits : Nat) - (geluDerivTarget : GeluDerivTarget) - (attnValueCoeff wqOpBoundMax wkOpBoundMax mlpWinBound mlpWoutBound - ln1MaxAbsGamma ln1MaxAbsBeta ln2MaxAbsGamma : Array Rat) - (marginCerts : Array LayerBestMatchMarginCert) : Except String ModelCert := - Id.run do - match verifyModelCert cert eps soundnessBits geluDerivTarget - attnValueCoeff wqOpBoundMax wkOpBoundMax - mlpWinBound mlpWoutBound ln1MaxAbsGamma ln1MaxAbsBeta ln2MaxAbsGamma with - | .error e => return .error e - | .ok base => - match tightenModelCertBestMatchMargins base marginCerts with - | .error e => return .error e - | .ok tightened => - if tightened.check then - return .ok tightened - else - return .error "best-match margin tightening produced invalid cert" - -theorem verifyModelCertBestMatchMargins_spec : - verifyModelCertBestMatchMargins = verifyModelCertBestMatchMargins := rfl - -/-! ### Specs -/ - -end Nfp.Sound diff --git a/Legacy/Nfp/Sound/Decimal.lean b/Legacy/Nfp/Sound/Decimal.lean deleted file mode 100644 index c9a36a4..0000000 --- a/Legacy/Nfp/Sound/Decimal.lean +++ /dev/null @@ -1,253 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std - -namespace Nfp.Sound - -/-! -# Exact decimal/scientific parsing for sound certification - -This module parses decimal and scientific-notation numerals (e.g. `-1.25e-3`) into `Rat`. - -Design goal: avoid `Float` as a source of truth in the sound certification path. -We only use exact integer arithmetic and powers of 10. --/ - -/-- Parse a signed integer written in base-10 (optional leading `+`/`-`). -/ -def parseInt10 (s : String) : Except String Int := - let s := s.trim - if s.isEmpty then - .error "empty integer" - else - let (neg, rest) := - if s.startsWith "-" then (true, s.drop 1) - else if s.startsWith "+" then (false, s.drop 1) - else (false, s) - if rest.isEmpty then - .error s!"invalid integer '{s}'" - else - match rest.toNat? with - | none => .error s!"invalid integer '{s}'" - | some n => - let i : Int := Int.ofNat n - .ok (if neg then -i else i) - -/-- Parse a base-10 natural number; empty string is treated as 0. -/ -def parseNat10OrZero (s : String) : Except String Nat := - let s := s.trim - if s.isEmpty then - .ok 0 - else - match s.toNat? with - | none => .error s!"invalid natural '{s}'" - | some n => .ok n - -/-- Parse a decimal/scientific numeral from a substring into an exact `Rat`. -/ -def parseRatRange (s : String) (start stop : String.Pos.Raw) : Except String Rat := Id.run do - if start >= stop then - return .error "empty numeral" - - let token := fun () => String.Pos.Raw.extract s start stop - - -- sign - let mut p := start - let mut neg := false - let c0 := p.get s - if c0 = '-' then - neg := true - p := p.next s - else if c0 = '+' then - p := p.next s - - -- optional exponent (exactly one `e`, otherwise exactly one `E`). - let mut ePos : Option String.Pos.Raw := none - let mut eCount : Nat := 0 - let mut EPos : Option String.Pos.Raw := none - let mut ECount : Nat := 0 - let mut q := p - while q < stop do - let c := q.get s - if c = 'e' then - eCount := eCount + 1 - if eCount = 1 then ePos := some q - else if c = 'E' then - ECount := ECount + 1 - if ECount = 1 then EPos := some q - q := q.next s - - let expMarker? : Option String.Pos.Raw := - if eCount = 1 then ePos else if ECount = 1 then EPos else none - - let mantEnd : String.Pos.Raw := - match expMarker? with - | some ep => ep - | none => stop - - -- mantissa: intPart.fracPart - let mut dotPos : Option String.Pos.Raw := none - let mut dotCount : Nat := 0 - let mut r := p - while r < mantEnd do - if r.get s = '.' then - dotCount := dotCount + 1 - if dotCount = 1 then dotPos := some r - r := r.next s - if dotCount > 1 then - return .error s!"invalid numeral '{token ()}'" - - let intStart := p - let intStop : String.Pos.Raw := - match dotPos with - | some dp => dp - | none => mantEnd - let fracStart? : Option String.Pos.Raw := - match dotPos with - | some dp => some (dp.next s) - | none => none - let fracStop := mantEnd - - let parseNatRangeOrZero (start stop : String.Pos.Raw) : Except String (Nat × Nat) := Id.run do - if start >= stop then - return .ok (0, 0) - let mut p := start - let mut acc : Nat := 0 - let mut len : Nat := 0 - while p < stop do - let c := p.get s - if ('0' <= c) && (c <= '9') then - acc := acc * 10 + (c.toNat - '0'.toNat) - len := len + 1 - p := p.next s - else - let tok := String.Pos.Raw.extract s start stop - return .error s!"invalid natural '{tok}'" - return .ok (acc, len) - - let parseIntRange (start stop : String.Pos.Raw) : Except String Int := Id.run do - if start >= stop then - return .error "empty integer" - let tok := String.Pos.Raw.extract s start stop - let mut p := start - let mut neg := false - let c0 := p.get s - if c0 = '-' then - neg := true - p := p.next s - else if c0 = '+' then - p := p.next s - if p >= stop then - return .error s!"invalid integer '{tok}'" - let mut acc : Nat := 0 - while p < stop do - let c := p.get s - if ('0' <= c) && (c <= '9') then - acc := acc * 10 + (c.toNat - '0'.toNat) - p := p.next s - else - return .error s!"invalid integer '{tok}'" - let i : Int := Int.ofNat acc - return .ok (if neg then -i else i) - - let buildResult (iNat fNat fracLen : Nat) (expInt : Int) : Except String Rat := - -- Construct `Rat` in a single normalization step (avoids repeated gcd normalization). - let denomBase : Nat := Nat.pow 10 fracLen - let mantissaNat : Nat := iNat * denomBase + fNat - let num0 : Int := if neg then -(Int.ofNat mantissaNat) else (Int.ofNat mantissaNat) - let expAbs : Nat := Int.natAbs expInt - let pow10Nat : Nat := Nat.pow 10 expAbs - - let den : Nat := - if expInt < 0 then denomBase * pow10Nat else denomBase - let num : Int := - if expInt > 0 then num0 * (Int.ofNat pow10Nat) else num0 - - have den_nz : den ≠ 0 := by - have h10pos : (0 : Nat) < 10 := by decide - have hpow1 : denomBase ≠ 0 := by - exact Nat.ne_of_gt (Nat.pow_pos (n := fracLen) h10pos) - have hpow2 : pow10Nat ≠ 0 := by - exact Nat.ne_of_gt (Nat.pow_pos (n := expAbs) h10pos) - by_cases hneg : expInt < 0 - · -- `den = denomBase * pow10Nat` - simpa [den, hneg] using Nat.mul_ne_zero hpow1 hpow2 - · -- `den = denomBase` - simpa [den, hneg] using hpow1 - - .ok (Rat.normalize num den (den_nz := den_nz)) - - let result : Except String Rat := - match parseNatRangeOrZero intStart intStop with - | .error e => .error e - | .ok (iNat, _) => - match fracStart? with - | none => - match expMarker? with - | none => buildResult iNat 0 0 0 - | some ep => - let expStart := ep.next s - match parseIntRange expStart stop with - | .error e => .error e - | .ok expInt => buildResult iNat 0 0 expInt - | some fs => - match parseNatRangeOrZero fs fracStop with - | .error e => .error e - | .ok (fNat, fracLen) => - match expMarker? with - | none => buildResult iNat fNat fracLen 0 - | some ep => - let expStart := ep.next s - match parseIntRange expStart stop with - | .error e => .error e - | .ok expInt => buildResult iNat fNat fracLen expInt - - return result - -/-- Parse a decimal/scientific numeral into an exact `Rat`. - -Supported forms: -- `123`, `-123`, `+123` -- `1.25`, `.25`, `2.` -- `1e3`, `1E3`, `-1.25e-3` - -This is intended for `.nfpt` parsing in sound mode. --/ -def parseRat (s : String) : Except String Rat := do - let s := s.trim - if s.isEmpty then - throw "empty numeral" - parseRatRange s 0 s.rawEndPos - -/-- Parse a line of space-separated rationals, failing on the first invalid token. -/ -def parseRatLine (line : String) : Except String (Array Rat) := Id.run do - let isWs (c : Char) : Bool := - c = ' ' || c = '\t' || c = '\n' || c = '\r' - let mut out : Array Rat := #[] - let s := line - let mut p : String.Pos.Raw := 0 - let stop := s.rawEndPos - while p < stop do - while p < stop && isWs (p.get s) do - p := p.next s - let tokStart := p - while p < stop && !isWs (p.get s) do - p := p.next s - if tokStart < p then - match parseRatRange s tokStart p with - | .error e => return .error e - | .ok r => out := out.push r - return .ok out - -/-! ### Specs -/ - -theorem parseInt10_spec (s : String) : parseInt10 s = parseInt10 s := rfl - -theorem parseNat10OrZero_spec (s : String) : parseNat10OrZero s = parseNat10OrZero s := rfl - -theorem parseRatRange_spec (s : String) (start stop : String.Pos.Raw) : - parseRatRange s start stop = parseRatRange s start stop := rfl - -theorem parseRat_spec (s : String) : parseRat s = parseRat s := rfl - -theorem parseRatLine_spec (line : String) : parseRatLine line = parseRatLine line := rfl - -end Nfp.Sound diff --git a/Legacy/Nfp/Sound/Demo.lean b/Legacy/Nfp/Sound/Demo.lean deleted file mode 100644 index 83e28d0..0000000 --- a/Legacy/Nfp/Sound/Demo.lean +++ /dev/null @@ -1,103 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Tactic.Linarith -import Mathlib.Tactic.NormNum -import Mathlib.Tactic.FinCases -import Nfp.Linearization -import Nfp.Sound.Cert - -namespace Nfp.Sound - -open scoped BigOperators - -open Nfp - -/-- A tiny 2×2 signed mixer used to demonstrate the sound row-sum bound story. -/ -noncomputable def demoMixer : SignedMixer (Fin 1 × Fin 2) (Fin 1 × Fin 2) := - ⟨fun i j => - match i.2.val, j.2.val with - | 0, 0 => (1 : ℝ) - | 0, 1 => (2 : ℝ) - | 1, 0 => (-3 : ℝ) - | 1, 1 => (4 : ℝ) - | _, _ => 0⟩ - -/-- End-to-end toy lemma: the abstract `Linearization.operatorNormBound` is controlled by an -explicit row-sum bound for a concrete small matrix. - -This is intentionally tiny (2×2) but exercises the same definition (`max row sum of abs`). --/ -theorem demo_operatorNormBound_le : - Nfp.operatorNormBound demoMixer ≤ (7 : ℝ) := by - classical - -- Unfold to a `Finset.sup'` of row sums and check each row explicitly. - dsimp [Nfp.operatorNormBound, SignedMixer.operatorNormBound, demoMixer] - refine (Finset.sup'_le_iff (s := (Finset.univ : Finset (Fin 1 × Fin 2))) - (f := fun i : Fin 1 × Fin 2 => - ∑ x : Fin 1 × Fin 2, - abs - (match (i.2 : Nat), (x.2 : Nat) with - | 0, 0 => (1 : ℝ) - | 0, 1 => (2 : ℝ) - | 1, 0 => (-3 : ℝ) - | 1, 1 => (4 : ℝ) - | _, _ => 0)) - (H := Finset.univ_nonempty)).2 ?_ - intro i _hi - rcases i with ⟨i1, i2⟩ - fin_cases i1 - fin_cases i2 - · -- Row 0: |1| + |2| = 3 ≤ 7. - have hsum : - (∑ x : Fin 1 × Fin 2, - abs - (match (0 : Nat), (x.2 : Nat) with - | 0, 0 => (1 : ℝ) - | 0, 1 => (2 : ℝ) - | 1, 0 => (-3 : ℝ) - | 1, 1 => (4 : ℝ) - | _, _ => 0)) = 3 := by - simp [Fintype.sum_prod_type, Fin.sum_univ_two] - norm_num - nlinarith [hsum] - · -- Row 1: |−3| + |4| = 7 ≤ 7. - have hsum : - (∑ x : Fin 1 × Fin 2, - abs - (match (1 : Nat), (x.2 : Nat) with - | 0, 0 => (1 : ℝ) - | 0, 1 => (2 : ℝ) - | 1, 0 => (-3 : ℝ) - | 1, 1 => (4 : ℝ) - | _, _ => 0)) = 7 := by - simp [Fintype.sum_prod_type, Fin.sum_univ_two] - norm_num - nlinarith [hsum] - -/-! ## Executable checker sanity test -/ - -/-- A tiny inconsistent certificate used to sanity-check the boolean checker. -/ -def demoBadCert : ModelCert := - { modelPath := "" - inputPath? := none - inputDelta := 0 - eps := 0 - seqLen := 0 - modelDim := 0 - headDim := 0 - soundnessBits := 20 - geluDerivTarget := .tanh - actDerivBound := 0 - softmaxJacobianNormInfWorst := 0 - layers := #[] - totalAmplificationFactor := 1 } - -theorem demoBadCert_check : ModelCert.check demoBadCert = false := by - simp [ModelCert.check, ModelCert.Valid, demoBadCert, softmaxJacobianNormInfWorst] - -/-! ### Specs -/ - -theorem demoMixer_spec : demoMixer = demoMixer := rfl -theorem demoBadCert_spec : demoBadCert = demoBadCert := rfl - -end Nfp.Sound diff --git a/Legacy/Nfp/Sound/Fixed.lean b/Legacy/Nfp/Sound/Fixed.lean deleted file mode 100644 index cd30b8d..0000000 --- a/Legacy/Nfp/Sound/Fixed.lean +++ /dev/null @@ -1,400 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Nfp.Sound.Activation - -namespace Nfp.Sound - -/-! -# Fixed-point (base-10) arithmetic for SOUND-mode streaming - -We represent real numbers as scaled integers with a **global** scale `S = 10^p`. - -This file is intentionally `Int`/`Nat`-only: it is used to avoid `Rat.normalize`/gcd costs on -hot paths (matrix streaming / IBP), while preserving mathematical rigor by using conservative -outward rounding when rescaling after multiplication/division. --/ - -structure Fixed10Cfg where - /-- Scale exponent `p` in `S = 10^p`. -/ - scalePow10 : Nat - deriving Repr - -namespace Fixed10Cfg - -def scaleNat (cfg : Fixed10Cfg) : Nat := Nat.pow 10 cfg.scalePow10 -def scaleInt (cfg : Fixed10Cfg) : Int := Int.ofNat cfg.scaleNat - -theorem scaleNat_def (cfg : Fixed10Cfg) : scaleNat cfg = Nat.pow 10 cfg.scalePow10 := rfl - -theorem scaleInt_def (cfg : Fixed10Cfg) : scaleInt cfg = Int.ofNat cfg.scaleNat := rfl - -end Fixed10Cfg - -/-- Fixed-point scalar encoded as an `Int` meaning `x / S`. -/ -abbrev Fixed10 := Int - -/-- Closed fixed-point interval `[lo, hi]` (both in scaled integer units). -/ -structure Fixed10Interval where - lo : Fixed10 - hi : Fixed10 - deriving Repr - -namespace Fixed10Interval - -instance : Inhabited Fixed10Interval := ⟨{ lo := 0, hi := 0 }⟩ - -def const (x : Fixed10) : Fixed10Interval := { lo := x, hi := x } - -def union (a b : Fixed10Interval) : Fixed10Interval := - { lo := min a.lo b.lo, hi := max a.hi b.hi } - -def add (a b : Fixed10Interval) : Fixed10Interval := - { lo := a.lo + b.lo, hi := a.hi + b.hi } - -def sub (a b : Fixed10Interval) : Fixed10Interval := - { lo := a.lo - b.hi, hi := a.hi - b.lo } - -def relu (a : Fixed10Interval) : Fixed10Interval := - { lo := max 0 a.lo, hi := max 0 a.hi } - -/-- Conservative GeLU hull using a linear lower bound `GeLU(x) ≥ x/2`. - -For both exact and tanh GeLU, `GeLU(x) = x * g(x)` with `g(x) ∈ [0, 1]` and -`g(x) ≥ 1/2` when `x ≥ 0`, so `x/2` is a global lower bound. --/ -def geluOverapprox (a : Fixed10Interval) : Fixed10Interval := - { lo := a.lo.ediv (Int.ofNat 2), hi := max a.hi 0 } - -private def absInt (x : Int) : Int := if x < 0 then -x else x - -/-- Maximum absolute endpoint (in scaled integer units). -/ -def absUpper (a : Fixed10Interval) : Int := - max (absInt a.lo) (absInt a.hi) - -/-- Upper bound on `|x - μ|` for any `x, μ ∈ [lo, hi]` (scaled units). -/ -def centeredAbsBound (a : Fixed10Interval) : Int := - absInt (a.hi - a.lo) - -/-- Upper bound on `max |gelu'(x)|` over a fixed-point interval. -/ -def geluDerivBound (cfg : Fixed10Cfg) (target : GeluDerivTarget) (a : Fixed10Interval) : Rat := - let maxAbsInt := absUpper a - let maxAbsRat : Rat := - Rat.normalize maxAbsInt cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - let maxAbsSq := maxAbsRat * maxAbsRat - let half : Rat := (1 : Rat) / 2 - match target with - | .exact => - min (1 + half * maxAbsRat) 2 - | .tanh => - let c : Rat := (44715 : Rat) / 1000000 - let slope := 1 + (3 : Rat) * c * maxAbsSq - let localBound := 1 + half * maxAbsRat * slope - min localBound 2 - -/-- Floor division by a positive `Nat` divisor. -/ -private def floorDivNat (a : Int) (d : Nat) : Int := - -- `Int.ediv` is Euclidean division (for positive divisor): `a = q*d + r`, `0 ≤ r < d`. - a.ediv (Int.ofNat d) - -/-- Ceil division by a positive `Nat` divisor. -/ -private def ceilDivNat (a : Int) (d : Nat) : Int := - let di : Int := Int.ofNat d - let q := a.ediv di - let r := a.emod di - if r = 0 then q else q + 1 - -/-- Rescale an interval from scale `S^2` down to `S` using conservative rounding. -/ -private def rescaleFromSq (cfg : Fixed10Cfg) (loSq hiSq : Int) : Fixed10Interval := - let S : Nat := cfg.scaleNat - { lo := floorDivNat loSq S, hi := ceilDivNat hiSq S } - -/-- Multiply two fixed-point intervals, returning an interval at the same scale. - -If `a,b` are in units of `1/S`, then their product is in units of `1/S^2`; we rescale back to `1/S` -with outward rounding to remain conservative. --/ -def mul (cfg : Fixed10Cfg) (a b : Fixed10Interval) : Fixed10Interval := - let p1 := a.lo * b.lo - let p2 := a.lo * b.hi - let p3 := a.hi * b.lo - let p4 := a.hi * b.hi - let loSq := min (min p1 p2) (min p3 p4) - let hiSq := max (max p1 p2) (max p3 p4) - rescaleFromSq cfg loSq hiSq - -/-- Add a constant vector to a vector of intervals. -/ -def addConstVec (xs : Array Fixed10Interval) (c : Array Fixed10Interval) : Array Fixed10Interval := - if xs.size = c.size then - let n := xs.size - Array.ofFn fun (i : Fin n) => - add xs[i] (c.getD i.val default) - else - xs - -/-- Elementwise union of two interval vectors. -/ -def unionVec (a b : Array Fixed10Interval) : Array Fixed10Interval := - if a.size = b.size then - let n := a.size - Array.ofFn fun (i : Fin n) => - union a[i] (b.getD i.val default) - else - a - -/-! ### Specs -/ - -theorem Fixed10_spec : Fixed10 = Int := rfl - -theorem const_def (x : Fixed10) : Fixed10Interval.const x = { lo := x, hi := x } := rfl - -theorem union_def (a b : Fixed10Interval) : - Fixed10Interval.union a b = { lo := min a.lo b.lo, hi := max a.hi b.hi } := rfl - -theorem add_def (a b : Fixed10Interval) : - Fixed10Interval.add a b = { lo := a.lo + b.lo, hi := a.hi + b.hi } := rfl - -theorem sub_def (a b : Fixed10Interval) : - Fixed10Interval.sub a b = { lo := a.lo - b.hi, hi := a.hi - b.lo } := rfl - -theorem relu_def (a : Fixed10Interval) : - Fixed10Interval.relu a = { lo := max 0 a.lo, hi := max 0 a.hi } := rfl - -theorem geluOverapprox_def (a : Fixed10Interval) : - Fixed10Interval.geluOverapprox a = - { lo := a.lo.ediv (Int.ofNat 2), hi := max a.hi 0 } := rfl - -theorem absInt_spec (x : Int) : absInt x = absInt x := rfl - -theorem absUpper_def (a : Fixed10Interval) : - Fixed10Interval.absUpper a = max (absInt a.lo) (absInt a.hi) := rfl - -theorem centeredAbsBound_def (a : Fixed10Interval) : - Fixed10Interval.centeredAbsBound a = absInt (a.hi - a.lo) := rfl - -theorem geluDerivBound_def (cfg : Fixed10Cfg) (target : GeluDerivTarget) (a : Fixed10Interval) : - Fixed10Interval.geluDerivBound cfg target a = - let maxAbsInt := absUpper a - let maxAbsRat : Rat := - Rat.normalize maxAbsInt cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - let maxAbsSq := maxAbsRat * maxAbsRat - let half : Rat := (1 : Rat) / 2 - match target with - | .exact => - min (1 + half * maxAbsRat) 2 - | .tanh => - let c : Rat := (44715 : Rat) / 1000000 - let slope := 1 + (3 : Rat) * c * maxAbsSq - let localBound := 1 + half * maxAbsRat * slope - min localBound 2 := rfl - -theorem floorDivNat_spec (a : Int) (d : Nat) : floorDivNat a d = floorDivNat a d := rfl - -theorem ceilDivNat_spec (a : Int) (d : Nat) : ceilDivNat a d = ceilDivNat a d := rfl - -theorem rescaleFromSq_spec (cfg : Fixed10Cfg) (loSq hiSq : Int) : - rescaleFromSq cfg loSq hiSq = rescaleFromSq cfg loSq hiSq := rfl - -theorem mul_spec (cfg : Fixed10Cfg) (a b : Fixed10Interval) : - Fixed10Interval.mul cfg a b = Fixed10Interval.mul cfg a b := rfl - -theorem addConstVec_spec (xs : Array Fixed10Interval) (c : Array Fixed10Interval) : - Fixed10Interval.addConstVec xs c = Fixed10Interval.addConstVec xs c := rfl - -theorem unionVec_spec (a b : Array Fixed10Interval) : - Fixed10Interval.unionVec a b = Fixed10Interval.unionVec a b := rfl - -end Fixed10Interval - -/-! -## Fast decimal → fixed-point parsing - -We parse a decimal/scientific numeral token into a **rounded** scaled integer at scale `S = 10^p` -without constructing a `Rat` (and therefore without gcd normalization). - -Correctness contract (soundness): -- The returned integer `r` is a rounding of the exact scaled value `x*S`. -- If later we treat the true scaled value as lying in `[r-1, r+1]`, then this interval always - contains the exact scaled value (since the exact value lies between `floor` and `ceil`). --/ - -private def isDigit (b : UInt8) : Bool := (48 ≤ b) && (b ≤ 57) -private def digitVal (b : UInt8) : Nat := (b.toNat - 48) - -private def pow10Nat (k : Nat) : Nat := Nat.pow 10 k - -/-- Parse an `Int` exponent written in base-10 from a byte slice. -/ -private def parseExpInt (bytes : ByteArray) (start stop : Nat) : Except String Int := - if start ≥ stop then - .error "invalid exponent" - else - Id.run do - let mut i := start - let mut neg : Bool := false - let b0 := bytes[i]! - if b0 = 45 then -- '-' - neg := true - i := i + 1 - else if b0 = 43 then -- '+' - i := i + 1 - if i ≥ stop then - return .error "invalid exponent" - let mut acc : Int := 0 - while i < stop do - let b := bytes[i]! - if !isDigit b then - return .error "invalid exponent digit" - acc := acc * 10 + (Int.ofNat (digitVal b)) - i := i + 1 - return .ok (if neg then -acc else acc) - -/-- Parse a token into a rounded scaled integer at scale `10^scalePow10`. -/ -def parseFixed10Rounded (scalePow10 : Nat) (bytes : ByteArray) (start stop : Nat) : - Except String Int := - if start ≥ stop then - .error "empty token" - else - Id.run do - let mut i := start - -- sign - let mut neg : Bool := false - let b0 := bytes[i]! - if b0 = 45 then -- '-' - neg := true - i := i + 1 - else if b0 = 43 then - i := i + 1 - - -- mantissa with optional '.' - let mut mant : Int := 0 - let mut fracLen : Nat := 0 - let mut seenDot : Bool := false - let mut anyDigit : Bool := false - while i < stop do - let b := bytes[i]! - if b = 46 then -- '.' - if seenDot then - return .error "invalid numeral (multiple dots)" - seenDot := true - i := i + 1 - else if b = 101 || b = 69 then -- 'e' or 'E' - break - else if isDigit b then - anyDigit := true - mant := mant * 10 + (Int.ofNat (digitVal b)) - if seenDot then - fracLen := fracLen + 1 - i := i + 1 - else - return .error "invalid numeral" - if !anyDigit then - return .error "invalid numeral (no digits)" - - -- optional exponent - let mut exp : Int := 0 - if i < stop then - let b := bytes[i]! - if b = 101 || b = 69 then - match parseExpInt bytes (i + 1) stop with - | .error e => return .error e - | .ok e => exp := e - - -- scaled value: mant * 10^(scalePow10 + exp - fracLen) - let expTotal : Int := (Int.ofNat scalePow10) + exp - (Int.ofNat fracLen) - let num0 : Int := if neg then -mant else mant - if expTotal ≥ 0 then - let eNat : Nat := Int.toNat expTotal - let pow : Int := Int.ofNat (pow10Nat eNat) - return .ok (num0 * pow) - else - let eNat : Nat := Int.toNat (-expTotal) - let denNat : Nat := pow10Nat eNat - let den : Int := Int.ofNat denNat - let q := num0.ediv den - let r := num0.emod den - if r = 0 then - return .ok q - -- Round-to-nearest (ties up). Always within 1 of the exact scaled value. - let twoR := (2 : Int) * r - if twoR < den then - return .ok q - else - return .ok (q + 1) - -/-! -### Token folding helpers (line-based) - -These helpers mirror the `foldRatTokens` pattern used elsewhere, but avoid allocating token -substrings by scanning whitespace boundaries in the UTF-8 byte array of each line. --/ - -private def isWs (b : UInt8) : Bool := b = 32 || b = 9 -- ' ' or '\t' - -def foldFixed10Tokens {α : Type} - (scalePow10 : Nat) - (lines : Array String) - (start : Nat) - (count : Nat) - (state : α) - (step : α → Int → α) : Except String (α × Nat) := - Id.run do - let mut i := start - let mut remaining := count - let mut st := state - while remaining > 0 do - if i ≥ lines.size then - return .error "unexpected end of file while reading fixed tokens" - let line := lines[i]!.trim - i := i + 1 - if line.isEmpty then - pure () - else - let bytes := line.toUTF8 - let mut j : Nat := 0 - while j < bytes.size && remaining > 0 do - while j < bytes.size && isWs (bytes[j]!) do - j := j + 1 - if j ≥ bytes.size then - break - let tokStart := j - while j < bytes.size && !isWs (bytes[j]!) do - j := j + 1 - let tokStop := j - match parseFixed10Rounded scalePow10 bytes tokStart tokStop with - | .error e => return .error e - | .ok x => - st := step st x - remaining := remaining - 1 - return .ok (st, i) - -/-! ### Specs -/ - -theorem isDigit_spec (b : UInt8) : isDigit b = isDigit b := rfl - -theorem digitVal_spec (b : UInt8) : digitVal b = digitVal b := rfl - -theorem pow10Nat_spec (k : Nat) : pow10Nat k = pow10Nat k := rfl - -theorem parseExpInt_spec (bytes : ByteArray) (start stop : Nat) : - parseExpInt bytes start stop = parseExpInt bytes start stop := rfl - -theorem parseFixed10Rounded_spec (scalePow10 : Nat) (bytes : ByteArray) (start stop : Nat) : - parseFixed10Rounded scalePow10 bytes start stop = - parseFixed10Rounded scalePow10 bytes start stop := rfl - -theorem isWs_spec (b : UInt8) : isWs b = isWs b := rfl - -theorem foldFixed10Tokens_spec {α : Type} - (scalePow10 : Nat) - (lines : Array String) - (start : Nat) - (count : Nat) - (state : α) - (step : α → Int → α) : - foldFixed10Tokens scalePow10 lines start count state step = - foldFixed10Tokens scalePow10 lines start count state step := rfl - -end Nfp.Sound diff --git a/Legacy/Nfp/Sound/HeadCert.lean b/Legacy/Nfp/Sound/HeadCert.lean deleted file mode 100644 index 012876e..0000000 --- a/Legacy/Nfp/Sound/HeadCert.lean +++ /dev/null @@ -1,748 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Nfp.Induction -import Nfp.Sound.Bounds - -namespace Nfp.Sound - -/-! -# Sound per-head contribution certificates - -This module defines a minimal, checkable certificate for per-head weight-only -contribution bounds. These are intended as a lightweight starting point for -sound circuit certification. --/ - -/-- Weight-only per-head operator-norm bounds and derived factors. -/ -structure HeadContributionCert where - layerIdx : Nat - headIdx : Nat - wqOpBound : Rat - wkOpBound : Rat - wvOpBound : Rat - woOpBound : Rat - qkFactorBound : Rat - voFactorBound : Rat - deriving Repr - -namespace HeadContributionCert - -/-- Internal consistency checks for derived factor bounds. -/ -def Valid (c : HeadContributionCert) : Prop := - c.qkFactorBound = c.wqOpBound * c.wkOpBound ∧ - c.voFactorBound = c.wvOpBound * c.woOpBound - -instance (c : HeadContributionCert) : Decidable (Valid c) := by - unfold Valid - infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (c : HeadContributionCert) : Bool := - decide (Valid c) - -theorem check_iff (c : HeadContributionCert) : c.check = true ↔ c.Valid := by - simp [check, Valid] - -end HeadContributionCert - -/-- Local (input-dependent) per-head attention contribution bounds. -/ -structure HeadLocalContributionCert where - layerIdx : Nat - headIdx : Nat - /-- Precision in dyadic bits for local LayerNorm bounds. -/ - soundnessBits : Nat - ln1MaxAbsGamma : Rat - ln1VarianceLowerBound : Rat - ln1Bound : Rat - wqOpBound : Rat - wkOpBound : Rat - wvOpBound : Rat - woOpBound : Rat - qkFactorBound : Rat - /-- Upper bound on the softmax Jacobian row-sum norm for this head. -/ - softmaxJacobianNormInfUpperBound : Rat - /-- Upper bound on the per-head attention Jacobian contribution. -/ - attnJacBound : Rat - deriving Repr - -namespace HeadLocalContributionCert - -/-- Internal consistency checks for local per-head bounds. -/ -def Valid (eps : Rat) (c : HeadLocalContributionCert) : Prop := - 0 < eps ∧ - c.ln1Bound = - (if c.ln1VarianceLowerBound > 0 then - layerNormOpBoundLocal c.ln1MaxAbsGamma c.ln1VarianceLowerBound eps c.soundnessBits - else - layerNormOpBoundConservative c.ln1MaxAbsGamma eps c.soundnessBits) ∧ - c.qkFactorBound = c.wqOpBound * c.wkOpBound ∧ - c.attnJacBound = - c.ln1Bound * c.softmaxJacobianNormInfUpperBound * c.wvOpBound * c.woOpBound - -instance (eps : Rat) (c : HeadLocalContributionCert) : Decidable (Valid eps c) := by - unfold Valid - infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (eps : Rat) (c : HeadLocalContributionCert) : Bool := - decide (Valid eps c) - -theorem check_iff (eps : Rat) (c : HeadLocalContributionCert) : - c.check eps = true ↔ c.Valid eps := by - simp [check] - -end HeadLocalContributionCert - -/-- Local per-head attention pattern certificate (target logit dominance). -/ -structure HeadPatternCert where - layerIdx : Nat - headIdx : Nat - seqLen : Nat - targetOffset : Int - /-- Key-position offset used for token matching. -/ - keyOffset : Int - targetCountLowerBound : Nat - targetLogitLowerBound : Rat - otherLogitUpperBound : Rat - marginLowerBound : Rat - /-- Effort level for the `expLB` portfolio used in margin-to-weight bounds. -/ - softmaxExpEffort : Nat - targetWeightLowerBound : Rat - deriving Repr - -namespace HeadPatternCert - -/-- Internal consistency checks for pattern bounds. -/ -def Valid (c : HeadPatternCert) : Prop := - c.seqLen > 0 ∧ - c.targetCountLowerBound ≤ c.seqLen ∧ - c.marginLowerBound = c.targetLogitLowerBound - c.otherLogitUpperBound ∧ - c.targetWeightLowerBound = - softmaxTargetWeightLowerBound c.seqLen c.targetCountLowerBound - c.marginLowerBound c.softmaxExpEffort - -instance (c : HeadPatternCert) : Decidable (Valid c) := by - unfold Valid - infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (c : HeadPatternCert) : Bool := - decide (Valid c) - -theorem check_iff (c : HeadPatternCert) : c.check = true ↔ c.Valid := by - simp [check] - -end HeadPatternCert - -/-! ## Local value-direction bounds -/ - -/-- Safe lower bound for a convex mixture when only a lower bound on the match weight is known. -/ -def mixLowerBound (w m n : Rat) : Rat := - min m (w * m + (1 - w) * n) - -/-- Local per-head output lower bound for a single coordinate. -/ -structure HeadValueLowerBoundCert where - layerIdx : Nat - headIdx : Nat - coord : Nat - matchWeightLowerBound : Rat - matchCoordLowerBound : Rat - nonmatchCoordLowerBound : Rat - outputCoordLowerBound : Rat - deriving Repr - -namespace HeadValueLowerBoundCert - -/-- Internal consistency checks for the coordinate lower bound. -/ -def Valid (c : HeadValueLowerBoundCert) : Prop := - 0 ≤ c.matchWeightLowerBound ∧ - c.matchWeightLowerBound ≤ 1 ∧ - c.outputCoordLowerBound = - mixLowerBound c.matchWeightLowerBound c.matchCoordLowerBound c.nonmatchCoordLowerBound - -instance (c : HeadValueLowerBoundCert) : Decidable (Valid c) := by - unfold Valid - infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (c : HeadValueLowerBoundCert) : Bool := - decide (Valid c) - -theorem check_iff (c : HeadValueLowerBoundCert) : c.check = true ↔ c.Valid := by - simp [check, Valid] - -end HeadValueLowerBoundCert - -/-! ## Logit-direction bounds -/ - -/-- Local per-head logit-difference lower bound for a target direction. -/ -structure HeadLogitDiffLowerBoundCert where - layerIdx : Nat - headIdx : Nat - targetToken : Nat - negativeToken : Nat - matchWeightLowerBound : Rat - matchLogitLowerBound : Rat - nonmatchLogitLowerBound : Rat - logitDiffLowerBound : Rat - deriving Repr - -namespace HeadLogitDiffLowerBoundCert - -/-- Internal consistency checks for the logit-difference lower bound. -/ -def Valid (c : HeadLogitDiffLowerBoundCert) : Prop := - 0 ≤ c.matchWeightLowerBound ∧ - c.matchWeightLowerBound ≤ 1 ∧ - c.logitDiffLowerBound = - mixLowerBound c.matchWeightLowerBound c.matchLogitLowerBound c.nonmatchLogitLowerBound - -instance (c : HeadLogitDiffLowerBoundCert) : Decidable (Valid c) := by - unfold Valid - infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (c : HeadLogitDiffLowerBoundCert) : Bool := - decide (Valid c) - -theorem check_iff (c : HeadLogitDiffLowerBoundCert) : c.check = true ↔ c.Valid := by - simp [check, Valid] - -end HeadLogitDiffLowerBoundCert - -/-! ## Best-match pattern bounds -/ - -/-- Local per-head attention pattern certificate (best-match, single query position). -/ -structure HeadBestMatchPatternCert where - layerIdx : Nat - headIdx : Nat - seqLen : Nat - queryPos : Nat - targetOffset : Int - /-- Key-position offset used for token matching. -/ - keyOffset : Int - targetToken : Int - bestMatchLogitLowerBound : Rat - bestNonmatchLogitUpperBound : Rat - marginLowerBound : Rat - /-- Effort level for the `expLB` portfolio used in margin-to-probability bounds. -/ - softmaxExpEffort : Nat - bestMatchWeightLowerBound : Rat - /-- Softmax Jacobian row-sum bound derived from the max-probability lower bound. -/ - softmaxJacobianNormInfUpperBound : Rat - deriving Repr - -namespace HeadBestMatchPatternCert - -/-- Internal consistency checks for best-match pattern bounds. -/ -def Valid (c : HeadBestMatchPatternCert) : Prop := - c.seqLen > 0 ∧ - c.queryPos < c.seqLen ∧ - c.marginLowerBound = c.bestMatchLogitLowerBound - c.bestNonmatchLogitUpperBound ∧ - c.bestMatchWeightLowerBound = - softmaxMaxProbLowerBound c.seqLen c.marginLowerBound c.softmaxExpEffort ∧ - c.softmaxJacobianNormInfUpperBound = - softmaxJacobianNormInfBoundFromMargin c.seqLen c.marginLowerBound c.softmaxExpEffort - -instance (c : HeadBestMatchPatternCert) : Decidable (Valid c) := by - unfold Valid - infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (c : HeadBestMatchPatternCert) : Bool := - decide (Valid c) - -theorem check_iff (c : HeadBestMatchPatternCert) : c.check = true ↔ c.Valid := by - simp [check, Valid] - -end HeadBestMatchPatternCert - -/-! ## Layer-level best-match margin aggregation -/ - -/-- Index into a `(numHeads × seqLen)` margin array. -/ -def headQueryIndex (seqLen : Nat) (headIdx queryPos : Nat) : Nat := - headIdx * seqLen + queryPos - -/-- Populate a margin array from best-match certs; fails on duplicates or out-of-range indices. -/ -def marginsFromBestMatchCerts - (numHeads seqLen : Nat) (certs : Array HeadBestMatchPatternCert) : - Option (Array Rat) := - Id.run do - let size := numHeads * seqLen - let mut margins : Array Rat := Array.replicate size 0 - let mut seen : Array Bool := Array.replicate size false - for cert in certs do - if cert.headIdx < numHeads && cert.queryPos < seqLen then - let idx := headQueryIndex seqLen cert.headIdx cert.queryPos - if seen[idx]! then - return none - seen := seen.set! idx true - margins := margins.set! idx cert.marginLowerBound - else - return none - return some margins - -/-- Minimum margin over a nonempty array (defaults to `0` for empty input). -/ -def minMarginArray (margins : Array Rat) : Rat := - if margins.size = 0 then - 0 - else - margins.foldl (fun acc m => min acc m) margins[0]! - -/-- Layer-level best-match margin evidence aggregated across heads and query positions. -/ -structure LayerBestMatchMarginCert where - layerIdx : Nat - seqLen : Nat - numHeads : Nat - /-- Max softmax exp effort allowed for per-head best-match certificates. -/ - softmaxExpEffort : Nat - marginLowerBound : Rat - margins : Array Rat - headCerts : Array HeadBestMatchPatternCert - deriving Repr - -namespace LayerBestMatchMarginCert - -/-- Internal consistency checks for aggregated margins. -/ -def Valid (c : LayerBestMatchMarginCert) : Prop := - c.seqLen > 0 ∧ - c.numHeads > 0 ∧ - c.margins.size = c.numHeads * c.seqLen ∧ - c.headCerts.all (fun cert => - cert.check && - cert.layerIdx == c.layerIdx && - cert.seqLen == c.seqLen && - decide (cert.softmaxExpEffort ≤ c.softmaxExpEffort) && - cert.headIdx < c.numHeads && - cert.queryPos < c.seqLen) = true ∧ - marginsFromBestMatchCerts c.numHeads c.seqLen c.headCerts = some c.margins ∧ - c.marginLowerBound = minMarginArray c.margins - -instance (c : LayerBestMatchMarginCert) : Decidable (Valid c) := by - unfold Valid - infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (c : LayerBestMatchMarginCert) : Bool := - decide (Valid c) - -theorem check_iff (c : LayerBestMatchMarginCert) : c.check = true ↔ c.Valid := by - simp [check, Valid] - -end LayerBestMatchMarginCert - -/-! ## Best-match value/logit bounds -/ - -/-- Local per-head output lower bound for a single coordinate (single query position). -/ -structure HeadValueLowerBoundPosCert where - layerIdx : Nat - headIdx : Nat - queryPos : Nat - coord : Nat - matchWeightLowerBound : Rat - matchCoordLowerBound : Rat - nonmatchCoordLowerBound : Rat - outputCoordLowerBound : Rat - deriving Repr - -namespace HeadValueLowerBoundPosCert - -/-- Internal consistency checks for the coordinate lower bound. -/ -def Valid (c : HeadValueLowerBoundPosCert) : Prop := - 0 ≤ c.matchWeightLowerBound ∧ - c.matchWeightLowerBound ≤ 1 ∧ - c.outputCoordLowerBound = - mixLowerBound c.matchWeightLowerBound c.matchCoordLowerBound c.nonmatchCoordLowerBound - -instance (c : HeadValueLowerBoundPosCert) : Decidable (Valid c) := by - unfold Valid - infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (c : HeadValueLowerBoundPosCert) : Bool := - decide (Valid c) - -theorem check_iff (c : HeadValueLowerBoundPosCert) : c.check = true ↔ c.Valid := by - simp [check, Valid] - -end HeadValueLowerBoundPosCert - -/-- Local per-head logit-difference lower bound (single query position). -/ -structure HeadLogitDiffLowerBoundPosCert where - layerIdx : Nat - headIdx : Nat - queryPos : Nat - targetToken : Nat - negativeToken : Nat - matchWeightLowerBound : Rat - matchLogitLowerBound : Rat - nonmatchLogitLowerBound : Rat - logitDiffLowerBound : Rat - deriving Repr - -namespace HeadLogitDiffLowerBoundPosCert - -/-- Internal consistency checks for the logit-difference lower bound. -/ -def Valid (c : HeadLogitDiffLowerBoundPosCert) : Prop := - 0 ≤ c.matchWeightLowerBound ∧ - c.matchWeightLowerBound ≤ 1 ∧ - c.logitDiffLowerBound = - mixLowerBound c.matchWeightLowerBound c.matchLogitLowerBound c.nonmatchLogitLowerBound - -instance (c : HeadLogitDiffLowerBoundPosCert) : Decidable (Valid c) := by - unfold Valid - infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (c : HeadLogitDiffLowerBoundPosCert) : Bool := - decide (Valid c) - -theorem check_iff (c : HeadLogitDiffLowerBoundPosCert) : c.check = true ↔ c.Valid := by - simp [check, Valid] - -end HeadLogitDiffLowerBoundPosCert - -namespace HeadPatternCert - -/-- Convert a sound head pattern certificate into a token-match witness. -/ -def toTokenMatchPattern (c : HeadPatternCert) : Nfp.TokenMatchPattern := { - seqLen := c.seqLen - targetOffset := c.targetOffset - keyOffset := c.keyOffset - targetCountLowerBound := c.targetCountLowerBound - softmaxExpEffort := c.softmaxExpEffort - targetWeightLowerBound := c.targetWeightLowerBound - marginLowerBound := c.marginLowerBound -} - -theorem toTokenMatchPattern_valid (c : HeadPatternCert) (h : c.Valid) : - (toTokenMatchPattern c).Valid := by - rcases h with ⟨hseq, hcount, _hmargin, hweight⟩ - exact ⟨hseq, hcount, by simpa [toTokenMatchPattern] using hweight⟩ - -def toInductionPatternWitness - (c : HeadPatternCert) (h : c.Valid) (hm : c.marginLowerBound > 0) - (hcount : 0 < c.targetCountLowerBound) (hoff : c.targetOffset = -1) - (hkey : c.keyOffset = 0) : - Nfp.InductionPatternWitness := - Nfp.TokenMatchPattern.toInductionPatternWitness - (toTokenMatchPattern c) (toTokenMatchPattern_valid c h) hm hcount hoff hkey - -/-- Build a copy-next witness from a head pattern certificate. -/ -def toCopyNextPatternWitness - (c : HeadPatternCert) (h : c.Valid) (hm : c.marginLowerBound > 0) - (hcount : 0 < c.targetCountLowerBound) (hoff : c.targetOffset = 0) - (hkey : c.keyOffset = -1) : - Nfp.CopyNextPatternWitness := - Nfp.TokenMatchPattern.toCopyNextPatternWitness - (toTokenMatchPattern c) (toTokenMatchPattern_valid c h) hm hcount hoff hkey - -end HeadPatternCert - -/-! ## Induction head sound certificates -/ - -/-- Combined sound certificate for an induction-style head pair. -/ -structure InductionHeadSoundCert where - layer1Pattern : HeadPatternCert - layer2Pattern : HeadPatternCert - layer2Value : HeadValueLowerBoundCert - layer2Logit? : Option HeadLogitDiffLowerBoundCert - deltaLowerBound : Rat - deriving Repr - -namespace InductionHeadSoundCert - -/-- Internal consistency checks for the combined certificate. -/ -def Valid (c : InductionHeadSoundCert) : Prop := - HeadPatternCert.Valid c.layer1Pattern ∧ - HeadPatternCert.Valid c.layer2Pattern ∧ - HeadValueLowerBoundCert.Valid c.layer2Value ∧ - c.layer2Value.layerIdx = c.layer2Pattern.layerIdx ∧ - c.layer2Value.headIdx = c.layer2Pattern.headIdx ∧ - c.layer2Value.matchWeightLowerBound = c.layer2Pattern.targetWeightLowerBound ∧ - c.deltaLowerBound = c.layer2Value.outputCoordLowerBound ∧ - (match c.layer2Logit? with - | none => True - | some logit => - HeadLogitDiffLowerBoundCert.Valid logit ∧ - logit.layerIdx = c.layer2Pattern.layerIdx ∧ - logit.headIdx = c.layer2Pattern.headIdx ∧ - logit.matchWeightLowerBound = c.layer2Pattern.targetWeightLowerBound) - -instance (c : InductionHeadSoundCert) : Decidable (Valid c) := by - classical - unfold Valid - cases c.layer2Logit? <;> infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (c : InductionHeadSoundCert) : Bool := - decide (Valid c) - -theorem check_iff (c : InductionHeadSoundCert) : c.check = true ↔ c.Valid := by - simp [check, Valid] - -end InductionHeadSoundCert - -/-! ## Best-match induction head certificates -/ - -/-- Combined sound certificate for an induction-style head pair (best-match pattern). -/ -structure InductionHeadBestMatchSoundCert where - layer1Pattern : HeadBestMatchPatternCert - layer2Pattern : HeadBestMatchPatternCert - layer2Value : HeadValueLowerBoundPosCert - layer2Logit? : Option HeadLogitDiffLowerBoundPosCert - deltaLowerBound : Rat - deriving Repr - -namespace InductionHeadBestMatchSoundCert - -/-- Internal consistency checks for the combined certificate. -/ -def Valid (c : InductionHeadBestMatchSoundCert) : Prop := - HeadBestMatchPatternCert.Valid c.layer1Pattern ∧ - HeadBestMatchPatternCert.Valid c.layer2Pattern ∧ - HeadValueLowerBoundPosCert.Valid c.layer2Value ∧ - c.layer2Value.layerIdx = c.layer2Pattern.layerIdx ∧ - c.layer2Value.headIdx = c.layer2Pattern.headIdx ∧ - c.layer2Value.queryPos = c.layer2Pattern.queryPos ∧ - c.layer2Value.matchWeightLowerBound = c.layer2Pattern.bestMatchWeightLowerBound ∧ - c.deltaLowerBound = c.layer2Value.outputCoordLowerBound ∧ - (match c.layer2Logit? with - | none => True - | some logit => - HeadLogitDiffLowerBoundPosCert.Valid logit ∧ - logit.layerIdx = c.layer2Pattern.layerIdx ∧ - logit.headIdx = c.layer2Pattern.headIdx ∧ - logit.queryPos = c.layer2Pattern.queryPos ∧ - logit.matchWeightLowerBound = c.layer2Pattern.bestMatchWeightLowerBound) - -instance (c : InductionHeadBestMatchSoundCert) : Decidable (Valid c) := by - classical - unfold Valid - cases c.layer2Logit? <;> infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (c : InductionHeadBestMatchSoundCert) : Bool := - decide (Valid c) - -theorem check_iff (c : InductionHeadBestMatchSoundCert) : c.check = true ↔ c.Valid := by - simp [check, Valid] - -end InductionHeadBestMatchSoundCert - -/-! ### Certificate verification helpers -/ - -/-- Validate a batch of head contribution certificates. -/ -def verifyHeadContributionCerts (certs : Array HeadContributionCert) : - Except String (Array HeadContributionCert) := - let ok := certs.foldl (fun acc c => acc && c.check) true - if ok then - .ok certs - else - .error "head contribution certificate failed internal checks" - -/-- Validate a batch of local head contribution certificates. -/ -def verifyHeadLocalContributionCerts (eps : Rat) (soundnessBits : Nat) - (certs : Array HeadLocalContributionCert) : - Except String (Array HeadLocalContributionCert) := - let ok := - certs.foldl (fun acc c => - acc && c.soundnessBits = soundnessBits && c.check eps) true - if ok then - .ok certs - else - .error "local head contribution certificate failed internal checks" - -/-- Validate a single local head contribution certificate. -/ -def verifyHeadLocalContributionCert (eps : Rat) (soundnessBits : Nat) - (cert : HeadLocalContributionCert) : Except String HeadLocalContributionCert := - if cert.soundnessBits = soundnessBits && cert.check eps then - .ok cert - else - .error "local head contribution certificate failed internal checks" - -/-- Validate a head pattern certificate. -/ -def verifyHeadPatternCert (cert : HeadPatternCert) : Except String HeadPatternCert := - if cert.check then - .ok cert - else - .error "head pattern certificate failed internal checks" - -/-- Validate a best-match head pattern certificate. -/ -def verifyHeadBestMatchPatternCert (cert : HeadBestMatchPatternCert) : - Except String HeadBestMatchPatternCert := - if cert.check then - .ok cert - else - .error "head best-match pattern certificate failed internal checks" - -/-- Validate a batch of best-match head pattern certificates. -/ -def verifyHeadBestMatchPatternCerts (certs : Array HeadBestMatchPatternCert) : - Except String (Array HeadBestMatchPatternCert) := - let ok := certs.foldl (fun acc c => acc && c.check) true - if ok then - .ok certs - else - .error "head best-match sweep certificate failed internal checks" - -/-- Validate a layer-level best-match margin certificate. -/ -def verifyLayerBestMatchMarginCert (cert : LayerBestMatchMarginCert) : - Except String LayerBestMatchMarginCert := - if cert.check then - .ok cert - else - .error "layer best-match margin certificate failed internal checks" - -/-- Validate a head output lower-bound certificate. -/ -def verifyHeadValueLowerBoundCert (cert : HeadValueLowerBoundCert) : - Except String HeadValueLowerBoundCert := - if cert.check then - .ok cert - else - .error "head value lower bound certificate failed internal checks" - -/-- Validate a head logit-difference lower-bound certificate. -/ -def verifyHeadLogitDiffLowerBoundCert (cert : HeadLogitDiffLowerBoundCert) : - Except String HeadLogitDiffLowerBoundCert := - if cert.check then - .ok cert - else - .error "head logit-diff lower bound certificate failed internal checks" - -/-- Validate an induction-head certificate. -/ -def verifyInductionHeadSoundCert (cert : InductionHeadSoundCert) : - Except String InductionHeadSoundCert := - if cert.check then - .ok cert - else - .error "induction head certificate failed internal checks" - -/-- Validate a best-match induction-head certificate. -/ -def verifyInductionHeadBestMatchSoundCert (cert : InductionHeadBestMatchSoundCert) : - Except String InductionHeadBestMatchSoundCert := - if cert.check then - .ok cert - else - .error "best-match induction head certificate failed internal checks" - -/-- Locate a local head contribution certificate for a specific layer/head. -/ -def findHeadLocalContribution (certs : Array HeadLocalContributionCert) - (layerIdx headIdx : Nat) : Except String HeadLocalContributionCert := - match certs.find? (fun c => c.layerIdx == layerIdx && c.headIdx == headIdx) with - | some c => .ok c - | none => .error s!"no local head contribution cert for layer {layerIdx} head {headIdx}" - -/-- Tighten a local head contribution certificate using best-match evidence. -/ -def tightenHeadLocalContributionBestMatch - (eps : Rat) - (soundnessBits : Nat) - (base : HeadLocalContributionCert) - (pattern : HeadBestMatchPatternCert) - (softmaxExpEffort : Nat) : Except String HeadLocalContributionCert := - Id.run do - let _ ← verifyHeadLocalContributionCert eps soundnessBits base - let _ ← verifyHeadBestMatchPatternCert pattern - if pattern.layerIdx ≠ base.layerIdx || pattern.headIdx ≠ base.headIdx then - return .error "best-match pattern cert layer/head mismatch" - if pattern.softmaxExpEffort ≠ softmaxExpEffort then - return .error "best-match pattern cert softmax effort mismatch" - let softmaxBound := pattern.softmaxJacobianNormInfUpperBound - if softmaxBound > base.softmaxJacobianNormInfUpperBound then - return .error "best-match softmax bound is worse than baseline" - let attnJacBound := - base.ln1Bound * softmaxBound * base.wvOpBound * base.woOpBound - let tightened := - { base with - softmaxJacobianNormInfUpperBound := softmaxBound - attnJacBound := attnJacBound } - if tightened.check eps then - return .ok tightened - return .error "tightened head contribution certificate failed internal checks" - -/-! ### Specs -/ - -theorem HeadContributionCert.Valid_spec : - HeadContributionCert.Valid = HeadContributionCert.Valid := rfl -theorem HeadContributionCert.check_spec : - HeadContributionCert.check = HeadContributionCert.check := rfl -theorem HeadLocalContributionCert.Valid_spec : - HeadLocalContributionCert.Valid = HeadLocalContributionCert.Valid := rfl -theorem HeadLocalContributionCert.check_spec : - HeadLocalContributionCert.check = HeadLocalContributionCert.check := rfl -theorem mixLowerBound_spec : - mixLowerBound = mixLowerBound := rfl -theorem HeadPatternCert.Valid_spec : - HeadPatternCert.Valid = HeadPatternCert.Valid := rfl -theorem HeadPatternCert.check_spec : - HeadPatternCert.check = HeadPatternCert.check := rfl -theorem headQueryIndex_spec : - headQueryIndex = headQueryIndex := rfl -theorem marginsFromBestMatchCerts_spec : - marginsFromBestMatchCerts = marginsFromBestMatchCerts := rfl -theorem minMarginArray_spec : - minMarginArray = minMarginArray := rfl -theorem HeadPatternCert.toTokenMatchPattern_spec : - HeadPatternCert.toTokenMatchPattern = HeadPatternCert.toTokenMatchPattern := rfl -theorem HeadPatternCert.toInductionPatternWitness_spec : - HeadPatternCert.toInductionPatternWitness = HeadPatternCert.toInductionPatternWitness := rfl -theorem HeadPatternCert.toCopyNextPatternWitness_spec : - HeadPatternCert.toCopyNextPatternWitness = HeadPatternCert.toCopyNextPatternWitness := rfl -theorem HeadValueLowerBoundCert.Valid_spec : - HeadValueLowerBoundCert.Valid = HeadValueLowerBoundCert.Valid := rfl -theorem HeadValueLowerBoundCert.check_spec : - HeadValueLowerBoundCert.check = HeadValueLowerBoundCert.check := rfl -theorem HeadLogitDiffLowerBoundCert.Valid_spec : - HeadLogitDiffLowerBoundCert.Valid = HeadLogitDiffLowerBoundCert.Valid := rfl -theorem HeadLogitDiffLowerBoundCert.check_spec : - HeadLogitDiffLowerBoundCert.check = HeadLogitDiffLowerBoundCert.check := rfl -theorem HeadBestMatchPatternCert.Valid_spec : - HeadBestMatchPatternCert.Valid = HeadBestMatchPatternCert.Valid := rfl -theorem HeadBestMatchPatternCert.check_spec : - HeadBestMatchPatternCert.check = HeadBestMatchPatternCert.check := rfl -theorem LayerBestMatchMarginCert.Valid_spec : - LayerBestMatchMarginCert.Valid = LayerBestMatchMarginCert.Valid := rfl -theorem LayerBestMatchMarginCert.check_spec : - LayerBestMatchMarginCert.check = LayerBestMatchMarginCert.check := rfl -theorem HeadValueLowerBoundPosCert.Valid_spec : - HeadValueLowerBoundPosCert.Valid = HeadValueLowerBoundPosCert.Valid := rfl -theorem HeadValueLowerBoundPosCert.check_spec : - HeadValueLowerBoundPosCert.check = HeadValueLowerBoundPosCert.check := rfl -theorem HeadLogitDiffLowerBoundPosCert.Valid_spec : - HeadLogitDiffLowerBoundPosCert.Valid = HeadLogitDiffLowerBoundPosCert.Valid := rfl -theorem HeadLogitDiffLowerBoundPosCert.check_spec : - HeadLogitDiffLowerBoundPosCert.check = HeadLogitDiffLowerBoundPosCert.check := rfl -theorem InductionHeadSoundCert.Valid_spec : - InductionHeadSoundCert.Valid = InductionHeadSoundCert.Valid := rfl -theorem InductionHeadSoundCert.check_spec : - InductionHeadSoundCert.check = InductionHeadSoundCert.check := rfl -theorem InductionHeadBestMatchSoundCert.Valid_spec : - InductionHeadBestMatchSoundCert.Valid = InductionHeadBestMatchSoundCert.Valid := rfl -theorem InductionHeadBestMatchSoundCert.check_spec : - InductionHeadBestMatchSoundCert.check = InductionHeadBestMatchSoundCert.check := rfl -theorem verifyHeadContributionCerts_spec : - verifyHeadContributionCerts = verifyHeadContributionCerts := rfl -theorem verifyHeadLocalContributionCerts_spec : - verifyHeadLocalContributionCerts = verifyHeadLocalContributionCerts := rfl -theorem verifyHeadLocalContributionCert_spec : - verifyHeadLocalContributionCert = verifyHeadLocalContributionCert := rfl -theorem verifyHeadPatternCert_spec : - verifyHeadPatternCert = verifyHeadPatternCert := rfl -theorem verifyHeadBestMatchPatternCert_spec : - verifyHeadBestMatchPatternCert = verifyHeadBestMatchPatternCert := rfl -theorem verifyHeadBestMatchPatternCerts_spec : - verifyHeadBestMatchPatternCerts = verifyHeadBestMatchPatternCerts := rfl -theorem verifyLayerBestMatchMarginCert_spec : - verifyLayerBestMatchMarginCert = verifyLayerBestMatchMarginCert := rfl -theorem verifyHeadValueLowerBoundCert_spec : - verifyHeadValueLowerBoundCert = verifyHeadValueLowerBoundCert := rfl -theorem verifyHeadLogitDiffLowerBoundCert_spec : - verifyHeadLogitDiffLowerBoundCert = verifyHeadLogitDiffLowerBoundCert := rfl -theorem verifyInductionHeadSoundCert_spec : - verifyInductionHeadSoundCert = verifyInductionHeadSoundCert := rfl -theorem verifyInductionHeadBestMatchSoundCert_spec : - verifyInductionHeadBestMatchSoundCert = verifyInductionHeadBestMatchSoundCert := rfl -theorem findHeadLocalContribution_spec : - findHeadLocalContribution = findHeadLocalContribution := rfl -theorem tightenHeadLocalContributionBestMatch_spec : - tightenHeadLocalContributionBestMatch = tightenHeadLocalContributionBestMatch := rfl - -end Nfp.Sound diff --git a/Legacy/Nfp/Sound/IO.lean b/Legacy/Nfp/Sound/IO.lean deleted file mode 100644 index 7b8d4a3..0000000 --- a/Legacy/Nfp/Sound/IO.lean +++ /dev/null @@ -1,654 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Nfp.Sound.BinaryPure -import Nfp.Sound.Cert -import Nfp.Sound.HeadCert -import Nfp.Sound.ModelHeader -import Nfp.Sound.TextPure -import Nfp.Untrusted.SoundBinary -import Nfp.Untrusted.SoundCompute - -namespace Nfp.Sound - -open IO - -/-! -# SOUND IO wrappers (trusted verification only) - -This module is intentionally thin: it delegates witness generation to -`Nfp.Untrusted.SoundCompute` and **verifies** returned certificates locally. --/ - -private def readTextModelHeader (path : System.FilePath) : - IO (Except String TextHeader) := do - let contents ← IO.FS.readFile path - let lines : Array String := splitLines contents - return Nfp.Sound.parseTextHeader lines - -private def readBinaryModelHeader (path : System.FilePath) : - IO (Except String Nfp.Sound.BinaryHeader) := do - IO.FS.withFile path IO.FS.Mode.read fun h => do - match ← Nfp.Untrusted.SoundBinary.readBinaryHeader h with - | .error e => return .error e - | .ok hdr => return .ok hdr - -private def readModelHeader (path : System.FilePath) : - IO (Except String (Rat × GeluDerivTarget)) := do - let firstLine ← - IO.FS.withFile path IO.FS.Mode.read fun h => h.getLine - if firstLine.trim = "NFP_BINARY_V1" then - match ← readBinaryModelHeader path with - | .error e => return .error e - | .ok hdr => return .ok (hdr.eps, hdr.geluDerivTarget) - else - match ← readTextModelHeader path with - | .error e => return .error e - | .ok hdr => return .ok (hdr.eps, hdr.geluDerivTarget) - -private def readModelEps (path : System.FilePath) : IO (Except String Rat) := do - match ← readModelHeader path with - | .error e => return .error e - | .ok (eps, _) => return .ok eps - -private def recomputeModelWeightBoundsBinary - (path : System.FilePath) : IO (Except String ModelWeightBounds) := do - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - match ← Nfp.Untrusted.SoundBinary.readBinaryHeader h with - | .error e => return .error e - | .ok hdr => - let scalePow10 := defaultBinaryScalePow10 - match ← Nfp.Untrusted.SoundBinary.skipI32Array h hdr.seqLen with - | .error e => return .error e - | .ok _ => pure () - match ← Nfp.Untrusted.SoundBinary.skipF64Array h (hdr.seqLen * hdr.modelDim) with - | .error e => return .error e - | .ok _ => pure () - let mut valuePairsLayers : Array (Array (Int × Int)) := Array.replicate hdr.numLayers #[] - let mut qkPairsLayers : Array (Array (Int × Int)) := Array.replicate hdr.numLayers #[] - let mut mlpWinBound : Array Rat := Array.replicate hdr.numLayers 0 - let mut mlpWoutBound : Array Rat := Array.replicate hdr.numLayers 0 - let mut ln1MaxAbsGamma : Array Rat := Array.replicate hdr.numLayers 0 - let mut ln1MaxAbsBeta : Array Rat := Array.replicate hdr.numLayers 0 - let mut ln2MaxAbsGamma : Array Rat := Array.replicate hdr.numLayers 0 - for l in [:hdr.numLayers] do - let mut valuePairs : Array (Int × Int) := Array.replicate hdr.numHeads (0, 0) - let mut qkPairs : Array (Int × Int) := Array.replicate hdr.numHeads (0, 0) - for hIdx in [:hdr.numHeads] do - let wqScaledE ← - Nfp.Untrusted.SoundBinary.readMatrixNormInfScaled - h hdr.modelDim hdr.headDim scalePow10 - let wqScaled ← - match wqScaledE with - | .error e => return .error e - | .ok v => pure v - match ← Nfp.Untrusted.SoundBinary.skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - let wkScaledE ← - Nfp.Untrusted.SoundBinary.readMatrixNormInfScaled - h hdr.modelDim hdr.headDim scalePow10 - let wkScaled ← - match wkScaledE with - | .error e => return .error e - | .ok v => pure v - match ← Nfp.Untrusted.SoundBinary.skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - let nvScaledE ← - Nfp.Untrusted.SoundBinary.readMatrixNormInfScaled - h hdr.modelDim hdr.headDim scalePow10 - let nvScaled ← - match nvScaledE with - | .error e => return .error e - | .ok v => pure v - match ← Nfp.Untrusted.SoundBinary.skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - let noScaledE ← - Nfp.Untrusted.SoundBinary.readMatrixNormInfScaled - h hdr.headDim hdr.modelDim scalePow10 - let noScaled ← - match noScaledE with - | .error e => return .error e - | .ok v => pure v - qkPairs := qkPairs.set! hIdx (wqScaled, wkScaled) - valuePairs := valuePairs.set! hIdx (nvScaled, noScaled) - match ← Nfp.Untrusted.SoundBinary.skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - let nWinScaledE ← - Nfp.Untrusted.SoundBinary.readMatrixNormInfScaled - h hdr.modelDim hdr.hiddenDim scalePow10 - let nWinScaled ← - match nWinScaledE with - | .error e => return .error e - | .ok v => pure v - match ← Nfp.Untrusted.SoundBinary.skipF64Array h hdr.hiddenDim with - | .error e => return .error e - | .ok _ => pure () - let nWoutScaledE ← - Nfp.Untrusted.SoundBinary.readMatrixNormInfScaled - h hdr.hiddenDim hdr.modelDim scalePow10 - let nWoutScaled ← - match nWoutScaledE with - | .error e => return .error e - | .ok v => pure v - match ← Nfp.Untrusted.SoundBinary.skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - let ln1GammaScaledE ← - Nfp.Untrusted.SoundBinary.readVectorMaxAbsScaled h hdr.modelDim scalePow10 - let ln1GammaScaled ← - match ln1GammaScaledE with - | .error e => return .error e - | .ok v => pure v - let ln1BetaScaledE ← - Nfp.Untrusted.SoundBinary.readVectorMaxAbsScaled h hdr.modelDim scalePow10 - let ln1BetaScaled ← - match ln1BetaScaledE with - | .error e => return .error e - | .ok v => pure v - let ln2GammaScaledE ← - Nfp.Untrusted.SoundBinary.readVectorMaxAbsScaled h hdr.modelDim scalePow10 - let ln2GammaScaled ← - match ln2GammaScaledE with - | .error e => return .error e - | .ok v => pure v - let ln2BetaScaledE ← - Nfp.Untrusted.SoundBinary.readVectorMaxAbsScaled h hdr.modelDim scalePow10 - match ln2BetaScaledE with - | .error e => return .error e - | .ok _ => pure () - let nWin := ratOfScaledInt scalePow10 nWinScaled - let nWout := ratOfScaledInt scalePow10 nWoutScaled - let ln1Gamma := ratOfScaledInt scalePow10 ln1GammaScaled - let ln1Beta := ratOfScaledInt scalePow10 ln1BetaScaled - let ln2Gamma := ratOfScaledInt scalePow10 ln2GammaScaled - mlpWinBound := mlpWinBound.set! l nWin - mlpWoutBound := mlpWoutBound.set! l nWout - ln1MaxAbsGamma := ln1MaxAbsGamma.set! l ln1Gamma - ln1MaxAbsBeta := ln1MaxAbsBeta.set! l ln1Beta - ln2MaxAbsGamma := ln2MaxAbsGamma.set! l ln2Gamma - valuePairsLayers := valuePairsLayers.set! l valuePairs - qkPairsLayers := qkPairsLayers.set! l qkPairs - match ← Nfp.Untrusted.SoundBinary.skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← Nfp.Untrusted.SoundBinary.skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← Nfp.Untrusted.SoundBinary.skipF64Array h (hdr.modelDim * hdr.vocabSize) with - | .error e => return .error e - | .ok _ => pure () - match attnWeightBoundsArraysFromScaledPairs scalePow10 valuePairsLayers qkPairsLayers with - | .error e => return .error e - | .ok (coeffs, wqMaxs, wkMaxs) => - return .ok { - attnValueCoeff := coeffs - wqOpBoundMax := wqMaxs - wkOpBoundMax := wkMaxs - mlpWinBound := mlpWinBound - mlpWoutBound := mlpWoutBound - ln1MaxAbsGamma := ln1MaxAbsGamma - ln1MaxAbsBeta := ln1MaxAbsBeta - ln2MaxAbsGamma := ln2MaxAbsGamma - } - -private def recomputeModelWeightBoundsText - (path : System.FilePath) : IO (Except String ModelWeightBounds) := do - let contents ← IO.FS.readFile path - let lines : Array String := splitLines contents - return modelWeightBoundsFromTextLines lines - -private def recomputeModelWeightBounds - (path : System.FilePath) : IO (Except String ModelWeightBounds) := do - let firstLine ← - IO.FS.withFile path IO.FS.Mode.read fun h => h.getLine - if firstLine.trim = "NFP_BINARY_V1" then - recomputeModelWeightBoundsBinary path - else - recomputeModelWeightBoundsText path - -/-- Compute weight-only per-head contribution bounds from a binary `.nfpt`. -/ -def certifyHeadBoundsBinary - (path : System.FilePath) - (scalePow10 : Nat := 9) : - IO (Except String (Array HeadContributionCert)) := do - match ← Nfp.Untrusted.SoundCompute.certifyHeadBoundsBinary path scalePow10 with - | .error e => return .error e - | .ok certs => - return verifyHeadContributionCerts certs - -/-- Soundly compute conservative per-layer residual amplification constants from a `.nfpt` file. -/ -def certifyModelFileGlobal - (path : System.FilePath) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (partitionDepth : Nat := 0) - (softmaxMarginLowerBound : Rat := 0) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : IO (Except String ModelCert) := do - match ← readModelHeader path with - | .error e => return .error e - | .ok (eps, geluTarget) => - match ← - Nfp.Untrusted.SoundCompute.certifyModelFileGlobal - path eps geluTarget soundnessBits inputPath? inputDelta partitionDepth - softmaxMarginLowerBound softmaxExpEffort with - | .error e => return .error e - | .ok cert => - match ← recomputeModelWeightBounds path with - | .error e => - return .error s!"model weight bounds verification failed: {e}" - | .ok bounds => - return verifyModelCert cert eps soundnessBits geluTarget - bounds.attnValueCoeff bounds.wqOpBoundMax bounds.wkOpBoundMax - bounds.mlpWinBound bounds.mlpWoutBound - bounds.ln1MaxAbsGamma bounds.ln1MaxAbsBeta bounds.ln2MaxAbsGamma - -/-- Entry point for sound certification (global or local). -/ -def certifyModelFile - (path : System.FilePath) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (partitionDepth : Nat := 0) - (softmaxMarginLowerBound : Rat := 0) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : IO (Except String ModelCert) := do - match ← readModelHeader path with - | .error e => return .error e - | .ok (eps, geluTarget) => - match ← - Nfp.Untrusted.SoundCompute.certifyModelFile - path eps geluTarget soundnessBits inputPath? inputDelta partitionDepth - softmaxMarginLowerBound softmaxExpEffort with - | .error e => return .error e - | .ok cert => - match ← recomputeModelWeightBounds path with - | .error e => - return .error s!"model weight bounds verification failed: {e}" - | .ok bounds => - return verifyModelCert cert eps soundnessBits geluTarget - bounds.attnValueCoeff bounds.wqOpBoundMax bounds.wkOpBoundMax - bounds.mlpWinBound bounds.mlpWoutBound - bounds.ln1MaxAbsGamma bounds.ln1MaxAbsBeta bounds.ln2MaxAbsGamma - -/-- Compute per-head contribution bounds (global). -/ -def certifyHeadBounds - (path : System.FilePath) - (scalePow10 : Nat := 9) : - IO (Except String (Array HeadContributionCert)) := do - match ← Nfp.Untrusted.SoundCompute.certifyHeadBounds path scalePow10 with - | .error e => return .error e - | .ok certs => - return verifyHeadContributionCerts certs - -/-- Compute local per-head attention contribution bounds. -/ -def certifyHeadBoundsLocal - (path : System.FilePath) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (soundnessBits : Nat) - (scalePow10 : Nat := 9) : - IO (Except String (Array HeadLocalContributionCert)) := do - match ← readModelEps path with - | .error e => return .error e - | .ok eps => - match ← - Nfp.Untrusted.SoundCompute.certifyHeadBoundsLocal - path eps inputPath? inputDelta soundnessBits scalePow10 with - | .error e => return .error e - | .ok certs => - return verifyHeadLocalContributionCerts eps soundnessBits certs - -/-- Compute local attention pattern bounds for a specific head. -/ -def certifyHeadPatternLocal - (path : System.FilePath) - (layerIdx headIdx : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (soundnessBits : Nat) - (targetOffset : Int := -1) - (keyOffset : Int := 0) - (maxSeqLen : Nat := 256) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (scalePow10 : Nat := 9) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) - (causalPattern : Bool := true) : - IO (Except String HeadPatternCert) := do - match ← readModelEps path with - | .error e => return .error e - | .ok eps => - match ← - Nfp.Untrusted.SoundCompute.certifyHeadPatternLocal - path layerIdx headIdx eps soundnessBits inputPath? inputDelta targetOffset keyOffset - maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 - softmaxExpEffort causalPattern with - | .error e => return .error e - | .ok cert => - return verifyHeadPatternCert cert - -/-- Compute local best-match pattern bounds for a specific head. -/ -def certifyHeadPatternBestMatchLocal - (path : System.FilePath) - (layerIdx headIdx : Nat) - (queryPos? : Option Nat := none) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (soundnessBits : Nat) - (targetOffset : Int := -1) - (keyOffset : Int := 0) - (maxSeqLen : Nat := 256) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (useAffine : Bool := false) - (scalePow10 : Nat := 9) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) - (causalPattern : Bool := true) : - IO (Except String HeadBestMatchPatternCert) := do - match ← readModelEps path with - | .error e => return .error e - | .ok eps => - match ← - Nfp.Untrusted.SoundCompute.certifyHeadPatternBestMatchLocal - path layerIdx headIdx queryPos? eps soundnessBits inputPath? inputDelta targetOffset - keyOffset maxSeqLen tightPattern tightPatternLayers perRowPatternLayers useAffine - scalePow10 softmaxExpEffort causalPattern with - | .error e => return .error e - | .ok cert => - return verifyHeadBestMatchPatternCert cert - -/-- Compute local best-match pattern bounds for a sweep of heads. -/ -def certifyHeadPatternBestMatchLocalSweep - (path : System.FilePath) - (layerIdx headIdx : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (soundnessBits : Nat) - (targetOffset : Int := -1) - (keyOffset : Int := 0) - (maxSeqLen : Nat := 256) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (useAffine : Bool := false) - (scalePow10 : Nat := 9) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) - (causalPattern : Bool := true) : - IO (Except String (Array HeadBestMatchPatternCert)) := do - match ← readModelEps path with - | .error e => return .error e - | .ok eps => - match ← - Nfp.Untrusted.SoundCompute.certifyHeadPatternBestMatchLocalSweep - path layerIdx headIdx eps soundnessBits inputPath? inputDelta targetOffset keyOffset - maxSeqLen tightPattern tightPatternLayers perRowPatternLayers useAffine scalePow10 - softmaxExpEffort causalPattern with - | .error e => return .error e - | .ok certs => - return verifyHeadBestMatchPatternCerts certs - -/-- Compute layer-level best-match margin evidence (binary only). -/ -def certifyLayerBestMatchMarginLocal - (path : System.FilePath) - (layerIdx : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (soundnessBits : Nat) - (targetOffset : Int := -1) - (keyOffset : Int := 0) - (maxSeqLen : Nat := 256) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (scalePow10 : Nat := 9) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) - (causalPattern : Bool := true) : - IO (Except String LayerBestMatchMarginCert) := do - match ← readModelEps path with - | .error e => return .error e - | .ok eps => - match ← - Nfp.Untrusted.SoundCompute.certifyLayerBestMatchMarginLocal - path layerIdx eps soundnessBits inputPath? inputDelta targetOffset keyOffset - maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 - softmaxExpEffort causalPattern with - | .error e => return .error e - | .ok cert => - return verifyLayerBestMatchMarginCert cert - -/-- Soundly compute conservative bounds and tighten them using best-match margin evidence. -/ -def certifyModelFileBestMatchMargins - (path : System.FilePath) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (partitionDepth : Nat := 0) - (targetOffset : Int := -1) - (maxSeqLen : Nat := 0) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (scalePow10 : Nat := 9) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) - (causalPattern : Bool := true) : IO (Except String ModelCert) := do - match ← readBinaryModelHeader path with - | .error e => return .error e - | .ok hdr => - if inputPath?.isNone then - return .error "best-match margin tightening requires local input" - let maxSeqLen' := if maxSeqLen = 0 then hdr.seqLen else maxSeqLen - match ← - Nfp.Untrusted.SoundCompute.certifyModelFile - path hdr.eps hdr.geluDerivTarget soundnessBits inputPath? inputDelta partitionDepth - (softmaxMarginLowerBound := 0) (softmaxExpEffort := softmaxExpEffort) with - | .error e => return .error e - | .ok cert => - match ← recomputeModelWeightBounds path with - | .error e => - return .error s!"model weight bounds verification failed: {e}" - | .ok bounds => - let mut marginCerts : Array LayerBestMatchMarginCert := Array.mkEmpty hdr.numLayers - for layerIdx in [:hdr.numLayers] do - match ← - certifyLayerBestMatchMarginLocal path layerIdx - (inputPath? := inputPath?) (inputDelta := inputDelta) - (soundnessBits := soundnessBits) - (targetOffset := targetOffset) (maxSeqLen := maxSeqLen') - (tightPattern := tightPattern) - (tightPatternLayers := tightPatternLayers) - (perRowPatternLayers := perRowPatternLayers) - (scalePow10 := scalePow10) - (softmaxExpEffort := softmaxExpEffort) - (causalPattern := causalPattern) with - | .error e => return .error e - | .ok cert => marginCerts := marginCerts.push cert - return verifyModelCertBestMatchMargins cert hdr.eps soundnessBits hdr.geluDerivTarget - bounds.attnValueCoeff bounds.wqOpBoundMax bounds.wkOpBoundMax - bounds.mlpWinBound bounds.mlpWoutBound - bounds.ln1MaxAbsGamma bounds.ln1MaxAbsBeta bounds.ln2MaxAbsGamma - marginCerts - -/-- Compute local per-head attention contribution bounds tightened by - best-match pattern evidence. -/ -def certifyHeadBoundsLocalBestMatch - (path : System.FilePath) - (layerIdx headIdx : Nat) - (queryPos? : Option Nat := none) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (soundnessBits : Nat) - (targetOffset : Int := -1) - (keyOffset : Int := 0) - (maxSeqLen : Nat := 256) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (scalePow10 : Nat := 9) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) - (causalPattern : Bool := true) : - IO (Except String HeadLocalContributionCert) := do - match ← readModelEps path with - | .error e => return .error e - | .ok eps => - match ← - certifyHeadBoundsLocal path - (inputPath? := inputPath?) (inputDelta := inputDelta) - (soundnessBits := soundnessBits) (scalePow10 := scalePow10) with - | .error e => return .error e - | .ok certs => - match findHeadLocalContribution certs layerIdx headIdx with - | .error e => return .error e - | .ok base => - match ← - certifyHeadPatternBestMatchLocal path layerIdx headIdx - (queryPos? := queryPos?) (inputPath? := inputPath?) - (inputDelta := inputDelta) (soundnessBits := soundnessBits) - (targetOffset := targetOffset) (keyOffset := keyOffset) - (maxSeqLen := maxSeqLen) - (tightPattern := tightPattern) (tightPatternLayers := tightPatternLayers) - (perRowPatternLayers := perRowPatternLayers) - (softmaxExpEffort := softmaxExpEffort) - (causalPattern := causalPattern) with - | .error e => return .error e - | .ok pattern => - return tightenHeadLocalContributionBestMatch - eps soundnessBits base pattern pattern.softmaxExpEffort - -/-- Compute local head output lower bounds. -/ -def certifyHeadValueLowerBoundLocal - (path : System.FilePath) - (layerIdx headIdx : Nat) - (coord : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (soundnessBits : Nat) - (targetOffset : Int := -1) - (keyOffset : Int := 0) - (maxSeqLen : Nat := 256) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (scalePow10 : Nat := 9) - (causalPattern : Bool := true) : - IO (Except String HeadValueLowerBoundCert) := do - match ← readModelEps path with - | .error e => return .error e - | .ok eps => - match ← - Nfp.Untrusted.SoundCompute.certifyHeadValueLowerBoundLocal - path layerIdx headIdx coord eps soundnessBits inputPath? inputDelta targetOffset - keyOffset maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 - causalPattern with - | .error e => return .error e - | .ok cert => - return verifyHeadValueLowerBoundCert cert - -/-- Compute local head logit-difference lower bounds. -/ -def certifyHeadLogitDiffLowerBoundLocal - (path : System.FilePath) - (layerIdx headIdx : Nat) - (targetToken negativeToken : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (soundnessBits : Nat) - (targetOffset : Int := -1) - (keyOffset : Int := 0) - (maxSeqLen : Nat := 256) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (scalePow10 : Nat := 9) - (causalPattern : Bool := true) : - IO (Except String HeadLogitDiffLowerBoundCert) := do - match ← readModelEps path with - | .error e => return .error e - | .ok eps => - match ← - Nfp.Untrusted.SoundCompute.certifyHeadLogitDiffLowerBoundLocal - path layerIdx headIdx targetToken negativeToken eps soundnessBits inputPath? inputDelta - targetOffset keyOffset maxSeqLen tightPattern tightPatternLayers - perRowPatternLayers scalePow10 causalPattern with - | .error e => return .error e - | .ok cert => - return verifyHeadLogitDiffLowerBoundCert cert - -/-- Sound induction-head certification (local path). -/ -def certifyInductionSound - (path : System.FilePath) - (layer1 head1 layer2 head2 : Nat) - (coord : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (soundnessBits : Nat) - (offset1 : Int := -1) - (offset2 : Int := -1) - (keyOffset1 : Int := 0) - (keyOffset2 : Int := 0) - (maxSeqLen : Nat := 256) - (scalePow10 : Nat := 9) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (targetToken? : Option Nat := none) - (negativeToken? : Option Nat := none) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) - (causalPattern : Bool := true) : - IO (Except String InductionHeadSoundCert) := do - match ← readModelEps path with - | .error e => return .error e - | .ok eps => - match ← - Nfp.Untrusted.SoundCompute.certifyInductionSound - path layer1 head1 layer2 head2 coord eps soundnessBits inputPath? inputDelta - offset1 offset2 keyOffset1 keyOffset2 maxSeqLen scalePow10 tightPattern - tightPatternLayers - perRowPatternLayers targetToken? negativeToken? softmaxExpEffort causalPattern with - | .error e => return .error e - | .ok cert => - return verifyInductionHeadSoundCert cert - -/-- Sound best-match induction-head certification (local path). -/ -def certifyInductionSoundBestMatch - (path : System.FilePath) - (layer1 head1 layer2 head2 : Nat) - (coord : Nat) - (queryPos? : Option Nat := none) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (soundnessBits : Nat) - (offset1 : Int := -1) - (offset2 : Int := -1) - (keyOffset1 : Int := 0) - (keyOffset2 : Int := 0) - (maxSeqLen : Nat := 256) - (scalePow10 : Nat := 9) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (useAffine : Bool := false) - (iterTighten : Bool := false) - (targetToken? : Option Nat := none) - (negativeToken? : Option Nat := none) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) - (causalPattern : Bool := true) : - IO (Except String InductionHeadBestMatchSoundCert) := do - match ← readModelEps path with - | .error e => return .error e - | .ok eps => - match ← - Nfp.Untrusted.SoundCompute.certifyInductionSoundBestMatch - path layer1 head1 layer2 head2 coord queryPos? eps soundnessBits inputPath? inputDelta - offset1 offset2 keyOffset1 keyOffset2 maxSeqLen scalePow10 tightPattern - tightPatternLayers perRowPatternLayers useAffine iterTighten targetToken? negativeToken? - softmaxExpEffort - causalPattern with - | .error e => return .error e - | .ok cert => - return verifyInductionHeadBestMatchSoundCert cert - -end Nfp.Sound diff --git a/Legacy/Nfp/Sound/Interval.lean b/Legacy/Nfp/Sound/Interval.lean deleted file mode 100644 index 60428fb..0000000 --- a/Legacy/Nfp/Sound/Interval.lean +++ /dev/null @@ -1,448 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Nfp.Sound.Bounds - -namespace Nfp.Sound - -/-! -# Rational intervals for SOUND-mode local certification - -This file provides a minimal-but-complete `Rat` interval arithmetic library used by -streaming Interval Bound Propagation (IBP) in `Nfp.Untrusted.SoundCompute`, with trusted -verification wrappers in `Nfp.Sound.IO`. - -All operations are conservative (over-approximations). --/ - -/-- Closed rational interval `[lo, hi]`. -/ -structure RatInterval where - lo : Rat - hi : Rat - deriving Repr - -namespace RatInterval - -instance : Inhabited RatInterval := ⟨{ lo := 0, hi := 0 }⟩ - -/-- Singleton interval `[r, r]`. -/ -def const (r : Rat) : RatInterval := { lo := r, hi := r } - -/-- Interval addition: `[a.lo + b.lo, a.hi + b.hi]`. -/ -def add (a b : RatInterval) : RatInterval := - { lo := a.lo + b.lo, hi := a.hi + b.hi } - -/-- Interval subtraction: `[a.lo - b.hi, a.hi - b.lo]`. -/ -def sub (a b : RatInterval) : RatInterval := - { lo := a.lo - b.hi, hi := a.hi - b.lo } - -/-- Interval multiplication via endpoint products. -/ -def mul (a b : RatInterval) : RatInterval := - let p1 := a.lo * b.lo - let p2 := a.lo * b.hi - let p3 := a.hi * b.lo - let p4 := a.hi * b.hi - { lo := min (min p1 p2) (min p3 p4), hi := max (max p1 p2) (max p3 p4) } - -/-- Scale an interval by a rational `c`, handling sign. -/ -def scale (c : Rat) (a : RatInterval) : RatInterval := - if c ≥ 0 then - { lo := c * a.lo, hi := c * a.hi } - else - { lo := c * a.hi, hi := c * a.lo } - -/-- ReLU over-approximation: `[max(0, lo), max(0, hi)]`. -/ -def relu (a : RatInterval) : RatInterval := - { lo := max 0 a.lo, hi := max 0 a.hi } - -/-- Union (hull) of intervals: `[min(a.lo,b.lo), max(a.hi,b.hi)]`. -/ -def union (a b : RatInterval) : RatInterval := - { lo := min a.lo b.lo, hi := max a.hi b.hi } - -/-- Whether the interval contains 0. -/ -def containsZero (a : RatInterval) : Bool := - decide (a.lo ≤ 0 ∧ 0 ≤ a.hi) - -/-- Upper bound on `|x - μ|` for any `x, μ ∈ [lo, hi]`. -/ -def centeredAbsBound (a : RatInterval) : Rat := - ratAbs (a.hi - a.lo) - -private def ratSq (x : Rat) : Rat := x * x - -/-- Lower bound on `x^2` over an interval. - -If the interval contains 0, this is 0. Otherwise it is the squared distance to 0 of the -endpoint with smaller absolute value. --/ -def squareLowerBound (a : RatInterval) : Rat := - if containsZero a then - 0 - else - let alo := ratAbs a.lo - let ahi := ratAbs a.hi - let m := min alo ahi - ratSq m - -/-- Mean interval of a coordinate-wise box. - -For `xᵢ ∈ [loᵢ, hiᵢ]`, we have `μ = (1/n)∑ xᵢ ∈ [(1/n)∑ loᵢ, (1/n)∑ hiᵢ]`. --/ -def mean (xs : Array RatInterval) : RatInterval := - if xs.isEmpty then - const 0 - else - let n : Nat := xs.size - let nRat : Rat := (n : Nat) - let (loSum, hiSum) := - xs.foldl (fun (acc : Rat × Rat) x => (acc.1 + x.lo, acc.2 + x.hi)) (0, 0) - { lo := loSum / nRat, hi := hiSum / nRat } - -/-- Sound lower bound on the variance of a coordinate-wise box. - -We return a conservative lower bound on `var(x)` for all `x` in the box. - -We compute the exact minimum via a 1D convex minimization: - -`var(x) = (1/n) ∑ (xᵢ - mean(x))^2 = min_c (1/n) ∑ (xᵢ - c)^2`. - -Therefore, - -`min_{x∈box} var(x) = min_c (1/n) ∑ min_{xᵢ∈[lᵢ,uᵢ]} (xᵢ - c)^2` - -and for fixed `c` each coordinate minimization is `dist([lᵢ,uᵢ], c)^2` where `dist` is 0 if -`c ∈ [lᵢ,uᵢ]` and otherwise the squared distance to the nearer endpoint. - -The resulting one-dimensional function of `c` is convex piecewise-quadratic, so we can find its -global minimum by scanning the sorted breakpoints `{lᵢ,uᵢ}` and checking the unique stationary point -in each region (plus the breakpoints themselves). --/ -def varianceLowerBound (xs : Array RatInterval) : Rat := - if xs.isEmpty then - 0 - else - Id.run do - let n : Nat := xs.size - let nRat : Rat := (n : Nat) - -- Normalize endpoints defensively. - let normed : Array RatInterval := - xs.map (fun x => { lo := min x.lo x.hi, hi := max x.lo x.hi }) - if n < 2 then - return 0 - -- Build sorted breakpoint lists for `lo` and `hi` with squared endpoints for O(1) evaluation. - let mut enters : Array (Rat × Rat) := Array.replicate n (0, 0) - let mut leaves : Array (Rat × Rat) := Array.replicate n (0, 0) - let mut sumLeft : Rat := 0 - let mut sumLeftSq : Rat := 0 - for i in [:n] do - let x := normed[i]! - let lo := x.lo - let hi := x.hi - enters := enters.set! i (lo, ratSq lo) - leaves := leaves.set! i (hi, ratSq hi) - sumLeft := sumLeft + lo - sumLeftSq := sumLeftSq + ratSq lo - -- Exact minimization over the breakpoints (O(n log n)). - let entersSorted := enters.qsort (fun a b => a.1 ≤ b.1) - let leavesSorted := leaves.qsort (fun a b => a.1 ≤ b.1) - let breaksAll := - (entersSorted.map (fun p => p.1) ++ leavesSorted.map (fun p => p.1)).qsort (· ≤ ·) - -- Unique-ify breakpoints. - let mut breaks : Array Rat := Array.mkEmpty breaksAll.size - for b in breaksAll do - if breaks.isEmpty then - breaks := breaks.push b - else if breaks.back! = b then - pure () - else - breaks := breaks.push b - let evalG (c : Rat) (leftCount rightCount : Nat) - (sumLeft sumLeftSq sumRight sumRightSq : Rat) : Rat := - let cSq := ratSq c - let leftTerm := - sumLeftSq - (2 : Rat) * c * sumLeft + ((leftCount : Nat) : Rat) * cSq - let rightTerm := - ((rightCount : Nat) : Rat) * cSq - (2 : Rat) * c * sumRight + sumRightSq - leftTerm + rightTerm - -- State for scanning regions: - -- left set: intervals with `c < lo` - -- right set: intervals with `c > hi` - let mut leftCount : Nat := n - let mut rightCount : Nat := 0 - let mut sumRight : Rat := 0 - let mut sumRightSq : Rat := 0 - let mut iEnter : Nat := 0 - let mut iLeave : Nat := 0 - let mut bestG : Rat := - evalG (sumLeft / nRat) leftCount rightCount sumLeft sumLeftSq sumRight sumRightSq - if !breaks.isEmpty then - bestG := min bestG - (evalG breaks[0]! leftCount rightCount sumLeft sumLeftSq sumRight sumRightSq) - for bi in [:breaks.size] do - let b := breaks[bi]! - -- Process enters at `b` (at `c=b` these are already inside, so not in left). - while iEnter < entersSorted.size && entersSorted[iEnter]!.1 = b do - let (_, loSq) := entersSorted[iEnter]! - let lo := entersSorted[iEnter]!.1 - leftCount := leftCount - 1 - sumLeft := sumLeft - lo - sumLeftSq := sumLeftSq - loSq - iEnter := iEnter + 1 - -- Evaluate at the breakpoint `c=b` (intervals with `hi=b` are still inside at `c=b`). - bestG := min bestG (evalG b leftCount rightCount sumLeft sumLeftSq sumRight sumRightSq) - -- After `b`, process leaves at `b` (those intervals become right for `c>b`). - while iLeave < leavesSorted.size && leavesSorted[iLeave]!.1 = b do - let (_, hiSq) := leavesSorted[iLeave]! - let hi := leavesSorted[iLeave]!.1 - rightCount := rightCount + 1 - sumRight := sumRight + hi - sumRightSq := sumRightSq + hiSq - iLeave := iLeave + 1 - -- Check stationary point in the region `(b, nextB)` if there is a next breakpoint. - let outsideCount : Nat := leftCount + rightCount - if outsideCount = 0 then - -- There exists `c` contained in every interval, so the box intersects the constant line. - -- The exact minimum is 0. - return 0 - let cStar : Rat := (sumLeft + sumRight) / ((outsideCount : Nat) : Rat) - if bi + 1 < breaks.size then - let bNext := breaks[bi + 1]! - if b < cStar ∧ cStar < bNext then - bestG := min bestG - (evalG cStar leftCount rightCount sumLeft sumLeftSq sumRight sumRightSq) - else - -- Last region `(b, +∞)`. - if b ≤ cStar then - bestG := min bestG - (evalG cStar leftCount rightCount sumLeft sumLeftSq sumRight sumRightSq) - let exactLB := bestG / nRat - return exactLB - -/-- Over-approximate GeLU on an interval without transcendental evaluation. - -For both exact and tanh GeLU, `GeLU(x) = x * g(x)` with `g(x) ∈ [0, 1]` and -`g(x) ≥ 1/2` when `x ≥ 0`, so `GeLU(x) ≥ x/2` for all `x`. -We keep the standard upper bound `GeLU(x) ≤ max(x, 0)`. --/ -def geluOverapprox (a : RatInterval) : RatInterval := - { lo := a.lo / (2 : Rat), hi := max a.hi 0 } - -/-- Exp lower bound for all signs, using reciprocal of `expUB` for `x < 0`. -/ -private def expLBAll (x : Rat) (effort : Nat) : Rat := - if x ≥ 0 then - expLB x effort - else - let ub := expUBScaledGeom (-x) - if ub = 0 then 0 else (1 : Rat) / ub - -/-- Exp upper bound for all signs, using reciprocal of `expLB` for `x < 0`. -/ -private def expUBAll (x : Rat) (effort : Nat) : Rat := - if x ≥ 0 then - expUBScaledGeom x - else - let lb := expLB (-x) effort - if lb = 0 then 1 else (1 : Rat) / lb - -/-- Tanh over-approximation using exp bounds on the endpoints. -/ -def tanhOverapprox (a : RatInterval) (expEffort : Nat) : RatInterval := - let lo := min a.lo a.hi - let hi := max a.lo a.hi - let eLo := expLBAll ((2 : Rat) * lo) expEffort - let eHi := expUBAll ((2 : Rat) * hi) expEffort - let f : Rat → Rat := fun e => (e - 1) / (e + 1) - { lo := f eLo, hi := f eHi } - -/-- Tanh-based GeLU over-approximation using exp bounds. -/ -def geluOverapproxTanh (a : RatInterval) (expEffort : Nat := defaultSoftmaxExpEffort) : - RatInterval := - let x := { lo := min a.lo a.hi, hi := max a.lo a.hi } - let c : Rat := (44715 : Rat) / 1000000 - let kLo : Rat := (7978845608 : Rat) / 10000000000 - let kHi : Rat := (7978845609 : Rat) / 10000000000 - let kI : RatInterval := { lo := kLo, hi := kHi } - let x2 := mul x x - let x3 := mul x2 x - let sPoly := add x (scale c x3) - let s := mul kI sPoly - let t := tanhOverapprox s expEffort - let half : Rat := (1 : Rat) / 2 - let onePlus := add (const 1) t - let g := scale half onePlus - mul x g - -/-- Split-based tightening for tanh GeLU over-approximation. -/ -def geluOverapproxTanhSplit (a : RatInterval) (expEffort : Nat := defaultSoftmaxExpEffort) - (splitDepth : Nat := 0) : RatInterval := - Id.run do - let mut stack : Array (RatInterval × Nat) := #[(a, splitDepth)] - let mut acc? : Option RatInterval := none - while stack.size > 0 do - let idx := stack.size - 1 - let (cur, depth) := stack[idx]! - stack := stack.pop - if depth = 0 then - let leaf := geluOverapproxTanh cur expEffort - acc? := - match acc? with - | none => some leaf - | some acc => some (union acc leaf) - else - let lo := min cur.lo cur.hi - let hi := max cur.lo cur.hi - let mid := (lo + hi) / (2 : Rat) - let left : RatInterval := { lo := lo, hi := mid } - let right : RatInterval := { lo := mid, hi := hi } - let depth' := depth - 1 - stack := stack.push (left, depth') - stack := stack.push (right, depth') - return acc?.getD (geluOverapproxTanh a expEffort) - -/-- Upper bound on `max |gelu'(x)|` over a rational interval. -/ -def geluDerivBound (target : GeluDerivTarget) (a : RatInterval) : Rat := - let maxAbs := max (ratAbs a.lo) (ratAbs a.hi) - let maxAbsSq := maxAbs * maxAbs - let half : Rat := (1 : Rat) / 2 - match target with - | .exact => - -- Conservative: Φ(x) ≤ 1 and φ(x) ≤ 1/2, so gelu'(x) ≤ 1 + |x|/2. - min (1 + half * maxAbs) (geluDerivBoundGlobal target) - | .tanh => - -- Conservative: |tanh| ≤ 1, sech^2 ≤ 1, and sqrt(2/pi) ≤ 1. - let c : Rat := (44715 : Rat) / 1000000 - let slope := 1 + (3 : Rat) * c * maxAbsSq - let localBound := 1 + half * maxAbs * slope - min localBound (geluDerivBoundGlobal target) - -/-- Upper bound on the row-sum softmax Jacobian norm for a probability interval. -/ -def softmaxJacobianNormInfBound (a : RatInterval) : Rat := - Nfp.Sound.softmaxJacobianNormInfBound a.lo a.hi - -/-! ### Specs -/ - -theorem const_def (r : Rat) : RatInterval.const r = { lo := r, hi := r } := rfl - -theorem centeredAbsBound_def (a : RatInterval) : - RatInterval.centeredAbsBound a = ratAbs (a.hi - a.lo) := rfl - -theorem add_def (a b : RatInterval) : - RatInterval.add a b = { lo := a.lo + b.lo, hi := a.hi + b.hi } := rfl - -theorem sub_def (a b : RatInterval) : - RatInterval.sub a b = { lo := a.lo - b.hi, hi := a.hi - b.lo } := rfl - -theorem mul_def (a b : RatInterval) : - RatInterval.mul a b = - let p1 := a.lo * b.lo - let p2 := a.lo * b.hi - let p3 := a.hi * b.lo - let p4 := a.hi * b.hi - { lo := min (min p1 p2) (min p3 p4), hi := max (max p1 p2) (max p3 p4) } := rfl - -theorem scale_def (c : Rat) (a : RatInterval) : - RatInterval.scale c a = - if c ≥ 0 then - { lo := c * a.lo, hi := c * a.hi } - else - { lo := c * a.hi, hi := c * a.lo } := rfl - -theorem relu_def (a : RatInterval) : - RatInterval.relu a = { lo := max 0 a.lo, hi := max 0 a.hi } := rfl - -theorem union_def (a b : RatInterval) : - RatInterval.union a b = { lo := min a.lo b.lo, hi := max a.hi b.hi } := rfl - -theorem softmaxJacobianNormInfBound_def (a : RatInterval) : - RatInterval.softmaxJacobianNormInfBound a = - Nfp.Sound.softmaxJacobianNormInfBound a.lo a.hi := rfl - -theorem geluDerivBound_def (target : GeluDerivTarget) (a : RatInterval) : - RatInterval.geluDerivBound target a = - let maxAbs := max (ratAbs a.lo) (ratAbs a.hi) - let maxAbsSq := maxAbs * maxAbs - let half : Rat := (1 : Rat) / 2 - match target with - | .exact => - min (1 + half * maxAbs) (geluDerivBoundGlobal target) - | .tanh => - let c : Rat := (44715 : Rat) / 1000000 - let slope := 1 + (3 : Rat) * c * maxAbsSq - let localBound := 1 + half * maxAbs * slope - min localBound (geluDerivBoundGlobal target) := rfl - -theorem containsZero_iff (a : RatInterval) : - RatInterval.containsZero a = true ↔ a.lo ≤ 0 ∧ 0 ≤ a.hi := by - simp [containsZero] - -theorem ratSq_def (x : Rat) : ratSq x = x * x := rfl - -theorem squareLowerBound_def (a : RatInterval) : - RatInterval.squareLowerBound a = - if RatInterval.containsZero a then - 0 - else - let alo := ratAbs a.lo - let ahi := ratAbs a.hi - let m := min alo ahi - ratSq m := rfl - -theorem mean_def (xs : Array RatInterval) : - RatInterval.mean xs = - if xs.isEmpty then - RatInterval.const 0 - else - let n : Nat := xs.size - let nRat : Rat := (n : Nat) - let (loSum, hiSum) := - xs.foldl (fun (acc : Rat × Rat) x => (acc.1 + x.lo, acc.2 + x.hi)) (0, 0) - { lo := loSum / nRat, hi := hiSum / nRat } := rfl - -theorem varianceLowerBound_spec (xs : Array RatInterval) : - RatInterval.varianceLowerBound xs = RatInterval.varianceLowerBound xs := rfl - -theorem geluOverapprox_def (a : RatInterval) : - RatInterval.geluOverapprox a = { lo := a.lo / (2 : Rat), hi := max a.hi 0 } := rfl - -theorem tanhOverapprox_def (a : RatInterval) (expEffort : Nat) : - RatInterval.tanhOverapprox a expEffort = - let lo := min a.lo a.hi - let hi := max a.lo a.hi - let eLo := - (fun x => - if x ≥ 0 then - expLB x expEffort - else - let ub := expUBScaledGeom (-x) - if ub = 0 then 0 else (1 : Rat) / ub) ((2 : Rat) * lo) - let eHi := - (fun x => - if x ≥ 0 then - expUBScaledGeom x - else - let lb := expLB (-x) expEffort - if lb = 0 then 1 else (1 : Rat) / lb) ((2 : Rat) * hi) - let f : Rat → Rat := fun e => (e - 1) / (e + 1) - { lo := f eLo, hi := f eHi } := rfl - -theorem geluOverapproxTanh_def (a : RatInterval) (expEffort : Nat) : - RatInterval.geluOverapproxTanh a expEffort = - let x := { lo := min a.lo a.hi, hi := max a.lo a.hi } - let c : Rat := (44715 : Rat) / 1000000 - let kLo : Rat := (7978845608 : Rat) / 10000000000 - let kHi : Rat := (7978845609 : Rat) / 10000000000 - let kI : RatInterval := { lo := kLo, hi := kHi } - let x2 := RatInterval.mul x x - let x3 := RatInterval.mul x2 x - let sPoly := RatInterval.add x (RatInterval.scale c x3) - let s := RatInterval.mul kI sPoly - let t := RatInterval.tanhOverapprox s expEffort - let half : Rat := (1 : Rat) / 2 - let onePlus := RatInterval.add (RatInterval.const 1) t - let g := RatInterval.scale half onePlus - RatInterval.mul x g := rfl - -theorem geluOverapproxTanhSplit_spec (a : RatInterval) (expEffort : Nat) (splitDepth : Nat) : - RatInterval.geluOverapproxTanhSplit a expEffort splitDepth = - RatInterval.geluOverapproxTanhSplit a expEffort splitDepth := rfl - -end RatInterval - -end Nfp.Sound diff --git a/Legacy/Nfp/Sound/ModelHeader.lean b/Legacy/Nfp/Sound/ModelHeader.lean deleted file mode 100644 index ce25c9d..0000000 --- a/Legacy/Nfp/Sound/ModelHeader.lean +++ /dev/null @@ -1,167 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Nfp.Sound.Activation -import Nfp.Sound.Decimal - -namespace Nfp.Sound - -/-! -# Model header helpers (SOUND) - -Pure parsing utilities for extracting trusted metadata from `NFP_TEXT` model headers. --/ - -/-- Parse `key=value` header lines. -/ -def parseHeaderLine (line : String) : Option (String × String) := Id.run do - let line := line.trim - if line.isEmpty then - return none - -- Scan once to avoid `splitOn` allocations; require exactly one '='. - let s := line - let stop := s.rawEndPos - let mut eqPos : Option String.Pos.Raw := none - let mut eqCount : Nat := 0 - let mut p : String.Pos.Raw := 0 - while p < stop do - if p.get s = '=' then - eqCount := eqCount + 1 - if eqCount = 1 then - eqPos := some p - p := p.next s - if eqCount ≠ 1 then - return none - let some eq := eqPos | return none - let k := String.Pos.Raw.extract s 0 eq - let v := String.Pos.Raw.extract s (eq.next s) stop - return some (k.trim, v.trim) - -/-- Split a string on `\n`, preserving empty lines. -/ -def splitLines (s : String) : Array String := - Id.run do - let mut out : Array String := #[] - let mut start : String.Pos.Raw := 0 - let mut p : String.Pos.Raw := 0 - let stop := s.rawEndPos - while p < stop do - if p.get s = '\n' then - out := out.push (String.Pos.Raw.extract s start p) - p := p.next s - start := p - else - p := p.next s - out := out.push (String.Pos.Raw.extract s start stop) - return out - -/-- Whitespace predicate used in token scanners. -/ -@[inline] def isWsChar (c : Char) : Bool := - c = ' ' || c = '\t' || c = '\n' || c = '\r' - -/-- Count whitespace-separated tokens in a line. -/ -def countWsTokens (s : String) : Nat := - Id.run do - let mut p : String.Pos.Raw := 0 - let stop := s.rawEndPos - let mut inTok : Bool := false - let mut cnt : Nat := 0 - while p < stop do - let c := p.get s - if isWsChar c then - inTok := false - else if !inTok then - inTok := true - cnt := cnt + 1 - p := p.next s - return cnt - -/-- Find the first line index at or after `start` that satisfies `p`. -/ -def findLineIdxFrom (lines : Array String) (start : Nat) (p : String → Bool) : Option Nat := - Id.run do - let mut i := start - while i < lines.size do - if p (lines[i]!.trim) then - return some i - i := i + 1 - return none - -/-- Skip to the next line satisfying `p`, or return `lines.size` if none. -/ -def skipUntil (lines : Array String) (start : Nat) (p : String → Bool) : Nat := - match findLineIdxFrom lines start p with - | some i => i - | none => lines.size - -/-- Skip blank (whitespace-only) lines starting at `start`. -/ -def skipBlankLines (lines : Array String) (start : Nat) : Nat := - Id.run do - let mut i := start - while i < lines.size && lines[i]!.trim.isEmpty do - i := i + 1 - return i - -/-- Minimal parsed header data for sound certification. -/ -structure TextHeader where - eps : Rat - geluDerivTarget : GeluDerivTarget - deriving Repr - -private def parseGeluDerivTarget (v : String) : Except String GeluDerivTarget := - match geluDerivTargetOfString v with - | some t => .ok t - | none => .error s!"invalid gelu_kind '{v}' (expected tanh|exact)" - -/-- Parse required metadata from a `NFP_TEXT` header. -/ -def parseTextHeader (lines : Array String) : Except String TextHeader := - Id.run do - let mut i : Nat := 0 - while i < lines.size && lines[i]!.trim.isEmpty do - i := i + 1 - if !(i < lines.size) then - return .error "empty model file" - let headerTag := lines[i]!.trim - if !headerTag.startsWith "NFP_TEXT" then - return .error s!"unexpected header '{headerTag}'" - i := i + 1 - let mut eps? : Option Rat := none - let mut gelu? : Option GeluDerivTarget := none - while i < lines.size do - let line := lines[i]!.trim - if line.isEmpty then - i := lines.size - else - match parseHeaderLine line with - | none => - i := i + 1 - | some (k, v) => - if k = "layer_norm_eps" || k = "eps" then - match parseRat v with - | .error e => return .error s!"invalid layer_norm_eps '{v}': {e}" - | .ok r => eps? := some r - else if k = "gelu_kind" || k = "gelu_deriv" then - match parseGeluDerivTarget v with - | .error e => return .error e - | .ok t => gelu? := some t - i := i + 1 - let some eps := eps? | return .error "missing layer_norm_eps" - let some gelu := gelu? | return .error "missing gelu_kind" - return .ok { eps := eps, geluDerivTarget := gelu } - -/-- Parse `layer_norm_eps` (or `eps`) from a `NFP_TEXT` header. -/ -def parseTextHeaderEps (lines : Array String) : Except String Rat := do - let hdr ← parseTextHeader lines - return hdr.eps - -/-! ### Specs -/ - -theorem parseHeaderLine_spec : parseHeaderLine = parseHeaderLine := rfl -theorem splitLines_spec : splitLines = splitLines := rfl -theorem isWsChar_spec : isWsChar = isWsChar := rfl -theorem countWsTokens_spec : countWsTokens = countWsTokens := rfl -theorem findLineIdxFrom_spec : findLineIdxFrom = findLineIdxFrom := rfl -theorem skipUntil_spec : skipUntil = skipUntil := rfl -theorem skipBlankLines_spec : skipBlankLines = skipBlankLines := rfl -theorem parseGeluDerivTarget_spec (v : String) : - parseGeluDerivTarget v = parseGeluDerivTarget v := rfl -theorem parseTextHeader_spec : parseTextHeader = parseTextHeader := rfl -theorem parseTextHeaderEps_spec : parseTextHeaderEps = parseTextHeaderEps := rfl - -end Nfp.Sound diff --git a/Legacy/Nfp/Sound/TextPure.lean b/Legacy/Nfp/Sound/TextPure.lean deleted file mode 100644 index 8ab5aa3..0000000 --- a/Legacy/Nfp/Sound/TextPure.lean +++ /dev/null @@ -1,313 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Nfp.Sound.Bounds -import Nfp.Sound.Cert -import Nfp.Sound.Decimal -import Nfp.Sound.ModelHeader - -namespace Nfp.Sound - -/-! -# Pure text helpers (`NFP_TEXT`) - -Pure parsing utilities for extracting exact `Rat` bounds from text model files. --/ - -structure TextModelDims where - numLayers : Nat - numHeads : Nat - modelDim : Nat - headDim : Nat - hiddenDim : Nat - seqLen : Nat - start : Nat - deriving Repr - -/-- Per-layer weight-derived bounds extracted from a text model. -/ -structure ModelWeightBounds where - attnValueCoeff : Array Rat - wqOpBoundMax : Array Rat - wkOpBoundMax : Array Rat - mlpWinBound : Array Rat - mlpWoutBound : Array Rat - ln1MaxAbsGamma : Array Rat - ln1MaxAbsBeta : Array Rat - ln2MaxAbsGamma : Array Rat - deriving Repr - -/-- Verify that weight-derived bounds match the certificate layer fields. -/ -def checkModelWeightBounds (cert : ModelCert) (expected : ModelWeightBounds) : - Except String Unit := - checkWeightBoundsArrays cert expected.attnValueCoeff expected.wqOpBoundMax - expected.wkOpBoundMax expected.mlpWinBound expected.mlpWoutBound - expected.ln1MaxAbsGamma expected.ln1MaxAbsBeta expected.ln2MaxAbsGamma - -def parseTextHeaderDims (lines : Array String) : Except String TextModelDims := - Id.run do - let mut i : Nat := 0 - while i < lines.size && lines[i]!.trim.isEmpty do - i := i + 1 - if !(i < lines.size) then - return .error "empty model file" - let headerTag := lines[i]!.trim - if !headerTag.startsWith "NFP_TEXT" then - return .error s!"unexpected header '{headerTag}'" - i := i + 1 - let mut numLayers : Option Nat := none - let mut numHeads : Option Nat := none - let mut modelDim : Option Nat := none - let mut headDim : Option Nat := none - let mut hiddenDim : Option Nat := none - let mut seqLen : Option Nat := none - while i < lines.size do - let line := lines[i]!.trim - if line.isEmpty then - i := i + 1 - break - match parseHeaderLine line with - | none => - i := i + 1 - | some (k, v) => - match k with - | "num_layers" => numLayers := v.toNat? - | "num_heads" => numHeads := v.toNat? - | "model_dim" => modelDim := v.toNat? - | "head_dim" => headDim := v.toNat? - | "hidden_dim" => hiddenDim := v.toNat? - | "seq_len" => seqLen := v.toNat? - | _ => pure () - i := i + 1 - let some L := numLayers | return .error "missing num_layers" - let some H := numHeads | return .error "missing num_heads" - let some d := modelDim | return .error "missing model_dim" - let some dh := headDim | return .error "missing head_dim" - let some dhid := hiddenDim | return .error "missing hidden_dim" - let some n := seqLen | return .error "missing seq_len" - return .ok { - numLayers := L - numHeads := H - modelDim := d - headDim := dh - hiddenDim := dhid - seqLen := n - start := i - } - -/-- Fold `count` rationals from lines starting at `start`, returning the new state and index. -/ -def foldRatTokens {α : Type} - (lines : Array String) - (start : Nat) - (count : Nat) - (state : α) - (step : α → Rat → α) : Except String (α × Nat) := - Id.run do - let mut i := start - let mut remaining := count - let mut st := state - while remaining > 0 do - if i < lines.size then - let line := lines[i]! - i := i + 1 - let mut p : String.Pos.Raw := 0 - let stop := line.rawEndPos - while p < stop && remaining > 0 do - while p < stop && isWsChar (p.get line) do - p := p.next line - let tokStart := p - while p < stop && !isWsChar (p.get line) do - p := p.next line - if tokStart < p then - match parseRatRange line tokStart p with - | .error e => return .error e - | .ok r => - st := step st r - remaining := remaining - 1 - else - return .error "unexpected end of file while reading numbers" - return .ok (st, i) - -/-- Consume a vector of length `n` and return its values. -/ -def consumeVector - (lines : Array String) - (start : Nat) - (n : Nat) : Except String (Array Rat × Nat) := - let step := fun (acc : Array Rat) (x : Rat) => acc.push x - foldRatTokens lines start n (Array.mkEmpty n) step - -/-- Consume a vector of length `n` and return its max absolute entry. -/ -def consumeVectorMaxAbs - (lines : Array String) - (start : Nat) - (n : Nat) : Except String (Rat × Nat) := - let step := fun (acc : Rat) (x : Rat) => max acc (ratAbs x) - foldRatTokens lines start n 0 step - -/-- Consume a matrix and return its row-sum norm. -/ -def consumeMatrixNormInf - (lines : Array String) - (start : Nat) - (rows cols : Nat) : Except String (Rat × Nat) := - let count := rows * cols - if count = 0 then - .ok (0, start) - else - let step := fun (acc : Rat × Rat × Nat) (x : Rat) => - let (curRowSum, maxRowSum, colIdx) := acc - let curRowSum := curRowSum + ratAbs x - let colIdx := colIdx + 1 - if colIdx = cols then - (0, max maxRowSum curRowSum, 0) - else - (curRowSum, maxRowSum, colIdx) - match foldRatTokens lines start count (0, 0, 0) step with - | .error e => .error e - | .ok ((_, maxRowSum, _), next) => .ok (maxRowSum, next) - -/-- Compute per-layer weight bounds from text model lines. -/ -def modelWeightBoundsFromTextLines (lines : Array String) : Except String ModelWeightBounds := - Id.run do - let infoE := parseTextHeaderDims lines - let info ← - match infoE with - | .error e => return .error e - | .ok v => pure v - let mut i := info.start - let mut curLayer : Nat := 0 - let mut attnValueCoeff : Array Rat := Array.replicate info.numLayers 0 - let mut wqMax : Array Rat := Array.replicate info.numLayers 0 - let mut wkMax : Array Rat := Array.replicate info.numLayers 0 - let mut mlpWinBound : Array Rat := Array.replicate info.numLayers 0 - let mut mlpWoutBound : Array Rat := Array.replicate info.numLayers 0 - let mut ln1MaxAbsGamma : Array Rat := Array.replicate info.numLayers 0 - let mut ln1MaxAbsBeta : Array Rat := Array.replicate info.numLayers 0 - let mut ln2MaxAbsGamma : Array Rat := Array.replicate info.numLayers 0 - let updateAt := fun (arr : Array Rat) (idx : Nat) (f : Rat → Rat) => - if idx < arr.size then - arr.set! idx (f arr[idx]!) - else - arr - let setAt := fun (arr : Array Rat) (idx : Nat) (val : Rat) => - updateAt arr idx (fun _ => val) - let setMaxAt := fun (arr : Array Rat) (idx : Nat) (val : Rat) => - updateAt arr idx (fun cur => max cur val) - while i < lines.size do - let line := lines[i]!.trim - if line.startsWith "LAYER" then - let mut p : String.Pos.Raw := 0 - let stop := line.rawEndPos - while p < stop && p.get line ≠ ' ' do - p := p.next line - while p < stop && p.get line = ' ' do - p := p.next line - if p < stop then - let start := p - while p < stop && p.get line ≠ ' ' do - p := p.next line - let tok := String.Pos.Raw.extract line start p - curLayer := tok.toNat? |>.getD 0 - i := i + 1 - else if line = "W_Q" then - let r := curLayer - match consumeMatrixNormInf lines (i + 1) info.modelDim info.headDim with - | .error e => return .error e - | .ok (nq, next) => - wqMax := setMaxAt wqMax r nq - i := next - else if line = "W_K" then - let r := curLayer - match consumeMatrixNormInf lines (i + 1) info.modelDim info.headDim with - | .error e => return .error e - | .ok (nk, next) => - wkMax := setMaxAt wkMax r nk - i := next - else if line = "W_V" then - let r := curLayer - match consumeMatrixNormInf lines (i + 1) info.modelDim info.headDim with - | .error e => return .error e - | .ok (nv, next) => - i := next - while i < lines.size && lines[i]!.trim ≠ "W_O" do - i := i + 1 - if !(i < lines.size) then - return .error "expected W_O after W_V" - match consumeMatrixNormInf lines (i + 1) info.headDim info.modelDim with - | .error e => return .error e - | .ok (no, next2) => - attnValueCoeff := updateAt attnValueCoeff r (fun cur => cur + (nv * no)) - i := next2 - else if line = "W_in" then - let r := curLayer - match consumeMatrixNormInf lines (i + 1) info.modelDim info.hiddenDim with - | .error e => return .error e - | .ok (nwin, next) => - mlpWinBound := setAt mlpWinBound r nwin - i := next - else if line = "W_out" then - let r := curLayer - match consumeMatrixNormInf lines (i + 1) info.hiddenDim info.modelDim with - | .error e => return .error e - | .ok (nwout, next) => - mlpWoutBound := setAt mlpWoutBound r nwout - i := next - else if line = "LN1_GAMMA" then - let r := curLayer - match consumeVectorMaxAbs lines (i + 1) info.modelDim with - | .error e => return .error e - | .ok (g, next) => - ln1MaxAbsGamma := setAt ln1MaxAbsGamma r g - i := next - else if line = "LN1_BETA" then - let r := curLayer - match consumeVectorMaxAbs lines (i + 1) info.modelDim with - | .error e => return .error e - | .ok (b, next) => - ln1MaxAbsBeta := setAt ln1MaxAbsBeta r b - i := next - else if line = "LN2_GAMMA" then - let r := curLayer - match consumeVectorMaxAbs lines (i + 1) info.modelDim with - | .error e => return .error e - | .ok (g, next) => - ln2MaxAbsGamma := setAt ln2MaxAbsGamma r g - i := next - else if line = "LN2_BETA" then - match consumeVectorMaxAbs lines (i + 1) info.modelDim with - | .error e => return .error e - | .ok (_, next) => - i := next - else - i := i + 1 - return .ok { - attnValueCoeff := attnValueCoeff - wqOpBoundMax := wqMax - wkOpBoundMax := wkMax - mlpWinBound := mlpWinBound - mlpWoutBound := mlpWoutBound - ln1MaxAbsGamma := ln1MaxAbsGamma - ln1MaxAbsBeta := ln1MaxAbsBeta - ln2MaxAbsGamma := ln2MaxAbsGamma - } - -/-- Compute per-layer `attnValueCoeff` from text model lines. -/ -def attnValueCoeffFromTextLines (lines : Array String) : Except String (Array Rat) := do - let bounds ← modelWeightBoundsFromTextLines lines - return bounds.attnValueCoeff - -/-! ### Specs -/ - -theorem parseTextHeaderDims_spec : parseTextHeaderDims = parseTextHeaderDims := rfl -theorem ModelWeightBounds_spec : ModelWeightBounds = ModelWeightBounds := rfl -theorem checkModelWeightBounds_spec : - checkModelWeightBounds = checkModelWeightBounds := rfl -theorem foldRatTokens_spec (α : Type) : - @foldRatTokens α = @foldRatTokens α := rfl -theorem consumeVector_spec : consumeVector = consumeVector := rfl -theorem consumeVectorMaxAbs_spec : consumeVectorMaxAbs = consumeVectorMaxAbs := rfl -theorem consumeMatrixNormInf_spec : consumeMatrixNormInf = consumeMatrixNormInf := rfl -theorem modelWeightBoundsFromTextLines_spec : - modelWeightBoundsFromTextLines = modelWeightBoundsFromTextLines := rfl -theorem attnValueCoeffFromTextLines_spec : - attnValueCoeffFromTextLines = attnValueCoeffFromTextLines := rfl - -end Nfp.Sound diff --git a/Legacy/Nfp/Uniqueness.lean b/Legacy/Nfp/Uniqueness.lean deleted file mode 100644 index ff18d04..0000000 --- a/Legacy/Nfp/Uniqueness.lean +++ /dev/null @@ -1,97 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.NNReal.Basic -import Mathlib.Data.Fin.Basic -import Mathlib.Data.Finset.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic - -namespace Nfp - -open scoped BigOperators -open Finset - -variable {S : Type*} - -/-! -Auxiliary linear-uniqueness lemma used to support the Appendix A.3 uniqueness -story (fixed masks ⇒ linear system). It shows that for a finite DAG ordered by a -topological index, any two tracer families satisfying the same homogeneous linear -recurrence coincide. This is a helper fact; Appendix A.5 in the paper is a -counterexample narrative (not a formal lemma), so we avoid referring to A.5 here. -Small parent-index checks in the inductive step are kept explicit; the overall -proof structure (nested induction on the index bound) remains explicit. --/ - -/-- A local mixing system over `n` nodes, where each node `i` aggregates parents -with nonnegative coefficients, and every parent `u` of `i` has a strictly smaller index. -/ -structure LocalSystem (n : ℕ) where - Pa : Fin n → Finset (Fin n) - c : Fin n → Fin n → NNReal - topo : ∀ {i u}, u ∈ Pa i → (u.1 < i.1) - -namespace LocalSystem - -variable {n : ℕ} - -/-- A family of tracer masses at each node. It does not depend on `L`. -/ -abbrev TracerFamily (n : ℕ) := Fin n → (S → NNReal) - -/-- The recurrence equation for a tracer family against the local system. -/ -def Satisfies (L : LocalSystem n) (T : TracerFamily (S := S) n) : Prop := - ∀ i s, T i s = (L.Pa i).sum (fun u => L.c i u * T u s) - -/-- Homogeneous linear uniqueness on a topologically ordered finite DAG: if `T` and -`T'` both satisfy the same recurrence, then they are equal pointwise. -/ -theorem tracer_unique (L : LocalSystem n) {T T' : TracerFamily (S := S) n} - (hT : Satisfies (S := S) L T) (hT' : Satisfies (S := S) L T') : T = T' := by - classical - funext i s - -- Prove by induction on an upper bound `k` of the index. - have main : ∀ k, ∀ j : Fin n, j.1 ≤ k → T j s = T' j s := by - intro k - induction k with - | zero => - intro j hj - have hj0 : j.1 = 0 := Nat.le_zero.mp hj - -- No parents at index 0 - have hPaEmpty : L.Pa j = (∅ : Finset (Fin n)) := by - classical - apply Finset.eq_empty_iff_forall_notMem.mpr - intro u hu - have hlt : u.1 < j.1 := L.topo (i := j) (u := u) hu - have hlt0 : u.1 < 0 := by - have := hlt - rwa [hj0] at this - exact (Nat.not_lt_zero _ hlt0).elim - have hjT : T j s = (L.Pa j).sum (fun u => L.c j u * T u s) := (hT j s) - have hjT' : T' j s = (L.Pa j).sum (fun u => L.c j u * T' u s) := (hT' j s) - have hz1 : T j s = 0 := by simpa [hPaEmpty] using hjT - have hz2 : T' j s = 0 := by simpa [hPaEmpty] using hjT' - simpa [hz1] using hz2.symm - | succ k ih => - intro j hj - by_cases hle : j.1 ≤ k - · exact ih j hle - · have hklt : k < j.1 := Nat.not_le.mp hle - have hjeq : j.1 = k.succ := le_antisymm hj (Nat.succ_le_of_lt hklt) - have hjT : T j s = (L.Pa j).sum (fun u => L.c j u * T u s) := (hT j s) - have hjT' : T' j s = (L.Pa j).sum (fun u => L.c j u * T' u s) := (hT' j s) - have hparents : ∀ u ∈ L.Pa j, u.1 ≤ k := by - intro u hu - have hlt : u.1 < k.succ := by - simpa [hjeq] using L.topo (i := j) (u := u) hu - have hle : u.1 ≤ k := Nat.le_of_lt_succ hlt - exact hle - have hsum : (L.Pa j).sum (fun u => L.c j u * T u s) = - (L.Pa j).sum (fun u => L.c j u * T' u s) := by - classical - apply Finset.sum_congr rfl - intro u hu - have := ih u (hparents u hu) - simp [this] - simp [hjT, hjT', hsum] - exact main i.1 i le_rfl - -end LocalSystem - -end Nfp diff --git a/Legacy/Nfp/Untrusted/SoundBinary.lean b/Legacy/Nfp/Untrusted/SoundBinary.lean deleted file mode 100644 index e4797d9..0000000 --- a/Legacy/Nfp/Untrusted/SoundBinary.lean +++ /dev/null @@ -1,141 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Nfp.Sound.BinaryPure - -namespace Nfp.Untrusted.SoundBinary - -/-! -# Untrusted binary IO helpers (`NFP_BINARY_V1`) - -This module provides IO wrappers for the sound binary path. -Pure parsing/decoding lives in `Nfp.Sound.BinaryPure`. --/ - -private def readLine? (h : IO.FS.Handle) : IO (Option String) := do - let s ← h.getLine - if s.isEmpty then - return none - return some s - -def readBinaryHeader (h : IO.FS.Handle) : IO (Except String Nfp.Sound.BinaryHeader) := do - let some magicLine ← readLine? h - | return .error "empty file" - let mut lines : Array String := #[] - let mut line? ← readLine? h - while true do - match line? with - | none => return .error "unexpected EOF while reading header" - | some line => - let t := line.trim - if t = "BINARY_START" then - break - lines := lines.push line - line? ← readLine? h - return Nfp.Sound.parseBinaryHeaderLines magicLine lines - -/-- Read exactly `n` bytes or throw on EOF. -/ -def readExactly (h : IO.FS.Handle) (n : Nat) : IO ByteArray := do - if n = 0 then - return ByteArray.empty - let mut out : Array UInt8 := Array.replicate n 0 - let mut off : Nat := 0 - while off < n do - let chunk ← h.read (USize.ofNat (n - off)) - if chunk.isEmpty then - throw (IO.userError "unexpected EOF") - for b in chunk.data do - out := out.set! off b - off := off + 1 - return ByteArray.mk out - -@[inline] private def readExactlyExcept (h : IO.FS.Handle) (n : Nat) : - IO (Except String ByteArray) := do - try - return .ok (← readExactly h n) - catch - | _ => return .error "unexpected EOF" - -def skipBytes (h : IO.FS.Handle) (n : Nat) : IO (Except String Unit) := do - let mut remaining := n - while remaining > 0 do - let chunkSize := min remaining 65536 - let chunk ← h.read (USize.ofNat chunkSize) - if chunk.isEmpty then - return .error "unexpected EOF" - remaining := remaining - chunk.size - return .ok () - -def skipI32Array (h : IO.FS.Handle) (n : Nat) : IO (Except String Unit) := - skipBytes h (n * 4) - -def skipF64Array (h : IO.FS.Handle) (n : Nat) : IO (Except String Unit) := - skipBytes h (n * 8) - -def readVectorMaxAbsScaled (h : IO.FS.Handle) (n scalePow10 : Nat) : - IO (Except String Int) := do - if n = 0 then - return .ok 0 - let bytesE ← readExactlyExcept h (n * 8) - match bytesE with - | .error e => return .error e - | .ok bytes => - return Nfp.Sound.vectorMaxAbsScaledFromBytes bytes n scalePow10 - -def readMatrixNormInfScaled (h : IO.FS.Handle) (rows cols scalePow10 : Nat) : - IO (Except String Int) := do - if rows = 0 || cols = 0 then - return .ok 0 - let count := rows * cols - let bytesE ← readExactlyExcept h (count * 8) - match bytesE with - | .error e => return .error e - | .ok bytes => - return Nfp.Sound.matrixNormInfScaledFromBytes bytes rows cols scalePow10 - -def readScaledFloatArray (h : IO.FS.Handle) (count scalePow10 : Nat) : - IO (Except String (Array Int)) := do - if count = 0 then - return .ok #[] - let bytesE ← readExactlyExcept h (count * 8) - match bytesE with - | .error e => return .error e - | .ok bytes => - return Nfp.Sound.scaledFloatArrayFromBytes bytes count scalePow10 - -def readScaledFloat (h : IO.FS.Handle) (scalePow10 : Nat) : IO (Except String Int) := do - let bytesE ← readExactlyExcept h 8 - match bytesE with - | .error e => return .error e - | .ok bytes => - return Nfp.Sound.scaledFloatFromBytes bytes scalePow10 - -def readI32Array (h : IO.FS.Handle) (count : Nat) : - IO (Except String (Array Int)) := do - if count = 0 then - return .ok #[] - let bytesE ← readExactlyExcept h (count * 4) - match bytesE with - | .error e => return .error e - | .ok bytes => - return Nfp.Sound.i32ArrayFromBytes bytes count - -def readMatrixNormOneInfScaled (h : IO.FS.Handle) (rows cols scalePow10 : Nat) : - IO (Except String (Nat × Nat)) := do - if rows = 0 || cols = 0 then - return .ok (0, 0) - let count := rows * cols - let bytesE ← readExactlyExcept h (count * 8) - match bytesE with - | .error e => return .error e - | .ok bytes => - return Nfp.Sound.matrixNormOneInfScaledFromBytes bytes rows cols scalePow10 - -def readMatrixOpBoundScaled (h : IO.FS.Handle) (rows cols scalePow10 : Nat) : - IO (Except String Nat) := do - match ← readMatrixNormOneInfScaled h rows cols scalePow10 with - | .error e => return .error e - | .ok (rowSum, colSum) => - return .ok (Nfp.Sound.opBoundScaledFromOneInf rowSum colSum) - -end Nfp.Untrusted.SoundBinary diff --git a/Legacy/Nfp/Untrusted/SoundCacheIO.lean b/Legacy/Nfp/Untrusted/SoundCacheIO.lean deleted file mode 100644 index cccee27..0000000 --- a/Legacy/Nfp/Untrusted/SoundCacheIO.lean +++ /dev/null @@ -1,256 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Init.System.IO -import Nfp.Sound.CachePure -import Nfp.Sound.ModelHeader -import Nfp.Untrusted.SoundBinary - -namespace Nfp.Untrusted.SoundCacheIO - -/-! -# Untrusted SOUND fixed-point cache - -IO wrappers for the SOUND cache format. Pure parsing/encoding lives in `Nfp.Sound.CachePure`. --/ -private def appendI32LE (buf : Array UInt8) (x : Int) : Array UInt8 := - Id.run do - let ux : UInt32 := UInt32.ofInt x - let mut out := buf - out := out.push (ux &&& 0xFF).toUInt8 - out := out.push ((ux >>> 8) &&& 0xFF).toUInt8 - out := out.push ((ux >>> 16) &&& 0xFF).toUInt8 - out := out.push ((ux >>> 24) &&& 0xFF).toUInt8 - return out - -private def appendI32Array (buf : Array UInt8) (xs : Array Int) : Array UInt8 := - Id.run do - let mut out := buf - for x in xs do - out := appendI32LE out x - return out - -private def appendBytes (buf : Array UInt8) (bytes : ByteArray) : Array UInt8 := - Id.run do - let mut out := buf - for b in bytes.data do - out := out.push b - return out - -def isBinaryModelFile (path : System.FilePath) : IO (Except String Bool) := do - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let line ← h.getLine - if line.isEmpty then - return .error "empty model file" - let magic := line.trim - return .ok (magic = "NFP_BINARY_V1") - -def writeHeader (h : IO.FS.Handle) (hdr : Nfp.Sound.SoundCache.Header) : IO Unit := do - h.write (Nfp.Sound.SoundCache.encodeHeader hdr) - -def readHeader (h : IO.FS.Handle) : IO Nfp.Sound.SoundCache.Header := do - let bytes ← Nfp.Untrusted.SoundBinary.readExactly h Nfp.Sound.SoundCache.headerBytes - match Nfp.Sound.SoundCache.decodeHeader bytes with - | .ok hdr => return hdr - | .error e => throw (IO.userError e) - -/-- FNV-1a 64-bit hash of a file's bytes (stable, deterministic). -/ -def fnv1a64File (path : System.FilePath) : IO UInt64 := do - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let mut hash : UInt64 := Nfp.Sound.SoundCache.fnv1a64Init - let mut done := false - while !done do - let chunk ← h.read (USize.ofNat 1048576) - if chunk.isEmpty then - done := true - else - hash := Nfp.Sound.SoundCache.fnv1a64Update hash chunk - return hash - -/-- Ensure the cache directory exists. -/ -def ensureCacheDir : IO Unit := do - IO.FS.createDirAll Nfp.Sound.SoundCache.cacheDir - -def buildCacheBytesText - (modelPath : System.FilePath) - (scalePow10 : Nat) - (modelHash modelSize : UInt64) : IO (Except String ByteArray) := do - let contents ← IO.FS.readFile modelPath - let lines : Array String := Nfp.Sound.splitLines contents - return Nfp.Sound.SoundCache.buildCacheBytes lines scalePow10 modelHash modelSize - -def buildCacheBytesBinary - (modelPath : System.FilePath) - (scalePow10 : Nat) - (modelHash modelSize : UInt64) : IO (Except String ByteArray) := do - let action : ExceptT String IO ByteArray := do - let liftExcept {α : Type} (act : IO (Except String α)) : ExceptT String IO α := - ExceptT.mk act - - let h1 ← ExceptT.lift <| IO.FS.Handle.mk modelPath IO.FS.Mode.read - let hdr1 ← liftExcept <| Nfp.Untrusted.SoundBinary.readBinaryHeader h1 - let d := hdr1.modelDim - let dh := hdr1.headDim - let dhid := hdr1.hiddenDim - let L := hdr1.numLayers - let H := hdr1.numHeads - - let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipI32Array h1 hdr1.seqLen - let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 (hdr1.seqLen * d) - - let mut ln1Gamma : Array (Array Int) := Array.mkEmpty L - let mut ln1Beta : Array (Array Int) := Array.mkEmpty L - let mut ln2Gamma : Array (Array Int) := Array.mkEmpty L - let mut ln2Beta : Array (Array Int) := Array.mkEmpty L - - for _l in [:L] do - for _h in [:H] do - let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 (d * dh) - let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 dh - let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 (d * dh) - let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 dh - let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 (d * dh) - let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 dh - let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 (dh * d) - let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 d - let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 (d * dhid) - let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 dhid - let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 (dhid * d) - let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h1 d - let ln1G ← liftExcept <| Nfp.Untrusted.SoundBinary.readScaledFloatArray h1 d scalePow10 - let ln1B ← liftExcept <| Nfp.Untrusted.SoundBinary.readScaledFloatArray h1 d scalePow10 - let ln2G ← liftExcept <| Nfp.Untrusted.SoundBinary.readScaledFloatArray h1 d scalePow10 - let ln2B ← liftExcept <| Nfp.Untrusted.SoundBinary.readScaledFloatArray h1 d scalePow10 - ln1Gamma := ln1Gamma.push ln1G - ln1Beta := ln1Beta.push ln1B - ln2Gamma := ln2Gamma.push ln2G - ln2Beta := ln2Beta.push ln2B - - let hdrCache : Nfp.Sound.SoundCache.Header := { - modelHash := modelHash - modelSize := modelSize - scalePow10 := UInt32.ofNat scalePow10 - numLayers := UInt32.ofNat L - numHeads := UInt32.ofNat H - modelDim := UInt32.ofNat d - headDim := UInt32.ofNat dh - hiddenDim := UInt32.ofNat dhid - } - - let totalBytes : Nat := - Nfp.Sound.SoundCache.headerBytes + - Nfp.Sound.SoundCache.expectedI32Count hdrCache * 4 - let mut out : Array UInt8 := Array.mkEmpty totalBytes - out := appendBytes out (Nfp.Sound.SoundCache.encodeHeader hdrCache) - - let h2 ← ExceptT.lift <| IO.FS.Handle.mk modelPath IO.FS.Mode.read - let hdr2 ← liftExcept <| Nfp.Untrusted.SoundBinary.readBinaryHeader h2 - if hdr2.numLayers ≠ L || hdr2.numHeads ≠ H || hdr2.modelDim ≠ d || - hdr2.headDim ≠ dh || hdr2.hiddenDim ≠ dhid then - throw "binary header mismatch between passes" - - let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipI32Array h2 hdr2.seqLen - let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h2 (hdr2.seqLen * d) - - for l in [:L] do - out := appendI32Array out (ln1Gamma[l]!) - out := appendI32Array out (ln1Beta[l]!) - out := appendI32Array out (ln2Gamma[l]!) - out := appendI32Array out (ln2Beta[l]!) - for _h in [:H] do - let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h2 (d * dh) - let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h2 dh - let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h2 (d * dh) - let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h2 dh - let wV ← liftExcept <| Nfp.Untrusted.SoundBinary.readScaledFloatArray h2 (d * dh) scalePow10 - let bV ← liftExcept <| Nfp.Untrusted.SoundBinary.readScaledFloatArray h2 dh scalePow10 - let wO ← liftExcept <| Nfp.Untrusted.SoundBinary.readScaledFloatArray h2 (dh * d) scalePow10 - out := appendI32Array out wV - out := appendI32Array out bV - out := appendI32Array out wO - let attnBias ← liftExcept <| Nfp.Untrusted.SoundBinary.readScaledFloatArray h2 d scalePow10 - let wIn ← liftExcept <| Nfp.Untrusted.SoundBinary.readScaledFloatArray h2 (d * dhid) scalePow10 - let bIn ← liftExcept <| Nfp.Untrusted.SoundBinary.readScaledFloatArray h2 dhid scalePow10 - let wOut ← liftExcept <| Nfp.Untrusted.SoundBinary.readScaledFloatArray h2 (dhid * d) scalePow10 - let bOut ← liftExcept <| Nfp.Untrusted.SoundBinary.readScaledFloatArray h2 d scalePow10 - out := appendI32Array out attnBias - out := appendI32Array out wIn - out := appendI32Array out bIn - out := appendI32Array out wOut - out := appendI32Array out bOut - let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h2 d - let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h2 d - let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h2 d - let _ ← liftExcept <| Nfp.Untrusted.SoundBinary.skipF64Array h2 d - - if out.size ≠ totalBytes then - throw s!"cache size mismatch: expected {totalBytes}, got {out.size}" - return ByteArray.mk out - action.run - -/-- Build (or overwrite) a SOUND fixed-point cache file. -/ -def buildCacheFile - (modelPath cachePath : System.FilePath) - (scalePow10 : Nat := 9) : IO (Except String Unit) := do - ensureCacheDir - let modelHash ← fnv1a64File modelPath - let mdata ← modelPath.metadata - let modelSize : UInt64 := mdata.byteSize - match ← buildCacheBytesText modelPath scalePow10 modelHash modelSize with - | .error e => return .error e - | .ok bytes => - let tmpPath := cachePath.withExtension "tmp" - if (← tmpPath.pathExists) then - IO.FS.removeFile tmpPath - let out ← IO.FS.Handle.mk tmpPath IO.FS.Mode.write - out.write bytes - out.flush - if (← cachePath.pathExists) then - IO.FS.removeFile cachePath - IO.FS.rename tmpPath cachePath - return .ok () - -/-! ## Consistency checks (for CI and debugging) -/ - -/-- Check that for each numeric token in the text file, its exact `Rat` value lies in the -`±1`-ulp interval induced by `parseFixed10Rounded scalePow10`. -/ -def checkTextTokenEnvelope - (modelPath : System.FilePath) - (scalePow10 : Nat := 9) - (maxTokens : Nat := 0) : IO (Except String Unit) := do - let contents ← IO.FS.readFile modelPath - let lines : Array String := Nfp.Sound.splitLines contents - return Nfp.Sound.SoundCache.checkTextTokenEnvelopeLines lines scalePow10 maxTokens - -/-- Check that the cache file size matches the expected tensor stream length. -/ -def checkCacheFileSize (cachePath : System.FilePath) (hdr : Nfp.Sound.SoundCache.Header) : - IO (Except String Unit) := do - let mdata ← cachePath.metadata - let expectedBytes := Nfp.Sound.SoundCache.expectedCacheBytes hdr - if mdata.byteSize = expectedBytes then - return .ok () - else - return .error s!"cache size mismatch: expected {expectedBytes}, got {mdata.byteSize}" - -/-! ## Cache reader (buffered) -/ - -def I32Reader.init (h : IO.FS.Handle) : IO Nfp.Sound.SoundCache.I32Reader := - pure { h := h, buf := ByteArray.empty, pos := 0 } - -private def I32Reader.refill (r : Nfp.Sound.SoundCache.I32Reader) : - IO Nfp.Sound.SoundCache.I32Reader := do - let chunk ← r.h.read (USize.ofNat 1048576) - if chunk.isEmpty then - throw (IO.userError "unexpected EOF while reading cache") - return { r with buf := chunk, pos := 0 } - -def I32Reader.readI32 (r : Nfp.Sound.SoundCache.I32Reader) : - IO (Int × Nfp.Sound.SoundCache.I32Reader) := do - let r ← - if r.pos + 4 ≤ r.buf.size then - pure r - else - I32Reader.refill r - let x := Nfp.Sound.SoundCache.i32FromBuffer r.buf r.pos - return (x, { r with pos := r.pos + 4 }) -end Nfp.Untrusted.SoundCacheIO diff --git a/Legacy/Nfp/Untrusted/SoundCompute.lean b/Legacy/Nfp/Untrusted/SoundCompute.lean deleted file mode 100644 index 1bcbdb8..0000000 --- a/Legacy/Nfp/Untrusted/SoundCompute.lean +++ /dev/null @@ -1,8588 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Nfp.Sound.Cert -import Nfp.Sound.HeadCert -import Nfp.Sound.ModelHeader -import Nfp.Sound.TextPure -import Nfp.Untrusted.SoundBinary -import Nfp.Sound.Interval -import Nfp.Sound.Affine -import Nfp.Untrusted.SoundCacheIO -import Nfp.Sound.Fixed - -namespace Nfp.Untrusted.SoundCompute - -open IO -open Nfp.Sound -open Nfp.Untrusted.SoundBinary - -/-! -# Untrusted SOUND computation helpers - -This module performs **IO-heavy witness generation** for SOUND certification. It parses `.nfpt` -models (binary, plus legacy text for some paths) and computes candidate certificates for: -- model-level residual amplification bounds, -- per-head contribution bounds, -- local head-pattern / best-match / induction certificates. - -It does **not** construct the full `ConcreteModel` (Float-based). Instead it parses only the -weights needed for conservative residual amplification constants `Cᵢ` (bounds ‖layerJacobian - I‖), -using exact `Rat` arithmetic or fixed-point interval arithmetic. - -All certificates produced here are **untrusted** and must be validated by the trusted checker -in `Nfp.Sound.IO`. - -Trusted base: -- Parsing from text to `Rat` via `Nfp.Sound.parseRat`. -- Exact accumulation of row-sum norms and max-abs values. - -No `Float` arithmetic is *trusted* as an input to certification. --/ - -private def defaultBinaryScalePow10 : Nat := 9 - -private def maxAbsOfVector (xs : Array Rat) : Rat := - xs.foldl (fun acc x => max acc (ratAbs x)) 0 - -/-- Compute weight-only per-head contribution bounds from a binary `.nfpt`. -/ -def certifyHeadBoundsBinary - (path : System.FilePath) - (scalePow10 : Nat := defaultBinaryScalePow10) : - IO (Except String (Array HeadContributionCert)) := do - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - match ← readBinaryHeader h with - | .error e => return .error e - | .ok hdr => - match ← skipI32Array h hdr.seqLen with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.seqLen * hdr.modelDim) with - | .error e => return .error e - | .ok _ => pure () - let mut heads : Array HeadContributionCert := Array.mkEmpty (hdr.numLayers * hdr.numHeads) - for l in [:hdr.numLayers] do - for hIdx in [:hdr.numHeads] do - let wqScaledE ← readMatrixOpBoundScaled h hdr.modelDim hdr.headDim scalePow10 - let wqScaled ← - match wqScaledE with - | .error e => return .error e - | .ok v => pure v - match ← skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - let wkScaledE ← readMatrixOpBoundScaled h hdr.modelDim hdr.headDim scalePow10 - let wkScaled ← - match wkScaledE with - | .error e => return .error e - | .ok v => pure v - match ← skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - let wvScaledE ← readMatrixOpBoundScaled h hdr.modelDim hdr.headDim scalePow10 - let wvScaled ← - match wvScaledE with - | .error e => return .error e - | .ok v => pure v - match ← skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - let woScaledE ← readMatrixOpBoundScaled h hdr.headDim hdr.modelDim scalePow10 - let woScaled ← - match woScaledE with - | .error e => return .error e - | .ok v => pure v - let wqOp := ratOfScaledNat scalePow10 wqScaled - let wkOp := ratOfScaledNat scalePow10 wkScaled - let wvOp := ratOfScaledNat scalePow10 wvScaled - let woOp := ratOfScaledNat scalePow10 woScaled - let cert : HeadContributionCert := { - layerIdx := l - headIdx := hIdx - wqOpBound := wqOp - wkOpBound := wkOp - wvOpBound := wvOp - woOpBound := woOp - qkFactorBound := wqOp * wkOp - voFactorBound := wvOp * woOp - } - if cert.check then - heads := heads.push cert - else - return .error "head contribution certificate failed internal checks" - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.modelDim * hdr.hiddenDim) with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.hiddenDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.hiddenDim * hdr.modelDim) with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.modelDim * hdr.vocabSize) with - | .error e => return .error e - | .ok _ => pure () - return .ok heads - -private def certifyModelFileGlobalBinary - (path : System.FilePath) - (eps : Rat) - (geluDerivTarget : GeluDerivTarget) - (soundnessBits : Nat) - (partitionDepth : Nat) - (softmaxMarginLowerBound : Rat := 0) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : IO (Except String ModelCert) := do - if partitionDepth ≠ 0 then - return .error "partitionDepth > 0 not yet implemented" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - match ← readBinaryHeader h with - | .error e => return .error e - | .ok hdr => - let scalePow10 := defaultBinaryScalePow10 - let actDerivBound := geluDerivBoundGlobal geluDerivTarget - match ← skipI32Array h hdr.seqLen with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.seqLen * hdr.modelDim) with - | .error e => return .error e - | .ok _ => pure () - let mut layers : Array LayerAmplificationCert := Array.mkEmpty hdr.numLayers - let mut totalAmp : Rat := 1 - for l in [:hdr.numLayers] do - let mut attnValueCoeff : Rat := 0 - let mut wqMax : Rat := 0 - let mut wkMax : Rat := 0 - for _h in [:hdr.numHeads] do - let wqScaledE ← readMatrixNormInfScaled h hdr.modelDim hdr.headDim scalePow10 - let wqScaled ← - match wqScaledE with - | .error e => return .error e - | .ok v => pure v - match ← skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - let wkScaledE ← readMatrixNormInfScaled h hdr.modelDim hdr.headDim scalePow10 - let wkScaled ← - match wkScaledE with - | .error e => return .error e - | .ok v => pure v - match ← skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - let nvScaledE ← readMatrixNormInfScaled h hdr.modelDim hdr.headDim scalePow10 - let nvScaled ← - match nvScaledE with - | .error e => return .error e - | .ok v => pure v - match ← skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - let noScaledE ← readMatrixNormInfScaled h hdr.headDim hdr.modelDim scalePow10 - let noScaled ← - match noScaledE with - | .error e => return .error e - | .ok v => pure v - let wq := ratOfScaledInt scalePow10 wqScaled - let wk := ratOfScaledInt scalePow10 wkScaled - let nv := ratOfScaledInt scalePow10 nvScaled - let no := ratOfScaledInt scalePow10 noScaled - wqMax := max wqMax wq - wkMax := max wkMax wk - attnValueCoeff := attnValueCoeff + nv * no - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - let nWinScaledE ← - readMatrixNormInfScaled h hdr.modelDim hdr.hiddenDim scalePow10 - let nWinScaled ← - match nWinScaledE with - | .error e => return .error e - | .ok v => pure v - match ← skipF64Array h hdr.hiddenDim with - | .error e => return .error e - | .ok _ => pure () - let nWoutScaledE ← - readMatrixNormInfScaled h hdr.hiddenDim hdr.modelDim scalePow10 - let nWoutScaled ← - match nWoutScaledE with - | .error e => return .error e - | .ok v => pure v - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - let ln1GammaScaledE ← readVectorMaxAbsScaled h hdr.modelDim scalePow10 - let ln1GammaScaled ← - match ln1GammaScaledE with - | .error e => return .error e - | .ok v => pure v - let ln1BetaScaledE ← readVectorMaxAbsScaled h hdr.modelDim scalePow10 - let ln1BetaScaled ← - match ln1BetaScaledE with - | .error e => return .error e - | .ok v => pure v - let ln2GammaScaledE ← readVectorMaxAbsScaled h hdr.modelDim scalePow10 - let ln2GammaScaled ← - match ln2GammaScaledE with - | .error e => return .error e - | .ok v => pure v - let ln2BetaScaledE ← readVectorMaxAbsScaled h hdr.modelDim scalePow10 - match ln2BetaScaledE with - | .error e => return .error e - | .ok _ => pure () - let ln1Max := ratOfScaledInt scalePow10 ln1GammaScaled - let ln1MaxAbsBeta := ratOfScaledInt scalePow10 ln1BetaScaled - let ln2Max := ratOfScaledInt scalePow10 ln2GammaScaled - let nWin := ratOfScaledInt scalePow10 nWinScaled - let nWout := ratOfScaledInt scalePow10 nWoutScaled - let ln1Bound := layerNormOpBoundConservative ln1Max eps soundnessBits - let ln2Bound := layerNormOpBoundConservative ln2Max eps soundnessBits - let ln1OutMaxAbsBound := layerNormOutputMaxAbsBound hdr.modelDim ln1Max ln1MaxAbsBeta - let attnPatternCoeff := - attnPatternCoeffBound hdr.seqLen hdr.modelDim hdr.headDim ln1OutMaxAbsBound - wqMax wkMax attnValueCoeff - let mlpCoeff := nWin * nWout - let mlpActDerivBound := actDerivBound - let scoreAbsBound := - attnScoreAbsBound hdr.modelDim hdr.headDim ln1OutMaxAbsBound wqMax wkMax - let (softmaxProbLo, softmaxProbHi) := - softmaxProbIntervalFromScoreAbsBound hdr.seqLen scoreAbsBound softmaxExpEffort - let softmaxIntervalBound := softmaxJacobianNormInfBound softmaxProbLo softmaxProbHi - let softmaxMarginBound := - softmaxJacobianNormInfBoundFromMargin hdr.seqLen softmaxMarginLowerBound softmaxExpEffort - let softmaxBound := min softmaxIntervalBound softmaxMarginBound - let attnW := - ln1Bound * - ((hdr.seqLen : Rat) * attnValueCoeff + softmaxBound * attnPatternCoeff) - let mlpW := ln2Bound * (mlpCoeff * mlpActDerivBound) - let C := attnW + mlpW + attnW * mlpW - layers := layers.push { - layerIdx := l - ln1MaxAbsGamma := ln1Max - ln1MaxAbsBeta := ln1MaxAbsBeta - ln2MaxAbsGamma := ln2Max - ln1VarianceLowerBound? := none - ln2VarianceLowerBound? := none - ln1Bound := ln1Bound - ln2Bound := ln2Bound - ln1OutMaxAbsBound := ln1OutMaxAbsBound - softmaxProbLo := softmaxProbLo - softmaxProbHi := softmaxProbHi - softmaxMarginLowerBound := softmaxMarginLowerBound - softmaxExpEffort := softmaxExpEffort - softmaxJacobianNormInfUpperBound := softmaxBound - wqOpBoundMax := wqMax - wkOpBoundMax := wkMax - attnValueCoeff := attnValueCoeff - attnPatternCoeff := attnPatternCoeff - mlpCoeff := mlpCoeff - mlpWinBound := nWin - mlpWoutBound := nWout - mlpActDerivBound := mlpActDerivBound - attnJacBound := attnW - mlpJacBound := mlpW - C := C - } - totalAmp := totalAmp * (1 + C) - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.modelDim * hdr.vocabSize) with - | .error e => return .error e - | .ok _ => pure () - let cert : ModelCert := { - modelPath := path.toString - inputPath? := none - inputDelta := 0 - eps := eps - seqLen := hdr.seqLen - modelDim := hdr.modelDim - headDim := hdr.headDim - soundnessBits := soundnessBits - geluDerivTarget := geluDerivTarget - actDerivBound := actDerivBound - softmaxJacobianNormInfWorst := softmaxJacobianNormInfWorst - layers := layers - totalAmplificationFactor := totalAmp - } - if cert.check then - return .ok cert - return .error "sound certificate failed internal consistency checks" - -private def addVecIntervals (a b : Array RatInterval) : Array RatInterval := - Id.run do - if a.size ≠ b.size then - return a - let mut out : Array RatInterval := Array.mkEmpty a.size - for i in [:a.size] do - out := out.push (RatInterval.add a[i]! b[i]!) - return out - -private def addConstVec (a : Array RatInterval) (b : Array Rat) : Array RatInterval := - Id.run do - if a.size ≠ b.size then - return a - let mut out : Array RatInterval := Array.mkEmpty a.size - for i in [:a.size] do - out := out.push (RatInterval.add a[i]! (RatInterval.const b[i]!)) - return out - -private def unionVecIntervals (a b : Array RatInterval) : Array RatInterval := - Id.run do - if a.size ≠ b.size then - return a - let mut out : Array RatInterval := Array.mkEmpty a.size - for i in [:a.size] do - out := out.push (RatInterval.union a[i]! b[i]!) - return out - -private def zeroIntervals (n : Nat) : Array RatInterval := - Array.replicate n (RatInterval.const 0) - -/-- Max GeLU derivative bound across a vector of rational intervals. -/ -private def maxGeluDerivBound (target : GeluDerivTarget) (xs : Array RatInterval) : Rat := - xs.foldl (fun acc x => max acc (RatInterval.geluDerivBound target x)) 0 - -/-- Sum of per-coordinate centered absolute bounds (interval widths). -/ -private def centeredAbsSum (xs : Array RatInterval) : Rat := - xs.foldl (fun acc x => acc + RatInterval.centeredAbsBound x) 0 - -/-- Max GeLU derivative bound across fixed-point intervals (converted to `Rat`). -/ -private def maxGeluDerivBoundFixed (cfg : Fixed10Cfg) (target : GeluDerivTarget) - (xs : Array Fixed10Interval) : Rat := - xs.foldl (fun acc x => max acc (Fixed10Interval.geluDerivBound cfg target x)) 0 - -private def unionRows (rows : Array (Array RatInterval)) (dim : Nat) : Array RatInterval := - Id.run do - if rows.isEmpty then - return zeroIntervals dim - let mut out : Array RatInterval := zeroIntervals dim - let r0 := rows[0]! - if r0.size = dim then - out := r0 - for r in rows do - if r.size = dim then - out := unionVecIntervals out r - return out - -private def layerNormRowApprox (row : Array RatInterval) (gamma beta : Array Rat) (eps : Rat) - (soundnessBits : Nat) : (Array RatInterval × Rat) := - if row.size = 0 || gamma.size ≠ row.size || beta.size ≠ row.size then - (row, 0) - else - Id.run do - let μ := RatInterval.mean row - let varLB := RatInterval.varianceLowerBound row - let invσUpper : Rat := - if varLB ≤ 0 then - -- Sound fallback for IBP propagation: `1/σ ≤ 1/eps` (conservative, but rigorous). - layerNormOpBoundConservative 1 eps soundnessBits - else - layerNormOpBoundLocal 1 varLB eps soundnessBits - let mut out : Array RatInterval := Array.mkEmpty row.size - for i in [:row.size] do - let centered := RatInterval.sub row[i]! μ - let scaled := RatInterval.scale (gamma[i]! * invσUpper) centered - out := out.push (RatInterval.add scaled (RatInterval.const beta[i]!)) - return (out, varLB) - -private def minVarAcrossRows (rows : Array (Array RatInterval)) : Rat := - Id.run do - let mut best : Option Rat := none - for r in rows do - let v := RatInterval.varianceLowerBound r - best := some (match best with | none => v | some b => min b v) - best.getD 0 - -private def findLineIdxFrom - (lines : Array String) (start : Nat) (p : String → Bool) : Option Nat := - Nfp.Sound.findLineIdxFrom lines start p - -private def skipUntil (lines : Array String) (start : Nat) (p : String → Bool) : Nat := - Nfp.Sound.skipUntil lines start p - -private def skipBlankLines (lines : Array String) (start : Nat) : Nat := - Nfp.Sound.skipBlankLines lines start - -/-! -### Fast skipping without parsing - -For local SOUND certification we do not need `W_Q`, `W_K`, `b_Q`, or `b_K` numerically -(they don't affect the Jacobian bounds we certify in this streaming-only pass). - -Parsing decimals into `Rat` is expensive, so we skip these sections by **counting tokens** -instead of calling `parseRat`. --/ - -@[inline] private def countWsTokens (s : String) : Nat := - Nfp.Sound.countWsTokens s - -private def consumeTokensSkipFast - (lines : Array String) (start : Nat) (numTokens : Nat) : Except String Nat := - Id.run do - let mut iLine := start - let mut remaining := numTokens - while remaining > 0 do - if iLine ≥ lines.size then - return .error "unexpected end of file while skipping tokens" - let line := lines[iLine]! - iLine := iLine + 1 - let c := countWsTokens line - if c = 0 then - pure () - else if c ≥ remaining then - remaining := 0 - else - remaining := remaining - c - return .ok iLine - -private def consumeMatrixSkip - (lines : Array String) - (start : Nat) - (rows cols : Nat) : Except String Nat := - match foldRatTokens lines start (rows * cols) () (fun _ _ => ()) with - | .error e => .error e - | .ok (_, next) => .ok next - -private def consumeMatrixSkipFast - (lines : Array String) - (start : Nat) - (rows cols : Nat) : Except String Nat := - consumeTokensSkipFast lines start (rows * cols) - -private def consumeVectorSkipFast - (lines : Array String) - (start : Nat) - (n : Nat) : Except String Nat := - consumeTokensSkipFast lines start n - -/-- Accumulator for streaming matrix multiplication with row-abs tracking. -/ -private structure MulAndNormAcc where - out : Array RatInterval - row : Nat - col : Nat - curRowAbs : Rat - maxRowAbs : Rat - -/-! -Streaming multiplication for row-major stored matrices. - -The `.nfpt` format stores matrices row-major with `rows` = input dimension and `cols` = output -dimension in the repo's row-vector convention: `y = x · W` where `W : rows×cols`. - -We compute `y` in a single pass over weights by accumulating contributions row-by-row: -for each input index `i`, parse the `i`-th row `w_{i,*}` and add `w_{i,j} * x[i]` into `y[j]`. -This never stores the matrix. --/ -private def consumeMatrixMulAndNormInf - (lines : Array String) - (start : Nat) - (rows cols : Nat) - (input : Array RatInterval) : Except String (Array RatInterval × Rat × Nat) := - Id.run do - if input.size ≠ rows then - return .error "input interval dimension mismatch" - let init : MulAndNormAcc := { - out := zeroIntervals cols - row := 0 - col := 0 - curRowAbs := 0 - maxRowAbs := 0 - } - let step := fun (st : MulAndNormAcc) (w : Rat) => - let r := st.row - let c := st.col - let curRowAbs := st.curRowAbs + ratAbs w - -- out[c] += w * input[r] - let term := RatInterval.scale w (input[r]!) - let out := st.out.set! c (RatInterval.add (st.out[c]!) term) - if c + 1 = cols then - { out := out - row := r + 1 - col := 0 - curRowAbs := 0 - maxRowAbs := max st.maxRowAbs curRowAbs } - else - { out := out - row := r - col := c + 1 - curRowAbs := curRowAbs - maxRowAbs := st.maxRowAbs } - match foldRatTokens lines start (rows * cols) init step with - | .error e => return .error e - | .ok (st, next) => - -- Account for a partial last row (should not happen if rows*cols consumed). - let maxRowAbs := max st.maxRowAbs st.curRowAbs - return .ok (st.out, maxRowAbs, next) - -/-- Soundly compute conservative per-layer residual amplification constants from a `.nfpt` file. -/ -def certifyModelFileGlobal - (path : System.FilePath) - (eps : Rat) - (geluDerivTarget : GeluDerivTarget) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (partitionDepth : Nat := 0) - (softmaxMarginLowerBound : Rat := 0) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : IO (Except String ModelCert) := do - if partitionDepth ≠ 0 then - return .error "partitionDepth > 0 not yet implemented" - let actDerivBound := geluDerivBoundGlobal geluDerivTarget - let contents ← IO.FS.readFile path - let lines : Array String := Nfp.Sound.splitLines contents - -- Header - let mut i : Nat := 0 - while i < lines.size && lines[i]!.trim.isEmpty do - i := i + 1 - if !(i < lines.size) then - return .error "empty file" - let headerTag := lines[i]!.trim - if !headerTag.startsWith "NFP_TEXT" then - return .error s!"unexpected header '{headerTag}'" - i := i + 1 - let mut numLayers : Option Nat := none - let mut numHeads : Option Nat := none - let mut modelDim : Option Nat := none - let mut headDim : Option Nat := none - let mut hiddenDim : Option Nat := none - let mut seqLen : Option Nat := none - while i < lines.size do - let line := lines[i]!.trim - if line.isEmpty then - i := i + 1 - break - match parseHeaderLine line with - | none => - i := i + 1 - | some (k, v) => - match k with - | "num_layers" => numLayers := v.toNat? - | "num_heads" => numHeads := v.toNat? - | "model_dim" => modelDim := v.toNat? - | "head_dim" => headDim := v.toNat? - | "hidden_dim" => hiddenDim := v.toNat? - | "seq_len" => seqLen := v.toNat? - | _ => pure () - i := i + 1 - let some L := numLayers | return .error "missing num_layers" - let some _ := numHeads | return .error "missing num_heads" - let some d := modelDim | return .error "missing model_dim" - let some dh := headDim | return .error "missing head_dim" - let some dhid := hiddenDim | return .error "missing hidden_dim" - let some n := seqLen | return .error "missing seq_len" - let inputVarLowerMin? : Option Rat := none - -- Accumulators - let mut ln1GammaMax : Array Rat := Array.replicate L 1 - let mut ln1BetaMax : Array Rat := Array.replicate L 0 - let mut ln2GammaMax : Array Rat := Array.replicate L 1 - let mut attnValueCoeff : Array Rat := Array.replicate L 0 - let mut wqMax : Array Rat := Array.replicate L 0 - let mut wkMax : Array Rat := Array.replicate L 0 - let mut mlpWin : Array Rat := Array.replicate L 0 - let mut mlpWout : Array Rat := Array.replicate L 0 - let mut curLayer : Nat := 0 - -- Scan remaining sections - while i < lines.size do - let line := lines[i]!.trim - if line.startsWith "LAYER" then - let parts := line.splitOn " " |>.filter (· ≠ "") - if parts.length >= 2 then - curLayer := (parts[1]!).toNat? |>.getD 0 - i := i + 1 - else if line = "W_Q" then - let r := curLayer - match consumeMatrixNormInf lines (i + 1) d dh with - | .error e => return .error e - | .ok (nq, next) => - if r < wqMax.size then - wqMax := wqMax.set! r (max wqMax[r]! nq) - i := next - else if line = "W_K" then - let r := curLayer - match consumeMatrixNormInf lines (i + 1) d dh with - | .error e => return .error e - | .ok (nk, next) => - if r < wkMax.size then - wkMax := wkMax.set! r (max wkMax[r]! nk) - i := next - else if line = "W_V" then - -- Expect: modelDim × headDim - let r := curLayer - match consumeMatrixNormInf lines (i + 1) d dh with - | .error e => return .error e - | .ok (nv, next) => - -- Find W_O next by scanning forward (format guarantee: W_O follows eventually) - i := next - while i < lines.size && lines[i]!.trim ≠ "W_O" do - i := i + 1 - if !(i < lines.size) then - return .error "expected W_O after W_V" - match consumeMatrixNormInf lines (i + 1) dh d with - | .error e => return .error e - | .ok (no, next2) => - if r < attnValueCoeff.size then - attnValueCoeff := attnValueCoeff.set! r (attnValueCoeff[r]! + (nv * no)) - i := next2 - else if line = "W_in" then - match consumeMatrixNormInf lines (i + 1) d dhid with - | .error e => return .error e - | .ok (n, next) => - if curLayer < mlpWin.size then - mlpWin := mlpWin.set! curLayer n - i := next - else if line = "W_out" then - match consumeMatrixNormInf lines (i + 1) dhid d with - | .error e => return .error e - | .ok (n, next) => - if curLayer < mlpWout.size then - mlpWout := mlpWout.set! curLayer n - i := next - else if line = "LN1_GAMMA" then - match consumeVectorMaxAbs lines (i + 1) d with - | .error e => return .error e - | .ok (m, next) => - if curLayer < ln1GammaMax.size then - ln1GammaMax := ln1GammaMax.set! curLayer m - i := next - else if line = "LN1_BETA" then - match consumeVectorMaxAbs lines (i + 1) d with - | .error e => return .error e - | .ok (m, next) => - if curLayer < ln1BetaMax.size then - ln1BetaMax := ln1BetaMax.set! curLayer m - i := next - else if line = "LN2_GAMMA" then - match consumeVectorMaxAbs lines (i + 1) d with - | .error e => return .error e - | .ok (m, next) => - if curLayer < ln2GammaMax.size then - ln2GammaMax := ln2GammaMax.set! curLayer m - i := next - else - -- default: advance - i := i + 1 - -- Build layer reports - let mut layers : Array LayerAmplificationCert := Array.mkEmpty L - let mut totalAmp : Rat := 1 - let mut actDerivBoundMax : Rat := 0 - for l in [:L] do - let ln1Max := ln1GammaMax[l]! - let ln1MaxAbsBeta := ln1BetaMax[l]! - let ln2Max := ln2GammaMax[l]! - let ln1Var? : Option Rat := if l = 0 then inputVarLowerMin? else none - let ln2Var? : Option Rat := none - let ln1Bound := - match ln1Var? with - | some v => layerNormOpBoundLocal ln1Max v eps soundnessBits - | none => layerNormOpBoundConservative ln1Max eps soundnessBits - let ln2Bound := - match ln2Var? with - | some v => layerNormOpBoundLocal ln2Max v eps soundnessBits - | none => layerNormOpBoundConservative ln2Max eps soundnessBits - let ln1OutMaxAbsBound := layerNormOutputMaxAbsBound d ln1Max ln1MaxAbsBeta - let attnValueCoeffLayer := attnValueCoeff[l]! - let attnPatternCoeff := - attnPatternCoeffBound n d dh ln1OutMaxAbsBound (wqMax[l]!) (wkMax[l]!) - attnValueCoeffLayer - let mlpCoeff := mlpWin[l]! * mlpWout[l]! - let mlpActDerivBound := actDerivBound - let scoreAbsBound := - attnScoreAbsBound d dh ln1OutMaxAbsBound (wqMax[l]!) (wkMax[l]!) - let (softmaxProbLo, softmaxProbHi) := - softmaxProbIntervalFromScoreAbsBound n scoreAbsBound softmaxExpEffort - let softmaxIntervalBound := softmaxJacobianNormInfBound softmaxProbLo softmaxProbHi - let softmaxMarginBound := - softmaxJacobianNormInfBoundFromMargin n softmaxMarginLowerBound softmaxExpEffort - let softmaxBound := min softmaxIntervalBound softmaxMarginBound - let attnW := - ln1Bound * ((n : Rat) * attnValueCoeffLayer + softmaxBound * attnPatternCoeff) - let mlpW := ln2Bound * (mlpCoeff * mlpActDerivBound) - let C := attnW + mlpW + attnW * mlpW - layers := layers.push { - layerIdx := l - ln1MaxAbsGamma := ln1Max - ln1MaxAbsBeta := ln1MaxAbsBeta - ln2MaxAbsGamma := ln2Max - ln1VarianceLowerBound? := ln1Var? - ln2VarianceLowerBound? := ln2Var? - ln1Bound := ln1Bound - ln2Bound := ln2Bound - ln1OutMaxAbsBound := ln1OutMaxAbsBound - softmaxProbLo := softmaxProbLo - softmaxProbHi := softmaxProbHi - softmaxMarginLowerBound := softmaxMarginLowerBound - softmaxExpEffort := softmaxExpEffort - softmaxJacobianNormInfUpperBound := softmaxBound - wqOpBoundMax := wqMax[l]! - wkOpBoundMax := wkMax[l]! - attnValueCoeff := attnValueCoeffLayer - attnPatternCoeff := attnPatternCoeff - mlpCoeff := mlpCoeff - mlpWinBound := mlpWin[l]! - mlpWoutBound := mlpWout[l]! - mlpActDerivBound := mlpActDerivBound - attnJacBound := attnW - mlpJacBound := mlpW - C := C - } - totalAmp := totalAmp * (1 + C) - actDerivBoundMax := max actDerivBoundMax mlpActDerivBound - let cert : ModelCert := { - modelPath := path.toString - inputPath? := inputPath?.map (·.toString) - inputDelta := inputDelta - eps := eps - seqLen := n - modelDim := d - headDim := dh - soundnessBits := soundnessBits - geluDerivTarget := geluDerivTarget - actDerivBound := actDerivBoundMax - softmaxJacobianNormInfWorst := softmaxJacobianNormInfWorst - layers := layers - totalAmplificationFactor := totalAmp - } - if cert.check then - return .ok cert - return .error "sound certificate failed internal consistency checks" - -/-- Parse input `EMBEDDINGS` from an `.nfpt` file and return intervals `xᵢ ∈ [xᵢ-δ, xᵢ+δ]` -as an array of rows (`seqLen` rows, each of length `modelDim`). -/ -private def loadEmbeddingsIntervals - (path : System.FilePath) (seqLen modelDim : Nat) (delta : Rat) : - IO (Except String (Array (Array RatInterval))) := do - let contents ← IO.FS.readFile path - let lines : Array String := Nfp.Sound.splitLines contents - let mut i : Nat := 0 - while i < lines.size && lines[i]!.trim.isEmpty do - i := i + 1 - if !(i < lines.size) then - return .error "empty input file" - let headerTag := lines[i]!.trim - if !headerTag.startsWith "NFP_TEXT" then - return .error s!"unexpected input header '{headerTag}'" - i := i + 1 - -- Scan to EMBEDDINGS (optionally skipping TOKENS). - i := skipUntil lines i (fun s => s = "EMBEDDINGS") - if !(i < lines.size) then - return .error "Missing EMBEDDINGS section in input file" - i := i + 1 - let step := - fun (st : (Array (Array RatInterval) × Array RatInterval)) (x : Rat) => - let (rows, cur) := st - let cur := cur.push { lo := x - delta, hi := x + delta } - if cur.size = modelDim then - (rows.push cur, #[]) - else - (rows, cur) - match foldRatTokens lines i (seqLen * modelDim) (#[], #[]) step with - | .error e => return .error e - | .ok ((rows, cur), _) => - if cur.size ≠ 0 then - return .error "EMBEDDINGS parse ended mid-row" - if rows.size ≠ seqLen then - return .error s!"EMBEDDINGS length mismatch: expected {seqLen} rows, got {rows.size}" - return .ok rows - -private structure LayerNormParams where - gamma : Array Rat - beta : Array Rat - -private structure LayerNormParamsFixed where - gamma : Array Fixed10Interval - beta : Array Fixed10Interval - -private def intervalsFromScaled (xs : Array Int) (slack : Int) : Array Fixed10Interval := - xs.map (fun x => { lo := x - slack, hi := x + slack }) - -private def collectLayerNormParams - (lines : Array String) (L d : Nat) : - Except String (Array LayerNormParams × Array LayerNormParams) := - Id.run do - let defP : LayerNormParams := { gamma := Array.replicate d 1, beta := Array.replicate d 0 } - let mut ln1 : Array LayerNormParams := - Array.replicate L defP - let mut ln2 : Array LayerNormParams := - Array.replicate L defP - let mut i : Nat := 0 - let mut curLayer : Nat := 0 - while i < lines.size do - let s := lines[i]!.trim - if s.startsWith "LAYER" then - let parts := s.splitOn " " |>.filter (· ≠ "") - if parts.length >= 2 then - curLayer := (parts[1]!).toNat? |>.getD curLayer - i := i + 1 - else if s = "LN1_GAMMA" then - match consumeVector lines (i + 1) d with - | .error e => return .error e - | .ok (xs, next) => - if curLayer < L then - let old := ln1.getD curLayer defP - ln1 := ln1.set! curLayer { old with gamma := xs } - i := next - else if s = "LN1_BETA" then - match consumeVector lines (i + 1) d with - | .error e => return .error e - | .ok (xs, next) => - if curLayer < L then - let old := ln1.getD curLayer defP - ln1 := ln1.set! curLayer { old with beta := xs } - i := next - else if s = "LN2_GAMMA" then - match consumeVector lines (i + 1) d with - | .error e => return .error e - | .ok (xs, next) => - if curLayer < L then - let old := ln2.getD curLayer defP - ln2 := ln2.set! curLayer { old with gamma := xs } - i := next - else if s = "LN2_BETA" then - match consumeVector lines (i + 1) d with - | .error e => return .error e - | .ok (xs, next) => - if curLayer < L then - let old := ln2.getD curLayer defP - ln2 := ln2.set! curLayer { old with beta := xs } - i := next - else - i := i + 1 - return .ok (ln1, ln2) - -private def collectLayerNormParamsBinary - (path : System.FilePath) - (scalePow10 : Nat) - (slack : Int) : - IO - (Except String - (BinaryHeader × Array LayerNormParamsFixed × Array LayerNormParamsFixed)) := do - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - match ← readBinaryHeader h with - | .error e => return .error e - | .ok hdr => - match ← skipI32Array h hdr.seqLen with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.seqLen * hdr.modelDim) with - | .error e => return .error e - | .ok _ => pure () - let defP : LayerNormParamsFixed := { - gamma := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - beta := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - } - let mut ln1 : Array LayerNormParamsFixed := Array.replicate hdr.numLayers defP - let mut ln2 : Array LayerNormParamsFixed := Array.replicate hdr.numLayers defP - for l in [:hdr.numLayers] do - for _h in [:hdr.numHeads] do - match ← skipF64Array h (hdr.modelDim * hdr.headDim) with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.modelDim * hdr.headDim) with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.modelDim * hdr.headDim) with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.headDim * hdr.modelDim) with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.modelDim * hdr.hiddenDim) with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.hiddenDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.hiddenDim * hdr.modelDim) with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - let ln1GammaE ← readScaledFloatArray h hdr.modelDim scalePow10 - let ln1Gamma ← - match ln1GammaE with - | .error e => return .error e - | .ok xs => pure (intervalsFromScaled xs slack) - let ln1BetaE ← readScaledFloatArray h hdr.modelDim scalePow10 - let ln1Beta ← - match ln1BetaE with - | .error e => return .error e - | .ok xs => pure (intervalsFromScaled xs slack) - let ln2GammaE ← readScaledFloatArray h hdr.modelDim scalePow10 - let ln2Gamma ← - match ln2GammaE with - | .error e => return .error e - | .ok xs => pure (intervalsFromScaled xs slack) - let ln2BetaE ← readScaledFloatArray h hdr.modelDim scalePow10 - let ln2Beta ← - match ln2BetaE with - | .error e => return .error e - | .ok xs => pure (intervalsFromScaled xs slack) - ln1 := ln1.set! l { gamma := ln1Gamma, beta := ln1Beta } - ln2 := ln2.set! l { gamma := ln2Gamma, beta := ln2Beta } - return .ok (hdr, ln1, ln2) - -/-! -## Cached fixed-point local certification (fast path) - -The original local path (`RatInterval` + `parseRat`) is mathematically rigorous but too slow for -large models because it performs gcd-based normalization on the hot path. - -We therefore prefer a cached fixed-point representation (`sound_cache/*.nfpc`) and run local IBP -in scaled-`Int` arithmetic with conservative outward rounding. --/ - -private def defaultFixedScalePow10 : Nat := 9 -private def fixedUlpSlack : Int := 1 - -private def scaleCfgOfPow10 (p : Nat) : Fixed10Cfg := { scalePow10 := p } - -private def ratCeilMulNat (x : Rat) (k : Nat) : Int := - if x ≤ 0 then - 0 - else - let num : Int := x.num - let den : Nat := x.den - let numK : Int := num * (Int.ofNat k) - let q := numK.ediv (Int.ofNat den) - let r := numK.emod (Int.ofNat den) - if r = 0 then q else q + 1 - -private def ratFloorMulNat (x : Rat) (k : Nat) : Int := - let num : Int := x.num - let den : Nat := x.den - let numK : Int := num * (Int.ofNat k) - numK.ediv (Int.ofNat den) - -private def fixedMeanInterval (xs : Array Fixed10Interval) : Fixed10Interval := - if xs.isEmpty then - { lo := 0, hi := 0 } - else - Id.run do - let n : Nat := xs.size - let mut loSum : Int := 0 - let mut hiSum : Int := 0 - for x in xs do - loSum := loSum + x.lo - hiSum := hiSum + x.hi - let loμ := loSum.ediv (Int.ofNat n) - let hiμ := - let q := hiSum.ediv (Int.ofNat n) - let r := hiSum.emod (Int.ofNat n) - if r = 0 then q else q + 1 - { lo := loμ, hi := hiμ } - -private def fixedVarianceLowerBoundRange (cfg : Fixed10Cfg) (xs : Array Fixed10Interval) : Rat := - if xs.size < 2 then - 0 - else - Id.run do - let n : Nat := xs.size - let nRat : Rat := (n : Nat) - let mut loMax : Int := xs[0]!.lo - let mut hiMin : Int := xs[0]!.hi - for x in xs do - loMax := max loMax x.lo - hiMin := min hiMin x.hi - let δInt : Int := max 0 (loMax - hiMin) - if δInt = 0 then - return 0 - let δRat : Rat := - Rat.normalize δInt cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - let δSq : Rat := δRat * δRat - return δSq / ((2 : Rat) * nRat) - -private def absInt (x : Int) : Int := if x < 0 then -x else x - -/-- Lower bound on variance using midpoint + radius deviation. -/ -private def fixedVarianceLowerBoundMidpoint (cfg : Fixed10Cfg) (xs : Array Fixed10Interval) : - Rat := - if xs.size < 2 then - 0 - else - Id.run do - let n : Nat := xs.size - let nInt : Int := Int.ofNat n - let d : Nat := 2 * cfg.scaleNat - let mut sumM : Int := 0 - let mut sumR : Int := 0 - for x in xs do - sumM := sumM + (x.lo + x.hi) - sumR := sumR + (x.hi - x.lo) - let mut varNum : Int := 0 - let mut errNum : Int := 0 - for x in xs do - let mInt := x.lo + x.hi - let rInt := x.hi - x.lo - let aNum := nInt * mInt - sumM - let rNum := nInt * rInt + sumR - varNum := varNum + aNum * aNum - errNum := errNum + (absInt aNum) * rNum - let num := varNum - 2 * errNum - if num <= 0 then - return 0 - let denNat : Nat := d * d * n * n * n - return (num : Rat) / (denNat : Rat) - -/-- Exact variance lower bound by converting to `RatInterval` and using the exact routine. -/ -private def fixedVarianceLowerBoundExact (cfg : Fixed10Cfg) (xs : Array Fixed10Interval) : Rat := - if xs.size < 2 then - 0 - else - let ratXs := - xs.map (fun x => { lo := ratOfScaledInt cfg.scalePow10 x.lo, - hi := ratOfScaledInt cfg.scalePow10 x.hi }) - RatInterval.varianceLowerBound ratXs - -/-- Best available variance lower bound from range + midpoint deviation. -/ -private def fixedVarianceLowerBound (cfg : Fixed10Cfg) (xs : Array Fixed10Interval) : Rat := - let rangeLB := fixedVarianceLowerBoundRange cfg xs - let midLB := fixedVarianceLowerBoundMidpoint cfg xs - let approxLB := max rangeLB midLB - -- Avoid the exact Rat-based bound on large rows (expensive and stack-heavy), - -- but recover it when the fast bounds collapse to zero for medium sizes. - if xs.size ≤ 256 then - let exactLB := fixedVarianceLowerBoundExact cfg xs - max approxLB exactLB - else if approxLB = 0 && xs.size ≤ 1024 then - let exactLB := fixedVarianceLowerBoundExact cfg xs - max approxLB exactLB - else - approxLB - -private def fixedLayerNormRowApprox - (cfg : Fixed10Cfg) - (row : Array Fixed10Interval) - (gamma beta : Array Fixed10Interval) - (eps : Rat) - (soundnessBits : Nat) : - (Array Fixed10Interval × Rat) := - if row.size = 0 || gamma.size ≠ row.size || beta.size ≠ row.size then - (row, 0) - else - Id.run do - let μ := fixedMeanInterval row - let varLB := fixedVarianceLowerBound cfg row - let invσUpper : Rat := - if varLB ≤ 0 then - layerNormOpBoundConservative 1 eps soundnessBits - else - layerNormOpBoundLocal 1 varLB eps soundnessBits - let invσUpperInt : Int := ratCeilMulNat invσUpper cfg.scaleNat - let invσFix : Fixed10Interval := { lo := invσUpperInt, hi := invσUpperInt } - let mut out : Array Fixed10Interval := Array.mkEmpty row.size - for i in [:row.size] do - let centered := Fixed10Interval.sub row[i]! μ - let coeff := Fixed10Interval.mul cfg gamma[i]! invσFix - let scaled := Fixed10Interval.mul cfg coeff centered - out := out.push (Fixed10Interval.add scaled beta[i]!) - return (out, varLB) - -private def fixedLayerNormRowApproxExact - (cfg : Fixed10Cfg) - (row : Array Fixed10Interval) - (gamma beta : Array Fixed10Interval) - (eps : Rat) - (soundnessBits : Nat) : Array Fixed10Interval := - if row.size = 0 || gamma.size ≠ row.size || beta.size ≠ row.size then - row - else - Id.run do - let μ := fixedMeanInterval row - let varLB := fixedVarianceLowerBoundExact cfg row - let invσUpper : Rat := - if varLB ≤ 0 then - layerNormOpBoundConservative 1 eps soundnessBits - else - layerNormOpBoundLocal 1 varLB eps soundnessBits - let invσUpperInt : Int := ratCeilMulNat invσUpper cfg.scaleNat - let invσFix : Fixed10Interval := { lo := invσUpperInt, hi := invσUpperInt } - let mut out : Array Fixed10Interval := Array.mkEmpty row.size - for i in [:row.size] do - let centered := Fixed10Interval.sub row[i]! μ - let coeff := Fixed10Interval.mul cfg gamma[i]! invσFix - let scaled := Fixed10Interval.mul cfg coeff centered - out := out.push (Fixed10Interval.add scaled beta[i]!) - return out - -private def fixedLayerNormRowsApprox - (cfg : Fixed10Cfg) - (rows : Array (Array Fixed10Interval)) - (p : LayerNormParamsFixed) - (eps : Rat) - (soundnessBits : Nat) : - Array (Array Fixed10Interval) := - let useTasks := rows.size > 32 - if useTasks then - Id.run do - let chunkSize : Nat := 16 - let numChunks : Nat := (rows.size + chunkSize - 1) / chunkSize - let mut tasks : Array (Task (Array (Array Fixed10Interval))) := Array.mkEmpty numChunks - let mut chunkIdx : Nat := 0 - while chunkIdx < numChunks do - let start := chunkIdx * chunkSize - let stop := min rows.size (start + chunkSize) - tasks := tasks.push <| - Task.spawn (fun _ => - Id.run do - let mut outChunk : Array (Array Fixed10Interval) := Array.mkEmpty (stop - start) - let mut i := start - while i < stop do - outChunk := outChunk.push - (fixedLayerNormRowApprox cfg rows[i]! p.gamma p.beta eps soundnessBits).1 - i := i + 1 - return outChunk) - chunkIdx := chunkIdx + 1 - let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size - for t in tasks do - for row in t.get do - out := out.push row - return out - else - rows.map (fun row => (fixedLayerNormRowApprox cfg row p.gamma p.beta eps soundnessBits).1) - -private def fixedLayerNormRowsApproxExact - (cfg : Fixed10Cfg) - (rows : Array (Array Fixed10Interval)) - (p : LayerNormParamsFixed) - (eps : Rat) - (soundnessBits : Nat) : - Array (Array Fixed10Interval) := - let useTasks := rows.size > 32 - if useTasks then - Id.run do - let chunkSize : Nat := 16 - let numChunks : Nat := (rows.size + chunkSize - 1) / chunkSize - let mut tasks : Array (Task (Array (Array Fixed10Interval))) := Array.mkEmpty numChunks - let mut chunkIdx : Nat := 0 - while chunkIdx < numChunks do - let start := chunkIdx * chunkSize - let stop := min rows.size (start + chunkSize) - tasks := tasks.push <| - Task.spawn (fun _ => - Id.run do - let mut outChunk : Array (Array Fixed10Interval) := Array.mkEmpty (stop - start) - let mut i := start - while i < stop do - outChunk := outChunk.push - (fixedLayerNormRowApproxExact cfg rows[i]! p.gamma p.beta eps soundnessBits) - i := i + 1 - return outChunk) - chunkIdx := chunkIdx + 1 - let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size - for t in tasks do - for row in t.get do - out := out.push row - return out - else - rows.map (fun row => fixedLayerNormRowApproxExact cfg row p.gamma p.beta eps soundnessBits) - -private def readVecIntervals - (r : SoundCache.I32Reader) (n : Nat) (slack : Int) : - IO (Array Fixed10Interval × SoundCache.I32Reader) := do - let mut rr := r - let mut out : Array Fixed10Interval := Array.mkEmpty n - for _ in [:n] do - let (x, rr2) ← Nfp.Untrusted.SoundCacheIO.I32Reader.readI32 rr - rr := rr2 - out := out.push { lo := x - slack, hi := x + slack } - return (out, rr) - -private def readVecIntervalsBinary - (h : IO.FS.Handle) (n : Nat) (slack : Int) (scalePow10 : Nat) : - IO (Except String (Array Fixed10Interval)) := do - match ← readScaledFloatArray h n scalePow10 with - | .error e => return .error e - | .ok xs => return .ok (intervalsFromScaled xs slack) - -private def matMulIntervalsFromScaledCore - (cfg : Fixed10Cfg) - (slack : Int) - (rows cols : Nat) - (weights : Array Int) - (input : Array Fixed10Interval) : Array Fixed10Interval := - Id.run do - let mut out : Array Fixed10Interval := Array.replicate cols { lo := 0, hi := 0 } - let mut rowIdx : Nat := 0 - while rowIdx < rows do - let xi := input[rowIdx]! - let mut colIdx : Nat := 0 - while colIdx < cols do - let idx := rowIdx * cols + colIdx - let w := weights[idx]! - let wI : Fixed10Interval := { lo := w - slack, hi := w + slack } - let term := Fixed10Interval.mul cfg wI xi - out := out.set! colIdx (Fixed10Interval.add (out[colIdx]!) term) - colIdx := colIdx + 1 - rowIdx := rowIdx + 1 - return out - -private def matMulIntervalsFromScaledNoTask - (cfg : Fixed10Cfg) - (slack : Int) - (rows cols : Nat) - (weights : Array Int) - (input : Array Fixed10Interval) : Array Fixed10Interval := - if input.size ≠ rows || weights.size ≠ rows * cols then - Array.replicate cols { lo := 0, hi := 0 } - else - matMulIntervalsFromScaledCore cfg slack rows cols weights input - -private def matMulIntervalsFromIntervalsNoTask - (cfg : Fixed10Cfg) - (rows cols : Nat) - (weights : Array Fixed10Interval) - (input : Array Fixed10Interval) : Array Fixed10Interval := - Id.run do - if input.size ≠ rows || weights.size ≠ rows * cols then - return Array.replicate cols { lo := 0, hi := 0 } - let mut out : Array Fixed10Interval := Array.replicate cols { lo := 0, hi := 0 } - let mut rowIdx : Nat := 0 - while rowIdx < rows do - let xi := input[rowIdx]! - let mut colIdx : Nat := 0 - while colIdx < cols do - let idx := rowIdx * cols + colIdx - let wI := weights[idx]! - let term := Fixed10Interval.mul cfg wI xi - out := out.set! colIdx (Fixed10Interval.add (out[colIdx]!) term) - colIdx := colIdx + 1 - rowIdx := rowIdx + 1 - return out - -private def matMulIntervalsFromIntervals - (cfg : Fixed10Cfg) - (rows cols : Nat) - (weights : Array Fixed10Interval) - (input : Array Fixed10Interval) : Array Fixed10Interval := - Id.run do - if input.size ≠ rows || weights.size ≠ rows * cols then - return Array.replicate cols { lo := 0, hi := 0 } - let useTasks := rows * cols > 16384 && cols > 1 - if useTasks then - let chunkSize : Nat := 32 - let numChunks : Nat := (cols + chunkSize - 1) / chunkSize - let mut tasks : Array (Task (Array Fixed10Interval)) := Array.mkEmpty numChunks - let mut chunkIdx : Nat := 0 - while chunkIdx < numChunks do - let start := chunkIdx * chunkSize - let stop := min cols (start + chunkSize) - tasks := tasks.push <| - Task.spawn (fun _ => - Id.run do - let mut outChunk : Array Fixed10Interval := Array.mkEmpty (stop - start) - let mut colIdx : Nat := start - while colIdx < stop do - let mut acc : Fixed10Interval := { lo := 0, hi := 0 } - let mut rowIdx : Nat := 0 - while rowIdx < rows do - let xi := input[rowIdx]! - let idx := rowIdx * cols + colIdx - let wI := weights[idx]! - let term := Fixed10Interval.mul cfg wI xi - acc := Fixed10Interval.add acc term - rowIdx := rowIdx + 1 - outChunk := outChunk.push acc - colIdx := colIdx + 1 - return outChunk) - chunkIdx := chunkIdx + 1 - let mut out : Array Fixed10Interval := Array.mkEmpty cols - for t in tasks do - let chunk := t.get - for v in chunk do - out := out.push v - return out - else - return matMulIntervalsFromIntervalsNoTask cfg rows cols weights input - -private def matMulIntervalsFromScaled - (cfg : Fixed10Cfg) - (slack : Int) - (rows cols : Nat) - (weights : Array Int) - (input : Array Fixed10Interval) : Array Fixed10Interval := - Id.run do - if input.size ≠ rows || weights.size ≠ rows * cols then - return Array.replicate cols { lo := 0, hi := 0 } - let useTasks := rows * cols > 16384 && cols > 1 - if useTasks then - let chunkSize : Nat := 32 - let numChunks : Nat := (cols + chunkSize - 1) / chunkSize - let mut tasks : Array (Task (Array Fixed10Interval)) := Array.mkEmpty numChunks - let mut chunkIdx : Nat := 0 - while chunkIdx < numChunks do - let start := chunkIdx * chunkSize - let stop := min cols (start + chunkSize) - tasks := tasks.push <| - Task.spawn (fun _ => - Id.run do - let mut outChunk : Array Fixed10Interval := Array.mkEmpty (stop - start) - let mut colIdx : Nat := start - while colIdx < stop do - let mut acc : Fixed10Interval := { lo := 0, hi := 0 } - let mut rowIdx : Nat := 0 - while rowIdx < rows do - let xi := input[rowIdx]! - let idx := rowIdx * cols + colIdx - let w := weights[idx]! - let wI : Fixed10Interval := { lo := w - slack, hi := w + slack } - let term := Fixed10Interval.mul cfg wI xi - acc := Fixed10Interval.add acc term - rowIdx := rowIdx + 1 - outChunk := outChunk.push acc - colIdx := colIdx + 1 - return outChunk) - chunkIdx := chunkIdx + 1 - let mut out : Array Fixed10Interval := Array.mkEmpty cols - for t in tasks do - let chunk := t.get - for v in chunk do - out := out.push v - return out - else - return matMulIntervalsFromScaledCore cfg slack rows cols weights input - -private def fixedDotInterval - (cfg : Fixed10Cfg) - (a b : Array Fixed10Interval) : Fixed10Interval := - if a.size = 0 || a.size ≠ b.size then - { lo := 0, hi := 0 } - else - Id.run do - let mut acc : Fixed10Interval := { lo := 0, hi := 0 } - for i in [:a.size] do - let term := Fixed10Interval.mul cfg a[i]! b[i]! - acc := Fixed10Interval.add acc term - return acc - -private def centerRadiusOfFixed - (cfg : Fixed10Cfg) (a : Fixed10Interval) : Rat × Rat := - let lo := ratOfScaledInt cfg.scalePow10 a.lo - let hi := ratOfScaledInt cfg.scalePow10 a.hi - let center := (lo + hi) / (2 : Rat) - let radius := (hi - lo) / (2 : Rat) - (center, radius) - -private def rowCentersRadiiAbs - (cfg : Fixed10Cfg) - (row : Array Fixed10Interval) : Array Rat × Array Rat × Rat := - Id.run do - let mut centers : Array Rat := Array.mkEmpty row.size - let mut radii : Array Rat := Array.mkEmpty row.size - let mut absSum : Rat := 0 - for x in row do - let lo := ratOfScaledInt cfg.scalePow10 x.lo - let hi := ratOfScaledInt cfg.scalePow10 x.hi - let center := (lo + hi) / (2 : Rat) - let radius := (hi - lo) / (2 : Rat) - centers := centers.push center - radii := radii.push radius - absSum := absSum + max (ratAbs lo) (ratAbs hi) - return (centers, radii, absSum) - -private def weightsRatFromScaled (cfg : Fixed10Cfg) (weights : Array Int) : Array Rat := - weights.map (ratOfScaledInt cfg.scalePow10) - -private def affineMatMulRowExact - (rows cols : Nat) - (weights : Array Rat) - (centers radii : Array Rat) : Array AffineForm := - Id.run do - if centers.size ≠ rows || radii.size ≠ rows || weights.size ≠ rows * cols then - return Array.replicate cols (AffineForm.const 0) - let mut out : Array AffineForm := Array.mkEmpty cols - for colIdx in [:cols] do - let mut center : Rat := 0 - let mut coeffs : Array Rat := Array.mkEmpty rows - for rowIdx in [:rows] do - let idx := rowIdx * cols + colIdx - let w := weights[idx]! - center := center + w * centers[rowIdx]! - coeffs := coeffs.push (w * radii[rowIdx]!) - out := out.push { center := center, coeffs := coeffs } - return out - -private def affineAddBiasCenters - (biasCenters : Array Rat) - (row : Array AffineForm) : Array AffineForm := - Id.run do - if biasCenters.size ≠ row.size then - return row - let mut out : Array AffineForm := Array.mkEmpty row.size - for i in [:row.size] do - let a := row.getD i (AffineForm.const 0) - let bias := biasCenters.getD i 0 - out := out.push { a with center := a.center + bias } - return out - -private def affineAbsSum (row : Array AffineForm) : Rat := - row.foldl (fun acc a => acc + ratAbs a.center + AffineForm.radius a) 0 - -private def affineDotDisjoint - (a b : Array AffineForm) : AffineForm := - if a.size = 0 || a.size ≠ b.size then - AffineForm.const 0 - else - Id.run do - let mut acc := AffineForm.const 0 - for i in [:a.size] do - let ai := a.getD i (AffineForm.const 0) - let bi := b.getD i (AffineForm.const 0) - let term := AffineForm.mulDisjoint ai bi - acc := AffineForm.add acc term - return acc - -private def sumRat (xs : Array Rat) : Rat := - Id.run do - let mut acc : Rat := 0 - let mut i := 0 - while i < xs.size do - acc := acc + xs[i]! - i := i + 1 - return acc - -private def sumAbsRat (xs : Array Rat) : Rat := - Id.run do - let mut acc : Rat := 0 - let mut i := 0 - while i < xs.size do - acc := acc + ratAbs xs[i]! - i := i + 1 - return acc - -private def addVecRat (a b : Array Rat) : Array Rat := - Id.run do - if a.size ≠ b.size then - return a - let mut out : Array Rat := Array.mkEmpty a.size - let mut i := 0 - while i < a.size do - out := out.push (a[i]! + b[i]!) - i := i + 1 - return out - -private def dotRat (a b : Array Rat) : Rat := - if a.size = 0 || a.size ≠ b.size then - 0 - else - Id.run do - let mut acc : Rat := 0 - let mut i := 0 - while i < a.size do - acc := acc + a[i]! * b[i]! - i := i + 1 - return acc - -private def matMulCentersRadii - (rows cols : Nat) - (weights : Array Rat) - (centers radii : Array Rat) : Array Rat × Array Rat := - Id.run do - if centers.size ≠ rows || radii.size ≠ rows || weights.size ≠ rows * cols then - return (Array.replicate cols 0, Array.replicate cols 0) - let mut outCenters : Array Rat := Array.mkEmpty cols - let mut outRadii : Array Rat := Array.mkEmpty cols - let mut colIdx := 0 - while colIdx < cols do - let mut center : Rat := 0 - let mut radius : Rat := 0 - let mut rowIdx := 0 - while rowIdx < rows do - let idx := rowIdx * cols + colIdx - let w := weights.getD idx 0 - let c := centers.getD rowIdx 0 - let r := radii.getD rowIdx 0 - center := center + w * c - radius := radius + ratAbs w * r - rowIdx := rowIdx + 1 - outCenters := outCenters.push center - outRadii := outRadii.push radius - colIdx := colIdx + 1 - return (outCenters, outRadii) - -private def coeffSumFromCenters - (rows cols : Nat) - (weights : Array Rat) - (inputRadii : Array Rat) - (otherCenters : Array Rat) : Rat := - if inputRadii.size ≠ rows || otherCenters.size ≠ cols || weights.size ≠ rows * cols then - 0 - else - Id.run do - let mut acc : Rat := 0 - let mut rowIdx := 0 - while rowIdx < rows do - let mut sum : Rat := 0 - let mut colIdx := 0 - while colIdx < cols do - let idx := rowIdx * cols + colIdx - sum := sum + weights.getD idx 0 * otherCenters.getD colIdx 0 - colIdx := colIdx + 1 - let coeff := inputRadii.getD rowIdx 0 * sum - acc := acc + ratAbs coeff - rowIdx := rowIdx + 1 - return acc - -private def sumInt (xs : Array Int) : Int := - Id.run do - let mut acc : Int := 0 - let mut i := 0 - while i < xs.size do - acc := acc + xs[i]! - i := i + 1 - return acc - -private def sumAbsInt (xs : Array Int) : Int := - Id.run do - let mut acc : Int := 0 - let mut i := 0 - while i < xs.size do - acc := acc + absInt xs[i]! - i := i + 1 - return acc - -private def addVecScaledInt (a : Array Int) (b : Array Int) (scale : Int) : Array Int := - Id.run do - if a.size ≠ b.size then - return a - let mut out : Array Int := Array.mkEmpty a.size - let mut i := 0 - while i < a.size do - out := out.push (a[i]! + b[i]! * scale) - i := i + 1 - return out - -private def dotInt (a b : Array Int) : Int := - if a.size = 0 || a.size ≠ b.size then - 0 - else - Id.run do - let mut acc : Int := 0 - let mut i := 0 - while i < a.size do - acc := acc + a[i]! * b[i]! - i := i + 1 - return acc - -private def rowCentersRadiiAbsInt - (row : Array Fixed10Interval) : Array Int × Array Int × Int := - Id.run do - let mut centers : Array Int := Array.mkEmpty row.size - let mut radii : Array Int := Array.mkEmpty row.size - let mut absSum : Int := 0 - for x in row do - let sum := x.lo + x.hi - let width := x.hi - x.lo - let center := sum.ediv (Int.ofNat 2) - let half := width.ediv (Int.ofNat 2) - let radius := if width.emod (Int.ofNat 2) = 0 then half else half + 1 - centers := centers.push center - radii := radii.push radius - absSum := absSum + Fixed10Interval.absUpper x - return (centers, radii, absSum) - -private def matMulCentersRadiiInt - (cfg : Fixed10Cfg) - (rows cols : Nat) - (weights : Array Int) - (centers radii : Array Int) : Array Int × Array Int := - Id.run do - if centers.size ≠ rows || radii.size ≠ rows || weights.size ≠ rows * cols then - return (Array.replicate cols 0, Array.replicate cols 0) - let mut outCenters : Array Int := Array.mkEmpty cols - let mut outRadii : Array Int := Array.mkEmpty cols - let mut colIdx := 0 - while colIdx < cols do - let mut centerI : Fixed10Interval := { lo := 0, hi := 0 } - let mut radiusAcc : Int := 0 - let mut rowIdx := 0 - while rowIdx < rows do - let idx := rowIdx * cols + colIdx - let w := weights.getD idx 0 - let c := centers.getD rowIdx 0 - let r := radii.getD rowIdx 0 - let term := Fixed10Interval.mul cfg { lo := w, hi := w } { lo := c, hi := c } - centerI := Fixed10Interval.add centerI term - if r ≠ 0 && w ≠ 0 then - let wAbs := absInt w - let termR := Fixed10Interval.mul cfg { lo := wAbs, hi := wAbs } { lo := r, hi := r } - radiusAcc := radiusAcc + termR.hi - rowIdx := rowIdx + 1 - let width := centerI.hi - centerI.lo - let center := (centerI.lo + centerI.hi).ediv (Int.ofNat 2) - let half := width.ediv (Int.ofNat 2) - let radiusMid := if width.emod (Int.ofNat 2) = 0 then half else half + 1 - let radius := radiusMid + radiusAcc - outCenters := outCenters.push center - outRadii := outRadii.push radius - colIdx := colIdx + 1 - return (outCenters, outRadii) - -private def intervalRadiusInt (x : Fixed10Interval) : Int := - let width := x.hi - x.lo - let half := width.ediv (Int.ofNat 2) - if width.emod (Int.ofNat 2) = 0 then half else half + 1 - -private def matMulCentersRadiiIntSlack - (cfg : Fixed10Cfg) - (slack : Int) - (rows cols : Nat) - (weights : Array Int) - (centers radii : Array Int) : Array Int × Array Int := - Id.run do - if centers.size ≠ rows || radii.size ≠ rows || weights.size ≠ rows * cols then - return (Array.replicate cols 0, Array.replicate cols 0) - let mut outCenters : Array Int := Array.mkEmpty cols - let mut outRadii : Array Int := Array.mkEmpty cols - let mut colIdx := 0 - while colIdx < cols do - let mut centerI : Fixed10Interval := { lo := 0, hi := 0 } - let mut radiusAcc : Int := 0 - let mut rowIdx := 0 - while rowIdx < rows do - let idx := rowIdx * cols + colIdx - let w := weights.getD idx 0 - let c := centers.getD rowIdx 0 - let r := radii.getD rowIdx 0 - let term := Fixed10Interval.mul cfg { lo := w, hi := w } { lo := c, hi := c } - centerI := Fixed10Interval.add centerI term - if r ≠ 0 || slack ≠ 0 then - let wAbs := absInt w - let cAbs := absInt c - let term1 := Fixed10Interval.mul cfg { lo := wAbs, hi := wAbs } { lo := r, hi := r } - let term2 := - if slack = 0 then 0 - else - (Fixed10Interval.mul cfg { lo := slack, hi := slack } - { lo := cAbs, hi := cAbs }).hi - let term3 := - if slack = 0 then 0 - else - (Fixed10Interval.mul cfg { lo := slack, hi := slack } { lo := r, hi := r }).hi - radiusAcc := radiusAcc + term1.hi + term2 + term3 - rowIdx := rowIdx + 1 - let width := centerI.hi - centerI.lo - let center := (centerI.lo + centerI.hi).ediv (Int.ofNat 2) - let half := width.ediv (Int.ofNat 2) - let radiusMid := if width.emod (Int.ofNat 2) = 0 then half else half + 1 - let radius := radiusMid + radiusAcc - outCenters := outCenters.push center - outRadii := outRadii.push radius - colIdx := colIdx + 1 - return (outCenters, outRadii) - -private def coeffSumFromCentersInt - (cfg : Fixed10Cfg) - (rows cols : Nat) - (weights : Array Int) - (inputRadii : Array Int) - (otherCenters : Array Int) : Int := - if inputRadii.size ≠ rows || otherCenters.size ≠ cols || weights.size ≠ rows * cols then - 0 - else - Id.run do - let mut acc : Int := 0 - let mut rowIdx := 0 - while rowIdx < rows do - let mut sum : Fixed10Interval := { lo := 0, hi := 0 } - let mut colIdx := 0 - while colIdx < cols do - let idx := rowIdx * cols + colIdx - let w := weights.getD idx 0 - let c := otherCenters.getD colIdx 0 - let term := Fixed10Interval.mul cfg { lo := w, hi := w } { lo := c, hi := c } - sum := Fixed10Interval.add sum term - colIdx := colIdx + 1 - let r := inputRadii.getD rowIdx 0 - let coeff := Fixed10Interval.mul cfg sum { lo := r, hi := r } - acc := acc + Fixed10Interval.absUpper coeff - rowIdx := rowIdx + 1 - return acc - -private def dotIntervalFromCentersInt - (cfg : Fixed10Cfg) - (a b : Array Int) : Fixed10Interval := - if a.size = 0 || a.size ≠ b.size then - { lo := 0, hi := 0 } - else - Id.run do - let mut acc : Fixed10Interval := { lo := 0, hi := 0 } - let mut i := 0 - while i < a.size do - let term := Fixed10Interval.mul cfg - { lo := a[i]!, hi := a[i]! } - { lo := b[i]!, hi := b[i]! } - acc := Fixed10Interval.add acc term - i := i + 1 - return acc - -private def dotIntervalFromCentersRadiiInt - (cfg : Fixed10Cfg) - (aCenters aRadii bCenters bRadii : Array Int) : Fixed10Interval := - if aCenters.size = 0 || aCenters.size ≠ bCenters.size || - aCenters.size ≠ aRadii.size || bCenters.size ≠ bRadii.size then - { lo := 0, hi := 0 } - else - Id.run do - let mut centerI : Fixed10Interval := { lo := 0, hi := 0 } - let mut radiusAcc : Int := 0 - let mut i := 0 - while i < aCenters.size do - let ac := aCenters[i]! - let ar := aRadii[i]! - let bc := bCenters[i]! - let br := bRadii[i]! - let term := Fixed10Interval.mul cfg { lo := ac, hi := ac } { lo := bc, hi := bc } - centerI := Fixed10Interval.add centerI term - if ar ≠ 0 || br ≠ 0 then - let acAbs := absInt ac - let bcAbs := absInt bc - let term1 := Fixed10Interval.mul cfg { lo := acAbs, hi := acAbs } { lo := br, hi := br } - let term2 := Fixed10Interval.mul cfg { lo := bcAbs, hi := bcAbs } { lo := ar, hi := ar } - let term3 := Fixed10Interval.mul cfg { lo := ar, hi := ar } { lo := br, hi := br } - radiusAcc := radiusAcc + term1.hi + term2.hi + term3.hi - i := i + 1 - let width := centerI.hi - centerI.lo - let center := (centerI.lo + centerI.hi).ediv (Int.ofNat 2) - let half := width.ediv (Int.ofNat 2) - let radiusMid := if width.emod (Int.ofNat 2) = 0 then half else half + 1 - let radius := radiusMid + radiusAcc - return { lo := center - radius, hi := center + radius } - -private def sumMulUpperInt - (cfg : Fixed10Cfg) - (a b : Array Int) : Int := - if a.size = 0 || a.size ≠ b.size then - 0 - else - Id.run do - let mut acc : Int := 0 - let mut i := 0 - while i < a.size do - let term := Fixed10Interval.mul cfg - { lo := a[i]!, hi := a[i]! } - { lo := b[i]!, hi := b[i]! } - acc := acc + term.hi - i := i + 1 - return acc - -private def floorDivNat (a : Int) (d : Nat) : Int := - a.ediv (Int.ofNat d) - -private def ceilDivNat (a : Int) (d : Nat) : Int := - let di : Int := Int.ofNat d - let q := a.ediv di - let r := a.emod di - if r = 0 then q else q + 1 - -private def maxAbsVecFixed (xs : Array Fixed10Interval) : Int := - xs.foldl (fun acc x => max acc (Fixed10Interval.absUpper x)) 0 - -/-- Sum of per-coordinate centered absolute bounds (interval widths), as a `Rat`. -/ -private def centeredAbsSumFixed (cfg : Fixed10Cfg) (xs : Array Fixed10Interval) : Rat := - let sumWidth : Int := xs.foldl (fun acc x => acc + Fixed10Interval.centeredAbsBound x) 0 - Rat.normalize sumWidth cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - -private def ratIntervalOfFixed (cfg : Fixed10Cfg) (a : Fixed10Interval) : RatInterval := - { lo := ratOfScaledInt cfg.scalePow10 a.lo, hi := ratOfScaledInt cfg.scalePow10 a.hi } - -private def fixedIntervalOfRat (cfg : Fixed10Cfg) (a : RatInterval) : Fixed10Interval := - { lo := ratFloorMulNat a.lo cfg.scaleNat, hi := ratCeilMulNat a.hi cfg.scaleNat } - -private def defaultGeluExpEffort : Nat := 2 -private def defaultGeluSplitDepth : Nat := 1 - -private def geluOverapproxRat (target : GeluDerivTarget) (a : RatInterval) : RatInterval := - match target with - | .tanh => RatInterval.geluOverapproxTanhSplit a defaultGeluExpEffort defaultGeluSplitDepth - | .exact => RatInterval.geluOverapprox a - -private def geluOverapproxFixed (cfg : Fixed10Cfg) (target : GeluDerivTarget) - (a : Fixed10Interval) : Fixed10Interval := - match target with - | .tanh => - let r := ratIntervalOfFixed cfg a - fixedIntervalOfRat cfg - (RatInterval.geluOverapproxTanhSplit r defaultGeluExpEffort defaultGeluSplitDepth) - | .exact => - Fixed10Interval.geluOverapprox a - -private def geluOverapproxFixedVec (cfg : Fixed10Cfg) (target : GeluDerivTarget) - (xs : Array Fixed10Interval) : Array Fixed10Interval := - Id.run do - let mut out : Array Fixed10Interval := Array.mkEmpty xs.size - let mut i : Nat := 0 - while i < xs.size do - out := out.push (geluOverapproxFixed cfg target xs[i]!) - i := i + 1 - return out - -private def geluOverapproxFixedVecLinear - (xs : Array Fixed10Interval) : Array Fixed10Interval := - Id.run do - let mut out : Array Fixed10Interval := Array.mkEmpty xs.size - let mut i : Nat := 0 - while i < xs.size do - out := out.push (Fixed10Interval.geluOverapprox xs[i]!) - i := i + 1 - return out - -private def addVecFixed (a b : Array Fixed10Interval) : Array Fixed10Interval := - Id.run do - if a.size ≠ b.size then - return a - let mut out : Array Fixed10Interval := Array.mkEmpty a.size - let mut i : Nat := 0 - while i < a.size do - out := out.push (Fixed10Interval.add a[i]! b[i]!) - i := i + 1 - return out - -private def addVecFixedRows - (rows : Array (Array Fixed10Interval)) - (v : Array Fixed10Interval) : Array (Array Fixed10Interval) := - Id.run do - let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size - let mut i : Nat := 0 - while i < rows.size do - out := out.push (addVecFixed rows[i]! v) - i := i + 1 - return out - -private def addRowsFixed - (rows : Array (Array Fixed10Interval)) - (adds : Array (Array Fixed10Interval)) : Array (Array Fixed10Interval) := - Id.run do - if rows.size ≠ adds.size then - return rows - let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size - let mut i : Nat := 0 - while i < rows.size do - out := out.push (addVecFixed rows[i]! adds[i]!) - i := i + 1 - return out - -private def takePrefix {α : Type} (xs : Array α) (n : Nat) : Array α := - if xs.size ≤ n then xs else xs.extract 0 n - -private def mlpRowFromScaled - (cfg : Fixed10Cfg) - (geluDerivTarget : GeluDerivTarget) - (slack : Int) - (modelDim hiddenDim : Nat) - (wIn wOut : Array Int) - (bIn bOut : Array Fixed10Interval) - (row : Array Fixed10Interval) : Array Fixed10Interval := - let hidden0 := matMulIntervalsFromScaled cfg slack modelDim hiddenDim wIn row - let hiddenB := addVecFixed hidden0 bIn - let actHidden := geluOverapproxFixedVec cfg geluDerivTarget hiddenB - let mlpOut0 := matMulIntervalsFromScaled cfg slack hiddenDim modelDim wOut actHidden - addVecFixed mlpOut0 bOut - -private def mlpRowFromScaledNoTask - (cfg : Fixed10Cfg) - (geluDerivTarget : GeluDerivTarget) - (slack : Int) - (modelDim hiddenDim : Nat) - (wIn wOut : Array Int) - (bIn bOut : Array Fixed10Interval) - (row : Array Fixed10Interval) : Array Fixed10Interval := - let hidden0 := matMulIntervalsFromScaledNoTask cfg slack modelDim hiddenDim wIn row - let hiddenB := addVecFixed hidden0 bIn - let actHidden := geluOverapproxFixedVec cfg geluDerivTarget hiddenB - let mlpOut0 := matMulIntervalsFromScaledNoTask cfg slack hiddenDim modelDim wOut actHidden - addVecFixed mlpOut0 bOut - -/-- Linear GeLU-hull MLP row used to avoid the tanh/exp path in hot loops. -/ -private def mlpRowFromScaledLinear - (cfg : Fixed10Cfg) - (slack : Int) - (modelDim hiddenDim : Nat) - (wIn wOut : Array Int) - (bIn bOut : Array Fixed10Interval) - (row : Array Fixed10Interval) : Array Fixed10Interval := - let hidden0 := matMulIntervalsFromScaled cfg slack modelDim hiddenDim wIn row - let hiddenB := addVecFixed hidden0 bIn - let actHidden := geluOverapproxFixedVecLinear hiddenB - let mlpOut0 := matMulIntervalsFromScaled cfg slack hiddenDim modelDim wOut actHidden - addVecFixed mlpOut0 bOut - -private def mlpRowFromScaledLinearNoTask - (cfg : Fixed10Cfg) - (slack : Int) - (modelDim hiddenDim : Nat) - (wIn wOut : Array Int) - (bIn bOut : Array Fixed10Interval) - (row : Array Fixed10Interval) : Array Fixed10Interval := - let hidden0 := matMulIntervalsFromScaledNoTask cfg slack modelDim hiddenDim wIn row - let hiddenB := addVecFixed hidden0 bIn - let actHidden := geluOverapproxFixedVecLinear hiddenB - let mlpOut0 := matMulIntervalsFromScaledNoTask cfg slack hiddenDim modelDim wOut actHidden - addVecFixed mlpOut0 bOut - -private def mlpRowFromIntervalsNoTask - (cfg : Fixed10Cfg) - (geluDerivTarget : GeluDerivTarget) - (modelDim hiddenDim : Nat) - (wIn wOut : Array Fixed10Interval) - (bIn bOut : Array Fixed10Interval) - (row : Array Fixed10Interval) : Array Fixed10Interval := - let hidden0 := matMulIntervalsFromIntervalsNoTask cfg modelDim hiddenDim wIn row - let hiddenB := addVecFixed hidden0 bIn - let actHidden := geluOverapproxFixedVec cfg geluDerivTarget hiddenB - let mlpOut0 := matMulIntervalsFromIntervalsNoTask cfg hiddenDim modelDim wOut actHidden - addVecFixed mlpOut0 bOut - -private def mlpRowFromIntervals - (cfg : Fixed10Cfg) - (geluDerivTarget : GeluDerivTarget) - (modelDim hiddenDim : Nat) - (wIn wOut : Array Fixed10Interval) - (bIn bOut : Array Fixed10Interval) - (row : Array Fixed10Interval) : Array Fixed10Interval := - let hidden0 := matMulIntervalsFromIntervals cfg modelDim hiddenDim wIn row - let hiddenB := addVecFixed hidden0 bIn - let actHidden := geluOverapproxFixedVec cfg geluDerivTarget hiddenB - let mlpOut0 := matMulIntervalsFromIntervals cfg hiddenDim modelDim wOut actHidden - addVecFixed mlpOut0 bOut - -private def mlpRowsFromIntervals - (cfg : Fixed10Cfg) - (geluDerivTarget : GeluDerivTarget) - (modelDim hiddenDim : Nat) - (wIn wOut : Array Fixed10Interval) - (bIn bOut : Array Fixed10Interval) - (rows : Array (Array Fixed10Interval)) : Array (Array Fixed10Interval) := - let useTasks := rows.size > 32 - if useTasks then - Id.run do - let chunkSize : Nat := 16 - let numChunks : Nat := (rows.size + chunkSize - 1) / chunkSize - let mut tasks : Array (Task (Array (Array Fixed10Interval))) := Array.mkEmpty numChunks - let mut chunkIdx : Nat := 0 - while chunkIdx < numChunks do - let start := chunkIdx * chunkSize - let stop := min rows.size (start + chunkSize) - tasks := tasks.push <| - Task.spawn (fun _ => - Id.run do - let mut outChunk : Array (Array Fixed10Interval) := Array.mkEmpty (stop - start) - let mut i := start - while i < stop do - outChunk := outChunk.push - (mlpRowFromIntervalsNoTask cfg geluDerivTarget modelDim hiddenDim wIn wOut bIn - bOut rows[i]!) - i := i + 1 - return outChunk) - chunkIdx := chunkIdx + 1 - let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size - for t in tasks do - for row in t.get do - out := out.push row - return out - else - rows.map (mlpRowFromIntervalsNoTask cfg geluDerivTarget modelDim hiddenDim wIn wOut bIn bOut) - -private def mlpRowsFromScaled - (cfg : Fixed10Cfg) - (geluDerivTarget : GeluDerivTarget) - (slack : Int) - (modelDim hiddenDim : Nat) - (wIn wOut : Array Int) - (bIn bOut : Array Fixed10Interval) - (rows : Array (Array Fixed10Interval)) : Array (Array Fixed10Interval) := - let useTasks := rows.size > 32 - if useTasks then - Id.run do - let chunkSize : Nat := 16 - let numChunks : Nat := (rows.size + chunkSize - 1) / chunkSize - let mut tasks : Array (Task (Array (Array Fixed10Interval))) := Array.mkEmpty numChunks - let mut chunkIdx : Nat := 0 - while chunkIdx < numChunks do - let start := chunkIdx * chunkSize - let stop := min rows.size (start + chunkSize) - tasks := tasks.push <| - Task.spawn (fun _ => - Id.run do - let mut outChunk : Array (Array Fixed10Interval) := Array.mkEmpty (stop - start) - let mut i := start - while i < stop do - outChunk := outChunk.push - (mlpRowFromScaledNoTask cfg geluDerivTarget slack modelDim hiddenDim wIn wOut bIn - bOut rows[i]!) - i := i + 1 - return outChunk) - chunkIdx := chunkIdx + 1 - let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size - for t in tasks do - for row in t.get do - out := out.push row - return out - else - rows.map (mlpRowFromScaledNoTask cfg geluDerivTarget slack modelDim hiddenDim wIn wOut bIn bOut) - -/-- Linear GeLU-hull per-row MLP for best-match induction hot paths. -/ -private def mlpRowsFromScaledLinear - (cfg : Fixed10Cfg) - (slack : Int) - (modelDim hiddenDim : Nat) - (wIn wOut : Array Int) - (bIn bOut : Array Fixed10Interval) - (rows : Array (Array Fixed10Interval)) : Array (Array Fixed10Interval) := - let wInIntervals := intervalsFromScaled wIn slack - let wOutIntervals := intervalsFromScaled wOut slack - let mlpRowFromIntervals (row : Array Fixed10Interval) : Array Fixed10Interval := - let hidden0 := matMulIntervalsFromIntervalsNoTask cfg modelDim hiddenDim wInIntervals row - let hiddenB := addVecFixed hidden0 bIn - let actHidden := geluOverapproxFixedVecLinear hiddenB - let mlpOut0 := matMulIntervalsFromIntervalsNoTask cfg hiddenDim modelDim wOutIntervals actHidden - addVecFixed mlpOut0 bOut - let useTasks := rows.size > 32 - if useTasks then - Id.run do - let chunkSize : Nat := 16 - let numChunks : Nat := (rows.size + chunkSize - 1) / chunkSize - let mut tasks : Array (Task (Array (Array Fixed10Interval))) := Array.mkEmpty numChunks - let mut chunkIdx : Nat := 0 - while chunkIdx < numChunks do - let start := chunkIdx * chunkSize - let stop := min rows.size (start + chunkSize) - tasks := tasks.push <| - Task.spawn (fun _ => - Id.run do - let mut outChunk : Array (Array Fixed10Interval) := Array.mkEmpty (stop - start) - let mut i := start - while i < stop do - outChunk := outChunk.push (mlpRowFromIntervals rows[i]!) - i := i + 1 - return outChunk) - chunkIdx := chunkIdx + 1 - let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size - for t in tasks do - for row in t.get do - out := out.push row - return out - else - rows.map mlpRowFromIntervals - -private def groupUnionRowsByToken - (rows : Array (Array Fixed10Interval)) - (tokens : Array Int) : Array (Array Fixed10Interval) := - Id.run do - if rows.size ≠ tokens.size then - return rows - let mut uniqTokens : Array Int := #[] - let mut uniqRows : Array (Array Fixed10Interval) := #[] - let mut i : Nat := 0 - while i < rows.size do - let tok := tokens[i]! - match uniqTokens.findIdx? (· == tok) with - | some idx => - let merged := Fixed10Interval.unionVec (uniqRows[idx]!) rows[i]! - uniqRows := uniqRows.set! idx merged - | none => - uniqTokens := uniqTokens.push tok - uniqRows := uniqRows.push rows[i]! - i := i + 1 - return uniqRows - -private def unionRowsFixed - (rows : Array (Array Fixed10Interval)) : Array Fixed10Interval := - if rows.isEmpty then - #[] - else - Id.run do - let mut out := rows[0]! - let mut i : Nat := 1 - while i < rows.size do - let row := rows[i]! - if row.size = out.size then - let mut j : Nat := 0 - while j < out.size do - let cur := out[j]! - let r := row[j]! - out := out.set! j { lo := min cur.lo r.lo, hi := max cur.hi r.hi } - j := j + 1 - i := i + 1 - return out - -private def prefixUnionRowsFixed - (rows : Array (Array Fixed10Interval)) : Array (Array Fixed10Interval) := - if rows.isEmpty then - #[] - else - Id.run do - let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size - let mut acc := rows[0]! - out := out.push acc - let mut i : Nat := 1 - while i < rows.size do - acc := Fixed10Interval.unionVec acc rows[i]! - out := out.push acc - i := i + 1 - return out - -private def consumeMatrixMulAndNormInfFixed - (cfg : Fixed10Cfg) - (slack : Int) - (r : SoundCache.I32Reader) - (rows cols : Nat) - (input : Array Fixed10Interval) : - IO (Array Fixed10Interval × Rat × SoundCache.I32Reader) := do - if input.size ≠ rows then - return (Array.replicate cols { lo := 0, hi := 0 }, 0, r) - let mut rr := r - let mut out : Array Fixed10Interval := Array.replicate cols { lo := 0, hi := 0 } - let mut curRowAbs : Int := 0 - let mut maxRowAbs : Int := 0 - let mut rowIdx : Nat := 0 - while rowIdx < rows do - let xi := input[rowIdx]! - let mut colIdx : Nat := 0 - while colIdx < cols do - let (w, rr2) ← Nfp.Untrusted.SoundCacheIO.I32Reader.readI32 rr - rr := rr2 - let wAbsBound : Int := (if w < 0 then -w else w) + slack - curRowAbs := curRowAbs + wAbsBound - let wI : Fixed10Interval := { lo := w - slack, hi := w + slack } - let term := Fixed10Interval.mul cfg wI xi - out := out.set! colIdx (Fixed10Interval.add (out[colIdx]!) term) - colIdx := colIdx + 1 - maxRowAbs := max maxRowAbs curRowAbs - curRowAbs := 0 - rowIdx := rowIdx + 1 - let normInf : Rat := - Rat.normalize maxRowAbs cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - return (out, normInf, rr) - -private def consumeMatrixMulAndNormInfFixedBinary - (cfg : Fixed10Cfg) - (slack : Int) - (h : IO.FS.Handle) - (rows cols : Nat) - (input : Array Fixed10Interval) - (scalePow10 : Nat) : - IO (Except String (Array Fixed10Interval × Rat)) := do - if input.size ≠ rows then - match ← skipF64Array h (rows * cols) with - | .error e => return .error e - | .ok _ => return .ok (Array.replicate cols { lo := 0, hi := 0 }, 0) - match ← readScaledFloatArray h (rows * cols) scalePow10 with - | .error e => return .error e - | .ok vals => - let mut out : Array Fixed10Interval := Array.replicate cols { lo := 0, hi := 0 } - let mut curRowAbs : Int := 0 - let mut maxRowAbs : Int := 0 - let mut rowIdx : Nat := 0 - while rowIdx < rows do - let xi := input[rowIdx]! - let mut colIdx : Nat := 0 - while colIdx < cols do - let idx := rowIdx * cols + colIdx - let w := vals[idx]! - let wAbsBound : Int := (if w < 0 then -w else w) + slack - curRowAbs := curRowAbs + wAbsBound - let wI : Fixed10Interval := { lo := w - slack, hi := w + slack } - let term := Fixed10Interval.mul cfg wI xi - out := out.set! colIdx (Fixed10Interval.add (out[colIdx]!) term) - colIdx := colIdx + 1 - maxRowAbs := max maxRowAbs curRowAbs - curRowAbs := 0 - rowIdx := rowIdx + 1 - let normInf : Rat := - Rat.normalize maxRowAbs cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - return .ok (out, normInf) - -private def consumeMatrixMulFixedBinaryStreaming - (cfg : Fixed10Cfg) - (slack : Int) - (h : IO.FS.Handle) - (rows cols : Nat) - (input : Array Fixed10Interval) - (scalePow10 : Nat) : - IO (Except String (Array Fixed10Interval)) := do - if input.size ≠ rows then - match ← skipF64Array h (rows * cols) with - | .error e => return .error e - | .ok _ => return .ok (Array.replicate cols { lo := 0, hi := 0 }) - let mut out : Array Fixed10Interval := Array.replicate cols { lo := 0, hi := 0 } - let mut rowIdx : Nat := 0 - while rowIdx < rows do - let rowWeightsE ← readScaledFloatArray h cols scalePow10 - match rowWeightsE with - | .error e => return .error e - | .ok rowWeights => - let xi := input[rowIdx]! - let mut colIdx : Nat := 0 - while colIdx < cols do - let w := rowWeights[colIdx]! - let wI : Fixed10Interval := { lo := w - slack, hi := w + slack } - let term := Fixed10Interval.mul cfg wI xi - out := out.set! colIdx (Fixed10Interval.add (out[colIdx]!) term) - colIdx := colIdx + 1 - rowIdx := rowIdx + 1 - return .ok out - -/-- Apply union-MLP propagation for binary bounds using streaming matmul. -/ -private def mlpUnionStepBinary - (cfg : Fixed10Cfg) - (slack : Int) - (h : IO.FS.Handle) - (modelDim hiddenDim : Nat) - (ln2Rows : Array (Array Fixed10Interval)) - (residuals : Array (Array Fixed10Interval)) - (scalePow10 : Nat) : - IO (Except String (Array (Array Fixed10Interval))) := do - let ln2Union := unionRowsFixed ln2Rows - let hidden0E ← - consumeMatrixMulFixedBinaryStreaming cfg slack h modelDim hiddenDim ln2Union scalePow10 - match hidden0E with - | .error e => return .error e - | .ok hidden0 => - let bInE ← readVecIntervalsBinary h hiddenDim slack scalePow10 - match bInE with - | .error e => return .error e - | .ok bIn => - let hiddenB := addVecFixed hidden0 bIn - -- Linear GeLU hull keeps the union path fast and avoids heavy tanh bounds. - let actHidden := geluOverapproxFixedVecLinear hiddenB - let mut mlpOut0 : Array Fixed10Interval := - Array.replicate modelDim { lo := 0, hi := 0 } - let mut rowIdx : Nat := 0 - while rowIdx < hiddenDim do - let rowWeightsE ← readScaledFloatArray h modelDim scalePow10 - match rowWeightsE with - | .error e => return .error e - | .ok rowWeights => - let xi := actHidden[rowIdx]! - let mut colIdx : Nat := 0 - while colIdx < modelDim do - let w := rowWeights[colIdx]! - let wI : Fixed10Interval := { lo := w - slack, hi := w + slack } - let term := Fixed10Interval.mul cfg wI xi - mlpOut0 := mlpOut0.set! colIdx - (Fixed10Interval.add (mlpOut0[colIdx]!) term) - colIdx := colIdx + 1 - rowIdx := rowIdx + 1 - let bOutE ← readVecIntervalsBinary h modelDim slack scalePow10 - match bOutE with - | .error e => return .error e - | .ok bOut => - let mlpOut := addVecFixed mlpOut0 bOut - let residuals' := addVecFixedRows residuals mlpOut - return .ok residuals' - -private def loadEmbeddingsUnionFixed - (cfg : Fixed10Cfg) - (path : System.FilePath) - (expectedModelDim : Nat) - (delta : Rat) : IO (Except String (Array Fixed10Interval × Nat)) := do - let deltaInt : Int := ratCeilMulNat delta cfg.scaleNat - let mut out : Array Fixed10Interval := Array.replicate expectedModelDim { lo := 0, hi := 0 } - let mut iCol : Nat := 0 - let mut remaining : Nat := 0 - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - -- Header: read until blank line. - let mut seqLen : Option Nat := none - let mut modelDim : Option Nat := none - let mut seenHeader : Bool := false - while true do - let line ← h.getLine - if line.isEmpty then - return .error "unexpected EOF while reading input header" - let s := line.trim - if !seenHeader then - if s.startsWith "NFP_TEXT" then - seenHeader := true - continue - if s.isEmpty then - break - match parseHeaderLine s with - | none => pure () - | some (k, v) => - if k = "seq_len" then - seqLen := v.toNat? - else if k = "model_dim" then - modelDim := v.toNat? - else - pure () - let some n := seqLen | return .error "missing seq_len in input file" - let some d := modelDim | return .error "missing model_dim in input file" - if d ≠ expectedModelDim then - return .error s!"input model_dim mismatch (expected {expectedModelDim}, got {d})" - remaining := n * d - -- Scan to EMBEDDINGS marker. - let mut found : Bool := false - while !found do - let line ← h.getLine - if line.isEmpty then - return .error "unexpected EOF while scanning for EMBEDDINGS" - if line.trim = "EMBEDDINGS" then - found := true - while remaining > 0 do - let line ← h.getLine - if line.isEmpty then - return .error "unexpected EOF while reading EMBEDDINGS" - let s := line.trim - if s.isEmpty then - continue - let bytes := s.toUTF8 - let mut j : Nat := 0 - while j < bytes.size && remaining > 0 do - while j < bytes.size && (bytes[j]! = 32 || bytes[j]! = 9) do - j := j + 1 - if j ≥ bytes.size then - break - let tokStart := j - while j < bytes.size && (bytes[j]! ≠ 32 && bytes[j]! ≠ 9) do - j := j + 1 - let tokStop := j - match parseFixed10Rounded cfg.scalePow10 bytes tokStart tokStop with - | .error e => return .error e - | .ok x => - let lo := x - fixedUlpSlack - deltaInt - let hi := x + fixedUlpSlack + deltaInt - let cur := out[iCol]! - out := out.set! iCol { lo := min cur.lo lo, hi := max cur.hi hi } - iCol := (iCol + 1) % expectedModelDim - remaining := remaining - 1 - return .ok (out, n) - -/-- Parse binary embeddings into a union-box of fixed-point intervals. -/ -private def loadEmbeddingsUnionFixedBinary - (path : System.FilePath) - (expectedModelDim : Nat) - (delta : Rat) - (scalePow10 : Nat := defaultBinaryScalePow10) : - IO (Except String (Array Fixed10Interval)) := do - if delta < 0 then - return .error "delta must be nonnegative" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - match ← readBinaryHeader h with - | .error e => return .error e - | .ok hdr => - if hdr.modelDim ≠ expectedModelDim then - return .error - s!"input model_dim mismatch (expected {expectedModelDim}, got {hdr.modelDim})" - let total := hdr.seqLen * hdr.modelDim - let deltaScaled : Int := ratCeilMulNat delta (Nat.pow 10 scalePow10) - match ← skipI32Array h hdr.seqLen with - | .error e => return .error e - | .ok _ => pure () - match ← readScaledFloatArray h total scalePow10 with - | .error e => return .error e - | .ok scaled => - if total = 0 then - return .ok #[] - let mut out : Array Fixed10Interval := Array.mkEmpty hdr.modelDim - for col in [:hdr.modelDim] do - let v := scaled[col]! - out := out.push { lo := v - deltaScaled, hi := v + deltaScaled } - for i in [hdr.modelDim:total] do - let v := scaled[i]! - let col := i % hdr.modelDim - let lo := v - deltaScaled - let hi := v + deltaScaled - let cur := out[col]! - out := out.set! col { lo := min cur.lo lo, hi := max cur.hi hi } - return .ok out - -/-- Parse binary embeddings into per-position fixed-point intervals. -/ -private def loadEmbeddingsIntervalsBinary - (path : System.FilePath) - (expectedModelDim : Nat) - (delta : Rat) - (scalePow10 : Nat := defaultBinaryScalePow10) : - IO (Except String (Array (Array Fixed10Interval))) := do - if delta < 0 then - return .error "delta must be nonnegative" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - match ← readBinaryHeader h with - | .error e => return .error e - | .ok hdr => - if hdr.modelDim ≠ expectedModelDim then - return .error - s!"input model_dim mismatch (expected {expectedModelDim}, got {hdr.modelDim})" - let total := hdr.seqLen * hdr.modelDim - let deltaScaled : Int := ratCeilMulNat delta (Nat.pow 10 scalePow10) - match ← skipI32Array h hdr.seqLen with - | .error e => return .error e - | .ok _ => pure () - match ← readScaledFloatArray h total scalePow10 with - | .error e => return .error e - | .ok scaled => - if total = 0 then - return .ok #[] - let useTasks := hdr.seqLen > 32 - if useTasks then - let chunkSize : Nat := 16 - let numChunks : Nat := (hdr.seqLen + chunkSize - 1) / chunkSize - let mut tasks : - Array (Task (Array (Array Fixed10Interval))) := Array.mkEmpty numChunks - let mut chunkIdx : Nat := 0 - while chunkIdx < numChunks do - let start := chunkIdx * chunkSize - let stop := min hdr.seqLen (start + chunkSize) - tasks := tasks.push <| - Task.spawn (fun _ => - Id.run do - let mut rowsChunk : - Array (Array Fixed10Interval) := Array.mkEmpty (stop - start) - let mut rowIdx := start - while rowIdx < stop do - let mut row : Array Fixed10Interval := Array.mkEmpty hdr.modelDim - for colIdx in [:hdr.modelDim] do - let idx := rowIdx * hdr.modelDim + colIdx - let v := scaled[idx]! - row := row.push { lo := v - deltaScaled, hi := v + deltaScaled } - rowsChunk := rowsChunk.push row - rowIdx := rowIdx + 1 - return rowsChunk) - chunkIdx := chunkIdx + 1 - let mut rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for t in tasks do - for row in t.get do - rows := rows.push row - return .ok rows - else - let mut rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for rowIdx in [:hdr.seqLen] do - let mut row : Array Fixed10Interval := Array.mkEmpty hdr.modelDim - for colIdx in [:hdr.modelDim] do - let idx := rowIdx * hdr.modelDim + colIdx - let v := scaled[idx]! - row := row.push { lo := v - deltaScaled, hi := v + deltaScaled } - rows := rows.push row - return .ok rows - -private def loadTokensBinary - (path : System.FilePath) : IO (Except String (BinaryHeader × Array Int)) := do - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - match ← readBinaryHeader h with - | .error e => return .error e - | .ok hdr => - match ← readI32Array h hdr.seqLen with - | .error e => return .error e - | .ok toks => return .ok (hdr, toks) - -/-- Shared binary inputs for repeated local bound checks. -/ -private structure SharedBinaryInputs where - hdr : BinaryHeader - ln1Params : Array LayerNormParamsFixed - ln2Params : Array LayerNormParamsFixed - tokens : Array Int - residuals0 : Array (Array Fixed10Interval) - inputDelta : Rat - scalePow10 : Nat - -/-- Cached prefix views for a fixed query position. -/ -private structure SharedBinaryPrefix where - seqLenEff : Nat - residuals : Thunk (Array (Array Fixed10Interval)) - tokens : Thunk (Array Int) - -/-- Load shared model/input data once for reuse across best-match configs. -/ -private def loadSharedBinaryInputs - (path : System.FilePath) - (inputPath : System.FilePath) - (inputDelta : Rat) - (scalePow10 : Nat) : - IO (Except String SharedBinaryInputs) := do - let slack : Int := fixedUlpSlack - let action : ExceptT String IO SharedBinaryInputs := do - let paramsTask ← - ExceptT.lift <| IO.asTask (collectLayerNormParamsBinary path scalePow10 slack) - let tokensTask ← - ExceptT.lift <| IO.asTask (loadTokensBinary inputPath) - let (hdr, ln1Params, ln2Params) ← - match paramsTask.get with - | .error e => throw (toString e) - | .ok (.error msg) => throw msg - | .ok (.ok v) => pure v - let (hdrTok, tokens) ← - match tokensTask.get with - | .error e => throw (toString e) - | .ok (.error msg) => throw msg - | .ok (.ok v) => pure v - let residuals0 ← - ExceptT.mk - (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) - if hdrTok.seqLen ≠ hdr.seqLen then - throw "token/embedding seq_len mismatch" - return { - hdr := hdr - ln1Params := ln1Params - ln2Params := ln2Params - tokens := tokens - residuals0 := residuals0 - inputDelta := inputDelta - scalePow10 := scalePow10 - } - action.run - -/-- Build cached prefix arrays for a fixed query position. -/ -private def mkSharedBinaryPrefix - (shared : SharedBinaryInputs) - (queryPos : Nat) - (causalPattern : Bool) : - SharedBinaryPrefix := - let seqLenEff : Nat := if causalPattern then queryPos + 1 else shared.hdr.seqLen - { - seqLenEff := seqLenEff - residuals := Thunk.mk (fun () => - if causalPattern then takePrefix shared.residuals0 seqLenEff else shared.residuals0) - tokens := Thunk.mk (fun () => - if causalPattern then takePrefix shared.tokens seqLenEff else shared.tokens) - } - -private def skipToUnembeddingBinary - (h : IO.FS.Handle) (hdr : BinaryHeader) : IO (Except String Unit) := do - let action : ExceptT String IO Unit := do - let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) - let _ ← ExceptT.mk (skipF64Array h (hdr.seqLen * hdr.modelDim)) - for _l in [:hdr.numLayers] do - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.headDim * hdr.modelDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.hiddenDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.hiddenDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.hiddenDim * hdr.modelDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - action.run - -/-- Compute local head output lower bounds at a specific query position (binary only). -/ -private def certifyHeadValueLowerBoundLocalBinaryAt - (path : System.FilePath) - (layerIdx headIdx queryPos coord : Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath : System.FilePath) - (inputDelta : Rat) - (targetOffset : Int) - (keyOffset : Int) - (matchWeightLowerBound : Rat) - (maxSeqLen : Nat) - (scalePow10 : Nat := defaultBinaryScalePow10) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (causalPattern : Bool := true) - (shared? : Option SharedBinaryInputs := none) - (prefix? : Option SharedBinaryPrefix := none) : - IO (Except String HeadValueLowerBoundPosCert) := do - let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 - let slack : Int := fixedUlpSlack - let action : ExceptT String IO HeadValueLowerBoundPosCert := do - let (hdr, ln1Params, ln2Params, residualsBase, tokensBase) ← - match shared? with - | some shared => - if shared.scalePow10 ≠ scalePow10 then - throw "shared scalePow10 mismatch" - if shared.inputDelta ≠ inputDelta then - throw "shared inputDelta mismatch" - pure (shared.hdr, shared.ln1Params, shared.ln2Params, shared.residuals0, shared.tokens) - | none => - let (hdr, ln1Params, ln2Params) ← - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) - let residuals0 ← - ExceptT.mk - (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) - let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) - if hdrTok.seqLen ≠ hdr.seqLen then - throw "token/embedding seq_len mismatch" - pure (hdr, ln1Params, ln2Params, residuals0, tokens) - if layerIdx ≥ hdr.numLayers then - throw s!"layer index {layerIdx} out of range" - if headIdx ≥ hdr.numHeads then - throw s!"head index {headIdx} out of range" - if coord ≥ hdr.modelDim then - throw s!"coord index {coord} out of range" - if hdr.seqLen > maxSeqLen then - throw s!"seq_len {hdr.seqLen} exceeds maxSeqLen {maxSeqLen}" - if queryPos ≥ hdr.seqLen then - throw s!"queryPos {queryPos} out of range" - let seqLenEff : Nat := if causalPattern then queryPos + 1 else hdr.seqLen - let (residuals0, tokens) ← - match prefix? with - | some pref => - if pref.seqLenEff ≠ seqLenEff then - throw "prefix seq_len mismatch" - pure (pref.residuals.get, pref.tokens.get) - | none => - let residuals0 := - if causalPattern then takePrefix residualsBase seqLenEff else residualsBase - let tokens := if causalPattern then takePrefix tokensBase seqLenEff else tokensBase - pure (residuals0, tokens) - let keyOffsetNat? : Option Nat := - if keyOffset ≥ 0 then some (Int.toNat keyOffset) else none - let keyOffsetNeg : Nat := Int.toNat (-keyOffset) - let h ← ExceptT.lift <| IO.FS.Handle.mk path IO.FS.Mode.read - let _ ← ExceptT.mk (readBinaryHeader h) - let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) - let _ ← ExceptT.mk (skipF64Array h (hdr.seqLen * hdr.modelDim)) - let defP : LayerNormParamsFixed := { - gamma := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - beta := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - } - let mut residuals := residuals0 - for l in [:hdr.numLayers] do - let p1 := ln1Params.getD l defP - let ln1Rows := fixedLayerNormRowsApprox cfg residuals p1 eps soundnessBits - if l = layerIdx then - let mut wv? : Option (Array Int) := none - let mut bv? : Option (Array Int) := none - let mut wo? : Option (Array Int) := none - for hIdx in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - if hIdx = headIdx then - let wv ← - ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - wv? := some wv - let bV ← - ExceptT.mk <| - readScaledFloatArray h hdr.headDim scalePow10 - bv? := some bV - let wo ← - ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - wo? := some wo - else - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.headDim * hdr.modelDim)) - let wv ← - match wv? with - | none => throw "missing W_V for requested head" - | some xs => pure xs - let bV ← - match bv? with - | none => throw "missing b_V for requested head" - | some xs => pure xs - let wo ← - match wo? with - | none => throw "missing W_O for requested head" - | some xs => pure xs - let bVIntervals := intervalsFromScaled bV slack - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty seqLenEff - for row in ln1Rows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bVIntervals - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let ti : Int := (Int.ofNat queryPos) + targetOffset - if ti < 0 || ti ≥ (Int.ofNat seqLenEff) then - throw "query position has no valid target offset" - let tIdx : Nat := Int.toNat ti - let targetTok := tokens[tIdx]! - let mut matchLo? : Option Int := none - let mut nonmatchLo? : Option Int := none - for j in [:seqLenEff] do - if !causalPattern || j ≤ queryPos then - let row := vOutRows[j]! - let vCoord := row[coord]!.lo - let isMatch : Bool := - match keyOffsetNat? with - | some k => - let idx := j + k - idx < seqLenEff && tokens[idx]! = targetTok - | none => - if j < keyOffsetNeg then - false - else - tokens[j - keyOffsetNeg]! = targetTok - if isMatch then - matchLo? := - match matchLo? with - | none => some vCoord - | some m => some (min m vCoord) - else - nonmatchLo? := - match nonmatchLo? with - | none => some vCoord - | some m => some (min m vCoord) - else - pure () - let matchLo ← - match matchLo? with - | none => throw "no matching keys for the requested offset" - | some v => pure v - let nonmatchLo := - match nonmatchLo? with - | none => matchLo - | some v => v - let matchLoRat := ratOfScaledInt scalePow10 matchLo - let nonmatchLoRat := ratOfScaledInt scalePow10 nonmatchLo - let weightLB := matchWeightLowerBound - let outputLB := mixLowerBound weightLB matchLoRat nonmatchLoRat - let cert : HeadValueLowerBoundPosCert := { - layerIdx := layerIdx - headIdx := headIdx - queryPos := queryPos - coord := coord - matchWeightLowerBound := weightLB - matchCoordLowerBound := matchLoRat - nonmatchCoordLowerBound := nonmatchLoRat - outputCoordLowerBound := outputLB - } - if cert.check then - return cert - throw "head value lower bound (pos) failed internal consistency checks" - else - let tightLayers : Nat := - if tightPattern then Nat.max 1 tightPatternLayers else 0 - if tightLayers > 0 && layerIdx ≤ l + tightLayers then - if causalPattern then - let zeroRow : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - let mut attnRows : Array (Array Fixed10Interval) := - Array.replicate hdr.seqLen zeroRow - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wv ← ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let wo ← ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - let mut rowIdx : Nat := 0 - while rowIdx < ln1Rows.size do - let row := ln1Rows[rowIdx]! - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - rowIdx := rowIdx + 1 - let headRows := prefixUnionRowsFixed vOutRows - attnRows := addRowsFixed attnRows headRows - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnRows := addVecFixedRows attnRows attnBias - residuals := addRowsFixed residuals attnRows - else - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - let groupRows := groupUnionRowsByToken ln1Rows tokens - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wv ← ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let wo ← ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty groupRows.size - let mut rowIdx : Nat := 0 - while rowIdx < groupRows.size do - let row := groupRows[rowIdx]! - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - rowIdx := rowIdx + 1 - let vUnion := unionRowsFixed vOutRows - attnUnion := addVecFixed attnUnion vUnion - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - else - let ln1Union := unionRowsFixed ln1Rows - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wv ← - ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv ln1Union - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let vHidden := addVecFixed vHidden0 bV - let wo ← - ExceptT.mk <| readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - attnUnion := addVecFixed attnUnion vOut - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - let p2 := ln2Params.getD l defP - let ln2Rows := fixedLayerNormRowsApprox cfg residuals p2 eps soundnessBits - let perRowLayers : Nat := perRowPatternLayers - if perRowLayers > 0 && layerIdx ≤ l + perRowLayers then - let wIn ← - ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.hiddenDim) scalePow10 - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let wOut ← - ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpRows := - mlpRowsFromScaled cfg hdr.geluDerivTarget slack - hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows - residuals := addRowsFixed residuals mlpRows - else - let ln2Union := unionRowsFixed ln2Rows - let (hidden0, _nWin) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Union scalePow10) - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let hiddenB := addVecFixed hidden0 bIn - let actHidden := geluOverapproxFixedVec cfg hdr.geluDerivTarget hiddenB - let (mlpOut0, _nWout) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.hiddenDim hdr.modelDim actHidden scalePow10) - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpOut := addVecFixed mlpOut0 bOut - residuals := addVecFixedRows residuals mlpOut - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - throw "target layer not reached" - action.run - -/-- Combined value + optional logit certs for a single query position (binary only). -/ -private structure HeadValueLogitCert where - value : HeadValueLowerBoundPosCert - logit? : Option HeadLogitDiffLowerBoundPosCert - -/-- Compute value and optional logit bounds for a head at a query position (binary only). -/ -private def certifyHeadValueLogitLowerBoundLocalBinaryAt - (path : System.FilePath) - (layerIdx headIdx queryPos coord : Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath : System.FilePath) - (inputDelta : Rat) - (targetOffset : Int) - (keyOffset : Int) - (matchWeightLowerBound : Rat) - (maxSeqLen : Nat) - (scalePow10 : Nat := defaultBinaryScalePow10) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (causalPattern : Bool := true) - (shared? : Option SharedBinaryInputs := none) - (prefix? : Option SharedBinaryPrefix := none) - (targetToken? : Option Nat := none) - (negativeToken? : Option Nat := none) - (direction? : Option (Thunk (Array Fixed10Interval)) := none) : - IO (Except String HeadValueLogitCert) := do - let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 - let slack : Int := fixedUlpSlack - let action : ExceptT String IO HeadValueLogitCert := do - let (hdr, ln1Params, ln2Params, residualsBase, tokensBase) ← - match shared? with - | some shared => - if shared.scalePow10 ≠ scalePow10 then - throw "shared scalePow10 mismatch" - if shared.inputDelta ≠ inputDelta then - throw "shared inputDelta mismatch" - pure (shared.hdr, shared.ln1Params, shared.ln2Params, shared.residuals0, shared.tokens) - | none => - let (hdr, ln1Params, ln2Params) ← - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) - let residuals0 ← - ExceptT.mk - (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) - let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) - if hdrTok.seqLen ≠ hdr.seqLen then - throw "token/embedding seq_len mismatch" - pure (hdr, ln1Params, ln2Params, residuals0, tokens) - if layerIdx ≥ hdr.numLayers then - throw s!"layer index {layerIdx} out of range" - if headIdx ≥ hdr.numHeads then - throw s!"head index {headIdx} out of range" - if coord ≥ hdr.modelDim then - throw s!"coord index {coord} out of range" - if hdr.seqLen > maxSeqLen then - throw s!"seq_len {hdr.seqLen} exceeds maxSeqLen {maxSeqLen}" - if queryPos ≥ hdr.seqLen then - throw s!"queryPos {queryPos} out of range" - let seqLenEff : Nat := if causalPattern then queryPos + 1 else hdr.seqLen - let (residuals0, tokens) ← - match prefix? with - | some pref => - if pref.seqLenEff ≠ seqLenEff then - throw "prefix seq_len mismatch" - pure (pref.residuals.get, pref.tokens.get) - | none => - let residuals0 := - if causalPattern then takePrefix residualsBase seqLenEff else residualsBase - let tokens := if causalPattern then takePrefix tokensBase seqLenEff else tokensBase - pure (residuals0, tokens) - let keyOffsetNat? : Option Nat := - if keyOffset ≥ 0 then some (Int.toNat keyOffset) else none - let keyOffsetNeg : Nat := Int.toNat (-keyOffset) - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let _ ← ExceptT.mk (readBinaryHeader h) - let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) - let _ ← ExceptT.mk (skipF64Array h (hdr.seqLen * hdr.modelDim)) - let defP : LayerNormParamsFixed := { - gamma := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - beta := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - } - let mut residuals := residuals0 - for l in [:hdr.numLayers] do - let p1 := ln1Params.getD l defP - let ln1Rows := fixedLayerNormRowsApprox cfg residuals p1 eps soundnessBits - if l = layerIdx then - let mut wv? : Option (Array Int) := none - let mut bv? : Option (Array Int) := none - let mut wo? : Option (Array Int) := none - for hIdx in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - if hIdx = headIdx then - let wv ← - ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - wv? := some wv - let bV ← - ExceptT.mk <| - readScaledFloatArray h hdr.headDim scalePow10 - bv? := some bV - let wo ← - ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - wo? := some wo - else - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.headDim * hdr.modelDim)) - let wv ← - match wv? with - | none => throw "missing W_V for requested head" - | some xs => pure xs - let bV ← - match bv? with - | none => throw "missing b_V for requested head" - | some xs => pure xs - let wo ← - match wo? with - | none => throw "missing W_O for requested head" - | some xs => pure xs - let bVIntervals := intervalsFromScaled bV slack - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty seqLenEff - for row in ln1Rows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bVIntervals - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let ti : Int := (Int.ofNat queryPos) + targetOffset - if ti < 0 || ti ≥ (Int.ofNat seqLenEff) then - throw "query position has no valid target offset" - let tIdx : Nat := Int.toNat ti - let targetTok := tokens[tIdx]! - let mut matchLo? : Option Int := none - let mut nonmatchLo? : Option Int := none - for j in [:seqLenEff] do - if !causalPattern || j ≤ queryPos then - let row := vOutRows[j]! - let vCoord := row[coord]!.lo - let isMatch : Bool := - match keyOffsetNat? with - | some k => - let idx := j + k - idx < seqLenEff && tokens[idx]! = targetTok - | none => - if j < keyOffsetNeg then - false - else - tokens[j - keyOffsetNeg]! = targetTok - if isMatch then - matchLo? := - match matchLo? with - | none => some vCoord - | some m => some (min m vCoord) - else - nonmatchLo? := - match nonmatchLo? with - | none => some vCoord - | some m => some (min m vCoord) - else - pure () - let matchLo ← - match matchLo? with - | none => throw "no matching keys for the requested offset" - | some v => pure v - let nonmatchLo := - match nonmatchLo? with - | none => matchLo - | some v => v - let matchLoRat := ratOfScaledInt scalePow10 matchLo - let nonmatchLoRat := ratOfScaledInt scalePow10 nonmatchLo - let outputLB := mixLowerBound matchWeightLowerBound matchLoRat nonmatchLoRat - let value : HeadValueLowerBoundPosCert := { - layerIdx := layerIdx - headIdx := headIdx - queryPos := queryPos - coord := coord - matchWeightLowerBound := matchWeightLowerBound - matchCoordLowerBound := matchLoRat - nonmatchCoordLowerBound := nonmatchLoRat - outputCoordLowerBound := outputLB - } - if !value.check then - throw "head value certificate failed internal consistency checks" - let logit? ← - match targetToken?, negativeToken?, direction? with - | none, none, none => pure none - | some targetToken, some negativeToken, some direction => do - let dir := direction.get - if dir.size ≠ hdr.modelDim then - throw "logit direction size mismatch" - let vDotRows := - let useTasks := vOutRows.size > 32 - if useTasks then - let tasks := vOutRows.map (fun row => - Task.spawn (fun _ => fixedDotInterval cfg row dir)) - tasks.map (fun t => t.get) - else - Id.run do - let mut out : Array Fixed10Interval := Array.mkEmpty seqLenEff - for row in vOutRows do - out := out.push (fixedDotInterval cfg row dir) - return out - let mut matchLoLogit? : Option Int := none - let mut nonmatchLoLogit? : Option Int := none - for j in [:seqLenEff] do - if !causalPattern || j ≤ queryPos then - let vLo := (vDotRows[j]!).lo - let isMatch : Bool := - match keyOffsetNat? with - | some k => - let idx := j + k - idx < seqLenEff && tokens[idx]! = targetTok - | none => - if j < keyOffsetNeg then - false - else - tokens[j - keyOffsetNeg]! = targetTok - if isMatch then - matchLoLogit? := - match matchLoLogit? with - | none => some vLo - | some m => some (min m vLo) - else - nonmatchLoLogit? := - match nonmatchLoLogit? with - | none => some vLo - | some m => some (min m vLo) - else - pure () - let matchLoLogit ← - match matchLoLogit? with - | none => throw "no matching keys for the requested offset" - | some v => pure v - let nonmatchLoLogit := - match nonmatchLoLogit? with - | none => matchLoLogit - | some v => v - let matchLoRat := ratOfScaledInt scalePow10 matchLoLogit - let nonmatchLoRat := ratOfScaledInt scalePow10 nonmatchLoLogit - let logitLB := mixLowerBound matchWeightLowerBound matchLoRat nonmatchLoRat - let logitCert : HeadLogitDiffLowerBoundPosCert := { - layerIdx := layerIdx - headIdx := headIdx - queryPos := queryPos - targetToken := targetToken - negativeToken := negativeToken - matchWeightLowerBound := matchWeightLowerBound - matchLogitLowerBound := matchLoRat - nonmatchLogitLowerBound := nonmatchLoRat - logitDiffLowerBound := logitLB - } - if logitCert.check then - pure (some logitCert) - else - throw "head logit certificate failed internal consistency checks" - | _, _, _ => - throw "use both target and negative tokens (or neither)" - return { value := value, logit? := logit? } - else - let tightLayers : Nat := - if tightPattern then Nat.max 1 tightPatternLayers else 0 - if tightLayers > 0 && layerIdx ≤ l + tightLayers then - if causalPattern then - let zeroRow : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - let mut attnRows : Array (Array Fixed10Interval) := - Array.replicate hdr.seqLen zeroRow - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wv ← ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let wo ← ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in ln1Rows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let headRows := prefixUnionRowsFixed vOutRows - attnRows := addRowsFixed attnRows headRows - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnRows := addVecFixedRows attnRows attnBias - residuals := addRowsFixed residuals attnRows - else - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - let groupRows := groupUnionRowsByToken ln1Rows tokens - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wv ← ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let wo ← ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty groupRows.size - for row in groupRows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let vUnion := unionRowsFixed vOutRows - attnUnion := addVecFixed attnUnion vUnion - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - else - let ln1Union := unionRowsFixed ln1Rows - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wv ← - ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv ln1Union - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let vHidden := addVecFixed vHidden0 bV - let wo ← - ExceptT.mk <| readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - attnUnion := addVecFixed attnUnion vOut - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - let p2 := ln2Params.getD l defP - let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty seqLenEff - for row in residuals do - let (ln2Out, _ln2VarLB) := - fixedLayerNormRowApprox cfg row p2.gamma p2.beta eps soundnessBits - ln2Rows := ln2Rows.push ln2Out - let perRowLayers : Nat := perRowPatternLayers - if perRowLayers > 0 && layerIdx ≤ l + perRowLayers then - let wIn ← - ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.hiddenDim) scalePow10 - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let wOut ← - ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpRows := - mlpRowsFromScaled cfg hdr.geluDerivTarget slack - hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows - residuals := addRowsFixed residuals mlpRows - else - let ln2Union := unionRowsFixed ln2Rows - let (hidden0, _nWin) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Union scalePow10) - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let hiddenB := addVecFixed hidden0 bIn - let actHidden := geluOverapproxFixedVec cfg hdr.geluDerivTarget hiddenB - let (mlpOut0, _nWout) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.hiddenDim hdr.modelDim actHidden scalePow10) - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpOut := addVecFixed mlpOut0 bOut - residuals := addVecFixedRows residuals mlpOut - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - throw "target layer not reached" - action.run - - -private def readUnembeddingColumnsBinary - (path : System.FilePath) - (tokenA tokenB : Nat) - (scalePow10 : Nat) : - IO (Except String (BinaryHeader × Array Int × Array Int)) := do - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let action : ExceptT String IO (BinaryHeader × Array Int × Array Int) := do - let hdr ← ExceptT.mk (readBinaryHeader h) - if tokenA ≥ hdr.vocabSize || tokenB ≥ hdr.vocabSize then - throw "token index out of range for unembedding" - if tokenA = tokenB then - throw "target and negative tokens must differ" - let _ ← ExceptT.mk (skipToUnembeddingBinary h hdr) - let loTok := min tokenA tokenB - let hiTok := max tokenA tokenB - let swapped : Bool := tokenA > tokenB - let mut colA : Array Int := Array.mkEmpty hdr.modelDim - let mut colB : Array Int := Array.mkEmpty hdr.modelDim - for _r in [:hdr.modelDim] do - let _ ← ExceptT.mk (skipF64Array h loTok) - let vLo ← ExceptT.mk (readScaledFloat h scalePow10) - let _ ← ExceptT.mk (skipF64Array h (hiTok - loTok - 1)) - let vHi ← ExceptT.mk (readScaledFloat h scalePow10) - let _ ← ExceptT.mk (skipF64Array h (hdr.vocabSize - hiTok - 1)) - if swapped then - colA := colA.push vHi - colB := colB.push vLo - else - colA := colA.push vLo - colB := colB.push vHi - return (hdr, colA, colB) - action.run - -private def readLogitDiffDirectionBinary - (path : System.FilePath) - (targetToken negativeToken : Nat) - (scalePow10 : Nat) - (slack : Int) : - IO (Except String (BinaryHeader × Array Fixed10Interval)) := do - let action : ExceptT String IO (BinaryHeader × Array Fixed10Interval) := do - let (hdr, colTarget, colNeg) ← - ExceptT.mk (readUnembeddingColumnsBinary path targetToken negativeToken scalePow10) - if colTarget.size ≠ hdr.modelDim || colNeg.size ≠ hdr.modelDim then - throw "unembedding column size mismatch" - let targetIntervals := intervalsFromScaled colTarget slack - let negIntervals := intervalsFromScaled colNeg slack - let mut dir : Array Fixed10Interval := Array.mkEmpty hdr.modelDim - for i in [:hdr.modelDim] do - dir := dir.push (Fixed10Interval.sub targetIntervals[i]! negIntervals[i]!) - return (hdr, dir) - action.run - -/-- Compute local head logit-difference lower bounds at a specific query position (binary only). -/ -private def certifyHeadLogitDiffLowerBoundLocalBinaryAt - (path : System.FilePath) - (layerIdx headIdx queryPos : Nat) - (targetToken negativeToken : Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath : System.FilePath) - (inputDelta : Rat) - (targetOffset : Int) - (keyOffset : Int) - (matchWeightLowerBound : Rat) - (maxSeqLen : Nat) - (scalePow10 : Nat := defaultBinaryScalePow10) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (causalPattern : Bool := true) - (shared? : Option SharedBinaryInputs := none) - (prefix? : Option SharedBinaryPrefix := none) - (direction? : Option (Thunk (Array Fixed10Interval)) := none) : - IO (Except String HeadLogitDiffLowerBoundPosCert) := do - let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 - let slack : Int := fixedUlpSlack - let action : ExceptT String IO HeadLogitDiffLowerBoundPosCert := do - let (direction, hdrDir?) ← - match direction? with - | some thunk => pure (thunk.get, none) - | none => - let (hdrDir, dir) ← - ExceptT.mk <| - readLogitDiffDirectionBinary path targetToken negativeToken scalePow10 slack - pure (dir, some hdrDir) - let (hdr, ln1Params, ln2Params, residualsBase, tokensBase) ← - match shared? with - | some shared => - if shared.scalePow10 ≠ scalePow10 then - throw "shared scalePow10 mismatch" - if shared.inputDelta ≠ inputDelta then - throw "shared inputDelta mismatch" - pure (shared.hdr, shared.ln1Params, shared.ln2Params, shared.residuals0, shared.tokens) - | none => - let (hdr, ln1Params, ln2Params) ← - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) - let residuals0 ← - ExceptT.mk - (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) - let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) - if hdrTok.seqLen ≠ hdr.seqLen then - throw "token/embedding seq_len mismatch" - pure (hdr, ln1Params, ln2Params, residuals0, tokens) - match hdrDir? with - | some hdrDir => - if hdr.modelDim ≠ hdrDir.modelDim then - throw "unembedding model_dim mismatch" - | none => - if direction.size ≠ hdr.modelDim then - throw "logit direction size mismatch" - if layerIdx ≥ hdr.numLayers then - throw s!"layer index {layerIdx} out of range" - if headIdx ≥ hdr.numHeads then - throw s!"head index {headIdx} out of range" - if queryPos ≥ hdr.seqLen then - throw s!"queryPos {queryPos} out of range" - if hdr.seqLen > maxSeqLen then - throw s!"seq_len {hdr.seqLen} exceeds maxSeqLen {maxSeqLen}" - let seqLenEff : Nat := if causalPattern then queryPos + 1 else hdr.seqLen - if direction.size ≠ hdr.modelDim then - throw "logit direction size mismatch" - let (residuals0, tokens) ← - match prefix? with - | some pref => - if pref.seqLenEff ≠ seqLenEff then - throw "prefix seq_len mismatch" - pure (pref.residuals.get, pref.tokens.get) - | none => - let residuals0 := - if causalPattern then takePrefix residualsBase seqLenEff else residualsBase - let tokens := if causalPattern then takePrefix tokensBase seqLenEff else tokensBase - pure (residuals0, tokens) - let keyOffsetNat? : Option Nat := - if keyOffset ≥ 0 then some (Int.toNat keyOffset) else none - let keyOffsetNeg : Nat := Int.toNat (-keyOffset) - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let _ ← ExceptT.mk (readBinaryHeader h) - let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) - let _ ← ExceptT.mk (skipF64Array h (hdr.seqLen * hdr.modelDim)) - let defP : LayerNormParamsFixed := { - gamma := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - beta := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - } - let mut residuals := residuals0 - for l in [:hdr.numLayers] do - let p1 := ln1Params.getD l defP - let mut ln1Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln1Out, _ln1VarLB) := - fixedLayerNormRowApprox cfg row p1.gamma p1.beta eps soundnessBits - ln1Rows := ln1Rows.push ln1Out - if l = layerIdx then - let mut wv? : Option (Array Int) := none - let mut bv? : Option (Array Int) := none - let mut wo? : Option (Array Int) := none - for hIdx in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - if hIdx = headIdx then - let wv ← - ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - wv? := some wv - let bV ← - ExceptT.mk <| - readScaledFloatArray h hdr.headDim scalePow10 - bv? := some bV - let wo ← - ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - wo? := some wo - else - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.headDim * hdr.modelDim)) - let wv ← - match wv? with - | none => throw "missing W_V for requested head" - | some xs => pure xs - let bV ← - match bv? with - | none => throw "missing b_V for requested head" - | some xs => pure xs - let wo ← - match wo? with - | none => throw "missing W_O for requested head" - | some xs => pure xs - let bVIntervals := intervalsFromScaled bV slack - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty seqLenEff - for row in ln1Rows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bVIntervals - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let mut vDotRows : Array Fixed10Interval := Array.mkEmpty seqLenEff - for row in vOutRows do - vDotRows := vDotRows.push (fixedDotInterval cfg row direction) - let ti : Int := (Int.ofNat queryPos) + targetOffset - if ti < 0 || ti ≥ (Int.ofNat seqLenEff) then - throw "query position has no valid target offset" - let tIdx : Nat := Int.toNat ti - let targetTok := tokens[tIdx]! - let mut matchLo? : Option Int := none - let mut nonmatchLo? : Option Int := none - for j in [:seqLenEff] do - if !causalPattern || j ≤ queryPos then - let vLo := (vDotRows[j]!).lo - let isMatch : Bool := - match keyOffsetNat? with - | some k => - let idx := j + k - idx < seqLenEff && tokens[idx]! = targetTok - | none => - if j < keyOffsetNeg then - false - else - tokens[j - keyOffsetNeg]! = targetTok - if isMatch then - matchLo? := - match matchLo? with - | none => some vLo - | some m => some (min m vLo) - else - nonmatchLo? := - match nonmatchLo? with - | none => some vLo - | some m => some (min m vLo) - else - pure () - let matchLo ← - match matchLo? with - | none => throw "no matching keys for the requested offset" - | some v => pure v - let nonmatchLo := - match nonmatchLo? with - | none => matchLo - | some v => v - let matchLoRat := ratOfScaledInt scalePow10 matchLo - let nonmatchLoRat := ratOfScaledInt scalePow10 nonmatchLo - let weightLB := matchWeightLowerBound - let outputLB := mixLowerBound weightLB matchLoRat nonmatchLoRat - let cert : HeadLogitDiffLowerBoundPosCert := { - layerIdx := layerIdx - headIdx := headIdx - queryPos := queryPos - targetToken := targetToken - negativeToken := negativeToken - matchWeightLowerBound := weightLB - matchLogitLowerBound := matchLoRat - nonmatchLogitLowerBound := nonmatchLoRat - logitDiffLowerBound := outputLB - } - if cert.check then - return cert - throw "head logit lower bound (pos) failed internal consistency checks" - else - let tightLayers : Nat := - if tightPattern then Nat.max 1 tightPatternLayers else 0 - if tightLayers > 0 && layerIdx ≤ l + tightLayers then - if causalPattern then - let zeroRow : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - let mut attnRows : Array (Array Fixed10Interval) := - Array.replicate hdr.seqLen zeroRow - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wv ← ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let wo ← ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in ln1Rows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let headRows := prefixUnionRowsFixed vOutRows - attnRows := addRowsFixed attnRows headRows - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnRows := addVecFixedRows attnRows attnBias - residuals := addRowsFixed residuals attnRows - else - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - let groupRows := groupUnionRowsByToken ln1Rows tokens - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wv ← ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let wo ← ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty groupRows.size - for row in groupRows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let vUnion := unionRowsFixed vOutRows - attnUnion := addVecFixed attnUnion vUnion - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - else - let ln1Union := unionRowsFixed ln1Rows - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let (vHidden0, _nWv) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.headDim ln1Union scalePow10) - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let vHidden := addVecFixed vHidden0 bV - let (vOut, _nWo) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.headDim hdr.modelDim vHidden scalePow10) - attnUnion := addVecFixed attnUnion vOut - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - let p2 := ln2Params.getD l defP - let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty seqLenEff - for row in residuals do - let (ln2Out, _ln2VarLB) := - fixedLayerNormRowApprox cfg row p2.gamma p2.beta eps soundnessBits - ln2Rows := ln2Rows.push ln2Out - let perRowLayers : Nat := perRowPatternLayers - if perRowLayers > 0 && layerIdx ≤ l + perRowLayers then - let wIn ← - ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.hiddenDim) scalePow10 - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let wOut ← - ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpRows := - mlpRowsFromScaled cfg hdr.geluDerivTarget slack - hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows - residuals := addRowsFixed residuals mlpRows - else - let ln2Union := unionRowsFixed ln2Rows - let (hidden0, _nWin) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Union scalePow10) - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let hiddenB := addVecFixed hidden0 bIn - let actHidden := geluOverapproxFixedVec cfg hdr.geluDerivTarget hiddenB - let (mlpOut0, _nWout) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.hiddenDim hdr.modelDim actHidden scalePow10) - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpOut := addVecFixed mlpOut0 bOut - residuals := addVecFixedRows residuals mlpOut - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - throw "target layer not reached" - action.run - -private def ensureSoundCache - (modelPath : System.FilePath) - (scalePow10 : Nat := defaultFixedScalePow10) : - IO (Except String (System.FilePath × SoundCache.Header)) := do - Nfp.Untrusted.SoundCacheIO.ensureCacheDir - let modelHash ← Nfp.Untrusted.SoundCacheIO.fnv1a64File modelPath - let mdata ← modelPath.metadata - let modelSize : UInt64 := mdata.byteSize - let cpath := SoundCache.cachePath modelPath modelHash scalePow10 - if !(← cpath.pathExists) then - match (← Nfp.Untrusted.SoundCacheIO.buildCacheFile modelPath cpath scalePow10) with - | .error e => return .error e - | .ok _ => pure () - let h ← IO.FS.Handle.mk cpath IO.FS.Mode.read - let hdr ← Nfp.Untrusted.SoundCacheIO.readHeader h - if hdr.modelHash ≠ modelHash then - return .error "sound cache hash mismatch" - if hdr.modelSize ≠ modelSize then - return .error "sound cache size mismatch" - return .ok (cpath, hdr) - -private def readWqWkMaxBinary - (path : System.FilePath) - (scalePow10 : Nat := defaultBinaryScalePow10) : - IO (Except String (Array Rat × Array Rat)) := do - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - match ← readBinaryHeader h with - | .error e => return .error e - | .ok hdr => - match ← skipI32Array h hdr.seqLen with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.seqLen * hdr.modelDim) with - | .error e => return .error e - | .ok _ => pure () - let mut wqMax : Array Rat := Array.replicate hdr.numLayers 0 - let mut wkMax : Array Rat := Array.replicate hdr.numLayers 0 - for l in [:hdr.numLayers] do - let mut wqLayer : Rat := 0 - let mut wkLayer : Rat := 0 - for _h in [:hdr.numHeads] do - let wqScaledE ← readMatrixNormInfScaled h hdr.modelDim hdr.headDim scalePow10 - let wqScaled ← - match wqScaledE with - | .error e => return .error e - | .ok v => pure v - match ← skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - let wkScaledE ← readMatrixNormInfScaled h hdr.modelDim hdr.headDim scalePow10 - let wkScaled ← - match wkScaledE with - | .error e => return .error e - | .ok v => pure v - match ← skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.modelDim * hdr.headDim) with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.headDim * hdr.modelDim) with - | .error e => return .error e - | .ok _ => pure () - let wq := ratOfScaledInt scalePow10 wqScaled - let wk := ratOfScaledInt scalePow10 wkScaled - wqLayer := max wqLayer wq - wkLayer := max wkLayer wk - wqMax := wqMax.set! l wqLayer - wkMax := wkMax.set! l wkLayer - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.modelDim * hdr.hiddenDim) with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.hiddenDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.hiddenDim * hdr.modelDim) with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - return .ok (wqMax, wkMax) - -/-- Local (input-dependent) certificate path using streaming interval propagation. - -This is conservative in two key ways to remain streaming/memory-safe: -- it uses a **union box** over tokens throughout (so we never hold `seqLen×modelDim` intervals), - which is sound (a superset) but can be looser than per-token tracking, -- it uses union boxes for attention/MLP linear maps to avoid `seqLen×hiddenDim` blowups. --/ -private def certifyModelFileLocalText - (path : System.FilePath) - (eps : Rat) - (geluDerivTarget : GeluDerivTarget) - (soundnessBits : Nat) - (partitionDepth : Nat) - (inputPath : System.FilePath) - (inputDelta : Rat) - (softmaxMarginLowerBound : Rat := 0) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : IO (Except String ModelCert) := do - if partitionDepth ≠ 0 then - return .error "partitionDepth > 0 not yet implemented" - let contents ← IO.FS.readFile path - let lines : Array String := Nfp.Sound.splitLines contents - -- Header - let mut i : Nat := 0 - while i < lines.size && lines[i]!.trim.isEmpty do - i := i + 1 - if !(i < lines.size) then - return .error "empty model file" - let headerTag := lines[i]!.trim - if !headerTag.startsWith "NFP_TEXT" then - return .error s!"unexpected header '{headerTag}'" - i := i + 1 - let mut numLayers : Option Nat := none - let mut numHeads : Option Nat := none - let mut modelDim : Option Nat := none - let mut headDim : Option Nat := none - let mut hiddenDim : Option Nat := none - let mut seqLen : Option Nat := none - while i < lines.size do - let line := lines[i]!.trim - if line.isEmpty then - i := i + 1 - break - match parseHeaderLine line with - | none => i := i + 1 - | some (k, v) => - match k with - | "num_layers" => numLayers := v.toNat? - | "num_heads" => numHeads := v.toNat? - | "model_dim" => modelDim := v.toNat? - | "head_dim" => headDim := v.toNat? - | "hidden_dim" => hiddenDim := v.toNat? - | "seq_len" => seqLen := v.toNat? - | _ => pure () - i := i + 1 - let some L := numLayers | return .error "missing num_layers" - let some H := numHeads | return .error "missing num_heads" - let some d := modelDim | return .error "missing model_dim" - let some dh := headDim | return .error "missing head_dim" - let some dhid := hiddenDim | return .error "missing hidden_dim" - let some n := seqLen | return .error "missing seq_len" - -- Prepass: collect LN parameters. - let (ln1Params, ln2Params) ← - match collectLayerNormParams lines L d with - | .error e => return .error e - | .ok x => pure x - let defLn : LayerNormParams := { gamma := Array.replicate d 1, beta := Array.replicate d 0 } - -- Input: per-token residual intervals. - let residual0 ← loadEmbeddingsIntervals inputPath n d inputDelta - match residual0 with - | .error e => return .error e - | .ok residualRows0 => - -- Use a single union box for all tokens (sound superset, much faster than - -- `seqLen×modelDim`). - let mut residualUnion := unionRows residualRows0 d - -- Start scanning at first layer marker. - let mut pos : Nat := skipUntil lines 0 (fun s => s.startsWith "LAYER") - let mut layers : Array LayerAmplificationCert := Array.mkEmpty L - let mut totalAmp : Rat := 1 - let mut actDerivBoundMax : Rat := 0 - for l in [:L] do - -- Ensure we're at the next layer. - pos := skipUntil lines pos (fun s => s.startsWith "LAYER") - if pos ≥ lines.size then - return .error s!"unexpected end of file while scanning layer {l}" - pos := pos + 1 - -- LN1: compute per-row outputs (for union) and min variance LB (for Jacobian bound). - let p1 := ln1Params.getD l defLn - let (ln1Out, ln1VarLB) := - layerNormRowApprox residualUnion p1.gamma p1.beta eps soundnessBits - let ln1MaxAbsGamma := maxAbsOfVector p1.gamma - let ln1MaxAbsBeta := maxAbsOfVector p1.beta - let ln1Bound := - if ln1VarLB > 0 then - layerNormOpBoundLocal ln1MaxAbsGamma ln1VarLB eps soundnessBits - else - layerNormOpBoundConservative ln1MaxAbsGamma eps soundnessBits - let ln1OutMaxAbsBound := layerNormOutputMaxAbsBound d ln1MaxAbsGamma ln1MaxAbsBeta - let ln1Union := ln1Out - -- Attention (streaming): use union input box. - let mut attnUnion : Array RatInterval := zeroIntervals d - let mut attnValueCoeff : Rat := 0 - let mut wqMax : Rat := 0 - let mut wkMax : Rat := 0 - for _h in [:H] do - pos := skipBlankLines lines pos - if !(pos < lines.size) then - return .error "unexpected end of file while scanning HEAD" - if !(lines[pos]!.trim.startsWith "HEAD") then - return .error "expected HEAD marker before per-head matrices" - pos := pos + 1 - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "W_Q") then - return .error "missing W_Q" - match consumeMatrixNormInf lines (pos + 1) d dh with - | .error e => return .error e - | .ok (nq, next) => - wqMax := max wqMax nq - pos := next - -- Optional per-head Q bias (does not affect Jacobian, - -- but must be parsed to stay in sync). - pos := skipBlankLines lines pos - if pos < lines.size && lines[pos]!.trim = "b_Q" then - match consumeVectorSkipFast lines (pos + 1) dh with - | .error e => return .error e - | .ok next => pos := next - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "W_K") then - return .error "missing W_K" - match consumeMatrixNormInf lines (pos + 1) d dh with - | .error e => return .error e - | .ok (nk, next) => - wkMax := max wkMax nk - pos := next - -- Optional per-head K bias (does not affect Jacobian, - -- but must be parsed to stay in sync). - pos := skipBlankLines lines pos - if pos < lines.size && lines[pos]!.trim = "b_K" then - match consumeVectorSkipFast lines (pos + 1) dh with - | .error e => return .error e - | .ok next => pos := next - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "W_V") then - return .error "missing W_V" - match consumeMatrixMulAndNormInf lines (pos + 1) d dh ln1Union with - | .error e => return .error e - | .ok (vHidden, _nWv, nextV) => - pos := nextV - -- Optional per-head V bias (affects forward activations / variance, - -- so we include it). - pos := skipBlankLines lines pos - let mut vHidden := vHidden - if pos < lines.size && lines[pos]!.trim = "b_V" then - match consumeVector lines (pos + 1) dh with - | .error e => return .error e - | .ok (bv, nextBv) => - pos := nextBv - vHidden := addConstVec vHidden bv - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "W_O") then - return .error "missing W_O" - let vCenteredOpBound := centeredAbsSum vHidden - match consumeMatrixMulAndNormInf lines (pos + 1) dh d vHidden with - | .error e => return .error e - | .ok (vOut, no, nextO) => - pos := nextO - attnUnion := addVecIntervals attnUnion vOut - attnValueCoeff := attnValueCoeff + vCenteredOpBound * no - -- Shared attention projection bias (affects forward activations / variance, - -- so we include it). - pos := skipBlankLines lines pos - if pos < lines.size && lines[pos]!.trim = "ATTN_BIAS" then - match consumeVector lines (pos + 1) d with - | .error e => return .error e - | .ok (bAttn, nextB) => - pos := nextB - attnUnion := addConstVec attnUnion bAttn - residualUnion := addVecIntervals residualUnion attnUnion - -- LN2: compute per-row outputs and min variance LB. - let p2 := ln2Params.getD l defLn - let (ln2Out, ln2VarLB) := - layerNormRowApprox residualUnion p2.gamma p2.beta eps soundnessBits - let ln2MaxAbsGamma := maxAbsOfVector p2.gamma - let ln2Bound := - if ln2VarLB > 0 then - layerNormOpBoundLocal ln2MaxAbsGamma ln2VarLB eps soundnessBits - else - layerNormOpBoundConservative ln2MaxAbsGamma eps soundnessBits - let ln2Union := ln2Out - -- MLP (streaming): W_in, b_in, W_out, b_out. - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "MLP") then - return .error "missing MLP section" - pos := pos + 1 - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "W_in") then - return .error "missing W_in" - match consumeMatrixMulAndNormInf lines (pos + 1) d dhid ln2Union with - | .error e => return .error e - | .ok (hidden, nWin, nextWin) => - pos := nextWin - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "b_in") then - return .error "missing b_in" - match consumeVector lines (pos + 1) dhid with - | .error e => return .error e - | .ok (bin, nextBin) => - pos := nextBin - let hiddenB := addConstVec hidden bin - let mlpActDerivBound := maxGeluDerivBound geluDerivTarget hiddenB - let actHidden := hiddenB.map (geluOverapproxRat geluDerivTarget) - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "W_out") then - return .error "missing W_out" - match consumeMatrixMulAndNormInf lines (pos + 1) dhid d actHidden with - | .error e => return .error e - | .ok (mlpOut0, nWout, nextWout) => - pos := nextWout - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "b_out") then - return .error "missing b_out" - match consumeVector lines (pos + 1) d with - | .error e => return .error e - | .ok (bout, nextBout) => - pos := nextBout - let mlpOut := addConstVec mlpOut0 bout - residualUnion := addVecIntervals residualUnion mlpOut - let scoreAbsBound := - attnScoreAbsBound d dh ln1OutMaxAbsBound wqMax wkMax - let (softmaxProbLo, softmaxProbHi) := - softmaxProbIntervalFromScoreAbsBound n scoreAbsBound softmaxExpEffort - let softmaxIntervalBound := - softmaxJacobianNormInfBound softmaxProbLo softmaxProbHi - let softmaxMarginBound := - softmaxJacobianNormInfBoundFromMargin n softmaxMarginLowerBound - softmaxExpEffort - let softmaxBound := min softmaxIntervalBound softmaxMarginBound - let attnPatternCoeff := - attnPatternCoeffBound n d dh ln1OutMaxAbsBound wqMax wkMax - attnValueCoeff - let attnW := - ln1Bound * - ((n : Rat) * attnValueCoeff + softmaxBound * attnPatternCoeff) - let mlpCoeff := nWin * nWout - let mlpW := ln2Bound * (mlpCoeff * mlpActDerivBound) - let C := attnW + mlpW + attnW * mlpW - layers := layers.push { - layerIdx := l - ln1MaxAbsGamma := ln1MaxAbsGamma - ln1MaxAbsBeta := ln1MaxAbsBeta - ln2MaxAbsGamma := ln2MaxAbsGamma - ln1VarianceLowerBound? := some ln1VarLB - ln2VarianceLowerBound? := some ln2VarLB - ln1Bound := ln1Bound - ln2Bound := ln2Bound - ln1OutMaxAbsBound := ln1OutMaxAbsBound - softmaxProbLo := softmaxProbLo - softmaxProbHi := softmaxProbHi - softmaxMarginLowerBound := softmaxMarginLowerBound - softmaxExpEffort := softmaxExpEffort - softmaxJacobianNormInfUpperBound := softmaxBound - wqOpBoundMax := wqMax - wkOpBoundMax := wkMax - attnValueCoeff := attnValueCoeff - attnPatternCoeff := attnPatternCoeff - mlpCoeff := mlpCoeff - mlpWinBound := nWin - mlpWoutBound := nWout - mlpActDerivBound := mlpActDerivBound - attnJacBound := attnW - mlpJacBound := mlpW - C := C - } - totalAmp := totalAmp * (1 + C) - actDerivBoundMax := max actDerivBoundMax mlpActDerivBound - pos := skipUntil lines pos (fun s => s.startsWith "LAYER") - let cert : ModelCert := { - modelPath := path.toString - inputPath? := some inputPath.toString - inputDelta := inputDelta - eps := eps - seqLen := n - modelDim := d - headDim := dh - soundnessBits := soundnessBits - geluDerivTarget := geluDerivTarget - actDerivBound := actDerivBoundMax - softmaxJacobianNormInfWorst := softmaxJacobianNormInfWorst - layers := layers - totalAmplificationFactor := totalAmp - } - if cert.check then - return .ok cert - return .error "sound certificate failed internal consistency checks" - -private def certifyModelFileLocal - (path : System.FilePath) - (eps : Rat) - (geluDerivTarget : GeluDerivTarget) - (soundnessBits : Nat) - (partitionDepth : Nat) - (inputPath : System.FilePath) - (inputDelta : Rat) - (softmaxMarginLowerBound : Rat := 0) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : IO (Except String ModelCert) := do - if partitionDepth ≠ 0 then - return .error "partitionDepth > 0 not yet implemented" - -- Prefer cached fixed-point path; fall back to the (slow) Rat-based path on any cache error. - match (← ensureSoundCache path) with - | .error _ => - certifyModelFileLocalText path eps geluDerivTarget soundnessBits partitionDepth - inputPath inputDelta softmaxMarginLowerBound softmaxExpEffort - | .ok (cpath, hdr) => - let cfg : Fixed10Cfg := scaleCfgOfPow10 hdr.scalePow10.toNat - let slack : Int := fixedUlpSlack - let modelDim := hdr.modelDim.toNat - let headDim := hdr.headDim.toNat - let hiddenDim := hdr.hiddenDim.toNat - let L := hdr.numLayers.toNat - let H := hdr.numHeads.toNat - let wqWkE ← readWqWkMaxBinary path (scalePow10 := hdr.scalePow10.toNat) - let (wqMaxArr, wkMaxArr) ← - match wqWkE with - | .error e => return .error e - | .ok v => pure v - -- For now we read embeddings from the input `.nfpt` file and use a union box. - let residualUnionE ← loadEmbeddingsUnionFixed cfg inputPath modelDim inputDelta - match residualUnionE with - | .error e => return .error e - | .ok (residualUnion0, inputSeqLen) => - let mut residualUnion := residualUnion0 - -- Open cache and position reader after header. - let ch ← IO.FS.Handle.mk cpath IO.FS.Mode.read - let _ ← Nfp.Untrusted.SoundCacheIO.readHeader ch - let mut rr ← Nfp.Untrusted.SoundCacheIO.I32Reader.init ch - let mut layers : Array LayerAmplificationCert := Array.mkEmpty L - let mut totalAmp : Rat := 1 - let mut actDerivBoundMax : Rat := 0 - for l in [:L] do - -- LN params from cache - let (ln1Gamma, rr1) ← readVecIntervals rr modelDim slack - let (ln1Beta, rr2) ← readVecIntervals rr1 modelDim slack - let (ln2Gamma, rr3) ← readVecIntervals rr2 modelDim slack - let (ln2Beta, rr4) ← readVecIntervals rr3 modelDim slack - rr := rr4 - -- LN1 - let (ln1Out, ln1VarLB) := - fixedLayerNormRowApprox cfg residualUnion ln1Gamma ln1Beta eps soundnessBits - let ln1MaxAbsGamma : Rat := - Rat.normalize (maxAbsVecFixed ln1Gamma) cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - let ln1MaxAbsBeta : Rat := - Rat.normalize (maxAbsVecFixed ln1Beta) cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - let ln1Bound := - if ln1VarLB > 0 then - layerNormOpBoundLocal ln1MaxAbsGamma ln1VarLB eps soundnessBits - else - layerNormOpBoundConservative ln1MaxAbsGamma eps soundnessBits - let ln1OutMaxAbsBound := - layerNormOutputMaxAbsBound modelDim ln1MaxAbsGamma ln1MaxAbsBeta - -- Attention (streaming from cache) - let mut attnUnion : Array Fixed10Interval := - Array.replicate modelDim { lo := 0, hi := 0 } - let mut attnValueCoeff : Rat := 0 - for _h in [:H] do - let (vHidden0, _nWv, rrV) ← - consumeMatrixMulAndNormInfFixed cfg slack rr modelDim headDim ln1Out - rr := rrV - let (bV, rrBv) ← readVecIntervals rr headDim slack - rr := rrBv - let vHidden := addVecFixed vHidden0 bV - let vCenteredOpBound := centeredAbsSumFixed cfg vHidden - let (vOut, nWo, rrO) ← - consumeMatrixMulAndNormInfFixed cfg slack rr headDim modelDim vHidden - rr := rrO - attnUnion := addVecFixed attnUnion vOut - attnValueCoeff := attnValueCoeff + vCenteredOpBound * nWo - let (attnBias, rrB) ← readVecIntervals rr modelDim slack - rr := rrB - attnUnion := addVecFixed attnUnion attnBias - residualUnion := addVecFixed residualUnion attnUnion - -- LN2 - let (ln2Out, ln2VarLB) := - fixedLayerNormRowApprox cfg residualUnion ln2Gamma ln2Beta eps soundnessBits - let ln2MaxAbsGamma : Rat := - Rat.normalize (maxAbsVecFixed ln2Gamma) cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - let ln2Bound := - if ln2VarLB > 0 then - layerNormOpBoundLocal ln2MaxAbsGamma ln2VarLB eps soundnessBits - else - layerNormOpBoundConservative ln2MaxAbsGamma eps soundnessBits - -- MLP - let (hidden0, nWin, rrWin) ← - consumeMatrixMulAndNormInfFixed cfg slack rr modelDim hiddenDim ln2Out - rr := rrWin - let (bIn, rrBin) ← readVecIntervals rr hiddenDim slack - rr := rrBin - let hiddenB := addVecFixed hidden0 bIn - let mlpActDerivBound := maxGeluDerivBoundFixed cfg geluDerivTarget hiddenB - let actHidden := geluOverapproxFixedVec cfg geluDerivTarget hiddenB - let (mlpOut0, nWout, rrWout) ← - consumeMatrixMulAndNormInfFixed cfg slack rr hiddenDim modelDim actHidden - rr := rrWout - let (bOut, rrBout) ← readVecIntervals rr modelDim slack - rr := rrBout - let mlpOut := addVecFixed mlpOut0 bOut - residualUnion := addVecFixed residualUnion mlpOut - let scoreAbsBound := - attnScoreAbsBound modelDim headDim ln1OutMaxAbsBound (wqMaxArr[l]!) - (wkMaxArr[l]!) - let (softmaxProbLo, softmaxProbHi) := - softmaxProbIntervalFromScoreAbsBound inputSeqLen scoreAbsBound softmaxExpEffort - let softmaxIntervalBound := softmaxJacobianNormInfBound softmaxProbLo softmaxProbHi - let softmaxMarginBound := - softmaxJacobianNormInfBoundFromMargin inputSeqLen softmaxMarginLowerBound - softmaxExpEffort - let softmaxBound := min softmaxIntervalBound softmaxMarginBound - let attnPatternCoeff := - attnPatternCoeffBound inputSeqLen modelDim headDim ln1OutMaxAbsBound - (wqMaxArr[l]!) (wkMaxArr[l]!) attnValueCoeff - let attnW := - ln1Bound * - ((inputSeqLen : Rat) * attnValueCoeff + softmaxBound * attnPatternCoeff) - let mlpCoeff := nWin * nWout - let mlpW := ln2Bound * (mlpCoeff * mlpActDerivBound) - let C := attnW + mlpW + attnW * mlpW - layers := layers.push { - layerIdx := l - ln1MaxAbsGamma := ln1MaxAbsGamma - ln1MaxAbsBeta := ln1MaxAbsBeta - ln2MaxAbsGamma := ln2MaxAbsGamma - ln1VarianceLowerBound? := some ln1VarLB - ln2VarianceLowerBound? := some ln2VarLB - ln1Bound := ln1Bound - ln2Bound := ln2Bound - ln1OutMaxAbsBound := ln1OutMaxAbsBound - softmaxProbLo := softmaxProbLo - softmaxProbHi := softmaxProbHi - softmaxMarginLowerBound := softmaxMarginLowerBound - softmaxExpEffort := softmaxExpEffort - softmaxJacobianNormInfUpperBound := softmaxBound - wqOpBoundMax := wqMaxArr[l]! - wkOpBoundMax := wkMaxArr[l]! - attnValueCoeff := attnValueCoeff - attnPatternCoeff := attnPatternCoeff - mlpCoeff := mlpCoeff - mlpWinBound := nWin - mlpWoutBound := nWout - mlpActDerivBound := mlpActDerivBound - attnJacBound := attnW - mlpJacBound := mlpW - C := C - } - totalAmp := totalAmp * (1 + C) - actDerivBoundMax := max actDerivBoundMax mlpActDerivBound - let cert : ModelCert := { - modelPath := path.toString - inputPath? := some inputPath.toString - inputDelta := inputDelta - eps := eps - seqLen := inputSeqLen - modelDim := modelDim - headDim := headDim - soundnessBits := soundnessBits - geluDerivTarget := geluDerivTarget - actDerivBound := actDerivBoundMax - softmaxJacobianNormInfWorst := softmaxJacobianNormInfWorst - layers := layers - totalAmplificationFactor := totalAmp - } - if cert.check then - return .ok cert - return .error "sound certificate failed internal consistency checks" - -private def certifyModelFileLocalBinary - (path : System.FilePath) - (eps : Rat) - (geluDerivTarget : GeluDerivTarget) - (soundnessBits : Nat) - (partitionDepth : Nat) - (inputPath : System.FilePath) - (inputDelta : Rat) - (softmaxMarginLowerBound : Rat := 0) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : IO (Except String ModelCert) := do - if partitionDepth ≠ 0 then - return .error "partitionDepth > 0 not yet implemented" - let scalePow10 := defaultBinaryScalePow10 - let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 - let slack : Int := fixedUlpSlack - let action : ExceptT String IO ModelCert := do - let (hdr, ln1Params, ln2Params) ← - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) - let residualUnion0 ← - ExceptT.mk (loadEmbeddingsUnionFixedBinary inputPath hdr.modelDim inputDelta scalePow10) - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let _ ← ExceptT.mk (readBinaryHeader h) - let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) - let _ ← ExceptT.mk (skipF64Array h (hdr.seqLen * hdr.modelDim)) - let defP : LayerNormParamsFixed := { - gamma := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - beta := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - } - let mut residualUnion := residualUnion0 - let mut layers : Array LayerAmplificationCert := Array.mkEmpty hdr.numLayers - let mut totalAmp : Rat := 1 - let mut actDerivBoundMax : Rat := 0 - for l in [:hdr.numLayers] do - let p1 := ln1Params.getD l defP - let (ln1Out, ln1VarLB) := - fixedLayerNormRowApprox cfg residualUnion p1.gamma p1.beta eps soundnessBits - let ln1MaxAbsGamma : Rat := - Rat.normalize (maxAbsVecFixed p1.gamma) cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - let ln1MaxAbsBeta : Rat := - Rat.normalize (maxAbsVecFixed p1.beta) cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - let ln1Bound := - if ln1VarLB > 0 then - layerNormOpBoundLocal ln1MaxAbsGamma ln1VarLB eps soundnessBits - else - layerNormOpBoundConservative ln1MaxAbsGamma eps soundnessBits - let ln1OutMaxAbsBound := - layerNormOutputMaxAbsBound hdr.modelDim ln1MaxAbsGamma ln1MaxAbsBeta - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - let mut attnValueCoeff : Rat := 0 - let mut wqMax : Rat := 0 - let mut wkMax : Rat := 0 - for _h in [:hdr.numHeads] do - let wqScaled ← - ExceptT.mk (readMatrixNormInfScaled h hdr.modelDim hdr.headDim scalePow10) - wqMax := max wqMax (ratOfScaledInt scalePow10 wqScaled) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wkScaled ← - ExceptT.mk (readMatrixNormInfScaled h hdr.modelDim hdr.headDim scalePow10) - wkMax := max wkMax (ratOfScaledInt scalePow10 wkScaled) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let (vHidden0, _nWv) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.headDim ln1Out scalePow10) - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let vHidden := addVecFixed vHidden0 bV - let vCenteredOpBound := centeredAbsSumFixed cfg vHidden - let (vOut, nWo) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.headDim hdr.modelDim vHidden scalePow10) - attnUnion := addVecFixed attnUnion vOut - attnValueCoeff := attnValueCoeff + vCenteredOpBound * nWo - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residualUnion := addVecFixed residualUnion attnUnion - let p2 := ln2Params.getD l defP - let (ln2Out, ln2VarLB) := - fixedLayerNormRowApprox cfg residualUnion p2.gamma p2.beta eps soundnessBits - let ln2MaxAbsGamma : Rat := - Rat.normalize (maxAbsVecFixed p2.gamma) cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - let ln2Bound := - if ln2VarLB > 0 then - layerNormOpBoundLocal ln2MaxAbsGamma ln2VarLB eps soundnessBits - else - layerNormOpBoundConservative ln2MaxAbsGamma eps soundnessBits - let (hidden0, nWin) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Out scalePow10) - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let hiddenB := addVecFixed hidden0 bIn - let mlpActDerivBound := maxGeluDerivBoundFixed cfg geluDerivTarget hiddenB - let actHidden := geluOverapproxFixedVec cfg hdr.geluDerivTarget hiddenB - let (mlpOut0, nWout) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.hiddenDim hdr.modelDim actHidden scalePow10) - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpOut := addVecFixed mlpOut0 bOut - residualUnion := addVecFixed residualUnion mlpOut - let scoreAbsBound := - attnScoreAbsBound hdr.modelDim hdr.headDim ln1OutMaxAbsBound wqMax wkMax - let (softmaxProbLo, softmaxProbHi) := - softmaxProbIntervalFromScoreAbsBound hdr.seqLen scoreAbsBound softmaxExpEffort - let softmaxIntervalBound := softmaxJacobianNormInfBound softmaxProbLo softmaxProbHi - let softmaxMarginBound := - softmaxJacobianNormInfBoundFromMargin hdr.seqLen softmaxMarginLowerBound softmaxExpEffort - let softmaxBound := min softmaxIntervalBound softmaxMarginBound - let attnPatternCoeff := - attnPatternCoeffBound hdr.seqLen hdr.modelDim hdr.headDim ln1OutMaxAbsBound - wqMax wkMax attnValueCoeff - let attnW := - ln1Bound * - ((hdr.seqLen : Rat) * attnValueCoeff + softmaxBound * attnPatternCoeff) - let mlpCoeff := nWin * nWout - let mlpW := ln2Bound * (mlpCoeff * mlpActDerivBound) - let C := attnW + mlpW + attnW * mlpW - layers := layers.push { - layerIdx := l - ln1MaxAbsGamma := ln1MaxAbsGamma - ln1MaxAbsBeta := ln1MaxAbsBeta - ln2MaxAbsGamma := ln2MaxAbsGamma - ln1VarianceLowerBound? := some ln1VarLB - ln2VarianceLowerBound? := some ln2VarLB - ln1Bound := ln1Bound - ln2Bound := ln2Bound - ln1OutMaxAbsBound := ln1OutMaxAbsBound - softmaxProbLo := softmaxProbLo - softmaxProbHi := softmaxProbHi - softmaxMarginLowerBound := softmaxMarginLowerBound - softmaxExpEffort := softmaxExpEffort - softmaxJacobianNormInfUpperBound := softmaxBound - wqOpBoundMax := wqMax - wkOpBoundMax := wkMax - attnValueCoeff := attnValueCoeff - attnPatternCoeff := attnPatternCoeff - mlpCoeff := mlpCoeff - mlpWinBound := nWin - mlpWoutBound := nWout - mlpActDerivBound := mlpActDerivBound - attnJacBound := attnW - mlpJacBound := mlpW - C := C - } - totalAmp := totalAmp * (1 + C) - actDerivBoundMax := max actDerivBoundMax mlpActDerivBound - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let cert : ModelCert := { - modelPath := path.toString - inputPath? := some inputPath.toString - inputDelta := inputDelta - eps := eps - seqLen := hdr.seqLen - modelDim := hdr.modelDim - headDim := hdr.headDim - soundnessBits := soundnessBits - geluDerivTarget := geluDerivTarget - actDerivBound := actDerivBoundMax - softmaxJacobianNormInfWorst := softmaxJacobianNormInfWorst - layers := layers - totalAmplificationFactor := totalAmp - } - if cert.check then - return cert - throw "sound certificate failed internal consistency checks" - action.run - -/-- Compute local per-head attention contribution bounds from a binary `.nfpt`. -/ -private def certifyHeadBoundsLocalBinary - (path : System.FilePath) - (eps : Rat) - (inputPath : System.FilePath) - (inputDelta : Rat) - (soundnessBits : Nat) - (scalePow10 : Nat := defaultBinaryScalePow10) : - IO (Except String (Array HeadLocalContributionCert)) := do - let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 - let slack : Int := fixedUlpSlack - let action : ExceptT String IO (Array HeadLocalContributionCert) := do - let (hdr, ln1Params, ln2Params) ← - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) - let residualUnion0 ← - ExceptT.mk (loadEmbeddingsUnionFixedBinary inputPath hdr.modelDim inputDelta scalePow10) - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let _ ← ExceptT.mk (readBinaryHeader h) - let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) - let _ ← ExceptT.mk (skipF64Array h (hdr.seqLen * hdr.modelDim)) - let defP : LayerNormParamsFixed := { - gamma := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - beta := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - } - let mut residualUnion := residualUnion0 - let mut heads : Array HeadLocalContributionCert := - Array.mkEmpty (hdr.numLayers * hdr.numHeads) - for l in [:hdr.numLayers] do - let p1 := ln1Params.getD l defP - let (ln1Out, ln1VarLB) := - fixedLayerNormRowApprox cfg residualUnion p1.gamma p1.beta eps soundnessBits - let ln1MaxAbsGamma : Rat := - Rat.normalize (maxAbsVecFixed p1.gamma) cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - let ln1Bound := - if ln1VarLB > 0 then - layerNormOpBoundLocal ln1MaxAbsGamma ln1VarLB eps soundnessBits - else - layerNormOpBoundConservative ln1MaxAbsGamma eps soundnessBits - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - for hIdx in [:hdr.numHeads] do - let wqScaledE ← - ExceptT.mk (readMatrixOpBoundScaled h hdr.modelDim hdr.headDim scalePow10) - let wqOp := ratOfScaledNat scalePow10 wqScaledE - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wkScaledE ← - ExceptT.mk (readMatrixOpBoundScaled h hdr.modelDim hdr.headDim scalePow10) - let wkOp := ratOfScaledNat scalePow10 wkScaledE - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let (vHidden0, _nWv) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.headDim ln1Out scalePow10) - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let vHidden := addVecFixed vHidden0 bV - let vCenteredOpBound := centeredAbsSumFixed cfg vHidden - let (vOut, nWo) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.headDim hdr.modelDim vHidden scalePow10) - attnUnion := addVecFixed attnUnion vOut - let softmaxJacobianBound := softmaxJacobianNormInfWorst - let attnW := ln1Bound * softmaxJacobianBound * vCenteredOpBound * nWo - let cert : HeadLocalContributionCert := { - layerIdx := l - headIdx := hIdx - soundnessBits := soundnessBits - ln1MaxAbsGamma := ln1MaxAbsGamma - ln1VarianceLowerBound := ln1VarLB - ln1Bound := ln1Bound - wqOpBound := wqOp - wkOpBound := wkOp - wvOpBound := vCenteredOpBound - woOpBound := nWo - qkFactorBound := wqOp * wkOp - softmaxJacobianNormInfUpperBound := softmaxJacobianBound - attnJacBound := attnW - } - if cert.check eps then - heads := heads.push cert - else - throw "local head contribution certificate failed internal checks" - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residualUnion := addVecFixed residualUnion attnUnion - let p2 := ln2Params.getD l defP - let (ln2Out, _ln2VarLB) := - fixedLayerNormRowApprox cfg residualUnion p2.gamma p2.beta eps soundnessBits - let (hidden0, _nWin) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Out scalePow10) - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let hiddenB := addVecFixed hidden0 bIn - let actHidden := geluOverapproxFixedVec cfg hdr.geluDerivTarget hiddenB - let (mlpOut0, _nWout) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.hiddenDim hdr.modelDim actHidden scalePow10) - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpOut := addVecFixed mlpOut0 bOut - residualUnion := addVecFixed residualUnion mlpOut - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - return heads - action.run - -/-- Compute local attention pattern bounds for a specific binary head. -/ -private def certifyHeadPatternLocalBinary - (path : System.FilePath) - (layerIdx headIdx : Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath : System.FilePath) - (inputDelta : Rat) - (targetOffset : Int) - (keyOffset : Int) - (maxSeqLen : Nat) - (tightPattern : Bool) - (tightPatternLayers : Nat) - (perRowPatternLayers : Nat) - (scalePow10 : Nat := defaultBinaryScalePow10) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) - (causalPattern : Bool := true) : - IO (Except String HeadPatternCert) := do - let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 - let slack : Int := fixedUlpSlack - let action : ExceptT String IO HeadPatternCert := do - let (hdr, ln1Params, ln2Params) ← - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) - if layerIdx ≥ hdr.numLayers then - throw s!"layer index {layerIdx} out of range" - if headIdx ≥ hdr.numHeads then - throw s!"head index {headIdx} out of range" - if hdr.seqLen > maxSeqLen then - throw s!"seq_len {hdr.seqLen} exceeds maxSeqLen {maxSeqLen}" - let residuals0 ← - ExceptT.mk - (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) - let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) - if hdrTok.seqLen ≠ hdr.seqLen then - throw "token/embedding seq_len mismatch" - let keyOffsetNat? : Option Nat := - if keyOffset ≥ 0 then some (Int.toNat keyOffset) else none - let keyOffsetNeg : Nat := Int.toNat (-keyOffset) - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let _ ← ExceptT.mk (readBinaryHeader h) - let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) - let _ ← ExceptT.mk (skipF64Array h (hdr.seqLen * hdr.modelDim)) - let defP : LayerNormParamsFixed := { - gamma := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - beta := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - } - let mut residuals := residuals0 - for l in [:hdr.numLayers] do - let p1 := ln1Params.getD l defP - let ln1Rows := fixedLayerNormRowsApprox cfg residuals p1 eps soundnessBits - if l = layerIdx then - let mut wq? : Option (Array Int) := none - let mut bq? : Option (Array Int) := none - let mut wk? : Option (Array Int) := none - let mut bk? : Option (Array Int) := none - for hIdx in [:hdr.numHeads] do - if hIdx = headIdx then - let wq ← - ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - wq? := some wq - let bQ ← - ExceptT.mk <| - readScaledFloatArray h hdr.headDim scalePow10 - bq? := some bQ - else - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - pure () - if hIdx = headIdx then - let wk ← - ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - wk? := some wk - let bK ← - ExceptT.mk <| - readScaledFloatArray h hdr.headDim scalePow10 - bk? := some bK - else - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - pure () - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.headDim * hdr.modelDim)) - let wq ← - match wq? with - | none => throw "missing W_Q for requested head" - | some xs => pure xs - let bQ ← - match bq? with - | none => throw "missing b_Q for requested head" - | some xs => pure xs - let wk ← - match wk? with - | none => throw "missing W_K for requested head" - | some xs => pure xs - let bK ← - match bk? with - | none => throw "missing b_K for requested head" - | some xs => pure xs - let bQIntervals := intervalsFromScaled bQ slack - let bKIntervals := intervalsFromScaled bK slack - let mut qRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - let mut kRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - let mut rowIdx : Nat := 0 - while rowIdx < ln1Rows.size do - let row := ln1Rows[rowIdx]! - let qRow0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wq row - let kRow0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wk row - qRows := qRows.push (addVecFixed qRow0 bQIntervals) - kRows := kRows.push (addVecFixed kRow0 bKIntervals) - rowIdx := rowIdx + 1 - let mut minTargetLower? : Option Int := none - let mut maxOtherUpper? : Option Int := none - let mut minTargetCount? : Option Nat := none - let mut i : Nat := 0 - while i < hdr.seqLen do - let ti : Int := (Int.ofNat i) + targetOffset - if ti < 0 || ti ≥ (Int.ofNat hdr.seqLen) then - pure () - else - let tIdx : Nat := Int.toNat ti - let targetTok := tokens[tIdx]! - let qRow := qRows[i]! - let mut targetLower? : Option Int := none - let mut targetMaxLower? : Option Int := none - let mut maxOtherUpperRow? : Option Int := none - let mut targetCount : Nat := 0 - let mut j : Nat := 0 - while j < hdr.seqLen do - if !causalPattern || j ≤ i then - let dot := fixedDotInterval cfg qRow (kRows[j]!) - let isMatch : Bool := - match keyOffsetNat? with - | some k => - let idx := j + k - idx < hdr.seqLen && tokens[idx]! = targetTok - | none => - if j < keyOffsetNeg then - false - else - tokens[j - keyOffsetNeg]! = targetTok - if isMatch then - targetCount := targetCount + 1 - targetLower? := - match targetLower? with - | none => some dot.lo - | some m => some (min m dot.lo) - targetMaxLower? := - match targetMaxLower? with - | none => some dot.lo - | some m => some (max m dot.lo) - else - let cur := dot.hi - maxOtherUpperRow? := - match maxOtherUpperRow? with - | none => some cur - | some m => some (max m cur) - else - pure () - j := j + 1 - let targetLowerRow? := - if tightPattern then targetMaxLower? else targetLower? - match targetLowerRow? with - | none => pure () - | some targetLower => - let maxOtherUpperRow := - match maxOtherUpperRow? with - | none => targetLower - | some v => v - minTargetLower? := - match minTargetLower? with - | none => some targetLower - | some m => some (min m targetLower) - maxOtherUpper? := - match maxOtherUpper? with - | none => some maxOtherUpperRow - | some m => some (max m maxOtherUpperRow) - minTargetCount? := - match minTargetCount? with - | none => some targetCount - | some m => some (min m targetCount) - i := i + 1 - let minTargetLower ← - match minTargetLower? with - | none => throw "no valid target positions for the requested offset" - | some v => pure v - let minTargetCount : Nat := - match minTargetCount? with - | none => 0 - | some v => v - let targetCountLB : Nat := - if tightPattern then (if minTargetCount > 0 then 1 else 0) else minTargetCount - let maxOtherUpper := - match maxOtherUpper? with - | none => minTargetLower - | some v => v - let marginInt : Int := minTargetLower - maxOtherUpper - let targetLower := ratOfScaledInt scalePow10 minTargetLower - let otherUpper := ratOfScaledInt scalePow10 maxOtherUpper - let margin := ratOfScaledInt scalePow10 marginInt - let weightLB : Rat := - softmaxTargetWeightLowerBound hdr.seqLen targetCountLB margin softmaxExpEffort - let cert : HeadPatternCert := { - layerIdx := layerIdx - headIdx := headIdx - seqLen := hdr.seqLen - targetOffset := targetOffset - keyOffset := keyOffset - targetCountLowerBound := targetCountLB - targetLogitLowerBound := targetLower - otherLogitUpperBound := otherUpper - marginLowerBound := margin - softmaxExpEffort := softmaxExpEffort - targetWeightLowerBound := weightLB - } - if cert.check then - return cert - throw "head pattern certificate failed internal consistency checks" - else - let tightLayers : Nat := - if tightPattern then Nat.max 1 tightPatternLayers else 0 - if tightLayers > 0 && layerIdx ≤ l + tightLayers then - if causalPattern then - let zeroRow : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - let mut attnRows : Array (Array Fixed10Interval) := - Array.replicate hdr.seqLen zeroRow - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wv ← ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let wo ← ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - let mut rowIdx : Nat := 0 - while rowIdx < ln1Rows.size do - let row := ln1Rows[rowIdx]! - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - rowIdx := rowIdx + 1 - let headRows := prefixUnionRowsFixed vOutRows - attnRows := addRowsFixed attnRows headRows - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnRows := addVecFixedRows attnRows attnBias - residuals := addRowsFixed residuals attnRows - else - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - let groupRows := groupUnionRowsByToken ln1Rows tokens - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wv ← ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let wo ← ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty groupRows.size - let mut rowIdx : Nat := 0 - while rowIdx < groupRows.size do - let row := groupRows[rowIdx]! - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - rowIdx := rowIdx + 1 - let vUnion := unionRowsFixed vOutRows - attnUnion := addVecFixed attnUnion vUnion - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - else - let ln1Union := unionRowsFixed ln1Rows - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let (vHidden0, _nWv) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.headDim ln1Union scalePow10) - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let vHidden := addVecFixed vHidden0 bV - let (vOut, _nWo) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.headDim hdr.modelDim vHidden scalePow10) - attnUnion := addVecFixed attnUnion vOut - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - let p2 := ln2Params.getD l defP - let ln2Rows := fixedLayerNormRowsApprox cfg residuals p2 eps soundnessBits - let perRowLayers : Nat := perRowPatternLayers - if perRowLayers > 0 && layerIdx ≤ l + perRowLayers then - let wIn ← - ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.hiddenDim) scalePow10 - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let wOut ← - ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpRows := - mlpRowsFromScaled cfg hdr.geluDerivTarget slack - hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows - residuals := addRowsFixed residuals mlpRows - else - let residuals' ← - ExceptT.mk (mlpUnionStepBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Rows residuals scalePow10) - residuals := residuals' - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - throw "target layer not reached" - action.run - -/-- Minimum relative improvement required to keep increasing softmax exp effort. -/ -private def defaultSoftmaxEffortMinRelImprove : Rat := (1 : Rat) / 100 - -/-- Choose a softmax exp effort by iterating until improvements are negligible. -/ -private def chooseSoftmaxExpEffort - (seqLen : Nat) (margin : Rat) (maxEffort : Nat) : - Nat × Rat × Rat := - let startEffort : Nat := if maxEffort = 0 then 0 else 1 - let weight0 : Rat := softmaxMaxProbLowerBound seqLen margin startEffort - let jac0 : Rat := softmaxJacobianNormInfBoundFromMargin seqLen margin startEffort - if startEffort ≥ maxEffort then - (startEffort, weight0, jac0) - else - Id.run do - let mut bestEff : Nat := startEffort - let mut bestWeight : Rat := weight0 - let mut bestJac : Rat := jac0 - let mut eff : Nat := startEffort - while eff < maxEffort do - eff := eff + 1 - let weight := softmaxMaxProbLowerBound seqLen margin eff - let jac := softmaxJacobianNormInfBoundFromMargin seqLen margin eff - if jac < bestJac then - let relImprove := - if bestJac = 0 then 0 else (bestJac - jac) / bestJac - bestEff := eff - bestWeight := weight - bestJac := jac - if relImprove < defaultSoftmaxEffortMinRelImprove then - eff := maxEffort - else - eff := maxEffort - return (bestEff, bestWeight, bestJac) - -/-- Compute local head best-match pattern bounds for a specific `.nfpt` head (binary only). -/ -private def certifyHeadPatternBestMatchLocalBinary - (path : System.FilePath) - (layerIdx headIdx : Nat) - (queryPos? : Option Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath : System.FilePath) - (inputDelta : Rat) - (targetOffset : Int) - (keyOffset : Int) - (maxSeqLen : Nat) - (tightPattern : Bool) - (tightPatternLayers : Nat) - (perRowPatternLayers : Nat) - (useAffine : Bool) - (scalePow10 : Nat := defaultBinaryScalePow10) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) - (causalPattern : Bool := true) - (shared? : Option SharedBinaryInputs := none) - (prefix? : Option SharedBinaryPrefix := none) : - IO (Except String HeadBestMatchPatternCert) := do - let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 - let slack : Int := fixedUlpSlack - let action : ExceptT String IO HeadBestMatchPatternCert := do - let timingEnabled ← ExceptT.lift <| IO.getEnv "NFP_TIMING" - let timing : Bool := timingEnabled.isSome - let timeIt {α : Type} (label : String) (work : ExceptT String IO α) : - ExceptT String IO α := do - if !timing then - work - else - let t0 ← ExceptT.lift IO.monoNanosNow - let r ← work - let t1 ← ExceptT.lift IO.monoNanosNow - let dtMs := (t1 - t0) / 1000000 - ExceptT.lift <| IO.eprintln s!"timing:{label} {dtMs}ms" - return r - let (hdr, ln1Params, ln2Params, residualsBase, tokensBase) ← - timeIt "load_shared" <| match shared? with - | some shared => do - if shared.scalePow10 ≠ scalePow10 then - throw "shared scalePow10 mismatch" - if shared.inputDelta ≠ inputDelta then - throw "shared inputDelta mismatch" - pure (shared.hdr, shared.ln1Params, shared.ln2Params, shared.residuals0, shared.tokens) - | none => do - let (hdr, ln1Params, ln2Params) ← - timeIt "load_ln_params" <| - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) - let residuals0 ← - timeIt "load_embeddings" <| - ExceptT.mk - (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) - let (hdrTok, tokens) ← - timeIt "load_tokens" <| ExceptT.mk (loadTokensBinary inputPath) - if hdrTok.seqLen ≠ hdr.seqLen then - throw "token/embedding seq_len mismatch" - pure (hdr, ln1Params, ln2Params, residuals0, tokens) - if layerIdx ≥ hdr.numLayers then - throw s!"layer index {layerIdx} out of range" - if headIdx ≥ hdr.numHeads then - throw s!"head index {headIdx} out of range" - if hdr.seqLen > maxSeqLen then - throw s!"seq_len {hdr.seqLen} exceeds maxSeqLen {maxSeqLen}" - let queryPos : Nat := - match queryPos? with - | some q => q - | none => - if hdr.seqLen = 0 then 0 else hdr.seqLen - 1 - if queryPos ≥ hdr.seqLen then - throw s!"queryPos {queryPos} out of range" - let seqLenEff : Nat := if causalPattern then queryPos + 1 else hdr.seqLen - let (residuals0, tokens) ← - match prefix? with - | some pref => - if pref.seqLenEff ≠ seqLenEff then - throw "prefix seq_len mismatch" - pure (pref.residuals.get, pref.tokens.get) - | none => - let residuals0 := - if causalPattern then takePrefix residualsBase seqLenEff else residualsBase - let tokens := if causalPattern then takePrefix tokensBase seqLenEff else tokensBase - pure (residuals0, tokens) - let keyOffsetNat? : Option Nat := - if keyOffset ≥ 0 then some (Int.toNat keyOffset) else none - let keyOffsetNeg : Nat := Int.toNat (-keyOffset) - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let _ ← ExceptT.mk (readBinaryHeader h) - let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) - let _ ← ExceptT.mk (skipF64Array h (hdr.seqLen * hdr.modelDim)) - let defP : LayerNormParamsFixed := { - gamma := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - beta := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - } - let mut residuals := residuals0 - for l in [:hdr.numLayers] do - let p1 := ln1Params.getD l defP - let mut ln1Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln1Out, _ln1VarLB) := - fixedLayerNormRowApprox cfg row p1.gamma p1.beta eps soundnessBits - ln1Rows := ln1Rows.push ln1Out - if l = layerIdx then - let tPattern0? ← - if timing then - let t0 ← ExceptT.lift IO.monoNanosNow - pure (some t0) - else - pure none - let mut wq? : Option (Array Int) := none - let mut bq? : Option (Array Int) := none - let mut wk? : Option (Array Int) := none - let mut bk? : Option (Array Int) := none - for hIdx in [:hdr.numHeads] do - if hIdx = headIdx then - let wq ← - ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - wq? := some wq - let bQ ← - ExceptT.mk <| - readScaledFloatArray h hdr.headDim scalePow10 - bq? := some bQ - let wk ← - ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - wk? := some wk - let bK ← - ExceptT.mk <| - readScaledFloatArray h hdr.headDim scalePow10 - bk? := some bK - else - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.headDim * hdr.modelDim)) - let wq ← - match wq? with - | none => throw "missing W_Q for requested head" - | some xs => pure xs - let bQ ← - match bq? with - | none => throw "missing b_Q for requested head" - | some xs => pure xs - let wk ← - match wk? with - | none => throw "missing W_K for requested head" - | some xs => pure xs - let bK ← - match bk? with - | none => throw "missing b_K for requested head" - | some xs => pure xs - let bQIntervals := intervalsFromScaled bQ slack - let bKIntervals := intervalsFromScaled bK slack - let ti : Int := (Int.ofNat queryPos) + targetOffset - if ti < 0 || ti ≥ (Int.ofNat seqLenEff) then - throw "query position has no valid target offset" - let tIdx : Nat := Int.toNat ti - let targetTok := tokens[tIdx]! - let mut bestMatchLower? : Option Int := none - let mut bestNonmatchUpper? : Option Int := none - if useAffine then - let (qInputCenters, qInputRadii, _qAbsInput) := - rowCentersRadiiAbsInt (ln1Rows[queryPos]!) - let (qCenters0, qRadii0) := - matMulCentersRadiiIntSlack cfg slack - hdr.modelDim hdr.headDim wq qInputCenters qInputRadii - let bQCenters := bQ - let bKCenters := bK - let bQRadii := bQIntervals.map intervalRadiusInt - let bKRadii := bKIntervals.map intervalRadiusInt - let qCenters := addVecScaledInt qCenters0 bQCenters 1 - let qRadii := addVecScaledInt qRadii0 bQRadii 1 - let useTasks := seqLenEff > 32 - if useTasks then - let chunkSize : Nat := 16 - let numChunks : Nat := (seqLenEff + chunkSize - 1) / chunkSize - let mut tasks : Array (Task (Option Int × Option Int)) := Array.mkEmpty numChunks - let mut chunkIdx : Nat := 0 - while chunkIdx < numChunks do - let start := chunkIdx * chunkSize - let stop := min seqLenEff (start + chunkSize) - tasks := tasks.push <| Task.spawn (fun _ => - Id.run do - let mut bestMatchLower? : Option Int := none - let mut bestNonmatchUpper? : Option Int := none - let mut j := start - while j < stop do - if !causalPattern || j ≤ queryPos then - let (kInputCenters, kInputRadii, _kAbsInput) := - rowCentersRadiiAbsInt (ln1Rows[j]!) - let (kCenters0, kRadii0) := - matMulCentersRadiiIntSlack cfg slack - hdr.modelDim hdr.headDim wk kInputCenters kInputRadii - let kCenters := addVecScaledInt kCenters0 bKCenters 1 - let kRadii := addVecScaledInt kRadii0 bKRadii 1 - let dot := - dotIntervalFromCentersRadiiInt cfg qCenters qRadii kCenters kRadii - let isMatch : Bool := - match keyOffsetNat? with - | some k => - let idx := j + k - idx < seqLenEff && tokens[idx]! = targetTok - | none => - if j < keyOffsetNeg then - false - else - tokens[j - keyOffsetNeg]! = targetTok - if isMatch then - bestMatchLower? := - match bestMatchLower? with - | none => some dot.lo - | some m => some (max m dot.lo) - else - bestNonmatchUpper? := - match bestNonmatchUpper? with - | none => some dot.hi - | some m => some (max m dot.hi) - j := j + 1 - return (bestMatchLower?, bestNonmatchUpper?)) - chunkIdx := chunkIdx + 1 - for t in tasks do - let (matchChunk?, nonmatchChunk?) := t.get - if matchChunk?.isSome then - bestMatchLower? := - match bestMatchLower?, matchChunk? with - | none, some v => some v - | some cur, some v => some (max cur v) - | some cur, none => some cur - | none, none => none - if nonmatchChunk?.isSome then - bestNonmatchUpper? := - match bestNonmatchUpper?, nonmatchChunk? with - | none, some v => some v - | some cur, some v => some (max cur v) - | some cur, none => some cur - | none, none => none - else - for j in [:seqLenEff] do - if !causalPattern || j ≤ queryPos then - let (kInputCenters, kInputRadii, _kAbsInput) := - rowCentersRadiiAbsInt (ln1Rows[j]!) - let (kCenters0, kRadii0) := - matMulCentersRadiiIntSlack cfg slack - hdr.modelDim hdr.headDim wk kInputCenters kInputRadii - let kCenters := addVecScaledInt kCenters0 bKCenters 1 - let kRadii := addVecScaledInt kRadii0 bKRadii 1 - let dot := - dotIntervalFromCentersRadiiInt cfg qCenters qRadii kCenters kRadii - let isMatch : Bool := - match keyOffsetNat? with - | some k => - let idx := j + k - idx < seqLenEff && tokens[idx]! = targetTok - | none => - if j < keyOffsetNeg then - false - else - tokens[j - keyOffsetNeg]! = targetTok - if isMatch then - bestMatchLower? := - match bestMatchLower? with - | none => some dot.lo - | some m => some (max m dot.lo) - else - bestNonmatchUpper? := - match bestNonmatchUpper? with - | none => some dot.hi - | some m => some (max m dot.hi) - else - pure () - else - let qRow := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wq (ln1Rows[queryPos]!) - let qRow := addVecFixed qRow bQIntervals - let mut kRows : Array (Array Fixed10Interval) := Array.mkEmpty seqLenEff - for row in ln1Rows do - let kRow0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wk row - kRows := kRows.push (addVecFixed kRow0 bKIntervals) - for j in [:seqLenEff] do - if !causalPattern || j ≤ queryPos then - let dot := fixedDotInterval cfg qRow (kRows[j]!) - let isMatch : Bool := - match keyOffsetNat? with - | some k => - let idx := j + k - idx < seqLenEff && tokens[idx]! = targetTok - | none => - if j < keyOffsetNeg then - false - else - tokens[j - keyOffsetNeg]! = targetTok - if isMatch then - bestMatchLower? := - match bestMatchLower? with - | none => some dot.lo - | some m => some (max m dot.lo) - else - bestNonmatchUpper? := - match bestNonmatchUpper? with - | none => some dot.hi - | some m => some (max m dot.hi) - else - pure () - let bestMatchLower ← - match bestMatchLower? with - | none => throw "no matching keys for the requested offset" - | some v => pure v - let bestNonmatchUpper := - match bestNonmatchUpper? with - | none => bestMatchLower - | some v => v - let marginInt : Int := bestMatchLower - bestNonmatchUpper - let bestMatchLowerRat := ratOfScaledInt scalePow10 bestMatchLower - let bestNonmatchUpperRat := ratOfScaledInt scalePow10 bestNonmatchUpper - let margin := ratOfScaledInt scalePow10 marginInt - let (effortUsed, weightLB, softmaxJacobianUB) := - chooseSoftmaxExpEffort hdr.seqLen margin softmaxExpEffort - let cert : HeadBestMatchPatternCert := { - layerIdx := layerIdx - headIdx := headIdx - seqLen := hdr.seqLen - queryPos := queryPos - targetOffset := targetOffset - keyOffset := keyOffset - targetToken := targetTok - bestMatchLogitLowerBound := bestMatchLowerRat - bestNonmatchLogitUpperBound := bestNonmatchUpperRat - marginLowerBound := margin - softmaxExpEffort := effortUsed - bestMatchWeightLowerBound := weightLB - softmaxJacobianNormInfUpperBound := softmaxJacobianUB - } - if cert.check then - if let some t0 := tPattern0? then - let t1 ← ExceptT.lift IO.monoNanosNow - let dtMs := (t1 - t0) / 1000000 - ExceptT.lift <| IO.eprintln s!"timing:layer{l}:pattern {dtMs}ms" - return cert - throw "best-match head pattern certificate failed internal consistency checks" - else - let tightLayers : Nat := - if tightPattern then Nat.max 1 tightPatternLayers else 0 - if tightLayers > 0 && layerIdx ≤ l + tightLayers then - if causalPattern then - let zeroRow : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - let mut attnRows : Array (Array Fixed10Interval) := - Array.replicate seqLenEff zeroRow - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wv ← ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let wo ← ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty seqLenEff - for row in ln1Rows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let headRows := prefixUnionRowsFixed vOutRows - attnRows := addRowsFixed attnRows headRows - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnRows := addVecFixedRows attnRows attnBias - residuals := addRowsFixed residuals attnRows - else - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - let groupRows := groupUnionRowsByToken ln1Rows tokens - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wv ← ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let wo ← ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty groupRows.size - for row in groupRows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let vUnion := unionRowsFixed vOutRows - attnUnion := addVecFixed attnUnion vUnion - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - else - let ln1Union := unionRowsFixed ln1Rows - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let (vHidden0, _nWv) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.headDim ln1Union scalePow10) - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let vHidden := addVecFixed vHidden0 bV - let (vOut, _nWo) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.headDim hdr.modelDim vHidden scalePow10) - attnUnion := addVecFixed attnUnion vOut - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - let p2 := ln2Params.getD l defP - let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty seqLenEff - for row in residuals do - let (ln2Out, _ln2VarLB) := - fixedLayerNormRowApprox cfg row p2.gamma p2.beta eps soundnessBits - ln2Rows := ln2Rows.push ln2Out - let perRowLayers : Nat := perRowPatternLayers - if perRowLayers > 0 && layerIdx ≤ l + perRowLayers then - let wIn ← - ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.hiddenDim) scalePow10 - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let wOut ← - ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpRows := - mlpRowsFromScaled cfg hdr.geluDerivTarget slack - hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows - residuals := addRowsFixed residuals mlpRows - else - let ln2Union := unionRowsFixed ln2Rows - let (hidden0, _nWin) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Union scalePow10) - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let hiddenB := addVecFixed hidden0 bIn - let actHidden := geluOverapproxFixedVec cfg hdr.geluDerivTarget hiddenB - let (mlpOut0, _nWout) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.hiddenDim hdr.modelDim actHidden scalePow10) - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpOut := addVecFixed mlpOut0 bOut - residuals := addVecFixedRows residuals mlpOut - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - throw "target layer not reached" - action.run - -/-- Compute local head best-match pattern bounds for all valid query positions (binary only). -/ -private def certifyHeadPatternBestMatchLocalBinarySweep - (path : System.FilePath) - (layerIdx headIdx : Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath : System.FilePath) - (inputDelta : Rat) - (targetOffset : Int) - (keyOffset : Int) - (maxSeqLen : Nat) - (tightPattern : Bool) - (tightPatternLayers : Nat) - (perRowPatternLayers : Nat) - (useAffine : Bool) - (scalePow10 : Nat := defaultBinaryScalePow10) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) - (causalPattern : Bool := true) - (shared? : Option SharedBinaryInputs := none) : - IO (Except String (Array HeadBestMatchPatternCert)) := do - let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 - let slack : Int := fixedUlpSlack - let action : ExceptT String IO (Array HeadBestMatchPatternCert) := do - let (hdr, ln1Params, ln2Params, residuals0, tokens) ← - match shared? with - | some shared => - if shared.scalePow10 ≠ scalePow10 then - throw "shared scalePow10 mismatch" - if shared.inputDelta ≠ inputDelta then - throw "shared inputDelta mismatch" - pure (shared.hdr, shared.ln1Params, shared.ln2Params, shared.residuals0, shared.tokens) - | none => - let (hdr, ln1Params, ln2Params) ← - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) - let residuals0 ← - ExceptT.mk - (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) - let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) - if hdrTok.seqLen ≠ hdr.seqLen then - throw "token/embedding seq_len mismatch" - pure (hdr, ln1Params, ln2Params, residuals0, tokens) - if layerIdx ≥ hdr.numLayers then - throw s!"layer index {layerIdx} out of range" - if headIdx ≥ hdr.numHeads then - throw s!"head index {headIdx} out of range" - if hdr.seqLen > maxSeqLen then - throw s!"seq_len {hdr.seqLen} exceeds maxSeqLen {maxSeqLen}" - if useAffine then - throw "affine sweep is unsupported; use --bestMatch without --sweep" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let _ ← ExceptT.mk (readBinaryHeader h) - let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) - let _ ← ExceptT.mk (skipF64Array h (hdr.seqLen * hdr.modelDim)) - let defP : LayerNormParamsFixed := { - gamma := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - beta := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - } - let mut residuals := residuals0 - for l in [:hdr.numLayers] do - let p1 := ln1Params.getD l defP - let mut ln1Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln1Out, _ln1VarLB) := - fixedLayerNormRowApprox cfg row p1.gamma p1.beta eps soundnessBits - ln1Rows := ln1Rows.push ln1Out - if l = layerIdx then - let mut wq? : Option (Array Int) := none - let mut bq? : Option (Array Int) := none - let mut wk? : Option (Array Int) := none - let mut bk? : Option (Array Int) := none - for hIdx in [:hdr.numHeads] do - if hIdx = headIdx then - let wq ← - ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - wq? := some wq - let bQ ← - ExceptT.mk <| - readScaledFloatArray h hdr.headDim scalePow10 - bq? := some bQ - let wk ← - ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - wk? := some wk - let bK ← - ExceptT.mk <| - readScaledFloatArray h hdr.headDim scalePow10 - bk? := some bK - else - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.headDim * hdr.modelDim)) - let wq ← - match wq? with - | none => throw "missing W_Q for requested head" - | some xs => pure xs - let bQ ← - match bq? with - | none => throw "missing b_Q for requested head" - | some xs => pure xs - let wk ← - match wk? with - | none => throw "missing W_K for requested head" - | some xs => pure xs - let bK ← - match bk? with - | none => throw "missing b_K for requested head" - | some xs => pure xs - let bQIntervals := intervalsFromScaled bQ slack - let bKIntervals := intervalsFromScaled bK slack - let mut qRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - let mut kRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in ln1Rows do - let qRow0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wq row - let kRow0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wk row - qRows := qRows.push (addVecFixed qRow0 bQIntervals) - kRows := kRows.push (addVecFixed kRow0 bKIntervals) - let validPositions : Array Nat := Id.run do - let mut out : Array Nat := Array.mkEmpty hdr.seqLen - for i in [:hdr.seqLen] do - let ti : Int := (Int.ofNat i) + targetOffset - if ti < 0 || ti ≥ (Int.ofNat hdr.seqLen) then - pure () - else - out := out.push i - out - if validPositions.isEmpty then - throw "no valid query positions for the requested offset" - let keyOffsetNat? : Option Nat := - if keyOffset ≥ 0 then some (Int.toNat keyOffset) else none - let keyOffsetNeg : Nat := Int.toNat (-keyOffset) - let computeCert : Nat → Except String HeadBestMatchPatternCert := fun queryPos => do - let ti : Int := (Int.ofNat queryPos) + targetOffset - if ti < 0 || ti ≥ (Int.ofNat hdr.seqLen) then - throw "query position has no valid target offset" - let tIdx : Nat := Int.toNat ti - let targetTok := tokens[tIdx]! - let qRow := qRows[queryPos]! - let mut bestMatchLower? : Option Int := none - let mut bestNonmatchUpper? : Option Int := none - for j in [:hdr.seqLen] do - if !causalPattern || j ≤ queryPos then - let dot := fixedDotInterval cfg qRow (kRows[j]!) - let isMatch : Bool := - match keyOffsetNat? with - | some k => - let idx := j + k - idx < hdr.seqLen && tokens[idx]! = targetTok - | none => - if j < keyOffsetNeg then - false - else - tokens[j - keyOffsetNeg]! = targetTok - if isMatch then - bestMatchLower? := - match bestMatchLower? with - | none => some dot.lo - | some m => some (max m dot.lo) - else - bestNonmatchUpper? := - match bestNonmatchUpper? with - | none => some dot.hi - | some m => some (max m dot.hi) - else - pure () - let bestMatchLower ← - match bestMatchLower? with - | none => throw "no matching keys for the requested offset" - | some v => pure v - let bestNonmatchUpper := - match bestNonmatchUpper? with - | none => bestMatchLower - | some v => v - let marginInt : Int := bestMatchLower - bestNonmatchUpper - let bestMatchLowerRat := ratOfScaledInt scalePow10 bestMatchLower - let bestNonmatchUpperRat := ratOfScaledInt scalePow10 bestNonmatchUpper - let margin := ratOfScaledInt scalePow10 marginInt - let (effortUsed, weightLB, softmaxJacobianUB) := - chooseSoftmaxExpEffort hdr.seqLen margin softmaxExpEffort - let cert : HeadBestMatchPatternCert := { - layerIdx := layerIdx - headIdx := headIdx - seqLen := hdr.seqLen - queryPos := queryPos - targetOffset := targetOffset - keyOffset := keyOffset - targetToken := targetTok - bestMatchLogitLowerBound := bestMatchLowerRat - bestNonmatchLogitUpperBound := bestNonmatchUpperRat - marginLowerBound := margin - softmaxExpEffort := effortUsed - bestMatchWeightLowerBound := weightLB - softmaxJacobianNormInfUpperBound := softmaxJacobianUB - } - if cert.check then - return cert - throw "best-match head pattern certificate failed internal consistency checks" - let useTasks := validPositions.size > 32 - let mut certs : Array HeadBestMatchPatternCert := Array.mkEmpty validPositions.size - if useTasks then - let tasks := validPositions.map (fun i => - Task.spawn (fun _ => computeCert i)) - for t in tasks do - match t.get with - | .ok cert => certs := certs.push cert - | .error e => throw e - else - for i in validPositions do - match computeCert i with - | .ok cert => certs := certs.push cert - | .error e => throw e - return certs - else - let tightLayers : Nat := - if tightPattern then Nat.max 1 tightPatternLayers else 0 - if tightLayers > 0 && layerIdx ≤ l + tightLayers then - if causalPattern then - let zeroRow : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - let mut attnRows : Array (Array Fixed10Interval) := - Array.replicate hdr.seqLen zeroRow - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wv ← ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let wo ← ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in ln1Rows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let headRows := prefixUnionRowsFixed vOutRows - attnRows := addRowsFixed attnRows headRows - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnRows := addVecFixedRows attnRows attnBias - residuals := addRowsFixed residuals attnRows - else - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - let groupRows := groupUnionRowsByToken ln1Rows tokens - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wv ← ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let wo ← ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty groupRows.size - for row in groupRows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let vUnion := unionRowsFixed vOutRows - attnUnion := addVecFixed attnUnion vUnion - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - else - let ln1Union := unionRowsFixed ln1Rows - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let (vHidden0, _nWv) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.headDim ln1Union scalePow10) - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let vHidden := addVecFixed vHidden0 bV - let (vOut, _nWo) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.headDim hdr.modelDim vHidden scalePow10) - attnUnion := addVecFixed attnUnion vOut - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - let p2 := ln2Params.getD l defP - let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln2Out, _ln2VarLB) := - fixedLayerNormRowApprox cfg row p2.gamma p2.beta eps soundnessBits - ln2Rows := ln2Rows.push ln2Out - let perRowLayers : Nat := perRowPatternLayers - if perRowLayers > 0 && layerIdx ≤ l + perRowLayers then - let wIn ← - ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.hiddenDim) scalePow10 - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let wOut ← - ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpRows := - mlpRowsFromScaled cfg hdr.geluDerivTarget slack - hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows - residuals := addRowsFixed residuals mlpRows - else - let ln2Union := unionRowsFixed ln2Rows - let (hidden0, _nWin) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Union scalePow10) - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let hiddenB := addVecFixed hidden0 bIn - let actHidden := geluOverapproxFixedVec cfg hdr.geluDerivTarget hiddenB - let (mlpOut0, _nWout) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.hiddenDim hdr.modelDim actHidden scalePow10) - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpOut := addVecFixed mlpOut0 bOut - residuals := addVecFixedRows residuals mlpOut - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - throw "target layer not reached" - action.run - -/-- Compute local head output lower bounds for a single coordinate (binary only). -/ -private def certifyHeadValueLowerBoundLocalBinary - (path : System.FilePath) - (pattern : HeadPatternCert) - (coord : Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath : System.FilePath) - (inputDelta : Rat) - (maxSeqLen : Nat) - (scalePow10 : Nat := defaultBinaryScalePow10) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (causalPattern : Bool := true) : - IO (Except String HeadValueLowerBoundCert) := do - let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 - let slack : Int := fixedUlpSlack - let action : ExceptT String IO HeadValueLowerBoundCert := do - let (hdr, ln1Params, ln2Params) ← - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) - if pattern.layerIdx ≥ hdr.numLayers then - throw s!"layer index {pattern.layerIdx} out of range" - if pattern.headIdx ≥ hdr.numHeads then - throw s!"head index {pattern.headIdx} out of range" - if coord ≥ hdr.modelDim then - throw s!"coord index {coord} out of range" - if hdr.seqLen > maxSeqLen then - throw s!"seq_len {hdr.seqLen} exceeds maxSeqLen {maxSeqLen}" - let residuals0 ← - ExceptT.mk - (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) - let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) - if hdrTok.seqLen ≠ hdr.seqLen then - throw "token/embedding seq_len mismatch" - if pattern.seqLen ≠ hdr.seqLen then - throw "pattern seq_len mismatch" - let keyOffsetNat? : Option Nat := - if pattern.keyOffset ≥ 0 then some (Int.toNat pattern.keyOffset) else none - let keyOffsetNeg : Nat := Int.toNat (-pattern.keyOffset) - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let _ ← ExceptT.mk (readBinaryHeader h) - let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) - let _ ← ExceptT.mk (skipF64Array h (hdr.seqLen * hdr.modelDim)) - let defP : LayerNormParamsFixed := { - gamma := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - beta := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - } - let mut residuals := residuals0 - for l in [:hdr.numLayers] do - let p1 := ln1Params.getD l defP - let mut ln1Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln1Out, _ln1VarLB) := - fixedLayerNormRowApprox cfg row p1.gamma p1.beta eps soundnessBits - ln1Rows := ln1Rows.push ln1Out - if l = pattern.layerIdx then - let mut wv? : Option (Array Int) := none - let mut bv? : Option (Array Int) := none - let mut wo? : Option (Array Int) := none - for hIdx in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - if hIdx = pattern.headIdx then - let wv ← - ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - wv? := some wv - let bV ← - ExceptT.mk <| - readScaledFloatArray h hdr.headDim scalePow10 - bv? := some bV - let wo ← - ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - wo? := some wo - else - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.headDim * hdr.modelDim)) - let wv ← - match wv? with - | none => throw "missing W_V for requested head" - | some xs => pure xs - let bV ← - match bv? with - | none => throw "missing b_V for requested head" - | some xs => pure xs - let wo ← - match wo? with - | none => throw "missing W_O for requested head" - | some xs => pure xs - let bVIntervals := intervalsFromScaled bV slack - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in ln1Rows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bVIntervals - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let mut minMatchLo? : Option Int := none - let mut minNonmatchLo? : Option Int := none - for i in [:hdr.seqLen] do - let ti : Int := (Int.ofNat i) + pattern.targetOffset - if ti < 0 || ti ≥ (Int.ofNat hdr.seqLen) then - pure () - else - let tIdx : Nat := Int.toNat ti - let targetTok := tokens[tIdx]! - let mut matchLo? : Option Int := none - let mut nonmatchLo? : Option Int := none - for j in [:hdr.seqLen] do - if !causalPattern || j ≤ i then - let row := vOutRows[j]! - let vCoord := row[coord]!.lo - let isMatch : Bool := - match keyOffsetNat? with - | some k => - let idx := j + k - idx < hdr.seqLen && tokens[idx]! = targetTok - | none => - if j < keyOffsetNeg then - false - else - tokens[j - keyOffsetNeg]! = targetTok - if isMatch then - matchLo? := - match matchLo? with - | none => some vCoord - | some m => some (min m vCoord) - else - nonmatchLo? := - match nonmatchLo? with - | none => some vCoord - | some m => some (min m vCoord) - else - pure () - let matchLo := - match matchLo? with - | none => 0 - | some v => v - let nonmatchLo := - match nonmatchLo? with - | none => matchLo - | some v => v - minMatchLo? := - match minMatchLo? with - | none => some matchLo - | some m => some (min m matchLo) - minNonmatchLo? := - match minNonmatchLo? with - | none => some nonmatchLo - | some m => some (min m nonmatchLo) - let matchLo := - match minMatchLo? with - | none => 0 - | some v => v - let nonmatchLo := - match minNonmatchLo? with - | none => matchLo - | some v => v - let matchLoRat := ratOfScaledInt scalePow10 matchLo - let nonmatchLoRat := ratOfScaledInt scalePow10 nonmatchLo - let weightLB := pattern.targetWeightLowerBound - let outputLB := mixLowerBound weightLB matchLoRat nonmatchLoRat - let cert : HeadValueLowerBoundCert := { - layerIdx := pattern.layerIdx - headIdx := pattern.headIdx - coord := coord - matchWeightLowerBound := weightLB - matchCoordLowerBound := matchLoRat - nonmatchCoordLowerBound := nonmatchLoRat - outputCoordLowerBound := outputLB - } - if cert.check then - return cert - throw "head value lower bound failed internal consistency checks" - else - let tightLayers : Nat := - if tightPattern then Nat.max 1 tightPatternLayers else 0 - if tightLayers > 0 && pattern.layerIdx ≤ l + tightLayers then - if causalPattern then - let zeroRow : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - let mut attnRows : Array (Array Fixed10Interval) := - Array.replicate hdr.seqLen zeroRow - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wv ← ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let wo ← ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in ln1Rows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let headRows := prefixUnionRowsFixed vOutRows - attnRows := addRowsFixed attnRows headRows - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnRows := addVecFixedRows attnRows attnBias - residuals := addRowsFixed residuals attnRows - else - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - let groupRows := groupUnionRowsByToken ln1Rows tokens - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wv ← ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let wo ← ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty groupRows.size - for row in groupRows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let vUnion := unionRowsFixed vOutRows - attnUnion := addVecFixed attnUnion vUnion - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - else - let ln1Union := unionRowsFixed ln1Rows - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let (vHidden0, _nWv) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.headDim ln1Union scalePow10) - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let vHidden := addVecFixed vHidden0 bV - let (vOut, _nWo) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.headDim hdr.modelDim vHidden scalePow10) - attnUnion := addVecFixed attnUnion vOut - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - let p2 := ln2Params.getD l defP - let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln2Out, _ln2VarLB) := - fixedLayerNormRowApprox cfg row p2.gamma p2.beta eps soundnessBits - ln2Rows := ln2Rows.push ln2Out - let perRowLayers : Nat := perRowPatternLayers - if perRowLayers > 0 && pattern.layerIdx ≤ l + perRowLayers then - let wIn ← - ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.hiddenDim) scalePow10 - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let wOut ← - ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpRows := - mlpRowsFromScaled cfg hdr.geluDerivTarget slack - hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows - residuals := addRowsFixed residuals mlpRows - else - let ln2Union := unionRowsFixed ln2Rows - let (hidden0, _nWin) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Union scalePow10) - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let hiddenB := addVecFixed hidden0 bIn - let actHidden := geluOverapproxFixedVec cfg hdr.geluDerivTarget hiddenB - let (mlpOut0, _nWout) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.hiddenDim hdr.modelDim actHidden scalePow10) - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpOut := addVecFixed mlpOut0 bOut - residuals := addVecFixedRows residuals mlpOut - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - throw "target layer not reached" - action.run - -/-- Compute local head logit-difference lower bounds for a specific head (binary only). -/ -private def certifyHeadLogitDiffLowerBoundLocalBinary - (path : System.FilePath) - (pattern : HeadPatternCert) - (targetToken negativeToken : Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath : System.FilePath) - (inputDelta : Rat) - (maxSeqLen : Nat) - (scalePow10 : Nat := defaultBinaryScalePow10) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (causalPattern : Bool := true) : - IO (Except String HeadLogitDiffLowerBoundCert) := do - let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 - let slack : Int := fixedUlpSlack - let action : ExceptT String IO HeadLogitDiffLowerBoundCert := do - let (hdrDir, direction) ← - ExceptT.mk (readLogitDiffDirectionBinary path targetToken negativeToken scalePow10 slack) - let (hdr, ln1Params, ln2Params) ← - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) - if hdr.modelDim ≠ hdrDir.modelDim then - throw "unembedding model_dim mismatch" - if pattern.layerIdx ≥ hdr.numLayers then - throw s!"layer index {pattern.layerIdx} out of range" - if pattern.headIdx ≥ hdr.numHeads then - throw s!"head index {pattern.headIdx} out of range" - if hdr.seqLen > maxSeqLen then - throw s!"seq_len {hdr.seqLen} exceeds maxSeqLen {maxSeqLen}" - if direction.size ≠ hdr.modelDim then - throw "logit direction size mismatch" - let residuals0 ← - ExceptT.mk - (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) - let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) - if hdrTok.seqLen ≠ hdr.seqLen then - throw "token/embedding seq_len mismatch" - if pattern.seqLen ≠ hdr.seqLen then - throw "pattern seq_len mismatch" - let keyOffsetNat? : Option Nat := - if pattern.keyOffset ≥ 0 then some (Int.toNat pattern.keyOffset) else none - let keyOffsetNeg : Nat := Int.toNat (-pattern.keyOffset) - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let _ ← ExceptT.mk (readBinaryHeader h) - let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) - let _ ← ExceptT.mk (skipF64Array h (hdr.seqLen * hdr.modelDim)) - let defP : LayerNormParamsFixed := { - gamma := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - beta := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - } - let mut residuals := residuals0 - for l in [:hdr.numLayers] do - let p1 := ln1Params.getD l defP - let mut ln1Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln1Out, _ln1VarLB) := - fixedLayerNormRowApprox cfg row p1.gamma p1.beta eps soundnessBits - ln1Rows := ln1Rows.push ln1Out - if l = pattern.layerIdx then - let mut wv? : Option (Array Int) := none - let mut bv? : Option (Array Int) := none - let mut wo? : Option (Array Int) := none - for hIdx in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - if hIdx = pattern.headIdx then - let wv ← - ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - wv? := some wv - let bV ← - ExceptT.mk <| - readScaledFloatArray h hdr.headDim scalePow10 - bv? := some bV - let wo ← - ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - wo? := some wo - else - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.headDim * hdr.modelDim)) - let wv ← - match wv? with - | none => throw "missing W_V for requested head" - | some xs => pure xs - let bV ← - match bv? with - | none => throw "missing b_V for requested head" - | some xs => pure xs - let wo ← - match wo? with - | none => throw "missing W_O for requested head" - | some xs => pure xs - let bVIntervals := intervalsFromScaled bV slack - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in ln1Rows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bVIntervals - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let mut vDotRows : Array Fixed10Interval := Array.mkEmpty hdr.seqLen - for row in vOutRows do - vDotRows := vDotRows.push (fixedDotInterval cfg row direction) - let mut minMatchLo? : Option Int := none - let mut minNonmatchLo? : Option Int := none - for i in [:hdr.seqLen] do - let ti : Int := (Int.ofNat i) + pattern.targetOffset - if ti < 0 || ti ≥ (Int.ofNat hdr.seqLen) then - pure () - else - let tIdx : Nat := Int.toNat ti - let targetTok := tokens[tIdx]! - let mut matchLo? : Option Int := none - let mut nonmatchLo? : Option Int := none - for j in [:hdr.seqLen] do - if !causalPattern || j ≤ i then - let vLo := (vDotRows[j]!).lo - let isMatch : Bool := - match keyOffsetNat? with - | some k => - let idx := j + k - idx < hdr.seqLen && tokens[idx]! = targetTok - | none => - if j < keyOffsetNeg then - false - else - tokens[j - keyOffsetNeg]! = targetTok - if isMatch then - matchLo? := - match matchLo? with - | none => some vLo - | some m => some (min m vLo) - else - nonmatchLo? := - match nonmatchLo? with - | none => some vLo - | some m => some (min m vLo) - else - pure () - let matchLo := - match matchLo? with - | none => 0 - | some v => v - let nonmatchLo := - match nonmatchLo? with - | none => matchLo - | some v => v - minMatchLo? := - match minMatchLo? with - | none => some matchLo - | some m => some (min m matchLo) - minNonmatchLo? := - match minNonmatchLo? with - | none => some nonmatchLo - | some m => some (min m nonmatchLo) - let matchLo := - match minMatchLo? with - | none => 0 - | some v => v - let nonmatchLo := - match minNonmatchLo? with - | none => matchLo - | some v => v - let matchLoRat := ratOfScaledInt scalePow10 matchLo - let nonmatchLoRat := ratOfScaledInt scalePow10 nonmatchLo - let weightLB := pattern.targetWeightLowerBound - let outputLB := mixLowerBound weightLB matchLoRat nonmatchLoRat - let cert : HeadLogitDiffLowerBoundCert := { - layerIdx := pattern.layerIdx - headIdx := pattern.headIdx - targetToken := targetToken - negativeToken := negativeToken - matchWeightLowerBound := weightLB - matchLogitLowerBound := matchLoRat - nonmatchLogitLowerBound := nonmatchLoRat - logitDiffLowerBound := outputLB - } - if cert.check then - return cert - throw "head logit lower bound failed internal consistency checks" - else - let tightLayers : Nat := - if tightPattern then Nat.max 1 tightPatternLayers else 0 - if tightLayers > 0 && pattern.layerIdx ≤ l + tightLayers then - if causalPattern then - let zeroRow : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - let mut attnRows : Array (Array Fixed10Interval) := - Array.replicate hdr.seqLen zeroRow - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wv ← ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let wo ← ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in ln1Rows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let headRows := prefixUnionRowsFixed vOutRows - attnRows := addRowsFixed attnRows headRows - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnRows := addVecFixedRows attnRows attnBias - residuals := addRowsFixed residuals attnRows - else - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - let groupRows := groupUnionRowsByToken ln1Rows tokens - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wv ← ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let wo ← ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty groupRows.size - for row in groupRows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let vUnion := unionRowsFixed vOutRows - attnUnion := addVecFixed attnUnion vUnion - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - else - let ln1Union := unionRowsFixed ln1Rows - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let (vHidden0, _nWv) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.headDim ln1Union scalePow10) - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let vHidden := addVecFixed vHidden0 bV - let (vOut, _nWo) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.headDim hdr.modelDim vHidden scalePow10) - attnUnion := addVecFixed attnUnion vOut - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - let p2 := ln2Params.getD l defP - let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln2Out, _ln2VarLB) := - fixedLayerNormRowApprox cfg row p2.gamma p2.beta eps soundnessBits - ln2Rows := ln2Rows.push ln2Out - let perRowLayers : Nat := perRowPatternLayers - if perRowLayers > 0 && pattern.layerIdx ≤ l + perRowLayers then - let wIn ← - ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.hiddenDim) scalePow10 - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let wOut ← - ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpRows := - mlpRowsFromScaled cfg hdr.geluDerivTarget slack - hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows - residuals := addRowsFixed residuals mlpRows - else - let ln2Union := unionRowsFixed ln2Rows - let (hidden0, _nWin) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Union scalePow10) - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let hiddenB := addVecFixed hidden0 bIn - let actHidden := geluOverapproxFixedVec cfg hdr.geluDerivTarget hiddenB - let (mlpOut0, _nWout) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.hiddenDim hdr.modelDim actHidden scalePow10) - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpOut := addVecFixed mlpOut0 bOut - residuals := addVecFixedRows residuals mlpOut - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - throw "target layer not reached" - action.run - -/-- Soundly compute certification bounds from a `.nfpt` model file. - -If an input is provided via `inputPath?`, the certificate uses streaming rational IBP to obtain -local (input-dependent) LayerNorm variance lower bounds at every layer. -Otherwise it falls back to the weight-only global certificate. --/ -def certifyModelFile - (path : System.FilePath) - (eps : Rat) - (geluDerivTarget : GeluDerivTarget) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (partitionDepth : Nat := 0) - (softmaxMarginLowerBound : Rat := 0) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : IO (Except String ModelCert) := do - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let firstLine := (← h.getLine).trim - if firstLine = "NFP_BINARY_V1" then - if inputDelta < 0 then - return .error "delta must be nonnegative" - match inputPath? with - | none => - if inputDelta = 0 then - certifyModelFileGlobalBinary path eps geluDerivTarget soundnessBits partitionDepth - softmaxMarginLowerBound softmaxExpEffort - else - certifyModelFileLocalBinary path eps geluDerivTarget soundnessBits partitionDepth - path inputDelta softmaxMarginLowerBound softmaxExpEffort - | some ip => - certifyModelFileLocalBinary path eps geluDerivTarget soundnessBits partitionDepth - ip inputDelta softmaxMarginLowerBound softmaxExpEffort - else - match inputPath? with - | none => - certifyModelFileGlobal path eps geluDerivTarget soundnessBits - (inputPath? := none) (inputDelta := inputDelta) (partitionDepth := partitionDepth) - (softmaxMarginLowerBound := softmaxMarginLowerBound) - (softmaxExpEffort := softmaxExpEffort) - | some ip => - if inputDelta < 0 then - return .error "delta must be nonnegative" - certifyModelFileLocal path eps geluDerivTarget soundnessBits partitionDepth ip inputDelta - softmaxMarginLowerBound softmaxExpEffort - -/-- Compute weight-only per-head contribution bounds for a `.nfpt` model file. -/ -def certifyHeadBounds - (path : System.FilePath) - (scalePow10 : Nat := defaultBinaryScalePow10) : - IO (Except String (Array HeadContributionCert)) := do - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let firstLine := (← h.getLine).trim - if firstLine = "NFP_BINARY_V1" then - certifyHeadBoundsBinary path scalePow10 - else - return .error "head contribution bounds require NFP_BINARY_V1" - -/-- Compute local per-head attention contribution bounds for a `.nfpt` model file. -/ -def certifyHeadBoundsLocal - (path : System.FilePath) - (eps : Rat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (soundnessBits : Nat) - (scalePow10 : Nat := defaultBinaryScalePow10) : - IO (Except String (Array HeadLocalContributionCert)) := do - if inputDelta < 0 then - return .error "delta must be nonnegative" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let firstLine := (← h.getLine).trim - if firstLine = "NFP_BINARY_V1" then - let inputPath := inputPath?.getD path - certifyHeadBoundsLocalBinary path eps inputPath inputDelta soundnessBits scalePow10 - else - return .error "local head contribution bounds require NFP_BINARY_V1" - -/-- Compute local attention pattern bounds for a specific `.nfpt` head (binary only). -/ -def certifyHeadPatternLocal - (path : System.FilePath) - (layerIdx headIdx : Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (targetOffset : Int := -1) - (keyOffset : Int := 0) - (maxSeqLen : Nat := 256) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (scalePow10 : Nat := defaultBinaryScalePow10) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) - (causalPattern : Bool := true) : - IO (Except String HeadPatternCert) := do - if inputDelta < 0 then - return .error "delta must be nonnegative" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let firstLine := (← h.getLine).trim - if firstLine = "NFP_BINARY_V1" then - let inputPath := inputPath?.getD path - certifyHeadPatternLocalBinary path layerIdx headIdx eps soundnessBits inputPath inputDelta - targetOffset keyOffset maxSeqLen tightPattern tightPatternLayers perRowPatternLayers - scalePow10 - softmaxExpEffort causalPattern - else - return .error "head pattern bounds require NFP_BINARY_V1" - -/-- Compute local best-match pattern bounds for a specific `.nfpt` head (binary only). -/ -def certifyHeadPatternBestMatchLocal - (path : System.FilePath) - (layerIdx headIdx : Nat) - (queryPos? : Option Nat := none) - (eps : Rat) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (targetOffset : Int := -1) - (keyOffset : Int := 0) - (maxSeqLen : Nat := 256) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (useAffine : Bool := false) - (scalePow10 : Nat := defaultBinaryScalePow10) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) - (causalPattern : Bool := true) : - IO (Except String HeadBestMatchPatternCert) := do - if inputDelta < 0 then - return .error "delta must be nonnegative" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let firstLine := (← h.getLine).trim - if firstLine = "NFP_BINARY_V1" then - let inputPath := inputPath?.getD path - certifyHeadPatternBestMatchLocalBinary path layerIdx headIdx queryPos? eps soundnessBits - inputPath - inputDelta targetOffset keyOffset maxSeqLen tightPattern tightPatternLayers - perRowPatternLayers useAffine scalePow10 softmaxExpEffort causalPattern - else - return .error "head pattern bounds require NFP_BINARY_V1" - -/-- Compute local best-match pattern bounds for all valid query positions (binary only). -/ -def certifyHeadPatternBestMatchLocalSweep - (path : System.FilePath) - (layerIdx headIdx : Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (targetOffset : Int := -1) - (keyOffset : Int := 0) - (maxSeqLen : Nat := 256) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (useAffine : Bool := false) - (scalePow10 : Nat := defaultBinaryScalePow10) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) - (causalPattern : Bool := true) : - IO (Except String (Array HeadBestMatchPatternCert)) := do - if inputDelta < 0 then - return .error "delta must be nonnegative" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let firstLine := (← h.getLine).trim - if firstLine = "NFP_BINARY_V1" then - let inputPath := inputPath?.getD path - certifyHeadPatternBestMatchLocalBinarySweep path layerIdx headIdx eps soundnessBits inputPath - inputDelta targetOffset keyOffset maxSeqLen tightPattern tightPatternLayers - perRowPatternLayers - useAffine scalePow10 softmaxExpEffort causalPattern - else - return .error "head pattern bounds require NFP_BINARY_V1" - -/-- Compute layer-level best-match margin evidence for a `.nfpt` layer (binary only). -/ -def certifyLayerBestMatchMarginLocal - (path : System.FilePath) - (layerIdx : Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (targetOffset : Int := -1) - (keyOffset : Int := 0) - (maxSeqLen : Nat := 256) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (scalePow10 : Nat := defaultBinaryScalePow10) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) - (causalPattern : Bool := true) : - IO (Except String LayerBestMatchMarginCert) := do - if inputDelta < 0 then - return .error "delta must be nonnegative" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let firstLine := (← h.getLine).trim - if firstLine = "NFP_BINARY_V1" then - let hdrE ← readBinaryHeader h - match hdrE with - | .error e => return .error e - | .ok hdr => - if layerIdx ≥ hdr.numLayers then - return .error s!"layer index {layerIdx} out of range" - if hdr.seqLen > maxSeqLen then - return .error s!"seq_len {hdr.seqLen} exceeds maxSeqLen {maxSeqLen}" - let inputPath := inputPath?.getD path - let mut headCerts : Array HeadBestMatchPatternCert := Array.mkEmpty 0 - for hIdx in [:hdr.numHeads] do - match ← - certifyHeadPatternBestMatchLocalBinarySweep - path layerIdx hIdx eps soundnessBits inputPath inputDelta targetOffset keyOffset - maxSeqLen tightPattern tightPatternLayers perRowPatternLayers false scalePow10 - softmaxExpEffort causalPattern with - | .error e => return .error e - | .ok certs => - for cert in certs do - headCerts := headCerts.push cert - match marginsFromBestMatchCerts hdr.numHeads hdr.seqLen headCerts with - | none => return .error "best-match margin coverage failed" - | some margins => - let marginLowerBound := minMarginArray margins - let cert : LayerBestMatchMarginCert := { - layerIdx := layerIdx - seqLen := hdr.seqLen - numHeads := hdr.numHeads - softmaxExpEffort := softmaxExpEffort - marginLowerBound := marginLowerBound - margins := margins - headCerts := headCerts - } - if cert.check then - return .ok cert - return .error "layer best-match margin certificate failed internal checks" - else - return .error "layer best-match margins require NFP_BINARY_V1" - -/-- Compute local head value lower bounds for a specific `.nfpt` head (binary only). -/ -def certifyHeadValueLowerBoundLocal - (path : System.FilePath) - (layerIdx headIdx coord : Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (targetOffset : Int := -1) - (keyOffset : Int := 0) - (maxSeqLen : Nat := 256) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (scalePow10 : Nat := defaultBinaryScalePow10) - (causalPattern : Bool := true) : - IO (Except String HeadValueLowerBoundCert) := do - if inputDelta < 0 then - return .error "delta must be nonnegative" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let firstLine := (← h.getLine).trim - if firstLine = "NFP_BINARY_V1" then - let inputPath := inputPath?.getD path - let patternE ← - certifyHeadPatternLocalBinary path layerIdx headIdx eps soundnessBits inputPath inputDelta - targetOffset keyOffset maxSeqLen tightPattern tightPatternLayers perRowPatternLayers - scalePow10 - defaultSoftmaxExpEffort causalPattern - match patternE with - | .error e => return .error e - | .ok pattern => - certifyHeadValueLowerBoundLocalBinary path pattern coord eps soundnessBits inputPath - inputDelta maxSeqLen scalePow10 tightPattern tightPatternLayers perRowPatternLayers - causalPattern - else - return .error "head value bounds require NFP_BINARY_V1" - -/-- Compute local head logit-difference lower bounds for a specific `.nfpt` head (binary only). -/ -def certifyHeadLogitDiffLowerBoundLocal - (path : System.FilePath) - (layerIdx headIdx targetToken negativeToken : Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (targetOffset : Int := -1) - (keyOffset : Int := 0) - (maxSeqLen : Nat := 256) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (scalePow10 : Nat := defaultBinaryScalePow10) - (causalPattern : Bool := true) : - IO (Except String HeadLogitDiffLowerBoundCert) := do - if inputDelta < 0 then - return .error "delta must be nonnegative" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let firstLine := (← h.getLine).trim - if firstLine = "NFP_BINARY_V1" then - let inputPath := inputPath?.getD path - let patternE ← - certifyHeadPatternLocalBinary path layerIdx headIdx eps soundnessBits inputPath inputDelta - targetOffset keyOffset maxSeqLen tightPattern tightPatternLayers perRowPatternLayers - scalePow10 - defaultSoftmaxExpEffort causalPattern - match patternE with - | .error e => return .error e - | .ok pattern => - certifyHeadLogitDiffLowerBoundLocalBinary path pattern targetToken negativeToken - eps soundnessBits inputPath inputDelta maxSeqLen scalePow10 tightPattern - tightPatternLayers perRowPatternLayers causalPattern - else - return .error "head logit bounds require NFP_BINARY_V1" - -/-- Compute a combined sound certificate for an induction-style head pair (binary only). -/ -def certifyInductionSound - (path : System.FilePath) - (layer1 head1 layer2 head2 coord : Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (offset1 : Int := -1) - (offset2 : Int := -1) - (keyOffset1 : Int := 0) - (keyOffset2 : Int := 0) - (maxSeqLen : Nat := 256) - (scalePow10 : Nat := defaultBinaryScalePow10) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (targetToken? : Option Nat := none) - (negativeToken? : Option Nat := none) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) - (causalPattern : Bool := true) : - IO (Except String InductionHeadSoundCert) := do - if inputDelta < 0 then - return .error "delta must be nonnegative" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let firstLine := (← h.getLine).trim - if firstLine = "NFP_BINARY_V1" then - let inputPath := inputPath?.getD path - let p1E ← - certifyHeadPatternLocalBinary path layer1 head1 eps soundnessBits inputPath inputDelta - offset1 keyOffset1 maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 - softmaxExpEffort causalPattern - match p1E with - | .error e => return .error e - | .ok p1 => - let p2E ← - certifyHeadPatternLocalBinary path layer2 head2 eps soundnessBits inputPath inputDelta - offset2 keyOffset2 maxSeqLen tightPattern tightPatternLayers perRowPatternLayers - scalePow10 - softmaxExpEffort causalPattern - match p2E with - | .error e => return .error e - | .ok p2 => - let vE ← - certifyHeadValueLowerBoundLocalBinary path p2 coord eps soundnessBits inputPath - inputDelta maxSeqLen scalePow10 tightPattern tightPatternLayers perRowPatternLayers - causalPattern - match vE with - | .error e => return .error e - | .ok v => - let logitE ← - match targetToken?, negativeToken? with - | none, none => pure (.ok none) - | some targetToken, some negativeToken => do - let logitE ← certifyHeadLogitDiffLowerBoundLocalBinary path p2 - targetToken negativeToken eps soundnessBits inputPath inputDelta - maxSeqLen scalePow10 tightPattern tightPatternLayers perRowPatternLayers - causalPattern - pure (logitE.map some) - | _, _ => - pure (.error "use both target and negative tokens (or neither)") - match logitE with - | .error e => return .error e - | .ok logit? => - let cert : InductionHeadSoundCert := { - layer1Pattern := p1 - layer2Pattern := p2 - layer2Value := v - layer2Logit? := logit? - deltaLowerBound := v.outputCoordLowerBound - } - if cert.check then - return .ok cert - return .error "induction head certificate failed internal consistency checks" - else - return .error "induction sound cert requires NFP_BINARY_V1" - -/-- Compute a best-match induction-head certificate in a single binary pass. -/ -private def certifyInductionSoundBestMatchLocalBinaryPair - (path : System.FilePath) - (layer1 head1 layer2 head2 coord queryPos : Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath : System.FilePath) - (inputDelta : Rat) - (offset1 : Int) - (offset2 : Int) - (keyOffset1 : Int) - (keyOffset2 : Int) - (maxSeqLen : Nat) - (scalePow10 : Nat := defaultBinaryScalePow10) - (tightPattern : Bool) - (tightPatternLayers : Nat) - (perRowPatternLayers : Nat) - (useAffine : Bool) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) - (causalPattern : Bool := true) - (shared? : Option SharedBinaryInputs := none) - (prefix? : Option SharedBinaryPrefix := none) - (targetToken? : Option Nat := none) - (negativeToken? : Option Nat := none) - (direction? : Option (Thunk (Array Fixed10Interval)) := none) : - IO (Except String InductionHeadBestMatchSoundCert) := do - let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 - let slack : Int := fixedUlpSlack - let action : ExceptT String IO InductionHeadBestMatchSoundCert := do - let timingEnabled ← ExceptT.lift <| IO.getEnv "NFP_TIMING" - let timing : Bool := timingEnabled.isSome - let debugMarginEnabled ← ExceptT.lift <| IO.getEnv "NFP_MARGIN_DEBUG" - let debugMargin : Bool := debugMarginEnabled.isSome - let timeIt {α : Type} (label : String) (work : ExceptT String IO α) : - ExceptT String IO α := do - if !timing then - work - else - let t0 ← ExceptT.lift IO.monoNanosNow - let r ← work - let t1 ← ExceptT.lift IO.monoNanosNow - let dtMs := (t1 - t0) / 1000000 - ExceptT.lift <| IO.eprintln s!"timing:{label} {dtMs}ms" - return r - let (hdr, ln1Params, ln2Params, residualsBase, tokensBase) ← - timeIt "load_shared" <| match shared? with - | some shared => do - if shared.scalePow10 ≠ scalePow10 then - throw "shared scalePow10 mismatch" - if shared.inputDelta ≠ inputDelta then - throw "shared inputDelta mismatch" - pure (shared.hdr, shared.ln1Params, shared.ln2Params, shared.residuals0, shared.tokens) - | none => do - let (hdr, ln1Params, ln2Params) ← - timeIt "load_ln_params" <| - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) - let residuals0 ← - timeIt "load_embeddings" <| - ExceptT.mk - (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) - let (hdrTok, tokens) ← - timeIt "load_tokens" <| ExceptT.mk (loadTokensBinary inputPath) - if hdrTok.seqLen ≠ hdr.seqLen then - throw "token/embedding seq_len mismatch" - pure (hdr, ln1Params, ln2Params, residuals0, tokens) - if layer1 ≥ hdr.numLayers then - throw s!"layer1 index {layer1} out of range" - if layer2 ≥ hdr.numLayers then - throw s!"layer2 index {layer2} out of range" - if head1 ≥ hdr.numHeads then - throw s!"head1 index {head1} out of range" - if head2 ≥ hdr.numHeads then - throw s!"head2 index {head2} out of range" - if coord ≥ hdr.modelDim then - throw s!"coord index {coord} out of range" - if hdr.seqLen > maxSeqLen then - throw s!"seq_len {hdr.seqLen} exceeds maxSeqLen {maxSeqLen}" - if queryPos ≥ hdr.seqLen then - throw s!"queryPos {queryPos} out of range" - let seqLenEff : Nat := if causalPattern then queryPos + 1 else hdr.seqLen - let (residuals0, tokens) ← - match prefix? with - | some pref => - if pref.seqLenEff ≠ seqLenEff then - throw "prefix seq_len mismatch" - pure (pref.residuals.get, pref.tokens.get) - | none => - let residuals0 := - if causalPattern then takePrefix residualsBase seqLenEff else residualsBase - let tokens := if causalPattern then takePrefix tokensBase seqLenEff else tokensBase - pure (residuals0, tokens) - let matchRows - (targetOffset : Int) - (keyOffset : Int) : Array Nat := - Id.run do - let ti : Int := (Int.ofNat queryPos) + targetOffset - if ti < 0 || ti ≥ (Int.ofNat seqLenEff) then - return #[] - let tIdx : Nat := Int.toNat ti - let targetTok := tokens[tIdx]! - let keyOffsetNat? : Option Nat := - if keyOffset ≥ 0 then some (Int.toNat keyOffset) else none - let keyOffsetNeg : Nat := Int.toNat (-keyOffset) - let mut rows : Array Nat := Array.mkEmpty seqLenEff - for j in [:seqLenEff] do - if !causalPattern || j ≤ queryPos then - let isMatch : Bool := - match keyOffsetNat? with - | some k => - let idx := j + k - idx < seqLenEff && tokens[idx]! = targetTok - | none => - if j < keyOffsetNeg then - false - else - tokens[j - keyOffsetNeg]! = targetTok - if isMatch then - rows := rows.push j - else - pure () - return rows - let selectedRows : Array Nat := - let r1 := matchRows offset1 keyOffset1 - let r2 := matchRows offset2 keyOffset2 - if r1.isEmpty && r2.isEmpty then - #[] - else - Id.run do - let mut acc : Array Nat := Array.mkEmpty (r1.size + r2.size) - for v in r1 do - if !acc.contains v then - acc := acc.push v - for v in r2 do - if !acc.contains v then - acc := acc.push v - acc - let selectedRows? : Option (Array Nat) := - if selectedRows.isEmpty then none else some selectedRows - let useLogit ← - match targetToken?, negativeToken?, direction? with - | none, none, none => pure false - | some _, some _, some _ => pure true - | _, _, _ => throw "use both target and negative tokens (or neither)" - let calcLnRows - (rows : Array (Array Fixed10Interval)) - (p : LayerNormParamsFixed) : - Array (Array Fixed10Interval) := - fixedLayerNormRowsApprox cfg rows p eps soundnessBits - let calcLnRowsExact - (rows : Array (Array Fixed10Interval)) - (p : LayerNormParamsFixed) : - Array (Array Fixed10Interval) := - fixedLayerNormRowsApproxExact cfg rows p eps soundnessBits - let logRowsWidth (tag label : String) (rows : Array (Array Fixed10Interval)) : - ExceptT String IO Unit := do - if debugMargin then - if rows.isEmpty then - ExceptT.lift <| IO.eprintln s!"{tag}:{label} empty" - else - let mut maxW : Rat := 0 - for row in rows do - let w := centeredAbsSumFixed cfg row - if w > maxW then - maxW := w - let qW := - if queryPos < rows.size then centeredAbsSumFixed cfg rows[queryPos]! else 0 - ExceptT.lift <| - IO.eprintln s!"{tag}:{label} rows={rows.size} queryWidth={qW} maxWidth={maxW}" - let logVecWidth (tag label : String) (row : Array Fixed10Interval) : - ExceptT String IO Unit := do - if debugMargin then - let w := centeredAbsSumFixed cfg row - ExceptT.lift <| IO.eprintln s!"{tag}:{label} width={w}" - let calcVOutRowsIntervalsNoTask - (cfg : Fixed10Cfg) - (modelDim headDim : Nat) - (wvIntervals woIntervals : Array Fixed10Interval) - (bV : Array Fixed10Interval) - (rows : Array (Array Fixed10Interval)) - (start stop : Nat) : - Array (Array Fixed10Interval) := - Id.run do - let mut out : Array (Array Fixed10Interval) := Array.mkEmpty (stop - start) - let mut i := start - while i < stop do - let vHidden0 := matMulIntervalsFromIntervalsNoTask cfg modelDim headDim wvIntervals - (rows[i]!) - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromIntervalsNoTask cfg headDim modelDim woIntervals vHidden - out := out.push vOut - i := i + 1 - return out - let calcVOutRowsIntervals - (rows : Array (Array Fixed10Interval)) - (wvIntervals woIntervals : Array Fixed10Interval) - (bV : Array Fixed10Interval) : - Array (Array Fixed10Interval) := - let useTasks := rows.size > 32 - if useTasks then - Id.run do - let chunkSize : Nat := 16 - let numChunks : Nat := (rows.size + chunkSize - 1) / chunkSize - let mut tasks : Array (Task (Array (Array Fixed10Interval))) := Array.mkEmpty numChunks - let mut chunkIdx : Nat := 0 - while chunkIdx < numChunks do - let start := chunkIdx * chunkSize - let stop := min rows.size (start + chunkSize) - tasks := tasks.push <| - Task.spawn (fun _ => - calcVOutRowsIntervalsNoTask cfg hdr.modelDim hdr.headDim wvIntervals - woIntervals bV rows start stop) - chunkIdx := chunkIdx + 1 - let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size - for t in tasks do - for row in t.get do - out := out.push row - return out - else - calcVOutRowsIntervalsNoTask cfg hdr.modelDim hdr.headDim wvIntervals woIntervals bV rows 0 - rows.size - let calcVOutIntervals - (row : Array Fixed10Interval) - (wvIntervals woIntervals : Array Fixed10Interval) - (bV : Array Fixed10Interval) : - Array Fixed10Interval := - let vHidden0 := matMulIntervalsFromIntervalsNoTask cfg - hdr.modelDim hdr.headDim wvIntervals row - let vHidden := addVecFixed vHidden0 bV - matMulIntervalsFromIntervalsNoTask cfg hdr.headDim hdr.modelDim woIntervals vHidden - let bestMatchPattern - (layerIdx headIdx : Nat) - (ln1Rows : Array (Array Fixed10Interval)) - (wq wk : Array Int) - (bQ bK : Array Fixed10Interval) - (targetOffset : Int) - (keyOffset : Int) - (useTasks : Bool := true) : - ExceptT String IO HeadBestMatchPatternCert := do - let ti : Int := (Int.ofNat queryPos) + targetOffset - if ti < 0 || ti ≥ (Int.ofNat seqLenEff) then - throw "query position has no valid target offset" - let tIdx : Nat := Int.toNat ti - let targetTok := tokens[tIdx]! - let keyOffsetNat? : Option Nat := - if keyOffset ≥ 0 then some (Int.toNat keyOffset) else none - let keyOffsetNeg : Nat := Int.toNat (-keyOffset) - let mut bestMatchLower? : Option Int := none - let mut bestNonmatchUpper? : Option Int := none - let mut matchCount : Nat := 0 - let mut nonmatchCount : Nat := 0 - let mut matchLowerMin? : Option Int := none - let mut matchUpperMax? : Option Int := none - let mut nonmatchLowerMax? : Option Int := none - let mut nonmatchUpperMax? : Option Int := none - let mut matchWidthMax : Int := 0 - let mut nonmatchWidthMax : Int := 0 - if useAffine then - let bQCenters := bQ.map (fun x => (x.lo + x.hi).ediv (Int.ofNat 2)) - let bKCenters := bK.map (fun x => (x.lo + x.hi).ediv (Int.ofNat 2)) - let bQRadii := bQ.map intervalRadiusInt - let bKRadii := bK.map intervalRadiusInt - let (qInputCenters, qInputRadii, _qAbsInput) := - rowCentersRadiiAbsInt (ln1Rows[queryPos]!) - let (qCenters0, qRadii0) := - matMulCentersRadiiIntSlack cfg slack - hdr.modelDim hdr.headDim wq qInputCenters qInputRadii - let qCenters := addVecScaledInt qCenters0 bQCenters 1 - let qRadii := addVecScaledInt qRadii0 bQRadii 1 - let useTasksHere := useTasks && !debugMargin && seqLenEff > 32 - if useTasksHere then - let chunkSize : Nat := 32 - let numChunks : Nat := (seqLenEff + chunkSize - 1) / chunkSize - let mut tasks : Array (Task (Option Int × Option Int)) := Array.mkEmpty numChunks - let mut chunkIdx : Nat := 0 - while chunkIdx < numChunks do - let start := chunkIdx * chunkSize - let stop := min seqLenEff (start + chunkSize) - tasks := tasks.push <| Task.spawn (fun _ => - Id.run do - let mut bestMatchLower? : Option Int := none - let mut bestNonmatchUpper? : Option Int := none - let mut j := start - while j < stop do - if !causalPattern || j ≤ queryPos then - let (kInputCenters, kInputRadii, _kAbsInput) := - rowCentersRadiiAbsInt (ln1Rows[j]!) - let (kCenters0, kRadii0) := - matMulCentersRadiiIntSlack cfg slack - hdr.modelDim hdr.headDim wk kInputCenters kInputRadii - let kCenters := addVecScaledInt kCenters0 bKCenters 1 - let kRadii := addVecScaledInt kRadii0 bKRadii 1 - let dot := - dotIntervalFromCentersRadiiInt cfg qCenters qRadii kCenters kRadii - let isMatch : Bool := - match keyOffsetNat? with - | some k => - let idx := j + k - idx < seqLenEff && tokens[idx]! = targetTok - | none => - if j < keyOffsetNeg then - false - else - tokens[j - keyOffsetNeg]! = targetTok - if isMatch then - bestMatchLower? := - match bestMatchLower? with - | none => some dot.lo - | some m => some (max m dot.lo) - else - bestNonmatchUpper? := - match bestNonmatchUpper? with - | none => some dot.hi - | some m => some (max m dot.hi) - j := j + 1 - return (bestMatchLower?, bestNonmatchUpper?)) - chunkIdx := chunkIdx + 1 - for t in tasks do - let (matchChunk?, nonmatchChunk?) := t.get - if matchChunk?.isSome then - bestMatchLower? := - match bestMatchLower?, matchChunk? with - | none, some v => some v - | some cur, some v => some (max cur v) - | some cur, none => some cur - | none, none => none - if nonmatchChunk?.isSome then - bestNonmatchUpper? := - match bestNonmatchUpper?, nonmatchChunk? with - | none, some v => some v - | some cur, some v => some (max cur v) - | some cur, none => some cur - | none, none => none - else - for j in [:seqLenEff] do - if !causalPattern || j ≤ queryPos then - let (kInputCenters, kInputRadii, _kAbsInput) := - rowCentersRadiiAbsInt (ln1Rows[j]!) - let (kCenters0, kRadii0) := - matMulCentersRadiiIntSlack cfg slack - hdr.modelDim hdr.headDim wk kInputCenters kInputRadii - let kCenters := addVecScaledInt kCenters0 bKCenters 1 - let kRadii := addVecScaledInt kRadii0 bKRadii 1 - let dot := - dotIntervalFromCentersRadiiInt cfg qCenters qRadii kCenters kRadii - let isMatch : Bool := - match keyOffsetNat? with - | some k => - let idx := j + k - idx < seqLenEff && tokens[idx]! = targetTok - | none => - if j < keyOffsetNeg then - false - else - tokens[j - keyOffsetNeg]! = targetTok - if isMatch then - if debugMargin then - matchCount := matchCount + 1 - matchLowerMin? := - match matchLowerMin? with - | none => some dot.lo - | some v => some (min v dot.lo) - matchUpperMax? := - match matchUpperMax? with - | none => some dot.hi - | some v => some (max v dot.hi) - let width := dot.hi - dot.lo - if width > matchWidthMax then - matchWidthMax := width - bestMatchLower? := - match bestMatchLower? with - | none => some dot.lo - | some m => some (max m dot.lo) - else - if debugMargin then - nonmatchCount := nonmatchCount + 1 - nonmatchLowerMax? := - match nonmatchLowerMax? with - | none => some dot.lo - | some v => some (max v dot.lo) - nonmatchUpperMax? := - match nonmatchUpperMax? with - | none => some dot.hi - | some v => some (max v dot.hi) - let width := dot.hi - dot.lo - if width > nonmatchWidthMax then - nonmatchWidthMax := width - bestNonmatchUpper? := - match bestNonmatchUpper? with - | none => some dot.hi - | some m => some (max m dot.hi) - else - pure () - else - let qRow0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wq (ln1Rows[queryPos]!) - let qRow := addVecFixed qRow0 bQ - let kRows := - let useTasksHere := useTasks && !debugMargin && ln1Rows.size > 32 - if useTasksHere then - let tasks := ln1Rows.map (fun row => - Task.spawn (fun _ => - let kRow0 := matMulIntervalsFromScaledNoTask cfg slack - hdr.modelDim hdr.headDim wk row - addVecFixed kRow0 bK)) - tasks.map (fun t => t.get) - else - Id.run do - let mut out : Array (Array Fixed10Interval) := Array.mkEmpty seqLenEff - for row in ln1Rows do - let kRow0 := matMulIntervalsFromScaledNoTask cfg slack - hdr.modelDim hdr.headDim wk row - out := out.push (addVecFixed kRow0 bK) - return out - for j in [:seqLenEff] do - if !causalPattern || j ≤ queryPos then - let dot := fixedDotInterval cfg qRow (kRows[j]!) - let isMatch : Bool := - match keyOffsetNat? with - | some k => - let idx := j + k - idx < seqLenEff && tokens[idx]! = targetTok - | none => - if j < keyOffsetNeg then - false - else - tokens[j - keyOffsetNeg]! = targetTok - if isMatch then - if debugMargin then - matchCount := matchCount + 1 - matchLowerMin? := - match matchLowerMin? with - | none => some dot.lo - | some v => some (min v dot.lo) - matchUpperMax? := - match matchUpperMax? with - | none => some dot.hi - | some v => some (max v dot.hi) - let width := dot.hi - dot.lo - if width > matchWidthMax then - matchWidthMax := width - bestMatchLower? := - match bestMatchLower? with - | none => some dot.lo - | some m => some (max m dot.lo) - else - if debugMargin then - nonmatchCount := nonmatchCount + 1 - nonmatchLowerMax? := - match nonmatchLowerMax? with - | none => some dot.lo - | some v => some (max v dot.lo) - nonmatchUpperMax? := - match nonmatchUpperMax? with - | none => some dot.hi - | some v => some (max v dot.hi) - let width := dot.hi - dot.lo - if width > nonmatchWidthMax then - nonmatchWidthMax := width - bestNonmatchUpper? := - match bestNonmatchUpper? with - | none => some dot.hi - | some m => some (max m dot.hi) - else - pure () - let bestMatchLower ← - match bestMatchLower? with - | none => throw "no matching keys for the requested offset" - | some v => pure v - let bestNonmatchUpper := - match bestNonmatchUpper? with - | none => bestMatchLower - | some v => v - let marginInt : Int := bestMatchLower - bestNonmatchUpper - if debugMargin then - let matchLowerMin := matchLowerMin?.getD bestMatchLower - let matchUpperMax := matchUpperMax?.getD bestMatchLower - let nonmatchLowerMax := nonmatchLowerMax?.getD bestNonmatchUpper - let nonmatchUpperMax := nonmatchUpperMax?.getD bestNonmatchUpper - let msg := - s!"pattern_debug:layer{layerIdx}:head{headIdx} " ++ - s!"targetTok={targetTok} queryPos={queryPos} offset={targetOffset} " ++ - s!"keyOffset={keyOffset} scalePow10={scalePow10} " ++ - s!"matches={matchCount} nonmatches={nonmatchCount} " ++ - s!"matchLoMaxInt={bestMatchLower} matchLoMinInt={matchLowerMin} " ++ - s!"matchHiMaxInt={matchUpperMax} nonmatchLoMaxInt={nonmatchLowerMax} " ++ - s!"nonmatchHiMaxInt={nonmatchUpperMax} marginInt={marginInt} " ++ - s!"matchWidthMaxInt={matchWidthMax} nonmatchWidthMaxInt={nonmatchWidthMax}" - ExceptT.lift <| IO.eprintln msg - let bestMatchLowerRat := ratOfScaledInt scalePow10 bestMatchLower - let bestNonmatchUpperRat := ratOfScaledInt scalePow10 bestNonmatchUpper - let margin := ratOfScaledInt scalePow10 marginInt - let (effortUsed, weightLB, softmaxJacobianUB) := - chooseSoftmaxExpEffort hdr.seqLen margin softmaxExpEffort - let cert : HeadBestMatchPatternCert := { - layerIdx := layerIdx - headIdx := headIdx - seqLen := hdr.seqLen - queryPos := queryPos - targetOffset := targetOffset - keyOffset := keyOffset - targetToken := targetTok - bestMatchLogitLowerBound := bestMatchLowerRat - bestNonmatchLogitUpperBound := bestNonmatchUpperRat - marginLowerBound := margin - softmaxExpEffort := effortUsed - bestMatchWeightLowerBound := weightLB - softmaxJacobianNormInfUpperBound := softmaxJacobianUB - } - if cert.check then - return cert - throw "best-match head pattern certificate failed internal consistency checks" - let valueLogit - (ln1Rows : Array (Array Fixed10Interval)) - (matchWeightLowerBound : Rat) - (wvIntervals woIntervals : Array Fixed10Interval) - (bV : Array Fixed10Interval) - (targetOffset : Int) - (keyOffset : Int) : - ExceptT String IO HeadValueLogitCert := do - let vOutRows := calcVOutRowsIntervals ln1Rows wvIntervals woIntervals bV - let ti : Int := (Int.ofNat queryPos) + targetOffset - if ti < 0 || ti ≥ (Int.ofNat seqLenEff) then - throw "query position has no valid target offset" - let tIdx : Nat := Int.toNat ti - let targetTok := tokens[tIdx]! - let keyOffsetNat? : Option Nat := - if keyOffset ≥ 0 then some (Int.toNat keyOffset) else none - let keyOffsetNeg : Nat := Int.toNat (-keyOffset) - let mut matchLo? : Option Int := none - let mut nonmatchLo? : Option Int := none - for j in [:seqLenEff] do - if !causalPattern || j ≤ queryPos then - let row := vOutRows[j]! - let vCoord := row[coord]!.lo - let isMatch : Bool := - match keyOffsetNat? with - | some k => - let idx := j + k - idx < seqLenEff && tokens[idx]! = targetTok - | none => - if j < keyOffsetNeg then - false - else - tokens[j - keyOffsetNeg]! = targetTok - if isMatch then - matchLo? := - match matchLo? with - | none => some vCoord - | some m => some (min m vCoord) - else - nonmatchLo? := - match nonmatchLo? with - | none => some vCoord - | some m => some (min m vCoord) - else - pure () - let matchLo ← - match matchLo? with - | none => throw "no matching keys for the requested offset" - | some v => pure v - let nonmatchLo := - match nonmatchLo? with - | none => matchLo - | some v => v - let matchLoRat := ratOfScaledInt scalePow10 matchLo - let nonmatchLoRat := ratOfScaledInt scalePow10 nonmatchLo - let outputLB := mixLowerBound matchWeightLowerBound matchLoRat nonmatchLoRat - let value : HeadValueLowerBoundPosCert := { - layerIdx := layer2 - headIdx := head2 - queryPos := queryPos - coord := coord - matchWeightLowerBound := matchWeightLowerBound - matchCoordLowerBound := matchLoRat - nonmatchCoordLowerBound := nonmatchLoRat - outputCoordLowerBound := outputLB - } - if !value.check then - throw "head value certificate failed internal consistency checks" - let logit? ← - if !useLogit then - pure none - else - match targetToken?, negativeToken?, direction? with - | some targetToken, some negativeToken, some direction => do - let dir := direction.get - if dir.size ≠ hdr.modelDim then - throw "logit direction size mismatch" - let mut vDotRows : Array Fixed10Interval := Array.mkEmpty seqLenEff - for row in vOutRows do - vDotRows := vDotRows.push (fixedDotInterval cfg row dir) - let mut matchLoLogit? : Option Int := none - let mut nonmatchLoLogit? : Option Int := none - for j in [:seqLenEff] do - if !causalPattern || j ≤ queryPos then - let vLo := (vDotRows[j]!).lo - let isMatch : Bool := - match keyOffsetNat? with - | some k => - let idx := j + k - idx < seqLenEff && tokens[idx]! = targetTok - | none => - if j < keyOffsetNeg then - false - else - tokens[j - keyOffsetNeg]! = targetTok - if isMatch then - matchLoLogit? := - match matchLoLogit? with - | none => some vLo - | some m => some (min m vLo) - else - nonmatchLoLogit? := - match nonmatchLoLogit? with - | none => some vLo - | some m => some (min m vLo) - else - pure () - let matchLoLogit ← - match matchLoLogit? with - | none => throw "no matching keys for the requested offset" - | some v => pure v - let nonmatchLoLogit := - match nonmatchLoLogit? with - | none => matchLoLogit - | some v => v - let matchLoRat := ratOfScaledInt scalePow10 matchLoLogit - let nonmatchLoRat := ratOfScaledInt scalePow10 nonmatchLoLogit - let logitLB := mixLowerBound matchWeightLowerBound matchLoRat nonmatchLoRat - let logitCert : HeadLogitDiffLowerBoundPosCert := { - layerIdx := layer2 - headIdx := head2 - queryPos := queryPos - targetToken := targetToken - negativeToken := negativeToken - matchWeightLowerBound := matchWeightLowerBound - matchLogitLowerBound := matchLoRat - nonmatchLogitLowerBound := nonmatchLoRat - logitDiffLowerBound := logitLB - } - if logitCert.check then - pure (some logitCert) - else - throw "head logit certificate failed internal consistency checks" - | _, _, _ => - throw "use both target and negative tokens (or neither)" - return { value := value, logit? := logit? } - let tightenQueryRowLower - (baseRow : Array Fixed10Interval) - (vOutRows : Array (Array Fixed10Interval)) - (matchWeightLowerBound : Rat) - (targetOffset : Int) - (keyOffset : Int) : - ExceptT String IO (Array Fixed10Interval) := do - let ti : Int := (Int.ofNat queryPos) + targetOffset - if ti < 0 || ti ≥ (Int.ofNat seqLenEff) then - throw "query position has no valid target offset" - let tIdx : Nat := Int.toNat ti - let targetTok := tokens[tIdx]! - let keyOffsetNat? : Option Nat := - if keyOffset ≥ 0 then some (Int.toNat keyOffset) else none - let keyOffsetNeg : Nat := Int.toNat (-keyOffset) - let mut matchLo? : Option (Array Int) := none - let mut nonmatchLo? : Option (Array Int) := none - for j in [:seqLenEff] do - if !causalPattern || j ≤ queryPos then - let row := vOutRows[j]! - let rowLo : Array Int := row.map (fun x => x.lo) - let isMatch : Bool := - match keyOffsetNat? with - | some k => - let idx := j + k - idx < seqLenEff && tokens[idx]! = targetTok - | none => - if j < keyOffsetNeg then - false - else - tokens[j - keyOffsetNeg]! = targetTok - if isMatch then - matchLo? := - match matchLo? with - | none => some rowLo - | some cur => - some <| Id.run do - let mut out : Array Int := Array.mkEmpty hdr.modelDim - for i in [:hdr.modelDim] do - out := out.push (min cur[i]! rowLo[i]!) - out - else - nonmatchLo? := - match nonmatchLo? with - | none => some rowLo - | some cur => - some <| Id.run do - let mut out : Array Int := Array.mkEmpty hdr.modelDim - for i in [:hdr.modelDim] do - out := out.push (min cur[i]! rowLo[i]!) - out - let matchLo ← - match matchLo? with - | none => throw "no matching keys for the requested offset" - | some v => pure v - let nonmatchLo := - match nonmatchLo? with - | none => matchLo - | some v => v - let mut tightened : Array Fixed10Interval := Array.mkEmpty hdr.modelDim - for i in [:hdr.modelDim] do - let matchLoRat := ratOfScaledInt scalePow10 matchLo[i]! - let nonmatchLoRat := ratOfScaledInt scalePow10 nonmatchLo[i]! - let outLB := mixLowerBound matchWeightLowerBound matchLoRat nonmatchLoRat - let outLBInt := ratFloorMulNat outLB cfg.scaleNat - let base := baseRow[i]! - let newLo := max base.lo outLBInt - tightened := tightened.push { lo := newLo, hi := base.hi } - return tightened - let addAttn - (useTight : Bool) - (ln1Rows : Array (Array Fixed10Interval)) - (ln1Union? : Option (Array Fixed10Interval)) - (groupRows? : Option (Array (Array Fixed10Interval))) - (attnRows? : Option (Array (Array Fixed10Interval))) - (attnUnion? : Option (Array Fixed10Interval)) - (wvIntervals woIntervals : Array Fixed10Interval) - (bV : Array Fixed10Interval) : - ExceptT String IO - (Option (Array (Array Fixed10Interval)) × Option (Array Fixed10Interval)) := do - if useTight then - if causalPattern then - let vOutRows := calcVOutRowsIntervals ln1Rows wvIntervals woIntervals bV - match attnRows? with - | some rows => - if rows.size ≠ vOutRows.size then - return (some rows, attnUnion?) - if vOutRows.isEmpty then - return (some rows, attnUnion?) - let mut outRows := rows - let mut acc := vOutRows[0]! - outRows := outRows.set! 0 (addVecFixed rows[0]! acc) - let mut i : Nat := 1 - while i < vOutRows.size do - acc := Fixed10Interval.unionVec acc vOutRows[i]! - outRows := outRows.set! i (addVecFixed rows[i]! acc) - i := i + 1 - return (some outRows, attnUnion?) - | none => throw "missing attnRows" - else - let groupRows ← - match groupRows? with - | some rows => pure rows - | none => throw "missing group rows" - let vOutRows := calcVOutRowsIntervals groupRows wvIntervals woIntervals bV - let vUnion := unionRowsFixed vOutRows - match attnUnion? with - | some u => return (attnRows?, some (addVecFixed u vUnion)) - | none => throw "missing attnUnion" - else - let ln1Union ← - match ln1Union? with - | some row => pure row - | none => throw "missing ln1Union" - let vOut := calcVOutIntervals ln1Union wvIntervals woIntervals bV - match attnUnion? with - | some u => return (attnRows?, some (addVecFixed u vOut)) - | none => throw "missing attnUnion" - let applyAttn - (label : String) - (rows : Array (Array Fixed10Interval)) - (useTight : Bool) - (attnRows? : Option (Array (Array Fixed10Interval))) - (attnUnion? : Option (Array Fixed10Interval)) - (attnBias : Array Fixed10Interval) : - ExceptT String IO (Array (Array Fixed10Interval)) := do - if useTight && causalPattern then - match attnRows? with - | some attnRows => - let attnRows := addVecFixedRows attnRows attnBias - logRowsWidth "attn_debug" s!"{label}:attn_rows" attnRows - let out := addRowsFixed rows attnRows - logRowsWidth "attn_debug" s!"{label}:out" out - return out - | none => throw "missing attnRows" - else - match attnUnion? with - | some attnUnion => - let attnUnion := addVecFixed attnUnion attnBias - logVecWidth "attn_debug" s!"{label}:attn_union" attnUnion - let out := addVecFixedRows rows attnUnion - logRowsWidth "attn_debug" s!"{label}:out" out - return out - | none => throw "missing attnUnion" - let applyMlp - (label : String) - (rows : Array (Array Fixed10Interval)) - (usePerRow : Bool) - (p : LayerNormParamsFixed) - (wIn wOut : Array Int) - (bIn bOut : Array Fixed10Interval) : - ExceptT String IO (Array (Array Fixed10Interval)) := do - let ln2Rows := calcLnRows rows p - logRowsWidth "ln2_debug" s!"{label}:ln2" ln2Rows - let geluTargetUnion : GeluDerivTarget := - if hdr.geluDerivTarget = .tanh then .exact else hdr.geluDerivTarget - if usePerRow then - match selectedRows? with - | none => - let wInIntervals := intervalsFromScaled wIn slack - let wOutIntervals := intervalsFromScaled wOut slack - let mlpRows := mlpRowsFromIntervals cfg geluTargetUnion - hdr.modelDim hdr.hiddenDim wInIntervals wOutIntervals bIn bOut ln2Rows - logRowsWidth "mlp_debug" s!"{label}:mlp_rows" mlpRows - let out := addRowsFixed rows mlpRows - logRowsWidth "mlp_debug" s!"{label}:out" out - return out - | some idxs => - let ln2Union := unionRowsFixed ln2Rows - let mlpUnion := mlpRowFromScaled cfg geluTargetUnion slack - hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Union - logVecWidth "mlp_debug" s!"{label}:mlp_union" mlpUnion - let idxsValid := idxs.filter (fun idx => idx < ln2Rows.size) - let useTasksHere := idxsValid.size > 4 - let out := - if useTasksHere then - Id.run do - let chunkSize : Nat := 8 - let numChunks : Nat := (idxsValid.size + chunkSize - 1) / chunkSize - let mut tasks : Array (Task (Array (Nat × Array Fixed10Interval))) := - Array.mkEmpty numChunks - let mut chunkIdx : Nat := 0 - while chunkIdx < numChunks do - let start := chunkIdx * chunkSize - let stop := min idxsValid.size (start + chunkSize) - tasks := tasks.push <| - Task.spawn (fun _ => - Id.run do - let mut outChunk : Array (Nat × Array Fixed10Interval) := - Array.mkEmpty (stop - start) - let mut i := start - while i < stop do - let idx := idxsValid[i]! - let mlpRow := mlpRowFromScaledNoTask cfg hdr.geluDerivTarget slack - hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut (ln2Rows[idx]!) - outChunk := outChunk.push (idx, mlpRow) - i := i + 1 - return outChunk) - chunkIdx := chunkIdx + 1 - let mut out := addVecFixedRows rows mlpUnion - for t in tasks do - for (idx, mlpRow) in t.get do - out := out.set! idx (addVecFixed (rows[idx]!) mlpRow) - return out - else - Id.run do - let mut out := addVecFixedRows rows mlpUnion - for idx in idxsValid do - let mlpRow := mlpRowFromScaledNoTask cfg hdr.geluDerivTarget slack - hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut (ln2Rows[idx]!) - out := out.set! idx (addVecFixed (rows[idx]!) mlpRow) - return out - logRowsWidth "mlp_debug" s!"{label}:out" out - return out - else - let ln2Union := unionRowsFixed ln2Rows - let mlpOut := mlpRowFromScaled cfg geluTargetUnion slack - hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Union - logVecWidth "mlp_debug" s!"{label}:mlp_union" mlpOut - let out := addVecFixedRows rows mlpOut - logRowsWidth "mlp_debug" s!"{label}:out" out - return out - let awaitPattern - (pattern? : Option HeadBestMatchPatternCert) - (task? : Option (Task (Except IO.Error (Except String HeadBestMatchPatternCert)))) - (label : String) : - ExceptT String IO HeadBestMatchPatternCert := do - match pattern? with - | some cert => pure cert - | none => - match task? with - | none => throw label - | some task => - match task.get with - | .error e => throw (toString e) - | .ok (.error msg) => throw msg - | .ok (.ok cert) => pure cert - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let _ ← ExceptT.mk (readBinaryHeader h) - let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) - let _ ← ExceptT.mk (skipF64Array h (hdr.seqLen * hdr.modelDim)) - let defP : LayerNormParamsFixed := { - gamma := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - beta := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - } - let mut residuals1 := residuals0 - let mut residuals2 := residuals0 - let mut residualsSame : Bool := true - let mut residualsV := residuals0 - let mut residualsSameV : Bool := true - let mut p1? : Option HeadBestMatchPatternCert := none - let mut p2? : Option HeadBestMatchPatternCert := none - let mut p1Task? : - Option (Task (Except IO.Error (Except String HeadBestMatchPatternCert))) := none - let mut p2Task? : - Option (Task (Except IO.Error (Except String HeadBestMatchPatternCert))) := none - let mut vlogit? : Option HeadValueLogitCert := none - for l in [:hdr.numLayers] do - let at1 := l = layer1 && p1?.isNone - let at2 := l = layer2 && p2?.isNone - let needUpdate1 := l < layer1 && p1?.isNone - let needUpdate2 := l < layer2 && p2?.isNone - let needUpdateV := needUpdate2 - let needRows1 := at1 || needUpdate1 - let needRows2 := at2 || needUpdate2 - let needRowsV := needRows2 - let ln1P := ln1Params.getD l defP - let tLn10? ← - if timing then - let t0 ← ExceptT.lift IO.monoNanosNow - pure (some t0) - else - pure none - let mut ln1RowsShared? : Option (Array (Array Fixed10Interval)) := none - if residualsSame && (needRows1 || needRows2) then - let rows := calcLnRows residuals1 ln1P - ln1RowsShared? := some rows - logRowsWidth "ln1_debug" s!"layer{l}:ln1_shared" rows - let mut ln1Rows1? : Option (Array (Array Fixed10Interval)) := none - let mut ln1Rows2? : Option (Array (Array Fixed10Interval)) := none - if needRows1 then - match ln1RowsShared? with - | some rows => ln1Rows1? := some rows - | none => - let rows := calcLnRows residuals1 ln1P - ln1Rows1? := some rows - logRowsWidth "ln1_debug" s!"layer{l}:ln1_1" rows - if needRows2 then - match ln1RowsShared? with - | some rows => ln1Rows2? := some rows - | none => - let rows := calcLnRows residuals2 ln1P - ln1Rows2? := some rows - logRowsWidth "ln1_debug" s!"layer{l}:ln1_2" rows - let mut ln1Rows1Exact? : Option (Array (Array Fixed10Interval)) := none - let mut ln1Rows2Exact? : Option (Array (Array Fixed10Interval)) := none - if at1 then - ln1Rows1Exact? := some (calcLnRowsExact residuals1 ln1P) - if at2 then - ln1Rows2Exact? := some (calcLnRowsExact residuals2 ln1P) - let mut ln1RowsV? : Option (Array (Array Fixed10Interval)) := none - if needRowsV then - if residualsSameV then - ln1RowsV? := ln1Rows2? - else - let rows := calcLnRows residualsV ln1P - ln1RowsV? := some rows - logRowsWidth "ln1_debug" s!"layer{l}:ln1_v" rows - if let some t0 := tLn10? then - let t1 ← ExceptT.lift IO.monoNanosNow - let dtMs := (t1 - t0) / 1000000 - ExceptT.lift <| IO.eprintln s!"timing:layer{l}:ln1 {dtMs}ms" - let tightLayers : Nat := - if tightPattern then Nat.max 1 tightPatternLayers else 0 - let useTight1 := needUpdate1 && tightLayers > 0 && layer1 ≤ l + tightLayers - let useTight2 := needUpdate2 && tightLayers > 0 && layer2 ≤ l + tightLayers - let usePerRow1 := - needUpdate1 && perRowPatternLayers > 0 && layer1 ≤ l + perRowPatternLayers - let usePerRow2 := - needUpdate2 && perRowPatternLayers > 0 && layer2 ≤ l + perRowPatternLayers - let useTightV := useTight2 - let usePerRowV := usePerRow2 - let usePatternTasks : Bool := perRowPatternLayers = 0 - let needTightenNow : Bool := l == layer1 && useTight2 && causalPattern - let skipAttnV := useTightV && causalPattern && seqLenEff < hdr.seqLen - let shareUpdateV := residualsSameV && needUpdateV && !skipAttnV - let shareUpdate := - residualsSame && needUpdate1 && needUpdate2 && - useTight1 = useTight2 && usePerRow1 = usePerRow2 - let zeroRow : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - let mut ln1Union1? : Option (Array Fixed10Interval) := none - let mut ln1Union2? : Option (Array Fixed10Interval) := none - let mut groupRows1? : Option (Array (Array Fixed10Interval)) := none - let mut groupRows2? : Option (Array (Array Fixed10Interval)) := none - let mut attnRows1? : Option (Array (Array Fixed10Interval)) := none - let mut attnRows2? : Option (Array (Array Fixed10Interval)) := none - let mut attnUnion1? : Option (Array Fixed10Interval) := none - let mut attnUnion2? : Option (Array Fixed10Interval) := none - let mut ln1UnionV? : Option (Array Fixed10Interval) := none - let mut groupRowsV? : Option (Array (Array Fixed10Interval)) := none - let mut attnRowsV? : Option (Array (Array Fixed10Interval)) := none - let mut attnUnionV? : Option (Array Fixed10Interval) := none - let mut ln1UnionShared? : Option (Array Fixed10Interval) := none - let mut groupRowsShared? : Option (Array (Array Fixed10Interval)) := none - let mut attnRowsShared? : Option (Array (Array Fixed10Interval)) := none - let mut attnUnionShared? : Option (Array Fixed10Interval) := none - let ln1Rows1 := ln1Rows1?.getD #[] - let ln1Rows2 := ln1Rows2?.getD #[] - let ln1RowsV := ln1RowsV?.getD #[] - let ln1RowsShared := ln1RowsShared?.getD #[] - if shareUpdate then - if useTight1 then - if causalPattern then - attnRowsShared? := some (Array.replicate seqLenEff zeroRow) - else - groupRowsShared? := some (groupUnionRowsByToken ln1RowsShared tokens) - attnUnionShared? := some (Array.replicate hdr.modelDim { lo := 0, hi := 0 }) - else - ln1UnionShared? := some (unionRowsFixed ln1RowsShared) - attnUnionShared? := some (Array.replicate hdr.modelDim { lo := 0, hi := 0 }) - else - if needUpdate1 then - if useTight1 then - if causalPattern then - attnRows1? := some (Array.replicate seqLenEff zeroRow) - else - groupRows1? := some (groupUnionRowsByToken ln1Rows1 tokens) - attnUnion1? := some (Array.replicate hdr.modelDim { lo := 0, hi := 0 }) - else - ln1Union1? := some (unionRowsFixed ln1Rows1) - attnUnion1? := some (Array.replicate hdr.modelDim { lo := 0, hi := 0 }) - if needUpdate2 then - if useTight2 then - if causalPattern then - attnRows2? := some (Array.replicate seqLenEff zeroRow) - else - groupRows2? := some (groupUnionRowsByToken ln1Rows2 tokens) - attnUnion2? := some (Array.replicate hdr.modelDim { lo := 0, hi := 0 }) - else - ln1Union2? := some (unionRowsFixed ln1Rows2) - attnUnion2? := some (Array.replicate hdr.modelDim { lo := 0, hi := 0 }) - if needUpdateV && !shareUpdateV && !skipAttnV then - if useTightV then - if causalPattern then - attnRowsV? := some (Array.replicate seqLenEff zeroRow) - else - groupRowsV? := some (groupUnionRowsByToken ln1RowsV tokens) - attnUnionV? := some (Array.replicate hdr.modelDim { lo := 0, hi := 0 }) - else - ln1UnionV? := some (unionRowsFixed ln1RowsV) - attnUnionV? := some (Array.replicate hdr.modelDim { lo := 0, hi := 0 }) - let needUpdate := needUpdate1 || needUpdate2 - let tHeads0? ← - if timing then - let t0 ← ExceptT.lift IO.monoNanosNow - pure (some t0) - else - pure none - let mut qkReadMs : Nat := 0 - let mut vReadMs : Nat := 0 - let mut addAttnMs : Nat := 0 - let mut tightenMs : Nat := 0 - let mut tightenVOutMs : Nat := 0 - let mut tightenPrefixMs : Nat := 0 - let mut tightenRowMs : Nat := 0 - let mut tightenWaitMs : Nat := 0 - for hIdx in [:hdr.numHeads] do - let needValue := at2 && hIdx = head2 - let needV := needUpdate || needValue - let needQK := (at1 && hIdx = head1) || (at2 && hIdx = head2) - if needQK then - let tQK0? ← - if timing then - let t0 ← ExceptT.lift IO.monoNanosNow - pure (some t0) - else - pure none - let wq ← - ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let bQ ← ExceptT.mk <| readScaledFloatArray h hdr.headDim scalePow10 - let wk ← - ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let bK ← ExceptT.mk <| readScaledFloatArray h hdr.headDim scalePow10 - if let some t0 := tQK0? then - let t1 ← ExceptT.lift IO.monoNanosNow - let dtMs := (t1 - t0) / 1000000 - qkReadMs := qkReadMs + dtMs - let bQIntervals := intervalsFromScaled bQ slack - let bKIntervals := intervalsFromScaled bK slack - if at1 && hIdx = head1 then - let ln1Rows1Exact := ln1Rows1Exact?.getD ln1Rows1 - if needV && !needTightenNow then - let task ← - ExceptT.lift <| - IO.asTask - (timeIt s!"layer{layer1}:pattern" <| - bestMatchPattern - layer1 head1 ln1Rows1Exact wq wk bQIntervals bKIntervals offset1 - keyOffset1 - (useTasks := usePatternTasks)).run - p1Task? := some task - else - let p1 ← - timeIt s!"layer{layer1}:pattern" <| - bestMatchPattern - layer1 head1 ln1Rows1Exact wq wk bQIntervals bKIntervals offset1 keyOffset1 - p1? := some p1 - if at2 && hIdx = head2 then - let ln1Rows2Exact := ln1Rows2Exact?.getD ln1Rows2 - if needV then - let task ← - ExceptT.lift <| - IO.asTask - (timeIt s!"layer{layer2}:pattern" <| - bestMatchPattern - layer2 head2 ln1Rows2Exact wq wk bQIntervals bKIntervals offset2 - keyOffset2 - (useTasks := usePatternTasks)).run - p2Task? := some task - else - let p2 ← - timeIt s!"layer{layer2}:pattern" <| - bestMatchPattern - layer2 head2 ln1Rows2Exact wq wk bQIntervals bKIntervals offset2 keyOffset2 - p2? := some p2 - else - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - if needV then - let tV0? ← - if timing then - let t0 ← ExceptT.lift IO.monoNanosNow - pure (some t0) - else - pure none - let wv ← - ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let bV ← ExceptT.mk <| readScaledFloatArray h hdr.headDim scalePow10 - let wo ← - ExceptT.mk <| readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - if let some t0 := tV0? then - let t1 ← ExceptT.lift IO.monoNanosNow - let dtMs := (t1 - t0) / 1000000 - vReadMs := vReadMs + dtMs - let bVIntervals := intervalsFromScaled bV slack - let wvIntervals := intervalsFromScaled wv slack - let woIntervals := intervalsFromScaled wo slack - if needUpdate then - if shareUpdate then - let tAdd0? ← - if timing then - let t0 ← ExceptT.lift IO.monoNanosNow - pure (some t0) - else - pure none - let (attnRows', attnUnion') ← - addAttn useTight1 ln1RowsShared ln1UnionShared? groupRowsShared? - attnRowsShared? attnUnionShared? wvIntervals woIntervals bVIntervals - if let some t0 := tAdd0? then - let t1 ← ExceptT.lift IO.monoNanosNow - let dtMs := (t1 - t0) / 1000000 - addAttnMs := addAttnMs + dtMs - attnRowsShared? := attnRows' - attnUnionShared? := attnUnion' - else - if needUpdate1 then - let tAdd0? ← - if timing then - let t0 ← ExceptT.lift IO.monoNanosNow - pure (some t0) - else - pure none - let (attnRows', attnUnion') ← - addAttn useTight1 ln1Rows1 ln1Union1? groupRows1? - attnRows1? attnUnion1? wvIntervals woIntervals bVIntervals - if let some t0 := tAdd0? then - let t1 ← ExceptT.lift IO.monoNanosNow - let dtMs := (t1 - t0) / 1000000 - addAttnMs := addAttnMs + dtMs - attnRows1? := attnRows' - attnUnion1? := attnUnion' - if needUpdate2 then - if l == layer1 && hIdx == head1 && useTight2 && causalPattern then - let tTight0? ← - if timing then - let t0 ← ExceptT.lift IO.monoNanosNow - pure (some t0) - else - pure none - let tWait0? ← - if timing then - let t0 ← ExceptT.lift IO.monoNanosNow - pure (some t0) - else - pure none - let p1 ← - awaitPattern p1? p1Task? "missing best-match pattern cert for tightening" - if let some t0 := tWait0? then - let t1 ← ExceptT.lift IO.monoNanosNow - let dtMs := (t1 - t0) / 1000000 - tightenWaitMs := tightenWaitMs + dtMs - p1? := some p1 - let tVOut0? ← - if timing then - let t0 ← ExceptT.lift IO.monoNanosNow - pure (some t0) - else - pure none - let vOutRows := calcVOutRowsIntervals ln1Rows2 wvIntervals woIntervals bVIntervals - if let some t0 := tVOut0? then - let t1 ← ExceptT.lift IO.monoNanosNow - let dtMs := (t1 - t0) / 1000000 - tightenVOutMs := tightenVOutMs + dtMs - let tPrefix0? ← - if timing then - let t0 ← ExceptT.lift IO.monoNanosNow - pure (some t0) - else - pure none - let mut headRows := prefixUnionRowsFixed vOutRows - if let some t0 := tPrefix0? then - let t1 ← ExceptT.lift IO.monoNanosNow - let dtMs := (t1 - t0) / 1000000 - tightenPrefixMs := tightenPrefixMs + dtMs - let baseRow := headRows[queryPos]! - let tRow0? ← - if timing then - let t0 ← ExceptT.lift IO.monoNanosNow - pure (some t0) - else - pure none - let tightRow ← - tightenQueryRowLower baseRow vOutRows p1.bestMatchWeightLowerBound offset1 - keyOffset1 - if let some t0 := tRow0? then - let t1 ← ExceptT.lift IO.monoNanosNow - let dtMs := (t1 - t0) / 1000000 - tightenRowMs := tightenRowMs + dtMs - headRows := headRows.set! queryPos tightRow - if let some t0 := tTight0? then - let t1 ← ExceptT.lift IO.monoNanosNow - let dtMs := (t1 - t0) / 1000000 - tightenMs := tightenMs + dtMs - match attnRows2? with - | some rows => attnRows2? := some (addRowsFixed rows headRows) - | none => throw "missing attnRows" - else - let tAdd0? ← - if timing then - let t0 ← ExceptT.lift IO.monoNanosNow - pure (some t0) - else - pure none - let (attnRows', attnUnion') ← - addAttn useTight2 ln1Rows2 ln1Union2? groupRows2? - attnRows2? attnUnion2? wvIntervals woIntervals bVIntervals - if let some t0 := tAdd0? then - let t1 ← ExceptT.lift IO.monoNanosNow - let dtMs := (t1 - t0) / 1000000 - addAttnMs := addAttnMs + dtMs - attnRows2? := attnRows' - attnUnion2? := attnUnion' - if needUpdateV && !shareUpdateV && !skipAttnV then - let tAdd0? ← - if timing then - let t0 ← ExceptT.lift IO.monoNanosNow - pure (some t0) - else - pure none - let (attnRows', attnUnion') ← - addAttn useTightV ln1RowsV ln1UnionV? groupRowsV? - attnRowsV? attnUnionV? wvIntervals woIntervals bVIntervals - if let some t0 := tAdd0? then - let t1 ← ExceptT.lift IO.monoNanosNow - let dtMs := (t1 - t0) / 1000000 - addAttnMs := addAttnMs + dtMs - attnRowsV? := attnRows' - attnUnionV? := attnUnion' - if needValue then - let p2 ← - awaitPattern p2? p2Task? "missing best-match pattern cert for value bound" - p2? := some p2 - let vlogit ← - timeIt s!"layer{layer2}:value_logit" <| - valueLogit ln1RowsV p2.bestMatchWeightLowerBound wvIntervals woIntervals - bVIntervals offset2 keyOffset2 - vlogit? := some vlogit - else - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.headDim * hdr.modelDim)) - if let some t0 := tHeads0? then - let t1 ← ExceptT.lift IO.monoNanosNow - let dtMs := (t1 - t0) / 1000000 - ExceptT.lift <| IO.eprintln s!"timing:layer{l}:heads {dtMs}ms" - if timing then - ExceptT.lift <| IO.eprintln s!"timing:layer{l}:qk_read {qkReadMs}ms" - ExceptT.lift <| IO.eprintln s!"timing:layer{l}:v_read {vReadMs}ms" - ExceptT.lift <| IO.eprintln s!"timing:layer{l}:add_attn {addAttnMs}ms" - ExceptT.lift <| IO.eprintln s!"timing:layer{l}:tighten {tightenMs}ms" - ExceptT.lift <| IO.eprintln s!"timing:layer{l}:tighten_vout {tightenVOutMs}ms" - ExceptT.lift <| IO.eprintln s!"timing:layer{l}:tighten_prefix {tightenPrefixMs}ms" - ExceptT.lift <| IO.eprintln s!"timing:layer{l}:tighten_row {tightenRowMs}ms" - ExceptT.lift <| IO.eprintln s!"timing:layer{l}:tighten_wait {tightenWaitMs}ms" - if p1?.isSome && p2?.isSome && vlogit?.isSome && !(needUpdate1 || needUpdate2) then - match p1?, p2?, vlogit? with - | some p1, some p2, some vlogit => - let cert : InductionHeadBestMatchSoundCert := { - layer1Pattern := p1 - layer2Pattern := p2 - layer2Value := vlogit.value - layer2Logit? := vlogit.logit? - deltaLowerBound := vlogit.value.outputCoordLowerBound - } - if cert.check then - return cert - throw "induction head certificate failed internal consistency checks" - | _, _, _ => throw "induction head certificate failed internal consistency checks" - if needUpdate1 || needUpdate2 then - let tAttn0? ← - if timing then - let t0 ← ExceptT.lift IO.monoNanosNow - pure (some t0) - else - pure none - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - if shareUpdate then - residuals1 ← - applyAttn s!"layer{l}:attn_shared" residuals1 useTight1 attnRowsShared? - attnUnionShared? attnBias - residuals2 := residuals1 - else - if needUpdate1 then - residuals1 ← - applyAttn s!"layer{l}:attn_1" residuals1 useTight1 attnRows1? attnUnion1? - attnBias - if needUpdate2 then - residuals2 ← - applyAttn s!"layer{l}:attn_2" residuals2 useTight2 attnRows2? attnUnion2? - attnBias - if needUpdateV && !shareUpdateV && !skipAttnV then - residualsV ← - applyAttn s!"layer{l}:attn_v" residualsV useTightV attnRowsV? attnUnionV? - attnBias - if let some t0 := tAttn0? then - let t1 ← ExceptT.lift IO.monoNanosNow - let dtMs := (t1 - t0) / 1000000 - ExceptT.lift <| IO.eprintln s!"timing:layer{l}:attn_update {dtMs}ms" - let tMlp0? ← - if timing then - let t0 ← ExceptT.lift IO.monoNanosNow - pure (some t0) - else - pure none - let wIn ← - ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.hiddenDim) scalePow10 - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let wOut ← - ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let ln2P := ln2Params.getD l defP - if shareUpdate then - residuals1 ← - applyMlp s!"layer{l}:mlp_shared" residuals1 usePerRow1 ln2P wIn wOut bIn bOut - residuals2 := residuals1 - else - if needUpdate1 then - residuals1 ← - applyMlp s!"layer{l}:mlp_1" residuals1 usePerRow1 ln2P wIn wOut bIn bOut - if needUpdate2 then - residuals2 ← - applyMlp s!"layer{l}:mlp_2" residuals2 usePerRow2 ln2P wIn wOut bIn bOut - if needUpdateV then - if shareUpdateV then - residualsV := residuals2 - else - residualsV ← - applyMlp s!"layer{l}:mlp_v" residualsV usePerRowV ln2P wIn wOut bIn bOut - if shareUpdate then - residualsSame := true - else if needUpdate1 && needUpdate2 then - residualsSame := false - if needUpdateV then - if shareUpdateV then - residualsSameV := true - else - residualsSameV := false - if let some t0 := tMlp0? then - let t1 ← ExceptT.lift IO.monoNanosNow - let dtMs := (t1 - t0) / 1000000 - ExceptT.lift <| IO.eprintln s!"timing:layer{l}:mlp_update {dtMs}ms" - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - if l == layer1 && p1?.isNone then - let p1 ← - awaitPattern p1? p1Task? "missing best-match pattern cert for layer1" - p1? := some p1 - if l == layer2 && p2?.isNone then - let p2 ← - awaitPattern p2? p2Task? "missing best-match pattern cert for layer2" - p2? := some p2 - match p1?, p2?, vlogit? with - | some p1, some p2, some vlogit => - let cert : InductionHeadBestMatchSoundCert := { - layer1Pattern := p1 - layer2Pattern := p2 - layer2Value := vlogit.value - layer2Logit? := vlogit.logit? - deltaLowerBound := vlogit.value.outputCoordLowerBound - } - if cert.check then - return cert - throw "induction head certificate failed internal consistency checks" - | _, _, _ => - throw "target layer not reached" - action.run - - -/-- Compute a combined sound certificate for an induction-style head pair (best-match, -binary only). -/ -def certifyInductionSoundBestMatch - (path : System.FilePath) - (layer1 head1 layer2 head2 coord : Nat) - (queryPos? : Option Nat := none) - (eps : Rat) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (offset1 : Int := -1) - (offset2 : Int := -1) - (keyOffset1 : Int := 0) - (keyOffset2 : Int := 0) - (maxSeqLen : Nat := 256) - (scalePow10 : Nat := defaultBinaryScalePow10) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (useAffine : Bool := false) - (iterTighten : Bool := false) - (targetToken? : Option Nat := none) - (negativeToken? : Option Nat := none) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) - (causalPattern : Bool := true) : - IO (Except String InductionHeadBestMatchSoundCert) := do - if inputDelta < 0 then - return .error "delta must be nonnegative" - let action : ExceptT String IO InductionHeadBestMatchSoundCert := do - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let firstLine := (← h.getLine).trim - if firstLine = "NFP_BINARY_V1" then - let inputPath := inputPath?.getD path - let timingEnabled ← ExceptT.lift <| IO.getEnv "NFP_TIMING" - let timing : Bool := timingEnabled.isSome - let timeIt {α : Type} (label : String) (action : ExceptT String IO α) : - ExceptT String IO α := do - if !timing then - action - else - let t0 ← ExceptT.lift IO.monoNanosNow - let r ← action - let t1 ← ExceptT.lift IO.monoNanosNow - let dtNs := t1 - t0 - let dtMs := dtNs / 1000000 - ExceptT.lift <| IO.eprintln s!"timing:{label} {dtMs}ms" - return r - let loadSharedAndDirection (scalePow10 : Nat) : - ExceptT String IO (SharedBinaryInputs × Option (Thunk (Array Fixed10Interval))) := do - let sharedTask ← - ExceptT.lift <| IO.asTask (loadSharedBinaryInputs path inputPath inputDelta scalePow10) - let directionTask? ← - match targetToken?, negativeToken? with - | none, none => pure none - | some targetToken, some negativeToken => - let task ← - ExceptT.lift <| - IO.asTask - (readLogitDiffDirectionBinary - path targetToken negativeToken scalePow10 fixedUlpSlack) - pure (some task) - | _, _ => - throw "use both target and negative tokens (or neither)" - let shared ← - match sharedTask.get with - | .error e => throw (toString e) - | .ok (.error msg) => throw msg - | .ok (.ok v) => pure v - let direction? ← - match directionTask? with - | none => pure none - | some task => - match task.get with - | .error e => throw (toString e) - | .ok (.error msg) => throw msg - | .ok (.ok (hdrDir, dir)) => - if hdrDir.modelDim ≠ shared.hdr.modelDim then - throw "unembedding model_dim mismatch" - pure (some (Thunk.mk (fun () => dir))) - return (shared, direction?) - let computeBestAtScale (scalePow10 : Nat) - (configs : Array (Bool × Nat × Nat)) : - ExceptT String IO (Rat × InductionHeadBestMatchSoundCert) := do - let (shared, direction?) ← loadSharedAndDirection scalePow10 - let queryPos : Nat := - match queryPos? with - | some q => q - | none => - if shared.hdr.seqLen = 0 then 0 else shared.hdr.seqLen - 1 - if queryPos ≥ shared.hdr.seqLen then - throw s!"queryPos {queryPos} out of range" - let prefixCache := mkSharedBinaryPrefix shared queryPos causalPattern - let computeCert (useTight : Bool) (tightLayers perRowLayers : Nat) : - ExceptT String IO InductionHeadBestMatchSoundCert := do - let label := - s!"scale={scalePow10} tight={useTight} tl={tightLayers} pr={perRowLayers}" - let cert ← - timeIt (s!"{label}:pair") <| - ExceptT.mk <| - certifyInductionSoundBestMatchLocalBinaryPair - path layer1 head1 layer2 head2 coord queryPos eps soundnessBits inputPath - inputDelta offset1 offset2 keyOffset1 keyOffset2 maxSeqLen scalePow10 useTight - tightLayers perRowLayers useAffine softmaxExpEffort causalPattern - (shared? := some shared) (prefix? := some prefixCache) - (targetToken? := targetToken?) (negativeToken? := negativeToken?) - (direction? := direction?) - return cert - let metricOf (cert : InductionHeadBestMatchSoundCert) : Rat := - match cert.layer2Logit? with - | some logit => logit.logitDiffLowerBound - | none => cert.deltaLowerBound - -- Avoid nested task pools when per-row MLP already spawns tasks. - let parallelConfigs : Bool := - configs.size > 1 && configs.all (fun (_, _, perRowLayers) => perRowLayers = 0) - let mut best : Option (Rat × InductionHeadBestMatchSoundCert) := none - if parallelConfigs then - let tasks ← - ExceptT.lift <| - configs.mapM fun (useTight, tightLayers, perRowLayers) => - IO.asTask (computeCert useTight tightLayers perRowLayers).run - let results := tasks.map (fun t => t.get) - for i in [:configs.size] do - let res := results[i]! - match res with - | .error e => throw (toString e) - | .ok (.error msg) => throw msg - | .ok (.ok cert) => - let metric := metricOf cert - best := - match best with - | none => some (metric, cert) - | some (bestMetric, bestCert) => - if metric > bestMetric then - some (metric, cert) - else - some (bestMetric, bestCert) - else - for i in [:configs.size] do - let (useTight, tightLayers, perRowLayers) := configs[i]! - let cert ← computeCert useTight tightLayers perRowLayers - let metric := metricOf cert - best := - match best with - | none => some (metric, cert) - | some (bestMetric, bestCert) => - if metric > bestMetric then - some (metric, cert) - else - some (bestMetric, bestCert) - match best with - | none => throw "no induction certs computed" - | some bestPair => return bestPair - let computeBestAtScaleOrdered (scalePow10 : Nat) - (configs : Array (Bool × Nat × Nat)) - (stopAtPositive : Bool) : - ExceptT String IO (Rat × InductionHeadBestMatchSoundCert) := do - let (shared, direction?) ← loadSharedAndDirection scalePow10 - let queryPos : Nat := - match queryPos? with - | some q => q - | none => - if shared.hdr.seqLen = 0 then 0 else shared.hdr.seqLen - 1 - if queryPos ≥ shared.hdr.seqLen then - throw s!"queryPos {queryPos} out of range" - let prefixCache := mkSharedBinaryPrefix shared queryPos causalPattern - let computeCert (useTight : Bool) (tightLayers perRowLayers : Nat) : - ExceptT String IO InductionHeadBestMatchSoundCert := do - let label := - s!"scale={scalePow10} tight={useTight} tl={tightLayers} pr={perRowLayers}" - let cert ← - timeIt (s!"{label}:pair") <| - ExceptT.mk <| - certifyInductionSoundBestMatchLocalBinaryPair - path layer1 head1 layer2 head2 coord queryPos eps soundnessBits inputPath - inputDelta offset1 offset2 keyOffset1 keyOffset2 maxSeqLen scalePow10 useTight - tightLayers perRowLayers useAffine softmaxExpEffort causalPattern - (shared? := some shared) (prefix? := some prefixCache) - (targetToken? := targetToken?) (negativeToken? := negativeToken?) - (direction? := direction?) - return cert - let metricOf (cert : InductionHeadBestMatchSoundCert) : Rat := - match cert.layer2Logit? with - | some logit => logit.logitDiffLowerBound - | none => cert.deltaLowerBound - let mut best : Option (Rat × InductionHeadBestMatchSoundCert) := none - for i in [:configs.size] do - let (useTight, tightLayers, perRowLayers) := configs[i]! - let cert ← computeCert useTight tightLayers perRowLayers - let metric := metricOf cert - if stopAtPositive && metric > 0 then - return (metric, cert) - best := - match best with - | none => some (metric, cert) - | some (bestMetric, bestCert) => - if metric > bestMetric then - some (metric, cert) - else - some (bestMetric, bestCert) - match best with - | none => throw "no induction certs computed" - | some bestPair => return bestPair - let maxLayer := Nat.max layer1 layer2 - let tightFull := Nat.max 1 maxLayer - let perRowFull := maxLayer - let normalizeConfig (useTight : Bool) (tightLayers perRowLayers : Nat) : - Bool × Nat × Nat := - if useTight then - (true, Nat.max 1 tightLayers, perRowLayers) - else - (false, 0, perRowLayers) - let pushUnique (configs : Array (Bool × Nat × Nat)) (cfg : Bool × Nat × Nat) : - Array (Bool × Nat × Nat) := - if configs.any (fun c => c == cfg) then configs else configs.push cfg - let baseCfg : Bool × Nat × Nat := - normalizeConfig tightPattern tightPatternLayers perRowPatternLayers - if !iterTighten then - let (_, cert) ← computeBestAtScale scalePow10 #[baseCfg] - return cert - else - let mut configs : Array (Bool × Nat × Nat) := #[baseCfg] - let needTightFull := (!tightPattern) || tightPatternLayers < tightFull - if needTightFull then - configs := pushUnique configs (normalizeConfig true tightFull perRowPatternLayers) - if perRowPatternLayers < perRowFull then - configs := pushUnique configs (normalizeConfig true tightFull perRowFull) - let scales : List Nat := [scalePow10, scalePow10 + 1, scalePow10 + 2] - let mut bestOverall : Option (Rat × InductionHeadBestMatchSoundCert) := none - for scale in scales do - let (metric, cert) ← - computeBestAtScaleOrdered scale configs (stopAtPositive := true) - bestOverall := - match bestOverall with - | none => some (metric, cert) - | some (bestMetric, bestCert) => - if metric > bestMetric then - some (metric, cert) - else - some (bestMetric, bestCert) - if metric > 0 then - return cert - match bestOverall with - | none => throw "no induction certs computed" - | some (_, cert) => return cert - else - throw "induction sound cert requires NFP_BINARY_V1" - action.run - -/-! ### Specs -/ - -theorem defaultBinaryScalePow10_spec_io : - defaultBinaryScalePow10 = defaultBinaryScalePow10 := rfl - -theorem maxAbsOfVector_spec_io : - maxAbsOfVector = maxAbsOfVector := rfl - -theorem certifyHeadBoundsBinary_spec_io : - certifyHeadBoundsBinary = certifyHeadBoundsBinary := rfl - -theorem certifyModelFileGlobalBinary_spec_io : - certifyModelFileGlobalBinary = certifyModelFileGlobalBinary := rfl - -theorem addVecIntervals_spec_io : - addVecIntervals = addVecIntervals := rfl - -theorem addConstVec_spec_io : - addConstVec = addConstVec := rfl - -theorem unionVecIntervals_spec_io : - unionVecIntervals = unionVecIntervals := rfl - -theorem zeroIntervals_spec_io : - zeroIntervals = zeroIntervals := rfl - -theorem unionRows_spec_io : - unionRows = unionRows := rfl - -theorem layerNormRowApprox_spec_io : - layerNormRowApprox = layerNormRowApprox := rfl - -theorem minVarAcrossRows_spec_io : - minVarAcrossRows = minVarAcrossRows := rfl - -theorem findLineIdxFrom_spec_io : - findLineIdxFrom = findLineIdxFrom := rfl - -theorem skipUntil_spec_io : - skipUntil = skipUntil := rfl - -theorem skipBlankLines_spec_io : - skipBlankLines = skipBlankLines := rfl - -theorem countWsTokens_spec_io : - countWsTokens = countWsTokens := rfl - -theorem consumeTokensSkipFast_spec_io : - consumeTokensSkipFast = consumeTokensSkipFast := rfl - -theorem consumeMatrixSkip_spec_io : - consumeMatrixSkip = consumeMatrixSkip := rfl - -theorem consumeMatrixSkipFast_spec_io : - consumeMatrixSkipFast = consumeMatrixSkipFast := rfl - -theorem consumeVectorSkipFast_spec_io : - consumeVectorSkipFast = consumeVectorSkipFast := rfl - -theorem consumeMatrixMulAndNormInf_spec_io : - consumeMatrixMulAndNormInf = consumeMatrixMulAndNormInf := rfl - -theorem certifyModelFileGlobal_spec_io : - certifyModelFileGlobal = certifyModelFileGlobal := rfl - -theorem loadEmbeddingsIntervals_spec_io : - loadEmbeddingsIntervals = loadEmbeddingsIntervals := rfl - -theorem intervalsFromScaled_spec_io : - intervalsFromScaled = intervalsFromScaled := rfl - -theorem collectLayerNormParams_spec_io : - collectLayerNormParams = collectLayerNormParams := rfl - -theorem collectLayerNormParamsBinary_spec_io : - collectLayerNormParamsBinary = collectLayerNormParamsBinary := rfl - -theorem defaultFixedScalePow10_spec_io : - defaultFixedScalePow10 = defaultFixedScalePow10 := rfl - -theorem fixedUlpSlack_spec_io : - fixedUlpSlack = fixedUlpSlack := rfl - -theorem scaleCfgOfPow10_spec_io : - scaleCfgOfPow10 = scaleCfgOfPow10 := rfl - -theorem ratCeilMulNat_spec_io : - ratCeilMulNat = ratCeilMulNat := rfl - -theorem fixedMeanInterval_spec_io : - fixedMeanInterval = fixedMeanInterval := rfl - -theorem fixedVarianceLowerBoundRange_spec_io : - fixedVarianceLowerBoundRange = fixedVarianceLowerBoundRange := rfl - -theorem fixedLayerNormRowApprox_spec_io : - fixedLayerNormRowApprox = fixedLayerNormRowApprox := rfl - -theorem readVecIntervals_spec_io : - readVecIntervals = readVecIntervals := rfl - -theorem readVecIntervalsBinary_spec_io : - readVecIntervalsBinary = readVecIntervalsBinary := rfl - -theorem matMulIntervalsFromScaled_spec_io : - matMulIntervalsFromScaled = matMulIntervalsFromScaled := rfl - -theorem fixedDotInterval_spec_io : - fixedDotInterval = fixedDotInterval := rfl - -theorem maxAbsVecFixed_spec_io : - maxAbsVecFixed = maxAbsVecFixed := rfl - -theorem addVecFixed_spec_io : - addVecFixed = addVecFixed := rfl - -theorem addVecFixedRows_spec_io : - addVecFixedRows = addVecFixedRows := rfl - -theorem addRowsFixed_spec_io : - addRowsFixed = addRowsFixed := rfl - -theorem mlpRowFromScaled_spec_io : - mlpRowFromScaled = mlpRowFromScaled := rfl - -theorem mlpRowsFromScaled_spec_io : - mlpRowsFromScaled = mlpRowsFromScaled := rfl - -theorem groupUnionRowsByToken_spec_io : - groupUnionRowsByToken = groupUnionRowsByToken := rfl - -theorem unionRowsFixed_spec_io : - unionRowsFixed = unionRowsFixed := rfl - -theorem consumeMatrixMulAndNormInfFixed_spec_io : - consumeMatrixMulAndNormInfFixed = consumeMatrixMulAndNormInfFixed := rfl - -theorem consumeMatrixMulAndNormInfFixedBinary_spec_io : - consumeMatrixMulAndNormInfFixedBinary = consumeMatrixMulAndNormInfFixedBinary := rfl - -theorem loadEmbeddingsUnionFixed_spec_io : - loadEmbeddingsUnionFixed = loadEmbeddingsUnionFixed := rfl - -theorem loadEmbeddingsUnionFixedBinary_spec_io : - loadEmbeddingsUnionFixedBinary = loadEmbeddingsUnionFixedBinary := rfl - -theorem loadEmbeddingsIntervalsBinary_spec_io : - loadEmbeddingsIntervalsBinary = loadEmbeddingsIntervalsBinary := rfl - -theorem loadTokensBinary_spec_io : - loadTokensBinary = loadTokensBinary := rfl - -theorem skipToUnembeddingBinary_spec_io : - skipToUnembeddingBinary = skipToUnembeddingBinary := rfl - -theorem certifyHeadValueLowerBoundLocalBinaryAt_spec_io : - certifyHeadValueLowerBoundLocalBinaryAt = certifyHeadValueLowerBoundLocalBinaryAt := rfl - -theorem readUnembeddingColumnsBinary_spec_io : - readUnembeddingColumnsBinary = readUnembeddingColumnsBinary := rfl - -theorem readLogitDiffDirectionBinary_spec_io : - readLogitDiffDirectionBinary = readLogitDiffDirectionBinary := rfl - -theorem certifyHeadLogitDiffLowerBoundLocalBinaryAt_spec_io : - certifyHeadLogitDiffLowerBoundLocalBinaryAt = - certifyHeadLogitDiffLowerBoundLocalBinaryAt := rfl - -theorem ensureSoundCache_spec_io : - ensureSoundCache = ensureSoundCache := rfl - -theorem certifyModelFileLocalText_spec_io : - certifyModelFileLocalText = certifyModelFileLocalText := rfl - -theorem certifyModelFileLocal_spec_io : - certifyModelFileLocal = certifyModelFileLocal := rfl - -theorem certifyModelFileLocalBinary_spec_io : - certifyModelFileLocalBinary = certifyModelFileLocalBinary := rfl - -theorem certifyHeadBoundsLocalBinary_spec_io : - certifyHeadBoundsLocalBinary = certifyHeadBoundsLocalBinary := rfl - -theorem certifyHeadPatternLocalBinary_spec_io : - certifyHeadPatternLocalBinary = certifyHeadPatternLocalBinary := rfl - -theorem certifyHeadPatternBestMatchLocalBinary_spec_io : - certifyHeadPatternBestMatchLocalBinary = certifyHeadPatternBestMatchLocalBinary := rfl - -theorem certifyHeadPatternBestMatchLocalBinarySweep_spec_io : - certifyHeadPatternBestMatchLocalBinarySweep = - certifyHeadPatternBestMatchLocalBinarySweep := rfl - -theorem certifyHeadValueLowerBoundLocalBinary_spec_io : - certifyHeadValueLowerBoundLocalBinary = certifyHeadValueLowerBoundLocalBinary := rfl - -theorem certifyHeadLogitDiffLowerBoundLocalBinary_spec_io : - certifyHeadLogitDiffLowerBoundLocalBinary = certifyHeadLogitDiffLowerBoundLocalBinary := rfl - -theorem certifyModelFile_spec_io : - certifyModelFile = certifyModelFile := rfl - -theorem certifyHeadBounds_spec_io : - certifyHeadBounds = certifyHeadBounds := rfl - -theorem certifyHeadBoundsLocal_spec_io : - certifyHeadBoundsLocal = certifyHeadBoundsLocal := rfl - -theorem certifyHeadPatternLocal_spec_io : - certifyHeadPatternLocal = certifyHeadPatternLocal := rfl - -theorem certifyHeadPatternBestMatchLocal_spec_io : - certifyHeadPatternBestMatchLocal = certifyHeadPatternBestMatchLocal := rfl - -theorem certifyHeadPatternBestMatchLocalSweep_spec_io : - certifyHeadPatternBestMatchLocalSweep = certifyHeadPatternBestMatchLocalSweep := rfl - -theorem certifyLayerBestMatchMarginLocal_spec_io : - certifyLayerBestMatchMarginLocal = certifyLayerBestMatchMarginLocal := rfl - -theorem certifyHeadValueLowerBoundLocal_spec_io : - certifyHeadValueLowerBoundLocal = certifyHeadValueLowerBoundLocal := rfl - -theorem certifyHeadLogitDiffLowerBoundLocal_spec_io : - certifyHeadLogitDiffLowerBoundLocal = certifyHeadLogitDiffLowerBoundLocal := rfl - -theorem certifyInductionSound_spec_io : - certifyInductionSound = certifyInductionSound := rfl - -theorem certifyInductionSoundBestMatch_spec_io : - certifyInductionSoundBestMatch = certifyInductionSoundBestMatch := rfl - -end Nfp.Untrusted.SoundCompute diff --git a/Legacy/Nfp/Verification.lean b/Legacy/Nfp/Verification.lean deleted file mode 100644 index d4d56af..0000000 --- a/Legacy/Nfp/Verification.lean +++ /dev/null @@ -1,399 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Nfp.Discovery - -/-! -# Causal Verification (Head Ablation) - -This module implements **causal verification** of candidate circuits using an executable -forward pass with **head ablation** (a.k.a. zero-ablation). - -The core API is `verifyCircuit`, which: -- selects **energy-matched control heads** (same layer, closest output norm), -- checks runtime **axioms** (baseline competence, control independence, energy equivalence), -- and (only if axioms hold) measures **logit-difference impact** under ablation. - -This is intended to operationalize interchange-intervention style causality checks in a -fully executable setting. --/ - -namespace Nfp - -/-! ## Head references -/ - -/-- A reference to a specific attention head `(layer, head)`. -/ -structure HeadRef where - layerIdx : Nat - headIdx : Nat -deriving BEq, DecidableEq - -namespace HeadRef - -instance : Inhabited HeadRef := ⟨{ layerIdx := 0, headIdx := 0 }⟩ - -def toString (h : HeadRef) : String := - s!"L{h.layerIdx}H{h.headIdx}" - -instance : ToString HeadRef := ⟨toString⟩ - -def toComponentId (h : HeadRef) : ComponentId := - .head h.layerIdx h.headIdx - -end HeadRef - -/-! ## Induction target utilities -/ - -/-- Derive the **induction target token** from the model's `inputTokens`, if available. - -Let `T` be the token sequence and `t_curr = T[last]`. If `t_curr` appeared previously at index `k`, -the induction target is `t_tgt = T[k+1]` (the successor of the previous match). - -Returns `none` if there is no prior repetition of the last token or no token history. -/ -def inductionTargetTokenFromHistory (model : ConcreteModel) : Option Nat := do - let tokens ← model.inputTokens - if tokens.size = 0 then none else - let lastIdx := tokens.size - 1 - let tCurr := tokens[lastIdx]! - let mut foundIdx : Option Nat := none - let mut idx := lastIdx - while idx > 0 && foundIdx.isNone do - idx := idx - 1 - if tokens[idx]! = tCurr then - foundIdx := some idx - let k ← foundIdx - some (tokens[k + 1]!) - -/-! ## Verification configuration and results -/ - -/-- Runtime parameters for causal verification via head ablation. -/ -structure VerificationConfig where - /-- Competence threshold ε for baseline logit-difference Δ. -/ - competenceEpsilon : Float := 1e-3 - /-- Relative tolerance δ for energy matching: |‖out_cand‖ - ‖out_ctrl‖| < δ · ‖out_cand‖. -/ - energyRelTol : Float := 0.05 - /-- Absolute fallback tolerance for tiny-norm heads. -/ - energyAbsTol : Float := 1e-6 - /-- Precision (bits) for dyadic sqrt bounds in SOUND-mode checks. -/ - soundnessBits : Nat := 20 - /-- Whether to run attention causally (autoregressive). -/ - causal : Bool := true - -/-- The outcome of checking the runtime axioms required for causal interpretation. -/ -structure AxiomStatus where - baselineCompetence : Bool - controlIndependence : Bool - energyEquivalence : Bool - /-- Human-readable failures (empty iff all axioms hold). -/ - failures : Array String := #[] - -namespace AxiomStatus - -def verified (s : AxiomStatus) : Bool := - s.failures.isEmpty && s.baselineCompetence && s.controlIndependence && s.energyEquivalence - -end AxiomStatus - -/-- A single verification row for a candidate circuit. -/ -structure CircuitVerificationRow where - /-- Candidate circuit heads (ablated together). -/ - candidateHeads : Array HeadRef - /-- Selected control heads (one per candidate head, same layers). -/ - controlHeads : Option (Array HeadRef) - /-- Baseline logit-difference Δ = logit(t_tgt) - logit(t_neg). -/ - baseDelta : Float - /-- Ablated Δ for the candidate circuit, if axioms verified. -/ - ablatedDelta : Option Float - /-- Candidate causal impact = Δ_base - Δ_ablated, if axioms verified. -/ - impact : Option Float - /-- Relative score = impact / Δ_base (only defined if baseline competence holds). -/ - relScore : Option Float - /-- Control impact = Δ_base - Δ_ctrlAblated, if axioms verified. -/ - controlImpact : Option Float - /-- Axiom status (checked *before* computing impact scores). -/ - axioms : AxiomStatus - -namespace CircuitVerificationRow - -def candidateLabel (r : CircuitVerificationRow) : String := - if r.candidateHeads.size = 2 then - let h1 := r.candidateHeads[0]! - let h2 := r.candidateHeads[1]! - s!"{h1} -> {h2}" - else - String.intercalate "," (r.candidateHeads.toList.map (fun h => toString h)) - -end CircuitVerificationRow - -/-! ## Low-level helpers: logits, controls, and head ablation -/ - -private def logitAt (residual : ConcreteMatrix) (pos : Nat) - (W_U : ConcreteMatrix) (token : Nat) : Option Float := - if residual.numCols = W_U.numRows ∧ pos < residual.numRows ∧ token < W_U.numCols then - some <| Id.run do - let d := residual.numCols - let vocab := W_U.numCols - let rowBase := pos * d - let mut acc : Float := 0.0 - for k in [:d] do - -- SAFETY: `pos < residual.numRows` and `k < d = residual.numCols`. - let x := residual.data[rowBase + k]! - -- SAFETY: `k < W_U.numRows` and `token < vocab = W_U.numCols`. - let w := W_U.data[k * vocab + token]! - acc := acc + x * w - return acc - else none - -private def deltaAt (residual : ConcreteMatrix) (pos : Nat) - (W_U : ConcreteMatrix) (targetToken negativeToken : Nat) : Float := - let targetLogit := (logitAt residual pos W_U targetToken).getD 0.0 - let negLogit := (logitAt residual pos W_U negativeToken).getD 0.0 - targetLogit - negLogit - -private def topNonTargetToken (residual : ConcreteMatrix) (pos : Nat) - (W_U : ConcreteMatrix) (targetToken : Nat) : Option (Nat × Float) := Id.run do - if residual.numCols = W_U.numRows ∧ pos < residual.numRows ∧ - targetToken < W_U.numCols ∧ W_U.numCols ≥ 2 then - let d := residual.numCols - let vocab := W_U.numCols - let rowBase := pos * d - let mut bestTok : Nat := 0 - let mut bestLogit : Float := (-Float.inf) - let mut found : Bool := false - for tok in [:vocab] do - if tok ≠ targetToken then - found := true - let mut acc : Float := 0.0 - for k in [:d] do - -- SAFETY: `pos < residual.numRows` and `k < d = residual.numCols`. - let x := residual.data[rowBase + k]! - -- SAFETY: `k < W_U.numRows` and `tok < vocab = W_U.numCols`. - let w := W_U.data[k * vocab + tok]! - acc := acc + x * w - if acc > bestLogit then - bestTok := tok - bestLogit := acc - if found then - return some (bestTok, bestLogit) - else return none - else - return none - -private def fullCircuit (model : ConcreteModel) : ConcreteCircuit := - let headsPerLayer := - Array.ofFn (fun l : Fin model.numLayers => (model.layers.getD l.1 #[]).size) - let neuronsPerLayer := - Array.ofFn (fun l : Fin model.numLayers => model.numNeuronsAtLayer l.1) - ConcreteCircuit.full model.numLayers headsPerLayer neuronsPerLayer - -private def runForwardAblatingHeads (model : ConcreteModel) (heads : Array HeadRef) - (causal : Bool := true) : ForwardPassResult := - let base := fullCircuit model - let circuit := heads.foldl (fun c h => c.removeComponent h.toComponentId) base - model.runAblatedForward circuit causal - -private def headOutputNorm? (fwd : ForwardPassResult) (h : HeadRef) : Option Float := - if hl : h.layerIdx < fwd.attnOutputs.size then - let layerOut := fwd.attnOutputs[h.layerIdx] - if hh : h.headIdx < layerOut.size then - some ((layerOut[h.headIdx]'hh).frobeniusNorm) - else none - else none - -private structure ControlSelection where - candidate : HeadRef - control : HeadRef - candNorm : Float - ctrlNorm : Float - absDiff : Float - -private def selectEnergyMatchedControl (fwd : ForwardPassResult) - (cand : HeadRef) (exclude : Array HeadRef) : Option ControlSelection := Id.run do - match headOutputNorm? fwd cand with - | none => none - | some candNorm => - if hl : cand.layerIdx < fwd.attnOutputs.size then - let layerOut := fwd.attnOutputs[cand.layerIdx] - let best : Option (HeadRef × Float × Float) := Id.run do - let mut best : Option (HeadRef × Float × Float) := none - for hIdx in [:layerOut.size] do - let h : HeadRef := { layerIdx := cand.layerIdx, headIdx := hIdx } - if !exclude.contains h then - if hh : hIdx < layerOut.size then - let norm := (layerOut[hIdx]'hh).frobeniusNorm - let diff := Float.abs (candNorm - norm) - match best with - | none => - best := some (h, norm, diff) - | some (_, _, bestDiff) => - if diff < bestDiff then - best := some (h, norm, diff) - return best - match best with - | none => none - | some (ctrl, ctrlNorm, absDiff) => - some { candidate := cand, control := ctrl, candNorm, ctrlNorm, absDiff } - else - none - -private def energyEquivalent (cfg : VerificationConfig) (candNorm : Float) - (absDiff : Float) : Bool := - let thresh := max cfg.energyAbsTol (cfg.energyRelTol * candNorm) - absDiff < thresh - -/-! ## Verification context and core API -/ - -/-- Shared baseline information for circuit verification on a fixed prompt. -/ -structure VerificationContext where - model : ConcreteModel - W_U : ConcreteMatrix - /-- Logits are evaluated at this position (typically the last token). -/ - pos : Nat - targetToken : Nat - negativeToken : Nat - baseTargetLogit : Float - baseNegLogit : Float - baseDelta : Float - baselineForward : ForwardPassResult - cfg : VerificationConfig - -namespace VerificationContext - -/-- Build a verification context for a fixed target token. - -The negative token `t_neg` is chosen as the **top non-target** logit at the evaluation position. --/ -def build (model : ConcreteModel) (targetToken : Nat) - (cfg : VerificationConfig := {}) : Except String VerificationContext := - match model.unembedding with - | none => .error "Model is missing UNEMBEDDING; cannot compute logits." - | some W_U => - if model.seqLen = 0 then - .error "Model has seqLen = 0; cannot evaluate logits." - else - let pos := model.seqLen - 1 - let baselineForward := model.runForward cfg.causal - let residual := baselineForward.finalOutput - match logitAt residual pos W_U targetToken with - | none => - .error s!"Cannot compute logit(target={targetToken}) at pos={pos} \ - (dimension mismatch or token OOB)." - | some baseTargetLogit => - match topNonTargetToken residual pos W_U targetToken with - | none => - .error "Cannot select top non-target token (need vocab ≥ 2 and target in-bounds)." - | some (negativeToken, baseNegLogit) => - let baseDelta := baseTargetLogit - baseNegLogit - .ok { - model := model - W_U := W_U - pos := pos - targetToken := targetToken - negativeToken := negativeToken - baseTargetLogit := baseTargetLogit - baseNegLogit := baseNegLogit - baseDelta := baseDelta - baselineForward := baselineForward - cfg := cfg - } - -end VerificationContext - -/-- Verify a candidate circuit by **head ablation** with an energy-matched control. - -This function enforces the runtime axioms: -1. Baseline competence: `Δ_base > ε` -2. Control independence: control heads disjoint from candidate heads -3. Energy equivalence: per-head output norms match within tolerance - -Only if these axioms are verified do we run the ablations and compute impact scores. --/ -def verifyCircuit (ctx : VerificationContext) (candidateHeads : Array HeadRef) : - CircuitVerificationRow := Id.run do - let baseDelta := ctx.baseDelta - let competence := baseDelta > ctx.cfg.competenceEpsilon - let mut failures : Array String := Array.mkEmpty (candidateHeads.size + 3) - if !competence then - failures := failures.push s!"Axiom1(baseline competence) failed: Δ_base={baseDelta} ≤ \ - ε={ctx.cfg.competenceEpsilon}" - let axioms : AxiomStatus := { - baselineCompetence := false - controlIndependence := true - energyEquivalence := true - failures := failures - } - return { - candidateHeads := candidateHeads - controlHeads := none - baseDelta := baseDelta - ablatedDelta := none - impact := none - relScore := none - controlImpact := none - axioms := axioms - } - - -- Choose one control head per candidate head (same layer, closest output norm). - let mut selections : Array ControlSelection := Array.mkEmpty candidateHeads.size - for cand in candidateHeads do - match selectEnergyMatchedControl ctx.baselineForward cand candidateHeads with - | some sel => selections := selections.push sel - | none => - failures := failures.push s!"No control head found in layer {cand.layerIdx} for {cand}" - - let controlsComplete := selections.size = candidateHeads.size - let controlHeads : Array HeadRef := selections.map (·.control) - let independence := !(controlHeads.any candidateHeads.contains) - if !independence then - failures := failures.push "Axiom2(control independence) failed: control overlaps candidate." - - let energyOk := - controlsComplete && selections.all (fun s => energyEquivalent ctx.cfg s.candNorm s.absDiff) - if !energyOk then - failures := failures.push "Axiom3(energy equivalence) failed: no sufficiently matched control." - - let axioms : AxiomStatus := { - baselineCompetence := competence - controlIndependence := independence - energyEquivalence := energyOk - failures := failures - } - - if !axioms.verified then - return { - candidateHeads := candidateHeads - controlHeads := if controlsComplete then some controlHeads else none - baseDelta := baseDelta - ablatedDelta := none - impact := none - relScore := none - controlImpact := none - axioms := axioms - } - - -- Candidate ablation - let fwdAblated := runForwardAblatingHeads ctx.model candidateHeads ctx.cfg.causal - let residualAblated := fwdAblated.finalOutput - let ablatedDelta := - deltaAt residualAblated ctx.pos ctx.W_U ctx.targetToken ctx.negativeToken - let impact := baseDelta - ablatedDelta - let relScore := if baseDelta > 0.0 then impact / baseDelta else 0.0 - - -- Control ablation (energy-matched, layer-matched) - let fwdCtrl := runForwardAblatingHeads ctx.model controlHeads ctx.cfg.causal - let residualCtrl := fwdCtrl.finalOutput - let ctrlDelta := - deltaAt residualCtrl ctx.pos ctx.W_U ctx.targetToken ctx.negativeToken - let controlImpact := baseDelta - ctrlDelta - - return { - candidateHeads := candidateHeads - controlHeads := some controlHeads - baseDelta := baseDelta - ablatedDelta := some ablatedDelta - impact := some impact - relScore := some relScore - controlImpact := some controlImpact - axioms := axioms - } - -end Nfp diff --git a/README.md b/README.md index 5ae409d..247dd5d 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ This is research tooling. Interfaces may change; please treat results as experim ## Tabula Rasa Rewrite (current state) -The `tabula-rasa` branch is a fresh, minimal Lean 4 core focused on circuit certification. The legacy system remains in `Legacy/Nfp/` and is not built by default. +The `tabula-rasa` branch is a fresh, minimal Lean 4 core focused on circuit certification. Current core modules (new): - `Nfp/Core`, `Nfp/Prob`, `Nfp/Mixer`, `Nfp/System` define basic mass/probability, mixers, and DAG-backed local systems. From 27fabed979546fd58582f51687d07141b4e36406 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 23:22:48 +0100 Subject: [PATCH 088/244] Add one-hot bounds certification --- Nfp/Circuit/Layers/Induction.lean | 183 ++++++++++++++++++++++++++++++ 1 file changed, 183 insertions(+) diff --git a/Nfp/Circuit/Layers/Induction.lean b/Nfp/Circuit/Layers/Induction.lean index aa7ede7..d618fec 100644 --- a/Nfp/Circuit/Layers/Induction.lean +++ b/Nfp/Circuit/Layers/Induction.lean @@ -1,5 +1,6 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later +import Mathlib.Algebra.BigOperators.Group.Finset.Basic import Nfp.Circuit.Layers.Attention /-! @@ -14,6 +15,8 @@ namespace Layers universe v +open scoped BigOperators + section Weights variable {Val : Type v} [NonAssocSemiring Val] @@ -57,6 +60,59 @@ def prevIndex : Fin (Nat.succ n) → Fin (Nat.succ n) end Spec +section Bounds + +variable {Val : Type v} [Semiring Val] [PartialOrder Val] +variable {seq : Nat} [NeZero seq] + +/-- Numeric bounds certifying one-hot weights on nonzero queries. -/ +structure OneHotBoundsOn (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) : Prop where + /-- All weights are nonnegative on nonzero queries. -/ + nonneg : ∀ q, q ≠ 0 → ∀ k, 0 ≤ weights q k + /-- Weights sum to one on nonzero queries. -/ + sum_one : ∀ q, q ≠ 0 → (∑ k, weights q k) = 1 + /-- Non-prev weights are nonpositive on nonzero queries. -/ + other_le_zero : ∀ q, q ≠ 0 → ∀ k, k ≠ prev q → weights q k ≤ 0 + +/-- Certified bounds imply one-hot weights on nonzero queries. -/ +theorem oneHot_of_boundsOn (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) [DecidableEq (Fin seq)] + (h : OneHotBoundsOn prev weights) : + ∀ q, q ≠ 0 → weights q = Pi.single (prev q) 1 := by + intro q hq + funext k + by_cases hk : k = prev q + · subst hk + have hzero : + (∑ k ∈ (Finset.univ.erase (prev q)), weights q k) = 0 := by + refine Finset.sum_eq_zero ?_ + intro k hk' + have hkne : k ≠ prev q := (Finset.mem_erase.1 hk').1 + have hle : weights q k ≤ 0 := h.other_le_zero q hq k hkne + have hge : 0 ≤ weights q k := h.nonneg q hq k + exact le_antisymm hle hge + have hsum : + weights q (prev q) + + ∑ k ∈ (Finset.univ.erase (prev q)), weights q k = + ∑ k, weights q k := by + simpa using + (Finset.add_sum_erase + (s := (Finset.univ : Finset (Fin seq))) + (f := weights q) (a := prev q) (by simp)) + have hprev : weights q (prev q) = 1 := by + have hsum' : + weights q (prev q) + 0 = 1 := by + simpa [hzero, h.sum_one q hq] using hsum + simpa using hsum' + simp [Pi.single, hprev] + · have hle : weights q k ≤ 0 := h.other_le_zero q hq k hk + have hge : 0 ≤ weights q k := h.nonneg q hq k + have hzero : weights q k = 0 := le_antisymm hle hge + simp [Pi.single, hk, hzero] + +end Bounds + section Attention variable {Batch : Type} [Fintype Batch] [DecidableEq Batch] @@ -226,6 +282,133 @@ end Typed end Attention +section InductionSpecTyped + +variable {Batch : Type} [Fintype Batch] [DecidableEq Batch] +variable {heads dim n : Nat} +variable {Val : Type v} [NonAssocSemiring Val] + +variable (scale : Val) +variable (softmax : (Fin (Nat.succ n) → Val) → Fin (Nat.succ n) → Val) + +/-- One-hot weights on nonzero queries imply the induction spec for typed evaluation. -/ +theorem attentionTyped_eval_inductionSpec_of_oneHot + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (input : AttentionInput Batch (Nat.succ n) heads dim → Val) + (b : Batch) (h : Fin heads) (d : Fin dim) + (hweights : + ∀ q, q ≠ 0 → + attentionOutWeights + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + b h q d + (fun j _ => + Circuit.evalInput + (attentionCircuit + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax) + ((attentionInterface + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax).toInputAssignment input) j) = + Pi.single (prev q) 1) : + InductionSpec (n := n) prev + (fun q => + (attentionTyped + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax).eval input (b, q, h, d)) + (fun k => + input (attnInputV + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + (b, k, h, d))) := by + intro q hq + have hweights_q := hweights q hq + exact attentionTyped_eval_out_eq_of_oneHot + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + (scale := scale) + (softmax := softmax) + (prev := prev) + (input := input) + (b := b) + (h := h) + (q := q) + (d := d) + hweights_q + +/-- Induction spec for `prevIndex` under one-hot weight hypotheses. -/ +theorem attentionTyped_eval_inductionSpec_prevIndex + (input : AttentionInput Batch (Nat.succ n) heads dim → Val) + (b : Batch) (h : Fin heads) (d : Fin dim) + (hweights : + ∀ q, q ≠ 0 → + attentionOutWeights + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + b h q d + (fun j _ => + Circuit.evalInput + (attentionCircuit + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax) + ((attentionInterface + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax).toInputAssignment input) j) = + Pi.single (prevIndex (n := n) q) 1) : + InductionSpec (n := n) (prevIndex (n := n)) + (fun q => + (attentionTyped + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax).eval input (b, q, h, d)) + (fun k => + input (attnInputV + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + (b, k, h, d))) := by + exact attentionTyped_eval_inductionSpec_of_oneHot + (Batch := Batch) + (heads := heads) + (dim := dim) + (n := n) + (scale := scale) + (softmax := softmax) + (prev := prevIndex (n := n)) + (input := input) + (b := b) + (h := h) + (d := d) + hweights + +end InductionSpecTyped + end Layers end Circuit From cf4a9268cbf4dbf387efac0dc0a1ec0ff6598e24 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 23:53:50 +0100 Subject: [PATCH 089/244] Add approximate induction specs --- Nfp/Circuit/Layers/Induction.lean | 184 ++++++++++++++++++++++++++++++ 1 file changed, 184 insertions(+) diff --git a/Nfp/Circuit/Layers/Induction.lean b/Nfp/Circuit/Layers/Induction.lean index d618fec..2f8862d 100644 --- a/Nfp/Circuit/Layers/Induction.lean +++ b/Nfp/Circuit/Layers/Induction.lean @@ -1,6 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Mathlib.Algebra.BigOperators.Group.Finset.Basic +import Mathlib.Algebra.Order.Monoid.Unbundled.Basic import Nfp.Circuit.Layers.Attention /-! @@ -60,6 +61,32 @@ def prevIndex : Fin (Nat.succ n) → Fin (Nat.succ n) end Spec +section ApproxSpec + +variable {Val : Type v} [AddCommMonoid Val] [PartialOrder Val] [IsOrderedAddMonoid Val] +variable {n : Nat} + +/-- Approximate induction-head spec: outputs are within `ε` of `prev` values. -/ +def InductionSpecApprox (ε : Val) + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (out vals : Fin (Nat.succ n) → Val) : Prop := + ∀ q, q ≠ 0 → out q ≤ vals (prev q) + ε ∧ vals (prev q) ≤ out q + ε + +/-- Exact induction spec implies the approximate spec for any nonnegative tolerance. -/ +theorem inductionSpecApprox_of_spec (ε : Val) (hε : 0 ≤ ε) + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (out vals : Fin (Nat.succ n) → Val) + (h : InductionSpec prev out vals) : + InductionSpecApprox (Val := Val) (n := n) ε prev out vals := by + intro q hq + have hq' : out q = vals (prev q) := h q hq + constructor <;> + simpa [hq'] using + (le_add_of_nonneg_right hε : + vals (prev q) ≤ vals (prev q) + ε) + +end ApproxSpec + section Bounds variable {Val : Type v} [Semiring Val] [PartialOrder Val] @@ -113,6 +140,94 @@ theorem oneHot_of_boundsOn (prev : Fin seq → Fin seq) end Bounds +section ApproxBounds + +variable {Val : Type v} [Semiring Val] [PartialOrder Val] +variable {seq : Nat} [NeZero seq] + +/-- Approximate one-hot bounds for attention weights on nonzero queries. -/ +structure OneHotApproxBoundsOn (ε : Val) (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) : Prop where + /-- All weights are nonnegative on nonzero queries. -/ + nonneg : ∀ q, q ≠ 0 → ∀ k, 0 ≤ weights q k + /-- Weights sum to one on nonzero queries. -/ + sum_one : ∀ q, q ≠ 0 → (∑ k, weights q k) = 1 + /-- The `prev` weight is within `ε` of one on nonzero queries. -/ + prev_large : ∀ q, q ≠ 0 → 1 ≤ weights q (prev q) + ε + /-- Non-prev weights are at most `ε` on nonzero queries. -/ + other_le : ∀ q, q ≠ 0 → ∀ k, k ≠ prev q → weights q k ≤ ε + +/-- Approximate induction weights: prev weight near one, others at most `ε`. -/ +def InductionWeightsApprox (ε : Val) (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) : Prop := + ∀ q, q ≠ 0 → + 1 ≤ weights q (prev q) + ε ∧ + ∀ k, k ≠ prev q → weights q k ≤ ε + +/-- Approximate bounds imply approximate induction weights. -/ +theorem inductionWeightsApprox_of_boundsOn (ε : Val) (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) + (h : OneHotApproxBoundsOn ε prev weights) : + InductionWeightsApprox (Val := Val) ε prev weights := by + intro q hq + exact ⟨h.prev_large q hq, h.other_le q hq⟩ + +end ApproxBounds + +section SoftmaxMargin + +variable {Val : Type v} [Semiring Val] [PartialOrder Val] +variable {seq : Nat} [NeZero seq] + +/-- Softmax margin certificates for approximate one-hot weights. -/ +structure SoftmaxMarginBounds (ε margin : Val) (prev : Fin seq → Fin seq) + (scores weights : Fin seq → Fin seq → Val) : Prop where + /-- Score gap between `prev` and other keys on nonzero queries. -/ + score_margin : ∀ q, q ≠ 0 → ∀ k, k ≠ prev q → scores q k + margin ≤ scores q (prev q) + /-- All weights are nonnegative on nonzero queries. -/ + nonneg : ∀ q, q ≠ 0 → ∀ k, 0 ≤ weights q k + /-- Weights sum to one on nonzero queries. -/ + sum_one : ∀ q, q ≠ 0 → (∑ k, weights q k) = 1 + /-- The `prev` weight is within `ε` of one on nonzero queries. -/ + prev_large : ∀ q, q ≠ 0 → 1 ≤ weights q (prev q) + ε + /-- Non-prev weights are at most `ε` on nonzero queries. -/ + other_le : ∀ q, q ≠ 0 → ∀ k, k ≠ prev q → weights q k ≤ ε + +/-- Margin certificates yield approximate one-hot bounds for the weights. -/ +theorem oneHotApproxBounds_of_softmaxMargin (ε margin : Val) (prev : Fin seq → Fin seq) + (scores weights : Fin seq → Fin seq → Val) + (h : SoftmaxMarginBounds (Val := Val) ε margin prev scores weights) : + OneHotApproxBoundsOn (Val := Val) ε prev weights := by + exact + { nonneg := h.nonneg + sum_one := h.sum_one + prev_large := h.prev_large + other_le := h.other_le } + +/-- Margin certificates imply approximate induction-weight bounds. -/ +theorem inductionWeightsApprox_of_softmaxMargin (ε margin : Val) + (prev : Fin seq → Fin seq) + (scores weights : Fin seq → Fin seq → Val) + (h : SoftmaxMarginBounds (Val := Val) ε margin prev scores weights) : + InductionWeightsApprox (Val := Val) ε prev weights := by + exact inductionWeightsApprox_of_boundsOn + (Val := Val) + (seq := seq) + (ε := ε) + (prev := prev) + (weights := weights) + (h := oneHotApproxBounds_of_softmaxMargin + (Val := Val) + (seq := seq) + (ε := ε) + (margin := margin) + (prev := prev) + (scores := scores) + (weights := weights) + h) + +end SoftmaxMargin + section Attention variable {Batch : Type} [Fintype Batch] [DecidableEq Batch] @@ -409,6 +524,75 @@ theorem attentionTyped_eval_inductionSpec_prevIndex end InductionSpecTyped +section InductionSpecApproxTyped + +variable {Batch : Type} [Fintype Batch] [DecidableEq Batch] +variable {heads dim n : Nat} +variable {Val : Type v} [NonAssocSemiring Val] [PartialOrder Val] [IsOrderedAddMonoid Val] + +variable (scale : Val) +variable (softmax : (Fin (Nat.succ n) → Val) → Fin (Nat.succ n) → Val) + +/-- One-hot weights imply the approximate induction spec for any nonnegative tolerance. -/ +theorem attentionTyped_eval_inductionSpecApprox_of_oneHot (ε : Val) (hε : 0 ≤ ε) + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (input : AttentionInput Batch (Nat.succ n) heads dim → Val) + (b : Batch) (h : Fin heads) (d : Fin dim) + (hweights : + ∀ q, q ≠ 0 → + attentionOutWeights + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + b h q d + (fun j _ => + Circuit.evalInput + (attentionCircuit + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax) + ((attentionInterface + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax).toInputAssignment input) j) = + Pi.single (prev q) 1) : + InductionSpecApprox (Val := Val) (n := n) ε prev + (fun q => + (attentionTyped + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax).eval input (b, q, h, d)) + (fun k => + input (attnInputV + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + (b, k, h, d))) := by + apply inductionSpecApprox_of_spec (Val := Val) (n := n) (ε := ε) hε + exact attentionTyped_eval_inductionSpec_of_oneHot + (Batch := Batch) + (heads := heads) + (dim := dim) + (n := n) + (scale := scale) + (softmax := softmax) + (prev := prev) + (input := input) + (b := b) + (h := h) + (d := d) + hweights + +end InductionSpecApproxTyped + end Layers end Circuit From e007989c8bf8136d8971d5f71ee9f5bf93742061 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 2 Jan 2026 23:54:16 +0100 Subject: [PATCH 090/244] Add softmax margin certificate checker --- AGENTS.md | 2 + Nfp/Circuit.lean | 1 + Nfp/Circuit/Cert/SoftmaxMargin.lean | 335 ++++++++++++++++++++++++++++ 3 files changed, 338 insertions(+) create mode 100644 Nfp/Circuit/Cert/SoftmaxMargin.lean diff --git a/AGENTS.md b/AGENTS.md index 921e6cc..e893def 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -283,6 +283,8 @@ but you **must** update this list in the same commit. - Basic well-formedness conditions for circuit inputs. - `Nfp/Circuit/Cert.lean` - Equivalence definition and finite checker. +- `Nfp/Circuit/Cert/SoftmaxMargin.lean` + - Softmax-margin certificate payloads and checker soundness. - `Nfp/Circuit/Typed.lean` - Typed circuit wrapper and interface-level equivalence checker. - `Nfp/Circuit/Compose.lean` diff --git a/Nfp/Circuit.lean b/Nfp/Circuit.lean index 51917ea..2f57e89 100644 --- a/Nfp/Circuit.lean +++ b/Nfp/Circuit.lean @@ -6,6 +6,7 @@ import Nfp.Circuit.Interface import Nfp.Circuit.Semantics import Nfp.Circuit.WellFormed import Nfp.Circuit.Cert +import Nfp.Circuit.Cert.SoftmaxMargin import Nfp.Circuit.Typed import Nfp.Circuit.Compose import Nfp.Circuit.Gates diff --git a/Nfp/Circuit/Cert/SoftmaxMargin.lean b/Nfp/Circuit/Cert/SoftmaxMargin.lean new file mode 100644 index 0000000..a59ffe6 --- /dev/null +++ b/Nfp/Circuit/Cert/SoftmaxMargin.lean @@ -0,0 +1,335 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.BigOperators.Group.Finset.Basic +import Mathlib.Algebra.Order.Ring.Rat +import Nfp.Circuit.Cert +import Nfp.Circuit.Layers.Induction + +/-! +Softmax-margin certificates for approximate one-hot attention weights. +-/ + +namespace Nfp + +namespace Circuit + +open scoped BigOperators + +variable {seq : Nat} + +/-- Certificate payload for softmax-margin bounds (Rat-valued). -/ +structure SoftmaxMarginCert (seq : Nat) where + /-- Weight tolerance. -/ + eps : Rat + /-- Score margin used to justify weight bounds. -/ + margin : Rat + /-- `prev` selector for induction-style attention. -/ + prev : Fin seq → Fin seq + /-- Score matrix entries. -/ + scores : Fin seq → Fin seq → Rat + /-- Attention weight entries. -/ + weights : Fin seq → Fin seq → Rat + +/-- Boolean checker for softmax-margin certificates. -/ +def checkSoftmaxMarginCert [NeZero seq] (c : SoftmaxMarginCert seq) : Bool := + finsetAll (Finset.univ : Finset (Fin seq)) (fun q => + if q = 0 then + true + else + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (0 ≤ c.weights q k) && + (if k = c.prev q then + true + else + decide (c.weights q k ≤ c.eps))) && + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + if k = c.prev q then + true + else + decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) && + decide (1 ≤ c.weights q (c.prev q) + c.eps) && + decide ((∑ k, c.weights q k) = 1)) + +/-- `checkSoftmaxMarginCert` is sound for `SoftmaxMarginBounds`. -/ +theorem checkSoftmaxMarginCert_sound [NeZero seq] (c : SoftmaxMarginCert seq) : + checkSoftmaxMarginCert c = true → + Layers.SoftmaxMarginBounds (Val := Rat) c.eps c.margin c.prev c.scores c.weights := by + classical + intro hcheck + have hqall : + ∀ q ∈ (Finset.univ : Finset (Fin seq)), + (if q = 0 then + true + else + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (0 ≤ c.weights q k) && + (if k = c.prev q then + true + else + decide (c.weights q k ≤ c.eps))) && + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + if k = c.prev q then + true + else + decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) && + decide (1 ≤ c.weights q (c.prev q) + c.eps) && + decide ((∑ k, c.weights q k) = 1)) = true := by + have hcheck' : checkSoftmaxMarginCert c = true := hcheck + have hcheck'' : + finsetAll (Finset.univ : Finset (Fin seq)) (fun q => + if q = 0 then + true + else + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (0 ≤ c.weights q k) && + (if k = c.prev q then + true + else + decide (c.weights q k ≤ c.eps))) && + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + if k = c.prev q then + true + else + decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) && + decide (1 ≤ c.weights q (c.prev q) + c.eps) && + decide ((∑ k, c.weights q k) = 1)) = true := by + simpa [checkSoftmaxMarginCert] using hcheck' + exact (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hcheck'' + refine + { score_margin := ?_ + nonneg := ?_ + sum_one := ?_ + prev_large := ?_ + other_le := ?_ } + · intro q hq k hk + have hqcheck := hqall q (by simp) + have hqcheck' : + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + if k = c.prev q then + true + else + decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) = true := by + have hqcheck'' : + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (0 ≤ c.weights q k) && + (if k = c.prev q then + true + else + decide (c.weights q k ≤ c.eps))) && + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + if k = c.prev q then + true + else + decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) && + decide (1 ≤ c.weights q (c.prev q) + c.eps) && + decide ((∑ k, c.weights q k) = 1) = true := by + simpa [hq] using hqcheck + have hqcheck''' : + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (0 ≤ c.weights q k) && + (if k = c.prev q then + true + else + decide (c.weights q k ≤ c.eps))) = true ∧ + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + if k = c.prev q then + true + else + decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) = true ∧ + decide (1 ≤ c.weights q (c.prev q) + c.eps) = true ∧ + decide ((∑ k, c.weights q k) = 1) = true := by + simpa [Bool.and_eq_true, and_assoc] using hqcheck'' + rcases hqcheck''' with ⟨_, hscoreOk, _, _⟩ + exact hscoreOk + have hscoreall := + (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hqcheck' + have hscorek := hscoreall k (by simp) + have hscorek' : + decide (c.scores q k + c.margin ≤ c.scores q (c.prev q)) = true := by + simpa [hk] using hscorek + exact (decide_eq_true_iff).1 hscorek' + · intro q hq k + have hqcheck := hqall q (by simp) + have hqcheck' : + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (0 ≤ c.weights q k) && + (if k = c.prev q then + true + else + decide (c.weights q k ≤ c.eps))) = true := by + have hqcheck'' : + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (0 ≤ c.weights q k) && + (if k = c.prev q then + true + else + decide (c.weights q k ≤ c.eps))) && + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + if k = c.prev q then + true + else + decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) && + decide (1 ≤ c.weights q (c.prev q) + c.eps) && + decide ((∑ k, c.weights q k) = 1) = true := by + simpa [hq] using hqcheck + have hqcheck''' : + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (0 ≤ c.weights q k) && + (if k = c.prev q then + true + else + decide (c.weights q k ≤ c.eps))) = true ∧ + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + if k = c.prev q then + true + else + decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) = true ∧ + decide (1 ≤ c.weights q (c.prev q) + c.eps) = true ∧ + decide ((∑ k, c.weights q k) = 1) = true := by + simpa [Bool.and_eq_true, and_assoc] using hqcheck'' + rcases hqcheck''' with ⟨hweightsOk, _, _, _⟩ + exact hweightsOk + have hweightsall := + (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hqcheck' + have hweightsk := hweightsall k (by simp) + have hweightsk' : + decide (0 ≤ c.weights q k) = true ∧ + (if k = c.prev q then + true + else + decide (c.weights q k ≤ c.eps)) = true := by + simpa [Bool.and_eq_true] using hweightsk + exact (decide_eq_true_iff).1 hweightsk'.1 + · intro q hq + have hqcheck := hqall q (by simp) + have hqcheck' : + decide ((∑ k, c.weights q k) = 1) = true := by + have hqcheck'' : + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (0 ≤ c.weights q k) && + (if k = c.prev q then + true + else + decide (c.weights q k ≤ c.eps))) && + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + if k = c.prev q then + true + else + decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) && + decide (1 ≤ c.weights q (c.prev q) + c.eps) && + decide ((∑ k, c.weights q k) = 1) = true := by + simpa [hq] using hqcheck + have hqcheck''' : + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (0 ≤ c.weights q k) && + (if k = c.prev q then + true + else + decide (c.weights q k ≤ c.eps))) = true ∧ + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + if k = c.prev q then + true + else + decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) = true ∧ + decide (1 ≤ c.weights q (c.prev q) + c.eps) = true ∧ + decide ((∑ k, c.weights q k) = 1) = true := by + simpa [Bool.and_eq_true, and_assoc] using hqcheck'' + rcases hqcheck''' with ⟨_, _, _, hsumOk⟩ + exact hsumOk + exact (decide_eq_true_iff).1 hqcheck' + · intro q hq + have hqcheck := hqall q (by simp) + have hqcheck' : + decide (1 ≤ c.weights q (c.prev q) + c.eps) = true := by + have hqcheck'' : + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (0 ≤ c.weights q k) && + (if k = c.prev q then + true + else + decide (c.weights q k ≤ c.eps))) && + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + if k = c.prev q then + true + else + decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) && + decide (1 ≤ c.weights q (c.prev q) + c.eps) && + decide ((∑ k, c.weights q k) = 1) = true := by + simpa [hq] using hqcheck + have hqcheck''' : + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (0 ≤ c.weights q k) && + (if k = c.prev q then + true + else + decide (c.weights q k ≤ c.eps))) = true ∧ + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + if k = c.prev q then + true + else + decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) = true ∧ + decide (1 ≤ c.weights q (c.prev q) + c.eps) = true ∧ + decide ((∑ k, c.weights q k) = 1) = true := by + simpa [Bool.and_eq_true, and_assoc] using hqcheck'' + rcases hqcheck''' with ⟨_, _, hprevOk, _⟩ + exact hprevOk + exact (decide_eq_true_iff).1 hqcheck' + · intro q hq k hk + have hqcheck := hqall q (by simp) + have hqcheck' : + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (0 ≤ c.weights q k) && + (if k = c.prev q then + true + else + decide (c.weights q k ≤ c.eps))) = true := by + have hqcheck'' : + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (0 ≤ c.weights q k) && + (if k = c.prev q then + true + else + decide (c.weights q k ≤ c.eps))) && + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + if k = c.prev q then + true + else + decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) && + decide (1 ≤ c.weights q (c.prev q) + c.eps) && + decide ((∑ k, c.weights q k) = 1) = true := by + simpa [hq] using hqcheck + have hqcheck''' : + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (0 ≤ c.weights q k) && + (if k = c.prev q then + true + else + decide (c.weights q k ≤ c.eps))) = true ∧ + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + if k = c.prev q then + true + else + decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) = true ∧ + decide (1 ≤ c.weights q (c.prev q) + c.eps) = true ∧ + decide ((∑ k, c.weights q k) = 1) = true := by + simpa [Bool.and_eq_true, and_assoc] using hqcheck'' + rcases hqcheck''' with ⟨hweightsOk, _, _, _⟩ + exact hweightsOk + have hweightsall := + (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hqcheck' + have hweightsk := hweightsall k (by simp) + have hweightsk' : + decide (0 ≤ c.weights q k) = true ∧ + (if k = c.prev q then + true + else + decide (c.weights q k ≤ c.eps)) = true := by + simpa [Bool.and_eq_true] using hweightsk + have hother : + decide (c.weights q k ≤ c.eps) = true := by + simpa [hk] using hweightsk'.2 + exact (decide_eq_true_iff).1 hother + +end Circuit + +end Nfp From 328bc38277c2ed09499ab4f7dad284f57aa5ed25 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 3 Jan 2026 00:06:50 +0100 Subject: [PATCH 091/244] Add induction cert parser and CLI --- Nfp/Cli.lean | 23 ++++++ Nfp/IO.lean | 46 +++++++++++ Nfp/IO/Pure.lean | 199 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 268 insertions(+) create mode 100644 Nfp/IO.lean create mode 100644 Nfp/IO/Pure.lean diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index 3c109f2..e6f4b4e 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -1,6 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Cli +import Nfp.IO /-! Minimal CLI surface for the NFP rewrite. @@ -24,12 +25,34 @@ def versionCmd : Cmd := `[Cli| "Print the NFP version." ] +/-- Check a softmax-margin certificate for induction heads. -/ +def runInductionCertify (p : Parsed) : IO UInt32 := do + let scoresPath := p.flag! "scores" |>.as! String + IO.runInductionCertify scoresPath + +/-- `nfp induction certify` subcommand. -/ +def inductionCertifyCmd : Cmd := `[Cli| + certify VIA runInductionCertify; + "Check a softmax-margin certificate for induction heads." + FLAGS: + scores : String; "Path to the softmax-margin certificate file." +] + +/-- Induction-head subcommands. -/ +def inductionCmd : Cmd := `[Cli| + induction NOOP; + "Induction-head utilities." + SUBCOMMANDS: + inductionCertifyCmd +] + /-- The root CLI command. -/ def nfpCmd : Cmd := `[Cli| nfp NOOP; "NFP: Neural Formal Pathways (rewrite in progress)." SUBCOMMANDS: versionCmd + inductionCmd ] /-- Main entry point for the CLI. -/ diff --git a/Nfp/IO.lean b/Nfp/IO.lean new file mode 100644 index 0000000..7d17ea3 --- /dev/null +++ b/Nfp/IO.lean @@ -0,0 +1,46 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.IO.Pure + +/-! +IO wrappers for loading and checking softmax-margin certificates. +-/ + +namespace Nfp + +namespace IO + +open Nfp.Circuit + +/-- Load a softmax-margin certificate from disk. -/ +def loadSoftmaxMarginCert (path : System.FilePath) : + IO (Except String (Sigma SoftmaxMarginCert)) := do + let data ← IO.FS.readFile path + return Pure.parseSoftmaxMarginCert data + +/-- Check a softmax-margin certificate file and print a short status line. -/ +def runInductionCertify (path : System.FilePath) : IO UInt32 := do + let parsed ← loadSoftmaxMarginCert path + match parsed with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seq, cert⟩ => + match seq with + | 0 => + IO.eprintln "error: seq must be positive" + return 1 + | Nat.succ n => + let seq := Nat.succ n + let _ : NeZero seq := ⟨by simpa using Nat.succ_ne_zero n⟩ + let ok := Circuit.checkSoftmaxMarginCert cert + if ok then + IO.println s!"ok: certificate accepted (seq={seq})" + return 0 + else + IO.eprintln s!"error: certificate rejected (seq={seq})" + return 2 + +end IO + +end Nfp diff --git a/Nfp/IO/Pure.lean b/Nfp/IO/Pure.lean new file mode 100644 index 0000000..938194c --- /dev/null +++ b/Nfp/IO/Pure.lean @@ -0,0 +1,199 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.Order.Ring.Rat +import Nfp.Circuit.Cert.SoftmaxMargin + +/-! +Pure parsing helpers for softmax-margin certificates. +-/ + +namespace Nfp + +namespace IO + +namespace Pure + +open Nfp.Circuit + +private def splitWords (line : String) : List String := + line.split (fun c => c = ' ' || c = '\t') |>.filter (· ≠ "") + +private def cleanTokens (line : String) : Option (List String) := + let trimmed := line.trim + if trimmed.isEmpty then + none + else if trimmed.startsWith "#" then + none + else + some (splitWords trimmed) + +private def parseNat (s : String) : Except String Nat := + match s.toNat? with + | some n => Except.ok n + | none => Except.error s!"expected Nat, got '{s}'" + +private def parseInt (s : String) : Except String Int := + match s.toInt? with + | some n => Except.ok n + | none => Except.error s!"expected Int, got '{s}'" + +private def parseRat (s : String) : Except String Rat := + match s.splitOn "/" with + | [num] => return Rat.ofInt (← parseInt num) + | [num, den] => + let n ← parseInt num + let d ← parseNat den + if d = 0 then + throw s!"invalid rational '{s}': zero denominator" + else + return Rat.ofInt n / Rat.ofInt (Int.ofNat d) + | _ => throw s!"invalid rational '{s}'" + +private structure SoftmaxMarginParseState (seq : Nat) where + eps : Option Rat + margin : Option Rat + prev : Array (Option (Fin seq)) + scores : Array (Array (Option Rat)) + weights : Array (Array (Option Rat)) + +private def initState (seq : Nat) : SoftmaxMarginParseState seq := + let prev := Array.mkArray seq none + let row : Array (Option Rat) := Array.mkArray seq none + let mat : Array (Array (Option Rat)) := Array.mkArray seq row + { eps := none + margin := none + prev := prev + scores := mat + weights := mat } + +private def setPrev {seq : Nat} (st : SoftmaxMarginParseState seq) + (q k : Nat) : Except String (SoftmaxMarginParseState seq) := do + if hq : q < seq then + if hk : k < seq then + let qFin : Fin seq := ⟨q, hq⟩ + let kFin : Fin seq := ⟨k, hk⟩ + match st.prev.get qFin with + | some _ => + throw s!"duplicate prev entry for q={q}" + | none => + let prev' := st.prev.set qFin (some kFin) + return { st with prev := prev' } + else + throw s!"prev index out of range: k={k}" + else + throw s!"prev index out of range: q={q}" + +private def setMatrixEntry {seq : Nat} (mat : Array (Array (Option Rat))) + (q k : Nat) (v : Rat) : Except String (Array (Array (Option Rat))) := do + if hq : q < seq then + if hk : k < seq then + let qFin : Fin seq := ⟨q, hq⟩ + let kFin : Fin seq := ⟨k, hk⟩ + let row := mat.get qFin + match row.get kFin with + | some _ => + throw s!"duplicate matrix entry at ({q}, {k})" + | none => + let row' := row.set kFin (some v) + return mat.set qFin row' + else + throw s!"index out of range: k={k}" + else + throw s!"index out of range: q={q}" + +private def parseLine {seq : Nat} (st : SoftmaxMarginParseState seq) + (tokens : List String) : Except String (SoftmaxMarginParseState seq) := do + match tokens with + | ["eps", val] => + if st.eps.isSome then + throw "duplicate eps entry" + else + return { st with eps := some (← parseRat val) } + | ["margin", val] => + if st.margin.isSome then + throw "duplicate margin entry" + else + return { st with margin := some (← parseRat val) } + | ["prev", q, k] => + setPrev st (← parseNat q) (← parseNat k) + | ["score", q, k, val] => + let mat ← setMatrixEntry st.scores (← parseNat q) (← parseNat k) (← parseRat val) + return { st with scores := mat } + | ["weight", q, k, val] => + let mat ← setMatrixEntry st.weights (← parseNat q) (← parseNat k) (← parseRat val) + return { st with weights := mat } + | _ => + throw s!"unrecognized line: '{String.intercalate " " tokens}'" + +private def allSomeArray {α : Type} (arr : Array (Option α)) : Bool := + arr.all (fun v => v.isSome) + +private def allSomeMatrix {α : Type} (mat : Array (Array (Option α))) : Bool := + mat.all (fun row => row.all (fun v => v.isSome)) + +private def finalizeState {seq : Nat} (hpos : 0 < seq) + (st : SoftmaxMarginParseState seq) : Except String (SoftmaxMarginCert seq) := do + let eps ← + match st.eps with + | some v => pure v + | none => throw "missing eps entry" + let margin ← + match st.margin with + | some v => pure v + | none => throw "missing margin entry" + if !allSomeArray st.prev then + throw "missing prev entries" + if !allSomeMatrix st.scores then + throw "missing score entries" + if !allSomeMatrix st.weights then + throw "missing weight entries" + let defaultPrev : Fin seq := ⟨0, hpos⟩ + let prevFun : Fin seq → Fin seq := fun q => + (st.prev.get q).getD defaultPrev + let scoresFun : Fin seq → Fin seq → Rat := fun q k => + (st.scores.get q).get k |>.getD 0 + let weightsFun : Fin seq → Fin seq → Rat := fun q k => + (st.weights.get q).get k |>.getD 0 + return + { eps := eps + margin := margin + prev := prevFun + scores := scoresFun + weights := weightsFun } + +/-- Parse a softmax-margin certificate from a text payload. -/ +def parseSoftmaxMarginCert (input : String) : + Except String (Sigma SoftmaxMarginCert) := do + let lines := input.splitOn "\n" + let tokens := lines.filterMap cleanTokens + let mut seq? : Option Nat := none + for t in tokens do + match t with + | ["seq", n] => + if seq?.isSome then + throw "duplicate seq entry" + else + seq? := some (← parseNat n) + | _ => pure () + let seq ← + match seq? with + | some v => pure v + | none => throw "missing seq entry" + match seq with + | 0 => throw "seq must be positive" + | Nat.succ n => + let seq := Nat.succ n + let hpos : 0 < seq := Nat.succ_pos n + let st0 : SoftmaxMarginParseState seq := initState seq + let st ← tokens.foldlM (fun st t => + match t with + | ["seq", _] => pure st + | _ => parseLine st t) st0 + let cert ← finalizeState hpos st + return ⟨seq, cert⟩ + +end Pure + +end IO + +end Nfp From a98c8e56ed858338bfafb9dc77e8c87d963b6a37 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 3 Jan 2026 00:07:05 +0100 Subject: [PATCH 092/244] Update module map for IO --- AGENTS.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/AGENTS.md b/AGENTS.md index e893def..4ce8e53 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -317,6 +317,10 @@ but you **must** update this list in the same commit. - Aggregator for circuit modules. ### 5.6 CLI surface +- `Nfp/IO/Pure.lean` + - Pure parsing helpers for CLI inputs. +- `Nfp/IO.lean` + - IO-only wrappers for loading inputs and running checks. - `Nfp/Cli.lean` - CLI commands and `main` implementation. - `Main.lean` From c7baf68be626153a50302db695ae94a98673c452 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 3 Jan 2026 01:17:39 +0100 Subject: [PATCH 093/244] Add value-range certificate support for induction bounds --- AGENTS.md | 2 + Nfp/Circuit.lean | 1 + Nfp/Circuit/Cert/ValueRange.lean | 68 ++++++++ Nfp/Circuit/Layers/Induction.lean | 236 +++++++++++++++++++++++++++ Nfp/Cli.lean | 10 +- Nfp/IO.lean | 86 +++++++--- Nfp/IO/Pure.lean | 167 +++++++++++++++---- README.md | 45 +++++ scripts/build_gpt2_induction_cert.py | 205 +++++++++++++++++++++++ 9 files changed, 762 insertions(+), 58 deletions(-) create mode 100644 Nfp/Circuit/Cert/ValueRange.lean create mode 100644 scripts/build_gpt2_induction_cert.py diff --git a/AGENTS.md b/AGENTS.md index 4ce8e53..549fda9 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -285,6 +285,8 @@ but you **must** update this list in the same commit. - Equivalence definition and finite checker. - `Nfp/Circuit/Cert/SoftmaxMargin.lean` - Softmax-margin certificate payloads and checker soundness. +- `Nfp/Circuit/Cert/ValueRange.lean` + - Value-range certificate payloads and checker soundness. - `Nfp/Circuit/Typed.lean` - Typed circuit wrapper and interface-level equivalence checker. - `Nfp/Circuit/Compose.lean` diff --git a/Nfp/Circuit.lean b/Nfp/Circuit.lean index 2f57e89..3b2b6ca 100644 --- a/Nfp/Circuit.lean +++ b/Nfp/Circuit.lean @@ -7,6 +7,7 @@ import Nfp.Circuit.Semantics import Nfp.Circuit.WellFormed import Nfp.Circuit.Cert import Nfp.Circuit.Cert.SoftmaxMargin +import Nfp.Circuit.Cert.ValueRange import Nfp.Circuit.Typed import Nfp.Circuit.Compose import Nfp.Circuit.Gates diff --git a/Nfp/Circuit/Cert/ValueRange.lean b/Nfp/Circuit/Cert/ValueRange.lean new file mode 100644 index 0000000..c18a46b --- /dev/null +++ b/Nfp/Circuit/Cert/ValueRange.lean @@ -0,0 +1,68 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.BigOperators.Group.Finset.Basic +import Mathlib.Algebra.Order.Ring.Rat +import Nfp.Circuit.Cert +import Nfp.Circuit.Layers.Induction + +/-! +Value-range certificates for attention value vectors. +-/ + +namespace Nfp + +namespace Circuit + +open scoped BigOperators + +variable {seq : Nat} + +/-- Certificate payload for value-range bounds (Rat-valued). -/ +structure ValueRangeCert (seq : Nat) where + /-- Lower bound for values. -/ + lo : Rat + /-- Upper bound for values. -/ + hi : Rat + /-- Value entries. -/ + vals : Fin seq → Rat + +/-- Boolean checker for value-range certificates. -/ +def checkValueRangeCert [NeZero seq] (c : ValueRangeCert seq) : Bool := + decide (c.lo ≤ c.hi) && + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (c.lo ≤ c.vals k) && decide (c.vals k ≤ c.hi)) + +/-- `checkValueRangeCert` is sound for `ValueRangeBounds`. -/ +theorem checkValueRangeCert_sound [NeZero seq] (c : ValueRangeCert seq) : + checkValueRangeCert c = true → + Layers.ValueRangeBounds (Val := Rat) c.lo c.hi c.vals := by + classical + intro hcheck + have hcheck' : + decide (c.lo ≤ c.hi) = true ∧ + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (c.lo ≤ c.vals k) && decide (c.vals k ≤ c.hi)) = true := by + simpa [checkValueRangeCert, Bool.and_eq_true] using hcheck + rcases hcheck' with ⟨hlohi, hall⟩ + have hlohi' : c.lo ≤ c.hi := (decide_eq_true_iff).1 hlohi + have hall' := + (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hall + have hlo : ∀ k, c.lo ≤ c.vals k := by + intro k + have hk := hall' k (by simp) + have hk' : + decide (c.lo ≤ c.vals k) = true ∧ decide (c.vals k ≤ c.hi) = true := by + simpa [Bool.and_eq_true] using hk + exact (decide_eq_true_iff).1 hk'.1 + have hhi : ∀ k, c.vals k ≤ c.hi := by + intro k + have hk := hall' k (by simp) + have hk' : + decide (c.lo ≤ c.vals k) = true ∧ decide (c.vals k ≤ c.hi) = true := by + simpa [Bool.and_eq_true] using hk + exact (decide_eq_true_iff).1 hk'.2 + exact { lo_le_hi := hlohi', lo_le := hlo, le_hi := hhi } + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Layers/Induction.lean b/Nfp/Circuit/Layers/Induction.lean index 2f8862d..33d2dcc 100644 --- a/Nfp/Circuit/Layers/Induction.lean +++ b/Nfp/Circuit/Layers/Induction.lean @@ -1,7 +1,9 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Mathlib.Algebra.BigOperators.Group.Finset.Basic +import Mathlib.Algebra.BigOperators.Ring.Finset import Mathlib.Algebra.Order.Monoid.Unbundled.Basic +import Mathlib.Algebra.Order.Ring.Defs import Nfp.Circuit.Layers.Attention /-! @@ -87,6 +89,22 @@ theorem inductionSpecApprox_of_spec (ε : Val) (hε : 0 ≤ ε) end ApproxSpec +section ValueRange + +variable {Val : Type v} [PartialOrder Val] +variable {seq : Nat} + +/-- Value-range bounds for a vector of attention values. -/ +structure ValueRangeBounds (lo hi : Val) (vals : Fin seq → Val) : Prop where + /-- Lower and upper bounds are ordered. -/ + lo_le_hi : lo ≤ hi + /-- All values are at least `lo`. -/ + lo_le : ∀ k, lo ≤ vals k + /-- All values are at most `hi`. -/ + le_hi : ∀ k, vals k ≤ hi + +end ValueRange + section Bounds variable {Val : Type v} [Semiring Val] [PartialOrder Val] @@ -174,6 +192,224 @@ theorem inductionWeightsApprox_of_boundsOn (ε : Val) (prev : Fin seq → Fin se end ApproxBounds +section ApproxOutput + +variable {Val : Type v} [Ring Val] [LinearOrder Val] [IsOrderedRing Val] +variable {n : Nat} + +local instance : NeZero (Nat.succ n) := ⟨by simp⟩ + +/-- Approximate one-hot weights plus bounded values yield an approximate induction spec. -/ +theorem inductionSpecApprox_of_oneHotApprox_valueRange + (ε lo hi : Val) + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Val) + (vals : Fin (Nat.succ n) → Val) + (hweights : OneHotApproxBoundsOn (Val := Val) ε prev weights) + (hvals : ValueRangeBounds (Val := Val) lo hi vals) : + InductionSpecApprox (Val := Val) (n := n) (ε * (hi - lo)) prev + (fun q => dotProduct (weights q) vals) vals := by + classical + intro q hq + let others : Finset (Fin (Nat.succ n)) := + (Finset.univ : Finset (Fin (Nat.succ n))).erase (prev q) + have hsum_decomp : + weights q (prev q) + ∑ k ∈ others, weights q k = ∑ k, weights q k := by + simp [others] + have hsum : + weights q (prev q) + ∑ k ∈ others, weights q k = 1 := by + simpa [hweights.sum_one q hq] using hsum_decomp + have hsum_others_le : (∑ k ∈ others, weights q k) ≤ ε := by + have hprev : 1 ≤ weights q (prev q) + ε := hweights.prev_large q hq + have hprev' : + weights q (prev q) + ∑ k ∈ others, weights q k ≤ weights q (prev q) + ε := by + simpa [hsum] using hprev + exact (add_le_add_iff_left (weights q (prev q))).1 hprev' + have hsum_others_nonneg : 0 ≤ ∑ k ∈ others, weights q k := by + refine Finset.sum_nonneg ?_ + intro k hk + exact hweights.nonneg q hq k + have hvals_hi : ∀ k, vals k ≤ hi := hvals.le_hi + have hvals_lo : ∀ k, lo ≤ vals k := hvals.lo_le + have hdiff_nonneg : 0 ≤ hi - lo := sub_nonneg.mpr hvals.lo_le_hi + have hsum_vals_le : + (∑ k ∈ others, weights q k * vals k) ≤ (∑ k ∈ others, weights q k) * hi := by + have hle : ∀ k ∈ others, weights q k * vals k ≤ weights q k * hi := by + intro k hk + have hval : vals k ≤ hi := hvals_hi k + have hnonneg : 0 ≤ weights q k := hweights.nonneg q hq k + exact mul_le_mul_of_nonneg_left hval hnonneg + calc + ∑ k ∈ others, weights q k * vals k + ≤ ∑ k ∈ others, weights q k * hi := Finset.sum_le_sum hle + _ = (∑ k ∈ others, weights q k) * hi := by + simpa using + (Finset.sum_mul (s := others) (f := fun k => weights q k) (a := hi)).symm + have hsum_vals_ge : + (∑ k ∈ others, weights q k) * lo ≤ (∑ k ∈ others, weights q k * vals k) := by + have hle : ∀ k ∈ others, weights q k * lo ≤ weights q k * vals k := by + intro k hk + have hval : lo ≤ vals k := hvals_lo k + have hnonneg : 0 ≤ weights q k := hweights.nonneg q hq k + exact mul_le_mul_of_nonneg_left hval hnonneg + calc + (∑ k ∈ others, weights q k) * lo + = ∑ k ∈ others, weights q k * lo := by + exact + (Finset.sum_mul (s := others) (f := fun k => weights q k) (a := lo)) + _ ≤ ∑ k ∈ others, weights q k * vals k := Finset.sum_le_sum hle + have hsum_prod : + weights q (prev q) * vals (prev q) + ∑ k ∈ others, weights q k * vals k = + ∑ k, weights q k * vals k := by + simp [others] + have hout_eq : + dotProduct (weights q) vals = + weights q (prev q) * vals (prev q) + ∑ k ∈ others, weights q k * vals k := by + simpa [dotProduct] using hsum_prod.symm + have hsum_val_prev : + weights q (prev q) * vals (prev q) + + (∑ k ∈ others, weights q k) * vals (prev q) = + vals (prev q) := by + calc + weights q (prev q) * vals (prev q) + + (∑ k ∈ others, weights q k) * vals (prev q) = + (weights q (prev q) + ∑ k ∈ others, weights q k) * vals (prev q) := by + simpa using + (add_mul (weights q (prev q)) (∑ k ∈ others, weights q k) (vals (prev q))).symm + _ = 1 * vals (prev q) := by + simp [hsum] + _ = vals (prev q) := by simp + have hsplit : + (∑ k ∈ others, weights q k) * hi = + (∑ k ∈ others, weights q k) * lo + + (∑ k ∈ others, weights q k) * (hi - lo) := by + calc + (∑ k ∈ others, weights q k) * hi = + (∑ k ∈ others, weights q k) * lo + + (∑ k ∈ others, weights q k) * hi - + (∑ k ∈ others, weights q k) * lo := by + exact + (add_sub_cancel_left + ((∑ k ∈ others, weights q k) * lo) ((∑ k ∈ others, weights q k) * hi)).symm + _ = (∑ k ∈ others, weights q k) * lo + + ((∑ k ∈ others, weights q k) * hi - + (∑ k ∈ others, weights q k) * lo) := by + simp [sub_eq_add_neg, add_assoc] + _ = (∑ k ∈ others, weights q k) * lo + + (∑ k ∈ others, weights q k) * (hi - lo) := by + simp [mul_sub] + have hsum_prev_le : + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * lo ≤ + vals (prev q) := by + have hmul : (∑ k ∈ others, weights q k) * lo ≤ + (∑ k ∈ others, weights q k) * vals (prev q) := + mul_le_mul_of_nonneg_left (hvals_lo (prev q)) hsum_others_nonneg + calc + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * lo + ≤ weights q (prev q) * vals (prev q) + + (∑ k ∈ others, weights q k) * vals (prev q) := by + have h := + add_le_add_left hmul (weights q (prev q) * vals (prev q)) + simpa [add_comm, add_left_comm, add_assoc] using h + _ = vals (prev q) := hsum_val_prev + have hupper_mid : + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi ≤ + vals (prev q) + (∑ k ∈ others, weights q k) * (hi - lo) := by + calc + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi = + weights q (prev q) * vals (prev q) + + ((∑ k ∈ others, weights q k) * lo + + (∑ k ∈ others, weights q k) * (hi - lo)) := by + simp [hsplit] + _ = weights q (prev q) * vals (prev q) + + (∑ k ∈ others, weights q k) * lo + + (∑ k ∈ others, weights q k) * (hi - lo) := by + simp [add_assoc] + _ ≤ vals (prev q) + (∑ k ∈ others, weights q k) * (hi - lo) := by + have h := + add_le_add_right hsum_prev_le ((∑ k ∈ others, weights q k) * (hi - lo)) + simpa [add_comm, add_left_comm, add_assoc] using h + have hupper : + dotProduct (weights q) vals ≤ vals (prev q) + ε * (hi - lo) := by + have hmul : + (∑ k ∈ others, weights q k) * (hi - lo) ≤ ε * (hi - lo) := + mul_le_mul_of_nonneg_right hsum_others_le hdiff_nonneg + calc + dotProduct (weights q) vals = + weights q (prev q) * vals (prev q) + ∑ k ∈ others, weights q k * vals k := hout_eq + _ ≤ weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi := by + have h := + add_le_add_left hsum_vals_le (weights q (prev q) * vals (prev q)) + simpa [add_comm, add_left_comm, add_assoc] using h + _ ≤ vals (prev q) + (∑ k ∈ others, weights q k) * (hi - lo) := hupper_mid + _ ≤ vals (prev q) + ε * (hi - lo) := by + have h := add_le_add_left hmul (vals (prev q)) + simpa [add_comm, add_left_comm, add_assoc] using h + have hprev_le : + vals (prev q) ≤ + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi := by + have hmul : (∑ k ∈ others, weights q k) * vals (prev q) ≤ + (∑ k ∈ others, weights q k) * hi := + mul_le_mul_of_nonneg_left (hvals_hi (prev q)) hsum_others_nonneg + have hmul' : + weights q (prev q) * vals (prev q) + + (∑ k ∈ others, weights q k) * vals (prev q) ≤ + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi := by + have h := + add_le_add_left hmul (weights q (prev q) * vals (prev q)) + simpa [add_comm, add_left_comm, add_assoc] using h + calc + vals (prev q) = + weights q (prev q) * vals (prev q) + + (∑ k ∈ others, weights q k) * vals (prev q) := by + simpa using hsum_val_prev.symm + _ ≤ weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi := hmul' + have hprev_le' : + vals (prev q) ≤ + weights q (prev q) * vals (prev q) + + (∑ k ∈ others, weights q k) * lo + + (∑ k ∈ others, weights q k) * (hi - lo) := by + calc + vals (prev q) ≤ + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi := hprev_le + _ = + weights q (prev q) * vals (prev q) + + (∑ k ∈ others, weights q k) * lo + + (∑ k ∈ others, weights q k) * (hi - lo) := by + simp [hsplit, add_assoc] + have hsub : + vals (prev q) - (∑ k ∈ others, weights q k) * (hi - lo) ≤ + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * lo := by + exact (sub_le_iff_le_add).2 hprev_le' + have hlowershift : + vals (prev q) - ε * (hi - lo) ≤ + vals (prev q) - (∑ k ∈ others, weights q k) * (hi - lo) := by + have hmul : + (∑ k ∈ others, weights q k) * (hi - lo) ≤ ε * (hi - lo) := + mul_le_mul_of_nonneg_right hsum_others_le hdiff_nonneg + exact sub_le_sub_left hmul (vals (prev q)) + have hlow : + vals (prev q) - ε * (hi - lo) ≤ dotProduct (weights q) vals := by + calc + vals (prev q) - ε * (hi - lo) ≤ + vals (prev q) - (∑ k ∈ others, weights q k) * (hi - lo) := hlowershift + _ ≤ weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * lo := hsub + _ ≤ dotProduct (weights q) vals := by + calc + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * lo + ≤ weights q (prev q) * vals (prev q) + ∑ k ∈ others, weights q k * vals k := by + have h := + add_le_add_left hsum_vals_ge (weights q (prev q) * vals (prev q)) + simpa [add_comm, add_left_comm, add_assoc] using h + _ = dotProduct (weights q) vals := by + simp [hout_eq] + have hlower : + vals (prev q) ≤ dotProduct (weights q) vals + ε * (hi - lo) := by + exact (sub_le_iff_le_add).1 hlow + exact ⟨hupper, hlower⟩ + +end ApproxOutput + section SoftmaxMargin variable {Val : Type v} [Semiring Val] [PartialOrder Val] diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index e6f4b4e..bf61b4a 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -25,17 +25,19 @@ def versionCmd : Cmd := `[Cli| "Print the NFP version." ] -/-- Check a softmax-margin certificate for induction heads. -/ +/-- Check induction certificates for induction heads. -/ def runInductionCertify (p : Parsed) : IO UInt32 := do let scoresPath := p.flag! "scores" |>.as! String - IO.runInductionCertify scoresPath + let valuesPath? := (p.flag? "values").map (·.as! String) + IO.runInductionCertify scoresPath valuesPath? /-- `nfp induction certify` subcommand. -/ def inductionCertifyCmd : Cmd := `[Cli| certify VIA runInductionCertify; - "Check a softmax-margin certificate for induction heads." + "Check induction certificates for induction heads." FLAGS: scores : String; "Path to the softmax-margin certificate file." + values : String; "Optional path to a value-range certificate file." ] /-- Induction-head subcommands. -/ @@ -51,7 +53,7 @@ def nfpCmd : Cmd := `[Cli| nfp NOOP; "NFP: Neural Formal Pathways (rewrite in progress)." SUBCOMMANDS: - versionCmd + versionCmd; inductionCmd ] diff --git a/Nfp/IO.lean b/Nfp/IO.lean index 7d17ea3..e42ea6e 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -3,7 +3,7 @@ import Nfp.IO.Pure /-! -IO wrappers for loading and checking softmax-margin certificates. +IO wrappers for loading and checking induction certificates. -/ namespace Nfp @@ -18,28 +18,76 @@ def loadSoftmaxMarginCert (path : System.FilePath) : let data ← IO.FS.readFile path return Pure.parseSoftmaxMarginCert data -/-- Check a softmax-margin certificate file and print a short status line. -/ -def runInductionCertify (path : System.FilePath) : IO UInt32 := do - let parsed ← loadSoftmaxMarginCert path - match parsed with +/-- Load a value-range certificate from disk. -/ +def loadValueRangeCert (path : System.FilePath) : + IO (Except String (Sigma ValueRangeCert)) := do + let data ← IO.FS.readFile path + return Pure.parseValueRangeCert data + +private def checkSoftmaxMargin (seq : Nat) (cert : SoftmaxMarginCert seq) : + IO (Except String Unit) := + match seq with + | 0 => return Except.error "seq must be positive" + | Nat.succ n => + let seq := Nat.succ n + let _ : NeZero seq := ⟨by simp⟩ + let ok := Circuit.checkSoftmaxMarginCert cert + if ok then + return Except.ok () + else + return Except.error "softmax-margin certificate rejected" + +private def checkValueRange (seq : Nat) (cert : ValueRangeCert seq) : + IO (Except String Unit) := + match seq with + | 0 => return Except.error "seq must be positive" + | Nat.succ n => + let seq := Nat.succ n + let _ : NeZero seq := ⟨by simp⟩ + let ok := Circuit.checkValueRangeCert cert + if ok then + return Except.ok () + else + return Except.error "value-range certificate rejected" + +/-- Check induction certificates and print a short status line. -/ +def runInductionCertify (scoresPath : System.FilePath) + (valuesPath? : Option System.FilePath) : IO UInt32 := do + let parsedScores ← loadSoftmaxMarginCert scoresPath + match parsedScores with | Except.error msg => IO.eprintln s!"error: {msg}" return 1 | Except.ok ⟨seq, cert⟩ => - match seq with - | 0 => - IO.eprintln "error: seq must be positive" - return 1 - | Nat.succ n => - let seq := Nat.succ n - let _ : NeZero seq := ⟨by simpa using Nat.succ_ne_zero n⟩ - let ok := Circuit.checkSoftmaxMarginCert cert - if ok then - IO.println s!"ok: certificate accepted (seq={seq})" - return 0 - else - IO.eprintln s!"error: certificate rejected (seq={seq})" - return 2 + let scoresOk ← checkSoftmaxMargin seq cert + match scoresOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + match valuesPath? with + | none => + IO.println s!"ok: softmax-margin certificate accepted (seq={seq})" + return 0 + | some valuesPath => + let parsedValues ← loadValueRangeCert valuesPath + match parsedValues with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seqVals, certVals⟩ => + if seqVals ≠ seq then + IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" + return 2 + let valuesOk ← checkValueRange seqVals certVals + match valuesOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let tol := cert.eps * (certVals.hi - certVals.lo) + IO.println s!"ok: induction bound certified (seq={seq}, tol={tol})" + return 0 end IO diff --git a/Nfp/IO/Pure.lean b/Nfp/IO/Pure.lean index 938194c..da2fb2d 100644 --- a/Nfp/IO/Pure.lean +++ b/Nfp/IO/Pure.lean @@ -2,9 +2,10 @@ import Mathlib.Algebra.Order.Ring.Rat import Nfp.Circuit.Cert.SoftmaxMargin +import Nfp.Circuit.Cert.ValueRange /-! -Pure parsing helpers for softmax-margin certificates. +Pure parsing helpers for softmax-margin and value-range certificates. -/ namespace Nfp @@ -16,7 +17,7 @@ namespace Pure open Nfp.Circuit private def splitWords (line : String) : List String := - line.split (fun c => c = ' ' || c = '\t') |>.filter (· ≠ "") + line.splitToList (fun c => c = ' ' || c = '\t') |>.filter (· ≠ "") private def cleanTokens (line : String) : Option (List String) := let trimmed := line.trim @@ -37,9 +38,10 @@ private def parseInt (s : String) : Except String Int := | some n => Except.ok n | none => Except.error s!"expected Int, got '{s}'" -private def parseRat (s : String) : Except String Rat := +private def parseRat (s : String) : Except String Rat := do match s.splitOn "/" with - | [num] => return Rat.ofInt (← parseInt num) + | [num] => + return Rat.ofInt (← parseInt num) | [num, den] => let n ← parseInt num let d ← parseNat den @@ -47,24 +49,22 @@ private def parseRat (s : String) : Except String Rat := throw s!"invalid rational '{s}': zero denominator" else return Rat.ofInt n / Rat.ofInt (Int.ofNat d) - | _ => throw s!"invalid rational '{s}'" + | _ => + throw s!"invalid rational '{s}'" private structure SoftmaxMarginParseState (seq : Nat) where eps : Option Rat margin : Option Rat - prev : Array (Option (Fin seq)) - scores : Array (Array (Option Rat)) - weights : Array (Array (Option Rat)) + prev : Fin seq → Option (Fin seq) + scores : Fin seq → Fin seq → Option Rat + weights : Fin seq → Fin seq → Option Rat private def initState (seq : Nat) : SoftmaxMarginParseState seq := - let prev := Array.mkArray seq none - let row : Array (Option Rat) := Array.mkArray seq none - let mat : Array (Array (Option Rat)) := Array.mkArray seq row { eps := none margin := none - prev := prev - scores := mat - weights := mat } + prev := fun _ => none + scores := fun _ _ => none + weights := fun _ _ => none } private def setPrev {seq : Nat} (st : SoftmaxMarginParseState seq) (q k : Nat) : Except String (SoftmaxMarginParseState seq) := do @@ -72,30 +72,40 @@ private def setPrev {seq : Nat} (st : SoftmaxMarginParseState seq) if hk : k < seq then let qFin : Fin seq := ⟨q, hq⟩ let kFin : Fin seq := ⟨k, hk⟩ - match st.prev.get qFin with + match st.prev qFin with | some _ => throw s!"duplicate prev entry for q={q}" | none => - let prev' := st.prev.set qFin (some kFin) + let prev' : Fin seq → Option (Fin seq) := fun q' => + if q' = qFin then + some kFin + else + st.prev q' return { st with prev := prev' } else throw s!"prev index out of range: k={k}" else throw s!"prev index out of range: q={q}" -private def setMatrixEntry {seq : Nat} (mat : Array (Array (Option Rat))) - (q k : Nat) (v : Rat) : Except String (Array (Array (Option Rat))) := do +private def setMatrixEntry {seq : Nat} (mat : Fin seq → Fin seq → Option Rat) + (q k : Nat) (v : Rat) : Except String (Fin seq → Fin seq → Option Rat) := do if hq : q < seq then if hk : k < seq then let qFin : Fin seq := ⟨q, hq⟩ let kFin : Fin seq := ⟨k, hk⟩ - let row := mat.get qFin - match row.get kFin with + match mat qFin kFin with | some _ => throw s!"duplicate matrix entry at ({q}, {k})" | none => - let row' := row.set kFin (some v) - return mat.set qFin row' + let mat' : Fin seq → Fin seq → Option Rat := fun q' k' => + if q' = qFin then + if k' = kFin then + some v + else + mat q' k' + else + mat q' k' + return mat' else throw s!"index out of range: k={k}" else @@ -125,12 +135,6 @@ private def parseLine {seq : Nat} (st : SoftmaxMarginParseState seq) | _ => throw s!"unrecognized line: '{String.intercalate " " tokens}'" -private def allSomeArray {α : Type} (arr : Array (Option α)) : Bool := - arr.all (fun v => v.isSome) - -private def allSomeMatrix {α : Type} (mat : Array (Array (Option α))) : Bool := - mat.all (fun row => row.all (fun v => v.isSome)) - private def finalizeState {seq : Nat} (hpos : 0 < seq) (st : SoftmaxMarginParseState seq) : Except String (SoftmaxMarginCert seq) := do let eps ← @@ -141,20 +145,22 @@ private def finalizeState {seq : Nat} (hpos : 0 < seq) match st.margin with | some v => pure v | none => throw "missing margin entry" - if !allSomeArray st.prev then + if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => (st.prev q).isSome) then throw "missing prev entries" - if !allSomeMatrix st.scores then + if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.scores q k).isSome)) then throw "missing score entries" - if !allSomeMatrix st.weights then + if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.weights q k).isSome)) then throw "missing weight entries" let defaultPrev : Fin seq := ⟨0, hpos⟩ let prevFun : Fin seq → Fin seq := fun q => - (st.prev.get q).getD defaultPrev + (st.prev q).getD defaultPrev let scoresFun : Fin seq → Fin seq → Rat := fun q k => - (st.scores.get q).get k |>.getD 0 + (st.scores q k).getD 0 let weightsFun : Fin seq → Fin seq → Rat := fun q k => - (st.weights.get q).get k |>.getD 0 - return + (st.weights q k).getD 0 + pure { eps := eps margin := margin prev := prevFun @@ -192,6 +198,97 @@ def parseSoftmaxMarginCert (input : String) : let cert ← finalizeState hpos st return ⟨seq, cert⟩ +private structure ValueRangeParseState (seq : Nat) where + lo : Option Rat + hi : Option Rat + vals : Fin seq → Option Rat + +private def initValueRangeState (seq : Nat) : ValueRangeParseState seq := + { lo := none + hi := none + vals := fun _ => none } + +private def setVal {seq : Nat} (st : ValueRangeParseState seq) + (k : Nat) (v : Rat) : Except String (ValueRangeParseState seq) := do + if hk : k < seq then + let kFin : Fin seq := ⟨k, hk⟩ + match st.vals kFin with + | some _ => + throw s!"duplicate value entry for k={k}" + | none => + let vals' : Fin seq → Option Rat := fun k' => + if k' = kFin then + some v + else + st.vals k' + return { st with vals := vals' } + else + throw s!"value index out of range: k={k}" + +private def parseValueLine {seq : Nat} (st : ValueRangeParseState seq) + (tokens : List String) : Except String (ValueRangeParseState seq) := do + match tokens with + | ["lo", val] => + if st.lo.isSome then + throw "duplicate lo entry" + else + return { st with lo := some (← parseRat val) } + | ["hi", val] => + if st.hi.isSome then + throw "duplicate hi entry" + else + return { st with hi := some (← parseRat val) } + | ["val", k, val] => + setVal st (← parseNat k) (← parseRat val) + | _ => + throw s!"unrecognized line: '{String.intercalate " " tokens}'" + +private def finalizeValueState {seq : Nat} (st : ValueRangeParseState seq) : + Except String (ValueRangeCert seq) := do + let lo ← + match st.lo with + | some v => pure v + | none => throw "missing lo entry" + let hi ← + match st.hi with + | some v => pure v + | none => throw "missing hi entry" + if !finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.vals k).isSome) then + throw "missing value entries" + let valsFun : Fin seq → Rat := fun k => + (st.vals k).getD 0 + return { lo := lo, hi := hi, vals := valsFun } + +/-- Parse a value-range certificate from a text payload. -/ +def parseValueRangeCert (input : String) : + Except String (Sigma ValueRangeCert) := do + let lines := input.splitOn "\n" + let tokens := lines.filterMap cleanTokens + let mut seq? : Option Nat := none + for t in tokens do + match t with + | ["seq", n] => + if seq?.isSome then + throw "duplicate seq entry" + else + seq? := some (← parseNat n) + | _ => pure () + let seq ← + match seq? with + | some v => pure v + | none => throw "missing seq entry" + match seq with + | 0 => throw "seq must be positive" + | Nat.succ n => + let seq := Nat.succ n + let st0 : ValueRangeParseState seq := initValueRangeState seq + let st ← tokens.foldlM (fun st t => + match t with + | ["seq", _] => pure st + | _ => parseValueLine st t) st0 + let cert ← finalizeValueState st + return ⟨seq, cert⟩ + end Pure end IO diff --git a/README.md b/README.md index 247dd5d..61b2c42 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,51 @@ Current core modules (new): Module map and invariants are tracked in `AGENTS.md`. +## Induction Certification (prototype) + +The current end-to-end prototype checks a **softmax-margin certificate** for a single GPT-2-small +head. The certificate is produced by an **untrusted** helper script and verified by the CLI. + +Generate certificates (untrusted): + +```bash +python scripts/build_gpt2_induction_cert.py \ + --output reports/gpt2_induction.cert \ + --layer 5 --head 1 --seq 32 --pattern-length 16 \ + --values-out reports/gpt2_induction.values --value-dim 0 +``` + +Verify it (trusted checker): + +```bash +lake exe nfp induction certify --scores reports/gpt2_induction.cert \ + --values reports/gpt2_induction.values +``` + +Softmax-margin certificate format (line-oriented): + +``` +seq +eps +margin +prev +score +weight +``` + +Value-range certificate format (line-oriented): + +``` +seq +lo +hi +val +``` + +The checker validates that the provided scores/weights satisfy `SoftmaxMarginBounds` and that the +value entries are bounded by `lo`/`hi`. When both are provided, the CLI reports a tolerance +`eps * (hi - lo)` for the approximate induction spec. + ## Soundness statement (what is proven vs checked) The Lean library defines the core math objects (finite probability, mixers, linearizations, and operator-norm-style bounds) and proves a number of lemmas about them. The CLI sound path produces certificates using exact `Rat` arithmetic and a trusted checker that verifies internal arithmetic relationships between certificate fields. diff --git a/scripts/build_gpt2_induction_cert.py b/scripts/build_gpt2_induction_cert.py new file mode 100644 index 0000000..83084e0 --- /dev/null +++ b/scripts/build_gpt2_induction_cert.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: AGPL-3.0-or-later + +""" +Build a softmax-margin certificate for a GPT-2-small induction head. + +This script is untrusted and uses floating-point arithmetic to produce a +rational certificate compatible with `nfp induction certify`. + +Usage: + python scripts/build_gpt2_induction_cert.py --output reports/gpt2_induction.cert \ + --layer 5 --head 1 --seq 32 --pattern-length 16 \ + --values-out reports/gpt2_induction.values --value-dim 0 +""" + +import argparse +import math +from fractions import Fraction +from pathlib import Path + +import numpy as np + +try: + import torch + from transformers import GPT2Model +except ImportError: + raise SystemExit( + "Missing dependencies. Install with: uv add transformers torch" + ) + + +def rat_from_float(x: float, decimals: int) -> Fraction: + scale = 10 ** decimals + return Fraction(int(round(x * scale)), scale) + + +def rat_to_str(q: Fraction) -> str: + if q.denominator == 1: + return str(q.numerator) + return f"{q.numerator}/{q.denominator}" + + +def build_tokens(seq: int, pattern_len: int, random_pattern: bool, seed: int) -> np.ndarray: + if random_pattern: + rng = np.random.default_rng(seed) + pattern = rng.integers(1000, 30000, size=pattern_len, endpoint=False) + else: + pattern = np.arange(pattern_len) + repeats = (seq // pattern_len) + 1 + return np.tile(pattern, repeats)[:seq] + + +def build_prev(tokens: np.ndarray) -> np.ndarray: + prev = np.zeros_like(tokens) + last_seen = {} + for idx, tok in enumerate(tokens): + if idx == 0: + prev[idx] = 0 + else: + prev[idx] = last_seen.get(tok, 0) + last_seen[tok] = idx + return prev + + +def compute_scores_weights(model, input_ids, layer: int, head: int, device: str): + model.eval() + with torch.no_grad(): + outputs = model(input_ids, output_hidden_states=True) + hidden_states = outputs.hidden_states[layer] + block = model.h[layer] + x = block.ln_1(hidden_states) + qkv = block.attn.c_attn(x) + n_head = model.config.n_head + head_dim = qkv.shape[-1] // (3 * n_head) + q, k, _v = qkv.split(n_head * head_dim, dim=2) + q = q.view(1, -1, n_head, head_dim).transpose(1, 2) + k = k.view(1, -1, n_head, head_dim).transpose(1, 2) + v = _v.view(1, -1, n_head, head_dim).transpose(1, 2) + qh = q[:, head] + kh = k[:, head] + vh = v[:, head] + scores = torch.matmul(qh, kh.transpose(-2, -1)) / math.sqrt(head_dim) + seq = scores.shape[-1] + mask = torch.triu(torch.ones(seq, seq, device=device), diagonal=1).bool() + scores = scores.masked_fill(mask, -1e9) + weights = torch.softmax(scores, dim=-1) + return (scores.squeeze(0).cpu().numpy(), + weights.squeeze(0).cpu().numpy(), + vh.squeeze(0).cpu().numpy()) + + +def write_scores(path: Path, seq: int, prev: np.ndarray, scores, weights, eps=None, margin=None): + with path.open("w", encoding="ascii") as f: + f.write(f"seq {seq}\n") + if eps is not None: + f.write(f"eps {rat_to_str(eps)}\n") + if margin is not None: + f.write(f"margin {rat_to_str(margin)}\n") + for q, k in enumerate(prev.tolist()): + f.write(f"prev {q} {k}\n") + for q in range(seq): + for k in range(seq): + f.write(f"score {q} {k} {rat_to_str(scores[q][k])}\n") + for q in range(seq): + for k in range(seq): + f.write(f"weight {q} {k} {rat_to_str(weights[q][k])}\n") + + +def write_value_range(path: Path, seq: int, values, decimals: int) -> None: + vals_rat = [rat_from_float(float(values[k]), decimals) for k in range(seq)] + lo = min(vals_rat) + hi = max(vals_rat) + with path.open("w", encoding="ascii") as f: + f.write(f"seq {seq}\n") + f.write(f"lo {rat_to_str(lo)}\n") + f.write(f"hi {rat_to_str(hi)}\n") + for k, val in enumerate(vals_rat): + f.write(f"val {k} {rat_to_str(val)}\n") + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--output", required=True, help="Path to write certificate") + parser.add_argument("--scores-out", help="Optional path for raw scores/weights dump") + parser.add_argument("--layer", type=int, default=0, help="Transformer layer index") + parser.add_argument("--head", type=int, default=0, help="Attention head index") + parser.add_argument("--seq", type=int, default=32, help="Sequence length") + parser.add_argument("--pattern-length", type=int, default=16, help="Pattern length") + parser.add_argument("--random-pattern", action="store_true", help="Use random token pattern") + parser.add_argument("--seed", type=int, default=0, help="RNG seed for random pattern") + parser.add_argument("--decimals", type=int, default=6, help="Decimal rounding for rationals") + parser.add_argument("--model", default="gpt2", help="HuggingFace model name") + parser.add_argument("--device", default="cpu", help="Torch device") + parser.add_argument("--values-out", help="Optional path for a value-range certificate") + parser.add_argument("--value-dim", type=int, default=0, + help="Value dimension index for the value-range certificate") + args = parser.parse_args() + + if args.seq <= 0: + raise SystemExit("seq must be positive") + + tokens = build_tokens(args.seq, args.pattern_length, args.random_pattern, args.seed) + prev = build_prev(tokens) + + model = GPT2Model.from_pretrained(args.model) + model.to(args.device) + input_ids = torch.tensor(tokens, dtype=torch.long, device=args.device).unsqueeze(0) + + scores, weights, values = compute_scores_weights(model, input_ids, args.layer, args.head, + args.device) + + scores_rat = [[rat_from_float(float(scores[q, k]), args.decimals) for k in range(args.seq)] + for q in range(args.seq)] + weights_rat = [[rat_from_float(float(weights[q, k]), args.decimals) for k in range(args.seq)] + for q in range(args.seq)] + + for q in range(args.seq): + total = sum(weights_rat[q], Fraction(0)) + if total == 0: + raise SystemExit(f"zero weight sum at q={q}") + weights_rat[q] = [w / total for w in weights_rat[q]] + + eps = Fraction(0) + margin = None + for q in range(1, args.seq): + prev_q = prev[q] + prev_w = weights_rat[q][prev_q] + max_other = max(weights_rat[q][k] for k in range(args.seq) if k != prev_q) + deficit = Fraction(1) - prev_w + eps = max(eps, max(max_other, deficit)) + + diffs = [scores_rat[q][prev_q] - scores_rat[q][k] + for k in range(args.seq) if k != prev_q] + if diffs: + min_diff = min(diffs) + margin = min_diff if margin is None else min(margin, min_diff) + + if margin is None: + margin = Fraction(0) + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + write_scores(output_path, args.seq, prev, scores_rat, weights_rat, eps=eps, margin=margin) + + if args.scores_out: + scores_path = Path(args.scores_out) + scores_path.parent.mkdir(parents=True, exist_ok=True) + write_scores(scores_path, args.seq, prev, scores_rat, weights_rat) + + if args.values_out: + if args.value_dim < 0 or args.value_dim >= values.shape[1]: + raise SystemExit(f"value-dim must be in [0, {values.shape[1] - 1}]") + values_path = Path(args.values_out) + values_path.parent.mkdir(parents=True, exist_ok=True) + write_value_range(values_path, args.seq, values[:, args.value_dim], args.decimals) + + print(f"Wrote certificate to {output_path}") + if args.scores_out: + print(f"Wrote scores dump to {scores_path}") + if args.values_out: + print(f"Wrote value-range certificate to {values_path}") + + +if __name__ == "__main__": + main() From 943e80b2f490d9314bf7ab4069592e23a27789e8 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 3 Jan 2026 14:13:54 +0100 Subject: [PATCH 094/244] Add induction certification pipeline and downstream bounds --- AGENTS.md | 26 ++ Nfp.lean | 2 + Nfp/Circuit.lean | 2 + Nfp/Circuit/Cert/DownstreamLinear.lean | 62 +++ Nfp/Circuit/Cert/LogitDiff.lean | 58 +++ Nfp/Circuit/Cert/SoftmaxMargin.lean | 31 +- Nfp/Circuit/Cert/ValueRange.lean | 9 + Nfp/Circuit/Layers.lean | 1 + Nfp/Circuit/Layers/Induction.lean | 130 ++++++- Nfp/Circuit/Layers/Softmax.lean | 127 ++++++ Nfp/Cli.lean | 99 ++++- Nfp/IO.lean | 471 ++++++++++++++++++++-- Nfp/IO/Pure.lean | 496 +++++++++++++++++++++++- Nfp/Model.lean | 9 + Nfp/Model/Gpt2.lean | 64 +++ Nfp/Model/InductionHead.lean | 45 +++ Nfp/Model/InductionPrompt.lean | 32 ++ Nfp/Sound.lean | 9 + Nfp/Sound/Gpt2/HeadInputs.lean | 57 +++ Nfp/Sound/Induction.lean | 439 +++++++++++++++++++++ Nfp/Sound/Linear/FinFold.lean | 52 +++ README.md | 128 +++++- scripts/build_downstream_linear_cert.py | 66 ++++ scripts/build_gpt2_head_inputs.py | 295 ++++++++++++++ scripts/build_gpt2_induction_cert.py | 96 ++++- 25 files changed, 2732 insertions(+), 74 deletions(-) create mode 100644 Nfp/Circuit/Cert/DownstreamLinear.lean create mode 100644 Nfp/Circuit/Cert/LogitDiff.lean create mode 100644 Nfp/Circuit/Layers/Softmax.lean create mode 100644 Nfp/Model.lean create mode 100644 Nfp/Model/Gpt2.lean create mode 100644 Nfp/Model/InductionHead.lean create mode 100644 Nfp/Model/InductionPrompt.lean create mode 100644 Nfp/Sound.lean create mode 100644 Nfp/Sound/Gpt2/HeadInputs.lean create mode 100644 Nfp/Sound/Induction.lean create mode 100644 Nfp/Sound/Linear/FinFold.lean create mode 100644 scripts/build_downstream_linear_cert.py create mode 100644 scripts/build_gpt2_head_inputs.py diff --git a/AGENTS.md b/AGENTS.md index 549fda9..4500f6a 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -287,6 +287,10 @@ but you **must** update this list in the same commit. - Softmax-margin certificate payloads and checker soundness. - `Nfp/Circuit/Cert/ValueRange.lean` - Value-range certificate payloads and checker soundness. +- `Nfp/Circuit/Cert/LogitDiff.lean` + - Logit-diff lower-bound computation for induction certificates. +- `Nfp/Circuit/Cert/DownstreamLinear.lean` + - Downstream linear error certificates for end-to-end induction bounds. - `Nfp/Circuit/Typed.lean` - Typed circuit wrapper and interface-level equivalence checker. - `Nfp/Circuit/Compose.lean` @@ -307,6 +311,8 @@ but you **must** update this list in the same commit. - Reshape combinators for product-typed circuit interfaces. - `Nfp/Circuit/Layers/Heads.lean` - Head split/merge combinators for transformer-shaped indices. +- `Nfp/Circuit/Layers/Softmax.lean` + - Softmax helpers and margin-based bounds for layer reasoning. - `Nfp/Circuit/Layers/Attention.lean` - Q/K/V, output projection wiring, and attention score/mixing core. - `Nfp/Circuit/Layers/Induction.lean` @@ -330,6 +336,26 @@ but you **must** update this list in the same commit. - `Nfp.lean` - Top-level reexports and axioms dashboard (`#print axioms`). +### 5.7 Sound certification +- `Nfp/Sound/Induction.lean` + - Sound builders for induction certificates from exact inputs. +- `Nfp/Sound/Linear/FinFold.lean` + - Tail-recursive folds and sums for sound linear computations. +- `Nfp/Sound/Gpt2/HeadInputs.lean` + - Sound construction of GPT-2 induction head inputs. +- `Nfp/Sound.lean` + - Aggregator for sound certification modules. + +### 5.8 Model inputs +- `Nfp/Model/InductionHead.lean` + - Exact induction-head input payloads (embeddings and projection weights). +- `Nfp/Model/InductionPrompt.lean` + - Prompt utilities (`prev` map and active set for periodic prompts). +- `Nfp/Model/Gpt2.lean` + - Exact GPT-2 head-slice data and embedding helpers. +- `Nfp/Model.lean` + - Aggregator for model input modules. + If you introduce a new conceptual layer: - either extend the closest existing file, - or add a new module with a clear name + top docstring, diff --git a/Nfp.lean b/Nfp.lean index 5b4254f..ee1adc1 100644 --- a/Nfp.lean +++ b/Nfp.lean @@ -5,6 +5,8 @@ import Nfp.Prob import Nfp.Mixer import Nfp.System import Nfp.Circuit +import Nfp.Model +import Nfp.Sound /-! Top-level reexports and trust dashboard for the NFP rewrite. diff --git a/Nfp/Circuit.lean b/Nfp/Circuit.lean index 3b2b6ca..cb952a7 100644 --- a/Nfp/Circuit.lean +++ b/Nfp/Circuit.lean @@ -8,6 +8,8 @@ import Nfp.Circuit.WellFormed import Nfp.Circuit.Cert import Nfp.Circuit.Cert.SoftmaxMargin import Nfp.Circuit.Cert.ValueRange +import Nfp.Circuit.Cert.LogitDiff +import Nfp.Circuit.Cert.DownstreamLinear import Nfp.Circuit.Typed import Nfp.Circuit.Compose import Nfp.Circuit.Gates diff --git a/Nfp/Circuit/Cert/DownstreamLinear.lean b/Nfp/Circuit/Cert/DownstreamLinear.lean new file mode 100644 index 0000000..69bb959 --- /dev/null +++ b/Nfp/Circuit/Cert/DownstreamLinear.lean @@ -0,0 +1,62 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.Order.Ring.Rat +import Nfp.Circuit.Cert + +/-! +Downstream linear certificates for end-to-end induction bounds. + +These certificates record a nonnegative error bound computed externally. +The checker only verifies arithmetic consistency (`error = gain * inputBound`) +and nonnegativity of the reported quantities. +-/ + +namespace Nfp + +namespace Circuit + +/-- Certificate payload for downstream linear error bounds. -/ +structure DownstreamLinearCert where + /-- Upper bound on the downstream logit-diff error. -/ + error : Rat + /-- Operator gain bound used to justify the error. -/ + gain : Rat + /-- Input magnitude bound used to justify the error. -/ + inputBound : Rat + +/-- Arithmetic properties enforced by `checkDownstreamLinearCert`. -/ +structure DownstreamLinearBounds (c : DownstreamLinearCert) : Prop where + /-- Error bound is nonnegative. -/ + error_nonneg : 0 ≤ c.error + /-- Gain bound is nonnegative. -/ + gain_nonneg : 0 ≤ c.gain + /-- Input bound is nonnegative. -/ + input_nonneg : 0 ≤ c.inputBound + /-- Error bound matches the reported gain/input product. -/ + error_eq : c.error = c.gain * c.inputBound + +/-- Boolean checker for downstream linear certificates. -/ +def checkDownstreamLinearCert (c : DownstreamLinearCert) : Bool := + decide (0 ≤ c.error) && + decide (0 ≤ c.gain) && + decide (0 ≤ c.inputBound) && + decide (c.error = c.gain * c.inputBound) + +/-- `checkDownstreamLinearCert` is sound for `DownstreamLinearBounds`. -/ +theorem checkDownstreamLinearCert_sound (c : DownstreamLinearCert) : + checkDownstreamLinearCert c = true → DownstreamLinearBounds c := by + intro h + have h' : + ((0 ≤ c.error ∧ 0 ≤ c.gain) ∧ 0 ≤ c.inputBound) ∧ + c.error = c.gain * c.inputBound := by + simpa [checkDownstreamLinearCert, Bool.and_eq_true] using h + rcases h' with ⟨⟨⟨herror, hgain⟩, hinput⟩, heq⟩ + refine + { error_nonneg := herror + gain_nonneg := hgain + input_nonneg := hinput + error_eq := heq } + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Cert/LogitDiff.lean b/Nfp/Circuit/Cert/LogitDiff.lean new file mode 100644 index 0000000..7d5f731 --- /dev/null +++ b/Nfp/Circuit/Cert/LogitDiff.lean @@ -0,0 +1,58 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.Order.Ring.Rat +import Mathlib.Data.Finset.Image +import Nfp.Circuit.Layers.Induction + +/-! +Lower bounds for logit-diff contributions from induction-style heads. +-/ + +namespace Nfp + +namespace Circuit + +variable {seq : Nat} + +/-- Compute a lower bound on the logit-diff contribution over active queries. -/ +def logitDiffLowerBound (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (eps lo hi : Rat) (vals : Fin seq → Rat) : Option Rat := by + classical + if h : active.Nonempty then + let gap := eps * (hi - lo) + let f : Fin seq → Rat := fun q => vals (prev q) - gap + let img := active.image f + have himg : img.Nonempty := h.image f + exact some (Finset.min' img himg) + else + exact none + +/-- The computed lower bound is below every active `prev` value minus the tolerance gap. -/ +theorem logitDiffLowerBound_le (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (eps lo hi : Rat) (vals : Fin seq → Rat) + (q : Fin seq) (hq : q ∈ active) : + ∀ lb, logitDiffLowerBound active prev eps lo hi vals = some lb → + lb ≤ vals (prev q) - eps * (hi - lo) := by + classical + intro lb hbound + have hnonempty : active.Nonempty := ⟨q, hq⟩ + have hbound' : + (active.image (fun q => vals (prev q) - eps * (hi - lo))).min' + (hnonempty.image (fun q => vals (prev q) - eps * (hi - lo))) = lb := by + simpa [logitDiffLowerBound, hnonempty] using hbound + let gap := eps * (hi - lo) + let f : Fin seq → Rat := fun q => vals (prev q) - gap + have hmem : f q ∈ (active.image f) := by + refine Finset.mem_image.2 ?_ + exact ⟨q, hq, rfl⟩ + have hmin : (active.image f).min' (hnonempty.image f) ≤ f q := + Finset.min'_le _ _ hmem + have hlb : lb = (active.image f).min' (hnonempty.image f) := by + simpa [f, gap] using hbound'.symm + simpa [f, gap, hlb] using hmin + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Cert/SoftmaxMargin.lean b/Nfp/Circuit/Cert/SoftmaxMargin.lean index a59ffe6..f8157a9 100644 --- a/Nfp/Circuit/Cert/SoftmaxMargin.lean +++ b/Nfp/Circuit/Cert/SoftmaxMargin.lean @@ -23,6 +23,8 @@ structure SoftmaxMarginCert (seq : Nat) where eps : Rat /-- Score margin used to justify weight bounds. -/ margin : Rat + /-- Active queries for which bounds are checked. -/ + active : Finset (Fin seq) /-- `prev` selector for induction-style attention. -/ prev : Fin seq → Fin seq /-- Score matrix entries. -/ @@ -33,9 +35,7 @@ structure SoftmaxMarginCert (seq : Nat) where /-- Boolean checker for softmax-margin certificates. -/ def checkSoftmaxMarginCert [NeZero seq] (c : SoftmaxMarginCert seq) : Bool := finsetAll (Finset.univ : Finset (Fin seq)) (fun q => - if q = 0 then - true - else + if q ∈ c.active then finsetAll (Finset.univ : Finset (Fin seq)) (fun k => decide (0 ≤ c.weights q k) && (if k = c.prev q then @@ -48,19 +48,20 @@ def checkSoftmaxMarginCert [NeZero seq] (c : SoftmaxMarginCert seq) : Bool := else decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) && decide (1 ≤ c.weights q (c.prev q) + c.eps) && - decide ((∑ k, c.weights q k) = 1)) + decide ((∑ k, c.weights q k) = 1) + else + true) -/-- `checkSoftmaxMarginCert` is sound for `SoftmaxMarginBounds`. -/ +/-- `checkSoftmaxMarginCert` is sound for `SoftmaxMarginBoundsOn`. -/ theorem checkSoftmaxMarginCert_sound [NeZero seq] (c : SoftmaxMarginCert seq) : checkSoftmaxMarginCert c = true → - Layers.SoftmaxMarginBounds (Val := Rat) c.eps c.margin c.prev c.scores c.weights := by + Layers.SoftmaxMarginBoundsOn (Val := Rat) c.eps c.margin (fun q => q ∈ c.active) + c.prev c.scores c.weights := by classical intro hcheck have hqall : ∀ q ∈ (Finset.univ : Finset (Fin seq)), - (if q = 0 then - true - else + (if q ∈ c.active then finsetAll (Finset.univ : Finset (Fin seq)) (fun k => decide (0 ≤ c.weights q k) && (if k = c.prev q then @@ -73,13 +74,13 @@ theorem checkSoftmaxMarginCert_sound [NeZero seq] (c : SoftmaxMarginCert seq) : else decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) && decide (1 ≤ c.weights q (c.prev q) + c.eps) && - decide ((∑ k, c.weights q k) = 1)) = true := by + decide ((∑ k, c.weights q k) = 1) + else + true) = true := by have hcheck' : checkSoftmaxMarginCert c = true := hcheck have hcheck'' : finsetAll (Finset.univ : Finset (Fin seq)) (fun q => - if q = 0 then - true - else + if q ∈ c.active then finsetAll (Finset.univ : Finset (Fin seq)) (fun k => decide (0 ≤ c.weights q k) && (if k = c.prev q then @@ -92,7 +93,9 @@ theorem checkSoftmaxMarginCert_sound [NeZero seq] (c : SoftmaxMarginCert seq) : else decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) && decide (1 ≤ c.weights q (c.prev q) + c.eps) && - decide ((∑ k, c.weights q k) = 1)) = true := by + decide ((∑ k, c.weights q k) = 1) + else + true) = true := by simpa [checkSoftmaxMarginCert] using hcheck' exact (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hcheck'' refine diff --git a/Nfp/Circuit/Cert/ValueRange.lean b/Nfp/Circuit/Cert/ValueRange.lean index c18a46b..33d03f8 100644 --- a/Nfp/Circuit/Cert/ValueRange.lean +++ b/Nfp/Circuit/Cert/ValueRange.lean @@ -17,6 +17,13 @@ open scoped BigOperators variable {seq : Nat} +/-- Metadata describing a logit-diff direction (target minus negative token). -/ +structure DirectionSpec where + /-- Target token id for the logit-diff direction. -/ + target : Nat + /-- Negative token id for the logit-diff direction. -/ + negative : Nat + /-- Certificate payload for value-range bounds (Rat-valued). -/ structure ValueRangeCert (seq : Nat) where /-- Lower bound for values. -/ @@ -25,6 +32,8 @@ structure ValueRangeCert (seq : Nat) where hi : Rat /-- Value entries. -/ vals : Fin seq → Rat + /-- Optional logit-diff direction metadata (ignored by the checker). -/ + direction : Option DirectionSpec /-- Boolean checker for value-range certificates. -/ def checkValueRangeCert [NeZero seq] (c : ValueRangeCert seq) : Bool := diff --git a/Nfp/Circuit/Layers.lean b/Nfp/Circuit/Layers.lean index dfbe548..5a54a6d 100644 --- a/Nfp/Circuit/Layers.lean +++ b/Nfp/Circuit/Layers.lean @@ -5,6 +5,7 @@ import Nfp.Circuit.Layers.Tensor import Nfp.Circuit.Layers.Reshape import Nfp.Circuit.Layers.Heads import Nfp.Circuit.Layers.Attention +import Nfp.Circuit.Layers.Softmax import Nfp.Circuit.Layers.Induction import Nfp.Circuit.Layers.TransformerBlock diff --git a/Nfp/Circuit/Layers/Induction.lean b/Nfp/Circuit/Layers/Induction.lean index 33d2dcc..9b2ac97 100644 --- a/Nfp/Circuit/Layers/Induction.lean +++ b/Nfp/Circuit/Layers/Induction.lean @@ -74,6 +74,12 @@ def InductionSpecApprox (ε : Val) (out vals : Fin (Nat.succ n) → Val) : Prop := ∀ q, q ≠ 0 → out q ≤ vals (prev q) + ε ∧ vals (prev q) ≤ out q + ε +/-- Approximate induction-head spec restricted to active queries. -/ +def InductionSpecApproxOn (ε : Val) (active : Fin (Nat.succ n) → Prop) + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (out vals : Fin (Nat.succ n) → Val) : Prop := + ∀ q, active q → out q ≤ vals (prev q) + ε ∧ vals (prev q) ≤ out q + ε + /-- Exact induction spec implies the approximate spec for any nonnegative tolerance. -/ theorem inductionSpecApprox_of_spec (ε : Val) (hε : 0 ≤ ε) (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) @@ -175,6 +181,34 @@ structure OneHotApproxBoundsOn (ε : Val) (prev : Fin seq → Fin seq) /-- Non-prev weights are at most `ε` on nonzero queries. -/ other_le : ∀ q, q ≠ 0 → ∀ k, k ≠ prev q → weights q k ≤ ε +/-- Approximate one-hot bounds for attention weights on active queries. -/ +structure OneHotApproxBoundsOnActive (ε : Val) (active : Fin seq → Prop) + (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) : Prop where + /-- All weights are nonnegative on active queries. -/ + nonneg : ∀ q, active q → ∀ k, 0 ≤ weights q k + /-- Weights sum to one on active queries. -/ + sum_one : ∀ q, active q → (∑ k, weights q k) = 1 + /-- The `prev` weight is within `ε` of one on active queries. -/ + prev_large : ∀ q, active q → 1 ≤ weights q (prev q) + ε + /-- Non-prev weights are at most `ε` on active queries. -/ + other_le : ∀ q, active q → ∀ k, k ≠ prev q → weights q k ≤ ε + +/-- Lift global approximate bounds to an active-set version. -/ +theorem oneHotApproxBoundsOnActive_of_on (ε : Val) (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) + (h : OneHotApproxBoundsOn (Val := Val) ε prev weights) : + OneHotApproxBoundsOnActive (Val := Val) ε (fun q => q ≠ 0) prev weights := by + refine { nonneg := ?_, sum_one := ?_, prev_large := ?_, other_le := ?_ } + · intro q hq k + exact h.nonneg q hq k + · intro q hq + exact h.sum_one q hq + · intro q hq + exact h.prev_large q hq + · intro q hq k hk + exact h.other_le q hq k hk + /-- Approximate induction weights: prev weight near one, others at most `ε`. -/ def InductionWeightsApprox (ε : Val) (prev : Fin seq → Fin seq) (weights : Fin seq → Fin seq → Val) : Prop := @@ -199,15 +233,16 @@ variable {n : Nat} local instance : NeZero (Nat.succ n) := ⟨by simp⟩ -/-- Approximate one-hot weights plus bounded values yield an approximate induction spec. -/ -theorem inductionSpecApprox_of_oneHotApprox_valueRange - (ε lo hi : Val) +/-- Approximate one-hot weights plus bounded values yield an approximate induction spec + on active queries. -/ +theorem inductionSpecApproxOn_of_oneHotApprox_valueRange + (ε lo hi : Val) (active : Fin (Nat.succ n) → Prop) (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) (weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Val) (vals : Fin (Nat.succ n) → Val) - (hweights : OneHotApproxBoundsOn (Val := Val) ε prev weights) + (hweights : OneHotApproxBoundsOnActive (Val := Val) ε active prev weights) (hvals : ValueRangeBounds (Val := Val) lo hi vals) : - InductionSpecApprox (Val := Val) (n := n) (ε * (hi - lo)) prev + InductionSpecApproxOn (Val := Val) (n := n) (ε * (hi - lo)) active prev (fun q => dotProduct (weights q) vals) vals := by classical intro q hq @@ -408,6 +443,34 @@ theorem inductionSpecApprox_of_oneHotApprox_valueRange exact (sub_le_iff_le_add).1 hlow exact ⟨hupper, hlower⟩ +/-- Approximate one-hot weights plus bounded values yield an approximate induction spec. -/ +theorem inductionSpecApprox_of_oneHotApprox_valueRange + (ε lo hi : Val) + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Val) + (vals : Fin (Nat.succ n) → Val) + (hweights : OneHotApproxBoundsOn (Val := Val) ε prev weights) + (hvals : ValueRangeBounds (Val := Val) lo hi vals) : + InductionSpecApprox (Val := Val) (n := n) (ε * (hi - lo)) prev + (fun q => dotProduct (weights q) vals) vals := by + have hweights' : + OneHotApproxBoundsOnActive (Val := Val) ε (fun q => q ≠ 0) prev weights := + oneHotApproxBoundsOnActive_of_on (Val := Val) (seq := Nat.succ n) + (ε := ε) (prev := prev) (weights := weights) hweights + exact + inductionSpecApproxOn_of_oneHotApprox_valueRange + (Val := Val) + (n := n) + (ε := ε) + (lo := lo) + (hi := hi) + (active := fun q => q ≠ 0) + (prev := prev) + (weights := weights) + (vals := vals) + (hweights := hweights') + (hvals := hvals) + end ApproxOutput section SoftmaxMargin @@ -429,6 +492,43 @@ structure SoftmaxMarginBounds (ε margin : Val) (prev : Fin seq → Fin seq) /-- Non-prev weights are at most `ε` on nonzero queries. -/ other_le : ∀ q, q ≠ 0 → ∀ k, k ≠ prev q → weights q k ≤ ε +/-- Softmax margin certificates for approximate one-hot weights on active queries. -/ +structure SoftmaxMarginBoundsOn (ε margin : Val) (active : Fin seq → Prop) + (prev : Fin seq → Fin seq) + (scores weights : Fin seq → Fin seq → Val) : Prop where + /-- Score gap between `prev` and other keys on active queries. -/ + score_margin : ∀ q, active q → ∀ k, k ≠ prev q → scores q k + margin ≤ scores q (prev q) + /-- All weights are nonnegative on active queries. -/ + nonneg : ∀ q, active q → ∀ k, 0 ≤ weights q k + /-- Weights sum to one on active queries. -/ + sum_one : ∀ q, active q → (∑ k, weights q k) = 1 + /-- The `prev` weight is within `ε` of one on active queries. -/ + prev_large : ∀ q, active q → 1 ≤ weights q (prev q) + ε + /-- Non-prev weights are at most `ε` on active queries. -/ + other_le : ∀ q, active q → ∀ k, k ≠ prev q → weights q k ≤ ε + +/-- Lift global softmax-margin bounds to an active-set version. -/ +theorem softmaxMarginBoundsOn_of_on (ε margin : Val) (prev : Fin seq → Fin seq) + (scores weights : Fin seq → Fin seq → Val) + (h : SoftmaxMarginBounds (Val := Val) ε margin prev scores weights) : + SoftmaxMarginBoundsOn (Val := Val) ε margin (fun q => q ≠ 0) prev scores weights := by + refine + { score_margin := ?_ + nonneg := ?_ + sum_one := ?_ + prev_large := ?_ + other_le := ?_ } + · intro q hq k hk + exact h.score_margin q hq k hk + · intro q hq k + exact h.nonneg q hq k + · intro q hq + exact h.sum_one q hq + · intro q hq + exact h.prev_large q hq + · intro q hq k hk + exact h.other_le q hq k hk + /-- Margin certificates yield approximate one-hot bounds for the weights. -/ theorem oneHotApproxBounds_of_softmaxMargin (ε margin : Val) (prev : Fin seq → Fin seq) (scores weights : Fin seq → Fin seq → Val) @@ -464,6 +564,26 @@ theorem inductionWeightsApprox_of_softmaxMargin (ε margin : Val) end SoftmaxMargin +section SoftmaxMarginActive + +variable {Val : Type v} [Semiring Val] [PartialOrder Val] +variable {seq : Nat} + +/-- Margin certificates yield approximate one-hot bounds on active queries. -/ +theorem oneHotApproxBoundsOnActive_of_softmaxMargin (ε margin : Val) + (active : Fin seq → Prop) + (prev : Fin seq → Fin seq) + (scores weights : Fin seq → Fin seq → Val) + (h : SoftmaxMarginBoundsOn (Val := Val) ε margin active prev scores weights) : + OneHotApproxBoundsOnActive (Val := Val) ε active prev weights := by + exact + { nonneg := h.nonneg + sum_one := h.sum_one + prev_large := h.prev_large + other_le := h.other_le } + +end SoftmaxMarginActive + section Attention variable {Batch : Type} [Fintype Batch] [DecidableEq Batch] diff --git a/Nfp/Circuit/Layers/Softmax.lean b/Nfp/Circuit/Layers/Softmax.lean new file mode 100644 index 0000000..8a11656 --- /dev/null +++ b/Nfp/Circuit/Layers/Softmax.lean @@ -0,0 +1,127 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.BigOperators.Field +import Mathlib.Algebra.Order.BigOperators.Group.Finset +import Mathlib.Analysis.Complex.Exponential +import Mathlib.Data.Finset.Card + +/-! +Real-valued softmax utilities and margin-based bounds. + +These lemmas provide the analytical bridge from score gaps to softmax weight +upper bounds. +-/ + +namespace Nfp + +namespace Circuit + +open scoped BigOperators + +noncomputable section + +variable {seq : Nat} + +/-- Real softmax over a finite score vector. -/ +def softmax (scores : Fin seq → Real) (k : Fin seq) : Real := + Real.exp (scores k) / ∑ j, Real.exp (scores j) + +private lemma softmax_denom_pos [NeZero seq] (scores : Fin seq → Real) : + 0 < ∑ j, Real.exp (scores j) := by + classical + have hnonempty : (Finset.univ : Finset (Fin seq)).Nonempty := by + refine ⟨⟨0, ?_⟩, by simp⟩ + exact Nat.pos_of_ne_zero (NeZero.ne seq) + exact Finset.sum_pos (fun _ _ => Real.exp_pos _) hnonempty + +lemma softmax_nonneg [NeZero seq] (scores : Fin seq → Real) (k : Fin seq) : + 0 ≤ softmax scores k := by + have hdenom : 0 < ∑ j, Real.exp (scores j) := softmax_denom_pos scores + exact (div_nonneg (Real.exp_pos _).le (le_of_lt hdenom)) + +lemma softmax_sum_one [NeZero seq] (scores : Fin seq → Real) : + (∑ k, softmax scores k) = 1 := by + classical + have hdenom : (∑ j, Real.exp (scores j)) ≠ 0 := + ne_of_gt (softmax_denom_pos scores) + have hsum : + (∑ k, Real.exp (scores k) / ∑ j, Real.exp (scores j)) = + (∑ k, Real.exp (scores k)) / ∑ j, Real.exp (scores j) := by + simpa using + (Finset.sum_div (Finset.univ) (fun k => Real.exp (scores k)) + (∑ j, Real.exp (scores j))).symm + calc + ∑ k, softmax scores k + = ∑ k, Real.exp (scores k) / ∑ j, Real.exp (scores j) := by + simp [softmax] + _ = (∑ k, Real.exp (scores k)) / ∑ j, Real.exp (scores j) := hsum + _ = 1 := by + simp [hdenom] + +lemma softmax_le_one [NeZero seq] (scores : Fin seq → Real) (k : Fin seq) : + softmax scores k ≤ 1 := by + classical + have hdenom_pos : 0 < ∑ j, Real.exp (scores j) := softmax_denom_pos scores + have hnum_le : Real.exp (scores k) ≤ ∑ j, Real.exp (scores j) := by + have hnonneg : ∀ j ∈ (Finset.univ : Finset (Fin seq)), 0 ≤ Real.exp (scores j) := + fun _ _ => (Real.exp_pos _).le + simpa using (Finset.single_le_sum hnonneg (by simp)) + have hdiv := (div_le_one hdenom_pos).2 hnum_le + simpa [softmax] using hdiv + +lemma exp_neg_le_inv_one_add {m : Real} (hm : 0 ≤ m) : + Real.exp (-m) ≤ 1 / (1 + m) := by + have hpos : 0 < 1 + m := add_pos_of_pos_of_nonneg zero_lt_one hm + have hle : 1 + m ≤ Real.exp m := by + simpa [add_comm] using (Real.add_one_le_exp m) + have hdiv : 1 / Real.exp m ≤ 1 / (1 + m) := + one_div_le_one_div_of_le hpos hle + simpa [Real.exp_neg] using hdiv + +lemma softmax_other_le_exp_neg [NeZero seq] (scores : Fin seq → Real) + {prev k : Fin seq} {m : Real} (hmargin : scores k + m ≤ scores prev) : + softmax scores k ≤ Real.exp (-m) := by + classical + let denom : Real := ∑ j, Real.exp (scores j) + have hdenom_pos : 0 < denom := softmax_denom_pos scores + have hdenom_ge : Real.exp (scores prev) ≤ denom := by + have hnonneg : ∀ j ∈ (Finset.univ : Finset (Fin seq)), 0 ≤ Real.exp (scores j) := + fun _ _ => (Real.exp_pos _).le + simpa [denom] using + (Finset.single_le_sum hnonneg (by simp : prev ∈ (Finset.univ : Finset (Fin seq)))) + have hinv : 1 / denom ≤ 1 / Real.exp (scores prev) := + one_div_le_one_div_of_le (Real.exp_pos _) hdenom_ge + have hmul := + mul_le_mul_of_nonneg_left hinv (Real.exp_pos (scores k)).le + have hratio : + Real.exp (scores k) / Real.exp (scores prev) = + Real.exp (scores k - scores prev) := by + symm + exact Real.exp_sub (scores k) (scores prev) + have hk : scores k ≤ scores prev - m := (le_sub_iff_add_le).2 hmargin + have hdiff : scores k - scores prev ≤ -m := by + have hsub := sub_le_sub_right hk (scores prev) + simpa [sub_eq_add_neg, add_assoc, add_left_comm, add_comm] using hsub + have hle : Real.exp (scores k - scores prev) ≤ Real.exp (-m) := + Real.exp_le_exp.mpr hdiff + have hsoft : + Real.exp (scores k) / denom ≤ Real.exp (scores k) / Real.exp (scores prev) := by + simpa [denom, div_eq_mul_inv] using hmul + calc + softmax scores k + = Real.exp (scores k) / denom := by + simp [softmax, denom] + _ ≤ Real.exp (scores k) / Real.exp (scores prev) := hsoft + _ = Real.exp (scores k - scores prev) := hratio + _ ≤ Real.exp (-m) := hle + +lemma softmax_other_le_inv_one_add [NeZero seq] (scores : Fin seq → Real) + {prev k : Fin seq} {m : Real} (hm : 0 ≤ m) (hmargin : scores k + m ≤ scores prev) : + softmax scores k ≤ 1 / (1 + m) := + (softmax_other_le_exp_neg (scores := scores) hmargin).trans (exp_neg_le_inv_one_add hm) + +end + +end Circuit + +end Nfp diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index bf61b4a..0fcf11b 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -29,7 +29,12 @@ def versionCmd : Cmd := `[Cli| def runInductionCertify (p : Parsed) : IO UInt32 := do let scoresPath := p.flag! "scores" |>.as! String let valuesPath? := (p.flag? "values").map (·.as! String) - IO.runInductionCertify scoresPath valuesPath? + let minActive? := (p.flag? "min-active").map (·.as! Nat) + let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) + let minMarginStr? := (p.flag? "min-margin").map (·.as! String) + let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) + IO.runInductionCertify scoresPath valuesPath? minActive? minLogitDiffStr? + minMarginStr? maxEpsStr? /-- `nfp induction certify` subcommand. -/ def inductionCertifyCmd : Cmd := `[Cli| @@ -38,6 +43,93 @@ def inductionCertifyCmd : Cmd := `[Cli| FLAGS: scores : String; "Path to the softmax-margin certificate file." values : String; "Optional path to a value-range certificate file." + "min-active" : Nat; "Optional minimum number of active queries required \ + (default: max 1 (seq/8))." + "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ + (rational literal; requires --values). Defaults \ + to 0 when direction metadata is present." + "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." + "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." +] + +/-- `nfp induction certify-sound` subcommand. -/ +def runInductionCertifySound (p : Parsed) : IO UInt32 := do + let scoresPath := p.flag! "scores" |>.as! String + let valuesPath := p.flag! "values" |>.as! String + let minActive? := (p.flag? "min-active").map (·.as! Nat) + let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) + let minMarginStr? := (p.flag? "min-margin").map (·.as! String) + let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) + IO.runInductionCertifySound scoresPath valuesPath minActive? minLogitDiffStr? + minMarginStr? maxEpsStr? + +/-- `nfp induction certify_sound` subcommand. -/ +def inductionCertifySoundCmd : Cmd := `[Cli| + certify_sound VIA runInductionCertifySound; + "Check induction certificates from raw scores/values." + FLAGS: + scores : String; "Path to the raw scores/weights file." + values : String; "Path to the raw value entries file." + "min-active" : Nat; "Optional minimum number of active queries required \ + (default: max 1 (seq/8))." + "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ + (rational literal). Defaults to 0 when \ + direction metadata is present." + "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." + "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." +] + +/-- `nfp induction certify_end_to_end` subcommand. -/ +def runInductionCertifyEndToEnd (p : Parsed) : IO UInt32 := do + let scoresPath := p.flag! "scores" |>.as! String + let valuesPath := p.flag! "values" |>.as! String + let downstreamPath := p.flag! "downstream" |>.as! String + let minActive? := (p.flag? "min-active").map (·.as! Nat) + let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) + let minMarginStr? := (p.flag? "min-margin").map (·.as! String) + let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) + IO.runInductionCertifyEndToEnd scoresPath valuesPath downstreamPath + minActive? minLogitDiffStr? minMarginStr? maxEpsStr? + +/-- `nfp induction certify_end_to_end` subcommand. -/ +def inductionCertifyEndToEndCmd : Cmd := `[Cli| + certify_end_to_end VIA runInductionCertifyEndToEnd; + "Check end-to-end induction bounds with a downstream error certificate." + FLAGS: + scores : String; "Path to the softmax-margin certificate file." + values : String; "Path to the value-range certificate file." + downstream : String; "Path to the downstream linear certificate file." + "min-active" : Nat; "Optional minimum number of active queries required \ + (default: max 1 (seq/8))." + "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ + (rational literal). Defaults to 0 when \ + direction metadata is present." + "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." + "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." +] + +/-- `nfp induction certify_head` subcommand. -/ +def runInductionCertifyHead (p : Parsed) : IO UInt32 := do + let inputsPath := p.flag! "inputs" |>.as! String + let minActive? := (p.flag? "min-active").map (·.as! Nat) + let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) + let minMarginStr? := (p.flag? "min-margin").map (·.as! String) + let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) + IO.runInductionCertifyHead inputsPath minActive? minLogitDiffStr? + minMarginStr? maxEpsStr? + +/-- `nfp induction certify_head` subcommand. -/ +def inductionCertifyHeadCmd : Cmd := `[Cli| + certify_head VIA runInductionCertifyHead; + "Check induction certificates from exact head inputs." + FLAGS: + inputs : String; "Path to the induction head input file." + "min-active" : Nat; "Optional minimum number of active queries required \ + (default: max 1 (seq/8))." + "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ + (rational literal). Defaults to 0." + "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." + "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." ] /-- Induction-head subcommands. -/ @@ -45,7 +137,10 @@ def inductionCmd : Cmd := `[Cli| induction NOOP; "Induction-head utilities." SUBCOMMANDS: - inductionCertifyCmd + inductionCertifyCmd; + inductionCertifySoundCmd; + inductionCertifyEndToEndCmd; + inductionCertifyHeadCmd ] /-- The root CLI command. -/ diff --git a/Nfp/IO.lean b/Nfp/IO.lean index e42ea6e..4d1a85d 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -1,6 +1,9 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Nfp.IO.Pure +import Nfp.Circuit.Cert.LogitDiff +import Nfp.Circuit.Cert.DownstreamLinear +import Nfp.Sound.Induction /-! IO wrappers for loading and checking induction certificates. @@ -18,12 +21,37 @@ def loadSoftmaxMarginCert (path : System.FilePath) : let data ← IO.FS.readFile path return Pure.parseSoftmaxMarginCert data +/-- Load raw softmax-margin inputs from disk. -/ +def loadSoftmaxMarginRaw (path : System.FilePath) : + IO (Except String (Sigma Pure.SoftmaxMarginRaw)) := do + let data ← IO.FS.readFile path + return Pure.parseSoftmaxMarginRaw data + /-- Load a value-range certificate from disk. -/ def loadValueRangeCert (path : System.FilePath) : IO (Except String (Sigma ValueRangeCert)) := do let data ← IO.FS.readFile path return Pure.parseValueRangeCert data +/-- Load a downstream linear certificate from disk. -/ +def loadDownstreamLinearCert (path : System.FilePath) : + IO (Except String DownstreamLinearCert) := do + let data ← IO.FS.readFile path + return Pure.parseDownstreamLinearCert data + +/-- Load raw value-range inputs from disk. -/ +def loadValueRangeRaw (path : System.FilePath) : + IO (Except String (Sigma Pure.ValueRangeRaw)) := do + let data ← IO.FS.readFile path + return Pure.parseValueRangeRaw data + +/-- Load induction head inputs from disk. -/ +def loadInductionHeadInputs (path : System.FilePath) : + IO (Except String (Sigma (fun seq => + Sigma (fun dModel => Sigma (fun dHead => Model.InductionHeadInputs seq dModel dHead))))) := do + let data ← IO.FS.readFile path + return Pure.parseInductionHeadInputs data + private def checkSoftmaxMargin (seq : Nat) (cert : SoftmaxMarginCert seq) : IO (Except String Unit) := match seq with @@ -50,44 +78,439 @@ private def checkValueRange (seq : Nat) (cert : ValueRangeCert seq) : else return Except.error "value-range certificate rejected" +private def parseRatOpt (label : String) (raw? : Option String) : + Except String (Option Rat) := + match raw? with + | none => Except.ok none + | some raw => + match Pure.parseRat raw with + | Except.ok v => Except.ok (some v) + | Except.error msg => Except.error s!"invalid {label}: {msg}" + /-- Check induction certificates and print a short status line. -/ def runInductionCertify (scoresPath : System.FilePath) - (valuesPath? : Option System.FilePath) : IO UInt32 := do - let parsedScores ← loadSoftmaxMarginCert scoresPath - match parsedScores with - | Except.error msg => + (valuesPath? : Option System.FilePath) (minActive? : Option Nat) + (minLogitDiffStr? : Option String) (minMarginStr? : Option String) + (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seq, cert⟩ => - let scoresOk ← checkSoftmaxMargin seq cert - match scoresOk with + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (1 / 2 : Rat) + if minLogitDiff?.isSome && valuesPath?.isNone then + IO.eprintln "error: min-logit-diff requires --values" + return 2 + let parsedScores ← loadSoftmaxMarginCert scoresPath + match parsedScores with | Except.error msg => IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - match valuesPath? with - | none => - IO.println s!"ok: softmax-margin certificate accepted (seq={seq})" - return 0 - | some valuesPath => + return 1 + | Except.ok ⟨seq, cert⟩ => + let scoresOk ← checkSoftmaxMargin seq cert + match scoresOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let activeCount := cert.active.card + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" + return 2 + if cert.margin < minMargin then + IO.eprintln + s!"error: margin {cert.margin} below minimum {minMargin}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {cert.eps} above maximum {maxEps}" + return 2 + match valuesPath? with + | none => + IO.println + s!"ok: softmax-margin certificate accepted \ + (seq={seq}, active={activeCount})" + return 0 + | some valuesPath => + let parsedValues ← loadValueRangeCert valuesPath + match parsedValues with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seqVals, certVals⟩ => + if hseq : seqVals ≠ seq then + IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" + return 2 + else + have hseq' : seqVals = seq := by + exact (not_ne_iff).1 hseq + let certVals' : ValueRangeCert seq := by + simpa [hseq'] using certVals + let valuesOk ← checkValueRange seq certVals' + match valuesOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let tol := cert.eps * (certVals'.hi - certVals'.lo) + let logitDiffLB? := + Circuit.logitDiffLowerBound cert.active cert.prev cert.eps + certVals'.lo certVals'.hi certVals'.vals + let effectiveMinLogitDiff := + match minLogitDiff?, certVals'.direction with + | some v, _ => some v + | none, some _ => some (0 : Rat) + | none, none => none + match logitDiffLB? with + | none => + IO.eprintln "error: empty active set for logit-diff bound" + return (2 : UInt32) + | some logitDiffLB => + let violation? : Option Rat := + match effectiveMinLogitDiff with + | none => none + | some minLogitDiff => + if logitDiffLB < minLogitDiff then + some minLogitDiff + else + none + match violation? with + | some minLogitDiff => + IO.eprintln + s!"error: logitDiffLB {logitDiffLB} \ + below minimum {minLogitDiff}" + return (2 : UInt32) + | none => + IO.println + s!"ok: induction bound certified \ + (seq={seq}, active={activeCount}, tol={tol}, \ + logitDiffLB={logitDiffLB})" + return 0 + +/-- Build and check induction certificates from raw scores/values. -/ +def runInductionCertifySound (scoresPath : System.FilePath) + (valuesPath : System.FilePath) (minActive? : Option Nat) + (minLogitDiffStr? : Option String) (minMarginStr? : Option String) + (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (1 / 2 : Rat) + let parsedScores ← loadSoftmaxMarginRaw scoresPath + match parsedScores with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seq, raw⟩ => + match seq with + | 0 => + IO.eprintln "error: seq must be positive" + return 2 + | Nat.succ n => + let seq := Nat.succ n + let _ : NeZero seq := ⟨by simp⟩ + match Sound.buildSoftmaxMarginCert? raw.active raw.prev raw.scores raw.weights with + | none => + IO.eprintln "error: softmax-margin inputs rejected" + return 2 + | some ⟨cert, _⟩ => + let activeCount := cert.active.card + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" + return 2 + if cert.margin < minMargin then + IO.eprintln + s!"error: margin {cert.margin} below minimum {minMargin}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {cert.eps} above maximum {maxEps}" + return 2 + let parsedValues ← loadValueRangeRaw valuesPath + match parsedValues with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seqVals, rawVals⟩ => + if hseq : seqVals ≠ seq then + IO.eprintln + s!"error: seq mismatch (scores={seq}, values={seqVals})" + return 2 + else + have hseq' : seqVals = seq := by + exact (not_ne_iff).1 hseq + let rawVals' : Pure.ValueRangeRaw seq := by + simpa [hseq'] using rawVals + match Sound.buildValueRangeCert? rawVals'.vals rawVals'.direction with + | none => + IO.eprintln "error: value-range inputs rejected" + return 2 + | some ⟨certVals, _⟩ => + let tol := cert.eps * (certVals.hi - certVals.lo) + let logitDiffLB? := + Circuit.logitDiffLowerBound cert.active cert.prev cert.eps + certVals.lo certVals.hi certVals.vals + let effectiveMinLogitDiff := + match minLogitDiff?, certVals.direction with + | some v, _ => some v + | none, some _ => some (0 : Rat) + | none, none => none + match logitDiffLB? with + | none => + IO.eprintln "error: empty active set for logit-diff bound" + return 2 + | some logitDiffLB => + let violation? : Option Rat := + match effectiveMinLogitDiff with + | none => none + | some minLogitDiff => + if logitDiffLB < minLogitDiff then + some minLogitDiff + else + none + match violation? with + | some minLogitDiff => + IO.eprintln + s!"error: logitDiffLB {logitDiffLB} \ + below minimum {minLogitDiff}" + return 2 + | none => + IO.println + s!"ok: induction bound certified \ + (seq={seq}, active={activeCount}, \ + tol={tol}, logitDiffLB={logitDiffLB})" + return 0 + +/-- Check end-to-end induction certificates with a downstream error bound. -/ +def runInductionCertifyEndToEnd (scoresPath : System.FilePath) + (valuesPath : System.FilePath) (downstreamPath : System.FilePath) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (1 / 2 : Rat) + let parsedScores ← loadSoftmaxMarginCert scoresPath + match parsedScores with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seq, cert⟩ => + let scoresOk ← checkSoftmaxMargin seq cert + match scoresOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let activeCount := cert.active.card + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" + return 2 + if cert.margin < minMargin then + IO.eprintln + s!"error: margin {cert.margin} below minimum {minMargin}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {cert.eps} above maximum {maxEps}" + return 2 let parsedValues ← loadValueRangeCert valuesPath match parsedValues with | Except.error msg => IO.eprintln s!"error: {msg}" return 1 | Except.ok ⟨seqVals, certVals⟩ => - if seqVals ≠ seq then + if hseq : seqVals ≠ seq then IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" return 2 - let valuesOk ← checkValueRange seqVals certVals - match valuesOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" + else + have hseq' : seqVals = seq := by + exact (not_ne_iff).1 hseq + let certVals' : ValueRangeCert seq := by + simpa [hseq'] using certVals + let valuesOk ← checkValueRange seq certVals' + match valuesOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let logitDiffLB? := + Circuit.logitDiffLowerBound cert.active cert.prev cert.eps + certVals'.lo certVals'.hi certVals'.vals + let effectiveMinLogitDiff := + match minLogitDiff?, certVals'.direction with + | some v, _ => some v + | none, some _ => some (0 : Rat) + | none, none => none + match logitDiffLB? with + | none => + IO.eprintln "error: empty active set for logit-diff bound" + return (2 : UInt32) + | some logitDiffLB => + let parsedDownstream ← loadDownstreamLinearCert downstreamPath + match parsedDownstream with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok downstream => + let downstreamOk := Circuit.checkDownstreamLinearCert downstream + if downstreamOk then + let finalLB := logitDiffLB - downstream.error + let violation? : Option Rat := + match effectiveMinLogitDiff with + | none => none + | some minLogitDiff => + if finalLB < minLogitDiff then + some minLogitDiff + else + none + match violation? with + | some minLogitDiff => + IO.eprintln + s!"error: end-to-end logitDiffLB {finalLB} \ + below minimum {minLogitDiff}" + return (2 : UInt32) + | none => + IO.println + s!"ok: end-to-end induction bound certified \ + (seq={seq}, active={activeCount}, \ + logitDiffLB={logitDiffLB}, \ + downstreamError={downstream.error}, \ + finalLB={finalLB})" + return 0 + else + IO.eprintln "error: downstream certificate rejected" + return 2 + +/-- Build and check induction certificates from exact head inputs. -/ +def runInductionCertifyHead (inputsPath : System.FilePath) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (1 / 2 : Rat) + let parsedInputs ← loadInductionHeadInputs inputsPath + match parsedInputs with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seq, ⟨dModel, ⟨dHead, inputs⟩⟩⟩ => + match seq with + | 0 => + IO.eprintln "error: seq must be positive" + return 2 + | Nat.succ n => + let seq := Nat.succ n + let _ : NeZero seq := ⟨by simp⟩ + match Sound.buildInductionCertFromHead? inputs with + | none => + IO.eprintln "error: head inputs rejected" + return 2 + | some ⟨cert, _hcert⟩ => + let activeCount := cert.active.card + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" + return 2 + if cert.margin < minMargin then + IO.eprintln + s!"error: margin {cert.margin} below minimum {minMargin}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {cert.eps} above maximum {maxEps}" + return 2 + let tol := cert.eps * (cert.values.hi - cert.values.lo) + let logitDiffLB? := + Circuit.logitDiffLowerBound cert.active cert.prev cert.eps + cert.values.lo cert.values.hi cert.values.vals + let effectiveMinLogitDiff := + match minLogitDiff? with + | some v => some v + | none => some (0 : Rat) + match logitDiffLB? with + | none => + IO.eprintln "error: empty active set for logit-diff bound" return 2 - | Except.ok () => - let tol := cert.eps * (certVals.hi - certVals.lo) - IO.println s!"ok: induction bound certified (seq={seq}, tol={tol})" - return 0 + | some logitDiffLB => + let violation? : Option Rat := + match effectiveMinLogitDiff with + | none => none + | some minLogitDiff => + if logitDiffLB < minLogitDiff then + some minLogitDiff + else + none + match violation? with + | some minLogitDiff => + IO.eprintln + s!"error: logitDiffLB {logitDiffLB} \ + below minimum {minLogitDiff}" + return 2 + | none => + IO.println + s!"ok: induction bound certified \ + (seq={seq}, active={activeCount}, \ + tol={tol}, logitDiffLB={logitDiffLB})" + return 0 end IO diff --git a/Nfp/IO/Pure.lean b/Nfp/IO/Pure.lean index da2fb2d..688e209 100644 --- a/Nfp/IO/Pure.lean +++ b/Nfp/IO/Pure.lean @@ -1,11 +1,14 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Mathlib.Algebra.Order.Ring.Rat +import Mathlib.Data.Finset.Insert import Nfp.Circuit.Cert.SoftmaxMargin import Nfp.Circuit.Cert.ValueRange +import Nfp.Circuit.Cert.DownstreamLinear +import Nfp.Model.InductionHead /-! -Pure parsing helpers for softmax-margin and value-range certificates. +Pure parsing helpers for softmax-margin, value-range, and downstream certificates. -/ namespace Nfp @@ -38,7 +41,8 @@ private def parseInt (s : String) : Except String Int := | some n => Except.ok n | none => Except.error s!"expected Int, got '{s}'" -private def parseRat (s : String) : Except String Rat := do +/-- Parse a rational literal of the form `a` or `a/b`. -/ +def parseRat (s : String) : Except String Rat := do match s.splitOn "/" with | [num] => return Rat.ofInt (← parseInt num) @@ -55,6 +59,8 @@ private def parseRat (s : String) : Except String Rat := do private structure SoftmaxMarginParseState (seq : Nat) where eps : Option Rat margin : Option Rat + active : Finset (Fin seq) + activeSeen : Bool prev : Fin seq → Option (Fin seq) scores : Fin seq → Fin seq → Option Rat weights : Fin seq → Fin seq → Option Rat @@ -62,6 +68,8 @@ private structure SoftmaxMarginParseState (seq : Nat) where private def initState (seq : Nat) : SoftmaxMarginParseState seq := { eps := none margin := none + active := ∅ + activeSeen := false prev := fun _ => none scores := fun _ _ => none weights := fun _ _ => none } @@ -87,6 +95,17 @@ private def setPrev {seq : Nat} (st : SoftmaxMarginParseState seq) else throw s!"prev index out of range: q={q}" +private def setActive {seq : Nat} (st : SoftmaxMarginParseState seq) + (q : Nat) : Except String (SoftmaxMarginParseState seq) := do + if hq : q < seq then + let qFin : Fin seq := ⟨q, hq⟩ + if qFin ∈ st.active then + throw s!"duplicate active entry for q={q}" + else + return { st with active := insert qFin st.active, activeSeen := true } + else + throw s!"active index out of range: q={q}" + private def setMatrixEntry {seq : Nat} (mat : Fin seq → Fin seq → Option Rat) (q k : Nat) (v : Rat) : Except String (Fin seq → Fin seq → Option Rat) := do if hq : q < seq then @@ -124,6 +143,8 @@ private def parseLine {seq : Nat} (st : SoftmaxMarginParseState seq) throw "duplicate margin entry" else return { st with margin := some (← parseRat val) } + | ["active", q] => + setActive st (← parseNat q) | ["prev", q, k] => setPrev st (← parseNat q) (← parseNat k) | ["score", q, k, val] => @@ -160,9 +181,15 @@ private def finalizeState {seq : Nat} (hpos : 0 < seq) (st.scores q k).getD 0 let weightsFun : Fin seq → Fin seq → Rat := fun q k => (st.weights q k).getD 0 + let active := + if st.activeSeen then + st.active + else + (Finset.univ : Finset (Fin seq)).erase defaultPrev pure { eps := eps margin := margin + active := active prev := prevFun scores := scoresFun weights := weightsFun } @@ -198,15 +225,89 @@ def parseSoftmaxMarginCert (input : String) : let cert ← finalizeState hpos st return ⟨seq, cert⟩ +/-- Raw softmax-margin payload without `eps`/`margin`. -/ +structure SoftmaxMarginRaw (seq : Nat) where + /-- Active queries for which bounds are required. -/ + active : Finset (Fin seq) + /-- `prev` selector for induction-style attention. -/ + prev : Fin seq → Fin seq + /-- Score matrix entries. -/ + scores : Fin seq → Fin seq → Rat + /-- Attention weight entries. -/ + weights : Fin seq → Fin seq → Rat + +private def finalizeRawState {seq : Nat} (hpos : 0 < seq) + (st : SoftmaxMarginParseState seq) : Except String (SoftmaxMarginRaw seq) := do + if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => (st.prev q).isSome) then + throw "missing prev entries" + if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.scores q k).isSome)) then + throw "missing score entries" + if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.weights q k).isSome)) then + throw "missing weight entries" + let defaultPrev : Fin seq := ⟨0, hpos⟩ + let prevFun : Fin seq → Fin seq := fun q => + (st.prev q).getD defaultPrev + let scoresFun : Fin seq → Fin seq → Rat := fun q k => + (st.scores q k).getD 0 + let weightsFun : Fin seq → Fin seq → Rat := fun q k => + (st.weights q k).getD 0 + let active := + if st.activeSeen then + st.active + else + (Finset.univ : Finset (Fin seq)).erase defaultPrev + pure + { active := active + prev := prevFun + scores := scoresFun + weights := weightsFun } + +/-- Parse a raw softmax-margin payload from text (ignores any `eps`/`margin`). -/ +def parseSoftmaxMarginRaw (input : String) : + Except String (Sigma SoftmaxMarginRaw) := do + let lines := input.splitOn "\n" + let tokens := lines.filterMap cleanTokens + let mut seq? : Option Nat := none + for t in tokens do + match t with + | ["seq", n] => + if seq?.isSome then + throw "duplicate seq entry" + else + seq? := some (← parseNat n) + | _ => pure () + let seq ← + match seq? with + | some v => pure v + | none => throw "missing seq entry" + match seq with + | 0 => throw "seq must be positive" + | Nat.succ n => + let seq := Nat.succ n + let hpos : 0 < seq := Nat.succ_pos n + let st0 : SoftmaxMarginParseState seq := initState seq + let st ← tokens.foldlM (fun st t => + match t with + | ["seq", _] => pure st + | _ => parseLine st t) st0 + let raw ← finalizeRawState hpos st + return ⟨seq, raw⟩ + private structure ValueRangeParseState (seq : Nat) where lo : Option Rat hi : Option Rat vals : Fin seq → Option Rat + directionTarget : Option Nat + directionNegative : Option Nat private def initValueRangeState (seq : Nat) : ValueRangeParseState seq := { lo := none hi := none - vals := fun _ => none } + vals := fun _ => none + directionTarget := none + directionNegative := none } private def setVal {seq : Nat} (st : ValueRangeParseState seq) (k : Nat) (v : Rat) : Except String (ValueRangeParseState seq) := do @@ -240,6 +341,16 @@ private def parseValueLine {seq : Nat} (st : ValueRangeParseState seq) return { st with hi := some (← parseRat val) } | ["val", k, val] => setVal st (← parseNat k) (← parseRat val) + | ["direction-target", tok] => + if st.directionTarget.isSome then + throw "duplicate direction-target entry" + else + return { st with directionTarget := some (← parseNat tok) } + | ["direction-negative", tok] => + if st.directionNegative.isSome then + throw "duplicate direction-negative entry" + else + return { st with directionNegative := some (← parseNat tok) } | _ => throw s!"unrecognized line: '{String.intercalate " " tokens}'" @@ -257,7 +368,14 @@ private def finalizeValueState {seq : Nat} (st : ValueRangeParseState seq) : throw "missing value entries" let valsFun : Fin seq → Rat := fun k => (st.vals k).getD 0 - return { lo := lo, hi := hi, vals := valsFun } + let direction ← + match st.directionTarget, st.directionNegative with + | none, none => pure none + | some target, some negative => + pure (some { target := target, negative := negative }) + | _, _ => + throw "direction metadata requires both direction-target and direction-negative" + return { lo := lo, hi := hi, vals := valsFun, direction := direction } /-- Parse a value-range certificate from a text payload. -/ def parseValueRangeCert (input : String) : @@ -289,6 +407,376 @@ def parseValueRangeCert (input : String) : let cert ← finalizeValueState st return ⟨seq, cert⟩ +/-- Raw value-range payload without `lo`/`hi` bounds. -/ +structure ValueRangeRaw (seq : Nat) where + /-- Value entries. -/ + vals : Fin seq → Rat + /-- Optional logit-diff direction metadata. -/ + direction : Option Circuit.DirectionSpec + +private def finalizeValueRawState {seq : Nat} (st : ValueRangeParseState seq) : + Except String (ValueRangeRaw seq) := do + if !finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.vals k).isSome) then + throw "missing value entries" + let valsFun : Fin seq → Rat := fun k => + (st.vals k).getD 0 + let direction ← + match st.directionTarget, st.directionNegative with + | none, none => pure none + | some target, some negative => + pure (some { target := target, negative := negative }) + | _, _ => + throw "direction metadata requires both direction-target and direction-negative" + return { vals := valsFun, direction := direction } + +/-- Parse a raw value-range payload from text (ignores any `lo`/`hi`). -/ +def parseValueRangeRaw (input : String) : + Except String (Sigma ValueRangeRaw) := do + let lines := input.splitOn "\n" + let tokens := lines.filterMap cleanTokens + let mut seq? : Option Nat := none + for t in tokens do + match t with + | ["seq", n] => + if seq?.isSome then + throw "duplicate seq entry" + else + seq? := some (← parseNat n) + | _ => pure () + let seq ← + match seq? with + | some v => pure v + | none => throw "missing seq entry" + match seq with + | 0 => throw "seq must be positive" + | Nat.succ n => + let seq := Nat.succ n + let st0 : ValueRangeParseState seq := initValueRangeState seq + let st ← tokens.foldlM (fun st t => + match t with + | ["seq", _] => pure st + | _ => parseValueLine st t) st0 + let raw ← finalizeValueRawState st + return ⟨seq, raw⟩ + +private structure DownstreamLinearParseState where + error : Option Rat + gain : Option Rat + inputBound : Option Rat + +private def initDownstreamLinearState : DownstreamLinearParseState := + { error := none, gain := none, inputBound := none } + +private def parseDownstreamLinearLine (st : DownstreamLinearParseState) + (tokens : List String) : Except String DownstreamLinearParseState := do + match tokens with + | ["error", val] => + if st.error.isSome then + throw "duplicate error entry" + else + return { st with error := some (← parseRat val) } + | ["gain", val] => + if st.gain.isSome then + throw "duplicate gain entry" + else + return { st with gain := some (← parseRat val) } + | ["input-bound", val] => + if st.inputBound.isSome then + throw "duplicate input-bound entry" + else + return { st with inputBound := some (← parseRat val) } + | _ => + throw s!"unrecognized line: '{String.intercalate " " tokens}'" + +private def finalizeDownstreamLinearState (st : DownstreamLinearParseState) : + Except String Circuit.DownstreamLinearCert := do + let error ← + match st.error with + | some v => pure v + | none => throw "missing error entry" + let gain ← + match st.gain with + | some v => pure v + | none => throw "missing gain entry" + let inputBound ← + match st.inputBound with + | some v => pure v + | none => throw "missing input-bound entry" + return { error := error, gain := gain, inputBound := inputBound } + +/-- Parse a downstream linear certificate from a text payload. -/ +def parseDownstreamLinearCert (input : String) : + Except String Circuit.DownstreamLinearCert := do + let lines := input.splitOn "\n" + let tokens := lines.filterMap cleanTokens + let st0 := initDownstreamLinearState + let st ← tokens.foldlM (fun st t => parseDownstreamLinearLine st t) st0 + finalizeDownstreamLinearState st + +private def setVecEntry {n : Nat} (vec : Fin n → Option Rat) + (i : Nat) (v : Rat) : Except String (Fin n → Option Rat) := do + if hi : i < n then + let iFin : Fin n := ⟨i, hi⟩ + match vec iFin with + | some _ => + throw s!"duplicate entry for index={i}" + | none => + let vec' : Fin n → Option Rat := fun i' => + if i' = iFin then + some v + else + vec i' + return vec' + else + throw s!"index out of range: i={i}" + +private def setMatEntry {m n : Nat} (mat : Fin m → Fin n → Option Rat) + (i j : Nat) (v : Rat) : Except String (Fin m → Fin n → Option Rat) := do + if hi : i < m then + if hj : j < n then + let iFin : Fin m := ⟨i, hi⟩ + let jFin : Fin n := ⟨j, hj⟩ + match mat iFin jFin with + | some _ => + throw s!"duplicate entry for indices={i},{j}" + | none => + let mat' : Fin m → Fin n → Option Rat := fun i' j' => + if i' = iFin then + if j' = jFin then + some v + else + mat i' j' + else + mat i' j' + return mat' + else + throw s!"index out of range: j={j}" + else + throw s!"index out of range: i={i}" + +private structure HeadParseState (seq dModel dHead : Nat) where + scale : Option Rat + active : Finset (Fin seq) + activeSeen : Bool + prev : Fin seq → Option (Fin seq) + embed : Fin seq → Fin dModel → Option Rat + wq : Fin dModel → Fin dHead → Option Rat + wk : Fin dModel → Fin dHead → Option Rat + wv : Fin dModel → Fin dHead → Option Rat + wo : Fin dModel → Fin dHead → Option Rat + directionTarget : Option Nat + directionNegative : Option Nat + direction : Fin dModel → Option Rat + +private def initHeadState (seq dModel dHead : Nat) : HeadParseState seq dModel dHead := + { scale := none + active := ∅ + activeSeen := false + prev := fun _ => none + embed := fun _ _ => none + wq := fun _ _ => none + wk := fun _ _ => none + wv := fun _ _ => none + wo := fun _ _ => none + directionTarget := none + directionNegative := none + direction := fun _ => none } + +private def setHeadActive {seq dModel dHead : Nat} + (st : HeadParseState seq dModel dHead) (q : Nat) : + Except String (HeadParseState seq dModel dHead) := do + if hq : q < seq then + let qFin : Fin seq := ⟨q, hq⟩ + return { st with active := st.active ∪ {qFin}, activeSeen := true } + else + throw s!"active index out of range: q={q}" + +private def setHeadPrev {seq dModel dHead : Nat} + (st : HeadParseState seq dModel dHead) (q k : Nat) : + Except String (HeadParseState seq dModel dHead) := do + if hq : q < seq then + if hk : k < seq then + let qFin : Fin seq := ⟨q, hq⟩ + let kFin : Fin seq := ⟨k, hk⟩ + match st.prev qFin with + | some _ => + throw s!"duplicate prev entry for q={q}" + | none => + let prev' : Fin seq → Option (Fin seq) := fun q' => + if q' = qFin then + some kFin + else + st.prev q' + return { st with prev := prev' } + else + throw s!"prev index out of range: k={k}" + else + throw s!"prev index out of range: q={q}" + +private def parseHeadLine {seq dModel dHead : Nat} (st : HeadParseState seq dModel dHead) + (tokens : List String) : Except String (HeadParseState seq dModel dHead) := do + match tokens with + | ["scale", val] => + if st.scale.isSome then + throw "duplicate scale entry" + else + return { st with scale := some (← parseRat val) } + | ["active", q] => + setHeadActive st (← parseNat q) + | ["prev", q, k] => + setHeadPrev st (← parseNat q) (← parseNat k) + | ["embed", q, d, val] => + let mat ← setMatEntry st.embed (← parseNat q) (← parseNat d) (← parseRat val) + return { st with embed := mat } + | ["wq", i, j, val] => + let mat ← setMatEntry st.wq (← parseNat i) (← parseNat j) (← parseRat val) + return { st with wq := mat } + | ["wk", i, j, val] => + let mat ← setMatEntry st.wk (← parseNat i) (← parseNat j) (← parseRat val) + return { st with wk := mat } + | ["wv", i, j, val] => + let mat ← setMatEntry st.wv (← parseNat i) (← parseNat j) (← parseRat val) + return { st with wv := mat } + | ["wo", i, j, val] => + let mat ← setMatEntry st.wo (← parseNat i) (← parseNat j) (← parseRat val) + return { st with wo := mat } + | ["direction-target", tok] => + if st.directionTarget.isSome then + throw "duplicate direction-target entry" + else + return { st with directionTarget := some (← parseNat tok) } + | ["direction-negative", tok] => + if st.directionNegative.isSome then + throw "duplicate direction-negative entry" + else + return { st with directionNegative := some (← parseNat tok) } + | ["direction", d, val] => + let vec ← setVecEntry st.direction (← parseNat d) (← parseRat val) + return { st with direction := vec } + | _ => + throw s!"unrecognized line: '{String.intercalate " " tokens}'" + +private def finalizeHeadState {seq dModel dHead : Nat} (hpos : 0 < seq) + (st : HeadParseState seq dModel dHead) : + Except String (Model.InductionHeadInputs seq dModel dHead) := do + let scale ← + match st.scale with + | some v => pure v + | none => throw "missing scale entry" + if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => (st.prev q).isSome) then + throw "missing prev entries" + if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => + finsetAll (Finset.univ : Finset (Fin dModel)) (fun d => (st.embed q d).isSome)) then + throw "missing embed entries" + if !finsetAll (Finset.univ : Finset (Fin dModel)) (fun i => + finsetAll (Finset.univ : Finset (Fin dHead)) (fun j => (st.wq i j).isSome)) then + throw "missing wq entries" + if !finsetAll (Finset.univ : Finset (Fin dModel)) (fun i => + finsetAll (Finset.univ : Finset (Fin dHead)) (fun j => (st.wk i j).isSome)) then + throw "missing wk entries" + if !finsetAll (Finset.univ : Finset (Fin dModel)) (fun i => + finsetAll (Finset.univ : Finset (Fin dHead)) (fun j => (st.wv i j).isSome)) then + throw "missing wv entries" + if !finsetAll (Finset.univ : Finset (Fin dModel)) (fun i => + finsetAll (Finset.univ : Finset (Fin dHead)) (fun j => (st.wo i j).isSome)) then + throw "missing wo entries" + if !finsetAll (Finset.univ : Finset (Fin dModel)) (fun d => (st.direction d).isSome) then + throw "missing direction entries" + let directionSpec ← + match st.directionTarget, st.directionNegative with + | some target, some negative => pure { target := target, negative := negative } + | _, _ => + throw "direction metadata requires both direction-target and direction-negative" + let defaultPrev : Fin seq := ⟨0, hpos⟩ + let prevFun : Fin seq → Fin seq := fun q => + (st.prev q).getD defaultPrev + let embedFun : Fin seq → Fin dModel → Rat := fun q d => + (st.embed q d).getD 0 + let wqFun : Fin dModel → Fin dHead → Rat := fun i j => + (st.wq i j).getD 0 + let wkFun : Fin dModel → Fin dHead → Rat := fun i j => + (st.wk i j).getD 0 + let wvFun : Fin dModel → Fin dHead → Rat := fun i j => + (st.wv i j).getD 0 + let woFun : Fin dModel → Fin dHead → Rat := fun i j => + (st.wo i j).getD 0 + let directionFun : Fin dModel → Rat := fun d => + (st.direction d).getD 0 + let active := + if st.activeSeen then + st.active + else + (Finset.univ : Finset (Fin seq)).erase defaultPrev + pure + { scale := scale + active := active + prev := prevFun + embed := embedFun + wq := wqFun + wk := wkFun + wv := wvFun + wo := woFun + directionSpec := directionSpec + direction := directionFun } + +/-- Parse a raw induction head input payload from text. -/ +def parseInductionHeadInputs (input : String) : + Except String (Sigma (fun seq => + Sigma (fun dModel => Sigma (fun dHead => Model.InductionHeadInputs seq dModel dHead)))) := do + let lines := input.splitOn "\n" + let tokens := lines.filterMap cleanTokens + let mut seq? : Option Nat := none + let mut dModel? : Option Nat := none + let mut dHead? : Option Nat := none + for t in tokens do + match t with + | ["seq", n] => + if seq?.isSome then + throw "duplicate seq entry" + else + seq? := some (← parseNat n) + | ["d_model", n] => + if dModel?.isSome then + throw "duplicate d_model entry" + else + dModel? := some (← parseNat n) + | ["d_head", n] => + if dHead?.isSome then + throw "duplicate d_head entry" + else + dHead? := some (← parseNat n) + | _ => pure () + let seq ← + match seq? with + | some v => pure v + | none => throw "missing seq entry" + let dModel ← + match dModel? with + | some v => pure v + | none => throw "missing d_model entry" + let dHead ← + match dHead? with + | some v => pure v + | none => throw "missing d_head entry" + match seq, dModel, dHead with + | 0, _, _ => throw "seq must be positive" + | _, 0, _ => throw "d_model must be positive" + | _, _, 0 => throw "d_head must be positive" + | Nat.succ n, Nat.succ m, Nat.succ h => + let seq := Nat.succ n + let dModel := Nat.succ m + let dHead := Nat.succ h + let hpos : 0 < seq := Nat.succ_pos n + let st0 : HeadParseState seq dModel dHead := initHeadState seq dModel dHead + let st ← tokens.foldlM (fun st t => + match t with + | ["seq", _] => pure st + | ["d_model", _] => pure st + | ["d_head", _] => pure st + | _ => parseHeadLine st t) st0 + let inputs ← finalizeHeadState hpos st + return ⟨seq, ⟨dModel, ⟨dHead, inputs⟩⟩⟩ + end Pure end IO diff --git a/Nfp/Model.lean b/Nfp/Model.lean new file mode 100644 index 0000000..9b9f188 --- /dev/null +++ b/Nfp/Model.lean @@ -0,0 +1,9 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Model.Gpt2 +import Nfp.Model.InductionHead +import Nfp.Model.InductionPrompt + +/-! +Model-specific data containers for the NFP rewrite. +-/ diff --git a/Nfp/Model/Gpt2.lean b/Nfp/Model/Gpt2.lean new file mode 100644 index 0000000..4607c7b --- /dev/null +++ b/Nfp/Model/Gpt2.lean @@ -0,0 +1,64 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.Order.Ring.Rat +import Nfp.Circuit.Cert.ValueRange + +/-! +Exact GPT-2 head-slice data for induction certification. + +This module holds the precise token embeddings, position embeddings, and head +projection weights needed to build `InductionHeadInputs` for a single head. +-/ + +namespace Nfp + +namespace Model + +open Nfp.Circuit + +/-- Token indices describing a logit-diff direction (target minus negative). -/ +structure DirectionTokens (vocab : Nat) where + /-- Target token index. -/ + target : Fin vocab + /-- Negative token index. -/ + negative : Fin vocab + +/-- Convert `DirectionTokens` to a `DirectionSpec`. -/ +def DirectionTokens.spec {vocab : Nat} (dir : DirectionTokens vocab) : DirectionSpec := + { target := dir.target.val, negative := dir.negative.val } + +/-- Exact GPT-2 head slice needed to build induction-head inputs. -/ +structure Gpt2HeadSlice (seq dModel dHead vocab : Nat) where + /-- Softmax scale factor (e.g. `1/8` for head dim 64). -/ + scale : Rat + /-- Token ids for the prompt. -/ + tokens : Fin seq → Fin vocab + /-- Token embedding matrix. -/ + wte : Fin vocab → Fin dModel → Rat + /-- Positional embedding matrix. -/ + wpe : Fin seq → Fin dModel → Rat + /-- Query projection weights. -/ + wq : Fin dModel → Fin dHead → Rat + /-- Key projection weights. -/ + wk : Fin dModel → Fin dHead → Rat + /-- Value projection weights. -/ + wv : Fin dModel → Fin dHead → Rat + /-- Output projection weights for this head slice. -/ + wo : Fin dModel → Fin dHead → Rat + /-- Direction tokens for logit-diff certification. -/ + direction : DirectionTokens vocab + +/-- Token-plus-position embeddings for a GPT-2 head slice. -/ +def Gpt2HeadSlice.embed {seq dModel dHead vocab : Nat} + (slice : Gpt2HeadSlice seq dModel dHead vocab) : + Fin seq → Fin dModel → Rat := + fun q d => slice.wte (slice.tokens q) d + slice.wpe q d + +/-- Direction vector in model space for a GPT-2 head slice. -/ +def Gpt2HeadSlice.directionVec {seq dModel dHead vocab : Nat} + (slice : Gpt2HeadSlice seq dModel dHead vocab) : Fin dModel → Rat := + fun d => slice.wte slice.direction.target d - slice.wte slice.direction.negative d + +end Model + +end Nfp diff --git a/Nfp/Model/InductionHead.lean b/Nfp/Model/InductionHead.lean new file mode 100644 index 0000000..d402d26 --- /dev/null +++ b/Nfp/Model/InductionHead.lean @@ -0,0 +1,45 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Data.Finset.Basic +import Mathlib.Algebra.Order.Ring.Rat +import Nfp.Circuit.Cert.ValueRange + +/-! +Exact inputs for induction-head scoring and value-direction computations. + +These structures store exact rational inputs (embeddings and weights) for a +single attention head. They are intended to be consumed by sound builders. +-/ + +namespace Nfp + +namespace Model + +open Nfp.Circuit + +/-- Exact head inputs for induction certification. -/ +structure InductionHeadInputs (seq dModel dHead : Nat) where + /-- Softmax scale factor (e.g. `1/8` for GPT-2-small head dim 64). -/ + scale : Rat + /-- Active queries for which bounds are required. -/ + active : Finset (Fin seq) + /-- `prev` selector for induction-style attention. -/ + prev : Fin seq → Fin seq + /-- Token embeddings for the sequence. -/ + embed : Fin seq → Fin dModel → Rat + /-- Query projection weights. -/ + wq : Fin dModel → Fin dHead → Rat + /-- Key projection weights. -/ + wk : Fin dModel → Fin dHead → Rat + /-- Value projection weights. -/ + wv : Fin dModel → Fin dHead → Rat + /-- Output projection weights (head slice). -/ + wo : Fin dModel → Fin dHead → Rat + /-- Logit-diff direction metadata. -/ + directionSpec : DirectionSpec + /-- Logit-diff direction vector in model space. -/ + direction : Fin dModel → Rat + +end Model + +end Nfp diff --git a/Nfp/Model/InductionPrompt.lean b/Nfp/Model/InductionPrompt.lean new file mode 100644 index 0000000..8c5ff61 --- /dev/null +++ b/Nfp/Model/InductionPrompt.lean @@ -0,0 +1,32 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Data.Fintype.Basic + +/-! +Helpers for induction-style prompts. + +These are small, deterministic utilities for constructing the `prev` map and +active-query set from a fixed period. They keep the prompt bookkeeping +separate from the model weights. +-/ + +namespace Nfp + +namespace Model + +/-- `prev` map for a periodic induction prompt: `q ↦ q - period` (truncated at 0). -/ +def prevOfPeriod {seq : Nat} (period : Nat) (q : Fin seq) : Fin seq := + ⟨q.val - period, lt_of_le_of_lt (Nat.sub_le _ _) q.isLt⟩ + +/-- Active queries for a periodic induction prompt (`period ≤ q`). -/ +def activeOfPeriod {seq : Nat} (period : Nat) : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).filter (fun q => period ≤ q.val) + +/-- Membership characterization for `activeOfPeriod`. -/ +theorem mem_activeOfPeriod {seq : Nat} {period : Nat} {q : Fin seq} : + q ∈ activeOfPeriod (seq := seq) period ↔ period ≤ q.val := by + simp [activeOfPeriod] + +end Model + +end Nfp diff --git a/Nfp/Sound.lean b/Nfp/Sound.lean new file mode 100644 index 0000000..2f3e713 --- /dev/null +++ b/Nfp/Sound.lean @@ -0,0 +1,9 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Sound.Gpt2.HeadInputs +import Nfp.Sound.Induction +import Nfp.Sound.Linear.FinFold + +/-! +Sound certificate builders and verified helpers. +-/ diff --git a/Nfp/Sound/Gpt2/HeadInputs.lean b/Nfp/Sound/Gpt2/HeadInputs.lean new file mode 100644 index 0000000..045a406 --- /dev/null +++ b/Nfp/Sound/Gpt2/HeadInputs.lean @@ -0,0 +1,57 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Model.Gpt2 +import Nfp.Model.InductionHead +import Nfp.Model.InductionPrompt + +/-! +Sound builder for GPT-2 induction head inputs. + +This converts exact GPT-2 head slices into `InductionHeadInputs` using a +periodic prompt description. The construction is purely definitional and is +captured by an explicit theorem, so the trusted core does not hide any logic. +-/ + +namespace Nfp + +namespace Sound + +namespace Gpt2 + +open Nfp.Model + +/-- Build induction-head inputs from a GPT-2 head slice and prompt period. -/ +def buildInductionHeadInputs {seq dModel dHead vocab : Nat} + (slice : Gpt2HeadSlice seq dModel dHead vocab) (period : Nat) : + Model.InductionHeadInputs seq dModel dHead := + { scale := slice.scale + active := activeOfPeriod (seq := seq) period + prev := prevOfPeriod (seq := seq) period + embed := slice.embed + wq := slice.wq + wk := slice.wk + wv := slice.wv + wo := slice.wo + directionSpec := slice.direction.spec + direction := slice.directionVec } + +/-- Definitional characterization of `buildInductionHeadInputs`. -/ +theorem buildInductionHeadInputs_def {seq dModel dHead vocab : Nat} + (slice : Gpt2HeadSlice seq dModel dHead vocab) (period : Nat) : + buildInductionHeadInputs slice period = + { scale := slice.scale + active := activeOfPeriod (seq := seq) period + prev := prevOfPeriod (seq := seq) period + embed := slice.embed + wq := slice.wq + wk := slice.wk + wv := slice.wv + wo := slice.wo + directionSpec := slice.direction.spec + direction := slice.directionVec } := rfl + +end Gpt2 + +end Sound + +end Nfp diff --git a/Nfp/Sound/Induction.lean b/Nfp/Sound/Induction.lean new file mode 100644 index 0000000..a62eb4a --- /dev/null +++ b/Nfp/Sound/Induction.lean @@ -0,0 +1,439 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.BigOperators.Group.Finset.Basic +import Mathlib.Algebra.Order.Ring.Rat +import Mathlib.Data.Finset.Lattice.Fold +import Mathlib.Data.Rat.Cast.Order +import Mathlib.Data.Vector.Defs +import Nfp.Circuit.Cert.SoftmaxMargin +import Nfp.Circuit.Cert.ValueRange +import Nfp.Circuit.Layers.Softmax +import Nfp.Model.InductionHead +import Nfp.Sound.Linear.FinFold + +/-! +Sound builders for induction certificates. + +These builders recompute certificate bounds inside Lean from exact inputs and +return proof-carrying results. The head-input path derives softmax tolerances +from score margins rather than trusting external weight dumps. +-/ + +namespace Nfp + +namespace Sound + +open scoped BigOperators + +open Nfp.Circuit + +variable {seq : Nat} + +/-- Cached query projections for head inputs (opaque to avoid kernel reduction). -/ +private opaque qVecVecOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Vector (Vector Rat dHead) seq := + Vector.ofFn (fun q : Fin seq => + Vector.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.embed q j) (fun j => inputs.wq j d))) + +/-- Cached key projections for head inputs (opaque to avoid kernel reduction). -/ +private opaque kVecVecOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Vector (Vector Rat dHead) seq := + Vector.ofFn (fun q : Fin seq => + Vector.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.embed q j) (fun j => inputs.wk j d))) + +/-- Cached value projections for head inputs (opaque to avoid kernel reduction). -/ +private opaque vVecVecOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Vector (Vector Rat dHead) seq := + Vector.ofFn (fun q : Fin seq => + Vector.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.embed q j) (fun j => inputs.wv j d))) + +/-- Cached attention scores for head inputs (opaque to avoid kernel reduction). -/ +private opaque scoresVecOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (qVecVec kVecVec : Vector (Vector Rat dHead) seq) : Vector (Vector Rat seq) seq := + Vector.ofFn (fun q : Fin seq => + Vector.ofFn (fun k : Fin seq => + let qVec : Fin dHead → Rat := fun d => (qVecVec.get q).get d + let kVec : Fin dHead → Rat := fun d => (kVecVec.get k).get d + inputs.scale * (Linear.dotFin dHead (fun d => qVec d) (fun d => kVec d)))) + +/-- Cached direction head for head inputs (opaque to avoid kernel reduction). -/ +private opaque dirHeadVecOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Vector Rat dHead := + Vector.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.wo j d) (fun j => inputs.direction j)) + +/-- Cached value projections for head inputs (opaque to avoid kernel reduction). -/ +private opaque valsVecOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vVecVec : Vector (Vector Rat dHead) seq) (dirHeadVec : Vector Rat dHead) : + Vector Rat seq := + Vector.ofFn (fun k : Fin seq => + let vVec : Fin dHead → Rat := fun d => (vVecVec.get k).get d + let dirHead : Fin dHead → Rat := fun d => dirHeadVec.get d + Linear.dotFin dHead (fun d => vVec d) (fun d => dirHead d)) + +/-- Sound induction-certificate payload built from exact head inputs. -/ +structure InductionHeadCert (seq : Nat) where + /-- Weight tolerance. -/ + eps : Rat + /-- Score margin used to justify the weight tolerance. -/ + margin : Rat + /-- Active queries for which bounds are required. -/ + active : Finset (Fin seq) + /-- `prev` selector for induction-style attention. -/ + prev : Fin seq → Fin seq + /-- Score matrix entries. -/ + scores : Fin seq → Fin seq → Rat + /-- Value-range certificate for the direction values. -/ + values : ValueRangeCert seq + +/-- Soundness predicate for `InductionHeadCert`. -/ +structure InductionHeadCertSound [NeZero seq] (c : InductionHeadCert seq) : Prop where + /-- Softmax weights respect the derived margin bounds. -/ + softmax_bounds : + Layers.SoftmaxMarginBoundsOn (Val := Real) (c.eps : Real) (c.margin : Real) + (fun q => q ∈ c.active) c.prev + (fun q k => (c.scores q k : Real)) + (fun q k => Circuit.softmax (fun j => (c.scores q j : Real)) k) + /-- Value-range bounds hold for the certificate values. -/ + value_bounds : + Layers.ValueRangeBounds (Val := Rat) c.values.lo c.values.hi c.values.vals + +/-- Build and certify a softmax-margin certificate from exact scores/weights. -/ +def buildSoftmaxMarginCert? [NeZero seq] + (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (scores : Fin seq → Fin seq → Rat) + (weights : Fin seq → Fin seq → Rat) : + Option {c : SoftmaxMarginCert seq // checkSoftmaxMarginCert c = true} := by + classical + let otherKeys : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (prev q) + let epsAt : Fin seq → Rat := fun q => + let other := otherKeys q + let maxOther := + if h : other.Nonempty then + other.sup' h (fun k => weights q k) + else + (0 : Rat) + let deficit := (1 : Rat) - weights q (prev q) + max maxOther deficit + let marginAt : Fin seq → Rat := fun q => + let other := otherKeys q + if h : other.Nonempty then + other.inf' h (fun k => scores q (prev q) - scores q k) + else + (0 : Rat) + let eps := + if h : active.Nonempty then + active.sup' h epsAt + else + (0 : Rat) + let margin := + if h : active.Nonempty then + active.inf' h marginAt + else + (0 : Rat) + let cert : SoftmaxMarginCert seq := + { eps := eps + margin := margin + active := active + prev := prev + scores := scores + weights := weights } + if h : checkSoftmaxMarginCert cert = true then + exact some ⟨cert, h⟩ + else + exact none + +/-- Build and certify a value-range certificate from exact values. -/ +def buildValueRangeCert? [NeZero seq] + (vals : Fin seq → Rat) + (direction : Option DirectionSpec) : + Option {c : ValueRangeCert seq // checkValueRangeCert c = true} := by + classical + let _ : Nonempty (Fin seq) := by + refine ⟨⟨0, ?_⟩⟩ + exact Nat.pos_of_ne_zero (NeZero.ne seq) + let univ : Finset (Fin seq) := Finset.univ + let hnonempty : univ.Nonempty := Finset.univ_nonempty + let lo := univ.inf' hnonempty vals + let hi := univ.sup' hnonempty vals + let cert : ValueRangeCert seq := + { lo := lo + hi := hi + vals := vals + direction := direction } + if h : checkValueRangeCert cert = true then + exact some ⟨cert, h⟩ + else + exact none + +/-- Build and certify induction certificates from exact head inputs. -/ +def buildInductionCertFromHead? [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : + Option {c : InductionHeadCert seq // InductionHeadCertSound c} := by + classical + let qVecVec := qVecVecOfInputs inputs + let kVecVec := kVecVecOfInputs inputs + let vVecVec := vVecVecOfInputs inputs + let scoresVec := scoresVecOfInputs inputs qVecVec kVecVec + let scores : Fin seq → Fin seq → Rat := fun q k => + (scoresVec.get q).get k + let otherKeys : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + let marginAt : Fin seq → Rat := fun q => + let other := otherKeys q + if h : other.Nonempty then + other.inf' h (fun k => scores q (inputs.prev q) - scores q k) + else + (0 : Rat) + let margin : Rat := + if h : inputs.active.Nonempty then + inputs.active.inf' h marginAt + else + (0 : Rat) + let eps : Rat := + if margin < 0 then + (1 : Rat) + else + (seq - 1 : Rat) / (1 + margin) + have hseq : (1 : Nat) ≤ seq := + Nat.succ_le_iff.mpr (Nat.pos_of_ne_zero (NeZero.ne seq)) + let dirHeadVec := dirHeadVecOfInputs inputs + let valsVec := valsVecOfInputs inputs vVecVec dirHeadVec + let vals : Fin seq → Rat := fun k => valsVec.get k + exact + match buildValueRangeCert? vals (some inputs.directionSpec) with + | none => none + | some ⟨valCert, hval⟩ => + let cert : InductionHeadCert seq := + { eps := eps + margin := margin + active := inputs.active + prev := inputs.prev + scores := scores + values := valCert } + have hvalues : Layers.ValueRangeBounds (Val := Rat) valCert.lo valCert.hi valCert.vals := + Circuit.checkValueRangeCert_sound valCert hval + let scoresReal : Fin seq → Fin seq → Real := fun q k => (scores q k : Real) + let weights : Fin seq → Fin seq → Real := fun q k => + Circuit.softmax (scoresReal q) k + let others : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + have hscore_margin : + ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → + scores q k + margin ≤ scores q (inputs.prev q) := by + intro q hq k hk + by_cases hactive : inputs.active.Nonempty + · have hmargin_le : margin ≤ marginAt q := by + have hle : margin ≤ inputs.active.inf' hactive marginAt := by + simp [margin, hactive] + have hle_all := + (Finset.le_inf'_iff (s := inputs.active) (H := hactive) (f := marginAt) + (a := margin)).1 hle + exact hle_all q hq + have hother : (otherKeys q).Nonempty := ⟨k, by simp [otherKeys, hk]⟩ + have hgap_le : + marginAt q ≤ scores q (inputs.prev q) - scores q k := by + have hle : marginAt q ≤ + (otherKeys q).inf' hother + (fun k => scores q (inputs.prev q) - scores q k) := by + simp [marginAt, hother] + have hle_all := + (Finset.le_inf'_iff (s := otherKeys q) (H := hother) + (f := fun k => scores q (inputs.prev q) - scores q k) + (a := marginAt q)).1 hle + exact hle_all k (by simp [otherKeys, hk]) + have hgap : margin ≤ scores q (inputs.prev q) - scores q k := + le_trans hmargin_le hgap_le + have hgap' := + add_le_add_left hgap (scores q k) + simpa [sub_eq_add_neg, add_assoc, add_left_comm, add_comm] using hgap' + · exact (hactive ⟨q, hq⟩).elim + have hscore_margin_real : + ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → + scoresReal q k + (margin : Real) ≤ scoresReal q (inputs.prev q) := by + intro q hq k hk + have hrat := hscore_margin q hq k hk + have hreal : + ((scores q k + margin : Rat) : Real) ≤ scores q (inputs.prev q) := by + exact (Rat.cast_le (K := Real)).2 hrat + simpa [scoresReal, Rat.cast_add] using hreal + have hsoftmax_bounds : + Layers.SoftmaxMarginBoundsOn (Val := Real) (eps : Real) (margin : Real) + (fun q => q ∈ inputs.active) inputs.prev scoresReal weights := by + classical + refine + { score_margin := ?_ + nonneg := ?_ + sum_one := ?_ + prev_large := ?_ + other_le := ?_ } + · intro q hq k hk + exact hscore_margin_real q hq k hk + · intro q _ k + simpa [weights] using + (Circuit.softmax_nonneg (scores := scoresReal q) k) + · intro q _ + simpa [weights] using + (Circuit.softmax_sum_one (scores := scoresReal q)) + · intro q hq + have hsum_others_le : (∑ k ∈ others q, weights q k) ≤ (eps : Real) := by + by_cases hneg : margin < 0 + · have heps : (eps : Real) = 1 := by + simp [eps, hneg] + have hsubset : others q ⊆ (Finset.univ : Finset (Fin seq)) := by + intro k hk + simp + have hnonneg : + ∀ k ∈ (Finset.univ : Finset (Fin seq)), 0 ≤ weights q k := by + intro k _ + simpa [weights] using + (Circuit.softmax_nonneg (scores := scoresReal q) k) + have hsum_le : + (∑ k ∈ others q, weights q k) ≤ + ∑ k ∈ (Finset.univ : Finset (Fin seq)), weights q k := + Finset.sum_le_sum_of_subset_of_nonneg hsubset (by + intro k hk _; exact hnonneg k hk) + have hsum_one : (∑ k, weights q k) = 1 := by + simpa [weights] using + (Circuit.softmax_sum_one (scores := scoresReal q)) + have hsum_le' : (∑ k ∈ others q, weights q k) ≤ 1 := by + simpa [hsum_one] using hsum_le + simpa [heps] using hsum_le' + · have hnonneg : 0 ≤ margin := le_of_not_gt hneg + have hnonneg_real : 0 ≤ (margin : Real) := by + exact (Rat.cast_nonneg (K := Real)).2 hnonneg + have hbound : + ∀ k ∈ others q, + weights q k ≤ (1 + (margin : Real))⁻¹ := by + intro k hk + have hkne : k ≠ inputs.prev q := (Finset.mem_erase.mp hk).1 + have hscore := hscore_margin_real q hq k hkne + simpa [weights] using + (Circuit.softmax_other_le_inv_one_add (scores := scoresReal q) + (prev := inputs.prev q) (k := k) (m := (margin : Real)) + hnonneg_real hscore) + have hsum_le : + (∑ k ∈ others q, weights q k) ≤ + ∑ k ∈ others q, (1 + (margin : Real))⁻¹ := + Finset.sum_le_sum hbound + have hsum_const : + (∑ k ∈ others q, (1 + (margin : Real))⁻¹) = + (others q).card * (1 + (margin : Real))⁻¹ := by + simp + have hcard : (others q).card = seq - 1 := by + simp [others, Finset.card_erase_of_mem] + have hsum_le' : + (∑ k ∈ others q, weights q k) ≤ + (seq - 1 : Real) * (1 + (margin : Real))⁻¹ := by + have hsum_le'' := hsum_le.trans_eq hsum_const + have hsum_le''' := hsum_le'' + simp only [hcard, Nat.cast_sub hseq, Nat.cast_one] at hsum_le''' + exact hsum_le''' + have heps : + (eps : Real) = (seq - 1 : Real) * (1 + (margin : Real))⁻¹ := by + simp [eps, hneg, Rat.cast_add, div_eq_mul_inv] + simpa [heps] using hsum_le' + have hsum_eq : + weights q (inputs.prev q) + ∑ k ∈ others q, weights q k = 1 := by + have hsum' : + weights q (inputs.prev q) + ∑ k ∈ others q, weights q k = + ∑ k, weights q k := by + simp [others] + have hsum_one : (∑ k, weights q k) = 1 := by + simpa [weights] using + (Circuit.softmax_sum_one (scores := scoresReal q)) + calc + weights q (inputs.prev q) + ∑ k ∈ others q, weights q k = + ∑ k, weights q k := hsum' + _ = 1 := hsum_one + have hsum_le' : + weights q (inputs.prev q) + ∑ k ∈ others q, weights q k ≤ + weights q (inputs.prev q) + (eps : Real) := by + have hsum_le'' := add_le_add_left hsum_others_le (weights q (inputs.prev q)) + simpa [add_comm, add_left_comm, add_assoc] using hsum_le'' + have hprev : + 1 ≤ weights q (inputs.prev q) + (eps : Real) := by + simpa [hsum_eq] using hsum_le' + exact hprev + · intro q hq k hk + have hsum_others_le : (∑ j ∈ others q, weights q j) ≤ (eps : Real) := by + by_cases hneg : margin < 0 + · have heps : (eps : Real) = 1 := by + simp [eps, hneg] + have hsubset : others q ⊆ (Finset.univ : Finset (Fin seq)) := by + intro j hj + simp + have hnonneg : + ∀ j ∈ (Finset.univ : Finset (Fin seq)), 0 ≤ weights q j := by + intro j _ + simpa [weights] using + (Circuit.softmax_nonneg (scores := scoresReal q) j) + have hsum_le : + (∑ j ∈ others q, weights q j) ≤ + ∑ j ∈ (Finset.univ : Finset (Fin seq)), weights q j := + Finset.sum_le_sum_of_subset_of_nonneg hsubset (by + intro j hj _; exact hnonneg j hj) + have hsum_one : (∑ j, weights q j) = 1 := by + simpa [weights] using + (Circuit.softmax_sum_one (scores := scoresReal q)) + have hsum_le' : (∑ j ∈ others q, weights q j) ≤ 1 := by + simpa [hsum_one] using hsum_le + simpa [heps] using hsum_le' + · have hnonneg : 0 ≤ margin := le_of_not_gt hneg + have hnonneg_real : 0 ≤ (margin : Real) := by + exact (Rat.cast_nonneg (K := Real)).2 hnonneg + have hbound : + ∀ j ∈ others q, + weights q j ≤ (1 + (margin : Real))⁻¹ := by + intro j hj + have hjne : j ≠ inputs.prev q := (Finset.mem_erase.mp hj).1 + have hscore := hscore_margin_real q hq j hjne + simpa [weights] using + (Circuit.softmax_other_le_inv_one_add (scores := scoresReal q) + (prev := inputs.prev q) (k := j) (m := (margin : Real)) + hnonneg_real hscore) + have hsum_le : + (∑ j ∈ others q, weights q j) ≤ + ∑ j ∈ others q, (1 + (margin : Real))⁻¹ := + Finset.sum_le_sum hbound + have hsum_const : + (∑ j ∈ others q, (1 + (margin : Real))⁻¹) = + (others q).card * (1 + (margin : Real))⁻¹ := by + simp + have hcard : (others q).card = seq - 1 := by + simp [others, Finset.card_erase_of_mem] + have hsum_le' : + (∑ j ∈ others q, weights q j) ≤ + (seq - 1 : Real) * (1 + (margin : Real))⁻¹ := by + have hsum_le'' := hsum_le.trans_eq hsum_const + have hsum_le''' := hsum_le'' + simp only [hcard, Nat.cast_sub hseq, Nat.cast_one] at hsum_le''' + exact hsum_le''' + have heps : + (eps : Real) = (seq - 1 : Real) * (1 + (margin : Real))⁻¹ := by + simp [eps, hneg, Rat.cast_add, div_eq_mul_inv] + simpa [heps] using hsum_le' + have hk' : k ∈ others q := by + simp [others, hk] + have hnonneg : + ∀ j ∈ others q, 0 ≤ weights q j := by + intro j _ + simpa [weights] using + (Circuit.softmax_nonneg (scores := scoresReal q) j) + have hle : + weights q k ≤ ∑ j ∈ others q, weights q j := by + have h := Finset.single_le_sum hnonneg hk' + simpa using h + exact hle.trans hsum_others_le + some ⟨cert, { softmax_bounds := hsoftmax_bounds, value_bounds := hvalues }⟩ + +end Sound + +end Nfp diff --git a/Nfp/Sound/Linear/FinFold.lean b/Nfp/Sound/Linear/FinFold.lean new file mode 100644 index 0000000..d764309 --- /dev/null +++ b/Nfp/Sound/Linear/FinFold.lean @@ -0,0 +1,52 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.Order.Ring.Rat +import Batteries.Data.Fin.Fold + +/-! +Tail-recursive folds and sums over `Fin`. + +These helpers keep sound computations stack-safe while remaining explicit. +-/ + +namespace Nfp + +namespace Sound + +namespace Linear + +variable {α : Type _} + +/-- Tail-recursive fold over `Fin n`. -/ +def foldlFin (n : Nat) (f : α → Fin n → α) (init : α) : α := + Fin.dfoldl n (fun _ => α) (fun i acc => f acc i) init + +/-- `foldlFin` matches `Fin.foldl`. -/ +theorem foldlFin_eq_foldl (n : Nat) (f : α → Fin n → α) (init : α) : + foldlFin n f init = Fin.foldl n f init := by + simpa [foldlFin] using + (Fin.dfoldl_eq_foldl (n := n) (f := fun i acc => f acc i) (x := init)) + +/-- Tail-recursive sum over `Fin n` (Rat-valued). -/ +def sumFin (n : Nat) (f : Fin n → Rat) : Rat := + foldlFin n (fun acc i => acc + f i) 0 + +/-- `sumFin` as a left fold over the finite range list. -/ +theorem sumFin_eq_list_foldl (n : Nat) (f : Fin n → Rat) : + sumFin n f = (List.finRange n).foldl (fun acc i => acc + f i) 0 := by + simpa [sumFin, foldlFin_eq_foldl] using + (Fin.foldl_eq_foldl_finRange (f := fun acc i => acc + f i) (x := (0 : Rat)) (n := n)) + +/-- Dot product over `Fin n` (Rat-valued). -/ +def dotFin (n : Nat) (x y : Fin n → Rat) : Rat := + sumFin n (fun i => x i * y i) + +/-- Unfolding lemma for `dotFin`. -/ +theorem dotFin_def (n : Nat) (x y : Fin n → Rat) : + dotFin n x y = sumFin n (fun i => x i * y i) := rfl + +end Linear + +end Sound + +end Nfp diff --git a/README.md b/README.md index 61b2c42..4e6a690 100644 --- a/README.md +++ b/README.md @@ -32,8 +32,9 @@ Module map and invariants are tracked in `AGENTS.md`. ## Induction Certification (prototype) -The current end-to-end prototype checks a **softmax-margin certificate** for a single GPT-2-small -head. The certificate is produced by an **untrusted** helper script and verified by the CLI. +The current prototype checks **head-level induction certificates** and can optionally compose +them with a **downstream error bound**. Certificates are produced by **untrusted** helper scripts +and verified by the CLI. Generate certificates (untrusted): @@ -41,7 +42,14 @@ Generate certificates (untrusted): python scripts/build_gpt2_induction_cert.py \ --output reports/gpt2_induction.cert \ --layer 5 --head 1 --seq 32 --pattern-length 16 \ - --values-out reports/gpt2_induction.values --value-dim 0 + --values-out reports/gpt2_induction.values --value-dim 0 \ + --active-eps-max 1/2 +``` + +To produce value-range certificates aligned with a logit-diff direction, add: + +``` +--direction-target --direction-negative ``` Verify it (trusted checker): @@ -51,35 +59,141 @@ lake exe nfp induction certify --scores reports/gpt2_induction.cert \ --values reports/gpt2_induction.values ``` +You can enforce non-vacuity checks with: + +``` +--min-margin --max-eps --min-active --min-logit-diff +``` + +To recompute `eps`/`margin` and `lo`/`hi` inside Lean (sound builder), run: + +```bash +lake exe nfp induction certify_sound --scores reports/gpt2_induction.cert \ + --values reports/gpt2_induction.values +``` + +To compute scores/values inside Lean from exact head inputs, run: + +```bash +lake exe nfp induction certify_head --inputs reports/gpt2_induction.head +``` + +To add a downstream error bound (end-to-end check), supply a downstream certificate +that records a nonnegative error bound computed externally: + +```bash +python scripts/build_downstream_linear_cert.py \ + --output reports/gpt2_downstream.cert \ + --gain 3/2 --input-bound 5/4 + +lake exe nfp induction certify_end_to_end --scores reports/gpt2_induction.cert \ + --values reports/gpt2_induction.values --downstream reports/gpt2_downstream.cert +``` + +To build a head-input file from an exported `.nfpt` binary: + +```bash +python scripts/build_gpt2_head_inputs.py --model models/gpt2_rigorous.nfpt \ + --layer 5 --head 1 --direction-target 17850 --direction-negative 31215 \ + --output reports/gpt2_induction.head +``` + +This extractor is **untrusted** and currently ignores LN/bias terms, so treat it as a +convenience path for exercising the `certify_head` pipeline rather than a full +end-to-end verification of GPT-2 internals. + Softmax-margin certificate format (line-oriented): ``` seq eps margin +active prev score weight ``` +`active ` lines declare the queries on which the bounds are required; if omitted, +the checker defaults to all nonzero queries. + Value-range certificate format (line-oriented): ``` seq +direction-target +direction-negative lo hi val ``` -The checker validates that the provided scores/weights satisfy `SoftmaxMarginBounds` and that the -value entries are bounded by `lo`/`hi`. When both are provided, the CLI reports a tolerance -`eps * (hi - lo)` for the approximate induction spec. +`direction-*` lines are optional metadata for directional (logit-diff) values. + +Downstream linear certificate format (line-oriented): + +``` +error +gain +input-bound +``` + +The checker enforces `error = gain * input-bound` and nonnegativity of all fields. + +Head input format for `certify_head` (line-oriented): + +``` +seq +d_model +d_head +scale +direction-target +direction-negative +direction +active +prev +embed +wq +wk +wv +wo +``` + +All `direction`, `embed`, and projection matrices must be fully specified. If no +`active` lines appear, the checker defaults to all nonzero queries. + +The checker derives a softmax tolerance from the score margins and validates the value-range +bounds. The CLI reports a tolerance `eps * (hi - lo)` for the approximate induction spec. + +For tighter, non-vacuous bounds, use `--active-eps-max` when building the certificate to restrict +`active` queries to positions with small `eps` (at the cost of fewer certified positions). +You can enforce a minimum active coverage at check time with `--min-active `. +The default minimum is `max 1 (seq / 8)` when the flag is omitted. + +`certify`/`certify_sound`/`certify_end_to_end` also accept `--min-margin` and `--max-eps` to +reject vacuous score gaps or overly large tolerances (defaults: `0` and `1/2`). + +If the value-range certificate is built from a logit-diff direction (see below), +the checker also reports `logitDiffLB`. When `direction-target`/`direction-negative` +metadata is present, the checker defaults `--min-logit-diff` to `0` to avoid +vacuous directional bounds. You can override with a higher rational literal. + +`certify-sound` ignores any supplied `eps`/`margin`/`lo`/`hi` lines and recomputes +those bounds from the raw entries. + +`certify_head` reads a single input file with exact head inputs (embeddings, +projection weights, direction vector, and scale) and recomputes +scores/values inside Lean. ## Soundness statement (what is proven vs checked) The Lean library defines the core math objects (finite probability, mixers, linearizations, and operator-norm-style bounds) and proves a number of lemmas about them. The CLI sound path produces certificates using exact `Rat` arithmetic and a trusted checker that verifies internal arithmetic relationships between certificate fields. -At present, the checker does **not** include a bridge theorem that connects certificate validity to the Lean-defined Jacobian bounds (for example, a theorem of the form `||layerJacobian - I|| <= C`). Treat sound certificates as **internally consistent bound reports**, not as a fully formal end-to-end verification of transformer Jacobians. +At present, the checker does **not** include a bridge theorem that connects certificate validity to +Lean-defined Jacobian bounds (for example, a theorem of the form `||layerJacobian - I|| <= C`). +The downstream error certificate is only checked for internal arithmetic consistency. +Treat sound certificates as **internally consistent bound reports**, not as a fully formal +end-to-end verification of transformer Jacobians. Margin-based softmax tightening exists, but only **best-match margin evidence** is accepted today. Direct `--softmaxMargin` is rejected by the checker, and best-match logit bounds are generated in untrusted code and only checked for internal consistency. diff --git a/scripts/build_downstream_linear_cert.py b/scripts/build_downstream_linear_cert.py new file mode 100644 index 0000000..206da9c --- /dev/null +++ b/scripts/build_downstream_linear_cert.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: AGPL-3.0-or-later + +""" +Build a downstream linear certificate from externally computed bounds. + +This script is untrusted: it only formats rational inputs into the certificate +format expected by `nfp induction certify_end_to_end`. + +Usage: + python scripts/build_downstream_linear_cert.py \ + --output reports/gpt2_downstream.cert \ + --gain 3/2 \ + --input-bound 5/4 + +Optional: + --error 15/8 # override gain * input-bound +""" + +import argparse +from fractions import Fraction +from pathlib import Path + + +def parse_rat(raw: str) -> Fraction: + if "/" in raw: + num, den = raw.split("/", 1) + return Fraction(int(num.strip()), int(den.strip())) + return Fraction(int(raw.strip()), 1) + + +def rat_to_str(q: Fraction) -> str: + if q.denominator == 1: + return str(q.numerator) + return f"{q.numerator}/{q.denominator}" + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--output", required=True, help="Path to write certificate") + parser.add_argument("--gain", required=True, help="Nonnegative gain bound (Rat)") + parser.add_argument("--input-bound", required=True, + help="Nonnegative input bound (Rat)") + parser.add_argument("--error", + help="Optional error override (Rat). Defaults to gain * input-bound.") + args = parser.parse_args() + + gain = parse_rat(args.gain) + input_bound = parse_rat(args.input_bound) + if gain < 0 or input_bound < 0: + raise SystemExit("gain and input-bound must be nonnegative") + error = parse_rat(args.error) if args.error else gain * input_bound + if error < 0: + raise SystemExit("error must be nonnegative") + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w", encoding="ascii") as f: + f.write(f"error {rat_to_str(error)}\n") + f.write(f"gain {rat_to_str(gain)}\n") + f.write(f"input-bound {rat_to_str(input_bound)}\n") + print(f"Wrote downstream certificate to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/build_gpt2_head_inputs.py b/scripts/build_gpt2_head_inputs.py new file mode 100644 index 0000000..0065566 --- /dev/null +++ b/scripts/build_gpt2_head_inputs.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: AGPL-3.0-or-later + +""" +Build an induction head input file from an NFP_BINARY_V1 model. + +This is an untrusted helper that extracts a single head slice plus the +prompt embeddings from an `.nfpt` file and writes the text format consumed by +`nfp induction certify_head`. +""" + +from __future__ import annotations + +import argparse +import math +import struct +from fractions import Fraction +from pathlib import Path +from typing import Dict, Tuple + +import numpy as np + + +def rat_from_float_exact(x: float) -> Fraction: + if not math.isfinite(x): + raise SystemExit(f"non-finite float encountered: {x}") + num, den = x.as_integer_ratio() + return Fraction(num, den) + + +def rat_to_str(q: Fraction) -> str: + if q.denominator == 1: + return str(q.numerator) + return f"{q.numerator}/{q.denominator}" + + +def parse_header(f) -> Dict[str, str]: + header: Dict[str, str] = {} + magic = f.readline().decode("ascii").strip() + if magic != "NFP_BINARY_V1": + raise SystemExit(f"Unsupported magic header: {magic}") + while True: + line = f.readline() + if line == b"": + raise SystemExit("Unexpected EOF while reading header.") + text = line.decode("ascii").strip() + if text == "BINARY_START": + break + if "=" in text: + key, value = text.split("=", 1) + header[key.strip()] = value.strip() + return header + + +def read_i32(f, count: int) -> np.ndarray: + raw = f.read(count * 4) + if len(raw) != count * 4: + raise SystemExit("Unexpected EOF while reading int32 payload.") + return np.frombuffer(raw, dtype=" np.ndarray: + raw = f.read(count * 8) + if len(raw) != count * 8: + raise SystemExit("Unexpected EOF while reading float64 payload.") + return np.frombuffer(raw, dtype=" None: + offset = count * 8 + f.seek(offset, 1) + + +def build_prev(tokens: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + prev = np.zeros_like(tokens) + active = np.zeros_like(tokens, dtype=bool) + last_seen: Dict[int, int] = {} + for idx, tok in enumerate(tokens.tolist()): + if idx == 0: + prev[idx] = 0 + active[idx] = False + else: + if tok in last_seen: + prev[idx] = last_seen[tok] + active[idx] = True + else: + prev[idx] = 0 + active[idx] = False + last_seen[tok] = idx + return prev, active + + +def read_head_weights( + f, + num_layers: int, + num_heads: int, + model_dim: int, + head_dim: int, + hidden_dim: int, + layer: int, + head: int, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + target = (layer, head) + wq = wk = wv = wo = None + for layer_idx in range(num_layers): + for head_idx in range(num_heads): + wq_block = read_f64(f, model_dim * head_dim).reshape(model_dim, head_dim) + _ = read_f64(f, head_dim) # b_Q + wk_block = read_f64(f, model_dim * head_dim).reshape(model_dim, head_dim) + _ = read_f64(f, head_dim) # b_K + wv_block = read_f64(f, model_dim * head_dim).reshape(model_dim, head_dim) + _ = read_f64(f, head_dim) # b_V + wo_block = read_f64(f, head_dim * model_dim).reshape(head_dim, model_dim) + if (layer_idx, head_idx) == target: + wq = wq_block + wk = wk_block + wv = wv_block + wo = wo_block + # Skip per-layer non-head data. + skip_f64(f, model_dim) # attn_bias + skip_f64(f, model_dim * hidden_dim) # w_in + skip_f64(f, hidden_dim) # b_in + skip_f64(f, hidden_dim * model_dim) # w_out + skip_f64(f, model_dim) # b_out + skip_f64(f, model_dim) # ln1_gamma + skip_f64(f, model_dim) # ln1_beta + skip_f64(f, model_dim) # ln2_gamma + skip_f64(f, model_dim) # ln2_beta + if wq is None or wk is None or wv is None or wo is None: + raise SystemExit("Failed to locate head weights.") + return wq, wk, wv, wo + + +def read_unembed_columns( + f, + start: int, + model_dim: int, + vocab_size: int, + target: int, + negative: int, +) -> Tuple[np.ndarray, np.ndarray]: + row_bytes = vocab_size * 8 + col_t = np.zeros(model_dim, dtype=np.float64) + col_n = np.zeros(model_dim, dtype=np.float64) + for row in range(model_dim): + base = start + row * row_bytes + f.seek(base + target * 8) + col_t[row] = struct.unpack(" None: + seq, model_dim = embeddings.shape + _, head_dim = wq.shape + with path.open("w", encoding="ascii") as f: + f.write(f"seq {seq}\n") + f.write(f"d_model {model_dim}\n") + f.write(f"d_head {head_dim}\n") + f.write(f"scale {rat_to_str(scale)}\n") + for q, flag in enumerate(active.tolist()): + if flag: + f.write(f"active {q}\n") + for q, k in enumerate(prev.tolist()): + f.write(f"prev {q} {k}\n") + for q in range(seq): + for d in range(model_dim): + f.write(f"embed {q} {d} {rat_to_str(rat_from_float_exact(float(embeddings[q, d])))}\n") + for i in range(model_dim): + for j in range(head_dim): + f.write(f"wq {i} {j} {rat_to_str(rat_from_float_exact(float(wq[i, j])))}\n") + for i in range(model_dim): + for j in range(head_dim): + f.write(f"wk {i} {j} {rat_to_str(rat_from_float_exact(float(wk[i, j])))}\n") + for i in range(model_dim): + for j in range(head_dim): + f.write(f"wv {i} {j} {rat_to_str(rat_from_float_exact(float(wv[i, j])))}\n") + for i in range(model_dim): + for j in range(head_dim): + f.write(f"wo {i} {j} {rat_to_str(rat_from_float_exact(float(wo[i, j])))}\n") + f.write(f"direction-target {direction_target}\n") + f.write(f"direction-negative {direction_negative}\n") + for d in range(model_dim): + f.write(f"direction {d} {rat_to_str(rat_from_float_exact(float(direction[d])))}\n") + + +def main() -> None: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--model", type=Path, required=True, help="Path to NFP_BINARY_V1 model") + ap.add_argument("--layer", type=int, required=True, help="Layer index") + ap.add_argument("--head", type=int, required=True, help="Head index") + ap.add_argument("--output", type=Path, required=True, help="Path for the head input file") + ap.add_argument("--direction-target", type=int, required=True, help="Target token id") + ap.add_argument("--direction-negative", type=int, required=True, help="Negative token id") + args = ap.parse_args() + + if not args.model.exists(): + raise SystemExit(f"Missing model file: {args.model}") + + with args.model.open("rb") as f: + header = parse_header(f) + num_layers = int(header["num_layers"]) + num_heads = int(header["num_heads"]) + model_dim = int(header["model_dim"]) + head_dim = int(header["head_dim"]) + vocab_size = int(header["vocab_size"]) + seq_len = int(header["seq_len"]) + hidden_dim = int(header["hidden_dim"]) + + if args.layer < 0 or args.layer >= num_layers: + raise SystemExit("layer index out of range") + if args.head < 0 or args.head >= num_heads: + raise SystemExit("head index out of range") + if not (0 <= args.direction_target < vocab_size): + raise SystemExit("direction-target out of vocab range") + if not (0 <= args.direction_negative < vocab_size): + raise SystemExit("direction-negative out of vocab range") + + tokens = read_i32(f, seq_len) + embeddings = read_f64(f, seq_len * model_dim).reshape(seq_len, model_dim) + + wq, wk, wv, wo_raw = read_head_weights( + f, + num_layers, + num_heads, + model_dim, + head_dim, + hidden_dim, + args.layer, + args.head, + ) + + # Skip final layer norm parameters. + skip_f64(f, model_dim) # ln_f_gamma + skip_f64(f, model_dim) # ln_f_beta + + unembed_start = f.tell() + col_target, col_negative = read_unembed_columns( + f, + unembed_start, + model_dim, + vocab_size, + args.direction_target, + args.direction_negative, + ) + + prev, active = build_prev(tokens) + direction = col_target - col_negative + scale_denom = int(math.isqrt(head_dim)) + if scale_denom * scale_denom != head_dim: + scale = rat_from_float_exact(1.0 / math.sqrt(head_dim)) + else: + scale = Fraction(1, scale_denom) + + # Stored W_O is (head_dim, model_dim); transpose to model_dim × head_dim. + wo = wo_raw.T + + args.output.parent.mkdir(parents=True, exist_ok=True) + write_head_inputs( + args.output, + scale, + tokens, + embeddings, + prev, + active, + wq, + wk, + wv, + wo, + args.direction_target, + args.direction_negative, + direction, + ) + + print(f"Wrote head inputs to {args.output}") + print(f"seq={seq_len} d_model={model_dim} d_head={head_dim}") + + +if __name__ == "__main__": + main() diff --git a/scripts/build_gpt2_induction_cert.py b/scripts/build_gpt2_induction_cert.py index 83084e0..05b38d5 100644 --- a/scripts/build_gpt2_induction_cert.py +++ b/scripts/build_gpt2_induction_cert.py @@ -5,12 +5,17 @@ Build a softmax-margin certificate for a GPT-2-small induction head. This script is untrusted and uses floating-point arithmetic to produce a -rational certificate compatible with `nfp induction certify`. +rational certificate compatible with `nfp induction certify`. Active +induction positions are recorded as `active ` lines in the output. Usage: python scripts/build_gpt2_induction_cert.py --output reports/gpt2_induction.cert \ --layer 5 --head 1 --seq 32 --pattern-length 16 \ - --values-out reports/gpt2_induction.values --value-dim 0 + --values-out reports/gpt2_induction.values --value-dim 0 \ + --active-eps-max 0.2 + +Optionally, provide a logit-diff direction: + --direction-target --direction-negative """ import argparse @@ -50,16 +55,23 @@ def build_tokens(seq: int, pattern_len: int, random_pattern: bool, seed: int) -> return np.tile(pattern, repeats)[:seq] -def build_prev(tokens: np.ndarray) -> np.ndarray: +def build_prev(tokens: np.ndarray) -> tuple[np.ndarray, np.ndarray]: prev = np.zeros_like(tokens) + active = np.zeros_like(tokens, dtype=bool) last_seen = {} for idx, tok in enumerate(tokens): if idx == 0: prev[idx] = 0 + active[idx] = False else: - prev[idx] = last_seen.get(tok, 0) + if tok in last_seen: + prev[idx] = last_seen[tok] + active[idx] = True + else: + prev[idx] = 0 + active[idx] = False last_seen[tok] = idx - return prev + return prev, active def compute_scores_weights(model, input_ids, layer: int, head: int, device: str): @@ -89,13 +101,17 @@ def compute_scores_weights(model, input_ids, layer: int, head: int, device: str) vh.squeeze(0).cpu().numpy()) -def write_scores(path: Path, seq: int, prev: np.ndarray, scores, weights, eps=None, margin=None): +def write_scores(path: Path, seq: int, prev: np.ndarray, scores, weights, eps=None, margin=None, + active=None): with path.open("w", encoding="ascii") as f: f.write(f"seq {seq}\n") if eps is not None: f.write(f"eps {rat_to_str(eps)}\n") if margin is not None: f.write(f"margin {rat_to_str(margin)}\n") + if active is not None: + for q in active: + f.write(f"active {q}\n") for q, k in enumerate(prev.tolist()): f.write(f"prev {q} {k}\n") for q in range(seq): @@ -106,12 +122,16 @@ def write_scores(path: Path, seq: int, prev: np.ndarray, scores, weights, eps=No f.write(f"weight {q} {k} {rat_to_str(weights[q][k])}\n") -def write_value_range(path: Path, seq: int, values, decimals: int) -> None: +def write_value_range(path: Path, seq: int, values, decimals: int, + direction_target=None, direction_negative=None) -> None: vals_rat = [rat_from_float(float(values[k]), decimals) for k in range(seq)] lo = min(vals_rat) hi = max(vals_rat) with path.open("w", encoding="ascii") as f: f.write(f"seq {seq}\n") + if direction_target is not None and direction_negative is not None: + f.write(f"direction-target {direction_target}\n") + f.write(f"direction-negative {direction_negative}\n") f.write(f"lo {rat_to_str(lo)}\n") f.write(f"hi {rat_to_str(hi)}\n") for k, val in enumerate(vals_rat): @@ -134,13 +154,20 @@ def main() -> None: parser.add_argument("--values-out", help="Optional path for a value-range certificate") parser.add_argument("--value-dim", type=int, default=0, help="Value dimension index for the value-range certificate") + parser.add_argument("--active-eps-max", default="1/2", + help="Maximum eps to include an active position (default: 1/2).") + parser.add_argument("--direction-target", type=int, + help="Target token id for logit-diff direction (optional)") + parser.add_argument("--direction-negative", type=int, + help="Negative token id for logit-diff direction (optional)") args = parser.parse_args() if args.seq <= 0: raise SystemExit("seq must be positive") tokens = build_tokens(args.seq, args.pattern_length, args.random_pattern, args.seed) - prev = build_prev(tokens) + prev, active_mask = build_prev(tokens) + candidate_positions = [int(i) for i, flag in enumerate(active_mask) if flag] model = GPT2Model.from_pretrained(args.model) model.to(args.device) @@ -162,37 +189,70 @@ def main() -> None: eps = Fraction(0) margin = None - for q in range(1, args.seq): + eps_by_q: dict[int, Fraction] = {} + margin_by_q: dict[int, Fraction] = {} + for q in candidate_positions: prev_q = prev[q] prev_w = weights_rat[q][prev_q] max_other = max(weights_rat[q][k] for k in range(args.seq) if k != prev_q) deficit = Fraction(1) - prev_w - eps = max(eps, max(max_other, deficit)) + eps_by_q[q] = max(max_other, deficit) diffs = [scores_rat[q][prev_q] - scores_rat[q][k] for k in range(args.seq) if k != prev_q] if diffs: - min_diff = min(diffs) - margin = min_diff if margin is None else min(margin, min_diff) + margin_by_q[q] = min(diffs) + + active_positions = candidate_positions + eps_threshold = Fraction(args.active_eps_max) + active_positions = [q for q in candidate_positions if eps_by_q[q] <= eps_threshold] + if not active_positions and candidate_positions: + print("Warning: no active positions satisfy active-eps-max; certificate may be vacuous.") - if margin is None: + if active_positions: + eps = max(eps_by_q[q] for q in active_positions) + margin = min((margin_by_q[q] for q in active_positions), default=Fraction(0)) + else: margin = Fraction(0) + if candidate_positions: + print(f"Active positions: {len(active_positions)}/{len(candidate_positions)}") + output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) - write_scores(output_path, args.seq, prev, scores_rat, weights_rat, eps=eps, margin=margin) + write_scores(output_path, args.seq, prev, scores_rat, weights_rat, + eps=eps, margin=margin, active=active_positions) if args.scores_out: scores_path = Path(args.scores_out) scores_path.parent.mkdir(parents=True, exist_ok=True) - write_scores(scores_path, args.seq, prev, scores_rat, weights_rat) + write_scores(scores_path, args.seq, prev, scores_rat, weights_rat, + active=active_positions) if args.values_out: - if args.value_dim < 0 or args.value_dim >= values.shape[1]: - raise SystemExit(f"value-dim must be in [0, {values.shape[1] - 1}]") values_path = Path(args.values_out) values_path.parent.mkdir(parents=True, exist_ok=True) - write_value_range(values_path, args.seq, values[:, args.value_dim], args.decimals) + if (args.direction_target is None) != (args.direction_negative is None): + raise SystemExit("direction-target and direction-negative must be provided together") + if args.direction_target is not None: + wte = model.wte.weight.detach().cpu().numpy() + if args.direction_target < 0 or args.direction_target >= wte.shape[0]: + raise SystemExit("direction-target out of vocab range") + if args.direction_negative < 0 or args.direction_negative >= wte.shape[0]: + raise SystemExit("direction-negative out of vocab range") + direction = wte[args.direction_target] - wte[args.direction_negative] + head_dim = model.config.n_embd // model.config.n_head + start, end = args.head * head_dim, (args.head + 1) * head_dim + w_o = model.h[args.layer].attn.c_proj.weight.detach().cpu().numpy()[:, start:end] + dir_head = w_o.T @ direction + dir_vals = values @ dir_head + write_value_range(values_path, args.seq, dir_vals, args.decimals, + direction_target=args.direction_target, + direction_negative=args.direction_negative) + else: + if args.value_dim < 0 or args.value_dim >= values.shape[1]: + raise SystemExit(f"value-dim must be in [0, {values.shape[1] - 1}]") + write_value_range(values_path, args.seq, values[:, args.value_dim], args.decimals) print(f"Wrote certificate to {output_path}") if args.scores_out: From b7dad210b7372050c080a6d2d926ff4e1573deca Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 3 Jan 2026 14:33:53 +0100 Subject: [PATCH 095/244] Add verified matrix norm bounds and rewrite README --- AGENTS.md | 2 + Nfp/Sound.lean | 1 + Nfp/Sound/Bounds/MatrixNorm.lean | 125 ++++++ README.md | 631 ++++--------------------------- 4 files changed, 199 insertions(+), 560 deletions(-) create mode 100644 Nfp/Sound/Bounds/MatrixNorm.lean diff --git a/AGENTS.md b/AGENTS.md index 4500f6a..55f0a23 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -339,6 +339,8 @@ but you **must** update this list in the same commit. ### 5.7 Sound certification - `Nfp/Sound/Induction.lean` - Sound builders for induction certificates from exact inputs. +- `Nfp/Sound/Bounds/MatrixNorm.lean` + - Row-sum matrix norms and downstream linear certificate builders. - `Nfp/Sound/Linear/FinFold.lean` - Tail-recursive folds and sums for sound linear computations. - `Nfp/Sound/Gpt2/HeadInputs.lean` diff --git a/Nfp/Sound.lean b/Nfp/Sound.lean index 2f3e713..99456bb 100644 --- a/Nfp/Sound.lean +++ b/Nfp/Sound.lean @@ -2,6 +2,7 @@ import Nfp.Sound.Gpt2.HeadInputs import Nfp.Sound.Induction +import Nfp.Sound.Bounds.MatrixNorm import Nfp.Sound.Linear.FinFold /-! diff --git a/Nfp/Sound/Bounds/MatrixNorm.lean b/Nfp/Sound/Bounds/MatrixNorm.lean new file mode 100644 index 0000000..8cf54b1 --- /dev/null +++ b/Nfp/Sound/Bounds/MatrixNorm.lean @@ -0,0 +1,125 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.Order.BigOperators.Group.Finset +import Mathlib.Algebra.Order.Ring.Abs +import Mathlib.Algebra.Order.Ring.Rat +import Mathlib.Data.Matrix.Mul +import Nfp.Circuit.Cert.DownstreamLinear + +/-! +Row-sum matrix norms for downstream linear certificates. + +These bounds are used to compute verified downstream error certificates +from explicit Rat matrices. +-/ + +namespace Nfp + +namespace Sound + +namespace Bounds + +open scoped BigOperators + +/-- Row-sum of absolute values for a matrix row. -/ +def rowSum {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : Rat := + ∑ j, |W i j| + +/-- Maximum row-sum norm (defaults to `0` on empty matrices). -/ +def rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) : Rat := + if h : (Finset.univ : Finset (Fin m)).Nonempty then + (Finset.univ).sup' h (fun i => rowSum W i) + else + 0 + +/-- Row-sums are nonnegative. -/ +theorem rowSum_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : + 0 ≤ rowSum W i := by + refine Finset.sum_nonneg ?_ + intro j _ + exact abs_nonneg (W i j) + +/-- Each row-sum is bounded by the row-sum norm. -/ +theorem rowSum_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : + rowSum W i ≤ rowSumNorm W := by + classical + have h : (Finset.univ : Finset (Fin m)).Nonempty := ⟨i, by simp⟩ + have hle : + rowSum W i ≤ (Finset.univ).sup' h (fun i => rowSum W i) := by + simpa using + (Finset.le_sup' + (s := (Finset.univ : Finset (Fin m))) + (f := fun i => rowSum W i) + (by simp : i ∈ (Finset.univ : Finset (Fin m)))) + simpa [rowSumNorm, h] using hle + +/-- The row-sum norm is nonnegative. -/ +theorem rowSumNorm_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) : + 0 ≤ rowSumNorm W := by + classical + by_cases h : (Finset.univ : Finset (Fin m)).Nonempty + · rcases h with ⟨i, hi⟩ + have hrow : 0 ≤ rowSum W i := rowSum_nonneg W i + have hle : rowSum W i ≤ rowSumNorm W := rowSum_le_rowSumNorm W i + exact le_trans hrow hle + · simp [rowSumNorm, h] + +/-- Row-sum norm bounds a matrix-vector product under a uniform input bound. -/ +theorem abs_mulVec_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (x : Fin n → Rat) (inputBound : Rat) + (hx : ∀ j, |x j| ≤ inputBound) (hinput : 0 ≤ inputBound) : + ∀ i, |Matrix.mulVec W x i| ≤ rowSumNorm W * inputBound := by + intro i + have hrow : |Matrix.mulVec W x i| ≤ rowSum W i * inputBound := by + have h1 : |∑ j, W i j * x j| ≤ ∑ j, |W i j * x j| := by + simpa using + (Finset.abs_sum_le_sum_abs + (f := fun j => W i j * x j) + (s := (Finset.univ : Finset (Fin n)))) + have h2 : ∑ j, |W i j * x j| ≤ ∑ j, |W i j| * inputBound := by + refine Finset.sum_le_sum ?_ + intro j _ + have hxj := hx j + have hnonneg : 0 ≤ |W i j| := abs_nonneg (W i j) + calc + |W i j * x j| = |W i j| * |x j| := by + simp [abs_mul] + _ ≤ |W i j| * inputBound := by + exact mul_le_mul_of_nonneg_left hxj hnonneg + have h3 : ∑ j, |W i j| * inputBound = rowSum W i * inputBound := by + have hsum : + (∑ j, |W i j|) * inputBound = ∑ j, |W i j| * inputBound := by + simpa using + (Finset.sum_mul + (s := (Finset.univ : Finset (Fin n))) + (f := fun j => |W i j|) + (a := inputBound)) + simpa [rowSum] using hsum.symm + have hmul : |Matrix.mulVec W x i| ≤ rowSum W i * inputBound := by + simpa [Matrix.mulVec, dotProduct] using h1.trans (h2.trans_eq h3) + exact hmul + have hle : rowSum W i ≤ rowSumNorm W := rowSum_le_rowSumNorm W i + have hmul : rowSum W i * inputBound ≤ rowSumNorm W * inputBound := + mul_le_mul_of_nonneg_right hle hinput + exact hrow.trans hmul + +/-- Build a downstream linear certificate from a matrix and input bound. -/ +def buildDownstreamLinearCert {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (inputBound : Rat) (hinput : 0 ≤ inputBound) : + {c : Circuit.DownstreamLinearCert // Circuit.DownstreamLinearBounds c} := by + let gain := rowSumNorm W + let error := gain * inputBound + refine ⟨{ error := error, gain := gain, inputBound := inputBound }, ?_⟩ + refine + { error_nonneg := ?_ + gain_nonneg := ?_ + input_nonneg := hinput + error_eq := rfl } + · exact mul_nonneg (rowSumNorm_nonneg W) hinput + · exact rowSumNorm_nonneg W + +end Bounds + +end Sound + +end Nfp diff --git a/README.md b/README.md index 4e6a690..f3c78ce 100644 --- a/README.md +++ b/README.md @@ -1,42 +1,51 @@ # NFP -NFP is a Lean 4 project for **mathematically rigorous** reasoning about transformer-style computations, with a focus on mechanistic interpretability (e.g. induction heads) and provable norm/error bounds. +NFP is a Lean 4 project for **mathematically rigorous** reasoning about transformer-style +computations, with a focus on mechanistic interpretability (e.g. induction heads) and provable +norm/error bounds. NFP stands for **Neural Formal Pathways**. -This repo contains: +## Status -- A **Lean library** (under `Nfp/`) for finite probability and a lightweight “transformer semantics” layer. -- A **CLI executable** (`lake exe nfp …`) that loads transformer weights stored in a compact binary format (`.nfpt`) and produces rigorous bounds and diagnostics. +This repository is in a **tabula rasa rewrite**. The new core is intentionally minimal and the API +surface is still settling. Expect breaking changes. -> Goal: *no “hand-wavy” numerics in the bound path.* Heuristic estimates (e.g. power iteration) may exist for diagnostics, but the bounds reported as “rigorous” are computed via conservative inequalities. +## Build -## Status +```bash +lake build -q --wfail +lake build nfp -q --wfail +``` -This is research tooling. Interfaces may change; please treat results as experimental unless they are backed by a certificate/check you trust. +## CLI -## Tabula Rasa Rewrite (current state) +```bash +lake exe nfp --help +lake exe nfp induction --help +``` + +Current subcommands are limited to **induction certificate checking**. The CLI does **not** run a +full model forward pass and does **not** ingest `.nfpt` weights directly; weight ingestion is done +by untrusted helper scripts (see below). -The `tabula-rasa` branch is a fresh, minimal Lean 4 core focused on circuit certification. +## Module map -Current core modules (new): -- `Nfp/Core`, `Nfp/Prob`, `Nfp/Mixer`, `Nfp/System` define basic mass/probability, mixers, and DAG-backed local systems. -- `Nfp/Circuit` defines DAG-based circuits with typed interfaces, well-formedness, and equivalence checkers. -- `Nfp/Circuit/Compose` adds sequential and residual wiring combinators for typed circuits. -- `Nfp/Circuit/Layers/Attention` contains Q/K/V projection wiring plus an attention score/mixing core. -- `Nfp/Circuit/Layers/Induction` provides induction-head specs and the core attention one-hot lemma. -- `Nfp/Circuit/Layers/TransformerBlock` wires LN/attention/MLP into a GPT-style block skeleton. -- `Nfp/Cli` and `Main.lean` remain thin placeholders (no full transformer pipeline yet). +The authoritative module map and invariants are tracked in `AGENTS.md`. -Module map and invariants are tracked in `AGENTS.md`. +High-level layout: +- `Nfp/Core`, `Nfp/Prob`, `Nfp/Mixer`, `Nfp/System`: core math infrastructure. +- `Nfp/Circuit`: circuits, typed interfaces, and layer wiring (attention, induction). +- `Nfp/Sound`: sound builders and verified helpers. +- `Nfp/IO`, `Nfp/Cli`: parsing and CLI entrypoints. ## Induction Certification (prototype) -The current prototype checks **head-level induction certificates** and can optionally compose -them with a **downstream error bound**. Certificates are produced by **untrusted** helper scripts -and verified by the CLI. +The current prototype checks **head-level induction certificates** and can optionally compose them +with a **downstream error bound**. Certificates are produced by **untrusted** helper scripts and +verified by the CLI. -Generate certificates (untrusted): +### Build a head certificate (untrusted) ```bash python scripts/build_gpt2_induction_cert.py \ @@ -46,63 +55,64 @@ python scripts/build_gpt2_induction_cert.py \ --active-eps-max 1/2 ``` -To produce value-range certificates aligned with a logit-diff direction, add: +If you want values aligned to a logit-diff direction, add: ``` --direction-target --direction-negative ``` -Verify it (trusted checker): +### Verify a head certificate (trusted checker) ```bash -lake exe nfp induction certify --scores reports/gpt2_induction.cert \ +lake exe nfp induction certify \ + --scores reports/gpt2_induction.cert \ --values reports/gpt2_induction.values ``` -You can enforce non-vacuity checks with: +Non-vacuity gates (optional): ``` --min-margin --max-eps --min-active --min-logit-diff ``` -To recompute `eps`/`margin` and `lo`/`hi` inside Lean (sound builder), run: +### Recompute bounds inside Lean (sound builder) ```bash -lake exe nfp induction certify_sound --scores reports/gpt2_induction.cert \ +lake exe nfp induction certify_sound \ + --scores reports/gpt2_induction.cert \ --values reports/gpt2_induction.values ``` -To compute scores/values inside Lean from exact head inputs, run: +This ignores any `eps`/`margin`/`lo`/`hi` lines and recomputes them from the raw entries. + +### Compute exact head inputs inside Lean (experimental) ```bash lake exe nfp induction certify_head --inputs reports/gpt2_induction.head ``` -To add a downstream error bound (end-to-end check), supply a downstream certificate -that records a nonnegative error bound computed externally: +This path recomputes scores/values in Lean from exact head inputs. It is **experimental** and can +be slow for nontrivial sequence lengths. + +### End-to-end check with downstream bound (prototype) ```bash python scripts/build_downstream_linear_cert.py \ --output reports/gpt2_downstream.cert \ --gain 3/2 --input-bound 5/4 -lake exe nfp induction certify_end_to_end --scores reports/gpt2_induction.cert \ - --values reports/gpt2_induction.values --downstream reports/gpt2_downstream.cert +lake exe nfp induction certify_end_to_end \ + --scores reports/gpt2_induction.cert \ + --values reports/gpt2_induction.values \ + --downstream reports/gpt2_downstream.cert ``` -To build a head-input file from an exported `.nfpt` binary: - -```bash -python scripts/build_gpt2_head_inputs.py --model models/gpt2_rigorous.nfpt \ - --layer 5 --head 1 --direction-target 17850 --direction-negative 31215 \ - --output reports/gpt2_induction.head -``` +The downstream certificate is **checked for internal arithmetic consistency** but is still +externally computed. Work is ongoing to compute this bound inside Lean from model weights. -This extractor is **untrusted** and currently ignores LN/bias terms, so treat it as a -convenience path for exercising the `certify_head` pipeline rather than a full -end-to-end verification of GPT-2 internals. +## File formats -Softmax-margin certificate format (line-oriented): +### Softmax-margin certificate ``` seq @@ -114,10 +124,10 @@ score weight ``` -`active ` lines declare the queries on which the bounds are required; if omitted, -the checker defaults to all nonzero queries. +`active ` lines declare the queries on which bounds are required; if omitted, the checker +defaults to all nonzero queries. -Value-range certificate format (line-oriented): +### Value-range certificate ``` seq @@ -130,7 +140,7 @@ val `direction-*` lines are optional metadata for directional (logit-diff) values. -Downstream linear certificate format (line-oriented): +### Downstream linear certificate ``` error @@ -140,7 +150,7 @@ input-bound The checker enforces `error = gain * input-bound` and nonnegativity of all fields. -Head input format for `certify_head` (line-oriented): +### Head input format (for `certify_head`) ``` seq @@ -159,522 +169,23 @@ wv wo ``` -All `direction`, `embed`, and projection matrices must be fully specified. If no -`active` lines appear, the checker defaults to all nonzero queries. - -The checker derives a softmax tolerance from the score margins and validates the value-range -bounds. The CLI reports a tolerance `eps * (hi - lo)` for the approximate induction spec. - -For tighter, non-vacuous bounds, use `--active-eps-max` when building the certificate to restrict -`active` queries to positions with small `eps` (at the cost of fewer certified positions). -You can enforce a minimum active coverage at check time with `--min-active `. -The default minimum is `max 1 (seq / 8)` when the flag is omitted. - -`certify`/`certify_sound`/`certify_end_to_end` also accept `--min-margin` and `--max-eps` to -reject vacuous score gaps or overly large tolerances (defaults: `0` and `1/2`). - -If the value-range certificate is built from a logit-diff direction (see below), -the checker also reports `logitDiffLB`. When `direction-target`/`direction-negative` -metadata is present, the checker defaults `--min-logit-diff` to `0` to avoid -vacuous directional bounds. You can override with a higher rational literal. - -`certify-sound` ignores any supplied `eps`/`margin`/`lo`/`hi` lines and recomputes -those bounds from the raw entries. - -`certify_head` reads a single input file with exact head inputs (embeddings, -projection weights, direction vector, and scale) and recomputes -scores/values inside Lean. - -## Soundness statement (what is proven vs checked) - -The Lean library defines the core math objects (finite probability, mixers, linearizations, and operator-norm-style bounds) and proves a number of lemmas about them. The CLI sound path produces certificates using exact `Rat` arithmetic and a trusted checker that verifies internal arithmetic relationships between certificate fields. - -At present, the checker does **not** include a bridge theorem that connects certificate validity to -Lean-defined Jacobian bounds (for example, a theorem of the form `||layerJacobian - I|| <= C`). -The downstream error certificate is only checked for internal arithmetic consistency. -Treat sound certificates as **internally consistent bound reports**, not as a fully formal -end-to-end verification of transformer Jacobians. - -Margin-based softmax tightening exists, but only **best-match margin evidence** is accepted today. Direct `--softmaxMargin` is rejected by the checker, and best-match logit bounds are generated in untrusted code and only checked for internal consistency. - -For known gaps and ongoing upgrades, see `SOUNDNESS_LIMITATIONS.md`. - -## North Star - -NFP’s long-term direction is **verified circuit discovery**: - -- Use fast, exploratory tooling to **propose** candidate circuits (e.g. induction-style head interactions), -- then produce **checkable evidence** (bounds / certificates) that a skeptical reader can re-run and validate. - -Concretely, the intended split is: - -- **Discovery / exploration (untrusted, fast):** - Heuristic search, ranking, and diagnostics are allowed here (and should be clearly labelled as such). - This includes things like candidate search (`induction`) and comparison estimates printed under diagnostics/verbose flags. - -- **Certification / checking (trusted, boring):** - Anything described as “rigorous” should be justified by conservative inequalities or by a certificate that a checker can validate. - The long-term aim is that Lean does as little “real inference” as possible: instead of running large forward passes, - it should mostly **check small, structured proof obligations** (e.g. inequality chains, norm bounds, interval/rational arithmetic). - -Current state: `certify` is already an example of this direction (sound-mode reporting using exact `Rat` arithmetic rather than trusted floats), -but the certificate story is still evolving and interfaces may change. - -Model trajectory: GPT-2 support is currently a proving ground for the end-to-end workflow (export → analyze/search → bound/certify). -The goal is to gradually cover more modern decoder blocks (e.g. RoPE-style position handling) while keeping the certification/checking layer lightweight. - -## Reproduce results - -Minimal local demo (no network needed): - -```bash -lake build -q --wfail -lake build nfp -q --wfail -lake exe nfp certify tests/fixtures/tiny_sound_binary.nfpt \ - --output reports/tiny_sound_demo.txt -``` - -Expected artifacts: -- `reports/tiny_sound_demo.txt` - -Optional (rebuild the tiny binary from text fixtures and run a fixed induction cert): - -```bash -./scripts/demo_tiny_local_binary.sh -./scripts/demo_tiny_induction_cert.sh -``` - -Expected artifacts (optional path): -- `reports/tiny_sound_local_binary.txt` -- `reports/tiny_induction_cert.txt` - -End-to-end GPT-2 demo (requires network/model download): - -```bash -./scripts/demo_gpt2_sound.sh -./scripts/demo_gpt2_induction_sound.sh -``` +All `direction`, `embed`, and projection matrices must be fully specified. If no `active` lines +appear, the checker defaults to all nonzero queries. -Expected artifacts: -- `reports/gpt2_sound_demo.txt` -- `reports/gpt2_induction_sound_scan.txt` +## Soundness boundary -Notes: -- If a legacy `.nfpt` header is missing `gelu_kind`, `demo_gpt2_sound.sh` writes - `models/gpt2_with_gelu_kind.nfpt` and uses that for certification. -- `demo_gpt2_induction_sound.sh` can take a while on CPU; use `--top 1`, - `--fast`, or `--jobs 2` to shorten the scan or run it on a larger machine. -- You can also set `NFP_BIN=./.lake/build/bin/nfp` to avoid repeated `lake exe` - startup overhead. +- Untrusted scripts may use floating-point numerics to generate candidate certificates. +- The CLI **only verifies** certificate constraints inside Lean; it does not search for witnesses. +- Downstream error certificates are currently **not derived in Lean** (work in progress). +For known gaps, see `SOUNDNESS_LIMITATIONS.md`. ## Requirements -- **Lean 4** (pinned by `lean-toolchain`) and **Lake**. - - Easiest install: `elan` (Lean toolchain manager). -- A standard build toolchain for Lean (C/C++ compiler, `make`, etc.). -- (Optional) **Python** for the export scripts in `scripts/`. - -Lean version is pinned in `lean-toolchain` (currently `leanprover/lean4:v4.26`). - -## Getting started - -Clone and build: - -```bash -lake update -lake build -``` - -Run the CLI (see subcommands below): - -```bash -lake exe nfp --help -``` - -## Models - -The CLI expects a model file in **`.nfpt`** format (NFP_BINARY_V1). -Most commands (analysis/induction/diagnostics) require `NFP_BINARY_V1`; legacy `NFP_TEXT_V1/V2` -is supported only for local SOUND certification. - -- Create a local `models/` directory and place your `.nfpt` files there (the repo does not version model files; the author’s setup may have used local symlinks). -- You can export GPT-2 weights from Hugging Face using the scripts in `scripts/`. - -`.nfpt` files use a small text header followed by a binary payload: - -``` -NFP_BINARY_V1 -num_layers=... -num_heads=... -model_dim=... -head_dim=... -hidden_dim=... -vocab_size=... -seq_len=... -layer_norm_eps=... -gelu_kind=... -BINARY_START -``` - -The payload is raw little-endian bytes in a fixed order (tokens, embeddings, then weights). - -Notes: -- `layer_norm_eps` (or legacy `eps`) and `gelu_kind` (or legacy `gelu_deriv`) are required for - SOUND certification. -- Global sound certification supports `NFP_BINARY_V1`. Local sound certification supports - `NFP_BINARY_V1` (fixed-point union-box) and legacy `NFP_TEXT_V1/V2`. - -### Exporting GPT-2 to `.nfpt` - -The export scripts use `torch` + `transformers`. - -Example (write `models/gpt2_rigorous.nfpt`): - -```bash -python scripts/export_gpt2.py models/gpt2_rigorous.nfpt -``` - -If you prefer a locked Python environment, use `uv` or a venv and install dependencies from `pyproject.toml`: - -```bash -uv run python scripts/export_gpt2.py models/gpt2_rigorous.nfpt -``` - -### GPT-2 sound demo (global) - -This demo downloads GPT-2 weights on demand, exports a binary `.nfpt`, and runs the -global sound certificate. - -```bash -./scripts/demo_gpt2_sound.sh -``` - -Artifacts: -- `models/gpt2.nfpt` (binary export) -- `reports/gpt2_sound_demo.txt` (sound certificate report) - -### GPT-2 induction sound scan - -This demo builds the rigorous induction dataset (if needed), finds candidate -induction head pairs, and ranks them by sound logit-diff lower bounds. - -```bash -./scripts/demo_gpt2_induction_sound.sh -``` - -Artifacts: -- `models/gpt2_rigorous.nfpt` (binary export) -- `reports/gpt2_induction_sound_scan.txt` (sound scan report) - -### Tiny local binary demo - -This demo converts the tiny text fixtures into a binary `.nfpt` and runs a local -sound certificate (with `--delta`). - -```bash -./scripts/demo_tiny_local_binary.sh -``` - -Artifacts: -- `tests/fixtures/tiny_sound_binary.nfpt` (binary fixture) -- `reports/tiny_sound_local_binary.txt` (local sound certificate report) - -### Tiny induction cert demo - -This demo computes a minimal induction head certificate on the tiny fixture. - -```bash -./scripts/demo_tiny_induction_cert.sh -``` - -Artifacts: -- `reports/tiny_induction_cert.txt` (induction cert report) - -## CLI overview - -The main entrypoint is: - -```bash -lake exe nfp [args] [flags] -``` - -By default, `nfp` mirrors everything printed to stdout into `logs/` as a timestamped `.log` file. - -### `analyze` - -Runs the default end-to-end analysis for the supplied model and prints a human-readable report. - -```bash -lake exe nfp analyze models/gpt2_rigorous.nfpt \ - --threshold 0.1 --verify --verbose --output report.txt -``` - -- `--threshold` (`-t`) sets the minimum effect threshold used for verification (default: `0.1`). -- `--verify` optionally runs causal verification using model-provided inputs. -- `--verbose` prints model metadata and per-stage status messages. -- `--output` (`-o`) writes the report to a file instead of stdout. - -### `induction` - -Searches for **candidate induction circuits** and ranks head pairs by a mechanical score. - -```bash -lake exe nfp induction models/gpt2_rigorous.nfpt \ - --threshold 0.0 --diagnostics --diagTop 5 --adaptive --verbose -``` - -- `--threshold` (`-t`) sets the minimum normalized effect (default: `0.0`). -- `--correct` / `--incorrect` manually pick logit IDs for the induction target (otherwise the target is inferred from tokens). -- `--verify` runs causal verification via head ablation on the top-10 candidates. -- `--diagnostics` enables bound breakdowns; `--diagTop` controls how many candidates receive diagnostics (default: `5`). -- `--adaptive` turns on the adaptive bound scheduler. Tuning flags include `--targetSlack` (default: `8.0`), - `--maxUpgrades` (default: `120`), `--minRelImprove` (default: `0.01`), `--krylovSteps` (default: `2`), - and `--adaptiveScope` (`layernorm | all`, default: `layernorm`). -- `--verbose` prints detailed scoring metrics for each candidate. - -### `certify` - -Computes a conservative **certificate report** in sound mode using exact `Rat` arithmetic (no trusted floats). - -Note: global sound certification supports `NFP_BINARY_V1`. Local sound certification -supports `NFP_BINARY_V1` (fixed-point union-box) and legacy `NFP_TEXT_V1/V2`. - -`certify` supports both: -- **global certification** (weights only), and -- **local certification** (weights + a small input region around a concrete prompt/input). - -```bash -lake exe nfp certify models/gpt2_rigorous.nfpt \ - --output cert.txt -``` - -- For local (input-dependent) LayerNorm certification, pass an ℓ∞ radius `δ`: - -```bash -lake exe nfp certify models/gpt2_rigorous.nfpt \ - --delta 0.01 -``` - -If you want to override the embedded input, pass a separate input `.nfpt`: - -- LayerNorm ε is read from the model header (`layer_norm_eps`). -- `gelu_kind` in the model header selects the GeLU derivative target (`tanh` or `exact`). -- `--delta` sets the local ℓ∞ radius `δ` (default: `0`). Providing `--delta` enables local certification. -- `--input` optionally provides an input `.nfpt` file used for local certification; if omitted and the - model file embeds `EMBEDDINGS`, `certify` reuses the model file as its input source. -- `--softmaxMargin` provides a logit-margin lower bound, but it is currently **rejected** by the - verifier (use `--bestMatchMargins` instead). -- `--softmaxExpEffort` controls exp lower-bound effort used for margin-based softmax tightening (default: `1`). -- `--bestMatchMargins` runs a full best-match sweep (binary + local only) and tightens layer - softmax bounds using verified margin evidence. It is incompatible with `--softmaxMargin`. -- `--targetOffset` selects the target-token offset for best-match margins (default: `-1`). -- `--maxSeqLen` caps the sequence length used in best-match margin sweeps (default: `0` = full `seq_len`). -- `--tightPattern`, `--tightPatternLayers`, and `--perRowPatternLayers` control pattern tightening - during best-match sweeps. -- `--scalePow10` sets fixed-point scaling for best-match sweeps (default: `9`). -- `--noncausalPattern` disables the causal-prefix restriction (required for non-causal models). -- `--soundnessBits` sets dyadic sqrt precision for LayerNorm bounds (default: `20`). -- `--partitionDepth` requests input partitioning depth (default: `0`; scaffold only, must remain `0` for now). -- `--output` (`-o`) writes the report to a file (otherwise it prints to stdout). - -### `head_bounds` - -Computes sound per-head contribution bounds (global weight-only, or local with `--delta`). - -```bash -lake exe nfp head_bounds models/gpt2_rigorous.nfpt -``` - -For local bounds (uses input embeddings in the model file when present): - -```bash -lake exe nfp head_bounds models/gpt2_rigorous.nfpt --delta 0.01 -``` - -- `--delta` enables local head bounds; `--input` can override the embedded input. -- LayerNorm ε is read from the model header (`layer_norm_eps`). -- `--soundnessBits` sets dyadic sqrt precision for LayerNorm bounds (default: `20`). -- `--scalePow10` controls fixed-point scaling for global bounds (default: `9`). -- `--output` (`-o`) writes the report to a file (otherwise it prints to stdout). - -### `head_pattern` - -Computes a sound local attention pattern bound for a single head (binary only), -propagating per-position intervals up to the target layer (bounded by `maxSeqLen`). -The pattern compares logits for keys whose **shifted-key token** matches the -query’s **offset token** (e.g., `--offset -1` matches the previous token, and -`--offset 0 --keyOffset -1` matches the copy-next pattern). - -```bash -lake exe nfp head_pattern models/gpt2_rigorous.nfpt --layer 0 --head 0 --delta 0.01 --offset -1 -``` - -- `--offset` selects the target key position relative to the query (default: `-1` for previous token). -- `--keyOffset` selects which key-position token is matched (default: `0` for the key token itself). -- `--maxSeqLen` caps the sequence length analyzed for pattern bounds (default: `256`). -- `--input` optionally provides an input `.nfpt` file; required for legacy text models. -- `--delta` sets the local input radius; LayerNorm ε is read from the model header (`layer_norm_eps`). -- `--soundnessBits` sets dyadic sqrt precision for LayerNorm bounds (default: `20`). -- `--tightPattern` enables a slower but tighter pattern bound near the target layer. -- `--tightPatternLayers` sets how many layers use tight bounds (default: `1`; implies `--tightPattern`). -- `--perRowPatternLayers` sets how many layers use per-row MLP propagation (default: `0`). -- `--softmaxExpEffort` sets the exp lower-bound effort for margin-derived softmax bounds (default: `1`). -- `--scalePow10` sets fixed-point scaling for best-match bounds (default: `9`). -- `--noncausalPattern` disables the causal-prefix restriction (required for non-causal models). -- `--bestMatch` switches to a single-query best-match bound (default query: last position). -- `--affine` uses affine Q/K dot bounds in best-match mode. -- `--sweep` prints best-match bounds for all valid query positions (requires `--bestMatch`). -- `--queryPos` chooses the query position for best-match bounds (default: last position). -- `--output` (`-o`) writes the report to a file (otherwise it prints to stdout). - -### `induction_cert` - -Computes a minimal sound induction-head certificate by combining two pattern -certificates and a value-coordinate lower bound (binary only). - -```bash -lake exe nfp induction_cert models/gpt2_rigorous.nfpt \ - --layer1 0 --head1 0 --layer2 1 --head2 0 --coord 0 --delta 0.01 \ - --target 42 --negative 17 -``` - -- `--layer1/--head1` selects the previous-token head; `--layer2/--head2` selects the - token-match head. -- `--coord` chooses the output coordinate used for the value lower bound. -- `--offset1/--offset2` adjust the token-match offsets (default: `-1`). -- `--keyOffset1/--keyOffset2` adjust the key-token offsets (default: `0`; - use `--offset2 0 --keyOffset2 -1` for copy-next induction). -- `--target/--negative` optionally add a logit-diff lower bound using unembedding columns. -- `--input` optionally provides an input `.nfpt` file; required for legacy text models. -- `--delta` sets the local input radius (default: `0`). -- `--soundnessBits` sets dyadic sqrt precision for LayerNorm bounds (default: `20`). -- `--tightPattern` enables a slower but tighter pattern bound near the target layer. -- `--tightPatternLayers` sets how many layers use tight bounds (default: `1`; implies `--tightPattern`). -- `--perRowPatternLayers` sets how many layers use per-row MLP propagation (default: `0`). -- `--softmaxExpEffort` sets the exp lower-bound effort for margin-derived softmax bounds (default: `1`). -- `--maxSeqLen` caps the sequence length analyzed for best-match bounds (default: `256`). -- `--scalePow10` sets fixed-point scaling for best-match bounds (default: `9`). -- `--noncausalPattern` disables the causal-prefix restriction (required for non-causal models). -- `--bestMatch` switches to single-query best-match bounds (default query: last position). -- `--affine` uses affine Q/K dot bounds in best-match mode. -- `--queryPos` chooses the query position for best-match bounds (default: last position). -- `--iterTighten` iteratively tightens best-match bounds (tight/per-row layers and scale precision). -- `--output` (`-o`) writes the report to a file (otherwise it prints to stdout). - -### `rope` - -Generates RoPE-related linearization bounds used by the certificate/checking pipeline. - -```bash -lake exe nfp rope --seqLen 4 --pairs 8 -``` - -- `--seqLen` instantiates the bound at the given sequence length (default: `4`). -- `--pairs` sets the number of RoPE pairs; the dimension is `2 * pairs` (default: `8`). - -### `bench` - -Runs repeatable microbenchmarks for analysis or induction search. - -```bash -lake exe nfp bench models/gpt2_rigorous.nfpt --mode analysis --runs 5 --repeats 1 -``` - -- `--mode` selects `analysis` or `induction` (default: `analysis`). -- `--runs` sets the number of timed runs (default: `5`). -- `--repeats` repeats the inner workload per run (default: `1`). -- `--threshold` sets the analyze threshold (default: `0.1`). -- `--minEffect` sets the induction minEffect (default: `0.0`). -- `--correct/--incorrect` override induction target tokens. -- `--verbose` prints per-run timing details. -- `--breakdown` emits per-phase averages (analysis only). - -### `sound_cache_check` - -Checks SOUND fixed-point cache soundness (CI / small fixtures). - -```bash -lake exe nfp sound_cache_check tests/fixtures/tiny_sound_binary.nfpt -``` - -- `--scalePow10` sets the fixed-point scale exponent (default: `9`). -- `--maxTokens` checks at most this many numeric tokens (default: `0` = all). - -### `sound_cache_bench` - -Benchmarks SOUND fixed-point cache build (text or binary). - -```bash -lake exe nfp sound_cache_bench models/gpt2_rigorous.nfpt --runs 3 -``` - -- `--scalePow10` sets the fixed-point scale exponent (default: `9`). -- `--runs` sets the number of benchmark runs (default: `1`). - -### `dump` - -Dumps a small forward-pass slice for PyTorch sanity checking. - -```bash -lake exe nfp dump models/gpt2_rigorous.nfpt --layer 0 --pos 0 --kind afterLayer -``` - -- `--layer` selects the layer index (default: `0`). -- `--pos` selects the token position / row index (default: `0`). -- `--take` limits columns from the start (default: `16`). -- `--kind` chooses `embeddings | layerInput | postAttn | afterLayer` (default: `afterLayer`). - -### `logit_diff` - -Computes an empirical logit-difference for a target vs. negative token. - -```bash -lake exe nfp logit_diff models/gpt2_rigorous.nfpt 42 17 --autoNegative -``` - -- `--pos` selects the token position (default: last position). -- `--input` provides an input `.nfpt` with TOKENS + EMBEDDINGS. -- `--autoNegative` uses the top non-target logit as the negative token. - -### `--version` - -Prints the CLI version string. - -## What “rigorous” means here - -At a high level, the “rigorous” path avoids heuristic operator-norm estimation and instead uses **upper bounds** derived from standard inequalities (examples you may see in logs): - -- Frobenius-norm based bounds. -- Gram-matrix based bounds. -- Schur / Brauer-style eigenvalue bounds for symmetric matrices. -- Row-wise softmax operator bounds using quantities like `rowMaxP`, `rowTrace`, Gershgorin-style estimates, and a “moment” bound. - -The CLI may still compute **power-iteration estimates** for comparison, but those are explicitly labelled as diagnostics and are not used to produce the rigorous `ub=…` values. - -## Reproducing the example command - -A typical workflow: - -```bash -# 1) Build -lake update -lake build - -# 2) Export a model (optional) -python scripts/export_gpt2.py models/gpt2_rigorous.nfpt - -# 3) Run induction search with diagnostics -lake exe nfp induction models/gpt2_rigorous.nfpt -v -d | sed -n '1,220p' -``` - -## Project layout - -- `Main.lean` — CLI wiring and command definitions. -- `Nfp/` — library code (probability, transformer semantics, soundness/cert machinery, discovery routines). -- `scripts/` — Python helpers to export models and generate induction datasets. -- `models/` — local model files (not versioned here if large; author’s setup may have used local symlinks). +- **Lean 4** (pinned in `lean-toolchain`) and **Lake**. +- Optional: **Python** for helper scripts (`scripts/`). -## License +## Contributing -This project is licensed under the GNU Affero General Public License v3.0 or later (AGPL-3.0-or-later). See the LICENSE file. +Please follow the project rules in `AGENTS.md` (no `sorry`, no linter disables, total soundness in +trusted namespaces). From 76320d4fcb3aa6faab82de082c531294590329ba Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 3 Jan 2026 14:37:14 +0100 Subject: [PATCH 096/244] Update soundness docs and remove changelog notes --- CHANGELOG_NOTES.md | 46 --------------------------- CLAIMS.md | 49 ++++++++++++++++++++++------- SOUNDNESS_LIMITATIONS.md | 67 ++++++++++++++-------------------------- 3 files changed, 60 insertions(+), 102 deletions(-) delete mode 100644 CHANGELOG_NOTES.md diff --git a/CHANGELOG_NOTES.md b/CHANGELOG_NOTES.md deleted file mode 100644 index eeb5be8..0000000 --- a/CHANGELOG_NOTES.md +++ /dev/null @@ -1,46 +0,0 @@ -# CHANGELOG / NOTES - -## 2025-12-24 -- Added attention pattern-term coefficients using max `W_Q/W_K` row-sum norms and a conservative - LayerNorm output magnitude bound; updated layer cert formulas and reports accordingly. -- Added `modelDim`/`headDim` metadata to sound certificates and threaded through the checker. - -## 2025-12-23 -- Added margin-derived softmax max-probability and Jacobian bounds for best-match pattern certificates. -- Added effort-indexed exp lower bounds (scaled Taylor + squaring) and wired them into best-match softmax bounds. -- Extended best-match head pattern certs with a recorded softmax Jacobian upper bound and wired untrusted computation to populate it. -- Noted that the exp lower-bound correctness is not yet formalized in Lean. -- Layer-level sound certificates now use a portfolio softmax Jacobian bound field, with margin-based - tightening available when margins are supplied (defaults remain worst-case today). -- Added `nfp certify --softmaxMargin/--softmaxExpEffort` flags and report fields to pass margin - evidence into layer-level softmax portfolio bounds. - -## 2025-12-22 -- Optimized induction head discovery by caching per-head induction scores and per-layer input norms, eliminating redundant pattern scans and repeated Frobenius norm computations. -- Tightened induction error bounds by using data-dependent V norms (Frobenius/op) in pattern-term calculations. -- Tightened per-head weight operator-norm bounds using Brauer/moment Gram candidates, improving pattern-term bounds. - -## 2025-12-21 -- Updated sound certificate algebra to include the attn*mlp cross term and surfaced it in output. -- Added CLAIMS.md and clarified soundness limitations and reproducibility documentation. -- Added operator-norm bound lemmas for SignedMixers, including a residual composition bound that takes external operator-norm bounds. -- Added a helper lemma to extract the `C` identity from `LayerAmplificationCert.Valid`. -- Added a bridge lemma to bound residual composition from component bounds plus the cast `C` identity. -- Added cast and `Valid`-based bridge lemmas to move from certificate validity to the residual bound. -- Added a `DeepLinearization` lemma that turns per-component operator-norm bounds into a layer residual bound. -- Added a certificate-to-Jacobian bridge lemma tying layer certificates to residual bounds (under component-bound assumptions). -- Added attention/MLP component bound lemmas that connect certificate identities to operator-norm bounds. -- Added an assumption-based bridge theorem combining component bounds into a full layer residual bound. -- Added a `LayerComponentNormAssumptions` structure to package the remaining component-norm obligations. -- Added an operator-norm bound lemma for diagonal mixers based on uniform entry bounds. -- Added an operator-norm bound lemma for `A ∘ diag(d) ∘ B` from component bounds. -- Replaced the MLP component assumption with a factored-Jacobian assumption and derived the MLP bound from it. -- Added `MLPFactorization` and `mlpFactors` to `DeepLinearization`, with MLP Jacobians derived from the factorization data. -- Added `mlpWinBound`/`mlpWoutBound` fields to the sound certificate and wired the bridge to use them for MLP coefficient bounds. -- Updated README tiny demo instructions to use the binary fixture directly and document the optional scripted path. -- Added a helper to patch missing `gelu_kind` in legacy `.nfpt` headers and wired it into the GPT-2 sound demo script. -- Made demo scripts fall back to `python3` when `python` is unavailable. -- Documented the GPT-2 header patch behavior in README. -- Updated the GPT-2 induction demo to patch legacy headers and use the current Python executable when generating rigorous models. -- Documented GPT-2 induction scan runtime notes in README. -- Added `--fast`, `--jobs`, and `--nfp-bin` options to the induction scan script for faster runs. diff --git a/CLAIMS.md b/CLAIMS.md index 613ab25..84052bd 100644 --- a/CLAIMS.md +++ b/CLAIMS.md @@ -1,15 +1,40 @@ # CLAIMS This file lists what is formally proven in Lean, what is soundly checked by the trusted checker, -what is heuristic, and what is not yet proven. - -| Claim | Status | Where | -| --- | --- | --- | -| Definitions of mixers/signed mixers and linearizations (ReLU, GeLU, LayerNorm, softmax) with basic lemmas (composition, diagonality, etc.) | Proven in Lean | `Nfp/SignedMixer.lean`, `Nfp/Linearization.lean` | -| Model-level SOUND certificate checker validates internal arithmetic consistency and recomputes weight-derived bounds from model files | Soundly checked (Lean) | `Nfp/Sound/Cert.lean`, `Nfp/Sound/IO.lean`, `Nfp/Sound/BinaryPure.lean`, `Nfp/Sound/TextPure.lean` | -| Per-head contribution, head-pattern, and induction-head certificates (including best-match variants) have internal consistency checks | Soundly checked (Lean) | `Nfp/Sound/HeadCert.lean`, `Nfp/Sound/IO.lean` | -| Sound bound formulas use exact `Rat` arithmetic (LayerNorm/softmax/GeLU envelopes); witness values are produced in untrusted code and then checked | Soundly checked formulas; untrusted witnesses | `Nfp/Sound/Bounds.lean`, `Nfp/Untrusted/SoundCompute.lean`, `Nfp/Untrusted/SoundBinary.lean` | -| Best-match margin tightening uses untrusted logit bounds; verification checks only internal margin/softmax consistency | Partially checked (internal consistency only) | `Nfp/Sound/HeadCert.lean`, `Nfp/Sound/IO.lean`, `Nfp/Untrusted/SoundCompute.lean` | -| Heuristic discovery and ranking of induction-style candidates | Heuristic | `Nfp/Discovery.lean`, CLI `induction` | -| Empirical causal verification via head ablation (competence/control/energy checks) | Heuristic | `Nfp/Verification.lean`, CLI `analyze --verify` / `induction --verify` | -| End-to-end statement that certificate validity implies `||layerJacobian - I|| <= C` for Lean-defined Jacobians | Not yet proven | See `SOUNDNESS_LIMITATIONS.md` | +what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewrite. + +## Proven in Lean + +- Circuit core definitions and semantics (typed circuits, evaluation, interfaces). +- Softmax-margin certificate soundness: `checkSoftmaxMarginCert` implies + `SoftmaxMarginBoundsOn`. +- Value-range certificate soundness: `checkValueRangeCert` implies `ValueRangeBounds`. +- Logit-diff lower bound lemma: `logitDiffLowerBound_le`. +- Downstream linear certificate soundness: `checkDownstreamLinearCert` implies + `DownstreamLinearBounds`. +- Row-sum matrix norm bounds for `mulVec` under uniform input magnitude. + +## Soundly checked by the trusted CLI + +- `nfp induction certify` verifies softmax-margin certificates, value-range certificates, + and computes a logit-diff lower bound. +- `nfp induction certify_sound` recomputes `eps`/`margin` and `lo`/`hi` from raw entries + and verifies the resulting certificates. +- `nfp induction certify_head` recomputes scores/values from exact head inputs and verifies + the resulting induction certificate (experimental, potentially slow). +- `nfp induction certify_end_to_end` composes a head-level logit-diff lower bound with a + downstream error certificate (arithmetic consistency only). + +## Untrusted / heuristic + +- Python helpers that generate certificates from GPT-2 weights or head inputs: + `scripts/build_gpt2_induction_cert.py`, `scripts/build_gpt2_head_inputs.py`, + `scripts/build_downstream_linear_cert.py`. +- The head-input extractor currently ignores LayerNorm and bias terms. +- Any downstream error bound provided externally (until it is computed in Lean). + +## Not yet proven + +- End-to-end claims about GPT-2 logits or Jacobians derived from certificates. +- Sound, verified downstream bounds computed from GPT-2 weights inside Lean. +- A bridge theorem connecting certificate validity to full circuit/model semantics. diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index 46840a1..0b2201c 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -1,47 +1,26 @@ -## SOUNDNESS upgrade status +# SOUNDNESS_LIMITATIONS -This file tracks **current limitations** and **remaining work** for the rigorous -soundness upgrade. It is intentionally brief and human-readable. +This file tracks **current limitations** and **remaining work** for the tabula rasa rewrite. +It is intentionally brief and focused on the soundness boundary. -### Current limitations -- The bridge theorem in `Nfp/Sound/Bridge.lean` links `LayerAmplificationCert` bounds to - `DeepLinearization` residual Jacobians, but it requires external operator-norm assumptions - (LN Jacobians, attention full Jacobian, and MLP factors). The trusted checker now recomputes - weight-derived bounds (W_Q/W_K/W_V/W_O, MLP W_in/W_out, LN1 gamma/beta, LN2 gamma) from model files, - but it still treats softmax probability or margin evidence as external and does not derive those - bounds from logits. -- `partitionDepth > 0` is rejected with an explicit error (no partitioning logic yet). -- Affine arithmetic is available via `--affine` for best-match Q/K dot bounds, but those dot-bound - computations are untrusted witness generation; the checker only validates the downstream - margin-to-probability derivations. -- Softmax Jacobian bounds in the standard `certify` path derive a probability interval from a - global attention-score magnitude bound (LN1 max-abs + W_Q/W_K norms), but it is typically very - loose and often collapses to `[0,1]`. Direct `--softmaxMargin` is still rejected because margin - evidence is unverified. -- Best-match margin tightening is now available via `nfp certify --bestMatchMargins` (binary + local - inputs with EMBEDDINGS). It runs a full best-match sweep across heads and query positions, which - can be expensive and will fail if coverage is incomplete. -- Local pattern/value/logit bounds now assume **causal attention** by default (prefix-only keys). - Use `--noncausalPattern` for non-causal models; otherwise these bounds are not sound. -- Per-head best-match tightening (used by head-pattern/induction certs) now records the **actual** - `softmaxExpEffort` chosen by iterative exp-portfolio tightening (early stop on low relative - improvement). The verifier accepts any per-head effort ≤ the requested cap, but model-level - certification still requires `--bestMatchMargins`. -- Best-match pattern certificates rely on untrusted interval/affine logit bounds to produce a - margin, and then use a margin-derived softmax Jacobian bound with an effort-indexed `expLB` - (scaled Taylor + squaring). The lower-bound correctness of `expLB` is not yet formalized in Lean. -- GeLU derivative bounds are conservative envelopes; the exact interval supremum is not computed yet. -- Attention Jacobian bounds now include an explicit pattern-term coefficient using max `W_Q/W_K` - row-sum norms and a conservative LayerNorm output magnitude bound (`max|gamma|*sqrt(d)+max|beta|`), - but this is still very conservative and only connected to the Lean Jacobian theorems - under the external norm assumptions above. +## Current limitations -### Remaining work -- Implement input-space partitioning in the SOUND local path and plumb it through the certify pipeline. -- Replace or augment interval propagation with affine forms to preserve correlations. -- Add sound probability interval extraction for softmax (requires sound exp/log-sum-exp bounds). -- Verify or compute margin evidence in the trusted path so margin-derived softmax tightening can be - enabled without a best-match sweep and without rejecting `--softmaxMargin`. -- Tighten GeLU derivative envelopes to the exact interval supremum if desired. -- Discharge the bridge theorem’s component-norm assumptions from certificates/model weights, and - connect the resulting statement to the `Linearization` Jacobian theorems. +- The trusted CLI only **checks certificates**; it does not search for witnesses or run a model. +- Induction certificates are **head-level** (softmax-margin + value-range + logit-diff lower bound). + They do **not** yet imply end-to-end model behavior. +- Downstream error bounds are still **external**. `certify_end_to_end` checks arithmetic + consistency, but it does not derive the downstream bound from GPT-2 weights. +- The `certify_head` path uses a **head-input file** extracted by an untrusted script; the extractor + currently ignores LayerNorm and bias terms, so it is not end-to-end faithful. +- Performance: exact head-input recomputation in Lean can be slow for nontrivial sequence lengths. +- There is no bridge theorem connecting certificate validity to a full circuit/model semantics + statement (for example, a formal statement about logits under a transformer block stack). + +## Remaining work + +- Compute the downstream bound **inside Lean** from model weights and verified bounds + (row-sum norms, LayerNorm/GeLU/softmax envelopes), and wire this into `certify_end_to_end`. +- Replace untrusted extraction with a verified parser for model weight slices. +- Add a formal bridge from certificates to circuit semantics and (eventually) to end-to-end + transformer claims. +- Improve performance for the exact head-input path without weakening soundness. From 2b6225128a50c7b85d7ff3bdc5dac05ce1c172bc Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 3 Jan 2026 15:08:04 +0100 Subject: [PATCH 097/244] Add nfpt parsing and end-to-end model certification --- AGENTS.md | 2 + CLAIMS.md | 7 +- Nfp/Cli.lean | 62 +++++++++ Nfp/IO.lean | 283 ++++++++++++++++++++++++++++++++++++++- Nfp/IO/NfptPure.lean | 216 ++++++++++++++++++++++++++++++ Nfp/IO/Pure.lean | 111 +++++++++++++++ README.md | 33 ++++- SOUNDNESS_LIMITATIONS.md | 9 +- 8 files changed, 715 insertions(+), 8 deletions(-) create mode 100644 Nfp/IO/NfptPure.lean diff --git a/AGENTS.md b/AGENTS.md index 55f0a23..702d495 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -327,6 +327,8 @@ but you **must** update this list in the same commit. ### 5.6 CLI surface - `Nfp/IO/Pure.lean` - Pure parsing helpers for CLI inputs. +- `Nfp/IO/NfptPure.lean` + - Pure parsing helpers for `NFP_BINARY_V1` model slices. - `Nfp/IO.lean` - IO-only wrappers for loading inputs and running checks. - `Nfp/Cli.lean` diff --git a/CLAIMS.md b/CLAIMS.md index 84052bd..c8c747a 100644 --- a/CLAIMS.md +++ b/CLAIMS.md @@ -24,6 +24,11 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri the resulting induction certificate (experimental, potentially slow). - `nfp induction certify_end_to_end` composes a head-level logit-diff lower bound with a downstream error certificate (arithmetic consistency only). +- `nfp induction certify_end_to_end_matrix` computes a downstream bound from a matrix payload + using verified row-sum norms, then composes it with the head-level logit-diff lower bound. +- `nfp induction certify_end_to_end_model` derives a downstream matrix from an `NFP_BINARY_V1` + model file (unembedding direction only) and composes it with the head-level logit-diff + lower bound. ## Untrusted / heuristic @@ -31,7 +36,7 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri `scripts/build_gpt2_induction_cert.py`, `scripts/build_gpt2_head_inputs.py`, `scripts/build_downstream_linear_cert.py`. - The head-input extractor currently ignores LayerNorm and bias terms. -- Any downstream error bound provided externally (until it is computed in Lean). +- Any downstream error bound provided externally (outside the matrix-payload or model-based path). ## Not yet proven diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index 0fcf11b..5026942 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -108,6 +108,66 @@ def inductionCertifyEndToEndCmd : Cmd := `[Cli| "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." ] +/-- `nfp induction certify_end_to_end_matrix` subcommand. -/ +def runInductionCertifyEndToEndMatrix (p : Parsed) : IO UInt32 := do + let scoresPath := p.flag! "scores" |>.as! String + let valuesPath := p.flag! "values" |>.as! String + let matrixPath := p.flag! "matrix" |>.as! String + let minActive? := (p.flag? "min-active").map (·.as! Nat) + let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) + let minMarginStr? := (p.flag? "min-margin").map (·.as! String) + let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) + IO.runInductionCertifyEndToEndMatrix scoresPath valuesPath matrixPath + minActive? minLogitDiffStr? minMarginStr? maxEpsStr? + +/-- `nfp induction certify_end_to_end_matrix` subcommand. -/ +def inductionCertifyEndToEndMatrixCmd : Cmd := `[Cli| + certify_end_to_end_matrix VIA runInductionCertifyEndToEndMatrix; + "Check end-to-end induction bounds using a downstream matrix payload." + FLAGS: + scores : String; "Path to the softmax-margin certificate file." + values : String; "Path to the value-range certificate file." + matrix : String; "Path to the downstream matrix payload file." + "min-active" : Nat; "Optional minimum number of active queries required \ + (default: max 1 (seq/8))." + "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ + (rational literal). Defaults to 0 when \ + direction metadata is present." + "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." + "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." +] + +/-- `nfp induction certify_end_to_end_model` subcommand. -/ +def runInductionCertifyEndToEndModel (p : Parsed) : IO UInt32 := do + let scoresPath := p.flag! "scores" |>.as! String + let valuesPath := p.flag! "values" |>.as! String + let modelPath := p.flag! "model" |>.as! String + let inputBound := p.flag! "input-bound" |>.as! String + let minActive? := (p.flag? "min-active").map (·.as! Nat) + let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) + let minMarginStr? := (p.flag? "min-margin").map (·.as! String) + let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) + IO.runInductionCertifyEndToEndModel scoresPath valuesPath modelPath inputBound + minActive? minLogitDiffStr? minMarginStr? maxEpsStr? + +/-- `nfp induction certify_end_to_end_model` subcommand. -/ +def inductionCertifyEndToEndModelCmd : Cmd := `[Cli| + certify_end_to_end_model VIA runInductionCertifyEndToEndModel; + "Check end-to-end induction bounds using a model file for the downstream matrix." + FLAGS: + scores : String; "Path to the softmax-margin certificate file." + values : String; "Path to the value-range certificate file." + model : String; "Path to the NFP_BINARY_V1 model file." + "input-bound" : String; "Nonnegative input bound for the downstream matrix (rational literal)." + "min-active" : Nat; "Optional minimum number of active queries required \ + (default: max 1 (seq/8))." + "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ + (rational literal). Defaults to 0 when \ + direction metadata is present." + "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." + "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." +] + /-- `nfp induction certify_head` subcommand. -/ def runInductionCertifyHead (p : Parsed) : IO UInt32 := do let inputsPath := p.flag! "inputs" |>.as! String @@ -140,6 +200,8 @@ def inductionCmd : Cmd := `[Cli| inductionCertifyCmd; inductionCertifySoundCmd; inductionCertifyEndToEndCmd; + inductionCertifyEndToEndMatrixCmd; + inductionCertifyEndToEndModelCmd; inductionCertifyHeadCmd ] diff --git a/Nfp/IO.lean b/Nfp/IO.lean index 4d1a85d..5fa2de9 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -1,8 +1,10 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Nfp.IO.Pure +import Nfp.IO.NfptPure import Nfp.Circuit.Cert.LogitDiff import Nfp.Circuit.Cert.DownstreamLinear +import Nfp.Sound.Bounds.MatrixNorm import Nfp.Sound.Induction /-! @@ -39,6 +41,13 @@ def loadDownstreamLinearCert (path : System.FilePath) : let data ← IO.FS.readFile path return Pure.parseDownstreamLinearCert data +/-- Load a downstream matrix payload from disk. -/ +def loadDownstreamMatrixRaw (path : System.FilePath) : + IO (Except String (Sigma (fun rows => + Sigma (fun cols => Pure.DownstreamMatrixRaw rows cols)))) := do + let data ← IO.FS.readFile path + return Pure.parseDownstreamMatrixRaw data + /-- Load raw value-range inputs from disk. -/ def loadValueRangeRaw (path : System.FilePath) : IO (Except String (Sigma Pure.ValueRangeRaw)) := do @@ -420,11 +429,283 @@ def runInductionCertifyEndToEnd (scoresPath : System.FilePath) logitDiffLB={logitDiffLB}, \ downstreamError={downstream.error}, \ finalLB={finalLB})" - return 0 + return 0 else IO.eprintln "error: downstream certificate rejected" return 2 +/-- Check end-to-end induction certificates with a downstream matrix. -/ +def runInductionCertifyEndToEndMatrix (scoresPath : System.FilePath) + (valuesPath : System.FilePath) (matrixPath : System.FilePath) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (1 / 2 : Rat) + let parsedScores ← loadSoftmaxMarginCert scoresPath + match parsedScores with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seq, cert⟩ => + let scoresOk ← checkSoftmaxMargin seq cert + match scoresOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let activeCount := cert.active.card + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" + return 2 + if cert.margin < minMargin then + IO.eprintln + s!"error: margin {cert.margin} below minimum {minMargin}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {cert.eps} above maximum {maxEps}" + return 2 + let parsedValues ← loadValueRangeCert valuesPath + match parsedValues with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seqVals, certVals⟩ => + if hseq : seqVals ≠ seq then + IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" + return 2 + else + have hseq' : seqVals = seq := by + exact (not_ne_iff).1 hseq + let certVals' : ValueRangeCert seq := by + simpa [hseq'] using certVals + let valuesOk ← checkValueRange seq certVals' + match valuesOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let logitDiffLB? := + Circuit.logitDiffLowerBound cert.active cert.prev cert.eps + certVals'.lo certVals'.hi certVals'.vals + let effectiveMinLogitDiff := + match minLogitDiff?, certVals'.direction with + | some v, _ => some v + | none, some _ => some (0 : Rat) + | none, none => none + match logitDiffLB? with + | none => + IO.eprintln "error: empty active set for logit-diff bound" + return (2 : UInt32) + | some logitDiffLB => + let parsedMatrix ← loadDownstreamMatrixRaw matrixPath + match parsedMatrix with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨rows, ⟨cols, raw⟩⟩ => + let inputBound := raw.inputBound + if hneg : inputBound < 0 then + IO.eprintln + s!"error: input-bound {inputBound} must be nonnegative" + return 2 + else + have hinput : 0 ≤ inputBound := by + exact le_of_not_gt hneg + let W : Matrix (Fin rows) (Fin cols) Rat := raw.entries + let downstream := + (Sound.Bounds.buildDownstreamLinearCert W inputBound hinput).1 + let finalLB := logitDiffLB - downstream.error + let violation? : Option Rat := + match effectiveMinLogitDiff with + | none => none + | some minLogitDiff => + if finalLB < minLogitDiff then + some minLogitDiff + else + none + match violation? with + | some minLogitDiff => + IO.eprintln + s!"error: end-to-end logitDiffLB {finalLB} \ + below minimum {minLogitDiff}" + return (2 : UInt32) + | none => + IO.println + s!"ok: end-to-end induction bound certified \ + (seq={seq}, active={activeCount}, \ + logitDiffLB={logitDiffLB}, \ + downstreamError={downstream.error}, \ + finalLB={finalLB})" + return 0 + +/-- Check end-to-end induction certificates using a model file to derive the downstream matrix. -/ +def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) + (valuesPath : System.FilePath) (modelPath : System.FilePath) + (inputBoundStr : String) (minActive? : Option Nat) + (minLogitDiffStr? : Option String) (minMarginStr? : Option String) + (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + let inputBoundE := Pure.parseRat inputBoundStr + match minLogitDiff?E, minMargin?E, maxEps?E, inputBoundE with + | Except.error msg, _, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, _, Except.error msg => + IO.eprintln s!"error: invalid input-bound: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps?, Except.ok inputBound => do + if hneg : inputBound < 0 then + IO.eprintln s!"error: input-bound {inputBound} must be nonnegative" + return 2 + else + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (1 / 2 : Rat) + let parsedScores ← loadSoftmaxMarginCert scoresPath + match parsedScores with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seq, cert⟩ => + let scoresOk ← checkSoftmaxMargin seq cert + match scoresOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let activeCount := cert.active.card + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" + return 2 + if cert.margin < minMargin then + IO.eprintln + s!"error: margin {cert.margin} below minimum {minMargin}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {cert.eps} above maximum {maxEps}" + return 2 + let parsedValues ← loadValueRangeCert valuesPath + match parsedValues with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seqVals, certVals⟩ => + if hseq : seqVals ≠ seq then + IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" + return 2 + else + have hseq' : seqVals = seq := by + exact (not_ne_iff).1 hseq + let certVals' : ValueRangeCert seq := by + simpa [hseq'] using certVals + let valuesOk ← checkValueRange seq certVals' + match valuesOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let logitDiffLB? := + Circuit.logitDiffLowerBound cert.active cert.prev cert.eps + certVals'.lo certVals'.hi certVals'.vals + let effectiveMinLogitDiff := + match minLogitDiff?, certVals'.direction with + | some v, _ => some v + | none, some _ => some (0 : Rat) + | none, none => none + match logitDiffLB? with + | none => + IO.eprintln "error: empty active set for logit-diff bound" + return (2 : UInt32) + | some logitDiffLB => + match certVals'.direction with + | none => + IO.eprintln + "error: value-range certificate missing direction \ + metadata" + return 2 + | some dirSpec => + let data ← IO.FS.readBinFile modelPath + match NfptPure.parseHeader data with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨header, start⟩ => + let dirPos := dirSpec.target + let dirNeg := dirSpec.negative + match NfptPure.readUnembedColumn data start header dirPos with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok colTarget => + match + NfptPure.readUnembedColumn data start header dirNeg + with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok colNeg => + let dirVec : Fin header.modelDim → Rat := + fun i => colTarget i - colNeg i + have hinput : 0 ≤ inputBound := by + exact le_of_not_gt hneg + let W : Matrix (Fin 1) (Fin header.modelDim) Rat := + fun _ j => dirVec j + let downstream := + (Sound.Bounds.buildDownstreamLinearCert + W inputBound hinput).1 + let finalLB := logitDiffLB - downstream.error + let violation? : Option Rat := + match effectiveMinLogitDiff with + | none => none + | some minLogitDiff => + if finalLB < minLogitDiff then + some minLogitDiff + else + none + match violation? with + | some minLogitDiff => + IO.eprintln + s!"error: end-to-end logitDiffLB {finalLB} \ + below minimum {minLogitDiff}" + return (2 : UInt32) + | none => + IO.println + s!"ok: end-to-end induction bound certified \ + (seq={seq}, active={activeCount}, \ + logitDiffLB={logitDiffLB}, \ + downstreamError={downstream.error}, \ + finalLB={finalLB})" + return 0 + /-- Build and check induction certificates from exact head inputs. -/ def runInductionCertifyHead (inputsPath : System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) diff --git a/Nfp/IO/NfptPure.lean b/Nfp/IO/NfptPure.lean new file mode 100644 index 0000000..7193ce9 --- /dev/null +++ b/Nfp/IO/NfptPure.lean @@ -0,0 +1,216 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.Order.Ring.Rat +import Mathlib.Data.List.Range + +/-! +Pure parsing utilities for `NFP_BINARY_V1` model files. + +These helpers parse headers and extract selected weight slices as exact rationals. +-/ + +namespace Nfp + +namespace IO + +namespace NfptPure + +/-- Required header fields for NFP binary models. -/ +structure NfptHeader where + /-- Number of transformer layers. -/ + numLayers : Nat + /-- Number of attention heads per layer. -/ + numHeads : Nat + /-- Model dimension. -/ + modelDim : Nat + /-- Head dimension. -/ + headDim : Nat + /-- MLP hidden dimension. -/ + hiddenDim : Nat + /-- Vocabulary size. -/ + vocabSize : Nat + /-- Sequence length used in the binary. -/ + seqLen : Nat + +private def parseNat (s : String) : Except String Nat := + match s.toNat? with + | some n => Except.ok n + | none => Except.error s!"expected Nat, got '{s}'" + +private def splitKV (line : String) : Option (String × String) := + match line.splitOn "=" with + | [k, v] => some (k.trim, v.trim) + | _ => none + +private def readHeaderField (name : String) (fields : List (String × String)) : + Except String Nat := do + match fields.find? (fun kv => kv.1 = name) with + | some kv => parseNat kv.2 + | none => throw s!"missing header field '{name}'" + +private def sentinelBytes : ByteArray := + "BINARY_START\n".toUTF8 + +private def findSentinel (data : ByteArray) : Option Nat := + let n := data.size + let m := sentinelBytes.size + if m ≤ n then + let maxStart := n - m + let rec loop (i : Nat) (remaining : Nat) : Option Nat := + match remaining with + | 0 => none + | Nat.succ rem => + let ok := + (List.range m).all (fun j => data.get! (i + j) = sentinelBytes.get! j) + if ok then + some i + else + loop (i + 1) rem + loop 0 (maxStart + 1) + else + none + +/-- Parse the NFP binary header and return the binary start offset. -/ +def parseHeader (data : ByteArray) : Except String (NfptHeader × Nat) := do + let idx ← + match findSentinel data with + | some i => pure i + | none => throw "missing BINARY_START sentinel" + let headerBytes := data.extract 0 idx + let headerStr ← + match String.fromUTF8? headerBytes with + | some s => pure s + | none => throw "invalid UTF-8 in header" + let lines := headerStr.splitOn "\n" |>.filter (· ≠ "") + match lines with + | [] => throw "empty header" + | magic :: rest => + if magic != "NFP_BINARY_V1" then + throw s!"unexpected magic '{magic}'" + let fields := rest.filterMap splitKV + let numLayers ← readHeaderField "num_layers" fields + let numHeads ← readHeaderField "num_heads" fields + let modelDim ← readHeaderField "model_dim" fields + let headDim ← readHeaderField "head_dim" fields + let hiddenDim ← readHeaderField "hidden_dim" fields + let vocabSize ← readHeaderField "vocab_size" fields + let seqLen ← readHeaderField "seq_len" fields + if numLayers = 0 then + throw "num_layers must be positive" + if numHeads = 0 then + throw "num_heads must be positive" + if modelDim = 0 then + throw "model_dim must be positive" + if headDim = 0 then + throw "head_dim must be positive" + if hiddenDim = 0 then + throw "hidden_dim must be positive" + if vocabSize = 0 then + throw "vocab_size must be positive" + if seqLen = 0 then + throw "seq_len must be positive" + let start := idx + sentinelBytes.size + return ({ numLayers := numLayers + numHeads := numHeads + modelDim := modelDim + headDim := headDim + hiddenDim := hiddenDim + vocabSize := vocabSize + seqLen := seqLen }, start) + +private def pow2 (k : Nat) : Nat := + Nat.pow 2 k + +private def getBits (n hi lo : Nat) : Nat := + (n / pow2 lo) % pow2 (hi - lo + 1) + +private def ratOfFloatBits (bits : Nat) : Option Rat := + let signBit := getBits bits 63 63 + let expBits := getBits bits 62 52 + let mantBits := getBits bits 51 0 + let sign : Int := if signBit = 0 then 1 else -1 + if expBits = 2047 then + none + else if expBits = 0 then + if mantBits = 0 then + some 0 + else + let num : Int := sign * Int.ofNat mantBits + let denom : Int := Int.ofNat (pow2 1074) + some (Rat.ofInt num / Rat.ofInt denom) + else + let mant := mantBits + pow2 52 + let exp := expBits - 1023 + let shift : Int := Int.ofNat exp - 52 + let base : Rat := Rat.ofInt (sign * Int.ofNat mant) + if 0 ≤ shift then + let k : Nat := Int.toNat shift + some (base * Rat.ofInt (Int.ofNat (pow2 k))) + else + let k : Nat := Int.toNat (-shift) + some (base / Rat.ofInt (Int.ofNat (pow2 k))) + +private def readNatLE (data : ByteArray) (off : Nat) (count : Nat) : Option Nat := + if off + count ≤ data.size then + let rec loop (i : Nat) (acc : Nat) : Nat := + if i < count then + let byte := data.get! (off + i) + loop (i + 1) (acc + byte.toNat * pow2 (8 * i)) + else + acc + some (loop 0 0) + else + none + +private def readF64Rat (data : ByteArray) (off : Nat) : Option Rat := do + let bits ← readNatLE data off 8 + ratOfFloatBits bits + +private def bytesI32 (n : Nat) : Nat := + n * 4 + +private def bytesF64 (n : Nat) : Nat := + n * 8 + +private def f64CountPerHead (h : NfptHeader) : Nat := + 4 * h.modelDim * h.headDim + 3 * h.headDim + +private def f64CountPerLayer (h : NfptHeader) : Nat := + h.numHeads * f64CountPerHead h + + (2 * h.modelDim * h.hiddenDim + h.hiddenDim) + + (6 * h.modelDim) + +private def f64CountBeforeUnembed (h : NfptHeader) : Nat := + h.seqLen * h.modelDim + + h.numLayers * f64CountPerLayer h + + (2 * h.modelDim) + +/-- Byte offset from the binary start to the unembedding matrix. -/ +def unembedOffset (h : NfptHeader) : Nat := + bytesI32 h.seqLen + bytesF64 (f64CountBeforeUnembed h) + +/-- Read a single unembedding column as exact rationals. -/ +def readUnembedColumn (data : ByteArray) (start : Nat) (h : NfptHeader) (col : Nat) : + Except String (Fin h.modelDim → Rat) := do + if col < h.vocabSize then + let base := start + unembedOffset h + let rows := List.range h.modelDim + let vals ← rows.mapM (fun row => do + let off := base + bytesF64 (row * h.vocabSize + col) + match readF64Rat data off with + | some v => pure v + | none => throw s!"invalid f64 at offset {off}") + if hlen : vals.length = h.modelDim then + let vec : Fin h.modelDim → Rat := fun i => + vals.get ⟨i.val, by simp [hlen]⟩ + return vec + else + throw "internal error: unembed column length mismatch" + else + throw s!"column out of range: {col}" + +end NfptPure + +end IO + +end Nfp diff --git a/Nfp/IO/Pure.lean b/Nfp/IO/Pure.lean index 688e209..804cea3 100644 --- a/Nfp/IO/Pure.lean +++ b/Nfp/IO/Pure.lean @@ -777,6 +777,117 @@ def parseInductionHeadInputs (input : String) : let inputs ← finalizeHeadState hpos st return ⟨seq, ⟨dModel, ⟨dHead, inputs⟩⟩⟩ +/-- Raw downstream matrix payload with an input bound. -/ +structure DownstreamMatrixRaw (rows cols : Nat) where + /-- Input magnitude bound. -/ + inputBound : Rat + /-- Matrix entries. -/ + entries : Fin rows → Fin cols → Rat + +private structure DownstreamMatrixParseState (rows cols : Nat) where + inputBound : Option Rat + entries : Fin rows → Fin cols → Option Rat + +private def initDownstreamMatrixState (rows cols : Nat) : + DownstreamMatrixParseState rows cols := + { inputBound := none, entries := fun _ _ => none } + +private def setRectEntry {rows cols : Nat} (mat : Fin rows → Fin cols → Option Rat) + (i j : Nat) (v : Rat) : Except String (Fin rows → Fin cols → Option Rat) := do + if hi : i < rows then + if hj : j < cols then + let iFin : Fin rows := ⟨i, hi⟩ + let jFin : Fin cols := ⟨j, hj⟩ + match mat iFin jFin with + | some _ => + throw s!"duplicate matrix entry at ({i}, {j})" + | none => + let mat' : Fin rows → Fin cols → Option Rat := fun i' j' => + if i' = iFin then + if j' = jFin then + some v + else + mat i' j' + else + mat i' j' + return mat' + else + throw s!"index out of range: col={j}" + else + throw s!"index out of range: row={i}" + +private def parseDownstreamMatrixLine {rows cols : Nat} + (st : DownstreamMatrixParseState rows cols) (tokens : List String) : + Except String (DownstreamMatrixParseState rows cols) := do + match tokens with + | ["input-bound", val] => + if st.inputBound.isSome then + throw "duplicate input-bound entry" + else + return { st with inputBound := some (← parseRat val) } + | ["w", i, j, val] => + let mat ← setRectEntry st.entries (← parseNat i) (← parseNat j) (← parseRat val) + return { st with entries := mat } + | _ => + throw s!"unrecognized line: '{String.intercalate " " tokens}'" + +private def finalizeDownstreamMatrixState {rows cols : Nat} + (st : DownstreamMatrixParseState rows cols) : + Except String (DownstreamMatrixRaw rows cols) := do + let inputBound ← + match st.inputBound with + | some v => pure v + | none => throw "missing input-bound entry" + if !finsetAll (Finset.univ : Finset (Fin rows)) (fun i => + finsetAll (Finset.univ : Finset (Fin cols)) (fun j => (st.entries i j).isSome)) then + throw "missing matrix entries" + let entries : Fin rows → Fin cols → Rat := fun i j => + (st.entries i j).getD 0 + return { inputBound := inputBound, entries := entries } + +/-- Parse a downstream matrix payload from text. -/ +def parseDownstreamMatrixRaw (input : String) : + Except String (Sigma (fun rows => Sigma (fun cols => DownstreamMatrixRaw rows cols))) := do + let lines := input.splitOn "\n" + let tokens := lines.filterMap cleanTokens + let mut rows? : Option Nat := none + let mut cols? : Option Nat := none + for t in tokens do + match t with + | ["rows", n] => + if rows?.isSome then + throw "duplicate rows entry" + else + rows? := some (← parseNat n) + | ["cols", n] => + if cols?.isSome then + throw "duplicate cols entry" + else + cols? := some (← parseNat n) + | _ => pure () + let rows ← + match rows? with + | some v => pure v + | none => throw "missing rows entry" + let cols ← + match cols? with + | some v => pure v + | none => throw "missing cols entry" + match rows, cols with + | 0, _ => throw "rows must be positive" + | _, 0 => throw "cols must be positive" + | Nat.succ r, Nat.succ c => + let rows := Nat.succ r + let cols := Nat.succ c + let st0 := initDownstreamMatrixState rows cols + let st ← tokens.foldlM (fun st t => + match t with + | ["rows", _] => pure st + | ["cols", _] => pure st + | _ => parseDownstreamMatrixLine st t) st0 + let raw ← finalizeDownstreamMatrixState st + return ⟨rows, ⟨cols, raw⟩⟩ + end Pure end IO diff --git a/README.md b/README.md index f3c78ce..8fa2f64 100644 --- a/README.md +++ b/README.md @@ -107,8 +107,26 @@ lake exe nfp induction certify_end_to_end \ --downstream reports/gpt2_downstream.cert ``` -The downstream certificate is **checked for internal arithmetic consistency** but is still -externally computed. Work is ongoing to compute this bound inside Lean from model weights. +The downstream certificate is **checked for internal arithmetic consistency** but is externally +computed. You can also compute the downstream bound inside Lean from a matrix payload: + +```bash +lake exe nfp induction certify_end_to_end_matrix \ + --scores reports/gpt2_induction.cert \ + --values reports/gpt2_induction.values \ + --matrix reports/gpt2_downstream.matrix +``` + +Or derive the downstream matrix directly from an `NFP_BINARY_V1` model file +(currently uses the unembedding direction only): + +```bash +lake exe nfp induction certify_end_to_end_model \ + --scores reports/gpt2_induction.cert \ + --values reports/gpt2_induction.values \ + --model models/gpt2_rigorous.nfpt \ + --input-bound 1/2 +``` ## File formats @@ -150,6 +168,17 @@ input-bound The checker enforces `error = gain * input-bound` and nonnegativity of all fields. +### Downstream matrix payload + +``` +rows +cols +input-bound +w +``` + +The checker computes a row-sum norm bound from the matrix entries. + ### Head input format (for `certify_head`) ``` diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index 0b2201c..f6fee44 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -8,8 +8,9 @@ It is intentionally brief and focused on the soundness boundary. - The trusted CLI only **checks certificates**; it does not search for witnesses or run a model. - Induction certificates are **head-level** (softmax-margin + value-range + logit-diff lower bound). They do **not** yet imply end-to-end model behavior. -- Downstream error bounds are still **external**. `certify_end_to_end` checks arithmetic - consistency, but it does not derive the downstream bound from GPT-2 weights. +- Downstream error bounds can be computed from a **matrix payload** inside Lean. A model-based + path exists, but it currently uses only the unembedding direction and still requires an + external `input-bound` assumption. - The `certify_head` path uses a **head-input file** extracted by an untrusted script; the extractor currently ignores LayerNorm and bias terms, so it is not end-to-end faithful. - Performance: exact head-input recomputation in Lean can be slow for nontrivial sequence lengths. @@ -18,8 +19,8 @@ It is intentionally brief and focused on the soundness boundary. ## Remaining work -- Compute the downstream bound **inside Lean** from model weights and verified bounds - (row-sum norms, LayerNorm/GeLU/softmax envelopes), and wire this into `certify_end_to_end`. +- Compute the downstream bound **inside Lean** from model weights (not just matrix payloads), + and wire this into `certify_end_to_end`. - Replace untrusted extraction with a verified parser for model weight slices. - Add a formal bridge from certificates to circuit semantics and (eventually) to end-to-end transformer claims. From 1bbcbe21b866fd6588df02bb0092dbb5195bf85e Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 3 Jan 2026 15:33:13 +0100 Subject: [PATCH 098/244] Add residual-bound certificates and generator --- AGENTS.md | 2 + CLAIMS.md | 9 +- Nfp/Circuit.lean | 1 + Nfp/Circuit/Cert/ResidualBound.lean | 47 +++++ Nfp/Cli.lean | 6 +- Nfp/IO.lean | 297 +++++++++++++++------------ Nfp/IO/Pure.lean | 62 ++++++ Nfp/Sound/Bounds/MatrixNorm.lean | 62 ++++++ README.md | 12 +- SOUNDNESS_LIMITATIONS.md | 9 +- scripts/build_residual_bound_cert.py | 151 ++++++++++++++ 11 files changed, 512 insertions(+), 146 deletions(-) create mode 100644 Nfp/Circuit/Cert/ResidualBound.lean create mode 100644 scripts/build_residual_bound_cert.py diff --git a/AGENTS.md b/AGENTS.md index 702d495..972d2d6 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -291,6 +291,8 @@ but you **must** update this list in the same commit. - Logit-diff lower-bound computation for induction certificates. - `Nfp/Circuit/Cert/DownstreamLinear.lean` - Downstream linear error certificates for end-to-end induction bounds. +- `Nfp/Circuit/Cert/ResidualBound.lean` + - Residual-stream bound certificates for downstream error computation. - `Nfp/Circuit/Typed.lean` - Typed circuit wrapper and interface-level equivalence checker. - `Nfp/Circuit/Compose.lean` diff --git a/CLAIMS.md b/CLAIMS.md index c8c747a..f57dbc3 100644 --- a/CLAIMS.md +++ b/CLAIMS.md @@ -12,6 +12,8 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri - Logit-diff lower bound lemma: `logitDiffLowerBound_le`. - Downstream linear certificate soundness: `checkDownstreamLinearCert` implies `DownstreamLinearBounds`. +- Residual-bound certificate soundness: `checkResidualBoundCert` implies + `ResidualBoundBounds`. - Row-sum matrix norm bounds for `mulVec` under uniform input magnitude. ## Soundly checked by the trusted CLI @@ -27,8 +29,8 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri - `nfp induction certify_end_to_end_matrix` computes a downstream bound from a matrix payload using verified row-sum norms, then composes it with the head-level logit-diff lower bound. - `nfp induction certify_end_to_end_model` derives a downstream matrix from an `NFP_BINARY_V1` - model file (unembedding direction only) and composes it with the head-level logit-diff - lower bound. + model file (unembedding direction only), computes a downstream error bound from a + residual-bound certificate, and composes it with the head-level logit-diff lower bound. ## Untrusted / heuristic @@ -36,7 +38,8 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri `scripts/build_gpt2_induction_cert.py`, `scripts/build_gpt2_head_inputs.py`, `scripts/build_downstream_linear_cert.py`. - The head-input extractor currently ignores LayerNorm and bias terms. -- Any downstream error bound provided externally (outside the matrix-payload or model-based path). +- Residual-bound certificates are generated externally (unchecked beyond nonnegativity). +- Any downstream error bound provided externally (outside the matrix-payload path). ## Not yet proven diff --git a/Nfp/Circuit.lean b/Nfp/Circuit.lean index cb952a7..78bd9e5 100644 --- a/Nfp/Circuit.lean +++ b/Nfp/Circuit.lean @@ -10,6 +10,7 @@ import Nfp.Circuit.Cert.SoftmaxMargin import Nfp.Circuit.Cert.ValueRange import Nfp.Circuit.Cert.LogitDiff import Nfp.Circuit.Cert.DownstreamLinear +import Nfp.Circuit.Cert.ResidualBound import Nfp.Circuit.Typed import Nfp.Circuit.Compose import Nfp.Circuit.Gates diff --git a/Nfp/Circuit/Cert/ResidualBound.lean b/Nfp/Circuit/Cert/ResidualBound.lean new file mode 100644 index 0000000..4cc7a3a --- /dev/null +++ b/Nfp/Circuit/Cert/ResidualBound.lean @@ -0,0 +1,47 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.Order.Ring.Rat +import Nfp.Circuit.Cert + +/-! +Residual-stream bound certificates. + +These certificates record per-coordinate absolute bounds for residual vectors. +-/ + +namespace Nfp + +namespace Circuit + +/-- Certificate payload for per-coordinate residual absolute bounds. -/ +structure ResidualBoundCert (n : Nat) where + /-- Absolute bound per coordinate. -/ + bound : Fin n → Rat + +/-- Properties enforced by `checkResidualBoundCert`. -/ +structure ResidualBoundBounds {n : Nat} (c : ResidualBoundCert n) : Prop where + /-- Residual bounds are nonnegative. -/ + bound_nonneg : ∀ i, 0 ≤ c.bound i + +/-- Boolean checker for residual-bound certificates. -/ +def checkResidualBoundCert {n : Nat} (c : ResidualBoundCert n) : Bool := + finsetAll (Finset.univ : Finset (Fin n)) (fun i => decide (0 ≤ c.bound i)) + +/-- `checkResidualBoundCert` is sound for `ResidualBoundBounds`. -/ +theorem checkResidualBoundCert_sound {n : Nat} (c : ResidualBoundCert n) : + checkResidualBoundCert c = true → ResidualBoundBounds c := by + intro hcheck + have hall : + finsetAll (Finset.univ : Finset (Fin n)) (fun i => + decide (0 ≤ c.bound i)) = true := by + simpa [checkResidualBoundCert] using hcheck + have hall' := + (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin n)))).1 hall + refine { bound_nonneg := ?_ } + intro i + have hi := hall' i (by simp) + exact (decide_eq_true_iff).1 hi + +end Circuit + +end Nfp diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index 5026942..2bfd4a3 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -142,12 +142,12 @@ def runInductionCertifyEndToEndModel (p : Parsed) : IO UInt32 := do let scoresPath := p.flag! "scores" |>.as! String let valuesPath := p.flag! "values" |>.as! String let modelPath := p.flag! "model" |>.as! String - let inputBound := p.flag! "input-bound" |>.as! String + let residualPath := p.flag! "residual-bound" |>.as! String let minActive? := (p.flag? "min-active").map (·.as! Nat) let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) let minMarginStr? := (p.flag? "min-margin").map (·.as! String) let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) - IO.runInductionCertifyEndToEndModel scoresPath valuesPath modelPath inputBound + IO.runInductionCertifyEndToEndModel scoresPath valuesPath modelPath residualPath minActive? minLogitDiffStr? minMarginStr? maxEpsStr? /-- `nfp induction certify_end_to_end_model` subcommand. -/ @@ -158,7 +158,7 @@ def inductionCertifyEndToEndModelCmd : Cmd := `[Cli| scores : String; "Path to the softmax-margin certificate file." values : String; "Path to the value-range certificate file." model : String; "Path to the NFP_BINARY_V1 model file." - "input-bound" : String; "Nonnegative input bound for the downstream matrix (rational literal)." + "residual-bound" : String; "Path to the residual-bound certificate file." "min-active" : Nat; "Optional minimum number of active queries required \ (default: max 1 (seq/8))." "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ diff --git a/Nfp/IO.lean b/Nfp/IO.lean index 5fa2de9..b43307e 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -4,6 +4,7 @@ import Nfp.IO.Pure import Nfp.IO.NfptPure import Nfp.Circuit.Cert.LogitDiff import Nfp.Circuit.Cert.DownstreamLinear +import Nfp.Circuit.Cert.ResidualBound import Nfp.Sound.Bounds.MatrixNorm import Nfp.Sound.Induction @@ -48,6 +49,12 @@ def loadDownstreamMatrixRaw (path : System.FilePath) : let data ← IO.FS.readFile path return Pure.parseDownstreamMatrixRaw data +/-- Load a residual-bound certificate from disk. -/ +def loadResidualBoundCert (path : System.FilePath) : + IO (Except String (Sigma (fun n => ResidualBoundCert n))) := do + let data ← IO.FS.readFile path + return Pure.parseResidualBoundCert data + /-- Load raw value-range inputs from disk. -/ def loadValueRangeRaw (path : System.FilePath) : IO (Except String (Sigma Pure.ValueRangeRaw)) := do @@ -556,155 +563,175 @@ def runInductionCertifyEndToEndMatrix (scoresPath : System.FilePath) finalLB={finalLB})" return 0 -/-- Check end-to-end induction certificates using a model file to derive the downstream matrix. -/ +/-- Check end-to-end induction certificates using a model file and residual bounds. -/ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) (valuesPath : System.FilePath) (modelPath : System.FilePath) - (inputBoundStr : String) (minActive? : Option Nat) + (residualPath : System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? let minMargin?E := parseRatOpt "min-margin" minMarginStr? let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - let inputBoundE := Pure.parseRat inputBoundStr - match minLogitDiff?E, minMargin?E, maxEps?E, inputBoundE with - | Except.error msg, _, _, _ => + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => IO.eprintln s!"error: {msg}" return 2 - | _, Except.error msg, _, _ => + | _, Except.error msg, _ => IO.eprintln s!"error: {msg}" return 2 - | _, _, Except.error msg, _ => + | _, _, Except.error msg => IO.eprintln s!"error: {msg}" return 2 - | _, _, _, Except.error msg => - IO.eprintln s!"error: invalid input-bound: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps?, Except.ok inputBound => do - if hneg : inputBound < 0 then - IO.eprintln s!"error: input-bound {inputBound} must be nonnegative" - return 2 - else - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (1 / 2 : Rat) - let parsedScores ← loadSoftmaxMarginCert scoresPath - match parsedScores with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seq, cert⟩ => - let scoresOk ← checkSoftmaxMargin seq cert - match scoresOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (1 / 2 : Rat) + let parsedScores ← loadSoftmaxMarginCert scoresPath + match parsedScores with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seq, cert⟩ => + let scoresOk ← checkSoftmaxMargin seq cert + match scoresOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let activeCount := cert.active.card + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" return 2 - | Except.ok () => - let activeCount := cert.active.card - let defaultMinActive := max 1 (seq / 8) - let minActive := minActive?.getD defaultMinActive - if activeCount < minActive then - IO.eprintln - s!"error: active queries {activeCount} below minimum {minActive}" - return 2 - if cert.margin < minMargin then - IO.eprintln - s!"error: margin {cert.margin} below minimum {minMargin}" - return 2 - if maxEps < cert.eps then - IO.eprintln - s!"error: eps {cert.eps} above maximum {maxEps}" - return 2 - let parsedValues ← loadValueRangeCert valuesPath - match parsedValues with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seqVals, certVals⟩ => - if hseq : seqVals ≠ seq then - IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" - return 2 - else - have hseq' : seqVals = seq := by - exact (not_ne_iff).1 hseq - let certVals' : ValueRangeCert seq := by - simpa [hseq'] using certVals - let valuesOk ← checkValueRange seq certVals' - match valuesOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - let logitDiffLB? := - Circuit.logitDiffLowerBound cert.active cert.prev cert.eps - certVals'.lo certVals'.hi certVals'.vals - let effectiveMinLogitDiff := - match minLogitDiff?, certVals'.direction with - | some v, _ => some v - | none, some _ => some (0 : Rat) - | none, none => none - match logitDiffLB? with - | none => - IO.eprintln "error: empty active set for logit-diff bound" - return (2 : UInt32) - | some logitDiffLB => - match certVals'.direction with - | none => - IO.eprintln - "error: value-range certificate missing direction \ - metadata" - return 2 - | some dirSpec => - let data ← IO.FS.readBinFile modelPath - match NfptPure.parseHeader data with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨header, start⟩ => - let dirPos := dirSpec.target - let dirNeg := dirSpec.negative - match NfptPure.readUnembedColumn data start header dirPos with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok colTarget => - match - NfptPure.readUnembedColumn data start header dirNeg - with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok colNeg => - let dirVec : Fin header.modelDim → Rat := - fun i => colTarget i - colNeg i - have hinput : 0 ≤ inputBound := by - exact le_of_not_gt hneg - let W : Matrix (Fin 1) (Fin header.modelDim) Rat := - fun _ j => dirVec j - let downstream := - (Sound.Bounds.buildDownstreamLinearCert - W inputBound hinput).1 - let finalLB := logitDiffLB - downstream.error - let violation? : Option Rat := - match effectiveMinLogitDiff with - | none => none - | some minLogitDiff => - if finalLB < minLogitDiff then - some minLogitDiff - else - none - match violation? with - | some minLogitDiff => - IO.eprintln - s!"error: end-to-end logitDiffLB {finalLB} \ - below minimum {minLogitDiff}" - return (2 : UInt32) - | none => - IO.println - s!"ok: end-to-end induction bound certified \ - (seq={seq}, active={activeCount}, \ - logitDiffLB={logitDiffLB}, \ - downstreamError={downstream.error}, \ - finalLB={finalLB})" - return 0 + if cert.margin < minMargin then + IO.eprintln + s!"error: margin {cert.margin} below minimum {minMargin}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {cert.eps} above maximum {maxEps}" + return 2 + let parsedValues ← loadValueRangeCert valuesPath + match parsedValues with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seqVals, certVals⟩ => + if hseq : seqVals ≠ seq then + IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" + return 2 + else + have hseq' : seqVals = seq := by + exact (not_ne_iff).1 hseq + let certVals' : ValueRangeCert seq := by + simpa [hseq'] using certVals + let valuesOk ← checkValueRange seq certVals' + match valuesOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let logitDiffLB? := + Circuit.logitDiffLowerBound cert.active cert.prev cert.eps + certVals'.lo certVals'.hi certVals'.vals + let effectiveMinLogitDiff := + match minLogitDiff?, certVals'.direction with + | some v, _ => some v + | none, some _ => some (0 : Rat) + | none, none => none + match logitDiffLB? with + | none => + IO.eprintln "error: empty active set for logit-diff bound" + return (2 : UInt32) + | some logitDiffLB => + match certVals'.direction with + | none => + IO.eprintln + "error: value-range certificate missing direction \ + metadata" + return 2 + | some dirSpec => + let data ← IO.FS.readBinFile modelPath + match NfptPure.parseHeader data with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨header, start⟩ => + let parsedResidual ← loadResidualBoundCert residualPath + match parsedResidual with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨dim, residualCert⟩ => + if hdim : dim = header.modelDim then + let residualCert' : + ResidualBoundCert header.modelDim := by + simpa [hdim] using residualCert + let residualOk := + Circuit.checkResidualBoundCert residualCert' + if residualOk then + let dirPos := dirSpec.target + let dirNeg := dirSpec.negative + match + NfptPure.readUnembedColumn data start header dirPos + with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok colTarget => + match + NfptPure.readUnembedColumn + data start header dirNeg + with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok colNeg => + let dirVec : + Fin header.modelDim → Rat := + fun i => colTarget i - colNeg i + let W : + Matrix (Fin 1) + (Fin header.modelDim) Rat := + fun _ j => dirVec j + let downstreamError := + Sound.Bounds.downstreamErrorFromBounds + W residualCert'.bound + let finalLB := logitDiffLB - downstreamError + let violation? : Option Rat := + match effectiveMinLogitDiff with + | none => none + | some minLogitDiff => + if finalLB < minLogitDiff then + some minLogitDiff + else + none + match violation? with + | some minLogitDiff => + IO.eprintln + s!"error: end-to-end logitDiffLB \ + {finalLB} below minimum \ + {minLogitDiff}" + return (2 : UInt32) + | none => + IO.println + s!"ok: end-to-end induction \ + bound certified (seq={seq}, \ + active={activeCount}, \ + logitDiffLB={logitDiffLB}, \ + downstreamError={downstreamError}, \ + finalLB={finalLB})" + return 0 + else + IO.eprintln + "error: residual-bound certificate rejected" + return 2 + else + IO.eprintln + s!"error: residual bound dim {dim} \ + does not match model dim {header.modelDim}" + return 2 /-- Build and check induction certificates from exact head inputs. -/ def runInductionCertifyHead (inputsPath : System.FilePath) diff --git a/Nfp/IO/Pure.lean b/Nfp/IO/Pure.lean index 804cea3..3b5930a 100644 --- a/Nfp/IO/Pure.lean +++ b/Nfp/IO/Pure.lean @@ -5,6 +5,7 @@ import Mathlib.Data.Finset.Insert import Nfp.Circuit.Cert.SoftmaxMargin import Nfp.Circuit.Cert.ValueRange import Nfp.Circuit.Cert.DownstreamLinear +import Nfp.Circuit.Cert.ResidualBound import Nfp.Model.InductionHead /-! @@ -888,6 +889,67 @@ def parseDownstreamMatrixRaw (input : String) : let raw ← finalizeDownstreamMatrixState st return ⟨rows, ⟨cols, raw⟩⟩ +private structure ResidualBoundParseState (n : Nat) where + bounds : Fin n → Option Rat + +private def initResidualBoundState (n : Nat) : ResidualBoundParseState n := + { bounds := fun _ => none } + +private def setVectorEntry {n : Nat} (bounds : Fin n → Option Rat) + (i : Nat) (v : Rat) : Except String (Fin n → Option Rat) := do + if hi : i < n then + let iFin : Fin n := ⟨i, hi⟩ + match bounds iFin with + | some _ => + throw s!"duplicate bound entry at index {i}" + | none => + let bounds' : Fin n → Option Rat := fun i' => + if i' = iFin then + some v + else + bounds i' + return bounds' + else + throw s!"index out of range: {i}" + +private def parseResidualBoundLine {n : Nat} (st : ResidualBoundParseState n) + (tokens : List String) : Except String (ResidualBoundParseState n) := do + match tokens with + | ["bound", i, val] => + let bounds ← setVectorEntry st.bounds (← parseNat i) (← parseRat val) + return { st with bounds := bounds } + | ["dim", _] => + throw "duplicate dim entry" + | _ => + throw s!"unrecognized line: '{String.intercalate " " tokens}'" + +private def finalizeResidualBoundState {n : Nat} (st : ResidualBoundParseState n) : + Except String (Circuit.ResidualBoundCert n) := do + if !finsetAll (Finset.univ : Finset (Fin n)) (fun i => (st.bounds i).isSome) then + throw "missing bound entries" + let bound : Fin n → Rat := fun i => + (st.bounds i).getD 0 + return { bound := bound } + +/-- Parse a residual-bound payload from text. -/ +def parseResidualBoundCert (input : String) : + Except String (Sigma (fun n => Circuit.ResidualBoundCert n)) := do + let lines := input.splitOn "\n" + let tokens := lines.filterMap cleanTokens + match tokens with + | [] => throw "empty residual-bound payload" + | ["dim", nStr] :: rest => + let n ← parseNat nStr + match n with + | 0 => throw "dim must be positive" + | Nat.succ n' => + let dim := Nat.succ n' + let st0 := initResidualBoundState dim + let st ← rest.foldlM (fun st t => parseResidualBoundLine st t) st0 + let cert ← finalizeResidualBoundState st + return ⟨dim, cert⟩ + | _ => throw "expected header 'dim '" + end Pure end IO diff --git a/Nfp/Sound/Bounds/MatrixNorm.lean b/Nfp/Sound/Bounds/MatrixNorm.lean index 8cf54b1..5aa60f2 100644 --- a/Nfp/Sound/Bounds/MatrixNorm.lean +++ b/Nfp/Sound/Bounds/MatrixNorm.lean @@ -25,6 +25,11 @@ open scoped BigOperators def rowSum {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : Rat := ∑ j, |W i j| +/-- Weighted row-sum using per-coordinate bounds. -/ +def rowSumWeighted {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (bound : Fin n → Rat) (i : Fin m) : Rat := + ∑ j, |W i j| * bound j + /-- Maximum row-sum norm (defaults to `0` on empty matrices). -/ def rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) : Rat := if h : (Finset.univ : Finset (Fin m)).Nonempty then @@ -32,6 +37,14 @@ def rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) : Rat := else 0 +/-- Maximum weighted row-sum (defaults to `0` on empty matrices). -/ +def rowSumWeightedNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (bound : Fin n → Rat) : Rat := + if h : (Finset.univ : Finset (Fin m)).Nonempty then + (Finset.univ).sup' h (fun i => rowSumWeighted W bound i) + else + 0 + /-- Row-sums are nonnegative. -/ theorem rowSum_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : 0 ≤ rowSum W i := by @@ -39,6 +52,14 @@ theorem rowSum_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : intro j _ exact abs_nonneg (W i j) +/-- Weighted row-sums are nonnegative under nonnegative bounds. -/ +theorem rowSumWeighted_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (bound : Fin n → Rat) (i : Fin m) (hbound : ∀ j, 0 ≤ bound j) : + 0 ≤ rowSumWeighted W bound i := by + refine Finset.sum_nonneg ?_ + intro j _ + exact mul_nonneg (abs_nonneg (W i j)) (hbound j) + /-- Each row-sum is bounded by the row-sum norm. -/ theorem rowSum_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : rowSum W i ≤ rowSumNorm W := by @@ -53,6 +74,22 @@ theorem rowSum_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : F (by simp : i ∈ (Finset.univ : Finset (Fin m)))) simpa [rowSumNorm, h] using hle +/-- Each weighted row-sum is bounded by the weighted row-sum norm. -/ +theorem rowSumWeighted_le_rowSumWeightedNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (bound : Fin n → Rat) (i : Fin m) : + rowSumWeighted W bound i ≤ rowSumWeightedNorm W bound := by + classical + have h : (Finset.univ : Finset (Fin m)).Nonempty := ⟨i, by simp⟩ + have hle : + rowSumWeighted W bound i ≤ + (Finset.univ).sup' h (fun i => rowSumWeighted W bound i) := by + simpa using + (Finset.le_sup' + (s := (Finset.univ : Finset (Fin m))) + (f := fun i => rowSumWeighted W bound i) + (by simp : i ∈ (Finset.univ : Finset (Fin m)))) + simpa [rowSumWeightedNorm, h] using hle + /-- The row-sum norm is nonnegative. -/ theorem rowSumNorm_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) : 0 ≤ rowSumNorm W := by @@ -64,6 +101,31 @@ theorem rowSumNorm_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) : exact le_trans hrow hle · simp [rowSumNorm, h] +/-- Weighted row-sum norm is nonnegative under nonnegative bounds. -/ +theorem rowSumWeightedNorm_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (bound : Fin n → Rat) (hbound : ∀ j, 0 ≤ bound j) : + 0 ≤ rowSumWeightedNorm W bound := by + classical + by_cases h : (Finset.univ : Finset (Fin m)).Nonempty + · rcases h with ⟨i, hi⟩ + have hrow : 0 ≤ rowSumWeighted W bound i := + rowSumWeighted_nonneg W bound i hbound + have hle : rowSumWeighted W bound i ≤ rowSumWeightedNorm W bound := + rowSumWeighted_le_rowSumWeightedNorm W bound i + exact le_trans hrow hle + · simp [rowSumWeightedNorm, h] + +/-- Downstream error from per-coordinate residual bounds. -/ +def downstreamErrorFromBounds {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (bound : Fin n → Rat) : Rat := + rowSumWeightedNorm W bound + +/-- `downstreamErrorFromBounds` is nonnegative. -/ +theorem downstreamErrorFromBounds_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (bound : Fin n → Rat) (hbound : ∀ j, 0 ≤ bound j) : + 0 ≤ downstreamErrorFromBounds W bound := by + simpa [downstreamErrorFromBounds] using rowSumWeightedNorm_nonneg W bound hbound + /-- Row-sum norm bounds a matrix-vector product under a uniform input bound. -/ theorem abs_mulVec_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (x : Fin n → Rat) (inputBound : Rat) diff --git a/README.md b/README.md index 8fa2f64..a3e43ec 100644 --- a/README.md +++ b/README.md @@ -125,7 +125,7 @@ lake exe nfp induction certify_end_to_end_model \ --scores reports/gpt2_induction.cert \ --values reports/gpt2_induction.values \ --model models/gpt2_rigorous.nfpt \ - --input-bound 1/2 + --residual-bound reports/gpt2_residual.bound ``` ## File formats @@ -179,6 +179,16 @@ w The checker computes a row-sum norm bound from the matrix entries. +### Residual-bound certificate + +``` +dim +bound +``` + +Each `bound` entry supplies a nonnegative absolute bound for the residual vector +coordinate `i`, used to compute downstream error. + ### Head input format (for `certify_head`) ``` diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index f6fee44..51042ff 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -9,8 +9,8 @@ It is intentionally brief and focused on the soundness boundary. - Induction certificates are **head-level** (softmax-margin + value-range + logit-diff lower bound). They do **not** yet imply end-to-end model behavior. - Downstream error bounds can be computed from a **matrix payload** inside Lean. A model-based - path exists, but it currently uses only the unembedding direction and still requires an - external `input-bound` assumption. + path exists, but it currently uses only the unembedding direction and relies on an external + **residual-bound certificate** (per-coordinate absolute bounds). - The `certify_head` path uses a **head-input file** extracted by an untrusted script; the extractor currently ignores LayerNorm and bias terms, so it is not end-to-end faithful. - Performance: exact head-input recomputation in Lean can be slow for nontrivial sequence lengths. @@ -19,8 +19,9 @@ It is intentionally brief and focused on the soundness boundary. ## Remaining work -- Compute the downstream bound **inside Lean** from model weights (not just matrix payloads), - and wire this into `certify_end_to_end`. +- Compute the downstream bound **inside Lean** from model weights and certified residual + bounds (not just matrix payloads), and wire this into `certify_end_to_end`. +- Replace untrusted residual-bound generation with a verified derivation from upstream bounds. - Replace untrusted extraction with a verified parser for model weight slices. - Add a formal bridge from certificates to circuit semantics and (eventually) to end-to-end transformer claims. diff --git a/scripts/build_residual_bound_cert.py b/scripts/build_residual_bound_cert.py new file mode 100644 index 0000000..a34b5c6 --- /dev/null +++ b/scripts/build_residual_bound_cert.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: AGPL-3.0-or-later + +""" +Build a residual-bound certificate from a GPT-2 forward pass. + +This script is untrusted. It computes per-coordinate absolute bounds by +taking maxima over a fixed input sequence (optionally restricted to active +positions from a softmax-margin certificate). The resulting bounds are +rounded up to rationals for checking by `nfp induction certify_end_to_end_model`. + +Usage: + uv run scripts/build_residual_bound_cert.py \ + --output reports/gpt2_residual.bound \ + --seq 32 --pattern-length 16 \ + --scores reports/gpt2_induction.cert + +Optional: + --tokens tokens.txt # whitespace-separated token ids + --random-pattern --seed 0 + --decimals 6 --safety 1e-6 +""" + +import argparse +import math +from fractions import Fraction +from pathlib import Path + +import numpy as np + +try: + import torch + from transformers import GPT2Model +except ImportError: + raise SystemExit( + "Missing dependencies. Install with: uv add transformers torch" + ) + + +def rat_to_str(q: Fraction) -> str: + if q.denominator == 1: + return str(q.numerator) + return f"{q.numerator}/{q.denominator}" + + +def build_tokens(seq: int, pattern_len: int, random_pattern: bool, seed: int) -> np.ndarray: + if random_pattern: + rng = np.random.default_rng(seed) + pattern = rng.integers(1000, 30000, size=pattern_len, endpoint=False) + else: + pattern = np.arange(pattern_len) + repeats = (seq // pattern_len) + 1 + return np.tile(pattern, repeats)[:seq] + + +def parse_tokens(path: Path) -> np.ndarray: + raw = path.read_text(encoding="ascii") + tokens = [int(tok) for tok in raw.split() if tok.strip()] + if not tokens: + raise SystemExit(f"no tokens found in {path}") + return np.array(tokens, dtype=np.int64) + + +def parse_active_positions(path: Path) -> tuple[int | None, list[int]]: + seq = None + active: list[int] = [] + for line in path.read_text(encoding="ascii").splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + parts = line.split() + if parts[0] == "seq" and len(parts) >= 2: + seq = int(parts[1]) + elif parts[0] == "active" and len(parts) >= 2: + active.append(int(parts[1])) + return seq, active + + +def ceil_rat(x: float, decimals: int, safety: float) -> Fraction: + scale = 10 ** decimals + scaled = abs(x) * (1.0 + safety) * scale + return Fraction(int(math.ceil(scaled)), scale) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--output", required=True, help="Path to write certificate") + parser.add_argument("--seq", type=int, default=32, help="Sequence length") + parser.add_argument("--pattern-length", type=int, default=16, help="Pattern length") + parser.add_argument("--random-pattern", action="store_true", + help="Use random token pattern") + parser.add_argument("--seed", type=int, default=0, help="RNG seed for random pattern") + parser.add_argument("--tokens", help="Optional path to whitespace-separated tokens") + parser.add_argument("--scores", help="Optional softmax-margin certificate for active queries") + parser.add_argument("--model", default="gpt2", help="HuggingFace model name") + parser.add_argument("--device", default="cpu", help="Torch device") + parser.add_argument("--decimals", type=int, default=6, + help="Decimal rounding for rationals (ceil)") + parser.add_argument("--safety", type=float, default=1e-6, + help="Relative safety slack added before rounding") + args = parser.parse_args() + + if args.seq <= 0: + raise SystemExit("seq must be positive") + if args.decimals < 0: + raise SystemExit("decimals must be nonnegative") + if args.safety < 0: + raise SystemExit("safety must be nonnegative") + + if args.tokens: + tokens = parse_tokens(Path(args.tokens)) + seq = len(tokens) + else: + seq = args.seq + tokens = build_tokens(seq, args.pattern_length, args.random_pattern, args.seed) + + positions = list(range(seq)) + if args.scores: + cert_seq, active = parse_active_positions(Path(args.scores)) + if cert_seq is not None and cert_seq != seq: + raise SystemExit(f"seq mismatch: scores={cert_seq} tokens={seq}") + if active: + positions = active + + model = GPT2Model.from_pretrained(args.model) + model.to(args.device) + model.eval() + input_ids = torch.tensor(tokens, dtype=torch.long, device=args.device).unsqueeze(0) + with torch.no_grad(): + outputs = model(input_ids) + hidden = outputs.last_hidden_state.squeeze(0).cpu().numpy() + + if hidden.shape[0] != seq: + raise SystemExit(f"hidden state seq mismatch: {hidden.shape[0]} vs {seq}") + + chosen = hidden[positions] + max_abs = np.max(np.abs(chosen), axis=0) + bounds = [ceil_rat(float(val), args.decimals, args.safety) for val in max_abs] + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w", encoding="ascii") as f: + f.write(f"dim {len(bounds)}\n") + for i, bound in enumerate(bounds): + f.write(f"bound {i} {rat_to_str(bound)}\n") + + print(f"Wrote residual-bound certificate to {output_path}") + + +if __name__ == "__main__": + main() From e7ee745ea35fc5bb8777a203b96be0f55f177092 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 3 Jan 2026 15:48:46 +0100 Subject: [PATCH 099/244] Add residual-interval bounds for downstream error --- AGENTS.md | 2 + CLAIMS.md | 8 +- Nfp/Circuit.lean | 1 + Nfp/Circuit/Cert/ResidualInterval.lean | 49 +++++++ Nfp/Cli.lean | 6 +- Nfp/IO.lean | 28 ++-- Nfp/IO/Pure.lean | 53 +++++++ Nfp/Sound/Bounds/MatrixNorm.lean | 76 ++++++++++ README.md | 11 +- SOUNDNESS_LIMITATIONS.md | 4 +- scripts/build_residual_interval_cert.py | 179 ++++++++++++++++++++++++ 11 files changed, 391 insertions(+), 26 deletions(-) create mode 100644 Nfp/Circuit/Cert/ResidualInterval.lean create mode 100644 scripts/build_residual_interval_cert.py diff --git a/AGENTS.md b/AGENTS.md index 972d2d6..d69c644 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -293,6 +293,8 @@ but you **must** update this list in the same commit. - Downstream linear error certificates for end-to-end induction bounds. - `Nfp/Circuit/Cert/ResidualBound.lean` - Residual-stream bound certificates for downstream error computation. +- `Nfp/Circuit/Cert/ResidualInterval.lean` + - Residual-stream interval certificates for downstream dot-product bounds. - `Nfp/Circuit/Typed.lean` - Typed circuit wrapper and interface-level equivalence checker. - `Nfp/Circuit/Compose.lean` diff --git a/CLAIMS.md b/CLAIMS.md index f57dbc3..5d60637 100644 --- a/CLAIMS.md +++ b/CLAIMS.md @@ -12,8 +12,8 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri - Logit-diff lower bound lemma: `logitDiffLowerBound_le`. - Downstream linear certificate soundness: `checkDownstreamLinearCert` implies `DownstreamLinearBounds`. -- Residual-bound certificate soundness: `checkResidualBoundCert` implies - `ResidualBoundBounds`. +- Residual-interval certificate soundness: `checkResidualIntervalCert` implies + `ResidualIntervalBounds`. - Row-sum matrix norm bounds for `mulVec` under uniform input magnitude. ## Soundly checked by the trusted CLI @@ -30,7 +30,7 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri using verified row-sum norms, then composes it with the head-level logit-diff lower bound. - `nfp induction certify_end_to_end_model` derives a downstream matrix from an `NFP_BINARY_V1` model file (unembedding direction only), computes a downstream error bound from a - residual-bound certificate, and composes it with the head-level logit-diff lower bound. + residual-interval certificate, and composes it with the head-level logit-diff lower bound. ## Untrusted / heuristic @@ -38,7 +38,7 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri `scripts/build_gpt2_induction_cert.py`, `scripts/build_gpt2_head_inputs.py`, `scripts/build_downstream_linear_cert.py`. - The head-input extractor currently ignores LayerNorm and bias terms. -- Residual-bound certificates are generated externally (unchecked beyond nonnegativity). +- Residual-interval certificates are generated externally (unchecked beyond consistency). - Any downstream error bound provided externally (outside the matrix-payload path). ## Not yet proven diff --git a/Nfp/Circuit.lean b/Nfp/Circuit.lean index 78bd9e5..dc663ad 100644 --- a/Nfp/Circuit.lean +++ b/Nfp/Circuit.lean @@ -11,6 +11,7 @@ import Nfp.Circuit.Cert.ValueRange import Nfp.Circuit.Cert.LogitDiff import Nfp.Circuit.Cert.DownstreamLinear import Nfp.Circuit.Cert.ResidualBound +import Nfp.Circuit.Cert.ResidualInterval import Nfp.Circuit.Typed import Nfp.Circuit.Compose import Nfp.Circuit.Gates diff --git a/Nfp/Circuit/Cert/ResidualInterval.lean b/Nfp/Circuit/Cert/ResidualInterval.lean new file mode 100644 index 0000000..dca299b --- /dev/null +++ b/Nfp/Circuit/Cert/ResidualInterval.lean @@ -0,0 +1,49 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.Order.Ring.Rat +import Nfp.Circuit.Cert + +/-! +Residual-stream interval certificates. + +These certificates record per-coordinate lower/upper bounds for residual vectors. +-/ + +namespace Nfp + +namespace Circuit + +/-- Certificate payload for per-coordinate residual intervals. -/ +structure ResidualIntervalCert (n : Nat) where + /-- Lower bound per coordinate. -/ + lo : Fin n → Rat + /-- Upper bound per coordinate. -/ + hi : Fin n → Rat + +/-- Properties enforced by `checkResidualIntervalCert`. -/ +structure ResidualIntervalBounds {n : Nat} (c : ResidualIntervalCert n) : Prop where + /-- Lower bounds are at most upper bounds. -/ + lo_le_hi : ∀ i, c.lo i ≤ c.hi i + +/-- Boolean checker for residual-interval certificates. -/ +def checkResidualIntervalCert {n : Nat} (c : ResidualIntervalCert n) : Bool := + finsetAll (Finset.univ : Finset (Fin n)) (fun i => decide (c.lo i ≤ c.hi i)) + +/-- `checkResidualIntervalCert` is sound for `ResidualIntervalBounds`. -/ +theorem checkResidualIntervalCert_sound {n : Nat} (c : ResidualIntervalCert n) : + checkResidualIntervalCert c = true → ResidualIntervalBounds c := by + intro hcheck + have hall : + finsetAll (Finset.univ : Finset (Fin n)) (fun i => + decide (c.lo i ≤ c.hi i)) = true := by + simpa [checkResidualIntervalCert] using hcheck + have hall' := + (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin n)))).1 hall + refine { lo_le_hi := ?_ } + intro i + have hi := hall' i (by simp) + exact (decide_eq_true_iff).1 hi + +end Circuit + +end Nfp diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index 2bfd4a3..c8da323 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -142,12 +142,12 @@ def runInductionCertifyEndToEndModel (p : Parsed) : IO UInt32 := do let scoresPath := p.flag! "scores" |>.as! String let valuesPath := p.flag! "values" |>.as! String let modelPath := p.flag! "model" |>.as! String - let residualPath := p.flag! "residual-bound" |>.as! String + let residualIntervalPath := p.flag! "residual-interval" |>.as! String let minActive? := (p.flag? "min-active").map (·.as! Nat) let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) let minMarginStr? := (p.flag? "min-margin").map (·.as! String) let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) - IO.runInductionCertifyEndToEndModel scoresPath valuesPath modelPath residualPath + IO.runInductionCertifyEndToEndModel scoresPath valuesPath modelPath residualIntervalPath minActive? minLogitDiffStr? minMarginStr? maxEpsStr? /-- `nfp induction certify_end_to_end_model` subcommand. -/ @@ -158,7 +158,7 @@ def inductionCertifyEndToEndModelCmd : Cmd := `[Cli| scores : String; "Path to the softmax-margin certificate file." values : String; "Path to the value-range certificate file." model : String; "Path to the NFP_BINARY_V1 model file." - "residual-bound" : String; "Path to the residual-bound certificate file." + "residual-interval" : String; "Path to the residual-interval certificate file." "min-active" : Nat; "Optional minimum number of active queries required \ (default: max 1 (seq/8))." "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ diff --git a/Nfp/IO.lean b/Nfp/IO.lean index b43307e..7ba3feb 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -5,6 +5,7 @@ import Nfp.IO.NfptPure import Nfp.Circuit.Cert.LogitDiff import Nfp.Circuit.Cert.DownstreamLinear import Nfp.Circuit.Cert.ResidualBound +import Nfp.Circuit.Cert.ResidualInterval import Nfp.Sound.Bounds.MatrixNorm import Nfp.Sound.Induction @@ -55,6 +56,12 @@ def loadResidualBoundCert (path : System.FilePath) : let data ← IO.FS.readFile path return Pure.parseResidualBoundCert data +/-- Load a residual-interval certificate from disk. -/ +def loadResidualIntervalCert (path : System.FilePath) : + IO (Except String (Sigma (fun n => ResidualIntervalCert n))) := do + let data ← IO.FS.readFile path + return Pure.parseResidualIntervalCert data + /-- Load raw value-range inputs from disk. -/ def loadValueRangeRaw (path : System.FilePath) : IO (Except String (Sigma Pure.ValueRangeRaw)) := do @@ -566,7 +573,7 @@ def runInductionCertifyEndToEndMatrix (scoresPath : System.FilePath) /-- Check end-to-end induction certificates using a model file and residual bounds. -/ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) (valuesPath : System.FilePath) (modelPath : System.FilePath) - (residualPath : System.FilePath) (minActive? : Option Nat) + (residualIntervalPath : System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? @@ -658,7 +665,8 @@ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) IO.eprintln s!"error: {msg}" return 1 | Except.ok ⟨header, start⟩ => - let parsedResidual ← loadResidualBoundCert residualPath + let parsedResidual ← + loadResidualIntervalCert residualIntervalPath match parsedResidual with | Except.error msg => IO.eprintln s!"error: {msg}" @@ -666,10 +674,10 @@ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) | Except.ok ⟨dim, residualCert⟩ => if hdim : dim = header.modelDim then let residualCert' : - ResidualBoundCert header.modelDim := by + ResidualIntervalCert header.modelDim := by simpa [hdim] using residualCert let residualOk := - Circuit.checkResidualBoundCert residualCert' + Circuit.checkResidualIntervalCert residualCert' if residualOk then let dirPos := dirSpec.target let dirNeg := dirSpec.negative @@ -691,13 +699,9 @@ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) let dirVec : Fin header.modelDim → Rat := fun i => colTarget i - colNeg i - let W : - Matrix (Fin 1) - (Fin header.modelDim) Rat := - fun _ j => dirVec j let downstreamError := - Sound.Bounds.downstreamErrorFromBounds - W residualCert'.bound + Sound.Bounds.dotIntervalAbsBound + dirVec residualCert'.lo residualCert'.hi let finalLB := logitDiffLB - downstreamError let violation? : Option Rat := match effectiveMinLogitDiff with @@ -725,11 +729,11 @@ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) return 0 else IO.eprintln - "error: residual-bound certificate rejected" + "error: residual-interval certificate rejected" return 2 else IO.eprintln - s!"error: residual bound dim {dim} \ + s!"error: residual interval dim {dim} \ does not match model dim {header.modelDim}" return 2 diff --git a/Nfp/IO/Pure.lean b/Nfp/IO/Pure.lean index 3b5930a..10ca560 100644 --- a/Nfp/IO/Pure.lean +++ b/Nfp/IO/Pure.lean @@ -6,6 +6,7 @@ import Nfp.Circuit.Cert.SoftmaxMargin import Nfp.Circuit.Cert.ValueRange import Nfp.Circuit.Cert.DownstreamLinear import Nfp.Circuit.Cert.ResidualBound +import Nfp.Circuit.Cert.ResidualInterval import Nfp.Model.InductionHead /-! @@ -950,6 +951,58 @@ def parseResidualBoundCert (input : String) : return ⟨dim, cert⟩ | _ => throw "expected header 'dim '" +private structure ResidualIntervalParseState (n : Nat) where + lo : Fin n → Option Rat + hi : Fin n → Option Rat + +private def initResidualIntervalState (n : Nat) : ResidualIntervalParseState n := + { lo := fun _ => none, hi := fun _ => none } + +private def parseResidualIntervalLine {n : Nat} (st : ResidualIntervalParseState n) + (tokens : List String) : Except String (ResidualIntervalParseState n) := do + match tokens with + | ["lo", i, val] => + let lo ← setVectorEntry st.lo (← parseNat i) (← parseRat val) + return { st with lo := lo } + | ["hi", i, val] => + let hi ← setVectorEntry st.hi (← parseNat i) (← parseRat val) + return { st with hi := hi } + | ["dim", _] => + throw "duplicate dim entry" + | _ => + throw s!"unrecognized line: '{String.intercalate " " tokens}'" + +private def finalizeResidualIntervalState {n : Nat} (st : ResidualIntervalParseState n) : + Except String (Circuit.ResidualIntervalCert n) := do + if !finsetAll (Finset.univ : Finset (Fin n)) (fun i => (st.lo i).isSome) then + throw "missing lo entries" + if !finsetAll (Finset.univ : Finset (Fin n)) (fun i => (st.hi i).isSome) then + throw "missing hi entries" + let lo : Fin n → Rat := fun i => + (st.lo i).getD 0 + let hi : Fin n → Rat := fun i => + (st.hi i).getD 0 + return { lo := lo, hi := hi } + +/-- Parse a residual-interval payload from text. -/ +def parseResidualIntervalCert (input : String) : + Except String (Sigma (fun n => Circuit.ResidualIntervalCert n)) := do + let lines := input.splitOn "\n" + let tokens := lines.filterMap cleanTokens + match tokens with + | [] => throw "empty residual-interval payload" + | ["dim", nStr] :: rest => + let n ← parseNat nStr + match n with + | 0 => throw "dim must be positive" + | Nat.succ n' => + let dim := Nat.succ n' + let st0 := initResidualIntervalState dim + let st ← rest.foldlM (fun st t => parseResidualIntervalLine st t) st0 + let cert ← finalizeResidualIntervalState st + return ⟨dim, cert⟩ + | _ => throw "expected header 'dim '" + end Pure end IO diff --git a/Nfp/Sound/Bounds/MatrixNorm.lean b/Nfp/Sound/Bounds/MatrixNorm.lean index 5aa60f2..c6ebf94 100644 --- a/Nfp/Sound/Bounds/MatrixNorm.lean +++ b/Nfp/Sound/Bounds/MatrixNorm.lean @@ -126,6 +126,82 @@ theorem downstreamErrorFromBounds_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) 0 ≤ downstreamErrorFromBounds W bound := by simpa [downstreamErrorFromBounds] using rowSumWeightedNorm_nonneg W bound hbound +/-- Lower interval endpoint for a dot product with per-coordinate bounds. -/ +def dotIntervalLower {n : Nat} (v lo hi : Fin n → Rat) : Rat := + ∑ j, if 0 ≤ v j then v j * lo j else v j * hi j + +/-- Upper interval endpoint for a dot product with per-coordinate bounds. -/ +def dotIntervalUpper {n : Nat} (v lo hi : Fin n → Rat) : Rat := + ∑ j, if 0 ≤ v j then v j * hi j else v j * lo j + +/-- Absolute bound from interval endpoints for a dot product. -/ +def dotIntervalAbsBound {n : Nat} (v lo hi : Fin n → Rat) : Rat := + max |dotIntervalLower v lo hi| |dotIntervalUpper v lo hi| + +theorem dotIntervalLower_le_dotProduct {n : Nat} (v lo hi x : Fin n → Rat) + (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : + dotIntervalLower v lo hi ≤ dotProduct v x := by + classical + refine Finset.sum_le_sum ?_ + intro j _ + by_cases hv : 0 ≤ v j + · have h1 : v j * lo j ≤ v j * x j := + mul_le_mul_of_nonneg_left (hlo j) hv + simpa [hv] using h1 + · have hv' : v j ≤ 0 := le_of_lt (lt_of_not_ge hv) + have h1 : v j * hi j ≤ v j * x j := + mul_le_mul_of_nonpos_left (hhi j) hv' + simpa [hv] using h1 + +theorem dotProduct_le_dotIntervalUpper {n : Nat} (v lo hi x : Fin n → Rat) + (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : + dotProduct v x ≤ dotIntervalUpper v lo hi := by + classical + refine Finset.sum_le_sum ?_ + intro j _ + by_cases hv : 0 ≤ v j + · have h1 : v j * x j ≤ v j * hi j := + mul_le_mul_of_nonneg_left (hhi j) hv + simpa [hv] using h1 + · have hv' : v j ≤ 0 := le_of_lt (lt_of_not_ge hv) + have h1 : v j * x j ≤ v j * lo j := + mul_le_mul_of_nonpos_left (hlo j) hv' + simpa [hv] using h1 + +theorem abs_le_max_abs_abs_of_interval {a b x : Rat} (hlo : a ≤ x) (hhi : x ≤ b) : + |x| ≤ max |a| |b| := by + by_cases hx : 0 ≤ x + · have hb : 0 ≤ b := le_trans hx hhi + have hx' : |x| = x := abs_of_nonneg hx + have hb' : |b| = b := abs_of_nonneg hb + calc + |x| = x := hx' + _ ≤ b := hhi + _ = |b| := hb'.symm + _ ≤ max |a| |b| := le_max_right _ _ + · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) + have ha : a ≤ 0 := le_trans hlo hx' + have hxabs : |x| = -x := abs_of_nonpos hx' + have haabs : |a| = -a := abs_of_nonpos ha + calc + |x| = -x := hxabs + _ ≤ -a := neg_le_neg hlo + _ = |a| := by simp [haabs] + _ ≤ max |a| |b| := le_max_left _ _ + +theorem abs_dotProduct_le_dotIntervalAbsBound {n : Nat} (v lo hi x : Fin n → Rat) + (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : + |dotProduct v x| ≤ dotIntervalAbsBound v lo hi := by + have hlow : dotIntervalLower v lo hi ≤ dotProduct v x := + dotIntervalLower_le_dotProduct v lo hi x hlo hhi + have hhigh : dotProduct v x ≤ dotIntervalUpper v lo hi := + dotProduct_le_dotIntervalUpper v lo hi x hlo hhi + have habs : |dotProduct v x| ≤ + max |dotIntervalLower v lo hi| |dotIntervalUpper v lo hi| := + abs_le_max_abs_abs_of_interval hlow hhigh + unfold dotIntervalAbsBound + exact habs + /-- Row-sum norm bounds a matrix-vector product under a uniform input bound. -/ theorem abs_mulVec_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (x : Fin n → Rat) (inputBound : Rat) diff --git a/README.md b/README.md index a3e43ec..0c3718f 100644 --- a/README.md +++ b/README.md @@ -125,7 +125,7 @@ lake exe nfp induction certify_end_to_end_model \ --scores reports/gpt2_induction.cert \ --values reports/gpt2_induction.values \ --model models/gpt2_rigorous.nfpt \ - --residual-bound reports/gpt2_residual.bound + --residual-interval reports/gpt2_residual.interval ``` ## File formats @@ -179,15 +179,16 @@ w The checker computes a row-sum norm bound from the matrix entries. -### Residual-bound certificate +### Residual-interval certificate ``` dim -bound +lo +hi ``` -Each `bound` entry supplies a nonnegative absolute bound for the residual vector -coordinate `i`, used to compute downstream error. +Each `lo`/`hi` entry supplies an interval bound for residual vector coordinate `i`, +used to compute downstream error. ### Head input format (for `certify_head`) diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index 51042ff..0e8110f 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -10,7 +10,7 @@ It is intentionally brief and focused on the soundness boundary. They do **not** yet imply end-to-end model behavior. - Downstream error bounds can be computed from a **matrix payload** inside Lean. A model-based path exists, but it currently uses only the unembedding direction and relies on an external - **residual-bound certificate** (per-coordinate absolute bounds). + **residual-interval certificate** (per-coordinate lower/upper bounds). - The `certify_head` path uses a **head-input file** extracted by an untrusted script; the extractor currently ignores LayerNorm and bias terms, so it is not end-to-end faithful. - Performance: exact head-input recomputation in Lean can be slow for nontrivial sequence lengths. @@ -21,7 +21,7 @@ It is intentionally brief and focused on the soundness boundary. - Compute the downstream bound **inside Lean** from model weights and certified residual bounds (not just matrix payloads), and wire this into `certify_end_to_end`. -- Replace untrusted residual-bound generation with a verified derivation from upstream bounds. +- Replace untrusted residual-interval generation with a verified derivation from upstream bounds. - Replace untrusted extraction with a verified parser for model weight slices. - Add a formal bridge from certificates to circuit semantics and (eventually) to end-to-end transformer claims. diff --git a/scripts/build_residual_interval_cert.py b/scripts/build_residual_interval_cert.py new file mode 100644 index 0000000..c70b755 --- /dev/null +++ b/scripts/build_residual_interval_cert.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: AGPL-3.0-or-later + +""" +Build a residual-interval certificate from a GPT-2 forward pass. + +This script is untrusted. It computes per-coordinate min/max bounds by +taking extrema over a fixed input sequence (optionally restricted to active +positions from a softmax-margin certificate). The resulting intervals are +expanded slightly and rounded outwards to rationals for checking by +`nfp induction certify_end_to_end_model`. + +Usage: + uv run scripts/build_residual_interval_cert.py \ + --output reports/gpt2_residual.interval \ + --seq 32 --pattern-length 16 \ + --scores reports/gpt2_induction.cert + +Optional: + --tokens tokens.txt # whitespace-separated token ids + --random-pattern --seed 0 + --decimals 6 --safety 1e-6 +""" + +import argparse +import math +from fractions import Fraction +from pathlib import Path + +import numpy as np + +try: + import torch + from transformers import GPT2Model +except ImportError: + raise SystemExit( + "Missing dependencies. Install with: uv add transformers torch" + ) + + +def rat_to_str(q: Fraction) -> str: + if q.denominator == 1: + return str(q.numerator) + return f"{q.numerator}/{q.denominator}" + + +def build_tokens(seq: int, pattern_len: int, random_pattern: bool, seed: int) -> np.ndarray: + if random_pattern: + rng = np.random.default_rng(seed) + pattern = rng.integers(1000, 30000, size=pattern_len, endpoint=False) + else: + pattern = np.arange(pattern_len) + repeats = (seq // pattern_len) + 1 + return np.tile(pattern, repeats)[:seq] + + +def parse_tokens(path: Path) -> np.ndarray: + raw = path.read_text(encoding="ascii") + tokens = [int(tok) for tok in raw.split() if tok.strip()] + if not tokens: + raise SystemExit(f"no tokens found in {path}") + return np.array(tokens, dtype=np.int64) + + +def parse_active_positions(path: Path) -> tuple[int | None, list[int]]: + seq = None + active: list[int] = [] + for line in path.read_text(encoding="ascii").splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + parts = line.split() + if parts[0] == "seq" and len(parts) >= 2: + seq = int(parts[1]) + elif parts[0] == "active" and len(parts) >= 2: + active.append(int(parts[1])) + return seq, active + + +def expand_lo(val: float, safety: float) -> float: + slack = safety * max(1.0, abs(val)) + return val - slack + + +def expand_hi(val: float, safety: float) -> float: + slack = safety * max(1.0, abs(val)) + return val + slack + + +def floor_rat(val: float, decimals: int) -> Fraction: + scale = 10 ** decimals + return Fraction(int(math.floor(val * scale)), scale) + + +def ceil_rat(val: float, decimals: int) -> Fraction: + scale = 10 ** decimals + return Fraction(int(math.ceil(val * scale)), scale) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--output", required=True, help="Path to write certificate") + parser.add_argument("--seq", type=int, default=32, help="Sequence length") + parser.add_argument("--pattern-length", type=int, default=16, help="Pattern length") + parser.add_argument("--random-pattern", action="store_true", + help="Use random token pattern") + parser.add_argument("--seed", type=int, default=0, help="RNG seed for random pattern") + parser.add_argument("--tokens", help="Optional path to whitespace-separated tokens") + parser.add_argument("--scores", help="Optional softmax-margin certificate for active queries") + parser.add_argument("--model", default="gpt2", help="HuggingFace model name") + parser.add_argument("--device", default="cpu", help="Torch device") + parser.add_argument("--decimals", type=int, default=6, + help="Decimal rounding for rationals (outward)") + parser.add_argument("--safety", type=float, default=1e-6, + help="Relative safety slack added before rounding") + args = parser.parse_args() + + if args.seq <= 0: + raise SystemExit("seq must be positive") + if args.decimals < 0: + raise SystemExit("decimals must be nonnegative") + if args.safety < 0: + raise SystemExit("safety must be nonnegative") + + if args.tokens: + tokens = parse_tokens(Path(args.tokens)) + seq = len(tokens) + else: + seq = args.seq + tokens = build_tokens(seq, args.pattern_length, args.random_pattern, args.seed) + + positions = list(range(seq)) + if args.scores: + cert_seq, active = parse_active_positions(Path(args.scores)) + if cert_seq is not None and cert_seq != seq: + raise SystemExit(f"seq mismatch: scores={cert_seq} tokens={seq}") + if active: + positions = active + + model = GPT2Model.from_pretrained(args.model) + model.to(args.device) + model.eval() + input_ids = torch.tensor(tokens, dtype=torch.long, device=args.device).unsqueeze(0) + with torch.no_grad(): + outputs = model(input_ids) + hidden = outputs.last_hidden_state.squeeze(0).cpu().numpy() + + if hidden.shape[0] != seq: + raise SystemExit(f"hidden state seq mismatch: {hidden.shape[0]} vs {seq}") + + chosen = hidden[positions] + mins = np.min(chosen, axis=0) + maxs = np.max(chosen, axis=0) + + lo_bounds = [] + hi_bounds = [] + for lo_val, hi_val in zip(mins.tolist(), maxs.tolist(), strict=True): + lo_adj = expand_lo(float(lo_val), args.safety) + hi_adj = expand_hi(float(hi_val), args.safety) + lo_rat = floor_rat(lo_adj, args.decimals) + hi_rat = ceil_rat(hi_adj, args.decimals) + if lo_rat > hi_rat: + lo_rat, hi_rat = hi_rat, lo_rat + lo_bounds.append(lo_rat) + hi_bounds.append(hi_rat) + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w", encoding="ascii") as f: + f.write(f"dim {len(lo_bounds)}\n") + for i, (lo, hi) in enumerate(zip(lo_bounds, hi_bounds, strict=True)): + f.write(f"lo {i} {rat_to_str(lo)}\n") + f.write(f"hi {i} {rat_to_str(hi)}\n") + + print(f"Wrote residual-interval certificate to {output_path}") + + +if __name__ == "__main__": + main() From 1f504f474f8ab715df99247eda5c5ebba0a893a9 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 3 Jan 2026 16:16:58 +0100 Subject: [PATCH 100/244] Add model-based head input certification --- CLAIMS.md | 2 + Nfp/Cli.lean | 37 ++++++++- Nfp/IO.lean | 165 ++++++++++++++++++++++++--------------- Nfp/IO/NfptPure.lean | 122 +++++++++++++++++++++++++++++ README.md | 9 +++ SOUNDNESS_LIMITATIONS.md | 3 + 6 files changed, 276 insertions(+), 62 deletions(-) diff --git a/CLAIMS.md b/CLAIMS.md index 5d60637..c14e010 100644 --- a/CLAIMS.md +++ b/CLAIMS.md @@ -24,6 +24,8 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri and verifies the resulting certificates. - `nfp induction certify_head` recomputes scores/values from exact head inputs and verifies the resulting induction certificate (experimental, potentially slow). +- `nfp induction certify_head_model` reads a model binary, derives head inputs in Lean, and + verifies the resulting induction certificate (still ignores attention biases). - `nfp induction certify_end_to_end` composes a head-level logit-diff lower bound with a downstream error certificate (arithmetic consistency only). - `nfp induction certify_end_to_end_matrix` computes a downstream bound from a matrix payload diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index c8da323..bfccc40 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -192,6 +192,40 @@ def inductionCertifyHeadCmd : Cmd := `[Cli| "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." ] +/-- `nfp induction certify_head_model` subcommand. -/ +def runInductionCertifyHeadModel (p : Parsed) : IO UInt32 := do + let modelPath := p.flag! "model" |>.as! String + let layer := p.flag! "layer" |>.as! Nat + let head := p.flag! "head" |>.as! Nat + let period := p.flag! "period" |>.as! Nat + let dirTarget := p.flag! "direction-target" |>.as! Nat + let dirNegative := p.flag! "direction-negative" |>.as! Nat + let minActive? := (p.flag? "min-active").map (·.as! Nat) + let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) + let minMarginStr? := (p.flag? "min-margin").map (·.as! String) + let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) + IO.runInductionCertifyHeadModel modelPath layer head period dirTarget dirNegative + minActive? minLogitDiffStr? minMarginStr? maxEpsStr? + +/-- `nfp induction certify_head_model` subcommand. -/ +def inductionCertifyHeadModelCmd : Cmd := `[Cli| + certify_head_model VIA runInductionCertifyHeadModel; + "Check induction certificates by reading a model binary directly." + FLAGS: + model : String; "Path to the NFP_BINARY_V1 model file." + layer : Nat; "Layer index for the induction head." + head : Nat; "Head index for the induction head." + period : Nat; "Prompt period used to derive active/prev." + "direction-target" : Nat; "Target token id for logit-diff direction." + "direction-negative" : Nat; "Negative token id for logit-diff direction." + "min-active" : Nat; "Optional minimum number of active queries required \ + (default: max 1 (seq/8))." + "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ + (rational literal). Defaults to 0." + "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." + "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." +] + /-- Induction-head subcommands. -/ def inductionCmd : Cmd := `[Cli| induction NOOP; @@ -202,7 +236,8 @@ def inductionCmd : Cmd := `[Cli| inductionCertifyEndToEndCmd; inductionCertifyEndToEndMatrixCmd; inductionCertifyEndToEndModelCmd; - inductionCertifyHeadCmd + inductionCertifyHeadCmd; + inductionCertifyHeadModelCmd ] /-- The root CLI command. -/ diff --git a/Nfp/IO.lean b/Nfp/IO.lean index 7ba3feb..a4abfc3 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -737,6 +737,71 @@ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) does not match model dim {header.modelDim}" return 2 +private def checkInductionHeadInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (minActive? : Option Nat) (minLogitDiff? : Option Rat) + (minMargin maxEps : Rat) : IO UInt32 := do + match seq with + | 0 => + IO.eprintln "error: seq must be positive" + return 2 + | Nat.succ n => + let seq := Nat.succ n + let _ : NeZero seq := ⟨by simp⟩ + match Sound.buildInductionCertFromHead? inputs with + | none => + IO.eprintln "error: head inputs rejected" + return 2 + | some ⟨cert, _hcert⟩ => + let activeCount := cert.active.card + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" + return 2 + if cert.margin < minMargin then + IO.eprintln + s!"error: margin {cert.margin} below minimum {minMargin}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {cert.eps} above maximum {maxEps}" + return 2 + let tol := cert.eps * (cert.values.hi - cert.values.lo) + let logitDiffLB? := + Circuit.logitDiffLowerBound cert.active cert.prev cert.eps + cert.values.lo cert.values.hi cert.values.vals + let effectiveMinLogitDiff := + match minLogitDiff? with + | some v => some v + | none => some (0 : Rat) + match logitDiffLB? with + | none => + IO.eprintln "error: empty active set for logit-diff bound" + return 2 + | some logitDiffLB => + let violation? : Option Rat := + match effectiveMinLogitDiff with + | none => none + | some minLogitDiff => + if logitDiffLB < minLogitDiff then + some minLogitDiff + else + none + match violation? with + | some minLogitDiff => + IO.eprintln + s!"error: logitDiffLB {logitDiffLB} \ + below minimum {minLogitDiff}" + return 2 + | none => + IO.println + s!"ok: induction bound certified \ + (seq={seq}, active={activeCount}, \ + tol={tol}, logitDiffLB={logitDiffLB})" + return 0 + /-- Build and check induction certificates from exact head inputs. -/ def runInductionCertifyHead (inputsPath : System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) @@ -762,67 +827,45 @@ def runInductionCertifyHead (inputsPath : System.FilePath) | Except.error msg => IO.eprintln s!"error: {msg}" return 1 - | Except.ok ⟨seq, ⟨dModel, ⟨dHead, inputs⟩⟩⟩ => - match seq with - | 0 => - IO.eprintln "error: seq must be positive" - return 2 - | Nat.succ n => - let seq := Nat.succ n - let _ : NeZero seq := ⟨by simp⟩ - match Sound.buildInductionCertFromHead? inputs with - | none => - IO.eprintln "error: head inputs rejected" - return 2 - | some ⟨cert, _hcert⟩ => - let activeCount := cert.active.card - let defaultMinActive := max 1 (seq / 8) - let minActive := minActive?.getD defaultMinActive - if activeCount < minActive then - IO.eprintln - s!"error: active queries {activeCount} below minimum {minActive}" - return 2 - if cert.margin < minMargin then - IO.eprintln - s!"error: margin {cert.margin} below minimum {minMargin}" - return 2 - if maxEps < cert.eps then - IO.eprintln - s!"error: eps {cert.eps} above maximum {maxEps}" - return 2 - let tol := cert.eps * (cert.values.hi - cert.values.lo) - let logitDiffLB? := - Circuit.logitDiffLowerBound cert.active cert.prev cert.eps - cert.values.lo cert.values.hi cert.values.vals - let effectiveMinLogitDiff := - match minLogitDiff? with - | some v => some v - | none => some (0 : Rat) - match logitDiffLB? with - | none => - IO.eprintln "error: empty active set for logit-diff bound" - return 2 - | some logitDiffLB => - let violation? : Option Rat := - match effectiveMinLogitDiff with - | none => none - | some minLogitDiff => - if logitDiffLB < minLogitDiff then - some minLogitDiff - else - none - match violation? with - | some minLogitDiff => - IO.eprintln - s!"error: logitDiffLB {logitDiffLB} \ - below minimum {minLogitDiff}" - return 2 - | none => - IO.println - s!"ok: induction bound certified \ - (seq={seq}, active={activeCount}, \ - tol={tol}, logitDiffLB={logitDiffLB})" - return 0 + | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => + checkInductionHeadInputs inputs minActive? minLogitDiff? minMargin maxEps + +/-- Build and check induction certificates from a model binary. -/ +def runInductionCertifyHeadModel (modelPath : System.FilePath) + (layer head period dirTarget dirNegative : Nat) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (1 / 2 : Rat) + let data ← IO.FS.readBinFile modelPath + match NfptPure.parseHeader data with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨header, start⟩ => + match + NfptPure.readInductionHeadInputs + data start header layer head period dirTarget dirNegative + with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok inputs => + checkInductionHeadInputs inputs minActive? minLogitDiff? minMargin maxEps end IO diff --git a/Nfp/IO/NfptPure.lean b/Nfp/IO/NfptPure.lean index 7193ce9..cd13a24 100644 --- a/Nfp/IO/NfptPure.lean +++ b/Nfp/IO/NfptPure.lean @@ -2,6 +2,8 @@ import Mathlib.Algebra.Order.Ring.Rat import Mathlib.Data.List.Range +import Nfp.Model.InductionHead +import Nfp.Model.InductionPrompt /-! Pure parsing utilities for `NFP_BINARY_V1` model files. @@ -172,6 +174,58 @@ private def bytesI32 (n : Nat) : Nat := private def bytesF64 (n : Nat) : Nat := n * 8 +private def sqrtNat? (n : Nat) : Option Nat := + let k := Nat.sqrt n + if k * k = n then + some k + else + none + +private def scaleOfHeadDim (dHead : Nat) : Except String Rat := do + match sqrtNat? dHead with + | some k => + if k = 0 then + throw "head_dim must be positive" + else + pure (Rat.ofInt 1 / Rat.ofInt (Int.ofNat k)) + | none => + throw "head_dim must be a perfect square to compute scale" + +private def matrixIndex {rows cols : Nat} (i : Fin rows) (j : Fin cols) : Fin (rows * cols) := + let idx := i.val * cols + j.val + have hstep : i.val * cols + j.val < (i.val + 1) * cols := by + have h' : i.val * cols + j.val < i.val * cols + cols := + Nat.add_lt_add_left j.isLt _ + have hmul : (i.val + 1) * cols = i.val * cols + cols := by + simpa [Nat.succ_eq_add_one] using (Nat.succ_mul i.val cols) + exact hmul ▸ h' + have hle : (i.val + 1) * cols ≤ rows * cols := + Nat.mul_le_mul_right cols (Nat.succ_le_iff.mpr i.isLt) + ⟨idx, lt_of_lt_of_le hstep hle⟩ + +private def readF64List (data : ByteArray) (off : Nat) (count : Nat) : + Except String {xs : List Rat // xs.length = count} := do + match count with + | 0 => return ⟨[], rfl⟩ + | Nat.succ n => + match readF64Rat data off with + | some v => + let rest ← readF64List data (off + bytesF64 1) n + return ⟨v :: rest.1, by simp [rest.2]⟩ + | none => throw s!"invalid f64 at offset {off}" + +private def readF64Matrix (data : ByteArray) (off : Nat) (rows cols : Nat) : + Except String (Fin rows → Fin cols → Rat) := do + let count := rows * cols + let ⟨vals, hlen⟩ ← readF64List data off count + let hlen' : vals.length = rows * cols := by + simpa using hlen + let mat : Fin rows → Fin cols → Rat := fun i j => + let idx := matrixIndex i j + let hidx : idx.val < vals.length := lt_of_lt_of_eq idx.isLt hlen'.symm + vals.get ⟨idx.val, hidx⟩ + return mat + private def f64CountPerHead (h : NfptHeader) : Nat := 4 * h.modelDim * h.headDim + 3 * h.headDim @@ -185,10 +239,52 @@ private def f64CountBeforeUnembed (h : NfptHeader) : Nat := h.numLayers * f64CountPerLayer h + (2 * h.modelDim) +private def f64CountBeforeHeads (h : NfptHeader) : Nat := + h.seqLen * h.modelDim + /-- Byte offset from the binary start to the unembedding matrix. -/ def unembedOffset (h : NfptHeader) : Nat := bytesI32 h.seqLen + bytesF64 (f64CountBeforeUnembed h) +/-- Read input embeddings stored in the binary. -/ +def readEmbeddings (data : ByteArray) (start : Nat) (h : NfptHeader) : + Except String (Fin h.seqLen → Fin h.modelDim → Rat) := do + let base := start + bytesI32 h.seqLen + readF64Matrix data base h.seqLen h.modelDim + +private def headOffset (h : NfptHeader) (layer head : Nat) : Nat := + bytesI32 h.seqLen + + bytesF64 (f64CountBeforeHeads h + + layer * f64CountPerLayer h + + head * f64CountPerHead h) + +private def readHeadWeights (data : ByteArray) (start : Nat) (h : NfptHeader) + (layer head : Nat) : + Except String + ((Fin h.modelDim → Fin h.headDim → Rat) × + (Fin h.modelDim → Fin h.headDim → Rat) × + (Fin h.modelDim → Fin h.headDim → Rat) × + (Fin h.modelDim → Fin h.headDim → Rat)) := do + if layer < h.numLayers then + if head < h.numHeads then + let base := start + headOffset h layer head + let wq ← readF64Matrix data base h.modelDim h.headDim + let offbq := base + bytesF64 (h.modelDim * h.headDim) + let offwk := offbq + bytesF64 h.headDim + let wk ← readF64Matrix data offwk h.modelDim h.headDim + let offbk := offwk + bytesF64 (h.modelDim * h.headDim) + let offwv := offbk + bytesF64 h.headDim + let wv ← readF64Matrix data offwv h.modelDim h.headDim + let offbv := offwv + bytesF64 (h.modelDim * h.headDim) + let offwo := offbv + bytesF64 h.headDim + let woRaw ← readF64Matrix data offwo h.headDim h.modelDim + let wo : Fin h.modelDim → Fin h.headDim → Rat := fun i j => woRaw j i + return (wq, wk, wv, wo) + else + throw s!"head index out of range: {head}" + else + throw s!"layer index out of range: {layer}" + /-- Read a single unembedding column as exact rationals. -/ def readUnembedColumn (data : ByteArray) (start : Nat) (h : NfptHeader) (col : Nat) : Except String (Fin h.modelDim → Rat) := do @@ -209,6 +305,32 @@ def readUnembedColumn (data : ByteArray) (start : Nat) (h : NfptHeader) (col : N else throw s!"column out of range: {col}" +/-- Read induction-head inputs directly from the model binary. -/ +def readInductionHeadInputs (data : ByteArray) (start : Nat) (h : NfptHeader) + (layer head period dirTarget dirNegative : Nat) : + Except String (Model.InductionHeadInputs h.seqLen h.modelDim h.headDim) := do + let scale ← scaleOfHeadDim h.headDim + let embed ← readEmbeddings data start h + let (wq, wk, wv, wo) ← readHeadWeights data start h layer head + let colTarget ← readUnembedColumn data start h dirTarget + let colNegative ← readUnembedColumn data start h dirNegative + let direction : Fin h.modelDim → Rat := fun i => colTarget i - colNegative i + let directionSpec : Circuit.DirectionSpec := + { target := dirTarget, negative := dirNegative } + let active := Model.activeOfPeriod (seq := h.seqLen) period + let prev := Model.prevOfPeriod (seq := h.seqLen) period + pure + { scale := scale + active := active + prev := prev + embed := embed + wq := wq + wk := wk + wv := wv + wo := wo + directionSpec := directionSpec + direction := direction } + end NfptPure end IO diff --git a/README.md b/README.md index 0c3718f..f0dc56d 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,15 @@ lake exe nfp induction certify_head --inputs reports/gpt2_induction.head This path recomputes scores/values in Lean from exact head inputs. It is **experimental** and can be slow for nontrivial sequence lengths. +You can also derive the head inputs directly from an `NFP_BINARY_V1` model file: + +```bash +lake exe nfp induction certify_head_model \ + --model models/gpt2_rigorous_with_gelu_kind_seq32.nfpt \ + --layer 5 --head 1 --period 16 \ + --direction-target 1 --direction-negative 2 +``` + ### End-to-end check with downstream bound (prototype) ```bash diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index 0e8110f..403f99f 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -13,6 +13,9 @@ It is intentionally brief and focused on the soundness boundary. **residual-interval certificate** (per-coordinate lower/upper bounds). - The `certify_head` path uses a **head-input file** extracted by an untrusted script; the extractor currently ignores LayerNorm and bias terms, so it is not end-to-end faithful. +- The `certify_head_model` path derives head inputs from the model binary in Lean, but it still + ignores attention biases and LayerNorm, and currently requires `head_dim` to be a perfect square + to represent the scale as an exact rational. - Performance: exact head-input recomputation in Lean can be slow for nontrivial sequence lengths. - There is no bridge theorem connecting certificate validity to a full circuit/model semantics statement (for example, a formal statement about logits under a transformer block stack). From 1f3489e084c6ccfb27ae28693380c0b66b77c8c2 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 3 Jan 2026 16:30:11 +0100 Subject: [PATCH 101/244] Add interval bounds for matrix products --- Nfp/Sound/Bounds/MatrixNorm.lean | 55 ++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/Nfp/Sound/Bounds/MatrixNorm.lean b/Nfp/Sound/Bounds/MatrixNorm.lean index c6ebf94..fe58dca 100644 --- a/Nfp/Sound/Bounds/MatrixNorm.lean +++ b/Nfp/Sound/Bounds/MatrixNorm.lean @@ -5,6 +5,7 @@ import Mathlib.Algebra.Order.Ring.Abs import Mathlib.Algebra.Order.Ring.Rat import Mathlib.Data.Matrix.Mul import Nfp.Circuit.Cert.DownstreamLinear +import Nfp.Circuit.Cert.ResidualInterval /-! Row-sum matrix norms for downstream linear certificates. @@ -138,6 +139,16 @@ def dotIntervalUpper {n : Nat} (v lo hi : Fin n → Rat) : Rat := def dotIntervalAbsBound {n : Nat} (v lo hi : Fin n → Rat) : Rat := max |dotIntervalLower v lo hi| |dotIntervalUpper v lo hi| +/-- Lower interval endpoint for a matrix-vector product under input intervals. -/ +def mulVecIntervalLower {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (lo hi : Fin n → Rat) : Fin m → Rat := + fun i => dotIntervalLower (fun j => W i j) lo hi + +/-- Upper interval endpoint for a matrix-vector product under input intervals. -/ +def mulVecIntervalUpper {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (lo hi : Fin n → Rat) : Fin m → Rat := + fun i => dotIntervalUpper (fun j => W i j) lo hi + theorem dotIntervalLower_le_dotProduct {n : Nat} (v lo hi x : Fin n → Rat) (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : dotIntervalLower v lo hi ≤ dotProduct v x := by @@ -202,6 +213,50 @@ theorem abs_dotProduct_le_dotIntervalAbsBound {n : Nat} (v lo hi x : Fin n → R unfold dotIntervalAbsBound exact habs +/-- Matrix-interval lower bounds dominate matrix-vector products. -/ +theorem mulVecIntervalLower_le_mulVec {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (lo hi x : Fin n → Rat) (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : + ∀ i, mulVecIntervalLower W lo hi i ≤ Matrix.mulVec W x i := by + intro i + have h := + dotIntervalLower_le_dotProduct (v := fun j => W i j) lo hi x hlo hhi + simpa [mulVecIntervalLower, Matrix.mulVec, dotProduct] using h + +/-- Matrix-interval upper bounds dominate matrix-vector products. -/ +theorem mulVec_le_mulVecIntervalUpper {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (lo hi x : Fin n → Rat) (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : + ∀ i, Matrix.mulVec W x i ≤ mulVecIntervalUpper W lo hi i := by + intro i + have h := + dotProduct_le_dotIntervalUpper (v := fun j => W i j) lo hi x hlo hhi + simpa [mulVecIntervalUpper, Matrix.mulVec, dotProduct] using h + +/-- Interval endpoints for `mulVec` are ordered when the input interval is ordered. -/ +theorem mulVecIntervalLower_le_upper {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (lo hi : Fin n → Rat) (hlohi : ∀ j, lo j ≤ hi j) : + ∀ i, mulVecIntervalLower W lo hi i ≤ mulVecIntervalUpper W lo hi i := by + intro i + have hlow : + dotIntervalLower (fun j => W i j) lo hi ≤ dotProduct (fun j => W i j) lo := + dotIntervalLower_le_dotProduct (v := fun j => W i j) lo hi lo + (fun j => le_rfl) hlohi + have hhigh : + dotProduct (fun j => W i j) lo ≤ dotIntervalUpper (fun j => W i j) lo hi := + dotProduct_le_dotIntervalUpper (v := fun j => W i j) lo hi lo + (fun j => le_rfl) hlohi + exact le_trans hlow hhigh + +/-- Build a residual-interval certificate by applying a matrix to an input interval. -/ +def buildResidualIntervalCertFromMatrix {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (lo hi : Fin n → Rat) (hlohi : ∀ j, lo j ≤ hi j) : + {c : Circuit.ResidualIntervalCert m // Circuit.ResidualIntervalBounds c} := by + let lo' := mulVecIntervalLower W lo hi + let hi' := mulVecIntervalUpper W lo hi + refine ⟨{ lo := lo', hi := hi' }, ?_⟩ + refine { lo_le_hi := ?_ } + intro i + exact mulVecIntervalLower_le_upper W lo hi hlohi i + /-- Row-sum norm bounds a matrix-vector product under a uniform input bound. -/ theorem abs_mulVec_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (x : Fin n → Rat) (inputBound : Rat) From 1cbaa33fcb1341a952a17ccd11fcc9ee50c5803d Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 3 Jan 2026 17:11:52 +0100 Subject: [PATCH 102/244] Add head output interval CLI --- Nfp/Cli.lean | 44 ++++++- Nfp/IO.lean | 68 +++++++++++ Nfp/Sound/Induction.lean | 253 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 364 insertions(+), 1 deletion(-) diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index bfccc40..7d6d152 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -226,6 +226,46 @@ def inductionCertifyHeadModelCmd : Cmd := `[Cli| "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." ] +/-- `nfp induction head_interval` subcommand. -/ +def runInductionHeadInterval (p : Parsed) : IO UInt32 := do + let inputsPath := p.flag! "inputs" |>.as! String + let outPath? := (p.flag? "out").map (·.as! String) + IO.runInductionHeadInterval inputsPath outPath? + +/-- `nfp induction head_interval` subcommand. -/ +def inductionHeadIntervalCmd : Cmd := `[Cli| + head_interval VIA runInductionHeadInterval; + "Build head-output interval bounds from exact head inputs." + FLAGS: + inputs : String; "Path to the induction head input file." + out : String; "Optional path to write the residual-interval certificate." +] + +/-- `nfp induction head_interval_model` subcommand. -/ +def runInductionHeadIntervalModel (p : Parsed) : IO UInt32 := do + let modelPath := p.flag! "model" |>.as! String + let layer := p.flag! "layer" |>.as! Nat + let head := p.flag! "head" |>.as! Nat + let period := p.flag! "period" |>.as! Nat + let dirTarget := p.flag! "direction-target" |>.as! Nat + let dirNegative := p.flag! "direction-negative" |>.as! Nat + let outPath? := (p.flag? "out").map (·.as! String) + IO.runInductionHeadIntervalModel modelPath layer head period dirTarget dirNegative outPath? + +/-- `nfp induction head_interval_model` subcommand. -/ +def inductionHeadIntervalModelCmd : Cmd := `[Cli| + head_interval_model VIA runInductionHeadIntervalModel; + "Build head-output interval bounds by reading a model binary directly." + FLAGS: + model : String; "Path to the NFP_BINARY_V1 model file." + layer : Nat; "Layer index for the induction head." + head : Nat; "Head index for the induction head." + period : Nat; "Prompt period used to derive active/prev." + "direction-target" : Nat; "Target token id for logit-diff direction." + "direction-negative" : Nat; "Negative token id for logit-diff direction." + out : String; "Optional path to write the residual-interval certificate." +] + /-- Induction-head subcommands. -/ def inductionCmd : Cmd := `[Cli| induction NOOP; @@ -237,7 +277,9 @@ def inductionCmd : Cmd := `[Cli| inductionCertifyEndToEndMatrixCmd; inductionCertifyEndToEndModelCmd; inductionCertifyHeadCmd; - inductionCertifyHeadModelCmd + inductionCertifyHeadModelCmd; + inductionHeadIntervalCmd; + inductionHeadIntervalModelCmd ] /-- The root CLI command. -/ diff --git a/Nfp/IO.lean b/Nfp/IO.lean index a4abfc3..34d57b4 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -1,5 +1,6 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later +import Mathlib.Data.List.Range import Nfp.IO.Pure import Nfp.IO.NfptPure import Nfp.Circuit.Cert.LogitDiff @@ -75,6 +76,42 @@ def loadInductionHeadInputs (path : System.FilePath) : let data ← IO.FS.readFile path return Pure.parseInductionHeadInputs data +private def renderResidualIntervalCert {n : Nat} (c : ResidualIntervalCert n) : String := + let header := s!"dim {n}" + let lines := + (List.finRange n).foldr (fun i acc => + s!"lo {i.val} {c.lo i}" :: s!"hi {i.val} {c.hi i}" :: acc) [] + String.intercalate "\n" (header :: lines) + +private def emitResidualIntervalCert {n : Nat} (c : ResidualIntervalCert n) + (outPath? : Option System.FilePath) : IO Unit := do + let payload := renderResidualIntervalCert c + match outPath? with + | some path => IO.FS.writeFile path (payload ++ "\n") + | none => IO.println payload + +private def buildHeadOutputIntervalFromInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (outPath? : Option System.FilePath) : IO UInt32 := do + match seq with + | 0 => + IO.eprintln "error: seq must be positive" + return 2 + | Nat.succ n => + let seq := Nat.succ n + let _ : NeZero seq := ⟨by simp⟩ + match Sound.buildHeadOutputIntervalFromHead? inputs with + | none => + IO.eprintln "error: head output interval rejected" + return 2 + | some result => + emitResidualIntervalCert result.cert outPath? + if outPath?.isSome then + let activeCount := result.active.card + IO.println + s!"ok: head output interval built (seq={seq}, dim={dModel}, active={activeCount})" + return 0 + private def checkSoftmaxMargin (seq : Nat) (cert : SoftmaxMarginCert seq) : IO (Except String Unit) := match seq with @@ -867,6 +904,37 @@ def runInductionCertifyHeadModel (modelPath : System.FilePath) | Except.ok inputs => checkInductionHeadInputs inputs minActive? minLogitDiff? minMargin maxEps +/-- Build head-output interval bounds from exact head inputs. -/ +def runInductionHeadInterval (inputsPath : System.FilePath) + (outPath? : Option System.FilePath) : IO UInt32 := do + let parsedInputs ← loadInductionHeadInputs inputsPath + match parsedInputs with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => + buildHeadOutputIntervalFromInputs inputs outPath? + +/-- Build head-output interval bounds from a model binary. -/ +def runInductionHeadIntervalModel (modelPath : System.FilePath) + (layer head period dirTarget dirNegative : Nat) + (outPath? : Option System.FilePath) : IO UInt32 := do + let data ← IO.FS.readBinFile modelPath + match NfptPure.parseHeader data with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨header, start⟩ => + match + NfptPure.readInductionHeadInputs + data start header layer head period dirTarget dirNegative + with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok inputs => + buildHeadOutputIntervalFromInputs inputs outPath? + end IO end Nfp diff --git a/Nfp/Sound/Induction.lean b/Nfp/Sound/Induction.lean index a62eb4a..19e7ab6 100644 --- a/Nfp/Sound/Induction.lean +++ b/Nfp/Sound/Induction.lean @@ -5,6 +5,7 @@ import Mathlib.Algebra.Order.Ring.Rat import Mathlib.Data.Finset.Lattice.Fold import Mathlib.Data.Rat.Cast.Order import Mathlib.Data.Vector.Defs +import Nfp.Circuit.Cert.ResidualInterval import Nfp.Circuit.Cert.SoftmaxMargin import Nfp.Circuit.Cert.ValueRange import Nfp.Circuit.Layers.Softmax @@ -76,6 +77,15 @@ private opaque valsVecOfInputs {seq dModel dHead : Nat} let dirHead : Fin dHead → Rat := fun d => dirHeadVec.get d Linear.dotFin dHead (fun d => vVec d) (fun d => dirHead d)) +/-- Cached per-key head outputs in model space (opaque to avoid kernel reduction). -/ +private opaque headValueVecOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vVecVec : Vector (Vector Rat dHead) seq) : Vector (Vector Rat dModel) seq := + Vector.ofFn (fun k : Fin seq => + Vector.ofFn (fun i : Fin dModel => + let vVec : Fin dHead → Rat := fun d => (vVecVec.get k).get d + Linear.dotFin dHead (fun d => vVec d) (fun d => inputs.wo i d))) + /-- Sound induction-certificate payload built from exact head inputs. -/ structure InductionHeadCert (seq : Nat) where /-- Weight tolerance. -/ @@ -434,6 +444,249 @@ def buildInductionCertFromHead? [NeZero seq] {dModel dHead : Nat} exact hle.trans hsum_others_le some ⟨cert, { softmax_bounds := hsoftmax_bounds, value_bounds := hvalues }⟩ +section HeadOutputInterval + +variable {seq dModel dHead : Nat} + +noncomputable section + +/-- Real-valued head output using explicit score inputs. -/ +def headOutputWithScores (scores : Fin seq → Fin seq → Rat) + (inputs : Model.InductionHeadInputs seq dModel dHead) + (q : Fin seq) (i : Fin dModel) : Real := + let weights : Fin seq → Fin seq → Real := fun q k => + Circuit.softmax (fun j => (scores q j : Real)) k + let vVecVec := vVecVecOfInputs inputs + let headValuesVec := headValueVecOfInputs inputs vVecVec + let vals : Fin seq → Real := fun k => (headValuesVec.get k).get i + dotProduct (weights q) vals + +/-- Unfolding lemma for `headOutputWithScores`. -/ +theorem headOutputWithScores_def (scores : Fin seq → Fin seq → Rat) + (inputs : Model.InductionHeadInputs seq dModel dHead) + (q : Fin seq) (i : Fin dModel) : + headOutputWithScores scores inputs q i = + let weights : Fin seq → Fin seq → Real := fun q k => + Circuit.softmax (fun j => (scores q j : Real)) k + let vVecVec := vVecVecOfInputs inputs + let headValuesVec := headValueVecOfInputs inputs vVecVec + let vals : Fin seq → Real := fun k => (headValuesVec.get k).get i + dotProduct (weights q) vals := rfl + +/-- Real-valued head output for a query and model dimension. -/ +def headOutput (inputs : Model.InductionHeadInputs seq dModel dHead) + (q : Fin seq) (i : Fin dModel) : Real := + let qVecVec := qVecVecOfInputs inputs + let kVecVec := kVecVecOfInputs inputs + let scoresVec := scoresVecOfInputs inputs qVecVec kVecVec + let scores : Fin seq → Fin seq → Rat := fun q k => (scoresVec.get q).get k + headOutputWithScores scores inputs q i + +/-- Unfolding lemma for `headOutput`. -/ +theorem headOutput_def (inputs : Model.InductionHeadInputs seq dModel dHead) + (q : Fin seq) (i : Fin dModel) : + headOutput inputs q i = + let qVecVec := qVecVecOfInputs inputs + let kVecVec := kVecVecOfInputs inputs + let scoresVec := scoresVecOfInputs inputs qVecVec kVecVec + let scores : Fin seq → Fin seq → Rat := fun q k => (scoresVec.get q).get k + headOutputWithScores scores inputs q i := rfl + +/-- Soundness predicate for head-output interval bounds. -/ +structure HeadOutputIntervalSound [NeZero seq] + (inputs : Model.InductionHeadInputs seq dModel dHead) + (scores : Fin seq → Fin seq → Rat) + (active : Finset (Fin seq)) + (c : Circuit.ResidualIntervalCert dModel) : Prop where + /-- Interval bounds are ordered coordinatewise. -/ + bounds : Circuit.ResidualIntervalBounds c + /-- Active-query outputs lie inside the interval bounds. -/ + output_mem : + ∀ q, q ∈ active → ∀ i, + (c.lo i : Real) ≤ headOutputWithScores scores inputs q i ∧ + headOutputWithScores scores inputs q i ≤ (c.hi i : Real) + +/-- Certified head-output interval data for a specific active set. -/ +structure HeadOutputIntervalResult [NeZero seq] + (inputs : Model.InductionHeadInputs seq dModel dHead) where + /-- Scores used to derive softmax weights. -/ + scores : Fin seq → Fin seq → Rat + /-- Active queries covered by the interval bounds. -/ + active : Finset (Fin seq) + /-- Residual-interval certificate for head outputs. -/ + cert : Circuit.ResidualIntervalCert dModel + /-- Soundness proof for the interval bounds. -/ + sound : HeadOutputIntervalSound inputs scores active cert + +/-- Build residual-interval bounds for head outputs on active queries. -/ +def buildHeadOutputIntervalFromHead? [NeZero seq] + (inputs : Model.InductionHeadInputs seq dModel dHead) : + Option (HeadOutputIntervalResult inputs) := by + classical + cases seq with + | zero => + cases (NeZero.ne (n := (0 : Nat)) rfl) + | succ n => + cases hbuild : buildInductionCertFromHead? inputs with + | none => exact none + | some certWithProof => + rcases certWithProof with ⟨cert, hcert⟩ + let vVecVec := vVecVecOfInputs inputs + let headValuesVec := headValueVecOfInputs inputs vVecVec + let headValue : Fin (Nat.succ n) → Fin dModel → Rat := fun k i => + (headValuesVec.get k).get i + let scores : Fin (Nat.succ n) → Fin (Nat.succ n) → Rat := cert.scores + let scoresReal : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := + fun q k => (scores q k : Real) + let weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := fun q k => + Circuit.softmax (scoresReal q) k + let activeSet : Finset (Fin (Nat.succ n)) := cert.active + let univ : Finset (Fin (Nat.succ n)) := Finset.univ + have huniv : univ.Nonempty := Finset.univ_nonempty + let loVal : Fin dModel → Rat := fun i => + univ.inf' huniv (fun k => headValue k i) + let hiVal : Fin dModel → Rat := fun i => + univ.sup' huniv (fun k => headValue k i) + have hvalsBounds : + ∀ i, Layers.ValueRangeBounds (Val := Rat) (loVal i) (hiVal i) + (fun k => headValue k i) := by + intro i + refine { lo_le_hi := ?_, lo_le := ?_, le_hi := ?_ } + · rcases huniv with ⟨k0, hk0⟩ + have hlo := + Finset.inf'_le (s := univ) (f := fun k => headValue k i) hk0 + have hhi := + Finset.le_sup' (s := univ) (f := fun k => headValue k i) hk0 + exact le_trans hlo hhi + · intro k + exact Finset.inf'_le (s := univ) (f := fun k => headValue k i) (by simp [univ]) + · intro k + exact Finset.le_sup' (s := univ) (f := fun k => headValue k i) (by simp [univ]) + have hvalsBoundsReal : + ∀ i, Layers.ValueRangeBounds (Val := Real) + (loVal i : Real) (hiVal i : Real) + (fun k => (headValue k i : Real)) := by + intro i + have hvals := hvalsBounds i + refine { lo_le_hi := ?_, lo_le := ?_, le_hi := ?_ } + · exact (Rat.cast_le (K := Real)).2 hvals.lo_le_hi + · intro k + exact (Rat.cast_le (K := Real)).2 (hvals.lo_le k) + · intro k + exact (Rat.cast_le (K := Real)).2 (hvals.le_hi k) + have hsoftmax : + Layers.SoftmaxMarginBoundsOn (Val := Real) (cert.eps : Real) (cert.margin : Real) + (fun q => q ∈ activeSet) cert.prev scoresReal weights := by + simpa [scoresReal, weights, activeSet] using hcert.softmax_bounds + have hweights : + Layers.OneHotApproxBoundsOnActive (Val := Real) (cert.eps : Real) + (fun q => q ∈ activeSet) cert.prev weights := + Layers.oneHotApproxBoundsOnActive_of_softmaxMargin + (Val := Real) + (ε := (cert.eps : Real)) + (margin := (cert.margin : Real)) + (active := fun q => q ∈ activeSet) + (prev := cert.prev) + (scores := scoresReal) + (weights := weights) + hsoftmax + have happrox : + ∀ i, Layers.InductionSpecApproxOn (Val := Real) (n := n) + ((cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real))) + (fun q => q ∈ activeSet) cert.prev + (fun q => dotProduct (weights q) (fun k => (headValue k i : Real))) + (fun k => (headValue k i : Real)) := by + intro i + exact + Layers.inductionSpecApproxOn_of_oneHotApprox_valueRange + (Val := Real) + (n := n) + (ε := (cert.eps : Real)) + (lo := (loVal i : Real)) + (hi := (hiVal i : Real)) + (active := fun q => q ∈ activeSet) + (prev := cert.prev) + (weights := weights) + (vals := fun k => (headValue k i : Real)) + (hweights := hweights) + (hvals := hvalsBoundsReal i) + let delta : Fin dModel → Rat := fun i => hiVal i - loVal i + let boundLoRat : Fin (Nat.succ n) → Fin dModel → Rat := fun q i => + headValue (cert.prev q) i - cert.eps * delta i + let boundHiRat : Fin (Nat.succ n) → Fin dModel → Rat := fun q i => + headValue (cert.prev q) i + cert.eps * delta i + let loOut : Fin dModel → Rat := fun i => + if h : activeSet.Nonempty then + activeSet.inf' h (fun q => boundLoRat q i) + else + 0 + let hiOut : Fin dModel → Rat := fun i => + if h : activeSet.Nonempty then + activeSet.sup' h (fun q => boundHiRat q i) + else + 0 + have hout : + ∀ q, q ∈ activeSet → ∀ i, + (loOut i : Real) ≤ headOutputWithScores scores inputs q i ∧ + headOutputWithScores scores inputs q i ≤ (hiOut i : Real) := by + intro q hq i + have hactive : activeSet.Nonempty := ⟨q, hq⟩ + have hspec := (happrox i) q hq + have hout_def : + headOutputWithScores scores inputs q i = + dotProduct (weights q) (fun k => (headValue k i : Real)) := by + simp [headOutputWithScores, scoresReal, weights, headValue, headValuesVec, vVecVec] + have hupper : + headOutputWithScores scores inputs q i ≤ (boundHiRat q i : Real) := by + have hupper' := + (happrox i) q hq |>.1 + simpa [hout_def, boundHiRat, delta] using hupper' + have hlower : + (boundLoRat q i : Real) ≤ headOutputWithScores scores inputs q i := by + have hlower' : + (headValue (cert.prev q) i : Real) - + (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) ≤ + dotProduct (weights q) (fun k => (headValue k i : Real)) := by + exact (sub_le_iff_le_add).2 hspec.2 + simpa [hout_def, boundLoRat, delta] using hlower' + have hlo : + (loOut i : Real) ≤ (boundLoRat q i : Real) := by + have hloRat : loOut i ≤ boundLoRat q i := by + simpa [loOut, hactive] using + (Finset.inf'_le (s := activeSet) (f := fun q => boundLoRat q i) hq) + exact (Rat.cast_le (K := Real)).2 hloRat + have hhi : + (boundHiRat q i : Real) ≤ (hiOut i : Real) := by + have hhiRat : boundHiRat q i ≤ hiOut i := by + simpa [hiOut, hactive] using + (Finset.le_sup' (s := activeSet) (f := fun q => boundHiRat q i) hq) + exact (Rat.cast_le (K := Real)).2 hhiRat + exact ⟨le_trans hlo hlower, le_trans hupper hhi⟩ + have hbounds : Circuit.ResidualIntervalBounds { lo := loOut, hi := hiOut } := by + refine { lo_le_hi := ?_ } + intro i + by_cases hactive : activeSet.Nonempty + · rcases hactive with ⟨q, hq⟩ + have hout_i := hout q hq i + have hleReal : (loOut i : Real) ≤ (hiOut i : Real) := + le_trans hout_i.1 hout_i.2 + exact (Rat.cast_le (K := Real)).1 hleReal + · simp [loOut, hiOut, hactive] + let certOut : Circuit.ResidualIntervalCert dModel := { lo := loOut, hi := hiOut } + exact some + { scores := scores + active := activeSet + cert := certOut + sound := + { bounds := hbounds + output_mem := by + intro q hq i + exact hout q hq i } } + +end + +end HeadOutputInterval + end Sound end Nfp From d00a8afa9cd3f3ad1b3149260c4ec32eb94d16c5 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 3 Jan 2026 18:02:32 +0100 Subject: [PATCH 103/244] Plumb head biases and LN metadata --- CLAIMS.md | 5 +- Nfp/IO/NfptPure.lean | 137 +++++++++++++++++++++++++++--- Nfp/IO/Pure.lean | 72 ++++++++++++++++ Nfp/Model/Gpt2.lean | 14 +++ Nfp/Model/InductionHead.lean | 14 +++ Nfp/Sound/Gpt2/HeadInputs.lean | 14 +++ Nfp/Sound/Induction.lean | 9 +- README.md | 7 ++ SOUNDNESS_LIMITATIONS.md | 10 ++- scripts/build_gpt2_head_inputs.py | 82 +++++++++++++++--- 10 files changed, 333 insertions(+), 31 deletions(-) diff --git a/CLAIMS.md b/CLAIMS.md index c14e010..234e04c 100644 --- a/CLAIMS.md +++ b/CLAIMS.md @@ -25,7 +25,7 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri - `nfp induction certify_head` recomputes scores/values from exact head inputs and verifies the resulting induction certificate (experimental, potentially slow). - `nfp induction certify_head_model` reads a model binary, derives head inputs in Lean, and - verifies the resulting induction certificate (still ignores attention biases). + verifies the resulting induction certificate (includes attention projection biases). - `nfp induction certify_end_to_end` composes a head-level logit-diff lower bound with a downstream error certificate (arithmetic consistency only). - `nfp induction certify_end_to_end_matrix` computes a downstream bound from a matrix payload @@ -39,7 +39,8 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri - Python helpers that generate certificates from GPT-2 weights or head inputs: `scripts/build_gpt2_induction_cert.py`, `scripts/build_gpt2_head_inputs.py`, `scripts/build_downstream_linear_cert.py`. -- The head-input extractor currently ignores LayerNorm and bias terms. +- The head-input extractor now emits attention projection biases and LayerNorm metadata, but + the Lean-side computation still ignores LayerNorm and the shared attention output bias. - Residual-interval certificates are generated externally (unchecked beyond consistency). - Any downstream error bound provided externally (outside the matrix-payload path). diff --git a/Nfp/IO/NfptPure.lean b/Nfp/IO/NfptPure.lean index cd13a24..6119f2b 100644 --- a/Nfp/IO/NfptPure.lean +++ b/Nfp/IO/NfptPure.lean @@ -33,6 +33,8 @@ structure NfptHeader where vocabSize : Nat /-- Sequence length used in the binary. -/ seqLen : Nat + /-- LayerNorm epsilon parameter. -/ + layerNormEps : Rat private def parseNat (s : String) : Except String Nat := match s.toNat? with @@ -50,6 +52,64 @@ private def readHeaderField (name : String) (fields : List (String × String)) : | some kv => parseNat kv.2 | none => throw s!"missing header field '{name}'" +private def parseInt (s : String) : Except String Int := + match s.toInt? with + | some n => Except.ok n + | none => Except.error s!"expected Int, got '{s}'" + +private def pow10 (k : Nat) : Nat := + Nat.pow 10 k + +private def parseRatScientific (s : String) : Except String Rat := do + let s := s.trim + let (sign, rest) := + if s.startsWith "-" then + (-1, s.drop 1) + else if s.startsWith "+" then + (1, s.drop 1) + else + (1, s) + let parts := rest.toLower.splitOn "e" + let (mant, expStr?) ← + match parts with + | [m] => pure (m, none) + | [m, e] => pure (m, some e) + | _ => throw s!"invalid scientific literal '{s}'" + let (intPart, fracPart) ← + match mant.splitOn "." with + | [i] => pure (i, "") + | [i, f] => pure (i, f) + | _ => throw s!"invalid decimal literal '{s}'" + let digits := intPart ++ fracPart + if digits = "" then + throw s!"invalid decimal literal '{s}'" + let n ← parseNat digits + let scale := fracPart.length + let base : Rat := + (Rat.ofInt (sign * Int.ofNat n)) / Rat.ofInt (Int.ofNat (pow10 scale)) + let exp ← + match expStr? with + | none => pure (0 : Int) + | some e => parseInt e + if exp ≥ 0 then + let k := Int.toNat exp + pure (base * Rat.ofInt (Int.ofNat (pow10 k))) + else + let k := Int.toNat (-exp) + pure (base / Rat.ofInt (Int.ofNat (pow10 k))) + +private def readHeaderFieldRat (names : List String) (fields : List (String × String)) : + Except String Rat := do + let rec loop : List String → Option String + | [] => none + | name :: rest => + match fields.find? (fun kv => kv.1 = name) with + | some kv => some kv.2 + | none => loop rest + match loop names with + | some raw => parseRatScientific raw + | none => throw s!"missing header field '{String.intercalate "|" names}'" + private def sentinelBytes : ByteArray := "BINARY_START\n".toUTF8 @@ -97,6 +157,7 @@ def parseHeader (data : ByteArray) : Except String (NfptHeader × Nat) := do let hiddenDim ← readHeaderField "hidden_dim" fields let vocabSize ← readHeaderField "vocab_size" fields let seqLen ← readHeaderField "seq_len" fields + let layerNormEps ← readHeaderFieldRat ["layer_norm_eps", "eps"] fields if numLayers = 0 then throw "num_layers must be positive" if numHeads = 0 then @@ -118,7 +179,8 @@ def parseHeader (data : ByteArray) : Except String (NfptHeader × Nat) := do headDim := headDim hiddenDim := hiddenDim vocabSize := vocabSize - seqLen := seqLen }, start) + seqLen := seqLen + layerNormEps := layerNormEps }, start) private def pow2 (k : Nat) : Nat := Nat.pow 2 k @@ -226,6 +288,15 @@ private def readF64Matrix (data : ByteArray) (off : Nat) (rows cols : Nat) : vals.get ⟨idx.val, hidx⟩ return mat +private def readF64Vec (data : ByteArray) (off : Nat) (count : Nat) : + Except String (Fin count → Rat) := do + let ⟨vals, hlen⟩ ← readF64List data off count + let hlen' : vals.length = count := by + simpa using hlen + let vec : Fin count → Rat := fun i => + vals.get ⟨i.val, lt_of_lt_of_eq i.isLt hlen'.symm⟩ + return vec + private def f64CountPerHead (h : NfptHeader) : Nat := 4 * h.modelDim * h.headDim + 3 * h.headDim @@ -258,33 +329,67 @@ private def headOffset (h : NfptHeader) (layer head : Nat) : Nat := layer * f64CountPerLayer h + head * f64CountPerHead h) +private def layerExtrasOffset (h : NfptHeader) (layer : Nat) : Nat := + bytesI32 h.seqLen + + bytesF64 (f64CountBeforeHeads h + + layer * f64CountPerLayer h + + h.numHeads * f64CountPerHead h) + +/-- Head weights plus biases for a single attention head. -/ +private structure HeadWeights (dModel dHead : Nat) where + wq : Fin dModel → Fin dHead → Rat + bq : Fin dHead → Rat + wk : Fin dModel → Fin dHead → Rat + bk : Fin dHead → Rat + wv : Fin dModel → Fin dHead → Rat + bv : Fin dHead → Rat + wo : Fin dModel → Fin dHead → Rat + private def readHeadWeights (data : ByteArray) (start : Nat) (h : NfptHeader) (layer head : Nat) : - Except String - ((Fin h.modelDim → Fin h.headDim → Rat) × - (Fin h.modelDim → Fin h.headDim → Rat) × - (Fin h.modelDim → Fin h.headDim → Rat) × - (Fin h.modelDim → Fin h.headDim → Rat)) := do + Except String (HeadWeights h.modelDim h.headDim) := do if layer < h.numLayers then if head < h.numHeads then let base := start + headOffset h layer head let wq ← readF64Matrix data base h.modelDim h.headDim let offbq := base + bytesF64 (h.modelDim * h.headDim) + let bq ← readF64Vec data offbq h.headDim let offwk := offbq + bytesF64 h.headDim let wk ← readF64Matrix data offwk h.modelDim h.headDim let offbk := offwk + bytesF64 (h.modelDim * h.headDim) + let bk ← readF64Vec data offbk h.headDim let offwv := offbk + bytesF64 h.headDim let wv ← readF64Matrix data offwv h.modelDim h.headDim let offbv := offwv + bytesF64 (h.modelDim * h.headDim) + let bv ← readF64Vec data offbv h.headDim let offwo := offbv + bytesF64 h.headDim let woRaw ← readF64Matrix data offwo h.headDim h.modelDim let wo : Fin h.modelDim → Fin h.headDim → Rat := fun i j => woRaw j i - return (wq, wk, wv, wo) + return { wq := wq, bq := bq, wk := wk, bk := bk, wv := wv, bv := bv, wo := wo } else throw s!"head index out of range: {head}" else throw s!"layer index out of range: {layer}" +private def readLayerAttnBiasLn1 (data : ByteArray) (start : Nat) (h : NfptHeader) + (layer : Nat) : + Except String ((Fin h.modelDim → Rat) × (Fin h.modelDim → Rat) × + (Fin h.modelDim → Rat)) := do + if layer < h.numLayers then + let base := start + layerExtrasOffset h layer + let attnBias ← readF64Vec data base h.modelDim + let offWIn := base + bytesF64 h.modelDim + let offBIn := offWIn + bytesF64 (h.modelDim * h.hiddenDim) + let offWOut := offBIn + bytesF64 h.hiddenDim + let offBOut := offWOut + bytesF64 (h.hiddenDim * h.modelDim) + let offLn1Gamma := offBOut + bytesF64 h.modelDim + let ln1Gamma ← readF64Vec data offLn1Gamma h.modelDim + let offLn1Beta := offLn1Gamma + bytesF64 h.modelDim + let ln1Beta ← readF64Vec data offLn1Beta h.modelDim + return (attnBias, ln1Gamma, ln1Beta) + else + throw s!"layer index out of range: {layer}" + /-- Read a single unembedding column as exact rationals. -/ def readUnembedColumn (data : ByteArray) (start : Nat) (h : NfptHeader) (col : Nat) : Except String (Fin h.modelDim → Rat) := do @@ -311,7 +416,8 @@ def readInductionHeadInputs (data : ByteArray) (start : Nat) (h : NfptHeader) Except String (Model.InductionHeadInputs h.seqLen h.modelDim h.headDim) := do let scale ← scaleOfHeadDim h.headDim let embed ← readEmbeddings data start h - let (wq, wk, wv, wo) ← readHeadWeights data start h layer head + let weights ← readHeadWeights data start h layer head + let (attnBias, ln1Gamma, ln1Beta) ← readLayerAttnBiasLn1 data start h layer let colTarget ← readUnembedColumn data start h dirTarget let colNegative ← readUnembedColumn data start h dirNegative let direction : Fin h.modelDim → Rat := fun i => colTarget i - colNegative i @@ -324,10 +430,17 @@ def readInductionHeadInputs (data : ByteArray) (start : Nat) (h : NfptHeader) active := active prev := prev embed := embed - wq := wq - wk := wk - wv := wv - wo := wo + lnEps := h.layerNormEps + ln1Gamma := ln1Gamma + ln1Beta := ln1Beta + wq := weights.wq + bq := weights.bq + wk := weights.wk + bk := weights.bk + wv := weights.wv + bv := weights.bv + wo := weights.wo + attnBias := attnBias directionSpec := directionSpec direction := direction } diff --git a/Nfp/IO/Pure.lean b/Nfp/IO/Pure.lean index 10ca560..18887ae 100644 --- a/Nfp/IO/Pure.lean +++ b/Nfp/IO/Pure.lean @@ -562,10 +562,17 @@ private structure HeadParseState (seq dModel dHead : Nat) where activeSeen : Bool prev : Fin seq → Option (Fin seq) embed : Fin seq → Fin dModel → Option Rat + lnEps : Option Rat + ln1Gamma : Fin dModel → Option Rat + ln1Beta : Fin dModel → Option Rat wq : Fin dModel → Fin dHead → Option Rat + bq : Fin dHead → Option Rat wk : Fin dModel → Fin dHead → Option Rat + bk : Fin dHead → Option Rat wv : Fin dModel → Fin dHead → Option Rat + bv : Fin dHead → Option Rat wo : Fin dModel → Fin dHead → Option Rat + attnBias : Fin dModel → Option Rat directionTarget : Option Nat directionNegative : Option Nat direction : Fin dModel → Option Rat @@ -576,10 +583,17 @@ private def initHeadState (seq dModel dHead : Nat) : HeadParseState seq dModel d activeSeen := false prev := fun _ => none embed := fun _ _ => none + lnEps := none + ln1Gamma := fun _ => none + ln1Beta := fun _ => none wq := fun _ _ => none + bq := fun _ => none wk := fun _ _ => none + bk := fun _ => none wv := fun _ _ => none + bv := fun _ => none wo := fun _ _ => none + attnBias := fun _ => none directionTarget := none directionNegative := none direction := fun _ => none } @@ -630,18 +644,41 @@ private def parseHeadLine {seq dModel dHead : Nat} (st : HeadParseState seq dMod | ["embed", q, d, val] => let mat ← setMatEntry st.embed (← parseNat q) (← parseNat d) (← parseRat val) return { st with embed := mat } + | ["ln_eps", val] => + if st.lnEps.isSome then + throw "duplicate ln_eps entry" + else + return { st with lnEps := some (← parseRat val) } + | ["ln1_gamma", d, val] => + let vec ← setVecEntry st.ln1Gamma (← parseNat d) (← parseRat val) + return { st with ln1Gamma := vec } + | ["ln1_beta", d, val] => + let vec ← setVecEntry st.ln1Beta (← parseNat d) (← parseRat val) + return { st with ln1Beta := vec } | ["wq", i, j, val] => let mat ← setMatEntry st.wq (← parseNat i) (← parseNat j) (← parseRat val) return { st with wq := mat } + | ["bq", j, val] => + let vec ← setVecEntry st.bq (← parseNat j) (← parseRat val) + return { st with bq := vec } | ["wk", i, j, val] => let mat ← setMatEntry st.wk (← parseNat i) (← parseNat j) (← parseRat val) return { st with wk := mat } + | ["bk", j, val] => + let vec ← setVecEntry st.bk (← parseNat j) (← parseRat val) + return { st with bk := vec } | ["wv", i, j, val] => let mat ← setMatEntry st.wv (← parseNat i) (← parseNat j) (← parseRat val) return { st with wv := mat } + | ["bv", j, val] => + let vec ← setVecEntry st.bv (← parseNat j) (← parseRat val) + return { st with bv := vec } | ["wo", i, j, val] => let mat ← setMatEntry st.wo (← parseNat i) (← parseNat j) (← parseRat val) return { st with wo := mat } + | ["attn_bias", d, val] => + let vec ← setVecEntry st.attnBias (← parseNat d) (← parseRat val) + return { st with attnBias := vec } | ["direction-target", tok] => if st.directionTarget.isSome then throw "duplicate direction-target entry" @@ -670,18 +707,34 @@ private def finalizeHeadState {seq dModel dHead : Nat} (hpos : 0 < seq) if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => finsetAll (Finset.univ : Finset (Fin dModel)) (fun d => (st.embed q d).isSome)) then throw "missing embed entries" + let lnEps ← + match st.lnEps with + | some v => pure v + | none => throw "missing ln_eps entry" + if !finsetAll (Finset.univ : Finset (Fin dModel)) (fun d => (st.ln1Gamma d).isSome) then + throw "missing ln1_gamma entries" + if !finsetAll (Finset.univ : Finset (Fin dModel)) (fun d => (st.ln1Beta d).isSome) then + throw "missing ln1_beta entries" if !finsetAll (Finset.univ : Finset (Fin dModel)) (fun i => finsetAll (Finset.univ : Finset (Fin dHead)) (fun j => (st.wq i j).isSome)) then throw "missing wq entries" + if !finsetAll (Finset.univ : Finset (Fin dHead)) (fun j => (st.bq j).isSome) then + throw "missing bq entries" if !finsetAll (Finset.univ : Finset (Fin dModel)) (fun i => finsetAll (Finset.univ : Finset (Fin dHead)) (fun j => (st.wk i j).isSome)) then throw "missing wk entries" + if !finsetAll (Finset.univ : Finset (Fin dHead)) (fun j => (st.bk j).isSome) then + throw "missing bk entries" if !finsetAll (Finset.univ : Finset (Fin dModel)) (fun i => finsetAll (Finset.univ : Finset (Fin dHead)) (fun j => (st.wv i j).isSome)) then throw "missing wv entries" + if !finsetAll (Finset.univ : Finset (Fin dHead)) (fun j => (st.bv j).isSome) then + throw "missing bv entries" if !finsetAll (Finset.univ : Finset (Fin dModel)) (fun i => finsetAll (Finset.univ : Finset (Fin dHead)) (fun j => (st.wo i j).isSome)) then throw "missing wo entries" + if !finsetAll (Finset.univ : Finset (Fin dModel)) (fun d => (st.attnBias d).isSome) then + throw "missing attn_bias entries" if !finsetAll (Finset.univ : Finset (Fin dModel)) (fun d => (st.direction d).isSome) then throw "missing direction entries" let directionSpec ← @@ -694,14 +747,26 @@ private def finalizeHeadState {seq dModel dHead : Nat} (hpos : 0 < seq) (st.prev q).getD defaultPrev let embedFun : Fin seq → Fin dModel → Rat := fun q d => (st.embed q d).getD 0 + let ln1GammaFun : Fin dModel → Rat := fun d => + (st.ln1Gamma d).getD 0 + let ln1BetaFun : Fin dModel → Rat := fun d => + (st.ln1Beta d).getD 0 let wqFun : Fin dModel → Fin dHead → Rat := fun i j => (st.wq i j).getD 0 + let bqFun : Fin dHead → Rat := fun j => + (st.bq j).getD 0 let wkFun : Fin dModel → Fin dHead → Rat := fun i j => (st.wk i j).getD 0 + let bkFun : Fin dHead → Rat := fun j => + (st.bk j).getD 0 let wvFun : Fin dModel → Fin dHead → Rat := fun i j => (st.wv i j).getD 0 + let bvFun : Fin dHead → Rat := fun j => + (st.bv j).getD 0 let woFun : Fin dModel → Fin dHead → Rat := fun i j => (st.wo i j).getD 0 + let attnBiasFun : Fin dModel → Rat := fun d => + (st.attnBias d).getD 0 let directionFun : Fin dModel → Rat := fun d => (st.direction d).getD 0 let active := @@ -714,10 +779,17 @@ private def finalizeHeadState {seq dModel dHead : Nat} (hpos : 0 < seq) active := active prev := prevFun embed := embedFun + lnEps := lnEps + ln1Gamma := ln1GammaFun + ln1Beta := ln1BetaFun wq := wqFun + bq := bqFun wk := wkFun + bk := bkFun wv := wvFun + bv := bvFun wo := woFun + attnBias := attnBiasFun directionSpec := directionSpec direction := directionFun } diff --git a/Nfp/Model/Gpt2.lean b/Nfp/Model/Gpt2.lean index 4607c7b..c81d0a8 100644 --- a/Nfp/Model/Gpt2.lean +++ b/Nfp/Model/Gpt2.lean @@ -39,12 +39,26 @@ structure Gpt2HeadSlice (seq dModel dHead vocab : Nat) where wpe : Fin seq → Fin dModel → Rat /-- Query projection weights. -/ wq : Fin dModel → Fin dHead → Rat + /-- Query projection bias. -/ + bq : Fin dHead → Rat /-- Key projection weights. -/ wk : Fin dModel → Fin dHead → Rat + /-- Key projection bias. -/ + bk : Fin dHead → Rat /-- Value projection weights. -/ wv : Fin dModel → Fin dHead → Rat + /-- Value projection bias. -/ + bv : Fin dHead → Rat /-- Output projection weights for this head slice. -/ wo : Fin dModel → Fin dHead → Rat + /-- Attention output bias (shared across heads). -/ + attnBias : Fin dModel → Rat + /-- LayerNorm epsilon for the attention input. -/ + lnEps : Rat + /-- LayerNorm scale for the attention input. -/ + ln1Gamma : Fin dModel → Rat + /-- LayerNorm bias for the attention input. -/ + ln1Beta : Fin dModel → Rat /-- Direction tokens for logit-diff certification. -/ direction : DirectionTokens vocab diff --git a/Nfp/Model/InductionHead.lean b/Nfp/Model/InductionHead.lean index d402d26..d90bf17 100644 --- a/Nfp/Model/InductionHead.lean +++ b/Nfp/Model/InductionHead.lean @@ -27,14 +27,28 @@ structure InductionHeadInputs (seq dModel dHead : Nat) where prev : Fin seq → Fin seq /-- Token embeddings for the sequence. -/ embed : Fin seq → Fin dModel → Rat + /-- LayerNorm epsilon used before attention. -/ + lnEps : Rat + /-- LayerNorm scale for pre-attention normalization. -/ + ln1Gamma : Fin dModel → Rat + /-- LayerNorm bias for pre-attention normalization. -/ + ln1Beta : Fin dModel → Rat /-- Query projection weights. -/ wq : Fin dModel → Fin dHead → Rat + /-- Query projection bias. -/ + bq : Fin dHead → Rat /-- Key projection weights. -/ wk : Fin dModel → Fin dHead → Rat + /-- Key projection bias. -/ + bk : Fin dHead → Rat /-- Value projection weights. -/ wv : Fin dModel → Fin dHead → Rat + /-- Value projection bias. -/ + bv : Fin dHead → Rat /-- Output projection weights (head slice). -/ wo : Fin dModel → Fin dHead → Rat + /-- Attention output bias (shared across heads). -/ + attnBias : Fin dModel → Rat /-- Logit-diff direction metadata. -/ directionSpec : DirectionSpec /-- Logit-diff direction vector in model space. -/ diff --git a/Nfp/Sound/Gpt2/HeadInputs.lean b/Nfp/Sound/Gpt2/HeadInputs.lean index 045a406..3d9eeab 100644 --- a/Nfp/Sound/Gpt2/HeadInputs.lean +++ b/Nfp/Sound/Gpt2/HeadInputs.lean @@ -28,10 +28,17 @@ def buildInductionHeadInputs {seq dModel dHead vocab : Nat} active := activeOfPeriod (seq := seq) period prev := prevOfPeriod (seq := seq) period embed := slice.embed + lnEps := slice.lnEps + ln1Gamma := slice.ln1Gamma + ln1Beta := slice.ln1Beta wq := slice.wq + bq := slice.bq wk := slice.wk + bk := slice.bk wv := slice.wv + bv := slice.bv wo := slice.wo + attnBias := slice.attnBias directionSpec := slice.direction.spec direction := slice.directionVec } @@ -43,10 +50,17 @@ theorem buildInductionHeadInputs_def {seq dModel dHead vocab : Nat} active := activeOfPeriod (seq := seq) period prev := prevOfPeriod (seq := seq) period embed := slice.embed + lnEps := slice.lnEps + ln1Gamma := slice.ln1Gamma + ln1Beta := slice.ln1Beta wq := slice.wq + bq := slice.bq wk := slice.wk + bk := slice.bk wv := slice.wv + bv := slice.bv wo := slice.wo + attnBias := slice.attnBias directionSpec := slice.direction.spec direction := slice.directionVec } := rfl diff --git a/Nfp/Sound/Induction.lean b/Nfp/Sound/Induction.lean index 19e7ab6..2a9a986 100644 --- a/Nfp/Sound/Induction.lean +++ b/Nfp/Sound/Induction.lean @@ -35,21 +35,24 @@ private opaque qVecVecOfInputs {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : Vector (Vector Rat dHead) seq := Vector.ofFn (fun q : Fin seq => Vector.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.embed q j) (fun j => inputs.wq j d))) + Linear.dotFin dModel (fun j => inputs.embed q j) (fun j => inputs.wq j d) + + inputs.bq d)) /-- Cached key projections for head inputs (opaque to avoid kernel reduction). -/ private opaque kVecVecOfInputs {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : Vector (Vector Rat dHead) seq := Vector.ofFn (fun q : Fin seq => Vector.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.embed q j) (fun j => inputs.wk j d))) + Linear.dotFin dModel (fun j => inputs.embed q j) (fun j => inputs.wk j d) + + inputs.bk d)) /-- Cached value projections for head inputs (opaque to avoid kernel reduction). -/ private opaque vVecVecOfInputs {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : Vector (Vector Rat dHead) seq := Vector.ofFn (fun q : Fin seq => Vector.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.embed q j) (fun j => inputs.wv j d))) + Linear.dotFin dModel (fun j => inputs.embed q j) (fun j => inputs.wv j d) + + inputs.bv d)) /-- Cached attention scores for head inputs (opaque to avoid kernel reduction). -/ private opaque scoresVecOfInputs {seq dModel dHead : Nat} diff --git a/README.md b/README.md index f0dc56d..342e36a 100644 --- a/README.md +++ b/README.md @@ -212,10 +212,17 @@ direction active prev embed +ln_eps +ln1_gamma +ln1_beta wq +bq wk +bk wv +bv wo +attn_bias ``` All `direction`, `embed`, and projection matrices must be fully specified. If no `active` lines diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index 403f99f..520ffda 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -12,10 +12,12 @@ It is intentionally brief and focused on the soundness boundary. path exists, but it currently uses only the unembedding direction and relies on an external **residual-interval certificate** (per-coordinate lower/upper bounds). - The `certify_head` path uses a **head-input file** extracted by an untrusted script; the extractor - currently ignores LayerNorm and bias terms, so it is not end-to-end faithful. -- The `certify_head_model` path derives head inputs from the model binary in Lean, but it still - ignores attention biases and LayerNorm, and currently requires `head_dim` to be a perfect square - to represent the scale as an exact rational. + now includes attention projection biases and LayerNorm metadata, but the Lean-side computation + still ignores LayerNorm and the shared attention output bias. +- The `certify_head_model` path derives head inputs from the model binary in Lean, includes + attention projection biases, but still ignores LayerNorm and the shared attention output bias. + It currently requires `head_dim` to be a perfect square to represent the scale as an exact + rational. - Performance: exact head-input recomputation in Lean can be slow for nontrivial sequence lengths. - There is no bridge theorem connecting certificate validity to a full circuit/model semantics statement (for example, a formal statement about logits under a transformer block stack). diff --git a/scripts/build_gpt2_head_inputs.py b/scripts/build_gpt2_head_inputs.py index 0065566..39f9878 100644 --- a/scripts/build_gpt2_head_inputs.py +++ b/scripts/build_gpt2_head_inputs.py @@ -99,36 +99,67 @@ def read_head_weights( hidden_dim: int, layer: int, head: int, -) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: +) -> Tuple[ + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, +]: target = (layer, head) wq = wk = wv = wo = None + bq = bk = bv = None + attn_bias = ln1_gamma = ln1_beta = None for layer_idx in range(num_layers): for head_idx in range(num_heads): wq_block = read_f64(f, model_dim * head_dim).reshape(model_dim, head_dim) - _ = read_f64(f, head_dim) # b_Q + bq_block = read_f64(f, head_dim) # b_Q wk_block = read_f64(f, model_dim * head_dim).reshape(model_dim, head_dim) - _ = read_f64(f, head_dim) # b_K + bk_block = read_f64(f, head_dim) # b_K wv_block = read_f64(f, model_dim * head_dim).reshape(model_dim, head_dim) - _ = read_f64(f, head_dim) # b_V + bv_block = read_f64(f, head_dim) # b_V wo_block = read_f64(f, head_dim * model_dim).reshape(head_dim, model_dim) if (layer_idx, head_idx) == target: wq = wq_block wk = wk_block wv = wv_block wo = wo_block + bq = bq_block + bk = bk_block + bv = bv_block # Skip per-layer non-head data. - skip_f64(f, model_dim) # attn_bias + attn_bias_block = read_f64(f, model_dim) # attn_bias skip_f64(f, model_dim * hidden_dim) # w_in skip_f64(f, hidden_dim) # b_in skip_f64(f, hidden_dim * model_dim) # w_out skip_f64(f, model_dim) # b_out - skip_f64(f, model_dim) # ln1_gamma - skip_f64(f, model_dim) # ln1_beta + ln1_gamma_block = read_f64(f, model_dim) # ln1_gamma + ln1_beta_block = read_f64(f, model_dim) # ln1_beta skip_f64(f, model_dim) # ln2_gamma skip_f64(f, model_dim) # ln2_beta - if wq is None or wk is None or wv is None or wo is None: + if layer_idx == layer: + attn_bias = attn_bias_block + ln1_gamma = ln1_gamma_block + ln1_beta = ln1_beta_block + if ( + wq is None + or wk is None + or wv is None + or wo is None + or bq is None + or bk is None + or bv is None + or attn_bias is None + or ln1_gamma is None + or ln1_beta is None + ): raise SystemExit("Failed to locate head weights.") - return wq, wk, wv, wo + return wq, bq, wk, bk, wv, bv, wo, attn_bias, ln1_gamma, ln1_beta def read_unembed_columns( @@ -159,9 +190,16 @@ def write_head_inputs( prev: np.ndarray, active: np.ndarray, wq: np.ndarray, + bq: np.ndarray, wk: np.ndarray, + bk: np.ndarray, wv: np.ndarray, + bv: np.ndarray, wo: np.ndarray, + attn_bias: np.ndarray, + ln_eps: Fraction, + ln1_gamma: np.ndarray, + ln1_beta: np.ndarray, direction_target: int, direction_negative: int, direction: np.ndarray, @@ -181,18 +219,31 @@ def write_head_inputs( for q in range(seq): for d in range(model_dim): f.write(f"embed {q} {d} {rat_to_str(rat_from_float_exact(float(embeddings[q, d])))}\n") + f.write(f"ln_eps {rat_to_str(ln_eps)}\n") + for d in range(model_dim): + f.write(f"ln1_gamma {d} {rat_to_str(rat_from_float_exact(float(ln1_gamma[d])))}\n") + for d in range(model_dim): + f.write(f"ln1_beta {d} {rat_to_str(rat_from_float_exact(float(ln1_beta[d])))}\n") for i in range(model_dim): for j in range(head_dim): f.write(f"wq {i} {j} {rat_to_str(rat_from_float_exact(float(wq[i, j])))}\n") + for j in range(head_dim): + f.write(f"bq {j} {rat_to_str(rat_from_float_exact(float(bq[j])))}\n") for i in range(model_dim): for j in range(head_dim): f.write(f"wk {i} {j} {rat_to_str(rat_from_float_exact(float(wk[i, j])))}\n") + for j in range(head_dim): + f.write(f"bk {j} {rat_to_str(rat_from_float_exact(float(bk[j])))}\n") for i in range(model_dim): for j in range(head_dim): f.write(f"wv {i} {j} {rat_to_str(rat_from_float_exact(float(wv[i, j])))}\n") + for j in range(head_dim): + f.write(f"bv {j} {rat_to_str(rat_from_float_exact(float(bv[j])))}\n") for i in range(model_dim): for j in range(head_dim): f.write(f"wo {i} {j} {rat_to_str(rat_from_float_exact(float(wo[i, j])))}\n") + for d in range(model_dim): + f.write(f"attn_bias {d} {rat_to_str(rat_from_float_exact(float(attn_bias[d])))}\n") f.write(f"direction-target {direction_target}\n") f.write(f"direction-negative {direction_negative}\n") for d in range(model_dim): @@ -234,7 +285,7 @@ def main() -> None: tokens = read_i32(f, seq_len) embeddings = read_f64(f, seq_len * model_dim).reshape(seq_len, model_dim) - wq, wk, wv, wo_raw = read_head_weights( + wq, bq, wk, bk, wv, bv, wo_raw, attn_bias, ln1_gamma, ln1_beta = read_head_weights( f, num_layers, num_heads, @@ -271,6 +322,10 @@ def main() -> None: wo = wo_raw.T args.output.parent.mkdir(parents=True, exist_ok=True) + ln_eps_raw = header.get("layer_norm_eps") + if ln_eps_raw is None: + raise SystemExit("Missing layer_norm_eps in header.") + ln_eps = rat_from_float_exact(float(ln_eps_raw)) write_head_inputs( args.output, scale, @@ -279,9 +334,16 @@ def main() -> None: prev, active, wq, + bq, wk, + bk, wv, + bv, wo, + attn_bias, + ln_eps, + ln1_gamma, + ln1_beta, args.direction_target, args.direction_negative, direction, From 1f390eafef75d48235a33ad979cddbf4c95214ed Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 3 Jan 2026 21:14:19 +0100 Subject: [PATCH 104/244] Fix induction bounds proofs and docs --- CLAIMS.md | 5 +- Nfp/Cli.lean | 12 +- Nfp/IO.lean | 10 +- Nfp/IO/NfptPure.lean | 46 +- Nfp/Model/InductionPrompt.lean | 23 + Nfp/Sound/Bounds/LayerNorm.lean | 405 +++++++++++++ Nfp/Sound/Bounds/MatrixNorm.lean | 106 ++++ Nfp/Sound/Induction.lean | 985 +++++++++++++++++++++---------- README.md | 6 +- SOUNDNESS_LIMITATIONS.md | 6 +- 10 files changed, 1283 insertions(+), 321 deletions(-) create mode 100644 Nfp/Sound/Bounds/LayerNorm.lean diff --git a/CLAIMS.md b/CLAIMS.md index 234e04c..a725704 100644 --- a/CLAIMS.md +++ b/CLAIMS.md @@ -24,8 +24,9 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri and verifies the resulting certificates. - `nfp induction certify_head` recomputes scores/values from exact head inputs and verifies the resulting induction certificate (experimental, potentially slow). -- `nfp induction certify_head_model` reads a model binary, derives head inputs in Lean, and - verifies the resulting induction certificate (includes attention projection biases). +- `nfp induction certify_head_model` reads a model binary, derives head inputs in Lean, + and verifies the resulting induction certificate (includes attention projection biases + and derives `prev`/active from the stored token sequence by default). - `nfp induction certify_end_to_end` composes a head-level logit-diff lower bound with a downstream error certificate (arithmetic consistency only). - `nfp induction certify_end_to_end_matrix` computes a downstream bound from a matrix payload diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index 7d6d152..4ae6d36 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -197,14 +197,14 @@ def runInductionCertifyHeadModel (p : Parsed) : IO UInt32 := do let modelPath := p.flag! "model" |>.as! String let layer := p.flag! "layer" |>.as! Nat let head := p.flag! "head" |>.as! Nat - let period := p.flag! "period" |>.as! Nat + let period? := (p.flag? "period").map (·.as! Nat) let dirTarget := p.flag! "direction-target" |>.as! Nat let dirNegative := p.flag! "direction-negative" |>.as! Nat let minActive? := (p.flag? "min-active").map (·.as! Nat) let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) let minMarginStr? := (p.flag? "min-margin").map (·.as! String) let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) - IO.runInductionCertifyHeadModel modelPath layer head period dirTarget dirNegative + IO.runInductionCertifyHeadModel modelPath layer head dirTarget dirNegative period? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? /-- `nfp induction certify_head_model` subcommand. -/ @@ -215,7 +215,7 @@ def inductionCertifyHeadModelCmd : Cmd := `[Cli| model : String; "Path to the NFP_BINARY_V1 model file." layer : Nat; "Layer index for the induction head." head : Nat; "Head index for the induction head." - period : Nat; "Prompt period used to derive active/prev." + period : Nat; "Optional prompt period override (default: derive from tokens)." "direction-target" : Nat; "Target token id for logit-diff direction." "direction-negative" : Nat; "Negative token id for logit-diff direction." "min-active" : Nat; "Optional minimum number of active queries required \ @@ -246,11 +246,11 @@ def runInductionHeadIntervalModel (p : Parsed) : IO UInt32 := do let modelPath := p.flag! "model" |>.as! String let layer := p.flag! "layer" |>.as! Nat let head := p.flag! "head" |>.as! Nat - let period := p.flag! "period" |>.as! Nat + let period? := (p.flag? "period").map (·.as! Nat) let dirTarget := p.flag! "direction-target" |>.as! Nat let dirNegative := p.flag! "direction-negative" |>.as! Nat let outPath? := (p.flag? "out").map (·.as! String) - IO.runInductionHeadIntervalModel modelPath layer head period dirTarget dirNegative outPath? + IO.runInductionHeadIntervalModel modelPath layer head dirTarget dirNegative period? outPath? /-- `nfp induction head_interval_model` subcommand. -/ def inductionHeadIntervalModelCmd : Cmd := `[Cli| @@ -260,7 +260,7 @@ def inductionHeadIntervalModelCmd : Cmd := `[Cli| model : String; "Path to the NFP_BINARY_V1 model file." layer : Nat; "Layer index for the induction head." head : Nat; "Head index for the induction head." - period : Nat; "Prompt period used to derive active/prev." + period : Nat; "Optional prompt period override (default: derive from tokens)." "direction-target" : Nat; "Target token id for logit-diff direction." "direction-negative" : Nat; "Negative token id for logit-diff direction." out : String; "Optional path to write the residual-interval certificate." diff --git a/Nfp/IO.lean b/Nfp/IO.lean index 34d57b4..b83dd8f 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -808,7 +808,7 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} let tol := cert.eps * (cert.values.hi - cert.values.lo) let logitDiffLB? := Circuit.logitDiffLowerBound cert.active cert.prev cert.eps - cert.values.lo cert.values.hi cert.values.vals + cert.values.lo cert.values.hi cert.values.valsLo let effectiveMinLogitDiff := match minLogitDiff? with | some v => some v @@ -869,7 +869,7 @@ def runInductionCertifyHead (inputsPath : System.FilePath) /-- Build and check induction certificates from a model binary. -/ def runInductionCertifyHeadModel (modelPath : System.FilePath) - (layer head period dirTarget dirNegative : Nat) + (layer head dirTarget dirNegative : Nat) (period? : Option Nat) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? @@ -896,7 +896,7 @@ def runInductionCertifyHeadModel (modelPath : System.FilePath) | Except.ok ⟨header, start⟩ => match NfptPure.readInductionHeadInputs - data start header layer head period dirTarget dirNegative + data start header layer head dirTarget dirNegative period? with | Except.error msg => IO.eprintln s!"error: {msg}" @@ -917,7 +917,7 @@ def runInductionHeadInterval (inputsPath : System.FilePath) /-- Build head-output interval bounds from a model binary. -/ def runInductionHeadIntervalModel (modelPath : System.FilePath) - (layer head period dirTarget dirNegative : Nat) + (layer head dirTarget dirNegative : Nat) (period? : Option Nat) (outPath? : Option System.FilePath) : IO UInt32 := do let data ← IO.FS.readBinFile modelPath match NfptPure.parseHeader data with @@ -927,7 +927,7 @@ def runInductionHeadIntervalModel (modelPath : System.FilePath) | Except.ok ⟨header, start⟩ => match NfptPure.readInductionHeadInputs - data start header layer head period dirTarget dirNegative + data start header layer head dirTarget dirNegative period? with | Except.error msg => IO.eprintln s!"error: {msg}" diff --git a/Nfp/IO/NfptPure.lean b/Nfp/IO/NfptPure.lean index 6119f2b..f2547fa 100644 --- a/Nfp/IO/NfptPure.lean +++ b/Nfp/IO/NfptPure.lean @@ -226,6 +226,15 @@ private def readNatLE (data : ByteArray) (off : Nat) (count : Nat) : Option Nat else none +private def readI32 (data : ByteArray) (off : Nat) : Option Int := do + let bits ← readNatLE data off 4 + let two31 := pow2 31 + let two32 := pow2 32 + if bits < two31 then + some (Int.ofNat bits) + else + some (Int.ofNat bits - Int.ofNat two32) + private def readF64Rat (data : ByteArray) (off : Nat) : Option Rat := do let bits ← readNatLE data off 8 ratOfFloatBits bits @@ -276,6 +285,17 @@ private def readF64List (data : ByteArray) (off : Nat) (count : Nat) : return ⟨v :: rest.1, by simp [rest.2]⟩ | none => throw s!"invalid f64 at offset {off}" +private def readI32List (data : ByteArray) (off : Nat) (count : Nat) : + Except String {xs : List Int // xs.length = count} := do + match count with + | 0 => return ⟨[], rfl⟩ + | Nat.succ n => + match readI32 data off with + | some v => + let rest ← readI32List data (off + bytesI32 1) n + return ⟨v :: rest.1, by simp [rest.2]⟩ + | none => throw s!"invalid i32 at offset {off}" + private def readF64Matrix (data : ByteArray) (off : Nat) (rows cols : Nat) : Except String (Fin rows → Fin cols → Rat) := do let count := rows * cols @@ -323,6 +343,19 @@ def readEmbeddings (data : ByteArray) (start : Nat) (h : NfptHeader) : let base := start + bytesI32 h.seqLen readF64Matrix data base h.seqLen h.modelDim +/-- Read input token ids stored in the binary. -/ +def readTokens (data : ByteArray) (start : Nat) (h : NfptHeader) : + Except String (Fin h.seqLen → Nat) := do + let ⟨vals, hlen⟩ ← readI32List data start h.seqLen + let ok := vals.all (fun z => decide (0 ≤ z)) + if !ok then + throw "token ids must be nonnegative" + let hlen' : vals.length = h.seqLen := by + simpa using hlen + let tokens : Fin h.seqLen → Nat := fun i => + Int.toNat (vals.get ⟨i.val, lt_of_lt_of_eq i.isLt hlen'.symm⟩) + return tokens + private def headOffset (h : NfptHeader) (layer head : Nat) : Nat := bytesI32 h.seqLen + bytesF64 (f64CountBeforeHeads h + @@ -412,9 +445,10 @@ def readUnembedColumn (data : ByteArray) (start : Nat) (h : NfptHeader) (col : N /-- Read induction-head inputs directly from the model binary. -/ def readInductionHeadInputs (data : ByteArray) (start : Nat) (h : NfptHeader) - (layer head period dirTarget dirNegative : Nat) : + (layer head dirTarget dirNegative : Nat) (period? : Option Nat) : Except String (Model.InductionHeadInputs h.seqLen h.modelDim h.headDim) := do let scale ← scaleOfHeadDim h.headDim + let tokens ← readTokens data start h let embed ← readEmbeddings data start h let weights ← readHeadWeights data start h layer head let (attnBias, ln1Gamma, ln1Beta) ← readLayerAttnBiasLn1 data start h layer @@ -423,8 +457,14 @@ def readInductionHeadInputs (data : ByteArray) (start : Nat) (h : NfptHeader) let direction : Fin h.modelDim → Rat := fun i => colTarget i - colNegative i let directionSpec : Circuit.DirectionSpec := { target := dirTarget, negative := dirNegative } - let active := Model.activeOfPeriod (seq := h.seqLen) period - let prev := Model.prevOfPeriod (seq := h.seqLen) period + let active := + match period? with + | some period => Model.activeOfPeriod (seq := h.seqLen) period + | none => Model.activeOfTokens (seq := h.seqLen) tokens + let prev := + match period? with + | some period => Model.prevOfPeriod (seq := h.seqLen) period + | none => Model.prevOfTokens (seq := h.seqLen) tokens pure { scale := scale active := active diff --git a/Nfp/Model/InductionPrompt.lean b/Nfp/Model/InductionPrompt.lean index 8c5ff61..df0ac2f 100644 --- a/Nfp/Model/InductionPrompt.lean +++ b/Nfp/Model/InductionPrompt.lean @@ -1,5 +1,6 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later +import Mathlib.Data.Finset.Max import Mathlib.Data.Fintype.Basic /-! @@ -27,6 +28,28 @@ theorem mem_activeOfPeriod {seq : Nat} {period : Nat} {q : Fin seq} : q ∈ activeOfPeriod (seq := seq) period ↔ period ≤ q.val := by simp [activeOfPeriod] +/-- `prev` map induced by token repeats (defaulting to `0` when no prior match exists). -/ +def prevOfTokens {seq : Nat} (tokens : Fin seq → Nat) (q : Fin seq) : Fin seq := by + classical + let hpos : 0 < seq := lt_of_le_of_lt (Nat.zero_le _) q.isLt + let zero : Fin seq := ⟨0, hpos⟩ + let candidates : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).filter (fun k => + k.val < q.val ∧ tokens k = tokens q) + by_cases h : candidates.Nonempty + · exact Finset.max' candidates h + · exact zero + +/-- Active queries induced by token repeats. -/ +def activeOfTokens {seq : Nat} (tokens : Fin seq → Nat) : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).filter (fun q => + ∃ k, k.val < q.val ∧ tokens k = tokens q) + +/-- Membership characterization for `activeOfTokens`. -/ +theorem mem_activeOfTokens {seq : Nat} {tokens : Fin seq → Nat} {q : Fin seq} : + q ∈ activeOfTokens tokens ↔ ∃ k, k.val < q.val ∧ tokens k = tokens q := by + simp [activeOfTokens] + end Model end Nfp diff --git a/Nfp/Sound/Bounds/LayerNorm.lean b/Nfp/Sound/Bounds/LayerNorm.lean new file mode 100644 index 0000000..2cc9aae --- /dev/null +++ b/Nfp/Sound/Bounds/LayerNorm.lean @@ -0,0 +1,405 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.BigOperators.Group.Finset.Basic +import Mathlib.Algebra.Order.Ring.Rat +import Mathlib.Data.Nat.Sqrt +import Mathlib.Data.Real.Sqrt +import Mathlib.Data.Rat.Cast.Order + +/-! +LayerNorm interval bounds for exact rational inputs. + +This module computes rational interval bounds for LayerNorm outputs and proves +those bounds sound for real-valued LayerNorm semantics. +-/ + +namespace Nfp + +namespace Sound + +namespace Bounds + +open scoped BigOperators + +/-- Mean of a finite vector (defaults to `0` when `n = 0`). -/ +def mean {n : Nat} (x : Fin n → Rat) : Rat := + if n = 0 then + 0 + else + (∑ i, x i) / n + +/-- Unfold `mean` when `n ≠ 0`. -/ +theorem mean_def {n : Nat} (x : Fin n → Rat) (h : n ≠ 0) : + mean x = (∑ i, x i) / n := by + simp [mean, h] + +/-- Variance of a finite vector (defaults to `0` when `n = 0`). -/ +def variance {n : Nat} (x : Fin n → Rat) : Rat := + if n = 0 then + 0 + else + let μ := mean x + (∑ i, (x i - μ) ^ 2) / n + +/-- Unfold `variance` when `n ≠ 0`. -/ +theorem variance_def {n : Nat} (x : Fin n → Rat) (h : n ≠ 0) : + variance x = + let μ := mean x + (∑ i, (x i - μ) ^ 2) / n := by + simp [variance, h] + +/-- Variance is nonnegative when `n ≠ 0`. -/ +theorem variance_nonneg {n : Nat} (x : Fin n → Rat) (h : n ≠ 0) : + 0 ≤ variance x := by + classical + have hsum : 0 ≤ ∑ i, (x i - mean x) ^ 2 := by + refine Finset.sum_nonneg ?_ + intro i _ + exact sq_nonneg (x i - mean x) + have hden : 0 ≤ (n : Rat) := by + exact_mod_cast (Nat.zero_le n) + have hdiv : 0 ≤ (∑ i, (x i - mean x) ^ 2) / n := + div_nonneg hsum hden + simpa [variance_def x h] using hdiv + +/-- Rational lower bound for a square root. -/ +def sqrtLower (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den + let a := Nat.sqrt num + let b := Nat.sqrt den + (a : Rat) / (b + 1 : Rat) + +/-- Rational upper bound for a square root. -/ +def sqrtUpper (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den + let a := Nat.sqrt num + let b := Nat.sqrt den + (a + 1 : Rat) / (b : Rat) + +/-- `sqrtLower` is nonnegative. -/ +theorem sqrtLower_nonneg (q : Rat) : 0 ≤ sqrtLower q := by + classical + unfold sqrtLower + have hden : 0 ≤ (Nat.sqrt q.den + 1 : Rat) := by + exact_mod_cast (Nat.zero_le _) + have hnum : 0 ≤ (Nat.sqrt q.num.natAbs : Rat) := by + exact_mod_cast (Nat.zero_le _) + exact div_nonneg hnum hden + +/-! Strict positivity helpers. -/ + +/-- `sqrtLower` is positive when its input is positive. -/ +theorem sqrtLower_pos {q : Rat} (hq : 0 < q) : 0 < sqrtLower q := by + classical + unfold sqrtLower + have hnum_pos : 0 < (Nat.sqrt q.num.natAbs : Rat) := by + have hnum_pos' : 0 < q.num.natAbs := by + have hnum : 0 < q.num := (Rat.num_pos (a := q)).2 hq + exact Int.natAbs_pos.mpr hnum.ne' + exact_mod_cast (Nat.sqrt_pos.2 hnum_pos') + have hden_pos : 0 < (Nat.sqrt q.den + 1 : Rat) := by + exact_mod_cast (Nat.succ_pos _) + exact div_pos hnum_pos hden_pos + +/-- `sqrtUpper` is nonnegative. -/ +theorem sqrtUpper_nonneg (q : Rat) : 0 ≤ sqrtUpper q := by + classical + unfold sqrtUpper + have hden : 0 ≤ (Nat.sqrt q.den : Rat) := by + exact_mod_cast (Nat.zero_le _) + have hnum : 0 ≤ (Nat.sqrt q.num.natAbs + 1 : Rat) := by + exact_mod_cast (Nat.zero_le _) + exact div_nonneg hnum hden + +/-- `sqrtUpper` is always positive. -/ +theorem sqrtUpper_pos (q : Rat) : 0 < sqrtUpper q := by + classical + unfold sqrtUpper + have hnum_pos : 0 < (Nat.sqrt q.num.natAbs + 1 : Rat) := by + exact_mod_cast (Nat.succ_pos _) + have hden_pos : 0 < (Nat.sqrt q.den : Rat) := by + have hden : 0 < q.den := q.den_pos + exact_mod_cast (Nat.sqrt_pos.2 hden) + exact div_pos hnum_pos hden_pos + +/-- Square-root lower bound in reals. -/ +theorem sqrtLower_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : + (sqrtLower q : Real) ≤ Real.sqrt (q : Real) := by + classical + -- Set up numerator/denominator witnesses. + set num : Nat := q.num.natAbs + set den : Nat := q.den + set a : Nat := Nat.sqrt num + set b : Nat := Nat.sqrt den + have hden_pos : 0 < (den : Real) := by + exact_mod_cast q.den_pos + have hbpos : 0 < (b + 1 : Real) := by + exact_mod_cast (Nat.succ_pos b) + have hnum_le : (a ^ 2 : Real) ≤ num := by + exact_mod_cast (Nat.sqrt_le' num) + have hden_le : (den : Real) ≤ (b + 1) ^ 2 := by + exact_mod_cast (le_of_lt (Nat.lt_succ_sqrt' den)) + have hmul : (a ^ 2 : Real) * den ≤ (num : Real) * (b + 1) ^ 2 := by + have hden_nonneg : 0 ≤ (den : Real) := by exact_mod_cast (Nat.zero_le den) + have hnum_nonneg : 0 ≤ (num : Real) := by exact_mod_cast (Nat.zero_le num) + exact mul_le_mul hnum_le hden_le hden_nonneg hnum_nonneg + have hbpos2 : 0 < (b + 1 : Real) ^ 2 := by + nlinarith [hbpos] + have hdiv : (a ^ 2 : Real) / (b + 1) ^ 2 ≤ (num : Real) / den := by + exact (div_le_div_iff₀ hbpos2 hden_pos).2 hmul + have hpow : ((a : Real) / (b + 1 : Real)) ^ 2 = (a ^ 2 : Real) / (b + 1) ^ 2 := by + simp [pow_two, div_mul_div_comm] + have hq_cast : (q : Real) = (num : Real) / den := by + have hnum_nonneg : 0 ≤ q.num := by + exact (Rat.num_nonneg (q := q)).2 hq + have hnum_eq : (num : Int) = q.num := by + simpa [num] using (Int.natAbs_of_nonneg hnum_nonneg) + have hnum_cast : (q.num : Real) = (num : Real) := by + exact_mod_cast hnum_eq.symm + have hq_rat : (q : Real) = (q.num : Real) / q.den := by + simp [Rat.cast_def] + simpa [hnum_cast, den] using hq_rat + have hsq : ((a : Real) / (b + 1 : Real)) ^ 2 ≤ (q : Real) := by + simpa [hpow, hq_cast, den, num] using hdiv + have hnonneg : 0 ≤ (a : Real) / (b + 1 : Real) := by + have hnum_nonneg : 0 ≤ (a : Real) := by exact_mod_cast (Nat.zero_le a) + have hden_nonneg : 0 ≤ (b + 1 : Real) := by exact_mod_cast (Nat.zero_le (b + 1)) + exact div_nonneg hnum_nonneg hden_nonneg + have hq_nonneg : 0 ≤ (q : Real) := by exact_mod_cast hq + have hle : (a : Real) / (b + 1 : Real) ≤ Real.sqrt (q : Real) := + (Real.le_sqrt hnonneg hq_nonneg).2 hsq + simpa [sqrtLower, num, den, a, b] using hle + +/-- Square-root upper bound in reals. -/ +theorem real_sqrt_le_sqrtUpper {q : Rat} (hq : 0 ≤ q) : + Real.sqrt (q : Real) ≤ (sqrtUpper q : Real) := by + classical + set num : Nat := q.num.natAbs + set den : Nat := q.den + set a : Nat := Nat.sqrt num + set b : Nat := Nat.sqrt den + have hden_pos : 0 < (den : Real) := by + exact_mod_cast q.den_pos + have hbpos : 0 < (b : Real) := by + have hb : 0 < b := by + have hden : 0 < den := q.den_pos + exact (Nat.sqrt_pos).2 hden + exact_mod_cast hb + have hnum_lt : (num : Real) < (a + 1) ^ 2 := by + exact_mod_cast (Nat.lt_succ_sqrt' num) + have hden_le : (b ^ 2 : Real) ≤ den := by + exact_mod_cast (Nat.sqrt_le' den) + have hmul : (num : Real) * (b ^ 2) ≤ (a + 1) ^ 2 * den := by + have hb2_nonneg : 0 ≤ (b ^ 2 : Real) := by + exact sq_nonneg (b : Real) + have hsq_nonneg : 0 ≤ (a + 1 : Real) ^ 2 := by + exact sq_nonneg (a + 1 : Real) + exact mul_le_mul (le_of_lt hnum_lt) hden_le hb2_nonneg hsq_nonneg + have hbpos2 : 0 < (b : Real) ^ 2 := by + nlinarith [hbpos] + have hdiv : (num : Real) / den ≤ (a + 1) ^ 2 / (b : Real) ^ 2 := by + exact (div_le_div_iff₀ hden_pos hbpos2).2 hmul + have hpow : ((a + 1 : Real) / (b : Real)) ^ 2 = (a + 1) ^ 2 / (b : Real) ^ 2 := by + simp [pow_two, div_mul_div_comm] + have hq_cast : (q : Real) = (num : Real) / den := by + have hnum_nonneg : 0 ≤ q.num := by + exact (Rat.num_nonneg (q := q)).2 hq + have hnum_eq : (num : Int) = q.num := by + simpa [num] using (Int.natAbs_of_nonneg hnum_nonneg) + have hnum_cast : (q.num : Real) = (num : Real) := by + exact_mod_cast hnum_eq.symm + have hq_rat : (q : Real) = (q.num : Real) / q.den := by + simp [Rat.cast_def] + simpa [hnum_cast, den] using hq_rat + have hsq : (q : Real) ≤ ((a + 1 : Real) / (b : Real)) ^ 2 := by + simpa [hpow, hq_cast, den, num] using hdiv + have hnonneg : 0 ≤ ((a + 1 : Real) / (b : Real)) := by + have hnum_nonneg : 0 ≤ (a + 1 : Real) := by exact_mod_cast (Nat.zero_le (a + 1)) + have hden_nonneg : 0 ≤ (b : Real) := by exact_mod_cast (Nat.zero_le b) + exact div_nonneg hnum_nonneg hden_nonneg + have hle : Real.sqrt (q : Real) ≤ (a + 1 : Real) / (b : Real) := + (Real.sqrt_le_iff).2 ⟨hnonneg, hsq⟩ + simpa [sqrtUpper, num, den, a, b] using hle + +/-- Bounds for multiplying a scalar by a bounded value. -/ +def scaleInterval (x lo hi : Rat) : Rat × Rat := + if 0 ≤ x then + (x * lo, x * hi) + else + (x * hi, x * lo) + +/-- `scaleInterval` bounds a product. -/ +theorem scaleInterval_bounds {x lo hi y : Rat} + (hlo : lo ≤ y) (hhi : y ≤ hi) : + let bounds := scaleInterval x lo hi + bounds.1 ≤ x * y ∧ x * y ≤ bounds.2 := by + by_cases hx : 0 ≤ x + · have h1 : x * lo ≤ x * y := by + exact mul_le_mul_of_nonneg_left hlo hx + have h2 : x * y ≤ x * hi := by + exact mul_le_mul_of_nonneg_left hhi hx + simp [scaleInterval, hx, h1, h2] + · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) + have h1 : x * hi ≤ x * y := by + exact mul_le_mul_of_nonpos_left hhi hx' + have h2 : x * y ≤ x * lo := by + exact mul_le_mul_of_nonpos_left hlo hx' + simp [scaleInterval, hx, h1, h2] + +/-- `scaleInterval` bounds interpreted in the reals. -/ +theorem scaleInterval_bounds_real {x lo hi : Rat} {y : Real} + (hlo : (lo : Real) ≤ y) (hhi : y ≤ (hi : Real)) : + let bounds := scaleInterval x lo hi + (bounds.1 : Real) ≤ (x : Real) * y ∧ (x : Real) * y ≤ (bounds.2 : Real) := by + by_cases hx : 0 ≤ x + · have h1 : (x : Real) * (lo : Real) ≤ (x : Real) * y := by + exact mul_le_mul_of_nonneg_left hlo (by exact_mod_cast hx) + have h2 : (x : Real) * y ≤ (x : Real) * (hi : Real) := by + exact mul_le_mul_of_nonneg_left hhi (by exact_mod_cast hx) + simp [scaleInterval, hx, h1, h2] + · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) + have h1 : (x : Real) * (hi : Real) ≤ (x : Real) * y := by + exact mul_le_mul_of_nonpos_left hhi (by exact_mod_cast hx') + have h2 : (x : Real) * y ≤ (x : Real) * (lo : Real) := by + exact mul_le_mul_of_nonpos_left hlo (by exact_mod_cast hx') + simp [scaleInterval, hx, h1, h2] + +/-- Real-valued LayerNorm output for a vector. -/ +noncomputable def layerNormReal {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) : Fin n → Real := + if n = 0 then + fun _ => 0 + else + let μ : Real := mean x + let varEps : Real := (variance x + eps : Rat) + let invStd : Real := (Real.sqrt varEps)⁻¹ + fun i => (gamma i : Real) * ((x i : Real) - μ) * invStd + (beta i : Real) + +/-- Interval bounds for LayerNorm outputs. -/ +def layerNormBounds {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) : + (Fin n → Rat) × (Fin n → Rat) := + if n = 0 then + (fun _ => 0, fun _ => 0) + else + let μ := mean x + let var := variance x + let varEps := var + eps + let sLo := sqrtLower varEps + let sHi := sqrtUpper varEps + let invLo := sHi⁻¹ + let invHi := sLo⁻¹ + let normBounds : Fin n → Rat × Rat := fun i => + let centered := x i - μ + scaleInterval centered invLo invHi + let outBounds : Fin n → Rat × Rat := fun i => + let nb := normBounds i + let sb := scaleInterval (gamma i) nb.1 nb.2 + (sb.1 + beta i, sb.2 + beta i) + (fun i => (outBounds i).1, fun i => (outBounds i).2) + +/-- `layerNormBounds` soundness for real LayerNorm outputs. -/ +theorem layerNormBounds_spec {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) + (hne : n ≠ 0) (heps : 0 < eps) : + let bounds := layerNormBounds eps gamma beta x + ∀ i, + (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i ∧ + layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + classical + intro bounds i + have hvar_nonneg : 0 ≤ variance x := variance_nonneg x hne + have hvarEps_pos : 0 < variance x + eps := by + exact add_pos_of_nonneg_of_pos hvar_nonneg heps + have hvarEps_nonneg : 0 ≤ variance x + eps := by + exact le_of_lt hvarEps_pos + let varEps : Rat := variance x + eps + let sLo : Rat := sqrtLower varEps + let sHi : Rat := sqrtUpper varEps + let invLo : Rat := sHi⁻¹ + let invHi : Rat := sLo⁻¹ + let invStd : Real := (Real.sqrt (varEps : Real))⁻¹ + have hsLo : (sLo : Real) ≤ Real.sqrt (varEps : Real) := by + have hsLo' := sqrtLower_le_real_sqrt (q := varEps) hvarEps_nonneg + simpa [sLo, varEps, Rat.cast_add] using hsLo' + have hsHi : Real.sqrt (varEps : Real) ≤ (sHi : Real) := by + have hsHi' := real_sqrt_le_sqrtUpper (q := varEps) hvarEps_nonneg + simpa [sHi, varEps, Rat.cast_add] using hsHi' + have hsqrt_pos : 0 < Real.sqrt (varEps : Real) := by + exact Real.sqrt_pos.2 (by exact_mod_cast hvarEps_pos) + have hsLo_pos : 0 < (sLo : Real) := by + exact_mod_cast (sqrtLower_pos (q := varEps) hvarEps_pos) + have hsHi_ne : (sHi : Rat) ≠ 0 := ne_of_gt (sqrtUpper_pos varEps) + have hsLo_ne : (sLo : Rat) ≠ 0 := ne_of_gt (sqrtLower_pos (q := varEps) hvarEps_pos) + have hcast_inv_hi : (invLo : Real) = (sHi : Real)⁻¹ := by + have hnum_ne : (sHi.num : Real) ≠ 0 := by + exact_mod_cast (Rat.num_ne_zero (q := sHi)).2 hsHi_ne + have hcast := Rat.cast_inv_of_ne_zero (q := sHi) hnum_ne + dsimp [invLo] + exact hcast + have hcast_inv_lo : (invHi : Real) = (sLo : Real)⁻¹ := by + have hnum_ne : (sLo.num : Real) ≠ 0 := by + exact_mod_cast (Rat.num_ne_zero (q := sLo)).2 hsLo_ne + have hcast := Rat.cast_inv_of_ne_zero (q := sLo) hnum_ne + dsimp [invHi] + exact hcast + have hinv_lo : (invLo : Real) ≤ invStd := by + have hcalc : (sHi : Real)⁻¹ ≤ invStd := by + have h := one_div_le_one_div_of_le hsqrt_pos hsHi + simpa [one_div, invStd] using h + simpa [hcast_inv_hi] using hcalc + have hinv_hi : invStd ≤ (invHi : Real) := by + have hcalc : invStd ≤ (sLo : Real)⁻¹ := by + have h := one_div_le_one_div_of_le hsLo_pos hsLo + simpa [one_div, invStd] using h + simpa [hcast_inv_lo] using hcalc + let μ : Rat := mean x + let centered : Rat := x i - μ + let nb : Rat × Rat := scaleInterval centered invLo invHi + have hnb : (nb.1 : Real) ≤ (centered : Real) * invStd ∧ + (centered : Real) * invStd ≤ (nb.2 : Real) := by + have hscale := scaleInterval_bounds_real (x := centered) + (lo := invLo) (hi := invHi) (y := invStd) hinv_lo hinv_hi + simpa [nb] using hscale + let sb : Rat × Rat := scaleInterval (gamma i) nb.1 nb.2 + have hsb : + (sb.1 : Real) ≤ (gamma i : Real) * ((centered : Real) * invStd) ∧ + (gamma i : Real) * ((centered : Real) * invStd) ≤ (sb.2 : Real) := by + have hscale := scaleInterval_bounds_real (x := gamma i) + (lo := nb.1) (hi := nb.2) (y := (centered : Real) * invStd) hnb.1 hnb.2 + simpa [sb] using hscale + let lo : Rat := sb.1 + beta i + let hi : Rat := sb.2 + beta i + have hreal : + layerNormReal eps gamma beta x i = + (gamma i : Real) * ((centered : Real) * invStd) + (beta i : Real) := by + calc + layerNormReal eps gamma beta x i = + (gamma i : Real) * ((x i : Real) - μ) * invStd + (beta i : Real) := by + simp [layerNormReal, hne, μ, invStd, varEps] + _ = (gamma i : Real) * (((x i : Real) - μ) * invStd) + (beta i : Real) := by + simp [mul_assoc] + _ = (gamma i : Real) * ((centered : Real) * invStd) + (beta i : Real) := by + simp [centered] + have hlo : (lo : Real) ≤ layerNormReal eps gamma beta x i := by + have hlo' : (sb.1 : Real) ≤ (gamma i : Real) * ((centered : Real) * invStd) := hsb.1 + have hlo'' : (lo : Real) ≤ + (gamma i : Real) * ((centered : Real) * invStd) + (beta i : Real) := by + simpa [lo] using add_le_add_right hlo' (beta i : Real) + simpa [hreal] using hlo'' + have hhi : layerNormReal eps gamma beta x i ≤ (hi : Real) := by + have hhi' : (gamma i : Real) * ((centered : Real) * invStd) ≤ (sb.2 : Real) := hsb.2 + have hhi'' : + (gamma i : Real) * ((centered : Real) * invStd) + (beta i : Real) ≤ (hi : Real) := by + simpa [hi] using add_le_add_right hhi' (beta i : Real) + simpa [hreal] using hhi'' + simpa [bounds, layerNormBounds, hne, μ, varEps, invLo, invHi, centered, nb, sb, lo, hi] using + And.intro hlo hhi + +end Bounds + +end Sound + +end Nfp diff --git a/Nfp/Sound/Bounds/MatrixNorm.lean b/Nfp/Sound/Bounds/MatrixNorm.lean index fe58dca..9836486 100644 --- a/Nfp/Sound/Bounds/MatrixNorm.lean +++ b/Nfp/Sound/Bounds/MatrixNorm.lean @@ -4,6 +4,9 @@ import Mathlib.Algebra.Order.BigOperators.Group.Finset import Mathlib.Algebra.Order.Ring.Abs import Mathlib.Algebra.Order.Ring.Rat import Mathlib.Data.Matrix.Mul +import Mathlib.Data.Rat.BigOperators +import Mathlib.Data.Rat.Cast.Order +import Mathlib.Data.Real.Basic import Nfp.Circuit.Cert.DownstreamLinear import Nfp.Circuit.Cert.ResidualInterval @@ -213,6 +216,109 @@ theorem abs_dotProduct_le_dotIntervalAbsBound {n : Nat} (v lo hi x : Fin n → R unfold dotIntervalAbsBound exact habs +/-! Real-valued bounds from rational intervals. -/ + +theorem dotIntervalLower_le_dotProduct_real {n : Nat} (v lo hi : Fin n → Rat) + (x : Fin n → Real) + (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : + (dotIntervalLower v lo hi : Real) ≤ dotProduct (fun j => (v j : Real)) x := by + classical + have hcast : + (dotIntervalLower v lo hi : Real) = + ∑ j, if 0 ≤ v j then (v j : Real) * (lo j : Real) else (v j : Real) * (hi j : Real) := by + conv_lhs => simp [dotIntervalLower] + refine Finset.sum_congr rfl ?_ + intro j _ + by_cases hv : 0 ≤ v j + · simp [hv] + · simp [hv] + have hsum : + (∑ j, if 0 ≤ v j then (v j : Real) * (lo j : Real) else (v j : Real) * (hi j : Real)) ≤ + ∑ j, (v j : Real) * x j := by + refine Finset.sum_le_sum ?_ + intro j _ + by_cases hv : 0 ≤ v j + · have h1 : (v j : Real) * (lo j : Real) ≤ (v j : Real) * x j := by + exact mul_le_mul_of_nonneg_left (hlo j) (by exact_mod_cast hv) + simpa [hv] using h1 + · have hv' : (v j : Real) ≤ 0 := by + exact_mod_cast (le_of_lt (lt_of_not_ge hv)) + have h1 : (v j : Real) * (hi j : Real) ≤ (v j : Real) * x j := by + exact mul_le_mul_of_nonpos_left (hhi j) hv' + simpa [hv] using h1 + simpa [hcast, dotProduct] using hsum + +theorem dotProduct_le_dotIntervalUpper_real {n : Nat} (v lo hi : Fin n → Rat) + (x : Fin n → Real) + (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : + dotProduct (fun j => (v j : Real)) x ≤ (dotIntervalUpper v lo hi : Real) := by + classical + have hcast : + (dotIntervalUpper v lo hi : Real) = + ∑ j, if 0 ≤ v j then (v j : Real) * (hi j : Real) else (v j : Real) * (lo j : Real) := by + conv_lhs => simp [dotIntervalUpper] + refine Finset.sum_congr rfl ?_ + intro j _ + by_cases hv : 0 ≤ v j + · simp [hv] + · simp [hv] + have hsum : + ∑ j, (v j : Real) * x j ≤ + ∑ j, if 0 ≤ v j then (v j : Real) * (hi j : Real) else (v j : Real) * (lo j : Real) := by + refine Finset.sum_le_sum ?_ + intro j _ + by_cases hv : 0 ≤ v j + · have h1 : (v j : Real) * x j ≤ (v j : Real) * (hi j : Real) := by + exact mul_le_mul_of_nonneg_left (hhi j) (by exact_mod_cast hv) + simpa [hv] using h1 + · have hv' : (v j : Real) ≤ 0 := by + exact_mod_cast (le_of_lt (lt_of_not_ge hv)) + have h1 : (v j : Real) * x j ≤ (v j : Real) * (lo j : Real) := by + exact mul_le_mul_of_nonpos_left (hlo j) hv' + simpa [hv] using h1 + simpa [hcast, dotProduct] using hsum + +theorem abs_le_max_abs_abs_of_interval_real {a b x : Real} (hlo : a ≤ x) (hhi : x ≤ b) : + |x| ≤ max |a| |b| := by + by_cases hx : 0 ≤ x + · have hb : 0 ≤ b := le_trans hx hhi + have hx' : |x| = x := abs_of_nonneg hx + have hb' : |b| = b := abs_of_nonneg hb + calc + |x| = x := hx' + _ ≤ b := hhi + _ = |b| := hb'.symm + _ ≤ max |a| |b| := le_max_right _ _ + · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) + have ha : a ≤ 0 := le_trans hlo hx' + have hxabs : |x| = -x := abs_of_nonpos hx' + have haabs : |a| = -a := abs_of_nonpos ha + calc + |x| = -x := hxabs + _ ≤ -a := neg_le_neg hlo + _ = |a| := by simp [haabs] + _ ≤ max |a| |b| := le_max_left _ _ + +theorem abs_dotProduct_le_dotIntervalAbsBound_real {n : Nat} (v lo hi : Fin n → Rat) + (x : Fin n → Real) + (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : + |dotProduct (fun j => (v j : Real)) x| ≤ (dotIntervalAbsBound v lo hi : Real) := by + have hlow : + (dotIntervalLower v lo hi : Real) ≤ dotProduct (fun j => (v j : Real)) x := + dotIntervalLower_le_dotProduct_real v lo hi x hlo hhi + have hhigh : + dotProduct (fun j => (v j : Real)) x ≤ (dotIntervalUpper v lo hi : Real) := + dotProduct_le_dotIntervalUpper_real v lo hi x hlo hhi + have habs : + |dotProduct (fun j => (v j : Real)) x| ≤ + max |(dotIntervalLower v lo hi : Real)| |(dotIntervalUpper v lo hi : Real)| := + abs_le_max_abs_abs_of_interval_real hlow hhigh + have hcast : + (dotIntervalAbsBound v lo hi : Real) = + max |(dotIntervalLower v lo hi : Real)| |(dotIntervalUpper v lo hi : Real)| := by + simp [dotIntervalAbsBound] + simpa [hcast] using habs + /-- Matrix-interval lower bounds dominate matrix-vector products. -/ theorem mulVecIntervalLower_le_mulVec {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (lo hi x : Fin n → Rat) (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : diff --git a/Nfp/Sound/Induction.lean b/Nfp/Sound/Induction.lean index 2a9a986..08492f5 100644 --- a/Nfp/Sound/Induction.lean +++ b/Nfp/Sound/Induction.lean @@ -1,6 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Mathlib.Algebra.BigOperators.Group.Finset.Basic +import Mathlib.Algebra.Order.BigOperators.Group.Finset import Mathlib.Algebra.Order.Ring.Rat import Mathlib.Data.Finset.Lattice.Fold import Mathlib.Data.Rat.Cast.Order @@ -10,6 +11,8 @@ import Nfp.Circuit.Cert.SoftmaxMargin import Nfp.Circuit.Cert.ValueRange import Nfp.Circuit.Layers.Softmax import Nfp.Model.InductionHead +import Nfp.Sound.Bounds.LayerNorm +import Nfp.Sound.Bounds.MatrixNorm import Nfp.Sound.Linear.FinFold /-! @@ -27,67 +30,84 @@ namespace Sound open scoped BigOperators open Nfp.Circuit +open Nfp.Sound.Bounds variable {seq : Nat} -/-- Cached query projections for head inputs (opaque to avoid kernel reduction). -/ -private opaque qVecVecOfInputs {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : Vector (Vector Rat dHead) seq := - Vector.ofFn (fun q : Fin seq => - Vector.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.embed q j) (fun j => inputs.wq j d) + - inputs.bq d)) - -/-- Cached key projections for head inputs (opaque to avoid kernel reduction). -/ -private opaque kVecVecOfInputs {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : Vector (Vector Rat dHead) seq := - Vector.ofFn (fun q : Fin seq => - Vector.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.embed q j) (fun j => inputs.wk j d) + - inputs.bk d)) - -/-- Cached value projections for head inputs (opaque to avoid kernel reduction). -/ -private opaque vVecVecOfInputs {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : Vector (Vector Rat dHead) seq := - Vector.ofFn (fun q : Fin seq => - Vector.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.embed q j) (fun j => inputs.wv j d) + - inputs.bv d)) - -/-- Cached attention scores for head inputs (opaque to avoid kernel reduction). -/ -private opaque scoresVecOfInputs {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (qVecVec kVecVec : Vector (Vector Rat dHead) seq) : Vector (Vector Rat seq) seq := - Vector.ofFn (fun q : Fin seq => - Vector.ofFn (fun k : Fin seq => - let qVec : Fin dHead → Rat := fun d => (qVecVec.get q).get d - let kVec : Fin dHead → Rat := fun d => (kVecVec.get k).get d - inputs.scale * (Linear.dotFin dHead (fun d => qVec d) (fun d => kVec d)))) - -/-- Cached direction head for head inputs (opaque to avoid kernel reduction). -/ -private opaque dirHeadVecOfInputs {seq dModel dHead : Nat} +/-- Cached direction head for head inputs. -/ +private def dirHeadVecOfInputs {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : Vector Rat dHead := Vector.ofFn (fun d : Fin dHead => Linear.dotFin dModel (fun j => inputs.wo j d) (fun j => inputs.direction j)) -/-- Cached value projections for head inputs (opaque to avoid kernel reduction). -/ -private opaque valsVecOfInputs {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vVecVec : Vector (Vector Rat dHead) seq) (dirHeadVec : Vector Rat dHead) : - Vector Rat seq := - Vector.ofFn (fun k : Fin seq => - let vVec : Fin dHead → Rat := fun d => (vVecVec.get k).get d - let dirHead : Fin dHead → Rat := fun d => dirHeadVec.get d - Linear.dotFin dHead (fun d => vVec d) (fun d => dirHead d)) - -/-- Cached per-key head outputs in model space (opaque to avoid kernel reduction). -/ -private opaque headValueVecOfInputs {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vVecVec : Vector (Vector Rat dHead) seq) : Vector (Vector Rat dModel) seq := - Vector.ofFn (fun k : Fin seq => - Vector.ofFn (fun i : Fin dModel => - let vVec : Fin dHead → Rat := fun d => (vVecVec.get k).get d - Linear.dotFin dHead (fun d => vVec d) (fun d => inputs.wo i d))) +/-- Real-valued LayerNorm outputs for head inputs. -/ +private noncomputable def lnRealOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dModel → Real := + fun q => + Bounds.layerNormReal inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q) + +/-- Real-valued query projections for head inputs. -/ +private noncomputable def qRealOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dHead → Real := + fun q d => + dotProduct (fun j => (inputs.wq j d : Real)) (lnRealOfInputs inputs q) + (inputs.bq d : Real) + +/-- Real-valued key projections for head inputs. -/ +private noncomputable def kRealOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dHead → Real := + fun q d => + dotProduct (fun j => (inputs.wk j d : Real)) (lnRealOfInputs inputs q) + (inputs.bk d : Real) + +/-- Real-valued value projections for head inputs. -/ +private noncomputable def vRealOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dHead → Real := + fun q d => + dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs q) + (inputs.bv d : Real) + +/-- Real-valued attention scores for head inputs. -/ +private noncomputable def scoresRealOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin seq → Real := + fun q k => + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) + +/-- Real-valued per-key head outputs in model space. -/ +private noncomputable def headValueRealOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dModel → Real := + fun k i => + dotProduct (fun d => (inputs.wo i d : Real)) (fun d => vRealOfInputs inputs k d) + +/-- Real-valued direction scores for head inputs. -/ +private noncomputable def valsRealOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Real := + let dirHead : Fin dHead → Real := fun d => (dirHeadVecOfInputs inputs).get d + fun k => dotProduct dirHead (fun d => vRealOfInputs inputs k d) + +/-- Interval data for direction values. -/ +structure ValueInterval (seq : Nat) where + /-- Lower bound for values. -/ + lo : Rat + /-- Upper bound for values. -/ + hi : Rat + /-- Lower bounds on per-key values. -/ + valsLo : Fin seq → Rat + /-- Upper bounds on per-key values. -/ + valsHi : Fin seq → Rat + /-- Optional logit-diff direction metadata (ignored by the checker). -/ + direction : Option DirectionSpec + +/-- Soundness predicate for direction-value interval data. -/ +structure ValueIntervalBounds {seq : Nat} (vals : Fin seq → Real) + (c : ValueInterval seq) : Prop where + /-- Interval endpoints are ordered. -/ + lo_le_hi : c.lo ≤ c.hi + /-- `lo` is below every lower bound. -/ + lo_le_valsLo : ∀ k, (c.lo : Real) ≤ (c.valsLo k : Real) + /-- Bounds sandwich the real values. -/ + vals_bounds : + ∀ k, (c.valsLo k : Real) ≤ vals k ∧ vals k ≤ (c.valsHi k : Real) + /-- `hi` is above every upper bound. -/ + valsHi_le_hi : ∀ k, (c.valsHi k : Real) ≤ (c.hi : Real) /-- Sound induction-certificate payload built from exact head inputs. -/ structure InductionHeadCert (seq : Nat) where @@ -99,22 +119,21 @@ structure InductionHeadCert (seq : Nat) where active : Finset (Fin seq) /-- `prev` selector for induction-style attention. -/ prev : Fin seq → Fin seq - /-- Score matrix entries. -/ - scores : Fin seq → Fin seq → Rat - /-- Value-range certificate for the direction values. -/ - values : ValueRangeCert seq + /-- Value-interval certificate for the direction values. -/ + values : ValueInterval seq /-- Soundness predicate for `InductionHeadCert`. -/ -structure InductionHeadCertSound [NeZero seq] (c : InductionHeadCert seq) : Prop where +structure InductionHeadCertSound [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) : Prop where /-- Softmax weights respect the derived margin bounds. -/ softmax_bounds : Layers.SoftmaxMarginBoundsOn (Val := Real) (c.eps : Real) (c.margin : Real) (fun q => q ∈ c.active) c.prev - (fun q k => (c.scores q k : Real)) - (fun q k => Circuit.softmax (fun j => (c.scores q j : Real)) k) - /-- Value-range bounds hold for the certificate values. -/ + (scoresRealOfInputs inputs) + (fun q k => Circuit.softmax (scoresRealOfInputs inputs q) k) + /-- Interval bounds hold for the direction values. -/ value_bounds : - Layers.ValueRangeBounds (Val := Rat) c.values.lo c.values.hi c.values.vals + ValueIntervalBounds (vals := valsRealOfInputs inputs) c.values /-- Build and certify a softmax-margin certificate from exact scores/weights. -/ def buildSoftmaxMarginCert? [NeZero seq] @@ -189,58 +208,228 @@ def buildValueRangeCert? [NeZero seq] /-- Build and certify induction certificates from exact head inputs. -/ def buildInductionCertFromHead? [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : - Option {c : InductionHeadCert seq // InductionHeadCertSound c} := by + Option {c : InductionHeadCert seq // InductionHeadCertSound inputs c} := by classical - let qVecVec := qVecVecOfInputs inputs - let kVecVec := kVecVecOfInputs inputs - let vVecVec := vVecVecOfInputs inputs - let scoresVec := scoresVecOfInputs inputs qVecVec kVecVec - let scores : Fin seq → Fin seq → Rat := fun q k => - (scoresVec.get q).get k - let otherKeys : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - let marginAt : Fin seq → Rat := fun q => - let other := otherKeys q - if h : other.Nonempty then - other.inf' h (fun k => scores q (inputs.prev q) - scores q k) - else - (0 : Rat) - let margin : Rat := - if h : inputs.active.Nonempty then - inputs.active.inf' h marginAt - else - (0 : Rat) - let eps : Rat := - if margin < 0 then - (1 : Rat) - else - (seq - 1 : Rat) / (1 + margin) - have hseq : (1 : Nat) ≤ seq := - Nat.succ_le_iff.mpr (Nat.pos_of_ne_zero (NeZero.ne seq)) - let dirHeadVec := dirHeadVecOfInputs inputs - let valsVec := valsVecOfInputs inputs vVecVec dirHeadVec - let vals : Fin seq → Rat := fun k => valsVec.get k - exact - match buildValueRangeCert? vals (some inputs.directionSpec) with - | none => none - | some ⟨valCert, hval⟩ => + by_cases hEps : 0 < inputs.lnEps + · by_cases hmodel : dModel = 0 + · exact none + · let lnBounds : Fin seq → (Fin dModel → Rat) × (Fin dModel → Rat) := fun q => + Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q) + let lnLo : Fin seq → Fin dModel → Rat := fun q => (lnBounds q).1 + let lnHi : Fin seq → Fin dModel → Rat := fun q => (lnBounds q).2 + let qLo : Fin seq → Fin dHead → Rat := fun q d => + dotIntervalLower (fun j => inputs.wq j d) (lnLo q) (lnHi q) + inputs.bq d + let qHi : Fin seq → Fin dHead → Rat := fun q d => + dotIntervalUpper (fun j => inputs.wq j d) (lnLo q) (lnHi q) + inputs.bq d + let kLo : Fin seq → Fin dHead → Rat := fun q d => + dotIntervalLower (fun j => inputs.wk j d) (lnLo q) (lnHi q) + inputs.bk d + let kHi : Fin seq → Fin dHead → Rat := fun q d => + dotIntervalUpper (fun j => inputs.wk j d) (lnLo q) (lnHi q) + inputs.bk d + let vLo : Fin seq → Fin dHead → Rat := fun q d => + dotIntervalLower (fun j => inputs.wv j d) (lnLo q) (lnHi q) + inputs.bv d + let vHi : Fin seq → Fin dHead → Rat := fun q d => + dotIntervalUpper (fun j => inputs.wv j d) (lnLo q) (lnHi q) + inputs.bv d + let qAbs : Fin seq → Fin dHead → Rat := fun q d => + max |qLo q d| |qHi q d| + let kAbs : Fin seq → Fin dHead → Rat := fun q d => + max |kLo q d| |kHi q d| + let dotAbs : Fin seq → Fin seq → Rat := fun q k => + dotProduct (fun d => qAbs q d) (fun d => kAbs k d) + let scoreAbs : Fin seq → Fin seq → Rat := fun q k => + |inputs.scale| * dotAbs q k + let scoreLo : Fin seq → Fin seq → Rat := fun q k => + -scoreAbs q k + let scoreHi : Fin seq → Fin seq → Rat := fun q k => + scoreAbs q k + let otherKeys : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + let marginAt : Fin seq → Rat := fun q => + let other := otherKeys q + if h : other.Nonempty then + other.inf' h (fun k => scoreLo q (inputs.prev q) - scoreHi q k) + else + (0 : Rat) + let margin : Rat := + if h : inputs.active.Nonempty then + inputs.active.inf' h marginAt + else + (0 : Rat) + let eps : Rat := + if margin < 0 then + (1 : Rat) + else + (seq - 1 : Rat) / (1 + margin) + have hseq : (1 : Nat) ≤ seq := + Nat.succ_le_iff.mpr (Nat.pos_of_ne_zero (NeZero.ne seq)) + let dirHeadVec := dirHeadVecOfInputs inputs + let dirHead : Fin dHead → Rat := fun d => dirHeadVec.get d + let valsLo : Fin seq → Rat := fun k => + dotIntervalLower dirHead (vLo k) (vHi k) + let valsHi : Fin seq → Rat := fun k => + dotIntervalUpper dirHead (vLo k) (vHi k) + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := by simp [univ] + let lo := univ.inf' hnonempty valsLo + let hi := univ.sup' hnonempty valsHi + let valCert : ValueInterval seq := + { lo := lo + hi := hi + valsLo := valsLo + valsHi := valsHi + direction := some inputs.directionSpec } let cert : InductionHeadCert seq := { eps := eps margin := margin active := inputs.active prev := inputs.prev - scores := scores values := valCert } - have hvalues : Layers.ValueRangeBounds (Val := Rat) valCert.lo valCert.hi valCert.vals := - Circuit.checkValueRangeCert_sound valCert hval - let scoresReal : Fin seq → Fin seq → Real := fun q k => (scores q k : Real) + have hln_bounds : + ∀ q i, (lnLo q i : Real) ≤ lnRealOfInputs inputs q i ∧ + lnRealOfInputs inputs q i ≤ (lnHi q i : Real) := by + intro q i + have hln := + Bounds.layerNormBounds_spec (eps := inputs.lnEps) + (gamma := inputs.ln1Gamma) (beta := inputs.ln1Beta) + (x := inputs.embed q) hmodel hEps + simpa [lnBounds, lnLo, lnHi, lnRealOfInputs] using hln i + have hq_bounds : + ∀ q d, (qLo q d : Real) ≤ qRealOfInputs inputs q d ∧ + qRealOfInputs inputs q d ≤ (qHi q d : Real) := by + intro q d + have hln := hln_bounds q + have hlo : ∀ j, (lnLo q j : Real) ≤ lnRealOfInputs inputs q j := fun j => (hln j).1 + have hhi : ∀ j, lnRealOfInputs inputs q j ≤ (lnHi q j : Real) := fun j => (hln j).2 + have hlow := + dotIntervalLower_le_dotProduct_real (v := fun j => inputs.wq j d) + (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi + have hhigh := + dotProduct_le_dotIntervalUpper_real (v := fun j => inputs.wq j d) + (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi + have hlow' := add_le_add_right hlow (inputs.bq d : Real) + have hhigh' := add_le_add_right hhigh (inputs.bq d : Real) + constructor + · simpa [qLo, qRealOfInputs, Rat.cast_add] using hlow' + · simpa [qHi, qRealOfInputs, Rat.cast_add] using hhigh' + have hk_bounds : + ∀ q d, (kLo q d : Real) ≤ kRealOfInputs inputs q d ∧ + kRealOfInputs inputs q d ≤ (kHi q d : Real) := by + intro q d + have hln := hln_bounds q + have hlo : ∀ j, (lnLo q j : Real) ≤ lnRealOfInputs inputs q j := fun j => (hln j).1 + have hhi : ∀ j, lnRealOfInputs inputs q j ≤ (lnHi q j : Real) := fun j => (hln j).2 + have hlow := + dotIntervalLower_le_dotProduct_real (v := fun j => inputs.wk j d) + (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi + have hhigh := + dotProduct_le_dotIntervalUpper_real (v := fun j => inputs.wk j d) + (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi + have hlow' := add_le_add_right hlow (inputs.bk d : Real) + have hhigh' := add_le_add_right hhigh (inputs.bk d : Real) + constructor + · simpa [kLo, kRealOfInputs, Rat.cast_add] using hlow' + · simpa [kHi, kRealOfInputs, Rat.cast_add] using hhigh' + have hv_bounds : + ∀ q d, (vLo q d : Real) ≤ vRealOfInputs inputs q d ∧ + vRealOfInputs inputs q d ≤ (vHi q d : Real) := by + intro q d + have hln := hln_bounds q + have hlo : ∀ j, (lnLo q j : Real) ≤ lnRealOfInputs inputs q j := fun j => (hln j).1 + have hhi : ∀ j, lnRealOfInputs inputs q j ≤ (lnHi q j : Real) := fun j => (hln j).2 + have hlow := + dotIntervalLower_le_dotProduct_real (v := fun j => inputs.wv j d) + (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi + have hhigh := + dotProduct_le_dotIntervalUpper_real (v := fun j => inputs.wv j d) + (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi + have hlow' := add_le_add_right hlow (inputs.bv d : Real) + have hhigh' := add_le_add_right hhigh (inputs.bv d : Real) + constructor + · simpa [vLo, vRealOfInputs, Rat.cast_add] using hlow' + · simpa [vHi, vRealOfInputs, Rat.cast_add] using hhigh' + have hscore_bounds : + ∀ q k, (scoreLo q k : Real) ≤ scoresRealOfInputs inputs q k ∧ + scoresRealOfInputs inputs q k ≤ (scoreHi q k : Real) := by + intro q k + let scoresReal := scoresRealOfInputs inputs + have hq_abs : ∀ d, |qRealOfInputs inputs q d| ≤ (qAbs q d : Real) := by + intro d + have hq := hq_bounds q d + have h := abs_le_max_abs_abs_of_interval_real hq.1 hq.2 + simpa [qAbs] using h + have hk_abs : ∀ d, |kRealOfInputs inputs k d| ≤ (kAbs k d : Real) := by + intro d + have hk := hk_bounds k d + have h := abs_le_max_abs_abs_of_interval_real hk.1 hk.2 + simpa [kAbs] using h + have hdot_abs : + |dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d)| ≤ + (dotAbs q k : Real) := by + have hsum : + |∑ d, qRealOfInputs inputs q d * kRealOfInputs inputs k d| ≤ + ∑ d, |qRealOfInputs inputs q d * kRealOfInputs inputs k d| := by + simpa [dotProduct] using + (Finset.abs_sum_le_sum_abs (s := (Finset.univ : Finset (Fin dHead))) + (f := fun d => qRealOfInputs inputs q d * kRealOfInputs inputs k d)) + have hterm : + ∀ d, + |qRealOfInputs inputs q d * kRealOfInputs inputs k d| ≤ + (qAbs q d : Real) * (kAbs k d : Real) := by + intro d + have hq := hq_abs d + have hk := hk_abs d + have hqnonneg : 0 ≤ (qAbs q d : Real) := by + have hqnonneg' : 0 ≤ qAbs q d := by + have h1 : 0 ≤ |qLo q d| := abs_nonneg (qLo q d) + exact le_trans h1 (le_max_left _ _) + exact_mod_cast hqnonneg' + calc + |qRealOfInputs inputs q d * kRealOfInputs inputs k d| = + |qRealOfInputs inputs q d| * |kRealOfInputs inputs k d| := by + simp [abs_mul] + _ ≤ (qAbs q d : Real) * (kAbs k d : Real) := + mul_le_mul hq hk (abs_nonneg _) hqnonneg + have hsum_le : + ∑ d, |qRealOfInputs inputs q d * kRealOfInputs inputs k d| ≤ + ∑ d, (qAbs q d : Real) * (kAbs k d : Real) := by + refine Finset.sum_le_sum ?_ + intro d _ + exact hterm d + have hcast : + (dotAbs q k : Real) = + ∑ d, (qAbs q d : Real) * (kAbs k d : Real) := by + simp [dotAbs, dotProduct] + have hfinal := hsum.trans (hsum_le.trans_eq hcast.symm) + simpa [dotProduct] using hfinal + have hscale_abs : 0 ≤ (|inputs.scale| : Real) := by + exact_mod_cast (abs_nonneg (inputs.scale)) + have hscore_abs : + |scoresReal q k| ≤ (scoreAbs q k : Real) := by + have hdot_abs' := hdot_abs + have hmul : + |scoresReal q k| = + (|inputs.scale| : Real) * + |dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d)| := by + simp [scoresReal, scoresRealOfInputs, abs_mul] + have hmul_le : + (|inputs.scale| : Real) * + |dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d)| ≤ + (|inputs.scale| : Real) * (dotAbs q k : Real) := by + exact mul_le_mul_of_nonneg_left hdot_abs' hscale_abs + simpa [scoreAbs, hmul] using hmul_le + have hscore_bounds := (abs_le).1 hscore_abs + constructor + · simpa [scoresReal, scoreLo] using hscore_bounds.1 + · simpa [scoresReal, scoreHi] using hscore_bounds.2 + let scoresReal := scoresRealOfInputs inputs let weights : Fin seq → Fin seq → Real := fun q k => Circuit.softmax (scoresReal q) k let others : Fin seq → Finset (Fin seq) := fun q => (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - have hscore_margin : + have hscore_margin_real : ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → - scores q k + margin ≤ scores q (inputs.prev q) := by + scoresReal q k + (margin : Real) ≤ scoresReal q (inputs.prev q) := by intro q hq k hk by_cases hactive : inputs.active.Nonempty · have hmargin_le : margin ≤ marginAt q := by @@ -252,31 +441,51 @@ def buildInductionCertFromHead? [NeZero seq] {dModel dHead : Nat} exact hle_all q hq have hother : (otherKeys q).Nonempty := ⟨k, by simp [otherKeys, hk]⟩ have hgap_le : - marginAt q ≤ scores q (inputs.prev q) - scores q k := by + marginAt q ≤ scoreLo q (inputs.prev q) - scoreHi q k := by have hle : marginAt q ≤ (otherKeys q).inf' hother - (fun k => scores q (inputs.prev q) - scores q k) := by + (fun k => scoreLo q (inputs.prev q) - scoreHi q k) := by simp [marginAt, hother] have hle_all := (Finset.le_inf'_iff (s := otherKeys q) (H := hother) - (f := fun k => scores q (inputs.prev q) - scores q k) + (f := fun k => scoreLo q (inputs.prev q) - scoreHi q k) (a := marginAt q)).1 hle exact hle_all k (by simp [otherKeys, hk]) - have hgap : margin ≤ scores q (inputs.prev q) - scores q k := + have hgap : margin ≤ scoreLo q (inputs.prev q) - scoreHi q k := le_trans hmargin_le hgap_le - have hgap' := - add_le_add_left hgap (scores q k) - simpa [sub_eq_add_neg, add_assoc, add_left_comm, add_comm] using hgap' + have hgap_real : (margin : Real) ≤ + (scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real) := by + have hgap_real' : + (margin : Real) ≤ ((scoreLo q (inputs.prev q) - scoreHi q k : Rat) : Real) := + (Rat.cast_le (K := Real)).2 hgap + simpa [Rat.cast_sub] using hgap_real' + have hk_bounds := hscore_bounds q k + have hprev_bounds := hscore_bounds q (inputs.prev q) + have h1 : + scoresReal q k + (margin : Real) ≤ + scoresReal q k + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) := by + exact (add_le_add_iff_left (scoresReal q k)).2 hgap_real + have h2 : + scoresReal q k + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) ≤ + (scoreLo q (inputs.prev q) : Real) := by + have hscore_le' : + scoresReal q k + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) ≤ + (scoreHi q k : Real) + + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) := by + exact (add_le_add_iff_right + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real))).2 hk_bounds.2 + calc + scoresReal q k + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) ≤ + (scoreHi q k : Real) + + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) := by + exact hscore_le' + _ = (scoreLo q (inputs.prev q) : Real) := by + exact add_sub_cancel (scoreHi q k : Real) (scoreLo q (inputs.prev q) : Real) + have h3 : + scoresReal q k + (margin : Real) ≤ (scoreLo q (inputs.prev q) : Real) := + h1.trans h2 + exact h3.trans hprev_bounds.1 · exact (hactive ⟨q, hq⟩).elim - have hscore_margin_real : - ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → - scoresReal q k + (margin : Real) ≤ scoresReal q (inputs.prev q) := by - intro q hq k hk - have hrat := hscore_margin q hq k hk - have hreal : - ((scores q k + margin : Rat) : Real) ≤ scores q (inputs.prev q) := by - exact (Rat.cast_le (K := Real)).2 hrat - simpa [scoresReal, Rat.cast_add] using hreal have hsoftmax_bounds : Layers.SoftmaxMarginBoundsOn (Val := Real) (eps : Real) (margin : Real) (fun q => q ∈ inputs.active) inputs.prev scoresReal weights := by @@ -445,7 +654,98 @@ def buildInductionCertFromHead? [NeZero seq] {dModel dHead : Nat} have h := Finset.single_le_sum hnonneg hk' simpa using h exact hle.trans hsum_others_le - some ⟨cert, { softmax_bounds := hsoftmax_bounds, value_bounds := hvalues }⟩ + have hvals_bounds : + ValueIntervalBounds (vals := valsRealOfInputs inputs) valCert := by + refine + { lo_le_hi := ?_ + lo_le_valsLo := ?_ + vals_bounds := ?_ + valsHi_le_hi := ?_ } + · rcases (Finset.univ_nonempty : univ.Nonempty) with ⟨k0, hk0⟩ + have hmem0 : k0 ∈ univ := hk0 + have hlo : (valCert.lo : Real) ≤ (valCert.valsLo k0 : Real) := by + have hloRat : valCert.lo ≤ valCert.valsLo k0 := by + change lo ≤ valsLo k0 + dsimp [lo] + refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) + (f := valsLo) (a := valsLo k0)).2 ?_ + refine ⟨k0, hmem0, ?_⟩ + exact le_rfl + exact (Rat.cast_le (K := Real)).2 hloRat + have hvals : + (valCert.valsLo k0 : Real) ≤ valsRealOfInputs inputs k0 ∧ + valsRealOfInputs inputs k0 ≤ (valCert.valsHi k0 : Real) := by + have hv := hv_bounds k0 + have hlo' : ∀ d, (vLo k0 d : Real) ≤ vRealOfInputs inputs k0 d := fun d => (hv d).1 + have hhi' : ∀ d, vRealOfInputs inputs k0 d ≤ (vHi k0 d : Real) := fun d => (hv d).2 + have hlow := + dotIntervalLower_le_dotProduct_real (v := dirHead) + (lo := vLo k0) (hi := vHi k0) + (x := fun d => vRealOfInputs inputs k0 d) hlo' hhi' + have hhigh := + dotProduct_le_dotIntervalUpper_real (v := dirHead) + (lo := vLo k0) (hi := vHi k0) + (x := fun d => vRealOfInputs inputs k0 d) hlo' hhi' + have hlow' : + (valCert.valsLo k0 : Real) ≤ valsRealOfInputs inputs k0 := by + simpa [valsLo, valCert, dirHead, valsRealOfInputs] using hlow + have hhigh' : + valsRealOfInputs inputs k0 ≤ (valCert.valsHi k0 : Real) := by + simpa [valsHi, valCert, dirHead, valsRealOfInputs] using hhigh + exact ⟨hlow', hhigh'⟩ + have hhi : (valCert.valsHi k0 : Real) ≤ (valCert.hi : Real) := by + have hhiRat : valCert.valsHi k0 ≤ valCert.hi := by + change valsHi k0 ≤ hi + dsimp [hi] + refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) + (f := valsHi) (a := valsHi k0)).2 ?_ + exact ⟨k0, ⟨hmem0, le_rfl⟩⟩ + exact (Rat.cast_le (K := Real)).2 hhiRat + have hreal : + (valCert.lo : Real) ≤ (valCert.hi : Real) := + le_trans hlo (le_trans hvals.1 (le_trans hvals.2 hhi)) + exact (Rat.cast_le (K := Real)).1 hreal + · intro k + have hmem : k ∈ univ := by simp [univ] + have hloRat : valCert.lo ≤ valCert.valsLo k := by + change lo ≤ valsLo k + dsimp [lo] + refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) + (f := valsLo) (a := valsLo k)).2 ?_ + refine ⟨k, hmem, ?_⟩ + exact le_rfl + exact (Rat.cast_le (K := Real)).2 hloRat + · intro k + have hv := hv_bounds k + have hlo' : ∀ d, (vLo k d : Real) ≤ vRealOfInputs inputs k d := fun d => (hv d).1 + have hhi' : ∀ d, vRealOfInputs inputs k d ≤ (vHi k d : Real) := fun d => (hv d).2 + have hlow := + dotIntervalLower_le_dotProduct_real (v := dirHead) + (lo := vLo k) (hi := vHi k) + (x := fun d => vRealOfInputs inputs k d) hlo' hhi' + have hhigh := + dotProduct_le_dotIntervalUpper_real (v := dirHead) + (lo := vLo k) (hi := vHi k) + (x := fun d => vRealOfInputs inputs k d) hlo' hhi' + have hlow' : + (valCert.valsLo k : Real) ≤ valsRealOfInputs inputs k := by + simpa [valsLo, valCert, dirHead, valsRealOfInputs] using hlow + have hhigh' : + valsRealOfInputs inputs k ≤ (valCert.valsHi k : Real) := by + simpa [valsHi, valCert, dirHead, valsRealOfInputs] using hhigh + exact ⟨hlow', hhigh'⟩ + · intro k + have hmem : k ∈ univ := by simp [univ] + have hhiRat : valCert.valsHi k ≤ valCert.hi := by + change valsHi k ≤ hi + dsimp [hi] + refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) + (f := valsHi) (a := valsHi k)).2 ?_ + refine ⟨k, hmem, ?_⟩ + exact le_rfl + exact (Rat.cast_le (K := Real)).2 hhiRat + exact some ⟨cert, { softmax_bounds := hsoftmax_bounds, value_bounds := hvals_bounds }⟩ + · exact none section HeadOutputInterval @@ -454,51 +754,38 @@ variable {seq dModel dHead : Nat} noncomputable section /-- Real-valued head output using explicit score inputs. -/ -def headOutputWithScores (scores : Fin seq → Fin seq → Rat) +def headOutputWithScores (scores : Fin seq → Fin seq → Real) (inputs : Model.InductionHeadInputs seq dModel dHead) (q : Fin seq) (i : Fin dModel) : Real := let weights : Fin seq → Fin seq → Real := fun q k => - Circuit.softmax (fun j => (scores q j : Real)) k - let vVecVec := vVecVecOfInputs inputs - let headValuesVec := headValueVecOfInputs inputs vVecVec - let vals : Fin seq → Real := fun k => (headValuesVec.get k).get i + Circuit.softmax (scores q) k + let vals : Fin seq → Real := fun k => headValueRealOfInputs inputs k i dotProduct (weights q) vals /-- Unfolding lemma for `headOutputWithScores`. -/ -theorem headOutputWithScores_def (scores : Fin seq → Fin seq → Rat) +theorem headOutputWithScores_def (scores : Fin seq → Fin seq → Real) (inputs : Model.InductionHeadInputs seq dModel dHead) (q : Fin seq) (i : Fin dModel) : headOutputWithScores scores inputs q i = let weights : Fin seq → Fin seq → Real := fun q k => - Circuit.softmax (fun j => (scores q j : Real)) k - let vVecVec := vVecVecOfInputs inputs - let headValuesVec := headValueVecOfInputs inputs vVecVec - let vals : Fin seq → Real := fun k => (headValuesVec.get k).get i + Circuit.softmax (scores q) k + let vals : Fin seq → Real := fun k => headValueRealOfInputs inputs k i dotProduct (weights q) vals := rfl /-- Real-valued head output for a query and model dimension. -/ def headOutput (inputs : Model.InductionHeadInputs seq dModel dHead) (q : Fin seq) (i : Fin dModel) : Real := - let qVecVec := qVecVecOfInputs inputs - let kVecVec := kVecVecOfInputs inputs - let scoresVec := scoresVecOfInputs inputs qVecVec kVecVec - let scores : Fin seq → Fin seq → Rat := fun q k => (scoresVec.get q).get k - headOutputWithScores scores inputs q i + headOutputWithScores (scoresRealOfInputs inputs) inputs q i /-- Unfolding lemma for `headOutput`. -/ theorem headOutput_def (inputs : Model.InductionHeadInputs seq dModel dHead) (q : Fin seq) (i : Fin dModel) : headOutput inputs q i = - let qVecVec := qVecVecOfInputs inputs - let kVecVec := kVecVecOfInputs inputs - let scoresVec := scoresVecOfInputs inputs qVecVec kVecVec - let scores : Fin seq → Fin seq → Rat := fun q k => (scoresVec.get q).get k - headOutputWithScores scores inputs q i := rfl + headOutputWithScores (scoresRealOfInputs inputs) inputs q i := rfl /-- Soundness predicate for head-output interval bounds. -/ structure HeadOutputIntervalSound [NeZero seq] (inputs : Model.InductionHeadInputs seq dModel dHead) - (scores : Fin seq → Fin seq → Rat) (active : Finset (Fin seq)) (c : Circuit.ResidualIntervalCert dModel) : Prop where /-- Interval bounds are ordered coordinatewise. -/ @@ -506,20 +793,18 @@ structure HeadOutputIntervalSound [NeZero seq] /-- Active-query outputs lie inside the interval bounds. -/ output_mem : ∀ q, q ∈ active → ∀ i, - (c.lo i : Real) ≤ headOutputWithScores scores inputs q i ∧ - headOutputWithScores scores inputs q i ≤ (c.hi i : Real) + (c.lo i : Real) ≤ headOutput inputs q i ∧ + headOutput inputs q i ≤ (c.hi i : Real) /-- Certified head-output interval data for a specific active set. -/ structure HeadOutputIntervalResult [NeZero seq] (inputs : Model.InductionHeadInputs seq dModel dHead) where - /-- Scores used to derive softmax weights. -/ - scores : Fin seq → Fin seq → Rat /-- Active queries covered by the interval bounds. -/ active : Finset (Fin seq) /-- Residual-interval certificate for head outputs. -/ cert : Circuit.ResidualIntervalCert dModel /-- Soundness proof for the interval bounds. -/ - sound : HeadOutputIntervalSound inputs scores active cert + sound : HeadOutputIntervalSound inputs active cert /-- Build residual-interval bounds for head outputs on active queries. -/ def buildHeadOutputIntervalFromHead? [NeZero seq] @@ -530,161 +815,259 @@ def buildHeadOutputIntervalFromHead? [NeZero seq] | zero => cases (NeZero.ne (n := (0 : Nat)) rfl) | succ n => - cases hbuild : buildInductionCertFromHead? inputs with - | none => exact none - | some certWithProof => - rcases certWithProof with ⟨cert, hcert⟩ - let vVecVec := vVecVecOfInputs inputs - let headValuesVec := headValueVecOfInputs inputs vVecVec - let headValue : Fin (Nat.succ n) → Fin dModel → Rat := fun k i => - (headValuesVec.get k).get i - let scores : Fin (Nat.succ n) → Fin (Nat.succ n) → Rat := cert.scores - let scoresReal : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := - fun q k => (scores q k : Real) - let weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := fun q k => - Circuit.softmax (scoresReal q) k - let activeSet : Finset (Fin (Nat.succ n)) := cert.active - let univ : Finset (Fin (Nat.succ n)) := Finset.univ - have huniv : univ.Nonempty := Finset.univ_nonempty - let loVal : Fin dModel → Rat := fun i => - univ.inf' huniv (fun k => headValue k i) - let hiVal : Fin dModel → Rat := fun i => - univ.sup' huniv (fun k => headValue k i) - have hvalsBounds : - ∀ i, Layers.ValueRangeBounds (Val := Rat) (loVal i) (hiVal i) - (fun k => headValue k i) := by - intro i - refine { lo_le_hi := ?_, lo_le := ?_, le_hi := ?_ } - · rcases huniv with ⟨k0, hk0⟩ - have hlo := - Finset.inf'_le (s := univ) (f := fun k => headValue k i) hk0 - have hhi := - Finset.le_sup' (s := univ) (f := fun k => headValue k i) hk0 - exact le_trans hlo hhi - · intro k - exact Finset.inf'_le (s := univ) (f := fun k => headValue k i) (by simp [univ]) - · intro k - exact Finset.le_sup' (s := univ) (f := fun k => headValue k i) (by simp [univ]) - have hvalsBoundsReal : - ∀ i, Layers.ValueRangeBounds (Val := Real) - (loVal i : Real) (hiVal i : Real) - (fun k => (headValue k i : Real)) := by - intro i - have hvals := hvalsBounds i - refine { lo_le_hi := ?_, lo_le := ?_, le_hi := ?_ } - · exact (Rat.cast_le (K := Real)).2 hvals.lo_le_hi - · intro k - exact (Rat.cast_le (K := Real)).2 (hvals.lo_le k) - · intro k - exact (Rat.cast_le (K := Real)).2 (hvals.le_hi k) - have hsoftmax : - Layers.SoftmaxMarginBoundsOn (Val := Real) (cert.eps : Real) (cert.margin : Real) - (fun q => q ∈ activeSet) cert.prev scoresReal weights := by - simpa [scoresReal, weights, activeSet] using hcert.softmax_bounds - have hweights : - Layers.OneHotApproxBoundsOnActive (Val := Real) (cert.eps : Real) - (fun q => q ∈ activeSet) cert.prev weights := - Layers.oneHotApproxBoundsOnActive_of_softmaxMargin - (Val := Real) - (ε := (cert.eps : Real)) - (margin := (cert.margin : Real)) - (active := fun q => q ∈ activeSet) - (prev := cert.prev) - (scores := scoresReal) - (weights := weights) - hsoftmax - have happrox : - ∀ i, Layers.InductionSpecApproxOn (Val := Real) (n := n) - ((cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real))) - (fun q => q ∈ activeSet) cert.prev - (fun q => dotProduct (weights q) (fun k => (headValue k i : Real))) - (fun k => (headValue k i : Real)) := by - intro i - exact - Layers.inductionSpecApproxOn_of_oneHotApprox_valueRange - (Val := Real) - (n := n) - (ε := (cert.eps : Real)) - (lo := (loVal i : Real)) - (hi := (hiVal i : Real)) - (active := fun q => q ∈ activeSet) - (prev := cert.prev) - (weights := weights) - (vals := fun k => (headValue k i : Real)) - (hweights := hweights) - (hvals := hvalsBoundsReal i) - let delta : Fin dModel → Rat := fun i => hiVal i - loVal i - let boundLoRat : Fin (Nat.succ n) → Fin dModel → Rat := fun q i => - headValue (cert.prev q) i - cert.eps * delta i - let boundHiRat : Fin (Nat.succ n) → Fin dModel → Rat := fun q i => - headValue (cert.prev q) i + cert.eps * delta i - let loOut : Fin dModel → Rat := fun i => - if h : activeSet.Nonempty then - activeSet.inf' h (fun q => boundLoRat q i) - else - 0 - let hiOut : Fin dModel → Rat := fun i => - if h : activeSet.Nonempty then - activeSet.sup' h (fun q => boundHiRat q i) - else - 0 - have hout : - ∀ q, q ∈ activeSet → ∀ i, - (loOut i : Real) ≤ headOutputWithScores scores inputs q i ∧ - headOutputWithScores scores inputs q i ≤ (hiOut i : Real) := by - intro q hq i - have hactive : activeSet.Nonempty := ⟨q, hq⟩ - have hspec := (happrox i) q hq - have hout_def : - headOutputWithScores scores inputs q i = - dotProduct (weights q) (fun k => (headValue k i : Real)) := by - simp [headOutputWithScores, scoresReal, weights, headValue, headValuesVec, vVecVec] - have hupper : - headOutputWithScores scores inputs q i ≤ (boundHiRat q i : Real) := by - have hupper' := - (happrox i) q hq |>.1 - simpa [hout_def, boundHiRat, delta] using hupper' - have hlower : - (boundLoRat q i : Real) ≤ headOutputWithScores scores inputs q i := by - have hlower' : - (headValue (cert.prev q) i : Real) - - (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) ≤ - dotProduct (weights q) (fun k => (headValue k i : Real)) := by - exact (sub_le_iff_le_add).2 hspec.2 - simpa [hout_def, boundLoRat, delta] using hlower' - have hlo : - (loOut i : Real) ≤ (boundLoRat q i : Real) := by - have hloRat : loOut i ≤ boundLoRat q i := by - simpa [loOut, hactive] using - (Finset.inf'_le (s := activeSet) (f := fun q => boundLoRat q i) hq) - exact (Rat.cast_le (K := Real)).2 hloRat - have hhi : - (boundHiRat q i : Real) ≤ (hiOut i : Real) := by - have hhiRat : boundHiRat q i ≤ hiOut i := by - simpa [hiOut, hactive] using - (Finset.le_sup' (s := activeSet) (f := fun q => boundHiRat q i) hq) - exact (Rat.cast_le (K := Real)).2 hhiRat - exact ⟨le_trans hlo hlower, le_trans hupper hhi⟩ - have hbounds : Circuit.ResidualIntervalBounds { lo := loOut, hi := hiOut } := by - refine { lo_le_hi := ?_ } - intro i - by_cases hactive : activeSet.Nonempty - · rcases hactive with ⟨q, hq⟩ - have hout_i := hout q hq i - have hleReal : (loOut i : Real) ≤ (hiOut i : Real) := - le_trans hout_i.1 hout_i.2 - exact (Rat.cast_le (K := Real)).1 hleReal - · simp [loOut, hiOut, hactive] - let certOut : Circuit.ResidualIntervalCert dModel := { lo := loOut, hi := hiOut } - exact some - { scores := scores - active := activeSet - cert := certOut - sound := - { bounds := hbounds - output_mem := by - intro q hq i - exact hout q hq i } } + by_cases hEps : 0 < inputs.lnEps + · by_cases hmodel : dModel = 0 + · exact none + · cases hbuild : buildInductionCertFromHead? inputs with + | none => exact none + | some certWithProof => + rcases certWithProof with ⟨cert, hcert⟩ + let lnBounds : Fin (Nat.succ n) → (Fin dModel → Rat) × (Fin dModel → Rat) := fun q => + Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q) + let lnLo : Fin (Nat.succ n) → Fin dModel → Rat := fun q => (lnBounds q).1 + let lnHi : Fin (Nat.succ n) → Fin dModel → Rat := fun q => (lnBounds q).2 + let vLo : Fin (Nat.succ n) → Fin dHead → Rat := fun q d => + dotIntervalLower (fun j => inputs.wv j d) (lnLo q) (lnHi q) + inputs.bv d + let vHi : Fin (Nat.succ n) → Fin dHead → Rat := fun q d => + dotIntervalUpper (fun j => inputs.wv j d) (lnLo q) (lnHi q) + inputs.bv d + let headValueLo : Fin (Nat.succ n) → Fin dModel → Rat := fun k i => + dotIntervalLower (fun d => inputs.wo i d) (vLo k) (vHi k) + let headValueHi : Fin (Nat.succ n) → Fin dModel → Rat := fun k i => + dotIntervalUpper (fun d => inputs.wo i d) (vLo k) (vHi k) + have hln_bounds : + ∀ q i, (lnLo q i : Real) ≤ lnRealOfInputs inputs q i ∧ + lnRealOfInputs inputs q i ≤ (lnHi q i : Real) := by + intro q i + have hln := + Bounds.layerNormBounds_spec (eps := inputs.lnEps) + (gamma := inputs.ln1Gamma) (beta := inputs.ln1Beta) + (x := inputs.embed q) hmodel hEps + simpa [lnBounds, lnLo, lnHi, lnRealOfInputs] using hln i + have hv_bounds : + ∀ q d, (vLo q d : Real) ≤ vRealOfInputs inputs q d ∧ + vRealOfInputs inputs q d ≤ (vHi q d : Real) := by + intro q d + have hln := hln_bounds q + have hlo : ∀ j, (lnLo q j : Real) ≤ lnRealOfInputs inputs q j := fun j => (hln j).1 + have hhi : ∀ j, lnRealOfInputs inputs q j ≤ (lnHi q j : Real) := fun j => (hln j).2 + have hlow := + dotIntervalLower_le_dotProduct_real (v := fun j => inputs.wv j d) + (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi + have hhigh := + dotProduct_le_dotIntervalUpper_real (v := fun j => inputs.wv j d) + (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi + have hlow' := add_le_add_right hlow (inputs.bv d : Real) + have hhigh' := add_le_add_right hhigh (inputs.bv d : Real) + constructor + · simpa [vLo, vRealOfInputs, Rat.cast_add] using hlow' + · simpa [vHi, vRealOfInputs, Rat.cast_add] using hhigh' + have hhead_bounds : + ∀ k i, (headValueLo k i : Real) ≤ headValueRealOfInputs inputs k i ∧ + headValueRealOfInputs inputs k i ≤ (headValueHi k i : Real) := by + intro k i + have hv := hv_bounds k + have hlo : ∀ d, (vLo k d : Real) ≤ vRealOfInputs inputs k d := fun d => (hv d).1 + have hhi : ∀ d, vRealOfInputs inputs k d ≤ (vHi k d : Real) := fun d => (hv d).2 + have hlow := + dotIntervalLower_le_dotProduct_real (v := fun d => inputs.wo i d) + (lo := vLo k) (hi := vHi k) + (x := fun d => vRealOfInputs inputs k d) hlo hhi + have hhigh := + dotProduct_le_dotIntervalUpper_real (v := fun d => inputs.wo i d) + (lo := vLo k) (hi := vHi k) + (x := fun d => vRealOfInputs inputs k d) hlo hhi + constructor + · simpa [headValueLo, headValueRealOfInputs] using hlow + · simpa [headValueHi, headValueRealOfInputs] using hhigh + let scoresReal : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := + scoresRealOfInputs inputs + let weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := fun q k => + Circuit.softmax (scoresReal q) k + let activeSet : Finset (Fin (Nat.succ n)) := cert.active + let univ : Finset (Fin (Nat.succ n)) := Finset.univ + have huniv : univ.Nonempty := by simp [univ] + let loVal : Fin dModel → Rat := fun i => + univ.inf' huniv (fun k => headValueLo k i) + let hiVal : Fin dModel → Rat := fun i => + univ.sup' huniv (fun k => headValueHi k i) + have hvalsBoundsReal : + ∀ i, Layers.ValueRangeBounds (Val := Real) + (loVal i : Real) (hiVal i : Real) + (fun k => headValueRealOfInputs inputs k i) := by + intro i + refine { lo_le_hi := ?_, lo_le := ?_, le_hi := ?_ } + · rcases (Finset.univ_nonempty : univ.Nonempty) with ⟨k0, hk0⟩ + have hmem0 : k0 ∈ univ := hk0 + have hloRat : loVal i ≤ headValueLo k0 i := by + change loVal i ≤ headValueLo k0 i + dsimp [loVal] + refine (Finset.inf'_le_iff (s := univ) (H := huniv) + (f := fun k => headValueLo k i) (a := headValueLo k0 i)).2 ?_ + refine ⟨k0, hmem0, ?_⟩ + exact le_rfl + have hhiRat : headValueHi k0 i ≤ hiVal i := by + change headValueHi k0 i ≤ hiVal i + dsimp [hiVal] + refine (Finset.le_sup'_iff (s := univ) (H := huniv) + (f := fun k => headValueHi k i) (a := headValueHi k0 i)).2 ?_ + exact ⟨k0, ⟨hmem0, le_rfl⟩⟩ + have hbounds := hhead_bounds k0 i + have hreal : + (loVal i : Real) ≤ (hiVal i : Real) := + le_trans ((Rat.cast_le (K := Real)).2 hloRat) + (le_trans hbounds.1 (le_trans hbounds.2 ((Rat.cast_le (K := Real)).2 hhiRat))) + exact hreal + · intro k + have hmem : k ∈ univ := by simp [univ] + have hloRat : loVal i ≤ headValueLo k i := by + change loVal i ≤ headValueLo k i + dsimp [loVal] + refine (Finset.inf'_le_iff (s := univ) (H := huniv) + (f := fun k => headValueLo k i) (a := headValueLo k i)).2 ?_ + refine ⟨k, hmem, ?_⟩ + exact le_rfl + have hbounds := hhead_bounds k i + exact (Rat.cast_le (K := Real)).2 hloRat |>.trans hbounds.1 + · intro k + have hmem : k ∈ univ := by simp [univ] + have hhiRat : headValueHi k i ≤ hiVal i := by + change headValueHi k i ≤ hiVal i + dsimp [hiVal] + refine (Finset.le_sup'_iff (s := univ) (H := huniv) + (f := fun k => headValueHi k i) (a := headValueHi k i)).2 ?_ + exact ⟨k, ⟨hmem, le_rfl⟩⟩ + have hbounds := hhead_bounds k i + exact hbounds.2.trans ((Rat.cast_le (K := Real)).2 hhiRat) + have hsoftmax : + Layers.SoftmaxMarginBoundsOn (Val := Real) (cert.eps : Real) (cert.margin : Real) + (fun q => q ∈ activeSet) cert.prev scoresReal weights := by + simpa [scoresReal, weights, activeSet] using hcert.softmax_bounds + have hweights : + Layers.OneHotApproxBoundsOnActive (Val := Real) (cert.eps : Real) + (fun q => q ∈ activeSet) cert.prev weights := + Layers.oneHotApproxBoundsOnActive_of_softmaxMargin + (Val := Real) + (ε := (cert.eps : Real)) + (margin := (cert.margin : Real)) + (active := fun q => q ∈ activeSet) + (prev := cert.prev) + (scores := scoresReal) + (weights := weights) + hsoftmax + have happrox : + ∀ i, Layers.InductionSpecApproxOn (Val := Real) (n := n) + ((cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real))) + (fun q => q ∈ activeSet) cert.prev + (fun q => dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i)) + (fun k => headValueRealOfInputs inputs k i) := by + intro i + exact + Layers.inductionSpecApproxOn_of_oneHotApprox_valueRange + (Val := Real) + (n := n) + (ε := (cert.eps : Real)) + (lo := (loVal i : Real)) + (hi := (hiVal i : Real)) + (active := fun q => q ∈ activeSet) + (prev := cert.prev) + (weights := weights) + (vals := fun k => headValueRealOfInputs inputs k i) + (hweights := hweights) + (hvals := hvalsBoundsReal i) + let delta : Fin dModel → Rat := fun i => hiVal i - loVal i + let boundLoRat : Fin (Nat.succ n) → Fin dModel → Rat := fun q i => + headValueLo (cert.prev q) i - cert.eps * delta i + let boundHiRat : Fin (Nat.succ n) → Fin dModel → Rat := fun q i => + headValueHi (cert.prev q) i + cert.eps * delta i + let loOut : Fin dModel → Rat := fun i => + if h : activeSet.Nonempty then + activeSet.inf' h (fun q => boundLoRat q i) + else + 0 + let hiOut : Fin dModel → Rat := fun i => + if h : activeSet.Nonempty then + activeSet.sup' h (fun q => boundHiRat q i) + else + 0 + have hout : + ∀ q, q ∈ activeSet → ∀ i, + (loOut i : Real) ≤ headOutput inputs q i ∧ + headOutput inputs q i ≤ (hiOut i : Real) := by + intro q hq i + have hactive : activeSet.Nonempty := ⟨q, hq⟩ + have hspec := (happrox i) q hq + have hout_def : + headOutput inputs q i = + dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) := by + simp [headOutput, headOutputWithScores, scoresReal, weights] + have hprev_bounds := hhead_bounds (cert.prev q) i + have hupper : + headOutput inputs q i ≤ (boundHiRat q i : Real) := by + have hupper' : + dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) ≤ + headValueRealOfInputs inputs (cert.prev q) i + + (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) := by + exact hspec.1 + have hupper'' : + dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) ≤ + (headValueHi (cert.prev q) i : Real) + + (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) := by + have hprev_bounds' := + (add_le_add_iff_right + ((cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)))).2 + hprev_bounds.2 + exact le_trans hupper' hprev_bounds' + simpa + [hout_def, boundHiRat, delta, Rat.cast_add, Rat.cast_mul, Rat.cast_sub] using + hupper'' + have hlower : + (boundLoRat q i : Real) ≤ headOutput inputs q i := by + have hlower' : + (headValueRealOfInputs inputs (cert.prev q) i : Real) - + (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) ≤ + dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) := by + exact (sub_le_iff_le_add).2 hspec.2 + have hlower'' : + (headValueLo (cert.prev q) i : Real) - + (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) ≤ + dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) := by + exact le_trans (sub_le_sub_right hprev_bounds.1 + ((cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)))) hlower' + simpa [hout_def, boundLoRat, delta, Rat.cast_mul, Rat.cast_sub] using + hlower'' + have hlo : + (loOut i : Real) ≤ (boundLoRat q i : Real) := by + have hloRat : loOut i ≤ boundLoRat q i := by + simpa [loOut, hactive] using + (Finset.inf'_le (s := activeSet) (f := fun q => boundLoRat q i) (b := q) hq) + exact (Rat.cast_le (K := Real)).2 hloRat + have hhi : + (boundHiRat q i : Real) ≤ (hiOut i : Real) := by + have hhiRat : boundHiRat q i ≤ hiOut i := by + simpa [hiOut, hactive] using + (Finset.le_sup' (s := activeSet) (f := fun q => boundHiRat q i) (b := q) hq) + exact (Rat.cast_le (K := Real)).2 hhiRat + exact ⟨le_trans hlo hlower, le_trans hupper hhi⟩ + have hbounds : Circuit.ResidualIntervalBounds { lo := loOut, hi := hiOut } := by + refine { lo_le_hi := ?_ } + intro i + by_cases hactive : activeSet.Nonempty + · rcases hactive with ⟨q, hq⟩ + have hout_i := hout q hq i + have hleReal : (loOut i : Real) ≤ (hiOut i : Real) := + le_trans hout_i.1 hout_i.2 + exact (Rat.cast_le (K := Real)).1 hleReal + · simp [loOut, hiOut, hactive] + let certOut : Circuit.ResidualIntervalCert dModel := { lo := loOut, hi := hiOut } + exact some + { active := activeSet + cert := certOut + sound := + { bounds := hbounds + output_mem := by + intro q hq i + exact hout q hq i } } + · exact none end diff --git a/README.md b/README.md index 342e36a..0d44d0b 100644 --- a/README.md +++ b/README.md @@ -99,10 +99,14 @@ You can also derive the head inputs directly from an `NFP_BINARY_V1` model file: ```bash lake exe nfp induction certify_head_model \ --model models/gpt2_rigorous_with_gelu_kind_seq32.nfpt \ - --layer 5 --head 1 --period 16 \ + --layer 5 --head 1 \ --direction-target 1 --direction-negative 2 ``` +By default, `certify_head_model` derives the `prev` map and active set from the +token sequence stored in the model file. Use `--period ` to override with a +fixed periodic prompt. + ### End-to-end check with downstream bound (prototype) ```bash diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index 520ffda..eb4a27a 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -15,9 +15,9 @@ It is intentionally brief and focused on the soundness boundary. now includes attention projection biases and LayerNorm metadata, but the Lean-side computation still ignores LayerNorm and the shared attention output bias. - The `certify_head_model` path derives head inputs from the model binary in Lean, includes - attention projection biases, but still ignores LayerNorm and the shared attention output bias. - It currently requires `head_dim` to be a perfect square to represent the scale as an exact - rational. + attention projection biases, and derives `prev`/active from the stored token sequence by + default, but still ignores LayerNorm and the shared attention output bias. It currently + requires `head_dim` to be a perfect square to represent the scale as an exact rational. - Performance: exact head-input recomputation in Lean can be slow for nontrivial sequence lengths. - There is no bridge theorem connecting certificate validity to a full circuit/model semantics statement (for example, a formal statement about logits under a transformer block stack). From 67e5429eb823974934e2a7e77fe6403c2c3066f8 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sun, 4 Jan 2026 00:35:19 +0100 Subject: [PATCH 105/244] Extend induction certs and tighten LayerNorm bounds --- Nfp/Circuit/Cert/LogitDiff.lean | 39 ++ Nfp/Cli.lean | 60 +++ Nfp/IO.lean | 125 ++++++ Nfp/IO/NfptPure.lean | 2 + Nfp/IO/Pure.lean | 24 + Nfp/Model/InductionHead.lean | 4 + Nfp/Sound.lean | 1 + Nfp/Sound/Bounds/LayerNorm.lean | 257 ++++++++++- Nfp/Sound/Gpt2/HeadInputs.lean | 4 + Nfp/Sound/Induction.lean | 179 +++++++- Nfp/Sound/Induction/LogitDiff.lean | 186 ++++++++ Nfp/Sound/Induction/OneHot.lean | 216 +++++++++ scripts/build_gpt2_head_inputs.py | 9 + scripts/build_gpt2_induction_cert.py | 2 +- .../build_gpt2_induction_cert_from_binary.py | 377 ++++++++++++++++ scripts/discover_gpt2_induction_targets.py | 425 ++++++++++++++++++ scripts/generate_rigorous_induction.py | 146 ++++-- scripts/sweep_gpt2_induction_nonvacuous.py | 290 ++++++++++++ 18 files changed, 2264 insertions(+), 82 deletions(-) create mode 100644 Nfp/Sound/Induction/LogitDiff.lean create mode 100644 Nfp/Sound/Induction/OneHot.lean create mode 100644 scripts/build_gpt2_induction_cert_from_binary.py create mode 100644 scripts/discover_gpt2_induction_targets.py create mode 100644 scripts/sweep_gpt2_induction_nonvacuous.py diff --git a/Nfp/Circuit/Cert/LogitDiff.lean b/Nfp/Circuit/Cert/LogitDiff.lean index 7d5f731..99da2da 100644 --- a/Nfp/Circuit/Cert/LogitDiff.lean +++ b/Nfp/Circuit/Cert/LogitDiff.lean @@ -28,6 +28,20 @@ def logitDiffLowerBound (active : Finset (Fin seq)) else exact none +/-- Compute a lower bound on the logit-diff contribution with per-query eps. -/ +def logitDiffLowerBoundAt (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (epsAt : Fin seq → Rat) (lo hi : Rat) (vals : Fin seq → Rat) : Option Rat := by + classical + if h : active.Nonempty then + let gap : Fin seq → Rat := fun q => epsAt q * (hi - lo) + let f : Fin seq → Rat := fun q => vals (prev q) - gap q + let img := active.image f + have himg : img.Nonempty := h.image f + exact some (Finset.min' img himg) + else + exact none + /-- The computed lower bound is below every active `prev` value minus the tolerance gap. -/ theorem logitDiffLowerBound_le (active : Finset (Fin seq)) (prev : Fin seq → Fin seq) @@ -53,6 +67,31 @@ theorem logitDiffLowerBound_le (active : Finset (Fin seq)) simpa [f, gap] using hbound'.symm simpa [f, gap, hlb] using hmin +/-- The per-query lower bound is below every active `prev` value minus the local gap. -/ +theorem logitDiffLowerBoundAt_le (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (epsAt : Fin seq → Rat) (lo hi : Rat) (vals : Fin seq → Rat) + (q : Fin seq) (hq : q ∈ active) : + ∀ lb, logitDiffLowerBoundAt active prev epsAt lo hi vals = some lb → + lb ≤ vals (prev q) - epsAt q * (hi - lo) := by + classical + intro lb hbound + have hnonempty : active.Nonempty := ⟨q, hq⟩ + have hbound' : + (active.image (fun q => vals (prev q) - epsAt q * (hi - lo))).min' + (hnonempty.image (fun q => vals (prev q) - epsAt q * (hi - lo))) = lb := by + simpa [logitDiffLowerBoundAt, hnonempty] using hbound + let gap : Fin seq → Rat := fun q => epsAt q * (hi - lo) + let f : Fin seq → Rat := fun q => vals (prev q) - gap q + have hmem : f q ∈ (active.image f) := by + refine Finset.mem_image.2 ?_ + exact ⟨q, hq, rfl⟩ + have hmin : (active.image f).min' (hnonempty.image f) ≤ f q := + Finset.min'_le _ _ hmem + have hlb : lb = (active.image f).min' (hnonempty.image f) := by + simpa [f, gap] using hbound'.symm + simpa [f, gap, hlb] using hmin + end Circuit end Nfp diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index 4ae6d36..799ff16 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -178,6 +178,16 @@ def runInductionCertifyHead (p : Parsed) : IO UInt32 := do IO.runInductionCertifyHead inputsPath minActive? minLogitDiffStr? minMarginStr? maxEpsStr? +/-- `nfp induction certify_head_nonvacuous` subcommand. -/ +def runInductionCertifyHeadNonvacuous (p : Parsed) : IO UInt32 := do + let inputsPath := p.flag! "inputs" |>.as! String + let minActive? := (p.flag? "min-active").map (·.as! Nat) + let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) + let minMarginStr? := (p.flag? "min-margin").map (·.as! String) + let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) + IO.runInductionCertifyHeadNonvacuous inputsPath minActive? minLogitDiffStr? + minMarginStr? maxEpsStr? + /-- `nfp induction certify_head` subcommand. -/ def inductionCertifyHeadCmd : Cmd := `[Cli| certify_head VIA runInductionCertifyHead; @@ -192,6 +202,20 @@ def inductionCertifyHeadCmd : Cmd := `[Cli| "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." ] +/-- `nfp induction certify_head_nonvacuous` subcommand. -/ +def inductionCertifyHeadNonvacuousCmd : Cmd := `[Cli| + certify_head_nonvacuous VIA runInductionCertifyHeadNonvacuous; + "Require a strictly positive logit-diff bound from exact head inputs." + FLAGS: + inputs : String; "Path to the induction head input file." + "min-active" : Nat; "Optional minimum number of active queries required \ + (default: max 1 (seq/8))." + "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ + (rational literal; default: 0)." + "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." + "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." +] + /-- `nfp induction certify_head_model` subcommand. -/ def runInductionCertifyHeadModel (p : Parsed) : IO UInt32 := do let modelPath := p.flag! "model" |>.as! String @@ -207,6 +231,21 @@ def runInductionCertifyHeadModel (p : Parsed) : IO UInt32 := do IO.runInductionCertifyHeadModel modelPath layer head dirTarget dirNegative period? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? +/-- `nfp induction certify_head_model_nonvacuous` subcommand. -/ +def runInductionCertifyHeadModelNonvacuous (p : Parsed) : IO UInt32 := do + let modelPath := p.flag! "model" |>.as! String + let layer := p.flag! "layer" |>.as! Nat + let head := p.flag! "head" |>.as! Nat + let period? := (p.flag? "period").map (·.as! Nat) + let dirTarget := p.flag! "direction-target" |>.as! Nat + let dirNegative := p.flag! "direction-negative" |>.as! Nat + let minActive? := (p.flag? "min-active").map (·.as! Nat) + let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) + let minMarginStr? := (p.flag? "min-margin").map (·.as! String) + let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) + IO.runInductionCertifyHeadModelNonvacuous modelPath layer head dirTarget dirNegative period? + minActive? minLogitDiffStr? minMarginStr? maxEpsStr? + /-- `nfp induction certify_head_model` subcommand. -/ def inductionCertifyHeadModelCmd : Cmd := `[Cli| certify_head_model VIA runInductionCertifyHeadModel; @@ -226,6 +265,25 @@ def inductionCertifyHeadModelCmd : Cmd := `[Cli| "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." ] +/-- `nfp induction certify_head_model_nonvacuous` subcommand. -/ +def inductionCertifyHeadModelNonvacuousCmd : Cmd := `[Cli| + certify_head_model_nonvacuous VIA runInductionCertifyHeadModelNonvacuous; + "Require a strictly positive logit-diff bound from a model binary." + FLAGS: + model : String; "Path to the NFP_BINARY_V1 model file." + layer : Nat; "Layer index for the induction head." + head : Nat; "Head index for the induction head." + period : Nat; "Optional prompt period override (default: derive from tokens)." + "direction-target" : Nat; "Target token id for logit-diff direction." + "direction-negative" : Nat; "Negative token id for logit-diff direction." + "min-active" : Nat; "Optional minimum number of active queries required \ + (default: max 1 (seq/8))." + "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ + (rational literal; default: 0)." + "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." + "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." +] + /-- `nfp induction head_interval` subcommand. -/ def runInductionHeadInterval (p : Parsed) : IO UInt32 := do let inputsPath := p.flag! "inputs" |>.as! String @@ -277,7 +335,9 @@ def inductionCmd : Cmd := `[Cli| inductionCertifyEndToEndMatrixCmd; inductionCertifyEndToEndModelCmd; inductionCertifyHeadCmd; + inductionCertifyHeadNonvacuousCmd; inductionCertifyHeadModelCmd; + inductionCertifyHeadModelNonvacuousCmd; inductionHeadIntervalCmd; inductionHeadIntervalModelCmd ] diff --git a/Nfp/IO.lean b/Nfp/IO.lean index b83dd8f..7f17ee2 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -9,6 +9,7 @@ import Nfp.Circuit.Cert.ResidualBound import Nfp.Circuit.Cert.ResidualInterval import Nfp.Sound.Bounds.MatrixNorm import Nfp.Sound.Induction +import Nfp.Sound.Induction.LogitDiff /-! IO wrappers for loading and checking induction certificates. @@ -839,6 +840,65 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} tol={tol}, logitDiffLB={logitDiffLB})" return 0 +private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (minActive? : Option Nat) (minLogitDiff? : Option Rat) + (minMargin maxEps : Rat) : IO UInt32 := do + match seq with + | 0 => + IO.eprintln "error: seq must be positive" + return 2 + | Nat.succ n => + let seq := Nat.succ n + let _ : NeZero seq := ⟨by simp⟩ + match Sound.buildInductionLogitLowerBoundNonvacuous? inputs with + | none => + IO.eprintln "error: nonvacuous logit-diff bound unavailable" + return 2 + | some result => + let cert := result.base.cert + let logitDiffLB := result.base.lb + let activeCount := cert.active.card + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" + return 2 + if cert.margin < minMargin then + IO.eprintln + s!"error: margin {cert.margin} below minimum {minMargin}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {cert.eps} above maximum {maxEps}" + return 2 + let tol := cert.eps * (cert.values.hi - cert.values.lo) + let effectiveMinLogitDiff := + match minLogitDiff? with + | some v => some v + | none => some (0 : Rat) + let violation? : Option Rat := + match effectiveMinLogitDiff with + | none => none + | some minLogitDiff => + if logitDiffLB < minLogitDiff then + some minLogitDiff + else + none + match violation? with + | some minLogitDiff => + IO.eprintln + s!"error: logitDiffLB {logitDiffLB} \ + below minimum {minLogitDiff}" + return 2 + | none => + IO.println + s!"ok: nonvacuous induction bound certified \ + (seq={seq}, active={activeCount}, \ + tol={tol}, logitDiffLB={logitDiffLB})" + return 0 + /-- Build and check induction certificates from exact head inputs. -/ def runInductionCertifyHead (inputsPath : System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) @@ -867,6 +927,34 @@ def runInductionCertifyHead (inputsPath : System.FilePath) | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => checkInductionHeadInputs inputs minActive? minLogitDiff? minMargin maxEps +/-- Build and check a strictly positive induction logit-diff bound from head inputs. -/ +def runInductionCertifyHeadNonvacuous (inputsPath : System.FilePath) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (1 / 2 : Rat) + let parsedInputs ← loadInductionHeadInputs inputsPath + match parsedInputs with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => + checkInductionHeadInputsNonvacuous inputs minActive? minLogitDiff? minMargin maxEps + /-- Build and check induction certificates from a model binary. -/ def runInductionCertifyHeadModel (modelPath : System.FilePath) (layer head dirTarget dirNegative : Nat) (period? : Option Nat) @@ -904,6 +992,43 @@ def runInductionCertifyHeadModel (modelPath : System.FilePath) | Except.ok inputs => checkInductionHeadInputs inputs minActive? minLogitDiff? minMargin maxEps +/-- Build and check a strictly positive induction logit-diff bound from a model binary. -/ +def runInductionCertifyHeadModelNonvacuous (modelPath : System.FilePath) + (layer head dirTarget dirNegative : Nat) (period? : Option Nat) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (1 / 2 : Rat) + let data ← IO.FS.readBinFile modelPath + match NfptPure.parseHeader data with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨header, start⟩ => + match + NfptPure.readInductionHeadInputs + data start header layer head dirTarget dirNegative period? + with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok inputs => + checkInductionHeadInputsNonvacuous inputs minActive? minLogitDiff? minMargin maxEps + /-- Build head-output interval bounds from exact head inputs. -/ def runInductionHeadInterval (inputsPath : System.FilePath) (outPath? : Option System.FilePath) : IO UInt32 := do diff --git a/Nfp/IO/NfptPure.lean b/Nfp/IO/NfptPure.lean index f2547fa..bf9ab72 100644 --- a/Nfp/IO/NfptPure.lean +++ b/Nfp/IO/NfptPure.lean @@ -481,6 +481,8 @@ def readInductionHeadInputs (data : ByteArray) (start : Nat) (h : NfptHeader) bv := weights.bv wo := weights.wo attnBias := attnBias + maskCausal := true + maskValue := (-10000 : Rat) directionSpec := directionSpec direction := direction } diff --git a/Nfp/IO/Pure.lean b/Nfp/IO/Pure.lean index 18887ae..dfb5151 100644 --- a/Nfp/IO/Pure.lean +++ b/Nfp/IO/Pure.lean @@ -573,6 +573,8 @@ private structure HeadParseState (seq dModel dHead : Nat) where bv : Fin dHead → Option Rat wo : Fin dModel → Fin dHead → Option Rat attnBias : Fin dModel → Option Rat + maskCausal : Option Bool + maskValue : Option Rat directionTarget : Option Nat directionNegative : Option Nat direction : Fin dModel → Option Rat @@ -594,6 +596,8 @@ private def initHeadState (seq dModel dHead : Nat) : HeadParseState seq dModel d bv := fun _ => none wo := fun _ _ => none attnBias := fun _ => none + maskCausal := none + maskValue := none directionTarget := none directionNegative := none direction := fun _ => none } @@ -679,6 +683,19 @@ private def parseHeadLine {seq dModel dHead : Nat} (st : HeadParseState seq dMod | ["attn_bias", d, val] => let vec ← setVecEntry st.attnBias (← parseNat d) (← parseRat val) return { st with attnBias := vec } + | ["mask", kind] => + if st.maskCausal.isSome then + throw "duplicate mask entry" + else + match kind with + | "causal" => return { st with maskCausal := some true } + | "none" => return { st with maskCausal := some false } + | _ => throw "mask must be 'causal' or 'none'" + | ["mask_value", val] => + if st.maskValue.isSome then + throw "duplicate mask_value entry" + else + return { st with maskValue := some (← parseRat val) } | ["direction-target", tok] => if st.directionTarget.isSome then throw "duplicate direction-target entry" @@ -767,6 +784,11 @@ private def finalizeHeadState {seq dModel dHead : Nat} (hpos : 0 < seq) (st.wo i j).getD 0 let attnBiasFun : Fin dModel → Rat := fun d => (st.attnBias d).getD 0 + let maskCausal := st.maskCausal.getD false + let maskValue := + match st.maskValue with + | some v => v + | none => if maskCausal then (-10000 : Rat) else 0 let directionFun : Fin dModel → Rat := fun d => (st.direction d).getD 0 let active := @@ -790,6 +812,8 @@ private def finalizeHeadState {seq dModel dHead : Nat} (hpos : 0 < seq) bv := bvFun wo := woFun attnBias := attnBiasFun + maskCausal := maskCausal + maskValue := maskValue directionSpec := directionSpec direction := directionFun } diff --git a/Nfp/Model/InductionHead.lean b/Nfp/Model/InductionHead.lean index d90bf17..f1e3eb2 100644 --- a/Nfp/Model/InductionHead.lean +++ b/Nfp/Model/InductionHead.lean @@ -49,6 +49,10 @@ structure InductionHeadInputs (seq dModel dHead : Nat) where wo : Fin dModel → Fin dHead → Rat /-- Attention output bias (shared across heads). -/ attnBias : Fin dModel → Rat + /-- Whether to apply a causal mask to attention scores. -/ + maskCausal : Bool + /-- Score value for masked entries (e.g. `-10000` for GPT-2 causal masking). -/ + maskValue : Rat /-- Logit-diff direction metadata. -/ directionSpec : DirectionSpec /-- Logit-diff direction vector in model space. -/ diff --git a/Nfp/Sound.lean b/Nfp/Sound.lean index 99456bb..aff7bb1 100644 --- a/Nfp/Sound.lean +++ b/Nfp/Sound.lean @@ -2,6 +2,7 @@ import Nfp.Sound.Gpt2.HeadInputs import Nfp.Sound.Induction +import Nfp.Sound.Induction.LogitDiff import Nfp.Sound.Bounds.MatrixNorm import Nfp.Sound.Linear.FinFold diff --git a/Nfp/Sound/Bounds/LayerNorm.lean b/Nfp/Sound/Bounds/LayerNorm.lean index 2cc9aae..8f97090 100644 --- a/Nfp/Sound/Bounds/LayerNorm.lean +++ b/Nfp/Sound/Bounds/LayerNorm.lean @@ -62,26 +62,50 @@ theorem variance_nonneg {n : Nat} (x : Fin n → Rat) (h : n ≠ 0) : div_nonneg hsum hden simpa [variance_def x h] using hdiv -/-- Rational lower bound for a square root. -/ -def sqrtLower (q : Rat) : Rat := +/-! Square-root bounds. -/ + +/-- Base rational lower bound for a square root. -/ +def sqrtLowerBase (q : Rat) : Rat := let num := q.num.natAbs let den := q.den let a := Nat.sqrt num let b := Nat.sqrt den (a : Rat) / (b + 1 : Rat) -/-- Rational upper bound for a square root. -/ -def sqrtUpper (q : Rat) : Rat := +/-- Base rational upper bound for a square root. -/ +def sqrtUpperBase (q : Rat) : Rat := let num := q.num.natAbs let den := q.den let a := Nat.sqrt num let b := Nat.sqrt den (a + 1 : Rat) / (b : Rat) -/-- `sqrtLower` is nonnegative. -/ -theorem sqrtLower_nonneg (q : Rat) : 0 ≤ sqrtLower q := by +/-- Alternate rational lower bound for a square root. -/ +def sqrtLowerAlt (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den + let a := Nat.sqrt (num * den) + (a : Rat) / den + +/-- Alternate rational upper bound for a square root. -/ +def sqrtUpperAlt (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den + let a := Nat.sqrt (num * den) + (a + 1 : Rat) / den + +/-- Rational lower bound for a square root (tighter of two bounds). -/ +def sqrtLower (q : Rat) : Rat := + max (sqrtLowerBase q) (sqrtLowerAlt q) + +/-- Rational upper bound for a square root (tighter of two bounds). -/ +def sqrtUpper (q : Rat) : Rat := + min (sqrtUpperBase q) (sqrtUpperAlt q) + +/-- `sqrtLowerBase` is nonnegative. -/ +theorem sqrtLowerBase_nonneg (q : Rat) : 0 ≤ sqrtLowerBase q := by classical - unfold sqrtLower + unfold sqrtLowerBase have hden : 0 ≤ (Nat.sqrt q.den + 1 : Rat) := by exact_mod_cast (Nat.zero_le _) have hnum : 0 ≤ (Nat.sqrt q.num.natAbs : Rat) := by @@ -90,10 +114,12 @@ theorem sqrtLower_nonneg (q : Rat) : 0 ≤ sqrtLower q := by /-! Strict positivity helpers. -/ -/-- `sqrtLower` is positive when its input is positive. -/ -theorem sqrtLower_pos {q : Rat} (hq : 0 < q) : 0 < sqrtLower q := by +/-! Base bounds. -/ + +/-- `sqrtLowerBase` is positive when its input is positive. -/ +theorem sqrtLowerBase_pos {q : Rat} (hq : 0 < q) : 0 < sqrtLowerBase q := by classical - unfold sqrtLower + unfold sqrtLowerBase have hnum_pos : 0 < (Nat.sqrt q.num.natAbs : Rat) := by have hnum_pos' : 0 < q.num.natAbs := by have hnum : 0 < q.num := (Rat.num_pos (a := q)).2 hq @@ -103,20 +129,20 @@ theorem sqrtLower_pos {q : Rat} (hq : 0 < q) : 0 < sqrtLower q := by exact_mod_cast (Nat.succ_pos _) exact div_pos hnum_pos hden_pos -/-- `sqrtUpper` is nonnegative. -/ -theorem sqrtUpper_nonneg (q : Rat) : 0 ≤ sqrtUpper q := by +/-- `sqrtUpperBase` is nonnegative. -/ +theorem sqrtUpperBase_nonneg (q : Rat) : 0 ≤ sqrtUpperBase q := by classical - unfold sqrtUpper + unfold sqrtUpperBase have hden : 0 ≤ (Nat.sqrt q.den : Rat) := by exact_mod_cast (Nat.zero_le _) have hnum : 0 ≤ (Nat.sqrt q.num.natAbs + 1 : Rat) := by exact_mod_cast (Nat.zero_le _) exact div_nonneg hnum hden -/-- `sqrtUpper` is always positive. -/ -theorem sqrtUpper_pos (q : Rat) : 0 < sqrtUpper q := by +/-- `sqrtUpperBase` is always positive. -/ +theorem sqrtUpperBase_pos (q : Rat) : 0 < sqrtUpperBase q := by classical - unfold sqrtUpper + unfold sqrtUpperBase have hnum_pos : 0 < (Nat.sqrt q.num.natAbs + 1 : Rat) := by exact_mod_cast (Nat.succ_pos _) have hden_pos : 0 < (Nat.sqrt q.den : Rat) := by @@ -124,9 +150,81 @@ theorem sqrtUpper_pos (q : Rat) : 0 < sqrtUpper q := by exact_mod_cast (Nat.sqrt_pos.2 hden) exact div_pos hnum_pos hden_pos +/-! Alternate bounds. -/ + +/-- `sqrtLowerAlt` is nonnegative. -/ +theorem sqrtLowerAlt_nonneg (q : Rat) : 0 ≤ sqrtLowerAlt q := by + classical + unfold sqrtLowerAlt + have hnum : 0 ≤ (Nat.sqrt (q.num.natAbs * q.den) : Rat) := by + exact_mod_cast (Nat.zero_le _) + have hden : 0 ≤ (q.den : Rat) := by + exact_mod_cast (Nat.zero_le _) + exact div_nonneg hnum hden + +/-- `sqrtLowerAlt` is positive when its input is positive. -/ +theorem sqrtLowerAlt_pos {q : Rat} (hq : 0 < q) : 0 < sqrtLowerAlt q := by + classical + unfold sqrtLowerAlt + have hnum_pos : 0 < (Nat.sqrt (q.num.natAbs * q.den) : Rat) := by + have hnum_pos' : 0 < q.num.natAbs := by + have hnum : 0 < q.num := (Rat.num_pos (a := q)).2 hq + exact Int.natAbs_pos.mpr hnum.ne' + have hden_pos : 0 < q.den := q.den_pos + have hmul_pos : 0 < q.num.natAbs * q.den := by + exact Nat.mul_pos hnum_pos' hden_pos + exact_mod_cast (Nat.sqrt_pos.2 hmul_pos) + have hden_pos : 0 < (q.den : Rat) := by + exact_mod_cast q.den_pos + exact div_pos hnum_pos hden_pos + +/-- `sqrtUpperAlt` is nonnegative. -/ +theorem sqrtUpperAlt_nonneg (q : Rat) : 0 ≤ sqrtUpperAlt q := by + classical + unfold sqrtUpperAlt + have hnum : 0 ≤ (Nat.sqrt (q.num.natAbs * q.den) + 1 : Rat) := by + exact_mod_cast (Nat.zero_le _) + have hden : 0 ≤ (q.den : Rat) := by + exact_mod_cast (Nat.zero_le _) + exact div_nonneg hnum hden + +/-- `sqrtUpperAlt` is always positive. -/ +theorem sqrtUpperAlt_pos (q : Rat) : 0 < sqrtUpperAlt q := by + classical + unfold sqrtUpperAlt + have hnum_pos : 0 < (Nat.sqrt (q.num.natAbs * q.den) + 1 : Rat) := by + exact_mod_cast (Nat.succ_pos _) + have hden_pos : 0 < (q.den : Rat) := by + exact_mod_cast q.den_pos + exact div_pos hnum_pos hden_pos + +/-! Combined bounds. -/ + +/-- `sqrtLower` is nonnegative. -/ +theorem sqrtLower_nonneg (q : Rat) : 0 ≤ sqrtLower q := by + have hbase : 0 ≤ sqrtLowerBase q := sqrtLowerBase_nonneg q + exact le_trans hbase (le_max_left _ _) + +/-- `sqrtLower` is positive when its input is positive. -/ +theorem sqrtLower_pos {q : Rat} (hq : 0 < q) : 0 < sqrtLower q := by + have hbase : 0 < sqrtLowerBase q := sqrtLowerBase_pos hq + exact lt_of_lt_of_le hbase (le_max_left _ _) + +/-- `sqrtUpper` is nonnegative. -/ +theorem sqrtUpper_nonneg (q : Rat) : 0 ≤ sqrtUpper q := by + have hbase : 0 ≤ sqrtUpperBase q := sqrtUpperBase_nonneg q + have halt : 0 ≤ sqrtUpperAlt q := sqrtUpperAlt_nonneg q + exact le_min hbase halt + +/-- `sqrtUpper` is always positive. -/ +theorem sqrtUpper_pos (q : Rat) : 0 < sqrtUpper q := by + have hbase : 0 < sqrtUpperBase q := sqrtUpperBase_pos q + have halt : 0 < sqrtUpperAlt q := sqrtUpperAlt_pos q + exact lt_min hbase halt + /-- Square-root lower bound in reals. -/ -theorem sqrtLower_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : - (sqrtLower q : Real) ≤ Real.sqrt (q : Real) := by +theorem sqrtLowerBase_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : + (sqrtLowerBase q : Real) ≤ Real.sqrt (q : Real) := by classical -- Set up numerator/denominator witnesses. set num : Nat := q.num.natAbs @@ -170,11 +268,11 @@ theorem sqrtLower_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : have hq_nonneg : 0 ≤ (q : Real) := by exact_mod_cast hq have hle : (a : Real) / (b + 1 : Real) ≤ Real.sqrt (q : Real) := (Real.le_sqrt hnonneg hq_nonneg).2 hsq - simpa [sqrtLower, num, den, a, b] using hle + simpa [sqrtLowerBase, num, den, a, b] using hle /-- Square-root upper bound in reals. -/ -theorem real_sqrt_le_sqrtUpper {q : Rat} (hq : 0 ≤ q) : - Real.sqrt (q : Real) ≤ (sqrtUpper q : Real) := by +theorem real_sqrt_le_sqrtUpperBase {q : Rat} (hq : 0 ≤ q) : + Real.sqrt (q : Real) ≤ (sqrtUpperBase q : Real) := by classical set num : Nat := q.num.natAbs set den : Nat := q.den @@ -221,7 +319,122 @@ theorem real_sqrt_le_sqrtUpper {q : Rat} (hq : 0 ≤ q) : exact div_nonneg hnum_nonneg hden_nonneg have hle : Real.sqrt (q : Real) ≤ (a + 1 : Real) / (b : Real) := (Real.sqrt_le_iff).2 ⟨hnonneg, hsq⟩ - simpa [sqrtUpper, num, den, a, b] using hle + simpa [sqrtUpperBase, num, den, a, b] using hle + +/-- Alternate square-root lower bound in reals. -/ +theorem sqrtLowerAlt_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : + (sqrtLowerAlt q : Real) ≤ Real.sqrt (q : Real) := by + classical + set num : Nat := q.num.natAbs + set den : Nat := q.den + set a : Nat := Nat.sqrt (num * den) + have hden_pos : 0 < (den : Real) := by + exact_mod_cast q.den_pos + have hnumden_le : (a ^ 2 : Real) ≤ (num * den : Nat) := by + exact_mod_cast (Nat.sqrt_le' (num * den)) + have hmul : (a ^ 2 : Real) ≤ (num : Real) * den := by + simpa [num, den, Nat.cast_mul] using hnumden_le + have hden_pos2 : 0 < (den : Real) ^ 2 := by + nlinarith [hden_pos] + have hdiv : + (a ^ 2 : Real) / (den : Real) ^ 2 ≤ (num : Real) * den / (den : Real) ^ 2 := by + have hmul' : + (a ^ 2 : Real) * (den : Real) ^ 2 ≤ (num : Real) * den * (den : Real) ^ 2 := by + have hden_sq_nonneg : 0 ≤ (den : Real) ^ 2 := by + exact sq_nonneg (den : Real) + exact mul_le_mul_of_nonneg_right hmul hden_sq_nonneg + exact (div_le_div_iff₀ hden_pos2 hden_pos2).2 hmul' + have hden_ne : (den : Real) ≠ 0 := by + exact_mod_cast q.den_pos.ne' + have hq_cast : (q : Real) = (num : Real) * den / (den : Real) ^ 2 := by + have hnum_nonneg : 0 ≤ q.num := by + exact (Rat.num_nonneg (q := q)).2 hq + have hnum_eq : (num : Int) = q.num := by + simpa [num] using (Int.natAbs_of_nonneg hnum_nonneg) + have hnum_cast : (q.num : Real) = (num : Real) := by + exact_mod_cast hnum_eq.symm + have hq_rat : (q : Real) = (q.num : Real) / q.den := by + simp [Rat.cast_def] + have hq_eq : + (num : Real) / den = (num : Real) * den / (den : Real) ^ 2 := by + field_simp [hden_ne] + simpa [hnum_cast, den, hq_eq] using hq_rat + have hsq : ((a : Real) / (den : Real)) ^ 2 ≤ (q : Real) := by + simpa [hq_cast, pow_two, div_mul_div_comm] using hdiv + have hnonneg : 0 ≤ (a : Real) / (den : Real) := by + have hnum_nonneg : 0 ≤ (a : Real) := by exact_mod_cast (Nat.zero_le a) + have hden_nonneg : 0 ≤ (den : Real) := by exact_mod_cast (Nat.zero_le den) + exact div_nonneg hnum_nonneg hden_nonneg + have hq_nonneg : 0 ≤ (q : Real) := by exact_mod_cast hq + have hle : (a : Real) / (den : Real) ≤ Real.sqrt (q : Real) := + (Real.le_sqrt hnonneg hq_nonneg).2 hsq + simpa [sqrtLowerAlt, num, den, a] using hle + +/-- Alternate square-root upper bound in reals. -/ +theorem real_sqrt_le_sqrtUpperAlt {q : Rat} (hq : 0 ≤ q) : + Real.sqrt (q : Real) ≤ (sqrtUpperAlt q : Real) := by + classical + set num : Nat := q.num.natAbs + set den : Nat := q.den + set a : Nat := Nat.sqrt (num * den) + have hden_pos : 0 < (den : Real) := by + exact_mod_cast q.den_pos + have hnumden_lt : (num * den : Real) < (a + 1) ^ 2 := by + exact_mod_cast (Nat.lt_succ_sqrt' (num * den)) + have hmul : (num : Real) * den ≤ (a + 1 : Real) ^ 2 := by + exact le_of_lt hnumden_lt + have hden_pos2 : 0 < (den : Real) ^ 2 := by + nlinarith [hden_pos] + have hdiv : + (num : Real) * den / (den : Real) ^ 2 ≤ (a + 1 : Real) ^ 2 / (den : Real) ^ 2 := by + have hmul' : + (num : Real) * den * (den : Real) ^ 2 ≤ (a + 1 : Real) ^ 2 * (den : Real) ^ 2 := by + have hden_sq_nonneg : 0 ≤ (den : Real) ^ 2 := by + exact sq_nonneg (den : Real) + exact mul_le_mul_of_nonneg_right hmul hden_sq_nonneg + exact (div_le_div_iff₀ hden_pos2 hden_pos2).2 hmul' + have hden_ne : (den : Real) ≠ 0 := by + exact_mod_cast q.den_pos.ne' + have hq_cast : (q : Real) = (num : Real) * den / (den : Real) ^ 2 := by + have hnum_nonneg : 0 ≤ q.num := by + exact (Rat.num_nonneg (q := q)).2 hq + have hnum_eq : (num : Int) = q.num := by + simpa [num] using (Int.natAbs_of_nonneg hnum_nonneg) + have hnum_cast : (q.num : Real) = (num : Real) := by + exact_mod_cast hnum_eq.symm + have hq_rat : (q : Real) = (q.num : Real) / q.den := by + simp [Rat.cast_def] + have hq_eq : + (num : Real) / den = (num : Real) * den / (den : Real) ^ 2 := by + field_simp [hden_ne] + simpa [hnum_cast, den, hq_eq] using hq_rat + have hpow : + ((a + 1 : Real) / (den : Real)) ^ 2 = + (a + 1 : Real) ^ 2 / (den : Real) ^ 2 := by + simp [pow_two, div_mul_div_comm] + have hsq : (q : Real) ≤ ((a + 1 : Real) / (den : Real)) ^ 2 := by + simpa [hq_cast, hpow] using hdiv + have hnonneg : 0 ≤ ((a + 1 : Real) / (den : Real)) := by + have hnum_nonneg : 0 ≤ (a + 1 : Real) := by exact_mod_cast (Nat.zero_le (a + 1)) + have hden_nonneg : 0 ≤ (den : Real) := by exact_mod_cast (Nat.zero_le den) + exact div_nonneg hnum_nonneg hden_nonneg + have hle : Real.sqrt (q : Real) ≤ (a + 1 : Real) / (den : Real) := + (Real.sqrt_le_iff).2 ⟨hnonneg, hsq⟩ + simpa [sqrtUpperAlt, num, den, a] using hle + +/-- Square-root lower bound in reals (tighter of two bounds). -/ +theorem sqrtLower_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : + (sqrtLower q : Real) ≤ Real.sqrt (q : Real) := by + have hbase := sqrtLowerBase_le_real_sqrt (q := q) hq + have halt := sqrtLowerAlt_le_real_sqrt (q := q) hq + simpa [sqrtLower] using (max_le_iff).2 ⟨hbase, halt⟩ + +/-- Square-root upper bound in reals (tighter of two bounds). -/ +theorem real_sqrt_le_sqrtUpper {q : Rat} (hq : 0 ≤ q) : + Real.sqrt (q : Real) ≤ (sqrtUpper q : Real) := by + have hbase := real_sqrt_le_sqrtUpperBase (q := q) hq + have halt := real_sqrt_le_sqrtUpperAlt (q := q) hq + simpa [sqrtUpper] using (le_min_iff).2 ⟨hbase, halt⟩ /-- Bounds for multiplying a scalar by a bounded value. -/ def scaleInterval (x lo hi : Rat) : Rat × Rat := diff --git a/Nfp/Sound/Gpt2/HeadInputs.lean b/Nfp/Sound/Gpt2/HeadInputs.lean index 3d9eeab..7d66d41 100644 --- a/Nfp/Sound/Gpt2/HeadInputs.lean +++ b/Nfp/Sound/Gpt2/HeadInputs.lean @@ -39,6 +39,8 @@ def buildInductionHeadInputs {seq dModel dHead vocab : Nat} bv := slice.bv wo := slice.wo attnBias := slice.attnBias + maskCausal := true + maskValue := (-10000 : Rat) directionSpec := slice.direction.spec direction := slice.directionVec } @@ -61,6 +63,8 @@ theorem buildInductionHeadInputs_def {seq dModel dHead vocab : Nat} bv := slice.bv wo := slice.wo attnBias := slice.attnBias + maskCausal := true + maskValue := (-10000 : Rat) directionSpec := slice.direction.spec direction := slice.directionVec } := rfl diff --git a/Nfp/Sound/Induction.lean b/Nfp/Sound/Induction.lean index 08492f5..77c91b2 100644 --- a/Nfp/Sound/Induction.lean +++ b/Nfp/Sound/Induction.lean @@ -13,6 +13,7 @@ import Nfp.Circuit.Layers.Softmax import Nfp.Model.InductionHead import Nfp.Sound.Bounds.LayerNorm import Nfp.Sound.Bounds.MatrixNorm +import Nfp.Sound.Induction.OneHot import Nfp.Sound.Linear.FinFold /-! @@ -65,11 +66,19 @@ private noncomputable def vRealOfInputs {seq dModel dHead : Nat} dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs q) + (inputs.bv d : Real) /-- Real-valued attention scores for head inputs. -/ -private noncomputable def scoresRealOfInputs {seq dModel dHead : Nat} +noncomputable def scoresRealOfInputs {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin seq → Real := fun q k => - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) + let base := + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) + if inputs.maskCausal then + if k ≤ q then + base + else + (inputs.maskValue : Real) + else + base /-- Real-valued per-key head outputs in model space. -/ private noncomputable def headValueRealOfInputs {seq dModel dHead : Nat} @@ -78,7 +87,7 @@ private noncomputable def headValueRealOfInputs {seq dModel dHead : Nat} dotProduct (fun d => (inputs.wo i d : Real)) (fun d => vRealOfInputs inputs k d) /-- Real-valued direction scores for head inputs. -/ -private noncomputable def valsRealOfInputs {seq dModel dHead : Nat} +noncomputable def valsRealOfInputs {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Real := let dirHead : Fin dHead → Real := fun d => (dirHeadVecOfInputs inputs).get d fun k => dotProduct dirHead (fun d => vRealOfInputs inputs k d) @@ -113,6 +122,8 @@ structure ValueIntervalBounds {seq : Nat} (vals : Fin seq → Real) structure InductionHeadCert (seq : Nat) where /-- Weight tolerance. -/ eps : Rat + /-- Per-query weight tolerance derived from local margins. -/ + epsAt : Fin seq → Rat /-- Score margin used to justify the weight tolerance. -/ margin : Rat /-- Active queries for which bounds are required. -/ @@ -131,6 +142,12 @@ structure InductionHeadCertSound [NeZero seq] {dModel dHead : Nat} (fun q => q ∈ c.active) c.prev (scoresRealOfInputs inputs) (fun q k => Circuit.softmax (scoresRealOfInputs inputs q) k) + /-- Per-query one-hot bounds derived from local margins. -/ + oneHot_bounds_at : + ∀ q, q ∈ c.active → + Layers.OneHotApproxBoundsOnActive (Val := Real) (c.epsAt q : Real) + (fun q' => q' = q) c.prev + (fun q' k => Circuit.softmax (scoresRealOfInputs inputs q') k) /-- Interval bounds hold for the direction values. -/ value_bounds : ValueIntervalBounds (vals := valsRealOfInputs inputs) c.values @@ -233,14 +250,18 @@ def buildInductionCertFromHead? [NeZero seq] {dModel dHead : Nat} max |qLo q d| |qHi q d| let kAbs : Fin seq → Fin dHead → Rat := fun q d => max |kLo q d| |kHi q d| + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal ∧ q < k let dotAbs : Fin seq → Fin seq → Rat := fun q k => dotProduct (fun d => qAbs q d) (fun d => kAbs k d) - let scoreAbs : Fin seq → Fin seq → Rat := fun q k => + let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => |inputs.scale| * dotAbs q k + let scoreAbs : Fin seq → Fin seq → Rat := fun q k => + if masked q k then |inputs.maskValue| else scoreBaseAbs q k let scoreLo : Fin seq → Fin seq → Rat := fun q k => - -scoreAbs q k + if masked q k then inputs.maskValue else -scoreBaseAbs q k let scoreHi : Fin seq → Fin seq → Rat := fun q k => - scoreAbs q k + if masked q k then inputs.maskValue else scoreBaseAbs q k let otherKeys : Fin seq → Finset (Fin seq) := fun q => (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) let marginAt : Fin seq → Rat := fun q => @@ -249,6 +270,11 @@ def buildInductionCertFromHead? [NeZero seq] {dModel dHead : Nat} other.inf' h (fun k => scoreLo q (inputs.prev q) - scoreHi q k) else (0 : Rat) + let epsAt : Fin seq → Rat := fun q => + if marginAt q < 0 then + (1 : Rat) + else + (seq - 1 : Rat) / (1 + marginAt q) let margin : Rat := if h : inputs.active.Nonempty then inputs.active.inf' h marginAt @@ -279,6 +305,7 @@ def buildInductionCertFromHead? [NeZero seq] {dModel dHead : Nat} direction := some inputs.directionSpec } let cert : InductionHeadCert seq := { eps := eps + epsAt := epsAt margin := margin active := inputs.active prev := inputs.prev @@ -351,6 +378,9 @@ def buildInductionCertFromHead? [NeZero seq] {dModel dHead : Nat} scoresRealOfInputs inputs q k ≤ (scoreHi q k : Real) := by intro q k let scoresReal := scoresRealOfInputs inputs + let base := + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) have hq_abs : ∀ d, |qRealOfInputs inputs q d| ≤ (qAbs q d : Real) := by intro d have hq := hq_bounds q d @@ -402,26 +432,55 @@ def buildInductionCertFromHead? [NeZero seq] {dModel dHead : Nat} simpa [dotProduct] using hfinal have hscale_abs : 0 ≤ (|inputs.scale| : Real) := by exact_mod_cast (abs_nonneg (inputs.scale)) - have hscore_abs : - |scoresReal q k| ≤ (scoreAbs q k : Real) := by + have hbase_abs : + |base| ≤ (scoreBaseAbs q k : Real) := by have hdot_abs' := hdot_abs have hmul : - |scoresReal q k| = + |base| = (|inputs.scale| : Real) * |dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d)| := by - simp [scoresReal, scoresRealOfInputs, abs_mul] + simp [base, abs_mul] have hmul_le : (|inputs.scale| : Real) * |dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d)| ≤ (|inputs.scale| : Real) * (dotAbs q k : Real) := by exact mul_le_mul_of_nonneg_left hdot_abs' hscale_abs - simpa [scoreAbs, hmul] using hmul_le - have hscore_bounds := (abs_le).1 hscore_abs - constructor - · simpa [scoresReal, scoreLo] using hscore_bounds.1 - · simpa [scoresReal, scoreHi] using hscore_bounds.2 + simpa [scoreBaseAbs, hmul] using hmul_le + by_cases hcausal : inputs.maskCausal + · by_cases hle : k ≤ q + · have hnot : ¬ q < k := not_lt_of_ge hle + have hscore_eq : scoresReal q k = base := by + simp [scoresReal, scoresRealOfInputs, hcausal, hle, base] + have hscore_abs' : |scoresReal q k| ≤ (scoreBaseAbs q k : Real) := by + simpa [hscore_eq] using hbase_abs + have hscore_abs : + |scoresReal q k| ≤ (scoreAbs q k : Real) := by + simpa [scoreAbs, masked, hcausal, hnot] using hscore_abs' + have hscore_bounds := (abs_le).1 hscore_abs + constructor + · simpa [scoresReal, scoreLo, scoreAbs, masked, hcausal, hnot] + using hscore_bounds.1 + · simpa [scoresReal, scoreHi, scoreAbs, masked, hcausal, hnot] + using hscore_bounds.2 + · have hlt : q < k := lt_of_not_ge hle + constructor + · simp [scoresRealOfInputs, scoreLo, masked, hcausal, hle, hlt] + · simp [scoresRealOfInputs, scoreHi, masked, hcausal, hle, hlt] + · have hscore_eq : scoresReal q k = base := by + simp [scoresReal, scoresRealOfInputs, hcausal, base] + have hscore_abs' : |scoresReal q k| ≤ (scoreBaseAbs q k : Real) := by + simpa [hscore_eq] using hbase_abs + have hscore_abs : + |scoresReal q k| ≤ (scoreAbs q k : Real) := by + simpa [scoreAbs, masked, hcausal] using hscore_abs' + have hscore_bounds := (abs_le).1 hscore_abs + constructor + · simpa [scoresReal, scoreLo, scoreAbs, masked, hcausal] + using hscore_bounds.1 + · simpa [scoresReal, scoreHi, scoreAbs, masked, hcausal] + using hscore_bounds.2 let scoresReal := scoresRealOfInputs inputs let weights : Fin seq → Fin seq → Real := fun q k => Circuit.softmax (scoresReal q) k @@ -654,6 +713,89 @@ def buildInductionCertFromHead? [NeZero seq] {dModel dHead : Nat} have h := Finset.single_le_sum hnonneg hk' simpa using h exact hle.trans hsum_others_le + have hscore_margin_real_at : + ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → + scoresReal q k + (marginAt q : Real) ≤ scoresReal q (inputs.prev q) := by + intro q hq k hk + let other := otherKeys q + have hother : other.Nonempty := by + refine ⟨k, ?_⟩ + simp [other, otherKeys, hk] + have hgap_le : + marginAt q ≤ scoreLo q (inputs.prev q) - scoreHi q k := by + have hkmem : k ∈ other := by + simp [other, otherKeys, hk] + have hle : + other.inf' hother (fun k => scoreLo q (inputs.prev q) - scoreHi q k) ≤ + scoreLo q (inputs.prev q) - scoreHi q k := by + exact (Finset.inf'_le (s := other) (f := fun k => + scoreLo q (inputs.prev q) - scoreHi q k) (b := k) hkmem) + have hmarginAt : + marginAt q = + other.inf' hother (fun k => scoreLo q (inputs.prev q) - scoreHi q k) := by + simp [marginAt, hother, other] + simpa [hmarginAt] using hle + have hgap_real : + (marginAt q : Real) ≤ + (scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real) := by + have hgap_real' : + (marginAt q : Real) ≤ + ((scoreLo q (inputs.prev q) - scoreHi q k : Rat) : Real) := + (Rat.cast_le (K := Real)).2 hgap_le + simpa [Rat.cast_sub] using hgap_real' + have hk_bounds := hscore_bounds q k + have hprev_bounds := hscore_bounds q (inputs.prev q) + have h1 : + scoresReal q k + (marginAt q : Real) ≤ + (scoreHi q k : Real) + (marginAt q : Real) := by + have h1' := add_le_add_right hk_bounds.2 (marginAt q : Real) + simpa [scoresReal] using h1' + have h2 : + (scoreHi q k : Real) + (marginAt q : Real) ≤ + (scoreLo q (inputs.prev q) : Real) := by + have hgap_real' : + (scoreHi q k : Real) + (marginAt q : Real) ≤ + (scoreHi q k : Real) + + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) := by + exact add_le_add_right hgap_real (scoreHi q k : Real) + have hgap_real'' : + (scoreHi q k : Real) + + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) = + (scoreLo q (inputs.prev q) : Real) := by + calc + (scoreHi q k : Real) + + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) = + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) + + (scoreHi q k : Real) := by + exact add_comm _ _ + _ = (scoreLo q (inputs.prev q) : Real) := by + exact sub_add_cancel (scoreLo q (inputs.prev q) : Real) (scoreHi q k : Real) + exact hgap_real'.trans (le_of_eq hgap_real'') + have h3 : + scoresReal q k + (marginAt q : Real) ≤ + (scoreLo q (inputs.prev q) : Real) := h1.trans h2 + exact h3.trans hprev_bounds.1 + have hepsAt : + ∀ q, epsAt q = + if marginAt q < 0 then (1 : Rat) else (seq - 1 : Rat) / (1 + marginAt q) := by + intro q + simp [epsAt] + have oneHot_bounds_at : + ∀ q, q ∈ inputs.active → + Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAt q : Real) + (fun q' => q' = q) inputs.prev weights := by + intro q hq + exact + Sound.oneHot_bounds_at_of_marginAt + (active := inputs.active) + (prev := inputs.prev) + (scoresReal := scoresReal) + (marginAt := marginAt) + (epsAt := epsAt) + (hepsAt := hepsAt) + (hseq := hseq) + (hscore_margin_real_at := hscore_margin_real_at) + q hq have hvals_bounds : ValueIntervalBounds (vals := valsRealOfInputs inputs) valCert := by refine @@ -744,7 +886,10 @@ def buildInductionCertFromHead? [NeZero seq] {dModel dHead : Nat} refine ⟨k, hmem, ?_⟩ exact le_rfl exact (Rat.cast_le (K := Real)).2 hhiRat - exact some ⟨cert, { softmax_bounds := hsoftmax_bounds, value_bounds := hvals_bounds }⟩ + exact some ⟨cert, + { softmax_bounds := hsoftmax_bounds + oneHot_bounds_at := oneHot_bounds_at + value_bounds := hvals_bounds }⟩ · exact none section HeadOutputInterval diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean new file mode 100644 index 0000000..bdbd8d5 --- /dev/null +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -0,0 +1,186 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Circuit.Cert.LogitDiff +import Nfp.Sound.Induction + +/-! +Logit-diff bounds derived from induction certificates. +-/ + +namespace Nfp + +namespace Sound + +open Nfp.Circuit + +section LogitDiffLowerBound + +variable {seq dModel dHead : Nat} [NeZero seq] + +/-- Real-valued logit-diff contribution for a query. -/ +noncomputable def headLogitDiff (inputs : Model.InductionHeadInputs seq dModel dHead) + (q : Fin seq) : Real := + let weights : Fin seq → Fin seq → Real := fun q k => + Circuit.softmax (scoresRealOfInputs inputs q) k + dotProduct (weights q) (valsRealOfInputs inputs) + +/-- Lower bound computed from the per-key lower bounds in an induction certificate. -/ +def logitDiffLowerBoundFromCert (c : InductionHeadCert seq) : Option Rat := + Circuit.logitDiffLowerBoundAt c.active c.prev c.epsAt + c.values.lo c.values.hi c.values.valsLo + +theorem logitDiffLowerBoundFromCert_le + (inputs : Model.InductionHeadInputs seq dModel dHead) + (c : InductionHeadCert seq) (hsound : InductionHeadCertSound inputs c) + {lb : Rat} (hbound : logitDiffLowerBoundFromCert c = some lb) + {q : Fin seq} (hq : q ∈ c.active) : + (lb : Real) ≤ headLogitDiff inputs q := by + classical + cases seq with + | zero => + cases (NeZero.ne (n := (0 : Nat)) rfl) + | succ n => + let weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := fun q k => + Circuit.softmax (scoresRealOfInputs inputs q) k + have hweights : + Layers.OneHotApproxBoundsOnActive (Val := Real) (c.epsAt q : Real) + (fun q' => q' = q) c.prev weights := + hsound.oneHot_bounds_at q hq + have hvalsRange : + Layers.ValueRangeBounds (Val := Real) (c.values.lo : Real) (c.values.hi : Real) + (valsRealOfInputs inputs) := by + refine { lo_le_hi := ?_, lo_le := ?_, le_hi := ?_ } + · exact (Rat.cast_le (K := Real)).2 hsound.value_bounds.lo_le_hi + · intro k + exact + le_trans (hsound.value_bounds.lo_le_valsLo k) + (hsound.value_bounds.vals_bounds k).1 + · intro k + exact + le_trans (hsound.value_bounds.vals_bounds k).2 + (hsound.value_bounds.valsHi_le_hi k) + have happrox := + Layers.inductionSpecApproxOn_of_oneHotApprox_valueRange + (Val := Real) + (n := n) + (ε := (c.epsAt q : Real)) + (lo := (c.values.lo : Real)) + (hi := (c.values.hi : Real)) + (active := fun q' => q' = q) + (prev := c.prev) + (weights := weights) + (vals := valsRealOfInputs inputs) + hweights hvalsRange + have hboundRat : + lb ≤ c.values.valsLo (c.prev q) - + c.epsAt q * (c.values.hi - c.values.lo) := by + refine + Circuit.logitDiffLowerBoundAt_le + (active := c.active) + (prev := c.prev) + (epsAt := c.epsAt) + (lo := c.values.lo) + (hi := c.values.hi) + (vals := c.values.valsLo) + q hq lb ?_ + simpa [logitDiffLowerBoundFromCert] using hbound + have hboundReal : + (lb : Real) ≤ + (c.values.valsLo (c.prev q) : Real) - + (c.epsAt q : Real) * ((c.values.hi : Real) - (c.values.lo : Real)) := by + have hboundReal' : + (lb : Real) ≤ + (c.values.valsLo (c.prev q) - c.epsAt q * (c.values.hi - c.values.lo) : Rat) := by + exact (Rat.cast_le (K := Real)).2 hboundRat + simpa [Rat.cast_sub, Rat.cast_mul] using hboundReal' + have hvalsLo : + (c.values.valsLo (c.prev q) : Real) ≤ + valsRealOfInputs inputs (c.prev q) := by + exact (hsound.value_bounds.vals_bounds (c.prev q)).1 + have hvalsLo' : + (c.values.valsLo (c.prev q) : Real) - + (c.epsAt q : Real) * ((c.values.hi : Real) - (c.values.lo : Real)) ≤ + valsRealOfInputs inputs (c.prev q) - + (c.epsAt q : Real) * ((c.values.hi : Real) - (c.values.lo : Real)) := by + exact + sub_le_sub_right hvalsLo + ((c.epsAt q : Real) * ((c.values.hi : Real) - (c.values.lo : Real))) + have hlow : + valsRealOfInputs inputs (c.prev q) - + (c.epsAt q : Real) * ((c.values.hi : Real) - (c.values.lo : Real)) ≤ + dotProduct (weights q) (valsRealOfInputs inputs) := by + exact (sub_le_iff_le_add).2 (happrox q rfl).2 + have hle : + (lb : Real) ≤ dotProduct (weights q) (valsRealOfInputs inputs) := + le_trans hboundReal (le_trans hvalsLo' hlow) + simpa [headLogitDiff, weights] using hle + +/-- Certified logit-diff lower bound derived from exact head inputs. -/ +structure InductionLogitLowerBoundResult + (inputs : Model.InductionHeadInputs seq dModel dHead) where + /-- Induction certificate built from the head inputs. -/ + cert : InductionHeadCert seq + /-- Soundness proof for the induction certificate. -/ + sound : InductionHeadCertSound inputs cert + /-- Reported lower bound on logit diff. -/ + lb : Rat + /-- `lb` is computed from `logitDiffLowerBoundFromCert`. -/ + lb_def : logitDiffLowerBoundFromCert cert = some lb + /-- The lower bound is sound on active queries. -/ + lb_sound : ∀ q, q ∈ cert.active → (lb : Real) ≤ headLogitDiff inputs q + +/-- Nonvacuous logit-diff bound (strictly positive). -/ +structure InductionLogitLowerBoundNonvacuous + (inputs : Model.InductionHeadInputs seq dModel dHead) where + /-- Base logit-diff bound data. -/ + base : InductionLogitLowerBoundResult inputs + /-- The reported bound is strictly positive. -/ + lb_pos : 0 < base.lb + +/-- Build a logit-diff lower bound from exact head inputs. -/ +def buildInductionLogitLowerBoundFromHead? + (inputs : Model.InductionHeadInputs seq dModel dHead) : + Option (InductionLogitLowerBoundResult inputs) := by + classical + cases hcert : buildInductionCertFromHead? inputs with + | none => exact none + | some certWithProof => + rcases certWithProof with ⟨cert, hsound⟩ + cases hlb : logitDiffLowerBoundFromCert cert with + | none => exact none + | some lb => + refine some ?_ + refine + { cert := cert + sound := hsound + lb := lb + lb_def := hlb + lb_sound := ?_ } + intro q hq + exact + logitDiffLowerBoundFromCert_le + (inputs := inputs) + (c := cert) + (hsound := hsound) + (lb := lb) + (hbound := hlb) + (q := q) + hq + +/-- Build a strictly positive logit-diff lower bound from exact head inputs. -/ +def buildInductionLogitLowerBoundNonvacuous? + (inputs : Model.InductionHeadInputs seq dModel dHead) : + Option (InductionLogitLowerBoundNonvacuous inputs) := by + classical + cases hbase : buildInductionLogitLowerBoundFromHead? inputs with + | none => exact none + | some base => + by_cases hpos : 0 < base.lb + · exact some { base := base, lb_pos := hpos } + · exact none + +end LogitDiffLowerBound + +end Sound + +end Nfp diff --git a/Nfp/Sound/Induction/OneHot.lean b/Nfp/Sound/Induction/OneHot.lean new file mode 100644 index 0000000..fd2cee8 --- /dev/null +++ b/Nfp/Sound/Induction/OneHot.lean @@ -0,0 +1,216 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.BigOperators.Group.Finset.Basic +import Mathlib.Algebra.Order.BigOperators.Group.Finset +import Mathlib.Algebra.Order.Ring.Rat +import Mathlib.Data.Rat.Cast.Order +import Nfp.Circuit.Layers.Induction +import Nfp.Circuit.Layers.Softmax + +/-! +Per-query one-hot bounds derived from score margins. +-/ + +namespace Nfp + +namespace Sound + +open scoped BigOperators + +open Nfp.Circuit + +variable {seq : Nat} [NeZero seq] + +/-- One-hot bounds on a single active query, derived from a per-query margin. -/ +theorem oneHot_bounds_at_of_marginAt + (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (scoresReal : Fin seq → Fin seq → Real) + (marginAt : Fin seq → Rat) + (epsAt : Fin seq → Rat) + (hepsAt : + ∀ q, epsAt q = + if marginAt q < 0 then (1 : Rat) else (seq - 1 : Rat) / (1 + marginAt q)) + (hseq : (1 : Nat) ≤ seq) + (hscore_margin_real_at : + ∀ q, q ∈ active → ∀ k, k ≠ prev q → + scoresReal q k + (marginAt q : Real) ≤ scoresReal q (prev q)) : + ∀ q, q ∈ active → + Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAt q : Real) + (fun q' => q' = q) prev + (fun q k => Circuit.softmax (scoresReal q) k) := by + classical + intro q hq + let weights : Fin seq → Fin seq → Real := fun q k => + Circuit.softmax (scoresReal q) k + let others : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (prev q) + refine + { nonneg := ?_ + sum_one := ?_ + prev_large := ?_ + other_le := ?_ } + · intro q' hq' k + subst q' + change 0 ≤ Circuit.softmax (scoresReal q) k + exact Circuit.softmax_nonneg (scores := scoresReal q) k + · intro q' hq' + subst q' + change (∑ k, Circuit.softmax (scoresReal q) k) = 1 + exact Circuit.softmax_sum_one (scores := scoresReal q) + · intro q' hq' + subst q' + have hsum_others_le : (∑ k ∈ others q, weights q k) ≤ (epsAt q : Real) := by + by_cases hneg : marginAt q < 0 + · have heps : (epsAt q : Real) = 1 := by + simp [hepsAt, hneg] + have hsubset : others q ⊆ (Finset.univ : Finset (Fin seq)) := by + intro k hk + simp + have hnonneg : + ∀ k ∈ (Finset.univ : Finset (Fin seq)), 0 ≤ weights q k := by + intro k _ + simpa [weights] using + (Circuit.softmax_nonneg (scores := scoresReal q) k) + have hsum_le : + (∑ k ∈ others q, weights q k) ≤ + ∑ k ∈ (Finset.univ : Finset (Fin seq)), weights q k := + Finset.sum_le_sum_of_subset_of_nonneg hsubset (by + intro k hk _; exact hnonneg k hk) + have hsum_one : (∑ k, weights q k) = 1 := by + simpa [weights] using + (Circuit.softmax_sum_one (scores := scoresReal q)) + have hsum_le' : (∑ k ∈ others q, weights q k) ≤ 1 := by + simpa [hsum_one] using hsum_le + simpa [heps] using hsum_le' + · have hnonneg : 0 ≤ marginAt q := le_of_not_gt hneg + have hnonneg_real : 0 ≤ (marginAt q : Real) := by + exact (Rat.cast_nonneg (K := Real)).2 hnonneg + have hbound : + ∀ k ∈ others q, + weights q k ≤ (1 + (marginAt q : Real))⁻¹ := by + intro k hk + have hkne : k ≠ prev q := (Finset.mem_erase.mp hk).1 + have hscore := hscore_margin_real_at q hq k hkne + simpa [weights] using + (Circuit.softmax_other_le_inv_one_add (scores := scoresReal q) + (prev := prev q) (k := k) (m := (marginAt q : Real)) + hnonneg_real hscore) + have hsum_le : + (∑ k ∈ others q, weights q k) ≤ + ∑ k ∈ others q, (1 + (marginAt q : Real))⁻¹ := + Finset.sum_le_sum hbound + have hsum_const : + (∑ k ∈ others q, (1 + (marginAt q : Real))⁻¹) = + (others q).card * (1 + (marginAt q : Real))⁻¹ := by + simp + have hcard : (others q).card = seq - 1 := by + simp [others, Finset.card_erase_of_mem] + have hsum_le' : + (∑ k ∈ others q, weights q k) ≤ + (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ := by + have hsum_le'' := hsum_le.trans_eq hsum_const + have hsum_le''' := hsum_le'' + simp only [hcard, Nat.cast_sub hseq, Nat.cast_one] at hsum_le''' + exact hsum_le''' + have heps : + (epsAt q : Real) = (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ := by + simp [hepsAt, hneg, Rat.cast_add, div_eq_mul_inv] + simpa [heps] using hsum_le' + have hsum_eq : + weights q (prev q) + ∑ k ∈ others q, weights q k = 1 := by + have hsum' : + weights q (prev q) + ∑ k ∈ others q, weights q k = + ∑ k, weights q k := by + simp [others] + have hsum_one : (∑ k, weights q k) = 1 := by + simpa [weights] using + (Circuit.softmax_sum_one (scores := scoresReal q)) + calc + weights q (prev q) + ∑ k ∈ others q, weights q k = + ∑ k, weights q k := hsum' + _ = 1 := hsum_one + have hsum_le' : + weights q (prev q) + ∑ k ∈ others q, weights q k ≤ + weights q (prev q) + (epsAt q : Real) := by + have hsum_le'' := add_le_add_left hsum_others_le (weights q (prev q)) + simpa [add_comm, add_left_comm, add_assoc] using hsum_le'' + have hprev : + 1 ≤ weights q (prev q) + (epsAt q : Real) := by + simpa [hsum_eq] using hsum_le' + exact hprev + · intro q' hq' k hk + subst q' + have hsum_others_le : (∑ j ∈ others q, weights q j) ≤ (epsAt q : Real) := by + by_cases hneg : marginAt q < 0 + · have heps : (epsAt q : Real) = 1 := by + simp [hepsAt, hneg] + have hsubset : others q ⊆ (Finset.univ : Finset (Fin seq)) := by + intro j hj + simp + have hnonneg : + ∀ j ∈ (Finset.univ : Finset (Fin seq)), 0 ≤ weights q j := by + intro j _ + simpa [weights] using + (Circuit.softmax_nonneg (scores := scoresReal q) j) + have hsum_le : + (∑ j ∈ others q, weights q j) ≤ + ∑ j ∈ (Finset.univ : Finset (Fin seq)), weights q j := + Finset.sum_le_sum_of_subset_of_nonneg hsubset (by + intro j hj _; exact hnonneg j hj) + have hsum_one : (∑ j, weights q j) = 1 := by + simpa [weights] using + (Circuit.softmax_sum_one (scores := scoresReal q)) + have hsum_le' : (∑ j ∈ others q, weights q j) ≤ 1 := by + simpa [hsum_one] using hsum_le + simpa [heps] using hsum_le' + · have hnonneg : 0 ≤ marginAt q := le_of_not_gt hneg + have hnonneg_real : 0 ≤ (marginAt q : Real) := by + exact (Rat.cast_nonneg (K := Real)).2 hnonneg + have hbound : + ∀ j ∈ others q, + weights q j ≤ (1 + (marginAt q : Real))⁻¹ := by + intro j hj + have hjne : j ≠ prev q := (Finset.mem_erase.mp hj).1 + have hscore := hscore_margin_real_at q hq j hjne + simpa [weights] using + (Circuit.softmax_other_le_inv_one_add (scores := scoresReal q) + (prev := prev q) (k := j) (m := (marginAt q : Real)) + hnonneg_real hscore) + have hsum_le : + (∑ j ∈ others q, weights q j) ≤ + ∑ j ∈ others q, (1 + (marginAt q : Real))⁻¹ := + Finset.sum_le_sum hbound + have hsum_const : + (∑ j ∈ others q, (1 + (marginAt q : Real))⁻¹) = + (others q).card * (1 + (marginAt q : Real))⁻¹ := by + simp + have hcard : (others q).card = seq - 1 := by + simp [others, Finset.card_erase_of_mem] + have hsum_le' : + (∑ j ∈ others q, weights q j) ≤ + (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ := by + have hsum_le'' := hsum_le.trans_eq hsum_const + have hsum_le''' := hsum_le'' + simp only [hcard, Nat.cast_sub hseq, Nat.cast_one] at hsum_le''' + exact hsum_le''' + have heps : + (epsAt q : Real) = (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ := by + simp [hepsAt, hneg, Rat.cast_add, div_eq_mul_inv] + simpa [heps] using hsum_le' + have hk' : k ∈ others q := by + simp [others, hk] + have hnonneg : + ∀ j ∈ others q, 0 ≤ weights q j := by + intro j _ + simpa [weights] using + (Circuit.softmax_nonneg (scores := scoresReal q) j) + have hle : + weights q k ≤ ∑ j ∈ others q, weights q j := by + have h := Finset.single_le_sum hnonneg hk' + simpa using h + exact hle.trans hsum_others_le + +end Sound + +end Nfp diff --git a/scripts/build_gpt2_head_inputs.py b/scripts/build_gpt2_head_inputs.py index 39f9878..4da754f 100644 --- a/scripts/build_gpt2_head_inputs.py +++ b/scripts/build_gpt2_head_inputs.py @@ -200,6 +200,8 @@ def write_head_inputs( ln_eps: Fraction, ln1_gamma: np.ndarray, ln1_beta: np.ndarray, + mask_causal: bool, + mask_value: Fraction, direction_target: int, direction_negative: int, direction: np.ndarray, @@ -244,6 +246,8 @@ def write_head_inputs( f.write(f"wo {i} {j} {rat_to_str(rat_from_float_exact(float(wo[i, j])))}\n") for d in range(model_dim): f.write(f"attn_bias {d} {rat_to_str(rat_from_float_exact(float(attn_bias[d])))}\n") + f.write(f"mask {'causal' if mask_causal else 'none'}\n") + f.write(f"mask_value {rat_to_str(mask_value)}\n") f.write(f"direction-target {direction_target}\n") f.write(f"direction-negative {direction_negative}\n") for d in range(model_dim): @@ -321,6 +325,9 @@ def main() -> None: # Stored W_O is (head_dim, model_dim); transpose to model_dim × head_dim. wo = wo_raw.T + mask_causal = True + mask_value = Fraction(-10000, 1) + args.output.parent.mkdir(parents=True, exist_ok=True) ln_eps_raw = header.get("layer_norm_eps") if ln_eps_raw is None: @@ -344,6 +351,8 @@ def main() -> None: ln_eps, ln1_gamma, ln1_beta, + mask_causal, + mask_value, args.direction_target, args.direction_negative, direction, diff --git a/scripts/build_gpt2_induction_cert.py b/scripts/build_gpt2_induction_cert.py index 05b38d5..80be5f1 100644 --- a/scripts/build_gpt2_induction_cert.py +++ b/scripts/build_gpt2_induction_cert.py @@ -94,7 +94,7 @@ def compute_scores_weights(model, input_ids, layer: int, head: int, device: str) scores = torch.matmul(qh, kh.transpose(-2, -1)) / math.sqrt(head_dim) seq = scores.shape[-1] mask = torch.triu(torch.ones(seq, seq, device=device), diagonal=1).bool() - scores = scores.masked_fill(mask, -1e9) + scores = scores.masked_fill(mask, -10000.0) weights = torch.softmax(scores, dim=-1) return (scores.squeeze(0).cpu().numpy(), weights.squeeze(0).cpu().numpy(), diff --git a/scripts/build_gpt2_induction_cert_from_binary.py b/scripts/build_gpt2_induction_cert_from_binary.py new file mode 100644 index 0000000..8229320 --- /dev/null +++ b/scripts/build_gpt2_induction_cert_from_binary.py @@ -0,0 +1,377 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: AGPL-3.0-or-later + +""" +Build a softmax-margin certificate and value-range certificate from an NFP_BINARY_V1 model. + +This is untrusted and uses floating-point arithmetic to produce rational certificates +compatible with `nfp induction certify`. +""" + +from __future__ import annotations + +import argparse +import math +import struct +from fractions import Fraction +from pathlib import Path +from typing import Dict, Tuple + +import numpy as np + + +def rat_from_float(x: float, decimals: int) -> Fraction: + scale = 10 ** decimals + return Fraction(int(round(x * scale)), scale) + + +def rat_to_str(q: Fraction) -> str: + if q.denominator == 1: + return str(q.numerator) + return f"{q.numerator}/{q.denominator}" + + +def parse_header(f) -> Dict[str, str]: + header: Dict[str, str] = {} + magic = f.readline().decode("ascii").strip() + if magic != "NFP_BINARY_V1": + raise SystemExit(f"Unsupported magic header: {magic}") + while True: + line = f.readline() + if line == b"": + raise SystemExit("Unexpected EOF while reading header.") + text = line.decode("ascii").strip() + if text == "BINARY_START": + break + if "=" in text: + key, value = text.split("=", 1) + header[key.strip()] = value.strip() + return header + + +def read_i32(f, count: int) -> np.ndarray: + raw = f.read(count * 4) + if len(raw) != count * 4: + raise SystemExit("Unexpected EOF while reading int32 payload.") + return np.frombuffer(raw, dtype=" np.ndarray: + raw = f.read(count * 8) + if len(raw) != count * 8: + raise SystemExit("Unexpected EOF while reading float64 payload.") + return np.frombuffer(raw, dtype=" None: + offset = count * 8 + f.seek(offset, 1) + + +def build_prev(tokens: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + prev = np.zeros_like(tokens) + active = np.zeros_like(tokens, dtype=bool) + last_seen: Dict[int, int] = {} + for idx, tok in enumerate(tokens.tolist()): + if idx == 0: + prev[idx] = 0 + active[idx] = False + else: + if tok in last_seen: + prev[idx] = last_seen[tok] + active[idx] = True + else: + prev[idx] = 0 + active[idx] = False + last_seen[tok] = idx + return prev, active + + +def read_head_weights( + f, + num_layers: int, + num_heads: int, + model_dim: int, + head_dim: int, + hidden_dim: int, + layer: int, + head: int, +) -> Tuple[ + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, +]: + target = (layer, head) + wq = wk = wv = wo = None + bq = bk = bv = None + attn_bias = ln1_gamma = ln1_beta = None + for layer_idx in range(num_layers): + for head_idx in range(num_heads): + wq_block = read_f64(f, model_dim * head_dim).reshape(model_dim, head_dim) + bq_block = read_f64(f, head_dim) + wk_block = read_f64(f, model_dim * head_dim).reshape(model_dim, head_dim) + bk_block = read_f64(f, head_dim) + wv_block = read_f64(f, model_dim * head_dim).reshape(model_dim, head_dim) + bv_block = read_f64(f, head_dim) + wo_block = read_f64(f, head_dim * model_dim).reshape(head_dim, model_dim) + if (layer_idx, head_idx) == target: + wq = wq_block + wk = wk_block + wv = wv_block + wo = wo_block + bq = bq_block + bk = bk_block + bv = bv_block + attn_bias_block = read_f64(f, model_dim) + skip_f64(f, model_dim * hidden_dim) + skip_f64(f, hidden_dim) + skip_f64(f, hidden_dim * model_dim) + skip_f64(f, model_dim) + ln1_gamma_block = read_f64(f, model_dim) + ln1_beta_block = read_f64(f, model_dim) + skip_f64(f, model_dim) + skip_f64(f, model_dim) + if layer_idx == layer: + attn_bias = attn_bias_block + ln1_gamma = ln1_gamma_block + ln1_beta = ln1_beta_block + if ( + wq is None + or wk is None + or wv is None + or wo is None + or bq is None + or bk is None + or bv is None + or attn_bias is None + or ln1_gamma is None + or ln1_beta is None + ): + raise SystemExit("Failed to locate head weights.") + return wq, bq, wk, bk, wv, bv, wo, attn_bias, ln1_gamma, ln1_beta + + +def read_unembed_columns( + f, + start: int, + model_dim: int, + vocab_size: int, + target: int, + negative: int, +) -> Tuple[np.ndarray, np.ndarray]: + row_bytes = vocab_size * 8 + col_t = np.zeros(model_dim, dtype=np.float64) + col_n = np.zeros(model_dim, dtype=np.float64) + for row in range(model_dim): + base = start + row * row_bytes + f.seek(base + target * 8) + col_t[row] = struct.unpack(" np.ndarray: + mean = x.mean(axis=1, keepdims=True) + var = ((x - mean) ** 2).mean(axis=1, keepdims=True) + x_hat = (x - mean) / np.sqrt(var + eps) + return x_hat * gamma + beta + + +def softmax(scores: np.ndarray) -> np.ndarray: + shift = scores - scores.max(axis=1, keepdims=True) + exp = np.exp(shift) + return exp / exp.sum(axis=1, keepdims=True) + + +def write_softmax_cert(path: Path, seq: int, prev: np.ndarray, + scores_rat, weights_rat, eps: Fraction, + margin: Fraction, active_positions) -> None: + with path.open("w", encoding="ascii") as f: + f.write(f"seq {seq}\n") + f.write(f"eps {rat_to_str(eps)}\n") + f.write(f"margin {rat_to_str(margin)}\n") + for q in active_positions: + f.write(f"active {q}\n") + for q, k in enumerate(prev.tolist()): + f.write(f"prev {q} {k}\n") + for q in range(seq): + for k in range(seq): + f.write(f"score {q} {k} {rat_to_str(scores_rat[q][k])}\n") + for q in range(seq): + for k in range(seq): + f.write(f"weight {q} {k} {rat_to_str(weights_rat[q][k])}\n") + + +def write_value_range(path: Path, seq: int, values, decimals: int, + direction_target=None, direction_negative=None) -> None: + vals_rat = [rat_from_float(float(values[k]), decimals) for k in range(seq)] + lo = min(vals_rat) + hi = max(vals_rat) + with path.open("w", encoding="ascii") as f: + f.write(f"seq {seq}\n") + if direction_target is not None and direction_negative is not None: + f.write(f"direction-target {direction_target}\n") + f.write(f"direction-negative {direction_negative}\n") + f.write(f"lo {rat_to_str(lo)}\n") + f.write(f"hi {rat_to_str(hi)}\n") + for k, val in enumerate(vals_rat): + f.write(f"val {k} {rat_to_str(val)}\n") + + +def main() -> None: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--model", type=Path, required=True, help="Path to NFP_BINARY_V1 model") + ap.add_argument("--layer", type=int, required=True, help="Layer index") + ap.add_argument("--head", type=int, required=True, help="Head index") + ap.add_argument("--output", type=Path, required=True, help="Path for softmax certificate") + ap.add_argument("--values-out", type=Path, required=True, + help="Path for value-range certificate") + ap.add_argument("--direction-target", type=int, required=True, + help="Target token id for logit-diff direction") + ap.add_argument("--direction-negative", type=int, required=True, + help="Negative token id for logit-diff direction") + ap.add_argument("--decimals", type=int, default=6, + help="Decimal rounding for rationals") + ap.add_argument("--active-eps-max", default="1/2", + help="Maximum eps to include an active position") + args = ap.parse_args() + + if not args.model.exists(): + raise SystemExit(f"Missing model file: {args.model}") + + with args.model.open("rb") as f: + header = parse_header(f) + num_layers = int(header["num_layers"]) + num_heads = int(header["num_heads"]) + model_dim = int(header["model_dim"]) + head_dim = int(header["head_dim"]) + vocab_size = int(header["vocab_size"]) + seq_len = int(header["seq_len"]) + hidden_dim = int(header["hidden_dim"]) + ln_eps = float(header.get("layer_norm_eps", header.get("eps", "0"))) + + if args.layer < 0 or args.layer >= num_layers: + raise SystemExit("layer index out of range") + if args.head < 0 or args.head >= num_heads: + raise SystemExit("head index out of range") + if not (0 <= args.direction_target < vocab_size): + raise SystemExit("direction-target out of vocab range") + if not (0 <= args.direction_negative < vocab_size): + raise SystemExit("direction-negative out of vocab range") + + tokens = read_i32(f, seq_len) + embeddings = read_f64(f, seq_len * model_dim).reshape(seq_len, model_dim) + + wq, bq, wk, bk, wv, bv, wo_raw, _attn_bias, ln1_gamma, ln1_beta = read_head_weights( + f, + num_layers, + num_heads, + model_dim, + head_dim, + hidden_dim, + args.layer, + args.head, + ) + + skip_f64(f, model_dim) + skip_f64(f, model_dim) + + unembed_start = f.tell() + col_target, col_negative = read_unembed_columns( + f, + unembed_start, + model_dim, + vocab_size, + args.direction_target, + args.direction_negative, + ) + + prev, active_mask = build_prev(tokens) + candidate_positions = [int(i) for i, flag in enumerate(active_mask) if flag] + active_eps_max = Fraction(args.active_eps_max) + + scale_denom = int(math.isqrt(head_dim)) + if scale_denom * scale_denom != head_dim: + scale = 1.0 / math.sqrt(head_dim) + else: + scale = 1.0 / scale_denom + + ln = layer_norm(embeddings, ln1_gamma, ln1_beta, ln_eps) + q = ln @ wq + bq + k = ln @ wk + bk + v = ln @ wv + bv + + scores = scale * (q @ k.T) + mask_value = -10000.0 + mask = np.triu(np.ones((seq_len, seq_len), dtype=bool), k=1) + scores = scores.copy() + scores[mask] = mask_value + weights = softmax(scores) + + scores_rat = [[rat_from_float(float(scores[q, k]), args.decimals) + for k in range(seq_len)] for q in range(seq_len)] + weights_rat = [[rat_from_float(float(weights[q, k]), args.decimals) + for k in range(seq_len)] for q in range(seq_len)] + + for q in range(seq_len): + total = sum(weights_rat[q], Fraction(0)) + if total == 0: + raise SystemExit(f"zero weight sum at q={q}") + weights_rat[q] = [w / total for w in weights_rat[q]] + + eps_by_q: dict[int, Fraction] = {} + margin_by_q: dict[int, Fraction] = {} + for q in candidate_positions: + prev_q = prev[q] + prev_w = weights_rat[q][prev_q] + max_other = max(weights_rat[q][k] for k in range(seq_len) if k != prev_q) + deficit = Fraction(1) - prev_w + eps_by_q[q] = max(max_other, deficit) + + diffs = [scores_rat[q][prev_q] - scores_rat[q][k] + for k in range(seq_len) if k != prev_q] + if diffs: + margin_by_q[q] = min(diffs) + + active_positions = [q for q in candidate_positions if eps_by_q[q] <= active_eps_max] + if not active_positions and candidate_positions: + print("Warning: no active positions satisfy active-eps-max; certificate may be vacuous.") + + if active_positions: + eps = max(eps_by_q[q] for q in active_positions) + margin = min((margin_by_q[q] for q in active_positions), default=Fraction(0)) + else: + eps = Fraction(0) + margin = Fraction(0) + + output_path = args.output + output_path.parent.mkdir(parents=True, exist_ok=True) + write_softmax_cert(output_path, seq_len, prev, scores_rat, weights_rat, eps, margin, + active_positions) + + wo = wo_raw.T + direction = col_target - col_negative + dir_head = wo.T @ direction + dir_vals = v @ dir_head + values_path = args.values_out + values_path.parent.mkdir(parents=True, exist_ok=True) + write_value_range(values_path, seq_len, dir_vals, args.decimals, + direction_target=args.direction_target, + direction_negative=args.direction_negative) + + print(f"Wrote softmax certificate to {output_path}") + print(f"Wrote value-range certificate to {values_path}") + if candidate_positions: + print(f"Active positions: {len(active_positions)}/{len(candidate_positions)}") + + +if __name__ == "__main__": + main() diff --git a/scripts/discover_gpt2_induction_targets.py b/scripts/discover_gpt2_induction_targets.py new file mode 100644 index 0000000..b4cce0f --- /dev/null +++ b/scripts/discover_gpt2_induction_targets.py @@ -0,0 +1,425 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: AGPL-3.0-or-later + +""" +Discover promising GPT-2 induction heads and logit-diff directions from an NFP binary. + +This script is untrusted: it uses floating-point arithmetic to score candidates +and optionally invokes the Lean verifier (`nfp induction certify_head_model_nonvacuous`) +to confirm nonvacuous bounds. +""" + +from __future__ import annotations + +import argparse +import json +import os +import subprocess +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Tuple + +import numpy as np + + +@dataclass(frozen=True) +class HeadResult: + layer: int + head: int + target: int + negative: int + logit_lb: float + eps: float + margin: float + min_prev: float + value_range: float + active: int + + +def parse_header(f) -> Dict[str, str]: + header: Dict[str, str] = {} + magic = f.readline().decode("ascii").strip() + if magic != "NFP_BINARY_V1": + raise SystemExit(f"Unsupported magic header: {magic}") + while True: + line = f.readline() + if line == b"": + raise SystemExit("Unexpected EOF while reading header.") + text = line.decode("ascii").strip() + if text == "BINARY_START": + break + if "=" in text: + key, value = text.split("=", 1) + header[key.strip()] = value.strip() + return header + + +def read_i32(f, count: int) -> np.ndarray: + raw = f.read(count * 4) + if len(raw) != count * 4: + raise SystemExit("Unexpected EOF while reading int32 payload.") + return np.frombuffer(raw, dtype=" np.ndarray: + raw = f.read(count * 8) + if len(raw) != count * 8: + raise SystemExit("Unexpected EOF while reading float64 payload.") + return np.frombuffer(raw, dtype=" None: + offset = count * 8 + f.seek(offset, 1) + + +def build_prev(tokens: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + prev = np.zeros_like(tokens) + active = np.zeros_like(tokens, dtype=bool) + last_seen: Dict[int, int] = {} + for idx, tok in enumerate(tokens.tolist()): + if idx == 0: + prev[idx] = 0 + active[idx] = False + else: + if tok in last_seen: + prev[idx] = last_seen[tok] + active[idx] = True + else: + prev[idx] = 0 + active[idx] = False + last_seen[tok] = idx + return prev, active + + +def layer_norm(x: np.ndarray, gamma: np.ndarray, beta: np.ndarray, eps: float) -> np.ndarray: + mean = x.mean(axis=1, keepdims=True) + var = ((x - mean) ** 2).mean(axis=1, keepdims=True) + x_hat = (x - mean) / np.sqrt(var + eps) + return x_hat * gamma + beta + + +def softmax(scores: np.ndarray) -> np.ndarray: + shift = scores - scores.max(axis=1, keepdims=True) + exp = np.exp(shift) + return exp / exp.sum(axis=1, keepdims=True) + + +def parse_index_list(raw: str | None, max_value: int) -> List[int] | None: + if raw is None: + return None + raw = raw.strip() + if raw.lower() == "all": + return list(range(max_value)) + out: List[int] = [] + for part in raw.split(","): + part = part.strip() + if not part: + continue + idx = int(part) + if idx < 0 or idx >= max_value: + raise ValueError(f"index {idx} out of range [0,{max_value})") + out.append(idx) + return out + + +def resolve_nfp_cmd(nfp_bin: str | None) -> List[str]: + if nfp_bin: + return [nfp_bin] + env_bin = os.environ.get("NFP_BIN") + if env_bin: + return [env_bin] + local_bin = Path(".lake/build/bin/nfp") + if local_bin.exists(): + return [str(local_bin)] + return ["lake", "exe", "nfp"] + + +def read_unembed_column( + f, + start: int, + model_dim: int, + vocab_size: int, + col: int, +) -> np.ndarray: + if col < 0 or col >= vocab_size: + raise ValueError(f"column {col} out of range") + row_bytes = vocab_size * 8 + data = np.zeros(model_dim, dtype=np.float64) + for row in range(model_dim): + base = start + row * row_bytes + f.seek(base + col * 8) + data[row] = np.frombuffer(f.read(8), dtype=" Tuple[float, float]: + eps_vals: List[float] = [] + margin_vals: List[float] = [] + seq = weights.shape[0] + for q in active_positions: + prev_q = int(prev[q]) + prev_w = weights[q, prev_q] + max_other = np.max(np.delete(weights[q], prev_q)) + eps_vals.append(max(max_other, 1.0 - prev_w)) + diffs = scores[q, prev_q] - np.delete(scores[q], prev_q) + margin_vals.append(float(np.min(diffs)) if diffs.size > 0 else 0.0) + if not eps_vals: + return 0.0, 0.0 + return max(eps_vals), min(margin_vals) + + +def format_result(result: HeadResult) -> str: + return ( + f"L{result.layer}H{result.head} target={result.target} " + f"negative={result.negative} logitLB={result.logit_lb:.6f} " + f"eps={result.eps:.6f} margin={result.margin:.6f} " + f"minPrev={result.min_prev:.6f} range={result.value_range:.6f} " + f"active={result.active}" + ) + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model", required=True, type=Path, help="Path to NFP_BINARY_V1 model") + parser.add_argument("--max-tokens", type=int, default=32, + help="Maximum unique tokens from the prompt to consider") + parser.add_argument("--top", type=int, default=20, help="Number of results to report") + parser.add_argument("--verify-top", type=int, default=0, + help="Run verifier on the top N candidates") + parser.add_argument("--min-eps", type=float, default=0.5, + help="Filter candidates with eps above this value") + parser.add_argument("--min-margin", type=float, default=0.0, + help="Filter candidates with margin below this value") + parser.add_argument("--min-logit-lb", type=float, default=0.0, + help="Filter candidates with logit lower bound below this value") + parser.add_argument("--layers", help="Comma-separated layer list or 'all'") + parser.add_argument("--heads", help="Comma-separated head list or 'all'") + parser.add_argument("--period", type=int, help="Optional prompt period override") + parser.add_argument("--output", type=Path, default=Path("reports/gpt2_induction_discover.txt")) + parser.add_argument("--json-out", type=Path, help="Optional JSON output path") + parser.add_argument("--nfp-bin", help="Path to nfp binary (defaults to .lake/build/bin/nfp)") + args = parser.parse_args() + + if args.max_tokens <= 1: + raise SystemExit("max-tokens must be at least 2") + + if not args.model.exists(): + raise SystemExit(f"Missing model file: {args.model}") + + with args.model.open("rb") as f: + header = parse_header(f) + num_layers = int(header["num_layers"]) + num_heads = int(header["num_heads"]) + model_dim = int(header["model_dim"]) + head_dim = int(header["head_dim"]) + vocab_size = int(header["vocab_size"]) + seq_len = int(header["seq_len"]) + hidden_dim = int(header["hidden_dim"]) + ln_eps = float(header.get("layer_norm_eps", header.get("eps", "0"))) + + layers = parse_index_list(args.layers, num_layers) or list(range(num_layers)) + heads = parse_index_list(args.heads, num_heads) or list(range(num_heads)) + + tokens = read_i32(f, seq_len) + embeddings = read_f64(f, seq_len * model_dim).reshape(seq_len, model_dim) + + if args.period is not None: + period = int(args.period) + prev = np.arange(seq_len, dtype=np.int64) + prev = np.where(prev >= period, prev - period, 0) + active_mask = np.arange(seq_len) >= period + else: + prev, active_mask = build_prev(tokens) + + active_positions = [int(i) for i, flag in enumerate(active_mask) if flag] + if not active_positions: + raise SystemExit("No active positions found in the prompt") + + unique_tokens = [] + seen = set() + for tok in tokens.tolist(): + if tok not in seen: + seen.add(tok) + unique_tokens.append(int(tok)) + if len(unique_tokens) >= args.max_tokens: + break + if len(unique_tokens) < 2: + raise SystemExit("Need at least two unique tokens to form directions") + + head_data: Dict[Tuple[int, int], Tuple[np.ndarray, np.ndarray, float, float]] = {} + + for layer_idx in range(num_layers): + head_weights = [] + for _ in range(num_heads): + wq = read_f64(f, model_dim * head_dim).reshape(model_dim, head_dim) + bq = read_f64(f, head_dim) + wk = read_f64(f, model_dim * head_dim).reshape(model_dim, head_dim) + bk = read_f64(f, head_dim) + wv = read_f64(f, model_dim * head_dim).reshape(model_dim, head_dim) + bv = read_f64(f, head_dim) + wo = read_f64(f, head_dim * model_dim).reshape(head_dim, model_dim) + head_weights.append((wq, bq, wk, bk, wv, bv, wo)) + _attn_bias = read_f64(f, model_dim) + skip_f64(f, model_dim * hidden_dim) + skip_f64(f, hidden_dim) + skip_f64(f, hidden_dim * model_dim) + skip_f64(f, model_dim) + ln1_gamma = read_f64(f, model_dim) + ln1_beta = read_f64(f, model_dim) + skip_f64(f, model_dim) + skip_f64(f, model_dim) + + if layer_idx not in layers: + continue + + ln = layer_norm(embeddings, ln1_gamma, ln1_beta, ln_eps) + scale = 1.0 / np.sqrt(head_dim) + for head_idx in heads: + wq, bq, wk, bk, wv, bv, wo = head_weights[head_idx] + q = ln @ wq + bq + k = ln @ wk + bk + v = ln @ wv + bv + + scores = scale * (q @ k.T) + mask = np.triu(np.ones((seq_len, seq_len), dtype=bool), k=1) + scores = scores.copy() + scores[mask] = -10000.0 + weights = softmax(scores) + + eps, margin = compute_eps_margin(weights, scores, prev, active_positions) + head_data[(layer_idx, head_idx)] = (v, wo, eps, margin) + + ln_f_gamma = read_f64(f, model_dim) + _ln_f_beta = read_f64(f, model_dim) + _ = ln_f_gamma + unembed_start = f.tell() + + columns: Dict[int, np.ndarray] = {} + for tok in unique_tokens: + columns[tok] = read_unembed_column( + f, + unembed_start, + model_dim, + vocab_size, + tok, + ) + + results: List[HeadResult] = [] + prev_indices = prev[np.array(active_positions, dtype=np.int64)] + for (layer_idx, head_idx), (v, wo, eps, margin) in head_data.items(): + if eps > args.min_eps or margin < args.min_margin: + continue + proj: Dict[int, np.ndarray] = {} + for tok in unique_tokens: + dir_head = wo @ columns[tok] + proj[tok] = v @ dir_head + best: HeadResult | None = None + for target in unique_tokens: + vals_target = proj[target] + for negative in unique_tokens: + if target == negative: + continue + vals = vals_target - proj[negative] + vals_prev = vals[prev_indices] + min_prev = float(vals_prev.min()) if vals_prev.size else 0.0 + value_range = float(vals.max() - vals.min()) + logit_lb = min_prev - eps * value_range + if logit_lb < args.min_logit_lb: + continue + candidate = HeadResult( + layer=layer_idx, + head=head_idx, + target=target, + negative=negative, + logit_lb=logit_lb, + eps=eps, + margin=margin, + min_prev=min_prev, + value_range=value_range, + active=len(active_positions), + ) + if best is None or candidate.logit_lb > best.logit_lb: + best = candidate + if best is not None: + results.append(best) + + results.sort(key=lambda r: r.logit_lb, reverse=True) + args.output.parent.mkdir(parents=True, exist_ok=True) + with args.output.open("w", encoding="ascii") as f: + f.write("Induction discovery (approximate ranking)\n") + f.write(f"model={args.model}\n") + f.write(f"tokens={len(unique_tokens)} active={len(active_positions)}\n") + f.write(f"min-eps={args.min_eps} min-margin={args.min_margin} min-logit-lb={args.min_logit_lb}\n") + for rank, result in enumerate(results[: args.top], start=1): + f.write(f"{rank:02d} {format_result(result)}\n") + + print(f"Wrote report to {args.output}") + for rank, result in enumerate(results[: args.top], start=1): + print(f"{rank:02d} {format_result(result)}") + + if args.json_out is not None: + args.json_out.parent.mkdir(parents=True, exist_ok=True) + payload = { + "model": str(args.model), + "tokens": len(unique_tokens), + "active": len(active_positions), + "min_eps": args.min_eps, + "min_margin": args.min_margin, + "min_logit_lb": args.min_logit_lb, + "results": [ + { + "rank": rank, + "layer": r.layer, + "head": r.head, + "target": r.target, + "negative": r.negative, + "logit_lb": r.logit_lb, + "eps": r.eps, + "margin": r.margin, + "min_prev": r.min_prev, + "value_range": r.value_range, + "active": r.active, + } + for rank, r in enumerate(results[: args.top], start=1) + ], + } + args.json_out.write_text(json.dumps(payload, indent=2), encoding="ascii") + + if args.verify_top > 0 and results: + nfp_cmd = resolve_nfp_cmd(args.nfp_bin) + print("\nVerifying top candidates with Lean checker:") + for result in results[: args.verify_top]: + cmd = nfp_cmd + [ + "induction", + "certify_head_model_nonvacuous", + "--model", + str(args.model), + "--layer", + str(result.layer), + "--head", + str(result.head), + "--direction-target", + str(result.target), + "--direction-negative", + str(result.negative), + ] + if args.period is not None: + cmd += ["--period", str(args.period)] + proc = subprocess.run(cmd, capture_output=True, text=True) + status = "ok" if proc.returncode == 0 else "fail" + stdout = proc.stdout.strip().replace("\n", " ") + stderr = proc.stderr.strip().replace("\n", " ") + print(f"{status} {result.layer}/{result.head} tgt={result.target} neg={result.negative}") + if stdout: + print(f" out: {stdout}") + if stderr: + print(f" err: {stderr}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/generate_rigorous_induction.py b/scripts/generate_rigorous_induction.py index 35f5da4..40e339d 100644 --- a/scripts/generate_rigorous_induction.py +++ b/scripts/generate_rigorous_induction.py @@ -12,8 +12,12 @@ This aims to isolate induction-style copying from semantic completion. """ +from __future__ import annotations + +import argparse import sys from pathlib import Path + import numpy as np import torch @@ -25,63 +29,94 @@ sys.exit(1) -def export_rigorous_induction(output_path: str = "models/gpt2_rigorous.nfpt"): - print("Loading GPT-2 Small...") - model = GPT2Model.from_pretrained("gpt2") - tokenizer = GPT2Tokenizer.from_pretrained("gpt2") +def select_vocab_candidates( + tokenizer, + vocab_min: int, + vocab_max: int, + min_word_length: int, + require_leading_space: bool, +) -> list[int]: + candidates = [] + for tid in range(vocab_min, vocab_max): + word = tokenizer.decode([tid]) + if len(word.strip()) <= min_word_length: + continue + if require_leading_space and not word.startswith(" "): + continue + candidates.append(tid) + return candidates + + +def export_rigorous_induction( + output_path: str = "models/gpt2_rigorous.nfpt", + seq_len: int = 256, + pattern_len: int = 20, + seed: int = 1337, + vocab_min: int = 1000, + vocab_max: int = 5000, + min_word_length: int = 4, + require_leading_space: bool = True, + model_name: str = "gpt2", +) -> None: + print(f"Loading {model_name}...") + model = GPT2Model.from_pretrained(model_name) + tokenizer = GPT2Tokenizer.from_pretrained(model_name) config = model.config layer_norm_eps = float(config.layer_norm_epsilon) - # 1. Define a vocabulary of common, distinct English nouns - # These token IDs are single-token words in GPT-2 (e.g., " apple", " logic") - # This prevents fragmentation issues. - # IDs: 1000-5000 range usually contains common words. - # We filter for length > 4 to ensure they are substantive. - vocab_candidates = [] - for tid in range(1000, 5000): - word = tokenizer.decode([tid]) - if len(word.strip()) > 4 and word.startswith(" "): - vocab_candidates.append(tid) - - # 2. Construct the Random Repeated Sequence - # Pattern length L=20, repeated to fill seq_len (256 tokens). - seq_len = 256 - pattern_len = 20 + if seq_len <= 0: + raise ValueError("seq_len must be positive") + if pattern_len <= 0 or pattern_len > seq_len: + raise ValueError("pattern_len must be between 1 and seq_len") + if vocab_min < 0 or vocab_max <= vocab_min: + raise ValueError("invalid vocab range") + + vocab_candidates = select_vocab_candidates( + tokenizer, + vocab_min=vocab_min, + vocab_max=vocab_max, + min_word_length=min_word_length, + require_leading_space=require_leading_space, + ) + if len(vocab_candidates) < pattern_len: + raise ValueError( + f"Need at least {pattern_len} vocab candidates; only found {len(vocab_candidates)}" + ) - np.random.seed(1337) # Fixed seed for reproducibility + np.random.seed(seed) unique_pattern = np.random.choice(vocab_candidates, size=pattern_len, replace=False) - # [A, B, C, ...] -> [A, B, C, ..., A, B, C, ...] repeats = (seq_len // pattern_len) + 1 full_sequence = np.tile(unique_pattern, repeats)[:seq_len] - # 3. Verify the "Induction Target" property - # The last token T[N] must have appeared at T[N - pattern_len]. - # The target is T[N - pattern_len + 1]. last_token = full_sequence[-1] prev_idx = seq_len - 1 - pattern_len target_token = full_sequence[prev_idx + 1] - print(f"\nSequence Structure:") + print("\nSequence Structure:") print(f" Pattern Length: {pattern_len}") print(f" Total Length: {seq_len}") + print(f" Seed: {seed}") + print(f" Vocab Range: [{vocab_min}, {vocab_max})") + print( + f" Token Filter: min_len>{min_word_length}, " + f"leading_space={require_leading_space}" + ) print(f" Last Token: '{tokenizer.decode([last_token])}' (ID: {last_token})") print(f" Previous Occur: Index {prev_idx}") print( f" True Target: '{tokenizer.decode([target_token])}' (ID: {target_token})" ) - # 4. Compute Embeddings & Weights wte = model.wte.weight.detach().numpy() wpe = model.wpe.weight.detach().numpy() sample_embeddings = wte[full_sequence] + wpe[:seq_len] - # 5. Export output_file = Path(output_path) output_file.parent.mkdir(parents=True, exist_ok=True) print(f"\nExporting to {output_path}...") - with open(output_path, "wb") as f: + with output_file.open("wb") as f: write_header( f, num_layers=config.n_layer, @@ -98,7 +133,6 @@ def export_rigorous_induction(output_path: str = "models/gpt2_rigorous.nfpt"): write_i32(f, full_sequence) write_f64(f, sample_embeddings) - # Export Layers (Standard Loop) for layer_idx in range(config.n_layer): block = model.h[layer_idx] @@ -107,21 +141,21 @@ def export_rigorous_induction(output_path: str = "models/gpt2_rigorous.nfpt"): c_proj = block.attn.c_proj.weight.detach().numpy() c_proj_bias = get_bias(block.attn.c_proj.bias, 768) - W_Q_all = c_attn[:, 0:768] - W_K_all = c_attn[:, 768 : 2 * 768] - W_V_all = c_attn[:, 2 * 768 : 3 * 768] - b_Q_all = c_attn_bias[0:768] - b_K_all = c_attn_bias[768 : 2 * 768] - b_V_all = c_attn_bias[2 * 768 : 3 * 768] + w_q_all = c_attn[:, 0:768] + w_k_all = c_attn[:, 768 : 2 * 768] + w_v_all = c_attn[:, 2 * 768 : 3 * 768] + b_q_all = c_attn_bias[0:768] + b_k_all = c_attn_bias[768 : 2 * 768] + b_v_all = c_attn_bias[2 * 768 : 3 * 768] for h in range(12): start, end = h * 64, (h + 1) * 64 - write_f64(f, W_Q_all[:, start:end]) - write_f64(f, b_Q_all[start:end]) - write_f64(f, W_K_all[:, start:end]) - write_f64(f, b_K_all[start:end]) - write_f64(f, W_V_all[:, start:end]) - write_f64(f, b_V_all[start:end]) + write_f64(f, w_q_all[:, start:end]) + write_f64(f, b_q_all[start:end]) + write_f64(f, w_k_all[:, start:end]) + write_f64(f, b_k_all[start:end]) + write_f64(f, w_v_all[:, start:end]) + write_f64(f, b_v_all[start:end]) write_f64(f, c_proj[start:end, :]) write_f64(f, c_proj_bias) @@ -167,4 +201,32 @@ def get_bias(param, size: int) -> np.ndarray: if __name__ == "__main__": - export_rigorous_induction() + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--output", default="models/gpt2_rigorous.nfpt") + parser.add_argument("--seq-len", type=int, default=256) + parser.add_argument("--pattern-len", type=int, default=20) + parser.add_argument("--seed", type=int, default=1337) + parser.add_argument("--vocab-min", type=int, default=1000) + parser.add_argument("--vocab-max", type=int, default=5000) + parser.add_argument("--min-word-length", type=int, default=4) + parser.add_argument("--require-leading-space", action="store_true", default=True) + parser.add_argument( + "--allow-no-leading-space", + action="store_true", + help="Permit tokens without a leading space", + ) + parser.add_argument("--model", default="gpt2") + args = parser.parse_args() + + require_leading_space = args.require_leading_space and not args.allow_no_leading_space + export_rigorous_induction( + output_path=args.output, + seq_len=args.seq_len, + pattern_len=args.pattern_len, + seed=args.seed, + vocab_min=args.vocab_min, + vocab_max=args.vocab_max, + min_word_length=args.min_word_length, + require_leading_space=require_leading_space, + model_name=args.model, + ) diff --git a/scripts/sweep_gpt2_induction_nonvacuous.py b/scripts/sweep_gpt2_induction_nonvacuous.py new file mode 100644 index 0000000..8b996a3 --- /dev/null +++ b/scripts/sweep_gpt2_induction_nonvacuous.py @@ -0,0 +1,290 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: AGPL-3.0-or-later + +""" +Sweep prompt parameters and verify nonvacuous induction bounds for GPT-2. + +This is untrusted orchestration: discovery uses floating-point math and only +Lean verification results are treated as definitive. +""" + +from __future__ import annotations + +import argparse +import json +import re +import shutil +import subprocess +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable, List + + +LOGIT_RE = re.compile(r"logitDiffLB=([^\s\)]+)") + + +@dataclass(frozen=True) +class VerifyResult: + ok: bool + logit_lb: str | None + stdout: str + stderr: str + + +def parse_int_list(raw: str, name: str) -> List[int]: + items: List[int] = [] + for part in raw.split(","): + part = part.strip() + if not part: + continue + try: + items.append(int(part)) + except ValueError as exc: + raise ValueError(f"invalid {name} entry: {part}") from exc + if not items: + raise ValueError(f"{name} list is empty") + return items + + +def resolve_python_cmd() -> List[str]: + if shutil.which("uv"): + return ["uv", "run"] + return ["python3"] + + +def resolve_nfp_cmd(nfp_bin: str | None) -> List[str]: + if nfp_bin: + return [nfp_bin] + local_bin = Path(".lake/build/bin/nfp") + if local_bin.exists(): + return [str(local_bin)] + return ["lake", "exe", "nfp"] + + +def run_cmd(cmd: Iterable[str], check: bool = True) -> subprocess.CompletedProcess: + return subprocess.run(list(cmd), check=check, capture_output=True, text=True) + + +def ensure_model( + generator: Path, + output: Path, + seq_len: int, + pattern_len: int, + seed: int, +) -> None: + if output.exists(): + return + output.parent.mkdir(parents=True, exist_ok=True) + cmd = resolve_python_cmd() + [ + str(generator), + "--output", + str(output), + "--seq-len", + str(seq_len), + "--pattern-len", + str(pattern_len), + "--seed", + str(seed), + ] + run_cmd(cmd, check=True) + + +def run_discovery( + discover_script: Path, + model: Path, + max_tokens: int, + top: int, + min_eps: float, + min_margin: float, + min_logit_lb: float, + period: int | None, + output_dir: Path, +) -> list[dict]: + output_dir.mkdir(parents=True, exist_ok=True) + json_out = output_dir / f"{model.stem}.json" + cmd = resolve_python_cmd() + [ + str(discover_script), + "--model", + str(model), + "--max-tokens", + str(max_tokens), + "--top", + str(top), + "--min-eps", + str(min_eps), + "--min-margin", + str(min_margin), + "--min-logit-lb", + str(min_logit_lb), + "--json-out", + str(json_out), + ] + if period is not None: + cmd += ["--period", str(period)] + run_cmd(cmd, check=True) + payload = json.loads(json_out.read_text(encoding="ascii")) + return payload.get("results", []) + + +def verify_candidate( + nfp_cmd: List[str], + model: Path, + layer: int, + head: int, + target: int, + negative: int, + period: int | None, +) -> VerifyResult: + cmd = nfp_cmd + [ + "induction", + "certify_head_model_nonvacuous", + "--model", + str(model), + "--layer", + str(layer), + "--head", + str(head), + "--direction-target", + str(target), + "--direction-negative", + str(negative), + ] + if period is not None: + cmd += ["--period", str(period)] + proc = run_cmd(cmd, check=False) + stdout = proc.stdout.strip() + stderr = proc.stderr.strip() + logit_lb = None + match = LOGIT_RE.search(stdout) + if match: + logit_lb = match.group(1) + return VerifyResult(proc.returncode == 0, logit_lb, stdout, stderr) + + +def write_csv_row(path: Path, row: dict) -> None: + header = [ + "model_path", + "seq_len", + "pattern_len", + "seed", + "layer", + "head", + "target", + "negative", + "approx_logit_lb", + "approx_eps", + "approx_margin", + "approx_min_prev", + "approx_value_range", + "active", + "period", + "verify_status", + "verify_logit_lb", + ] + new_file = not path.exists() or path.stat().st_size == 0 + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("a", encoding="ascii") as f: + if new_file: + f.write(",".join(header) + "\n") + f.write(",".join(str(row.get(col, "")) for col in header) + "\n") + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--output", type=Path, default=Path("reports/gpt2_induction_sweep.csv")) + parser.add_argument("--model-dir", type=Path, default=Path("models")) + parser.add_argument("--generator", type=Path, default=Path("scripts/generate_rigorous_induction.py")) + parser.add_argument("--discover", type=Path, default=Path("scripts/discover_gpt2_induction_targets.py")) + parser.add_argument("--seq-lens", default="64") + parser.add_argument("--pattern-lens", default="16") + parser.add_argument("--seeds", default="1337") + parser.add_argument("--max-tokens", type=int, default=32) + parser.add_argument("--top", type=int, default=10) + parser.add_argument("--verify-top", type=int, default=3) + parser.add_argument("--min-eps", type=float, default=0.5) + parser.add_argument("--min-margin", type=float, default=0.0) + parser.add_argument("--min-logit-lb", type=float, default=0.0) + parser.add_argument("--use-period", action="store_true", + help="Use pattern length as the period override") + parser.add_argument("--nfp-bin", help="Path to nfp binary") + parser.add_argument("--discovery-dir", type=Path, default=Path("reports/discovery")) + args = parser.parse_args() + + seq_lens = parse_int_list(args.seq_lens, "seq-lens") + pattern_lens = parse_int_list(args.pattern_lens, "pattern-lens") + seeds = parse_int_list(args.seeds, "seeds") + + nfp_cmd = resolve_nfp_cmd(args.nfp_bin) + + for seq_len in seq_lens: + for pattern_len in pattern_lens: + for seed in seeds: + model_name = f"gpt2_rigorous_seq{seq_len}_pat{pattern_len}_seed{seed}.nfpt" + model_path = args.model_dir / model_name + ensure_model(args.generator, model_path, seq_len, pattern_len, seed) + period = pattern_len if args.use_period else None + results = run_discovery( + args.discover, + model_path, + args.max_tokens, + args.top, + args.min_eps, + args.min_margin, + args.min_logit_lb, + period, + args.discovery_dir, + ) + if not results: + print( + f"no candidates for seq={seq_len} pat={pattern_len} seed={seed}", + flush=True, + ) + continue + for result in results[: args.verify_top]: + verify = verify_candidate( + nfp_cmd, + model_path, + result["layer"], + result["head"], + result["target"], + result["negative"], + period, + ) + status = "ok" if verify.ok else "fail" + if verify.ok: + print( + f"verified L{result['layer']}H{result['head']} " + f"seq={seq_len} pat={pattern_len} seed={seed}", + flush=True, + ) + row = { + "model_path": model_path, + "seq_len": seq_len, + "pattern_len": pattern_len, + "seed": seed, + "layer": result["layer"], + "head": result["head"], + "target": result["target"], + "negative": result["negative"], + "approx_logit_lb": result["logit_lb"], + "approx_eps": result["eps"], + "approx_margin": result["margin"], + "approx_min_prev": result["min_prev"], + "approx_value_range": result["value_range"], + "active": result["active"], + "period": period if period is not None else "", + "verify_status": status, + "verify_logit_lb": verify.logit_lb or "", + } + write_csv_row(args.output, row) + if not verify.ok: + if verify.stdout: + print(f" out: {verify.stdout}", flush=True) + if verify.stderr: + print(f" err: {verify.stderr}", flush=True) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From 0f780afb7d7ac9e9181800862990f2792e827a01 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sun, 4 Jan 2026 03:07:58 +0100 Subject: [PATCH 106/244] Add transformer-stack bounds and model-derived residual intervals --- AGENTS.md | 10 +- CLAIMS.md | 12 +- Nfp/Cli.lean | 7 +- Nfp/IO.lean | 83 +++++-- Nfp/IO/NfptPure.lean | 113 ++++++++- Nfp/Model/Gpt2.lean | 52 +++- Nfp/Sound.lean | 4 + Nfp/Sound/Bounds/Attention.lean | 384 ++++++++++++++++++++++++++++++ Nfp/Sound/Bounds/Gelu.lean | 142 +++++++++++ Nfp/Sound/Bounds/LayerNorm.lean | 335 ++++++++++++++++++++++++++ Nfp/Sound/Bounds/MatrixNorm.lean | 50 ++++ Nfp/Sound/Bounds/Mlp.lean | 285 ++++++++++++++++++++++ Nfp/Sound/Bounds/Transformer.lean | 281 ++++++++++++++++++++++ README.md | 9 +- SOUNDNESS_LIMITATIONS.md | 10 +- 15 files changed, 1728 insertions(+), 49 deletions(-) create mode 100644 Nfp/Sound/Bounds/Attention.lean create mode 100644 Nfp/Sound/Bounds/Gelu.lean create mode 100644 Nfp/Sound/Bounds/Mlp.lean create mode 100644 Nfp/Sound/Bounds/Transformer.lean diff --git a/AGENTS.md b/AGENTS.md index d69c644..fea08bc 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -347,6 +347,14 @@ but you **must** update this list in the same commit. - Sound builders for induction certificates from exact inputs. - `Nfp/Sound/Bounds/MatrixNorm.lean` - Row-sum matrix norms and downstream linear certificate builders. +- `Nfp/Sound/Bounds/Gelu.lean` + - Tanh-GELU bounds for interval propagation through MLPs. +- `Nfp/Sound/Bounds/Mlp.lean` + - Interval bounds for GPT-2 MLP blocks and LayerNorm composition. +- `Nfp/Sound/Bounds/Attention.lean` + - Interval bounds for multi-head attention and transformer layers. +- `Nfp/Sound/Bounds/Transformer.lean` + - Interval bounds for transformer stacks and final LayerNorm outputs. - `Nfp/Sound/Linear/FinFold.lean` - Tail-recursive folds and sums for sound linear computations. - `Nfp/Sound/Gpt2/HeadInputs.lean` @@ -360,7 +368,7 @@ but you **must** update this list in the same commit. - `Nfp/Model/InductionPrompt.lean` - Prompt utilities (`prev` map and active set for periodic prompts). - `Nfp/Model/Gpt2.lean` - - Exact GPT-2 head-slice data and embedding helpers. + - Exact GPT-2 head-slice data, layer/MLP/LayerNorm parameters, and embedding helpers. - `Nfp/Model.lean` - Aggregator for model input modules. diff --git a/CLAIMS.md b/CLAIMS.md index a725704..3b1ac7a 100644 --- a/CLAIMS.md +++ b/CLAIMS.md @@ -15,6 +15,9 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri - Residual-interval certificate soundness: `checkResidualIntervalCert` implies `ResidualIntervalBounds`. - Row-sum matrix norm bounds for `mulVec` under uniform input magnitude. +- Tanh-GELU bounds and interval propagation through MLP layers. +- Interval bounds for multi-head attention and full transformer-layer residual blocks. +- Interval bounds for transformer stacks and final LayerNorm outputs. ## Soundly checked by the trusted CLI @@ -31,9 +34,10 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri downstream error certificate (arithmetic consistency only). - `nfp induction certify_end_to_end_matrix` computes a downstream bound from a matrix payload using verified row-sum norms, then composes it with the head-level logit-diff lower bound. -- `nfp induction certify_end_to_end_model` derives a downstream matrix from an `NFP_BINARY_V1` - model file (unembedding direction only), computes a downstream error bound from a - residual-interval certificate, and composes it with the head-level logit-diff lower bound. +- `nfp induction certify_end_to_end_model` derives the unembedding direction from an + `NFP_BINARY_V1` model file, computes a downstream error bound from either a supplied + residual-interval certificate or a verified model-derived interval, and composes it with + the head-level logit-diff lower bound. ## Untrusted / heuristic @@ -42,7 +46,7 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri `scripts/build_downstream_linear_cert.py`. - The head-input extractor now emits attention projection biases and LayerNorm metadata, but the Lean-side computation still ignores LayerNorm and the shared attention output bias. -- Residual-interval certificates are generated externally (unchecked beyond consistency). +- External residual-interval scripts remain untrusted; model-derived bounds are now available. - Any downstream error bound provided externally (outside the matrix-payload path). ## Not yet proven diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index 799ff16..ec4998f 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -142,12 +142,12 @@ def runInductionCertifyEndToEndModel (p : Parsed) : IO UInt32 := do let scoresPath := p.flag! "scores" |>.as! String let valuesPath := p.flag! "values" |>.as! String let modelPath := p.flag! "model" |>.as! String - let residualIntervalPath := p.flag! "residual-interval" |>.as! String + let residualIntervalPath? := (p.flag? "residual-interval").map (·.as! String) let minActive? := (p.flag? "min-active").map (·.as! Nat) let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) let minMarginStr? := (p.flag? "min-margin").map (·.as! String) let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) - IO.runInductionCertifyEndToEndModel scoresPath valuesPath modelPath residualIntervalPath + IO.runInductionCertifyEndToEndModel scoresPath valuesPath modelPath residualIntervalPath? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? /-- `nfp induction certify_end_to_end_model` subcommand. -/ @@ -158,7 +158,8 @@ def inductionCertifyEndToEndModelCmd : Cmd := `[Cli| scores : String; "Path to the softmax-margin certificate file." values : String; "Path to the value-range certificate file." model : String; "Path to the NFP_BINARY_V1 model file." - "residual-interval" : String; "Path to the residual-interval certificate file." + "residual-interval" : String; "Optional path to a residual-interval certificate file \ + (defaults to deriving from the model)." "min-active" : Nat; "Optional minimum number of active queries required \ (default: max 1 (seq/8))." "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ diff --git a/Nfp/IO.lean b/Nfp/IO.lean index 7f17ee2..bffd260 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -8,6 +8,7 @@ import Nfp.Circuit.Cert.DownstreamLinear import Nfp.Circuit.Cert.ResidualBound import Nfp.Circuit.Cert.ResidualInterval import Nfp.Sound.Bounds.MatrixNorm +import Nfp.Sound.Bounds.Transformer import Nfp.Sound.Induction import Nfp.Sound.Induction.LogitDiff @@ -91,6 +92,36 @@ private def emitResidualIntervalCert {n : Nat} (c : ResidualIntervalCert n) | some path => IO.FS.writeFile path (payload ++ "\n") | none => IO.println payload +/-! Derived residual intervals from model binaries. -/ + +/-- Derive residual-interval bounds from a model binary via interval propagation. -/ +private def deriveResidualIntervalFromModel (data : ByteArray) (start : Nat) + (header : NfptPure.NfptHeader) : + Except String (ResidualIntervalCert header.modelDim) := do + if hseq : header.seqLen = 0 then + throw "seq must be positive" + else + have _ : NeZero header.seqLen := ⟨hseq⟩ + if header.modelDim = 0 then + throw "model dim must be positive" + else if 0 < header.layerNormEps then + let embed ← NfptPure.readEmbeddings data start header + let layerSlices ← NfptPure.readLayerSlices data start header + let headLayers ← NfptPure.readLayerHeads data start header + let finalLn ← NfptPure.readFinalLayerNorm data start header + let layers : Fin header.numLayers → Model.Gpt2LayerSlice header.modelDim header.hiddenDim := + fun l => NfptPure.SizedArray.get layerSlices l + let heads : + Fin header.numLayers → Fin header.numHeads → + Model.Gpt2HeadWeights header.modelDim header.headDim := fun l h => + NfptPure.SizedArray.get (NfptPure.SizedArray.get headLayers l) h + let bounds := + Sound.Bounds.gpt2ResidualIntervalBounds (eps := header.layerNormEps) + layers heads finalLn embed + return { lo := bounds.1, hi := bounds.2 } + else + throw s!"layer norm epsilon {header.layerNormEps} must be positive" + private def buildHeadOutputIntervalFromInputs {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (outPath? : Option System.FilePath) : IO UInt32 := do @@ -608,10 +639,11 @@ def runInductionCertifyEndToEndMatrix (scoresPath : System.FilePath) finalLB={finalLB})" return 0 -/-- Check end-to-end induction certificates using a model file and residual bounds. -/ +/-- Check end-to-end induction certificates using a model file and residual bounds + (loaded from disk or derived from the model). -/ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) (valuesPath : System.FilePath) (modelPath : System.FilePath) - (residualIntervalPath : System.FilePath) (minActive? : Option Nat) + (residualIntervalPath? : Option System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? @@ -703,17 +735,32 @@ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) IO.eprintln s!"error: {msg}" return 1 | Except.ok ⟨header, start⟩ => - let parsedResidual ← - loadResidualIntervalCert residualIntervalPath - match parsedResidual with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨dim, residualCert⟩ => - if hdim : dim = header.modelDim then - let residualCert' : - ResidualIntervalCert header.modelDim := by - simpa [hdim] using residualCert + if header.seqLen = seq then + let residualCertE : Except String + (ResidualIntervalCert header.modelDim) ← + match residualIntervalPath? with + | some residualIntervalPath => do + let parsedResidual ← + loadResidualIntervalCert residualIntervalPath + match parsedResidual with + | Except.error msg => pure (Except.error msg) + | Except.ok ⟨dim, residualCert⟩ => + if hdim : dim = header.modelDim then + let residualCert' : + ResidualIntervalCert header.modelDim := by + simpa [hdim] using residualCert + pure (Except.ok residualCert') + else + pure (Except.error + s!"residual interval dim {dim} \ + does not match model dim {header.modelDim}") + | none => + pure (deriveResidualIntervalFromModel data start header) + match residualCertE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok residualCert' => let residualOk := Circuit.checkResidualIntervalCert residualCert' if residualOk then @@ -769,11 +816,11 @@ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) IO.eprintln "error: residual-interval certificate rejected" return 2 - else - IO.eprintln - s!"error: residual interval dim {dim} \ - does not match model dim {header.modelDim}" - return 2 + else + IO.eprintln + s!"error: model seq {header.seqLen} \ + does not match cert seq {seq}" + return 2 private def checkInductionHeadInputs {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) diff --git a/Nfp/IO/NfptPure.lean b/Nfp/IO/NfptPure.lean index bf9ab72..ad9cea3 100644 --- a/Nfp/IO/NfptPure.lean +++ b/Nfp/IO/NfptPure.lean @@ -2,6 +2,7 @@ import Mathlib.Algebra.Order.Ring.Rat import Mathlib.Data.List.Range +import Nfp.Model.Gpt2 import Nfp.Model.InductionHead import Nfp.Model.InductionPrompt @@ -36,6 +37,17 @@ structure NfptHeader where /-- LayerNorm epsilon parameter. -/ layerNormEps : Rat +/-- Array with a fixed size proof. -/ +structure SizedArray (n : Nat) (α : Type) where + /-- Underlying array data. -/ + data : Array α + /-- Size proof for the array. -/ + size_eq : data.size = n + +/-- Index into a `SizedArray` using a `Fin`. -/ +def SizedArray.get {n : Nat} {α : Type} (arr : SizedArray n α) (i : Fin n) : α := + arr.data[i.val]'(by simp [arr.size_eq]) + private def parseNat (s : String) : Except String Nat := match s.toNat? with | some n => Except.ok n @@ -337,6 +349,10 @@ private def f64CountBeforeHeads (h : NfptHeader) : Nat := def unembedOffset (h : NfptHeader) : Nat := bytesI32 h.seqLen + bytesF64 (f64CountBeforeUnembed h) +private def finalLayerNormOffset (h : NfptHeader) : Nat := + bytesI32 h.seqLen + + bytesF64 (f64CountBeforeHeads h + h.numLayers * f64CountPerLayer h) + /-- Read input embeddings stored in the binary. -/ def readEmbeddings (data : ByteArray) (start : Nat) (h : NfptHeader) : Except String (Fin h.seqLen → Fin h.modelDim → Rat) := do @@ -368,19 +384,10 @@ private def layerExtrasOffset (h : NfptHeader) (layer : Nat) : Nat := layer * f64CountPerLayer h + h.numHeads * f64CountPerHead h) -/-- Head weights plus biases for a single attention head. -/ -private structure HeadWeights (dModel dHead : Nat) where - wq : Fin dModel → Fin dHead → Rat - bq : Fin dHead → Rat - wk : Fin dModel → Fin dHead → Rat - bk : Fin dHead → Rat - wv : Fin dModel → Fin dHead → Rat - bv : Fin dHead → Rat - wo : Fin dModel → Fin dHead → Rat - -private def readHeadWeights (data : ByteArray) (start : Nat) (h : NfptHeader) +/-- Read attention head weights and biases for a specific layer/head. -/ +def readHeadWeights (data : ByteArray) (start : Nat) (h : NfptHeader) (layer head : Nat) : - Except String (HeadWeights h.modelDim h.headDim) := do + Except String (Model.Gpt2HeadWeights h.modelDim h.headDim) := do if layer < h.numLayers then if head < h.numHeads then let base := start + headOffset h layer head @@ -423,6 +430,88 @@ private def readLayerAttnBiasLn1 (data : ByteArray) (start : Nat) (h : NfptHeade else throw s!"layer index out of range: {layer}" +/-- Read GPT-2 layer parameters (MLP + LayerNorm) from the model binary. -/ +def readLayerSlice (data : ByteArray) (start : Nat) (h : NfptHeader) + (layer : Nat) : Except String (Model.Gpt2LayerSlice h.modelDim h.hiddenDim) := do + if layer < h.numLayers then + let base := start + layerExtrasOffset h layer + let attnBias ← readF64Vec data base h.modelDim + let offWIn := base + bytesF64 h.modelDim + let mlpWIn ← readF64Matrix data offWIn h.modelDim h.hiddenDim + let offBIn := offWIn + bytesF64 (h.modelDim * h.hiddenDim) + let mlpBIn ← readF64Vec data offBIn h.hiddenDim + let offWOut := offBIn + bytesF64 h.hiddenDim + let mlpWOut ← readF64Matrix data offWOut h.hiddenDim h.modelDim + let offBOut := offWOut + bytesF64 (h.hiddenDim * h.modelDim) + let mlpBOut ← readF64Vec data offBOut h.modelDim + let offLn1Gamma := offBOut + bytesF64 h.modelDim + let ln1Gamma ← readF64Vec data offLn1Gamma h.modelDim + let offLn1Beta := offLn1Gamma + bytesF64 h.modelDim + let ln1Beta ← readF64Vec data offLn1Beta h.modelDim + let offLn2Gamma := offLn1Beta + bytesF64 h.modelDim + let ln2Gamma ← readF64Vec data offLn2Gamma h.modelDim + let offLn2Beta := offLn2Gamma + bytesF64 h.modelDim + let ln2Beta ← readF64Vec data offLn2Beta h.modelDim + return { attnBias := attnBias + mlpWIn := mlpWIn + mlpBIn := mlpBIn + mlpWOut := mlpWOut + mlpBOut := mlpBOut + ln1Gamma := ln1Gamma + ln1Beta := ln1Beta + ln2Gamma := ln2Gamma + ln2Beta := ln2Beta } + else + throw s!"layer index out of range: {layer}" + +/-- Read all GPT-2 layer slices from the model binary. -/ +def readLayerSlices (data : ByteArray) (start : Nat) (h : NfptHeader) : + Except String (SizedArray h.numLayers (Model.Gpt2LayerSlice h.modelDim h.hiddenDim)) := do + let slices ← (List.finRange h.numLayers).foldlM + (fun (acc : Array (Model.Gpt2LayerSlice h.modelDim h.hiddenDim)) layer => do + let slice ← readLayerSlice data start h layer.val + pure (acc.push slice)) + (#[] : Array (Model.Gpt2LayerSlice h.modelDim h.hiddenDim)) + if hlen : slices.size = h.numLayers then + return { data := slices, size_eq := hlen } + else + throw "internal error: layer slice count mismatch" + +/-- Read all attention head weights from the model binary. -/ +def readLayerHeads (data : ByteArray) (start : Nat) (h : NfptHeader) : + Except String + (SizedArray h.numLayers + (SizedArray h.numHeads (Model.Gpt2HeadWeights h.modelDim h.headDim))) := do + let layers ← (List.finRange h.numLayers).foldlM + (fun (acc : Array + (SizedArray h.numHeads (Model.Gpt2HeadWeights h.modelDim h.headDim))) layer => do + let heads ← (List.finRange h.numHeads).foldlM + (fun (accHead : Array (Model.Gpt2HeadWeights h.modelDim h.headDim)) head => do + let weights ← readHeadWeights data start h layer.val head.val + pure (accHead.push weights)) + (#[] : Array (Model.Gpt2HeadWeights h.modelDim h.headDim)) + if hlen : heads.size = h.numHeads then + let headArray : SizedArray h.numHeads (Model.Gpt2HeadWeights h.modelDim h.headDim) := + { data := heads, size_eq := hlen } + pure (acc.push headArray) + else + throw "internal error: head count mismatch") + (#[] : Array + (SizedArray h.numHeads (Model.Gpt2HeadWeights h.modelDim h.headDim))) + if hlen : layers.size = h.numLayers then + return { data := layers, size_eq := hlen } + else + throw "internal error: layer head count mismatch" + +/-- Read the final LayerNorm parameters from the model binary. -/ +def readFinalLayerNorm (data : ByteArray) (start : Nat) (h : NfptHeader) : + Except String (Model.Gpt2FinalLayerNorm h.modelDim) := do + let base := start + finalLayerNormOffset h + let gamma ← readF64Vec data base h.modelDim + let offBeta := base + bytesF64 h.modelDim + let beta ← readF64Vec data offBeta h.modelDim + return { gamma := gamma, beta := beta } + /-- Read a single unembedding column as exact rationals. -/ def readUnembedColumn (data : ByteArray) (start : Nat) (h : NfptHeader) (col : Nat) : Except String (Fin h.modelDim → Rat) := do diff --git a/Nfp/Model/Gpt2.lean b/Nfp/Model/Gpt2.lean index c81d0a8..a360d92 100644 --- a/Nfp/Model/Gpt2.lean +++ b/Nfp/Model/Gpt2.lean @@ -4,10 +4,11 @@ import Mathlib.Algebra.Order.Ring.Rat import Nfp.Circuit.Cert.ValueRange /-! -Exact GPT-2 head-slice data for induction certification. +Exact GPT-2 slices for induction certification and downstream bounds. -This module holds the precise token embeddings, position embeddings, and head -projection weights needed to build `InductionHeadInputs` for a single head. +This module holds token embeddings, head projection weights, and per-layer +MLP/LayerNorm parameters needed to build `InductionHeadInputs` and downstream +bound computations. -/ namespace Nfp @@ -62,6 +63,51 @@ structure Gpt2HeadSlice (seq dModel dHead vocab : Nat) where /-- Direction tokens for logit-diff certification. -/ direction : DirectionTokens vocab +/-- Exact per-head attention weights and biases. -/ +structure Gpt2HeadWeights (dModel dHead : Nat) where + /-- Query projection weights. -/ + wq : Fin dModel → Fin dHead → Rat + /-- Query projection bias. -/ + bq : Fin dHead → Rat + /-- Key projection weights. -/ + wk : Fin dModel → Fin dHead → Rat + /-- Key projection bias. -/ + bk : Fin dHead → Rat + /-- Value projection weights. -/ + wv : Fin dModel → Fin dHead → Rat + /-- Value projection bias. -/ + bv : Fin dHead → Rat + /-- Output projection weights for this head slice. -/ + wo : Fin dModel → Fin dHead → Rat + +/-- Exact GPT-2 layer slice with MLP and LayerNorm parameters. -/ +structure Gpt2LayerSlice (dModel hidden : Nat) where + /-- Attention output bias (shared across heads). -/ + attnBias : Fin dModel → Rat + /-- MLP input projection weights. -/ + mlpWIn : Fin dModel → Fin hidden → Rat + /-- MLP input projection bias. -/ + mlpBIn : Fin hidden → Rat + /-- MLP output projection weights. -/ + mlpWOut : Fin hidden → Fin dModel → Rat + /-- MLP output projection bias. -/ + mlpBOut : Fin dModel → Rat + /-- LayerNorm scale for the attention input. -/ + ln1Gamma : Fin dModel → Rat + /-- LayerNorm bias for the attention input. -/ + ln1Beta : Fin dModel → Rat + /-- LayerNorm scale for the MLP input. -/ + ln2Gamma : Fin dModel → Rat + /-- LayerNorm bias for the MLP input. -/ + ln2Beta : Fin dModel → Rat + +/-- Final LayerNorm parameters applied before unembedding. -/ +structure Gpt2FinalLayerNorm (dModel : Nat) where + /-- LayerNorm scale. -/ + gamma : Fin dModel → Rat + /-- LayerNorm bias. -/ + beta : Fin dModel → Rat + /-- Token-plus-position embeddings for a GPT-2 head slice. -/ def Gpt2HeadSlice.embed {seq dModel dHead vocab : Nat} (slice : Gpt2HeadSlice seq dModel dHead vocab) : diff --git a/Nfp/Sound.lean b/Nfp/Sound.lean index aff7bb1..dee6446 100644 --- a/Nfp/Sound.lean +++ b/Nfp/Sound.lean @@ -4,6 +4,10 @@ import Nfp.Sound.Gpt2.HeadInputs import Nfp.Sound.Induction import Nfp.Sound.Induction.LogitDiff import Nfp.Sound.Bounds.MatrixNorm +import Nfp.Sound.Bounds.Gelu +import Nfp.Sound.Bounds.Mlp +import Nfp.Sound.Bounds.Attention +import Nfp.Sound.Bounds.Transformer import Nfp.Sound.Linear.FinFold /-! diff --git a/Nfp/Sound/Bounds/Attention.lean b/Nfp/Sound/Bounds/Attention.lean new file mode 100644 index 0000000..0208bf8 --- /dev/null +++ b/Nfp/Sound/Bounds/Attention.lean @@ -0,0 +1,384 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.BigOperators.Field +import Mathlib.Algebra.BigOperators.Ring.Finset +import Mathlib.Algebra.Order.BigOperators.Group.Finset +import Mathlib.Algebra.Order.Ring.Rat +import Mathlib.Data.Rat.Cast.Order +import Mathlib.Data.Real.Basic +import Nfp.Circuit.Layers.Softmax +import Nfp.Model.Gpt2 +import Nfp.Sound.Bounds.LayerNorm +import Nfp.Sound.Bounds.MatrixNorm +import Nfp.Sound.Bounds.Mlp + +/-! +Interval bounds for multi-head attention and transformer layers. +-/ + +namespace Nfp + +namespace Sound + +namespace Bounds + +open scoped BigOperators + +/-- Real-valued attention output for a query token and model coordinate. -/ +noncomputable def attentionOutputReal {seq dModel dHead numHeads : Nat} [NeZero seq] + (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) + (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (attnBias : Fin dModel → Rat) + (scores : Fin numHeads → Fin seq → Fin seq → Real) + (x : Fin seq → Fin dModel → Real) + (q : Fin seq) (i : Fin dModel) : Real := + let lnOut : Fin seq → Fin dModel → Real := fun k j => + layerNormRealOfReal eps ln1Gamma ln1Beta (x k) j + let headValue : Fin numHeads → Fin seq → Fin dHead → Real := fun h k d => + dotProduct (fun j => ((heads h).wv j d : Real)) (lnOut k) + (heads h).bv d + let headWeights : Fin numHeads → Fin seq → Fin seq → Real := fun h q k => + Circuit.softmax (scores h q) k + let headOutput : Fin numHeads → Fin seq → Fin dHead → Real := fun h q d => + dotProduct (headWeights h q) (fun k => headValue h k d) + let headProj : Fin numHeads → Fin seq → Fin dModel → Real := fun h q j => + dotProduct (fun d => ((heads h).wo j d : Real)) (fun d => headOutput h q d) + (∑ h, headProj h q i) + (attnBias i : Real) + +/-- Unfolding lemma for `attentionOutputReal`. -/ +theorem attentionOutputReal_def {seq dModel dHead numHeads : Nat} [NeZero seq] + (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) + (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (attnBias : Fin dModel → Rat) + (scores : Fin numHeads → Fin seq → Fin seq → Real) + (x : Fin seq → Fin dModel → Real) + (q : Fin seq) (i : Fin dModel) : + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q i = + let lnOut : Fin seq → Fin dModel → Real := fun k j => + layerNormRealOfReal eps ln1Gamma ln1Beta (x k) j + let headValue : Fin numHeads → Fin seq → Fin dHead → Real := fun h k d => + dotProduct (fun j => ((heads h).wv j d : Real)) (lnOut k) + (heads h).bv d + let headWeights : Fin numHeads → Fin seq → Fin seq → Real := fun h q k => + Circuit.softmax (scores h q) k + let headOutput : Fin numHeads → Fin seq → Fin dHead → Real := fun h q d => + dotProduct (headWeights h q) (fun k => headValue h k d) + let headProj : Fin numHeads → Fin seq → Fin dModel → Real := fun h q j => + dotProduct (fun d => ((heads h).wo j d : Real)) (fun d => headOutput h q d) + (∑ h, headProj h q i) + (attnBias i : Real) := rfl + +/-- Interval bounds for multi-head attention outputs from interval inputs. -/ +def attentionOutputBounds {dModel dHead numHeads : Nat} + (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) + (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (attnBias : Fin dModel → Rat) + (lo hi : Fin dModel → Rat) : + (Fin dModel → Rat) × (Fin dModel → Rat) := + let absBound := intervalAbsBound lo hi + let ln := layerNormAbsBounds eps ln1Gamma ln1Beta absBound + let lnLo := ln.1 + let lnHi := ln.2 + let vLo : Fin numHeads → Fin dHead → Rat := fun h d => + dotIntervalLower (fun j => (heads h).wv j d) lnLo lnHi + (heads h).bv d + let vHi : Fin numHeads → Fin dHead → Rat := fun h d => + dotIntervalUpper (fun j => (heads h).wv j d) lnLo lnHi + (heads h).bv d + let headLo : Fin numHeads → Fin dModel → Rat := fun h i => + dotIntervalLower (fun d => (heads h).wo i d) (vLo h) (vHi h) + let headHi : Fin numHeads → Fin dModel → Rat := fun h i => + dotIntervalUpper (fun d => (heads h).wo i d) (vLo h) (vHi h) + let sumLo : Fin dModel → Rat := fun i => ∑ h, headLo h i + let sumHi : Fin dModel → Rat := fun i => ∑ h, headHi h i + (fun i => sumLo i + attnBias i, fun i => sumHi i + attnBias i) + +/-- `attentionOutputBounds` soundness for real attention outputs. -/ +theorem attentionOutputBounds_spec {seq dModel dHead numHeads : Nat} [NeZero seq] + (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) + (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (attnBias : Fin dModel → Rat) + (scores : Fin numHeads → Fin seq → Fin seq → Real) + (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) + (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : + let bounds := attentionOutputBounds eps ln1Gamma ln1Beta heads attnBias lo hi + ∀ q i, + (bounds.1 i : Real) ≤ + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q i ∧ + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q i ≤ + (bounds.2 i : Real) := by + classical + intro bounds q i + let absBound := intervalAbsBound lo hi + let lnBounds := layerNormAbsBounds eps ln1Gamma ln1Beta absBound + let lnLo := lnBounds.1 + let lnHi := lnBounds.2 + let lnOut : Fin seq → Fin dModel → Real := fun k j => + layerNormRealOfReal eps ln1Gamma ln1Beta (x k) j + let vLo : Fin numHeads → Fin dHead → Rat := fun h d => + dotIntervalLower (fun j => (heads h).wv j d) lnLo lnHi + (heads h).bv d + let vHi : Fin numHeads → Fin dHead → Rat := fun h d => + dotIntervalUpper (fun j => (heads h).wv j d) lnLo lnHi + (heads h).bv d + let headLo : Fin numHeads → Fin dModel → Rat := fun h j => + dotIntervalLower (fun d => (heads h).wo j d) (vLo h) (vHi h) + let headHi : Fin numHeads → Fin dModel → Rat := fun h j => + dotIntervalUpper (fun d => (heads h).wo j d) (vLo h) (vHi h) + let sumLo : Fin dModel → Rat := fun j => ∑ h, headLo h j + let sumHi : Fin dModel → Rat := fun j => ∑ h, headHi h j + let headValue : Fin numHeads → Fin seq → Fin dHead → Real := fun h k d => + dotProduct (fun j => ((heads h).wv j d : Real)) (lnOut k) + (heads h).bv d + let headWeights : Fin numHeads → Fin seq → Fin seq → Real := fun h q k => + Circuit.softmax (scores h q) k + let headOutput : Fin numHeads → Fin seq → Fin dHead → Real := fun h q d => + dotProduct (headWeights h q) (fun k => headValue h k d) + let headProj : Fin numHeads → Fin seq → Fin dModel → Real := fun h q j => + dotProduct (fun d => ((heads h).wo j d : Real)) (fun d => headOutput h q d) + have habs : ∀ q i, |x q i| ≤ (absBound : Real) := by + intro q i + have hbound : + |x q i| ≤ max |(lo i : Real)| |(hi i : Real)| := + abs_le_max_abs_abs_of_interval_real (hlo q i) (hhi q i) + have hnonempty : (Finset.univ : Finset (Fin dModel)).Nonempty := ⟨i, by simp⟩ + have hsup : + max |lo i| |hi i| ≤ intervalAbsBound lo hi := by + have hsup' : + max |lo i| |hi i| ≤ + (Finset.univ).sup' hnonempty (fun k => max |lo k| |hi k|) := by + simpa using + (Finset.le_sup' + (s := (Finset.univ : Finset (Fin dModel))) + (f := fun k => max |lo k| |hi k|) + (by simp : i ∈ (Finset.univ : Finset (Fin dModel)))) + simpa [intervalAbsBound, hnonempty] using hsup' + have hsup_real : + max |(lo i : Real)| |(hi i : Real)| ≤ (absBound : Real) := by + exact_mod_cast hsup + exact le_trans hbound hsup_real + have hln_bounds : ∀ q i, (lnLo i : Real) ≤ lnOut q i ∧ lnOut q i ≤ (lnHi i : Real) := by + intro q i + have hln := layerNormAbsBounds_spec_real eps ln1Gamma ln1Beta absBound (x q) hne heps + (fun j => habs q j) + simpa [lnBounds, lnLo, lnHi, lnOut] using hln i + have hval_bounds : + ∀ h k d, + (vLo h d : Real) ≤ headValue h k d ∧ + headValue h k d ≤ (vHi h d : Real) := by + intro h k d + have hln := hln_bounds k + have hlo' : ∀ j, (lnLo j : Real) ≤ lnOut k j := fun j => (hln j).1 + have hhi' : ∀ j, lnOut k j ≤ (lnHi j : Real) := fun j => (hln j).2 + have hlow := + dotIntervalLower_le_dotProduct_real (v := fun j => (heads h).wv j d) + (lo := lnLo) (hi := lnHi) (x := lnOut k) hlo' hhi' + have hhigh := + dotProduct_le_dotIntervalUpper_real (v := fun j => (heads h).wv j d) + (lo := lnLo) (hi := lnHi) (x := lnOut k) hlo' hhi' + have hlow' := add_le_add_right hlow ((heads h).bv d : Real) + have hhigh' := add_le_add_right hhigh ((heads h).bv d : Real) + constructor + · simpa [headValue, vLo, Rat.cast_add] using hlow' + · simpa [headValue, vHi, Rat.cast_add] using hhigh' + have weighted_bounds : + ∀ {lo hi : Real} {vals : Fin seq → Real} {w : Fin seq → Real}, + (∀ k, lo ≤ vals k) → (∀ k, vals k ≤ hi) → + (∀ k, 0 ≤ w k) → (∑ k, w k = 1) → + lo ≤ dotProduct w vals ∧ dotProduct w vals ≤ hi := by + intro lo hi vals w hlo' hhi' hnonneg hsum + have hsum_lo : ∑ k, w k * lo ≤ ∑ k, w k * vals k := by + refine Finset.sum_le_sum ?_ + intro k _ + exact mul_le_mul_of_nonneg_left (hlo' k) (hnonneg k) + have hsum_lo' : ∑ k, w k * lo = lo := by + calc + ∑ k, w k * lo = (∑ k, w k) * lo := by + simpa using + (Finset.sum_mul (s := (Finset.univ : Finset (Fin seq))) (f := w) (a := lo)).symm + _ = lo := by simp [hsum] + have hlow : lo ≤ dotProduct w vals := by + have hsum_le : lo ≤ ∑ k, w k * vals k := by + simpa [hsum_lo'] using hsum_lo + simpa [dotProduct] using hsum_le + have hsum_hi : ∑ k, w k * vals k ≤ ∑ k, w k * hi := by + refine Finset.sum_le_sum ?_ + intro k _ + exact mul_le_mul_of_nonneg_left (hhi' k) (hnonneg k) + have hsum_hi' : ∑ k, w k * hi = hi := by + calc + ∑ k, w k * hi = (∑ k, w k) * hi := by + simpa using + (Finset.sum_mul (s := (Finset.univ : Finset (Fin seq))) (f := w) (a := hi)).symm + _ = hi := by simp [hsum] + have hhigh : dotProduct w vals ≤ hi := by + have hsum_le : ∑ k, w k * vals k ≤ hi := by + simpa [hsum_hi'] using hsum_hi + simpa [dotProduct] using hsum_le + exact ⟨hlow, hhigh⟩ + have hhead_output_bounds : + ∀ h q d, + (vLo h d : Real) ≤ headOutput h q d ∧ + headOutput h q d ≤ (vHi h d : Real) := by + intro h q d + have hvals := hval_bounds h + have hlo' : ∀ k, + (vLo h d : Real) ≤ headValue h k d := fun k => (hvals k d).1 + have hhi' : ∀ k, + headValue h k d ≤ (vHi h d : Real) := fun k => (hvals k d).2 + have hnonneg : ∀ k, 0 ≤ headWeights h q k := by + intro k + exact Circuit.softmax_nonneg (scores h q) k + have hsum : ∑ k, headWeights h q k = 1 := by + simpa [headWeights] using Circuit.softmax_sum_one (scores h q) + have h := weighted_bounds (lo := (vLo h d : Real)) (hi := (vHi h d : Real)) + (vals := fun k => headValue h k d) (w := headWeights h q) + hlo' hhi' hnonneg hsum + simpa [headOutput] using h + have hproj_bounds : + ∀ h q i, + (headLo h i : Real) ≤ headProj h q i ∧ headProj h q i ≤ (headHi h i : Real) := by + intro h q i + have hout := hhead_output_bounds h q + have hlo' : ∀ d, + (vLo h d : Real) ≤ headOutput h q d := fun d => (hout d).1 + have hhi' : ∀ d, + headOutput h q d ≤ (vHi h d : Real) := fun d => (hout d).2 + have hlow := + dotIntervalLower_le_dotProduct_real (v := fun d => (heads h).wo i d) + (lo := vLo h) (hi := vHi h) + (x := fun d => headOutput h q d) hlo' hhi' + have hhigh := + dotProduct_le_dotIntervalUpper_real (v := fun d => (heads h).wo i d) + (lo := vLo h) (hi := vHi h) + (x := fun d => headOutput h q d) hlo' hhi' + constructor + · simpa [headProj, headLo] using hlow + · simpa [headProj, headHi] using hhigh + have hsum_bounds : + (sumLo i : Real) ≤ ∑ h, headProj h q i ∧ + ∑ h, headProj h q i ≤ (sumHi i : Real) := by + have hlow : (sumLo i : Real) ≤ ∑ h, headProj h q i := by + have hsum := + Finset.sum_le_sum (s := (Finset.univ : Finset (Fin numHeads))) + (fun h _ => (hproj_bounds h q i).1) + simpa [sumLo] using hsum + have hhigh : ∑ h, headProj h q i ≤ (sumHi i : Real) := by + have hsum := + Finset.sum_le_sum (s := (Finset.univ : Finset (Fin numHeads))) + (fun h _ => (hproj_bounds h q i).2) + simpa [sumHi] using hsum + exact ⟨hlow, hhigh⟩ + have hlow : + (sumLo i : Real) + (attnBias i : Real) ≤ + (∑ h, headProj h q i) + (attnBias i : Real) := by + have h := add_le_add_left hsum_bounds.1 (attnBias i : Real) + simpa [add_comm, add_left_comm, add_assoc] using h + have hhigh : + (∑ h, headProj h q i) + (attnBias i : Real) ≤ + (sumHi i : Real) + (attnBias i : Real) := by + have h := add_le_add_left hsum_bounds.2 (attnBias i : Real) + simpa [add_comm, add_left_comm, add_assoc] using h + have hreal : + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q i = + (∑ h, headProj h q i) + (attnBias i : Real) := by + simp [attentionOutputReal, lnOut, headValue, headWeights, headOutput, headProj] + have hlo : + (bounds.1 i : Real) ≤ + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q i := by + simpa [bounds, attentionOutputBounds, absBound, lnBounds, lnLo, lnHi, vLo, vHi, headLo, headHi, + sumLo, sumHi, hreal] using hlow + have hhi : + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q i ≤ + (bounds.2 i : Real) := by + simpa [bounds, attentionOutputBounds, absBound, lnBounds, lnLo, lnHi, vLo, vHi, headLo, headHi, + sumLo, sumHi, hreal] using hhigh + exact And.intro hlo hhi + +/-- Interval bounds for the attention residual path. -/ +def attentionResidualBounds {dModel dHead numHeads : Nat} + (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) + (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (attnBias : Fin dModel → Rat) + (lo hi : Fin dModel → Rat) : + (Fin dModel → Rat) × (Fin dModel → Rat) := + let attn := attentionOutputBounds eps ln1Gamma ln1Beta heads attnBias lo hi + (fun i => lo i + attn.1 i, fun i => hi i + attn.2 i) + +/-- `attentionResidualBounds` soundness for attention residual outputs. -/ +theorem attentionResidualBounds_spec {seq dModel dHead numHeads : Nat} [NeZero seq] + (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) + (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (attnBias : Fin dModel → Rat) + (scores : Fin numHeads → Fin seq → Fin seq → Real) + (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) + (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : + let bounds := attentionResidualBounds eps ln1Gamma ln1Beta heads attnBias lo hi + ∀ q i, + (bounds.1 i : Real) ≤ + x q i + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q i ∧ + x q i + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q i ≤ + (bounds.2 i : Real) := by + classical + intro bounds q i + let attn := attentionOutputBounds eps ln1Gamma ln1Beta heads attnBias lo hi + have hattn := + attentionOutputBounds_spec eps ln1Gamma ln1Beta heads attnBias scores lo hi x + hne heps hlo hhi q i + have hlow := add_le_add (hlo q i) hattn.1 + have hhigh := add_le_add (hhi q i) hattn.2 + constructor + · simpa [bounds, attentionResidualBounds, attn, Rat.cast_add] using hlow + · simpa [bounds, attentionResidualBounds, attn, Rat.cast_add] using hhigh + +/-- Interval bounds for a full transformer layer (attention + MLP). -/ +def transformerLayerBounds {dModel dHead numHeads hidden : Nat} + (eps : Rat) + (ln1Gamma ln1Beta ln2Gamma ln2Beta : Fin dModel → Rat) + (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (attnBias : Fin dModel → Rat) + (mlpWIn : Fin dModel → Fin hidden → Rat) (mlpBIn : Fin hidden → Rat) + (mlpWOut : Fin hidden → Fin dModel → Rat) (mlpBOut : Fin dModel → Rat) + (lo hi : Fin dModel → Rat) : + (Fin dModel → Rat) × (Fin dModel → Rat) := + let attn := attentionResidualBounds eps ln1Gamma ln1Beta heads attnBias lo hi + layerNormAbsMlpResidualBounds eps ln2Gamma ln2Beta mlpWIn mlpBIn mlpWOut mlpBOut attn.1 attn.2 + +/-- `transformerLayerBounds` soundness for full transformer-layer outputs. -/ +theorem transformerLayerBounds_spec {seq dModel dHead numHeads hidden : Nat} [NeZero seq] + (eps : Rat) + (ln1Gamma ln1Beta ln2Gamma ln2Beta : Fin dModel → Rat) + (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (attnBias : Fin dModel → Rat) + (mlpWIn : Fin dModel → Fin hidden → Rat) (mlpBIn : Fin hidden → Rat) + (mlpWOut : Fin hidden → Fin dModel → Rat) (mlpBOut : Fin dModel → Rat) + (scores : Fin numHeads → Fin seq → Fin seq → Real) + (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) + (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : + let bounds := transformerLayerBounds eps ln1Gamma ln1Beta ln2Gamma ln2Beta heads attnBias + mlpWIn mlpBIn mlpWOut mlpBOut lo hi + ∀ q i, + (bounds.1 i : Real) ≤ + x q i + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q i + + mlpReal mlpWIn mlpBIn mlpWOut mlpBOut + (layerNormRealOfReal eps ln2Gamma ln2Beta + (fun j => + x q j + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q j)) i ∧ + x q i + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q i + + mlpReal mlpWIn mlpBIn mlpWOut mlpBOut + (layerNormRealOfReal eps ln2Gamma ln2Beta + (fun j => + x q j + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q j)) i ≤ + (bounds.2 i : Real) := by + classical + intro bounds q i + let attn := attentionResidualBounds eps ln1Gamma ln1Beta heads attnBias lo hi + have hattn := attentionResidualBounds_spec eps ln1Gamma ln1Beta heads attnBias scores lo hi x + hne heps hlo hhi q + have hmlp := layerNormAbsMlpResidualBounds_spec eps ln2Gamma ln2Beta mlpWIn mlpBIn mlpWOut + mlpBOut attn.1 attn.2 (fun j => x q j + attentionOutputReal eps ln1Gamma ln1Beta heads + attnBias scores x q j) hne heps + (fun j => (hattn j).1) (fun j => (hattn j).2) + have hmlp_i := hmlp i + simpa [bounds, transformerLayerBounds, attn] using hmlp_i + +end Bounds + +end Sound + +end Nfp diff --git a/Nfp/Sound/Bounds/Gelu.lean b/Nfp/Sound/Bounds/Gelu.lean new file mode 100644 index 0000000..be5dc60 --- /dev/null +++ b/Nfp/Sound/Bounds/Gelu.lean @@ -0,0 +1,142 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.Order.Ring.Abs +import Mathlib.Analysis.Complex.Trigonometric +import Mathlib.Analysis.SpecialFunctions.Trigonometric.Basic +import Mathlib.Data.Rat.Cast.Order + +/-! +Tanh-based GELU bounds for GPT-2 style MLPs. +These bounds are used to propagate interval constraints through nonlinear gates. +-/ + +namespace Nfp + +namespace Sound + +namespace Bounds + +/-- Tanh-based GELU activation used by GPT-2 (approximate form). -/ +noncomputable def geluTanh (x : Real) : Real := + let k : Real := Real.sqrt (2 / Real.pi) + let c : Real := (44715 : Real) / 1000000 + x * ((1 + Real.tanh (k * (x + c * x ^ 3))) / 2) + +/-- The hyperbolic tangent is bounded in absolute value by `1`. -/ +theorem abs_tanh_le_one (x : Real) : |Real.tanh x| ≤ 1 := by + have hpos_exp : 0 < Real.exp x := Real.exp_pos x + have hpos_exp_neg : 0 < Real.exp (-x) := Real.exp_pos (-x) + have hsum_pos : 0 < Real.exp x + Real.exp (-x) := + add_pos hpos_exp hpos_exp_neg + have hsum_nonneg : 0 ≤ Real.exp x + Real.exp (-x) := le_of_lt hsum_pos + have habs : |Real.exp x - Real.exp (-x)| ≤ Real.exp x + Real.exp (-x) := by + have h := abs_add_le (Real.exp x) (-Real.exp (-x)) + simpa [sub_eq_add_neg, abs_neg, abs_of_nonneg (le_of_lt hpos_exp), + abs_of_nonneg (le_of_lt hpos_exp_neg)] using h + calc + |Real.tanh x| = + |(Real.exp x - Real.exp (-x)) / (Real.exp x + Real.exp (-x))| := by + simp [Real.tanh_eq] + _ = |Real.exp x - Real.exp (-x)| / (Real.exp x + Real.exp (-x)) := by + simp [abs_div, abs_of_nonneg hsum_nonneg] + _ ≤ (Real.exp x + Real.exp (-x)) / (Real.exp x + Real.exp (-x)) := by + exact div_le_div_of_nonneg_right habs hsum_nonneg + _ = 1 := by + have hne : Real.exp x + Real.exp (-x) ≠ 0 := ne_of_gt hsum_pos + simp [hne] + +/-- The tanh coefficient in `geluTanh` lies in `[0, 1]`. -/ +theorem geluTanh_coeff_bounds (x : Real) : + 0 ≤ + (1 + + Real.tanh + (Real.sqrt (2 / Real.pi) * + (x + (44715 : Real) / 1000000 * x ^ 3))) / + 2 ∧ + (1 + + Real.tanh + (Real.sqrt (2 / Real.pi) * + (x + (44715 : Real) / 1000000 * x ^ 3))) / + 2 ≤ 1 := by + have habs := + abs_tanh_le_one + (Real.sqrt (2 / Real.pi) * (x + (44715 : Real) / 1000000 * x ^ 3)) + have hbounds := abs_le.mp habs + constructor <;> nlinarith + +/-- `geluTanh` outputs stay between `min x 0` and `max x 0`. -/ +theorem geluTanh_bounds (x : Real) : + min x 0 ≤ geluTanh x ∧ geluTanh x ≤ max x 0 := by + by_cases hx : 0 ≤ x + · have hcoeff := geluTanh_coeff_bounds x + have hnonneg : + 0 ≤ x * + ((1 + + Real.tanh + (Real.sqrt (2 / Real.pi) * + (x + (44715 : Real) / 1000000 * x ^ 3))) / + 2) := by + exact mul_nonneg hx hcoeff.1 + have hle : + x * + ((1 + + Real.tanh + (Real.sqrt (2 / Real.pi) * + (x + (44715 : Real) / 1000000 * x ^ 3))) / + 2) ≤ x := by + have h := mul_le_mul_of_nonneg_left hcoeff.2 hx + simpa [mul_one] using h + have h0 : 0 ≤ geluTanh x := by + simpa [geluTanh] using hnonneg + have h1 : geluTanh x ≤ x := by + simpa [geluTanh] using hle + simpa [min_eq_right hx, max_eq_left hx] using And.intro h0 h1 + · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) + have hcoeff := geluTanh_coeff_bounds x + have hle0 : + x * + ((1 + + Real.tanh + (Real.sqrt (2 / Real.pi) * + (x + (44715 : Real) / 1000000 * x ^ 3))) / + 2) ≤ 0 := by + exact mul_nonpos_of_nonpos_of_nonneg hx' hcoeff.1 + have hxle : + x ≤ + x * + ((1 + + Real.tanh + (Real.sqrt (2 / Real.pi) * + (x + (44715 : Real) / 1000000 * x ^ 3))) / + 2) := by + have h := mul_le_mul_of_nonpos_left hcoeff.2 hx' + simpa [mul_one] using h + have h0 : geluTanh x ≤ 0 := by + simpa [geluTanh] using hle0 + have h1 : x ≤ geluTanh x := by + simpa [geluTanh] using hxle + simpa [min_eq_left hx', max_eq_right hx'] using And.intro h1 h0 + +/-- Interval bounds for GELU given input bounds. -/ +def geluInterval (lo hi : Rat) : Rat × Rat := + (min lo 0, max hi 0) + +/-- `geluInterval` soundly bounds `geluTanh` on a real interval. -/ +theorem geluInterval_bounds {lo hi : Rat} {x : Real} + (hlo : (lo : Real) ≤ x) (hhi : x ≤ (hi : Real)) : + (geluInterval lo hi).1 ≤ (geluTanh x : Real) ∧ + (geluTanh x : Real) ≤ (geluInterval lo hi).2 := by + have hgelu := geluTanh_bounds x + have hmin : min (lo : Real) 0 ≤ min x 0 := min_le_min hlo le_rfl + have hmax : max x 0 ≤ max (hi : Real) 0 := max_le_max hhi le_rfl + have hlo' : min (lo : Real) 0 ≤ geluTanh x := le_trans hmin hgelu.1 + have hhi' : geluTanh x ≤ max (hi : Real) 0 := le_trans hgelu.2 hmax + constructor + · simpa [geluInterval, Rat.cast_min] using hlo' + · simpa [geluInterval, Rat.cast_max] using hhi' + +end Bounds + +end Sound + +end Nfp diff --git a/Nfp/Sound/Bounds/LayerNorm.lean b/Nfp/Sound/Bounds/LayerNorm.lean index 8f97090..607fe04 100644 --- a/Nfp/Sound/Bounds/LayerNorm.lean +++ b/Nfp/Sound/Bounds/LayerNorm.lean @@ -1,6 +1,8 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Mathlib.Algebra.BigOperators.Group.Finset.Basic +import Mathlib.Algebra.Order.BigOperators.Group.Finset +import Mathlib.Algebra.Order.Field.Basic import Mathlib.Algebra.Order.Ring.Rat import Mathlib.Data.Nat.Sqrt import Mathlib.Data.Real.Sqrt @@ -62,6 +64,107 @@ theorem variance_nonneg {n : Nat} (x : Fin n → Rat) (h : n ≠ 0) : div_nonneg hsum hden simpa [variance_def x h] using hdiv +/-- Absolute mean bound from per-coordinate bounds. -/ +theorem mean_abs_le_bound {n : Nat} (x : Fin n → Rat) (bound : Rat) + (hne : n ≠ 0) (hbound : ∀ i, |x i| ≤ bound) : + |mean x| ≤ bound := by + classical + have hsum_abs : + |∑ i : Fin n, x i| ≤ ∑ i : Fin n, |x i| := by + simpa using + (Finset.abs_sum_le_sum_abs + (f := fun i : Fin n => x i) + (s := (Finset.univ : Finset (Fin n)))) + have hsum_bound : ∑ i : Fin n, |x i| ≤ ∑ i : Fin n, bound := by + refine Finset.sum_le_sum ?_ + intro i _ + exact hbound i + have hsum_le : |∑ i : Fin n, x i| ≤ (n : Rat) * bound := by + have hsum := le_trans hsum_abs hsum_bound + simpa [Finset.sum_const, Finset.card_univ] using hsum + have hpos : 0 < (n : Rat) := by + exact_mod_cast Nat.pos_of_ne_zero hne + have hsum_le' : |∑ i : Fin n, x i| ≤ bound * (n : Rat) := by + simpa [mul_comm] using hsum_le + have hdiv : |∑ i : Fin n, x i| / (n : Rat) ≤ bound := by + exact (div_le_iff₀ hpos).2 hsum_le' + have habs_mean : + |(∑ i : Fin n, x i) / (n : Rat)| ≤ bound := by + simpa [abs_div, abs_of_nonneg (le_of_lt hpos)] using hdiv + simpa [mean_def x hne] using habs_mean + +/-! Real-valued mean and variance. -/ + +/-- Mean of a real vector (defaults to `0` when `n = 0`). -/ +noncomputable def meanReal {n : Nat} (x : Fin n → Real) : Real := + if n = 0 then + 0 + else + (∑ i, x i) / n + +/-- Unfold `meanReal` when `n ≠ 0`. -/ +theorem meanReal_def {n : Nat} (x : Fin n → Real) (h : n ≠ 0) : + meanReal x = (∑ i, x i) / n := by + simp [meanReal, h] + +/-- Variance of a real vector (defaults to `0` when `n = 0`). -/ +noncomputable def varianceReal {n : Nat} (x : Fin n → Real) : Real := + if n = 0 then + 0 + else + let μ := meanReal x + (∑ i, (x i - μ) ^ 2) / n + +/-- Unfold `varianceReal` when `n ≠ 0`. -/ +theorem varianceReal_def {n : Nat} (x : Fin n → Real) (h : n ≠ 0) : + varianceReal x = + let μ := meanReal x + (∑ i, (x i - μ) ^ 2) / n := by + simp [varianceReal, h] + +/-- Variance is nonnegative when `n ≠ 0`. -/ +theorem varianceReal_nonneg {n : Nat} (x : Fin n → Real) (h : n ≠ 0) : + 0 ≤ varianceReal x := by + classical + have hsum : 0 ≤ ∑ i, (x i - meanReal x) ^ 2 := by + refine Finset.sum_nonneg ?_ + intro i _ + exact sq_nonneg (x i - meanReal x) + have hden : 0 ≤ (n : Real) := by + exact_mod_cast (Nat.zero_le n) + have hdiv : 0 ≤ (∑ i, (x i - meanReal x) ^ 2) / n := + div_nonneg hsum hden + simpa [varianceReal_def x h] using hdiv + +/-- Absolute mean bound from per-coordinate bounds (real inputs). -/ +theorem meanReal_abs_le_bound {n : Nat} (x : Fin n → Real) (bound : Rat) + (hne : n ≠ 0) (hbound : ∀ i, |x i| ≤ (bound : Real)) : + |meanReal x| ≤ (bound : Real) := by + classical + have hsum_abs : + |∑ i : Fin n, x i| ≤ ∑ i : Fin n, |x i| := by + simpa using + (Finset.abs_sum_le_sum_abs + (f := fun i : Fin n => x i) + (s := (Finset.univ : Finset (Fin n)))) + have hsum_bound : ∑ i : Fin n, |x i| ≤ ∑ i : Fin n, (bound : Real) := by + refine Finset.sum_le_sum ?_ + intro i _ + exact hbound i + have hsum_le : |∑ i : Fin n, x i| ≤ (n : Real) * (bound : Real) := by + have hsum := le_trans hsum_abs hsum_bound + simpa [Finset.sum_const, Finset.card_univ, mul_comm] using hsum + have hpos : 0 < (n : Real) := by + exact_mod_cast Nat.pos_of_ne_zero hne + have hsum_le' : |∑ i : Fin n, x i| ≤ (bound : Real) * (n : Real) := by + simpa [mul_comm] using hsum_le + have hdiv : |∑ i : Fin n, x i| / (n : Real) ≤ (bound : Real) := by + exact (div_le_iff₀ hpos).2 hsum_le' + have habs_mean : + |(∑ i : Fin n, x i) / (n : Real)| ≤ (bound : Real) := by + simpa [abs_div, abs_of_nonneg (le_of_lt hpos)] using hdiv + simpa [meanReal_def x hne] using habs_mean + /-! Square-root bounds. -/ /-- Base rational lower bound for a square root. -/ @@ -490,6 +593,17 @@ noncomputable def layerNormReal {n : Nat} let invStd : Real := (Real.sqrt varEps)⁻¹ fun i => (gamma i : Real) * ((x i : Real) - μ) * invStd + (beta i : Real) +/-- Real-valued LayerNorm output for a real vector. -/ +noncomputable def layerNormRealOfReal {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Real) : Fin n → Real := + if n = 0 then + fun _ => 0 + else + let μ : Real := meanReal x + let varEps : Real := varianceReal x + (eps : Real) + let invStd : Real := (Real.sqrt varEps)⁻¹ + fun i => (gamma i : Real) * (x i - μ) * invStd + (beta i : Real) + /-- Interval bounds for LayerNorm outputs. -/ def layerNormBounds {n : Nat} (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) : @@ -611,6 +725,227 @@ theorem layerNormBounds_spec {n : Nat} simpa [bounds, layerNormBounds, hne, μ, varEps, invLo, invHi, centered, nb, sb, lo, hi] using And.intro hlo hhi +/-- Interval bounds for LayerNorm outputs from an absolute input bound. -/ +def layerNormAbsBounds {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (absBound : Rat) : + (Fin n → Rat) × (Fin n → Rat) := + let centeredBound : Rat := 2 * absBound + let invStdBound : Rat := (sqrtLower eps)⁻¹ + let radius : Fin n → Rat := fun i => |gamma i| * centeredBound * invStdBound + (fun i => beta i - radius i, fun i => beta i + radius i) + +/-- `layerNormAbsBounds` soundness for real LayerNorm outputs under absolute input bounds. -/ +theorem layerNormAbsBounds_spec {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (absBound : Rat) (x : Fin n → Rat) + (hne : n ≠ 0) (heps : 0 < eps) (habs : ∀ i, |x i| ≤ absBound) : + let bounds := layerNormAbsBounds eps gamma beta absBound + ∀ i, + (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i ∧ + layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + classical + intro bounds i + have hmean_abs : |mean x| ≤ absBound := + mean_abs_le_bound x absBound hne habs + have hmean_abs_real : |(mean x : Real)| ≤ (absBound : Real) := by + exact_mod_cast hmean_abs + have hbound_nonneg : 0 ≤ absBound := by + have hposn : 0 < n := Nat.pos_of_ne_zero hne + let i0 : Fin n := ⟨0, hposn⟩ + have h0 : 0 ≤ |x i0| := abs_nonneg _ + exact le_trans h0 (habs i0) + let centeredBound : Rat := 2 * absBound + let invStdBound : Rat := (sqrtLower eps)⁻¹ + let varEps : Rat := variance x + eps + let μ : Real := mean x + let invStd : Real := (Real.sqrt (varEps : Real))⁻¹ + have hcentered_abs : |(x i : Real) - μ| ≤ (centeredBound : Real) := by + have h1 : |(x i : Real) - μ| ≤ |(x i : Real)| + |μ| := by + simpa [sub_eq_add_neg, abs_neg] using abs_add_le (x i : Real) (-μ) + have hx : |(x i : Real)| ≤ (absBound : Real) := by + exact_mod_cast (habs i) + have hmu : |μ| ≤ (absBound : Real) := by + simpa using hmean_abs_real + have h2 : |(x i : Real)| + |μ| ≤ (absBound : Real) + (absBound : Real) := + add_le_add hx hmu + have h12 : |(x i : Real) - μ| ≤ (absBound : Real) + (absBound : Real) := + le_trans h1 h2 + simpa [centeredBound, two_mul] using h12 + have hbound_nonneg_real : 0 ≤ (absBound : Real) := by + exact_mod_cast hbound_nonneg + have hcentered_nonneg : 0 ≤ (centeredBound : Real) := by + have hsum := add_nonneg hbound_nonneg_real hbound_nonneg_real + simpa [centeredBound, two_mul] using hsum + have hvar_nonneg : 0 ≤ variance x := variance_nonneg x hne + have hsqrt_lower : + (sqrtLower eps : Real) ≤ Real.sqrt (varEps : Real) := by + have hsqrt_eps : + (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by + have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) + simpa using h + have hle : (eps : Real) ≤ (varEps : Real) := by + have hle' : eps ≤ varEps := le_add_of_nonneg_left hvar_nonneg + exact_mod_cast hle' + have hsqrt_eps' : Real.sqrt (eps : Real) ≤ Real.sqrt (varEps : Real) := by + exact Real.sqrt_le_sqrt hle + exact le_trans hsqrt_eps hsqrt_eps' + have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by + exact_mod_cast (sqrtLower_pos (q := eps) heps) + have hinv : invStd ≤ (invStdBound : Real) := by + have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower + simpa [invStd, invStdBound] using h + have hinv_nonneg : 0 ≤ invStd := by + have hsqrt_nonneg : 0 ≤ Real.sqrt (varEps : Real) := by + exact Real.sqrt_nonneg _ + exact inv_nonneg.2 hsqrt_nonneg + have hmul1 : |(x i : Real) - μ| * invStd ≤ + (centeredBound : Real) * (invStdBound : Real) := by + have hleft : + |(x i : Real) - μ| * invStd ≤ (centeredBound : Real) * invStd := by + exact mul_le_mul_of_nonneg_right hcentered_abs hinv_nonneg + have hright : + (centeredBound : Real) * invStd ≤ (centeredBound : Real) * (invStdBound : Real) := by + exact mul_le_mul_of_nonneg_left hinv hcentered_nonneg + exact le_trans hleft hright + have hmul2 : |(gamma i : Real)| * |(x i : Real) - μ| * invStd ≤ + |(gamma i : Real)| * (centeredBound : Real) * (invStdBound : Real) := by + have hgamma_nonneg : 0 ≤ |(gamma i : Real)| := abs_nonneg _ + have hmul2' : |(gamma i : Real)| * (|(x i : Real) - μ| * invStd) ≤ + |(gamma i : Real)| * ((centeredBound : Real) * (invStdBound : Real)) := by + exact mul_le_mul_of_nonneg_left hmul1 hgamma_nonneg + simpa [mul_assoc] using hmul2' + let t : Real := (gamma i : Real) * ((x i : Real) - μ) * invStd + have ht_abs : + |t| ≤ |(gamma i : Real)| * (centeredBound : Real) * (invStdBound : Real) := by + have ht : |t| = |(gamma i : Real)| * |(x i : Real) - μ| * invStd := by + have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg + simp [t, abs_mul, hinv_abs, mul_assoc] + simpa [ht] using hmul2 + let radius : Fin n → Rat := fun j => |gamma j| * centeredBound * invStdBound + have ht_abs' : |t| ≤ (radius i : Real) := by + simpa [radius, centeredBound, invStdBound] using ht_abs + have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by + exact abs_le.mp ht_abs' + have hlow : + (beta i : Real) - (radius i : Real) ≤ t + (beta i : Real) := by + have h := add_le_add_left hbounds.1 (beta i : Real) + simpa [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h + have hhigh : + t + (beta i : Real) ≤ (beta i : Real) + (radius i : Real) := by + have h := add_le_add_left hbounds.2 (beta i : Real) + simpa [add_comm, add_left_comm, add_assoc] using h + have hreal : + layerNormReal eps gamma beta x i = t + (beta i : Real) := by + simp [layerNormReal, hne, μ, invStd, varEps, t, add_comm] + have hlo : (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i := by + simpa [bounds, layerNormAbsBounds, radius, centeredBound, invStdBound, hreal] using hlow + have hhi : layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + simpa [bounds, layerNormAbsBounds, radius, centeredBound, invStdBound, hreal] using hhigh + exact And.intro hlo hhi + +/-- `layerNormAbsBounds` soundness for real LayerNorm outputs on real inputs. -/ +theorem layerNormAbsBounds_spec_real {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (absBound : Rat) (x : Fin n → Real) + (hne : n ≠ 0) (heps : 0 < eps) (habs : ∀ i, |x i| ≤ (absBound : Real)) : + let bounds := layerNormAbsBounds eps gamma beta absBound + ∀ i, + (bounds.1 i : Real) ≤ layerNormRealOfReal eps gamma beta x i ∧ + layerNormRealOfReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + classical + intro bounds i + have hmean_abs : |meanReal x| ≤ (absBound : Real) := + meanReal_abs_le_bound x absBound hne habs + have hbound_nonneg_real : 0 ≤ (absBound : Real) := by + have hposn : 0 < n := Nat.pos_of_ne_zero hne + let i0 : Fin n := ⟨0, hposn⟩ + have h0 : 0 ≤ |x i0| := abs_nonneg _ + exact le_trans h0 (habs i0) + let centeredBound : Rat := 2 * absBound + let invStdBound : Rat := (sqrtLower eps)⁻¹ + let varEps : Real := varianceReal x + (eps : Real) + let μ : Real := meanReal x + let invStd : Real := (Real.sqrt varEps)⁻¹ + have hcentered_abs : |x i - μ| ≤ (centeredBound : Real) := by + have h1 : |x i - μ| ≤ |x i| + |μ| := by + simpa [sub_eq_add_neg, abs_neg] using abs_add_le (x i) (-μ) + have hx : |x i| ≤ (absBound : Real) := habs i + have hmu : |μ| ≤ (absBound : Real) := by + simpa using hmean_abs + have h2 : |x i| + |μ| ≤ (absBound : Real) + (absBound : Real) := + add_le_add hx hmu + have h12 : |x i - μ| ≤ (absBound : Real) + (absBound : Real) := + le_trans h1 h2 + simpa [centeredBound, two_mul] using h12 + have hcentered_nonneg : 0 ≤ (centeredBound : Real) := by + have hsum := add_nonneg hbound_nonneg_real hbound_nonneg_real + simpa [centeredBound, two_mul] using hsum + have hvar_nonneg : 0 ≤ varianceReal x := varianceReal_nonneg x hne + have hsqrt_lower : + (sqrtLower eps : Real) ≤ Real.sqrt varEps := by + have hsqrt_eps : + (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by + have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) + simpa using h + have hle : (eps : Real) ≤ varEps := by + have hle' : (eps : Real) ≤ varianceReal x + (eps : Real) := by + exact le_add_of_nonneg_left hvar_nonneg + simpa [varEps] using hle' + have hsqrt_eps' : Real.sqrt (eps : Real) ≤ Real.sqrt varEps := by + exact Real.sqrt_le_sqrt hle + exact le_trans hsqrt_eps hsqrt_eps' + have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by + exact_mod_cast (sqrtLower_pos (q := eps) heps) + have hinv : invStd ≤ (invStdBound : Real) := by + have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower + simpa [invStd, invStdBound] using h + have hinv_nonneg : 0 ≤ invStd := by + have hsqrt_nonneg : 0 ≤ Real.sqrt varEps := by + exact Real.sqrt_nonneg _ + exact inv_nonneg.2 hsqrt_nonneg + have hmul1 : |x i - μ| * invStd ≤ + (centeredBound : Real) * (invStdBound : Real) := by + have hleft : + |x i - μ| * invStd ≤ (centeredBound : Real) * invStd := by + exact mul_le_mul_of_nonneg_right hcentered_abs hinv_nonneg + have hright : + (centeredBound : Real) * invStd ≤ (centeredBound : Real) * (invStdBound : Real) := by + exact mul_le_mul_of_nonneg_left hinv hcentered_nonneg + exact le_trans hleft hright + have hmul2 : |(gamma i : Real)| * |x i - μ| * invStd ≤ + |(gamma i : Real)| * (centeredBound : Real) * (invStdBound : Real) := by + have hgamma_nonneg : 0 ≤ |(gamma i : Real)| := abs_nonneg _ + have hmul2' : |(gamma i : Real)| * (|x i - μ| * invStd) ≤ + |(gamma i : Real)| * ((centeredBound : Real) * (invStdBound : Real)) := by + exact mul_le_mul_of_nonneg_left hmul1 hgamma_nonneg + simpa [mul_assoc] using hmul2' + let t : Real := (gamma i : Real) * (x i - μ) * invStd + have ht_abs : + |t| ≤ |(gamma i : Real)| * (centeredBound : Real) * (invStdBound : Real) := by + have ht : |t| = |(gamma i : Real)| * |x i - μ| * invStd := by + have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg + simp [t, abs_mul, hinv_abs, mul_assoc] + simpa [ht] using hmul2 + let radius : Fin n → Rat := fun j => |gamma j| * centeredBound * invStdBound + have ht_abs' : |t| ≤ (radius i : Real) := by + simpa [radius, centeredBound, invStdBound] using ht_abs + have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by + exact abs_le.mp ht_abs' + have hlow : + (beta i : Real) - (radius i : Real) ≤ t + (beta i : Real) := by + have h := add_le_add_left hbounds.1 (beta i : Real) + simpa [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h + have hhigh : + t + (beta i : Real) ≤ (beta i : Real) + (radius i : Real) := by + have h := add_le_add_left hbounds.2 (beta i : Real) + simpa [add_comm, add_left_comm, add_assoc] using h + have hreal : + layerNormRealOfReal eps gamma beta x i = t + (beta i : Real) := by + simp [layerNormRealOfReal, hne, μ, invStd, varEps, t, add_comm] + have hlo : (bounds.1 i : Real) ≤ layerNormRealOfReal eps gamma beta x i := by + simpa [bounds, layerNormAbsBounds, radius, centeredBound, invStdBound, hreal] using hlow + have hhi : layerNormRealOfReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + simpa [bounds, layerNormAbsBounds, radius, centeredBound, invStdBound, hreal] using hhigh + exact And.intro hlo hhi + end Bounds end Sound diff --git a/Nfp/Sound/Bounds/MatrixNorm.lean b/Nfp/Sound/Bounds/MatrixNorm.lean index 9836486..b3a6243 100644 --- a/Nfp/Sound/Bounds/MatrixNorm.lean +++ b/Nfp/Sound/Bounds/MatrixNorm.lean @@ -203,6 +203,33 @@ theorem abs_le_max_abs_abs_of_interval {a b x : Rat} (hlo : a ≤ x) (hhi : x _ = |a| := by simp [haabs] _ ≤ max |a| |b| := le_max_left _ _ +/-- Global absolute bound from interval endpoints. -/ +def intervalAbsBound {n : Nat} (lo hi : Fin n → Rat) : Rat := + if h : (Finset.univ : Finset (Fin n)).Nonempty then + (Finset.univ).sup' h (fun i => max |lo i| |hi i|) + else + 0 + +/-- `intervalAbsBound` bounds any element inside the interval. -/ +theorem abs_le_intervalAbsBound {n : Nat} (lo hi x : Fin n → Rat) + (hlo : ∀ i, lo i ≤ x i) (hhi : ∀ i, x i ≤ hi i) (i : Fin n) : + |x i| ≤ intervalAbsBound lo hi := by + classical + have hbound : |x i| ≤ max |lo i| |hi i| := + abs_le_max_abs_abs_of_interval (hlo i) (hhi i) + have hnonempty : (Finset.univ : Finset (Fin n)).Nonempty := ⟨i, by simp⟩ + have hsup : + max |lo i| |hi i| ≤ + (Finset.univ).sup' hnonempty (fun j => max |lo j| |hi j|) := by + simpa using + (Finset.le_sup' + (s := (Finset.univ : Finset (Fin n))) + (f := fun j => max |lo j| |hi j|) + (by simp : i ∈ (Finset.univ : Finset (Fin n)))) + have hfinal : |x i| ≤ (Finset.univ).sup' hnonempty (fun j => max |lo j| |hi j|) := + le_trans hbound hsup + simpa [intervalAbsBound, hnonempty] using hfinal + theorem abs_dotProduct_le_dotIntervalAbsBound {n : Nat} (v lo hi x : Fin n → Rat) (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : |dotProduct v x| ≤ dotIntervalAbsBound v lo hi := by @@ -299,6 +326,29 @@ theorem abs_le_max_abs_abs_of_interval_real {a b x : Real} (hlo : a ≤ x) (hhi _ = |a| := by simp [haabs] _ ≤ max |a| |b| := le_max_left _ _ +/-- `intervalAbsBound` controls real-valued coordinates inside a rational interval. -/ +theorem abs_le_intervalAbsBound_real {n : Nat} (lo hi : Fin n → Rat) (x : Fin n → Real) + (hlo : ∀ i, (lo i : Real) ≤ x i) (hhi : ∀ i, x i ≤ (hi i : Real)) (i : Fin n) : + |x i| ≤ (intervalAbsBound lo hi : Real) := by + classical + have hbound : |x i| ≤ max |(lo i : Real)| |(hi i : Real)| := + abs_le_max_abs_abs_of_interval_real (hlo i) (hhi i) + have hnonempty : (Finset.univ : Finset (Fin n)).Nonempty := ⟨i, by simp⟩ + have hsup : + max |lo i| |hi i| ≤ + (Finset.univ).sup' hnonempty (fun j => max |lo j| |hi j|) := by + simpa using + (Finset.le_sup' + (s := (Finset.univ : Finset (Fin n))) + (f := fun j => max |lo j| |hi j|) + (by simp : i ∈ (Finset.univ : Finset (Fin n)))) + have hsup' : max |lo i| |hi i| ≤ intervalAbsBound lo hi := by + simpa [intervalAbsBound, hnonempty] using hsup + have hsup_real : + max |(lo i : Real)| |(hi i : Real)| ≤ (intervalAbsBound lo hi : Real) := by + exact_mod_cast hsup' + exact le_trans hbound hsup_real + theorem abs_dotProduct_le_dotIntervalAbsBound_real {n : Nat} (v lo hi : Fin n → Rat) (x : Fin n → Real) (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : diff --git a/Nfp/Sound/Bounds/Mlp.lean b/Nfp/Sound/Bounds/Mlp.lean new file mode 100644 index 0000000..e699fb1 --- /dev/null +++ b/Nfp/Sound/Bounds/Mlp.lean @@ -0,0 +1,285 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.BigOperators.Group.Finset.Basic +import Mathlib.Algebra.Order.Ring.Rat +import Nfp.Sound.Bounds.Gelu +import Nfp.Sound.Bounds.LayerNorm +import Nfp.Sound.Bounds.MatrixNorm + +/-! +Interval bounds for GPT-2 MLP blocks (linear + GELU + linear). +-/ + +namespace Nfp + +namespace Sound + +namespace Bounds + +open scoped BigOperators + +/-- Real-valued MLP with tanh-based GELU activations. -/ +noncomputable def mlpReal {dModel hidden : Nat} + (wIn : Fin dModel → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin dModel → Rat) (bOut : Fin dModel → Rat) + (x : Fin dModel → Real) : Fin dModel → Real := + fun i => + let hidden : Fin hidden → Real := fun h => + geluTanh (dotProduct (fun j => (wIn j h : Real)) x + (bIn h : Real)) + dotProduct (fun h => (wOut h i : Real)) hidden + (bOut i : Real) + +/-- Interval bounds for a tanh-GELU MLP given input intervals. -/ +def mlpBounds {dModel hidden : Nat} + (wIn : Fin dModel → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin dModel → Rat) (bOut : Fin dModel → Rat) + (lo hi : Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := + let preLo : Fin hidden → Rat := fun h => + dotIntervalLower (fun j => wIn j h) lo hi + bIn h + let preHi : Fin hidden → Rat := fun h => + dotIntervalUpper (fun j => wIn j h) lo hi + bIn h + let geluLo : Fin hidden → Rat := fun h => min (preLo h) 0 + let geluHi : Fin hidden → Rat := fun h => max (preHi h) 0 + let outLo : Fin dModel → Rat := fun i => + dotIntervalLower (fun h => wOut h i) geluLo geluHi + bOut i + let outHi : Fin dModel → Rat := fun i => + dotIntervalUpper (fun h => wOut h i) geluLo geluHi + bOut i + (outLo, outHi) + +/-- `mlpBounds` soundness for real MLP outputs. -/ +theorem mlpBounds_spec {dModel hidden : Nat} + (wIn : Fin dModel → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin dModel → Rat) (bOut : Fin dModel → Rat) + (lo hi : Fin dModel → Rat) (x : Fin dModel → Real) + (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : + let bounds := mlpBounds wIn bIn wOut bOut lo hi + ∀ i, (bounds.1 i : Real) ≤ mlpReal wIn bIn wOut bOut x i ∧ + mlpReal wIn bIn wOut bOut x i ≤ (bounds.2 i : Real) := by + classical + intro bounds i + let preLo : Fin hidden → Rat := fun h => + dotIntervalLower (fun j => wIn j h) lo hi + bIn h + let preHi : Fin hidden → Rat := fun h => + dotIntervalUpper (fun j => wIn j h) lo hi + bIn h + let pre : Fin hidden → Real := fun h => + dotProduct (fun j => (wIn j h : Real)) x + (bIn h : Real) + have hpre_lower : ∀ h, (preLo h : Real) ≤ pre h := by + intro h + have hdot := + dotIntervalLower_le_dotProduct_real (v := fun j => wIn j h) lo hi x hlo hhi + have hdot' := add_le_add_right hdot (bIn h : Real) + simpa [pre, preLo, Rat.cast_add] using hdot' + have hpre_upper : ∀ h, pre h ≤ (preHi h : Real) := by + intro h + have hdot := + dotProduct_le_dotIntervalUpper_real (v := fun j => wIn j h) lo hi x hlo hhi + have hdot' := add_le_add_right hdot (bIn h : Real) + simpa [pre, preHi, Rat.cast_add] using hdot' + let geluLo : Fin hidden → Rat := fun h => min (preLo h) 0 + let geluHi : Fin hidden → Rat := fun h => max (preHi h) 0 + let hidden : Fin hidden → Real := fun h => geluTanh (pre h) + have hgelu : ∀ h, (geluLo h : Real) ≤ hidden h ∧ hidden h ≤ (geluHi h : Real) := by + intro h + have hbounds := geluInterval_bounds (lo := preLo h) (hi := preHi h) + (hpre_lower h) (hpre_upper h) + dsimp [geluLo, geluHi, hidden, geluInterval] + exact hbounds + let outLo : Fin dModel → Rat := fun i => + dotIntervalLower (fun h => wOut h i) geluLo geluHi + bOut i + let outHi : Fin dModel → Rat := fun i => + dotIntervalUpper (fun h => wOut h i) geluLo geluHi + bOut i + have hout_lower : + (outLo i : Real) ≤ dotProduct (fun h => (wOut h i : Real)) hidden + (bOut i : Real) := by + have hdot := + dotIntervalLower_le_dotProduct_real (v := fun h => wOut h i) geluLo geluHi hidden + (fun h => (hgelu h).1) (fun h => (hgelu h).2) + have hdot' := add_le_add_right hdot (bOut i : Real) + simpa [outLo, Rat.cast_add] using hdot' + have hout_upper : + dotProduct (fun h => (wOut h i : Real)) hidden + (bOut i : Real) ≤ (outHi i : Real) := by + have hdot := + dotProduct_le_dotIntervalUpper_real (v := fun h => wOut h i) geluLo geluHi hidden + (fun h => (hgelu h).1) (fun h => (hgelu h).2) + have hdot' := add_le_add_right hdot (bOut i : Real) + simpa [outHi, Rat.cast_add] using hdot' + have hlo' : (outLo i : Real) ≤ mlpReal wIn bIn wOut bOut x i := by + simpa [mlpReal, hidden, pre] using hout_lower + have hhi' : mlpReal wIn bIn wOut bOut x i ≤ (outHi i : Real) := by + simpa [mlpReal, hidden, pre] using hout_upper + simpa [bounds, mlpBounds, preLo, preHi, geluLo, geluHi, outLo, outHi] using + And.intro hlo' hhi' + +/-- Interval bounds for a LayerNorm + MLP sublayer from exact inputs. -/ +def layerNormMlpBounds {n hidden : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) + (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) + (x : Fin n → Rat) : (Fin n → Rat) × (Fin n → Rat) := + let ln := layerNormBounds eps gamma beta x + mlpBounds wIn bIn wOut bOut ln.1 ln.2 + +/-- `layerNormMlpBounds` soundness for real LayerNorm + MLP outputs. -/ +theorem layerNormMlpBounds_spec {n hidden : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) + (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) + (x : Fin n → Rat) (hne : n ≠ 0) (heps : 0 < eps) : + let bounds := layerNormMlpBounds eps gamma beta wIn bIn wOut bOut x + ∀ i, (bounds.1 i : Real) ≤ + mlpReal wIn bIn wOut bOut (layerNormReal eps gamma beta x) i ∧ + mlpReal wIn bIn wOut bOut (layerNormReal eps gamma beta x) i ≤ (bounds.2 i : Real) := by + classical + intro bounds i + let ln := layerNormBounds eps gamma beta x + have hln := layerNormBounds_spec eps gamma beta x hne heps + have hlo : ∀ j, (ln.1 j : Real) ≤ layerNormReal eps gamma beta x j := fun j => (hln j).1 + have hhi : ∀ j, layerNormReal eps gamma beta x j ≤ (ln.2 j : Real) := fun j => (hln j).2 + have hmlp := mlpBounds_spec wIn bIn wOut bOut ln.1 ln.2 + (layerNormReal eps gamma beta x) hlo hhi + simpa [bounds, layerNormMlpBounds, ln] using hmlp i + +/-- Interval bounds for LayerNorm + MLP sublayer from interval inputs. -/ +def layerNormAbsMlpBounds {n hidden : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) + (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) + (lo hi : Fin n → Rat) : (Fin n → Rat) × (Fin n → Rat) := + let absBound := intervalAbsBound lo hi + let ln := layerNormAbsBounds eps gamma beta absBound + mlpBounds wIn bIn wOut bOut ln.1 ln.2 + +/-- `layerNormAbsMlpBounds` soundness for real LayerNorm + MLP outputs. -/ +theorem layerNormAbsMlpBounds_spec {n hidden : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) + (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) + (lo hi : Fin n → Rat) (x : Fin n → Real) + (hne : n ≠ 0) (heps : 0 < eps) + (hlo : ∀ i, (lo i : Real) ≤ x i) (hhi : ∀ i, x i ≤ (hi i : Real)) : + let bounds := layerNormAbsMlpBounds eps gamma beta wIn bIn wOut bOut lo hi + ∀ i, (bounds.1 i : Real) ≤ + mlpReal wIn bIn wOut bOut (layerNormRealOfReal eps gamma beta x) i ∧ + mlpReal wIn bIn wOut bOut (layerNormRealOfReal eps gamma beta x) i ≤ (bounds.2 i : Real) := by + classical + intro bounds i + let absBound := intervalAbsBound lo hi + let ln := layerNormAbsBounds eps gamma beta absBound + have habs : ∀ j, |x j| ≤ (absBound : Real) := by + intro j + have hbound : + |x j| ≤ max |(lo j : Real)| |(hi j : Real)| := + abs_le_max_abs_abs_of_interval_real (hlo j) (hhi j) + have hnonempty : (Finset.univ : Finset (Fin n)).Nonempty := ⟨j, by simp⟩ + have hsup : + max |lo j| |hi j| ≤ intervalAbsBound lo hi := by + have hsup' : + max |lo j| |hi j| ≤ + (Finset.univ).sup' hnonempty (fun k => max |lo k| |hi k|) := by + simpa using + (Finset.le_sup' + (s := (Finset.univ : Finset (Fin n))) + (f := fun k => max |lo k| |hi k|) + (by simp : j ∈ (Finset.univ : Finset (Fin n)))) + simpa [intervalAbsBound, hnonempty] using hsup' + have hsup_real : + max |(lo j : Real)| |(hi j : Real)| ≤ (absBound : Real) := by + exact_mod_cast hsup + exact le_trans hbound hsup_real + have hln := + layerNormAbsBounds_spec_real eps gamma beta absBound x hne heps habs + have hlo_ln : ∀ j, (ln.1 j : Real) ≤ layerNormRealOfReal eps gamma beta x j := fun j => (hln j).1 + have hhi_ln : ∀ j, layerNormRealOfReal eps gamma beta x j ≤ (ln.2 j : Real) := fun j => (hln j).2 + have hmlp := mlpBounds_spec wIn bIn wOut bOut ln.1 ln.2 + (layerNormRealOfReal eps gamma beta x) hlo_ln hhi_ln + simpa [bounds, layerNormAbsMlpBounds, absBound, ln] using hmlp i + +/-- Add residual inputs to interval bounds. -/ +def residualAddBounds {n : Nat} (x : Fin n → Rat) (lo hi : Fin n → Rat) : + (Fin n → Rat) × (Fin n → Rat) := + (fun i => x i + lo i, fun i => x i + hi i) + +/-- `residualAddBounds` soundness for residual addition. -/ +theorem residualAddBounds_spec {n : Nat} (x : Fin n → Rat) + (lo hi : Fin n → Rat) (y : Fin n → Real) + (hlo : ∀ i, (lo i : Real) ≤ y i) (hhi : ∀ i, y i ≤ (hi i : Real)) : + let bounds := residualAddBounds x lo hi + ∀ i, (bounds.1 i : Real) ≤ (x i : Real) + y i ∧ + (x i : Real) + y i ≤ (bounds.2 i : Real) := by + intro bounds i + have hlow := add_le_add_left (hlo i) (x i : Real) + have hhigh := add_le_add_left (hhi i) (x i : Real) + constructor + · simpa [bounds, residualAddBounds, Rat.cast_add] using hlow + · simpa [bounds, residualAddBounds, Rat.cast_add] using hhigh + +/-- Interval bounds for a full MLP residual path (LayerNorm + MLP + residual add). -/ +def layerNormMlpResidualBounds {n hidden : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) + (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) + (x : Fin n → Rat) : (Fin n → Rat) × (Fin n → Rat) := + let mlp := layerNormMlpBounds eps gamma beta wIn bIn wOut bOut x + residualAddBounds x mlp.1 mlp.2 + +/-- `layerNormMlpResidualBounds` soundness for the MLP residual path. -/ +theorem layerNormMlpResidualBounds_spec {n hidden : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) + (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) + (x : Fin n → Rat) (hne : n ≠ 0) (heps : 0 < eps) : + let bounds := layerNormMlpResidualBounds eps gamma beta wIn bIn wOut bOut x + ∀ i, + (bounds.1 i : Real) ≤ + (x i : Real) + + mlpReal wIn bIn wOut bOut (layerNormReal eps gamma beta x) i ∧ + (x i : Real) + + mlpReal wIn bIn wOut bOut (layerNormReal eps gamma beta x) i ≤ + (bounds.2 i : Real) := by + classical + intro bounds i + let mlp := layerNormMlpBounds eps gamma beta wIn bIn wOut bOut x + have hmlp := layerNormMlpBounds_spec eps gamma beta wIn bIn wOut bOut x hne heps + have hres := residualAddBounds_spec x mlp.1 mlp.2 + (mlpReal wIn bIn wOut bOut (layerNormReal eps gamma beta x)) + (fun j => (hmlp j).1) (fun j => (hmlp j).2) + simpa [bounds, layerNormMlpResidualBounds, mlp] using hres i + +/-- Interval bounds for a full MLP residual path (LayerNorm + MLP + residual add) from intervals. -/ +def layerNormAbsMlpResidualBounds {n hidden : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) + (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) + (lo hi : Fin n → Rat) : (Fin n → Rat) × (Fin n → Rat) := + let mlp := layerNormAbsMlpBounds eps gamma beta wIn bIn wOut bOut lo hi + (fun i => lo i + mlp.1 i, fun i => hi i + mlp.2 i) + +/-- `layerNormAbsMlpResidualBounds` soundness for the MLP residual path. -/ +theorem layerNormAbsMlpResidualBounds_spec {n hidden : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) + (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) + (lo hi : Fin n → Rat) (x : Fin n → Real) + (hne : n ≠ 0) (heps : 0 < eps) + (hlo : ∀ i, (lo i : Real) ≤ x i) (hhi : ∀ i, x i ≤ (hi i : Real)) : + let bounds := layerNormAbsMlpResidualBounds eps gamma beta wIn bIn wOut bOut lo hi + ∀ i, + (bounds.1 i : Real) ≤ + x i + mlpReal wIn bIn wOut bOut (layerNormRealOfReal eps gamma beta x) i ∧ + x i + mlpReal wIn bIn wOut bOut (layerNormRealOfReal eps gamma beta x) i ≤ + (bounds.2 i : Real) := by + classical + intro bounds i + let mlp := layerNormAbsMlpBounds eps gamma beta wIn bIn wOut bOut lo hi + have hmlp := layerNormAbsMlpBounds_spec eps gamma beta wIn bIn wOut bOut lo hi x hne heps hlo hhi + have hlo' := (hmlp i).1 + have hhi' := (hmlp i).2 + have hlow := add_le_add (hlo i) hlo' + have hhigh := add_le_add (hhi i) hhi' + constructor + · simpa [bounds, layerNormAbsMlpResidualBounds, mlp, Rat.cast_add] using hlow + · simpa [bounds, layerNormAbsMlpResidualBounds, mlp, Rat.cast_add] using hhigh + +end Bounds + +end Sound + +end Nfp diff --git a/Nfp/Sound/Bounds/Transformer.lean b/Nfp/Sound/Bounds/Transformer.lean new file mode 100644 index 0000000..f528152 --- /dev/null +++ b/Nfp/Sound/Bounds/Transformer.lean @@ -0,0 +1,281 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.BigOperators.Group.Finset.Basic +import Mathlib.Algebra.Order.Ring.Rat +import Mathlib.Data.List.Range +import Mathlib.Data.Real.Basic +import Nfp.Model.Gpt2 +import Nfp.Sound.Bounds.Attention +import Nfp.Sound.Bounds.LayerNorm +import Nfp.Sound.Bounds.MatrixNorm + +/-! +Interval bounds for transformer stacks and final LayerNorm outputs. +-/ + +namespace Nfp + +namespace Sound + +namespace Bounds + +open scoped BigOperators + +private lemma fin_univ_nonempty (seq : Nat) [NeZero seq] : + (Finset.univ : Finset (Fin seq)).Nonempty := by + classical + refine ⟨⟨0, ?_⟩, by simp⟩ + exact Nat.pos_of_ne_zero (NeZero.ne (n := seq)) + +/-- Interval bounds across tokens for an embedding map. -/ +def embeddingIntervalBounds {seq dModel : Nat} [NeZero seq] + (x : Fin seq → Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := + let h : (Finset.univ : Finset (Fin seq)).Nonempty := fin_univ_nonempty (seq := seq) + (fun i => (Finset.univ).inf' h (fun q => x q i), + fun i => (Finset.univ).sup' h (fun q => x q i)) + +/-- `embeddingIntervalBounds` bounds embeddings coordinatewise. -/ +theorem embeddingIntervalBounds_spec {seq dModel : Nat} [NeZero seq] + (x : Fin seq → Fin dModel → Rat) : + let bounds := embeddingIntervalBounds x + ∀ q i, + (bounds.1 i : Real) ≤ (x q i : Real) ∧ + (x q i : Real) ≤ (bounds.2 i : Real) := by + classical + intro bounds q i + have hloRat : bounds.1 i ≤ x q i := by + have h := + Finset.inf'_le (s := (Finset.univ : Finset (Fin seq))) + (f := fun k => x k i) (b := q) (by simp) + simpa [bounds, embeddingIntervalBounds, fin_univ_nonempty] using h + have hhiRat : x q i ≤ bounds.2 i := by + have h := + Finset.le_sup' (s := (Finset.univ : Finset (Fin seq))) + (f := fun k => x k i) (b := q) (by simp) + simpa [bounds, embeddingIntervalBounds, fin_univ_nonempty] using h + constructor + · exact_mod_cast hloRat + · exact_mod_cast hhiRat + +/-- Real-valued output of a transformer layer. -/ +noncomputable def transformerLayerReal {seq dModel dHead numHeads hidden : Nat} [NeZero seq] + (eps : Rat) (layer : Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (scores : Fin numHeads → Fin seq → Fin seq → Real) + (x : Fin seq → Fin dModel → Real) (q : Fin seq) (i : Fin dModel) : Real := + x q i + + attentionOutputReal eps layer.ln1Gamma layer.ln1Beta heads layer.attnBias scores x q i + + mlpReal layer.mlpWIn layer.mlpBIn layer.mlpWOut layer.mlpBOut + (layerNormRealOfReal eps layer.ln2Gamma layer.ln2Beta + (fun j => + x q j + attentionOutputReal eps layer.ln1Gamma layer.ln1Beta heads + layer.attnBias scores x q j)) i + +/-- `transformerLayerBounds` soundness for `transformerLayerReal`. -/ +theorem transformerLayerBounds_spec_real {seq dModel dHead numHeads hidden : Nat} [NeZero seq] + (eps : Rat) (layer : Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (scores : Fin numHeads → Fin seq → Fin seq → Real) + (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) + (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : + let bounds := transformerLayerBounds eps layer.ln1Gamma layer.ln1Beta layer.ln2Gamma + layer.ln2Beta heads layer.attnBias layer.mlpWIn layer.mlpBIn layer.mlpWOut + layer.mlpBOut lo hi + ∀ q i, + (bounds.1 i : Real) ≤ transformerLayerReal eps layer heads scores x q i ∧ + transformerLayerReal eps layer heads scores x q i ≤ (bounds.2 i : Real) := by + classical + simpa [transformerLayerReal] using + (transformerLayerBounds_spec (eps := eps) + (ln1Gamma := layer.ln1Gamma) (ln1Beta := layer.ln1Beta) + (ln2Gamma := layer.ln2Gamma) (ln2Beta := layer.ln2Beta) + (heads := heads) (attnBias := layer.attnBias) + (mlpWIn := layer.mlpWIn) (mlpBIn := layer.mlpBIn) + (mlpWOut := layer.mlpWOut) (mlpBOut := layer.mlpBOut) + (scores := scores) (lo := lo) (hi := hi) (x := x) + hne heps hlo hhi) + +/-- Real-valued transformer stack output (folded left over layers). -/ +noncomputable def transformerStackReal + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) + (x : Fin seq → Fin dModel → Real) : Fin seq → Fin dModel → Real := + let step := fun x layerIdx => + transformerLayerReal eps (layers layerIdx) (heads layerIdx) (scores layerIdx) x + (List.finRange numLayers).foldl step x + +/-- Interval bounds for a transformer stack (folded left over layers). -/ +def transformerStackBounds {dModel dHead numHeads hidden numLayers : Nat} + (eps : Rat) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (lo hi : Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := + let step := fun bounds layerIdx => + transformerLayerBounds eps (layers layerIdx).ln1Gamma (layers layerIdx).ln1Beta + (layers layerIdx).ln2Gamma (layers layerIdx).ln2Beta (heads layerIdx) + (layers layerIdx).attnBias (layers layerIdx).mlpWIn (layers layerIdx).mlpBIn + (layers layerIdx).mlpWOut (layers layerIdx).mlpBOut bounds.1 bounds.2 + (List.finRange numLayers).foldl step (lo, hi) + +private theorem transformerStackBounds_spec_list + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] + (eps : Rat) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) : + ∀ (ls : List (Fin numLayers)) (lo hi : Fin dModel → Rat) + (x : Fin seq → Fin dModel → Real), + (∀ q i, (lo i : Real) ≤ x q i) → + (∀ q i, x q i ≤ (hi i : Real)) → + let bounds := (ls.foldl + (fun bounds layerIdx => + transformerLayerBounds eps (layers layerIdx).ln1Gamma (layers layerIdx).ln1Beta + (layers layerIdx).ln2Gamma (layers layerIdx).ln2Beta (heads layerIdx) + (layers layerIdx).attnBias (layers layerIdx).mlpWIn (layers layerIdx).mlpBIn + (layers layerIdx).mlpWOut (layers layerIdx).mlpBOut bounds.1 bounds.2) + (lo, hi)) + let x' := (ls.foldl + (fun x layerIdx => + transformerLayerReal eps (layers layerIdx) (heads layerIdx) (scores layerIdx) x) + x) + ∀ q i, + (bounds.1 i : Real) ≤ x' q i ∧ + x' q i ≤ (bounds.2 i : Real) := by + intro ls lo hi x hlo hhi + induction ls generalizing lo hi x hlo hhi with + | nil => + simpa using fun q i => And.intro (hlo q i) (hhi q i) + | cons l ls ih => + have hstep := + transformerLayerBounds_spec_real eps (layers l) (heads l) (scores l) lo hi x + hne heps hlo hhi + let bounds1 := + transformerLayerBounds eps (layers l).ln1Gamma (layers l).ln1Beta (layers l).ln2Gamma + (layers l).ln2Beta (heads l) (layers l).attnBias (layers l).mlpWIn (layers l).mlpBIn + (layers l).mlpWOut (layers l).mlpBOut lo hi + let x1 := transformerLayerReal eps (layers l) (heads l) (scores l) x + have hlo1 : ∀ q i, (bounds1.1 i : Real) ≤ x1 q i := fun q i => (hstep q i).1 + have hhi1 : ∀ q i, x1 q i ≤ (bounds1.2 i : Real) := fun q i => (hstep q i).2 + have ih' := ih bounds1.1 bounds1.2 x1 hlo1 hhi1 + simpa [bounds1, x1] using ih' + +/-- `transformerStackBounds` soundness for real transformer-stack outputs. -/ +theorem transformerStackBounds_spec {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] + (eps : Rat) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) + (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) + (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : + let bounds := transformerStackBounds eps layers heads lo hi + ∀ q i, + (bounds.1 i : Real) ≤ transformerStackReal eps layers heads scores x q i ∧ + transformerStackReal eps layers heads scores x q i ≤ (bounds.2 i : Real) := by + classical + simpa [transformerStackBounds, transformerStackReal] using + transformerStackBounds_spec_list eps layers heads scores hne heps + (List.finRange numLayers) lo hi x hlo hhi + +/-- Real-valued transformer stack output after the final LayerNorm. -/ +noncomputable def transformerStackFinalReal {seq dModel dHead numHeads hidden numLayers : Nat} + [NeZero seq] (eps : Rat) (finalLn : Model.Gpt2FinalLayerNorm dModel) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) + (x : Fin seq → Fin dModel → Real) (q : Fin seq) (i : Fin dModel) : Real := + layerNormRealOfReal eps finalLn.gamma finalLn.beta + (fun j => transformerStackReal eps layers heads scores x q j) i + +/-- Interval bounds for transformer stack outputs after the final LayerNorm. -/ +def transformerStackFinalBounds {dModel dHead numHeads hidden numLayers : Nat} + (eps : Rat) (finalLn : Model.Gpt2FinalLayerNorm dModel) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (lo hi : Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := + let stack := transformerStackBounds eps layers heads lo hi + let absBound := intervalAbsBound stack.1 stack.2 + layerNormAbsBounds eps finalLn.gamma finalLn.beta absBound + +/-- `transformerStackFinalBounds` soundness for real outputs. -/ +theorem transformerStackFinalBounds_spec {seq dModel dHead numHeads hidden numLayers : Nat} + [NeZero seq] (eps : Rat) (finalLn : Model.Gpt2FinalLayerNorm dModel) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) + (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) + (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : + let bounds := transformerStackFinalBounds eps finalLn layers heads lo hi + ∀ q i, + (bounds.1 i : Real) ≤ transformerStackFinalReal eps finalLn layers heads scores x q i ∧ + transformerStackFinalReal eps finalLn layers heads scores x q i ≤ (bounds.2 i : Real) := by + classical + intro bounds q i + let stack := transformerStackBounds eps layers heads lo hi + let absBound := intervalAbsBound stack.1 stack.2 + have hstack := + transformerStackBounds_spec eps layers heads scores lo hi x hne heps hlo hhi q + have habs : + ∀ j, |transformerStackReal eps layers heads scores x q j| ≤ (absBound : Real) := by + intro j + have hlo' : ∀ k, (stack.1 k : Real) ≤ transformerStackReal eps layers heads scores x q k := + fun k => (hstack k).1 + have hhi' : ∀ k, transformerStackReal eps layers heads scores x q k ≤ (stack.2 k : Real) := + fun k => (hstack k).2 + have hbound := + abs_le_intervalAbsBound_real (lo := stack.1) (hi := stack.2) + (x := fun k => transformerStackReal eps layers heads scores x q k) hlo' hhi' j + simpa [absBound] using hbound + have hln := + layerNormAbsBounds_spec_real eps finalLn.gamma finalLn.beta absBound + (fun j => transformerStackReal eps layers heads scores x q j) hne heps habs + simpa [bounds, transformerStackFinalBounds, absBound, stack, transformerStackFinalReal] using + hln i + +/-- Residual interval bounds for a GPT-2 stack from exact embeddings. -/ +def gpt2ResidualIntervalBounds + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (finalLn : Model.Gpt2FinalLayerNorm dModel) + (embed : Fin seq → Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := + let base := embeddingIntervalBounds embed + transformerStackFinalBounds eps finalLn layers heads base.1 base.2 + +/-- `gpt2ResidualIntervalBounds` soundness for real GPT-2 outputs. -/ +theorem gpt2ResidualIntervalBounds_spec + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (finalLn : Model.Gpt2FinalLayerNorm dModel) + (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) + (embed : Fin seq → Fin dModel → Rat) + (hne : dModel ≠ 0) (heps : 0 < eps) : + let bounds := gpt2ResidualIntervalBounds eps layers heads finalLn embed + ∀ q i, + (bounds.1 i : Real) ≤ + transformerStackFinalReal eps finalLn layers heads scores + (fun q i => (embed q i : Real)) q i ∧ + transformerStackFinalReal eps finalLn layers heads scores + (fun q i => (embed q i : Real)) q i ≤ (bounds.2 i : Real) := by + classical + intro bounds q i + let base := embeddingIntervalBounds embed + have hbase := embeddingIntervalBounds_spec embed + have hlo : ∀ q i, (base.1 i : Real) ≤ (embed q i : Real) := fun q i => (hbase q i).1 + have hhi : ∀ q i, (embed q i : Real) ≤ (base.2 i : Real) := fun q i => (hbase q i).2 + have hstack := + transformerStackFinalBounds_spec eps finalLn layers heads scores base.1 base.2 + (fun q i => (embed q i : Real)) hne heps hlo hhi q i + simpa [bounds, gpt2ResidualIntervalBounds, base] using hstack + +end Bounds + +end Sound + +end Nfp diff --git a/README.md b/README.md index 0d44d0b..02d1be1 100644 --- a/README.md +++ b/README.md @@ -131,16 +131,19 @@ lake exe nfp induction certify_end_to_end_matrix \ ``` Or derive the downstream matrix directly from an `NFP_BINARY_V1` model file -(currently uses the unembedding direction only): +(currently uses the unembedding direction only). If `--residual-interval` is omitted, +the tool derives a conservative residual interval from the model: ```bash lake exe nfp induction certify_end_to_end_model \ --scores reports/gpt2_induction.cert \ --values reports/gpt2_induction.values \ - --model models/gpt2_rigorous.nfpt \ - --residual-interval reports/gpt2_residual.interval + --model models/gpt2_rigorous.nfpt ``` +To use an external residual-interval certificate instead, include +`--residual-interval reports/gpt2_residual.interval`. + ## File formats ### Softmax-margin certificate diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index eb4a27a..fa1af52 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -9,8 +9,9 @@ It is intentionally brief and focused on the soundness boundary. - Induction certificates are **head-level** (softmax-margin + value-range + logit-diff lower bound). They do **not** yet imply end-to-end model behavior. - Downstream error bounds can be computed from a **matrix payload** inside Lean. A model-based - path exists, but it currently uses only the unembedding direction and relies on an external - **residual-interval certificate** (per-coordinate lower/upper bounds). + path exists, but it currently uses only the unembedding direction and derives residual + intervals via conservative interval propagation (ignoring attention-score structure), + which can be loose. - The `certify_head` path uses a **head-input file** extracted by an untrusted script; the extractor now includes attention projection biases and LayerNorm metadata, but the Lean-side computation still ignores LayerNorm and the shared attention output bias. @@ -24,9 +25,8 @@ It is intentionally brief and focused on the soundness boundary. ## Remaining work -- Compute the downstream bound **inside Lean** from model weights and certified residual - bounds (not just matrix payloads), and wire this into `certify_end_to_end`. -- Replace untrusted residual-interval generation with a verified derivation from upstream bounds. +- Tighten model-derived residual intervals (e.g., use attention-weight certificates or + score-aware bounds) to avoid vacuity. - Replace untrusted extraction with a verified parser for model weight slices. - Add a formal bridge from certificates to circuit semantics and (eventually) to end-to-end transformer claims. From fed4e8bfa964c1cccad68fc3c972614164888441 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sun, 4 Jan 2026 05:37:33 +0100 Subject: [PATCH 107/244] Add interval layernorm bounds and tail-recursive sums --- Nfp/Sound/Bounds/LayerNorm.lean | 313 ++++++++++++++++++++++++++++++ Nfp/Sound/Bounds/MatrixNorm.lean | 88 +++++++-- Nfp/Sound/Bounds/Transformer.lean | 35 ++-- 3 files changed, 401 insertions(+), 35 deletions(-) diff --git a/Nfp/Sound/Bounds/LayerNorm.lean b/Nfp/Sound/Bounds/LayerNorm.lean index 607fe04..139a985 100644 --- a/Nfp/Sound/Bounds/LayerNorm.lean +++ b/Nfp/Sound/Bounds/LayerNorm.lean @@ -3,9 +3,11 @@ import Mathlib.Algebra.BigOperators.Group.Finset.Basic import Mathlib.Algebra.Order.BigOperators.Group.Finset import Mathlib.Algebra.Order.Field.Basic +import Mathlib.Algebra.Order.Ring.Basic import Mathlib.Algebra.Order.Ring.Rat import Mathlib.Data.Nat.Sqrt import Mathlib.Data.Real.Sqrt +import Mathlib.Data.Rat.BigOperators import Mathlib.Data.Rat.Cast.Order /-! @@ -35,6 +37,20 @@ theorem mean_def {n : Nat} (x : Fin n → Rat) (h : n ≠ 0) : mean x = (∑ i, x i) / n := by simp [mean, h] +/-- Mean is monotone under pointwise order (rational inputs). -/ +theorem mean_le_mean {n : Nat} (x y : Fin n → Rat) (hne : n ≠ 0) + (hxy : ∀ i, x i ≤ y i) : mean x ≤ mean y := by + classical + have hsum : (∑ i, x i) ≤ ∑ i, y i := by + refine Finset.sum_le_sum ?_ + intro i _ + exact hxy i + have hden : 0 ≤ (n : Rat) := by + exact_mod_cast (Nat.zero_le n) + have hdiv : (∑ i, x i) / n ≤ (∑ i, y i) / n := + div_le_div_of_nonneg_right hsum hden + simpa [mean, hne] using hdiv + /-- Variance of a finite vector (defaults to `0` when `n = 0`). -/ def variance {n : Nat} (x : Fin n → Rat) : Rat := if n = 0 then @@ -93,6 +109,29 @@ theorem mean_abs_le_bound {n : Nat} (x : Fin n → Rat) (bound : Rat) simpa [abs_div, abs_of_nonneg (le_of_lt hpos)] using hdiv simpa [mean_def x hne] using habs_mean +/-! Interval helpers. -/ + +/-- Absolute value bound from endpoint bounds. -/ +theorem abs_le_max_of_bounds {α : Type _} [Ring α] [LinearOrder α] [IsOrderedRing α] + {a b z : α} + (hlo : a ≤ z) (hhi : z ≤ b) : + |z| ≤ max |a| |b| := by + have hleft : -max |a| |b| ≤ z := by + have hneg : -max |a| |b| ≤ a := by + have hneg' : -max |a| |b| ≤ -|a| := by + exact neg_le_neg (le_max_left _ _) + have hneg'' : -|a| ≤ a := by + have h : -a ≤ |a| := neg_le_abs a + simpa using (neg_le_neg h) + exact le_trans hneg' hneg'' + exact le_trans hneg hlo + have hright : z ≤ max |a| |b| := by + have hb : b ≤ |b| := by + exact le_abs_self b + have hb' : b ≤ max |a| |b| := le_trans hb (le_max_right _ _) + exact le_trans hhi hb' + exact (abs_le.mpr ⟨hleft, hright⟩) + /-! Real-valued mean and variance. -/ /-- Mean of a real vector (defaults to `0` when `n = 0`). -/ @@ -107,6 +146,27 @@ theorem meanReal_def {n : Nat} (x : Fin n → Real) (h : n ≠ 0) : meanReal x = (∑ i, x i) / n := by simp [meanReal, h] +/-- `meanReal` agrees with `mean` after casting. -/ +theorem meanReal_ratCast {n : Nat} (x : Fin n → Rat) : + meanReal (fun i => (x i : Real)) = (mean x : Real) := by + by_cases h : n = 0 + · simp [meanReal, mean, h] + · simp [meanReal, mean, h, Rat.cast_sum, Rat.cast_div] + +/-- Mean is monotone under pointwise order (real inputs). -/ +theorem meanReal_le_meanReal {n : Nat} (x y : Fin n → Real) (hne : n ≠ 0) + (hxy : ∀ i, x i ≤ y i) : meanReal x ≤ meanReal y := by + classical + have hsum : (∑ i, x i) ≤ ∑ i, y i := by + refine Finset.sum_le_sum ?_ + intro i _ + exact hxy i + have hden : 0 ≤ (n : Real) := by + exact_mod_cast (Nat.zero_le n) + have hdiv : (∑ i, x i) / n ≤ (∑ i, y i) / n := + div_le_div_of_nonneg_right hsum hden + simpa [meanReal, hne] using hdiv + /-- Variance of a real vector (defaults to `0` when `n = 0`). -/ noncomputable def varianceReal {n : Nat} (x : Fin n → Real) : Real := if n = 0 then @@ -725,6 +785,138 @@ theorem layerNormBounds_spec {n : Nat} simpa [bounds, layerNormBounds, hne, μ, varEps, invLo, invHi, centered, nb, sb, lo, hi] using And.intro hlo hhi +/-- Interval bounds for LayerNorm outputs from per-coordinate intervals. -/ +def layerNormIntervalBounds {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (lo hi : Fin n → Rat) : + (Fin n → Rat) × (Fin n → Rat) := + if n = 0 then + (fun _ => 0, fun _ => 0) + else + let μLo := mean lo + let μHi := mean hi + let centeredBound : Fin n → Rat := fun i => + max |lo i - μHi| |hi i - μLo| + let invStdBound : Rat := (sqrtLower eps)⁻¹ + let radius : Fin n → Rat := fun i => |gamma i| * centeredBound i * invStdBound + (fun i => beta i - radius i, fun i => beta i + radius i) + +/-- `layerNormIntervalBounds` soundness for real LayerNorm outputs. -/ +theorem layerNormIntervalBounds_spec {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (lo hi : Fin n → Rat) (x : Fin n → Rat) + (hne : n ≠ 0) (heps : 0 < eps) + (hlo : ∀ i, lo i ≤ x i) (hhi : ∀ i, x i ≤ hi i) : + let bounds := layerNormIntervalBounds eps gamma beta lo hi + ∀ i, + (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i ∧ + layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + classical + intro bounds i + have hmean_lo : mean lo ≤ mean x := mean_le_mean lo x hne hlo + have hmean_hi : mean x ≤ mean hi := mean_le_mean x hi hne hhi + let μLo : Rat := mean lo + let μHi : Rat := mean hi + let centeredBound : Fin n → Rat := fun j => max |lo j - μHi| |hi j - μLo| + let invStdBound : Rat := (sqrtLower eps)⁻¹ + let varEps : Rat := variance x + eps + let μ : Real := mean x + let invStd : Real := (Real.sqrt (varEps : Real))⁻¹ + have hcentered_nonneg : 0 ≤ (centeredBound i : Real) := by + have h0 : 0 ≤ centeredBound i := by + dsimp [centeredBound] + exact le_trans (abs_nonneg _) (le_max_left _ _) + exact_mod_cast h0 + have hcentered_abs : |(x i : Real) - μ| ≤ (centeredBound i : Real) := by + have hmean_lo_real : (μLo : Real) ≤ μ := by + have h' : (mean lo : Real) ≤ (mean x : Real) := by + exact_mod_cast hmean_lo + simpa [μLo, μ] using h' + have hmean_hi_real : μ ≤ (μHi : Real) := by + have h' : (mean x : Real) ≤ (mean hi : Real) := by + exact_mod_cast hmean_hi + simpa [μHi, μ] using h' + have hlo' : (lo i : Real) - (μHi : Real) ≤ (x i : Real) - μ := by + have h1 : (lo i : Real) - (μHi : Real) ≤ (lo i : Real) - μ := by + exact sub_le_sub_left hmean_hi_real (lo i : Real) + have h2 : (lo i : Real) - μ ≤ (x i : Real) - μ := by + exact sub_le_sub_right (by exact_mod_cast (hlo i)) μ + exact le_trans h1 h2 + have hhi' : (x i : Real) - μ ≤ (hi i : Real) - (μLo : Real) := by + have h1 : (x i : Real) - μ ≤ (hi i : Real) - μ := by + exact sub_le_sub_right (by exact_mod_cast (hhi i)) μ + have h2 : (hi i : Real) - μ ≤ (hi i : Real) - (μLo : Real) := by + exact sub_le_sub_left hmean_lo_real (hi i : Real) + exact le_trans h1 h2 + have hbound := abs_le_max_of_bounds hlo' hhi' + simpa [centeredBound, μLo, μHi, Rat.cast_abs, Rat.cast_sub, Rat.cast_max] using hbound + have hvar_nonneg : 0 ≤ variance x := variance_nonneg x hne + have hsqrt_lower : + (sqrtLower eps : Real) ≤ Real.sqrt (varEps : Real) := by + have hsqrt_eps : + (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by + have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) + simpa using h + have hle : (eps : Real) ≤ (varEps : Real) := by + have hle' : eps ≤ varEps := le_add_of_nonneg_left hvar_nonneg + exact_mod_cast hle' + have hsqrt_eps' : Real.sqrt (eps : Real) ≤ Real.sqrt (varEps : Real) := by + exact Real.sqrt_le_sqrt hle + exact le_trans hsqrt_eps hsqrt_eps' + have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by + exact_mod_cast (sqrtLower_pos (q := eps) heps) + have hinv : invStd ≤ (invStdBound : Real) := by + have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower + simpa [invStd, invStdBound] using h + have hinv_nonneg : 0 ≤ invStd := by + have hsqrt_nonneg : 0 ≤ Real.sqrt (varEps : Real) := by + exact Real.sqrt_nonneg _ + exact inv_nonneg.2 hsqrt_nonneg + have hmul1 : |(x i : Real) - μ| * invStd ≤ + (centeredBound i : Real) * (invStdBound : Real) := by + have hleft : + |(x i : Real) - μ| * invStd ≤ (centeredBound i : Real) * invStd := by + exact mul_le_mul_of_nonneg_right hcentered_abs hinv_nonneg + have hright : + (centeredBound i : Real) * invStd ≤ (centeredBound i : Real) * (invStdBound : Real) := by + exact mul_le_mul_of_nonneg_left hinv hcentered_nonneg + exact le_trans hleft hright + have hmul2 : |(gamma i : Real)| * |(x i : Real) - μ| * invStd ≤ + |(gamma i : Real)| * (centeredBound i : Real) * (invStdBound : Real) := by + have hgamma_nonneg : 0 ≤ |(gamma i : Real)| := abs_nonneg _ + have hmul2' : |(gamma i : Real)| * (|(x i : Real) - μ| * invStd) ≤ + |(gamma i : Real)| * ((centeredBound i : Real) * (invStdBound : Real)) := by + exact mul_le_mul_of_nonneg_left hmul1 hgamma_nonneg + simpa [mul_assoc] using hmul2' + let t : Real := (gamma i : Real) * ((x i : Real) - μ) * invStd + have ht_abs : + |t| ≤ |(gamma i : Real)| * (centeredBound i : Real) * (invStdBound : Real) := by + have ht : |t| = |(gamma i : Real)| * |(x i : Real) - μ| * invStd := by + have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg + simp [t, abs_mul, hinv_abs, mul_assoc] + simpa [ht] using hmul2 + let radius : Fin n → Rat := fun j => |gamma j| * centeredBound j * invStdBound + have ht_abs' : |t| ≤ (radius i : Real) := by + simpa [radius, centeredBound, invStdBound] using ht_abs + have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by + exact abs_le.mp ht_abs' + have hlow : + (beta i : Real) - (radius i : Real) ≤ t + (beta i : Real) := by + have h := add_le_add_left hbounds.1 (beta i : Real) + simpa [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h + have hhigh : + t + (beta i : Real) ≤ (beta i : Real) + (radius i : Real) := by + have h := add_le_add_left hbounds.2 (beta i : Real) + simpa [add_comm, add_left_comm, add_assoc] using h + have hreal : + layerNormReal eps gamma beta x i = t + (beta i : Real) := by + simp [layerNormReal, hne, μ, invStd, varEps, t, add_comm] + have hlo : (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i := by + simpa [bounds, layerNormIntervalBounds, hne, radius, centeredBound, invStdBound, μLo, μHi, + hreal] using hlow + have hhi : layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + simpa [bounds, layerNormIntervalBounds, hne, radius, centeredBound, invStdBound, μLo, μHi, + hreal] using hhigh + exact And.intro hlo hhi + /-- Interval bounds for LayerNorm outputs from an absolute input bound. -/ def layerNormAbsBounds {n : Nat} (eps : Rat) (gamma beta : Fin n → Rat) (absBound : Rat) : @@ -946,6 +1138,127 @@ theorem layerNormAbsBounds_spec_real {n : Nat} simpa [bounds, layerNormAbsBounds, radius, centeredBound, invStdBound, hreal] using hhigh exact And.intro hlo hhi +/-- `layerNormIntervalBounds` soundness for real LayerNorm outputs on real inputs. -/ +theorem layerNormIntervalBounds_spec_real {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (lo hi : Fin n → Rat) (x : Fin n → Real) + (hne : n ≠ 0) (heps : 0 < eps) + (hlo : ∀ i, (lo i : Real) ≤ x i) (hhi : ∀ i, x i ≤ (hi i : Real)) : + let bounds := layerNormIntervalBounds eps gamma beta lo hi + ∀ i, + (bounds.1 i : Real) ≤ layerNormRealOfReal eps gamma beta x i ∧ + layerNormRealOfReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + classical + intro bounds i + have hmean_lo : (mean lo : Real) ≤ meanReal x := by + have h := + meanReal_le_meanReal (x := fun j => (lo j : Real)) (y := x) hne + (fun j => hlo j) + simpa [meanReal_ratCast] using h + have hmean_hi : meanReal x ≤ (mean hi : Real) := by + have h := + meanReal_le_meanReal (x := x) (y := fun j => (hi j : Real)) hne + (fun j => hhi j) + simpa [meanReal_ratCast] using h + let μLo : Rat := mean lo + let μHi : Rat := mean hi + let centeredBound : Fin n → Rat := fun j => max |lo j - μHi| |hi j - μLo| + let invStdBound : Rat := (sqrtLower eps)⁻¹ + let varEps : Real := varianceReal x + (eps : Real) + let μ : Real := meanReal x + let invStd : Real := (Real.sqrt varEps)⁻¹ + have hcentered_nonneg : 0 ≤ (centeredBound i : Real) := by + have h0 : 0 ≤ centeredBound i := by + dsimp [centeredBound] + exact le_trans (abs_nonneg _) (le_max_left _ _) + exact_mod_cast h0 + have hcentered_abs : |x i - μ| ≤ (centeredBound i : Real) := by + have hmean_lo_real : (μLo : Real) ≤ μ := by + simpa [μLo, μ] using hmean_lo + have hmean_hi_real : μ ≤ (μHi : Real) := by + simpa [μHi, μ] using hmean_hi + have hlo' : (lo i : Real) - (μHi : Real) ≤ x i - μ := by + have h1 : (lo i : Real) - (μHi : Real) ≤ (lo i : Real) - μ := by + exact sub_le_sub_left hmean_hi_real (lo i : Real) + have h2 : (lo i : Real) - μ ≤ x i - μ := by + exact sub_le_sub_right (hlo i) μ + exact le_trans h1 h2 + have hhi' : x i - μ ≤ (hi i : Real) - (μLo : Real) := by + have h1 : x i - μ ≤ (hi i : Real) - μ := by + exact sub_le_sub_right (hhi i) μ + have h2 : (hi i : Real) - μ ≤ (hi i : Real) - (μLo : Real) := by + exact sub_le_sub_left hmean_lo_real (hi i : Real) + exact le_trans h1 h2 + have hbound := abs_le_max_of_bounds hlo' hhi' + simpa [centeredBound, μLo, μHi, Rat.cast_abs, Rat.cast_sub, Rat.cast_max] using hbound + have hvar_nonneg : 0 ≤ varianceReal x := varianceReal_nonneg x hne + have hsqrt_lower : + (sqrtLower eps : Real) ≤ Real.sqrt varEps := by + have hsqrt_eps : + (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by + have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) + simpa using h + have hle : (eps : Real) ≤ varEps := by + have hle' : (eps : Real) ≤ varianceReal x + (eps : Real) := by + exact le_add_of_nonneg_left hvar_nonneg + simpa [varEps] using hle' + have hsqrt_eps' : Real.sqrt (eps : Real) ≤ Real.sqrt varEps := by + exact Real.sqrt_le_sqrt hle + exact le_trans hsqrt_eps hsqrt_eps' + have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by + exact_mod_cast (sqrtLower_pos (q := eps) heps) + have hinv : invStd ≤ (invStdBound : Real) := by + have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower + simpa [invStd, invStdBound] using h + have hinv_nonneg : 0 ≤ invStd := by + have hsqrt_nonneg : 0 ≤ Real.sqrt varEps := by + exact Real.sqrt_nonneg _ + exact inv_nonneg.2 hsqrt_nonneg + have hmul1 : |x i - μ| * invStd ≤ + (centeredBound i : Real) * (invStdBound : Real) := by + have hleft : |x i - μ| * invStd ≤ (centeredBound i : Real) * invStd := by + exact mul_le_mul_of_nonneg_right hcentered_abs hinv_nonneg + have hright : + (centeredBound i : Real) * invStd ≤ (centeredBound i : Real) * (invStdBound : Real) := by + exact mul_le_mul_of_nonneg_left hinv hcentered_nonneg + exact le_trans hleft hright + have hmul2 : |(gamma i : Real)| * |x i - μ| * invStd ≤ + |(gamma i : Real)| * (centeredBound i : Real) * (invStdBound : Real) := by + have hgamma_nonneg : 0 ≤ |(gamma i : Real)| := abs_nonneg _ + have hmul2' : |(gamma i : Real)| * (|x i - μ| * invStd) ≤ + |(gamma i : Real)| * ((centeredBound i : Real) * (invStdBound : Real)) := by + exact mul_le_mul_of_nonneg_left hmul1 hgamma_nonneg + simpa [mul_assoc] using hmul2' + let t : Real := (gamma i : Real) * (x i - μ) * invStd + have ht_abs : + |t| ≤ |(gamma i : Real)| * (centeredBound i : Real) * (invStdBound : Real) := by + have ht : |t| = |(gamma i : Real)| * |x i - μ| * invStd := by + have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg + simp [t, abs_mul, hinv_abs, mul_assoc] + simpa [ht] using hmul2 + let radius : Fin n → Rat := fun j => |gamma j| * centeredBound j * invStdBound + have ht_abs' : |t| ≤ (radius i : Real) := by + simpa [radius, centeredBound, invStdBound] using ht_abs + have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by + exact abs_le.mp ht_abs' + have hlow : + (beta i : Real) - (radius i : Real) ≤ t + (beta i : Real) := by + have h := add_le_add_left hbounds.1 (beta i : Real) + simpa [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h + have hhigh : + t + (beta i : Real) ≤ (beta i : Real) + (radius i : Real) := by + have h := add_le_add_left hbounds.2 (beta i : Real) + simpa [add_comm, add_left_comm, add_assoc] using h + have hreal : + layerNormRealOfReal eps gamma beta x i = t + (beta i : Real) := by + simp [layerNormRealOfReal, hne, μ, invStd, varEps, t, add_comm] + have hlo : (bounds.1 i : Real) ≤ layerNormRealOfReal eps gamma beta x i := by + simpa [bounds, layerNormIntervalBounds, hne, radius, centeredBound, invStdBound, μLo, μHi, + hreal] using hlow + have hhi : layerNormRealOfReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + simpa [bounds, layerNormIntervalBounds, hne, radius, centeredBound, invStdBound, μLo, μHi, + hreal] using hhigh + exact And.intro hlo hhi + end Bounds end Sound diff --git a/Nfp/Sound/Bounds/MatrixNorm.lean b/Nfp/Sound/Bounds/MatrixNorm.lean index b3a6243..9f68ecd 100644 --- a/Nfp/Sound/Bounds/MatrixNorm.lean +++ b/Nfp/Sound/Bounds/MatrixNorm.lean @@ -1,5 +1,6 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later +import Mathlib.Algebra.BigOperators.Fin import Mathlib.Algebra.Order.BigOperators.Group.Finset import Mathlib.Algebra.Order.Ring.Abs import Mathlib.Algebra.Order.Ring.Rat @@ -9,6 +10,7 @@ import Mathlib.Data.Rat.Cast.Order import Mathlib.Data.Real.Basic import Nfp.Circuit.Cert.DownstreamLinear import Nfp.Circuit.Cert.ResidualInterval +import Nfp.Sound.Linear.FinFold /-! Row-sum matrix norms for downstream linear certificates. @@ -25,14 +27,64 @@ namespace Bounds open scoped BigOperators +private theorem sumFin_eq_sum_univ {n : Nat} (f : Fin n → Rat) : + Linear.sumFin n f = ∑ i, f i := by + classical + have hfold : + Linear.sumFin n f = (List.finRange n).foldl (fun acc i => acc + f i) 0 := by + simpa using Linear.sumFin_eq_list_foldl n f + have hmap : + ((List.finRange n).map f).foldl (fun acc x : Rat => acc + x) 0 = + (List.finRange n).foldl (fun acc i => acc + f i) 0 := by + have hmap' : + ∀ l : List (Fin n), ∀ init : Rat, + (l.map f).foldl (fun acc x : Rat => acc + x) init = + l.foldl (fun acc i => acc + f i) init := by + intro l + induction l with + | nil => + intro init + simp + | cons a l ih => + intro init + simp [ih] + exact hmap' (List.finRange n) 0 + let _ : Std.Commutative (fun a b : Rat => a + b) := + ⟨by intro a b; exact add_comm _ _⟩ + let _ : Std.Associative (fun a b : Rat => a + b) := + ⟨by intro a b c; exact add_assoc _ _ _⟩ + have hfoldr : + ((List.finRange n).map f).foldl (fun acc x : Rat => acc + x) 0 = + ((List.finRange n).map f).foldr (fun x acc => x + acc) 0 := by + simpa using + (List.foldl_eq_foldr (f := fun acc x : Rat => acc + x) + (a := 0) (l := (List.finRange n).map f)) + have hsum_list : + ((List.finRange n).map f).sum = (List.finRange n).foldl (fun acc i => acc + f i) 0 := by + calc + ((List.finRange n).map f).sum + = ((List.finRange n).map f).foldr (fun x acc => x + acc) 0 := by + rfl + _ = ((List.finRange n).map f).foldl (fun acc x : Rat => acc + x) 0 := by + exact hfoldr.symm + _ = (List.finRange n).foldl (fun acc i => acc + f i) 0 := by + exact hmap + have hsum_univ : ((List.finRange n).map f).sum = ∑ i, f i := by + exact (Fin.sum_univ_def f).symm + calc + Linear.sumFin n f + = (List.finRange n).foldl (fun acc i => acc + f i) 0 := hfold + _ = ((List.finRange n).map f).sum := hsum_list.symm + _ = ∑ i, f i := hsum_univ + /-- Row-sum of absolute values for a matrix row. -/ def rowSum {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : Rat := - ∑ j, |W i j| + Linear.sumFin n (fun j => |W i j|) /-- Weighted row-sum using per-coordinate bounds. -/ def rowSumWeighted {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (bound : Fin n → Rat) (i : Fin m) : Rat := - ∑ j, |W i j| * bound j + Linear.sumFin n (fun j => |W i j| * bound j) /-- Maximum row-sum norm (defaults to `0` on empty matrices). -/ def rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) : Rat := @@ -52,17 +104,25 @@ def rowSumWeightedNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) /-- Row-sums are nonnegative. -/ theorem rowSum_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : 0 ≤ rowSum W i := by - refine Finset.sum_nonneg ?_ - intro j _ - exact abs_nonneg (W i j) + have hsum : rowSum W i = ∑ j, |W i j| := by + simp [rowSum, sumFin_eq_sum_univ] + have hnonneg : 0 ≤ ∑ j, |W i j| := by + refine Finset.sum_nonneg ?_ + intro j _ + exact abs_nonneg (W i j) + simpa [hsum] using hnonneg /-- Weighted row-sums are nonnegative under nonnegative bounds. -/ theorem rowSumWeighted_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (bound : Fin n → Rat) (i : Fin m) (hbound : ∀ j, 0 ≤ bound j) : 0 ≤ rowSumWeighted W bound i := by - refine Finset.sum_nonneg ?_ - intro j _ - exact mul_nonneg (abs_nonneg (W i j)) (hbound j) + have hsum : rowSumWeighted W bound i = ∑ j, |W i j| * bound j := by + simp [rowSumWeighted, sumFin_eq_sum_univ] + have hnonneg : 0 ≤ ∑ j, |W i j| * bound j := by + refine Finset.sum_nonneg ?_ + intro j _ + exact mul_nonneg (abs_nonneg (W i j)) (hbound j) + simpa [hsum] using hnonneg /-- Each row-sum is bounded by the row-sum norm. -/ theorem rowSum_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : @@ -132,11 +192,11 @@ theorem downstreamErrorFromBounds_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) /-- Lower interval endpoint for a dot product with per-coordinate bounds. -/ def dotIntervalLower {n : Nat} (v lo hi : Fin n → Rat) : Rat := - ∑ j, if 0 ≤ v j then v j * lo j else v j * hi j + Linear.sumFin n (fun j => if 0 ≤ v j then v j * lo j else v j * hi j) /-- Upper interval endpoint for a dot product with per-coordinate bounds. -/ def dotIntervalUpper {n : Nat} (v lo hi : Fin n → Rat) : Rat := - ∑ j, if 0 ≤ v j then v j * hi j else v j * lo j + Linear.sumFin n (fun j => if 0 ≤ v j then v j * hi j else v j * lo j) /-- Absolute bound from interval endpoints for a dot product. -/ def dotIntervalAbsBound {n : Nat} (v lo hi : Fin n → Rat) : Rat := @@ -156,6 +216,7 @@ theorem dotIntervalLower_le_dotProduct {n : Nat} (v lo hi x : Fin n → Rat) (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : dotIntervalLower v lo hi ≤ dotProduct v x := by classical + simp only [dotIntervalLower, sumFin_eq_sum_univ, dotProduct] refine Finset.sum_le_sum ?_ intro j _ by_cases hv : 0 ≤ v j @@ -171,6 +232,7 @@ theorem dotProduct_le_dotIntervalUpper {n : Nat} (v lo hi x : Fin n → Rat) (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : dotProduct v x ≤ dotIntervalUpper v lo hi := by classical + simp only [dotIntervalUpper, sumFin_eq_sum_univ, dotProduct] refine Finset.sum_le_sum ?_ intro j _ by_cases hv : 0 ≤ v j @@ -253,7 +315,7 @@ theorem dotIntervalLower_le_dotProduct_real {n : Nat} (v lo hi : Fin n → Rat) have hcast : (dotIntervalLower v lo hi : Real) = ∑ j, if 0 ≤ v j then (v j : Real) * (lo j : Real) else (v j : Real) * (hi j : Real) := by - conv_lhs => simp [dotIntervalLower] + conv_lhs => simp [dotIntervalLower, sumFin_eq_sum_univ] refine Finset.sum_congr rfl ?_ intro j _ by_cases hv : 0 ≤ v j @@ -283,7 +345,7 @@ theorem dotProduct_le_dotIntervalUpper_real {n : Nat} (v lo hi : Fin n → Rat) have hcast : (dotIntervalUpper v lo hi : Real) = ∑ j, if 0 ≤ v j then (v j : Real) * (hi j : Real) else (v j : Real) * (lo j : Real) := by - conv_lhs => simp [dotIntervalUpper] + conv_lhs => simp [dotIntervalUpper, sumFin_eq_sum_univ] refine Finset.sum_congr rfl ?_ intro j _ by_cases hv : 0 ≤ v j @@ -443,7 +505,7 @@ theorem abs_mulVec_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (s := (Finset.univ : Finset (Fin n))) (f := fun j => |W i j|) (a := inputBound)) - simpa [rowSum] using hsum.symm + simpa [rowSum, sumFin_eq_sum_univ] using hsum.symm have hmul : |Matrix.mulVec W x i| ≤ rowSum W i * inputBound := by simpa [Matrix.mulVec, dotProduct] using h1.trans (h2.trans_eq h3) exact hmul diff --git a/Nfp/Sound/Bounds/Transformer.lean b/Nfp/Sound/Bounds/Transformer.lean index f528152..6a57aeb 100644 --- a/Nfp/Sound/Bounds/Transformer.lean +++ b/Nfp/Sound/Bounds/Transformer.lean @@ -7,7 +7,7 @@ import Mathlib.Data.Real.Basic import Nfp.Model.Gpt2 import Nfp.Sound.Bounds.Attention import Nfp.Sound.Bounds.LayerNorm -import Nfp.Sound.Bounds.MatrixNorm +import Nfp.Sound.Linear.FinFold /-! Interval bounds for transformer stacks and final LayerNorm outputs. @@ -105,7 +105,7 @@ noncomputable def transformerStackReal (x : Fin seq → Fin dModel → Real) : Fin seq → Fin dModel → Real := let step := fun x layerIdx => transformerLayerReal eps (layers layerIdx) (heads layerIdx) (scores layerIdx) x - (List.finRange numLayers).foldl step x + Linear.foldlFin numLayers step x /-- Interval bounds for a transformer stack (folded left over layers). -/ def transformerStackBounds {dModel dHead numHeads hidden numLayers : Nat} @@ -118,7 +118,7 @@ def transformerStackBounds {dModel dHead numHeads hidden numLayers : Nat} (layers layerIdx).ln2Gamma (layers layerIdx).ln2Beta (heads layerIdx) (layers layerIdx).attnBias (layers layerIdx).mlpWIn (layers layerIdx).mlpBIn (layers layerIdx).mlpWOut (layers layerIdx).mlpBOut bounds.1 bounds.2 - (List.finRange numLayers).foldl step (lo, hi) + Linear.foldlFin numLayers step (lo, hi) private theorem transformerStackBounds_spec_list {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] @@ -177,7 +177,8 @@ theorem transformerStackBounds_spec {seq dModel dHead numHeads hidden numLayers (bounds.1 i : Real) ≤ transformerStackReal eps layers heads scores x q i ∧ transformerStackReal eps layers heads scores x q i ≤ (bounds.2 i : Real) := by classical - simpa [transformerStackBounds, transformerStackReal] using + simpa [transformerStackBounds, transformerStackReal, Linear.foldlFin_eq_foldl, + Fin.foldl_eq_foldl_finRange] using transformerStackBounds_spec_list eps layers heads scores hne heps (List.finRange numLayers) lo hi x hlo hhi @@ -198,8 +199,7 @@ def transformerStackFinalBounds {dModel dHead numHeads hidden numLayers : Nat} (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (lo hi : Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := let stack := transformerStackBounds eps layers heads lo hi - let absBound := intervalAbsBound stack.1 stack.2 - layerNormAbsBounds eps finalLn.gamma finalLn.beta absBound + layerNormIntervalBounds eps finalLn.gamma finalLn.beta stack.1 stack.2 /-- `transformerStackFinalBounds` soundness for real outputs. -/ theorem transformerStackFinalBounds_spec {seq dModel dHead numHeads hidden numLayers : Nat} @@ -217,25 +217,16 @@ theorem transformerStackFinalBounds_spec {seq dModel dHead numHeads hidden numLa classical intro bounds q i let stack := transformerStackBounds eps layers heads lo hi - let absBound := intervalAbsBound stack.1 stack.2 have hstack := transformerStackBounds_spec eps layers heads scores lo hi x hne heps hlo hhi q - have habs : - ∀ j, |transformerStackReal eps layers heads scores x q j| ≤ (absBound : Real) := by - intro j - have hlo' : ∀ k, (stack.1 k : Real) ≤ transformerStackReal eps layers heads scores x q k := - fun k => (hstack k).1 - have hhi' : ∀ k, transformerStackReal eps layers heads scores x q k ≤ (stack.2 k : Real) := - fun k => (hstack k).2 - have hbound := - abs_le_intervalAbsBound_real (lo := stack.1) (hi := stack.2) - (x := fun k => transformerStackReal eps layers heads scores x q k) hlo' hhi' j - simpa [absBound] using hbound + have hlo' : ∀ k, (stack.1 k : Real) ≤ transformerStackReal eps layers heads scores x q k := + fun k => (hstack k).1 + have hhi' : ∀ k, transformerStackReal eps layers heads scores x q k ≤ (stack.2 k : Real) := + fun k => (hstack k).2 have hln := - layerNormAbsBounds_spec_real eps finalLn.gamma finalLn.beta absBound - (fun j => transformerStackReal eps layers heads scores x q j) hne heps habs - simpa [bounds, transformerStackFinalBounds, absBound, stack, transformerStackFinalReal] using - hln i + layerNormIntervalBounds_spec_real eps finalLn.gamma finalLn.beta stack.1 stack.2 + (fun j => transformerStackReal eps layers heads scores x q j) hne heps hlo' hhi' + simpa [bounds, transformerStackFinalBounds, stack, transformerStackFinalReal] using hln i /-- Residual interval bounds for a GPT-2 stack from exact embeddings. -/ def gpt2ResidualIntervalBounds From 739b3428b3593099ed60a6d9d5076b8e6ab0e73d Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Wed, 7 Jan 2026 02:07:15 +0100 Subject: [PATCH 108/244] Refactor IO/Sound modules and streamline dyadic-real lemmas --- AGENTS.md | 80 +- Nfp.lean | 22 - Nfp/Circuit/Cert/DownstreamLinear.lean | 8 +- Nfp/Circuit/Cert/LogitDiff.lean | 22 +- Nfp/Circuit/Cert/ResidualBound.lean | 4 +- Nfp/Circuit/Cert/ResidualInterval.lean | 6 +- Nfp/Circuit/Cert/SoftmaxMargin.lean | 14 +- Nfp/Circuit/Cert/ValueRange.lean | 12 +- Nfp/Core/Basic.lean | 581 +++++++++ Nfp/IO.lean | 646 ++-------- Nfp/IO/Bench/InductionCore.lean | 229 ++++ Nfp/IO/Bench/InductionCounts.lean | 72 ++ Nfp/IO/Bench/Rational.lean | 362 ++++++ Nfp/IO/Checks.lean | 44 + Nfp/IO/Derive.lean | 136 ++ Nfp/IO/HeadScore.lean | 56 + Nfp/IO/InductionHead.lean | 821 ++++++++++++ Nfp/IO/Loaders.lean | 70 + Nfp/IO/NfptPure.lean | 164 ++- Nfp/IO/Pure.lean | 1108 +--------------- Nfp/IO/Pure/Basic.lean | 75 ++ Nfp/IO/Pure/Downstream.lean | 202 +++ Nfp/IO/Pure/InductionHead.lean | 25 + Nfp/IO/Pure/InductionHead/Bytes.lean | 786 +++++++++++ Nfp/IO/Pure/Residual.lean | 136 ++ Nfp/IO/Pure/SoftmaxMargin.lean | 8 + Nfp/IO/Pure/SoftmaxMargin/Cert.lean | 79 ++ Nfp/IO/Pure/SoftmaxMargin/Raw.lean | 80 ++ Nfp/IO/Pure/SoftmaxMargin/Shared.lean | 138 ++ Nfp/IO/Pure/ValueRange.lean | 8 + Nfp/IO/Pure/ValueRange/Cert.lean | 63 + Nfp/IO/Pure/ValueRange/Raw.lean | 62 + Nfp/IO/Pure/ValueRange/Shared.lean | 103 ++ Nfp/IO/Timing.lean | 204 +++ Nfp/IO/Util.lean | 25 + Nfp/Mixer/Operations.lean | 43 +- Nfp/Model/Gpt2.lean | 70 +- Nfp/Model/InductionHead.lean | 34 +- Nfp/Prob/Operations.lean | 17 +- Nfp/Sound.lean | 1 + Nfp/Sound/Bounds/Attention.lean | 294 +++-- Nfp/Sound/Bounds/Gelu.lean | 61 +- Nfp/Sound/Bounds/LayerNorm.lean | 999 +++++++------- Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean | 260 ++++ Nfp/Sound/Bounds/MatrixNorm.lean | 447 +------ Nfp/Sound/Bounds/MatrixNorm/Interval.lean | 436 +++++++ Nfp/Sound/Bounds/Mlp.lean | 195 ++- Nfp/Sound/Bounds/Transformer.lean | 376 +++++- Nfp/Sound/Bounds/Transformer/Embedding.lean | 132 ++ Nfp/Sound/Bounds/UnnormRat.lean | 61 + Nfp/Sound/Gpt2/HeadInputs.lean | 4 +- Nfp/Sound/Induction.lean | 1219 +----------------- Nfp/Sound/Induction/Core.lean | 1191 +++++++++++++++++ Nfp/Sound/Induction/CoreDefs.lean | 149 +++ Nfp/Sound/Induction/HeadBounds.lean | 532 ++++++++ Nfp/Sound/Induction/HeadOutput.lean | 376 ++++++ Nfp/Sound/Induction/LogitDiff.lean | 16 +- Nfp/Sound/Induction/OneHot.lean | 56 +- Nfp/Sound/Linear/FinFold.lean | 92 +- TheoremAxioms.lean | 28 + lakefile.toml | 16 + scripts/build_residual_interval_cert.py | 33 +- scripts/scan_gpt2_induction_sound.py | 190 ++- 63 files changed, 9380 insertions(+), 4399 deletions(-) create mode 100644 Nfp/IO/Bench/InductionCore.lean create mode 100644 Nfp/IO/Bench/InductionCounts.lean create mode 100644 Nfp/IO/Bench/Rational.lean create mode 100644 Nfp/IO/Checks.lean create mode 100644 Nfp/IO/Derive.lean create mode 100644 Nfp/IO/HeadScore.lean create mode 100644 Nfp/IO/InductionHead.lean create mode 100644 Nfp/IO/Loaders.lean create mode 100644 Nfp/IO/Pure/Basic.lean create mode 100644 Nfp/IO/Pure/Downstream.lean create mode 100644 Nfp/IO/Pure/InductionHead.lean create mode 100644 Nfp/IO/Pure/InductionHead/Bytes.lean create mode 100644 Nfp/IO/Pure/Residual.lean create mode 100644 Nfp/IO/Pure/SoftmaxMargin.lean create mode 100644 Nfp/IO/Pure/SoftmaxMargin/Cert.lean create mode 100644 Nfp/IO/Pure/SoftmaxMargin/Raw.lean create mode 100644 Nfp/IO/Pure/SoftmaxMargin/Shared.lean create mode 100644 Nfp/IO/Pure/ValueRange.lean create mode 100644 Nfp/IO/Pure/ValueRange/Cert.lean create mode 100644 Nfp/IO/Pure/ValueRange/Raw.lean create mode 100644 Nfp/IO/Pure/ValueRange/Shared.lean create mode 100644 Nfp/IO/Timing.lean create mode 100644 Nfp/IO/Util.lean create mode 100644 Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean create mode 100644 Nfp/Sound/Bounds/MatrixNorm/Interval.lean create mode 100644 Nfp/Sound/Bounds/Transformer/Embedding.lean create mode 100644 Nfp/Sound/Bounds/UnnormRat.lean create mode 100644 Nfp/Sound/Induction/Core.lean create mode 100644 Nfp/Sound/Induction/CoreDefs.lean create mode 100644 Nfp/Sound/Induction/HeadBounds.lean create mode 100644 Nfp/Sound/Induction/HeadOutput.lean create mode 100644 TheoremAxioms.lean diff --git a/AGENTS.md b/AGENTS.md index fea08bc..f68907e 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -17,10 +17,10 @@ but keep the core invariants and the “no fake proofs” ethos. ## 0. Quick Start (What to run) ### Build (warnings are errors) -- `lake build -q --wfail` +- `lake build --wfail` ### Build the CLI -- `lake build nfp -q --wfail` +- `lake build nfp --wfail` ### Run the CLI (preferred integration path) One of these typically works (depending on your Lake setup): @@ -33,8 +33,8 @@ If you add or change CLI behavior, validate at least: - `nfp --version` (if supported) Before you finish any change: -- `lake build -q --wfail` -- `lake build nfp -q --wfail` +- `lake build --wfail` +- `lake build nfp --wfail` Note: `models/` is gitignored, so `rg` will skip it unless you pass `--no-ignore` or `-uuu` (or equivalent) when searching. @@ -330,23 +330,84 @@ but you **must** update this list in the same commit. ### 5.6 CLI surface - `Nfp/IO/Pure.lean` - - Pure parsing helpers for CLI inputs. + - Aggregator for pure parsing helpers. +- `Nfp/IO/Pure/Basic.lean` + - Shared parsing helpers (`Nat`/`Int`/`Rat`, token cleanup). +- `Nfp/IO/Pure/InductionHead.lean` + - Induction-head input payload parsing from text/bytes. +- `Nfp/IO/Pure/InductionHead/Bytes.lean` + - Byte-level parser for induction-head input payloads. +- `Nfp/IO/Pure/SoftmaxMargin.lean` + - Aggregator for softmax-margin parsing helpers. +- `Nfp/IO/Pure/SoftmaxMargin/Shared.lean` + - Shared parsing helpers for softmax-margin payloads. +- `Nfp/IO/Pure/SoftmaxMargin/Cert.lean` + - Softmax-margin certificate parser. +- `Nfp/IO/Pure/SoftmaxMargin/Raw.lean` + - Softmax-margin raw-input parser. +- `Nfp/IO/Pure/ValueRange.lean` + - Aggregator for value-range parsing helpers. +- `Nfp/IO/Pure/ValueRange/Shared.lean` + - Shared parsing helpers for value-range payloads. +- `Nfp/IO/Pure/ValueRange/Cert.lean` + - Value-range certificate parser. +- `Nfp/IO/Pure/ValueRange/Raw.lean` + - Value-range raw-input parser. +- `Nfp/IO/Pure/Downstream.lean` + - Downstream linear and matrix payload parsers. +- `Nfp/IO/Pure/Residual.lean` + - Residual-bound and residual-interval payload parsers. - `Nfp/IO/NfptPure.lean` - Pure parsing helpers for `NFP_BINARY_V1` model slices. +- `Nfp/IO/HeadScore.lean` + - Pure task-based cache builder for head score dot-abs bounds. +- `Nfp/IO/Loaders.lean` + - IO loaders for certificates and raw inputs. +- `Nfp/IO/Checks.lean` + - IO checks for certificate validity. +- `Nfp/IO/Derive.lean` + - IO derivations building certificates from model binaries. +- `Nfp/IO/Timing.lean` + - IO timing helpers with microsecond reporting and phase wrappers. +- `Nfp/IO/Util.lean` + - Small CLI parsing utilities shared across IO entrypoints. +- `Nfp/IO/InductionHead.lean` + - Induction-head IO pipeline with timing instrumentation. +- `Nfp/IO/Bench/Rational.lean` + - Microbenchmarks for rational arithmetic and caching. - `Nfp/IO.lean` - IO-only wrappers for loading inputs and running checks. - `Nfp/Cli.lean` - CLI commands and `main` implementation. - `Main.lean` - Thin entrypoint delegating to `Nfp.Cli.main`. + - Benchmark entrypoint for rational microbenchmarks. - `Nfp.lean` - - Top-level reexports and axioms dashboard (`#print axioms`). + - Top-level reexports. +- `TheoremAxioms.lean` + - Axiom dashboard for `theorem-axioms` build target (`#print axioms`). ### 5.7 Sound certification - `Nfp/Sound/Induction.lean` - - Sound builders for induction certificates from exact inputs. + - Aggregator for induction soundness modules. +- `Nfp/Sound/Induction/Core.lean` + - Sound builders and core proofs for induction certificates from exact inputs. +- `Nfp/Sound/Induction/CoreDefs.lean` + - Core definitions and soundness predicates for induction certificates. +- `Nfp/Sound/Induction/HeadOutput.lean` + - Head-output interval certificates built from induction head inputs. +- `Nfp/Sound/Induction/HeadBounds.lean` + - Helper bounds used to stage head-induction certificate construction. - `Nfp/Sound/Bounds/MatrixNorm.lean` - Row-sum matrix norms and downstream linear certificate builders. +- `Nfp/Sound/Bounds/MatrixNorm/Interval.lean` + - Dot-product and matrix-vector interval bounds (dyadic and real). +- `Nfp/Sound/Bounds/LayerNorm.lean` + - LayerNorm interval bounds and end-to-end soundness lemmas. +- `Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean` + - Mean/variance helpers for LayerNorm bounds. +- `Nfp/Sound/Bounds/UnnormRat.lean` + - Unnormalized rational helpers for deferred normalization in bounds kernels. - `Nfp/Sound/Bounds/Gelu.lean` - Tanh-GELU bounds for interval propagation through MLPs. - `Nfp/Sound/Bounds/Mlp.lean` @@ -355,6 +416,8 @@ but you **must** update this list in the same commit. - Interval bounds for multi-head attention and transformer layers. - `Nfp/Sound/Bounds/Transformer.lean` - Interval bounds for transformer stacks and final LayerNorm outputs. +- `Nfp/Sound/Bounds/Transformer/Embedding.lean` + - Embedding interval bounds and position-restricted bounds. - `Nfp/Sound/Linear/FinFold.lean` - Tail-recursive folds and sums for sound linear computations. - `Nfp/Sound/Gpt2/HeadInputs.lean` @@ -385,7 +448,8 @@ This repo treats “axioms creep” as a serious regression. - Do not add axioms. - Keep an eye on classical assumptions; they may be unavoidable, but should be explicit. -- Use `Nfp.lean` as the “trust dashboard” for `#print axioms` / dependency visibility. +- Use `TheoremAxioms.lean` / `lake build theorem-axioms --wfail` as the trust dashboard for + `#print axioms` / dependency visibility. --- diff --git a/Nfp.lean b/Nfp.lean index ee1adc1..c99e563 100644 --- a/Nfp.lean +++ b/Nfp.lean @@ -11,25 +11,3 @@ import Nfp.Sound /-! Top-level reexports and trust dashboard for the NFP rewrite. -/ - -/-! -Axioms used by key definitions/lemmas. -These `#print axioms` lines help ensure we only depend on a small set of axioms -(ideally a subset of: `propext`, `Classical.choice`, `Quot.sound`). --/ - -#print axioms Nfp.ProbVec.sum_mass -#print axioms Nfp.ProbVec.pure -#print axioms Nfp.ProbVec.mix -#print axioms Nfp.Mixer.push -#print axioms Nfp.Mixer.comp -#print axioms Nfp.Mixer.id -#print axioms Nfp.Dag.parents -#print axioms Nfp.LocalSystem.toMixer -#print axioms Nfp.LocalSystem.eval -#print axioms Nfp.LocalSystem.eval_eq -#print axioms Nfp.Circuit.eval -#print axioms Nfp.Circuit.evalInput -#print axioms Nfp.Circuit.Interface.eval -#print axioms Nfp.Circuit.checkEquiv -#print axioms Nfp.Circuit.checkEquivOnInterface diff --git a/Nfp/Circuit/Cert/DownstreamLinear.lean b/Nfp/Circuit/Cert/DownstreamLinear.lean index 69bb959..41a61a9 100644 --- a/Nfp/Circuit/Cert/DownstreamLinear.lean +++ b/Nfp/Circuit/Cert/DownstreamLinear.lean @@ -1,6 +1,6 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.Order.Ring.Rat +import Nfp.Core.Basic import Nfp.Circuit.Cert /-! @@ -18,11 +18,11 @@ namespace Circuit /-- Certificate payload for downstream linear error bounds. -/ structure DownstreamLinearCert where /-- Upper bound on the downstream logit-diff error. -/ - error : Rat + error : Dyadic /-- Operator gain bound used to justify the error. -/ - gain : Rat + gain : Dyadic /-- Input magnitude bound used to justify the error. -/ - inputBound : Rat + inputBound : Dyadic /-- Arithmetic properties enforced by `checkDownstreamLinearCert`. -/ structure DownstreamLinearBounds (c : DownstreamLinearCert) : Prop where diff --git a/Nfp/Circuit/Cert/LogitDiff.lean b/Nfp/Circuit/Cert/LogitDiff.lean index 99da2da..f51d463 100644 --- a/Nfp/Circuit/Cert/LogitDiff.lean +++ b/Nfp/Circuit/Cert/LogitDiff.lean @@ -1,6 +1,6 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.Order.Ring.Rat +import Nfp.Core.Basic import Mathlib.Data.Finset.Image import Nfp.Circuit.Layers.Induction @@ -17,11 +17,11 @@ variable {seq : Nat} /-- Compute a lower bound on the logit-diff contribution over active queries. -/ def logitDiffLowerBound (active : Finset (Fin seq)) (prev : Fin seq → Fin seq) - (eps lo hi : Rat) (vals : Fin seq → Rat) : Option Rat := by + (eps lo hi : Dyadic) (vals : Fin seq → Dyadic) : Option Dyadic := by classical if h : active.Nonempty then let gap := eps * (hi - lo) - let f : Fin seq → Rat := fun q => vals (prev q) - gap + let f : Fin seq → Dyadic := fun q => vals (prev q) - gap let img := active.image f have himg : img.Nonempty := h.image f exact some (Finset.min' img himg) @@ -31,11 +31,11 @@ def logitDiffLowerBound (active : Finset (Fin seq)) /-- Compute a lower bound on the logit-diff contribution with per-query eps. -/ def logitDiffLowerBoundAt (active : Finset (Fin seq)) (prev : Fin seq → Fin seq) - (epsAt : Fin seq → Rat) (lo hi : Rat) (vals : Fin seq → Rat) : Option Rat := by + (epsAt : Fin seq → Dyadic) (lo hi : Dyadic) (vals : Fin seq → Dyadic) : Option Dyadic := by classical if h : active.Nonempty then - let gap : Fin seq → Rat := fun q => epsAt q * (hi - lo) - let f : Fin seq → Rat := fun q => vals (prev q) - gap q + let gap : Fin seq → Dyadic := fun q => epsAt q * (hi - lo) + let f : Fin seq → Dyadic := fun q => vals (prev q) - gap q let img := active.image f have himg : img.Nonempty := h.image f exact some (Finset.min' img himg) @@ -45,7 +45,7 @@ def logitDiffLowerBoundAt (active : Finset (Fin seq)) /-- The computed lower bound is below every active `prev` value minus the tolerance gap. -/ theorem logitDiffLowerBound_le (active : Finset (Fin seq)) (prev : Fin seq → Fin seq) - (eps lo hi : Rat) (vals : Fin seq → Rat) + (eps lo hi : Dyadic) (vals : Fin seq → Dyadic) (q : Fin seq) (hq : q ∈ active) : ∀ lb, logitDiffLowerBound active prev eps lo hi vals = some lb → lb ≤ vals (prev q) - eps * (hi - lo) := by @@ -57,7 +57,7 @@ theorem logitDiffLowerBound_le (active : Finset (Fin seq)) (hnonempty.image (fun q => vals (prev q) - eps * (hi - lo))) = lb := by simpa [logitDiffLowerBound, hnonempty] using hbound let gap := eps * (hi - lo) - let f : Fin seq → Rat := fun q => vals (prev q) - gap + let f : Fin seq → Dyadic := fun q => vals (prev q) - gap have hmem : f q ∈ (active.image f) := by refine Finset.mem_image.2 ?_ exact ⟨q, hq, rfl⟩ @@ -70,7 +70,7 @@ theorem logitDiffLowerBound_le (active : Finset (Fin seq)) /-- The per-query lower bound is below every active `prev` value minus the local gap. -/ theorem logitDiffLowerBoundAt_le (active : Finset (Fin seq)) (prev : Fin seq → Fin seq) - (epsAt : Fin seq → Rat) (lo hi : Rat) (vals : Fin seq → Rat) + (epsAt : Fin seq → Dyadic) (lo hi : Dyadic) (vals : Fin seq → Dyadic) (q : Fin seq) (hq : q ∈ active) : ∀ lb, logitDiffLowerBoundAt active prev epsAt lo hi vals = some lb → lb ≤ vals (prev q) - epsAt q * (hi - lo) := by @@ -81,8 +81,8 @@ theorem logitDiffLowerBoundAt_le (active : Finset (Fin seq)) (active.image (fun q => vals (prev q) - epsAt q * (hi - lo))).min' (hnonempty.image (fun q => vals (prev q) - epsAt q * (hi - lo))) = lb := by simpa [logitDiffLowerBoundAt, hnonempty] using hbound - let gap : Fin seq → Rat := fun q => epsAt q * (hi - lo) - let f : Fin seq → Rat := fun q => vals (prev q) - gap q + let gap : Fin seq → Dyadic := fun q => epsAt q * (hi - lo) + let f : Fin seq → Dyadic := fun q => vals (prev q) - gap q have hmem : f q ∈ (active.image f) := by refine Finset.mem_image.2 ?_ exact ⟨q, hq, rfl⟩ diff --git a/Nfp/Circuit/Cert/ResidualBound.lean b/Nfp/Circuit/Cert/ResidualBound.lean index 4cc7a3a..2a85061 100644 --- a/Nfp/Circuit/Cert/ResidualBound.lean +++ b/Nfp/Circuit/Cert/ResidualBound.lean @@ -1,6 +1,6 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.Order.Ring.Rat +import Nfp.Core.Basic import Nfp.Circuit.Cert /-! @@ -16,7 +16,7 @@ namespace Circuit /-- Certificate payload for per-coordinate residual absolute bounds. -/ structure ResidualBoundCert (n : Nat) where /-- Absolute bound per coordinate. -/ - bound : Fin n → Rat + bound : Fin n → Dyadic /-- Properties enforced by `checkResidualBoundCert`. -/ structure ResidualBoundBounds {n : Nat} (c : ResidualBoundCert n) : Prop where diff --git a/Nfp/Circuit/Cert/ResidualInterval.lean b/Nfp/Circuit/Cert/ResidualInterval.lean index dca299b..9360359 100644 --- a/Nfp/Circuit/Cert/ResidualInterval.lean +++ b/Nfp/Circuit/Cert/ResidualInterval.lean @@ -1,6 +1,6 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.Order.Ring.Rat +import Nfp.Core.Basic import Nfp.Circuit.Cert /-! @@ -16,9 +16,9 @@ namespace Circuit /-- Certificate payload for per-coordinate residual intervals. -/ structure ResidualIntervalCert (n : Nat) where /-- Lower bound per coordinate. -/ - lo : Fin n → Rat + lo : Fin n → Dyadic /-- Upper bound per coordinate. -/ - hi : Fin n → Rat + hi : Fin n → Dyadic /-- Properties enforced by `checkResidualIntervalCert`. -/ structure ResidualIntervalBounds {n : Nat} (c : ResidualIntervalCert n) : Prop where diff --git a/Nfp/Circuit/Cert/SoftmaxMargin.lean b/Nfp/Circuit/Cert/SoftmaxMargin.lean index f8157a9..0d0a1a2 100644 --- a/Nfp/Circuit/Cert/SoftmaxMargin.lean +++ b/Nfp/Circuit/Cert/SoftmaxMargin.lean @@ -1,7 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Algebra.Order.Ring.Rat +import Nfp.Core.Basic import Nfp.Circuit.Cert import Nfp.Circuit.Layers.Induction @@ -17,20 +17,20 @@ open scoped BigOperators variable {seq : Nat} -/-- Certificate payload for softmax-margin bounds (Rat-valued). -/ +/-- Certificate payload for softmax-margin bounds (Dyadic-valued). -/ structure SoftmaxMarginCert (seq : Nat) where /-- Weight tolerance. -/ - eps : Rat + eps : Dyadic /-- Score margin used to justify weight bounds. -/ - margin : Rat + margin : Dyadic /-- Active queries for which bounds are checked. -/ active : Finset (Fin seq) /-- `prev` selector for induction-style attention. -/ prev : Fin seq → Fin seq /-- Score matrix entries. -/ - scores : Fin seq → Fin seq → Rat + scores : Fin seq → Fin seq → Dyadic /-- Attention weight entries. -/ - weights : Fin seq → Fin seq → Rat + weights : Fin seq → Fin seq → Dyadic /-- Boolean checker for softmax-margin certificates. -/ def checkSoftmaxMarginCert [NeZero seq] (c : SoftmaxMarginCert seq) : Bool := @@ -55,7 +55,7 @@ def checkSoftmaxMarginCert [NeZero seq] (c : SoftmaxMarginCert seq) : Bool := /-- `checkSoftmaxMarginCert` is sound for `SoftmaxMarginBoundsOn`. -/ theorem checkSoftmaxMarginCert_sound [NeZero seq] (c : SoftmaxMarginCert seq) : checkSoftmaxMarginCert c = true → - Layers.SoftmaxMarginBoundsOn (Val := Rat) c.eps c.margin (fun q => q ∈ c.active) + Layers.SoftmaxMarginBoundsOn (Val := Dyadic) c.eps c.margin (fun q => q ∈ c.active) c.prev c.scores c.weights := by classical intro hcheck diff --git a/Nfp/Circuit/Cert/ValueRange.lean b/Nfp/Circuit/Cert/ValueRange.lean index 33d03f8..7e5f5df 100644 --- a/Nfp/Circuit/Cert/ValueRange.lean +++ b/Nfp/Circuit/Cert/ValueRange.lean @@ -1,7 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Algebra.Order.Ring.Rat +import Nfp.Core.Basic import Nfp.Circuit.Cert import Nfp.Circuit.Layers.Induction @@ -24,14 +24,14 @@ structure DirectionSpec where /-- Negative token id for the logit-diff direction. -/ negative : Nat -/-- Certificate payload for value-range bounds (Rat-valued). -/ +/-- Certificate payload for value-range bounds (Dyadic-valued). -/ structure ValueRangeCert (seq : Nat) where /-- Lower bound for values. -/ - lo : Rat + lo : Dyadic /-- Upper bound for values. -/ - hi : Rat + hi : Dyadic /-- Value entries. -/ - vals : Fin seq → Rat + vals : Fin seq → Dyadic /-- Optional logit-diff direction metadata (ignored by the checker). -/ direction : Option DirectionSpec @@ -44,7 +44,7 @@ def checkValueRangeCert [NeZero seq] (c : ValueRangeCert seq) : Bool := /-- `checkValueRangeCert` is sound for `ValueRangeBounds`. -/ theorem checkValueRangeCert_sound [NeZero seq] (c : ValueRangeCert seq) : checkValueRangeCert c = true → - Layers.ValueRangeBounds (Val := Rat) c.lo c.hi c.vals := by + Layers.ValueRangeBounds (Val := Dyadic) c.lo c.hi c.vals := by classical intro hcheck have hcheck' : diff --git a/Nfp/Core/Basic.lean b/Nfp/Core/Basic.lean index 13dffe2..7e3fb48 100644 --- a/Nfp/Core/Basic.lean +++ b/Nfp/Core/Basic.lean @@ -1,6 +1,12 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Mathlib.Data.NNReal.Basic +import Mathlib.Algebra.Order.Floor.Defs +import Mathlib.Data.Rat.Cast.Lemmas +import Mathlib.Data.Rat.Cast.Order +import Init.Data.Dyadic +import Init.Data.Dyadic.Inv +import Init.Data.Dyadic.Round /-! Basic shared definitions for the NFP rewrite. @@ -11,4 +17,579 @@ namespace Nfp /-- Nonnegative mass used for probabilities and weights. -/ abbrev Mass := NNReal +instance : ToString Dyadic := + ⟨fun x => toString x.toRat⟩ + +/-- Default dyadic precision (binary digits after the point). -/ +def defaultDyadicPrec : Int := 48 + +/-- One ulp at the given dyadic precision. -/ +def dyadicUlp (prec : Int := defaultDyadicPrec) : Dyadic := + Dyadic.ofIntWithPrec 1 prec + +/-- Round a rational down to dyadic precision. -/ +def dyadicOfRatDown (q : Rat) (prec : Int := defaultDyadicPrec) : Dyadic := + Rat.toDyadic q prec + +/-- Round a rational up to dyadic precision. -/ +def dyadicOfRatUp (q : Rat) (prec : Int := defaultDyadicPrec) : Dyadic := + Rat.toDyadic q prec + dyadicUlp prec + +instance : Coe Dyadic Rat := ⟨Dyadic.toRat⟩ + +/-- Real cast of a dyadic value via `Rat`. -/ +def dyadicToReal (x : Dyadic) : Real := + (x.toRat : Real) + +instance : Coe Dyadic Real := ⟨dyadicToReal⟩ + +@[simp] theorem dyadicToReal_zero : dyadicToReal 0 = 0 := by + simp [dyadicToReal] + +@[simp] theorem dyadicToReal_one : dyadicToReal 1 = 1 := by + change ((1 : Dyadic).toRat : Real) = 1 + have h : (1 : Dyadic).toRat = (1 : Rat) := Dyadic.toRat_natCast 1 + simp [h] + +@[simp] theorem dyadicToRat_one : (Dyadic.toRat 1 : Rat) = 1 := by + exact Dyadic.toRat_natCast 1 + +@[simp] theorem dyadicOfInt_zero : Dyadic.ofInt 0 = 0 := by + simp [Dyadic.ofInt, Dyadic.ofIntWithPrec] + +@[simp] theorem dyadicOfInt_toRat (i : Int) : (Dyadic.ofInt i).toRat = i := by + change ((i : Dyadic).toRat = i) + exact Dyadic.toRat_intCast (x := i) + +@[simp] theorem dyadicOfInt_succ (n : Nat) : Dyadic.ofInt (n + 1) = Dyadic.ofInt n + 1 := by + apply (Dyadic.toRat_inj).1 + calc + (Dyadic.ofInt (n + 1)).toRat = (n + 1 : Int) := by + simp + _ = (n : Int) + 1 := by + simp + _ = (Dyadic.ofInt n).toRat + 1 := by + simp + _ = (Dyadic.ofInt n + 1).toRat := by + simp [Dyadic.toRat_add] + +theorem dyadicOfRatDown_le (q : Rat) (prec : Int := defaultDyadicPrec) : + (dyadicOfRatDown q prec : Rat) ≤ q := by + simpa [dyadicOfRatDown] using (Rat.toRat_toDyadic_le (x := q) (prec := prec)) + +theorem dyadicOfRatUp_ge (q : Rat) (prec : Int := defaultDyadicPrec) : + q ≤ (dyadicOfRatUp q prec : Rat) := by + have hlt := Rat.lt_toRat_toDyadic_add (x := q) (prec := prec) + exact le_of_lt (by simpa [dyadicOfRatUp, dyadicUlp] using hlt) + +theorem dyadicOfRatDown_le_real (q : Rat) (prec : Int := defaultDyadicPrec) : + (dyadicOfRatDown q prec : Real) ≤ (q : Real) := by + have h := + (Rat.cast_le (K := Real) (p := (dyadicOfRatDown q prec : Rat)) (q := q)).2 + (dyadicOfRatDown_le q prec) + simpa [dyadicToReal] using h + +theorem real_le_dyadicOfRatUp (q : Rat) (prec : Int := defaultDyadicPrec) : + (q : Real) ≤ (dyadicOfRatUp q prec : Real) := by + have h := + (Rat.cast_le (K := Real) (p := q) (q := (dyadicOfRatUp q prec : Rat))).2 + (dyadicOfRatUp_ge q prec) + simpa [dyadicToReal] using h + +theorem dyadicOfRatDown_lt_add_ulp (q : Rat) (prec : Int := defaultDyadicPrec) : + q < (dyadicOfRatDown q prec : Rat) + (dyadicUlp prec : Rat) := by + have h := Rat.lt_toRat_toDyadic_add (x := q) (prec := prec) + simpa [dyadicOfRatDown, dyadicUlp, Dyadic.toRat_add] using h + +theorem dyadicOfRatDown_sub_ulp_lt (q : Rat) (prec : Int := defaultDyadicPrec) : + q - (dyadicUlp prec : Rat) < (dyadicOfRatDown q prec : Rat) := by + exact (sub_lt_iff_lt_add).2 (dyadicOfRatDown_lt_add_ulp q prec) + +theorem dyadicOfRatUp_le_add_ulp (q : Rat) (prec : Int := defaultDyadicPrec) : + (dyadicOfRatUp q prec : Rat) ≤ q + (dyadicUlp prec : Rat) := by + have hdown : (dyadicOfRatDown q prec : Rat) ≤ q := dyadicOfRatDown_le q prec + have hsum : (dyadicOfRatDown q prec : Rat) + (dyadicUlp prec : Rat) ≤ + q + (dyadicUlp prec : Rat) := by + exact Rat.add_le_add_right.2 hdown + simpa [dyadicOfRatUp, dyadicUlp, dyadicOfRatDown, Dyadic.toRat_add] using hsum + +theorem dyadicOfRatDown_lt_add_ulp_real (q : Rat) (prec : Int := defaultDyadicPrec) : + (q : Real) < (dyadicOfRatDown q prec : Real) + (dyadicUlp prec : Real) := by + have h := + (Rat.cast_lt (K := Real) (p := q) + (q := (dyadicOfRatDown q prec : Rat) + (dyadicUlp prec : Rat))).2 + (dyadicOfRatDown_lt_add_ulp q prec) + simpa [dyadicToReal] using h + +theorem dyadicOfRatUp_le_add_ulp_real (q : Rat) (prec : Int := defaultDyadicPrec) : + (dyadicOfRatUp q prec : Real) ≤ (q : Real) + (dyadicUlp prec : Real) := by + have h := + (Rat.cast_le (K := Real) + (p := (dyadicOfRatUp q prec : Rat)) + (q := q + (dyadicUlp prec : Rat))).2 + (dyadicOfRatUp_le_add_ulp q prec) + simpa [dyadicToReal] using h + +theorem dyadicOfRatDown_nonneg {q : Rat} (hq : 0 ≤ q) (prec : Int := defaultDyadicPrec) : + 0 ≤ dyadicOfRatDown q prec := by + apply (Dyadic.toRat_le_toRat_iff (x := 0) (y := dyadicOfRatDown q prec)).1 + have hrat : + (dyadicOfRatDown q prec : Rat) = + (q * 2 ^ prec).floor / 2 ^ prec := by + simpa [dyadicOfRatDown] using (Rat.toRat_toDyadic (x := q) (prec := prec)) + have hpow_pos : (0 : Rat) < (2 : Rat) ^ prec := by + exact Rat.zpow_pos (by decide : (0 : Rat) < 2) + have hmul_nonneg : (0 : Rat) ≤ q * 2 ^ prec := by + exact mul_nonneg hq (le_of_lt hpow_pos) + have hfloor_nonneg : 0 ≤ (q * 2 ^ prec).floor := by + exact (Int.floor_nonneg (a := q * 2 ^ prec)).2 hmul_nonneg + have hfloor_nonneg_rat : (0 : Rat) ≤ ((q * 2 ^ prec).floor : Rat) := by + exact_mod_cast hfloor_nonneg + have hdiv_nonneg : + (0 : Rat) ≤ ((q * 2 ^ prec).floor : Rat) / (2 : Rat) ^ prec := by + exact div_nonneg hfloor_nonneg_rat (le_of_lt hpow_pos) + simpa [hrat] using hdiv_nonneg + +private lemma dyadicUlp_rat (prec : Int) : + (dyadicUlp prec : Rat) = (1 : Rat) * 2 ^ (-prec) := by + simp [dyadicUlp, Dyadic.toRat_ofIntWithPrec_eq_mul_two_pow] + +theorem dyadicUlp_nonneg (prec : Int := defaultDyadicPrec) : 0 ≤ dyadicUlp prec := by + apply (Dyadic.toRat_le_toRat_iff (x := 0) (y := dyadicUlp prec)).1 + have hrat : (dyadicUlp prec : Rat) = (1 : Rat) * 2 ^ (-prec) := dyadicUlp_rat prec + have hpow_pos : (0 : Rat) < (2 : Rat) ^ (-prec) := by + exact Rat.zpow_pos (by decide : (0 : Rat) < 2) + have hnonneg : (0 : Rat) ≤ (1 : Rat) * 2 ^ (-prec) := by + exact mul_nonneg (by decide : (0 : Rat) ≤ 1) (le_of_lt hpow_pos) + simpa [hrat] using hnonneg + +theorem dyadicUlp_pos (prec : Int := defaultDyadicPrec) : 0 < dyadicUlp prec := by + apply (Dyadic.toRat_lt_toRat_iff (x := 0) (y := dyadicUlp prec)).1 + have hrat : (dyadicUlp prec : Rat) = (1 : Rat) * 2 ^ (-prec) := dyadicUlp_rat prec + have hpow_pos : (0 : Rat) < (2 : Rat) ^ (-prec) := by + exact Rat.zpow_pos (by decide : (0 : Rat) < 2) + have hpos : (0 : Rat) < (1 : Rat) * 2 ^ (-prec) := by + exact mul_pos (by decide : (0 : Rat) < 1) hpow_pos + simpa [hrat] using hpos + +theorem dyadicOfRatUp_nonneg {q : Rat} (hq : 0 ≤ q) (prec : Int := defaultDyadicPrec) : + 0 ≤ dyadicOfRatUp q prec := by + have hdown : 0 ≤ dyadicOfRatDown q prec := dyadicOfRatDown_nonneg hq prec + have hulp : 0 ≤ dyadicUlp prec := dyadicUlp_nonneg prec + apply (Dyadic.toRat_le_toRat_iff (x := 0) (y := dyadicOfRatUp q prec)).1 + have hdown_rat : (0 : Rat) ≤ (dyadicOfRatDown q prec : Rat) := + (Dyadic.toRat_le_toRat_iff (x := 0) (y := dyadicOfRatDown q prec)).2 hdown + have hulp_rat : (0 : Rat) ≤ (dyadicUlp prec : Rat) := + (Dyadic.toRat_le_toRat_iff (x := 0) (y := dyadicUlp prec)).2 hulp + have hsum : (0 : Rat) ≤ (dyadicOfRatDown q prec : Rat) + (dyadicUlp prec : Rat) := by + exact Rat.add_nonneg hdown_rat hulp_rat + simpa [dyadicOfRatUp, dyadicUlp, dyadicOfRatDown, Dyadic.toRat_add] using hsum + +theorem dyadicOfRatUp_pos {q : Rat} (hq : 0 < q) (prec : Int := defaultDyadicPrec) : + 0 < dyadicOfRatUp q prec := by + apply (Dyadic.toRat_lt_toRat_iff (x := 0) (y := dyadicOfRatUp q prec)).1 + have hdown_nonneg : (0 : Rat) ≤ (dyadicOfRatDown q prec : Rat) := by + have hdown : 0 ≤ dyadicOfRatDown q prec := dyadicOfRatDown_nonneg hq.le prec + exact (Dyadic.toRat_le_toRat_iff (x := 0) (y := dyadicOfRatDown q prec)).2 hdown + have hulp_pos : (0 : Rat) < (dyadicUlp prec : Rat) := by + exact (Dyadic.toRat_lt_toRat_iff (x := 0) (y := dyadicUlp prec)).2 (dyadicUlp_pos prec) + have hsum : (0 : Rat) < (dyadicOfRatDown q prec : Rat) + (dyadicUlp prec : Rat) := by + nlinarith + simpa [dyadicOfRatUp, dyadicUlp, dyadicOfRatDown, Dyadic.toRat_add] using hsum + +-- TODO: use a tighter upper rounding when needed. + +/-- Dyadic division with downward rounding at the chosen precision. -/ +def dyadicDivDown (x y : Dyadic) (prec : Int := defaultDyadicPrec) : Dyadic := + if y = 0 then + 0 + else + dyadicOfRatDown (x.toRat / y.toRat) prec + +/-- Dyadic division with upward rounding at the chosen precision. -/ +def dyadicDivUp (x y : Dyadic) (prec : Int := defaultDyadicPrec) : Dyadic := + if y = 0 then + 0 + else + dyadicOfRatUp (x.toRat / y.toRat) prec + +theorem dyadicDivUp_ge (x y : Dyadic) (hy : y ≠ 0) : + (x.toRat / y.toRat : Rat) ≤ (dyadicDivUp x y : Rat) := by + have hrat : (x.toRat / y.toRat : Rat) ≤ (dyadicOfRatUp (x.toRat / y.toRat) : Rat) := + dyadicOfRatUp_ge (x.toRat / y.toRat) + simpa [dyadicDivUp, hy] using hrat + +theorem dyadicDivUp_ge_real (x y : Dyadic) (hy : y ≠ 0) : + (x.toRat : Real) / (y.toRat : Real) ≤ dyadicToReal (dyadicDivUp x y) := by + have hrat : (x.toRat / y.toRat : Rat) ≤ (dyadicDivUp x y : Rat) := + dyadicDivUp_ge x y hy + have hrat' : ((x.toRat / y.toRat : Rat) : Real) ≤ ((dyadicDivUp x y : Rat) : Real) := + (Rat.cast_le (K := Real) (p := _) (q := _)).2 hrat + simpa [dyadicToReal] using hrat' + +@[simp] theorem dyadicToReal_add (x y : Dyadic) : + dyadicToReal (x + y) = dyadicToReal x + dyadicToReal y := by + simp [dyadicToReal, Dyadic.toRat_add] + +@[simp] theorem dyadicToReal_sub (x y : Dyadic) : + dyadicToReal (x - y) = dyadicToReal x - dyadicToReal y := by + simp [dyadicToReal, Dyadic.toRat_sub] + +@[simp] theorem dyadicToReal_mul (x y : Dyadic) : + dyadicToReal (x * y) = dyadicToReal x * dyadicToReal y := by + simp [dyadicToReal, Dyadic.toRat_mul] + +@[simp] theorem dyadicToReal_neg (x : Dyadic) : + dyadicToReal (-x) = -dyadicToReal x := by + simp [dyadicToReal, Dyadic.toRat_neg] + +@[simp] theorem dyadicToReal_if {p : Prop} [Decidable p] (a b : Dyadic) : + dyadicToReal (if p then a else b) = + if p then dyadicToReal a else dyadicToReal b := by + by_cases hp : p <;> simp [hp] + +theorem dyadicToReal_le_iff {x y : Dyadic} : + dyadicToReal x ≤ dyadicToReal y ↔ x ≤ y := by + constructor + · intro h + have h' : x.toRat ≤ y.toRat := by + have h'' : (x.toRat : Real) ≤ (y.toRat : Real) := by + simpa [dyadicToReal] using h + exact (Rat.cast_le (K := Real) (p := x.toRat) (q := y.toRat)).1 h'' + exact (Dyadic.toRat_le_toRat_iff).1 h' + · intro h + have h' : x.toRat ≤ y.toRat := (Dyadic.toRat_le_toRat_iff).2 h + have h'' : (x.toRat : Real) ≤ (y.toRat : Real) := + (Rat.cast_le (K := Real) (p := x.toRat) (q := y.toRat)).2 h' + simpa [dyadicToReal] using h'' + +/-- Dyadic order implies real order after casting. -/ +theorem dyadicToReal_le_of_le {x y : Dyadic} (h : x ≤ y) : + dyadicToReal x ≤ dyadicToReal y := + (dyadicToReal_le_iff (x := x) (y := y)).2 h + +theorem dyadicToReal_lt_iff {x y : Dyadic} : + dyadicToReal x < dyadicToReal y ↔ x < y := by + constructor + · intro h + have h' : x.toRat < y.toRat := by + have h'' : (x.toRat : Real) < (y.toRat : Real) := by + simpa [dyadicToReal] using h + exact (Rat.cast_lt (K := Real) (p := x.toRat) (q := y.toRat)).1 h'' + exact (Dyadic.toRat_lt_toRat_iff).1 h' + · intro h + have h' : x.toRat < y.toRat := (Dyadic.toRat_lt_toRat_iff).2 h + have h'' : (x.toRat : Real) < (y.toRat : Real) := + (Rat.cast_lt (K := Real) (p := x.toRat) (q := y.toRat)).2 h' + simpa [dyadicToReal] using h'' + +theorem dyadicToReal_nonneg_iff {x : Dyadic} : + 0 ≤ dyadicToReal x ↔ 0 ≤ x := by + simpa [dyadicToReal] using (dyadicToReal_le_iff (x := 0) (y := x)) + +theorem dyadicToReal_nonneg_of_nonneg {x : Dyadic} (h : 0 ≤ x) : + 0 ≤ dyadicToReal x := + (dyadicToReal_nonneg_iff (x := x)).2 h + +theorem dyadicToReal_nonpos_iff {x : Dyadic} : + dyadicToReal x ≤ 0 ↔ x ≤ 0 := by + simpa [dyadicToReal_zero] using (dyadicToReal_le_iff (x := x) (y := 0)) + +instance : LinearOrder Dyadic where + le := (· ≤ ·) + lt := (· < ·) + le_refl := Dyadic.le_refl + le_trans := by intro a b c hab hbc; exact Dyadic.le_trans hab hbc + le_antisymm := by intro a b hab hba; exact Dyadic.le_antisymm hab hba + le_total := Dyadic.le_total + toDecidableLE := inferInstance + toDecidableEq := inferInstance + toDecidableLT := inferInstance + lt_iff_le_not_ge := by + intro a b + have hlt : a < b ↔ a.toRat < b.toRat := + (Dyadic.toRat_lt_toRat_iff (x := a) (y := b)).symm + have hle : a ≤ b ↔ a.toRat ≤ b.toRat := + (Dyadic.toRat_le_toRat_iff (x := a) (y := b)).symm + have hge : b ≤ a ↔ b.toRat ≤ a.toRat := + (Dyadic.toRat_le_toRat_iff (x := b) (y := a)).symm + have hrat : a.toRat < b.toRat ↔ a.toRat ≤ b.toRat ∧ ¬ b.toRat ≤ a.toRat := by + simpa using (Rat.lt_iff_le_not_ge (a := a.toRat) (b := b.toRat)) + calc + a < b ↔ a.toRat < b.toRat := hlt + _ ↔ a.toRat ≤ b.toRat ∧ ¬ b.toRat ≤ a.toRat := hrat + _ ↔ a ≤ b ∧ ¬ b ≤ a := by + constructor + · intro h + refine ⟨(hle.mpr h.1), ?_⟩ + intro hba + exact h.2 (hge.mp hba) + · intro h + refine ⟨(hle.mp h.1), ?_⟩ + intro hba + exact h.2 (hge.mpr hba) + min_def := by intro a b; rfl + max_def := by intro a b; rfl + compare_eq_compareOfLessAndEq := by intro a b; rfl + +instance : AddMonoid Dyadic where + add := (· + ·) + zero := 0 + add_assoc := Dyadic.add_assoc + zero_add := Dyadic.zero_add + add_zero := Dyadic.add_zero + nsmul := nsmulRec + nsmul_zero := by intro x; rfl + nsmul_succ := by intro n x; rfl + +instance : AddCommMonoid Dyadic := + { (inferInstance : AddMonoid Dyadic) with + add_comm := Dyadic.add_comm } + +instance : AddMonoidWithOne Dyadic where + add := (· + ·) + zero := 0 + add_assoc := Dyadic.add_assoc + zero_add := Dyadic.zero_add + add_zero := Dyadic.add_zero + nsmul := nsmulRec + nsmul_zero := by intro x; rfl + nsmul_succ := by intro n x; rfl + one := 1 + natCast := fun n => Dyadic.ofInt n + natCast_zero := by + simp [dyadicOfInt_zero] + natCast_succ := by + intro n + simp [dyadicOfInt_succ n] + +instance : AddCommMonoidWithOne Dyadic where + add := (· + ·) + zero := 0 + add_assoc := Dyadic.add_assoc + zero_add := Dyadic.zero_add + add_zero := Dyadic.add_zero + nsmul := nsmulRec + nsmul_zero := by intro x; rfl + nsmul_succ := by intro n x; rfl + one := 1 + natCast := fun n => Dyadic.ofInt n + natCast_zero := by + simp [dyadicOfInt_zero] + natCast_succ := by + intro n + simp [dyadicOfInt_succ n] + add_comm := Dyadic.add_comm + +instance : AddGroup Dyadic where + add := (· + ·) + zero := 0 + add_assoc := Dyadic.add_assoc + zero_add := Dyadic.zero_add + add_zero := Dyadic.add_zero + nsmul := nsmulRec + nsmul_zero := by intro x; rfl + nsmul_succ := by intro n x; rfl + neg := Neg.neg + sub := fun a b => a + -b + sub_eq_add_neg := by intro a b; rfl + zsmul := zsmulRec + zsmul_zero' := by intro a; rfl + zsmul_succ' := by intro n a; rfl + zsmul_neg' := by intro n a; rfl + neg_add_cancel := Dyadic.neg_add_cancel + +instance : AddCommGroup Dyadic := + { (inferInstance : AddGroup Dyadic) with + add_comm := Dyadic.add_comm } + +instance : IsOrderedAddMonoid Dyadic where + add_le_add_left a b h c := by + have hrat : a.toRat ≤ b.toRat := + (Dyadic.toRat_le_toRat_iff (x := a) (y := b)).2 h + have hrat' : a.toRat + c.toRat ≤ b.toRat + c.toRat := by + exact Rat.add_le_add_right.2 hrat + have hrat'' : + (a + c).toRat ≤ (b + c).toRat := by + simpa [Dyadic.toRat_add] using hrat' + exact (Dyadic.toRat_le_toRat_iff (x := a + c) (y := b + c)).1 hrat'' + +instance : ExistsAddOfLE Dyadic where + exists_add_of_le {a b} h := by + refine ⟨b - a, ?_⟩ + simp [sub_eq_add_neg] + +instance : Monoid Dyadic where + mul := (· * ·) + one := 1 + mul_assoc := Dyadic.mul_assoc + one_mul := Dyadic.one_mul + mul_one := Dyadic.mul_one + +instance : MonoidWithZero Dyadic := + { (inferInstance : Monoid Dyadic) with + zero := 0 + zero_mul := Dyadic.zero_mul + mul_zero := Dyadic.mul_zero } + +instance : Distrib Dyadic where + left_distrib := Dyadic.mul_add + right_distrib := Dyadic.add_mul + +instance : Semiring Dyadic where + add := (· + ·) + zero := 0 + add_assoc := Dyadic.add_assoc + zero_add := Dyadic.zero_add + add_zero := Dyadic.add_zero + add_comm := Dyadic.add_comm + nsmul := nsmulRec + nsmul_zero := by intro x; rfl + nsmul_succ := by intro n x; rfl + one := 1 + natCast := fun n => Dyadic.ofInt n + natCast_zero := by + simp [dyadicOfInt_zero] + natCast_succ := by + intro n + simp [dyadicOfInt_succ n] + mul := (· * ·) + mul_assoc := Dyadic.mul_assoc + one_mul := Dyadic.one_mul + mul_one := Dyadic.mul_one + left_distrib := Dyadic.mul_add + right_distrib := Dyadic.add_mul + zero_mul := Dyadic.zero_mul + mul_zero := Dyadic.mul_zero + +instance : CommSemiring Dyadic where + toSemiring := (inferInstance : Semiring Dyadic) + mul_comm := by intro a b; exact Dyadic.mul_comm a b + +instance : AddGroupWithOne Dyadic := + { (inferInstance : AddMonoidWithOne Dyadic), + (inferInstance : AddGroup Dyadic) with + intCast := Int.castDef + intCast_ofNat := by + intro n + apply (Dyadic.toRat_inj).1 + have hleft : (Int.castDef (R := Dyadic) (n : ℤ)).toRat = (n : Rat) := by + simp [Int.castDef, Dyadic.toRat_natCast] + have hright : (n : Dyadic).toRat = (n : Rat) := Dyadic.toRat_natCast n + exact hleft.trans hright.symm + intCast_negSucc := by + intro n + apply (Dyadic.toRat_inj).1 + simp [Int.castDef, Dyadic.toRat_natCast, Dyadic.toRat_neg, Dyadic.toRat_add] } + +instance : Ring Dyadic := + { (inferInstance : Semiring Dyadic), + (inferInstance : AddCommGroup Dyadic), + (inferInstance : AddGroupWithOne Dyadic) with } + +instance : CommRing Dyadic := + { (inferInstance : Ring Dyadic), (inferInstance : CommSemiring Dyadic) with } + +instance : Nontrivial Dyadic := + ⟨0, 1, by + intro h + have hrat : (0 : Rat) = (1 : Rat) := by + simpa [Dyadic.toRat_zero, dyadicToRat_one] using congrArg Dyadic.toRat h + exact (zero_ne_one (α := Rat)) hrat⟩ + +instance : ZeroLEOneClass Dyadic where + zero_le_one := by + have hrat : (0 : Rat) ≤ (1 : Rat) := by decide + exact (Dyadic.toRat_le_toRat_iff (x := (0 : Dyadic)) (y := (1 : Dyadic))).1 hrat + +instance : PosMulMono Dyadic where + mul_le_mul_of_nonneg_left {a} ha {b c} hbc := by + have ha' : (0 : Dyadic) ≤ a := by simpa using ha + have ha'' : (0 : Rat) ≤ a.toRat := + (Dyadic.toRat_le_toRat_iff (x := 0) (y := a)).2 ha' + have hbc' : b.toRat ≤ c.toRat := + (Dyadic.toRat_le_toRat_iff (x := b) (y := c)).2 hbc + have hrat : a.toRat * b.toRat ≤ a.toRat * c.toRat := by + exact Rat.mul_le_mul_of_nonneg_left hbc' ha'' + have hrat' : + (a * b).toRat ≤ (a * c).toRat := by + simpa [Dyadic.toRat_mul] using hrat + exact (Dyadic.toRat_le_toRat_iff (x := a * b) (y := a * c)).1 hrat' + +instance : MulPosMono Dyadic where + mul_le_mul_of_nonneg_right {a} ha {b c} hbc := by + have h := (PosMulMono.mul_le_mul_of_nonneg_left (a := a) ha hbc) + simpa [mul_comm, mul_left_comm, mul_assoc] using h + +instance : IsOrderedRing Dyadic := + IsOrderedRing.of_mul_nonneg (R := Dyadic) (mul_nonneg := by + intro a b ha hb + have ha' : (0 : Dyadic) ≤ a := by simpa using ha + have hb' : (0 : Dyadic) ≤ b := by simpa using hb + have ha'' : (0 : Rat) ≤ a.toRat := + (Dyadic.toRat_le_toRat_iff (x := 0) (y := a)).2 ha' + have hb'' : (0 : Rat) ≤ b.toRat := + (Dyadic.toRat_le_toRat_iff (x := 0) (y := b)).2 hb' + have hrat : (0 : Rat) ≤ a.toRat * b.toRat := Rat.mul_nonneg ha'' hb'' + have hrat' : (0 : Rat) ≤ (a * b).toRat := by + simpa [Dyadic.toRat_mul] using hrat + exact (Dyadic.toRat_le_toRat_iff (x := 0) (y := a * b)).1 hrat') + +@[simp] theorem dyadicToReal_abs (x : Dyadic) : + dyadicToReal |x| = |dyadicToReal x| := by + by_cases hx : 0 ≤ x + · have hx' : 0 ≤ dyadicToReal x := (dyadicToReal_nonneg_iff).2 hx + calc + dyadicToReal |x| = dyadicToReal x := by simp [abs_of_nonneg hx] + _ = |dyadicToReal x| := (abs_of_nonneg hx').symm + · have hx' : x ≤ 0 := le_of_not_ge hx + have hx'' : dyadicToReal x ≤ 0 := by + simpa [dyadicToReal_zero] using dyadicToReal_le_of_le hx' + calc + dyadicToReal |x| = dyadicToReal (-x) := by simp [abs_of_nonpos hx'] + _ = -dyadicToReal x := dyadicToReal_neg x + _ = |dyadicToReal x| := (abs_of_nonpos hx'').symm + +theorem dyadicToReal_abs_le_of_le {x y : Dyadic} (h : |x| ≤ y) : + |dyadicToReal x| ≤ dyadicToReal y := by + have h' : dyadicToReal |x| ≤ dyadicToReal y := + dyadicToReal_le_of_le h + simpa [dyadicToReal_abs] using h' + +@[simp] theorem dyadicToReal_max (x y : Dyadic) : + dyadicToReal (max x y) = max (dyadicToReal x) (dyadicToReal y) := by + by_cases hxy : x ≤ y + · have hxy' : dyadicToReal x ≤ dyadicToReal y := + dyadicToReal_le_of_le hxy + calc + dyadicToReal (max x y) = dyadicToReal y := by simp [max_eq_right hxy] + _ = max (dyadicToReal x) (dyadicToReal y) := by + symm + exact max_eq_right hxy' + · have hyx : y ≤ x := le_of_not_ge hxy + have hyx' : dyadicToReal y ≤ dyadicToReal x := + dyadicToReal_le_of_le hyx + calc + dyadicToReal (max x y) = dyadicToReal x := by simp [max_eq_left hyx] + _ = max (dyadicToReal x) (dyadicToReal y) := by + exact (max_eq_left hyx').symm + +@[simp] theorem dyadicToReal_min (x y : Dyadic) : + dyadicToReal (min x y) = min (dyadicToReal x) (dyadicToReal y) := by + by_cases hxy : x ≤ y + · have hxy' : dyadicToReal x ≤ dyadicToReal y := + dyadicToReal_le_of_le hxy + calc + dyadicToReal (min x y) = dyadicToReal x := by simp [min_eq_left hxy] + _ = min (dyadicToReal x) (dyadicToReal y) := by + symm + exact min_eq_left hxy' + · have hyx : y ≤ x := le_of_not_ge hxy + have hyx' : dyadicToReal y ≤ dyadicToReal x := + dyadicToReal_le_of_le hyx + calc + dyadicToReal (min x y) = dyadicToReal y := by simp [min_eq_right hyx] + _ = min (dyadicToReal x) (dyadicToReal y) := by + exact (min_eq_right hyx').symm + end Nfp diff --git a/Nfp/IO.lean b/Nfp/IO.lean index bffd260..1f2fe1e 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -1,8 +1,11 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.List.Range -import Nfp.IO.Pure +import Nfp.IO.Checks +import Nfp.IO.Derive +import Nfp.IO.Loaders import Nfp.IO.NfptPure +import Nfp.IO.HeadScore +import Nfp.IO.InductionHead +import Nfp.IO.Util import Nfp.Circuit.Cert.LogitDiff import Nfp.Circuit.Cert.DownstreamLinear import Nfp.Circuit.Cert.ResidualBound @@ -10,183 +13,22 @@ import Nfp.Circuit.Cert.ResidualInterval import Nfp.Sound.Bounds.MatrixNorm import Nfp.Sound.Bounds.Transformer import Nfp.Sound.Induction +import Nfp.Sound.Induction.HeadBounds import Nfp.Sound.Induction.LogitDiff - -/-! -IO wrappers for loading and checking induction certificates. --/ - +import Nfp.Sound.Linear.FinFold +import Nfp.IO.Timing namespace Nfp - namespace IO - open Nfp.Circuit -/-- Load a softmax-margin certificate from disk. -/ -def loadSoftmaxMarginCert (path : System.FilePath) : - IO (Except String (Sigma SoftmaxMarginCert)) := do - let data ← IO.FS.readFile path - return Pure.parseSoftmaxMarginCert data - -/-- Load raw softmax-margin inputs from disk. -/ -def loadSoftmaxMarginRaw (path : System.FilePath) : - IO (Except String (Sigma Pure.SoftmaxMarginRaw)) := do - let data ← IO.FS.readFile path - return Pure.parseSoftmaxMarginRaw data - -/-- Load a value-range certificate from disk. -/ -def loadValueRangeCert (path : System.FilePath) : - IO (Except String (Sigma ValueRangeCert)) := do - let data ← IO.FS.readFile path - return Pure.parseValueRangeCert data - -/-- Load a downstream linear certificate from disk. -/ -def loadDownstreamLinearCert (path : System.FilePath) : - IO (Except String DownstreamLinearCert) := do - let data ← IO.FS.readFile path - return Pure.parseDownstreamLinearCert data - -/-- Load a downstream matrix payload from disk. -/ -def loadDownstreamMatrixRaw (path : System.FilePath) : - IO (Except String (Sigma (fun rows => - Sigma (fun cols => Pure.DownstreamMatrixRaw rows cols)))) := do - let data ← IO.FS.readFile path - return Pure.parseDownstreamMatrixRaw data - -/-- Load a residual-bound certificate from disk. -/ -def loadResidualBoundCert (path : System.FilePath) : - IO (Except String (Sigma (fun n => ResidualBoundCert n))) := do - let data ← IO.FS.readFile path - return Pure.parseResidualBoundCert data - -/-- Load a residual-interval certificate from disk. -/ -def loadResidualIntervalCert (path : System.FilePath) : - IO (Except String (Sigma (fun n => ResidualIntervalCert n))) := do - let data ← IO.FS.readFile path - return Pure.parseResidualIntervalCert data - -/-- Load raw value-range inputs from disk. -/ -def loadValueRangeRaw (path : System.FilePath) : - IO (Except String (Sigma Pure.ValueRangeRaw)) := do - let data ← IO.FS.readFile path - return Pure.parseValueRangeRaw data - -/-- Load induction head inputs from disk. -/ -def loadInductionHeadInputs (path : System.FilePath) : - IO (Except String (Sigma (fun seq => - Sigma (fun dModel => Sigma (fun dHead => Model.InductionHeadInputs seq dModel dHead))))) := do - let data ← IO.FS.readFile path - return Pure.parseInductionHeadInputs data - -private def renderResidualIntervalCert {n : Nat} (c : ResidualIntervalCert n) : String := - let header := s!"dim {n}" - let lines := - (List.finRange n).foldr (fun i acc => - s!"lo {i.val} {c.lo i}" :: s!"hi {i.val} {c.hi i}" :: acc) [] - String.intercalate "\n" (header :: lines) - -private def emitResidualIntervalCert {n : Nat} (c : ResidualIntervalCert n) - (outPath? : Option System.FilePath) : IO Unit := do - let payload := renderResidualIntervalCert c - match outPath? with - | some path => IO.FS.writeFile path (payload ++ "\n") - | none => IO.println payload - -/-! Derived residual intervals from model binaries. -/ - -/-- Derive residual-interval bounds from a model binary via interval propagation. -/ -private def deriveResidualIntervalFromModel (data : ByteArray) (start : Nat) - (header : NfptPure.NfptHeader) : - Except String (ResidualIntervalCert header.modelDim) := do - if hseq : header.seqLen = 0 then - throw "seq must be positive" - else - have _ : NeZero header.seqLen := ⟨hseq⟩ - if header.modelDim = 0 then - throw "model dim must be positive" - else if 0 < header.layerNormEps then - let embed ← NfptPure.readEmbeddings data start header - let layerSlices ← NfptPure.readLayerSlices data start header - let headLayers ← NfptPure.readLayerHeads data start header - let finalLn ← NfptPure.readFinalLayerNorm data start header - let layers : Fin header.numLayers → Model.Gpt2LayerSlice header.modelDim header.hiddenDim := - fun l => NfptPure.SizedArray.get layerSlices l - let heads : - Fin header.numLayers → Fin header.numHeads → - Model.Gpt2HeadWeights header.modelDim header.headDim := fun l h => - NfptPure.SizedArray.get (NfptPure.SizedArray.get headLayers l) h - let bounds := - Sound.Bounds.gpt2ResidualIntervalBounds (eps := header.layerNormEps) - layers heads finalLn embed - return { lo := bounds.1, hi := bounds.2 } - else - throw s!"layer norm epsilon {header.layerNormEps} must be positive" - -private def buildHeadOutputIntervalFromInputs {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (outPath? : Option System.FilePath) : IO UInt32 := do - match seq with - | 0 => - IO.eprintln "error: seq must be positive" - return 2 - | Nat.succ n => - let seq := Nat.succ n - let _ : NeZero seq := ⟨by simp⟩ - match Sound.buildHeadOutputIntervalFromHead? inputs with - | none => - IO.eprintln "error: head output interval rejected" - return 2 - | some result => - emitResidualIntervalCert result.cert outPath? - if outPath?.isSome then - let activeCount := result.active.card - IO.println - s!"ok: head output interval built (seq={seq}, dim={dModel}, active={activeCount})" - return 0 - -private def checkSoftmaxMargin (seq : Nat) (cert : SoftmaxMarginCert seq) : - IO (Except String Unit) := - match seq with - | 0 => return Except.error "seq must be positive" - | Nat.succ n => - let seq := Nat.succ n - let _ : NeZero seq := ⟨by simp⟩ - let ok := Circuit.checkSoftmaxMarginCert cert - if ok then - return Except.ok () - else - return Except.error "softmax-margin certificate rejected" - -private def checkValueRange (seq : Nat) (cert : ValueRangeCert seq) : - IO (Except String Unit) := - match seq with - | 0 => return Except.error "seq must be positive" - | Nat.succ n => - let seq := Nat.succ n - let _ : NeZero seq := ⟨by simp⟩ - let ok := Circuit.checkValueRangeCert cert - if ok then - return Except.ok () - else - return Except.error "value-range certificate rejected" - -private def parseRatOpt (label : String) (raw? : Option String) : - Except String (Option Rat) := - match raw? with - | none => Except.ok none - | some raw => - match Pure.parseRat raw with - | Except.ok v => Except.ok (some v) - | Except.error msg => Except.error s!"invalid {label}: {msg}" - /-- Check induction certificates and print a short status line. -/ def runInductionCertify (scoresPath : System.FilePath) (valuesPath? : Option System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + let minLogitDiff?E := parseDyadicOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseDyadicOpt "min-margin" minMarginStr? + let maxEps?E := parseDyadicOpt "max-eps" maxEpsStr? match minLogitDiff?E, minMargin?E, maxEps?E with | Except.error msg, _, _ => IO.eprintln s!"error: {msg}" @@ -198,18 +40,20 @@ def runInductionCertify (scoresPath : System.FilePath) IO.eprintln s!"error: {msg}" return 2 | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (1 / 2 : Rat) + let minMargin := minMargin?.getD (0 : Dyadic) + let maxEps := maxEps?.getD (dyadicOfRatDown (Rat.divInt 1 2)) if minLogitDiff?.isSome && valuesPath?.isNone then IO.eprintln "error: min-logit-diff requires --values" return 2 - let parsedScores ← loadSoftmaxMarginCert scoresPath + let parsedScores ← timePhase "load softmax cert" <| + loadSoftmaxMarginCert scoresPath match parsedScores with | Except.error msg => IO.eprintln s!"error: {msg}" return 1 | Except.ok ⟨seq, cert⟩ => - let scoresOk ← checkSoftmaxMargin seq cert + let scoresOk ← timePhase "check softmax cert" <| + checkSoftmaxMargin seq cert match scoresOk with | Except.error msg => IO.eprintln s!"error: {msg}" @@ -264,14 +108,14 @@ def runInductionCertify (scoresPath : System.FilePath) let effectiveMinLogitDiff := match minLogitDiff?, certVals'.direction with | some v, _ => some v - | none, some _ => some (0 : Rat) + | none, some _ => some (0 : Dyadic) | none, none => none match logitDiffLB? with | none => IO.eprintln "error: empty active set for logit-diff bound" return (2 : UInt32) | some logitDiffLB => - let violation? : Option Rat := + let violation? : Option Dyadic := match effectiveMinLogitDiff with | none => none | some minLogitDiff => @@ -291,15 +135,14 @@ def runInductionCertify (scoresPath : System.FilePath) (seq={seq}, active={activeCount}, tol={tol}, \ logitDiffLB={logitDiffLB})" return 0 - /-- Build and check induction certificates from raw scores/values. -/ def runInductionCertifySound (scoresPath : System.FilePath) (valuesPath : System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + let minLogitDiff?E := parseDyadicOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseDyadicOpt "min-margin" minMarginStr? + let maxEps?E := parseDyadicOpt "max-eps" maxEpsStr? match minLogitDiff?E, minMargin?E, maxEps?E with | Except.error msg, _, _ => IO.eprintln s!"error: {msg}" @@ -311,8 +154,8 @@ def runInductionCertifySound (scoresPath : System.FilePath) IO.eprintln s!"error: {msg}" return 2 | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (1 / 2 : Rat) + let minMargin := minMargin?.getD (0 : Dyadic) + let maxEps := maxEps?.getD (dyadicOfRatDown (Rat.divInt 1 2)) let parsedScores ← loadSoftmaxMarginRaw scoresPath match parsedScores with | Except.error msg => @@ -373,14 +216,14 @@ def runInductionCertifySound (scoresPath : System.FilePath) let effectiveMinLogitDiff := match minLogitDiff?, certVals.direction with | some v, _ => some v - | none, some _ => some (0 : Rat) + | none, some _ => some (0 : Dyadic) | none, none => none match logitDiffLB? with | none => IO.eprintln "error: empty active set for logit-diff bound" return 2 | some logitDiffLB => - let violation? : Option Rat := + let violation? : Option Dyadic := match effectiveMinLogitDiff with | none => none | some minLogitDiff => @@ -400,15 +243,14 @@ def runInductionCertifySound (scoresPath : System.FilePath) (seq={seq}, active={activeCount}, \ tol={tol}, logitDiffLB={logitDiffLB})" return 0 - /-- Check end-to-end induction certificates with a downstream error bound. -/ def runInductionCertifyEndToEnd (scoresPath : System.FilePath) (valuesPath : System.FilePath) (downstreamPath : System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + let minLogitDiff?E := parseDyadicOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseDyadicOpt "min-margin" minMarginStr? + let maxEps?E := parseDyadicOpt "max-eps" maxEpsStr? match minLogitDiff?E, minMargin?E, maxEps?E with | Except.error msg, _, _ => IO.eprintln s!"error: {msg}" @@ -420,15 +262,17 @@ def runInductionCertifyEndToEnd (scoresPath : System.FilePath) IO.eprintln s!"error: {msg}" return 2 | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (1 / 2 : Rat) - let parsedScores ← loadSoftmaxMarginCert scoresPath + let minMargin := minMargin?.getD (0 : Dyadic) + let maxEps := maxEps?.getD (dyadicOfRatDown (Rat.divInt 1 2)) + let parsedScores ← timePhase "load softmax cert" <| + loadSoftmaxMarginCert scoresPath match parsedScores with | Except.error msg => IO.eprintln s!"error: {msg}" return 1 | Except.ok ⟨seq, cert⟩ => - let scoresOk ← checkSoftmaxMargin seq cert + let scoresOk ← timePhase "check softmax cert" <| + checkSoftmaxMargin seq cert match scoresOk with | Except.error msg => IO.eprintln s!"error: {msg}" @@ -449,7 +293,8 @@ def runInductionCertifyEndToEnd (scoresPath : System.FilePath) IO.eprintln s!"error: eps {cert.eps} above maximum {maxEps}" return 2 - let parsedValues ← loadValueRangeCert valuesPath + let parsedValues ← timePhase "load value cert" <| + loadValueRangeCert valuesPath match parsedValues with | Except.error msg => IO.eprintln s!"error: {msg}" @@ -463,19 +308,20 @@ def runInductionCertifyEndToEnd (scoresPath : System.FilePath) exact (not_ne_iff).1 hseq let certVals' : ValueRangeCert seq := by simpa [hseq'] using certVals - let valuesOk ← checkValueRange seq certVals' + let valuesOk ← timePhase "check value cert" <| + checkValueRange seq certVals' match valuesOk with | Except.error msg => IO.eprintln s!"error: {msg}" return 2 | Except.ok () => - let logitDiffLB? := + let logitDiffLB? ← timePure "logit-diff lower bound" (fun () => Circuit.logitDiffLowerBound cert.active cert.prev cert.eps - certVals'.lo certVals'.hi certVals'.vals + certVals'.lo certVals'.hi certVals'.vals) let effectiveMinLogitDiff := match minLogitDiff?, certVals'.direction with | some v, _ => some v - | none, some _ => some (0 : Rat) + | none, some _ => some (0 : Dyadic) | none, none => none match logitDiffLB? with | none => @@ -491,7 +337,7 @@ def runInductionCertifyEndToEnd (scoresPath : System.FilePath) let downstreamOk := Circuit.checkDownstreamLinearCert downstream if downstreamOk then let finalLB := logitDiffLB - downstream.error - let violation? : Option Rat := + let violation? : Option Dyadic := match effectiveMinLogitDiff with | none => none | some minLogitDiff => @@ -516,15 +362,14 @@ def runInductionCertifyEndToEnd (scoresPath : System.FilePath) else IO.eprintln "error: downstream certificate rejected" return 2 - /-- Check end-to-end induction certificates with a downstream matrix. -/ def runInductionCertifyEndToEndMatrix (scoresPath : System.FilePath) (valuesPath : System.FilePath) (matrixPath : System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + let minLogitDiff?E := parseDyadicOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseDyadicOpt "min-margin" minMarginStr? + let maxEps?E := parseDyadicOpt "max-eps" maxEpsStr? match minLogitDiff?E, minMargin?E, maxEps?E with | Except.error msg, _, _ => IO.eprintln s!"error: {msg}" @@ -536,15 +381,17 @@ def runInductionCertifyEndToEndMatrix (scoresPath : System.FilePath) IO.eprintln s!"error: {msg}" return 2 | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (1 / 2 : Rat) - let parsedScores ← loadSoftmaxMarginCert scoresPath + let minMargin := minMargin?.getD (0 : Dyadic) + let maxEps := maxEps?.getD (dyadicOfRatDown (Rat.divInt 1 2)) + let parsedScores ← timePhase "load softmax cert" <| + loadSoftmaxMarginCert scoresPath match parsedScores with | Except.error msg => IO.eprintln s!"error: {msg}" return 1 | Except.ok ⟨seq, cert⟩ => - let scoresOk ← checkSoftmaxMargin seq cert + let scoresOk ← timePhase "check softmax cert" <| + checkSoftmaxMargin seq cert match scoresOk with | Except.error msg => IO.eprintln s!"error: {msg}" @@ -565,7 +412,8 @@ def runInductionCertifyEndToEndMatrix (scoresPath : System.FilePath) IO.eprintln s!"error: eps {cert.eps} above maximum {maxEps}" return 2 - let parsedValues ← loadValueRangeCert valuesPath + let parsedValues ← timePhase "load value cert" <| + loadValueRangeCert valuesPath match parsedValues with | Except.error msg => IO.eprintln s!"error: {msg}" @@ -579,19 +427,20 @@ def runInductionCertifyEndToEndMatrix (scoresPath : System.FilePath) exact (not_ne_iff).1 hseq let certVals' : ValueRangeCert seq := by simpa [hseq'] using certVals - let valuesOk ← checkValueRange seq certVals' + let valuesOk ← timePhase "check value cert" <| + checkValueRange seq certVals' match valuesOk with | Except.error msg => IO.eprintln s!"error: {msg}" return 2 | Except.ok () => - let logitDiffLB? := + let logitDiffLB? ← timePure "logit-diff lower bound" (fun () => Circuit.logitDiffLowerBound cert.active cert.prev cert.eps - certVals'.lo certVals'.hi certVals'.vals + certVals'.lo certVals'.hi certVals'.vals) let effectiveMinLogitDiff := match minLogitDiff?, certVals'.direction with | some v, _ => some v - | none, some _ => some (0 : Rat) + | none, some _ => some (0 : Dyadic) | none, none => none match logitDiffLB? with | none => @@ -612,11 +461,11 @@ def runInductionCertifyEndToEndMatrix (scoresPath : System.FilePath) else have hinput : 0 ≤ inputBound := by exact le_of_not_gt hneg - let W : Matrix (Fin rows) (Fin cols) Rat := raw.entries + let W : Matrix (Fin rows) (Fin cols) Dyadic := raw.entries let downstream := (Sound.Bounds.buildDownstreamLinearCert W inputBound hinput).1 let finalLB := logitDiffLB - downstream.error - let violation? : Option Rat := + let violation? : Option Dyadic := match effectiveMinLogitDiff with | none => none | some minLogitDiff => @@ -638,7 +487,6 @@ def runInductionCertifyEndToEndMatrix (scoresPath : System.FilePath) downstreamError={downstream.error}, \ finalLB={finalLB})" return 0 - /-- Check end-to-end induction certificates using a model file and residual bounds (loaded from disk or derived from the model). -/ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) @@ -646,9 +494,9 @@ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) (residualIntervalPath? : Option System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + let minLogitDiff?E := parseDyadicOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseDyadicOpt "min-margin" minMarginStr? + let maxEps?E := parseDyadicOpt "max-eps" maxEpsStr? match minLogitDiff?E, minMargin?E, maxEps?E with | Except.error msg, _, _ => IO.eprintln s!"error: {msg}" @@ -660,15 +508,17 @@ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) IO.eprintln s!"error: {msg}" return 2 | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (1 / 2 : Rat) - let parsedScores ← loadSoftmaxMarginCert scoresPath + let minMargin := minMargin?.getD (0 : Dyadic) + let maxEps := maxEps?.getD (dyadicOfRatDown (Rat.divInt 1 2)) + let parsedScores ← timePhase "load softmax cert" <| + loadSoftmaxMarginCert scoresPath match parsedScores with | Except.error msg => IO.eprintln s!"error: {msg}" return 1 | Except.ok ⟨seq, cert⟩ => - let scoresOk ← checkSoftmaxMargin seq cert + let scoresOk ← timePhase "check softmax cert" <| + checkSoftmaxMargin seq cert match scoresOk with | Except.error msg => IO.eprintln s!"error: {msg}" @@ -689,7 +539,8 @@ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) IO.eprintln s!"error: eps {cert.eps} above maximum {maxEps}" return 2 - let parsedValues ← loadValueRangeCert valuesPath + let parsedValues ← timePhase "load value cert" <| + loadValueRangeCert valuesPath match parsedValues with | Except.error msg => IO.eprintln s!"error: {msg}" @@ -703,19 +554,20 @@ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) exact (not_ne_iff).1 hseq let certVals' : ValueRangeCert seq := by simpa [hseq'] using certVals - let valuesOk ← checkValueRange seq certVals' + let valuesOk ← timePhase "check value cert" <| + checkValueRange seq certVals' match valuesOk with | Except.error msg => IO.eprintln s!"error: {msg}" return 2 | Except.ok () => - let logitDiffLB? := + let logitDiffLB? ← timePure "logit-diff lower bound" (fun () => Circuit.logitDiffLowerBound cert.active cert.prev cert.eps - certVals'.lo certVals'.hi certVals'.vals + certVals'.lo certVals'.hi certVals'.vals) let effectiveMinLogitDiff := match minLogitDiff?, certVals'.direction with | some v, _ => some v - | none, some _ => some (0 : Rat) + | none, some _ => some (0 : Dyadic) | none, none => none match logitDiffLB? with | none => @@ -729,19 +581,28 @@ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) metadata" return 2 | some dirSpec => - let data ← IO.FS.readBinFile modelPath - match NfptPure.parseHeader data with + let data ← timePhase "read model file" <| + IO.FS.readBinFile modelPath + let headerE ← timePure "parse model header" (fun () => + NfptPure.parseHeader data) + match headerE with | Except.error msg => IO.eprintln s!"error: {msg}" return 1 | Except.ok ⟨header, start⟩ => - if header.seqLen = seq then + if hseq : header.seqLen = seq then + let active? : Option (Finset (Fin header.seqLen)) := + if hactive : cert.active.Nonempty then + some (by simpa [hseq] using cert.active) + else + none let residualCertE : Except String (ResidualIntervalCert header.modelDim) ← match residualIntervalPath? with | some residualIntervalPath => do let parsedResidual ← - loadResidualIntervalCert residualIntervalPath + timePhase "load residual interval" <| + loadResidualIntervalCert residualIntervalPath match parsedResidual with | Except.error msg => pure (Except.error msg) | Except.ok ⟨dim, residualCert⟩ => @@ -755,40 +616,48 @@ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) s!"residual interval dim {dim} \ does not match model dim {header.modelDim}") | none => - pure (deriveResidualIntervalFromModel data start header) + deriveResidualIntervalFromModel data start header + active? match residualCertE with | Except.error msg => IO.eprintln s!"error: {msg}" return 1 | Except.ok residualCert' => - let residualOk := - Circuit.checkResidualIntervalCert residualCert' + let residualOk ← + timePure "check residual interval" (fun () => + Circuit.checkResidualIntervalCert residualCert') if residualOk then let dirPos := dirSpec.target let dirNeg := dirSpec.negative - match - NfptPure.readUnembedColumn data start header dirPos - with + let colTargetE ← + timePure "read unembed column target" (fun () => + NfptPure.readUnembedColumn + data start header dirPos) + match colTargetE with | Except.error msg => IO.eprintln s!"error: {msg}" return 1 | Except.ok colTarget => - match - NfptPure.readUnembedColumn - data start header dirNeg - with + let colNegE ← + timePure "read unembed column negative" (fun () => + NfptPure.readUnembedColumn + data start header dirNeg) + match colNegE with | Except.error msg => IO.eprintln s!"error: {msg}" return 1 | Except.ok colNeg => let dirVec : - Fin header.modelDim → Rat := + Fin header.modelDim → Dyadic := fun i => colTarget i - colNeg i - let downstreamError := - Sound.Bounds.dotIntervalAbsBound - dirVec residualCert'.lo residualCert'.hi + let downstreamError ← + timePure "downstream error" (fun () => + Sound.Bounds.dotIntervalAbsBound + dirVec + residualCert'.lo + residualCert'.hi) let finalLB := logitDiffLB - downstreamError - let violation? : Option Rat := + let violation? : Option Dyadic := match effectiveMinLogitDiff with | none => none | some minLogitDiff => @@ -821,292 +690,5 @@ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) s!"error: model seq {header.seqLen} \ does not match cert seq {seq}" return 2 - -private def checkInductionHeadInputs {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (minActive? : Option Nat) (minLogitDiff? : Option Rat) - (minMargin maxEps : Rat) : IO UInt32 := do - match seq with - | 0 => - IO.eprintln "error: seq must be positive" - return 2 - | Nat.succ n => - let seq := Nat.succ n - let _ : NeZero seq := ⟨by simp⟩ - match Sound.buildInductionCertFromHead? inputs with - | none => - IO.eprintln "error: head inputs rejected" - return 2 - | some ⟨cert, _hcert⟩ => - let activeCount := cert.active.card - let defaultMinActive := max 1 (seq / 8) - let minActive := minActive?.getD defaultMinActive - if activeCount < minActive then - IO.eprintln - s!"error: active queries {activeCount} below minimum {minActive}" - return 2 - if cert.margin < minMargin then - IO.eprintln - s!"error: margin {cert.margin} below minimum {minMargin}" - return 2 - if maxEps < cert.eps then - IO.eprintln - s!"error: eps {cert.eps} above maximum {maxEps}" - return 2 - let tol := cert.eps * (cert.values.hi - cert.values.lo) - let logitDiffLB? := - Circuit.logitDiffLowerBound cert.active cert.prev cert.eps - cert.values.lo cert.values.hi cert.values.valsLo - let effectiveMinLogitDiff := - match minLogitDiff? with - | some v => some v - | none => some (0 : Rat) - match logitDiffLB? with - | none => - IO.eprintln "error: empty active set for logit-diff bound" - return 2 - | some logitDiffLB => - let violation? : Option Rat := - match effectiveMinLogitDiff with - | none => none - | some minLogitDiff => - if logitDiffLB < minLogitDiff then - some minLogitDiff - else - none - match violation? with - | some minLogitDiff => - IO.eprintln - s!"error: logitDiffLB {logitDiffLB} \ - below minimum {minLogitDiff}" - return 2 - | none => - IO.println - s!"ok: induction bound certified \ - (seq={seq}, active={activeCount}, \ - tol={tol}, logitDiffLB={logitDiffLB})" - return 0 - -private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (minActive? : Option Nat) (minLogitDiff? : Option Rat) - (minMargin maxEps : Rat) : IO UInt32 := do - match seq with - | 0 => - IO.eprintln "error: seq must be positive" - return 2 - | Nat.succ n => - let seq := Nat.succ n - let _ : NeZero seq := ⟨by simp⟩ - match Sound.buildInductionLogitLowerBoundNonvacuous? inputs with - | none => - IO.eprintln "error: nonvacuous logit-diff bound unavailable" - return 2 - | some result => - let cert := result.base.cert - let logitDiffLB := result.base.lb - let activeCount := cert.active.card - let defaultMinActive := max 1 (seq / 8) - let minActive := minActive?.getD defaultMinActive - if activeCount < minActive then - IO.eprintln - s!"error: active queries {activeCount} below minimum {minActive}" - return 2 - if cert.margin < minMargin then - IO.eprintln - s!"error: margin {cert.margin} below minimum {minMargin}" - return 2 - if maxEps < cert.eps then - IO.eprintln - s!"error: eps {cert.eps} above maximum {maxEps}" - return 2 - let tol := cert.eps * (cert.values.hi - cert.values.lo) - let effectiveMinLogitDiff := - match minLogitDiff? with - | some v => some v - | none => some (0 : Rat) - let violation? : Option Rat := - match effectiveMinLogitDiff with - | none => none - | some minLogitDiff => - if logitDiffLB < minLogitDiff then - some minLogitDiff - else - none - match violation? with - | some minLogitDiff => - IO.eprintln - s!"error: logitDiffLB {logitDiffLB} \ - below minimum {minLogitDiff}" - return 2 - | none => - IO.println - s!"ok: nonvacuous induction bound certified \ - (seq={seq}, active={activeCount}, \ - tol={tol}, logitDiffLB={logitDiffLB})" - return 0 - -/-- Build and check induction certificates from exact head inputs. -/ -def runInductionCertifyHead (inputsPath : System.FilePath) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (1 / 2 : Rat) - let parsedInputs ← loadInductionHeadInputs inputsPath - match parsedInputs with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => - checkInductionHeadInputs inputs minActive? minLogitDiff? minMargin maxEps - -/-- Build and check a strictly positive induction logit-diff bound from head inputs. -/ -def runInductionCertifyHeadNonvacuous (inputsPath : System.FilePath) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (1 / 2 : Rat) - let parsedInputs ← loadInductionHeadInputs inputsPath - match parsedInputs with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => - checkInductionHeadInputsNonvacuous inputs minActive? minLogitDiff? minMargin maxEps - -/-- Build and check induction certificates from a model binary. -/ -def runInductionCertifyHeadModel (modelPath : System.FilePath) - (layer head dirTarget dirNegative : Nat) (period? : Option Nat) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (1 / 2 : Rat) - let data ← IO.FS.readBinFile modelPath - match NfptPure.parseHeader data with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨header, start⟩ => - match - NfptPure.readInductionHeadInputs - data start header layer head dirTarget dirNegative period? - with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok inputs => - checkInductionHeadInputs inputs minActive? minLogitDiff? minMargin maxEps - -/-- Build and check a strictly positive induction logit-diff bound from a model binary. -/ -def runInductionCertifyHeadModelNonvacuous (modelPath : System.FilePath) - (layer head dirTarget dirNegative : Nat) (period? : Option Nat) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (1 / 2 : Rat) - let data ← IO.FS.readBinFile modelPath - match NfptPure.parseHeader data with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨header, start⟩ => - match - NfptPure.readInductionHeadInputs - data start header layer head dirTarget dirNegative period? - with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok inputs => - checkInductionHeadInputsNonvacuous inputs minActive? minLogitDiff? minMargin maxEps - -/-- Build head-output interval bounds from exact head inputs. -/ -def runInductionHeadInterval (inputsPath : System.FilePath) - (outPath? : Option System.FilePath) : IO UInt32 := do - let parsedInputs ← loadInductionHeadInputs inputsPath - match parsedInputs with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => - buildHeadOutputIntervalFromInputs inputs outPath? - -/-- Build head-output interval bounds from a model binary. -/ -def runInductionHeadIntervalModel (modelPath : System.FilePath) - (layer head dirTarget dirNegative : Nat) (period? : Option Nat) - (outPath? : Option System.FilePath) : IO UInt32 := do - let data ← IO.FS.readBinFile modelPath - match NfptPure.parseHeader data with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨header, start⟩ => - match - NfptPure.readInductionHeadInputs - data start header layer head dirTarget dirNegative period? - with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok inputs => - buildHeadOutputIntervalFromInputs inputs outPath? - end IO - end Nfp diff --git a/Nfp/IO/Bench/InductionCore.lean b/Nfp/IO/Bench/InductionCore.lean new file mode 100644 index 0000000..b6050c3 --- /dev/null +++ b/Nfp/IO/Bench/InductionCore.lean @@ -0,0 +1,229 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.IO.Timing +import Nfp.Sound.Induction +import Nfp.Sound.Induction.HeadBounds + +/-! +Benchmark helpers for induction-head core certification. +-/ + +namespace Nfp + +namespace IO + +open Sound +open scoped BigOperators + +private def benchPhasePure {α : Type} (label : String) (act : Unit → α) : IO α := do + IO.println s!"bench: {label} start" + flushStdout + timePhase label (pure (act ())) + +private def forceScore {seq dModel dHead : Nat} + (score : Sound.HeadScoreBounds seq dModel dHead) : Dyadic := + score.margin + score.eps + +private def forceValues {seq dModel dHead : Nat} + (vals : Sound.HeadValueBounds seq dModel dHead) : Dyadic := + vals.lo + vals.hi + +private def forceQAbs {seq dHead : Nat} + (qAbs : Fin seq → Fin dHead → Dyadic) : Dyadic := + (Finset.univ : Finset (Fin seq)).sum (fun q => + (Finset.univ : Finset (Fin dHead)).sum (fun d => qAbs q d)) + +private def forceLn {seq dModel : Nat} + (ln : Fin seq → Fin dModel → Dyadic) : Dyadic := + (Finset.univ : Finset (Fin seq)).sum (fun q => + (Finset.univ : Finset (Fin dModel)).sum (fun i => ln q i)) + +private def forceKAbs {seq dHead : Nat} + (kAbs : Fin seq → Fin dHead → Dyadic) : Dyadic := + (Finset.univ : Finset (Fin seq)).sum (fun q => + (Finset.univ : Finset (Fin dHead)).sum (fun d => kAbs q d)) + +private def forceDotAbs {seq dHead : Nat} + (qAbs kAbs : Fin seq → Fin dHead → Dyadic) : Dyadic := + (Finset.univ : Finset (Fin seq)).sum (fun q => + (Finset.univ : Finset (Fin seq)).sum (fun k => + Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d))) + +private def forceDotAbsTasksReduce {seq dHead : Nat} + (qAbs kAbs : Fin seq → Fin dHead → Dyadic) : Dyadic := + let tasks : Array (Task Dyadic) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + (Finset.univ : Finset (Fin seq)).sum (fun k => + Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d)))) + (Finset.univ : Finset (Fin seq)).sum (fun q => + (tasks[q.1]'(by simp [tasks, q.isLt])).get) + +private def isPow2 (n : Nat) : Bool := + if n = 0 then + false + else + decide (Nat.pow 2 (Nat.log2 n) = n) + +private def isDyadic (q : Dyadic) : Bool := + isPow2 q.toRat.den + +private def countDyadic {seq dHead : Nat} + (qs : List (Fin seq)) (ds : List (Fin dHead)) + (f : Fin seq → Fin dHead → Dyadic) : Nat := + qs.foldl (fun acc q => + ds.foldl (fun acc' d => acc' + (if isDyadic (f q d) then 1 else 0)) acc) 0 + +private def dyadicSampleReport {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (qkv : Sound.HeadQKVBounds seq dModel dHead) : String := + let qs := (List.finRange seq).take (min seq 2) + let ds := (List.finRange dHead).take (min dHead 8) + let total := qs.length * ds.length + let qLoDy := countDyadic qs ds qkv.qLo + let qHiDy := countDyadic qs ds qkv.qHi + let qAbsDy := countDyadic qs ds qkv.qAbs + let kAbsDy := countDyadic qs ds qkv.kAbs + let epsDy := if isDyadic inputs.lnEps then 1 else 0 + s!"dyadic sample: total={total} qLo={qLoDy} qHi={qHiDy} qAbs={qAbsDy} " ++ + s!"kAbs={kAbsDy} lnEps={epsDy}" + +private def dyadicSanityReport : String := + let rat := dyadicOfRatDown (Rat.divInt 1 8) + let powChecks := + s!"pow2(1)={isPow2 1} pow2(2)={isPow2 2} pow2(3)={isPow2 3} " ++ + s!"pow2(4)={isPow2 4} pow2(8)={isPow2 8}" + let ratCheck := s!"rat(1/8).den={rat.toRat.den} dyadic={isDyadic rat}" + s!"dyadic sanity: {powChecks} {ratCheck}" + +private def forceQRowTasks {seq dHead : Nat} + (q0 : Fin seq) (qLo : Fin seq → Fin dHead → Dyadic) : Int := + let tasks : Array (Task Dyadic) := + Array.ofFn (fun d : Fin dHead => + Task.spawn (fun _ => qLo q0 d)) + let total := + (Finset.univ : Finset (Fin dHead)).sum (fun d => + (tasks[d.1]'(by simp [tasks, d.isLt])).get) + total.toRat.num + +private def qAbsRowChunk {seq dHead : Nat} + (q0 : Fin seq) (qAbs : Fin seq → Fin dHead → Dyadic) (start stop : Nat) : Dyadic := + let chunk : Finset (Fin dHead) := + (Finset.univ : Finset (Fin dHead)).filter (fun d => start ≤ d.1 ∧ d.1 < stop) + chunk.sum (fun d => qAbs q0 d) + +private def forceQAbsRowTasksReduce {seq dHead : Nat} + (q0 : Fin seq) (qAbs : Fin seq → Fin dHead → Dyadic) (chunkSize : Nat) : Dyadic := + if chunkSize = 0 then + (0 : Dyadic) + else + let chunks : Nat := (dHead + chunkSize - 1) / chunkSize + let tasks : Array (Task Dyadic) := + Array.ofFn (fun i : Fin chunks => + Task.spawn (fun _ => + let start := i.1 * chunkSize + let stop := min dHead (start + chunkSize) + qAbsRowChunk q0 qAbs start stop)) + (Finset.univ : Finset (Fin chunks)).sum (fun i => + (tasks[i.1]'(by simp [tasks, i.isLt])).get) + +private def forceQAbsAllTasksReduce {seq dHead : Nat} + (qAbs : Fin seq → Fin dHead → Dyadic) (chunkSize : Nat) : Dyadic := + let tasks : Array (Task Dyadic) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => forceQAbsRowTasksReduce q qAbs chunkSize)) + (Finset.univ : Finset (Fin seq)).sum (fun q => + (tasks[q.1]'(by simp [tasks, q.isLt])).get) + +private def forceQAbsActiveTasksReduce {seq dHead : Nat} + (active : Finset (Fin seq)) (qAbs : Fin seq → Fin dHead → Dyadic) (chunkSize : Nat) : Dyadic := + if hactive : active.Nonempty then + let tasks : Array (Task Dyadic) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + if q ∈ active then + forceQAbsRowTasksReduce q qAbs chunkSize + else + (0 : Dyadic))) + active.sum (fun q => + (tasks[q.1]'(by simp [tasks, q.isLt])).get) + else + (0 : Dyadic) + +/-- Run a core benchmark from already-parsed head inputs. -/ +def runCoreBenchFromInputs {seq dModel dHead : Nat} [NeZero seq] + (inputs : Model.InductionHeadInputs seq dModel dHead) : IO Unit := do + let lnBounds ← benchPhasePure "ln bounds" (fun () => Sound.headLnBounds inputs) + let lnLo := lnBounds.1 + let lnHi := lnBounds.2 + let _ ← benchPhasePure "lnLo force" (fun () => forceLn lnLo) + let _ ← benchPhasePure "lnHi force" (fun () => forceLn lnHi) + let qkv ← benchPhasePure "qkv bounds" (fun () => Sound.headQKVBounds inputs lnLo lnHi) + let _ ← timePhase "dyadic sample" (do + IO.println (dyadicSampleReport inputs qkv) + IO.println dyadicSanityReport + pure ()) + let _ ← benchPhasePure "qLo single" (fun () => + match h : dHead with + | 0 => (0 : Dyadic) + | Nat.succ _ => + let q0 : Fin seq := + ⟨0, Nat.pos_of_ne_zero (NeZero.ne (n := seq))⟩ + let d0 : Fin dHead := ⟨0, by simp [h]⟩ + qkv.qLo q0 d0) + let _ ← benchPhasePure "qLo row tasks" (fun () => + let q0 : Fin seq := + ⟨0, by + exact Nat.pos_of_ne_zero (NeZero.ne (n := seq))⟩ + forceQRowTasks q0 qkv.qLo) + let _ ← benchPhasePure "qLo row" (fun () => + let q0 : Fin seq := + ⟨0, by + exact Nat.pos_of_ne_zero (NeZero.ne (n := seq))⟩ + let total := + (Finset.univ : Finset (Fin dHead)).sum (fun d => qkv.qLo q0 d) + total.toRat.num) + let _ ← benchPhasePure "qLo force" (fun () => forceQAbs qkv.qLo) + let _ ← benchPhasePure "qHi force" (fun () => forceQAbs qkv.qHi) + let _ ← benchPhasePure "qAbs single" (fun () => + let q0 : Fin seq := + ⟨0, by + exact Nat.pos_of_ne_zero (NeZero.ne (n := seq))⟩ + match h : dHead with + | 0 => (0 : Dyadic) + | Nat.succ _ => + let d0 : Fin dHead := ⟨0, by simp [h]⟩ + qkv.qAbs q0 d0) + let _ ← benchPhasePure "qAbs row tasks reduce" (fun () => + let q0 : Fin seq := + ⟨0, by + exact Nat.pos_of_ne_zero (NeZero.ne (n := seq))⟩ + forceQAbsRowTasksReduce q0 qkv.qAbs 1) + let _ ← benchPhasePure "qAbs force tasks reduce active" (fun () => + forceQAbsActiveTasksReduce inputs.active qkv.qAbs 1) + let _ ← benchPhasePure "qAbs force tasks reduce" (fun () => + forceQAbsAllTasksReduce qkv.qAbs 1) + let _ ← benchPhasePure "kAbs force tasks reduce active" (fun () => + forceQAbsActiveTasksReduce inputs.active qkv.kAbs 1) + let _ ← benchPhasePure "kAbs force tasks reduce" (fun () => + forceQAbsAllTasksReduce qkv.kAbs 1) + let _ ← benchPhasePure "kAbs force tasks reduce (bench)" (fun () => + forceQAbsAllTasksReduce qkv.kAbs 1) + let _ ← benchPhasePure "dotAbs force tasks reduce" (fun () => + forceDotAbsTasksReduce qkv.qAbs qkv.kAbs) + let _ ← benchPhasePure "dotAbs force" (fun () => forceDotAbs qkv.qAbs qkv.kAbs) + let score ← benchPhasePure "score bounds" (fun () => + Sound.headScoreBounds inputs qkv.qAbs qkv.kAbs) + let _ ← benchPhasePure "score force" (fun () => forceScore score) + let vals ← benchPhasePure "value bounds" (fun () => + Sound.headValueBounds inputs qkv.vLo qkv.vHi) + let _ ← benchPhasePure "value force" (fun () => forceValues vals) + let cert ← benchPhasePure "core cert" (fun () => + Sound.buildInductionCertFromHeadCore? inputs) + match cert with + | none => IO.println "bench: core cert none" + | some _ => IO.println "bench: core cert some" + +end IO + +end Nfp diff --git a/Nfp/IO/Bench/InductionCounts.lean b/Nfp/IO/Bench/InductionCounts.lean new file mode 100644 index 0000000..3571f18 --- /dev/null +++ b/Nfp/IO/Bench/InductionCounts.lean @@ -0,0 +1,72 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Std.Data.HashMap +import Nfp.Model.InductionHead + +/-! +Call-count instrumentation for induction-head computations. + +This is a placeholder-only benchmark: it records how often key functions would be +called in a score-bound pass without performing heavy arithmetic. +-/ + +namespace Nfp + +namespace IO + +open scoped BigOperators + +private def bumpCount (ref : IO.Ref (Std.HashMap String Nat)) (key : String) (n : Nat) : + IO Unit := do + ref.modify (fun m => + let cur := (m.get? key).getD 0 + m.insert key (cur + n)) + +private def printCounts (ref : IO.Ref (Std.HashMap String Nat)) : IO Unit := do + let m ← ref.get + let entries := m.toList + IO.println "counts:" + for (k, v) in entries do + IO.println s!" {k}: {v}" + +private def countScoreCalls {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (ref : IO.Ref (Std.HashMap String Nat)) : IO Unit := do + let activeCount := inputs.active.card + let otherCount := seq - 1 + let rowCount := activeCount * otherCount + let elemCount := rowCount * dHead + bumpCount ref "scoreBounds:scoreLo" rowCount + bumpCount ref "scoreBounds:scoreHi" rowCount + bumpCount ref "scoreBounds:qAbs" elemCount + bumpCount ref "scoreBounds:qLo" elemCount + bumpCount ref "scoreBounds:qHi" elemCount + bumpCount ref "scoreBounds:kAbs" elemCount + bumpCount ref "scoreBounds:kLo" elemCount + bumpCount ref "scoreBounds:kHi" elemCount + +private def countQKVCalls {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (ref : IO.Ref (Std.HashMap String Nat)) : IO Unit := do + let activeCount := inputs.active.card + let elemCount := activeCount * dHead + bumpCount ref "qkvBounds:qLo" elemCount + bumpCount ref "qkvBounds:qHi" elemCount + bumpCount ref "qkvBounds:kLo" elemCount + bumpCount ref "qkvBounds:kHi" elemCount + bumpCount ref "qkvBounds:vLo" elemCount + bumpCount ref "qkvBounds:vHi" elemCount + bumpCount ref "qkvBounds:qAbs" elemCount + bumpCount ref "qkvBounds:kAbs" elemCount + +/-- Count calls used by score/QKV bounds on the active set. -/ +def countInductionCalls {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : IO Unit := do + let ref ← IO.mkRef (∅ : Std.HashMap String Nat) + countQKVCalls inputs ref + countScoreCalls inputs ref + printCounts ref + +end IO + +end Nfp diff --git a/Nfp/IO/Bench/Rational.lean b/Nfp/IO/Bench/Rational.lean new file mode 100644 index 0000000..c48853f --- /dev/null +++ b/Nfp/IO/Bench/Rational.lean @@ -0,0 +1,362 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Data.List.Range +import Nfp.IO.Timing +import Nfp.Sound.Bounds.MatrixNorm +import Nfp.Sound.Linear.FinFold + +/-! +Microbenchmarks for dyadic arithmetic and caching strategies. +-/ + +namespace Nfp + +namespace IO + +open Sound + +private def benchItersFor (base : Nat) (n : Nat) : Nat := + let scale := max 1 (n / 64) + max 1 (base / scale) + +private def mkDyadic (num den : Nat) (neg : Bool) : Dyadic := + let n : Int := Int.ofNat (num + 1) + let d : Int := Int.ofNat (den + 1) + let q : Rat := Rat.divInt (if neg then -n else n) d + dyadicOfRatDown q + +private def mkVecDyadic (n : Nat) (seed : Nat) (salt : Nat) (negEvery : Nat) : Fin n → Dyadic := fun i => + let idx := i.1 + seed + salt + let neg := (idx % negEvery) = 0 + mkDyadic (idx % 97) (idx % 89) neg + +private def mkInterval (n : Nat) (seed : Nat) : + (Fin n → Dyadic) × (Fin n → Dyadic) × (Fin n → Dyadic) := + let v : Fin n → Dyadic := mkVecDyadic n seed 0 2 + let base : Fin n → Dyadic := mkVecDyadic n seed 13 3 + let lo : Fin n → Dyadic := fun i => base i - 1 + let hi : Fin n → Dyadic := fun i => base i + 1 + (v, lo, hi) + +private def benchLoop (label : String) (iters : Nat) (act : Unit → Dyadic) : IO Unit := do + let t0 ← monoUsNow + let mut last : Dyadic := 0 + for _ in List.range iters do + last := act () + let t1 ← monoUsNow + let total := t1 - t0 + let avg := total / max 1 iters + IO.println s!"bench: {label} iters={iters} total={total} us avg={avg} us last={last}" + +private def benchDotInterval (n iters seed : Nat) : IO Unit := do + let (v, lo, hi) := mkInterval n seed + let labelBase := s!"n={n}" + benchLoop s!"dotIntervalLower {labelBase}" iters (fun () => + Sound.Bounds.dotIntervalLower v lo hi) + benchLoop s!"dotIntervalLowerCommonDen {labelBase}" iters (fun () => + Sound.Bounds.dotIntervalLowerCommonDen v lo hi) + benchLoop s!"dotIntervalLowerCachedRat {labelBase}" iters (fun () => + Sound.Bounds.dotIntervalLowerCachedRat v lo hi) + benchLoop s!"dotIntervalUpper {labelBase}" iters (fun () => + Sound.Bounds.dotIntervalUpper v lo hi) + benchLoop s!"dotIntervalUpperCommonDen {labelBase}" iters (fun () => + Sound.Bounds.dotIntervalUpperCommonDen v lo hi) + benchLoop s!"dotIntervalUpperCachedRat {labelBase}" iters (fun () => + Sound.Bounds.dotIntervalUpperCachedRat v lo hi) + +private def dotIntervalLowerCachedCore {n : Nat} + (vArr loArr hiArr : Array Dyadic) + (hv : vArr.size = n) (hlo : loArr.size = n) (hhi : hiArr.size = n) : Dyadic := + let term : Fin n → Dyadic := fun j => + let vj := vArr[j.1]'(by + simp [hv, j.isLt]) + let loj := loArr[j.1]'(by + simp [hlo, j.isLt]) + let hij := hiArr[j.1]'(by + simp [hhi, j.isLt]) + if 0 ≤ vj then + vj * loj + else + vj * hij + Sound.Linear.sumFin n term + +private def dotIntervalUpperCachedCore {n : Nat} + (vArr loArr hiArr : Array Dyadic) + (hv : vArr.size = n) (hlo : loArr.size = n) (hhi : hiArr.size = n) : Dyadic := + let term : Fin n → Dyadic := fun j => + let vj := vArr[j.1]'(by + simp [hv, j.isLt]) + let loj := loArr[j.1]'(by + simp [hlo, j.isLt]) + let hij := hiArr[j.1]'(by + simp [hhi, j.isLt]) + if 0 ≤ vj then + vj * hij + else + vj * loj + Sound.Linear.sumFin n term + +private def benchDotIntervalCachedParts (n iters seed : Nat) : IO Unit := do + let (v, lo, hi) := mkInterval n seed + let vArr := Array.ofFn v + let loArr := Array.ofFn lo + let hiArr := Array.ofFn hi + have hv : vArr.size = n := by simp [vArr] + have hlo : loArr.size = n := by simp [loArr] + have hhi : hiArr.size = n := by simp [hiArr] + let labelBase := s!"n={n}" + benchLoop s!"dotIntervalLowerCachedRat arrays {labelBase}" iters (fun () => + let vArr' := Array.ofFn v + let loArr' := Array.ofFn lo + let hiArr' := Array.ofFn hi + vArr'.size + loArr'.size + hiArr'.size) + benchLoop s!"dotIntervalLowerCachedRat sum {labelBase}" iters (fun () => + dotIntervalLowerCachedCore vArr loArr hiArr hv hlo hhi) + benchLoop s!"dotIntervalUpperCachedRat sum {labelBase}" iters (fun () => + dotIntervalUpperCachedCore vArr loArr hiArr hv hlo hhi) + +private def benchDotFin (n iters seed : Nat) : IO Unit := do + let x : Fin n → Dyadic := mkVecDyadic n seed 7 4 + let y : Fin n → Dyadic := mkVecDyadic n seed 19 5 + let labelBase := s!"n={n}" + benchLoop s!"dotFin {labelBase}" iters (fun () => + Sound.Linear.dotFin n x y) + +private def headShapeIters (base : Nat) : Nat := + max 1 (base / 10) + +private def mkHeadAbs (seq dHead : Nat) (seed : Nat) (salt : Nat) : Fin seq → Fin dHead → Dyadic := + fun q d => + mkDyadic (q.1 * 31 + d.1 + seed + salt) + (q.1 + d.1 + 7 + seed + salt) (((q.1 + d.1) % 3) = 0) + +private def mkHeadVal (seq dHead : Nat) (seed : Nat) (salt : Nat) : Fin seq → Fin dHead → Dyadic := + fun q d => + mkDyadic (q.1 * 17 + d.1 + seed + salt) + (q.1 + d.1 + 11 + seed + salt) (((q.1 + d.1) % 5) = 0) + +private def mkHeadDir (dHead : Nat) (seed : Nat) (salt : Nat) : Fin dHead → Dyadic := fun d => + mkDyadic (d.1 + seed + salt) (d.1 + 3 + seed + salt) ((d.1 % 2) = 0) + +private def benchHeadDotAbs (iters seed : Nat) : IO Unit := do + let seq := 8 + let dHead := 64 + let qAbs : Fin seq → Fin dHead → Dyadic := mkHeadAbs seq dHead seed 3 + let kAbs : Fin seq → Fin dHead → Dyadic := mkHeadAbs seq dHead seed 19 + benchLoop "head dotAbs dotFin" iters (fun () => + (List.finRange seq).foldl (fun acc q => + (List.finRange seq).foldl (fun acc' k => + acc' + Sound.Linear.dotFin dHead (qAbs q) (kAbs k)) acc) 0) + +private def benchHeadValueBounds (iters seed : Nat) : IO Unit := do + let seq := 8 + let dHead := 64 + let dirHead : Fin dHead → Dyadic := mkHeadDir dHead seed 5 + let vLo : Fin seq → Fin dHead → Dyadic := mkHeadVal seq dHead seed 11 + let vHi : Fin seq → Fin dHead → Dyadic := mkHeadVal seq dHead seed 23 + let dirArr := Array.ofFn dirHead + have hdir : dirArr.size = dHead := by simp [dirArr] + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := by + simp [univ] + benchLoop "head value bounds (cached)" iters (fun () => + let valsLo : Fin seq → Dyadic := fun k => + Sound.Bounds.dotIntervalLowerCachedRat dirHead (vLo k) (vHi k) + let valsHi : Fin seq → Dyadic := fun k => + Sound.Bounds.dotIntervalUpperCachedRat dirHead (vLo k) (vHi k) + let lo := univ.inf' hnonempty valsLo + let hi := univ.sup' hnonempty valsHi + lo + hi) + benchLoop "head value bounds (common den)" iters (fun () => + let valsLo : Fin seq → Dyadic := fun k => + Sound.Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k) + let valsHi : Fin seq → Dyadic := fun k => + Sound.Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k) + let lo := univ.inf' hnonempty valsLo + let hi := univ.sup' hnonempty valsHi + lo + hi) + benchLoop "head value bounds (direct)" iters (fun () => + let valsLo : Fin seq → Dyadic := fun k => + Sound.Bounds.dotIntervalLower dirHead (vLo k) (vHi k) + let valsHi : Fin seq → Dyadic := fun k => + Sound.Bounds.dotIntervalUpper dirHead (vLo k) (vHi k) + let lo := univ.inf' hnonempty valsLo + let hi := univ.sup' hnonempty valsHi + lo + hi) + benchLoop "head value bounds (cached, reuse dir)" iters (fun () => + let valsLo : Fin seq → Dyadic := fun k => + let loArr := Array.ofFn (vLo k) + let hiArr := Array.ofFn (vHi k) + have hlo : loArr.size = dHead := by simp [loArr] + have hhi : hiArr.size = dHead := by simp [hiArr] + dotIntervalLowerCachedCore dirArr loArr hiArr hdir hlo hhi + let valsHi : Fin seq → Dyadic := fun k => + let loArr := Array.ofFn (vLo k) + let hiArr := Array.ofFn (vHi k) + have hlo : loArr.size = dHead := by simp [loArr] + have hhi : hiArr.size = dHead := by simp [hiArr] + dotIntervalUpperCachedCore dirArr loArr hiArr hdir hlo hhi + let lo := univ.inf' hnonempty valsLo + let hi := univ.sup' hnonempty valsHi + lo + hi) + +private def benchDyadicDivInt (iters seed : Nat) : IO Unit := do + let bigNum : Int := + Int.ofNat (2 ^ 200) * Int.ofNat (3 ^ 120) + Int.ofNat (5 ^ 90) + Int.ofNat seed + let bigDen : Int := + Int.ofNat (2 ^ 150) * Int.ofNat (3 ^ 80) + (Int.ofNat seed) + 1 + benchLoop "dyadicOfRatDown divInt big" iters (fun () => + dyadicOfRatDown (Rat.divInt bigNum bigDen)) + +private def forceQkvSumLimited {seq dModel dHead : Nat} + (qkv : Sound.HeadQKVBounds seq dModel dHead) (qLimit dLimit : Nat) : Dyadic := + let qs := (List.finRange seq).take qLimit + let ds := (List.finRange dHead).take dLimit + qs.foldl (fun acc q => + ds.foldl (fun acc' d => + acc' + qkv.qLo q d + qkv.qHi q d + + qkv.kLo q d + qkv.kHi q d + + qkv.vLo q d + qkv.vHi q d + + qkv.qAbs q d + qkv.kAbs q d) acc) 0 + +private def forceQkvSumDirect {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (lnLo lnHi : Fin seq → Fin dModel → Dyadic) + (qLimit dLimit : Nat) + (dotLower dotUpper : (Fin dModel → Dyadic) → (Fin dModel → Dyadic) → (Fin dModel → Dyadic) → Dyadic) : + Dyadic := + let qs := (List.finRange seq).take qLimit + let ds := (List.finRange dHead).take dLimit + qs.foldl (fun acc q => + ds.foldl (fun acc' d => + let qLo := dotLower (fun j => inputs.wq j d) (lnLo q) (lnHi q) + inputs.bq d + let qHi := dotUpper (fun j => inputs.wq j d) (lnLo q) (lnHi q) + inputs.bq d + let kLo := dotLower (fun j => inputs.wk j d) (lnLo q) (lnHi q) + inputs.bk d + let kHi := dotUpper (fun j => inputs.wk j d) (lnLo q) (lnHi q) + inputs.bk d + let vLo := dotLower (fun j => inputs.wv j d) (lnLo q) (lnHi q) + inputs.bv d + let vHi := dotUpper (fun j => inputs.wv j d) (lnLo q) (lnHi q) + inputs.bv d + let qAbs := max |qLo| |qHi| + let kAbs := max |kLo| |kHi| + acc' + qLo + qHi + kLo + kHi + vLo + vHi + qAbs + kAbs) acc) 0 + +private def benchHeadInputs {seq dModel dHead : Nat} [NeZero seq] + (inputs : Model.InductionHeadInputs seq dModel dHead) (iters : Nat) : IO Unit := do + IO.println "bench: head ln bounds start" + (← IO.getStdout).flush + let lnBounds ← Nfp.IO.timePure "bench: head ln bounds" (fun () => + Sound.headLnBounds inputs) + IO.println "bench: head qkv bounds start" + (← IO.getStdout).flush + let qLimit := + match (← IO.getEnv "NFP_BENCH_QKV_Q") with + | some raw => raw.toNat?.getD seq + | none => seq + let dLimit := + match (← IO.getEnv "NFP_BENCH_QKV_D") with + | some raw => raw.toNat?.getD dHead + | none => dHead + let skipCache := (← IO.getEnv "NFP_BENCH_SKIP_QKV_CACHE").isSome + if !skipCache then + IO.println s!"bench: head qkv bounds (cachedDyadic) start q={qLimit} d={dLimit}" + (← IO.getStdout).flush + let _sumDyadic ← Nfp.IO.timePure "bench: head qkv bounds (cachedDyadic)" (fun () => + forceQkvSumLimited (Sound.headQKVBounds inputs lnBounds.1 lnBounds.2) qLimit dLimit) + pure () + IO.println s!"bench: head qkv bounds (directDyadic) start q={qLimit} d={dLimit}" + (← IO.getStdout).flush + let _sumDirectDyadic ← Nfp.IO.timePure "bench: head qkv bounds (directDyadic)" (fun () => + forceQkvSumDirect inputs lnBounds.1 lnBounds.2 qLimit dLimit + Sound.Bounds.dotIntervalLowerCachedRat Sound.Bounds.dotIntervalUpperCachedRat) + IO.println s!"bench: head qkv bounds (directDyadicNoCache) start q={qLimit} d={dLimit}" + (← IO.getStdout).flush + let _sumDirectDyadicNoCache ← Nfp.IO.timePure "bench: head qkv bounds (directDyadicNoCache)" (fun () => + forceQkvSumDirect inputs lnBounds.1 lnBounds.2 qLimit dLimit + Sound.Bounds.dotIntervalLower Sound.Bounds.dotIntervalUpper) + let qkv := Sound.headQKVBounds inputs lnBounds.1 lnBounds.2 + let qAbs := qkv.qAbs + let kAbs := qkv.kAbs + benchLoop "head inputs dotAbs dotFin" iters (fun () => + (List.finRange seq).foldl (fun acc q => + (List.finRange seq).foldl (fun acc' k => + acc' + Sound.Linear.dotFin dHead (qAbs q) (kAbs k)) acc) 0) + IO.println "bench: head value dir start" + (← IO.getStdout).flush + let dirHead ← Nfp.IO.timePure "bench: head value dir" (fun () => + Sound.headValueDirHead inputs) + let dirArr := Array.ofFn dirHead + have hdir : dirArr.size = dHead := by simp [dirArr] + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := by simp [univ] + benchLoop "head inputs value bounds (cached)" iters (fun () => + let valsLo : Fin seq → Dyadic := fun k => + Sound.Bounds.dotIntervalLowerCachedRat dirHead (qkv.vLo k) (qkv.vHi k) + let valsHi : Fin seq → Dyadic := fun k => + Sound.Bounds.dotIntervalUpperCachedRat dirHead (qkv.vLo k) (qkv.vHi k) + let lo := univ.inf' hnonempty valsLo + let hi := univ.sup' hnonempty valsHi + lo + hi) + benchLoop "head inputs value bounds (direct)" iters (fun () => + let valsLo : Fin seq → Dyadic := fun k => + Sound.Bounds.dotIntervalLower dirHead (qkv.vLo k) (qkv.vHi k) + let valsHi : Fin seq → Dyadic := fun k => + Sound.Bounds.dotIntervalUpper dirHead (qkv.vLo k) (qkv.vHi k) + let lo := univ.inf' hnonempty valsLo + let hi := univ.sup' hnonempty valsHi + lo + hi) + benchLoop "head inputs value bounds (cached, reuse dir)" iters (fun () => + let valsLo : Fin seq → Dyadic := fun k => + let loArr := Array.ofFn (qkv.vLo k) + let hiArr := Array.ofFn (qkv.vHi k) + have hlo : loArr.size = dHead := by simp [loArr] + have hhi : hiArr.size = dHead := by simp [hiArr] + dotIntervalLowerCachedCore dirArr loArr hiArr hdir hlo hhi + let valsHi : Fin seq → Dyadic := fun k => + let loArr := Array.ofFn (qkv.vLo k) + let hiArr := Array.ofFn (qkv.vHi k) + have hlo : loArr.size = dHead := by simp [loArr] + have hhi : hiArr.size = dHead := by simp [hiArr] + dotIntervalUpperCachedCore dirArr loArr hiArr hdir hlo hhi + let lo := univ.inf' hnonempty valsLo + let hi := univ.sup' hnonempty valsHi + lo + hi) + +/-- Run rational microbenchmarks for several vector sizes. -/ +def runDyadicBench (seed : Nat) : IO Unit := do + let baseIters := + match (← IO.getEnv "NFP_BENCH_ITERS") with + | some raw => raw.toNat?.getD 200 + | none => 200 + let sizes : List Nat := [8, 64, 256, 768] + for n in sizes do + let iters := benchItersFor baseIters n + IO.println s!"bench: start n={n} iters={iters}" + benchDotInterval n iters seed + benchDotIntervalCachedParts n iters seed + benchDotFin n iters seed + let headIters := headShapeIters baseIters + IO.println s!"bench: start head-shape iters={headIters}" + benchHeadDotAbs headIters seed + benchHeadValueBounds headIters seed + benchDyadicDivInt headIters seed + +/-- Run benchmarks using a real induction-head input payload. -/ +def runDyadicBenchFromInputs {seq dModel dHead : Nat} [NeZero seq] + (seed : Nat) (inputs : Model.InductionHeadInputs seq dModel dHead) : IO Unit := do + let skipSynth := (← IO.getEnv "NFP_BENCH_SKIP_SYNTH").isSome + if !skipSynth then + runDyadicBench seed + let baseIters := + match (← IO.getEnv "NFP_BENCH_ITERS") with + | some raw => raw.toNat?.getD 200 + | none => 200 + let headIters := + match (← IO.getEnv "NFP_BENCH_HEAD_ITERS") with + | some raw => raw.toNat?.getD (headShapeIters baseIters) + | none => headShapeIters baseIters + IO.println s!"bench: start head-inputs iters={headIters}" + (← IO.getStdout).flush + benchHeadInputs inputs headIters + +end IO + +end Nfp diff --git a/Nfp/IO/Checks.lean b/Nfp/IO/Checks.lean new file mode 100644 index 0000000..4b373ab --- /dev/null +++ b/Nfp/IO/Checks.lean @@ -0,0 +1,44 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Circuit.Cert.SoftmaxMargin +import Nfp.Circuit.Cert.ValueRange + +/-! +IO checks for certificates. +-/ + +namespace Nfp + +namespace IO + +open Nfp.Circuit + +def checkSoftmaxMargin (seq : Nat) (cert : SoftmaxMarginCert seq) : + IO (Except String Unit) := + match seq with + | 0 => return Except.error "seq must be positive" + | Nat.succ n => + let seq := Nat.succ n + let _ : NeZero seq := ⟨by simp⟩ + let ok := Circuit.checkSoftmaxMarginCert cert + if ok then + return Except.ok () + else + return Except.error "softmax-margin certificate rejected" + +def checkValueRange (seq : Nat) (cert : ValueRangeCert seq) : + IO (Except String Unit) := + match seq with + | 0 => return Except.error "seq must be positive" + | Nat.succ n => + let seq := Nat.succ n + let _ : NeZero seq := ⟨by simp⟩ + let ok := Circuit.checkValueRangeCert cert + if ok then + return Except.ok () + else + return Except.error "value-range certificate rejected" + +end IO + +end Nfp diff --git a/Nfp/IO/Derive.lean b/Nfp/IO/Derive.lean new file mode 100644 index 0000000..639f511 --- /dev/null +++ b/Nfp/IO/Derive.lean @@ -0,0 +1,136 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Data.List.Range +import Mathlib.Data.Matrix.Mul +import Mathlib.Data.Vector.Defs +import Nfp.IO.NfptPure +import Nfp.IO.Timing +import Nfp.Model.Gpt2 +import Nfp.Sound.Bounds.Transformer +import Nfp.Sound.Induction +import Nfp.Sound.Induction.HeadBounds + +/-! +IO derivations that build certificates from model binaries. +-/ + +namespace Nfp + +namespace IO + +open Nfp.Circuit + +def deriveResidualIntervalFromModel (data : ByteArray) (start : Nat) + (header : NfptPure.NfptHeader) (active? : Option (Finset (Fin header.seqLen))) : + IO (Except String (ResidualIntervalCert header.modelDim)) := do + if hseq : header.seqLen = 0 then + return Except.error "seq must be positive" + else + have _ : NeZero header.seqLen := ⟨hseq⟩ + if header.modelDim = 0 then + return Except.error "model dim must be positive" + else if 0 < header.layerNormEps then + let embedE ← timePure "read embeddings" (fun () => + NfptPure.readEmbeddings data start header) + match embedE with + | Except.error msg => return Except.error msg + | Except.ok embed => + let layerSlicesE ← timePure "read layer slices" (fun () => + NfptPure.readLayerSlices data start header) + match layerSlicesE with + | Except.error msg => return Except.error msg + | Except.ok layerSlices => + let headLayersE ← timePure "read layer heads" (fun () => + NfptPure.readLayerHeads data start header) + match headLayersE with + | Except.error msg => return Except.error msg + | Except.ok headLayers => + let finalLnE ← timePure "read final layer norm" (fun () => + NfptPure.readFinalLayerNorm data start header) + match finalLnE with + | Except.error msg => return Except.error msg + | Except.ok finalLn => + let layers : + Fin header.numLayers → + Model.Gpt2LayerSlice header.modelDim header.hiddenDim := + fun l => NfptPure.SizedArray.get layerSlices l + let heads : + Fin header.numLayers → Fin header.numHeads → + Model.Gpt2HeadWeights header.modelDim header.headDim := fun l h => + NfptPure.SizedArray.get (NfptPure.SizedArray.get headLayers l) h + let strict? ← IO.getEnv "NFP_TIMING_STRICT" + match strict? with + | some _ => + logTiming "timing strict enabled" + | none => + logTiming "timing strict disabled" + match active? with + | some active => + if hactive : active.Nonempty then + logTiming "before transformer stack bounds (active)" + let bounds ← timePhaseThunk "transformer stack bounds (active)" + (fun () => do + let bounds := Sound.Bounds.gpt2ResidualIntervalBoundsActive + active hactive header.layerNormEps layers heads finalLn embed + match strict? with + | some _ => + let forced := + (List.finRange header.modelDim).foldl + (fun acc i => acc + bounds.1 i + bounds.2 i) (0 : Dyadic) + logTiming s!"forced transformer stack sum {forced}" + | none => pure () + return bounds) + logTiming "after transformer stack bounds (active)" + return Except.ok { lo := bounds.1, hi := bounds.2 } + else + logTiming "active set empty; falling back to global bounds" + let base ← timePure "embedding interval bounds" (fun () => + Sound.Bounds.embeddingIntervalBounds embed) + logTiming "before transformer stack bounds" + let stack ← timePhaseThunk "transformer stack bounds" (fun () => do + let stack := Sound.Bounds.transformerStackBounds + (eps := header.layerNormEps) layers heads base.1 base.2 + match strict? with + | some _ => + let forced := + (List.finRange header.modelDim).foldl + (fun acc i => acc + stack.1 i + stack.2 i) (0 : Dyadic) + logTiming s!"forced transformer stack sum {forced}" + | none => pure () + return stack) + logTiming "after transformer stack bounds" + logTiming "enter final layer norm bounds" + let bounds ← timePure "final layer norm bounds" (fun () => + Sound.Bounds.layerNormIntervalBounds (eps := header.layerNormEps) + finalLn.gamma finalLn.beta stack.1 stack.2) + logTiming "exit final layer norm bounds" + return Except.ok { lo := bounds.1, hi := bounds.2 } + | none => + let base ← timePure "embedding interval bounds" (fun () => + Sound.Bounds.embeddingIntervalBounds embed) + logTiming "before transformer stack bounds" + let stack ← timePhaseThunk "transformer stack bounds" (fun () => do + let stack := Sound.Bounds.transformerStackBounds + (eps := header.layerNormEps) layers heads base.1 base.2 + match strict? with + | some _ => + let forced := + (List.finRange header.modelDim).foldl + (fun acc i => acc + stack.1 i + stack.2 i) (0 : Dyadic) + logTiming s!"forced transformer stack sum {forced}" + | none => pure () + return stack) + logTiming "after transformer stack bounds" + logTiming "enter final layer norm bounds" + let bounds ← timePure "final layer norm bounds" (fun () => + Sound.Bounds.layerNormIntervalBounds (eps := header.layerNormEps) + finalLn.gamma finalLn.beta stack.1 stack.2) + logTiming "exit final layer norm bounds" + return Except.ok { lo := bounds.1, hi := bounds.2 } + else + return Except.error + s!"layer norm epsilon {header.layerNormEps} must be positive" + +end IO + +end Nfp diff --git a/Nfp/IO/HeadScore.lean b/Nfp/IO/HeadScore.lean new file mode 100644 index 0000000..c7dab88 --- /dev/null +++ b/Nfp/IO/HeadScore.lean @@ -0,0 +1,56 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Core.Basic +import Nfp.Sound.Linear.FinFold + +/-! +Pure helpers for building cached dot-abs functions for head scoring. +-/ + +namespace Nfp + +namespace IO + +/-- Build a cached dot-abs function from Q/K absolute bounds using tasks. -/ +def dotAbsFromQKV {seq dHead : Nat} + (qAbs kAbs : Fin seq → Fin dHead → Dyadic) : Fin seq → Fin seq → Dyadic := + let rowTasks : Array (Task (Array Dyadic)) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + Array.ofFn (fun k : Fin seq => + Sound.Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d)))) + let cache : Array (Array Dyadic) := + Array.ofFn (fun q : Fin seq => + (rowTasks[q.1]'(by + simp [rowTasks, q.isLt])).get) + fun q k => + let row := cache[q.1]'(by + simp [cache, q.isLt]) + row[k.1]'(by + have hrow : row.size = seq := by + simp [row, cache, rowTasks, Task.spawn] + simp [hrow, k.isLt]) + +theorem dotAbsFromQKV_spec {seq dHead : Nat} + (qAbs kAbs : Fin seq → Fin dHead → Dyadic) : + dotAbsFromQKV qAbs kAbs = + let rowTasks : Array (Task (Array Dyadic)) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + Array.ofFn (fun k : Fin seq => + Sound.Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d)))) + let cache : Array (Array Dyadic) := + Array.ofFn (fun q : Fin seq => + (rowTasks[q.1]'(by + simp [rowTasks, q.isLt])).get) + fun q k => + let row := cache[q.1]'(by + simp [cache, q.isLt]) + row[k.1]'(by + have hrow : row.size = seq := by + simp [row, cache, rowTasks, Task.spawn] + simp [hrow, k.isLt]) := rfl + +end IO + +end Nfp diff --git a/Nfp/IO/InductionHead.lean b/Nfp/IO/InductionHead.lean new file mode 100644 index 0000000..d18a3e8 --- /dev/null +++ b/Nfp/IO/InductionHead.lean @@ -0,0 +1,821 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Data.List.Range +import Nfp.IO.Pure +import Nfp.IO.NfptPure +import Nfp.IO.HeadScore +import Nfp.IO.Timing +import Nfp.IO.Util +import Nfp.Circuit.Cert.LogitDiff +import Nfp.Circuit.Cert.ResidualInterval +import Nfp.Sound.Induction +import Nfp.Sound.Induction.HeadBounds +import Nfp.Sound.Induction.LogitDiff +import Nfp.Sound.Linear.FinFold + +/-! +IO helpers for induction-head certificate construction. +-/ + +namespace Nfp + +namespace IO + +open Nfp.Circuit + +private def valueBoundsModeFromEnv : IO (Option Bool) := do + match (← IO.getEnv "NFP_VALUE_BOUNDS_MODE") with + | some "common" => return some true + | some "cached" => return some false + | _ => return none + +/-- Load induction head inputs from disk. -/ +def loadInductionHeadInputs (path : System.FilePath) : + IO (Except String (Sigma (fun seq => + Sigma (fun dModel => Sigma (fun dHead => Model.InductionHeadInputs seq dModel dHead))))) := do + let t0 ← monoUsNow + let data ← IO.FS.readFile path + let t1 ← monoUsNow + IO.println s!"timing: read head input file {t1 - t0} us" + let t2 ← monoUsNow + let parsed := + match Pure.parseInductionHeadInputs data with + | Except.error msg => Except.error msg + | Except.ok v => Except.ok v + let t3 ← monoUsNow + IO.println s!"timing: parse head input file {t3 - t2} us" + return parsed + +private def dyadicToString (x : Dyadic) : String := + toString x.toRat + +private def renderResidualIntervalCert {n : Nat} (c : Circuit.ResidualIntervalCert n) : String := + let header := s!"dim {n}" + let lines := + (List.finRange n).foldr (fun i acc => + s!"lo {i.val} {dyadicToString (c.lo i)}" :: + s!"hi {i.val} {dyadicToString (c.hi i)}" :: acc) [] + String.intercalate "\n" (header :: lines) + +private def emitResidualIntervalCert {n : Nat} (c : Circuit.ResidualIntervalCert n) + (outPath? : Option System.FilePath) : IO Unit := do + let payload := renderResidualIntervalCert c + match outPath? with + | some path => IO.FS.writeFile path (payload ++ "\n") + | none => IO.println payload + +private def buildHeadOutputIntervalFromInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (outPath? : Option System.FilePath) : IO UInt32 := do + match seq with + | 0 => + IO.eprintln "error: seq must be positive" + return 2 + | Nat.succ n => + let seq := Nat.succ n + let _ : NeZero seq := ⟨by simp⟩ + match Sound.buildHeadOutputIntervalFromHead? inputs with + | none => + IO.eprintln "error: head output interval rejected" + return 2 + | some result => + emitResidualIntervalCert result.cert outPath? + if outPath?.isSome then + let activeCount := result.active.card + IO.println + s!"ok: head output interval built (seq={seq}, dim={dModel}, active={activeCount})" + return 0 + +private def headScoreBoundsFromDotAbsTimed {seq dModel dHead : Nat} [NeZero seq] + (inputs : Model.InductionHeadInputs seq dModel dHead) + (dotAbs : Fin seq → Fin seq → Dyadic) : + IO (Sound.HeadScoreBounds seq dModel dHead) := do + let headScoreBoundsFromCachesTimed + (scoreLo scoreHi : Fin seq → Fin seq → Dyadic) : + IO (Sound.HeadScoreBounds seq dModel dHead) := do + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let scoreBaseAbs : Fin seq → Fin seq → Dyadic := fun q k => + |inputs.scale| * dotAbs q k + let scoreAbs : Fin seq → Fin seq → Dyadic := fun q k => + if masked q k then |inputs.maskValue| else scoreBaseAbs q k + let otherKeys : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + let maskedKeys : Fin seq → Finset (Fin seq) := fun q => + if inputs.maskCausal = true then + (otherKeys q).filter (fun k => q < k) + else + (∅ : Finset (Fin seq)) + let unmaskedKeys : Fin seq → Finset (Fin seq) := fun q => + (otherKeys q) \ (maskedKeys q) + let maskedGap : Fin seq → Dyadic := fun q => + scoreLo q (inputs.prev q) - inputs.maskValue + let marginTasks : { arr : Array (Task Dyadic) // arr.size = seq } ← + timePhase "head: score margin tasks" <| do + let arr : Array (Task Dyadic) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + if q ∈ inputs.active then + let other := unmaskedKeys q + let masked := maskedKeys q + let prev := inputs.prev q + let gapTasks : Array (Task Dyadic) := + Array.ofFn (fun k : Fin seq => + Task.spawn (fun _ => scoreLo q prev - scoreHi q k)) + let gap : Fin seq → Dyadic := fun k => + let row := gapTasks[k.1]'(by + simp [gapTasks, k.isLt]) + row.get + if hunmasked : other.Nonempty then + let unmaskedMin := other.inf' hunmasked gap + if hmasked : masked.Nonempty then + min unmaskedMin (maskedGap q) + else + unmaskedMin + else + if hmasked : masked.Nonempty then + maskedGap q + else + (0 : Dyadic) + else + (0 : Dyadic))) + let hsize : arr.size = seq := by simp [arr] + pure ⟨arr, hsize⟩ + have hmargin : marginTasks.1.size = seq := marginTasks.2 + let marginAt : Fin seq → Dyadic := fun q => + let q' : Fin marginTasks.1.size := Fin.cast hmargin.symm q + (marginTasks.1[q'.1]'(by exact q'.isLt)).get + let epsTasks : { arr : Array (Task Dyadic) // arr.size = seq } ← + timePhase "head: score eps tasks" <| do + let arr : Array (Task Dyadic) := + Array.ofFn (fun q : Fin seq => + let q' : Fin marginTasks.1.size := Fin.cast hmargin.symm q + (marginTasks.1[q'.1]'(by exact q'.isLt)).map (fun m => + if m < 0 then + (1 : Dyadic) + else + dyadicDivUp (seq - 1) (1 + m))) + let hsize : arr.size = seq := by simp [arr] + pure ⟨arr, hsize⟩ + have heps : epsTasks.1.size = seq := epsTasks.2 + let epsAt : Fin seq → Dyadic := fun q => + let q' : Fin epsTasks.1.size := Fin.cast heps.symm q + (epsTasks.1[q'.1]'(by exact q'.isLt)).get + let margin ← timePhase "head: score margin reduction" <| + pure (if h : inputs.active.Nonempty then + inputs.active.inf' h marginAt + else + (0 : Dyadic)) + let eps ← timePhase "head: score eps reduction" <| + pure (if margin < 0 then + (1 : Dyadic) + else + dyadicDivUp (seq - 1) (1 + margin)) + let result : Sound.HeadScoreBounds seq dModel dHead := + { dotAbs := dotAbs + scoreBaseAbs := scoreBaseAbs + scoreAbs := scoreAbs + scoreLo := scoreLo + scoreHi := scoreHi + marginAt := marginAt + epsAt := epsAt + margin := margin + eps := eps } + return result + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let scoreBaseAbs : Fin seq → Fin seq → Dyadic := fun q k => + |inputs.scale| * dotAbs q k + let scoreLoRaw : Fin seq → Fin seq → Dyadic := fun q k => + if masked q k then + inputs.maskValue + else + -scoreBaseAbs q k + let scoreHiRaw : Fin seq → Fin seq → Dyadic := fun q k => + if masked q k then + inputs.maskValue + else + scoreBaseAbs q k + IO.println "timing: head score caches skipped (direct score functions)" + flushStdout + let scoreLo : Fin seq → Fin seq → Dyadic := fun q k => + if masked q k then inputs.maskValue else -(|inputs.scale| * dotAbs q k) + let scoreHi : Fin seq → Fin seq → Dyadic := fun q k => + if masked q k then inputs.maskValue else |inputs.scale| * dotAbs q k + headScoreBoundsFromCachesTimed scoreLo scoreHi + +private def headScoreBoundsFromQAbsKAbsTimed {seq dModel dHead : Nat} [NeZero seq] + (inputs : Model.InductionHeadInputs seq dModel dHead) + (qAbs kAbs : Fin seq → Fin dHead → Dyadic) + (dotAbs : Fin seq → Fin seq → Dyadic) : + IO (Sound.HeadScoreBounds seq dModel dHead) := do + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let otherKeys : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + let maskedKeys : Fin seq → Finset (Fin seq) := fun q => + if inputs.maskCausal = true then + (otherKeys q).filter (fun k => q < k) + else + (∅ : Finset (Fin seq)) + let unmaskedKeys : Fin seq → Finset (Fin seq) := fun q => + (otherKeys q) \ (maskedKeys q) + let scoreBaseAbs : Fin seq → Fin seq → Dyadic := fun q k => + |inputs.scale| * dotAbs q k + let scoreLo : Fin seq → Fin seq → Dyadic := fun q k => + if masked q k then inputs.maskValue else -scoreBaseAbs q k + let scoreHi : Fin seq → Fin seq → Dyadic := fun q k => + if masked q k then inputs.maskValue else scoreBaseAbs q k + let kAbsMax : Fin dHead → Dyadic := fun d => + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := Finset.univ_nonempty + univ.sup' hnonempty (fun k => kAbs k d) + let dotAbsUpper : Fin seq → Dyadic := fun q => + Sound.Linear.dotFin dHead (fun d => qAbs q d) kAbsMax + let scoreHiUpper : Fin seq → Dyadic := fun q => + max inputs.maskValue (|inputs.scale| * dotAbsUpper q) + let fastGap : Fin seq → Dyadic := fun q => + let prev := inputs.prev q + scoreLo q prev - scoreHiUpper q + let marginTasks : Array (Task Dyadic) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + if q ∈ inputs.active then + let fast := fastGap q + if fast < 0 then + let other := unmaskedKeys q + let maskedSet := maskedKeys q + let exact := + if hunmasked : other.Nonempty then + let unmaskedMin := + other.inf' hunmasked (fun k => scoreLo q (inputs.prev q) - scoreHi q k) + if maskedSet.Nonempty then + min unmaskedMin (scoreLo q (inputs.prev q) - inputs.maskValue) + else + unmaskedMin + else + if maskedSet.Nonempty then + scoreLo q (inputs.prev q) - inputs.maskValue + else + (0 : Dyadic) + exact + else + fast + else + (0 : Dyadic))) + let marginAt : Fin seq → Dyadic := fun q => + (marginTasks[q.1]'(by + simp [marginTasks, q.isLt])).get + let epsTasks : Array (Task Dyadic) := + Array.ofFn (fun q : Fin seq => + (marginTasks[q.1]'(by + simp [marginTasks, q.isLt])).map (fun m => + if m < 0 then + (1 : Dyadic) + else + dyadicDivUp (seq - 1) (1 + m))) + let epsAt : Fin seq → Dyadic := fun q => + (epsTasks[q.1]'(by + simp [epsTasks, q.isLt])).get + let margin : Dyadic := + if h : inputs.active.Nonempty then + inputs.active.inf' h marginAt + else + (0 : Dyadic) + let eps : Dyadic := + if margin < 0 then + (1 : Dyadic) + else + dyadicDivUp (seq - 1) (1 + margin) + let result : Sound.HeadScoreBounds seq dModel dHead := + { dotAbs := dotAbs + scoreBaseAbs := scoreBaseAbs + scoreAbs := fun q k => if masked q k then |inputs.maskValue| else scoreBaseAbs q k + scoreLo := scoreLo + scoreHi := scoreHi + marginAt := marginAt + epsAt := epsAt + margin := margin + eps := eps } + return result + +private def checkInductionHeadInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (minActive? : Option Nat) (minLogitDiff? : Option Dyadic) + (minMargin maxEps : Dyadic) : IO UInt32 := do + match seq with + | 0 => + IO.eprintln "error: seq must be positive" + return 2 + | Nat.succ n => + let seq := Nat.succ n + let _ : NeZero seq := ⟨by simp⟩ + logTiming "start: head build induction cert" + IO.println "timing: head build induction cert start" + flushStdout + let verboseTiming ← IO.getEnv "NFP_TIMING_VERBOSE" + let taskBenchEnv ← IO.getEnv "NFP_TASK_BENCH" + if taskBenchEnv.isSome then + let n := taskBenchEnv.bind String.toNat? |>.getD 1000 + Nfp.IO.taskBench n + if verboseTiming.isSome then + IO.println s!"timing: head dims seq={seq} dModel={dModel} dHead={dHead}" + IO.println s!"timing: head active card={inputs.active.card}" + flushStdout + let precompute := (← IO.getEnv "NFP_TIMING_PRECOMPUTE").isSome + if precompute then + IO.println "timing: head ln bounds start" + flushStdout + let lnBounds ← timePure "head: ln bounds" (fun () => + Sound.headLnBounds inputs) + IO.println "timing: head ln bounds done" + flushStdout + IO.println "timing: head qkv bounds start" + flushStdout + let lnLo := lnBounds.1 + let lnHi := lnBounds.2 + let qkv ← timePure "head: qkv bounds" (fun () => + Sound.headQKVBounds inputs lnLo lnHi) + IO.println "timing: head qkv bounds done" + flushStdout + if verboseTiming.isSome then + IO.println "timing: head qkv abs force start" + flushStdout + let tAbs0 ← monoUsNow + for q in List.finRange seq do + for d in List.finRange dHead do + let _ := qkv.qAbs q d + let _ := qkv.kAbs q d + pure () + let tAbs1 ← monoUsNow + IO.println s!"timing: head qkv abs force {tAbs1 - tAbs0} us" + flushStdout + IO.println "timing: head score/value bounds spawn start" + flushStdout + let tSpawn0 ← monoUsNow + if verboseTiming.isSome then + IO.println "timing: head score dotAbs tasks start" + flushStdout + let dotAbs ← timePure "head: score dotAbs tasks" (fun () => + dotAbsFromQKV qkv.qAbs qkv.kAbs) + if verboseTiming.isSome then + IO.println "timing: head score dotAbs tasks done" + flushStdout + if verboseTiming.isSome then + IO.println "timing: head score dotAbs force start" + flushStdout + let tForce0 ← monoUsNow + match List.finRange seq with + | [] => + IO.println "timing: head score dotAbs force skipped (empty seq)" + | q :: _ => + match List.finRange seq with + | [] => + IO.println "timing: head score dotAbs force skipped (empty seq)" + | k :: _ => + let _ := dotAbs q k + pure () + let tForce1 ← monoUsNow + IO.println s!"timing: head score dotAbs force {tForce1 - tForce0} us" + flushStdout + let inlineVals := (← IO.getEnv "NFP_TIMING_VALUE_INLINE").isSome + let valueMode? ← valueBoundsModeFromEnv + let useCommon := valueMode?.getD false + let (valsInline?, valsTask?) := + if inlineVals then + let vals := + if useCommon then + Sound.headValueBoundsCommonDen inputs qkv.vLo qkv.vHi + else + Sound.headValueBounds inputs qkv.vLo qkv.vHi + (some vals, none) + else + let task := Task.spawn (fun _ => + if useCommon then + Sound.headValueBoundsCommonDen inputs qkv.vLo qkv.vHi + else + Sound.headValueBounds inputs qkv.vLo qkv.vHi) + (none, some task) + let activeList := (List.finRange seq).filter (fun q => q ∈ inputs.active) + if verboseTiming.isSome then + timeHeadScoreMarginRaw inputs dotAbs activeList + let tSpawn1 ← monoUsNow + IO.println s!"timing: head score/value bounds spawn {tSpawn1 - tSpawn0} us" + flushStdout + let skipScoreBounds := (← IO.getEnv "NFP_TIMING_SKIP_SCORE_BOUNDS").isSome + let scoreOpt ← + if skipScoreBounds then + IO.println "timing: head score bounds skipped" + pure none + else + IO.println "timing: head score bounds from dotAbs start" + flushStdout + let fastMargin := (← IO.getEnv "NFP_TIMING_FAST_MARGIN").isSome + let score ← + if fastMargin then + headScoreBoundsFromQAbsKAbsTimed inputs qkv.qAbs qkv.kAbs dotAbs + else + headScoreBoundsFromDotAbsTimed inputs dotAbs + IO.println "timing: head score bounds from dotAbs done" + flushStdout + pure (some score) + match scoreOpt with + | none => pure () + | some score => + if verboseTiming.isSome then + timeHeadScoreSampleGap inputs score + if verboseTiming.isSome then + timeHeadScoreMarginList activeList score + if verboseTiming.isSome then + timeHeadScoreFieldForces score + if verboseTiming.isSome then + IO.println "timing: head score bounds force start" + flushStdout + let tScore0 ← monoUsNow + let _ := score.margin + let _ := score.eps + let tScore1 ← monoUsNow + IO.println s!"timing: head score bounds force {tScore1 - tScore0} us" + flushStdout + if verboseTiming.isSome then + IO.println "timing: head value parts start" + flushStdout + IO.println "timing: head value dirHead start" + flushStdout + let tDir0 ← monoUsNow + let dirHead := Sound.headValueDirHead inputs + match List.finRange dHead with + | [] => + IO.println "timing: head value dirHead forced skipped (empty dHead)" + | d :: _ => + let _ := dirHead d + pure () + let tDir1 ← monoUsNow + IO.println s!"timing: head value dirHead {tDir1 - tDir0} us" + flushStdout + IO.println "timing: head value valsLo start" + flushStdout + let tLo0 ← monoUsNow + let valsLo := Sound.headValueValsLo inputs qkv.vLo qkv.vHi + match List.finRange seq with + | [] => + IO.println "timing: head value valsLo forced skipped (empty seq)" + | k :: _ => + let _ := valsLo k + pure () + let tLo1 ← monoUsNow + IO.println s!"timing: head value valsLo {tLo1 - tLo0} us" + flushStdout + IO.println "timing: head value valsHi start" + flushStdout + let tHi0 ← monoUsNow + let valsHi := Sound.headValueValsHi inputs qkv.vLo qkv.vHi + match List.finRange seq with + | [] => + IO.println "timing: head value valsHi forced skipped (empty seq)" + | k :: _ => + let _ := valsHi k + pure () + let tHi1 ← monoUsNow + IO.println s!"timing: head value valsHi {tHi1 - tHi0} us" + flushStdout + IO.println "timing: head value lo start" + flushStdout + let tLo2 ← monoUsNow + let _ := Sound.headValueLo valsLo + let tLo3 ← monoUsNow + IO.println s!"timing: head value lo {tLo3 - tLo2} us" + flushStdout + IO.println "timing: head value hi start" + flushStdout + let tHi2 ← monoUsNow + let _ := Sound.headValueHi valsHi + let tHi3 ← monoUsNow + IO.println s!"timing: head value hi {tHi3 - tHi2} us" + flushStdout + IO.println "timing: head value parts done" + flushStdout + IO.println "timing: head value bounds start" + flushStdout + let tVals0 ← monoUsNow + let vals ← + match valsInline?, valsTask? with + | some vals, _ => + timePure "head: value bounds inline" (fun () => vals) + | none, some valsTask => + timePure "head: value bounds wait" (fun () => valsTask.get) + | none, none => + timePure "head: value bounds inline" (fun () => + Sound.headValueBounds inputs qkv.vLo qkv.vHi) + let tVals1 ← monoUsNow + IO.println s!"timing: head value bounds {tVals1 - tVals0} us" + flushStdout + let certOpt : + Option { c : Sound.InductionHeadCert seq // Sound.InductionHeadCertSound inputs c } ← + timePure "head: build induction cert" (fun () => + match Sound.buildInductionCertFromHead? inputs with + | none => none + | some ⟨cert, hcert⟩ => + let _ := cert.active.card + some ⟨cert, hcert⟩) + IO.println "timing: head build induction cert returned" + flushStdout + logTiming "done: head build induction cert" + match certOpt with + | none => + IO.eprintln "error: head inputs rejected" + return 2 + | some ⟨cert, _hcert⟩ => + IO.println "timing: head active count start" + flushStdout + let activeCount := cert.active.card + IO.println "timing: head active count done" + flushStdout + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" + return 2 + if cert.margin < minMargin then + IO.eprintln + s!"error: margin {dyadicToString cert.margin} \ + below minimum {dyadicToString minMargin}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {dyadicToString cert.eps} \ + above maximum {dyadicToString maxEps}" + return 2 + IO.println "timing: head tol start" + flushStdout + let tol := cert.eps * (cert.values.hi - cert.values.lo) + IO.println "timing: head tol done" + flushStdout + logTiming "start: head logit-diff lower bound" + IO.println "timing: head logit-diff lower bound start" + flushStdout + let logitDiffLB? ← timePure "head: logit-diff lower bound" (fun () => + Circuit.logitDiffLowerBound cert.active cert.prev cert.eps + cert.values.lo cert.values.hi cert.values.valsLo) + logTiming "done: head logit-diff lower bound" + let effectiveMinLogitDiff := + match minLogitDiff? with + | some v => some v + | none => some (0 : Dyadic) + match logitDiffLB? with + | none => + IO.eprintln "error: empty active set for logit-diff bound" + return 2 + | some logitDiffLB => + let violation? : Option Dyadic := + match effectiveMinLogitDiff with + | none => none + | some minLogitDiff => + if logitDiffLB < minLogitDiff then + some minLogitDiff + else + none + match violation? with + | some minLogitDiff => + IO.eprintln + s!"error: logitDiffLB {dyadicToString logitDiffLB} \ + below minimum {dyadicToString minLogitDiff}" + return 2 + | none => + IO.println + s!"ok: induction bound certified \ + (seq={seq}, active={activeCount}, \ + tol={dyadicToString tol}, logitDiffLB={dyadicToString logitDiffLB})" + return 0 + +private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (minActive? : Option Nat) (minLogitDiff? : Option Dyadic) + (minMargin maxEps : Dyadic) : IO UInt32 := do + match seq with + | 0 => + IO.eprintln "error: seq must be positive" + return 2 + | Nat.succ n => + let seq := Nat.succ n + let _ : NeZero seq := ⟨by simp⟩ + logTiming "start: head build nonvacuous logit-diff" + let res : Option (Sound.InductionLogitLowerBoundNonvacuous inputs) ← + timePure "head: build nonvacuous logit-diff" (fun () => + Sound.buildInductionLogitLowerBoundNonvacuous? inputs) + logTiming "done: head build nonvacuous logit-diff" + match res with + | none => + IO.eprintln "error: nonvacuous logit-diff construction failed" + return 2 + | some result => + let cert := result.base.cert + let logitDiffLB := result.base.lb + let activeCount := cert.active.card + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" + return 2 + if cert.margin < minMargin then + IO.eprintln + s!"error: margin {dyadicToString cert.margin} \ + below minimum {dyadicToString minMargin}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {dyadicToString cert.eps} above maximum {dyadicToString maxEps}" + return 2 + match minLogitDiff? with + | some minLogitDiff => + if logitDiffLB < minLogitDiff then + IO.eprintln + s!"error: logitDiffLB {dyadicToString logitDiffLB} \ + below minimum {dyadicToString minLogitDiff}" + return 2 + | none => pure () + let tol := cert.eps * (cert.values.hi - cert.values.lo) + IO.println + s!"ok: nonvacuous induction bound certified \ + (seq={seq}, active={activeCount}, \ + tol={dyadicToString tol}, logitDiffLB={dyadicToString logitDiffLB})" + return 0 + +/-- Build and check induction certificates from exact head inputs. -/ +def runInductionCertifyHead (inputsPath : System.FilePath) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseDyadicOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseDyadicOpt "min-margin" minMarginStr? + let maxEps?E := parseDyadicOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do + let minMargin := minMargin?.getD (0 : Dyadic) + let maxEps := maxEps?.getD (dyadicOfRatDown (Rat.divInt 1 2)) + let parsedInputs ← timePhase "load head inputs" <| + loadInductionHeadInputs inputsPath + match parsedInputs with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => + checkInductionHeadInputs inputs minActive? minLogitDiff? minMargin maxEps + +/-- Build and check a strictly positive induction logit-diff bound from head inputs. -/ +def runInductionCertifyHeadNonvacuous (inputsPath : System.FilePath) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseDyadicOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseDyadicOpt "min-margin" minMarginStr? + let maxEps?E := parseDyadicOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do + let minMargin := minMargin?.getD (0 : Dyadic) + let maxEps := maxEps?.getD (dyadicOfRatDown (Rat.divInt 1 2)) + let parsedInputs ← timePhase "load head inputs" <| + loadInductionHeadInputs inputsPath + match parsedInputs with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => + checkInductionHeadInputsNonvacuous inputs minActive? minLogitDiff? minMargin maxEps + +/-- Build and check induction certificates from a model binary. -/ +def runInductionCertifyHeadModel (modelPath : System.FilePath) + (layer head dirTarget dirNegative : Nat) (period? : Option Nat) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseDyadicOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseDyadicOpt "min-margin" minMarginStr? + let maxEps?E := parseDyadicOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => + let minMargin := minMargin?.getD (0 : Dyadic) + let maxEps := maxEps?.getD (dyadicOfRatDown (Rat.divInt 1 2)) + logTiming "start: read model file" + IO.println "timing: read model file start" + flushStdout + let data ← timePhase "read model file" <| IO.FS.readBinFile modelPath + let headerE ← timePure "parse model header" (fun () => + NfptPure.parseHeader data) + match headerE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨header, start⟩ => + let inputsE ← timePure "read head inputs" (fun () => + NfptPure.readInductionHeadInputs + data start header layer head dirTarget dirNegative period?) + match inputsE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok inputs => + checkInductionHeadInputs inputs minActive? minLogitDiff? minMargin maxEps + +/-- Build and check a strictly positive induction logit-diff bound from a model binary. -/ +def runInductionCertifyHeadModelNonvacuous (modelPath : System.FilePath) + (layer head dirTarget dirNegative : Nat) (period? : Option Nat) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseDyadicOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseDyadicOpt "min-margin" minMarginStr? + let maxEps?E := parseDyadicOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => + let minMargin := minMargin?.getD (0 : Dyadic) + let maxEps := maxEps?.getD (dyadicOfRatDown (Rat.divInt 1 2)) + logTiming "start: read model file" + IO.println "timing: read model file start" + flushStdout + let data ← timePhase "read model file" <| IO.FS.readBinFile modelPath + let headerE ← timePure "parse model header" (fun () => + NfptPure.parseHeader data) + match headerE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨header, start⟩ => + let inputsE ← timePure "read head inputs" (fun () => + NfptPure.readInductionHeadInputs + data start header layer head dirTarget dirNegative period?) + match inputsE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok inputs => + checkInductionHeadInputsNonvacuous inputs minActive? minLogitDiff? minMargin maxEps + +/-- Build head-output interval bounds from exact head inputs. -/ +def runInductionHeadInterval (inputsPath : System.FilePath) + (outPath? : Option System.FilePath) : IO UInt32 := do + let parsedInputs ← loadInductionHeadInputs inputsPath + match parsedInputs with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => + buildHeadOutputIntervalFromInputs inputs outPath? + +/-- Build head-output interval bounds from a model binary. -/ +def runInductionHeadIntervalModel (modelPath : System.FilePath) + (layer head dirTarget dirNegative : Nat) (period? : Option Nat) + (outPath? : Option System.FilePath) : IO UInt32 := do + let data ← IO.FS.readBinFile modelPath + match NfptPure.parseHeader data with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨header, start⟩ => + match + NfptPure.readInductionHeadInputs + data start header layer head dirTarget dirNegative period? + with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok inputs => + buildHeadOutputIntervalFromInputs inputs outPath? + +end IO + +end Nfp diff --git a/Nfp/IO/Loaders.lean b/Nfp/IO/Loaders.lean new file mode 100644 index 0000000..7efc4ac --- /dev/null +++ b/Nfp/IO/Loaders.lean @@ -0,0 +1,70 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.IO.Pure +import Nfp.Circuit.Cert.LogitDiff +import Nfp.Circuit.Cert.DownstreamLinear +import Nfp.Circuit.Cert.ResidualBound +import Nfp.Circuit.Cert.ResidualInterval + +/-! +IO loaders for certificates and raw inputs. +-/ + +namespace Nfp + +namespace IO + +open Nfp.Circuit + +/-- Load a softmax-margin certificate from disk. -/ +def loadSoftmaxMarginCert (path : System.FilePath) : + IO (Except String (Sigma SoftmaxMarginCert)) := do + let data ← IO.FS.readFile path + return Pure.parseSoftmaxMarginCert data + +/-- Load raw softmax-margin inputs from disk. -/ +def loadSoftmaxMarginRaw (path : System.FilePath) : + IO (Except String (Sigma Pure.SoftmaxMarginRaw)) := do + let data ← IO.FS.readFile path + return Pure.parseSoftmaxMarginRaw data + +/-- Load a value-range certificate from disk. -/ +def loadValueRangeCert (path : System.FilePath) : + IO (Except String (Sigma ValueRangeCert)) := do + let data ← IO.FS.readFile path + return Pure.parseValueRangeCert data + +/-- Load a downstream linear certificate from disk. -/ +def loadDownstreamLinearCert (path : System.FilePath) : + IO (Except String DownstreamLinearCert) := do + let data ← IO.FS.readFile path + return Pure.parseDownstreamLinearCert data + +/-- Load a downstream matrix payload from disk. -/ +def loadDownstreamMatrixRaw (path : System.FilePath) : + IO (Except String (Sigma (fun rows => + Sigma (fun cols => Pure.DownstreamMatrixRaw rows cols)))) := do + let data ← IO.FS.readFile path + return Pure.parseDownstreamMatrixRaw data + +/-- Load a residual-bound certificate from disk. -/ +def loadResidualBoundCert (path : System.FilePath) : + IO (Except String (Sigma (fun n => ResidualBoundCert n))) := do + let data ← IO.FS.readFile path + return Pure.parseResidualBoundCert data + +/-- Load a residual-interval certificate from disk. -/ +def loadResidualIntervalCert (path : System.FilePath) : + IO (Except String (Sigma (fun n => ResidualIntervalCert n))) := do + let data ← IO.FS.readFile path + return Pure.parseResidualIntervalCert data + +/-- Load raw value-range inputs from disk. -/ +def loadValueRangeRaw (path : System.FilePath) : + IO (Except String (Sigma Pure.ValueRangeRaw)) := do + let data ← IO.FS.readFile path + return Pure.parseValueRangeRaw data + +end IO + +end Nfp diff --git a/Nfp/IO/NfptPure.lean b/Nfp/IO/NfptPure.lean index ad9cea3..96a52ae 100644 --- a/Nfp/IO/NfptPure.lean +++ b/Nfp/IO/NfptPure.lean @@ -1,7 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.Order.Ring.Rat import Mathlib.Data.List.Range +import Nfp.Core.Basic import Nfp.Model.Gpt2 import Nfp.Model.InductionHead import Nfp.Model.InductionPrompt @@ -9,7 +9,7 @@ import Nfp.Model.InductionPrompt /-! Pure parsing utilities for `NFP_BINARY_V1` model files. -These helpers parse headers and extract selected weight slices as exact rationals. +These helpers parse headers and extract selected weight slices as dyadic values. -/ namespace Nfp @@ -35,7 +35,7 @@ structure NfptHeader where /-- Sequence length used in the binary. -/ seqLen : Nat /-- LayerNorm epsilon parameter. -/ - layerNormEps : Rat + layerNormEps : Dyadic /-- Array with a fixed size proof. -/ structure SizedArray (n : Nat) (α : Type) where @@ -72,7 +72,7 @@ private def parseInt (s : String) : Except String Int := private def pow10 (k : Nat) : Nat := Nat.pow 10 k -private def parseRatScientific (s : String) : Except String Rat := do +private def parseDyadicScientific (s : String) : Except String Dyadic := do let s := s.trim let (sign, rest) := if s.startsWith "-" then @@ -105,13 +105,13 @@ private def parseRatScientific (s : String) : Except String Rat := do | some e => parseInt e if exp ≥ 0 then let k := Int.toNat exp - pure (base * Rat.ofInt (Int.ofNat (pow10 k))) + pure (dyadicOfRatDown (base * Rat.ofInt (Int.ofNat (pow10 k)))) else let k := Int.toNat (-exp) - pure (base / Rat.ofInt (Int.ofNat (pow10 k))) + pure (dyadicOfRatDown (base / Rat.ofInt (Int.ofNat (pow10 k)))) -private def readHeaderFieldRat (names : List String) (fields : List (String × String)) : - Except String Rat := do +private def readHeaderFieldDyadic (names : List String) (fields : List (String × String)) : + Except String Dyadic := do let rec loop : List String → Option String | [] => none | name :: rest => @@ -119,7 +119,7 @@ private def readHeaderFieldRat (names : List String) (fields : List (String × S | some kv => some kv.2 | none => loop rest match loop names with - | some raw => parseRatScientific raw + | some raw => parseDyadicScientific raw | none => throw s!"missing header field '{String.intercalate "|" names}'" private def sentinelBytes : ByteArray := @@ -169,7 +169,7 @@ def parseHeader (data : ByteArray) : Except String (NfptHeader × Nat) := do let hiddenDim ← readHeaderField "hidden_dim" fields let vocabSize ← readHeaderField "vocab_size" fields let seqLen ← readHeaderField "seq_len" fields - let layerNormEps ← readHeaderFieldRat ["layer_norm_eps", "eps"] fields + let layerNormEps ← readHeaderFieldDyadic ["layer_norm_eps", "eps"] fields if numLayers = 0 then throw "num_layers must be positive" if numHeads = 0 then @@ -200,7 +200,7 @@ private def pow2 (k : Nat) : Nat := private def getBits (n hi lo : Nat) : Nat := (n / pow2 lo) % pow2 (hi - lo + 1) -private def ratOfFloatBits (bits : Nat) : Option Rat := +private def dyadicOfFloatBits (bits : Nat) : Option Dyadic := let signBit := getBits bits 63 63 let expBits := getBits bits 62 52 let mantBits := getBits bits 51 0 @@ -212,19 +212,13 @@ private def ratOfFloatBits (bits : Nat) : Option Rat := some 0 else let num : Int := sign * Int.ofNat mantBits - let denom : Int := Int.ofNat (pow2 1074) - some (Rat.ofInt num / Rat.ofInt denom) + some (Dyadic.ofIntWithPrec num 1074) else let mant := mantBits + pow2 52 let exp := expBits - 1023 let shift : Int := Int.ofNat exp - 52 - let base : Rat := Rat.ofInt (sign * Int.ofNat mant) - if 0 ≤ shift then - let k : Nat := Int.toNat shift - some (base * Rat.ofInt (Int.ofNat (pow2 k))) - else - let k : Nat := Int.toNat (-shift) - some (base / Rat.ofInt (Int.ofNat (pow2 k))) + let prec : Int := -shift + some (Dyadic.ofIntWithPrec (sign * Int.ofNat mant) prec) private def readNatLE (data : ByteArray) (off : Nat) (count : Nat) : Option Nat := if off + count ≤ data.size then @@ -247,9 +241,9 @@ private def readI32 (data : ByteArray) (off : Nat) : Option Int := do else some (Int.ofNat bits - Int.ofNat two32) -private def readF64Rat (data : ByteArray) (off : Nat) : Option Rat := do +private def readF64Dyadic (data : ByteArray) (off : Nat) : Option Dyadic := do let bits ← readNatLE data off 8 - ratOfFloatBits bits + dyadicOfFloatBits bits private def bytesI32 (n : Nat) : Nat := n * 4 @@ -264,13 +258,13 @@ private def sqrtNat? (n : Nat) : Option Nat := else none -private def scaleOfHeadDim (dHead : Nat) : Except String Rat := do +private def scaleOfHeadDim (dHead : Nat) : Except String Dyadic := do match sqrtNat? dHead with | some k => if k = 0 then throw "head_dim must be positive" else - pure (Rat.ofInt 1 / Rat.ofInt (Int.ofNat k)) + pure (dyadicOfRatDown (Rat.ofInt 1 / Rat.ofInt (Int.ofNat k))) | none => throw "head_dim must be a perfect square to compute scale" @@ -286,46 +280,108 @@ private def matrixIndex {rows cols : Nat} (i : Fin rows) (j : Fin cols) : Fin (r Nat.mul_le_mul_right cols (Nat.succ_le_iff.mpr i.isLt) ⟨idx, lt_of_lt_of_le hstep hle⟩ -private def readF64List (data : ByteArray) (off : Nat) (count : Nat) : - Except String {xs : List Rat // xs.length = count} := do - match count with - | 0 => return ⟨[], rfl⟩ - | Nat.succ n => - match readF64Rat data off with +private def readF64ListAux (data : ByteArray) (off : Nat) : + Nat → List Dyadic → Except String (List Dyadic) + | 0, acc => Except.ok acc.reverse + | Nat.succ n, acc => + match readF64Dyadic data off with + | some v => readF64ListAux data (off + bytesF64 1) n (v :: acc) + | none => Except.error s!"invalid f64 at offset {off}" + +private theorem readF64ListAux_length (data : ByteArray) : + ∀ (off n : Nat) (acc xs : List Dyadic), + readF64ListAux data off n acc = Except.ok xs → + xs.length = acc.length + n := by + intro off n acc xs h + induction n generalizing off acc xs with + | zero => + have h' := h + simp only [readF64ListAux] at h' + cases h' + simp + | succ n ih => + cases hread : readF64Dyadic data off with + | none => + have h' := h + simp only [readF64ListAux, hread] at h' + cases h' | some v => - let rest ← readF64List data (off + bytesF64 1) n - return ⟨v :: rest.1, by simp [rest.2]⟩ - | none => throw s!"invalid f64 at offset {off}" + have h' := h + simp only [readF64ListAux, hread] at h' + have hlen := ih (off := off + bytesF64 1) (acc := v :: acc) (xs := xs) h' + simpa [List.length, Nat.add_assoc, Nat.add_left_comm, Nat.add_comm] using hlen -private def readI32List (data : ByteArray) (off : Nat) (count : Nat) : - Except String {xs : List Int // xs.length = count} := do - match count with - | 0 => return ⟨[], rfl⟩ - | Nat.succ n => +private def readF64List (data : ByteArray) (off : Nat) (count : Nat) : + Except String {xs : List Dyadic // xs.length = count} := + match h : readF64ListAux data off count [] with + | Except.error msg => Except.error msg + | Except.ok xs => + have hlen : + xs.length = count := by + simpa using readF64ListAux_length (data := data) (off := off) + (n := count) (acc := []) (xs := xs) h + Except.ok ⟨xs, hlen⟩ + +private def readI32ListAux (data : ByteArray) (off : Nat) : + Nat → List Int → Except String (List Int) + | 0, acc => Except.ok acc.reverse + | Nat.succ n, acc => match readI32 data off with + | some v => readI32ListAux data (off + bytesI32 1) n (v :: acc) + | none => Except.error s!"invalid i32 at offset {off}" + +private theorem readI32ListAux_length (data : ByteArray) : + ∀ (off n : Nat) (acc xs : List Int), + readI32ListAux data off n acc = Except.ok xs → + xs.length = acc.length + n := by + intro off n acc xs h + induction n generalizing off acc xs with + | zero => + have h' := h + simp only [readI32ListAux] at h' + cases h' + simp + | succ n ih => + cases hread : readI32 data off with + | none => + have h' := h + simp only [readI32ListAux, hread] at h' + cases h' | some v => - let rest ← readI32List data (off + bytesI32 1) n - return ⟨v :: rest.1, by simp [rest.2]⟩ - | none => throw s!"invalid i32 at offset {off}" + have h' := h + simp only [readI32ListAux, hread] at h' + have hlen := ih (off := off + bytesI32 1) (acc := v :: acc) (xs := xs) h' + simpa [List.length, Nat.add_assoc, Nat.add_left_comm, Nat.add_comm] using hlen + +private def readI32List (data : ByteArray) (off : Nat) (count : Nat) : + Except String {xs : List Int // xs.length = count} := + match h : readI32ListAux data off count [] with + | Except.error msg => Except.error msg + | Except.ok xs => + have hlen : + xs.length = count := by + simpa using readI32ListAux_length (data := data) (off := off) + (n := count) (acc := []) (xs := xs) h + Except.ok ⟨xs, hlen⟩ private def readF64Matrix (data : ByteArray) (off : Nat) (rows cols : Nat) : - Except String (Fin rows → Fin cols → Rat) := do + Except String (Fin rows → Fin cols → Dyadic) := do let count := rows * cols let ⟨vals, hlen⟩ ← readF64List data off count let hlen' : vals.length = rows * cols := by simpa using hlen - let mat : Fin rows → Fin cols → Rat := fun i j => + let mat : Fin rows → Fin cols → Dyadic := fun i j => let idx := matrixIndex i j let hidx : idx.val < vals.length := lt_of_lt_of_eq idx.isLt hlen'.symm vals.get ⟨idx.val, hidx⟩ return mat private def readF64Vec (data : ByteArray) (off : Nat) (count : Nat) : - Except String (Fin count → Rat) := do + Except String (Fin count → Dyadic) := do let ⟨vals, hlen⟩ ← readF64List data off count let hlen' : vals.length = count := by simpa using hlen - let vec : Fin count → Rat := fun i => + let vec : Fin count → Dyadic := fun i => vals.get ⟨i.val, lt_of_lt_of_eq i.isLt hlen'.symm⟩ return vec @@ -355,7 +411,7 @@ private def finalLayerNormOffset (h : NfptHeader) : Nat := /-- Read input embeddings stored in the binary. -/ def readEmbeddings (data : ByteArray) (start : Nat) (h : NfptHeader) : - Except String (Fin h.seqLen → Fin h.modelDim → Rat) := do + Except String (Fin h.seqLen → Fin h.modelDim → Dyadic) := do let base := start + bytesI32 h.seqLen readF64Matrix data base h.seqLen h.modelDim @@ -404,7 +460,7 @@ def readHeadWeights (data : ByteArray) (start : Nat) (h : NfptHeader) let bv ← readF64Vec data offbv h.headDim let offwo := offbv + bytesF64 h.headDim let woRaw ← readF64Matrix data offwo h.headDim h.modelDim - let wo : Fin h.modelDim → Fin h.headDim → Rat := fun i j => woRaw j i + let wo : Fin h.modelDim → Fin h.headDim → Dyadic := fun i j => woRaw j i return { wq := wq, bq := bq, wk := wk, bk := bk, wv := wv, bv := bv, wo := wo } else throw s!"head index out of range: {head}" @@ -413,8 +469,8 @@ def readHeadWeights (data : ByteArray) (start : Nat) (h : NfptHeader) private def readLayerAttnBiasLn1 (data : ByteArray) (start : Nat) (h : NfptHeader) (layer : Nat) : - Except String ((Fin h.modelDim → Rat) × (Fin h.modelDim → Rat) × - (Fin h.modelDim → Rat)) := do + Except String ((Fin h.modelDim → Dyadic) × (Fin h.modelDim → Dyadic) × + (Fin h.modelDim → Dyadic)) := do if layer < h.numLayers then let base := start + layerExtrasOffset h layer let attnBias ← readF64Vec data base h.modelDim @@ -514,17 +570,17 @@ def readFinalLayerNorm (data : ByteArray) (start : Nat) (h : NfptHeader) : /-- Read a single unembedding column as exact rationals. -/ def readUnembedColumn (data : ByteArray) (start : Nat) (h : NfptHeader) (col : Nat) : - Except String (Fin h.modelDim → Rat) := do + Except String (Fin h.modelDim → Dyadic) := do if col < h.vocabSize then let base := start + unembedOffset h let rows := List.range h.modelDim let vals ← rows.mapM (fun row => do let off := base + bytesF64 (row * h.vocabSize + col) - match readF64Rat data off with + match readF64Dyadic data off with | some v => pure v | none => throw s!"invalid f64 at offset {off}") if hlen : vals.length = h.modelDim then - let vec : Fin h.modelDim → Rat := fun i => + let vec : Fin h.modelDim → Dyadic := fun i => vals.get ⟨i.val, by simp [hlen]⟩ return vec else @@ -543,7 +599,7 @@ def readInductionHeadInputs (data : ByteArray) (start : Nat) (h : NfptHeader) let (attnBias, ln1Gamma, ln1Beta) ← readLayerAttnBiasLn1 data start h layer let colTarget ← readUnembedColumn data start h dirTarget let colNegative ← readUnembedColumn data start h dirNegative - let direction : Fin h.modelDim → Rat := fun i => colTarget i - colNegative i + let direction : Fin h.modelDim → Dyadic := fun i => colTarget i - colNegative i let directionSpec : Circuit.DirectionSpec := { target := dirTarget, negative := dirNegative } let active := @@ -571,7 +627,7 @@ def readInductionHeadInputs (data : ByteArray) (start : Nat) (h : NfptHeader) wo := weights.wo attnBias := attnBias maskCausal := true - maskValue := (-10000 : Rat) + maskValue := (-10000 : Dyadic) directionSpec := directionSpec direction := direction } diff --git a/Nfp/IO/Pure.lean b/Nfp/IO/Pure.lean index dfb5151..9bb0758 100644 --- a/Nfp/IO/Pure.lean +++ b/Nfp/IO/Pure.lean @@ -1,1106 +1,12 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.Order.Ring.Rat -import Mathlib.Data.Finset.Insert -import Nfp.Circuit.Cert.SoftmaxMargin -import Nfp.Circuit.Cert.ValueRange -import Nfp.Circuit.Cert.DownstreamLinear -import Nfp.Circuit.Cert.ResidualBound -import Nfp.Circuit.Cert.ResidualInterval -import Nfp.Model.InductionHead +import Nfp.IO.Pure.Basic +import Nfp.IO.Pure.Downstream +import Nfp.IO.Pure.InductionHead +import Nfp.IO.Pure.Residual +import Nfp.IO.Pure.SoftmaxMargin +import Nfp.IO.Pure.ValueRange /-! -Pure parsing helpers for softmax-margin, value-range, and downstream certificates. +Aggregator for pure CLI parsing helpers. -/ - -namespace Nfp - -namespace IO - -namespace Pure - -open Nfp.Circuit - -private def splitWords (line : String) : List String := - line.splitToList (fun c => c = ' ' || c = '\t') |>.filter (· ≠ "") - -private def cleanTokens (line : String) : Option (List String) := - let trimmed := line.trim - if trimmed.isEmpty then - none - else if trimmed.startsWith "#" then - none - else - some (splitWords trimmed) - -private def parseNat (s : String) : Except String Nat := - match s.toNat? with - | some n => Except.ok n - | none => Except.error s!"expected Nat, got '{s}'" - -private def parseInt (s : String) : Except String Int := - match s.toInt? with - | some n => Except.ok n - | none => Except.error s!"expected Int, got '{s}'" - -/-- Parse a rational literal of the form `a` or `a/b`. -/ -def parseRat (s : String) : Except String Rat := do - match s.splitOn "/" with - | [num] => - return Rat.ofInt (← parseInt num) - | [num, den] => - let n ← parseInt num - let d ← parseNat den - if d = 0 then - throw s!"invalid rational '{s}': zero denominator" - else - return Rat.ofInt n / Rat.ofInt (Int.ofNat d) - | _ => - throw s!"invalid rational '{s}'" - -private structure SoftmaxMarginParseState (seq : Nat) where - eps : Option Rat - margin : Option Rat - active : Finset (Fin seq) - activeSeen : Bool - prev : Fin seq → Option (Fin seq) - scores : Fin seq → Fin seq → Option Rat - weights : Fin seq → Fin seq → Option Rat - -private def initState (seq : Nat) : SoftmaxMarginParseState seq := - { eps := none - margin := none - active := ∅ - activeSeen := false - prev := fun _ => none - scores := fun _ _ => none - weights := fun _ _ => none } - -private def setPrev {seq : Nat} (st : SoftmaxMarginParseState seq) - (q k : Nat) : Except String (SoftmaxMarginParseState seq) := do - if hq : q < seq then - if hk : k < seq then - let qFin : Fin seq := ⟨q, hq⟩ - let kFin : Fin seq := ⟨k, hk⟩ - match st.prev qFin with - | some _ => - throw s!"duplicate prev entry for q={q}" - | none => - let prev' : Fin seq → Option (Fin seq) := fun q' => - if q' = qFin then - some kFin - else - st.prev q' - return { st with prev := prev' } - else - throw s!"prev index out of range: k={k}" - else - throw s!"prev index out of range: q={q}" - -private def setActive {seq : Nat} (st : SoftmaxMarginParseState seq) - (q : Nat) : Except String (SoftmaxMarginParseState seq) := do - if hq : q < seq then - let qFin : Fin seq := ⟨q, hq⟩ - if qFin ∈ st.active then - throw s!"duplicate active entry for q={q}" - else - return { st with active := insert qFin st.active, activeSeen := true } - else - throw s!"active index out of range: q={q}" - -private def setMatrixEntry {seq : Nat} (mat : Fin seq → Fin seq → Option Rat) - (q k : Nat) (v : Rat) : Except String (Fin seq → Fin seq → Option Rat) := do - if hq : q < seq then - if hk : k < seq then - let qFin : Fin seq := ⟨q, hq⟩ - let kFin : Fin seq := ⟨k, hk⟩ - match mat qFin kFin with - | some _ => - throw s!"duplicate matrix entry at ({q}, {k})" - | none => - let mat' : Fin seq → Fin seq → Option Rat := fun q' k' => - if q' = qFin then - if k' = kFin then - some v - else - mat q' k' - else - mat q' k' - return mat' - else - throw s!"index out of range: k={k}" - else - throw s!"index out of range: q={q}" - -private def parseLine {seq : Nat} (st : SoftmaxMarginParseState seq) - (tokens : List String) : Except String (SoftmaxMarginParseState seq) := do - match tokens with - | ["eps", val] => - if st.eps.isSome then - throw "duplicate eps entry" - else - return { st with eps := some (← parseRat val) } - | ["margin", val] => - if st.margin.isSome then - throw "duplicate margin entry" - else - return { st with margin := some (← parseRat val) } - | ["active", q] => - setActive st (← parseNat q) - | ["prev", q, k] => - setPrev st (← parseNat q) (← parseNat k) - | ["score", q, k, val] => - let mat ← setMatrixEntry st.scores (← parseNat q) (← parseNat k) (← parseRat val) - return { st with scores := mat } - | ["weight", q, k, val] => - let mat ← setMatrixEntry st.weights (← parseNat q) (← parseNat k) (← parseRat val) - return { st with weights := mat } - | _ => - throw s!"unrecognized line: '{String.intercalate " " tokens}'" - -private def finalizeState {seq : Nat} (hpos : 0 < seq) - (st : SoftmaxMarginParseState seq) : Except String (SoftmaxMarginCert seq) := do - let eps ← - match st.eps with - | some v => pure v - | none => throw "missing eps entry" - let margin ← - match st.margin with - | some v => pure v - | none => throw "missing margin entry" - if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => (st.prev q).isSome) then - throw "missing prev entries" - if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.scores q k).isSome)) then - throw "missing score entries" - if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.weights q k).isSome)) then - throw "missing weight entries" - let defaultPrev : Fin seq := ⟨0, hpos⟩ - let prevFun : Fin seq → Fin seq := fun q => - (st.prev q).getD defaultPrev - let scoresFun : Fin seq → Fin seq → Rat := fun q k => - (st.scores q k).getD 0 - let weightsFun : Fin seq → Fin seq → Rat := fun q k => - (st.weights q k).getD 0 - let active := - if st.activeSeen then - st.active - else - (Finset.univ : Finset (Fin seq)).erase defaultPrev - pure - { eps := eps - margin := margin - active := active - prev := prevFun - scores := scoresFun - weights := weightsFun } - -/-- Parse a softmax-margin certificate from a text payload. -/ -def parseSoftmaxMarginCert (input : String) : - Except String (Sigma SoftmaxMarginCert) := do - let lines := input.splitOn "\n" - let tokens := lines.filterMap cleanTokens - let mut seq? : Option Nat := none - for t in tokens do - match t with - | ["seq", n] => - if seq?.isSome then - throw "duplicate seq entry" - else - seq? := some (← parseNat n) - | _ => pure () - let seq ← - match seq? with - | some v => pure v - | none => throw "missing seq entry" - match seq with - | 0 => throw "seq must be positive" - | Nat.succ n => - let seq := Nat.succ n - let hpos : 0 < seq := Nat.succ_pos n - let st0 : SoftmaxMarginParseState seq := initState seq - let st ← tokens.foldlM (fun st t => - match t with - | ["seq", _] => pure st - | _ => parseLine st t) st0 - let cert ← finalizeState hpos st - return ⟨seq, cert⟩ - -/-- Raw softmax-margin payload without `eps`/`margin`. -/ -structure SoftmaxMarginRaw (seq : Nat) where - /-- Active queries for which bounds are required. -/ - active : Finset (Fin seq) - /-- `prev` selector for induction-style attention. -/ - prev : Fin seq → Fin seq - /-- Score matrix entries. -/ - scores : Fin seq → Fin seq → Rat - /-- Attention weight entries. -/ - weights : Fin seq → Fin seq → Rat - -private def finalizeRawState {seq : Nat} (hpos : 0 < seq) - (st : SoftmaxMarginParseState seq) : Except String (SoftmaxMarginRaw seq) := do - if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => (st.prev q).isSome) then - throw "missing prev entries" - if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.scores q k).isSome)) then - throw "missing score entries" - if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.weights q k).isSome)) then - throw "missing weight entries" - let defaultPrev : Fin seq := ⟨0, hpos⟩ - let prevFun : Fin seq → Fin seq := fun q => - (st.prev q).getD defaultPrev - let scoresFun : Fin seq → Fin seq → Rat := fun q k => - (st.scores q k).getD 0 - let weightsFun : Fin seq → Fin seq → Rat := fun q k => - (st.weights q k).getD 0 - let active := - if st.activeSeen then - st.active - else - (Finset.univ : Finset (Fin seq)).erase defaultPrev - pure - { active := active - prev := prevFun - scores := scoresFun - weights := weightsFun } - -/-- Parse a raw softmax-margin payload from text (ignores any `eps`/`margin`). -/ -def parseSoftmaxMarginRaw (input : String) : - Except String (Sigma SoftmaxMarginRaw) := do - let lines := input.splitOn "\n" - let tokens := lines.filterMap cleanTokens - let mut seq? : Option Nat := none - for t in tokens do - match t with - | ["seq", n] => - if seq?.isSome then - throw "duplicate seq entry" - else - seq? := some (← parseNat n) - | _ => pure () - let seq ← - match seq? with - | some v => pure v - | none => throw "missing seq entry" - match seq with - | 0 => throw "seq must be positive" - | Nat.succ n => - let seq := Nat.succ n - let hpos : 0 < seq := Nat.succ_pos n - let st0 : SoftmaxMarginParseState seq := initState seq - let st ← tokens.foldlM (fun st t => - match t with - | ["seq", _] => pure st - | _ => parseLine st t) st0 - let raw ← finalizeRawState hpos st - return ⟨seq, raw⟩ - -private structure ValueRangeParseState (seq : Nat) where - lo : Option Rat - hi : Option Rat - vals : Fin seq → Option Rat - directionTarget : Option Nat - directionNegative : Option Nat - -private def initValueRangeState (seq : Nat) : ValueRangeParseState seq := - { lo := none - hi := none - vals := fun _ => none - directionTarget := none - directionNegative := none } - -private def setVal {seq : Nat} (st : ValueRangeParseState seq) - (k : Nat) (v : Rat) : Except String (ValueRangeParseState seq) := do - if hk : k < seq then - let kFin : Fin seq := ⟨k, hk⟩ - match st.vals kFin with - | some _ => - throw s!"duplicate value entry for k={k}" - | none => - let vals' : Fin seq → Option Rat := fun k' => - if k' = kFin then - some v - else - st.vals k' - return { st with vals := vals' } - else - throw s!"value index out of range: k={k}" - -private def parseValueLine {seq : Nat} (st : ValueRangeParseState seq) - (tokens : List String) : Except String (ValueRangeParseState seq) := do - match tokens with - | ["lo", val] => - if st.lo.isSome then - throw "duplicate lo entry" - else - return { st with lo := some (← parseRat val) } - | ["hi", val] => - if st.hi.isSome then - throw "duplicate hi entry" - else - return { st with hi := some (← parseRat val) } - | ["val", k, val] => - setVal st (← parseNat k) (← parseRat val) - | ["direction-target", tok] => - if st.directionTarget.isSome then - throw "duplicate direction-target entry" - else - return { st with directionTarget := some (← parseNat tok) } - | ["direction-negative", tok] => - if st.directionNegative.isSome then - throw "duplicate direction-negative entry" - else - return { st with directionNegative := some (← parseNat tok) } - | _ => - throw s!"unrecognized line: '{String.intercalate " " tokens}'" - -private def finalizeValueState {seq : Nat} (st : ValueRangeParseState seq) : - Except String (ValueRangeCert seq) := do - let lo ← - match st.lo with - | some v => pure v - | none => throw "missing lo entry" - let hi ← - match st.hi with - | some v => pure v - | none => throw "missing hi entry" - if !finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.vals k).isSome) then - throw "missing value entries" - let valsFun : Fin seq → Rat := fun k => - (st.vals k).getD 0 - let direction ← - match st.directionTarget, st.directionNegative with - | none, none => pure none - | some target, some negative => - pure (some { target := target, negative := negative }) - | _, _ => - throw "direction metadata requires both direction-target and direction-negative" - return { lo := lo, hi := hi, vals := valsFun, direction := direction } - -/-- Parse a value-range certificate from a text payload. -/ -def parseValueRangeCert (input : String) : - Except String (Sigma ValueRangeCert) := do - let lines := input.splitOn "\n" - let tokens := lines.filterMap cleanTokens - let mut seq? : Option Nat := none - for t in tokens do - match t with - | ["seq", n] => - if seq?.isSome then - throw "duplicate seq entry" - else - seq? := some (← parseNat n) - | _ => pure () - let seq ← - match seq? with - | some v => pure v - | none => throw "missing seq entry" - match seq with - | 0 => throw "seq must be positive" - | Nat.succ n => - let seq := Nat.succ n - let st0 : ValueRangeParseState seq := initValueRangeState seq - let st ← tokens.foldlM (fun st t => - match t with - | ["seq", _] => pure st - | _ => parseValueLine st t) st0 - let cert ← finalizeValueState st - return ⟨seq, cert⟩ - -/-- Raw value-range payload without `lo`/`hi` bounds. -/ -structure ValueRangeRaw (seq : Nat) where - /-- Value entries. -/ - vals : Fin seq → Rat - /-- Optional logit-diff direction metadata. -/ - direction : Option Circuit.DirectionSpec - -private def finalizeValueRawState {seq : Nat} (st : ValueRangeParseState seq) : - Except String (ValueRangeRaw seq) := do - if !finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.vals k).isSome) then - throw "missing value entries" - let valsFun : Fin seq → Rat := fun k => - (st.vals k).getD 0 - let direction ← - match st.directionTarget, st.directionNegative with - | none, none => pure none - | some target, some negative => - pure (some { target := target, negative := negative }) - | _, _ => - throw "direction metadata requires both direction-target and direction-negative" - return { vals := valsFun, direction := direction } - -/-- Parse a raw value-range payload from text (ignores any `lo`/`hi`). -/ -def parseValueRangeRaw (input : String) : - Except String (Sigma ValueRangeRaw) := do - let lines := input.splitOn "\n" - let tokens := lines.filterMap cleanTokens - let mut seq? : Option Nat := none - for t in tokens do - match t with - | ["seq", n] => - if seq?.isSome then - throw "duplicate seq entry" - else - seq? := some (← parseNat n) - | _ => pure () - let seq ← - match seq? with - | some v => pure v - | none => throw "missing seq entry" - match seq with - | 0 => throw "seq must be positive" - | Nat.succ n => - let seq := Nat.succ n - let st0 : ValueRangeParseState seq := initValueRangeState seq - let st ← tokens.foldlM (fun st t => - match t with - | ["seq", _] => pure st - | _ => parseValueLine st t) st0 - let raw ← finalizeValueRawState st - return ⟨seq, raw⟩ - -private structure DownstreamLinearParseState where - error : Option Rat - gain : Option Rat - inputBound : Option Rat - -private def initDownstreamLinearState : DownstreamLinearParseState := - { error := none, gain := none, inputBound := none } - -private def parseDownstreamLinearLine (st : DownstreamLinearParseState) - (tokens : List String) : Except String DownstreamLinearParseState := do - match tokens with - | ["error", val] => - if st.error.isSome then - throw "duplicate error entry" - else - return { st with error := some (← parseRat val) } - | ["gain", val] => - if st.gain.isSome then - throw "duplicate gain entry" - else - return { st with gain := some (← parseRat val) } - | ["input-bound", val] => - if st.inputBound.isSome then - throw "duplicate input-bound entry" - else - return { st with inputBound := some (← parseRat val) } - | _ => - throw s!"unrecognized line: '{String.intercalate " " tokens}'" - -private def finalizeDownstreamLinearState (st : DownstreamLinearParseState) : - Except String Circuit.DownstreamLinearCert := do - let error ← - match st.error with - | some v => pure v - | none => throw "missing error entry" - let gain ← - match st.gain with - | some v => pure v - | none => throw "missing gain entry" - let inputBound ← - match st.inputBound with - | some v => pure v - | none => throw "missing input-bound entry" - return { error := error, gain := gain, inputBound := inputBound } - -/-- Parse a downstream linear certificate from a text payload. -/ -def parseDownstreamLinearCert (input : String) : - Except String Circuit.DownstreamLinearCert := do - let lines := input.splitOn "\n" - let tokens := lines.filterMap cleanTokens - let st0 := initDownstreamLinearState - let st ← tokens.foldlM (fun st t => parseDownstreamLinearLine st t) st0 - finalizeDownstreamLinearState st - -private def setVecEntry {n : Nat} (vec : Fin n → Option Rat) - (i : Nat) (v : Rat) : Except String (Fin n → Option Rat) := do - if hi : i < n then - let iFin : Fin n := ⟨i, hi⟩ - match vec iFin with - | some _ => - throw s!"duplicate entry for index={i}" - | none => - let vec' : Fin n → Option Rat := fun i' => - if i' = iFin then - some v - else - vec i' - return vec' - else - throw s!"index out of range: i={i}" - -private def setMatEntry {m n : Nat} (mat : Fin m → Fin n → Option Rat) - (i j : Nat) (v : Rat) : Except String (Fin m → Fin n → Option Rat) := do - if hi : i < m then - if hj : j < n then - let iFin : Fin m := ⟨i, hi⟩ - let jFin : Fin n := ⟨j, hj⟩ - match mat iFin jFin with - | some _ => - throw s!"duplicate entry for indices={i},{j}" - | none => - let mat' : Fin m → Fin n → Option Rat := fun i' j' => - if i' = iFin then - if j' = jFin then - some v - else - mat i' j' - else - mat i' j' - return mat' - else - throw s!"index out of range: j={j}" - else - throw s!"index out of range: i={i}" - -private structure HeadParseState (seq dModel dHead : Nat) where - scale : Option Rat - active : Finset (Fin seq) - activeSeen : Bool - prev : Fin seq → Option (Fin seq) - embed : Fin seq → Fin dModel → Option Rat - lnEps : Option Rat - ln1Gamma : Fin dModel → Option Rat - ln1Beta : Fin dModel → Option Rat - wq : Fin dModel → Fin dHead → Option Rat - bq : Fin dHead → Option Rat - wk : Fin dModel → Fin dHead → Option Rat - bk : Fin dHead → Option Rat - wv : Fin dModel → Fin dHead → Option Rat - bv : Fin dHead → Option Rat - wo : Fin dModel → Fin dHead → Option Rat - attnBias : Fin dModel → Option Rat - maskCausal : Option Bool - maskValue : Option Rat - directionTarget : Option Nat - directionNegative : Option Nat - direction : Fin dModel → Option Rat - -private def initHeadState (seq dModel dHead : Nat) : HeadParseState seq dModel dHead := - { scale := none - active := ∅ - activeSeen := false - prev := fun _ => none - embed := fun _ _ => none - lnEps := none - ln1Gamma := fun _ => none - ln1Beta := fun _ => none - wq := fun _ _ => none - bq := fun _ => none - wk := fun _ _ => none - bk := fun _ => none - wv := fun _ _ => none - bv := fun _ => none - wo := fun _ _ => none - attnBias := fun _ => none - maskCausal := none - maskValue := none - directionTarget := none - directionNegative := none - direction := fun _ => none } - -private def setHeadActive {seq dModel dHead : Nat} - (st : HeadParseState seq dModel dHead) (q : Nat) : - Except String (HeadParseState seq dModel dHead) := do - if hq : q < seq then - let qFin : Fin seq := ⟨q, hq⟩ - return { st with active := st.active ∪ {qFin}, activeSeen := true } - else - throw s!"active index out of range: q={q}" - -private def setHeadPrev {seq dModel dHead : Nat} - (st : HeadParseState seq dModel dHead) (q k : Nat) : - Except String (HeadParseState seq dModel dHead) := do - if hq : q < seq then - if hk : k < seq then - let qFin : Fin seq := ⟨q, hq⟩ - let kFin : Fin seq := ⟨k, hk⟩ - match st.prev qFin with - | some _ => - throw s!"duplicate prev entry for q={q}" - | none => - let prev' : Fin seq → Option (Fin seq) := fun q' => - if q' = qFin then - some kFin - else - st.prev q' - return { st with prev := prev' } - else - throw s!"prev index out of range: k={k}" - else - throw s!"prev index out of range: q={q}" - -private def parseHeadLine {seq dModel dHead : Nat} (st : HeadParseState seq dModel dHead) - (tokens : List String) : Except String (HeadParseState seq dModel dHead) := do - match tokens with - | ["scale", val] => - if st.scale.isSome then - throw "duplicate scale entry" - else - return { st with scale := some (← parseRat val) } - | ["active", q] => - setHeadActive st (← parseNat q) - | ["prev", q, k] => - setHeadPrev st (← parseNat q) (← parseNat k) - | ["embed", q, d, val] => - let mat ← setMatEntry st.embed (← parseNat q) (← parseNat d) (← parseRat val) - return { st with embed := mat } - | ["ln_eps", val] => - if st.lnEps.isSome then - throw "duplicate ln_eps entry" - else - return { st with lnEps := some (← parseRat val) } - | ["ln1_gamma", d, val] => - let vec ← setVecEntry st.ln1Gamma (← parseNat d) (← parseRat val) - return { st with ln1Gamma := vec } - | ["ln1_beta", d, val] => - let vec ← setVecEntry st.ln1Beta (← parseNat d) (← parseRat val) - return { st with ln1Beta := vec } - | ["wq", i, j, val] => - let mat ← setMatEntry st.wq (← parseNat i) (← parseNat j) (← parseRat val) - return { st with wq := mat } - | ["bq", j, val] => - let vec ← setVecEntry st.bq (← parseNat j) (← parseRat val) - return { st with bq := vec } - | ["wk", i, j, val] => - let mat ← setMatEntry st.wk (← parseNat i) (← parseNat j) (← parseRat val) - return { st with wk := mat } - | ["bk", j, val] => - let vec ← setVecEntry st.bk (← parseNat j) (← parseRat val) - return { st with bk := vec } - | ["wv", i, j, val] => - let mat ← setMatEntry st.wv (← parseNat i) (← parseNat j) (← parseRat val) - return { st with wv := mat } - | ["bv", j, val] => - let vec ← setVecEntry st.bv (← parseNat j) (← parseRat val) - return { st with bv := vec } - | ["wo", i, j, val] => - let mat ← setMatEntry st.wo (← parseNat i) (← parseNat j) (← parseRat val) - return { st with wo := mat } - | ["attn_bias", d, val] => - let vec ← setVecEntry st.attnBias (← parseNat d) (← parseRat val) - return { st with attnBias := vec } - | ["mask", kind] => - if st.maskCausal.isSome then - throw "duplicate mask entry" - else - match kind with - | "causal" => return { st with maskCausal := some true } - | "none" => return { st with maskCausal := some false } - | _ => throw "mask must be 'causal' or 'none'" - | ["mask_value", val] => - if st.maskValue.isSome then - throw "duplicate mask_value entry" - else - return { st with maskValue := some (← parseRat val) } - | ["direction-target", tok] => - if st.directionTarget.isSome then - throw "duplicate direction-target entry" - else - return { st with directionTarget := some (← parseNat tok) } - | ["direction-negative", tok] => - if st.directionNegative.isSome then - throw "duplicate direction-negative entry" - else - return { st with directionNegative := some (← parseNat tok) } - | ["direction", d, val] => - let vec ← setVecEntry st.direction (← parseNat d) (← parseRat val) - return { st with direction := vec } - | _ => - throw s!"unrecognized line: '{String.intercalate " " tokens}'" - -private def finalizeHeadState {seq dModel dHead : Nat} (hpos : 0 < seq) - (st : HeadParseState seq dModel dHead) : - Except String (Model.InductionHeadInputs seq dModel dHead) := do - let scale ← - match st.scale with - | some v => pure v - | none => throw "missing scale entry" - if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => (st.prev q).isSome) then - throw "missing prev entries" - if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => - finsetAll (Finset.univ : Finset (Fin dModel)) (fun d => (st.embed q d).isSome)) then - throw "missing embed entries" - let lnEps ← - match st.lnEps with - | some v => pure v - | none => throw "missing ln_eps entry" - if !finsetAll (Finset.univ : Finset (Fin dModel)) (fun d => (st.ln1Gamma d).isSome) then - throw "missing ln1_gamma entries" - if !finsetAll (Finset.univ : Finset (Fin dModel)) (fun d => (st.ln1Beta d).isSome) then - throw "missing ln1_beta entries" - if !finsetAll (Finset.univ : Finset (Fin dModel)) (fun i => - finsetAll (Finset.univ : Finset (Fin dHead)) (fun j => (st.wq i j).isSome)) then - throw "missing wq entries" - if !finsetAll (Finset.univ : Finset (Fin dHead)) (fun j => (st.bq j).isSome) then - throw "missing bq entries" - if !finsetAll (Finset.univ : Finset (Fin dModel)) (fun i => - finsetAll (Finset.univ : Finset (Fin dHead)) (fun j => (st.wk i j).isSome)) then - throw "missing wk entries" - if !finsetAll (Finset.univ : Finset (Fin dHead)) (fun j => (st.bk j).isSome) then - throw "missing bk entries" - if !finsetAll (Finset.univ : Finset (Fin dModel)) (fun i => - finsetAll (Finset.univ : Finset (Fin dHead)) (fun j => (st.wv i j).isSome)) then - throw "missing wv entries" - if !finsetAll (Finset.univ : Finset (Fin dHead)) (fun j => (st.bv j).isSome) then - throw "missing bv entries" - if !finsetAll (Finset.univ : Finset (Fin dModel)) (fun i => - finsetAll (Finset.univ : Finset (Fin dHead)) (fun j => (st.wo i j).isSome)) then - throw "missing wo entries" - if !finsetAll (Finset.univ : Finset (Fin dModel)) (fun d => (st.attnBias d).isSome) then - throw "missing attn_bias entries" - if !finsetAll (Finset.univ : Finset (Fin dModel)) (fun d => (st.direction d).isSome) then - throw "missing direction entries" - let directionSpec ← - match st.directionTarget, st.directionNegative with - | some target, some negative => pure { target := target, negative := negative } - | _, _ => - throw "direction metadata requires both direction-target and direction-negative" - let defaultPrev : Fin seq := ⟨0, hpos⟩ - let prevFun : Fin seq → Fin seq := fun q => - (st.prev q).getD defaultPrev - let embedFun : Fin seq → Fin dModel → Rat := fun q d => - (st.embed q d).getD 0 - let ln1GammaFun : Fin dModel → Rat := fun d => - (st.ln1Gamma d).getD 0 - let ln1BetaFun : Fin dModel → Rat := fun d => - (st.ln1Beta d).getD 0 - let wqFun : Fin dModel → Fin dHead → Rat := fun i j => - (st.wq i j).getD 0 - let bqFun : Fin dHead → Rat := fun j => - (st.bq j).getD 0 - let wkFun : Fin dModel → Fin dHead → Rat := fun i j => - (st.wk i j).getD 0 - let bkFun : Fin dHead → Rat := fun j => - (st.bk j).getD 0 - let wvFun : Fin dModel → Fin dHead → Rat := fun i j => - (st.wv i j).getD 0 - let bvFun : Fin dHead → Rat := fun j => - (st.bv j).getD 0 - let woFun : Fin dModel → Fin dHead → Rat := fun i j => - (st.wo i j).getD 0 - let attnBiasFun : Fin dModel → Rat := fun d => - (st.attnBias d).getD 0 - let maskCausal := st.maskCausal.getD false - let maskValue := - match st.maskValue with - | some v => v - | none => if maskCausal then (-10000 : Rat) else 0 - let directionFun : Fin dModel → Rat := fun d => - (st.direction d).getD 0 - let active := - if st.activeSeen then - st.active - else - (Finset.univ : Finset (Fin seq)).erase defaultPrev - pure - { scale := scale - active := active - prev := prevFun - embed := embedFun - lnEps := lnEps - ln1Gamma := ln1GammaFun - ln1Beta := ln1BetaFun - wq := wqFun - bq := bqFun - wk := wkFun - bk := bkFun - wv := wvFun - bv := bvFun - wo := woFun - attnBias := attnBiasFun - maskCausal := maskCausal - maskValue := maskValue - directionSpec := directionSpec - direction := directionFun } - -/-- Parse a raw induction head input payload from text. -/ -def parseInductionHeadInputs (input : String) : - Except String (Sigma (fun seq => - Sigma (fun dModel => Sigma (fun dHead => Model.InductionHeadInputs seq dModel dHead)))) := do - let lines := input.splitOn "\n" - let tokens := lines.filterMap cleanTokens - let mut seq? : Option Nat := none - let mut dModel? : Option Nat := none - let mut dHead? : Option Nat := none - for t in tokens do - match t with - | ["seq", n] => - if seq?.isSome then - throw "duplicate seq entry" - else - seq? := some (← parseNat n) - | ["d_model", n] => - if dModel?.isSome then - throw "duplicate d_model entry" - else - dModel? := some (← parseNat n) - | ["d_head", n] => - if dHead?.isSome then - throw "duplicate d_head entry" - else - dHead? := some (← parseNat n) - | _ => pure () - let seq ← - match seq? with - | some v => pure v - | none => throw "missing seq entry" - let dModel ← - match dModel? with - | some v => pure v - | none => throw "missing d_model entry" - let dHead ← - match dHead? with - | some v => pure v - | none => throw "missing d_head entry" - match seq, dModel, dHead with - | 0, _, _ => throw "seq must be positive" - | _, 0, _ => throw "d_model must be positive" - | _, _, 0 => throw "d_head must be positive" - | Nat.succ n, Nat.succ m, Nat.succ h => - let seq := Nat.succ n - let dModel := Nat.succ m - let dHead := Nat.succ h - let hpos : 0 < seq := Nat.succ_pos n - let st0 : HeadParseState seq dModel dHead := initHeadState seq dModel dHead - let st ← tokens.foldlM (fun st t => - match t with - | ["seq", _] => pure st - | ["d_model", _] => pure st - | ["d_head", _] => pure st - | _ => parseHeadLine st t) st0 - let inputs ← finalizeHeadState hpos st - return ⟨seq, ⟨dModel, ⟨dHead, inputs⟩⟩⟩ - -/-- Raw downstream matrix payload with an input bound. -/ -structure DownstreamMatrixRaw (rows cols : Nat) where - /-- Input magnitude bound. -/ - inputBound : Rat - /-- Matrix entries. -/ - entries : Fin rows → Fin cols → Rat - -private structure DownstreamMatrixParseState (rows cols : Nat) where - inputBound : Option Rat - entries : Fin rows → Fin cols → Option Rat - -private def initDownstreamMatrixState (rows cols : Nat) : - DownstreamMatrixParseState rows cols := - { inputBound := none, entries := fun _ _ => none } - -private def setRectEntry {rows cols : Nat} (mat : Fin rows → Fin cols → Option Rat) - (i j : Nat) (v : Rat) : Except String (Fin rows → Fin cols → Option Rat) := do - if hi : i < rows then - if hj : j < cols then - let iFin : Fin rows := ⟨i, hi⟩ - let jFin : Fin cols := ⟨j, hj⟩ - match mat iFin jFin with - | some _ => - throw s!"duplicate matrix entry at ({i}, {j})" - | none => - let mat' : Fin rows → Fin cols → Option Rat := fun i' j' => - if i' = iFin then - if j' = jFin then - some v - else - mat i' j' - else - mat i' j' - return mat' - else - throw s!"index out of range: col={j}" - else - throw s!"index out of range: row={i}" - -private def parseDownstreamMatrixLine {rows cols : Nat} - (st : DownstreamMatrixParseState rows cols) (tokens : List String) : - Except String (DownstreamMatrixParseState rows cols) := do - match tokens with - | ["input-bound", val] => - if st.inputBound.isSome then - throw "duplicate input-bound entry" - else - return { st with inputBound := some (← parseRat val) } - | ["w", i, j, val] => - let mat ← setRectEntry st.entries (← parseNat i) (← parseNat j) (← parseRat val) - return { st with entries := mat } - | _ => - throw s!"unrecognized line: '{String.intercalate " " tokens}'" - -private def finalizeDownstreamMatrixState {rows cols : Nat} - (st : DownstreamMatrixParseState rows cols) : - Except String (DownstreamMatrixRaw rows cols) := do - let inputBound ← - match st.inputBound with - | some v => pure v - | none => throw "missing input-bound entry" - if !finsetAll (Finset.univ : Finset (Fin rows)) (fun i => - finsetAll (Finset.univ : Finset (Fin cols)) (fun j => (st.entries i j).isSome)) then - throw "missing matrix entries" - let entries : Fin rows → Fin cols → Rat := fun i j => - (st.entries i j).getD 0 - return { inputBound := inputBound, entries := entries } - -/-- Parse a downstream matrix payload from text. -/ -def parseDownstreamMatrixRaw (input : String) : - Except String (Sigma (fun rows => Sigma (fun cols => DownstreamMatrixRaw rows cols))) := do - let lines := input.splitOn "\n" - let tokens := lines.filterMap cleanTokens - let mut rows? : Option Nat := none - let mut cols? : Option Nat := none - for t in tokens do - match t with - | ["rows", n] => - if rows?.isSome then - throw "duplicate rows entry" - else - rows? := some (← parseNat n) - | ["cols", n] => - if cols?.isSome then - throw "duplicate cols entry" - else - cols? := some (← parseNat n) - | _ => pure () - let rows ← - match rows? with - | some v => pure v - | none => throw "missing rows entry" - let cols ← - match cols? with - | some v => pure v - | none => throw "missing cols entry" - match rows, cols with - | 0, _ => throw "rows must be positive" - | _, 0 => throw "cols must be positive" - | Nat.succ r, Nat.succ c => - let rows := Nat.succ r - let cols := Nat.succ c - let st0 := initDownstreamMatrixState rows cols - let st ← tokens.foldlM (fun st t => - match t with - | ["rows", _] => pure st - | ["cols", _] => pure st - | _ => parseDownstreamMatrixLine st t) st0 - let raw ← finalizeDownstreamMatrixState st - return ⟨rows, ⟨cols, raw⟩⟩ - -private structure ResidualBoundParseState (n : Nat) where - bounds : Fin n → Option Rat - -private def initResidualBoundState (n : Nat) : ResidualBoundParseState n := - { bounds := fun _ => none } - -private def setVectorEntry {n : Nat} (bounds : Fin n → Option Rat) - (i : Nat) (v : Rat) : Except String (Fin n → Option Rat) := do - if hi : i < n then - let iFin : Fin n := ⟨i, hi⟩ - match bounds iFin with - | some _ => - throw s!"duplicate bound entry at index {i}" - | none => - let bounds' : Fin n → Option Rat := fun i' => - if i' = iFin then - some v - else - bounds i' - return bounds' - else - throw s!"index out of range: {i}" - -private def parseResidualBoundLine {n : Nat} (st : ResidualBoundParseState n) - (tokens : List String) : Except String (ResidualBoundParseState n) := do - match tokens with - | ["bound", i, val] => - let bounds ← setVectorEntry st.bounds (← parseNat i) (← parseRat val) - return { st with bounds := bounds } - | ["dim", _] => - throw "duplicate dim entry" - | _ => - throw s!"unrecognized line: '{String.intercalate " " tokens}'" - -private def finalizeResidualBoundState {n : Nat} (st : ResidualBoundParseState n) : - Except String (Circuit.ResidualBoundCert n) := do - if !finsetAll (Finset.univ : Finset (Fin n)) (fun i => (st.bounds i).isSome) then - throw "missing bound entries" - let bound : Fin n → Rat := fun i => - (st.bounds i).getD 0 - return { bound := bound } - -/-- Parse a residual-bound payload from text. -/ -def parseResidualBoundCert (input : String) : - Except String (Sigma (fun n => Circuit.ResidualBoundCert n)) := do - let lines := input.splitOn "\n" - let tokens := lines.filterMap cleanTokens - match tokens with - | [] => throw "empty residual-bound payload" - | ["dim", nStr] :: rest => - let n ← parseNat nStr - match n with - | 0 => throw "dim must be positive" - | Nat.succ n' => - let dim := Nat.succ n' - let st0 := initResidualBoundState dim - let st ← rest.foldlM (fun st t => parseResidualBoundLine st t) st0 - let cert ← finalizeResidualBoundState st - return ⟨dim, cert⟩ - | _ => throw "expected header 'dim '" - -private structure ResidualIntervalParseState (n : Nat) where - lo : Fin n → Option Rat - hi : Fin n → Option Rat - -private def initResidualIntervalState (n : Nat) : ResidualIntervalParseState n := - { lo := fun _ => none, hi := fun _ => none } - -private def parseResidualIntervalLine {n : Nat} (st : ResidualIntervalParseState n) - (tokens : List String) : Except String (ResidualIntervalParseState n) := do - match tokens with - | ["lo", i, val] => - let lo ← setVectorEntry st.lo (← parseNat i) (← parseRat val) - return { st with lo := lo } - | ["hi", i, val] => - let hi ← setVectorEntry st.hi (← parseNat i) (← parseRat val) - return { st with hi := hi } - | ["dim", _] => - throw "duplicate dim entry" - | _ => - throw s!"unrecognized line: '{String.intercalate " " tokens}'" - -private def finalizeResidualIntervalState {n : Nat} (st : ResidualIntervalParseState n) : - Except String (Circuit.ResidualIntervalCert n) := do - if !finsetAll (Finset.univ : Finset (Fin n)) (fun i => (st.lo i).isSome) then - throw "missing lo entries" - if !finsetAll (Finset.univ : Finset (Fin n)) (fun i => (st.hi i).isSome) then - throw "missing hi entries" - let lo : Fin n → Rat := fun i => - (st.lo i).getD 0 - let hi : Fin n → Rat := fun i => - (st.hi i).getD 0 - return { lo := lo, hi := hi } - -/-- Parse a residual-interval payload from text. -/ -def parseResidualIntervalCert (input : String) : - Except String (Sigma (fun n => Circuit.ResidualIntervalCert n)) := do - let lines := input.splitOn "\n" - let tokens := lines.filterMap cleanTokens - match tokens with - | [] => throw "empty residual-interval payload" - | ["dim", nStr] :: rest => - let n ← parseNat nStr - match n with - | 0 => throw "dim must be positive" - | Nat.succ n' => - let dim := Nat.succ n' - let st0 := initResidualIntervalState dim - let st ← rest.foldlM (fun st t => parseResidualIntervalLine st t) st0 - let cert ← finalizeResidualIntervalState st - return ⟨dim, cert⟩ - | _ => throw "expected header 'dim '" - -end Pure - -end IO - -end Nfp diff --git a/Nfp/IO/Pure/Basic.lean b/Nfp/IO/Pure/Basic.lean new file mode 100644 index 0000000..d0622cd --- /dev/null +++ b/Nfp/IO/Pure/Basic.lean @@ -0,0 +1,75 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Core.Basic + +/-! +Shared parsing helpers for CLI inputs. +-/ + +namespace Nfp + +namespace IO + +namespace Pure + +/-- Split a line into whitespace-separated tokens. -/ +def splitWords (line : String) : List String := + line.splitToList (fun c => c = ' ' || c = '\t') |>.filter (· ≠ "") + +/-- Drop empty/comment lines and return their whitespace tokens. -/ +def cleanTokens (line : String) : Option (List String) := + let trimmed := line.trim + if trimmed.isEmpty then + none + else if trimmed.startsWith "#" then + none + else + some (splitWords trimmed) + +/-- Parse a nonnegative decimal integer. -/ +def parseNat (s : String) : Except String Nat := do + if s.isEmpty then + throw s!"expected Nat, got '{s}'" + else + let mut acc : Nat := 0 + for c in s.toList do + if c.isDigit then + acc := acc * 10 + c.toNat - '0'.toNat + else + throw s!"expected Nat, got '{s}'" + return acc + +/-- Parse a signed decimal integer. -/ +def parseInt (s : String) : Except String Int := do + if s.isEmpty then + throw s!"expected Int, got '{s}'" + else + match s.toSlice.front? with + | some '-' => + let rest := s.drop 1 + let n ← parseNat rest + return -Int.ofNat n + | _ => + let n ← parseNat s + return Int.ofNat n + +/-- Parse a dyadic literal from `a` or `a/b`, rounding down if needed. -/ +def parseDyadic (s : String) : Except String Dyadic := do + match s.splitOn "/" with + | [num] => + return dyadicOfRatDown (Rat.ofInt (← parseInt num)) + | [num, den] => + let n ← parseInt num + let d ← parseNat den + if d = 0 then + throw s!"invalid rational '{s}': zero denominator" + else + return dyadicOfRatDown (Rat.divInt n (Int.ofNat d)) + | _ => + throw s!"invalid rational '{s}'" + +end Pure + +end IO + +end Nfp diff --git a/Nfp/IO/Pure/Downstream.lean b/Nfp/IO/Pure/Downstream.lean new file mode 100644 index 0000000..82aa903 --- /dev/null +++ b/Nfp/IO/Pure/Downstream.lean @@ -0,0 +1,202 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Circuit.Cert.DownstreamLinear +import Nfp.IO.Pure.Basic + +/-! +Pure parsing helpers for downstream linear and matrix payloads. +-/ + +namespace Nfp + +namespace IO + +namespace Pure + +open Nfp.Circuit + +private structure DownstreamLinearParseState where + error : Option Dyadic + gain : Option Dyadic + inputBound : Option Dyadic + +private def initDownstreamLinearState : DownstreamLinearParseState := + { error := none, gain := none, inputBound := none } + +private def parseDownstreamLinearLine (st : DownstreamLinearParseState) + (tokens : List String) : Except String DownstreamLinearParseState := do + match tokens with + | ["error", val] => + if st.error.isSome then + throw "duplicate error entry" + else + return { st with error := some (← parseDyadic val) } + | ["gain", val] => + if st.gain.isSome then + throw "duplicate gain entry" + else + return { st with gain := some (← parseDyadic val) } + | ["input-bound", val] => + if st.inputBound.isSome then + throw "duplicate input-bound entry" + else + return { st with inputBound := some (← parseDyadic val) } + | _ => + throw s!"unrecognized line: '{String.intercalate " " tokens}'" + +private def finalizeDownstreamLinearState (st : DownstreamLinearParseState) : + Except String Circuit.DownstreamLinearCert := do + let error ← + match st.error with + | some v => pure v + | none => throw "missing error entry" + let gain ← + match st.gain with + | some v => pure v + | none => throw "missing gain entry" + let inputBound ← + match st.inputBound with + | some v => pure v + | none => throw "missing input-bound entry" + return { error := error, gain := gain, inputBound := inputBound } + +/-- Parse a downstream linear certificate from a text payload. -/ +def parseDownstreamLinearCert (input : String) : + Except String Circuit.DownstreamLinearCert := do + let lines := input.splitOn "\n" + let tokens := lines.filterMap cleanTokens + let st0 := initDownstreamLinearState + let st ← tokens.foldlM (fun st t => parseDownstreamLinearLine st t) st0 + finalizeDownstreamLinearState st + +private def initPrevOpt (n : Nat) : Array (Option (Fin n)) := + Array.replicate n none + +private def initActiveBits (n : Nat) : Array Bool := + Array.replicate n false + +private def activeFromBits {n : Nat} (bits : Array Bool) : Finset (Fin n) := + (Finset.univ : Finset (Fin n)).filter (fun i => bits.getD i.1 false) + +private def arrayAllSome {α : Type} (arr : Array (Option α)) : Bool := + (List.range arr.size).all (fun i => (arr.getD i none).isSome) + +private def matAllSome {α : Type} (mat : Array (Array (Option α))) : Bool := + (List.range mat.size).all (fun i => arrayAllSome (mat.getD i #[])) + +/-- Raw downstream matrix payload with an input bound. -/ +structure DownstreamMatrixRaw (rows cols : Nat) where + /-- Input magnitude bound. -/ + inputBound : Dyadic + /-- Matrix entries. -/ + entries : Fin rows → Fin cols → Dyadic + +private structure DownstreamMatrixParseState (rows cols : Nat) where + inputBound : Option Dyadic + entries : Fin rows → Fin cols → Option Dyadic + +private def initDownstreamMatrixState (rows cols : Nat) : + DownstreamMatrixParseState rows cols := + { inputBound := none, entries := fun _ _ => none } + +private def setRectEntry {rows cols : Nat} (mat : Fin rows → Fin cols → Option Dyadic) + (i j : Nat) (v : Dyadic) : Except String (Fin rows → Fin cols → Option Dyadic) := do + if hi : i < rows then + if hj : j < cols then + let iFin : Fin rows := ⟨i, hi⟩ + let jFin : Fin cols := ⟨j, hj⟩ + match mat iFin jFin with + | some _ => + throw s!"duplicate matrix entry at ({i}, {j})" + | none => + let mat' : Fin rows → Fin cols → Option Dyadic := fun i' j' => + if i' = iFin then + if j' = jFin then + some v + else + mat i' j' + else + mat i' j' + return mat' + else + throw s!"index out of range: col={j}" + else + throw s!"index out of range: row={i}" + +private def parseDownstreamMatrixLine {rows cols : Nat} + (st : DownstreamMatrixParseState rows cols) (tokens : List String) : + Except String (DownstreamMatrixParseState rows cols) := do + match tokens with + | ["input-bound", val] => + if st.inputBound.isSome then + throw "duplicate input-bound entry" + else + return { st with inputBound := some (← parseDyadic val) } + | ["w", i, j, val] => + let mat ← setRectEntry st.entries (← parseNat i) (← parseNat j) (← parseDyadic val) + return { st with entries := mat } + | _ => + throw s!"unrecognized line: '{String.intercalate " " tokens}'" + +private def finalizeDownstreamMatrixState {rows cols : Nat} + (st : DownstreamMatrixParseState rows cols) : + Except String (DownstreamMatrixRaw rows cols) := do + let inputBound ← + match st.inputBound with + | some v => pure v + | none => throw "missing input-bound entry" + if !finsetAll (Finset.univ : Finset (Fin rows)) (fun i => + finsetAll (Finset.univ : Finset (Fin cols)) (fun j => (st.entries i j).isSome)) then + throw "missing matrix entries" + let entries : Fin rows → Fin cols → Dyadic := fun i j => + (st.entries i j).getD 0 + return { inputBound := inputBound, entries := entries } + +/-- Parse a downstream matrix payload from text. -/ +def parseDownstreamMatrixRaw (input : String) : + Except String (Sigma (fun rows => Sigma (fun cols => DownstreamMatrixRaw rows cols))) := do + let lines := input.splitOn "\n" + let tokens := lines.filterMap cleanTokens + let mut rows? : Option Nat := none + let mut cols? : Option Nat := none + for t in tokens do + match t with + | ["rows", n] => + if rows?.isSome then + throw "duplicate rows entry" + else + rows? := some (← parseNat n) + | ["cols", n] => + if cols?.isSome then + throw "duplicate cols entry" + else + cols? := some (← parseNat n) + | _ => pure () + let rows ← + match rows? with + | some v => pure v + | none => throw "missing rows entry" + let cols ← + match cols? with + | some v => pure v + | none => throw "missing cols entry" + match rows, cols with + | 0, _ => throw "rows must be positive" + | _, 0 => throw "cols must be positive" + | Nat.succ r, Nat.succ c => + let rows := Nat.succ r + let cols := Nat.succ c + let st0 := initDownstreamMatrixState rows cols + let st ← tokens.foldlM (fun st t => + match t with + | ["rows", _] => pure st + | ["cols", _] => pure st + | _ => parseDownstreamMatrixLine st t) st0 + let raw ← finalizeDownstreamMatrixState st + return ⟨rows, ⟨cols, raw⟩⟩ + +end Pure + +end IO + +end Nfp diff --git a/Nfp/IO/Pure/InductionHead.lean b/Nfp/IO/Pure/InductionHead.lean new file mode 100644 index 0000000..7881ed1 --- /dev/null +++ b/Nfp/IO/Pure/InductionHead.lean @@ -0,0 +1,25 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.IO.Pure.InductionHead.Bytes + +/-! +Parsing helpers for induction-head input payloads. +-/ + +namespace Nfp + +namespace IO + +namespace Pure + +/-- Parse a raw induction head input payload from text. -/ +def parseInductionHeadInputs (input : String) : + Except String (Sigma (fun seq => + Sigma (fun dModel => Sigma (fun dHead => Model.InductionHeadInputs seq dModel dHead)))) := do + parseInductionHeadInputsBytes input.toUTF8 + +end Pure + +end IO + +end Nfp diff --git a/Nfp/IO/Pure/InductionHead/Bytes.lean b/Nfp/IO/Pure/InductionHead/Bytes.lean new file mode 100644 index 0000000..c4f188d --- /dev/null +++ b/Nfp/IO/Pure/InductionHead/Bytes.lean @@ -0,0 +1,786 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Data.Finset.Insert +import Nfp.IO.Pure.Basic +import Nfp.Model.InductionHead + +/-! +Parsing helpers for induction-head input payloads from UTF-8 bytes. +-/ + +namespace Nfp + +namespace IO + +namespace Pure + +private def kwSeq : ByteArray := "seq".toUTF8 +private def kwDModel : ByteArray := "d_model".toUTF8 +private def kwDHead : ByteArray := "d_head".toUTF8 +private def kwScale : ByteArray := "scale".toUTF8 +private def kwActive : ByteArray := "active".toUTF8 +private def kwPrev : ByteArray := "prev".toUTF8 +private def kwEmbed : ByteArray := "embed".toUTF8 +private def kwLnEps : ByteArray := "ln_eps".toUTF8 +private def kwLn1Gamma : ByteArray := "ln1_gamma".toUTF8 +private def kwLn1Beta : ByteArray := "ln1_beta".toUTF8 +private def kwAttnBias : ByteArray := "attn_bias".toUTF8 +private def kwMask : ByteArray := "mask".toUTF8 +private def kwMaskValue : ByteArray := "mask_value".toUTF8 +private def kwCausal : ByteArray := "causal".toUTF8 +private def kwNone : ByteArray := "none".toUTF8 +private def kwDirection : ByteArray := "direction".toUTF8 +private def kwDirectionTarget : ByteArray := "direction-target".toUTF8 +private def kwDirectionNegative : ByteArray := "direction-negative".toUTF8 + +private structure ByteToken where + start : Nat + stop : Nat + +private def tokenLen (t : ByteToken) : Nat := + t.stop - t.start + +private def tokenEq (data : ByteArray) (t : ByteToken) (kw : ByteArray) : Bool := Id.run do + if tokenLen t != kw.size then + return false + let mut i := 0 + while i < kw.size do + if data.get! (t.start + i) != kw.get! i then + return false + i := i + 1 + return true + +private def parseNatBytesCore (data : ByteArray) (i stop : Nat) (acc : Nat) : + Except String Nat := + if h : i < stop then + let b := data.get! i + if b >= 48 && b <= 57 then + parseNatBytesCore data (i + 1) stop (acc * 10 + (b.toNat - 48)) + else + Except.error "expected Nat" + else + Except.ok acc +termination_by stop - i + +private def parseNatBytesSpec (data : ByteArray) (t : ByteToken) : Except String Nat := + if tokenLen t = 0 then + throw "expected Nat" + else + parseNatBytesCore data t.start t.stop 0 + +private def parseNatBytes (data : ByteArray) (t : ByteToken) : Except String Nat := + parseNatBytesSpec data t + +theorem parseNatBytes_eq_spec (data : ByteArray) (t : ByteToken) : + parseNatBytes data t = parseNatBytesSpec data t := by + rfl + +private def parseIntBytesSpec (data : ByteArray) (t : ByteToken) : Except String Int := do + if tokenLen t = 0 then + throw "expected Int" + let first := data.get! t.start + if first = 45 then + let t' : ByteToken := { start := t.start + 1, stop := t.stop } + let n ← parseNatBytesSpec data t' + return -Int.ofNat n + else + let n ← parseNatBytesSpec data t + return Int.ofNat n + +private def parseIntBytes (data : ByteArray) (t : ByteToken) : Except String Int := + parseIntBytesSpec data t + +theorem parseIntBytes_eq_spec (data : ByteArray) (t : ByteToken) : + parseIntBytes data t = parseIntBytesSpec data t := by + rfl + +private def findSlash (data : ByteArray) (i stop : Nat) : Option Nat := + if h : i < stop then + if data.get! i = 47 then + some i + else + findSlash data (i + 1) stop + else + none +termination_by stop - i + +private def parseDyadicBytesSpec (data : ByteArray) (t : ByteToken) : Except String Dyadic := do + match findSlash data t.start t.stop with + | none => + return dyadicOfRatDown (Rat.ofInt (← parseIntBytesSpec data t)) + | some s => + let numTok : ByteToken := { start := t.start, stop := s } + let denTok : ByteToken := { start := s + 1, stop := t.stop } + let n ← parseIntBytesSpec data numTok + let d ← parseNatBytesSpec data denTok + if d = 0 then + throw "invalid rational: zero denominator" + else + return dyadicOfRatDown (Rat.divInt n (Int.ofNat d)) + +private def parseDyadicBytes (data : ByteArray) (t : ByteToken) : Except String Dyadic := + parseDyadicBytesSpec data t + +theorem parseDyadicBytes_eq_spec (data : ByteArray) (t : ByteToken) : + parseDyadicBytes data t = parseDyadicBytesSpec data t := by + rfl + +private def nextLineBounds (data : ByteArray) (start : Nat) : Nat × Nat × Nat := + Id.run do + let mut i := start + let lineStart := start + while i < data.size do + let b := data.get! i + if b == 10 || b == 13 then + let lineEnd := i + let mut j := i + 1 + if b == 13 && j < data.size && data.get! j == 10 then + j := j + 1 + return (j, lineStart, lineEnd) + i := i + 1 + return (data.size, lineStart, data.size) + +private def skipSpaces (data : ByteArray) (i lineEnd : Nat) : Nat := + Id.run do + let mut j := i + while j < lineEnd do + let b := data.get! j + if b == 32 || b == 9 then + j := j + 1 + else + break + return j + +private def readToken (data : ByteArray) (i lineEnd : Nat) : + Option (ByteToken × Nat) := + Id.run do + let j := skipSpaces data i lineEnd + if j >= lineEnd then + return none + let start := j + let mut k := j + while k < lineEnd do + let b := data.get! k + if b == 32 || b == 9 then + break + k := k + 1 + return some ({ start := start, stop := k }, k) + +private def expectToken (data : ByteArray) (i lineEnd : Nat) : + Except String (ByteToken × Nat) := do + match readToken data i lineEnd with + | some out => return out + | none => throw "expected token" + +private def ensureNoMoreTokens (data : ByteArray) (i lineEnd : Nat) : + Except String Unit := do + let j := skipSpaces data i lineEnd + if j < lineEnd then + throw "unrecognized line" + +private def parseNatAt (data : ByteArray) (i lineEnd : Nat) : + Except String (Nat × Nat) := do + let (tok, i') ← expectToken data i lineEnd + let n ← parseNatBytes data tok + return (n, i') + +private def parseDyadicAt (data : ByteArray) (i lineEnd : Nat) : + Except String (Dyadic × Nat) := do + let (tok, i') ← expectToken data i lineEnd + let r ← parseDyadicBytes data tok + return (r, i') + +private def setVecEntry (n : Nat) (vec : Array (Option Dyadic)) + (i : Nat) (v : Dyadic) : + Except String (Array (Option Dyadic)) := do + if i < n then + match vec.getD i none with + | some _ => + throw s!"duplicate entry for index={i}" + | none => + let vec' := vec.set! i (some v) + return vec' + else + throw s!"index out of range: i={i}" + +private def setMatEntry (rows cols : Nat) (mat : Array (Array (Option Dyadic))) + (i j : Nat) (v : Dyadic) : Except String (Array (Array (Option Dyadic))) := do + if i < rows then + if j < cols then + let row := mat.getD i #[] + match row.getD j none with + | some _ => + throw s!"duplicate entry for index=({i}, {j})" + | none => + let row' := row.set! j (some v) + let mat' := mat.set! i row' + return mat' + else + throw s!"index out of range: j={j}" + else + throw s!"index out of range: i={i}" + +private def initVecOpt (n : Nat) : Array (Option Dyadic) := + Array.replicate n none + +private def initMatOpt (rows cols : Nat) : Array (Array (Option Dyadic)) := + Array.replicate rows (initVecOpt cols) + +private def initPrevOpt (n : Nat) : Array (Option (Fin n)) := + Array.replicate n none + +private def initActiveBits (n : Nat) : Array Bool := + Array.replicate n false + +private def activeFromBits {n : Nat} (bits : Array Bool) : Finset (Fin n) := + (Finset.univ : Finset (Fin n)).filter (fun i => bits.getD i.1 false) + +private def arrayAllSome {α : Type} (arr : Array (Option α)) : Bool := + (List.range arr.size).all (fun i => (arr.getD i none).isSome) + +private def matAllSome {α : Type} (mat : Array (Array (Option α))) : Bool := + (List.range mat.size).all (fun i => arrayAllSome (mat.getD i #[])) + +private structure HeadParseState (seq dModel dHead : Nat) where + scale : Option Dyadic + activeBits : Array Bool + activeSeen : Bool + prev : Array (Option (Fin seq)) + embed : Array (Array (Option Dyadic)) + lnEps : Option Dyadic + ln1Gamma : Array (Option Dyadic) + ln1Beta : Array (Option Dyadic) + wq : Array (Array (Option Dyadic)) + bq : Array (Option Dyadic) + wk : Array (Array (Option Dyadic)) + bk : Array (Option Dyadic) + wv : Array (Array (Option Dyadic)) + bv : Array (Option Dyadic) + wo : Array (Array (Option Dyadic)) + attnBias : Array (Option Dyadic) + maskCausal : Option Bool + maskValue : Option Dyadic + directionTarget : Option Nat + directionNegative : Option Nat + direction : Array (Option Dyadic) + +private def initHeadState (seq dModel dHead : Nat) : HeadParseState seq dModel dHead := + { scale := none + activeBits := initActiveBits seq + activeSeen := false + prev := initPrevOpt seq + embed := initMatOpt seq dModel + lnEps := none + ln1Gamma := initVecOpt dModel + ln1Beta := initVecOpt dModel + wq := initMatOpt dModel dHead + bq := initVecOpt dHead + wk := initMatOpt dModel dHead + bk := initVecOpt dHead + wv := initMatOpt dModel dHead + bv := initVecOpt dHead + wo := initMatOpt dModel dHead + attnBias := initVecOpt dModel + maskCausal := none + maskValue := none + directionTarget := none + directionNegative := none + direction := initVecOpt dModel } + +private def setHeadActive {seq dModel dHead : Nat} + (st : HeadParseState seq dModel dHead) (q : Nat) : + Except String (HeadParseState seq dModel dHead) := do + if q < seq then + return { st with activeBits := st.activeBits.set! q true, activeSeen := true } + else + throw s!"active index out of range: q={q}" + +private def setHeadPrev {seq dModel dHead : Nat} + (st : HeadParseState seq dModel dHead) (q k : Nat) : + Except String (HeadParseState seq dModel dHead) := do + if q < seq then + if hk : k < seq then + let kFin : Fin seq := ⟨k, hk⟩ + match st.prev.getD q none with + | some _ => + throw s!"duplicate prev entry for q={q}" + | none => + return { st with prev := st.prev.set! q (some kFin) } + else + throw s!"prev index out of range: k={k}" + else + throw s!"prev index out of range: q={q}" + +private def parseHeadLine {seq dModel dHead : Nat} (st : HeadParseState seq dModel dHead) + (tokens : List String) : Except String (HeadParseState seq dModel dHead) := do + match tokens with + | ["scale", val] => + if st.scale.isSome then + throw "duplicate scale entry" + else + return { st with scale := some (← parseDyadic val) } + | ["active", q] => + setHeadActive st (← parseNat q) + | ["prev", q, k] => + setHeadPrev st (← parseNat q) (← parseNat k) + | ["embed", q, d, val] => + let mat ← + setMatEntry seq dModel st.embed (← parseNat q) (← parseNat d) (← parseDyadic val) + return { st with embed := mat } + | ["ln_eps", val] => + if st.lnEps.isSome then + throw "duplicate ln_eps entry" + else + return { st with lnEps := some (← parseDyadic val) } + | ["ln1_gamma", d, val] => + let vec ← setVecEntry dModel st.ln1Gamma (← parseNat d) (← parseDyadic val) + return { st with ln1Gamma := vec } + | ["ln1_beta", d, val] => + let vec ← setVecEntry dModel st.ln1Beta (← parseNat d) (← parseDyadic val) + return { st with ln1Beta := vec } + | ["wq", i, j, val] => + let mat ← + setMatEntry dModel dHead st.wq (← parseNat i) (← parseNat j) (← parseDyadic val) + return { st with wq := mat } + | ["bq", j, val] => + let vec ← setVecEntry dHead st.bq (← parseNat j) (← parseDyadic val) + return { st with bq := vec } + | ["wk", i, j, val] => + let mat ← + setMatEntry dModel dHead st.wk (← parseNat i) (← parseNat j) (← parseDyadic val) + return { st with wk := mat } + | ["bk", j, val] => + let vec ← setVecEntry dHead st.bk (← parseNat j) (← parseDyadic val) + return { st with bk := vec } + | ["wv", i, j, val] => + let mat ← + setMatEntry dModel dHead st.wv (← parseNat i) (← parseNat j) (← parseDyadic val) + return { st with wv := mat } + | ["bv", j, val] => + let vec ← setVecEntry dHead st.bv (← parseNat j) (← parseDyadic val) + return { st with bv := vec } + | ["wo", i, j, val] => + let mat ← + setMatEntry dModel dHead st.wo (← parseNat i) (← parseNat j) (← parseDyadic val) + return { st with wo := mat } + | ["attn_bias", d, val] => + let vec ← setVecEntry dModel st.attnBias (← parseNat d) (← parseDyadic val) + return { st with attnBias := vec } + | ["mask", kind] => + if st.maskCausal.isSome then + throw "duplicate mask entry" + else + match kind with + | "causal" => return { st with maskCausal := some true } + | "none" => return { st with maskCausal := some false } + | _ => throw "mask must be 'causal' or 'none'" + | ["mask_value", val] => + if st.maskValue.isSome then + throw "duplicate mask_value entry" + else + return { st with maskValue := some (← parseDyadic val) } + | ["direction-target", tok] => + if st.directionTarget.isSome then + throw "duplicate direction-target entry" + else + return { st with directionTarget := some (← parseNat tok) } + | ["direction-negative", tok] => + if st.directionNegative.isSome then + throw "duplicate direction-negative entry" + else + return { st with directionNegative := some (← parseNat tok) } + | ["direction", d, val] => + let vec ← setVecEntry dModel st.direction (← parseNat d) (← parseDyadic val) + return { st with direction := vec } + | _ => + throw s!"unrecognized line: '{String.intercalate " " tokens}'" + +private def parseHeadLineBytes {seq dModel dHead : Nat} (data : ByteArray) + (st : HeadParseState seq dModel dHead) (lineStart lineEnd : Nat) : + Except String (HeadParseState seq dModel dHead) := do + let i0 := skipSpaces data lineStart lineEnd + if i0 >= lineEnd then + return st + if data.get! i0 = 35 then + return st + match readToken data i0 lineEnd with + | none => return st + | some (t0, i1) => + let len := tokenLen t0 + let b0 := data.get! t0.start + match b0 with + | 115 => -- s + if len = kwSeq.size && tokenEq data t0 kwSeq then + return st + else if len = kwScale.size && tokenEq data t0 kwScale then + if st.scale.isSome then + throw "duplicate scale entry" + else + let (t1, i2) ← expectToken data i1 lineEnd + ensureNoMoreTokens data i2 lineEnd + return { st with scale := some (← parseDyadicBytes data t1) } + else + throw "unrecognized line" + | 97 => -- a + if len = kwActive.size && tokenEq data t0 kwActive then + let (q, i2) ← parseNatAt data i1 lineEnd + ensureNoMoreTokens data i2 lineEnd + setHeadActive st q + else if len = kwAttnBias.size && tokenEq data t0 kwAttnBias then + let (d, i2) ← parseNatAt data i1 lineEnd + let (v, i3) ← parseDyadicAt data i2 lineEnd + ensureNoMoreTokens data i3 lineEnd + let vec ← setVecEntry dModel st.attnBias d v + return { st with attnBias := vec } + else + throw "unrecognized line" + | 112 => -- p + if len = kwPrev.size && tokenEq data t0 kwPrev then + let (q, i2) ← parseNatAt data i1 lineEnd + let (k, i3) ← parseNatAt data i2 lineEnd + ensureNoMoreTokens data i3 lineEnd + setHeadPrev st q k + else + throw "unrecognized line" + | 101 => -- e + if len = kwEmbed.size && tokenEq data t0 kwEmbed then + let (q, i2) ← parseNatAt data i1 lineEnd + let (d, i3) ← parseNatAt data i2 lineEnd + let (v, i4) ← parseDyadicAt data i3 lineEnd + ensureNoMoreTokens data i4 lineEnd + let mat ← setMatEntry seq dModel st.embed q d v + return { st with embed := mat } + else + throw "unrecognized line" + | 108 => -- l + if len = kwLnEps.size && tokenEq data t0 kwLnEps then + if st.lnEps.isSome then + throw "duplicate ln_eps entry" + else + let (v, i2) ← parseDyadicAt data i1 lineEnd + ensureNoMoreTokens data i2 lineEnd + return { st with lnEps := some v } + else if len = kwLn1Gamma.size && tokenEq data t0 kwLn1Gamma then + let (d, i2) ← parseNatAt data i1 lineEnd + let (v, i3) ← parseDyadicAt data i2 lineEnd + ensureNoMoreTokens data i3 lineEnd + let vec ← setVecEntry dModel st.ln1Gamma d v + return { st with ln1Gamma := vec } + else if len = kwLn1Beta.size && tokenEq data t0 kwLn1Beta then + let (d, i2) ← parseNatAt data i1 lineEnd + let (v, i3) ← parseDyadicAt data i2 lineEnd + ensureNoMoreTokens data i3 lineEnd + let vec ← setVecEntry dModel st.ln1Beta d v + return { st with ln1Beta := vec } + else + throw "unrecognized line" + | 119 => -- w + if len = 2 then + let b1 := data.get! (t0.start + 1) + let (i, i2) ← parseNatAt data i1 lineEnd + let (j, i3) ← parseNatAt data i2 lineEnd + let (v, i4) ← parseDyadicAt data i3 lineEnd + ensureNoMoreTokens data i4 lineEnd + if b1 = 113 then + let mat ← setMatEntry dModel dHead st.wq i j v + return { st with wq := mat } + else if b1 = 107 then + let mat ← setMatEntry dModel dHead st.wk i j v + return { st with wk := mat } + else if b1 = 118 then + let mat ← setMatEntry dModel dHead st.wv i j v + return { st with wv := mat } + else if b1 = 111 then + let mat ← setMatEntry dModel dHead st.wo i j v + return { st with wo := mat } + else + throw "unrecognized line" + else + throw "unrecognized line" + | 98 => -- b + if len = 2 then + let b1 := data.get! (t0.start + 1) + let (j, i2) ← parseNatAt data i1 lineEnd + let (v, i3) ← parseDyadicAt data i2 lineEnd + ensureNoMoreTokens data i3 lineEnd + if b1 = 113 then + let vec ← setVecEntry dHead st.bq j v + return { st with bq := vec } + else if b1 = 107 then + let vec ← setVecEntry dHead st.bk j v + return { st with bk := vec } + else if b1 = 118 then + let vec ← setVecEntry dHead st.bv j v + return { st with bv := vec } + else + throw "unrecognized line" + else + throw "unrecognized line" + | 109 => -- m + if len = kwMask.size && tokenEq data t0 kwMask then + if st.maskCausal.isSome then + throw "duplicate mask entry" + else + let (t1, i2) ← expectToken data i1 lineEnd + ensureNoMoreTokens data i2 lineEnd + if tokenEq data t1 kwCausal then + return { st with maskCausal := some true } + else if tokenEq data t1 kwNone then + return { st with maskCausal := some false } + else + throw "mask must be 'causal' or 'none'" + else if len = kwMaskValue.size && tokenEq data t0 kwMaskValue then + if st.maskValue.isSome then + throw "duplicate mask_value entry" + else + let (v, i2) ← parseDyadicAt data i1 lineEnd + ensureNoMoreTokens data i2 lineEnd + return { st with maskValue := some v } + else + throw "unrecognized line" + | 100 => -- d + if len = kwDModel.size && tokenEq data t0 kwDModel then + return st + else if len = kwDHead.size && tokenEq data t0 kwDHead then + return st + else if len = kwDirection.size && tokenEq data t0 kwDirection then + let (d, i2) ← parseNatAt data i1 lineEnd + let (v, i3) ← parseDyadicAt data i2 lineEnd + ensureNoMoreTokens data i3 lineEnd + let vec ← setVecEntry dModel st.direction d v + return { st with direction := vec } + else if len = kwDirectionTarget.size && tokenEq data t0 kwDirectionTarget then + if st.directionTarget.isSome then + throw "duplicate direction-target entry" + else + let (v, i2) ← parseNatAt data i1 lineEnd + ensureNoMoreTokens data i2 lineEnd + return { st with directionTarget := some v } + else if len = kwDirectionNegative.size && tokenEq data t0 kwDirectionNegative then + if st.directionNegative.isSome then + throw "duplicate direction-negative entry" + else + let (v, i2) ← parseNatAt data i1 lineEnd + ensureNoMoreTokens data i2 lineEnd + return { st with directionNegative := some v } + else + throw "unrecognized line" + | _ => + throw "unrecognized line" + +private def parseHeaderLineBytes (data : ByteArray) (lineStart lineEnd : Nat) + (seq? dModel? dHead? : Option Nat) : + Except String (Option Nat × Option Nat × Option Nat) := do + let i0 := skipSpaces data lineStart lineEnd + if i0 >= lineEnd then + return (seq?, dModel?, dHead?) + if data.get! i0 = 35 then + return (seq?, dModel?, dHead?) + match readToken data i0 lineEnd with + | none => return (seq?, dModel?, dHead?) + | some (t0, i1) => + if tokenEq data t0 kwSeq then + let (v, i2) ← parseNatAt data i1 lineEnd + ensureNoMoreTokens data i2 lineEnd + if seq?.isSome then + throw "duplicate seq entry" + else + return (some v, dModel?, dHead?) + else if tokenEq data t0 kwDModel then + let (v, i2) ← parseNatAt data i1 lineEnd + ensureNoMoreTokens data i2 lineEnd + if dModel?.isSome then + throw "duplicate d_model entry" + else + return (seq?, some v, dHead?) + else if tokenEq data t0 kwDHead then + let (v, i2) ← parseNatAt data i1 lineEnd + ensureNoMoreTokens data i2 lineEnd + if dHead?.isSome then + throw "duplicate d_head entry" + else + return (seq?, dModel?, some v) + else + return (seq?, dModel?, dHead?) + +private def finalizeHeadState {seq dModel dHead : Nat} (hpos : 0 < seq) + (st : HeadParseState seq dModel dHead) : + Except String (Model.InductionHeadInputs seq dModel dHead) := do + let scale ← + match st.scale with + | some v => pure v + | none => throw "missing scale entry" + if !arrayAllSome st.prev then + throw "missing prev entries" + if !matAllSome st.embed then + throw "missing embed entries" + let lnEps ← + match st.lnEps with + | some v => pure v + | none => throw "missing ln_eps entry" + if !arrayAllSome st.ln1Gamma then + throw "missing ln1_gamma entries" + if !arrayAllSome st.ln1Beta then + throw "missing ln1_beta entries" + if !matAllSome st.wq then + throw "missing wq entries" + if !arrayAllSome st.bq then + throw "missing bq entries" + if !matAllSome st.wk then + throw "missing wk entries" + if !arrayAllSome st.bk then + throw "missing bk entries" + if !matAllSome st.wv then + throw "missing wv entries" + if !arrayAllSome st.bv then + throw "missing bv entries" + if !matAllSome st.wo then + throw "missing wo entries" + if !arrayAllSome st.attnBias then + throw "missing attn_bias entries" + if !arrayAllSome st.direction then + throw "missing direction entries" + let directionSpec ← + match st.directionTarget, st.directionNegative with + | some target, some negative => pure { target := target, negative := negative } + | _, _ => + throw "direction metadata requires both direction-target and direction-negative" + let defaultPrev : Fin seq := ⟨0, hpos⟩ + let prevFun : Fin seq → Fin seq := fun q => + (st.prev.getD q.1 none).getD defaultPrev + let embedArr : Array (Array Dyadic) := + st.embed.map (fun row => row.map (fun v => v.getD 0)) + let ln1GammaArr : Array Dyadic := + st.ln1Gamma.map (fun v => v.getD 0) + let ln1BetaArr : Array Dyadic := + st.ln1Beta.map (fun v => v.getD 0) + let wqArr : Array (Array Dyadic) := + st.wq.map (fun row => row.map (fun v => v.getD 0)) + let bqArr : Array Dyadic := + st.bq.map (fun v => v.getD 0) + let wkArr : Array (Array Dyadic) := + st.wk.map (fun row => row.map (fun v => v.getD 0)) + let bkArr : Array Dyadic := + st.bk.map (fun v => v.getD 0) + let wvArr : Array (Array Dyadic) := + st.wv.map (fun row => row.map (fun v => v.getD 0)) + let bvArr : Array Dyadic := + st.bv.map (fun v => v.getD 0) + let woArr : Array (Array Dyadic) := + st.wo.map (fun row => row.map (fun v => v.getD 0)) + let attnBiasArr : Array Dyadic := + st.attnBias.map (fun v => v.getD 0) + let directionArr : Array Dyadic := + st.direction.map (fun v => v.getD 0) + let embedFun : Fin seq → Fin dModel → Dyadic := fun q d => + (embedArr.getD q.1 #[]).getD d.1 0 + let ln1GammaFun : Fin dModel → Dyadic := fun d => + ln1GammaArr.getD d.1 0 + let ln1BetaFun : Fin dModel → Dyadic := fun d => + ln1BetaArr.getD d.1 0 + let wqFun : Fin dModel → Fin dHead → Dyadic := fun i j => + (wqArr.getD i.1 #[]).getD j.1 0 + let bqFun : Fin dHead → Dyadic := fun j => + bqArr.getD j.1 0 + let wkFun : Fin dModel → Fin dHead → Dyadic := fun i j => + (wkArr.getD i.1 #[]).getD j.1 0 + let bkFun : Fin dHead → Dyadic := fun j => + bkArr.getD j.1 0 + let wvFun : Fin dModel → Fin dHead → Dyadic := fun i j => + (wvArr.getD i.1 #[]).getD j.1 0 + let bvFun : Fin dHead → Dyadic := fun j => + bvArr.getD j.1 0 + let woFun : Fin dModel → Fin dHead → Dyadic := fun i j => + (woArr.getD i.1 #[]).getD j.1 0 + let attnBiasFun : Fin dModel → Dyadic := fun d => + attnBiasArr.getD d.1 0 + let maskCausal := st.maskCausal.getD false + let maskValue := + match st.maskValue with + | some v => v + | none => if maskCausal then (-10000 : Dyadic) else 0 + let directionFun : Fin dModel → Dyadic := fun d => + directionArr.getD d.1 0 + let active := + if st.activeSeen then + activeFromBits st.activeBits + else + (Finset.univ : Finset (Fin seq)).erase defaultPrev + pure + { scale := scale + active := active + prev := prevFun + embed := embedFun + lnEps := lnEps + ln1Gamma := ln1GammaFun + ln1Beta := ln1BetaFun + wq := wqFun + bq := bqFun + wk := wkFun + bk := bkFun + wv := wvFun + bv := bvFun + wo := woFun + attnBias := attnBiasFun + maskCausal := maskCausal + maskValue := maskValue + directionSpec := directionSpec + direction := directionFun } + +/-- Parse a raw induction head input payload from UTF-8 bytes. -/ +def parseInductionHeadInputsBytes (data : ByteArray) : + Except String (Sigma (fun seq => + Sigma (fun dModel => Sigma (fun dHead => Model.InductionHeadInputs seq dModel dHead)))) := do + let mut seq? : Option Nat := none + let mut dModel? : Option Nat := none + let mut dHead? : Option Nat := none + let mut i := 0 + let mut afterDims := 0 + let mut haveDims := false + while i < data.size && !haveDims do + let (i', lineStart, lineEnd) := nextLineBounds data i + i := i' + afterDims := i' + let (seqNew, dModelNew, dHeadNew) ← + parseHeaderLineBytes data lineStart lineEnd seq? dModel? dHead? + seq? := seqNew + dModel? := dModelNew + dHead? := dHeadNew + if seq?.isSome && dModel?.isSome && dHead?.isSome then + haveDims := true + let seq ← + match seq? with + | some v => pure v + | none => throw "missing seq entry" + let dModel ← + match dModel? with + | some v => pure v + | none => throw "missing d_model entry" + let dHead ← + match dHead? with + | some v => pure v + | none => throw "missing d_head entry" + match seq, dModel, dHead with + | 0, _, _ => throw "seq must be positive" + | _, 0, _ => throw "d_model must be positive" + | _, _, 0 => throw "d_head must be positive" + | Nat.succ n, Nat.succ m, Nat.succ h => + let seq := Nat.succ n + let dModel := Nat.succ m + let dHead := Nat.succ h + let hpos : 0 < seq := Nat.succ_pos n + let st0 : HeadParseState seq dModel dHead := initHeadState seq dModel dHead + let mut st := st0 + let mut j := afterDims + while j < data.size do + let (j', lineStart, lineEnd) := nextLineBounds data j + j := j' + st ← parseHeadLineBytes data st lineStart lineEnd + let inputs ← finalizeHeadState hpos st + return ⟨seq, ⟨dModel, ⟨dHead, inputs⟩⟩⟩ + +end Pure + +end IO + +end Nfp diff --git a/Nfp/IO/Pure/Residual.lean b/Nfp/IO/Pure/Residual.lean new file mode 100644 index 0000000..72fb5c6 --- /dev/null +++ b/Nfp/IO/Pure/Residual.lean @@ -0,0 +1,136 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Circuit.Cert.ResidualBound +import Nfp.Circuit.Cert.ResidualInterval +import Nfp.IO.Pure.Basic + +/-! +Pure parsing helpers for residual-bound and residual-interval certificates. +-/ + +namespace Nfp + +namespace IO + +namespace Pure + +open Nfp.Circuit + +private structure ResidualBoundParseState (n : Nat) where + bounds : Fin n → Option Dyadic + +private def initResidualBoundState (n : Nat) : ResidualBoundParseState n := + { bounds := fun _ => none } + +private def setVectorEntry {n : Nat} (bounds : Fin n → Option Dyadic) + (i : Nat) (v : Dyadic) : Except String (Fin n → Option Dyadic) := do + if hi : i < n then + let iFin : Fin n := ⟨i, hi⟩ + match bounds iFin with + | some _ => + throw s!"duplicate bound entry at index {i}" + | none => + let bounds' : Fin n → Option Dyadic := fun i' => + if i' = iFin then + some v + else + bounds i' + return bounds' + else + throw s!"index out of range: {i}" + +private def parseResidualBoundLine {n : Nat} (st : ResidualBoundParseState n) + (tokens : List String) : Except String (ResidualBoundParseState n) := do + match tokens with + | ["bound", i, val] => + let bounds ← setVectorEntry st.bounds (← parseNat i) (← parseDyadic val) + return { st with bounds := bounds } + | ["dim", _] => + throw "duplicate dim entry" + | _ => + throw s!"unrecognized line: '{String.intercalate " " tokens}'" + +private def finalizeResidualBoundState {n : Nat} (st : ResidualBoundParseState n) : + Except String (Circuit.ResidualBoundCert n) := do + if !finsetAll (Finset.univ : Finset (Fin n)) (fun i => (st.bounds i).isSome) then + throw "missing bound entries" + let bound : Fin n → Dyadic := fun i => + (st.bounds i).getD 0 + return { bound := bound } + +/-- Parse a residual-bound payload from text. -/ +def parseResidualBoundCert (input : String) : + Except String (Sigma (fun n => Circuit.ResidualBoundCert n)) := do + let lines := input.splitOn "\n" + let tokens := lines.filterMap cleanTokens + match tokens with + | [] => throw "empty residual-bound payload" + | ["dim", nStr] :: rest => + let n ← parseNat nStr + match n with + | 0 => throw "dim must be positive" + | Nat.succ n' => + let dim := Nat.succ n' + let st0 := initResidualBoundState dim + let st ← rest.foldlM (fun st t => parseResidualBoundLine st t) st0 + let cert ← finalizeResidualBoundState st + return ⟨dim, cert⟩ + | _ => throw "expected header 'dim '" + +private structure ResidualIntervalParseState (n : Nat) where + lo : Fin n → Option Dyadic + hi : Fin n → Option Dyadic + +private def initResidualIntervalState (n : Nat) : ResidualIntervalParseState n := + { lo := fun _ => none, hi := fun _ => none } + +private def parseResidualIntervalLine {n : Nat} (st : ResidualIntervalParseState n) + (tokens : List String) : Except String (ResidualIntervalParseState n) := do + match tokens with + | ["lo", i, val] => + let lo ← setVectorEntry st.lo (← parseNat i) (← parseDyadic val) + return { st with lo := lo } + | ["hi", i, val] => + let hi ← setVectorEntry st.hi (← parseNat i) (← parseDyadic val) + return { st with hi := hi } + | ["dim", _] => + throw "duplicate dim entry" + | _ => + throw s!"unrecognized line: '{String.intercalate " " tokens}'" + +private def finalizeResidualIntervalState {n : Nat} (st : ResidualIntervalParseState n) : + Except String (Circuit.ResidualIntervalCert n) := do + if !finsetAll (Finset.univ : Finset (Fin n)) (fun i => (st.lo i).isSome) then + throw "missing lo entries" + if !finsetAll (Finset.univ : Finset (Fin n)) (fun i => (st.hi i).isSome) then + throw "missing hi entries" + let lo : Fin n → Dyadic := fun i => + (st.lo i).getD 0 + let hi : Fin n → Dyadic := fun i => + (st.hi i).getD 0 + return { lo := lo, hi := hi } + +/-- Parse a residual-interval payload from text. -/ +def parseResidualIntervalCert (input : String) : + Except String (Sigma (fun n => Circuit.ResidualIntervalCert n)) := do + let lines := input.splitOn "\n" + let tokens := lines.filterMap cleanTokens + match tokens with + | [] => throw "empty residual-interval payload" + | ["dim", nStr] :: rest => + let n ← parseNat nStr + match n with + | 0 => throw "dim must be positive" + | Nat.succ n' => + let dim := Nat.succ n' + let st0 := initResidualIntervalState dim + let st ← rest.foldlM (fun st t => parseResidualIntervalLine st t) st0 + let cert ← finalizeResidualIntervalState st + return ⟨dim, cert⟩ + | _ => throw "expected header 'dim '" + +end Pure + +end IO + +end Nfp diff --git a/Nfp/IO/Pure/SoftmaxMargin.lean b/Nfp/IO/Pure/SoftmaxMargin.lean new file mode 100644 index 0000000..0b12370 --- /dev/null +++ b/Nfp/IO/Pure/SoftmaxMargin.lean @@ -0,0 +1,8 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.IO.Pure.SoftmaxMargin.Cert +import Nfp.IO.Pure.SoftmaxMargin.Raw + +/-! +Aggregator for softmax-margin parsing helpers. +-/ diff --git a/Nfp/IO/Pure/SoftmaxMargin/Cert.lean b/Nfp/IO/Pure/SoftmaxMargin/Cert.lean new file mode 100644 index 0000000..29bd6aa --- /dev/null +++ b/Nfp/IO/Pure/SoftmaxMargin/Cert.lean @@ -0,0 +1,79 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Circuit.Cert.SoftmaxMargin +import Nfp.IO.Pure.SoftmaxMargin.Shared + +/-! +Pure parsing helpers for softmax-margin certificates. +-/ + +namespace Nfp + +namespace IO + +namespace Pure + +open Nfp.Circuit + +private def finalizeState {seq : Nat} (hpos : 0 < seq) + (st : SoftmaxMargin.ParseState seq) : Except String (SoftmaxMarginCert seq) := do + let eps ← + match st.eps with + | some v => pure v + | none => throw "missing eps entry" + let margin ← + match st.margin with + | some v => pure v + | none => throw "missing margin entry" + if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => (st.prev q).isSome) then + throw "missing prev entries" + if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.scores q k).isSome)) then + throw "missing score entries" + if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.weights q k).isSome)) then + throw "missing weight entries" + let defaultPrev : Fin seq := ⟨0, hpos⟩ + let prevFun : Fin seq → Fin seq := fun q => + (st.prev q).getD defaultPrev + let scoresFun : Fin seq → Fin seq → Dyadic := fun q k => + (st.scores q k).getD 0 + let weightsFun : Fin seq → Fin seq → Dyadic := fun q k => + (st.weights q k).getD 0 + let active := + if st.activeSeen then + st.active + else + (Finset.univ : Finset (Fin seq)).erase defaultPrev + pure + { eps := eps + margin := margin + active := active + prev := prevFun + scores := scoresFun + weights := weightsFun } + +/-- Parse a softmax-margin certificate from a text payload. -/ +def parseSoftmaxMarginCert (input : String) : + Except String (Sigma SoftmaxMarginCert) := do + let lines := input.splitOn "\n" + let tokens := lines.filterMap cleanTokens + let seq ← SoftmaxMargin.parseSeq tokens + match seq with + | 0 => throw "seq must be positive" + | Nat.succ n => + let seq := Nat.succ n + let hpos : 0 < seq := Nat.succ_pos n + let st0 : SoftmaxMargin.ParseState seq := SoftmaxMargin.initState seq + let st ← tokens.foldlM (fun st t => + match t with + | ["seq", _] => pure st + | _ => SoftmaxMargin.parseLine st t) st0 + let cert ← finalizeState hpos st + return ⟨seq, cert⟩ + +end Pure + +end IO + +end Nfp diff --git a/Nfp/IO/Pure/SoftmaxMargin/Raw.lean b/Nfp/IO/Pure/SoftmaxMargin/Raw.lean new file mode 100644 index 0000000..005d024 --- /dev/null +++ b/Nfp/IO/Pure/SoftmaxMargin/Raw.lean @@ -0,0 +1,80 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Circuit.Cert.SoftmaxMargin +import Nfp.IO.Pure.SoftmaxMargin.Shared + +/-! +Pure parsing helpers for raw softmax-margin inputs. +-/ + +namespace Nfp + +namespace IO + +namespace Pure + +open Nfp.Circuit + +/-- Raw softmax-margin payload without `eps`/`margin`. -/ +structure SoftmaxMarginRaw (seq : Nat) where + /-- Active queries for which bounds are required. -/ + active : Finset (Fin seq) + /-- `prev` selector for induction-style attention. -/ + prev : Fin seq → Fin seq + /-- Score matrix entries. -/ + scores : Fin seq → Fin seq → Dyadic + /-- Attention weight entries. -/ + weights : Fin seq → Fin seq → Dyadic + +private def finalizeRawState {seq : Nat} (hpos : 0 < seq) + (st : SoftmaxMargin.ParseState seq) : Except String (SoftmaxMarginRaw seq) := do + if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => (st.prev q).isSome) then + throw "missing prev entries" + if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.scores q k).isSome)) then + throw "missing score entries" + if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.weights q k).isSome)) then + throw "missing weight entries" + let defaultPrev : Fin seq := ⟨0, hpos⟩ + let prevFun : Fin seq → Fin seq := fun q => + (st.prev q).getD defaultPrev + let scoresFun : Fin seq → Fin seq → Dyadic := fun q k => + (st.scores q k).getD 0 + let weightsFun : Fin seq → Fin seq → Dyadic := fun q k => + (st.weights q k).getD 0 + let active := + if st.activeSeen then + st.active + else + (Finset.univ : Finset (Fin seq)).erase defaultPrev + pure + { active := active + prev := prevFun + scores := scoresFun + weights := weightsFun } + +/-- Parse a raw softmax-margin payload from text (ignores any `eps`/`margin`). -/ +def parseSoftmaxMarginRaw (input : String) : + Except String (Sigma SoftmaxMarginRaw) := do + let lines := input.splitOn "\n" + let tokens := lines.filterMap cleanTokens + let seq ← SoftmaxMargin.parseSeq tokens + match seq with + | 0 => throw "seq must be positive" + | Nat.succ n => + let seq := Nat.succ n + let hpos : 0 < seq := Nat.succ_pos n + let st0 : SoftmaxMargin.ParseState seq := SoftmaxMargin.initState seq + let st ← tokens.foldlM (fun st t => + match t with + | ["seq", _] => pure st + | _ => SoftmaxMargin.parseLine st t) st0 + let raw ← finalizeRawState hpos st + return ⟨seq, raw⟩ + +end Pure + +end IO + +end Nfp diff --git a/Nfp/IO/Pure/SoftmaxMargin/Shared.lean b/Nfp/IO/Pure/SoftmaxMargin/Shared.lean new file mode 100644 index 0000000..a316d83 --- /dev/null +++ b/Nfp/IO/Pure/SoftmaxMargin/Shared.lean @@ -0,0 +1,138 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Data.Finset.Insert +import Nfp.IO.Pure.Basic + +/-! +Shared parsing helpers for softmax-margin payloads. +-/ + +namespace Nfp + +namespace IO + +namespace Pure + +namespace SoftmaxMargin + +open Nfp.Circuit + +structure ParseState (seq : Nat) where + eps : Option Dyadic + margin : Option Dyadic + active : Finset (Fin seq) + activeSeen : Bool + prev : Fin seq → Option (Fin seq) + scores : Fin seq → Fin seq → Option Dyadic + weights : Fin seq → Fin seq → Option Dyadic + +def initState (seq : Nat) : ParseState seq := + { eps := none + margin := none + active := ∅ + activeSeen := false + prev := fun _ => none + scores := fun _ _ => none + weights := fun _ _ => none } + +def setPrev {seq : Nat} (st : ParseState seq) (q k : Nat) : Except String (ParseState seq) := do + if hq : q < seq then + if hk : k < seq then + let qFin : Fin seq := ⟨q, hq⟩ + let kFin : Fin seq := ⟨k, hk⟩ + match st.prev qFin with + | some _ => + throw s!"duplicate prev entry for q={q}" + | none => + let prev' : Fin seq → Option (Fin seq) := fun q' => + if q' = qFin then + some kFin + else + st.prev q' + return { st with prev := prev' } + else + throw s!"prev index out of range: k={k}" + else + throw s!"prev index out of range: q={q}" + +def setActive {seq : Nat} (st : ParseState seq) (q : Nat) : Except String (ParseState seq) := do + if hq : q < seq then + let qFin : Fin seq := ⟨q, hq⟩ + if qFin ∈ st.active then + throw s!"duplicate active entry for q={q}" + else + return { st with active := insert qFin st.active, activeSeen := true } + else + throw s!"active index out of range: q={q}" + +def setMatrixEntry {seq : Nat} (mat : Fin seq → Fin seq → Option Dyadic) + (q k : Nat) (v : Dyadic) : Except String (Fin seq → Fin seq → Option Dyadic) := do + if hq : q < seq then + if hk : k < seq then + let qFin : Fin seq := ⟨q, hq⟩ + let kFin : Fin seq := ⟨k, hk⟩ + match mat qFin kFin with + | some _ => + throw s!"duplicate matrix entry at ({q}, {k})" + | none => + let mat' : Fin seq → Fin seq → Option Dyadic := fun q' k' => + if q' = qFin then + if k' = kFin then + some v + else + mat q' k' + else + mat q' k' + return mat' + else + throw s!"index out of range: k={k}" + else + throw s!"index out of range: q={q}" + +def parseLine {seq : Nat} (st : ParseState seq) + (tokens : List String) : Except String (ParseState seq) := do + match tokens with + | ["eps", val] => + if st.eps.isSome then + throw "duplicate eps entry" + else + return { st with eps := some (← parseDyadic val) } + | ["margin", val] => + if st.margin.isSome then + throw "duplicate margin entry" + else + return { st with margin := some (← parseDyadic val) } + | ["active", q] => + setActive st (← parseNat q) + | ["prev", q, k] => + setPrev st (← parseNat q) (← parseNat k) + | ["score", q, k, val] => + let mat ← setMatrixEntry st.scores (← parseNat q) (← parseNat k) (← parseDyadic val) + return { st with scores := mat } + | ["weight", q, k, val] => + let mat ← setMatrixEntry st.weights (← parseNat q) (← parseNat k) (← parseDyadic val) + return { st with weights := mat } + | _ => + throw s!"unrecognized line: '{String.intercalate " " tokens}'" + +def parseSeq (tokens : List (List String)) : Except String Nat := do + let mut seq? : Option Nat := none + for t in tokens do + match t with + | ["seq", n] => + if seq?.isSome then + throw "duplicate seq entry" + else + seq? := some (← parseNat n) + | _ => pure () + match seq? with + | some v => pure v + | none => throw "missing seq entry" + +end SoftmaxMargin + +end Pure + +end IO + +end Nfp diff --git a/Nfp/IO/Pure/ValueRange.lean b/Nfp/IO/Pure/ValueRange.lean new file mode 100644 index 0000000..a6053d4 --- /dev/null +++ b/Nfp/IO/Pure/ValueRange.lean @@ -0,0 +1,8 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.IO.Pure.ValueRange.Cert +import Nfp.IO.Pure.ValueRange.Raw + +/-! +Aggregator for value-range parsing helpers. +-/ diff --git a/Nfp/IO/Pure/ValueRange/Cert.lean b/Nfp/IO/Pure/ValueRange/Cert.lean new file mode 100644 index 0000000..87edb9b --- /dev/null +++ b/Nfp/IO/Pure/ValueRange/Cert.lean @@ -0,0 +1,63 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Circuit.Cert.ValueRange +import Nfp.IO.Pure.ValueRange.Shared + +/-! +Pure parsing helpers for value-range certificates. +-/ + +namespace Nfp + +namespace IO + +namespace Pure + +open Nfp.Circuit + +private def finalizeValueState {seq : Nat} (st : ValueRange.ParseState seq) : + Except String (ValueRangeCert seq) := do + let lo ← + match st.lo with + | some v => pure v + | none => throw "missing lo entry" + let hi ← + match st.hi with + | some v => pure v + | none => throw "missing hi entry" + if !finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.vals k).isSome) then + throw "missing value entries" + let valsFun : Fin seq → Dyadic := fun k => + (st.vals k).getD 0 + let direction ← + match st.directionTarget, st.directionNegative with + | none, none => pure none + | some target, some negative => + pure (some { target := target, negative := negative }) + | _, _ => + throw "direction metadata requires both direction-target and direction-negative" + return { lo := lo, hi := hi, vals := valsFun, direction := direction } + +/-- Parse a value-range certificate from a text payload. -/ +def parseValueRangeCert (input : String) : + Except String (Sigma ValueRangeCert) := do + let lines := input.splitOn "\n" + let tokens := lines.filterMap cleanTokens + let seq ← ValueRange.parseSeq tokens + match seq with + | 0 => throw "seq must be positive" + | Nat.succ n => + let seq := Nat.succ n + let st0 : ValueRange.ParseState seq := ValueRange.initState seq + let st ← tokens.foldlM (fun st t => + match t with + | ["seq", _] => pure st + | _ => ValueRange.parseLine st t) st0 + let cert ← finalizeValueState st + return ⟨seq, cert⟩ + +end Pure + +end IO + +end Nfp diff --git a/Nfp/IO/Pure/ValueRange/Raw.lean b/Nfp/IO/Pure/ValueRange/Raw.lean new file mode 100644 index 0000000..f7c74fc --- /dev/null +++ b/Nfp/IO/Pure/ValueRange/Raw.lean @@ -0,0 +1,62 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Circuit.Cert.ValueRange +import Nfp.IO.Pure.ValueRange.Shared + +/-! +Pure parsing helpers for raw value-range inputs. +-/ + +namespace Nfp + +namespace IO + +namespace Pure + +open Nfp.Circuit + +/-- Raw value-range payload without `lo`/`hi` bounds. -/ +structure ValueRangeRaw (seq : Nat) where + /-- Value entries. -/ + vals : Fin seq → Dyadic + /-- Optional logit-diff direction metadata. -/ + direction : Option Circuit.DirectionSpec + +private def finalizeValueRawState {seq : Nat} (st : ValueRange.ParseState seq) : + Except String (ValueRangeRaw seq) := do + if !finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.vals k).isSome) then + throw "missing value entries" + let valsFun : Fin seq → Dyadic := fun k => + (st.vals k).getD 0 + let direction ← + match st.directionTarget, st.directionNegative with + | none, none => pure none + | some target, some negative => + pure (some { target := target, negative := negative }) + | _, _ => + throw "direction metadata requires both direction-target and direction-negative" + return { vals := valsFun, direction := direction } + +/-- Parse a raw value-range payload from text (ignores any `lo`/`hi`). -/ +def parseValueRangeRaw (input : String) : + Except String (Sigma ValueRangeRaw) := do + let lines := input.splitOn "\n" + let tokens := lines.filterMap cleanTokens + let seq ← ValueRange.parseSeq tokens + match seq with + | 0 => throw "seq must be positive" + | Nat.succ n => + let seq := Nat.succ n + let st0 : ValueRange.ParseState seq := ValueRange.initState seq + let st ← tokens.foldlM (fun st t => + match t with + | ["seq", _] => pure st + | _ => ValueRange.parseLine st t) st0 + let raw ← finalizeValueRawState st + return ⟨seq, raw⟩ + +end Pure + +end IO + +end Nfp diff --git a/Nfp/IO/Pure/ValueRange/Shared.lean b/Nfp/IO/Pure/ValueRange/Shared.lean new file mode 100644 index 0000000..b653df2 --- /dev/null +++ b/Nfp/IO/Pure/ValueRange/Shared.lean @@ -0,0 +1,103 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Circuit.Cert.ValueRange +import Nfp.IO.Pure.Basic + +/-! +Shared parsing helpers for value-range payloads. +-/ + +namespace Nfp + +namespace IO + +namespace Pure + +namespace ValueRange + +open Nfp.Circuit + +structure ParseState (seq : Nat) where + lo : Option Dyadic + hi : Option Dyadic + vals : Fin seq → Option Dyadic + directionTarget : Option Nat + directionNegative : Option Nat + + +def initState (seq : Nat) : ParseState seq := + { lo := none + hi := none + vals := fun _ => none + directionTarget := none + directionNegative := none } + + +def setVal {seq : Nat} (st : ParseState seq) (k : Nat) (v : Dyadic) : + Except String (ParseState seq) := do + if hk : k < seq then + let kFin : Fin seq := ⟨k, hk⟩ + match st.vals kFin with + | some _ => + throw s!"duplicate value entry for k={k}" + | none => + let vals' : Fin seq → Option Dyadic := fun k' => + if k' = kFin then + some v + else + st.vals k' + return { st with vals := vals' } + else + throw s!"value index out of range: k={k}" + + +def parseLine {seq : Nat} (st : ParseState seq) + (tokens : List String) : Except String (ParseState seq) := do + match tokens with + | ["lo", val] => + if st.lo.isSome then + throw "duplicate lo entry" + else + return { st with lo := some (← parseDyadic val) } + | ["hi", val] => + if st.hi.isSome then + throw "duplicate hi entry" + else + return { st with hi := some (← parseDyadic val) } + | ["val", k, val] => + setVal st (← parseNat k) (← parseDyadic val) + | ["direction-target", tok] => + if st.directionTarget.isSome then + throw "duplicate direction-target entry" + else + return { st with directionTarget := some (← parseNat tok) } + | ["direction-negative", tok] => + if st.directionNegative.isSome then + throw "duplicate direction-negative entry" + else + return { st with directionNegative := some (← parseNat tok) } + | _ => + throw s!"unrecognized line: '{String.intercalate " " tokens}'" + + +def parseSeq (tokens : List (List String)) : Except String Nat := do + let mut seq? : Option Nat := none + for t in tokens do + match t with + | ["seq", n] => + if seq?.isSome then + throw "duplicate seq entry" + else + seq? := some (← parseNat n) + | _ => pure () + match seq? with + | some v => pure v + | none => throw "missing seq entry" + +end ValueRange + +end Pure + +end IO + +end Nfp diff --git a/Nfp/IO/Timing.lean b/Nfp/IO/Timing.lean new file mode 100644 index 0000000..b510e7e --- /dev/null +++ b/Nfp/IO/Timing.lean @@ -0,0 +1,204 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Data.List.Range +import Nfp.Model.InductionHead +import Nfp.Sound.Induction.HeadBounds + +/-! +Small IO helpers for benchmarking task overhead and profiling slow phases. +-/ + +namespace Nfp + +namespace IO + +open Sound + +/-- Current monotonic time in microseconds. -/ +def monoUsNow : IO Nat := do + let t ← IO.monoNanosNow + return t / 1000 + +/-- Append a timing log line to `NFP_TIMING_LOG` when set. -/ +def logTiming (line : String) : IO Unit := do + match (← IO.getEnv "NFP_TIMING_LOG") with + | some path => + let h ← IO.FS.Handle.mk (System.FilePath.mk path) IO.FS.Mode.append + h.putStr (line ++ "\n") + h.flush + | none => pure () + +/-- Time an IO phase and print the duration in microseconds. -/ +def timePhase {α : Type} (label : String) (act : IO α) : IO α := do + logTiming s!"start: {label}" + let t0 ← monoUsNow + let res ← act + let t1 ← monoUsNow + logTiming s!"done: {label} {t1 - t0} us" + IO.println s!"timing: {label} {t1 - t0} us" + return res + +/-- Time an IO phase supplied as a thunk and print the duration in microseconds. -/ +def timePhaseThunk {α : Type} (label : String) (act : Unit → IO α) : IO α := do + logTiming s!"start: {label}" + let t0 ← monoUsNow + let res ← act () + let t1 ← monoUsNow + logTiming s!"done: {label} {t1 - t0} us" + IO.println s!"timing: {label} {t1 - t0} us" + return res + +/-- Time a pure thunk and print the duration in microseconds. -/ +def timePure {α : Type} (label : String) (f : Unit → α) : IO α := do + logTiming s!"start: {label}" + let t0 ← monoUsNow + let res := f () + let t1 ← monoUsNow + logTiming s!"done: {label} {t1 - t0} us" + IO.println s!"timing: {label} {t1 - t0} us" + return res + +/-- Flush stdout immediately for interleaved timing output. -/ +def flushStdout : IO Unit := do + let h ← IO.getStdout + h.flush + +/-- Measure task spawn/get overhead on this machine. -/ +def taskBench (n : Nat) : IO Unit := do + if n = 0 then + IO.println "timing: task bench skipped (n=0)" + return + let t0 ← monoUsNow + let tasks := (List.range n).map (fun _ => Task.spawn (fun _ => ())) + for t in tasks do + let _ := t.get + pure () + let t1 ← monoUsNow + let total := t1 - t0 + let avg := total / n + IO.println s!"timing: task bench n={n} total={total} us avg={avg} us" + +/-- Force a sample score-gap computation for timing. -/ +def timeHeadScoreSampleGap {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (score : Sound.HeadScoreBounds seq dModel dHead) : IO Unit := do + IO.println "timing: head score sample gap start" + (← IO.getStdout).flush + let t0 ← monoUsNow + match List.finRange seq with + | [] => + IO.println "timing: head score sample gap skipped (empty seq)" + | q :: _ => + let _ := score.scoreLo q (inputs.prev q) + let _ := score.scoreHi q (inputs.prev q) + let _ := score.scoreLo q (inputs.prev q) - score.scoreHi q (inputs.prev q) + pure () + let t1 ← monoUsNow + IO.println s!"timing: head score sample gap {t1 - t0} us" + (← IO.getStdout).flush + +/-- Force marginAt evaluation over the active list for timing. -/ +def timeHeadScoreMarginList {seq dModel dHead : Nat} + (activeList : List (Fin seq)) + (score : Sound.HeadScoreBounds seq dModel dHead) : IO Unit := do + IO.println "timing: head score marginAt list start" + (← IO.getStdout).flush + let t0 ← monoUsNow + for q in activeList do + let _ := score.marginAt q + pure () + let t1 ← monoUsNow + IO.println s!"timing: head score marginAt list {t1 - t0} us" + (← IO.getStdout).flush + +/-- Force marginAt evaluation without constructing the full score bounds record. -/ +def timeHeadScoreMarginRaw {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (dotAbs : Fin seq → Fin seq → Dyadic) + (activeList : List (Fin seq)) : IO Unit := do + IO.println "timing: head score marginRaw list start" + (← IO.getStdout).flush + let t0 ← monoUsNow + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let scoreBaseAbs : Fin seq → Fin seq → Dyadic := fun q k => + |inputs.scale| * dotAbs q k + let scoreLo : Fin seq → Fin seq → Dyadic := fun q k => + if masked q k then + inputs.maskValue + else + -scoreBaseAbs q k + let scoreHi : Fin seq → Fin seq → Dyadic := fun q k => + if masked q k then + inputs.maskValue + else + scoreBaseAbs q k + let otherKeys : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + let maskedKeys : Fin seq → Finset (Fin seq) := fun q => + if inputs.maskCausal = true then + (otherKeys q).filter (fun k => q < k) + else + (∅ : Finset (Fin seq)) + let unmaskedKeys : Fin seq → Finset (Fin seq) := fun q => + (otherKeys q) \ (maskedKeys q) + let maskedGap : Fin seq → Dyadic := fun q => + scoreLo q (inputs.prev q) - inputs.maskValue + let scoreGap : Fin seq → Fin seq → Dyadic := fun q k => + scoreLo q (inputs.prev q) - scoreHi q k + let marginAtRaw : Fin seq → Dyadic := fun q => + let other := unmaskedKeys q + let maskedSet := maskedKeys q + if hunmasked : other.Nonempty then + let unmaskedMin := other.inf' hunmasked (fun k => scoreGap q k) + if _hmasked : maskedSet.Nonempty then + min unmaskedMin (maskedGap q) + else + unmaskedMin + else + if _hmasked : maskedSet.Nonempty then + maskedGap q + else + (0 : Dyadic) + for q in activeList do + let _ := marginAtRaw q + pure () + let t1 ← monoUsNow + IO.println s!"timing: head score marginRaw list {t1 - t0} us" + (← IO.getStdout).flush + +/-- Force individual score-bound fields to locate slow evaluations. -/ +def timeHeadScoreFieldForces {seq dModel dHead : Nat} + (score : Sound.HeadScoreBounds seq dModel dHead) : IO Unit := do + IO.println "timing: head score field force start" + (← IO.getStdout).flush + let timeOne (label : String) (f : Unit → IO Unit) : IO Unit := do + let t0 ← monoUsNow + f () + let t1 ← monoUsNow + IO.println s!"timing: head score field {label} {t1 - t0} us" + (← IO.getStdout).flush + match List.finRange seq with + | [] => + IO.println "timing: head score field force skipped (empty seq)" + (← IO.getStdout).flush + | q :: _ => + match List.finRange seq with + | [] => + IO.println "timing: head score field force skipped (empty seq)" + (← IO.getStdout).flush + | k :: _ => + timeOne "scoreBaseAbs" (fun _ => do let _ := score.scoreBaseAbs q k; pure ()) + timeOne "scoreAbs" (fun _ => do let _ := score.scoreAbs q k; pure ()) + timeOne "scoreLo" (fun _ => do let _ := score.scoreLo q k; pure ()) + timeOne "scoreHi" (fun _ => do let _ := score.scoreHi q k; pure ()) + timeOne "marginAt" (fun _ => do let _ := score.marginAt q; pure ()) + timeOne "epsAt" (fun _ => do let _ := score.epsAt q; pure ()) + timeOne "margin" (fun _ => do let _ := score.margin; pure ()) + timeOne "eps" (fun _ => do let _ := score.eps; pure ()) + IO.println "timing: head score field force done" + (← IO.getStdout).flush + +end IO + +end Nfp diff --git a/Nfp/IO/Util.lean b/Nfp/IO/Util.lean new file mode 100644 index 0000000..b7fe877 --- /dev/null +++ b/Nfp/IO/Util.lean @@ -0,0 +1,25 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.IO.Pure + +/-! +Small shared helpers for IO parsing. +-/ + +namespace Nfp + +namespace IO + +/-- Parse an optional dyadic literal for CLI flags (rounded down if needed). -/ +def parseDyadicOpt (label : String) (raw? : Option String) : + Except String (Option Dyadic) := + match raw? with + | none => Except.ok none + | some raw => + match Pure.parseDyadic raw with + | Except.ok v => Except.ok (some v) + | Except.error msg => Except.error s!"invalid {label}: {msg}" + +end IO + +end Nfp diff --git a/Nfp/Mixer/Operations.lean b/Nfp/Mixer/Operations.lean index ee66010..8fc4a64 100644 --- a/Nfp/Mixer/Operations.lean +++ b/Nfp/Mixer/Operations.lean @@ -17,6 +17,22 @@ universe u variable {ι κ α : Type u} [Fintype ι] [Fintype κ] [Fintype α] +/-- Swap a double sum and factor the inner sum out by multiplication. -/ +private lemma sum_mul_sum (p : ι → Mass) (w : ι → κ → Mass) : + (∑ k, ∑ i, p i * w i k) = ∑ i, p i * ∑ k, w i k := by + classical + calc + ∑ k, ∑ i, p i * w i k = ∑ i, ∑ k, p i * w i k := by + simpa using + (Finset.sum_comm : + (∑ k : κ, ∑ i : ι, p i * w i k) = ∑ i : ι, ∑ k : κ, p i * w i k) + _ = ∑ i, p i * ∑ k, w i k := by + refine Finset.sum_congr rfl ?_ + intro i _ + simpa using + (Finset.mul_sum (a := p i) (s := (Finset.univ : Finset κ)) + (f := fun k => w i k)).symm + /-- Push a probability vector forward along a mixer. -/ def push (M : Mixer ι κ) (p : ProbVec ι) : ProbVec κ := { mass := fun k => ∑ i, p.mass i * M.weight i k @@ -24,18 +40,8 @@ def push (M : Mixer ι κ) (p : ProbVec ι) : ProbVec κ := classical calc ∑ k, ∑ i, p.mass i * M.weight i k - = ∑ i, ∑ k, p.mass i * M.weight i k := by - simpa using - (Finset.sum_comm : - (∑ k : κ, ∑ i : ι, p.mass i * M.weight i k) = - ∑ i : ι, ∑ k : κ, p.mass i * M.weight i k) - _ = ∑ i, p.mass i * ∑ k, M.weight i k := by - refine Finset.sum_congr rfl ?_ - intro i _ - simpa using - (Finset.mul_sum (a := p.mass i) (s := (Finset.univ : Finset κ)) - (f := fun k => M.weight i k)).symm - _ = ∑ i, p.mass i * 1 := by simp + = ∑ i, p.mass i * ∑ k, M.weight i k := by + simpa using sum_mul_sum (p := fun i => p.mass i) (w := fun i => M.weight i) _ = 1 := by simp } /-- Composition of two mixers. -/ @@ -46,18 +52,9 @@ def comp (M : Mixer ι κ) (N : Mixer κ α) : Mixer ι α := intro i calc ∑ a, ∑ k, M.weight i k * N.weight k a - = ∑ k, ∑ a, M.weight i k * N.weight k a := by + = ∑ k, M.weight i k * ∑ a, N.weight k a := by simpa using - (Finset.sum_comm : - (∑ a : α, ∑ k : κ, M.weight i k * N.weight k a) = - ∑ k : κ, ∑ a : α, M.weight i k * N.weight k a) - _ = ∑ k, M.weight i k * ∑ a, N.weight k a := by - refine Finset.sum_congr rfl ?_ - intro k _ - simpa using - (Finset.mul_sum (a := M.weight i k) (s := (Finset.univ : Finset α)) - (f := fun a => N.weight k a)).symm - _ = ∑ k, M.weight i k * 1 := by simp + sum_mul_sum (p := fun k => M.weight i k) (w := fun k => N.weight k) _ = 1 := by simp } /-- Identity mixer. -/ diff --git a/Nfp/Model/Gpt2.lean b/Nfp/Model/Gpt2.lean index a360d92..975a2cd 100644 --- a/Nfp/Model/Gpt2.lean +++ b/Nfp/Model/Gpt2.lean @@ -1,6 +1,6 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.Order.Ring.Rat +import Nfp.Core.Basic import Nfp.Circuit.Cert.ValueRange /-! @@ -31,92 +31,92 @@ def DirectionTokens.spec {vocab : Nat} (dir : DirectionTokens vocab) : Direction /-- Exact GPT-2 head slice needed to build induction-head inputs. -/ structure Gpt2HeadSlice (seq dModel dHead vocab : Nat) where /-- Softmax scale factor (e.g. `1/8` for head dim 64). -/ - scale : Rat + scale : Dyadic /-- Token ids for the prompt. -/ tokens : Fin seq → Fin vocab /-- Token embedding matrix. -/ - wte : Fin vocab → Fin dModel → Rat + wte : Fin vocab → Fin dModel → Dyadic /-- Positional embedding matrix. -/ - wpe : Fin seq → Fin dModel → Rat + wpe : Fin seq → Fin dModel → Dyadic /-- Query projection weights. -/ - wq : Fin dModel → Fin dHead → Rat + wq : Fin dModel → Fin dHead → Dyadic /-- Query projection bias. -/ - bq : Fin dHead → Rat + bq : Fin dHead → Dyadic /-- Key projection weights. -/ - wk : Fin dModel → Fin dHead → Rat + wk : Fin dModel → Fin dHead → Dyadic /-- Key projection bias. -/ - bk : Fin dHead → Rat + bk : Fin dHead → Dyadic /-- Value projection weights. -/ - wv : Fin dModel → Fin dHead → Rat + wv : Fin dModel → Fin dHead → Dyadic /-- Value projection bias. -/ - bv : Fin dHead → Rat + bv : Fin dHead → Dyadic /-- Output projection weights for this head slice. -/ - wo : Fin dModel → Fin dHead → Rat + wo : Fin dModel → Fin dHead → Dyadic /-- Attention output bias (shared across heads). -/ - attnBias : Fin dModel → Rat + attnBias : Fin dModel → Dyadic /-- LayerNorm epsilon for the attention input. -/ - lnEps : Rat + lnEps : Dyadic /-- LayerNorm scale for the attention input. -/ - ln1Gamma : Fin dModel → Rat + ln1Gamma : Fin dModel → Dyadic /-- LayerNorm bias for the attention input. -/ - ln1Beta : Fin dModel → Rat + ln1Beta : Fin dModel → Dyadic /-- Direction tokens for logit-diff certification. -/ direction : DirectionTokens vocab /-- Exact per-head attention weights and biases. -/ structure Gpt2HeadWeights (dModel dHead : Nat) where /-- Query projection weights. -/ - wq : Fin dModel → Fin dHead → Rat + wq : Fin dModel → Fin dHead → Dyadic /-- Query projection bias. -/ - bq : Fin dHead → Rat + bq : Fin dHead → Dyadic /-- Key projection weights. -/ - wk : Fin dModel → Fin dHead → Rat + wk : Fin dModel → Fin dHead → Dyadic /-- Key projection bias. -/ - bk : Fin dHead → Rat + bk : Fin dHead → Dyadic /-- Value projection weights. -/ - wv : Fin dModel → Fin dHead → Rat + wv : Fin dModel → Fin dHead → Dyadic /-- Value projection bias. -/ - bv : Fin dHead → Rat + bv : Fin dHead → Dyadic /-- Output projection weights for this head slice. -/ - wo : Fin dModel → Fin dHead → Rat + wo : Fin dModel → Fin dHead → Dyadic /-- Exact GPT-2 layer slice with MLP and LayerNorm parameters. -/ structure Gpt2LayerSlice (dModel hidden : Nat) where /-- Attention output bias (shared across heads). -/ - attnBias : Fin dModel → Rat + attnBias : Fin dModel → Dyadic /-- MLP input projection weights. -/ - mlpWIn : Fin dModel → Fin hidden → Rat + mlpWIn : Fin dModel → Fin hidden → Dyadic /-- MLP input projection bias. -/ - mlpBIn : Fin hidden → Rat + mlpBIn : Fin hidden → Dyadic /-- MLP output projection weights. -/ - mlpWOut : Fin hidden → Fin dModel → Rat + mlpWOut : Fin hidden → Fin dModel → Dyadic /-- MLP output projection bias. -/ - mlpBOut : Fin dModel → Rat + mlpBOut : Fin dModel → Dyadic /-- LayerNorm scale for the attention input. -/ - ln1Gamma : Fin dModel → Rat + ln1Gamma : Fin dModel → Dyadic /-- LayerNorm bias for the attention input. -/ - ln1Beta : Fin dModel → Rat + ln1Beta : Fin dModel → Dyadic /-- LayerNorm scale for the MLP input. -/ - ln2Gamma : Fin dModel → Rat + ln2Gamma : Fin dModel → Dyadic /-- LayerNorm bias for the MLP input. -/ - ln2Beta : Fin dModel → Rat + ln2Beta : Fin dModel → Dyadic /-- Final LayerNorm parameters applied before unembedding. -/ structure Gpt2FinalLayerNorm (dModel : Nat) where /-- LayerNorm scale. -/ - gamma : Fin dModel → Rat + gamma : Fin dModel → Dyadic /-- LayerNorm bias. -/ - beta : Fin dModel → Rat + beta : Fin dModel → Dyadic /-- Token-plus-position embeddings for a GPT-2 head slice. -/ def Gpt2HeadSlice.embed {seq dModel dHead vocab : Nat} (slice : Gpt2HeadSlice seq dModel dHead vocab) : - Fin seq → Fin dModel → Rat := + Fin seq → Fin dModel → Dyadic := fun q d => slice.wte (slice.tokens q) d + slice.wpe q d /-- Direction vector in model space for a GPT-2 head slice. -/ def Gpt2HeadSlice.directionVec {seq dModel dHead vocab : Nat} - (slice : Gpt2HeadSlice seq dModel dHead vocab) : Fin dModel → Rat := + (slice : Gpt2HeadSlice seq dModel dHead vocab) : Fin dModel → Dyadic := fun d => slice.wte slice.direction.target d - slice.wte slice.direction.negative d end Model diff --git a/Nfp/Model/InductionHead.lean b/Nfp/Model/InductionHead.lean index f1e3eb2..10df33e 100644 --- a/Nfp/Model/InductionHead.lean +++ b/Nfp/Model/InductionHead.lean @@ -1,13 +1,13 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Mathlib.Data.Finset.Basic -import Mathlib.Algebra.Order.Ring.Rat +import Nfp.Core.Basic import Nfp.Circuit.Cert.ValueRange /-! Exact inputs for induction-head scoring and value-direction computations. -These structures store exact rational inputs (embeddings and weights) for a +These structures store exact dyadic inputs (embeddings and weights) for a single attention head. They are intended to be consumed by sound builders. -/ @@ -20,43 +20,43 @@ open Nfp.Circuit /-- Exact head inputs for induction certification. -/ structure InductionHeadInputs (seq dModel dHead : Nat) where /-- Softmax scale factor (e.g. `1/8` for GPT-2-small head dim 64). -/ - scale : Rat + scale : Dyadic /-- Active queries for which bounds are required. -/ active : Finset (Fin seq) /-- `prev` selector for induction-style attention. -/ prev : Fin seq → Fin seq /-- Token embeddings for the sequence. -/ - embed : Fin seq → Fin dModel → Rat + embed : Fin seq → Fin dModel → Dyadic /-- LayerNorm epsilon used before attention. -/ - lnEps : Rat + lnEps : Dyadic /-- LayerNorm scale for pre-attention normalization. -/ - ln1Gamma : Fin dModel → Rat + ln1Gamma : Fin dModel → Dyadic /-- LayerNorm bias for pre-attention normalization. -/ - ln1Beta : Fin dModel → Rat + ln1Beta : Fin dModel → Dyadic /-- Query projection weights. -/ - wq : Fin dModel → Fin dHead → Rat + wq : Fin dModel → Fin dHead → Dyadic /-- Query projection bias. -/ - bq : Fin dHead → Rat + bq : Fin dHead → Dyadic /-- Key projection weights. -/ - wk : Fin dModel → Fin dHead → Rat + wk : Fin dModel → Fin dHead → Dyadic /-- Key projection bias. -/ - bk : Fin dHead → Rat + bk : Fin dHead → Dyadic /-- Value projection weights. -/ - wv : Fin dModel → Fin dHead → Rat + wv : Fin dModel → Fin dHead → Dyadic /-- Value projection bias. -/ - bv : Fin dHead → Rat + bv : Fin dHead → Dyadic /-- Output projection weights (head slice). -/ - wo : Fin dModel → Fin dHead → Rat + wo : Fin dModel → Fin dHead → Dyadic /-- Attention output bias (shared across heads). -/ - attnBias : Fin dModel → Rat + attnBias : Fin dModel → Dyadic /-- Whether to apply a causal mask to attention scores. -/ maskCausal : Bool /-- Score value for masked entries (e.g. `-10000` for GPT-2 causal masking). -/ - maskValue : Rat + maskValue : Dyadic /-- Logit-diff direction metadata. -/ directionSpec : DirectionSpec /-- Logit-diff direction vector in model space. -/ - direction : Fin dModel → Rat + direction : Fin dModel → Dyadic end Model diff --git a/Nfp/Prob/Operations.lean b/Nfp/Prob/Operations.lean index ef0d1f2..d32dd74 100644 --- a/Nfp/Prob/Operations.lean +++ b/Nfp/Prob/Operations.lean @@ -16,6 +16,13 @@ universe u variable {ι : Type u} [Fintype ι] +/-- Factor a constant out of a sum. -/ +private lemma sum_mul_const (a : Mass) (p : ι → Mass) : + (∑ i, a * p i) = a * ∑ i, p i := by + simpa using + (Finset.mul_sum (a := a) (s := (Finset.univ : Finset ι)) + (f := fun i => p i)).symm + /-- The pure distribution at a single point. -/ def pure (i0 : ι) [DecidableEq ι] : ProbVec ι := by refine @@ -40,15 +47,7 @@ def mix (a b : Mass) (h : a + b = 1) (p q : ProbVec ι) : ProbVec ι := = (∑ i, a * p.mass i) + (∑ i, b * q.mass i) := by simp [Finset.sum_add_distrib] _ = a * ∑ i, p.mass i + b * ∑ i, q.mass i := by - have ha : (∑ i, a * p.mass i) = a * ∑ i, p.mass i := by - simpa using - (Finset.mul_sum (a := a) (s := (Finset.univ : Finset ι)) - (f := fun i => p.mass i)).symm - have hb : (∑ i, b * q.mass i) = b * ∑ i, q.mass i := by - simpa using - (Finset.mul_sum (a := b) (s := (Finset.univ : Finset ι)) - (f := fun i => q.mass i)).symm - simp [ha, hb] + simp [sum_mul_const] _ = a * 1 + b * 1 := by simp _ = 1 := by simp [h] } diff --git a/Nfp/Sound.lean b/Nfp/Sound.lean index dee6446..dac2283 100644 --- a/Nfp/Sound.lean +++ b/Nfp/Sound.lean @@ -2,6 +2,7 @@ import Nfp.Sound.Gpt2.HeadInputs import Nfp.Sound.Induction +import Nfp.Sound.Induction.HeadBounds import Nfp.Sound.Induction.LogitDiff import Nfp.Sound.Bounds.MatrixNorm import Nfp.Sound.Bounds.Gelu diff --git a/Nfp/Sound/Bounds/Attention.lean b/Nfp/Sound/Bounds/Attention.lean index 0208bf8..a60b3d7 100644 --- a/Nfp/Sound/Bounds/Attention.lean +++ b/Nfp/Sound/Bounds/Attention.lean @@ -3,8 +3,7 @@ import Mathlib.Algebra.BigOperators.Field import Mathlib.Algebra.BigOperators.Ring.Finset import Mathlib.Algebra.Order.BigOperators.Group.Finset -import Mathlib.Algebra.Order.Ring.Rat -import Mathlib.Data.Rat.Cast.Order +import Nfp.Core.Basic import Mathlib.Data.Real.Basic import Nfp.Circuit.Layers.Softmax import Nfp.Model.Gpt2 @@ -24,11 +23,134 @@ namespace Bounds open scoped BigOperators +/-! +Caching helpers for interval bounds. +-/ + +/-- Cache a bound function in an array-backed lookup to avoid repeated evaluation. -/ +def cacheBound {n : Nat} (f : Fin n → Dyadic) : Fin n → Dyadic := + let data : Thunk (Array Dyadic) := Thunk.mk (fun _ => Array.ofFn f) + fun i => (Thunk.get data)[i.1]'(by + have hsize : (Thunk.get data).size = n := by + simp [Thunk.get, data] + simp [hsize]) + +/-- `cacheBound` preserves pointwise values. -/ +theorem cacheBound_apply {n : Nat} (f : Fin n → Dyadic) (i : Fin n) : + cacheBound f i = f i := by + simp [cacheBound, Thunk.get, Array.getElem_ofFn] + +/-- Cache a bound function on two indices. -/ +def cacheBound2 {m n : Nat} (f : Fin m → Fin n → Dyadic) : Fin m → Fin n → Dyadic := + let data : Thunk (Array (Thunk (Array Dyadic))) := Thunk.mk (fun _ => + Array.ofFn (fun q => Thunk.mk (fun _ => Array.ofFn (f q)))) + fun q i => + let rowThunk := (Thunk.get data)[q.1]'(by + have hsize : (Thunk.get data).size = m := by + simp [Thunk.get, data] + rw [hsize] + exact q.isLt) + let row := Thunk.get rowThunk + row[i.1]'(by + have hsize : row.size = n := by + simp [row, rowThunk, Thunk.get, data, Array.getElem_ofFn] + rw [hsize] + exact i.isLt) + +/-- `cacheBound2` preserves pointwise values. -/ +theorem cacheBound2_apply {m n : Nat} (f : Fin m → Fin n → Dyadic) (q : Fin m) (i : Fin n) : + cacheBound2 f q i = f q i := by + simp [cacheBound2, Thunk.get, Array.getElem_ofFn] + +/-- Cache a bound function on two indices using row tasks for parallel evaluation. -/ +def cacheBound2Task {m n : Nat} (f : Fin m → Fin n → Dyadic) : Fin m → Fin n → Dyadic := + let rowTasks : Array (Task { row : Array Dyadic // row.size = n }) := + Array.ofFn (fun q : Fin m => + Task.spawn (fun _ => ⟨Array.ofFn (f q), by simp⟩)) + fun q i => + let row := (rowTasks[q.1]'(by + simp [rowTasks, q.isLt])).get + row.1[i.1]'(by + simp [row.2]) + +/-- `cacheBound2Task` preserves pointwise values. -/ +theorem cacheBound2Task_apply {m n : Nat} (f : Fin m → Fin n → Dyadic) (q : Fin m) (i : Fin n) : + cacheBound2Task f q i = f q i := by + classical + simp [cacheBound2Task, Task.spawn, Array.getElem_ofFn] + +/-- Cache a bound function on two indices using per-element tasks for parallel evaluation. -/ +def cacheBound2TaskElem {m n : Nat} (f : Fin m → Fin n → Dyadic) : Fin m → Fin n → Dyadic := + let rowTasks : Array (Array (Task Dyadic)) := + Array.ofFn (fun q : Fin m => + Array.ofFn (fun i : Fin n => + Task.spawn (fun _ => f q i))) + fun q i => + let row := (rowTasks[q.1]'(by + simp [rowTasks, q.isLt])) + let t := row[i.1]'(by + have hsize : row.size = n := by + simp [row, rowTasks] + simp [hsize, i.isLt]) + t.get + +/-- `cacheBound2TaskElem` preserves pointwise values. -/ +theorem cacheBound2TaskElem_apply {m n : Nat} (f : Fin m → Fin n → Dyadic) (q : Fin m) (i : Fin n) : + cacheBound2TaskElem f q i = f q i := by + classical + simp [cacheBound2TaskElem, Task.spawn, Array.getElem_ofFn] + +/-- Cache a pair of bound functions on two indices. -/ +def cacheBoundPair2 {m n : Nat} + (f : Fin m → (Fin n → Dyadic) × (Fin n → Dyadic)) : + (Fin m → Fin n → Dyadic) × (Fin m → Fin n → Dyadic) := + let data : Thunk (Array (Array Dyadic × Array Dyadic)) := Thunk.mk (fun _ => + Array.ofFn (fun q => + let row := f q + (Array.ofFn row.1, Array.ofFn row.2))) + let lo : Fin m → Fin n → Dyadic := fun q i => + let row := (Thunk.get data)[q.1]'(by + have hsize : (Thunk.get data).size = m := by + simp [Thunk.get, data] + rw [hsize] + exact q.isLt) + let loRow := row.1 + loRow[i.1]'(by + have hsize : loRow.size = n := by + simp [loRow, row, Thunk.get, data, Array.getElem_ofFn] + rw [hsize] + exact i.isLt) + let hi : Fin m → Fin n → Dyadic := fun q i => + let row := (Thunk.get data)[q.1]'(by + have hsize : (Thunk.get data).size = m := by + simp [Thunk.get, data] + rw [hsize] + exact q.isLt) + let hiRow := row.2 + hiRow[i.1]'(by + have hsize : hiRow.size = n := by + simp [hiRow, row, Thunk.get, data, Array.getElem_ofFn] + rw [hsize] + exact i.isLt) + (lo, hi) + +/-- `cacheBoundPair2` preserves pointwise values (first component). -/ +theorem cacheBoundPair2_apply_left {m n : Nat} + (f : Fin m → (Fin n → Dyadic) × (Fin n → Dyadic)) (q : Fin m) (i : Fin n) : + (cacheBoundPair2 f).1 q i = (f q).1 i := by + simp [cacheBoundPair2, Thunk.get, Array.getElem_ofFn] + +/-- `cacheBoundPair2` preserves pointwise values (second component). -/ +theorem cacheBoundPair2_apply_right {m n : Nat} + (f : Fin m → (Fin n → Dyadic) × (Fin n → Dyadic)) (q : Fin m) (i : Fin n) : + (cacheBoundPair2 f).2 q i = (f q).2 i := by + simp [cacheBoundPair2, Thunk.get, Array.getElem_ofFn] + /-- Real-valued attention output for a query token and model coordinate. -/ noncomputable def attentionOutputReal {seq dModel dHead numHeads : Nat} [NeZero seq] - (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) + (eps : Dyadic) (ln1Gamma ln1Beta : Fin dModel → Dyadic) (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (attnBias : Fin dModel → Rat) + (attnBias : Fin dModel → Dyadic) (scores : Fin numHeads → Fin seq → Fin seq → Real) (x : Fin seq → Fin dModel → Real) (q : Fin seq) (i : Fin dModel) : Real := @@ -46,9 +168,9 @@ noncomputable def attentionOutputReal {seq dModel dHead numHeads : Nat} [NeZero /-- Unfolding lemma for `attentionOutputReal`. -/ theorem attentionOutputReal_def {seq dModel dHead numHeads : Nat} [NeZero seq] - (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) + (eps : Dyadic) (ln1Gamma ln1Beta : Fin dModel → Dyadic) (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (attnBias : Fin dModel → Rat) + (attnBias : Fin dModel → Dyadic) (scores : Fin numHeads → Fin seq → Fin seq → Real) (x : Fin seq → Fin dModel → Real) (q : Fin seq) (i : Fin dModel) : @@ -67,35 +189,35 @@ theorem attentionOutputReal_def {seq dModel dHead numHeads : Nat} [NeZero seq] /-- Interval bounds for multi-head attention outputs from interval inputs. -/ def attentionOutputBounds {dModel dHead numHeads : Nat} - (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) + (eps : Dyadic) (ln1Gamma ln1Beta : Fin dModel → Dyadic) (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (attnBias : Fin dModel → Rat) - (lo hi : Fin dModel → Rat) : - (Fin dModel → Rat) × (Fin dModel → Rat) := + (attnBias : Fin dModel → Dyadic) + (lo hi : Fin dModel → Dyadic) : + (Fin dModel → Dyadic) × (Fin dModel → Dyadic) := let absBound := intervalAbsBound lo hi let ln := layerNormAbsBounds eps ln1Gamma ln1Beta absBound let lnLo := ln.1 let lnHi := ln.2 - let vLo : Fin numHeads → Fin dHead → Rat := fun h d => + let vLo : Fin numHeads → Fin dHead → Dyadic := fun h d => dotIntervalLower (fun j => (heads h).wv j d) lnLo lnHi + (heads h).bv d - let vHi : Fin numHeads → Fin dHead → Rat := fun h d => + let vHi : Fin numHeads → Fin dHead → Dyadic := fun h d => dotIntervalUpper (fun j => (heads h).wv j d) lnLo lnHi + (heads h).bv d - let headLo : Fin numHeads → Fin dModel → Rat := fun h i => + let headLo : Fin numHeads → Fin dModel → Dyadic := fun h i => dotIntervalLower (fun d => (heads h).wo i d) (vLo h) (vHi h) - let headHi : Fin numHeads → Fin dModel → Rat := fun h i => + let headHi : Fin numHeads → Fin dModel → Dyadic := fun h i => dotIntervalUpper (fun d => (heads h).wo i d) (vLo h) (vHi h) - let sumLo : Fin dModel → Rat := fun i => ∑ h, headLo h i - let sumHi : Fin dModel → Rat := fun i => ∑ h, headHi h i + let sumLo : Fin dModel → Dyadic := fun i => ∑ h, headLo h i + let sumHi : Fin dModel → Dyadic := fun i => ∑ h, headHi h i (fun i => sumLo i + attnBias i, fun i => sumHi i + attnBias i) /-- `attentionOutputBounds` soundness for real attention outputs. -/ theorem attentionOutputBounds_spec {seq dModel dHead numHeads : Nat} [NeZero seq] - (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) + (eps : Dyadic) (ln1Gamma ln1Beta : Fin dModel → Dyadic) (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (attnBias : Fin dModel → Rat) + (attnBias : Fin dModel → Dyadic) (scores : Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) - (hne : dModel ≠ 0) (heps : 0 < eps) + (lo hi : Fin dModel → Dyadic) (x : Fin seq → Fin dModel → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : let bounds := attentionOutputBounds eps ln1Gamma ln1Beta heads attnBias lo hi ∀ q i, @@ -111,16 +233,16 @@ theorem attentionOutputBounds_spec {seq dModel dHead numHeads : Nat} [NeZero seq let lnHi := lnBounds.2 let lnOut : Fin seq → Fin dModel → Real := fun k j => layerNormRealOfReal eps ln1Gamma ln1Beta (x k) j - let vLo : Fin numHeads → Fin dHead → Rat := fun h d => + let vLo : Fin numHeads → Fin dHead → Dyadic := fun h d => dotIntervalLower (fun j => (heads h).wv j d) lnLo lnHi + (heads h).bv d - let vHi : Fin numHeads → Fin dHead → Rat := fun h d => + let vHi : Fin numHeads → Fin dHead → Dyadic := fun h d => dotIntervalUpper (fun j => (heads h).wv j d) lnLo lnHi + (heads h).bv d - let headLo : Fin numHeads → Fin dModel → Rat := fun h j => + let headLo : Fin numHeads → Fin dModel → Dyadic := fun h j => dotIntervalLower (fun d => (heads h).wo j d) (vLo h) (vHi h) - let headHi : Fin numHeads → Fin dModel → Rat := fun h j => + let headHi : Fin numHeads → Fin dModel → Dyadic := fun h j => dotIntervalUpper (fun d => (heads h).wo j d) (vLo h) (vHi h) - let sumLo : Fin dModel → Rat := fun j => ∑ h, headLo h j - let sumHi : Fin dModel → Rat := fun j => ∑ h, headHi h j + let sumLo : Fin dModel → Dyadic := fun j => ∑ h, headLo h j + let sumHi : Fin dModel → Dyadic := fun j => ∑ h, headHi h j let headValue : Fin numHeads → Fin seq → Fin dHead → Real := fun h k d => dotProduct (fun j => ((heads h).wv j d : Real)) (lnOut k) + (heads h).bv d let headWeights : Fin numHeads → Fin seq → Fin seq → Real := fun h q k => @@ -134,25 +256,17 @@ theorem attentionOutputBounds_spec {seq dModel dHead numHeads : Nat} [NeZero seq have hbound : |x q i| ≤ max |(lo i : Real)| |(hi i : Real)| := abs_le_max_abs_abs_of_interval_real (hlo q i) (hhi q i) - have hnonempty : (Finset.univ : Finset (Fin dModel)).Nonempty := ⟨i, by simp⟩ - have hsup : - max |lo i| |hi i| ≤ intervalAbsBound lo hi := by - have hsup' : - max |lo i| |hi i| ≤ - (Finset.univ).sup' hnonempty (fun k => max |lo k| |hi k|) := by - simpa using - (Finset.le_sup' - (s := (Finset.univ : Finset (Fin dModel))) - (f := fun k => max |lo k| |hi k|) - (by simp : i ∈ (Finset.univ : Finset (Fin dModel)))) - simpa [intervalAbsBound, hnonempty] using hsup' + have hsup : max |lo i| |hi i| ≤ intervalAbsBound lo hi := + max_abs_le_intervalAbsBound lo hi i have hsup_real : max |(lo i : Real)| |(hi i : Real)| ≤ (absBound : Real) := by - exact_mod_cast hsup + have hsup' : dyadicToReal (max |lo i| |hi i|) ≤ dyadicToReal absBound := + dyadicToReal_le_of_le hsup + simpa [dyadicToReal_abs, dyadicToReal_max] using hsup' exact le_trans hbound hsup_real have hln_bounds : ∀ q i, (lnLo i : Real) ≤ lnOut q i ∧ lnOut q i ≤ (lnHi i : Real) := by intro q i - have hln := layerNormAbsBounds_spec_real eps ln1Gamma ln1Beta absBound (x q) hne heps + have hln := layerNormAbsBounds_spec_real eps ln1Gamma ln1Beta absBound (x q) hne heps hsqrt (fun j => habs q j) simpa [lnBounds, lnLo, lnHi, lnOut] using hln i have hval_bounds : @@ -172,8 +286,8 @@ theorem attentionOutputBounds_spec {seq dModel dHead numHeads : Nat} [NeZero seq have hlow' := add_le_add_right hlow ((heads h).bv d : Real) have hhigh' := add_le_add_right hhigh ((heads h).bv d : Real) constructor - · simpa [headValue, vLo, Rat.cast_add] using hlow' - · simpa [headValue, vHi, Rat.cast_add] using hhigh' + · simpa [headValue, vLo] using hlow' + · simpa [headValue, vHi] using hhigh' have weighted_bounds : ∀ {lo hi : Real} {vals : Fin seq → Real} {w : Fin seq → Real}, (∀ k, lo ≤ vals k) → (∀ k, vals k ≤ hi) → @@ -255,12 +369,12 @@ theorem attentionOutputBounds_spec {seq dModel dHead numHeads : Nat} [NeZero seq have hsum := Finset.sum_le_sum (s := (Finset.univ : Finset (Fin numHeads))) (fun h _ => (hproj_bounds h q i).1) - simpa [sumLo] using hsum + simpa [sumLo, Linear.dyadicToReal_sum_univ] using hsum have hhigh : ∑ h, headProj h q i ≤ (sumHi i : Real) := by have hsum := Finset.sum_le_sum (s := (Finset.univ : Finset (Fin numHeads))) (fun h _ => (hproj_bounds h q i).2) - simpa [sumHi] using hsum + simpa [sumHi, Linear.dyadicToReal_sum_univ] using hsum exact ⟨hlow, hhigh⟩ have hlow : (sumLo i : Real) + (attnBias i : Real) ≤ @@ -290,22 +404,22 @@ theorem attentionOutputBounds_spec {seq dModel dHead numHeads : Nat} [NeZero seq /-- Interval bounds for the attention residual path. -/ def attentionResidualBounds {dModel dHead numHeads : Nat} - (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) + (eps : Dyadic) (ln1Gamma ln1Beta : Fin dModel → Dyadic) (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (attnBias : Fin dModel → Rat) - (lo hi : Fin dModel → Rat) : - (Fin dModel → Rat) × (Fin dModel → Rat) := + (attnBias : Fin dModel → Dyadic) + (lo hi : Fin dModel → Dyadic) : + (Fin dModel → Dyadic) × (Fin dModel → Dyadic) := let attn := attentionOutputBounds eps ln1Gamma ln1Beta heads attnBias lo hi (fun i => lo i + attn.1 i, fun i => hi i + attn.2 i) /-- `attentionResidualBounds` soundness for attention residual outputs. -/ theorem attentionResidualBounds_spec {seq dModel dHead numHeads : Nat} [NeZero seq] - (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) + (eps : Dyadic) (ln1Gamma ln1Beta : Fin dModel → Dyadic) (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (attnBias : Fin dModel → Rat) + (attnBias : Fin dModel → Dyadic) (scores : Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) - (hne : dModel ≠ 0) (heps : 0 < eps) + (lo hi : Fin dModel → Dyadic) (x : Fin seq → Fin dModel → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : let bounds := attentionResidualBounds eps ln1Gamma ln1Beta heads attnBias lo hi ∀ q i, @@ -318,37 +432,45 @@ theorem attentionResidualBounds_spec {seq dModel dHead numHeads : Nat} [NeZero s let attn := attentionOutputBounds eps ln1Gamma ln1Beta heads attnBias lo hi have hattn := attentionOutputBounds_spec eps ln1Gamma ln1Beta heads attnBias scores lo hi x - hne heps hlo hhi q i + hne heps hsqrt hlo hhi q i have hlow := add_le_add (hlo q i) hattn.1 have hhigh := add_le_add (hhi q i) hattn.2 constructor - · simpa [bounds, attentionResidualBounds, attn, Rat.cast_add] using hlow - · simpa [bounds, attentionResidualBounds, attn, Rat.cast_add] using hhigh + · simpa [bounds, attentionResidualBounds, attn] using hlow + · simpa [bounds, attentionResidualBounds, attn] using hhigh /-- Interval bounds for a full transformer layer (attention + MLP). -/ def transformerLayerBounds {dModel dHead numHeads hidden : Nat} - (eps : Rat) - (ln1Gamma ln1Beta ln2Gamma ln2Beta : Fin dModel → Rat) + (eps : Dyadic) + (ln1Gamma ln1Beta ln2Gamma ln2Beta : Fin dModel → Dyadic) (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (attnBias : Fin dModel → Rat) - (mlpWIn : Fin dModel → Fin hidden → Rat) (mlpBIn : Fin hidden → Rat) - (mlpWOut : Fin hidden → Fin dModel → Rat) (mlpBOut : Fin dModel → Rat) - (lo hi : Fin dModel → Rat) : - (Fin dModel → Rat) × (Fin dModel → Rat) := - let attn := attentionResidualBounds eps ln1Gamma ln1Beta heads attnBias lo hi - layerNormAbsMlpResidualBounds eps ln2Gamma ln2Beta mlpWIn mlpBIn mlpWOut mlpBOut attn.1 attn.2 + (attnBias : Fin dModel → Dyadic) + (mlpWIn : Fin dModel → Fin hidden → Dyadic) (mlpBIn : Fin hidden → Dyadic) + (mlpWOut : Fin hidden → Fin dModel → Dyadic) (mlpBOut : Fin dModel → Dyadic) + (lo hi : Fin dModel → Dyadic) : + (Fin dModel → Dyadic) × (Fin dModel → Dyadic) := + let loCached := cacheBound lo + let hiCached := cacheBound hi + let attn := attentionResidualBounds eps ln1Gamma ln1Beta heads attnBias loCached hiCached + let attnLo := cacheBound attn.1 + let attnHi := cacheBound attn.2 + let out := layerNormAbsMlpResidualBounds eps ln2Gamma ln2Beta mlpWIn mlpBIn mlpWOut mlpBOut + attnLo attnHi + let outLo := cacheBound out.1 + let outHi := cacheBound out.2 + (outLo, outHi) /-- `transformerLayerBounds` soundness for full transformer-layer outputs. -/ theorem transformerLayerBounds_spec {seq dModel dHead numHeads hidden : Nat} [NeZero seq] - (eps : Rat) - (ln1Gamma ln1Beta ln2Gamma ln2Beta : Fin dModel → Rat) + (eps : Dyadic) + (ln1Gamma ln1Beta ln2Gamma ln2Beta : Fin dModel → Dyadic) (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (attnBias : Fin dModel → Rat) - (mlpWIn : Fin dModel → Fin hidden → Rat) (mlpBIn : Fin hidden → Rat) - (mlpWOut : Fin hidden → Fin dModel → Rat) (mlpBOut : Fin dModel → Rat) + (attnBias : Fin dModel → Dyadic) + (mlpWIn : Fin dModel → Fin hidden → Dyadic) (mlpBIn : Fin hidden → Dyadic) + (mlpWOut : Fin hidden → Fin dModel → Dyadic) (mlpBOut : Fin dModel → Dyadic) (scores : Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) - (hne : dModel ≠ 0) (heps : 0 < eps) + (lo hi : Fin dModel → Dyadic) (x : Fin seq → Fin dModel → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : let bounds := transformerLayerBounds eps ln1Gamma ln1Beta ln2Gamma ln2Beta heads attnBias mlpWIn mlpBIn mlpWOut mlpBOut lo hi @@ -367,15 +489,31 @@ theorem transformerLayerBounds_spec {seq dModel dHead numHeads hidden : Nat} [Ne (bounds.2 i : Real) := by classical intro bounds q i - let attn := attentionResidualBounds eps ln1Gamma ln1Beta heads attnBias lo hi - have hattn := attentionResidualBounds_spec eps ln1Gamma ln1Beta heads attnBias scores lo hi x - hne heps hlo hhi q + let loCached := cacheBound lo + let hiCached := cacheBound hi + have hloCached : ∀ q i, (loCached i : Real) ≤ x q i := by + intro q i + simpa [loCached, cacheBound_apply] using hlo q i + have hhiCached : ∀ q i, x q i ≤ (hiCached i : Real) := by + intro q i + simpa [hiCached, cacheBound_apply] using hhi q i + let attn := attentionResidualBounds eps ln1Gamma ln1Beta heads attnBias loCached hiCached + have hattn := attentionResidualBounds_spec eps ln1Gamma ln1Beta heads attnBias scores + loCached hiCached x hne heps hsqrt hloCached hhiCached q + let attnLo := cacheBound attn.1 + let attnHi := cacheBound attn.2 + let y := fun j => x q j + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q j + have hattnLo : ∀ j, (attnLo j : Real) ≤ y j := by + intro j + simpa [attnLo, cacheBound_apply, y] using (hattn j).1 + have hattnHi : ∀ j, y j ≤ (attnHi j : Real) := by + intro j + simpa [attnHi, cacheBound_apply, y] using (hattn j).2 have hmlp := layerNormAbsMlpResidualBounds_spec eps ln2Gamma ln2Beta mlpWIn mlpBIn mlpWOut - mlpBOut attn.1 attn.2 (fun j => x q j + attentionOutputReal eps ln1Gamma ln1Beta heads - attnBias scores x q j) hne heps - (fun j => (hattn j).1) (fun j => (hattn j).2) + mlpBOut attnLo attnHi y hne heps hsqrt hattnLo hattnHi have hmlp_i := hmlp i - simpa [bounds, transformerLayerBounds, attn] using hmlp_i + simpa [bounds, transformerLayerBounds, attn, loCached, hiCached, attnLo, attnHi, y, + cacheBound_apply] using hmlp_i end Bounds diff --git a/Nfp/Sound/Bounds/Gelu.lean b/Nfp/Sound/Bounds/Gelu.lean index be5dc60..2bb5ab1 100644 --- a/Nfp/Sound/Bounds/Gelu.lean +++ b/Nfp/Sound/Bounds/Gelu.lean @@ -3,7 +3,7 @@ import Mathlib.Algebra.Order.Ring.Abs import Mathlib.Analysis.Complex.Trigonometric import Mathlib.Analysis.SpecialFunctions.Trigonometric.Basic -import Mathlib.Data.Rat.Cast.Order +import Nfp.Core.Basic /-! Tanh-based GELU bounds for GPT-2 style MLPs. @@ -118,22 +118,61 @@ theorem geluTanh_bounds (x : Real) : simpa [min_eq_left hx', max_eq_right hx'] using And.intro h1 h0 /-- Interval bounds for GELU given input bounds. -/ -def geluInterval (lo hi : Rat) : Rat × Rat := - (min lo 0, max hi 0) +def geluInterval (lo hi : Dyadic) : Dyadic × Dyadic := + (if lo ≤ 0 then lo else 0, if 0 ≤ hi then hi else 0) /-- `geluInterval` soundly bounds `geluTanh` on a real interval. -/ -theorem geluInterval_bounds {lo hi : Rat} {x : Real} +theorem geluInterval_bounds {lo hi : Dyadic} {x : Real} (hlo : (lo : Real) ≤ x) (hhi : x ≤ (hi : Real)) : (geluInterval lo hi).1 ≤ (geluTanh x : Real) ∧ (geluTanh x : Real) ≤ (geluInterval lo hi).2 := by have hgelu := geluTanh_bounds x - have hmin : min (lo : Real) 0 ≤ min x 0 := min_le_min hlo le_rfl - have hmax : max x 0 ≤ max (hi : Real) 0 := max_le_max hhi le_rfl - have hlo' : min (lo : Real) 0 ≤ geluTanh x := le_trans hmin hgelu.1 - have hhi' : geluTanh x ≤ max (hi : Real) 0 := le_trans hgelu.2 hmax - constructor - · simpa [geluInterval, Rat.cast_min] using hlo' - · simpa [geluInterval, Rat.cast_max] using hhi' + by_cases hlo0 : lo ≤ 0 + · have hlo0r : (lo : Real) ≤ 0 := by + exact (dyadicToReal_nonpos_iff (x := lo)).2 hlo0 + have hmin : min (lo : Real) 0 ≤ min x 0 := min_le_min hlo le_rfl + have hlo' : (lo : Real) ≤ geluTanh x := by + have hmin' : (lo : Real) ≤ min x 0 := by + simpa [min_eq_left hlo0r] using hmin + exact le_trans hmin' hgelu.1 + have hmax : max x 0 ≤ max (hi : Real) 0 := max_le_max hhi le_rfl + have hhi' : geluTanh x ≤ max (hi : Real) 0 := le_trans hgelu.2 hmax + constructor + · simpa [geluInterval, hlo0] using hlo' + · by_cases hhi0 : 0 ≤ hi + · have hhi0r : 0 ≤ (hi : Real) := by + exact (dyadicToReal_nonneg_iff (x := hi)).2 hhi0 + have hmax' : max (hi : Real) 0 = (hi : Real) := max_eq_left hhi0r + simpa [geluInterval, hhi0, hmax'] using hhi' + · have hhi0r : (hi : Real) ≤ 0 := by + exact (dyadicToReal_nonpos_iff (x := hi)).2 (le_of_not_ge hhi0) + have hx0 : x ≤ 0 := le_trans hhi hhi0r + have hmax' : max x 0 = 0 := max_eq_right hx0 + have hhi'' : geluTanh x ≤ (0 : Real) := by + simpa [hmax'] using hgelu.2 + simpa [geluInterval, hhi0, dyadicToReal_zero] using hhi'' + · have hlo0r : 0 ≤ (lo : Real) := by + exact (dyadicToReal_nonneg_iff (x := lo)).2 (le_of_not_ge hlo0) + have hx0 : 0 ≤ x := le_trans hlo0r hlo + have hmin' : min x 0 = 0 := min_eq_right hx0 + have hlo' : (0 : Real) ≤ geluTanh x := by + simpa [hmin'] using hgelu.1 + have hmax : max x 0 ≤ max (hi : Real) 0 := max_le_max hhi le_rfl + have hhi' : geluTanh x ≤ max (hi : Real) 0 := le_trans hgelu.2 hmax + constructor + · simpa [geluInterval, hlo0, dyadicToReal_zero] using hlo' + · by_cases hhi0 : 0 ≤ hi + · have hhi0r : 0 ≤ (hi : Real) := by + exact (dyadicToReal_nonneg_iff (x := hi)).2 hhi0 + have hmax' : max (hi : Real) 0 = (hi : Real) := max_eq_left hhi0r + simpa [geluInterval, hhi0, hmax'] using hhi' + · have hhi0r : (hi : Real) ≤ 0 := by + exact (dyadicToReal_nonpos_iff (x := hi)).2 (le_of_not_ge hhi0) + have hx0' : x ≤ 0 := le_trans hhi hhi0r + have hmax' : max x 0 = 0 := max_eq_right hx0' + have hhi'' : geluTanh x ≤ (0 : Real) := by + simpa [hmax'] using hgelu.2 + simpa [geluInterval, hhi0, dyadicToReal_zero] using hhi'' end Bounds diff --git a/Nfp/Sound/Bounds/LayerNorm.lean b/Nfp/Sound/Bounds/LayerNorm.lean index 139a985..37fae19 100644 --- a/Nfp/Sound/Bounds/LayerNorm.lean +++ b/Nfp/Sound/Bounds/LayerNorm.lean @@ -1,19 +1,22 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later +import Mathlib.Algebra.BigOperators.Fin import Mathlib.Algebra.BigOperators.Group.Finset.Basic import Mathlib.Algebra.Order.BigOperators.Group.Finset import Mathlib.Algebra.Order.Field.Basic import Mathlib.Algebra.Order.Ring.Basic -import Mathlib.Algebra.Order.Ring.Rat import Mathlib.Data.Nat.Sqrt import Mathlib.Data.Real.Sqrt import Mathlib.Data.Rat.BigOperators import Mathlib.Data.Rat.Cast.Order +import Nfp.Core.Basic +import Nfp.Sound.Bounds.LayerNorm.MeanVariance +import Nfp.Sound.Linear.FinFold /-! -LayerNorm interval bounds for exact rational inputs. +LayerNorm interval bounds for dyadic inputs. -This module computes rational interval bounds for LayerNorm outputs and proves +This module computes dyadic interval bounds for LayerNorm outputs and proves those bounds sound for real-valued LayerNorm semantics. -/ @@ -25,377 +28,171 @@ namespace Bounds open scoped BigOperators -/-- Mean of a finite vector (defaults to `0` when `n = 0`). -/ -def mean {n : Nat} (x : Fin n → Rat) : Rat := - if n = 0 then - 0 - else - (∑ i, x i) / n - -/-- Unfold `mean` when `n ≠ 0`. -/ -theorem mean_def {n : Nat} (x : Fin n → Rat) (h : n ≠ 0) : - mean x = (∑ i, x i) / n := by - simp [mean, h] - -/-- Mean is monotone under pointwise order (rational inputs). -/ -theorem mean_le_mean {n : Nat} (x y : Fin n → Rat) (hne : n ≠ 0) - (hxy : ∀ i, x i ≤ y i) : mean x ≤ mean y := by - classical - have hsum : (∑ i, x i) ≤ ∑ i, y i := by - refine Finset.sum_le_sum ?_ - intro i _ - exact hxy i - have hden : 0 ≤ (n : Rat) := by - exact_mod_cast (Nat.zero_le n) - have hdiv : (∑ i, x i) / n ≤ (∑ i, y i) / n := - div_le_div_of_nonneg_right hsum hden - simpa [mean, hne] using hdiv - -/-- Variance of a finite vector (defaults to `0` when `n = 0`). -/ -def variance {n : Nat} (x : Fin n → Rat) : Rat := - if n = 0 then - 0 - else - let μ := mean x - (∑ i, (x i - μ) ^ 2) / n - -/-- Unfold `variance` when `n ≠ 0`. -/ -theorem variance_def {n : Nat} (x : Fin n → Rat) (h : n ≠ 0) : - variance x = - let μ := mean x - (∑ i, (x i - μ) ^ 2) / n := by - simp [variance, h] - -/-- Variance is nonnegative when `n ≠ 0`. -/ -theorem variance_nonneg {n : Nat} (x : Fin n → Rat) (h : n ≠ 0) : - 0 ≤ variance x := by - classical - have hsum : 0 ≤ ∑ i, (x i - mean x) ^ 2 := by - refine Finset.sum_nonneg ?_ - intro i _ - exact sq_nonneg (x i - mean x) - have hden : 0 ≤ (n : Rat) := by - exact_mod_cast (Nat.zero_le n) - have hdiv : 0 ≤ (∑ i, (x i - mean x) ^ 2) / n := - div_nonneg hsum hden - simpa [variance_def x h] using hdiv - -/-- Absolute mean bound from per-coordinate bounds. -/ -theorem mean_abs_le_bound {n : Nat} (x : Fin n → Rat) (bound : Rat) - (hne : n ≠ 0) (hbound : ∀ i, |x i| ≤ bound) : - |mean x| ≤ bound := by - classical - have hsum_abs : - |∑ i : Fin n, x i| ≤ ∑ i : Fin n, |x i| := by - simpa using - (Finset.abs_sum_le_sum_abs - (f := fun i : Fin n => x i) - (s := (Finset.univ : Finset (Fin n)))) - have hsum_bound : ∑ i : Fin n, |x i| ≤ ∑ i : Fin n, bound := by - refine Finset.sum_le_sum ?_ - intro i _ - exact hbound i - have hsum_le : |∑ i : Fin n, x i| ≤ (n : Rat) * bound := by - have hsum := le_trans hsum_abs hsum_bound - simpa [Finset.sum_const, Finset.card_univ] using hsum - have hpos : 0 < (n : Rat) := by - exact_mod_cast Nat.pos_of_ne_zero hne - have hsum_le' : |∑ i : Fin n, x i| ≤ bound * (n : Rat) := by - simpa [mul_comm] using hsum_le - have hdiv : |∑ i : Fin n, x i| / (n : Rat) ≤ bound := by - exact (div_le_iff₀ hpos).2 hsum_le' - have habs_mean : - |(∑ i : Fin n, x i) / (n : Rat)| ≤ bound := by - simpa [abs_div, abs_of_nonneg (le_of_lt hpos)] using hdiv - simpa [mean_def x hne] using habs_mean - -/-! Interval helpers. -/ - -/-- Absolute value bound from endpoint bounds. -/ -theorem abs_le_max_of_bounds {α : Type _} [Ring α] [LinearOrder α] [IsOrderedRing α] - {a b z : α} - (hlo : a ≤ z) (hhi : z ≤ b) : - |z| ≤ max |a| |b| := by - have hleft : -max |a| |b| ≤ z := by - have hneg : -max |a| |b| ≤ a := by - have hneg' : -max |a| |b| ≤ -|a| := by - exact neg_le_neg (le_max_left _ _) - have hneg'' : -|a| ≤ a := by - have h : -a ≤ |a| := neg_le_abs a - simpa using (neg_le_neg h) - exact le_trans hneg' hneg'' - exact le_trans hneg hlo - have hright : z ≤ max |a| |b| := by - have hb : b ≤ |b| := by - exact le_abs_self b - have hb' : b ≤ max |a| |b| := le_trans hb (le_max_right _ _) - exact le_trans hhi hb' - exact (abs_le.mpr ⟨hleft, hright⟩) - -/-! Real-valued mean and variance. -/ - -/-- Mean of a real vector (defaults to `0` when `n = 0`). -/ -noncomputable def meanReal {n : Nat} (x : Fin n → Real) : Real := - if n = 0 then - 0 - else - (∑ i, x i) / n - -/-- Unfold `meanReal` when `n ≠ 0`. -/ -theorem meanReal_def {n : Nat} (x : Fin n → Real) (h : n ≠ 0) : - meanReal x = (∑ i, x i) / n := by - simp [meanReal, h] - -/-- `meanReal` agrees with `mean` after casting. -/ -theorem meanReal_ratCast {n : Nat} (x : Fin n → Rat) : - meanReal (fun i => (x i : Real)) = (mean x : Real) := by - by_cases h : n = 0 - · simp [meanReal, mean, h] - · simp [meanReal, mean, h, Rat.cast_sum, Rat.cast_div] - -/-- Mean is monotone under pointwise order (real inputs). -/ -theorem meanReal_le_meanReal {n : Nat} (x y : Fin n → Real) (hne : n ≠ 0) - (hxy : ∀ i, x i ≤ y i) : meanReal x ≤ meanReal y := by - classical - have hsum : (∑ i, x i) ≤ ∑ i, y i := by - refine Finset.sum_le_sum ?_ - intro i _ - exact hxy i - have hden : 0 ≤ (n : Real) := by - exact_mod_cast (Nat.zero_le n) - have hdiv : (∑ i, x i) / n ≤ (∑ i, y i) / n := - div_le_div_of_nonneg_right hsum hden - simpa [meanReal, hne] using hdiv - -/-- Variance of a real vector (defaults to `0` when `n = 0`). -/ -noncomputable def varianceReal {n : Nat} (x : Fin n → Real) : Real := - if n = 0 then - 0 - else - let μ := meanReal x - (∑ i, (x i - μ) ^ 2) / n - -/-- Unfold `varianceReal` when `n ≠ 0`. -/ -theorem varianceReal_def {n : Nat} (x : Fin n → Real) (h : n ≠ 0) : - varianceReal x = - let μ := meanReal x - (∑ i, (x i - μ) ^ 2) / n := by - simp [varianceReal, h] - -/-- Variance is nonnegative when `n ≠ 0`. -/ -theorem varianceReal_nonneg {n : Nat} (x : Fin n → Real) (h : n ≠ 0) : - 0 ≤ varianceReal x := by - classical - have hsum : 0 ≤ ∑ i, (x i - meanReal x) ^ 2 := by - refine Finset.sum_nonneg ?_ - intro i _ - exact sq_nonneg (x i - meanReal x) - have hden : 0 ≤ (n : Real) := by - exact_mod_cast (Nat.zero_le n) - have hdiv : 0 ≤ (∑ i, (x i - meanReal x) ^ 2) / n := - div_nonneg hsum hden - simpa [varianceReal_def x h] using hdiv +/-! Square-root bounds. -/ -/-- Absolute mean bound from per-coordinate bounds (real inputs). -/ -theorem meanReal_abs_le_bound {n : Nat} (x : Fin n → Real) (bound : Rat) - (hne : n ≠ 0) (hbound : ∀ i, |x i| ≤ (bound : Real)) : - |meanReal x| ≤ (bound : Real) := by - classical - have hsum_abs : - |∑ i : Fin n, x i| ≤ ∑ i : Fin n, |x i| := by - simpa using - (Finset.abs_sum_le_sum_abs - (f := fun i : Fin n => x i) - (s := (Finset.univ : Finset (Fin n)))) - have hsum_bound : ∑ i : Fin n, |x i| ≤ ∑ i : Fin n, (bound : Real) := by - refine Finset.sum_le_sum ?_ - intro i _ - exact hbound i - have hsum_le : |∑ i : Fin n, x i| ≤ (n : Real) * (bound : Real) := by - have hsum := le_trans hsum_abs hsum_bound - simpa [Finset.sum_const, Finset.card_univ, mul_comm] using hsum - have hpos : 0 < (n : Real) := by - exact_mod_cast Nat.pos_of_ne_zero hne - have hsum_le' : |∑ i : Fin n, x i| ≤ (bound : Real) * (n : Real) := by - simpa [mul_comm] using hsum_le - have hdiv : |∑ i : Fin n, x i| / (n : Real) ≤ (bound : Real) := by - exact (div_le_iff₀ hpos).2 hsum_le' - have habs_mean : - |(∑ i : Fin n, x i) / (n : Real)| ≤ (bound : Real) := by - simpa [abs_div, abs_of_nonneg (le_of_lt hpos)] using hdiv - simpa [meanReal_def x hne] using habs_mean +lemma dyadic_nat_cast_nonneg (n : Nat) : (0 : Dyadic) ≤ (n : Dyadic) := by + simp -/-! Square-root bounds. -/ +lemma dyadic_nat_cast_pos {n : Nat} (h : 0 < n) : (0 : Dyadic) < (n : Dyadic) := by + exact (Nat.cast_pos (α := Dyadic)).2 h /-- Base rational lower bound for a square root. -/ -def sqrtLowerBase (q : Rat) : Rat := - let num := q.num.natAbs - let den := q.den +def sqrtLowerBase (q : Dyadic) : Dyadic := + let num := q.toRat.num.natAbs + let den := q.toRat.den let a := Nat.sqrt num let b := Nat.sqrt den - (a : Rat) / (b + 1 : Rat) + dyadicOfRatDown ((a : Rat) / (b + 1)) /-- Base rational upper bound for a square root. -/ -def sqrtUpperBase (q : Rat) : Rat := - let num := q.num.natAbs - let den := q.den +def sqrtUpperBase (q : Dyadic) : Dyadic := + let num := q.toRat.num.natAbs + let den := q.toRat.den let a := Nat.sqrt num let b := Nat.sqrt den - (a + 1 : Rat) / (b : Rat) + dyadicOfRatUp ((a + 1 : Rat) / b) /-- Alternate rational lower bound for a square root. -/ -def sqrtLowerAlt (q : Rat) : Rat := - let num := q.num.natAbs - let den := q.den +def sqrtLowerAlt (q : Dyadic) : Dyadic := + let num := q.toRat.num.natAbs + let den := q.toRat.den let a := Nat.sqrt (num * den) - (a : Rat) / den + dyadicOfRatDown ((a : Rat) / den) /-- Alternate rational upper bound for a square root. -/ -def sqrtUpperAlt (q : Rat) : Rat := - let num := q.num.natAbs - let den := q.den +def sqrtUpperAlt (q : Dyadic) : Dyadic := + let num := q.toRat.num.natAbs + let den := q.toRat.den let a := Nat.sqrt (num * den) - (a + 1 : Rat) / den + dyadicOfRatUp ((a + 1 : Rat) / den) -/-- Rational lower bound for a square root (tighter of two bounds). -/ -def sqrtLower (q : Rat) : Rat := +/-- Dyadicional lower bound for a square root (tighter of two bounds). -/ +def sqrtLower (q : Dyadic) : Dyadic := max (sqrtLowerBase q) (sqrtLowerAlt q) -/-- Rational upper bound for a square root (tighter of two bounds). -/ -def sqrtUpper (q : Rat) : Rat := +/-- Dyadicional upper bound for a square root (tighter of two bounds). -/ +def sqrtUpper (q : Dyadic) : Dyadic := min (sqrtUpperBase q) (sqrtUpperAlt q) /-- `sqrtLowerBase` is nonnegative. -/ -theorem sqrtLowerBase_nonneg (q : Rat) : 0 ≤ sqrtLowerBase q := by +theorem sqrtLowerBase_nonneg (q : Dyadic) : 0 ≤ sqrtLowerBase q := by classical unfold sqrtLowerBase - have hden : 0 ≤ (Nat.sqrt q.den + 1 : Rat) := by - exact_mod_cast (Nat.zero_le _) - have hnum : 0 ≤ (Nat.sqrt q.num.natAbs : Rat) := by - exact_mod_cast (Nat.zero_le _) - exact div_nonneg hnum hden + have hnum : 0 ≤ (Nat.sqrt q.toRat.num.natAbs : Rat) := by + exact_mod_cast (Nat.zero_le (Nat.sqrt q.toRat.num.natAbs)) + have hden : 0 ≤ (Nat.sqrt q.toRat.den + 1 : Rat) := by + exact_mod_cast (Nat.zero_le (Nat.sqrt q.toRat.den + 1)) + have hrat : 0 ≤ (Nat.sqrt q.toRat.num.natAbs : Rat) / (Nat.sqrt q.toRat.den + 1) := by + exact div_nonneg hnum hden + exact dyadicOfRatDown_nonneg hrat /-! Strict positivity helpers. -/ /-! Base bounds. -/ -/-- `sqrtLowerBase` is positive when its input is positive. -/ -theorem sqrtLowerBase_pos {q : Rat} (hq : 0 < q) : 0 < sqrtLowerBase q := by - classical - unfold sqrtLowerBase - have hnum_pos : 0 < (Nat.sqrt q.num.natAbs : Rat) := by - have hnum_pos' : 0 < q.num.natAbs := by - have hnum : 0 < q.num := (Rat.num_pos (a := q)).2 hq - exact Int.natAbs_pos.mpr hnum.ne' - exact_mod_cast (Nat.sqrt_pos.2 hnum_pos') - have hden_pos : 0 < (Nat.sqrt q.den + 1 : Rat) := by - exact_mod_cast (Nat.succ_pos _) - exact div_pos hnum_pos hden_pos /-- `sqrtUpperBase` is nonnegative. -/ -theorem sqrtUpperBase_nonneg (q : Rat) : 0 ≤ sqrtUpperBase q := by +theorem sqrtUpperBase_nonneg (q : Dyadic) : 0 ≤ sqrtUpperBase q := by classical unfold sqrtUpperBase - have hden : 0 ≤ (Nat.sqrt q.den : Rat) := by - exact_mod_cast (Nat.zero_le _) - have hnum : 0 ≤ (Nat.sqrt q.num.natAbs + 1 : Rat) := by - exact_mod_cast (Nat.zero_le _) - exact div_nonneg hnum hden + have hnum : 0 ≤ (Nat.sqrt q.toRat.num.natAbs + 1 : Rat) := by + exact_mod_cast (Nat.zero_le (Nat.sqrt q.toRat.num.natAbs + 1)) + have hden : 0 ≤ (Nat.sqrt q.toRat.den : Rat) := by + exact_mod_cast (Nat.zero_le (Nat.sqrt q.toRat.den)) + have hrat : + 0 ≤ (Nat.sqrt q.toRat.num.natAbs + 1 : Rat) / (Nat.sqrt q.toRat.den) := by + exact div_nonneg hnum hden + exact dyadicOfRatUp_nonneg hrat /-- `sqrtUpperBase` is always positive. -/ -theorem sqrtUpperBase_pos (q : Rat) : 0 < sqrtUpperBase q := by +theorem sqrtUpperBase_pos (q : Dyadic) : 0 < sqrtUpperBase q := by classical unfold sqrtUpperBase - have hnum_pos : 0 < (Nat.sqrt q.num.natAbs + 1 : Rat) := by - exact_mod_cast (Nat.succ_pos _) - have hden_pos : 0 < (Nat.sqrt q.den : Rat) := by - have hden : 0 < q.den := q.den_pos + have hnum_pos : (0 : Rat) < (Nat.sqrt q.toRat.num.natAbs + 1 : Rat) := by + exact_mod_cast (Nat.succ_pos (Nat.sqrt q.toRat.num.natAbs)) + have hden_pos : (0 : Rat) < (Nat.sqrt q.toRat.den : Rat) := by + have hden : 0 < q.toRat.den := q.toRat.den_pos exact_mod_cast (Nat.sqrt_pos.2 hden) - exact div_pos hnum_pos hden_pos + have hrat_pos : + (0 : Rat) < (Nat.sqrt q.toRat.num.natAbs + 1 : Rat) / (Nat.sqrt q.toRat.den) := by + exact div_pos hnum_pos hden_pos + exact dyadicOfRatUp_pos hrat_pos /-! Alternate bounds. -/ /-- `sqrtLowerAlt` is nonnegative. -/ -theorem sqrtLowerAlt_nonneg (q : Rat) : 0 ≤ sqrtLowerAlt q := by +theorem sqrtLowerAlt_nonneg (q : Dyadic) : 0 ≤ sqrtLowerAlt q := by classical unfold sqrtLowerAlt - have hnum : 0 ≤ (Nat.sqrt (q.num.natAbs * q.den) : Rat) := by - exact_mod_cast (Nat.zero_le _) - have hden : 0 ≤ (q.den : Rat) := by - exact_mod_cast (Nat.zero_le _) - exact div_nonneg hnum hden + have hnum : 0 ≤ (Nat.sqrt (q.toRat.num.natAbs * q.toRat.den) : Rat) := by + exact_mod_cast (Nat.zero_le (Nat.sqrt (q.toRat.num.natAbs * q.toRat.den))) + have hden : 0 ≤ (q.toRat.den : Rat) := by + exact_mod_cast (Nat.zero_le q.toRat.den) + have hrat : + 0 ≤ (Nat.sqrt (q.toRat.num.natAbs * q.toRat.den) : Rat) / q.toRat.den := by + exact div_nonneg hnum hden + exact dyadicOfRatDown_nonneg hrat -/-- `sqrtLowerAlt` is positive when its input is positive. -/ -theorem sqrtLowerAlt_pos {q : Rat} (hq : 0 < q) : 0 < sqrtLowerAlt q := by - classical - unfold sqrtLowerAlt - have hnum_pos : 0 < (Nat.sqrt (q.num.natAbs * q.den) : Rat) := by - have hnum_pos' : 0 < q.num.natAbs := by - have hnum : 0 < q.num := (Rat.num_pos (a := q)).2 hq - exact Int.natAbs_pos.mpr hnum.ne' - have hden_pos : 0 < q.den := q.den_pos - have hmul_pos : 0 < q.num.natAbs * q.den := by - exact Nat.mul_pos hnum_pos' hden_pos - exact_mod_cast (Nat.sqrt_pos.2 hmul_pos) - have hden_pos : 0 < (q.den : Rat) := by - exact_mod_cast q.den_pos - exact div_pos hnum_pos hden_pos /-- `sqrtUpperAlt` is nonnegative. -/ -theorem sqrtUpperAlt_nonneg (q : Rat) : 0 ≤ sqrtUpperAlt q := by +theorem sqrtUpperAlt_nonneg (q : Dyadic) : 0 ≤ sqrtUpperAlt q := by classical unfold sqrtUpperAlt - have hnum : 0 ≤ (Nat.sqrt (q.num.natAbs * q.den) + 1 : Rat) := by - exact_mod_cast (Nat.zero_le _) - have hden : 0 ≤ (q.den : Rat) := by - exact_mod_cast (Nat.zero_le _) - exact div_nonneg hnum hden + have hnum : 0 ≤ (Nat.sqrt (q.toRat.num.natAbs * q.toRat.den) + 1 : Rat) := by + exact_mod_cast (Nat.zero_le (Nat.sqrt (q.toRat.num.natAbs * q.toRat.den) + 1)) + have hden : 0 ≤ (q.toRat.den : Rat) := by + exact_mod_cast (Nat.zero_le q.toRat.den) + have hrat : + 0 ≤ (Nat.sqrt (q.toRat.num.natAbs * q.toRat.den) + 1 : Rat) / q.toRat.den := by + exact div_nonneg hnum hden + exact dyadicOfRatUp_nonneg hrat /-- `sqrtUpperAlt` is always positive. -/ -theorem sqrtUpperAlt_pos (q : Rat) : 0 < sqrtUpperAlt q := by +theorem sqrtUpperAlt_pos (q : Dyadic) : 0 < sqrtUpperAlt q := by classical unfold sqrtUpperAlt - have hnum_pos : 0 < (Nat.sqrt (q.num.natAbs * q.den) + 1 : Rat) := by - exact_mod_cast (Nat.succ_pos _) - have hden_pos : 0 < (q.den : Rat) := by - exact_mod_cast q.den_pos - exact div_pos hnum_pos hden_pos + have hnum_pos : + (0 : Rat) < (Nat.sqrt (q.toRat.num.natAbs * q.toRat.den) + 1 : Rat) := by + exact_mod_cast (Nat.succ_pos (Nat.sqrt (q.toRat.num.natAbs * q.toRat.den))) + have hden_pos : (0 : Rat) < (q.toRat.den : Rat) := by + exact_mod_cast q.toRat.den_pos + have hrat_pos : + (0 : Rat) < + (Nat.sqrt (q.toRat.num.natAbs * q.toRat.den) + 1 : Rat) / q.toRat.den := by + exact div_pos hnum_pos hden_pos + exact dyadicOfRatUp_pos hrat_pos /-! Combined bounds. -/ /-- `sqrtLower` is nonnegative. -/ -theorem sqrtLower_nonneg (q : Rat) : 0 ≤ sqrtLower q := by +theorem sqrtLower_nonneg (q : Dyadic) : 0 ≤ sqrtLower q := by have hbase : 0 ≤ sqrtLowerBase q := sqrtLowerBase_nonneg q exact le_trans hbase (le_max_left _ _) -/-- `sqrtLower` is positive when its input is positive. -/ -theorem sqrtLower_pos {q : Rat} (hq : 0 < q) : 0 < sqrtLower q := by - have hbase : 0 < sqrtLowerBase q := sqrtLowerBase_pos hq - exact lt_of_lt_of_le hbase (le_max_left _ _) /-- `sqrtUpper` is nonnegative. -/ -theorem sqrtUpper_nonneg (q : Rat) : 0 ≤ sqrtUpper q := by +theorem sqrtUpper_nonneg (q : Dyadic) : 0 ≤ sqrtUpper q := by have hbase : 0 ≤ sqrtUpperBase q := sqrtUpperBase_nonneg q have halt : 0 ≤ sqrtUpperAlt q := sqrtUpperAlt_nonneg q exact le_min hbase halt /-- `sqrtUpper` is always positive. -/ -theorem sqrtUpper_pos (q : Rat) : 0 < sqrtUpper q := by +theorem sqrtUpper_pos (q : Dyadic) : 0 < sqrtUpper q := by have hbase : 0 < sqrtUpperBase q := sqrtUpperBase_pos q have halt : 0 < sqrtUpperAlt q := sqrtUpperAlt_pos q exact lt_min hbase halt /-- Square-root lower bound in reals. -/ -theorem sqrtLowerBase_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : +theorem sqrtLowerBase_le_real_sqrt {q : Dyadic} (hq : 0 ≤ q) : (sqrtLowerBase q : Real) ≤ Real.sqrt (q : Real) := by classical -- Set up numerator/denominator witnesses. - set num : Nat := q.num.natAbs - set den : Nat := q.den + set num : Nat := q.toRat.num.natAbs + set den : Nat := q.toRat.den set a : Nat := Nat.sqrt num set b : Nat := Nat.sqrt den have hden_pos : 0 < (den : Real) := by - exact_mod_cast q.den_pos + exact_mod_cast q.toRat.den_pos have hbpos : 0 < (b + 1 : Real) := by exact_mod_cast (Nat.succ_pos b) have hnum_le : (a ^ 2 : Real) ≤ num := by @@ -413,14 +210,16 @@ theorem sqrtLowerBase_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : have hpow : ((a : Real) / (b + 1 : Real)) ^ 2 = (a ^ 2 : Real) / (b + 1) ^ 2 := by simp [pow_two, div_mul_div_comm] have hq_cast : (q : Real) = (num : Real) / den := by - have hnum_nonneg : 0 ≤ q.num := by - exact (Rat.num_nonneg (q := q)).2 hq - have hnum_eq : (num : Int) = q.num := by + have hnum_nonneg : 0 ≤ q.toRat.num := by + have hq' : (0 : Rat) ≤ q.toRat := + (Dyadic.toRat_le_toRat_iff (x := 0) (y := q)).2 hq + exact (Rat.num_nonneg (q := q.toRat)).2 hq' + have hnum_eq : (num : Int) = q.toRat.num := by simpa [num] using (Int.natAbs_of_nonneg hnum_nonneg) - have hnum_cast : (q.num : Real) = (num : Real) := by + have hnum_cast : (q.toRat.num : Real) = (num : Real) := by exact_mod_cast hnum_eq.symm - have hq_rat : (q : Real) = (q.num : Real) / q.den := by - simp [Rat.cast_def] + have hq_rat : (q : Real) = (q.toRat.num : Real) / q.toRat.den := by + simp [dyadicToReal, Rat.cast_def] simpa [hnum_cast, den] using hq_rat have hsq : ((a : Real) / (b + 1 : Real)) ^ 2 ≤ (q : Real) := by simpa [hpow, hq_cast, den, num] using hdiv @@ -428,24 +227,32 @@ theorem sqrtLowerBase_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : have hnum_nonneg : 0 ≤ (a : Real) := by exact_mod_cast (Nat.zero_le a) have hden_nonneg : 0 ≤ (b + 1 : Real) := by exact_mod_cast (Nat.zero_le (b + 1)) exact div_nonneg hnum_nonneg hden_nonneg - have hq_nonneg : 0 ≤ (q : Real) := by exact_mod_cast hq + have hq_nonneg : 0 ≤ (q : Real) := by + exact dyadicToReal_nonneg_of_nonneg hq have hle : (a : Real) / (b + 1 : Real) ≤ Real.sqrt (q : Real) := (Real.le_sqrt hnonneg hq_nonneg).2 hsq - simpa [sqrtLowerBase, num, den, a, b] using hle + have hdown : + (sqrtLowerBase q : Real) ≤ (a : Real) / (b + 1 : Real) := by + have hdown' : + dyadicToReal (dyadicOfRatDown ((a : Rat) / (b + 1))) ≤ + (a : Real) / (b + 1 : Real) := by + simpa using dyadicOfRatDown_le_real ((a : Rat) / (b + 1)) + simpa [sqrtLowerBase, num, den, a, b] using hdown' + exact le_trans hdown hle /-- Square-root upper bound in reals. -/ -theorem real_sqrt_le_sqrtUpperBase {q : Rat} (hq : 0 ≤ q) : +theorem real_sqrt_le_sqrtUpperBase {q : Dyadic} (hq : 0 ≤ q) : Real.sqrt (q : Real) ≤ (sqrtUpperBase q : Real) := by classical - set num : Nat := q.num.natAbs - set den : Nat := q.den + set num : Nat := q.toRat.num.natAbs + set den : Nat := q.toRat.den set a : Nat := Nat.sqrt num set b : Nat := Nat.sqrt den have hden_pos : 0 < (den : Real) := by - exact_mod_cast q.den_pos + exact_mod_cast q.toRat.den_pos have hbpos : 0 < (b : Real) := by have hb : 0 < b := by - have hden : 0 < den := q.den_pos + have hden : 0 < den := q.toRat.den_pos exact (Nat.sqrt_pos).2 hden exact_mod_cast hb have hnum_lt : (num : Real) < (a + 1) ^ 2 := by @@ -465,14 +272,16 @@ theorem real_sqrt_le_sqrtUpperBase {q : Rat} (hq : 0 ≤ q) : have hpow : ((a + 1 : Real) / (b : Real)) ^ 2 = (a + 1) ^ 2 / (b : Real) ^ 2 := by simp [pow_two, div_mul_div_comm] have hq_cast : (q : Real) = (num : Real) / den := by - have hnum_nonneg : 0 ≤ q.num := by - exact (Rat.num_nonneg (q := q)).2 hq - have hnum_eq : (num : Int) = q.num := by + have hnum_nonneg : 0 ≤ q.toRat.num := by + have hq' : (0 : Rat) ≤ q.toRat := + (Dyadic.toRat_le_toRat_iff (x := 0) (y := q)).2 hq + exact (Rat.num_nonneg (q := q.toRat)).2 hq' + have hnum_eq : (num : Int) = q.toRat.num := by simpa [num] using (Int.natAbs_of_nonneg hnum_nonneg) - have hnum_cast : (q.num : Real) = (num : Real) := by + have hnum_cast : (q.toRat.num : Real) = (num : Real) := by exact_mod_cast hnum_eq.symm - have hq_rat : (q : Real) = (q.num : Real) / q.den := by - simp [Rat.cast_def] + have hq_rat : (q : Real) = (q.toRat.num : Real) / q.toRat.den := by + simp [dyadicToReal, Rat.cast_def] simpa [hnum_cast, den] using hq_rat have hsq : (q : Real) ≤ ((a + 1 : Real) / (b : Real)) ^ 2 := by simpa [hpow, hq_cast, den, num] using hdiv @@ -482,17 +291,24 @@ theorem real_sqrt_le_sqrtUpperBase {q : Rat} (hq : 0 ≤ q) : exact div_nonneg hnum_nonneg hden_nonneg have hle : Real.sqrt (q : Real) ≤ (a + 1 : Real) / (b : Real) := (Real.sqrt_le_iff).2 ⟨hnonneg, hsq⟩ - simpa [sqrtUpperBase, num, den, a, b] using hle + have hup : + (a + 1 : Real) / (b : Real) ≤ (sqrtUpperBase q : Real) := by + have hup' : + (a + 1 : Real) / (b : Real) ≤ + dyadicToReal (dyadicOfRatUp ((a + 1 : Rat) / b)) := by + simpa using real_le_dyadicOfRatUp ((a + 1 : Rat) / b) + simpa [sqrtUpperBase, num, den, a, b] using hup' + exact le_trans hle hup /-- Alternate square-root lower bound in reals. -/ -theorem sqrtLowerAlt_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : +theorem sqrtLowerAlt_le_real_sqrt {q : Dyadic} (hq : 0 ≤ q) : (sqrtLowerAlt q : Real) ≤ Real.sqrt (q : Real) := by classical - set num : Nat := q.num.natAbs - set den : Nat := q.den + set num : Nat := q.toRat.num.natAbs + set den : Nat := q.toRat.den set a : Nat := Nat.sqrt (num * den) have hden_pos : 0 < (den : Real) := by - exact_mod_cast q.den_pos + exact_mod_cast q.toRat.den_pos have hnumden_le : (a ^ 2 : Real) ≤ (num * den : Nat) := by exact_mod_cast (Nat.sqrt_le' (num * den)) have hmul : (a ^ 2 : Real) ≤ (num : Real) * den := by @@ -508,16 +324,18 @@ theorem sqrtLowerAlt_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : exact mul_le_mul_of_nonneg_right hmul hden_sq_nonneg exact (div_le_div_iff₀ hden_pos2 hden_pos2).2 hmul' have hden_ne : (den : Real) ≠ 0 := by - exact_mod_cast q.den_pos.ne' + exact_mod_cast q.toRat.den_pos.ne' have hq_cast : (q : Real) = (num : Real) * den / (den : Real) ^ 2 := by - have hnum_nonneg : 0 ≤ q.num := by - exact (Rat.num_nonneg (q := q)).2 hq - have hnum_eq : (num : Int) = q.num := by + have hnum_nonneg : 0 ≤ q.toRat.num := by + have hq' : (0 : Rat) ≤ q.toRat := + (Dyadic.toRat_le_toRat_iff (x := 0) (y := q)).2 hq + exact (Rat.num_nonneg (q := q.toRat)).2 hq' + have hnum_eq : (num : Int) = q.toRat.num := by simpa [num] using (Int.natAbs_of_nonneg hnum_nonneg) - have hnum_cast : (q.num : Real) = (num : Real) := by + have hnum_cast : (q.toRat.num : Real) = (num : Real) := by exact_mod_cast hnum_eq.symm - have hq_rat : (q : Real) = (q.num : Real) / q.den := by - simp [Rat.cast_def] + have hq_rat : (q : Real) = (q.toRat.num : Real) / q.toRat.den := by + simp [dyadicToReal, Rat.cast_def] have hq_eq : (num : Real) / den = (num : Real) * den / (den : Real) ^ 2 := by field_simp [hden_ne] @@ -528,20 +346,28 @@ theorem sqrtLowerAlt_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : have hnum_nonneg : 0 ≤ (a : Real) := by exact_mod_cast (Nat.zero_le a) have hden_nonneg : 0 ≤ (den : Real) := by exact_mod_cast (Nat.zero_le den) exact div_nonneg hnum_nonneg hden_nonneg - have hq_nonneg : 0 ≤ (q : Real) := by exact_mod_cast hq + have hq_nonneg : 0 ≤ (q : Real) := by + exact dyadicToReal_nonneg_of_nonneg hq have hle : (a : Real) / (den : Real) ≤ Real.sqrt (q : Real) := (Real.le_sqrt hnonneg hq_nonneg).2 hsq - simpa [sqrtLowerAlt, num, den, a] using hle + have hdown : + (sqrtLowerAlt q : Real) ≤ (a : Real) / (den : Real) := by + have hdown' : + dyadicToReal (dyadicOfRatDown ((a : Rat) / den)) ≤ + (a : Real) / (den : Real) := by + simpa using dyadicOfRatDown_le_real ((a : Rat) / den) + simpa [sqrtLowerAlt, num, den, a] using hdown' + exact le_trans hdown hle /-- Alternate square-root upper bound in reals. -/ -theorem real_sqrt_le_sqrtUpperAlt {q : Rat} (hq : 0 ≤ q) : +theorem real_sqrt_le_sqrtUpperAlt {q : Dyadic} (hq : 0 ≤ q) : Real.sqrt (q : Real) ≤ (sqrtUpperAlt q : Real) := by classical - set num : Nat := q.num.natAbs - set den : Nat := q.den + set num : Nat := q.toRat.num.natAbs + set den : Nat := q.toRat.den set a : Nat := Nat.sqrt (num * den) have hden_pos : 0 < (den : Real) := by - exact_mod_cast q.den_pos + exact_mod_cast q.toRat.den_pos have hnumden_lt : (num * den : Real) < (a + 1) ^ 2 := by exact_mod_cast (Nat.lt_succ_sqrt' (num * den)) have hmul : (num : Real) * den ≤ (a + 1 : Real) ^ 2 := by @@ -557,16 +383,18 @@ theorem real_sqrt_le_sqrtUpperAlt {q : Rat} (hq : 0 ≤ q) : exact mul_le_mul_of_nonneg_right hmul hden_sq_nonneg exact (div_le_div_iff₀ hden_pos2 hden_pos2).2 hmul' have hden_ne : (den : Real) ≠ 0 := by - exact_mod_cast q.den_pos.ne' + exact_mod_cast q.toRat.den_pos.ne' have hq_cast : (q : Real) = (num : Real) * den / (den : Real) ^ 2 := by - have hnum_nonneg : 0 ≤ q.num := by - exact (Rat.num_nonneg (q := q)).2 hq - have hnum_eq : (num : Int) = q.num := by + have hnum_nonneg : 0 ≤ q.toRat.num := by + have hq' : (0 : Rat) ≤ q.toRat := + (Dyadic.toRat_le_toRat_iff (x := 0) (y := q)).2 hq + exact (Rat.num_nonneg (q := q.toRat)).2 hq' + have hnum_eq : (num : Int) = q.toRat.num := by simpa [num] using (Int.natAbs_of_nonneg hnum_nonneg) - have hnum_cast : (q.num : Real) = (num : Real) := by + have hnum_cast : (q.toRat.num : Real) = (num : Real) := by exact_mod_cast hnum_eq.symm - have hq_rat : (q : Real) = (q.num : Real) / q.den := by - simp [Rat.cast_def] + have hq_rat : (q : Real) = (q.toRat.num : Real) / q.toRat.den := by + simp [dyadicToReal, Rat.cast_def] have hq_eq : (num : Real) / den = (num : Real) * den / (den : Real) ^ 2 := by field_simp [hden_ne] @@ -583,31 +411,38 @@ theorem real_sqrt_le_sqrtUpperAlt {q : Rat} (hq : 0 ≤ q) : exact div_nonneg hnum_nonneg hden_nonneg have hle : Real.sqrt (q : Real) ≤ (a + 1 : Real) / (den : Real) := (Real.sqrt_le_iff).2 ⟨hnonneg, hsq⟩ - simpa [sqrtUpperAlt, num, den, a] using hle + have hup : + (a + 1 : Real) / (den : Real) ≤ (sqrtUpperAlt q : Real) := by + have hup' : + (a + 1 : Real) / (den : Real) ≤ + dyadicToReal (dyadicOfRatUp ((a + 1 : Rat) / den)) := by + simpa using real_le_dyadicOfRatUp ((a + 1 : Rat) / den) + simpa [sqrtUpperAlt, num, den, a] using hup' + exact le_trans hle hup /-- Square-root lower bound in reals (tighter of two bounds). -/ -theorem sqrtLower_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : +theorem sqrtLower_le_real_sqrt {q : Dyadic} (hq : 0 ≤ q) : (sqrtLower q : Real) ≤ Real.sqrt (q : Real) := by have hbase := sqrtLowerBase_le_real_sqrt (q := q) hq have halt := sqrtLowerAlt_le_real_sqrt (q := q) hq simpa [sqrtLower] using (max_le_iff).2 ⟨hbase, halt⟩ /-- Square-root upper bound in reals (tighter of two bounds). -/ -theorem real_sqrt_le_sqrtUpper {q : Rat} (hq : 0 ≤ q) : +theorem real_sqrt_le_sqrtUpper {q : Dyadic} (hq : 0 ≤ q) : Real.sqrt (q : Real) ≤ (sqrtUpper q : Real) := by have hbase := real_sqrt_le_sqrtUpperBase (q := q) hq have halt := real_sqrt_le_sqrtUpperAlt (q := q) hq simpa [sqrtUpper] using (le_min_iff).2 ⟨hbase, halt⟩ /-- Bounds for multiplying a scalar by a bounded value. -/ -def scaleInterval (x lo hi : Rat) : Rat × Rat := +def scaleInterval (x lo hi : Dyadic) : Dyadic × Dyadic := if 0 ≤ x then (x * lo, x * hi) else (x * hi, x * lo) /-- `scaleInterval` bounds a product. -/ -theorem scaleInterval_bounds {x lo hi y : Rat} +theorem scaleInterval_bounds {x lo hi y : Dyadic} (hlo : lo ≤ y) (hhi : y ≤ hi) : let bounds := scaleInterval x lo hi bounds.1 ≤ x * y ∧ x * y ≤ bounds.2 := by @@ -625,37 +460,41 @@ theorem scaleInterval_bounds {x lo hi y : Rat} simp [scaleInterval, hx, h1, h2] /-- `scaleInterval` bounds interpreted in the reals. -/ -theorem scaleInterval_bounds_real {x lo hi : Rat} {y : Real} +theorem scaleInterval_bounds_real {x lo hi : Dyadic} {y : Real} (hlo : (lo : Real) ≤ y) (hhi : y ≤ (hi : Real)) : let bounds := scaleInterval x lo hi (bounds.1 : Real) ≤ (x : Real) * y ∧ (x : Real) * y ≤ (bounds.2 : Real) := by by_cases hx : 0 ≤ x · have h1 : (x : Real) * (lo : Real) ≤ (x : Real) * y := by - exact mul_le_mul_of_nonneg_left hlo (by exact_mod_cast hx) + have hx' : 0 ≤ (x : Real) := dyadicToReal_nonneg_of_nonneg hx + exact mul_le_mul_of_nonneg_left hlo hx' have h2 : (x : Real) * y ≤ (x : Real) * (hi : Real) := by - exact mul_le_mul_of_nonneg_left hhi (by exact_mod_cast hx) + have hx' : 0 ≤ (x : Real) := dyadicToReal_nonneg_of_nonneg hx + exact mul_le_mul_of_nonneg_left hhi hx' simp [scaleInterval, hx, h1, h2] · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) have h1 : (x : Real) * (hi : Real) ≤ (x : Real) * y := by - exact mul_le_mul_of_nonpos_left hhi (by exact_mod_cast hx') + have hx'' : (x : Real) ≤ 0 := (dyadicToReal_nonpos_iff (x := x)).2 hx' + exact mul_le_mul_of_nonpos_left hhi hx'' have h2 : (x : Real) * y ≤ (x : Real) * (lo : Real) := by - exact mul_le_mul_of_nonpos_left hlo (by exact_mod_cast hx') + have hx'' : (x : Real) ≤ 0 := (dyadicToReal_nonpos_iff (x := x)).2 hx' + exact mul_le_mul_of_nonpos_left hlo hx'' simp [scaleInterval, hx, h1, h2] /-- Real-valued LayerNorm output for a vector. -/ noncomputable def layerNormReal {n : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) : Fin n → Real := + (eps : Dyadic) (gamma beta : Fin n → Dyadic) (x : Fin n → Dyadic) : Fin n → Real := if n = 0 then fun _ => 0 else - let μ : Real := mean x - let varEps : Real := (variance x + eps : Rat) + let μ : Real := meanRat x + let varEps : Real := (varianceRat x : Real) + (eps : Real) let invStd : Real := (Real.sqrt varEps)⁻¹ fun i => (gamma i : Real) * ((x i : Real) - μ) * invStd + (beta i : Real) /-- Real-valued LayerNorm output for a real vector. -/ noncomputable def layerNormRealOfReal {n : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Real) : Fin n → Real := + (eps : Dyadic) (gamma beta : Fin n → Dyadic) (x : Fin n → Real) : Fin n → Real := if n = 0 then fun _ => 0 else @@ -666,144 +505,151 @@ noncomputable def layerNormRealOfReal {n : Nat} /-- Interval bounds for LayerNorm outputs. -/ def layerNormBounds {n : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) : - (Fin n → Rat) × (Fin n → Rat) := + (eps : Dyadic) (gamma beta : Fin n → Dyadic) (x : Fin n → Dyadic) : + (Fin n → Dyadic) × (Fin n → Dyadic) := if n = 0 then (fun _ => 0, fun _ => 0) else - let μ := mean x - let var := variance x - let varEps := var + eps - let sLo := sqrtLower varEps - let sHi := sqrtUpper varEps - let invLo := sHi⁻¹ - let invHi := sLo⁻¹ - let normBounds : Fin n → Rat × Rat := fun i => - let centered := x i - μ - scaleInterval centered invLo invHi - let outBounds : Fin n → Rat × Rat := fun i => - let nb := normBounds i - let sb := scaleInterval (gamma i) nb.1 nb.2 - (sb.1 + beta i, sb.2 + beta i) - (fun i => (outBounds i).1, fun i => (outBounds i).2) + let μLo := mean x + let μHi := meanUpper x + let centeredBound : Fin n → Dyadic := fun i => + max |x i - μHi| |x i - μLo| + let invStdBound : Dyadic := dyadicDivUp 1 (sqrtLower eps) + let radius : Fin n → Dyadic := fun i => |gamma i| * centeredBound i * invStdBound + (fun i => beta i - radius i, fun i => beta i + radius i) /-- `layerNormBounds` soundness for real LayerNorm outputs. -/ theorem layerNormBounds_spec {n : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) - (hne : n ≠ 0) (heps : 0 < eps) : + (eps : Dyadic) (gamma beta : Fin n → Dyadic) (x : Fin n → Dyadic) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : let bounds := layerNormBounds eps gamma beta x ∀ i, (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i ∧ layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by classical intro bounds i - have hvar_nonneg : 0 ≤ variance x := variance_nonneg x hne - have hvarEps_pos : 0 < variance x + eps := by - exact add_pos_of_nonneg_of_pos hvar_nonneg heps - have hvarEps_nonneg : 0 ≤ variance x + eps := by - exact le_of_lt hvarEps_pos - let varEps : Rat := variance x + eps - let sLo : Rat := sqrtLower varEps - let sHi : Rat := sqrtUpper varEps - let invLo : Rat := sHi⁻¹ - let invHi : Rat := sLo⁻¹ - let invStd : Real := (Real.sqrt (varEps : Real))⁻¹ - have hsLo : (sLo : Real) ≤ Real.sqrt (varEps : Real) := by - have hsLo' := sqrtLower_le_real_sqrt (q := varEps) hvarEps_nonneg - simpa [sLo, varEps, Rat.cast_add] using hsLo' - have hsHi : Real.sqrt (varEps : Real) ≤ (sHi : Real) := by - have hsHi' := real_sqrt_le_sqrtUpper (q := varEps) hvarEps_nonneg - simpa [sHi, varEps, Rat.cast_add] using hsHi' - have hsqrt_pos : 0 < Real.sqrt (varEps : Real) := by - exact Real.sqrt_pos.2 (by exact_mod_cast hvarEps_pos) - have hsLo_pos : 0 < (sLo : Real) := by - exact_mod_cast (sqrtLower_pos (q := varEps) hvarEps_pos) - have hsHi_ne : (sHi : Rat) ≠ 0 := ne_of_gt (sqrtUpper_pos varEps) - have hsLo_ne : (sLo : Rat) ≠ 0 := ne_of_gt (sqrtLower_pos (q := varEps) hvarEps_pos) - have hcast_inv_hi : (invLo : Real) = (sHi : Real)⁻¹ := by - have hnum_ne : (sHi.num : Real) ≠ 0 := by - exact_mod_cast (Rat.num_ne_zero (q := sHi)).2 hsHi_ne - have hcast := Rat.cast_inv_of_ne_zero (q := sHi) hnum_ne - dsimp [invLo] - exact hcast - have hcast_inv_lo : (invHi : Real) = (sLo : Real)⁻¹ := by - have hnum_ne : (sLo.num : Real) ≠ 0 := by - exact_mod_cast (Rat.num_ne_zero (q := sLo)).2 hsLo_ne - have hcast := Rat.cast_inv_of_ne_zero (q := sLo) hnum_ne - dsimp [invHi] - exact hcast - have hinv_lo : (invLo : Real) ≤ invStd := by - have hcalc : (sHi : Real)⁻¹ ≤ invStd := by - have h := one_div_le_one_div_of_le hsqrt_pos hsHi - simpa [one_div, invStd] using h - simpa [hcast_inv_hi] using hcalc - have hinv_hi : invStd ≤ (invHi : Real) := by - have hcalc : invStd ≤ (sLo : Real)⁻¹ := by - have h := one_div_le_one_div_of_le hsLo_pos hsLo - simpa [one_div, invStd] using h - simpa [hcast_inv_lo] using hcalc - let μ : Rat := mean x - let centered : Rat := x i - μ - let nb : Rat × Rat := scaleInterval centered invLo invHi - have hnb : (nb.1 : Real) ≤ (centered : Real) * invStd ∧ - (centered : Real) * invStd ≤ (nb.2 : Real) := by - have hscale := scaleInterval_bounds_real (x := centered) - (lo := invLo) (hi := invHi) (y := invStd) hinv_lo hinv_hi - simpa [nb] using hscale - let sb : Rat × Rat := scaleInterval (gamma i) nb.1 nb.2 - have hsb : - (sb.1 : Real) ≤ (gamma i : Real) * ((centered : Real) * invStd) ∧ - (gamma i : Real) * ((centered : Real) * invStd) ≤ (sb.2 : Real) := by - have hscale := scaleInterval_bounds_real (x := gamma i) - (lo := nb.1) (hi := nb.2) (y := (centered : Real) * invStd) hnb.1 hnb.2 - simpa [sb] using hscale - let lo : Rat := sb.1 + beta i - let hi : Rat := sb.2 + beta i + let μLo : Dyadic := mean x + let μHi : Dyadic := meanUpper x + let centeredBound : Fin n → Dyadic := fun j => max |x j - μHi| |x j - μLo| + let invStdBound : Dyadic := dyadicDivUp 1 (sqrtLower eps) + let varEps : Real := (varianceRat x : Real) + (eps : Real) + let μ : Real := meanRat x + let invStd : Real := (Real.sqrt varEps)⁻¹ + have hcentered_nonneg : 0 ≤ (centeredBound i : Real) := by + have h0 : 0 ≤ centeredBound i := by + dsimp [centeredBound] + exact le_trans (abs_nonneg _) (le_max_left _ _) + exact dyadicToReal_nonneg_of_nonneg h0 + have hmean_lo_real : (μLo : Real) ≤ μ := by + have h := dyadicOfRatDown_le_real (meanRat x) + simpa [μLo, μ, mean_def x hne] using h + have hmean_hi_real : μ ≤ (μHi : Real) := by + have h := real_le_dyadicOfRatUp (meanRat x) + simpa [μHi, μ, meanUpper_def x hne] using h + have hcentered_abs : |(x i : Real) - μ| ≤ (centeredBound i : Real) := by + have hlo : (x i : Real) - (μHi : Real) ≤ (x i : Real) - μ := by + exact sub_le_sub_left hmean_hi_real (x i : Real) + have hhi : (x i : Real) - μ ≤ (x i : Real) - (μLo : Real) := by + exact sub_le_sub_left hmean_lo_real (x i : Real) + have hbound := abs_le_max_of_bounds hlo hhi + simpa [centeredBound, μLo, μHi, dyadicToReal_abs, dyadicToReal_sub, + dyadicToReal_max] using hbound + have hvar_nonneg : 0 ≤ (varianceRat x : Real) := varianceRat_nonneg_real x hne + have hsqrt_lower : + (sqrtLower eps : Real) ≤ Real.sqrt varEps := by + have hsqrt_eps : + (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by + have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) + simpa using h + have hle : (eps : Real) ≤ varEps := by + have hle' : (eps : Real) ≤ (varianceRat x : Real) + (eps : Real) := + le_add_of_nonneg_left hvar_nonneg + simpa [varEps] using hle' + have hsqrt_eps' : Real.sqrt (eps : Real) ≤ Real.sqrt varEps := by + exact Real.sqrt_le_sqrt hle + exact le_trans hsqrt_eps hsqrt_eps' + have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by + simpa [dyadicToReal_zero] using + (dyadicToReal_lt_iff (x := 0) (y := sqrtLower eps)).2 hsqrt + have hinv_sqrt : invStd ≤ (sqrtLower eps : Real)⁻¹ := by + have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower + simpa [invStd] using h + have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by + have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt + have hdiv := dyadicDivUp_ge_real (x := 1) (y := sqrtLower eps) hy + simpa [invStdBound, one_div] using hdiv + have hinv : invStd ≤ (invStdBound : Real) := by + exact le_trans hinv_sqrt hinv_bound + have hinv_nonneg : 0 ≤ invStd := by + have hsqrt_nonneg : 0 ≤ Real.sqrt varEps := by + exact Real.sqrt_nonneg _ + exact inv_nonneg.2 hsqrt_nonneg + have hmul1 : |(x i : Real) - μ| * invStd ≤ + (centeredBound i : Real) * (invStdBound : Real) := by + have hleft : + |(x i : Real) - μ| * invStd ≤ (centeredBound i : Real) * invStd := by + exact mul_le_mul_of_nonneg_right hcentered_abs hinv_nonneg + have hright : + (centeredBound i : Real) * invStd ≤ (centeredBound i : Real) * (invStdBound : Real) := by + exact mul_le_mul_of_nonneg_left hinv hcentered_nonneg + exact le_trans hleft hright + have hmul2 : |(gamma i : Real)| * |(x i : Real) - μ| * invStd ≤ + |(gamma i : Real)| * (centeredBound i : Real) * (invStdBound : Real) := by + have hgamma_nonneg : 0 ≤ |(gamma i : Real)| := abs_nonneg _ + have hmul2' : |(gamma i : Real)| * (|(x i : Real) - μ| * invStd) ≤ + |(gamma i : Real)| * ((centeredBound i : Real) * (invStdBound : Real)) := by + exact mul_le_mul_of_nonneg_left hmul1 hgamma_nonneg + simpa [mul_assoc] using hmul2' + let t : Real := (gamma i : Real) * ((x i : Real) - μ) * invStd + have ht_abs : + |t| ≤ |(gamma i : Real)| * (centeredBound i : Real) * (invStdBound : Real) := by + have ht : |t| = |(gamma i : Real)| * |(x i : Real) - μ| * invStd := by + have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg + simp [t, abs_mul, hinv_abs, mul_assoc] + simpa [ht] using hmul2 + let radius : Fin n → Dyadic := fun j => |gamma j| * centeredBound j * invStdBound + have ht_abs' : |t| ≤ (radius i : Real) := by + simpa [radius, centeredBound, invStdBound] using ht_abs + have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by + exact abs_le.mp ht_abs' + have hlow : + (beta i : Real) - (radius i : Real) ≤ t + (beta i : Real) := by + have h := add_le_add_left hbounds.1 (beta i : Real) + simpa [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h + have hhigh : + t + (beta i : Real) ≤ (beta i : Real) + (radius i : Real) := by + have h := add_le_add_left hbounds.2 (beta i : Real) + simpa [add_comm, add_left_comm, add_assoc] using h have hreal : - layerNormReal eps gamma beta x i = - (gamma i : Real) * ((centered : Real) * invStd) + (beta i : Real) := by - calc - layerNormReal eps gamma beta x i = - (gamma i : Real) * ((x i : Real) - μ) * invStd + (beta i : Real) := by - simp [layerNormReal, hne, μ, invStd, varEps] - _ = (gamma i : Real) * (((x i : Real) - μ) * invStd) + (beta i : Real) := by - simp [mul_assoc] - _ = (gamma i : Real) * ((centered : Real) * invStd) + (beta i : Real) := by - simp [centered] - have hlo : (lo : Real) ≤ layerNormReal eps gamma beta x i := by - have hlo' : (sb.1 : Real) ≤ (gamma i : Real) * ((centered : Real) * invStd) := hsb.1 - have hlo'' : (lo : Real) ≤ - (gamma i : Real) * ((centered : Real) * invStd) + (beta i : Real) := by - simpa [lo] using add_le_add_right hlo' (beta i : Real) - simpa [hreal] using hlo'' - have hhi : layerNormReal eps gamma beta x i ≤ (hi : Real) := by - have hhi' : (gamma i : Real) * ((centered : Real) * invStd) ≤ (sb.2 : Real) := hsb.2 - have hhi'' : - (gamma i : Real) * ((centered : Real) * invStd) + (beta i : Real) ≤ (hi : Real) := by - simpa [hi] using add_le_add_right hhi' (beta i : Real) - simpa [hreal] using hhi'' - simpa [bounds, layerNormBounds, hne, μ, varEps, invLo, invHi, centered, nb, sb, lo, hi] using - And.intro hlo hhi + layerNormReal eps gamma beta x i = t + (beta i : Real) := by + simp [layerNormReal, hne, μ, invStd, varEps, t, add_comm] + have hlo : (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i := by + simpa [bounds, layerNormBounds, hne, radius, centeredBound, invStdBound, μLo, μHi, + hreal] using hlow + have hhi : layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + simpa [bounds, layerNormBounds, hne, radius, centeredBound, invStdBound, μLo, μHi, + hreal] using hhigh + exact And.intro hlo hhi /-- Interval bounds for LayerNorm outputs from per-coordinate intervals. -/ def layerNormIntervalBounds {n : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) (lo hi : Fin n → Rat) : - (Fin n → Rat) × (Fin n → Rat) := + (eps : Dyadic) (gamma beta : Fin n → Dyadic) (lo hi : Fin n → Dyadic) : + (Fin n → Dyadic) × (Fin n → Dyadic) := if n = 0 then (fun _ => 0, fun _ => 0) else let μLo := mean lo - let μHi := mean hi - let centeredBound : Fin n → Rat := fun i => + let μHi := meanUpper hi + let centeredBound : Fin n → Dyadic := fun i => max |lo i - μHi| |hi i - μLo| - let invStdBound : Rat := (sqrtLower eps)⁻¹ - let radius : Fin n → Rat := fun i => |gamma i| * centeredBound i * invStdBound + let invStdBound : Dyadic := dyadicDivUp 1 (sqrtLower eps) + let radius : Fin n → Dyadic := fun i => |gamma i| * centeredBound i * invStdBound (fun i => beta i - radius i, fun i => beta i + radius i) /-- `layerNormIntervalBounds` soundness for real LayerNorm outputs. -/ theorem layerNormIntervalBounds_spec {n : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) (lo hi : Fin n → Rat) (x : Fin n → Rat) - (hne : n ≠ 0) (heps : 0 < eps) + (eps : Dyadic) (gamma beta : Fin n → Dyadic) (lo hi : Fin n → Dyadic) (x : Fin n → Dyadic) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) (hlo : ∀ i, lo i ≤ x i) (hhi : ∀ i, x i ≤ hi i) : let bounds := layerNormIntervalBounds eps gamma beta lo hi ∀ i, @@ -811,63 +657,80 @@ theorem layerNormIntervalBounds_spec {n : Nat} layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by classical intro bounds i - have hmean_lo : mean lo ≤ mean x := mean_le_mean lo x hne hlo - have hmean_hi : mean x ≤ mean hi := mean_le_mean x hi hne hhi - let μLo : Rat := mean lo - let μHi : Rat := mean hi - let centeredBound : Fin n → Rat := fun j => max |lo j - μHi| |hi j - μLo| - let invStdBound : Rat := (sqrtLower eps)⁻¹ - let varEps : Rat := variance x + eps - let μ : Real := mean x - let invStd : Real := (Real.sqrt (varEps : Real))⁻¹ + let μLo : Dyadic := mean lo + let μHi : Dyadic := meanUpper hi + let centeredBound : Fin n → Dyadic := fun j => max |lo j - μHi| |hi j - μLo| + let invStdBound : Dyadic := dyadicDivUp 1 (sqrtLower eps) + let varEps : Real := (varianceRat x : Real) + (eps : Real) + let μ : Real := meanRat x + let invStd : Real := (Real.sqrt varEps)⁻¹ have hcentered_nonneg : 0 ≤ (centeredBound i : Real) := by have h0 : 0 ≤ centeredBound i := by dsimp [centeredBound] exact le_trans (abs_nonneg _) (le_max_left _ _) - exact_mod_cast h0 + exact dyadicToReal_nonneg_of_nonneg h0 have hcentered_abs : |(x i : Real) - μ| ≤ (centeredBound i : Real) := by have hmean_lo_real : (μLo : Real) ≤ μ := by - have h' : (mean lo : Real) ≤ (mean x : Real) := by - exact_mod_cast hmean_lo - simpa [μLo, μ] using h' + have hmean_rat : (meanRat lo : Real) ≤ (meanRat x : Real) := + meanRat_le_meanRat_real lo x hne hlo + have hdown : (μLo : Real) ≤ (meanRat lo : Real) := by + simpa [μLo, mean_def lo hne] using dyadicOfRatDown_le_real (meanRat lo) + exact le_trans hdown hmean_rat have hmean_hi_real : μ ≤ (μHi : Real) := by - have h' : (mean x : Real) ≤ (mean hi : Real) := by - exact_mod_cast hmean_hi - simpa [μHi, μ] using h' + have hmean_rat : (meanRat x : Real) ≤ (meanRat hi : Real) := + meanRat_le_meanRat_real x hi hne hhi + have hup : (meanRat hi : Real) ≤ (μHi : Real) := by + simpa [μHi, meanUpper_def hi hne] using real_le_dyadicOfRatUp (meanRat hi) + exact le_trans hmean_rat hup have hlo' : (lo i : Real) - (μHi : Real) ≤ (x i : Real) - μ := by have h1 : (lo i : Real) - (μHi : Real) ≤ (lo i : Real) - μ := by exact sub_le_sub_left hmean_hi_real (lo i : Real) have h2 : (lo i : Real) - μ ≤ (x i : Real) - μ := by - exact sub_le_sub_right (by exact_mod_cast (hlo i)) μ + exact sub_le_sub_right + (by + simpa using dyadicToReal_le_of_le (hlo i)) + μ exact le_trans h1 h2 have hhi' : (x i : Real) - μ ≤ (hi i : Real) - (μLo : Real) := by have h1 : (x i : Real) - μ ≤ (hi i : Real) - μ := by - exact sub_le_sub_right (by exact_mod_cast (hhi i)) μ + exact sub_le_sub_right + (by + simpa using dyadicToReal_le_of_le (hhi i)) + μ have h2 : (hi i : Real) - μ ≤ (hi i : Real) - (μLo : Real) := by exact sub_le_sub_left hmean_lo_real (hi i : Real) exact le_trans h1 h2 have hbound := abs_le_max_of_bounds hlo' hhi' - simpa [centeredBound, μLo, μHi, Rat.cast_abs, Rat.cast_sub, Rat.cast_max] using hbound - have hvar_nonneg : 0 ≤ variance x := variance_nonneg x hne + simpa [centeredBound, μLo, μHi, dyadicToReal_abs, dyadicToReal_sub, + dyadicToReal_max] using hbound have hsqrt_lower : - (sqrtLower eps : Real) ≤ Real.sqrt (varEps : Real) := by + (sqrtLower eps : Real) ≤ Real.sqrt varEps := by have hsqrt_eps : (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) simpa using h - have hle : (eps : Real) ≤ (varEps : Real) := by - have hle' : eps ≤ varEps := le_add_of_nonneg_left hvar_nonneg - exact_mod_cast hle' - have hsqrt_eps' : Real.sqrt (eps : Real) ≤ Real.sqrt (varEps : Real) := by + have hvar_nonneg : 0 ≤ (varianceRat x : Real) := varianceRat_nonneg_real x hne + have hle : (eps : Real) ≤ varEps := by + have hle' : (eps : Real) ≤ (varianceRat x : Real) + (eps : Real) := + le_add_of_nonneg_left hvar_nonneg + simpa [varEps] using hle' + have hsqrt_eps' : Real.sqrt (eps : Real) ≤ Real.sqrt varEps := by exact Real.sqrt_le_sqrt hle exact le_trans hsqrt_eps hsqrt_eps' have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by - exact_mod_cast (sqrtLower_pos (q := eps) heps) - have hinv : invStd ≤ (invStdBound : Real) := by + simpa [dyadicToReal_zero] using + (dyadicToReal_lt_iff (x := 0) (y := sqrtLower eps)).2 hsqrt + have hinv_sqrt : invStd ≤ (sqrtLower eps : Real)⁻¹ := by have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower - simpa [invStd, invStdBound] using h + simpa [invStd] using h + have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by + have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt + have hdiv := dyadicDivUp_ge_real (x := 1) (y := sqrtLower eps) hy + simpa [invStdBound, one_div] using hdiv + have hinv : invStd ≤ (invStdBound : Real) := by + exact le_trans hinv_sqrt hinv_bound have hinv_nonneg : 0 ≤ invStd := by - have hsqrt_nonneg : 0 ≤ Real.sqrt (varEps : Real) := by + have hsqrt_nonneg : 0 ≤ Real.sqrt varEps := by exact Real.sqrt_nonneg _ exact inv_nonneg.2 hsqrt_nonneg have hmul1 : |(x i : Real) - μ| * invStd ≤ @@ -893,7 +756,7 @@ theorem layerNormIntervalBounds_spec {n : Nat} have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg simp [t, abs_mul, hinv_abs, mul_assoc] simpa [ht] using hmul2 - let radius : Fin n → Rat := fun j => |gamma j| * centeredBound j * invStdBound + let radius : Fin n → Dyadic := fun j => |gamma j| * centeredBound j * invStdBound have ht_abs' : |t| ≤ (radius i : Real) := by simpa [radius, centeredBound, invStdBound] using ht_abs have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by @@ -919,74 +782,86 @@ theorem layerNormIntervalBounds_spec {n : Nat} /-- Interval bounds for LayerNorm outputs from an absolute input bound. -/ def layerNormAbsBounds {n : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) (absBound : Rat) : - (Fin n → Rat) × (Fin n → Rat) := - let centeredBound : Rat := 2 * absBound - let invStdBound : Rat := (sqrtLower eps)⁻¹ - let radius : Fin n → Rat := fun i => |gamma i| * centeredBound * invStdBound + (eps : Dyadic) (gamma beta : Fin n → Dyadic) (absBound : Dyadic) : + (Fin n → Dyadic) × (Fin n → Dyadic) := + let centeredBound : Dyadic := 2 * absBound + let invStdBound : Dyadic := dyadicDivUp 1 (sqrtLower eps) + let radius : Fin n → Dyadic := fun i => |gamma i| * centeredBound * invStdBound (fun i => beta i - radius i, fun i => beta i + radius i) /-- `layerNormAbsBounds` soundness for real LayerNorm outputs under absolute input bounds. -/ theorem layerNormAbsBounds_spec {n : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) (absBound : Rat) (x : Fin n → Rat) - (hne : n ≠ 0) (heps : 0 < eps) (habs : ∀ i, |x i| ≤ absBound) : + (eps : Dyadic) (gamma beta : Fin n → Dyadic) (absBound : Dyadic) (x : Fin n → Dyadic) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) + (habs : ∀ i, |x i| ≤ absBound) : let bounds := layerNormAbsBounds eps gamma beta absBound ∀ i, (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i ∧ layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by classical intro bounds i - have hmean_abs : |mean x| ≤ absBound := - mean_abs_le_bound x absBound hne habs - have hmean_abs_real : |(mean x : Real)| ≤ (absBound : Real) := by - exact_mod_cast hmean_abs + have hmean_abs_real : |(meanRat x : Real)| ≤ (absBound : Real) := by + have h := + meanReal_abs_le_bound (x := fun j => (x j : Real)) (bound := absBound) hne + (by + intro j + exact dyadicToReal_abs_le_of_le (habs j)) + simpa [meanReal_eq_meanRat] using h have hbound_nonneg : 0 ≤ absBound := by have hposn : 0 < n := Nat.pos_of_ne_zero hne let i0 : Fin n := ⟨0, hposn⟩ have h0 : 0 ≤ |x i0| := abs_nonneg _ exact le_trans h0 (habs i0) - let centeredBound : Rat := 2 * absBound - let invStdBound : Rat := (sqrtLower eps)⁻¹ - let varEps : Rat := variance x + eps - let μ : Real := mean x - let invStd : Real := (Real.sqrt (varEps : Real))⁻¹ + let centeredBound : Dyadic := 2 * absBound + let invStdBound : Dyadic := dyadicDivUp 1 (sqrtLower eps) + let varEps : Real := (varianceRat x : Real) + (eps : Real) + let μ : Real := meanRat x + let invStd : Real := (Real.sqrt varEps)⁻¹ have hcentered_abs : |(x i : Real) - μ| ≤ (centeredBound : Real) := by have h1 : |(x i : Real) - μ| ≤ |(x i : Real)| + |μ| := by simpa [sub_eq_add_neg, abs_neg] using abs_add_le (x i : Real) (-μ) have hx : |(x i : Real)| ≤ (absBound : Real) := by - exact_mod_cast (habs i) + exact dyadicToReal_abs_le_of_le (habs i) have hmu : |μ| ≤ (absBound : Real) := by - simpa using hmean_abs_real + simpa [μ] using hmean_abs_real have h2 : |(x i : Real)| + |μ| ≤ (absBound : Real) + (absBound : Real) := add_le_add hx hmu have h12 : |(x i : Real) - μ| ≤ (absBound : Real) + (absBound : Real) := le_trans h1 h2 simpa [centeredBound, two_mul] using h12 have hbound_nonneg_real : 0 ≤ (absBound : Real) := by - exact_mod_cast hbound_nonneg + exact dyadicToReal_nonneg_of_nonneg hbound_nonneg have hcentered_nonneg : 0 ≤ (centeredBound : Real) := by have hsum := add_nonneg hbound_nonneg_real hbound_nonneg_real simpa [centeredBound, two_mul] using hsum - have hvar_nonneg : 0 ≤ variance x := variance_nonneg x hne have hsqrt_lower : - (sqrtLower eps : Real) ≤ Real.sqrt (varEps : Real) := by + (sqrtLower eps : Real) ≤ Real.sqrt varEps := by have hsqrt_eps : (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) simpa using h - have hle : (eps : Real) ≤ (varEps : Real) := by - have hle' : eps ≤ varEps := le_add_of_nonneg_left hvar_nonneg - exact_mod_cast hle' - have hsqrt_eps' : Real.sqrt (eps : Real) ≤ Real.sqrt (varEps : Real) := by + have hvar_nonneg : 0 ≤ (varianceRat x : Real) := varianceRat_nonneg_real x hne + have hle : (eps : Real) ≤ varEps := by + have hle' : (eps : Real) ≤ (varianceRat x : Real) + (eps : Real) := + le_add_of_nonneg_left hvar_nonneg + simpa [varEps] using hle' + have hsqrt_eps' : Real.sqrt (eps : Real) ≤ Real.sqrt varEps := by exact Real.sqrt_le_sqrt hle exact le_trans hsqrt_eps hsqrt_eps' have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by - exact_mod_cast (sqrtLower_pos (q := eps) heps) - have hinv : invStd ≤ (invStdBound : Real) := by + simpa [dyadicToReal_zero] using + (dyadicToReal_lt_iff (x := 0) (y := sqrtLower eps)).2 hsqrt + have hinv_sqrt : invStd ≤ (sqrtLower eps : Real)⁻¹ := by have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower - simpa [invStd, invStdBound] using h + simpa [invStd] using h + have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by + have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt + have hdiv := dyadicDivUp_ge_real (x := 1) (y := sqrtLower eps) hy + simpa [invStdBound, one_div] using hdiv + have hinv : invStd ≤ (invStdBound : Real) := by + exact le_trans hinv_sqrt hinv_bound have hinv_nonneg : 0 ≤ invStd := by - have hsqrt_nonneg : 0 ≤ Real.sqrt (varEps : Real) := by + have hsqrt_nonneg : 0 ≤ Real.sqrt varEps := by exact Real.sqrt_nonneg _ exact inv_nonneg.2 hsqrt_nonneg have hmul1 : |(x i : Real) - μ| * invStd ≤ @@ -1012,7 +887,7 @@ theorem layerNormAbsBounds_spec {n : Nat} have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg simp [t, abs_mul, hinv_abs, mul_assoc] simpa [ht] using hmul2 - let radius : Fin n → Rat := fun j => |gamma j| * centeredBound * invStdBound + let radius : Fin n → Dyadic := fun j => |gamma j| * centeredBound * invStdBound have ht_abs' : |t| ≤ (radius i : Real) := by simpa [radius, centeredBound, invStdBound] using ht_abs have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by @@ -1036,8 +911,9 @@ theorem layerNormAbsBounds_spec {n : Nat} /-- `layerNormAbsBounds` soundness for real LayerNorm outputs on real inputs. -/ theorem layerNormAbsBounds_spec_real {n : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) (absBound : Rat) (x : Fin n → Real) - (hne : n ≠ 0) (heps : 0 < eps) (habs : ∀ i, |x i| ≤ (absBound : Real)) : + (eps : Dyadic) (gamma beta : Fin n → Dyadic) (absBound : Dyadic) (x : Fin n → Real) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) + (habs : ∀ i, |x i| ≤ (absBound : Real)) : let bounds := layerNormAbsBounds eps gamma beta absBound ∀ i, (bounds.1 i : Real) ≤ layerNormRealOfReal eps gamma beta x i ∧ @@ -1051,8 +927,8 @@ theorem layerNormAbsBounds_spec_real {n : Nat} let i0 : Fin n := ⟨0, hposn⟩ have h0 : 0 ≤ |x i0| := abs_nonneg _ exact le_trans h0 (habs i0) - let centeredBound : Rat := 2 * absBound - let invStdBound : Rat := (sqrtLower eps)⁻¹ + let centeredBound : Dyadic := 2 * absBound + let invStdBound : Dyadic := dyadicDivUp 1 (sqrtLower eps) let varEps : Real := varianceReal x + (eps : Real) let μ : Real := meanReal x let invStd : Real := (Real.sqrt varEps)⁻¹ @@ -1085,10 +961,17 @@ theorem layerNormAbsBounds_spec_real {n : Nat} exact Real.sqrt_le_sqrt hle exact le_trans hsqrt_eps hsqrt_eps' have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by - exact_mod_cast (sqrtLower_pos (q := eps) heps) - have hinv : invStd ≤ (invStdBound : Real) := by + simpa [dyadicToReal_zero] using + (dyadicToReal_lt_iff (x := 0) (y := sqrtLower eps)).2 hsqrt + have hinv_sqrt : invStd ≤ (sqrtLower eps : Real)⁻¹ := by have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower - simpa [invStd, invStdBound] using h + simpa [invStd] using h + have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by + have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt + have hdiv := dyadicDivUp_ge_real (x := 1) (y := sqrtLower eps) hy + simpa [invStdBound, one_div] using hdiv + have hinv : invStd ≤ (invStdBound : Real) := by + exact le_trans hinv_sqrt hinv_bound have hinv_nonneg : 0 ≤ invStd := by have hsqrt_nonneg : 0 ≤ Real.sqrt varEps := by exact Real.sqrt_nonneg _ @@ -1116,7 +999,7 @@ theorem layerNormAbsBounds_spec_real {n : Nat} have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg simp [t, abs_mul, hinv_abs, mul_assoc] simpa [ht] using hmul2 - let radius : Fin n → Rat := fun j => |gamma j| * centeredBound * invStdBound + let radius : Fin n → Dyadic := fun j => |gamma j| * centeredBound * invStdBound have ht_abs' : |t| ≤ (radius i : Real) := by simpa [radius, centeredBound, invStdBound] using ht_abs have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by @@ -1140,8 +1023,8 @@ theorem layerNormAbsBounds_spec_real {n : Nat} /-- `layerNormIntervalBounds` soundness for real LayerNorm outputs on real inputs. -/ theorem layerNormIntervalBounds_spec_real {n : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) (lo hi : Fin n → Rat) (x : Fin n → Real) - (hne : n ≠ 0) (heps : 0 < eps) + (eps : Dyadic) (gamma beta : Fin n → Dyadic) (lo hi : Fin n → Dyadic) (x : Fin n → Real) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) (hlo : ∀ i, (lo i : Real) ≤ x i) (hhi : ∀ i, x i ≤ (hi i : Real)) : let bounds := layerNormIntervalBounds eps gamma beta lo hi ∀ i, @@ -1153,16 +1036,24 @@ theorem layerNormIntervalBounds_spec_real {n : Nat} have h := meanReal_le_meanReal (x := fun j => (lo j : Real)) (y := x) hne (fun j => hlo j) - simpa [meanReal_ratCast] using h - have hmean_hi : meanReal x ≤ (mean hi : Real) := by + have hrat : (meanRat lo : Real) ≤ meanReal x := by + simpa [meanReal_eq_meanRat] using h + have hdown : (mean lo : Real) ≤ (meanRat lo : Real) := by + simpa [mean_def lo hne] using dyadicOfRatDown_le_real (meanRat lo) + exact le_trans hdown hrat + have hmean_hi : meanReal x ≤ (meanUpper hi : Real) := by have h := meanReal_le_meanReal (x := x) (y := fun j => (hi j : Real)) hne (fun j => hhi j) - simpa [meanReal_ratCast] using h - let μLo : Rat := mean lo - let μHi : Rat := mean hi - let centeredBound : Fin n → Rat := fun j => max |lo j - μHi| |hi j - μLo| - let invStdBound : Rat := (sqrtLower eps)⁻¹ + have hrat : meanReal x ≤ (meanRat hi : Real) := by + simpa [meanReal_eq_meanRat] using h + have hup : (meanRat hi : Real) ≤ (meanUpper hi : Real) := by + simpa [meanUpper_def hi hne] using real_le_dyadicOfRatUp (meanRat hi) + exact le_trans hrat hup + let μLo : Dyadic := mean lo + let μHi : Dyadic := meanUpper hi + let centeredBound : Fin n → Dyadic := fun j => max |lo j - μHi| |hi j - μLo| + let invStdBound : Dyadic := dyadicDivUp 1 (sqrtLower eps) let varEps : Real := varianceReal x + (eps : Real) let μ : Real := meanReal x let invStd : Real := (Real.sqrt varEps)⁻¹ @@ -1170,7 +1061,7 @@ theorem layerNormIntervalBounds_spec_real {n : Nat} have h0 : 0 ≤ centeredBound i := by dsimp [centeredBound] exact le_trans (abs_nonneg _) (le_max_left _ _) - exact_mod_cast h0 + exact dyadicToReal_nonneg_of_nonneg h0 have hcentered_abs : |x i - μ| ≤ (centeredBound i : Real) := by have hmean_lo_real : (μLo : Real) ≤ μ := by simpa [μLo, μ] using hmean_lo @@ -1189,7 +1080,8 @@ theorem layerNormIntervalBounds_spec_real {n : Nat} exact sub_le_sub_left hmean_lo_real (hi i : Real) exact le_trans h1 h2 have hbound := abs_le_max_of_bounds hlo' hhi' - simpa [centeredBound, μLo, μHi, Rat.cast_abs, Rat.cast_sub, Rat.cast_max] using hbound + simpa [centeredBound, μLo, μHi, dyadicToReal_abs, dyadicToReal_sub, + dyadicToReal_max] using hbound have hvar_nonneg : 0 ≤ varianceReal x := varianceReal_nonneg x hne have hsqrt_lower : (sqrtLower eps : Real) ≤ Real.sqrt varEps := by @@ -1205,10 +1097,17 @@ theorem layerNormIntervalBounds_spec_real {n : Nat} exact Real.sqrt_le_sqrt hle exact le_trans hsqrt_eps hsqrt_eps' have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by - exact_mod_cast (sqrtLower_pos (q := eps) heps) - have hinv : invStd ≤ (invStdBound : Real) := by + simpa [dyadicToReal_zero] using + (dyadicToReal_lt_iff (x := 0) (y := sqrtLower eps)).2 hsqrt + have hinv_sqrt : invStd ≤ (sqrtLower eps : Real)⁻¹ := by have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower - simpa [invStd, invStdBound] using h + simpa [invStd] using h + have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by + have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt + have hdiv := dyadicDivUp_ge_real (x := 1) (y := sqrtLower eps) hy + simpa [invStdBound, one_div] using hdiv + have hinv : invStd ≤ (invStdBound : Real) := by + exact le_trans hinv_sqrt hinv_bound have hinv_nonneg : 0 ≤ invStd := by have hsqrt_nonneg : 0 ≤ Real.sqrt varEps := by exact Real.sqrt_nonneg _ @@ -1235,7 +1134,7 @@ theorem layerNormIntervalBounds_spec_real {n : Nat} have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg simp [t, abs_mul, hinv_abs, mul_assoc] simpa [ht] using hmul2 - let radius : Fin n → Rat := fun j => |gamma j| * centeredBound j * invStdBound + let radius : Fin n → Dyadic := fun j => |gamma j| * centeredBound j * invStdBound have ht_abs' : |t| ≤ (radius i : Real) := by simpa [radius, centeredBound, invStdBound] using ht_abs have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by diff --git a/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean b/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean new file mode 100644 index 0000000..1c4112e --- /dev/null +++ b/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean @@ -0,0 +1,260 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.BigOperators.Fin +import Mathlib.Algebra.BigOperators.Group.Finset.Basic +import Mathlib.Algebra.Order.BigOperators.Group.Finset +import Mathlib.Algebra.Order.Field.Basic +import Mathlib.Algebra.Order.Ring.Basic +import Mathlib.Data.Rat.BigOperators +import Mathlib.Data.Rat.Cast.Order +import Nfp.Core.Basic + +/-! +Mean/variance helpers for LayerNorm bounds. + +This module isolates the dyadic and real mean/variance definitions and their +basic lemmas to keep `LayerNorm` bounds modular. +-/ + +namespace Nfp + +namespace Sound + +namespace Bounds + +open scoped BigOperators + +/-- Sum as a rational, used for exact mean/variance computations. -/ +def sumRat {n : Nat} (x : Fin n → Dyadic) : Rat := + ∑ i, (x i : Rat) + +/-- Exact mean as a rational (defaults to `0` when `n = 0`). -/ +def meanRat {n : Nat} (x : Fin n → Dyadic) : Rat := + if n = 0 then + 0 + else + (sumRat x) / n + +/-- Mean rounded down to dyadic precision (defaults to `0` when `n = 0`). -/ +def mean {n : Nat} (x : Fin n → Dyadic) : Dyadic := + if n = 0 then + 0 + else + dyadicOfRatDown (meanRat x) + +/-- Mean rounded up to dyadic precision (defaults to `0` when `n = 0`). -/ +def meanUpper {n : Nat} (x : Fin n → Dyadic) : Dyadic := + if n = 0 then + 0 + else + dyadicOfRatUp (meanRat x) + +/-- Unfold `mean` when `n ≠ 0`. -/ +theorem mean_def {n : Nat} (x : Fin n → Dyadic) (h : n ≠ 0) : + mean x = dyadicOfRatDown (meanRat x) := by + simp [mean, h] + +/-- Unfold `meanUpper` when `n ≠ 0`. -/ +theorem meanUpper_def {n : Nat} (x : Fin n → Dyadic) (h : n ≠ 0) : + meanUpper x = dyadicOfRatUp (meanRat x) := by + simp [meanUpper, h] + +/-- Exact variance as a rational (defaults to `0` when `n = 0`). -/ +def varianceRat {n : Nat} (x : Fin n → Dyadic) : Rat := + if n = 0 then + 0 + else + let μ := meanRat x + (∑ i, ((x i : Rat) - μ) ^ 2) / n + +/-- Variance rounded down to dyadic precision (defaults to `0` when `n = 0`). -/ +def variance {n : Nat} (x : Fin n → Dyadic) : Dyadic := + if n = 0 then + 0 + else + dyadicOfRatDown (varianceRat x) + +/-- Variance rounded up to dyadic precision (defaults to `0` when `n = 0`). -/ +def varianceUpper {n : Nat} (x : Fin n → Dyadic) : Dyadic := + if n = 0 then + 0 + else + dyadicOfRatUp (varianceRat x) + +/-- Unfold `variance` when `n ≠ 0`. -/ +theorem variance_def {n : Nat} (x : Fin n → Dyadic) (h : n ≠ 0) : + variance x = dyadicOfRatDown (varianceRat x) := by + simp [variance, h] + +/-! Interval helpers. -/ + +/-- Absolute value bound from endpoint bounds. -/ +theorem abs_le_max_of_bounds {α : Type _} [Ring α] [LinearOrder α] [IsOrderedRing α] + {a b z : α} + (hlo : a ≤ z) (hhi : z ≤ b) : + |z| ≤ max |a| |b| := by + have hleft : -max |a| |b| ≤ z := by + have hneg : -max |a| |b| ≤ a := by + have hneg' : -max |a| |b| ≤ -|a| := by + exact neg_le_neg (le_max_left _ _) + have hneg'' : -|a| ≤ a := by + have h : -a ≤ |a| := neg_le_abs a + simpa using (neg_le_neg h) + exact le_trans hneg' hneg'' + exact le_trans hneg hlo + have hright : z ≤ max |a| |b| := by + have hb : b ≤ |b| := by + exact le_abs_self b + have hb' : b ≤ max |a| |b| := le_trans hb (le_max_right _ _) + exact le_trans hhi hb' + exact (abs_le.mpr ⟨hleft, hright⟩) + +/-! Real-valued mean and variance. -/ + +/-- Mean of a real vector (defaults to `0` when `n = 0`). -/ +noncomputable def meanReal {n : Nat} (x : Fin n → Real) : Real := + if n = 0 then + 0 + else + (∑ i, x i) / n + +/-- Unfold `meanReal` when `n ≠ 0`. -/ +theorem meanReal_def {n : Nat} (x : Fin n → Real) (h : n ≠ 0) : + meanReal x = (∑ i, x i) / n := by + simp [meanReal, h] + +/-- `meanReal` agrees with `mean` after casting. -/ +theorem meanReal_eq_meanRat {n : Nat} (x : Fin n → Dyadic) : + meanReal (fun i => (x i : Real)) = (meanRat x : Real) := by + by_cases h : n = 0 + · simp [meanReal, meanRat, h] + · have hsum : + (sumRat x : Real) = ∑ i, (x i : Real) := by + classical + unfold sumRat + simp [dyadicToReal, Rat.cast_sum] + have hmean : (meanRat x : Real) = (sumRat x : Real) / n := by + simp [meanRat, h] + have hreal : meanReal (fun i => (x i : Real)) = (∑ i, (x i : Real)) / n := by + simp [meanReal, h] + simpa [hmean, hsum] using hreal + +/-- Mean is monotone under pointwise order (real inputs). -/ +theorem meanReal_le_meanReal {n : Nat} (x y : Fin n → Real) (hne : n ≠ 0) + (hxy : ∀ i, x i ≤ y i) : meanReal x ≤ meanReal y := by + classical + have hsum : (∑ i, x i) ≤ ∑ i, y i := by + refine Finset.sum_le_sum ?_ + intro i _ + exact hxy i + have hden : 0 ≤ (n : Real) := by + simp + have hdiv : (∑ i, x i) / n ≤ (∑ i, y i) / n := + div_le_div_of_nonneg_right hsum hden + simpa [meanReal, hne] using hdiv + +/-- Mean monotonicity for dyadic inputs, interpreted in reals. -/ +theorem meanRat_le_meanRat_real {n : Nat} (x y : Fin n → Dyadic) (hne : n ≠ 0) + (hxy : ∀ i, x i ≤ y i) : + (meanRat x : Real) ≤ (meanRat y : Real) := by + have hreal : + meanReal (fun i => (x i : Real)) ≤ meanReal (fun i => (y i : Real)) := by + refine meanReal_le_meanReal (x := fun i => (x i : Real)) (y := fun i => (y i : Real)) hne ?_ + intro i + exact dyadicToReal_le_of_le (hxy i) + simpa [meanReal_eq_meanRat] using hreal + +/-- Variance of a real vector (defaults to `0` when `n = 0`). -/ +noncomputable def varianceReal {n : Nat} (x : Fin n → Real) : Real := + if n = 0 then + 0 + else + let μ := meanReal x + (∑ i, (x i - μ) ^ 2) / n + +/-- Unfold `varianceReal` when `n ≠ 0`. -/ +theorem varianceReal_def {n : Nat} (x : Fin n → Real) (h : n ≠ 0) : + varianceReal x = + let μ := meanReal x + (∑ i, (x i - μ) ^ 2) / n := by + simp [varianceReal, h] + +/-- Variance is nonnegative when `n ≠ 0`. -/ +theorem varianceReal_nonneg {n : Nat} (x : Fin n → Real) (h : n ≠ 0) : + 0 ≤ varianceReal x := by + classical + have hsum : 0 ≤ ∑ i, (x i - meanReal x) ^ 2 := by + refine Finset.sum_nonneg ?_ + intro i _ + exact sq_nonneg (x i - meanReal x) + have hden : 0 ≤ (n : Real) := by + simp + have hdiv : 0 ≤ (∑ i, (x i - meanReal x) ^ 2) / n := + div_nonneg hsum hden + simpa [varianceReal_def x h] using hdiv + +theorem varianceReal_eq_varianceRat {n : Nat} (x : Fin n → Dyadic) : + varianceReal (fun i => (x i : Real)) = (varianceRat x : Real) := by + by_cases h : n = 0 + · simp [varianceReal, varianceRat, h] + · have hmean := meanReal_eq_meanRat (n := n) x + have hsum : + (∑ i, ((x i : Real) - (meanRat x : Real)) ^ 2) = + (∑ i, ((x i : Rat) - meanRat x) ^ 2 : Rat) := by + classical + simp [dyadicToReal, Rat.cast_sum] + have hreal : varianceReal (fun i => (x i : Real)) = + (∑ i, ((x i : Real) - meanReal (fun j => (x j : Real))) ^ 2) / n := by + simp [varianceReal, h] + have hrat : (varianceRat x : Real) = + (∑ i, ((x i : Rat) - meanRat x) ^ 2 : Rat) / n := by + simp [varianceRat, h] + calc + varianceReal (fun i => (x i : Real)) + = (∑ i, ((x i : Real) - meanReal (fun j => (x j : Real))) ^ 2) / n := hreal + _ = (∑ i, ((x i : Real) - (meanRat x : Real)) ^ 2) / n := by + simp [hmean] + _ = (∑ i, ((x i : Rat) - meanRat x) ^ 2 : Rat) / n := by + simp [hsum] + _ = (varianceRat x : Real) := hrat.symm + +/-- Variance is nonnegative when `n ≠ 0`, interpreted in reals. -/ +theorem varianceRat_nonneg_real {n : Nat} (x : Fin n → Dyadic) (hne : n ≠ 0) : + 0 ≤ (varianceRat x : Real) := by + have hreal := varianceReal_nonneg (x := fun i => (x i : Real)) hne + simpa [varianceReal_eq_varianceRat] using hreal + +/-- Absolute mean bound from per-coordinate bounds (real inputs). -/ +theorem meanReal_abs_le_bound {n : Nat} (x : Fin n → Real) (bound : Dyadic) + (hne : n ≠ 0) (hbound : ∀ i, |x i| ≤ (bound : Real)) : + |meanReal x| ≤ (bound : Real) := by + classical + have hsum_abs : + |∑ i : Fin n, x i| ≤ ∑ i : Fin n, |x i| := by + simpa using + (Finset.abs_sum_le_sum_abs + (f := fun i : Fin n => x i) + (s := (Finset.univ : Finset (Fin n)))) + have hsum_bound : ∑ i : Fin n, |x i| ≤ ∑ i : Fin n, (bound : Real) := by + refine Finset.sum_le_sum ?_ + intro i _ + exact hbound i + have hsum_le : |∑ i : Fin n, x i| ≤ (n : Real) * (bound : Real) := by + have hsum := le_trans hsum_abs hsum_bound + simpa [Finset.sum_const, Finset.card_univ, mul_comm] using hsum + have hpos : 0 < (n : Real) := by + exact (Nat.cast_pos (α := Real)).2 (Nat.pos_of_ne_zero hne) + have hsum_le' : |∑ i : Fin n, x i| ≤ (bound : Real) * (n : Real) := by + simpa [mul_comm] using hsum_le + have hdiv : |∑ i : Fin n, x i| / (n : Real) ≤ (bound : Real) := by + exact (div_le_iff₀ hpos).2 hsum_le' + have habs_mean : + |(∑ i : Fin n, x i) / (n : Real)| ≤ (bound : Real) := by + simpa [abs_div, abs_of_nonneg (le_of_lt hpos)] using hdiv + simpa [meanReal_def x hne] using habs_mean + +end Bounds + +end Sound + +end Nfp diff --git a/Nfp/Sound/Bounds/MatrixNorm.lean b/Nfp/Sound/Bounds/MatrixNorm.lean index 9f68ecd..0b0cab1 100644 --- a/Nfp/Sound/Bounds/MatrixNorm.lean +++ b/Nfp/Sound/Bounds/MatrixNorm.lean @@ -3,20 +3,20 @@ import Mathlib.Algebra.BigOperators.Fin import Mathlib.Algebra.Order.BigOperators.Group.Finset import Mathlib.Algebra.Order.Ring.Abs -import Mathlib.Algebra.Order.Ring.Rat +import Mathlib.Data.Fintype.Basic import Mathlib.Data.Matrix.Mul -import Mathlib.Data.Rat.BigOperators -import Mathlib.Data.Rat.Cast.Order import Mathlib.Data.Real.Basic import Nfp.Circuit.Cert.DownstreamLinear import Nfp.Circuit.Cert.ResidualInterval +import Nfp.Core.Basic +import Nfp.Sound.Bounds.MatrixNorm.Interval import Nfp.Sound.Linear.FinFold /-! Row-sum matrix norms for downstream linear certificates. These bounds are used to compute verified downstream error certificates -from explicit Rat matrices. +from explicit Dyadic matrices. -/ namespace Nfp @@ -27,85 +27,29 @@ namespace Bounds open scoped BigOperators -private theorem sumFin_eq_sum_univ {n : Nat} (f : Fin n → Rat) : - Linear.sumFin n f = ∑ i, f i := by - classical - have hfold : - Linear.sumFin n f = (List.finRange n).foldl (fun acc i => acc + f i) 0 := by - simpa using Linear.sumFin_eq_list_foldl n f - have hmap : - ((List.finRange n).map f).foldl (fun acc x : Rat => acc + x) 0 = - (List.finRange n).foldl (fun acc i => acc + f i) 0 := by - have hmap' : - ∀ l : List (Fin n), ∀ init : Rat, - (l.map f).foldl (fun acc x : Rat => acc + x) init = - l.foldl (fun acc i => acc + f i) init := by - intro l - induction l with - | nil => - intro init - simp - | cons a l ih => - intro init - simp [ih] - exact hmap' (List.finRange n) 0 - let _ : Std.Commutative (fun a b : Rat => a + b) := - ⟨by intro a b; exact add_comm _ _⟩ - let _ : Std.Associative (fun a b : Rat => a + b) := - ⟨by intro a b c; exact add_assoc _ _ _⟩ - have hfoldr : - ((List.finRange n).map f).foldl (fun acc x : Rat => acc + x) 0 = - ((List.finRange n).map f).foldr (fun x acc => x + acc) 0 := by - simpa using - (List.foldl_eq_foldr (f := fun acc x : Rat => acc + x) - (a := 0) (l := (List.finRange n).map f)) - have hsum_list : - ((List.finRange n).map f).sum = (List.finRange n).foldl (fun acc i => acc + f i) 0 := by - calc - ((List.finRange n).map f).sum - = ((List.finRange n).map f).foldr (fun x acc => x + acc) 0 := by - rfl - _ = ((List.finRange n).map f).foldl (fun acc x : Rat => acc + x) 0 := by - exact hfoldr.symm - _ = (List.finRange n).foldl (fun acc i => acc + f i) 0 := by - exact hmap - have hsum_univ : ((List.finRange n).map f).sum = ∑ i, f i := by - exact (Fin.sum_univ_def f).symm - calc - Linear.sumFin n f - = (List.finRange n).foldl (fun acc i => acc + f i) 0 := hfold - _ = ((List.finRange n).map f).sum := hsum_list.symm - _ = ∑ i, f i := hsum_univ - /-- Row-sum of absolute values for a matrix row. -/ -def rowSum {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : Rat := +def rowSum {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) (i : Fin m) : Dyadic := Linear.sumFin n (fun j => |W i j|) /-- Weighted row-sum using per-coordinate bounds. -/ -def rowSumWeighted {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (bound : Fin n → Rat) (i : Fin m) : Rat := +def rowSumWeighted {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) + (bound : Fin n → Dyadic) (i : Fin m) : Dyadic := Linear.sumFin n (fun j => |W i j| * bound j) /-- Maximum row-sum norm (defaults to `0` on empty matrices). -/ -def rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) : Rat := - if h : (Finset.univ : Finset (Fin m)).Nonempty then - (Finset.univ).sup' h (fun i => rowSum W i) - else - 0 +def rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) : Dyadic := + Linear.foldlFin m (fun acc i => max acc (rowSum W i)) 0 /-- Maximum weighted row-sum (defaults to `0` on empty matrices). -/ -def rowSumWeightedNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (bound : Fin n → Rat) : Rat := - if h : (Finset.univ : Finset (Fin m)).Nonempty then - (Finset.univ).sup' h (fun i => rowSumWeighted W bound i) - else - 0 +def rowSumWeightedNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) + (bound : Fin n → Dyadic) : Dyadic := + Linear.foldlFin m (fun acc i => max acc (rowSumWeighted W bound i)) 0 /-- Row-sums are nonnegative. -/ -theorem rowSum_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : +theorem rowSum_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) (i : Fin m) : 0 ≤ rowSum W i := by have hsum : rowSum W i = ∑ j, |W i j| := by - simp [rowSum, sumFin_eq_sum_univ] + simp [rowSum, Linear.sumFin_eq_sum_univ] have hnonneg : 0 ≤ ∑ j, |W i j| := by refine Finset.sum_nonneg ?_ intro j _ @@ -113,11 +57,11 @@ theorem rowSum_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : simpa [hsum] using hnonneg /-- Weighted row-sums are nonnegative under nonnegative bounds. -/ -theorem rowSumWeighted_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (bound : Fin n → Rat) (i : Fin m) (hbound : ∀ j, 0 ≤ bound j) : +theorem rowSumWeighted_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) + (bound : Fin n → Dyadic) (i : Fin m) (hbound : ∀ j, 0 ≤ bound j) : 0 ≤ rowSumWeighted W bound i := by have hsum : rowSumWeighted W bound i = ∑ j, |W i j| * bound j := by - simp [rowSumWeighted, sumFin_eq_sum_univ] + simp [rowSumWeighted, Linear.sumFin_eq_sum_univ] have hnonneg : 0 ≤ ∑ j, |W i j| * bound j := by refine Finset.sum_nonneg ?_ intro j _ @@ -125,348 +69,45 @@ theorem rowSumWeighted_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) simpa [hsum] using hnonneg /-- Each row-sum is bounded by the row-sum norm. -/ -theorem rowSum_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : +theorem rowSum_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) (i : Fin m) : rowSum W i ≤ rowSumNorm W := by - classical - have h : (Finset.univ : Finset (Fin m)).Nonempty := ⟨i, by simp⟩ - have hle : - rowSum W i ≤ (Finset.univ).sup' h (fun i => rowSum W i) := by - simpa using - (Finset.le_sup' - (s := (Finset.univ : Finset (Fin m))) - (f := fun i => rowSum W i) - (by simp : i ∈ (Finset.univ : Finset (Fin m)))) - simpa [rowSumNorm, h] using hle + simpa [rowSumNorm] using + (foldlFin_max_ge (f := fun j => rowSum W j) i) /-- Each weighted row-sum is bounded by the weighted row-sum norm. -/ -theorem rowSumWeighted_le_rowSumWeightedNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (bound : Fin n → Rat) (i : Fin m) : +theorem rowSumWeighted_le_rowSumWeightedNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) + (bound : Fin n → Dyadic) (i : Fin m) : rowSumWeighted W bound i ≤ rowSumWeightedNorm W bound := by - classical - have h : (Finset.univ : Finset (Fin m)).Nonempty := ⟨i, by simp⟩ - have hle : - rowSumWeighted W bound i ≤ - (Finset.univ).sup' h (fun i => rowSumWeighted W bound i) := by - simpa using - (Finset.le_sup' - (s := (Finset.univ : Finset (Fin m))) - (f := fun i => rowSumWeighted W bound i) - (by simp : i ∈ (Finset.univ : Finset (Fin m)))) - simpa [rowSumWeightedNorm, h] using hle + simpa [rowSumWeightedNorm] using + (foldlFin_max_ge (f := fun j => rowSumWeighted W bound j) i) /-- The row-sum norm is nonnegative. -/ -theorem rowSumNorm_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) : +theorem rowSumNorm_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) : 0 ≤ rowSumNorm W := by - classical - by_cases h : (Finset.univ : Finset (Fin m)).Nonempty - · rcases h with ⟨i, hi⟩ - have hrow : 0 ≤ rowSum W i := rowSum_nonneg W i - have hle : rowSum W i ≤ rowSumNorm W := rowSum_le_rowSumNorm W i - exact le_trans hrow hle - · simp [rowSumNorm, h] + simpa [rowSumNorm] using + (foldlFin_max_ge_init (f := fun i => rowSum W i) (init := (0 : Dyadic))) -/-- Weighted row-sum norm is nonnegative under nonnegative bounds. -/ -theorem rowSumWeightedNorm_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (bound : Fin n → Rat) (hbound : ∀ j, 0 ≤ bound j) : +/-- Weighted row-sum norm is nonnegative. -/ +theorem rowSumWeightedNorm_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) + (bound : Fin n → Dyadic) : 0 ≤ rowSumWeightedNorm W bound := by - classical - by_cases h : (Finset.univ : Finset (Fin m)).Nonempty - · rcases h with ⟨i, hi⟩ - have hrow : 0 ≤ rowSumWeighted W bound i := - rowSumWeighted_nonneg W bound i hbound - have hle : rowSumWeighted W bound i ≤ rowSumWeightedNorm W bound := - rowSumWeighted_le_rowSumWeightedNorm W bound i - exact le_trans hrow hle - · simp [rowSumWeightedNorm, h] + simpa [rowSumWeightedNorm] using + (foldlFin_max_ge_init (f := fun i => rowSumWeighted W bound i) (init := (0 : Dyadic))) /-- Downstream error from per-coordinate residual bounds. -/ -def downstreamErrorFromBounds {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (bound : Fin n → Rat) : Rat := +def downstreamErrorFromBounds {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) + (bound : Fin n → Dyadic) : Dyadic := rowSumWeightedNorm W bound /-- `downstreamErrorFromBounds` is nonnegative. -/ -theorem downstreamErrorFromBounds_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (bound : Fin n → Rat) (hbound : ∀ j, 0 ≤ bound j) : +theorem downstreamErrorFromBounds_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) + (bound : Fin n → Dyadic) : 0 ≤ downstreamErrorFromBounds W bound := by - simpa [downstreamErrorFromBounds] using rowSumWeightedNorm_nonneg W bound hbound - -/-- Lower interval endpoint for a dot product with per-coordinate bounds. -/ -def dotIntervalLower {n : Nat} (v lo hi : Fin n → Rat) : Rat := - Linear.sumFin n (fun j => if 0 ≤ v j then v j * lo j else v j * hi j) - -/-- Upper interval endpoint for a dot product with per-coordinate bounds. -/ -def dotIntervalUpper {n : Nat} (v lo hi : Fin n → Rat) : Rat := - Linear.sumFin n (fun j => if 0 ≤ v j then v j * hi j else v j * lo j) - -/-- Absolute bound from interval endpoints for a dot product. -/ -def dotIntervalAbsBound {n : Nat} (v lo hi : Fin n → Rat) : Rat := - max |dotIntervalLower v lo hi| |dotIntervalUpper v lo hi| - -/-- Lower interval endpoint for a matrix-vector product under input intervals. -/ -def mulVecIntervalLower {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (lo hi : Fin n → Rat) : Fin m → Rat := - fun i => dotIntervalLower (fun j => W i j) lo hi - -/-- Upper interval endpoint for a matrix-vector product under input intervals. -/ -def mulVecIntervalUpper {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (lo hi : Fin n → Rat) : Fin m → Rat := - fun i => dotIntervalUpper (fun j => W i j) lo hi - -theorem dotIntervalLower_le_dotProduct {n : Nat} (v lo hi x : Fin n → Rat) - (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : - dotIntervalLower v lo hi ≤ dotProduct v x := by - classical - simp only [dotIntervalLower, sumFin_eq_sum_univ, dotProduct] - refine Finset.sum_le_sum ?_ - intro j _ - by_cases hv : 0 ≤ v j - · have h1 : v j * lo j ≤ v j * x j := - mul_le_mul_of_nonneg_left (hlo j) hv - simpa [hv] using h1 - · have hv' : v j ≤ 0 := le_of_lt (lt_of_not_ge hv) - have h1 : v j * hi j ≤ v j * x j := - mul_le_mul_of_nonpos_left (hhi j) hv' - simpa [hv] using h1 - -theorem dotProduct_le_dotIntervalUpper {n : Nat} (v lo hi x : Fin n → Rat) - (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : - dotProduct v x ≤ dotIntervalUpper v lo hi := by - classical - simp only [dotIntervalUpper, sumFin_eq_sum_univ, dotProduct] - refine Finset.sum_le_sum ?_ - intro j _ - by_cases hv : 0 ≤ v j - · have h1 : v j * x j ≤ v j * hi j := - mul_le_mul_of_nonneg_left (hhi j) hv - simpa [hv] using h1 - · have hv' : v j ≤ 0 := le_of_lt (lt_of_not_ge hv) - have h1 : v j * x j ≤ v j * lo j := - mul_le_mul_of_nonpos_left (hlo j) hv' - simpa [hv] using h1 - -theorem abs_le_max_abs_abs_of_interval {a b x : Rat} (hlo : a ≤ x) (hhi : x ≤ b) : - |x| ≤ max |a| |b| := by - by_cases hx : 0 ≤ x - · have hb : 0 ≤ b := le_trans hx hhi - have hx' : |x| = x := abs_of_nonneg hx - have hb' : |b| = b := abs_of_nonneg hb - calc - |x| = x := hx' - _ ≤ b := hhi - _ = |b| := hb'.symm - _ ≤ max |a| |b| := le_max_right _ _ - · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) - have ha : a ≤ 0 := le_trans hlo hx' - have hxabs : |x| = -x := abs_of_nonpos hx' - have haabs : |a| = -a := abs_of_nonpos ha - calc - |x| = -x := hxabs - _ ≤ -a := neg_le_neg hlo - _ = |a| := by simp [haabs] - _ ≤ max |a| |b| := le_max_left _ _ - -/-- Global absolute bound from interval endpoints. -/ -def intervalAbsBound {n : Nat} (lo hi : Fin n → Rat) : Rat := - if h : (Finset.univ : Finset (Fin n)).Nonempty then - (Finset.univ).sup' h (fun i => max |lo i| |hi i|) - else - 0 - -/-- `intervalAbsBound` bounds any element inside the interval. -/ -theorem abs_le_intervalAbsBound {n : Nat} (lo hi x : Fin n → Rat) - (hlo : ∀ i, lo i ≤ x i) (hhi : ∀ i, x i ≤ hi i) (i : Fin n) : - |x i| ≤ intervalAbsBound lo hi := by - classical - have hbound : |x i| ≤ max |lo i| |hi i| := - abs_le_max_abs_abs_of_interval (hlo i) (hhi i) - have hnonempty : (Finset.univ : Finset (Fin n)).Nonempty := ⟨i, by simp⟩ - have hsup : - max |lo i| |hi i| ≤ - (Finset.univ).sup' hnonempty (fun j => max |lo j| |hi j|) := by - simpa using - (Finset.le_sup' - (s := (Finset.univ : Finset (Fin n))) - (f := fun j => max |lo j| |hi j|) - (by simp : i ∈ (Finset.univ : Finset (Fin n)))) - have hfinal : |x i| ≤ (Finset.univ).sup' hnonempty (fun j => max |lo j| |hi j|) := - le_trans hbound hsup - simpa [intervalAbsBound, hnonempty] using hfinal - -theorem abs_dotProduct_le_dotIntervalAbsBound {n : Nat} (v lo hi x : Fin n → Rat) - (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : - |dotProduct v x| ≤ dotIntervalAbsBound v lo hi := by - have hlow : dotIntervalLower v lo hi ≤ dotProduct v x := - dotIntervalLower_le_dotProduct v lo hi x hlo hhi - have hhigh : dotProduct v x ≤ dotIntervalUpper v lo hi := - dotProduct_le_dotIntervalUpper v lo hi x hlo hhi - have habs : |dotProduct v x| ≤ - max |dotIntervalLower v lo hi| |dotIntervalUpper v lo hi| := - abs_le_max_abs_abs_of_interval hlow hhigh - unfold dotIntervalAbsBound - exact habs - -/-! Real-valued bounds from rational intervals. -/ - -theorem dotIntervalLower_le_dotProduct_real {n : Nat} (v lo hi : Fin n → Rat) - (x : Fin n → Real) - (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : - (dotIntervalLower v lo hi : Real) ≤ dotProduct (fun j => (v j : Real)) x := by - classical - have hcast : - (dotIntervalLower v lo hi : Real) = - ∑ j, if 0 ≤ v j then (v j : Real) * (lo j : Real) else (v j : Real) * (hi j : Real) := by - conv_lhs => simp [dotIntervalLower, sumFin_eq_sum_univ] - refine Finset.sum_congr rfl ?_ - intro j _ - by_cases hv : 0 ≤ v j - · simp [hv] - · simp [hv] - have hsum : - (∑ j, if 0 ≤ v j then (v j : Real) * (lo j : Real) else (v j : Real) * (hi j : Real)) ≤ - ∑ j, (v j : Real) * x j := by - refine Finset.sum_le_sum ?_ - intro j _ - by_cases hv : 0 ≤ v j - · have h1 : (v j : Real) * (lo j : Real) ≤ (v j : Real) * x j := by - exact mul_le_mul_of_nonneg_left (hlo j) (by exact_mod_cast hv) - simpa [hv] using h1 - · have hv' : (v j : Real) ≤ 0 := by - exact_mod_cast (le_of_lt (lt_of_not_ge hv)) - have h1 : (v j : Real) * (hi j : Real) ≤ (v j : Real) * x j := by - exact mul_le_mul_of_nonpos_left (hhi j) hv' - simpa [hv] using h1 - simpa [hcast, dotProduct] using hsum - -theorem dotProduct_le_dotIntervalUpper_real {n : Nat} (v lo hi : Fin n → Rat) - (x : Fin n → Real) - (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : - dotProduct (fun j => (v j : Real)) x ≤ (dotIntervalUpper v lo hi : Real) := by - classical - have hcast : - (dotIntervalUpper v lo hi : Real) = - ∑ j, if 0 ≤ v j then (v j : Real) * (hi j : Real) else (v j : Real) * (lo j : Real) := by - conv_lhs => simp [dotIntervalUpper, sumFin_eq_sum_univ] - refine Finset.sum_congr rfl ?_ - intro j _ - by_cases hv : 0 ≤ v j - · simp [hv] - · simp [hv] - have hsum : - ∑ j, (v j : Real) * x j ≤ - ∑ j, if 0 ≤ v j then (v j : Real) * (hi j : Real) else (v j : Real) * (lo j : Real) := by - refine Finset.sum_le_sum ?_ - intro j _ - by_cases hv : 0 ≤ v j - · have h1 : (v j : Real) * x j ≤ (v j : Real) * (hi j : Real) := by - exact mul_le_mul_of_nonneg_left (hhi j) (by exact_mod_cast hv) - simpa [hv] using h1 - · have hv' : (v j : Real) ≤ 0 := by - exact_mod_cast (le_of_lt (lt_of_not_ge hv)) - have h1 : (v j : Real) * x j ≤ (v j : Real) * (lo j : Real) := by - exact mul_le_mul_of_nonpos_left (hlo j) hv' - simpa [hv] using h1 - simpa [hcast, dotProduct] using hsum - -theorem abs_le_max_abs_abs_of_interval_real {a b x : Real} (hlo : a ≤ x) (hhi : x ≤ b) : - |x| ≤ max |a| |b| := by - by_cases hx : 0 ≤ x - · have hb : 0 ≤ b := le_trans hx hhi - have hx' : |x| = x := abs_of_nonneg hx - have hb' : |b| = b := abs_of_nonneg hb - calc - |x| = x := hx' - _ ≤ b := hhi - _ = |b| := hb'.symm - _ ≤ max |a| |b| := le_max_right _ _ - · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) - have ha : a ≤ 0 := le_trans hlo hx' - have hxabs : |x| = -x := abs_of_nonpos hx' - have haabs : |a| = -a := abs_of_nonpos ha - calc - |x| = -x := hxabs - _ ≤ -a := neg_le_neg hlo - _ = |a| := by simp [haabs] - _ ≤ max |a| |b| := le_max_left _ _ - -/-- `intervalAbsBound` controls real-valued coordinates inside a rational interval. -/ -theorem abs_le_intervalAbsBound_real {n : Nat} (lo hi : Fin n → Rat) (x : Fin n → Real) - (hlo : ∀ i, (lo i : Real) ≤ x i) (hhi : ∀ i, x i ≤ (hi i : Real)) (i : Fin n) : - |x i| ≤ (intervalAbsBound lo hi : Real) := by - classical - have hbound : |x i| ≤ max |(lo i : Real)| |(hi i : Real)| := - abs_le_max_abs_abs_of_interval_real (hlo i) (hhi i) - have hnonempty : (Finset.univ : Finset (Fin n)).Nonempty := ⟨i, by simp⟩ - have hsup : - max |lo i| |hi i| ≤ - (Finset.univ).sup' hnonempty (fun j => max |lo j| |hi j|) := by - simpa using - (Finset.le_sup' - (s := (Finset.univ : Finset (Fin n))) - (f := fun j => max |lo j| |hi j|) - (by simp : i ∈ (Finset.univ : Finset (Fin n)))) - have hsup' : max |lo i| |hi i| ≤ intervalAbsBound lo hi := by - simpa [intervalAbsBound, hnonempty] using hsup - have hsup_real : - max |(lo i : Real)| |(hi i : Real)| ≤ (intervalAbsBound lo hi : Real) := by - exact_mod_cast hsup' - exact le_trans hbound hsup_real - -theorem abs_dotProduct_le_dotIntervalAbsBound_real {n : Nat} (v lo hi : Fin n → Rat) - (x : Fin n → Real) - (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : - |dotProduct (fun j => (v j : Real)) x| ≤ (dotIntervalAbsBound v lo hi : Real) := by - have hlow : - (dotIntervalLower v lo hi : Real) ≤ dotProduct (fun j => (v j : Real)) x := - dotIntervalLower_le_dotProduct_real v lo hi x hlo hhi - have hhigh : - dotProduct (fun j => (v j : Real)) x ≤ (dotIntervalUpper v lo hi : Real) := - dotProduct_le_dotIntervalUpper_real v lo hi x hlo hhi - have habs : - |dotProduct (fun j => (v j : Real)) x| ≤ - max |(dotIntervalLower v lo hi : Real)| |(dotIntervalUpper v lo hi : Real)| := - abs_le_max_abs_abs_of_interval_real hlow hhigh - have hcast : - (dotIntervalAbsBound v lo hi : Real) = - max |(dotIntervalLower v lo hi : Real)| |(dotIntervalUpper v lo hi : Real)| := by - simp [dotIntervalAbsBound] - simpa [hcast] using habs - -/-- Matrix-interval lower bounds dominate matrix-vector products. -/ -theorem mulVecIntervalLower_le_mulVec {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (lo hi x : Fin n → Rat) (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : - ∀ i, mulVecIntervalLower W lo hi i ≤ Matrix.mulVec W x i := by - intro i - have h := - dotIntervalLower_le_dotProduct (v := fun j => W i j) lo hi x hlo hhi - simpa [mulVecIntervalLower, Matrix.mulVec, dotProduct] using h - -/-- Matrix-interval upper bounds dominate matrix-vector products. -/ -theorem mulVec_le_mulVecIntervalUpper {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (lo hi x : Fin n → Rat) (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : - ∀ i, Matrix.mulVec W x i ≤ mulVecIntervalUpper W lo hi i := by - intro i - have h := - dotProduct_le_dotIntervalUpper (v := fun j => W i j) lo hi x hlo hhi - simpa [mulVecIntervalUpper, Matrix.mulVec, dotProduct] using h - -/-- Interval endpoints for `mulVec` are ordered when the input interval is ordered. -/ -theorem mulVecIntervalLower_le_upper {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (lo hi : Fin n → Rat) (hlohi : ∀ j, lo j ≤ hi j) : - ∀ i, mulVecIntervalLower W lo hi i ≤ mulVecIntervalUpper W lo hi i := by - intro i - have hlow : - dotIntervalLower (fun j => W i j) lo hi ≤ dotProduct (fun j => W i j) lo := - dotIntervalLower_le_dotProduct (v := fun j => W i j) lo hi lo - (fun j => le_rfl) hlohi - have hhigh : - dotProduct (fun j => W i j) lo ≤ dotIntervalUpper (fun j => W i j) lo hi := - dotProduct_le_dotIntervalUpper (v := fun j => W i j) lo hi lo - (fun j => le_rfl) hlohi - exact le_trans hlow hhigh + simpa [downstreamErrorFromBounds] using rowSumWeightedNorm_nonneg W bound /-- Build a residual-interval certificate by applying a matrix to an input interval. -/ -def buildResidualIntervalCertFromMatrix {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (lo hi : Fin n → Rat) (hlohi : ∀ j, lo j ≤ hi j) : +def buildResidualIntervalCertFromMatrix {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) + (lo hi : Fin n → Dyadic) (hlohi : ∀ j, lo j ≤ hi j) : {c : Circuit.ResidualIntervalCert m // Circuit.ResidualIntervalBounds c} := by let lo' := mulVecIntervalLower W lo hi let hi' := mulVecIntervalUpper W lo hi @@ -476,8 +117,8 @@ def buildResidualIntervalCertFromMatrix {m n : Nat} (W : Matrix (Fin m) (Fin n) exact mulVecIntervalLower_le_upper W lo hi hlohi i /-- Row-sum norm bounds a matrix-vector product under a uniform input bound. -/ -theorem abs_mulVec_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (x : Fin n → Rat) (inputBound : Rat) +theorem abs_mulVec_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) + (x : Fin n → Dyadic) (inputBound : Dyadic) (hx : ∀ j, |x j| ≤ inputBound) (hinput : 0 ≤ inputBound) : ∀ i, |Matrix.mulVec W x i| ≤ rowSumNorm W * inputBound := by intro i @@ -505,7 +146,7 @@ theorem abs_mulVec_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (s := (Finset.univ : Finset (Fin n))) (f := fun j => |W i j|) (a := inputBound)) - simpa [rowSum, sumFin_eq_sum_univ] using hsum.symm + simpa [rowSum, Linear.sumFin_eq_sum_univ] using hsum.symm have hmul : |Matrix.mulVec W x i| ≤ rowSum W i * inputBound := by simpa [Matrix.mulVec, dotProduct] using h1.trans (h2.trans_eq h3) exact hmul @@ -515,8 +156,8 @@ theorem abs_mulVec_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) exact hrow.trans hmul /-- Build a downstream linear certificate from a matrix and input bound. -/ -def buildDownstreamLinearCert {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (inputBound : Rat) (hinput : 0 ≤ inputBound) : +def buildDownstreamLinearCert {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) + (inputBound : Dyadic) (hinput : 0 ≤ inputBound) : {c : Circuit.DownstreamLinearCert // Circuit.DownstreamLinearBounds c} := by let gain := rowSumNorm W let error := gain * inputBound diff --git a/Nfp/Sound/Bounds/MatrixNorm/Interval.lean b/Nfp/Sound/Bounds/MatrixNorm/Interval.lean new file mode 100644 index 0000000..db10a41 --- /dev/null +++ b/Nfp/Sound/Bounds/MatrixNorm/Interval.lean @@ -0,0 +1,436 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.BigOperators.Fin +import Mathlib.Algebra.Order.BigOperators.Group.Finset +import Mathlib.Algebra.Order.Ring.Abs +import Mathlib.Data.Matrix.Mul +import Mathlib.Data.Real.Basic +import Nfp.Core.Basic +import Nfp.Sound.Linear.FinFold + +/-! +Interval bounds for dot products and matrix-vector products. + +This module isolates interval-bound helpers used across downstream certificates. +-/ + +namespace Nfp + +namespace Sound + +namespace Bounds + +open scoped BigOperators + +lemma foldl_max_ge_init {α : Type _} (f : α → Dyadic) : + ∀ (l : List α) (init : Dyadic), + init ≤ l.foldl (fun acc x => max acc (f x)) init := by + intro l init + induction l generalizing init with + | nil => + simp + | cons a l ih => + have hinit : init ≤ max init (f a) := le_max_left _ _ + have hrest : max init (f a) ≤ l.foldl (fun acc x => max acc (f x)) (max init (f a)) := + ih (max init (f a)) + simpa [List.foldl] using le_trans hinit hrest + +lemma foldl_max_ge_mem {α : Type _} (f : α → Dyadic) : + ∀ (l : List α) (a : α) (init : Dyadic), + a ∈ l → f a ≤ l.foldl (fun acc x => max acc (f x)) init := by + intro l a init hmem + induction l generalizing init with + | nil => + cases hmem + | cons b l ih => + have hmem' : a = b ∨ a ∈ l := by + simpa using hmem + cases hmem' with + | inl h => + subst h + have hstep : f a ≤ max init (f a) := le_max_right _ _ + have hrest : + max init (f a) ≤ l.foldl (fun acc x => max acc (f x)) (max init (f a)) := + foldl_max_ge_init (f := f) l (max init (f a)) + simpa [List.foldl] using le_trans hstep hrest + | inr h => + have h' := ih (init := max init (f b)) h + simpa [List.foldl] using h' + +lemma foldlFin_max_ge_init {n : Nat} (f : Fin n → Dyadic) (init : Dyadic) : + init ≤ Linear.foldlFin n (fun acc j => max acc (f j)) init := by + classical + have hlist : + init ≤ (List.finRange n).foldl (fun acc j => max acc (f j)) init := + foldl_max_ge_init (f := f) (List.finRange n) init + have hfold : + Linear.foldlFin n (fun acc j => max acc (f j)) init = + (List.finRange n).foldl (fun acc j => max acc (f j)) init := by + simpa [Linear.foldlFin_eq_foldl] using + (Fin.foldl_eq_foldl_finRange + (f := fun acc j => max acc (f j)) (x := init) (n := n)) + simpa [hfold] using hlist + +lemma foldlFin_max_ge {n : Nat} (f : Fin n → Dyadic) (i : Fin n) : + f i ≤ Linear.foldlFin n (fun acc j => max acc (f j)) 0 := by + classical + have hmem : i ∈ List.finRange n := by + simp + have hlist : + f i ≤ (List.finRange n).foldl (fun acc j => max acc (f j)) 0 := + foldl_max_ge_mem (f := f) (List.finRange n) i 0 hmem + have hfold : + Linear.foldlFin n (fun acc j => max acc (f j)) 0 = + (List.finRange n).foldl (fun acc j => max acc (f j)) 0 := by + simpa [Linear.foldlFin_eq_foldl] using + (Fin.foldl_eq_foldl_finRange + (f := fun acc j => max acc (f j)) (x := (0 : Dyadic)) (n := n)) + simpa [hfold] using hlist + +/-- Lower interval endpoint for a dot product with per-coordinate bounds. -/ +def dotIntervalLower {n : Nat} (v lo hi : Fin n → Dyadic) : Dyadic := + Linear.sumFin n (fun j => if 0 ≤ v j then v j * lo j else v j * hi j) + +/-- Upper interval endpoint for a dot product with per-coordinate bounds. -/ +def dotIntervalUpper {n : Nat} (v lo hi : Fin n → Dyadic) : Dyadic := + Linear.sumFin n (fun j => if 0 ≤ v j then v j * hi j else v j * lo j) + +/-- Lower interval endpoint using a shared-denominator accumulator. -/ +def dotIntervalLowerCommonDen {n : Nat} (v lo hi : Fin n → Dyadic) : Dyadic := + Linear.sumFinCommonDen n (fun j => if 0 ≤ v j then v j * lo j else v j * hi j) + +/-- Upper interval endpoint using a shared-denominator accumulator. -/ +def dotIntervalUpperCommonDen {n : Nat} (v lo hi : Fin n → Dyadic) : Dyadic := + Linear.sumFinCommonDen n (fun j => if 0 ≤ v j then v j * hi j else v j * lo j) + +/-- Lower interval endpoint using unnormalized accumulation. -/ +def dotIntervalLowerUnnorm {n : Nat} (v lo hi : Fin n → Dyadic) : Dyadic := + dotIntervalLower v lo hi + +/-- Upper interval endpoint using unnormalized accumulation. -/ +def dotIntervalUpperUnnorm {n : Nat} (v lo hi : Fin n → Dyadic) : Dyadic := + dotIntervalUpper v lo hi + +theorem dotIntervalLowerCommonDen_eq {n : Nat} (v lo hi : Fin n → Dyadic) : + dotIntervalLowerCommonDen v lo hi = dotIntervalLower v lo hi := by + simp [dotIntervalLowerCommonDen, dotIntervalLower, Linear.sumFinCommonDen_eq_sumFin] + +theorem dotIntervalUpperCommonDen_eq {n : Nat} (v lo hi : Fin n → Dyadic) : + dotIntervalUpperCommonDen v lo hi = dotIntervalUpper v lo hi := by + simp [dotIntervalUpperCommonDen, dotIntervalUpper, Linear.sumFinCommonDen_eq_sumFin] + +theorem dotIntervalLowerUnnorm_eq {n : Nat} (v lo hi : Fin n → Dyadic) : + dotIntervalLowerUnnorm v lo hi = dotIntervalLower v lo hi := rfl + +theorem dotIntervalUpperUnnorm_eq {n : Nat} (v lo hi : Fin n → Dyadic) : + dotIntervalUpperUnnorm v lo hi = dotIntervalUpper v lo hi := rfl + +/-! Cached endpoints. -/ + +/-- Cached-array lower interval endpoint for a dot product using normalized dyadic sums. -/ +def dotIntervalLowerCachedDyadic {n : Nat} (v lo hi : Fin n → Dyadic) : Dyadic := + let vArr := Array.ofFn v + let loArr := Array.ofFn lo + let hiArr := Array.ofFn hi + Linear.sumFin n (fun j => + let vj := vArr[j.1]'(by + have hsize : vArr.size = n := by simp [vArr] + simp [hsize, j.isLt]) + let loj := loArr[j.1]'(by + have hsize : loArr.size = n := by simp [loArr] + simp [hsize, j.isLt]) + let hij := hiArr[j.1]'(by + have hsize : hiArr.size = n := by simp [hiArr] + simp [hsize, j.isLt]) + if 0 ≤ vj then vj * loj else vj * hij) + +/-- Cached-array upper interval endpoint for a dot product using normalized dyadic sums. -/ +def dotIntervalUpperCachedDyadic {n : Nat} (v lo hi : Fin n → Dyadic) : Dyadic := + let vArr := Array.ofFn v + let loArr := Array.ofFn lo + let hiArr := Array.ofFn hi + Linear.sumFin n (fun j => + let vj := vArr[j.1]'(by + have hsize : vArr.size = n := by simp [vArr] + simp [hsize, j.isLt]) + let loj := loArr[j.1]'(by + have hsize : loArr.size = n := by simp [loArr] + simp [hsize, j.isLt]) + let hij := hiArr[j.1]'(by + have hsize : hiArr.size = n := by simp [hiArr] + simp [hsize, j.isLt]) + if 0 ≤ vj then vj * hij else vj * loj) + +theorem dotIntervalLowerCachedRat_eq {n : Nat} (v lo hi : Fin n → Dyadic) : + dotIntervalLowerCachedDyadic v lo hi = dotIntervalLower v lo hi := by + classical + simp [dotIntervalLowerCachedDyadic, dotIntervalLower, Linear.sumFin_eq_list_foldl, + Array.getElem_ofFn] + +theorem dotIntervalUpperCachedRat_eq {n : Nat} (v lo hi : Fin n → Dyadic) : + dotIntervalUpperCachedDyadic v lo hi = dotIntervalUpper v lo hi := by + classical + simp [dotIntervalUpperCachedDyadic, dotIntervalUpper, Linear.sumFin_eq_list_foldl, + Array.getElem_ofFn] + +/-! Absolute bounds. -/ + +/-- Absolute bound from interval endpoints for a dot product. -/ +def dotIntervalAbsBound {n : Nat} (v lo hi : Fin n → Dyadic) : Dyadic := + max |dotIntervalLower v lo hi| |dotIntervalUpper v lo hi| + +/-- Lower interval endpoint for a matrix-vector product under input intervals. -/ +def mulVecIntervalLower {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) + (lo hi : Fin n → Dyadic) : Fin m → Dyadic := + fun i => dotIntervalLower (fun j => W i j) lo hi + +/-- Upper interval endpoint for a matrix-vector product under input intervals. -/ +def mulVecIntervalUpper {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) + (lo hi : Fin n → Dyadic) : Fin m → Dyadic := + fun i => dotIntervalUpper (fun j => W i j) lo hi + +theorem dotIntervalLower_le_dotProduct {n : Nat} (v lo hi x : Fin n → Dyadic) + (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : + dotIntervalLower v lo hi ≤ dotProduct v x := by + classical + simp only [dotIntervalLower, Linear.sumFin_eq_sum_univ, dotProduct] + refine Finset.sum_le_sum ?_ + intro j _ + by_cases hv : 0 ≤ v j + · have h1 : v j * lo j ≤ v j * x j := + mul_le_mul_of_nonneg_left (hlo j) hv + simpa [hv] using h1 + · have hv' : v j ≤ 0 := le_of_lt (lt_of_not_ge hv) + have h1 : v j * hi j ≤ v j * x j := + mul_le_mul_of_nonpos_left (hhi j) hv' + simpa [hv] using h1 + +theorem dotProduct_le_dotIntervalUpper {n : Nat} (v lo hi x : Fin n → Dyadic) + (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : + dotProduct v x ≤ dotIntervalUpper v lo hi := by + classical + simp only [dotIntervalUpper, Linear.sumFin_eq_sum_univ, dotProduct] + refine Finset.sum_le_sum ?_ + intro j _ + by_cases hv : 0 ≤ v j + · have h1 : v j * x j ≤ v j * hi j := + mul_le_mul_of_nonneg_left (hhi j) hv + simpa [hv] using h1 + · have hv' : v j ≤ 0 := le_of_lt (lt_of_not_ge hv) + have h1 : v j * x j ≤ v j * lo j := + mul_le_mul_of_nonpos_left (hlo j) hv' + simpa [hv] using h1 + +theorem abs_le_max_abs_abs_of_interval {a b x : Dyadic} (hlo : a ≤ x) (hhi : x ≤ b) : + |x| ≤ max |a| |b| := by + by_cases hx : 0 ≤ x + · have hb : 0 ≤ b := le_trans hx hhi + have hx' : |x| = x := abs_of_nonneg hx + have hb' : |b| = b := abs_of_nonneg hb + calc + |x| = x := hx' + _ ≤ b := hhi + _ = |b| := hb'.symm + _ ≤ max |a| |b| := le_max_right _ _ + · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) + have ha : a ≤ 0 := le_trans hlo hx' + have hxabs : |x| = -x := abs_of_nonpos hx' + have haabs : |a| = -a := abs_of_nonpos ha + calc + |x| = -x := hxabs + _ ≤ -a := neg_le_neg hlo + _ = |a| := by simp [haabs] + _ ≤ max |a| |b| := le_max_left _ _ + +/-- Global absolute bound from interval endpoints. -/ +def intervalAbsBound {n : Nat} (lo hi : Fin n → Dyadic) : Dyadic := + Linear.foldlFin n (fun acc i => max acc (max |lo i| |hi i|)) 0 + +/-- `intervalAbsBound` dominates each endpoint absolute value. -/ +theorem max_abs_le_intervalAbsBound {n : Nat} (lo hi : Fin n → Dyadic) (i : Fin n) : + max |lo i| |hi i| ≤ intervalAbsBound lo hi := by + simpa [intervalAbsBound] using + (foldlFin_max_ge (f := fun j => max |lo j| |hi j|) i) + +/-- `intervalAbsBound` bounds any element inside the interval. -/ +theorem abs_le_intervalAbsBound {n : Nat} (lo hi x : Fin n → Dyadic) + (hlo : ∀ i, lo i ≤ x i) (hhi : ∀ i, x i ≤ hi i) (i : Fin n) : + |x i| ≤ intervalAbsBound lo hi := by + have hbound : |x i| ≤ max |lo i| |hi i| := + abs_le_max_abs_abs_of_interval (hlo i) (hhi i) + have hsup : max |lo i| |hi i| ≤ intervalAbsBound lo hi := + max_abs_le_intervalAbsBound lo hi i + exact le_trans hbound hsup + +theorem abs_dotProduct_le_dotIntervalAbsBound {n : Nat} (v lo hi x : Fin n → Dyadic) + (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : + |dotProduct v x| ≤ dotIntervalAbsBound v lo hi := by + have hlow : dotIntervalLower v lo hi ≤ dotProduct v x := + dotIntervalLower_le_dotProduct v lo hi x hlo hhi + have hhigh : dotProduct v x ≤ dotIntervalUpper v lo hi := + dotProduct_le_dotIntervalUpper v lo hi x hlo hhi + have habs : |dotProduct v x| ≤ + max |dotIntervalLower v lo hi| |dotIntervalUpper v lo hi| := + abs_le_max_abs_abs_of_interval hlow hhigh + unfold dotIntervalAbsBound + exact habs + +/-! Real-valued bounds from rational intervals. -/ + +theorem dotIntervalLower_le_dotProduct_real {n : Nat} (v lo hi : Fin n → Dyadic) + (x : Fin n → Real) + (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : + (dotIntervalLower v lo hi : Real) ≤ dotProduct (fun j => (v j : Real)) x := by + classical + have hcast : + (dotIntervalLower v lo hi : Real) = + ∑ j, if 0 ≤ v j then (v j : Real) * (lo j : Real) else (v j : Real) * (hi j : Real) := by + simpa [dotIntervalLower, dyadicToReal_mul, dyadicToReal_if] using + (Linear.dyadicToReal_sumFin + (f := fun j => if 0 ≤ v j then v j * lo j else v j * hi j)) + have hsum : + (∑ j, if 0 ≤ v j then (v j : Real) * (lo j : Real) else (v j : Real) * (hi j : Real)) ≤ + ∑ j, (v j : Real) * x j := by + refine Finset.sum_le_sum ?_ + intro j _ + by_cases hv : 0 ≤ v j + · have h1 : (v j : Real) * (lo j : Real) ≤ (v j : Real) * x j := by + have hv' : (0 : Real) ≤ (v j : Real) := dyadicToReal_nonneg_of_nonneg hv + exact mul_le_mul_of_nonneg_left (hlo j) hv' + simpa [hv] using h1 + · have hv' : (v j : Real) ≤ 0 := by + exact (dyadicToReal_nonpos_iff (x := v j)).2 (le_of_not_ge hv) + have h1 : (v j : Real) * (hi j : Real) ≤ (v j : Real) * x j := by + exact mul_le_mul_of_nonpos_left (hhi j) hv' + simpa [hv] using h1 + simpa [hcast, dotProduct] using hsum + +theorem dotProduct_le_dotIntervalUpper_real {n : Nat} (v lo hi : Fin n → Dyadic) + (x : Fin n → Real) + (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : + dotProduct (fun j => (v j : Real)) x ≤ (dotIntervalUpper v lo hi : Real) := by + classical + have hcast : + (dotIntervalUpper v lo hi : Real) = + ∑ j, if 0 ≤ v j then (v j : Real) * (hi j : Real) else (v j : Real) * (lo j : Real) := by + simpa [dotIntervalUpper, dyadicToReal_mul, dyadicToReal_if] using + (Linear.dyadicToReal_sumFin + (f := fun j => if 0 ≤ v j then v j * hi j else v j * lo j)) + have hsum : + ∑ j, (v j : Real) * x j ≤ + ∑ j, if 0 ≤ v j then (v j : Real) * (hi j : Real) else (v j : Real) * (lo j : Real) := by + refine Finset.sum_le_sum ?_ + intro j _ + by_cases hv : 0 ≤ v j + · have h1 : (v j : Real) * x j ≤ (v j : Real) * (hi j : Real) := by + have hv' : (0 : Real) ≤ (v j : Real) := dyadicToReal_nonneg_of_nonneg hv + exact mul_le_mul_of_nonneg_left (hhi j) hv' + simpa [hv] using h1 + · have hv' : (v j : Real) ≤ 0 := by + exact (dyadicToReal_nonpos_iff (x := v j)).2 (le_of_not_ge hv) + have h1 : (v j : Real) * x j ≤ (v j : Real) * (lo j : Real) := by + exact mul_le_mul_of_nonpos_left (hlo j) hv' + simpa [hv] using h1 + simpa [hcast, dotProduct] using hsum + +theorem abs_le_max_abs_abs_of_interval_real {a b x : Real} (hlo : a ≤ x) (hhi : x ≤ b) : + |x| ≤ max |a| |b| := by + by_cases hx : 0 ≤ x + · have hb : 0 ≤ b := le_trans hx hhi + have hx' : |x| = x := abs_of_nonneg hx + have hb' : |b| = b := abs_of_nonneg hb + calc + |x| = x := hx' + _ ≤ b := hhi + _ = |b| := hb'.symm + _ ≤ max |a| |b| := le_max_right _ _ + · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) + have ha : a ≤ 0 := le_trans hlo hx' + have hxabs : |x| = -x := abs_of_nonpos hx' + have haabs : |a| = -a := abs_of_nonpos ha + calc + |x| = -x := hxabs + _ ≤ -a := neg_le_neg hlo + _ = |a| := by simp [haabs] + _ ≤ max |a| |b| := le_max_left _ _ + +/-- `intervalAbsBound` controls real-valued coordinates inside a rational interval. -/ +theorem abs_le_intervalAbsBound_real {n : Nat} (lo hi : Fin n → Dyadic) (x : Fin n → Real) + (hlo : ∀ i, (lo i : Real) ≤ x i) (hhi : ∀ i, x i ≤ (hi i : Real)) (i : Fin n) : + |x i| ≤ (intervalAbsBound lo hi : Real) := by + have hbound : |x i| ≤ max |(lo i : Real)| |(hi i : Real)| := + abs_le_max_abs_abs_of_interval_real (hlo i) (hhi i) + have hsup_real : + max |(lo i : Real)| |(hi i : Real)| ≤ (intervalAbsBound lo hi : Real) := by + have hsup : max |lo i| |hi i| ≤ intervalAbsBound lo hi := + max_abs_le_intervalAbsBound lo hi i + have hlo : |lo i| ≤ intervalAbsBound lo hi := + le_trans (le_max_left _ _) hsup + have hhi : |hi i| ≤ intervalAbsBound lo hi := + le_trans (le_max_right _ _) hsup + have hlo_real : + |(lo i : Real)| ≤ (intervalAbsBound lo hi : Real) := by + exact dyadicToReal_abs_le_of_le hlo + have hhi_real : + |(hi i : Real)| ≤ (intervalAbsBound lo hi : Real) := by + exact dyadicToReal_abs_le_of_le hhi + exact max_le_iff.mpr ⟨hlo_real, hhi_real⟩ + exact le_trans hbound hsup_real + +theorem abs_dotProduct_le_dotIntervalAbsBound_real {n : Nat} (v lo hi : Fin n → Dyadic) + (x : Fin n → Real) + (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : + |dotProduct (fun j => (v j : Real)) x| ≤ (dotIntervalAbsBound v lo hi : Real) := by + have hlow : + (dotIntervalLower v lo hi : Real) ≤ dotProduct (fun j => (v j : Real)) x := + dotIntervalLower_le_dotProduct_real v lo hi x hlo hhi + have hhigh : + dotProduct (fun j => (v j : Real)) x ≤ (dotIntervalUpper v lo hi : Real) := + dotProduct_le_dotIntervalUpper_real v lo hi x hlo hhi + have habs : + |dotProduct (fun j => (v j : Real)) x| ≤ + max |(dotIntervalLower v lo hi : Real)| |(dotIntervalUpper v lo hi : Real)| := + abs_le_max_abs_abs_of_interval_real hlow hhigh + have hcast : + (dotIntervalAbsBound v lo hi : Real) = + max |(dotIntervalLower v lo hi : Real)| |(dotIntervalUpper v lo hi : Real)| := by + simp [dotIntervalAbsBound, dyadicToReal_abs, dyadicToReal_max] + simpa [hcast] using habs + +/-! Matrix-vector interval bounds. -/ + +theorem mulVecIntervalLower_le_mulVec {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) + (lo hi x : Fin n → Dyadic) (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : + ∀ i, mulVecIntervalLower W lo hi i ≤ Matrix.mulVec W x i := by + intro i + have h := + dotIntervalLower_le_dotProduct (v := fun j => W i j) lo hi x hlo hhi + simpa [mulVecIntervalLower, Matrix.mulVec, dotProduct] using h + +theorem mulVec_le_mulVecIntervalUpper {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) + (lo hi x : Fin n → Dyadic) (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : + ∀ i, Matrix.mulVec W x i ≤ mulVecIntervalUpper W lo hi i := by + intro i + have h := + dotProduct_le_dotIntervalUpper (v := fun j => W i j) lo hi x hlo hhi + simpa [mulVecIntervalUpper, Matrix.mulVec, dotProduct] using h + +theorem mulVecIntervalLower_le_upper {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) + (lo hi : Fin n → Dyadic) (hlohi : ∀ j, lo j ≤ hi j) : + ∀ i, mulVecIntervalLower W lo hi i ≤ mulVecIntervalUpper W lo hi i := by + intro i + have hlow : + dotIntervalLower (fun j => W i j) lo hi ≤ dotProduct (fun j => W i j) lo := + dotIntervalLower_le_dotProduct (v := fun j => W i j) lo hi lo + (fun j => le_rfl) hlohi + have hhigh : + dotProduct (fun j => W i j) lo ≤ dotIntervalUpper (fun j => W i j) lo hi := + dotProduct_le_dotIntervalUpper (v := fun j => W i j) lo hi lo + (fun j => le_rfl) hlohi + exact le_trans hlow hhigh + +end Bounds + +end Sound + +end Nfp diff --git a/Nfp/Sound/Bounds/Mlp.lean b/Nfp/Sound/Bounds/Mlp.lean index e699fb1..079c7e2 100644 --- a/Nfp/Sound/Bounds/Mlp.lean +++ b/Nfp/Sound/Bounds/Mlp.lean @@ -1,7 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Algebra.Order.Ring.Rat +import Nfp.Core.Basic import Nfp.Sound.Bounds.Gelu import Nfp.Sound.Bounds.LayerNorm import Nfp.Sound.Bounds.MatrixNorm @@ -20,8 +20,8 @@ open scoped BigOperators /-- Real-valued MLP with tanh-based GELU activations. -/ noncomputable def mlpReal {dModel hidden : Nat} - (wIn : Fin dModel → Fin hidden → Rat) (bIn : Fin hidden → Rat) - (wOut : Fin hidden → Fin dModel → Rat) (bOut : Fin dModel → Rat) + (wIn : Fin dModel → Fin hidden → Dyadic) (bIn : Fin hidden → Dyadic) + (wOut : Fin hidden → Fin dModel → Dyadic) (bOut : Fin dModel → Dyadic) (x : Fin dModel → Real) : Fin dModel → Real := fun i => let hidden : Fin hidden → Real := fun h => @@ -30,77 +30,78 @@ noncomputable def mlpReal {dModel hidden : Nat} /-- Interval bounds for a tanh-GELU MLP given input intervals. -/ def mlpBounds {dModel hidden : Nat} - (wIn : Fin dModel → Fin hidden → Rat) (bIn : Fin hidden → Rat) - (wOut : Fin hidden → Fin dModel → Rat) (bOut : Fin dModel → Rat) - (lo hi : Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := - let preLo : Fin hidden → Rat := fun h => + (wIn : Fin dModel → Fin hidden → Dyadic) (bIn : Fin hidden → Dyadic) + (wOut : Fin hidden → Fin dModel → Dyadic) (bOut : Fin dModel → Dyadic) + (lo hi : Fin dModel → Dyadic) : (Fin dModel → Dyadic) × (Fin dModel → Dyadic) := + let preLo : Fin hidden → Dyadic := fun h => dotIntervalLower (fun j => wIn j h) lo hi + bIn h - let preHi : Fin hidden → Rat := fun h => + let preHi : Fin hidden → Dyadic := fun h => dotIntervalUpper (fun j => wIn j h) lo hi + bIn h - let geluLo : Fin hidden → Rat := fun h => min (preLo h) 0 - let geluHi : Fin hidden → Rat := fun h => max (preHi h) 0 - let outLo : Fin dModel → Rat := fun i => + let geluBounds : Fin hidden → Dyadic × Dyadic := fun h => geluInterval (preLo h) (preHi h) + let geluLo : Fin hidden → Dyadic := fun h => (geluBounds h).1 + let geluHi : Fin hidden → Dyadic := fun h => (geluBounds h).2 + let outLo : Fin dModel → Dyadic := fun i => dotIntervalLower (fun h => wOut h i) geluLo geluHi + bOut i - let outHi : Fin dModel → Rat := fun i => + let outHi : Fin dModel → Dyadic := fun i => dotIntervalUpper (fun h => wOut h i) geluLo geluHi + bOut i (outLo, outHi) /-- `mlpBounds` soundness for real MLP outputs. -/ theorem mlpBounds_spec {dModel hidden : Nat} - (wIn : Fin dModel → Fin hidden → Rat) (bIn : Fin hidden → Rat) - (wOut : Fin hidden → Fin dModel → Rat) (bOut : Fin dModel → Rat) - (lo hi : Fin dModel → Rat) (x : Fin dModel → Real) + (wIn : Fin dModel → Fin hidden → Dyadic) (bIn : Fin hidden → Dyadic) + (wOut : Fin hidden → Fin dModel → Dyadic) (bOut : Fin dModel → Dyadic) + (lo hi : Fin dModel → Dyadic) (x : Fin dModel → Real) (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : let bounds := mlpBounds wIn bIn wOut bOut lo hi ∀ i, (bounds.1 i : Real) ≤ mlpReal wIn bIn wOut bOut x i ∧ mlpReal wIn bIn wOut bOut x i ≤ (bounds.2 i : Real) := by classical intro bounds i - let preLo : Fin hidden → Rat := fun h => + let preLo : Fin hidden → Dyadic := fun h => dotIntervalLower (fun j => wIn j h) lo hi + bIn h - let preHi : Fin hidden → Rat := fun h => + let preHi : Fin hidden → Dyadic := fun h => dotIntervalUpper (fun j => wIn j h) lo hi + bIn h let pre : Fin hidden → Real := fun h => dotProduct (fun j => (wIn j h : Real)) x + (bIn h : Real) have hpre_lower : ∀ h, (preLo h : Real) ≤ pre h := by intro h - have hdot := - dotIntervalLower_le_dotProduct_real (v := fun j => wIn j h) lo hi x hlo hhi - have hdot' := add_le_add_right hdot (bIn h : Real) - simpa [pre, preLo, Rat.cast_add] using hdot' + simpa [pre, preLo] using + add_le_add_right + (dotIntervalLower_le_dotProduct_real (v := fun j => wIn j h) lo hi x hlo hhi) + (bIn h : Real) have hpre_upper : ∀ h, pre h ≤ (preHi h : Real) := by intro h - have hdot := - dotProduct_le_dotIntervalUpper_real (v := fun j => wIn j h) lo hi x hlo hhi - have hdot' := add_le_add_right hdot (bIn h : Real) - simpa [pre, preHi, Rat.cast_add] using hdot' - let geluLo : Fin hidden → Rat := fun h => min (preLo h) 0 - let geluHi : Fin hidden → Rat := fun h => max (preHi h) 0 + simpa [pre, preHi] using + add_le_add_right + (dotProduct_le_dotIntervalUpper_real (v := fun j => wIn j h) lo hi x hlo hhi) + (bIn h : Real) + let geluBounds : Fin hidden → Dyadic × Dyadic := fun h => geluInterval (preLo h) (preHi h) + let geluLo : Fin hidden → Dyadic := fun h => (geluBounds h).1 + let geluHi : Fin hidden → Dyadic := fun h => (geluBounds h).2 let hidden : Fin hidden → Real := fun h => geluTanh (pre h) have hgelu : ∀ h, (geluLo h : Real) ≤ hidden h ∧ hidden h ≤ (geluHi h : Real) := by intro h have hbounds := geluInterval_bounds (lo := preLo h) (hi := preHi h) (hpre_lower h) (hpre_upper h) - dsimp [geluLo, geluHi, hidden, geluInterval] - exact hbounds - let outLo : Fin dModel → Rat := fun i => + simpa [geluLo, geluHi, geluBounds, hidden] using hbounds + let outLo : Fin dModel → Dyadic := fun i => dotIntervalLower (fun h => wOut h i) geluLo geluHi + bOut i - let outHi : Fin dModel → Rat := fun i => + let outHi : Fin dModel → Dyadic := fun i => dotIntervalUpper (fun h => wOut h i) geluLo geluHi + bOut i have hout_lower : (outLo i : Real) ≤ dotProduct (fun h => (wOut h i : Real)) hidden + (bOut i : Real) := by - have hdot := - dotIntervalLower_le_dotProduct_real (v := fun h => wOut h i) geluLo geluHi hidden - (fun h => (hgelu h).1) (fun h => (hgelu h).2) - have hdot' := add_le_add_right hdot (bOut i : Real) - simpa [outLo, Rat.cast_add] using hdot' + simpa [outLo] using + add_le_add_right + (dotIntervalLower_le_dotProduct_real (v := fun h => wOut h i) geluLo geluHi hidden + (fun h => (hgelu h).1) (fun h => (hgelu h).2)) + (bOut i : Real) have hout_upper : dotProduct (fun h => (wOut h i : Real)) hidden + (bOut i : Real) ≤ (outHi i : Real) := by - have hdot := - dotProduct_le_dotIntervalUpper_real (v := fun h => wOut h i) geluLo geluHi hidden - (fun h => (hgelu h).1) (fun h => (hgelu h).2) - have hdot' := add_le_add_right hdot (bOut i : Real) - simpa [outHi, Rat.cast_add] using hdot' + simpa [outHi] using + add_le_add_right + (dotProduct_le_dotIntervalUpper_real (v := fun h => wOut h i) geluLo geluHi hidden + (fun h => (hgelu h).1) (fun h => (hgelu h).2)) + (bOut i : Real) have hlo' : (outLo i : Real) ≤ mlpReal wIn bIn wOut bOut x i := by simpa [mlpReal, hidden, pre] using hout_lower have hhi' : mlpReal wIn bIn wOut bOut x i ≤ (outHi i : Real) := by @@ -110,19 +111,19 @@ theorem mlpBounds_spec {dModel hidden : Nat} /-- Interval bounds for a LayerNorm + MLP sublayer from exact inputs. -/ def layerNormMlpBounds {n hidden : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) - (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) - (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) - (x : Fin n → Rat) : (Fin n → Rat) × (Fin n → Rat) := + (eps : Dyadic) (gamma beta : Fin n → Dyadic) + (wIn : Fin n → Fin hidden → Dyadic) (bIn : Fin hidden → Dyadic) + (wOut : Fin hidden → Fin n → Dyadic) (bOut : Fin n → Dyadic) + (x : Fin n → Dyadic) : (Fin n → Dyadic) × (Fin n → Dyadic) := let ln := layerNormBounds eps gamma beta x mlpBounds wIn bIn wOut bOut ln.1 ln.2 /-- `layerNormMlpBounds` soundness for real LayerNorm + MLP outputs. -/ theorem layerNormMlpBounds_spec {n hidden : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) - (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) - (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) - (x : Fin n → Rat) (hne : n ≠ 0) (heps : 0 < eps) : + (eps : Dyadic) (gamma beta : Fin n → Dyadic) + (wIn : Fin n → Fin hidden → Dyadic) (bIn : Fin hidden → Dyadic) + (wOut : Fin hidden → Fin n → Dyadic) (bOut : Fin n → Dyadic) + (x : Fin n → Dyadic) (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : let bounds := layerNormMlpBounds eps gamma beta wIn bIn wOut bOut x ∀ i, (bounds.1 i : Real) ≤ mlpReal wIn bIn wOut bOut (layerNormReal eps gamma beta x) i ∧ @@ -130,7 +131,7 @@ theorem layerNormMlpBounds_spec {n hidden : Nat} classical intro bounds i let ln := layerNormBounds eps gamma beta x - have hln := layerNormBounds_spec eps gamma beta x hne heps + have hln := layerNormBounds_spec eps gamma beta x hne heps hsqrt have hlo : ∀ j, (ln.1 j : Real) ≤ layerNormReal eps gamma beta x j := fun j => (hln j).1 have hhi : ∀ j, layerNormReal eps gamma beta x j ≤ (ln.2 j : Real) := fun j => (hln j).2 have hmlp := mlpBounds_spec wIn bIn wOut bOut ln.1 ln.2 @@ -139,21 +140,21 @@ theorem layerNormMlpBounds_spec {n hidden : Nat} /-- Interval bounds for LayerNorm + MLP sublayer from interval inputs. -/ def layerNormAbsMlpBounds {n hidden : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) - (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) - (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) - (lo hi : Fin n → Rat) : (Fin n → Rat) × (Fin n → Rat) := + (eps : Dyadic) (gamma beta : Fin n → Dyadic) + (wIn : Fin n → Fin hidden → Dyadic) (bIn : Fin hidden → Dyadic) + (wOut : Fin hidden → Fin n → Dyadic) (bOut : Fin n → Dyadic) + (lo hi : Fin n → Dyadic) : (Fin n → Dyadic) × (Fin n → Dyadic) := let absBound := intervalAbsBound lo hi let ln := layerNormAbsBounds eps gamma beta absBound mlpBounds wIn bIn wOut bOut ln.1 ln.2 /-- `layerNormAbsMlpBounds` soundness for real LayerNorm + MLP outputs. -/ theorem layerNormAbsMlpBounds_spec {n hidden : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) - (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) - (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) - (lo hi : Fin n → Rat) (x : Fin n → Real) - (hne : n ≠ 0) (heps : 0 < eps) + (eps : Dyadic) (gamma beta : Fin n → Dyadic) + (wIn : Fin n → Fin hidden → Dyadic) (bIn : Fin hidden → Dyadic) + (wOut : Fin hidden → Fin n → Dyadic) (bOut : Fin n → Dyadic) + (lo hi : Fin n → Dyadic) (x : Fin n → Real) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) (hlo : ∀ i, (lo i : Real) ≤ x i) (hhi : ∀ i, x i ≤ (hi i : Real)) : let bounds := layerNormAbsMlpBounds eps gamma beta wIn bIn wOut bOut lo hi ∀ i, (bounds.1 i : Real) ≤ @@ -168,24 +169,17 @@ theorem layerNormAbsMlpBounds_spec {n hidden : Nat} have hbound : |x j| ≤ max |(lo j : Real)| |(hi j : Real)| := abs_le_max_abs_abs_of_interval_real (hlo j) (hhi j) - have hnonempty : (Finset.univ : Finset (Fin n)).Nonempty := ⟨j, by simp⟩ - have hsup : - max |lo j| |hi j| ≤ intervalAbsBound lo hi := by - have hsup' : - max |lo j| |hi j| ≤ - (Finset.univ).sup' hnonempty (fun k => max |lo k| |hi k|) := by - simpa using - (Finset.le_sup' - (s := (Finset.univ : Finset (Fin n))) - (f := fun k => max |lo k| |hi k|) - (by simp : j ∈ (Finset.univ : Finset (Fin n)))) - simpa [intervalAbsBound, hnonempty] using hsup' + have hsup : max |lo j| |hi j| ≤ intervalAbsBound lo hi := + max_abs_le_intervalAbsBound lo hi j have hsup_real : max |(lo j : Real)| |(hi j : Real)| ≤ (absBound : Real) := by - exact_mod_cast hsup + have hsup' : + dyadicToReal (max |lo j| |hi j|) ≤ dyadicToReal absBound := + dyadicToReal_le_of_le hsup + simpa [dyadicToReal_abs, dyadicToReal_max] using hsup' exact le_trans hbound hsup_real have hln := - layerNormAbsBounds_spec_real eps gamma beta absBound x hne heps habs + layerNormAbsBounds_spec_real eps gamma beta absBound x hne heps hsqrt habs have hlo_ln : ∀ j, (ln.1 j : Real) ≤ layerNormRealOfReal eps gamma beta x j := fun j => (hln j).1 have hhi_ln : ∀ j, layerNormRealOfReal eps gamma beta x j ≤ (ln.2 j : Real) := fun j => (hln j).2 have hmlp := mlpBounds_spec wIn bIn wOut bOut ln.1 ln.2 @@ -193,13 +187,13 @@ theorem layerNormAbsMlpBounds_spec {n hidden : Nat} simpa [bounds, layerNormAbsMlpBounds, absBound, ln] using hmlp i /-- Add residual inputs to interval bounds. -/ -def residualAddBounds {n : Nat} (x : Fin n → Rat) (lo hi : Fin n → Rat) : - (Fin n → Rat) × (Fin n → Rat) := +def residualAddBounds {n : Nat} (x : Fin n → Dyadic) (lo hi : Fin n → Dyadic) : + (Fin n → Dyadic) × (Fin n → Dyadic) := (fun i => x i + lo i, fun i => x i + hi i) /-- `residualAddBounds` soundness for residual addition. -/ -theorem residualAddBounds_spec {n : Nat} (x : Fin n → Rat) - (lo hi : Fin n → Rat) (y : Fin n → Real) +theorem residualAddBounds_spec {n : Nat} (x : Fin n → Dyadic) + (lo hi : Fin n → Dyadic) (y : Fin n → Real) (hlo : ∀ i, (lo i : Real) ≤ y i) (hhi : ∀ i, y i ≤ (hi i : Real)) : let bounds := residualAddBounds x lo hi ∀ i, (bounds.1 i : Real) ≤ (x i : Real) + y i ∧ @@ -208,24 +202,24 @@ theorem residualAddBounds_spec {n : Nat} (x : Fin n → Rat) have hlow := add_le_add_left (hlo i) (x i : Real) have hhigh := add_le_add_left (hhi i) (x i : Real) constructor - · simpa [bounds, residualAddBounds, Rat.cast_add] using hlow - · simpa [bounds, residualAddBounds, Rat.cast_add] using hhigh + · simpa [bounds, residualAddBounds] using hlow + · simpa [bounds, residualAddBounds] using hhigh /-- Interval bounds for a full MLP residual path (LayerNorm + MLP + residual add). -/ def layerNormMlpResidualBounds {n hidden : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) - (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) - (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) - (x : Fin n → Rat) : (Fin n → Rat) × (Fin n → Rat) := + (eps : Dyadic) (gamma beta : Fin n → Dyadic) + (wIn : Fin n → Fin hidden → Dyadic) (bIn : Fin hidden → Dyadic) + (wOut : Fin hidden → Fin n → Dyadic) (bOut : Fin n → Dyadic) + (x : Fin n → Dyadic) : (Fin n → Dyadic) × (Fin n → Dyadic) := let mlp := layerNormMlpBounds eps gamma beta wIn bIn wOut bOut x residualAddBounds x mlp.1 mlp.2 /-- `layerNormMlpResidualBounds` soundness for the MLP residual path. -/ theorem layerNormMlpResidualBounds_spec {n hidden : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) - (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) - (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) - (x : Fin n → Rat) (hne : n ≠ 0) (heps : 0 < eps) : + (eps : Dyadic) (gamma beta : Fin n → Dyadic) + (wIn : Fin n → Fin hidden → Dyadic) (bIn : Fin hidden → Dyadic) + (wOut : Fin hidden → Fin n → Dyadic) (bOut : Fin n → Dyadic) + (x : Fin n → Dyadic) (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : let bounds := layerNormMlpResidualBounds eps gamma beta wIn bIn wOut bOut x ∀ i, (bounds.1 i : Real) ≤ @@ -237,7 +231,7 @@ theorem layerNormMlpResidualBounds_spec {n hidden : Nat} classical intro bounds i let mlp := layerNormMlpBounds eps gamma beta wIn bIn wOut bOut x - have hmlp := layerNormMlpBounds_spec eps gamma beta wIn bIn wOut bOut x hne heps + have hmlp := layerNormMlpBounds_spec eps gamma beta wIn bIn wOut bOut x hne heps hsqrt have hres := residualAddBounds_spec x mlp.1 mlp.2 (mlpReal wIn bIn wOut bOut (layerNormReal eps gamma beta x)) (fun j => (hmlp j).1) (fun j => (hmlp j).2) @@ -245,20 +239,20 @@ theorem layerNormMlpResidualBounds_spec {n hidden : Nat} /-- Interval bounds for a full MLP residual path (LayerNorm + MLP + residual add) from intervals. -/ def layerNormAbsMlpResidualBounds {n hidden : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) - (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) - (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) - (lo hi : Fin n → Rat) : (Fin n → Rat) × (Fin n → Rat) := + (eps : Dyadic) (gamma beta : Fin n → Dyadic) + (wIn : Fin n → Fin hidden → Dyadic) (bIn : Fin hidden → Dyadic) + (wOut : Fin hidden → Fin n → Dyadic) (bOut : Fin n → Dyadic) + (lo hi : Fin n → Dyadic) : (Fin n → Dyadic) × (Fin n → Dyadic) := let mlp := layerNormAbsMlpBounds eps gamma beta wIn bIn wOut bOut lo hi (fun i => lo i + mlp.1 i, fun i => hi i + mlp.2 i) /-- `layerNormAbsMlpResidualBounds` soundness for the MLP residual path. -/ theorem layerNormAbsMlpResidualBounds_spec {n hidden : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) - (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) - (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) - (lo hi : Fin n → Rat) (x : Fin n → Real) - (hne : n ≠ 0) (heps : 0 < eps) + (eps : Dyadic) (gamma beta : Fin n → Dyadic) + (wIn : Fin n → Fin hidden → Dyadic) (bIn : Fin hidden → Dyadic) + (wOut : Fin hidden → Fin n → Dyadic) (bOut : Fin n → Dyadic) + (lo hi : Fin n → Dyadic) (x : Fin n → Real) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) (hlo : ∀ i, (lo i : Real) ≤ x i) (hhi : ∀ i, x i ≤ (hi i : Real)) : let bounds := layerNormAbsMlpResidualBounds eps gamma beta wIn bIn wOut bOut lo hi ∀ i, @@ -269,14 +263,15 @@ theorem layerNormAbsMlpResidualBounds_spec {n hidden : Nat} classical intro bounds i let mlp := layerNormAbsMlpBounds eps gamma beta wIn bIn wOut bOut lo hi - have hmlp := layerNormAbsMlpBounds_spec eps gamma beta wIn bIn wOut bOut lo hi x hne heps hlo hhi + have hmlp := + layerNormAbsMlpBounds_spec eps gamma beta wIn bIn wOut bOut lo hi x hne heps hsqrt hlo hhi have hlo' := (hmlp i).1 have hhi' := (hmlp i).2 have hlow := add_le_add (hlo i) hlo' have hhigh := add_le_add (hhi i) hhi' constructor - · simpa [bounds, layerNormAbsMlpResidualBounds, mlp, Rat.cast_add] using hlow - · simpa [bounds, layerNormAbsMlpResidualBounds, mlp, Rat.cast_add] using hhigh + · simpa [bounds, layerNormAbsMlpResidualBounds, mlp] using hlow + · simpa [bounds, layerNormAbsMlpResidualBounds, mlp] using hhigh end Bounds diff --git a/Nfp/Sound/Bounds/Transformer.lean b/Nfp/Sound/Bounds/Transformer.lean index 6a57aeb..c4c3687 100644 --- a/Nfp/Sound/Bounds/Transformer.lean +++ b/Nfp/Sound/Bounds/Transformer.lean @@ -1,12 +1,12 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Algebra.Order.Ring.Rat import Mathlib.Data.List.Range import Mathlib.Data.Real.Basic import Nfp.Model.Gpt2 import Nfp.Sound.Bounds.Attention import Nfp.Sound.Bounds.LayerNorm +import Nfp.Sound.Bounds.Transformer.Embedding import Nfp.Sound.Linear.FinFold /-! @@ -21,45 +21,10 @@ namespace Bounds open scoped BigOperators -private lemma fin_univ_nonempty (seq : Nat) [NeZero seq] : - (Finset.univ : Finset (Fin seq)).Nonempty := by - classical - refine ⟨⟨0, ?_⟩, by simp⟩ - exact Nat.pos_of_ne_zero (NeZero.ne (n := seq)) - -/-- Interval bounds across tokens for an embedding map. -/ -def embeddingIntervalBounds {seq dModel : Nat} [NeZero seq] - (x : Fin seq → Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := - let h : (Finset.univ : Finset (Fin seq)).Nonempty := fin_univ_nonempty (seq := seq) - (fun i => (Finset.univ).inf' h (fun q => x q i), - fun i => (Finset.univ).sup' h (fun q => x q i)) - -/-- `embeddingIntervalBounds` bounds embeddings coordinatewise. -/ -theorem embeddingIntervalBounds_spec {seq dModel : Nat} [NeZero seq] - (x : Fin seq → Fin dModel → Rat) : - let bounds := embeddingIntervalBounds x - ∀ q i, - (bounds.1 i : Real) ≤ (x q i : Real) ∧ - (x q i : Real) ≤ (bounds.2 i : Real) := by - classical - intro bounds q i - have hloRat : bounds.1 i ≤ x q i := by - have h := - Finset.inf'_le (s := (Finset.univ : Finset (Fin seq))) - (f := fun k => x k i) (b := q) (by simp) - simpa [bounds, embeddingIntervalBounds, fin_univ_nonempty] using h - have hhiRat : x q i ≤ bounds.2 i := by - have h := - Finset.le_sup' (s := (Finset.univ : Finset (Fin seq))) - (f := fun k => x k i) (b := q) (by simp) - simpa [bounds, embeddingIntervalBounds, fin_univ_nonempty] using h - constructor - · exact_mod_cast hloRat - · exact_mod_cast hhiRat /-- Real-valued output of a transformer layer. -/ noncomputable def transformerLayerReal {seq dModel dHead numHeads hidden : Nat} [NeZero seq] - (eps : Rat) (layer : Model.Gpt2LayerSlice dModel hidden) + (eps : Dyadic) (layer : Model.Gpt2LayerSlice dModel hidden) (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (scores : Fin numHeads → Fin seq → Fin seq → Real) (x : Fin seq → Fin dModel → Real) (q : Fin seq) (i : Fin dModel) : Real := @@ -73,11 +38,11 @@ noncomputable def transformerLayerReal {seq dModel dHead numHeads hidden : Nat} /-- `transformerLayerBounds` soundness for `transformerLayerReal`. -/ theorem transformerLayerBounds_spec_real {seq dModel dHead numHeads hidden : Nat} [NeZero seq] - (eps : Rat) (layer : Model.Gpt2LayerSlice dModel hidden) + (eps : Dyadic) (layer : Model.Gpt2LayerSlice dModel hidden) (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (scores : Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) - (hne : dModel ≠ 0) (heps : 0 < eps) + (lo hi : Fin dModel → Dyadic) (x : Fin seq → Fin dModel → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : let bounds := transformerLayerBounds eps layer.ln1Gamma layer.ln1Beta layer.ln2Gamma layer.ln2Beta heads layer.attnBias layer.mlpWIn layer.mlpBIn layer.mlpWOut @@ -94,11 +59,120 @@ theorem transformerLayerBounds_spec_real {seq dModel dHead numHeads hidden : Nat (mlpWIn := layer.mlpWIn) (mlpBIn := layer.mlpBIn) (mlpWOut := layer.mlpWOut) (mlpBOut := layer.mlpBOut) (scores := scores) (lo := lo) (hi := hi) (x := x) - hne heps hlo hhi) + hne heps hsqrt hlo hhi) + +/-- Interval bounds for a transformer layer from per-position bounds. -/ +def transformerLayerBoundsPos {seq dModel dHead numHeads hidden : Nat} [NeZero seq] + (eps : Dyadic) (layer : Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (lo hi : Fin seq → Fin dModel → Dyadic) : + (Fin seq → Fin dModel → Dyadic) × (Fin seq → Fin dModel → Dyadic) := + let positions := (Finset.univ : Finset (Fin seq)) + let hpos : positions.Nonempty := by + classical + have h : Nonempty (Fin seq) := + ⟨⟨0, Nat.pos_of_ne_zero (NeZero.ne (n := seq))⟩⟩ + exact (Finset.univ_nonempty_iff.mpr h) + let loCached := cacheBound2 lo + let hiCached := cacheBound2 hi + let base := intervalBoundsOn positions hpos loCached hiCached + let baseLo := cacheBound base.1 + let baseHi := cacheBound base.2 + let attn := attentionOutputBounds eps layer.ln1Gamma layer.ln1Beta heads layer.attnBias + baseLo baseHi + let attnLo := cacheBound attn.1 + let attnHi := cacheBound attn.2 + let yLo : Fin seq → Fin dModel → Dyadic := fun q i => loCached q i + attnLo i + let yHi : Fin seq → Fin dModel → Dyadic := fun q i => hiCached q i + attnHi i + let yLoCached := cacheBound2 yLo + let yHiCached := cacheBound2 yHi + let out := cacheBoundPair2 (fun q => + layerNormAbsMlpResidualBounds eps layer.ln2Gamma layer.ln2Beta + layer.mlpWIn layer.mlpBIn layer.mlpWOut layer.mlpBOut + (yLoCached q) (yHiCached q)) + out + +/-- `transformerLayerBoundsPos` soundness for `transformerLayerReal`. -/ +theorem transformerLayerBoundsPos_spec {seq dModel dHead numHeads hidden : Nat} [NeZero seq] + (eps : Dyadic) (layer : Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (scores : Fin numHeads → Fin seq → Fin seq → Real) + (lo hi : Fin seq → Fin dModel → Dyadic) (x : Fin seq → Fin dModel → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) + (hlo : ∀ q i, (lo q i : Real) ≤ x q i) + (hhi : ∀ q i, x q i ≤ (hi q i : Real)) : + let bounds := transformerLayerBoundsPos eps layer heads lo hi + ∀ q i, + (bounds.1 q i : Real) ≤ transformerLayerReal eps layer heads scores x q i ∧ + transformerLayerReal eps layer heads scores x q i ≤ (bounds.2 q i : Real) := by + classical + intro bounds q i + let positions := (Finset.univ : Finset (Fin seq)) + have hpos : positions.Nonempty := by + classical + have h : Nonempty (Fin seq) := + ⟨⟨0, Nat.pos_of_ne_zero (NeZero.ne (n := seq))⟩⟩ + exact (Finset.univ_nonempty_iff.mpr h) + let loCached := cacheBound2 lo + let hiCached := cacheBound2 hi + have hloCached : ∀ q i, (loCached q i : Real) ≤ x q i := by + intro q i + simpa [loCached, cacheBound2_apply] using hlo q i + have hhiCached : ∀ q i, x q i ≤ (hiCached q i : Real) := by + intro q i + simpa [hiCached, cacheBound2_apply] using hhi q i + let base := intervalBoundsOn positions hpos loCached hiCached + have hbase := intervalBoundsOn_spec positions hpos loCached hiCached x + (fun q _ i => hloCached q i) (fun q _ i => hhiCached q i) + have hloBase : ∀ q i, (base.1 i : Real) ≤ x q i := fun q i => + (hbase q (by simp [positions]) i).1 + have hhiBase : ∀ q i, x q i ≤ (base.2 i : Real) := fun q i => + (hbase q (by simp [positions]) i).2 + let baseLo := cacheBound base.1 + let baseHi := cacheBound base.2 + have hloBaseCached : ∀ q i, (baseLo i : Real) ≤ x q i := by + intro q i + simpa [baseLo, cacheBound_apply] using hloBase q i + have hhiBaseCached : ∀ q i, x q i ≤ (baseHi i : Real) := by + intro q i + simpa [baseHi, cacheBound_apply] using hhiBase q i + let attn := attentionOutputBounds eps layer.ln1Gamma layer.ln1Beta heads layer.attnBias + baseLo baseHi + have hattn := attentionOutputBounds_spec eps layer.ln1Gamma layer.ln1Beta heads + layer.attnBias scores baseLo baseHi x hne heps hsqrt hloBaseCached hhiBaseCached q + let attnLo := cacheBound attn.1 + let attnHi := cacheBound attn.2 + let y := fun j => + x q j + attentionOutputReal eps layer.ln1Gamma layer.ln1Beta heads + layer.attnBias scores x q j + have yLo : ∀ j, (loCached q j : Real) + (attn.1 j : Real) ≤ y j := by + intro j + have hlow := add_le_add (hloCached q j) (hattn j).1 + simpa [y] using hlow + have yHi : ∀ j, y j ≤ (hiCached q j : Real) + (attn.2 j : Real) := by + intro j + have hhigh := add_le_add (hhiCached q j) (hattn j).2 + simpa [y] using hhigh + let yLoCached := cacheBound2 (fun q i => loCached q i + attnLo i) + let yHiCached := cacheBound2 (fun q i => hiCached q i + attnHi i) + have yLoCached_bound : ∀ j, (yLoCached q j : Real) ≤ y j := by + intro j + simpa [yLoCached, attnLo, cacheBound_apply, cacheBound2_apply] using (yLo j) + have yHiCached_bound : ∀ j, y j ≤ (yHiCached q j : Real) := by + intro j + simpa [yHiCached, attnHi, cacheBound_apply, cacheBound2_apply] using (yHi j) + have hmlp := + layerNormAbsMlpResidualBounds_spec eps layer.ln2Gamma layer.ln2Beta + layer.mlpWIn layer.mlpBIn layer.mlpWOut layer.mlpBOut + (yLoCached q) (yHiCached q) y hne heps hsqrt yLoCached_bound yHiCached_bound + have hmlp_i := hmlp i + simpa [bounds, transformerLayerBoundsPos, positions, base, loCached, hiCached, baseLo, baseHi, + attn, attnLo, attnHi, y, yLoCached, yHiCached, cacheBound2_apply, cacheBoundPair2_apply_left, + cacheBoundPair2_apply_right, transformerLayerReal, cacheBound_apply] using hmlp_i /-- Real-valued transformer stack output (folded left over layers). -/ noncomputable def transformerStackReal - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Dyadic) (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) @@ -109,10 +183,10 @@ noncomputable def transformerStackReal /-- Interval bounds for a transformer stack (folded left over layers). -/ def transformerStackBounds {dModel dHead numHeads hidden numLayers : Nat} - (eps : Rat) + (eps : Dyadic) (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (lo hi : Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := + (lo hi : Fin dModel → Dyadic) : (Fin dModel → Dyadic) × (Fin dModel → Dyadic) := let step := fun bounds layerIdx => transformerLayerBounds eps (layers layerIdx).ln1Gamma (layers layerIdx).ln1Beta (layers layerIdx).ln2Gamma (layers layerIdx).ln2Beta (heads layerIdx) @@ -120,14 +194,82 @@ def transformerStackBounds {dModel dHead numHeads hidden numLayers : Nat} (layers layerIdx).mlpWOut (layers layerIdx).mlpBOut bounds.1 bounds.2 Linear.foldlFin numLayers step (lo, hi) +/-- Interval bounds for a transformer stack from per-position bounds. -/ +def transformerStackBoundsPos {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] + (eps : Dyadic) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (lo hi : Fin seq → Fin dModel → Dyadic) : + (Fin seq → Fin dModel → Dyadic) × (Fin seq → Fin dModel → Dyadic) := + let step := fun bounds layerIdx => + transformerLayerBoundsPos eps (layers layerIdx) (heads layerIdx) bounds.1 bounds.2 + Linear.foldlFin numLayers step (lo, hi) + +private theorem transformerStackBoundsPos_spec_list + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] + (eps : Dyadic) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : + ∀ (ls : List (Fin numLayers)) (lo hi : Fin seq → Fin dModel → Dyadic) + (x : Fin seq → Fin dModel → Real), + (∀ q i, (lo q i : Real) ≤ x q i) → + (∀ q i, x q i ≤ (hi q i : Real)) → + let bounds := (ls.foldl + (fun bounds layerIdx => + transformerLayerBoundsPos eps (layers layerIdx) (heads layerIdx) bounds.1 bounds.2) + (lo, hi)) + let x' := (ls.foldl + (fun x layerIdx => + transformerLayerReal eps (layers layerIdx) (heads layerIdx) (scores layerIdx) x) + x) + ∀ q i, + (bounds.1 q i : Real) ≤ x' q i ∧ + x' q i ≤ (bounds.2 q i : Real) := by + intro ls lo hi x hlo hhi + induction ls generalizing lo hi x hlo hhi with + | nil => + simpa using fun q i => And.intro (hlo q i) (hhi q i) + | cons l ls ih => + have hstep := + transformerLayerBoundsPos_spec eps (layers l) (heads l) (scores l) lo hi x + hne heps hsqrt hlo hhi + let bounds1 := transformerLayerBoundsPos eps (layers l) (heads l) lo hi + let x1 := transformerLayerReal eps (layers l) (heads l) (scores l) x + have hlo1 : ∀ q i, (bounds1.1 q i : Real) ≤ x1 q i := fun q i => (hstep q i).1 + have hhi1 : ∀ q i, x1 q i ≤ (bounds1.2 q i : Real) := fun q i => (hstep q i).2 + have ih' := ih bounds1.1 bounds1.2 x1 hlo1 hhi1 + simpa [bounds1, x1] using ih' + +/-- `transformerStackBoundsPos` soundness for real transformer-stack outputs. -/ +theorem transformerStackBoundsPos_spec {seq dModel dHead numHeads hidden numLayers : Nat} + [NeZero seq] (eps : Dyadic) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) + (lo hi : Fin seq → Fin dModel → Dyadic) (x : Fin seq → Fin dModel → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) + (hlo : ∀ q i, (lo q i : Real) ≤ x q i) + (hhi : ∀ q i, x q i ≤ (hi q i : Real)) : + let bounds := transformerStackBoundsPos eps layers heads lo hi + ∀ q i, + (bounds.1 q i : Real) ≤ transformerStackReal eps layers heads scores x q i ∧ + transformerStackReal eps layers heads scores x q i ≤ (bounds.2 q i : Real) := by + classical + simpa [transformerStackBoundsPos, transformerStackReal, Linear.foldlFin_eq_foldl, + Fin.foldl_eq_foldl_finRange] using + transformerStackBoundsPos_spec_list eps layers heads scores hne heps hsqrt + (List.finRange numLayers) lo hi x hlo hhi + private theorem transformerStackBounds_spec_list {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] - (eps : Rat) + (eps : Dyadic) (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (hne : dModel ≠ 0) (heps : 0 < eps) : - ∀ (ls : List (Fin numLayers)) (lo hi : Fin dModel → Rat) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : + ∀ (ls : List (Fin numLayers)) (lo hi : Fin dModel → Dyadic) (x : Fin seq → Fin dModel → Real), (∀ q i, (lo i : Real) ≤ x q i) → (∀ q i, x q i ≤ (hi i : Real)) → @@ -152,7 +294,7 @@ private theorem transformerStackBounds_spec_list | cons l ls ih => have hstep := transformerLayerBounds_spec_real eps (layers l) (heads l) (scores l) lo hi x - hne heps hlo hhi + hne heps hsqrt hlo hhi let bounds1 := transformerLayerBounds eps (layers l).ln1Gamma (layers l).ln1Beta (layers l).ln2Gamma (layers l).ln2Beta (heads l) (layers l).attnBias (layers l).mlpWIn (layers l).mlpBIn @@ -165,12 +307,12 @@ private theorem transformerStackBounds_spec_list /-- `transformerStackBounds` soundness for real transformer-stack outputs. -/ theorem transformerStackBounds_spec {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] - (eps : Rat) + (eps : Dyadic) (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) - (hne : dModel ≠ 0) (heps : 0 < eps) + (lo hi : Fin dModel → Dyadic) (x : Fin seq → Fin dModel → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : let bounds := transformerStackBounds eps layers heads lo hi ∀ q i, @@ -179,12 +321,12 @@ theorem transformerStackBounds_spec {seq dModel dHead numHeads hidden numLayers classical simpa [transformerStackBounds, transformerStackReal, Linear.foldlFin_eq_foldl, Fin.foldl_eq_foldl_finRange] using - transformerStackBounds_spec_list eps layers heads scores hne heps + transformerStackBounds_spec_list eps layers heads scores hne heps hsqrt (List.finRange numLayers) lo hi x hlo hhi /-- Real-valued transformer stack output after the final LayerNorm. -/ noncomputable def transformerStackFinalReal {seq dModel dHead numHeads hidden numLayers : Nat} - [NeZero seq] (eps : Rat) (finalLn : Model.Gpt2FinalLayerNorm dModel) + [NeZero seq] (eps : Dyadic) (finalLn : Model.Gpt2FinalLayerNorm dModel) (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) @@ -194,21 +336,21 @@ noncomputable def transformerStackFinalReal {seq dModel dHead numHeads hidden nu /-- Interval bounds for transformer stack outputs after the final LayerNorm. -/ def transformerStackFinalBounds {dModel dHead numHeads hidden numLayers : Nat} - (eps : Rat) (finalLn : Model.Gpt2FinalLayerNorm dModel) + (eps : Dyadic) (finalLn : Model.Gpt2FinalLayerNorm dModel) (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (lo hi : Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := + (lo hi : Fin dModel → Dyadic) : (Fin dModel → Dyadic) × (Fin dModel → Dyadic) := let stack := transformerStackBounds eps layers heads lo hi layerNormIntervalBounds eps finalLn.gamma finalLn.beta stack.1 stack.2 /-- `transformerStackFinalBounds` soundness for real outputs. -/ theorem transformerStackFinalBounds_spec {seq dModel dHead numHeads hidden numLayers : Nat} - [NeZero seq] (eps : Rat) (finalLn : Model.Gpt2FinalLayerNorm dModel) + [NeZero seq] (eps : Dyadic) (finalLn : Model.Gpt2FinalLayerNorm dModel) (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) - (hne : dModel ≠ 0) (heps : 0 < eps) + (lo hi : Fin dModel → Dyadic) (x : Fin seq → Fin dModel → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : let bounds := transformerStackFinalBounds eps finalLn layers heads lo hi ∀ q i, @@ -218,35 +360,79 @@ theorem transformerStackFinalBounds_spec {seq dModel dHead numHeads hidden numLa intro bounds q i let stack := transformerStackBounds eps layers heads lo hi have hstack := - transformerStackBounds_spec eps layers heads scores lo hi x hne heps hlo hhi q + transformerStackBounds_spec eps layers heads scores lo hi x hne heps hsqrt hlo hhi q have hlo' : ∀ k, (stack.1 k : Real) ≤ transformerStackReal eps layers heads scores x q k := fun k => (hstack k).1 have hhi' : ∀ k, transformerStackReal eps layers heads scores x q k ≤ (stack.2 k : Real) := fun k => (hstack k).2 have hln := layerNormIntervalBounds_spec_real eps finalLn.gamma finalLn.beta stack.1 stack.2 - (fun j => transformerStackReal eps layers heads scores x q j) hne heps hlo' hhi' + (fun j => transformerStackReal eps layers heads scores x q j) hne heps hsqrt hlo' hhi' simpa [bounds, transformerStackFinalBounds, stack, transformerStackFinalReal] using hln i +/-- Interval bounds for transformer stack outputs after the final LayerNorm (per-position). -/ +def transformerStackFinalBoundsPos + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Dyadic) + (finalLn : Model.Gpt2FinalLayerNorm dModel) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (lo hi : Fin seq → Fin dModel → Dyadic) : + (Fin seq → Fin dModel → Dyadic) × (Fin seq → Fin dModel → Dyadic) := + let stack := transformerStackBoundsPos eps layers heads lo hi + let ln := fun q => + layerNormIntervalBounds eps finalLn.gamma finalLn.beta (stack.1 q) (stack.2 q) + (fun q i => (ln q).1 i, fun q i => (ln q).2 i) + +/-- `transformerStackFinalBoundsPos` soundness for real outputs. -/ +theorem transformerStackFinalBoundsPos_spec + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Dyadic) + (finalLn : Model.Gpt2FinalLayerNorm dModel) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) + (lo hi : Fin seq → Fin dModel → Dyadic) (x : Fin seq → Fin dModel → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) + (hlo : ∀ q i, (lo q i : Real) ≤ x q i) + (hhi : ∀ q i, x q i ≤ (hi q i : Real)) : + let bounds := transformerStackFinalBoundsPos eps finalLn layers heads lo hi + ∀ q i, + (bounds.1 q i : Real) ≤ + transformerStackFinalReal eps finalLn layers heads scores x q i ∧ + transformerStackFinalReal eps finalLn layers heads scores x q i ≤ + (bounds.2 q i : Real) := by + classical + intro bounds q i + let stack := transformerStackBoundsPos eps layers heads lo hi + have hstack := + transformerStackBoundsPos_spec eps layers heads scores lo hi x hne heps hsqrt hlo hhi q + have hlo' : ∀ j, (stack.1 q j : Real) ≤ transformerStackReal eps layers heads scores x q j := + fun j => (hstack j).1 + have hhi' : ∀ j, transformerStackReal eps layers heads scores x q j ≤ (stack.2 q j : Real) := + fun j => (hstack j).2 + have hln := + layerNormIntervalBounds_spec_real eps finalLn.gamma finalLn.beta (stack.1 q) (stack.2 q) + (fun j => transformerStackReal eps layers heads scores x q j) hne heps hsqrt hlo' hhi' + simpa [bounds, transformerStackFinalBoundsPos, stack, transformerStackFinalReal] using hln i + /-- Residual interval bounds for a GPT-2 stack from exact embeddings. -/ def gpt2ResidualIntervalBounds - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Dyadic) (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (finalLn : Model.Gpt2FinalLayerNorm dModel) - (embed : Fin seq → Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := + (embed : Fin seq → Fin dModel → Dyadic) : (Fin dModel → Dyadic) × (Fin dModel → Dyadic) := let base := embeddingIntervalBounds embed transformerStackFinalBounds eps finalLn layers heads base.1 base.2 /-- `gpt2ResidualIntervalBounds` soundness for real GPT-2 outputs. -/ theorem gpt2ResidualIntervalBounds_spec - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Dyadic) (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (finalLn : Model.Gpt2FinalLayerNorm dModel) (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (embed : Fin seq → Fin dModel → Rat) - (hne : dModel ≠ 0) (heps : 0 < eps) : + (embed : Fin seq → Fin dModel → Dyadic) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : let bounds := gpt2ResidualIntervalBounds eps layers heads finalLn embed ∀ q i, (bounds.1 i : Real) ≤ @@ -262,9 +448,67 @@ theorem gpt2ResidualIntervalBounds_spec have hhi : ∀ q i, (embed q i : Real) ≤ (base.2 i : Real) := fun q i => (hbase q i).2 have hstack := transformerStackFinalBounds_spec eps finalLn layers heads scores base.1 base.2 - (fun q i => (embed q i : Real)) hne heps hlo hhi q i + (fun q i => (embed q i : Real)) hne heps hsqrt hlo hhi q i simpa [bounds, gpt2ResidualIntervalBounds, base] using hstack +/-- Residual interval bounds over an active set from exact embeddings. -/ +def gpt2ResidualIntervalBoundsActive + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] + (active : Finset (Fin seq)) (hactive : active.Nonempty) (eps : Dyadic) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (finalLn : Model.Gpt2FinalLayerNorm dModel) + (embed : Fin seq → Fin dModel → Dyadic) : (Fin dModel → Dyadic) × (Fin dModel → Dyadic) := + let baseLo : Fin seq → Fin dModel → Dyadic := embed + let baseHi : Fin seq → Fin dModel → Dyadic := embed + let final := transformerStackFinalBoundsPos eps finalLn layers heads baseLo baseHi + intervalBoundsOn active hactive final.1 final.2 + +/-- `gpt2ResidualIntervalBoundsActive` soundness for real GPT-2 outputs. -/ +theorem gpt2ResidualIntervalBoundsActive_spec + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] + (active : Finset (Fin seq)) (hactive : active.Nonempty) (eps : Dyadic) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (finalLn : Model.Gpt2FinalLayerNorm dModel) + (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) + (embed : Fin seq → Fin dModel → Dyadic) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : + let bounds := gpt2ResidualIntervalBoundsActive active hactive eps layers heads finalLn embed + ∀ q, q ∈ active → ∀ i, + (bounds.1 i : Real) ≤ + transformerStackFinalReal eps finalLn layers heads scores + (fun q i => (embed q i : Real)) q i ∧ + transformerStackFinalReal eps finalLn layers heads scores + (fun q i => (embed q i : Real)) q i ≤ (bounds.2 i : Real) := by + classical + intro bounds q hq i + let baseLo : Fin seq → Fin dModel → Dyadic := embed + let baseHi : Fin seq → Fin dModel → Dyadic := embed + let final := transformerStackFinalBoundsPos eps finalLn layers heads baseLo baseHi + have hfinal := + transformerStackFinalBoundsPos_spec eps finalLn layers heads scores baseLo baseHi + (fun q i => (embed q i : Real)) hne heps hsqrt + (fun q i => by simp [baseLo]) + (fun q i => by simp [baseHi]) + have hlo : ∀ q, q ∈ active → ∀ i, + (final.1 q i : Real) ≤ + transformerStackFinalReal eps finalLn layers heads scores + (fun q i => (embed q i : Real)) q i := by + intro q hq i + simpa [final] using (hfinal q i).1 + have hhi : ∀ q, q ∈ active → ∀ i, + transformerStackFinalReal eps finalLn layers heads scores + (fun q i => (embed q i : Real)) q i ≤ (final.2 q i : Real) := by + intro q hq i + simpa [final] using (hfinal q i).2 + have hbounds := intervalBoundsOn_spec active hactive final.1 final.2 + (fun q i => transformerStackFinalReal eps finalLn layers heads scores + (fun q i => (embed q i : Real)) q i) + hlo hhi + simpa [bounds, gpt2ResidualIntervalBoundsActive, final, baseLo, baseHi] using + hbounds q hq i + end Bounds end Sound diff --git a/Nfp/Sound/Bounds/Transformer/Embedding.lean b/Nfp/Sound/Bounds/Transformer/Embedding.lean new file mode 100644 index 0000000..2c80bf2 --- /dev/null +++ b/Nfp/Sound/Bounds/Transformer/Embedding.lean @@ -0,0 +1,132 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.BigOperators.Group.Finset.Basic +import Nfp.Core.Basic + +/-! +Embedding interval bounds for transformer stacks. + +This module isolates per-position and per-set embedding bounds. +-/ + +namespace Nfp + +namespace Sound + +namespace Bounds + +open scoped BigOperators + +private lemma fin_univ_nonempty (seq : Nat) [NeZero seq] : + (Finset.univ : Finset (Fin seq)).Nonempty := by + classical + refine ⟨⟨0, ?_⟩, by simp⟩ + exact Nat.pos_of_ne_zero (NeZero.ne (n := seq)) + +/-- Interval bounds across tokens for an embedding map. -/ +def embeddingIntervalBounds {seq dModel : Nat} [NeZero seq] + (x : Fin seq → Fin dModel → Dyadic) : (Fin dModel → Dyadic) × (Fin dModel → Dyadic) := + let h : (Finset.univ : Finset (Fin seq)).Nonempty := fin_univ_nonempty (seq := seq) + (fun i => (Finset.univ).inf' h (fun q => x q i), + fun i => (Finset.univ).sup' h (fun q => x q i)) + +/-- `embeddingIntervalBounds` bounds embeddings coordinatewise. -/ +theorem embeddingIntervalBounds_spec {seq dModel : Nat} [NeZero seq] + (x : Fin seq → Fin dModel → Dyadic) : + let bounds := embeddingIntervalBounds x + ∀ q i, + (bounds.1 i : Real) ≤ (x q i : Real) ∧ + (x q i : Real) ≤ (bounds.2 i : Real) := by + classical + intro bounds q i + have hloDyadic : bounds.1 i ≤ x q i := by + have h := + Finset.inf'_le (s := (Finset.univ : Finset (Fin seq))) + (f := fun k => x k i) (b := q) (by simp) + simpa [bounds, embeddingIntervalBounds, fin_univ_nonempty] using h + have hhiDyadic : x q i ≤ bounds.2 i := by + have h := + Finset.le_sup' (s := (Finset.univ : Finset (Fin seq))) + (f := fun k => x k i) (b := q) (by simp) + simpa [bounds, embeddingIntervalBounds, fin_univ_nonempty] using h + constructor + · simpa using (dyadicToReal_le_of_le hloDyadic) + · simpa using (dyadicToReal_le_of_le hhiDyadic) + +/-- Interval bounds across a finite set of positions for an embedding map. -/ +def embeddingIntervalBoundsOn {seq dModel : Nat} [NeZero seq] + (positions : Finset (Fin seq)) (hpos : positions.Nonempty) + (x : Fin seq → Fin dModel → Dyadic) : (Fin dModel → Dyadic) × (Fin dModel → Dyadic) := + (fun i => positions.inf' hpos (fun q => x q i), + fun i => positions.sup' hpos (fun q => x q i)) + +/-- `embeddingIntervalBoundsOn` bounds embeddings on the chosen positions. -/ +theorem embeddingIntervalBoundsOn_spec {seq dModel : Nat} [NeZero seq] + (positions : Finset (Fin seq)) (hpos : positions.Nonempty) + (x : Fin seq → Fin dModel → Dyadic) : + let bounds := embeddingIntervalBoundsOn positions hpos x + ∀ q, q ∈ positions → ∀ i, + (bounds.1 i : Real) ≤ (x q i : Real) ∧ + (x q i : Real) ≤ (bounds.2 i : Real) := by + classical + intro bounds q hq i + have hloDyadic : bounds.1 i ≤ x q i := by + have h := + Finset.inf'_le (s := positions) + (f := fun k => x k i) (b := q) hq + simpa [bounds, embeddingIntervalBoundsOn] using h + have hhiDyadic : x q i ≤ bounds.2 i := by + have h := + Finset.le_sup' (s := positions) + (f := fun k => x k i) (b := q) hq + simpa [bounds, embeddingIntervalBoundsOn] using h + constructor + · simpa using (dyadicToReal_le_of_le hloDyadic) + · simpa using (dyadicToReal_le_of_le hhiDyadic) + +/-- Collapse per-position interval bounds over a finite set of positions. -/ +def intervalBoundsOn {seq dModel : Nat} [NeZero seq] + (positions : Finset (Fin seq)) (hpos : positions.Nonempty) + (lo hi : Fin seq → Fin dModel → Dyadic) : (Fin dModel → Dyadic) × (Fin dModel → Dyadic) := + (fun i => positions.inf' hpos (fun q => lo q i), + fun i => positions.sup' hpos (fun q => hi q i)) + +/-- `intervalBoundsOn` soundness for bounds on the chosen positions. -/ +theorem intervalBoundsOn_spec {seq dModel : Nat} [NeZero seq] + (positions : Finset (Fin seq)) (hpos : positions.Nonempty) + (lo hi : Fin seq → Fin dModel → Dyadic) (x : Fin seq → Fin dModel → Real) + (hlo : ∀ q, q ∈ positions → ∀ i, (lo q i : Real) ≤ x q i) + (hhi : ∀ q, q ∈ positions → ∀ i, x q i ≤ (hi q i : Real)) : + let bounds := intervalBoundsOn positions hpos lo hi + ∀ q, q ∈ positions → ∀ i, + (bounds.1 i : Real) ≤ x q i ∧ + x q i ≤ (bounds.2 i : Real) := by + classical + intro bounds q hq i + have hmin : bounds.1 i ≤ lo q i := by + have h := + Finset.inf'_le (s := positions) + (f := fun k => lo k i) (b := q) hq + simpa [bounds, intervalBoundsOn] using h + have hmax : hi q i ≤ bounds.2 i := by + have h := + Finset.le_sup' (s := positions) + (f := fun k => hi k i) (b := q) hq + simpa [bounds, intervalBoundsOn] using h + have hlo' := hlo q hq i + have hhi' := hhi q hq i + constructor + · have hmin_real : + (bounds.1 i : Real) ≤ (lo q i : Real) := by + simpa using (dyadicToReal_le_of_le hmin) + exact le_trans hmin_real hlo' + · have hmax_real : + (hi q i : Real) ≤ (bounds.2 i : Real) := by + simpa using (dyadicToReal_le_of_le hmax) + exact le_trans hhi' hmax_real + +end Bounds + +end Sound + +end Nfp diff --git a/Nfp/Sound/Bounds/UnnormRat.lean b/Nfp/Sound/Bounds/UnnormRat.lean new file mode 100644 index 0000000..e2c0273 --- /dev/null +++ b/Nfp/Sound/Bounds/UnnormRat.lean @@ -0,0 +1,61 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Core.Basic +import Nfp.Sound.Linear.FinFold + +/-! +Unnormalized dyadic arithmetic. + +Dyadic values already avoid gcd normalization, so this module provides a +lightweight alias and helper API used by older code paths. +-/ + +namespace Nfp + +namespace Sound + +namespace Bounds + +/-- Unnormalized dyadic value (alias). -/ +abbrev UnnormDyadic := Dyadic + +/-- Interpret an unnormalized dyadic as a dyadic. -/ +def UnnormDyadic.toDyadic (q : UnnormDyadic) : Dyadic := + q + +/-- Embed a dyadic as an unnormalized dyadic. -/ +def UnnormDyadic.ofDyadic (q : Dyadic) : UnnormDyadic := + q + +/-- Unnormalized zero. -/ +def UnnormDyadic.zero : UnnormDyadic := 0 + +/-- Unnormalized addition. -/ +def UnnormDyadic.add (a b : UnnormDyadic) : UnnormDyadic := + a + b + +/-- Unnormalized multiplication. -/ +def UnnormDyadic.mul (a b : UnnormDyadic) : UnnormDyadic := + a * b + +/-- `toDyadic` respects multiplication. -/ +theorem UnnormDyadic.toDyadic_mul_ofDyadic (a b : Dyadic) : + UnnormDyadic.toDyadic (UnnormDyadic.mul (UnnormDyadic.ofDyadic a) + (UnnormDyadic.ofDyadic b)) = a * b := by + rfl + +/-- Tail-recursive sum of unnormalized dyadics. -/ +def UnnormDyadic.sumFin (n : Nat) (f : Fin n → UnnormDyadic) : UnnormDyadic := + Linear.sumFin n f + +/-- `toDyadic` commutes with `sumFin`. -/ +theorem UnnormDyadic.toDyadic_sumFin (n : Nat) (f : Fin n → UnnormDyadic) : + UnnormDyadic.toDyadic (UnnormDyadic.sumFin n f) = + Linear.sumFin n (fun i => UnnormDyadic.toDyadic (f i)) := by + rfl + +end Bounds + +end Sound + +end Nfp diff --git a/Nfp/Sound/Gpt2/HeadInputs.lean b/Nfp/Sound/Gpt2/HeadInputs.lean index 7d66d41..054cd55 100644 --- a/Nfp/Sound/Gpt2/HeadInputs.lean +++ b/Nfp/Sound/Gpt2/HeadInputs.lean @@ -40,7 +40,7 @@ def buildInductionHeadInputs {seq dModel dHead vocab : Nat} wo := slice.wo attnBias := slice.attnBias maskCausal := true - maskValue := (-10000 : Rat) + maskValue := (-10000 : Dyadic) directionSpec := slice.direction.spec direction := slice.directionVec } @@ -64,7 +64,7 @@ theorem buildInductionHeadInputs_def {seq dModel dHead vocab : Nat} wo := slice.wo attnBias := slice.attnBias maskCausal := true - maskValue := (-10000 : Rat) + maskValue := (-10000 : Dyadic) directionSpec := slice.direction.spec direction := slice.directionVec } := rfl diff --git a/Nfp/Sound/Induction.lean b/Nfp/Sound/Induction.lean index 77c91b2..6c0b295 100644 --- a/Nfp/Sound/Induction.lean +++ b/Nfp/Sound/Induction.lean @@ -1,1223 +1,10 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Algebra.Order.BigOperators.Group.Finset -import Mathlib.Algebra.Order.Ring.Rat -import Mathlib.Data.Finset.Lattice.Fold -import Mathlib.Data.Rat.Cast.Order -import Mathlib.Data.Vector.Defs -import Nfp.Circuit.Cert.ResidualInterval -import Nfp.Circuit.Cert.SoftmaxMargin -import Nfp.Circuit.Cert.ValueRange -import Nfp.Circuit.Layers.Softmax -import Nfp.Model.InductionHead -import Nfp.Sound.Bounds.LayerNorm -import Nfp.Sound.Bounds.MatrixNorm -import Nfp.Sound.Induction.OneHot -import Nfp.Sound.Linear.FinFold +import Nfp.Sound.Induction.Core +import Nfp.Sound.Induction.HeadOutput /-! Sound builders for induction certificates. -These builders recompute certificate bounds inside Lean from exact inputs and -return proof-carrying results. The head-input path derives softmax tolerances -from score margins rather than trusting external weight dumps. +This module re-exports the core constructions and head-output interval bounds. -/ - -namespace Nfp - -namespace Sound - -open scoped BigOperators - -open Nfp.Circuit -open Nfp.Sound.Bounds - -variable {seq : Nat} - -/-- Cached direction head for head inputs. -/ -private def dirHeadVecOfInputs {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : Vector Rat dHead := - Vector.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.wo j d) (fun j => inputs.direction j)) - -/-- Real-valued LayerNorm outputs for head inputs. -/ -private noncomputable def lnRealOfInputs {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dModel → Real := - fun q => - Bounds.layerNormReal inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q) - -/-- Real-valued query projections for head inputs. -/ -private noncomputable def qRealOfInputs {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dHead → Real := - fun q d => - dotProduct (fun j => (inputs.wq j d : Real)) (lnRealOfInputs inputs q) + (inputs.bq d : Real) - -/-- Real-valued key projections for head inputs. -/ -private noncomputable def kRealOfInputs {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dHead → Real := - fun q d => - dotProduct (fun j => (inputs.wk j d : Real)) (lnRealOfInputs inputs q) + (inputs.bk d : Real) - -/-- Real-valued value projections for head inputs. -/ -private noncomputable def vRealOfInputs {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dHead → Real := - fun q d => - dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs q) + (inputs.bv d : Real) - -/-- Real-valued attention scores for head inputs. -/ -noncomputable def scoresRealOfInputs {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin seq → Real := - fun q k => - let base := - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) - if inputs.maskCausal then - if k ≤ q then - base - else - (inputs.maskValue : Real) - else - base - -/-- Real-valued per-key head outputs in model space. -/ -private noncomputable def headValueRealOfInputs {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dModel → Real := - fun k i => - dotProduct (fun d => (inputs.wo i d : Real)) (fun d => vRealOfInputs inputs k d) - -/-- Real-valued direction scores for head inputs. -/ -noncomputable def valsRealOfInputs {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Real := - let dirHead : Fin dHead → Real := fun d => (dirHeadVecOfInputs inputs).get d - fun k => dotProduct dirHead (fun d => vRealOfInputs inputs k d) - -/-- Interval data for direction values. -/ -structure ValueInterval (seq : Nat) where - /-- Lower bound for values. -/ - lo : Rat - /-- Upper bound for values. -/ - hi : Rat - /-- Lower bounds on per-key values. -/ - valsLo : Fin seq → Rat - /-- Upper bounds on per-key values. -/ - valsHi : Fin seq → Rat - /-- Optional logit-diff direction metadata (ignored by the checker). -/ - direction : Option DirectionSpec - -/-- Soundness predicate for direction-value interval data. -/ -structure ValueIntervalBounds {seq : Nat} (vals : Fin seq → Real) - (c : ValueInterval seq) : Prop where - /-- Interval endpoints are ordered. -/ - lo_le_hi : c.lo ≤ c.hi - /-- `lo` is below every lower bound. -/ - lo_le_valsLo : ∀ k, (c.lo : Real) ≤ (c.valsLo k : Real) - /-- Bounds sandwich the real values. -/ - vals_bounds : - ∀ k, (c.valsLo k : Real) ≤ vals k ∧ vals k ≤ (c.valsHi k : Real) - /-- `hi` is above every upper bound. -/ - valsHi_le_hi : ∀ k, (c.valsHi k : Real) ≤ (c.hi : Real) - -/-- Sound induction-certificate payload built from exact head inputs. -/ -structure InductionHeadCert (seq : Nat) where - /-- Weight tolerance. -/ - eps : Rat - /-- Per-query weight tolerance derived from local margins. -/ - epsAt : Fin seq → Rat - /-- Score margin used to justify the weight tolerance. -/ - margin : Rat - /-- Active queries for which bounds are required. -/ - active : Finset (Fin seq) - /-- `prev` selector for induction-style attention. -/ - prev : Fin seq → Fin seq - /-- Value-interval certificate for the direction values. -/ - values : ValueInterval seq - -/-- Soundness predicate for `InductionHeadCert`. -/ -structure InductionHeadCertSound [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) : Prop where - /-- Softmax weights respect the derived margin bounds. -/ - softmax_bounds : - Layers.SoftmaxMarginBoundsOn (Val := Real) (c.eps : Real) (c.margin : Real) - (fun q => q ∈ c.active) c.prev - (scoresRealOfInputs inputs) - (fun q k => Circuit.softmax (scoresRealOfInputs inputs q) k) - /-- Per-query one-hot bounds derived from local margins. -/ - oneHot_bounds_at : - ∀ q, q ∈ c.active → - Layers.OneHotApproxBoundsOnActive (Val := Real) (c.epsAt q : Real) - (fun q' => q' = q) c.prev - (fun q' k => Circuit.softmax (scoresRealOfInputs inputs q') k) - /-- Interval bounds hold for the direction values. -/ - value_bounds : - ValueIntervalBounds (vals := valsRealOfInputs inputs) c.values - -/-- Build and certify a softmax-margin certificate from exact scores/weights. -/ -def buildSoftmaxMarginCert? [NeZero seq] - (active : Finset (Fin seq)) - (prev : Fin seq → Fin seq) - (scores : Fin seq → Fin seq → Rat) - (weights : Fin seq → Fin seq → Rat) : - Option {c : SoftmaxMarginCert seq // checkSoftmaxMarginCert c = true} := by - classical - let otherKeys : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (prev q) - let epsAt : Fin seq → Rat := fun q => - let other := otherKeys q - let maxOther := - if h : other.Nonempty then - other.sup' h (fun k => weights q k) - else - (0 : Rat) - let deficit := (1 : Rat) - weights q (prev q) - max maxOther deficit - let marginAt : Fin seq → Rat := fun q => - let other := otherKeys q - if h : other.Nonempty then - other.inf' h (fun k => scores q (prev q) - scores q k) - else - (0 : Rat) - let eps := - if h : active.Nonempty then - active.sup' h epsAt - else - (0 : Rat) - let margin := - if h : active.Nonempty then - active.inf' h marginAt - else - (0 : Rat) - let cert : SoftmaxMarginCert seq := - { eps := eps - margin := margin - active := active - prev := prev - scores := scores - weights := weights } - if h : checkSoftmaxMarginCert cert = true then - exact some ⟨cert, h⟩ - else - exact none - -/-- Build and certify a value-range certificate from exact values. -/ -def buildValueRangeCert? [NeZero seq] - (vals : Fin seq → Rat) - (direction : Option DirectionSpec) : - Option {c : ValueRangeCert seq // checkValueRangeCert c = true} := by - classical - let _ : Nonempty (Fin seq) := by - refine ⟨⟨0, ?_⟩⟩ - exact Nat.pos_of_ne_zero (NeZero.ne seq) - let univ : Finset (Fin seq) := Finset.univ - let hnonempty : univ.Nonempty := Finset.univ_nonempty - let lo := univ.inf' hnonempty vals - let hi := univ.sup' hnonempty vals - let cert : ValueRangeCert seq := - { lo := lo - hi := hi - vals := vals - direction := direction } - if h : checkValueRangeCert cert = true then - exact some ⟨cert, h⟩ - else - exact none - -/-- Build and certify induction certificates from exact head inputs. -/ -def buildInductionCertFromHead? [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : - Option {c : InductionHeadCert seq // InductionHeadCertSound inputs c} := by - classical - by_cases hEps : 0 < inputs.lnEps - · by_cases hmodel : dModel = 0 - · exact none - · let lnBounds : Fin seq → (Fin dModel → Rat) × (Fin dModel → Rat) := fun q => - Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q) - let lnLo : Fin seq → Fin dModel → Rat := fun q => (lnBounds q).1 - let lnHi : Fin seq → Fin dModel → Rat := fun q => (lnBounds q).2 - let qLo : Fin seq → Fin dHead → Rat := fun q d => - dotIntervalLower (fun j => inputs.wq j d) (lnLo q) (lnHi q) + inputs.bq d - let qHi : Fin seq → Fin dHead → Rat := fun q d => - dotIntervalUpper (fun j => inputs.wq j d) (lnLo q) (lnHi q) + inputs.bq d - let kLo : Fin seq → Fin dHead → Rat := fun q d => - dotIntervalLower (fun j => inputs.wk j d) (lnLo q) (lnHi q) + inputs.bk d - let kHi : Fin seq → Fin dHead → Rat := fun q d => - dotIntervalUpper (fun j => inputs.wk j d) (lnLo q) (lnHi q) + inputs.bk d - let vLo : Fin seq → Fin dHead → Rat := fun q d => - dotIntervalLower (fun j => inputs.wv j d) (lnLo q) (lnHi q) + inputs.bv d - let vHi : Fin seq → Fin dHead → Rat := fun q d => - dotIntervalUpper (fun j => inputs.wv j d) (lnLo q) (lnHi q) + inputs.bv d - let qAbs : Fin seq → Fin dHead → Rat := fun q d => - max |qLo q d| |qHi q d| - let kAbs : Fin seq → Fin dHead → Rat := fun q d => - max |kLo q d| |kHi q d| - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal ∧ q < k - let dotAbs : Fin seq → Fin seq → Rat := fun q k => - dotProduct (fun d => qAbs q d) (fun d => kAbs k d) - let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => - |inputs.scale| * dotAbs q k - let scoreAbs : Fin seq → Fin seq → Rat := fun q k => - if masked q k then |inputs.maskValue| else scoreBaseAbs q k - let scoreLo : Fin seq → Fin seq → Rat := fun q k => - if masked q k then inputs.maskValue else -scoreBaseAbs q k - let scoreHi : Fin seq → Fin seq → Rat := fun q k => - if masked q k then inputs.maskValue else scoreBaseAbs q k - let otherKeys : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - let marginAt : Fin seq → Rat := fun q => - let other := otherKeys q - if h : other.Nonempty then - other.inf' h (fun k => scoreLo q (inputs.prev q) - scoreHi q k) - else - (0 : Rat) - let epsAt : Fin seq → Rat := fun q => - if marginAt q < 0 then - (1 : Rat) - else - (seq - 1 : Rat) / (1 + marginAt q) - let margin : Rat := - if h : inputs.active.Nonempty then - inputs.active.inf' h marginAt - else - (0 : Rat) - let eps : Rat := - if margin < 0 then - (1 : Rat) - else - (seq - 1 : Rat) / (1 + margin) - have hseq : (1 : Nat) ≤ seq := - Nat.succ_le_iff.mpr (Nat.pos_of_ne_zero (NeZero.ne seq)) - let dirHeadVec := dirHeadVecOfInputs inputs - let dirHead : Fin dHead → Rat := fun d => dirHeadVec.get d - let valsLo : Fin seq → Rat := fun k => - dotIntervalLower dirHead (vLo k) (vHi k) - let valsHi : Fin seq → Rat := fun k => - dotIntervalUpper dirHead (vLo k) (vHi k) - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := by simp [univ] - let lo := univ.inf' hnonempty valsLo - let hi := univ.sup' hnonempty valsHi - let valCert : ValueInterval seq := - { lo := lo - hi := hi - valsLo := valsLo - valsHi := valsHi - direction := some inputs.directionSpec } - let cert : InductionHeadCert seq := - { eps := eps - epsAt := epsAt - margin := margin - active := inputs.active - prev := inputs.prev - values := valCert } - have hln_bounds : - ∀ q i, (lnLo q i : Real) ≤ lnRealOfInputs inputs q i ∧ - lnRealOfInputs inputs q i ≤ (lnHi q i : Real) := by - intro q i - have hln := - Bounds.layerNormBounds_spec (eps := inputs.lnEps) - (gamma := inputs.ln1Gamma) (beta := inputs.ln1Beta) - (x := inputs.embed q) hmodel hEps - simpa [lnBounds, lnLo, lnHi, lnRealOfInputs] using hln i - have hq_bounds : - ∀ q d, (qLo q d : Real) ≤ qRealOfInputs inputs q d ∧ - qRealOfInputs inputs q d ≤ (qHi q d : Real) := by - intro q d - have hln := hln_bounds q - have hlo : ∀ j, (lnLo q j : Real) ≤ lnRealOfInputs inputs q j := fun j => (hln j).1 - have hhi : ∀ j, lnRealOfInputs inputs q j ≤ (lnHi q j : Real) := fun j => (hln j).2 - have hlow := - dotIntervalLower_le_dotProduct_real (v := fun j => inputs.wq j d) - (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi - have hhigh := - dotProduct_le_dotIntervalUpper_real (v := fun j => inputs.wq j d) - (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi - have hlow' := add_le_add_right hlow (inputs.bq d : Real) - have hhigh' := add_le_add_right hhigh (inputs.bq d : Real) - constructor - · simpa [qLo, qRealOfInputs, Rat.cast_add] using hlow' - · simpa [qHi, qRealOfInputs, Rat.cast_add] using hhigh' - have hk_bounds : - ∀ q d, (kLo q d : Real) ≤ kRealOfInputs inputs q d ∧ - kRealOfInputs inputs q d ≤ (kHi q d : Real) := by - intro q d - have hln := hln_bounds q - have hlo : ∀ j, (lnLo q j : Real) ≤ lnRealOfInputs inputs q j := fun j => (hln j).1 - have hhi : ∀ j, lnRealOfInputs inputs q j ≤ (lnHi q j : Real) := fun j => (hln j).2 - have hlow := - dotIntervalLower_le_dotProduct_real (v := fun j => inputs.wk j d) - (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi - have hhigh := - dotProduct_le_dotIntervalUpper_real (v := fun j => inputs.wk j d) - (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi - have hlow' := add_le_add_right hlow (inputs.bk d : Real) - have hhigh' := add_le_add_right hhigh (inputs.bk d : Real) - constructor - · simpa [kLo, kRealOfInputs, Rat.cast_add] using hlow' - · simpa [kHi, kRealOfInputs, Rat.cast_add] using hhigh' - have hv_bounds : - ∀ q d, (vLo q d : Real) ≤ vRealOfInputs inputs q d ∧ - vRealOfInputs inputs q d ≤ (vHi q d : Real) := by - intro q d - have hln := hln_bounds q - have hlo : ∀ j, (lnLo q j : Real) ≤ lnRealOfInputs inputs q j := fun j => (hln j).1 - have hhi : ∀ j, lnRealOfInputs inputs q j ≤ (lnHi q j : Real) := fun j => (hln j).2 - have hlow := - dotIntervalLower_le_dotProduct_real (v := fun j => inputs.wv j d) - (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi - have hhigh := - dotProduct_le_dotIntervalUpper_real (v := fun j => inputs.wv j d) - (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi - have hlow' := add_le_add_right hlow (inputs.bv d : Real) - have hhigh' := add_le_add_right hhigh (inputs.bv d : Real) - constructor - · simpa [vLo, vRealOfInputs, Rat.cast_add] using hlow' - · simpa [vHi, vRealOfInputs, Rat.cast_add] using hhigh' - have hscore_bounds : - ∀ q k, (scoreLo q k : Real) ≤ scoresRealOfInputs inputs q k ∧ - scoresRealOfInputs inputs q k ≤ (scoreHi q k : Real) := by - intro q k - let scoresReal := scoresRealOfInputs inputs - let base := - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) - have hq_abs : ∀ d, |qRealOfInputs inputs q d| ≤ (qAbs q d : Real) := by - intro d - have hq := hq_bounds q d - have h := abs_le_max_abs_abs_of_interval_real hq.1 hq.2 - simpa [qAbs] using h - have hk_abs : ∀ d, |kRealOfInputs inputs k d| ≤ (kAbs k d : Real) := by - intro d - have hk := hk_bounds k d - have h := abs_le_max_abs_abs_of_interval_real hk.1 hk.2 - simpa [kAbs] using h - have hdot_abs : - |dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d)| ≤ - (dotAbs q k : Real) := by - have hsum : - |∑ d, qRealOfInputs inputs q d * kRealOfInputs inputs k d| ≤ - ∑ d, |qRealOfInputs inputs q d * kRealOfInputs inputs k d| := by - simpa [dotProduct] using - (Finset.abs_sum_le_sum_abs (s := (Finset.univ : Finset (Fin dHead))) - (f := fun d => qRealOfInputs inputs q d * kRealOfInputs inputs k d)) - have hterm : - ∀ d, - |qRealOfInputs inputs q d * kRealOfInputs inputs k d| ≤ - (qAbs q d : Real) * (kAbs k d : Real) := by - intro d - have hq := hq_abs d - have hk := hk_abs d - have hqnonneg : 0 ≤ (qAbs q d : Real) := by - have hqnonneg' : 0 ≤ qAbs q d := by - have h1 : 0 ≤ |qLo q d| := abs_nonneg (qLo q d) - exact le_trans h1 (le_max_left _ _) - exact_mod_cast hqnonneg' - calc - |qRealOfInputs inputs q d * kRealOfInputs inputs k d| = - |qRealOfInputs inputs q d| * |kRealOfInputs inputs k d| := by - simp [abs_mul] - _ ≤ (qAbs q d : Real) * (kAbs k d : Real) := - mul_le_mul hq hk (abs_nonneg _) hqnonneg - have hsum_le : - ∑ d, |qRealOfInputs inputs q d * kRealOfInputs inputs k d| ≤ - ∑ d, (qAbs q d : Real) * (kAbs k d : Real) := by - refine Finset.sum_le_sum ?_ - intro d _ - exact hterm d - have hcast : - (dotAbs q k : Real) = - ∑ d, (qAbs q d : Real) * (kAbs k d : Real) := by - simp [dotAbs, dotProduct] - have hfinal := hsum.trans (hsum_le.trans_eq hcast.symm) - simpa [dotProduct] using hfinal - have hscale_abs : 0 ≤ (|inputs.scale| : Real) := by - exact_mod_cast (abs_nonneg (inputs.scale)) - have hbase_abs : - |base| ≤ (scoreBaseAbs q k : Real) := by - have hdot_abs' := hdot_abs - have hmul : - |base| = - (|inputs.scale| : Real) * - |dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d)| := by - simp [base, abs_mul] - have hmul_le : - (|inputs.scale| : Real) * - |dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d)| ≤ - (|inputs.scale| : Real) * (dotAbs q k : Real) := by - exact mul_le_mul_of_nonneg_left hdot_abs' hscale_abs - simpa [scoreBaseAbs, hmul] using hmul_le - by_cases hcausal : inputs.maskCausal - · by_cases hle : k ≤ q - · have hnot : ¬ q < k := not_lt_of_ge hle - have hscore_eq : scoresReal q k = base := by - simp [scoresReal, scoresRealOfInputs, hcausal, hle, base] - have hscore_abs' : |scoresReal q k| ≤ (scoreBaseAbs q k : Real) := by - simpa [hscore_eq] using hbase_abs - have hscore_abs : - |scoresReal q k| ≤ (scoreAbs q k : Real) := by - simpa [scoreAbs, masked, hcausal, hnot] using hscore_abs' - have hscore_bounds := (abs_le).1 hscore_abs - constructor - · simpa [scoresReal, scoreLo, scoreAbs, masked, hcausal, hnot] - using hscore_bounds.1 - · simpa [scoresReal, scoreHi, scoreAbs, masked, hcausal, hnot] - using hscore_bounds.2 - · have hlt : q < k := lt_of_not_ge hle - constructor - · simp [scoresRealOfInputs, scoreLo, masked, hcausal, hle, hlt] - · simp [scoresRealOfInputs, scoreHi, masked, hcausal, hle, hlt] - · have hscore_eq : scoresReal q k = base := by - simp [scoresReal, scoresRealOfInputs, hcausal, base] - have hscore_abs' : |scoresReal q k| ≤ (scoreBaseAbs q k : Real) := by - simpa [hscore_eq] using hbase_abs - have hscore_abs : - |scoresReal q k| ≤ (scoreAbs q k : Real) := by - simpa [scoreAbs, masked, hcausal] using hscore_abs' - have hscore_bounds := (abs_le).1 hscore_abs - constructor - · simpa [scoresReal, scoreLo, scoreAbs, masked, hcausal] - using hscore_bounds.1 - · simpa [scoresReal, scoreHi, scoreAbs, masked, hcausal] - using hscore_bounds.2 - let scoresReal := scoresRealOfInputs inputs - let weights : Fin seq → Fin seq → Real := fun q k => - Circuit.softmax (scoresReal q) k - let others : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - have hscore_margin_real : - ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → - scoresReal q k + (margin : Real) ≤ scoresReal q (inputs.prev q) := by - intro q hq k hk - by_cases hactive : inputs.active.Nonempty - · have hmargin_le : margin ≤ marginAt q := by - have hle : margin ≤ inputs.active.inf' hactive marginAt := by - simp [margin, hactive] - have hle_all := - (Finset.le_inf'_iff (s := inputs.active) (H := hactive) (f := marginAt) - (a := margin)).1 hle - exact hle_all q hq - have hother : (otherKeys q).Nonempty := ⟨k, by simp [otherKeys, hk]⟩ - have hgap_le : - marginAt q ≤ scoreLo q (inputs.prev q) - scoreHi q k := by - have hle : marginAt q ≤ - (otherKeys q).inf' hother - (fun k => scoreLo q (inputs.prev q) - scoreHi q k) := by - simp [marginAt, hother] - have hle_all := - (Finset.le_inf'_iff (s := otherKeys q) (H := hother) - (f := fun k => scoreLo q (inputs.prev q) - scoreHi q k) - (a := marginAt q)).1 hle - exact hle_all k (by simp [otherKeys, hk]) - have hgap : margin ≤ scoreLo q (inputs.prev q) - scoreHi q k := - le_trans hmargin_le hgap_le - have hgap_real : (margin : Real) ≤ - (scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real) := by - have hgap_real' : - (margin : Real) ≤ ((scoreLo q (inputs.prev q) - scoreHi q k : Rat) : Real) := - (Rat.cast_le (K := Real)).2 hgap - simpa [Rat.cast_sub] using hgap_real' - have hk_bounds := hscore_bounds q k - have hprev_bounds := hscore_bounds q (inputs.prev q) - have h1 : - scoresReal q k + (margin : Real) ≤ - scoresReal q k + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) := by - exact (add_le_add_iff_left (scoresReal q k)).2 hgap_real - have h2 : - scoresReal q k + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) ≤ - (scoreLo q (inputs.prev q) : Real) := by - have hscore_le' : - scoresReal q k + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) ≤ - (scoreHi q k : Real) + - ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) := by - exact (add_le_add_iff_right - ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real))).2 hk_bounds.2 - calc - scoresReal q k + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) ≤ - (scoreHi q k : Real) + - ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) := by - exact hscore_le' - _ = (scoreLo q (inputs.prev q) : Real) := by - exact add_sub_cancel (scoreHi q k : Real) (scoreLo q (inputs.prev q) : Real) - have h3 : - scoresReal q k + (margin : Real) ≤ (scoreLo q (inputs.prev q) : Real) := - h1.trans h2 - exact h3.trans hprev_bounds.1 - · exact (hactive ⟨q, hq⟩).elim - have hsoftmax_bounds : - Layers.SoftmaxMarginBoundsOn (Val := Real) (eps : Real) (margin : Real) - (fun q => q ∈ inputs.active) inputs.prev scoresReal weights := by - classical - refine - { score_margin := ?_ - nonneg := ?_ - sum_one := ?_ - prev_large := ?_ - other_le := ?_ } - · intro q hq k hk - exact hscore_margin_real q hq k hk - · intro q _ k - simpa [weights] using - (Circuit.softmax_nonneg (scores := scoresReal q) k) - · intro q _ - simpa [weights] using - (Circuit.softmax_sum_one (scores := scoresReal q)) - · intro q hq - have hsum_others_le : (∑ k ∈ others q, weights q k) ≤ (eps : Real) := by - by_cases hneg : margin < 0 - · have heps : (eps : Real) = 1 := by - simp [eps, hneg] - have hsubset : others q ⊆ (Finset.univ : Finset (Fin seq)) := by - intro k hk - simp - have hnonneg : - ∀ k ∈ (Finset.univ : Finset (Fin seq)), 0 ≤ weights q k := by - intro k _ - simpa [weights] using - (Circuit.softmax_nonneg (scores := scoresReal q) k) - have hsum_le : - (∑ k ∈ others q, weights q k) ≤ - ∑ k ∈ (Finset.univ : Finset (Fin seq)), weights q k := - Finset.sum_le_sum_of_subset_of_nonneg hsubset (by - intro k hk _; exact hnonneg k hk) - have hsum_one : (∑ k, weights q k) = 1 := by - simpa [weights] using - (Circuit.softmax_sum_one (scores := scoresReal q)) - have hsum_le' : (∑ k ∈ others q, weights q k) ≤ 1 := by - simpa [hsum_one] using hsum_le - simpa [heps] using hsum_le' - · have hnonneg : 0 ≤ margin := le_of_not_gt hneg - have hnonneg_real : 0 ≤ (margin : Real) := by - exact (Rat.cast_nonneg (K := Real)).2 hnonneg - have hbound : - ∀ k ∈ others q, - weights q k ≤ (1 + (margin : Real))⁻¹ := by - intro k hk - have hkne : k ≠ inputs.prev q := (Finset.mem_erase.mp hk).1 - have hscore := hscore_margin_real q hq k hkne - simpa [weights] using - (Circuit.softmax_other_le_inv_one_add (scores := scoresReal q) - (prev := inputs.prev q) (k := k) (m := (margin : Real)) - hnonneg_real hscore) - have hsum_le : - (∑ k ∈ others q, weights q k) ≤ - ∑ k ∈ others q, (1 + (margin : Real))⁻¹ := - Finset.sum_le_sum hbound - have hsum_const : - (∑ k ∈ others q, (1 + (margin : Real))⁻¹) = - (others q).card * (1 + (margin : Real))⁻¹ := by - simp - have hcard : (others q).card = seq - 1 := by - simp [others, Finset.card_erase_of_mem] - have hsum_le' : - (∑ k ∈ others q, weights q k) ≤ - (seq - 1 : Real) * (1 + (margin : Real))⁻¹ := by - have hsum_le'' := hsum_le.trans_eq hsum_const - have hsum_le''' := hsum_le'' - simp only [hcard, Nat.cast_sub hseq, Nat.cast_one] at hsum_le''' - exact hsum_le''' - have heps : - (eps : Real) = (seq - 1 : Real) * (1 + (margin : Real))⁻¹ := by - simp [eps, hneg, Rat.cast_add, div_eq_mul_inv] - simpa [heps] using hsum_le' - have hsum_eq : - weights q (inputs.prev q) + ∑ k ∈ others q, weights q k = 1 := by - have hsum' : - weights q (inputs.prev q) + ∑ k ∈ others q, weights q k = - ∑ k, weights q k := by - simp [others] - have hsum_one : (∑ k, weights q k) = 1 := by - simpa [weights] using - (Circuit.softmax_sum_one (scores := scoresReal q)) - calc - weights q (inputs.prev q) + ∑ k ∈ others q, weights q k = - ∑ k, weights q k := hsum' - _ = 1 := hsum_one - have hsum_le' : - weights q (inputs.prev q) + ∑ k ∈ others q, weights q k ≤ - weights q (inputs.prev q) + (eps : Real) := by - have hsum_le'' := add_le_add_left hsum_others_le (weights q (inputs.prev q)) - simpa [add_comm, add_left_comm, add_assoc] using hsum_le'' - have hprev : - 1 ≤ weights q (inputs.prev q) + (eps : Real) := by - simpa [hsum_eq] using hsum_le' - exact hprev - · intro q hq k hk - have hsum_others_le : (∑ j ∈ others q, weights q j) ≤ (eps : Real) := by - by_cases hneg : margin < 0 - · have heps : (eps : Real) = 1 := by - simp [eps, hneg] - have hsubset : others q ⊆ (Finset.univ : Finset (Fin seq)) := by - intro j hj - simp - have hnonneg : - ∀ j ∈ (Finset.univ : Finset (Fin seq)), 0 ≤ weights q j := by - intro j _ - simpa [weights] using - (Circuit.softmax_nonneg (scores := scoresReal q) j) - have hsum_le : - (∑ j ∈ others q, weights q j) ≤ - ∑ j ∈ (Finset.univ : Finset (Fin seq)), weights q j := - Finset.sum_le_sum_of_subset_of_nonneg hsubset (by - intro j hj _; exact hnonneg j hj) - have hsum_one : (∑ j, weights q j) = 1 := by - simpa [weights] using - (Circuit.softmax_sum_one (scores := scoresReal q)) - have hsum_le' : (∑ j ∈ others q, weights q j) ≤ 1 := by - simpa [hsum_one] using hsum_le - simpa [heps] using hsum_le' - · have hnonneg : 0 ≤ margin := le_of_not_gt hneg - have hnonneg_real : 0 ≤ (margin : Real) := by - exact (Rat.cast_nonneg (K := Real)).2 hnonneg - have hbound : - ∀ j ∈ others q, - weights q j ≤ (1 + (margin : Real))⁻¹ := by - intro j hj - have hjne : j ≠ inputs.prev q := (Finset.mem_erase.mp hj).1 - have hscore := hscore_margin_real q hq j hjne - simpa [weights] using - (Circuit.softmax_other_le_inv_one_add (scores := scoresReal q) - (prev := inputs.prev q) (k := j) (m := (margin : Real)) - hnonneg_real hscore) - have hsum_le : - (∑ j ∈ others q, weights q j) ≤ - ∑ j ∈ others q, (1 + (margin : Real))⁻¹ := - Finset.sum_le_sum hbound - have hsum_const : - (∑ j ∈ others q, (1 + (margin : Real))⁻¹) = - (others q).card * (1 + (margin : Real))⁻¹ := by - simp - have hcard : (others q).card = seq - 1 := by - simp [others, Finset.card_erase_of_mem] - have hsum_le' : - (∑ j ∈ others q, weights q j) ≤ - (seq - 1 : Real) * (1 + (margin : Real))⁻¹ := by - have hsum_le'' := hsum_le.trans_eq hsum_const - have hsum_le''' := hsum_le'' - simp only [hcard, Nat.cast_sub hseq, Nat.cast_one] at hsum_le''' - exact hsum_le''' - have heps : - (eps : Real) = (seq - 1 : Real) * (1 + (margin : Real))⁻¹ := by - simp [eps, hneg, Rat.cast_add, div_eq_mul_inv] - simpa [heps] using hsum_le' - have hk' : k ∈ others q := by - simp [others, hk] - have hnonneg : - ∀ j ∈ others q, 0 ≤ weights q j := by - intro j _ - simpa [weights] using - (Circuit.softmax_nonneg (scores := scoresReal q) j) - have hle : - weights q k ≤ ∑ j ∈ others q, weights q j := by - have h := Finset.single_le_sum hnonneg hk' - simpa using h - exact hle.trans hsum_others_le - have hscore_margin_real_at : - ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → - scoresReal q k + (marginAt q : Real) ≤ scoresReal q (inputs.prev q) := by - intro q hq k hk - let other := otherKeys q - have hother : other.Nonempty := by - refine ⟨k, ?_⟩ - simp [other, otherKeys, hk] - have hgap_le : - marginAt q ≤ scoreLo q (inputs.prev q) - scoreHi q k := by - have hkmem : k ∈ other := by - simp [other, otherKeys, hk] - have hle : - other.inf' hother (fun k => scoreLo q (inputs.prev q) - scoreHi q k) ≤ - scoreLo q (inputs.prev q) - scoreHi q k := by - exact (Finset.inf'_le (s := other) (f := fun k => - scoreLo q (inputs.prev q) - scoreHi q k) (b := k) hkmem) - have hmarginAt : - marginAt q = - other.inf' hother (fun k => scoreLo q (inputs.prev q) - scoreHi q k) := by - simp [marginAt, hother, other] - simpa [hmarginAt] using hle - have hgap_real : - (marginAt q : Real) ≤ - (scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real) := by - have hgap_real' : - (marginAt q : Real) ≤ - ((scoreLo q (inputs.prev q) - scoreHi q k : Rat) : Real) := - (Rat.cast_le (K := Real)).2 hgap_le - simpa [Rat.cast_sub] using hgap_real' - have hk_bounds := hscore_bounds q k - have hprev_bounds := hscore_bounds q (inputs.prev q) - have h1 : - scoresReal q k + (marginAt q : Real) ≤ - (scoreHi q k : Real) + (marginAt q : Real) := by - have h1' := add_le_add_right hk_bounds.2 (marginAt q : Real) - simpa [scoresReal] using h1' - have h2 : - (scoreHi q k : Real) + (marginAt q : Real) ≤ - (scoreLo q (inputs.prev q) : Real) := by - have hgap_real' : - (scoreHi q k : Real) + (marginAt q : Real) ≤ - (scoreHi q k : Real) + - ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) := by - exact add_le_add_right hgap_real (scoreHi q k : Real) - have hgap_real'' : - (scoreHi q k : Real) + - ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) = - (scoreLo q (inputs.prev q) : Real) := by - calc - (scoreHi q k : Real) + - ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) = - ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) + - (scoreHi q k : Real) := by - exact add_comm _ _ - _ = (scoreLo q (inputs.prev q) : Real) := by - exact sub_add_cancel (scoreLo q (inputs.prev q) : Real) (scoreHi q k : Real) - exact hgap_real'.trans (le_of_eq hgap_real'') - have h3 : - scoresReal q k + (marginAt q : Real) ≤ - (scoreLo q (inputs.prev q) : Real) := h1.trans h2 - exact h3.trans hprev_bounds.1 - have hepsAt : - ∀ q, epsAt q = - if marginAt q < 0 then (1 : Rat) else (seq - 1 : Rat) / (1 + marginAt q) := by - intro q - simp [epsAt] - have oneHot_bounds_at : - ∀ q, q ∈ inputs.active → - Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAt q : Real) - (fun q' => q' = q) inputs.prev weights := by - intro q hq - exact - Sound.oneHot_bounds_at_of_marginAt - (active := inputs.active) - (prev := inputs.prev) - (scoresReal := scoresReal) - (marginAt := marginAt) - (epsAt := epsAt) - (hepsAt := hepsAt) - (hseq := hseq) - (hscore_margin_real_at := hscore_margin_real_at) - q hq - have hvals_bounds : - ValueIntervalBounds (vals := valsRealOfInputs inputs) valCert := by - refine - { lo_le_hi := ?_ - lo_le_valsLo := ?_ - vals_bounds := ?_ - valsHi_le_hi := ?_ } - · rcases (Finset.univ_nonempty : univ.Nonempty) with ⟨k0, hk0⟩ - have hmem0 : k0 ∈ univ := hk0 - have hlo : (valCert.lo : Real) ≤ (valCert.valsLo k0 : Real) := by - have hloRat : valCert.lo ≤ valCert.valsLo k0 := by - change lo ≤ valsLo k0 - dsimp [lo] - refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) - (f := valsLo) (a := valsLo k0)).2 ?_ - refine ⟨k0, hmem0, ?_⟩ - exact le_rfl - exact (Rat.cast_le (K := Real)).2 hloRat - have hvals : - (valCert.valsLo k0 : Real) ≤ valsRealOfInputs inputs k0 ∧ - valsRealOfInputs inputs k0 ≤ (valCert.valsHi k0 : Real) := by - have hv := hv_bounds k0 - have hlo' : ∀ d, (vLo k0 d : Real) ≤ vRealOfInputs inputs k0 d := fun d => (hv d).1 - have hhi' : ∀ d, vRealOfInputs inputs k0 d ≤ (vHi k0 d : Real) := fun d => (hv d).2 - have hlow := - dotIntervalLower_le_dotProduct_real (v := dirHead) - (lo := vLo k0) (hi := vHi k0) - (x := fun d => vRealOfInputs inputs k0 d) hlo' hhi' - have hhigh := - dotProduct_le_dotIntervalUpper_real (v := dirHead) - (lo := vLo k0) (hi := vHi k0) - (x := fun d => vRealOfInputs inputs k0 d) hlo' hhi' - have hlow' : - (valCert.valsLo k0 : Real) ≤ valsRealOfInputs inputs k0 := by - simpa [valsLo, valCert, dirHead, valsRealOfInputs] using hlow - have hhigh' : - valsRealOfInputs inputs k0 ≤ (valCert.valsHi k0 : Real) := by - simpa [valsHi, valCert, dirHead, valsRealOfInputs] using hhigh - exact ⟨hlow', hhigh'⟩ - have hhi : (valCert.valsHi k0 : Real) ≤ (valCert.hi : Real) := by - have hhiRat : valCert.valsHi k0 ≤ valCert.hi := by - change valsHi k0 ≤ hi - dsimp [hi] - refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) - (f := valsHi) (a := valsHi k0)).2 ?_ - exact ⟨k0, ⟨hmem0, le_rfl⟩⟩ - exact (Rat.cast_le (K := Real)).2 hhiRat - have hreal : - (valCert.lo : Real) ≤ (valCert.hi : Real) := - le_trans hlo (le_trans hvals.1 (le_trans hvals.2 hhi)) - exact (Rat.cast_le (K := Real)).1 hreal - · intro k - have hmem : k ∈ univ := by simp [univ] - have hloRat : valCert.lo ≤ valCert.valsLo k := by - change lo ≤ valsLo k - dsimp [lo] - refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) - (f := valsLo) (a := valsLo k)).2 ?_ - refine ⟨k, hmem, ?_⟩ - exact le_rfl - exact (Rat.cast_le (K := Real)).2 hloRat - · intro k - have hv := hv_bounds k - have hlo' : ∀ d, (vLo k d : Real) ≤ vRealOfInputs inputs k d := fun d => (hv d).1 - have hhi' : ∀ d, vRealOfInputs inputs k d ≤ (vHi k d : Real) := fun d => (hv d).2 - have hlow := - dotIntervalLower_le_dotProduct_real (v := dirHead) - (lo := vLo k) (hi := vHi k) - (x := fun d => vRealOfInputs inputs k d) hlo' hhi' - have hhigh := - dotProduct_le_dotIntervalUpper_real (v := dirHead) - (lo := vLo k) (hi := vHi k) - (x := fun d => vRealOfInputs inputs k d) hlo' hhi' - have hlow' : - (valCert.valsLo k : Real) ≤ valsRealOfInputs inputs k := by - simpa [valsLo, valCert, dirHead, valsRealOfInputs] using hlow - have hhigh' : - valsRealOfInputs inputs k ≤ (valCert.valsHi k : Real) := by - simpa [valsHi, valCert, dirHead, valsRealOfInputs] using hhigh - exact ⟨hlow', hhigh'⟩ - · intro k - have hmem : k ∈ univ := by simp [univ] - have hhiRat : valCert.valsHi k ≤ valCert.hi := by - change valsHi k ≤ hi - dsimp [hi] - refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) - (f := valsHi) (a := valsHi k)).2 ?_ - refine ⟨k, hmem, ?_⟩ - exact le_rfl - exact (Rat.cast_le (K := Real)).2 hhiRat - exact some ⟨cert, - { softmax_bounds := hsoftmax_bounds - oneHot_bounds_at := oneHot_bounds_at - value_bounds := hvals_bounds }⟩ - · exact none - -section HeadOutputInterval - -variable {seq dModel dHead : Nat} - -noncomputable section - -/-- Real-valued head output using explicit score inputs. -/ -def headOutputWithScores (scores : Fin seq → Fin seq → Real) - (inputs : Model.InductionHeadInputs seq dModel dHead) - (q : Fin seq) (i : Fin dModel) : Real := - let weights : Fin seq → Fin seq → Real := fun q k => - Circuit.softmax (scores q) k - let vals : Fin seq → Real := fun k => headValueRealOfInputs inputs k i - dotProduct (weights q) vals - -/-- Unfolding lemma for `headOutputWithScores`. -/ -theorem headOutputWithScores_def (scores : Fin seq → Fin seq → Real) - (inputs : Model.InductionHeadInputs seq dModel dHead) - (q : Fin seq) (i : Fin dModel) : - headOutputWithScores scores inputs q i = - let weights : Fin seq → Fin seq → Real := fun q k => - Circuit.softmax (scores q) k - let vals : Fin seq → Real := fun k => headValueRealOfInputs inputs k i - dotProduct (weights q) vals := rfl - -/-- Real-valued head output for a query and model dimension. -/ -def headOutput (inputs : Model.InductionHeadInputs seq dModel dHead) - (q : Fin seq) (i : Fin dModel) : Real := - headOutputWithScores (scoresRealOfInputs inputs) inputs q i - -/-- Unfolding lemma for `headOutput`. -/ -theorem headOutput_def (inputs : Model.InductionHeadInputs seq dModel dHead) - (q : Fin seq) (i : Fin dModel) : - headOutput inputs q i = - headOutputWithScores (scoresRealOfInputs inputs) inputs q i := rfl - -/-- Soundness predicate for head-output interval bounds. -/ -structure HeadOutputIntervalSound [NeZero seq] - (inputs : Model.InductionHeadInputs seq dModel dHead) - (active : Finset (Fin seq)) - (c : Circuit.ResidualIntervalCert dModel) : Prop where - /-- Interval bounds are ordered coordinatewise. -/ - bounds : Circuit.ResidualIntervalBounds c - /-- Active-query outputs lie inside the interval bounds. -/ - output_mem : - ∀ q, q ∈ active → ∀ i, - (c.lo i : Real) ≤ headOutput inputs q i ∧ - headOutput inputs q i ≤ (c.hi i : Real) - -/-- Certified head-output interval data for a specific active set. -/ -structure HeadOutputIntervalResult [NeZero seq] - (inputs : Model.InductionHeadInputs seq dModel dHead) where - /-- Active queries covered by the interval bounds. -/ - active : Finset (Fin seq) - /-- Residual-interval certificate for head outputs. -/ - cert : Circuit.ResidualIntervalCert dModel - /-- Soundness proof for the interval bounds. -/ - sound : HeadOutputIntervalSound inputs active cert - -/-- Build residual-interval bounds for head outputs on active queries. -/ -def buildHeadOutputIntervalFromHead? [NeZero seq] - (inputs : Model.InductionHeadInputs seq dModel dHead) : - Option (HeadOutputIntervalResult inputs) := by - classical - cases seq with - | zero => - cases (NeZero.ne (n := (0 : Nat)) rfl) - | succ n => - by_cases hEps : 0 < inputs.lnEps - · by_cases hmodel : dModel = 0 - · exact none - · cases hbuild : buildInductionCertFromHead? inputs with - | none => exact none - | some certWithProof => - rcases certWithProof with ⟨cert, hcert⟩ - let lnBounds : Fin (Nat.succ n) → (Fin dModel → Rat) × (Fin dModel → Rat) := fun q => - Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q) - let lnLo : Fin (Nat.succ n) → Fin dModel → Rat := fun q => (lnBounds q).1 - let lnHi : Fin (Nat.succ n) → Fin dModel → Rat := fun q => (lnBounds q).2 - let vLo : Fin (Nat.succ n) → Fin dHead → Rat := fun q d => - dotIntervalLower (fun j => inputs.wv j d) (lnLo q) (lnHi q) + inputs.bv d - let vHi : Fin (Nat.succ n) → Fin dHead → Rat := fun q d => - dotIntervalUpper (fun j => inputs.wv j d) (lnLo q) (lnHi q) + inputs.bv d - let headValueLo : Fin (Nat.succ n) → Fin dModel → Rat := fun k i => - dotIntervalLower (fun d => inputs.wo i d) (vLo k) (vHi k) - let headValueHi : Fin (Nat.succ n) → Fin dModel → Rat := fun k i => - dotIntervalUpper (fun d => inputs.wo i d) (vLo k) (vHi k) - have hln_bounds : - ∀ q i, (lnLo q i : Real) ≤ lnRealOfInputs inputs q i ∧ - lnRealOfInputs inputs q i ≤ (lnHi q i : Real) := by - intro q i - have hln := - Bounds.layerNormBounds_spec (eps := inputs.lnEps) - (gamma := inputs.ln1Gamma) (beta := inputs.ln1Beta) - (x := inputs.embed q) hmodel hEps - simpa [lnBounds, lnLo, lnHi, lnRealOfInputs] using hln i - have hv_bounds : - ∀ q d, (vLo q d : Real) ≤ vRealOfInputs inputs q d ∧ - vRealOfInputs inputs q d ≤ (vHi q d : Real) := by - intro q d - have hln := hln_bounds q - have hlo : ∀ j, (lnLo q j : Real) ≤ lnRealOfInputs inputs q j := fun j => (hln j).1 - have hhi : ∀ j, lnRealOfInputs inputs q j ≤ (lnHi q j : Real) := fun j => (hln j).2 - have hlow := - dotIntervalLower_le_dotProduct_real (v := fun j => inputs.wv j d) - (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi - have hhigh := - dotProduct_le_dotIntervalUpper_real (v := fun j => inputs.wv j d) - (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi - have hlow' := add_le_add_right hlow (inputs.bv d : Real) - have hhigh' := add_le_add_right hhigh (inputs.bv d : Real) - constructor - · simpa [vLo, vRealOfInputs, Rat.cast_add] using hlow' - · simpa [vHi, vRealOfInputs, Rat.cast_add] using hhigh' - have hhead_bounds : - ∀ k i, (headValueLo k i : Real) ≤ headValueRealOfInputs inputs k i ∧ - headValueRealOfInputs inputs k i ≤ (headValueHi k i : Real) := by - intro k i - have hv := hv_bounds k - have hlo : ∀ d, (vLo k d : Real) ≤ vRealOfInputs inputs k d := fun d => (hv d).1 - have hhi : ∀ d, vRealOfInputs inputs k d ≤ (vHi k d : Real) := fun d => (hv d).2 - have hlow := - dotIntervalLower_le_dotProduct_real (v := fun d => inputs.wo i d) - (lo := vLo k) (hi := vHi k) - (x := fun d => vRealOfInputs inputs k d) hlo hhi - have hhigh := - dotProduct_le_dotIntervalUpper_real (v := fun d => inputs.wo i d) - (lo := vLo k) (hi := vHi k) - (x := fun d => vRealOfInputs inputs k d) hlo hhi - constructor - · simpa [headValueLo, headValueRealOfInputs] using hlow - · simpa [headValueHi, headValueRealOfInputs] using hhigh - let scoresReal : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := - scoresRealOfInputs inputs - let weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := fun q k => - Circuit.softmax (scoresReal q) k - let activeSet : Finset (Fin (Nat.succ n)) := cert.active - let univ : Finset (Fin (Nat.succ n)) := Finset.univ - have huniv : univ.Nonempty := by simp [univ] - let loVal : Fin dModel → Rat := fun i => - univ.inf' huniv (fun k => headValueLo k i) - let hiVal : Fin dModel → Rat := fun i => - univ.sup' huniv (fun k => headValueHi k i) - have hvalsBoundsReal : - ∀ i, Layers.ValueRangeBounds (Val := Real) - (loVal i : Real) (hiVal i : Real) - (fun k => headValueRealOfInputs inputs k i) := by - intro i - refine { lo_le_hi := ?_, lo_le := ?_, le_hi := ?_ } - · rcases (Finset.univ_nonempty : univ.Nonempty) with ⟨k0, hk0⟩ - have hmem0 : k0 ∈ univ := hk0 - have hloRat : loVal i ≤ headValueLo k0 i := by - change loVal i ≤ headValueLo k0 i - dsimp [loVal] - refine (Finset.inf'_le_iff (s := univ) (H := huniv) - (f := fun k => headValueLo k i) (a := headValueLo k0 i)).2 ?_ - refine ⟨k0, hmem0, ?_⟩ - exact le_rfl - have hhiRat : headValueHi k0 i ≤ hiVal i := by - change headValueHi k0 i ≤ hiVal i - dsimp [hiVal] - refine (Finset.le_sup'_iff (s := univ) (H := huniv) - (f := fun k => headValueHi k i) (a := headValueHi k0 i)).2 ?_ - exact ⟨k0, ⟨hmem0, le_rfl⟩⟩ - have hbounds := hhead_bounds k0 i - have hreal : - (loVal i : Real) ≤ (hiVal i : Real) := - le_trans ((Rat.cast_le (K := Real)).2 hloRat) - (le_trans hbounds.1 (le_trans hbounds.2 ((Rat.cast_le (K := Real)).2 hhiRat))) - exact hreal - · intro k - have hmem : k ∈ univ := by simp [univ] - have hloRat : loVal i ≤ headValueLo k i := by - change loVal i ≤ headValueLo k i - dsimp [loVal] - refine (Finset.inf'_le_iff (s := univ) (H := huniv) - (f := fun k => headValueLo k i) (a := headValueLo k i)).2 ?_ - refine ⟨k, hmem, ?_⟩ - exact le_rfl - have hbounds := hhead_bounds k i - exact (Rat.cast_le (K := Real)).2 hloRat |>.trans hbounds.1 - · intro k - have hmem : k ∈ univ := by simp [univ] - have hhiRat : headValueHi k i ≤ hiVal i := by - change headValueHi k i ≤ hiVal i - dsimp [hiVal] - refine (Finset.le_sup'_iff (s := univ) (H := huniv) - (f := fun k => headValueHi k i) (a := headValueHi k i)).2 ?_ - exact ⟨k, ⟨hmem, le_rfl⟩⟩ - have hbounds := hhead_bounds k i - exact hbounds.2.trans ((Rat.cast_le (K := Real)).2 hhiRat) - have hsoftmax : - Layers.SoftmaxMarginBoundsOn (Val := Real) (cert.eps : Real) (cert.margin : Real) - (fun q => q ∈ activeSet) cert.prev scoresReal weights := by - simpa [scoresReal, weights, activeSet] using hcert.softmax_bounds - have hweights : - Layers.OneHotApproxBoundsOnActive (Val := Real) (cert.eps : Real) - (fun q => q ∈ activeSet) cert.prev weights := - Layers.oneHotApproxBoundsOnActive_of_softmaxMargin - (Val := Real) - (ε := (cert.eps : Real)) - (margin := (cert.margin : Real)) - (active := fun q => q ∈ activeSet) - (prev := cert.prev) - (scores := scoresReal) - (weights := weights) - hsoftmax - have happrox : - ∀ i, Layers.InductionSpecApproxOn (Val := Real) (n := n) - ((cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real))) - (fun q => q ∈ activeSet) cert.prev - (fun q => dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i)) - (fun k => headValueRealOfInputs inputs k i) := by - intro i - exact - Layers.inductionSpecApproxOn_of_oneHotApprox_valueRange - (Val := Real) - (n := n) - (ε := (cert.eps : Real)) - (lo := (loVal i : Real)) - (hi := (hiVal i : Real)) - (active := fun q => q ∈ activeSet) - (prev := cert.prev) - (weights := weights) - (vals := fun k => headValueRealOfInputs inputs k i) - (hweights := hweights) - (hvals := hvalsBoundsReal i) - let delta : Fin dModel → Rat := fun i => hiVal i - loVal i - let boundLoRat : Fin (Nat.succ n) → Fin dModel → Rat := fun q i => - headValueLo (cert.prev q) i - cert.eps * delta i - let boundHiRat : Fin (Nat.succ n) → Fin dModel → Rat := fun q i => - headValueHi (cert.prev q) i + cert.eps * delta i - let loOut : Fin dModel → Rat := fun i => - if h : activeSet.Nonempty then - activeSet.inf' h (fun q => boundLoRat q i) - else - 0 - let hiOut : Fin dModel → Rat := fun i => - if h : activeSet.Nonempty then - activeSet.sup' h (fun q => boundHiRat q i) - else - 0 - have hout : - ∀ q, q ∈ activeSet → ∀ i, - (loOut i : Real) ≤ headOutput inputs q i ∧ - headOutput inputs q i ≤ (hiOut i : Real) := by - intro q hq i - have hactive : activeSet.Nonempty := ⟨q, hq⟩ - have hspec := (happrox i) q hq - have hout_def : - headOutput inputs q i = - dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) := by - simp [headOutput, headOutputWithScores, scoresReal, weights] - have hprev_bounds := hhead_bounds (cert.prev q) i - have hupper : - headOutput inputs q i ≤ (boundHiRat q i : Real) := by - have hupper' : - dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) ≤ - headValueRealOfInputs inputs (cert.prev q) i + - (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) := by - exact hspec.1 - have hupper'' : - dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) ≤ - (headValueHi (cert.prev q) i : Real) + - (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) := by - have hprev_bounds' := - (add_le_add_iff_right - ((cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)))).2 - hprev_bounds.2 - exact le_trans hupper' hprev_bounds' - simpa - [hout_def, boundHiRat, delta, Rat.cast_add, Rat.cast_mul, Rat.cast_sub] using - hupper'' - have hlower : - (boundLoRat q i : Real) ≤ headOutput inputs q i := by - have hlower' : - (headValueRealOfInputs inputs (cert.prev q) i : Real) - - (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) ≤ - dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) := by - exact (sub_le_iff_le_add).2 hspec.2 - have hlower'' : - (headValueLo (cert.prev q) i : Real) - - (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) ≤ - dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) := by - exact le_trans (sub_le_sub_right hprev_bounds.1 - ((cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)))) hlower' - simpa [hout_def, boundLoRat, delta, Rat.cast_mul, Rat.cast_sub] using - hlower'' - have hlo : - (loOut i : Real) ≤ (boundLoRat q i : Real) := by - have hloRat : loOut i ≤ boundLoRat q i := by - simpa [loOut, hactive] using - (Finset.inf'_le (s := activeSet) (f := fun q => boundLoRat q i) (b := q) hq) - exact (Rat.cast_le (K := Real)).2 hloRat - have hhi : - (boundHiRat q i : Real) ≤ (hiOut i : Real) := by - have hhiRat : boundHiRat q i ≤ hiOut i := by - simpa [hiOut, hactive] using - (Finset.le_sup' (s := activeSet) (f := fun q => boundHiRat q i) (b := q) hq) - exact (Rat.cast_le (K := Real)).2 hhiRat - exact ⟨le_trans hlo hlower, le_trans hupper hhi⟩ - have hbounds : Circuit.ResidualIntervalBounds { lo := loOut, hi := hiOut } := by - refine { lo_le_hi := ?_ } - intro i - by_cases hactive : activeSet.Nonempty - · rcases hactive with ⟨q, hq⟩ - have hout_i := hout q hq i - have hleReal : (loOut i : Real) ≤ (hiOut i : Real) := - le_trans hout_i.1 hout_i.2 - exact (Rat.cast_le (K := Real)).1 hleReal - · simp [loOut, hiOut, hactive] - let certOut : Circuit.ResidualIntervalCert dModel := { lo := loOut, hi := hiOut } - exact some - { active := activeSet - cert := certOut - sound := - { bounds := hbounds - output_mem := by - intro q hq i - exact hout q hq i } } - · exact none - -end - -end HeadOutputInterval - -end Sound - -end Nfp diff --git a/Nfp/Sound/Induction/Core.lean b/Nfp/Sound/Induction/Core.lean new file mode 100644 index 0000000..f80d53c --- /dev/null +++ b/Nfp/Sound/Induction/Core.lean @@ -0,0 +1,1191 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.BigOperators.Group.Finset.Basic +import Mathlib.Algebra.Order.BigOperators.Group.Finset +import Nfp.Core.Basic +import Mathlib.Data.Finset.Lattice.Fold +import Nfp.Circuit.Cert.ResidualInterval +import Nfp.Circuit.Cert.SoftmaxMargin +import Nfp.Circuit.Cert.ValueRange +import Nfp.Sound.Bounds.Attention +import Nfp.Sound.Bounds.LayerNorm +import Nfp.Sound.Bounds.MatrixNorm +import Nfp.Sound.Induction.CoreDefs +import Nfp.Sound.Induction.OneHot +import Nfp.Sound.Linear.FinFold + +/-! +Sound builders for induction certificates. + +These builders recompute certificate bounds inside Lean from exact inputs and +return proof-carrying results. The head-input path derives softmax tolerances +from score margins rather than trusting external weight dumps. +-/ + +namespace Nfp + +namespace Sound + +open scoped BigOperators + +open Nfp.Circuit +open Nfp.Sound.Bounds + +variable {seq : Nat} + +/-- Build and certify a softmax-margin certificate from exact scores/weights. -/ +def buildSoftmaxMarginCert? [NeZero seq] + (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (scores : Fin seq → Fin seq → Dyadic) + (weights : Fin seq → Fin seq → Dyadic) : + Option {c : SoftmaxMarginCert seq // checkSoftmaxMarginCert c = true} := by + classical + let otherKeys : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (prev q) + let epsAt : Fin seq → Dyadic := fun q => + let other := otherKeys q + let maxOther := + if h : other.Nonempty then + other.sup' h (fun k => weights q k) + else + (0 : Dyadic) + let deficit := (1 : Dyadic) - weights q (prev q) + max maxOther deficit + let marginAt : Fin seq → Dyadic := fun q => + let other := otherKeys q + if h : other.Nonempty then + other.inf' h (fun k => scores q (prev q) - scores q k) + else + (0 : Dyadic) + let eps := + if h : active.Nonempty then + active.sup' h epsAt + else + (0 : Dyadic) + let margin := + if h : active.Nonempty then + active.inf' h marginAt + else + (0 : Dyadic) + let cert : SoftmaxMarginCert seq := + { eps := eps + margin := margin + active := active + prev := prev + scores := scores + weights := weights } + if h : checkSoftmaxMarginCert cert = true then + exact some ⟨cert, h⟩ + else + exact none + +/-- Build and certify a value-range certificate from exact values. -/ +def buildValueRangeCert? [NeZero seq] + (vals : Fin seq → Dyadic) + (direction : Option DirectionSpec) : + Option {c : ValueRangeCert seq // checkValueRangeCert c = true} := by + classical + let _ : Nonempty (Fin seq) := by + refine ⟨⟨0, ?_⟩⟩ + exact Nat.pos_of_ne_zero (NeZero.ne seq) + let univ : Finset (Fin seq) := Finset.univ + let hnonempty : univ.Nonempty := Finset.univ_nonempty + let lo := univ.inf' hnonempty vals + let hi := univ.sup' hnonempty vals + let cert : ValueRangeCert seq := + { lo := lo + hi := hi + vals := vals + direction := direction } + if h : checkValueRangeCert cert = true then + exact some ⟨cert, h⟩ + else + exact none + +/-- Build induction certificates from exact head inputs (core computation). -/ +def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : + Option (InductionHeadCert seq) := by + classical + by_cases hEps : 0 < inputs.lnEps + · by_cases hSqrt : 0 < sqrtLower inputs.lnEps + · by_cases hmodel : dModel = 0 + · exact none + · by_cases hactive : inputs.active.Nonempty + · let lnBounds := Bounds.cacheBoundPair2 (fun q => + Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) + let lnLo : Fin seq → Fin dModel → Dyadic := lnBounds.1 + let lnHi : Fin seq → Fin dModel → Dyadic := lnBounds.2 + let qLo := + Bounds.cacheBound2 (fun q d => + Bounds.dotIntervalLowerCachedDyadic (fun j => inputs.wq j d) (lnLo q) (lnHi q) + + inputs.bq d) + let qHi := + Bounds.cacheBound2 (fun q d => + Bounds.dotIntervalUpperCachedDyadic (fun j => inputs.wq j d) (lnLo q) (lnHi q) + + inputs.bq d) + let kLo := + Bounds.cacheBound2 (fun q d => + Bounds.dotIntervalLowerCachedDyadic (fun j => inputs.wk j d) (lnLo q) (lnHi q) + + inputs.bk d) + let kHi := + Bounds.cacheBound2 (fun q d => + Bounds.dotIntervalUpperCachedDyadic (fun j => inputs.wk j d) (lnLo q) (lnHi q) + + inputs.bk d) + let vLo := + Bounds.cacheBound2 (fun q d => + Bounds.dotIntervalLowerCachedDyadic (fun j => inputs.wv j d) (lnLo q) (lnHi q) + + inputs.bv d) + let vHi := + Bounds.cacheBound2 (fun q d => + Bounds.dotIntervalUpperCachedDyadic (fun j => inputs.wv j d) (lnLo q) (lnHi q) + + inputs.bv d) + let qAbs := Bounds.cacheBound2 (fun q d => max |qLo q d| |qHi q d|) + let kAbs := Bounds.cacheBound2 (fun q d => max |kLo q d| |kHi q d|) + let kAbsMax : Fin dHead → Dyadic := fun d => + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := Finset.univ_nonempty + univ.sup' hnonempty (fun k => kAbs k d) + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let dotAbs : Fin seq → Fin seq → Dyadic := fun q k => + Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d) + let scoreBaseAbs : Fin seq → Fin seq → Dyadic := fun q k => + |inputs.scale| * dotAbs q k + let scoreAbs : Fin seq → Fin seq → Dyadic := fun q k => + if masked q k then |inputs.maskValue| else scoreBaseAbs q k + let scoreLo : Fin seq → Fin seq → Dyadic := fun q k => + if masked q k then inputs.maskValue else -scoreBaseAbs q k + let scoreHi : Fin seq → Fin seq → Dyadic := fun q k => + if masked q k then inputs.maskValue else scoreBaseAbs q k + let dotAbsUpper : Fin seq → Dyadic := fun q => + Linear.dotFin dHead (fun d => qAbs q d) kAbsMax + let scoreHiUpper : Fin seq → Dyadic := fun q => + max inputs.maskValue (|inputs.scale| * dotAbsUpper q) + let fastGap : Fin seq → Dyadic := fun q => + let prev := inputs.prev q + scoreLo q prev - scoreHiUpper q + let otherKeys : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + let maskedKeys : Fin seq → Finset (Fin seq) := fun q => + if inputs.maskCausal = true then + (otherKeys q).filter (fun k => q < k) + else + (∅ : Finset (Fin seq)) + let unmaskedKeys : Fin seq → Finset (Fin seq) := fun q => + (otherKeys q) \ (maskedKeys q) + let maskedGap : Fin seq → Dyadic := fun q => + scoreLo q (inputs.prev q) - inputs.maskValue + let marginAtRaw : Fin seq → Dyadic := fun q => + let fast := fastGap q + if fast < 0 then + let other := unmaskedKeys q + let maskedSet := maskedKeys q + if hunmasked : other.Nonempty then + let unmaskedMin := other.inf' hunmasked (fun k => + scoreLo q (inputs.prev q) - scoreHi q k) + if maskedSet.Nonempty then + min unmaskedMin (maskedGap q) + else + unmaskedMin + else + if maskedSet.Nonempty then + maskedGap q + else + (0 : Dyadic) + else + fast + let marginAt : Fin seq → Dyadic := fun q => + if q ∈ inputs.active then + marginAtRaw q + else + (0 : Dyadic) + let epsAt : Fin seq → Dyadic := fun q => + if marginAt q < 0 then + (1 : Dyadic) + else + dyadicDivUp (seq - 1) (1 + marginAt q) + let margin : Dyadic := inputs.active.inf' hactive marginAt + let eps : Dyadic := + if margin < 0 then + (1 : Dyadic) + else + dyadicDivUp (seq - 1) (1 + margin) + let dirHeadVec := dirHeadVecOfInputs inputs + let dirHead : Fin dHead → Dyadic := fun d => dirHeadVec.get d + let valsLo := + Bounds.cacheBound (fun k => + Bounds.dotIntervalLowerCachedDyadic dirHead (vLo k) (vHi k)) + let valsHi := + Bounds.cacheBound (fun k => + Bounds.dotIntervalUpperCachedDyadic dirHead (vLo k) (vHi k)) + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := by simp [univ] + let lo := univ.inf' hnonempty valsLo + let hi := univ.sup' hnonempty valsHi + let valCert : ValueInterval seq := + { lo := lo + hi := hi + valsLo := valsLo + valsHi := valsHi + direction := some inputs.directionSpec } + let cert : InductionHeadCert seq := + { eps := eps + epsAt := epsAt + margin := margin + active := inputs.active + prev := inputs.prev + values := valCert } + exact some cert + · exact none + · exact none + · exact none + +set_option maxHeartbeats 1000000 in +-- Large softmax/interval proof expands many bounds; bump heartbeats to avoid timeouts. +/-- Soundness for `buildInductionCertFromHeadCore?`. -/ +theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) + (hcore : buildInductionCertFromHeadCore? inputs = some c) : + InductionHeadCertSound inputs c := by + classical + by_cases hEps : 0 < inputs.lnEps + · by_cases hSqrt : 0 < sqrtLower inputs.lnEps + · by_cases hmodel : dModel = 0 + · have : False := by + simp [buildInductionCertFromHeadCore?, hEps, hSqrt, hmodel] at hcore + exact this.elim + · by_cases hactive : inputs.active.Nonempty + · let lnBounds := Bounds.cacheBoundPair2 (fun q => + Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) + let lnLo : Fin seq → Fin dModel → Dyadic := lnBounds.1 + let lnHi : Fin seq → Fin dModel → Dyadic := lnBounds.2 + let qLo := + Bounds.cacheBound2 (fun q d => + Bounds.dotIntervalLowerCachedDyadic (fun j => inputs.wq j d) (lnLo q) (lnHi q) + + inputs.bq d) + let qHi := + Bounds.cacheBound2 (fun q d => + Bounds.dotIntervalUpperCachedDyadic (fun j => inputs.wq j d) (lnLo q) (lnHi q) + + inputs.bq d) + let kLo := + Bounds.cacheBound2 (fun q d => + Bounds.dotIntervalLowerCachedDyadic (fun j => inputs.wk j d) (lnLo q) (lnHi q) + + inputs.bk d) + let kHi := + Bounds.cacheBound2 (fun q d => + Bounds.dotIntervalUpperCachedDyadic (fun j => inputs.wk j d) (lnLo q) (lnHi q) + + inputs.bk d) + let vLo := + Bounds.cacheBound2 (fun q d => + Bounds.dotIntervalLowerCachedDyadic (fun j => inputs.wv j d) (lnLo q) (lnHi q) + + inputs.bv d) + let vHi := + Bounds.cacheBound2 (fun q d => + Bounds.dotIntervalUpperCachedDyadic (fun j => inputs.wv j d) (lnLo q) (lnHi q) + + inputs.bv d) + let qAbs := + Bounds.cacheBound2 (fun q d => max |qLo q d| |qHi q d|) + let kAbs := + Bounds.cacheBound2 (fun q d => max |kLo q d| |kHi q d|) + let kAbsMax : Fin dHead → Dyadic := fun d => + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := Finset.univ_nonempty + univ.sup' hnonempty (fun k => kAbs k d) + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let dotAbs : Fin seq → Fin seq → Dyadic := fun q k => + Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d) + let scoreBaseAbs : Fin seq → Fin seq → Dyadic := fun q k => + |inputs.scale| * dotAbs q k + let scoreAbs : Fin seq → Fin seq → Dyadic := fun q k => + if masked q k then |inputs.maskValue| else scoreBaseAbs q k + let scoreLo : Fin seq → Fin seq → Dyadic := fun q k => + if masked q k then inputs.maskValue else -scoreBaseAbs q k + let scoreHi : Fin seq → Fin seq → Dyadic := fun q k => + if masked q k then inputs.maskValue else scoreBaseAbs q k + let dotAbsUpper : Fin seq → Dyadic := fun q => + Linear.dotFin dHead (fun d => qAbs q d) kAbsMax + let scoreHiUpper : Fin seq → Dyadic := fun q => + max inputs.maskValue (|inputs.scale| * dotAbsUpper q) + let fastGap : Fin seq → Dyadic := fun q => + let prev := inputs.prev q + scoreLo q prev - scoreHiUpper q + let otherKeys : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + let maskedKeys : Fin seq → Finset (Fin seq) := fun q => + if inputs.maskCausal = true then + (otherKeys q).filter (fun k => q < k) + else + (∅ : Finset (Fin seq)) + let unmaskedKeys : Fin seq → Finset (Fin seq) := fun q => + (otherKeys q) \ (maskedKeys q) + let maskedGap : Fin seq → Dyadic := fun q => + scoreLo q (inputs.prev q) - inputs.maskValue + let marginAtRaw : Fin seq → Dyadic := fun q => + let fast := fastGap q + if fast < 0 then + let other := unmaskedKeys q + let maskedSet := maskedKeys q + if hunmasked : other.Nonempty then + let unmaskedMin := other.inf' hunmasked (fun k => + scoreLo q (inputs.prev q) - scoreHi q k) + if maskedSet.Nonempty then + min unmaskedMin (maskedGap q) + else + unmaskedMin + else + if maskedSet.Nonempty then + maskedGap q + else + (0 : Dyadic) + else + fast + let marginAt : Fin seq → Dyadic := fun q => + if q ∈ inputs.active then + marginAtRaw q + else + (0 : Dyadic) + let epsAt : Fin seq → Dyadic := fun q => + if marginAt q < 0 then + (1 : Dyadic) + else + dyadicDivUp (seq - 1) (1 + marginAt q) + let margin : Dyadic := inputs.active.inf' hactive marginAt + let eps : Dyadic := + if margin < 0 then + (1 : Dyadic) + else + dyadicDivUp (seq - 1) (1 + margin) + have hseq : (1 : Nat) ≤ seq := + Nat.succ_le_iff.mpr (Nat.pos_of_ne_zero (NeZero.ne seq)) + let dirHeadVec := dirHeadVecOfInputs inputs + let dirHead : Fin dHead → Dyadic := fun d => dirHeadVec.get d + let valsLo := + Bounds.cacheBound (fun k => + Bounds.dotIntervalLowerCachedDyadic dirHead (vLo k) (vHi k)) + let valsHi := + Bounds.cacheBound (fun k => + Bounds.dotIntervalUpperCachedDyadic dirHead (vLo k) (vHi k)) + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := by simp [univ] + let lo := univ.inf' hnonempty valsLo + let hi := univ.sup' hnonempty valsHi + let valCert : ValueInterval seq := + { lo := lo + hi := hi + valsLo := valsLo + valsHi := valsHi + direction := some inputs.directionSpec } + let cert : InductionHeadCert seq := + { eps := eps + epsAt := epsAt + margin := margin + active := inputs.active + prev := inputs.prev + values := valCert } + have hcore' : some cert = some c := by + simpa + [buildInductionCertFromHeadCore?, hEps, hSqrt, hmodel, hactive, lnBounds, lnLo, + lnHi, qLo, qHi, kLo, kHi, vLo, vHi, qAbs, kAbs, kAbsMax, masked, dotAbs, + scoreBaseAbs, scoreAbs, scoreLo, scoreHi, dotAbsUpper, scoreHiUpper, fastGap, + otherKeys, maskedKeys, unmaskedKeys, maskedGap, marginAt, marginAtRaw, epsAt, + margin, eps, dirHeadVec, dirHead, valsLo, valsHi, univ, lo, hi, valCert, cert] + using hcore + have hc : c = cert := by + simpa using (Option.some.inj hcore').symm + subst hc + have hln_bounds : + ∀ q i, (lnLo q i : Real) ≤ lnRealOfInputs inputs q i ∧ + lnRealOfInputs inputs q i ≤ (lnHi q i : Real) := by + intro q i + have hln := + Bounds.layerNormBounds_spec (eps := inputs.lnEps) + (gamma := inputs.ln1Gamma) (beta := inputs.ln1Beta) + (x := inputs.embed q) hmodel hEps hSqrt + simpa [lnBounds, lnLo, lnHi, lnRealOfInputs, + Bounds.cacheBoundPair2_apply_left, Bounds.cacheBoundPair2_apply_right] + using hln i + have hq_bounds : + ∀ q d, (qLo q d : Real) ≤ qRealOfInputs inputs q d ∧ + qRealOfInputs inputs q d ≤ (qHi q d : Real) := by + intro q d + have hln := hln_bounds q + have hlo : ∀ j, (lnLo q j : Real) ≤ lnRealOfInputs inputs q j := fun j => (hln j).1 + have hhi : ∀ j, lnRealOfInputs inputs q j ≤ (lnHi q j : Real) := fun j => (hln j).2 + have hlow := + dotIntervalLower_le_dotProduct_real (v := fun j => inputs.wq j d) + (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi + have hhigh := + dotProduct_le_dotIntervalUpper_real (v := fun j => inputs.wq j d) + (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi + have hlow' := add_le_add_right hlow (inputs.bq d : Real) + have hhigh' := add_le_add_right hhigh (inputs.bq d : Real) + constructor + · simpa [qLo, qRealOfInputs, Bounds.cacheBound2_apply, + Bounds.dotIntervalLowerCachedRat_eq] using hlow' + · simpa [qHi, qRealOfInputs, Bounds.cacheBound2_apply, + Bounds.dotIntervalUpperCachedRat_eq] using hhigh' + have hk_bounds : + ∀ q d, (kLo q d : Real) ≤ kRealOfInputs inputs q d ∧ + kRealOfInputs inputs q d ≤ (kHi q d : Real) := by + intro q d + have hln := hln_bounds q + have hlo : ∀ j, (lnLo q j : Real) ≤ lnRealOfInputs inputs q j := fun j => (hln j).1 + have hhi : ∀ j, lnRealOfInputs inputs q j ≤ (lnHi q j : Real) := fun j => (hln j).2 + have hlow := + dotIntervalLower_le_dotProduct_real (v := fun j => inputs.wk j d) + (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi + have hhigh := + dotProduct_le_dotIntervalUpper_real (v := fun j => inputs.wk j d) + (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi + have hlow' := add_le_add_right hlow (inputs.bk d : Real) + have hhigh' := add_le_add_right hhigh (inputs.bk d : Real) + constructor + · simpa [kLo, kRealOfInputs, Bounds.cacheBound2_apply, + Bounds.dotIntervalLowerCachedRat_eq] using hlow' + · simpa [kHi, kRealOfInputs, Bounds.cacheBound2_apply, + Bounds.dotIntervalUpperCachedRat_eq] using hhigh' + have hv_bounds : + ∀ q d, (vLo q d : Real) ≤ vRealOfInputs inputs q d ∧ + vRealOfInputs inputs q d ≤ (vHi q d : Real) := by + intro q d + have hln := hln_bounds q + have hlo : ∀ j, (lnLo q j : Real) ≤ lnRealOfInputs inputs q j := fun j => (hln j).1 + have hhi : ∀ j, lnRealOfInputs inputs q j ≤ (lnHi q j : Real) := fun j => (hln j).2 + have hlow := + dotIntervalLower_le_dotProduct_real (v := fun j => inputs.wv j d) + (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi + have hhigh := + dotProduct_le_dotIntervalUpper_real (v := fun j => inputs.wv j d) + (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi + have hlow' := add_le_add_right hlow (inputs.bv d : Real) + have hhigh' := add_le_add_right hhigh (inputs.bv d : Real) + constructor + · simpa [vLo, vRealOfInputs, Bounds.cacheBound2_apply, + Bounds.dotIntervalLowerCachedRat_eq] using hlow' + · simpa [vHi, vRealOfInputs, Bounds.cacheBound2_apply, + Bounds.dotIntervalUpperCachedRat_eq] using hhigh' + have hscore_bounds : + ∀ q k, (scoreLo q k : Real) ≤ scoresRealOfInputs inputs q k ∧ + scoresRealOfInputs inputs q k ≤ (scoreHi q k : Real) := by + intro q k + let scoresReal := scoresRealOfInputs inputs + let base := + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) + have hq_abs : ∀ d, |qRealOfInputs inputs q d| ≤ (qAbs q d : Real) := by + intro d + have hq := hq_bounds q d + have h := abs_le_max_abs_abs_of_interval_real hq.1 hq.2 + simpa [qAbs, Bounds.cacheBound2_apply] using h + have hk_abs : ∀ d, |kRealOfInputs inputs k d| ≤ (kAbs k d : Real) := by + intro d + have hk := hk_bounds k d + have h := abs_le_max_abs_abs_of_interval_real hk.1 hk.2 + simpa [kAbs, Bounds.cacheBound2_apply] using h + have hdot_abs : + |dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d)| ≤ + (dotAbs q k : Real) := by + have hsum : + |∑ d, qRealOfInputs inputs q d * kRealOfInputs inputs k d| ≤ + ∑ d, |qRealOfInputs inputs q d * kRealOfInputs inputs k d| := by + simpa [dotProduct] using + (Finset.abs_sum_le_sum_abs (s := (Finset.univ : Finset (Fin dHead))) + (f := fun d => qRealOfInputs inputs q d * kRealOfInputs inputs k d)) + have hterm : + ∀ d, + |qRealOfInputs inputs q d * kRealOfInputs inputs k d| ≤ + (qAbs q d : Real) * (kAbs k d : Real) := by + intro d + have hq := hq_abs d + have hk := hk_abs d + have hqnonneg : 0 ≤ (qAbs q d : Real) := by + have hqnonneg' : 0 ≤ qAbs q d := by + have hmax : |qLo q d| ≤ qAbs q d := by + simp [qAbs, Bounds.cacheBound2_apply] + exact le_trans (abs_nonneg _) hmax + exact dyadicToReal_nonneg_of_nonneg hqnonneg' + calc + |qRealOfInputs inputs q d * kRealOfInputs inputs k d| = + |qRealOfInputs inputs q d| * |kRealOfInputs inputs k d| := by + simp [abs_mul] + _ ≤ (qAbs q d : Real) * (kAbs k d : Real) := + mul_le_mul hq hk (abs_nonneg _) hqnonneg + have hsum_le : + ∑ d, |qRealOfInputs inputs q d * kRealOfInputs inputs k d| ≤ + ∑ d, (qAbs q d : Real) * (kAbs k d : Real) := by + refine Finset.sum_le_sum ?_ + intro d _ + exact hterm d + have hcast : + (dotAbs q k : Real) = + ∑ d, (qAbs q d : Real) * (kAbs k d : Real) := by + have hsum : + ((∑ d, qAbs q d * kAbs k d : Dyadic) : Real) = + ∑ d, ((qAbs q d * kAbs k d : Dyadic) : Real) := + Linear.dyadicToReal_sum_univ (f := fun d => qAbs q d * kAbs k d) + have hsum' : + ∑ d, ((qAbs q d * kAbs k d : Dyadic) : Real) = + ∑ d, (qAbs q d : Real) * (kAbs k d : Real) := by + refine Finset.sum_congr rfl ?_ + intro d _ + simp [dyadicToReal_mul] + have hfinal := hsum.trans hsum' + simpa [dotAbs, Linear.dotFin_eq_dotProduct, dotProduct] using hfinal + have hfinal := hsum.trans (hsum_le.trans_eq hcast.symm) + simpa [dotProduct] using hfinal + have hscale_abs : 0 ≤ (|inputs.scale| : Real) := by + exact abs_nonneg (dyadicToReal inputs.scale) + have hbase_abs : + |base| ≤ (scoreBaseAbs q k : Real) := by + have hdot_abs' := hdot_abs + have hmul : + |base| = + (|inputs.scale| : Real) * + |dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d)| := by + simp [base, abs_mul] + have hmul_le : + (|inputs.scale| : Real) * + |dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d)| ≤ + (|inputs.scale| : Real) * (dotAbs q k : Real) := by + exact mul_le_mul_of_nonneg_left hdot_abs' hscale_abs + simpa [scoreBaseAbs, hmul] using hmul_le + by_cases hcausal : inputs.maskCausal + · by_cases hle : k ≤ q + · have hnot : ¬ q < k := not_lt_of_ge hle + have hscore_eq : scoresReal q k = base := by + simp [scoresReal, scoresRealOfInputs, hcausal, hle, base] + have hscore_abs' : |scoresReal q k| ≤ (scoreBaseAbs q k : Real) := by + simpa [hscore_eq] using hbase_abs + have hscore_abs : + |scoresReal q k| ≤ (scoreAbs q k : Real) := by + simpa [scoreAbs, masked, hcausal, hnot] + using hscore_abs' + have hscore_bounds := (abs_le).1 hscore_abs + constructor + · simpa [scoresReal, scoreLo, scoreAbs, masked, hcausal, hnot] + using hscore_bounds.1 + · simpa [scoresReal, scoreHi, scoreAbs, masked, hcausal, hnot] + using hscore_bounds.2 + · have hlt : q < k := lt_of_not_ge hle + constructor + · simp [scoresRealOfInputs, scoreLo, masked, hcausal, hle, hlt] + · simp [scoresRealOfInputs, scoreHi, masked, hcausal, hle, hlt] + · have hscore_eq : scoresReal q k = base := by + simp [scoresReal, scoresRealOfInputs, hcausal, base] + have hscore_abs' : |scoresReal q k| ≤ (scoreBaseAbs q k : Real) := by + simpa [hscore_eq] using hbase_abs + have hscore_abs : + |scoresReal q k| ≤ (scoreAbs q k : Real) := by + simpa [scoreAbs, masked, hcausal] using hscore_abs' + have hscore_bounds := (abs_le).1 hscore_abs + constructor + · simpa [scoresReal, scoreLo, scoreAbs, masked, hcausal] + using hscore_bounds.1 + · simpa [scoresReal, scoreHi, scoreAbs, masked, hcausal] + using hscore_bounds.2 + let scoresReal := scoresRealOfInputs inputs + have hdotAbs_le : ∀ q k, dotAbs q k ≤ dotAbsUpper q := by + intro q k + classical + have hnonneg : ∀ d, 0 ≤ qAbs q d := by + intro d + have h0 : 0 ≤ |qLo q d| := abs_nonneg _ + have hle : |qLo q d| ≤ qAbs q d := by + simp [qAbs, Bounds.cacheBound2_apply] + exact le_trans h0 hle + have hterm : ∀ d, qAbs q d * kAbs k d ≤ qAbs q d * kAbsMax d := by + intro d + have hmem : k ∈ (Finset.univ : Finset (Fin seq)) := by simp + have hkabs : kAbs k d ≤ kAbsMax d := by + simpa [kAbsMax] using + (Finset.le_sup' (s := (Finset.univ : Finset (Fin seq))) + (f := fun k => kAbs k d) (b := k) hmem) + exact mul_le_mul_of_nonneg_left hkabs (hnonneg d) + have hsum : (∑ d, qAbs q d * kAbs k d) ≤ ∑ d, qAbs q d * kAbsMax d := by + refine Finset.sum_le_sum ?_ + intro d _ + exact hterm d + simpa [dotAbs, dotAbsUpper, Linear.dotFin_eq_dotProduct, dotProduct] using hsum + have hscoreHi_le : ∀ q k, scoreHi q k ≤ scoreHiUpper q := by + intro q k + by_cases hmask : masked q k + · simp [scoreHi, scoreHiUpper, hmask] + · have hdot := hdotAbs_le q k + have hmul : |inputs.scale| * dotAbs q k ≤ |inputs.scale| * dotAbsUpper q := by + exact mul_le_mul_of_nonneg_left hdot (abs_nonneg _) + calc + scoreHi q k = |inputs.scale| * dotAbs q k := by + simp [scoreHi, scoreBaseAbs, hmask] + _ ≤ |inputs.scale| * dotAbsUpper q := hmul + _ ≤ max inputs.maskValue (|inputs.scale| * dotAbsUpper q) := by + exact le_max_right _ _ + let weights : Fin seq → Fin seq → Real := fun q k => + Circuit.softmax (scoresReal q) k + let others : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + have hscore_margin_real : + ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → + scoresReal q k + (margin : Real) ≤ scoresReal q (inputs.prev q) := by + intro q hq k hk + have hmargin_le : margin ≤ marginAt q := by + have hle : margin ≤ inputs.active.inf' hactive marginAt := by + simp [margin] + have hle_all := + (Finset.le_inf'_iff (s := inputs.active) (H := hactive) (f := marginAt) + (a := margin)).1 hle + exact hle_all q hq + have hgap_le : + marginAt q ≤ scoreLo q (inputs.prev q) - scoreHi q k := by + by_cases hfast : fastGap q < 0 + · by_cases hmask : k ∈ maskedKeys q + · have hmask_nonempty : (maskedKeys q).Nonempty := ⟨k, hmask⟩ + have hmargin_eq : marginAt q = marginAtRaw q := by + simp [marginAt, hq] + have hraw_le : marginAtRaw q ≤ maskedGap q := by + by_cases hunmasked : (unmaskedKeys q).Nonempty + · have hraw_eq : + marginAtRaw q = + let unmaskedMin := (unmaskedKeys q).inf' hunmasked + (fun k => scoreLo q (inputs.prev q) - scoreHi q k) + min unmaskedMin (maskedGap q) := by + simp [marginAtRaw, hfast, hunmasked, hmask_nonempty] + simp [hraw_eq] + · have hraw_eq : marginAtRaw q = maskedGap q := by + simp [marginAtRaw, hfast, hunmasked, hmask_nonempty] + simp [hraw_eq] + have hcausal : inputs.maskCausal = true := by + by_contra hcausal + simp [maskedKeys, hcausal] at hmask + have hmem : + k ∈ (otherKeys q).filter (fun k => q < k) := by + simpa [maskedKeys, hcausal] using hmask + have hlt : q < k := (Finset.mem_filter.mp hmem).2 + have hmask_prop : masked q k := ⟨hcausal, hlt⟩ + have hmask_score : scoreHi q k = inputs.maskValue := by + simp [scoreHi, hmask_prop] + have hgap : marginAt q ≤ scoreLo q (inputs.prev q) - inputs.maskValue := by + simpa [hmargin_eq] using hraw_le + simpa [maskedGap, hmask_score] using hgap + · have hmem : k ∈ unmaskedKeys q := by + have hother_mem : k ∈ otherKeys q := by simp [otherKeys, hk] + simp [unmaskedKeys, hother_mem, hmask] + have hunmasked : (unmaskedKeys q).Nonempty := ⟨k, hmem⟩ + have hmargin_eq : marginAt q = marginAtRaw q := by + simp [marginAt, hq] + have hraw_le : marginAtRaw q ≤ + (unmaskedKeys q).inf' hunmasked + (fun k => scoreLo q (inputs.prev q) - scoreHi q k) := by + let unmaskedMin := + (unmaskedKeys q).inf' hunmasked + (fun k => scoreLo q (inputs.prev q) - scoreHi q k) + by_cases hmask_nonempty : (maskedKeys q).Nonempty + · have hraw_eq : marginAtRaw q = min unmaskedMin (maskedGap q) := by + simp [marginAtRaw, hfast, hunmasked, hmask_nonempty, unmaskedMin, + unmaskedKeys, maskedKeys] + have hmin_le : marginAtRaw q ≤ unmaskedMin := by + rw [hraw_eq] + exact min_le_left _ _ + exact hmin_le + · simp [marginAtRaw, hfast, hunmasked, hmask_nonempty, unmaskedKeys, + maskedKeys] + have hle_all := + (Finset.le_inf'_iff (s := unmaskedKeys q) (H := hunmasked) + (f := fun k => scoreLo q (inputs.prev q) - scoreHi q k) + (a := marginAtRaw q)).1 hraw_le + have hle := hle_all k hmem + simpa [hmargin_eq] using hle + · have hgap_fast : fastGap q ≤ scoreLo q (inputs.prev q) - scoreHi q k := by + have hle_score : scoreHi q k ≤ scoreHiUpper q := hscoreHi_le q k + have hle_sub : + scoreLo q (inputs.prev q) - scoreHiUpper q ≤ + scoreLo q (inputs.prev q) - scoreHi q k := + sub_le_sub_left hle_score (scoreLo q (inputs.prev q)) + simpa [fastGap] using hle_sub + have hmargin_eq : marginAt q = fastGap q := by + simp [marginAt, marginAtRaw, hq, hfast] + simpa [hmargin_eq] using hgap_fast + have hgap : margin ≤ scoreLo q (inputs.prev q) - scoreHi q k := + le_trans hmargin_le hgap_le + have hgap_real : (margin : Real) ≤ + (scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real) := by + have hgap_real' : + (margin : Real) ≤ ((scoreLo q (inputs.prev q) - scoreHi q k : Dyadic) : Real) := + dyadicToReal_le_of_le hgap + simpa [dyadicToReal_sub] using hgap_real' + have hk_bounds := hscore_bounds q k + have hprev_bounds := hscore_bounds q (inputs.prev q) + have h1 : + scoresReal q k + (margin : Real) ≤ + scoresReal q k + + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) := by + exact (add_le_add_iff_left (scoresReal q k)).2 hgap_real + have h2 : + scoresReal q k + + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) ≤ + (scoreLo q (inputs.prev q) : Real) := by + have hscore_le' : + scoresReal q k + + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) ≤ + (scoreHi q k : Real) + + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) := by + exact (add_le_add_iff_right + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real))).2 hk_bounds.2 + calc + scoresReal q k + + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) ≤ + (scoreHi q k : Real) + + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) := by + exact hscore_le' + _ = (scoreLo q (inputs.prev q) : Real) := by + exact add_sub_cancel (scoreHi q k : Real) (scoreLo q (inputs.prev q) : Real) + have h3 : + scoresReal q k + (margin : Real) ≤ (scoreLo q (inputs.prev q) : Real) := + h1.trans h2 + exact h3.trans hprev_bounds.1 + have hsoftmax_bounds : + Layers.SoftmaxMarginBoundsOn (Val := Real) (eps : Real) (margin : Real) + (fun q => q ∈ inputs.active) inputs.prev scoresReal weights := by + classical + refine + { score_margin := ?_ + nonneg := ?_ + sum_one := ?_ + prev_large := ?_ + other_le := ?_ } + · intro q hq k hk + exact hscore_margin_real q hq k hk + · intro q _ k + simpa [weights] using + (Circuit.softmax_nonneg (scores := scoresReal q) k) + · intro q _ + simpa [weights] using + (Circuit.softmax_sum_one (scores := scoresReal q)) + · intro q hq + have hsum_others_le : (∑ k ∈ others q, weights q k) ≤ (eps : Real) := by + by_cases hneg : margin < 0 + · have heps : (eps : Real) = 1 := by + simp [eps, hneg] + have hsubset : others q ⊆ (Finset.univ : Finset (Fin seq)) := by + intro k hk + simp + have hnonneg : + ∀ k ∈ (Finset.univ : Finset (Fin seq)), 0 ≤ weights q k := by + intro k _ + simpa [weights] using + (Circuit.softmax_nonneg (scores := scoresReal q) k) + have hsum_le : + (∑ k ∈ others q, weights q k) ≤ + ∑ k ∈ (Finset.univ : Finset (Fin seq)), weights q k := + Finset.sum_le_sum_of_subset_of_nonneg hsubset (by + intro k hk _; exact hnonneg k hk) + have hsum_one : (∑ k, weights q k) = 1 := by + simpa [weights] using + (Circuit.softmax_sum_one (scores := scoresReal q)) + have hsum_le' : (∑ k ∈ others q, weights q k) ≤ 1 := by + simpa [hsum_one] using hsum_le + simpa [heps] using hsum_le' + · have hnonneg : 0 ≤ margin := le_of_not_gt hneg + have hnonneg_real : 0 ≤ (margin : Real) := by + exact dyadicToReal_nonneg_of_nonneg hnonneg + have hbound : + ∀ k ∈ others q, + weights q k ≤ (1 + (margin : Real))⁻¹ := by + intro k hk + have hkne : k ≠ inputs.prev q := (Finset.mem_erase.mp hk).1 + have hscore := hscore_margin_real q hq k hkne + simpa [weights] using + (Circuit.softmax_other_le_inv_one_add (scores := scoresReal q) + (prev := inputs.prev q) (k := k) (m := (margin : Real)) + hnonneg_real hscore) + have hsum_le : + (∑ k ∈ others q, weights q k) ≤ + ∑ k ∈ others q, (1 + (margin : Real))⁻¹ := + Finset.sum_le_sum hbound + have hsum_const : + (∑ k ∈ others q, (1 + (margin : Real))⁻¹) = + (others q).card * (1 + (margin : Real))⁻¹ := by + simp + have hcard : (others q).card = seq - 1 := by + simp [others, Finset.card_erase_of_mem] + have hsum_le' : + (∑ k ∈ others q, weights q k) ≤ + (seq - 1 : Real) * (1 + (margin : Real))⁻¹ := by + have hsum_le'' := hsum_le.trans_eq hsum_const + have hsum_le''' := hsum_le'' + simp only [hcard, Nat.cast_sub hseq, Nat.cast_one] at hsum_le''' + exact hsum_le''' + have hden : (1 + margin) ≠ 0 := by + have hnonneg_real : 0 ≤ (margin : Real) := + dyadicToReal_nonneg_of_nonneg hnonneg + have hpos_real : (0 : Real) < 1 + (margin : Real) := by + linarith + have hpos_real' : dyadicToReal 0 < dyadicToReal (1 + margin) := by + simpa [dyadicToReal_add] using hpos_real + have hpos : (0 : Dyadic) < 1 + margin := + (dyadicToReal_lt_iff).1 hpos_real' + exact ne_of_gt hpos + have heps : + (seq - 1 : Real) * (1 + (margin : Real))⁻¹ ≤ (eps : Real) := by + have hrat' := dyadicDivUp_ge_real (seq - 1) (1 + margin) hden + simpa [eps, hneg, dyadicToReal, Dyadic.toRat_add, Dyadic.toRat_natCast, + Rat.cast_div, Rat.cast_add, Rat.cast_natCast, div_eq_mul_inv] using hrat' + exact le_trans hsum_le' heps + have hsum_eq : + weights q (inputs.prev q) + ∑ k ∈ others q, weights q k = 1 := by + have hsum' : + weights q (inputs.prev q) + ∑ k ∈ others q, weights q k = + ∑ k, weights q k := by + simp [others] + have hsum_one : (∑ k, weights q k) = 1 := by + simpa [weights] using + (Circuit.softmax_sum_one (scores := scoresReal q)) + calc + weights q (inputs.prev q) + ∑ k ∈ others q, weights q k = + ∑ k, weights q k := hsum' + _ = 1 := hsum_one + have hsum_le' : + weights q (inputs.prev q) + ∑ k ∈ others q, weights q k ≤ + weights q (inputs.prev q) + (eps : Real) := by + have hsum_le'' := add_le_add_left hsum_others_le (weights q (inputs.prev q)) + simpa [add_comm, add_left_comm, add_assoc] using hsum_le'' + have hprev : + 1 ≤ weights q (inputs.prev q) + (eps : Real) := by + simpa [hsum_eq] using hsum_le' + exact hprev + · intro q hq k hk + have hsum_others_le : (∑ j ∈ others q, weights q j) ≤ (eps : Real) := by + by_cases hneg : margin < 0 + · have heps : (eps : Real) = 1 := by + simp [eps, hneg] + have hsubset : others q ⊆ (Finset.univ : Finset (Fin seq)) := by + intro j hj + simp + have hnonneg : + ∀ j ∈ (Finset.univ : Finset (Fin seq)), 0 ≤ weights q j := by + intro j _ + simpa [weights] using + (Circuit.softmax_nonneg (scores := scoresReal q) j) + have hsum_le : + (∑ j ∈ others q, weights q j) ≤ + ∑ j ∈ (Finset.univ : Finset (Fin seq)), weights q j := + Finset.sum_le_sum_of_subset_of_nonneg hsubset (by + intro j hj _; exact hnonneg j hj) + have hsum_one : (∑ j, weights q j) = 1 := by + simpa [weights] using + (Circuit.softmax_sum_one (scores := scoresReal q)) + have hsum_le' : (∑ j ∈ others q, weights q j) ≤ 1 := by + simpa [hsum_one] using hsum_le + simpa [heps] using hsum_le' + · have hnonneg : 0 ≤ margin := le_of_not_gt hneg + have hnonneg_real : 0 ≤ (margin : Real) := by + exact dyadicToReal_nonneg_of_nonneg hnonneg + have hbound : + ∀ j ∈ others q, + weights q j ≤ (1 + (margin : Real))⁻¹ := by + intro j hj + have hjne : j ≠ inputs.prev q := (Finset.mem_erase.mp hj).1 + have hscore := hscore_margin_real q hq j hjne + simpa [weights] using + (Circuit.softmax_other_le_inv_one_add (scores := scoresReal q) + (prev := inputs.prev q) (k := j) (m := (margin : Real)) + hnonneg_real hscore) + have hsum_le : + (∑ j ∈ others q, weights q j) ≤ + ∑ j ∈ others q, (1 + (margin : Real))⁻¹ := + Finset.sum_le_sum hbound + have hsum_const : + (∑ j ∈ others q, (1 + (margin : Real))⁻¹) = + (others q).card * (1 + (margin : Real))⁻¹ := by + simp + have hcard : (others q).card = seq - 1 := by + simp [others, Finset.card_erase_of_mem] + have hsum_le' : + (∑ j ∈ others q, weights q j) ≤ + (seq - 1 : Real) * (1 + (margin : Real))⁻¹ := by + have hsum_le'' := hsum_le.trans_eq hsum_const + have hsum_le''' := hsum_le'' + simp only [hcard, Nat.cast_sub hseq, Nat.cast_one] at hsum_le''' + exact hsum_le''' + have hden : (1 + margin) ≠ 0 := by + have hnonneg_real : 0 ≤ (margin : Real) := + dyadicToReal_nonneg_of_nonneg hnonneg + have hpos_real : (0 : Real) < 1 + (margin : Real) := by + linarith + have hpos_real' : dyadicToReal 0 < dyadicToReal (1 + margin) := by + simpa [dyadicToReal_add] using hpos_real + have hpos : (0 : Dyadic) < 1 + margin := + (dyadicToReal_lt_iff).1 hpos_real' + exact ne_of_gt hpos + have heps : + (seq - 1 : Real) * (1 + (margin : Real))⁻¹ ≤ (eps : Real) := by + have hrat' := dyadicDivUp_ge_real (seq - 1) (1 + margin) hden + simpa [eps, hneg, dyadicToReal, Dyadic.toRat_add, Dyadic.toRat_natCast, + Rat.cast_div, Rat.cast_add, Rat.cast_natCast, div_eq_mul_inv] using hrat' + exact le_trans hsum_le' heps + have hk' : k ∈ others q := by + simp [others, hk] + have hnonneg : + ∀ j ∈ others q, 0 ≤ weights q j := by + intro j _ + simpa [weights] using + (Circuit.softmax_nonneg (scores := scoresReal q) j) + have hle : + weights q k ≤ ∑ j ∈ others q, weights q j := by + have h := Finset.single_le_sum hnonneg hk' + simpa using h + exact hle.trans hsum_others_le + have hscore_margin_real_at : + ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → + scoresReal q k + (marginAt q : Real) ≤ scoresReal q (inputs.prev q) := by + intro q hq k hk + have hgap_le : + marginAt q ≤ scoreLo q (inputs.prev q) - scoreHi q k := by + by_cases hfast : fastGap q < 0 + · by_cases hmask : k ∈ maskedKeys q + · have hmask_nonempty : (maskedKeys q).Nonempty := ⟨k, hmask⟩ + have hmargin_eq : marginAt q = marginAtRaw q := by + simp [marginAt, hq] + have hraw_le : marginAtRaw q ≤ maskedGap q := by + by_cases hunmasked : (unmaskedKeys q).Nonempty + · have hraw_eq : + marginAtRaw q = + let unmaskedMin := (unmaskedKeys q).inf' hunmasked + (fun k => scoreLo q (inputs.prev q) - scoreHi q k) + min unmaskedMin (maskedGap q) := by + simp [marginAtRaw, hfast, hunmasked, hmask_nonempty] + simp [hraw_eq] + · have hraw_eq : marginAtRaw q = maskedGap q := by + simp [marginAtRaw, hfast, hunmasked, hmask_nonempty] + simp [hraw_eq] + have hcausal : inputs.maskCausal = true := by + by_contra hcausal + simp [maskedKeys, hcausal] at hmask + have hmem : + k ∈ (otherKeys q).filter (fun k => q < k) := by + simpa [maskedKeys, hcausal] using hmask + have hlt : q < k := (Finset.mem_filter.mp hmem).2 + have hmask_prop : masked q k := ⟨hcausal, hlt⟩ + have hmask_score : scoreHi q k = inputs.maskValue := by + simp [scoreHi, hmask_prop] + have hgap : marginAt q ≤ scoreLo q (inputs.prev q) - inputs.maskValue := by + simpa [hmargin_eq] using hraw_le + simpa [maskedGap, hmask_score] using hgap + · have hmem : k ∈ unmaskedKeys q := by + have hother_mem : k ∈ otherKeys q := by simp [otherKeys, hk] + simp [unmaskedKeys, hother_mem, hmask] + have hunmasked : (unmaskedKeys q).Nonempty := ⟨k, hmem⟩ + have hmargin_eq : marginAt q = marginAtRaw q := by + simp [marginAt, hq] + have hraw_le : marginAtRaw q ≤ + (unmaskedKeys q).inf' hunmasked + (fun k => scoreLo q (inputs.prev q) - scoreHi q k) := by + let unmaskedMin := + (unmaskedKeys q).inf' hunmasked + (fun k => scoreLo q (inputs.prev q) - scoreHi q k) + by_cases hmask_nonempty : (maskedKeys q).Nonempty + · have hraw_eq : marginAtRaw q = min unmaskedMin (maskedGap q) := by + simp [marginAtRaw, hfast, hunmasked, hmask_nonempty, unmaskedMin, + unmaskedKeys, maskedKeys] + rw [hraw_eq] + exact min_le_left _ _ + · simp [marginAtRaw, hfast, hunmasked, hmask_nonempty, unmaskedKeys, + maskedKeys] + have hle_all := + (Finset.le_inf'_iff (s := unmaskedKeys q) (H := hunmasked) + (f := fun k => scoreLo q (inputs.prev q) - scoreHi q k) + (a := marginAtRaw q)).1 hraw_le + have hle := hle_all k hmem + simpa [hmargin_eq] using hle + · have hgap_fast : fastGap q ≤ scoreLo q (inputs.prev q) - scoreHi q k := by + have hle_score : scoreHi q k ≤ scoreHiUpper q := hscoreHi_le q k + have hle_sub : + scoreLo q (inputs.prev q) - scoreHiUpper q ≤ + scoreLo q (inputs.prev q) - scoreHi q k := + sub_le_sub_left hle_score (scoreLo q (inputs.prev q)) + simpa [fastGap] using hle_sub + have hmargin_eq : marginAt q = fastGap q := by + simp [marginAt, marginAtRaw, hq, hfast] + simpa [hmargin_eq] using hgap_fast + have hgap_real : + (marginAt q : Real) ≤ + (scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real) := by + have hgap_real' : + (marginAt q : Real) ≤ + ((scoreLo q (inputs.prev q) - scoreHi q k : Dyadic) : Real) := + dyadicToReal_le_of_le hgap_le + simpa [dyadicToReal_sub] using hgap_real' + have hk_bounds := hscore_bounds q k + have hprev_bounds := hscore_bounds q (inputs.prev q) + have h1 : + scoresReal q k + (marginAt q : Real) ≤ + (scoreHi q k : Real) + (marginAt q : Real) := by + have h1' := add_le_add_right hk_bounds.2 (marginAt q : Real) + simpa [scoresReal] using h1' + have h2 : + (scoreHi q k : Real) + (marginAt q : Real) ≤ + (scoreLo q (inputs.prev q) : Real) := by + have hgap_real' : + (scoreHi q k : Real) + (marginAt q : Real) ≤ + (scoreHi q k : Real) + + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) := by + exact add_le_add_right hgap_real (scoreHi q k : Real) + have hgap_real'' : + (scoreHi q k : Real) + + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) = + (scoreLo q (inputs.prev q) : Real) := by + calc + (scoreHi q k : Real) + + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) = + ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) + + (scoreHi q k : Real) := by + exact add_comm _ _ + _ = (scoreLo q (inputs.prev q) : Real) := by + exact sub_add_cancel (scoreLo q (inputs.prev q) : Real) (scoreHi q k : Real) + exact hgap_real'.trans (le_of_eq hgap_real'') + have h3 : + scoresReal q k + (marginAt q : Real) ≤ + (scoreLo q (inputs.prev q) : Real) := h1.trans h2 + exact h3.trans hprev_bounds.1 + have hepsAt : + ∀ q, epsAt q = + if marginAt q < 0 then + (1 : Dyadic) + else + dyadicDivUp (seq - 1) (1 + marginAt q) := by + intro q + simp [epsAt] + have oneHot_bounds_at : + ∀ q, q ∈ inputs.active → + Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAt q : Real) + (fun q' => q' = q) inputs.prev weights := by + intro q hq + exact + Sound.oneHot_bounds_at_of_marginAt + (active := inputs.active) + (prev := inputs.prev) + (scoresReal := scoresReal) + (marginAt := marginAt) + (epsAt := epsAt) + (hepsAt := hepsAt) + (hseq := hseq) + (hscore_margin_real_at := hscore_margin_real_at) + q hq + have hvals_bounds : + ValueIntervalBounds (vals := valsRealOfInputs inputs) valCert := by + refine + { lo_le_hi := ?_ + lo_le_valsLo := ?_ + vals_bounds := ?_ + valsHi_le_hi := ?_ } + · rcases (Finset.univ_nonempty : univ.Nonempty) with ⟨k0, hk0⟩ + have hmem0 : k0 ∈ univ := hk0 + have hlo : (valCert.lo : Real) ≤ (valCert.valsLo k0 : Real) := by + have hloDyadic : valCert.lo ≤ valCert.valsLo k0 := by + change lo ≤ valsLo k0 + dsimp [lo] + refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) + (f := valsLo) (a := valsLo k0)).2 ?_ + refine ⟨k0, hmem0, ?_⟩ + exact le_rfl + exact dyadicToReal_le_of_le hloDyadic + have hvals : + (valCert.valsLo k0 : Real) ≤ valsRealOfInputs inputs k0 ∧ + valsRealOfInputs inputs k0 ≤ (valCert.valsHi k0 : Real) := by + have hv := hv_bounds k0 + have hlo' : ∀ d, (vLo k0 d : Real) ≤ vRealOfInputs inputs k0 d := fun d => + (hv d).1 + have hhi' : ∀ d, vRealOfInputs inputs k0 d ≤ (vHi k0 d : Real) := fun d => + (hv d).2 + have hlow := + dotIntervalLower_le_dotProduct_real (v := dirHead) + (lo := vLo k0) (hi := vHi k0) + (x := fun d => vRealOfInputs inputs k0 d) hlo' hhi' + have hhigh := + dotProduct_le_dotIntervalUpper_real (v := dirHead) + (lo := vLo k0) (hi := vHi k0) + (x := fun d => vRealOfInputs inputs k0 d) hlo' hhi' + have hlow' : + (valCert.valsLo k0 : Real) ≤ valsRealOfInputs inputs k0 := by + simpa [valsLo, valCert, dirHead, valsRealOfInputs, + Bounds.cacheBound_apply, Bounds.dotIntervalLowerCachedRat_eq] using hlow + have hhigh' : + valsRealOfInputs inputs k0 ≤ (valCert.valsHi k0 : Real) := by + simpa [valsHi, valCert, dirHead, valsRealOfInputs, + Bounds.cacheBound_apply, Bounds.dotIntervalUpperCachedRat_eq] using hhigh + exact ⟨hlow', hhigh'⟩ + have hhi : (valCert.valsHi k0 : Real) ≤ (valCert.hi : Real) := by + have hhiDyadic : valCert.valsHi k0 ≤ valCert.hi := by + change valsHi k0 ≤ hi + dsimp [hi] + refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) + (f := valsHi) (a := valsHi k0)).2 ?_ + exact ⟨k0, ⟨hmem0, le_rfl⟩⟩ + exact dyadicToReal_le_of_le hhiDyadic + have hreal : + (valCert.lo : Real) ≤ (valCert.hi : Real) := + le_trans hlo (le_trans hvals.1 (le_trans hvals.2 hhi)) + exact (dyadicToReal_le_iff (x := valCert.lo) (y := valCert.hi)).1 hreal + · intro k + have hmem : k ∈ univ := by simp [univ] + have hloDyadic : valCert.lo ≤ valCert.valsLo k := by + change lo ≤ valsLo k + dsimp [lo] + refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) + (f := valsLo) (a := valsLo k)).2 ?_ + refine ⟨k, hmem, ?_⟩ + exact le_rfl + exact dyadicToReal_le_of_le hloDyadic + · intro k + have hv := hv_bounds k + have hlo' : ∀ d, (vLo k d : Real) ≤ vRealOfInputs inputs k d := fun d => (hv d).1 + have hhi' : ∀ d, vRealOfInputs inputs k d ≤ (vHi k d : Real) := fun d => (hv d).2 + have hlow := + dotIntervalLower_le_dotProduct_real (v := dirHead) + (lo := vLo k) (hi := vHi k) + (x := fun d => vRealOfInputs inputs k d) hlo' hhi' + have hhigh := + dotProduct_le_dotIntervalUpper_real (v := dirHead) + (lo := vLo k) (hi := vHi k) + (x := fun d => vRealOfInputs inputs k d) hlo' hhi' + have hlow' : + (valCert.valsLo k : Real) ≤ valsRealOfInputs inputs k := by + simpa [valsLo, valCert, dirHead, valsRealOfInputs, + Bounds.cacheBound_apply, Bounds.dotIntervalLowerCachedRat_eq] using hlow + have hhigh' : + valsRealOfInputs inputs k ≤ (valCert.valsHi k : Real) := by + simpa [valsHi, valCert, dirHead, valsRealOfInputs, + Bounds.cacheBound_apply, Bounds.dotIntervalUpperCachedRat_eq] using hhigh + exact ⟨hlow', hhigh'⟩ + · intro k + have hmem : k ∈ univ := by simp [univ] + have hhiDyadic : valCert.valsHi k ≤ valCert.hi := by + change valsHi k ≤ hi + dsimp [hi] + refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) + (f := valsHi) (a := valsHi k)).2 ?_ + refine ⟨k, hmem, ?_⟩ + exact le_rfl + exact dyadicToReal_le_of_le hhiDyadic + exact + { softmax_bounds := hsoftmax_bounds + oneHot_bounds_at := oneHot_bounds_at + value_bounds := hvals_bounds } + · have : False := by + simp [buildInductionCertFromHeadCore?, hEps, hSqrt, hmodel, hactive] at hcore + exact this.elim + · have : False := by + simp [buildInductionCertFromHeadCore?, hEps, hSqrt] at hcore + exact this.elim + · have : False := by + simp [buildInductionCertFromHeadCore?, hEps] at hcore + exact this.elim + +end Sound + +end Nfp diff --git a/Nfp/Sound/Induction/CoreDefs.lean b/Nfp/Sound/Induction/CoreDefs.lean new file mode 100644 index 0000000..e8115c6 --- /dev/null +++ b/Nfp/Sound/Induction/CoreDefs.lean @@ -0,0 +1,149 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.BigOperators.Group.Finset.Basic +import Mathlib.Data.Vector.Defs +import Nfp.Circuit.Layers.Induction +import Nfp.Circuit.Layers.Softmax +import Nfp.Core.Basic +import Nfp.Model.InductionHead +import Nfp.Sound.Bounds.LayerNorm +import Nfp.Sound.Bounds.MatrixNorm +import Nfp.Sound.Linear.FinFold + +/-! +Core definitions for induction-head certificates. + +These definitions are shared across induction certificate builders and checkers. +-/ + +namespace Nfp + +namespace Sound + +open scoped BigOperators + +open Nfp.Circuit +open Nfp.Sound.Bounds + +variable {seq : Nat} + +/-- Cached direction head for head inputs. -/ +def dirHeadVecOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Vector Dyadic dHead := + Vector.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.wo j d) (fun j => inputs.direction j)) + +/-- Real-valued LayerNorm outputs for head inputs. -/ +noncomputable def lnRealOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dModel → Real := + fun q => + Bounds.layerNormReal inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q) + +/-- Real-valued query projections for head inputs. -/ +noncomputable def qRealOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dHead → Real := + fun q d => + dotProduct (fun j => (inputs.wq j d : Real)) (lnRealOfInputs inputs q) + (inputs.bq d : Real) + +/-- Real-valued key projections for head inputs. -/ +noncomputable def kRealOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dHead → Real := + fun q d => + dotProduct (fun j => (inputs.wk j d : Real)) (lnRealOfInputs inputs q) + (inputs.bk d : Real) + +/-- Real-valued value projections for head inputs. -/ +noncomputable def vRealOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dHead → Real := + fun q d => + dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs q) + (inputs.bv d : Real) + +/-- Real-valued attention scores for head inputs. -/ +noncomputable def scoresRealOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin seq → Real := + fun q k => + let base := + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) + if inputs.maskCausal then + if k ≤ q then + base + else + (inputs.maskValue : Real) + else + base + +/-- Real-valued per-key head outputs in model space. -/ +noncomputable def headValueRealOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dModel → Real := + fun k i => + dotProduct (fun d => (inputs.wo i d : Real)) (fun d => vRealOfInputs inputs k d) + +/-- Real-valued direction scores for head inputs. -/ +noncomputable def valsRealOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Real := + let dirHead : Fin dHead → Real := fun d => (dirHeadVecOfInputs inputs).get d + fun k => dotProduct dirHead (fun d => vRealOfInputs inputs k d) + +/-- Interval data for direction values. -/ +structure ValueInterval (seq : Nat) where + /-- Lower bound for values. -/ + lo : Dyadic + /-- Upper bound for values. -/ + hi : Dyadic + /-- Lower bounds on per-key values. -/ + valsLo : Fin seq → Dyadic + /-- Upper bounds on per-key values. -/ + valsHi : Fin seq → Dyadic + /-- Optional logit-diff direction metadata (ignored by the checker). -/ + direction : Option DirectionSpec + +/-- Soundness predicate for direction-value interval data. -/ +structure ValueIntervalBounds {seq : Nat} (vals : Fin seq → Real) + (c : ValueInterval seq) : Prop where + /-- Interval endpoints are ordered. -/ + lo_le_hi : c.lo ≤ c.hi + /-- `lo` is below every lower bound. -/ + lo_le_valsLo : ∀ k, (c.lo : Real) ≤ (c.valsLo k : Real) + /-- Bounds sandwich the real values. -/ + vals_bounds : + ∀ k, (c.valsLo k : Real) ≤ vals k ∧ vals k ≤ (c.valsHi k : Real) + /-- `hi` is above every upper bound. -/ + valsHi_le_hi : ∀ k, (c.valsHi k : Real) ≤ (c.hi : Real) + +/-- Sound induction-certificate payload built from exact head inputs. -/ +structure InductionHeadCert (seq : Nat) where + /-- Weight tolerance. -/ + eps : Dyadic + /-- Per-query weight tolerance derived from local margins. -/ + epsAt : Fin seq → Dyadic + /-- Score margin used to justify the weight tolerance. -/ + margin : Dyadic + /-- Active queries for which bounds are required. -/ + active : Finset (Fin seq) + /-- `prev` selector for induction-style attention. -/ + prev : Fin seq → Fin seq + /-- Value-interval certificate for the direction values. -/ + values : ValueInterval seq + +/-- Soundness predicate for `InductionHeadCert`. -/ +structure InductionHeadCertSound [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) : Prop where + /-- Softmax weights respect the derived margin bounds. -/ + softmax_bounds : + Layers.SoftmaxMarginBoundsOn (Val := Real) (c.eps : Real) (c.margin : Real) + (fun q => q ∈ c.active) c.prev + (scoresRealOfInputs inputs) + (fun q k => Circuit.softmax (scoresRealOfInputs inputs q) k) + /-- Per-query one-hot bounds derived from local margins. -/ + oneHot_bounds_at : + ∀ q, q ∈ c.active → + Layers.OneHotApproxBoundsOnActive (Val := Real) (c.epsAt q : Real) + (fun q' => q' = q) c.prev + (fun q' k => Circuit.softmax (scoresRealOfInputs inputs q') k) + /-- Interval bounds hold for the direction values. -/ + value_bounds : + ValueIntervalBounds (vals := valsRealOfInputs inputs) c.values + +end Sound + +end Nfp diff --git a/Nfp/Sound/Induction/HeadBounds.lean b/Nfp/Sound/Induction/HeadBounds.lean new file mode 100644 index 0000000..747e054 --- /dev/null +++ b/Nfp/Sound/Induction/HeadBounds.lean @@ -0,0 +1,532 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Core.Basic +import Mathlib.Data.Finset.Basic +import Mathlib.Data.Vector.Defs +import Nfp.Model.InductionHead +import Nfp.Sound.Bounds.Attention +import Nfp.Sound.Bounds.LayerNorm +import Nfp.Sound.Bounds.MatrixNorm +import Nfp.Sound.Linear.FinFold + +/-! +Helper bounds for head-induction certificate construction. + +These are pure precomputations that are useful for profiling and staging. +-/ + +namespace Nfp + +namespace Sound + +open Nfp.Sound.Bounds + +variable {seq : Nat} + +/-- Cached direction head for head inputs. -/ +private def dirHeadVecOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Vector Dyadic dHead := + Vector.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.wo j d) (fun j => inputs.direction j)) + +/-- LayerNorm bounds used by the induction-head builder. -/ +def headLnBounds [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : + (Fin seq → Fin dModel → Dyadic) × (Fin seq → Fin dModel → Dyadic) := + Bounds.cacheBoundPair2 (fun q => + Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) + +theorem headLnBounds_spec [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : + headLnBounds inputs = + Bounds.cacheBoundPair2 (fun q => + Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) := rfl + +/-- Q/K/V bounds used by the induction-head builder. -/ +structure HeadQKVBounds (seq dModel dHead : Nat) where + /-- Q lower bounds. -/ + qLo : Fin seq → Fin dHead → Dyadic + /-- Q upper bounds. -/ + qHi : Fin seq → Fin dHead → Dyadic + /-- K lower bounds. -/ + kLo : Fin seq → Fin dHead → Dyadic + /-- K upper bounds. -/ + kHi : Fin seq → Fin dHead → Dyadic + /-- V lower bounds. -/ + vLo : Fin seq → Fin dHead → Dyadic + /-- V upper bounds. -/ + vHi : Fin seq → Fin dHead → Dyadic + /-- Q absolute bounds. -/ + qAbs : Fin seq → Fin dHead → Dyadic + /-- K absolute bounds. -/ + kAbs : Fin seq → Fin dHead → Dyadic + +/-- Compute Q/K/V bounds from LayerNorm bounds. -/ +def headQKVBounds [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (lnLo lnHi : Fin seq → Fin dModel → Dyadic) : + HeadQKVBounds seq dModel dHead := + let qLo := + Bounds.cacheBound2TaskElem (fun q d => + Bounds.dotIntervalLowerUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + + inputs.bq d) + let qHi := + Bounds.cacheBound2TaskElem (fun q d => + Bounds.dotIntervalUpperUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + + inputs.bq d) + let kLo := + Bounds.cacheBound2TaskElem (fun q d => + Bounds.dotIntervalLowerUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + + inputs.bk d) + let kHi := + Bounds.cacheBound2TaskElem (fun q d => + Bounds.dotIntervalUpperUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + + inputs.bk d) + let vLo := + Bounds.cacheBound2TaskElem (fun q d => + Bounds.dotIntervalLowerUnnorm (fun j => inputs.wv j d) (lnLo q) (lnHi q) + + inputs.bv d) + let vHi := + Bounds.cacheBound2TaskElem (fun q d => + Bounds.dotIntervalUpperUnnorm (fun j => inputs.wv j d) (lnLo q) (lnHi q) + + inputs.bv d) + let qAbs := + Bounds.cacheBound2TaskElem (fun q d => max |qLo q d| |qHi q d|) + let kAbs := + Bounds.cacheBound2TaskElem (fun q d => max |kLo q d| |kHi q d|) + { qLo := qLo + qHi := qHi + kLo := kLo + kHi := kHi + vLo := vLo + vHi := vHi + qAbs := qAbs + kAbs := kAbs } + +theorem headQKVBounds_spec [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (lnLo lnHi : Fin seq → Fin dModel → Dyadic) : + headQKVBounds inputs lnLo lnHi = + let qLo := + Bounds.cacheBound2TaskElem (fun q d => + Bounds.dotIntervalLowerUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + + inputs.bq d) + let qHi := + Bounds.cacheBound2TaskElem (fun q d => + Bounds.dotIntervalUpperUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + + inputs.bq d) + let kLo := + Bounds.cacheBound2TaskElem (fun q d => + Bounds.dotIntervalLowerUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + + inputs.bk d) + let kHi := + Bounds.cacheBound2TaskElem (fun q d => + Bounds.dotIntervalUpperUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + + inputs.bk d) + let vLo := + Bounds.cacheBound2TaskElem (fun q d => + Bounds.dotIntervalLowerUnnorm (fun j => inputs.wv j d) (lnLo q) (lnHi q) + + inputs.bv d) + let vHi := + Bounds.cacheBound2TaskElem (fun q d => + Bounds.dotIntervalUpperUnnorm (fun j => inputs.wv j d) (lnLo q) (lnHi q) + + inputs.bv d) + let qAbs := + Bounds.cacheBound2TaskElem (fun q d => max |qLo q d| |qHi q d|) + let kAbs := + Bounds.cacheBound2TaskElem (fun q d => max |kLo q d| |kHi q d|) + { qLo := qLo + qHi := qHi + kLo := kLo + kHi := kHi + vLo := vLo + vHi := vHi + qAbs := qAbs + kAbs := kAbs } := rfl + +/-- Score and margin bounds used by the induction-head builder. -/ +structure HeadScoreBounds (seq dModel dHead : Nat) where + /-- Absolute dot-product bound. -/ + dotAbs : Fin seq → Fin seq → Dyadic + /-- Base score absolute bound. -/ + scoreBaseAbs : Fin seq → Fin seq → Dyadic + /-- Score absolute bound with causal masking. -/ + scoreAbs : Fin seq → Fin seq → Dyadic + /-- Score lower bound. -/ + scoreLo : Fin seq → Fin seq → Dyadic + /-- Score upper bound. -/ + scoreHi : Fin seq → Fin seq → Dyadic + /-- Margin per query. -/ + marginAt : Fin seq → Dyadic + /-- Epsilon per query. -/ + epsAt : Fin seq → Dyadic + /-- Global margin. -/ + margin : Dyadic + /-- Global epsilon. -/ + eps : Dyadic + +/-- Compute score and margin bounds from cached score lower/upper bounds. -/ +def headScoreBoundsFromCaches [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (dotAbs : Fin seq → Fin seq → Dyadic) + (scoreLo scoreHi : Fin seq → Fin seq → Dyadic) : + HeadScoreBounds seq dModel dHead := + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let scoreBaseAbs : Fin seq → Fin seq → Dyadic := fun q k => + |inputs.scale| * dotAbs q k + let scoreAbs : Fin seq → Fin seq → Dyadic := fun q k => + if masked q k then |inputs.maskValue| else scoreBaseAbs q k + let otherKeys : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + let maskedKeys : Fin seq → Finset (Fin seq) := fun q => + if inputs.maskCausal = true then + (otherKeys q).filter (fun k => q < k) + else + (∅ : Finset (Fin seq)) + let unmaskedKeys : Fin seq → Finset (Fin seq) := fun q => + (otherKeys q) \ (maskedKeys q) + let maskedGap : Fin seq → Dyadic := fun q => + scoreLo q (inputs.prev q) - inputs.maskValue + let marginTasks : Array (Task Dyadic) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + if q ∈ inputs.active then + let other := unmaskedKeys q + let masked := maskedKeys q + if hunmasked : other.Nonempty then + let unmaskedMin := other.inf' hunmasked (fun k => + scoreLo q (inputs.prev q) - scoreHi q k) + if hmasked : masked.Nonempty then + min unmaskedMin (maskedGap q) + else + unmaskedMin + else + if hmasked : masked.Nonempty then + maskedGap q + else + (0 : Dyadic) + else + (0 : Dyadic))) + let marginAt : Fin seq → Dyadic := fun q => + (marginTasks[q.1]'(by + simp [marginTasks, q.isLt])).get + let epsTasks : Array (Task Dyadic) := + Array.ofFn (fun q : Fin seq => + (marginTasks[q.1]'(by + simp [marginTasks, q.isLt])).map (fun m => + if m < 0 then + (1 : Dyadic) + else + dyadicDivUp (seq - 1) (1 + m))) + let epsAt : Fin seq → Dyadic := fun q => + (epsTasks[q.1]'(by + simp [epsTasks, q.isLt])).get + let margin : Dyadic := + if h : inputs.active.Nonempty then + inputs.active.inf' h marginAt + else + (0 : Dyadic) + let eps : Dyadic := + if margin < 0 then + (1 : Dyadic) + else + dyadicDivUp (seq - 1) (1 + margin) + { dotAbs := dotAbs + scoreBaseAbs := scoreBaseAbs + scoreAbs := scoreAbs + scoreLo := scoreLo + scoreHi := scoreHi + marginAt := marginAt + epsAt := epsAt + margin := margin + eps := eps } + +/-- Compute score and margin bounds from dot-product absolute bounds. -/ +def headScoreBoundsFromDotAbs [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (dotAbs : Fin seq → Fin seq → Dyadic) : HeadScoreBounds seq dModel dHead := + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let dotAbsRowTasks : Array (Task { row : Array Dyadic // row.size = seq }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + ⟨Array.ofFn (fun k : Fin seq => dotAbs q k), by simp⟩)) + let scoreRowTasks : Array (Task (Array Dyadic × Array Dyadic)) := + Array.ofFn (fun q : Fin seq => + (dotAbsRowTasks[q.1]'(by + simp [dotAbsRowTasks, q.isLt])).map (fun row => + let rowArr := row.1 + let scoreBaseAt : Fin seq → Dyadic := fun k => + |inputs.scale| * rowArr.getD k.1 0 + let loRow := Array.ofFn (fun k : Fin seq => + if masked q k then inputs.maskValue else -scoreBaseAt k) + let hiRow := Array.ofFn (fun k : Fin seq => + if masked q k then inputs.maskValue else scoreBaseAt k) + (loRow, hiRow))) + let scoreLoCached : Fin seq → Fin seq → Dyadic := fun q k => + let rowPair := (scoreRowTasks[q.1]'(by + simp [scoreRowTasks, q.isLt])).get + rowPair.1.getD k.1 0 + let scoreHiCached : Fin seq → Fin seq → Dyadic := fun q k => + let rowPair := (scoreRowTasks[q.1]'(by + simp [scoreRowTasks, q.isLt])).get + rowPair.2.getD k.1 0 + headScoreBoundsFromCaches inputs dotAbs scoreLoCached scoreHiCached + +theorem headScoreBoundsFromDotAbs_spec [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (dotAbs : Fin seq → Fin seq → Dyadic) : + headScoreBoundsFromDotAbs inputs dotAbs = + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let dotAbsRowTasks : Array (Task { row : Array Dyadic // row.size = seq }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + ⟨Array.ofFn (fun k : Fin seq => dotAbs q k), by simp⟩)) + let scoreRowTasks : Array (Task (Array Dyadic × Array Dyadic)) := + Array.ofFn (fun q : Fin seq => + (dotAbsRowTasks[q.1]'(by + simp [dotAbsRowTasks, q.isLt])).map (fun row => + let rowArr := row.1 + let scoreBaseAt : Fin seq → Dyadic := fun k => + |inputs.scale| * rowArr.getD k.1 0 + let loRow := Array.ofFn (fun k : Fin seq => + if masked q k then inputs.maskValue else -scoreBaseAt k) + let hiRow := Array.ofFn (fun k : Fin seq => + if masked q k then inputs.maskValue else scoreBaseAt k) + (loRow, hiRow))) + let scoreLoCached : Fin seq → Fin seq → Dyadic := fun q k => + let rowPair := (scoreRowTasks[q.1]'(by + simp [scoreRowTasks, q.isLt])).get + rowPair.1.getD k.1 0 + let scoreHiCached : Fin seq → Fin seq → Dyadic := fun q k => + let rowPair := (scoreRowTasks[q.1]'(by + simp [scoreRowTasks, q.isLt])).get + rowPair.2.getD k.1 0 + headScoreBoundsFromCaches inputs dotAbs scoreLoCached scoreHiCached := rfl + +/-- Compute score and margin bounds from Q/K absolute bounds. -/ +def headScoreBounds [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (qAbs kAbs : Fin seq → Fin dHead → Dyadic) : + HeadScoreBounds seq dModel dHead := + headScoreBoundsFromDotAbs inputs (fun q k => + Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d)) + +theorem headScoreBounds_spec [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (qAbs kAbs : Fin seq → Fin dHead → Dyadic) : + headScoreBounds inputs qAbs kAbs = + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let dotAbs : Fin seq → Fin seq → Dyadic := fun q k => + Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d) + let dotAbsRowTasks : Array (Task { row : Array Dyadic // row.size = seq }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + ⟨Array.ofFn (fun k : Fin seq => dotAbs q k), by simp⟩)) + let scoreRowTasks : Array (Task (Array Dyadic × Array Dyadic)) := + Array.ofFn (fun q : Fin seq => + (dotAbsRowTasks[q.1]'(by + simp [dotAbsRowTasks, q.isLt])).map (fun row => + let rowArr := row.1 + let scoreBaseAt : Fin seq → Dyadic := fun k => + |inputs.scale| * rowArr.getD k.1 0 + let loRow := Array.ofFn (fun k : Fin seq => + if masked q k then inputs.maskValue else -scoreBaseAt k) + let hiRow := Array.ofFn (fun k : Fin seq => + if masked q k then inputs.maskValue else scoreBaseAt k) + (loRow, hiRow))) + let scoreLoCached : Fin seq → Fin seq → Dyadic := fun q k => + let rowPair := (scoreRowTasks[q.1]'(by + simp [scoreRowTasks, q.isLt])).get + rowPair.1.getD k.1 0 + let scoreHiCached : Fin seq → Fin seq → Dyadic := fun q k => + let rowPair := (scoreRowTasks[q.1]'(by + simp [scoreRowTasks, q.isLt])).get + rowPair.2.getD k.1 0 + headScoreBoundsFromCaches inputs dotAbs scoreLoCached scoreHiCached := rfl + +/-- Value bounds used by the induction-head builder. -/ +structure HeadValueBounds (seq dModel dHead : Nat) where + /-- Value lower bounds. -/ + valsLo : Fin seq → Dyadic + /-- Value upper bounds. -/ + valsHi : Fin seq → Dyadic + /-- Global value lower bound. -/ + lo : Dyadic + /-- Global value upper bound. -/ + hi : Dyadic + +/-- Cached direction vector for value bounds. -/ +def headValueDirHead {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin dHead → Dyadic := + let dirHeadVec := dirHeadVecOfInputs inputs + fun d => dirHeadVec.get d + +theorem headValueDirHead_spec {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : + headValueDirHead inputs = + let dirHeadVec := dirHeadVecOfInputs inputs + fun d => dirHeadVec.get d := rfl + +/-- Cached lower value bounds from V intervals. -/ +def headValueValsLo {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Dyadic) : Fin seq → Dyadic := + let dirHead := headValueDirHead inputs + Bounds.cacheBound (fun k => + Bounds.dotIntervalLowerCachedDyadic dirHead (vLo k) (vHi k)) + +theorem headValueValsLo_spec {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Dyadic) : + headValueValsLo inputs vLo vHi = + let dirHead := headValueDirHead inputs + Bounds.cacheBound (fun k => + Bounds.dotIntervalLowerCachedDyadic dirHead (vLo k) (vHi k)) := rfl + +/-- Cached lower value bounds from V intervals using a common-denominator sum. -/ +def headValueValsLoCommonDen {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Dyadic) : Fin seq → Dyadic := + let dirHead := headValueDirHead inputs + Bounds.cacheBound (fun k => + Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) + +theorem headValueValsLoCommonDen_spec {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Dyadic) : + headValueValsLoCommonDen inputs vLo vHi = + let dirHead := headValueDirHead inputs + Bounds.cacheBound (fun k => + Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) := rfl + +theorem headValueValsLoCommonDen_eq {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Dyadic) : + headValueValsLoCommonDen inputs vLo vHi = headValueValsLo inputs vLo vHi := by + classical + funext k + simp [headValueValsLoCommonDen, headValueValsLo, Bounds.cacheBound_apply, + Bounds.dotIntervalLowerCommonDen_eq, Bounds.dotIntervalLowerCachedRat_eq] + +/-- Cached upper value bounds from V intervals. -/ +def headValueValsHi {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Dyadic) : Fin seq → Dyadic := + let dirHead := headValueDirHead inputs + Bounds.cacheBound (fun k => + Bounds.dotIntervalUpperCachedDyadic dirHead (vLo k) (vHi k)) + +theorem headValueValsHi_spec {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Dyadic) : + headValueValsHi inputs vLo vHi = + let dirHead := headValueDirHead inputs + Bounds.cacheBound (fun k => + Bounds.dotIntervalUpperCachedDyadic dirHead (vLo k) (vHi k)) := rfl + +/-- Cached upper value bounds from V intervals using a common-denominator sum. -/ +def headValueValsHiCommonDen {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Dyadic) : Fin seq → Dyadic := + let dirHead := headValueDirHead inputs + Bounds.cacheBound (fun k => + Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) + +theorem headValueValsHiCommonDen_spec {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Dyadic) : + headValueValsHiCommonDen inputs vLo vHi = + let dirHead := headValueDirHead inputs + Bounds.cacheBound (fun k => + Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) := rfl + +theorem headValueValsHiCommonDen_eq {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Dyadic) : + headValueValsHiCommonDen inputs vLo vHi = headValueValsHi inputs vLo vHi := by + classical + funext k + simp [headValueValsHiCommonDen, headValueValsHi, Bounds.cacheBound_apply, + Bounds.dotIntervalUpperCommonDen_eq, Bounds.dotIntervalUpperCachedRat_eq] + +/-- Global lower value bound from cached per-key values. -/ +def headValueLo [NeZero seq] (valsLo : Fin seq → Dyadic) : Dyadic := + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := by simp [univ] + univ.inf' hnonempty valsLo + +theorem headValueLo_spec [NeZero seq] (valsLo : Fin seq → Dyadic) : + headValueLo valsLo = + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := by simp [univ] + univ.inf' hnonempty valsLo := rfl + +/-- Global upper value bound from cached per-key values. -/ +def headValueHi [NeZero seq] (valsHi : Fin seq → Dyadic) : Dyadic := + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := by simp [univ] + univ.sup' hnonempty valsHi + +theorem headValueHi_spec [NeZero seq] (valsHi : Fin seq → Dyadic) : + headValueHi valsHi = + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := by simp [univ] + univ.sup' hnonempty valsHi := rfl + +/-- Compute value bounds from V interval bounds. -/ +def headValueBounds [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Dyadic) : + HeadValueBounds seq dModel dHead := + let valsLo := headValueValsLo inputs vLo vHi + let valsHi := headValueValsHi inputs vLo vHi + let lo := headValueLo valsLo + let hi := headValueHi valsHi + { valsLo := valsLo, valsHi := valsHi, lo := lo, hi := hi } + +theorem headValueBounds_spec [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Dyadic) : + headValueBounds inputs vLo vHi = + let valsLo := headValueValsLo inputs vLo vHi + let valsHi := headValueValsHi inputs vLo vHi + let lo := headValueLo valsLo + let hi := headValueHi valsHi + { valsLo := valsLo, valsHi := valsHi, lo := lo, hi := hi } := rfl + +/-- Compute value bounds from V interval bounds using a common-denominator sum. -/ +def headValueBoundsCommonDen [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Dyadic) : + HeadValueBounds seq dModel dHead := + let valsLo := headValueValsLoCommonDen inputs vLo vHi + let valsHi := headValueValsHiCommonDen inputs vLo vHi + let lo := headValueLo valsLo + let hi := headValueHi valsHi + { valsLo := valsLo, valsHi := valsHi, lo := lo, hi := hi } + +theorem headValueBoundsCommonDen_spec [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Dyadic) : + headValueBoundsCommonDen inputs vLo vHi = + let valsLo := headValueValsLoCommonDen inputs vLo vHi + let valsHi := headValueValsHiCommonDen inputs vLo vHi + let lo := headValueLo valsLo + let hi := headValueHi valsHi + { valsLo := valsLo, valsHi := valsHi, lo := lo, hi := hi } := rfl + +theorem headValueBoundsCommonDen_eq [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Dyadic) : + headValueBoundsCommonDen inputs vLo vHi = headValueBounds inputs vLo vHi := by + classical + simp [headValueBoundsCommonDen, headValueBounds, headValueValsLoCommonDen_eq, + headValueValsHiCommonDen_eq] + +end Sound + +end Nfp diff --git a/Nfp/Sound/Induction/HeadOutput.lean b/Nfp/Sound/Induction/HeadOutput.lean new file mode 100644 index 0000000..e7dbf8e --- /dev/null +++ b/Nfp/Sound/Induction/HeadOutput.lean @@ -0,0 +1,376 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Sound.Induction.Core + +/-! +Head-output interval certificates for induction heads. +-/ + +namespace Nfp + +namespace Sound + +open scoped BigOperators + +open Nfp.Circuit +open Nfp.Sound.Bounds + +variable {seq : Nat} + +/-- Build and certify induction certificates from exact head inputs. -/ +def buildInductionCertFromHead? [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : + Option {c : InductionHeadCert seq // InductionHeadCertSound inputs c} := by + classical + cases hcore : buildInductionCertFromHeadCore? inputs with + | none => exact none + | some c => + exact some ⟨c, buildInductionCertFromHeadCore?_sound inputs c hcore⟩ + +section HeadOutputInterval + +variable {seq dModel dHead : Nat} + +noncomputable section + +/-- Real-valued head output using explicit score inputs. -/ +def headOutputWithScores (scores : Fin seq → Fin seq → Real) + (inputs : Model.InductionHeadInputs seq dModel dHead) + (q : Fin seq) (i : Fin dModel) : Real := + let weights : Fin seq → Fin seq → Real := fun q k => + Circuit.softmax (scores q) k + let vals : Fin seq → Real := fun k => headValueRealOfInputs inputs k i + dotProduct (weights q) vals + +/-- Unfolding lemma for `headOutputWithScores`. -/ +theorem headOutputWithScores_def (scores : Fin seq → Fin seq → Real) + (inputs : Model.InductionHeadInputs seq dModel dHead) + (q : Fin seq) (i : Fin dModel) : + headOutputWithScores scores inputs q i = + let weights : Fin seq → Fin seq → Real := fun q k => + Circuit.softmax (scores q) k + let vals : Fin seq → Real := fun k => headValueRealOfInputs inputs k i + dotProduct (weights q) vals := rfl + +/-- Real-valued head output for a query and model dimension. -/ +def headOutput (inputs : Model.InductionHeadInputs seq dModel dHead) + (q : Fin seq) (i : Fin dModel) : Real := + headOutputWithScores (scoresRealOfInputs inputs) inputs q i + +/-- Unfolding lemma for `headOutput`. -/ +theorem headOutput_def (inputs : Model.InductionHeadInputs seq dModel dHead) + (q : Fin seq) (i : Fin dModel) : + headOutput inputs q i = + headOutputWithScores (scoresRealOfInputs inputs) inputs q i := rfl + +/-- Soundness predicate for head-output interval bounds. -/ +structure HeadOutputIntervalSound [NeZero seq] + (inputs : Model.InductionHeadInputs seq dModel dHead) + (active : Finset (Fin seq)) + (c : Circuit.ResidualIntervalCert dModel) : Prop where + /-- Interval bounds are ordered coordinatewise. -/ + bounds : Circuit.ResidualIntervalBounds c + /-- Active-query outputs lie inside the interval bounds. -/ + output_mem : + ∀ q, q ∈ active → ∀ i, + (c.lo i : Real) ≤ headOutput inputs q i ∧ + headOutput inputs q i ≤ (c.hi i : Real) + +/-- Certified head-output interval data for a specific active set. -/ +structure HeadOutputIntervalResult [NeZero seq] + (inputs : Model.InductionHeadInputs seq dModel dHead) where + /-- Active queries covered by the interval bounds. -/ + active : Finset (Fin seq) + /-- Residual-interval certificate for head outputs. -/ + cert : Circuit.ResidualIntervalCert dModel + /-- Soundness proof for the interval bounds. -/ + sound : HeadOutputIntervalSound inputs active cert + +/-- Build residual-interval bounds for head outputs on active queries. -/ +def buildHeadOutputIntervalFromHead? [NeZero seq] + (inputs : Model.InductionHeadInputs seq dModel dHead) : + Option (HeadOutputIntervalResult inputs) := by + classical + cases seq with + | zero => + cases (NeZero.ne (n := (0 : Nat)) rfl) + | succ n => + by_cases hEps : 0 < inputs.lnEps + · by_cases hSqrt : 0 < sqrtLower inputs.lnEps + · by_cases hmodel : dModel = 0 + · exact none + · cases hbuild : buildInductionCertFromHead? inputs with + | none => exact none + | some certWithProof => + rcases certWithProof with ⟨cert, hcert⟩ + let lnBounds : Fin (Nat.succ n) → (Fin dModel → Dyadic) × (Fin dModel → Dyadic) := + fun q => + Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta + (inputs.embed q) + let lnLo : Fin (Nat.succ n) → Fin dModel → Dyadic := fun q => (lnBounds q).1 + let lnHi : Fin (Nat.succ n) → Fin dModel → Dyadic := fun q => (lnBounds q).2 + let vLo : Fin (Nat.succ n) → Fin dHead → Dyadic := fun q d => + dotIntervalLower (fun j => inputs.wv j d) (lnLo q) (lnHi q) + inputs.bv d + let vHi : Fin (Nat.succ n) → Fin dHead → Dyadic := fun q d => + dotIntervalUpper (fun j => inputs.wv j d) (lnLo q) (lnHi q) + inputs.bv d + let headValueLo : Fin (Nat.succ n) → Fin dModel → Dyadic := fun k i => + dotIntervalLower (fun d => inputs.wo i d) (vLo k) (vHi k) + let headValueHi : Fin (Nat.succ n) → Fin dModel → Dyadic := fun k i => + dotIntervalUpper (fun d => inputs.wo i d) (vLo k) (vHi k) + have hln_bounds : + ∀ q i, (lnLo q i : Real) ≤ lnRealOfInputs inputs q i ∧ + lnRealOfInputs inputs q i ≤ (lnHi q i : Real) := by + intro q i + have hln := + Bounds.layerNormBounds_spec (eps := inputs.lnEps) + (gamma := inputs.ln1Gamma) (beta := inputs.ln1Beta) + (x := inputs.embed q) hmodel hEps hSqrt + simpa [lnBounds, lnLo, lnHi, lnRealOfInputs] using hln i + have hv_bounds : + ∀ q d, (vLo q d : Real) ≤ vRealOfInputs inputs q d ∧ + vRealOfInputs inputs q d ≤ (vHi q d : Real) := by + intro q d + have hln := hln_bounds q + have hlo : ∀ j, (lnLo q j : Real) ≤ lnRealOfInputs inputs q j := fun j => + (hln j).1 + have hhi : ∀ j, lnRealOfInputs inputs q j ≤ (lnHi q j : Real) := fun j => + (hln j).2 + have hlow := + dotIntervalLower_le_dotProduct_real (v := fun j => inputs.wv j d) + (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi + have hhigh := + dotProduct_le_dotIntervalUpper_real (v := fun j => inputs.wv j d) + (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi + constructor + · simpa [vLo, vRealOfInputs, Bounds.cacheBound2_apply, + Bounds.dotIntervalLowerCachedRat_eq, dyadicToReal_add] using + add_le_add_right hlow (inputs.bv d : Real) + · simpa [vHi, vRealOfInputs, Bounds.cacheBound2_apply, + Bounds.dotIntervalUpperCachedRat_eq, dyadicToReal_add] using + add_le_add_right hhigh (inputs.bv d : Real) + have hhead_bounds : + ∀ k i, (headValueLo k i : Real) ≤ headValueRealOfInputs inputs k i ∧ + headValueRealOfInputs inputs k i ≤ (headValueHi k i : Real) := by + intro k i + have hv := hv_bounds k + have hlo : ∀ d, (vLo k d : Real) ≤ vRealOfInputs inputs k d := fun d => (hv d).1 + have hhi : ∀ d, vRealOfInputs inputs k d ≤ (vHi k d : Real) := fun d => (hv d).2 + have hlow := + dotIntervalLower_le_dotProduct_real (v := fun d => inputs.wo i d) + (lo := vLo k) (hi := vHi k) + (x := fun d => vRealOfInputs inputs k d) hlo hhi + have hhigh := + dotProduct_le_dotIntervalUpper_real (v := fun d => inputs.wo i d) + (lo := vLo k) (hi := vHi k) + (x := fun d => vRealOfInputs inputs k d) hlo hhi + constructor + · simpa [headValueLo, headValueRealOfInputs] using hlow + · simpa [headValueHi, headValueRealOfInputs] using hhigh + let scoresReal : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := + scoresRealOfInputs inputs + let weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := fun q k => + Circuit.softmax (scoresReal q) k + let activeSet : Finset (Fin (Nat.succ n)) := cert.active + let univ : Finset (Fin (Nat.succ n)) := Finset.univ + have huniv : univ.Nonempty := by simp [univ] + let loVal : Fin dModel → Dyadic := fun i => + univ.inf' huniv (fun k => headValueLo k i) + let hiVal : Fin dModel → Dyadic := fun i => + univ.sup' huniv (fun k => headValueHi k i) + have hvalsBoundsReal : + ∀ i, Layers.ValueRangeBounds (Val := Real) + (loVal i : Real) (hiVal i : Real) + (fun k => headValueRealOfInputs inputs k i) := by + intro i + refine { lo_le_hi := ?_, lo_le := ?_, le_hi := ?_ } + · rcases (Finset.univ_nonempty : univ.Nonempty) with ⟨k0, hk0⟩ + have hmem0 : k0 ∈ univ := hk0 + have hloDyadic : loVal i ≤ headValueLo k0 i := by + change loVal i ≤ headValueLo k0 i + dsimp [loVal] + refine (Finset.inf'_le_iff (s := univ) (H := huniv) + (f := fun k => headValueLo k i) (a := headValueLo k0 i)).2 ?_ + refine ⟨k0, hmem0, ?_⟩ + exact le_rfl + have hhiDyadic : headValueHi k0 i ≤ hiVal i := by + change headValueHi k0 i ≤ hiVal i + dsimp [hiVal] + refine (Finset.le_sup'_iff (s := univ) (H := huniv) + (f := fun k => headValueHi k i) (a := headValueHi k0 i)).2 ?_ + exact ⟨k0, ⟨hmem0, le_rfl⟩⟩ + have hbounds := hhead_bounds k0 i + have hreal : (loVal i : Real) ≤ (hiVal i : Real) := by + refine le_trans (dyadicToReal_le_of_le hloDyadic) ?_ + refine le_trans hbounds.1 ?_ + exact le_trans hbounds.2 (dyadicToReal_le_of_le hhiDyadic) + exact hreal + · intro k + have hmem : k ∈ univ := by simp [univ] + have hloDyadic : loVal i ≤ headValueLo k i := by + change loVal i ≤ headValueLo k i + dsimp [loVal] + refine (Finset.inf'_le_iff (s := univ) (H := huniv) + (f := fun k => headValueLo k i) (a := headValueLo k i)).2 ?_ + refine ⟨k, hmem, ?_⟩ + exact le_rfl + have hbounds := hhead_bounds k i + exact (dyadicToReal_le_of_le hloDyadic) |>.trans hbounds.1 + · intro k + have hmem : k ∈ univ := by simp [univ] + have hhiDyadic : headValueHi k i ≤ hiVal i := by + change headValueHi k i ≤ hiVal i + dsimp [hiVal] + refine (Finset.le_sup'_iff (s := univ) (H := huniv) + (f := fun k => headValueHi k i) (a := headValueHi k i)).2 ?_ + exact ⟨k, ⟨hmem, le_rfl⟩⟩ + have hbounds := hhead_bounds k i + exact hbounds.2.trans + (dyadicToReal_le_of_le hhiDyadic) + have hsoftmax : + Layers.SoftmaxMarginBoundsOn (Val := Real) + (cert.eps : Real) (cert.margin : Real) + (fun q => q ∈ activeSet) cert.prev scoresReal weights := by + simpa [scoresReal, weights, activeSet] using hcert.softmax_bounds + have hweights : + Layers.OneHotApproxBoundsOnActive (Val := Real) (cert.eps : Real) + (fun q => q ∈ activeSet) cert.prev weights := + Layers.oneHotApproxBoundsOnActive_of_softmaxMargin + (Val := Real) + (ε := (cert.eps : Real)) + (margin := (cert.margin : Real)) + (active := fun q => q ∈ activeSet) + (prev := cert.prev) + (scores := scoresReal) + (weights := weights) + hsoftmax + have happrox : + ∀ i, Layers.InductionSpecApproxOn (Val := Real) (n := n) + ((cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real))) + (fun q => q ∈ activeSet) cert.prev + (fun q => dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i)) + (fun k => headValueRealOfInputs inputs k i) := by + intro i + exact + Layers.inductionSpecApproxOn_of_oneHotApprox_valueRange + (Val := Real) + (n := n) + (ε := (cert.eps : Real)) + (lo := (loVal i : Real)) + (hi := (hiVal i : Real)) + (active := fun q => q ∈ activeSet) + (prev := cert.prev) + (weights := weights) + (vals := fun k => headValueRealOfInputs inputs k i) + (hweights := hweights) + (hvals := hvalsBoundsReal i) + let delta : Fin dModel → Dyadic := fun i => hiVal i - loVal i + let boundLoDyadic : Fin (Nat.succ n) → Fin dModel → Dyadic := fun q i => + headValueLo (cert.prev q) i - cert.eps * delta i + let boundHiDyadic : Fin (Nat.succ n) → Fin dModel → Dyadic := fun q i => + headValueHi (cert.prev q) i + cert.eps * delta i + let loOut : Fin dModel → Dyadic := fun i => + if h : activeSet.Nonempty then + activeSet.inf' h (fun q => boundLoDyadic q i) + else + 0 + let hiOut : Fin dModel → Dyadic := fun i => + if h : activeSet.Nonempty then + activeSet.sup' h (fun q => boundHiDyadic q i) + else + 0 + have hout : + ∀ q, q ∈ activeSet → ∀ i, + (loOut i : Real) ≤ headOutput inputs q i ∧ + headOutput inputs q i ≤ (hiOut i : Real) := by + intro q hq i + have hactive : activeSet.Nonempty := ⟨q, hq⟩ + have hspec := (happrox i) q hq + have hout_def : + headOutput inputs q i = + dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) := by + simp [headOutput, headOutputWithScores, scoresReal, weights] + have hprev_bounds := hhead_bounds (cert.prev q) i + have hupper : + headOutput inputs q i ≤ (boundHiDyadic q i : Real) := by + have hupper' : + dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) ≤ + headValueRealOfInputs inputs (cert.prev q) i + + (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) := by + exact hspec.1 + have hupper'' : + dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) ≤ + (headValueHi (cert.prev q) i : Real) + + (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) := by + have hprev_bounds' := + (add_le_add_iff_right + ((cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)))).2 + hprev_bounds.2 + exact le_trans hupper' hprev_bounds' + simpa + [hout_def, boundHiDyadic, delta, dyadicToReal_add, dyadicToReal_mul, + dyadicToReal_sub] using + hupper'' + have hlower : + (boundLoDyadic q i : Real) ≤ headOutput inputs q i := by + have hlower' : + (headValueRealOfInputs inputs (cert.prev q) i : Real) - + (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) ≤ + dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) := by + exact (sub_le_iff_le_add).2 hspec.2 + have hlower'' : + (headValueLo (cert.prev q) i : Real) - + (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) ≤ + dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) := by + refine le_trans (sub_le_sub_right hprev_bounds.1 + ((cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)))) ?_ + exact hlower' + simpa [hout_def, boundLoDyadic, delta, dyadicToReal_mul, dyadicToReal_sub] using + hlower'' + have hlo : + (loOut i : Real) ≤ (boundLoDyadic q i : Real) := by + have hloDyadic : loOut i ≤ boundLoDyadic q i := by + simpa [loOut, hactive] using + (Finset.inf'_le + (s := activeSet) + (f := fun q => boundLoDyadic q i) + (b := q) hq) + exact dyadicToReal_le_of_le hloDyadic + have hhi : + (boundHiDyadic q i : Real) ≤ (hiOut i : Real) := by + have hhiDyadic : boundHiDyadic q i ≤ hiOut i := by + simpa [hiOut, hactive] using + (Finset.le_sup' + (s := activeSet) + (f := fun q => boundHiDyadic q i) + (b := q) hq) + exact dyadicToReal_le_of_le hhiDyadic + exact ⟨le_trans hlo hlower, le_trans hupper hhi⟩ + have hbounds : Circuit.ResidualIntervalBounds { lo := loOut, hi := hiOut } := by + refine { lo_le_hi := ?_ } + intro i + by_cases hactive : activeSet.Nonempty + · rcases hactive with ⟨q, hq⟩ + have hout_i := hout q hq i + have hleReal : (loOut i : Real) ≤ (hiOut i : Real) := + le_trans hout_i.1 hout_i.2 + exact (dyadicToReal_le_iff (x := loOut i) (y := hiOut i)).1 hleReal + · simp [loOut, hiOut, hactive] + let certOut : Circuit.ResidualIntervalCert dModel := { lo := loOut, hi := hiOut } + exact some + { active := activeSet + cert := certOut + sound := + { bounds := hbounds + output_mem := by + intro q hq i + exact hout q hq i } } + · exact none + · exact none + +end + +end HeadOutputInterval + +end Sound + +end Nfp diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean index bdbd8d5..27c363c 100644 --- a/Nfp/Sound/Induction/LogitDiff.lean +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -25,14 +25,14 @@ noncomputable def headLogitDiff (inputs : Model.InductionHeadInputs seq dModel d dotProduct (weights q) (valsRealOfInputs inputs) /-- Lower bound computed from the per-key lower bounds in an induction certificate. -/ -def logitDiffLowerBoundFromCert (c : InductionHeadCert seq) : Option Rat := +def logitDiffLowerBoundFromCert (c : InductionHeadCert seq) : Option Dyadic := Circuit.logitDiffLowerBoundAt c.active c.prev c.epsAt c.values.lo c.values.hi c.values.valsLo theorem logitDiffLowerBoundFromCert_le (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) (hsound : InductionHeadCertSound inputs c) - {lb : Rat} (hbound : logitDiffLowerBoundFromCert c = some lb) + {lb : Dyadic} (hbound : logitDiffLowerBoundFromCert c = some lb) {q : Fin seq} (hq : q ∈ c.active) : (lb : Real) ≤ headLogitDiff inputs q := by classical @@ -50,7 +50,7 @@ theorem logitDiffLowerBoundFromCert_le Layers.ValueRangeBounds (Val := Real) (c.values.lo : Real) (c.values.hi : Real) (valsRealOfInputs inputs) := by refine { lo_le_hi := ?_, lo_le := ?_, le_hi := ?_ } - · exact (Rat.cast_le (K := Real)).2 hsound.value_bounds.lo_le_hi + · exact dyadicToReal_le_of_le hsound.value_bounds.lo_le_hi · intro k exact le_trans (hsound.value_bounds.lo_le_valsLo k) @@ -71,7 +71,7 @@ theorem logitDiffLowerBoundFromCert_le (weights := weights) (vals := valsRealOfInputs inputs) hweights hvalsRange - have hboundRat : + have hboundDyadic : lb ≤ c.values.valsLo (c.prev q) - c.epsAt q * (c.values.hi - c.values.lo) := by refine @@ -90,9 +90,9 @@ theorem logitDiffLowerBoundFromCert_le (c.epsAt q : Real) * ((c.values.hi : Real) - (c.values.lo : Real)) := by have hboundReal' : (lb : Real) ≤ - (c.values.valsLo (c.prev q) - c.epsAt q * (c.values.hi - c.values.lo) : Rat) := by - exact (Rat.cast_le (K := Real)).2 hboundRat - simpa [Rat.cast_sub, Rat.cast_mul] using hboundReal' + (c.values.valsLo (c.prev q) - c.epsAt q * (c.values.hi - c.values.lo) : Dyadic) := by + exact dyadicToReal_le_of_le hboundDyadic + simpa [dyadicToReal_sub, dyadicToReal_mul] using hboundReal' have hvalsLo : (c.values.valsLo (c.prev q) : Real) ≤ valsRealOfInputs inputs (c.prev q) := by @@ -123,7 +123,7 @@ structure InductionLogitLowerBoundResult /-- Soundness proof for the induction certificate. -/ sound : InductionHeadCertSound inputs cert /-- Reported lower bound on logit diff. -/ - lb : Rat + lb : Dyadic /-- `lb` is computed from `logitDiffLowerBoundFromCert`. -/ lb_def : logitDiffLowerBoundFromCert cert = some lb /-- The lower bound is sound on active queries. -/ diff --git a/Nfp/Sound/Induction/OneHot.lean b/Nfp/Sound/Induction/OneHot.lean index fd2cee8..42200eb 100644 --- a/Nfp/Sound/Induction/OneHot.lean +++ b/Nfp/Sound/Induction/OneHot.lean @@ -2,8 +2,7 @@ import Mathlib.Algebra.BigOperators.Group.Finset.Basic import Mathlib.Algebra.Order.BigOperators.Group.Finset -import Mathlib.Algebra.Order.Ring.Rat -import Mathlib.Data.Rat.Cast.Order +import Nfp.Core.Basic import Nfp.Circuit.Layers.Induction import Nfp.Circuit.Layers.Softmax @@ -26,11 +25,12 @@ theorem oneHot_bounds_at_of_marginAt (active : Finset (Fin seq)) (prev : Fin seq → Fin seq) (scoresReal : Fin seq → Fin seq → Real) - (marginAt : Fin seq → Rat) - (epsAt : Fin seq → Rat) + (marginAt : Fin seq → Dyadic) + (epsAt : Fin seq → Dyadic) (hepsAt : ∀ q, epsAt q = - if marginAt q < 0 then (1 : Rat) else (seq - 1 : Rat) / (1 + marginAt q)) + if marginAt q < 0 then (1 : Dyadic) else + dyadicDivUp (seq - 1) (1 + marginAt q)) (hseq : (1 : Nat) ≤ seq) (hscore_margin_real_at : ∀ q, q ∈ active → ∀ k, k ≠ prev q → @@ -63,7 +63,7 @@ theorem oneHot_bounds_at_of_marginAt have hsum_others_le : (∑ k ∈ others q, weights q k) ≤ (epsAt q : Real) := by by_cases hneg : marginAt q < 0 · have heps : (epsAt q : Real) = 1 := by - simp [hepsAt, hneg] + simp [hepsAt, hneg, dyadicToReal_one] have hsubset : others q ⊆ (Finset.univ : Finset (Fin seq)) := by intro k hk simp @@ -85,7 +85,7 @@ theorem oneHot_bounds_at_of_marginAt simpa [heps] using hsum_le' · have hnonneg : 0 ≤ marginAt q := le_of_not_gt hneg have hnonneg_real : 0 ≤ (marginAt q : Real) := by - exact (Rat.cast_nonneg (K := Real)).2 hnonneg + exact dyadicToReal_nonneg_of_nonneg hnonneg have hbound : ∀ k ∈ others q, weights q k ≤ (1 + (marginAt q : Real))⁻¹ := by @@ -114,9 +114,23 @@ theorem oneHot_bounds_at_of_marginAt simp only [hcard, Nat.cast_sub hseq, Nat.cast_one] at hsum_le''' exact hsum_le''' have heps : - (epsAt q : Real) = (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ := by - simp [hepsAt, hneg, Rat.cast_add, div_eq_mul_inv] - simpa [heps] using hsum_le' + (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ ≤ (epsAt q : Real) := by + have hden : (1 + marginAt q) ≠ 0 := by + intro hzero + have hrat : (1 : Rat) + (marginAt q).toRat = 0 := by + have := congrArg Dyadic.toRat hzero + simpa [Dyadic.toRat_add, Dyadic.toRat_natCast] using this + have hnonneg_rat : (0 : Rat) ≤ (marginAt q).toRat := + (Dyadic.toRat_le_toRat_iff (x := 0) (y := marginAt q)).2 hnonneg + linarith + have hrat : + (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ ≤ + (dyadicDivUp (seq - 1) (1 + marginAt q) : Real) := by + have hrat' := dyadicDivUp_ge_real (seq - 1) (1 + marginAt q) hden + simpa [dyadicToReal, Dyadic.toRat_add, Dyadic.toRat_natCast, + Rat.cast_div, Rat.cast_add, Rat.cast_natCast, div_eq_mul_inv] using hrat' + simpa [hepsAt, hneg] using hrat + exact le_trans hsum_le' heps have hsum_eq : weights q (prev q) + ∑ k ∈ others q, weights q k = 1 := by have hsum' : @@ -166,7 +180,7 @@ theorem oneHot_bounds_at_of_marginAt simpa [heps] using hsum_le' · have hnonneg : 0 ≤ marginAt q := le_of_not_gt hneg have hnonneg_real : 0 ≤ (marginAt q : Real) := by - exact (Rat.cast_nonneg (K := Real)).2 hnonneg + exact dyadicToReal_nonneg_of_nonneg hnonneg have hbound : ∀ j ∈ others q, weights q j ≤ (1 + (marginAt q : Real))⁻¹ := by @@ -195,9 +209,23 @@ theorem oneHot_bounds_at_of_marginAt simp only [hcard, Nat.cast_sub hseq, Nat.cast_one] at hsum_le''' exact hsum_le''' have heps : - (epsAt q : Real) = (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ := by - simp [hepsAt, hneg, Rat.cast_add, div_eq_mul_inv] - simpa [heps] using hsum_le' + (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ ≤ (epsAt q : Real) := by + have hden : (1 + marginAt q) ≠ 0 := by + intro hzero + have hrat : (1 : Rat) + (marginAt q).toRat = 0 := by + have := congrArg Dyadic.toRat hzero + simpa [Dyadic.toRat_add, Dyadic.toRat_natCast] using this + have hnonneg_rat : (0 : Rat) ≤ (marginAt q).toRat := + (Dyadic.toRat_le_toRat_iff (x := 0) (y := marginAt q)).2 hnonneg + linarith + have hrat : + (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ ≤ + (dyadicDivUp (seq - 1) (1 + marginAt q) : Real) := by + have hrat' := dyadicDivUp_ge_real (seq - 1) (1 + marginAt q) hden + simpa [dyadicToReal, Dyadic.toRat_add, Dyadic.toRat_natCast, + Rat.cast_div, Rat.cast_add, Rat.cast_natCast, div_eq_mul_inv] using hrat' + simpa [hepsAt, hneg] using hrat + exact le_trans hsum_le' heps have hk' : k ∈ others q := by simp [others, hk] have hnonneg : diff --git a/Nfp/Sound/Linear/FinFold.lean b/Nfp/Sound/Linear/FinFold.lean index d764309..c722133 100644 --- a/Nfp/Sound/Linear/FinFold.lean +++ b/Nfp/Sound/Linear/FinFold.lean @@ -1,7 +1,9 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.Order.Ring.Rat +import Mathlib.Algebra.BigOperators.Fin +import Mathlib.Data.Matrix.Mul import Batteries.Data.Fin.Fold +import Nfp.Core.Basic /-! Tail-recursive folds and sums over `Fin`. @@ -27,24 +29,96 @@ theorem foldlFin_eq_foldl (n : Nat) (f : α → Fin n → α) (init : α) : simpa [foldlFin] using (Fin.dfoldl_eq_foldl (n := n) (f := fun i acc => f acc i) (x := init)) -/-- Tail-recursive sum over `Fin n` (Rat-valued). -/ -def sumFin (n : Nat) (f : Fin n → Rat) : Rat := +/-- Tail-recursive sum over `Fin n` (Dyadic-valued). -/ +def sumFin (n : Nat) (f : Fin n → Dyadic) : Dyadic := foldlFin n (fun acc i => acc + f i) 0 +/-- Tail-recursive sum over `Fin n` (alias for `sumFin`). -/ +def sumFinCommonDen (n : Nat) (f : Fin n → Dyadic) : Dyadic := + sumFin n f + /-- `sumFin` as a left fold over the finite range list. -/ -theorem sumFin_eq_list_foldl (n : Nat) (f : Fin n → Rat) : +theorem sumFin_eq_list_foldl (n : Nat) (f : Fin n → Dyadic) : sumFin n f = (List.finRange n).foldl (fun acc i => acc + f i) 0 := by simpa [sumFin, foldlFin_eq_foldl] using - (Fin.foldl_eq_foldl_finRange (f := fun acc i => acc + f i) (x := (0 : Rat)) (n := n)) - -/-- Dot product over `Fin n` (Rat-valued). -/ -def dotFin (n : Nat) (x y : Fin n → Rat) : Rat := + (Fin.foldl_eq_foldl_finRange (f := fun acc i => acc + f i) (x := (0 : Dyadic)) (n := n)) + +/-- `sumFin` agrees with the `Finset.univ` sum. -/ +theorem sumFin_eq_sum_univ {n : Nat} (f : Fin n → Dyadic) : + sumFin n f = ∑ i, f i := by + classical + have hfold : + sumFin n f = (List.finRange n).foldl (fun acc i => acc + f i) 0 := by + simpa using sumFin_eq_list_foldl n f + have hmap : + ((List.finRange n).map f).foldl (fun acc x : Dyadic => acc + x) 0 = + (List.finRange n).foldl (fun acc i => acc + f i) 0 := by + simpa using + (List.foldl_map (f := f) (g := fun acc x : Dyadic => acc + x) + (l := List.finRange n) (init := (0 : Dyadic))) + let _ : Std.Commutative (fun a b : Dyadic => a + b) := + ⟨by intro a b; exact add_comm _ _⟩ + let _ : Std.Associative (fun a b : Dyadic => a + b) := + ⟨by intro a b c; exact add_assoc _ _ _⟩ + have hfoldr : + ((List.finRange n).map f).foldl (fun acc x : Dyadic => acc + x) 0 = + ((List.finRange n).map f).foldr (fun x acc => x + acc) 0 := by + simpa using + (List.foldl_eq_foldr (f := fun acc x : Dyadic => acc + x) + (a := 0) (l := (List.finRange n).map f)) + have hsum_list : + ((List.finRange n).map f).sum = (List.finRange n).foldl (fun acc i => acc + f i) 0 := by + calc + ((List.finRange n).map f).sum + = ((List.finRange n).map f).foldr (fun x acc => x + acc) 0 := by + rfl + _ = ((List.finRange n).map f).foldl (fun acc x : Dyadic => acc + x) 0 := by + exact hfoldr.symm + _ = (List.finRange n).foldl (fun acc i => acc + f i) 0 := by + exact hmap + have hsum_univ : ((List.finRange n).map f).sum = ∑ i, f i := by + exact (Fin.sum_univ_def f).symm + calc + sumFin n f + = (List.finRange n).foldl (fun acc i => acc + f i) 0 := hfold + _ = ((List.finRange n).map f).sum := hsum_list.symm + _ = ∑ i, f i := hsum_univ + +/-- Casting a `Finset.univ` dyadic sum to `Real` commutes with summation. -/ +theorem dyadicToReal_sum_univ {n : Nat} (f : Fin n → Dyadic) : + ((∑ i, f i : Dyadic) : Real) = ∑ i, (f i : Real) := by + classical + refine Finset.induction_on (Finset.univ : Finset (Fin n)) ?_ ?_ + · simp + · intro a s ha hs + simp [Finset.sum_insert, ha, hs, dyadicToReal_add] + +/-- Casting a dyadic `sumFin` to `Real` commutes with summation. -/ +theorem dyadicToReal_sumFin {n : Nat} (f : Fin n → Dyadic) : + (sumFin n f : Real) = ∑ i, (f i : Real) := by + classical + have hsum : sumFin n f = ∑ i, f i := sumFin_eq_sum_univ (f := f) + have hcast : ((∑ i, f i : Dyadic) : Real) = ∑ i, (f i : Real) := + dyadicToReal_sum_univ (f := f) + simpa [hsum] using hcast + +/-- `sumFinCommonDen` agrees with `sumFin`. -/ +theorem sumFinCommonDen_eq_sumFin (n : Nat) (f : Fin n → Dyadic) : + sumFinCommonDen n f = sumFin n f := rfl + +/-- Dot product over `Fin n` (Dyadic-valued). -/ +def dotFin (n : Nat) (x y : Fin n → Dyadic) : Dyadic := sumFin n (fun i => x i * y i) /-- Unfolding lemma for `dotFin`. -/ -theorem dotFin_def (n : Nat) (x y : Fin n → Rat) : +theorem dotFin_def (n : Nat) (x y : Fin n → Dyadic) : dotFin n x y = sumFin n (fun i => x i * y i) := rfl +/-- `dotFin` matches `dotProduct`. -/ +theorem dotFin_eq_dotProduct (n : Nat) (x y : Fin n → Dyadic) : + dotFin n x y = dotProduct x y := by + simp [dotFin_def, sumFin_eq_sum_univ, dotProduct] + end Linear end Sound diff --git a/TheoremAxioms.lean b/TheoremAxioms.lean new file mode 100644 index 0000000..b390259 --- /dev/null +++ b/TheoremAxioms.lean @@ -0,0 +1,28 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp + +/-! +Axioms used by key definitions/lemmas. +These `#print axioms` lines help ensure we only depend on a small set of axioms +(ideally a subset of: `propext`, `Classical.choice`, `Quot.sound`). +-/ + +#print axioms Nfp.ProbVec.sum_mass +#print axioms Nfp.ProbVec.pure +#print axioms Nfp.ProbVec.mix +#print axioms Nfp.Mixer.push +#print axioms Nfp.Mixer.comp +#print axioms Nfp.Mixer.id +#print axioms Nfp.Dag.parents +#print axioms Nfp.LocalSystem.toMixer +#print axioms Nfp.LocalSystem.eval +#print axioms Nfp.LocalSystem.eval_eq +#print axioms Nfp.Circuit.eval +#print axioms Nfp.Circuit.evalInput +#print axioms Nfp.Circuit.Interface.eval +#print axioms Nfp.Circuit.checkEquiv +#print axioms Nfp.Circuit.checkEquivOnInterface + +/-- Entrypoint for the axiom report build target. -/ +def main : IO Unit := pure () diff --git a/lakefile.toml b/lakefile.toml index 3ecdd1e..e627101 100644 --- a/lakefile.toml +++ b/lakefile.toml @@ -36,3 +36,19 @@ roots = ["Nfp"] [[lean_exe]] name = "nfp" root = "Main" + +[[lean_exe]] +name = "bench-rational" +root = "BenchRational" + +[[lean_exe]] +name = "bench-induction-core" +root = "BenchInductionCore" + +[[lean_exe]] +name = "bench-induction-counts" +root = "BenchInductionCounts" + +[[lean_exe]] +name = "theorem-axioms" +root = "TheoremAxioms" diff --git a/scripts/build_residual_interval_cert.py b/scripts/build_residual_interval_cert.py index c70b755..accfd61 100644 --- a/scripts/build_residual_interval_cert.py +++ b/scripts/build_residual_interval_cert.py @@ -18,6 +18,7 @@ Optional: --tokens tokens.txt # whitespace-separated token ids + --nfpt model.nfpt # read tokens from binary model --random-pattern --seed 0 --decimals 6 --safety 1e-6 """ @@ -62,6 +63,32 @@ def parse_tokens(path: Path) -> np.ndarray: return np.array(tokens, dtype=np.int64) +def parse_tokens_from_nfpt(path: Path) -> np.ndarray: + header: dict[str, str] = {} + with path.open("rb") as f: + while True: + line = f.readline() + if not line: + raise SystemExit("unexpected EOF while reading header") + text = line.decode("ascii").strip() + if text == "BINARY_START": + break + if "=" in text: + key, value = text.split("=", 1) + header[key.strip()] = value.strip() + seq_len_raw = header.get("seq_len") + if seq_len_raw is None: + raise SystemExit("header missing seq_len") + seq_len = int(seq_len_raw) + token_bytes = f.read(seq_len * 4) + if len(token_bytes) != seq_len * 4: + raise SystemExit("unexpected EOF while reading tokens") + tokens = np.frombuffer(token_bytes, dtype=" tuple[int | None, list[int]]: seq = None active: list[int] = [] @@ -106,6 +133,7 @@ def main() -> None: help="Use random token pattern") parser.add_argument("--seed", type=int, default=0, help="RNG seed for random pattern") parser.add_argument("--tokens", help="Optional path to whitespace-separated tokens") + parser.add_argument("--nfpt", help="Optional .nfpt file to read tokens from") parser.add_argument("--scores", help="Optional softmax-margin certificate for active queries") parser.add_argument("--model", default="gpt2", help="HuggingFace model name") parser.add_argument("--device", default="cpu", help="Torch device") @@ -122,7 +150,10 @@ def main() -> None: if args.safety < 0: raise SystemExit("safety must be nonnegative") - if args.tokens: + if args.nfpt: + tokens = parse_tokens_from_nfpt(Path(args.nfpt)) + seq = len(tokens) + elif args.tokens: tokens = parse_tokens(Path(args.tokens)) seq = len(tokens) else: diff --git a/scripts/scan_gpt2_induction_sound.py b/scripts/scan_gpt2_induction_sound.py index 0d33efe..b9e8ec4 100755 --- a/scripts/scan_gpt2_induction_sound.py +++ b/scripts/scan_gpt2_induction_sound.py @@ -6,15 +6,15 @@ This script: 1) Ensures a GPT-2 "rigorous induction" binary model exists. -2) Uses the heuristic `induction` command to list candidate head pairs. -3) Runs `induction_cert` for each pair and ranks by logitDiffLB. +2) Uses the untrusted discovery helper to propose head/direction candidates. +3) Runs `nfp induction certify_head_model_nonvacuous` to check each candidate. """ from __future__ import annotations import argparse +import json import os -import re import shutil import struct import subprocess @@ -24,9 +24,6 @@ from pathlib import Path -PAIR_RE = re.compile(r"L(\d+)H(\d+)\s+->\s+L(\d+)H(\d+)") - - def run_cmd(cmd: list[str]) -> str: proc = subprocess.run(cmd, check=True, capture_output=True, text=True) return proc.stdout @@ -97,31 +94,17 @@ def derive_target_negative(tokens: list[int]) -> tuple[int, int]: return target, negative -def parse_candidates(output: str, top: int) -> list[tuple[int, int, int, int]]: - pairs: list[tuple[int, int, int, int]] = [] - seen: set[tuple[int, int, int, int]] = set() - for line in output.splitlines(): - match = PAIR_RE.search(line) - if match is None: - continue - pair = tuple(int(x) for x in match.groups()) - if pair in seen: - continue - seen.add(pair) - pairs.append(pair) - if len(pairs) >= top: - break - return pairs - - def parse_logit_lb(output: str) -> Fraction | None: for line in output.splitlines(): - if line.startswith("logitDiffLB="): - token = line.split("=", 1)[1].split()[0] - try: - return Fraction(token) - except ValueError: - return None + if "logitDiffLB=" not in line: + continue + for token in line.split(): + if token.startswith("logitDiffLB="): + value = token.split("=", 1)[1].strip("),") + try: + return Fraction(value) + except ValueError: + return None return None @@ -129,21 +112,16 @@ def main() -> int: parser = argparse.ArgumentParser() parser.add_argument("--model", default="models/gpt2_rigorous.nfpt") parser.add_argument("--top", type=int, default=8) - parser.add_argument("--delta", default="0.01") - parser.add_argument("--coord", type=int, default=0) - parser.add_argument("--offset1", type=int, default=-1) - parser.add_argument("--offset2", type=int, default=0) - parser.add_argument("--keyOffset1", type=int, default=0) - parser.add_argument("--keyOffset2", type=int, default=-1) parser.add_argument("--maxSeqLen", type=int, default=256) parser.add_argument("--jobs", type=int, default=1) parser.add_argument("--fast", action="store_true") parser.add_argument("--nfp-bin", help="Path to nfp binary (defaults to .lake/build/bin/nfp)") - parser.add_argument("--tightPattern", action="store_true") - parser.add_argument("--tightPatternLayers", type=int) - parser.add_argument("--perRowPatternLayers", type=int) - parser.add_argument("--bestMatch", action="store_true") - parser.add_argument("--queryPos", type=int) + parser.add_argument("--min-eps", type=float, default=0.5) + parser.add_argument("--min-margin", type=float, default=0.0) + parser.add_argument("--min-logit-lb", type=float, default=0.0) + parser.add_argument("--layers", help="Comma-separated layer list or 'all'") + parser.add_argument("--heads", help="Comma-separated head list or 'all'") + parser.add_argument("--period", type=int) parser.add_argument("--output", default="reports/gpt2_induction_sound_scan.txt") args = parser.parse_args() args.jobs = max(1, args.jobs) @@ -169,84 +147,82 @@ def main() -> int: return 1 target, negative = derive_target_negative(tokens) - induction_out = run_cmd( - nfp_cmd - + [ - "induction", - str(model_path), - "--threshold", - "0.0", - ] - ) - pairs = parse_candidates(induction_out, args.top) - if not pairs: + discover_json = Path(args.output).with_suffix(".json") + discover_txt = Path(args.output).with_suffix(".discover.txt") + discover_cmd = [ + sys.executable, + "scripts/discover_gpt2_induction_targets.py", + "--model", + str(model_path), + "--top", + str(args.top), + "--min-eps", + str(args.min_eps), + "--min-margin", + str(args.min_margin), + "--min-logit-lb", + str(args.min_logit_lb), + "--output", + str(discover_txt), + "--json-out", + str(discover_json), + ] + if args.layers is not None: + discover_cmd += ["--layers", args.layers] + if args.heads is not None: + discover_cmd += ["--heads", args.heads] + if args.period is not None: + discover_cmd += ["--period", str(args.period)] + run_cmd(discover_cmd) + payload = json.loads(discover_json.read_text(encoding="ascii")) + candidates = payload.get("results", []) + if not candidates: print("No induction candidates found.", file=sys.stderr) return 1 - results: list[tuple[Fraction, tuple[int, int, int, int]]] = [] + results: list[tuple[Fraction, dict[str, int]]] = [] - def run_cert(pair: tuple[int, int, int, int]) -> tuple[tuple[int, int, int, int], Fraction | None]: - l1, h1, l2, h2 = pair + def run_cert(candidate: dict[str, int]) -> tuple[dict[str, int], Fraction | None]: + layer = int(candidate["layer"]) + head = int(candidate["head"]) + target_id = int(candidate.get("target", target)) + negative_id = int(candidate.get("negative", negative)) cmd = nfp_cmd + [ - "induction_cert", + "induction", + "certify_head_model_nonvacuous", + "--model", str(model_path), - "--layer1", - str(l1), - "--head1", - str(h1), - "--layer2", - str(l2), - "--head2", - str(h2), - "--coord", - str(args.coord), - "--offset1", - str(args.offset1), - "--offset2", - str(args.offset2), - "--keyOffset1", - str(args.keyOffset1), - "--keyOffset2", - str(args.keyOffset2), - "--delta", - args.delta, - "--maxSeqLen", - str(args.maxSeqLen), - "--target", - str(target), - "--negative", - str(negative), + "--layer", + str(layer), + "--head", + str(head), + "--direction-target", + str(target_id), + "--direction-negative", + str(negative_id), ] - if args.tightPattern: - cmd.append("--tightPattern") - if args.tightPatternLayers is not None: - cmd.extend(["--tightPatternLayers", str(args.tightPatternLayers)]) - if args.perRowPatternLayers is not None: - cmd.extend(["--perRowPatternLayers", str(args.perRowPatternLayers)]) - if args.bestMatch: - cmd.append("--bestMatch") - if args.queryPos is not None: - cmd.extend(["--queryPos", str(args.queryPos)]) + if args.period is not None: + cmd += ["--period", str(args.period)] try: cert_out = run_cmd(cmd) except subprocess.CalledProcessError: - return pair, None - return pair, parse_logit_lb(cert_out) + return candidate, None + return candidate, parse_logit_lb(cert_out) if args.jobs == 1: - for pair in pairs: - pair_out, logit_lb = run_cert(pair) + for candidate in candidates: + candidate_out, logit_lb = run_cert(candidate) if logit_lb is None: continue - results.append((logit_lb, pair_out)) + results.append((logit_lb, candidate_out)) else: with ThreadPoolExecutor(max_workers=args.jobs) as executor: - futures = {executor.submit(run_cert, pair): pair for pair in pairs} + futures = {executor.submit(run_cert, candidate): candidate for candidate in candidates} for future in as_completed(futures): - pair_out, logit_lb = future.result() + candidate_out, logit_lb = future.result() if logit_lb is None: continue - results.append((logit_lb, pair_out)) + results.append((logit_lb, candidate_out)) if not results: print("No sound logit bounds produced.", file=sys.stderr) @@ -259,16 +235,16 @@ def run_cert(pair: tuple[int, int, int, int]) -> tuple[tuple[int, int, int, int] f.write("SOUND induction scan (logitDiffLB ranking)\n") f.write(f"model={model_path}\n") f.write(f"target={target} negative={negative}\n") - f.write( - f"bestMatch={args.bestMatch} queryPos={args.queryPos} " - f"tightPatternLayers={args.tightPatternLayers} " - f"perRowPatternLayers={args.perRowPatternLayers}\n" - ) eps_header = header.get("layer_norm_eps") or header.get("eps") or "unknown" - f.write(f"top={args.top} delta={args.delta} eps={eps_header}\n") - for rank, (lb, (l1, h1, l2, h2)) in enumerate(results, start=1): + f.write(f"top={args.top} eps={eps_header}\n") + for rank, (lb, candidate) in enumerate(results, start=1): + layer = int(candidate["layer"]) + head = int(candidate["head"]) + target_id = int(candidate.get("target", target)) + negative_id = int(candidate.get("negative", negative)) f.write( - f"{rank:02d} L{l1}H{h1} -> L{l2}H{h2} logitDiffLB={lb}\n" + f"{rank:02d} L{layer}H{head} " + f"target={target_id} negative={negative_id} logitDiffLB={lb}\n" ) print(f"Report written to {out_path}") From aef5a5fd3e6b1b98b4c39434b8ce4568886b9926 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Wed, 7 Jan 2026 02:27:24 +0100 Subject: [PATCH 109/244] Add missing docstrings and reuse global dyadic-real lemmas --- Nfp/IO/Checks.lean | 2 ++ Nfp/IO/Derive.lean | 1 + Nfp/IO/Pure/SoftmaxMargin/Shared.lean | 16 ++++++++++++++-- Nfp/IO/Pure/ValueRange/Shared.lean | 10 ++++++++++ Nfp/Sound/Bounds/Gelu.lean | 6 +++--- 5 files changed, 30 insertions(+), 5 deletions(-) diff --git a/Nfp/IO/Checks.lean b/Nfp/IO/Checks.lean index 4b373ab..224b3ac 100644 --- a/Nfp/IO/Checks.lean +++ b/Nfp/IO/Checks.lean @@ -13,6 +13,7 @@ namespace IO open Nfp.Circuit +/-- Check a softmax-margin certificate for a positive sequence length. -/ def checkSoftmaxMargin (seq : Nat) (cert : SoftmaxMarginCert seq) : IO (Except String Unit) := match seq with @@ -26,6 +27,7 @@ def checkSoftmaxMargin (seq : Nat) (cert : SoftmaxMarginCert seq) : else return Except.error "softmax-margin certificate rejected" +/-- Check a value-range certificate for a positive sequence length. -/ def checkValueRange (seq : Nat) (cert : ValueRangeCert seq) : IO (Except String Unit) := match seq with diff --git a/Nfp/IO/Derive.lean b/Nfp/IO/Derive.lean index 639f511..b18fff1 100644 --- a/Nfp/IO/Derive.lean +++ b/Nfp/IO/Derive.lean @@ -20,6 +20,7 @@ namespace IO open Nfp.Circuit +/-- Build a residual-interval certificate from an on-disk model payload. -/ def deriveResidualIntervalFromModel (data : ByteArray) (start : Nat) (header : NfptPure.NfptHeader) (active? : Option (Finset (Fin header.seqLen))) : IO (Except String (ResidualIntervalCert header.modelDim)) := do diff --git a/Nfp/IO/Pure/SoftmaxMargin/Shared.lean b/Nfp/IO/Pure/SoftmaxMargin/Shared.lean index a316d83..7430cff 100644 --- a/Nfp/IO/Pure/SoftmaxMargin/Shared.lean +++ b/Nfp/IO/Pure/SoftmaxMargin/Shared.lean @@ -15,17 +15,24 @@ namespace Pure namespace SoftmaxMargin -open Nfp.Circuit - +/-- State for parsing softmax-margin payloads. -/ structure ParseState (seq : Nat) where + /-- Optional epsilon bound. -/ eps : Option Dyadic + /-- Optional margin bound. -/ margin : Option Dyadic + /-- Active query set. -/ active : Finset (Fin seq) + /-- Whether any active entries were parsed. -/ activeSeen : Bool + /-- Optional predecessor pointer per query. -/ prev : Fin seq → Option (Fin seq) + /-- Optional score matrix entries. -/ scores : Fin seq → Fin seq → Option Dyadic + /-- Optional weight matrix entries. -/ weights : Fin seq → Fin seq → Option Dyadic +/-- Initialize a softmax-margin parse state. -/ def initState (seq : Nat) : ParseState seq := { eps := none margin := none @@ -35,6 +42,7 @@ def initState (seq : Nat) : ParseState seq := scores := fun _ _ => none weights := fun _ _ => none } +/-- Set a predecessor entry from `(q, k)` tokens. -/ def setPrev {seq : Nat} (st : ParseState seq) (q k : Nat) : Except String (ParseState seq) := do if hq : q < seq then if hk : k < seq then @@ -55,6 +63,7 @@ def setPrev {seq : Nat} (st : ParseState seq) (q k : Nat) : Except String (Parse else throw s!"prev index out of range: q={q}" +/-- Mark an active query index. -/ def setActive {seq : Nat} (st : ParseState seq) (q : Nat) : Except String (ParseState seq) := do if hq : q < seq then let qFin : Fin seq := ⟨q, hq⟩ @@ -65,6 +74,7 @@ def setActive {seq : Nat} (st : ParseState seq) (q : Nat) : Except String (Parse else throw s!"active index out of range: q={q}" +/-- Insert a matrix entry for scores/weights. -/ def setMatrixEntry {seq : Nat} (mat : Fin seq → Fin seq → Option Dyadic) (q k : Nat) (v : Dyadic) : Except String (Fin seq → Fin seq → Option Dyadic) := do if hq : q < seq then @@ -89,6 +99,7 @@ def setMatrixEntry {seq : Nat} (mat : Fin seq → Fin seq → Option Dyadic) else throw s!"index out of range: q={q}" +/-- Parse a tokenized line into the softmax-margin parse state. -/ def parseLine {seq : Nat} (st : ParseState seq) (tokens : List String) : Except String (ParseState seq) := do match tokens with @@ -115,6 +126,7 @@ def parseLine {seq : Nat} (st : ParseState seq) | _ => throw s!"unrecognized line: '{String.intercalate " " tokens}'" +/-- Extract the `seq` header from tokenized lines. -/ def parseSeq (tokens : List (List String)) : Except String Nat := do let mut seq? : Option Nat := none for t in tokens do diff --git a/Nfp/IO/Pure/ValueRange/Shared.lean b/Nfp/IO/Pure/ValueRange/Shared.lean index b653df2..4a4c621 100644 --- a/Nfp/IO/Pure/ValueRange/Shared.lean +++ b/Nfp/IO/Pure/ValueRange/Shared.lean @@ -17,14 +17,21 @@ namespace ValueRange open Nfp.Circuit +/-- State for parsing value-range payloads. -/ structure ParseState (seq : Nat) where + /-- Optional lower bound. -/ lo : Option Dyadic + /-- Optional upper bound. -/ hi : Option Dyadic + /-- Optional per-position values. -/ vals : Fin seq → Option Dyadic + /-- Optional direction target index. -/ directionTarget : Option Nat + /-- Optional direction negative index. -/ directionNegative : Option Nat +/-- Initialize a value-range parse state. -/ def initState (seq : Nat) : ParseState seq := { lo := none hi := none @@ -33,6 +40,7 @@ def initState (seq : Nat) : ParseState seq := directionNegative := none } +/-- Set a value entry from `(k, v)` tokens. -/ def setVal {seq : Nat} (st : ParseState seq) (k : Nat) (v : Dyadic) : Except String (ParseState seq) := do if hk : k < seq then @@ -51,6 +59,7 @@ def setVal {seq : Nat} (st : ParseState seq) (k : Nat) (v : Dyadic) : throw s!"value index out of range: k={k}" +/-- Parse a tokenized line into the value-range parse state. -/ def parseLine {seq : Nat} (st : ParseState seq) (tokens : List String) : Except String (ParseState seq) := do match tokens with @@ -80,6 +89,7 @@ def parseLine {seq : Nat} (st : ParseState seq) throw s!"unrecognized line: '{String.intercalate " " tokens}'" +/-- Extract the `seq` header from tokenized lines. -/ def parseSeq (tokens : List (List String)) : Except String Nat := do let mut seq? : Option Nat := none for t in tokens do diff --git a/Nfp/Sound/Bounds/Gelu.lean b/Nfp/Sound/Bounds/Gelu.lean index 2bb5ab1..8190bdb 100644 --- a/Nfp/Sound/Bounds/Gelu.lean +++ b/Nfp/Sound/Bounds/Gelu.lean @@ -141,7 +141,7 @@ theorem geluInterval_bounds {lo hi : Dyadic} {x : Real} · simpa [geluInterval, hlo0] using hlo' · by_cases hhi0 : 0 ≤ hi · have hhi0r : 0 ≤ (hi : Real) := by - exact (dyadicToReal_nonneg_iff (x := hi)).2 hhi0 + exact dyadicToReal_nonneg_of_nonneg hhi0 have hmax' : max (hi : Real) 0 = (hi : Real) := max_eq_left hhi0r simpa [geluInterval, hhi0, hmax'] using hhi' · have hhi0r : (hi : Real) ≤ 0 := by @@ -152,7 +152,7 @@ theorem geluInterval_bounds {lo hi : Dyadic} {x : Real} simpa [hmax'] using hgelu.2 simpa [geluInterval, hhi0, dyadicToReal_zero] using hhi'' · have hlo0r : 0 ≤ (lo : Real) := by - exact (dyadicToReal_nonneg_iff (x := lo)).2 (le_of_not_ge hlo0) + exact dyadicToReal_nonneg_of_nonneg (le_of_not_ge hlo0) have hx0 : 0 ≤ x := le_trans hlo0r hlo have hmin' : min x 0 = 0 := min_eq_right hx0 have hlo' : (0 : Real) ≤ geluTanh x := by @@ -163,7 +163,7 @@ theorem geluInterval_bounds {lo hi : Dyadic} {x : Real} · simpa [geluInterval, hlo0, dyadicToReal_zero] using hlo' · by_cases hhi0 : 0 ≤ hi · have hhi0r : 0 ≤ (hi : Real) := by - exact (dyadicToReal_nonneg_iff (x := hi)).2 hhi0 + exact dyadicToReal_nonneg_of_nonneg hhi0 have hmax' : max (hi : Real) 0 = (hi : Real) := max_eq_left hhi0r simpa [geluInterval, hhi0, hmax'] using hhi' · have hhi0r : (hi : Real) ≤ 0 := by From 3cb190ba4c9e58a1a276c43f1788a6c7cabf10ae Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 8 Jan 2026 13:02:36 +0100 Subject: [PATCH 110/244] Refine induction Q/K abs bounds caching --- Nfp/IO/InductionHead.lean | 838 ++++++++++++++------- Nfp/Sound/Induction/Core.lean | 1296 ++++++++++++++++++--------------- 2 files changed, 1291 insertions(+), 843 deletions(-) diff --git a/Nfp/IO/InductionHead.lean b/Nfp/IO/InductionHead.lean index d18a3e8..97abeeb 100644 --- a/Nfp/IO/InductionHead.lean +++ b/Nfp/IO/InductionHead.lean @@ -21,6 +21,11 @@ namespace Nfp namespace IO +private def unwrapTaskResult {α : Type} (res : Except IO.Error α) : IO α := + match res with + | .ok a => pure a + | .error e => throw e + open Nfp.Circuit private def valueBoundsModeFromEnv : IO (Option Bool) := do @@ -29,6 +34,163 @@ private def valueBoundsModeFromEnv : IO (Option Bool) := do | some "cached" => return some false | _ => return none +/-- Read the heartbeat interval (ms) for long-running induction cert builds. -/ +private def heartbeatMsFromEnv : IO UInt32 := do + let defaultMs : Nat := 10000 + let ms := (← IO.getEnv "NFP_TIMING_HEARTBEAT_MS").bind String.toNat? |>.getD defaultMs + return UInt32.ofNat ms + +private def timePureWithHeartbeat {α : Type} (label : String) (f : Unit → α) : IO α := do + let t0 ← monoUsNow + IO.println s!"timing: {label} start" + flushStdout + let task : Task α := Task.spawn (fun _ => f ()) + let heartbeatMs ← heartbeatMsFromEnv + if heartbeatMs ≠ 0 then + let mut finished := (← IO.hasFinished task) + while !finished do + IO.sleep heartbeatMs + finished := (← IO.hasFinished task) + if !finished then + let now ← monoUsNow + IO.println s!"timing: {label} running {now - t0} us" + flushStdout + let res ← IO.wait task + let t1 ← monoUsNow + IO.println s!"timing: {label} {t1 - t0} us" + return res + +private def forceRat (x : Rat) : IO Unit := do + if x = x then + pure () + else + pure () + +/-- Profile the core induction-head bounds used by the sound certificate builder. -/ +private def timeInductionHeadCoreStages {seq dModel dHead : Nat} [NeZero seq] + (inputs : Model.InductionHeadInputs seq dModel dHead) : IO Unit := do + IO.println "timing: core stages start" + flushStdout + let lnBounds ← timePureWithHeartbeat "core: ln bounds" (fun () => + Sound.headLnBounds inputs) + let lnLo := lnBounds.1 + let lnHi := lnBounds.2 + let lnAbsMaxTask : Fin seq → Rat := + Sound.Bounds.cacheBoundTask (fun q => + Sound.Bounds.intervalAbsBound (lnLo q) (lnHi q)) + let lnAbsMaxArr ← timePureWithHeartbeat "core: lnAbsMax force" (fun () => + Array.ofFn (fun q : Fin seq => lnAbsMaxTask q)) + let lnAbsMax : Fin seq → Rat := fun q => + lnAbsMaxArr.getD q.1 (0 : Rat) + let lnAbsMaxMax : Rat := + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := by + simp [univ] + univ.sup' hnonempty (fun q => lnAbsMax q) + let qAbsRowTasks : Array (Task (Array Rat)) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + Array.ofFn (fun d : Fin dHead => + Sound.Bounds.dotIntervalAbsBound (fun j => inputs.wq j d) (lnLo q) (lnHi q) + + |inputs.bq d|))) + let qAbsBaseArr ← timePureWithHeartbeat "core: qAbs base force" (fun () => + let defaultTask : Task (Array Rat) := Task.spawn (fun _ => #[]) + Array.ofFn (fun q : Fin seq => + (qAbsRowTasks.getD q.1 defaultTask).get)) + let qAbsBase : Fin seq → Fin dHead → Rat := fun q d => + let row := qAbsBaseArr.getD q.1 #[] + row.getD d.1 (0 : Rat) + let kAbsRowTasks : Array (Task (Array Rat)) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + Array.ofFn (fun d : Fin dHead => + Sound.Bounds.dotIntervalAbsBound (fun j => inputs.wk j d) (lnLo q) (lnHi q) + + |inputs.bk d|))) + let kAbsBaseArr ← timePureWithHeartbeat "core: kAbs base force" (fun () => + let defaultTask : Task (Array Rat) := Task.spawn (fun _ => #[]) + Array.ofFn (fun q : Fin seq => + (kAbsRowTasks.getD q.1 defaultTask).get)) + let kAbsBase : Fin seq → Fin dHead → Rat := fun q d => + let row := kAbsBaseArr.getD q.1 #[] + row.getD d.1 (0 : Rat) + let dotAbs ← timePureWithHeartbeat "core: dotAbs tasks" (fun () => + dotAbsFromQKV qAbsBase kAbsBase) + let _ ← timePureWithHeartbeat "core: dotAbs force" (fun () => + match List.finRange seq with + | [] => (0 : Rat) + | q :: _ => + match List.finRange seq with + | [] => (0 : Rat) + | k :: _ => dotAbs q k) + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => + |inputs.scale| * dotAbs q k + let scoreLo : Fin seq → Fin seq → Rat := fun q k => + if masked q k then inputs.maskValue else -scoreBaseAbs q k + let scoreHi : Fin seq → Fin seq → Rat := fun q k => + if masked q k then inputs.maskValue else scoreBaseAbs q k + let scoreLoPrev : Fin seq → Rat := fun q => + scoreLo q (inputs.prev q) + let otherKeys : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + let marginAt : Fin seq → Rat := fun q => + if hq : q ∈ inputs.active then + let other := otherKeys q + if h : other.Nonempty then + other.inf' h (fun k => scoreLoPrev q - scoreHi q k) + else + (0 : Rat) + else + (0 : Rat) + let margin ← timePureWithHeartbeat "core: margin" (fun () => + if h : inputs.active.Nonempty then + inputs.active.inf' h marginAt + else + (0 : Rat)) + let marginNeg ← timePureWithHeartbeat "core: margin < 0" (fun () => + decide (margin < 0)) + let verboseTiming ← IO.getEnv "NFP_TIMING_VERBOSE" + if verboseTiming.isSome then + IO.println s!"timing: core: margin neg={marginNeg}" + let tEps0 ← monoUsNow + IO.println "timing: core: eps start" + flushStdout + let eps := + if marginNeg then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + margin) + let tEps1 ← monoUsNow + IO.println s!"timing: core: eps {tEps1 - tEps0} us" + flushStdout + let _ := marginAt + let dirHeadVec ← timePureWithHeartbeat "core: dir head vec" (fun () => + Sound.dirHeadVecOfInputs inputs) + let dirHead : Fin dHead → Rat := fun d => dirHeadVec.get d + let wvDir : Fin dModel → Rat := + Sound.Bounds.cacheBoundTask (fun j => + Sound.Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) + let _ ← timePureWithHeartbeat "core: wvDir force" (fun () => + Array.ofFn (fun j : Fin dModel => wvDir j)) + let bDir ← timePureWithHeartbeat "core: bDir" (fun () => + Sound.Linear.dotFin dHead dirHead (fun d => inputs.bv d)) + let valsAbsBase ← timePureWithHeartbeat "core: valsAbsBase" (fun () => + Sound.Linear.sumFin dModel (fun j => |wvDir j|) * lnAbsMaxMax) + let valsLoBase := bDir - valsAbsBase + let valsHiBase := bDir + valsAbsBase + let valsLo : Fin seq → Rat := fun _ => valsLoBase + let valsHi : Fin seq → Rat := fun _ => valsHiBase + let _ ← timePureWithHeartbeat "core: value bounds" (fun () => + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := by + simp [univ] + let lo := univ.inf' hnonempty valsLo + let hi := univ.sup' hnonempty valsHi + (lo, hi)) + IO.println "timing: core stages done" + flushStdout + /-- Load induction head inputs from disk. -/ def loadInductionHeadInputs (path : System.FilePath) : IO (Except String (Sigma (fun seq => @@ -46,15 +208,15 @@ def loadInductionHeadInputs (path : System.FilePath) : IO.println s!"timing: parse head input file {t3 - t2} us" return parsed -private def dyadicToString (x : Dyadic) : String := - toString x.toRat +private def ratToString (x : Rat) : String := + toString x private def renderResidualIntervalCert {n : Nat} (c : Circuit.ResidualIntervalCert n) : String := let header := s!"dim {n}" let lines := (List.finRange n).foldr (fun i acc => - s!"lo {i.val} {dyadicToString (c.lo i)}" :: - s!"hi {i.val} {dyadicToString (c.hi i)}" :: acc) [] + s!"lo {i.val} {ratToString (c.lo i)}" :: + s!"hi {i.val} {ratToString (c.hi i)}" :: acc) [] String.intercalate "\n" (header :: lines) private def emitResidualIntervalCert {n : Nat} (c : Circuit.ResidualIntervalCert n) @@ -88,205 +250,65 @@ private def buildHeadOutputIntervalFromInputs {seq dModel dHead : Nat} private def headScoreBoundsFromDotAbsTimed {seq dModel dHead : Nat} [NeZero seq] (inputs : Model.InductionHeadInputs seq dModel dHead) - (dotAbs : Fin seq → Fin seq → Dyadic) : + (dotAbs : Fin seq → Fin seq → Rat) : IO (Sound.HeadScoreBounds seq dModel dHead) := do - let headScoreBoundsFromCachesTimed - (scoreLo scoreHi : Fin seq → Fin seq → Dyadic) : - IO (Sound.HeadScoreBounds seq dModel dHead) := do - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let scoreBaseAbs : Fin seq → Fin seq → Dyadic := fun q k => - |inputs.scale| * dotAbs q k - let scoreAbs : Fin seq → Fin seq → Dyadic := fun q k => - if masked q k then |inputs.maskValue| else scoreBaseAbs q k - let otherKeys : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - let maskedKeys : Fin seq → Finset (Fin seq) := fun q => - if inputs.maskCausal = true then - (otherKeys q).filter (fun k => q < k) - else - (∅ : Finset (Fin seq)) - let unmaskedKeys : Fin seq → Finset (Fin seq) := fun q => - (otherKeys q) \ (maskedKeys q) - let maskedGap : Fin seq → Dyadic := fun q => - scoreLo q (inputs.prev q) - inputs.maskValue - let marginTasks : { arr : Array (Task Dyadic) // arr.size = seq } ← - timePhase "head: score margin tasks" <| do - let arr : Array (Task Dyadic) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - if q ∈ inputs.active then - let other := unmaskedKeys q - let masked := maskedKeys q - let prev := inputs.prev q - let gapTasks : Array (Task Dyadic) := - Array.ofFn (fun k : Fin seq => - Task.spawn (fun _ => scoreLo q prev - scoreHi q k)) - let gap : Fin seq → Dyadic := fun k => - let row := gapTasks[k.1]'(by - simp [gapTasks, k.isLt]) - row.get - if hunmasked : other.Nonempty then - let unmaskedMin := other.inf' hunmasked gap - if hmasked : masked.Nonempty then - min unmaskedMin (maskedGap q) - else - unmaskedMin - else - if hmasked : masked.Nonempty then - maskedGap q - else - (0 : Dyadic) - else - (0 : Dyadic))) - let hsize : arr.size = seq := by simp [arr] - pure ⟨arr, hsize⟩ - have hmargin : marginTasks.1.size = seq := marginTasks.2 - let marginAt : Fin seq → Dyadic := fun q => - let q' : Fin marginTasks.1.size := Fin.cast hmargin.symm q - (marginTasks.1[q'.1]'(by exact q'.isLt)).get - let epsTasks : { arr : Array (Task Dyadic) // arr.size = seq } ← - timePhase "head: score eps tasks" <| do - let arr : Array (Task Dyadic) := - Array.ofFn (fun q : Fin seq => - let q' : Fin marginTasks.1.size := Fin.cast hmargin.symm q - (marginTasks.1[q'.1]'(by exact q'.isLt)).map (fun m => - if m < 0 then - (1 : Dyadic) - else - dyadicDivUp (seq - 1) (1 + m))) - let hsize : arr.size = seq := by simp [arr] - pure ⟨arr, hsize⟩ - have heps : epsTasks.1.size = seq := epsTasks.2 - let epsAt : Fin seq → Dyadic := fun q => - let q' : Fin epsTasks.1.size := Fin.cast heps.symm q - (epsTasks.1[q'.1]'(by exact q'.isLt)).get - let margin ← timePhase "head: score margin reduction" <| - pure (if h : inputs.active.Nonempty then - inputs.active.inf' h marginAt - else - (0 : Dyadic)) - let eps ← timePhase "head: score eps reduction" <| - pure (if margin < 0 then - (1 : Dyadic) - else - dyadicDivUp (seq - 1) (1 + margin)) - let result : Sound.HeadScoreBounds seq dModel dHead := - { dotAbs := dotAbs - scoreBaseAbs := scoreBaseAbs - scoreAbs := scoreAbs - scoreLo := scoreLo - scoreHi := scoreHi - marginAt := marginAt - epsAt := epsAt - margin := margin - eps := eps } - return result - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let scoreBaseAbs : Fin seq → Fin seq → Dyadic := fun q k => - |inputs.scale| * dotAbs q k - let scoreLoRaw : Fin seq → Fin seq → Dyadic := fun q k => - if masked q k then - inputs.maskValue - else - -scoreBaseAbs q k - let scoreHiRaw : Fin seq → Fin seq → Dyadic := fun q k => - if masked q k then - inputs.maskValue - else - scoreBaseAbs q k - IO.println "timing: head score caches skipped (direct score functions)" - flushStdout - let scoreLo : Fin seq → Fin seq → Dyadic := fun q k => - if masked q k then inputs.maskValue else -(|inputs.scale| * dotAbs q k) - let scoreHi : Fin seq → Fin seq → Dyadic := fun q k => - if masked q k then inputs.maskValue else |inputs.scale| * dotAbs q k - headScoreBoundsFromCachesTimed scoreLo scoreHi + timePure "head: score bounds" (fun () => + Sound.headScoreBoundsFromDotAbs inputs dotAbs) private def headScoreBoundsFromQAbsKAbsTimed {seq dModel dHead : Nat} [NeZero seq] (inputs : Model.InductionHeadInputs seq dModel dHead) - (qAbs kAbs : Fin seq → Fin dHead → Dyadic) - (dotAbs : Fin seq → Fin seq → Dyadic) : + (qAbs kAbs : Fin seq → Fin dHead → Rat) + (dotAbs : Fin seq → Fin seq → Rat) : IO (Sound.HeadScoreBounds seq dModel dHead) := do let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k - let otherKeys : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - let maskedKeys : Fin seq → Finset (Fin seq) := fun q => - if inputs.maskCausal = true then - (otherKeys q).filter (fun k => q < k) - else - (∅ : Finset (Fin seq)) - let unmaskedKeys : Fin seq → Finset (Fin seq) := fun q => - (otherKeys q) \ (maskedKeys q) - let scoreBaseAbs : Fin seq → Fin seq → Dyadic := fun q k => + let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => |inputs.scale| * dotAbs q k - let scoreLo : Fin seq → Fin seq → Dyadic := fun q k => + let scoreLo : Fin seq → Fin seq → Rat := fun q k => if masked q k then inputs.maskValue else -scoreBaseAbs q k - let scoreHi : Fin seq → Fin seq → Dyadic := fun q k => + let scoreHi : Fin seq → Fin seq → Rat := fun q k => if masked q k then inputs.maskValue else scoreBaseAbs q k - let kAbsMax : Fin dHead → Dyadic := fun d => + let kAbsMax : Fin dHead → Rat := fun d => let univ : Finset (Fin seq) := Finset.univ have hnonempty : univ.Nonempty := Finset.univ_nonempty univ.sup' hnonempty (fun k => kAbs k d) - let dotAbsUpper : Fin seq → Dyadic := fun q => + let dotAbsUpper : Fin seq → Rat := fun q => Sound.Linear.dotFin dHead (fun d => qAbs q d) kAbsMax - let scoreHiUpper : Fin seq → Dyadic := fun q => + let scoreHiUpper : Fin seq → Rat := fun q => max inputs.maskValue (|inputs.scale| * dotAbsUpper q) - let fastGap : Fin seq → Dyadic := fun q => - let prev := inputs.prev q - scoreLo q prev - scoreHiUpper q - let marginTasks : Array (Task Dyadic) := + let marginTasks : Array (Task Rat) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => if q ∈ inputs.active then - let fast := fastGap q - if fast < 0 then - let other := unmaskedKeys q - let maskedSet := maskedKeys q - let exact := - if hunmasked : other.Nonempty then - let unmaskedMin := - other.inf' hunmasked (fun k => scoreLo q (inputs.prev q) - scoreHi q k) - if maskedSet.Nonempty then - min unmaskedMin (scoreLo q (inputs.prev q) - inputs.maskValue) - else - unmaskedMin - else - if maskedSet.Nonempty then - scoreLo q (inputs.prev q) - inputs.maskValue - else - (0 : Dyadic) - exact - else - fast + let prev := inputs.prev q + let scoreLoPrev := scoreLo q prev + scoreLoPrev - scoreHiUpper q else - (0 : Dyadic))) - let marginAt : Fin seq → Dyadic := fun q => + (0 : Rat))) + let marginAt : Fin seq → Rat := fun q => (marginTasks[q.1]'(by simp [marginTasks, q.isLt])).get - let epsTasks : Array (Task Dyadic) := + let epsTasks : Array (Task Rat) := Array.ofFn (fun q : Fin seq => (marginTasks[q.1]'(by simp [marginTasks, q.isLt])).map (fun m => if m < 0 then - (1 : Dyadic) + (1 : Rat) else - dyadicDivUp (seq - 1) (1 + m))) - let epsAt : Fin seq → Dyadic := fun q => + ratDivUp (seq - 1) (1 + m))) + let epsAt : Fin seq → Rat := fun q => (epsTasks[q.1]'(by simp [epsTasks, q.isLt])).get - let margin : Dyadic := + let margin : Rat := if h : inputs.active.Nonempty then inputs.active.inf' h marginAt else - (0 : Dyadic) - let eps : Dyadic := + (0 : Rat) + let eps : Rat := if margin < 0 then - (1 : Dyadic) + (1 : Rat) else - dyadicDivUp (seq - 1) (1 + margin) + ratDivUp (seq - 1) (1 + margin) let result : Sound.HeadScoreBounds seq dModel dHead := { dotAbs := dotAbs scoreBaseAbs := scoreBaseAbs @@ -301,8 +323,8 @@ private def headScoreBoundsFromQAbsKAbsTimed {seq dModel dHead : Nat} [NeZero se private def checkInductionHeadInputs {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (minActive? : Option Nat) (minLogitDiff? : Option Dyadic) - (minMargin maxEps : Dyadic) : IO UInt32 := do + (minActive? : Option Nat) (minLogitDiff? : Option Rat) + (minMargin maxEps : Rat) : IO UInt32 := do match seq with | 0 => IO.eprintln "error: seq must be positive" @@ -390,11 +412,11 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} Sound.headValueBounds inputs qkv.vLo qkv.vHi (some vals, none) else - let task := Task.spawn (fun _ => + let task := if useCommon then - Sound.headValueBoundsCommonDen inputs qkv.vLo qkv.vHi + Sound.headValueBoundsCommonDenTask inputs qkv.vLo qkv.vHi else - Sound.headValueBounds inputs qkv.vLo qkv.vHi) + Sound.headValueBoundsTask inputs qkv.vLo qkv.vHi (none, some task) let activeList := (List.finRange seq).filter (fun q => q ∈ inputs.active) if verboseTiming.isSome then @@ -403,40 +425,21 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} IO.println s!"timing: head score/value bounds spawn {tSpawn1 - tSpawn0} us" flushStdout let skipScoreBounds := (← IO.getEnv "NFP_TIMING_SKIP_SCORE_BOUNDS").isSome - let scoreOpt ← + let scoreTaskOpt ← if skipScoreBounds then IO.println "timing: head score bounds skipped" pure none else IO.println "timing: head score bounds from dotAbs start" flushStdout - let fastMargin := (← IO.getEnv "NFP_TIMING_FAST_MARGIN").isSome - let score ← - if fastMargin then - headScoreBoundsFromQAbsKAbsTimed inputs qkv.qAbs qkv.kAbs dotAbs - else + let exactMargin := (← IO.getEnv "NFP_TIMING_EXACT_MARGIN").isSome + let action := + if exactMargin then headScoreBoundsFromDotAbsTimed inputs dotAbs - IO.println "timing: head score bounds from dotAbs done" - flushStdout - pure (some score) - match scoreOpt with - | none => pure () - | some score => - if verboseTiming.isSome then - timeHeadScoreSampleGap inputs score - if verboseTiming.isSome then - timeHeadScoreMarginList activeList score - if verboseTiming.isSome then - timeHeadScoreFieldForces score - if verboseTiming.isSome then - IO.println "timing: head score bounds force start" - flushStdout - let tScore0 ← monoUsNow - let _ := score.margin - let _ := score.eps - let tScore1 ← monoUsNow - IO.println s!"timing: head score bounds force {tScore1 - tScore0} us" - flushStdout + else + headScoreBoundsFromQAbsKAbsTimed inputs qkv.qAbs qkv.kAbs dotAbs + let t ← action.asTask + pure (some t) if verboseTiming.isSome then IO.println "timing: head value parts start" flushStdout @@ -510,17 +513,366 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} let tVals1 ← monoUsNow IO.println s!"timing: head value bounds {tVals1 - tVals0} us" flushStdout - let certOpt : - Option { c : Sound.InductionHeadCert seq // Sound.InductionHeadCertSound inputs c } ← - timePure "head: build induction cert" (fun () => + let scoreOpt ← + match scoreTaskOpt with + | none => pure none + | some scoreTask => do + let res ← IO.wait scoreTask + let score ← unwrapTaskResult res + IO.println "timing: head score bounds from dotAbs done" + flushStdout + pure (some score) + match scoreOpt with + | none => pure () + | some score => + if verboseTiming.isSome then + timeHeadScoreSampleGap inputs score + if verboseTiming.isSome then + timeHeadScoreMarginList activeList score + if verboseTiming.isSome then + timeHeadScoreFieldForces score + if verboseTiming.isSome then + IO.println "timing: head score bounds force start" + flushStdout + let tScore0 ← monoUsNow + let _ := score.margin + let _ := score.eps + let tScore1 ← monoUsNow + IO.println s!"timing: head score bounds force {tScore1 - tScore0} us" + flushStdout + let coreStages := (← IO.getEnv "NFP_TIMING_CORE_STAGES").isSome + let coreStagesOnly := (← IO.getEnv "NFP_TIMING_CORE_STAGES_ONLY").isSome + if coreStages then + timeInductionHeadCoreStages inputs + if coreStagesOnly then + return 0 + let breakdown := (← IO.getEnv "NFP_TIMING_BREAKDOWN").isSome + if breakdown then + let lnBounds ← timePureWithHeartbeat "breakdown: ln bounds" (fun () => + Sound.headLnBounds inputs) + IO.println "timing: breakdown ln bounds force start" + flushStdout + let tLn0 ← monoUsNow + for q in List.finRange seq do + for i in List.finRange dModel do + let _ := lnBounds.1 q i + let _ := lnBounds.2 q i + pure () + let tLn1 ← monoUsNow + IO.println s!"timing: breakdown ln bounds force {tLn1 - tLn0} us" + flushStdout + let qkv ← timePureWithHeartbeat "breakdown: qkv bounds" (fun () => + Sound.headQKVBounds inputs lnBounds.1 lnBounds.2) + IO.println "timing: breakdown qkv bounds force start" + flushStdout + let tQkv0 ← monoUsNow + for q in List.finRange seq do + for d in List.finRange dHead do + let _ := qkv.qLo q d + let _ := qkv.qHi q d + let _ := qkv.kLo q d + let _ := qkv.kHi q d + let _ := qkv.vLo q d + let _ := qkv.vHi q d + let _ := qkv.qAbs q d + let _ := qkv.kAbs q d + pure () + let tQkv1 ← monoUsNow + IO.println s!"timing: breakdown qkv bounds force {tQkv1 - tQkv0} us" + flushStdout + let dotAbs : Fin seq → Fin seq → Rat := fun q k => + Sound.Linear.dotFin dHead (fun d => qkv.qAbs q d) (fun d => qkv.kAbs k d) + let dotAbsRowTasks : + Array (Task { row : Array Rat // row.size = seq }) ← + timePureWithHeartbeat "breakdown: score dotAbs rows" (fun () => + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + ⟨Array.ofFn (fun k : Fin seq => dotAbs q k), by simp⟩))) + let dotAbsRowDefault : Task { row : Array Rat // row.size = seq } := + Task.spawn (fun _ => ⟨Array.ofFn (fun _ : Fin seq => (0 : Rat)), by simp⟩) + IO.println "timing: breakdown score dotAbs force start" + flushStdout + let tDot0 ← monoUsNow + for q in List.finRange seq do + let row := (dotAbsRowTasks.getD q.1 dotAbsRowDefault).get + let _ := row + pure () + let tDot1 ← monoUsNow + IO.println s!"timing: breakdown score dotAbs force {tDot1 - tDot0} us" + flushStdout + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let scaleAbs : Rat := |inputs.scale| + let marginAtRaw : Fin seq → Rat := fun q => + let row := (dotAbsRowTasks.getD q.1 dotAbsRowDefault).get + if q ∈ inputs.active then + let rowArr := row.1 + let prev := inputs.prev q + let dotAbsPrev := rowArr.getD prev.1 0 + if masked q prev then + let scoreLoPrev := inputs.maskValue + let scoreHiAt : Fin seq → Rat := fun k => + if masked q k then inputs.maskValue else scaleAbs * rowArr.getD k.1 0 + let maskedGap := scoreLoPrev - inputs.maskValue + let step : + (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := + fun acc k => + if k = prev then + acc + else if masked q k then + (acc.1, true) + else + let v := scoreLoPrev - scoreHiAt k + match acc.1 with + | none => (some v, acc.2) + | some cur => (some (min cur v), acc.2) + let acc := Sound.Linear.foldlFin seq step (none, false) + match acc.1, acc.2 with + | some unmaskedMin, true => min unmaskedMin maskedGap + | some unmaskedMin, false => unmaskedMin + | none, true => maskedGap + | none, false => (0 : Rat) + else + let scoreLoPrev := -(scaleAbs * dotAbsPrev) + let maskedGap := scoreLoPrev - inputs.maskValue + let step : + (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := + fun acc k => + if k = prev then + acc + else if masked q k then + (acc.1, true) + else + let raw := -(dotAbsPrev + rowArr.getD k.1 0) + match acc.1 with + | none => (some raw, acc.2) + | some cur => (some (min cur raw), acc.2) + let acc := Sound.Linear.foldlFin seq step (none, false) + match acc.1, acc.2 with + | some unmaskedMin, true => min (scaleAbs * unmaskedMin) maskedGap + | some unmaskedMin, false => scaleAbs * unmaskedMin + | none, true => maskedGap + | none, false => (0 : Rat) + else + (0 : Rat) + let marginAtCached : Fin seq → Rat ← + timePureWithHeartbeat "breakdown: score margin cache" (fun () => + Sound.Bounds.cacheBoundThunk marginAtRaw) + IO.println "timing: breakdown score margin force start" + flushStdout + let tMargin0 ← monoUsNow + for q in List.finRange seq do + let m := marginAtCached q + forceRat m + pure () + let tMargin1 ← monoUsNow + IO.println s!"timing: breakdown score margin force {tMargin1 - tMargin0} us" + flushStdout + let epsAtRaw : Fin seq → Rat := fun q => + let m := marginAtCached q + if m < 0 then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + m) + let epsAtCached : Fin seq → Rat ← + timePureWithHeartbeat "breakdown: score eps cache" (fun () => + Sound.Bounds.cacheBoundThunk epsAtRaw) + IO.println "timing: breakdown score eps force start" + flushStdout + let tEps0 ← monoUsNow + for q in List.finRange seq do + let e := epsAtCached q + forceRat e + pure () + let tEps1 ← monoUsNow + IO.println s!"timing: breakdown score eps force {tEps1 - tEps0} us" + flushStdout + let valsLo ← timePureWithHeartbeat "breakdown: value valsLo" (fun () => + Sound.headValueValsLo inputs qkv.vLo qkv.vHi) + IO.println "timing: breakdown value valsLo force start" + flushStdout + let tValsLo0 ← monoUsNow + for k in List.finRange seq do + let v := valsLo k + forceRat v + pure () + let tValsLo1 ← monoUsNow + IO.println s!"timing: breakdown value valsLo force {tValsLo1 - tValsLo0} us" + flushStdout + let valsHi ← timePureWithHeartbeat "breakdown: value valsHi" (fun () => + Sound.headValueValsHi inputs qkv.vLo qkv.vHi) + IO.println "timing: breakdown value valsHi force start" + flushStdout + let tValsHi0 ← monoUsNow + for k in List.finRange seq do + let v := valsHi k + forceRat v + pure () + let tValsHi1 ← monoUsNow + IO.println s!"timing: breakdown value valsHi force {tValsHi1 - tValsHi0} us" + flushStdout + let heartbeatMs ← heartbeatMsFromEnv + let taskMin (t1 t2 : Task Rat) : Task Rat := + Task.bind t1 (fun v1 => Task.map (fun v2 => min v1 v2) t2) + let taskMax (t1 t2 : Task Rat) : Task Rat := + Task.bind t1 (fun v1 => Task.map (fun v2 => max v1 v2) t2) + let reduceMinTasksWithProgress (tasks : Array (Task Rat)) : + IO Rat := do + let n := tasks.size + if n = 0 then + pure (0 : Rat) + else + let chunkSize : Nat := 16 + let chunks : Nat := (n + chunkSize - 1) / chunkSize + let defaultTask : Task Rat := Task.pure (0 : Rat) + let chunkTasks : Array (Task Rat) := + Array.ofFn (fun c : Fin chunks => + let start := c.val * chunkSize + let stop := Nat.min n (start + chunkSize) + let init := tasks.getD start defaultTask + if stop ≤ start + 1 then + init + else + let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) + rest.foldl (fun acc i => taskMin acc (tasks.getD i defaultTask)) init) + if heartbeatMs ≠ 0 then + let mut finished := 0 + let mut remaining := chunkTasks.size + while finished < remaining do + IO.sleep heartbeatMs + let mut count := 0 + for t in chunkTasks do + if (← IO.hasFinished t) then + count := count + 1 + finished := count + remaining := chunkTasks.size + if finished < remaining then + IO.println s!"timing: breakdown value lo progress {finished}/{remaining}" + flushStdout + let init := chunkTasks.getD 0 defaultTask + let rest := (List.range (chunkTasks.size - 1)).map (fun i => i + 1) + pure ((rest.foldl (fun acc i => taskMin acc (chunkTasks.getD i defaultTask)) init).get) + let reduceMaxTasksWithProgress (tasks : Array (Task Rat)) : + IO Rat := do + let n := tasks.size + if n = 0 then + pure (0 : Rat) + else + let chunkSize : Nat := 16 + let chunks : Nat := (n + chunkSize - 1) / chunkSize + let defaultTask : Task Rat := Task.pure (0 : Rat) + let chunkTasks : Array (Task Rat) := + Array.ofFn (fun c : Fin chunks => + let start := c.val * chunkSize + let stop := Nat.min n (start + chunkSize) + let init := tasks.getD start defaultTask + if stop ≤ start + 1 then + init + else + let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) + rest.foldl (fun acc i => taskMax acc (tasks.getD i defaultTask)) init) + if heartbeatMs ≠ 0 then + let mut finished := 0 + let mut remaining := chunkTasks.size + while finished < remaining do + IO.sleep heartbeatMs + let mut count := 0 + for t in chunkTasks do + if (← IO.hasFinished t) then + count := count + 1 + finished := count + remaining := chunkTasks.size + if finished < remaining then + IO.println s!"timing: breakdown value hi progress {finished}/{remaining}" + flushStdout + let init := chunkTasks.getD 0 defaultTask + let rest := (List.range (chunkTasks.size - 1)).map (fun i => i + 1) + pure ((rest.foldl (fun acc i => taskMax acc (chunkTasks.getD i defaultTask)) init).get) + if (← IO.getEnv "NFP_TIMING_TASK_PROGRESS").isSome then + let tasksLo := + (List.finRange seq).map (fun k => Task.spawn (fun _ => valsLo k)) + let tasksHi := + (List.finRange seq).map (fun k => Task.spawn (fun _ => valsHi k)) + let _ ← timePureWithHeartbeat "breakdown: value lo progress" (fun () => + reduceMinTasksWithProgress tasksLo.toArray) + let _ ← timePureWithHeartbeat "breakdown: value hi progress" (fun () => + reduceMaxTasksWithProgress tasksHi.toArray) + else + let loTask := Sound.headValueLoTask valsLo + let hiTask := Sound.headValueHiTask valsHi + let heartbeatMs ← heartbeatMsFromEnv + let tLo0 ← monoUsNow + if heartbeatMs ≠ 0 then + let mut finished := (← IO.hasFinished loTask) + while !finished do + IO.sleep heartbeatMs + finished := (← IO.hasFinished loTask) + if !finished then + let now ← monoUsNow + IO.println s!"timing: breakdown: value lo running {now - tLo0} us" + flushStdout + let lo := loTask.get + let tLo1 ← monoUsNow + IO.println s!"timing: breakdown: value lo {tLo1 - tLo0} us" + flushStdout + let tHi0 ← monoUsNow + if heartbeatMs ≠ 0 then + let mut finished := (← IO.hasFinished hiTask) + while !finished do + IO.sleep heartbeatMs + finished := (← IO.hasFinished hiTask) + if !finished then + let now ← monoUsNow + IO.println s!"timing: breakdown: value hi running {now - tHi0} us" + flushStdout + let hi := hiTask.get + let tHi1 ← monoUsNow + IO.println s!"timing: breakdown: value hi {tHi1 - tHi0} us" + flushStdout + let _ := lo + let _ := hi + if (← IO.getEnv "NFP_TIMING_SEQ_REDUCE").isSome then + let loSeq ← timePureWithHeartbeat "breakdown: value lo seq" (fun () => + match List.finRange seq with + | [] => (0 : Rat) + | k :: ks => + let init := valsLo k + ks.foldl (fun acc k => min acc (valsLo k)) init) + let hiSeq ← timePureWithHeartbeat "breakdown: value hi seq" (fun () => + match List.finRange seq with + | [] => (0 : Rat) + | k :: ks => + let init := valsHi k + ks.foldl (fun acc k => max acc (valsHi k)) init) + let _ := loSeq + let _ := hiSeq + let tCert0 ← monoUsNow + let certTask : + Task + (Option { c : Sound.InductionHeadCert seq // + Sound.InductionHeadCertSound inputs c }) := + Task.spawn (prio := Task.Priority.dedicated) (fun _ => match Sound.buildInductionCertFromHead? inputs with | none => none | some ⟨cert, hcert⟩ => let _ := cert.active.card some ⟨cert, hcert⟩) + let heartbeatMs ← heartbeatMsFromEnv + if heartbeatMs ≠ 0 then + let mut finished := (← IO.hasFinished certTask) + while !finished do + IO.sleep heartbeatMs + finished := (← IO.hasFinished certTask) + if !finished then + let now ← monoUsNow + IO.println s!"timing: head build induction cert running {now - tCert0} us" + flushStdout + let certOpt ← IO.wait certTask + let tCert1 ← monoUsNow + logTiming s!"done: head build induction cert {tCert1 - tCert0} us" + IO.println s!"timing: head build induction cert {tCert1 - tCert0} us" IO.println "timing: head build induction cert returned" flushStdout - logTiming "done: head build induction cert" match certOpt with | none => IO.eprintln "error: head inputs rejected" @@ -539,13 +891,13 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} return 2 if cert.margin < minMargin then IO.eprintln - s!"error: margin {dyadicToString cert.margin} \ - below minimum {dyadicToString minMargin}" + s!"error: margin {ratToString cert.margin} \ + below minimum {ratToString minMargin}" return 2 if maxEps < cert.eps then IO.eprintln - s!"error: eps {dyadicToString cert.eps} \ - above maximum {dyadicToString maxEps}" + s!"error: eps {ratToString cert.eps} \ + above maximum {ratToString maxEps}" return 2 IO.println "timing: head tol start" flushStdout @@ -562,13 +914,13 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} let effectiveMinLogitDiff := match minLogitDiff? with | some v => some v - | none => some (0 : Dyadic) + | none => some (0 : Rat) match logitDiffLB? with | none => IO.eprintln "error: empty active set for logit-diff bound" return 2 | some logitDiffLB => - let violation? : Option Dyadic := + let violation? : Option Rat := match effectiveMinLogitDiff with | none => none | some minLogitDiff => @@ -579,20 +931,20 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} match violation? with | some minLogitDiff => IO.eprintln - s!"error: logitDiffLB {dyadicToString logitDiffLB} \ - below minimum {dyadicToString minLogitDiff}" + s!"error: logitDiffLB {ratToString logitDiffLB} \ + below minimum {ratToString minLogitDiff}" return 2 | none => IO.println s!"ok: induction bound certified \ (seq={seq}, active={activeCount}, \ - tol={dyadicToString tol}, logitDiffLB={dyadicToString logitDiffLB})" + tol={ratToString tol}, logitDiffLB={ratToString logitDiffLB})" return 0 private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (minActive? : Option Nat) (minLogitDiff? : Option Dyadic) - (minMargin maxEps : Dyadic) : IO UInt32 := do + (minActive? : Option Nat) (minLogitDiff? : Option Rat) + (minMargin maxEps : Rat) : IO UInt32 := do match seq with | 0 => IO.eprintln "error: seq must be positive" @@ -621,35 +973,35 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} return 2 if cert.margin < minMargin then IO.eprintln - s!"error: margin {dyadicToString cert.margin} \ - below minimum {dyadicToString minMargin}" + s!"error: margin {ratToString cert.margin} \ + below minimum {ratToString minMargin}" return 2 if maxEps < cert.eps then IO.eprintln - s!"error: eps {dyadicToString cert.eps} above maximum {dyadicToString maxEps}" + s!"error: eps {ratToString cert.eps} above maximum {ratToString maxEps}" return 2 match minLogitDiff? with | some minLogitDiff => if logitDiffLB < minLogitDiff then IO.eprintln - s!"error: logitDiffLB {dyadicToString logitDiffLB} \ - below minimum {dyadicToString minLogitDiff}" + s!"error: logitDiffLB {ratToString logitDiffLB} \ + below minimum {ratToString minLogitDiff}" return 2 | none => pure () let tol := cert.eps * (cert.values.hi - cert.values.lo) IO.println s!"ok: nonvacuous induction bound certified \ (seq={seq}, active={activeCount}, \ - tol={dyadicToString tol}, logitDiffLB={dyadicToString logitDiffLB})" + tol={ratToString tol}, logitDiffLB={ratToString logitDiffLB})" return 0 /-- Build and check induction certificates from exact head inputs. -/ def runInductionCertifyHead (inputsPath : System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseDyadicOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseDyadicOpt "min-margin" minMarginStr? - let maxEps?E := parseDyadicOpt "max-eps" maxEpsStr? + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? match minLogitDiff?E, minMargin?E, maxEps?E with | Except.error msg, _, _ => IO.eprintln s!"error: {msg}" @@ -661,8 +1013,8 @@ def runInductionCertifyHead (inputsPath : System.FilePath) IO.eprintln s!"error: {msg}" return 2 | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let minMargin := minMargin?.getD (0 : Dyadic) - let maxEps := maxEps?.getD (dyadicOfRatDown (Rat.divInt 1 2)) + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) let parsedInputs ← timePhase "load head inputs" <| loadInductionHeadInputs inputsPath match parsedInputs with @@ -676,9 +1028,9 @@ def runInductionCertifyHead (inputsPath : System.FilePath) def runInductionCertifyHeadNonvacuous (inputsPath : System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseDyadicOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseDyadicOpt "min-margin" minMarginStr? - let maxEps?E := parseDyadicOpt "max-eps" maxEpsStr? + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? match minLogitDiff?E, minMargin?E, maxEps?E with | Except.error msg, _, _ => IO.eprintln s!"error: {msg}" @@ -690,8 +1042,8 @@ def runInductionCertifyHeadNonvacuous (inputsPath : System.FilePath) IO.eprintln s!"error: {msg}" return 2 | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let minMargin := minMargin?.getD (0 : Dyadic) - let maxEps := maxEps?.getD (dyadicOfRatDown (Rat.divInt 1 2)) + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) let parsedInputs ← timePhase "load head inputs" <| loadInductionHeadInputs inputsPath match parsedInputs with @@ -706,9 +1058,9 @@ def runInductionCertifyHeadModel (modelPath : System.FilePath) (layer head dirTarget dirNegative : Nat) (period? : Option Nat) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseDyadicOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseDyadicOpt "min-margin" minMarginStr? - let maxEps?E := parseDyadicOpt "max-eps" maxEpsStr? + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? match minLogitDiff?E, minMargin?E, maxEps?E with | Except.error msg, _, _ => IO.eprintln s!"error: {msg}" @@ -720,8 +1072,8 @@ def runInductionCertifyHeadModel (modelPath : System.FilePath) IO.eprintln s!"error: {msg}" return 2 | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => - let minMargin := minMargin?.getD (0 : Dyadic) - let maxEps := maxEps?.getD (dyadicOfRatDown (Rat.divInt 1 2)) + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) logTiming "start: read model file" IO.println "timing: read model file start" flushStdout @@ -748,9 +1100,9 @@ def runInductionCertifyHeadModelNonvacuous (modelPath : System.FilePath) (layer head dirTarget dirNegative : Nat) (period? : Option Nat) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseDyadicOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseDyadicOpt "min-margin" minMarginStr? - let maxEps?E := parseDyadicOpt "max-eps" maxEpsStr? + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? match minLogitDiff?E, minMargin?E, maxEps?E with | Except.error msg, _, _ => IO.eprintln s!"error: {msg}" @@ -762,8 +1114,8 @@ def runInductionCertifyHeadModelNonvacuous (modelPath : System.FilePath) IO.eprintln s!"error: {msg}" return 2 | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => - let minMargin := minMargin?.getD (0 : Dyadic) - let maxEps := maxEps?.getD (dyadicOfRatDown (Rat.divInt 1 2)) + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) logTiming "start: read model file" IO.println "timing: read model file start" flushStdout diff --git a/Nfp/Sound/Induction/Core.lean b/Nfp/Sound/Induction/Core.lean index f80d53c..266ffe0 100644 --- a/Nfp/Sound/Induction/Core.lean +++ b/Nfp/Sound/Induction/Core.lean @@ -2,12 +2,14 @@ import Mathlib.Algebra.BigOperators.Group.Finset.Basic import Mathlib.Algebra.Order.BigOperators.Group.Finset +import Mathlib.Algebra.Order.Field.Basic import Nfp.Core.Basic import Mathlib.Data.Finset.Lattice.Fold import Nfp.Circuit.Cert.ResidualInterval import Nfp.Circuit.Cert.SoftmaxMargin import Nfp.Circuit.Cert.ValueRange import Nfp.Sound.Bounds.Attention +import Nfp.Sound.Bounds.Cache import Nfp.Sound.Bounds.LayerNorm import Nfp.Sound.Bounds.MatrixNorm import Nfp.Sound.Induction.CoreDefs @@ -37,37 +39,37 @@ variable {seq : Nat} def buildSoftmaxMarginCert? [NeZero seq] (active : Finset (Fin seq)) (prev : Fin seq → Fin seq) - (scores : Fin seq → Fin seq → Dyadic) - (weights : Fin seq → Fin seq → Dyadic) : + (scores : Fin seq → Fin seq → Rat) + (weights : Fin seq → Fin seq → Rat) : Option {c : SoftmaxMarginCert seq // checkSoftmaxMarginCert c = true} := by classical let otherKeys : Fin seq → Finset (Fin seq) := fun q => (Finset.univ : Finset (Fin seq)).erase (prev q) - let epsAt : Fin seq → Dyadic := fun q => + let epsAt : Fin seq → Rat := fun q => let other := otherKeys q let maxOther := if h : other.Nonempty then other.sup' h (fun k => weights q k) else - (0 : Dyadic) - let deficit := (1 : Dyadic) - weights q (prev q) + (0 : Rat) + let deficit := (1 : Rat) - weights q (prev q) max maxOther deficit - let marginAt : Fin seq → Dyadic := fun q => + let marginAt : Fin seq → Rat := fun q => let other := otherKeys q if h : other.Nonempty then other.inf' h (fun k => scores q (prev q) - scores q k) else - (0 : Dyadic) + (0 : Rat) let eps := if h : active.Nonempty then active.sup' h epsAt else - (0 : Dyadic) + (0 : Rat) let margin := if h : active.Nonempty then active.inf' h marginAt else - (0 : Dyadic) + (0 : Rat) let cert : SoftmaxMarginCert seq := { eps := eps margin := margin @@ -82,7 +84,7 @@ def buildSoftmaxMarginCert? [NeZero seq] /-- Build and certify a value-range certificate from exact values. -/ def buildValueRangeCert? [NeZero seq] - (vals : Fin seq → Dyadic) + (vals : Fin seq → Rat) (direction : Option DirectionSpec) : Option {c : ValueRangeCert seq // checkValueRangeCert c = true} := by classical @@ -115,111 +117,120 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} · by_cases hactive : inputs.active.Nonempty · let lnBounds := Bounds.cacheBoundPair2 (fun q => Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) - let lnLo : Fin seq → Fin dModel → Dyadic := lnBounds.1 - let lnHi : Fin seq → Fin dModel → Dyadic := lnBounds.2 - let qLo := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalLowerCachedDyadic (fun j => inputs.wq j d) (lnLo q) (lnHi q) + - inputs.bq d) - let qHi := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalUpperCachedDyadic (fun j => inputs.wq j d) (lnLo q) (lnHi q) + - inputs.bq d) - let kLo := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalLowerCachedDyadic (fun j => inputs.wk j d) (lnLo q) (lnHi q) + - inputs.bk d) - let kHi := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalUpperCachedDyadic (fun j => inputs.wk j d) (lnLo q) (lnHi q) + - inputs.bk d) - let vLo := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalLowerCachedDyadic (fun j => inputs.wv j d) (lnLo q) (lnHi q) + - inputs.bv d) - let vHi := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalUpperCachedDyadic (fun j => inputs.wv j d) (lnLo q) (lnHi q) + - inputs.bv d) - let qAbs := Bounds.cacheBound2 (fun q d => max |qLo q d| |qHi q d|) - let kAbs := Bounds.cacheBound2 (fun q d => max |kLo q d| |kHi q d|) - let kAbsMax : Fin dHead → Dyadic := fun d => + let lnLo : Fin seq → Fin dModel → Rat := lnBounds.1 + let lnHi : Fin seq → Fin dModel → Rat := lnBounds.2 + let lnAbsMaxTask : Fin seq → Rat := + Bounds.cacheBoundTask (fun q => + Bounds.intervalAbsBound (lnLo q) (lnHi q)) + let lnAbsMaxArr : Array Rat := + Array.ofFn (fun q : Fin seq => lnAbsMaxTask q) + let lnAbsMax : Fin seq → Rat := fun q => + lnAbsMaxArr[q.1]'(by + have hsize : lnAbsMaxArr.size = seq := by + simp [lnAbsMaxArr] + simp [hsize]) + let lnAbsMaxMax : Rat := let univ : Finset (Fin seq) := Finset.univ have hnonempty : univ.Nonempty := Finset.univ_nonempty - univ.sup' hnonempty (fun k => kAbs k d) + univ.sup' hnonempty (fun q => lnAbsMax q) + let qAbsRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + ⟨Array.ofFn (fun d : Fin dHead => + Bounds.dotIntervalAbsBound (fun j => inputs.wq j d) (lnLo q) (lnHi q) + + |inputs.bq d|), + by simp⟩)) + let qAbsBaseArr : Array { row : Array Rat // row.size = dHead } := + Array.ofFn (fun q : Fin seq => + (qAbsRowTasks[q.1]'(by + have hsize : qAbsRowTasks.size = seq := by + simp [qAbsRowTasks] + simp [hsize])).get) + let qAbsBase : Fin seq → Fin dHead → Rat := fun q d => + let row := qAbsBaseArr[q.1]'(by + have hsize : qAbsBaseArr.size = seq := by + simp [qAbsBaseArr] + simp [hsize]) + row.1[d.1]'(by + simp [row.2]) + let kAbsRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + ⟨Array.ofFn (fun d : Fin dHead => + Bounds.dotIntervalAbsBound (fun j => inputs.wk j d) (lnLo q) (lnHi q) + + |inputs.bk d|), + by simp⟩)) + let kAbsBaseArr : Array { row : Array Rat // row.size = dHead } := + Array.ofFn (fun q : Fin seq => + (kAbsRowTasks[q.1]'(by + have hsize : kAbsRowTasks.size = seq := by + simp [kAbsRowTasks] + simp [hsize])).get) + let kAbsBase : Fin seq → Fin dHead → Rat := fun q d => + let row := kAbsBaseArr[q.1]'(by + have hsize : kAbsBaseArr.size = seq := by + simp [kAbsBaseArr] + simp [hsize]) + row.1[d.1]'(by + simp [row.2]) + let qLo : Fin seq → Fin dHead → Rat := fun q d => -qAbsBase q d + let qHi : Fin seq → Fin dHead → Rat := fun q d => qAbsBase q d + let kLo : Fin seq → Fin dHead → Rat := fun q d => -kAbsBase q d + let kHi : Fin seq → Fin dHead → Rat := fun q d => kAbsBase q d + let qAbs : Fin seq → Fin dHead → Rat := qAbsBase + let kAbs : Fin seq → Fin dHead → Rat := kAbsBase let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k - let dotAbs : Fin seq → Fin seq → Dyadic := fun q k => - Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d) - let scoreBaseAbs : Fin seq → Fin seq → Dyadic := fun q k => + let dotAbs := + Bounds.cacheBound2Task (fun q k => + Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d)) + let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => |inputs.scale| * dotAbs q k - let scoreAbs : Fin seq → Fin seq → Dyadic := fun q k => - if masked q k then |inputs.maskValue| else scoreBaseAbs q k - let scoreLo : Fin seq → Fin seq → Dyadic := fun q k => + let scoreLo : Fin seq → Fin seq → Rat := fun q k => if masked q k then inputs.maskValue else -scoreBaseAbs q k - let scoreHi : Fin seq → Fin seq → Dyadic := fun q k => + let scoreHi : Fin seq → Fin seq → Rat := fun q k => if masked q k then inputs.maskValue else scoreBaseAbs q k - let dotAbsUpper : Fin seq → Dyadic := fun q => - Linear.dotFin dHead (fun d => qAbs q d) kAbsMax - let scoreHiUpper : Fin seq → Dyadic := fun q => - max inputs.maskValue (|inputs.scale| * dotAbsUpper q) - let fastGap : Fin seq → Dyadic := fun q => - let prev := inputs.prev q - scoreLo q prev - scoreHiUpper q + let scoreLoPrev : Fin seq → Rat := fun q => + scoreLo q (inputs.prev q) let otherKeys : Fin seq → Finset (Fin seq) := fun q => (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - let maskedKeys : Fin seq → Finset (Fin seq) := fun q => - if inputs.maskCausal = true then - (otherKeys q).filter (fun k => q < k) - else - (∅ : Finset (Fin seq)) - let unmaskedKeys : Fin seq → Finset (Fin seq) := fun q => - (otherKeys q) \ (maskedKeys q) - let maskedGap : Fin seq → Dyadic := fun q => - scoreLo q (inputs.prev q) - inputs.maskValue - let marginAtRaw : Fin seq → Dyadic := fun q => - let fast := fastGap q - if fast < 0 then - let other := unmaskedKeys q - let maskedSet := maskedKeys q - if hunmasked : other.Nonempty then - let unmaskedMin := other.inf' hunmasked (fun k => - scoreLo q (inputs.prev q) - scoreHi q k) - if maskedSet.Nonempty then - min unmaskedMin (maskedGap q) - else - unmaskedMin + let marginAt : Fin seq → Rat := fun q => + if hq : q ∈ inputs.active then + let other := otherKeys q + if h : other.Nonempty then + other.inf' h (fun k => scoreLoPrev q - scoreHi q k) else - if maskedSet.Nonempty then - maskedGap q - else - (0 : Dyadic) - else - fast - let marginAt : Fin seq → Dyadic := fun q => - if q ∈ inputs.active then - marginAtRaw q + (0 : Rat) else - (0 : Dyadic) - let epsAt : Fin seq → Dyadic := fun q => + (0 : Rat) + let epsAt : Fin seq → Rat := fun q => if marginAt q < 0 then - (1 : Dyadic) + (1 : Rat) else - dyadicDivUp (seq - 1) (1 + marginAt q) - let margin : Dyadic := inputs.active.inf' hactive marginAt - let eps : Dyadic := + ratDivUp (seq - 1) (1 + marginAt q) + let margin : Rat := + if h : inputs.active.Nonempty then + inputs.active.inf' h marginAt + else + (0 : Rat) + let eps : Rat := if margin < 0 then - (1 : Dyadic) + (1 : Rat) else - dyadicDivUp (seq - 1) (1 + margin) + ratDivUp (seq - 1) (1 + margin) let dirHeadVec := dirHeadVecOfInputs inputs - let dirHead : Fin dHead → Dyadic := fun d => dirHeadVec.get d - let valsLo := - Bounds.cacheBound (fun k => - Bounds.dotIntervalLowerCachedDyadic dirHead (vLo k) (vHi k)) - let valsHi := - Bounds.cacheBound (fun k => - Bounds.dotIntervalUpperCachedDyadic dirHead (vLo k) (vHi k)) + let dirHead : Fin dHead → Rat := fun d => dirHeadVec.get d + let wvDir : Fin dModel → Rat := + Bounds.cacheBoundTask (fun j => + Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) + let bDir : Rat := + Linear.dotFin dHead dirHead (fun d => inputs.bv d) + let valsAbsBase : Rat := + Linear.sumFin dModel (fun j => |wvDir j|) * lnAbsMaxMax + let valsLoBase := bDir - valsAbsBase + let valsHiBase := bDir + valsAbsBase + let valsLo : Fin seq → Rat := fun _ => valsLoBase + let valsHi : Fin seq → Rat := fun _ => valsHiBase let univ : Finset (Fin seq) := Finset.univ have hnonempty : univ.Nonempty := by simp [univ] let lo := univ.inf' hnonempty valsLo @@ -259,115 +270,124 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} · by_cases hactive : inputs.active.Nonempty · let lnBounds := Bounds.cacheBoundPair2 (fun q => Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) - let lnLo : Fin seq → Fin dModel → Dyadic := lnBounds.1 - let lnHi : Fin seq → Fin dModel → Dyadic := lnBounds.2 - let qLo := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalLowerCachedDyadic (fun j => inputs.wq j d) (lnLo q) (lnHi q) + - inputs.bq d) - let qHi := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalUpperCachedDyadic (fun j => inputs.wq j d) (lnLo q) (lnHi q) + - inputs.bq d) - let kLo := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalLowerCachedDyadic (fun j => inputs.wk j d) (lnLo q) (lnHi q) + - inputs.bk d) - let kHi := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalUpperCachedDyadic (fun j => inputs.wk j d) (lnLo q) (lnHi q) + - inputs.bk d) - let vLo := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalLowerCachedDyadic (fun j => inputs.wv j d) (lnLo q) (lnHi q) + - inputs.bv d) - let vHi := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalUpperCachedDyadic (fun j => inputs.wv j d) (lnLo q) (lnHi q) + - inputs.bv d) - let qAbs := - Bounds.cacheBound2 (fun q d => max |qLo q d| |qHi q d|) - let kAbs := - Bounds.cacheBound2 (fun q d => max |kLo q d| |kHi q d|) - let kAbsMax : Fin dHead → Dyadic := fun d => + let lnLo : Fin seq → Fin dModel → Rat := lnBounds.1 + let lnHi : Fin seq → Fin dModel → Rat := lnBounds.2 + let lnAbsMaxTask : Fin seq → Rat := + Bounds.cacheBoundTask (fun q => + Bounds.intervalAbsBound (lnLo q) (lnHi q)) + let lnAbsMaxArr : Array Rat := + Array.ofFn (fun q : Fin seq => lnAbsMaxTask q) + let lnAbsMax : Fin seq → Rat := fun q => + lnAbsMaxArr[q.1]'(by + have hsize : lnAbsMaxArr.size = seq := by + simp [lnAbsMaxArr] + simp [hsize]) + let lnAbsMaxMax : Rat := let univ : Finset (Fin seq) := Finset.univ have hnonempty : univ.Nonempty := Finset.univ_nonempty - univ.sup' hnonempty (fun k => kAbs k d) + univ.sup' hnonempty (fun q => lnAbsMax q) + let qAbsRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + ⟨Array.ofFn (fun d : Fin dHead => + Bounds.dotIntervalAbsBound (fun j => inputs.wq j d) (lnLo q) (lnHi q) + + |inputs.bq d|), + by simp⟩)) + let qAbsBaseArr : Array { row : Array Rat // row.size = dHead } := + Array.ofFn (fun q : Fin seq => + (qAbsRowTasks[q.1]'(by + have hsize : qAbsRowTasks.size = seq := by + simp [qAbsRowTasks] + simp [hsize])).get) + let qAbsBase : Fin seq → Fin dHead → Rat := fun q d => + let row := qAbsBaseArr[q.1]'(by + have hsize : qAbsBaseArr.size = seq := by + simp [qAbsBaseArr] + simp [hsize]) + row.1[d.1]'(by + simp [row.2]) + let kAbsRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + ⟨Array.ofFn (fun d : Fin dHead => + Bounds.dotIntervalAbsBound (fun j => inputs.wk j d) (lnLo q) (lnHi q) + + |inputs.bk d|), + by simp⟩)) + let kAbsBaseArr : Array { row : Array Rat // row.size = dHead } := + Array.ofFn (fun q : Fin seq => + (kAbsRowTasks[q.1]'(by + have hsize : kAbsRowTasks.size = seq := by + simp [kAbsRowTasks] + simp [hsize])).get) + let kAbsBase : Fin seq → Fin dHead → Rat := fun q d => + let row := kAbsBaseArr[q.1]'(by + have hsize : kAbsBaseArr.size = seq := by + simp [kAbsBaseArr] + simp [hsize]) + row.1[d.1]'(by + simp [row.2]) + let qLo : Fin seq → Fin dHead → Rat := fun q d => -qAbsBase q d + let qHi : Fin seq → Fin dHead → Rat := fun q d => qAbsBase q d + let kLo : Fin seq → Fin dHead → Rat := fun q d => -kAbsBase q d + let kHi : Fin seq → Fin dHead → Rat := fun q d => kAbsBase q d + let qAbs : Fin seq → Fin dHead → Rat := qAbsBase + let kAbs : Fin seq → Fin dHead → Rat := kAbsBase let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k - let dotAbs : Fin seq → Fin seq → Dyadic := fun q k => - Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d) - let scoreBaseAbs : Fin seq → Fin seq → Dyadic := fun q k => + let dotAbs := + Bounds.cacheBound2Task (fun q k => + Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d)) + let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => |inputs.scale| * dotAbs q k - let scoreAbs : Fin seq → Fin seq → Dyadic := fun q k => + let scoreAbs : Fin seq → Fin seq → Rat := fun q k => if masked q k then |inputs.maskValue| else scoreBaseAbs q k - let scoreLo : Fin seq → Fin seq → Dyadic := fun q k => + let scoreLo : Fin seq → Fin seq → Rat := fun q k => if masked q k then inputs.maskValue else -scoreBaseAbs q k - let scoreHi : Fin seq → Fin seq → Dyadic := fun q k => + let scoreHi : Fin seq → Fin seq → Rat := fun q k => if masked q k then inputs.maskValue else scoreBaseAbs q k - let dotAbsUpper : Fin seq → Dyadic := fun q => - Linear.dotFin dHead (fun d => qAbs q d) kAbsMax - let scoreHiUpper : Fin seq → Dyadic := fun q => - max inputs.maskValue (|inputs.scale| * dotAbsUpper q) - let fastGap : Fin seq → Dyadic := fun q => - let prev := inputs.prev q - scoreLo q prev - scoreHiUpper q + let scoreLoPrev : Fin seq → Rat := fun q => + scoreLo q (inputs.prev q) let otherKeys : Fin seq → Finset (Fin seq) := fun q => (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - let maskedKeys : Fin seq → Finset (Fin seq) := fun q => - if inputs.maskCausal = true then - (otherKeys q).filter (fun k => q < k) - else - (∅ : Finset (Fin seq)) - let unmaskedKeys : Fin seq → Finset (Fin seq) := fun q => - (otherKeys q) \ (maskedKeys q) - let maskedGap : Fin seq → Dyadic := fun q => - scoreLo q (inputs.prev q) - inputs.maskValue - let marginAtRaw : Fin seq → Dyadic := fun q => - let fast := fastGap q - if fast < 0 then - let other := unmaskedKeys q - let maskedSet := maskedKeys q - if hunmasked : other.Nonempty then - let unmaskedMin := other.inf' hunmasked (fun k => - scoreLo q (inputs.prev q) - scoreHi q k) - if maskedSet.Nonempty then - min unmaskedMin (maskedGap q) - else - unmaskedMin + let marginAt : Fin seq → Rat := fun q => + if hq : q ∈ inputs.active then + let other := otherKeys q + if h : other.Nonempty then + other.inf' h (fun k => scoreLoPrev q - scoreHi q k) else - if maskedSet.Nonempty then - maskedGap q - else - (0 : Dyadic) + (0 : Rat) else - fast - let marginAt : Fin seq → Dyadic := fun q => - if q ∈ inputs.active then - marginAtRaw q - else - (0 : Dyadic) - let epsAt : Fin seq → Dyadic := fun q => + (0 : Rat) + let epsAt : Fin seq → Rat := fun q => if marginAt q < 0 then - (1 : Dyadic) + (1 : Rat) + else + ratDivUp (seq - 1) (1 + marginAt q) + let margin : Rat := + if h : inputs.active.Nonempty then + inputs.active.inf' h marginAt else - dyadicDivUp (seq - 1) (1 + marginAt q) - let margin : Dyadic := inputs.active.inf' hactive marginAt - let eps : Dyadic := + (0 : Rat) + let eps : Rat := if margin < 0 then - (1 : Dyadic) + (1 : Rat) else - dyadicDivUp (seq - 1) (1 + margin) + ratDivUp (seq - 1) (1 + margin) have hseq : (1 : Nat) ≤ seq := Nat.succ_le_iff.mpr (Nat.pos_of_ne_zero (NeZero.ne seq)) let dirHeadVec := dirHeadVecOfInputs inputs - let dirHead : Fin dHead → Dyadic := fun d => dirHeadVec.get d - let valsLo := - Bounds.cacheBound (fun k => - Bounds.dotIntervalLowerCachedDyadic dirHead (vLo k) (vHi k)) - let valsHi := - Bounds.cacheBound (fun k => - Bounds.dotIntervalUpperCachedDyadic dirHead (vLo k) (vHi k)) + let dirHead : Fin dHead → Rat := fun d => dirHeadVec.get d + let wvDir : Fin dModel → Rat := + Bounds.cacheBoundTask (fun j => + Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) + let bDir : Rat := + Linear.dotFin dHead dirHead (fun d => inputs.bv d) + let valsAbsBase : Rat := + Linear.sumFin dModel (fun j => |wvDir j|) * lnAbsMaxMax + let valsLoBase := bDir - valsAbsBase + let valsHiBase := bDir + valsAbsBase + let valsLo : Fin seq → Rat := fun _ => valsLoBase + let valsHi : Fin seq → Rat := fun _ => valsHiBase let univ : Finset (Fin seq) := Finset.univ have hnonempty : univ.Nonempty := by simp [univ] let lo := univ.inf' hnonempty valsLo @@ -387,11 +407,14 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} values := valCert } have hcore' : some cert = some c := by simpa - [buildInductionCertFromHeadCore?, hEps, hSqrt, hmodel, hactive, lnBounds, lnLo, - lnHi, qLo, qHi, kLo, kHi, vLo, vHi, qAbs, kAbs, kAbsMax, masked, dotAbs, - scoreBaseAbs, scoreAbs, scoreLo, scoreHi, dotAbsUpper, scoreHiUpper, fastGap, - otherKeys, maskedKeys, unmaskedKeys, maskedGap, marginAt, marginAtRaw, epsAt, - margin, eps, dirHeadVec, dirHead, valsLo, valsHi, univ, lo, hi, valCert, cert] + [buildInductionCertFromHeadCore?, hEps, hSqrt, hmodel, hactive, lnBounds, + lnLo, lnHi, lnAbsMaxTask, lnAbsMaxArr, lnAbsMax, lnAbsMaxMax, + qAbsRowTasks, qAbsBaseArr, qAbsBase, kAbsRowTasks, kAbsBaseArr, kAbsBase, + qLo, qHi, kLo, kHi, qAbs, kAbs, masked, dotAbs, scoreBaseAbs, scoreLo, + scoreHi, scoreLoPrev, otherKeys, marginAt, epsAt, margin, eps, dirHeadVec, + dirHead, wvDir, bDir, valsAbsBase, valsLoBase, valsHiBase, valsLo, + valsHi, univ, lo, hi, valCert, cert, Task.spawn, Bounds.cacheBoundTask_apply, + Bounds.cacheBound2Task_apply, Array.getElem_ofFn] using hcore have hc : c = cert := by simpa using (Option.some.inj hcore').symm @@ -407,66 +430,157 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} simpa [lnBounds, lnLo, lnHi, lnRealOfInputs, Bounds.cacheBoundPair2_apply_left, Bounds.cacheBoundPair2_apply_right] using hln i + have hln_abs : ∀ q j, |lnRealOfInputs inputs q j| ≤ (lnAbsMax q : Real) := by + intro q j + have hln := hln_bounds q + have h := + Bounds.abs_le_intervalAbsBound_real (lo := lnLo q) (hi := lnHi q) + (x := lnRealOfInputs inputs q) (hlo := fun j => (hln j).1) + (hhi := fun j => (hln j).2) j + simpa [lnAbsMax, lnAbsMaxArr, lnAbsMaxTask, Bounds.cacheBoundTask_apply, + Array.getElem_ofFn] using h + have hln_abs_max : ∀ q, lnAbsMax q ≤ lnAbsMaxMax := by + intro q + have hnonempty : (Finset.univ : Finset (Fin seq)).Nonempty := + Finset.univ_nonempty + have hmem : q ∈ (Finset.univ : Finset (Fin seq)) := by simp + simpa [lnAbsMaxMax] using + (Finset.le_sup'_iff (s := (Finset.univ : Finset (Fin seq))) + (H := hnonempty) (f := fun q => lnAbsMax q) (a := lnAbsMax q)).2 + ⟨q, hmem, le_rfl⟩ + have hdot_abs_bound : + ∀ (v : Fin dModel → Rat) (q : Fin seq), + |dotProduct (fun j => (v j : Real)) (lnRealOfInputs inputs q)| ≤ + (Bounds.dotIntervalAbsBound v (lnLo q) (lnHi q) : Real) := by + intro v q + have hln := hln_bounds q + have hlo : ∀ j, (lnLo q j : Real) ≤ lnRealOfInputs inputs q j := fun j => + (hln j).1 + have hhi : ∀ j, lnRealOfInputs inputs q j ≤ (lnHi q j : Real) := fun j => + (hln j).2 + simpa using + (Bounds.abs_dotProduct_le_dotIntervalAbsBound_real + (v := v) (lo := lnLo q) (hi := lnHi q) + (x := lnRealOfInputs inputs q) hlo hhi) + have hdot_abs_bound_sum : + ∀ (v : Fin dModel → Rat) (q : Fin seq), + |dotProduct (fun j => (v j : Real)) (lnRealOfInputs inputs q)| ≤ + (Linear.sumFin dModel (fun j => |v j|) : Real) * (lnAbsMax q : Real) := by + intro v q + have hsum : + |∑ j, (v j : Real) * lnRealOfInputs inputs q j| ≤ + ∑ j, |(v j : Real) * lnRealOfInputs inputs q j| := by + simpa [dotProduct] using + (Finset.abs_sum_le_sum_abs (s := (Finset.univ : Finset (Fin dModel))) + (f := fun j => (v j : Real) * lnRealOfInputs inputs q j)) + have hterm : + ∀ j, |(v j : Real) * lnRealOfInputs inputs q j| ≤ + (|v j| : Real) * (lnAbsMax q : Real) := by + intro j + have hln := hln_abs q j + have hnonneg : 0 ≤ (|v j| : Real) := by + exact abs_nonneg _ + calc + |(v j : Real) * lnRealOfInputs inputs q j| = + |(v j : Real)| * |lnRealOfInputs inputs q j| := by + simp [abs_mul] + _ ≤ (|v j| : Real) * (lnAbsMax q : Real) := + mul_le_mul_of_nonneg_left hln hnonneg + have hsum_le : + ∑ j, |(v j : Real) * lnRealOfInputs inputs q j| ≤ + ∑ j, (|v j| : Real) * (lnAbsMax q : Real) := by + refine Finset.sum_le_sum ?_ + intro j _ + exact hterm j + have hsum_mul : + ∑ j, (|v j| : Real) * (lnAbsMax q : Real) = + (∑ j, (|v j| : Real)) * (lnAbsMax q : Real) := by + symm + simpa using + (Finset.sum_mul (s := (Finset.univ : Finset (Fin dModel))) + (f := fun j => (|v j| : Real)) (a := (lnAbsMax q : Real))) + have hsum_cast : + (Linear.sumFin dModel (fun j => |v j|) : Real) = ∑ j, (|v j| : Real) := by + simpa [ratToReal] using + (Linear.ratToReal_sumFin (f := fun j => |v j|)) + have hsum_eq : + ∑ j, (|v j| : Real) * (lnAbsMax q : Real) = + (Linear.sumFin dModel (fun j => |v j|) : Real) * (lnAbsMax q : Real) := by + calc + ∑ j, (|v j| : Real) * (lnAbsMax q : Real) + = (∑ j, (|v j| : Real)) * (lnAbsMax q : Real) := hsum_mul + _ = (Linear.sumFin dModel (fun j => |v j|) : Real) * (lnAbsMax q : Real) := by + simp [hsum_cast] + have hfinal := hsum.trans (hsum_le.trans_eq hsum_eq) + simpa [dotProduct] using hfinal have hq_bounds : ∀ q d, (qLo q d : Real) ≤ qRealOfInputs inputs q d ∧ qRealOfInputs inputs q d ≤ (qHi q d : Real) := by intro q d - have hln := hln_bounds q - have hlo : ∀ j, (lnLo q j : Real) ≤ lnRealOfInputs inputs q j := fun j => (hln j).1 - have hhi : ∀ j, lnRealOfInputs inputs q j ≤ (lnHi q j : Real) := fun j => (hln j).2 - have hlow := - dotIntervalLower_le_dotProduct_real (v := fun j => inputs.wq j d) - (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi - have hhigh := - dotProduct_le_dotIntervalUpper_real (v := fun j => inputs.wq j d) - (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi - have hlow' := add_le_add_right hlow (inputs.bq d : Real) - have hhigh' := add_le_add_right hhigh (inputs.bq d : Real) + have hdot := hdot_abs_bound (fun j => inputs.wq j d) q + have hq_abs : + |qRealOfInputs inputs q d| ≤ (qAbsBase q d : Real) := by + have hsum : + |qRealOfInputs inputs q d| ≤ + (Bounds.dotIntervalAbsBound (fun j => inputs.wq j d) (lnLo q) (lnHi q) : + Real) + |(inputs.bq d : Real)| := by + calc + |qRealOfInputs inputs q d| + = |dotProduct (fun j => (inputs.wq j d : Real)) + (lnRealOfInputs inputs q) + (inputs.bq d : Real)| := by + simp [qRealOfInputs] + _ ≤ |dotProduct (fun j => (inputs.wq j d : Real)) + (lnRealOfInputs inputs q)| + |(inputs.bq d : Real)| := by + exact + (abs_add_le (a := dotProduct (fun j => (inputs.wq j d : Real)) + (lnRealOfInputs inputs q)) (b := (inputs.bq d : Real))) + _ ≤ (Bounds.dotIntervalAbsBound (fun j => inputs.wq j d) (lnLo q) (lnHi q) : + Real) + |(inputs.bq d : Real)| := by + exact add_le_add hdot (le_rfl) + have hsum' : + |qRealOfInputs inputs q d| ≤ (qAbsBase q d : Real) := by + simpa [qAbsBase, qAbsBaseArr, qAbsRowTasks, lnLo, lnHi, Task.spawn, + Bounds.dotIntervalAbsBound, ratToReal_add, ratToReal_abs] + using hsum + exact hsum' + have hq_bounds := (abs_le).1 hq_abs constructor - · simpa [qLo, qRealOfInputs, Bounds.cacheBound2_apply, - Bounds.dotIntervalLowerCachedRat_eq] using hlow' - · simpa [qHi, qRealOfInputs, Bounds.cacheBound2_apply, - Bounds.dotIntervalUpperCachedRat_eq] using hhigh' + · simpa [qLo] using hq_bounds.1 + · simpa [qHi] using hq_bounds.2 have hk_bounds : ∀ q d, (kLo q d : Real) ≤ kRealOfInputs inputs q d ∧ kRealOfInputs inputs q d ≤ (kHi q d : Real) := by intro q d - have hln := hln_bounds q - have hlo : ∀ j, (lnLo q j : Real) ≤ lnRealOfInputs inputs q j := fun j => (hln j).1 - have hhi : ∀ j, lnRealOfInputs inputs q j ≤ (lnHi q j : Real) := fun j => (hln j).2 - have hlow := - dotIntervalLower_le_dotProduct_real (v := fun j => inputs.wk j d) - (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi - have hhigh := - dotProduct_le_dotIntervalUpper_real (v := fun j => inputs.wk j d) - (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi - have hlow' := add_le_add_right hlow (inputs.bk d : Real) - have hhigh' := add_le_add_right hhigh (inputs.bk d : Real) - constructor - · simpa [kLo, kRealOfInputs, Bounds.cacheBound2_apply, - Bounds.dotIntervalLowerCachedRat_eq] using hlow' - · simpa [kHi, kRealOfInputs, Bounds.cacheBound2_apply, - Bounds.dotIntervalUpperCachedRat_eq] using hhigh' - have hv_bounds : - ∀ q d, (vLo q d : Real) ≤ vRealOfInputs inputs q d ∧ - vRealOfInputs inputs q d ≤ (vHi q d : Real) := by - intro q d - have hln := hln_bounds q - have hlo : ∀ j, (lnLo q j : Real) ≤ lnRealOfInputs inputs q j := fun j => (hln j).1 - have hhi : ∀ j, lnRealOfInputs inputs q j ≤ (lnHi q j : Real) := fun j => (hln j).2 - have hlow := - dotIntervalLower_le_dotProduct_real (v := fun j => inputs.wv j d) - (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi - have hhigh := - dotProduct_le_dotIntervalUpper_real (v := fun j => inputs.wv j d) - (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi - have hlow' := add_le_add_right hlow (inputs.bv d : Real) - have hhigh' := add_le_add_right hhigh (inputs.bv d : Real) + have hdot := hdot_abs_bound (fun j => inputs.wk j d) q + have hk_abs : + |kRealOfInputs inputs q d| ≤ (kAbsBase q d : Real) := by + have hsum : + |kRealOfInputs inputs q d| ≤ + (Bounds.dotIntervalAbsBound (fun j => inputs.wk j d) (lnLo q) (lnHi q) : + Real) + |(inputs.bk d : Real)| := by + calc + |kRealOfInputs inputs q d| + = |dotProduct (fun j => (inputs.wk j d : Real)) + (lnRealOfInputs inputs q) + (inputs.bk d : Real)| := by + simp [kRealOfInputs] + _ ≤ |dotProduct (fun j => (inputs.wk j d : Real)) + (lnRealOfInputs inputs q)| + |(inputs.bk d : Real)| := by + exact + (abs_add_le (a := dotProduct (fun j => (inputs.wk j d : Real)) + (lnRealOfInputs inputs q)) (b := (inputs.bk d : Real))) + _ ≤ (Bounds.dotIntervalAbsBound (fun j => inputs.wk j d) (lnLo q) (lnHi q) : + Real) + |(inputs.bk d : Real)| := by + exact add_le_add hdot (le_rfl) + have hsum' : + |kRealOfInputs inputs q d| ≤ (kAbsBase q d : Real) := by + simpa [kAbsBase, kAbsBaseArr, kAbsRowTasks, lnLo, lnHi, Task.spawn, + Bounds.dotIntervalAbsBound, ratToReal_add, ratToReal_abs] + using hsum + exact hsum' + have hk_bounds := (abs_le).1 hk_abs constructor - · simpa [vLo, vRealOfInputs, Bounds.cacheBound2_apply, - Bounds.dotIntervalLowerCachedRat_eq] using hlow' - · simpa [vHi, vRealOfInputs, Bounds.cacheBound2_apply, - Bounds.dotIntervalUpperCachedRat_eq] using hhigh' + · simpa [kLo] using hk_bounds.1 + · simpa [kHi] using hk_bounds.2 have hscore_bounds : ∀ q k, (scoreLo q k : Real) ≤ scoresRealOfInputs inputs q k ∧ scoresRealOfInputs inputs q k ≤ (scoreHi q k : Real) := by @@ -478,13 +592,21 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} have hq_abs : ∀ d, |qRealOfInputs inputs q d| ≤ (qAbs q d : Real) := by intro d have hq := hq_bounds q d - have h := abs_le_max_abs_abs_of_interval_real hq.1 hq.2 - simpa [qAbs, Bounds.cacheBound2_apply] using h + have hq' : + -(qAbsBase q d : Real) ≤ qRealOfInputs inputs q d ∧ + qRealOfInputs inputs q d ≤ (qAbsBase q d : Real) := by + simpa [qLo, qHi] using hq + have h := (abs_le).2 hq' + simpa [qAbs, qAbsBase] using h have hk_abs : ∀ d, |kRealOfInputs inputs k d| ≤ (kAbs k d : Real) := by intro d have hk := hk_bounds k d - have h := abs_le_max_abs_abs_of_interval_real hk.1 hk.2 - simpa [kAbs, Bounds.cacheBound2_apply] using h + have hk' : + -(kAbsBase k d : Real) ≤ kRealOfInputs inputs k d ∧ + kRealOfInputs inputs k d ≤ (kAbsBase k d : Real) := by + simpa [kLo, kHi] using hk + have h := (abs_le).2 hk' + simpa [kAbs, kAbsBase] using h have hdot_abs : |dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d)| ≤ @@ -503,11 +625,23 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} have hq := hq_abs d have hk := hk_abs d have hqnonneg : 0 ≤ (qAbs q d : Real) := by + have hdot_nonneg : + 0 ≤ Bounds.dotIntervalAbsBound + (fun j => inputs.wq j d) (lnLo q) (lnHi q) := by + have hleft : + 0 ≤ |Bounds.dotIntervalLower (fun j => inputs.wq j d) + (lnLo q) (lnHi q)| := by + exact abs_nonneg _ + exact le_trans hleft (le_max_left _ _) + have hbq_nonneg : 0 ≤ |inputs.bq d| := abs_nonneg _ + have hsum_nonneg : + 0 ≤ Bounds.dotIntervalAbsBound + (fun j => inputs.wq j d) (lnLo q) (lnHi q) + |inputs.bq d| := by + exact add_nonneg hdot_nonneg hbq_nonneg have hqnonneg' : 0 ≤ qAbs q d := by - have hmax : |qLo q d| ≤ qAbs q d := by - simp [qAbs, Bounds.cacheBound2_apply] - exact le_trans (abs_nonneg _) hmax - exact dyadicToReal_nonneg_of_nonneg hqnonneg' + simpa [qAbs, qAbsBase, qAbsBaseArr, qAbsRowTasks, lnLo, lnHi, + Task.spawn, Bounds.dotIntervalAbsBound] using hsum_nonneg + exact ratToReal_nonneg_of_nonneg hqnonneg' calc |qRealOfInputs inputs q d * kRealOfInputs inputs k d| = |qRealOfInputs inputs q d| * |kRealOfInputs inputs k d| := by @@ -524,21 +658,28 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} (dotAbs q k : Real) = ∑ d, (qAbs q d : Real) * (kAbs k d : Real) := by have hsum : - ((∑ d, qAbs q d * kAbs k d : Dyadic) : Real) = - ∑ d, ((qAbs q d * kAbs k d : Dyadic) : Real) := - Linear.dyadicToReal_sum_univ (f := fun d => qAbs q d * kAbs k d) + ((∑ d, qAbs q d * kAbs k d : Rat) : Real) = + ∑ d, ((qAbs q d * kAbs k d : Rat) : Real) := by + have h := Linear.ratToReal_sum_univ (f := fun d => qAbs q d * kAbs k d) + dsimp [ratToReal] at h + exact h have hsum' : - ∑ d, ((qAbs q d * kAbs k d : Dyadic) : Real) = + ∑ d, ((qAbs q d * kAbs k d : Rat) : Real) = ∑ d, (qAbs q d : Real) * (kAbs k d : Real) := by refine Finset.sum_congr rfl ?_ intro d _ - simp [dyadicToReal_mul] + simp have hfinal := hsum.trans hsum' - simpa [dotAbs, Linear.dotFin_eq_dotProduct, dotProduct] using hfinal + calc + (dotAbs q k : Real) + = ((∑ d, qAbs q d * kAbs k d : Rat) : Real) := by + simp [dotAbs, Bounds.cacheBound2Task_apply, + Linear.dotFin_eq_dotProduct, dotProduct] + _ = ∑ d, (qAbs q d : Real) * (kAbs k d : Real) := hfinal have hfinal := hsum.trans (hsum_le.trans_eq hcast.symm) simpa [dotProduct] using hfinal have hscale_abs : 0 ≤ (|inputs.scale| : Real) := by - exact abs_nonneg (dyadicToReal inputs.scale) + exact abs_nonneg (ratToReal inputs.scale) have hbase_abs : |base| ≤ (scoreBaseAbs q k : Real) := by have hdot_abs' := hdot_abs @@ -590,164 +731,92 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} · simpa [scoresReal, scoreHi, scoreAbs, masked, hcausal] using hscore_bounds.2 let scoresReal := scoresRealOfInputs inputs - have hdotAbs_le : ∀ q k, dotAbs q k ≤ dotAbsUpper q := by - intro q k - classical - have hnonneg : ∀ d, 0 ≤ qAbs q d := by - intro d - have h0 : 0 ≤ |qLo q d| := abs_nonneg _ - have hle : |qLo q d| ≤ qAbs q d := by - simp [qAbs, Bounds.cacheBound2_apply] - exact le_trans h0 hle - have hterm : ∀ d, qAbs q d * kAbs k d ≤ qAbs q d * kAbsMax d := by - intro d - have hmem : k ∈ (Finset.univ : Finset (Fin seq)) := by simp - have hkabs : kAbs k d ≤ kAbsMax d := by - simpa [kAbsMax] using - (Finset.le_sup' (s := (Finset.univ : Finset (Fin seq))) - (f := fun k => kAbs k d) (b := k) hmem) - exact mul_le_mul_of_nonneg_left hkabs (hnonneg d) - have hsum : (∑ d, qAbs q d * kAbs k d) ≤ ∑ d, qAbs q d * kAbsMax d := by - refine Finset.sum_le_sum ?_ - intro d _ - exact hterm d - simpa [dotAbs, dotAbsUpper, Linear.dotFin_eq_dotProduct, dotProduct] using hsum - have hscoreHi_le : ∀ q k, scoreHi q k ≤ scoreHiUpper q := by - intro q k - by_cases hmask : masked q k - · simp [scoreHi, scoreHiUpper, hmask] - · have hdot := hdotAbs_le q k - have hmul : |inputs.scale| * dotAbs q k ≤ |inputs.scale| * dotAbsUpper q := by - exact mul_le_mul_of_nonneg_left hdot (abs_nonneg _) - calc - scoreHi q k = |inputs.scale| * dotAbs q k := by - simp [scoreHi, scoreBaseAbs, hmask] - _ ≤ |inputs.scale| * dotAbsUpper q := hmul - _ ≤ max inputs.maskValue (|inputs.scale| * dotAbsUpper q) := by - exact le_max_right _ _ + have hmarginAt_le : + ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → + marginAt q ≤ scoreLoPrev q - scoreHi q k := by + intro q hq k hk + have hmem : k ∈ otherKeys q := by + simp [otherKeys, hk] + have hnonempty : (otherKeys q).Nonempty := ⟨k, hmem⟩ + have hle : + (otherKeys q).inf' hnonempty (fun k => scoreLoPrev q - scoreHi q k) ≤ + scoreLoPrev q - scoreHi q k := by + exact + (Finset.inf'_le_iff (s := otherKeys q) (H := hnonempty) + (f := fun k => scoreLoPrev q - scoreHi q k) + (a := scoreLoPrev q - scoreHi q k)).2 + ⟨k, hmem, le_rfl⟩ + simpa [marginAt, hq, hnonempty] using hle let weights : Fin seq → Fin seq → Real := fun q k => Circuit.softmax (scoresReal q) k let others : Fin seq → Finset (Fin seq) := fun q => (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + have hscore_margin_real_at : + ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → + scoresReal q k + (marginAt q : Real) ≤ scoresReal q (inputs.prev q) := by + intro q hq k hk + have hmargin_le : marginAt q ≤ scoreLoPrev q - scoreHi q k := + hmarginAt_le q hq k hk + have hmargin_le_real : + (marginAt q : Real) ≤ (scoreLoPrev q : Real) - (scoreHi q k : Real) := + by + simpa [ratToReal_sub] using (ratToReal_le_of_le hmargin_le) + have hscore_hi : scoresReal q k ≤ (scoreHi q k : Real) := + (hscore_bounds q k).2 + have hscore_prev : (scoreLoPrev q : Real) ≤ scoresReal q (inputs.prev q) := by + have hprev_bounds := hscore_bounds q (inputs.prev q) + simpa [scoreLoPrev] using hprev_bounds.1 + have hscore_diff : scoresReal q k - (scoreHi q k : Real) ≤ 0 := by + have h := sub_le_sub_right hscore_hi (scoreHi q k : Real) + simpa using h + have hsum_le' : + (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k ≤ + (scoreLoPrev q : Real) := by + have hsub : + (scoreLoPrev q : Real) - (scoreHi q k : Real) ≤ + (scoreLoPrev q : Real) - scoresReal q k := + sub_le_sub_left hscore_hi (scoreLoPrev q : Real) + have hsum_le'' := add_le_add_left hsub (scoresReal q k) + have hsum_le''' : + (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k ≤ + (scoreLoPrev q : Real) - scoresReal q k + scoresReal q k := by + simpa [add_comm, add_left_comm, add_assoc] using hsum_le'' + calc + (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k + ≤ (scoreLoPrev q : Real) - scoresReal q k + scoresReal q k := hsum_le''' + _ = (scoreLoPrev q : Real) := by + simp [sub_add_cancel] + have hgap : + scoresReal q k + (marginAt q : Real) ≤ (scoreLoPrev q : Real) := by + have hstep := add_le_add_left hmargin_le_real (scoresReal q k) + have hstep' : + scoresReal q k + (marginAt q : Real) ≤ + (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k := by + simpa [add_comm, add_left_comm, add_assoc] using hstep + exact hstep'.trans hsum_le' + exact hgap.trans hscore_prev have hscore_margin_real : ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → scoresReal q k + (margin : Real) ≤ scoresReal q (inputs.prev q) := by intro q hq k hk have hmargin_le : margin ≤ marginAt q := by - have hle : margin ≤ inputs.active.inf' hactive marginAt := by - simp [margin] - have hle_all := - (Finset.le_inf'_iff (s := inputs.active) (H := hactive) (f := marginAt) - (a := margin)).1 hle - exact hle_all q hq - have hgap_le : - marginAt q ≤ scoreLo q (inputs.prev q) - scoreHi q k := by - by_cases hfast : fastGap q < 0 - · by_cases hmask : k ∈ maskedKeys q - · have hmask_nonempty : (maskedKeys q).Nonempty := ⟨k, hmask⟩ - have hmargin_eq : marginAt q = marginAtRaw q := by - simp [marginAt, hq] - have hraw_le : marginAtRaw q ≤ maskedGap q := by - by_cases hunmasked : (unmaskedKeys q).Nonempty - · have hraw_eq : - marginAtRaw q = - let unmaskedMin := (unmaskedKeys q).inf' hunmasked - (fun k => scoreLo q (inputs.prev q) - scoreHi q k) - min unmaskedMin (maskedGap q) := by - simp [marginAtRaw, hfast, hunmasked, hmask_nonempty] - simp [hraw_eq] - · have hraw_eq : marginAtRaw q = maskedGap q := by - simp [marginAtRaw, hfast, hunmasked, hmask_nonempty] - simp [hraw_eq] - have hcausal : inputs.maskCausal = true := by - by_contra hcausal - simp [maskedKeys, hcausal] at hmask - have hmem : - k ∈ (otherKeys q).filter (fun k => q < k) := by - simpa [maskedKeys, hcausal] using hmask - have hlt : q < k := (Finset.mem_filter.mp hmem).2 - have hmask_prop : masked q k := ⟨hcausal, hlt⟩ - have hmask_score : scoreHi q k = inputs.maskValue := by - simp [scoreHi, hmask_prop] - have hgap : marginAt q ≤ scoreLo q (inputs.prev q) - inputs.maskValue := by - simpa [hmargin_eq] using hraw_le - simpa [maskedGap, hmask_score] using hgap - · have hmem : k ∈ unmaskedKeys q := by - have hother_mem : k ∈ otherKeys q := by simp [otherKeys, hk] - simp [unmaskedKeys, hother_mem, hmask] - have hunmasked : (unmaskedKeys q).Nonempty := ⟨k, hmem⟩ - have hmargin_eq : marginAt q = marginAtRaw q := by - simp [marginAt, hq] - have hraw_le : marginAtRaw q ≤ - (unmaskedKeys q).inf' hunmasked - (fun k => scoreLo q (inputs.prev q) - scoreHi q k) := by - let unmaskedMin := - (unmaskedKeys q).inf' hunmasked - (fun k => scoreLo q (inputs.prev q) - scoreHi q k) - by_cases hmask_nonempty : (maskedKeys q).Nonempty - · have hraw_eq : marginAtRaw q = min unmaskedMin (maskedGap q) := by - simp [marginAtRaw, hfast, hunmasked, hmask_nonempty, unmaskedMin, - unmaskedKeys, maskedKeys] - have hmin_le : marginAtRaw q ≤ unmaskedMin := by - rw [hraw_eq] - exact min_le_left _ _ - exact hmin_le - · simp [marginAtRaw, hfast, hunmasked, hmask_nonempty, unmaskedKeys, - maskedKeys] - have hle_all := - (Finset.le_inf'_iff (s := unmaskedKeys q) (H := hunmasked) - (f := fun k => scoreLo q (inputs.prev q) - scoreHi q k) - (a := marginAtRaw q)).1 hraw_le - have hle := hle_all k hmem - simpa [hmargin_eq] using hle - · have hgap_fast : fastGap q ≤ scoreLo q (inputs.prev q) - scoreHi q k := by - have hle_score : scoreHi q k ≤ scoreHiUpper q := hscoreHi_le q k - have hle_sub : - scoreLo q (inputs.prev q) - scoreHiUpper q ≤ - scoreLo q (inputs.prev q) - scoreHi q k := - sub_le_sub_left hle_score (scoreLo q (inputs.prev q)) - simpa [fastGap] using hle_sub - have hmargin_eq : marginAt q = fastGap q := by - simp [marginAt, marginAtRaw, hq, hfast] - simpa [hmargin_eq] using hgap_fast - have hgap : margin ≤ scoreLo q (inputs.prev q) - scoreHi q k := - le_trans hmargin_le hgap_le - have hgap_real : (margin : Real) ≤ - (scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real) := by - have hgap_real' : - (margin : Real) ≤ ((scoreLo q (inputs.prev q) - scoreHi q k : Dyadic) : Real) := - dyadicToReal_le_of_le hgap - simpa [dyadicToReal_sub] using hgap_real' - have hk_bounds := hscore_bounds q k - have hprev_bounds := hscore_bounds q (inputs.prev q) - have h1 : - scoresReal q k + (margin : Real) ≤ - scoresReal q k + - ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) := by - exact (add_le_add_iff_left (scoresReal q k)).2 hgap_real - have h2 : - scoresReal q k + - ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) ≤ - (scoreLo q (inputs.prev q) : Real) := by - have hscore_le' : - scoresReal q k + - ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) ≤ - (scoreHi q k : Real) + - ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) := by - exact (add_le_add_iff_right - ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real))).2 hk_bounds.2 - calc - scoresReal q k + - ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) ≤ - (scoreHi q k : Real) + - ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) := by - exact hscore_le' - _ = (scoreLo q (inputs.prev q) : Real) := by - exact add_sub_cancel (scoreHi q k : Real) (scoreLo q (inputs.prev q) : Real) - have h3 : - scoresReal q k + (margin : Real) ≤ (scoreLo q (inputs.prev q) : Real) := - h1.trans h2 - exact h3.trans hprev_bounds.1 + have hmem : q ∈ inputs.active := hq + have hnonempty : inputs.active.Nonempty := hactive + have hle := + (Finset.inf'_le_iff (s := inputs.active) (H := hnonempty) + (f := marginAt) (a := marginAt q)).2 ⟨q, hmem, le_rfl⟩ + simpa [margin, hnonempty] using hle + have hmargin_le_real : (margin : Real) ≤ (marginAt q : Real) := + ratToReal_le_of_le hmargin_le + have hscore := hscore_margin_real_at q hq k hk + have hscore' : + (marginAt q : Real) + scoresReal q k ≤ scoresReal q (inputs.prev q) := by + simpa [add_comm, add_left_comm, add_assoc] using hscore + have hstep := add_le_add_left hmargin_le_real (scoresReal q k) + have hstep' : + scoresReal q k + (margin : Real) ≤ (marginAt q : Real) + scoresReal q k := by + simpa [add_comm, add_left_comm, add_assoc] using hstep + exact hstep'.trans hscore' have hsoftmax_bounds : Layers.SoftmaxMarginBoundsOn (Val := Real) (eps : Real) (margin : Real) (fun q => q ∈ inputs.active) inputs.prev scoresReal weights := by @@ -792,7 +861,7 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} simpa [heps] using hsum_le' · have hnonneg : 0 ≤ margin := le_of_not_gt hneg have hnonneg_real : 0 ≤ (margin : Real) := by - exact dyadicToReal_nonneg_of_nonneg hnonneg + exact ratToReal_nonneg_of_nonneg hnonneg have hbound : ∀ k ∈ others q, weights q k ≤ (1 + (margin : Real))⁻¹ := by @@ -820,21 +889,19 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} have hsum_le''' := hsum_le'' simp only [hcard, Nat.cast_sub hseq, Nat.cast_one] at hsum_le''' exact hsum_le''' + have hpos : (0 : Rat) < 1 + margin := by + have hone : (0 : Rat) < 1 := by + exact zero_lt_one + have hle : (1 : Rat) ≤ 1 + margin := by + exact le_add_of_nonneg_right hnonneg + exact lt_of_lt_of_le hone hle have hden : (1 + margin) ≠ 0 := by - have hnonneg_real : 0 ≤ (margin : Real) := - dyadicToReal_nonneg_of_nonneg hnonneg - have hpos_real : (0 : Real) < 1 + (margin : Real) := by - linarith - have hpos_real' : dyadicToReal 0 < dyadicToReal (1 + margin) := by - simpa [dyadicToReal_add] using hpos_real - have hpos : (0 : Dyadic) < 1 + margin := - (dyadicToReal_lt_iff).1 hpos_real' exact ne_of_gt hpos + have hrat' := ratDivUp_ge_real (seq - 1) (1 + margin) hden have heps : (seq - 1 : Real) * (1 + (margin : Real))⁻¹ ≤ (eps : Real) := by - have hrat' := dyadicDivUp_ge_real (seq - 1) (1 + margin) hden - simpa [eps, hneg, dyadicToReal, Dyadic.toRat_add, Dyadic.toRat_natCast, - Rat.cast_div, Rat.cast_add, Rat.cast_natCast, div_eq_mul_inv] using hrat' + simpa [eps, hneg, ratToReal, Rat.cast_div, Rat.cast_add, + Rat.cast_natCast, div_eq_mul_inv] using hrat' exact le_trans hsum_le' heps have hsum_eq : weights q (inputs.prev q) + ∑ k ∈ others q, weights q k = 1 := by @@ -853,10 +920,16 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} weights q (inputs.prev q) + ∑ k ∈ others q, weights q k ≤ weights q (inputs.prev q) + (eps : Real) := by have hsum_le'' := add_le_add_left hsum_others_le (weights q (inputs.prev q)) - simpa [add_comm, add_left_comm, add_assoc] using hsum_le'' + have hsum_le''' := hsum_le'' + rw [add_comm (∑ k ∈ others q, weights q k) + (weights q (inputs.prev q))] at hsum_le''' + rw [add_comm (eps : Real) (weights q (inputs.prev q))] at hsum_le''' + exact hsum_le''' have hprev : 1 ≤ weights q (inputs.prev q) + (eps : Real) := by - simpa [hsum_eq] using hsum_le' + have hsum_le'' := hsum_le' + rw [hsum_eq] at hsum_le'' + exact hsum_le'' exact hprev · intro q hq k hk have hsum_others_le : (∑ j ∈ others q, weights q j) ≤ (eps : Real) := by @@ -884,7 +957,7 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} simpa [heps] using hsum_le' · have hnonneg : 0 ≤ margin := le_of_not_gt hneg have hnonneg_real : 0 ≤ (margin : Real) := by - exact dyadicToReal_nonneg_of_nonneg hnonneg + exact ratToReal_nonneg_of_nonneg hnonneg have hbound : ∀ j ∈ others q, weights q j ≤ (1 + (margin : Real))⁻¹ := by @@ -912,21 +985,19 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} have hsum_le''' := hsum_le'' simp only [hcard, Nat.cast_sub hseq, Nat.cast_one] at hsum_le''' exact hsum_le''' + have hpos : (0 : Rat) < 1 + margin := by + have hone : (0 : Rat) < 1 := by + exact zero_lt_one + have hle : (1 : Rat) ≤ 1 + margin := by + exact le_add_of_nonneg_right hnonneg + exact lt_of_lt_of_le hone hle have hden : (1 + margin) ≠ 0 := by - have hnonneg_real : 0 ≤ (margin : Real) := - dyadicToReal_nonneg_of_nonneg hnonneg - have hpos_real : (0 : Real) < 1 + (margin : Real) := by - linarith - have hpos_real' : dyadicToReal 0 < dyadicToReal (1 + margin) := by - simpa [dyadicToReal_add] using hpos_real - have hpos : (0 : Dyadic) < 1 + margin := - (dyadicToReal_lt_iff).1 hpos_real' exact ne_of_gt hpos + have hrat' := ratDivUp_ge_real (seq - 1) (1 + margin) hden have heps : (seq - 1 : Real) * (1 + (margin : Real))⁻¹ ≤ (eps : Real) := by - have hrat' := dyadicDivUp_ge_real (seq - 1) (1 + margin) hden - simpa [eps, hneg, dyadicToReal, Dyadic.toRat_add, Dyadic.toRat_natCast, - Rat.cast_div, Rat.cast_add, Rat.cast_natCast, div_eq_mul_inv] using hrat' + simpa [eps, hneg, ratToReal, Rat.cast_div, Rat.cast_add, + Rat.cast_natCast, div_eq_mul_inv] using hrat' exact le_trans hsum_le' heps have hk' : k ∈ others q := by simp [others, hk] @@ -940,126 +1011,14 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} have h := Finset.single_le_sum hnonneg hk' simpa using h exact hle.trans hsum_others_le - have hscore_margin_real_at : - ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → - scoresReal q k + (marginAt q : Real) ≤ scoresReal q (inputs.prev q) := by - intro q hq k hk - have hgap_le : - marginAt q ≤ scoreLo q (inputs.prev q) - scoreHi q k := by - by_cases hfast : fastGap q < 0 - · by_cases hmask : k ∈ maskedKeys q - · have hmask_nonempty : (maskedKeys q).Nonempty := ⟨k, hmask⟩ - have hmargin_eq : marginAt q = marginAtRaw q := by - simp [marginAt, hq] - have hraw_le : marginAtRaw q ≤ maskedGap q := by - by_cases hunmasked : (unmaskedKeys q).Nonempty - · have hraw_eq : - marginAtRaw q = - let unmaskedMin := (unmaskedKeys q).inf' hunmasked - (fun k => scoreLo q (inputs.prev q) - scoreHi q k) - min unmaskedMin (maskedGap q) := by - simp [marginAtRaw, hfast, hunmasked, hmask_nonempty] - simp [hraw_eq] - · have hraw_eq : marginAtRaw q = maskedGap q := by - simp [marginAtRaw, hfast, hunmasked, hmask_nonempty] - simp [hraw_eq] - have hcausal : inputs.maskCausal = true := by - by_contra hcausal - simp [maskedKeys, hcausal] at hmask - have hmem : - k ∈ (otherKeys q).filter (fun k => q < k) := by - simpa [maskedKeys, hcausal] using hmask - have hlt : q < k := (Finset.mem_filter.mp hmem).2 - have hmask_prop : masked q k := ⟨hcausal, hlt⟩ - have hmask_score : scoreHi q k = inputs.maskValue := by - simp [scoreHi, hmask_prop] - have hgap : marginAt q ≤ scoreLo q (inputs.prev q) - inputs.maskValue := by - simpa [hmargin_eq] using hraw_le - simpa [maskedGap, hmask_score] using hgap - · have hmem : k ∈ unmaskedKeys q := by - have hother_mem : k ∈ otherKeys q := by simp [otherKeys, hk] - simp [unmaskedKeys, hother_mem, hmask] - have hunmasked : (unmaskedKeys q).Nonempty := ⟨k, hmem⟩ - have hmargin_eq : marginAt q = marginAtRaw q := by - simp [marginAt, hq] - have hraw_le : marginAtRaw q ≤ - (unmaskedKeys q).inf' hunmasked - (fun k => scoreLo q (inputs.prev q) - scoreHi q k) := by - let unmaskedMin := - (unmaskedKeys q).inf' hunmasked - (fun k => scoreLo q (inputs.prev q) - scoreHi q k) - by_cases hmask_nonempty : (maskedKeys q).Nonempty - · have hraw_eq : marginAtRaw q = min unmaskedMin (maskedGap q) := by - simp [marginAtRaw, hfast, hunmasked, hmask_nonempty, unmaskedMin, - unmaskedKeys, maskedKeys] - rw [hraw_eq] - exact min_le_left _ _ - · simp [marginAtRaw, hfast, hunmasked, hmask_nonempty, unmaskedKeys, - maskedKeys] - have hle_all := - (Finset.le_inf'_iff (s := unmaskedKeys q) (H := hunmasked) - (f := fun k => scoreLo q (inputs.prev q) - scoreHi q k) - (a := marginAtRaw q)).1 hraw_le - have hle := hle_all k hmem - simpa [hmargin_eq] using hle - · have hgap_fast : fastGap q ≤ scoreLo q (inputs.prev q) - scoreHi q k := by - have hle_score : scoreHi q k ≤ scoreHiUpper q := hscoreHi_le q k - have hle_sub : - scoreLo q (inputs.prev q) - scoreHiUpper q ≤ - scoreLo q (inputs.prev q) - scoreHi q k := - sub_le_sub_left hle_score (scoreLo q (inputs.prev q)) - simpa [fastGap] using hle_sub - have hmargin_eq : marginAt q = fastGap q := by - simp [marginAt, marginAtRaw, hq, hfast] - simpa [hmargin_eq] using hgap_fast - have hgap_real : - (marginAt q : Real) ≤ - (scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real) := by - have hgap_real' : - (marginAt q : Real) ≤ - ((scoreLo q (inputs.prev q) - scoreHi q k : Dyadic) : Real) := - dyadicToReal_le_of_le hgap_le - simpa [dyadicToReal_sub] using hgap_real' - have hk_bounds := hscore_bounds q k - have hprev_bounds := hscore_bounds q (inputs.prev q) - have h1 : - scoresReal q k + (marginAt q : Real) ≤ - (scoreHi q k : Real) + (marginAt q : Real) := by - have h1' := add_le_add_right hk_bounds.2 (marginAt q : Real) - simpa [scoresReal] using h1' - have h2 : - (scoreHi q k : Real) + (marginAt q : Real) ≤ - (scoreLo q (inputs.prev q) : Real) := by - have hgap_real' : - (scoreHi q k : Real) + (marginAt q : Real) ≤ - (scoreHi q k : Real) + - ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) := by - exact add_le_add_right hgap_real (scoreHi q k : Real) - have hgap_real'' : - (scoreHi q k : Real) + - ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) = - (scoreLo q (inputs.prev q) : Real) := by - calc - (scoreHi q k : Real) + - ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) = - ((scoreLo q (inputs.prev q) : Real) - (scoreHi q k : Real)) + - (scoreHi q k : Real) := by - exact add_comm _ _ - _ = (scoreLo q (inputs.prev q) : Real) := by - exact sub_add_cancel (scoreLo q (inputs.prev q) : Real) (scoreHi q k : Real) - exact hgap_real'.trans (le_of_eq hgap_real'') - have h3 : - scoresReal q k + (marginAt q : Real) ≤ - (scoreLo q (inputs.prev q) : Real) := h1.trans h2 - exact h3.trans hprev_bounds.1 have hepsAt : ∀ q, epsAt q = if marginAt q < 0 then - (1 : Dyadic) + (1 : Rat) else - dyadicDivUp (seq - 1) (1 + marginAt q) := by + ratDivUp (seq - 1) (1 + marginAt q) := by intro q - simp [epsAt] + rfl have oneHot_bounds_at : ∀ q, q ∈ inputs.active → Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAt q : Real) @@ -1076,6 +1035,183 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} (hseq := hseq) (hscore_margin_real_at := hscore_margin_real_at) q hq + have hdir_wv : + ∀ j, (wvDir j : Real) = ∑ d, (dirHead d : Real) * (inputs.wv j d : Real) := by + intro j + have hsum : + ((∑ d, dirHead d * inputs.wv j d : Rat) : Real) = + ∑ d, ((dirHead d * inputs.wv j d : Rat) : Real) := by + have h := + Linear.ratToReal_sum_univ (f := fun d => dirHead d * inputs.wv j d) + dsimp [ratToReal] at h + exact h + have hsum' : + ∑ d, ((dirHead d * inputs.wv j d : Rat) : Real) = + ∑ d, (dirHead d : Real) * (inputs.wv j d : Real) := by + refine Finset.sum_congr rfl ?_ + intro d _ + simp + have hfinal := hsum.trans hsum' + calc + (wvDir j : Real) + = ((∑ d, dirHead d * inputs.wv j d : Rat) : Real) := by + simp [wvDir, Bounds.cacheBoundTask_apply, Linear.dotFin_eq_dotProduct, + dotProduct] + _ = ∑ d, (dirHead d : Real) * (inputs.wv j d : Real) := hfinal + have hdir_bv : + (bDir : Real) = ∑ d, (dirHead d : Real) * (inputs.bv d : Real) := by + have hsum : + ((∑ d, dirHead d * inputs.bv d : Rat) : Real) = + ∑ d, ((dirHead d * inputs.bv d : Rat) : Real) := by + have h := + Linear.ratToReal_sum_univ (f := fun d => dirHead d * inputs.bv d) + dsimp [ratToReal] at h + exact h + have hsum' : + ∑ d, ((dirHead d * inputs.bv d : Rat) : Real) = + ∑ d, (dirHead d : Real) * (inputs.bv d : Real) := by + refine Finset.sum_congr rfl ?_ + intro d _ + simp + have hfinal := hsum.trans hsum' + calc + (bDir : Real) + = ((∑ d, dirHead d * inputs.bv d : Rat) : Real) := by + simp [bDir, Linear.dotFin_eq_dotProduct, dotProduct] + _ = ∑ d, (dirHead d : Real) * (inputs.bv d : Real) := hfinal + have hvals_eq : + ∀ k, + valsRealOfInputs inputs k = + dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + + (bDir : Real) := by + intro k + classical + have hdot_add : + dotProduct (fun d => (dirHead d : Real)) + (fun d => + dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs k) + + (inputs.bv d : Real)) = + dotProduct (fun d => (dirHead d : Real)) + (fun d => dotProduct (fun j => (inputs.wv j d : Real)) + (lnRealOfInputs inputs k)) + + dotProduct (fun d => (dirHead d : Real)) (fun d => (inputs.bv d : Real)) := by + simp [dotProduct, mul_add, Finset.sum_add_distrib] + have hdot_wv : + dotProduct (fun d => (dirHead d : Real)) + (fun d => dotProduct (fun j => (inputs.wv j d : Real)) + (lnRealOfInputs inputs k)) = + dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) := by + classical + calc + dotProduct (fun d => (dirHead d : Real)) + (fun d => dotProduct (fun j => (inputs.wv j d : Real)) + (lnRealOfInputs inputs k)) = + ∑ d, (dirHead d : Real) * ∑ j, + (inputs.wv j d : Real) * lnRealOfInputs inputs k j := by + simp [dotProduct] + _ = ∑ d, ∑ j, + (dirHead d : Real) * + ((inputs.wv j d : Real) * lnRealOfInputs inputs k j) := by + simp [Finset.mul_sum] + _ = ∑ j, ∑ d, + (dirHead d : Real) * + ((inputs.wv j d : Real) * lnRealOfInputs inputs k j) := by + simpa using + (Finset.sum_comm (s := (Finset.univ : Finset (Fin dHead))) + (t := (Finset.univ : Finset (Fin dModel))) + (f := fun d j => + (dirHead d : Real) * + ((inputs.wv j d : Real) * lnRealOfInputs inputs k j))) + _ = ∑ j, (∑ d, (dirHead d : Real) * (inputs.wv j d : Real)) * + lnRealOfInputs inputs k j := by + refine Finset.sum_congr rfl ?_ + intro j _ + have hsum : + (∑ d, (dirHead d : Real) * (inputs.wv j d : Real)) * + lnRealOfInputs inputs k j = + ∑ d, (dirHead d : Real) * (inputs.wv j d : Real) * + lnRealOfInputs inputs k j := by + simp [Finset.sum_mul, mul_assoc] + simpa [mul_assoc] using hsum.symm + _ = ∑ j, (wvDir j : Real) * lnRealOfInputs inputs k j := by + refine Finset.sum_congr rfl ?_ + intro j _ + simp [hdir_wv j] + _ = dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) := by + simp [dotProduct] + calc + valsRealOfInputs inputs k = + dotProduct (fun d => (dirHead d : Real)) + (fun d => + dotProduct (fun j => (inputs.wv j d : Real)) + (lnRealOfInputs inputs k) + + (inputs.bv d : Real)) := by + simp [valsRealOfInputs, vRealOfInputs, dirHeadVec, dirHead] + _ = + dotProduct (fun d => (dirHead d : Real)) + (fun d => dotProduct (fun j => (inputs.wv j d : Real)) + (lnRealOfInputs inputs k)) + + dotProduct (fun d => (dirHead d : Real)) + (fun d => (inputs.bv d : Real)) := hdot_add + _ = + dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + + dotProduct (fun d => (dirHead d : Real)) + (fun d => (inputs.bv d : Real)) := by + simp [hdot_wv] + _ = + dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + + (bDir : Real) := by + have hb : + dotProduct (fun d => (dirHead d : Real)) + (fun d => (inputs.bv d : Real)) = + (bDir : Real) := by + have hb : (dotProduct (fun d => (dirHead d : Real)) + (fun d => (inputs.bv d : Real)) : Real) = (bDir : Real) := by + calc + dotProduct (fun d => (dirHead d : Real)) (fun d => (inputs.bv d : Real)) + = ∑ d, (dirHead d : Real) * (inputs.bv d : Real) := by + simp [dotProduct] + _ = (bDir : Real) := hdir_bv.symm + exact hb + simp [hb] + have hvals_bounds_at : + ∀ k, + (valCert.valsLo k : Real) ≤ valsRealOfInputs inputs k ∧ + valsRealOfInputs inputs k ≤ (valCert.valsHi k : Real) := by + intro k + have hdot_abs : + |dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k)| ≤ + (valsAbsBase : Real) := by + have hdot := hdot_abs_bound_sum (fun j => wvDir j) k + have hln_max_real : + (lnAbsMax k : Real) ≤ (lnAbsMaxMax : Real) := + ratToReal_le_of_le (hln_abs_max k) + have hsum_nonneg : 0 ≤ (Linear.sumFin dModel (fun j => |wvDir j|) : Real) := by + have hsum_nonneg' : 0 ≤ (Linear.sumFin dModel (fun j => |wvDir j|) : Rat) := by + have hsum_nonneg'' : 0 ≤ ∑ j, |wvDir j| := by + refine Finset.sum_nonneg ?_ + intro j _ + exact abs_nonneg _ + simpa [Linear.sumFin_eq_sum_univ] using hsum_nonneg'' + exact ratToReal_nonneg_of_nonneg hsum_nonneg' + have hmul : + (Linear.sumFin dModel (fun j => |wvDir j|) : Real) * (lnAbsMax k : Real) ≤ + (Linear.sumFin dModel (fun j => |wvDir j|) : Real) * (lnAbsMaxMax : Real) := + mul_le_mul_of_nonneg_left hln_max_real hsum_nonneg + have hfinal := hdot.trans hmul + simpa [valsAbsBase, ratToReal_mul] using hfinal + have hdot_bounds := (abs_le).1 hdot_abs + have hlow' := add_le_add_right hdot_bounds.1 (bDir : Real) + have hhigh' := add_le_add_right hdot_bounds.2 (bDir : Real) + have hlow : + (valCert.valsLo k : Real) ≤ valsRealOfInputs inputs k := by + simpa [valCert, valsLo, valsLoBase, valsAbsBase, hvals_eq k, ratToReal_sub, + sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using hlow' + have hhigh : + valsRealOfInputs inputs k ≤ (valCert.valsHi k : Real) := by + simpa [valCert, valsHi, valsHiBase, valsAbsBase, hvals_eq k, ratToReal_add, + add_comm, add_left_comm, add_assoc] using hhigh' + exact ⟨hlow, hhigh⟩ have hvals_bounds : ValueIntervalBounds (vals := valsRealOfInputs inputs) valCert := by refine @@ -1086,92 +1222,52 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} · rcases (Finset.univ_nonempty : univ.Nonempty) with ⟨k0, hk0⟩ have hmem0 : k0 ∈ univ := hk0 have hlo : (valCert.lo : Real) ≤ (valCert.valsLo k0 : Real) := by - have hloDyadic : valCert.lo ≤ valCert.valsLo k0 := by + have hloRat : valCert.lo ≤ valCert.valsLo k0 := by change lo ≤ valsLo k0 dsimp [lo] refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) (f := valsLo) (a := valsLo k0)).2 ?_ refine ⟨k0, hmem0, ?_⟩ exact le_rfl - exact dyadicToReal_le_of_le hloDyadic + exact ratToReal_le_of_le hloRat have hvals : (valCert.valsLo k0 : Real) ≤ valsRealOfInputs inputs k0 ∧ valsRealOfInputs inputs k0 ≤ (valCert.valsHi k0 : Real) := by - have hv := hv_bounds k0 - have hlo' : ∀ d, (vLo k0 d : Real) ≤ vRealOfInputs inputs k0 d := fun d => - (hv d).1 - have hhi' : ∀ d, vRealOfInputs inputs k0 d ≤ (vHi k0 d : Real) := fun d => - (hv d).2 - have hlow := - dotIntervalLower_le_dotProduct_real (v := dirHead) - (lo := vLo k0) (hi := vHi k0) - (x := fun d => vRealOfInputs inputs k0 d) hlo' hhi' - have hhigh := - dotProduct_le_dotIntervalUpper_real (v := dirHead) - (lo := vLo k0) (hi := vHi k0) - (x := fun d => vRealOfInputs inputs k0 d) hlo' hhi' - have hlow' : - (valCert.valsLo k0 : Real) ≤ valsRealOfInputs inputs k0 := by - simpa [valsLo, valCert, dirHead, valsRealOfInputs, - Bounds.cacheBound_apply, Bounds.dotIntervalLowerCachedRat_eq] using hlow - have hhigh' : - valsRealOfInputs inputs k0 ≤ (valCert.valsHi k0 : Real) := by - simpa [valsHi, valCert, dirHead, valsRealOfInputs, - Bounds.cacheBound_apply, Bounds.dotIntervalUpperCachedRat_eq] using hhigh - exact ⟨hlow', hhigh'⟩ + exact hvals_bounds_at k0 have hhi : (valCert.valsHi k0 : Real) ≤ (valCert.hi : Real) := by - have hhiDyadic : valCert.valsHi k0 ≤ valCert.hi := by + have hhiRat : valCert.valsHi k0 ≤ valCert.hi := by change valsHi k0 ≤ hi dsimp [hi] refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) (f := valsHi) (a := valsHi k0)).2 ?_ exact ⟨k0, ⟨hmem0, le_rfl⟩⟩ - exact dyadicToReal_le_of_le hhiDyadic + exact ratToReal_le_of_le hhiRat have hreal : (valCert.lo : Real) ≤ (valCert.hi : Real) := le_trans hlo (le_trans hvals.1 (le_trans hvals.2 hhi)) - exact (dyadicToReal_le_iff (x := valCert.lo) (y := valCert.hi)).1 hreal + exact (ratToReal_le_iff (x := valCert.lo) (y := valCert.hi)).1 hreal · intro k have hmem : k ∈ univ := by simp [univ] - have hloDyadic : valCert.lo ≤ valCert.valsLo k := by + have hloRat : valCert.lo ≤ valCert.valsLo k := by change lo ≤ valsLo k dsimp [lo] refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) (f := valsLo) (a := valsLo k)).2 ?_ refine ⟨k, hmem, ?_⟩ exact le_rfl - exact dyadicToReal_le_of_le hloDyadic + exact ratToReal_le_of_le hloRat · intro k - have hv := hv_bounds k - have hlo' : ∀ d, (vLo k d : Real) ≤ vRealOfInputs inputs k d := fun d => (hv d).1 - have hhi' : ∀ d, vRealOfInputs inputs k d ≤ (vHi k d : Real) := fun d => (hv d).2 - have hlow := - dotIntervalLower_le_dotProduct_real (v := dirHead) - (lo := vLo k) (hi := vHi k) - (x := fun d => vRealOfInputs inputs k d) hlo' hhi' - have hhigh := - dotProduct_le_dotIntervalUpper_real (v := dirHead) - (lo := vLo k) (hi := vHi k) - (x := fun d => vRealOfInputs inputs k d) hlo' hhi' - have hlow' : - (valCert.valsLo k : Real) ≤ valsRealOfInputs inputs k := by - simpa [valsLo, valCert, dirHead, valsRealOfInputs, - Bounds.cacheBound_apply, Bounds.dotIntervalLowerCachedRat_eq] using hlow - have hhigh' : - valsRealOfInputs inputs k ≤ (valCert.valsHi k : Real) := by - simpa [valsHi, valCert, dirHead, valsRealOfInputs, - Bounds.cacheBound_apply, Bounds.dotIntervalUpperCachedRat_eq] using hhigh - exact ⟨hlow', hhigh'⟩ + exact hvals_bounds_at k · intro k have hmem : k ∈ univ := by simp [univ] - have hhiDyadic : valCert.valsHi k ≤ valCert.hi := by + have hhiRat : valCert.valsHi k ≤ valCert.hi := by change valsHi k ≤ hi dsimp [hi] refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) (f := valsHi) (a := valsHi k)).2 ?_ refine ⟨k, hmem, ?_⟩ exact le_rfl - exact dyadicToReal_le_of_le hhiDyadic + exact ratToReal_le_of_le hhiRat exact { softmax_bounds := hsoftmax_bounds oneHot_bounds_at := oneHot_bounds_at From 44440784b4cc2f6e9b13a1b9503b8d3b55bef123 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 8 Jan 2026 13:11:21 +0100 Subject: [PATCH 111/244] Replace Dyadic usage with Rat --- Nfp/Circuit/Cert/DownstreamLinear.lean | 6 +- Nfp/Circuit/Cert/LogitDiff.lean | 20 +- Nfp/Circuit/Cert/ResidualBound.lean | 2 +- Nfp/Circuit/Cert/ResidualInterval.lean | 4 +- Nfp/Circuit/Cert/SoftmaxMargin.lean | 12 +- Nfp/Circuit/Cert/ValueRange.lean | 10 +- Nfp/Core/Basic.lean | 706 +++------------ Nfp/IO.lean | 74 +- Nfp/IO/Bench/InductionCore.lean | 86 +- Nfp/IO/Bench/Rational.lean | 124 +-- Nfp/IO/Derive.lean | 6 +- Nfp/IO/HeadScore.lean | 12 +- Nfp/IO/NfptPure.lean | 68 +- Nfp/IO/Pure/Basic.lean | 8 +- Nfp/IO/Pure/Downstream.lean | 32 +- Nfp/IO/Pure/InductionHead/Bytes.lean | 164 ++-- Nfp/IO/Pure/Residual.lean | 24 +- Nfp/IO/Pure/SoftmaxMargin/Cert.lean | 4 +- Nfp/IO/Pure/SoftmaxMargin/Raw.lean | 8 +- Nfp/IO/Pure/SoftmaxMargin/Shared.lean | 22 +- Nfp/IO/Pure/ValueRange/Cert.lean | 2 +- Nfp/IO/Pure/ValueRange/Raw.lean | 4 +- Nfp/IO/Pure/ValueRange/Shared.lean | 16 +- Nfp/IO/Timing.lean | 16 +- Nfp/IO/Util.lean | 8 +- Nfp/Model/Gpt2.lean | 68 +- Nfp/Model/InductionHead.lean | 32 +- Nfp/Sound/Bounds/Attention.lean | 222 ++--- Nfp/Sound/Bounds/Cache.lean | 245 +++++ Nfp/Sound/Bounds/Gelu.lean | 22 +- Nfp/Sound/Bounds/LayerNorm.lean | 423 +++++---- Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean | 64 +- Nfp/Sound/Bounds/MatrixNorm.lean | 56 +- Nfp/Sound/Bounds/MatrixNorm/Interval.lean | 204 +++-- Nfp/Sound/Bounds/Mlp.lean | 122 +-- Nfp/Sound/Bounds/Transformer.lean | 92 +- Nfp/Sound/Bounds/Transformer/Embedding.lean | 32 +- Nfp/Sound/Bounds/UnnormRat.lean | 41 +- Nfp/Sound/Gpt2/HeadInputs.lean | 4 +- Nfp/Sound/Induction/CoreDefs.lean | 16 +- Nfp/Sound/Induction/HeadBounds.lean | 888 ++++++++++++++----- Nfp/Sound/Induction/HeadOutput.lean | 80 +- Nfp/Sound/Induction/LogitDiff.lean | 16 +- Nfp/Sound/Induction/OneHot.lean | 46 +- Nfp/Sound/Linear/FinFold.lean | 56 +- 45 files changed, 2190 insertions(+), 1977 deletions(-) create mode 100644 Nfp/Sound/Bounds/Cache.lean diff --git a/Nfp/Circuit/Cert/DownstreamLinear.lean b/Nfp/Circuit/Cert/DownstreamLinear.lean index 41a61a9..8e01f23 100644 --- a/Nfp/Circuit/Cert/DownstreamLinear.lean +++ b/Nfp/Circuit/Cert/DownstreamLinear.lean @@ -18,11 +18,11 @@ namespace Circuit /-- Certificate payload for downstream linear error bounds. -/ structure DownstreamLinearCert where /-- Upper bound on the downstream logit-diff error. -/ - error : Dyadic + error : Rat /-- Operator gain bound used to justify the error. -/ - gain : Dyadic + gain : Rat /-- Input magnitude bound used to justify the error. -/ - inputBound : Dyadic + inputBound : Rat /-- Arithmetic properties enforced by `checkDownstreamLinearCert`. -/ structure DownstreamLinearBounds (c : DownstreamLinearCert) : Prop where diff --git a/Nfp/Circuit/Cert/LogitDiff.lean b/Nfp/Circuit/Cert/LogitDiff.lean index f51d463..46a3992 100644 --- a/Nfp/Circuit/Cert/LogitDiff.lean +++ b/Nfp/Circuit/Cert/LogitDiff.lean @@ -17,11 +17,11 @@ variable {seq : Nat} /-- Compute a lower bound on the logit-diff contribution over active queries. -/ def logitDiffLowerBound (active : Finset (Fin seq)) (prev : Fin seq → Fin seq) - (eps lo hi : Dyadic) (vals : Fin seq → Dyadic) : Option Dyadic := by + (eps lo hi : Rat) (vals : Fin seq → Rat) : Option Rat := by classical if h : active.Nonempty then let gap := eps * (hi - lo) - let f : Fin seq → Dyadic := fun q => vals (prev q) - gap + let f : Fin seq → Rat := fun q => vals (prev q) - gap let img := active.image f have himg : img.Nonempty := h.image f exact some (Finset.min' img himg) @@ -31,11 +31,11 @@ def logitDiffLowerBound (active : Finset (Fin seq)) /-- Compute a lower bound on the logit-diff contribution with per-query eps. -/ def logitDiffLowerBoundAt (active : Finset (Fin seq)) (prev : Fin seq → Fin seq) - (epsAt : Fin seq → Dyadic) (lo hi : Dyadic) (vals : Fin seq → Dyadic) : Option Dyadic := by + (epsAt : Fin seq → Rat) (lo hi : Rat) (vals : Fin seq → Rat) : Option Rat := by classical if h : active.Nonempty then - let gap : Fin seq → Dyadic := fun q => epsAt q * (hi - lo) - let f : Fin seq → Dyadic := fun q => vals (prev q) - gap q + let gap : Fin seq → Rat := fun q => epsAt q * (hi - lo) + let f : Fin seq → Rat := fun q => vals (prev q) - gap q let img := active.image f have himg : img.Nonempty := h.image f exact some (Finset.min' img himg) @@ -45,7 +45,7 @@ def logitDiffLowerBoundAt (active : Finset (Fin seq)) /-- The computed lower bound is below every active `prev` value minus the tolerance gap. -/ theorem logitDiffLowerBound_le (active : Finset (Fin seq)) (prev : Fin seq → Fin seq) - (eps lo hi : Dyadic) (vals : Fin seq → Dyadic) + (eps lo hi : Rat) (vals : Fin seq → Rat) (q : Fin seq) (hq : q ∈ active) : ∀ lb, logitDiffLowerBound active prev eps lo hi vals = some lb → lb ≤ vals (prev q) - eps * (hi - lo) := by @@ -57,7 +57,7 @@ theorem logitDiffLowerBound_le (active : Finset (Fin seq)) (hnonempty.image (fun q => vals (prev q) - eps * (hi - lo))) = lb := by simpa [logitDiffLowerBound, hnonempty] using hbound let gap := eps * (hi - lo) - let f : Fin seq → Dyadic := fun q => vals (prev q) - gap + let f : Fin seq → Rat := fun q => vals (prev q) - gap have hmem : f q ∈ (active.image f) := by refine Finset.mem_image.2 ?_ exact ⟨q, hq, rfl⟩ @@ -70,7 +70,7 @@ theorem logitDiffLowerBound_le (active : Finset (Fin seq)) /-- The per-query lower bound is below every active `prev` value minus the local gap. -/ theorem logitDiffLowerBoundAt_le (active : Finset (Fin seq)) (prev : Fin seq → Fin seq) - (epsAt : Fin seq → Dyadic) (lo hi : Dyadic) (vals : Fin seq → Dyadic) + (epsAt : Fin seq → Rat) (lo hi : Rat) (vals : Fin seq → Rat) (q : Fin seq) (hq : q ∈ active) : ∀ lb, logitDiffLowerBoundAt active prev epsAt lo hi vals = some lb → lb ≤ vals (prev q) - epsAt q * (hi - lo) := by @@ -81,8 +81,8 @@ theorem logitDiffLowerBoundAt_le (active : Finset (Fin seq)) (active.image (fun q => vals (prev q) - epsAt q * (hi - lo))).min' (hnonempty.image (fun q => vals (prev q) - epsAt q * (hi - lo))) = lb := by simpa [logitDiffLowerBoundAt, hnonempty] using hbound - let gap : Fin seq → Dyadic := fun q => epsAt q * (hi - lo) - let f : Fin seq → Dyadic := fun q => vals (prev q) - gap q + let gap : Fin seq → Rat := fun q => epsAt q * (hi - lo) + let f : Fin seq → Rat := fun q => vals (prev q) - gap q have hmem : f q ∈ (active.image f) := by refine Finset.mem_image.2 ?_ exact ⟨q, hq, rfl⟩ diff --git a/Nfp/Circuit/Cert/ResidualBound.lean b/Nfp/Circuit/Cert/ResidualBound.lean index 2a85061..9adc684 100644 --- a/Nfp/Circuit/Cert/ResidualBound.lean +++ b/Nfp/Circuit/Cert/ResidualBound.lean @@ -16,7 +16,7 @@ namespace Circuit /-- Certificate payload for per-coordinate residual absolute bounds. -/ structure ResidualBoundCert (n : Nat) where /-- Absolute bound per coordinate. -/ - bound : Fin n → Dyadic + bound : Fin n → Rat /-- Properties enforced by `checkResidualBoundCert`. -/ structure ResidualBoundBounds {n : Nat} (c : ResidualBoundCert n) : Prop where diff --git a/Nfp/Circuit/Cert/ResidualInterval.lean b/Nfp/Circuit/Cert/ResidualInterval.lean index 9360359..88370cd 100644 --- a/Nfp/Circuit/Cert/ResidualInterval.lean +++ b/Nfp/Circuit/Cert/ResidualInterval.lean @@ -16,9 +16,9 @@ namespace Circuit /-- Certificate payload for per-coordinate residual intervals. -/ structure ResidualIntervalCert (n : Nat) where /-- Lower bound per coordinate. -/ - lo : Fin n → Dyadic + lo : Fin n → Rat /-- Upper bound per coordinate. -/ - hi : Fin n → Dyadic + hi : Fin n → Rat /-- Properties enforced by `checkResidualIntervalCert`. -/ structure ResidualIntervalBounds {n : Nat} (c : ResidualIntervalCert n) : Prop where diff --git a/Nfp/Circuit/Cert/SoftmaxMargin.lean b/Nfp/Circuit/Cert/SoftmaxMargin.lean index 0d0a1a2..0e9de16 100644 --- a/Nfp/Circuit/Cert/SoftmaxMargin.lean +++ b/Nfp/Circuit/Cert/SoftmaxMargin.lean @@ -17,20 +17,20 @@ open scoped BigOperators variable {seq : Nat} -/-- Certificate payload for softmax-margin bounds (Dyadic-valued). -/ +/-- Certificate payload for softmax-margin bounds (Rat-valued). -/ structure SoftmaxMarginCert (seq : Nat) where /-- Weight tolerance. -/ - eps : Dyadic + eps : Rat /-- Score margin used to justify weight bounds. -/ - margin : Dyadic + margin : Rat /-- Active queries for which bounds are checked. -/ active : Finset (Fin seq) /-- `prev` selector for induction-style attention. -/ prev : Fin seq → Fin seq /-- Score matrix entries. -/ - scores : Fin seq → Fin seq → Dyadic + scores : Fin seq → Fin seq → Rat /-- Attention weight entries. -/ - weights : Fin seq → Fin seq → Dyadic + weights : Fin seq → Fin seq → Rat /-- Boolean checker for softmax-margin certificates. -/ def checkSoftmaxMarginCert [NeZero seq] (c : SoftmaxMarginCert seq) : Bool := @@ -55,7 +55,7 @@ def checkSoftmaxMarginCert [NeZero seq] (c : SoftmaxMarginCert seq) : Bool := /-- `checkSoftmaxMarginCert` is sound for `SoftmaxMarginBoundsOn`. -/ theorem checkSoftmaxMarginCert_sound [NeZero seq] (c : SoftmaxMarginCert seq) : checkSoftmaxMarginCert c = true → - Layers.SoftmaxMarginBoundsOn (Val := Dyadic) c.eps c.margin (fun q => q ∈ c.active) + Layers.SoftmaxMarginBoundsOn (Val := Rat) c.eps c.margin (fun q => q ∈ c.active) c.prev c.scores c.weights := by classical intro hcheck diff --git a/Nfp/Circuit/Cert/ValueRange.lean b/Nfp/Circuit/Cert/ValueRange.lean index 7e5f5df..fb70c55 100644 --- a/Nfp/Circuit/Cert/ValueRange.lean +++ b/Nfp/Circuit/Cert/ValueRange.lean @@ -24,14 +24,14 @@ structure DirectionSpec where /-- Negative token id for the logit-diff direction. -/ negative : Nat -/-- Certificate payload for value-range bounds (Dyadic-valued). -/ +/-- Certificate payload for value-range bounds (Rat-valued). -/ structure ValueRangeCert (seq : Nat) where /-- Lower bound for values. -/ - lo : Dyadic + lo : Rat /-- Upper bound for values. -/ - hi : Dyadic + hi : Rat /-- Value entries. -/ - vals : Fin seq → Dyadic + vals : Fin seq → Rat /-- Optional logit-diff direction metadata (ignored by the checker). -/ direction : Option DirectionSpec @@ -44,7 +44,7 @@ def checkValueRangeCert [NeZero seq] (c : ValueRangeCert seq) : Bool := /-- `checkValueRangeCert` is sound for `ValueRangeBounds`. -/ theorem checkValueRangeCert_sound [NeZero seq] (c : ValueRangeCert seq) : checkValueRangeCert c = true → - Layers.ValueRangeBounds (Val := Dyadic) c.lo c.hi c.vals := by + Layers.ValueRangeBounds (Val := Rat) c.lo c.hi c.vals := by classical intro hcheck have hcheck' : diff --git a/Nfp/Core/Basic.lean b/Nfp/Core/Basic.lean index 7e3fb48..f9589c2 100644 --- a/Nfp/Core/Basic.lean +++ b/Nfp/Core/Basic.lean @@ -1,12 +1,8 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Mathlib.Data.NNReal.Basic -import Mathlib.Algebra.Order.Floor.Defs import Mathlib.Data.Rat.Cast.Lemmas import Mathlib.Data.Rat.Cast.Order -import Init.Data.Dyadic -import Init.Data.Dyadic.Inv -import Init.Data.Dyadic.Round /-! Basic shared definitions for the NFP rewrite. @@ -17,579 +13,147 @@ namespace Nfp /-- Nonnegative mass used for probabilities and weights. -/ abbrev Mass := NNReal -instance : ToString Dyadic := - ⟨fun x => toString x.toRat⟩ - -/-- Default dyadic precision (binary digits after the point). -/ -def defaultDyadicPrec : Int := 48 - -/-- One ulp at the given dyadic precision. -/ -def dyadicUlp (prec : Int := defaultDyadicPrec) : Dyadic := - Dyadic.ofIntWithPrec 1 prec - -/-- Round a rational down to dyadic precision. -/ -def dyadicOfRatDown (q : Rat) (prec : Int := defaultDyadicPrec) : Dyadic := - Rat.toDyadic q prec - -/-- Round a rational up to dyadic precision. -/ -def dyadicOfRatUp (q : Rat) (prec : Int := defaultDyadicPrec) : Dyadic := - Rat.toDyadic q prec + dyadicUlp prec - -instance : Coe Dyadic Rat := ⟨Dyadic.toRat⟩ - -/-- Real cast of a dyadic value via `Rat`. -/ -def dyadicToReal (x : Dyadic) : Real := - (x.toRat : Real) - -instance : Coe Dyadic Real := ⟨dyadicToReal⟩ - -@[simp] theorem dyadicToReal_zero : dyadicToReal 0 = 0 := by - simp [dyadicToReal] - -@[simp] theorem dyadicToReal_one : dyadicToReal 1 = 1 := by - change ((1 : Dyadic).toRat : Real) = 1 - have h : (1 : Dyadic).toRat = (1 : Rat) := Dyadic.toRat_natCast 1 - simp [h] - -@[simp] theorem dyadicToRat_one : (Dyadic.toRat 1 : Rat) = 1 := by - exact Dyadic.toRat_natCast 1 - -@[simp] theorem dyadicOfInt_zero : Dyadic.ofInt 0 = 0 := by - simp [Dyadic.ofInt, Dyadic.ofIntWithPrec] - -@[simp] theorem dyadicOfInt_toRat (i : Int) : (Dyadic.ofInt i).toRat = i := by - change ((i : Dyadic).toRat = i) - exact Dyadic.toRat_intCast (x := i) - -@[simp] theorem dyadicOfInt_succ (n : Nat) : Dyadic.ofInt (n + 1) = Dyadic.ofInt n + 1 := by - apply (Dyadic.toRat_inj).1 - calc - (Dyadic.ofInt (n + 1)).toRat = (n + 1 : Int) := by - simp - _ = (n : Int) + 1 := by - simp - _ = (Dyadic.ofInt n).toRat + 1 := by - simp - _ = (Dyadic.ofInt n + 1).toRat := by - simp [Dyadic.toRat_add] - -theorem dyadicOfRatDown_le (q : Rat) (prec : Int := defaultDyadicPrec) : - (dyadicOfRatDown q prec : Rat) ≤ q := by - simpa [dyadicOfRatDown] using (Rat.toRat_toDyadic_le (x := q) (prec := prec)) - -theorem dyadicOfRatUp_ge (q : Rat) (prec : Int := defaultDyadicPrec) : - q ≤ (dyadicOfRatUp q prec : Rat) := by - have hlt := Rat.lt_toRat_toDyadic_add (x := q) (prec := prec) - exact le_of_lt (by simpa [dyadicOfRatUp, dyadicUlp] using hlt) - -theorem dyadicOfRatDown_le_real (q : Rat) (prec : Int := defaultDyadicPrec) : - (dyadicOfRatDown q prec : Real) ≤ (q : Real) := by - have h := - (Rat.cast_le (K := Real) (p := (dyadicOfRatDown q prec : Rat)) (q := q)).2 - (dyadicOfRatDown_le q prec) - simpa [dyadicToReal] using h - -theorem real_le_dyadicOfRatUp (q : Rat) (prec : Int := defaultDyadicPrec) : - (q : Real) ≤ (dyadicOfRatUp q prec : Real) := by - have h := - (Rat.cast_le (K := Real) (p := q) (q := (dyadicOfRatUp q prec : Rat))).2 - (dyadicOfRatUp_ge q prec) - simpa [dyadicToReal] using h - -theorem dyadicOfRatDown_lt_add_ulp (q : Rat) (prec : Int := defaultDyadicPrec) : - q < (dyadicOfRatDown q prec : Rat) + (dyadicUlp prec : Rat) := by - have h := Rat.lt_toRat_toDyadic_add (x := q) (prec := prec) - simpa [dyadicOfRatDown, dyadicUlp, Dyadic.toRat_add] using h - -theorem dyadicOfRatDown_sub_ulp_lt (q : Rat) (prec : Int := defaultDyadicPrec) : - q - (dyadicUlp prec : Rat) < (dyadicOfRatDown q prec : Rat) := by - exact (sub_lt_iff_lt_add).2 (dyadicOfRatDown_lt_add_ulp q prec) - -theorem dyadicOfRatUp_le_add_ulp (q : Rat) (prec : Int := defaultDyadicPrec) : - (dyadicOfRatUp q prec : Rat) ≤ q + (dyadicUlp prec : Rat) := by - have hdown : (dyadicOfRatDown q prec : Rat) ≤ q := dyadicOfRatDown_le q prec - have hsum : (dyadicOfRatDown q prec : Rat) + (dyadicUlp prec : Rat) ≤ - q + (dyadicUlp prec : Rat) := by - exact Rat.add_le_add_right.2 hdown - simpa [dyadicOfRatUp, dyadicUlp, dyadicOfRatDown, Dyadic.toRat_add] using hsum - -theorem dyadicOfRatDown_lt_add_ulp_real (q : Rat) (prec : Int := defaultDyadicPrec) : - (q : Real) < (dyadicOfRatDown q prec : Real) + (dyadicUlp prec : Real) := by - have h := - (Rat.cast_lt (K := Real) (p := q) - (q := (dyadicOfRatDown q prec : Rat) + (dyadicUlp prec : Rat))).2 - (dyadicOfRatDown_lt_add_ulp q prec) - simpa [dyadicToReal] using h - -theorem dyadicOfRatUp_le_add_ulp_real (q : Rat) (prec : Int := defaultDyadicPrec) : - (dyadicOfRatUp q prec : Real) ≤ (q : Real) + (dyadicUlp prec : Real) := by - have h := - (Rat.cast_le (K := Real) - (p := (dyadicOfRatUp q prec : Rat)) - (q := q + (dyadicUlp prec : Rat))).2 - (dyadicOfRatUp_le_add_ulp q prec) - simpa [dyadicToReal] using h - -theorem dyadicOfRatDown_nonneg {q : Rat} (hq : 0 ≤ q) (prec : Int := defaultDyadicPrec) : - 0 ≤ dyadicOfRatDown q prec := by - apply (Dyadic.toRat_le_toRat_iff (x := 0) (y := dyadicOfRatDown q prec)).1 - have hrat : - (dyadicOfRatDown q prec : Rat) = - (q * 2 ^ prec).floor / 2 ^ prec := by - simpa [dyadicOfRatDown] using (Rat.toRat_toDyadic (x := q) (prec := prec)) - have hpow_pos : (0 : Rat) < (2 : Rat) ^ prec := by - exact Rat.zpow_pos (by decide : (0 : Rat) < 2) - have hmul_nonneg : (0 : Rat) ≤ q * 2 ^ prec := by - exact mul_nonneg hq (le_of_lt hpow_pos) - have hfloor_nonneg : 0 ≤ (q * 2 ^ prec).floor := by - exact (Int.floor_nonneg (a := q * 2 ^ prec)).2 hmul_nonneg - have hfloor_nonneg_rat : (0 : Rat) ≤ ((q * 2 ^ prec).floor : Rat) := by - exact_mod_cast hfloor_nonneg - have hdiv_nonneg : - (0 : Rat) ≤ ((q * 2 ^ prec).floor : Rat) / (2 : Rat) ^ prec := by - exact div_nonneg hfloor_nonneg_rat (le_of_lt hpow_pos) - simpa [hrat] using hdiv_nonneg - -private lemma dyadicUlp_rat (prec : Int) : - (dyadicUlp prec : Rat) = (1 : Rat) * 2 ^ (-prec) := by - simp [dyadicUlp, Dyadic.toRat_ofIntWithPrec_eq_mul_two_pow] - -theorem dyadicUlp_nonneg (prec : Int := defaultDyadicPrec) : 0 ≤ dyadicUlp prec := by - apply (Dyadic.toRat_le_toRat_iff (x := 0) (y := dyadicUlp prec)).1 - have hrat : (dyadicUlp prec : Rat) = (1 : Rat) * 2 ^ (-prec) := dyadicUlp_rat prec - have hpow_pos : (0 : Rat) < (2 : Rat) ^ (-prec) := by - exact Rat.zpow_pos (by decide : (0 : Rat) < 2) - have hnonneg : (0 : Rat) ≤ (1 : Rat) * 2 ^ (-prec) := by - exact mul_nonneg (by decide : (0 : Rat) ≤ 1) (le_of_lt hpow_pos) - simpa [hrat] using hnonneg - -theorem dyadicUlp_pos (prec : Int := defaultDyadicPrec) : 0 < dyadicUlp prec := by - apply (Dyadic.toRat_lt_toRat_iff (x := 0) (y := dyadicUlp prec)).1 - have hrat : (dyadicUlp prec : Rat) = (1 : Rat) * 2 ^ (-prec) := dyadicUlp_rat prec - have hpow_pos : (0 : Rat) < (2 : Rat) ^ (-prec) := by - exact Rat.zpow_pos (by decide : (0 : Rat) < 2) - have hpos : (0 : Rat) < (1 : Rat) * 2 ^ (-prec) := by - exact mul_pos (by decide : (0 : Rat) < 1) hpow_pos - simpa [hrat] using hpos - -theorem dyadicOfRatUp_nonneg {q : Rat} (hq : 0 ≤ q) (prec : Int := defaultDyadicPrec) : - 0 ≤ dyadicOfRatUp q prec := by - have hdown : 0 ≤ dyadicOfRatDown q prec := dyadicOfRatDown_nonneg hq prec - have hulp : 0 ≤ dyadicUlp prec := dyadicUlp_nonneg prec - apply (Dyadic.toRat_le_toRat_iff (x := 0) (y := dyadicOfRatUp q prec)).1 - have hdown_rat : (0 : Rat) ≤ (dyadicOfRatDown q prec : Rat) := - (Dyadic.toRat_le_toRat_iff (x := 0) (y := dyadicOfRatDown q prec)).2 hdown - have hulp_rat : (0 : Rat) ≤ (dyadicUlp prec : Rat) := - (Dyadic.toRat_le_toRat_iff (x := 0) (y := dyadicUlp prec)).2 hulp - have hsum : (0 : Rat) ≤ (dyadicOfRatDown q prec : Rat) + (dyadicUlp prec : Rat) := by - exact Rat.add_nonneg hdown_rat hulp_rat - simpa [dyadicOfRatUp, dyadicUlp, dyadicOfRatDown, Dyadic.toRat_add] using hsum - -theorem dyadicOfRatUp_pos {q : Rat} (hq : 0 < q) (prec : Int := defaultDyadicPrec) : - 0 < dyadicOfRatUp q prec := by - apply (Dyadic.toRat_lt_toRat_iff (x := 0) (y := dyadicOfRatUp q prec)).1 - have hdown_nonneg : (0 : Rat) ≤ (dyadicOfRatDown q prec : Rat) := by - have hdown : 0 ≤ dyadicOfRatDown q prec := dyadicOfRatDown_nonneg hq.le prec - exact (Dyadic.toRat_le_toRat_iff (x := 0) (y := dyadicOfRatDown q prec)).2 hdown - have hulp_pos : (0 : Rat) < (dyadicUlp prec : Rat) := by - exact (Dyadic.toRat_lt_toRat_iff (x := 0) (y := dyadicUlp prec)).2 (dyadicUlp_pos prec) - have hsum : (0 : Rat) < (dyadicOfRatDown q prec : Rat) + (dyadicUlp prec : Rat) := by - nlinarith - simpa [dyadicOfRatUp, dyadicUlp, dyadicOfRatDown, Dyadic.toRat_add] using hsum - --- TODO: use a tighter upper rounding when needed. - -/-- Dyadic division with downward rounding at the chosen precision. -/ -def dyadicDivDown (x y : Dyadic) (prec : Int := defaultDyadicPrec) : Dyadic := +/-- Default precision placeholder retained for API compatibility. -/ +def defaultRatPrec : Int := 48 + +/-- Round a rational down (identity in the exact-rational refactor). -/ +def ratRoundDown (q : Rat) (_prec : Int := defaultRatPrec) : Rat := + q + +/-- Round a rational up (identity in the exact-rational refactor). -/ +def ratRoundUp (q : Rat) (_prec : Int := defaultRatPrec) : Rat := + q + +/-- Real cast of a rational value. -/ +def ratToReal (x : Rat) : Real := + (x : Real) + +@[simp] theorem ratToReal_zero : ratToReal 0 = 0 := by + simp [ratToReal] + +@[simp] theorem ratToReal_one : ratToReal 1 = 1 := by + simp [ratToReal] + +theorem ratRoundDown_le (q : Rat) (_prec : Int := defaultRatPrec) : + (ratRoundDown q _prec : Rat) ≤ q := by + simp [ratRoundDown] + +theorem ratRoundUp_ge (q : Rat) (_prec : Int := defaultRatPrec) : + q ≤ (ratRoundUp q _prec : Rat) := by + simp [ratRoundUp] + +theorem ratRoundDown_le_real (q : Rat) (_prec : Int := defaultRatPrec) : + (ratRoundDown q _prec : Real) ≤ (q : Real) := by + simp [ratRoundDown] + +theorem real_le_ratRoundUp (q : Rat) (_prec : Int := defaultRatPrec) : + (q : Real) ≤ (ratRoundUp q _prec : Real) := by + simp [ratRoundUp] + +theorem ratRoundDown_nonneg {q : Rat} (hq : 0 ≤ q) (_prec : Int := defaultRatPrec) : + 0 ≤ ratRoundDown q _prec := by + simpa [ratRoundDown] using hq + +theorem ratRoundUp_nonneg {q : Rat} (hq : 0 ≤ q) (_prec : Int := defaultRatPrec) : + 0 ≤ ratRoundUp q _prec := by + simpa [ratRoundUp] using hq + +theorem ratRoundUp_pos {q : Rat} (hq : 0 < q) (_prec : Int := defaultRatPrec) : + 0 < ratRoundUp q _prec := by + simpa [ratRoundUp] using hq + +/-- Interpret `n * 2^{-prec}` as a rational (matching power-of-two scaling). -/ +def ratOfIntWithPrec (n : Int) (prec : Int) : Rat := + let pow2 (k : Nat) : Rat := Rat.ofInt (Int.ofNat (Nat.pow 2 k)) + if _h : 0 ≤ prec then + Rat.ofInt n / pow2 (Int.toNat prec) + else + Rat.ofInt n * pow2 (Int.toNat (-prec)) + +/-- Rational division with downward rounding (exact for rationals). -/ +def ratDivDown (x y : Rat) (_prec : Int := defaultRatPrec) : Rat := if y = 0 then 0 else - dyadicOfRatDown (x.toRat / y.toRat) prec + x / y -/-- Dyadic division with upward rounding at the chosen precision. -/ -def dyadicDivUp (x y : Dyadic) (prec : Int := defaultDyadicPrec) : Dyadic := +/-- Rational division with upward rounding (exact for rationals). -/ +def ratDivUp (x y : Rat) (_prec : Int := defaultRatPrec) : Rat := if y = 0 then 0 else - dyadicOfRatUp (x.toRat / y.toRat) prec - -theorem dyadicDivUp_ge (x y : Dyadic) (hy : y ≠ 0) : - (x.toRat / y.toRat : Rat) ≤ (dyadicDivUp x y : Rat) := by - have hrat : (x.toRat / y.toRat : Rat) ≤ (dyadicOfRatUp (x.toRat / y.toRat) : Rat) := - dyadicOfRatUp_ge (x.toRat / y.toRat) - simpa [dyadicDivUp, hy] using hrat - -theorem dyadicDivUp_ge_real (x y : Dyadic) (hy : y ≠ 0) : - (x.toRat : Real) / (y.toRat : Real) ≤ dyadicToReal (dyadicDivUp x y) := by - have hrat : (x.toRat / y.toRat : Rat) ≤ (dyadicDivUp x y : Rat) := - dyadicDivUp_ge x y hy - have hrat' : ((x.toRat / y.toRat : Rat) : Real) ≤ ((dyadicDivUp x y : Rat) : Real) := - (Rat.cast_le (K := Real) (p := _) (q := _)).2 hrat - simpa [dyadicToReal] using hrat' - -@[simp] theorem dyadicToReal_add (x y : Dyadic) : - dyadicToReal (x + y) = dyadicToReal x + dyadicToReal y := by - simp [dyadicToReal, Dyadic.toRat_add] - -@[simp] theorem dyadicToReal_sub (x y : Dyadic) : - dyadicToReal (x - y) = dyadicToReal x - dyadicToReal y := by - simp [dyadicToReal, Dyadic.toRat_sub] - -@[simp] theorem dyadicToReal_mul (x y : Dyadic) : - dyadicToReal (x * y) = dyadicToReal x * dyadicToReal y := by - simp [dyadicToReal, Dyadic.toRat_mul] - -@[simp] theorem dyadicToReal_neg (x : Dyadic) : - dyadicToReal (-x) = -dyadicToReal x := by - simp [dyadicToReal, Dyadic.toRat_neg] - -@[simp] theorem dyadicToReal_if {p : Prop} [Decidable p] (a b : Dyadic) : - dyadicToReal (if p then a else b) = - if p then dyadicToReal a else dyadicToReal b := by - by_cases hp : p <;> simp [hp] - -theorem dyadicToReal_le_iff {x y : Dyadic} : - dyadicToReal x ≤ dyadicToReal y ↔ x ≤ y := by - constructor - · intro h - have h' : x.toRat ≤ y.toRat := by - have h'' : (x.toRat : Real) ≤ (y.toRat : Real) := by - simpa [dyadicToReal] using h - exact (Rat.cast_le (K := Real) (p := x.toRat) (q := y.toRat)).1 h'' - exact (Dyadic.toRat_le_toRat_iff).1 h' - · intro h - have h' : x.toRat ≤ y.toRat := (Dyadic.toRat_le_toRat_iff).2 h - have h'' : (x.toRat : Real) ≤ (y.toRat : Real) := - (Rat.cast_le (K := Real) (p := x.toRat) (q := y.toRat)).2 h' - simpa [dyadicToReal] using h'' - -/-- Dyadic order implies real order after casting. -/ -theorem dyadicToReal_le_of_le {x y : Dyadic} (h : x ≤ y) : - dyadicToReal x ≤ dyadicToReal y := - (dyadicToReal_le_iff (x := x) (y := y)).2 h - -theorem dyadicToReal_lt_iff {x y : Dyadic} : - dyadicToReal x < dyadicToReal y ↔ x < y := by - constructor - · intro h - have h' : x.toRat < y.toRat := by - have h'' : (x.toRat : Real) < (y.toRat : Real) := by - simpa [dyadicToReal] using h - exact (Rat.cast_lt (K := Real) (p := x.toRat) (q := y.toRat)).1 h'' - exact (Dyadic.toRat_lt_toRat_iff).1 h' - · intro h - have h' : x.toRat < y.toRat := (Dyadic.toRat_lt_toRat_iff).2 h - have h'' : (x.toRat : Real) < (y.toRat : Real) := - (Rat.cast_lt (K := Real) (p := x.toRat) (q := y.toRat)).2 h' - simpa [dyadicToReal] using h'' - -theorem dyadicToReal_nonneg_iff {x : Dyadic} : - 0 ≤ dyadicToReal x ↔ 0 ≤ x := by - simpa [dyadicToReal] using (dyadicToReal_le_iff (x := 0) (y := x)) - -theorem dyadicToReal_nonneg_of_nonneg {x : Dyadic} (h : 0 ≤ x) : - 0 ≤ dyadicToReal x := - (dyadicToReal_nonneg_iff (x := x)).2 h - -theorem dyadicToReal_nonpos_iff {x : Dyadic} : - dyadicToReal x ≤ 0 ↔ x ≤ 0 := by - simpa [dyadicToReal_zero] using (dyadicToReal_le_iff (x := x) (y := 0)) - -instance : LinearOrder Dyadic where - le := (· ≤ ·) - lt := (· < ·) - le_refl := Dyadic.le_refl - le_trans := by intro a b c hab hbc; exact Dyadic.le_trans hab hbc - le_antisymm := by intro a b hab hba; exact Dyadic.le_antisymm hab hba - le_total := Dyadic.le_total - toDecidableLE := inferInstance - toDecidableEq := inferInstance - toDecidableLT := inferInstance - lt_iff_le_not_ge := by - intro a b - have hlt : a < b ↔ a.toRat < b.toRat := - (Dyadic.toRat_lt_toRat_iff (x := a) (y := b)).symm - have hle : a ≤ b ↔ a.toRat ≤ b.toRat := - (Dyadic.toRat_le_toRat_iff (x := a) (y := b)).symm - have hge : b ≤ a ↔ b.toRat ≤ a.toRat := - (Dyadic.toRat_le_toRat_iff (x := b) (y := a)).symm - have hrat : a.toRat < b.toRat ↔ a.toRat ≤ b.toRat ∧ ¬ b.toRat ≤ a.toRat := by - simpa using (Rat.lt_iff_le_not_ge (a := a.toRat) (b := b.toRat)) - calc - a < b ↔ a.toRat < b.toRat := hlt - _ ↔ a.toRat ≤ b.toRat ∧ ¬ b.toRat ≤ a.toRat := hrat - _ ↔ a ≤ b ∧ ¬ b ≤ a := by - constructor - · intro h - refine ⟨(hle.mpr h.1), ?_⟩ - intro hba - exact h.2 (hge.mp hba) - · intro h - refine ⟨(hle.mp h.1), ?_⟩ - intro hba - exact h.2 (hge.mpr hba) - min_def := by intro a b; rfl - max_def := by intro a b; rfl - compare_eq_compareOfLessAndEq := by intro a b; rfl - -instance : AddMonoid Dyadic where - add := (· + ·) - zero := 0 - add_assoc := Dyadic.add_assoc - zero_add := Dyadic.zero_add - add_zero := Dyadic.add_zero - nsmul := nsmulRec - nsmul_zero := by intro x; rfl - nsmul_succ := by intro n x; rfl - -instance : AddCommMonoid Dyadic := - { (inferInstance : AddMonoid Dyadic) with - add_comm := Dyadic.add_comm } - -instance : AddMonoidWithOne Dyadic where - add := (· + ·) - zero := 0 - add_assoc := Dyadic.add_assoc - zero_add := Dyadic.zero_add - add_zero := Dyadic.add_zero - nsmul := nsmulRec - nsmul_zero := by intro x; rfl - nsmul_succ := by intro n x; rfl - one := 1 - natCast := fun n => Dyadic.ofInt n - natCast_zero := by - simp [dyadicOfInt_zero] - natCast_succ := by - intro n - simp [dyadicOfInt_succ n] - -instance : AddCommMonoidWithOne Dyadic where - add := (· + ·) - zero := 0 - add_assoc := Dyadic.add_assoc - zero_add := Dyadic.zero_add - add_zero := Dyadic.add_zero - nsmul := nsmulRec - nsmul_zero := by intro x; rfl - nsmul_succ := by intro n x; rfl - one := 1 - natCast := fun n => Dyadic.ofInt n - natCast_zero := by - simp [dyadicOfInt_zero] - natCast_succ := by - intro n - simp [dyadicOfInt_succ n] - add_comm := Dyadic.add_comm - -instance : AddGroup Dyadic where - add := (· + ·) - zero := 0 - add_assoc := Dyadic.add_assoc - zero_add := Dyadic.zero_add - add_zero := Dyadic.add_zero - nsmul := nsmulRec - nsmul_zero := by intro x; rfl - nsmul_succ := by intro n x; rfl - neg := Neg.neg - sub := fun a b => a + -b - sub_eq_add_neg := by intro a b; rfl - zsmul := zsmulRec - zsmul_zero' := by intro a; rfl - zsmul_succ' := by intro n a; rfl - zsmul_neg' := by intro n a; rfl - neg_add_cancel := Dyadic.neg_add_cancel - -instance : AddCommGroup Dyadic := - { (inferInstance : AddGroup Dyadic) with - add_comm := Dyadic.add_comm } - -instance : IsOrderedAddMonoid Dyadic where - add_le_add_left a b h c := by - have hrat : a.toRat ≤ b.toRat := - (Dyadic.toRat_le_toRat_iff (x := a) (y := b)).2 h - have hrat' : a.toRat + c.toRat ≤ b.toRat + c.toRat := by - exact Rat.add_le_add_right.2 hrat - have hrat'' : - (a + c).toRat ≤ (b + c).toRat := by - simpa [Dyadic.toRat_add] using hrat' - exact (Dyadic.toRat_le_toRat_iff (x := a + c) (y := b + c)).1 hrat'' - -instance : ExistsAddOfLE Dyadic where - exists_add_of_le {a b} h := by - refine ⟨b - a, ?_⟩ - simp [sub_eq_add_neg] - -instance : Monoid Dyadic where - mul := (· * ·) - one := 1 - mul_assoc := Dyadic.mul_assoc - one_mul := Dyadic.one_mul - mul_one := Dyadic.mul_one - -instance : MonoidWithZero Dyadic := - { (inferInstance : Monoid Dyadic) with - zero := 0 - zero_mul := Dyadic.zero_mul - mul_zero := Dyadic.mul_zero } - -instance : Distrib Dyadic where - left_distrib := Dyadic.mul_add - right_distrib := Dyadic.add_mul - -instance : Semiring Dyadic where - add := (· + ·) - zero := 0 - add_assoc := Dyadic.add_assoc - zero_add := Dyadic.zero_add - add_zero := Dyadic.add_zero - add_comm := Dyadic.add_comm - nsmul := nsmulRec - nsmul_zero := by intro x; rfl - nsmul_succ := by intro n x; rfl - one := 1 - natCast := fun n => Dyadic.ofInt n - natCast_zero := by - simp [dyadicOfInt_zero] - natCast_succ := by - intro n - simp [dyadicOfInt_succ n] - mul := (· * ·) - mul_assoc := Dyadic.mul_assoc - one_mul := Dyadic.one_mul - mul_one := Dyadic.mul_one - left_distrib := Dyadic.mul_add - right_distrib := Dyadic.add_mul - zero_mul := Dyadic.zero_mul - mul_zero := Dyadic.mul_zero - -instance : CommSemiring Dyadic where - toSemiring := (inferInstance : Semiring Dyadic) - mul_comm := by intro a b; exact Dyadic.mul_comm a b - -instance : AddGroupWithOne Dyadic := - { (inferInstance : AddMonoidWithOne Dyadic), - (inferInstance : AddGroup Dyadic) with - intCast := Int.castDef - intCast_ofNat := by - intro n - apply (Dyadic.toRat_inj).1 - have hleft : (Int.castDef (R := Dyadic) (n : ℤ)).toRat = (n : Rat) := by - simp [Int.castDef, Dyadic.toRat_natCast] - have hright : (n : Dyadic).toRat = (n : Rat) := Dyadic.toRat_natCast n - exact hleft.trans hright.symm - intCast_negSucc := by - intro n - apply (Dyadic.toRat_inj).1 - simp [Int.castDef, Dyadic.toRat_natCast, Dyadic.toRat_neg, Dyadic.toRat_add] } - -instance : Ring Dyadic := - { (inferInstance : Semiring Dyadic), - (inferInstance : AddCommGroup Dyadic), - (inferInstance : AddGroupWithOne Dyadic) with } - -instance : CommRing Dyadic := - { (inferInstance : Ring Dyadic), (inferInstance : CommSemiring Dyadic) with } - -instance : Nontrivial Dyadic := - ⟨0, 1, by - intro h - have hrat : (0 : Rat) = (1 : Rat) := by - simpa [Dyadic.toRat_zero, dyadicToRat_one] using congrArg Dyadic.toRat h - exact (zero_ne_one (α := Rat)) hrat⟩ - -instance : ZeroLEOneClass Dyadic where - zero_le_one := by - have hrat : (0 : Rat) ≤ (1 : Rat) := by decide - exact (Dyadic.toRat_le_toRat_iff (x := (0 : Dyadic)) (y := (1 : Dyadic))).1 hrat - -instance : PosMulMono Dyadic where - mul_le_mul_of_nonneg_left {a} ha {b c} hbc := by - have ha' : (0 : Dyadic) ≤ a := by simpa using ha - have ha'' : (0 : Rat) ≤ a.toRat := - (Dyadic.toRat_le_toRat_iff (x := 0) (y := a)).2 ha' - have hbc' : b.toRat ≤ c.toRat := - (Dyadic.toRat_le_toRat_iff (x := b) (y := c)).2 hbc - have hrat : a.toRat * b.toRat ≤ a.toRat * c.toRat := by - exact Rat.mul_le_mul_of_nonneg_left hbc' ha'' - have hrat' : - (a * b).toRat ≤ (a * c).toRat := by - simpa [Dyadic.toRat_mul] using hrat - exact (Dyadic.toRat_le_toRat_iff (x := a * b) (y := a * c)).1 hrat' - -instance : MulPosMono Dyadic where - mul_le_mul_of_nonneg_right {a} ha {b c} hbc := by - have h := (PosMulMono.mul_le_mul_of_nonneg_left (a := a) ha hbc) - simpa [mul_comm, mul_left_comm, mul_assoc] using h - -instance : IsOrderedRing Dyadic := - IsOrderedRing.of_mul_nonneg (R := Dyadic) (mul_nonneg := by - intro a b ha hb - have ha' : (0 : Dyadic) ≤ a := by simpa using ha - have hb' : (0 : Dyadic) ≤ b := by simpa using hb - have ha'' : (0 : Rat) ≤ a.toRat := - (Dyadic.toRat_le_toRat_iff (x := 0) (y := a)).2 ha' - have hb'' : (0 : Rat) ≤ b.toRat := - (Dyadic.toRat_le_toRat_iff (x := 0) (y := b)).2 hb' - have hrat : (0 : Rat) ≤ a.toRat * b.toRat := Rat.mul_nonneg ha'' hb'' - have hrat' : (0 : Rat) ≤ (a * b).toRat := by - simpa [Dyadic.toRat_mul] using hrat - exact (Dyadic.toRat_le_toRat_iff (x := 0) (y := a * b)).1 hrat') - -@[simp] theorem dyadicToReal_abs (x : Dyadic) : - dyadicToReal |x| = |dyadicToReal x| := by - by_cases hx : 0 ≤ x - · have hx' : 0 ≤ dyadicToReal x := (dyadicToReal_nonneg_iff).2 hx - calc - dyadicToReal |x| = dyadicToReal x := by simp [abs_of_nonneg hx] - _ = |dyadicToReal x| := (abs_of_nonneg hx').symm - · have hx' : x ≤ 0 := le_of_not_ge hx - have hx'' : dyadicToReal x ≤ 0 := by - simpa [dyadicToReal_zero] using dyadicToReal_le_of_le hx' - calc - dyadicToReal |x| = dyadicToReal (-x) := by simp [abs_of_nonpos hx'] - _ = -dyadicToReal x := dyadicToReal_neg x - _ = |dyadicToReal x| := (abs_of_nonpos hx'').symm - -theorem dyadicToReal_abs_le_of_le {x y : Dyadic} (h : |x| ≤ y) : - |dyadicToReal x| ≤ dyadicToReal y := by - have h' : dyadicToReal |x| ≤ dyadicToReal y := - dyadicToReal_le_of_le h - simpa [dyadicToReal_abs] using h' - -@[simp] theorem dyadicToReal_max (x y : Dyadic) : - dyadicToReal (max x y) = max (dyadicToReal x) (dyadicToReal y) := by - by_cases hxy : x ≤ y - · have hxy' : dyadicToReal x ≤ dyadicToReal y := - dyadicToReal_le_of_le hxy - calc - dyadicToReal (max x y) = dyadicToReal y := by simp [max_eq_right hxy] - _ = max (dyadicToReal x) (dyadicToReal y) := by - symm - exact max_eq_right hxy' - · have hyx : y ≤ x := le_of_not_ge hxy - have hyx' : dyadicToReal y ≤ dyadicToReal x := - dyadicToReal_le_of_le hyx - calc - dyadicToReal (max x y) = dyadicToReal x := by simp [max_eq_left hyx] - _ = max (dyadicToReal x) (dyadicToReal y) := by - exact (max_eq_left hyx').symm - -@[simp] theorem dyadicToReal_min (x y : Dyadic) : - dyadicToReal (min x y) = min (dyadicToReal x) (dyadicToReal y) := by - by_cases hxy : x ≤ y - · have hxy' : dyadicToReal x ≤ dyadicToReal y := - dyadicToReal_le_of_le hxy - calc - dyadicToReal (min x y) = dyadicToReal x := by simp [min_eq_left hxy] - _ = min (dyadicToReal x) (dyadicToReal y) := by - symm - exact min_eq_left hxy' - · have hyx : y ≤ x := le_of_not_ge hxy - have hyx' : dyadicToReal y ≤ dyadicToReal x := - dyadicToReal_le_of_le hyx - calc - dyadicToReal (min x y) = dyadicToReal y := by simp [min_eq_right hyx] - _ = min (dyadicToReal x) (dyadicToReal y) := by - exact (min_eq_right hyx').symm + x / y + +theorem ratDivUp_ge (x y : Rat) (hy : y ≠ 0) : + (x / y : Rat) ≤ (ratDivUp x y : Rat) := by + simp [ratDivUp, hy] + +theorem ratDivUp_ge_real (x y : Rat) (hy : y ≠ 0) : + (x : Real) / (y : Real) ≤ ratToReal (ratDivUp x y) := by + simp [ratDivUp, ratToReal, hy] + +@[simp] theorem ratToReal_add (x y : Rat) : + ratToReal (x + y) = ratToReal x + ratToReal y := by + simp [ratToReal] + +@[simp] theorem ratToReal_sub (x y : Rat) : + ratToReal (x - y) = ratToReal x - ratToReal y := by + simp [ratToReal] + +@[simp] theorem ratToReal_mul (x y : Rat) : + ratToReal (x * y) = ratToReal x * ratToReal y := by + simp [ratToReal] + +@[simp] theorem ratToReal_neg (x : Rat) : + ratToReal (-x) = -ratToReal x := by + simp [ratToReal] + +@[simp] theorem ratToReal_if {p : Prop} [Decidable p] (a b : Rat) : + ratToReal (if p then a else b) = + if p then ratToReal a else ratToReal b := by + by_cases hp : p <;> simp [hp, ratToReal] + +theorem ratToReal_le_iff {x y : Rat} : + ratToReal x ≤ ratToReal y ↔ x ≤ y := by + simp [ratToReal] + +/-- Rational order implies real order after casting. -/ +theorem ratToReal_le_of_le {x y : Rat} (h : x ≤ y) : + ratToReal x ≤ ratToReal y := + (ratToReal_le_iff (x := x) (y := y)).2 h + +theorem ratToReal_lt_iff {x y : Rat} : + ratToReal x < ratToReal y ↔ x < y := by + simp [ratToReal] + +theorem ratToReal_nonneg_iff {x : Rat} : + 0 ≤ ratToReal x ↔ 0 ≤ x := by + simp [ratToReal] + +theorem ratToReal_nonneg_of_nonneg {x : Rat} (h : 0 ≤ x) : + 0 ≤ ratToReal x := + (ratToReal_nonneg_iff (x := x)).2 h + +theorem ratToReal_nonpos_iff {x : Rat} : + ratToReal x ≤ 0 ↔ x ≤ 0 := by + simp [ratToReal] + +@[simp] theorem ratToReal_abs (x : Rat) : + ratToReal |x| = |ratToReal x| := by + simp [ratToReal] + +theorem ratToReal_abs_le_of_le {x y : Rat} (h : |x| ≤ y) : + |ratToReal x| ≤ ratToReal y := by + have h' : ratToReal |x| ≤ ratToReal y := + ratToReal_le_of_le h + simpa [ratToReal_abs] using h' + +@[simp] theorem ratToReal_max (x y : Rat) : + ratToReal (max x y) = max (ratToReal x) (ratToReal y) := by + simp [ratToReal] + +@[simp] theorem ratToReal_min (x y : Rat) : + ratToReal (min x y) = min (ratToReal x) (ratToReal y) := by + simp [ratToReal] end Nfp diff --git a/Nfp/IO.lean b/Nfp/IO.lean index 1f2fe1e..2278526 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -26,9 +26,9 @@ def runInductionCertify (scoresPath : System.FilePath) (valuesPath? : Option System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseDyadicOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseDyadicOpt "min-margin" minMarginStr? - let maxEps?E := parseDyadicOpt "max-eps" maxEpsStr? + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? match minLogitDiff?E, minMargin?E, maxEps?E with | Except.error msg, _, _ => IO.eprintln s!"error: {msg}" @@ -40,8 +40,8 @@ def runInductionCertify (scoresPath : System.FilePath) IO.eprintln s!"error: {msg}" return 2 | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let minMargin := minMargin?.getD (0 : Dyadic) - let maxEps := maxEps?.getD (dyadicOfRatDown (Rat.divInt 1 2)) + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) if minLogitDiff?.isSome && valuesPath?.isNone then IO.eprintln "error: min-logit-diff requires --values" return 2 @@ -108,14 +108,14 @@ def runInductionCertify (scoresPath : System.FilePath) let effectiveMinLogitDiff := match minLogitDiff?, certVals'.direction with | some v, _ => some v - | none, some _ => some (0 : Dyadic) + | none, some _ => some (0 : Rat) | none, none => none match logitDiffLB? with | none => IO.eprintln "error: empty active set for logit-diff bound" return (2 : UInt32) | some logitDiffLB => - let violation? : Option Dyadic := + let violation? : Option Rat := match effectiveMinLogitDiff with | none => none | some minLogitDiff => @@ -140,9 +140,9 @@ def runInductionCertifySound (scoresPath : System.FilePath) (valuesPath : System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseDyadicOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseDyadicOpt "min-margin" minMarginStr? - let maxEps?E := parseDyadicOpt "max-eps" maxEpsStr? + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? match minLogitDiff?E, minMargin?E, maxEps?E with | Except.error msg, _, _ => IO.eprintln s!"error: {msg}" @@ -154,8 +154,8 @@ def runInductionCertifySound (scoresPath : System.FilePath) IO.eprintln s!"error: {msg}" return 2 | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => - let minMargin := minMargin?.getD (0 : Dyadic) - let maxEps := maxEps?.getD (dyadicOfRatDown (Rat.divInt 1 2)) + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) let parsedScores ← loadSoftmaxMarginRaw scoresPath match parsedScores with | Except.error msg => @@ -216,14 +216,14 @@ def runInductionCertifySound (scoresPath : System.FilePath) let effectiveMinLogitDiff := match minLogitDiff?, certVals.direction with | some v, _ => some v - | none, some _ => some (0 : Dyadic) + | none, some _ => some (0 : Rat) | none, none => none match logitDiffLB? with | none => IO.eprintln "error: empty active set for logit-diff bound" return 2 | some logitDiffLB => - let violation? : Option Dyadic := + let violation? : Option Rat := match effectiveMinLogitDiff with | none => none | some minLogitDiff => @@ -248,9 +248,9 @@ def runInductionCertifyEndToEnd (scoresPath : System.FilePath) (valuesPath : System.FilePath) (downstreamPath : System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseDyadicOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseDyadicOpt "min-margin" minMarginStr? - let maxEps?E := parseDyadicOpt "max-eps" maxEpsStr? + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? match minLogitDiff?E, minMargin?E, maxEps?E with | Except.error msg, _, _ => IO.eprintln s!"error: {msg}" @@ -262,8 +262,8 @@ def runInductionCertifyEndToEnd (scoresPath : System.FilePath) IO.eprintln s!"error: {msg}" return 2 | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let minMargin := minMargin?.getD (0 : Dyadic) - let maxEps := maxEps?.getD (dyadicOfRatDown (Rat.divInt 1 2)) + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) let parsedScores ← timePhase "load softmax cert" <| loadSoftmaxMarginCert scoresPath match parsedScores with @@ -321,7 +321,7 @@ def runInductionCertifyEndToEnd (scoresPath : System.FilePath) let effectiveMinLogitDiff := match minLogitDiff?, certVals'.direction with | some v, _ => some v - | none, some _ => some (0 : Dyadic) + | none, some _ => some (0 : Rat) | none, none => none match logitDiffLB? with | none => @@ -337,7 +337,7 @@ def runInductionCertifyEndToEnd (scoresPath : System.FilePath) let downstreamOk := Circuit.checkDownstreamLinearCert downstream if downstreamOk then let finalLB := logitDiffLB - downstream.error - let violation? : Option Dyadic := + let violation? : Option Rat := match effectiveMinLogitDiff with | none => none | some minLogitDiff => @@ -367,9 +367,9 @@ def runInductionCertifyEndToEndMatrix (scoresPath : System.FilePath) (valuesPath : System.FilePath) (matrixPath : System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseDyadicOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseDyadicOpt "min-margin" minMarginStr? - let maxEps?E := parseDyadicOpt "max-eps" maxEpsStr? + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? match minLogitDiff?E, minMargin?E, maxEps?E with | Except.error msg, _, _ => IO.eprintln s!"error: {msg}" @@ -381,8 +381,8 @@ def runInductionCertifyEndToEndMatrix (scoresPath : System.FilePath) IO.eprintln s!"error: {msg}" return 2 | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let minMargin := minMargin?.getD (0 : Dyadic) - let maxEps := maxEps?.getD (dyadicOfRatDown (Rat.divInt 1 2)) + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) let parsedScores ← timePhase "load softmax cert" <| loadSoftmaxMarginCert scoresPath match parsedScores with @@ -440,7 +440,7 @@ def runInductionCertifyEndToEndMatrix (scoresPath : System.FilePath) let effectiveMinLogitDiff := match minLogitDiff?, certVals'.direction with | some v, _ => some v - | none, some _ => some (0 : Dyadic) + | none, some _ => some (0 : Rat) | none, none => none match logitDiffLB? with | none => @@ -461,11 +461,11 @@ def runInductionCertifyEndToEndMatrix (scoresPath : System.FilePath) else have hinput : 0 ≤ inputBound := by exact le_of_not_gt hneg - let W : Matrix (Fin rows) (Fin cols) Dyadic := raw.entries + let W : Matrix (Fin rows) (Fin cols) Rat := raw.entries let downstream := (Sound.Bounds.buildDownstreamLinearCert W inputBound hinput).1 let finalLB := logitDiffLB - downstream.error - let violation? : Option Dyadic := + let violation? : Option Rat := match effectiveMinLogitDiff with | none => none | some minLogitDiff => @@ -494,9 +494,9 @@ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) (residualIntervalPath? : Option System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseDyadicOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseDyadicOpt "min-margin" minMarginStr? - let maxEps?E := parseDyadicOpt "max-eps" maxEpsStr? + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? match minLogitDiff?E, minMargin?E, maxEps?E with | Except.error msg, _, _ => IO.eprintln s!"error: {msg}" @@ -508,8 +508,8 @@ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) IO.eprintln s!"error: {msg}" return 2 | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let minMargin := minMargin?.getD (0 : Dyadic) - let maxEps := maxEps?.getD (dyadicOfRatDown (Rat.divInt 1 2)) + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) let parsedScores ← timePhase "load softmax cert" <| loadSoftmaxMarginCert scoresPath match parsedScores with @@ -567,7 +567,7 @@ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) let effectiveMinLogitDiff := match minLogitDiff?, certVals'.direction with | some v, _ => some v - | none, some _ => some (0 : Dyadic) + | none, some _ => some (0 : Rat) | none, none => none match logitDiffLB? with | none => @@ -648,7 +648,7 @@ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) return 1 | Except.ok colNeg => let dirVec : - Fin header.modelDim → Dyadic := + Fin header.modelDim → Rat := fun i => colTarget i - colNeg i let downstreamError ← timePure "downstream error" (fun () => @@ -657,7 +657,7 @@ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) residualCert'.lo residualCert'.hi) let finalLB := logitDiffLB - downstreamError - let violation? : Option Dyadic := + let violation? : Option Rat := match effectiveMinLogitDiff with | none => none | some minLogitDiff => diff --git a/Nfp/IO/Bench/InductionCore.lean b/Nfp/IO/Bench/InductionCore.lean index b6050c3..7ba8c29 100644 --- a/Nfp/IO/Bench/InductionCore.lean +++ b/Nfp/IO/Bench/InductionCore.lean @@ -21,37 +21,37 @@ private def benchPhasePure {α : Type} (label : String) (act : Unit → α) : IO timePhase label (pure (act ())) private def forceScore {seq dModel dHead : Nat} - (score : Sound.HeadScoreBounds seq dModel dHead) : Dyadic := + (score : Sound.HeadScoreBounds seq dModel dHead) : Rat := score.margin + score.eps private def forceValues {seq dModel dHead : Nat} - (vals : Sound.HeadValueBounds seq dModel dHead) : Dyadic := + (vals : Sound.HeadValueBounds seq dModel dHead) : Rat := vals.lo + vals.hi private def forceQAbs {seq dHead : Nat} - (qAbs : Fin seq → Fin dHead → Dyadic) : Dyadic := + (qAbs : Fin seq → Fin dHead → Rat) : Rat := (Finset.univ : Finset (Fin seq)).sum (fun q => (Finset.univ : Finset (Fin dHead)).sum (fun d => qAbs q d)) private def forceLn {seq dModel : Nat} - (ln : Fin seq → Fin dModel → Dyadic) : Dyadic := + (ln : Fin seq → Fin dModel → Rat) : Rat := (Finset.univ : Finset (Fin seq)).sum (fun q => (Finset.univ : Finset (Fin dModel)).sum (fun i => ln q i)) private def forceKAbs {seq dHead : Nat} - (kAbs : Fin seq → Fin dHead → Dyadic) : Dyadic := + (kAbs : Fin seq → Fin dHead → Rat) : Rat := (Finset.univ : Finset (Fin seq)).sum (fun q => (Finset.univ : Finset (Fin dHead)).sum (fun d => kAbs q d)) private def forceDotAbs {seq dHead : Nat} - (qAbs kAbs : Fin seq → Fin dHead → Dyadic) : Dyadic := + (qAbs kAbs : Fin seq → Fin dHead → Rat) : Rat := (Finset.univ : Finset (Fin seq)).sum (fun q => (Finset.univ : Finset (Fin seq)).sum (fun k => Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d))) private def forceDotAbsTasksReduce {seq dHead : Nat} - (qAbs kAbs : Fin seq → Fin dHead → Dyadic) : Dyadic := - let tasks : Array (Task Dyadic) := + (qAbs kAbs : Fin seq → Fin dHead → Rat) : Rat := + let tasks : Array (Task Rat) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => (Finset.univ : Finset (Fin seq)).sum (fun k => @@ -65,60 +65,60 @@ private def isPow2 (n : Nat) : Bool := else decide (Nat.pow 2 (Nat.log2 n) = n) -private def isDyadic (q : Dyadic) : Bool := - isPow2 q.toRat.den +private def isPow2Den (q : Rat) : Bool := + isPow2 q.den -private def countDyadic {seq dHead : Nat} +private def countPow2Den {seq dHead : Nat} (qs : List (Fin seq)) (ds : List (Fin dHead)) - (f : Fin seq → Fin dHead → Dyadic) : Nat := + (f : Fin seq → Fin dHead → Rat) : Nat := qs.foldl (fun acc q => - ds.foldl (fun acc' d => acc' + (if isDyadic (f q d) then 1 else 0)) acc) 0 + ds.foldl (fun acc' d => acc' + (if isPow2Den (f q d) then 1 else 0)) acc) 0 -private def dyadicSampleReport {seq dModel dHead : Nat} +private def pow2DenSampleReport {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (qkv : Sound.HeadQKVBounds seq dModel dHead) : String := let qs := (List.finRange seq).take (min seq 2) let ds := (List.finRange dHead).take (min dHead 8) let total := qs.length * ds.length - let qLoDy := countDyadic qs ds qkv.qLo - let qHiDy := countDyadic qs ds qkv.qHi - let qAbsDy := countDyadic qs ds qkv.qAbs - let kAbsDy := countDyadic qs ds qkv.kAbs - let epsDy := if isDyadic inputs.lnEps then 1 else 0 - s!"dyadic sample: total={total} qLo={qLoDy} qHi={qHiDy} qAbs={qAbsDy} " ++ + let qLoDy := countPow2Den qs ds qkv.qLo + let qHiDy := countPow2Den qs ds qkv.qHi + let qAbsDy := countPow2Den qs ds qkv.qAbs + let kAbsDy := countPow2Den qs ds qkv.kAbs + let epsDy := if isPow2Den inputs.lnEps then 1 else 0 + s!"pow2-den sample: total={total} qLo={qLoDy} qHi={qHiDy} qAbs={qAbsDy} " ++ s!"kAbs={kAbsDy} lnEps={epsDy}" -private def dyadicSanityReport : String := - let rat := dyadicOfRatDown (Rat.divInt 1 8) +private def pow2DenSanityReport : String := + let rat := ratRoundDown (Rat.divInt 1 8) let powChecks := s!"pow2(1)={isPow2 1} pow2(2)={isPow2 2} pow2(3)={isPow2 3} " ++ s!"pow2(4)={isPow2 4} pow2(8)={isPow2 8}" - let ratCheck := s!"rat(1/8).den={rat.toRat.den} dyadic={isDyadic rat}" - s!"dyadic sanity: {powChecks} {ratCheck}" + let ratCheck := s!"rat(1/8).den={rat.den} pow2den={isPow2Den rat}" + s!"pow2-den sanity: {powChecks} {ratCheck}" private def forceQRowTasks {seq dHead : Nat} - (q0 : Fin seq) (qLo : Fin seq → Fin dHead → Dyadic) : Int := - let tasks : Array (Task Dyadic) := + (q0 : Fin seq) (qLo : Fin seq → Fin dHead → Rat) : Int := + let tasks : Array (Task Rat) := Array.ofFn (fun d : Fin dHead => Task.spawn (fun _ => qLo q0 d)) let total := (Finset.univ : Finset (Fin dHead)).sum (fun d => (tasks[d.1]'(by simp [tasks, d.isLt])).get) - total.toRat.num + total.num private def qAbsRowChunk {seq dHead : Nat} - (q0 : Fin seq) (qAbs : Fin seq → Fin dHead → Dyadic) (start stop : Nat) : Dyadic := + (q0 : Fin seq) (qAbs : Fin seq → Fin dHead → Rat) (start stop : Nat) : Rat := let chunk : Finset (Fin dHead) := (Finset.univ : Finset (Fin dHead)).filter (fun d => start ≤ d.1 ∧ d.1 < stop) chunk.sum (fun d => qAbs q0 d) private def forceQAbsRowTasksReduce {seq dHead : Nat} - (q0 : Fin seq) (qAbs : Fin seq → Fin dHead → Dyadic) (chunkSize : Nat) : Dyadic := + (q0 : Fin seq) (qAbs : Fin seq → Fin dHead → Rat) (chunkSize : Nat) : Rat := if chunkSize = 0 then - (0 : Dyadic) + (0 : Rat) else let chunks : Nat := (dHead + chunkSize - 1) / chunkSize - let tasks : Array (Task Dyadic) := + let tasks : Array (Task Rat) := Array.ofFn (fun i : Fin chunks => Task.spawn (fun _ => let start := i.1 * chunkSize @@ -128,27 +128,27 @@ private def forceQAbsRowTasksReduce {seq dHead : Nat} (tasks[i.1]'(by simp [tasks, i.isLt])).get) private def forceQAbsAllTasksReduce {seq dHead : Nat} - (qAbs : Fin seq → Fin dHead → Dyadic) (chunkSize : Nat) : Dyadic := - let tasks : Array (Task Dyadic) := + (qAbs : Fin seq → Fin dHead → Rat) (chunkSize : Nat) : Rat := + let tasks : Array (Task Rat) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => forceQAbsRowTasksReduce q qAbs chunkSize)) (Finset.univ : Finset (Fin seq)).sum (fun q => (tasks[q.1]'(by simp [tasks, q.isLt])).get) private def forceQAbsActiveTasksReduce {seq dHead : Nat} - (active : Finset (Fin seq)) (qAbs : Fin seq → Fin dHead → Dyadic) (chunkSize : Nat) : Dyadic := + (active : Finset (Fin seq)) (qAbs : Fin seq → Fin dHead → Rat) (chunkSize : Nat) : Rat := if hactive : active.Nonempty then - let tasks : Array (Task Dyadic) := + let tasks : Array (Task Rat) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => if q ∈ active then forceQAbsRowTasksReduce q qAbs chunkSize else - (0 : Dyadic))) + (0 : Rat))) active.sum (fun q => (tasks[q.1]'(by simp [tasks, q.isLt])).get) else - (0 : Dyadic) + (0 : Rat) /-- Run a core benchmark from already-parsed head inputs. -/ def runCoreBenchFromInputs {seq dModel dHead : Nat} [NeZero seq] @@ -159,13 +159,13 @@ def runCoreBenchFromInputs {seq dModel dHead : Nat} [NeZero seq] let _ ← benchPhasePure "lnLo force" (fun () => forceLn lnLo) let _ ← benchPhasePure "lnHi force" (fun () => forceLn lnHi) let qkv ← benchPhasePure "qkv bounds" (fun () => Sound.headQKVBounds inputs lnLo lnHi) - let _ ← timePhase "dyadic sample" (do - IO.println (dyadicSampleReport inputs qkv) - IO.println dyadicSanityReport + let _ ← timePhase "pow2-den sample" (do + IO.println (pow2DenSampleReport inputs qkv) + IO.println pow2DenSanityReport pure ()) let _ ← benchPhasePure "qLo single" (fun () => match h : dHead with - | 0 => (0 : Dyadic) + | 0 => (0 : Rat) | Nat.succ _ => let q0 : Fin seq := ⟨0, Nat.pos_of_ne_zero (NeZero.ne (n := seq))⟩ @@ -182,7 +182,7 @@ def runCoreBenchFromInputs {seq dModel dHead : Nat} [NeZero seq] exact Nat.pos_of_ne_zero (NeZero.ne (n := seq))⟩ let total := (Finset.univ : Finset (Fin dHead)).sum (fun d => qkv.qLo q0 d) - total.toRat.num) + total.num) let _ ← benchPhasePure "qLo force" (fun () => forceQAbs qkv.qLo) let _ ← benchPhasePure "qHi force" (fun () => forceQAbs qkv.qHi) let _ ← benchPhasePure "qAbs single" (fun () => @@ -190,7 +190,7 @@ def runCoreBenchFromInputs {seq dModel dHead : Nat} [NeZero seq] ⟨0, by exact Nat.pos_of_ne_zero (NeZero.ne (n := seq))⟩ match h : dHead with - | 0 => (0 : Dyadic) + | 0 => (0 : Rat) | Nat.succ _ => let d0 : Fin dHead := ⟨0, by simp [h]⟩ qkv.qAbs q0 d0) diff --git a/Nfp/IO/Bench/Rational.lean b/Nfp/IO/Bench/Rational.lean index c48853f..657d56a 100644 --- a/Nfp/IO/Bench/Rational.lean +++ b/Nfp/IO/Bench/Rational.lean @@ -6,7 +6,7 @@ import Nfp.Sound.Bounds.MatrixNorm import Nfp.Sound.Linear.FinFold /-! -Microbenchmarks for dyadic arithmetic and caching strategies. +Microbenchmarks for rational arithmetic and caching strategies. -/ namespace Nfp @@ -19,28 +19,28 @@ private def benchItersFor (base : Nat) (n : Nat) : Nat := let scale := max 1 (n / 64) max 1 (base / scale) -private def mkDyadic (num den : Nat) (neg : Bool) : Dyadic := +private def mkRat (num den : Nat) (neg : Bool) : Rat := let n : Int := Int.ofNat (num + 1) let d : Int := Int.ofNat (den + 1) let q : Rat := Rat.divInt (if neg then -n else n) d - dyadicOfRatDown q + ratRoundDown q -private def mkVecDyadic (n : Nat) (seed : Nat) (salt : Nat) (negEvery : Nat) : Fin n → Dyadic := fun i => +private def mkVecRat (n : Nat) (seed : Nat) (salt : Nat) (negEvery : Nat) : Fin n → Rat := fun i => let idx := i.1 + seed + salt let neg := (idx % negEvery) = 0 - mkDyadic (idx % 97) (idx % 89) neg + mkRat (idx % 97) (idx % 89) neg private def mkInterval (n : Nat) (seed : Nat) : - (Fin n → Dyadic) × (Fin n → Dyadic) × (Fin n → Dyadic) := - let v : Fin n → Dyadic := mkVecDyadic n seed 0 2 - let base : Fin n → Dyadic := mkVecDyadic n seed 13 3 - let lo : Fin n → Dyadic := fun i => base i - 1 - let hi : Fin n → Dyadic := fun i => base i + 1 + (Fin n → Rat) × (Fin n → Rat) × (Fin n → Rat) := + let v : Fin n → Rat := mkVecRat n seed 0 2 + let base : Fin n → Rat := mkVecRat n seed 13 3 + let lo : Fin n → Rat := fun i => base i - 1 + let hi : Fin n → Rat := fun i => base i + 1 (v, lo, hi) -private def benchLoop (label : String) (iters : Nat) (act : Unit → Dyadic) : IO Unit := do +private def benchLoop (label : String) (iters : Nat) (act : Unit → Rat) : IO Unit := do let t0 ← monoUsNow - let mut last : Dyadic := 0 + let mut last : Rat := 0 for _ in List.range iters do last := act () let t1 ← monoUsNow @@ -65,9 +65,9 @@ private def benchDotInterval (n iters seed : Nat) : IO Unit := do Sound.Bounds.dotIntervalUpperCachedRat v lo hi) private def dotIntervalLowerCachedCore {n : Nat} - (vArr loArr hiArr : Array Dyadic) - (hv : vArr.size = n) (hlo : loArr.size = n) (hhi : hiArr.size = n) : Dyadic := - let term : Fin n → Dyadic := fun j => + (vArr loArr hiArr : Array Rat) + (hv : vArr.size = n) (hlo : loArr.size = n) (hhi : hiArr.size = n) : Rat := + let term : Fin n → Rat := fun j => let vj := vArr[j.1]'(by simp [hv, j.isLt]) let loj := loArr[j.1]'(by @@ -81,9 +81,9 @@ private def dotIntervalLowerCachedCore {n : Nat} Sound.Linear.sumFin n term private def dotIntervalUpperCachedCore {n : Nat} - (vArr loArr hiArr : Array Dyadic) - (hv : vArr.size = n) (hlo : loArr.size = n) (hhi : hiArr.size = n) : Dyadic := - let term : Fin n → Dyadic := fun j => + (vArr loArr hiArr : Array Rat) + (hv : vArr.size = n) (hlo : loArr.size = n) (hhi : hiArr.size = n) : Rat := + let term : Fin n → Rat := fun j => let vj := vArr[j.1]'(by simp [hv, j.isLt]) let loj := loArr[j.1]'(by @@ -116,8 +116,8 @@ private def benchDotIntervalCachedParts (n iters seed : Nat) : IO Unit := do dotIntervalUpperCachedCore vArr loArr hiArr hv hlo hhi) private def benchDotFin (n iters seed : Nat) : IO Unit := do - let x : Fin n → Dyadic := mkVecDyadic n seed 7 4 - let y : Fin n → Dyadic := mkVecDyadic n seed 19 5 + let x : Fin n → Rat := mkVecRat n seed 7 4 + let y : Fin n → Rat := mkVecRat n seed 19 5 let labelBase := s!"n={n}" benchLoop s!"dotFin {labelBase}" iters (fun () => Sound.Linear.dotFin n x y) @@ -125,24 +125,24 @@ private def benchDotFin (n iters seed : Nat) : IO Unit := do private def headShapeIters (base : Nat) : Nat := max 1 (base / 10) -private def mkHeadAbs (seq dHead : Nat) (seed : Nat) (salt : Nat) : Fin seq → Fin dHead → Dyadic := +private def mkHeadAbs (seq dHead : Nat) (seed : Nat) (salt : Nat) : Fin seq → Fin dHead → Rat := fun q d => - mkDyadic (q.1 * 31 + d.1 + seed + salt) + mkRat (q.1 * 31 + d.1 + seed + salt) (q.1 + d.1 + 7 + seed + salt) (((q.1 + d.1) % 3) = 0) -private def mkHeadVal (seq dHead : Nat) (seed : Nat) (salt : Nat) : Fin seq → Fin dHead → Dyadic := +private def mkHeadVal (seq dHead : Nat) (seed : Nat) (salt : Nat) : Fin seq → Fin dHead → Rat := fun q d => - mkDyadic (q.1 * 17 + d.1 + seed + salt) + mkRat (q.1 * 17 + d.1 + seed + salt) (q.1 + d.1 + 11 + seed + salt) (((q.1 + d.1) % 5) = 0) -private def mkHeadDir (dHead : Nat) (seed : Nat) (salt : Nat) : Fin dHead → Dyadic := fun d => - mkDyadic (d.1 + seed + salt) (d.1 + 3 + seed + salt) ((d.1 % 2) = 0) +private def mkHeadDir (dHead : Nat) (seed : Nat) (salt : Nat) : Fin dHead → Rat := fun d => + mkRat (d.1 + seed + salt) (d.1 + 3 + seed + salt) ((d.1 % 2) = 0) private def benchHeadDotAbs (iters seed : Nat) : IO Unit := do let seq := 8 let dHead := 64 - let qAbs : Fin seq → Fin dHead → Dyadic := mkHeadAbs seq dHead seed 3 - let kAbs : Fin seq → Fin dHead → Dyadic := mkHeadAbs seq dHead seed 19 + let qAbs : Fin seq → Fin dHead → Rat := mkHeadAbs seq dHead seed 3 + let kAbs : Fin seq → Fin dHead → Rat := mkHeadAbs seq dHead seed 19 benchLoop "head dotAbs dotFin" iters (fun () => (List.finRange seq).foldl (fun acc q => (List.finRange seq).foldl (fun acc' k => @@ -151,46 +151,46 @@ private def benchHeadDotAbs (iters seed : Nat) : IO Unit := do private def benchHeadValueBounds (iters seed : Nat) : IO Unit := do let seq := 8 let dHead := 64 - let dirHead : Fin dHead → Dyadic := mkHeadDir dHead seed 5 - let vLo : Fin seq → Fin dHead → Dyadic := mkHeadVal seq dHead seed 11 - let vHi : Fin seq → Fin dHead → Dyadic := mkHeadVal seq dHead seed 23 + let dirHead : Fin dHead → Rat := mkHeadDir dHead seed 5 + let vLo : Fin seq → Fin dHead → Rat := mkHeadVal seq dHead seed 11 + let vHi : Fin seq → Fin dHead → Rat := mkHeadVal seq dHead seed 23 let dirArr := Array.ofFn dirHead have hdir : dirArr.size = dHead := by simp [dirArr] let univ : Finset (Fin seq) := Finset.univ have hnonempty : univ.Nonempty := by simp [univ] benchLoop "head value bounds (cached)" iters (fun () => - let valsLo : Fin seq → Dyadic := fun k => + let valsLo : Fin seq → Rat := fun k => Sound.Bounds.dotIntervalLowerCachedRat dirHead (vLo k) (vHi k) - let valsHi : Fin seq → Dyadic := fun k => + let valsHi : Fin seq → Rat := fun k => Sound.Bounds.dotIntervalUpperCachedRat dirHead (vLo k) (vHi k) let lo := univ.inf' hnonempty valsLo let hi := univ.sup' hnonempty valsHi lo + hi) benchLoop "head value bounds (common den)" iters (fun () => - let valsLo : Fin seq → Dyadic := fun k => + let valsLo : Fin seq → Rat := fun k => Sound.Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k) - let valsHi : Fin seq → Dyadic := fun k => + let valsHi : Fin seq → Rat := fun k => Sound.Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k) let lo := univ.inf' hnonempty valsLo let hi := univ.sup' hnonempty valsHi lo + hi) benchLoop "head value bounds (direct)" iters (fun () => - let valsLo : Fin seq → Dyadic := fun k => + let valsLo : Fin seq → Rat := fun k => Sound.Bounds.dotIntervalLower dirHead (vLo k) (vHi k) - let valsHi : Fin seq → Dyadic := fun k => + let valsHi : Fin seq → Rat := fun k => Sound.Bounds.dotIntervalUpper dirHead (vLo k) (vHi k) let lo := univ.inf' hnonempty valsLo let hi := univ.sup' hnonempty valsHi lo + hi) benchLoop "head value bounds (cached, reuse dir)" iters (fun () => - let valsLo : Fin seq → Dyadic := fun k => + let valsLo : Fin seq → Rat := fun k => let loArr := Array.ofFn (vLo k) let hiArr := Array.ofFn (vHi k) have hlo : loArr.size = dHead := by simp [loArr] have hhi : hiArr.size = dHead := by simp [hiArr] dotIntervalLowerCachedCore dirArr loArr hiArr hdir hlo hhi - let valsHi : Fin seq → Dyadic := fun k => + let valsHi : Fin seq → Rat := fun k => let loArr := Array.ofFn (vLo k) let hiArr := Array.ofFn (vHi k) have hlo : loArr.size = dHead := by simp [loArr] @@ -200,16 +200,16 @@ private def benchHeadValueBounds (iters seed : Nat) : IO Unit := do let hi := univ.sup' hnonempty valsHi lo + hi) -private def benchDyadicDivInt (iters seed : Nat) : IO Unit := do +private def benchRatDivInt (iters seed : Nat) : IO Unit := do let bigNum : Int := Int.ofNat (2 ^ 200) * Int.ofNat (3 ^ 120) + Int.ofNat (5 ^ 90) + Int.ofNat seed let bigDen : Int := Int.ofNat (2 ^ 150) * Int.ofNat (3 ^ 80) + (Int.ofNat seed) + 1 - benchLoop "dyadicOfRatDown divInt big" iters (fun () => - dyadicOfRatDown (Rat.divInt bigNum bigDen)) + benchLoop "ratRoundDown divInt big" iters (fun () => + ratRoundDown (Rat.divInt bigNum bigDen)) private def forceQkvSumLimited {seq dModel dHead : Nat} - (qkv : Sound.HeadQKVBounds seq dModel dHead) (qLimit dLimit : Nat) : Dyadic := + (qkv : Sound.HeadQKVBounds seq dModel dHead) (qLimit dLimit : Nat) : Rat := let qs := (List.finRange seq).take qLimit let ds := (List.finRange dHead).take dLimit qs.foldl (fun acc q => @@ -221,10 +221,10 @@ private def forceQkvSumLimited {seq dModel dHead : Nat} private def forceQkvSumDirect {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (lnLo lnHi : Fin seq → Fin dModel → Dyadic) + (lnLo lnHi : Fin seq → Fin dModel → Rat) (qLimit dLimit : Nat) - (dotLower dotUpper : (Fin dModel → Dyadic) → (Fin dModel → Dyadic) → (Fin dModel → Dyadic) → Dyadic) : - Dyadic := + (dotLower dotUpper : (Fin dModel → Rat) → (Fin dModel → Rat) → (Fin dModel → Rat) → Rat) : + Rat := let qs := (List.finRange seq).take qLimit let ds := (List.finRange dHead).take dLimit qs.foldl (fun acc q => @@ -257,19 +257,19 @@ private def benchHeadInputs {seq dModel dHead : Nat} [NeZero seq] | none => dHead let skipCache := (← IO.getEnv "NFP_BENCH_SKIP_QKV_CACHE").isSome if !skipCache then - IO.println s!"bench: head qkv bounds (cachedDyadic) start q={qLimit} d={dLimit}" + IO.println s!"bench: head qkv bounds (cachedRat) start q={qLimit} d={dLimit}" (← IO.getStdout).flush - let _sumDyadic ← Nfp.IO.timePure "bench: head qkv bounds (cachedDyadic)" (fun () => + let _sumRat ← Nfp.IO.timePure "bench: head qkv bounds (cachedRat)" (fun () => forceQkvSumLimited (Sound.headQKVBounds inputs lnBounds.1 lnBounds.2) qLimit dLimit) pure () - IO.println s!"bench: head qkv bounds (directDyadic) start q={qLimit} d={dLimit}" + IO.println s!"bench: head qkv bounds (directRat) start q={qLimit} d={dLimit}" (← IO.getStdout).flush - let _sumDirectDyadic ← Nfp.IO.timePure "bench: head qkv bounds (directDyadic)" (fun () => + let _sumDirectRat ← Nfp.IO.timePure "bench: head qkv bounds (directRat)" (fun () => forceQkvSumDirect inputs lnBounds.1 lnBounds.2 qLimit dLimit Sound.Bounds.dotIntervalLowerCachedRat Sound.Bounds.dotIntervalUpperCachedRat) - IO.println s!"bench: head qkv bounds (directDyadicNoCache) start q={qLimit} d={dLimit}" + IO.println s!"bench: head qkv bounds (directRatNoCache) start q={qLimit} d={dLimit}" (← IO.getStdout).flush - let _sumDirectDyadicNoCache ← Nfp.IO.timePure "bench: head qkv bounds (directDyadicNoCache)" (fun () => + let _sumDirectRatNoCache ← Nfp.IO.timePure "bench: head qkv bounds (directRatNoCache)" (fun () => forceQkvSumDirect inputs lnBounds.1 lnBounds.2 qLimit dLimit Sound.Bounds.dotIntervalLower Sound.Bounds.dotIntervalUpper) let qkv := Sound.headQKVBounds inputs lnBounds.1 lnBounds.2 @@ -288,29 +288,29 @@ private def benchHeadInputs {seq dModel dHead : Nat} [NeZero seq] let univ : Finset (Fin seq) := Finset.univ have hnonempty : univ.Nonempty := by simp [univ] benchLoop "head inputs value bounds (cached)" iters (fun () => - let valsLo : Fin seq → Dyadic := fun k => + let valsLo : Fin seq → Rat := fun k => Sound.Bounds.dotIntervalLowerCachedRat dirHead (qkv.vLo k) (qkv.vHi k) - let valsHi : Fin seq → Dyadic := fun k => + let valsHi : Fin seq → Rat := fun k => Sound.Bounds.dotIntervalUpperCachedRat dirHead (qkv.vLo k) (qkv.vHi k) let lo := univ.inf' hnonempty valsLo let hi := univ.sup' hnonempty valsHi lo + hi) benchLoop "head inputs value bounds (direct)" iters (fun () => - let valsLo : Fin seq → Dyadic := fun k => + let valsLo : Fin seq → Rat := fun k => Sound.Bounds.dotIntervalLower dirHead (qkv.vLo k) (qkv.vHi k) - let valsHi : Fin seq → Dyadic := fun k => + let valsHi : Fin seq → Rat := fun k => Sound.Bounds.dotIntervalUpper dirHead (qkv.vLo k) (qkv.vHi k) let lo := univ.inf' hnonempty valsLo let hi := univ.sup' hnonempty valsHi lo + hi) benchLoop "head inputs value bounds (cached, reuse dir)" iters (fun () => - let valsLo : Fin seq → Dyadic := fun k => + let valsLo : Fin seq → Rat := fun k => let loArr := Array.ofFn (qkv.vLo k) let hiArr := Array.ofFn (qkv.vHi k) have hlo : loArr.size = dHead := by simp [loArr] have hhi : hiArr.size = dHead := by simp [hiArr] dotIntervalLowerCachedCore dirArr loArr hiArr hdir hlo hhi - let valsHi : Fin seq → Dyadic := fun k => + let valsHi : Fin seq → Rat := fun k => let loArr := Array.ofFn (qkv.vLo k) let hiArr := Array.ofFn (qkv.vHi k) have hlo : loArr.size = dHead := by simp [loArr] @@ -321,7 +321,7 @@ private def benchHeadInputs {seq dModel dHead : Nat} [NeZero seq] lo + hi) /-- Run rational microbenchmarks for several vector sizes. -/ -def runDyadicBench (seed : Nat) : IO Unit := do +def runRatBench (seed : Nat) : IO Unit := do let baseIters := match (← IO.getEnv "NFP_BENCH_ITERS") with | some raw => raw.toNat?.getD 200 @@ -337,14 +337,14 @@ def runDyadicBench (seed : Nat) : IO Unit := do IO.println s!"bench: start head-shape iters={headIters}" benchHeadDotAbs headIters seed benchHeadValueBounds headIters seed - benchDyadicDivInt headIters seed + benchRatDivInt headIters seed /-- Run benchmarks using a real induction-head input payload. -/ -def runDyadicBenchFromInputs {seq dModel dHead : Nat} [NeZero seq] +def runRatBenchFromInputs {seq dModel dHead : Nat} [NeZero seq] (seed : Nat) (inputs : Model.InductionHeadInputs seq dModel dHead) : IO Unit := do let skipSynth := (← IO.getEnv "NFP_BENCH_SKIP_SYNTH").isSome if !skipSynth then - runDyadicBench seed + runRatBench seed let baseIters := match (← IO.getEnv "NFP_BENCH_ITERS") with | some raw => raw.toNat?.getD 200 diff --git a/Nfp/IO/Derive.lean b/Nfp/IO/Derive.lean index b18fff1..c092908 100644 --- a/Nfp/IO/Derive.lean +++ b/Nfp/IO/Derive.lean @@ -77,7 +77,7 @@ def deriveResidualIntervalFromModel (data : ByteArray) (start : Nat) | some _ => let forced := (List.finRange header.modelDim).foldl - (fun acc i => acc + bounds.1 i + bounds.2 i) (0 : Dyadic) + (fun acc i => acc + bounds.1 i + bounds.2 i) (0 : Rat) logTiming s!"forced transformer stack sum {forced}" | none => pure () return bounds) @@ -95,7 +95,7 @@ def deriveResidualIntervalFromModel (data : ByteArray) (start : Nat) | some _ => let forced := (List.finRange header.modelDim).foldl - (fun acc i => acc + stack.1 i + stack.2 i) (0 : Dyadic) + (fun acc i => acc + stack.1 i + stack.2 i) (0 : Rat) logTiming s!"forced transformer stack sum {forced}" | none => pure () return stack) @@ -117,7 +117,7 @@ def deriveResidualIntervalFromModel (data : ByteArray) (start : Nat) | some _ => let forced := (List.finRange header.modelDim).foldl - (fun acc i => acc + stack.1 i + stack.2 i) (0 : Dyadic) + (fun acc i => acc + stack.1 i + stack.2 i) (0 : Rat) logTiming s!"forced transformer stack sum {forced}" | none => pure () return stack) diff --git a/Nfp/IO/HeadScore.lean b/Nfp/IO/HeadScore.lean index c7dab88..b5fd6a5 100644 --- a/Nfp/IO/HeadScore.lean +++ b/Nfp/IO/HeadScore.lean @@ -13,13 +13,13 @@ namespace IO /-- Build a cached dot-abs function from Q/K absolute bounds using tasks. -/ def dotAbsFromQKV {seq dHead : Nat} - (qAbs kAbs : Fin seq → Fin dHead → Dyadic) : Fin seq → Fin seq → Dyadic := - let rowTasks : Array (Task (Array Dyadic)) := + (qAbs kAbs : Fin seq → Fin dHead → Rat) : Fin seq → Fin seq → Rat := + let rowTasks : Array (Task (Array Rat)) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => Array.ofFn (fun k : Fin seq => Sound.Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d)))) - let cache : Array (Array Dyadic) := + let cache : Array (Array Rat) := Array.ofFn (fun q : Fin seq => (rowTasks[q.1]'(by simp [rowTasks, q.isLt])).get) @@ -32,14 +32,14 @@ def dotAbsFromQKV {seq dHead : Nat} simp [hrow, k.isLt]) theorem dotAbsFromQKV_spec {seq dHead : Nat} - (qAbs kAbs : Fin seq → Fin dHead → Dyadic) : + (qAbs kAbs : Fin seq → Fin dHead → Rat) : dotAbsFromQKV qAbs kAbs = - let rowTasks : Array (Task (Array Dyadic)) := + let rowTasks : Array (Task (Array Rat)) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => Array.ofFn (fun k : Fin seq => Sound.Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d)))) - let cache : Array (Array Dyadic) := + let cache : Array (Array Rat) := Array.ofFn (fun q : Fin seq => (rowTasks[q.1]'(by simp [rowTasks, q.isLt])).get) diff --git a/Nfp/IO/NfptPure.lean b/Nfp/IO/NfptPure.lean index 96a52ae..0164c6e 100644 --- a/Nfp/IO/NfptPure.lean +++ b/Nfp/IO/NfptPure.lean @@ -9,7 +9,7 @@ import Nfp.Model.InductionPrompt /-! Pure parsing utilities for `NFP_BINARY_V1` model files. -These helpers parse headers and extract selected weight slices as dyadic values. +These helpers parse headers and extract selected weight slices as rational values. -/ namespace Nfp @@ -35,7 +35,7 @@ structure NfptHeader where /-- Sequence length used in the binary. -/ seqLen : Nat /-- LayerNorm epsilon parameter. -/ - layerNormEps : Dyadic + layerNormEps : Rat /-- Array with a fixed size proof. -/ structure SizedArray (n : Nat) (α : Type) where @@ -72,7 +72,7 @@ private def parseInt (s : String) : Except String Int := private def pow10 (k : Nat) : Nat := Nat.pow 10 k -private def parseDyadicScientific (s : String) : Except String Dyadic := do +private def parseRatScientific (s : String) : Except String Rat := do let s := s.trim let (sign, rest) := if s.startsWith "-" then @@ -105,13 +105,13 @@ private def parseDyadicScientific (s : String) : Except String Dyadic := do | some e => parseInt e if exp ≥ 0 then let k := Int.toNat exp - pure (dyadicOfRatDown (base * Rat.ofInt (Int.ofNat (pow10 k)))) + pure (ratRoundDown (base * Rat.ofInt (Int.ofNat (pow10 k)))) else let k := Int.toNat (-exp) - pure (dyadicOfRatDown (base / Rat.ofInt (Int.ofNat (pow10 k)))) + pure (ratRoundDown (base / Rat.ofInt (Int.ofNat (pow10 k)))) -private def readHeaderFieldDyadic (names : List String) (fields : List (String × String)) : - Except String Dyadic := do +private def readHeaderFieldRat (names : List String) (fields : List (String × String)) : + Except String Rat := do let rec loop : List String → Option String | [] => none | name :: rest => @@ -119,7 +119,7 @@ private def readHeaderFieldDyadic (names : List String) (fields : List (String | some kv => some kv.2 | none => loop rest match loop names with - | some raw => parseDyadicScientific raw + | some raw => parseRatScientific raw | none => throw s!"missing header field '{String.intercalate "|" names}'" private def sentinelBytes : ByteArray := @@ -169,7 +169,7 @@ def parseHeader (data : ByteArray) : Except String (NfptHeader × Nat) := do let hiddenDim ← readHeaderField "hidden_dim" fields let vocabSize ← readHeaderField "vocab_size" fields let seqLen ← readHeaderField "seq_len" fields - let layerNormEps ← readHeaderFieldDyadic ["layer_norm_eps", "eps"] fields + let layerNormEps ← readHeaderFieldRat ["layer_norm_eps", "eps"] fields if numLayers = 0 then throw "num_layers must be positive" if numHeads = 0 then @@ -200,7 +200,7 @@ private def pow2 (k : Nat) : Nat := private def getBits (n hi lo : Nat) : Nat := (n / pow2 lo) % pow2 (hi - lo + 1) -private def dyadicOfFloatBits (bits : Nat) : Option Dyadic := +private def ratOfFloatBits (bits : Nat) : Option Rat := let signBit := getBits bits 63 63 let expBits := getBits bits 62 52 let mantBits := getBits bits 51 0 @@ -212,13 +212,13 @@ private def dyadicOfFloatBits (bits : Nat) : Option Dyadic := some 0 else let num : Int := sign * Int.ofNat mantBits - some (Dyadic.ofIntWithPrec num 1074) + some (ratOfIntWithPrec num 1074) else let mant := mantBits + pow2 52 let exp := expBits - 1023 let shift : Int := Int.ofNat exp - 52 let prec : Int := -shift - some (Dyadic.ofIntWithPrec (sign * Int.ofNat mant) prec) + some (ratOfIntWithPrec (sign * Int.ofNat mant) prec) private def readNatLE (data : ByteArray) (off : Nat) (count : Nat) : Option Nat := if off + count ≤ data.size then @@ -241,9 +241,9 @@ private def readI32 (data : ByteArray) (off : Nat) : Option Int := do else some (Int.ofNat bits - Int.ofNat two32) -private def readF64Dyadic (data : ByteArray) (off : Nat) : Option Dyadic := do +private def readF64Rat (data : ByteArray) (off : Nat) : Option Rat := do let bits ← readNatLE data off 8 - dyadicOfFloatBits bits + ratOfFloatBits bits private def bytesI32 (n : Nat) : Nat := n * 4 @@ -258,13 +258,13 @@ private def sqrtNat? (n : Nat) : Option Nat := else none -private def scaleOfHeadDim (dHead : Nat) : Except String Dyadic := do +private def scaleOfHeadDim (dHead : Nat) : Except String Rat := do match sqrtNat? dHead with | some k => if k = 0 then throw "head_dim must be positive" else - pure (dyadicOfRatDown (Rat.ofInt 1 / Rat.ofInt (Int.ofNat k))) + pure (ratRoundDown (Rat.ofInt 1 / Rat.ofInt (Int.ofNat k))) | none => throw "head_dim must be a perfect square to compute scale" @@ -281,15 +281,15 @@ private def matrixIndex {rows cols : Nat} (i : Fin rows) (j : Fin cols) : Fin (r ⟨idx, lt_of_lt_of_le hstep hle⟩ private def readF64ListAux (data : ByteArray) (off : Nat) : - Nat → List Dyadic → Except String (List Dyadic) + Nat → List Rat → Except String (List Rat) | 0, acc => Except.ok acc.reverse | Nat.succ n, acc => - match readF64Dyadic data off with + match readF64Rat data off with | some v => readF64ListAux data (off + bytesF64 1) n (v :: acc) | none => Except.error s!"invalid f64 at offset {off}" private theorem readF64ListAux_length (data : ByteArray) : - ∀ (off n : Nat) (acc xs : List Dyadic), + ∀ (off n : Nat) (acc xs : List Rat), readF64ListAux data off n acc = Except.ok xs → xs.length = acc.length + n := by intro off n acc xs h @@ -300,7 +300,7 @@ private theorem readF64ListAux_length (data : ByteArray) : cases h' simp | succ n ih => - cases hread : readF64Dyadic data off with + cases hread : readF64Rat data off with | none => have h' := h simp only [readF64ListAux, hread] at h' @@ -312,7 +312,7 @@ private theorem readF64ListAux_length (data : ByteArray) : simpa [List.length, Nat.add_assoc, Nat.add_left_comm, Nat.add_comm] using hlen private def readF64List (data : ByteArray) (off : Nat) (count : Nat) : - Except String {xs : List Dyadic // xs.length = count} := + Except String {xs : List Rat // xs.length = count} := match h : readF64ListAux data off count [] with | Except.error msg => Except.error msg | Except.ok xs => @@ -365,23 +365,23 @@ private def readI32List (data : ByteArray) (off : Nat) (count : Nat) : Except.ok ⟨xs, hlen⟩ private def readF64Matrix (data : ByteArray) (off : Nat) (rows cols : Nat) : - Except String (Fin rows → Fin cols → Dyadic) := do + Except String (Fin rows → Fin cols → Rat) := do let count := rows * cols let ⟨vals, hlen⟩ ← readF64List data off count let hlen' : vals.length = rows * cols := by simpa using hlen - let mat : Fin rows → Fin cols → Dyadic := fun i j => + let mat : Fin rows → Fin cols → Rat := fun i j => let idx := matrixIndex i j let hidx : idx.val < vals.length := lt_of_lt_of_eq idx.isLt hlen'.symm vals.get ⟨idx.val, hidx⟩ return mat private def readF64Vec (data : ByteArray) (off : Nat) (count : Nat) : - Except String (Fin count → Dyadic) := do + Except String (Fin count → Rat) := do let ⟨vals, hlen⟩ ← readF64List data off count let hlen' : vals.length = count := by simpa using hlen - let vec : Fin count → Dyadic := fun i => + let vec : Fin count → Rat := fun i => vals.get ⟨i.val, lt_of_lt_of_eq i.isLt hlen'.symm⟩ return vec @@ -411,7 +411,7 @@ private def finalLayerNormOffset (h : NfptHeader) : Nat := /-- Read input embeddings stored in the binary. -/ def readEmbeddings (data : ByteArray) (start : Nat) (h : NfptHeader) : - Except String (Fin h.seqLen → Fin h.modelDim → Dyadic) := do + Except String (Fin h.seqLen → Fin h.modelDim → Rat) := do let base := start + bytesI32 h.seqLen readF64Matrix data base h.seqLen h.modelDim @@ -460,7 +460,7 @@ def readHeadWeights (data : ByteArray) (start : Nat) (h : NfptHeader) let bv ← readF64Vec data offbv h.headDim let offwo := offbv + bytesF64 h.headDim let woRaw ← readF64Matrix data offwo h.headDim h.modelDim - let wo : Fin h.modelDim → Fin h.headDim → Dyadic := fun i j => woRaw j i + let wo : Fin h.modelDim → Fin h.headDim → Rat := fun i j => woRaw j i return { wq := wq, bq := bq, wk := wk, bk := bk, wv := wv, bv := bv, wo := wo } else throw s!"head index out of range: {head}" @@ -469,8 +469,8 @@ def readHeadWeights (data : ByteArray) (start : Nat) (h : NfptHeader) private def readLayerAttnBiasLn1 (data : ByteArray) (start : Nat) (h : NfptHeader) (layer : Nat) : - Except String ((Fin h.modelDim → Dyadic) × (Fin h.modelDim → Dyadic) × - (Fin h.modelDim → Dyadic)) := do + Except String ((Fin h.modelDim → Rat) × (Fin h.modelDim → Rat) × + (Fin h.modelDim → Rat)) := do if layer < h.numLayers then let base := start + layerExtrasOffset h layer let attnBias ← readF64Vec data base h.modelDim @@ -570,17 +570,17 @@ def readFinalLayerNorm (data : ByteArray) (start : Nat) (h : NfptHeader) : /-- Read a single unembedding column as exact rationals. -/ def readUnembedColumn (data : ByteArray) (start : Nat) (h : NfptHeader) (col : Nat) : - Except String (Fin h.modelDim → Dyadic) := do + Except String (Fin h.modelDim → Rat) := do if col < h.vocabSize then let base := start + unembedOffset h let rows := List.range h.modelDim let vals ← rows.mapM (fun row => do let off := base + bytesF64 (row * h.vocabSize + col) - match readF64Dyadic data off with + match readF64Rat data off with | some v => pure v | none => throw s!"invalid f64 at offset {off}") if hlen : vals.length = h.modelDim then - let vec : Fin h.modelDim → Dyadic := fun i => + let vec : Fin h.modelDim → Rat := fun i => vals.get ⟨i.val, by simp [hlen]⟩ return vec else @@ -599,7 +599,7 @@ def readInductionHeadInputs (data : ByteArray) (start : Nat) (h : NfptHeader) let (attnBias, ln1Gamma, ln1Beta) ← readLayerAttnBiasLn1 data start h layer let colTarget ← readUnembedColumn data start h dirTarget let colNegative ← readUnembedColumn data start h dirNegative - let direction : Fin h.modelDim → Dyadic := fun i => colTarget i - colNegative i + let direction : Fin h.modelDim → Rat := fun i => colTarget i - colNegative i let directionSpec : Circuit.DirectionSpec := { target := dirTarget, negative := dirNegative } let active := @@ -627,7 +627,7 @@ def readInductionHeadInputs (data : ByteArray) (start : Nat) (h : NfptHeader) wo := weights.wo attnBias := attnBias maskCausal := true - maskValue := (-10000 : Dyadic) + maskValue := (-10000 : Rat) directionSpec := directionSpec direction := direction } diff --git a/Nfp/IO/Pure/Basic.lean b/Nfp/IO/Pure/Basic.lean index d0622cd..48481e9 100644 --- a/Nfp/IO/Pure/Basic.lean +++ b/Nfp/IO/Pure/Basic.lean @@ -53,18 +53,18 @@ def parseInt (s : String) : Except String Int := do let n ← parseNat s return Int.ofNat n -/-- Parse a dyadic literal from `a` or `a/b`, rounding down if needed. -/ -def parseDyadic (s : String) : Except String Dyadic := do +/-- Parse a rational literal from `a` or `a/b`, rounding down if needed. -/ +def parseRat (s : String) : Except String Rat := do match s.splitOn "/" with | [num] => - return dyadicOfRatDown (Rat.ofInt (← parseInt num)) + return ratRoundDown (Rat.ofInt (← parseInt num)) | [num, den] => let n ← parseInt num let d ← parseNat den if d = 0 then throw s!"invalid rational '{s}': zero denominator" else - return dyadicOfRatDown (Rat.divInt n (Int.ofNat d)) + return ratRoundDown (Rat.divInt n (Int.ofNat d)) | _ => throw s!"invalid rational '{s}'" diff --git a/Nfp/IO/Pure/Downstream.lean b/Nfp/IO/Pure/Downstream.lean index 82aa903..353d3ca 100644 --- a/Nfp/IO/Pure/Downstream.lean +++ b/Nfp/IO/Pure/Downstream.lean @@ -16,9 +16,9 @@ namespace Pure open Nfp.Circuit private structure DownstreamLinearParseState where - error : Option Dyadic - gain : Option Dyadic - inputBound : Option Dyadic + error : Option Rat + gain : Option Rat + inputBound : Option Rat private def initDownstreamLinearState : DownstreamLinearParseState := { error := none, gain := none, inputBound := none } @@ -30,17 +30,17 @@ private def parseDownstreamLinearLine (st : DownstreamLinearParseState) if st.error.isSome then throw "duplicate error entry" else - return { st with error := some (← parseDyadic val) } + return { st with error := some (← parseRat val) } | ["gain", val] => if st.gain.isSome then throw "duplicate gain entry" else - return { st with gain := some (← parseDyadic val) } + return { st with gain := some (← parseRat val) } | ["input-bound", val] => if st.inputBound.isSome then throw "duplicate input-bound entry" else - return { st with inputBound := some (← parseDyadic val) } + return { st with inputBound := some (← parseRat val) } | _ => throw s!"unrecognized line: '{String.intercalate " " tokens}'" @@ -87,20 +87,20 @@ private def matAllSome {α : Type} (mat : Array (Array (Option α))) : Bool := /-- Raw downstream matrix payload with an input bound. -/ structure DownstreamMatrixRaw (rows cols : Nat) where /-- Input magnitude bound. -/ - inputBound : Dyadic + inputBound : Rat /-- Matrix entries. -/ - entries : Fin rows → Fin cols → Dyadic + entries : Fin rows → Fin cols → Rat private structure DownstreamMatrixParseState (rows cols : Nat) where - inputBound : Option Dyadic - entries : Fin rows → Fin cols → Option Dyadic + inputBound : Option Rat + entries : Fin rows → Fin cols → Option Rat private def initDownstreamMatrixState (rows cols : Nat) : DownstreamMatrixParseState rows cols := { inputBound := none, entries := fun _ _ => none } -private def setRectEntry {rows cols : Nat} (mat : Fin rows → Fin cols → Option Dyadic) - (i j : Nat) (v : Dyadic) : Except String (Fin rows → Fin cols → Option Dyadic) := do +private def setRectEntry {rows cols : Nat} (mat : Fin rows → Fin cols → Option Rat) + (i j : Nat) (v : Rat) : Except String (Fin rows → Fin cols → Option Rat) := do if hi : i < rows then if hj : j < cols then let iFin : Fin rows := ⟨i, hi⟩ @@ -109,7 +109,7 @@ private def setRectEntry {rows cols : Nat} (mat : Fin rows → Fin cols → Opti | some _ => throw s!"duplicate matrix entry at ({i}, {j})" | none => - let mat' : Fin rows → Fin cols → Option Dyadic := fun i' j' => + let mat' : Fin rows → Fin cols → Option Rat := fun i' j' => if i' = iFin then if j' = jFin then some v @@ -131,9 +131,9 @@ private def parseDownstreamMatrixLine {rows cols : Nat} if st.inputBound.isSome then throw "duplicate input-bound entry" else - return { st with inputBound := some (← parseDyadic val) } + return { st with inputBound := some (← parseRat val) } | ["w", i, j, val] => - let mat ← setRectEntry st.entries (← parseNat i) (← parseNat j) (← parseDyadic val) + let mat ← setRectEntry st.entries (← parseNat i) (← parseNat j) (← parseRat val) return { st with entries := mat } | _ => throw s!"unrecognized line: '{String.intercalate " " tokens}'" @@ -148,7 +148,7 @@ private def finalizeDownstreamMatrixState {rows cols : Nat} if !finsetAll (Finset.univ : Finset (Fin rows)) (fun i => finsetAll (Finset.univ : Finset (Fin cols)) (fun j => (st.entries i j).isSome)) then throw "missing matrix entries" - let entries : Fin rows → Fin cols → Dyadic := fun i j => + let entries : Fin rows → Fin cols → Rat := fun i j => (st.entries i j).getD 0 return { inputBound := inputBound, entries := entries } diff --git a/Nfp/IO/Pure/InductionHead/Bytes.lean b/Nfp/IO/Pure/InductionHead/Bytes.lean index c4f188d..2333f70 100644 --- a/Nfp/IO/Pure/InductionHead/Bytes.lean +++ b/Nfp/IO/Pure/InductionHead/Bytes.lean @@ -104,10 +104,10 @@ private def findSlash (data : ByteArray) (i stop : Nat) : Option Nat := none termination_by stop - i -private def parseDyadicBytesSpec (data : ByteArray) (t : ByteToken) : Except String Dyadic := do +private def parseRatBytesSpec (data : ByteArray) (t : ByteToken) : Except String Rat := do match findSlash data t.start t.stop with | none => - return dyadicOfRatDown (Rat.ofInt (← parseIntBytesSpec data t)) + return ratRoundDown (Rat.ofInt (← parseIntBytesSpec data t)) | some s => let numTok : ByteToken := { start := t.start, stop := s } let denTok : ByteToken := { start := s + 1, stop := t.stop } @@ -116,13 +116,13 @@ private def parseDyadicBytesSpec (data : ByteArray) (t : ByteToken) : Except Str if d = 0 then throw "invalid rational: zero denominator" else - return dyadicOfRatDown (Rat.divInt n (Int.ofNat d)) + return ratRoundDown (Rat.divInt n (Int.ofNat d)) -private def parseDyadicBytes (data : ByteArray) (t : ByteToken) : Except String Dyadic := - parseDyadicBytesSpec data t +private def parseRatBytes (data : ByteArray) (t : ByteToken) : Except String Rat := + parseRatBytesSpec data t -theorem parseDyadicBytes_eq_spec (data : ByteArray) (t : ByteToken) : - parseDyadicBytes data t = parseDyadicBytesSpec data t := by +theorem parseRatBytes_eq_spec (data : ByteArray) (t : ByteToken) : + parseRatBytes data t = parseRatBytesSpec data t := by rfl private def nextLineBounds (data : ByteArray) (start : Nat) : Nat × Nat × Nat := @@ -184,15 +184,15 @@ private def parseNatAt (data : ByteArray) (i lineEnd : Nat) : let n ← parseNatBytes data tok return (n, i') -private def parseDyadicAt (data : ByteArray) (i lineEnd : Nat) : - Except String (Dyadic × Nat) := do +private def parseRatAt (data : ByteArray) (i lineEnd : Nat) : + Except String (Rat × Nat) := do let (tok, i') ← expectToken data i lineEnd - let r ← parseDyadicBytes data tok + let r ← parseRatBytes data tok return (r, i') -private def setVecEntry (n : Nat) (vec : Array (Option Dyadic)) - (i : Nat) (v : Dyadic) : - Except String (Array (Option Dyadic)) := do +private def setVecEntry (n : Nat) (vec : Array (Option Rat)) + (i : Nat) (v : Rat) : + Except String (Array (Option Rat)) := do if i < n then match vec.getD i none with | some _ => @@ -203,8 +203,8 @@ private def setVecEntry (n : Nat) (vec : Array (Option Dyadic)) else throw s!"index out of range: i={i}" -private def setMatEntry (rows cols : Nat) (mat : Array (Array (Option Dyadic))) - (i j : Nat) (v : Dyadic) : Except String (Array (Array (Option Dyadic))) := do +private def setMatEntry (rows cols : Nat) (mat : Array (Array (Option Rat))) + (i j : Nat) (v : Rat) : Except String (Array (Array (Option Rat))) := do if i < rows then if j < cols then let row := mat.getD i #[] @@ -220,10 +220,10 @@ private def setMatEntry (rows cols : Nat) (mat : Array (Array (Option Dyadic))) else throw s!"index out of range: i={i}" -private def initVecOpt (n : Nat) : Array (Option Dyadic) := +private def initVecOpt (n : Nat) : Array (Option Rat) := Array.replicate n none -private def initMatOpt (rows cols : Nat) : Array (Array (Option Dyadic)) := +private def initMatOpt (rows cols : Nat) : Array (Array (Option Rat)) := Array.replicate rows (initVecOpt cols) private def initPrevOpt (n : Nat) : Array (Option (Fin n)) := @@ -242,27 +242,27 @@ private def matAllSome {α : Type} (mat : Array (Array (Option α))) : Bool := (List.range mat.size).all (fun i => arrayAllSome (mat.getD i #[])) private structure HeadParseState (seq dModel dHead : Nat) where - scale : Option Dyadic + scale : Option Rat activeBits : Array Bool activeSeen : Bool prev : Array (Option (Fin seq)) - embed : Array (Array (Option Dyadic)) - lnEps : Option Dyadic - ln1Gamma : Array (Option Dyadic) - ln1Beta : Array (Option Dyadic) - wq : Array (Array (Option Dyadic)) - bq : Array (Option Dyadic) - wk : Array (Array (Option Dyadic)) - bk : Array (Option Dyadic) - wv : Array (Array (Option Dyadic)) - bv : Array (Option Dyadic) - wo : Array (Array (Option Dyadic)) - attnBias : Array (Option Dyadic) + embed : Array (Array (Option Rat)) + lnEps : Option Rat + ln1Gamma : Array (Option Rat) + ln1Beta : Array (Option Rat) + wq : Array (Array (Option Rat)) + bq : Array (Option Rat) + wk : Array (Array (Option Rat)) + bk : Array (Option Rat) + wv : Array (Array (Option Rat)) + bv : Array (Option Rat) + wo : Array (Array (Option Rat)) + attnBias : Array (Option Rat) maskCausal : Option Bool - maskValue : Option Dyadic + maskValue : Option Rat directionTarget : Option Nat directionNegative : Option Nat - direction : Array (Option Dyadic) + direction : Array (Option Rat) private def initHeadState (seq dModel dHead : Nat) : HeadParseState seq dModel dHead := { scale := none @@ -318,53 +318,53 @@ private def parseHeadLine {seq dModel dHead : Nat} (st : HeadParseState seq dMod if st.scale.isSome then throw "duplicate scale entry" else - return { st with scale := some (← parseDyadic val) } + return { st with scale := some (← parseRat val) } | ["active", q] => setHeadActive st (← parseNat q) | ["prev", q, k] => setHeadPrev st (← parseNat q) (← parseNat k) | ["embed", q, d, val] => let mat ← - setMatEntry seq dModel st.embed (← parseNat q) (← parseNat d) (← parseDyadic val) + setMatEntry seq dModel st.embed (← parseNat q) (← parseNat d) (← parseRat val) return { st with embed := mat } | ["ln_eps", val] => if st.lnEps.isSome then throw "duplicate ln_eps entry" else - return { st with lnEps := some (← parseDyadic val) } + return { st with lnEps := some (← parseRat val) } | ["ln1_gamma", d, val] => - let vec ← setVecEntry dModel st.ln1Gamma (← parseNat d) (← parseDyadic val) + let vec ← setVecEntry dModel st.ln1Gamma (← parseNat d) (← parseRat val) return { st with ln1Gamma := vec } | ["ln1_beta", d, val] => - let vec ← setVecEntry dModel st.ln1Beta (← parseNat d) (← parseDyadic val) + let vec ← setVecEntry dModel st.ln1Beta (← parseNat d) (← parseRat val) return { st with ln1Beta := vec } | ["wq", i, j, val] => let mat ← - setMatEntry dModel dHead st.wq (← parseNat i) (← parseNat j) (← parseDyadic val) + setMatEntry dModel dHead st.wq (← parseNat i) (← parseNat j) (← parseRat val) return { st with wq := mat } | ["bq", j, val] => - let vec ← setVecEntry dHead st.bq (← parseNat j) (← parseDyadic val) + let vec ← setVecEntry dHead st.bq (← parseNat j) (← parseRat val) return { st with bq := vec } | ["wk", i, j, val] => let mat ← - setMatEntry dModel dHead st.wk (← parseNat i) (← parseNat j) (← parseDyadic val) + setMatEntry dModel dHead st.wk (← parseNat i) (← parseNat j) (← parseRat val) return { st with wk := mat } | ["bk", j, val] => - let vec ← setVecEntry dHead st.bk (← parseNat j) (← parseDyadic val) + let vec ← setVecEntry dHead st.bk (← parseNat j) (← parseRat val) return { st with bk := vec } | ["wv", i, j, val] => let mat ← - setMatEntry dModel dHead st.wv (← parseNat i) (← parseNat j) (← parseDyadic val) + setMatEntry dModel dHead st.wv (← parseNat i) (← parseNat j) (← parseRat val) return { st with wv := mat } | ["bv", j, val] => - let vec ← setVecEntry dHead st.bv (← parseNat j) (← parseDyadic val) + let vec ← setVecEntry dHead st.bv (← parseNat j) (← parseRat val) return { st with bv := vec } | ["wo", i, j, val] => let mat ← - setMatEntry dModel dHead st.wo (← parseNat i) (← parseNat j) (← parseDyadic val) + setMatEntry dModel dHead st.wo (← parseNat i) (← parseNat j) (← parseRat val) return { st with wo := mat } | ["attn_bias", d, val] => - let vec ← setVecEntry dModel st.attnBias (← parseNat d) (← parseDyadic val) + let vec ← setVecEntry dModel st.attnBias (← parseNat d) (← parseRat val) return { st with attnBias := vec } | ["mask", kind] => if st.maskCausal.isSome then @@ -378,7 +378,7 @@ private def parseHeadLine {seq dModel dHead : Nat} (st : HeadParseState seq dMod if st.maskValue.isSome then throw "duplicate mask_value entry" else - return { st with maskValue := some (← parseDyadic val) } + return { st with maskValue := some (← parseRat val) } | ["direction-target", tok] => if st.directionTarget.isSome then throw "duplicate direction-target entry" @@ -390,7 +390,7 @@ private def parseHeadLine {seq dModel dHead : Nat} (st : HeadParseState seq dMod else return { st with directionNegative := some (← parseNat tok) } | ["direction", d, val] => - let vec ← setVecEntry dModel st.direction (← parseNat d) (← parseDyadic val) + let vec ← setVecEntry dModel st.direction (← parseNat d) (← parseRat val) return { st with direction := vec } | _ => throw s!"unrecognized line: '{String.intercalate " " tokens}'" @@ -418,7 +418,7 @@ private def parseHeadLineBytes {seq dModel dHead : Nat} (data : ByteArray) else let (t1, i2) ← expectToken data i1 lineEnd ensureNoMoreTokens data i2 lineEnd - return { st with scale := some (← parseDyadicBytes data t1) } + return { st with scale := some (← parseRatBytes data t1) } else throw "unrecognized line" | 97 => -- a @@ -428,7 +428,7 @@ private def parseHeadLineBytes {seq dModel dHead : Nat} (data : ByteArray) setHeadActive st q else if len = kwAttnBias.size && tokenEq data t0 kwAttnBias then let (d, i2) ← parseNatAt data i1 lineEnd - let (v, i3) ← parseDyadicAt data i2 lineEnd + let (v, i3) ← parseRatAt data i2 lineEnd ensureNoMoreTokens data i3 lineEnd let vec ← setVecEntry dModel st.attnBias d v return { st with attnBias := vec } @@ -446,7 +446,7 @@ private def parseHeadLineBytes {seq dModel dHead : Nat} (data : ByteArray) if len = kwEmbed.size && tokenEq data t0 kwEmbed then let (q, i2) ← parseNatAt data i1 lineEnd let (d, i3) ← parseNatAt data i2 lineEnd - let (v, i4) ← parseDyadicAt data i3 lineEnd + let (v, i4) ← parseRatAt data i3 lineEnd ensureNoMoreTokens data i4 lineEnd let mat ← setMatEntry seq dModel st.embed q d v return { st with embed := mat } @@ -457,18 +457,18 @@ private def parseHeadLineBytes {seq dModel dHead : Nat} (data : ByteArray) if st.lnEps.isSome then throw "duplicate ln_eps entry" else - let (v, i2) ← parseDyadicAt data i1 lineEnd + let (v, i2) ← parseRatAt data i1 lineEnd ensureNoMoreTokens data i2 lineEnd return { st with lnEps := some v } else if len = kwLn1Gamma.size && tokenEq data t0 kwLn1Gamma then let (d, i2) ← parseNatAt data i1 lineEnd - let (v, i3) ← parseDyadicAt data i2 lineEnd + let (v, i3) ← parseRatAt data i2 lineEnd ensureNoMoreTokens data i3 lineEnd let vec ← setVecEntry dModel st.ln1Gamma d v return { st with ln1Gamma := vec } else if len = kwLn1Beta.size && tokenEq data t0 kwLn1Beta then let (d, i2) ← parseNatAt data i1 lineEnd - let (v, i3) ← parseDyadicAt data i2 lineEnd + let (v, i3) ← parseRatAt data i2 lineEnd ensureNoMoreTokens data i3 lineEnd let vec ← setVecEntry dModel st.ln1Beta d v return { st with ln1Beta := vec } @@ -479,7 +479,7 @@ private def parseHeadLineBytes {seq dModel dHead : Nat} (data : ByteArray) let b1 := data.get! (t0.start + 1) let (i, i2) ← parseNatAt data i1 lineEnd let (j, i3) ← parseNatAt data i2 lineEnd - let (v, i4) ← parseDyadicAt data i3 lineEnd + let (v, i4) ← parseRatAt data i3 lineEnd ensureNoMoreTokens data i4 lineEnd if b1 = 113 then let mat ← setMatEntry dModel dHead st.wq i j v @@ -501,7 +501,7 @@ private def parseHeadLineBytes {seq dModel dHead : Nat} (data : ByteArray) if len = 2 then let b1 := data.get! (t0.start + 1) let (j, i2) ← parseNatAt data i1 lineEnd - let (v, i3) ← parseDyadicAt data i2 lineEnd + let (v, i3) ← parseRatAt data i2 lineEnd ensureNoMoreTokens data i3 lineEnd if b1 = 113 then let vec ← setVecEntry dHead st.bq j v @@ -533,7 +533,7 @@ private def parseHeadLineBytes {seq dModel dHead : Nat} (data : ByteArray) if st.maskValue.isSome then throw "duplicate mask_value entry" else - let (v, i2) ← parseDyadicAt data i1 lineEnd + let (v, i2) ← parseRatAt data i1 lineEnd ensureNoMoreTokens data i2 lineEnd return { st with maskValue := some v } else @@ -545,7 +545,7 @@ private def parseHeadLineBytes {seq dModel dHead : Nat} (data : ByteArray) return st else if len = kwDirection.size && tokenEq data t0 kwDirection then let (d, i2) ← parseNatAt data i1 lineEnd - let (v, i3) ← parseDyadicAt data i2 lineEnd + let (v, i3) ← parseRatAt data i2 lineEnd ensureNoMoreTokens data i3 lineEnd let vec ← setVecEntry dModel st.direction d v return { st with direction := vec } @@ -648,58 +648,58 @@ private def finalizeHeadState {seq dModel dHead : Nat} (hpos : 0 < seq) let defaultPrev : Fin seq := ⟨0, hpos⟩ let prevFun : Fin seq → Fin seq := fun q => (st.prev.getD q.1 none).getD defaultPrev - let embedArr : Array (Array Dyadic) := + let embedArr : Array (Array Rat) := st.embed.map (fun row => row.map (fun v => v.getD 0)) - let ln1GammaArr : Array Dyadic := + let ln1GammaArr : Array Rat := st.ln1Gamma.map (fun v => v.getD 0) - let ln1BetaArr : Array Dyadic := + let ln1BetaArr : Array Rat := st.ln1Beta.map (fun v => v.getD 0) - let wqArr : Array (Array Dyadic) := + let wqArr : Array (Array Rat) := st.wq.map (fun row => row.map (fun v => v.getD 0)) - let bqArr : Array Dyadic := + let bqArr : Array Rat := st.bq.map (fun v => v.getD 0) - let wkArr : Array (Array Dyadic) := + let wkArr : Array (Array Rat) := st.wk.map (fun row => row.map (fun v => v.getD 0)) - let bkArr : Array Dyadic := + let bkArr : Array Rat := st.bk.map (fun v => v.getD 0) - let wvArr : Array (Array Dyadic) := + let wvArr : Array (Array Rat) := st.wv.map (fun row => row.map (fun v => v.getD 0)) - let bvArr : Array Dyadic := + let bvArr : Array Rat := st.bv.map (fun v => v.getD 0) - let woArr : Array (Array Dyadic) := + let woArr : Array (Array Rat) := st.wo.map (fun row => row.map (fun v => v.getD 0)) - let attnBiasArr : Array Dyadic := + let attnBiasArr : Array Rat := st.attnBias.map (fun v => v.getD 0) - let directionArr : Array Dyadic := + let directionArr : Array Rat := st.direction.map (fun v => v.getD 0) - let embedFun : Fin seq → Fin dModel → Dyadic := fun q d => + let embedFun : Fin seq → Fin dModel → Rat := fun q d => (embedArr.getD q.1 #[]).getD d.1 0 - let ln1GammaFun : Fin dModel → Dyadic := fun d => + let ln1GammaFun : Fin dModel → Rat := fun d => ln1GammaArr.getD d.1 0 - let ln1BetaFun : Fin dModel → Dyadic := fun d => + let ln1BetaFun : Fin dModel → Rat := fun d => ln1BetaArr.getD d.1 0 - let wqFun : Fin dModel → Fin dHead → Dyadic := fun i j => + let wqFun : Fin dModel → Fin dHead → Rat := fun i j => (wqArr.getD i.1 #[]).getD j.1 0 - let bqFun : Fin dHead → Dyadic := fun j => + let bqFun : Fin dHead → Rat := fun j => bqArr.getD j.1 0 - let wkFun : Fin dModel → Fin dHead → Dyadic := fun i j => + let wkFun : Fin dModel → Fin dHead → Rat := fun i j => (wkArr.getD i.1 #[]).getD j.1 0 - let bkFun : Fin dHead → Dyadic := fun j => + let bkFun : Fin dHead → Rat := fun j => bkArr.getD j.1 0 - let wvFun : Fin dModel → Fin dHead → Dyadic := fun i j => + let wvFun : Fin dModel → Fin dHead → Rat := fun i j => (wvArr.getD i.1 #[]).getD j.1 0 - let bvFun : Fin dHead → Dyadic := fun j => + let bvFun : Fin dHead → Rat := fun j => bvArr.getD j.1 0 - let woFun : Fin dModel → Fin dHead → Dyadic := fun i j => + let woFun : Fin dModel → Fin dHead → Rat := fun i j => (woArr.getD i.1 #[]).getD j.1 0 - let attnBiasFun : Fin dModel → Dyadic := fun d => + let attnBiasFun : Fin dModel → Rat := fun d => attnBiasArr.getD d.1 0 let maskCausal := st.maskCausal.getD false let maskValue := match st.maskValue with | some v => v - | none => if maskCausal then (-10000 : Dyadic) else 0 - let directionFun : Fin dModel → Dyadic := fun d => + | none => if maskCausal then (-10000 : Rat) else 0 + let directionFun : Fin dModel → Rat := fun d => directionArr.getD d.1 0 let active := if st.activeSeen then diff --git a/Nfp/IO/Pure/Residual.lean b/Nfp/IO/Pure/Residual.lean index 72fb5c6..d02ef07 100644 --- a/Nfp/IO/Pure/Residual.lean +++ b/Nfp/IO/Pure/Residual.lean @@ -17,20 +17,20 @@ namespace Pure open Nfp.Circuit private structure ResidualBoundParseState (n : Nat) where - bounds : Fin n → Option Dyadic + bounds : Fin n → Option Rat private def initResidualBoundState (n : Nat) : ResidualBoundParseState n := { bounds := fun _ => none } -private def setVectorEntry {n : Nat} (bounds : Fin n → Option Dyadic) - (i : Nat) (v : Dyadic) : Except String (Fin n → Option Dyadic) := do +private def setVectorEntry {n : Nat} (bounds : Fin n → Option Rat) + (i : Nat) (v : Rat) : Except String (Fin n → Option Rat) := do if hi : i < n then let iFin : Fin n := ⟨i, hi⟩ match bounds iFin with | some _ => throw s!"duplicate bound entry at index {i}" | none => - let bounds' : Fin n → Option Dyadic := fun i' => + let bounds' : Fin n → Option Rat := fun i' => if i' = iFin then some v else @@ -43,7 +43,7 @@ private def parseResidualBoundLine {n : Nat} (st : ResidualBoundParseState n) (tokens : List String) : Except String (ResidualBoundParseState n) := do match tokens with | ["bound", i, val] => - let bounds ← setVectorEntry st.bounds (← parseNat i) (← parseDyadic val) + let bounds ← setVectorEntry st.bounds (← parseNat i) (← parseRat val) return { st with bounds := bounds } | ["dim", _] => throw "duplicate dim entry" @@ -54,7 +54,7 @@ private def finalizeResidualBoundState {n : Nat} (st : ResidualBoundParseState n Except String (Circuit.ResidualBoundCert n) := do if !finsetAll (Finset.univ : Finset (Fin n)) (fun i => (st.bounds i).isSome) then throw "missing bound entries" - let bound : Fin n → Dyadic := fun i => + let bound : Fin n → Rat := fun i => (st.bounds i).getD 0 return { bound := bound } @@ -78,8 +78,8 @@ def parseResidualBoundCert (input : String) : | _ => throw "expected header 'dim '" private structure ResidualIntervalParseState (n : Nat) where - lo : Fin n → Option Dyadic - hi : Fin n → Option Dyadic + lo : Fin n → Option Rat + hi : Fin n → Option Rat private def initResidualIntervalState (n : Nat) : ResidualIntervalParseState n := { lo := fun _ => none, hi := fun _ => none } @@ -88,10 +88,10 @@ private def parseResidualIntervalLine {n : Nat} (st : ResidualIntervalParseState (tokens : List String) : Except String (ResidualIntervalParseState n) := do match tokens with | ["lo", i, val] => - let lo ← setVectorEntry st.lo (← parseNat i) (← parseDyadic val) + let lo ← setVectorEntry st.lo (← parseNat i) (← parseRat val) return { st with lo := lo } | ["hi", i, val] => - let hi ← setVectorEntry st.hi (← parseNat i) (← parseDyadic val) + let hi ← setVectorEntry st.hi (← parseNat i) (← parseRat val) return { st with hi := hi } | ["dim", _] => throw "duplicate dim entry" @@ -104,9 +104,9 @@ private def finalizeResidualIntervalState {n : Nat} (st : ResidualIntervalParseS throw "missing lo entries" if !finsetAll (Finset.univ : Finset (Fin n)) (fun i => (st.hi i).isSome) then throw "missing hi entries" - let lo : Fin n → Dyadic := fun i => + let lo : Fin n → Rat := fun i => (st.lo i).getD 0 - let hi : Fin n → Dyadic := fun i => + let hi : Fin n → Rat := fun i => (st.hi i).getD 0 return { lo := lo, hi := hi } diff --git a/Nfp/IO/Pure/SoftmaxMargin/Cert.lean b/Nfp/IO/Pure/SoftmaxMargin/Cert.lean index 29bd6aa..e43e378 100644 --- a/Nfp/IO/Pure/SoftmaxMargin/Cert.lean +++ b/Nfp/IO/Pure/SoftmaxMargin/Cert.lean @@ -36,9 +36,9 @@ private def finalizeState {seq : Nat} (hpos : 0 < seq) let defaultPrev : Fin seq := ⟨0, hpos⟩ let prevFun : Fin seq → Fin seq := fun q => (st.prev q).getD defaultPrev - let scoresFun : Fin seq → Fin seq → Dyadic := fun q k => + let scoresFun : Fin seq → Fin seq → Rat := fun q k => (st.scores q k).getD 0 - let weightsFun : Fin seq → Fin seq → Dyadic := fun q k => + let weightsFun : Fin seq → Fin seq → Rat := fun q k => (st.weights q k).getD 0 let active := if st.activeSeen then diff --git a/Nfp/IO/Pure/SoftmaxMargin/Raw.lean b/Nfp/IO/Pure/SoftmaxMargin/Raw.lean index 005d024..35d787f 100644 --- a/Nfp/IO/Pure/SoftmaxMargin/Raw.lean +++ b/Nfp/IO/Pure/SoftmaxMargin/Raw.lean @@ -22,9 +22,9 @@ structure SoftmaxMarginRaw (seq : Nat) where /-- `prev` selector for induction-style attention. -/ prev : Fin seq → Fin seq /-- Score matrix entries. -/ - scores : Fin seq → Fin seq → Dyadic + scores : Fin seq → Fin seq → Rat /-- Attention weight entries. -/ - weights : Fin seq → Fin seq → Dyadic + weights : Fin seq → Fin seq → Rat private def finalizeRawState {seq : Nat} (hpos : 0 < seq) (st : SoftmaxMargin.ParseState seq) : Except String (SoftmaxMarginRaw seq) := do @@ -39,9 +39,9 @@ private def finalizeRawState {seq : Nat} (hpos : 0 < seq) let defaultPrev : Fin seq := ⟨0, hpos⟩ let prevFun : Fin seq → Fin seq := fun q => (st.prev q).getD defaultPrev - let scoresFun : Fin seq → Fin seq → Dyadic := fun q k => + let scoresFun : Fin seq → Fin seq → Rat := fun q k => (st.scores q k).getD 0 - let weightsFun : Fin seq → Fin seq → Dyadic := fun q k => + let weightsFun : Fin seq → Fin seq → Rat := fun q k => (st.weights q k).getD 0 let active := if st.activeSeen then diff --git a/Nfp/IO/Pure/SoftmaxMargin/Shared.lean b/Nfp/IO/Pure/SoftmaxMargin/Shared.lean index 7430cff..2939c0a 100644 --- a/Nfp/IO/Pure/SoftmaxMargin/Shared.lean +++ b/Nfp/IO/Pure/SoftmaxMargin/Shared.lean @@ -18,9 +18,9 @@ namespace SoftmaxMargin /-- State for parsing softmax-margin payloads. -/ structure ParseState (seq : Nat) where /-- Optional epsilon bound. -/ - eps : Option Dyadic + eps : Option Rat /-- Optional margin bound. -/ - margin : Option Dyadic + margin : Option Rat /-- Active query set. -/ active : Finset (Fin seq) /-- Whether any active entries were parsed. -/ @@ -28,9 +28,9 @@ structure ParseState (seq : Nat) where /-- Optional predecessor pointer per query. -/ prev : Fin seq → Option (Fin seq) /-- Optional score matrix entries. -/ - scores : Fin seq → Fin seq → Option Dyadic + scores : Fin seq → Fin seq → Option Rat /-- Optional weight matrix entries. -/ - weights : Fin seq → Fin seq → Option Dyadic + weights : Fin seq → Fin seq → Option Rat /-- Initialize a softmax-margin parse state. -/ def initState (seq : Nat) : ParseState seq := @@ -75,8 +75,8 @@ def setActive {seq : Nat} (st : ParseState seq) (q : Nat) : Except String (Parse throw s!"active index out of range: q={q}" /-- Insert a matrix entry for scores/weights. -/ -def setMatrixEntry {seq : Nat} (mat : Fin seq → Fin seq → Option Dyadic) - (q k : Nat) (v : Dyadic) : Except String (Fin seq → Fin seq → Option Dyadic) := do +def setMatrixEntry {seq : Nat} (mat : Fin seq → Fin seq → Option Rat) + (q k : Nat) (v : Rat) : Except String (Fin seq → Fin seq → Option Rat) := do if hq : q < seq then if hk : k < seq then let qFin : Fin seq := ⟨q, hq⟩ @@ -85,7 +85,7 @@ def setMatrixEntry {seq : Nat} (mat : Fin seq → Fin seq → Option Dyadic) | some _ => throw s!"duplicate matrix entry at ({q}, {k})" | none => - let mat' : Fin seq → Fin seq → Option Dyadic := fun q' k' => + let mat' : Fin seq → Fin seq → Option Rat := fun q' k' => if q' = qFin then if k' = kFin then some v @@ -107,21 +107,21 @@ def parseLine {seq : Nat} (st : ParseState seq) if st.eps.isSome then throw "duplicate eps entry" else - return { st with eps := some (← parseDyadic val) } + return { st with eps := some (← parseRat val) } | ["margin", val] => if st.margin.isSome then throw "duplicate margin entry" else - return { st with margin := some (← parseDyadic val) } + return { st with margin := some (← parseRat val) } | ["active", q] => setActive st (← parseNat q) | ["prev", q, k] => setPrev st (← parseNat q) (← parseNat k) | ["score", q, k, val] => - let mat ← setMatrixEntry st.scores (← parseNat q) (← parseNat k) (← parseDyadic val) + let mat ← setMatrixEntry st.scores (← parseNat q) (← parseNat k) (← parseRat val) return { st with scores := mat } | ["weight", q, k, val] => - let mat ← setMatrixEntry st.weights (← parseNat q) (← parseNat k) (← parseDyadic val) + let mat ← setMatrixEntry st.weights (← parseNat q) (← parseNat k) (← parseRat val) return { st with weights := mat } | _ => throw s!"unrecognized line: '{String.intercalate " " tokens}'" diff --git a/Nfp/IO/Pure/ValueRange/Cert.lean b/Nfp/IO/Pure/ValueRange/Cert.lean index 87edb9b..5a54f32 100644 --- a/Nfp/IO/Pure/ValueRange/Cert.lean +++ b/Nfp/IO/Pure/ValueRange/Cert.lean @@ -27,7 +27,7 @@ private def finalizeValueState {seq : Nat} (st : ValueRange.ParseState seq) : | none => throw "missing hi entry" if !finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.vals k).isSome) then throw "missing value entries" - let valsFun : Fin seq → Dyadic := fun k => + let valsFun : Fin seq → Rat := fun k => (st.vals k).getD 0 let direction ← match st.directionTarget, st.directionNegative with diff --git a/Nfp/IO/Pure/ValueRange/Raw.lean b/Nfp/IO/Pure/ValueRange/Raw.lean index f7c74fc..a9da85b 100644 --- a/Nfp/IO/Pure/ValueRange/Raw.lean +++ b/Nfp/IO/Pure/ValueRange/Raw.lean @@ -18,7 +18,7 @@ open Nfp.Circuit /-- Raw value-range payload without `lo`/`hi` bounds. -/ structure ValueRangeRaw (seq : Nat) where /-- Value entries. -/ - vals : Fin seq → Dyadic + vals : Fin seq → Rat /-- Optional logit-diff direction metadata. -/ direction : Option Circuit.DirectionSpec @@ -26,7 +26,7 @@ private def finalizeValueRawState {seq : Nat} (st : ValueRange.ParseState seq) : Except String (ValueRangeRaw seq) := do if !finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.vals k).isSome) then throw "missing value entries" - let valsFun : Fin seq → Dyadic := fun k => + let valsFun : Fin seq → Rat := fun k => (st.vals k).getD 0 let direction ← match st.directionTarget, st.directionNegative with diff --git a/Nfp/IO/Pure/ValueRange/Shared.lean b/Nfp/IO/Pure/ValueRange/Shared.lean index 4a4c621..441600f 100644 --- a/Nfp/IO/Pure/ValueRange/Shared.lean +++ b/Nfp/IO/Pure/ValueRange/Shared.lean @@ -20,11 +20,11 @@ open Nfp.Circuit /-- State for parsing value-range payloads. -/ structure ParseState (seq : Nat) where /-- Optional lower bound. -/ - lo : Option Dyadic + lo : Option Rat /-- Optional upper bound. -/ - hi : Option Dyadic + hi : Option Rat /-- Optional per-position values. -/ - vals : Fin seq → Option Dyadic + vals : Fin seq → Option Rat /-- Optional direction target index. -/ directionTarget : Option Nat /-- Optional direction negative index. -/ @@ -41,7 +41,7 @@ def initState (seq : Nat) : ParseState seq := /-- Set a value entry from `(k, v)` tokens. -/ -def setVal {seq : Nat} (st : ParseState seq) (k : Nat) (v : Dyadic) : +def setVal {seq : Nat} (st : ParseState seq) (k : Nat) (v : Rat) : Except String (ParseState seq) := do if hk : k < seq then let kFin : Fin seq := ⟨k, hk⟩ @@ -49,7 +49,7 @@ def setVal {seq : Nat} (st : ParseState seq) (k : Nat) (v : Dyadic) : | some _ => throw s!"duplicate value entry for k={k}" | none => - let vals' : Fin seq → Option Dyadic := fun k' => + let vals' : Fin seq → Option Rat := fun k' => if k' = kFin then some v else @@ -67,14 +67,14 @@ def parseLine {seq : Nat} (st : ParseState seq) if st.lo.isSome then throw "duplicate lo entry" else - return { st with lo := some (← parseDyadic val) } + return { st with lo := some (← parseRat val) } | ["hi", val] => if st.hi.isSome then throw "duplicate hi entry" else - return { st with hi := some (← parseDyadic val) } + return { st with hi := some (← parseRat val) } | ["val", k, val] => - setVal st (← parseNat k) (← parseDyadic val) + setVal st (← parseNat k) (← parseRat val) | ["direction-target", tok] => if st.directionTarget.isSome then throw "duplicate direction-target entry" diff --git a/Nfp/IO/Timing.lean b/Nfp/IO/Timing.lean index b510e7e..49006a1 100644 --- a/Nfp/IO/Timing.lean +++ b/Nfp/IO/Timing.lean @@ -114,21 +114,21 @@ def timeHeadScoreMarginList {seq dModel dHead : Nat} /-- Force marginAt evaluation without constructing the full score bounds record. -/ def timeHeadScoreMarginRaw {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (dotAbs : Fin seq → Fin seq → Dyadic) + (dotAbs : Fin seq → Fin seq → Rat) (activeList : List (Fin seq)) : IO Unit := do IO.println "timing: head score marginRaw list start" (← IO.getStdout).flush let t0 ← monoUsNow let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k - let scoreBaseAbs : Fin seq → Fin seq → Dyadic := fun q k => + let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => |inputs.scale| * dotAbs q k - let scoreLo : Fin seq → Fin seq → Dyadic := fun q k => + let scoreLo : Fin seq → Fin seq → Rat := fun q k => if masked q k then inputs.maskValue else -scoreBaseAbs q k - let scoreHi : Fin seq → Fin seq → Dyadic := fun q k => + let scoreHi : Fin seq → Fin seq → Rat := fun q k => if masked q k then inputs.maskValue else @@ -142,11 +142,11 @@ def timeHeadScoreMarginRaw {seq dModel dHead : Nat} (∅ : Finset (Fin seq)) let unmaskedKeys : Fin seq → Finset (Fin seq) := fun q => (otherKeys q) \ (maskedKeys q) - let maskedGap : Fin seq → Dyadic := fun q => + let maskedGap : Fin seq → Rat := fun q => scoreLo q (inputs.prev q) - inputs.maskValue - let scoreGap : Fin seq → Fin seq → Dyadic := fun q k => + let scoreGap : Fin seq → Fin seq → Rat := fun q k => scoreLo q (inputs.prev q) - scoreHi q k - let marginAtRaw : Fin seq → Dyadic := fun q => + let marginAtRaw : Fin seq → Rat := fun q => let other := unmaskedKeys q let maskedSet := maskedKeys q if hunmasked : other.Nonempty then @@ -159,7 +159,7 @@ def timeHeadScoreMarginRaw {seq dModel dHead : Nat} if _hmasked : maskedSet.Nonempty then maskedGap q else - (0 : Dyadic) + (0 : Rat) for q in activeList do let _ := marginAtRaw q pure () diff --git a/Nfp/IO/Util.lean b/Nfp/IO/Util.lean index b7fe877..bb374fe 100644 --- a/Nfp/IO/Util.lean +++ b/Nfp/IO/Util.lean @@ -10,13 +10,13 @@ namespace Nfp namespace IO -/-- Parse an optional dyadic literal for CLI flags (rounded down if needed). -/ -def parseDyadicOpt (label : String) (raw? : Option String) : - Except String (Option Dyadic) := +/-- Parse an optional rational literal for CLI flags (rounded down if needed). -/ +def parseRatOpt (label : String) (raw? : Option String) : + Except String (Option Rat) := match raw? with | none => Except.ok none | some raw => - match Pure.parseDyadic raw with + match Pure.parseRat raw with | Except.ok v => Except.ok (some v) | Except.error msg => Except.error s!"invalid {label}: {msg}" diff --git a/Nfp/Model/Gpt2.lean b/Nfp/Model/Gpt2.lean index 975a2cd..bfd9644 100644 --- a/Nfp/Model/Gpt2.lean +++ b/Nfp/Model/Gpt2.lean @@ -31,92 +31,92 @@ def DirectionTokens.spec {vocab : Nat} (dir : DirectionTokens vocab) : Direction /-- Exact GPT-2 head slice needed to build induction-head inputs. -/ structure Gpt2HeadSlice (seq dModel dHead vocab : Nat) where /-- Softmax scale factor (e.g. `1/8` for head dim 64). -/ - scale : Dyadic + scale : Rat /-- Token ids for the prompt. -/ tokens : Fin seq → Fin vocab /-- Token embedding matrix. -/ - wte : Fin vocab → Fin dModel → Dyadic + wte : Fin vocab → Fin dModel → Rat /-- Positional embedding matrix. -/ - wpe : Fin seq → Fin dModel → Dyadic + wpe : Fin seq → Fin dModel → Rat /-- Query projection weights. -/ - wq : Fin dModel → Fin dHead → Dyadic + wq : Fin dModel → Fin dHead → Rat /-- Query projection bias. -/ - bq : Fin dHead → Dyadic + bq : Fin dHead → Rat /-- Key projection weights. -/ - wk : Fin dModel → Fin dHead → Dyadic + wk : Fin dModel → Fin dHead → Rat /-- Key projection bias. -/ - bk : Fin dHead → Dyadic + bk : Fin dHead → Rat /-- Value projection weights. -/ - wv : Fin dModel → Fin dHead → Dyadic + wv : Fin dModel → Fin dHead → Rat /-- Value projection bias. -/ - bv : Fin dHead → Dyadic + bv : Fin dHead → Rat /-- Output projection weights for this head slice. -/ - wo : Fin dModel → Fin dHead → Dyadic + wo : Fin dModel → Fin dHead → Rat /-- Attention output bias (shared across heads). -/ - attnBias : Fin dModel → Dyadic + attnBias : Fin dModel → Rat /-- LayerNorm epsilon for the attention input. -/ - lnEps : Dyadic + lnEps : Rat /-- LayerNorm scale for the attention input. -/ - ln1Gamma : Fin dModel → Dyadic + ln1Gamma : Fin dModel → Rat /-- LayerNorm bias for the attention input. -/ - ln1Beta : Fin dModel → Dyadic + ln1Beta : Fin dModel → Rat /-- Direction tokens for logit-diff certification. -/ direction : DirectionTokens vocab /-- Exact per-head attention weights and biases. -/ structure Gpt2HeadWeights (dModel dHead : Nat) where /-- Query projection weights. -/ - wq : Fin dModel → Fin dHead → Dyadic + wq : Fin dModel → Fin dHead → Rat /-- Query projection bias. -/ - bq : Fin dHead → Dyadic + bq : Fin dHead → Rat /-- Key projection weights. -/ - wk : Fin dModel → Fin dHead → Dyadic + wk : Fin dModel → Fin dHead → Rat /-- Key projection bias. -/ - bk : Fin dHead → Dyadic + bk : Fin dHead → Rat /-- Value projection weights. -/ - wv : Fin dModel → Fin dHead → Dyadic + wv : Fin dModel → Fin dHead → Rat /-- Value projection bias. -/ - bv : Fin dHead → Dyadic + bv : Fin dHead → Rat /-- Output projection weights for this head slice. -/ - wo : Fin dModel → Fin dHead → Dyadic + wo : Fin dModel → Fin dHead → Rat /-- Exact GPT-2 layer slice with MLP and LayerNorm parameters. -/ structure Gpt2LayerSlice (dModel hidden : Nat) where /-- Attention output bias (shared across heads). -/ - attnBias : Fin dModel → Dyadic + attnBias : Fin dModel → Rat /-- MLP input projection weights. -/ - mlpWIn : Fin dModel → Fin hidden → Dyadic + mlpWIn : Fin dModel → Fin hidden → Rat /-- MLP input projection bias. -/ - mlpBIn : Fin hidden → Dyadic + mlpBIn : Fin hidden → Rat /-- MLP output projection weights. -/ - mlpWOut : Fin hidden → Fin dModel → Dyadic + mlpWOut : Fin hidden → Fin dModel → Rat /-- MLP output projection bias. -/ - mlpBOut : Fin dModel → Dyadic + mlpBOut : Fin dModel → Rat /-- LayerNorm scale for the attention input. -/ - ln1Gamma : Fin dModel → Dyadic + ln1Gamma : Fin dModel → Rat /-- LayerNorm bias for the attention input. -/ - ln1Beta : Fin dModel → Dyadic + ln1Beta : Fin dModel → Rat /-- LayerNorm scale for the MLP input. -/ - ln2Gamma : Fin dModel → Dyadic + ln2Gamma : Fin dModel → Rat /-- LayerNorm bias for the MLP input. -/ - ln2Beta : Fin dModel → Dyadic + ln2Beta : Fin dModel → Rat /-- Final LayerNorm parameters applied before unembedding. -/ structure Gpt2FinalLayerNorm (dModel : Nat) where /-- LayerNorm scale. -/ - gamma : Fin dModel → Dyadic + gamma : Fin dModel → Rat /-- LayerNorm bias. -/ - beta : Fin dModel → Dyadic + beta : Fin dModel → Rat /-- Token-plus-position embeddings for a GPT-2 head slice. -/ def Gpt2HeadSlice.embed {seq dModel dHead vocab : Nat} (slice : Gpt2HeadSlice seq dModel dHead vocab) : - Fin seq → Fin dModel → Dyadic := + Fin seq → Fin dModel → Rat := fun q d => slice.wte (slice.tokens q) d + slice.wpe q d /-- Direction vector in model space for a GPT-2 head slice. -/ def Gpt2HeadSlice.directionVec {seq dModel dHead vocab : Nat} - (slice : Gpt2HeadSlice seq dModel dHead vocab) : Fin dModel → Dyadic := + (slice : Gpt2HeadSlice seq dModel dHead vocab) : Fin dModel → Rat := fun d => slice.wte slice.direction.target d - slice.wte slice.direction.negative d end Model diff --git a/Nfp/Model/InductionHead.lean b/Nfp/Model/InductionHead.lean index 10df33e..697652d 100644 --- a/Nfp/Model/InductionHead.lean +++ b/Nfp/Model/InductionHead.lean @@ -7,7 +7,7 @@ import Nfp.Circuit.Cert.ValueRange /-! Exact inputs for induction-head scoring and value-direction computations. -These structures store exact dyadic inputs (embeddings and weights) for a +These structures store exact rational inputs (embeddings and weights) for a single attention head. They are intended to be consumed by sound builders. -/ @@ -20,43 +20,43 @@ open Nfp.Circuit /-- Exact head inputs for induction certification. -/ structure InductionHeadInputs (seq dModel dHead : Nat) where /-- Softmax scale factor (e.g. `1/8` for GPT-2-small head dim 64). -/ - scale : Dyadic + scale : Rat /-- Active queries for which bounds are required. -/ active : Finset (Fin seq) /-- `prev` selector for induction-style attention. -/ prev : Fin seq → Fin seq /-- Token embeddings for the sequence. -/ - embed : Fin seq → Fin dModel → Dyadic + embed : Fin seq → Fin dModel → Rat /-- LayerNorm epsilon used before attention. -/ - lnEps : Dyadic + lnEps : Rat /-- LayerNorm scale for pre-attention normalization. -/ - ln1Gamma : Fin dModel → Dyadic + ln1Gamma : Fin dModel → Rat /-- LayerNorm bias for pre-attention normalization. -/ - ln1Beta : Fin dModel → Dyadic + ln1Beta : Fin dModel → Rat /-- Query projection weights. -/ - wq : Fin dModel → Fin dHead → Dyadic + wq : Fin dModel → Fin dHead → Rat /-- Query projection bias. -/ - bq : Fin dHead → Dyadic + bq : Fin dHead → Rat /-- Key projection weights. -/ - wk : Fin dModel → Fin dHead → Dyadic + wk : Fin dModel → Fin dHead → Rat /-- Key projection bias. -/ - bk : Fin dHead → Dyadic + bk : Fin dHead → Rat /-- Value projection weights. -/ - wv : Fin dModel → Fin dHead → Dyadic + wv : Fin dModel → Fin dHead → Rat /-- Value projection bias. -/ - bv : Fin dHead → Dyadic + bv : Fin dHead → Rat /-- Output projection weights (head slice). -/ - wo : Fin dModel → Fin dHead → Dyadic + wo : Fin dModel → Fin dHead → Rat /-- Attention output bias (shared across heads). -/ - attnBias : Fin dModel → Dyadic + attnBias : Fin dModel → Rat /-- Whether to apply a causal mask to attention scores. -/ maskCausal : Bool /-- Score value for masked entries (e.g. `-10000` for GPT-2 causal masking). -/ - maskValue : Dyadic + maskValue : Rat /-- Logit-diff direction metadata. -/ directionSpec : DirectionSpec /-- Logit-diff direction vector in model space. -/ - direction : Fin dModel → Dyadic + direction : Fin dModel → Rat end Model diff --git a/Nfp/Sound/Bounds/Attention.lean b/Nfp/Sound/Bounds/Attention.lean index a60b3d7..950e8de 100644 --- a/Nfp/Sound/Bounds/Attention.lean +++ b/Nfp/Sound/Bounds/Attention.lean @@ -3,10 +3,11 @@ import Mathlib.Algebra.BigOperators.Field import Mathlib.Algebra.BigOperators.Ring.Finset import Mathlib.Algebra.Order.BigOperators.Group.Finset -import Nfp.Core.Basic import Mathlib.Data.Real.Basic import Nfp.Circuit.Layers.Softmax +import Nfp.Core.Basic import Nfp.Model.Gpt2 +import Nfp.Sound.Bounds.Cache import Nfp.Sound.Bounds.LayerNorm import Nfp.Sound.Bounds.MatrixNorm import Nfp.Sound.Bounds.Mlp @@ -23,134 +24,11 @@ namespace Bounds open scoped BigOperators -/-! -Caching helpers for interval bounds. --/ - -/-- Cache a bound function in an array-backed lookup to avoid repeated evaluation. -/ -def cacheBound {n : Nat} (f : Fin n → Dyadic) : Fin n → Dyadic := - let data : Thunk (Array Dyadic) := Thunk.mk (fun _ => Array.ofFn f) - fun i => (Thunk.get data)[i.1]'(by - have hsize : (Thunk.get data).size = n := by - simp [Thunk.get, data] - simp [hsize]) - -/-- `cacheBound` preserves pointwise values. -/ -theorem cacheBound_apply {n : Nat} (f : Fin n → Dyadic) (i : Fin n) : - cacheBound f i = f i := by - simp [cacheBound, Thunk.get, Array.getElem_ofFn] - -/-- Cache a bound function on two indices. -/ -def cacheBound2 {m n : Nat} (f : Fin m → Fin n → Dyadic) : Fin m → Fin n → Dyadic := - let data : Thunk (Array (Thunk (Array Dyadic))) := Thunk.mk (fun _ => - Array.ofFn (fun q => Thunk.mk (fun _ => Array.ofFn (f q)))) - fun q i => - let rowThunk := (Thunk.get data)[q.1]'(by - have hsize : (Thunk.get data).size = m := by - simp [Thunk.get, data] - rw [hsize] - exact q.isLt) - let row := Thunk.get rowThunk - row[i.1]'(by - have hsize : row.size = n := by - simp [row, rowThunk, Thunk.get, data, Array.getElem_ofFn] - rw [hsize] - exact i.isLt) - -/-- `cacheBound2` preserves pointwise values. -/ -theorem cacheBound2_apply {m n : Nat} (f : Fin m → Fin n → Dyadic) (q : Fin m) (i : Fin n) : - cacheBound2 f q i = f q i := by - simp [cacheBound2, Thunk.get, Array.getElem_ofFn] - -/-- Cache a bound function on two indices using row tasks for parallel evaluation. -/ -def cacheBound2Task {m n : Nat} (f : Fin m → Fin n → Dyadic) : Fin m → Fin n → Dyadic := - let rowTasks : Array (Task { row : Array Dyadic // row.size = n }) := - Array.ofFn (fun q : Fin m => - Task.spawn (fun _ => ⟨Array.ofFn (f q), by simp⟩)) - fun q i => - let row := (rowTasks[q.1]'(by - simp [rowTasks, q.isLt])).get - row.1[i.1]'(by - simp [row.2]) - -/-- `cacheBound2Task` preserves pointwise values. -/ -theorem cacheBound2Task_apply {m n : Nat} (f : Fin m → Fin n → Dyadic) (q : Fin m) (i : Fin n) : - cacheBound2Task f q i = f q i := by - classical - simp [cacheBound2Task, Task.spawn, Array.getElem_ofFn] - -/-- Cache a bound function on two indices using per-element tasks for parallel evaluation. -/ -def cacheBound2TaskElem {m n : Nat} (f : Fin m → Fin n → Dyadic) : Fin m → Fin n → Dyadic := - let rowTasks : Array (Array (Task Dyadic)) := - Array.ofFn (fun q : Fin m => - Array.ofFn (fun i : Fin n => - Task.spawn (fun _ => f q i))) - fun q i => - let row := (rowTasks[q.1]'(by - simp [rowTasks, q.isLt])) - let t := row[i.1]'(by - have hsize : row.size = n := by - simp [row, rowTasks] - simp [hsize, i.isLt]) - t.get - -/-- `cacheBound2TaskElem` preserves pointwise values. -/ -theorem cacheBound2TaskElem_apply {m n : Nat} (f : Fin m → Fin n → Dyadic) (q : Fin m) (i : Fin n) : - cacheBound2TaskElem f q i = f q i := by - classical - simp [cacheBound2TaskElem, Task.spawn, Array.getElem_ofFn] - -/-- Cache a pair of bound functions on two indices. -/ -def cacheBoundPair2 {m n : Nat} - (f : Fin m → (Fin n → Dyadic) × (Fin n → Dyadic)) : - (Fin m → Fin n → Dyadic) × (Fin m → Fin n → Dyadic) := - let data : Thunk (Array (Array Dyadic × Array Dyadic)) := Thunk.mk (fun _ => - Array.ofFn (fun q => - let row := f q - (Array.ofFn row.1, Array.ofFn row.2))) - let lo : Fin m → Fin n → Dyadic := fun q i => - let row := (Thunk.get data)[q.1]'(by - have hsize : (Thunk.get data).size = m := by - simp [Thunk.get, data] - rw [hsize] - exact q.isLt) - let loRow := row.1 - loRow[i.1]'(by - have hsize : loRow.size = n := by - simp [loRow, row, Thunk.get, data, Array.getElem_ofFn] - rw [hsize] - exact i.isLt) - let hi : Fin m → Fin n → Dyadic := fun q i => - let row := (Thunk.get data)[q.1]'(by - have hsize : (Thunk.get data).size = m := by - simp [Thunk.get, data] - rw [hsize] - exact q.isLt) - let hiRow := row.2 - hiRow[i.1]'(by - have hsize : hiRow.size = n := by - simp [hiRow, row, Thunk.get, data, Array.getElem_ofFn] - rw [hsize] - exact i.isLt) - (lo, hi) - -/-- `cacheBoundPair2` preserves pointwise values (first component). -/ -theorem cacheBoundPair2_apply_left {m n : Nat} - (f : Fin m → (Fin n → Dyadic) × (Fin n → Dyadic)) (q : Fin m) (i : Fin n) : - (cacheBoundPair2 f).1 q i = (f q).1 i := by - simp [cacheBoundPair2, Thunk.get, Array.getElem_ofFn] - -/-- `cacheBoundPair2` preserves pointwise values (second component). -/ -theorem cacheBoundPair2_apply_right {m n : Nat} - (f : Fin m → (Fin n → Dyadic) × (Fin n → Dyadic)) (q : Fin m) (i : Fin n) : - (cacheBoundPair2 f).2 q i = (f q).2 i := by - simp [cacheBoundPair2, Thunk.get, Array.getElem_ofFn] - /-- Real-valued attention output for a query token and model coordinate. -/ noncomputable def attentionOutputReal {seq dModel dHead numHeads : Nat} [NeZero seq] - (eps : Dyadic) (ln1Gamma ln1Beta : Fin dModel → Dyadic) + (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (attnBias : Fin dModel → Dyadic) + (attnBias : Fin dModel → Rat) (scores : Fin numHeads → Fin seq → Fin seq → Real) (x : Fin seq → Fin dModel → Real) (q : Fin seq) (i : Fin dModel) : Real := @@ -168,9 +46,9 @@ noncomputable def attentionOutputReal {seq dModel dHead numHeads : Nat} [NeZero /-- Unfolding lemma for `attentionOutputReal`. -/ theorem attentionOutputReal_def {seq dModel dHead numHeads : Nat} [NeZero seq] - (eps : Dyadic) (ln1Gamma ln1Beta : Fin dModel → Dyadic) + (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (attnBias : Fin dModel → Dyadic) + (attnBias : Fin dModel → Rat) (scores : Fin numHeads → Fin seq → Fin seq → Real) (x : Fin seq → Fin dModel → Real) (q : Fin seq) (i : Fin dModel) : @@ -189,34 +67,34 @@ theorem attentionOutputReal_def {seq dModel dHead numHeads : Nat} [NeZero seq] /-- Interval bounds for multi-head attention outputs from interval inputs. -/ def attentionOutputBounds {dModel dHead numHeads : Nat} - (eps : Dyadic) (ln1Gamma ln1Beta : Fin dModel → Dyadic) + (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (attnBias : Fin dModel → Dyadic) - (lo hi : Fin dModel → Dyadic) : - (Fin dModel → Dyadic) × (Fin dModel → Dyadic) := + (attnBias : Fin dModel → Rat) + (lo hi : Fin dModel → Rat) : + (Fin dModel → Rat) × (Fin dModel → Rat) := let absBound := intervalAbsBound lo hi let ln := layerNormAbsBounds eps ln1Gamma ln1Beta absBound let lnLo := ln.1 let lnHi := ln.2 - let vLo : Fin numHeads → Fin dHead → Dyadic := fun h d => + let vLo : Fin numHeads → Fin dHead → Rat := fun h d => dotIntervalLower (fun j => (heads h).wv j d) lnLo lnHi + (heads h).bv d - let vHi : Fin numHeads → Fin dHead → Dyadic := fun h d => + let vHi : Fin numHeads → Fin dHead → Rat := fun h d => dotIntervalUpper (fun j => (heads h).wv j d) lnLo lnHi + (heads h).bv d - let headLo : Fin numHeads → Fin dModel → Dyadic := fun h i => + let headLo : Fin numHeads → Fin dModel → Rat := fun h i => dotIntervalLower (fun d => (heads h).wo i d) (vLo h) (vHi h) - let headHi : Fin numHeads → Fin dModel → Dyadic := fun h i => + let headHi : Fin numHeads → Fin dModel → Rat := fun h i => dotIntervalUpper (fun d => (heads h).wo i d) (vLo h) (vHi h) - let sumLo : Fin dModel → Dyadic := fun i => ∑ h, headLo h i - let sumHi : Fin dModel → Dyadic := fun i => ∑ h, headHi h i + let sumLo : Fin dModel → Rat := fun i => ∑ h, headLo h i + let sumHi : Fin dModel → Rat := fun i => ∑ h, headHi h i (fun i => sumLo i + attnBias i, fun i => sumHi i + attnBias i) /-- `attentionOutputBounds` soundness for real attention outputs. -/ theorem attentionOutputBounds_spec {seq dModel dHead numHeads : Nat} [NeZero seq] - (eps : Dyadic) (ln1Gamma ln1Beta : Fin dModel → Dyadic) + (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (attnBias : Fin dModel → Dyadic) + (attnBias : Fin dModel → Rat) (scores : Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin dModel → Dyadic) (x : Fin seq → Fin dModel → Real) + (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : let bounds := attentionOutputBounds eps ln1Gamma ln1Beta heads attnBias lo hi @@ -233,16 +111,16 @@ theorem attentionOutputBounds_spec {seq dModel dHead numHeads : Nat} [NeZero seq let lnHi := lnBounds.2 let lnOut : Fin seq → Fin dModel → Real := fun k j => layerNormRealOfReal eps ln1Gamma ln1Beta (x k) j - let vLo : Fin numHeads → Fin dHead → Dyadic := fun h d => + let vLo : Fin numHeads → Fin dHead → Rat := fun h d => dotIntervalLower (fun j => (heads h).wv j d) lnLo lnHi + (heads h).bv d - let vHi : Fin numHeads → Fin dHead → Dyadic := fun h d => + let vHi : Fin numHeads → Fin dHead → Rat := fun h d => dotIntervalUpper (fun j => (heads h).wv j d) lnLo lnHi + (heads h).bv d - let headLo : Fin numHeads → Fin dModel → Dyadic := fun h j => + let headLo : Fin numHeads → Fin dModel → Rat := fun h j => dotIntervalLower (fun d => (heads h).wo j d) (vLo h) (vHi h) - let headHi : Fin numHeads → Fin dModel → Dyadic := fun h j => + let headHi : Fin numHeads → Fin dModel → Rat := fun h j => dotIntervalUpper (fun d => (heads h).wo j d) (vLo h) (vHi h) - let sumLo : Fin dModel → Dyadic := fun j => ∑ h, headLo h j - let sumHi : Fin dModel → Dyadic := fun j => ∑ h, headHi h j + let sumLo : Fin dModel → Rat := fun j => ∑ h, headLo h j + let sumHi : Fin dModel → Rat := fun j => ∑ h, headHi h j let headValue : Fin numHeads → Fin seq → Fin dHead → Real := fun h k d => dotProduct (fun j => ((heads h).wv j d : Real)) (lnOut k) + (heads h).bv d let headWeights : Fin numHeads → Fin seq → Fin seq → Real := fun h q k => @@ -260,9 +138,9 @@ theorem attentionOutputBounds_spec {seq dModel dHead numHeads : Nat} [NeZero seq max_abs_le_intervalAbsBound lo hi i have hsup_real : max |(lo i : Real)| |(hi i : Real)| ≤ (absBound : Real) := by - have hsup' : dyadicToReal (max |lo i| |hi i|) ≤ dyadicToReal absBound := - dyadicToReal_le_of_le hsup - simpa [dyadicToReal_abs, dyadicToReal_max] using hsup' + have hsup' : ratToReal (max |lo i| |hi i|) ≤ ratToReal absBound := + ratToReal_le_of_le hsup + simpa [ratToReal_abs, ratToReal_max] using hsup' exact le_trans hbound hsup_real have hln_bounds : ∀ q i, (lnLo i : Real) ≤ lnOut q i ∧ lnOut q i ≤ (lnHi i : Real) := by intro q i @@ -369,12 +247,12 @@ theorem attentionOutputBounds_spec {seq dModel dHead numHeads : Nat} [NeZero seq have hsum := Finset.sum_le_sum (s := (Finset.univ : Finset (Fin numHeads))) (fun h _ => (hproj_bounds h q i).1) - simpa [sumLo, Linear.dyadicToReal_sum_univ] using hsum + simpa [sumLo, Linear.ratToReal_sum_univ] using hsum have hhigh : ∑ h, headProj h q i ≤ (sumHi i : Real) := by have hsum := Finset.sum_le_sum (s := (Finset.univ : Finset (Fin numHeads))) (fun h _ => (hproj_bounds h q i).2) - simpa [sumHi, Linear.dyadicToReal_sum_univ] using hsum + simpa [sumHi, Linear.ratToReal_sum_univ] using hsum exact ⟨hlow, hhigh⟩ have hlow : (sumLo i : Real) + (attnBias i : Real) ≤ @@ -404,21 +282,21 @@ theorem attentionOutputBounds_spec {seq dModel dHead numHeads : Nat} [NeZero seq /-- Interval bounds for the attention residual path. -/ def attentionResidualBounds {dModel dHead numHeads : Nat} - (eps : Dyadic) (ln1Gamma ln1Beta : Fin dModel → Dyadic) + (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (attnBias : Fin dModel → Dyadic) - (lo hi : Fin dModel → Dyadic) : - (Fin dModel → Dyadic) × (Fin dModel → Dyadic) := + (attnBias : Fin dModel → Rat) + (lo hi : Fin dModel → Rat) : + (Fin dModel → Rat) × (Fin dModel → Rat) := let attn := attentionOutputBounds eps ln1Gamma ln1Beta heads attnBias lo hi (fun i => lo i + attn.1 i, fun i => hi i + attn.2 i) /-- `attentionResidualBounds` soundness for attention residual outputs. -/ theorem attentionResidualBounds_spec {seq dModel dHead numHeads : Nat} [NeZero seq] - (eps : Dyadic) (ln1Gamma ln1Beta : Fin dModel → Dyadic) + (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (attnBias : Fin dModel → Dyadic) + (attnBias : Fin dModel → Rat) (scores : Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin dModel → Dyadic) (x : Fin seq → Fin dModel → Real) + (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : let bounds := attentionResidualBounds eps ln1Gamma ln1Beta heads attnBias lo hi @@ -441,14 +319,14 @@ theorem attentionResidualBounds_spec {seq dModel dHead numHeads : Nat} [NeZero s /-- Interval bounds for a full transformer layer (attention + MLP). -/ def transformerLayerBounds {dModel dHead numHeads hidden : Nat} - (eps : Dyadic) - (ln1Gamma ln1Beta ln2Gamma ln2Beta : Fin dModel → Dyadic) + (eps : Rat) + (ln1Gamma ln1Beta ln2Gamma ln2Beta : Fin dModel → Rat) (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (attnBias : Fin dModel → Dyadic) - (mlpWIn : Fin dModel → Fin hidden → Dyadic) (mlpBIn : Fin hidden → Dyadic) - (mlpWOut : Fin hidden → Fin dModel → Dyadic) (mlpBOut : Fin dModel → Dyadic) - (lo hi : Fin dModel → Dyadic) : - (Fin dModel → Dyadic) × (Fin dModel → Dyadic) := + (attnBias : Fin dModel → Rat) + (mlpWIn : Fin dModel → Fin hidden → Rat) (mlpBIn : Fin hidden → Rat) + (mlpWOut : Fin hidden → Fin dModel → Rat) (mlpBOut : Fin dModel → Rat) + (lo hi : Fin dModel → Rat) : + (Fin dModel → Rat) × (Fin dModel → Rat) := let loCached := cacheBound lo let hiCached := cacheBound hi let attn := attentionResidualBounds eps ln1Gamma ln1Beta heads attnBias loCached hiCached @@ -462,14 +340,14 @@ def transformerLayerBounds {dModel dHead numHeads hidden : Nat} /-- `transformerLayerBounds` soundness for full transformer-layer outputs. -/ theorem transformerLayerBounds_spec {seq dModel dHead numHeads hidden : Nat} [NeZero seq] - (eps : Dyadic) - (ln1Gamma ln1Beta ln2Gamma ln2Beta : Fin dModel → Dyadic) + (eps : Rat) + (ln1Gamma ln1Beta ln2Gamma ln2Beta : Fin dModel → Rat) (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (attnBias : Fin dModel → Dyadic) - (mlpWIn : Fin dModel → Fin hidden → Dyadic) (mlpBIn : Fin hidden → Dyadic) - (mlpWOut : Fin hidden → Fin dModel → Dyadic) (mlpBOut : Fin dModel → Dyadic) + (attnBias : Fin dModel → Rat) + (mlpWIn : Fin dModel → Fin hidden → Rat) (mlpBIn : Fin hidden → Rat) + (mlpWOut : Fin hidden → Fin dModel → Rat) (mlpBOut : Fin dModel → Rat) (scores : Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin dModel → Dyadic) (x : Fin seq → Fin dModel → Real) + (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : let bounds := transformerLayerBounds eps ln1Gamma ln1Beta ln2Gamma ln2Beta heads attnBias diff --git a/Nfp/Sound/Bounds/Cache.lean b/Nfp/Sound/Bounds/Cache.lean new file mode 100644 index 0000000..5084089 --- /dev/null +++ b/Nfp/Sound/Bounds/Cache.lean @@ -0,0 +1,245 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Core.Basic + +/-! +Caching helpers for interval bounds. +-/ + +namespace Nfp + +namespace Sound + +namespace Bounds + +/-- Cache a bound function in an array-backed lookup to avoid repeated evaluation. -/ +def cacheBound {n : Nat} (f : Fin n → Rat) : Fin n → Rat := + let data : Thunk (Array Rat) := Thunk.mk (fun _ => Array.ofFn f) + fun i => (Thunk.get data)[i.1]'(by + have hsize : (Thunk.get data).size = n := by + simp [Thunk.get, data] + simp [hsize]) + +/-- `cacheBound` preserves pointwise values. -/ +theorem cacheBound_apply {n : Nat} (f : Fin n → Rat) (i : Fin n) : + cacheBound f i = f i := by + simp [cacheBound, Thunk.get, Array.getElem_ofFn] + +/-- Cache a bound function using per-index thunks for lazy evaluation. -/ +def cacheBoundThunk {n : Nat} (f : Fin n → Rat) : Fin n → Rat := + let data : Array (Thunk Rat) := Array.ofFn (fun i => Thunk.mk (fun _ => f i)) + fun i => + let t := data[i.1]'(by + simp [data, i.isLt]) + Thunk.get t + +/-- `cacheBoundThunk` preserves pointwise values. -/ +theorem cacheBoundThunk_apply {n : Nat} (f : Fin n → Rat) (i : Fin n) : + cacheBoundThunk f i = f i := by + simp [cacheBoundThunk, Thunk.get, Array.getElem_ofFn] + +/-- Cache a bound function using tasks for parallel evaluation. -/ +def cacheBoundTask {n : Nat} (f : Fin n → Rat) : Fin n → Rat := + let tasks : Array (Task Rat) := + Array.ofFn (fun i : Fin n => + Task.spawn (fun _ => f i)) + fun i => + let hsize : tasks.size = n := by + simp [tasks] + let t := tasks[i.1]'(by + simp [hsize, i.isLt]) + t.get + +/-- `cacheBoundTask` preserves pointwise values. -/ +theorem cacheBoundTask_apply {n : Nat} (f : Fin n → Rat) (i : Fin n) : + cacheBoundTask f i = f i := by + classical + simp [cacheBoundTask, Task.spawn, Array.getElem_ofFn] + +/-- Cache a bound function on two indices. -/ +def cacheBound2 {m n : Nat} (f : Fin m → Fin n → Rat) : Fin m → Fin n → Rat := + let data : Thunk (Array (Thunk (Array Rat))) := Thunk.mk (fun _ => + Array.ofFn (fun q => Thunk.mk (fun _ => Array.ofFn (f q)))) + fun q i => + let rowThunk := (Thunk.get data)[q.1]'(by + have hsize : (Thunk.get data).size = m := by + simp [Thunk.get, data] + rw [hsize] + exact q.isLt) + let row := Thunk.get rowThunk + row[i.1]'(by + have hsize : row.size = n := by + simp [row, rowThunk, Thunk.get, data, Array.getElem_ofFn] + rw [hsize] + exact i.isLt) + +/-- `cacheBound2` preserves pointwise values. -/ +theorem cacheBound2_apply {m n : Nat} (f : Fin m → Fin n → Rat) (q : Fin m) (i : Fin n) : + cacheBound2 f q i = f q i := by + simp [cacheBound2, Thunk.get, Array.getElem_ofFn] + +/-- Cache a bound function on two indices using row tasks for parallel evaluation. -/ +def cacheBound2Task {m n : Nat} (f : Fin m → Fin n → Rat) : Fin m → Fin n → Rat := + let rowTasks : Array (Task { row : Array Rat // row.size = n }) := + Array.ofFn (fun q : Fin m => + Task.spawn (fun _ => ⟨Array.ofFn (f q), by simp⟩)) + fun q i => + let row := (rowTasks[q.1]'(by + simp [rowTasks, q.isLt])).get + row.1[i.1]'(by + simp [row.2]) + +/-- `cacheBound2Task` preserves pointwise values. -/ +theorem cacheBound2Task_apply {m n : Nat} (f : Fin m → Fin n → Rat) (q : Fin m) (i : Fin n) : + cacheBound2Task f q i = f q i := by + classical + simp [cacheBound2Task, Task.spawn, Array.getElem_ofFn] + +/-- Cache a bound function on two indices using per-element tasks for parallel evaluation. -/ +def cacheBound2TaskElem {m n : Nat} (f : Fin m → Fin n → Rat) : Fin m → Fin n → Rat := + let rowTasks : Array (Array (Task Rat)) := + Array.ofFn (fun q : Fin m => + Array.ofFn (fun i : Fin n => + Task.spawn (fun _ => f q i))) + fun q i => + let row := (rowTasks[q.1]'(by + simp [rowTasks, q.isLt])) + let t := row[i.1]'(by + have hsize : row.size = n := by + simp [row, rowTasks] + simp [hsize, i.isLt]) + t.get + +/-- `cacheBound2TaskElem` preserves pointwise values. -/ +theorem cacheBound2TaskElem_apply {m n : Nat} (f : Fin m → Fin n → Rat) (q : Fin m) (i : Fin n) : + cacheBound2TaskElem f q i = f q i := by + classical + simp [cacheBound2TaskElem, Task.spawn, Array.getElem_ofFn] + +/-- Cache a pair of bound functions on two indices. -/ +def cacheBoundPair2 {m n : Nat} + (f : Fin m → (Fin n → Rat) × (Fin n → Rat)) : + (Fin m → Fin n → Rat) × (Fin m → Fin n → Rat) := + let data : Thunk (Array (Array Rat × Array Rat)) := Thunk.mk (fun _ => + Array.ofFn (fun q => + let row := f q + (Array.ofFn row.1, Array.ofFn row.2))) + let lo : Fin m → Fin n → Rat := fun q i => + let row := (Thunk.get data)[q.1]'(by + have hsize : (Thunk.get data).size = m := by + simp [Thunk.get, data] + rw [hsize] + exact q.isLt) + let loRow := row.1 + loRow[i.1]'(by + have hsize : loRow.size = n := by + simp [loRow, row, Thunk.get, data, Array.getElem_ofFn] + rw [hsize] + exact i.isLt) + let hi : Fin m → Fin n → Rat := fun q i => + let row := (Thunk.get data)[q.1]'(by + have hsize : (Thunk.get data).size = m := by + simp [Thunk.get, data] + rw [hsize] + exact q.isLt) + let hiRow := row.2 + hiRow[i.1]'(by + have hsize : hiRow.size = n := by + simp [hiRow, row, Thunk.get, data, Array.getElem_ofFn] + rw [hsize] + exact i.isLt) + (lo, hi) + +/-- `cacheBoundPair2` preserves pointwise values (first component). -/ +theorem cacheBoundPair2_apply_left {m n : Nat} + (f : Fin m → (Fin n → Rat) × (Fin n → Rat)) (q : Fin m) (i : Fin n) : + (cacheBoundPair2 f).1 q i = (f q).1 i := by + simp [cacheBoundPair2, Thunk.get, Array.getElem_ofFn] + +/-- `cacheBoundPair2` preserves pointwise values (second component). -/ +theorem cacheBoundPair2_apply_right {m n : Nat} + (f : Fin m → (Fin n → Rat) × (Fin n → Rat)) (q : Fin m) (i : Fin n) : + (cacheBoundPair2 f).2 q i = (f q).2 i := by + simp [cacheBoundPair2, Thunk.get, Array.getElem_ofFn] + +/-- Cache a pair of bound functions on two indices using row tasks. -/ +def cacheBoundPair2Task {m n : Nat} + (f : Fin m → (Fin n → Rat) × (Fin n → Rat)) : + (Fin m → Fin n → Rat) × (Fin m → Fin n → Rat) := + let rowTasks : + Array (Task ({ rowLo : Array Rat // rowLo.size = n } × + { rowHi : Array Rat // rowHi.size = n })) := + Array.ofFn (fun q : Fin m => + Task.spawn (fun _ => + let row := f q + (⟨Array.ofFn row.1, by simp⟩, ⟨Array.ofFn row.2, by simp⟩))) + let lo : Fin m → Fin n → Rat := fun q i => + let row := (rowTasks[q.1]'(by + simp [rowTasks, q.isLt])).get + row.1.1[i.1]'(by + simp [row.1.2, i.isLt]) + let hi : Fin m → Fin n → Rat := fun q i => + let row := (rowTasks[q.1]'(by + simp [rowTasks, q.isLt])).get + row.2.1[i.1]'(by + simp [row.2.2, i.isLt]) + (lo, hi) + +/-- `cacheBoundPair2Task` preserves pointwise values (first component). -/ +theorem cacheBoundPair2Task_apply_left {m n : Nat} + (f : Fin m → (Fin n → Rat) × (Fin n → Rat)) (q : Fin m) (i : Fin n) : + (cacheBoundPair2Task f).1 q i = (f q).1 i := by + classical + simp [cacheBoundPair2Task, Task.spawn, Array.getElem_ofFn] + +/-- `cacheBoundPair2Task` preserves pointwise values (second component). -/ +theorem cacheBoundPair2Task_apply_right {m n : Nat} + (f : Fin m → (Fin n → Rat) × (Fin n → Rat)) (q : Fin m) (i : Fin n) : + (cacheBoundPair2Task f).2 q i = (f q).2 i := by + classical + simp [cacheBoundPair2Task, Task.spawn, Array.getElem_ofFn] + +/-- Cache a pair of bound functions on two indices using per-element tasks. -/ +def cacheBoundPair2TaskElem {m n : Nat} (f : Fin m → Fin n → Rat × Rat) : + (Fin m → Fin n → Rat) × (Fin m → Fin n → Rat) := + let rowTasks : Array (Array (Task (Rat × Rat))) := + Array.ofFn (fun q : Fin m => + Array.ofFn (fun i : Fin n => + Task.spawn (fun _ => f q i))) + let lo : Fin m → Fin n → Rat := fun q i => + let row := (rowTasks[q.1]'(by + simp [rowTasks, q.isLt])) + let t := row[i.1]'(by + have hsize : row.size = n := by + simp [row, rowTasks] + simp [hsize, i.isLt]) + (t.get).1 + let hi : Fin m → Fin n → Rat := fun q i => + let row := (rowTasks[q.1]'(by + simp [rowTasks, q.isLt])) + let t := row[i.1]'(by + have hsize : row.size = n := by + simp [row, rowTasks] + simp [hsize, i.isLt]) + (t.get).2 + (lo, hi) + +/-- `cacheBoundPair2TaskElem` preserves pointwise values (first component). -/ +theorem cacheBoundPair2TaskElem_apply_left {m n : Nat} + (f : Fin m → Fin n → Rat × Rat) (q : Fin m) (i : Fin n) : + (cacheBoundPair2TaskElem f).1 q i = (f q i).1 := by + classical + simp [cacheBoundPair2TaskElem, Task.spawn, Array.getElem_ofFn] + +/-- `cacheBoundPair2TaskElem` preserves pointwise values (second component). -/ +theorem cacheBoundPair2TaskElem_apply_right {m n : Nat} + (f : Fin m → Fin n → Rat × Rat) (q : Fin m) (i : Fin n) : + (cacheBoundPair2TaskElem f).2 q i = (f q i).2 := by + classical + simp [cacheBoundPair2TaskElem, Task.spawn, Array.getElem_ofFn] + +end Bounds + +end Sound + +end Nfp diff --git a/Nfp/Sound/Bounds/Gelu.lean b/Nfp/Sound/Bounds/Gelu.lean index 8190bdb..6817dec 100644 --- a/Nfp/Sound/Bounds/Gelu.lean +++ b/Nfp/Sound/Bounds/Gelu.lean @@ -118,18 +118,18 @@ theorem geluTanh_bounds (x : Real) : simpa [min_eq_left hx', max_eq_right hx'] using And.intro h1 h0 /-- Interval bounds for GELU given input bounds. -/ -def geluInterval (lo hi : Dyadic) : Dyadic × Dyadic := +def geluInterval (lo hi : Rat) : Rat × Rat := (if lo ≤ 0 then lo else 0, if 0 ≤ hi then hi else 0) /-- `geluInterval` soundly bounds `geluTanh` on a real interval. -/ -theorem geluInterval_bounds {lo hi : Dyadic} {x : Real} +theorem geluInterval_bounds {lo hi : Rat} {x : Real} (hlo : (lo : Real) ≤ x) (hhi : x ≤ (hi : Real)) : (geluInterval lo hi).1 ≤ (geluTanh x : Real) ∧ (geluTanh x : Real) ≤ (geluInterval lo hi).2 := by have hgelu := geluTanh_bounds x by_cases hlo0 : lo ≤ 0 · have hlo0r : (lo : Real) ≤ 0 := by - exact (dyadicToReal_nonpos_iff (x := lo)).2 hlo0 + exact (ratToReal_nonpos_iff (x := lo)).2 hlo0 have hmin : min (lo : Real) 0 ≤ min x 0 := min_le_min hlo le_rfl have hlo' : (lo : Real) ≤ geluTanh x := by have hmin' : (lo : Real) ≤ min x 0 := by @@ -141,18 +141,18 @@ theorem geluInterval_bounds {lo hi : Dyadic} {x : Real} · simpa [geluInterval, hlo0] using hlo' · by_cases hhi0 : 0 ≤ hi · have hhi0r : 0 ≤ (hi : Real) := by - exact dyadicToReal_nonneg_of_nonneg hhi0 + exact ratToReal_nonneg_of_nonneg hhi0 have hmax' : max (hi : Real) 0 = (hi : Real) := max_eq_left hhi0r simpa [geluInterval, hhi0, hmax'] using hhi' · have hhi0r : (hi : Real) ≤ 0 := by - exact (dyadicToReal_nonpos_iff (x := hi)).2 (le_of_not_ge hhi0) + exact (ratToReal_nonpos_iff (x := hi)).2 (le_of_not_ge hhi0) have hx0 : x ≤ 0 := le_trans hhi hhi0r have hmax' : max x 0 = 0 := max_eq_right hx0 have hhi'' : geluTanh x ≤ (0 : Real) := by simpa [hmax'] using hgelu.2 - simpa [geluInterval, hhi0, dyadicToReal_zero] using hhi'' + simpa [geluInterval, hhi0, ratToReal_zero] using hhi'' · have hlo0r : 0 ≤ (lo : Real) := by - exact dyadicToReal_nonneg_of_nonneg (le_of_not_ge hlo0) + exact ratToReal_nonneg_of_nonneg (le_of_not_ge hlo0) have hx0 : 0 ≤ x := le_trans hlo0r hlo have hmin' : min x 0 = 0 := min_eq_right hx0 have hlo' : (0 : Real) ≤ geluTanh x := by @@ -160,19 +160,19 @@ theorem geluInterval_bounds {lo hi : Dyadic} {x : Real} have hmax : max x 0 ≤ max (hi : Real) 0 := max_le_max hhi le_rfl have hhi' : geluTanh x ≤ max (hi : Real) 0 := le_trans hgelu.2 hmax constructor - · simpa [geluInterval, hlo0, dyadicToReal_zero] using hlo' + · simpa [geluInterval, hlo0, ratToReal_zero] using hlo' · by_cases hhi0 : 0 ≤ hi · have hhi0r : 0 ≤ (hi : Real) := by - exact dyadicToReal_nonneg_of_nonneg hhi0 + exact ratToReal_nonneg_of_nonneg hhi0 have hmax' : max (hi : Real) 0 = (hi : Real) := max_eq_left hhi0r simpa [geluInterval, hhi0, hmax'] using hhi' · have hhi0r : (hi : Real) ≤ 0 := by - exact (dyadicToReal_nonpos_iff (x := hi)).2 (le_of_not_ge hhi0) + exact (ratToReal_nonpos_iff (x := hi)).2 (le_of_not_ge hhi0) have hx0' : x ≤ 0 := le_trans hhi hhi0r have hmax' : max x 0 = 0 := max_eq_right hx0' have hhi'' : geluTanh x ≤ (0 : Real) := by simpa [hmax'] using hgelu.2 - simpa [geluInterval, hhi0, dyadicToReal_zero] using hhi'' + simpa [geluInterval, hhi0, ratToReal_zero] using hhi'' end Bounds diff --git a/Nfp/Sound/Bounds/LayerNorm.lean b/Nfp/Sound/Bounds/LayerNorm.lean index 37fae19..60e7e34 100644 --- a/Nfp/Sound/Bounds/LayerNorm.lean +++ b/Nfp/Sound/Bounds/LayerNorm.lean @@ -14,9 +14,9 @@ import Nfp.Sound.Bounds.LayerNorm.MeanVariance import Nfp.Sound.Linear.FinFold /-! -LayerNorm interval bounds for dyadic inputs. +LayerNorm interval bounds for rational inputs. -This module computes dyadic interval bounds for LayerNorm outputs and proves +This module computes rational interval bounds for LayerNorm outputs and proves those bounds sound for real-valued LayerNorm semantics. -/ @@ -30,61 +30,61 @@ open scoped BigOperators /-! Square-root bounds. -/ -lemma dyadic_nat_cast_nonneg (n : Nat) : (0 : Dyadic) ≤ (n : Dyadic) := by +lemma rat_nat_cast_nonneg (n : Nat) : (0 : Rat) ≤ (n : Rat) := by simp -lemma dyadic_nat_cast_pos {n : Nat} (h : 0 < n) : (0 : Dyadic) < (n : Dyadic) := by - exact (Nat.cast_pos (α := Dyadic)).2 h +lemma rat_nat_cast_pos {n : Nat} (h : 0 < n) : (0 : Rat) < (n : Rat) := by + exact (Nat.cast_pos (α := Rat)).2 h /-- Base rational lower bound for a square root. -/ -def sqrtLowerBase (q : Dyadic) : Dyadic := - let num := q.toRat.num.natAbs - let den := q.toRat.den +def sqrtLowerBase (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den let a := Nat.sqrt num let b := Nat.sqrt den - dyadicOfRatDown ((a : Rat) / (b + 1)) + ratRoundDown ((a : Rat) / (b + 1)) /-- Base rational upper bound for a square root. -/ -def sqrtUpperBase (q : Dyadic) : Dyadic := - let num := q.toRat.num.natAbs - let den := q.toRat.den +def sqrtUpperBase (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den let a := Nat.sqrt num let b := Nat.sqrt den - dyadicOfRatUp ((a + 1 : Rat) / b) + ratRoundUp ((a + 1 : Rat) / b) /-- Alternate rational lower bound for a square root. -/ -def sqrtLowerAlt (q : Dyadic) : Dyadic := - let num := q.toRat.num.natAbs - let den := q.toRat.den +def sqrtLowerAlt (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den let a := Nat.sqrt (num * den) - dyadicOfRatDown ((a : Rat) / den) + ratRoundDown ((a : Rat) / den) /-- Alternate rational upper bound for a square root. -/ -def sqrtUpperAlt (q : Dyadic) : Dyadic := - let num := q.toRat.num.natAbs - let den := q.toRat.den +def sqrtUpperAlt (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den let a := Nat.sqrt (num * den) - dyadicOfRatUp ((a + 1 : Rat) / den) + ratRoundUp ((a + 1 : Rat) / den) -/-- Dyadicional lower bound for a square root (tighter of two bounds). -/ -def sqrtLower (q : Dyadic) : Dyadic := +/-- Rational lower bound for a square root (tighter of two bounds). -/ +def sqrtLower (q : Rat) : Rat := max (sqrtLowerBase q) (sqrtLowerAlt q) -/-- Dyadicional upper bound for a square root (tighter of two bounds). -/ -def sqrtUpper (q : Dyadic) : Dyadic := +/-- Rational upper bound for a square root (tighter of two bounds). -/ +def sqrtUpper (q : Rat) : Rat := min (sqrtUpperBase q) (sqrtUpperAlt q) /-- `sqrtLowerBase` is nonnegative. -/ -theorem sqrtLowerBase_nonneg (q : Dyadic) : 0 ≤ sqrtLowerBase q := by +theorem sqrtLowerBase_nonneg (q : Rat) : 0 ≤ sqrtLowerBase q := by classical unfold sqrtLowerBase - have hnum : 0 ≤ (Nat.sqrt q.toRat.num.natAbs : Rat) := by - exact_mod_cast (Nat.zero_le (Nat.sqrt q.toRat.num.natAbs)) - have hden : 0 ≤ (Nat.sqrt q.toRat.den + 1 : Rat) := by - exact_mod_cast (Nat.zero_le (Nat.sqrt q.toRat.den + 1)) - have hrat : 0 ≤ (Nat.sqrt q.toRat.num.natAbs : Rat) / (Nat.sqrt q.toRat.den + 1) := by + have hnum : 0 ≤ (Nat.sqrt q.num.natAbs : Rat) := by + exact_mod_cast (Nat.zero_le (Nat.sqrt q.num.natAbs)) + have hden : 0 ≤ (Nat.sqrt q.den + 1 : Rat) := by + exact_mod_cast (Nat.zero_le (Nat.sqrt q.den + 1)) + have hrat : 0 ≤ (Nat.sqrt q.num.natAbs : Rat) / (Nat.sqrt q.den + 1) := by exact div_nonneg hnum hden - exact dyadicOfRatDown_nonneg hrat + exact ratRoundDown_nonneg hrat /-! Strict positivity helpers. -/ @@ -92,107 +92,107 @@ theorem sqrtLowerBase_nonneg (q : Dyadic) : 0 ≤ sqrtLowerBase q := by /-- `sqrtUpperBase` is nonnegative. -/ -theorem sqrtUpperBase_nonneg (q : Dyadic) : 0 ≤ sqrtUpperBase q := by +theorem sqrtUpperBase_nonneg (q : Rat) : 0 ≤ sqrtUpperBase q := by classical unfold sqrtUpperBase - have hnum : 0 ≤ (Nat.sqrt q.toRat.num.natAbs + 1 : Rat) := by - exact_mod_cast (Nat.zero_le (Nat.sqrt q.toRat.num.natAbs + 1)) - have hden : 0 ≤ (Nat.sqrt q.toRat.den : Rat) := by - exact_mod_cast (Nat.zero_le (Nat.sqrt q.toRat.den)) + have hnum : 0 ≤ (Nat.sqrt q.num.natAbs + 1 : Rat) := by + exact_mod_cast (Nat.zero_le (Nat.sqrt q.num.natAbs + 1)) + have hden : 0 ≤ (Nat.sqrt q.den : Rat) := by + exact_mod_cast (Nat.zero_le (Nat.sqrt q.den)) have hrat : - 0 ≤ (Nat.sqrt q.toRat.num.natAbs + 1 : Rat) / (Nat.sqrt q.toRat.den) := by + 0 ≤ (Nat.sqrt q.num.natAbs + 1 : Rat) / (Nat.sqrt q.den) := by exact div_nonneg hnum hden - exact dyadicOfRatUp_nonneg hrat + exact ratRoundUp_nonneg hrat /-- `sqrtUpperBase` is always positive. -/ -theorem sqrtUpperBase_pos (q : Dyadic) : 0 < sqrtUpperBase q := by +theorem sqrtUpperBase_pos (q : Rat) : 0 < sqrtUpperBase q := by classical unfold sqrtUpperBase - have hnum_pos : (0 : Rat) < (Nat.sqrt q.toRat.num.natAbs + 1 : Rat) := by - exact_mod_cast (Nat.succ_pos (Nat.sqrt q.toRat.num.natAbs)) - have hden_pos : (0 : Rat) < (Nat.sqrt q.toRat.den : Rat) := by - have hden : 0 < q.toRat.den := q.toRat.den_pos + have hnum_pos : (0 : Rat) < (Nat.sqrt q.num.natAbs + 1 : Rat) := by + exact_mod_cast (Nat.succ_pos (Nat.sqrt q.num.natAbs)) + have hden_pos : (0 : Rat) < (Nat.sqrt q.den : Rat) := by + have hden : 0 < q.den := q.den_pos exact_mod_cast (Nat.sqrt_pos.2 hden) have hrat_pos : - (0 : Rat) < (Nat.sqrt q.toRat.num.natAbs + 1 : Rat) / (Nat.sqrt q.toRat.den) := by + (0 : Rat) < (Nat.sqrt q.num.natAbs + 1 : Rat) / (Nat.sqrt q.den) := by exact div_pos hnum_pos hden_pos - exact dyadicOfRatUp_pos hrat_pos + exact ratRoundUp_pos hrat_pos /-! Alternate bounds. -/ /-- `sqrtLowerAlt` is nonnegative. -/ -theorem sqrtLowerAlt_nonneg (q : Dyadic) : 0 ≤ sqrtLowerAlt q := by +theorem sqrtLowerAlt_nonneg (q : Rat) : 0 ≤ sqrtLowerAlt q := by classical unfold sqrtLowerAlt - have hnum : 0 ≤ (Nat.sqrt (q.toRat.num.natAbs * q.toRat.den) : Rat) := by - exact_mod_cast (Nat.zero_le (Nat.sqrt (q.toRat.num.natAbs * q.toRat.den))) - have hden : 0 ≤ (q.toRat.den : Rat) := by - exact_mod_cast (Nat.zero_le q.toRat.den) + have hnum : 0 ≤ (Nat.sqrt (q.num.natAbs * q.den) : Rat) := by + exact_mod_cast (Nat.zero_le (Nat.sqrt (q.num.natAbs * q.den))) + have hden : 0 ≤ (q.den : Rat) := by + exact_mod_cast (Nat.zero_le q.den) have hrat : - 0 ≤ (Nat.sqrt (q.toRat.num.natAbs * q.toRat.den) : Rat) / q.toRat.den := by + 0 ≤ (Nat.sqrt (q.num.natAbs * q.den) : Rat) / q.den := by exact div_nonneg hnum hden - exact dyadicOfRatDown_nonneg hrat + exact ratRoundDown_nonneg hrat /-- `sqrtUpperAlt` is nonnegative. -/ -theorem sqrtUpperAlt_nonneg (q : Dyadic) : 0 ≤ sqrtUpperAlt q := by +theorem sqrtUpperAlt_nonneg (q : Rat) : 0 ≤ sqrtUpperAlt q := by classical unfold sqrtUpperAlt - have hnum : 0 ≤ (Nat.sqrt (q.toRat.num.natAbs * q.toRat.den) + 1 : Rat) := by - exact_mod_cast (Nat.zero_le (Nat.sqrt (q.toRat.num.natAbs * q.toRat.den) + 1)) - have hden : 0 ≤ (q.toRat.den : Rat) := by - exact_mod_cast (Nat.zero_le q.toRat.den) + have hnum : 0 ≤ (Nat.sqrt (q.num.natAbs * q.den) + 1 : Rat) := by + exact_mod_cast (Nat.zero_le (Nat.sqrt (q.num.natAbs * q.den) + 1)) + have hden : 0 ≤ (q.den : Rat) := by + exact_mod_cast (Nat.zero_le q.den) have hrat : - 0 ≤ (Nat.sqrt (q.toRat.num.natAbs * q.toRat.den) + 1 : Rat) / q.toRat.den := by + 0 ≤ (Nat.sqrt (q.num.natAbs * q.den) + 1 : Rat) / q.den := by exact div_nonneg hnum hden - exact dyadicOfRatUp_nonneg hrat + exact ratRoundUp_nonneg hrat /-- `sqrtUpperAlt` is always positive. -/ -theorem sqrtUpperAlt_pos (q : Dyadic) : 0 < sqrtUpperAlt q := by +theorem sqrtUpperAlt_pos (q : Rat) : 0 < sqrtUpperAlt q := by classical unfold sqrtUpperAlt have hnum_pos : - (0 : Rat) < (Nat.sqrt (q.toRat.num.natAbs * q.toRat.den) + 1 : Rat) := by - exact_mod_cast (Nat.succ_pos (Nat.sqrt (q.toRat.num.natAbs * q.toRat.den))) - have hden_pos : (0 : Rat) < (q.toRat.den : Rat) := by - exact_mod_cast q.toRat.den_pos + (0 : Rat) < (Nat.sqrt (q.num.natAbs * q.den) + 1 : Rat) := by + exact_mod_cast (Nat.succ_pos (Nat.sqrt (q.num.natAbs * q.den))) + have hden_pos : (0 : Rat) < (q.den : Rat) := by + exact_mod_cast q.den_pos have hrat_pos : (0 : Rat) < - (Nat.sqrt (q.toRat.num.natAbs * q.toRat.den) + 1 : Rat) / q.toRat.den := by + (Nat.sqrt (q.num.natAbs * q.den) + 1 : Rat) / q.den := by exact div_pos hnum_pos hden_pos - exact dyadicOfRatUp_pos hrat_pos + exact ratRoundUp_pos hrat_pos /-! Combined bounds. -/ /-- `sqrtLower` is nonnegative. -/ -theorem sqrtLower_nonneg (q : Dyadic) : 0 ≤ sqrtLower q := by +theorem sqrtLower_nonneg (q : Rat) : 0 ≤ sqrtLower q := by have hbase : 0 ≤ sqrtLowerBase q := sqrtLowerBase_nonneg q exact le_trans hbase (le_max_left _ _) /-- `sqrtUpper` is nonnegative. -/ -theorem sqrtUpper_nonneg (q : Dyadic) : 0 ≤ sqrtUpper q := by +theorem sqrtUpper_nonneg (q : Rat) : 0 ≤ sqrtUpper q := by have hbase : 0 ≤ sqrtUpperBase q := sqrtUpperBase_nonneg q have halt : 0 ≤ sqrtUpperAlt q := sqrtUpperAlt_nonneg q exact le_min hbase halt /-- `sqrtUpper` is always positive. -/ -theorem sqrtUpper_pos (q : Dyadic) : 0 < sqrtUpper q := by +theorem sqrtUpper_pos (q : Rat) : 0 < sqrtUpper q := by have hbase : 0 < sqrtUpperBase q := sqrtUpperBase_pos q have halt : 0 < sqrtUpperAlt q := sqrtUpperAlt_pos q exact lt_min hbase halt /-- Square-root lower bound in reals. -/ -theorem sqrtLowerBase_le_real_sqrt {q : Dyadic} (hq : 0 ≤ q) : +theorem sqrtLowerBase_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : (sqrtLowerBase q : Real) ≤ Real.sqrt (q : Real) := by classical -- Set up numerator/denominator witnesses. - set num : Nat := q.toRat.num.natAbs - set den : Nat := q.toRat.den + set num : Nat := q.num.natAbs + set den : Nat := q.den set a : Nat := Nat.sqrt num set b : Nat := Nat.sqrt den have hden_pos : 0 < (den : Real) := by - exact_mod_cast q.toRat.den_pos + exact_mod_cast q.den_pos have hbpos : 0 < (b + 1 : Real) := by exact_mod_cast (Nat.succ_pos b) have hnum_le : (a ^ 2 : Real) ≤ num := by @@ -210,16 +210,14 @@ theorem sqrtLowerBase_le_real_sqrt {q : Dyadic} (hq : 0 ≤ q) : have hpow : ((a : Real) / (b + 1 : Real)) ^ 2 = (a ^ 2 : Real) / (b + 1) ^ 2 := by simp [pow_two, div_mul_div_comm] have hq_cast : (q : Real) = (num : Real) / den := by - have hnum_nonneg : 0 ≤ q.toRat.num := by - have hq' : (0 : Rat) ≤ q.toRat := - (Dyadic.toRat_le_toRat_iff (x := 0) (y := q)).2 hq - exact (Rat.num_nonneg (q := q.toRat)).2 hq' - have hnum_eq : (num : Int) = q.toRat.num := by + have hnum_nonneg : 0 ≤ q.num := by + exact (Rat.num_nonneg (q := q)).2 hq + have hnum_eq : (num : Int) = q.num := by simpa [num] using (Int.natAbs_of_nonneg hnum_nonneg) - have hnum_cast : (q.toRat.num : Real) = (num : Real) := by + have hnum_cast : (q.num : Real) = (num : Real) := by exact_mod_cast hnum_eq.symm - have hq_rat : (q : Real) = (q.toRat.num : Real) / q.toRat.den := by - simp [dyadicToReal, Rat.cast_def] + have hq_rat : (q : Real) = (q.num : Real) / q.den := by + simp [Rat.cast_def] simpa [hnum_cast, den] using hq_rat have hsq : ((a : Real) / (b + 1 : Real)) ^ 2 ≤ (q : Real) := by simpa [hpow, hq_cast, den, num] using hdiv @@ -228,31 +226,31 @@ theorem sqrtLowerBase_le_real_sqrt {q : Dyadic} (hq : 0 ≤ q) : have hden_nonneg : 0 ≤ (b + 1 : Real) := by exact_mod_cast (Nat.zero_le (b + 1)) exact div_nonneg hnum_nonneg hden_nonneg have hq_nonneg : 0 ≤ (q : Real) := by - exact dyadicToReal_nonneg_of_nonneg hq + exact ratToReal_nonneg_of_nonneg hq have hle : (a : Real) / (b + 1 : Real) ≤ Real.sqrt (q : Real) := (Real.le_sqrt hnonneg hq_nonneg).2 hsq have hdown : (sqrtLowerBase q : Real) ≤ (a : Real) / (b + 1 : Real) := by have hdown' : - dyadicToReal (dyadicOfRatDown ((a : Rat) / (b + 1))) ≤ + ratToReal (ratRoundDown ((a : Rat) / (b + 1))) ≤ (a : Real) / (b + 1 : Real) := by - simpa using dyadicOfRatDown_le_real ((a : Rat) / (b + 1)) + simpa using ratRoundDown_le_real ((a : Rat) / (b + 1)) simpa [sqrtLowerBase, num, den, a, b] using hdown' exact le_trans hdown hle /-- Square-root upper bound in reals. -/ -theorem real_sqrt_le_sqrtUpperBase {q : Dyadic} (hq : 0 ≤ q) : +theorem real_sqrt_le_sqrtUpperBase {q : Rat} (hq : 0 ≤ q) : Real.sqrt (q : Real) ≤ (sqrtUpperBase q : Real) := by classical - set num : Nat := q.toRat.num.natAbs - set den : Nat := q.toRat.den + set num : Nat := q.num.natAbs + set den : Nat := q.den set a : Nat := Nat.sqrt num set b : Nat := Nat.sqrt den have hden_pos : 0 < (den : Real) := by - exact_mod_cast q.toRat.den_pos + exact_mod_cast q.den_pos have hbpos : 0 < (b : Real) := by have hb : 0 < b := by - have hden : 0 < den := q.toRat.den_pos + have hden : 0 < den := q.den_pos exact (Nat.sqrt_pos).2 hden exact_mod_cast hb have hnum_lt : (num : Real) < (a + 1) ^ 2 := by @@ -272,16 +270,14 @@ theorem real_sqrt_le_sqrtUpperBase {q : Dyadic} (hq : 0 ≤ q) : have hpow : ((a + 1 : Real) / (b : Real)) ^ 2 = (a + 1) ^ 2 / (b : Real) ^ 2 := by simp [pow_two, div_mul_div_comm] have hq_cast : (q : Real) = (num : Real) / den := by - have hnum_nonneg : 0 ≤ q.toRat.num := by - have hq' : (0 : Rat) ≤ q.toRat := - (Dyadic.toRat_le_toRat_iff (x := 0) (y := q)).2 hq - exact (Rat.num_nonneg (q := q.toRat)).2 hq' - have hnum_eq : (num : Int) = q.toRat.num := by + have hnum_nonneg : 0 ≤ q.num := by + exact (Rat.num_nonneg (q := q)).2 hq + have hnum_eq : (num : Int) = q.num := by simpa [num] using (Int.natAbs_of_nonneg hnum_nonneg) - have hnum_cast : (q.toRat.num : Real) = (num : Real) := by + have hnum_cast : (q.num : Real) = (num : Real) := by exact_mod_cast hnum_eq.symm - have hq_rat : (q : Real) = (q.toRat.num : Real) / q.toRat.den := by - simp [dyadicToReal, Rat.cast_def] + have hq_rat : (q : Real) = (q.num : Real) / q.den := by + simp [Rat.cast_def] simpa [hnum_cast, den] using hq_rat have hsq : (q : Real) ≤ ((a + 1 : Real) / (b : Real)) ^ 2 := by simpa [hpow, hq_cast, den, num] using hdiv @@ -295,20 +291,20 @@ theorem real_sqrt_le_sqrtUpperBase {q : Dyadic} (hq : 0 ≤ q) : (a + 1 : Real) / (b : Real) ≤ (sqrtUpperBase q : Real) := by have hup' : (a + 1 : Real) / (b : Real) ≤ - dyadicToReal (dyadicOfRatUp ((a + 1 : Rat) / b)) := by - simpa using real_le_dyadicOfRatUp ((a + 1 : Rat) / b) + ratToReal (ratRoundUp ((a + 1 : Rat) / b)) := by + simpa using real_le_ratRoundUp ((a + 1 : Rat) / b) simpa [sqrtUpperBase, num, den, a, b] using hup' exact le_trans hle hup /-- Alternate square-root lower bound in reals. -/ -theorem sqrtLowerAlt_le_real_sqrt {q : Dyadic} (hq : 0 ≤ q) : +theorem sqrtLowerAlt_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : (sqrtLowerAlt q : Real) ≤ Real.sqrt (q : Real) := by classical - set num : Nat := q.toRat.num.natAbs - set den : Nat := q.toRat.den + set num : Nat := q.num.natAbs + set den : Nat := q.den set a : Nat := Nat.sqrt (num * den) have hden_pos : 0 < (den : Real) := by - exact_mod_cast q.toRat.den_pos + exact_mod_cast q.den_pos have hnumden_le : (a ^ 2 : Real) ≤ (num * den : Nat) := by exact_mod_cast (Nat.sqrt_le' (num * den)) have hmul : (a ^ 2 : Real) ≤ (num : Real) * den := by @@ -324,18 +320,16 @@ theorem sqrtLowerAlt_le_real_sqrt {q : Dyadic} (hq : 0 ≤ q) : exact mul_le_mul_of_nonneg_right hmul hden_sq_nonneg exact (div_le_div_iff₀ hden_pos2 hden_pos2).2 hmul' have hden_ne : (den : Real) ≠ 0 := by - exact_mod_cast q.toRat.den_pos.ne' + exact_mod_cast q.den_pos.ne' have hq_cast : (q : Real) = (num : Real) * den / (den : Real) ^ 2 := by - have hnum_nonneg : 0 ≤ q.toRat.num := by - have hq' : (0 : Rat) ≤ q.toRat := - (Dyadic.toRat_le_toRat_iff (x := 0) (y := q)).2 hq - exact (Rat.num_nonneg (q := q.toRat)).2 hq' - have hnum_eq : (num : Int) = q.toRat.num := by + have hnum_nonneg : 0 ≤ q.num := by + exact (Rat.num_nonneg (q := q)).2 hq + have hnum_eq : (num : Int) = q.num := by simpa [num] using (Int.natAbs_of_nonneg hnum_nonneg) - have hnum_cast : (q.toRat.num : Real) = (num : Real) := by + have hnum_cast : (q.num : Real) = (num : Real) := by exact_mod_cast hnum_eq.symm - have hq_rat : (q : Real) = (q.toRat.num : Real) / q.toRat.den := by - simp [dyadicToReal, Rat.cast_def] + have hq_rat : (q : Real) = (q.num : Real) / q.den := by + simp [Rat.cast_def] have hq_eq : (num : Real) / den = (num : Real) * den / (den : Real) ^ 2 := by field_simp [hden_ne] @@ -347,27 +341,27 @@ theorem sqrtLowerAlt_le_real_sqrt {q : Dyadic} (hq : 0 ≤ q) : have hden_nonneg : 0 ≤ (den : Real) := by exact_mod_cast (Nat.zero_le den) exact div_nonneg hnum_nonneg hden_nonneg have hq_nonneg : 0 ≤ (q : Real) := by - exact dyadicToReal_nonneg_of_nonneg hq + exact ratToReal_nonneg_of_nonneg hq have hle : (a : Real) / (den : Real) ≤ Real.sqrt (q : Real) := (Real.le_sqrt hnonneg hq_nonneg).2 hsq have hdown : (sqrtLowerAlt q : Real) ≤ (a : Real) / (den : Real) := by have hdown' : - dyadicToReal (dyadicOfRatDown ((a : Rat) / den)) ≤ + ratToReal (ratRoundDown ((a : Rat) / den)) ≤ (a : Real) / (den : Real) := by - simpa using dyadicOfRatDown_le_real ((a : Rat) / den) + simpa using ratRoundDown_le_real ((a : Rat) / den) simpa [sqrtLowerAlt, num, den, a] using hdown' exact le_trans hdown hle /-- Alternate square-root upper bound in reals. -/ -theorem real_sqrt_le_sqrtUpperAlt {q : Dyadic} (hq : 0 ≤ q) : +theorem real_sqrt_le_sqrtUpperAlt {q : Rat} (hq : 0 ≤ q) : Real.sqrt (q : Real) ≤ (sqrtUpperAlt q : Real) := by classical - set num : Nat := q.toRat.num.natAbs - set den : Nat := q.toRat.den + set num : Nat := q.num.natAbs + set den : Nat := q.den set a : Nat := Nat.sqrt (num * den) have hden_pos : 0 < (den : Real) := by - exact_mod_cast q.toRat.den_pos + exact_mod_cast q.den_pos have hnumden_lt : (num * den : Real) < (a + 1) ^ 2 := by exact_mod_cast (Nat.lt_succ_sqrt' (num * den)) have hmul : (num : Real) * den ≤ (a + 1 : Real) ^ 2 := by @@ -383,18 +377,16 @@ theorem real_sqrt_le_sqrtUpperAlt {q : Dyadic} (hq : 0 ≤ q) : exact mul_le_mul_of_nonneg_right hmul hden_sq_nonneg exact (div_le_div_iff₀ hden_pos2 hden_pos2).2 hmul' have hden_ne : (den : Real) ≠ 0 := by - exact_mod_cast q.toRat.den_pos.ne' + exact_mod_cast q.den_pos.ne' have hq_cast : (q : Real) = (num : Real) * den / (den : Real) ^ 2 := by - have hnum_nonneg : 0 ≤ q.toRat.num := by - have hq' : (0 : Rat) ≤ q.toRat := - (Dyadic.toRat_le_toRat_iff (x := 0) (y := q)).2 hq - exact (Rat.num_nonneg (q := q.toRat)).2 hq' - have hnum_eq : (num : Int) = q.toRat.num := by + have hnum_nonneg : 0 ≤ q.num := by + exact (Rat.num_nonneg (q := q)).2 hq + have hnum_eq : (num : Int) = q.num := by simpa [num] using (Int.natAbs_of_nonneg hnum_nonneg) - have hnum_cast : (q.toRat.num : Real) = (num : Real) := by + have hnum_cast : (q.num : Real) = (num : Real) := by exact_mod_cast hnum_eq.symm - have hq_rat : (q : Real) = (q.toRat.num : Real) / q.toRat.den := by - simp [dyadicToReal, Rat.cast_def] + have hq_rat : (q : Real) = (q.num : Real) / q.den := by + simp [Rat.cast_def] have hq_eq : (num : Real) / den = (num : Real) * den / (den : Real) ^ 2 := by field_simp [hden_ne] @@ -415,34 +407,34 @@ theorem real_sqrt_le_sqrtUpperAlt {q : Dyadic} (hq : 0 ≤ q) : (a + 1 : Real) / (den : Real) ≤ (sqrtUpperAlt q : Real) := by have hup' : (a + 1 : Real) / (den : Real) ≤ - dyadicToReal (dyadicOfRatUp ((a + 1 : Rat) / den)) := by - simpa using real_le_dyadicOfRatUp ((a + 1 : Rat) / den) + ratToReal (ratRoundUp ((a + 1 : Rat) / den)) := by + simpa using real_le_ratRoundUp ((a + 1 : Rat) / den) simpa [sqrtUpperAlt, num, den, a] using hup' exact le_trans hle hup /-- Square-root lower bound in reals (tighter of two bounds). -/ -theorem sqrtLower_le_real_sqrt {q : Dyadic} (hq : 0 ≤ q) : +theorem sqrtLower_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : (sqrtLower q : Real) ≤ Real.sqrt (q : Real) := by have hbase := sqrtLowerBase_le_real_sqrt (q := q) hq have halt := sqrtLowerAlt_le_real_sqrt (q := q) hq simpa [sqrtLower] using (max_le_iff).2 ⟨hbase, halt⟩ /-- Square-root upper bound in reals (tighter of two bounds). -/ -theorem real_sqrt_le_sqrtUpper {q : Dyadic} (hq : 0 ≤ q) : +theorem real_sqrt_le_sqrtUpper {q : Rat} (hq : 0 ≤ q) : Real.sqrt (q : Real) ≤ (sqrtUpper q : Real) := by have hbase := real_sqrt_le_sqrtUpperBase (q := q) hq have halt := real_sqrt_le_sqrtUpperAlt (q := q) hq simpa [sqrtUpper] using (le_min_iff).2 ⟨hbase, halt⟩ /-- Bounds for multiplying a scalar by a bounded value. -/ -def scaleInterval (x lo hi : Dyadic) : Dyadic × Dyadic := +def scaleInterval (x lo hi : Rat) : Rat × Rat := if 0 ≤ x then (x * lo, x * hi) else (x * hi, x * lo) /-- `scaleInterval` bounds a product. -/ -theorem scaleInterval_bounds {x lo hi y : Dyadic} +theorem scaleInterval_bounds {x lo hi y : Rat} (hlo : lo ≤ y) (hhi : y ≤ hi) : let bounds := scaleInterval x lo hi bounds.1 ≤ x * y ∧ x * y ≤ bounds.2 := by @@ -460,30 +452,30 @@ theorem scaleInterval_bounds {x lo hi y : Dyadic} simp [scaleInterval, hx, h1, h2] /-- `scaleInterval` bounds interpreted in the reals. -/ -theorem scaleInterval_bounds_real {x lo hi : Dyadic} {y : Real} +theorem scaleInterval_bounds_real {x lo hi : Rat} {y : Real} (hlo : (lo : Real) ≤ y) (hhi : y ≤ (hi : Real)) : let bounds := scaleInterval x lo hi (bounds.1 : Real) ≤ (x : Real) * y ∧ (x : Real) * y ≤ (bounds.2 : Real) := by by_cases hx : 0 ≤ x · have h1 : (x : Real) * (lo : Real) ≤ (x : Real) * y := by - have hx' : 0 ≤ (x : Real) := dyadicToReal_nonneg_of_nonneg hx + have hx' : 0 ≤ (x : Real) := ratToReal_nonneg_of_nonneg hx exact mul_le_mul_of_nonneg_left hlo hx' have h2 : (x : Real) * y ≤ (x : Real) * (hi : Real) := by - have hx' : 0 ≤ (x : Real) := dyadicToReal_nonneg_of_nonneg hx + have hx' : 0 ≤ (x : Real) := ratToReal_nonneg_of_nonneg hx exact mul_le_mul_of_nonneg_left hhi hx' simp [scaleInterval, hx, h1, h2] · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) have h1 : (x : Real) * (hi : Real) ≤ (x : Real) * y := by - have hx'' : (x : Real) ≤ 0 := (dyadicToReal_nonpos_iff (x := x)).2 hx' + have hx'' : (x : Real) ≤ 0 := (ratToReal_nonpos_iff (x := x)).2 hx' exact mul_le_mul_of_nonpos_left hhi hx'' have h2 : (x : Real) * y ≤ (x : Real) * (lo : Real) := by - have hx'' : (x : Real) ≤ 0 := (dyadicToReal_nonpos_iff (x := x)).2 hx' + have hx'' : (x : Real) ≤ 0 := (ratToReal_nonpos_iff (x := x)).2 hx' exact mul_le_mul_of_nonpos_left hlo hx'' simp [scaleInterval, hx, h1, h2] /-- Real-valued LayerNorm output for a vector. -/ noncomputable def layerNormReal {n : Nat} - (eps : Dyadic) (gamma beta : Fin n → Dyadic) (x : Fin n → Dyadic) : Fin n → Real := + (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) : Fin n → Real := if n = 0 then fun _ => 0 else @@ -494,7 +486,7 @@ noncomputable def layerNormReal {n : Nat} /-- Real-valued LayerNorm output for a real vector. -/ noncomputable def layerNormRealOfReal {n : Nat} - (eps : Dyadic) (gamma beta : Fin n → Dyadic) (x : Fin n → Real) : Fin n → Real := + (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Real) : Fin n → Real := if n = 0 then fun _ => 0 else @@ -505,22 +497,22 @@ noncomputable def layerNormRealOfReal {n : Nat} /-- Interval bounds for LayerNorm outputs. -/ def layerNormBounds {n : Nat} - (eps : Dyadic) (gamma beta : Fin n → Dyadic) (x : Fin n → Dyadic) : - (Fin n → Dyadic) × (Fin n → Dyadic) := + (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) : + (Fin n → Rat) × (Fin n → Rat) := if n = 0 then (fun _ => 0, fun _ => 0) else let μLo := mean x let μHi := meanUpper x - let centeredBound : Fin n → Dyadic := fun i => + let centeredBound : Fin n → Rat := fun i => max |x i - μHi| |x i - μLo| - let invStdBound : Dyadic := dyadicDivUp 1 (sqrtLower eps) - let radius : Fin n → Dyadic := fun i => |gamma i| * centeredBound i * invStdBound + let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) + let radius : Fin n → Rat := fun i => |gamma i| * centeredBound i * invStdBound (fun i => beta i - radius i, fun i => beta i + radius i) /-- `layerNormBounds` soundness for real LayerNorm outputs. -/ theorem layerNormBounds_spec {n : Nat} - (eps : Dyadic) (gamma beta : Fin n → Dyadic) (x : Fin n → Dyadic) + (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : let bounds := layerNormBounds eps gamma beta x ∀ i, @@ -528,10 +520,10 @@ theorem layerNormBounds_spec {n : Nat} layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by classical intro bounds i - let μLo : Dyadic := mean x - let μHi : Dyadic := meanUpper x - let centeredBound : Fin n → Dyadic := fun j => max |x j - μHi| |x j - μLo| - let invStdBound : Dyadic := dyadicDivUp 1 (sqrtLower eps) + let μLo : Rat := mean x + let μHi : Rat := meanUpper x + let centeredBound : Fin n → Rat := fun j => max |x j - μHi| |x j - μLo| + let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) let varEps : Real := (varianceRat x : Real) + (eps : Real) let μ : Real := meanRat x let invStd : Real := (Real.sqrt varEps)⁻¹ @@ -539,12 +531,12 @@ theorem layerNormBounds_spec {n : Nat} have h0 : 0 ≤ centeredBound i := by dsimp [centeredBound] exact le_trans (abs_nonneg _) (le_max_left _ _) - exact dyadicToReal_nonneg_of_nonneg h0 + exact ratToReal_nonneg_of_nonneg h0 have hmean_lo_real : (μLo : Real) ≤ μ := by - have h := dyadicOfRatDown_le_real (meanRat x) + have h := ratRoundDown_le_real (meanRat x) simpa [μLo, μ, mean_def x hne] using h have hmean_hi_real : μ ≤ (μHi : Real) := by - have h := real_le_dyadicOfRatUp (meanRat x) + have h := real_le_ratRoundUp (meanRat x) simpa [μHi, μ, meanUpper_def x hne] using h have hcentered_abs : |(x i : Real) - μ| ≤ (centeredBound i : Real) := by have hlo : (x i : Real) - (μHi : Real) ≤ (x i : Real) - μ := by @@ -552,8 +544,8 @@ theorem layerNormBounds_spec {n : Nat} have hhi : (x i : Real) - μ ≤ (x i : Real) - (μLo : Real) := by exact sub_le_sub_left hmean_lo_real (x i : Real) have hbound := abs_le_max_of_bounds hlo hhi - simpa [centeredBound, μLo, μHi, dyadicToReal_abs, dyadicToReal_sub, - dyadicToReal_max] using hbound + simpa [centeredBound, μLo, μHi, ratToReal_abs, ratToReal_sub, + ratToReal_max] using hbound have hvar_nonneg : 0 ≤ (varianceRat x : Real) := varianceRat_nonneg_real x hne have hsqrt_lower : (sqrtLower eps : Real) ≤ Real.sqrt varEps := by @@ -569,14 +561,13 @@ theorem layerNormBounds_spec {n : Nat} exact Real.sqrt_le_sqrt hle exact le_trans hsqrt_eps hsqrt_eps' have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by - simpa [dyadicToReal_zero] using - (dyadicToReal_lt_iff (x := 0) (y := sqrtLower eps)).2 hsqrt + exact (Rat.cast_pos (K := Real) (q := sqrtLower eps)).2 hsqrt have hinv_sqrt : invStd ≤ (sqrtLower eps : Real)⁻¹ := by have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower simpa [invStd] using h have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt - have hdiv := dyadicDivUp_ge_real (x := 1) (y := sqrtLower eps) hy + have hdiv := ratDivUp_ge_real (x := 1) (y := sqrtLower eps) hy simpa [invStdBound, one_div] using hdiv have hinv : invStd ≤ (invStdBound : Real) := by exact le_trans hinv_sqrt hinv_bound @@ -607,7 +598,7 @@ theorem layerNormBounds_spec {n : Nat} have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg simp [t, abs_mul, hinv_abs, mul_assoc] simpa [ht] using hmul2 - let radius : Fin n → Dyadic := fun j => |gamma j| * centeredBound j * invStdBound + let radius : Fin n → Rat := fun j => |gamma j| * centeredBound j * invStdBound have ht_abs' : |t| ≤ (radius i : Real) := by simpa [radius, centeredBound, invStdBound] using ht_abs have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by @@ -633,22 +624,22 @@ theorem layerNormBounds_spec {n : Nat} /-- Interval bounds for LayerNorm outputs from per-coordinate intervals. -/ def layerNormIntervalBounds {n : Nat} - (eps : Dyadic) (gamma beta : Fin n → Dyadic) (lo hi : Fin n → Dyadic) : - (Fin n → Dyadic) × (Fin n → Dyadic) := + (eps : Rat) (gamma beta : Fin n → Rat) (lo hi : Fin n → Rat) : + (Fin n → Rat) × (Fin n → Rat) := if n = 0 then (fun _ => 0, fun _ => 0) else let μLo := mean lo let μHi := meanUpper hi - let centeredBound : Fin n → Dyadic := fun i => + let centeredBound : Fin n → Rat := fun i => max |lo i - μHi| |hi i - μLo| - let invStdBound : Dyadic := dyadicDivUp 1 (sqrtLower eps) - let radius : Fin n → Dyadic := fun i => |gamma i| * centeredBound i * invStdBound + let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) + let radius : Fin n → Rat := fun i => |gamma i| * centeredBound i * invStdBound (fun i => beta i - radius i, fun i => beta i + radius i) /-- `layerNormIntervalBounds` soundness for real LayerNorm outputs. -/ theorem layerNormIntervalBounds_spec {n : Nat} - (eps : Dyadic) (gamma beta : Fin n → Dyadic) (lo hi : Fin n → Dyadic) (x : Fin n → Dyadic) + (eps : Rat) (gamma beta : Fin n → Rat) (lo hi : Fin n → Rat) (x : Fin n → Rat) (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) (hlo : ∀ i, lo i ≤ x i) (hhi : ∀ i, x i ≤ hi i) : let bounds := layerNormIntervalBounds eps gamma beta lo hi @@ -657,10 +648,10 @@ theorem layerNormIntervalBounds_spec {n : Nat} layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by classical intro bounds i - let μLo : Dyadic := mean lo - let μHi : Dyadic := meanUpper hi - let centeredBound : Fin n → Dyadic := fun j => max |lo j - μHi| |hi j - μLo| - let invStdBound : Dyadic := dyadicDivUp 1 (sqrtLower eps) + let μLo : Rat := mean lo + let μHi : Rat := meanUpper hi + let centeredBound : Fin n → Rat := fun j => max |lo j - μHi| |hi j - μLo| + let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) let varEps : Real := (varianceRat x : Real) + (eps : Real) let μ : Real := meanRat x let invStd : Real := (Real.sqrt varEps)⁻¹ @@ -668,19 +659,19 @@ theorem layerNormIntervalBounds_spec {n : Nat} have h0 : 0 ≤ centeredBound i := by dsimp [centeredBound] exact le_trans (abs_nonneg _) (le_max_left _ _) - exact dyadicToReal_nonneg_of_nonneg h0 + exact ratToReal_nonneg_of_nonneg h0 have hcentered_abs : |(x i : Real) - μ| ≤ (centeredBound i : Real) := by have hmean_lo_real : (μLo : Real) ≤ μ := by have hmean_rat : (meanRat lo : Real) ≤ (meanRat x : Real) := meanRat_le_meanRat_real lo x hne hlo have hdown : (μLo : Real) ≤ (meanRat lo : Real) := by - simpa [μLo, mean_def lo hne] using dyadicOfRatDown_le_real (meanRat lo) + simpa [μLo, mean_def lo hne] using ratRoundDown_le_real (meanRat lo) exact le_trans hdown hmean_rat have hmean_hi_real : μ ≤ (μHi : Real) := by have hmean_rat : (meanRat x : Real) ≤ (meanRat hi : Real) := meanRat_le_meanRat_real x hi hne hhi have hup : (meanRat hi : Real) ≤ (μHi : Real) := by - simpa [μHi, meanUpper_def hi hne] using real_le_dyadicOfRatUp (meanRat hi) + simpa [μHi, meanUpper_def hi hne] using real_le_ratRoundUp (meanRat hi) exact le_trans hmean_rat hup have hlo' : (lo i : Real) - (μHi : Real) ≤ (x i : Real) - μ := by have h1 : (lo i : Real) - (μHi : Real) ≤ (lo i : Real) - μ := by @@ -688,21 +679,21 @@ theorem layerNormIntervalBounds_spec {n : Nat} have h2 : (lo i : Real) - μ ≤ (x i : Real) - μ := by exact sub_le_sub_right (by - simpa using dyadicToReal_le_of_le (hlo i)) + exact ratToReal_le_of_le (hlo i)) μ exact le_trans h1 h2 have hhi' : (x i : Real) - μ ≤ (hi i : Real) - (μLo : Real) := by have h1 : (x i : Real) - μ ≤ (hi i : Real) - μ := by exact sub_le_sub_right (by - simpa using dyadicToReal_le_of_le (hhi i)) + exact ratToReal_le_of_le (hhi i)) μ have h2 : (hi i : Real) - μ ≤ (hi i : Real) - (μLo : Real) := by exact sub_le_sub_left hmean_lo_real (hi i : Real) exact le_trans h1 h2 have hbound := abs_le_max_of_bounds hlo' hhi' - simpa [centeredBound, μLo, μHi, dyadicToReal_abs, dyadicToReal_sub, - dyadicToReal_max] using hbound + simpa [centeredBound, μLo, μHi, ratToReal_abs, ratToReal_sub, + ratToReal_max] using hbound have hsqrt_lower : (sqrtLower eps : Real) ≤ Real.sqrt varEps := by have hsqrt_eps : @@ -718,14 +709,13 @@ theorem layerNormIntervalBounds_spec {n : Nat} exact Real.sqrt_le_sqrt hle exact le_trans hsqrt_eps hsqrt_eps' have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by - simpa [dyadicToReal_zero] using - (dyadicToReal_lt_iff (x := 0) (y := sqrtLower eps)).2 hsqrt + exact (Rat.cast_pos (K := Real) (q := sqrtLower eps)).2 hsqrt have hinv_sqrt : invStd ≤ (sqrtLower eps : Real)⁻¹ := by have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower simpa [invStd] using h have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt - have hdiv := dyadicDivUp_ge_real (x := 1) (y := sqrtLower eps) hy + have hdiv := ratDivUp_ge_real (x := 1) (y := sqrtLower eps) hy simpa [invStdBound, one_div] using hdiv have hinv : invStd ≤ (invStdBound : Real) := by exact le_trans hinv_sqrt hinv_bound @@ -756,7 +746,7 @@ theorem layerNormIntervalBounds_spec {n : Nat} have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg simp [t, abs_mul, hinv_abs, mul_assoc] simpa [ht] using hmul2 - let radius : Fin n → Dyadic := fun j => |gamma j| * centeredBound j * invStdBound + let radius : Fin n → Rat := fun j => |gamma j| * centeredBound j * invStdBound have ht_abs' : |t| ≤ (radius i : Real) := by simpa [radius, centeredBound, invStdBound] using ht_abs have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by @@ -782,16 +772,16 @@ theorem layerNormIntervalBounds_spec {n : Nat} /-- Interval bounds for LayerNorm outputs from an absolute input bound. -/ def layerNormAbsBounds {n : Nat} - (eps : Dyadic) (gamma beta : Fin n → Dyadic) (absBound : Dyadic) : - (Fin n → Dyadic) × (Fin n → Dyadic) := - let centeredBound : Dyadic := 2 * absBound - let invStdBound : Dyadic := dyadicDivUp 1 (sqrtLower eps) - let radius : Fin n → Dyadic := fun i => |gamma i| * centeredBound * invStdBound + (eps : Rat) (gamma beta : Fin n → Rat) (absBound : Rat) : + (Fin n → Rat) × (Fin n → Rat) := + let centeredBound : Rat := 2 * absBound + let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) + let radius : Fin n → Rat := fun i => |gamma i| * centeredBound * invStdBound (fun i => beta i - radius i, fun i => beta i + radius i) /-- `layerNormAbsBounds` soundness for real LayerNorm outputs under absolute input bounds. -/ theorem layerNormAbsBounds_spec {n : Nat} - (eps : Dyadic) (gamma beta : Fin n → Dyadic) (absBound : Dyadic) (x : Fin n → Dyadic) + (eps : Rat) (gamma beta : Fin n → Rat) (absBound : Rat) (x : Fin n → Rat) (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) (habs : ∀ i, |x i| ≤ absBound) : let bounds := layerNormAbsBounds eps gamma beta absBound @@ -805,15 +795,15 @@ theorem layerNormAbsBounds_spec {n : Nat} meanReal_abs_le_bound (x := fun j => (x j : Real)) (bound := absBound) hne (by intro j - exact dyadicToReal_abs_le_of_le (habs j)) + exact ratToReal_abs_le_of_le (habs j)) simpa [meanReal_eq_meanRat] using h have hbound_nonneg : 0 ≤ absBound := by have hposn : 0 < n := Nat.pos_of_ne_zero hne let i0 : Fin n := ⟨0, hposn⟩ have h0 : 0 ≤ |x i0| := abs_nonneg _ exact le_trans h0 (habs i0) - let centeredBound : Dyadic := 2 * absBound - let invStdBound : Dyadic := dyadicDivUp 1 (sqrtLower eps) + let centeredBound : Rat := 2 * absBound + let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) let varEps : Real := (varianceRat x : Real) + (eps : Real) let μ : Real := meanRat x let invStd : Real := (Real.sqrt varEps)⁻¹ @@ -821,7 +811,7 @@ theorem layerNormAbsBounds_spec {n : Nat} have h1 : |(x i : Real) - μ| ≤ |(x i : Real)| + |μ| := by simpa [sub_eq_add_neg, abs_neg] using abs_add_le (x i : Real) (-μ) have hx : |(x i : Real)| ≤ (absBound : Real) := by - exact dyadicToReal_abs_le_of_le (habs i) + exact ratToReal_abs_le_of_le (habs i) have hmu : |μ| ≤ (absBound : Real) := by simpa [μ] using hmean_abs_real have h2 : |(x i : Real)| + |μ| ≤ (absBound : Real) + (absBound : Real) := @@ -830,7 +820,7 @@ theorem layerNormAbsBounds_spec {n : Nat} le_trans h1 h2 simpa [centeredBound, two_mul] using h12 have hbound_nonneg_real : 0 ≤ (absBound : Real) := by - exact dyadicToReal_nonneg_of_nonneg hbound_nonneg + exact ratToReal_nonneg_of_nonneg hbound_nonneg have hcentered_nonneg : 0 ≤ (centeredBound : Real) := by have hsum := add_nonneg hbound_nonneg_real hbound_nonneg_real simpa [centeredBound, two_mul] using hsum @@ -849,14 +839,13 @@ theorem layerNormAbsBounds_spec {n : Nat} exact Real.sqrt_le_sqrt hle exact le_trans hsqrt_eps hsqrt_eps' have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by - simpa [dyadicToReal_zero] using - (dyadicToReal_lt_iff (x := 0) (y := sqrtLower eps)).2 hsqrt + exact (Rat.cast_pos (K := Real) (q := sqrtLower eps)).2 hsqrt have hinv_sqrt : invStd ≤ (sqrtLower eps : Real)⁻¹ := by have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower simpa [invStd] using h have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt - have hdiv := dyadicDivUp_ge_real (x := 1) (y := sqrtLower eps) hy + have hdiv := ratDivUp_ge_real (x := 1) (y := sqrtLower eps) hy simpa [invStdBound, one_div] using hdiv have hinv : invStd ≤ (invStdBound : Real) := by exact le_trans hinv_sqrt hinv_bound @@ -887,7 +876,7 @@ theorem layerNormAbsBounds_spec {n : Nat} have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg simp [t, abs_mul, hinv_abs, mul_assoc] simpa [ht] using hmul2 - let radius : Fin n → Dyadic := fun j => |gamma j| * centeredBound * invStdBound + let radius : Fin n → Rat := fun j => |gamma j| * centeredBound * invStdBound have ht_abs' : |t| ≤ (radius i : Real) := by simpa [radius, centeredBound, invStdBound] using ht_abs have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by @@ -911,7 +900,7 @@ theorem layerNormAbsBounds_spec {n : Nat} /-- `layerNormAbsBounds` soundness for real LayerNorm outputs on real inputs. -/ theorem layerNormAbsBounds_spec_real {n : Nat} - (eps : Dyadic) (gamma beta : Fin n → Dyadic) (absBound : Dyadic) (x : Fin n → Real) + (eps : Rat) (gamma beta : Fin n → Rat) (absBound : Rat) (x : Fin n → Real) (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) (habs : ∀ i, |x i| ≤ (absBound : Real)) : let bounds := layerNormAbsBounds eps gamma beta absBound @@ -927,8 +916,8 @@ theorem layerNormAbsBounds_spec_real {n : Nat} let i0 : Fin n := ⟨0, hposn⟩ have h0 : 0 ≤ |x i0| := abs_nonneg _ exact le_trans h0 (habs i0) - let centeredBound : Dyadic := 2 * absBound - let invStdBound : Dyadic := dyadicDivUp 1 (sqrtLower eps) + let centeredBound : Rat := 2 * absBound + let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) let varEps : Real := varianceReal x + (eps : Real) let μ : Real := meanReal x let invStd : Real := (Real.sqrt varEps)⁻¹ @@ -961,14 +950,13 @@ theorem layerNormAbsBounds_spec_real {n : Nat} exact Real.sqrt_le_sqrt hle exact le_trans hsqrt_eps hsqrt_eps' have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by - simpa [dyadicToReal_zero] using - (dyadicToReal_lt_iff (x := 0) (y := sqrtLower eps)).2 hsqrt + exact (Rat.cast_pos (K := Real) (q := sqrtLower eps)).2 hsqrt have hinv_sqrt : invStd ≤ (sqrtLower eps : Real)⁻¹ := by have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower simpa [invStd] using h have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt - have hdiv := dyadicDivUp_ge_real (x := 1) (y := sqrtLower eps) hy + have hdiv := ratDivUp_ge_real (x := 1) (y := sqrtLower eps) hy simpa [invStdBound, one_div] using hdiv have hinv : invStd ≤ (invStdBound : Real) := by exact le_trans hinv_sqrt hinv_bound @@ -999,7 +987,7 @@ theorem layerNormAbsBounds_spec_real {n : Nat} have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg simp [t, abs_mul, hinv_abs, mul_assoc] simpa [ht] using hmul2 - let radius : Fin n → Dyadic := fun j => |gamma j| * centeredBound * invStdBound + let radius : Fin n → Rat := fun j => |gamma j| * centeredBound * invStdBound have ht_abs' : |t| ≤ (radius i : Real) := by simpa [radius, centeredBound, invStdBound] using ht_abs have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by @@ -1023,7 +1011,7 @@ theorem layerNormAbsBounds_spec_real {n : Nat} /-- `layerNormIntervalBounds` soundness for real LayerNorm outputs on real inputs. -/ theorem layerNormIntervalBounds_spec_real {n : Nat} - (eps : Dyadic) (gamma beta : Fin n → Dyadic) (lo hi : Fin n → Dyadic) (x : Fin n → Real) + (eps : Rat) (gamma beta : Fin n → Rat) (lo hi : Fin n → Rat) (x : Fin n → Real) (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) (hlo : ∀ i, (lo i : Real) ≤ x i) (hhi : ∀ i, x i ≤ (hi i : Real)) : let bounds := layerNormIntervalBounds eps gamma beta lo hi @@ -1039,7 +1027,7 @@ theorem layerNormIntervalBounds_spec_real {n : Nat} have hrat : (meanRat lo : Real) ≤ meanReal x := by simpa [meanReal_eq_meanRat] using h have hdown : (mean lo : Real) ≤ (meanRat lo : Real) := by - simpa [mean_def lo hne] using dyadicOfRatDown_le_real (meanRat lo) + simpa [mean_def lo hne] using ratRoundDown_le_real (meanRat lo) exact le_trans hdown hrat have hmean_hi : meanReal x ≤ (meanUpper hi : Real) := by have h := @@ -1048,12 +1036,12 @@ theorem layerNormIntervalBounds_spec_real {n : Nat} have hrat : meanReal x ≤ (meanRat hi : Real) := by simpa [meanReal_eq_meanRat] using h have hup : (meanRat hi : Real) ≤ (meanUpper hi : Real) := by - simpa [meanUpper_def hi hne] using real_le_dyadicOfRatUp (meanRat hi) + simpa [meanUpper_def hi hne] using real_le_ratRoundUp (meanRat hi) exact le_trans hrat hup - let μLo : Dyadic := mean lo - let μHi : Dyadic := meanUpper hi - let centeredBound : Fin n → Dyadic := fun j => max |lo j - μHi| |hi j - μLo| - let invStdBound : Dyadic := dyadicDivUp 1 (sqrtLower eps) + let μLo : Rat := mean lo + let μHi : Rat := meanUpper hi + let centeredBound : Fin n → Rat := fun j => max |lo j - μHi| |hi j - μLo| + let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) let varEps : Real := varianceReal x + (eps : Real) let μ : Real := meanReal x let invStd : Real := (Real.sqrt varEps)⁻¹ @@ -1061,7 +1049,7 @@ theorem layerNormIntervalBounds_spec_real {n : Nat} have h0 : 0 ≤ centeredBound i := by dsimp [centeredBound] exact le_trans (abs_nonneg _) (le_max_left _ _) - exact dyadicToReal_nonneg_of_nonneg h0 + exact ratToReal_nonneg_of_nonneg h0 have hcentered_abs : |x i - μ| ≤ (centeredBound i : Real) := by have hmean_lo_real : (μLo : Real) ≤ μ := by simpa [μLo, μ] using hmean_lo @@ -1080,8 +1068,8 @@ theorem layerNormIntervalBounds_spec_real {n : Nat} exact sub_le_sub_left hmean_lo_real (hi i : Real) exact le_trans h1 h2 have hbound := abs_le_max_of_bounds hlo' hhi' - simpa [centeredBound, μLo, μHi, dyadicToReal_abs, dyadicToReal_sub, - dyadicToReal_max] using hbound + simpa [centeredBound, μLo, μHi, ratToReal_abs, ratToReal_sub, + ratToReal_max] using hbound have hvar_nonneg : 0 ≤ varianceReal x := varianceReal_nonneg x hne have hsqrt_lower : (sqrtLower eps : Real) ≤ Real.sqrt varEps := by @@ -1097,14 +1085,13 @@ theorem layerNormIntervalBounds_spec_real {n : Nat} exact Real.sqrt_le_sqrt hle exact le_trans hsqrt_eps hsqrt_eps' have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by - simpa [dyadicToReal_zero] using - (dyadicToReal_lt_iff (x := 0) (y := sqrtLower eps)).2 hsqrt + exact (Rat.cast_pos (K := Real) (q := sqrtLower eps)).2 hsqrt have hinv_sqrt : invStd ≤ (sqrtLower eps : Real)⁻¹ := by have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower simpa [invStd] using h have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt - have hdiv := dyadicDivUp_ge_real (x := 1) (y := sqrtLower eps) hy + have hdiv := ratDivUp_ge_real (x := 1) (y := sqrtLower eps) hy simpa [invStdBound, one_div] using hdiv have hinv : invStd ≤ (invStdBound : Real) := by exact le_trans hinv_sqrt hinv_bound @@ -1134,7 +1121,7 @@ theorem layerNormIntervalBounds_spec_real {n : Nat} have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg simp [t, abs_mul, hinv_abs, mul_assoc] simpa [ht] using hmul2 - let radius : Fin n → Dyadic := fun j => |gamma j| * centeredBound j * invStdBound + let radius : Fin n → Rat := fun j => |gamma j| * centeredBound j * invStdBound have ht_abs' : |t| ≤ (radius i : Real) := by simpa [radius, centeredBound, invStdBound] using ht_abs have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by diff --git a/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean b/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean index 1c4112e..ba506e8 100644 --- a/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean +++ b/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean @@ -12,7 +12,7 @@ import Nfp.Core.Basic /-! Mean/variance helpers for LayerNorm bounds. -This module isolates the dyadic and real mean/variance definitions and their +This module isolates the rational and real mean/variance definitions and their basic lemmas to keep `LayerNorm` bounds modular. -/ @@ -25,65 +25,65 @@ namespace Bounds open scoped BigOperators /-- Sum as a rational, used for exact mean/variance computations. -/ -def sumRat {n : Nat} (x : Fin n → Dyadic) : Rat := +def sumRat {n : Nat} (x : Fin n → Rat) : Rat := ∑ i, (x i : Rat) /-- Exact mean as a rational (defaults to `0` when `n = 0`). -/ -def meanRat {n : Nat} (x : Fin n → Dyadic) : Rat := +def meanRat {n : Nat} (x : Fin n → Rat) : Rat := if n = 0 then 0 else (sumRat x) / n -/-- Mean rounded down to dyadic precision (defaults to `0` when `n = 0`). -/ -def mean {n : Nat} (x : Fin n → Dyadic) : Dyadic := +/-- Mean rounded down (identity in exact-rational mode; defaults to `0` when `n = 0`). -/ +def mean {n : Nat} (x : Fin n → Rat) : Rat := if n = 0 then 0 else - dyadicOfRatDown (meanRat x) + ratRoundDown (meanRat x) -/-- Mean rounded up to dyadic precision (defaults to `0` when `n = 0`). -/ -def meanUpper {n : Nat} (x : Fin n → Dyadic) : Dyadic := +/-- Mean rounded up (identity in exact-rational mode; defaults to `0` when `n = 0`). -/ +def meanUpper {n : Nat} (x : Fin n → Rat) : Rat := if n = 0 then 0 else - dyadicOfRatUp (meanRat x) + ratRoundUp (meanRat x) /-- Unfold `mean` when `n ≠ 0`. -/ -theorem mean_def {n : Nat} (x : Fin n → Dyadic) (h : n ≠ 0) : - mean x = dyadicOfRatDown (meanRat x) := by +theorem mean_def {n : Nat} (x : Fin n → Rat) (h : n ≠ 0) : + mean x = ratRoundDown (meanRat x) := by simp [mean, h] /-- Unfold `meanUpper` when `n ≠ 0`. -/ -theorem meanUpper_def {n : Nat} (x : Fin n → Dyadic) (h : n ≠ 0) : - meanUpper x = dyadicOfRatUp (meanRat x) := by +theorem meanUpper_def {n : Nat} (x : Fin n → Rat) (h : n ≠ 0) : + meanUpper x = ratRoundUp (meanRat x) := by simp [meanUpper, h] /-- Exact variance as a rational (defaults to `0` when `n = 0`). -/ -def varianceRat {n : Nat} (x : Fin n → Dyadic) : Rat := +def varianceRat {n : Nat} (x : Fin n → Rat) : Rat := if n = 0 then 0 else let μ := meanRat x (∑ i, ((x i : Rat) - μ) ^ 2) / n -/-- Variance rounded down to dyadic precision (defaults to `0` when `n = 0`). -/ -def variance {n : Nat} (x : Fin n → Dyadic) : Dyadic := +/-- Variance rounded down (identity in exact-rational mode; defaults to `0` when `n = 0`). -/ +def variance {n : Nat} (x : Fin n → Rat) : Rat := if n = 0 then 0 else - dyadicOfRatDown (varianceRat x) + ratRoundDown (varianceRat x) -/-- Variance rounded up to dyadic precision (defaults to `0` when `n = 0`). -/ -def varianceUpper {n : Nat} (x : Fin n → Dyadic) : Dyadic := +/-- Variance rounded up (identity in exact-rational mode; defaults to `0` when `n = 0`). -/ +def varianceUpper {n : Nat} (x : Fin n → Rat) : Rat := if n = 0 then 0 else - dyadicOfRatUp (varianceRat x) + ratRoundUp (varianceRat x) /-- Unfold `variance` when `n ≠ 0`. -/ -theorem variance_def {n : Nat} (x : Fin n → Dyadic) (h : n ≠ 0) : - variance x = dyadicOfRatDown (varianceRat x) := by +theorem variance_def {n : Nat} (x : Fin n → Rat) (h : n ≠ 0) : + variance x = ratRoundDown (varianceRat x) := by simp [variance, h] /-! Interval helpers. -/ @@ -124,7 +124,7 @@ theorem meanReal_def {n : Nat} (x : Fin n → Real) (h : n ≠ 0) : simp [meanReal, h] /-- `meanReal` agrees with `mean` after casting. -/ -theorem meanReal_eq_meanRat {n : Nat} (x : Fin n → Dyadic) : +theorem meanReal_eq_meanRat {n : Nat} (x : Fin n → Rat) : meanReal (fun i => (x i : Real)) = (meanRat x : Real) := by by_cases h : n = 0 · simp [meanReal, meanRat, h] @@ -132,7 +132,7 @@ theorem meanReal_eq_meanRat {n : Nat} (x : Fin n → Dyadic) : (sumRat x : Real) = ∑ i, (x i : Real) := by classical unfold sumRat - simp [dyadicToReal, Rat.cast_sum] + simp [Rat.cast_sum] have hmean : (meanRat x : Real) = (sumRat x : Real) / n := by simp [meanRat, h] have hreal : meanReal (fun i => (x i : Real)) = (∑ i, (x i : Real)) / n := by @@ -153,15 +153,15 @@ theorem meanReal_le_meanReal {n : Nat} (x y : Fin n → Real) (hne : n ≠ 0) div_le_div_of_nonneg_right hsum hden simpa [meanReal, hne] using hdiv -/-- Mean monotonicity for dyadic inputs, interpreted in reals. -/ -theorem meanRat_le_meanRat_real {n : Nat} (x y : Fin n → Dyadic) (hne : n ≠ 0) +/-- Mean monotonicity for rational inputs, interpreted in reals. -/ +theorem meanRat_le_meanRat_real {n : Nat} (x y : Fin n → Rat) (hne : n ≠ 0) (hxy : ∀ i, x i ≤ y i) : (meanRat x : Real) ≤ (meanRat y : Real) := by have hreal : meanReal (fun i => (x i : Real)) ≤ meanReal (fun i => (y i : Real)) := by refine meanReal_le_meanReal (x := fun i => (x i : Real)) (y := fun i => (y i : Real)) hne ?_ intro i - exact dyadicToReal_le_of_le (hxy i) + exact ratToReal_le_of_le (hxy i) simpa [meanReal_eq_meanRat] using hreal /-- Variance of a real vector (defaults to `0` when `n = 0`). -/ @@ -193,7 +193,7 @@ theorem varianceReal_nonneg {n : Nat} (x : Fin n → Real) (h : n ≠ 0) : div_nonneg hsum hden simpa [varianceReal_def x h] using hdiv -theorem varianceReal_eq_varianceRat {n : Nat} (x : Fin n → Dyadic) : +theorem varianceReal_eq_varianceRat {n : Nat} (x : Fin n → Rat) : varianceReal (fun i => (x i : Real)) = (varianceRat x : Real) := by by_cases h : n = 0 · simp [varianceReal, varianceRat, h] @@ -202,7 +202,7 @@ theorem varianceReal_eq_varianceRat {n : Nat} (x : Fin n → Dyadic) : (∑ i, ((x i : Real) - (meanRat x : Real)) ^ 2) = (∑ i, ((x i : Rat) - meanRat x) ^ 2 : Rat) := by classical - simp [dyadicToReal, Rat.cast_sum] + simp [Rat.cast_sum] have hreal : varianceReal (fun i => (x i : Real)) = (∑ i, ((x i : Real) - meanReal (fun j => (x j : Real))) ^ 2) / n := by simp [varianceReal, h] @@ -215,17 +215,17 @@ theorem varianceReal_eq_varianceRat {n : Nat} (x : Fin n → Dyadic) : _ = (∑ i, ((x i : Real) - (meanRat x : Real)) ^ 2) / n := by simp [hmean] _ = (∑ i, ((x i : Rat) - meanRat x) ^ 2 : Rat) / n := by - simp [hsum] + rw [hsum] _ = (varianceRat x : Real) := hrat.symm /-- Variance is nonnegative when `n ≠ 0`, interpreted in reals. -/ -theorem varianceRat_nonneg_real {n : Nat} (x : Fin n → Dyadic) (hne : n ≠ 0) : +theorem varianceRat_nonneg_real {n : Nat} (x : Fin n → Rat) (hne : n ≠ 0) : 0 ≤ (varianceRat x : Real) := by have hreal := varianceReal_nonneg (x := fun i => (x i : Real)) hne simpa [varianceReal_eq_varianceRat] using hreal /-- Absolute mean bound from per-coordinate bounds (real inputs). -/ -theorem meanReal_abs_le_bound {n : Nat} (x : Fin n → Real) (bound : Dyadic) +theorem meanReal_abs_le_bound {n : Nat} (x : Fin n → Real) (bound : Rat) (hne : n ≠ 0) (hbound : ∀ i, |x i| ≤ (bound : Real)) : |meanReal x| ≤ (bound : Real) := by classical diff --git a/Nfp/Sound/Bounds/MatrixNorm.lean b/Nfp/Sound/Bounds/MatrixNorm.lean index 0b0cab1..b544969 100644 --- a/Nfp/Sound/Bounds/MatrixNorm.lean +++ b/Nfp/Sound/Bounds/MatrixNorm.lean @@ -16,7 +16,7 @@ import Nfp.Sound.Linear.FinFold Row-sum matrix norms for downstream linear certificates. These bounds are used to compute verified downstream error certificates -from explicit Dyadic matrices. +from explicit Rat matrices. -/ namespace Nfp @@ -28,25 +28,25 @@ namespace Bounds open scoped BigOperators /-- Row-sum of absolute values for a matrix row. -/ -def rowSum {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) (i : Fin m) : Dyadic := +def rowSum {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : Rat := Linear.sumFin n (fun j => |W i j|) /-- Weighted row-sum using per-coordinate bounds. -/ -def rowSumWeighted {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) - (bound : Fin n → Dyadic) (i : Fin m) : Dyadic := +def rowSumWeighted {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (bound : Fin n → Rat) (i : Fin m) : Rat := Linear.sumFin n (fun j => |W i j| * bound j) /-- Maximum row-sum norm (defaults to `0` on empty matrices). -/ -def rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) : Dyadic := +def rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) : Rat := Linear.foldlFin m (fun acc i => max acc (rowSum W i)) 0 /-- Maximum weighted row-sum (defaults to `0` on empty matrices). -/ -def rowSumWeightedNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) - (bound : Fin n → Dyadic) : Dyadic := +def rowSumWeightedNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (bound : Fin n → Rat) : Rat := Linear.foldlFin m (fun acc i => max acc (rowSumWeighted W bound i)) 0 /-- Row-sums are nonnegative. -/ -theorem rowSum_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) (i : Fin m) : +theorem rowSum_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : 0 ≤ rowSum W i := by have hsum : rowSum W i = ∑ j, |W i j| := by simp [rowSum, Linear.sumFin_eq_sum_univ] @@ -57,8 +57,8 @@ theorem rowSum_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) (i : Fin m simpa [hsum] using hnonneg /-- Weighted row-sums are nonnegative under nonnegative bounds. -/ -theorem rowSumWeighted_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) - (bound : Fin n → Dyadic) (i : Fin m) (hbound : ∀ j, 0 ≤ bound j) : +theorem rowSumWeighted_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (bound : Fin n → Rat) (i : Fin m) (hbound : ∀ j, 0 ≤ bound j) : 0 ≤ rowSumWeighted W bound i := by have hsum : rowSumWeighted W bound i = ∑ j, |W i j| * bound j := by simp [rowSumWeighted, Linear.sumFin_eq_sum_univ] @@ -69,45 +69,45 @@ theorem rowSumWeighted_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) simpa [hsum] using hnonneg /-- Each row-sum is bounded by the row-sum norm. -/ -theorem rowSum_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) (i : Fin m) : +theorem rowSum_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : rowSum W i ≤ rowSumNorm W := by simpa [rowSumNorm] using (foldlFin_max_ge (f := fun j => rowSum W j) i) /-- Each weighted row-sum is bounded by the weighted row-sum norm. -/ -theorem rowSumWeighted_le_rowSumWeightedNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) - (bound : Fin n → Dyadic) (i : Fin m) : +theorem rowSumWeighted_le_rowSumWeightedNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (bound : Fin n → Rat) (i : Fin m) : rowSumWeighted W bound i ≤ rowSumWeightedNorm W bound := by simpa [rowSumWeightedNorm] using (foldlFin_max_ge (f := fun j => rowSumWeighted W bound j) i) /-- The row-sum norm is nonnegative. -/ -theorem rowSumNorm_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) : +theorem rowSumNorm_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) : 0 ≤ rowSumNorm W := by simpa [rowSumNorm] using - (foldlFin_max_ge_init (f := fun i => rowSum W i) (init := (0 : Dyadic))) + (foldlFin_max_ge_init (f := fun i => rowSum W i) (init := (0 : Rat))) /-- Weighted row-sum norm is nonnegative. -/ -theorem rowSumWeightedNorm_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) - (bound : Fin n → Dyadic) : +theorem rowSumWeightedNorm_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (bound : Fin n → Rat) : 0 ≤ rowSumWeightedNorm W bound := by simpa [rowSumWeightedNorm] using - (foldlFin_max_ge_init (f := fun i => rowSumWeighted W bound i) (init := (0 : Dyadic))) + (foldlFin_max_ge_init (f := fun i => rowSumWeighted W bound i) (init := (0 : Rat))) /-- Downstream error from per-coordinate residual bounds. -/ -def downstreamErrorFromBounds {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) - (bound : Fin n → Dyadic) : Dyadic := +def downstreamErrorFromBounds {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (bound : Fin n → Rat) : Rat := rowSumWeightedNorm W bound /-- `downstreamErrorFromBounds` is nonnegative. -/ -theorem downstreamErrorFromBounds_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) - (bound : Fin n → Dyadic) : +theorem downstreamErrorFromBounds_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (bound : Fin n → Rat) : 0 ≤ downstreamErrorFromBounds W bound := by simpa [downstreamErrorFromBounds] using rowSumWeightedNorm_nonneg W bound /-- Build a residual-interval certificate by applying a matrix to an input interval. -/ -def buildResidualIntervalCertFromMatrix {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) - (lo hi : Fin n → Dyadic) (hlohi : ∀ j, lo j ≤ hi j) : +def buildResidualIntervalCertFromMatrix {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (lo hi : Fin n → Rat) (hlohi : ∀ j, lo j ≤ hi j) : {c : Circuit.ResidualIntervalCert m // Circuit.ResidualIntervalBounds c} := by let lo' := mulVecIntervalLower W lo hi let hi' := mulVecIntervalUpper W lo hi @@ -117,8 +117,8 @@ def buildResidualIntervalCertFromMatrix {m n : Nat} (W : Matrix (Fin m) (Fin n) exact mulVecIntervalLower_le_upper W lo hi hlohi i /-- Row-sum norm bounds a matrix-vector product under a uniform input bound. -/ -theorem abs_mulVec_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) - (x : Fin n → Dyadic) (inputBound : Dyadic) +theorem abs_mulVec_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (x : Fin n → Rat) (inputBound : Rat) (hx : ∀ j, |x j| ≤ inputBound) (hinput : 0 ≤ inputBound) : ∀ i, |Matrix.mulVec W x i| ≤ rowSumNorm W * inputBound := by intro i @@ -156,8 +156,8 @@ theorem abs_mulVec_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) exact hrow.trans hmul /-- Build a downstream linear certificate from a matrix and input bound. -/ -def buildDownstreamLinearCert {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) - (inputBound : Dyadic) (hinput : 0 ≤ inputBound) : +def buildDownstreamLinearCert {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (inputBound : Rat) (hinput : 0 ≤ inputBound) : {c : Circuit.DownstreamLinearCert // Circuit.DownstreamLinearBounds c} := by let gain := rowSumNorm W let error := gain * inputBound diff --git a/Nfp/Sound/Bounds/MatrixNorm/Interval.lean b/Nfp/Sound/Bounds/MatrixNorm/Interval.lean index db10a41..2db676c 100644 --- a/Nfp/Sound/Bounds/MatrixNorm/Interval.lean +++ b/Nfp/Sound/Bounds/MatrixNorm/Interval.lean @@ -22,8 +22,8 @@ namespace Bounds open scoped BigOperators -lemma foldl_max_ge_init {α : Type _} (f : α → Dyadic) : - ∀ (l : List α) (init : Dyadic), +lemma foldl_max_ge_init {α : Type _} (f : α → Rat) : + ∀ (l : List α) (init : Rat), init ≤ l.foldl (fun acc x => max acc (f x)) init := by intro l init induction l generalizing init with @@ -35,8 +35,8 @@ lemma foldl_max_ge_init {α : Type _} (f : α → Dyadic) : ih (max init (f a)) simpa [List.foldl] using le_trans hinit hrest -lemma foldl_max_ge_mem {α : Type _} (f : α → Dyadic) : - ∀ (l : List α) (a : α) (init : Dyadic), +lemma foldl_max_ge_mem {α : Type _} (f : α → Rat) : + ∀ (l : List α) (a : α) (init : Rat), a ∈ l → f a ≤ l.foldl (fun acc x => max acc (f x)) init := by intro l a init hmem induction l generalizing init with @@ -57,7 +57,7 @@ lemma foldl_max_ge_mem {α : Type _} (f : α → Dyadic) : have h' := ih (init := max init (f b)) h simpa [List.foldl] using h' -lemma foldlFin_max_ge_init {n : Nat} (f : Fin n → Dyadic) (init : Dyadic) : +lemma foldlFin_max_ge_init {n : Nat} (f : Fin n → Rat) (init : Rat) : init ≤ Linear.foldlFin n (fun acc j => max acc (f j)) init := by classical have hlist : @@ -71,7 +71,7 @@ lemma foldlFin_max_ge_init {n : Nat} (f : Fin n → Dyadic) (init : Dyadic) : (f := fun acc j => max acc (f j)) (x := init) (n := n)) simpa [hfold] using hlist -lemma foldlFin_max_ge {n : Nat} (f : Fin n → Dyadic) (i : Fin n) : +lemma foldlFin_max_ge {n : Nat} (f : Fin n → Rat) (i : Fin n) : f i ≤ Linear.foldlFin n (fun acc j => max acc (f j)) 0 := by classical have hmem : i ∈ List.finRange n := by @@ -84,51 +84,135 @@ lemma foldlFin_max_ge {n : Nat} (f : Fin n → Dyadic) (i : Fin n) : (List.finRange n).foldl (fun acc j => max acc (f j)) 0 := by simpa [Linear.foldlFin_eq_foldl] using (Fin.foldl_eq_foldl_finRange - (f := fun acc j => max acc (f j)) (x := (0 : Dyadic)) (n := n)) + (f := fun acc j => max acc (f j)) (x := (0 : Rat)) (n := n)) simpa [hfold] using hlist /-- Lower interval endpoint for a dot product with per-coordinate bounds. -/ -def dotIntervalLower {n : Nat} (v lo hi : Fin n → Dyadic) : Dyadic := +def dotIntervalLower {n : Nat} (v lo hi : Fin n → Rat) : Rat := Linear.sumFin n (fun j => if 0 ≤ v j then v j * lo j else v j * hi j) /-- Upper interval endpoint for a dot product with per-coordinate bounds. -/ -def dotIntervalUpper {n : Nat} (v lo hi : Fin n → Dyadic) : Dyadic := +def dotIntervalUpper {n : Nat} (v lo hi : Fin n → Rat) : Rat := Linear.sumFin n (fun j => if 0 ≤ v j then v j * hi j else v j * lo j) /-- Lower interval endpoint using a shared-denominator accumulator. -/ -def dotIntervalLowerCommonDen {n : Nat} (v lo hi : Fin n → Dyadic) : Dyadic := +def dotIntervalLowerCommonDen {n : Nat} (v lo hi : Fin n → Rat) : Rat := Linear.sumFinCommonDen n (fun j => if 0 ≤ v j then v j * lo j else v j * hi j) /-- Upper interval endpoint using a shared-denominator accumulator. -/ -def dotIntervalUpperCommonDen {n : Nat} (v lo hi : Fin n → Dyadic) : Dyadic := +def dotIntervalUpperCommonDen {n : Nat} (v lo hi : Fin n → Rat) : Rat := Linear.sumFinCommonDen n (fun j => if 0 ≤ v j then v j * hi j else v j * lo j) +/-- Lower/upper interval endpoints computed in a single pass. -/ +def dotIntervalLowerUpperCommonDen {n : Nat} (v lo hi : Fin n → Rat) : Rat × Rat := + Linear.foldlFin n + (fun acc j => + (acc.1 + if 0 ≤ v j then v j * lo j else v j * hi j, + acc.2 + if 0 ≤ v j then v j * hi j else v j * lo j)) + (0, 0) + /-- Lower interval endpoint using unnormalized accumulation. -/ -def dotIntervalLowerUnnorm {n : Nat} (v lo hi : Fin n → Dyadic) : Dyadic := +def dotIntervalLowerUnnorm {n : Nat} (v lo hi : Fin n → Rat) : Rat := dotIntervalLower v lo hi /-- Upper interval endpoint using unnormalized accumulation. -/ -def dotIntervalUpperUnnorm {n : Nat} (v lo hi : Fin n → Dyadic) : Dyadic := +def dotIntervalUpperUnnorm {n : Nat} (v lo hi : Fin n → Rat) : Rat := dotIntervalUpper v lo hi -theorem dotIntervalLowerCommonDen_eq {n : Nat} (v lo hi : Fin n → Dyadic) : +theorem dotIntervalLowerCommonDen_eq {n : Nat} (v lo hi : Fin n → Rat) : dotIntervalLowerCommonDen v lo hi = dotIntervalLower v lo hi := by simp [dotIntervalLowerCommonDen, dotIntervalLower, Linear.sumFinCommonDen_eq_sumFin] -theorem dotIntervalUpperCommonDen_eq {n : Nat} (v lo hi : Fin n → Dyadic) : +theorem dotIntervalUpperCommonDen_eq {n : Nat} (v lo hi : Fin n → Rat) : dotIntervalUpperCommonDen v lo hi = dotIntervalUpper v lo hi := by simp [dotIntervalUpperCommonDen, dotIntervalUpper, Linear.sumFinCommonDen_eq_sumFin] -theorem dotIntervalLowerUnnorm_eq {n : Nat} (v lo hi : Fin n → Dyadic) : +private lemma foldl_pair_fst {α : Type _} (xs : List α) (f g : α → Rat) (a b : Rat) : + (xs.foldl (fun acc x => (acc.1 + f x, acc.2 + g x)) (a, b)).1 = + xs.foldl (fun acc x => acc + f x) a := by + induction xs generalizing a b with + | nil => + simp + | cons x xs ih => + simp [List.foldl, ih] + +private lemma foldl_pair_snd {α : Type _} (xs : List α) (f g : α → Rat) (a b : Rat) : + (xs.foldl (fun acc x => (acc.1 + f x, acc.2 + g x)) (a, b)).2 = + xs.foldl (fun acc x => acc + g x) b := by + induction xs generalizing a b with + | nil => + simp + | cons x xs ih => + simp [List.foldl, ih] + +theorem dotIntervalLowerUpperCommonDen_fst {n : Nat} (v lo hi : Fin n → Rat) : + (dotIntervalLowerUpperCommonDen v lo hi).1 = dotIntervalLowerCommonDen v lo hi := by + classical + have hsum : + dotIntervalLowerCommonDen v lo hi = + (List.finRange n).foldl + (fun acc j => acc + if 0 ≤ v j then v j * lo j else v j * hi j) 0 := by + simp [dotIntervalLowerCommonDen, Linear.sumFinCommonDen_eq_sumFin, + Linear.sumFin_eq_list_foldl] + have hfold : + (dotIntervalLowerUpperCommonDen v lo hi).1 = + (List.finRange n).foldl + (fun acc j => acc + if 0 ≤ v j then v j * lo j else v j * hi j) 0 := by + simpa [dotIntervalLowerUpperCommonDen, Linear.foldlFin_eq_foldl, + Fin.foldl_eq_foldl_finRange] using + (foldl_pair_fst (xs := List.finRange n) + (f := fun j => if 0 ≤ v j then v j * lo j else v j * hi j) + (g := fun j => if 0 ≤ v j then v j * hi j else v j * lo j) + (a := 0) (b := 0)) + calc + (dotIntervalLowerUpperCommonDen v lo hi).1 + = + (List.finRange n).foldl + (fun acc j => acc + if 0 ≤ v j then v j * lo j else v j * hi j) 0 := hfold + _ = dotIntervalLowerCommonDen v lo hi := hsum.symm + +theorem dotIntervalLowerUpperCommonDen_snd {n : Nat} (v lo hi : Fin n → Rat) : + (dotIntervalLowerUpperCommonDen v lo hi).2 = dotIntervalUpperCommonDen v lo hi := by + classical + have hsum : + dotIntervalUpperCommonDen v lo hi = + (List.finRange n).foldl + (fun acc j => acc + if 0 ≤ v j then v j * hi j else v j * lo j) 0 := by + simp [dotIntervalUpperCommonDen, Linear.sumFinCommonDen_eq_sumFin, + Linear.sumFin_eq_list_foldl] + have hfold : + (dotIntervalLowerUpperCommonDen v lo hi).2 = + (List.finRange n).foldl + (fun acc j => acc + if 0 ≤ v j then v j * hi j else v j * lo j) 0 := by + simpa [dotIntervalLowerUpperCommonDen, Linear.foldlFin_eq_foldl, + Fin.foldl_eq_foldl_finRange] using + (foldl_pair_snd (xs := List.finRange n) + (f := fun j => if 0 ≤ v j then v j * lo j else v j * hi j) + (g := fun j => if 0 ≤ v j then v j * hi j else v j * lo j) + (a := 0) (b := 0)) + calc + (dotIntervalLowerUpperCommonDen v lo hi).2 + = + (List.finRange n).foldl + (fun acc j => acc + if 0 ≤ v j then v j * hi j else v j * lo j) 0 := hfold + _ = dotIntervalUpperCommonDen v lo hi := hsum.symm + +/-- Single-pass lower/upper endpoints agree with the common-denominator bounds. -/ +theorem dotIntervalLowerUpperCommonDen_eq {n : Nat} (v lo hi : Fin n → Rat) : + dotIntervalLowerUpperCommonDen v lo hi = + (dotIntervalLowerCommonDen v lo hi, dotIntervalUpperCommonDen v lo hi) := by + ext <;> simp [dotIntervalLowerUpperCommonDen_fst, dotIntervalLowerUpperCommonDen_snd] + +theorem dotIntervalLowerUnnorm_eq {n : Nat} (v lo hi : Fin n → Rat) : dotIntervalLowerUnnorm v lo hi = dotIntervalLower v lo hi := rfl -theorem dotIntervalUpperUnnorm_eq {n : Nat} (v lo hi : Fin n → Dyadic) : +theorem dotIntervalUpperUnnorm_eq {n : Nat} (v lo hi : Fin n → Rat) : dotIntervalUpperUnnorm v lo hi = dotIntervalUpper v lo hi := rfl /-! Cached endpoints. -/ -/-- Cached-array lower interval endpoint for a dot product using normalized dyadic sums. -/ -def dotIntervalLowerCachedDyadic {n : Nat} (v lo hi : Fin n → Dyadic) : Dyadic := +/-- Cached-array lower interval endpoint for a dot product using normalized rational sums. -/ +def dotIntervalLowerCachedRat {n : Nat} (v lo hi : Fin n → Rat) : Rat := let vArr := Array.ofFn v let loArr := Array.ofFn lo let hiArr := Array.ofFn hi @@ -144,8 +228,8 @@ def dotIntervalLowerCachedDyadic {n : Nat} (v lo hi : Fin n → Dyadic) : Dyadic simp [hsize, j.isLt]) if 0 ≤ vj then vj * loj else vj * hij) -/-- Cached-array upper interval endpoint for a dot product using normalized dyadic sums. -/ -def dotIntervalUpperCachedDyadic {n : Nat} (v lo hi : Fin n → Dyadic) : Dyadic := +/-- Cached-array upper interval endpoint for a dot product using normalized rational sums. -/ +def dotIntervalUpperCachedRat {n : Nat} (v lo hi : Fin n → Rat) : Rat := let vArr := Array.ofFn v let loArr := Array.ofFn lo let hiArr := Array.ofFn hi @@ -161,35 +245,35 @@ def dotIntervalUpperCachedDyadic {n : Nat} (v lo hi : Fin n → Dyadic) : Dyadic simp [hsize, j.isLt]) if 0 ≤ vj then vj * hij else vj * loj) -theorem dotIntervalLowerCachedRat_eq {n : Nat} (v lo hi : Fin n → Dyadic) : - dotIntervalLowerCachedDyadic v lo hi = dotIntervalLower v lo hi := by +theorem dotIntervalLowerCachedRat_eq {n : Nat} (v lo hi : Fin n → Rat) : + dotIntervalLowerCachedRat v lo hi = dotIntervalLower v lo hi := by classical - simp [dotIntervalLowerCachedDyadic, dotIntervalLower, Linear.sumFin_eq_list_foldl, + simp [dotIntervalLowerCachedRat, dotIntervalLower, Linear.sumFin_eq_list_foldl, Array.getElem_ofFn] -theorem dotIntervalUpperCachedRat_eq {n : Nat} (v lo hi : Fin n → Dyadic) : - dotIntervalUpperCachedDyadic v lo hi = dotIntervalUpper v lo hi := by +theorem dotIntervalUpperCachedRat_eq {n : Nat} (v lo hi : Fin n → Rat) : + dotIntervalUpperCachedRat v lo hi = dotIntervalUpper v lo hi := by classical - simp [dotIntervalUpperCachedDyadic, dotIntervalUpper, Linear.sumFin_eq_list_foldl, + simp [dotIntervalUpperCachedRat, dotIntervalUpper, Linear.sumFin_eq_list_foldl, Array.getElem_ofFn] /-! Absolute bounds. -/ /-- Absolute bound from interval endpoints for a dot product. -/ -def dotIntervalAbsBound {n : Nat} (v lo hi : Fin n → Dyadic) : Dyadic := +def dotIntervalAbsBound {n : Nat} (v lo hi : Fin n → Rat) : Rat := max |dotIntervalLower v lo hi| |dotIntervalUpper v lo hi| /-- Lower interval endpoint for a matrix-vector product under input intervals. -/ -def mulVecIntervalLower {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) - (lo hi : Fin n → Dyadic) : Fin m → Dyadic := +def mulVecIntervalLower {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (lo hi : Fin n → Rat) : Fin m → Rat := fun i => dotIntervalLower (fun j => W i j) lo hi /-- Upper interval endpoint for a matrix-vector product under input intervals. -/ -def mulVecIntervalUpper {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) - (lo hi : Fin n → Dyadic) : Fin m → Dyadic := +def mulVecIntervalUpper {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (lo hi : Fin n → Rat) : Fin m → Rat := fun i => dotIntervalUpper (fun j => W i j) lo hi -theorem dotIntervalLower_le_dotProduct {n : Nat} (v lo hi x : Fin n → Dyadic) +theorem dotIntervalLower_le_dotProduct {n : Nat} (v lo hi x : Fin n → Rat) (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : dotIntervalLower v lo hi ≤ dotProduct v x := by classical @@ -205,7 +289,7 @@ theorem dotIntervalLower_le_dotProduct {n : Nat} (v lo hi x : Fin n → Dyadic) mul_le_mul_of_nonpos_left (hhi j) hv' simpa [hv] using h1 -theorem dotProduct_le_dotIntervalUpper {n : Nat} (v lo hi x : Fin n → Dyadic) +theorem dotProduct_le_dotIntervalUpper {n : Nat} (v lo hi x : Fin n → Rat) (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : dotProduct v x ≤ dotIntervalUpper v lo hi := by classical @@ -221,7 +305,7 @@ theorem dotProduct_le_dotIntervalUpper {n : Nat} (v lo hi x : Fin n → Dyadic) mul_le_mul_of_nonpos_left (hlo j) hv' simpa [hv] using h1 -theorem abs_le_max_abs_abs_of_interval {a b x : Dyadic} (hlo : a ≤ x) (hhi : x ≤ b) : +theorem abs_le_max_abs_abs_of_interval {a b x : Rat} (hlo : a ≤ x) (hhi : x ≤ b) : |x| ≤ max |a| |b| := by by_cases hx : 0 ≤ x · have hb : 0 ≤ b := le_trans hx hhi @@ -243,17 +327,17 @@ theorem abs_le_max_abs_abs_of_interval {a b x : Dyadic} (hlo : a ≤ x) (hhi : x _ ≤ max |a| |b| := le_max_left _ _ /-- Global absolute bound from interval endpoints. -/ -def intervalAbsBound {n : Nat} (lo hi : Fin n → Dyadic) : Dyadic := +def intervalAbsBound {n : Nat} (lo hi : Fin n → Rat) : Rat := Linear.foldlFin n (fun acc i => max acc (max |lo i| |hi i|)) 0 /-- `intervalAbsBound` dominates each endpoint absolute value. -/ -theorem max_abs_le_intervalAbsBound {n : Nat} (lo hi : Fin n → Dyadic) (i : Fin n) : +theorem max_abs_le_intervalAbsBound {n : Nat} (lo hi : Fin n → Rat) (i : Fin n) : max |lo i| |hi i| ≤ intervalAbsBound lo hi := by simpa [intervalAbsBound] using (foldlFin_max_ge (f := fun j => max |lo j| |hi j|) i) /-- `intervalAbsBound` bounds any element inside the interval. -/ -theorem abs_le_intervalAbsBound {n : Nat} (lo hi x : Fin n → Dyadic) +theorem abs_le_intervalAbsBound {n : Nat} (lo hi x : Fin n → Rat) (hlo : ∀ i, lo i ≤ x i) (hhi : ∀ i, x i ≤ hi i) (i : Fin n) : |x i| ≤ intervalAbsBound lo hi := by have hbound : |x i| ≤ max |lo i| |hi i| := @@ -262,7 +346,7 @@ theorem abs_le_intervalAbsBound {n : Nat} (lo hi x : Fin n → Dyadic) max_abs_le_intervalAbsBound lo hi i exact le_trans hbound hsup -theorem abs_dotProduct_le_dotIntervalAbsBound {n : Nat} (v lo hi x : Fin n → Dyadic) +theorem abs_dotProduct_le_dotIntervalAbsBound {n : Nat} (v lo hi x : Fin n → Rat) (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : |dotProduct v x| ≤ dotIntervalAbsBound v lo hi := by have hlow : dotIntervalLower v lo hi ≤ dotProduct v x := @@ -277,7 +361,7 @@ theorem abs_dotProduct_le_dotIntervalAbsBound {n : Nat} (v lo hi x : Fin n → D /-! Real-valued bounds from rational intervals. -/ -theorem dotIntervalLower_le_dotProduct_real {n : Nat} (v lo hi : Fin n → Dyadic) +theorem dotIntervalLower_le_dotProduct_real {n : Nat} (v lo hi : Fin n → Rat) (x : Fin n → Real) (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : (dotIntervalLower v lo hi : Real) ≤ dotProduct (fun j => (v j : Real)) x := by @@ -285,8 +369,8 @@ theorem dotIntervalLower_le_dotProduct_real {n : Nat} (v lo hi : Fin n → Dyadi have hcast : (dotIntervalLower v lo hi : Real) = ∑ j, if 0 ≤ v j then (v j : Real) * (lo j : Real) else (v j : Real) * (hi j : Real) := by - simpa [dotIntervalLower, dyadicToReal_mul, dyadicToReal_if] using - (Linear.dyadicToReal_sumFin + simpa [dotIntervalLower, ratToReal_mul, ratToReal_if] using + (Linear.ratToReal_sumFin (f := fun j => if 0 ≤ v j then v j * lo j else v j * hi j)) have hsum : (∑ j, if 0 ≤ v j then (v j : Real) * (lo j : Real) else (v j : Real) * (hi j : Real)) ≤ @@ -295,17 +379,17 @@ theorem dotIntervalLower_le_dotProduct_real {n : Nat} (v lo hi : Fin n → Dyadi intro j _ by_cases hv : 0 ≤ v j · have h1 : (v j : Real) * (lo j : Real) ≤ (v j : Real) * x j := by - have hv' : (0 : Real) ≤ (v j : Real) := dyadicToReal_nonneg_of_nonneg hv + have hv' : (0 : Real) ≤ (v j : Real) := ratToReal_nonneg_of_nonneg hv exact mul_le_mul_of_nonneg_left (hlo j) hv' simpa [hv] using h1 · have hv' : (v j : Real) ≤ 0 := by - exact (dyadicToReal_nonpos_iff (x := v j)).2 (le_of_not_ge hv) + exact (ratToReal_nonpos_iff (x := v j)).2 (le_of_not_ge hv) have h1 : (v j : Real) * (hi j : Real) ≤ (v j : Real) * x j := by exact mul_le_mul_of_nonpos_left (hhi j) hv' simpa [hv] using h1 simpa [hcast, dotProduct] using hsum -theorem dotProduct_le_dotIntervalUpper_real {n : Nat} (v lo hi : Fin n → Dyadic) +theorem dotProduct_le_dotIntervalUpper_real {n : Nat} (v lo hi : Fin n → Rat) (x : Fin n → Real) (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : dotProduct (fun j => (v j : Real)) x ≤ (dotIntervalUpper v lo hi : Real) := by @@ -313,8 +397,8 @@ theorem dotProduct_le_dotIntervalUpper_real {n : Nat} (v lo hi : Fin n → Dyadi have hcast : (dotIntervalUpper v lo hi : Real) = ∑ j, if 0 ≤ v j then (v j : Real) * (hi j : Real) else (v j : Real) * (lo j : Real) := by - simpa [dotIntervalUpper, dyadicToReal_mul, dyadicToReal_if] using - (Linear.dyadicToReal_sumFin + simpa [dotIntervalUpper, ratToReal_mul, ratToReal_if] using + (Linear.ratToReal_sumFin (f := fun j => if 0 ≤ v j then v j * hi j else v j * lo j)) have hsum : ∑ j, (v j : Real) * x j ≤ @@ -323,11 +407,11 @@ theorem dotProduct_le_dotIntervalUpper_real {n : Nat} (v lo hi : Fin n → Dyadi intro j _ by_cases hv : 0 ≤ v j · have h1 : (v j : Real) * x j ≤ (v j : Real) * (hi j : Real) := by - have hv' : (0 : Real) ≤ (v j : Real) := dyadicToReal_nonneg_of_nonneg hv + have hv' : (0 : Real) ≤ (v j : Real) := ratToReal_nonneg_of_nonneg hv exact mul_le_mul_of_nonneg_left (hhi j) hv' simpa [hv] using h1 · have hv' : (v j : Real) ≤ 0 := by - exact (dyadicToReal_nonpos_iff (x := v j)).2 (le_of_not_ge hv) + exact (ratToReal_nonpos_iff (x := v j)).2 (le_of_not_ge hv) have h1 : (v j : Real) * x j ≤ (v j : Real) * (lo j : Real) := by exact mul_le_mul_of_nonpos_left (hlo j) hv' simpa [hv] using h1 @@ -355,7 +439,7 @@ theorem abs_le_max_abs_abs_of_interval_real {a b x : Real} (hlo : a ≤ x) (hhi _ ≤ max |a| |b| := le_max_left _ _ /-- `intervalAbsBound` controls real-valued coordinates inside a rational interval. -/ -theorem abs_le_intervalAbsBound_real {n : Nat} (lo hi : Fin n → Dyadic) (x : Fin n → Real) +theorem abs_le_intervalAbsBound_real {n : Nat} (lo hi : Fin n → Rat) (x : Fin n → Real) (hlo : ∀ i, (lo i : Real) ≤ x i) (hhi : ∀ i, x i ≤ (hi i : Real)) (i : Fin n) : |x i| ≤ (intervalAbsBound lo hi : Real) := by have hbound : |x i| ≤ max |(lo i : Real)| |(hi i : Real)| := @@ -370,14 +454,14 @@ theorem abs_le_intervalAbsBound_real {n : Nat} (lo hi : Fin n → Dyadic) (x : F le_trans (le_max_right _ _) hsup have hlo_real : |(lo i : Real)| ≤ (intervalAbsBound lo hi : Real) := by - exact dyadicToReal_abs_le_of_le hlo + exact ratToReal_abs_le_of_le hlo have hhi_real : |(hi i : Real)| ≤ (intervalAbsBound lo hi : Real) := by - exact dyadicToReal_abs_le_of_le hhi + exact ratToReal_abs_le_of_le hhi exact max_le_iff.mpr ⟨hlo_real, hhi_real⟩ exact le_trans hbound hsup_real -theorem abs_dotProduct_le_dotIntervalAbsBound_real {n : Nat} (v lo hi : Fin n → Dyadic) +theorem abs_dotProduct_le_dotIntervalAbsBound_real {n : Nat} (v lo hi : Fin n → Rat) (x : Fin n → Real) (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : |dotProduct (fun j => (v j : Real)) x| ≤ (dotIntervalAbsBound v lo hi : Real) := by @@ -394,29 +478,29 @@ theorem abs_dotProduct_le_dotIntervalAbsBound_real {n : Nat} (v lo hi : Fin n have hcast : (dotIntervalAbsBound v lo hi : Real) = max |(dotIntervalLower v lo hi : Real)| |(dotIntervalUpper v lo hi : Real)| := by - simp [dotIntervalAbsBound, dyadicToReal_abs, dyadicToReal_max] + simp [dotIntervalAbsBound] simpa [hcast] using habs /-! Matrix-vector interval bounds. -/ -theorem mulVecIntervalLower_le_mulVec {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) - (lo hi x : Fin n → Dyadic) (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : +theorem mulVecIntervalLower_le_mulVec {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (lo hi x : Fin n → Rat) (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : ∀ i, mulVecIntervalLower W lo hi i ≤ Matrix.mulVec W x i := by intro i have h := dotIntervalLower_le_dotProduct (v := fun j => W i j) lo hi x hlo hhi simpa [mulVecIntervalLower, Matrix.mulVec, dotProduct] using h -theorem mulVec_le_mulVecIntervalUpper {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) - (lo hi x : Fin n → Dyadic) (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : +theorem mulVec_le_mulVecIntervalUpper {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (lo hi x : Fin n → Rat) (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : ∀ i, Matrix.mulVec W x i ≤ mulVecIntervalUpper W lo hi i := by intro i have h := dotProduct_le_dotIntervalUpper (v := fun j => W i j) lo hi x hlo hhi simpa [mulVecIntervalUpper, Matrix.mulVec, dotProduct] using h -theorem mulVecIntervalLower_le_upper {m n : Nat} (W : Matrix (Fin m) (Fin n) Dyadic) - (lo hi : Fin n → Dyadic) (hlohi : ∀ j, lo j ≤ hi j) : +theorem mulVecIntervalLower_le_upper {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (lo hi : Fin n → Rat) (hlohi : ∀ j, lo j ≤ hi j) : ∀ i, mulVecIntervalLower W lo hi i ≤ mulVecIntervalUpper W lo hi i := by intro i have hlow : diff --git a/Nfp/Sound/Bounds/Mlp.lean b/Nfp/Sound/Bounds/Mlp.lean index 079c7e2..26a8574 100644 --- a/Nfp/Sound/Bounds/Mlp.lean +++ b/Nfp/Sound/Bounds/Mlp.lean @@ -20,8 +20,8 @@ open scoped BigOperators /-- Real-valued MLP with tanh-based GELU activations. -/ noncomputable def mlpReal {dModel hidden : Nat} - (wIn : Fin dModel → Fin hidden → Dyadic) (bIn : Fin hidden → Dyadic) - (wOut : Fin hidden → Fin dModel → Dyadic) (bOut : Fin dModel → Dyadic) + (wIn : Fin dModel → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin dModel → Rat) (bOut : Fin dModel → Rat) (x : Fin dModel → Real) : Fin dModel → Real := fun i => let hidden : Fin hidden → Real := fun h => @@ -30,36 +30,36 @@ noncomputable def mlpReal {dModel hidden : Nat} /-- Interval bounds for a tanh-GELU MLP given input intervals. -/ def mlpBounds {dModel hidden : Nat} - (wIn : Fin dModel → Fin hidden → Dyadic) (bIn : Fin hidden → Dyadic) - (wOut : Fin hidden → Fin dModel → Dyadic) (bOut : Fin dModel → Dyadic) - (lo hi : Fin dModel → Dyadic) : (Fin dModel → Dyadic) × (Fin dModel → Dyadic) := - let preLo : Fin hidden → Dyadic := fun h => + (wIn : Fin dModel → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin dModel → Rat) (bOut : Fin dModel → Rat) + (lo hi : Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := + let preLo : Fin hidden → Rat := fun h => dotIntervalLower (fun j => wIn j h) lo hi + bIn h - let preHi : Fin hidden → Dyadic := fun h => + let preHi : Fin hidden → Rat := fun h => dotIntervalUpper (fun j => wIn j h) lo hi + bIn h - let geluBounds : Fin hidden → Dyadic × Dyadic := fun h => geluInterval (preLo h) (preHi h) - let geluLo : Fin hidden → Dyadic := fun h => (geluBounds h).1 - let geluHi : Fin hidden → Dyadic := fun h => (geluBounds h).2 - let outLo : Fin dModel → Dyadic := fun i => + let geluBounds : Fin hidden → Rat × Rat := fun h => geluInterval (preLo h) (preHi h) + let geluLo : Fin hidden → Rat := fun h => (geluBounds h).1 + let geluHi : Fin hidden → Rat := fun h => (geluBounds h).2 + let outLo : Fin dModel → Rat := fun i => dotIntervalLower (fun h => wOut h i) geluLo geluHi + bOut i - let outHi : Fin dModel → Dyadic := fun i => + let outHi : Fin dModel → Rat := fun i => dotIntervalUpper (fun h => wOut h i) geluLo geluHi + bOut i (outLo, outHi) /-- `mlpBounds` soundness for real MLP outputs. -/ theorem mlpBounds_spec {dModel hidden : Nat} - (wIn : Fin dModel → Fin hidden → Dyadic) (bIn : Fin hidden → Dyadic) - (wOut : Fin hidden → Fin dModel → Dyadic) (bOut : Fin dModel → Dyadic) - (lo hi : Fin dModel → Dyadic) (x : Fin dModel → Real) + (wIn : Fin dModel → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin dModel → Rat) (bOut : Fin dModel → Rat) + (lo hi : Fin dModel → Rat) (x : Fin dModel → Real) (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : let bounds := mlpBounds wIn bIn wOut bOut lo hi ∀ i, (bounds.1 i : Real) ≤ mlpReal wIn bIn wOut bOut x i ∧ mlpReal wIn bIn wOut bOut x i ≤ (bounds.2 i : Real) := by classical intro bounds i - let preLo : Fin hidden → Dyadic := fun h => + let preLo : Fin hidden → Rat := fun h => dotIntervalLower (fun j => wIn j h) lo hi + bIn h - let preHi : Fin hidden → Dyadic := fun h => + let preHi : Fin hidden → Rat := fun h => dotIntervalUpper (fun j => wIn j h) lo hi + bIn h let pre : Fin hidden → Real := fun h => dotProduct (fun j => (wIn j h : Real)) x + (bIn h : Real) @@ -75,18 +75,18 @@ theorem mlpBounds_spec {dModel hidden : Nat} add_le_add_right (dotProduct_le_dotIntervalUpper_real (v := fun j => wIn j h) lo hi x hlo hhi) (bIn h : Real) - let geluBounds : Fin hidden → Dyadic × Dyadic := fun h => geluInterval (preLo h) (preHi h) - let geluLo : Fin hidden → Dyadic := fun h => (geluBounds h).1 - let geluHi : Fin hidden → Dyadic := fun h => (geluBounds h).2 + let geluBounds : Fin hidden → Rat × Rat := fun h => geluInterval (preLo h) (preHi h) + let geluLo : Fin hidden → Rat := fun h => (geluBounds h).1 + let geluHi : Fin hidden → Rat := fun h => (geluBounds h).2 let hidden : Fin hidden → Real := fun h => geluTanh (pre h) have hgelu : ∀ h, (geluLo h : Real) ≤ hidden h ∧ hidden h ≤ (geluHi h : Real) := by intro h have hbounds := geluInterval_bounds (lo := preLo h) (hi := preHi h) (hpre_lower h) (hpre_upper h) simpa [geluLo, geluHi, geluBounds, hidden] using hbounds - let outLo : Fin dModel → Dyadic := fun i => + let outLo : Fin dModel → Rat := fun i => dotIntervalLower (fun h => wOut h i) geluLo geluHi + bOut i - let outHi : Fin dModel → Dyadic := fun i => + let outHi : Fin dModel → Rat := fun i => dotIntervalUpper (fun h => wOut h i) geluLo geluHi + bOut i have hout_lower : (outLo i : Real) ≤ dotProduct (fun h => (wOut h i : Real)) hidden + (bOut i : Real) := by @@ -111,19 +111,19 @@ theorem mlpBounds_spec {dModel hidden : Nat} /-- Interval bounds for a LayerNorm + MLP sublayer from exact inputs. -/ def layerNormMlpBounds {n hidden : Nat} - (eps : Dyadic) (gamma beta : Fin n → Dyadic) - (wIn : Fin n → Fin hidden → Dyadic) (bIn : Fin hidden → Dyadic) - (wOut : Fin hidden → Fin n → Dyadic) (bOut : Fin n → Dyadic) - (x : Fin n → Dyadic) : (Fin n → Dyadic) × (Fin n → Dyadic) := + (eps : Rat) (gamma beta : Fin n → Rat) + (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) + (x : Fin n → Rat) : (Fin n → Rat) × (Fin n → Rat) := let ln := layerNormBounds eps gamma beta x mlpBounds wIn bIn wOut bOut ln.1 ln.2 /-- `layerNormMlpBounds` soundness for real LayerNorm + MLP outputs. -/ theorem layerNormMlpBounds_spec {n hidden : Nat} - (eps : Dyadic) (gamma beta : Fin n → Dyadic) - (wIn : Fin n → Fin hidden → Dyadic) (bIn : Fin hidden → Dyadic) - (wOut : Fin hidden → Fin n → Dyadic) (bOut : Fin n → Dyadic) - (x : Fin n → Dyadic) (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : + (eps : Rat) (gamma beta : Fin n → Rat) + (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) + (x : Fin n → Rat) (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : let bounds := layerNormMlpBounds eps gamma beta wIn bIn wOut bOut x ∀ i, (bounds.1 i : Real) ≤ mlpReal wIn bIn wOut bOut (layerNormReal eps gamma beta x) i ∧ @@ -140,20 +140,20 @@ theorem layerNormMlpBounds_spec {n hidden : Nat} /-- Interval bounds for LayerNorm + MLP sublayer from interval inputs. -/ def layerNormAbsMlpBounds {n hidden : Nat} - (eps : Dyadic) (gamma beta : Fin n → Dyadic) - (wIn : Fin n → Fin hidden → Dyadic) (bIn : Fin hidden → Dyadic) - (wOut : Fin hidden → Fin n → Dyadic) (bOut : Fin n → Dyadic) - (lo hi : Fin n → Dyadic) : (Fin n → Dyadic) × (Fin n → Dyadic) := + (eps : Rat) (gamma beta : Fin n → Rat) + (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) + (lo hi : Fin n → Rat) : (Fin n → Rat) × (Fin n → Rat) := let absBound := intervalAbsBound lo hi let ln := layerNormAbsBounds eps gamma beta absBound mlpBounds wIn bIn wOut bOut ln.1 ln.2 /-- `layerNormAbsMlpBounds` soundness for real LayerNorm + MLP outputs. -/ theorem layerNormAbsMlpBounds_spec {n hidden : Nat} - (eps : Dyadic) (gamma beta : Fin n → Dyadic) - (wIn : Fin n → Fin hidden → Dyadic) (bIn : Fin hidden → Dyadic) - (wOut : Fin hidden → Fin n → Dyadic) (bOut : Fin n → Dyadic) - (lo hi : Fin n → Dyadic) (x : Fin n → Real) + (eps : Rat) (gamma beta : Fin n → Rat) + (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) + (lo hi : Fin n → Rat) (x : Fin n → Real) (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) (hlo : ∀ i, (lo i : Real) ≤ x i) (hhi : ∀ i, x i ≤ (hi i : Real)) : let bounds := layerNormAbsMlpBounds eps gamma beta wIn bIn wOut bOut lo hi @@ -174,9 +174,9 @@ theorem layerNormAbsMlpBounds_spec {n hidden : Nat} have hsup_real : max |(lo j : Real)| |(hi j : Real)| ≤ (absBound : Real) := by have hsup' : - dyadicToReal (max |lo j| |hi j|) ≤ dyadicToReal absBound := - dyadicToReal_le_of_le hsup - simpa [dyadicToReal_abs, dyadicToReal_max] using hsup' + ratToReal (max |lo j| |hi j|) ≤ ratToReal absBound := + ratToReal_le_of_le hsup + simpa [ratToReal_abs, ratToReal_max] using hsup' exact le_trans hbound hsup_real have hln := layerNormAbsBounds_spec_real eps gamma beta absBound x hne heps hsqrt habs @@ -187,13 +187,13 @@ theorem layerNormAbsMlpBounds_spec {n hidden : Nat} simpa [bounds, layerNormAbsMlpBounds, absBound, ln] using hmlp i /-- Add residual inputs to interval bounds. -/ -def residualAddBounds {n : Nat} (x : Fin n → Dyadic) (lo hi : Fin n → Dyadic) : - (Fin n → Dyadic) × (Fin n → Dyadic) := +def residualAddBounds {n : Nat} (x : Fin n → Rat) (lo hi : Fin n → Rat) : + (Fin n → Rat) × (Fin n → Rat) := (fun i => x i + lo i, fun i => x i + hi i) /-- `residualAddBounds` soundness for residual addition. -/ -theorem residualAddBounds_spec {n : Nat} (x : Fin n → Dyadic) - (lo hi : Fin n → Dyadic) (y : Fin n → Real) +theorem residualAddBounds_spec {n : Nat} (x : Fin n → Rat) + (lo hi : Fin n → Rat) (y : Fin n → Real) (hlo : ∀ i, (lo i : Real) ≤ y i) (hhi : ∀ i, y i ≤ (hi i : Real)) : let bounds := residualAddBounds x lo hi ∀ i, (bounds.1 i : Real) ≤ (x i : Real) + y i ∧ @@ -207,19 +207,19 @@ theorem residualAddBounds_spec {n : Nat} (x : Fin n → Dyadic) /-- Interval bounds for a full MLP residual path (LayerNorm + MLP + residual add). -/ def layerNormMlpResidualBounds {n hidden : Nat} - (eps : Dyadic) (gamma beta : Fin n → Dyadic) - (wIn : Fin n → Fin hidden → Dyadic) (bIn : Fin hidden → Dyadic) - (wOut : Fin hidden → Fin n → Dyadic) (bOut : Fin n → Dyadic) - (x : Fin n → Dyadic) : (Fin n → Dyadic) × (Fin n → Dyadic) := + (eps : Rat) (gamma beta : Fin n → Rat) + (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) + (x : Fin n → Rat) : (Fin n → Rat) × (Fin n → Rat) := let mlp := layerNormMlpBounds eps gamma beta wIn bIn wOut bOut x residualAddBounds x mlp.1 mlp.2 /-- `layerNormMlpResidualBounds` soundness for the MLP residual path. -/ theorem layerNormMlpResidualBounds_spec {n hidden : Nat} - (eps : Dyadic) (gamma beta : Fin n → Dyadic) - (wIn : Fin n → Fin hidden → Dyadic) (bIn : Fin hidden → Dyadic) - (wOut : Fin hidden → Fin n → Dyadic) (bOut : Fin n → Dyadic) - (x : Fin n → Dyadic) (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : + (eps : Rat) (gamma beta : Fin n → Rat) + (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) + (x : Fin n → Rat) (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : let bounds := layerNormMlpResidualBounds eps gamma beta wIn bIn wOut bOut x ∀ i, (bounds.1 i : Real) ≤ @@ -239,19 +239,19 @@ theorem layerNormMlpResidualBounds_spec {n hidden : Nat} /-- Interval bounds for a full MLP residual path (LayerNorm + MLP + residual add) from intervals. -/ def layerNormAbsMlpResidualBounds {n hidden : Nat} - (eps : Dyadic) (gamma beta : Fin n → Dyadic) - (wIn : Fin n → Fin hidden → Dyadic) (bIn : Fin hidden → Dyadic) - (wOut : Fin hidden → Fin n → Dyadic) (bOut : Fin n → Dyadic) - (lo hi : Fin n → Dyadic) : (Fin n → Dyadic) × (Fin n → Dyadic) := + (eps : Rat) (gamma beta : Fin n → Rat) + (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) + (lo hi : Fin n → Rat) : (Fin n → Rat) × (Fin n → Rat) := let mlp := layerNormAbsMlpBounds eps gamma beta wIn bIn wOut bOut lo hi (fun i => lo i + mlp.1 i, fun i => hi i + mlp.2 i) /-- `layerNormAbsMlpResidualBounds` soundness for the MLP residual path. -/ theorem layerNormAbsMlpResidualBounds_spec {n hidden : Nat} - (eps : Dyadic) (gamma beta : Fin n → Dyadic) - (wIn : Fin n → Fin hidden → Dyadic) (bIn : Fin hidden → Dyadic) - (wOut : Fin hidden → Fin n → Dyadic) (bOut : Fin n → Dyadic) - (lo hi : Fin n → Dyadic) (x : Fin n → Real) + (eps : Rat) (gamma beta : Fin n → Rat) + (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) + (lo hi : Fin n → Rat) (x : Fin n → Real) (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) (hlo : ∀ i, (lo i : Real) ≤ x i) (hhi : ∀ i, x i ≤ (hi i : Real)) : let bounds := layerNormAbsMlpResidualBounds eps gamma beta wIn bIn wOut bOut lo hi diff --git a/Nfp/Sound/Bounds/Transformer.lean b/Nfp/Sound/Bounds/Transformer.lean index c4c3687..c6ae5a4 100644 --- a/Nfp/Sound/Bounds/Transformer.lean +++ b/Nfp/Sound/Bounds/Transformer.lean @@ -24,7 +24,7 @@ open scoped BigOperators /-- Real-valued output of a transformer layer. -/ noncomputable def transformerLayerReal {seq dModel dHead numHeads hidden : Nat} [NeZero seq] - (eps : Dyadic) (layer : Model.Gpt2LayerSlice dModel hidden) + (eps : Rat) (layer : Model.Gpt2LayerSlice dModel hidden) (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (scores : Fin numHeads → Fin seq → Fin seq → Real) (x : Fin seq → Fin dModel → Real) (q : Fin seq) (i : Fin dModel) : Real := @@ -38,10 +38,10 @@ noncomputable def transformerLayerReal {seq dModel dHead numHeads hidden : Nat} /-- `transformerLayerBounds` soundness for `transformerLayerReal`. -/ theorem transformerLayerBounds_spec_real {seq dModel dHead numHeads hidden : Nat} [NeZero seq] - (eps : Dyadic) (layer : Model.Gpt2LayerSlice dModel hidden) + (eps : Rat) (layer : Model.Gpt2LayerSlice dModel hidden) (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (scores : Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin dModel → Dyadic) (x : Fin seq → Fin dModel → Real) + (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : let bounds := transformerLayerBounds eps layer.ln1Gamma layer.ln1Beta layer.ln2Gamma @@ -63,10 +63,10 @@ theorem transformerLayerBounds_spec_real {seq dModel dHead numHeads hidden : Nat /-- Interval bounds for a transformer layer from per-position bounds. -/ def transformerLayerBoundsPos {seq dModel dHead numHeads hidden : Nat} [NeZero seq] - (eps : Dyadic) (layer : Model.Gpt2LayerSlice dModel hidden) + (eps : Rat) (layer : Model.Gpt2LayerSlice dModel hidden) (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (lo hi : Fin seq → Fin dModel → Dyadic) : - (Fin seq → Fin dModel → Dyadic) × (Fin seq → Fin dModel → Dyadic) := + (lo hi : Fin seq → Fin dModel → Rat) : + (Fin seq → Fin dModel → Rat) × (Fin seq → Fin dModel → Rat) := let positions := (Finset.univ : Finset (Fin seq)) let hpos : positions.Nonempty := by classical @@ -82,8 +82,8 @@ def transformerLayerBoundsPos {seq dModel dHead numHeads hidden : Nat} [NeZero s baseLo baseHi let attnLo := cacheBound attn.1 let attnHi := cacheBound attn.2 - let yLo : Fin seq → Fin dModel → Dyadic := fun q i => loCached q i + attnLo i - let yHi : Fin seq → Fin dModel → Dyadic := fun q i => hiCached q i + attnHi i + let yLo : Fin seq → Fin dModel → Rat := fun q i => loCached q i + attnLo i + let yHi : Fin seq → Fin dModel → Rat := fun q i => hiCached q i + attnHi i let yLoCached := cacheBound2 yLo let yHiCached := cacheBound2 yHi let out := cacheBoundPair2 (fun q => @@ -94,10 +94,10 @@ def transformerLayerBoundsPos {seq dModel dHead numHeads hidden : Nat} [NeZero s /-- `transformerLayerBoundsPos` soundness for `transformerLayerReal`. -/ theorem transformerLayerBoundsPos_spec {seq dModel dHead numHeads hidden : Nat} [NeZero seq] - (eps : Dyadic) (layer : Model.Gpt2LayerSlice dModel hidden) + (eps : Rat) (layer : Model.Gpt2LayerSlice dModel hidden) (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (scores : Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin seq → Fin dModel → Dyadic) (x : Fin seq → Fin dModel → Real) + (lo hi : Fin seq → Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) (hlo : ∀ q i, (lo q i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi q i : Real)) : @@ -172,7 +172,7 @@ theorem transformerLayerBoundsPos_spec {seq dModel dHead numHeads hidden : Nat} /-- Real-valued transformer stack output (folded left over layers). -/ noncomputable def transformerStackReal - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Dyadic) + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) @@ -183,10 +183,10 @@ noncomputable def transformerStackReal /-- Interval bounds for a transformer stack (folded left over layers). -/ def transformerStackBounds {dModel dHead numHeads hidden numLayers : Nat} - (eps : Dyadic) + (eps : Rat) (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (lo hi : Fin dModel → Dyadic) : (Fin dModel → Dyadic) × (Fin dModel → Dyadic) := + (lo hi : Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := let step := fun bounds layerIdx => transformerLayerBounds eps (layers layerIdx).ln1Gamma (layers layerIdx).ln1Beta (layers layerIdx).ln2Gamma (layers layerIdx).ln2Beta (heads layerIdx) @@ -196,23 +196,23 @@ def transformerStackBounds {dModel dHead numHeads hidden numLayers : Nat} /-- Interval bounds for a transformer stack from per-position bounds. -/ def transformerStackBoundsPos {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] - (eps : Dyadic) + (eps : Rat) (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (lo hi : Fin seq → Fin dModel → Dyadic) : - (Fin seq → Fin dModel → Dyadic) × (Fin seq → Fin dModel → Dyadic) := + (lo hi : Fin seq → Fin dModel → Rat) : + (Fin seq → Fin dModel → Rat) × (Fin seq → Fin dModel → Rat) := let step := fun bounds layerIdx => transformerLayerBoundsPos eps (layers layerIdx) (heads layerIdx) bounds.1 bounds.2 Linear.foldlFin numLayers step (lo, hi) private theorem transformerStackBoundsPos_spec_list {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] - (eps : Dyadic) + (eps : Rat) (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : - ∀ (ls : List (Fin numLayers)) (lo hi : Fin seq → Fin dModel → Dyadic) + ∀ (ls : List (Fin numLayers)) (lo hi : Fin seq → Fin dModel → Rat) (x : Fin seq → Fin dModel → Real), (∀ q i, (lo q i : Real) ≤ x q i) → (∀ q i, x q i ≤ (hi q i : Real)) → @@ -244,11 +244,11 @@ private theorem transformerStackBoundsPos_spec_list /-- `transformerStackBoundsPos` soundness for real transformer-stack outputs. -/ theorem transformerStackBoundsPos_spec {seq dModel dHead numHeads hidden numLayers : Nat} - [NeZero seq] (eps : Dyadic) + [NeZero seq] (eps : Rat) (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin seq → Fin dModel → Dyadic) (x : Fin seq → Fin dModel → Real) + (lo hi : Fin seq → Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) (hlo : ∀ q i, (lo q i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi q i : Real)) : @@ -264,12 +264,12 @@ theorem transformerStackBoundsPos_spec {seq dModel dHead numHeads hidden numLaye private theorem transformerStackBounds_spec_list {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] - (eps : Dyadic) + (eps : Rat) (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : - ∀ (ls : List (Fin numLayers)) (lo hi : Fin dModel → Dyadic) + ∀ (ls : List (Fin numLayers)) (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real), (∀ q i, (lo i : Real) ≤ x q i) → (∀ q i, x q i ≤ (hi i : Real)) → @@ -307,11 +307,11 @@ private theorem transformerStackBounds_spec_list /-- `transformerStackBounds` soundness for real transformer-stack outputs. -/ theorem transformerStackBounds_spec {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] - (eps : Dyadic) + (eps : Rat) (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin dModel → Dyadic) (x : Fin seq → Fin dModel → Real) + (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : let bounds := transformerStackBounds eps layers heads lo hi @@ -326,7 +326,7 @@ theorem transformerStackBounds_spec {seq dModel dHead numHeads hidden numLayers /-- Real-valued transformer stack output after the final LayerNorm. -/ noncomputable def transformerStackFinalReal {seq dModel dHead numHeads hidden numLayers : Nat} - [NeZero seq] (eps : Dyadic) (finalLn : Model.Gpt2FinalLayerNorm dModel) + [NeZero seq] (eps : Rat) (finalLn : Model.Gpt2FinalLayerNorm dModel) (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) @@ -336,20 +336,20 @@ noncomputable def transformerStackFinalReal {seq dModel dHead numHeads hidden nu /-- Interval bounds for transformer stack outputs after the final LayerNorm. -/ def transformerStackFinalBounds {dModel dHead numHeads hidden numLayers : Nat} - (eps : Dyadic) (finalLn : Model.Gpt2FinalLayerNorm dModel) + (eps : Rat) (finalLn : Model.Gpt2FinalLayerNorm dModel) (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (lo hi : Fin dModel → Dyadic) : (Fin dModel → Dyadic) × (Fin dModel → Dyadic) := + (lo hi : Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := let stack := transformerStackBounds eps layers heads lo hi layerNormIntervalBounds eps finalLn.gamma finalLn.beta stack.1 stack.2 /-- `transformerStackFinalBounds` soundness for real outputs. -/ theorem transformerStackFinalBounds_spec {seq dModel dHead numHeads hidden numLayers : Nat} - [NeZero seq] (eps : Dyadic) (finalLn : Model.Gpt2FinalLayerNorm dModel) + [NeZero seq] (eps : Rat) (finalLn : Model.Gpt2FinalLayerNorm dModel) (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin dModel → Dyadic) (x : Fin seq → Fin dModel → Real) + (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : let bounds := transformerStackFinalBounds eps finalLn layers heads lo hi @@ -372,12 +372,12 @@ theorem transformerStackFinalBounds_spec {seq dModel dHead numHeads hidden numLa /-- Interval bounds for transformer stack outputs after the final LayerNorm (per-position). -/ def transformerStackFinalBoundsPos - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Dyadic) + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) (finalLn : Model.Gpt2FinalLayerNorm dModel) (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (lo hi : Fin seq → Fin dModel → Dyadic) : - (Fin seq → Fin dModel → Dyadic) × (Fin seq → Fin dModel → Dyadic) := + (lo hi : Fin seq → Fin dModel → Rat) : + (Fin seq → Fin dModel → Rat) × (Fin seq → Fin dModel → Rat) := let stack := transformerStackBoundsPos eps layers heads lo hi let ln := fun q => layerNormIntervalBounds eps finalLn.gamma finalLn.beta (stack.1 q) (stack.2 q) @@ -385,12 +385,12 @@ def transformerStackFinalBoundsPos /-- `transformerStackFinalBoundsPos` soundness for real outputs. -/ theorem transformerStackFinalBoundsPos_spec - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Dyadic) + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) (finalLn : Model.Gpt2FinalLayerNorm dModel) (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin seq → Fin dModel → Dyadic) (x : Fin seq → Fin dModel → Real) + (lo hi : Fin seq → Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) (hlo : ∀ q i, (lo q i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi q i : Real)) : @@ -416,22 +416,22 @@ theorem transformerStackFinalBoundsPos_spec /-- Residual interval bounds for a GPT-2 stack from exact embeddings. -/ def gpt2ResidualIntervalBounds - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Dyadic) + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (finalLn : Model.Gpt2FinalLayerNorm dModel) - (embed : Fin seq → Fin dModel → Dyadic) : (Fin dModel → Dyadic) × (Fin dModel → Dyadic) := + (embed : Fin seq → Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := let base := embeddingIntervalBounds embed transformerStackFinalBounds eps finalLn layers heads base.1 base.2 /-- `gpt2ResidualIntervalBounds` soundness for real GPT-2 outputs. -/ theorem gpt2ResidualIntervalBounds_spec - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Dyadic) + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (finalLn : Model.Gpt2FinalLayerNorm dModel) (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (embed : Fin seq → Fin dModel → Dyadic) + (embed : Fin seq → Fin dModel → Rat) (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : let bounds := gpt2ResidualIntervalBounds eps layers heads finalLn embed ∀ q i, @@ -454,25 +454,25 @@ theorem gpt2ResidualIntervalBounds_spec /-- Residual interval bounds over an active set from exact embeddings. -/ def gpt2ResidualIntervalBoundsActive {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] - (active : Finset (Fin seq)) (hactive : active.Nonempty) (eps : Dyadic) + (active : Finset (Fin seq)) (hactive : active.Nonempty) (eps : Rat) (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (finalLn : Model.Gpt2FinalLayerNorm dModel) - (embed : Fin seq → Fin dModel → Dyadic) : (Fin dModel → Dyadic) × (Fin dModel → Dyadic) := - let baseLo : Fin seq → Fin dModel → Dyadic := embed - let baseHi : Fin seq → Fin dModel → Dyadic := embed + (embed : Fin seq → Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := + let baseLo : Fin seq → Fin dModel → Rat := embed + let baseHi : Fin seq → Fin dModel → Rat := embed let final := transformerStackFinalBoundsPos eps finalLn layers heads baseLo baseHi intervalBoundsOn active hactive final.1 final.2 /-- `gpt2ResidualIntervalBoundsActive` soundness for real GPT-2 outputs. -/ theorem gpt2ResidualIntervalBoundsActive_spec {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] - (active : Finset (Fin seq)) (hactive : active.Nonempty) (eps : Dyadic) + (active : Finset (Fin seq)) (hactive : active.Nonempty) (eps : Rat) (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (finalLn : Model.Gpt2FinalLayerNorm dModel) (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (embed : Fin seq → Fin dModel → Dyadic) + (embed : Fin seq → Fin dModel → Rat) (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : let bounds := gpt2ResidualIntervalBoundsActive active hactive eps layers heads finalLn embed ∀ q, q ∈ active → ∀ i, @@ -483,8 +483,8 @@ theorem gpt2ResidualIntervalBoundsActive_spec (fun q i => (embed q i : Real)) q i ≤ (bounds.2 i : Real) := by classical intro bounds q hq i - let baseLo : Fin seq → Fin dModel → Dyadic := embed - let baseHi : Fin seq → Fin dModel → Dyadic := embed + let baseLo : Fin seq → Fin dModel → Rat := embed + let baseHi : Fin seq → Fin dModel → Rat := embed let final := transformerStackFinalBoundsPos eps finalLn layers heads baseLo baseHi have hfinal := transformerStackFinalBoundsPos_spec eps finalLn layers heads scores baseLo baseHi diff --git a/Nfp/Sound/Bounds/Transformer/Embedding.lean b/Nfp/Sound/Bounds/Transformer/Embedding.lean index 2c80bf2..8ce4e27 100644 --- a/Nfp/Sound/Bounds/Transformer/Embedding.lean +++ b/Nfp/Sound/Bounds/Transformer/Embedding.lean @@ -25,76 +25,76 @@ private lemma fin_univ_nonempty (seq : Nat) [NeZero seq] : /-- Interval bounds across tokens for an embedding map. -/ def embeddingIntervalBounds {seq dModel : Nat} [NeZero seq] - (x : Fin seq → Fin dModel → Dyadic) : (Fin dModel → Dyadic) × (Fin dModel → Dyadic) := + (x : Fin seq → Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := let h : (Finset.univ : Finset (Fin seq)).Nonempty := fin_univ_nonempty (seq := seq) (fun i => (Finset.univ).inf' h (fun q => x q i), fun i => (Finset.univ).sup' h (fun q => x q i)) /-- `embeddingIntervalBounds` bounds embeddings coordinatewise. -/ theorem embeddingIntervalBounds_spec {seq dModel : Nat} [NeZero seq] - (x : Fin seq → Fin dModel → Dyadic) : + (x : Fin seq → Fin dModel → Rat) : let bounds := embeddingIntervalBounds x ∀ q i, (bounds.1 i : Real) ≤ (x q i : Real) ∧ (x q i : Real) ≤ (bounds.2 i : Real) := by classical intro bounds q i - have hloDyadic : bounds.1 i ≤ x q i := by + have hloRat : bounds.1 i ≤ x q i := by have h := Finset.inf'_le (s := (Finset.univ : Finset (Fin seq))) (f := fun k => x k i) (b := q) (by simp) simpa [bounds, embeddingIntervalBounds, fin_univ_nonempty] using h - have hhiDyadic : x q i ≤ bounds.2 i := by + have hhiRat : x q i ≤ bounds.2 i := by have h := Finset.le_sup' (s := (Finset.univ : Finset (Fin seq))) (f := fun k => x k i) (b := q) (by simp) simpa [bounds, embeddingIntervalBounds, fin_univ_nonempty] using h constructor - · simpa using (dyadicToReal_le_of_le hloDyadic) - · simpa using (dyadicToReal_le_of_le hhiDyadic) + · exact ratToReal_le_of_le hloRat + · exact ratToReal_le_of_le hhiRat /-- Interval bounds across a finite set of positions for an embedding map. -/ def embeddingIntervalBoundsOn {seq dModel : Nat} [NeZero seq] (positions : Finset (Fin seq)) (hpos : positions.Nonempty) - (x : Fin seq → Fin dModel → Dyadic) : (Fin dModel → Dyadic) × (Fin dModel → Dyadic) := + (x : Fin seq → Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := (fun i => positions.inf' hpos (fun q => x q i), fun i => positions.sup' hpos (fun q => x q i)) /-- `embeddingIntervalBoundsOn` bounds embeddings on the chosen positions. -/ theorem embeddingIntervalBoundsOn_spec {seq dModel : Nat} [NeZero seq] (positions : Finset (Fin seq)) (hpos : positions.Nonempty) - (x : Fin seq → Fin dModel → Dyadic) : + (x : Fin seq → Fin dModel → Rat) : let bounds := embeddingIntervalBoundsOn positions hpos x ∀ q, q ∈ positions → ∀ i, (bounds.1 i : Real) ≤ (x q i : Real) ∧ (x q i : Real) ≤ (bounds.2 i : Real) := by classical intro bounds q hq i - have hloDyadic : bounds.1 i ≤ x q i := by + have hloRat : bounds.1 i ≤ x q i := by have h := Finset.inf'_le (s := positions) (f := fun k => x k i) (b := q) hq simpa [bounds, embeddingIntervalBoundsOn] using h - have hhiDyadic : x q i ≤ bounds.2 i := by + have hhiRat : x q i ≤ bounds.2 i := by have h := Finset.le_sup' (s := positions) (f := fun k => x k i) (b := q) hq simpa [bounds, embeddingIntervalBoundsOn] using h constructor - · simpa using (dyadicToReal_le_of_le hloDyadic) - · simpa using (dyadicToReal_le_of_le hhiDyadic) + · exact ratToReal_le_of_le hloRat + · exact ratToReal_le_of_le hhiRat /-- Collapse per-position interval bounds over a finite set of positions. -/ def intervalBoundsOn {seq dModel : Nat} [NeZero seq] (positions : Finset (Fin seq)) (hpos : positions.Nonempty) - (lo hi : Fin seq → Fin dModel → Dyadic) : (Fin dModel → Dyadic) × (Fin dModel → Dyadic) := + (lo hi : Fin seq → Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := (fun i => positions.inf' hpos (fun q => lo q i), fun i => positions.sup' hpos (fun q => hi q i)) /-- `intervalBoundsOn` soundness for bounds on the chosen positions. -/ theorem intervalBoundsOn_spec {seq dModel : Nat} [NeZero seq] (positions : Finset (Fin seq)) (hpos : positions.Nonempty) - (lo hi : Fin seq → Fin dModel → Dyadic) (x : Fin seq → Fin dModel → Real) + (lo hi : Fin seq → Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) (hlo : ∀ q, q ∈ positions → ∀ i, (lo q i : Real) ≤ x q i) (hhi : ∀ q, q ∈ positions → ∀ i, x q i ≤ (hi q i : Real)) : let bounds := intervalBoundsOn positions hpos lo hi @@ -118,11 +118,11 @@ theorem intervalBoundsOn_spec {seq dModel : Nat} [NeZero seq] constructor · have hmin_real : (bounds.1 i : Real) ≤ (lo q i : Real) := by - simpa using (dyadicToReal_le_of_le hmin) + exact ratToReal_le_of_le hmin exact le_trans hmin_real hlo' · have hmax_real : (hi q i : Real) ≤ (bounds.2 i : Real) := by - simpa using (dyadicToReal_le_of_le hmax) + exact ratToReal_le_of_le hmax exact le_trans hhi' hmax_real end Bounds diff --git a/Nfp/Sound/Bounds/UnnormRat.lean b/Nfp/Sound/Bounds/UnnormRat.lean index e2c0273..ff9f1c3 100644 --- a/Nfp/Sound/Bounds/UnnormRat.lean +++ b/Nfp/Sound/Bounds/UnnormRat.lean @@ -4,9 +4,9 @@ import Nfp.Core.Basic import Nfp.Sound.Linear.FinFold /-! -Unnormalized dyadic arithmetic. +Unnormalized rational arithmetic. -Dyadic values already avoid gcd normalization, so this module provides a +Rat values already avoid gcd normalization, so this module provides a lightweight alias and helper API used by older code paths. -/ @@ -16,42 +16,41 @@ namespace Sound namespace Bounds -/-- Unnormalized dyadic value (alias). -/ -abbrev UnnormDyadic := Dyadic +/-- Unnormalized rational value (alias). -/ +abbrev UnnormRat := Rat -/-- Interpret an unnormalized dyadic as a dyadic. -/ -def UnnormDyadic.toDyadic (q : UnnormDyadic) : Dyadic := +/-- Interpret an unnormalized rational as a rational. -/ +def UnnormRat.toRat (q : UnnormRat) : Rat := q -/-- Embed a dyadic as an unnormalized dyadic. -/ -def UnnormDyadic.ofDyadic (q : Dyadic) : UnnormDyadic := +/-- Embed a rational as an unnormalized rational. -/ +def UnnormRat.ofRat (q : Rat) : UnnormRat := q /-- Unnormalized zero. -/ -def UnnormDyadic.zero : UnnormDyadic := 0 +def UnnormRat.zero : UnnormRat := 0 /-- Unnormalized addition. -/ -def UnnormDyadic.add (a b : UnnormDyadic) : UnnormDyadic := +def UnnormRat.add (a b : UnnormRat) : UnnormRat := a + b /-- Unnormalized multiplication. -/ -def UnnormDyadic.mul (a b : UnnormDyadic) : UnnormDyadic := +def UnnormRat.mul (a b : UnnormRat) : UnnormRat := a * b -/-- `toDyadic` respects multiplication. -/ -theorem UnnormDyadic.toDyadic_mul_ofDyadic (a b : Dyadic) : - UnnormDyadic.toDyadic (UnnormDyadic.mul (UnnormDyadic.ofDyadic a) - (UnnormDyadic.ofDyadic b)) = a * b := by +/-- `toRat` respects multiplication. -/ +theorem UnnormRat.toRat_mul_ofRat (a b : Rat) : + UnnormRat.toRat (UnnormRat.mul (UnnormRat.ofRat a) (UnnormRat.ofRat b)) = a * b := by rfl -/-- Tail-recursive sum of unnormalized dyadics. -/ -def UnnormDyadic.sumFin (n : Nat) (f : Fin n → UnnormDyadic) : UnnormDyadic := +/-- Tail-recursive sum of unnormalized rationals. -/ +def UnnormRat.sumFin (n : Nat) (f : Fin n → UnnormRat) : UnnormRat := Linear.sumFin n f -/-- `toDyadic` commutes with `sumFin`. -/ -theorem UnnormDyadic.toDyadic_sumFin (n : Nat) (f : Fin n → UnnormDyadic) : - UnnormDyadic.toDyadic (UnnormDyadic.sumFin n f) = - Linear.sumFin n (fun i => UnnormDyadic.toDyadic (f i)) := by +/-- `toRat` commutes with `sumFin`. -/ +theorem UnnormRat.toRat_sumFin (n : Nat) (f : Fin n → UnnormRat) : + UnnormRat.toRat (UnnormRat.sumFin n f) = + Linear.sumFin n (fun i => UnnormRat.toRat (f i)) := by rfl end Bounds diff --git a/Nfp/Sound/Gpt2/HeadInputs.lean b/Nfp/Sound/Gpt2/HeadInputs.lean index 054cd55..7d66d41 100644 --- a/Nfp/Sound/Gpt2/HeadInputs.lean +++ b/Nfp/Sound/Gpt2/HeadInputs.lean @@ -40,7 +40,7 @@ def buildInductionHeadInputs {seq dModel dHead vocab : Nat} wo := slice.wo attnBias := slice.attnBias maskCausal := true - maskValue := (-10000 : Dyadic) + maskValue := (-10000 : Rat) directionSpec := slice.direction.spec direction := slice.directionVec } @@ -64,7 +64,7 @@ theorem buildInductionHeadInputs_def {seq dModel dHead vocab : Nat} wo := slice.wo attnBias := slice.attnBias maskCausal := true - maskValue := (-10000 : Dyadic) + maskValue := (-10000 : Rat) directionSpec := slice.direction.spec direction := slice.directionVec } := rfl diff --git a/Nfp/Sound/Induction/CoreDefs.lean b/Nfp/Sound/Induction/CoreDefs.lean index e8115c6..a38d0ac 100644 --- a/Nfp/Sound/Induction/CoreDefs.lean +++ b/Nfp/Sound/Induction/CoreDefs.lean @@ -29,7 +29,7 @@ variable {seq : Nat} /-- Cached direction head for head inputs. -/ def dirHeadVecOfInputs {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : Vector Dyadic dHead := + (inputs : Model.InductionHeadInputs seq dModel dHead) : Vector Rat dHead := Vector.ofFn (fun d : Fin dHead => Linear.dotFin dModel (fun j => inputs.wo j d) (fun j => inputs.direction j)) @@ -87,13 +87,13 @@ noncomputable def valsRealOfInputs {seq dModel dHead : Nat} /-- Interval data for direction values. -/ structure ValueInterval (seq : Nat) where /-- Lower bound for values. -/ - lo : Dyadic + lo : Rat /-- Upper bound for values. -/ - hi : Dyadic + hi : Rat /-- Lower bounds on per-key values. -/ - valsLo : Fin seq → Dyadic + valsLo : Fin seq → Rat /-- Upper bounds on per-key values. -/ - valsHi : Fin seq → Dyadic + valsHi : Fin seq → Rat /-- Optional logit-diff direction metadata (ignored by the checker). -/ direction : Option DirectionSpec @@ -113,11 +113,11 @@ structure ValueIntervalBounds {seq : Nat} (vals : Fin seq → Real) /-- Sound induction-certificate payload built from exact head inputs. -/ structure InductionHeadCert (seq : Nat) where /-- Weight tolerance. -/ - eps : Dyadic + eps : Rat /-- Per-query weight tolerance derived from local margins. -/ - epsAt : Fin seq → Dyadic + epsAt : Fin seq → Rat /-- Score margin used to justify the weight tolerance. -/ - margin : Dyadic + margin : Rat /-- Active queries for which bounds are required. -/ active : Finset (Fin seq) /-- `prev` selector for induction-style attention. -/ diff --git a/Nfp/Sound/Induction/HeadBounds.lean b/Nfp/Sound/Induction/HeadBounds.lean index 747e054..0cdde49 100644 --- a/Nfp/Sound/Induction/HeadBounds.lean +++ b/Nfp/Sound/Induction/HeadBounds.lean @@ -2,6 +2,7 @@ import Nfp.Core.Basic import Mathlib.Data.Finset.Basic +import Mathlib.Data.List.Range import Mathlib.Data.Vector.Defs import Nfp.Model.InductionHead import Nfp.Sound.Bounds.Attention @@ -23,16 +24,86 @@ open Nfp.Sound.Bounds variable {seq : Nat} +private def taskMin (t1 t2 : Task Rat) : Task Rat := + Task.bind t1 (fun v1 => Task.map (fun v2 => min v1 v2) t2) + +private def taskMax (t1 t2 : Task Rat) : Task Rat := + Task.bind t1 (fun v1 => Task.map (fun v2 => max v1 v2) t2) + +/-! Helpers for reducing cached arrays without extra allocation. -/ + +/-- Reduce an array of rational bounds to its minimum (defaulting to `0` on empty arrays). -/ +private def reduceMinArray (arr : Array Rat) : Rat := + let init := arr.getD 0 (0 : Rat) + arr.foldl (fun acc x => min acc x) init + +/-- Reduce an array of rational bounds to its maximum (defaulting to `0` on empty arrays). -/ +private def reduceMaxArray (arr : Array Rat) : Rat := + let init := arr.getD 0 (0 : Rat) + arr.foldl (fun acc x => max acc x) init + +private def reduceMinFnTask [NeZero seq] (vals : Fin seq → Rat) : Task Rat := + let n := seq + if n = 0 then + Task.pure (0 : Rat) + else + let chunkSize : Nat := 256 + let chunks : Nat := (n + chunkSize - 1) / chunkSize + let hpos : 0 < seq := Nat.pos_of_ne_zero (by simpa using (NeZero.ne (n := seq))) + let defaultIdx : Fin seq := ⟨0, hpos⟩ + let idxs : Array (Fin seq) := Array.ofFn (fun i : Fin seq => i) + let defaultTask : Task Rat := Task.pure (0 : Rat) + let chunkTasks : Array (Task Rat) := + Array.ofFn (fun c : Fin chunks => + Task.spawn (fun _ => + let start := c.val * chunkSize + let stop := Nat.min n (start + chunkSize) + let init := vals (idxs.getD start defaultIdx) + if stop ≤ start + 1 then + init + else + let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) + rest.foldl (fun acc i => min acc (vals (idxs.getD i defaultIdx))) init)) + let init := chunkTasks.getD 0 defaultTask + let rest := (List.range (chunkTasks.size - 1)).map (fun i => i + 1) + rest.foldl (fun acc i => taskMin acc (chunkTasks.getD i defaultTask)) init + +private def reduceMaxFnTask [NeZero seq] (vals : Fin seq → Rat) : Task Rat := + let n := seq + if n = 0 then + Task.pure (0 : Rat) + else + let chunkSize : Nat := 256 + let chunks : Nat := (n + chunkSize - 1) / chunkSize + let hpos : 0 < seq := Nat.pos_of_ne_zero (by simpa using (NeZero.ne (n := seq))) + let defaultIdx : Fin seq := ⟨0, hpos⟩ + let idxs : Array (Fin seq) := Array.ofFn (fun i : Fin seq => i) + let defaultTask : Task Rat := Task.pure (0 : Rat) + let chunkTasks : Array (Task Rat) := + Array.ofFn (fun c : Fin chunks => + Task.spawn (fun _ => + let start := c.val * chunkSize + let stop := Nat.min n (start + chunkSize) + let init := vals (idxs.getD start defaultIdx) + if stop ≤ start + 1 then + init + else + let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) + rest.foldl (fun acc i => max acc (vals (idxs.getD i defaultIdx))) init)) + let init := chunkTasks.getD 0 defaultTask + let rest := (List.range (chunkTasks.size - 1)).map (fun i => i + 1) + rest.foldl (fun acc i => taskMax acc (chunkTasks.getD i defaultTask)) init + /-- Cached direction head for head inputs. -/ private def dirHeadVecOfInputs {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : Vector Dyadic dHead := + (inputs : Model.InductionHeadInputs seq dModel dHead) : Vector Rat dHead := Vector.ofFn (fun d : Fin dHead => Linear.dotFin dModel (fun j => inputs.wo j d) (fun j => inputs.direction j)) /-- LayerNorm bounds used by the induction-head builder. -/ def headLnBounds [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : - (Fin seq → Fin dModel → Dyadic) × (Fin seq → Fin dModel → Dyadic) := + (Fin seq → Fin dModel → Rat) × (Fin seq → Fin dModel → Rat) := Bounds.cacheBoundPair2 (fun q => Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) @@ -45,55 +116,55 @@ theorem headLnBounds_spec [NeZero seq] {dModel dHead : Nat} /-- Q/K/V bounds used by the induction-head builder. -/ structure HeadQKVBounds (seq dModel dHead : Nat) where /-- Q lower bounds. -/ - qLo : Fin seq → Fin dHead → Dyadic + qLo : Fin seq → Fin dHead → Rat /-- Q upper bounds. -/ - qHi : Fin seq → Fin dHead → Dyadic + qHi : Fin seq → Fin dHead → Rat /-- K lower bounds. -/ - kLo : Fin seq → Fin dHead → Dyadic + kLo : Fin seq → Fin dHead → Rat /-- K upper bounds. -/ - kHi : Fin seq → Fin dHead → Dyadic + kHi : Fin seq → Fin dHead → Rat /-- V lower bounds. -/ - vLo : Fin seq → Fin dHead → Dyadic + vLo : Fin seq → Fin dHead → Rat /-- V upper bounds. -/ - vHi : Fin seq → Fin dHead → Dyadic + vHi : Fin seq → Fin dHead → Rat /-- Q absolute bounds. -/ - qAbs : Fin seq → Fin dHead → Dyadic + qAbs : Fin seq → Fin dHead → Rat /-- K absolute bounds. -/ - kAbs : Fin seq → Fin dHead → Dyadic + kAbs : Fin seq → Fin dHead → Rat /-- Compute Q/K/V bounds from LayerNorm bounds. -/ def headQKVBounds [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (lnLo lnHi : Fin seq → Fin dModel → Dyadic) : + (lnLo lnHi : Fin seq → Fin dModel → Rat) : HeadQKVBounds seq dModel dHead := let qLo := - Bounds.cacheBound2TaskElem (fun q d => + Bounds.cacheBound2 (fun q d => Bounds.dotIntervalLowerUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + inputs.bq d) let qHi := - Bounds.cacheBound2TaskElem (fun q d => + Bounds.cacheBound2 (fun q d => Bounds.dotIntervalUpperUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + inputs.bq d) let kLo := - Bounds.cacheBound2TaskElem (fun q d => + Bounds.cacheBound2 (fun q d => Bounds.dotIntervalLowerUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + inputs.bk d) let kHi := - Bounds.cacheBound2TaskElem (fun q d => + Bounds.cacheBound2 (fun q d => Bounds.dotIntervalUpperUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + inputs.bk d) let vLo := - Bounds.cacheBound2TaskElem (fun q d => + Bounds.cacheBound2 (fun q d => Bounds.dotIntervalLowerUnnorm (fun j => inputs.wv j d) (lnLo q) (lnHi q) + inputs.bv d) let vHi := - Bounds.cacheBound2TaskElem (fun q d => + Bounds.cacheBound2 (fun q d => Bounds.dotIntervalUpperUnnorm (fun j => inputs.wv j d) (lnLo q) (lnHi q) + inputs.bv d) let qAbs := - Bounds.cacheBound2TaskElem (fun q d => max |qLo q d| |qHi q d|) + Bounds.cacheBound2 (fun q d => max |qLo q d| |qHi q d|) let kAbs := - Bounds.cacheBound2TaskElem (fun q d => max |kLo q d| |kHi q d|) + Bounds.cacheBound2 (fun q d => max |kLo q d| |kHi q d|) { qLo := qLo qHi := qHi kLo := kLo @@ -105,36 +176,36 @@ def headQKVBounds [NeZero seq] {dModel dHead : Nat} theorem headQKVBounds_spec [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (lnLo lnHi : Fin seq → Fin dModel → Dyadic) : + (lnLo lnHi : Fin seq → Fin dModel → Rat) : headQKVBounds inputs lnLo lnHi = let qLo := - Bounds.cacheBound2TaskElem (fun q d => + Bounds.cacheBound2 (fun q d => Bounds.dotIntervalLowerUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + inputs.bq d) let qHi := - Bounds.cacheBound2TaskElem (fun q d => + Bounds.cacheBound2 (fun q d => Bounds.dotIntervalUpperUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + inputs.bq d) let kLo := - Bounds.cacheBound2TaskElem (fun q d => + Bounds.cacheBound2 (fun q d => Bounds.dotIntervalLowerUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + inputs.bk d) let kHi := - Bounds.cacheBound2TaskElem (fun q d => + Bounds.cacheBound2 (fun q d => Bounds.dotIntervalUpperUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + inputs.bk d) let vLo := - Bounds.cacheBound2TaskElem (fun q d => + Bounds.cacheBound2 (fun q d => Bounds.dotIntervalLowerUnnorm (fun j => inputs.wv j d) (lnLo q) (lnHi q) + inputs.bv d) let vHi := - Bounds.cacheBound2TaskElem (fun q d => + Bounds.cacheBound2 (fun q d => Bounds.dotIntervalUpperUnnorm (fun j => inputs.wv j d) (lnLo q) (lnHi q) + inputs.bv d) let qAbs := - Bounds.cacheBound2TaskElem (fun q d => max |qLo q d| |qHi q d|) + Bounds.cacheBound2 (fun q d => max |qLo q d| |qHi q d|) let kAbs := - Bounds.cacheBound2TaskElem (fun q d => max |kLo q d| |kHi q d|) + Bounds.cacheBound2 (fun q d => max |kLo q d| |kHi q d|) { qLo := qLo qHi := qHi kLo := kLo @@ -147,35 +218,35 @@ theorem headQKVBounds_spec [NeZero seq] {dModel dHead : Nat} /-- Score and margin bounds used by the induction-head builder. -/ structure HeadScoreBounds (seq dModel dHead : Nat) where /-- Absolute dot-product bound. -/ - dotAbs : Fin seq → Fin seq → Dyadic + dotAbs : Fin seq → Fin seq → Rat /-- Base score absolute bound. -/ - scoreBaseAbs : Fin seq → Fin seq → Dyadic + scoreBaseAbs : Fin seq → Fin seq → Rat /-- Score absolute bound with causal masking. -/ - scoreAbs : Fin seq → Fin seq → Dyadic + scoreAbs : Fin seq → Fin seq → Rat /-- Score lower bound. -/ - scoreLo : Fin seq → Fin seq → Dyadic + scoreLo : Fin seq → Fin seq → Rat /-- Score upper bound. -/ - scoreHi : Fin seq → Fin seq → Dyadic + scoreHi : Fin seq → Fin seq → Rat /-- Margin per query. -/ - marginAt : Fin seq → Dyadic + marginAt : Fin seq → Rat /-- Epsilon per query. -/ - epsAt : Fin seq → Dyadic + epsAt : Fin seq → Rat /-- Global margin. -/ - margin : Dyadic + margin : Rat /-- Global epsilon. -/ - eps : Dyadic + eps : Rat /-- Compute score and margin bounds from cached score lower/upper bounds. -/ def headScoreBoundsFromCaches [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (dotAbs : Fin seq → Fin seq → Dyadic) - (scoreLo scoreHi : Fin seq → Fin seq → Dyadic) : + (dotAbs : Fin seq → Fin seq → Rat) + (scoreLo scoreHi : Fin seq → Fin seq → Rat) : HeadScoreBounds seq dModel dHead := let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k - let scoreBaseAbs : Fin seq → Fin seq → Dyadic := fun q k => + let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => |inputs.scale| * dotAbs q k - let scoreAbs : Fin seq → Fin seq → Dyadic := fun q k => + let scoreAbs : Fin seq → Fin seq → Rat := fun q k => if masked q k then |inputs.maskValue| else scoreBaseAbs q k let otherKeys : Fin seq → Finset (Fin seq) := fun q => (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) @@ -186,9 +257,9 @@ def headScoreBoundsFromCaches [NeZero seq] {dModel dHead : Nat} (∅ : Finset (Fin seq)) let unmaskedKeys : Fin seq → Finset (Fin seq) := fun q => (otherKeys q) \ (maskedKeys q) - let maskedGap : Fin seq → Dyadic := fun q => + let maskedGap : Fin seq → Rat := fun q => scoreLo q (inputs.prev q) - inputs.maskValue - let marginTasks : Array (Task Dyadic) := + let marginTasks : Array (Task Rat) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => if q ∈ inputs.active then @@ -205,33 +276,33 @@ def headScoreBoundsFromCaches [NeZero seq] {dModel dHead : Nat} if hmasked : masked.Nonempty then maskedGap q else - (0 : Dyadic) + (0 : Rat) else - (0 : Dyadic))) - let marginAt : Fin seq → Dyadic := fun q => + (0 : Rat))) + let marginAt : Fin seq → Rat := fun q => (marginTasks[q.1]'(by simp [marginTasks, q.isLt])).get - let epsTasks : Array (Task Dyadic) := + let epsTasks : Array (Task Rat) := Array.ofFn (fun q : Fin seq => (marginTasks[q.1]'(by simp [marginTasks, q.isLt])).map (fun m => if m < 0 then - (1 : Dyadic) + (1 : Rat) else - dyadicDivUp (seq - 1) (1 + m))) - let epsAt : Fin seq → Dyadic := fun q => + ratDivUp (seq - 1) (1 + m))) + let epsAt : Fin seq → Rat := fun q => (epsTasks[q.1]'(by simp [epsTasks, q.isLt])).get - let margin : Dyadic := + let margin : Rat := if h : inputs.active.Nonempty then inputs.active.inf' h marginAt else - (0 : Dyadic) - let eps : Dyadic := + (0 : Rat) + let eps : Rat := if margin < 0 then - (1 : Dyadic) + (1 : Rat) else - dyadicDivUp (seq - 1) (1 + margin) + ratDivUp (seq - 1) (1 + margin) { dotAbs := dotAbs scoreBaseAbs := scoreBaseAbs scoreAbs := scoreAbs @@ -245,123 +316,348 @@ def headScoreBoundsFromCaches [NeZero seq] {dModel dHead : Nat} /-- Compute score and margin bounds from dot-product absolute bounds. -/ def headScoreBoundsFromDotAbs [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (dotAbs : Fin seq → Fin seq → Dyadic) : HeadScoreBounds seq dModel dHead := + (dotAbs : Fin seq → Fin seq → Rat) : HeadScoreBounds seq dModel dHead := let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k - let dotAbsRowTasks : Array (Task { row : Array Dyadic // row.size = seq }) := + let dotAbsRowTasks : Array (Task { row : Array Rat // row.size = seq }) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => ⟨Array.ofFn (fun k : Fin seq => dotAbs q k), by simp⟩)) - let scoreRowTasks : Array (Task (Array Dyadic × Array Dyadic)) := - Array.ofFn (fun q : Fin seq => - (dotAbsRowTasks[q.1]'(by - simp [dotAbsRowTasks, q.isLt])).map (fun row => - let rowArr := row.1 - let scoreBaseAt : Fin seq → Dyadic := fun k => - |inputs.scale| * rowArr.getD k.1 0 - let loRow := Array.ofFn (fun k : Fin seq => - if masked q k then inputs.maskValue else -scoreBaseAt k) - let hiRow := Array.ofFn (fun k : Fin seq => - if masked q k then inputs.maskValue else scoreBaseAt k) - (loRow, hiRow))) - let scoreLoCached : Fin seq → Fin seq → Dyadic := fun q k => - let rowPair := (scoreRowTasks[q.1]'(by - simp [scoreRowTasks, q.isLt])).get - rowPair.1.getD k.1 0 - let scoreHiCached : Fin seq → Fin seq → Dyadic := fun q k => - let rowPair := (scoreRowTasks[q.1]'(by - simp [scoreRowTasks, q.isLt])).get - rowPair.2.getD k.1 0 - headScoreBoundsFromCaches inputs dotAbs scoreLoCached scoreHiCached + let scaleAbs : Rat := |inputs.scale| + let scoreLoCached : Fin seq → Fin seq → Rat := fun q k => + let row := (dotAbsRowTasks[q.1]'(by + simp [dotAbsRowTasks, q.isLt])).get + let base := scaleAbs * row.1.getD k.1 0 + if masked q k then inputs.maskValue else -base + let scoreHiCached : Fin seq → Fin seq → Rat := fun q k => + let row := (dotAbsRowTasks[q.1]'(by + simp [dotAbsRowTasks, q.isLt])).get + let base := scaleAbs * row.1.getD k.1 0 + if masked q k then inputs.maskValue else base + let marginAtRaw : Fin seq → Rat := fun q => + let row := (dotAbsRowTasks[q.1]'(by + simp [dotAbsRowTasks, q.isLt])).get + if q ∈ inputs.active then + let rowArr := row.1 + let prev := inputs.prev q + let dotAbsPrev := rowArr.getD prev.1 0 + if masked q prev then + let scoreLoPrev := inputs.maskValue + let scoreHiAt : Fin seq → Rat := fun k => + if masked q k then inputs.maskValue else scaleAbs * rowArr.getD k.1 0 + let maskedGap := scoreLoPrev - inputs.maskValue + let step : + (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := + fun acc k => + if k = prev then + acc + else if masked q k then + (acc.1, true) + else + let v := scoreLoPrev - scoreHiAt k + match acc.1 with + | none => (some v, acc.2) + | some cur => (some (min cur v), acc.2) + let acc := Linear.foldlFin seq step (none, false) + match acc.1, acc.2 with + | some unmaskedMin, true => min unmaskedMin maskedGap + | some unmaskedMin, false => unmaskedMin + | none, true => maskedGap + | none, false => (0 : Rat) + else + let scoreLoPrev := -(scaleAbs * dotAbsPrev) + let maskedGap := scoreLoPrev - inputs.maskValue + let step : + (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := + fun acc k => + if k = prev then + acc + else if masked q k then + (acc.1, true) + else + let raw := -(dotAbsPrev + rowArr.getD k.1 0) + match acc.1 with + | none => (some raw, acc.2) + | some cur => (some (min cur raw), acc.2) + let acc := Linear.foldlFin seq step (none, false) + match acc.1, acc.2 with + | some unmaskedMin, true => min (scaleAbs * unmaskedMin) maskedGap + | some unmaskedMin, false => scaleAbs * unmaskedMin + | none, true => maskedGap + | none, false => (0 : Rat) + else + (0 : Rat) + let marginAtCached := Bounds.cacheBoundThunk marginAtRaw + let marginAt : Fin seq → Rat := fun q => + marginAtCached q + let epsAtRaw : Fin seq → Rat := fun q => + let m := marginAt q + if m < 0 then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + m) + let epsAtCached := Bounds.cacheBoundThunk epsAtRaw + let epsAt : Fin seq → Rat := fun q => + epsAtCached q + let margin : Rat := + if h : inputs.active.Nonempty then + inputs.active.inf' h marginAt + else + (0 : Rat) + let eps : Rat := + if margin < 0 then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + margin) + { dotAbs := dotAbs + scoreBaseAbs := fun q k => |inputs.scale| * dotAbs q k + scoreAbs := fun q k => + if masked q k then |inputs.maskValue| else |inputs.scale| * dotAbs q k + scoreLo := scoreLoCached + scoreHi := scoreHiCached + marginAt := marginAt + epsAt := epsAt + margin := margin + eps := eps } theorem headScoreBoundsFromDotAbs_spec [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (dotAbs : Fin seq → Fin seq → Dyadic) : + (dotAbs : Fin seq → Fin seq → Rat) : headScoreBoundsFromDotAbs inputs dotAbs = let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k - let dotAbsRowTasks : Array (Task { row : Array Dyadic // row.size = seq }) := + let dotAbsRowTasks : Array (Task { row : Array Rat // row.size = seq }) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => ⟨Array.ofFn (fun k : Fin seq => dotAbs q k), by simp⟩)) - let scoreRowTasks : Array (Task (Array Dyadic × Array Dyadic)) := - Array.ofFn (fun q : Fin seq => - (dotAbsRowTasks[q.1]'(by - simp [dotAbsRowTasks, q.isLt])).map (fun row => - let rowArr := row.1 - let scoreBaseAt : Fin seq → Dyadic := fun k => - |inputs.scale| * rowArr.getD k.1 0 - let loRow := Array.ofFn (fun k : Fin seq => - if masked q k then inputs.maskValue else -scoreBaseAt k) - let hiRow := Array.ofFn (fun k : Fin seq => - if masked q k then inputs.maskValue else scoreBaseAt k) - (loRow, hiRow))) - let scoreLoCached : Fin seq → Fin seq → Dyadic := fun q k => - let rowPair := (scoreRowTasks[q.1]'(by - simp [scoreRowTasks, q.isLt])).get - rowPair.1.getD k.1 0 - let scoreHiCached : Fin seq → Fin seq → Dyadic := fun q k => - let rowPair := (scoreRowTasks[q.1]'(by - simp [scoreRowTasks, q.isLt])).get - rowPair.2.getD k.1 0 - headScoreBoundsFromCaches inputs dotAbs scoreLoCached scoreHiCached := rfl + let scaleAbs : Rat := |inputs.scale| + let scoreLoCached : Fin seq → Fin seq → Rat := fun q k => + let row := (dotAbsRowTasks[q.1]'(by + simp [dotAbsRowTasks, q.isLt])).get + let base := scaleAbs * row.1.getD k.1 0 + if masked q k then inputs.maskValue else -base + let scoreHiCached : Fin seq → Fin seq → Rat := fun q k => + let row := (dotAbsRowTasks[q.1]'(by + simp [dotAbsRowTasks, q.isLt])).get + let base := scaleAbs * row.1.getD k.1 0 + if masked q k then inputs.maskValue else base + let marginAtRaw : Fin seq → Rat := fun q => + let row := (dotAbsRowTasks[q.1]'(by + simp [dotAbsRowTasks, q.isLt])).get + if q ∈ inputs.active then + let rowArr := row.1 + let prev := inputs.prev q + let dotAbsPrev := rowArr.getD prev.1 0 + if masked q prev then + let scoreLoPrev := inputs.maskValue + let scoreHiAt : Fin seq → Rat := fun k => + if masked q k then inputs.maskValue else scaleAbs * rowArr.getD k.1 0 + let maskedGap := scoreLoPrev - inputs.maskValue + let step : + (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := + fun acc k => + if k = prev then + acc + else if masked q k then + (acc.1, true) + else + let v := scoreLoPrev - scoreHiAt k + match acc.1 with + | none => (some v, acc.2) + | some cur => (some (min cur v), acc.2) + let acc := Linear.foldlFin seq step (none, false) + match acc.1, acc.2 with + | some unmaskedMin, true => min unmaskedMin maskedGap + | some unmaskedMin, false => unmaskedMin + | none, true => maskedGap + | none, false => (0 : Rat) + else + let scoreLoPrev := -(scaleAbs * dotAbsPrev) + let maskedGap := scoreLoPrev - inputs.maskValue + let step : + (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := + fun acc k => + if k = prev then + acc + else if masked q k then + (acc.1, true) + else + let raw := -(dotAbsPrev + rowArr.getD k.1 0) + match acc.1 with + | none => (some raw, acc.2) + | some cur => (some (min cur raw), acc.2) + let acc := Linear.foldlFin seq step (none, false) + match acc.1, acc.2 with + | some unmaskedMin, true => min (scaleAbs * unmaskedMin) maskedGap + | some unmaskedMin, false => scaleAbs * unmaskedMin + | none, true => maskedGap + | none, false => (0 : Rat) + else + (0 : Rat) + let marginAtCached := Bounds.cacheBoundThunk marginAtRaw + let marginAt : Fin seq → Rat := fun q => + marginAtCached q + let epsAtRaw : Fin seq → Rat := fun q => + let m := marginAt q + if m < 0 then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + m) + let epsAtCached := Bounds.cacheBoundThunk epsAtRaw + let epsAt : Fin seq → Rat := fun q => + epsAtCached q + let margin : Rat := + if h : inputs.active.Nonempty then + inputs.active.inf' h marginAt + else + (0 : Rat) + let eps : Rat := + if margin < 0 then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + margin) + { dotAbs := dotAbs + scoreBaseAbs := fun q k => |inputs.scale| * dotAbs q k + scoreAbs := fun q k => + if masked q k then |inputs.maskValue| else |inputs.scale| * dotAbs q k + scoreLo := scoreLoCached + scoreHi := scoreHiCached + marginAt := marginAt + epsAt := epsAt + margin := margin + eps := eps } := rfl /-- Compute score and margin bounds from Q/K absolute bounds. -/ def headScoreBounds [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (qAbs kAbs : Fin seq → Fin dHead → Dyadic) : + (qAbs kAbs : Fin seq → Fin dHead → Rat) : HeadScoreBounds seq dModel dHead := headScoreBoundsFromDotAbs inputs (fun q k => Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d)) theorem headScoreBounds_spec [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (qAbs kAbs : Fin seq → Fin dHead → Dyadic) : + (qAbs kAbs : Fin seq → Fin dHead → Rat) : headScoreBounds inputs qAbs kAbs = let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k - let dotAbs : Fin seq → Fin seq → Dyadic := fun q k => + let dotAbs : Fin seq → Fin seq → Rat := fun q k => Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d) - let dotAbsRowTasks : Array (Task { row : Array Dyadic // row.size = seq }) := + let dotAbsRowTasks : Array (Task { row : Array Rat // row.size = seq }) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => ⟨Array.ofFn (fun k : Fin seq => dotAbs q k), by simp⟩)) - let scoreRowTasks : Array (Task (Array Dyadic × Array Dyadic)) := - Array.ofFn (fun q : Fin seq => - (dotAbsRowTasks[q.1]'(by - simp [dotAbsRowTasks, q.isLt])).map (fun row => - let rowArr := row.1 - let scoreBaseAt : Fin seq → Dyadic := fun k => - |inputs.scale| * rowArr.getD k.1 0 - let loRow := Array.ofFn (fun k : Fin seq => - if masked q k then inputs.maskValue else -scoreBaseAt k) - let hiRow := Array.ofFn (fun k : Fin seq => - if masked q k then inputs.maskValue else scoreBaseAt k) - (loRow, hiRow))) - let scoreLoCached : Fin seq → Fin seq → Dyadic := fun q k => - let rowPair := (scoreRowTasks[q.1]'(by - simp [scoreRowTasks, q.isLt])).get - rowPair.1.getD k.1 0 - let scoreHiCached : Fin seq → Fin seq → Dyadic := fun q k => - let rowPair := (scoreRowTasks[q.1]'(by - simp [scoreRowTasks, q.isLt])).get - rowPair.2.getD k.1 0 - headScoreBoundsFromCaches inputs dotAbs scoreLoCached scoreHiCached := rfl + let scaleAbs : Rat := |inputs.scale| + let scoreLoCached : Fin seq → Fin seq → Rat := fun q k => + let row := (dotAbsRowTasks[q.1]'(by + simp [dotAbsRowTasks, q.isLt])).get + let base := scaleAbs * row.1.getD k.1 0 + if masked q k then inputs.maskValue else -base + let scoreHiCached : Fin seq → Fin seq → Rat := fun q k => + let row := (dotAbsRowTasks[q.1]'(by + simp [dotAbsRowTasks, q.isLt])).get + let base := scaleAbs * row.1.getD k.1 0 + if masked q k then inputs.maskValue else base + let marginAtRaw : Fin seq → Rat := fun q => + let row := (dotAbsRowTasks[q.1]'(by + simp [dotAbsRowTasks, q.isLt])).get + if q ∈ inputs.active then + let rowArr := row.1 + let prev := inputs.prev q + let dotAbsPrev := rowArr.getD prev.1 0 + if masked q prev then + let scoreLoPrev := inputs.maskValue + let scoreHiAt : Fin seq → Rat := fun k => + if masked q k then inputs.maskValue else scaleAbs * rowArr.getD k.1 0 + let maskedGap := scoreLoPrev - inputs.maskValue + let step : + (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := + fun acc k => + if k = prev then + acc + else if masked q k then + (acc.1, true) + else + let v := scoreLoPrev - scoreHiAt k + match acc.1 with + | none => (some v, acc.2) + | some cur => (some (min cur v), acc.2) + let acc := Linear.foldlFin seq step (none, false) + match acc.1, acc.2 with + | some unmaskedMin, true => min unmaskedMin maskedGap + | some unmaskedMin, false => unmaskedMin + | none, true => maskedGap + | none, false => (0 : Rat) + else + let scoreLoPrev := -(scaleAbs * dotAbsPrev) + let maskedGap := scoreLoPrev - inputs.maskValue + let step : + (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := + fun acc k => + if k = prev then + acc + else if masked q k then + (acc.1, true) + else + let raw := -(dotAbsPrev + rowArr.getD k.1 0) + match acc.1 with + | none => (some raw, acc.2) + | some cur => (some (min cur raw), acc.2) + let acc := Linear.foldlFin seq step (none, false) + match acc.1, acc.2 with + | some unmaskedMin, true => min (scaleAbs * unmaskedMin) maskedGap + | some unmaskedMin, false => scaleAbs * unmaskedMin + | none, true => maskedGap + | none, false => (0 : Rat) + else + (0 : Rat) + let marginAtCached := Bounds.cacheBoundThunk marginAtRaw + let marginAt : Fin seq → Rat := fun q => + marginAtCached q + let epsAtRaw : Fin seq → Rat := fun q => + let m := marginAt q + if m < 0 then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + m) + let epsAtCached := Bounds.cacheBoundThunk epsAtRaw + let epsAt : Fin seq → Rat := fun q => + epsAtCached q + let margin : Rat := + if h : inputs.active.Nonempty then + inputs.active.inf' h marginAt + else + (0 : Rat) + let eps : Rat := + if margin < 0 then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + margin) + { dotAbs := dotAbs + scoreBaseAbs := fun q k => |inputs.scale| * dotAbs q k + scoreAbs := fun q k => + if masked q k then |inputs.maskValue| else |inputs.scale| * dotAbs q k + scoreLo := scoreLoCached + scoreHi := scoreHiCached + marginAt := marginAt + epsAt := epsAt + margin := margin + eps := eps } := rfl /-- Value bounds used by the induction-head builder. -/ structure HeadValueBounds (seq dModel dHead : Nat) where /-- Value lower bounds. -/ - valsLo : Fin seq → Dyadic + valsLo : Fin seq → Rat /-- Value upper bounds. -/ - valsHi : Fin seq → Dyadic + valsHi : Fin seq → Rat /-- Global value lower bound. -/ - lo : Dyadic + lo : Rat /-- Global value upper bound. -/ - hi : Dyadic + hi : Rat /-- Cached direction vector for value bounds. -/ def headValueDirHead {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin dHead → Dyadic := + (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin dHead → Rat := let dirHeadVec := dirHeadVecOfInputs inputs fun d => dirHeadVec.get d @@ -372,160 +668,324 @@ theorem headValueDirHead_spec {seq dModel dHead : Nat} fun d => dirHeadVec.get d := rfl /-- Cached lower value bounds from V intervals. -/ -def headValueValsLo {seq dModel dHead : Nat} +def headValueValsLoArray {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Dyadic) : Fin seq → Dyadic := + (vLo vHi : Fin seq → Fin dHead → Rat) : Array Rat := let dirHead := headValueDirHead inputs - Bounds.cacheBound (fun k => - Bounds.dotIntervalLowerCachedDyadic dirHead (vLo k) (vHi k)) + Array.ofFn (fun k => + Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) + +/-- Unfold `headValueValsLoArray` to its `Array.ofFn` definition. -/ +theorem headValueValsLoArray_spec {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueValsLoArray inputs vLo vHi = + let dirHead := headValueDirHead inputs + Array.ofFn (fun k => + Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) := rfl + +/-- Cached lower value bounds from V intervals. -/ +def headValueValsLo {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : Fin seq → Rat := + let arr := headValueValsLoArray inputs vLo vHi + fun k => arr.getD k.1 (0 : Rat) theorem headValueValsLo_spec {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Dyadic) : + (vLo vHi : Fin seq → Fin dHead → Rat) : headValueValsLo inputs vLo vHi = - let dirHead := headValueDirHead inputs - Bounds.cacheBound (fun k => - Bounds.dotIntervalLowerCachedDyadic dirHead (vLo k) (vHi k)) := rfl + let arr := headValueValsLoArray inputs vLo vHi + fun k => arr.getD k.1 (0 : Rat) := rfl /-- Cached lower value bounds from V intervals using a common-denominator sum. -/ -def headValueValsLoCommonDen {seq dModel dHead : Nat} +def headValueValsLoCommonDenArray {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Dyadic) : Fin seq → Dyadic := + (vLo vHi : Fin seq → Fin dHead → Rat) : Array Rat := let dirHead := headValueDirHead inputs - Bounds.cacheBound (fun k => + Array.ofFn (fun k => Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) -theorem headValueValsLoCommonDen_spec {seq dModel dHead : Nat} +/-- Unfold `headValueValsLoCommonDenArray` to its `Array.ofFn` definition. -/ +theorem headValueValsLoCommonDenArray_spec {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Dyadic) : - headValueValsLoCommonDen inputs vLo vHi = + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueValsLoCommonDenArray inputs vLo vHi = let dirHead := headValueDirHead inputs - Bounds.cacheBound (fun k => + Array.ofFn (fun k => Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) := rfl +/-- Cached lower value bounds from V intervals using a common-denominator sum. -/ +def headValueValsLoCommonDen {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : Fin seq → Rat := + let arr := headValueValsLoCommonDenArray inputs vLo vHi + fun k => arr.getD k.1 (0 : Rat) + +theorem headValueValsLoCommonDen_spec {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueValsLoCommonDen inputs vLo vHi = + let arr := headValueValsLoCommonDenArray inputs vLo vHi + fun k => arr.getD k.1 (0 : Rat) := rfl + +/-- Common-denominator lower bounds agree with cached rational bounds pointwise. -/ +theorem headValueValsLoCommonDenArray_eq {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueValsLoCommonDenArray inputs vLo vHi = headValueValsLoArray inputs vLo vHi := by + rfl + theorem headValueValsLoCommonDen_eq {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Dyadic) : + (vLo vHi : Fin seq → Fin dHead → Rat) : headValueValsLoCommonDen inputs vLo vHi = headValueValsLo inputs vLo vHi := by - classical funext k - simp [headValueValsLoCommonDen, headValueValsLo, Bounds.cacheBound_apply, - Bounds.dotIntervalLowerCommonDen_eq, Bounds.dotIntervalLowerCachedRat_eq] + simp [headValueValsLoCommonDen, headValueValsLo, headValueValsLoCommonDenArray, + headValueValsLoArray] /-- Cached upper value bounds from V intervals. -/ -def headValueValsHi {seq dModel dHead : Nat} +def headValueValsHiArray {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Dyadic) : Fin seq → Dyadic := + (vLo vHi : Fin seq → Fin dHead → Rat) : Array Rat := let dirHead := headValueDirHead inputs - Bounds.cacheBound (fun k => - Bounds.dotIntervalUpperCachedDyadic dirHead (vLo k) (vHi k)) + Array.ofFn (fun k => + Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) + +/-- Unfold `headValueValsHiArray` to its `Array.ofFn` definition. -/ +theorem headValueValsHiArray_spec {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueValsHiArray inputs vLo vHi = + let dirHead := headValueDirHead inputs + Array.ofFn (fun k => + Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) := rfl + +/-- Cached upper value bounds from V intervals. -/ +def headValueValsHi {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : Fin seq → Rat := + let arr := headValueValsHiArray inputs vLo vHi + fun k => arr.getD k.1 (0 : Rat) theorem headValueValsHi_spec {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Dyadic) : + (vLo vHi : Fin seq → Fin dHead → Rat) : headValueValsHi inputs vLo vHi = - let dirHead := headValueDirHead inputs - Bounds.cacheBound (fun k => - Bounds.dotIntervalUpperCachedDyadic dirHead (vLo k) (vHi k)) := rfl + let arr := headValueValsHiArray inputs vLo vHi + fun k => arr.getD k.1 (0 : Rat) := rfl /-- Cached upper value bounds from V intervals using a common-denominator sum. -/ -def headValueValsHiCommonDen {seq dModel dHead : Nat} +def headValueValsHiCommonDenArray {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Dyadic) : Fin seq → Dyadic := + (vLo vHi : Fin seq → Fin dHead → Rat) : Array Rat := let dirHead := headValueDirHead inputs - Bounds.cacheBound (fun k => + Array.ofFn (fun k => Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) -theorem headValueValsHiCommonDen_spec {seq dModel dHead : Nat} +/-- Unfold `headValueValsHiCommonDenArray` to its `Array.ofFn` definition. -/ +theorem headValueValsHiCommonDenArray_spec {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Dyadic) : - headValueValsHiCommonDen inputs vLo vHi = + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueValsHiCommonDenArray inputs vLo vHi = let dirHead := headValueDirHead inputs - Bounds.cacheBound (fun k => + Array.ofFn (fun k => Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) := rfl +/-- Cached upper value bounds from V intervals using a common-denominator sum. -/ +def headValueValsHiCommonDen {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : Fin seq → Rat := + let arr := headValueValsHiCommonDenArray inputs vLo vHi + fun k => arr.getD k.1 (0 : Rat) + +theorem headValueValsHiCommonDen_spec {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueValsHiCommonDen inputs vLo vHi = + let arr := headValueValsHiCommonDenArray inputs vLo vHi + fun k => arr.getD k.1 (0 : Rat) := rfl + +/-- Common-denominator upper bounds agree with cached rational bounds pointwise. -/ +theorem headValueValsHiCommonDenArray_eq {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueValsHiCommonDenArray inputs vLo vHi = headValueValsHiArray inputs vLo vHi := by + rfl + theorem headValueValsHiCommonDen_eq {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Dyadic) : + (vLo vHi : Fin seq → Fin dHead → Rat) : headValueValsHiCommonDen inputs vLo vHi = headValueValsHi inputs vLo vHi := by - classical funext k - simp [headValueValsHiCommonDen, headValueValsHi, Bounds.cacheBound_apply, - Bounds.dotIntervalUpperCommonDen_eq, Bounds.dotIntervalUpperCachedRat_eq] + simp [headValueValsHiCommonDen, headValueValsHi, headValueValsHiCommonDenArray, + headValueValsHiArray] + +/-- Global lower value bound from an array of per-key values. -/ +def headValueLoArray (valsLo : Array Rat) : Rat := + reduceMinArray valsLo + +/-- Unfold `headValueLoArray` to its reduction helper. -/ +theorem headValueLoArray_spec (valsLo : Array Rat) : + headValueLoArray valsLo = reduceMinArray valsLo := rfl /-- Global lower value bound from cached per-key values. -/ -def headValueLo [NeZero seq] (valsLo : Fin seq → Dyadic) : Dyadic := - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := by simp [univ] - univ.inf' hnonempty valsLo +def headValueLo [NeZero seq] (valsLo : Fin seq → Rat) : Rat := + headValueLoArray (Array.ofFn valsLo) + +theorem headValueLo_spec [NeZero seq] (valsLo : Fin seq → Rat) : + headValueLo valsLo = headValueLoArray (Array.ofFn valsLo) := rfl + +/-- Task wrapper for `headValueLo`. -/ +def headValueLoTask [NeZero seq] (valsLo : Fin seq → Rat) : Task Rat := + reduceMinFnTask valsLo + +theorem headValueLoTask_spec [NeZero seq] (valsLo : Fin seq → Rat) : + headValueLoTask valsLo = reduceMinFnTask valsLo := rfl + +/-- Global upper value bound from an array of per-key values. -/ +def headValueHiArray (valsHi : Array Rat) : Rat := + reduceMaxArray valsHi -theorem headValueLo_spec [NeZero seq] (valsLo : Fin seq → Dyadic) : - headValueLo valsLo = - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := by simp [univ] - univ.inf' hnonempty valsLo := rfl +/-- Unfold `headValueHiArray` to its reduction helper. -/ +theorem headValueHiArray_spec (valsHi : Array Rat) : + headValueHiArray valsHi = reduceMaxArray valsHi := rfl /-- Global upper value bound from cached per-key values. -/ -def headValueHi [NeZero seq] (valsHi : Fin seq → Dyadic) : Dyadic := - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := by simp [univ] - univ.sup' hnonempty valsHi +def headValueHi [NeZero seq] (valsHi : Fin seq → Rat) : Rat := + headValueHiArray (Array.ofFn valsHi) + +theorem headValueHi_spec [NeZero seq] (valsHi : Fin seq → Rat) : + headValueHi valsHi = headValueHiArray (Array.ofFn valsHi) := rfl + +/-- Task wrapper for `headValueHi`. -/ +def headValueHiTask [NeZero seq] (valsHi : Fin seq → Rat) : Task Rat := + reduceMaxFnTask valsHi + +theorem headValueHiTask_spec [NeZero seq] (valsHi : Fin seq → Rat) : + headValueHiTask valsHi = reduceMaxFnTask valsHi := rfl + +/-- Build `HeadValueBounds` from precomputed arrays. -/ +private def headValueBoundsOfArrays {seq dModel dHead : Nat} + (valsLoArr valsHiArr : Array Rat) : HeadValueBounds seq dModel dHead := + let valsLo : Fin seq → Rat := fun k => valsLoArr.getD k.1 (0 : Rat) + let valsHi : Fin seq → Rat := fun k => valsHiArr.getD k.1 (0 : Rat) + let lo := headValueLoArray valsLoArr + let hi := headValueHiArray valsHiArr + { valsLo := valsLo, valsHi := valsHi, lo := lo, hi := hi } -theorem headValueHi_spec [NeZero seq] (valsHi : Fin seq → Dyadic) : - headValueHi valsHi = - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := by simp [univ] - univ.sup' hnonempty valsHi := rfl +/-- Build a cached bounds array in parallel from a per-key computation. -/ +private def buildBoundArrayTask [NeZero seq] (f : Fin seq → Rat) : Task (Array Rat) := + let n := seq + let chunkSize : Nat := 64 + let chunks : Nat := (n + chunkSize - 1) / chunkSize + let hpos : 0 < seq := Nat.pos_of_ne_zero (by simpa using (NeZero.ne (n := seq))) + let defaultIdx : Fin seq := ⟨0, hpos⟩ + let idxs : Array (Fin seq) := Array.ofFn (fun i : Fin seq => i) + let chunkTasks : List (Task (Array Rat)) := + (List.range chunks).map (fun c => + Task.spawn (fun _ => + let start := c * chunkSize + let stop := Nat.min n (start + chunkSize) + let vals := + (List.range (stop - start)).map (fun i => + f (idxs.getD (start + i) defaultIdx)) + vals.toArray)) + Task.mapList (fun xs => xs.foldl (fun acc arr => acc ++ arr) #[]) chunkTasks /-- Compute value bounds from V interval bounds. -/ def headValueBounds [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Dyadic) : + (vLo vHi : Fin seq → Fin dHead → Rat) : HeadValueBounds seq dModel dHead := - let valsLo := headValueValsLo inputs vLo vHi - let valsHi := headValueValsHi inputs vLo vHi - let lo := headValueLo valsLo - let hi := headValueHi valsHi - { valsLo := valsLo, valsHi := valsHi, lo := lo, hi := hi } + let valsLoArr := headValueValsLoArray inputs vLo vHi + let valsHiArr := headValueValsHiArray inputs vLo vHi + headValueBoundsOfArrays valsLoArr valsHiArr theorem headValueBounds_spec [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Dyadic) : + (vLo vHi : Fin seq → Fin dHead → Rat) : headValueBounds inputs vLo vHi = - let valsLo := headValueValsLo inputs vLo vHi - let valsHi := headValueValsHi inputs vLo vHi - let lo := headValueLo valsLo - let hi := headValueHi valsHi - { valsLo := valsLo, valsHi := valsHi, lo := lo, hi := hi } := rfl + let valsLoArr := headValueValsLoArray inputs vLo vHi + let valsHiArr := headValueValsHiArray inputs vLo vHi + headValueBoundsOfArrays valsLoArr valsHiArr := rfl + +/-- Compute value bounds from V interval bounds in parallel. -/ +def headValueBoundsTask [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + Task (HeadValueBounds seq dModel dHead) := + let dirHead := headValueDirHead inputs + let valsLoTask := buildBoundArrayTask (fun k => + Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) + let valsHiTask := buildBoundArrayTask (fun k => + Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) + Task.bind valsLoTask (fun valsLoArr => + Task.map (fun valsHiArr => headValueBoundsOfArrays valsLoArr valsHiArr) valsHiTask) + +/-- Unfold `headValueBoundsTask` to its task graph. -/ +theorem headValueBoundsTask_spec [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueBoundsTask inputs vLo vHi = + let dirHead := headValueDirHead inputs + let valsLoTask := buildBoundArrayTask (fun k => + Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) + let valsHiTask := buildBoundArrayTask (fun k => + Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) + Task.bind valsLoTask (fun valsLoArr => + Task.map (fun valsHiArr => headValueBoundsOfArrays valsLoArr valsHiArr) valsHiTask) := rfl /-- Compute value bounds from V interval bounds using a common-denominator sum. -/ def headValueBoundsCommonDen [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Dyadic) : + (vLo vHi : Fin seq → Fin dHead → Rat) : HeadValueBounds seq dModel dHead := - let valsLo := headValueValsLoCommonDen inputs vLo vHi - let valsHi := headValueValsHiCommonDen inputs vLo vHi - let lo := headValueLo valsLo - let hi := headValueHi valsHi - { valsLo := valsLo, valsHi := valsHi, lo := lo, hi := hi } + let valsLoArr := headValueValsLoCommonDenArray inputs vLo vHi + let valsHiArr := headValueValsHiCommonDenArray inputs vLo vHi + headValueBoundsOfArrays valsLoArr valsHiArr theorem headValueBoundsCommonDen_spec [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Dyadic) : + (vLo vHi : Fin seq → Fin dHead → Rat) : headValueBoundsCommonDen inputs vLo vHi = - let valsLo := headValueValsLoCommonDen inputs vLo vHi - let valsHi := headValueValsHiCommonDen inputs vLo vHi - let lo := headValueLo valsLo - let hi := headValueHi valsHi - { valsLo := valsLo, valsHi := valsHi, lo := lo, hi := hi } := rfl + let valsLoArr := headValueValsLoCommonDenArray inputs vLo vHi + let valsHiArr := headValueValsHiCommonDenArray inputs vLo vHi + headValueBoundsOfArrays valsLoArr valsHiArr := rfl + +/-- Compute value bounds from V intervals using a common-denominator sum in parallel. -/ +def headValueBoundsCommonDenTask [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + Task (HeadValueBounds seq dModel dHead) := + let dirHead := headValueDirHead inputs + let valsLoTask := buildBoundArrayTask (fun k => + Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) + let valsHiTask := buildBoundArrayTask (fun k => + Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) + Task.bind valsLoTask (fun valsLoArr => + Task.map (fun valsHiArr => headValueBoundsOfArrays valsLoArr valsHiArr) valsHiTask) + +/-- Unfold `headValueBoundsCommonDenTask` to its task graph. -/ +theorem headValueBoundsCommonDenTask_spec [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueBoundsCommonDenTask inputs vLo vHi = + let dirHead := headValueDirHead inputs + let valsLoTask := buildBoundArrayTask (fun k => + Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) + let valsHiTask := buildBoundArrayTask (fun k => + Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) + Task.bind valsLoTask (fun valsLoArr => + Task.map (fun valsHiArr => headValueBoundsOfArrays valsLoArr valsHiArr) valsHiTask) := rfl theorem headValueBoundsCommonDen_eq [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Dyadic) : + (vLo vHi : Fin seq → Fin dHead → Rat) : headValueBoundsCommonDen inputs vLo vHi = headValueBounds inputs vLo vHi := by classical - simp [headValueBoundsCommonDen, headValueBounds, headValueValsLoCommonDen_eq, - headValueValsHiCommonDen_eq] + simp [headValueBoundsCommonDen, headValueBounds, headValueValsLoCommonDenArray_eq, + headValueValsHiCommonDenArray_eq] end Sound diff --git a/Nfp/Sound/Induction/HeadOutput.lean b/Nfp/Sound/Induction/HeadOutput.lean index e7dbf8e..d14fe03 100644 --- a/Nfp/Sound/Induction/HeadOutput.lean +++ b/Nfp/Sound/Induction/HeadOutput.lean @@ -103,19 +103,19 @@ def buildHeadOutputIntervalFromHead? [NeZero seq] | none => exact none | some certWithProof => rcases certWithProof with ⟨cert, hcert⟩ - let lnBounds : Fin (Nat.succ n) → (Fin dModel → Dyadic) × (Fin dModel → Dyadic) := + let lnBounds : Fin (Nat.succ n) → (Fin dModel → Rat) × (Fin dModel → Rat) := fun q => Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q) - let lnLo : Fin (Nat.succ n) → Fin dModel → Dyadic := fun q => (lnBounds q).1 - let lnHi : Fin (Nat.succ n) → Fin dModel → Dyadic := fun q => (lnBounds q).2 - let vLo : Fin (Nat.succ n) → Fin dHead → Dyadic := fun q d => + let lnLo : Fin (Nat.succ n) → Fin dModel → Rat := fun q => (lnBounds q).1 + let lnHi : Fin (Nat.succ n) → Fin dModel → Rat := fun q => (lnBounds q).2 + let vLo : Fin (Nat.succ n) → Fin dHead → Rat := fun q d => dotIntervalLower (fun j => inputs.wv j d) (lnLo q) (lnHi q) + inputs.bv d - let vHi : Fin (Nat.succ n) → Fin dHead → Dyadic := fun q d => + let vHi : Fin (Nat.succ n) → Fin dHead → Rat := fun q d => dotIntervalUpper (fun j => inputs.wv j d) (lnLo q) (lnHi q) + inputs.bv d - let headValueLo : Fin (Nat.succ n) → Fin dModel → Dyadic := fun k i => + let headValueLo : Fin (Nat.succ n) → Fin dModel → Rat := fun k i => dotIntervalLower (fun d => inputs.wo i d) (vLo k) (vHi k) - let headValueHi : Fin (Nat.succ n) → Fin dModel → Dyadic := fun k i => + let headValueHi : Fin (Nat.succ n) → Fin dModel → Rat := fun k i => dotIntervalUpper (fun d => inputs.wo i d) (vLo k) (vHi k) have hln_bounds : ∀ q i, (lnLo q i : Real) ≤ lnRealOfInputs inputs q i ∧ @@ -143,10 +143,10 @@ def buildHeadOutputIntervalFromHead? [NeZero seq] (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi constructor · simpa [vLo, vRealOfInputs, Bounds.cacheBound2_apply, - Bounds.dotIntervalLowerCachedRat_eq, dyadicToReal_add] using + Bounds.dotIntervalLowerCachedRat_eq, ratToReal_add] using add_le_add_right hlow (inputs.bv d : Real) · simpa [vHi, vRealOfInputs, Bounds.cacheBound2_apply, - Bounds.dotIntervalUpperCachedRat_eq, dyadicToReal_add] using + Bounds.dotIntervalUpperCachedRat_eq, ratToReal_add] using add_le_add_right hhigh (inputs.bv d : Real) have hhead_bounds : ∀ k i, (headValueLo k i : Real) ≤ headValueRealOfInputs inputs k i ∧ @@ -173,9 +173,9 @@ def buildHeadOutputIntervalFromHead? [NeZero seq] let activeSet : Finset (Fin (Nat.succ n)) := cert.active let univ : Finset (Fin (Nat.succ n)) := Finset.univ have huniv : univ.Nonempty := by simp [univ] - let loVal : Fin dModel → Dyadic := fun i => + let loVal : Fin dModel → Rat := fun i => univ.inf' huniv (fun k => headValueLo k i) - let hiVal : Fin dModel → Dyadic := fun i => + let hiVal : Fin dModel → Rat := fun i => univ.sup' huniv (fun k => headValueHi k i) have hvalsBoundsReal : ∀ i, Layers.ValueRangeBounds (Val := Real) @@ -185,14 +185,14 @@ def buildHeadOutputIntervalFromHead? [NeZero seq] refine { lo_le_hi := ?_, lo_le := ?_, le_hi := ?_ } · rcases (Finset.univ_nonempty : univ.Nonempty) with ⟨k0, hk0⟩ have hmem0 : k0 ∈ univ := hk0 - have hloDyadic : loVal i ≤ headValueLo k0 i := by + have hloRat : loVal i ≤ headValueLo k0 i := by change loVal i ≤ headValueLo k0 i dsimp [loVal] refine (Finset.inf'_le_iff (s := univ) (H := huniv) (f := fun k => headValueLo k i) (a := headValueLo k0 i)).2 ?_ refine ⟨k0, hmem0, ?_⟩ exact le_rfl - have hhiDyadic : headValueHi k0 i ≤ hiVal i := by + have hhiRat : headValueHi k0 i ≤ hiVal i := by change headValueHi k0 i ≤ hiVal i dsimp [hiVal] refine (Finset.le_sup'_iff (s := univ) (H := huniv) @@ -200,13 +200,13 @@ def buildHeadOutputIntervalFromHead? [NeZero seq] exact ⟨k0, ⟨hmem0, le_rfl⟩⟩ have hbounds := hhead_bounds k0 i have hreal : (loVal i : Real) ≤ (hiVal i : Real) := by - refine le_trans (dyadicToReal_le_of_le hloDyadic) ?_ + refine le_trans (ratToReal_le_of_le hloRat) ?_ refine le_trans hbounds.1 ?_ - exact le_trans hbounds.2 (dyadicToReal_le_of_le hhiDyadic) + exact le_trans hbounds.2 (ratToReal_le_of_le hhiRat) exact hreal · intro k have hmem : k ∈ univ := by simp [univ] - have hloDyadic : loVal i ≤ headValueLo k i := by + have hloRat : loVal i ≤ headValueLo k i := by change loVal i ≤ headValueLo k i dsimp [loVal] refine (Finset.inf'_le_iff (s := univ) (H := huniv) @@ -214,10 +214,10 @@ def buildHeadOutputIntervalFromHead? [NeZero seq] refine ⟨k, hmem, ?_⟩ exact le_rfl have hbounds := hhead_bounds k i - exact (dyadicToReal_le_of_le hloDyadic) |>.trans hbounds.1 + exact (ratToReal_le_of_le hloRat) |>.trans hbounds.1 · intro k have hmem : k ∈ univ := by simp [univ] - have hhiDyadic : headValueHi k i ≤ hiVal i := by + have hhiRat : headValueHi k i ≤ hiVal i := by change headValueHi k i ≤ hiVal i dsimp [hiVal] refine (Finset.le_sup'_iff (s := univ) (H := huniv) @@ -225,7 +225,7 @@ def buildHeadOutputIntervalFromHead? [NeZero seq] exact ⟨k, ⟨hmem, le_rfl⟩⟩ have hbounds := hhead_bounds k i exact hbounds.2.trans - (dyadicToReal_le_of_le hhiDyadic) + (ratToReal_le_of_le hhiRat) have hsoftmax : Layers.SoftmaxMarginBoundsOn (Val := Real) (cert.eps : Real) (cert.margin : Real) @@ -263,19 +263,19 @@ def buildHeadOutputIntervalFromHead? [NeZero seq] (vals := fun k => headValueRealOfInputs inputs k i) (hweights := hweights) (hvals := hvalsBoundsReal i) - let delta : Fin dModel → Dyadic := fun i => hiVal i - loVal i - let boundLoDyadic : Fin (Nat.succ n) → Fin dModel → Dyadic := fun q i => + let delta : Fin dModel → Rat := fun i => hiVal i - loVal i + let boundLoRat : Fin (Nat.succ n) → Fin dModel → Rat := fun q i => headValueLo (cert.prev q) i - cert.eps * delta i - let boundHiDyadic : Fin (Nat.succ n) → Fin dModel → Dyadic := fun q i => + let boundHiRat : Fin (Nat.succ n) → Fin dModel → Rat := fun q i => headValueHi (cert.prev q) i + cert.eps * delta i - let loOut : Fin dModel → Dyadic := fun i => + let loOut : Fin dModel → Rat := fun i => if h : activeSet.Nonempty then - activeSet.inf' h (fun q => boundLoDyadic q i) + activeSet.inf' h (fun q => boundLoRat q i) else 0 - let hiOut : Fin dModel → Dyadic := fun i => + let hiOut : Fin dModel → Rat := fun i => if h : activeSet.Nonempty then - activeSet.sup' h (fun q => boundHiDyadic q i) + activeSet.sup' h (fun q => boundHiRat q i) else 0 have hout : @@ -291,7 +291,7 @@ def buildHeadOutputIntervalFromHead? [NeZero seq] simp [headOutput, headOutputWithScores, scoresReal, weights] have hprev_bounds := hhead_bounds (cert.prev q) i have hupper : - headOutput inputs q i ≤ (boundHiDyadic q i : Real) := by + headOutput inputs q i ≤ (boundHiRat q i : Real) := by have hupper' : dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) ≤ headValueRealOfInputs inputs (cert.prev q) i + @@ -307,11 +307,11 @@ def buildHeadOutputIntervalFromHead? [NeZero seq] hprev_bounds.2 exact le_trans hupper' hprev_bounds' simpa - [hout_def, boundHiDyadic, delta, dyadicToReal_add, dyadicToReal_mul, - dyadicToReal_sub] using + [hout_def, boundHiRat, delta, ratToReal_add, ratToReal_mul, + ratToReal_sub] using hupper'' have hlower : - (boundLoDyadic q i : Real) ≤ headOutput inputs q i := by + (boundLoRat q i : Real) ≤ headOutput inputs q i := by have hlower' : (headValueRealOfInputs inputs (cert.prev q) i : Real) - (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) ≤ @@ -324,26 +324,26 @@ def buildHeadOutputIntervalFromHead? [NeZero seq] refine le_trans (sub_le_sub_right hprev_bounds.1 ((cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)))) ?_ exact hlower' - simpa [hout_def, boundLoDyadic, delta, dyadicToReal_mul, dyadicToReal_sub] using + simpa [hout_def, boundLoRat, delta, ratToReal_mul, ratToReal_sub] using hlower'' have hlo : - (loOut i : Real) ≤ (boundLoDyadic q i : Real) := by - have hloDyadic : loOut i ≤ boundLoDyadic q i := by + (loOut i : Real) ≤ (boundLoRat q i : Real) := by + have hloRat : loOut i ≤ boundLoRat q i := by simpa [loOut, hactive] using (Finset.inf'_le (s := activeSet) - (f := fun q => boundLoDyadic q i) + (f := fun q => boundLoRat q i) (b := q) hq) - exact dyadicToReal_le_of_le hloDyadic + exact ratToReal_le_of_le hloRat have hhi : - (boundHiDyadic q i : Real) ≤ (hiOut i : Real) := by - have hhiDyadic : boundHiDyadic q i ≤ hiOut i := by + (boundHiRat q i : Real) ≤ (hiOut i : Real) := by + have hhiRat : boundHiRat q i ≤ hiOut i := by simpa [hiOut, hactive] using (Finset.le_sup' (s := activeSet) - (f := fun q => boundHiDyadic q i) + (f := fun q => boundHiRat q i) (b := q) hq) - exact dyadicToReal_le_of_le hhiDyadic + exact ratToReal_le_of_le hhiRat exact ⟨le_trans hlo hlower, le_trans hupper hhi⟩ have hbounds : Circuit.ResidualIntervalBounds { lo := loOut, hi := hiOut } := by refine { lo_le_hi := ?_ } @@ -353,7 +353,7 @@ def buildHeadOutputIntervalFromHead? [NeZero seq] have hout_i := hout q hq i have hleReal : (loOut i : Real) ≤ (hiOut i : Real) := le_trans hout_i.1 hout_i.2 - exact (dyadicToReal_le_iff (x := loOut i) (y := hiOut i)).1 hleReal + exact (ratToReal_le_iff (x := loOut i) (y := hiOut i)).1 hleReal · simp [loOut, hiOut, hactive] let certOut : Circuit.ResidualIntervalCert dModel := { lo := loOut, hi := hiOut } exact some diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean index 27c363c..90c14ac 100644 --- a/Nfp/Sound/Induction/LogitDiff.lean +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -25,14 +25,14 @@ noncomputable def headLogitDiff (inputs : Model.InductionHeadInputs seq dModel d dotProduct (weights q) (valsRealOfInputs inputs) /-- Lower bound computed from the per-key lower bounds in an induction certificate. -/ -def logitDiffLowerBoundFromCert (c : InductionHeadCert seq) : Option Dyadic := +def logitDiffLowerBoundFromCert (c : InductionHeadCert seq) : Option Rat := Circuit.logitDiffLowerBoundAt c.active c.prev c.epsAt c.values.lo c.values.hi c.values.valsLo theorem logitDiffLowerBoundFromCert_le (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) (hsound : InductionHeadCertSound inputs c) - {lb : Dyadic} (hbound : logitDiffLowerBoundFromCert c = some lb) + {lb : Rat} (hbound : logitDiffLowerBoundFromCert c = some lb) {q : Fin seq} (hq : q ∈ c.active) : (lb : Real) ≤ headLogitDiff inputs q := by classical @@ -50,7 +50,7 @@ theorem logitDiffLowerBoundFromCert_le Layers.ValueRangeBounds (Val := Real) (c.values.lo : Real) (c.values.hi : Real) (valsRealOfInputs inputs) := by refine { lo_le_hi := ?_, lo_le := ?_, le_hi := ?_ } - · exact dyadicToReal_le_of_le hsound.value_bounds.lo_le_hi + · exact ratToReal_le_of_le hsound.value_bounds.lo_le_hi · intro k exact le_trans (hsound.value_bounds.lo_le_valsLo k) @@ -71,7 +71,7 @@ theorem logitDiffLowerBoundFromCert_le (weights := weights) (vals := valsRealOfInputs inputs) hweights hvalsRange - have hboundDyadic : + have hboundRat : lb ≤ c.values.valsLo (c.prev q) - c.epsAt q * (c.values.hi - c.values.lo) := by refine @@ -90,9 +90,9 @@ theorem logitDiffLowerBoundFromCert_le (c.epsAt q : Real) * ((c.values.hi : Real) - (c.values.lo : Real)) := by have hboundReal' : (lb : Real) ≤ - (c.values.valsLo (c.prev q) - c.epsAt q * (c.values.hi - c.values.lo) : Dyadic) := by - exact dyadicToReal_le_of_le hboundDyadic - simpa [dyadicToReal_sub, dyadicToReal_mul] using hboundReal' + (c.values.valsLo (c.prev q) - c.epsAt q * (c.values.hi - c.values.lo) : Rat) := by + exact ratToReal_le_of_le hboundRat + simpa [ratToReal_sub, ratToReal_mul] using hboundReal' have hvalsLo : (c.values.valsLo (c.prev q) : Real) ≤ valsRealOfInputs inputs (c.prev q) := by @@ -123,7 +123,7 @@ structure InductionLogitLowerBoundResult /-- Soundness proof for the induction certificate. -/ sound : InductionHeadCertSound inputs cert /-- Reported lower bound on logit diff. -/ - lb : Dyadic + lb : Rat /-- `lb` is computed from `logitDiffLowerBoundFromCert`. -/ lb_def : logitDiffLowerBoundFromCert cert = some lb /-- The lower bound is sound on active queries. -/ diff --git a/Nfp/Sound/Induction/OneHot.lean b/Nfp/Sound/Induction/OneHot.lean index 42200eb..048db7a 100644 --- a/Nfp/Sound/Induction/OneHot.lean +++ b/Nfp/Sound/Induction/OneHot.lean @@ -25,12 +25,12 @@ theorem oneHot_bounds_at_of_marginAt (active : Finset (Fin seq)) (prev : Fin seq → Fin seq) (scoresReal : Fin seq → Fin seq → Real) - (marginAt : Fin seq → Dyadic) - (epsAt : Fin seq → Dyadic) + (marginAt : Fin seq → Rat) + (epsAt : Fin seq → Rat) (hepsAt : ∀ q, epsAt q = - if marginAt q < 0 then (1 : Dyadic) else - dyadicDivUp (seq - 1) (1 + marginAt q)) + if marginAt q < 0 then (1 : Rat) else + ratDivUp (seq - 1) (1 + marginAt q)) (hseq : (1 : Nat) ≤ seq) (hscore_margin_real_at : ∀ q, q ∈ active → ∀ k, k ≠ prev q → @@ -63,7 +63,7 @@ theorem oneHot_bounds_at_of_marginAt have hsum_others_le : (∑ k ∈ others q, weights q k) ≤ (epsAt q : Real) := by by_cases hneg : marginAt q < 0 · have heps : (epsAt q : Real) = 1 := by - simp [hepsAt, hneg, dyadicToReal_one] + simp [hepsAt, hneg] have hsubset : others q ⊆ (Finset.univ : Finset (Fin seq)) := by intro k hk simp @@ -85,7 +85,7 @@ theorem oneHot_bounds_at_of_marginAt simpa [heps] using hsum_le' · have hnonneg : 0 ≤ marginAt q := le_of_not_gt hneg have hnonneg_real : 0 ≤ (marginAt q : Real) := by - exact dyadicToReal_nonneg_of_nonneg hnonneg + exact ratToReal_nonneg_of_nonneg hnonneg have hbound : ∀ k ∈ others q, weights q k ≤ (1 + (marginAt q : Real))⁻¹ := by @@ -117,18 +117,16 @@ theorem oneHot_bounds_at_of_marginAt (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ ≤ (epsAt q : Real) := by have hden : (1 + marginAt q) ≠ 0 := by intro hzero - have hrat : (1 : Rat) + (marginAt q).toRat = 0 := by - have := congrArg Dyadic.toRat hzero - simpa [Dyadic.toRat_add, Dyadic.toRat_natCast] using this - have hnonneg_rat : (0 : Rat) ≤ (marginAt q).toRat := - (Dyadic.toRat_le_toRat_iff (x := 0) (y := marginAt q)).2 hnonneg + have hrat : (1 : Rat) + marginAt q = 0 := by + simpa using hzero + have hnonneg_rat : (0 : Rat) ≤ marginAt q := hnonneg linarith have hrat : (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ ≤ - (dyadicDivUp (seq - 1) (1 + marginAt q) : Real) := by - have hrat' := dyadicDivUp_ge_real (seq - 1) (1 + marginAt q) hden - simpa [dyadicToReal, Dyadic.toRat_add, Dyadic.toRat_natCast, - Rat.cast_div, Rat.cast_add, Rat.cast_natCast, div_eq_mul_inv] using hrat' + (ratDivUp (seq - 1) (1 + marginAt q) : Real) := by + have hrat' := ratDivUp_ge_real (seq - 1) (1 + marginAt q) hden + simpa [ratToReal, Rat.cast_div, Rat.cast_add, Rat.cast_natCast, + div_eq_mul_inv] using hrat' simpa [hepsAt, hneg] using hrat exact le_trans hsum_le' heps have hsum_eq : @@ -180,7 +178,7 @@ theorem oneHot_bounds_at_of_marginAt simpa [heps] using hsum_le' · have hnonneg : 0 ≤ marginAt q := le_of_not_gt hneg have hnonneg_real : 0 ≤ (marginAt q : Real) := by - exact dyadicToReal_nonneg_of_nonneg hnonneg + exact ratToReal_nonneg_of_nonneg hnonneg have hbound : ∀ j ∈ others q, weights q j ≤ (1 + (marginAt q : Real))⁻¹ := by @@ -212,18 +210,16 @@ theorem oneHot_bounds_at_of_marginAt (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ ≤ (epsAt q : Real) := by have hden : (1 + marginAt q) ≠ 0 := by intro hzero - have hrat : (1 : Rat) + (marginAt q).toRat = 0 := by - have := congrArg Dyadic.toRat hzero - simpa [Dyadic.toRat_add, Dyadic.toRat_natCast] using this - have hnonneg_rat : (0 : Rat) ≤ (marginAt q).toRat := - (Dyadic.toRat_le_toRat_iff (x := 0) (y := marginAt q)).2 hnonneg + have hrat : (1 : Rat) + marginAt q = 0 := by + simpa using hzero + have hnonneg_rat : (0 : Rat) ≤ marginAt q := hnonneg linarith have hrat : (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ ≤ - (dyadicDivUp (seq - 1) (1 + marginAt q) : Real) := by - have hrat' := dyadicDivUp_ge_real (seq - 1) (1 + marginAt q) hden - simpa [dyadicToReal, Dyadic.toRat_add, Dyadic.toRat_natCast, - Rat.cast_div, Rat.cast_add, Rat.cast_natCast, div_eq_mul_inv] using hrat' + (ratDivUp (seq - 1) (1 + marginAt q) : Real) := by + have hrat' := ratDivUp_ge_real (seq - 1) (1 + marginAt q) hden + simpa [ratToReal, Rat.cast_div, Rat.cast_add, Rat.cast_natCast, + div_eq_mul_inv] using hrat' simpa [hepsAt, hneg] using hrat exact le_trans hsum_le' heps have hk' : k ∈ others q := by diff --git a/Nfp/Sound/Linear/FinFold.lean b/Nfp/Sound/Linear/FinFold.lean index c722133..a8cff01 100644 --- a/Nfp/Sound/Linear/FinFold.lean +++ b/Nfp/Sound/Linear/FinFold.lean @@ -29,42 +29,42 @@ theorem foldlFin_eq_foldl (n : Nat) (f : α → Fin n → α) (init : α) : simpa [foldlFin] using (Fin.dfoldl_eq_foldl (n := n) (f := fun i acc => f acc i) (x := init)) -/-- Tail-recursive sum over `Fin n` (Dyadic-valued). -/ -def sumFin (n : Nat) (f : Fin n → Dyadic) : Dyadic := +/-- Tail-recursive sum over `Fin n` (Rat-valued). -/ +def sumFin (n : Nat) (f : Fin n → Rat) : Rat := foldlFin n (fun acc i => acc + f i) 0 /-- Tail-recursive sum over `Fin n` (alias for `sumFin`). -/ -def sumFinCommonDen (n : Nat) (f : Fin n → Dyadic) : Dyadic := +def sumFinCommonDen (n : Nat) (f : Fin n → Rat) : Rat := sumFin n f /-- `sumFin` as a left fold over the finite range list. -/ -theorem sumFin_eq_list_foldl (n : Nat) (f : Fin n → Dyadic) : +theorem sumFin_eq_list_foldl (n : Nat) (f : Fin n → Rat) : sumFin n f = (List.finRange n).foldl (fun acc i => acc + f i) 0 := by simpa [sumFin, foldlFin_eq_foldl] using - (Fin.foldl_eq_foldl_finRange (f := fun acc i => acc + f i) (x := (0 : Dyadic)) (n := n)) + (Fin.foldl_eq_foldl_finRange (f := fun acc i => acc + f i) (x := (0 : Rat)) (n := n)) /-- `sumFin` agrees with the `Finset.univ` sum. -/ -theorem sumFin_eq_sum_univ {n : Nat} (f : Fin n → Dyadic) : +theorem sumFin_eq_sum_univ {n : Nat} (f : Fin n → Rat) : sumFin n f = ∑ i, f i := by classical have hfold : sumFin n f = (List.finRange n).foldl (fun acc i => acc + f i) 0 := by simpa using sumFin_eq_list_foldl n f have hmap : - ((List.finRange n).map f).foldl (fun acc x : Dyadic => acc + x) 0 = + ((List.finRange n).map f).foldl (fun acc x : Rat => acc + x) 0 = (List.finRange n).foldl (fun acc i => acc + f i) 0 := by simpa using - (List.foldl_map (f := f) (g := fun acc x : Dyadic => acc + x) - (l := List.finRange n) (init := (0 : Dyadic))) - let _ : Std.Commutative (fun a b : Dyadic => a + b) := + (List.foldl_map (f := f) (g := fun acc x : Rat => acc + x) + (l := List.finRange n) (init := (0 : Rat))) + let _ : Std.Commutative (fun a b : Rat => a + b) := ⟨by intro a b; exact add_comm _ _⟩ - let _ : Std.Associative (fun a b : Dyadic => a + b) := + let _ : Std.Associative (fun a b : Rat => a + b) := ⟨by intro a b c; exact add_assoc _ _ _⟩ have hfoldr : - ((List.finRange n).map f).foldl (fun acc x : Dyadic => acc + x) 0 = + ((List.finRange n).map f).foldl (fun acc x : Rat => acc + x) 0 = ((List.finRange n).map f).foldr (fun x acc => x + acc) 0 := by simpa using - (List.foldl_eq_foldr (f := fun acc x : Dyadic => acc + x) + (List.foldl_eq_foldr (f := fun acc x : Rat => acc + x) (a := 0) (l := (List.finRange n).map f)) have hsum_list : ((List.finRange n).map f).sum = (List.finRange n).foldl (fun acc i => acc + f i) 0 := by @@ -72,7 +72,7 @@ theorem sumFin_eq_sum_univ {n : Nat} (f : Fin n → Dyadic) : ((List.finRange n).map f).sum = ((List.finRange n).map f).foldr (fun x acc => x + acc) 0 := by rfl - _ = ((List.finRange n).map f).foldl (fun acc x : Dyadic => acc + x) 0 := by + _ = ((List.finRange n).map f).foldl (fun acc x : Rat => acc + x) 0 := by exact hfoldr.symm _ = (List.finRange n).foldl (fun acc i => acc + f i) 0 := by exact hmap @@ -84,38 +84,38 @@ theorem sumFin_eq_sum_univ {n : Nat} (f : Fin n → Dyadic) : _ = ((List.finRange n).map f).sum := hsum_list.symm _ = ∑ i, f i := hsum_univ -/-- Casting a `Finset.univ` dyadic sum to `Real` commutes with summation. -/ -theorem dyadicToReal_sum_univ {n : Nat} (f : Fin n → Dyadic) : - ((∑ i, f i : Dyadic) : Real) = ∑ i, (f i : Real) := by +/-- Casting a `Finset.univ` rational sum to `Real` commutes with summation. -/ +theorem ratToReal_sum_univ {n : Nat} (f : Fin n → Rat) : + ratToReal (∑ i, f i) = ∑ i, ratToReal (f i) := by classical refine Finset.induction_on (Finset.univ : Finset (Fin n)) ?_ ?_ · simp · intro a s ha hs - simp [Finset.sum_insert, ha, hs, dyadicToReal_add] + simp [Finset.sum_insert, ha, hs, ratToReal_add] -/-- Casting a dyadic `sumFin` to `Real` commutes with summation. -/ -theorem dyadicToReal_sumFin {n : Nat} (f : Fin n → Dyadic) : - (sumFin n f : Real) = ∑ i, (f i : Real) := by +/-- Casting a rational `sumFin` to `Real` commutes with summation. -/ +theorem ratToReal_sumFin {n : Nat} (f : Fin n → Rat) : + ratToReal (sumFin n f) = ∑ i, ratToReal (f i) := by classical have hsum : sumFin n f = ∑ i, f i := sumFin_eq_sum_univ (f := f) - have hcast : ((∑ i, f i : Dyadic) : Real) = ∑ i, (f i : Real) := - dyadicToReal_sum_univ (f := f) + have hcast : ratToReal (∑ i, f i) = ∑ i, ratToReal (f i) := + ratToReal_sum_univ (f := f) simpa [hsum] using hcast /-- `sumFinCommonDen` agrees with `sumFin`. -/ -theorem sumFinCommonDen_eq_sumFin (n : Nat) (f : Fin n → Dyadic) : +theorem sumFinCommonDen_eq_sumFin (n : Nat) (f : Fin n → Rat) : sumFinCommonDen n f = sumFin n f := rfl -/-- Dot product over `Fin n` (Dyadic-valued). -/ -def dotFin (n : Nat) (x y : Fin n → Dyadic) : Dyadic := +/-- Dot product over `Fin n` (Rat-valued). -/ +def dotFin (n : Nat) (x y : Fin n → Rat) : Rat := sumFin n (fun i => x i * y i) /-- Unfolding lemma for `dotFin`. -/ -theorem dotFin_def (n : Nat) (x y : Fin n → Dyadic) : +theorem dotFin_def (n : Nat) (x y : Fin n → Rat) : dotFin n x y = sumFin n (fun i => x i * y i) := rfl /-- `dotFin` matches `dotProduct`. -/ -theorem dotFin_eq_dotProduct (n : Nat) (x y : Fin n → Dyadic) : +theorem dotFin_eq_dotProduct (n : Nat) (x y : Fin n → Rat) : dotFin n x y = dotProduct x y := by simp [dotFin_def, sumFin_eq_sum_univ, dotProduct] From 52f9b3234798bc3a5b964f05641a93bd1a3f0e70 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Thu, 8 Jan 2026 13:11:29 +0100 Subject: [PATCH 112/244] Update AGENTS build instructions and module map --- AGENTS.md | 111 +++++------------------------------------------------- 1 file changed, 10 insertions(+), 101 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index f68907e..097fb54 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -46,15 +46,6 @@ or `-uuu` (or equivalent) when searching. ### 1.1 No fake proofs - **Forbidden:** `sorry` - **Forbidden:** introducing new nontrivial axioms beyond what mathlib already uses. -- If you can’t prove a lemma as stated: - - reconsider the statement (missing assumptions? wrong generality?), - - introduce helper lemmas, - - or refactor the structure so the proof becomes natural. - - Do **not** “paper over” gaps. - -> Lean 4.26+ exploration tools (`finish?`, `try?`, `grind => finish?`, etc.) may *suggest* scripts that -> contain `sorry` (useful for debugging). Treat those suggestions as **scratch** only. -> **No `sorry` may reach the branch.** ### 1.2 Linting stays on - **Never** disable linters globally or locally. @@ -62,7 +53,7 @@ or `-uuu` (or equivalent) when searching. - Fix the code/proofs instead. ### 1.3 Clean build -- `lake build -q --wfail` must succeed. +- `lake build --wfail` must succeed. - Any warning is treated as an error: resolve it, do not ignore it. ### 1.4 Core invariants must remain true @@ -93,21 +84,12 @@ The library’s claims rest on these being preserved (preferably with explicit l - Prefer `NNReal` for masses/capacities/probabilities. - Prefer finite types (`[Fintype ι]`) where possible. -### 2.2 Keep proofs readable and local -- Prefer: `simp`, `rw`, `linarith`/`nlinarith` when appropriate, small `calc` blocks, - and restrained `aesop` usage backed by named helper lemmas. -- Avoid huge opaque “mega proofs”. If a proof is long, factor it. - -> Lean 4.26+ note for agents: -> - Use stronger automation (`simp?`, `finish?`, `grind`, `try?`) primarily as **proof discovery** tools. -> - The final committed proof should be **explicit, minimal, and stable** (often: a small lemma + `simp [..]` / `rw [..]`). - -### 2.3 Don’t duplicate mathlib +### 2.2 Don’t duplicate mathlib - Search for existing lemmas before inventing new ones. - If you introduce a lemma that feels “standard”, consider whether mathlib already has it (or whether it belongs in a more general file in this repo). -### 2.4 Verify, Don't Trust +### 2.3 Verify, Don't Trust - Distinguish between **witness generation** (untrusted, can use heuristics) and **verification** (trusted, must contain proofs). - The trusted kernel should only check that a candidate witness is valid; it should not be responsible for finding it if the search is complex. @@ -130,7 +112,7 @@ The library’s claims rest on these being preserved (preferably with explicit l - small lemmas, smaller proof terms, fewer global simp rules. ### 3.3 After coding -- Ensure `lake build -q --wfail` passes. +- Ensure `lake build --wfail` passes. - Ensure no `sorry`. - Ensure no linter toggles were introduced. - If you changed module responsibilities/structure, update §5 in the same commit. @@ -151,88 +133,14 @@ The library’s claims rest on these being preserved (preferably with explicit l - non-explosive, - and broadly safe. - Prefer `simp [foo]` over global simp-set growth. -- Prefer `simp?` **only to discover** what `simp [..]` should be. - -### 4.3 Tactic usage -- `aesop` is allowed, but: - - avoid relying on “magic” if it makes failures hard to debug, - - extract key steps into named lemmas so proofs stay stable. -### 4.4 Refactors are allowed—but must be principled +### 4.3 Refactors are allowed—but must be principled - You may do nontrivial refactors to improve conceptual cleanliness. - If you rename/reshape core APIs: - update all call sites, - leave a brief comment (or commit message rationale), - keep the module map (§5) accurate. -### 4.5 Lean 4.26+ proof exploration toolkit (for LLM agents) -These tools can dramatically reduce “stuck time” for lemma discovery. Use them like a **search assistant**. -They are *not* a substitute for readable proofs. - -**Allowed for exploration (scratch / development):** -- `simp?` (optionally with suggestions, if available) -- `finish?` -- `grind` / `grind?`, and `grind => finish?` -- `try?` (as a hint generator) - -**Rules for using exploration tools:** -1. **Never commit generated `sorry`.** If an exploration tactic suggests a script with `sorry`, treat it as debugging output and delete it. -2. **Never commit giant opaque scripts.** If a generated script is long: - - identify the key lemmas it used, - - create named helper lemmas, - - replace the script with a small proof built from those lemmas. -3. **Minimize lemma sets.** - - If `simp?` / `finish?` / `grind` suggests many lemmas, shrink to the smallest stable subset. -4. **Prefer stable shapes:** - - a short `calc` block, - - or a couple of `simp [..]` / `rw [..]` steps, - - plus one helper lemma if necessary. -5. **Keep it local.** Prefer adding lemmas to the local simp set (`simp [foo, bar]`) over tagging globally `[simp]`. - -**Agent “proof playbook” (recommended loop):** -- Step A: Try the obvious: `simp`, `simp [defs]`, `rw [defs]`, `linarith`, `nlinarith`, `ring`, `field_simp` (as appropriate). -- Step B: If stuck, run `simp?` to discover missing rewrite/simp lemmas. -- Step C: If still stuck, use `finish?` or `grind => finish?` to learn the *shape* of the proof and which lemmas matter. -- Step D: Replace the discovered script with: - - a helper lemma (named + documented) capturing the crucial step, - - and a short final proof using `simp`/`rw`/`calc`. -- Step E: Re-run `lake build -q --wfail`. - ---- - -## Lean 4 performance & scalability (use when justified) - -Default: write the simplest correct thing first. Use the levers below only when there is a clear payoff -(hot path, large workload, or expensive work that’s often unused). Add a short comment explaining the trigger. - -### Parallelism: `Task` (opt-in, deterministic-by-construction) -Use `Task` when work is independent and CPU-heavy (e.g., per-candidate / per-layer computations). -- Prefer *pure* tasks: `Task.spawn (fun () => ...)` and later `Task.get`. - Tasks cache their result; subsequent `get`s do not recompute. (Tasks are like “opportunistic thunks”.) -- Use `IO.asTask` only when you truly need effects; remember a task is spawned each time the returned `IO` action is executed. -- Keep results deterministic: never depend on completion order; aggregate by stable keys. -- Keep granularity coarse enough to amortize scheduling overhead. -- Cancellation: pure tasks stop when dropped; `IO.asTask` tasks must check for cancellation (`IO.checkCanceled`), and can be canceled via `IO.cancel`. -- If benchmarking, note that the runtime task thread pool size is controlled by `LEAN_NUM_THREADS` (or defaults to logical CPU count). - -### Laziness: `Thunk` / delayed computations (opt-in, for expensive-but-often-unused work) -Use `Thunk` to defer work that is expensive and frequently unused (debug traces, optional certificates, rare branches). -- Prefer `Thunk` over “manual caching”: the runtime forces at most once and caches the value. -- Force explicitly at the boundary (`Thunk.get`), not “deep inside” unrelated logic. -- If a thunk is forced from multiple threads, other threads will wait while one thread computes it—avoid forcing in places where blocking could deadlock. - -### Compile-time / elaboration performance nudge -When proofs or declarations get large, prefer factoring them into smaller independent theorems/lemmas when it improves clarity. -Lean can elaborate theorem bodies in parallel, so smaller independent units can help the compiler do more work concurrently. - -### Transparency / unfolding control (use sparingly) -Unfolding choices affect performance of simplification and typeclass search. -- The simplifier unfolds *reducible* definitions by default; semireducible/irreducible require explicit rewrite rules or different settings. -- `opaque` definitions are not δ-reduced in the kernel; use them to prevent expensive kernel reduction when unfolding is not needed for reasoning. -- Avoid cargo-culting reducibility attributes: use `local`/`scoped` when possible, and leave a short comment about why. - -Note: Recent Lean versions changed the story around well-founded recursion transparency; don’t rely on old recipes like making well-founded recursion “reducible” via attributes. - --- ## 5. Module Map (Where Things Live) @@ -398,10 +306,12 @@ but you **must** update this list in the same commit. - Head-output interval certificates built from induction head inputs. - `Nfp/Sound/Induction/HeadBounds.lean` - Helper bounds used to stage head-induction certificate construction. +- `Nfp/Sound/Bounds/Cache.lean` + - Cached bound evaluators (thunk/task backed) for interval computations. - `Nfp/Sound/Bounds/MatrixNorm.lean` - Row-sum matrix norms and downstream linear certificate builders. - `Nfp/Sound/Bounds/MatrixNorm/Interval.lean` - - Dot-product and matrix-vector interval bounds (dyadic and real). + - Dot-product and matrix-vector interval bounds (rational and real). - `Nfp/Sound/Bounds/LayerNorm.lean` - LayerNorm interval bounds and end-to-end soundness lemmas. - `Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean` @@ -455,7 +365,7 @@ This repo treats “axioms creep” as a serious regression. ## 7. Definition of Done (Checklist) -- [ ] `lake build -q --wfail` succeeds. +- [ ] `lake build --wfail` succeeds. - [ ] No `sorry`. - [ ] No new axioms were introduced. - [ ] **Total Soundness:** Every pure definition in the trusted section is verified/proven. @@ -463,8 +373,7 @@ This repo treats “axioms creep” as a serious regression. - [ ] New nontrivial definitions/theorems have short, accurate docstrings. - [ ] Core invariants (nonnegativity, normalization, finiteness, acyclicity) are preserved and, where possible, explicitly proved. - [ ] §5 Module Map is accurate (updated in the same commit if needed). -- [ ] If CLI behavior changed: `lake build nfp -q --wfail` succeeds and basic `nfp ... --help` works. -- [ ] If you used Lean 4.26+ exploration tools, the final committed proof is short, explicit, and stable (no giant generated scripts). +- [ ] If CLI behavior changed: `lake build nfp --wfail` succeeds and basic `nfp ... --help` works. When forced to choose between: - “slightly breaking but conceptually clean redesign” From 072669d9c8ae344ed0a306be32ca2d4e0aee627f Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 9 Jan 2026 03:06:00 +0100 Subject: [PATCH 113/244] Tighten head score bounds via interval dot products --- Nfp/Sound/Bounds/MatrixNorm/Interval.lean | 312 +++++++++++ Nfp/Sound/Induction/Core.lean | 603 ++++++++++++---------- Nfp/Sound/Induction/HeadBounds.lean | 89 ++++ 3 files changed, 743 insertions(+), 261 deletions(-) diff --git a/Nfp/Sound/Bounds/MatrixNorm/Interval.lean b/Nfp/Sound/Bounds/MatrixNorm/Interval.lean index 2db676c..ba3f66a 100644 --- a/Nfp/Sound/Bounds/MatrixNorm/Interval.lean +++ b/Nfp/Sound/Bounds/MatrixNorm/Interval.lean @@ -95,6 +95,130 @@ def dotIntervalLower {n : Nat} (v lo hi : Fin n → Rat) : Rat := def dotIntervalUpper {n : Nat} (v lo hi : Fin n → Rat) : Rat := Linear.sumFin n (fun j => if 0 ≤ v j then v j * hi j else v j * lo j) +/-- Lower interval endpoint for a product of two intervals. -/ +def mulIntervalLower (a b c d : Rat) : Rat := + min (min (a * c) (a * d)) (min (b * c) (b * d)) + +/-- Upper interval endpoint for a product of two intervals. -/ +def mulIntervalUpper (a b c d : Rat) : Rat := + max (max (a * c) (a * d)) (max (b * c) (b * d)) + +/-- `x * y` lies between `min (a * y) (b * y)` and `max (a * y) (b * y)` when `a ≤ x ≤ b`. -/ +lemma mul_between_of_bounds {a b x y : Rat} (hx : a ≤ x) (hx' : x ≤ b) : + min (a * y) (b * y) ≤ x * y ∧ x * y ≤ max (a * y) (b * y) := by + have hab : a ≤ b := le_trans hx hx' + by_cases hy : 0 ≤ y + · have hmin : min (a * y) (b * y) = a * y := by + have hle : a * y ≤ b * y := mul_le_mul_of_nonneg_right hab hy + exact min_eq_left hle + have hmax : max (a * y) (b * y) = b * y := by + have hle : a * y ≤ b * y := mul_le_mul_of_nonneg_right hab hy + exact max_eq_right hle + constructor + · simpa [hmin] using (mul_le_mul_of_nonneg_right hx hy) + · simpa [hmax] using (mul_le_mul_of_nonneg_right hx' hy) + · have hy' : y ≤ 0 := le_of_not_ge hy + have hmin : min (a * y) (b * y) = b * y := by + have hle : b * y ≤ a * y := mul_le_mul_of_nonpos_right hab hy' + exact min_eq_right hle + have hmax : max (a * y) (b * y) = a * y := by + have hle : b * y ≤ a * y := mul_le_mul_of_nonpos_right hab hy' + exact max_eq_left hle + constructor + · simpa [hmin] using (mul_le_mul_of_nonpos_right hx' hy') + · simpa [hmax] using (mul_le_mul_of_nonpos_right hx hy') + +/-- Lower interval endpoint bounds `x * y` when both factors are interval-bounded. -/ +lemma mulIntervalLower_le_mul {a b c d x y : Rat} + (hx : a ≤ x) (hx' : x ≤ b) (hy : c ≤ y) (hy' : y ≤ d) : + mulIntervalLower a b c d ≤ x * y := by + have hAy : + min (a * c) (a * d) ≤ a * y := by + have h := mul_between_of_bounds (a := c) (b := d) (x := y) (y := a) hy hy' + simpa [mul_comm] using h.1 + have hBy : + min (b * c) (b * d) ≤ b * y := by + have h := mul_between_of_bounds (a := c) (b := d) (x := y) (y := b) hy hy' + simpa [mul_comm] using h.1 + have hmin : + min (min (a * c) (a * d)) (min (b * c) (b * d)) ≤ min (a * y) (b * y) := by + apply le_min + · exact le_trans (min_le_left _ _) hAy + · exact le_trans (min_le_right _ _) hBy + have hxy := (mul_between_of_bounds (a := a) (b := b) (x := x) (y := y) hx hx').1 + exact le_trans hmin hxy + +/-- Upper interval endpoint bounds `x * y` when both factors are interval-bounded. -/ +lemma mul_le_mulIntervalUpper {a b c d x y : Rat} + (hx : a ≤ x) (hx' : x ≤ b) (hy : c ≤ y) (hy' : y ≤ d) : + x * y ≤ mulIntervalUpper a b c d := by + have hAy : + a * y ≤ max (a * c) (a * d) := by + have h := mul_between_of_bounds (a := c) (b := d) (x := y) (y := a) hy hy' + simpa [mul_comm] using h.2 + have hBy : + b * y ≤ max (b * c) (b * d) := by + have h := mul_between_of_bounds (a := c) (b := d) (x := y) (y := b) hy hy' + simpa [mul_comm] using h.2 + have hmax : + max (a * y) (b * y) ≤ max (max (a * c) (a * d)) (max (b * c) (b * d)) := by + apply max_le + · exact le_trans hAy (le_max_left _ _) + · exact le_trans hBy (le_max_right _ _) + have hxy := (mul_between_of_bounds (a := a) (b := b) (x := x) (y := y) hx hx').2 + exact le_trans hxy hmax + +/-- Lower interval endpoint for a dot product with bounds on both vectors. -/ +def dotIntervalLower2 {n : Nat} (lo1 hi1 lo2 hi2 : Fin n → Rat) : Rat := + Linear.sumFin n (fun j => mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j)) + +/-- Upper interval endpoint for a dot product with bounds on both vectors. -/ +def dotIntervalUpper2 {n : Nat} (lo1 hi1 lo2 hi2 : Fin n → Rat) : Rat := + Linear.sumFin n (fun j => mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j)) + +/-- Lower/upper interval endpoints for a dot product with bounds on both vectors. -/ +def dotIntervalLowerUpper2CommonDen {n : Nat} (lo1 hi1 lo2 hi2 : Fin n → Rat) : Rat × Rat := + Linear.foldlFin n + (fun acc j => + (acc.1 + mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j), + acc.2 + mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j))) + (0, 0) + +theorem dotIntervalLower2_le_dotProduct {n : Nat} (lo1 hi1 lo2 hi2 x y : Fin n → Rat) + (hlo1 : ∀ j, lo1 j ≤ x j) (hhi1 : ∀ j, x j ≤ hi1 j) + (hlo2 : ∀ j, lo2 j ≤ y j) (hhi2 : ∀ j, y j ≤ hi2 j) : + dotIntervalLower2 lo1 hi1 lo2 hi2 ≤ dotProduct x y := by + classical + have hterm : + ∀ j, + mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j) ≤ x j * y j := by + intro j + exact mulIntervalLower_le_mul (hlo1 j) (hhi1 j) (hlo2 j) (hhi2 j) + have hsum : + ∑ j, mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j) ≤ + ∑ j, x j * y j := by + refine Finset.sum_le_sum ?_ + intro j _ + exact hterm j + simpa [dotIntervalLower2, Linear.sumFin_eq_sum_univ, dotProduct] using hsum + +theorem dotProduct_le_dotIntervalUpper2 {n : Nat} (lo1 hi1 lo2 hi2 x y : Fin n → Rat) + (hlo1 : ∀ j, lo1 j ≤ x j) (hhi1 : ∀ j, x j ≤ hi1 j) + (hlo2 : ∀ j, lo2 j ≤ y j) (hhi2 : ∀ j, y j ≤ hi2 j) : + dotProduct x y ≤ dotIntervalUpper2 lo1 hi1 lo2 hi2 := by + classical + have hterm : + ∀ j, + x j * y j ≤ mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j) := by + intro j + exact mul_le_mulIntervalUpper (hlo1 j) (hhi1 j) (hlo2 j) (hhi2 j) + have hsum : + ∑ j, x j * y j ≤ + ∑ j, mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j) := by + refine Finset.sum_le_sum ?_ + intro j _ + exact hterm j + simpa [dotIntervalUpper2, Linear.sumFin_eq_sum_univ, dotProduct] using hsum /-- Lower interval endpoint using a shared-denominator accumulator. -/ def dotIntervalLowerCommonDen {n : Nat} (v lo hi : Fin n → Rat) : Rat := Linear.sumFinCommonDen n (fun j => if 0 ≤ v j then v j * lo j else v j * hi j) @@ -145,6 +269,61 @@ private lemma foldl_pair_snd {α : Type _} (xs : List α) (f g : α → Rat) (a | cons x xs ih => simp [List.foldl, ih] +theorem dotIntervalLowerUpper2CommonDen_fst {n : Nat} (lo1 hi1 lo2 hi2 : Fin n → Rat) : + (dotIntervalLowerUpper2CommonDen lo1 hi1 lo2 hi2).1 = + dotIntervalLower2 lo1 hi1 lo2 hi2 := by + classical + have hfold : + (dotIntervalLowerUpper2CommonDen lo1 hi1 lo2 hi2).1 = + (List.finRange n).foldl + (fun acc j => acc + mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j)) 0 := by + simpa [dotIntervalLowerUpper2CommonDen, Linear.foldlFin_eq_foldl, + Fin.foldl_eq_foldl_finRange] using + (foldl_pair_fst (xs := List.finRange n) + (f := fun j => mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j)) + (g := fun j => mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j)) + (a := 0) (b := 0)) + have hsum : + dotIntervalLower2 lo1 hi1 lo2 hi2 = + (List.finRange n).foldl + (fun acc j => acc + mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j)) 0 := by + simp [dotIntervalLower2, Linear.sumFin_eq_list_foldl] + calc + (dotIntervalLowerUpper2CommonDen lo1 hi1 lo2 hi2).1 + = (List.finRange n).foldl + (fun acc j => acc + mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j)) 0 := hfold + _ = dotIntervalLower2 lo1 hi1 lo2 hi2 := hsum.symm + +theorem dotIntervalLowerUpper2CommonDen_snd {n : Nat} (lo1 hi1 lo2 hi2 : Fin n → Rat) : + (dotIntervalLowerUpper2CommonDen lo1 hi1 lo2 hi2).2 = + dotIntervalUpper2 lo1 hi1 lo2 hi2 := by + classical + have hfold : + (dotIntervalLowerUpper2CommonDen lo1 hi1 lo2 hi2).2 = + (List.finRange n).foldl + (fun acc j => acc + mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j)) 0 := by + simpa [dotIntervalLowerUpper2CommonDen, Linear.foldlFin_eq_foldl, + Fin.foldl_eq_foldl_finRange] using + (foldl_pair_snd (xs := List.finRange n) + (f := fun j => mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j)) + (g := fun j => mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j)) + (a := 0) (b := 0)) + have hsum : + dotIntervalUpper2 lo1 hi1 lo2 hi2 = + (List.finRange n).foldl + (fun acc j => acc + mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j)) 0 := by + simp [dotIntervalUpper2, Linear.sumFin_eq_list_foldl] + calc + (dotIntervalLowerUpper2CommonDen lo1 hi1 lo2 hi2).2 + = (List.finRange n).foldl + (fun acc j => acc + mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j)) 0 := hfold + _ = dotIntervalUpper2 lo1 hi1 lo2 hi2 := hsum.symm + +theorem dotIntervalLowerUpper2CommonDen_eq {n : Nat} (lo1 hi1 lo2 hi2 : Fin n → Rat) : + dotIntervalLowerUpper2CommonDen lo1 hi1 lo2 hi2 = + (dotIntervalLower2 lo1 hi1 lo2 hi2, dotIntervalUpper2 lo1 hi1 lo2 hi2) := by + ext <;> simp [dotIntervalLowerUpper2CommonDen_fst, dotIntervalLowerUpper2CommonDen_snd] + theorem dotIntervalLowerUpperCommonDen_fst {n : Nat} (v lo hi : Fin n → Rat) : (dotIntervalLowerUpperCommonDen v lo hi).1 = dotIntervalLowerCommonDen v lo hi := by classical @@ -361,6 +540,139 @@ theorem abs_dotProduct_le_dotIntervalAbsBound {n : Nat} (v lo hi x : Fin n → R /-! Real-valued bounds from rational intervals. -/ +lemma mul_between_of_bounds_real {a b x y : Real} (hx : a ≤ x) (hx' : x ≤ b) : + min (a * y) (b * y) ≤ x * y ∧ x * y ≤ max (a * y) (b * y) := by + have hab : a ≤ b := le_trans hx hx' + by_cases hy : 0 ≤ y + · have hmin : min (a * y) (b * y) = a * y := by + have hle : a * y ≤ b * y := mul_le_mul_of_nonneg_right hab hy + exact min_eq_left hle + have hmax : max (a * y) (b * y) = b * y := by + have hle : a * y ≤ b * y := mul_le_mul_of_nonneg_right hab hy + exact max_eq_right hle + constructor + · simpa [hmin] using (mul_le_mul_of_nonneg_right hx hy) + · simpa [hmax] using (mul_le_mul_of_nonneg_right hx' hy) + · have hy' : y ≤ 0 := le_of_not_ge hy + have hmin : min (a * y) (b * y) = b * y := by + have hle : b * y ≤ a * y := mul_le_mul_of_nonpos_right hab hy' + exact min_eq_right hle + have hmax : max (a * y) (b * y) = a * y := by + have hle : b * y ≤ a * y := mul_le_mul_of_nonpos_right hab hy' + exact max_eq_left hle + constructor + · simpa [hmin] using (mul_le_mul_of_nonpos_right hx' hy') + · simpa [hmax] using (mul_le_mul_of_nonpos_right hx hy') + +lemma mulIntervalLower_le_mul_real {a b c d : Rat} {x y : Real} + (hx : (a : Real) ≤ x) (hx' : x ≤ (b : Real)) + (hy : (c : Real) ≤ y) (hy' : y ≤ (d : Real)) : + (mulIntervalLower a b c d : Real) ≤ x * y := by + have hAy : + min ((a : Real) * (c : Real)) ((a : Real) * (d : Real)) ≤ (a : Real) * y := by + have h := mul_between_of_bounds_real (a := (c : Real)) (b := (d : Real)) (x := y) + (y := (a : Real)) hy hy' + simpa [mul_comm] using h.1 + have hBy : + min ((b : Real) * (c : Real)) ((b : Real) * (d : Real)) ≤ (b : Real) * y := by + have h := mul_between_of_bounds_real (a := (c : Real)) (b := (d : Real)) (x := y) + (y := (b : Real)) hy hy' + simpa [mul_comm] using h.1 + have hmin : + min (min ((a : Real) * (c : Real)) ((a : Real) * (d : Real))) + (min ((b : Real) * (c : Real)) ((b : Real) * (d : Real))) ≤ + min ((a : Real) * y) ((b : Real) * y) := by + apply le_min + · exact le_trans (min_le_left _ _) hAy + · exact le_trans (min_le_right _ _) hBy + have hxy := (mul_between_of_bounds_real (a := (a : Real)) (b := (b : Real)) (x := x) + (y := y) hx hx').1 + have hcast : + (mulIntervalLower a b c d : Real) = + min (min ((a : Real) * (c : Real)) ((a : Real) * (d : Real))) + (min ((b : Real) * (c : Real)) ((b : Real) * (d : Real))) := by + simp [mulIntervalLower, Rat.cast_min, Rat.cast_mul] + calc + (mulIntervalLower a b c d : Real) + = min (min ((a : Real) * (c : Real)) ((a : Real) * (d : Real))) + (min ((b : Real) * (c : Real)) ((b : Real) * (d : Real))) := hcast + _ ≤ min ((a : Real) * y) ((b : Real) * y) := hmin + _ ≤ x * y := hxy + +lemma mul_le_mulIntervalUpper_real {a b c d : Rat} {x y : Real} + (hx : (a : Real) ≤ x) (hx' : x ≤ (b : Real)) + (hy : (c : Real) ≤ y) (hy' : y ≤ (d : Real)) : + x * y ≤ (mulIntervalUpper a b c d : Real) := by + have hAy : + (a : Real) * y ≤ max ((a : Real) * (c : Real)) ((a : Real) * (d : Real)) := by + have h := mul_between_of_bounds_real (a := (c : Real)) (b := (d : Real)) (x := y) + (y := (a : Real)) hy hy' + simpa [mul_comm] using h.2 + have hBy : + (b : Real) * y ≤ max ((b : Real) * (c : Real)) ((b : Real) * (d : Real)) := by + have h := mul_between_of_bounds_real (a := (c : Real)) (b := (d : Real)) (x := y) + (y := (b : Real)) hy hy' + simpa [mul_comm] using h.2 + have hmax : + max ((a : Real) * y) ((b : Real) * y) ≤ + max (max ((a : Real) * (c : Real)) ((a : Real) * (d : Real))) + (max ((b : Real) * (c : Real)) ((b : Real) * (d : Real))) := by + apply max_le + · exact le_trans hAy (le_max_left _ _) + · exact le_trans hBy (le_max_right _ _) + have hxy := (mul_between_of_bounds_real (a := (a : Real)) (b := (b : Real)) (x := x) + (y := y) hx hx').2 + have hcast : + (mulIntervalUpper a b c d : Real) = + max (max ((a : Real) * (c : Real)) ((a : Real) * (d : Real))) + (max ((b : Real) * (c : Real)) ((b : Real) * (d : Real))) := by + simp [mulIntervalUpper, Rat.cast_max, Rat.cast_mul] + calc + x * y ≤ max ((a : Real) * y) ((b : Real) * y) := hxy + _ ≤ max (max ((a : Real) * (c : Real)) ((a : Real) * (d : Real))) + (max ((b : Real) * (c : Real)) ((b : Real) * (d : Real))) := hmax + _ = (mulIntervalUpper a b c d : Real) := hcast.symm + +theorem dotIntervalLower2_le_dotProduct_real {n : Nat} (lo1 hi1 lo2 hi2 : Fin n → Rat) + (x y : Fin n → Real) + (hlo1 : ∀ j, (lo1 j : Real) ≤ x j) (hhi1 : ∀ j, x j ≤ (hi1 j : Real)) + (hlo2 : ∀ j, (lo2 j : Real) ≤ y j) (hhi2 : ∀ j, y j ≤ (hi2 j : Real)) : + (dotIntervalLower2 lo1 hi1 lo2 hi2 : Real) ≤ dotProduct x y := by + classical + have hcast : + (dotIntervalLower2 lo1 hi1 lo2 hi2 : Real) = + ∑ j, (mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j) : Real) := by + simpa [dotIntervalLower2, ratToReal] using + (Linear.ratToReal_sumFin + (f := fun j => mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j))) + have hsum : + (∑ j, (mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j) : Real)) ≤ + ∑ j, x j * y j := by + refine Finset.sum_le_sum ?_ + intro j _ + exact mulIntervalLower_le_mul_real (hlo1 j) (hhi1 j) (hlo2 j) (hhi2 j) + simpa [hcast, dotProduct] using hsum + +theorem dotProduct_le_dotIntervalUpper2_real {n : Nat} (lo1 hi1 lo2 hi2 : Fin n → Rat) + (x y : Fin n → Real) + (hlo1 : ∀ j, (lo1 j : Real) ≤ x j) (hhi1 : ∀ j, x j ≤ (hi1 j : Real)) + (hlo2 : ∀ j, (lo2 j : Real) ≤ y j) (hhi2 : ∀ j, y j ≤ (hi2 j : Real)) : + dotProduct x y ≤ (dotIntervalUpper2 lo1 hi1 lo2 hi2 : Real) := by + classical + have hcast : + (dotIntervalUpper2 lo1 hi1 lo2 hi2 : Real) = + ∑ j, (mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j) : Real) := by + simpa [dotIntervalUpper2, ratToReal] using + (Linear.ratToReal_sumFin + (f := fun j => mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j))) + have hsum : + ∑ j, x j * y j ≤ + ∑ j, (mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j) : Real) := by + refine Finset.sum_le_sum ?_ + intro j _ + exact mul_le_mulIntervalUpper_real (hlo1 j) (hhi1 j) (hlo2 j) (hhi2 j) + simpa [hcast, dotProduct] using hsum + theorem dotIntervalLower_le_dotProduct_real {n : Nat} (v lo hi : Fin n → Rat) (x : Fin n → Real) (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : diff --git a/Nfp/Sound/Induction/Core.lean b/Nfp/Sound/Induction/Core.lean index 266ffe0..1f7339d 100644 --- a/Nfp/Sound/Induction/Core.lean +++ b/Nfp/Sound/Induction/Core.lean @@ -133,63 +133,129 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} let univ : Finset (Fin seq) := Finset.univ have hnonempty : univ.Nonempty := Finset.univ_nonempty univ.sup' hnonempty (fun q => lnAbsMax q) - let qAbsRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := + let qLoRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => ⟨Array.ofFn (fun d : Fin dHead => - Bounds.dotIntervalAbsBound (fun j => inputs.wq j d) (lnLo q) (lnHi q) + - |inputs.bq d|), + Bounds.dotIntervalLowerUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + + inputs.bq d), by simp⟩)) - let qAbsBaseArr : Array { row : Array Rat // row.size = dHead } := + let qHiRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := Array.ofFn (fun q : Fin seq => - (qAbsRowTasks[q.1]'(by - have hsize : qAbsRowTasks.size = seq := by - simp [qAbsRowTasks] + Task.spawn (fun _ => + ⟨Array.ofFn (fun d : Fin dHead => + Bounds.dotIntervalUpperUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + + inputs.bq d), + by simp⟩)) + let qLoArr : Array { row : Array Rat // row.size = dHead } := + Array.ofFn (fun q : Fin seq => + (qLoRowTasks[q.1]'(by + have hsize : qLoRowTasks.size = seq := by + simp [qLoRowTasks] + simp [hsize])).get) + let qHiArr : Array { row : Array Rat // row.size = dHead } := + Array.ofFn (fun q : Fin seq => + (qHiRowTasks[q.1]'(by + have hsize : qHiRowTasks.size = seq := by + simp [qHiRowTasks] simp [hsize])).get) - let qAbsBase : Fin seq → Fin dHead → Rat := fun q d => - let row := qAbsBaseArr[q.1]'(by - have hsize : qAbsBaseArr.size = seq := by - simp [qAbsBaseArr] + let qLo : Fin seq → Fin dHead → Rat := fun q d => + let row := qLoArr[q.1]'(by + have hsize : qLoArr.size = seq := by + simp [qLoArr] simp [hsize]) row.1[d.1]'(by simp [row.2]) - let kAbsRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := + let qHi : Fin seq → Fin dHead → Rat := fun q d => + let row := qHiArr[q.1]'(by + have hsize : qHiArr.size = seq := by + simp [qHiArr] + simp [hsize]) + row.1[d.1]'(by + simp [row.2]) + let kLoRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + ⟨Array.ofFn (fun d : Fin dHead => + Bounds.dotIntervalLowerUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + + inputs.bk d), + by simp⟩)) + let kHiRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => ⟨Array.ofFn (fun d : Fin dHead => - Bounds.dotIntervalAbsBound (fun j => inputs.wk j d) (lnLo q) (lnHi q) + - |inputs.bk d|), + Bounds.dotIntervalUpperUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + + inputs.bk d), by simp⟩)) - let kAbsBaseArr : Array { row : Array Rat // row.size = dHead } := + let kLoArr : Array { row : Array Rat // row.size = dHead } := + Array.ofFn (fun q : Fin seq => + (kLoRowTasks[q.1]'(by + have hsize : kLoRowTasks.size = seq := by + simp [kLoRowTasks] + simp [hsize])).get) + let kHiArr : Array { row : Array Rat // row.size = dHead } := Array.ofFn (fun q : Fin seq => - (kAbsRowTasks[q.1]'(by - have hsize : kAbsRowTasks.size = seq := by - simp [kAbsRowTasks] + (kHiRowTasks[q.1]'(by + have hsize : kHiRowTasks.size = seq := by + simp [kHiRowTasks] simp [hsize])).get) - let kAbsBase : Fin seq → Fin dHead → Rat := fun q d => - let row := kAbsBaseArr[q.1]'(by - have hsize : kAbsBaseArr.size = seq := by - simp [kAbsBaseArr] + let kLo : Fin seq → Fin dHead → Rat := fun q d => + let row := kLoArr[q.1]'(by + have hsize : kLoArr.size = seq := by + simp [kLoArr] simp [hsize]) row.1[d.1]'(by simp [row.2]) - let qLo : Fin seq → Fin dHead → Rat := fun q d => -qAbsBase q d - let qHi : Fin seq → Fin dHead → Rat := fun q d => qAbsBase q d - let kLo : Fin seq → Fin dHead → Rat := fun q d => -kAbsBase q d - let kHi : Fin seq → Fin dHead → Rat := fun q d => kAbsBase q d - let qAbs : Fin seq → Fin dHead → Rat := qAbsBase - let kAbs : Fin seq → Fin dHead → Rat := kAbsBase + let kHi : Fin seq → Fin dHead → Rat := fun q d => + let row := kHiArr[q.1]'(by + have hsize : kHiArr.size = seq := by + simp [kHiArr] + simp [hsize]) + row.1[d.1]'(by + simp [row.2]) + let qAbs : Fin seq → Fin dHead → Rat := fun q d => max |qLo q d| |qHi q d| + let kAbs : Fin seq → Fin dHead → Rat := fun q d => max |kLo q d| |kHi q d| let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k - let dotAbs := - Bounds.cacheBound2Task (fun q k => - Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d)) + let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + ⟨Array.ofFn (fun k : Fin seq => + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2CommonDen + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo k d) (fun d => kHi k d)), + by simp⟩)) + let dotLo : Fin seq → Fin seq → Rat := fun q k => + let row := (dotRowTasks[q.1]'(by + simp [dotRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.1 + let dotHi : Fin seq → Fin seq → Rat := fun q k => + let row := (dotRowTasks[q.1]'(by + simp [dotRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.2 + let dotAbs : Fin seq → Fin seq → Rat := fun q k => max |dotLo q k| |dotHi q k| let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => |inputs.scale| * dotAbs q k let scoreLo : Fin seq → Fin seq → Rat := fun q k => - if masked q k then inputs.maskValue else -scoreBaseAbs q k + if masked q k then + inputs.maskValue + else + if hscale : 0 ≤ inputs.scale then + inputs.scale * dotLo q k + else + inputs.scale * dotHi q k let scoreHi : Fin seq → Fin seq → Rat := fun q k => - if masked q k then inputs.maskValue else scoreBaseAbs q k + if masked q k then + inputs.maskValue + else + if hscale : 0 ≤ inputs.scale then + inputs.scale * dotHi q k + else + inputs.scale * dotLo q k let scoreLoPrev : Fin seq → Rat := fun q => scoreLo q (inputs.prev q) let otherKeys : Fin seq → Finset (Fin seq) := fun q => @@ -286,65 +352,129 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} let univ : Finset (Fin seq) := Finset.univ have hnonempty : univ.Nonempty := Finset.univ_nonempty univ.sup' hnonempty (fun q => lnAbsMax q) - let qAbsRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := + let qLoRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => ⟨Array.ofFn (fun d : Fin dHead => - Bounds.dotIntervalAbsBound (fun j => inputs.wq j d) (lnLo q) (lnHi q) + - |inputs.bq d|), + Bounds.dotIntervalLowerUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + + inputs.bq d), by simp⟩)) - let qAbsBaseArr : Array { row : Array Rat // row.size = dHead } := + let qHiRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := Array.ofFn (fun q : Fin seq => - (qAbsRowTasks[q.1]'(by - have hsize : qAbsRowTasks.size = seq := by - simp [qAbsRowTasks] + Task.spawn (fun _ => + ⟨Array.ofFn (fun d : Fin dHead => + Bounds.dotIntervalUpperUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + + inputs.bq d), + by simp⟩)) + let qLoArr : Array { row : Array Rat // row.size = dHead } := + Array.ofFn (fun q : Fin seq => + (qLoRowTasks[q.1]'(by + have hsize : qLoRowTasks.size = seq := by + simp [qLoRowTasks] simp [hsize])).get) - let qAbsBase : Fin seq → Fin dHead → Rat := fun q d => - let row := qAbsBaseArr[q.1]'(by - have hsize : qAbsBaseArr.size = seq := by - simp [qAbsBaseArr] + let qHiArr : Array { row : Array Rat // row.size = dHead } := + Array.ofFn (fun q : Fin seq => + (qHiRowTasks[q.1]'(by + have hsize : qHiRowTasks.size = seq := by + simp [qHiRowTasks] + simp [hsize])).get) + let qLo : Fin seq → Fin dHead → Rat := fun q d => + let row := qLoArr[q.1]'(by + have hsize : qLoArr.size = seq := by + simp [qLoArr] simp [hsize]) row.1[d.1]'(by simp [row.2]) - let kAbsRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := + let qHi : Fin seq → Fin dHead → Rat := fun q d => + let row := qHiArr[q.1]'(by + have hsize : qHiArr.size = seq := by + simp [qHiArr] + simp [hsize]) + row.1[d.1]'(by + simp [row.2]) + let kLoRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => ⟨Array.ofFn (fun d : Fin dHead => - Bounds.dotIntervalAbsBound (fun j => inputs.wk j d) (lnLo q) (lnHi q) + - |inputs.bk d|), + Bounds.dotIntervalLowerUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + + inputs.bk d), by simp⟩)) - let kAbsBaseArr : Array { row : Array Rat // row.size = dHead } := + let kHiRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + ⟨Array.ofFn (fun d : Fin dHead => + Bounds.dotIntervalUpperUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + + inputs.bk d), + by simp⟩)) + let kLoArr : Array { row : Array Rat // row.size = dHead } := + Array.ofFn (fun q : Fin seq => + (kLoRowTasks[q.1]'(by + have hsize : kLoRowTasks.size = seq := by + simp [kLoRowTasks] + simp [hsize])).get) + let kHiArr : Array { row : Array Rat // row.size = dHead } := Array.ofFn (fun q : Fin seq => - (kAbsRowTasks[q.1]'(by - have hsize : kAbsRowTasks.size = seq := by - simp [kAbsRowTasks] + (kHiRowTasks[q.1]'(by + have hsize : kHiRowTasks.size = seq := by + simp [kHiRowTasks] simp [hsize])).get) - let kAbsBase : Fin seq → Fin dHead → Rat := fun q d => - let row := kAbsBaseArr[q.1]'(by - have hsize : kAbsBaseArr.size = seq := by - simp [kAbsBaseArr] + let kLo : Fin seq → Fin dHead → Rat := fun q d => + let row := kLoArr[q.1]'(by + have hsize : kLoArr.size = seq := by + simp [kLoArr] simp [hsize]) row.1[d.1]'(by simp [row.2]) - let qLo : Fin seq → Fin dHead → Rat := fun q d => -qAbsBase q d - let qHi : Fin seq → Fin dHead → Rat := fun q d => qAbsBase q d - let kLo : Fin seq → Fin dHead → Rat := fun q d => -kAbsBase q d - let kHi : Fin seq → Fin dHead → Rat := fun q d => kAbsBase q d - let qAbs : Fin seq → Fin dHead → Rat := qAbsBase - let kAbs : Fin seq → Fin dHead → Rat := kAbsBase + let kHi : Fin seq → Fin dHead → Rat := fun q d => + let row := kHiArr[q.1]'(by + have hsize : kHiArr.size = seq := by + simp [kHiArr] + simp [hsize]) + row.1[d.1]'(by + simp [row.2]) + let qAbs : Fin seq → Fin dHead → Rat := fun q d => max |qLo q d| |qHi q d| + let kAbs : Fin seq → Fin dHead → Rat := fun q d => max |kLo q d| |kHi q d| let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k - let dotAbs := - Bounds.cacheBound2Task (fun q k => - Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d)) + let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + ⟨Array.ofFn (fun k : Fin seq => + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2CommonDen + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo k d) (fun d => kHi k d)), + by simp⟩)) + let dotLo : Fin seq → Fin seq → Rat := fun q k => + let row := (dotRowTasks[q.1]'(by + simp [dotRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.1 + let dotHi : Fin seq → Fin seq → Rat := fun q k => + let row := (dotRowTasks[q.1]'(by + simp [dotRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.2 + let dotAbs : Fin seq → Fin seq → Rat := fun q k => max |dotLo q k| |dotHi q k| let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => |inputs.scale| * dotAbs q k - let scoreAbs : Fin seq → Fin seq → Rat := fun q k => - if masked q k then |inputs.maskValue| else scoreBaseAbs q k let scoreLo : Fin seq → Fin seq → Rat := fun q k => - if masked q k then inputs.maskValue else -scoreBaseAbs q k + if masked q k then + inputs.maskValue + else + if hscale : 0 ≤ inputs.scale then + inputs.scale * dotLo q k + else + inputs.scale * dotHi q k let scoreHi : Fin seq → Fin seq → Rat := fun q k => - if masked q k then inputs.maskValue else scoreBaseAbs q k + if masked q k then + inputs.maskValue + else + if hscale : 0 ≤ inputs.scale then + inputs.scale * dotHi q k + else + inputs.scale * dotLo q k let scoreLoPrev : Fin seq → Rat := fun q => scoreLo q (inputs.prev q) let otherKeys : Fin seq → Finset (Fin seq) := fun q => @@ -409,12 +539,13 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} simpa [buildInductionCertFromHeadCore?, hEps, hSqrt, hmodel, hactive, lnBounds, lnLo, lnHi, lnAbsMaxTask, lnAbsMaxArr, lnAbsMax, lnAbsMaxMax, - qAbsRowTasks, qAbsBaseArr, qAbsBase, kAbsRowTasks, kAbsBaseArr, kAbsBase, - qLo, qHi, kLo, kHi, qAbs, kAbs, masked, dotAbs, scoreBaseAbs, scoreLo, + qLoRowTasks, qHiRowTasks, qLoArr, qHiArr, qLo, qHi, + kLoRowTasks, kHiRowTasks, kLoArr, kHiArr, kLo, kHi, + qAbs, kAbs, masked, dotRowTasks, dotLo, dotHi, dotAbs, scoreBaseAbs, scoreLo, scoreHi, scoreLoPrev, otherKeys, marginAt, epsAt, margin, eps, dirHeadVec, dirHead, wvDir, bDir, valsAbsBase, valsLoBase, valsHiBase, valsLo, valsHi, univ, lo, hi, valCert, cert, Task.spawn, Bounds.cacheBoundTask_apply, - Bounds.cacheBound2Task_apply, Array.getElem_ofFn] + Array.getElem_ofFn] using hcore have hc : c = cert := by simpa using (Option.some.inj hcore').symm @@ -517,70 +648,58 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} ∀ q d, (qLo q d : Real) ≤ qRealOfInputs inputs q d ∧ qRealOfInputs inputs q d ≤ (qHi q d : Real) := by intro q d - have hdot := hdot_abs_bound (fun j => inputs.wq j d) q - have hq_abs : - |qRealOfInputs inputs q d| ≤ (qAbsBase q d : Real) := by - have hsum : - |qRealOfInputs inputs q d| ≤ - (Bounds.dotIntervalAbsBound (fun j => inputs.wq j d) (lnLo q) (lnHi q) : - Real) + |(inputs.bq d : Real)| := by - calc - |qRealOfInputs inputs q d| - = |dotProduct (fun j => (inputs.wq j d : Real)) - (lnRealOfInputs inputs q) + (inputs.bq d : Real)| := by - simp [qRealOfInputs] - _ ≤ |dotProduct (fun j => (inputs.wq j d : Real)) - (lnRealOfInputs inputs q)| + |(inputs.bq d : Real)| := by - exact - (abs_add_le (a := dotProduct (fun j => (inputs.wq j d : Real)) - (lnRealOfInputs inputs q)) (b := (inputs.bq d : Real))) - _ ≤ (Bounds.dotIntervalAbsBound (fun j => inputs.wq j d) (lnLo q) (lnHi q) : - Real) + |(inputs.bq d : Real)| := by - exact add_le_add hdot (le_rfl) - have hsum' : - |qRealOfInputs inputs q d| ≤ (qAbsBase q d : Real) := by - simpa [qAbsBase, qAbsBaseArr, qAbsRowTasks, lnLo, lnHi, Task.spawn, - Bounds.dotIntervalAbsBound, ratToReal_add, ratToReal_abs] - using hsum - exact hsum' - have hq_bounds := (abs_le).1 hq_abs - constructor - · simpa [qLo] using hq_bounds.1 - · simpa [qHi] using hq_bounds.2 + have hln := hln_bounds q + have hlo : ∀ j, (lnLo q j : Real) ≤ lnRealOfInputs inputs q j := fun j => + (hln j).1 + have hhi : ∀ j, lnRealOfInputs inputs q j ≤ (lnHi q j : Real) := fun j => + (hln j).2 + have hdot_lo := + Bounds.dotIntervalLower_le_dotProduct_real + (v := fun j => inputs.wq j d) (lo := lnLo q) (hi := lnHi q) + (x := lnRealOfInputs inputs q) hlo hhi + have hdot_hi := + Bounds.dotProduct_le_dotIntervalUpper_real + (v := fun j => inputs.wq j d) (lo := lnLo q) (hi := lnHi q) + (x := lnRealOfInputs inputs q) hlo hhi + have hlow : + (qLo q d : Real) ≤ qRealOfInputs inputs q d := by + have h := add_le_add_right hdot_lo (inputs.bq d : Real) + simpa [qRealOfInputs, qLo, qLoArr, qLoRowTasks, lnLo, lnHi, Task.spawn, + Bounds.dotIntervalLowerUnnorm, ratToReal_add] using h + have hhigh : + qRealOfInputs inputs q d ≤ (qHi q d : Real) := by + have h := add_le_add_right hdot_hi (inputs.bq d : Real) + simpa [qRealOfInputs, qHi, qHiArr, qHiRowTasks, lnLo, lnHi, Task.spawn, + Bounds.dotIntervalUpperUnnorm, ratToReal_add] using h + exact ⟨hlow, hhigh⟩ have hk_bounds : ∀ q d, (kLo q d : Real) ≤ kRealOfInputs inputs q d ∧ kRealOfInputs inputs q d ≤ (kHi q d : Real) := by intro q d - have hdot := hdot_abs_bound (fun j => inputs.wk j d) q - have hk_abs : - |kRealOfInputs inputs q d| ≤ (kAbsBase q d : Real) := by - have hsum : - |kRealOfInputs inputs q d| ≤ - (Bounds.dotIntervalAbsBound (fun j => inputs.wk j d) (lnLo q) (lnHi q) : - Real) + |(inputs.bk d : Real)| := by - calc - |kRealOfInputs inputs q d| - = |dotProduct (fun j => (inputs.wk j d : Real)) - (lnRealOfInputs inputs q) + (inputs.bk d : Real)| := by - simp [kRealOfInputs] - _ ≤ |dotProduct (fun j => (inputs.wk j d : Real)) - (lnRealOfInputs inputs q)| + |(inputs.bk d : Real)| := by - exact - (abs_add_le (a := dotProduct (fun j => (inputs.wk j d : Real)) - (lnRealOfInputs inputs q)) (b := (inputs.bk d : Real))) - _ ≤ (Bounds.dotIntervalAbsBound (fun j => inputs.wk j d) (lnLo q) (lnHi q) : - Real) + |(inputs.bk d : Real)| := by - exact add_le_add hdot (le_rfl) - have hsum' : - |kRealOfInputs inputs q d| ≤ (kAbsBase q d : Real) := by - simpa [kAbsBase, kAbsBaseArr, kAbsRowTasks, lnLo, lnHi, Task.spawn, - Bounds.dotIntervalAbsBound, ratToReal_add, ratToReal_abs] - using hsum - exact hsum' - have hk_bounds := (abs_le).1 hk_abs - constructor - · simpa [kLo] using hk_bounds.1 - · simpa [kHi] using hk_bounds.2 + have hln := hln_bounds q + have hlo : ∀ j, (lnLo q j : Real) ≤ lnRealOfInputs inputs q j := fun j => + (hln j).1 + have hhi : ∀ j, lnRealOfInputs inputs q j ≤ (lnHi q j : Real) := fun j => + (hln j).2 + have hdot_lo := + Bounds.dotIntervalLower_le_dotProduct_real + (v := fun j => inputs.wk j d) (lo := lnLo q) (hi := lnHi q) + (x := lnRealOfInputs inputs q) hlo hhi + have hdot_hi := + Bounds.dotProduct_le_dotIntervalUpper_real + (v := fun j => inputs.wk j d) (lo := lnLo q) (hi := lnHi q) + (x := lnRealOfInputs inputs q) hlo hhi + have hlow : + (kLo q d : Real) ≤ kRealOfInputs inputs q d := by + have h := add_le_add_right hdot_lo (inputs.bk d : Real) + simpa [kRealOfInputs, kLo, kLoArr, kLoRowTasks, lnLo, lnHi, Task.spawn, + Bounds.dotIntervalLowerUnnorm, ratToReal_add] using h + have hhigh : + kRealOfInputs inputs q d ≤ (kHi q d : Real) := by + have h := add_le_add_right hdot_hi (inputs.bk d : Real) + simpa [kRealOfInputs, kHi, kHiArr, kHiRowTasks, lnLo, lnHi, Task.spawn, + Bounds.dotIntervalUpperUnnorm, ratToReal_add] using h + exact ⟨hlow, hhigh⟩ have hscore_bounds : ∀ q k, (scoreLo q k : Real) ≤ scoresRealOfInputs inputs q k ∧ scoresRealOfInputs inputs q k ≤ (scoreHi q k : Real) := by @@ -589,147 +708,109 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} let base := (inputs.scale : Real) * dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) - have hq_abs : ∀ d, |qRealOfInputs inputs q d| ≤ (qAbs q d : Real) := by - intro d - have hq := hq_bounds q d - have hq' : - -(qAbsBase q d : Real) ≤ qRealOfInputs inputs q d ∧ - qRealOfInputs inputs q d ≤ (qAbsBase q d : Real) := by - simpa [qLo, qHi] using hq - have h := (abs_le).2 hq' - simpa [qAbs, qAbsBase] using h - have hk_abs : ∀ d, |kRealOfInputs inputs k d| ≤ (kAbs k d : Real) := by - intro d - have hk := hk_bounds k d - have hk' : - -(kAbsBase k d : Real) ≤ kRealOfInputs inputs k d ∧ - kRealOfInputs inputs k d ≤ (kAbsBase k d : Real) := by - simpa [kLo, kHi] using hk - have h := (abs_le).2 hk' - simpa [kAbs, kAbsBase] using h - have hdot_abs : - |dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d)| ≤ - (dotAbs q k : Real) := by - have hsum : - |∑ d, qRealOfInputs inputs q d * kRealOfInputs inputs k d| ≤ - ∑ d, |qRealOfInputs inputs q d * kRealOfInputs inputs k d| := by - simpa [dotProduct] using - (Finset.abs_sum_le_sum_abs (s := (Finset.univ : Finset (Fin dHead))) - (f := fun d => qRealOfInputs inputs q d * kRealOfInputs inputs k d)) - have hterm : - ∀ d, - |qRealOfInputs inputs q d * kRealOfInputs inputs k d| ≤ - (qAbs q d : Real) * (kAbs k d : Real) := by - intro d - have hq := hq_abs d - have hk := hk_abs d - have hqnonneg : 0 ≤ (qAbs q d : Real) := by - have hdot_nonneg : - 0 ≤ Bounds.dotIntervalAbsBound - (fun j => inputs.wq j d) (lnLo q) (lnHi q) := by - have hleft : - 0 ≤ |Bounds.dotIntervalLower (fun j => inputs.wq j d) - (lnLo q) (lnHi q)| := by - exact abs_nonneg _ - exact le_trans hleft (le_max_left _ _) - have hbq_nonneg : 0 ≤ |inputs.bq d| := abs_nonneg _ - have hsum_nonneg : - 0 ≤ Bounds.dotIntervalAbsBound - (fun j => inputs.wq j d) (lnLo q) (lnHi q) + |inputs.bq d| := by - exact add_nonneg hdot_nonneg hbq_nonneg - have hqnonneg' : 0 ≤ qAbs q d := by - simpa [qAbs, qAbsBase, qAbsBaseArr, qAbsRowTasks, lnLo, lnHi, - Task.spawn, Bounds.dotIntervalAbsBound] using hsum_nonneg - exact ratToReal_nonneg_of_nonneg hqnonneg' - calc - |qRealOfInputs inputs q d * kRealOfInputs inputs k d| = - |qRealOfInputs inputs q d| * |kRealOfInputs inputs k d| := by - simp [abs_mul] - _ ≤ (qAbs q d : Real) * (kAbs k d : Real) := - mul_le_mul hq hk (abs_nonneg _) hqnonneg - have hsum_le : - ∑ d, |qRealOfInputs inputs q d * kRealOfInputs inputs k d| ≤ - ∑ d, (qAbs q d : Real) * (kAbs k d : Real) := by - refine Finset.sum_le_sum ?_ - intro d _ - exact hterm d - have hcast : - (dotAbs q k : Real) = - ∑ d, (qAbs q d : Real) * (kAbs k d : Real) := by - have hsum : - ((∑ d, qAbs q d * kAbs k d : Rat) : Real) = - ∑ d, ((qAbs q d * kAbs k d : Rat) : Real) := by - have h := Linear.ratToReal_sum_univ (f := fun d => qAbs q d * kAbs k d) - dsimp [ratToReal] at h - exact h - have hsum' : - ∑ d, ((qAbs q d * kAbs k d : Rat) : Real) = - ∑ d, (qAbs q d : Real) * (kAbs k d : Real) := by - refine Finset.sum_congr rfl ?_ - intro d _ - simp - have hfinal := hsum.trans hsum' - calc - (dotAbs q k : Real) - = ((∑ d, qAbs q d * kAbs k d : Rat) : Real) := by - simp [dotAbs, Bounds.cacheBound2Task_apply, - Linear.dotFin_eq_dotProduct, dotProduct] - _ = ∑ d, (qAbs q d : Real) * (kAbs k d : Real) := hfinal - have hfinal := hsum.trans (hsum_le.trans_eq hcast.symm) - simpa [dotProduct] using hfinal - have hscale_abs : 0 ≤ (|inputs.scale| : Real) := by - exact abs_nonneg (ratToReal inputs.scale) - have hbase_abs : - |base| ≤ (scoreBaseAbs q k : Real) := by - have hdot_abs' := hdot_abs - have hmul : - |base| = - (|inputs.scale| : Real) * - |dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d)| := by - simp [base, abs_mul] - have hmul_le : - (|inputs.scale| : Real) * - |dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d)| ≤ - (|inputs.scale| : Real) * (dotAbs q k : Real) := by - exact mul_le_mul_of_nonneg_left hdot_abs' hscale_abs - simpa [scoreBaseAbs, hmul] using hmul_le + have hdot_bounds : + (dotLo q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) ∧ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) ≤ (dotHi q k : Real) := by + have hq := hq_bounds q + have hk := hk_bounds k + have hlo1 : ∀ d, (qLo q d : Real) ≤ qRealOfInputs inputs q d := fun d => + (hq d).1 + have hhi1 : ∀ d, qRealOfInputs inputs q d ≤ (qHi q d : Real) := fun d => + (hq d).2 + have hlo2 : ∀ d, (kLo k d : Real) ≤ kRealOfInputs inputs k d := fun d => + (hk d).1 + have hhi2 : ∀ d, kRealOfInputs inputs k d ≤ (kHi k d : Real) := fun d => + (hk d).2 + have hlow := + _root_.Nfp.Sound.Bounds.dotIntervalLower2_le_dotProduct_real + (lo1 := fun d => qLo q d) (hi1 := fun d => qHi q d) + (lo2 := fun d => kLo k d) (hi2 := fun d => kHi k d) + (x := fun d => qRealOfInputs inputs q d) + (y := fun d => kRealOfInputs inputs k d) + hlo1 hhi1 hlo2 hhi2 + have hhigh := + _root_.Nfp.Sound.Bounds.dotProduct_le_dotIntervalUpper2_real + (lo1 := fun d => qLo q d) (hi1 := fun d => qHi q d) + (lo2 := fun d => kLo k d) (hi2 := fun d => kHi k d) + (x := fun d => qRealOfInputs inputs q d) + (y := fun d => kRealOfInputs inputs k d) + hlo1 hhi1 hlo2 hhi2 + have hlow' : + (dotLo q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) := by + simpa [dotLo, dotRowTasks, Task.spawn, Array.getElem_ofFn, + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2CommonDen_fst] using hlow + have hhigh' : + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) ≤ (dotHi q k : Real) := by + simpa [dotHi, dotRowTasks, Task.spawn, Array.getElem_ofFn, + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2CommonDen_snd] using hhigh + exact ⟨hlow', hhigh'⟩ by_cases hcausal : inputs.maskCausal · by_cases hle : k ≤ q · have hnot : ¬ q < k := not_lt_of_ge hle have hscore_eq : scoresReal q k = base := by simp [scoresReal, scoresRealOfInputs, hcausal, hle, base] - have hscore_abs' : |scoresReal q k| ≤ (scoreBaseAbs q k : Real) := by - simpa [hscore_eq] using hbase_abs - have hscore_abs : - |scoresReal q k| ≤ (scoreAbs q k : Real) := by - simpa [scoreAbs, masked, hcausal, hnot] - using hscore_abs' - have hscore_bounds := (abs_le).1 hscore_abs - constructor - · simpa [scoresReal, scoreLo, scoreAbs, masked, hcausal, hnot] - using hscore_bounds.1 - · simpa [scoresReal, scoreHi, scoreAbs, masked, hcausal, hnot] - using hscore_bounds.2 + by_cases hscale : 0 ≤ inputs.scale + · have hscale_real : 0 ≤ (inputs.scale : Real) := + ratToReal_nonneg_of_nonneg hscale + have hlow := + mul_le_mul_of_nonneg_left hdot_bounds.1 hscale_real + have hhigh := + mul_le_mul_of_nonneg_left hdot_bounds.2 hscale_real + constructor + · simpa [scoresReal, scoreLo, masked, hcausal, hnot, hscale, hscore_eq, base] + using hlow + · simpa [scoresReal, scoreHi, masked, hcausal, hnot, hscale, hscore_eq, base] + using hhigh + · have hscale_nonpos : inputs.scale ≤ 0 := + le_of_lt (lt_of_not_ge hscale) + have hscale_real : (inputs.scale : Real) ≤ 0 := + (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos + have hlow := + mul_le_mul_of_nonpos_left hdot_bounds.2 hscale_real + have hhigh := + mul_le_mul_of_nonpos_left hdot_bounds.1 hscale_real + constructor + · simpa [scoresReal, scoreLo, masked, hcausal, hnot, hscale, hscore_eq, base] + using hlow + · simpa [scoresReal, scoreHi, masked, hcausal, hnot, hscale, hscore_eq, base] + using hhigh · have hlt : q < k := lt_of_not_ge hle constructor · simp [scoresRealOfInputs, scoreLo, masked, hcausal, hle, hlt] · simp [scoresRealOfInputs, scoreHi, masked, hcausal, hle, hlt] · have hscore_eq : scoresReal q k = base := by simp [scoresReal, scoresRealOfInputs, hcausal, base] - have hscore_abs' : |scoresReal q k| ≤ (scoreBaseAbs q k : Real) := by - simpa [hscore_eq] using hbase_abs - have hscore_abs : - |scoresReal q k| ≤ (scoreAbs q k : Real) := by - simpa [scoreAbs, masked, hcausal] using hscore_abs' - have hscore_bounds := (abs_le).1 hscore_abs - constructor - · simpa [scoresReal, scoreLo, scoreAbs, masked, hcausal] - using hscore_bounds.1 - · simpa [scoresReal, scoreHi, scoreAbs, masked, hcausal] - using hscore_bounds.2 + by_cases hscale : 0 ≤ inputs.scale + · have hscale_real : 0 ≤ (inputs.scale : Real) := + ratToReal_nonneg_of_nonneg hscale + have hlow := + mul_le_mul_of_nonneg_left hdot_bounds.1 hscale_real + have hhigh := + mul_le_mul_of_nonneg_left hdot_bounds.2 hscale_real + constructor + · simpa [scoresReal, scoreLo, masked, hcausal, hscale, hscore_eq, base] + using hlow + · simpa [scoresReal, scoreHi, masked, hcausal, hscale, hscore_eq, base] + using hhigh + · have hscale_nonpos : inputs.scale ≤ 0 := + le_of_lt (lt_of_not_ge hscale) + have hscale_real : (inputs.scale : Real) ≤ 0 := + (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos + have hlow := + mul_le_mul_of_nonpos_left hdot_bounds.2 hscale_real + have hhigh := + mul_le_mul_of_nonpos_left hdot_bounds.1 hscale_real + constructor + · simpa [scoresReal, scoreLo, masked, hcausal, hscale, hscore_eq, base] + using hlow + · simpa [scoresReal, scoreHi, masked, hcausal, hscale, hscore_eq, base] + using hhigh let scoresReal := scoresRealOfInputs inputs have hmarginAt_le : ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → diff --git a/Nfp/Sound/Induction/HeadBounds.lean b/Nfp/Sound/Induction/HeadBounds.lean index 0cdde49..61ceb93 100644 --- a/Nfp/Sound/Induction/HeadBounds.lean +++ b/Nfp/Sound/Induction/HeadBounds.lean @@ -527,6 +527,95 @@ theorem headScoreBoundsFromDotAbs_spec [NeZero seq] {dModel dHead : Nat} margin := margin eps := eps } := rfl +/-- Compute score and margin bounds from Q/K interval bounds. -/ +def headScoreBoundsFromIntervals [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (qLo qHi kLo kHi : Fin seq → Fin dHead → Rat) : + HeadScoreBounds seq dModel dHead := + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + ⟨Array.ofFn (fun k : Fin seq => + dotIntervalLowerUpper2CommonDen (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo k d) (fun d => kHi k d)), + by simp⟩)) + let dotLo : Fin seq → Fin seq → Rat := fun q k => + let row := (dotRowTasks[q.1]'(by + simp [dotRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.1 + let dotHi : Fin seq → Fin seq → Rat := fun q k => + let row := (dotRowTasks[q.1]'(by + simp [dotRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.2 + let dotAbs : Fin seq → Fin seq → Rat := fun q k => max |dotLo q k| |dotHi q k| + let scoreLo : Fin seq → Fin seq → Rat := fun q k => + if masked q k then + inputs.maskValue + else + if 0 ≤ inputs.scale then + inputs.scale * dotLo q k + else + inputs.scale * dotHi q k + let scoreHi : Fin seq → Fin seq → Rat := fun q k => + if masked q k then + inputs.maskValue + else + if 0 ≤ inputs.scale then + inputs.scale * dotHi q k + else + inputs.scale * dotLo q k + headScoreBoundsFromCaches inputs dotAbs scoreLo scoreHi + +theorem headScoreBoundsFromIntervals_spec [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (qLo qHi kLo kHi : Fin seq → Fin dHead → Rat) : + headScoreBoundsFromIntervals inputs qLo qHi kLo kHi = + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + ⟨Array.ofFn (fun k : Fin seq => + dotIntervalLowerUpper2CommonDen (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo k d) (fun d => kHi k d)), + by simp⟩)) + let dotLo : Fin seq → Fin seq → Rat := fun q k => + let row := (dotRowTasks[q.1]'(by + simp [dotRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.1 + let dotHi : Fin seq → Fin seq → Rat := fun q k => + let row := (dotRowTasks[q.1]'(by + simp [dotRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.2 + let dotAbs : Fin seq → Fin seq → Rat := fun q k => max |dotLo q k| |dotHi q k| + let scoreLo : Fin seq → Fin seq → Rat := fun q k => + if masked q k then + inputs.maskValue + else + if 0 ≤ inputs.scale then + inputs.scale * dotLo q k + else + inputs.scale * dotHi q k + let scoreHi : Fin seq → Fin seq → Rat := fun q k => + if masked q k then + inputs.maskValue + else + if 0 ≤ inputs.scale then + inputs.scale * dotHi q k + else + inputs.scale * dotLo q k + headScoreBoundsFromCaches inputs dotAbs scoreLo scoreHi := rfl + /-- Compute score and margin bounds from Q/K absolute bounds. -/ def headScoreBounds [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) From b8da417167294ecd2e76bc7a417394bd6f2ec4e8 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 9 Jan 2026 03:17:42 +0100 Subject: [PATCH 114/244] Tighten LayerNorm inv-std bound with variance lower --- Nfp/Sound/Bounds/LayerNorm.lean | 87 ++++++++++++++++++++++++--------- 1 file changed, 64 insertions(+), 23 deletions(-) diff --git a/Nfp/Sound/Bounds/LayerNorm.lean b/Nfp/Sound/Bounds/LayerNorm.lean index 60e7e34..7236851 100644 --- a/Nfp/Sound/Bounds/LayerNorm.lean +++ b/Nfp/Sound/Bounds/LayerNorm.lean @@ -506,7 +506,10 @@ def layerNormBounds {n : Nat} let μHi := meanUpper x let centeredBound : Fin n → Rat := fun i => max |x i - μHi| |x i - μLo| - let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) + let varLo : Rat := variance x + let varEpsLo : Rat := varLo + eps + let sqrtLowerBound : Rat := max (sqrtLower eps) (sqrtLower varEpsLo) + let invStdBound : Rat := ratDivUp 1 sqrtLowerBound let radius : Fin n → Rat := fun i => |gamma i| * centeredBound i * invStdBound (fun i => beta i - radius i, fun i => beta i + radius i) @@ -523,7 +526,10 @@ theorem layerNormBounds_spec {n : Nat} let μLo : Rat := mean x let μHi : Rat := meanUpper x let centeredBound : Fin n → Rat := fun j => max |x j - μHi| |x j - μLo| - let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) + let varLo : Rat := variance x + let varEpsLo : Rat := varLo + eps + let sqrtLowerBound : Rat := max (sqrtLower eps) (sqrtLower varEpsLo) + let invStdBound : Rat := ratDivUp 1 sqrtLowerBound let varEps : Real := (varianceRat x : Real) + (eps : Real) let μ : Real := meanRat x let invStd : Real := (Real.sqrt varEps)⁻¹ @@ -547,27 +553,62 @@ theorem layerNormBounds_spec {n : Nat} simpa [centeredBound, μLo, μHi, ratToReal_abs, ratToReal_sub, ratToReal_max] using hbound have hvar_nonneg : 0 ≤ (varianceRat x : Real) := varianceRat_nonneg_real x hne + have hvar_nonneg_rat : 0 ≤ varianceRat x := by + have hreal : 0 ≤ (varianceRat x : Real) := hvar_nonneg + exact (ratToReal_nonneg_iff (x := varianceRat x)).1 hreal + have hvarLo_nonneg : 0 ≤ varLo := by + have h := ratRoundDown_nonneg (q := varianceRat x) hvar_nonneg_rat + simpa [varLo, variance_def x hne] using h + have hvarEpsLo_nonneg : 0 ≤ varEpsLo := by + exact add_nonneg hvarLo_nonneg (le_of_lt heps) have hsqrt_lower : - (sqrtLower eps : Real) ≤ Real.sqrt varEps := by + (sqrtLowerBound : Real) ≤ Real.sqrt varEps := by have hsqrt_eps : - (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by - have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) - simpa using h - have hle : (eps : Real) ≤ varEps := by - have hle' : (eps : Real) ≤ (varianceRat x : Real) + (eps : Real) := - le_add_of_nonneg_left hvar_nonneg - simpa [varEps] using hle' - have hsqrt_eps' : Real.sqrt (eps : Real) ≤ Real.sqrt varEps := by - exact Real.sqrt_le_sqrt hle - exact le_trans hsqrt_eps hsqrt_eps' - have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by - exact (Rat.cast_pos (K := Real) (q := sqrtLower eps)).2 hsqrt - have hinv_sqrt : invStd ≤ (sqrtLower eps : Real)⁻¹ := by + (sqrtLower eps : Real) ≤ Real.sqrt varEps := by + have hsqrt_eps' : + (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by + have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) + simpa using h + have hle : (eps : Real) ≤ varEps := by + have hle' : (eps : Real) ≤ (varianceRat x : Real) + (eps : Real) := + le_add_of_nonneg_left hvar_nonneg + simpa [varEps] using hle' + have hsqrt_eps'' : Real.sqrt (eps : Real) ≤ Real.sqrt varEps := by + exact Real.sqrt_le_sqrt hle + exact le_trans hsqrt_eps' hsqrt_eps'' + have hsqrt_var : + (sqrtLower varEpsLo : Real) ≤ Real.sqrt varEps := by + have hsqrt_var' : + (sqrtLower varEpsLo : Real) ≤ Real.sqrt (varEpsLo : Real) := by + have h := sqrtLower_le_real_sqrt (q := varEpsLo) hvarEpsLo_nonneg + simpa using h + have hle : (varEpsLo : Real) ≤ varEps := by + have hle' : (varLo : Real) ≤ (varianceRat x : Real) := by + have h := ratRoundDown_le_real (varianceRat x) + simpa [varLo, variance_def x hne] using h + have hle'' := add_le_add_right hle' (eps : Real) + simpa [varEpsLo, varEps, ratToReal_add] using hle'' + have hsqrt_var'' : Real.sqrt (varEpsLo : Real) ≤ Real.sqrt varEps := by + exact Real.sqrt_le_sqrt hle + exact le_trans hsqrt_var' hsqrt_var'' + have hmax : + max (sqrtLower eps : Real) (sqrtLower varEpsLo : Real) ≤ Real.sqrt varEps := + (max_le_iff).2 ⟨hsqrt_eps, hsqrt_var⟩ + simpa [sqrtLowerBound, ratToReal_max] using hmax + have hsqrt_lower_pos : 0 < (sqrtLowerBound : Real) := by + have hpos : 0 < (sqrtLower eps : Real) := by + exact (Rat.cast_pos (K := Real) (q := sqrtLower eps)).2 hsqrt + have hpos' : 0 < max (sqrtLower eps : Real) (sqrtLower varEpsLo : Real) := by + exact lt_of_lt_of_le hpos (le_max_left _ _) + simpa [sqrtLowerBound, ratToReal_max] using hpos' + have hinv_sqrt : invStd ≤ (sqrtLowerBound : Real)⁻¹ := by have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower simpa [invStd] using h - have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by - have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt - have hdiv := ratDivUp_ge_real (x := 1) (y := sqrtLower eps) hy + have hinv_bound : (sqrtLowerBound : Real)⁻¹ ≤ (invStdBound : Real) := by + have hsqrt_lower_pos_rat : 0 < sqrtLowerBound := by + exact lt_of_lt_of_le hsqrt (le_max_left _ _) + have hy : sqrtLowerBound ≠ 0 := ne_of_gt hsqrt_lower_pos_rat + have hdiv := ratDivUp_ge_real (x := 1) (y := sqrtLowerBound) hy simpa [invStdBound, one_div] using hdiv have hinv : invStd ≤ (invStdBound : Real) := by exact le_trans hinv_sqrt hinv_bound @@ -615,11 +656,11 @@ theorem layerNormBounds_spec {n : Nat} layerNormReal eps gamma beta x i = t + (beta i : Real) := by simp [layerNormReal, hne, μ, invStd, varEps, t, add_comm] have hlo : (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i := by - simpa [bounds, layerNormBounds, hne, radius, centeredBound, invStdBound, μLo, μHi, - hreal] using hlow + simpa [bounds, layerNormBounds, hne, radius, centeredBound, varLo, varEpsLo, + sqrtLowerBound, invStdBound, μLo, μHi, hreal] using hlow have hhi : layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by - simpa [bounds, layerNormBounds, hne, radius, centeredBound, invStdBound, μLo, μHi, - hreal] using hhigh + simpa [bounds, layerNormBounds, hne, radius, centeredBound, varLo, varEpsLo, + sqrtLowerBound, invStdBound, μLo, μHi, hreal] using hhigh exact And.intro hlo hhi /-- Interval bounds for LayerNorm outputs from per-coordinate intervals. -/ From fbc3567b5588b0dbd5133b7042d34e30ad09cf86 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 9 Jan 2026 03:30:29 +0100 Subject: [PATCH 115/244] Add scaled sqrtLower bound for LayerNorm --- Nfp/Sound/Bounds/LayerNorm.lean | 112 ++++++++++++++++++++++++++++++-- 1 file changed, 108 insertions(+), 4 deletions(-) diff --git a/Nfp/Sound/Bounds/LayerNorm.lean b/Nfp/Sound/Bounds/LayerNorm.lean index 7236851..4e445a7 100644 --- a/Nfp/Sound/Bounds/LayerNorm.lean +++ b/Nfp/Sound/Bounds/LayerNorm.lean @@ -66,9 +66,20 @@ def sqrtUpperAlt (q : Rat) : Rat := let a := Nat.sqrt (num * den) ratRoundUp ((a + 1 : Rat) / den) +/-- Extra precision scale for `sqrtLowerScaled`. -/ +def sqrtLowerScale : Nat := 4096 + +/-- Scaled rational lower bound for a square root (extra precision). -/ +def sqrtLowerScaled (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den + let scale := sqrtLowerScale + let a := Nat.sqrt (num * den * scale * scale) + ratRoundDown ((a : Rat) / (den * scale)) + /-- Rational lower bound for a square root (tighter of two bounds). -/ def sqrtLower (q : Rat) : Rat := - max (sqrtLowerBase q) (sqrtLowerAlt q) + max (max (sqrtLowerBase q) (sqrtLowerAlt q)) (sqrtLowerScaled q) /-- Rational upper bound for a square root (tighter of two bounds). -/ def sqrtUpper (q : Rat) : Rat := @@ -167,7 +178,9 @@ theorem sqrtUpperAlt_pos (q : Rat) : 0 < sqrtUpperAlt q := by /-- `sqrtLower` is nonnegative. -/ theorem sqrtLower_nonneg (q : Rat) : 0 ≤ sqrtLower q := by have hbase : 0 ≤ sqrtLowerBase q := sqrtLowerBase_nonneg q - exact le_trans hbase (le_max_left _ _) + have hmax : 0 ≤ max (sqrtLowerBase q) (sqrtLowerAlt q) := + le_trans hbase (le_max_left _ _) + exact le_trans hmax (le_max_left _ _) /-- `sqrtUpper` is nonnegative. -/ @@ -353,6 +366,89 @@ theorem sqrtLowerAlt_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : simpa [sqrtLowerAlt, num, den, a] using hdown' exact le_trans hdown hle +/-- Scaled square-root lower bound in reals. -/ +theorem sqrtLowerScaled_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : + (sqrtLowerScaled q : Real) ≤ Real.sqrt (q : Real) := by + classical + set num : Nat := q.num.natAbs + set den : Nat := q.den + set scale : Nat := sqrtLowerScale + set a : Nat := Nat.sqrt (num * den * scale * scale) + have hden_pos : 0 < (den : Real) := by + exact_mod_cast q.den_pos + have hscale_pos : 0 < (scale : Real) := by + have hscale_pos_nat : 0 < scale := by + simp [scale, sqrtLowerScale] + exact_mod_cast hscale_pos_nat + have hnumden_le : (a ^ 2 : Real) ≤ (num * den * scale * scale : Nat) := by + exact_mod_cast (Nat.sqrt_le' (num * den * scale * scale)) + have hmul : + (a ^ 2 : Real) ≤ (num : Real) * den * (scale : Real) * (scale : Real) := by + simpa [num, den, scale, Nat.cast_mul, mul_assoc, mul_left_comm, mul_comm] using hnumden_le + have hdenScale_pos : 0 < (den : Real) * (scale : Real) := + mul_pos hden_pos hscale_pos + have hdenScale_pos2 : 0 < ((den : Real) * (scale : Real)) ^ 2 := by + exact pow_pos hdenScale_pos 2 + have hmul' : + (a ^ 2 : Real) * ((den : Real) * (scale : Real)) ^ 2 ≤ + ((num : Real) * den * (scale : Real) * (scale : Real)) * + ((den : Real) * (scale : Real)) ^ 2 := by + have hnonneg : 0 ≤ ((den : Real) * (scale : Real)) ^ 2 := by + exact sq_nonneg _ + exact mul_le_mul_of_nonneg_right hmul hnonneg + have hdiv : + (a ^ 2 : Real) / ((den : Real) * (scale : Real)) ^ 2 ≤ + ((num : Real) * den * (scale : Real) * (scale : Real)) / + ((den : Real) * (scale : Real)) ^ 2 := by + exact (div_le_div_iff₀ hdenScale_pos2 hdenScale_pos2).2 hmul' + have hdenScale_ne : ((den : Real) * (scale : Real)) ≠ 0 := + ne_of_gt hdenScale_pos + have hq_cast : (q : Real) = (num : Real) / den := by + have hnum_nonneg : 0 ≤ q.num := by + exact (Rat.num_nonneg (q := q)).2 hq + have hnum_eq : (num : Int) = q.num := by + simpa [num] using (Int.natAbs_of_nonneg hnum_nonneg) + have hnum_cast : (q.num : Real) = (num : Real) := by + exact_mod_cast hnum_eq.symm + have hq_rat : (q : Real) = (q.num : Real) / q.den := by + simp [Rat.cast_def] + simpa [hnum_cast, den] using hq_rat + have hq_eq : + ((num : Real) * den * (scale : Real) * (scale : Real)) / + ((den : Real) * (scale : Real)) ^ 2 = (num : Real) / den := by + field_simp [hdenScale_ne] + have hpow : + ((a : Real) / ((den : Real) * (scale : Real))) ^ 2 = + (a ^ 2 : Real) / ((den : Real) * (scale : Real)) ^ 2 := by + simp [pow_two, div_mul_div_comm] + have hsq : + ((a : Real) / ((den : Real) * (scale : Real))) ^ 2 ≤ (q : Real) := by + calc + ((a : Real) / ((den : Real) * (scale : Real))) ^ 2 + = (a ^ 2 : Real) / ((den : Real) * (scale : Real)) ^ 2 := hpow + _ ≤ ((num : Real) * den * (scale : Real) * (scale : Real)) / + ((den : Real) * (scale : Real)) ^ 2 := hdiv + _ = (num : Real) / den := hq_eq + _ = (q : Real) := by simp [hq_cast] + have hnonneg : 0 ≤ (a : Real) / ((den : Real) * (scale : Real)) := by + have hnum_nonneg : 0 ≤ (a : Real) := by exact_mod_cast (Nat.zero_le a) + have hden_nonneg : 0 ≤ (den : Real) * (scale : Real) := by + nlinarith [hden_pos, hscale_pos] + exact div_nonneg hnum_nonneg hden_nonneg + have hq_nonneg : 0 ≤ (q : Real) := by + exact ratToReal_nonneg_of_nonneg hq + have hle : + (a : Real) / ((den : Real) * (scale : Real)) ≤ Real.sqrt (q : Real) := + (Real.le_sqrt hnonneg hq_nonneg).2 hsq + have hdown : + (sqrtLowerScaled q : Real) ≤ (a : Real) / ((den : Real) * (scale : Real)) := by + have hdown' : + ratToReal (ratRoundDown ((a : Rat) / (den * scale))) ≤ + (a : Real) / ((den : Real) * (scale : Real)) := by + simpa using ratRoundDown_le_real ((a : Rat) / (den * scale)) + simpa [sqrtLowerScaled, num, den, scale, a] using hdown' + exact le_trans hdown hle + /-- Alternate square-root upper bound in reals. -/ theorem real_sqrt_le_sqrtUpperAlt {q : Rat} (hq : 0 ≤ q) : Real.sqrt (q : Real) ≤ (sqrtUpperAlt q : Real) := by @@ -412,12 +508,20 @@ theorem real_sqrt_le_sqrtUpperAlt {q : Rat} (hq : 0 ≤ q) : simpa [sqrtUpperAlt, num, den, a] using hup' exact le_trans hle hup -/-- Square-root lower bound in reals (tighter of two bounds). -/ +/-- Square-root lower bound in reals (tighter of three bounds). -/ theorem sqrtLower_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : (sqrtLower q : Real) ≤ Real.sqrt (q : Real) := by have hbase := sqrtLowerBase_le_real_sqrt (q := q) hq have halt := sqrtLowerAlt_le_real_sqrt (q := q) hq - simpa [sqrtLower] using (max_le_iff).2 ⟨hbase, halt⟩ + have hscaled := sqrtLowerScaled_le_real_sqrt (q := q) hq + have hmax1 : + (max (sqrtLowerBase q) (sqrtLowerAlt q) : Real) ≤ Real.sqrt (q : Real) := by + simpa [ratToReal_max] using (max_le_iff).2 ⟨hbase, halt⟩ + have hmax2 : + (max (max (sqrtLowerBase q) (sqrtLowerAlt q)) (sqrtLowerScaled q) : Real) ≤ + Real.sqrt (q : Real) := by + simpa [ratToReal_max] using (max_le_iff).2 ⟨hmax1, hscaled⟩ + simpa [sqrtLower] using hmax2 /-- Square-root upper bound in reals (tighter of two bounds). -/ theorem real_sqrt_le_sqrtUpper {q : Rat} (hq : 0 ≤ q) : From a8c33dfb4001ae5cc3432044df13ffda250c3340 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 9 Jan 2026 07:15:55 +0100 Subject: [PATCH 116/244] Add sign-split dot-product bounds --- Nfp/Sound/Bounds/MatrixNorm/Interval.lean | 102 ++++++++++++++++++++++ Nfp/Sound/Induction/Core.lean | 41 +++++---- 2 files changed, 126 insertions(+), 17 deletions(-) diff --git a/Nfp/Sound/Bounds/MatrixNorm/Interval.lean b/Nfp/Sound/Bounds/MatrixNorm/Interval.lean index ba3f66a..d932d3d 100644 --- a/Nfp/Sound/Bounds/MatrixNorm/Interval.lean +++ b/Nfp/Sound/Bounds/MatrixNorm/Interval.lean @@ -184,6 +184,31 @@ def dotIntervalLowerUpper2CommonDen {n : Nat} (lo1 hi1 lo2 hi2 : Fin n → Rat) acc.2 + mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j))) (0, 0) +/-! Sign-splitting bounds. -/ + +/-- Clamp a single coordinate interval to be nonnegative or nonpositive. -/ +def clampAt {n : Nat} (i : Fin n) (nonneg : Bool) (lo hi : Fin n → Rat) : + (Fin n → Rat) × (Fin n → Rat) := + if nonneg then + (fun j => if j = i then max 0 (lo j) else lo j, hi) + else + (lo, fun j => if j = i then min 0 (hi j) else hi j) + +/-- Lower/upper interval endpoints with sign-splitting on selected coordinates. -/ +def dotIntervalLowerUpper2SignSplit {n : Nat} (dims : List (Fin n)) + (lo1 hi1 lo2 hi2 : Fin n → Rat) : Rat × Rat := + match dims with + | [] => + dotIntervalLowerUpper2CommonDen lo1 hi1 lo2 hi2 + | i :: rest => + let boundsPos := + let clamped := clampAt i true lo1 hi1 + dotIntervalLowerUpper2SignSplit rest clamped.1 clamped.2 lo2 hi2 + let boundsNeg := + let clamped := clampAt i false lo1 hi1 + dotIntervalLowerUpper2SignSplit rest clamped.1 clamped.2 lo2 hi2 + (min boundsPos.1 boundsNeg.1, max boundsPos.2 boundsNeg.2) + theorem dotIntervalLower2_le_dotProduct {n : Nat} (lo1 hi1 lo2 hi2 x y : Fin n → Rat) (hlo1 : ∀ j, lo1 j ≤ x j) (hhi1 : ∀ j, x j ≤ hi1 j) (hlo2 : ∀ j, lo2 j ≤ y j) (hhi2 : ∀ j, y j ≤ hi2 j) : @@ -673,6 +698,83 @@ theorem dotProduct_le_dotIntervalUpper2_real {n : Nat} (lo1 hi1 lo2 hi2 : Fin n exact mul_le_mulIntervalUpper_real (hlo1 j) (hhi1 j) (hlo2 j) (hhi2 j) simpa [hcast, dotProduct] using hsum +theorem dotIntervalLowerUpper2SignSplit_spec_real {n : Nat} (dims : List (Fin n)) + (lo1 hi1 lo2 hi2 : Fin n → Rat) (x y : Fin n → Real) + (hlo1 : ∀ j, (lo1 j : Real) ≤ x j) (hhi1 : ∀ j, x j ≤ (hi1 j : Real)) + (hlo2 : ∀ j, (lo2 j : Real) ≤ y j) (hhi2 : ∀ j, y j ≤ (hi2 j : Real)) : + let bounds := dotIntervalLowerUpper2SignSplit dims lo1 hi1 lo2 hi2 + (bounds.1 : Real) ≤ dotProduct x y ∧ dotProduct x y ≤ (bounds.2 : Real) := by + classical + induction dims generalizing lo1 hi1 with + | nil => + have hlow := + dotIntervalLower2_le_dotProduct_real + (lo1 := lo1) (hi1 := hi1) (lo2 := lo2) (hi2 := hi2) + (x := x) (y := y) hlo1 hhi1 hlo2 hhi2 + have hhigh := + dotProduct_le_dotIntervalUpper2_real + (lo1 := lo1) (hi1 := hi1) (lo2 := lo2) (hi2 := hi2) + (x := x) (y := y) hlo1 hhi1 hlo2 hhi2 + simpa [dotIntervalLowerUpper2SignSplit, dotIntervalLowerUpper2CommonDen_fst, + dotIntervalLowerUpper2CommonDen_snd] using And.intro hlow hhigh + | cons i rest ih => + by_cases hx : 0 ≤ x i + · let clamped := clampAt i true lo1 hi1 + let boundsPos := dotIntervalLowerUpper2SignSplit rest clamped.1 clamped.2 lo2 hi2 + let boundsNeg := + dotIntervalLowerUpper2SignSplit rest (clampAt i false lo1 hi1).1 + (clampAt i false lo1 hi1).2 lo2 hi2 + have hlo1' : ∀ j, (clamped.1 j : Real) ≤ x j := by + intro j + by_cases hji : j = i + · have hmax : max (0 : Real) (lo1 i : Real) ≤ x i := + (max_le_iff).2 ⟨hx, hlo1 i⟩ + simpa [clamped, clampAt, hji, ratToReal_max] using hmax + · simpa [clamped, clampAt, hji] using hlo1 j + have hhi1' : ∀ j, x j ≤ (clamped.2 j : Real) := by + intro j + simpa [clamped, clampAt] using hhi1 j + have hpos := + ih (lo1 := clamped.1) (hi1 := clamped.2) hlo1' hhi1' + have hlow : (min boundsPos.1 boundsNeg.1 : Real) ≤ dotProduct x y := by + have hmin : (min boundsPos.1 boundsNeg.1 : Real) ≤ (boundsPos.1 : Real) := by + exact min_le_left _ _ + exact le_trans hmin hpos.1 + have hhigh : dotProduct x y ≤ (max boundsPos.2 boundsNeg.2 : Real) := by + have hmax : (boundsPos.2 : Real) ≤ (max boundsPos.2 boundsNeg.2 : Real) := by + exact le_max_left _ _ + exact le_trans hpos.2 hmax + simpa [dotIntervalLowerUpper2SignSplit, boundsPos, boundsNeg, clamped] using + And.intro hlow hhigh + · have hxneg : x i ≤ 0 := le_of_lt (lt_of_not_ge hx) + let clamped := clampAt i false lo1 hi1 + let boundsPos := + dotIntervalLowerUpper2SignSplit rest (clampAt i true lo1 hi1).1 + (clampAt i true lo1 hi1).2 lo2 hi2 + let boundsNeg := dotIntervalLowerUpper2SignSplit rest clamped.1 clamped.2 lo2 hi2 + have hlo1' : ∀ j, (clamped.1 j : Real) ≤ x j := by + intro j + simpa [clamped, clampAt] using hlo1 j + have hhi1' : ∀ j, x j ≤ (clamped.2 j : Real) := by + intro j + by_cases hji : j = i + · have hmin : x i ≤ min (0 : Real) (hi1 i : Real) := + (le_min_iff).2 ⟨hxneg, hhi1 i⟩ + simpa [clamped, clampAt, hji, ratToReal_min] using hmin + · simpa [clamped, clampAt, hji] using hhi1 j + have hneg := + ih (lo1 := clamped.1) (hi1 := clamped.2) hlo1' hhi1' + have hlow : (min boundsPos.1 boundsNeg.1 : Real) ≤ dotProduct x y := by + have hmin : (min boundsPos.1 boundsNeg.1 : Real) ≤ (boundsNeg.1 : Real) := by + exact min_le_right _ _ + exact le_trans hmin hneg.1 + have hhigh : dotProduct x y ≤ (max boundsPos.2 boundsNeg.2 : Real) := by + have hmax : (boundsNeg.2 : Real) ≤ (max boundsPos.2 boundsNeg.2 : Real) := by + exact le_max_right _ _ + exact le_trans hneg.2 hmax + simpa [dotIntervalLowerUpper2SignSplit, boundsPos, boundsNeg, clamped] using + And.intro hlow hhigh + theorem dotIntervalLower_le_dotProduct_real {n : Nat} (v lo hi : Fin n → Rat) (x : Fin n → Real) (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : diff --git a/Nfp/Sound/Induction/Core.lean b/Nfp/Sound/Induction/Core.lean index 1f7339d..3210aa8 100644 --- a/Nfp/Sound/Induction/Core.lean +++ b/Nfp/Sound/Induction/Core.lean @@ -217,11 +217,17 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} let kAbs : Fin seq → Fin dHead → Rat := fun q d => max |kLo q d| |kHi q d| let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k + let splitBudget : Nat := 2 + let splitDims : Fin seq → List (Fin dHead) := fun q => + let ambig := + (List.finRange dHead).filter (fun d => qLo q d < 0 ∧ 0 < qHi q d) + ambig.take splitBudget let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => - ⟨Array.ofFn (fun k : Fin seq => - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2CommonDen + let dims := splitDims q + ⟨Array.ofFn (fun k : Fin seq => + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplit dims (fun d => qLo q d) (fun d => qHi q d) (fun d => kLo k d) (fun d => kHi k d)), by simp⟩)) @@ -436,11 +442,17 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} let kAbs : Fin seq → Fin dHead → Rat := fun q d => max |kLo q d| |kHi q d| let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k + let splitBudget : Nat := 2 + let splitDims : Fin seq → List (Fin dHead) := fun q => + let ambig := + (List.finRange dHead).filter (fun d => qLo q d < 0 ∧ 0 < qHi q d) + ambig.take splitBudget let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => + let dims := splitDims q ⟨Array.ofFn (fun k : Fin seq => - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2CommonDen + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplit dims (fun d => qLo q d) (fun d => qHi q d) (fun d => kLo k d) (fun d => kHi k d)), by simp⟩)) @@ -541,7 +553,8 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} lnLo, lnHi, lnAbsMaxTask, lnAbsMaxArr, lnAbsMax, lnAbsMaxMax, qLoRowTasks, qHiRowTasks, qLoArr, qHiArr, qLo, qHi, kLoRowTasks, kHiRowTasks, kLoArr, kHiArr, kLo, kHi, - qAbs, kAbs, masked, dotRowTasks, dotLo, dotHi, dotAbs, scoreBaseAbs, scoreLo, + qAbs, kAbs, masked, splitBudget, splitDims, dotRowTasks, dotLo, dotHi, dotAbs, + scoreBaseAbs, scoreLo, scoreHi, scoreLoPrev, otherKeys, marginAt, epsAt, margin, eps, dirHeadVec, dirHead, wvDir, bDir, valsAbsBase, valsLoBase, valsHiBase, valsLo, valsHi, univ, lo, hi, valCert, cert, Task.spawn, Bounds.cacheBoundTask_apply, @@ -724,15 +737,9 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} (hk d).1 have hhi2 : ∀ d, kRealOfInputs inputs k d ≤ (kHi k d : Real) := fun d => (hk d).2 - have hlow := - _root_.Nfp.Sound.Bounds.dotIntervalLower2_le_dotProduct_real - (lo1 := fun d => qLo q d) (hi1 := fun d => qHi q d) - (lo2 := fun d => kLo k d) (hi2 := fun d => kHi k d) - (x := fun d => qRealOfInputs inputs q d) - (y := fun d => kRealOfInputs inputs k d) - hlo1 hhi1 hlo2 hhi2 - have hhigh := - _root_.Nfp.Sound.Bounds.dotProduct_le_dotIntervalUpper2_real + have hspec := + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplit_spec_real + (dims := splitDims q) (lo1 := fun d => qLo q d) (hi1 := fun d => qHi q d) (lo2 := fun d => kLo k d) (hi2 := fun d => kHi k d) (x := fun d => qRealOfInputs inputs q d) @@ -742,13 +749,13 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} (dotLo q k : Real) ≤ dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) := by - simpa [dotLo, dotRowTasks, Task.spawn, Array.getElem_ofFn, - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2CommonDen_fst] using hlow + simpa [dotLo, dotRowTasks, Task.spawn, Array.getElem_ofFn, splitDims] + using hspec.1 have hhigh' : dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) ≤ (dotHi q k : Real) := by - simpa [dotHi, dotRowTasks, Task.spawn, Array.getElem_ofFn, - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2CommonDen_snd] using hhigh + simpa [dotHi, dotRowTasks, Task.spawn, Array.getElem_ofFn, splitDims] + using hspec.2 exact ⟨hlow', hhigh'⟩ by_cases hcausal : inputs.maskCausal · by_cases hle : k ≤ q From 2621aa8a460d3666374b4db758f48bf666d5c1ae Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 9 Jan 2026 14:46:53 +0100 Subject: [PATCH 117/244] Prioritize sign-split dims by bound score --- Nfp/Sound/Induction/Core.lean | 82 ++++++++++++++++++++++++++++++----- 1 file changed, 72 insertions(+), 10 deletions(-) diff --git a/Nfp/Sound/Induction/Core.lean b/Nfp/Sound/Induction/Core.lean index 3210aa8..f3911b8 100644 --- a/Nfp/Sound/Induction/Core.lean +++ b/Nfp/Sound/Induction/Core.lean @@ -215,13 +215,44 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} simp [row.2]) let qAbs : Fin seq → Fin dHead → Rat := fun q d => max |qLo q d| |qHi q d| let kAbs : Fin seq → Fin dHead → Rat := fun q d => max |kLo q d| |kHi q d| + let kAbsMaxArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := Finset.univ_nonempty + univ.sup' hnonempty (fun k => kAbs k d)) + let kAbsMax : Fin dHead → Rat := fun d => + kAbsMaxArr[d.1]'(by + have hsize : kAbsMaxArr.size = dHead := by + simp [kAbsMaxArr] + simp [hsize]) let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k let splitBudget : Nat := 2 let splitDims : Fin seq → List (Fin dHead) := fun q => let ambig := (List.finRange dHead).filter (fun d => qLo q d < 0 ∧ 0 < qHi q d) - ambig.take splitBudget + let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d + let step + (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) + (d : Fin dHead) : + Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := + let s := score d + match best with + | (none, none) => (some (s, d), none) + | (some b1, none) => + if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) + | (some b1, some b2) => + if b1.1 < s then (some (s, d), some b1) + else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) + | (none, some b2) => + if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) + let best := ambig.foldl step (none, none) + let dims := + match best with + | (some b1, some b2) => [b1.2, b2.2] + | (some b1, none) => [b1.2] + | (none, _) => [] + dims.take splitBudget let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => @@ -440,13 +471,44 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} simp [row.2]) let qAbs : Fin seq → Fin dHead → Rat := fun q d => max |qLo q d| |qHi q d| let kAbs : Fin seq → Fin dHead → Rat := fun q d => max |kLo q d| |kHi q d| + let kAbsMaxArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := Finset.univ_nonempty + univ.sup' hnonempty (fun k => kAbs k d)) + let kAbsMax : Fin dHead → Rat := fun d => + kAbsMaxArr[d.1]'(by + have hsize : kAbsMaxArr.size = dHead := by + simp [kAbsMaxArr] + simp [hsize]) let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k let splitBudget : Nat := 2 let splitDims : Fin seq → List (Fin dHead) := fun q => let ambig := (List.finRange dHead).filter (fun d => qLo q d < 0 ∧ 0 < qHi q d) - ambig.take splitBudget + let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d + let step + (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) + (d : Fin dHead) : + Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := + let s := score d + match best with + | (none, none) => (some (s, d), none) + | (some b1, none) => + if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) + | (some b1, some b2) => + if b1.1 < s then (some (s, d), some b1) + else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) + | (none, some b2) => + if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) + let best := ambig.foldl step (none, none) + let dims := + match best with + | (some b1, some b2) => [b1.2, b2.2] + | (some b1, none) => [b1.2] + | (none, _) => [] + dims.take splitBudget let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => @@ -553,12 +615,12 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} lnLo, lnHi, lnAbsMaxTask, lnAbsMaxArr, lnAbsMax, lnAbsMaxMax, qLoRowTasks, qHiRowTasks, qLoArr, qHiArr, qLo, qHi, kLoRowTasks, kHiRowTasks, kLoArr, kHiArr, kLo, kHi, - qAbs, kAbs, masked, splitBudget, splitDims, dotRowTasks, dotLo, dotHi, dotAbs, - scoreBaseAbs, scoreLo, - scoreHi, scoreLoPrev, otherKeys, marginAt, epsAt, margin, eps, dirHeadVec, - dirHead, wvDir, bDir, valsAbsBase, valsLoBase, valsHiBase, valsLo, - valsHi, univ, lo, hi, valCert, cert, Task.spawn, Bounds.cacheBoundTask_apply, - Array.getElem_ofFn] + qAbs, kAbs, kAbsMaxArr, kAbsMax, masked, splitBudget, splitDims, + dotRowTasks, dotLo, dotHi, dotAbs, + scoreBaseAbs, scoreLo, scoreHi, scoreLoPrev, otherKeys, marginAt, epsAt, + margin, eps, dirHeadVec, dirHead, wvDir, bDir, valsAbsBase, valsLoBase, + valsHiBase, valsLo, valsHi, univ, lo, hi, valCert, cert, Task.spawn, + Bounds.cacheBoundTask_apply, Array.getElem_ofFn] using hcore have hc : c = cert := by simpa using (Option.some.inj hcore').symm @@ -749,12 +811,12 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} (dotLo q k : Real) ≤ dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) := by - simpa [dotLo, dotRowTasks, Task.spawn, Array.getElem_ofFn, splitDims] + simpa [dotLo, dotRowTasks, Task.spawn, Array.getElem_ofFn] using hspec.1 have hhigh' : dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) ≤ (dotHi q k : Real) := by - simpa [dotHi, dotRowTasks, Task.spawn, Array.getElem_ofFn, splitDims] + simpa [dotHi, dotRowTasks, Task.spawn, Array.getElem_ofFn] using hspec.2 exact ⟨hlow', hhigh'⟩ by_cases hcausal : inputs.maskCausal From d385169da5d52cebf481d2d182fce3dd63c498e7 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 9 Jan 2026 15:17:04 +0100 Subject: [PATCH 118/244] Tighten dot bounds with k-side sign splits --- Nfp/Sound/Bounds/MatrixNorm/Interval.lean | 36 +++++++ Nfp/Sound/Induction/Core.lean | 115 +++++++++++++++++----- 2 files changed, 124 insertions(+), 27 deletions(-) diff --git a/Nfp/Sound/Bounds/MatrixNorm/Interval.lean b/Nfp/Sound/Bounds/MatrixNorm/Interval.lean index d932d3d..d233878 100644 --- a/Nfp/Sound/Bounds/MatrixNorm/Interval.lean +++ b/Nfp/Sound/Bounds/MatrixNorm/Interval.lean @@ -209,6 +209,13 @@ def dotIntervalLowerUpper2SignSplit {n : Nat} (dims : List (Fin n)) dotIntervalLowerUpper2SignSplit rest clamped.1 clamped.2 lo2 hi2 (min boundsPos.1 boundsNeg.1, max boundsPos.2 boundsNeg.2) +/-- Lower/upper interval endpoints with sign-splitting on both sides. -/ +def dotIntervalLowerUpper2SignSplitBoth {n : Nat} (dims1 dims2 : List (Fin n)) + (lo1 hi1 lo2 hi2 : Fin n → Rat) : Rat × Rat := + let bounds1 := dotIntervalLowerUpper2SignSplit dims1 lo1 hi1 lo2 hi2 + let bounds2 := dotIntervalLowerUpper2SignSplit dims2 lo2 hi2 lo1 hi1 + (max bounds1.1 bounds2.1, min bounds1.2 bounds2.2) + theorem dotIntervalLower2_le_dotProduct {n : Nat} (lo1 hi1 lo2 hi2 x y : Fin n → Rat) (hlo1 : ∀ j, lo1 j ≤ x j) (hhi1 : ∀ j, x j ≤ hi1 j) (hlo2 : ∀ j, lo2 j ≤ y j) (hhi2 : ∀ j, y j ≤ hi2 j) : @@ -775,6 +782,35 @@ theorem dotIntervalLowerUpper2SignSplit_spec_real {n : Nat} (dims : List (Fin n) simpa [dotIntervalLowerUpper2SignSplit, boundsPos, boundsNeg, clamped] using And.intro hlow hhigh +theorem dotIntervalLowerUpper2SignSplitBoth_spec_real {n : Nat} (dims1 dims2 : List (Fin n)) + (lo1 hi1 lo2 hi2 : Fin n → Rat) (x y : Fin n → Real) + (hlo1 : ∀ j, (lo1 j : Real) ≤ x j) (hhi1 : ∀ j, x j ≤ (hi1 j : Real)) + (hlo2 : ∀ j, (lo2 j : Real) ≤ y j) (hhi2 : ∀ j, y j ≤ (hi2 j : Real)) : + let bounds := dotIntervalLowerUpper2SignSplitBoth dims1 dims2 lo1 hi1 lo2 hi2 + (bounds.1 : Real) ≤ dotProduct x y ∧ dotProduct x y ≤ (bounds.2 : Real) := by + classical + let bounds1 := dotIntervalLowerUpper2SignSplit dims1 lo1 hi1 lo2 hi2 + let bounds2 := dotIntervalLowerUpper2SignSplit dims2 lo2 hi2 lo1 hi1 + have h1 := + dotIntervalLowerUpper2SignSplit_spec_real + (dims := dims1) (lo1 := lo1) (hi1 := hi1) (lo2 := lo2) (hi2 := hi2) + (x := x) (y := y) hlo1 hhi1 hlo2 hhi2 + have h2swap := + dotIntervalLowerUpper2SignSplit_spec_real + (dims := dims2) (lo1 := lo2) (hi1 := hi2) (lo2 := lo1) (hi2 := hi1) + (x := y) (y := x) hlo2 hhi2 hlo1 hhi1 + have h2 : (bounds2.1 : Real) ≤ dotProduct x y ∧ dotProduct x y ≤ (bounds2.2 : Real) := by + simpa [dotProduct_comm] using h2swap + have hlow' : max (bounds1.1 : Real) (bounds2.1 : Real) ≤ dotProduct x y := + (max_le_iff).2 ⟨h1.1, h2.1⟩ + have hhigh' : dotProduct x y ≤ min (bounds1.2 : Real) (bounds2.2 : Real) := + (le_min_iff).2 ⟨h1.2, h2.2⟩ + have hlow : ((max bounds1.1 bounds2.1 : Rat) : Real) ≤ dotProduct x y := by + simpa [ratToReal_max] using hlow' + have hhigh : dotProduct x y ≤ ((min bounds1.2 bounds2.2 : Rat) : Real) := by + simpa [ratToReal_min] using hhigh' + simpa [dotIntervalLowerUpper2SignSplitBoth, bounds1, bounds2] using And.intro hlow hhigh + theorem dotIntervalLower_le_dotProduct_real {n : Nat} (v lo hi : Fin n → Rat) (x : Fin n → Real) (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : diff --git a/Nfp/Sound/Induction/Core.lean b/Nfp/Sound/Induction/Core.lean index f3911b8..28daee9 100644 --- a/Nfp/Sound/Induction/Core.lean +++ b/Nfp/Sound/Induction/Core.lean @@ -1,5 +1,4 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later - import Mathlib.Algebra.BigOperators.Group.Finset.Basic import Mathlib.Algebra.Order.BigOperators.Group.Finset import Mathlib.Algebra.Order.Field.Basic @@ -15,26 +14,18 @@ import Nfp.Sound.Bounds.MatrixNorm import Nfp.Sound.Induction.CoreDefs import Nfp.Sound.Induction.OneHot import Nfp.Sound.Linear.FinFold - /-! Sound builders for induction certificates. - These builders recompute certificate bounds inside Lean from exact inputs and return proof-carrying results. The head-input path derives softmax tolerances from score margins rather than trusting external weight dumps. -/ - namespace Nfp - namespace Sound - open scoped BigOperators - open Nfp.Circuit open Nfp.Sound.Bounds - variable {seq : Nat} - /-- Build and certify a softmax-margin certificate from exact scores/weights. -/ def buildSoftmaxMarginCert? [NeZero seq] (active : Finset (Fin seq)) @@ -81,7 +72,6 @@ def buildSoftmaxMarginCert? [NeZero seq] exact some ⟨cert, h⟩ else exact none - /-- Build and certify a value-range certificate from exact values. -/ def buildValueRangeCert? [NeZero seq] (vals : Fin seq → Rat) @@ -104,7 +94,6 @@ def buildValueRangeCert? [NeZero seq] exact some ⟨cert, h⟩ else exact none - /-- Build induction certificates from exact head inputs (core computation). -/ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : @@ -215,6 +204,16 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} simp [row.2]) let qAbs : Fin seq → Fin dHead → Rat := fun q d => max |qLo q d| |qHi q d| let kAbs : Fin seq → Fin dHead → Rat := fun q d => max |kLo q d| |kHi q d| + let qAbsMaxArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := Finset.univ_nonempty + univ.sup' hnonempty (fun q => qAbs q d)) + let qAbsMax : Fin dHead → Rat := fun d => + qAbsMaxArr[d.1]'(by + have hsize : qAbsMaxArr.size = dHead := by + simp [qAbsMaxArr] + simp [hsize]) let kAbsMaxArr : Array Rat := Array.ofFn (fun d : Fin dHead => let univ : Finset (Fin seq) := Finset.univ @@ -227,8 +226,9 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} simp [hsize]) let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k - let splitBudget : Nat := 2 - let splitDims : Fin seq → List (Fin dHead) := fun q => + let splitBudgetQ : Nat := 2 + let splitBudgetK : Nat := 1 + let splitDimsQ : Fin seq → List (Fin dHead) := fun q => let ambig := (List.finRange dHead).filter (fun d => qLo q d < 0 ∧ 0 < qHi q d) let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d @@ -252,13 +252,39 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} | (some b1, some b2) => [b1.2, b2.2] | (some b1, none) => [b1.2] | (none, _) => [] - dims.take splitBudget + dims.take splitBudgetQ + let splitDimsK : Fin seq → List (Fin dHead) := fun k => + let ambig := + (List.finRange dHead).filter (fun d => kLo k d < 0 ∧ 0 < kHi k d) + let score : Fin dHead → Rat := fun d => (kHi k d - kLo k d) * qAbsMax d + let step + (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) + (d : Fin dHead) : + Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := + let s := score d + match best with + | (none, none) => (some (s, d), none) + | (some b1, none) => + if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) + | (some b1, some b2) => + if b1.1 < s then (some (s, d), some b1) + else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) + | (none, some b2) => + if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) + let best := ambig.foldl step (none, none) + let dims := + match best with + | (some b1, some b2) => [b1.2, b2.2] + | (some b1, none) => [b1.2] + | (none, _) => [] + dims.take splitBudgetK let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => - let dims := splitDims q + let dimsQ := splitDimsQ q ⟨Array.ofFn (fun k : Fin seq => - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplit dims + let dimsK := splitDimsK k + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsK (fun d => qLo q d) (fun d => qHi q d) (fun d => kLo k d) (fun d => kHi k d)), by simp⟩)) @@ -355,7 +381,6 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} · exact none · exact none · exact none - set_option maxHeartbeats 1000000 in -- Large softmax/interval proof expands many bounds; bump heartbeats to avoid timeouts. /-- Soundness for `buildInductionCertFromHeadCore?`. -/ @@ -471,6 +496,16 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} simp [row.2]) let qAbs : Fin seq → Fin dHead → Rat := fun q d => max |qLo q d| |qHi q d| let kAbs : Fin seq → Fin dHead → Rat := fun q d => max |kLo q d| |kHi q d| + let qAbsMaxArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := Finset.univ_nonempty + univ.sup' hnonempty (fun q => qAbs q d)) + let qAbsMax : Fin dHead → Rat := fun d => + qAbsMaxArr[d.1]'(by + have hsize : qAbsMaxArr.size = dHead := by + simp [qAbsMaxArr] + simp [hsize]) let kAbsMaxArr : Array Rat := Array.ofFn (fun d : Fin dHead => let univ : Finset (Fin seq) := Finset.univ @@ -483,8 +518,9 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} simp [hsize]) let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k - let splitBudget : Nat := 2 - let splitDims : Fin seq → List (Fin dHead) := fun q => + let splitBudgetQ : Nat := 2 + let splitBudgetK : Nat := 1 + let splitDimsQ : Fin seq → List (Fin dHead) := fun q => let ambig := (List.finRange dHead).filter (fun d => qLo q d < 0 ∧ 0 < qHi q d) let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d @@ -508,13 +544,39 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} | (some b1, some b2) => [b1.2, b2.2] | (some b1, none) => [b1.2] | (none, _) => [] - dims.take splitBudget + dims.take splitBudgetQ + let splitDimsK : Fin seq → List (Fin dHead) := fun k => + let ambig := + (List.finRange dHead).filter (fun d => kLo k d < 0 ∧ 0 < kHi k d) + let score : Fin dHead → Rat := fun d => (kHi k d - kLo k d) * qAbsMax d + let step + (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) + (d : Fin dHead) : + Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := + let s := score d + match best with + | (none, none) => (some (s, d), none) + | (some b1, none) => + if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) + | (some b1, some b2) => + if b1.1 < s then (some (s, d), some b1) + else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) + | (none, some b2) => + if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) + let best := ambig.foldl step (none, none) + let dims := + match best with + | (some b1, some b2) => [b1.2, b2.2] + | (some b1, none) => [b1.2] + | (none, _) => [] + dims.take splitBudgetK let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => - let dims := splitDims q + let dimsQ := splitDimsQ q ⟨Array.ofFn (fun k : Fin seq => - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplit dims + let dimsK := splitDimsK k + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsK (fun d => qLo q d) (fun d => qHi q d) (fun d => kLo k d) (fun d => kHi k d)), by simp⟩)) @@ -615,7 +677,8 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} lnLo, lnHi, lnAbsMaxTask, lnAbsMaxArr, lnAbsMax, lnAbsMaxMax, qLoRowTasks, qHiRowTasks, qLoArr, qHiArr, qLo, qHi, kLoRowTasks, kHiRowTasks, kLoArr, kHiArr, kLo, kHi, - qAbs, kAbs, kAbsMaxArr, kAbsMax, masked, splitBudget, splitDims, + qAbs, kAbs, qAbsMaxArr, qAbsMax, kAbsMaxArr, kAbsMax, masked, + splitBudgetQ, splitBudgetK, splitDimsQ, splitDimsK, dotRowTasks, dotLo, dotHi, dotAbs, scoreBaseAbs, scoreLo, scoreHi, scoreLoPrev, otherKeys, marginAt, epsAt, margin, eps, dirHeadVec, dirHead, wvDir, bDir, valsAbsBase, valsLoBase, @@ -800,8 +863,8 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} have hhi2 : ∀ d, kRealOfInputs inputs k d ≤ (kHi k d : Real) := fun d => (hk d).2 have hspec := - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplit_spec_real - (dims := splitDims q) + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth_spec_real + (dims1 := splitDimsQ q) (dims2 := splitDimsK k) (lo1 := fun d => qLo q d) (hi1 := fun d => qHi q d) (lo2 := fun d => kLo k d) (hi2 := fun d => kHi k d) (x := fun d => qRealOfInputs inputs q d) @@ -1431,7 +1494,5 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} · have : False := by simp [buildInductionCertFromHeadCore?, hEps] at hcore exact this.elim - end Sound - end Nfp From 3a71d8b517a4bb2b74ef508e9cc10b31898d63b5 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 9 Jan 2026 15:38:37 +0100 Subject: [PATCH 119/244] Increase k-side split budget --- Nfp/Sound/Induction/Core.lean | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Nfp/Sound/Induction/Core.lean b/Nfp/Sound/Induction/Core.lean index 28daee9..3831993 100644 --- a/Nfp/Sound/Induction/Core.lean +++ b/Nfp/Sound/Induction/Core.lean @@ -227,7 +227,7 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k let splitBudgetQ : Nat := 2 - let splitBudgetK : Nat := 1 + let splitBudgetK : Nat := 2 let splitDimsQ : Fin seq → List (Fin dHead) := fun q => let ambig := (List.finRange dHead).filter (fun d => qLo q d < 0 ∧ 0 < qHi q d) @@ -519,7 +519,7 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k let splitBudgetQ : Nat := 2 - let splitBudgetK : Nat := 1 + let splitBudgetK : Nat := 2 let splitDimsQ : Fin seq → List (Fin dHead) := fun q => let ambig := (List.finRange dHead).filter (fun d => qLo q d < 0 ∧ 0 < qHi q d) From d24130caa113e75f6a18e0c94c04f5a5e28f66f4 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 10 Jan 2026 02:44:17 +0100 Subject: [PATCH 120/244] Refactor induction head soundness and invStd bounds --- AGENTS.md | 6 + Nfp/Sound/Bounds/LayerNorm.lean | 422 +++++-- Nfp/Sound/Bounds/LayerNorm/InvStd.lean | 124 ++ Nfp/Sound/Induction.lean | 1 + Nfp/Sound/Induction/Core.lean | 1349 +++----------------- Nfp/Sound/Induction/CoreSound.lean | 1382 +++++++++++++++++++++ Nfp/Sound/Induction/CoreSound/Values.lean | 162 +++ Nfp/Sound/Induction/HeadOutput.lean | 2 +- 8 files changed, 2117 insertions(+), 1331 deletions(-) create mode 100644 Nfp/Sound/Bounds/LayerNorm/InvStd.lean create mode 100644 Nfp/Sound/Induction/CoreSound.lean create mode 100644 Nfp/Sound/Induction/CoreSound/Values.lean diff --git a/AGENTS.md b/AGENTS.md index 097fb54..1a3f1c1 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -300,6 +300,10 @@ but you **must** update this list in the same commit. - Aggregator for induction soundness modules. - `Nfp/Sound/Induction/Core.lean` - Sound builders and core proofs for induction certificates from exact inputs. +- `Nfp/Sound/Induction/CoreSound.lean` + - Soundness proof for `buildInductionCertFromHeadCore?`. +- `Nfp/Sound/Induction/CoreSound/Values.lean` + - Helper lemmas for value-direction projections in the core soundness proof. - `Nfp/Sound/Induction/CoreDefs.lean` - Core definitions and soundness predicates for induction certificates. - `Nfp/Sound/Induction/HeadOutput.lean` @@ -316,6 +320,8 @@ but you **must** update this list in the same commit. - LayerNorm interval bounds and end-to-end soundness lemmas. - `Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean` - Mean/variance helpers for LayerNorm bounds. +- `Nfp/Sound/Bounds/LayerNorm/InvStd.lean` + - Inverse-standard-deviation bounds for LayerNorm. - `Nfp/Sound/Bounds/UnnormRat.lean` - Unnormalized rational helpers for deferred normalization in bounds kernels. - `Nfp/Sound/Bounds/Gelu.lean` diff --git a/Nfp/Sound/Bounds/LayerNorm.lean b/Nfp/Sound/Bounds/LayerNorm.lean index 4e445a7..3a11a47 100644 --- a/Nfp/Sound/Bounds/LayerNorm.lean +++ b/Nfp/Sound/Bounds/LayerNorm.lean @@ -67,7 +67,7 @@ def sqrtUpperAlt (q : Rat) : Rat := ratRoundUp ((a + 1 : Rat) / den) /-- Extra precision scale for `sqrtLowerScaled`. -/ -def sqrtLowerScale : Nat := 4096 +def sqrtLowerScale : Nat := 65536 /-- Scaled rational lower bound for a square root (extra precision). -/ def sqrtLowerScaled (q : Rat) : Rat := @@ -77,13 +77,21 @@ def sqrtLowerScaled (q : Rat) : Rat := let a := Nat.sqrt (num * den * scale * scale) ratRoundDown ((a : Rat) / (den * scale)) -/-- Rational lower bound for a square root (tighter of two bounds). -/ +/-- Scaled rational upper bound for a square root (extra precision). -/ +def sqrtUpperScaled (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den + let scale := sqrtLowerScale + let a := Nat.sqrt (num * den * scale * scale) + ratRoundUp ((a + 1 : Rat) / (den * scale)) + +/-- Rational lower bound for a square root (tighter of three bounds). -/ def sqrtLower (q : Rat) : Rat := max (max (sqrtLowerBase q) (sqrtLowerAlt q)) (sqrtLowerScaled q) -/-- Rational upper bound for a square root (tighter of two bounds). -/ +/-- Rational upper bound for a square root (tighter of three bounds). -/ def sqrtUpper (q : Rat) : Rat := - min (sqrtUpperBase q) (sqrtUpperAlt q) + min (min (sqrtUpperBase q) (sqrtUpperAlt q)) (sqrtUpperScaled q) /-- `sqrtLowerBase` is nonnegative. -/ theorem sqrtLowerBase_nonneg (q : Rat) : 0 ≤ sqrtLowerBase q := by @@ -173,6 +181,43 @@ theorem sqrtUpperAlt_pos (q : Rat) : 0 < sqrtUpperAlt q := by exact div_pos hnum_pos hden_pos exact ratRoundUp_pos hrat_pos +/-- `sqrtUpperScaled` is nonnegative. -/ +theorem sqrtUpperScaled_nonneg (q : Rat) : 0 ≤ sqrtUpperScaled q := by + classical + unfold sqrtUpperScaled + have hnum : + 0 ≤ (Nat.sqrt (q.num.natAbs * q.den * sqrtLowerScale * sqrtLowerScale) + 1 : Rat) := by + exact_mod_cast + (Nat.zero_le (Nat.sqrt (q.num.natAbs * q.den * sqrtLowerScale * sqrtLowerScale) + 1)) + have hden : 0 ≤ (q.den * sqrtLowerScale : Rat) := by + exact_mod_cast (Nat.zero_le (q.den * sqrtLowerScale)) + have hrat : + 0 ≤ (Nat.sqrt (q.num.natAbs * q.den * sqrtLowerScale * sqrtLowerScale) + 1 : Rat) / + (q.den * sqrtLowerScale) := by + exact div_nonneg hnum hden + exact ratRoundUp_nonneg hrat + +/-- `sqrtUpperScaled` is always positive. -/ +theorem sqrtUpperScaled_pos (q : Rat) : 0 < sqrtUpperScaled q := by + classical + unfold sqrtUpperScaled + have hnum_pos : + (0 : Rat) < + (Nat.sqrt (q.num.natAbs * q.den * sqrtLowerScale * sqrtLowerScale) + 1 : Rat) := by + exact_mod_cast + (Nat.succ_pos (Nat.sqrt (q.num.natAbs * q.den * sqrtLowerScale * sqrtLowerScale))) + have hden_pos : (0 : Rat) < (q.den * sqrtLowerScale : Rat) := by + have hden : 0 < q.den := q.den_pos + have hscale : 0 < sqrtLowerScale := by + simp [sqrtLowerScale] + exact_mod_cast (Nat.mul_pos hden hscale) + have hrat_pos : + (0 : Rat) < + (Nat.sqrt (q.num.natAbs * q.den * sqrtLowerScale * sqrtLowerScale) + 1 : Rat) / + (q.den * sqrtLowerScale) := by + exact div_pos hnum_pos hden_pos + exact ratRoundUp_pos hrat_pos + /-! Combined bounds. -/ /-- `sqrtLower` is nonnegative. -/ @@ -187,13 +232,19 @@ theorem sqrtLower_nonneg (q : Rat) : 0 ≤ sqrtLower q := by theorem sqrtUpper_nonneg (q : Rat) : 0 ≤ sqrtUpper q := by have hbase : 0 ≤ sqrtUpperBase q := sqrtUpperBase_nonneg q have halt : 0 ≤ sqrtUpperAlt q := sqrtUpperAlt_nonneg q - exact le_min hbase halt + have hscaled : 0 ≤ sqrtUpperScaled q := sqrtUpperScaled_nonneg q + have hmin1 : 0 ≤ min (sqrtUpperBase q) (sqrtUpperAlt q) := by + exact le_min hbase halt + exact le_min hmin1 hscaled /-- `sqrtUpper` is always positive. -/ theorem sqrtUpper_pos (q : Rat) : 0 < sqrtUpper q := by have hbase : 0 < sqrtUpperBase q := sqrtUpperBase_pos q have halt : 0 < sqrtUpperAlt q := sqrtUpperAlt_pos q - exact lt_min hbase halt + have hscaled : 0 < sqrtUpperScaled q := sqrtUpperScaled_pos q + have hmin1 : 0 < min (sqrtUpperBase q) (sqrtUpperAlt q) := by + exact lt_min hbase halt + exact lt_min hmin1 hscaled /-- Square-root lower bound in reals. -/ theorem sqrtLowerBase_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : @@ -508,6 +559,90 @@ theorem real_sqrt_le_sqrtUpperAlt {q : Rat} (hq : 0 ≤ q) : simpa [sqrtUpperAlt, num, den, a] using hup' exact le_trans hle hup +/-- Scaled square-root upper bound in reals. -/ +theorem real_sqrt_le_sqrtUpperScaled {q : Rat} (hq : 0 ≤ q) : + Real.sqrt (q : Real) ≤ (sqrtUpperScaled q : Real) := by + classical + set num : Nat := q.num.natAbs + set den : Nat := q.den + set scale : Nat := sqrtLowerScale + set a : Nat := Nat.sqrt (num * den * scale * scale) + have hden_pos : 0 < (den : Real) := by + exact_mod_cast q.den_pos + have hscale_pos : 0 < (scale : Real) := by + have hscale_pos_nat : 0 < scale := by + simp [scale, sqrtLowerScale] + exact_mod_cast hscale_pos_nat + have hnumden_lt : (num * den * scale * scale : Real) < (a + 1) ^ 2 := by + exact_mod_cast (Nat.lt_succ_sqrt' (num * den * scale * scale)) + have hmul : + (num : Real) * den * (scale : Real) * (scale : Real) ≤ (a + 1 : Real) ^ 2 := by + exact le_of_lt hnumden_lt + have hdenScale_pos : 0 < (den : Real) * (scale : Real) := by + exact mul_pos hden_pos hscale_pos + have hdenScale_pos2 : 0 < ((den : Real) * (scale : Real)) ^ 2 := by + exact pow_pos hdenScale_pos 2 + have hdiv : + (num : Real) * den * (scale : Real) * (scale : Real) / + ((den : Real) * (scale : Real)) ^ 2 ≤ + (a + 1 : Real) ^ 2 / ((den : Real) * (scale : Real)) ^ 2 := by + have hmul' : + (num : Real) * den * (scale : Real) * (scale : Real) * + ((den : Real) * (scale : Real)) ^ 2 ≤ + (a + 1 : Real) ^ 2 * ((den : Real) * (scale : Real)) ^ 2 := by + have hden_sq_nonneg : 0 ≤ ((den : Real) * (scale : Real)) ^ 2 := by + exact sq_nonneg _ + exact mul_le_mul_of_nonneg_right hmul hden_sq_nonneg + exact (div_le_div_iff₀ hdenScale_pos2 hdenScale_pos2).2 hmul' + have hdenScale_ne : ((den : Real) * (scale : Real)) ≠ 0 := by + exact ne_of_gt hdenScale_pos + have hq_cast : (q : Real) = (num : Real) / den := by + have hnum_nonneg : 0 ≤ q.num := by + exact (Rat.num_nonneg (q := q)).2 hq + have hnum_eq : (num : Int) = q.num := by + simpa [num] using (Int.natAbs_of_nonneg hnum_nonneg) + have hnum_cast : (q.num : Real) = (num : Real) := by + exact_mod_cast hnum_eq.symm + have hq_rat : (q : Real) = (q.num : Real) / q.den := by + simp [Rat.cast_def] + simpa [hnum_cast, den] using hq_rat + have hq_eq : + ((num : Real) * den * (scale : Real) * (scale : Real)) / + ((den : Real) * (scale : Real)) ^ 2 = (num : Real) / den := by + field_simp [hdenScale_ne] + have hq_cast' : + (q : Real) = + ((num : Real) * den * (scale : Real) * (scale : Real)) / + ((den : Real) * (scale : Real)) ^ 2 := by + calc + (q : Real) = (num : Real) / den := hq_cast + _ = ((num : Real) * den * (scale : Real) * (scale : Real)) / + ((den : Real) * (scale : Real)) ^ 2 := hq_eq.symm + have hpow : + ((a + 1 : Real) / ((den : Real) * (scale : Real))) ^ 2 = + (a + 1 : Real) ^ 2 / ((den : Real) * (scale : Real)) ^ 2 := by + simp [pow_two, div_mul_div_comm] + have hsq : + (q : Real) ≤ ((a + 1 : Real) / ((den : Real) * (scale : Real))) ^ 2 := by + simpa [hq_cast', hpow] using hdiv + have hnonneg : 0 ≤ ((a + 1 : Real) / ((den : Real) * (scale : Real))) := by + have hnum_nonneg : 0 ≤ (a + 1 : Real) := by + exact_mod_cast (Nat.zero_le (a + 1)) + have hden_nonneg : 0 ≤ (den : Real) * (scale : Real) := by + nlinarith [hden_pos, hscale_pos] + exact div_nonneg hnum_nonneg hden_nonneg + have hle : + Real.sqrt (q : Real) ≤ (a + 1 : Real) / ((den : Real) * (scale : Real)) := + (Real.sqrt_le_iff).2 ⟨hnonneg, hsq⟩ + have hup : + (a + 1 : Real) / ((den : Real) * (scale : Real)) ≤ (sqrtUpperScaled q : Real) := by + have hup' : + (a + 1 : Real) / ((den : Real) * (scale : Real)) ≤ + ratToReal (ratRoundUp ((a + 1 : Rat) / (den * scale))) := by + simpa using real_le_ratRoundUp ((a + 1 : Rat) / (den * scale)) + simpa [sqrtUpperScaled, num, den, scale, a] using hup' + exact le_trans hle hup + /-- Square-root lower bound in reals (tighter of three bounds). -/ theorem sqrtLower_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : (sqrtLower q : Real) ≤ Real.sqrt (q : Real) := by @@ -523,12 +658,20 @@ theorem sqrtLower_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : simpa [ratToReal_max] using (max_le_iff).2 ⟨hmax1, hscaled⟩ simpa [sqrtLower] using hmax2 -/-- Square-root upper bound in reals (tighter of two bounds). -/ +/-- Square-root upper bound in reals (tighter of three bounds). -/ theorem real_sqrt_le_sqrtUpper {q : Rat} (hq : 0 ≤ q) : Real.sqrt (q : Real) ≤ (sqrtUpper q : Real) := by have hbase := real_sqrt_le_sqrtUpperBase (q := q) hq have halt := real_sqrt_le_sqrtUpperAlt (q := q) hq - simpa [sqrtUpper] using (le_min_iff).2 ⟨hbase, halt⟩ + have hscaled := real_sqrt_le_sqrtUpperScaled (q := q) hq + have hmin1 : + Real.sqrt (q : Real) ≤ min (sqrtUpperBase q : Real) (sqrtUpperAlt q : Real) := by + exact (le_min_iff).2 ⟨hbase, halt⟩ + have hmin2 : + Real.sqrt (q : Real) ≤ + min (min (sqrtUpperBase q : Real) (sqrtUpperAlt q : Real)) (sqrtUpperScaled q : Real) := by + exact (le_min_iff).2 ⟨hmin1, hscaled⟩ + simpa [sqrtUpper, ratToReal_min] using hmin2 /-- Bounds for multiplying a scalar by a bounded value. -/ def scaleInterval (x lo hi : Rat) : Rat × Rat := @@ -606,16 +749,26 @@ def layerNormBounds {n : Nat} if n = 0 then (fun _ => 0, fun _ => 0) else - let μLo := mean x - let μHi := meanUpper x - let centeredBound : Fin n → Rat := fun i => - max |x i - μHi| |x i - μLo| - let varLo : Rat := variance x - let varEpsLo : Rat := varLo + eps - let sqrtLowerBound : Rat := max (sqrtLower eps) (sqrtLower varEpsLo) - let invStdBound : Rat := ratDivUp 1 sqrtLowerBound - let radius : Fin n → Rat := fun i => |gamma i| * centeredBound i * invStdBound - (fun i => beta i - radius i, fun i => beta i + radius i) + let μ : Rat := mean x + let centered : Fin n → Rat := fun i => x i - μ + let var : Rat := variance x + let varEps : Rat := var + eps + let sqrtLowerBound : Rat := max (sqrtLower eps) (sqrtLower varEps) + let sqrtUpperBound : Rat := sqrtUpper varEps + let invStdLower : Rat := ratDivDown 1 sqrtUpperBound + let invStdUpper : Rat := ratDivUp 1 sqrtLowerBound + let coeff : Fin n → Rat := fun i => gamma i * centered i + let lo : Fin n → Rat := fun i => + if 0 ≤ coeff i then + beta i + coeff i * invStdLower + else + beta i + coeff i * invStdUpper + let hi : Fin n → Rat := fun i => + if 0 ≤ coeff i then + beta i + coeff i * invStdUpper + else + beta i + coeff i * invStdLower + (lo, hi) /-- `layerNormBounds` soundness for real LayerNorm outputs. -/ theorem layerNormBounds_spec {n : Nat} @@ -627,145 +780,138 @@ theorem layerNormBounds_spec {n : Nat} layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by classical intro bounds i - let μLo : Rat := mean x - let μHi : Rat := meanUpper x - let centeredBound : Fin n → Rat := fun j => max |x j - μHi| |x j - μLo| - let varLo : Rat := variance x - let varEpsLo : Rat := varLo + eps - let sqrtLowerBound : Rat := max (sqrtLower eps) (sqrtLower varEpsLo) - let invStdBound : Rat := ratDivUp 1 sqrtLowerBound - let varEps : Real := (varianceRat x : Real) + (eps : Real) + let μRat : Rat := mean x + let varRat : Rat := variance x + let varEpsRat : Rat := varRat + eps + let sqrtLowerBound : Rat := max (sqrtLower eps) (sqrtLower varEpsRat) + let sqrtUpperBound : Rat := sqrtUpper varEpsRat + let invStdLower : Rat := ratDivDown 1 sqrtUpperBound + let invStdUpper : Rat := ratDivUp 1 sqrtLowerBound + let centered : Rat := x i - μRat + let coeff : Rat := gamma i * centered let μ : Real := meanRat x + let varEps : Real := (varianceRat x : Real) + (eps : Real) let invStd : Real := (Real.sqrt varEps)⁻¹ - have hcentered_nonneg : 0 ≤ (centeredBound i : Real) := by - have h0 : 0 ≤ centeredBound i := by - dsimp [centeredBound] - exact le_trans (abs_nonneg _) (le_max_left _ _) - exact ratToReal_nonneg_of_nonneg h0 - have hmean_lo_real : (μLo : Real) ≤ μ := by - have h := ratRoundDown_le_real (meanRat x) - simpa [μLo, μ, mean_def x hne] using h - have hmean_hi_real : μ ≤ (μHi : Real) := by - have h := real_le_ratRoundUp (meanRat x) - simpa [μHi, μ, meanUpper_def x hne] using h - have hcentered_abs : |(x i : Real) - μ| ≤ (centeredBound i : Real) := by - have hlo : (x i : Real) - (μHi : Real) ≤ (x i : Real) - μ := by - exact sub_le_sub_left hmean_hi_real (x i : Real) - have hhi : (x i : Real) - μ ≤ (x i : Real) - (μLo : Real) := by - exact sub_le_sub_left hmean_lo_real (x i : Real) - have hbound := abs_le_max_of_bounds hlo hhi - simpa [centeredBound, μLo, μHi, ratToReal_abs, ratToReal_sub, - ratToReal_max] using hbound + have hmu : (μRat : Real) = μ := by + simp [μRat, μ, mean_def, hne, ratRoundDown] + have hvar : (varRat : Real) = (varianceRat x : Real) := by + simp [varRat, variance_def, hne, ratRoundDown] + have hvarEps : (varEpsRat : Real) = varEps := by + simp [varEpsRat, varEps, hvar] have hvar_nonneg : 0 ≤ (varianceRat x : Real) := varianceRat_nonneg_real x hne have hvar_nonneg_rat : 0 ≤ varianceRat x := by - have hreal : 0 ≤ (varianceRat x : Real) := hvar_nonneg - exact (ratToReal_nonneg_iff (x := varianceRat x)).1 hreal - have hvarLo_nonneg : 0 ≤ varLo := by + exact (ratToReal_nonneg_iff (x := varianceRat x)).1 hvar_nonneg + have hvarRat_nonneg : 0 ≤ varRat := by have h := ratRoundDown_nonneg (q := varianceRat x) hvar_nonneg_rat - simpa [varLo, variance_def x hne] using h - have hvarEpsLo_nonneg : 0 ≤ varEpsLo := by - exact add_nonneg hvarLo_nonneg (le_of_lt heps) + simpa [varRat, variance_def x hne] using h + have hvarEps_nonneg : 0 ≤ varEpsRat := by + exact add_nonneg hvarRat_nonneg (le_of_lt heps) have hsqrt_lower : (sqrtLowerBound : Real) ≤ Real.sqrt varEps := by - have hsqrt_eps : - (sqrtLower eps : Real) ≤ Real.sqrt varEps := by - have hsqrt_eps' : - (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by + have hsqrt_eps : (sqrtLower eps : Real) ≤ Real.sqrt varEps := by + have hsqrt_eps' : (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) simpa using h have hle : (eps : Real) ≤ varEps := by have hle' : (eps : Real) ≤ (varianceRat x : Real) + (eps : Real) := le_add_of_nonneg_left hvar_nonneg simpa [varEps] using hle' - have hsqrt_eps'' : Real.sqrt (eps : Real) ≤ Real.sqrt varEps := by - exact Real.sqrt_le_sqrt hle - exact le_trans hsqrt_eps' hsqrt_eps'' - have hsqrt_var : - (sqrtLower varEpsLo : Real) ≤ Real.sqrt varEps := by + exact le_trans hsqrt_eps' (Real.sqrt_le_sqrt hle) + have hsqrt_var : (sqrtLower varEpsRat : Real) ≤ Real.sqrt varEps := by have hsqrt_var' : - (sqrtLower varEpsLo : Real) ≤ Real.sqrt (varEpsLo : Real) := by - have h := sqrtLower_le_real_sqrt (q := varEpsLo) hvarEpsLo_nonneg + (sqrtLower varEpsRat : Real) ≤ Real.sqrt (varEpsRat : Real) := by + have h := sqrtLower_le_real_sqrt (q := varEpsRat) hvarEps_nonneg simpa using h - have hle : (varEpsLo : Real) ≤ varEps := by - have hle' : (varLo : Real) ≤ (varianceRat x : Real) := by - have h := ratRoundDown_le_real (varianceRat x) - simpa [varLo, variance_def x hne] using h - have hle'' := add_le_add_right hle' (eps : Real) - simpa [varEpsLo, varEps, ratToReal_add] using hle'' - have hsqrt_var'' : Real.sqrt (varEpsLo : Real) ≤ Real.sqrt varEps := by - exact Real.sqrt_le_sqrt hle - exact le_trans hsqrt_var' hsqrt_var'' + have hle : (varEpsRat : Real) ≤ varEps := by + simp [hvarEps] + exact le_trans hsqrt_var' (Real.sqrt_le_sqrt hle) have hmax : - max (sqrtLower eps : Real) (sqrtLower varEpsLo : Real) ≤ Real.sqrt varEps := + max (sqrtLower eps : Real) (sqrtLower varEpsRat : Real) ≤ Real.sqrt varEps := (max_le_iff).2 ⟨hsqrt_eps, hsqrt_var⟩ simpa [sqrtLowerBound, ratToReal_max] using hmax + have hsqrt_upper : + Real.sqrt varEps ≤ (sqrtUpperBound : Real) := by + have h := real_sqrt_le_sqrtUpper (q := varEpsRat) hvarEps_nonneg + simpa [sqrtUpperBound, hvarEps] using h + have hsqrt_lower_pos_rat : 0 < sqrtLowerBound := by + have hpos : 0 < sqrtLower eps := hsqrt + have hpos' : 0 < max (sqrtLower eps) (sqrtLower varEpsRat) := + lt_of_lt_of_le hpos (le_max_left _ _) + simpa [sqrtLowerBound] using hpos' have hsqrt_lower_pos : 0 < (sqrtLowerBound : Real) := by - have hpos : 0 < (sqrtLower eps : Real) := by - exact (Rat.cast_pos (K := Real) (q := sqrtLower eps)).2 hsqrt - have hpos' : 0 < max (sqrtLower eps : Real) (sqrtLower varEpsLo : Real) := by - exact lt_of_lt_of_le hpos (le_max_left _ _) - simpa [sqrtLowerBound, ratToReal_max] using hpos' - have hinv_sqrt : invStd ≤ (sqrtLowerBound : Real)⁻¹ := by - have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower - simpa [invStd] using h - have hinv_bound : (sqrtLowerBound : Real)⁻¹ ≤ (invStdBound : Real) := by - have hsqrt_lower_pos_rat : 0 < sqrtLowerBound := by - exact lt_of_lt_of_le hsqrt (le_max_left _ _) - have hy : sqrtLowerBound ≠ 0 := ne_of_gt hsqrt_lower_pos_rat - have hdiv := ratDivUp_ge_real (x := 1) (y := sqrtLowerBound) hy - simpa [invStdBound, one_div] using hdiv - have hinv : invStd ≤ (invStdBound : Real) := by - exact le_trans hinv_sqrt hinv_bound - have hinv_nonneg : 0 ≤ invStd := by - have hsqrt_nonneg : 0 ≤ Real.sqrt varEps := by - exact Real.sqrt_nonneg _ - exact inv_nonneg.2 hsqrt_nonneg - have hmul1 : |(x i : Real) - μ| * invStd ≤ - (centeredBound i : Real) * (invStdBound : Real) := by - have hleft : - |(x i : Real) - μ| * invStd ≤ (centeredBound i : Real) * invStd := by - exact mul_le_mul_of_nonneg_right hcentered_abs hinv_nonneg - have hright : - (centeredBound i : Real) * invStd ≤ (centeredBound i : Real) * (invStdBound : Real) := by - exact mul_le_mul_of_nonneg_left hinv hcentered_nonneg - exact le_trans hleft hright - have hmul2 : |(gamma i : Real)| * |(x i : Real) - μ| * invStd ≤ - |(gamma i : Real)| * (centeredBound i : Real) * (invStdBound : Real) := by - have hgamma_nonneg : 0 ≤ |(gamma i : Real)| := abs_nonneg _ - have hmul2' : |(gamma i : Real)| * (|(x i : Real) - μ| * invStd) ≤ - |(gamma i : Real)| * ((centeredBound i : Real) * (invStdBound : Real)) := by - exact mul_le_mul_of_nonneg_left hmul1 hgamma_nonneg - simpa [mul_assoc] using hmul2' - let t : Real := (gamma i : Real) * ((x i : Real) - μ) * invStd - have ht_abs : - |t| ≤ |(gamma i : Real)| * (centeredBound i : Real) * (invStdBound : Real) := by - have ht : |t| = |(gamma i : Real)| * |(x i : Real) - μ| * invStd := by - have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg - simp [t, abs_mul, hinv_abs, mul_assoc] - simpa [ht] using hmul2 - let radius : Fin n → Rat := fun j => |gamma j| * centeredBound j * invStdBound - have ht_abs' : |t| ≤ (radius i : Real) := by - simpa [radius, centeredBound, invStdBound] using ht_abs - have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by - exact abs_le.mp ht_abs' - have hlow : - (beta i : Real) - (radius i : Real) ≤ t + (beta i : Real) := by - have h := add_le_add_left hbounds.1 (beta i : Real) - simpa [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h - have hhigh : - t + (beta i : Real) ≤ (beta i : Real) + (radius i : Real) := by - have h := add_le_add_left hbounds.2 (beta i : Real) - simpa [add_comm, add_left_comm, add_assoc] using h - have hreal : - layerNormReal eps gamma beta x i = t + (beta i : Real) := by - simp [layerNormReal, hne, μ, invStd, varEps, t, add_comm] - have hlo : (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i := by - simpa [bounds, layerNormBounds, hne, radius, centeredBound, varLo, varEpsLo, - sqrtLowerBound, invStdBound, μLo, μHi, hreal] using hlow - have hhi : layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by - simpa [bounds, layerNormBounds, hne, radius, centeredBound, varLo, varEpsLo, - sqrtLowerBound, invStdBound, μLo, μHi, hreal] using hhigh - exact And.intro hlo hhi + exact (Rat.cast_pos (K := Real) (q := sqrtLowerBound)).2 hsqrt_lower_pos_rat + have hsqrt_upper_pos_rat : 0 < sqrtUpperBound := by + simpa [sqrtUpperBound] using sqrtUpper_pos varEpsRat + have hsqrt_upper_pos : 0 < (sqrtUpperBound : Real) := by + exact (Rat.cast_pos (K := Real) (q := sqrtUpperBound)).2 hsqrt_upper_pos_rat + have hvarEps_pos : 0 < varEps := by + have heps_real : 0 < (eps : Real) := by + exact_mod_cast heps + have hpos := add_pos_of_nonneg_of_pos hvar_nonneg heps_real + simpa [varEps] using hpos + have hsqrt_pos : 0 < Real.sqrt varEps := Real.sqrt_pos.2 hvarEps_pos + have hinv_lower_real : + (sqrtUpperBound : Real)⁻¹ ≤ invStd := by + have hle := inv_anti₀ hsqrt_pos hsqrt_upper + simpa [invStd] using hle + have hinv_upper_real : + invStd ≤ (sqrtLowerBound : Real)⁻¹ := by + have hle := inv_anti₀ hsqrt_lower_pos hsqrt_lower + simpa [invStd] using hle + have hupper_ne : sqrtUpperBound ≠ 0 := ne_of_gt hsqrt_upper_pos_rat + have hlower_ne : sqrtLowerBound ≠ 0 := ne_of_gt hsqrt_lower_pos_rat + have hinv_lower : (invStdLower : Real) ≤ invStd := by + simpa [invStdLower, ratDivDown, hupper_ne, one_div] using hinv_lower_real + have hinv_upper : invStd ≤ (invStdUpper : Real) := by + simpa [invStdUpper, ratDivUp, hlower_ne, one_div] using hinv_upper_real + have hlayer : + layerNormReal eps gamma beta x i = + (beta i : Real) + (coeff : Real) * invStd := by + simp [layerNormReal, hne, coeff, centered, μ, hmu, invStd, varEps, add_comm, mul_assoc] + by_cases hcoeff : 0 ≤ coeff + · have hcoeff_real : 0 ≤ (coeff : Real) := + ratToReal_nonneg_of_nonneg hcoeff + have hlow_raw : + (beta i : Real) + (coeff : Real) * (invStdLower : Real) ≤ + (beta i : Real) + (coeff : Real) * invStd := by + have hmul := mul_le_mul_of_nonneg_left hinv_lower hcoeff_real + exact add_le_add_right hmul (beta i : Real) + have hhigh_raw : + (beta i : Real) + (coeff : Real) * invStd ≤ + (beta i : Real) + (coeff : Real) * (invStdUpper : Real) := by + have hmul := mul_le_mul_of_nonneg_left hinv_upper hcoeff_real + exact add_le_add_right hmul (beta i : Real) + have hlo : (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i := by + simpa [bounds, layerNormBounds, hne, μRat, centered, varRat, varEpsRat, + sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] + using hlow_raw + have hhi : layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + simpa [bounds, layerNormBounds, hne, μRat, centered, varRat, varEpsRat, + sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] + using hhigh_raw + exact And.intro hlo hhi + · have hcoeff_lt : coeff < 0 := lt_of_not_ge hcoeff + have hcoeff_real : (coeff : Real) ≤ 0 := by + exact_mod_cast (le_of_lt hcoeff_lt) + have hlow_raw : + (beta i : Real) + (coeff : Real) * (invStdUpper : Real) ≤ + (beta i : Real) + (coeff : Real) * invStd := by + have hmul := mul_le_mul_of_nonpos_left hinv_upper hcoeff_real + exact add_le_add_right hmul (beta i : Real) + have hhigh_raw : + (beta i : Real) + (coeff : Real) * invStd ≤ + (beta i : Real) + (coeff : Real) * (invStdLower : Real) := by + have hmul := mul_le_mul_of_nonpos_left hinv_lower hcoeff_real + exact add_le_add_right hmul (beta i : Real) + have hlo : (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i := by + simpa [bounds, layerNormBounds, hne, μRat, centered, varRat, varEpsRat, + sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] + using hlow_raw + have hhi : layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + simpa [bounds, layerNormBounds, hne, μRat, centered, varRat, varEpsRat, + sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] + using hhigh_raw + exact And.intro hlo hhi /-- Interval bounds for LayerNorm outputs from per-coordinate intervals. -/ def layerNormIntervalBounds {n : Nat} diff --git a/Nfp/Sound/Bounds/LayerNorm/InvStd.lean b/Nfp/Sound/Bounds/LayerNorm/InvStd.lean new file mode 100644 index 0000000..3bf1967 --- /dev/null +++ b/Nfp/Sound/Bounds/LayerNorm/InvStd.lean @@ -0,0 +1,124 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Sound.Bounds.LayerNorm + +/-! +Inverse-standard-deviation bounds for LayerNorm. + +This module isolates invStd bounds and their soundness proof to keep +`LayerNorm.lean` below the style linter's file-length limit. +-/ + +namespace Nfp + +namespace Sound + +namespace Bounds + +/-- Bounds for the LayerNorm inverse standard deviation term. -/ +def invStdBounds {n : Nat} (eps : Rat) (x : Fin n → Rat) : Rat × Rat := + if n = 0 then + (0, 0) + else + let var : Rat := variance x + let varEps : Rat := var + eps + let sqrtLowerBound : Rat := max (sqrtLower eps) (sqrtLower varEps) + let sqrtUpperBound : Rat := sqrtUpper varEps + (ratDivDown 1 sqrtUpperBound, ratDivUp 1 sqrtLowerBound) + +/-- `invStdBounds` soundness for real inverse-std terms. -/ +theorem invStdBounds_spec {n : Nat} (eps : Rat) (x : Fin n → Rat) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : + let bounds := invStdBounds eps x + let invStd : Real := (Real.sqrt ((varianceRat x : Real) + (eps : Real)))⁻¹ + (bounds.1 : Real) ≤ invStd ∧ invStd ≤ (bounds.2 : Real) := by + classical + intro bounds invStd + let varRat : Rat := variance x + let varEpsRat : Rat := varRat + eps + let sqrtLowerBound : Rat := max (sqrtLower eps) (sqrtLower varEpsRat) + let sqrtUpperBound : Rat := sqrtUpper varEpsRat + let invStdLower : Rat := ratDivDown 1 sqrtUpperBound + let invStdUpper : Rat := ratDivUp 1 sqrtLowerBound + let varEps : Real := (varianceRat x : Real) + (eps : Real) + have hvar : (varRat : Real) = (varianceRat x : Real) := by + simp [varRat, variance_def, hne, ratRoundDown] + have hvarEps : (varEpsRat : Real) = varEps := by + simp [varEpsRat, varEps, hvar] + have hvar_nonneg : 0 ≤ (varianceRat x : Real) := varianceRat_nonneg_real x hne + have hvar_nonneg_rat : 0 ≤ varianceRat x := by + exact (ratToReal_nonneg_iff (x := varianceRat x)).1 hvar_nonneg + have hvarRat_nonneg : 0 ≤ varRat := by + have h := ratRoundDown_nonneg (q := varianceRat x) hvar_nonneg_rat + simpa [varRat, variance_def x hne] using h + have hvarEps_nonneg : 0 ≤ varEpsRat := by + exact add_nonneg hvarRat_nonneg (le_of_lt heps) + have hsqrt_lower : + (sqrtLowerBound : Real) ≤ Real.sqrt varEps := by + have hsqrt_eps : (sqrtLower eps : Real) ≤ Real.sqrt varEps := by + have hsqrt_eps' : (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by + have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) + simpa using h + have hle : (eps : Real) ≤ varEps := by + have hle' : (eps : Real) ≤ (varianceRat x : Real) + (eps : Real) := + le_add_of_nonneg_left hvar_nonneg + simpa [varEps] using hle' + exact le_trans hsqrt_eps' (Real.sqrt_le_sqrt hle) + have hsqrt_var : (sqrtLower varEpsRat : Real) ≤ Real.sqrt varEps := by + have hsqrt_var' : + (sqrtLower varEpsRat : Real) ≤ Real.sqrt (varEpsRat : Real) := by + have h := sqrtLower_le_real_sqrt (q := varEpsRat) hvarEps_nonneg + simpa using h + have hle : (varEpsRat : Real) ≤ varEps := by + simp [hvarEps] + exact le_trans hsqrt_var' (Real.sqrt_le_sqrt hle) + have hmax : + max (sqrtLower eps : Real) (sqrtLower varEpsRat : Real) ≤ Real.sqrt varEps := + (max_le_iff).2 ⟨hsqrt_eps, hsqrt_var⟩ + simpa [sqrtLowerBound, ratToReal_max] using hmax + have hsqrt_upper : + Real.sqrt varEps ≤ (sqrtUpperBound : Real) := by + have h := real_sqrt_le_sqrtUpper (q := varEpsRat) hvarEps_nonneg + simpa [sqrtUpperBound, hvarEps] using h + have hsqrt_lower_pos_rat : 0 < sqrtLowerBound := by + have hpos : 0 < sqrtLower eps := hsqrt + have hpos' : 0 < max (sqrtLower eps) (sqrtLower varEpsRat) := + lt_of_lt_of_le hpos (le_max_left _ _) + simpa [sqrtLowerBound] using hpos' + have hsqrt_lower_pos : 0 < (sqrtLowerBound : Real) := by + exact (Rat.cast_pos (K := Real) (q := sqrtLowerBound)).2 hsqrt_lower_pos_rat + have hsqrt_upper_pos_rat : 0 < sqrtUpperBound := by + simpa [sqrtUpperBound] using sqrtUpper_pos varEpsRat + have hsqrt_upper_pos : 0 < (sqrtUpperBound : Real) := by + exact (Rat.cast_pos (K := Real) (q := sqrtUpperBound)).2 hsqrt_upper_pos_rat + have hvarEps_pos : 0 < varEps := by + have heps_real : 0 < (eps : Real) := by + exact_mod_cast heps + have hpos := add_pos_of_nonneg_of_pos hvar_nonneg heps_real + simpa [varEps] using hpos + have hsqrt_pos : 0 < Real.sqrt varEps := Real.sqrt_pos.2 hvarEps_pos + have hinv_lower_real : + (sqrtUpperBound : Real)⁻¹ ≤ invStd := by + have hle := inv_anti₀ hsqrt_pos hsqrt_upper + simpa [invStd, varEps] using hle + have hinv_upper_real : + invStd ≤ (sqrtLowerBound : Real)⁻¹ := by + have hle := inv_anti₀ hsqrt_lower_pos hsqrt_lower + simpa [invStd, varEps] using hle + have hupper_ne : sqrtUpperBound ≠ 0 := ne_of_gt hsqrt_upper_pos_rat + have hlower_ne : sqrtLowerBound ≠ 0 := ne_of_gt hsqrt_lower_pos_rat + have hinv_lower : (invStdLower : Real) ≤ invStd := by + simpa [invStdLower, ratDivDown, hupper_ne, one_div] using hinv_lower_real + have hinv_upper : invStd ≤ (invStdUpper : Real) := by + simpa [invStdUpper, ratDivUp, hlower_ne, one_div] using hinv_upper_real + constructor + · simpa [bounds, invStdBounds, hne, varRat, varEpsRat, sqrtLowerBound, sqrtUpperBound, + invStdLower, invStdUpper] using hinv_lower + · simpa [bounds, invStdBounds, hne, varRat, varEpsRat, sqrtLowerBound, sqrtUpperBound, + invStdLower, invStdUpper] using hinv_upper + +end Bounds + +end Sound + +end Nfp diff --git a/Nfp/Sound/Induction.lean b/Nfp/Sound/Induction.lean index 6c0b295..7376f8e 100644 --- a/Nfp/Sound/Induction.lean +++ b/Nfp/Sound/Induction.lean @@ -1,6 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Nfp.Sound.Induction.Core +import Nfp.Sound.Induction.CoreSound import Nfp.Sound.Induction.HeadOutput /-! diff --git a/Nfp/Sound/Induction/Core.lean b/Nfp/Sound/Induction/Core.lean index 3831993..8d9b23d 100644 --- a/Nfp/Sound/Induction/Core.lean +++ b/Nfp/Sound/Induction/Core.lean @@ -10,16 +10,13 @@ import Nfp.Circuit.Cert.ValueRange import Nfp.Sound.Bounds.Attention import Nfp.Sound.Bounds.Cache import Nfp.Sound.Bounds.LayerNorm +import Nfp.Sound.Bounds.LayerNorm.InvStd import Nfp.Sound.Bounds.MatrixNorm import Nfp.Sound.Induction.CoreDefs import Nfp.Sound.Induction.OneHot import Nfp.Sound.Linear.FinFold -/-! -Sound builders for induction certificates. -These builders recompute certificate bounds inside Lean from exact inputs and -return proof-carrying results. The head-input path derives softmax tolerances -from score margins rather than trusting external weight dumps. --/ +/-! Sound builders for induction certificates; recompute bounds inside Lean from exact inputs and +derive softmax tolerances from score margins rather than trusting external weight dumps. -/ namespace Nfp namespace Sound open scoped BigOperators @@ -122,86 +119,99 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} let univ : Finset (Fin seq) := Finset.univ have hnonempty : univ.Nonempty := Finset.univ_nonempty univ.sup' hnonempty (fun q => lnAbsMax q) - let qLoRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := + let invStdBoundsTasks : Array (Task (Rat × Rat)) := Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - ⟨Array.ofFn (fun d : Fin dHead => - Bounds.dotIntervalLowerUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + - inputs.bq d), - by simp⟩)) - let qHiRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := + Task.spawn (fun _ => invStdBounds inputs.lnEps (inputs.embed q))) + let invStdBoundsArr : Array (Rat × Rat) := + Array.ofFn (fun q : Fin seq => + (invStdBoundsTasks[q.1]'(by + have hsize : invStdBoundsTasks.size = seq := by + simp [invStdBoundsTasks] + simp [hsize])).get) + let invStdLo : Fin seq → Rat := fun q => + (invStdBoundsArr[q.1]'(by + have hsize : invStdBoundsArr.size = seq := by + simp [invStdBoundsArr] + simp [hsize])).1 + let invStdHi : Fin seq → Rat := fun q => + (invStdBoundsArr[q.1]'(by + have hsize : invStdBoundsArr.size = seq := by + simp [invStdBoundsArr] + simp [hsize])).2 + let qBaseArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.wq j d) (fun j => inputs.ln1Beta j) + + inputs.bq d) + let qBase : Fin dHead → Rat := fun d => + qBaseArr[d.1]'(by + have hsize : qBaseArr.size = dHead := by + simp [qBaseArr] + simp [hsize]) + let kBaseArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.wk j d) (fun j => inputs.ln1Beta j) + + inputs.bk d) + let kBase : Fin dHead → Rat := fun d => + kBaseArr[d.1]'(by + have hsize : kBaseArr.size = dHead := by + simp [kBaseArr] + simp [hsize]) + let qCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => + let μ := mean (inputs.embed q) + let coeff : Fin dModel → Rat := fun j => + inputs.ln1Gamma j * (inputs.embed q j - μ) ⟨Array.ofFn (fun d : Fin dHead => - Bounds.dotIntervalUpperUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + - inputs.bq d), + Linear.dotFin dModel (fun j => inputs.wq j d) coeff), by simp⟩)) - let qLoArr : Array { row : Array Rat // row.size = dHead } := + let qCoeffArr : Array { row : Array Rat // row.size = dHead } := Array.ofFn (fun q : Fin seq => - (qLoRowTasks[q.1]'(by - have hsize : qLoRowTasks.size = seq := by - simp [qLoRowTasks] + (qCoeffRowTasks[q.1]'(by + have hsize : qCoeffRowTasks.size = seq := by + simp [qCoeffRowTasks] simp [hsize])).get) - let qHiArr : Array { row : Array Rat // row.size = dHead } := - Array.ofFn (fun q : Fin seq => - (qHiRowTasks[q.1]'(by - have hsize : qHiRowTasks.size = seq := by - simp [qHiRowTasks] - simp [hsize])).get) - let qLo : Fin seq → Fin dHead → Rat := fun q d => - let row := qLoArr[q.1]'(by - have hsize : qLoArr.size = seq := by - simp [qLoArr] + let qCoeff : Fin seq → Fin dHead → Rat := fun q d => + let row := qCoeffArr[q.1]'(by + have hsize : qCoeffArr.size = seq := by + simp [qCoeffArr] simp [hsize]) row.1[d.1]'(by simp [row.2]) - let qHi : Fin seq → Fin dHead → Rat := fun q d => - let row := qHiArr[q.1]'(by - have hsize : qHiArr.size = seq := by - simp [qHiArr] - simp [hsize]) - row.1[d.1]'(by - simp [row.2]) - let kLoRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - ⟨Array.ofFn (fun d : Fin dHead => - Bounds.dotIntervalLowerUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + - inputs.bk d), - by simp⟩)) - let kHiRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := + let kCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => + let μ := mean (inputs.embed q) + let coeff : Fin dModel → Rat := fun j => + inputs.ln1Gamma j * (inputs.embed q j - μ) ⟨Array.ofFn (fun d : Fin dHead => - Bounds.dotIntervalUpperUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + - inputs.bk d), + Linear.dotFin dModel (fun j => inputs.wk j d) coeff), by simp⟩)) - let kLoArr : Array { row : Array Rat // row.size = dHead } := + let kCoeffArr : Array { row : Array Rat // row.size = dHead } := Array.ofFn (fun q : Fin seq => - (kLoRowTasks[q.1]'(by - have hsize : kLoRowTasks.size = seq := by - simp [kLoRowTasks] + (kCoeffRowTasks[q.1]'(by + have hsize : kCoeffRowTasks.size = seq := by + simp [kCoeffRowTasks] simp [hsize])).get) - let kHiArr : Array { row : Array Rat // row.size = dHead } := - Array.ofFn (fun q : Fin seq => - (kHiRowTasks[q.1]'(by - have hsize : kHiRowTasks.size = seq := by - simp [kHiRowTasks] - simp [hsize])).get) - let kLo : Fin seq → Fin dHead → Rat := fun q d => - let row := kLoArr[q.1]'(by - have hsize : kLoArr.size = seq := by - simp [kLoArr] + let kCoeff : Fin seq → Fin dHead → Rat := fun q d => + let row := kCoeffArr[q.1]'(by + have hsize : kCoeffArr.size = seq := by + simp [kCoeffArr] simp [hsize]) row.1[d.1]'(by simp [row.2]) + let qLo : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) + qBase d + bounds.1 + let qHi : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) + qBase d + bounds.2 + let kLo : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) + kBase d + bounds.1 let kHi : Fin seq → Fin dHead → Rat := fun q d => - let row := kHiArr[q.1]'(by - have hsize : kHiArr.size = seq := by - simp [kHiArr] - simp [hsize]) - row.1[d.1]'(by - simp [row.2]) + let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) + kBase d + bounds.2 let qAbs : Fin seq → Fin dHead → Rat := fun q d => max |qLo q d| |qHi q d| let kAbs : Fin seq → Fin dHead → Rat := fun q d => max |kLo q d| |kHi q d| let qAbsMaxArr : Array Rat := @@ -226,8 +236,8 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} simp [hsize]) let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k - let splitBudgetQ : Nat := 2 - let splitBudgetK : Nat := 2 + let splitBudgetQ : Nat := 8 + let splitBudgetK : Nat := 8 let splitDimsQ : Fin seq → List (Fin dHead) := fun q => let ambig := (List.finRange dHead).filter (fun d => qLo q d < 0 ∧ 0 < qHi q d) @@ -246,17 +256,18 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) | (none, some b2) => if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - let best := ambig.foldl step (none, none) - let dims := - match best with + let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => + match ambig.foldl step (none, none) with | (some b1, some b2) => [b1.2, b2.2] | (some b1, none) => [b1.2] | (none, _) => [] - dims.take splitBudgetQ - let splitDimsK : Fin seq → List (Fin dHead) := fun k => + let dims1 := top2 ambig + let dims2 := top2 (ambig.filter (fun d => d ∉ dims1)) + (dims1 ++ dims2).take splitBudgetQ + let splitDimsK : Fin seq → Fin seq → List (Fin dHead) := fun q k => let ambig := (List.finRange dHead).filter (fun d => kLo k d < 0 ∧ 0 < kHi k d) - let score : Fin dHead → Rat := fun d => (kHi k d - kLo k d) * qAbsMax d + let score : Fin dHead → Rat := fun d => (kHi k d - kLo k d) * qAbs q d let step (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) (d : Fin dHead) : @@ -271,23 +282,68 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) | (none, some b2) => if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - let best := ambig.foldl step (none, none) - let dims := + let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => + match ambig.foldl step (none, none) with + | (some b1, some b2) => [b1.2, b2.2] + | (some b1, none) => [b1.2] + | (none, _) => [] + let dims1 := top2 ambig + let dims2 := top2 (ambig.filter (fun d => d ∉ dims1)) + (dims1 ++ dims2).take splitBudgetK + let splitDimsDiff : Fin seq → Fin seq → List (Fin dHead) := fun q k => + let prev := inputs.prev q + let diffLo : Fin dHead → Rat := fun d => kLo prev d - kHi k d + let diffHi : Fin dHead → Rat := fun d => kHi prev d - kLo k d + let ambig := + (List.finRange dHead).filter (fun d => diffLo d < 0 ∧ 0 < diffHi d) + let score : Fin dHead → Rat := fun d => (diffHi d - diffLo d) * qAbs q d + let step + (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) + (d : Fin dHead) : + Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := + let s := score d match best with + | (none, none) => (some (s, d), none) + | (some b1, none) => + if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) + | (some b1, some b2) => + if b1.1 < s then (some (s, d), some b1) + else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) + | (none, some b2) => + if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) + let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => + match ambig.foldl step (none, none) with | (some b1, some b2) => [b1.2, b2.2] | (some b1, none) => [b1.2] | (none, _) => [] - dims.take splitBudgetK + let dims1 := top2 ambig + let dims2 := top2 (ambig.filter (fun d => d ∉ dims1)) + (dims1 ++ dims2).take splitBudgetK let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => let dimsQ := splitDimsQ q ⟨Array.ofFn (fun k : Fin seq => - let dimsK := splitDimsK k + let dimsK := splitDimsK q k _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsK (fun d => qLo q d) (fun d => qHi q d) (fun d => kLo k d) (fun d => kHi k d)), by simp⟩)) + let dotDiffRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + let dimsQ := splitDimsQ q + ⟨Array.ofFn (fun k : Fin seq => + if masked q k then + (0, 0) + else + let dimsDiff := splitDimsDiff q k + let prev := inputs.prev q + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)), + by simp⟩)) let dotLo : Fin seq → Fin seq → Rat := fun q k => let row := (dotRowTasks[q.1]'(by simp [dotRowTasks, q.isLt])).get @@ -300,6 +356,18 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} let entry := row.1[k.1]'(by simp [row.2, k.isLt]) entry.2 + let dotDiffLo : Fin seq → Fin seq → Rat := fun q k => + let row := (dotDiffRowTasks[q.1]'(by + simp [dotDiffRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.1 + let dotDiffHi : Fin seq → Fin seq → Rat := fun q k => + let row := (dotDiffRowTasks[q.1]'(by + simp [dotDiffRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.2 let dotAbs : Fin seq → Fin seq → Rat := fun q k => max |dotLo q k| |dotHi q k| let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => |inputs.scale| * dotAbs q k @@ -321,13 +389,22 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} inputs.scale * dotLo q k let scoreLoPrev : Fin seq → Rat := fun q => scoreLo q (inputs.prev q) + let scoreGapLo : Fin seq → Fin seq → Rat := fun q k => + if masked q (inputs.prev q) then + scoreLoPrev q - scoreHi q k + else if masked q k then + scoreLoPrev q - inputs.maskValue + else if hscale : 0 ≤ inputs.scale then + inputs.scale * dotDiffLo q k + else + inputs.scale * dotDiffHi q k let otherKeys : Fin seq → Finset (Fin seq) := fun q => (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) let marginAt : Fin seq → Rat := fun q => if hq : q ∈ inputs.active then let other := otherKeys q if h : other.Nonempty then - other.inf' h (fun k => scoreLoPrev q - scoreHi q k) + other.inf' h (fun k => scoreGapLo q k) else (0 : Rat) else @@ -381,1118 +458,6 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} · exact none · exact none · exact none -set_option maxHeartbeats 1000000 in --- Large softmax/interval proof expands many bounds; bump heartbeats to avoid timeouts. -/-- Soundness for `buildInductionCertFromHeadCore?`. -/ -theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) - (hcore : buildInductionCertFromHeadCore? inputs = some c) : - InductionHeadCertSound inputs c := by - classical - by_cases hEps : 0 < inputs.lnEps - · by_cases hSqrt : 0 < sqrtLower inputs.lnEps - · by_cases hmodel : dModel = 0 - · have : False := by - simp [buildInductionCertFromHeadCore?, hEps, hSqrt, hmodel] at hcore - exact this.elim - · by_cases hactive : inputs.active.Nonempty - · let lnBounds := Bounds.cacheBoundPair2 (fun q => - Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) - let lnLo : Fin seq → Fin dModel → Rat := lnBounds.1 - let lnHi : Fin seq → Fin dModel → Rat := lnBounds.2 - let lnAbsMaxTask : Fin seq → Rat := - Bounds.cacheBoundTask (fun q => - Bounds.intervalAbsBound (lnLo q) (lnHi q)) - let lnAbsMaxArr : Array Rat := - Array.ofFn (fun q : Fin seq => lnAbsMaxTask q) - let lnAbsMax : Fin seq → Rat := fun q => - lnAbsMaxArr[q.1]'(by - have hsize : lnAbsMaxArr.size = seq := by - simp [lnAbsMaxArr] - simp [hsize]) - let lnAbsMaxMax : Rat := - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := Finset.univ_nonempty - univ.sup' hnonempty (fun q => lnAbsMax q) - let qLoRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - ⟨Array.ofFn (fun d : Fin dHead => - Bounds.dotIntervalLowerUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + - inputs.bq d), - by simp⟩)) - let qHiRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - ⟨Array.ofFn (fun d : Fin dHead => - Bounds.dotIntervalUpperUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + - inputs.bq d), - by simp⟩)) - let qLoArr : Array { row : Array Rat // row.size = dHead } := - Array.ofFn (fun q : Fin seq => - (qLoRowTasks[q.1]'(by - have hsize : qLoRowTasks.size = seq := by - simp [qLoRowTasks] - simp [hsize])).get) - let qHiArr : Array { row : Array Rat // row.size = dHead } := - Array.ofFn (fun q : Fin seq => - (qHiRowTasks[q.1]'(by - have hsize : qHiRowTasks.size = seq := by - simp [qHiRowTasks] - simp [hsize])).get) - let qLo : Fin seq → Fin dHead → Rat := fun q d => - let row := qLoArr[q.1]'(by - have hsize : qLoArr.size = seq := by - simp [qLoArr] - simp [hsize]) - row.1[d.1]'(by - simp [row.2]) - let qHi : Fin seq → Fin dHead → Rat := fun q d => - let row := qHiArr[q.1]'(by - have hsize : qHiArr.size = seq := by - simp [qHiArr] - simp [hsize]) - row.1[d.1]'(by - simp [row.2]) - let kLoRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - ⟨Array.ofFn (fun d : Fin dHead => - Bounds.dotIntervalLowerUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + - inputs.bk d), - by simp⟩)) - let kHiRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - ⟨Array.ofFn (fun d : Fin dHead => - Bounds.dotIntervalUpperUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + - inputs.bk d), - by simp⟩)) - let kLoArr : Array { row : Array Rat // row.size = dHead } := - Array.ofFn (fun q : Fin seq => - (kLoRowTasks[q.1]'(by - have hsize : kLoRowTasks.size = seq := by - simp [kLoRowTasks] - simp [hsize])).get) - let kHiArr : Array { row : Array Rat // row.size = dHead } := - Array.ofFn (fun q : Fin seq => - (kHiRowTasks[q.1]'(by - have hsize : kHiRowTasks.size = seq := by - simp [kHiRowTasks] - simp [hsize])).get) - let kLo : Fin seq → Fin dHead → Rat := fun q d => - let row := kLoArr[q.1]'(by - have hsize : kLoArr.size = seq := by - simp [kLoArr] - simp [hsize]) - row.1[d.1]'(by - simp [row.2]) - let kHi : Fin seq → Fin dHead → Rat := fun q d => - let row := kHiArr[q.1]'(by - have hsize : kHiArr.size = seq := by - simp [kHiArr] - simp [hsize]) - row.1[d.1]'(by - simp [row.2]) - let qAbs : Fin seq → Fin dHead → Rat := fun q d => max |qLo q d| |qHi q d| - let kAbs : Fin seq → Fin dHead → Rat := fun q d => max |kLo q d| |kHi q d| - let qAbsMaxArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := Finset.univ_nonempty - univ.sup' hnonempty (fun q => qAbs q d)) - let qAbsMax : Fin dHead → Rat := fun d => - qAbsMaxArr[d.1]'(by - have hsize : qAbsMaxArr.size = dHead := by - simp [qAbsMaxArr] - simp [hsize]) - let kAbsMaxArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := Finset.univ_nonempty - univ.sup' hnonempty (fun k => kAbs k d)) - let kAbsMax : Fin dHead → Rat := fun d => - kAbsMaxArr[d.1]'(by - have hsize : kAbsMaxArr.size = dHead := by - simp [kAbsMaxArr] - simp [hsize]) - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let splitBudgetQ : Nat := 2 - let splitBudgetK : Nat := 2 - let splitDimsQ : Fin seq → List (Fin dHead) := fun q => - let ambig := - (List.finRange dHead).filter (fun d => qLo q d < 0 ∧ 0 < qHi q d) - let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d - let step - (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) - (d : Fin dHead) : - Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := - let s := score d - match best with - | (none, none) => (some (s, d), none) - | (some b1, none) => - if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) - | (some b1, some b2) => - if b1.1 < s then (some (s, d), some b1) - else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) - | (none, some b2) => - if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - let best := ambig.foldl step (none, none) - let dims := - match best with - | (some b1, some b2) => [b1.2, b2.2] - | (some b1, none) => [b1.2] - | (none, _) => [] - dims.take splitBudgetQ - let splitDimsK : Fin seq → List (Fin dHead) := fun k => - let ambig := - (List.finRange dHead).filter (fun d => kLo k d < 0 ∧ 0 < kHi k d) - let score : Fin dHead → Rat := fun d => (kHi k d - kLo k d) * qAbsMax d - let step - (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) - (d : Fin dHead) : - Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := - let s := score d - match best with - | (none, none) => (some (s, d), none) - | (some b1, none) => - if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) - | (some b1, some b2) => - if b1.1 < s then (some (s, d), some b1) - else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) - | (none, some b2) => - if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - let best := ambig.foldl step (none, none) - let dims := - match best with - | (some b1, some b2) => [b1.2, b2.2] - | (some b1, none) => [b1.2] - | (none, _) => [] - dims.take splitBudgetK - let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - let dimsQ := splitDimsQ q - ⟨Array.ofFn (fun k : Fin seq => - let dimsK := splitDimsK k - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsK - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo k d) (fun d => kHi k d)), - by simp⟩)) - let dotLo : Fin seq → Fin seq → Rat := fun q k => - let row := (dotRowTasks[q.1]'(by - simp [dotRowTasks, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.1 - let dotHi : Fin seq → Fin seq → Rat := fun q k => - let row := (dotRowTasks[q.1]'(by - simp [dotRowTasks, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.2 - let dotAbs : Fin seq → Fin seq → Rat := fun q k => max |dotLo q k| |dotHi q k| - let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => - |inputs.scale| * dotAbs q k - let scoreLo : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - if hscale : 0 ≤ inputs.scale then - inputs.scale * dotLo q k - else - inputs.scale * dotHi q k - let scoreHi : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - if hscale : 0 ≤ inputs.scale then - inputs.scale * dotHi q k - else - inputs.scale * dotLo q k - let scoreLoPrev : Fin seq → Rat := fun q => - scoreLo q (inputs.prev q) - let otherKeys : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - let marginAt : Fin seq → Rat := fun q => - if hq : q ∈ inputs.active then - let other := otherKeys q - if h : other.Nonempty then - other.inf' h (fun k => scoreLoPrev q - scoreHi q k) - else - (0 : Rat) - else - (0 : Rat) - let epsAt : Fin seq → Rat := fun q => - if marginAt q < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + marginAt q) - let margin : Rat := - if h : inputs.active.Nonempty then - inputs.active.inf' h marginAt - else - (0 : Rat) - let eps : Rat := - if margin < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + margin) - have hseq : (1 : Nat) ≤ seq := - Nat.succ_le_iff.mpr (Nat.pos_of_ne_zero (NeZero.ne seq)) - let dirHeadVec := dirHeadVecOfInputs inputs - let dirHead : Fin dHead → Rat := fun d => dirHeadVec.get d - let wvDir : Fin dModel → Rat := - Bounds.cacheBoundTask (fun j => - Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) - let bDir : Rat := - Linear.dotFin dHead dirHead (fun d => inputs.bv d) - let valsAbsBase : Rat := - Linear.sumFin dModel (fun j => |wvDir j|) * lnAbsMaxMax - let valsLoBase := bDir - valsAbsBase - let valsHiBase := bDir + valsAbsBase - let valsLo : Fin seq → Rat := fun _ => valsLoBase - let valsHi : Fin seq → Rat := fun _ => valsHiBase - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := by simp [univ] - let lo := univ.inf' hnonempty valsLo - let hi := univ.sup' hnonempty valsHi - let valCert : ValueInterval seq := - { lo := lo - hi := hi - valsLo := valsLo - valsHi := valsHi - direction := some inputs.directionSpec } - let cert : InductionHeadCert seq := - { eps := eps - epsAt := epsAt - margin := margin - active := inputs.active - prev := inputs.prev - values := valCert } - have hcore' : some cert = some c := by - simpa - [buildInductionCertFromHeadCore?, hEps, hSqrt, hmodel, hactive, lnBounds, - lnLo, lnHi, lnAbsMaxTask, lnAbsMaxArr, lnAbsMax, lnAbsMaxMax, - qLoRowTasks, qHiRowTasks, qLoArr, qHiArr, qLo, qHi, - kLoRowTasks, kHiRowTasks, kLoArr, kHiArr, kLo, kHi, - qAbs, kAbs, qAbsMaxArr, qAbsMax, kAbsMaxArr, kAbsMax, masked, - splitBudgetQ, splitBudgetK, splitDimsQ, splitDimsK, - dotRowTasks, dotLo, dotHi, dotAbs, - scoreBaseAbs, scoreLo, scoreHi, scoreLoPrev, otherKeys, marginAt, epsAt, - margin, eps, dirHeadVec, dirHead, wvDir, bDir, valsAbsBase, valsLoBase, - valsHiBase, valsLo, valsHi, univ, lo, hi, valCert, cert, Task.spawn, - Bounds.cacheBoundTask_apply, Array.getElem_ofFn] - using hcore - have hc : c = cert := by - simpa using (Option.some.inj hcore').symm - subst hc - have hln_bounds : - ∀ q i, (lnLo q i : Real) ≤ lnRealOfInputs inputs q i ∧ - lnRealOfInputs inputs q i ≤ (lnHi q i : Real) := by - intro q i - have hln := - Bounds.layerNormBounds_spec (eps := inputs.lnEps) - (gamma := inputs.ln1Gamma) (beta := inputs.ln1Beta) - (x := inputs.embed q) hmodel hEps hSqrt - simpa [lnBounds, lnLo, lnHi, lnRealOfInputs, - Bounds.cacheBoundPair2_apply_left, Bounds.cacheBoundPair2_apply_right] - using hln i - have hln_abs : ∀ q j, |lnRealOfInputs inputs q j| ≤ (lnAbsMax q : Real) := by - intro q j - have hln := hln_bounds q - have h := - Bounds.abs_le_intervalAbsBound_real (lo := lnLo q) (hi := lnHi q) - (x := lnRealOfInputs inputs q) (hlo := fun j => (hln j).1) - (hhi := fun j => (hln j).2) j - simpa [lnAbsMax, lnAbsMaxArr, lnAbsMaxTask, Bounds.cacheBoundTask_apply, - Array.getElem_ofFn] using h - have hln_abs_max : ∀ q, lnAbsMax q ≤ lnAbsMaxMax := by - intro q - have hnonempty : (Finset.univ : Finset (Fin seq)).Nonempty := - Finset.univ_nonempty - have hmem : q ∈ (Finset.univ : Finset (Fin seq)) := by simp - simpa [lnAbsMaxMax] using - (Finset.le_sup'_iff (s := (Finset.univ : Finset (Fin seq))) - (H := hnonempty) (f := fun q => lnAbsMax q) (a := lnAbsMax q)).2 - ⟨q, hmem, le_rfl⟩ - have hdot_abs_bound : - ∀ (v : Fin dModel → Rat) (q : Fin seq), - |dotProduct (fun j => (v j : Real)) (lnRealOfInputs inputs q)| ≤ - (Bounds.dotIntervalAbsBound v (lnLo q) (lnHi q) : Real) := by - intro v q - have hln := hln_bounds q - have hlo : ∀ j, (lnLo q j : Real) ≤ lnRealOfInputs inputs q j := fun j => - (hln j).1 - have hhi : ∀ j, lnRealOfInputs inputs q j ≤ (lnHi q j : Real) := fun j => - (hln j).2 - simpa using - (Bounds.abs_dotProduct_le_dotIntervalAbsBound_real - (v := v) (lo := lnLo q) (hi := lnHi q) - (x := lnRealOfInputs inputs q) hlo hhi) - have hdot_abs_bound_sum : - ∀ (v : Fin dModel → Rat) (q : Fin seq), - |dotProduct (fun j => (v j : Real)) (lnRealOfInputs inputs q)| ≤ - (Linear.sumFin dModel (fun j => |v j|) : Real) * (lnAbsMax q : Real) := by - intro v q - have hsum : - |∑ j, (v j : Real) * lnRealOfInputs inputs q j| ≤ - ∑ j, |(v j : Real) * lnRealOfInputs inputs q j| := by - simpa [dotProduct] using - (Finset.abs_sum_le_sum_abs (s := (Finset.univ : Finset (Fin dModel))) - (f := fun j => (v j : Real) * lnRealOfInputs inputs q j)) - have hterm : - ∀ j, |(v j : Real) * lnRealOfInputs inputs q j| ≤ - (|v j| : Real) * (lnAbsMax q : Real) := by - intro j - have hln := hln_abs q j - have hnonneg : 0 ≤ (|v j| : Real) := by - exact abs_nonneg _ - calc - |(v j : Real) * lnRealOfInputs inputs q j| = - |(v j : Real)| * |lnRealOfInputs inputs q j| := by - simp [abs_mul] - _ ≤ (|v j| : Real) * (lnAbsMax q : Real) := - mul_le_mul_of_nonneg_left hln hnonneg - have hsum_le : - ∑ j, |(v j : Real) * lnRealOfInputs inputs q j| ≤ - ∑ j, (|v j| : Real) * (lnAbsMax q : Real) := by - refine Finset.sum_le_sum ?_ - intro j _ - exact hterm j - have hsum_mul : - ∑ j, (|v j| : Real) * (lnAbsMax q : Real) = - (∑ j, (|v j| : Real)) * (lnAbsMax q : Real) := by - symm - simpa using - (Finset.sum_mul (s := (Finset.univ : Finset (Fin dModel))) - (f := fun j => (|v j| : Real)) (a := (lnAbsMax q : Real))) - have hsum_cast : - (Linear.sumFin dModel (fun j => |v j|) : Real) = ∑ j, (|v j| : Real) := by - simpa [ratToReal] using - (Linear.ratToReal_sumFin (f := fun j => |v j|)) - have hsum_eq : - ∑ j, (|v j| : Real) * (lnAbsMax q : Real) = - (Linear.sumFin dModel (fun j => |v j|) : Real) * (lnAbsMax q : Real) := by - calc - ∑ j, (|v j| : Real) * (lnAbsMax q : Real) - = (∑ j, (|v j| : Real)) * (lnAbsMax q : Real) := hsum_mul - _ = (Linear.sumFin dModel (fun j => |v j|) : Real) * (lnAbsMax q : Real) := by - simp [hsum_cast] - have hfinal := hsum.trans (hsum_le.trans_eq hsum_eq) - simpa [dotProduct] using hfinal - have hq_bounds : - ∀ q d, (qLo q d : Real) ≤ qRealOfInputs inputs q d ∧ - qRealOfInputs inputs q d ≤ (qHi q d : Real) := by - intro q d - have hln := hln_bounds q - have hlo : ∀ j, (lnLo q j : Real) ≤ lnRealOfInputs inputs q j := fun j => - (hln j).1 - have hhi : ∀ j, lnRealOfInputs inputs q j ≤ (lnHi q j : Real) := fun j => - (hln j).2 - have hdot_lo := - Bounds.dotIntervalLower_le_dotProduct_real - (v := fun j => inputs.wq j d) (lo := lnLo q) (hi := lnHi q) - (x := lnRealOfInputs inputs q) hlo hhi - have hdot_hi := - Bounds.dotProduct_le_dotIntervalUpper_real - (v := fun j => inputs.wq j d) (lo := lnLo q) (hi := lnHi q) - (x := lnRealOfInputs inputs q) hlo hhi - have hlow : - (qLo q d : Real) ≤ qRealOfInputs inputs q d := by - have h := add_le_add_right hdot_lo (inputs.bq d : Real) - simpa [qRealOfInputs, qLo, qLoArr, qLoRowTasks, lnLo, lnHi, Task.spawn, - Bounds.dotIntervalLowerUnnorm, ratToReal_add] using h - have hhigh : - qRealOfInputs inputs q d ≤ (qHi q d : Real) := by - have h := add_le_add_right hdot_hi (inputs.bq d : Real) - simpa [qRealOfInputs, qHi, qHiArr, qHiRowTasks, lnLo, lnHi, Task.spawn, - Bounds.dotIntervalUpperUnnorm, ratToReal_add] using h - exact ⟨hlow, hhigh⟩ - have hk_bounds : - ∀ q d, (kLo q d : Real) ≤ kRealOfInputs inputs q d ∧ - kRealOfInputs inputs q d ≤ (kHi q d : Real) := by - intro q d - have hln := hln_bounds q - have hlo : ∀ j, (lnLo q j : Real) ≤ lnRealOfInputs inputs q j := fun j => - (hln j).1 - have hhi : ∀ j, lnRealOfInputs inputs q j ≤ (lnHi q j : Real) := fun j => - (hln j).2 - have hdot_lo := - Bounds.dotIntervalLower_le_dotProduct_real - (v := fun j => inputs.wk j d) (lo := lnLo q) (hi := lnHi q) - (x := lnRealOfInputs inputs q) hlo hhi - have hdot_hi := - Bounds.dotProduct_le_dotIntervalUpper_real - (v := fun j => inputs.wk j d) (lo := lnLo q) (hi := lnHi q) - (x := lnRealOfInputs inputs q) hlo hhi - have hlow : - (kLo q d : Real) ≤ kRealOfInputs inputs q d := by - have h := add_le_add_right hdot_lo (inputs.bk d : Real) - simpa [kRealOfInputs, kLo, kLoArr, kLoRowTasks, lnLo, lnHi, Task.spawn, - Bounds.dotIntervalLowerUnnorm, ratToReal_add] using h - have hhigh : - kRealOfInputs inputs q d ≤ (kHi q d : Real) := by - have h := add_le_add_right hdot_hi (inputs.bk d : Real) - simpa [kRealOfInputs, kHi, kHiArr, kHiRowTasks, lnLo, lnHi, Task.spawn, - Bounds.dotIntervalUpperUnnorm, ratToReal_add] using h - exact ⟨hlow, hhigh⟩ - have hscore_bounds : - ∀ q k, (scoreLo q k : Real) ≤ scoresRealOfInputs inputs q k ∧ - scoresRealOfInputs inputs q k ≤ (scoreHi q k : Real) := by - intro q k - let scoresReal := scoresRealOfInputs inputs - let base := - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) - have hdot_bounds : - (dotLo q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) ∧ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) ≤ (dotHi q k : Real) := by - have hq := hq_bounds q - have hk := hk_bounds k - have hlo1 : ∀ d, (qLo q d : Real) ≤ qRealOfInputs inputs q d := fun d => - (hq d).1 - have hhi1 : ∀ d, qRealOfInputs inputs q d ≤ (qHi q d : Real) := fun d => - (hq d).2 - have hlo2 : ∀ d, (kLo k d : Real) ≤ kRealOfInputs inputs k d := fun d => - (hk d).1 - have hhi2 : ∀ d, kRealOfInputs inputs k d ≤ (kHi k d : Real) := fun d => - (hk d).2 - have hspec := - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth_spec_real - (dims1 := splitDimsQ q) (dims2 := splitDimsK k) - (lo1 := fun d => qLo q d) (hi1 := fun d => qHi q d) - (lo2 := fun d => kLo k d) (hi2 := fun d => kHi k d) - (x := fun d => qRealOfInputs inputs q d) - (y := fun d => kRealOfInputs inputs k d) - hlo1 hhi1 hlo2 hhi2 - have hlow' : - (dotLo q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) := by - simpa [dotLo, dotRowTasks, Task.spawn, Array.getElem_ofFn] - using hspec.1 - have hhigh' : - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) ≤ (dotHi q k : Real) := by - simpa [dotHi, dotRowTasks, Task.spawn, Array.getElem_ofFn] - using hspec.2 - exact ⟨hlow', hhigh'⟩ - by_cases hcausal : inputs.maskCausal - · by_cases hle : k ≤ q - · have hnot : ¬ q < k := not_lt_of_ge hle - have hscore_eq : scoresReal q k = base := by - simp [scoresReal, scoresRealOfInputs, hcausal, hle, base] - by_cases hscale : 0 ≤ inputs.scale - · have hscale_real : 0 ≤ (inputs.scale : Real) := - ratToReal_nonneg_of_nonneg hscale - have hlow := - mul_le_mul_of_nonneg_left hdot_bounds.1 hscale_real - have hhigh := - mul_le_mul_of_nonneg_left hdot_bounds.2 hscale_real - constructor - · simpa [scoresReal, scoreLo, masked, hcausal, hnot, hscale, hscore_eq, base] - using hlow - · simpa [scoresReal, scoreHi, masked, hcausal, hnot, hscale, hscore_eq, base] - using hhigh - · have hscale_nonpos : inputs.scale ≤ 0 := - le_of_lt (lt_of_not_ge hscale) - have hscale_real : (inputs.scale : Real) ≤ 0 := - (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos - have hlow := - mul_le_mul_of_nonpos_left hdot_bounds.2 hscale_real - have hhigh := - mul_le_mul_of_nonpos_left hdot_bounds.1 hscale_real - constructor - · simpa [scoresReal, scoreLo, masked, hcausal, hnot, hscale, hscore_eq, base] - using hlow - · simpa [scoresReal, scoreHi, masked, hcausal, hnot, hscale, hscore_eq, base] - using hhigh - · have hlt : q < k := lt_of_not_ge hle - constructor - · simp [scoresRealOfInputs, scoreLo, masked, hcausal, hle, hlt] - · simp [scoresRealOfInputs, scoreHi, masked, hcausal, hle, hlt] - · have hscore_eq : scoresReal q k = base := by - simp [scoresReal, scoresRealOfInputs, hcausal, base] - by_cases hscale : 0 ≤ inputs.scale - · have hscale_real : 0 ≤ (inputs.scale : Real) := - ratToReal_nonneg_of_nonneg hscale - have hlow := - mul_le_mul_of_nonneg_left hdot_bounds.1 hscale_real - have hhigh := - mul_le_mul_of_nonneg_left hdot_bounds.2 hscale_real - constructor - · simpa [scoresReal, scoreLo, masked, hcausal, hscale, hscore_eq, base] - using hlow - · simpa [scoresReal, scoreHi, masked, hcausal, hscale, hscore_eq, base] - using hhigh - · have hscale_nonpos : inputs.scale ≤ 0 := - le_of_lt (lt_of_not_ge hscale) - have hscale_real : (inputs.scale : Real) ≤ 0 := - (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos - have hlow := - mul_le_mul_of_nonpos_left hdot_bounds.2 hscale_real - have hhigh := - mul_le_mul_of_nonpos_left hdot_bounds.1 hscale_real - constructor - · simpa [scoresReal, scoreLo, masked, hcausal, hscale, hscore_eq, base] - using hlow - · simpa [scoresReal, scoreHi, masked, hcausal, hscale, hscore_eq, base] - using hhigh - let scoresReal := scoresRealOfInputs inputs - have hmarginAt_le : - ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → - marginAt q ≤ scoreLoPrev q - scoreHi q k := by - intro q hq k hk - have hmem : k ∈ otherKeys q := by - simp [otherKeys, hk] - have hnonempty : (otherKeys q).Nonempty := ⟨k, hmem⟩ - have hle : - (otherKeys q).inf' hnonempty (fun k => scoreLoPrev q - scoreHi q k) ≤ - scoreLoPrev q - scoreHi q k := by - exact - (Finset.inf'_le_iff (s := otherKeys q) (H := hnonempty) - (f := fun k => scoreLoPrev q - scoreHi q k) - (a := scoreLoPrev q - scoreHi q k)).2 - ⟨k, hmem, le_rfl⟩ - simpa [marginAt, hq, hnonempty] using hle - let weights : Fin seq → Fin seq → Real := fun q k => - Circuit.softmax (scoresReal q) k - let others : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - have hscore_margin_real_at : - ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → - scoresReal q k + (marginAt q : Real) ≤ scoresReal q (inputs.prev q) := by - intro q hq k hk - have hmargin_le : marginAt q ≤ scoreLoPrev q - scoreHi q k := - hmarginAt_le q hq k hk - have hmargin_le_real : - (marginAt q : Real) ≤ (scoreLoPrev q : Real) - (scoreHi q k : Real) := - by - simpa [ratToReal_sub] using (ratToReal_le_of_le hmargin_le) - have hscore_hi : scoresReal q k ≤ (scoreHi q k : Real) := - (hscore_bounds q k).2 - have hscore_prev : (scoreLoPrev q : Real) ≤ scoresReal q (inputs.prev q) := by - have hprev_bounds := hscore_bounds q (inputs.prev q) - simpa [scoreLoPrev] using hprev_bounds.1 - have hscore_diff : scoresReal q k - (scoreHi q k : Real) ≤ 0 := by - have h := sub_le_sub_right hscore_hi (scoreHi q k : Real) - simpa using h - have hsum_le' : - (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k ≤ - (scoreLoPrev q : Real) := by - have hsub : - (scoreLoPrev q : Real) - (scoreHi q k : Real) ≤ - (scoreLoPrev q : Real) - scoresReal q k := - sub_le_sub_left hscore_hi (scoreLoPrev q : Real) - have hsum_le'' := add_le_add_left hsub (scoresReal q k) - have hsum_le''' : - (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k ≤ - (scoreLoPrev q : Real) - scoresReal q k + scoresReal q k := by - simpa [add_comm, add_left_comm, add_assoc] using hsum_le'' - calc - (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k - ≤ (scoreLoPrev q : Real) - scoresReal q k + scoresReal q k := hsum_le''' - _ = (scoreLoPrev q : Real) := by - simp [sub_add_cancel] - have hgap : - scoresReal q k + (marginAt q : Real) ≤ (scoreLoPrev q : Real) := by - have hstep := add_le_add_left hmargin_le_real (scoresReal q k) - have hstep' : - scoresReal q k + (marginAt q : Real) ≤ - (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k := by - simpa [add_comm, add_left_comm, add_assoc] using hstep - exact hstep'.trans hsum_le' - exact hgap.trans hscore_prev - have hscore_margin_real : - ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → - scoresReal q k + (margin : Real) ≤ scoresReal q (inputs.prev q) := by - intro q hq k hk - have hmargin_le : margin ≤ marginAt q := by - have hmem : q ∈ inputs.active := hq - have hnonempty : inputs.active.Nonempty := hactive - have hle := - (Finset.inf'_le_iff (s := inputs.active) (H := hnonempty) - (f := marginAt) (a := marginAt q)).2 ⟨q, hmem, le_rfl⟩ - simpa [margin, hnonempty] using hle - have hmargin_le_real : (margin : Real) ≤ (marginAt q : Real) := - ratToReal_le_of_le hmargin_le - have hscore := hscore_margin_real_at q hq k hk - have hscore' : - (marginAt q : Real) + scoresReal q k ≤ scoresReal q (inputs.prev q) := by - simpa [add_comm, add_left_comm, add_assoc] using hscore - have hstep := add_le_add_left hmargin_le_real (scoresReal q k) - have hstep' : - scoresReal q k + (margin : Real) ≤ (marginAt q : Real) + scoresReal q k := by - simpa [add_comm, add_left_comm, add_assoc] using hstep - exact hstep'.trans hscore' - have hsoftmax_bounds : - Layers.SoftmaxMarginBoundsOn (Val := Real) (eps : Real) (margin : Real) - (fun q => q ∈ inputs.active) inputs.prev scoresReal weights := by - classical - refine - { score_margin := ?_ - nonneg := ?_ - sum_one := ?_ - prev_large := ?_ - other_le := ?_ } - · intro q hq k hk - exact hscore_margin_real q hq k hk - · intro q _ k - simpa [weights] using - (Circuit.softmax_nonneg (scores := scoresReal q) k) - · intro q _ - simpa [weights] using - (Circuit.softmax_sum_one (scores := scoresReal q)) - · intro q hq - have hsum_others_le : (∑ k ∈ others q, weights q k) ≤ (eps : Real) := by - by_cases hneg : margin < 0 - · have heps : (eps : Real) = 1 := by - simp [eps, hneg] - have hsubset : others q ⊆ (Finset.univ : Finset (Fin seq)) := by - intro k hk - simp - have hnonneg : - ∀ k ∈ (Finset.univ : Finset (Fin seq)), 0 ≤ weights q k := by - intro k _ - simpa [weights] using - (Circuit.softmax_nonneg (scores := scoresReal q) k) - have hsum_le : - (∑ k ∈ others q, weights q k) ≤ - ∑ k ∈ (Finset.univ : Finset (Fin seq)), weights q k := - Finset.sum_le_sum_of_subset_of_nonneg hsubset (by - intro k hk _; exact hnonneg k hk) - have hsum_one : (∑ k, weights q k) = 1 := by - simpa [weights] using - (Circuit.softmax_sum_one (scores := scoresReal q)) - have hsum_le' : (∑ k ∈ others q, weights q k) ≤ 1 := by - simpa [hsum_one] using hsum_le - simpa [heps] using hsum_le' - · have hnonneg : 0 ≤ margin := le_of_not_gt hneg - have hnonneg_real : 0 ≤ (margin : Real) := by - exact ratToReal_nonneg_of_nonneg hnonneg - have hbound : - ∀ k ∈ others q, - weights q k ≤ (1 + (margin : Real))⁻¹ := by - intro k hk - have hkne : k ≠ inputs.prev q := (Finset.mem_erase.mp hk).1 - have hscore := hscore_margin_real q hq k hkne - simpa [weights] using - (Circuit.softmax_other_le_inv_one_add (scores := scoresReal q) - (prev := inputs.prev q) (k := k) (m := (margin : Real)) - hnonneg_real hscore) - have hsum_le : - (∑ k ∈ others q, weights q k) ≤ - ∑ k ∈ others q, (1 + (margin : Real))⁻¹ := - Finset.sum_le_sum hbound - have hsum_const : - (∑ k ∈ others q, (1 + (margin : Real))⁻¹) = - (others q).card * (1 + (margin : Real))⁻¹ := by - simp - have hcard : (others q).card = seq - 1 := by - simp [others, Finset.card_erase_of_mem] - have hsum_le' : - (∑ k ∈ others q, weights q k) ≤ - (seq - 1 : Real) * (1 + (margin : Real))⁻¹ := by - have hsum_le'' := hsum_le.trans_eq hsum_const - have hsum_le''' := hsum_le'' - simp only [hcard, Nat.cast_sub hseq, Nat.cast_one] at hsum_le''' - exact hsum_le''' - have hpos : (0 : Rat) < 1 + margin := by - have hone : (0 : Rat) < 1 := by - exact zero_lt_one - have hle : (1 : Rat) ≤ 1 + margin := by - exact le_add_of_nonneg_right hnonneg - exact lt_of_lt_of_le hone hle - have hden : (1 + margin) ≠ 0 := by - exact ne_of_gt hpos - have hrat' := ratDivUp_ge_real (seq - 1) (1 + margin) hden - have heps : - (seq - 1 : Real) * (1 + (margin : Real))⁻¹ ≤ (eps : Real) := by - simpa [eps, hneg, ratToReal, Rat.cast_div, Rat.cast_add, - Rat.cast_natCast, div_eq_mul_inv] using hrat' - exact le_trans hsum_le' heps - have hsum_eq : - weights q (inputs.prev q) + ∑ k ∈ others q, weights q k = 1 := by - have hsum' : - weights q (inputs.prev q) + ∑ k ∈ others q, weights q k = - ∑ k, weights q k := by - simp [others] - have hsum_one : (∑ k, weights q k) = 1 := by - simpa [weights] using - (Circuit.softmax_sum_one (scores := scoresReal q)) - calc - weights q (inputs.prev q) + ∑ k ∈ others q, weights q k = - ∑ k, weights q k := hsum' - _ = 1 := hsum_one - have hsum_le' : - weights q (inputs.prev q) + ∑ k ∈ others q, weights q k ≤ - weights q (inputs.prev q) + (eps : Real) := by - have hsum_le'' := add_le_add_left hsum_others_le (weights q (inputs.prev q)) - have hsum_le''' := hsum_le'' - rw [add_comm (∑ k ∈ others q, weights q k) - (weights q (inputs.prev q))] at hsum_le''' - rw [add_comm (eps : Real) (weights q (inputs.prev q))] at hsum_le''' - exact hsum_le''' - have hprev : - 1 ≤ weights q (inputs.prev q) + (eps : Real) := by - have hsum_le'' := hsum_le' - rw [hsum_eq] at hsum_le'' - exact hsum_le'' - exact hprev - · intro q hq k hk - have hsum_others_le : (∑ j ∈ others q, weights q j) ≤ (eps : Real) := by - by_cases hneg : margin < 0 - · have heps : (eps : Real) = 1 := by - simp [eps, hneg] - have hsubset : others q ⊆ (Finset.univ : Finset (Fin seq)) := by - intro j hj - simp - have hnonneg : - ∀ j ∈ (Finset.univ : Finset (Fin seq)), 0 ≤ weights q j := by - intro j _ - simpa [weights] using - (Circuit.softmax_nonneg (scores := scoresReal q) j) - have hsum_le : - (∑ j ∈ others q, weights q j) ≤ - ∑ j ∈ (Finset.univ : Finset (Fin seq)), weights q j := - Finset.sum_le_sum_of_subset_of_nonneg hsubset (by - intro j hj _; exact hnonneg j hj) - have hsum_one : (∑ j, weights q j) = 1 := by - simpa [weights] using - (Circuit.softmax_sum_one (scores := scoresReal q)) - have hsum_le' : (∑ j ∈ others q, weights q j) ≤ 1 := by - simpa [hsum_one] using hsum_le - simpa [heps] using hsum_le' - · have hnonneg : 0 ≤ margin := le_of_not_gt hneg - have hnonneg_real : 0 ≤ (margin : Real) := by - exact ratToReal_nonneg_of_nonneg hnonneg - have hbound : - ∀ j ∈ others q, - weights q j ≤ (1 + (margin : Real))⁻¹ := by - intro j hj - have hjne : j ≠ inputs.prev q := (Finset.mem_erase.mp hj).1 - have hscore := hscore_margin_real q hq j hjne - simpa [weights] using - (Circuit.softmax_other_le_inv_one_add (scores := scoresReal q) - (prev := inputs.prev q) (k := j) (m := (margin : Real)) - hnonneg_real hscore) - have hsum_le : - (∑ j ∈ others q, weights q j) ≤ - ∑ j ∈ others q, (1 + (margin : Real))⁻¹ := - Finset.sum_le_sum hbound - have hsum_const : - (∑ j ∈ others q, (1 + (margin : Real))⁻¹) = - (others q).card * (1 + (margin : Real))⁻¹ := by - simp - have hcard : (others q).card = seq - 1 := by - simp [others, Finset.card_erase_of_mem] - have hsum_le' : - (∑ j ∈ others q, weights q j) ≤ - (seq - 1 : Real) * (1 + (margin : Real))⁻¹ := by - have hsum_le'' := hsum_le.trans_eq hsum_const - have hsum_le''' := hsum_le'' - simp only [hcard, Nat.cast_sub hseq, Nat.cast_one] at hsum_le''' - exact hsum_le''' - have hpos : (0 : Rat) < 1 + margin := by - have hone : (0 : Rat) < 1 := by - exact zero_lt_one - have hle : (1 : Rat) ≤ 1 + margin := by - exact le_add_of_nonneg_right hnonneg - exact lt_of_lt_of_le hone hle - have hden : (1 + margin) ≠ 0 := by - exact ne_of_gt hpos - have hrat' := ratDivUp_ge_real (seq - 1) (1 + margin) hden - have heps : - (seq - 1 : Real) * (1 + (margin : Real))⁻¹ ≤ (eps : Real) := by - simpa [eps, hneg, ratToReal, Rat.cast_div, Rat.cast_add, - Rat.cast_natCast, div_eq_mul_inv] using hrat' - exact le_trans hsum_le' heps - have hk' : k ∈ others q := by - simp [others, hk] - have hnonneg : - ∀ j ∈ others q, 0 ≤ weights q j := by - intro j _ - simpa [weights] using - (Circuit.softmax_nonneg (scores := scoresReal q) j) - have hle : - weights q k ≤ ∑ j ∈ others q, weights q j := by - have h := Finset.single_le_sum hnonneg hk' - simpa using h - exact hle.trans hsum_others_le - have hepsAt : - ∀ q, epsAt q = - if marginAt q < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + marginAt q) := by - intro q - rfl - have oneHot_bounds_at : - ∀ q, q ∈ inputs.active → - Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAt q : Real) - (fun q' => q' = q) inputs.prev weights := by - intro q hq - exact - Sound.oneHot_bounds_at_of_marginAt - (active := inputs.active) - (prev := inputs.prev) - (scoresReal := scoresReal) - (marginAt := marginAt) - (epsAt := epsAt) - (hepsAt := hepsAt) - (hseq := hseq) - (hscore_margin_real_at := hscore_margin_real_at) - q hq - have hdir_wv : - ∀ j, (wvDir j : Real) = ∑ d, (dirHead d : Real) * (inputs.wv j d : Real) := by - intro j - have hsum : - ((∑ d, dirHead d * inputs.wv j d : Rat) : Real) = - ∑ d, ((dirHead d * inputs.wv j d : Rat) : Real) := by - have h := - Linear.ratToReal_sum_univ (f := fun d => dirHead d * inputs.wv j d) - dsimp [ratToReal] at h - exact h - have hsum' : - ∑ d, ((dirHead d * inputs.wv j d : Rat) : Real) = - ∑ d, (dirHead d : Real) * (inputs.wv j d : Real) := by - refine Finset.sum_congr rfl ?_ - intro d _ - simp - have hfinal := hsum.trans hsum' - calc - (wvDir j : Real) - = ((∑ d, dirHead d * inputs.wv j d : Rat) : Real) := by - simp [wvDir, Bounds.cacheBoundTask_apply, Linear.dotFin_eq_dotProduct, - dotProduct] - _ = ∑ d, (dirHead d : Real) * (inputs.wv j d : Real) := hfinal - have hdir_bv : - (bDir : Real) = ∑ d, (dirHead d : Real) * (inputs.bv d : Real) := by - have hsum : - ((∑ d, dirHead d * inputs.bv d : Rat) : Real) = - ∑ d, ((dirHead d * inputs.bv d : Rat) : Real) := by - have h := - Linear.ratToReal_sum_univ (f := fun d => dirHead d * inputs.bv d) - dsimp [ratToReal] at h - exact h - have hsum' : - ∑ d, ((dirHead d * inputs.bv d : Rat) : Real) = - ∑ d, (dirHead d : Real) * (inputs.bv d : Real) := by - refine Finset.sum_congr rfl ?_ - intro d _ - simp - have hfinal := hsum.trans hsum' - calc - (bDir : Real) - = ((∑ d, dirHead d * inputs.bv d : Rat) : Real) := by - simp [bDir, Linear.dotFin_eq_dotProduct, dotProduct] - _ = ∑ d, (dirHead d : Real) * (inputs.bv d : Real) := hfinal - have hvals_eq : - ∀ k, - valsRealOfInputs inputs k = - dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + - (bDir : Real) := by - intro k - classical - have hdot_add : - dotProduct (fun d => (dirHead d : Real)) - (fun d => - dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs k) + - (inputs.bv d : Real)) = - dotProduct (fun d => (dirHead d : Real)) - (fun d => dotProduct (fun j => (inputs.wv j d : Real)) - (lnRealOfInputs inputs k)) + - dotProduct (fun d => (dirHead d : Real)) (fun d => (inputs.bv d : Real)) := by - simp [dotProduct, mul_add, Finset.sum_add_distrib] - have hdot_wv : - dotProduct (fun d => (dirHead d : Real)) - (fun d => dotProduct (fun j => (inputs.wv j d : Real)) - (lnRealOfInputs inputs k)) = - dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) := by - classical - calc - dotProduct (fun d => (dirHead d : Real)) - (fun d => dotProduct (fun j => (inputs.wv j d : Real)) - (lnRealOfInputs inputs k)) = - ∑ d, (dirHead d : Real) * ∑ j, - (inputs.wv j d : Real) * lnRealOfInputs inputs k j := by - simp [dotProduct] - _ = ∑ d, ∑ j, - (dirHead d : Real) * - ((inputs.wv j d : Real) * lnRealOfInputs inputs k j) := by - simp [Finset.mul_sum] - _ = ∑ j, ∑ d, - (dirHead d : Real) * - ((inputs.wv j d : Real) * lnRealOfInputs inputs k j) := by - simpa using - (Finset.sum_comm (s := (Finset.univ : Finset (Fin dHead))) - (t := (Finset.univ : Finset (Fin dModel))) - (f := fun d j => - (dirHead d : Real) * - ((inputs.wv j d : Real) * lnRealOfInputs inputs k j))) - _ = ∑ j, (∑ d, (dirHead d : Real) * (inputs.wv j d : Real)) * - lnRealOfInputs inputs k j := by - refine Finset.sum_congr rfl ?_ - intro j _ - have hsum : - (∑ d, (dirHead d : Real) * (inputs.wv j d : Real)) * - lnRealOfInputs inputs k j = - ∑ d, (dirHead d : Real) * (inputs.wv j d : Real) * - lnRealOfInputs inputs k j := by - simp [Finset.sum_mul, mul_assoc] - simpa [mul_assoc] using hsum.symm - _ = ∑ j, (wvDir j : Real) * lnRealOfInputs inputs k j := by - refine Finset.sum_congr rfl ?_ - intro j _ - simp [hdir_wv j] - _ = dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) := by - simp [dotProduct] - calc - valsRealOfInputs inputs k = - dotProduct (fun d => (dirHead d : Real)) - (fun d => - dotProduct (fun j => (inputs.wv j d : Real)) - (lnRealOfInputs inputs k) + - (inputs.bv d : Real)) := by - simp [valsRealOfInputs, vRealOfInputs, dirHeadVec, dirHead] - _ = - dotProduct (fun d => (dirHead d : Real)) - (fun d => dotProduct (fun j => (inputs.wv j d : Real)) - (lnRealOfInputs inputs k)) + - dotProduct (fun d => (dirHead d : Real)) - (fun d => (inputs.bv d : Real)) := hdot_add - _ = - dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + - dotProduct (fun d => (dirHead d : Real)) - (fun d => (inputs.bv d : Real)) := by - simp [hdot_wv] - _ = - dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + - (bDir : Real) := by - have hb : - dotProduct (fun d => (dirHead d : Real)) - (fun d => (inputs.bv d : Real)) = - (bDir : Real) := by - have hb : (dotProduct (fun d => (dirHead d : Real)) - (fun d => (inputs.bv d : Real)) : Real) = (bDir : Real) := by - calc - dotProduct (fun d => (dirHead d : Real)) (fun d => (inputs.bv d : Real)) - = ∑ d, (dirHead d : Real) * (inputs.bv d : Real) := by - simp [dotProduct] - _ = (bDir : Real) := hdir_bv.symm - exact hb - simp [hb] - have hvals_bounds_at : - ∀ k, - (valCert.valsLo k : Real) ≤ valsRealOfInputs inputs k ∧ - valsRealOfInputs inputs k ≤ (valCert.valsHi k : Real) := by - intro k - have hdot_abs : - |dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k)| ≤ - (valsAbsBase : Real) := by - have hdot := hdot_abs_bound_sum (fun j => wvDir j) k - have hln_max_real : - (lnAbsMax k : Real) ≤ (lnAbsMaxMax : Real) := - ratToReal_le_of_le (hln_abs_max k) - have hsum_nonneg : 0 ≤ (Linear.sumFin dModel (fun j => |wvDir j|) : Real) := by - have hsum_nonneg' : 0 ≤ (Linear.sumFin dModel (fun j => |wvDir j|) : Rat) := by - have hsum_nonneg'' : 0 ≤ ∑ j, |wvDir j| := by - refine Finset.sum_nonneg ?_ - intro j _ - exact abs_nonneg _ - simpa [Linear.sumFin_eq_sum_univ] using hsum_nonneg'' - exact ratToReal_nonneg_of_nonneg hsum_nonneg' - have hmul : - (Linear.sumFin dModel (fun j => |wvDir j|) : Real) * (lnAbsMax k : Real) ≤ - (Linear.sumFin dModel (fun j => |wvDir j|) : Real) * (lnAbsMaxMax : Real) := - mul_le_mul_of_nonneg_left hln_max_real hsum_nonneg - have hfinal := hdot.trans hmul - simpa [valsAbsBase, ratToReal_mul] using hfinal - have hdot_bounds := (abs_le).1 hdot_abs - have hlow' := add_le_add_right hdot_bounds.1 (bDir : Real) - have hhigh' := add_le_add_right hdot_bounds.2 (bDir : Real) - have hlow : - (valCert.valsLo k : Real) ≤ valsRealOfInputs inputs k := by - simpa [valCert, valsLo, valsLoBase, valsAbsBase, hvals_eq k, ratToReal_sub, - sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using hlow' - have hhigh : - valsRealOfInputs inputs k ≤ (valCert.valsHi k : Real) := by - simpa [valCert, valsHi, valsHiBase, valsAbsBase, hvals_eq k, ratToReal_add, - add_comm, add_left_comm, add_assoc] using hhigh' - exact ⟨hlow, hhigh⟩ - have hvals_bounds : - ValueIntervalBounds (vals := valsRealOfInputs inputs) valCert := by - refine - { lo_le_hi := ?_ - lo_le_valsLo := ?_ - vals_bounds := ?_ - valsHi_le_hi := ?_ } - · rcases (Finset.univ_nonempty : univ.Nonempty) with ⟨k0, hk0⟩ - have hmem0 : k0 ∈ univ := hk0 - have hlo : (valCert.lo : Real) ≤ (valCert.valsLo k0 : Real) := by - have hloRat : valCert.lo ≤ valCert.valsLo k0 := by - change lo ≤ valsLo k0 - dsimp [lo] - refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) - (f := valsLo) (a := valsLo k0)).2 ?_ - refine ⟨k0, hmem0, ?_⟩ - exact le_rfl - exact ratToReal_le_of_le hloRat - have hvals : - (valCert.valsLo k0 : Real) ≤ valsRealOfInputs inputs k0 ∧ - valsRealOfInputs inputs k0 ≤ (valCert.valsHi k0 : Real) := by - exact hvals_bounds_at k0 - have hhi : (valCert.valsHi k0 : Real) ≤ (valCert.hi : Real) := by - have hhiRat : valCert.valsHi k0 ≤ valCert.hi := by - change valsHi k0 ≤ hi - dsimp [hi] - refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) - (f := valsHi) (a := valsHi k0)).2 ?_ - exact ⟨k0, ⟨hmem0, le_rfl⟩⟩ - exact ratToReal_le_of_le hhiRat - have hreal : - (valCert.lo : Real) ≤ (valCert.hi : Real) := - le_trans hlo (le_trans hvals.1 (le_trans hvals.2 hhi)) - exact (ratToReal_le_iff (x := valCert.lo) (y := valCert.hi)).1 hreal - · intro k - have hmem : k ∈ univ := by simp [univ] - have hloRat : valCert.lo ≤ valCert.valsLo k := by - change lo ≤ valsLo k - dsimp [lo] - refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) - (f := valsLo) (a := valsLo k)).2 ?_ - refine ⟨k, hmem, ?_⟩ - exact le_rfl - exact ratToReal_le_of_le hloRat - · intro k - exact hvals_bounds_at k - · intro k - have hmem : k ∈ univ := by simp [univ] - have hhiRat : valCert.valsHi k ≤ valCert.hi := by - change valsHi k ≤ hi - dsimp [hi] - refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) - (f := valsHi) (a := valsHi k)).2 ?_ - refine ⟨k, hmem, ?_⟩ - exact le_rfl - exact ratToReal_le_of_le hhiRat - exact - { softmax_bounds := hsoftmax_bounds - oneHot_bounds_at := oneHot_bounds_at - value_bounds := hvals_bounds } - · have : False := by - simp [buildInductionCertFromHeadCore?, hEps, hSqrt, hmodel, hactive] at hcore - exact this.elim - · have : False := by - simp [buildInductionCertFromHeadCore?, hEps, hSqrt] at hcore - exact this.elim - · have : False := by - simp [buildInductionCertFromHeadCore?, hEps] at hcore - exact this.elim + end Sound end Nfp diff --git a/Nfp/Sound/Induction/CoreSound.lean b/Nfp/Sound/Induction/CoreSound.lean new file mode 100644 index 0000000..cc30268 --- /dev/null +++ b/Nfp/Sound/Induction/CoreSound.lean @@ -0,0 +1,1382 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later +import Nfp.Sound.Induction.Core +import Nfp.Sound.Induction.CoreSound.Values +namespace Nfp +namespace Sound +open scoped BigOperators +open Nfp.Circuit +open Nfp.Sound.Bounds +variable {seq : Nat} +set_option maxHeartbeats 5000000 in +-- The soundness proof expands many cached bounds; extra heartbeats avoid spurious timeouts. +/-- Soundness for `buildInductionCertFromHeadCore?`. -/ +theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) + (hcore : buildInductionCertFromHeadCore? inputs = some c) : + InductionHeadCertSound inputs c := by + classical + by_cases hEps : 0 < inputs.lnEps + · by_cases hSqrt : 0 < sqrtLower inputs.lnEps + · by_cases hmodel : dModel = 0 + · have : False := by + simp [buildInductionCertFromHeadCore?, hEps, hSqrt, hmodel] at hcore + exact this.elim + · by_cases hactive : inputs.active.Nonempty + · let lnBounds := Bounds.cacheBoundPair2 (fun q => + Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) + let lnLo : Fin seq → Fin dModel → Rat := lnBounds.1 + let lnHi : Fin seq → Fin dModel → Rat := lnBounds.2 + let lnAbsMaxTask : Fin seq → Rat := + Bounds.cacheBoundTask (fun q => + Bounds.intervalAbsBound (lnLo q) (lnHi q)) + let lnAbsMaxArr : Array Rat := + Array.ofFn (fun q : Fin seq => lnAbsMaxTask q) + let lnAbsMax : Fin seq → Rat := fun q => + lnAbsMaxArr[q.1]'(by + have hsize : lnAbsMaxArr.size = seq := by simp [lnAbsMaxArr] + simp [hsize]) + let lnAbsMaxMax : Rat := + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := Finset.univ_nonempty + univ.sup' hnonempty (fun q => lnAbsMax q) + let invStdBoundsTasks : Array (Task (Rat × Rat)) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => invStdBounds inputs.lnEps (inputs.embed q))) + let invStdBoundsArr : Array (Rat × Rat) := + Array.ofFn (fun q : Fin seq => + (invStdBoundsTasks[q.1]'(by + have hsize : invStdBoundsTasks.size = seq := by simp [invStdBoundsTasks] + simp [hsize])).get) + let invStdLo : Fin seq → Rat := fun q => + (invStdBoundsArr[q.1]'(by + have hsize : invStdBoundsArr.size = seq := by simp [invStdBoundsArr] + simp [hsize])).1 + let invStdHi : Fin seq → Rat := fun q => + (invStdBoundsArr[q.1]'(by + have hsize : invStdBoundsArr.size = seq := by simp [invStdBoundsArr] + simp [hsize])).2 + let qBaseArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.wq j d) (fun j => inputs.ln1Beta j) + + inputs.bq d) + let qBase : Fin dHead → Rat := fun d => + qBaseArr[d.1]'(by + have hsize : qBaseArr.size = dHead := by simp [qBaseArr] + simp [hsize]) + let kBaseArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.wk j d) (fun j => inputs.ln1Beta j) + + inputs.bk d) + let kBase : Fin dHead → Rat := fun d => + kBaseArr[d.1]'(by + have hsize : kBaseArr.size = dHead := by simp [kBaseArr] + simp [hsize]) + let qCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + let μ := mean (inputs.embed q) + let coeff : Fin dModel → Rat := fun j => + inputs.ln1Gamma j * (inputs.embed q j - μ) + ⟨Array.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.wq j d) coeff), + by simp⟩)) + let qCoeffArr : Array { row : Array Rat // row.size = dHead } := + Array.ofFn (fun q : Fin seq => + (qCoeffRowTasks[q.1]'(by + have hsize : qCoeffRowTasks.size = seq := by simp [qCoeffRowTasks] + simp [hsize])).get) + let qCoeff : Fin seq → Fin dHead → Rat := fun q d => + let row := qCoeffArr[q.1]'(by + have hsize : qCoeffArr.size = seq := by simp [qCoeffArr] + simp [hsize]) + row.1[d.1]'(by simp [row.2]) + let kCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + let μ := mean (inputs.embed q) + let coeff : Fin dModel → Rat := fun j => + inputs.ln1Gamma j * (inputs.embed q j - μ) + ⟨Array.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.wk j d) coeff), + by simp⟩)) + let kCoeffArr : Array { row : Array Rat // row.size = dHead } := + Array.ofFn (fun q : Fin seq => + (kCoeffRowTasks[q.1]'(by + have hsize : kCoeffRowTasks.size = seq := by simp [kCoeffRowTasks] + simp [hsize])).get) + let kCoeff : Fin seq → Fin dHead → Rat := fun q d => + let row := kCoeffArr[q.1]'(by + have hsize : kCoeffArr.size = seq := by simp [kCoeffArr] + simp [hsize]) + row.1[d.1]'(by simp [row.2]) + let qLo : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) + qBase d + bounds.1 + let qHi : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) + qBase d + bounds.2 + let kLo : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) + kBase d + bounds.1 + let kHi : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) + kBase d + bounds.2 + let qAbs : Fin seq → Fin dHead → Rat := fun q d => max |qLo q d| |qHi q d| + let kAbs : Fin seq → Fin dHead → Rat := fun q d => max |kLo q d| |kHi q d| + let qAbsMaxArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := Finset.univ_nonempty + univ.sup' hnonempty (fun q => qAbs q d)) + let qAbsMax : Fin dHead → Rat := fun d => + qAbsMaxArr[d.1]'(by + have hsize : qAbsMaxArr.size = dHead := by + simp [qAbsMaxArr] + simp [hsize]) + let kAbsMaxArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := Finset.univ_nonempty + univ.sup' hnonempty (fun k => kAbs k d)) + let kAbsMax : Fin dHead → Rat := fun d => + kAbsMaxArr[d.1]'(by + have hsize : kAbsMaxArr.size = dHead := by + simp [kAbsMaxArr] + simp [hsize]) + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let splitBudgetQ : Nat := 8 + let splitBudgetK : Nat := 8 + let splitDimsQ : Fin seq → List (Fin dHead) := fun q => + let ambig := + (List.finRange dHead).filter (fun d => qLo q d < 0 ∧ 0 < qHi q d) + let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d + let step + (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) + (d : Fin dHead) : + Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := + let s := score d + match best with + | (none, none) => (some (s, d), none) + | (some b1, none) => + if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) + | (some b1, some b2) => + if b1.1 < s then (some (s, d), some b1) + else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) + | (none, some b2) => + if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) + let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => + match ambig.foldl step (none, none) with + | (some b1, some b2) => [b1.2, b2.2] + | (some b1, none) => [b1.2] + | (none, _) => [] + let dims1 := top2 ambig + let dims2 := top2 (ambig.filter (fun d => d ∉ dims1)) + (dims1 ++ dims2).take splitBudgetQ + let splitDimsK : Fin seq → Fin seq → List (Fin dHead) := fun q k => + let ambig := + (List.finRange dHead).filter (fun d => kLo k d < 0 ∧ 0 < kHi k d) + let score : Fin dHead → Rat := fun d => (kHi k d - kLo k d) * qAbs q d + let step + (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) + (d : Fin dHead) : + Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := + let s := score d + match best with + | (none, none) => (some (s, d), none) + | (some b1, none) => + if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) + | (some b1, some b2) => + if b1.1 < s then (some (s, d), some b1) + else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) + | (none, some b2) => + if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) + let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => + match ambig.foldl step (none, none) with + | (some b1, some b2) => [b1.2, b2.2] + | (some b1, none) => [b1.2] + | (none, _) => [] + let dims1 := top2 ambig + let dims2 := top2 (ambig.filter (fun d => d ∉ dims1)) + (dims1 ++ dims2).take splitBudgetK + let splitDimsDiff : Fin seq → Fin seq → List (Fin dHead) := fun q k => + let prev := inputs.prev q + let diffLo : Fin dHead → Rat := fun d => kLo prev d - kHi k d + let diffHi : Fin dHead → Rat := fun d => kHi prev d - kLo k d + let ambig := + (List.finRange dHead).filter (fun d => diffLo d < 0 ∧ 0 < diffHi d) + let score : Fin dHead → Rat := fun d => (diffHi d - diffLo d) * qAbs q d + let step + (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) + (d : Fin dHead) : + Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := + let s := score d + match best with + | (none, none) => (some (s, d), none) + | (some b1, none) => + if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) + | (some b1, some b2) => + if b1.1 < s then (some (s, d), some b1) + else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) + | (none, some b2) => + if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) + let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => + match ambig.foldl step (none, none) with + | (some b1, some b2) => [b1.2, b2.2] + | (some b1, none) => [b1.2] + | (none, _) => [] + let dims1 := top2 ambig + let dims2 := top2 (ambig.filter (fun d => d ∉ dims1)) + (dims1 ++ dims2).take splitBudgetK + let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + let dimsQ := splitDimsQ q + ⟨Array.ofFn (fun k : Fin seq => + let dimsK := splitDimsK q k + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsK + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo k d) (fun d => kHi k d)), + by simp⟩)) + let dotDiffRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + let dimsQ := splitDimsQ q + ⟨Array.ofFn (fun k : Fin seq => + if masked q k then + (0, 0) + else + let dimsDiff := splitDimsDiff q k + let prev := inputs.prev q + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)), + by simp⟩)) + let dotLo : Fin seq → Fin seq → Rat := fun q k => + let row := (dotRowTasks[q.1]'(by + simp [dotRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.1 + let dotHi : Fin seq → Fin seq → Rat := fun q k => + let row := (dotRowTasks[q.1]'(by + simp [dotRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.2 + let dotDiffLo : Fin seq → Fin seq → Rat := fun q k => + let row := (dotDiffRowTasks[q.1]'(by + simp [dotDiffRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.1 + let dotDiffHi : Fin seq → Fin seq → Rat := fun q k => + let row := (dotDiffRowTasks[q.1]'(by + simp [dotDiffRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.2 + let dotAbs : Fin seq → Fin seq → Rat := fun q k => max |dotLo q k| |dotHi q k| + let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => + |inputs.scale| * dotAbs q k + let scoreLo : Fin seq → Fin seq → Rat := fun q k => + if masked q k then + inputs.maskValue + else + if hscale : 0 ≤ inputs.scale then + inputs.scale * dotLo q k + else + inputs.scale * dotHi q k + let scoreHi : Fin seq → Fin seq → Rat := fun q k => + if masked q k then + inputs.maskValue + else + if hscale : 0 ≤ inputs.scale then + inputs.scale * dotHi q k + else + inputs.scale * dotLo q k + let scoreLoPrev : Fin seq → Rat := fun q => + scoreLo q (inputs.prev q) + let scoreGapLo : Fin seq → Fin seq → Rat := fun q k => + if masked q (inputs.prev q) then + scoreLoPrev q - scoreHi q k + else if masked q k then + scoreLoPrev q - inputs.maskValue + else if hscale : 0 ≤ inputs.scale then + inputs.scale * dotDiffLo q k + else + inputs.scale * dotDiffHi q k + let otherKeys : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + let marginAt : Fin seq → Rat := fun q => + if hq : q ∈ inputs.active then + let other := otherKeys q + if h : other.Nonempty then + other.inf' h (fun k => scoreGapLo q k) + else + (0 : Rat) + else + (0 : Rat) + let epsAt : Fin seq → Rat := fun q => + if marginAt q < 0 then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + marginAt q) + let margin : Rat := + if h : inputs.active.Nonempty then + inputs.active.inf' h marginAt + else + (0 : Rat) + let eps : Rat := + if margin < 0 then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + margin) + have hseq : (1 : Nat) ≤ seq := + Nat.succ_le_iff.mpr (Nat.pos_of_ne_zero (NeZero.ne seq)) + let dirHeadVec := dirHeadVecOfInputs inputs + let dirHead : Fin dHead → Rat := fun d => dirHeadVec.get d + let wvDir : Fin dModel → Rat := + Bounds.cacheBoundTask (fun j => + Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) + let bDir : Rat := + Linear.dotFin dHead dirHead (fun d => inputs.bv d) + let valsAbsBase : Rat := + Linear.sumFin dModel (fun j => |wvDir j|) * lnAbsMaxMax + let valsLoBase := bDir - valsAbsBase + let valsHiBase := bDir + valsAbsBase + let valsLo : Fin seq → Rat := fun _ => valsLoBase + let valsHi : Fin seq → Rat := fun _ => valsHiBase + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := by simp [univ] + let lo := univ.inf' hnonempty valsLo + let hi := univ.sup' hnonempty valsHi + let valCert : ValueInterval seq := + { lo := lo + hi := hi + valsLo := valsLo + valsHi := valsHi + direction := some inputs.directionSpec } + let cert : InductionHeadCert seq := + { eps := eps + epsAt := epsAt + margin := margin + active := inputs.active + prev := inputs.prev + values := valCert } + have hcore' : some cert = some c := by + simpa [buildInductionCertFromHeadCore?, hEps, hSqrt, hmodel, hactive, lnBounds, lnLo, + lnHi, lnAbsMaxTask, lnAbsMaxArr, lnAbsMax, lnAbsMaxMax, invStdBoundsTasks, + invStdBoundsArr, invStdLo, invStdHi, qBaseArr, qBase, kBaseArr, kBase, + qCoeffRowTasks, qCoeffArr, qCoeff, kCoeffRowTasks, kCoeffArr, kCoeff, qLo, qHi, kLo, + kHi, qAbs, kAbs, qAbsMaxArr, qAbsMax, kAbsMaxArr, kAbsMax, masked, splitBudgetQ, + splitBudgetK, splitDimsQ, splitDimsK, splitDimsDiff, dotRowTasks, dotDiffRowTasks, + dotLo, dotHi, dotDiffLo, dotDiffHi, dotAbs, scoreBaseAbs, scoreLo, scoreHi, + scoreLoPrev, scoreGapLo, otherKeys, marginAt, epsAt, margin, eps, dirHeadVec, + dirHead, wvDir, bDir, valsAbsBase, valsLoBase, valsHiBase, valsLo, valsHi, univ, lo, + hi, valCert, cert, Task.spawn, Bounds.cacheBoundTask_apply, + Array.getElem_ofFn] using hcore + have hc : c = cert := by + simpa using (Option.some.inj hcore').symm + subst hc + have hln_bounds : + ∀ q i, (lnLo q i : Real) ≤ lnRealOfInputs inputs q i ∧ + lnRealOfInputs inputs q i ≤ (lnHi q i : Real) := by + intro q i + have hln := + Bounds.layerNormBounds_spec (eps := inputs.lnEps) + (gamma := inputs.ln1Gamma) (beta := inputs.ln1Beta) + (x := inputs.embed q) hmodel hEps hSqrt + simpa [lnBounds, lnLo, lnHi, lnRealOfInputs, Bounds.cacheBoundPair2_apply_left, + Bounds.cacheBoundPair2_apply_right] using hln i + have hln_abs : ∀ q j, |lnRealOfInputs inputs q j| ≤ (lnAbsMax q : Real) := by + intro q j + have hln := hln_bounds q + have h := + Bounds.abs_le_intervalAbsBound_real (lo := lnLo q) (hi := lnHi q) + (x := lnRealOfInputs inputs q) (hlo := fun j => (hln j).1) + (hhi := fun j => (hln j).2) j + simpa [lnAbsMax, lnAbsMaxArr, lnAbsMaxTask, Bounds.cacheBoundTask_apply, + Array.getElem_ofFn] using h + have hln_abs_max : ∀ q, lnAbsMax q ≤ lnAbsMaxMax := by + intro q + have hnonempty : (Finset.univ : Finset (Fin seq)).Nonempty := + Finset.univ_nonempty + have hmem : q ∈ (Finset.univ : Finset (Fin seq)) := by simp + simpa [lnAbsMaxMax] using + (Finset.le_sup'_iff (s := (Finset.univ : Finset (Fin seq))) + (H := hnonempty) (f := fun q => lnAbsMax q) (a := lnAbsMax q)).2 + ⟨q, hmem, le_rfl⟩ + have hdot_abs_bound : + ∀ (v : Fin dModel → Rat) (q : Fin seq), + |dotProduct (fun j => (v j : Real)) (lnRealOfInputs inputs q)| ≤ + (Bounds.dotIntervalAbsBound v (lnLo q) (lnHi q) : Real) := by + intro v q + have hln := hln_bounds q + have hlo : ∀ j, (lnLo q j : Real) ≤ lnRealOfInputs inputs q j := fun j => + (hln j).1 + have hhi : ∀ j, lnRealOfInputs inputs q j ≤ (lnHi q j : Real) := fun j => + (hln j).2 + simpa using + (Bounds.abs_dotProduct_le_dotIntervalAbsBound_real + (v := v) (lo := lnLo q) (hi := lnHi q) + (x := lnRealOfInputs inputs q) hlo hhi) + have hdot_abs_bound_sum : + ∀ (v : Fin dModel → Rat) (q : Fin seq), + |dotProduct (fun j => (v j : Real)) (lnRealOfInputs inputs q)| ≤ + (Linear.sumFin dModel (fun j => |v j|) : Real) * (lnAbsMax q : Real) := by + intro v q + have hsum : + |∑ j, (v j : Real) * lnRealOfInputs inputs q j| ≤ + ∑ j, |(v j : Real) * lnRealOfInputs inputs q j| := by simpa [dotProduct] using + (Finset.abs_sum_le_sum_abs (s := (Finset.univ : Finset (Fin dModel))) + (f := fun j => (v j : Real) * lnRealOfInputs inputs q j)) + have hterm : + ∀ j, |(v j : Real) * lnRealOfInputs inputs q j| ≤ + (|v j| : Real) * (lnAbsMax q : Real) := by + intro j + have hln := hln_abs q j + have hnonneg : 0 ≤ (|v j| : Real) := by + exact abs_nonneg _ + calc + |(v j : Real) * lnRealOfInputs inputs q j| = + |(v j : Real)| * |lnRealOfInputs inputs q j| := by + simp [abs_mul] + _ ≤ (|v j| : Real) * (lnAbsMax q : Real) := + mul_le_mul_of_nonneg_left hln hnonneg + have hsum_le : + ∑ j, |(v j : Real) * lnRealOfInputs inputs q j| ≤ + ∑ j, (|v j| : Real) * (lnAbsMax q : Real) := by + refine Finset.sum_le_sum ?_ + intro j _ + exact hterm j + have hsum_mul : + ∑ j, (|v j| : Real) * (lnAbsMax q : Real) = + (∑ j, (|v j| : Real)) * (lnAbsMax q : Real) := by + symm + simpa using + (Finset.sum_mul (s := (Finset.univ : Finset (Fin dModel))) + (f := fun j => (|v j| : Real)) (a := (lnAbsMax q : Real))) + have hsum_cast : + (Linear.sumFin dModel (fun j => |v j|) : Real) = ∑ j, (|v j| : Real) := by + simpa [ratToReal] using (Linear.ratToReal_sumFin (f := fun j => |v j|)) + have hsum_eq : + ∑ j, (|v j| : Real) * (lnAbsMax q : Real) = + (Linear.sumFin dModel (fun j => |v j|) : Real) * (lnAbsMax q : Real) := by + calc + ∑ j, (|v j| : Real) * (lnAbsMax q : Real) + = (∑ j, (|v j| : Real)) * (lnAbsMax q : Real) := hsum_mul + _ = (Linear.sumFin dModel (fun j => |v j|) : Real) * (lnAbsMax q : Real) := by + simp [hsum_cast] + have hfinal := hsum.trans (hsum_le.trans_eq hsum_eq) + simpa [dotProduct] using hfinal + have dotFin_cast {n : Nat} (f g : Fin n → Rat) : + (Linear.dotFin n f g : Real) = + dotProduct (fun j => (f j : Real)) (fun j => (g j : Real)) := by + simp [Linear.dotFin_def, Linear.sumFin_eq_sum_univ, dotProduct] + have hq_bounds : + ∀ q d, (qLo q d : Real) ≤ qRealOfInputs inputs q d ∧ + qRealOfInputs inputs q d ≤ (qHi q d : Real) := by + intro q d + let x := inputs.embed q; let μRat : Rat := mean x + let centered : Fin dModel → Rat := fun j => x j - μRat + let coeff : Fin dModel → Rat := fun j => inputs.ln1Gamma j * centered j + let invStd : Real := + (Real.sqrt ((varianceRat x : Real) + (inputs.lnEps : Real)))⁻¹ + have hmu : (μRat : Real) = meanRat x := by + have hmu_rat : μRat = meanRat x := by simp [μRat, mean_def, hmodel, ratRoundDown] + simpa [ratToReal] using congrArg ratToReal hmu_rat + have hinv : (invStdLo q : Real) ≤ invStd ∧ invStd ≤ (invStdHi q : Real) := by + simpa [invStdLo, invStdHi, invStdBoundsArr, invStdBoundsTasks, invStdBounds, + Task.spawn, Array.getElem_ofFn] using + (Bounds.invStdBounds_spec (eps := inputs.lnEps) (x := x) hmodel hEps hSqrt) + have hln : ∀ j, + lnRealOfInputs inputs q j = + (inputs.ln1Beta j : Real) + (coeff j : Real) * invStd := by + intro j; simp [lnRealOfInputs, Bounds.layerNormReal, hmodel, coeff, centered, μRat, + hmu, invStd, x, add_comm, mul_assoc, -mul_eq_mul_left_iff, -mul_eq_mul_right_iff] + have hln_fun : + lnRealOfInputs inputs q = + fun j => (inputs.ln1Beta j : Real) + (coeff j : Real) * invStd := by + funext j; exact hln j + have hbase : + (qBase d : Real) = + dotProduct (fun j => (inputs.wq j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + (inputs.bq d : Real) := by simp [qBase, qBaseArr, dotFin_cast] + have hcoeff : + (qCoeff q d : Real) = + dotProduct (fun j => (inputs.wq j d : Real)) (fun j => (coeff j : Real)) := by + simp [qCoeff, qCoeffArr, qCoeffRowTasks, Task.spawn, coeff, centered, μRat, x, + dotFin_cast] + have hdot_add : + dotProduct (fun j => (inputs.wq j d : Real)) + (fun j => + (inputs.ln1Beta j : Real) + (coeff j : Real) * invStd) = + dotProduct (fun j => (inputs.wq j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + dotProduct (fun j => (inputs.wq j d : Real)) + (fun j => (coeff j : Real) * invStd) := by + simpa [dotProduct, mul_add] using + (Finset.sum_add_distrib (s := (Finset.univ : Finset (Fin dModel))) + (f := fun j => (inputs.wq j d : Real) * (inputs.ln1Beta j : Real)) + (g := fun j => (inputs.wq j d : Real) * ((coeff j : Real) * invStd))) + have hdot_coeff : + dotProduct (fun j => (inputs.wq j d : Real)) + (fun j => (coeff j : Real) * invStd) = + dotProduct (fun j => (inputs.wq j d : Real)) (fun j => (coeff j : Real)) * + invStd := by + simp [dotProduct, mul_assoc, Finset.sum_mul] + have hq_real : + qRealOfInputs inputs q d = + (qBase d : Real) + (qCoeff q d : Real) * invStd := by + calc + qRealOfInputs inputs q d = + dotProduct (fun j => (inputs.wq j d : Real)) + (fun j => + (inputs.ln1Beta j : Real) + (coeff j : Real) * invStd) + + (inputs.bq d : Real) := by simp [qRealOfInputs, hln_fun] + _ = + dotProduct (fun j => (inputs.wq j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + dotProduct (fun j => (inputs.wq j d : Real)) (fun j => (coeff j : Real)) * + invStd + + (inputs.bq d : Real) := by simp [hdot_add, hdot_coeff, add_assoc] + _ = + (dotProduct (fun j => (inputs.wq j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + (inputs.bq d : Real)) + + dotProduct (fun j => (inputs.wq j d : Real)) (fun j => (coeff j : Real)) * + invStd := by ac_rfl + _ = (qBase d : Real) + (qCoeff q d : Real) * invStd := by simp [hbase, hcoeff] + have hscale : + let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) + (bounds.1 : Real) ≤ (qCoeff q d : Real) * invStd ∧ + (qCoeff q d : Real) * invStd ≤ (bounds.2 : Real) := by + exact scaleInterval_bounds_real (x := qCoeff q d) (lo := invStdLo q) + (hi := invStdHi q) (y := invStd) hinv.1 hinv.2 + have hlow : + (qLo q d : Real) ≤ qRealOfInputs inputs q d := by + simpa [qLo, hq_real] using add_le_add_left hscale.1 (qBase d : Real) + have hhigh : + qRealOfInputs inputs q d ≤ (qHi q d : Real) := by + simpa [qHi, hq_real] using add_le_add_left hscale.2 (qBase d : Real) + exact ⟨hlow, hhigh⟩ + have hk_bounds : + ∀ q d, (kLo q d : Real) ≤ kRealOfInputs inputs q d ∧ + kRealOfInputs inputs q d ≤ (kHi q d : Real) := by + intro q d + let x := inputs.embed q; let μRat : Rat := mean x + let centered : Fin dModel → Rat := fun j => x j - μRat + let coeff : Fin dModel → Rat := fun j => inputs.ln1Gamma j * centered j + let invStd : Real := + (Real.sqrt ((varianceRat x : Real) + (inputs.lnEps : Real)))⁻¹ + have hmu : (μRat : Real) = meanRat x := by + have hmu_rat : μRat = meanRat x := by simp [μRat, mean_def, hmodel, ratRoundDown] + simpa [ratToReal] using congrArg ratToReal hmu_rat + have hinv : (invStdLo q : Real) ≤ invStd ∧ invStd ≤ (invStdHi q : Real) := by + simpa [invStdLo, invStdHi, invStdBoundsArr, invStdBoundsTasks, invStdBounds, + Task.spawn, Array.getElem_ofFn] using + (Bounds.invStdBounds_spec (eps := inputs.lnEps) (x := x) hmodel hEps hSqrt) + have hln : ∀ j, + lnRealOfInputs inputs q j = + (inputs.ln1Beta j : Real) + (coeff j : Real) * invStd := by + intro j; simp [lnRealOfInputs, Bounds.layerNormReal, hmodel, coeff, centered, μRat, + hmu, invStd, x, add_comm, mul_assoc, -mul_eq_mul_left_iff, -mul_eq_mul_right_iff] + have hln_fun : + lnRealOfInputs inputs q = + fun j => (inputs.ln1Beta j : Real) + (coeff j : Real) * invStd := by + funext j; exact hln j + have hbase : + (kBase d : Real) = + dotProduct (fun j => (inputs.wk j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + (inputs.bk d : Real) := by simp [kBase, kBaseArr, dotFin_cast] + have hcoeff : + (kCoeff q d : Real) = + dotProduct (fun j => (inputs.wk j d : Real)) (fun j => (coeff j : Real)) := by + simp [kCoeff, kCoeffArr, kCoeffRowTasks, Task.spawn, coeff, centered, μRat, x, + dotFin_cast] + have hdot_add : + dotProduct (fun j => (inputs.wk j d : Real)) + (fun j => + (inputs.ln1Beta j : Real) + (coeff j : Real) * invStd) = + dotProduct (fun j => (inputs.wk j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + dotProduct (fun j => (inputs.wk j d : Real)) + (fun j => (coeff j : Real) * invStd) := by + simpa [dotProduct, mul_add] using + (Finset.sum_add_distrib (s := (Finset.univ : Finset (Fin dModel))) + (f := fun j => (inputs.wk j d : Real) * (inputs.ln1Beta j : Real)) + (g := fun j => (inputs.wk j d : Real) * ((coeff j : Real) * invStd))) + have hdot_coeff : + dotProduct (fun j => (inputs.wk j d : Real)) + (fun j => (coeff j : Real) * invStd) = + dotProduct (fun j => (inputs.wk j d : Real)) (fun j => (coeff j : Real)) * + invStd := by + simp [dotProduct, mul_assoc, Finset.sum_mul] + have hk_real : + kRealOfInputs inputs q d = + (kBase d : Real) + (kCoeff q d : Real) * invStd := by + calc + kRealOfInputs inputs q d = + dotProduct (fun j => (inputs.wk j d : Real)) + (fun j => + (inputs.ln1Beta j : Real) + (coeff j : Real) * invStd) + + (inputs.bk d : Real) := by simp [kRealOfInputs, hln_fun] + _ = + dotProduct (fun j => (inputs.wk j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + dotProduct (fun j => (inputs.wk j d : Real)) (fun j => (coeff j : Real)) * + invStd + + (inputs.bk d : Real) := by simp [hdot_add, hdot_coeff, add_assoc] + _ = + (dotProduct (fun j => (inputs.wk j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + (inputs.bk d : Real)) + + dotProduct (fun j => (inputs.wk j d : Real)) (fun j => (coeff j : Real)) * + invStd := by ac_rfl + _ = (kBase d : Real) + (kCoeff q d : Real) * invStd := by simp [hbase, hcoeff] + have hscale : + let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) + (bounds.1 : Real) ≤ (kCoeff q d : Real) * invStd ∧ + (kCoeff q d : Real) * invStd ≤ (bounds.2 : Real) := by + exact scaleInterval_bounds_real (x := kCoeff q d) (lo := invStdLo q) + (hi := invStdHi q) (y := invStd) hinv.1 hinv.2 + have hlow : + (kLo q d : Real) ≤ kRealOfInputs inputs q d := by + simpa [kLo, hk_real] using add_le_add_left hscale.1 (kBase d : Real) + have hhigh : + kRealOfInputs inputs q d ≤ (kHi q d : Real) := by + simpa [kHi, hk_real] using add_le_add_left hscale.2 (kBase d : Real) + exact ⟨hlow, hhigh⟩ + have hscore_bounds : + ∀ q k, (scoreLo q k : Real) ≤ scoresRealOfInputs inputs q k ∧ + scoresRealOfInputs inputs q k ≤ (scoreHi q k : Real) := by + intro q k + let scoresReal := scoresRealOfInputs inputs + let base := + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) + have hdot_bounds : + (dotLo q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) ∧ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) ≤ (dotHi q k : Real) := by + have hq := hq_bounds q + have hk := hk_bounds k + have hlo1 : ∀ d, (qLo q d : Real) ≤ qRealOfInputs inputs q d := fun d => + (hq d).1 + have hhi1 : ∀ d, qRealOfInputs inputs q d ≤ (qHi q d : Real) := fun d => + (hq d).2 + have hlo2 : ∀ d, (kLo k d : Real) ≤ kRealOfInputs inputs k d := fun d => + (hk d).1 + have hhi2 : ∀ d, kRealOfInputs inputs k d ≤ (kHi k d : Real) := fun d => + (hk d).2 + have hspec := + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth_spec_real + (dims1 := splitDimsQ q) (dims2 := splitDimsK q k) + (lo1 := fun d => qLo q d) (hi1 := fun d => qHi q d) + (lo2 := fun d => kLo k d) (hi2 := fun d => kHi k d) + (x := fun d => qRealOfInputs inputs q d) + (y := fun d => kRealOfInputs inputs k d) + hlo1 hhi1 hlo2 hhi2 + have hlow' : + (dotLo q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) := by + simpa [dotLo, dotRowTasks, Task.spawn, Array.getElem_ofFn] + using hspec.1 + have hhigh' : + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) ≤ (dotHi q k : Real) := by + simpa [dotHi, dotRowTasks, Task.spawn, Array.getElem_ofFn] + using hspec.2 + exact ⟨hlow', hhigh'⟩ + by_cases hcausal : inputs.maskCausal + · by_cases hle : k ≤ q + · have hnot : ¬ q < k := not_lt_of_ge hle + have hscore_eq : scoresReal q k = base := by + simp [scoresReal, scoresRealOfInputs, hcausal, hle, base] + by_cases hscale : 0 ≤ inputs.scale + · have hscale_real : 0 ≤ (inputs.scale : Real) := + ratToReal_nonneg_of_nonneg hscale + have hlow := + mul_le_mul_of_nonneg_left hdot_bounds.1 hscale_real + have hhigh := + mul_le_mul_of_nonneg_left hdot_bounds.2 hscale_real + constructor + · simpa [scoresReal, scoreLo, masked, hcausal, hnot, hscale, hscore_eq, base] + using hlow + · simpa [scoresReal, scoreHi, masked, hcausal, hnot, hscale, hscore_eq, base] + using hhigh + · have hscale_nonpos : inputs.scale ≤ 0 := + le_of_lt (lt_of_not_ge hscale) + have hscale_real : (inputs.scale : Real) ≤ 0 := + (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos + have hlow := + mul_le_mul_of_nonpos_left hdot_bounds.2 hscale_real + have hhigh := + mul_le_mul_of_nonpos_left hdot_bounds.1 hscale_real + constructor + · simpa [scoresReal, scoreLo, masked, hcausal, hnot, hscale, hscore_eq, base] + using hlow + · simpa [scoresReal, scoreHi, masked, hcausal, hnot, hscale, hscore_eq, base] + using hhigh + · have hlt : q < k := lt_of_not_ge hle + constructor + · simp [scoresRealOfInputs, scoreLo, masked, hcausal, hle, hlt] + · simp [scoresRealOfInputs, scoreHi, masked, hcausal, hle, hlt] + · have hscore_eq : scoresReal q k = base := by + simp [scoresReal, scoresRealOfInputs, hcausal, base] + by_cases hscale : 0 ≤ inputs.scale + · have hscale_real : 0 ≤ (inputs.scale : Real) := + ratToReal_nonneg_of_nonneg hscale + have hlow := + mul_le_mul_of_nonneg_left hdot_bounds.1 hscale_real + have hhigh := + mul_le_mul_of_nonneg_left hdot_bounds.2 hscale_real + constructor + · simpa [scoresReal, scoreLo, masked, hcausal, hscale, hscore_eq, base] + using hlow + · simpa [scoresReal, scoreHi, masked, hcausal, hscale, hscore_eq, base] + using hhigh + · have hscale_nonpos : inputs.scale ≤ 0 := + le_of_lt (lt_of_not_ge hscale) + have hscale_real : (inputs.scale : Real) ≤ 0 := + (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos + have hlow := + mul_le_mul_of_nonpos_left hdot_bounds.2 hscale_real + have hhigh := + mul_le_mul_of_nonpos_left hdot_bounds.1 hscale_real + constructor + · simpa [scoresReal, scoreLo, masked, hcausal, hscale, hscore_eq, base] + using hlow + · simpa [scoresReal, scoreHi, masked, hcausal, hscale, hscore_eq, base] + using hhigh + have hdot_diff_bounds : + ∀ q k, ¬ masked q k → + (dotDiffLo q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) ∧ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by + intro q k hmask + have hq := hq_bounds q + have hkprev := hk_bounds (inputs.prev q) + have hk := hk_bounds k + have hlo1 : ∀ d, (qLo q d : Real) ≤ qRealOfInputs inputs q d := fun d => + (hq d).1 + have hhi1 : ∀ d, qRealOfInputs inputs q d ≤ (qHi q d : Real) := fun d => + (hq d).2 + have hlo2 : + ∀ d, + (kLo (inputs.prev q) d - kHi k d : Rat) ≤ + (kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) := by + intro d + have hprev_lo := (hkprev d).1 + have hk_hi := (hk d).2 + have h := sub_le_sub hprev_lo hk_hi + simpa [ratToReal_sub] using h + have hhi2 : + ∀ d, + (kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) ≤ + (kHi (inputs.prev q) d - kLo k d : Rat) := by + intro d + have hprev_hi := (hkprev d).2 + have hk_lo := (hk d).1 + have h := sub_le_sub hprev_hi hk_lo + simpa [ratToReal_sub] using h + have hspec := + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth_spec_real + (dims1 := splitDimsQ q) (dims2 := splitDimsDiff q k) + (lo1 := fun d => qLo q d) (hi1 := fun d => qHi q d) + (lo2 := fun d => kLo (inputs.prev q) d - kHi k d) + (hi2 := fun d => kHi (inputs.prev q) d - kLo k d) + (x := fun d => qRealOfInputs inputs q d) + (y := fun d => + kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) + hlo1 hhi1 hlo2 hhi2 + have hlow' : + (dotDiffLo q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) := by + simpa [dotDiffLo, dotDiffRowTasks, hmask, Task.spawn, Array.getElem_ofFn] + using hspec.1 + have hhigh' : + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by + simpa [dotDiffHi, dotDiffRowTasks, hmask, Task.spawn, Array.getElem_ofFn] + using hspec.2 + exact ⟨hlow', hhigh'⟩ + let scoresReal := scoresRealOfInputs inputs + have hmarginAt_le : + ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → + marginAt q ≤ scoreGapLo q k := by + intro q hq k hk + have hmem : k ∈ otherKeys q := by simp [otherKeys, hk] + have hnonempty : (otherKeys q).Nonempty := ⟨k, hmem⟩ + have hle : + (otherKeys q).inf' hnonempty (fun k => scoreGapLo q k) ≤ scoreGapLo q k := by + exact (Finset.inf'_le_iff (s := otherKeys q) (H := hnonempty) + (f := fun k => scoreGapLo q k) (a := scoreGapLo q k)).2 + ⟨k, hmem, le_rfl⟩ + simpa [marginAt, hq, hnonempty] using hle + have hscore_gap_real_at : + ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → + scoresReal q k + (scoreGapLo q k : Real) ≤ scoresReal q (inputs.prev q) := by + intro q hq k hk + by_cases hprevmask : masked q (inputs.prev q) + · have hscore_hi : scoresReal q k ≤ (scoreHi q k : Real) := + (hscore_bounds q k).2 + have hscore_prev : (scoreLoPrev q : Real) ≤ scoresReal q (inputs.prev q) := by + have hprev_bounds := hscore_bounds q (inputs.prev q) + simpa [scoreLoPrev] using hprev_bounds.1 + have hsum_le' : + (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k ≤ + (scoreLoPrev q : Real) := by + have hsub : + (scoreLoPrev q : Real) - (scoreHi q k : Real) ≤ + (scoreLoPrev q : Real) - scoresReal q k := + sub_le_sub_left hscore_hi (scoreLoPrev q : Real) + have hsum_le'' := add_le_add_left hsub (scoresReal q k) + have hsum_le''' : + (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k ≤ + (scoreLoPrev q : Real) - scoresReal q k + scoresReal q k := by + simpa [add_comm, add_left_comm, add_assoc] using hsum_le'' + calc + (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k + ≤ (scoreLoPrev q : Real) - scoresReal q k + scoresReal q k := hsum_le''' + _ = (scoreLoPrev q : Real) := by + simp [sub_add_cancel] + calc + scoresReal q k + (scoreGapLo q k : Real) + = (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k := by + simp [scoreGapLo, hprevmask, add_comm] + _ ≤ (scoreLoPrev q : Real) := hsum_le' + _ ≤ scoresReal q (inputs.prev q) := hscore_prev + · by_cases hmask : masked q k + · have hscore_prev : (scoreLoPrev q : Real) ≤ scoresReal q (inputs.prev q) := by + have hprev_bounds := hscore_bounds q (inputs.prev q) + simpa [scoreLoPrev] using hprev_bounds.1 + have hmask' : inputs.maskCausal = true ∧ q < k := by + simpa [masked] using hmask + have hle : ¬ k ≤ q := not_le_of_gt hmask'.2 + have hscore_k : scoresReal q k = (inputs.maskValue : Real) := by + simp [scoresReal, scoresRealOfInputs, hmask'.1, hle] + calc + scoresReal q k + (scoreGapLo q k : Real) + = (inputs.maskValue : Real) + (scoreLoPrev q : Real) - + (inputs.maskValue : Real) := by + simp [scoreGapLo, hprevmask, hmask, hscore_k] + _ = (scoreLoPrev q : Real) := by + simp [add_sub_cancel_left] + _ ≤ scoresReal q (inputs.prev q) := hscore_prev + · have hdiff := hdot_diff_bounds q k hmask + have hgap_le : + (scoreGapLo q k : Real) ≤ + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) := by + by_cases hscale : 0 ≤ inputs.scale + · have hscale_real : 0 ≤ (inputs.scale : Real) := + ratToReal_nonneg_of_nonneg hscale + have hle := mul_le_mul_of_nonneg_left hdiff.1 hscale_real + simpa [scoreGapLo, hprevmask, hmask, hscale] using hle + · have hscale_nonpos : inputs.scale ≤ 0 := + le_of_lt (lt_of_not_ge hscale) + have hscale_real : (inputs.scale : Real) ≤ 0 := + (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos + have hle := mul_le_mul_of_nonpos_left hdiff.2 hscale_real + simpa [scoreGapLo, hprevmask, hmask, hscale] using hle + have hscore_prev : + scoresReal q (inputs.prev q) = + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d) := by + by_cases hcausal : inputs.maskCausal + · have hlt_prev : ¬ q < inputs.prev q := by + intro hlt + exact hprevmask (by exact ⟨hcausal, hlt⟩) + have hle_prev : inputs.prev q ≤ q := le_of_not_gt hlt_prev + simp [scoresReal, scoresRealOfInputs, hcausal, hle_prev] + · simp [scoresReal, scoresRealOfInputs, hcausal] + have hscore_k : + scoresReal q k = + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) := by + by_cases hcausal : inputs.maskCausal + · have hlt : ¬ q < k := by + intro hlt + exact hmask (by exact ⟨hcausal, hlt⟩) + have hle : k ≤ q := le_of_not_gt hlt + simp [scoresReal, scoresRealOfInputs, hcausal, hle] + · simp [scoresReal, scoresRealOfInputs, hcausal] + have hdot_sub : + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) = + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d) - + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) := by + classical + simp [dotProduct, mul_sub, Finset.sum_sub_distrib] + have hscore_diff : + scoresReal q (inputs.prev q) - scoresReal q k = + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) := by + calc + scoresReal q (inputs.prev q) - scoresReal q k + = + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d) - + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) := by + simp [hscore_prev, hscore_k] + _ = + (inputs.scale : Real) * + (dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d) - + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d)) := by + simp [mul_sub] + _ = + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) := by + simp [hdot_sub] + have hgap_le' : + (scoreGapLo q k : Real) ≤ + scoresReal q (inputs.prev q) - scoresReal q k := by + simpa [hscore_diff] using hgap_le + have hgap_add := + add_le_add_right hgap_le' (scoresReal q k) + have hgap_add' : + scoresReal q k + (scoreGapLo q k : Real) ≤ + scoresReal q (inputs.prev q) := by + have hcancel : + scoresReal q k + (scoresReal q (inputs.prev q) - scoresReal q k) = + scoresReal q (inputs.prev q) := by + calc + scoresReal q k + (scoresReal q (inputs.prev q) - scoresReal q k) + = + scoresReal q k + scoresReal q (inputs.prev q) - + scoresReal q k := by + symm + exact add_sub_assoc (scoresReal q k) + (scoresReal q (inputs.prev q)) (scoresReal q k) + _ = scoresReal q (inputs.prev q) := by + simp [add_sub_cancel_left] + calc + scoresReal q k + (scoreGapLo q k : Real) + ≤ scoresReal q k + (scoresReal q (inputs.prev q) - scoresReal q k) := + hgap_add + _ = scoresReal q (inputs.prev q) := hcancel + exact hgap_add' + let weights : Fin seq → Fin seq → Real := fun q k => + Circuit.softmax (scoresReal q) k + let others : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + have hscore_margin_real_at : + ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → + scoresReal q k + (marginAt q : Real) ≤ scoresReal q (inputs.prev q) := by + intro q hq k hk + have hmargin_le : marginAt q ≤ scoreGapLo q k := + hmarginAt_le q hq k hk + have hmargin_le_real : (marginAt q : Real) ≤ (scoreGapLo q k : Real) := + ratToReal_le_of_le hmargin_le + have hscore_gap := hscore_gap_real_at q hq k hk + have hstep := add_le_add_right hmargin_le_real (scoresReal q k) + have hstep' : + scoresReal q k + (marginAt q : Real) ≤ + scoresReal q k + (scoreGapLo q k : Real) := by + exact hstep + exact hstep'.trans hscore_gap + have hscore_margin_real : + ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → + scoresReal q k + (margin : Real) ≤ scoresReal q (inputs.prev q) := by + intro q hq k hk + have hmargin_le : margin ≤ marginAt q := by + have hmem : q ∈ inputs.active := hq + have hnonempty : inputs.active.Nonempty := hactive + have hle := + (Finset.inf'_le_iff (s := inputs.active) (H := hnonempty) + (f := marginAt) (a := marginAt q)).2 ⟨q, hmem, le_rfl⟩ + simpa [margin, hnonempty] using hle + have hmargin_le_real : (margin : Real) ≤ (marginAt q : Real) := + ratToReal_le_of_le hmargin_le + have hscore := hscore_margin_real_at q hq k hk + have hscore' : + (marginAt q : Real) + scoresReal q k ≤ scoresReal q (inputs.prev q) := by + simpa [add_comm] using hscore + have hstep := add_le_add_right hmargin_le_real (scoresReal q k) + have hstep' : + scoresReal q k + (margin : Real) ≤ (marginAt q : Real) + scoresReal q k := by + calc + scoresReal q k + (margin : Real) ≤ scoresReal q k + (marginAt q : Real) := hstep + _ = (marginAt q : Real) + scoresReal q k := by + simp [add_comm] + exact hstep'.trans hscore' + have hsoftmax_bounds : + Layers.SoftmaxMarginBoundsOn (Val := Real) (eps : Real) (margin : Real) + (fun q => q ∈ inputs.active) inputs.prev scoresReal weights := by + classical + refine + { score_margin := ?_ + nonneg := ?_ + sum_one := ?_ + prev_large := ?_ + other_le := ?_ } + · intro q hq k hk + exact hscore_margin_real q hq k hk + · intro q _ k + simpa [weights] using + (Circuit.softmax_nonneg (scores := scoresReal q) k) + · intro q _ + simpa [weights] using + (Circuit.softmax_sum_one (scores := scoresReal q)) + · intro q hq + have hsum_others_le : (∑ k ∈ others q, weights q k) ≤ (eps : Real) := by + by_cases hneg : margin < 0 + · have heps : (eps : Real) = 1 := by + simp [eps, hneg] + have hsubset : others q ⊆ (Finset.univ : Finset (Fin seq)) := by + intro k hk + simp + have hnonneg : + ∀ k ∈ (Finset.univ : Finset (Fin seq)), 0 ≤ weights q k := by + intro k _ + simpa [weights] using + (Circuit.softmax_nonneg (scores := scoresReal q) k) + have hsum_le : + (∑ k ∈ others q, weights q k) ≤ + ∑ k ∈ (Finset.univ : Finset (Fin seq)), weights q k := + Finset.sum_le_sum_of_subset_of_nonneg hsubset (by + intro k hk _; exact hnonneg k hk) + have hsum_one : (∑ k, weights q k) = 1 := by + simpa [weights] using + (Circuit.softmax_sum_one (scores := scoresReal q)) + have hsum_le' : (∑ k ∈ others q, weights q k) ≤ 1 := by + simpa [hsum_one] using hsum_le + simpa [heps] using hsum_le' + · have hnonneg : 0 ≤ margin := le_of_not_gt hneg + have hnonneg_real : 0 ≤ (margin : Real) := by + exact ratToReal_nonneg_of_nonneg hnonneg + have hbound : + ∀ k ∈ others q, + weights q k ≤ (1 + (margin : Real))⁻¹ := by + intro k hk + have hkne : k ≠ inputs.prev q := (Finset.mem_erase.mp hk).1 + have hscore := hscore_margin_real q hq k hkne + simpa [weights] using + (Circuit.softmax_other_le_inv_one_add (scores := scoresReal q) + (prev := inputs.prev q) (k := k) (m := (margin : Real)) + hnonneg_real hscore) + have hsum_le : + (∑ k ∈ others q, weights q k) ≤ + ∑ k ∈ others q, (1 + (margin : Real))⁻¹ := + Finset.sum_le_sum hbound + have hsum_const : + (∑ k ∈ others q, (1 + (margin : Real))⁻¹) = + (others q).card * (1 + (margin : Real))⁻¹ := by + simp + have hcard : (others q).card = seq - 1 := by + simp [others, Finset.card_erase_of_mem] + have hsum_le' : + (∑ k ∈ others q, weights q k) ≤ + (seq - 1 : Real) * (1 + (margin : Real))⁻¹ := by + have hsum_le'' := hsum_le.trans_eq hsum_const + have hsum_le''' := hsum_le'' + simp only [hcard, Nat.cast_sub hseq, Nat.cast_one] at hsum_le''' + exact hsum_le''' + have hpos : (0 : Rat) < 1 + margin := by + have hone : (0 : Rat) < 1 := by + exact zero_lt_one + have hle : (1 : Rat) ≤ 1 + margin := by + exact le_add_of_nonneg_right hnonneg + exact lt_of_lt_of_le hone hle + have hden : (1 + margin) ≠ 0 := by + exact ne_of_gt hpos + have hrat' := ratDivUp_ge_real (seq - 1) (1 + margin) hden + have heps : + (seq - 1 : Real) * (1 + (margin : Real))⁻¹ ≤ (eps : Real) := by + simpa [eps, hneg, ratToReal, Rat.cast_div, Rat.cast_add, + Rat.cast_natCast, div_eq_mul_inv] using hrat' + exact le_trans hsum_le' heps + have hsum_eq : + weights q (inputs.prev q) + ∑ k ∈ others q, weights q k = 1 := by + have hsum' : + weights q (inputs.prev q) + ∑ k ∈ others q, weights q k = + ∑ k, weights q k := by + simp [others] + have hsum_one : (∑ k, weights q k) = 1 := by + simpa [weights] using + (Circuit.softmax_sum_one (scores := scoresReal q)) + calc + weights q (inputs.prev q) + ∑ k ∈ others q, weights q k = + ∑ k, weights q k := hsum' + _ = 1 := hsum_one + have hsum_le' : + weights q (inputs.prev q) + ∑ k ∈ others q, weights q k ≤ + weights q (inputs.prev q) + (eps : Real) := by + have hsum_le'' := add_le_add_left hsum_others_le (weights q (inputs.prev q)) + have hsum_le''' := hsum_le'' + rw [add_comm (∑ k ∈ others q, weights q k) + (weights q (inputs.prev q))] at hsum_le''' + rw [add_comm (eps : Real) (weights q (inputs.prev q))] at hsum_le''' + exact hsum_le''' + have hprev : + 1 ≤ weights q (inputs.prev q) + (eps : Real) := by + have hsum_le'' := hsum_le' + rw [hsum_eq] at hsum_le'' + exact hsum_le'' + exact hprev + · intro q hq k hk + have hsum_others_le : (∑ j ∈ others q, weights q j) ≤ (eps : Real) := by + by_cases hneg : margin < 0 + · have heps : (eps : Real) = 1 := by + simp [eps, hneg] + have hsubset : others q ⊆ (Finset.univ : Finset (Fin seq)) := by + intro j hj + simp + have hnonneg : + ∀ j ∈ (Finset.univ : Finset (Fin seq)), 0 ≤ weights q j := by + intro j _ + simpa [weights] using + (Circuit.softmax_nonneg (scores := scoresReal q) j) + have hsum_le : + (∑ j ∈ others q, weights q j) ≤ + ∑ j ∈ (Finset.univ : Finset (Fin seq)), weights q j := + Finset.sum_le_sum_of_subset_of_nonneg hsubset (by + intro j hj _; exact hnonneg j hj) + have hsum_one : (∑ j, weights q j) = 1 := by + simpa [weights] using + (Circuit.softmax_sum_one (scores := scoresReal q)) + have hsum_le' : (∑ j ∈ others q, weights q j) ≤ 1 := by + simpa [hsum_one] using hsum_le + simpa [heps] using hsum_le' + · have hnonneg : 0 ≤ margin := le_of_not_gt hneg + have hnonneg_real : 0 ≤ (margin : Real) := by + exact ratToReal_nonneg_of_nonneg hnonneg + have hbound : + ∀ j ∈ others q, + weights q j ≤ (1 + (margin : Real))⁻¹ := by + intro j hj + have hjne : j ≠ inputs.prev q := (Finset.mem_erase.mp hj).1 + have hscore := hscore_margin_real q hq j hjne + simpa [weights] using + (Circuit.softmax_other_le_inv_one_add (scores := scoresReal q) + (prev := inputs.prev q) (k := j) (m := (margin : Real)) + hnonneg_real hscore) + have hsum_le : + (∑ j ∈ others q, weights q j) ≤ + ∑ j ∈ others q, (1 + (margin : Real))⁻¹ := + Finset.sum_le_sum hbound + have hsum_const : + (∑ j ∈ others q, (1 + (margin : Real))⁻¹) = + (others q).card * (1 + (margin : Real))⁻¹ := by + simp + have hcard : (others q).card = seq - 1 := by + simp [others, Finset.card_erase_of_mem] + have hsum_le' : + (∑ j ∈ others q, weights q j) ≤ + (seq - 1 : Real) * (1 + (margin : Real))⁻¹ := by + have hsum_le'' := hsum_le.trans_eq hsum_const + have hsum_le''' := hsum_le'' + simp only [hcard, Nat.cast_sub hseq, Nat.cast_one] at hsum_le''' + exact hsum_le''' + have hpos : (0 : Rat) < 1 + margin := by + have hone : (0 : Rat) < 1 := by + exact zero_lt_one + have hle : (1 : Rat) ≤ 1 + margin := by + exact le_add_of_nonneg_right hnonneg + exact lt_of_lt_of_le hone hle + have hden : (1 + margin) ≠ 0 := by + exact ne_of_gt hpos + have hrat' := ratDivUp_ge_real (seq - 1) (1 + margin) hden + have heps : + (seq - 1 : Real) * (1 + (margin : Real))⁻¹ ≤ (eps : Real) := by + simpa [eps, hneg, ratToReal, Rat.cast_div, Rat.cast_add, + Rat.cast_natCast, div_eq_mul_inv] using hrat' + exact le_trans hsum_le' heps + have hk' : k ∈ others q := by + simp [others, hk] + have hnonneg : + ∀ j ∈ others q, 0 ≤ weights q j := by + intro j _ + simpa [weights] using + (Circuit.softmax_nonneg (scores := scoresReal q) j) + have hle : + weights q k ≤ ∑ j ∈ others q, weights q j := by + have h := Finset.single_le_sum hnonneg hk' + simpa using h + exact hle.trans hsum_others_le + have hepsAt : + ∀ q, epsAt q = + if marginAt q < 0 then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + marginAt q) := by + intro q + rfl + have oneHot_bounds_at : + ∀ q, q ∈ inputs.active → + Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAt q : Real) + (fun q' => q' = q) inputs.prev weights := by + intro q hq + exact + Sound.oneHot_bounds_at_of_marginAt + (active := inputs.active) + (prev := inputs.prev) + (scoresReal := scoresReal) + (marginAt := marginAt) + (epsAt := epsAt) + (hepsAt := hepsAt) + (hseq := hseq) + (hscore_margin_real_at := hscore_margin_real_at) + q hq + have hdirHead : + dirHead = fun d => (dirHeadVecOfInputs inputs).get d := by + simp [dirHead, dirHeadVec] + have hwvDir : + ∀ j, wvDir j = Linear.dotFin dHead dirHead (fun d => inputs.wv j d) := by + intro j + simp [wvDir, Bounds.cacheBoundTask_apply] + have hbDir : + bDir = Linear.dotFin dHead dirHead (fun d => inputs.bv d) := by + rfl + have hdir_wv : + ∀ j, (wvDir j : Real) = ∑ d, (dirHead d : Real) * (inputs.wv j d : Real) := + wvDir_real_eq_sum inputs dirHead wvDir hwvDir + have hdir_bv : + (bDir : Real) = ∑ d, (dirHead d : Real) * (inputs.bv d : Real) := + bDir_real_eq_sum inputs dirHead bDir hbDir + have hvals_eq : + ∀ k, + valsRealOfInputs inputs k = + dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + + (bDir : Real) := + valsReal_eq_of_dir inputs dirHead hdirHead wvDir bDir hdir_wv hdir_bv + have hvals_bounds_at : + ∀ k, + (valCert.valsLo k : Real) ≤ valsRealOfInputs inputs k ∧ + valsRealOfInputs inputs k ≤ (valCert.valsHi k : Real) := by + intro k + have hdot_abs : + |dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k)| ≤ + (valsAbsBase : Real) := by + have hdot := hdot_abs_bound_sum (fun j => wvDir j) k + have hln_max_real : + (lnAbsMax k : Real) ≤ (lnAbsMaxMax : Real) := + ratToReal_le_of_le (hln_abs_max k) + have hsum_nonneg : 0 ≤ (Linear.sumFin dModel (fun j => |wvDir j|) : Real) := by + have hsum_nonneg' : 0 ≤ (Linear.sumFin dModel (fun j => |wvDir j|) : Rat) := by + have hsum_nonneg'' : 0 ≤ ∑ j, |wvDir j| := by + refine Finset.sum_nonneg ?_ + intro j _ + exact abs_nonneg _ + simpa [Linear.sumFin_eq_sum_univ] using hsum_nonneg'' + exact ratToReal_nonneg_of_nonneg hsum_nonneg' + have hmul : + (Linear.sumFin dModel (fun j => |wvDir j|) : Real) * (lnAbsMax k : Real) ≤ + (Linear.sumFin dModel (fun j => |wvDir j|) : Real) * (lnAbsMaxMax : Real) := + mul_le_mul_of_nonneg_left hln_max_real hsum_nonneg + have hfinal := hdot.trans hmul + simpa [valsAbsBase, ratToReal_mul] using hfinal + have hdot_bounds := (abs_le).1 hdot_abs + have hlow' := add_le_add_right hdot_bounds.1 (bDir : Real) + have hhigh' := add_le_add_right hdot_bounds.2 (bDir : Real) + have hlow : + (valCert.valsLo k : Real) ≤ valsRealOfInputs inputs k := by + simpa [valCert, valsLo, valsLoBase, valsAbsBase, hvals_eq k, ratToReal_sub, + sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using hlow' + have hhigh : + valsRealOfInputs inputs k ≤ (valCert.valsHi k : Real) := by + simpa [valCert, valsHi, valsHiBase, valsAbsBase, hvals_eq k, ratToReal_add, + add_comm, add_left_comm, add_assoc] using hhigh' + exact ⟨hlow, hhigh⟩ + have hvals_bounds : + ValueIntervalBounds (vals := valsRealOfInputs inputs) valCert := by + refine + { lo_le_hi := ?_ + lo_le_valsLo := ?_ + vals_bounds := ?_ + valsHi_le_hi := ?_ } + · rcases (Finset.univ_nonempty : univ.Nonempty) with ⟨k0, hk0⟩ + have hmem0 : k0 ∈ univ := hk0 + have hlo : (valCert.lo : Real) ≤ (valCert.valsLo k0 : Real) := by + have hloRat : valCert.lo ≤ valCert.valsLo k0 := by + change lo ≤ valsLo k0 + dsimp [lo] + refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) + (f := valsLo) (a := valsLo k0)).2 ?_ + refine ⟨k0, hmem0, ?_⟩ + exact le_rfl + exact ratToReal_le_of_le hloRat + have hvals : + (valCert.valsLo k0 : Real) ≤ valsRealOfInputs inputs k0 ∧ + valsRealOfInputs inputs k0 ≤ (valCert.valsHi k0 : Real) := by + exact hvals_bounds_at k0 + have hhi : (valCert.valsHi k0 : Real) ≤ (valCert.hi : Real) := by + have hhiRat : valCert.valsHi k0 ≤ valCert.hi := by + change valsHi k0 ≤ hi + dsimp [hi] + refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) + (f := valsHi) (a := valsHi k0)).2 ?_ + exact ⟨k0, ⟨hmem0, le_rfl⟩⟩ + exact ratToReal_le_of_le hhiRat + have hreal : + (valCert.lo : Real) ≤ (valCert.hi : Real) := + le_trans hlo (le_trans hvals.1 (le_trans hvals.2 hhi)) + exact (ratToReal_le_iff (x := valCert.lo) (y := valCert.hi)).1 hreal + · intro k + have hmem : k ∈ univ := by simp [univ] + have hloRat : valCert.lo ≤ valCert.valsLo k := by + change lo ≤ valsLo k + dsimp [lo] + refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) + (f := valsLo) (a := valsLo k)).2 ?_ + refine ⟨k, hmem, ?_⟩ + exact le_rfl + exact ratToReal_le_of_le hloRat + · intro k + exact hvals_bounds_at k + · intro k + have hmem : k ∈ univ := by simp [univ] + have hhiRat : valCert.valsHi k ≤ valCert.hi := by + change valsHi k ≤ hi + dsimp [hi] + refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) + (f := valsHi) (a := valsHi k)).2 ?_ + refine ⟨k, hmem, ?_⟩ + exact le_rfl + exact ratToReal_le_of_le hhiRat + exact + { softmax_bounds := hsoftmax_bounds + oneHot_bounds_at := oneHot_bounds_at + value_bounds := hvals_bounds } + · have : False := by + simp [buildInductionCertFromHeadCore?, hEps, hSqrt, hmodel, hactive] at hcore + exact this.elim + · have : False := by + simp [buildInductionCertFromHeadCore?, hEps, hSqrt] at hcore + exact this.elim + · have : False := by + simp [buildInductionCertFromHeadCore?, hEps] at hcore + exact this.elim +end Sound +end Nfp diff --git a/Nfp/Sound/Induction/CoreSound/Values.lean b/Nfp/Sound/Induction/CoreSound/Values.lean new file mode 100644 index 0000000..457b658 --- /dev/null +++ b/Nfp/Sound/Induction/CoreSound/Values.lean @@ -0,0 +1,162 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later +import Mathlib.Algebra.BigOperators.Group.Finset.Basic +import Nfp.Sound.Induction.CoreDefs +import Nfp.Sound.Linear.FinFold + +/-! +Helper lemmas for value-direction bounds in induction-head soundness. + +These isolate the algebra needed to rewrite direction-value projections into +dot products over cached `wvDir`/`bDir` terms. +-/ + +namespace Nfp + +namespace Sound + +open scoped BigOperators + +open Nfp.Sound.Linear + +variable {seq dModel dHead : Nat} + +/-- Cast a cached `wvDir` dot to a Real-valued sum over head weights. -/ +theorem wvDir_real_eq_sum (inputs : Model.InductionHeadInputs seq dModel dHead) + (dirHead : Fin dHead → Rat) (wvDir : Fin dModel → Rat) + (hwvDir : ∀ j, wvDir j = Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) : + ∀ j, (wvDir j : Real) = ∑ d, (dirHead d : Real) * (inputs.wv j d : Real) := by + intro j + have hsum : + ((∑ d, dirHead d * inputs.wv j d : Rat) : Real) = + ∑ d, ((dirHead d * inputs.wv j d : Rat) : Real) := by + simp + have hsum' : + ∑ d, ((dirHead d * inputs.wv j d : Rat) : Real) = + ∑ d, (dirHead d : Real) * (inputs.wv j d : Real) := by + refine Finset.sum_congr rfl ?_ + intro d _ + simp + have hfinal := hsum.trans hsum' + calc + (wvDir j : Real) + = ((∑ d, dirHead d * inputs.wv j d : Rat) : Real) := by + simp [hwvDir j, Linear.dotFin_eq_dotProduct, dotProduct] + _ = ∑ d, (dirHead d : Real) * (inputs.wv j d : Real) := hfinal + +/-- Cast a cached `bDir` dot to a Real-valued sum over head biases. -/ +theorem bDir_real_eq_sum (inputs : Model.InductionHeadInputs seq dModel dHead) + (dirHead : Fin dHead → Rat) (bDir : Rat) + (hbDir : bDir = Linear.dotFin dHead dirHead (fun d => inputs.bv d)) : + (bDir : Real) = ∑ d, (dirHead d : Real) * (inputs.bv d : Real) := by + have hsum : + ((∑ d, dirHead d * inputs.bv d : Rat) : Real) = + ∑ d, ((dirHead d * inputs.bv d : Rat) : Real) := by + simp + have hsum' : + ∑ d, ((dirHead d * inputs.bv d : Rat) : Real) = + ∑ d, (dirHead d : Real) * (inputs.bv d : Real) := by + refine Finset.sum_congr rfl ?_ + intro d _ + simp + have hfinal := hsum.trans hsum' + calc + (bDir : Real) + = ((∑ d, dirHead d * inputs.bv d : Rat) : Real) := by + simp [hbDir, Linear.dotFin_eq_dotProduct, dotProduct] + _ = ∑ d, (dirHead d : Real) * (inputs.bv d : Real) := hfinal + +/-- Rewrite direction values using cached `wvDir` and `bDir` sums. -/ +theorem valsReal_eq_of_dir (inputs : Model.InductionHeadInputs seq dModel dHead) + (dirHead : Fin dHead → Rat) + (hdirHead : dirHead = fun d => (dirHeadVecOfInputs inputs).get d) + (wvDir : Fin dModel → Rat) (bDir : Rat) + (hdir_wv : ∀ j, (wvDir j : Real) = ∑ d, (dirHead d : Real) * (inputs.wv j d : Real)) + (hdir_bv : (bDir : Real) = ∑ d, (dirHead d : Real) * (inputs.bv d : Real)) : + ∀ k, + valsRealOfInputs inputs k = + dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + (bDir : Real) := by + intro k + classical + have hdirHead_real : + (fun d => (dirHeadVecOfInputs inputs).get d : Fin dHead → Real) = + fun d => (dirHead d : Real) := by + funext d + simp [hdirHead] + have hdot_add : + dotProduct (fun d => (dirHead d : Real)) + (fun d => + dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs k) + + (inputs.bv d : Real)) = + dotProduct (fun d => (dirHead d : Real)) + (fun d => dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs k)) + + dotProduct (fun d => (dirHead d : Real)) (fun d => (inputs.bv d : Real)) := by + simp [dotProduct, mul_add, Finset.sum_add_distrib] + have hdot_wv : + dotProduct (fun d => (dirHead d : Real)) + (fun d => dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs k)) = + dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) := by + calc + dotProduct (fun d => (dirHead d : Real)) + (fun d => dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs k)) = + ∑ d, (dirHead d : Real) * ∑ j, + (inputs.wv j d : Real) * lnRealOfInputs inputs k j := by + simp [dotProduct] + _ = ∑ d, ∑ j, + (dirHead d : Real) * + ((inputs.wv j d : Real) * lnRealOfInputs inputs k j) := by + simp [Finset.mul_sum] + _ = ∑ j, ∑ d, + (dirHead d : Real) * + ((inputs.wv j d : Real) * lnRealOfInputs inputs k j) := by + simpa using + (Finset.sum_comm (s := (Finset.univ : Finset (Fin dHead))) + (t := (Finset.univ : Finset (Fin dModel))) + (f := fun d j => + (dirHead d : Real) * ((inputs.wv j d : Real) * lnRealOfInputs inputs k j))) + _ = ∑ j, (∑ d, (dirHead d : Real) * (inputs.wv j d : Real)) * + lnRealOfInputs inputs k j := by + refine Finset.sum_congr rfl ?_ + intro j _ + have hsum : + (∑ d, (dirHead d : Real) * (inputs.wv j d : Real)) * + lnRealOfInputs inputs k j = + ∑ d, (dirHead d : Real) * (inputs.wv j d : Real) * + lnRealOfInputs inputs k j := by + simp [Finset.sum_mul, mul_assoc] + simpa [mul_assoc] using hsum.symm + _ = ∑ j, (wvDir j : Real) * lnRealOfInputs inputs k j := by + refine Finset.sum_congr rfl ?_ + intro j _ + simp [hdir_wv j] + _ = dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) := by + simp [dotProduct] + calc + valsRealOfInputs inputs k = + dotProduct (fun d => (dirHead d : Real)) + (fun d => + dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs k) + + (inputs.bv d : Real)) := by + simp [valsRealOfInputs, vRealOfInputs, hdirHead_real] + _ = + dotProduct (fun d => (dirHead d : Real)) + (fun d => dotProduct (fun j => (inputs.wv j d : Real)) + (lnRealOfInputs inputs k)) + + dotProduct (fun d => (dirHead d : Real)) (fun d => (inputs.bv d : Real)) := hdot_add + _ = + dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + + dotProduct (fun d => (dirHead d : Real)) (fun d => (inputs.bv d : Real)) := by + simp [hdot_wv] + _ = + dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + (bDir : Real) := by + have hb : + dotProduct (fun d => (dirHead d : Real)) (fun d => (inputs.bv d : Real)) = + (bDir : Real) := by + calc + dotProduct (fun d => (dirHead d : Real)) (fun d => (inputs.bv d : Real)) + = ∑ d, (dirHead d : Real) * (inputs.bv d : Real) := by + simp [dotProduct] + _ = (bDir : Real) := hdir_bv.symm + simp [hb] + +end Sound +end Nfp diff --git a/Nfp/Sound/Induction/HeadOutput.lean b/Nfp/Sound/Induction/HeadOutput.lean index d14fe03..649f108 100644 --- a/Nfp/Sound/Induction/HeadOutput.lean +++ b/Nfp/Sound/Induction/HeadOutput.lean @@ -1,6 +1,6 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Sound.Induction.Core +import Nfp.Sound.Induction.CoreSound /-! Head-output interval certificates for induction heads. From 322621e026155b1163240518b15736f66426b49e Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 10 Jan 2026 08:12:43 +0100 Subject: [PATCH 121/244] Tighten dot-diff bounds for active queries --- Nfp/Sound/Induction/Core.lean | 51 +++++++++++++--------- Nfp/Sound/Induction/CoreSound.lean | 70 +++++++++++++++++------------- 2 files changed, 72 insertions(+), 49 deletions(-) diff --git a/Nfp/Sound/Induction/Core.lean b/Nfp/Sound/Induction/Core.lean index 8d9b23d..6897d03 100644 --- a/Nfp/Sound/Induction/Core.lean +++ b/Nfp/Sound/Induction/Core.lean @@ -238,9 +238,10 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} inputs.maskCausal = true ∧ q < k let splitBudgetQ : Nat := 8 let splitBudgetK : Nat := 8 + let splitBudgetDiff : Nat := 6 let splitDimsQ : Fin seq → List (Fin dHead) := fun q => let ambig := - (List.finRange dHead).filter (fun d => qLo q d < 0 ∧ 0 < qHi q d) + (List.finRange dHead).filter (fun d => decide (qLo q d < 0 ∧ 0 < qHi q d)) let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d let step (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) @@ -262,11 +263,11 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} | (some b1, none) => [b1.2] | (none, _) => [] let dims1 := top2 ambig - let dims2 := top2 (ambig.filter (fun d => d ∉ dims1)) + let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) (dims1 ++ dims2).take splitBudgetQ let splitDimsK : Fin seq → Fin seq → List (Fin dHead) := fun q k => let ambig := - (List.finRange dHead).filter (fun d => kLo k d < 0 ∧ 0 < kHi k d) + (List.finRange dHead).filter (fun d => decide (kLo k d < 0 ∧ 0 < kHi k d)) let score : Fin dHead → Rat := fun d => (kHi k d - kLo k d) * qAbs q d let step (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) @@ -288,14 +289,14 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} | (some b1, none) => [b1.2] | (none, _) => [] let dims1 := top2 ambig - let dims2 := top2 (ambig.filter (fun d => d ∉ dims1)) + let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) (dims1 ++ dims2).take splitBudgetK let splitDimsDiff : Fin seq → Fin seq → List (Fin dHead) := fun q k => let prev := inputs.prev q let diffLo : Fin dHead → Rat := fun d => kLo prev d - kHi k d let diffHi : Fin dHead → Rat := fun d => kHi prev d - kLo k d let ambig := - (List.finRange dHead).filter (fun d => diffLo d < 0 ∧ 0 < diffHi d) + (List.finRange dHead).filter (fun d => decide (diffLo d < 0 ∧ 0 < diffHi d)) let score : Fin dHead → Rat := fun d => (diffHi d - diffLo d) * qAbs q d let step (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) @@ -316,9 +317,16 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} | (some b1, some b2) => [b1.2, b2.2] | (some b1, none) => [b1.2] | (none, _) => [] - let dims1 := top2 ambig - let dims2 := top2 (ambig.filter (fun d => d ∉ dims1)) - (dims1 ++ dims2).take splitBudgetK + let dims1 : List (Fin dHead) := top2 ambig + let dims2 : List (Fin dHead) := + top2 (ambig.filter (fun d => decide (d ∉ dims1))) + let memDims2 : Fin dHead → Bool := fun d => + dims2.any (fun d' => decide (d' = d)) + let dims3 : List (Fin dHead) := + top2 + ((ambig.filter (fun d => decide (d ∉ dims1))).filter + (fun d => !memDims2 d)) + (dims1 ++ dims2 ++ dims3).take splitBudgetDiff let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => @@ -332,18 +340,21 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} let dotDiffRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => - let dimsQ := splitDimsQ q - ⟨Array.ofFn (fun k : Fin seq => - if masked q k then - (0, 0) - else - let dimsDiff := splitDimsDiff q k - let prev := inputs.prev q - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)), - by simp⟩)) + if hq : q ∈ inputs.active then + let dimsQ := splitDimsQ q + ⟨Array.ofFn (fun k : Fin seq => + if masked q k then + (0, 0) + else + let dimsDiff := splitDimsDiff q k + let prev := inputs.prev q + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)), + by simp⟩ + else + ⟨Array.ofFn (fun _ : Fin seq => (0, 0)), by simp⟩)) let dotLo : Fin seq → Fin seq → Rat := fun q k => let row := (dotRowTasks[q.1]'(by simp [dotRowTasks, q.isLt])).get diff --git a/Nfp/Sound/Induction/CoreSound.lean b/Nfp/Sound/Induction/CoreSound.lean index cc30268..c3075c3 100644 --- a/Nfp/Sound/Induction/CoreSound.lean +++ b/Nfp/Sound/Induction/CoreSound.lean @@ -147,9 +147,10 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} inputs.maskCausal = true ∧ q < k let splitBudgetQ : Nat := 8 let splitBudgetK : Nat := 8 + let splitBudgetDiff : Nat := 6 let splitDimsQ : Fin seq → List (Fin dHead) := fun q => let ambig := - (List.finRange dHead).filter (fun d => qLo q d < 0 ∧ 0 < qHi q d) + (List.finRange dHead).filter (fun d => decide (qLo q d < 0 ∧ 0 < qHi q d)) let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d let step (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) @@ -171,11 +172,11 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} | (some b1, none) => [b1.2] | (none, _) => [] let dims1 := top2 ambig - let dims2 := top2 (ambig.filter (fun d => d ∉ dims1)) + let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) (dims1 ++ dims2).take splitBudgetQ let splitDimsK : Fin seq → Fin seq → List (Fin dHead) := fun q k => let ambig := - (List.finRange dHead).filter (fun d => kLo k d < 0 ∧ 0 < kHi k d) + (List.finRange dHead).filter (fun d => decide (kLo k d < 0 ∧ 0 < kHi k d)) let score : Fin dHead → Rat := fun d => (kHi k d - kLo k d) * qAbs q d let step (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) @@ -197,14 +198,14 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} | (some b1, none) => [b1.2] | (none, _) => [] let dims1 := top2 ambig - let dims2 := top2 (ambig.filter (fun d => d ∉ dims1)) + let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) (dims1 ++ dims2).take splitBudgetK let splitDimsDiff : Fin seq → Fin seq → List (Fin dHead) := fun q k => let prev := inputs.prev q let diffLo : Fin dHead → Rat := fun d => kLo prev d - kHi k d let diffHi : Fin dHead → Rat := fun d => kHi prev d - kLo k d let ambig := - (List.finRange dHead).filter (fun d => diffLo d < 0 ∧ 0 < diffHi d) + (List.finRange dHead).filter (fun d => decide (diffLo d < 0 ∧ 0 < diffHi d)) let score : Fin dHead → Rat := fun d => (diffHi d - diffLo d) * qAbs q d let step (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) @@ -225,9 +226,16 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} | (some b1, some b2) => [b1.2, b2.2] | (some b1, none) => [b1.2] | (none, _) => [] - let dims1 := top2 ambig - let dims2 := top2 (ambig.filter (fun d => d ∉ dims1)) - (dims1 ++ dims2).take splitBudgetK + let dims1 : List (Fin dHead) := top2 ambig + let dims2 : List (Fin dHead) := + top2 (ambig.filter (fun d => decide (d ∉ dims1))) + let memDims2 : Fin dHead → Bool := fun d => + dims2.any (fun d' => decide (d' = d)) + let dims3 : List (Fin dHead) := + top2 + ((ambig.filter (fun d => decide (d ∉ dims1))).filter + (fun d => !memDims2 d)) + (dims1 ++ dims2 ++ dims3).take splitBudgetDiff let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => @@ -241,18 +249,21 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} let dotDiffRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => - let dimsQ := splitDimsQ q - ⟨Array.ofFn (fun k : Fin seq => - if masked q k then - (0, 0) - else - let dimsDiff := splitDimsDiff q k - let prev := inputs.prev q - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)), - by simp⟩)) + if hq : q ∈ inputs.active then + let dimsQ := splitDimsQ q + ⟨Array.ofFn (fun k : Fin seq => + if masked q k then + (0, 0) + else + let dimsDiff := splitDimsDiff q k + let prev := inputs.prev q + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)), + by simp⟩ + else + ⟨Array.ofFn (fun _ : Fin seq => (0, 0)), by simp⟩)) let dotLo : Fin seq → Fin seq → Rat := fun q k => let row := (dotRowTasks[q.1]'(by simp [dotRowTasks, q.isLt])).get @@ -371,7 +382,8 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} invStdBoundsArr, invStdLo, invStdHi, qBaseArr, qBase, kBaseArr, kBase, qCoeffRowTasks, qCoeffArr, qCoeff, kCoeffRowTasks, kCoeffArr, kCoeff, qLo, qHi, kLo, kHi, qAbs, kAbs, qAbsMaxArr, qAbsMax, kAbsMaxArr, kAbsMax, masked, splitBudgetQ, - splitBudgetK, splitDimsQ, splitDimsK, splitDimsDiff, dotRowTasks, dotDiffRowTasks, + splitBudgetK, splitBudgetDiff, splitDimsQ, splitDimsK, splitDimsDiff, dotRowTasks, + dotDiffRowTasks, dotLo, dotHi, dotDiffLo, dotDiffHi, dotAbs, scoreBaseAbs, scoreLo, scoreHi, scoreLoPrev, scoreGapLo, otherKeys, marginAt, epsAt, margin, eps, dirHeadVec, dirHead, wvDir, bDir, valsAbsBase, valsLoBase, valsHiBase, valsLo, valsHi, univ, lo, @@ -757,7 +769,7 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} · simpa [scoresReal, scoreHi, masked, hcausal, hscale, hscore_eq, base] using hhigh have hdot_diff_bounds : - ∀ q k, ¬ masked q k → + ∀ q, q ∈ inputs.active → ∀ k, ¬ masked q k → (dotDiffLo q k : Real) ≤ dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs (inputs.prev q) d - @@ -765,14 +777,14 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by - intro q k hmask - have hq := hq_bounds q + intro q hq k hmask + have hq_bounds' := hq_bounds q have hkprev := hk_bounds (inputs.prev q) have hk := hk_bounds k have hlo1 : ∀ d, (qLo q d : Real) ≤ qRealOfInputs inputs q d := fun d => - (hq d).1 + (hq_bounds' d).1 have hhi1 : ∀ d, qRealOfInputs inputs q d ≤ (qHi q d : Real) := fun d => - (hq d).2 + (hq_bounds' d).2 have hlo2 : ∀ d, (kLo (inputs.prev q) d - kHi k d : Rat) ≤ @@ -806,13 +818,13 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) := by - simpa [dotDiffLo, dotDiffRowTasks, hmask, Task.spawn, Array.getElem_ofFn] + simpa [dotDiffLo, dotDiffRowTasks, hq, hmask, Task.spawn, Array.getElem_ofFn] using hspec.1 have hhigh' : dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by - simpa [dotDiffHi, dotDiffRowTasks, hmask, Task.spawn, Array.getElem_ofFn] + simpa [dotDiffHi, dotDiffRowTasks, hq, hmask, Task.spawn, Array.getElem_ofFn] using hspec.2 exact ⟨hlow', hhigh'⟩ let scoresReal := scoresRealOfInputs inputs @@ -878,7 +890,7 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} _ = (scoreLoPrev q : Real) := by simp [add_sub_cancel_left] _ ≤ scoresReal q (inputs.prev q) := hscore_prev - · have hdiff := hdot_diff_bounds q k hmask + · have hdiff := hdot_diff_bounds q hq k hmask have hgap_le : (scoreGapLo q k : Real) ≤ (inputs.scale : Real) * From 879aed24dda42634ce65e70f10f9b91db9a2e143 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 10 Jan 2026 11:30:26 +0100 Subject: [PATCH 122/244] Update agent docs and module map --- AGENTS.md | 274 +++++++------------------------------------------- MODULE_MAP.md | 220 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 254 insertions(+), 240 deletions(-) create mode 100644 MODULE_MAP.md diff --git a/AGENTS.md b/AGENTS.md index 1a3f1c1..be8613e 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -16,7 +16,7 @@ but keep the core invariants and the “no fake proofs” ethos. ## 0. Quick Start (What to run) -### Build (warnings are errors) +### Build - `lake build --wfail` ### Build the CLI @@ -27,15 +27,12 @@ One of these typically works (depending on your Lake setup): - `lake exe nfp --help` If you add or change CLI behavior, validate at least: -- `nfp --help` -- `nfp analyze --help` -- `nfp induction --help` -- `nfp --version` (if supported) - -Before you finish any change: -- `lake build --wfail` -- `lake build nfp --wfail` +- `lake exe nfp --help` (or `nfp --help` if on PATH) +- `lake exe nfp analyze --help` (or `nfp analyze --help`) +- `lake exe nfp induction --help` (or `nfp induction --help`) +- `lake exe nfp --version` (or `nfp --version`) if supported +### Search tips Note: `models/` is gitignored, so `rg` will skip it unless you pass `--no-ignore` or `-uuu` (or equivalent) when searching. @@ -45,12 +42,14 @@ or `-uuu` (or equivalent) when searching. ### 1.1 No fake proofs - **Forbidden:** `sorry` -- **Forbidden:** introducing new nontrivial axioms beyond what mathlib already uses. ### 1.2 Linting stays on - **Never** disable linters globally or locally. - **Forbidden:** any `set_option linter.* false` (including e.g. `linter.unnecessarySimpa`). - Fix the code/proofs instead. +- If linters warn about line length or file length, prefer principled refactors + (split modules, extract helpers) and keep docstrings with their code; avoid + squashing whitespace or formatting. ### 1.3 Clean build - `lake build --wfail` must succeed. @@ -65,14 +64,15 @@ The library’s claims rest on these being preserved (preferably with explicit l - Finiteness assumptions (`[Fintype _]`) are used intentionally and consistently ### 1.5 Trusted Code Verification (Total Soundness) -**All code** in trusted namespaces (e.g., `Nfp.Sound.*`) must be **verified**. +**All code** in trusted namespaces (see §6) must be **verified**. - **Requirement:** Every pure definition in the trusted scope must be characterized by a theorem or return a proof-carrying structure. - *Example (Bad):* `def addOne (x : Nat) := x + 1` (Unverified logic) - *Example (Good):* `def addOne (x : Nat) : { y // y > x } := ⟨x + 1, Nat.lt_succ_self _⟩` - *Example (Good):* `def addOne ...` followed immediately by `theorem addOne_gt_input ...` - **Scope:** This applies to **everything**: parsers, converters, arithmetic helpers, and bound computations. -- **IO Exception:** Low-level IO primitives (reading bytes/files) cannot be "proven" correct but must be kept **logic-free**. +- **IO Exception:** Low-level IO primitives (reading bytes/files) cannot be "proven" + correct but must be kept **logic-free**. - IO code should only read data and pass it to verified Pure code. - No mathematical transformations or complex branching allowed in IO functions. @@ -90,15 +90,24 @@ The library’s claims rest on these being preserved (preferably with explicit l (or whether it belongs in a more general file in this repo). ### 2.3 Verify, Don't Trust -- Distinguish between **witness generation** (untrusted, can use heuristics) and **verification** (trusted, must contain proofs). -- The trusted kernel should only check that a candidate witness is valid; it should not be responsible for finding it if the search is complex. +- Distinguish between **witness generation** (untrusted, can use heuristics) and + **verification** (trusted, must contain proofs). +- The trusted kernel should only check that a candidate witness is valid; it + should not be responsible for finding it if the search is complex. + +### 2.4 Prefer principled redesigns +When forced to choose between: +- “slightly breaking but conceptually clean redesign” +- vs “preserve an awkward design forever” + +prefer the **clean redesign**, but do it consciously and document the rationale. --- ## 3. Workflow Expectations (How to make changes) ### 3.1 Before coding -- Identify the right module (see §5 Module Map). +- Identify the right module (see `MODULE_MAP.md`). - Skim the top docstring / main definitions in that module. - Look for existing lemmas and naming patterns to match. @@ -112,10 +121,7 @@ The library’s claims rest on these being preserved (preferably with explicit l - small lemmas, smaller proof terms, fewer global simp rules. ### 3.3 After coding -- Ensure `lake build --wfail` passes. -- Ensure no `sorry`. -- Ensure no linter toggles were introduced. -- If you changed module responsibilities/structure, update §5 in the same commit. +- Ensure the Definition of Done checklist is satisfied. --- @@ -138,223 +144,13 @@ The library’s claims rest on these being preserved (preferably with explicit l - You may do nontrivial refactors to improve conceptual cleanliness. - If you rename/reshape core APIs: - update all call sites, - - leave a brief comment (or commit message rationale), - - keep the module map (§5) accurate. + - leave a brief comment (or commit message rationale). --- ## 5. Module Map (Where Things Live) -This is a *map*, not a prison. You may reshuffle if a better design emerges, -but you **must** update this list in the same commit. - -### 5.1 Core types -- `Nfp/Core/Basic.lean` - - `Mass` alias for nonnegative weights used throughout the rewrite. -- `Nfp/Core.lean` - - Aggregator for core shared definitions. - -### 5.2 Probability vectors -- `Nfp/Prob/Basic.lean` - - `ProbVec` definition + invariants. -- `Nfp/Prob/Operations.lean` - - `pure`, `mix`, and basic lemmas. -- `Nfp/Prob.lean` - - Aggregator for probability modules. - -### 5.3 Mixers -- `Nfp/Mixer/Basic.lean` - - `Mixer` structure and row-stochastic invariant. -- `Nfp/Mixer/Operations.lean` - - `push`, `comp`, and `id` mixers. -- `Nfp/Mixer.lean` - - Aggregator for mixer modules. - -### 5.4 Systems (DAG + local mixing) -- `Nfp/System/Dag.lean` - - DAG relation + parent/child sets. -- `Nfp/System/LocalSystem.lean` - - `LocalSystem` with edge support, row-stochastic predicate, and evaluation semantics. -- `Nfp/System.lean` - - Aggregator for system modules. - -### 5.5 Circuits (certification core) -- `Nfp/Circuit/Basic.lean` - - DAG-based circuit structure with inputs/outputs and gate semantics. -- `Nfp/Circuit/Combinators.lean` - - Core circuit combinators (relabeling, interface transport). -- `Nfp/Circuit/Interface.lean` - - Typed input/output interfaces and interface-based evaluation. -- `Nfp/Circuit/Semantics.lean` - - Well-founded evaluation semantics for circuits. -- `Nfp/Circuit/WellFormed.lean` - - Basic well-formedness conditions for circuit inputs. -- `Nfp/Circuit/Cert.lean` - - Equivalence definition and finite checker. -- `Nfp/Circuit/Cert/SoftmaxMargin.lean` - - Softmax-margin certificate payloads and checker soundness. -- `Nfp/Circuit/Cert/ValueRange.lean` - - Value-range certificate payloads and checker soundness. -- `Nfp/Circuit/Cert/LogitDiff.lean` - - Logit-diff lower-bound computation for induction certificates. -- `Nfp/Circuit/Cert/DownstreamLinear.lean` - - Downstream linear error certificates for end-to-end induction bounds. -- `Nfp/Circuit/Cert/ResidualBound.lean` - - Residual-stream bound certificates for downstream error computation. -- `Nfp/Circuit/Cert/ResidualInterval.lean` - - Residual-stream interval certificates for downstream dot-product bounds. -- `Nfp/Circuit/Typed.lean` - - Typed circuit wrapper and interface-level equivalence checker. -- `Nfp/Circuit/Compose.lean` - - Sequential composition and residual wiring for typed circuits. -- `Nfp/Circuit/Gates/Basic.lean` - - Basic gate combinators for aggregating parent values. -- `Nfp/Circuit/Gates/Linear.lean` - - Linear and affine gate combinators built from `Matrix.mulVec`. -- `Nfp/Circuit/Gates.lean` - - Aggregator for gate combinator modules. -- `Nfp/Circuit/Tensor.lean` - - Typed tensor indices and tensor aliases. -- `Nfp/Circuit/Layers/Linear.lean` - - Linear/affine layer circuits with typed interfaces. -- `Nfp/Circuit/Layers/Tensor.lean` - - Batched linear/affine layer circuits for tensor-shaped data. -- `Nfp/Circuit/Layers/Reshape.lean` - - Reshape combinators for product-typed circuit interfaces. -- `Nfp/Circuit/Layers/Heads.lean` - - Head split/merge combinators for transformer-shaped indices. -- `Nfp/Circuit/Layers/Softmax.lean` - - Softmax helpers and margin-based bounds for layer reasoning. -- `Nfp/Circuit/Layers/Attention.lean` - - Q/K/V, output projection wiring, and attention score/mixing core. -- `Nfp/Circuit/Layers/Induction.lean` - - Induction-head weight specs and attention-core output lemmas. -- `Nfp/Circuit/Layers/TransformerBlock.lean` - - GPT-style transformer block wiring from LN/attention/MLP circuits. -- `Nfp/Circuit/Layers.lean` - - Aggregator for circuit layer modules. -- `Nfp/Circuit.lean` - - Aggregator for circuit modules. - -### 5.6 CLI surface -- `Nfp/IO/Pure.lean` - - Aggregator for pure parsing helpers. -- `Nfp/IO/Pure/Basic.lean` - - Shared parsing helpers (`Nat`/`Int`/`Rat`, token cleanup). -- `Nfp/IO/Pure/InductionHead.lean` - - Induction-head input payload parsing from text/bytes. -- `Nfp/IO/Pure/InductionHead/Bytes.lean` - - Byte-level parser for induction-head input payloads. -- `Nfp/IO/Pure/SoftmaxMargin.lean` - - Aggregator for softmax-margin parsing helpers. -- `Nfp/IO/Pure/SoftmaxMargin/Shared.lean` - - Shared parsing helpers for softmax-margin payloads. -- `Nfp/IO/Pure/SoftmaxMargin/Cert.lean` - - Softmax-margin certificate parser. -- `Nfp/IO/Pure/SoftmaxMargin/Raw.lean` - - Softmax-margin raw-input parser. -- `Nfp/IO/Pure/ValueRange.lean` - - Aggregator for value-range parsing helpers. -- `Nfp/IO/Pure/ValueRange/Shared.lean` - - Shared parsing helpers for value-range payloads. -- `Nfp/IO/Pure/ValueRange/Cert.lean` - - Value-range certificate parser. -- `Nfp/IO/Pure/ValueRange/Raw.lean` - - Value-range raw-input parser. -- `Nfp/IO/Pure/Downstream.lean` - - Downstream linear and matrix payload parsers. -- `Nfp/IO/Pure/Residual.lean` - - Residual-bound and residual-interval payload parsers. -- `Nfp/IO/NfptPure.lean` - - Pure parsing helpers for `NFP_BINARY_V1` model slices. -- `Nfp/IO/HeadScore.lean` - - Pure task-based cache builder for head score dot-abs bounds. -- `Nfp/IO/Loaders.lean` - - IO loaders for certificates and raw inputs. -- `Nfp/IO/Checks.lean` - - IO checks for certificate validity. -- `Nfp/IO/Derive.lean` - - IO derivations building certificates from model binaries. -- `Nfp/IO/Timing.lean` - - IO timing helpers with microsecond reporting and phase wrappers. -- `Nfp/IO/Util.lean` - - Small CLI parsing utilities shared across IO entrypoints. -- `Nfp/IO/InductionHead.lean` - - Induction-head IO pipeline with timing instrumentation. -- `Nfp/IO/Bench/Rational.lean` - - Microbenchmarks for rational arithmetic and caching. -- `Nfp/IO.lean` - - IO-only wrappers for loading inputs and running checks. -- `Nfp/Cli.lean` - - CLI commands and `main` implementation. -- `Main.lean` - - Thin entrypoint delegating to `Nfp.Cli.main`. - - Benchmark entrypoint for rational microbenchmarks. -- `Nfp.lean` - - Top-level reexports. -- `TheoremAxioms.lean` - - Axiom dashboard for `theorem-axioms` build target (`#print axioms`). - -### 5.7 Sound certification -- `Nfp/Sound/Induction.lean` - - Aggregator for induction soundness modules. -- `Nfp/Sound/Induction/Core.lean` - - Sound builders and core proofs for induction certificates from exact inputs. -- `Nfp/Sound/Induction/CoreSound.lean` - - Soundness proof for `buildInductionCertFromHeadCore?`. -- `Nfp/Sound/Induction/CoreSound/Values.lean` - - Helper lemmas for value-direction projections in the core soundness proof. -- `Nfp/Sound/Induction/CoreDefs.lean` - - Core definitions and soundness predicates for induction certificates. -- `Nfp/Sound/Induction/HeadOutput.lean` - - Head-output interval certificates built from induction head inputs. -- `Nfp/Sound/Induction/HeadBounds.lean` - - Helper bounds used to stage head-induction certificate construction. -- `Nfp/Sound/Bounds/Cache.lean` - - Cached bound evaluators (thunk/task backed) for interval computations. -- `Nfp/Sound/Bounds/MatrixNorm.lean` - - Row-sum matrix norms and downstream linear certificate builders. -- `Nfp/Sound/Bounds/MatrixNorm/Interval.lean` - - Dot-product and matrix-vector interval bounds (rational and real). -- `Nfp/Sound/Bounds/LayerNorm.lean` - - LayerNorm interval bounds and end-to-end soundness lemmas. -- `Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean` - - Mean/variance helpers for LayerNorm bounds. -- `Nfp/Sound/Bounds/LayerNorm/InvStd.lean` - - Inverse-standard-deviation bounds for LayerNorm. -- `Nfp/Sound/Bounds/UnnormRat.lean` - - Unnormalized rational helpers for deferred normalization in bounds kernels. -- `Nfp/Sound/Bounds/Gelu.lean` - - Tanh-GELU bounds for interval propagation through MLPs. -- `Nfp/Sound/Bounds/Mlp.lean` - - Interval bounds for GPT-2 MLP blocks and LayerNorm composition. -- `Nfp/Sound/Bounds/Attention.lean` - - Interval bounds for multi-head attention and transformer layers. -- `Nfp/Sound/Bounds/Transformer.lean` - - Interval bounds for transformer stacks and final LayerNorm outputs. -- `Nfp/Sound/Bounds/Transformer/Embedding.lean` - - Embedding interval bounds and position-restricted bounds. -- `Nfp/Sound/Linear/FinFold.lean` - - Tail-recursive folds and sums for sound linear computations. -- `Nfp/Sound/Gpt2/HeadInputs.lean` - - Sound construction of GPT-2 induction head inputs. -- `Nfp/Sound.lean` - - Aggregator for sound certification modules. - -### 5.8 Model inputs -- `Nfp/Model/InductionHead.lean` - - Exact induction-head input payloads (embeddings and projection weights). -- `Nfp/Model/InductionPrompt.lean` - - Prompt utilities (`prev` map and active set for periodic prompts). -- `Nfp/Model/Gpt2.lean` - - Exact GPT-2 head-slice data, layer/MLP/LayerNorm parameters, and embedding helpers. -- `Nfp/Model.lean` - - Aggregator for model input modules. - -If you introduce a new conceptual layer: -- either extend the closest existing file, -- or add a new module with a clear name + top docstring, -- and update this map in the same commit. +The module map lives in `MODULE_MAP.md`. --- @@ -364,6 +160,9 @@ This repo treats “axioms creep” as a serious regression. - Do not add axioms. - Keep an eye on classical assumptions; they may be unavoidable, but should be explicit. +- Trusted namespaces are `Nfp.Sound.*`, `Nfp.IO.Pure.*`, and `Nfp.IO.NfptPure`. + If another module is intended to be trusted, say so explicitly in its docstring + and treat it as in-scope here. - Use `TheoremAxioms.lean` / `lake build theorem-axioms --wfail` as the trust dashboard for `#print axioms` / dependency visibility. @@ -374,15 +173,10 @@ This repo treats “axioms creep” as a serious regression. - [ ] `lake build --wfail` succeeds. - [ ] No `sorry`. - [ ] No new axioms were introduced. -- [ ] **Total Soundness:** Every pure definition in the trusted section is verified/proven. +- [ ] **Total Soundness:** Every pure definition in trusted namespaces is verified/proven. - [ ] No linters were disabled (`set_option linter.* false` is absent). - [ ] New nontrivial definitions/theorems have short, accurate docstrings. -- [ ] Core invariants (nonnegativity, normalization, finiteness, acyclicity) are preserved and, where possible, explicitly proved. -- [ ] §5 Module Map is accurate (updated in the same commit if needed). +- [ ] Core invariants (nonnegativity, normalization, finiteness, acyclicity) are + preserved and, where possible, explicitly proved. +- [ ] Module map in `MODULE_MAP.md` is accurate (updated in the same commit if needed). - [ ] If CLI behavior changed: `lake build nfp --wfail` succeeds and basic `nfp ... --help` works. - -When forced to choose between: -- “slightly breaking but conceptually clean redesign” -- vs “preserve an awkward design forever” - -prefer the **clean redesign**, but do it consciously and document the rationale. diff --git a/MODULE_MAP.md b/MODULE_MAP.md new file mode 100644 index 0000000..9899b32 --- /dev/null +++ b/MODULE_MAP.md @@ -0,0 +1,220 @@ +# Module Map (Where Things Live) + +This is a *map*, not a prison. You may reshuffle if a better design emerges, +but you **must** update this list in the same commit. + +## Core types +- `Nfp/Core/Basic.lean` + - `Mass` alias for nonnegative weights used throughout the rewrite. +- `Nfp/Core.lean` + - Aggregator for core shared definitions. + +## Probability vectors +- `Nfp/Prob/Basic.lean` + - `ProbVec` definition + invariants. +- `Nfp/Prob/Operations.lean` + - `pure`, `mix`, and basic lemmas. +- `Nfp/Prob.lean` + - Aggregator for probability modules. + +## Mixers +- `Nfp/Mixer/Basic.lean` + - `Mixer` structure and row-stochastic invariant. +- `Nfp/Mixer/Operations.lean` + - `push`, `comp`, and `id` mixers. +- `Nfp/Mixer.lean` + - Aggregator for mixer modules. + +## Systems (DAG + local mixing) +- `Nfp/System/Dag.lean` + - DAG relation + parent/child sets. +- `Nfp/System/LocalSystem.lean` + - `LocalSystem` with edge support, row-stochastic predicate, and evaluation semantics. +- `Nfp/System.lean` + - Aggregator for system modules. + +## Circuits (certification core) +- `Nfp/Circuit/Basic.lean` + - DAG-based circuit structure with inputs/outputs and gate semantics. +- `Nfp/Circuit/Combinators.lean` + - Core circuit combinators (relabeling, interface transport). +- `Nfp/Circuit/Interface.lean` + - Typed input/output interfaces and interface-based evaluation. +- `Nfp/Circuit/Semantics.lean` + - Well-founded evaluation semantics for circuits. +- `Nfp/Circuit/WellFormed.lean` + - Basic well-formedness conditions for circuit inputs. +- `Nfp/Circuit/Cert.lean` + - Equivalence definition and finite checker. +- `Nfp/Circuit/Cert/SoftmaxMargin.lean` + - Softmax-margin certificate payloads and checker soundness. +- `Nfp/Circuit/Cert/ValueRange.lean` + - Value-range certificate payloads and checker soundness. +- `Nfp/Circuit/Cert/LogitDiff.lean` + - Logit-diff lower-bound computation for induction certificates. +- `Nfp/Circuit/Cert/DownstreamLinear.lean` + - Downstream linear error certificates for end-to-end induction bounds. +- `Nfp/Circuit/Cert/ResidualBound.lean` + - Residual-stream bound certificates for downstream error computation. +- `Nfp/Circuit/Cert/ResidualInterval.lean` + - Residual-stream interval certificates for downstream dot-product bounds. +- `Nfp/Circuit/Typed.lean` + - Typed circuit wrapper and interface-level equivalence checker. +- `Nfp/Circuit/Compose.lean` + - Sequential composition and residual wiring for typed circuits. +- `Nfp/Circuit/Gates/Basic.lean` + - Basic gate combinators for aggregating parent values. +- `Nfp/Circuit/Gates/Linear.lean` + - Linear and affine gate combinators built from `Matrix.mulVec`. +- `Nfp/Circuit/Gates.lean` + - Aggregator for gate combinator modules. +- `Nfp/Circuit/Tensor.lean` + - Typed tensor indices and tensor aliases. +- `Nfp/Circuit/Layers/Linear.lean` + - Linear/affine layer circuits with typed interfaces. +- `Nfp/Circuit/Layers/Tensor.lean` + - Batched linear/affine layer circuits for tensor-shaped data. +- `Nfp/Circuit/Layers/Reshape.lean` + - Reshape combinators for product-typed circuit interfaces. +- `Nfp/Circuit/Layers/Heads.lean` + - Head split/merge combinators for transformer-shaped indices. +- `Nfp/Circuit/Layers/Softmax.lean` + - Softmax helpers and margin-based bounds for layer reasoning. +- `Nfp/Circuit/Layers/Attention.lean` + - Q/K/V, output projection wiring, and attention score/mixing core. +- `Nfp/Circuit/Layers/Induction.lean` + - Induction-head weight specs and attention-core output lemmas. +- `Nfp/Circuit/Layers/TransformerBlock.lean` + - GPT-style transformer block wiring from LN/attention/MLP circuits. +- `Nfp/Circuit/Layers.lean` + - Aggregator for circuit layer modules. +- `Nfp/Circuit.lean` + - Aggregator for circuit modules. + +## CLI surface +- `Nfp/IO/Pure.lean` + - Aggregator for pure parsing helpers. +- `Nfp/IO/Pure/Basic.lean` + - Shared parsing helpers (`Nat`/`Int`/`Rat`, token cleanup). +- `Nfp/IO/Pure/InductionHead.lean` + - Induction-head input payload parsing from text/bytes. +- `Nfp/IO/Pure/InductionHead/Bytes.lean` + - Byte-level parser for induction-head input payloads. +- `Nfp/IO/Pure/SoftmaxMargin.lean` + - Aggregator for softmax-margin parsing helpers. +- `Nfp/IO/Pure/SoftmaxMargin/Shared.lean` + - Shared parsing helpers for softmax-margin payloads. +- `Nfp/IO/Pure/SoftmaxMargin/Cert.lean` + - Softmax-margin certificate parser. +- `Nfp/IO/Pure/SoftmaxMargin/Raw.lean` + - Softmax-margin raw-input parser. +- `Nfp/IO/Pure/ValueRange.lean` + - Aggregator for value-range parsing helpers. +- `Nfp/IO/Pure/ValueRange/Shared.lean` + - Shared parsing helpers for value-range payloads. +- `Nfp/IO/Pure/ValueRange/Cert.lean` + - Value-range certificate parser. +- `Nfp/IO/Pure/ValueRange/Raw.lean` + - Value-range raw-input parser. +- `Nfp/IO/Pure/Downstream.lean` + - Downstream linear and matrix payload parsers. +- `Nfp/IO/Pure/Residual.lean` + - Residual-bound and residual-interval payload parsers. +- `Nfp/IO/NfptPure.lean` + - Pure parsing helpers for `NFP_BINARY_V1` model slices. +- `Nfp/IO/HeadScore.lean` + - Pure task-based cache builder for head score dot-abs bounds. +- `Nfp/IO/Loaders.lean` + - IO loaders for certificates and raw inputs. +- `Nfp/IO/Checks.lean` + - IO checks for certificate validity. +- `Nfp/IO/Derive.lean` + - IO derivations building certificates from model binaries. +- `Nfp/IO/Timing.lean` + - IO timing helpers with microsecond reporting and phase wrappers. +- `Nfp/IO/Util.lean` + - Small CLI parsing utilities shared across IO entrypoints. +- `Nfp/IO/InductionHead.lean` + - Induction-head IO pipeline with timing instrumentation. +- `Nfp/IO/Bench/Rational.lean` + - Microbenchmarks for rational arithmetic and caching. +- `Nfp/IO/Bench/InductionCore.lean` + - Benchmark helpers for induction-head core certification. +- `Nfp/IO/Bench/InductionCounts.lean` + - Call-count instrumentation for induction-head computations. +- `Nfp/IO.lean` + - IO-only wrappers for loading inputs and running checks. +- `Nfp/Cli.lean` + - CLI commands and `main` implementation. +- `Main.lean` + - Thin entrypoint delegating to `Nfp.Cli.main`. + - Benchmark entrypoint for rational microbenchmarks. +- `Nfp.lean` + - Top-level reexports. +- `TheoremAxioms.lean` + - Axiom dashboard for `theorem-axioms` build target (`#print axioms`). + +## Sound certification +- `Nfp/Sound/Induction.lean` + - Aggregator for induction soundness modules. +- `Nfp/Sound/Induction/Core.lean` + - Sound builders and core proofs for induction certificates from exact inputs. +- `Nfp/Sound/Induction/CoreSound.lean` + - Soundness proof for `buildInductionCertFromHeadCore?`. +- `Nfp/Sound/Induction/CoreSound/Values.lean` + - Helper lemmas for value-direction projections in the core soundness proof. +- `Nfp/Sound/Induction/CoreDefs.lean` + - Core definitions and soundness predicates for induction certificates. +- `Nfp/Sound/Induction/HeadOutput.lean` + - Head-output interval certificates built from induction head inputs. +- `Nfp/Sound/Induction/HeadBounds.lean` + - Helper bounds used to stage head-induction certificate construction. +- `Nfp/Sound/Induction/LogitDiff.lean` + - Logit-diff bounds derived from induction certificates. +- `Nfp/Sound/Induction/OneHot.lean` + - Per-query one-hot bounds derived from score margins. +- `Nfp/Sound/Bounds/Cache.lean` + - Cached bound evaluators (thunk/task backed) for interval computations. +- `Nfp/Sound/Bounds/MatrixNorm.lean` + - Row-sum matrix norms and downstream linear certificate builders. +- `Nfp/Sound/Bounds/MatrixNorm/Interval.lean` + - Dot-product and matrix-vector interval bounds (rational and real). +- `Nfp/Sound/Bounds/LayerNorm.lean` + - LayerNorm interval bounds and end-to-end soundness lemmas. +- `Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean` + - Mean/variance helpers for LayerNorm bounds. +- `Nfp/Sound/Bounds/LayerNorm/InvStd.lean` + - Inverse-standard-deviation bounds for LayerNorm. +- `Nfp/Sound/Bounds/UnnormRat.lean` + - Unnormalized rational helpers for deferred normalization in bounds kernels. +- `Nfp/Sound/Bounds/Gelu.lean` + - Tanh-GELU bounds for interval propagation through MLPs. +- `Nfp/Sound/Bounds/Mlp.lean` + - Interval bounds for GPT-2 MLP blocks and LayerNorm composition. +- `Nfp/Sound/Bounds/Attention.lean` + - Interval bounds for multi-head attention and transformer layers. +- `Nfp/Sound/Bounds/Transformer.lean` + - Interval bounds for transformer stacks and final LayerNorm outputs. +- `Nfp/Sound/Bounds/Transformer/Embedding.lean` + - Embedding interval bounds and position-restricted bounds. +- `Nfp/Sound/Linear/FinFold.lean` + - Tail-recursive folds and sums for sound linear computations. +- `Nfp/Sound/Gpt2/HeadInputs.lean` + - Sound construction of GPT-2 induction head inputs. +- `Nfp/Sound.lean` + - Aggregator for sound certification modules. + +## Model inputs +- `Nfp/Model/InductionHead.lean` + - Exact induction-head input payloads (embeddings and projection weights). +- `Nfp/Model/InductionPrompt.lean` + - Prompt utilities (`prev` map and active set for periodic prompts). +- `Nfp/Model/Gpt2.lean` + - Exact GPT-2 head-slice data, layer/MLP/LayerNorm parameters, and embedding helpers. +- `Nfp/Model.lean` + - Aggregator for model input modules. + +If you introduce a new conceptual layer: +- either extend the closest existing file, +- or add a new module with a clear name + top docstring, +- and update this map in the same commit. From b9bbe4e1c2b98193d5307a7deaf49099ff04298c Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 10 Jan 2026 11:31:06 +0100 Subject: [PATCH 123/244] Refine induction core bounds --- Nfp/Sound/Induction/Core.lean | 84 +++++++++++--- Nfp/Sound/Induction/CoreSound.lean | 169 +++++++++++++++++++++-------- 2 files changed, 194 insertions(+), 59 deletions(-) diff --git a/Nfp/Sound/Induction/Core.lean b/Nfp/Sound/Induction/Core.lean index 6897d03..1834cdf 100644 --- a/Nfp/Sound/Induction/Core.lean +++ b/Nfp/Sound/Induction/Core.lean @@ -236,9 +236,10 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} simp [hsize]) let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k - let splitBudgetQ : Nat := 8 - let splitBudgetK : Nat := 8 - let splitBudgetDiff : Nat := 6 + let splitBudgetQ : Nat := 2 + let splitBudgetK : Nat := 2 + let splitBudgetDiffBase : Nat := 0 + let splitBudgetDiffRefined : Nat := 12 let splitDimsQ : Fin seq → List (Fin dHead) := fun q => let ambig := (List.finRange dHead).filter (fun d => decide (qLo q d < 0 ∧ 0 < qHi q d)) @@ -291,7 +292,7 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} let dims1 := top2 ambig let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) (dims1 ++ dims2).take splitBudgetK - let splitDimsDiff : Fin seq → Fin seq → List (Fin dHead) := fun q k => + let splitDimsDiffCore : Nat → Fin seq → Fin seq → List (Fin dHead) := fun budget q k => let prev := inputs.prev q let diffLo : Fin dHead → Rat := fun d => kLo prev d - kHi k d let diffHi : Fin dHead → Rat := fun d => kHi prev d - kLo k d @@ -326,7 +327,11 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} top2 ((ambig.filter (fun d => decide (d ∉ dims1))).filter (fun d => !memDims2 d)) - (dims1 ++ dims2 ++ dims3).take splitBudgetDiff + (dims1 ++ dims2 ++ dims3).take budget + let splitDimsDiffBase : Fin seq → Fin seq → List (Fin dHead) := + splitDimsDiffCore splitBudgetDiffBase + let splitDimsDiffRefined : Fin seq → Fin seq → List (Fin dHead) := + splitDimsDiffCore splitBudgetDiffRefined let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => @@ -337,7 +342,7 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} (fun d => qLo q d) (fun d => qHi q d) (fun d => kLo k d) (fun d => kHi k d)), by simp⟩)) - let dotDiffRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := + let dotDiffRowTasksBase : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => if hq : q ∈ inputs.active then @@ -346,7 +351,7 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} if masked q k then (0, 0) else - let dimsDiff := splitDimsDiff q k + let dimsDiff := splitDimsDiffBase q k let prev := inputs.prev q _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff (fun d => qLo q d) (fun d => qHi q d) @@ -367,15 +372,15 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} let entry := row.1[k.1]'(by simp [row.2, k.isLt]) entry.2 - let dotDiffLo : Fin seq → Fin seq → Rat := fun q k => - let row := (dotDiffRowTasks[q.1]'(by - simp [dotDiffRowTasks, q.isLt])).get + let dotDiffLoBase : Fin seq → Fin seq → Rat := fun q k => + let row := (dotDiffRowTasksBase[q.1]'(by + simp [dotDiffRowTasksBase, q.isLt])).get let entry := row.1[k.1]'(by simp [row.2, k.isLt]) entry.1 - let dotDiffHi : Fin seq → Fin seq → Rat := fun q k => - let row := (dotDiffRowTasks[q.1]'(by - simp [dotDiffRowTasks, q.isLt])).get + let dotDiffHiBase : Fin seq → Fin seq → Rat := fun q k => + let row := (dotDiffRowTasksBase[q.1]'(by + simp [dotDiffRowTasksBase, q.isLt])).get let entry := row.1[k.1]'(by simp [row.2, k.isLt]) entry.2 @@ -400,6 +405,57 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} inputs.scale * dotLo q k let scoreLoPrev : Fin seq → Rat := fun q => scoreLo q (inputs.prev q) + let scoreGapLoBase : Fin seq → Fin seq → Rat := fun q k => + if masked q (inputs.prev q) then + scoreLoPrev q - scoreHi q k + else if masked q k then + scoreLoPrev q - inputs.maskValue + else if hscale : 0 ≤ inputs.scale then + inputs.scale * dotDiffLoBase q k + else + inputs.scale * dotDiffHiBase q k + let otherKeys : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + let worstKey : Fin seq → Option (Fin seq) := fun q => + if hq : q ∈ inputs.active then + let ks := (List.finRange seq).filter (fun k => decide (k ≠ inputs.prev q)) + match ks with + | [] => none + | k :: ks => + let step (best : Rat × Fin seq) (k : Fin seq) := + let s := scoreGapLoBase q k + if s ≤ best.1 then (s, k) else best + some (ks.foldl step (scoreGapLoBase q k, k)).2 + else + none + let dotDiffLo : Fin seq → Fin seq → Rat := fun q k => + match worstKey q with + | some k' => + if hk : k = k' then + let dimsQ := splitDimsQ q + let dimsDiff := splitDimsDiffRefined q k + let prev := inputs.prev q + (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)).1 + else + dotDiffLoBase q k + | none => dotDiffLoBase q k + let dotDiffHi : Fin seq → Fin seq → Rat := fun q k => + match worstKey q with + | some k' => + if hk : k = k' then + let dimsQ := splitDimsQ q + let dimsDiff := splitDimsDiffRefined q k + let prev := inputs.prev q + (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)).2 + else + dotDiffHiBase q k + | none => dotDiffHiBase q k let scoreGapLo : Fin seq → Fin seq → Rat := fun q k => if masked q (inputs.prev q) then scoreLoPrev q - scoreHi q k @@ -409,8 +465,6 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} inputs.scale * dotDiffLo q k else inputs.scale * dotDiffHi q k - let otherKeys : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) let marginAt : Fin seq → Rat := fun q => if hq : q ∈ inputs.active then let other := otherKeys q diff --git a/Nfp/Sound/Induction/CoreSound.lean b/Nfp/Sound/Induction/CoreSound.lean index c3075c3..2515ed0 100644 --- a/Nfp/Sound/Induction/CoreSound.lean +++ b/Nfp/Sound/Induction/CoreSound.lean @@ -9,6 +9,8 @@ open Nfp.Sound.Bounds variable {seq : Nat} set_option maxHeartbeats 5000000 in -- The soundness proof expands many cached bounds; extra heartbeats avoid spurious timeouts. +set_option synthInstance.maxHeartbeats 200000 in +-- Instance search also touches the expanded caches; allow more room to avoid timeouts. /-- Soundness for `buildInductionCertFromHeadCore?`. -/ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) @@ -145,9 +147,10 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} simp [hsize]) let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k - let splitBudgetQ : Nat := 8 - let splitBudgetK : Nat := 8 - let splitBudgetDiff : Nat := 6 + let splitBudgetQ : Nat := 2 + let splitBudgetK : Nat := 2 + let splitBudgetDiffBase : Nat := 0 + let splitBudgetDiffRefined : Nat := 12 let splitDimsQ : Fin seq → List (Fin dHead) := fun q => let ambig := (List.finRange dHead).filter (fun d => decide (qLo q d < 0 ∧ 0 < qHi q d)) @@ -200,7 +203,7 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} let dims1 := top2 ambig let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) (dims1 ++ dims2).take splitBudgetK - let splitDimsDiff : Fin seq → Fin seq → List (Fin dHead) := fun q k => + let splitDimsDiffCore : Nat → Fin seq → Fin seq → List (Fin dHead) := fun budget q k => let prev := inputs.prev q let diffLo : Fin dHead → Rat := fun d => kLo prev d - kHi k d let diffHi : Fin dHead → Rat := fun d => kHi prev d - kLo k d @@ -235,7 +238,11 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} top2 ((ambig.filter (fun d => decide (d ∉ dims1))).filter (fun d => !memDims2 d)) - (dims1 ++ dims2 ++ dims3).take splitBudgetDiff + (dims1 ++ dims2 ++ dims3).take budget + let splitDimsDiffBase : Fin seq → Fin seq → List (Fin dHead) := + splitDimsDiffCore splitBudgetDiffBase + let splitDimsDiffRefined : Fin seq → Fin seq → List (Fin dHead) := + splitDimsDiffCore splitBudgetDiffRefined let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => @@ -246,7 +253,7 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} (fun d => qLo q d) (fun d => qHi q d) (fun d => kLo k d) (fun d => kHi k d)), by simp⟩)) - let dotDiffRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := + let dotDiffRowTasksBase : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => if hq : q ∈ inputs.active then @@ -255,7 +262,7 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} if masked q k then (0, 0) else - let dimsDiff := splitDimsDiff q k + let dimsDiff := splitDimsDiffBase q k let prev := inputs.prev q _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff (fun d => qLo q d) (fun d => qHi q d) @@ -276,15 +283,15 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} let entry := row.1[k.1]'(by simp [row.2, k.isLt]) entry.2 - let dotDiffLo : Fin seq → Fin seq → Rat := fun q k => - let row := (dotDiffRowTasks[q.1]'(by - simp [dotDiffRowTasks, q.isLt])).get + let dotDiffLoBase : Fin seq → Fin seq → Rat := fun q k => + let row := (dotDiffRowTasksBase[q.1]'(by + simp [dotDiffRowTasksBase, q.isLt])).get let entry := row.1[k.1]'(by simp [row.2, k.isLt]) entry.1 - let dotDiffHi : Fin seq → Fin seq → Rat := fun q k => - let row := (dotDiffRowTasks[q.1]'(by - simp [dotDiffRowTasks, q.isLt])).get + let dotDiffHiBase : Fin seq → Fin seq → Rat := fun q k => + let row := (dotDiffRowTasksBase[q.1]'(by + simp [dotDiffRowTasksBase, q.isLt])).get let entry := row.1[k.1]'(by simp [row.2, k.isLt]) entry.2 @@ -309,6 +316,57 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} inputs.scale * dotLo q k let scoreLoPrev : Fin seq → Rat := fun q => scoreLo q (inputs.prev q) + let scoreGapLoBase : Fin seq → Fin seq → Rat := fun q k => + if masked q (inputs.prev q) then + scoreLoPrev q - scoreHi q k + else if masked q k then + scoreLoPrev q - inputs.maskValue + else if hscale : 0 ≤ inputs.scale then + inputs.scale * dotDiffLoBase q k + else + inputs.scale * dotDiffHiBase q k + let otherKeys : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + let worstKey : Fin seq → Option (Fin seq) := fun q => + if hq : q ∈ inputs.active then + let ks := (List.finRange seq).filter (fun k => decide (k ≠ inputs.prev q)) + match ks with + | [] => none + | k :: ks => + let step (best : Rat × Fin seq) (k : Fin seq) := + let s := scoreGapLoBase q k + if s ≤ best.1 then (s, k) else best + some (ks.foldl step (scoreGapLoBase q k, k)).2 + else + none + let dotDiffLo : Fin seq → Fin seq → Rat := fun q k => + match worstKey q with + | some k' => + if hk : k = k' then + let dimsQ := splitDimsQ q + let dimsDiff := splitDimsDiffRefined q k + let prev := inputs.prev q + (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)).1 + else + dotDiffLoBase q k + | none => dotDiffLoBase q k + let dotDiffHi : Fin seq → Fin seq → Rat := fun q k => + match worstKey q with + | some k' => + if hk : k = k' then + let dimsQ := splitDimsQ q + let dimsDiff := splitDimsDiffRefined q k + let prev := inputs.prev q + (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)).2 + else + dotDiffHiBase q k + | none => dotDiffHiBase q k let scoreGapLo : Fin seq → Fin seq → Rat := fun q k => if masked q (inputs.prev q) then scoreLoPrev q - scoreHi q k @@ -318,8 +376,6 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} inputs.scale * dotDiffLo q k else inputs.scale * dotDiffHi q k - let otherKeys : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) let marginAt : Fin seq → Rat := fun q => if hq : q ∈ inputs.active then let other := otherKeys q @@ -376,21 +432,14 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} active := inputs.active prev := inputs.prev values := valCert } - have hcore' : some cert = some c := by - simpa [buildInductionCertFromHeadCore?, hEps, hSqrt, hmodel, hactive, lnBounds, lnLo, - lnHi, lnAbsMaxTask, lnAbsMaxArr, lnAbsMax, lnAbsMaxMax, invStdBoundsTasks, - invStdBoundsArr, invStdLo, invStdHi, qBaseArr, qBase, kBaseArr, kBase, - qCoeffRowTasks, qCoeffArr, qCoeff, kCoeffRowTasks, kCoeffArr, kCoeff, qLo, qHi, kLo, - kHi, qAbs, kAbs, qAbsMaxArr, qAbsMax, kAbsMaxArr, kAbsMax, masked, splitBudgetQ, - splitBudgetK, splitBudgetDiff, splitDimsQ, splitDimsK, splitDimsDiff, dotRowTasks, - dotDiffRowTasks, - dotLo, dotHi, dotDiffLo, dotDiffHi, dotAbs, scoreBaseAbs, scoreLo, scoreHi, - scoreLoPrev, scoreGapLo, otherKeys, marginAt, epsAt, margin, eps, dirHeadVec, - dirHead, wvDir, bDir, valsAbsBase, valsLoBase, valsHiBase, valsLo, valsHi, univ, lo, - hi, valCert, cert, Task.spawn, Bounds.cacheBoundTask_apply, - Array.getElem_ofFn] using hcore + have hcore' : buildInductionCertFromHeadCore? inputs = some cert := by + simp (config := { zeta := false }) + [buildInductionCertFromHeadCore?, hEps, hSqrt, hmodel, hactive, cert, valCert] + rfl have hc : c = cert := by - simpa using (Option.some.inj hcore').symm + have hcert : cert = c := by + exact Option.some.inj (hcore'.symm.trans hcore) + simpa using hcert.symm subst hc have hln_bounds : ∀ q i, (lnLo q i : Real) ≤ lnRealOfInputs inputs q i ∧ @@ -803,9 +852,9 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} have hk_lo := (hk d).1 have h := sub_le_sub hprev_hi hk_lo simpa [ratToReal_sub] using h - have hspec := + have hspec (dimsDiff : List (Fin dHead)) := _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth_spec_real - (dims1 := splitDimsQ q) (dims2 := splitDimsDiff q k) + (dims1 := splitDimsQ q) (dims2 := dimsDiff) (lo1 := fun d => qLo q d) (hi1 := fun d => qHi q d) (lo2 := fun d => kLo (inputs.prev q) d - kHi k d) (hi2 := fun d => kHi (inputs.prev q) d - kLo k d) @@ -813,20 +862,52 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} (y := fun d => kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) hlo1 hhi1 hlo2 hhi2 - have hlow' : - (dotDiffLo q k : Real) ≤ + have hspecBase := hspec (splitDimsDiffBase q k) + have hspecRef := hspec (splitDimsDiffRefined q k) + cases hkey : worstKey q with + | none => + have hlow' : + (dotDiffLo q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) := by + simpa [dotDiffLo, hkey, hq, dotDiffLoBase, dotDiffRowTasksBase, hmask, + Task.spawn, Array.getElem_ofFn] using hspecBase.1 + have hhigh' : dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) := by - simpa [dotDiffLo, dotDiffRowTasks, hq, hmask, Task.spawn, Array.getElem_ofFn] - using hspec.1 - have hhigh' : - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by - simpa [dotDiffHi, dotDiffRowTasks, hq, hmask, Task.spawn, Array.getElem_ofFn] - using hspec.2 - exact ⟨hlow', hhigh'⟩ + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by + simpa [dotDiffHi, hkey, hq, dotDiffHiBase, dotDiffRowTasksBase, hmask, + Task.spawn, Array.getElem_ofFn] using hspecBase.2 + exact ⟨hlow', hhigh'⟩ + | some k' => + by_cases hk : k = k' + · have hlow' : + (dotDiffLo q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) := by + simpa [dotDiffLo, hkey, hk] using hspecRef.1 + have hhigh' : + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by + simpa [dotDiffHi, hkey, hk] using hspecRef.2 + exact ⟨hlow', hhigh'⟩ + · have hlow' : + (dotDiffLo q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) := by + simpa [dotDiffLo, hkey, hk, hq, dotDiffLoBase, dotDiffRowTasksBase, hmask, + Task.spawn, Array.getElem_ofFn] using hspecBase.1 + have hhigh' : + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by + simpa [dotDiffHi, hkey, hk, hq, dotDiffHiBase, dotDiffRowTasksBase, hmask, + Task.spawn, Array.getElem_ofFn] using hspecBase.2 + exact ⟨hlow', hhigh'⟩ let scoresReal := scoresRealOfInputs inputs have hmarginAt_le : ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → From c0db78dab2248b0ec68b0816a3ef7473e0333385 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 10 Jan 2026 13:28:24 +0100 Subject: [PATCH 124/244] Add proof automation guidance to AGENTS --- AGENTS.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/AGENTS.md b/AGENTS.md index be8613e..502d212 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -140,7 +140,15 @@ prefer the **clean redesign**, but do it consciously and document the rationale. - and broadly safe. - Prefer `simp [foo]` over global simp-set growth. -### 4.3 Refactors are allowed—but must be principled +### 4.3 Proof automation discipline +- Use automation to *discover* proofs, then write the small explicit proof (or a minimal + `simp only [...]` set) that captures it. +- Avoid large one-line automation proofs (e.g. `aesop`, `simp` without a controlled set) + in core library code; they are brittle and can slow down elaboration. +- Prefer local simplification: use `simp?` to get a minimal `simp only [...]` set for + non-terminal goals, and keep custom simp sets local via `registerSimpAttr` when needed. + +### 4.4 Refactors are allowed—but must be principled - You may do nontrivial refactors to improve conceptual cleanliness. - If you rename/reshape core APIs: - update all call sites, From 82185f6fdaf98754b8b9ff925b44e96735d79eeb Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sun, 11 Jan 2026 02:04:16 +0100 Subject: [PATCH 125/244] refactor proofs with controlled automation --- Nfp/Circuit/Cert.lean | 17 +- Nfp/Circuit/Cert/DownstreamLinear.lean | 2 +- Nfp/Circuit/Cert/LogitDiff.lean | 20 +- Nfp/Circuit/Cert/ResidualBound.lean | 2 +- Nfp/Circuit/Cert/ResidualInterval.lean | 2 +- Nfp/Circuit/Cert/SoftmaxMargin.lean | 265 +++--------------- Nfp/Circuit/Cert/ValueRange.lean | 22 +- Nfp/Circuit/Compose.lean | 19 +- Nfp/Circuit/Semantics.lean | 4 +- Nfp/Circuit/WellFormed.lean | 5 +- Nfp/Core/Basic.lean | 12 +- Nfp/Mixer/Operations.lean | 25 +- Nfp/Prob/Operations.lean | 16 +- Nfp/Sound/Bounds/Gelu.lean | 19 +- Nfp/Sound/Bounds/LayerNorm.lean | 34 +-- Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean | 34 +-- Nfp/Sound/Bounds/MatrixNorm.lean | 25 +- Nfp/Sound/Bounds/MatrixNorm/Interval.lean | 270 +++++-------------- Nfp/Sound/Induction/CoreSound/Values.lean | 34 +-- Nfp/Sound/Induction/HeadBounds.lean | 6 +- Nfp/Sound/Induction/LogitDiff.lean | 6 +- Nfp/Sound/Induction/OneHot.lean | 213 +++++---------- Nfp/Sound/Linear/FinFold.lean | 47 +--- Nfp/System/LocalSystem.lean | 6 +- 24 files changed, 268 insertions(+), 837 deletions(-) diff --git a/Nfp/Circuit/Cert.lean b/Nfp/Circuit/Cert.lean index 24f69cc..8ca7959 100644 --- a/Nfp/Circuit/Cert.lean +++ b/Nfp/Circuit/Cert.lean @@ -68,22 +68,7 @@ theorem finsetAll_eq_true_iff {β : Type v} {s : Finset β} {p : β → Bool} : | @insert a s ha ih => have hfold : finsetAll (insert a s) p = true ↔ p a = true ∧ finsetAll s p = true := by simp [finsetAll, ha, Bool.and_eq_true] - calc - finsetAll (insert a s) p = true - ↔ p a = true ∧ finsetAll s p = true := hfold - _ ↔ p a = true ∧ ∀ a ∈ s, p a = true := by simp [ih] - _ ↔ ∀ x ∈ insert a s, p x = true := by - constructor - · intro h x hx - rcases h with ⟨ha', hs⟩ - by_cases hx' : x = a - · simpa [hx'] using ha' - · exact hs x (Finset.mem_of_mem_insert_of_ne hx hx') - · intro h - refine ⟨?_, ?_⟩ - · exact h a (Finset.mem_insert_self a s) - · intro x hx - exact h x (Finset.mem_insert_of_mem hx) + simp [hfold, ih] /-- Boolean check for interface equality. -/ def sameInterface (C₁ C₂ : Circuit ι α) : Bool := diff --git a/Nfp/Circuit/Cert/DownstreamLinear.lean b/Nfp/Circuit/Cert/DownstreamLinear.lean index 8e01f23..85e1d96 100644 --- a/Nfp/Circuit/Cert/DownstreamLinear.lean +++ b/Nfp/Circuit/Cert/DownstreamLinear.lean @@ -49,7 +49,7 @@ theorem checkDownstreamLinearCert_sound (c : DownstreamLinearCert) : have h' : ((0 ≤ c.error ∧ 0 ≤ c.gain) ∧ 0 ≤ c.inputBound) ∧ c.error = c.gain * c.inputBound := by - simpa [checkDownstreamLinearCert, Bool.and_eq_true] using h + simpa [checkDownstreamLinearCert, Bool.and_eq_true, decide_eq_true_iff] using h rcases h' with ⟨⟨⟨herror, hgain⟩, hinput⟩, heq⟩ refine { error_nonneg := herror diff --git a/Nfp/Circuit/Cert/LogitDiff.lean b/Nfp/Circuit/Cert/LogitDiff.lean index 46a3992..28742e7 100644 --- a/Nfp/Circuit/Cert/LogitDiff.lean +++ b/Nfp/Circuit/Cert/LogitDiff.lean @@ -52,20 +52,16 @@ theorem logitDiffLowerBound_le (active : Finset (Fin seq)) classical intro lb hbound have hnonempty : active.Nonempty := ⟨q, hq⟩ - have hbound' : - (active.image (fun q => vals (prev q) - eps * (hi - lo))).min' - (hnonempty.image (fun q => vals (prev q) - eps * (hi - lo))) = lb := by - simpa [logitDiffLowerBound, hnonempty] using hbound let gap := eps * (hi - lo) let f : Fin seq → Rat := fun q => vals (prev q) - gap + have hbound' : (active.image f).min' (hnonempty.image f) = lb := by + simpa [logitDiffLowerBound, hnonempty, f, gap] using hbound have hmem : f q ∈ (active.image f) := by refine Finset.mem_image.2 ?_ exact ⟨q, hq, rfl⟩ have hmin : (active.image f).min' (hnonempty.image f) ≤ f q := Finset.min'_le _ _ hmem - have hlb : lb = (active.image f).min' (hnonempty.image f) := by - simpa [f, gap] using hbound'.symm - simpa [f, gap, hlb] using hmin + simpa [f, gap, hbound'] using hmin /-- The per-query lower bound is below every active `prev` value minus the local gap. -/ theorem logitDiffLowerBoundAt_le (active : Finset (Fin seq)) @@ -77,20 +73,16 @@ theorem logitDiffLowerBoundAt_le (active : Finset (Fin seq)) classical intro lb hbound have hnonempty : active.Nonempty := ⟨q, hq⟩ - have hbound' : - (active.image (fun q => vals (prev q) - epsAt q * (hi - lo))).min' - (hnonempty.image (fun q => vals (prev q) - epsAt q * (hi - lo))) = lb := by - simpa [logitDiffLowerBoundAt, hnonempty] using hbound let gap : Fin seq → Rat := fun q => epsAt q * (hi - lo) let f : Fin seq → Rat := fun q => vals (prev q) - gap q + have hbound' : (active.image f).min' (hnonempty.image f) = lb := by + simpa [logitDiffLowerBoundAt, hnonempty, f, gap] using hbound have hmem : f q ∈ (active.image f) := by refine Finset.mem_image.2 ?_ exact ⟨q, hq, rfl⟩ have hmin : (active.image f).min' (hnonempty.image f) ≤ f q := Finset.min'_le _ _ hmem - have hlb : lb = (active.image f).min' (hnonempty.image f) := by - simpa [f, gap] using hbound'.symm - simpa [f, gap, hlb] using hmin + simpa [f, gap, hbound'] using hmin end Circuit diff --git a/Nfp/Circuit/Cert/ResidualBound.lean b/Nfp/Circuit/Cert/ResidualBound.lean index 9adc684..e7aa6bd 100644 --- a/Nfp/Circuit/Cert/ResidualBound.lean +++ b/Nfp/Circuit/Cert/ResidualBound.lean @@ -40,7 +40,7 @@ theorem checkResidualBoundCert_sound {n : Nat} (c : ResidualBoundCert n) : refine { bound_nonneg := ?_ } intro i have hi := hall' i (by simp) - exact (decide_eq_true_iff).1 hi + simpa [decide_eq_true_iff] using hi end Circuit diff --git a/Nfp/Circuit/Cert/ResidualInterval.lean b/Nfp/Circuit/Cert/ResidualInterval.lean index 88370cd..aa36547 100644 --- a/Nfp/Circuit/Cert/ResidualInterval.lean +++ b/Nfp/Circuit/Cert/ResidualInterval.lean @@ -42,7 +42,7 @@ theorem checkResidualIntervalCert_sound {n : Nat} (c : ResidualIntervalCert n) : refine { lo_le_hi := ?_ } intro i have hi := hall' i (by simp) - exact (decide_eq_true_iff).1 hi + simpa [decide_eq_true_iff] using hi end Circuit diff --git a/Nfp/Circuit/Cert/SoftmaxMargin.lean b/Nfp/Circuit/Cert/SoftmaxMargin.lean index 0e9de16..3dc3784 100644 --- a/Nfp/Circuit/Cert/SoftmaxMargin.lean +++ b/Nfp/Circuit/Cert/SoftmaxMargin.lean @@ -59,45 +59,52 @@ theorem checkSoftmaxMarginCert_sound [NeZero seq] (c : SoftmaxMarginCert seq) : c.prev c.scores c.weights := by classical intro hcheck + let weightsOk (q : Fin seq) : Bool := + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (0 ≤ c.weights q k) && + (if k = c.prev q then + true + else + decide (c.weights q k ≤ c.eps))) + let scoresOk (q : Fin seq) : Bool := + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + if k = c.prev q then + true + else + decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) have hqall : ∀ q ∈ (Finset.univ : Finset (Fin seq)), (if q ∈ c.active then - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - decide (0 ≤ c.weights q k) && - (if k = c.prev q then - true - else - decide (c.weights q k ≤ c.eps))) && - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - if k = c.prev q then - true - else - decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) && + weightsOk q && + scoresOk q && decide (1 ≤ c.weights q (c.prev q) + c.eps) && decide ((∑ k, c.weights q k) = 1) else true) = true := by - have hcheck' : checkSoftmaxMarginCert c = true := hcheck have hcheck'' : finsetAll (Finset.univ : Finset (Fin seq)) (fun q => if q ∈ c.active then - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - decide (0 ≤ c.weights q k) && - (if k = c.prev q then - true - else - decide (c.weights q k ≤ c.eps))) && - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - if k = c.prev q then - true - else - decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) && + weightsOk q && + scoresOk q && decide (1 ≤ c.weights q (c.prev q) + c.eps) && decide ((∑ k, c.weights q k) = 1) else true) = true := by - simpa [checkSoftmaxMarginCert] using hcheck' + simpa [checkSoftmaxMarginCert, weightsOk, scoresOk] using hcheck exact (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hcheck'' + have hqchecks {q} (hq : q ∈ c.active) : + weightsOk q = true ∧ + scoresOk q = true ∧ + decide (1 ≤ c.weights q (c.prev q) + c.eps) = true ∧ + decide ((∑ k, c.weights q k) = 1) = true := by + have hqall' := hqall q (by simp) + have hqall'' : + weightsOk q && + scoresOk q && + decide (1 ≤ c.weights q (c.prev q) + c.eps) && + decide ((∑ k, c.weights q k) = 1) = true := by + simpa [hq] using hqall' + simpa [Bool.and_eq_true, and_assoc] using hqall'' refine { score_margin := ?_ nonneg := ?_ @@ -105,95 +112,18 @@ theorem checkSoftmaxMarginCert_sound [NeZero seq] (c : SoftmaxMarginCert seq) : prev_large := ?_ other_le := ?_ } · intro q hq k hk - have hqcheck := hqall q (by simp) - have hqcheck' : - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - if k = c.prev q then - true - else - decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) = true := by - have hqcheck'' : - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - decide (0 ≤ c.weights q k) && - (if k = c.prev q then - true - else - decide (c.weights q k ≤ c.eps))) && - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - if k = c.prev q then - true - else - decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) && - decide (1 ≤ c.weights q (c.prev q) + c.eps) && - decide ((∑ k, c.weights q k) = 1) = true := by - simpa [hq] using hqcheck - have hqcheck''' : - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - decide (0 ≤ c.weights q k) && - (if k = c.prev q then - true - else - decide (c.weights q k ≤ c.eps))) = true ∧ - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - if k = c.prev q then - true - else - decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) = true ∧ - decide (1 ≤ c.weights q (c.prev q) + c.eps) = true ∧ - decide ((∑ k, c.weights q k) = 1) = true := by - simpa [Bool.and_eq_true, and_assoc] using hqcheck'' - rcases hqcheck''' with ⟨_, hscoreOk, _, _⟩ - exact hscoreOk + rcases hqchecks hq with ⟨_, hscore, _, _⟩ have hscoreall := - (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hqcheck' + (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hscore have hscorek := hscoreall k (by simp) have hscorek' : decide (c.scores q k + c.margin ≤ c.scores q (c.prev q)) = true := by simpa [hk] using hscorek - exact (decide_eq_true_iff).1 hscorek' + simpa [decide_eq_true_iff] using hscorek' · intro q hq k - have hqcheck := hqall q (by simp) - have hqcheck' : - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - decide (0 ≤ c.weights q k) && - (if k = c.prev q then - true - else - decide (c.weights q k ≤ c.eps))) = true := by - have hqcheck'' : - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - decide (0 ≤ c.weights q k) && - (if k = c.prev q then - true - else - decide (c.weights q k ≤ c.eps))) && - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - if k = c.prev q then - true - else - decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) && - decide (1 ≤ c.weights q (c.prev q) + c.eps) && - decide ((∑ k, c.weights q k) = 1) = true := by - simpa [hq] using hqcheck - have hqcheck''' : - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - decide (0 ≤ c.weights q k) && - (if k = c.prev q then - true - else - decide (c.weights q k ≤ c.eps))) = true ∧ - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - if k = c.prev q then - true - else - decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) = true ∧ - decide (1 ≤ c.weights q (c.prev q) + c.eps) = true ∧ - decide ((∑ k, c.weights q k) = 1) = true := by - simpa [Bool.and_eq_true, and_assoc] using hqcheck'' - rcases hqcheck''' with ⟨hweightsOk, _, _, _⟩ - exact hweightsOk + rcases hqchecks hq with ⟨hweights, _, _, _⟩ have hweightsall := - (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hqcheck' + (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hweights have hweightsk := hweightsall k (by simp) have hweightsk' : decide (0 ≤ c.weights q k) = true ∧ @@ -202,124 +132,17 @@ theorem checkSoftmaxMarginCert_sound [NeZero seq] (c : SoftmaxMarginCert seq) : else decide (c.weights q k ≤ c.eps)) = true := by simpa [Bool.and_eq_true] using hweightsk - exact (decide_eq_true_iff).1 hweightsk'.1 + simpa [decide_eq_true_iff] using hweightsk'.1 · intro q hq - have hqcheck := hqall q (by simp) - have hqcheck' : - decide ((∑ k, c.weights q k) = 1) = true := by - have hqcheck'' : - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - decide (0 ≤ c.weights q k) && - (if k = c.prev q then - true - else - decide (c.weights q k ≤ c.eps))) && - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - if k = c.prev q then - true - else - decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) && - decide (1 ≤ c.weights q (c.prev q) + c.eps) && - decide ((∑ k, c.weights q k) = 1) = true := by - simpa [hq] using hqcheck - have hqcheck''' : - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - decide (0 ≤ c.weights q k) && - (if k = c.prev q then - true - else - decide (c.weights q k ≤ c.eps))) = true ∧ - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - if k = c.prev q then - true - else - decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) = true ∧ - decide (1 ≤ c.weights q (c.prev q) + c.eps) = true ∧ - decide ((∑ k, c.weights q k) = 1) = true := by - simpa [Bool.and_eq_true, and_assoc] using hqcheck'' - rcases hqcheck''' with ⟨_, _, _, hsumOk⟩ - exact hsumOk - exact (decide_eq_true_iff).1 hqcheck' + rcases hqchecks hq with ⟨_, _, _, hsum⟩ + simpa [decide_eq_true_iff] using hsum · intro q hq - have hqcheck := hqall q (by simp) - have hqcheck' : - decide (1 ≤ c.weights q (c.prev q) + c.eps) = true := by - have hqcheck'' : - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - decide (0 ≤ c.weights q k) && - (if k = c.prev q then - true - else - decide (c.weights q k ≤ c.eps))) && - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - if k = c.prev q then - true - else - decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) && - decide (1 ≤ c.weights q (c.prev q) + c.eps) && - decide ((∑ k, c.weights q k) = 1) = true := by - simpa [hq] using hqcheck - have hqcheck''' : - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - decide (0 ≤ c.weights q k) && - (if k = c.prev q then - true - else - decide (c.weights q k ≤ c.eps))) = true ∧ - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - if k = c.prev q then - true - else - decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) = true ∧ - decide (1 ≤ c.weights q (c.prev q) + c.eps) = true ∧ - decide ((∑ k, c.weights q k) = 1) = true := by - simpa [Bool.and_eq_true, and_assoc] using hqcheck'' - rcases hqcheck''' with ⟨_, _, hprevOk, _⟩ - exact hprevOk - exact (decide_eq_true_iff).1 hqcheck' + rcases hqchecks hq with ⟨_, _, hprev, _⟩ + simpa [decide_eq_true_iff] using hprev · intro q hq k hk - have hqcheck := hqall q (by simp) - have hqcheck' : - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - decide (0 ≤ c.weights q k) && - (if k = c.prev q then - true - else - decide (c.weights q k ≤ c.eps))) = true := by - have hqcheck'' : - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - decide (0 ≤ c.weights q k) && - (if k = c.prev q then - true - else - decide (c.weights q k ≤ c.eps))) && - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - if k = c.prev q then - true - else - decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) && - decide (1 ≤ c.weights q (c.prev q) + c.eps) && - decide ((∑ k, c.weights q k) = 1) = true := by - simpa [hq] using hqcheck - have hqcheck''' : - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - decide (0 ≤ c.weights q k) && - (if k = c.prev q then - true - else - decide (c.weights q k ≤ c.eps))) = true ∧ - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => - if k = c.prev q then - true - else - decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) = true ∧ - decide (1 ≤ c.weights q (c.prev q) + c.eps) = true ∧ - decide ((∑ k, c.weights q k) = 1) = true := by - simpa [Bool.and_eq_true, and_assoc] using hqcheck'' - rcases hqcheck''' with ⟨hweightsOk, _, _, _⟩ - exact hweightsOk + rcases hqchecks hq with ⟨hweights, _, _, _⟩ have hweightsall := - (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hqcheck' + (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hweights have hweightsk := hweightsall k (by simp) have hweightsk' : decide (0 ≤ c.weights q k) = true ∧ @@ -331,7 +154,7 @@ theorem checkSoftmaxMarginCert_sound [NeZero seq] (c : SoftmaxMarginCert seq) : have hother : decide (c.weights q k ≤ c.eps) = true := by simpa [hk] using hweightsk'.2 - exact (decide_eq_true_iff).1 hother + simpa [decide_eq_true_iff] using hother end Circuit diff --git a/Nfp/Circuit/Cert/ValueRange.lean b/Nfp/Circuit/Cert/ValueRange.lean index fb70c55..ad20a04 100644 --- a/Nfp/Circuit/Cert/ValueRange.lean +++ b/Nfp/Circuit/Cert/ValueRange.lean @@ -53,24 +53,18 @@ theorem checkValueRangeCert_sound [NeZero seq] (c : ValueRangeCert seq) : decide (c.lo ≤ c.vals k) && decide (c.vals k ≤ c.hi)) = true := by simpa [checkValueRangeCert, Bool.and_eq_true] using hcheck rcases hcheck' with ⟨hlohi, hall⟩ - have hlohi' : c.lo ≤ c.hi := (decide_eq_true_iff).1 hlohi + have hlohi' : c.lo ≤ c.hi := by + simpa [decide_eq_true_iff] using hlohi have hall' := (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hall - have hlo : ∀ k, c.lo ≤ c.vals k := by + have hbounds : ∀ k, c.lo ≤ c.vals k ∧ c.vals k ≤ c.hi := by intro k have hk := hall' k (by simp) - have hk' : - decide (c.lo ≤ c.vals k) = true ∧ decide (c.vals k ≤ c.hi) = true := by - simpa [Bool.and_eq_true] using hk - exact (decide_eq_true_iff).1 hk'.1 - have hhi : ∀ k, c.vals k ≤ c.hi := by - intro k - have hk := hall' k (by simp) - have hk' : - decide (c.lo ≤ c.vals k) = true ∧ decide (c.vals k ≤ c.hi) = true := by - simpa [Bool.and_eq_true] using hk - exact (decide_eq_true_iff).1 hk'.2 - exact { lo_le_hi := hlohi', lo_le := hlo, le_hi := hhi } + simpa [Bool.and_eq_true, decide_eq_true_iff] using hk + exact + { lo_le_hi := hlohi' + lo_le := fun k => (hbounds k).1 + le_hi := fun k => (hbounds k).2 } end Circuit diff --git a/Nfp/Circuit/Compose.lean b/Nfp/Circuit/Compose.lean index 88e654b..77fedb7 100644 --- a/Nfp/Circuit/Compose.lean +++ b/Nfp/Circuit/Compose.lean @@ -116,16 +116,9 @@ theorem seqBridge_iff_eq {j : Node₁} {i : Node₂} (hmem : i ∈ C2.inputs) : · rintro ⟨h, hEq⟩ have hSubtype : (⟨i, h⟩ : { i // i ∈ C2.inputs }) = ⟨i, hmem⟩ := by - apply Subtype.ext + ext rfl - have hMid : - I2.inputs.symm ⟨i, h⟩ = I2.inputs.symm ⟨i, hmem⟩ := by - exact congrArg I2.inputs.symm hSubtype - have hOut : - (I1.outputs (I2.inputs.symm ⟨i, h⟩)).1 = - (I1.outputs (I2.inputs.symm ⟨i, hmem⟩)).1 := by - exact congrArg Subtype.val (congrArg I1.outputs hMid) - exact hEq.trans hOut + simpa [hSubtype] using hEq · intro hEq exact ⟨hmem, hEq⟩ @@ -203,9 +196,7 @@ def seqCircuit : Circuit (Node₁ ⊕ Node₂) Val := cases i with | inl i => exact C1.gate i (fun j h => - rec (Sum.inl j) (by - change C1.dag.rel j i - exact h)) + rec (Sum.inl j) (by simpa using h)) | inr i => by_cases hinput : i ∈ C2.inputs · let mid : Mid := I2.inputs.symm ⟨i, hinput⟩ @@ -213,9 +204,7 @@ def seqCircuit : Circuit (Node₁ ⊕ Node₂) Val := exact rec (Sum.inl out) (by refine ⟨hinput, rfl⟩) · exact C2.gate i (fun j h => - rec (Sum.inr j) (by - change C2.dag.rel j i - exact h)) } + rec (Sum.inr j) (by simpa using h)) } /-- Interface for sequentially composed circuits. -/ def seqInterface : diff --git a/Nfp/Circuit/Semantics.lean b/Nfp/Circuit/Semantics.lean index 2105fbf..300129f 100644 --- a/Nfp/Circuit/Semantics.lean +++ b/Nfp/Circuit/Semantics.lean @@ -32,7 +32,7 @@ theorem eval_eq (C : Circuit ι α) (input : ι → α) (i : ι) : change C.dag.wf.fix F i = if _ : i ∈ C.inputs then input i else C.gate i (fun j _ => C.dag.wf.fix F j) rw [WellFounded.fix_eq] - dsimp [F, evalStep] + simp [F, evalStep] /-- Input nodes evaluate to their assigned input value. -/ theorem eval_eq_input (C : Circuit ι α) (input : ι → α) {i : ι} (h : i ∈ C.inputs) : @@ -61,7 +61,7 @@ theorem evalInput_eq (C : Circuit ι α) (input : C.InputAssignment) (i : ι) : change C.dag.wf.fix F i = if h : i ∈ C.inputs then input ⟨i, h⟩ else C.gate i (fun j _ => C.dag.wf.fix F j) rw [WellFounded.fix_eq] - dsimp [F, evalInputStep] + simp [F, evalInputStep] /-- Input nodes evaluate to their assigned input value (input-only form). -/ theorem evalInput_eq_input (C : Circuit ι α) (input : C.InputAssignment) {i : ι} diff --git a/Nfp/Circuit/WellFormed.lean b/Nfp/Circuit/WellFormed.lean index 070f565..f4a0535 100644 --- a/Nfp/Circuit/WellFormed.lean +++ b/Nfp/Circuit/WellFormed.lean @@ -22,13 +22,14 @@ def WellFormed (C : Circuit ι α) : Prop := /-- Inputs have no parents in a well-formed circuit. -/ theorem wellFormed_no_parent {C : Circuit ι α} (h : WellFormed C) {i j : ι} (hi : i ∈ C.inputs) : ¬ C.dag.rel j i := - h i hi j + by + simpa using h i hi j /-- Input nodes have empty parent sets in a well-formed circuit. -/ theorem wellFormed_parents_empty {C : Circuit ι α} (h : WellFormed C) {i : ι} (hi : i ∈ C.inputs) : C.dag.parents i = ∅ := by ext j - simp [Dag.mem_parents, h i hi j] + simp [Dag.mem_parents, h i hi] end Circuit diff --git a/Nfp/Core/Basic.lean b/Nfp/Core/Basic.lean index f9589c2..758f488 100644 --- a/Nfp/Core/Basic.lean +++ b/Nfp/Core/Basic.lean @@ -119,8 +119,8 @@ theorem ratToReal_le_iff {x y : Rat} : /-- Rational order implies real order after casting. -/ theorem ratToReal_le_of_le {x y : Rat} (h : x ≤ y) : - ratToReal x ≤ ratToReal y := - (ratToReal_le_iff (x := x) (y := y)).2 h + ratToReal x ≤ ratToReal y := by + simpa [ratToReal_le_iff] using h theorem ratToReal_lt_iff {x y : Rat} : ratToReal x < ratToReal y ↔ x < y := by @@ -131,8 +131,8 @@ theorem ratToReal_nonneg_iff {x : Rat} : simp [ratToReal] theorem ratToReal_nonneg_of_nonneg {x : Rat} (h : 0 ≤ x) : - 0 ≤ ratToReal x := - (ratToReal_nonneg_iff (x := x)).2 h + 0 ≤ ratToReal x := by + simpa [ratToReal_nonneg_iff] using h theorem ratToReal_nonpos_iff {x : Rat} : ratToReal x ≤ 0 ↔ x ≤ 0 := by @@ -144,9 +144,7 @@ theorem ratToReal_nonpos_iff {x : Rat} : theorem ratToReal_abs_le_of_le {x y : Rat} (h : |x| ≤ y) : |ratToReal x| ≤ ratToReal y := by - have h' : ratToReal |x| ≤ ratToReal y := - ratToReal_le_of_le h - simpa [ratToReal_abs] using h' + simpa [ratToReal_abs] using ratToReal_le_of_le h @[simp] theorem ratToReal_max (x y : Rat) : ratToReal (max x y) = max (ratToReal x) (ratToReal y) := by diff --git a/Nfp/Mixer/Operations.lean b/Nfp/Mixer/Operations.lean index 8fc4a64..5508ddf 100644 --- a/Nfp/Mixer/Operations.lean +++ b/Nfp/Mixer/Operations.lean @@ -23,26 +23,20 @@ private lemma sum_mul_sum (p : ι → Mass) (w : ι → κ → Mass) : classical calc ∑ k, ∑ i, p i * w i k = ∑ i, ∑ k, p i * w i k := by - simpa using - (Finset.sum_comm : - (∑ k : κ, ∑ i : ι, p i * w i k) = ∑ i : ι, ∑ k : κ, p i * w i k) + exact + (Finset.sum_comm (s := (Finset.univ : Finset κ)) (t := (Finset.univ : Finset ι)) + (f := fun k i => p i * w i k)) _ = ∑ i, p i * ∑ k, w i k := by refine Finset.sum_congr rfl ?_ intro i _ - simpa using - (Finset.mul_sum (a := p i) (s := (Finset.univ : Finset κ)) - (f := fun k => w i k)).symm + simp [Finset.mul_sum] /-- Push a probability vector forward along a mixer. -/ def push (M : Mixer ι κ) (p : ProbVec ι) : ProbVec κ := { mass := fun k => ∑ i, p.mass i * M.weight i k sum_mass := by classical - calc - ∑ k, ∑ i, p.mass i * M.weight i k - = ∑ i, p.mass i * ∑ k, M.weight i k := by - simpa using sum_mul_sum (p := fun i => p.mass i) (w := fun i => M.weight i) - _ = 1 := by simp } + simp [sum_mul_sum] } /-- Composition of two mixers. -/ def comp (M : Mixer ι κ) (N : Mixer κ α) : Mixer ι α := @@ -50,19 +44,14 @@ def comp (M : Mixer ι κ) (N : Mixer κ α) : Mixer ι α := row_sum := by classical intro i - calc - ∑ a, ∑ k, M.weight i k * N.weight k a - = ∑ k, M.weight i k * ∑ a, N.weight k a := by - simpa using - sum_mul_sum (p := fun k => M.weight i k) (w := fun k => N.weight k) - _ = 1 := by simp } + simp [sum_mul_sum] } /-- Identity mixer. -/ def id (ι : Type u) [Fintype ι] [DecidableEq ι] : Mixer ι ι := { weight := fun i j => (ProbVec.pure i).mass j row_sum := by intro i - exact (ProbVec.pure i).sum_mass } + simp } end Mixer end Nfp diff --git a/Nfp/Prob/Operations.lean b/Nfp/Prob/Operations.lean index d32dd74..79440d4 100644 --- a/Nfp/Prob/Operations.lean +++ b/Nfp/Prob/Operations.lean @@ -28,28 +28,18 @@ def pure (i0 : ι) [DecidableEq ι] : ProbVec ι := by refine { mass := Pi.single i0 (1 : Mass) sum_mass := ?_ } - exact (Fintype.sum_pi_single' (ι := ι) (i := i0) (a := (1 : Mass))) + simp @[simp] theorem mass_pure (i0 i : ι) [DecidableEq ι] : (pure i0).mass i = if i = i0 then 1 else 0 := by - by_cases h : i = i0 - · subst h - simp [pure, Pi.single] - · simp [pure, Pi.single, h] + by_cases h : i = i0 <;> simp [pure, Pi.single, h] /-- Convex combination of two probability vectors with weights that sum to one. -/ def mix (a b : Mass) (h : a + b = 1) (p q : ProbVec ι) : ProbVec ι := { mass := fun i => a * p.mass i + b * q.mass i sum_mass := by classical - calc - ∑ i, (a * p.mass i + b * q.mass i) - = (∑ i, a * p.mass i) + (∑ i, b * q.mass i) := by - simp [Finset.sum_add_distrib] - _ = a * ∑ i, p.mass i + b * ∑ i, q.mass i := by - simp [sum_mul_const] - _ = a * 1 + b * 1 := by simp - _ = 1 := by simp [h] } + simp [Finset.sum_add_distrib, sum_mul_const, h] } end ProbVec end Nfp diff --git a/Nfp/Sound/Bounds/Gelu.lean b/Nfp/Sound/Bounds/Gelu.lean index 6817dec..adf58df 100644 --- a/Nfp/Sound/Bounds/Gelu.lean +++ b/Nfp/Sound/Bounds/Gelu.lean @@ -86,11 +86,8 @@ theorem geluTanh_bounds (x : Real) : 2) ≤ x := by have h := mul_le_mul_of_nonneg_left hcoeff.2 hx simpa [mul_one] using h - have h0 : 0 ≤ geluTanh x := by - simpa [geluTanh] using hnonneg - have h1 : geluTanh x ≤ x := by - simpa [geluTanh] using hle - simpa [min_eq_right hx, max_eq_left hx] using And.intro h0 h1 + simpa [geluTanh, min_eq_right hx, max_eq_left hx] using + And.intro hnonneg hle · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) have hcoeff := geluTanh_coeff_bounds x have hle0 : @@ -111,11 +108,7 @@ theorem geluTanh_bounds (x : Real) : 2) := by have h := mul_le_mul_of_nonpos_left hcoeff.2 hx' simpa [mul_one] using h - have h0 : geluTanh x ≤ 0 := by - simpa [geluTanh] using hle0 - have h1 : x ≤ geluTanh x := by - simpa [geluTanh] using hxle - simpa [min_eq_left hx', max_eq_right hx'] using And.intro h1 h0 + simpa [geluTanh, min_eq_left hx', max_eq_right hx'] using And.intro hxle hle0 /-- Interval bounds for GELU given input bounds. -/ def geluInterval (lo hi : Rat) : Rat × Rat := @@ -142,8 +135,7 @@ theorem geluInterval_bounds {lo hi : Rat} {x : Real} · by_cases hhi0 : 0 ≤ hi · have hhi0r : 0 ≤ (hi : Real) := by exact ratToReal_nonneg_of_nonneg hhi0 - have hmax' : max (hi : Real) 0 = (hi : Real) := max_eq_left hhi0r - simpa [geluInterval, hhi0, hmax'] using hhi' + simpa [geluInterval, hhi0, max_eq_left hhi0r] using hhi' · have hhi0r : (hi : Real) ≤ 0 := by exact (ratToReal_nonpos_iff (x := hi)).2 (le_of_not_ge hhi0) have hx0 : x ≤ 0 := le_trans hhi hhi0r @@ -164,8 +156,7 @@ theorem geluInterval_bounds {lo hi : Rat} {x : Real} · by_cases hhi0 : 0 ≤ hi · have hhi0r : 0 ≤ (hi : Real) := by exact ratToReal_nonneg_of_nonneg hhi0 - have hmax' : max (hi : Real) 0 = (hi : Real) := max_eq_left hhi0r - simpa [geluInterval, hhi0, hmax'] using hhi' + simpa [geluInterval, hhi0, max_eq_left hhi0r] using hhi' · have hhi0r : (hi : Real) ≤ 0 := by exact (ratToReal_nonpos_iff (x := hi)).2 (le_of_not_ge hhi0) have hx0' : x ≤ 0 := le_trans hhi hhi0r diff --git a/Nfp/Sound/Bounds/LayerNorm.lean b/Nfp/Sound/Bounds/LayerNorm.lean index 3a11a47..c742a23 100644 --- a/Nfp/Sound/Bounds/LayerNorm.lean +++ b/Nfp/Sound/Bounds/LayerNorm.lean @@ -686,17 +686,11 @@ theorem scaleInterval_bounds {x lo hi y : Rat} let bounds := scaleInterval x lo hi bounds.1 ≤ x * y ∧ x * y ≤ bounds.2 := by by_cases hx : 0 ≤ x - · have h1 : x * lo ≤ x * y := by - exact mul_le_mul_of_nonneg_left hlo hx - have h2 : x * y ≤ x * hi := by - exact mul_le_mul_of_nonneg_left hhi hx - simp [scaleInterval, hx, h1, h2] + · simpa [scaleInterval, hx] using + And.intro (mul_le_mul_of_nonneg_left hlo hx) (mul_le_mul_of_nonneg_left hhi hx) · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) - have h1 : x * hi ≤ x * y := by - exact mul_le_mul_of_nonpos_left hhi hx' - have h2 : x * y ≤ x * lo := by - exact mul_le_mul_of_nonpos_left hlo hx' - simp [scaleInterval, hx, h1, h2] + simpa [scaleInterval, hx] using + And.intro (mul_le_mul_of_nonpos_left hhi hx') (mul_le_mul_of_nonpos_left hlo hx') /-- `scaleInterval` bounds interpreted in the reals. -/ theorem scaleInterval_bounds_real {x lo hi : Rat} {y : Real} @@ -704,21 +698,13 @@ theorem scaleInterval_bounds_real {x lo hi : Rat} {y : Real} let bounds := scaleInterval x lo hi (bounds.1 : Real) ≤ (x : Real) * y ∧ (x : Real) * y ≤ (bounds.2 : Real) := by by_cases hx : 0 ≤ x - · have h1 : (x : Real) * (lo : Real) ≤ (x : Real) * y := by - have hx' : 0 ≤ (x : Real) := ratToReal_nonneg_of_nonneg hx - exact mul_le_mul_of_nonneg_left hlo hx' - have h2 : (x : Real) * y ≤ (x : Real) * (hi : Real) := by - have hx' : 0 ≤ (x : Real) := ratToReal_nonneg_of_nonneg hx - exact mul_le_mul_of_nonneg_left hhi hx' - simp [scaleInterval, hx, h1, h2] + · have hx' : 0 ≤ (x : Real) := ratToReal_nonneg_of_nonneg hx + simpa [scaleInterval, hx] using + And.intro (mul_le_mul_of_nonneg_left hlo hx') (mul_le_mul_of_nonneg_left hhi hx') · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) - have h1 : (x : Real) * (hi : Real) ≤ (x : Real) * y := by - have hx'' : (x : Real) ≤ 0 := (ratToReal_nonpos_iff (x := x)).2 hx' - exact mul_le_mul_of_nonpos_left hhi hx'' - have h2 : (x : Real) * y ≤ (x : Real) * (lo : Real) := by - have hx'' : (x : Real) ≤ 0 := (ratToReal_nonpos_iff (x := x)).2 hx' - exact mul_le_mul_of_nonpos_left hlo hx'' - simp [scaleInterval, hx, h1, h2] + have hx'' : (x : Real) ≤ 0 := (ratToReal_nonpos_iff (x := x)).2 hx' + simpa [scaleInterval, hx] using + And.intro (mul_le_mul_of_nonpos_left hhi hx'') (mul_le_mul_of_nonpos_left hlo hx'') /-- Real-valued LayerNorm output for a vector. -/ noncomputable def layerNormReal {n : Nat} diff --git a/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean b/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean index ba506e8..257a492 100644 --- a/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean +++ b/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean @@ -128,16 +128,8 @@ theorem meanReal_eq_meanRat {n : Nat} (x : Fin n → Rat) : meanReal (fun i => (x i : Real)) = (meanRat x : Real) := by by_cases h : n = 0 · simp [meanReal, meanRat, h] - · have hsum : - (sumRat x : Real) = ∑ i, (x i : Real) := by - classical - unfold sumRat - simp [Rat.cast_sum] - have hmean : (meanRat x : Real) = (sumRat x : Real) / n := by - simp [meanRat, h] - have hreal : meanReal (fun i => (x i : Real)) = (∑ i, (x i : Real)) / n := by - simp [meanReal, h] - simpa [hmean, hsum] using hreal + · classical + simp [meanReal, meanRat, sumRat, h, Rat.cast_sum] /-- Mean is monotone under pointwise order (real inputs). -/ theorem meanReal_le_meanReal {n : Nat} (x y : Fin n → Real) (hne : n ≠ 0) @@ -197,26 +189,8 @@ theorem varianceReal_eq_varianceRat {n : Nat} (x : Fin n → Rat) : varianceReal (fun i => (x i : Real)) = (varianceRat x : Real) := by by_cases h : n = 0 · simp [varianceReal, varianceRat, h] - · have hmean := meanReal_eq_meanRat (n := n) x - have hsum : - (∑ i, ((x i : Real) - (meanRat x : Real)) ^ 2) = - (∑ i, ((x i : Rat) - meanRat x) ^ 2 : Rat) := by - classical - simp [Rat.cast_sum] - have hreal : varianceReal (fun i => (x i : Real)) = - (∑ i, ((x i : Real) - meanReal (fun j => (x j : Real))) ^ 2) / n := by - simp [varianceReal, h] - have hrat : (varianceRat x : Real) = - (∑ i, ((x i : Rat) - meanRat x) ^ 2 : Rat) / n := by - simp [varianceRat, h] - calc - varianceReal (fun i => (x i : Real)) - = (∑ i, ((x i : Real) - meanReal (fun j => (x j : Real))) ^ 2) / n := hreal - _ = (∑ i, ((x i : Real) - (meanRat x : Real)) ^ 2) / n := by - simp [hmean] - _ = (∑ i, ((x i : Rat) - meanRat x) ^ 2 : Rat) / n := by - rw [hsum] - _ = (varianceRat x : Real) := hrat.symm + · classical + simp [varianceReal, varianceRat, h, meanReal_eq_meanRat, Rat.cast_sum] /-- Variance is nonnegative when `n ≠ 0`, interpreted in reals. -/ theorem varianceRat_nonneg_real {n : Nat} (x : Fin n → Rat) (hne : n ≠ 0) : diff --git a/Nfp/Sound/Bounds/MatrixNorm.lean b/Nfp/Sound/Bounds/MatrixNorm.lean index b544969..d301a0f 100644 --- a/Nfp/Sound/Bounds/MatrixNorm.lean +++ b/Nfp/Sound/Bounds/MatrixNorm.lean @@ -48,25 +48,15 @@ def rowSumWeightedNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) /-- Row-sums are nonnegative. -/ theorem rowSum_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : 0 ≤ rowSum W i := by - have hsum : rowSum W i = ∑ j, |W i j| := by - simp [rowSum, Linear.sumFin_eq_sum_univ] - have hnonneg : 0 ≤ ∑ j, |W i j| := by - refine Finset.sum_nonneg ?_ - intro j _ - exact abs_nonneg (W i j) - simpa [hsum] using hnonneg + simpa [rowSum, Linear.sumFin_eq_sum_univ] using + (Finset.sum_nonneg (fun j _ => abs_nonneg (W i j))) /-- Weighted row-sums are nonnegative under nonnegative bounds. -/ theorem rowSumWeighted_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (bound : Fin n → Rat) (i : Fin m) (hbound : ∀ j, 0 ≤ bound j) : 0 ≤ rowSumWeighted W bound i := by - have hsum : rowSumWeighted W bound i = ∑ j, |W i j| * bound j := by - simp [rowSumWeighted, Linear.sumFin_eq_sum_univ] - have hnonneg : 0 ≤ ∑ j, |W i j| * bound j := by - refine Finset.sum_nonneg ?_ - intro j _ - exact mul_nonneg (abs_nonneg (W i j)) (hbound j) - simpa [hsum] using hnonneg + simpa [rowSumWeighted, Linear.sumFin_eq_sum_univ] using + (Finset.sum_nonneg (fun j _ => mul_nonneg (abs_nonneg (W i j)) (hbound j))) /-- Each row-sum is bounded by the row-sum norm. -/ theorem rowSum_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : @@ -131,13 +121,12 @@ theorem abs_mulVec_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) have h2 : ∑ j, |W i j * x j| ≤ ∑ j, |W i j| * inputBound := by refine Finset.sum_le_sum ?_ intro j _ - have hxj := hx j have hnonneg : 0 ≤ |W i j| := abs_nonneg (W i j) calc |W i j * x j| = |W i j| * |x j| := by simp [abs_mul] _ ≤ |W i j| * inputBound := by - exact mul_le_mul_of_nonneg_left hxj hnonneg + exact mul_le_mul_of_nonneg_left (hx j) hnonneg have h3 : ∑ j, |W i j| * inputBound = rowSum W i * inputBound := by have hsum : (∑ j, |W i j|) * inputBound = ∑ j, |W i j| * inputBound := by @@ -147,9 +136,7 @@ theorem abs_mulVec_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (f := fun j => |W i j|) (a := inputBound)) simpa [rowSum, Linear.sumFin_eq_sum_univ] using hsum.symm - have hmul : |Matrix.mulVec W x i| ≤ rowSum W i * inputBound := by - simpa [Matrix.mulVec, dotProduct] using h1.trans (h2.trans_eq h3) - exact hmul + simpa [Matrix.mulVec, dotProduct] using h1.trans (h2.trans_eq h3) have hle : rowSum W i ≤ rowSumNorm W := rowSum_le_rowSumNorm W i have hmul : rowSum W i * inputBound ≤ rowSumNorm W * inputBound := mul_le_mul_of_nonneg_right hle hinput diff --git a/Nfp/Sound/Bounds/MatrixNorm/Interval.lean b/Nfp/Sound/Bounds/MatrixNorm/Interval.lean index d233878..5421204 100644 --- a/Nfp/Sound/Bounds/MatrixNorm/Interval.lean +++ b/Nfp/Sound/Bounds/MatrixNorm/Interval.lean @@ -30,10 +30,8 @@ lemma foldl_max_ge_init {α : Type _} (f : α → Rat) : | nil => simp | cons a l ih => - have hinit : init ≤ max init (f a) := le_max_left _ _ - have hrest : max init (f a) ≤ l.foldl (fun acc x => max acc (f x)) (max init (f a)) := - ih (max init (f a)) - simpa [List.foldl] using le_trans hinit hrest + simpa [List.foldl] using + le_trans (le_max_left _ _) (ih (max init (f a))) lemma foldl_max_ge_mem {α : Type _} (f : α → Rat) : ∀ (l : List α) (a : α) (init : Rat), @@ -43,49 +41,25 @@ lemma foldl_max_ge_mem {α : Type _} (f : α → Rat) : | nil => cases hmem | cons b l ih => - have hmem' : a = b ∨ a ∈ l := by - simpa using hmem - cases hmem' with - | inl h => - subst h - have hstep : f a ≤ max init (f a) := le_max_right _ _ - have hrest : - max init (f a) ≤ l.foldl (fun acc x => max acc (f x)) (max init (f a)) := - foldl_max_ge_init (f := f) l (max init (f a)) - simpa [List.foldl] using le_trans hstep hrest - | inr h => - have h' := ih (init := max init (f b)) h - simpa [List.foldl] using h' + rcases (List.mem_cons.mp hmem) with rfl | hmem + · simpa [List.foldl] using + le_trans (le_max_right _ _) + (foldl_max_ge_init (f := f) l (max init (f a))) + · simpa [List.foldl] using ih (init := max init (f b)) hmem lemma foldlFin_max_ge_init {n : Nat} (f : Fin n → Rat) (init : Rat) : init ≤ Linear.foldlFin n (fun acc j => max acc (f j)) init := by classical - have hlist : - init ≤ (List.finRange n).foldl (fun acc j => max acc (f j)) init := - foldl_max_ge_init (f := f) (List.finRange n) init - have hfold : - Linear.foldlFin n (fun acc j => max acc (f j)) init = - (List.finRange n).foldl (fun acc j => max acc (f j)) init := by - simpa [Linear.foldlFin_eq_foldl] using - (Fin.foldl_eq_foldl_finRange - (f := fun acc j => max acc (f j)) (x := init) (n := n)) - simpa [hfold] using hlist + simpa [Linear.foldlFin_eq_foldl, Fin.foldl_eq_foldl_finRange] using + (foldl_max_ge_init (f := f) (List.finRange n) init) lemma foldlFin_max_ge {n : Nat} (f : Fin n → Rat) (i : Fin n) : f i ≤ Linear.foldlFin n (fun acc j => max acc (f j)) 0 := by classical have hmem : i ∈ List.finRange n := by simp - have hlist : - f i ≤ (List.finRange n).foldl (fun acc j => max acc (f j)) 0 := - foldl_max_ge_mem (f := f) (List.finRange n) i 0 hmem - have hfold : - Linear.foldlFin n (fun acc j => max acc (f j)) 0 = - (List.finRange n).foldl (fun acc j => max acc (f j)) 0 := by - simpa [Linear.foldlFin_eq_foldl] using - (Fin.foldl_eq_foldl_finRange - (f := fun acc j => max acc (f j)) (x := (0 : Rat)) (n := n)) - simpa [hfold] using hlist + simpa [Linear.foldlFin_eq_foldl, Fin.foldl_eq_foldl_finRange] using + (foldl_max_ge_mem (f := f) (List.finRange n) i 0 hmem) /-- Lower interval endpoint for a dot product with per-coordinate bounds. -/ def dotIntervalLower {n : Nat} (v lo hi : Fin n → Rat) : Rat := @@ -305,51 +279,23 @@ theorem dotIntervalLowerUpper2CommonDen_fst {n : Nat} (lo1 hi1 lo2 hi2 : Fin n (dotIntervalLowerUpper2CommonDen lo1 hi1 lo2 hi2).1 = dotIntervalLower2 lo1 hi1 lo2 hi2 := by classical - have hfold : - (dotIntervalLowerUpper2CommonDen lo1 hi1 lo2 hi2).1 = - (List.finRange n).foldl - (fun acc j => acc + mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j)) 0 := by - simpa [dotIntervalLowerUpper2CommonDen, Linear.foldlFin_eq_foldl, - Fin.foldl_eq_foldl_finRange] using - (foldl_pair_fst (xs := List.finRange n) - (f := fun j => mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j)) - (g := fun j => mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j)) - (a := 0) (b := 0)) - have hsum : - dotIntervalLower2 lo1 hi1 lo2 hi2 = - (List.finRange n).foldl - (fun acc j => acc + mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j)) 0 := by - simp [dotIntervalLower2, Linear.sumFin_eq_list_foldl] - calc - (dotIntervalLowerUpper2CommonDen lo1 hi1 lo2 hi2).1 - = (List.finRange n).foldl - (fun acc j => acc + mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j)) 0 := hfold - _ = dotIntervalLower2 lo1 hi1 lo2 hi2 := hsum.symm + simpa [dotIntervalLowerUpper2CommonDen, dotIntervalLower2, Linear.foldlFin_eq_foldl, + Linear.sumFin_eq_list_foldl, Fin.foldl_eq_foldl_finRange] using + (foldl_pair_fst (xs := List.finRange n) + (f := fun j => mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j)) + (g := fun j => mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j)) + (a := 0) (b := 0)) theorem dotIntervalLowerUpper2CommonDen_snd {n : Nat} (lo1 hi1 lo2 hi2 : Fin n → Rat) : (dotIntervalLowerUpper2CommonDen lo1 hi1 lo2 hi2).2 = dotIntervalUpper2 lo1 hi1 lo2 hi2 := by classical - have hfold : - (dotIntervalLowerUpper2CommonDen lo1 hi1 lo2 hi2).2 = - (List.finRange n).foldl - (fun acc j => acc + mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j)) 0 := by - simpa [dotIntervalLowerUpper2CommonDen, Linear.foldlFin_eq_foldl, - Fin.foldl_eq_foldl_finRange] using - (foldl_pair_snd (xs := List.finRange n) - (f := fun j => mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j)) - (g := fun j => mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j)) - (a := 0) (b := 0)) - have hsum : - dotIntervalUpper2 lo1 hi1 lo2 hi2 = - (List.finRange n).foldl - (fun acc j => acc + mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j)) 0 := by - simp [dotIntervalUpper2, Linear.sumFin_eq_list_foldl] - calc - (dotIntervalLowerUpper2CommonDen lo1 hi1 lo2 hi2).2 - = (List.finRange n).foldl - (fun acc j => acc + mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j)) 0 := hfold - _ = dotIntervalUpper2 lo1 hi1 lo2 hi2 := hsum.symm + simpa [dotIntervalLowerUpper2CommonDen, dotIntervalUpper2, Linear.foldlFin_eq_foldl, + Linear.sumFin_eq_list_foldl, Fin.foldl_eq_foldl_finRange] using + (foldl_pair_snd (xs := List.finRange n) + (f := fun j => mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j)) + (g := fun j => mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j)) + (a := 0) (b := 0)) theorem dotIntervalLowerUpper2CommonDen_eq {n : Nat} (lo1 hi1 lo2 hi2 : Fin n → Rat) : dotIntervalLowerUpper2CommonDen lo1 hi1 lo2 hi2 = @@ -359,54 +305,24 @@ theorem dotIntervalLowerUpper2CommonDen_eq {n : Nat} (lo1 hi1 lo2 hi2 : Fin n theorem dotIntervalLowerUpperCommonDen_fst {n : Nat} (v lo hi : Fin n → Rat) : (dotIntervalLowerUpperCommonDen v lo hi).1 = dotIntervalLowerCommonDen v lo hi := by classical - have hsum : - dotIntervalLowerCommonDen v lo hi = - (List.finRange n).foldl - (fun acc j => acc + if 0 ≤ v j then v j * lo j else v j * hi j) 0 := by - simp [dotIntervalLowerCommonDen, Linear.sumFinCommonDen_eq_sumFin, - Linear.sumFin_eq_list_foldl] - have hfold : - (dotIntervalLowerUpperCommonDen v lo hi).1 = - (List.finRange n).foldl - (fun acc j => acc + if 0 ≤ v j then v j * lo j else v j * hi j) 0 := by - simpa [dotIntervalLowerUpperCommonDen, Linear.foldlFin_eq_foldl, - Fin.foldl_eq_foldl_finRange] using - (foldl_pair_fst (xs := List.finRange n) - (f := fun j => if 0 ≤ v j then v j * lo j else v j * hi j) - (g := fun j => if 0 ≤ v j then v j * hi j else v j * lo j) - (a := 0) (b := 0)) - calc - (dotIntervalLowerUpperCommonDen v lo hi).1 - = - (List.finRange n).foldl - (fun acc j => acc + if 0 ≤ v j then v j * lo j else v j * hi j) 0 := hfold - _ = dotIntervalLowerCommonDen v lo hi := hsum.symm + simpa [dotIntervalLowerUpperCommonDen, dotIntervalLowerCommonDen, + Linear.foldlFin_eq_foldl, Linear.sumFinCommonDen_eq_sumFin, + Linear.sumFin_eq_list_foldl, Fin.foldl_eq_foldl_finRange] using + (foldl_pair_fst (xs := List.finRange n) + (f := fun j => if 0 ≤ v j then v j * lo j else v j * hi j) + (g := fun j => if 0 ≤ v j then v j * hi j else v j * lo j) + (a := 0) (b := 0)) theorem dotIntervalLowerUpperCommonDen_snd {n : Nat} (v lo hi : Fin n → Rat) : (dotIntervalLowerUpperCommonDen v lo hi).2 = dotIntervalUpperCommonDen v lo hi := by classical - have hsum : - dotIntervalUpperCommonDen v lo hi = - (List.finRange n).foldl - (fun acc j => acc + if 0 ≤ v j then v j * hi j else v j * lo j) 0 := by - simp [dotIntervalUpperCommonDen, Linear.sumFinCommonDen_eq_sumFin, - Linear.sumFin_eq_list_foldl] - have hfold : - (dotIntervalLowerUpperCommonDen v lo hi).2 = - (List.finRange n).foldl - (fun acc j => acc + if 0 ≤ v j then v j * hi j else v j * lo j) 0 := by - simpa [dotIntervalLowerUpperCommonDen, Linear.foldlFin_eq_foldl, - Fin.foldl_eq_foldl_finRange] using - (foldl_pair_snd (xs := List.finRange n) - (f := fun j => if 0 ≤ v j then v j * lo j else v j * hi j) - (g := fun j => if 0 ≤ v j then v j * hi j else v j * lo j) - (a := 0) (b := 0)) - calc - (dotIntervalLowerUpperCommonDen v lo hi).2 - = - (List.finRange n).foldl - (fun acc j => acc + if 0 ≤ v j then v j * hi j else v j * lo j) 0 := hfold - _ = dotIntervalUpperCommonDen v lo hi := hsum.symm + simpa [dotIntervalLowerUpperCommonDen, dotIntervalUpperCommonDen, + Linear.foldlFin_eq_foldl, Linear.sumFinCommonDen_eq_sumFin, + Linear.sumFin_eq_list_foldl, Fin.foldl_eq_foldl_finRange] using + (foldl_pair_snd (xs := List.finRange n) + (f := fun j => if 0 ≤ v j then v j * lo j else v j * hi j) + (g := fun j => if 0 ≤ v j then v j * hi j else v j * lo j) + (a := 0) (b := 0)) /-- Single-pass lower/upper endpoints agree with the common-denominator bounds. -/ theorem dotIntervalLowerUpperCommonDen_eq {n : Nat} (v lo hi : Fin n → Rat) : @@ -492,13 +408,9 @@ theorem dotIntervalLower_le_dotProduct {n : Nat} (v lo hi x : Fin n → Rat) refine Finset.sum_le_sum ?_ intro j _ by_cases hv : 0 ≤ v j - · have h1 : v j * lo j ≤ v j * x j := - mul_le_mul_of_nonneg_left (hlo j) hv - simpa [hv] using h1 + · simpa [hv] using (mul_le_mul_of_nonneg_left (hlo j) hv) · have hv' : v j ≤ 0 := le_of_lt (lt_of_not_ge hv) - have h1 : v j * hi j ≤ v j * x j := - mul_le_mul_of_nonpos_left (hhi j) hv' - simpa [hv] using h1 + simpa [hv] using (mul_le_mul_of_nonpos_left (hhi j) hv') theorem dotProduct_le_dotIntervalUpper {n : Nat} (v lo hi x : Fin n → Rat) (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : @@ -508,34 +420,13 @@ theorem dotProduct_le_dotIntervalUpper {n : Nat} (v lo hi x : Fin n → Rat) refine Finset.sum_le_sum ?_ intro j _ by_cases hv : 0 ≤ v j - · have h1 : v j * x j ≤ v j * hi j := - mul_le_mul_of_nonneg_left (hhi j) hv - simpa [hv] using h1 + · simpa [hv] using (mul_le_mul_of_nonneg_left (hhi j) hv) · have hv' : v j ≤ 0 := le_of_lt (lt_of_not_ge hv) - have h1 : v j * x j ≤ v j * lo j := - mul_le_mul_of_nonpos_left (hlo j) hv' - simpa [hv] using h1 + simpa [hv] using (mul_le_mul_of_nonpos_left (hlo j) hv') theorem abs_le_max_abs_abs_of_interval {a b x : Rat} (hlo : a ≤ x) (hhi : x ≤ b) : |x| ≤ max |a| |b| := by - by_cases hx : 0 ≤ x - · have hb : 0 ≤ b := le_trans hx hhi - have hx' : |x| = x := abs_of_nonneg hx - have hb' : |b| = b := abs_of_nonneg hb - calc - |x| = x := hx' - _ ≤ b := hhi - _ = |b| := hb'.symm - _ ≤ max |a| |b| := le_max_right _ _ - · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) - have ha : a ≤ 0 := le_trans hlo hx' - have hxabs : |x| = -x := abs_of_nonpos hx' - have haabs : |a| = -a := abs_of_nonpos ha - calc - |x| = -x := hxabs - _ ≤ -a := neg_le_neg hlo - _ = |a| := by simp [haabs] - _ ≤ max |a| |b| := le_max_left _ _ + exact abs_le_max_abs_abs hlo hhi /-- Global absolute bound from interval endpoints. -/ def intervalAbsBound {n : Nat} (lo hi : Fin n → Rat) : Rat := @@ -567,8 +458,7 @@ theorem abs_dotProduct_le_dotIntervalAbsBound {n : Nat} (v lo hi x : Fin n → R have habs : |dotProduct v x| ≤ max |dotIntervalLower v lo hi| |dotIntervalUpper v lo hi| := abs_le_max_abs_abs_of_interval hlow hhigh - unfold dotIntervalAbsBound - exact habs + simpa [dotIntervalAbsBound] using habs /-! Real-valued bounds from rational intervals. -/ @@ -828,15 +718,11 @@ theorem dotIntervalLower_le_dotProduct_real {n : Nat} (v lo hi : Fin n → Rat) refine Finset.sum_le_sum ?_ intro j _ by_cases hv : 0 ≤ v j - · have h1 : (v j : Real) * (lo j : Real) ≤ (v j : Real) * x j := by - have hv' : (0 : Real) ≤ (v j : Real) := ratToReal_nonneg_of_nonneg hv - exact mul_le_mul_of_nonneg_left (hlo j) hv' - simpa [hv] using h1 + · have hv' : (0 : Real) ≤ (v j : Real) := ratToReal_nonneg_of_nonneg hv + simpa [hv] using (mul_le_mul_of_nonneg_left (hlo j) hv') · have hv' : (v j : Real) ≤ 0 := by exact (ratToReal_nonpos_iff (x := v j)).2 (le_of_not_ge hv) - have h1 : (v j : Real) * (hi j : Real) ≤ (v j : Real) * x j := by - exact mul_le_mul_of_nonpos_left (hhi j) hv' - simpa [hv] using h1 + simpa [hv] using (mul_le_mul_of_nonpos_left (hhi j) hv') simpa [hcast, dotProduct] using hsum theorem dotProduct_le_dotIntervalUpper_real {n : Nat} (v lo hi : Fin n → Rat) @@ -856,37 +742,16 @@ theorem dotProduct_le_dotIntervalUpper_real {n : Nat} (v lo hi : Fin n → Rat) refine Finset.sum_le_sum ?_ intro j _ by_cases hv : 0 ≤ v j - · have h1 : (v j : Real) * x j ≤ (v j : Real) * (hi j : Real) := by - have hv' : (0 : Real) ≤ (v j : Real) := ratToReal_nonneg_of_nonneg hv - exact mul_le_mul_of_nonneg_left (hhi j) hv' - simpa [hv] using h1 + · have hv' : (0 : Real) ≤ (v j : Real) := ratToReal_nonneg_of_nonneg hv + simpa [hv] using (mul_le_mul_of_nonneg_left (hhi j) hv') · have hv' : (v j : Real) ≤ 0 := by exact (ratToReal_nonpos_iff (x := v j)).2 (le_of_not_ge hv) - have h1 : (v j : Real) * x j ≤ (v j : Real) * (lo j : Real) := by - exact mul_le_mul_of_nonpos_left (hlo j) hv' - simpa [hv] using h1 + simpa [hv] using (mul_le_mul_of_nonpos_left (hlo j) hv') simpa [hcast, dotProduct] using hsum theorem abs_le_max_abs_abs_of_interval_real {a b x : Real} (hlo : a ≤ x) (hhi : x ≤ b) : |x| ≤ max |a| |b| := by - by_cases hx : 0 ≤ x - · have hb : 0 ≤ b := le_trans hx hhi - have hx' : |x| = x := abs_of_nonneg hx - have hb' : |b| = b := abs_of_nonneg hb - calc - |x| = x := hx' - _ ≤ b := hhi - _ = |b| := hb'.symm - _ ≤ max |a| |b| := le_max_right _ _ - · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) - have ha : a ≤ 0 := le_trans hlo hx' - have hxabs : |x| = -x := abs_of_nonpos hx' - have haabs : |a| = -a := abs_of_nonpos ha - calc - |x| = -x := hxabs - _ ≤ -a := neg_le_neg hlo - _ = |a| := by simp [haabs] - _ ≤ max |a| |b| := le_max_left _ _ + exact abs_le_max_abs_abs hlo hhi /-- `intervalAbsBound` controls real-valued coordinates inside a rational interval. -/ theorem abs_le_intervalAbsBound_real {n : Nat} (lo hi : Fin n → Rat) (x : Fin n → Real) @@ -894,21 +759,14 @@ theorem abs_le_intervalAbsBound_real {n : Nat} (lo hi : Fin n → Rat) (x : Fin |x i| ≤ (intervalAbsBound lo hi : Real) := by have hbound : |x i| ≤ max |(lo i : Real)| |(hi i : Real)| := abs_le_max_abs_abs_of_interval_real (hlo i) (hhi i) + have hsup : max |lo i| |hi i| ≤ intervalAbsBound lo hi := + max_abs_le_intervalAbsBound lo hi i have hsup_real : max |(lo i : Real)| |(hi i : Real)| ≤ (intervalAbsBound lo hi : Real) := by - have hsup : max |lo i| |hi i| ≤ intervalAbsBound lo hi := - max_abs_le_intervalAbsBound lo hi i - have hlo : |lo i| ≤ intervalAbsBound lo hi := - le_trans (le_max_left _ _) hsup - have hhi : |hi i| ≤ intervalAbsBound lo hi := - le_trans (le_max_right _ _) hsup - have hlo_real : - |(lo i : Real)| ≤ (intervalAbsBound lo hi : Real) := by - exact ratToReal_abs_le_of_le hlo - have hhi_real : - |(hi i : Real)| ≤ (intervalAbsBound lo hi : Real) := by - exact ratToReal_abs_le_of_le hhi - exact max_le_iff.mpr ⟨hlo_real, hhi_real⟩ + refine max_le_iff.mpr ?_ + constructor + · exact ratToReal_abs_le_of_le (le_trans (le_max_left _ _) hsup) + · exact ratToReal_abs_le_of_le (le_trans (le_max_right _ _) hsup) exact le_trans hbound hsup_real theorem abs_dotProduct_le_dotIntervalAbsBound_real {n : Nat} (v lo hi : Fin n → Rat) @@ -925,11 +783,7 @@ theorem abs_dotProduct_le_dotIntervalAbsBound_real {n : Nat} (v lo hi : Fin n |dotProduct (fun j => (v j : Real)) x| ≤ max |(dotIntervalLower v lo hi : Real)| |(dotIntervalUpper v lo hi : Real)| := abs_le_max_abs_abs_of_interval_real hlow hhigh - have hcast : - (dotIntervalAbsBound v lo hi : Real) = - max |(dotIntervalLower v lo hi : Real)| |(dotIntervalUpper v lo hi : Real)| := by - simp [dotIntervalAbsBound] - simpa [hcast] using habs + simpa [dotIntervalAbsBound] using habs /-! Matrix-vector interval bounds. -/ @@ -937,17 +791,15 @@ theorem mulVecIntervalLower_le_mulVec {m n : Nat} (W : Matrix (Fin m) (Fin n) Ra (lo hi x : Fin n → Rat) (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : ∀ i, mulVecIntervalLower W lo hi i ≤ Matrix.mulVec W x i := by intro i - have h := - dotIntervalLower_le_dotProduct (v := fun j => W i j) lo hi x hlo hhi - simpa [mulVecIntervalLower, Matrix.mulVec, dotProduct] using h + simpa [mulVecIntervalLower, Matrix.mulVec, dotProduct] using + (dotIntervalLower_le_dotProduct (v := fun j => W i j) lo hi x hlo hhi) theorem mulVec_le_mulVecIntervalUpper {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (lo hi x : Fin n → Rat) (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : ∀ i, Matrix.mulVec W x i ≤ mulVecIntervalUpper W lo hi i := by intro i - have h := - dotProduct_le_dotIntervalUpper (v := fun j => W i j) lo hi x hlo hhi - simpa [mulVecIntervalUpper, Matrix.mulVec, dotProduct] using h + simpa [mulVecIntervalUpper, Matrix.mulVec, dotProduct] using + (dotProduct_le_dotIntervalUpper (v := fun j => W i j) lo hi x hlo hhi) theorem mulVecIntervalLower_le_upper {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (lo hi : Fin n → Rat) (hlohi : ∀ j, lo j ≤ hi j) : diff --git a/Nfp/Sound/Induction/CoreSound/Values.lean b/Nfp/Sound/Induction/CoreSound/Values.lean index 457b658..4938e70 100644 --- a/Nfp/Sound/Induction/CoreSound/Values.lean +++ b/Nfp/Sound/Induction/CoreSound/Values.lean @@ -26,44 +26,24 @@ theorem wvDir_real_eq_sum (inputs : Model.InductionHeadInputs seq dModel dHead) (hwvDir : ∀ j, wvDir j = Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) : ∀ j, (wvDir j : Real) = ∑ d, (dirHead d : Real) * (inputs.wv j d : Real) := by intro j - have hsum : - ((∑ d, dirHead d * inputs.wv j d : Rat) : Real) = - ∑ d, ((dirHead d * inputs.wv j d : Rat) : Real) := by - simp - have hsum' : - ∑ d, ((dirHead d * inputs.wv j d : Rat) : Real) = - ∑ d, (dirHead d : Real) * (inputs.wv j d : Real) := by - refine Finset.sum_congr rfl ?_ - intro d _ - simp - have hfinal := hsum.trans hsum' calc (wvDir j : Real) = ((∑ d, dirHead d * inputs.wv j d : Rat) : Real) := by simp [hwvDir j, Linear.dotFin_eq_dotProduct, dotProduct] - _ = ∑ d, (dirHead d : Real) * (inputs.wv j d : Real) := hfinal + _ = ∑ d, (dirHead d : Real) * (inputs.wv j d : Real) := by + simp /-- Cast a cached `bDir` dot to a Real-valued sum over head biases. -/ theorem bDir_real_eq_sum (inputs : Model.InductionHeadInputs seq dModel dHead) (dirHead : Fin dHead → Rat) (bDir : Rat) (hbDir : bDir = Linear.dotFin dHead dirHead (fun d => inputs.bv d)) : (bDir : Real) = ∑ d, (dirHead d : Real) * (inputs.bv d : Real) := by - have hsum : - ((∑ d, dirHead d * inputs.bv d : Rat) : Real) = - ∑ d, ((dirHead d * inputs.bv d : Rat) : Real) := by - simp - have hsum' : - ∑ d, ((dirHead d * inputs.bv d : Rat) : Real) = - ∑ d, (dirHead d : Real) * (inputs.bv d : Real) := by - refine Finset.sum_congr rfl ?_ - intro d _ - simp - have hfinal := hsum.trans hsum' calc (bDir : Real) = ((∑ d, dirHead d * inputs.bv d : Rat) : Real) := by simp [hbDir, Linear.dotFin_eq_dotProduct, dotProduct] - _ = ∑ d, (dirHead d : Real) * (inputs.bv d : Real) := hfinal + _ = ∑ d, (dirHead d : Real) * (inputs.bv d : Real) := by + simp /-- Rewrite direction values using cached `wvDir` and `bDir` sums. -/ theorem valsReal_eq_of_dir (inputs : Model.InductionHeadInputs seq dModel dHead) @@ -151,11 +131,7 @@ theorem valsReal_eq_of_dir (inputs : Model.InductionHeadInputs seq dModel dHead) have hb : dotProduct (fun d => (dirHead d : Real)) (fun d => (inputs.bv d : Real)) = (bDir : Real) := by - calc - dotProduct (fun d => (dirHead d : Real)) (fun d => (inputs.bv d : Real)) - = ∑ d, (dirHead d : Real) * (inputs.bv d : Real) := by - simp [dotProduct] - _ = (bDir : Real) := hdir_bv.symm + simpa [dotProduct] using hdir_bv.symm simp [hb] end Sound diff --git a/Nfp/Sound/Induction/HeadBounds.lean b/Nfp/Sound/Induction/HeadBounds.lean index 61ceb93..7acbd04 100644 --- a/Nfp/Sound/Induction/HeadBounds.lean +++ b/Nfp/Sound/Induction/HeadBounds.lean @@ -830,8 +830,7 @@ theorem headValueValsLoCommonDen_eq {seq dModel dHead : Nat} (vLo vHi : Fin seq → Fin dHead → Rat) : headValueValsLoCommonDen inputs vLo vHi = headValueValsLo inputs vLo vHi := by funext k - simp [headValueValsLoCommonDen, headValueValsLo, headValueValsLoCommonDenArray, - headValueValsLoArray] + simp [headValueValsLoCommonDen, headValueValsLo, headValueValsLoCommonDenArray_eq] /-- Cached upper value bounds from V intervals. -/ def headValueValsHiArray {seq dModel dHead : Nat} @@ -907,8 +906,7 @@ theorem headValueValsHiCommonDen_eq {seq dModel dHead : Nat} (vLo vHi : Fin seq → Fin dHead → Rat) : headValueValsHiCommonDen inputs vLo vHi = headValueValsHi inputs vLo vHi := by funext k - simp [headValueValsHiCommonDen, headValueValsHi, headValueValsHiCommonDenArray, - headValueValsHiArray] + simp [headValueValsHiCommonDen, headValueValsHi, headValueValsHiCommonDenArray_eq] /-- Global lower value bound from an array of per-key values. -/ def headValueLoArray (valsLo : Array Rat) : Rat := diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean index 90c14ac..89ddbd7 100644 --- a/Nfp/Sound/Induction/LogitDiff.lean +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -88,11 +88,7 @@ theorem logitDiffLowerBoundFromCert_le (lb : Real) ≤ (c.values.valsLo (c.prev q) : Real) - (c.epsAt q : Real) * ((c.values.hi : Real) - (c.values.lo : Real)) := by - have hboundReal' : - (lb : Real) ≤ - (c.values.valsLo (c.prev q) - c.epsAt q * (c.values.hi - c.values.lo) : Rat) := by - exact ratToReal_le_of_le hboundRat - simpa [ratToReal_sub, ratToReal_mul] using hboundReal' + simpa [ratToReal_sub, ratToReal_mul] using ratToReal_le_of_le hboundRat have hvalsLo : (c.values.valsLo (c.prev q) : Real) ≤ valsRealOfInputs inputs (c.prev q) := by diff --git a/Nfp/Sound/Induction/OneHot.lean b/Nfp/Sound/Induction/OneHot.lean index 048db7a..8c9eb93 100644 --- a/Nfp/Sound/Induction/OneHot.lean +++ b/Nfp/Sound/Induction/OneHot.lean @@ -45,6 +45,75 @@ theorem oneHot_bounds_at_of_marginAt Circuit.softmax (scoresReal q) k let others : Fin seq → Finset (Fin seq) := fun q => (Finset.univ : Finset (Fin seq)).erase (prev q) + have hweights_nonneg : ∀ k, 0 ≤ weights q k := by + intro k + simpa [weights] using + (Circuit.softmax_nonneg (scores := scoresReal q) k) + have hsum_one : (∑ k, weights q k) = 1 := by + simpa [weights] using + (Circuit.softmax_sum_one (scores := scoresReal q)) + have hsum_others_le : (∑ k ∈ others q, weights q k) ≤ (epsAt q : Real) := by + by_cases hneg : marginAt q < 0 + · have heps : (epsAt q : Real) = 1 := by + simp [hepsAt, hneg] + have hsubset : others q ⊆ (Finset.univ : Finset (Fin seq)) := by + intro k hk + simp + have hsum_le : + (∑ k ∈ others q, weights q k) ≤ + ∑ k ∈ (Finset.univ : Finset (Fin seq)), weights q k := + Finset.sum_le_sum_of_subset_of_nonneg hsubset (by + intro k _ _ + exact hweights_nonneg k) + have hsum_le' : (∑ k ∈ others q, weights q k) ≤ 1 := by + simpa [hsum_one] using hsum_le + simpa [heps] using hsum_le' + · have hnonneg : 0 ≤ marginAt q := le_of_not_gt hneg + have hnonneg_real : 0 ≤ (marginAt q : Real) := by + exact ratToReal_nonneg_of_nonneg hnonneg + have hbound : + ∀ k ∈ others q, + weights q k ≤ (1 + (marginAt q : Real))⁻¹ := by + intro k hk + have hkne : k ≠ prev q := (Finset.mem_erase.mp hk).1 + have hscore := hscore_margin_real_at q hq k hkne + simpa [weights] using + (Circuit.softmax_other_le_inv_one_add (scores := scoresReal q) + (prev := prev q) (k := k) (m := (marginAt q : Real)) + hnonneg_real hscore) + have hsum_le : + (∑ k ∈ others q, weights q k) ≤ + ∑ k ∈ others q, (1 + (marginAt q : Real))⁻¹ := + Finset.sum_le_sum hbound + have hsum_const : + (∑ k ∈ others q, (1 + (marginAt q : Real))⁻¹) = + (others q).card * (1 + (marginAt q : Real))⁻¹ := by + simp + have hcard : (others q).card = seq - 1 := by + simp [others, Finset.card_erase_of_mem] + have hsum_le' : + (∑ k ∈ others q, weights q k) ≤ + (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ := by + have hsum_le'' := hsum_le.trans_eq hsum_const + have hsum_le''' := hsum_le'' + simp only [hcard, Nat.cast_sub hseq, Nat.cast_one] at hsum_le''' + exact hsum_le''' + have heps : + (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ ≤ (epsAt q : Real) := by + have hden : (1 + marginAt q) ≠ 0 := by + intro hzero + have hrat : (1 : Rat) + marginAt q = 0 := by + simpa using hzero + have hnonneg_rat : (0 : Rat) ≤ marginAt q := hnonneg + linarith + have hrat : + (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ ≤ + (ratDivUp (seq - 1) (1 + marginAt q) : Real) := by + have hrat' := ratDivUp_ge_real (seq - 1) (1 + marginAt q) hden + simpa [ratToReal, Rat.cast_div, Rat.cast_add, Rat.cast_natCast, + div_eq_mul_inv] using hrat' + simpa [hepsAt, hneg] using hrat + exact le_trans hsum_le' heps refine { nonneg := ?_ sum_one := ?_ @@ -60,84 +129,12 @@ theorem oneHot_bounds_at_of_marginAt exact Circuit.softmax_sum_one (scores := scoresReal q) · intro q' hq' subst q' - have hsum_others_le : (∑ k ∈ others q, weights q k) ≤ (epsAt q : Real) := by - by_cases hneg : marginAt q < 0 - · have heps : (epsAt q : Real) = 1 := by - simp [hepsAt, hneg] - have hsubset : others q ⊆ (Finset.univ : Finset (Fin seq)) := by - intro k hk - simp - have hnonneg : - ∀ k ∈ (Finset.univ : Finset (Fin seq)), 0 ≤ weights q k := by - intro k _ - simpa [weights] using - (Circuit.softmax_nonneg (scores := scoresReal q) k) - have hsum_le : - (∑ k ∈ others q, weights q k) ≤ - ∑ k ∈ (Finset.univ : Finset (Fin seq)), weights q k := - Finset.sum_le_sum_of_subset_of_nonneg hsubset (by - intro k hk _; exact hnonneg k hk) - have hsum_one : (∑ k, weights q k) = 1 := by - simpa [weights] using - (Circuit.softmax_sum_one (scores := scoresReal q)) - have hsum_le' : (∑ k ∈ others q, weights q k) ≤ 1 := by - simpa [hsum_one] using hsum_le - simpa [heps] using hsum_le' - · have hnonneg : 0 ≤ marginAt q := le_of_not_gt hneg - have hnonneg_real : 0 ≤ (marginAt q : Real) := by - exact ratToReal_nonneg_of_nonneg hnonneg - have hbound : - ∀ k ∈ others q, - weights q k ≤ (1 + (marginAt q : Real))⁻¹ := by - intro k hk - have hkne : k ≠ prev q := (Finset.mem_erase.mp hk).1 - have hscore := hscore_margin_real_at q hq k hkne - simpa [weights] using - (Circuit.softmax_other_le_inv_one_add (scores := scoresReal q) - (prev := prev q) (k := k) (m := (marginAt q : Real)) - hnonneg_real hscore) - have hsum_le : - (∑ k ∈ others q, weights q k) ≤ - ∑ k ∈ others q, (1 + (marginAt q : Real))⁻¹ := - Finset.sum_le_sum hbound - have hsum_const : - (∑ k ∈ others q, (1 + (marginAt q : Real))⁻¹) = - (others q).card * (1 + (marginAt q : Real))⁻¹ := by - simp - have hcard : (others q).card = seq - 1 := by - simp [others, Finset.card_erase_of_mem] - have hsum_le' : - (∑ k ∈ others q, weights q k) ≤ - (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ := by - have hsum_le'' := hsum_le.trans_eq hsum_const - have hsum_le''' := hsum_le'' - simp only [hcard, Nat.cast_sub hseq, Nat.cast_one] at hsum_le''' - exact hsum_le''' - have heps : - (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ ≤ (epsAt q : Real) := by - have hden : (1 + marginAt q) ≠ 0 := by - intro hzero - have hrat : (1 : Rat) + marginAt q = 0 := by - simpa using hzero - have hnonneg_rat : (0 : Rat) ≤ marginAt q := hnonneg - linarith - have hrat : - (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ ≤ - (ratDivUp (seq - 1) (1 + marginAt q) : Real) := by - have hrat' := ratDivUp_ge_real (seq - 1) (1 + marginAt q) hden - simpa [ratToReal, Rat.cast_div, Rat.cast_add, Rat.cast_natCast, - div_eq_mul_inv] using hrat' - simpa [hepsAt, hneg] using hrat - exact le_trans hsum_le' heps have hsum_eq : weights q (prev q) + ∑ k ∈ others q, weights q k = 1 := by have hsum' : weights q (prev q) + ∑ k ∈ others q, weights q k = ∑ k, weights q k := by simp [others] - have hsum_one : (∑ k, weights q k) = 1 := by - simpa [weights] using - (Circuit.softmax_sum_one (scores := scoresReal q)) calc weights q (prev q) + ∑ k ∈ others q, weights q k = ∑ k, weights q k := hsum' @@ -153,82 +150,12 @@ theorem oneHot_bounds_at_of_marginAt exact hprev · intro q' hq' k hk subst q' - have hsum_others_le : (∑ j ∈ others q, weights q j) ≤ (epsAt q : Real) := by - by_cases hneg : marginAt q < 0 - · have heps : (epsAt q : Real) = 1 := by - simp [hepsAt, hneg] - have hsubset : others q ⊆ (Finset.univ : Finset (Fin seq)) := by - intro j hj - simp - have hnonneg : - ∀ j ∈ (Finset.univ : Finset (Fin seq)), 0 ≤ weights q j := by - intro j _ - simpa [weights] using - (Circuit.softmax_nonneg (scores := scoresReal q) j) - have hsum_le : - (∑ j ∈ others q, weights q j) ≤ - ∑ j ∈ (Finset.univ : Finset (Fin seq)), weights q j := - Finset.sum_le_sum_of_subset_of_nonneg hsubset (by - intro j hj _; exact hnonneg j hj) - have hsum_one : (∑ j, weights q j) = 1 := by - simpa [weights] using - (Circuit.softmax_sum_one (scores := scoresReal q)) - have hsum_le' : (∑ j ∈ others q, weights q j) ≤ 1 := by - simpa [hsum_one] using hsum_le - simpa [heps] using hsum_le' - · have hnonneg : 0 ≤ marginAt q := le_of_not_gt hneg - have hnonneg_real : 0 ≤ (marginAt q : Real) := by - exact ratToReal_nonneg_of_nonneg hnonneg - have hbound : - ∀ j ∈ others q, - weights q j ≤ (1 + (marginAt q : Real))⁻¹ := by - intro j hj - have hjne : j ≠ prev q := (Finset.mem_erase.mp hj).1 - have hscore := hscore_margin_real_at q hq j hjne - simpa [weights] using - (Circuit.softmax_other_le_inv_one_add (scores := scoresReal q) - (prev := prev q) (k := j) (m := (marginAt q : Real)) - hnonneg_real hscore) - have hsum_le : - (∑ j ∈ others q, weights q j) ≤ - ∑ j ∈ others q, (1 + (marginAt q : Real))⁻¹ := - Finset.sum_le_sum hbound - have hsum_const : - (∑ j ∈ others q, (1 + (marginAt q : Real))⁻¹) = - (others q).card * (1 + (marginAt q : Real))⁻¹ := by - simp - have hcard : (others q).card = seq - 1 := by - simp [others, Finset.card_erase_of_mem] - have hsum_le' : - (∑ j ∈ others q, weights q j) ≤ - (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ := by - have hsum_le'' := hsum_le.trans_eq hsum_const - have hsum_le''' := hsum_le'' - simp only [hcard, Nat.cast_sub hseq, Nat.cast_one] at hsum_le''' - exact hsum_le''' - have heps : - (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ ≤ (epsAt q : Real) := by - have hden : (1 + marginAt q) ≠ 0 := by - intro hzero - have hrat : (1 : Rat) + marginAt q = 0 := by - simpa using hzero - have hnonneg_rat : (0 : Rat) ≤ marginAt q := hnonneg - linarith - have hrat : - (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ ≤ - (ratDivUp (seq - 1) (1 + marginAt q) : Real) := by - have hrat' := ratDivUp_ge_real (seq - 1) (1 + marginAt q) hden - simpa [ratToReal, Rat.cast_div, Rat.cast_add, Rat.cast_natCast, - div_eq_mul_inv] using hrat' - simpa [hepsAt, hneg] using hrat - exact le_trans hsum_le' heps have hk' : k ∈ others q := by simp [others, hk] have hnonneg : ∀ j ∈ others q, 0 ≤ weights q j := by intro j _ - simpa [weights] using - (Circuit.softmax_nonneg (scores := scoresReal q) j) + exact hweights_nonneg j have hle : weights q k ≤ ∑ j ∈ others q, weights q j := by have h := Finset.single_le_sum hnonneg hk' diff --git a/Nfp/Sound/Linear/FinFold.lean b/Nfp/Sound/Linear/FinFold.lean index a8cff01..a224473 100644 --- a/Nfp/Sound/Linear/FinFold.lean +++ b/Nfp/Sound/Linear/FinFold.lean @@ -2,6 +2,7 @@ import Mathlib.Algebra.BigOperators.Fin import Mathlib.Data.Matrix.Mul +import Mathlib.Data.Rat.BigOperators import Batteries.Data.Fin.Fold import Nfp.Core.Basic @@ -47,40 +48,26 @@ theorem sumFin_eq_list_foldl (n : Nat) (f : Fin n → Rat) : theorem sumFin_eq_sum_univ {n : Nat} (f : Fin n → Rat) : sumFin n f = ∑ i, f i := by classical - have hfold : - sumFin n f = (List.finRange n).foldl (fun acc i => acc + f i) 0 := by - simpa using sumFin_eq_list_foldl n f - have hmap : - ((List.finRange n).map f).foldl (fun acc x : Rat => acc + x) 0 = - (List.finRange n).foldl (fun acc i => acc + f i) 0 := by - simpa using - (List.foldl_map (f := f) (g := fun acc x : Rat => acc + x) - (l := List.finRange n) (init := (0 : Rat))) - let _ : Std.Commutative (fun a b : Rat => a + b) := - ⟨by intro a b; exact add_comm _ _⟩ - let _ : Std.Associative (fun a b : Rat => a + b) := - ⟨by intro a b c; exact add_assoc _ _ _⟩ - have hfoldr : - ((List.finRange n).map f).foldl (fun acc x : Rat => acc + x) 0 = - ((List.finRange n).map f).foldr (fun x acc => x + acc) 0 := by - simpa using - (List.foldl_eq_foldr (f := fun acc x : Rat => acc + x) - (a := 0) (l := (List.finRange n).map f)) have hsum_list : ((List.finRange n).map f).sum = (List.finRange n).foldl (fun acc i => acc + f i) 0 := by + have hmap : + ((List.finRange n).map f).foldl (fun acc x : Rat => acc + x) 0 = + (List.finRange n).foldl (fun acc i => acc + f i) 0 := by + simpa using + (List.foldl_map (f := f) (g := fun acc x : Rat => acc + x) + (l := List.finRange n) (init := (0 : Rat))) calc ((List.finRange n).map f).sum - = ((List.finRange n).map f).foldr (fun x acc => x + acc) 0 := by - rfl - _ = ((List.finRange n).map f).foldl (fun acc x : Rat => acc + x) 0 := by - exact hfoldr.symm + = ((List.finRange n).map f).foldl (fun acc x : Rat => acc + x) 0 := by + simpa using (List.sum_eq_foldl (l := (List.finRange n).map f)) _ = (List.finRange n).foldl (fun acc i => acc + f i) 0 := by exact hmap have hsum_univ : ((List.finRange n).map f).sum = ∑ i, f i := by - exact (Fin.sum_univ_def f).symm + simpa using (Fin.sum_univ_def f).symm calc sumFin n f - = (List.finRange n).foldl (fun acc i => acc + f i) 0 := hfold + = (List.finRange n).foldl (fun acc i => acc + f i) 0 := by + simpa using sumFin_eq_list_foldl n f _ = ((List.finRange n).map f).sum := hsum_list.symm _ = ∑ i, f i := hsum_univ @@ -88,19 +75,13 @@ theorem sumFin_eq_sum_univ {n : Nat} (f : Fin n → Rat) : theorem ratToReal_sum_univ {n : Nat} (f : Fin n → Rat) : ratToReal (∑ i, f i) = ∑ i, ratToReal (f i) := by classical - refine Finset.induction_on (Finset.univ : Finset (Fin n)) ?_ ?_ - · simp - · intro a s ha hs - simp [Finset.sum_insert, ha, hs, ratToReal_add] + simp [ratToReal] /-- Casting a rational `sumFin` to `Real` commutes with summation. -/ theorem ratToReal_sumFin {n : Nat} (f : Fin n → Rat) : ratToReal (sumFin n f) = ∑ i, ratToReal (f i) := by classical - have hsum : sumFin n f = ∑ i, f i := sumFin_eq_sum_univ (f := f) - have hcast : ratToReal (∑ i, f i) = ∑ i, ratToReal (f i) := - ratToReal_sum_univ (f := f) - simpa [hsum] using hcast + simpa [sumFin_eq_sum_univ] using ratToReal_sum_univ (f := f) /-- `sumFinCommonDen` agrees with `sumFin`. -/ theorem sumFinCommonDen_eq_sumFin (n : Nat) (f : Fin n → Rat) : diff --git a/Nfp/System/LocalSystem.lean b/Nfp/System/LocalSystem.lean index 249d337..0ee1662 100644 --- a/Nfp/System/LocalSystem.lean +++ b/Nfp/System/LocalSystem.lean @@ -40,7 +40,8 @@ def toMixer (L : LocalSystem ι) (h : IsRowStochastic L) : Mixer ι ι := /-- Off-edge weights are zero. -/ theorem weight_eq_zero_of_not_parent (L : LocalSystem ι) {i j : ι} (h : ¬ L.dag.rel j i) : L.weight i j = 0 := - L.support i j h + by + simpa using L.support i j h /-- One-step evaluation functional used by `eval`. -/ def evalStep (L : LocalSystem ι) (input : ι → Mass) @@ -57,11 +58,12 @@ theorem eval_eq (L : LocalSystem ι) (input : ι → Mass) (i : ι) : eval L input i = input i + ∑ j, (if _ : L.dag.rel j i then L.weight i j * eval L input j else 0) := by + classical set F : ∀ i, (∀ j, L.dag.rel j i → Mass) → Mass := fun i rec => evalStep L input i rec change L.dag.wf.fix F i = input i + ∑ j, (if _ : L.dag.rel j i then L.weight i j * L.dag.wf.fix F j else 0) rw [WellFounded.fix_eq] - dsimp [F, evalStep] + simp [F, evalStep] end LocalSystem From 67efe483deae33bac67c97c0a7a18ec9f44c4167 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sun, 11 Jan 2026 02:26:37 +0100 Subject: [PATCH 126/244] tighten induction proof scaffolding --- Nfp/Sound/Induction/LogitDiff.lean | 2 +- Nfp/Sound/Induction/OneHot.lean | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean index 89ddbd7..ec6fe19 100644 --- a/Nfp/Sound/Induction/LogitDiff.lean +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -172,7 +172,7 @@ def buildInductionLogitLowerBoundNonvacuous? | none => exact none | some base => by_cases hpos : 0 < base.lb - · exact some { base := base, lb_pos := hpos } + · exact some ⟨base, hpos⟩ · exact none end LogitDiffLowerBound diff --git a/Nfp/Sound/Induction/OneHot.lean b/Nfp/Sound/Induction/OneHot.lean index 8c9eb93..6f12e26 100644 --- a/Nfp/Sound/Induction/OneHot.lean +++ b/Nfp/Sound/Induction/OneHot.lean @@ -65,9 +65,7 @@ theorem oneHot_bounds_at_of_marginAt Finset.sum_le_sum_of_subset_of_nonneg hsubset (by intro k _ _ exact hweights_nonneg k) - have hsum_le' : (∑ k ∈ others q, weights q k) ≤ 1 := by - simpa [hsum_one] using hsum_le - simpa [heps] using hsum_le' + simpa [heps, hsum_one] using hsum_le · have hnonneg : 0 ≤ marginAt q := le_of_not_gt hneg have hnonneg_real : 0 ≤ (marginAt q : Real) := by exact ratToReal_nonneg_of_nonneg hnonneg @@ -142,8 +140,8 @@ theorem oneHot_bounds_at_of_marginAt have hsum_le' : weights q (prev q) + ∑ k ∈ others q, weights q k ≤ weights q (prev q) + (epsAt q : Real) := by - have hsum_le'' := add_le_add_left hsum_others_le (weights q (prev q)) - simpa [add_comm, add_left_comm, add_assoc] using hsum_le'' + simpa [add_comm, add_left_comm, add_assoc] using + (add_le_add_left hsum_others_le (weights q (prev q))) have hprev : 1 ≤ weights q (prev q) + (epsAt q : Real) := by simpa [hsum_eq] using hsum_le' From 0250f48e910161cd225963fa8f40c937119f9e39 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sun, 11 Jan 2026 02:41:45 +0100 Subject: [PATCH 127/244] simplify CoreSound proof steps --- Nfp/Sound/Induction/CoreSound.lean | 44 +++++++++--------------------- 1 file changed, 13 insertions(+), 31 deletions(-) diff --git a/Nfp/Sound/Induction/CoreSound.lean b/Nfp/Sound/Induction/CoreSound.lean index 2515ed0..9cef2a5 100644 --- a/Nfp/Sound/Induction/CoreSound.lean +++ b/Nfp/Sound/Induction/CoreSound.lean @@ -938,14 +938,11 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} (scoreLoPrev q : Real) - (scoreHi q k : Real) ≤ (scoreLoPrev q : Real) - scoresReal q k := sub_le_sub_left hscore_hi (scoreLoPrev q : Real) - have hsum_le'' := add_le_add_left hsub (scoresReal q k) - have hsum_le''' : - (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k ≤ - (scoreLoPrev q : Real) - scoresReal q k + scoresReal q k := by - simpa [add_comm, add_left_comm, add_assoc] using hsum_le'' calc (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k - ≤ (scoreLoPrev q : Real) - scoresReal q k + scoresReal q k := hsum_le''' + ≤ (scoreLoPrev q : Real) - scoresReal q k + scoresReal q k := by + simpa [add_comm, add_left_comm, add_assoc] using + (add_le_add_left hsub (scoresReal q k)) _ = (scoreLoPrev q : Real) := by simp [sub_add_cancel] calc @@ -1163,9 +1160,7 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} have hsum_one : (∑ k, weights q k) = 1 := by simpa [weights] using (Circuit.softmax_sum_one (scores := scoresReal q)) - have hsum_le' : (∑ k ∈ others q, weights q k) ≤ 1 := by - simpa [hsum_one] using hsum_le - simpa [heps] using hsum_le' + simpa [heps, hsum_one] using hsum_le · have hnonneg : 0 ≤ margin := le_of_not_gt hneg have hnonneg_real : 0 ≤ (margin : Real) := by exact ratToReal_nonneg_of_nonneg hnonneg @@ -1192,10 +1187,8 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} have hsum_le' : (∑ k ∈ others q, weights q k) ≤ (seq - 1 : Real) * (1 + (margin : Real))⁻¹ := by - have hsum_le'' := hsum_le.trans_eq hsum_const - have hsum_le''' := hsum_le'' - simp only [hcard, Nat.cast_sub hseq, Nat.cast_one] at hsum_le''' - exact hsum_le''' + simpa [hcard, Nat.cast_sub hseq, Nat.cast_one] using + (hsum_le.trans_eq hsum_const) have hpos : (0 : Rat) < 1 + margin := by have hone : (0 : Rat) < 1 := by exact zero_lt_one @@ -1226,17 +1219,11 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} have hsum_le' : weights q (inputs.prev q) + ∑ k ∈ others q, weights q k ≤ weights q (inputs.prev q) + (eps : Real) := by - have hsum_le'' := add_le_add_left hsum_others_le (weights q (inputs.prev q)) - have hsum_le''' := hsum_le'' - rw [add_comm (∑ k ∈ others q, weights q k) - (weights q (inputs.prev q))] at hsum_le''' - rw [add_comm (eps : Real) (weights q (inputs.prev q))] at hsum_le''' - exact hsum_le''' + simpa [add_comm, add_left_comm, add_assoc] using + (add_le_add_left hsum_others_le (weights q (inputs.prev q))) have hprev : 1 ≤ weights q (inputs.prev q) + (eps : Real) := by - have hsum_le'' := hsum_le' - rw [hsum_eq] at hsum_le'' - exact hsum_le'' + simpa [hsum_eq] using hsum_le' exact hprev · intro q hq k hk have hsum_others_le : (∑ j ∈ others q, weights q j) ≤ (eps : Real) := by @@ -1259,9 +1246,7 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} have hsum_one : (∑ j, weights q j) = 1 := by simpa [weights] using (Circuit.softmax_sum_one (scores := scoresReal q)) - have hsum_le' : (∑ j ∈ others q, weights q j) ≤ 1 := by - simpa [hsum_one] using hsum_le - simpa [heps] using hsum_le' + simpa [heps, hsum_one] using hsum_le · have hnonneg : 0 ≤ margin := le_of_not_gt hneg have hnonneg_real : 0 ≤ (margin : Real) := by exact ratToReal_nonneg_of_nonneg hnonneg @@ -1288,10 +1273,8 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} have hsum_le' : (∑ j ∈ others q, weights q j) ≤ (seq - 1 : Real) * (1 + (margin : Real))⁻¹ := by - have hsum_le'' := hsum_le.trans_eq hsum_const - have hsum_le''' := hsum_le'' - simp only [hcard, Nat.cast_sub hseq, Nat.cast_one] at hsum_le''' - exact hsum_le''' + simpa [hcard, Nat.cast_sub hseq, Nat.cast_one] using + (hsum_le.trans_eq hsum_const) have hpos : (0 : Rat) < 1 + margin := by have hone : (0 : Rat) < 1 := by exact zero_lt_one @@ -1315,8 +1298,7 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} (Circuit.softmax_nonneg (scores := scoresReal q) j) have hle : weights q k ≤ ∑ j ∈ others q, weights q j := by - have h := Finset.single_le_sum hnonneg hk' - simpa using h + simpa using (Finset.single_le_sum hnonneg hk') exact hle.trans hsum_others_le have hepsAt : ∀ q, epsAt q = From 2eed498a5cdf4c0f06c76acf6869efb4a33daae4 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sun, 11 Jan 2026 02:59:59 +0100 Subject: [PATCH 128/244] Adjusted guidance on CLI command verification --- AGENTS.md | 6 ------ 1 file changed, 6 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 502d212..47a1f9f 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -26,12 +26,6 @@ but keep the core invariants and the “no fake proofs” ethos. One of these typically works (depending on your Lake setup): - `lake exe nfp --help` -If you add or change CLI behavior, validate at least: -- `lake exe nfp --help` (or `nfp --help` if on PATH) -- `lake exe nfp analyze --help` (or `nfp analyze --help`) -- `lake exe nfp induction --help` (or `nfp induction --help`) -- `lake exe nfp --version` (or `nfp --version`) if supported - ### Search tips Note: `models/` is gitignored, so `rg` will skip it unless you pass `--no-ignore` or `-uuu` (or equivalent) when searching. From 721ed128c498f54d4fd3c057d09f505e6c0941d4 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sun, 11 Jan 2026 03:00:33 +0100 Subject: [PATCH 129/244] refactor induction bounds proofs --- Nfp/Sound/Induction/CoreSound.lean | 30 ++++++++++------------------ Nfp/Sound/Induction/HeadOutput.lean | 31 +++++++++++++---------------- 2 files changed, 24 insertions(+), 37 deletions(-) diff --git a/Nfp/Sound/Induction/CoreSound.lean b/Nfp/Sound/Induction/CoreSound.lean index 9cef2a5..dad5344 100644 --- a/Nfp/Sound/Induction/CoreSound.lean +++ b/Nfp/Sound/Induction/CoreSound.lean @@ -1359,13 +1359,12 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} (lnAbsMax k : Real) ≤ (lnAbsMaxMax : Real) := ratToReal_le_of_le (hln_abs_max k) have hsum_nonneg : 0 ≤ (Linear.sumFin dModel (fun j => |wvDir j|) : Real) := by - have hsum_nonneg' : 0 ≤ (Linear.sumFin dModel (fun j => |wvDir j|) : Rat) := by - have hsum_nonneg'' : 0 ≤ ∑ j, |wvDir j| := by - refine Finset.sum_nonneg ?_ - intro j _ - exact abs_nonneg _ - simpa [Linear.sumFin_eq_sum_univ] using hsum_nonneg'' - exact ratToReal_nonneg_of_nonneg hsum_nonneg' + refine ratToReal_nonneg_of_nonneg ?_ + have : 0 ≤ ∑ j, |wvDir j| := by + refine Finset.sum_nonneg ?_ + intro j _ + exact abs_nonneg _ + simpa [Linear.sumFin_eq_sum_univ] using this have hmul : (Linear.sumFin dModel (fun j => |wvDir j|) : Real) * (lnAbsMax k : Real) ≤ (Linear.sumFin dModel (fun j => |wvDir j|) : Real) * (lnAbsMaxMax : Real) := @@ -1398,9 +1397,7 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} change lo ≤ valsLo k0 dsimp [lo] refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) - (f := valsLo) (a := valsLo k0)).2 ?_ - refine ⟨k0, hmem0, ?_⟩ - exact le_rfl + (f := valsLo) (a := valsLo k0)).2 ⟨k0, hmem0, le_rfl⟩ exact ratToReal_le_of_le hloRat have hvals : (valCert.valsLo k0 : Real) ≤ valsRealOfInputs inputs k0 ∧ @@ -1411,34 +1408,27 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} change valsHi k0 ≤ hi dsimp [hi] refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) - (f := valsHi) (a := valsHi k0)).2 ?_ - exact ⟨k0, ⟨hmem0, le_rfl⟩⟩ + (f := valsHi) (a := valsHi k0)).2 ⟨k0, ⟨hmem0, le_rfl⟩⟩ exact ratToReal_le_of_le hhiRat have hreal : (valCert.lo : Real) ≤ (valCert.hi : Real) := le_trans hlo (le_trans hvals.1 (le_trans hvals.2 hhi)) exact (ratToReal_le_iff (x := valCert.lo) (y := valCert.hi)).1 hreal · intro k - have hmem : k ∈ univ := by simp [univ] have hloRat : valCert.lo ≤ valCert.valsLo k := by change lo ≤ valsLo k dsimp [lo] refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) - (f := valsLo) (a := valsLo k)).2 ?_ - refine ⟨k, hmem, ?_⟩ - exact le_rfl + (f := valsLo) (a := valsLo k)).2 ⟨k, by simp [univ], le_rfl⟩ exact ratToReal_le_of_le hloRat · intro k exact hvals_bounds_at k · intro k - have hmem : k ∈ univ := by simp [univ] have hhiRat : valCert.valsHi k ≤ valCert.hi := by change valsHi k ≤ hi dsimp [hi] refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) - (f := valsHi) (a := valsHi k)).2 ?_ - refine ⟨k, hmem, ?_⟩ - exact le_rfl + (f := valsHi) (a := valsHi k)).2 ⟨k, by simp [univ], le_rfl⟩ exact ratToReal_le_of_le hhiRat exact { softmax_bounds := hsoftmax_bounds diff --git a/Nfp/Sound/Induction/HeadOutput.lean b/Nfp/Sound/Induction/HeadOutput.lean index 649f108..7ba3966 100644 --- a/Nfp/Sound/Induction/HeadOutput.lean +++ b/Nfp/Sound/Induction/HeadOutput.lean @@ -328,22 +328,20 @@ def buildHeadOutputIntervalFromHead? [NeZero seq] hlower'' have hlo : (loOut i : Real) ≤ (boundLoRat q i : Real) := by - have hloRat : loOut i ≤ boundLoRat q i := by - simpa [loOut, hactive] using - (Finset.inf'_le - (s := activeSet) - (f := fun q => boundLoRat q i) - (b := q) hq) - exact ratToReal_le_of_le hloRat + refine ratToReal_le_of_le ?_ + simpa [loOut, hactive] using + (Finset.inf'_le + (s := activeSet) + (f := fun q => boundLoRat q i) + (b := q) hq) have hhi : (boundHiRat q i : Real) ≤ (hiOut i : Real) := by - have hhiRat : boundHiRat q i ≤ hiOut i := by - simpa [hiOut, hactive] using - (Finset.le_sup' - (s := activeSet) - (f := fun q => boundHiRat q i) - (b := q) hq) - exact ratToReal_le_of_le hhiRat + refine ratToReal_le_of_le ?_ + simpa [hiOut, hactive] using + (Finset.le_sup' + (s := activeSet) + (f := fun q => boundHiRat q i) + (b := q) hq) exact ⟨le_trans hlo hlower, le_trans hupper hhi⟩ have hbounds : Circuit.ResidualIntervalBounds { lo := loOut, hi := hiOut } := by refine { lo_le_hi := ?_ } @@ -351,9 +349,8 @@ def buildHeadOutputIntervalFromHead? [NeZero seq] by_cases hactive : activeSet.Nonempty · rcases hactive with ⟨q, hq⟩ have hout_i := hout q hq i - have hleReal : (loOut i : Real) ≤ (hiOut i : Real) := - le_trans hout_i.1 hout_i.2 - exact (ratToReal_le_iff (x := loOut i) (y := hiOut i)).1 hleReal + exact (ratToReal_le_iff (x := loOut i) (y := hiOut i)).1 + (le_trans hout_i.1 hout_i.2) · simp [loOut, hiOut, hactive] let certOut : Circuit.ResidualIntervalCert dModel := { lo := loOut, hi := hiOut } exact some From c5b774d3011b888ac5923e11e6ec64307bd9feed Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sun, 11 Jan 2026 11:03:03 +0100 Subject: [PATCH 130/244] Tighten induction margins and add auto head CLI --- CLAIMS.md | 2 + Nfp/Cli.lean | 64 +++++++++++++++ Nfp/IO/InductionHead.lean | 144 ++++++++++++++++++++++++++++++++++ Nfp/Sound/Induction/Core.lean | 8 +- README.md | 16 ++++ SOUNDNESS_LIMITATIONS.md | 2 + 6 files changed, 232 insertions(+), 4 deletions(-) diff --git a/CLAIMS.md b/CLAIMS.md index 3b1ac7a..c77870a 100644 --- a/CLAIMS.md +++ b/CLAIMS.md @@ -30,6 +30,8 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri - `nfp induction certify_head_model` reads a model binary, derives head inputs in Lean, and verifies the resulting induction certificate (includes attention projection biases and derives `prev`/active from the stored token sequence by default). +- `nfp induction certify_head_model_auto` derives the logit-diff direction from the prompt + tokens stored in the model file before running the same head-input checker. - `nfp induction certify_end_to_end` composes a head-level logit-diff lower bound with a downstream error certificate (arithmetic consistency only). - `nfp induction certify_end_to_end_matrix` computes a downstream bound from a matrix payload diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index ec4998f..545ab3b 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -247,6 +247,32 @@ def runInductionCertifyHeadModelNonvacuous (p : Parsed) : IO UInt32 := do IO.runInductionCertifyHeadModelNonvacuous modelPath layer head dirTarget dirNegative period? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? +/-- `nfp induction certify_head_model_auto` subcommand. -/ +def runInductionCertifyHeadModelAuto (p : Parsed) : IO UInt32 := do + let modelPath := p.flag! "model" |>.as! String + let layer := p.flag! "layer" |>.as! Nat + let head := p.flag! "head" |>.as! Nat + let period? := (p.flag? "period").map (·.as! Nat) + let minActive? := (p.flag? "min-active").map (·.as! Nat) + let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) + let minMarginStr? := (p.flag? "min-margin").map (·.as! String) + let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) + IO.runInductionCertifyHeadModelAuto modelPath layer head period? + minActive? minLogitDiffStr? minMarginStr? maxEpsStr? + +/-- `nfp induction certify_head_model_auto_nonvacuous` subcommand. -/ +def runInductionCertifyHeadModelAutoNonvacuous (p : Parsed) : IO UInt32 := do + let modelPath := p.flag! "model" |>.as! String + let layer := p.flag! "layer" |>.as! Nat + let head := p.flag! "head" |>.as! Nat + let period? := (p.flag? "period").map (·.as! Nat) + let minActive? := (p.flag? "min-active").map (·.as! Nat) + let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) + let minMarginStr? := (p.flag? "min-margin").map (·.as! String) + let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) + IO.runInductionCertifyHeadModelAutoNonvacuous modelPath layer head period? + minActive? minLogitDiffStr? minMarginStr? maxEpsStr? + /-- `nfp induction certify_head_model` subcommand. -/ def inductionCertifyHeadModelCmd : Cmd := `[Cli| certify_head_model VIA runInductionCertifyHeadModel; @@ -285,6 +311,42 @@ def inductionCertifyHeadModelNonvacuousCmd : Cmd := `[Cli| "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." ] +/-- `nfp induction certify_head_model_auto` subcommand. -/ +def inductionCertifyHeadModelAutoCmd : Cmd := `[Cli| + certify_head_model_auto VIA runInductionCertifyHeadModelAuto; + "Check induction certificates by reading a model binary and deriving the direction \ + from the prompt tokens." + FLAGS: + model : String; "Path to the NFP_BINARY_V1 model file." + layer : Nat; "Layer index for the induction head." + head : Nat; "Head index for the induction head." + period : Nat; "Optional prompt period override (default: derive from tokens)." + "min-active" : Nat; "Optional minimum number of active queries required \ + (default: max 1 (seq/8))." + "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ + (rational literal). Defaults to 0." + "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." + "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." +] + +/-- `nfp induction certify_head_model_auto_nonvacuous` subcommand. -/ +def inductionCertifyHeadModelAutoNonvacuousCmd : Cmd := `[Cli| + certify_head_model_auto_nonvacuous VIA runInductionCertifyHeadModelAutoNonvacuous; + "Require a strictly positive logit-diff bound from a model binary, with the direction \ + derived from the prompt tokens." + FLAGS: + model : String; "Path to the NFP_BINARY_V1 model file." + layer : Nat; "Layer index for the induction head." + head : Nat; "Head index for the induction head." + period : Nat; "Optional prompt period override (default: derive from tokens)." + "min-active" : Nat; "Optional minimum number of active queries required \ + (default: max 1 (seq/8))." + "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ + (rational literal; default: 0)." + "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." + "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." +] + /-- `nfp induction head_interval` subcommand. -/ def runInductionHeadInterval (p : Parsed) : IO UInt32 := do let inputsPath := p.flag! "inputs" |>.as! String @@ -339,6 +401,8 @@ def inductionCmd : Cmd := `[Cli| inductionCertifyHeadNonvacuousCmd; inductionCertifyHeadModelCmd; inductionCertifyHeadModelNonvacuousCmd; + inductionCertifyHeadModelAutoCmd; + inductionCertifyHeadModelAutoNonvacuousCmd; inductionHeadIntervalCmd; inductionHeadIntervalModelCmd ] diff --git a/Nfp/IO/InductionHead.lean b/Nfp/IO/InductionHead.lean index 97abeeb..dfdba89 100644 --- a/Nfp/IO/InductionHead.lean +++ b/Nfp/IO/InductionHead.lean @@ -1137,6 +1137,150 @@ def runInductionCertifyHeadModelNonvacuous (modelPath : System.FilePath) | Except.ok inputs => checkInductionHeadInputsNonvacuous inputs minActive? minLogitDiff? minMargin maxEps +/-- Heuristic logit-diff direction derived from prompt tokens. -/ +private def deriveDirectionFromTokens {seq : Nat} (tokens : Fin seq → Nat) : + Except String (Nat × Nat) := do + let tokenArr : Array Nat := Array.ofFn (fun i : Fin seq => tokens i) + let n := tokenArr.size + if n < 2 then + throw "token sequence must have length at least 2" + let lastTok := tokenArr.getD (n - 1) 0 + let prevIdx? := + (List.range (n - 1)).reverse.find? (fun i => + tokenArr.getD i lastTok = lastTok) + let targetTok := + match prevIdx? with + | some i => tokenArr.getD (i + 1) lastTok + | none => lastTok + let neg0 := tokenArr.getD (n - 2) lastTok + let neg := + if neg0 = targetTok then + if lastTok ≠ targetTok then + lastTok + else if targetTok ≠ 0 then + 0 + else + 1 + else + neg0 + return (targetTok, neg) + +/-- Build and check induction certificates from a model binary, deriving direction tokens from the +prompt sequence. -/ +def runInductionCertifyHeadModelAuto (modelPath : System.FilePath) + (layer head : Nat) (period? : Option Nat) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) + logTiming "start: read model file" + IO.println "timing: read model file start" + flushStdout + let data ← timePhase "read model file" <| IO.FS.readBinFile modelPath + let headerE ← timePure "parse model header" (fun () => + NfptPure.parseHeader data) + match headerE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨header, start⟩ => + let tokensE ← timePure "read prompt tokens" (fun () => + NfptPure.readTokens data start header) + match tokensE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok tokens => + match deriveDirectionFromTokens tokens with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨dirTarget, dirNegative⟩ => + IO.println + s!"info: direction-target={dirTarget} direction-negative={dirNegative}" + let inputsE ← timePure "read head inputs" (fun () => + NfptPure.readInductionHeadInputs + data start header layer head dirTarget dirNegative period?) + match inputsE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok inputs => + checkInductionHeadInputs inputs minActive? minLogitDiff? + minMargin maxEps + +/-- Build and check a strictly positive induction logit-diff bound from a model binary, deriving +direction tokens from the prompt sequence. -/ +def runInductionCertifyHeadModelAutoNonvacuous (modelPath : System.FilePath) + (layer head : Nat) (period? : Option Nat) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) + logTiming "start: read model file" + IO.println "timing: read model file start" + flushStdout + let data ← timePhase "read model file" <| IO.FS.readBinFile modelPath + let headerE ← timePure "parse model header" (fun () => + NfptPure.parseHeader data) + match headerE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨header, start⟩ => + let tokensE ← timePure "read prompt tokens" (fun () => + NfptPure.readTokens data start header) + match tokensE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok tokens => + match deriveDirectionFromTokens tokens with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨dirTarget, dirNegative⟩ => + IO.println + s!"info: direction-target={dirTarget} direction-negative={dirNegative}" + let inputsE ← timePure "read head inputs" (fun () => + NfptPure.readInductionHeadInputs + data start header layer head dirTarget dirNegative period?) + match inputsE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok inputs => + checkInductionHeadInputsNonvacuous inputs minActive? minLogitDiff? + minMargin maxEps + /-- Build head-output interval bounds from exact head inputs. -/ def runInductionHeadInterval (inputsPath : System.FilePath) (outPath? : Option System.FilePath) : IO UInt32 := do diff --git a/Nfp/Sound/Induction/Core.lean b/Nfp/Sound/Induction/Core.lean index 1834cdf..480d66a 100644 --- a/Nfp/Sound/Induction/Core.lean +++ b/Nfp/Sound/Induction/Core.lean @@ -236,10 +236,10 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} simp [hsize]) let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k - let splitBudgetQ : Nat := 2 - let splitBudgetK : Nat := 2 - let splitBudgetDiffBase : Nat := 0 - let splitBudgetDiffRefined : Nat := 12 + let splitBudgetQ : Nat := 4 + let splitBudgetK : Nat := 4 + let splitBudgetDiffBase : Nat := 4 + let splitBudgetDiffRefined : Nat := 16 let splitDimsQ : Fin seq → List (Fin dHead) := fun q => let ambig := (List.finRange dHead).filter (fun d => decide (qLo q d < 0 ∧ 0 < qHi q d)) diff --git a/README.md b/README.md index 02d1be1..b56fdef 100644 --- a/README.md +++ b/README.md @@ -107,6 +107,22 @@ By default, `certify_head_model` derives the `prev` map and active set from the token sequence stored in the model file. Use `--period ` to override with a fixed periodic prompt. +### GPT2-small (model-driven) + +To certify induction heads from GPT2-small weights, export a model binary and +let the CLI derive the logit-diff direction from the stored prompt tokens +(prefix matching: [A][B] ... [A] -> [B]): + +```bash +python scripts/export_gpt2.py models/gpt2_small.nfpt + +lake exe nfp induction certify_head_model_auto \ + --model models/gpt2_small.nfpt \ + --layer 5 --head 1 +``` + +Use `--period ` to override the prompt period derived from tokens. + ### End-to-end check with downstream bound (prototype) ```bash diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index fa1af52..67fcb45 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -19,6 +19,8 @@ It is intentionally brief and focused on the soundness boundary. attention projection biases, and derives `prev`/active from the stored token sequence by default, but still ignores LayerNorm and the shared attention output bias. It currently requires `head_dim` to be a perfect square to represent the scale as an exact rational. +- The `certify_head_model_auto` path derives the logit-diff direction from the stored prompt + tokens using a heuristic; use explicit direction tokens for fixed claims. - Performance: exact head-input recomputation in Lean can be slow for nontrivial sequence lengths. - There is no bridge theorem connecting certificate validity to a full circuit/model semantics statement (for example, a formal statement about logits under a transformer block stack). From c56dfd49ce7cd6d636793306066876afe8e3f2bc Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sun, 11 Jan 2026 11:28:28 +0100 Subject: [PATCH 131/244] Tune sign-split budgets for tighter QK bounds --- Nfp/Sound/Induction/Core.lean | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Nfp/Sound/Induction/Core.lean b/Nfp/Sound/Induction/Core.lean index 480d66a..94b3dc4 100644 --- a/Nfp/Sound/Induction/Core.lean +++ b/Nfp/Sound/Induction/Core.lean @@ -236,10 +236,10 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} simp [hsize]) let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k - let splitBudgetQ : Nat := 4 - let splitBudgetK : Nat := 4 - let splitBudgetDiffBase : Nat := 4 - let splitBudgetDiffRefined : Nat := 16 + let splitBudgetQ : Nat := 3 + let splitBudgetK : Nat := 3 + let splitBudgetDiffBase : Nat := 2 + let splitBudgetDiffRefined : Nat := 12 let splitDimsQ : Fin seq → List (Fin dHead) := fun q => let ambig := (List.finRange dHead).filter (fun d => decide (qLo q d < 0 ∧ 0 < qHi q d)) From 83da20bd19f50da9b15c112798e747d591b21537 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sun, 11 Jan 2026 11:37:50 +0100 Subject: [PATCH 132/244] Increase LayerNorm sqrt precision --- Nfp/Sound/Bounds/LayerNorm.lean | 2 +- Nfp/Sound/Induction/Core.lean | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Nfp/Sound/Bounds/LayerNorm.lean b/Nfp/Sound/Bounds/LayerNorm.lean index c742a23..e9a5d10 100644 --- a/Nfp/Sound/Bounds/LayerNorm.lean +++ b/Nfp/Sound/Bounds/LayerNorm.lean @@ -67,7 +67,7 @@ def sqrtUpperAlt (q : Rat) : Rat := ratRoundUp ((a + 1 : Rat) / den) /-- Extra precision scale for `sqrtLowerScaled`. -/ -def sqrtLowerScale : Nat := 65536 +def sqrtLowerScale : Nat := 1048576 /-- Scaled rational lower bound for a square root (extra precision). -/ def sqrtLowerScaled (q : Rat) : Rat := diff --git a/Nfp/Sound/Induction/Core.lean b/Nfp/Sound/Induction/Core.lean index 94b3dc4..1834cdf 100644 --- a/Nfp/Sound/Induction/Core.lean +++ b/Nfp/Sound/Induction/Core.lean @@ -236,9 +236,9 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} simp [hsize]) let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k - let splitBudgetQ : Nat := 3 - let splitBudgetK : Nat := 3 - let splitBudgetDiffBase : Nat := 2 + let splitBudgetQ : Nat := 2 + let splitBudgetK : Nat := 2 + let splitBudgetDiffBase : Nat := 0 let splitBudgetDiffRefined : Nat := 12 let splitDimsQ : Fin seq → List (Fin dHead) := fun q => let ambig := From e525dde7fd63b24c604aa7306d318b15fb31b09a Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sun, 11 Jan 2026 11:56:10 +0100 Subject: [PATCH 133/244] Tighten value bounds per position --- Nfp/Sound/Induction/Core.lean | 14 +++------ Nfp/Sound/Induction/CoreSound.lean | 46 ++++++------------------------ 2 files changed, 12 insertions(+), 48 deletions(-) diff --git a/Nfp/Sound/Induction/Core.lean b/Nfp/Sound/Induction/Core.lean index 1834cdf..9f03f88 100644 --- a/Nfp/Sound/Induction/Core.lean +++ b/Nfp/Sound/Induction/Core.lean @@ -115,10 +115,6 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} have hsize : lnAbsMaxArr.size = seq := by simp [lnAbsMaxArr] simp [hsize]) - let lnAbsMaxMax : Rat := - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := Finset.univ_nonempty - univ.sup' hnonempty (fun q => lnAbsMax q) let invStdBoundsTasks : Array (Task (Rat × Rat)) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => invStdBounds inputs.lnEps (inputs.embed q))) @@ -496,12 +492,10 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) let bDir : Rat := Linear.dotFin dHead dirHead (fun d => inputs.bv d) - let valsAbsBase : Rat := - Linear.sumFin dModel (fun j => |wvDir j|) * lnAbsMaxMax - let valsLoBase := bDir - valsAbsBase - let valsHiBase := bDir + valsAbsBase - let valsLo : Fin seq → Rat := fun _ => valsLoBase - let valsHi : Fin seq → Rat := fun _ => valsHiBase + let valsAbs : Fin seq → Rat := fun q => + Linear.sumFin dModel (fun j => |wvDir j|) * lnAbsMax q + let valsLo : Fin seq → Rat := fun q => bDir - valsAbs q + let valsHi : Fin seq → Rat := fun q => bDir + valsAbs q let univ : Finset (Fin seq) := Finset.univ have hnonempty : univ.Nonempty := by simp [univ] let lo := univ.inf' hnonempty valsLo diff --git a/Nfp/Sound/Induction/CoreSound.lean b/Nfp/Sound/Induction/CoreSound.lean index dad5344..dd0e99f 100644 --- a/Nfp/Sound/Induction/CoreSound.lean +++ b/Nfp/Sound/Induction/CoreSound.lean @@ -37,10 +37,6 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} lnAbsMaxArr[q.1]'(by have hsize : lnAbsMaxArr.size = seq := by simp [lnAbsMaxArr] simp [hsize]) - let lnAbsMaxMax : Rat := - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := Finset.univ_nonempty - univ.sup' hnonempty (fun q => lnAbsMax q) let invStdBoundsTasks : Array (Task (Rat × Rat)) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => invStdBounds inputs.lnEps (inputs.embed q))) @@ -409,12 +405,10 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) let bDir : Rat := Linear.dotFin dHead dirHead (fun d => inputs.bv d) - let valsAbsBase : Rat := - Linear.sumFin dModel (fun j => |wvDir j|) * lnAbsMaxMax - let valsLoBase := bDir - valsAbsBase - let valsHiBase := bDir + valsAbsBase - let valsLo : Fin seq → Rat := fun _ => valsLoBase - let valsHi : Fin seq → Rat := fun _ => valsHiBase + let valsAbs : Fin seq → Rat := fun q => + Linear.sumFin dModel (fun j => |wvDir j|) * lnAbsMax q + let valsLo : Fin seq → Rat := fun q => bDir - valsAbs q + let valsHi : Fin seq → Rat := fun q => bDir + valsAbs q let univ : Finset (Fin seq) := Finset.univ have hnonempty : univ.Nonempty := by simp [univ] let lo := univ.inf' hnonempty valsLo @@ -460,15 +454,6 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} (hhi := fun j => (hln j).2) j simpa [lnAbsMax, lnAbsMaxArr, lnAbsMaxTask, Bounds.cacheBoundTask_apply, Array.getElem_ofFn] using h - have hln_abs_max : ∀ q, lnAbsMax q ≤ lnAbsMaxMax := by - intro q - have hnonempty : (Finset.univ : Finset (Fin seq)).Nonempty := - Finset.univ_nonempty - have hmem : q ∈ (Finset.univ : Finset (Fin seq)) := by simp - simpa [lnAbsMaxMax] using - (Finset.le_sup'_iff (s := (Finset.univ : Finset (Fin seq))) - (H := hnonempty) (f := fun q => lnAbsMax q) (a := lnAbsMax q)).2 - ⟨q, hmem, le_rfl⟩ have hdot_abs_bound : ∀ (v : Fin dModel → Rat) (q : Fin seq), |dotProduct (fun j => (v j : Real)) (lnRealOfInputs inputs q)| ≤ @@ -1353,34 +1338,19 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} intro k have hdot_abs : |dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k)| ≤ - (valsAbsBase : Real) := by + (valsAbs k : Real) := by have hdot := hdot_abs_bound_sum (fun j => wvDir j) k - have hln_max_real : - (lnAbsMax k : Real) ≤ (lnAbsMaxMax : Real) := - ratToReal_le_of_le (hln_abs_max k) - have hsum_nonneg : 0 ≤ (Linear.sumFin dModel (fun j => |wvDir j|) : Real) := by - refine ratToReal_nonneg_of_nonneg ?_ - have : 0 ≤ ∑ j, |wvDir j| := by - refine Finset.sum_nonneg ?_ - intro j _ - exact abs_nonneg _ - simpa [Linear.sumFin_eq_sum_univ] using this - have hmul : - (Linear.sumFin dModel (fun j => |wvDir j|) : Real) * (lnAbsMax k : Real) ≤ - (Linear.sumFin dModel (fun j => |wvDir j|) : Real) * (lnAbsMaxMax : Real) := - mul_le_mul_of_nonneg_left hln_max_real hsum_nonneg - have hfinal := hdot.trans hmul - simpa [valsAbsBase, ratToReal_mul] using hfinal + simpa [valsAbs, ratToReal_mul] using hdot have hdot_bounds := (abs_le).1 hdot_abs have hlow' := add_le_add_right hdot_bounds.1 (bDir : Real) have hhigh' := add_le_add_right hdot_bounds.2 (bDir : Real) have hlow : (valCert.valsLo k : Real) ≤ valsRealOfInputs inputs k := by - simpa [valCert, valsLo, valsLoBase, valsAbsBase, hvals_eq k, ratToReal_sub, + simpa [valCert, valsLo, valsAbs, hvals_eq k, ratToReal_sub, sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using hlow' have hhigh : valsRealOfInputs inputs k ≤ (valCert.valsHi k : Real) := by - simpa [valCert, valsHi, valsHiBase, valsAbsBase, hvals_eq k, ratToReal_add, + simpa [valCert, valsHi, valsAbs, hvals_eq k, ratToReal_add, add_comm, add_left_comm, add_assoc] using hhigh' exact ⟨hlow, hhigh⟩ have hvals_bounds : From f7dc710f8e84645696deb1bd5de88f68e9b92b2f Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sun, 11 Jan 2026 23:51:36 +0100 Subject: [PATCH 134/244] Guidance adjustment for aesop --- AGENTS.md | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 47a1f9f..fcc4ffb 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -134,13 +134,33 @@ prefer the **clean redesign**, but do it consciously and document the rationale. - and broadly safe. - Prefer `simp [foo]` over global simp-set growth. -### 4.3 Proof automation discipline -- Use automation to *discover* proofs, then write the small explicit proof (or a minimal - `simp only [...]` set) that captures it. -- Avoid large one-line automation proofs (e.g. `aesop`, `simp` without a controlled set) - in core library code; they are brittle and can slow down elaboration. -- Prefer local simplification: use `simp?` to get a minimal `simp only [...]` set for - non-terminal goals, and keep custom simp sets local via `registerSimpAttr` when needed. +### 4.3 Proof automation discipline (Aesop-aware) + +- Use `aesop?` (and `simp?`) to *discover* a proof or a minimal set of steps, then + prefer to: + - write a small explicit proof (`simp only [...]`, `constructor`, `cases`, `refine`, `exact`, etc.), or + - keep Aesop but constrain it with a local ruleset and/or targeted rules. + +- Avoid unconditional `by aesop` in core/trusted library code unless: + - the goal is genuinely routine, + - it stays fast and stable under small refactors, and + - it does not rely on a large implicit rule universe. + +- Prefer local rules over global rules: + - If a lemma is meant to be reused by Aesop, tag it deliberately (e.g. `@[aesop safe]`) + and explain why it is safe. + - Avoid tagging “utility” lemmas as Aesop rules unless you want them participating + in search broadly. + +- If Aesop succeeds but produces a long/fragile search: + - extract a helper lemma that expresses the key reasoning step, + - prove that lemma explicitly (or with tightly scoped automation), + - then let Aesop use the helper. + +- Keep automation predictable: + - prefer `simp only [...]` and small custom simp sets locally, + - avoid growing the global simp set and global Aesop rule set without strong reason. + ### 4.4 Refactors are allowed—but must be principled - You may do nontrivial refactors to improve conceptual cleanliness. From 654405b4d97a0ceb52e7831594f19b4574ce4101 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Mon, 12 Jan 2026 15:31:37 +0100 Subject: [PATCH 135/244] bd sync: 2026-01-12 15:31:37 --- .beads/interactions.jsonl | 0 .beads/issues.jsonl | 0 .beads/metadata.json | 4 ++++ 3 files changed, 4 insertions(+) create mode 100644 .beads/interactions.jsonl create mode 100644 .beads/issues.jsonl create mode 100644 .beads/metadata.json diff --git a/.beads/interactions.jsonl b/.beads/interactions.jsonl new file mode 100644 index 0000000..e69de29 diff --git a/.beads/issues.jsonl b/.beads/issues.jsonl new file mode 100644 index 0000000..e69de29 diff --git a/.beads/metadata.json b/.beads/metadata.json new file mode 100644 index 0000000..c787975 --- /dev/null +++ b/.beads/metadata.json @@ -0,0 +1,4 @@ +{ + "database": "beads.db", + "jsonl_export": "issues.jsonl" +} \ No newline at end of file From 4fe3efddfcd5a73a34edfe7eec97b505a16644ce Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Mon, 12 Jan 2026 15:47:42 +0100 Subject: [PATCH 136/244] Refactor induction bounds and LayerNorm proofs --- .beads/.gitignore | 44 + .beads/README.md | 81 ++ .beads/config.yaml | 62 + .gitattributes | 3 + AGENTS.md | 28 + MODULE_MAP.md | 2 + Nfp/Circuit/Layers/Induction.lean | 14 +- Nfp/Circuit/Layers/Softmax.lean | 1 - Nfp/Sound/Bounds/Attention.lean | 94 +- Nfp/Sound/Bounds/Cache.lean | 59 +- Nfp/Sound/Bounds/Gelu.lean | 43 +- Nfp/Sound/Bounds/LayerNorm.lean | 868 ++------------ Nfp/Sound/Bounds/LayerNorm/InvStd.lean | 23 +- Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean | 29 +- Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean | 640 ++++++++++ Nfp/Sound/Bounds/MatrixNorm.lean | 73 +- Nfp/Sound/Bounds/MatrixNorm/Interval.lean | 268 +++-- Nfp/Sound/Bounds/Mlp.lean | 58 +- Nfp/Sound/Bounds/Transformer.lean | 25 +- Nfp/Sound/Bounds/Transformer/Embedding.lean | 57 +- Nfp/Sound/Induction/Core.lean | 1110 +++++++++++------- Nfp/Sound/Induction/CoreDefs.lean | 55 + Nfp/Sound/Induction/CoreSound.lean | 599 +++++----- Nfp/Sound/Induction/CoreSound/Values.lean | 6 +- Nfp/Sound/Induction/HeadBounds.lean | 8 +- Nfp/Sound/Induction/HeadOutput.lean | 78 +- Nfp/Sound/Induction/LogitDiff.lean | 32 +- Nfp/Sound/Linear/FinFold.lean | 18 + 28 files changed, 2543 insertions(+), 1835 deletions(-) create mode 100644 .beads/.gitignore create mode 100644 .beads/README.md create mode 100644 .beads/config.yaml create mode 100644 .gitattributes create mode 100644 Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean diff --git a/.beads/.gitignore b/.beads/.gitignore new file mode 100644 index 0000000..d27a1db --- /dev/null +++ b/.beads/.gitignore @@ -0,0 +1,44 @@ +# SQLite databases +*.db +*.db?* +*.db-journal +*.db-wal +*.db-shm + +# Daemon runtime files +daemon.lock +daemon.log +daemon.pid +bd.sock +sync-state.json +last-touched + +# Local version tracking (prevents upgrade notification spam after git ops) +.local_version + +# Legacy database files +db.sqlite +bd.db + +# Worktree redirect file (contains relative path to main repo's .beads/) +# Must not be committed as paths would be wrong in other clones +redirect + +# Merge artifacts (temporary files from 3-way merge) +beads.base.jsonl +beads.base.meta.json +beads.left.jsonl +beads.left.meta.json +beads.right.jsonl +beads.right.meta.json + +# Sync state (local-only, per-machine) +# These files are machine-specific and should not be shared across clones +.sync.lock +sync_base.jsonl + +# NOTE: Do NOT add negation patterns (e.g., !issues.jsonl) here. +# They would override fork protection in .git/info/exclude, allowing +# contributors to accidentally commit upstream issue databases. +# The JSONL files (issues.jsonl, interactions.jsonl) and config files +# are tracked by git by default since no pattern above ignores them. diff --git a/.beads/README.md b/.beads/README.md new file mode 100644 index 0000000..50f281f --- /dev/null +++ b/.beads/README.md @@ -0,0 +1,81 @@ +# Beads - AI-Native Issue Tracking + +Welcome to Beads! This repository uses **Beads** for issue tracking - a modern, AI-native tool designed to live directly in your codebase alongside your code. + +## What is Beads? + +Beads is issue tracking that lives in your repo, making it perfect for AI coding agents and developers who want their issues close to their code. No web UI required - everything works through the CLI and integrates seamlessly with git. + +**Learn more:** [github.com/steveyegge/beads](https://github.com/steveyegge/beads) + +## Quick Start + +### Essential Commands + +```bash +# Create new issues +bd create "Add user authentication" + +# View all issues +bd list + +# View issue details +bd show + +# Update issue status +bd update --status in_progress +bd update --status done + +# Sync with git remote +bd sync +``` + +### Working with Issues + +Issues in Beads are: +- **Git-native**: Stored in `.beads/issues.jsonl` and synced like code +- **AI-friendly**: CLI-first design works perfectly with AI coding agents +- **Branch-aware**: Issues can follow your branch workflow +- **Always in sync**: Auto-syncs with your commits + +## Why Beads? + +✨ **AI-Native Design** +- Built specifically for AI-assisted development workflows +- CLI-first interface works seamlessly with AI coding agents +- No context switching to web UIs + +🚀 **Developer Focused** +- Issues live in your repo, right next to your code +- Works offline, syncs when you push +- Fast, lightweight, and stays out of your way + +🔧 **Git Integration** +- Automatic sync with git commits +- Branch-aware issue tracking +- Intelligent JSONL merge resolution + +## Get Started with Beads + +Try Beads in your own projects: + +```bash +# Install Beads +curl -sSL https://raw.githubusercontent.com/steveyegge/beads/main/scripts/install.sh | bash + +# Initialize in your repo +bd init + +# Create your first issue +bd create "Try out Beads" +``` + +## Learn More + +- **Documentation**: [github.com/steveyegge/beads/docs](https://github.com/steveyegge/beads/tree/main/docs) +- **Quick Start Guide**: Run `bd quickstart` +- **Examples**: [github.com/steveyegge/beads/examples](https://github.com/steveyegge/beads/tree/main/examples) + +--- + +*Beads: Issue tracking that moves at the speed of thought* ⚡ diff --git a/.beads/config.yaml b/.beads/config.yaml new file mode 100644 index 0000000..f242785 --- /dev/null +++ b/.beads/config.yaml @@ -0,0 +1,62 @@ +# Beads Configuration File +# This file configures default behavior for all bd commands in this repository +# All settings can also be set via environment variables (BD_* prefix) +# or overridden with command-line flags + +# Issue prefix for this repository (used by bd init) +# If not set, bd init will auto-detect from directory name +# Example: issue-prefix: "myproject" creates issues like "myproject-1", "myproject-2", etc. +# issue-prefix: "" + +# Use no-db mode: load from JSONL, no SQLite, write back after each command +# When true, bd will use .beads/issues.jsonl as the source of truth +# instead of SQLite database +# no-db: false + +# Disable daemon for RPC communication (forces direct database access) +# no-daemon: false + +# Disable auto-flush of database to JSONL after mutations +# no-auto-flush: false + +# Disable auto-import from JSONL when it's newer than database +# no-auto-import: false + +# Enable JSON output by default +# json: false + +# Default actor for audit trails (overridden by BD_ACTOR or --actor) +# actor: "" + +# Path to database (overridden by BEADS_DB or --db) +# db: "" + +# Auto-start daemon if not running (can also use BEADS_AUTO_START_DAEMON) +# auto-start-daemon: true + +# Debounce interval for auto-flush (can also use BEADS_FLUSH_DEBOUNCE) +# flush-debounce: "5s" + +# Git branch for beads commits (bd sync will commit to this branch) +# IMPORTANT: Set this for team projects so all clones use the same sync branch. +# This setting persists across clones (unlike database config which is gitignored). +# Can also use BEADS_SYNC_BRANCH env var for local override. +# If not set, bd sync will require you to run 'bd config set sync.branch '. +# sync-branch: "beads-sync" + +# Multi-repo configuration (experimental - bd-307) +# Allows hydrating from multiple repositories and routing writes to the correct JSONL +# repos: +# primary: "." # Primary repo (where this database lives) +# additional: # Additional repos to hydrate from (read-only) +# - ~/beads-planning # Personal planning repo +# - ~/work-planning # Work planning repo + +# Integration settings (access with 'bd config get/set') +# These are stored in the database, not in this file: +# - jira.url +# - jira.project +# - linear.url +# - linear.api-key +# - github.org +# - github.repo diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..807d598 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,3 @@ + +# Use bd merge for beads JSONL files +.beads/issues.jsonl merge=beads diff --git a/AGENTS.md b/AGENTS.md index fcc4ffb..9b257ed 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -14,6 +14,8 @@ but keep the core invariants and the “no fake proofs” ethos. --- +**Use 'bd' for task tracking** + ## 0. Quick Start (What to run) ### Build @@ -202,3 +204,29 @@ This repo treats “axioms creep” as a serious regression. preserved and, where possible, explicitly proved. - [ ] Module map in `MODULE_MAP.md` is accurate (updated in the same commit if needed). - [ ] If CLI behavior changed: `lake build nfp --wfail` succeeds and basic `nfp ... --help` works. + +## Landing the Plane (Session Completion) + +**When ending a work session**, you MUST complete ALL steps below. Work is NOT complete until `git push` succeeds. + +**MANDATORY WORKFLOW:** + +1. **File issues for remaining work** - Create issues for anything that needs follow-up +2. **Run quality gates** (if code changed) - Tests, linters, builds +3. **Update issue status** - Close finished work, update in-progress items +4. **PUSH TO REMOTE** - This is MANDATORY: + ```bash + git pull --rebase + bd sync + git push + git status # MUST show "up to date with origin" + ``` +5. **Clean up** - Clear stashes, prune remote branches +6. **Verify** - All changes committed AND pushed +7. **Hand off** - Provide context for next session + +**CRITICAL RULES:** +- Work is NOT complete until `git push` succeeds +- NEVER stop before pushing - that leaves work stranded locally +- NEVER say "ready to push when you are" - YOU must push +- If push fails, resolve and retry until it succeeds diff --git a/MODULE_MAP.md b/MODULE_MAP.md index 9899b32..93d67fe 100644 --- a/MODULE_MAP.md +++ b/MODULE_MAP.md @@ -183,6 +183,8 @@ but you **must** update this list in the same commit. - LayerNorm interval bounds and end-to-end soundness lemmas. - `Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean` - Mean/variance helpers for LayerNorm bounds. +- `Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean` + - Square-root bounds (rational + real) used by LayerNorm and invStd bounds. - `Nfp/Sound/Bounds/LayerNorm/InvStd.lean` - Inverse-standard-deviation bounds for LayerNorm. - `Nfp/Sound/Bounds/UnnormRat.lean` diff --git a/Nfp/Circuit/Layers/Induction.lean b/Nfp/Circuit/Layers/Induction.lean index 9b2ac97..8e1c3b6 100644 --- a/Nfp/Circuit/Layers/Induction.lean +++ b/Nfp/Circuit/Layers/Induction.lean @@ -362,13 +362,14 @@ theorem inductionSpecApproxOn_of_oneHotApprox_valueRange simp [add_assoc] _ ≤ vals (prev q) + (∑ k ∈ others, weights q k) * (hi - lo) := by have h := - add_le_add_right hsum_prev_le ((∑ k ∈ others, weights q k) * (hi - lo)) + add_le_add_right hsum_prev_le + ((∑ k ∈ others, weights q k) * (hi - lo)) simpa [add_comm, add_left_comm, add_assoc] using h have hupper : dotProduct (weights q) vals ≤ vals (prev q) + ε * (hi - lo) := by have hmul : - (∑ k ∈ others, weights q k) * (hi - lo) ≤ ε * (hi - lo) := - mul_le_mul_of_nonneg_right hsum_others_le hdiff_nonneg + (∑ k ∈ others, weights q k) * (hi - lo) ≤ ε * (hi - lo) := by + exact mul_le_mul_of_nonneg_right hsum_others_le hdiff_nonneg calc dotProduct (weights q) vals = weights q (prev q) * vals (prev q) + ∑ k ∈ others, weights q k * vals k := hout_eq @@ -378,7 +379,8 @@ theorem inductionSpecApproxOn_of_oneHotApprox_valueRange simpa [add_comm, add_left_comm, add_assoc] using h _ ≤ vals (prev q) + (∑ k ∈ others, weights q k) * (hi - lo) := hupper_mid _ ≤ vals (prev q) + ε * (hi - lo) := by - have h := add_le_add_left hmul (vals (prev q)) + have h := + add_le_add_left hmul (vals (prev q)) simpa [add_comm, add_left_comm, add_assoc] using h have hprev_le : vals (prev q) ≤ @@ -420,8 +422,8 @@ theorem inductionSpecApproxOn_of_oneHotApprox_valueRange vals (prev q) - ε * (hi - lo) ≤ vals (prev q) - (∑ k ∈ others, weights q k) * (hi - lo) := by have hmul : - (∑ k ∈ others, weights q k) * (hi - lo) ≤ ε * (hi - lo) := - mul_le_mul_of_nonneg_right hsum_others_le hdiff_nonneg + (∑ k ∈ others, weights q k) * (hi - lo) ≤ ε * (hi - lo) := by + exact mul_le_mul_of_nonneg_right hsum_others_le hdiff_nonneg exact sub_le_sub_left hmul (vals (prev q)) have hlow : vals (prev q) - ε * (hi - lo) ≤ dotProduct (weights q) vals := by diff --git a/Nfp/Circuit/Layers/Softmax.lean b/Nfp/Circuit/Layers/Softmax.lean index 8a11656..3e7968f 100644 --- a/Nfp/Circuit/Layers/Softmax.lean +++ b/Nfp/Circuit/Layers/Softmax.lean @@ -17,7 +17,6 @@ namespace Nfp namespace Circuit open scoped BigOperators - noncomputable section variable {seq : Nat} diff --git a/Nfp/Sound/Bounds/Attention.lean b/Nfp/Sound/Bounds/Attention.lean index 950e8de..731e1f0 100644 --- a/Nfp/Sound/Bounds/Attention.lean +++ b/Nfp/Sound/Bounds/Attention.lean @@ -88,6 +88,41 @@ def attentionOutputBounds {dModel dHead numHeads : Nat} let sumHi : Fin dModel → Rat := fun i => ∑ h, headHi h i (fun i => sumLo i + attnBias i, fun i => sumHi i + attnBias i) +private theorem sum_weighted_const {seq : Nat} (w : Fin seq → Real) (c : Real) + (hsum : ∑ k, w k = 1) : + ∑ k, w k * c = c := by + calc + ∑ k, w k * c = (∑ k, w k) * c := by + simpa using + (Finset.sum_mul (s := (Finset.univ : Finset (Fin seq))) (f := w) (a := c)).symm + _ = c := by simp [hsum] + +/-- Weighted dot-products preserve interval bounds. -/ +theorem dotProduct_bounds_of_weights {seq : Nat} {lo hi : Real} + {vals w : Fin seq → Real} + (hlo : ∀ k, lo ≤ vals k) (hhi : ∀ k, vals k ≤ hi) + (hnonneg : ∀ k, 0 ≤ w k) (hsum : ∑ k, w k = 1) : + lo ≤ dotProduct w vals ∧ dotProduct w vals ≤ hi := by + have hsum_lo : ∑ k, w k * lo ≤ ∑ k, w k * vals k := by + refine Finset.sum_le_sum ?_ + intro k _ + exact mul_le_mul_of_nonneg_left (hlo k) (hnonneg k) + have hsum_lo' : ∑ k, w k * lo = lo := sum_weighted_const w lo hsum + have hlow : lo ≤ dotProduct w vals := by + have hsum_le : lo ≤ ∑ k, w k * vals k := by + simpa [hsum_lo'] using hsum_lo + simpa [dotProduct] using hsum_le + have hsum_hi : ∑ k, w k * vals k ≤ ∑ k, w k * hi := by + refine Finset.sum_le_sum ?_ + intro k _ + exact mul_le_mul_of_nonneg_left (hhi k) (hnonneg k) + have hsum_hi' : ∑ k, w k * hi = hi := sum_weighted_const w hi hsum + have hhigh : dotProduct w vals ≤ hi := by + have hsum_le : ∑ k, w k * vals k ≤ hi := by + simpa [hsum_hi'] using hsum_hi + simpa [dotProduct] using hsum_le + exact ⟨hlow, hhigh⟩ + /-- `attentionOutputBounds` soundness for real attention outputs. -/ theorem attentionOutputBounds_spec {seq dModel dHead numHeads : Nat} [NeZero seq] (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) @@ -155,52 +190,15 @@ theorem attentionOutputBounds_spec {seq dModel dHead numHeads : Nat} [NeZero seq have hln := hln_bounds k have hlo' : ∀ j, (lnLo j : Real) ≤ lnOut k j := fun j => (hln j).1 have hhi' : ∀ j, lnOut k j ≤ (lnHi j : Real) := fun j => (hln j).2 - have hlow := - dotIntervalLower_le_dotProduct_real (v := fun j => (heads h).wv j d) - (lo := lnLo) (hi := lnHi) (x := lnOut k) hlo' hhi' - have hhigh := - dotProduct_le_dotIntervalUpper_real (v := fun j => (heads h).wv j d) - (lo := lnLo) (hi := lnHi) (x := lnOut k) hlo' hhi' - have hlow' := add_le_add_right hlow ((heads h).bv d : Real) - have hhigh' := add_le_add_right hhigh ((heads h).bv d : Real) + have hlow' := + dotIntervalLower_le_dotProduct_real_add (v := fun j => (heads h).wv j d) + (lo := lnLo) (hi := lnHi) (x := lnOut k) (b := ((heads h).bv d : Real)) hlo' hhi' + have hhigh' := + dotProduct_le_dotIntervalUpper_real_add (v := fun j => (heads h).wv j d) + (lo := lnLo) (hi := lnHi) (x := lnOut k) (b := ((heads h).bv d : Real)) hlo' hhi' constructor · simpa [headValue, vLo] using hlow' · simpa [headValue, vHi] using hhigh' - have weighted_bounds : - ∀ {lo hi : Real} {vals : Fin seq → Real} {w : Fin seq → Real}, - (∀ k, lo ≤ vals k) → (∀ k, vals k ≤ hi) → - (∀ k, 0 ≤ w k) → (∑ k, w k = 1) → - lo ≤ dotProduct w vals ∧ dotProduct w vals ≤ hi := by - intro lo hi vals w hlo' hhi' hnonneg hsum - have hsum_lo : ∑ k, w k * lo ≤ ∑ k, w k * vals k := by - refine Finset.sum_le_sum ?_ - intro k _ - exact mul_le_mul_of_nonneg_left (hlo' k) (hnonneg k) - have hsum_lo' : ∑ k, w k * lo = lo := by - calc - ∑ k, w k * lo = (∑ k, w k) * lo := by - simpa using - (Finset.sum_mul (s := (Finset.univ : Finset (Fin seq))) (f := w) (a := lo)).symm - _ = lo := by simp [hsum] - have hlow : lo ≤ dotProduct w vals := by - have hsum_le : lo ≤ ∑ k, w k * vals k := by - simpa [hsum_lo'] using hsum_lo - simpa [dotProduct] using hsum_le - have hsum_hi : ∑ k, w k * vals k ≤ ∑ k, w k * hi := by - refine Finset.sum_le_sum ?_ - intro k _ - exact mul_le_mul_of_nonneg_left (hhi' k) (hnonneg k) - have hsum_hi' : ∑ k, w k * hi = hi := by - calc - ∑ k, w k * hi = (∑ k, w k) * hi := by - simpa using - (Finset.sum_mul (s := (Finset.univ : Finset (Fin seq))) (f := w) (a := hi)).symm - _ = hi := by simp [hsum] - have hhigh : dotProduct w vals ≤ hi := by - have hsum_le : ∑ k, w k * vals k ≤ hi := by - simpa [hsum_hi'] using hsum_hi - simpa [dotProduct] using hsum_le - exact ⟨hlow, hhigh⟩ have hhead_output_bounds : ∀ h q d, (vLo h d : Real) ≤ headOutput h q d ∧ @@ -216,7 +214,7 @@ theorem attentionOutputBounds_spec {seq dModel dHead numHeads : Nat} [NeZero seq exact Circuit.softmax_nonneg (scores h q) k have hsum : ∑ k, headWeights h q k = 1 := by simpa [headWeights] using Circuit.softmax_sum_one (scores h q) - have h := weighted_bounds (lo := (vLo h d : Real)) (hi := (vHi h d : Real)) + have h := dotProduct_bounds_of_weights (lo := (vLo h d : Real)) (hi := (vHi h d : Real)) (vals := fun k => headValue h k d) (w := headWeights h q) hlo' hhi' hnonneg hsum simpa [headOutput] using h @@ -257,13 +255,13 @@ theorem attentionOutputBounds_spec {seq dModel dHead numHeads : Nat} [NeZero seq have hlow : (sumLo i : Real) + (attnBias i : Real) ≤ (∑ h, headProj h q i) + (attnBias i : Real) := by - have h := add_le_add_left hsum_bounds.1 (attnBias i : Real) - simpa [add_comm, add_left_comm, add_assoc] using h + simpa [add_comm] using + add_le_add_left hsum_bounds.1 (attnBias i : Real) have hhigh : (∑ h, headProj h q i) + (attnBias i : Real) ≤ (sumHi i : Real) + (attnBias i : Real) := by - have h := add_le_add_left hsum_bounds.2 (attnBias i : Real) - simpa [add_comm, add_left_comm, add_assoc] using h + simpa [add_comm] using + add_le_add_left hsum_bounds.2 (attnBias i : Real) have hreal : attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q i = (∑ h, headProj h q i) + (attnBias i : Real) := by diff --git a/Nfp/Sound/Bounds/Cache.lean b/Nfp/Sound/Bounds/Cache.lean index 5084089..7e21f0b 100644 --- a/Nfp/Sound/Bounds/Cache.lean +++ b/Nfp/Sound/Bounds/Cache.lean @@ -16,9 +16,7 @@ namespace Bounds def cacheBound {n : Nat} (f : Fin n → Rat) : Fin n → Rat := let data : Thunk (Array Rat) := Thunk.mk (fun _ => Array.ofFn f) fun i => (Thunk.get data)[i.1]'(by - have hsize : (Thunk.get data).size = n := by - simp [Thunk.get, data] - simp [hsize]) + simp [Thunk.get, data, i.isLt]) /-- `cacheBound` preserves pointwise values. -/ theorem cacheBound_apply {n : Nat} (f : Fin n → Rat) (i : Fin n) : @@ -44,10 +42,8 @@ def cacheBoundTask {n : Nat} (f : Fin n → Rat) : Fin n → Rat := Array.ofFn (fun i : Fin n => Task.spawn (fun _ => f i)) fun i => - let hsize : tasks.size = n := by - simp [tasks] let t := tasks[i.1]'(by - simp [hsize, i.isLt]) + simp [tasks, i.isLt]) t.get /-- `cacheBoundTask` preserves pointwise values. -/ @@ -61,17 +57,12 @@ def cacheBound2 {m n : Nat} (f : Fin m → Fin n → Rat) : Fin m → Fin n → let data : Thunk (Array (Thunk (Array Rat))) := Thunk.mk (fun _ => Array.ofFn (fun q => Thunk.mk (fun _ => Array.ofFn (f q)))) fun q i => - let rowThunk := (Thunk.get data)[q.1]'(by - have hsize : (Thunk.get data).size = m := by - simp [Thunk.get, data] - rw [hsize] - exact q.isLt) + let rows := Thunk.get data + let rowThunk := rows[q.1]'(by + simp [rows, Thunk.get, data, q.isLt]) let row := Thunk.get rowThunk row[i.1]'(by - have hsize : row.size = n := by - simp [row, rowThunk, Thunk.get, data, Array.getElem_ofFn] - rw [hsize] - exact i.isLt) + simp [row, rowThunk, rows, Thunk.get, data, i.isLt]) /-- `cacheBound2` preserves pointwise values. -/ theorem cacheBound2_apply {m n : Nat} (f : Fin m → Fin n → Rat) (q : Fin m) (i : Fin n) : @@ -105,9 +96,7 @@ def cacheBound2TaskElem {m n : Nat} (f : Fin m → Fin n → Rat) : Fin m → Fi let row := (rowTasks[q.1]'(by simp [rowTasks, q.isLt])) let t := row[i.1]'(by - have hsize : row.size = n := by - simp [row, rowTasks] - simp [hsize, i.isLt]) + simp [row, rowTasks, i.isLt]) t.get /-- `cacheBound2TaskElem` preserves pointwise values. -/ @@ -125,29 +114,19 @@ def cacheBoundPair2 {m n : Nat} let row := f q (Array.ofFn row.1, Array.ofFn row.2))) let lo : Fin m → Fin n → Rat := fun q i => - let row := (Thunk.get data)[q.1]'(by - have hsize : (Thunk.get data).size = m := by - simp [Thunk.get, data] - rw [hsize] - exact q.isLt) + let rows := Thunk.get data + let row := rows[q.1]'(by + simp [rows, Thunk.get, data, q.isLt]) let loRow := row.1 loRow[i.1]'(by - have hsize : loRow.size = n := by - simp [loRow, row, Thunk.get, data, Array.getElem_ofFn] - rw [hsize] - exact i.isLt) + simp [loRow, row, rows, Thunk.get, data, i.isLt]) let hi : Fin m → Fin n → Rat := fun q i => - let row := (Thunk.get data)[q.1]'(by - have hsize : (Thunk.get data).size = m := by - simp [Thunk.get, data] - rw [hsize] - exact q.isLt) + let rows := Thunk.get data + let row := rows[q.1]'(by + simp [rows, Thunk.get, data, q.isLt]) let hiRow := row.2 hiRow[i.1]'(by - have hsize : hiRow.size = n := by - simp [hiRow, row, Thunk.get, data, Array.getElem_ofFn] - rw [hsize] - exact i.isLt) + simp [hiRow, row, rows, Thunk.get, data, i.isLt]) (lo, hi) /-- `cacheBoundPair2` preserves pointwise values (first component). -/ @@ -210,17 +189,13 @@ def cacheBoundPair2TaskElem {m n : Nat} (f : Fin m → Fin n → Rat × Rat) : let row := (rowTasks[q.1]'(by simp [rowTasks, q.isLt])) let t := row[i.1]'(by - have hsize : row.size = n := by - simp [row, rowTasks] - simp [hsize, i.isLt]) + simp [row, rowTasks, i.isLt]) (t.get).1 let hi : Fin m → Fin n → Rat := fun q i => let row := (rowTasks[q.1]'(by simp [rowTasks, q.isLt])) let t := row[i.1]'(by - have hsize : row.size = n := by - simp [row, rowTasks] - simp [hsize, i.isLt]) + simp [row, rowTasks, i.isLt]) (t.get).2 (lo, hi) diff --git a/Nfp/Sound/Bounds/Gelu.lean b/Nfp/Sound/Bounds/Gelu.lean index adf58df..5fdeb93 100644 --- a/Nfp/Sound/Bounds/Gelu.lean +++ b/Nfp/Sound/Bounds/Gelu.lean @@ -120,6 +120,21 @@ theorem geluInterval_bounds {lo hi : Rat} {x : Real} (geluInterval lo hi).1 ≤ (geluTanh x : Real) ∧ (geluTanh x : Real) ≤ (geluInterval lo hi).2 := by have hgelu := geluTanh_bounds x + have hupper : geluTanh x ≤ (geluInterval lo hi).2 := by + have hmax : geluTanh x ≤ max (hi : Real) 0 := by + have hmax' : max x 0 ≤ max (hi : Real) 0 := max_le_max hhi le_rfl + exact le_trans hgelu.2 hmax' + by_cases hhi0 : 0 ≤ hi + · have hhi0r : 0 ≤ (hi : Real) := by + exact ratToReal_nonneg_of_nonneg hhi0 + simpa [geluInterval, hhi0, max_eq_left hhi0r] using hmax + · have hhi0r : (hi : Real) ≤ 0 := by + exact (ratToReal_nonpos_iff (x := hi)).2 (le_of_not_ge hhi0) + have hx0 : x ≤ 0 := le_trans hhi hhi0r + have hmax' : max x 0 = 0 := max_eq_right hx0 + have hhi'' : geluTanh x ≤ (0 : Real) := by + simpa [hmax'] using hgelu.2 + simpa [geluInterval, hhi0, ratToReal_zero] using hhi'' by_cases hlo0 : lo ≤ 0 · have hlo0r : (lo : Real) ≤ 0 := by exact (ratToReal_nonpos_iff (x := lo)).2 hlo0 @@ -128,42 +143,18 @@ theorem geluInterval_bounds {lo hi : Rat} {x : Real} have hmin' : (lo : Real) ≤ min x 0 := by simpa [min_eq_left hlo0r] using hmin exact le_trans hmin' hgelu.1 - have hmax : max x 0 ≤ max (hi : Real) 0 := max_le_max hhi le_rfl - have hhi' : geluTanh x ≤ max (hi : Real) 0 := le_trans hgelu.2 hmax constructor · simpa [geluInterval, hlo0] using hlo' - · by_cases hhi0 : 0 ≤ hi - · have hhi0r : 0 ≤ (hi : Real) := by - exact ratToReal_nonneg_of_nonneg hhi0 - simpa [geluInterval, hhi0, max_eq_left hhi0r] using hhi' - · have hhi0r : (hi : Real) ≤ 0 := by - exact (ratToReal_nonpos_iff (x := hi)).2 (le_of_not_ge hhi0) - have hx0 : x ≤ 0 := le_trans hhi hhi0r - have hmax' : max x 0 = 0 := max_eq_right hx0 - have hhi'' : geluTanh x ≤ (0 : Real) := by - simpa [hmax'] using hgelu.2 - simpa [geluInterval, hhi0, ratToReal_zero] using hhi'' + · exact hupper · have hlo0r : 0 ≤ (lo : Real) := by exact ratToReal_nonneg_of_nonneg (le_of_not_ge hlo0) have hx0 : 0 ≤ x := le_trans hlo0r hlo have hmin' : min x 0 = 0 := min_eq_right hx0 have hlo' : (0 : Real) ≤ geluTanh x := by simpa [hmin'] using hgelu.1 - have hmax : max x 0 ≤ max (hi : Real) 0 := max_le_max hhi le_rfl - have hhi' : geluTanh x ≤ max (hi : Real) 0 := le_trans hgelu.2 hmax constructor · simpa [geluInterval, hlo0, ratToReal_zero] using hlo' - · by_cases hhi0 : 0 ≤ hi - · have hhi0r : 0 ≤ (hi : Real) := by - exact ratToReal_nonneg_of_nonneg hhi0 - simpa [geluInterval, hhi0, max_eq_left hhi0r] using hhi' - · have hhi0r : (hi : Real) ≤ 0 := by - exact (ratToReal_nonpos_iff (x := hi)).2 (le_of_not_ge hhi0) - have hx0' : x ≤ 0 := le_trans hhi hhi0r - have hmax' : max x 0 = 0 := max_eq_right hx0' - have hhi'' : geluTanh x ≤ (0 : Real) := by - simpa [hmax'] using hgelu.2 - simpa [geluInterval, hhi0, ratToReal_zero] using hhi'' + · exact hupper end Bounds diff --git a/Nfp/Sound/Bounds/LayerNorm.lean b/Nfp/Sound/Bounds/LayerNorm.lean index e9a5d10..f5e11ae 100644 --- a/Nfp/Sound/Bounds/LayerNorm.lean +++ b/Nfp/Sound/Bounds/LayerNorm.lean @@ -5,12 +5,12 @@ import Mathlib.Algebra.BigOperators.Group.Finset.Basic import Mathlib.Algebra.Order.BigOperators.Group.Finset import Mathlib.Algebra.Order.Field.Basic import Mathlib.Algebra.Order.Ring.Basic -import Mathlib.Data.Nat.Sqrt import Mathlib.Data.Real.Sqrt import Mathlib.Data.Rat.BigOperators import Mathlib.Data.Rat.Cast.Order import Nfp.Core.Basic import Nfp.Sound.Bounds.LayerNorm.MeanVariance +import Nfp.Sound.Bounds.LayerNorm.SqrtBounds import Nfp.Sound.Linear.FinFold /-! @@ -28,651 +28,6 @@ namespace Bounds open scoped BigOperators -/-! Square-root bounds. -/ - -lemma rat_nat_cast_nonneg (n : Nat) : (0 : Rat) ≤ (n : Rat) := by - simp - -lemma rat_nat_cast_pos {n : Nat} (h : 0 < n) : (0 : Rat) < (n : Rat) := by - exact (Nat.cast_pos (α := Rat)).2 h - -/-- Base rational lower bound for a square root. -/ -def sqrtLowerBase (q : Rat) : Rat := - let num := q.num.natAbs - let den := q.den - let a := Nat.sqrt num - let b := Nat.sqrt den - ratRoundDown ((a : Rat) / (b + 1)) - -/-- Base rational upper bound for a square root. -/ -def sqrtUpperBase (q : Rat) : Rat := - let num := q.num.natAbs - let den := q.den - let a := Nat.sqrt num - let b := Nat.sqrt den - ratRoundUp ((a + 1 : Rat) / b) - -/-- Alternate rational lower bound for a square root. -/ -def sqrtLowerAlt (q : Rat) : Rat := - let num := q.num.natAbs - let den := q.den - let a := Nat.sqrt (num * den) - ratRoundDown ((a : Rat) / den) - -/-- Alternate rational upper bound for a square root. -/ -def sqrtUpperAlt (q : Rat) : Rat := - let num := q.num.natAbs - let den := q.den - let a := Nat.sqrt (num * den) - ratRoundUp ((a + 1 : Rat) / den) - -/-- Extra precision scale for `sqrtLowerScaled`. -/ -def sqrtLowerScale : Nat := 1048576 - -/-- Scaled rational lower bound for a square root (extra precision). -/ -def sqrtLowerScaled (q : Rat) : Rat := - let num := q.num.natAbs - let den := q.den - let scale := sqrtLowerScale - let a := Nat.sqrt (num * den * scale * scale) - ratRoundDown ((a : Rat) / (den * scale)) - -/-- Scaled rational upper bound for a square root (extra precision). -/ -def sqrtUpperScaled (q : Rat) : Rat := - let num := q.num.natAbs - let den := q.den - let scale := sqrtLowerScale - let a := Nat.sqrt (num * den * scale * scale) - ratRoundUp ((a + 1 : Rat) / (den * scale)) - -/-- Rational lower bound for a square root (tighter of three bounds). -/ -def sqrtLower (q : Rat) : Rat := - max (max (sqrtLowerBase q) (sqrtLowerAlt q)) (sqrtLowerScaled q) - -/-- Rational upper bound for a square root (tighter of three bounds). -/ -def sqrtUpper (q : Rat) : Rat := - min (min (sqrtUpperBase q) (sqrtUpperAlt q)) (sqrtUpperScaled q) - -/-- `sqrtLowerBase` is nonnegative. -/ -theorem sqrtLowerBase_nonneg (q : Rat) : 0 ≤ sqrtLowerBase q := by - classical - unfold sqrtLowerBase - have hnum : 0 ≤ (Nat.sqrt q.num.natAbs : Rat) := by - exact_mod_cast (Nat.zero_le (Nat.sqrt q.num.natAbs)) - have hden : 0 ≤ (Nat.sqrt q.den + 1 : Rat) := by - exact_mod_cast (Nat.zero_le (Nat.sqrt q.den + 1)) - have hrat : 0 ≤ (Nat.sqrt q.num.natAbs : Rat) / (Nat.sqrt q.den + 1) := by - exact div_nonneg hnum hden - exact ratRoundDown_nonneg hrat - -/-! Strict positivity helpers. -/ - -/-! Base bounds. -/ - - -/-- `sqrtUpperBase` is nonnegative. -/ -theorem sqrtUpperBase_nonneg (q : Rat) : 0 ≤ sqrtUpperBase q := by - classical - unfold sqrtUpperBase - have hnum : 0 ≤ (Nat.sqrt q.num.natAbs + 1 : Rat) := by - exact_mod_cast (Nat.zero_le (Nat.sqrt q.num.natAbs + 1)) - have hden : 0 ≤ (Nat.sqrt q.den : Rat) := by - exact_mod_cast (Nat.zero_le (Nat.sqrt q.den)) - have hrat : - 0 ≤ (Nat.sqrt q.num.natAbs + 1 : Rat) / (Nat.sqrt q.den) := by - exact div_nonneg hnum hden - exact ratRoundUp_nonneg hrat - -/-- `sqrtUpperBase` is always positive. -/ -theorem sqrtUpperBase_pos (q : Rat) : 0 < sqrtUpperBase q := by - classical - unfold sqrtUpperBase - have hnum_pos : (0 : Rat) < (Nat.sqrt q.num.natAbs + 1 : Rat) := by - exact_mod_cast (Nat.succ_pos (Nat.sqrt q.num.natAbs)) - have hden_pos : (0 : Rat) < (Nat.sqrt q.den : Rat) := by - have hden : 0 < q.den := q.den_pos - exact_mod_cast (Nat.sqrt_pos.2 hden) - have hrat_pos : - (0 : Rat) < (Nat.sqrt q.num.natAbs + 1 : Rat) / (Nat.sqrt q.den) := by - exact div_pos hnum_pos hden_pos - exact ratRoundUp_pos hrat_pos - -/-! Alternate bounds. -/ - -/-- `sqrtLowerAlt` is nonnegative. -/ -theorem sqrtLowerAlt_nonneg (q : Rat) : 0 ≤ sqrtLowerAlt q := by - classical - unfold sqrtLowerAlt - have hnum : 0 ≤ (Nat.sqrt (q.num.natAbs * q.den) : Rat) := by - exact_mod_cast (Nat.zero_le (Nat.sqrt (q.num.natAbs * q.den))) - have hden : 0 ≤ (q.den : Rat) := by - exact_mod_cast (Nat.zero_le q.den) - have hrat : - 0 ≤ (Nat.sqrt (q.num.natAbs * q.den) : Rat) / q.den := by - exact div_nonneg hnum hden - exact ratRoundDown_nonneg hrat - - -/-- `sqrtUpperAlt` is nonnegative. -/ -theorem sqrtUpperAlt_nonneg (q : Rat) : 0 ≤ sqrtUpperAlt q := by - classical - unfold sqrtUpperAlt - have hnum : 0 ≤ (Nat.sqrt (q.num.natAbs * q.den) + 1 : Rat) := by - exact_mod_cast (Nat.zero_le (Nat.sqrt (q.num.natAbs * q.den) + 1)) - have hden : 0 ≤ (q.den : Rat) := by - exact_mod_cast (Nat.zero_le q.den) - have hrat : - 0 ≤ (Nat.sqrt (q.num.natAbs * q.den) + 1 : Rat) / q.den := by - exact div_nonneg hnum hden - exact ratRoundUp_nonneg hrat - -/-- `sqrtUpperAlt` is always positive. -/ -theorem sqrtUpperAlt_pos (q : Rat) : 0 < sqrtUpperAlt q := by - classical - unfold sqrtUpperAlt - have hnum_pos : - (0 : Rat) < (Nat.sqrt (q.num.natAbs * q.den) + 1 : Rat) := by - exact_mod_cast (Nat.succ_pos (Nat.sqrt (q.num.natAbs * q.den))) - have hden_pos : (0 : Rat) < (q.den : Rat) := by - exact_mod_cast q.den_pos - have hrat_pos : - (0 : Rat) < - (Nat.sqrt (q.num.natAbs * q.den) + 1 : Rat) / q.den := by - exact div_pos hnum_pos hden_pos - exact ratRoundUp_pos hrat_pos - -/-- `sqrtUpperScaled` is nonnegative. -/ -theorem sqrtUpperScaled_nonneg (q : Rat) : 0 ≤ sqrtUpperScaled q := by - classical - unfold sqrtUpperScaled - have hnum : - 0 ≤ (Nat.sqrt (q.num.natAbs * q.den * sqrtLowerScale * sqrtLowerScale) + 1 : Rat) := by - exact_mod_cast - (Nat.zero_le (Nat.sqrt (q.num.natAbs * q.den * sqrtLowerScale * sqrtLowerScale) + 1)) - have hden : 0 ≤ (q.den * sqrtLowerScale : Rat) := by - exact_mod_cast (Nat.zero_le (q.den * sqrtLowerScale)) - have hrat : - 0 ≤ (Nat.sqrt (q.num.natAbs * q.den * sqrtLowerScale * sqrtLowerScale) + 1 : Rat) / - (q.den * sqrtLowerScale) := by - exact div_nonneg hnum hden - exact ratRoundUp_nonneg hrat - -/-- `sqrtUpperScaled` is always positive. -/ -theorem sqrtUpperScaled_pos (q : Rat) : 0 < sqrtUpperScaled q := by - classical - unfold sqrtUpperScaled - have hnum_pos : - (0 : Rat) < - (Nat.sqrt (q.num.natAbs * q.den * sqrtLowerScale * sqrtLowerScale) + 1 : Rat) := by - exact_mod_cast - (Nat.succ_pos (Nat.sqrt (q.num.natAbs * q.den * sqrtLowerScale * sqrtLowerScale))) - have hden_pos : (0 : Rat) < (q.den * sqrtLowerScale : Rat) := by - have hden : 0 < q.den := q.den_pos - have hscale : 0 < sqrtLowerScale := by - simp [sqrtLowerScale] - exact_mod_cast (Nat.mul_pos hden hscale) - have hrat_pos : - (0 : Rat) < - (Nat.sqrt (q.num.natAbs * q.den * sqrtLowerScale * sqrtLowerScale) + 1 : Rat) / - (q.den * sqrtLowerScale) := by - exact div_pos hnum_pos hden_pos - exact ratRoundUp_pos hrat_pos - -/-! Combined bounds. -/ - -/-- `sqrtLower` is nonnegative. -/ -theorem sqrtLower_nonneg (q : Rat) : 0 ≤ sqrtLower q := by - have hbase : 0 ≤ sqrtLowerBase q := sqrtLowerBase_nonneg q - have hmax : 0 ≤ max (sqrtLowerBase q) (sqrtLowerAlt q) := - le_trans hbase (le_max_left _ _) - exact le_trans hmax (le_max_left _ _) - - -/-- `sqrtUpper` is nonnegative. -/ -theorem sqrtUpper_nonneg (q : Rat) : 0 ≤ sqrtUpper q := by - have hbase : 0 ≤ sqrtUpperBase q := sqrtUpperBase_nonneg q - have halt : 0 ≤ sqrtUpperAlt q := sqrtUpperAlt_nonneg q - have hscaled : 0 ≤ sqrtUpperScaled q := sqrtUpperScaled_nonneg q - have hmin1 : 0 ≤ min (sqrtUpperBase q) (sqrtUpperAlt q) := by - exact le_min hbase halt - exact le_min hmin1 hscaled - -/-- `sqrtUpper` is always positive. -/ -theorem sqrtUpper_pos (q : Rat) : 0 < sqrtUpper q := by - have hbase : 0 < sqrtUpperBase q := sqrtUpperBase_pos q - have halt : 0 < sqrtUpperAlt q := sqrtUpperAlt_pos q - have hscaled : 0 < sqrtUpperScaled q := sqrtUpperScaled_pos q - have hmin1 : 0 < min (sqrtUpperBase q) (sqrtUpperAlt q) := by - exact lt_min hbase halt - exact lt_min hmin1 hscaled - -/-- Square-root lower bound in reals. -/ -theorem sqrtLowerBase_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : - (sqrtLowerBase q : Real) ≤ Real.sqrt (q : Real) := by - classical - -- Set up numerator/denominator witnesses. - set num : Nat := q.num.natAbs - set den : Nat := q.den - set a : Nat := Nat.sqrt num - set b : Nat := Nat.sqrt den - have hden_pos : 0 < (den : Real) := by - exact_mod_cast q.den_pos - have hbpos : 0 < (b + 1 : Real) := by - exact_mod_cast (Nat.succ_pos b) - have hnum_le : (a ^ 2 : Real) ≤ num := by - exact_mod_cast (Nat.sqrt_le' num) - have hden_le : (den : Real) ≤ (b + 1) ^ 2 := by - exact_mod_cast (le_of_lt (Nat.lt_succ_sqrt' den)) - have hmul : (a ^ 2 : Real) * den ≤ (num : Real) * (b + 1) ^ 2 := by - have hden_nonneg : 0 ≤ (den : Real) := by exact_mod_cast (Nat.zero_le den) - have hnum_nonneg : 0 ≤ (num : Real) := by exact_mod_cast (Nat.zero_le num) - exact mul_le_mul hnum_le hden_le hden_nonneg hnum_nonneg - have hbpos2 : 0 < (b + 1 : Real) ^ 2 := by - nlinarith [hbpos] - have hdiv : (a ^ 2 : Real) / (b + 1) ^ 2 ≤ (num : Real) / den := by - exact (div_le_div_iff₀ hbpos2 hden_pos).2 hmul - have hpow : ((a : Real) / (b + 1 : Real)) ^ 2 = (a ^ 2 : Real) / (b + 1) ^ 2 := by - simp [pow_two, div_mul_div_comm] - have hq_cast : (q : Real) = (num : Real) / den := by - have hnum_nonneg : 0 ≤ q.num := by - exact (Rat.num_nonneg (q := q)).2 hq - have hnum_eq : (num : Int) = q.num := by - simpa [num] using (Int.natAbs_of_nonneg hnum_nonneg) - have hnum_cast : (q.num : Real) = (num : Real) := by - exact_mod_cast hnum_eq.symm - have hq_rat : (q : Real) = (q.num : Real) / q.den := by - simp [Rat.cast_def] - simpa [hnum_cast, den] using hq_rat - have hsq : ((a : Real) / (b + 1 : Real)) ^ 2 ≤ (q : Real) := by - simpa [hpow, hq_cast, den, num] using hdiv - have hnonneg : 0 ≤ (a : Real) / (b + 1 : Real) := by - have hnum_nonneg : 0 ≤ (a : Real) := by exact_mod_cast (Nat.zero_le a) - have hden_nonneg : 0 ≤ (b + 1 : Real) := by exact_mod_cast (Nat.zero_le (b + 1)) - exact div_nonneg hnum_nonneg hden_nonneg - have hq_nonneg : 0 ≤ (q : Real) := by - exact ratToReal_nonneg_of_nonneg hq - have hle : (a : Real) / (b + 1 : Real) ≤ Real.sqrt (q : Real) := - (Real.le_sqrt hnonneg hq_nonneg).2 hsq - have hdown : - (sqrtLowerBase q : Real) ≤ (a : Real) / (b + 1 : Real) := by - have hdown' : - ratToReal (ratRoundDown ((a : Rat) / (b + 1))) ≤ - (a : Real) / (b + 1 : Real) := by - simpa using ratRoundDown_le_real ((a : Rat) / (b + 1)) - simpa [sqrtLowerBase, num, den, a, b] using hdown' - exact le_trans hdown hle - -/-- Square-root upper bound in reals. -/ -theorem real_sqrt_le_sqrtUpperBase {q : Rat} (hq : 0 ≤ q) : - Real.sqrt (q : Real) ≤ (sqrtUpperBase q : Real) := by - classical - set num : Nat := q.num.natAbs - set den : Nat := q.den - set a : Nat := Nat.sqrt num - set b : Nat := Nat.sqrt den - have hden_pos : 0 < (den : Real) := by - exact_mod_cast q.den_pos - have hbpos : 0 < (b : Real) := by - have hb : 0 < b := by - have hden : 0 < den := q.den_pos - exact (Nat.sqrt_pos).2 hden - exact_mod_cast hb - have hnum_lt : (num : Real) < (a + 1) ^ 2 := by - exact_mod_cast (Nat.lt_succ_sqrt' num) - have hden_le : (b ^ 2 : Real) ≤ den := by - exact_mod_cast (Nat.sqrt_le' den) - have hmul : (num : Real) * (b ^ 2) ≤ (a + 1) ^ 2 * den := by - have hb2_nonneg : 0 ≤ (b ^ 2 : Real) := by - exact sq_nonneg (b : Real) - have hsq_nonneg : 0 ≤ (a + 1 : Real) ^ 2 := by - exact sq_nonneg (a + 1 : Real) - exact mul_le_mul (le_of_lt hnum_lt) hden_le hb2_nonneg hsq_nonneg - have hbpos2 : 0 < (b : Real) ^ 2 := by - nlinarith [hbpos] - have hdiv : (num : Real) / den ≤ (a + 1) ^ 2 / (b : Real) ^ 2 := by - exact (div_le_div_iff₀ hden_pos hbpos2).2 hmul - have hpow : ((a + 1 : Real) / (b : Real)) ^ 2 = (a + 1) ^ 2 / (b : Real) ^ 2 := by - simp [pow_two, div_mul_div_comm] - have hq_cast : (q : Real) = (num : Real) / den := by - have hnum_nonneg : 0 ≤ q.num := by - exact (Rat.num_nonneg (q := q)).2 hq - have hnum_eq : (num : Int) = q.num := by - simpa [num] using (Int.natAbs_of_nonneg hnum_nonneg) - have hnum_cast : (q.num : Real) = (num : Real) := by - exact_mod_cast hnum_eq.symm - have hq_rat : (q : Real) = (q.num : Real) / q.den := by - simp [Rat.cast_def] - simpa [hnum_cast, den] using hq_rat - have hsq : (q : Real) ≤ ((a + 1 : Real) / (b : Real)) ^ 2 := by - simpa [hpow, hq_cast, den, num] using hdiv - have hnonneg : 0 ≤ ((a + 1 : Real) / (b : Real)) := by - have hnum_nonneg : 0 ≤ (a + 1 : Real) := by exact_mod_cast (Nat.zero_le (a + 1)) - have hden_nonneg : 0 ≤ (b : Real) := by exact_mod_cast (Nat.zero_le b) - exact div_nonneg hnum_nonneg hden_nonneg - have hle : Real.sqrt (q : Real) ≤ (a + 1 : Real) / (b : Real) := - (Real.sqrt_le_iff).2 ⟨hnonneg, hsq⟩ - have hup : - (a + 1 : Real) / (b : Real) ≤ (sqrtUpperBase q : Real) := by - have hup' : - (a + 1 : Real) / (b : Real) ≤ - ratToReal (ratRoundUp ((a + 1 : Rat) / b)) := by - simpa using real_le_ratRoundUp ((a + 1 : Rat) / b) - simpa [sqrtUpperBase, num, den, a, b] using hup' - exact le_trans hle hup - -/-- Alternate square-root lower bound in reals. -/ -theorem sqrtLowerAlt_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : - (sqrtLowerAlt q : Real) ≤ Real.sqrt (q : Real) := by - classical - set num : Nat := q.num.natAbs - set den : Nat := q.den - set a : Nat := Nat.sqrt (num * den) - have hden_pos : 0 < (den : Real) := by - exact_mod_cast q.den_pos - have hnumden_le : (a ^ 2 : Real) ≤ (num * den : Nat) := by - exact_mod_cast (Nat.sqrt_le' (num * den)) - have hmul : (a ^ 2 : Real) ≤ (num : Real) * den := by - simpa [num, den, Nat.cast_mul] using hnumden_le - have hden_pos2 : 0 < (den : Real) ^ 2 := by - nlinarith [hden_pos] - have hdiv : - (a ^ 2 : Real) / (den : Real) ^ 2 ≤ (num : Real) * den / (den : Real) ^ 2 := by - have hmul' : - (a ^ 2 : Real) * (den : Real) ^ 2 ≤ (num : Real) * den * (den : Real) ^ 2 := by - have hden_sq_nonneg : 0 ≤ (den : Real) ^ 2 := by - exact sq_nonneg (den : Real) - exact mul_le_mul_of_nonneg_right hmul hden_sq_nonneg - exact (div_le_div_iff₀ hden_pos2 hden_pos2).2 hmul' - have hden_ne : (den : Real) ≠ 0 := by - exact_mod_cast q.den_pos.ne' - have hq_cast : (q : Real) = (num : Real) * den / (den : Real) ^ 2 := by - have hnum_nonneg : 0 ≤ q.num := by - exact (Rat.num_nonneg (q := q)).2 hq - have hnum_eq : (num : Int) = q.num := by - simpa [num] using (Int.natAbs_of_nonneg hnum_nonneg) - have hnum_cast : (q.num : Real) = (num : Real) := by - exact_mod_cast hnum_eq.symm - have hq_rat : (q : Real) = (q.num : Real) / q.den := by - simp [Rat.cast_def] - have hq_eq : - (num : Real) / den = (num : Real) * den / (den : Real) ^ 2 := by - field_simp [hden_ne] - simpa [hnum_cast, den, hq_eq] using hq_rat - have hsq : ((a : Real) / (den : Real)) ^ 2 ≤ (q : Real) := by - simpa [hq_cast, pow_two, div_mul_div_comm] using hdiv - have hnonneg : 0 ≤ (a : Real) / (den : Real) := by - have hnum_nonneg : 0 ≤ (a : Real) := by exact_mod_cast (Nat.zero_le a) - have hden_nonneg : 0 ≤ (den : Real) := by exact_mod_cast (Nat.zero_le den) - exact div_nonneg hnum_nonneg hden_nonneg - have hq_nonneg : 0 ≤ (q : Real) := by - exact ratToReal_nonneg_of_nonneg hq - have hle : (a : Real) / (den : Real) ≤ Real.sqrt (q : Real) := - (Real.le_sqrt hnonneg hq_nonneg).2 hsq - have hdown : - (sqrtLowerAlt q : Real) ≤ (a : Real) / (den : Real) := by - have hdown' : - ratToReal (ratRoundDown ((a : Rat) / den)) ≤ - (a : Real) / (den : Real) := by - simpa using ratRoundDown_le_real ((a : Rat) / den) - simpa [sqrtLowerAlt, num, den, a] using hdown' - exact le_trans hdown hle - -/-- Scaled square-root lower bound in reals. -/ -theorem sqrtLowerScaled_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : - (sqrtLowerScaled q : Real) ≤ Real.sqrt (q : Real) := by - classical - set num : Nat := q.num.natAbs - set den : Nat := q.den - set scale : Nat := sqrtLowerScale - set a : Nat := Nat.sqrt (num * den * scale * scale) - have hden_pos : 0 < (den : Real) := by - exact_mod_cast q.den_pos - have hscale_pos : 0 < (scale : Real) := by - have hscale_pos_nat : 0 < scale := by - simp [scale, sqrtLowerScale] - exact_mod_cast hscale_pos_nat - have hnumden_le : (a ^ 2 : Real) ≤ (num * den * scale * scale : Nat) := by - exact_mod_cast (Nat.sqrt_le' (num * den * scale * scale)) - have hmul : - (a ^ 2 : Real) ≤ (num : Real) * den * (scale : Real) * (scale : Real) := by - simpa [num, den, scale, Nat.cast_mul, mul_assoc, mul_left_comm, mul_comm] using hnumden_le - have hdenScale_pos : 0 < (den : Real) * (scale : Real) := - mul_pos hden_pos hscale_pos - have hdenScale_pos2 : 0 < ((den : Real) * (scale : Real)) ^ 2 := by - exact pow_pos hdenScale_pos 2 - have hmul' : - (a ^ 2 : Real) * ((den : Real) * (scale : Real)) ^ 2 ≤ - ((num : Real) * den * (scale : Real) * (scale : Real)) * - ((den : Real) * (scale : Real)) ^ 2 := by - have hnonneg : 0 ≤ ((den : Real) * (scale : Real)) ^ 2 := by - exact sq_nonneg _ - exact mul_le_mul_of_nonneg_right hmul hnonneg - have hdiv : - (a ^ 2 : Real) / ((den : Real) * (scale : Real)) ^ 2 ≤ - ((num : Real) * den * (scale : Real) * (scale : Real)) / - ((den : Real) * (scale : Real)) ^ 2 := by - exact (div_le_div_iff₀ hdenScale_pos2 hdenScale_pos2).2 hmul' - have hdenScale_ne : ((den : Real) * (scale : Real)) ≠ 0 := - ne_of_gt hdenScale_pos - have hq_cast : (q : Real) = (num : Real) / den := by - have hnum_nonneg : 0 ≤ q.num := by - exact (Rat.num_nonneg (q := q)).2 hq - have hnum_eq : (num : Int) = q.num := by - simpa [num] using (Int.natAbs_of_nonneg hnum_nonneg) - have hnum_cast : (q.num : Real) = (num : Real) := by - exact_mod_cast hnum_eq.symm - have hq_rat : (q : Real) = (q.num : Real) / q.den := by - simp [Rat.cast_def] - simpa [hnum_cast, den] using hq_rat - have hq_eq : - ((num : Real) * den * (scale : Real) * (scale : Real)) / - ((den : Real) * (scale : Real)) ^ 2 = (num : Real) / den := by - field_simp [hdenScale_ne] - have hpow : - ((a : Real) / ((den : Real) * (scale : Real))) ^ 2 = - (a ^ 2 : Real) / ((den : Real) * (scale : Real)) ^ 2 := by - simp [pow_two, div_mul_div_comm] - have hsq : - ((a : Real) / ((den : Real) * (scale : Real))) ^ 2 ≤ (q : Real) := by - calc - ((a : Real) / ((den : Real) * (scale : Real))) ^ 2 - = (a ^ 2 : Real) / ((den : Real) * (scale : Real)) ^ 2 := hpow - _ ≤ ((num : Real) * den * (scale : Real) * (scale : Real)) / - ((den : Real) * (scale : Real)) ^ 2 := hdiv - _ = (num : Real) / den := hq_eq - _ = (q : Real) := by simp [hq_cast] - have hnonneg : 0 ≤ (a : Real) / ((den : Real) * (scale : Real)) := by - have hnum_nonneg : 0 ≤ (a : Real) := by exact_mod_cast (Nat.zero_le a) - have hden_nonneg : 0 ≤ (den : Real) * (scale : Real) := by - nlinarith [hden_pos, hscale_pos] - exact div_nonneg hnum_nonneg hden_nonneg - have hq_nonneg : 0 ≤ (q : Real) := by - exact ratToReal_nonneg_of_nonneg hq - have hle : - (a : Real) / ((den : Real) * (scale : Real)) ≤ Real.sqrt (q : Real) := - (Real.le_sqrt hnonneg hq_nonneg).2 hsq - have hdown : - (sqrtLowerScaled q : Real) ≤ (a : Real) / ((den : Real) * (scale : Real)) := by - have hdown' : - ratToReal (ratRoundDown ((a : Rat) / (den * scale))) ≤ - (a : Real) / ((den : Real) * (scale : Real)) := by - simpa using ratRoundDown_le_real ((a : Rat) / (den * scale)) - simpa [sqrtLowerScaled, num, den, scale, a] using hdown' - exact le_trans hdown hle - -/-- Alternate square-root upper bound in reals. -/ -theorem real_sqrt_le_sqrtUpperAlt {q : Rat} (hq : 0 ≤ q) : - Real.sqrt (q : Real) ≤ (sqrtUpperAlt q : Real) := by - classical - set num : Nat := q.num.natAbs - set den : Nat := q.den - set a : Nat := Nat.sqrt (num * den) - have hden_pos : 0 < (den : Real) := by - exact_mod_cast q.den_pos - have hnumden_lt : (num * den : Real) < (a + 1) ^ 2 := by - exact_mod_cast (Nat.lt_succ_sqrt' (num * den)) - have hmul : (num : Real) * den ≤ (a + 1 : Real) ^ 2 := by - exact le_of_lt hnumden_lt - have hden_pos2 : 0 < (den : Real) ^ 2 := by - nlinarith [hden_pos] - have hdiv : - (num : Real) * den / (den : Real) ^ 2 ≤ (a + 1 : Real) ^ 2 / (den : Real) ^ 2 := by - have hmul' : - (num : Real) * den * (den : Real) ^ 2 ≤ (a + 1 : Real) ^ 2 * (den : Real) ^ 2 := by - have hden_sq_nonneg : 0 ≤ (den : Real) ^ 2 := by - exact sq_nonneg (den : Real) - exact mul_le_mul_of_nonneg_right hmul hden_sq_nonneg - exact (div_le_div_iff₀ hden_pos2 hden_pos2).2 hmul' - have hden_ne : (den : Real) ≠ 0 := by - exact_mod_cast q.den_pos.ne' - have hq_cast : (q : Real) = (num : Real) * den / (den : Real) ^ 2 := by - have hnum_nonneg : 0 ≤ q.num := by - exact (Rat.num_nonneg (q := q)).2 hq - have hnum_eq : (num : Int) = q.num := by - simpa [num] using (Int.natAbs_of_nonneg hnum_nonneg) - have hnum_cast : (q.num : Real) = (num : Real) := by - exact_mod_cast hnum_eq.symm - have hq_rat : (q : Real) = (q.num : Real) / q.den := by - simp [Rat.cast_def] - have hq_eq : - (num : Real) / den = (num : Real) * den / (den : Real) ^ 2 := by - field_simp [hden_ne] - simpa [hnum_cast, den, hq_eq] using hq_rat - have hpow : - ((a + 1 : Real) / (den : Real)) ^ 2 = - (a + 1 : Real) ^ 2 / (den : Real) ^ 2 := by - simp [pow_two, div_mul_div_comm] - have hsq : (q : Real) ≤ ((a + 1 : Real) / (den : Real)) ^ 2 := by - simpa [hq_cast, hpow] using hdiv - have hnonneg : 0 ≤ ((a + 1 : Real) / (den : Real)) := by - have hnum_nonneg : 0 ≤ (a + 1 : Real) := by exact_mod_cast (Nat.zero_le (a + 1)) - have hden_nonneg : 0 ≤ (den : Real) := by exact_mod_cast (Nat.zero_le den) - exact div_nonneg hnum_nonneg hden_nonneg - have hle : Real.sqrt (q : Real) ≤ (a + 1 : Real) / (den : Real) := - (Real.sqrt_le_iff).2 ⟨hnonneg, hsq⟩ - have hup : - (a + 1 : Real) / (den : Real) ≤ (sqrtUpperAlt q : Real) := by - have hup' : - (a + 1 : Real) / (den : Real) ≤ - ratToReal (ratRoundUp ((a + 1 : Rat) / den)) := by - simpa using real_le_ratRoundUp ((a + 1 : Rat) / den) - simpa [sqrtUpperAlt, num, den, a] using hup' - exact le_trans hle hup - -/-- Scaled square-root upper bound in reals. -/ -theorem real_sqrt_le_sqrtUpperScaled {q : Rat} (hq : 0 ≤ q) : - Real.sqrt (q : Real) ≤ (sqrtUpperScaled q : Real) := by - classical - set num : Nat := q.num.natAbs - set den : Nat := q.den - set scale : Nat := sqrtLowerScale - set a : Nat := Nat.sqrt (num * den * scale * scale) - have hden_pos : 0 < (den : Real) := by - exact_mod_cast q.den_pos - have hscale_pos : 0 < (scale : Real) := by - have hscale_pos_nat : 0 < scale := by - simp [scale, sqrtLowerScale] - exact_mod_cast hscale_pos_nat - have hnumden_lt : (num * den * scale * scale : Real) < (a + 1) ^ 2 := by - exact_mod_cast (Nat.lt_succ_sqrt' (num * den * scale * scale)) - have hmul : - (num : Real) * den * (scale : Real) * (scale : Real) ≤ (a + 1 : Real) ^ 2 := by - exact le_of_lt hnumden_lt - have hdenScale_pos : 0 < (den : Real) * (scale : Real) := by - exact mul_pos hden_pos hscale_pos - have hdenScale_pos2 : 0 < ((den : Real) * (scale : Real)) ^ 2 := by - exact pow_pos hdenScale_pos 2 - have hdiv : - (num : Real) * den * (scale : Real) * (scale : Real) / - ((den : Real) * (scale : Real)) ^ 2 ≤ - (a + 1 : Real) ^ 2 / ((den : Real) * (scale : Real)) ^ 2 := by - have hmul' : - (num : Real) * den * (scale : Real) * (scale : Real) * - ((den : Real) * (scale : Real)) ^ 2 ≤ - (a + 1 : Real) ^ 2 * ((den : Real) * (scale : Real)) ^ 2 := by - have hden_sq_nonneg : 0 ≤ ((den : Real) * (scale : Real)) ^ 2 := by - exact sq_nonneg _ - exact mul_le_mul_of_nonneg_right hmul hden_sq_nonneg - exact (div_le_div_iff₀ hdenScale_pos2 hdenScale_pos2).2 hmul' - have hdenScale_ne : ((den : Real) * (scale : Real)) ≠ 0 := by - exact ne_of_gt hdenScale_pos - have hq_cast : (q : Real) = (num : Real) / den := by - have hnum_nonneg : 0 ≤ q.num := by - exact (Rat.num_nonneg (q := q)).2 hq - have hnum_eq : (num : Int) = q.num := by - simpa [num] using (Int.natAbs_of_nonneg hnum_nonneg) - have hnum_cast : (q.num : Real) = (num : Real) := by - exact_mod_cast hnum_eq.symm - have hq_rat : (q : Real) = (q.num : Real) / q.den := by - simp [Rat.cast_def] - simpa [hnum_cast, den] using hq_rat - have hq_eq : - ((num : Real) * den * (scale : Real) * (scale : Real)) / - ((den : Real) * (scale : Real)) ^ 2 = (num : Real) / den := by - field_simp [hdenScale_ne] - have hq_cast' : - (q : Real) = - ((num : Real) * den * (scale : Real) * (scale : Real)) / - ((den : Real) * (scale : Real)) ^ 2 := by - calc - (q : Real) = (num : Real) / den := hq_cast - _ = ((num : Real) * den * (scale : Real) * (scale : Real)) / - ((den : Real) * (scale : Real)) ^ 2 := hq_eq.symm - have hpow : - ((a + 1 : Real) / ((den : Real) * (scale : Real))) ^ 2 = - (a + 1 : Real) ^ 2 / ((den : Real) * (scale : Real)) ^ 2 := by - simp [pow_two, div_mul_div_comm] - have hsq : - (q : Real) ≤ ((a + 1 : Real) / ((den : Real) * (scale : Real))) ^ 2 := by - simpa [hq_cast', hpow] using hdiv - have hnonneg : 0 ≤ ((a + 1 : Real) / ((den : Real) * (scale : Real))) := by - have hnum_nonneg : 0 ≤ (a + 1 : Real) := by - exact_mod_cast (Nat.zero_le (a + 1)) - have hden_nonneg : 0 ≤ (den : Real) * (scale : Real) := by - nlinarith [hden_pos, hscale_pos] - exact div_nonneg hnum_nonneg hden_nonneg - have hle : - Real.sqrt (q : Real) ≤ (a + 1 : Real) / ((den : Real) * (scale : Real)) := - (Real.sqrt_le_iff).2 ⟨hnonneg, hsq⟩ - have hup : - (a + 1 : Real) / ((den : Real) * (scale : Real)) ≤ (sqrtUpperScaled q : Real) := by - have hup' : - (a + 1 : Real) / ((den : Real) * (scale : Real)) ≤ - ratToReal (ratRoundUp ((a + 1 : Rat) / (den * scale))) := by - simpa using real_le_ratRoundUp ((a + 1 : Rat) / (den * scale)) - simpa [sqrtUpperScaled, num, den, scale, a] using hup' - exact le_trans hle hup - -/-- Square-root lower bound in reals (tighter of three bounds). -/ -theorem sqrtLower_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : - (sqrtLower q : Real) ≤ Real.sqrt (q : Real) := by - have hbase := sqrtLowerBase_le_real_sqrt (q := q) hq - have halt := sqrtLowerAlt_le_real_sqrt (q := q) hq - have hscaled := sqrtLowerScaled_le_real_sqrt (q := q) hq - have hmax1 : - (max (sqrtLowerBase q) (sqrtLowerAlt q) : Real) ≤ Real.sqrt (q : Real) := by - simpa [ratToReal_max] using (max_le_iff).2 ⟨hbase, halt⟩ - have hmax2 : - (max (max (sqrtLowerBase q) (sqrtLowerAlt q)) (sqrtLowerScaled q) : Real) ≤ - Real.sqrt (q : Real) := by - simpa [ratToReal_max] using (max_le_iff).2 ⟨hmax1, hscaled⟩ - simpa [sqrtLower] using hmax2 - -/-- Square-root upper bound in reals (tighter of three bounds). -/ -theorem real_sqrt_le_sqrtUpper {q : Rat} (hq : 0 ≤ q) : - Real.sqrt (q : Real) ≤ (sqrtUpper q : Real) := by - have hbase := real_sqrt_le_sqrtUpperBase (q := q) hq - have halt := real_sqrt_le_sqrtUpperAlt (q := q) hq - have hscaled := real_sqrt_le_sqrtUpperScaled (q := q) hq - have hmin1 : - Real.sqrt (q : Real) ≤ min (sqrtUpperBase q : Real) (sqrtUpperAlt q : Real) := by - exact (le_min_iff).2 ⟨hbase, halt⟩ - have hmin2 : - Real.sqrt (q : Real) ≤ - min (min (sqrtUpperBase q : Real) (sqrtUpperAlt q : Real)) (sqrtUpperScaled q : Real) := by - exact (le_min_iff).2 ⟨hmin1, hscaled⟩ - simpa [sqrtUpper, ratToReal_min] using hmin2 - /-- Bounds for multiplying a scalar by a bounded value. -/ def scaleInterval (x lo hi : Rat) : Rat × Rat := if 0 ≤ x then @@ -686,11 +41,13 @@ theorem scaleInterval_bounds {x lo hi y : Rat} let bounds := scaleInterval x lo hi bounds.1 ≤ x * y ∧ x * y ≤ bounds.2 := by by_cases hx : 0 ≤ x - · simpa [scaleInterval, hx] using - And.intro (mul_le_mul_of_nonneg_left hlo hx) (mul_le_mul_of_nonneg_left hhi hx) + · have hbounds : x * lo ≤ x * y ∧ x * y ≤ x * hi := by + exact ⟨mul_le_mul_of_nonneg_left hlo hx, mul_le_mul_of_nonneg_left hhi hx⟩ + simpa [scaleInterval, hx] using hbounds · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) - simpa [scaleInterval, hx] using - And.intro (mul_le_mul_of_nonpos_left hhi hx') (mul_le_mul_of_nonpos_left hlo hx') + have hbounds : x * hi ≤ x * y ∧ x * y ≤ x * lo := by + exact ⟨mul_le_mul_of_nonpos_left hhi hx', mul_le_mul_of_nonpos_left hlo hx'⟩ + simpa [scaleInterval, hx] using hbounds /-- `scaleInterval` bounds interpreted in the reals. -/ theorem scaleInterval_bounds_real {x lo hi : Rat} {y : Real} @@ -699,12 +56,16 @@ theorem scaleInterval_bounds_real {x lo hi : Rat} {y : Real} (bounds.1 : Real) ≤ (x : Real) * y ∧ (x : Real) * y ≤ (bounds.2 : Real) := by by_cases hx : 0 ≤ x · have hx' : 0 ≤ (x : Real) := ratToReal_nonneg_of_nonneg hx - simpa [scaleInterval, hx] using - And.intro (mul_le_mul_of_nonneg_left hlo hx') (mul_le_mul_of_nonneg_left hhi hx') + have hbounds : (x : Real) * (lo : Real) ≤ (x : Real) * y ∧ + (x : Real) * y ≤ (x : Real) * (hi : Real) := by + exact ⟨mul_le_mul_of_nonneg_left hlo hx', mul_le_mul_of_nonneg_left hhi hx'⟩ + simpa [scaleInterval, hx] using hbounds · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) have hx'' : (x : Real) ≤ 0 := (ratToReal_nonpos_iff (x := x)).2 hx' - simpa [scaleInterval, hx] using - And.intro (mul_le_mul_of_nonpos_left hhi hx'') (mul_le_mul_of_nonpos_left hlo hx'') + have hbounds : (x : Real) * (hi : Real) ≤ (x : Real) * y ∧ + (x : Real) * y ≤ (x : Real) * (lo : Real) := by + exact ⟨mul_le_mul_of_nonpos_left hhi hx'', mul_le_mul_of_nonpos_left hlo hx''⟩ + simpa [scaleInterval, hx] using hbounds /-- Real-valued LayerNorm output for a vector. -/ noncomputable def layerNormReal {n : Nat} @@ -861,12 +222,12 @@ theorem layerNormBounds_spec {n : Nat} (beta i : Real) + (coeff : Real) * (invStdLower : Real) ≤ (beta i : Real) + (coeff : Real) * invStd := by have hmul := mul_le_mul_of_nonneg_left hinv_lower hcoeff_real - exact add_le_add_right hmul (beta i : Real) + simpa only [add_comm] using add_le_add_left hmul (beta i : Real) have hhigh_raw : (beta i : Real) + (coeff : Real) * invStd ≤ (beta i : Real) + (coeff : Real) * (invStdUpper : Real) := by have hmul := mul_le_mul_of_nonneg_left hinv_upper hcoeff_real - exact add_le_add_right hmul (beta i : Real) + simpa only [add_comm] using add_le_add_left hmul (beta i : Real) have hlo : (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i := by simpa [bounds, layerNormBounds, hne, μRat, centered, varRat, varEpsRat, sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] @@ -883,12 +244,12 @@ theorem layerNormBounds_spec {n : Nat} (beta i : Real) + (coeff : Real) * (invStdUpper : Real) ≤ (beta i : Real) + (coeff : Real) * invStd := by have hmul := mul_le_mul_of_nonpos_left hinv_upper hcoeff_real - exact add_le_add_right hmul (beta i : Real) + simpa only [add_comm] using add_le_add_left hmul (beta i : Real) have hhigh_raw : (beta i : Real) + (coeff : Real) * invStd ≤ (beta i : Real) + (coeff : Real) * (invStdLower : Real) := by have hmul := mul_le_mul_of_nonpos_left hinv_lower hcoeff_real - exact add_le_add_right hmul (beta i : Real) + simpa only [add_comm] using add_le_add_left hmul (beta i : Real) have hlo : (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i := by simpa [bounds, layerNormBounds, hne, μRat, centered, varRat, varEpsRat, sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] @@ -899,6 +260,61 @@ theorem layerNormBounds_spec {n : Nat} using hhigh_raw exact And.intro hlo hhi +/-! +Local bounds for monotone multiplication in real-valued bounds. +-/ + +/-- Lower sqrt bound against the variance-plus-eps term. -/ +theorem sqrtLower_le_real_sqrt_varEps {n : Nat} (eps : Rat) (x : Fin n → Rat) + (hne : n ≠ 0) (heps : 0 < eps) : + let varEps : Real := (varianceRat x : Real) + (eps : Real) + (sqrtLower eps : Real) ≤ Real.sqrt varEps := by + intro varEps + have hsqrt_eps : + (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by + have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) + simpa using h + have hvar_nonneg : 0 ≤ (varianceRat x : Real) := varianceRat_nonneg_real x hne + have hle : (eps : Real) ≤ varEps := by + have hle' : (eps : Real) ≤ (varianceRat x : Real) + (eps : Real) := + le_add_of_nonneg_left hvar_nonneg + simpa [varEps] using hle' + have hsqrt_eps' : Real.sqrt (eps : Real) ≤ Real.sqrt varEps := by + exact Real.sqrt_le_sqrt hle + exact le_trans hsqrt_eps hsqrt_eps' + +/-- Inverse-std upper bound from the lower sqrt bound. -/ +theorem invStd_le_invStdBound {n : Nat} (eps : Rat) (x : Fin n → Rat) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : + let varEps : Real := (varianceRat x : Real) + (eps : Real) + let invStd : Real := (Real.sqrt varEps)⁻¹ + let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) + invStd ≤ (invStdBound : Real) := by + intro varEps invStd invStdBound + have hsqrt_lower : (sqrtLower eps : Real) ≤ Real.sqrt varEps := by + simpa [varEps] using + (sqrtLower_le_real_sqrt_varEps (eps := eps) (x := x) hne heps) + have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by + exact (Rat.cast_pos (K := Real) (q := sqrtLower eps)).2 hsqrt + have hinv_sqrt : invStd ≤ (sqrtLower eps : Real)⁻¹ := by + have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower + simpa [invStd] using h + have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by + have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt + have hdiv := ratDivUp_ge_real (x := 1) (y := sqrtLower eps) hy + simpa [invStdBound, one_div] using hdiv + exact le_trans hinv_sqrt hinv_bound + +/-- Inverse-std is nonnegative. -/ +theorem invStd_nonneg {n : Nat} (eps : Rat) (x : Fin n → Rat) : + let varEps : Real := (varianceRat x : Real) + (eps : Real) + let invStd : Real := (Real.sqrt varEps)⁻¹ + 0 ≤ invStd := by + intro varEps invStd + have hsqrt_nonneg : 0 ≤ Real.sqrt varEps := by + exact Real.sqrt_nonneg _ + exact inv_nonneg.2 hsqrt_nonneg + /-- Interval bounds for LayerNorm outputs from per-coordinate intervals. -/ def layerNormIntervalBounds {n : Nat} (eps : Rat) (gamma beta : Fin n → Rat) (lo hi : Fin n → Rat) : @@ -971,35 +387,11 @@ theorem layerNormIntervalBounds_spec {n : Nat} have hbound := abs_le_max_of_bounds hlo' hhi' simpa [centeredBound, μLo, μHi, ratToReal_abs, ratToReal_sub, ratToReal_max] using hbound - have hsqrt_lower : - (sqrtLower eps : Real) ≤ Real.sqrt varEps := by - have hsqrt_eps : - (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by - have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) - simpa using h - have hvar_nonneg : 0 ≤ (varianceRat x : Real) := varianceRat_nonneg_real x hne - have hle : (eps : Real) ≤ varEps := by - have hle' : (eps : Real) ≤ (varianceRat x : Real) + (eps : Real) := - le_add_of_nonneg_left hvar_nonneg - simpa [varEps] using hle' - have hsqrt_eps' : Real.sqrt (eps : Real) ≤ Real.sqrt varEps := by - exact Real.sqrt_le_sqrt hle - exact le_trans hsqrt_eps hsqrt_eps' - have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by - exact (Rat.cast_pos (K := Real) (q := sqrtLower eps)).2 hsqrt - have hinv_sqrt : invStd ≤ (sqrtLower eps : Real)⁻¹ := by - have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower - simpa [invStd] using h - have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by - have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt - have hdiv := ratDivUp_ge_real (x := 1) (y := sqrtLower eps) hy - simpa [invStdBound, one_div] using hdiv have hinv : invStd ≤ (invStdBound : Real) := by - exact le_trans hinv_sqrt hinv_bound + simpa [varEps, invStd, invStdBound] using + (invStd_le_invStdBound (eps := eps) (x := x) hne heps hsqrt) have hinv_nonneg : 0 ≤ invStd := by - have hsqrt_nonneg : 0 ≤ Real.sqrt varEps := by - exact Real.sqrt_nonneg _ - exact inv_nonneg.2 hsqrt_nonneg + simp [varEps, invStd] have hmul1 : |(x i : Real) - μ| * invStd ≤ (centeredBound i : Real) * (invStdBound : Real) := by have hleft : @@ -1015,7 +407,7 @@ theorem layerNormIntervalBounds_spec {n : Nat} have hmul2' : |(gamma i : Real)| * (|(x i : Real) - μ| * invStd) ≤ |(gamma i : Real)| * ((centeredBound i : Real) * (invStdBound : Real)) := by exact mul_le_mul_of_nonneg_left hmul1 hgamma_nonneg - simpa [mul_assoc] using hmul2' + simpa only [mul_assoc] using hmul2' let t : Real := (gamma i : Real) * ((x i : Real) - μ) * invStd have ht_abs : |t| ≤ |(gamma i : Real)| * (centeredBound i : Real) * (invStdBound : Real) := by @@ -1030,12 +422,16 @@ theorem layerNormIntervalBounds_spec {n : Nat} exact abs_le.mp ht_abs' have hlow : (beta i : Real) - (radius i : Real) ≤ t + (beta i : Real) := by - have h := add_le_add_left hbounds.1 (beta i : Real) - simpa [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h + have h : (beta i : Real) + -(radius i : Real) ≤ (beta i : Real) + t := by + simpa only [add_comm, add_left_comm, add_assoc] using + add_le_add_left hbounds.1 (beta i : Real) + simpa only [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h have hhigh : t + (beta i : Real) ≤ (beta i : Real) + (radius i : Real) := by - have h := add_le_add_left hbounds.2 (beta i : Real) - simpa [add_comm, add_left_comm, add_assoc] using h + have h : (beta i : Real) + t ≤ (beta i : Real) + (radius i : Real) := by + simpa only [add_comm, add_left_comm, add_assoc] using + add_le_add_left hbounds.2 (beta i : Real) + simpa only [add_comm, add_left_comm, add_assoc] using h have hreal : layerNormReal eps gamma beta x i = t + (beta i : Real) := by simp [layerNormReal, hne, μ, invStd, varEps, t, add_comm] @@ -1056,6 +452,14 @@ def layerNormAbsBounds {n : Nat} let radius : Fin n → Rat := fun i => |gamma i| * centeredBound * invStdBound (fun i => beta i - radius i, fun i => beta i + radius i) +/-- Bound a centered value by double the absolute bound. -/ +private theorem abs_sub_le_double_bound {a b bound : Real} + (ha : |a| ≤ bound) (hb : |b| ≤ bound) : + |a - b| ≤ bound + bound := by + have h1 : |a - b| ≤ |a| + |b| := by + simpa [sub_eq_add_neg, abs_neg] using abs_add_le a (-b) + exact le_trans h1 (add_le_add ha hb) + /-- `layerNormAbsBounds` soundness for real LayerNorm outputs under absolute input bounds. -/ theorem layerNormAbsBounds_spec {n : Nat} (eps : Rat) (gamma beta : Fin n → Rat) (absBound : Rat) (x : Fin n → Rat) @@ -1085,51 +489,23 @@ theorem layerNormAbsBounds_spec {n : Nat} let μ : Real := meanRat x let invStd : Real := (Real.sqrt varEps)⁻¹ have hcentered_abs : |(x i : Real) - μ| ≤ (centeredBound : Real) := by - have h1 : |(x i : Real) - μ| ≤ |(x i : Real)| + |μ| := by - simpa [sub_eq_add_neg, abs_neg] using abs_add_le (x i : Real) (-μ) have hx : |(x i : Real)| ≤ (absBound : Real) := by exact ratToReal_abs_le_of_le (habs i) have hmu : |μ| ≤ (absBound : Real) := by simpa [μ] using hmean_abs_real - have h2 : |(x i : Real)| + |μ| ≤ (absBound : Real) + (absBound : Real) := - add_le_add hx hmu have h12 : |(x i : Real) - μ| ≤ (absBound : Real) + (absBound : Real) := - le_trans h1 h2 + abs_sub_le_double_bound hx hmu simpa [centeredBound, two_mul] using h12 have hbound_nonneg_real : 0 ≤ (absBound : Real) := by exact ratToReal_nonneg_of_nonneg hbound_nonneg have hcentered_nonneg : 0 ≤ (centeredBound : Real) := by have hsum := add_nonneg hbound_nonneg_real hbound_nonneg_real simpa [centeredBound, two_mul] using hsum - have hsqrt_lower : - (sqrtLower eps : Real) ≤ Real.sqrt varEps := by - have hsqrt_eps : - (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by - have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) - simpa using h - have hvar_nonneg : 0 ≤ (varianceRat x : Real) := varianceRat_nonneg_real x hne - have hle : (eps : Real) ≤ varEps := by - have hle' : (eps : Real) ≤ (varianceRat x : Real) + (eps : Real) := - le_add_of_nonneg_left hvar_nonneg - simpa [varEps] using hle' - have hsqrt_eps' : Real.sqrt (eps : Real) ≤ Real.sqrt varEps := by - exact Real.sqrt_le_sqrt hle - exact le_trans hsqrt_eps hsqrt_eps' - have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by - exact (Rat.cast_pos (K := Real) (q := sqrtLower eps)).2 hsqrt - have hinv_sqrt : invStd ≤ (sqrtLower eps : Real)⁻¹ := by - have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower - simpa [invStd] using h - have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by - have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt - have hdiv := ratDivUp_ge_real (x := 1) (y := sqrtLower eps) hy - simpa [invStdBound, one_div] using hdiv have hinv : invStd ≤ (invStdBound : Real) := by - exact le_trans hinv_sqrt hinv_bound + simpa [varEps, invStd, invStdBound] using + (invStd_le_invStdBound (eps := eps) (x := x) hne heps hsqrt) have hinv_nonneg : 0 ≤ invStd := by - have hsqrt_nonneg : 0 ≤ Real.sqrt varEps := by - exact Real.sqrt_nonneg _ - exact inv_nonneg.2 hsqrt_nonneg + simp [varEps, invStd] have hmul1 : |(x i : Real) - μ| * invStd ≤ (centeredBound : Real) * (invStdBound : Real) := by have hleft : @@ -1145,7 +521,7 @@ theorem layerNormAbsBounds_spec {n : Nat} have hmul2' : |(gamma i : Real)| * (|(x i : Real) - μ| * invStd) ≤ |(gamma i : Real)| * ((centeredBound : Real) * (invStdBound : Real)) := by exact mul_le_mul_of_nonneg_left hmul1 hgamma_nonneg - simpa [mul_assoc] using hmul2' + simpa only [mul_assoc] using hmul2' let t : Real := (gamma i : Real) * ((x i : Real) - μ) * invStd have ht_abs : |t| ≤ |(gamma i : Real)| * (centeredBound : Real) * (invStdBound : Real) := by @@ -1160,12 +536,16 @@ theorem layerNormAbsBounds_spec {n : Nat} exact abs_le.mp ht_abs' have hlow : (beta i : Real) - (radius i : Real) ≤ t + (beta i : Real) := by - have h := add_le_add_left hbounds.1 (beta i : Real) - simpa [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h + have h : (beta i : Real) + -(radius i : Real) ≤ (beta i : Real) + t := by + simpa only [add_comm, add_left_comm, add_assoc] using + add_le_add_left hbounds.1 (beta i : Real) + simpa only [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h have hhigh : t + (beta i : Real) ≤ (beta i : Real) + (radius i : Real) := by - have h := add_le_add_left hbounds.2 (beta i : Real) - simpa [add_comm, add_left_comm, add_assoc] using h + have h : (beta i : Real) + t ≤ (beta i : Real) + (radius i : Real) := by + simpa only [add_comm, add_left_comm, add_assoc] using + add_le_add_left hbounds.2 (beta i : Real) + simpa only [add_comm, add_left_comm, add_assoc] using h have hreal : layerNormReal eps gamma beta x i = t + (beta i : Real) := by simp [layerNormReal, hne, μ, invStd, varEps, t, add_comm] @@ -1199,15 +579,11 @@ theorem layerNormAbsBounds_spec_real {n : Nat} let μ : Real := meanReal x let invStd : Real := (Real.sqrt varEps)⁻¹ have hcentered_abs : |x i - μ| ≤ (centeredBound : Real) := by - have h1 : |x i - μ| ≤ |x i| + |μ| := by - simpa [sub_eq_add_neg, abs_neg] using abs_add_le (x i) (-μ) have hx : |x i| ≤ (absBound : Real) := habs i have hmu : |μ| ≤ (absBound : Real) := by simpa using hmean_abs - have h2 : |x i| + |μ| ≤ (absBound : Real) + (absBound : Real) := - add_le_add hx hmu have h12 : |x i - μ| ≤ (absBound : Real) + (absBound : Real) := - le_trans h1 h2 + abs_sub_le_double_bound hx hmu simpa [centeredBound, two_mul] using h12 have hcentered_nonneg : 0 ≤ (centeredBound : Real) := by have hsum := add_nonneg hbound_nonneg_real hbound_nonneg_real @@ -1256,7 +632,7 @@ theorem layerNormAbsBounds_spec_real {n : Nat} have hmul2' : |(gamma i : Real)| * (|x i - μ| * invStd) ≤ |(gamma i : Real)| * ((centeredBound : Real) * (invStdBound : Real)) := by exact mul_le_mul_of_nonneg_left hmul1 hgamma_nonneg - simpa [mul_assoc] using hmul2' + simpa only [mul_assoc] using hmul2' let t : Real := (gamma i : Real) * (x i - μ) * invStd have ht_abs : |t| ≤ |(gamma i : Real)| * (centeredBound : Real) * (invStdBound : Real) := by @@ -1271,12 +647,16 @@ theorem layerNormAbsBounds_spec_real {n : Nat} exact abs_le.mp ht_abs' have hlow : (beta i : Real) - (radius i : Real) ≤ t + (beta i : Real) := by - have h := add_le_add_left hbounds.1 (beta i : Real) - simpa [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h + have h : (beta i : Real) + -(radius i : Real) ≤ (beta i : Real) + t := by + simpa only [add_comm, add_left_comm, add_assoc] using + add_le_add_left hbounds.1 (beta i : Real) + simpa only [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h have hhigh : t + (beta i : Real) ≤ (beta i : Real) + (radius i : Real) := by - have h := add_le_add_left hbounds.2 (beta i : Real) - simpa [add_comm, add_left_comm, add_assoc] using h + have h : (beta i : Real) + t ≤ (beta i : Real) + (radius i : Real) := by + simpa only [add_comm, add_left_comm, add_assoc] using + add_le_add_left hbounds.2 (beta i : Real) + simpa only [add_comm, add_left_comm, add_assoc] using h have hreal : layerNormRealOfReal eps gamma beta x i = t + (beta i : Real) := by simp [layerNormRealOfReal, hne, μ, invStd, varEps, t, add_comm] @@ -1390,7 +770,7 @@ theorem layerNormIntervalBounds_spec_real {n : Nat} have hmul2' : |(gamma i : Real)| * (|x i - μ| * invStd) ≤ |(gamma i : Real)| * ((centeredBound i : Real) * (invStdBound : Real)) := by exact mul_le_mul_of_nonneg_left hmul1 hgamma_nonneg - simpa [mul_assoc] using hmul2' + simpa only [mul_assoc] using hmul2' let t : Real := (gamma i : Real) * (x i - μ) * invStd have ht_abs : |t| ≤ |(gamma i : Real)| * (centeredBound i : Real) * (invStdBound : Real) := by @@ -1405,12 +785,16 @@ theorem layerNormIntervalBounds_spec_real {n : Nat} exact abs_le.mp ht_abs' have hlow : (beta i : Real) - (radius i : Real) ≤ t + (beta i : Real) := by - have h := add_le_add_left hbounds.1 (beta i : Real) - simpa [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h + have h : (beta i : Real) + -(radius i : Real) ≤ (beta i : Real) + t := by + simpa only [add_comm, add_left_comm, add_assoc] using + add_le_add_left hbounds.1 (beta i : Real) + simpa only [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h have hhigh : t + (beta i : Real) ≤ (beta i : Real) + (radius i : Real) := by - have h := add_le_add_left hbounds.2 (beta i : Real) - simpa [add_comm, add_left_comm, add_assoc] using h + have h : (beta i : Real) + t ≤ (beta i : Real) + (radius i : Real) := by + simpa only [add_comm, add_left_comm, add_assoc] using + add_le_add_left hbounds.2 (beta i : Real) + simpa only [add_comm, add_left_comm, add_assoc] using h have hreal : layerNormRealOfReal eps gamma beta x i = t + (beta i : Real) := by simp [layerNormRealOfReal, hne, μ, invStd, varEps, t, add_comm] diff --git a/Nfp/Sound/Bounds/LayerNorm/InvStd.lean b/Nfp/Sound/Bounds/LayerNorm/InvStd.lean index 3bf1967..120d3c8 100644 --- a/Nfp/Sound/Bounds/LayerNorm/InvStd.lean +++ b/Nfp/Sound/Bounds/LayerNorm/InvStd.lean @@ -1,6 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Sound.Bounds.LayerNorm +import Nfp.Sound.Bounds.LayerNorm.MeanVariance +import Nfp.Sound.Bounds.LayerNorm.SqrtBounds /-! Inverse-standard-deviation bounds for LayerNorm. @@ -53,6 +54,9 @@ theorem invStdBounds_spec {n : Nat} (eps : Rat) (x : Fin n → Rat) simpa [varRat, variance_def x hne] using h have hvarEps_nonneg : 0 ≤ varEpsRat := by exact add_nonneg hvarRat_nonneg (le_of_lt heps) + have hsqrt_var : (sqrtLower varEpsRat : Real) ≤ Real.sqrt varEps := by + have h := sqrtLower_le_real_sqrt (q := varEpsRat) hvarEps_nonneg + simpa [hvarEps] using h have hsqrt_lower : (sqrtLowerBound : Real) ≤ Real.sqrt varEps := by have hsqrt_eps : (sqrtLower eps : Real) ≤ Real.sqrt varEps := by @@ -64,14 +68,6 @@ theorem invStdBounds_spec {n : Nat} (eps : Rat) (x : Fin n → Rat) le_add_of_nonneg_left hvar_nonneg simpa [varEps] using hle' exact le_trans hsqrt_eps' (Real.sqrt_le_sqrt hle) - have hsqrt_var : (sqrtLower varEpsRat : Real) ≤ Real.sqrt varEps := by - have hsqrt_var' : - (sqrtLower varEpsRat : Real) ≤ Real.sqrt (varEpsRat : Real) := by - have h := sqrtLower_le_real_sqrt (q := varEpsRat) hvarEps_nonneg - simpa using h - have hle : (varEpsRat : Real) ≤ varEps := by - simp [hvarEps] - exact le_trans hsqrt_var' (Real.sqrt_le_sqrt hle) have hmax : max (sqrtLower eps : Real) (sqrtLower varEpsRat : Real) ≤ Real.sqrt varEps := (max_le_iff).2 ⟨hsqrt_eps, hsqrt_var⟩ @@ -111,11 +107,12 @@ theorem invStdBounds_spec {n : Nat} (eps : Rat) (x : Fin n → Rat) simpa [invStdLower, ratDivDown, hupper_ne, one_div] using hinv_lower_real have hinv_upper : invStd ≤ (invStdUpper : Real) := by simpa [invStdUpper, ratDivUp, hlower_ne, one_div] using hinv_upper_real + have hbounds : bounds = (invStdLower, invStdUpper) := by + simp [bounds, invStdBounds, hne, varRat, varEpsRat, sqrtLowerBound, sqrtUpperBound, + invStdLower, invStdUpper] constructor - · simpa [bounds, invStdBounds, hne, varRat, varEpsRat, sqrtLowerBound, sqrtUpperBound, - invStdLower, invStdUpper] using hinv_lower - · simpa [bounds, invStdBounds, hne, varRat, varEpsRat, sqrtLowerBound, sqrtUpperBound, - invStdLower, invStdUpper] using hinv_upper + · simpa [hbounds] using hinv_lower + · simpa [hbounds] using hinv_upper end Bounds diff --git a/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean b/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean index 257a492..ffcd835 100644 --- a/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean +++ b/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean @@ -105,7 +105,8 @@ theorem abs_le_max_of_bounds {α : Type _} [Ring α] [LinearOrder α] [IsOrdered have hright : z ≤ max |a| |b| := by have hb : b ≤ |b| := by exact le_abs_self b - have hb' : b ≤ max |a| |b| := le_trans hb (le_max_right _ _) + have hb' : b ≤ max |a| |b| := by + exact le_trans hb (le_max_right _ _) exact le_trans hhi hb' exact (abs_le.mpr ⟨hleft, hright⟩) @@ -123,13 +124,24 @@ theorem meanReal_def {n : Nat} (x : Fin n → Real) (h : n ≠ 0) : meanReal x = (∑ i, x i) / n := by simp [meanReal, h] +/-- `sumRat` agrees with the real sum after casting. -/ +theorem sumRat_cast {n : Nat} (x : Fin n → Rat) : + (sumRat x : Real) = ∑ i, (x i : Real) := by + classical + simp [sumRat, Rat.cast_sum] + +/-- `meanReal` agrees with `meanRat` when `n ≠ 0`. -/ +theorem meanReal_eq_meanRat_of_ne {n : Nat} (x : Fin n → Rat) (hne : n ≠ 0) : + meanReal (fun i => (x i : Real)) = (meanRat x : Real) := by + classical + simp [meanReal, meanRat, sumRat_cast, hne] + /-- `meanReal` agrees with `mean` after casting. -/ theorem meanReal_eq_meanRat {n : Nat} (x : Fin n → Rat) : meanReal (fun i => (x i : Real)) = (meanRat x : Real) := by by_cases h : n = 0 · simp [meanReal, meanRat, h] - · classical - simp [meanReal, meanRat, sumRat, h, Rat.cast_sum] + · simpa [h] using meanReal_eq_meanRat_of_ne (x := x) h /-- Mean is monotone under pointwise order (real inputs). -/ theorem meanReal_le_meanReal {n : Nat} (x y : Fin n → Real) (hne : n ≠ 0) @@ -190,7 +202,7 @@ theorem varianceReal_eq_varianceRat {n : Nat} (x : Fin n → Rat) : by_cases h : n = 0 · simp [varianceReal, varianceRat, h] · classical - simp [varianceReal, varianceRat, h, meanReal_eq_meanRat, Rat.cast_sum] + simp [varianceReal, varianceRat, h, meanReal_eq_meanRat_of_ne (x := x) h, Rat.cast_sum] /-- Variance is nonnegative when `n ≠ 0`, interpreted in reals. -/ theorem varianceRat_nonneg_real {n : Nat} (x : Fin n → Rat) (hne : n ≠ 0) : @@ -203,8 +215,7 @@ theorem meanReal_abs_le_bound {n : Nat} (x : Fin n → Real) (bound : Rat) (hne : n ≠ 0) (hbound : ∀ i, |x i| ≤ (bound : Real)) : |meanReal x| ≤ (bound : Real) := by classical - have hsum_abs : - |∑ i : Fin n, x i| ≤ ∑ i : Fin n, |x i| := by + have hsum_abs : |∑ i : Fin n, x i| ≤ ∑ i : Fin n, |x i| := by simpa using (Finset.abs_sum_le_sum_abs (f := fun i : Fin n => x i) @@ -213,15 +224,13 @@ theorem meanReal_abs_le_bound {n : Nat} (x : Fin n → Real) (bound : Rat) refine Finset.sum_le_sum ?_ intro i _ exact hbound i - have hsum_le : |∑ i : Fin n, x i| ≤ (n : Real) * (bound : Real) := by + have hsum_le : |∑ i : Fin n, x i| ≤ (bound : Real) * (n : Real) := by have hsum := le_trans hsum_abs hsum_bound simpa [Finset.sum_const, Finset.card_univ, mul_comm] using hsum have hpos : 0 < (n : Real) := by exact (Nat.cast_pos (α := Real)).2 (Nat.pos_of_ne_zero hne) - have hsum_le' : |∑ i : Fin n, x i| ≤ (bound : Real) * (n : Real) := by - simpa [mul_comm] using hsum_le have hdiv : |∑ i : Fin n, x i| / (n : Real) ≤ (bound : Real) := by - exact (div_le_iff₀ hpos).2 hsum_le' + exact (div_le_iff₀ hpos).2 hsum_le have habs_mean : |(∑ i : Fin n, x i) / (n : Real)| ≤ (bound : Real) := by simpa [abs_div, abs_of_nonneg (le_of_lt hpos)] using hdiv diff --git a/Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean b/Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean new file mode 100644 index 0000000..d6b8c7c --- /dev/null +++ b/Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean @@ -0,0 +1,640 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.Order.Field.Basic +import Mathlib.Algebra.Order.Ring.Basic +import Mathlib.Data.Nat.Sqrt +import Mathlib.Data.Real.Sqrt +import Mathlib.Data.Rat.Cast.Order +import Nfp.Core.Basic + +/-! +Square-root bounds for LayerNorm intervals. + +This module isolates the rational sqrt lower/upper bounds and their basic +nonnegativity/positivity lemmas so the main LayerNorm bounds stay focused. +-/ + +namespace Nfp + +namespace Sound + +namespace Bounds + +/-! Square-root bounds. -/ + +lemma rat_nat_cast_nonneg (n : Nat) : (0 : Rat) ≤ (n : Rat) := by + simp + +lemma rat_nat_cast_pos {n : Nat} (h : 0 < n) : (0 : Rat) < (n : Rat) := by + exact (Nat.cast_pos (α := Rat)).2 h + +/-- Base rational lower bound for a square root. -/ +def sqrtLowerBase (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den + let a := Nat.sqrt num + let b := Nat.sqrt den + ratRoundDown ((a : Rat) / (b + 1)) + +/-- Base rational upper bound for a square root. -/ +def sqrtUpperBase (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den + let a := Nat.sqrt num + let b := Nat.sqrt den + ratRoundUp ((a + 1 : Rat) / b) + +/-- Alternate rational lower bound for a square root. -/ +def sqrtLowerAlt (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den + let a := Nat.sqrt (num * den) + ratRoundDown ((a : Rat) / den) + +/-- Alternate rational upper bound for a square root. -/ +def sqrtUpperAlt (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den + let a := Nat.sqrt (num * den) + ratRoundUp ((a + 1 : Rat) / den) + +/-- Extra precision scale for `sqrtLowerScaled`. -/ +def sqrtLowerScale : Nat := 1048576 + +/-- Scaled rational lower bound for a square root (extra precision). -/ +def sqrtLowerScaled (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den + let scale := sqrtLowerScale + let a := Nat.sqrt (num * den * scale * scale) + ratRoundDown ((a : Rat) / (den * scale)) + +/-- Scaled rational upper bound for a square root (extra precision). -/ +def sqrtUpperScaled (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den + let scale := sqrtLowerScale + let a := Nat.sqrt (num * den * scale * scale) + ratRoundUp ((a + 1 : Rat) / (den * scale)) + +/-- Rational lower bound for a square root (tighter of three bounds). -/ +def sqrtLower (q : Rat) : Rat := + max (max (sqrtLowerBase q) (sqrtLowerAlt q)) (sqrtLowerScaled q) + +/-- Rational upper bound for a square root (tighter of three bounds). -/ +def sqrtUpper (q : Rat) : Rat := + min (min (sqrtUpperBase q) (sqrtUpperAlt q)) (sqrtUpperScaled q) + +/-- `sqrtLowerBase` is nonnegative. -/ +theorem sqrtLowerBase_nonneg (q : Rat) : 0 ≤ sqrtLowerBase q := by + classical + unfold sqrtLowerBase + have hnum : 0 ≤ (Nat.sqrt q.num.natAbs : Rat) := + rat_nat_cast_nonneg (Nat.sqrt q.num.natAbs) + have hden : 0 ≤ (Nat.sqrt q.den : Rat) + 1 := by + simpa using rat_nat_cast_nonneg (Nat.sqrt q.den + 1) + exact ratRoundDown_nonneg + (q := (Nat.sqrt q.num.natAbs : Rat) / (Nat.sqrt q.den + 1)) + (by exact div_nonneg hnum hden) + +/-- `sqrtUpperBase` is nonnegative. -/ +theorem sqrtUpperBase_nonneg (q : Rat) : 0 ≤ sqrtUpperBase q := by + classical + unfold sqrtUpperBase + have hnum : 0 ≤ (Nat.sqrt q.num.natAbs : Rat) + 1 := by + simpa using rat_nat_cast_nonneg (Nat.sqrt q.num.natAbs + 1) + have hden : 0 ≤ (Nat.sqrt q.den : Rat) := + rat_nat_cast_nonneg (Nat.sqrt q.den) + exact ratRoundUp_nonneg + (q := (Nat.sqrt q.num.natAbs + 1 : Rat) / (Nat.sqrt q.den)) + (by exact div_nonneg hnum hden) + +/-- `sqrtUpperBase` is always positive. -/ +theorem sqrtUpperBase_pos (q : Rat) : 0 < sqrtUpperBase q := by + classical + unfold sqrtUpperBase + have hnum_pos : (0 : Rat) < (Nat.sqrt q.num.natAbs : Rat) + 1 := by + simpa using rat_nat_cast_pos (Nat.succ_pos (Nat.sqrt q.num.natAbs)) + have hden_pos : (0 : Rat) < (Nat.sqrt q.den : Rat) := by + have hden : 0 < q.den := q.den_pos + exact rat_nat_cast_pos (Nat.sqrt_pos.2 hden) + exact ratRoundUp_pos + (q := (Nat.sqrt q.num.natAbs + 1 : Rat) / (Nat.sqrt q.den)) + (by exact div_pos hnum_pos hden_pos) + +/-- `sqrtLowerAlt` is nonnegative. -/ +theorem sqrtLowerAlt_nonneg (q : Rat) : 0 ≤ sqrtLowerAlt q := by + classical + unfold sqrtLowerAlt + have hnum : 0 ≤ (Nat.sqrt (q.num.natAbs * q.den) : Rat) := + rat_nat_cast_nonneg (Nat.sqrt (q.num.natAbs * q.den)) + have hden : 0 ≤ (q.den : Rat) := + rat_nat_cast_nonneg q.den + exact ratRoundDown_nonneg + (q := (Nat.sqrt (q.num.natAbs * q.den) : Rat) / q.den) + (by exact div_nonneg hnum hden) + +/-- `sqrtUpperAlt` is nonnegative. -/ +theorem sqrtUpperAlt_nonneg (q : Rat) : 0 ≤ sqrtUpperAlt q := by + classical + unfold sqrtUpperAlt + have hnum : 0 ≤ (Nat.sqrt (q.num.natAbs * q.den) : Rat) + 1 := by + simpa using rat_nat_cast_nonneg (Nat.sqrt (q.num.natAbs * q.den) + 1) + have hden : 0 ≤ (q.den : Rat) := + rat_nat_cast_nonneg q.den + exact ratRoundUp_nonneg + (q := (Nat.sqrt (q.num.natAbs * q.den) + 1 : Rat) / q.den) + (by exact div_nonneg hnum hden) + +/-- `sqrtUpperAlt` is always positive. -/ +theorem sqrtUpperAlt_pos (q : Rat) : 0 < sqrtUpperAlt q := by + classical + unfold sqrtUpperAlt + have hnum_pos : + (0 : Rat) < (Nat.sqrt (q.num.natAbs * q.den) : Rat) + 1 := by + simpa using rat_nat_cast_pos (Nat.succ_pos (Nat.sqrt (q.num.natAbs * q.den))) + have hden_pos : (0 : Rat) < (q.den : Rat) := + rat_nat_cast_pos q.den_pos + exact ratRoundUp_pos + (q := (Nat.sqrt (q.num.natAbs * q.den) + 1 : Rat) / q.den) + (by exact div_pos hnum_pos hden_pos) + +/-- `sqrtUpperScaled` is nonnegative. -/ +theorem sqrtUpperScaled_nonneg (q : Rat) : 0 ≤ sqrtUpperScaled q := by + classical + unfold sqrtUpperScaled + have hnum : + 0 ≤ (Nat.sqrt (q.num.natAbs * q.den * sqrtLowerScale * sqrtLowerScale) : Rat) + 1 := by + simpa using rat_nat_cast_nonneg + (Nat.sqrt (q.num.natAbs * q.den * sqrtLowerScale * sqrtLowerScale) + 1) + have hden : 0 ≤ (q.den : Rat) * (sqrtLowerScale : Rat) := by + simpa [Nat.cast_mul] using rat_nat_cast_nonneg (q.den * sqrtLowerScale) + exact ratRoundUp_nonneg + (q := (Nat.sqrt (q.num.natAbs * q.den * sqrtLowerScale * sqrtLowerScale) + 1 : Rat) + / (q.den * sqrtLowerScale)) + (by exact div_nonneg hnum hden) + +/-- `sqrtUpperScaled` is always positive. -/ +theorem sqrtUpperScaled_pos (q : Rat) : 0 < sqrtUpperScaled q := by + classical + unfold sqrtUpperScaled + have hnum_pos : + (0 : Rat) < + (Nat.sqrt (q.num.natAbs * q.den * sqrtLowerScale * sqrtLowerScale) : Rat) + 1 := by + simpa using rat_nat_cast_pos + (Nat.succ_pos (Nat.sqrt (q.num.natAbs * q.den * sqrtLowerScale * sqrtLowerScale))) + have hden_pos : (0 : Rat) < (q.den : Rat) * (sqrtLowerScale : Rat) := by + have hden : 0 < q.den := q.den_pos + have hscale : 0 < sqrtLowerScale := by + simp [sqrtLowerScale] + simpa [Nat.cast_mul] using rat_nat_cast_pos (Nat.mul_pos hden hscale) + exact ratRoundUp_pos + (q := (Nat.sqrt (q.num.natAbs * q.den * sqrtLowerScale * sqrtLowerScale) + 1 : Rat) + / (q.den * sqrtLowerScale)) + (by exact div_pos hnum_pos hden_pos) + +/-! Combined bounds. -/ + +/-- `sqrtLower` is nonnegative. -/ +theorem sqrtLower_nonneg (q : Rat) : 0 ≤ sqrtLower q := by + have hbase : 0 ≤ sqrtLowerBase q := sqrtLowerBase_nonneg q + have hmax : 0 ≤ max (sqrtLowerBase q) (sqrtLowerAlt q) := + le_trans hbase (le_max_left _ _) + exact le_trans hmax (le_max_left _ _) + +/-- `sqrtUpper` is nonnegative. -/ +theorem sqrtUpper_nonneg (q : Rat) : 0 ≤ sqrtUpper q := by + have hbase : 0 ≤ sqrtUpperBase q := sqrtUpperBase_nonneg q + have halt : 0 ≤ sqrtUpperAlt q := sqrtUpperAlt_nonneg q + have hscaled : 0 ≤ sqrtUpperScaled q := sqrtUpperScaled_nonneg q + have hmin1 : 0 ≤ min (sqrtUpperBase q) (sqrtUpperAlt q) := by + exact le_min hbase halt + exact le_min hmin1 hscaled + +/-- `sqrtUpper` is always positive. -/ +theorem sqrtUpper_pos (q : Rat) : 0 < sqrtUpper q := by + have hbase : 0 < sqrtUpperBase q := sqrtUpperBase_pos q + have halt : 0 < sqrtUpperAlt q := sqrtUpperAlt_pos q + have hscaled : 0 < sqrtUpperScaled q := sqrtUpperScaled_pos q + have hmin1 : 0 < min (sqrtUpperBase q) (sqrtUpperAlt q) := by + exact lt_min hbase halt + exact lt_min hmin1 hscaled + +/-! Real-valued bounds. -/ + +/-- Cast a nonnegative rational as `num.natAbs / den`. -/ +theorem rat_cast_eq_num_den {q : Rat} (hq : 0 ≤ q) : + (q : Real) = (q.num.natAbs : Real) / q.den := by + have hnum_nonneg : 0 ≤ q.num := (Rat.num_nonneg (q := q)).2 hq + have hnum_eq : (q.num.natAbs : Int) = q.num := by + exact (Int.natAbs_of_nonneg hnum_nonneg) + have hnum_cast : (q.num : Real) = (q.num.natAbs : Real) := by + exact (congrArg (fun z : Int => (z : Real)) hnum_eq).symm + have hq_rat : (q : Real) = (q.num : Real) / q.den := by + simp [Rat.cast_def] + calc + (q : Real) = (q.num : Real) / q.den := hq_rat + _ = (q.num.natAbs : Real) / q.den := by + rw [hnum_cast] + +/-- Cast a nonnegative rational as `num.natAbs * den / den^2`. -/ +theorem rat_cast_eq_num_den_mul {q : Rat} (hq : 0 ≤ q) : + (q : Real) = (q.num.natAbs : Real) * q.den / (q.den : Real) ^ 2 := by + have hq_cast : (q : Real) = (q.num.natAbs : Real) / q.den := + rat_cast_eq_num_den (q := q) hq + have hden_ne : (q.den : Real) ≠ 0 := by + exact_mod_cast q.den_pos.ne' + have hq_eq : + (q.num.natAbs : Real) / q.den = + (q.num.natAbs : Real) * q.den / (q.den : Real) ^ 2 := by + field_simp [hden_ne] + exact hq_cast.trans hq_eq + +/-- Cast a nonnegative rational as `num.natAbs * den * scale^2 / (den * scale)^2`. -/ +theorem rat_cast_eq_num_den_scale {q : Rat} (hq : 0 ≤ q) {scale : Nat} (hscale : 0 < scale) : + (q : Real) = + (q.num.natAbs : Real) * q.den * (scale : Real) * (scale : Real) / + ((q.den : Real) * (scale : Real)) ^ 2 := by + have hq_cast : (q : Real) = (q.num.natAbs : Real) / q.den := + rat_cast_eq_num_den (q := q) hq + have hden_pos : 0 < (q.den : Real) := by + exact_mod_cast q.den_pos + have hscale_pos : 0 < (scale : Real) := by + exact_mod_cast hscale + have hden_scale_ne : ((q.den : Real) * (scale : Real)) ≠ 0 := by + exact ne_of_gt (mul_pos hden_pos hscale_pos) + have hq_eq : + (q.num.natAbs : Real) / q.den = + (q.num.natAbs : Real) * q.den * (scale : Real) * (scale : Real) / + ((q.den : Real) * (scale : Real)) ^ 2 := by + field_simp [hden_scale_ne] + exact hq_cast.trans hq_eq + +/-- Square-root lower bound in reals. -/ +theorem sqrtLowerBase_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : + (sqrtLowerBase q : Real) ≤ Real.sqrt (q : Real) := by + classical + -- Set up numerator/denominator witnesses. + set num : Nat := q.num.natAbs + set den : Nat := q.den + set a : Nat := Nat.sqrt num + set b : Nat := Nat.sqrt den + have hden_pos : 0 < (den : Real) := by + exact_mod_cast q.den_pos + have hbpos : 0 < (b + 1 : Real) := by + exact_mod_cast (Nat.succ_pos b) + have hnum_le : (a ^ 2 : Real) ≤ num := by + exact_mod_cast (Nat.sqrt_le' num) + have hden_le : (den : Real) ≤ (b + 1) ^ 2 := by + exact_mod_cast (le_of_lt (Nat.lt_succ_sqrt' den)) + have hmul : (a ^ 2 : Real) * den ≤ (num : Real) * (b + 1) ^ 2 := by + have hden_nonneg : 0 ≤ (den : Real) := by exact_mod_cast (Nat.zero_le den) + have hnum_nonneg : 0 ≤ (num : Real) := by exact_mod_cast (Nat.zero_le num) + exact mul_le_mul hnum_le hden_le hden_nonneg hnum_nonneg + have hbpos2 : 0 < (b + 1 : Real) ^ 2 := by + simpa [pow_two] using (mul_pos hbpos hbpos) + have hdiv : (a ^ 2 : Real) / (b + 1) ^ 2 ≤ (num : Real) / den := by + exact (div_le_div_iff₀ hbpos2 hden_pos).2 hmul + have hpow : ((a : Real) / (b + 1 : Real)) ^ 2 = (a ^ 2 : Real) / (b + 1) ^ 2 := by + simp [pow_two, div_mul_div_comm] + have hq_cast : (q : Real) = (num : Real) / den := by + simpa [num, den] using (rat_cast_eq_num_den (q := q) hq) + have hsq : ((a : Real) / (b + 1 : Real)) ^ 2 ≤ (q : Real) := by + simpa [hpow, hq_cast, den, num] using hdiv + have hnonneg : 0 ≤ (a : Real) / (b + 1 : Real) := by + have hnum_nonneg : 0 ≤ (a : Real) := by exact_mod_cast (Nat.zero_le a) + have hden_nonneg : 0 ≤ (b + 1 : Real) := by exact_mod_cast (Nat.zero_le (b + 1)) + exact div_nonneg hnum_nonneg hden_nonneg + have hq_nonneg : 0 ≤ (q : Real) := by + exact ratToReal_nonneg_of_nonneg hq + have hle : (a : Real) / (b + 1 : Real) ≤ Real.sqrt (q : Real) := + (Real.le_sqrt hnonneg hq_nonneg).2 hsq + have hdown : + (sqrtLowerBase q : Real) ≤ (a : Real) / (b + 1 : Real) := by + have hdown' : + ratToReal (ratRoundDown ((a : Rat) / (b + 1))) ≤ + (a : Real) / (b + 1 : Real) := by + simpa using ratRoundDown_le_real ((a : Rat) / (b + 1)) + simpa [sqrtLowerBase, num, den, a, b] using hdown' + exact le_trans hdown hle + +/-- Square-root upper bound in reals. -/ +theorem real_sqrt_le_sqrtUpperBase {q : Rat} (hq : 0 ≤ q) : + Real.sqrt (q : Real) ≤ (sqrtUpperBase q : Real) := by + classical + set num : Nat := q.num.natAbs + set den : Nat := q.den + set a : Nat := Nat.sqrt num + set b : Nat := Nat.sqrt den + have hden_pos : 0 < (den : Real) := by + exact_mod_cast q.den_pos + have hbpos : 0 < (b : Real) := by + have hb : 0 < b := by + have hden : 0 < den := q.den_pos + exact (Nat.sqrt_pos).2 hden + exact_mod_cast hb + have hnum_lt : (num : Real) < (a + 1) ^ 2 := by + exact_mod_cast (Nat.lt_succ_sqrt' num) + have hden_le : (b ^ 2 : Real) ≤ den := by + exact_mod_cast (Nat.sqrt_le' den) + have hmul : (num : Real) * (b ^ 2) ≤ (a + 1) ^ 2 * den := by + have hb2_nonneg : 0 ≤ (b ^ 2 : Real) := by + exact sq_nonneg (b : Real) + have hsq_nonneg : 0 ≤ (a + 1 : Real) ^ 2 := by + exact sq_nonneg (a + 1 : Real) + exact mul_le_mul (le_of_lt hnum_lt) hden_le hb2_nonneg hsq_nonneg + have hbpos2 : 0 < (b : Real) ^ 2 := by + simpa [pow_two] using (mul_pos hbpos hbpos) + have hdiv : (num : Real) / den ≤ (a + 1) ^ 2 / (b : Real) ^ 2 := by + exact (div_le_div_iff₀ hden_pos hbpos2).2 hmul + have hpow : ((a + 1 : Real) / (b : Real)) ^ 2 = (a + 1) ^ 2 / (b : Real) ^ 2 := by + simp [pow_two, div_mul_div_comm] + have hq_cast : (q : Real) = (num : Real) / den := by + simpa [num, den] using (rat_cast_eq_num_den (q := q) hq) + have hsq : (q : Real) ≤ ((a + 1 : Real) / (b : Real)) ^ 2 := by + simpa [hpow, hq_cast, den, num] using hdiv + have hnonneg : 0 ≤ ((a + 1 : Real) / (b : Real)) := by + have hnum_nonneg : 0 ≤ (a + 1 : Real) := by exact_mod_cast (Nat.zero_le (a + 1)) + have hden_nonneg : 0 ≤ (b : Real) := by exact_mod_cast (Nat.zero_le b) + exact div_nonneg hnum_nonneg hden_nonneg + have hle : Real.sqrt (q : Real) ≤ (a + 1 : Real) / (b : Real) := + (Real.sqrt_le_iff).2 ⟨hnonneg, hsq⟩ + have hup : + (a + 1 : Real) / (b : Real) ≤ (sqrtUpperBase q : Real) := by + have hup' : + (a + 1 : Real) / (b : Real) ≤ + ratToReal (ratRoundUp ((a + 1 : Rat) / b)) := by + simpa using real_le_ratRoundUp ((a + 1 : Rat) / b) + simpa [sqrtUpperBase, num, den, a, b] using hup' + exact le_trans hle hup + +/- + Local automation for monotone scaling steps in real-valued bounds. +-/ +/-- Alternate square-root lower bound in reals. -/ +theorem sqrtLowerAlt_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : + (sqrtLowerAlt q : Real) ≤ Real.sqrt (q : Real) := by + classical + set num : Nat := q.num.natAbs + set den : Nat := q.den + set a : Nat := Nat.sqrt (num * den) + have hden_pos : 0 < (den : Real) := by + exact_mod_cast q.den_pos + have hnumden_le : (a ^ 2 : Real) ≤ (num * den : Nat) := by + exact_mod_cast (Nat.sqrt_le' (num * den)) + have hmul : (a ^ 2 : Real) ≤ (num : Real) * den := by + simpa [num, den, Nat.cast_mul] using hnumden_le + have hden_pos2 : 0 < (den : Real) ^ 2 := by + exact pow_pos hden_pos 2 + have hdiv : + (a ^ 2 : Real) / (den : Real) ^ 2 ≤ (num : Real) * den / (den : Real) ^ 2 := by + have hmul' : + (a ^ 2 : Real) * (den : Real) ^ 2 ≤ (num : Real) * den * (den : Real) ^ 2 := by + have hden_sq_nonneg : 0 ≤ (den : Real) ^ 2 := by + exact sq_nonneg (den : Real) + exact mul_le_mul_of_nonneg_right hmul hden_sq_nonneg + exact (div_le_div_iff₀ hden_pos2 hden_pos2).2 hmul' + have hden_ne : (den : Real) ≠ 0 := by + exact_mod_cast q.den_pos.ne' + have hq_cast : (q : Real) = (num : Real) * den / (den : Real) ^ 2 := by + simpa [num, den] using (rat_cast_eq_num_den_mul (q := q) hq) + have hsq : ((a : Real) / (den : Real)) ^ 2 ≤ (q : Real) := by + simpa [hq_cast, pow_two, div_mul_div_comm] using hdiv + have hnonneg : 0 ≤ (a : Real) / (den : Real) := by + have hnum_nonneg : 0 ≤ (a : Real) := by exact_mod_cast (Nat.zero_le a) + have hden_nonneg : 0 ≤ (den : Real) := by exact_mod_cast (Nat.zero_le den) + exact div_nonneg hnum_nonneg hden_nonneg + have hq_nonneg : 0 ≤ (q : Real) := by + exact ratToReal_nonneg_of_nonneg hq + have hle : (a : Real) / (den : Real) ≤ Real.sqrt (q : Real) := + (Real.le_sqrt hnonneg hq_nonneg).2 hsq + have hdown : + (sqrtLowerAlt q : Real) ≤ (a : Real) / (den : Real) := by + have hdown' : + ratToReal (ratRoundDown ((a : Rat) / den)) ≤ + (a : Real) / (den : Real) := by + simpa using ratRoundDown_le_real ((a : Rat) / den) + simpa [sqrtLowerAlt, num, den, a] using hdown' + exact le_trans hdown hle + +/-- Scaled square-root lower bound in reals. -/ +theorem sqrtLowerScaled_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : + (sqrtLowerScaled q : Real) ≤ Real.sqrt (q : Real) := by + classical + set num : Nat := q.num.natAbs + set den : Nat := q.den + set scale : Nat := sqrtLowerScale + set a : Nat := Nat.sqrt (num * den * scale * scale) + have hden_pos : 0 < (den : Real) := by + exact_mod_cast q.den_pos + have hscale_pos_nat : 0 < scale := by + simp [scale, sqrtLowerScale] + have hscale_pos : 0 < (scale : Real) := by + exact_mod_cast hscale_pos_nat + have hnumden_le : (a ^ 2 : Real) ≤ (num * den * scale * scale : Nat) := by + exact_mod_cast (Nat.sqrt_le' (num * den * scale * scale)) + have hmul : + (a ^ 2 : Real) ≤ (num : Real) * den * (scale : Real) * (scale : Real) := by + simpa [num, den, scale, Nat.cast_mul, mul_assoc, mul_left_comm, mul_comm] using hnumden_le + have hdenScale_pos : 0 < (den : Real) * (scale : Real) := + mul_pos hden_pos hscale_pos + have hdenScale_pos2 : 0 < ((den : Real) * (scale : Real)) ^ 2 := by + exact pow_pos hdenScale_pos 2 + have hmul' : + (a ^ 2 : Real) * ((den : Real) * (scale : Real)) ^ 2 ≤ + ((num : Real) * den * (scale : Real) * (scale : Real)) * + ((den : Real) * (scale : Real)) ^ 2 := by + have hnonneg : 0 ≤ ((den : Real) * (scale : Real)) ^ 2 := by + exact sq_nonneg _ + exact mul_le_mul_of_nonneg_right hmul hnonneg + have hdiv : + (a ^ 2 : Real) / ((den : Real) * (scale : Real)) ^ 2 ≤ + ((num : Real) * den * (scale : Real) * (scale : Real)) / + ((den : Real) * (scale : Real)) ^ 2 := by + exact (div_le_div_iff₀ hdenScale_pos2 hdenScale_pos2).2 hmul' + have hq_cast : + (q : Real) = + (num : Real) * den * (scale : Real) * (scale : Real) / + ((den : Real) * (scale : Real)) ^ 2 := by + simpa [num, den, scale] using + (rat_cast_eq_num_den_scale (q := q) hq (scale := scale) hscale_pos_nat) + have hpow : + ((a : Real) / ((den : Real) * (scale : Real))) ^ 2 = + (a ^ 2 : Real) / ((den : Real) * (scale : Real)) ^ 2 := by + simp [pow_two, div_mul_div_comm] + have hsq : + ((a : Real) / ((den : Real) * (scale : Real))) ^ 2 ≤ (q : Real) := by + calc + ((a : Real) / ((den : Real) * (scale : Real))) ^ 2 + = (a ^ 2 : Real) / ((den : Real) * (scale : Real)) ^ 2 := hpow + _ ≤ ((num : Real) * den * (scale : Real) * (scale : Real)) / + ((den : Real) * (scale : Real)) ^ 2 := hdiv + _ = (q : Real) := by simp [hq_cast] + have hnonneg : 0 ≤ (a : Real) / ((den : Real) * (scale : Real)) := by + have hnum_nonneg : 0 ≤ (a : Real) := by exact_mod_cast (Nat.zero_le a) + have hden_nonneg : 0 ≤ (den : Real) * (scale : Real) := by + exact mul_nonneg (le_of_lt hden_pos) (le_of_lt hscale_pos) + exact div_nonneg hnum_nonneg hden_nonneg + have hq_nonneg : 0 ≤ (q : Real) := by + exact ratToReal_nonneg_of_nonneg hq + have hle : + (a : Real) / ((den : Real) * (scale : Real)) ≤ Real.sqrt (q : Real) := + (Real.le_sqrt hnonneg hq_nonneg).2 hsq + have hdown : + (sqrtLowerScaled q : Real) ≤ (a : Real) / ((den : Real) * (scale : Real)) := by + have hdown' : + ratToReal (ratRoundDown ((a : Rat) / (den * scale))) ≤ + (a : Real) / ((den : Real) * (scale : Real)) := by + simpa using ratRoundDown_le_real ((a : Rat) / (den * scale)) + simpa [sqrtLowerScaled, num, den, scale, a] using hdown' + exact le_trans hdown hle + +/-- Alternate square-root upper bound in reals. -/ +theorem real_sqrt_le_sqrtUpperAlt {q : Rat} (hq : 0 ≤ q) : + Real.sqrt (q : Real) ≤ (sqrtUpperAlt q : Real) := by + classical + set num : Nat := q.num.natAbs + set den : Nat := q.den + set a : Nat := Nat.sqrt (num * den) + have hden_pos : 0 < (den : Real) := by + exact_mod_cast q.den_pos + have hnumden_lt : (num * den : Real) < (a + 1) ^ 2 := by + exact_mod_cast (Nat.lt_succ_sqrt' (num * den)) + have hmul : (num : Real) * den ≤ (a + 1 : Real) ^ 2 := by + exact le_of_lt hnumden_lt + have hden_pos2 : 0 < (den : Real) ^ 2 := by + exact pow_pos hden_pos 2 + have hdiv : + (num : Real) * den / (den : Real) ^ 2 ≤ (a + 1 : Real) ^ 2 / (den : Real) ^ 2 := by + have hmul' : + (num : Real) * den * (den : Real) ^ 2 ≤ (a + 1 : Real) ^ 2 * (den : Real) ^ 2 := by + have hden_sq_nonneg : 0 ≤ (den : Real) ^ 2 := by + exact sq_nonneg (den : Real) + exact mul_le_mul_of_nonneg_right hmul hden_sq_nonneg + exact (div_le_div_iff₀ hden_pos2 hden_pos2).2 hmul' + have hden_ne : (den : Real) ≠ 0 := by + exact_mod_cast q.den_pos.ne' + have hq_cast : (q : Real) = (num : Real) * den / (den : Real) ^ 2 := by + simpa [num, den] using (rat_cast_eq_num_den_mul (q := q) hq) + have hpow : + ((a + 1 : Real) / (den : Real)) ^ 2 = + (a + 1 : Real) ^ 2 / (den : Real) ^ 2 := by + simp [pow_two, div_mul_div_comm] + have hsq : (q : Real) ≤ ((a + 1 : Real) / (den : Real)) ^ 2 := by + simpa [hq_cast, hpow] using hdiv + have hnonneg : 0 ≤ ((a + 1 : Real) / (den : Real)) := by + have hnum_nonneg : 0 ≤ (a + 1 : Real) := by exact_mod_cast (Nat.zero_le (a + 1)) + have hden_nonneg : 0 ≤ (den : Real) := by exact_mod_cast (Nat.zero_le den) + exact div_nonneg hnum_nonneg hden_nonneg + have hle : Real.sqrt (q : Real) ≤ (a + 1 : Real) / (den : Real) := + (Real.sqrt_le_iff).2 ⟨hnonneg, hsq⟩ + have hup : + (a + 1 : Real) / (den : Real) ≤ (sqrtUpperAlt q : Real) := by + have hup' : + (a + 1 : Real) / (den : Real) ≤ + ratToReal (ratRoundUp ((a + 1 : Rat) / den)) := by + simpa using real_le_ratRoundUp ((a + 1 : Rat) / den) + simpa [sqrtUpperAlt, num, den, a] using hup' + exact le_trans hle hup + +/-- Scaled square-root upper bound in reals. -/ +theorem real_sqrt_le_sqrtUpperScaled {q : Rat} (hq : 0 ≤ q) : + Real.sqrt (q : Real) ≤ (sqrtUpperScaled q : Real) := by + classical + set num : Nat := q.num.natAbs + set den : Nat := q.den + set scale : Nat := sqrtLowerScale + set a : Nat := Nat.sqrt (num * den * scale * scale) + have hden_pos : 0 < (den : Real) := by + exact_mod_cast q.den_pos + have hscale_pos_nat : 0 < scale := by + simp [scale, sqrtLowerScale] + have hscale_pos : 0 < (scale : Real) := by + exact_mod_cast hscale_pos_nat + have hnumden_lt : (num * den * scale * scale : Real) < (a + 1) ^ 2 := by + exact_mod_cast (Nat.lt_succ_sqrt' (num * den * scale * scale)) + have hmul : + (num : Real) * den * (scale : Real) * (scale : Real) ≤ (a + 1 : Real) ^ 2 := by + exact le_of_lt hnumden_lt + have hdenScale_pos : 0 < (den : Real) * (scale : Real) := by + exact mul_pos hden_pos hscale_pos + have hdenScale_pos2 : 0 < ((den : Real) * (scale : Real)) ^ 2 := by + exact pow_pos hdenScale_pos 2 + have hdiv : + (num : Real) * den * (scale : Real) * (scale : Real) / + ((den : Real) * (scale : Real)) ^ 2 ≤ + (a + 1 : Real) ^ 2 / ((den : Real) * (scale : Real)) ^ 2 := by + have hmul' : + (num : Real) * den * (scale : Real) * (scale : Real) * + ((den : Real) * (scale : Real)) ^ 2 ≤ + (a + 1 : Real) ^ 2 * ((den : Real) * (scale : Real)) ^ 2 := by + have hden_sq_nonneg : 0 ≤ ((den : Real) * (scale : Real)) ^ 2 := by + exact sq_nonneg _ + exact mul_le_mul hmul (le_refl _) hden_sq_nonneg (sq_nonneg _) + exact (div_le_div_iff₀ hdenScale_pos2 hdenScale_pos2).2 hmul' + have hq_cast' : + (q : Real) = + ((num : Real) * den * (scale : Real) * (scale : Real)) / + ((den : Real) * (scale : Real)) ^ 2 := by + simpa [num, den, scale] using + (rat_cast_eq_num_den_scale (q := q) hq (scale := scale) hscale_pos_nat) + have hpow : + ((a + 1 : Real) / ((den : Real) * (scale : Real))) ^ 2 = + (a + 1 : Real) ^ 2 / ((den : Real) * (scale : Real)) ^ 2 := by + simp [pow_two, div_mul_div_comm] + have hsq : + (q : Real) ≤ ((a + 1 : Real) / ((den : Real) * (scale : Real))) ^ 2 := by + simpa [hq_cast', hpow] using hdiv + have hnonneg : 0 ≤ ((a + 1 : Real) / ((den : Real) * (scale : Real))) := by + have hnum_nonneg : 0 ≤ (a + 1 : Real) := by + exact_mod_cast (Nat.zero_le (a + 1)) + have hden_nonneg : 0 ≤ (den : Real) * (scale : Real) := by + exact mul_nonneg (le_of_lt hden_pos) (le_of_lt hscale_pos) + exact div_nonneg hnum_nonneg hden_nonneg + have hle : + Real.sqrt (q : Real) ≤ (a + 1 : Real) / ((den : Real) * (scale : Real)) := + (Real.sqrt_le_iff).2 ⟨hnonneg, hsq⟩ + have hup : + (a + 1 : Real) / ((den : Real) * (scale : Real)) ≤ (sqrtUpperScaled q : Real) := by + have hup' : + (a + 1 : Real) / ((den : Real) * (scale : Real)) ≤ + ratToReal (ratRoundUp ((a + 1 : Rat) / (den * scale))) := by + simpa using real_le_ratRoundUp ((a + 1 : Rat) / (den * scale)) + simpa [sqrtUpperScaled, num, den, scale, a] using hup' + exact le_trans hle hup + +/-- Square-root lower bound in reals (tighter of three bounds). -/ +theorem sqrtLower_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : + (sqrtLower q : Real) ≤ Real.sqrt (q : Real) := by + have hbase := sqrtLowerBase_le_real_sqrt (q := q) hq + have halt := sqrtLowerAlt_le_real_sqrt (q := q) hq + have hscaled := sqrtLowerScaled_le_real_sqrt (q := q) hq + have hmax1 : + (max (sqrtLowerBase q) (sqrtLowerAlt q) : Real) ≤ Real.sqrt (q : Real) := by + simpa [ratToReal_max] using (max_le_iff).2 ⟨hbase, halt⟩ + have hmax2 : + (max (max (sqrtLowerBase q) (sqrtLowerAlt q)) (sqrtLowerScaled q) : Real) ≤ + Real.sqrt (q : Real) := by + simpa [ratToReal_max] using (max_le_iff).2 ⟨hmax1, hscaled⟩ + simpa [sqrtLower] using hmax2 + +/-- Square-root upper bound in reals (tighter of three bounds). -/ +theorem real_sqrt_le_sqrtUpper {q : Rat} (hq : 0 ≤ q) : + Real.sqrt (q : Real) ≤ (sqrtUpper q : Real) := by + have hbase := real_sqrt_le_sqrtUpperBase (q := q) hq + have halt := real_sqrt_le_sqrtUpperAlt (q := q) hq + have hscaled := real_sqrt_le_sqrtUpperScaled (q := q) hq + have hmin1 : + Real.sqrt (q : Real) ≤ min (sqrtUpperBase q : Real) (sqrtUpperAlt q : Real) := by + exact (le_min_iff).2 ⟨hbase, halt⟩ + have hmin2 : + Real.sqrt (q : Real) ≤ + min (min (sqrtUpperBase q : Real) (sqrtUpperAlt q : Real)) (sqrtUpperScaled q : Real) := by + exact (le_min_iff).2 ⟨hmin1, hscaled⟩ + simpa [sqrtUpper, ratToReal_min] using hmin2 + +end Bounds + +end Sound + +end Nfp diff --git a/Nfp/Sound/Bounds/MatrixNorm.lean b/Nfp/Sound/Bounds/MatrixNorm.lean index d301a0f..3e29e40 100644 --- a/Nfp/Sound/Bounds/MatrixNorm.lean +++ b/Nfp/Sound/Bounds/MatrixNorm.lean @@ -55,8 +55,12 @@ theorem rowSum_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : theorem rowSumWeighted_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (bound : Fin n → Rat) (i : Fin m) (hbound : ∀ j, 0 ≤ bound j) : 0 ≤ rowSumWeighted W bound i := by - simpa [rowSumWeighted, Linear.sumFin_eq_sum_univ] using - (Finset.sum_nonneg (fun j _ => mul_nonneg (abs_nonneg (W i j)) (hbound j))) + classical + have hsum : 0 ≤ ∑ j, |W i j| * bound j := by + refine Finset.sum_nonneg ?_ + intro j _ + exact mul_nonneg (abs_nonneg (W i j)) (hbound j) + simpa [rowSumWeighted, Linear.sumFin_eq_sum_univ] using hsum /-- Each row-sum is bounded by the row-sum norm. -/ theorem rowSum_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : @@ -106,37 +110,51 @@ def buildResidualIntervalCertFromMatrix {m n : Nat} (W : Matrix (Fin m) (Fin n) intro i exact mulVecIntervalLower_le_upper W lo hi hlohi i +/-- Summed absolute row entries factor out a scalar bound. -/ +theorem sum_abs_row_mul_eq_rowSum_mul {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (i : Fin m) (inputBound : Rat) : + (∑ j, |W i j| * inputBound) = rowSum W i * inputBound := by + have hsum : + (∑ j, |W i j|) * inputBound = ∑ j, |W i j| * inputBound := by + simpa using + (Finset.sum_mul + (s := (Finset.univ : Finset (Fin n))) + (f := fun j => |W i j|) + (a := inputBound)) + simpa [rowSum, Linear.sumFin_eq_sum_univ] using hsum.symm + +/-- Row-sum norm bounds a matrix-vector product under a uniform input bound. -/ +theorem abs_mulVec_le_rowSum {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (x : Fin n → Rat) (inputBound : Rat) + (hx : ∀ j, |x j| ≤ inputBound) : + ∀ i, |Matrix.mulVec W x i| ≤ rowSum W i * inputBound := by + intro i + have h1 : |∑ j, W i j * x j| ≤ ∑ j, |W i j * x j| := by + simpa using + (Finset.abs_sum_le_sum_abs + (f := fun j => W i j * x j) + (s := (Finset.univ : Finset (Fin n)))) + have h2 : ∑ j, |W i j * x j| ≤ ∑ j, |W i j| * inputBound := by + refine Finset.sum_le_sum ?_ + intro j _ + have hnonneg : 0 ≤ |W i j| := abs_nonneg (W i j) + calc + |W i j * x j| = |W i j| * |x j| := by + simp [abs_mul] + _ ≤ |W i j| * inputBound := by + exact mul_le_mul_of_nonneg_left (hx j) hnonneg + have h3 : ∑ j, |W i j| * inputBound = rowSum W i * inputBound := + sum_abs_row_mul_eq_rowSum_mul W i inputBound + simpa [Matrix.mulVec, dotProduct] using h1.trans (h2.trans_eq h3) + /-- Row-sum norm bounds a matrix-vector product under a uniform input bound. -/ theorem abs_mulVec_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (x : Fin n → Rat) (inputBound : Rat) (hx : ∀ j, |x j| ≤ inputBound) (hinput : 0 ≤ inputBound) : ∀ i, |Matrix.mulVec W x i| ≤ rowSumNorm W * inputBound := by intro i - have hrow : |Matrix.mulVec W x i| ≤ rowSum W i * inputBound := by - have h1 : |∑ j, W i j * x j| ≤ ∑ j, |W i j * x j| := by - simpa using - (Finset.abs_sum_le_sum_abs - (f := fun j => W i j * x j) - (s := (Finset.univ : Finset (Fin n)))) - have h2 : ∑ j, |W i j * x j| ≤ ∑ j, |W i j| * inputBound := by - refine Finset.sum_le_sum ?_ - intro j _ - have hnonneg : 0 ≤ |W i j| := abs_nonneg (W i j) - calc - |W i j * x j| = |W i j| * |x j| := by - simp [abs_mul] - _ ≤ |W i j| * inputBound := by - exact mul_le_mul_of_nonneg_left (hx j) hnonneg - have h3 : ∑ j, |W i j| * inputBound = rowSum W i * inputBound := by - have hsum : - (∑ j, |W i j|) * inputBound = ∑ j, |W i j| * inputBound := by - simpa using - (Finset.sum_mul - (s := (Finset.univ : Finset (Fin n))) - (f := fun j => |W i j|) - (a := inputBound)) - simpa [rowSum, Linear.sumFin_eq_sum_univ] using hsum.symm - simpa [Matrix.mulVec, dotProduct] using h1.trans (h2.trans_eq h3) + have hrow : |Matrix.mulVec W x i| ≤ rowSum W i * inputBound := + abs_mulVec_le_rowSum W x inputBound hx i have hle : rowSum W i ≤ rowSumNorm W := rowSum_le_rowSumNorm W i have hmul : rowSum W i * inputBound ≤ rowSumNorm W * inputBound := mul_le_mul_of_nonneg_right hle hinput @@ -157,6 +175,7 @@ def buildDownstreamLinearCert {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) · exact mul_nonneg (rowSumNorm_nonneg W) hinput · exact rowSumNorm_nonneg W + end Bounds end Sound diff --git a/Nfp/Sound/Bounds/MatrixNorm/Interval.lean b/Nfp/Sound/Bounds/MatrixNorm/Interval.lean index 5421204..dc875f6 100644 --- a/Nfp/Sound/Bounds/MatrixNorm/Interval.lean +++ b/Nfp/Sound/Bounds/MatrixNorm/Interval.lean @@ -83,24 +83,36 @@ lemma mul_between_of_bounds {a b x y : Rat} (hx : a ≤ x) (hx' : x ≤ b) : have hab : a ≤ b := le_trans hx hx' by_cases hy : 0 ≤ y · have hmin : min (a * y) (b * y) = a * y := by - have hle : a * y ≤ b * y := mul_le_mul_of_nonneg_right hab hy + have hle : a * y ≤ b * y := by + exact mul_le_mul_of_nonneg_right hab hy exact min_eq_left hle have hmax : max (a * y) (b * y) = b * y := by - have hle : a * y ≤ b * y := mul_le_mul_of_nonneg_right hab hy + have hle : a * y ≤ b * y := by + exact mul_le_mul_of_nonneg_right hab hy exact max_eq_right hle constructor - · simpa [hmin] using (mul_le_mul_of_nonneg_right hx hy) - · simpa [hmax] using (mul_le_mul_of_nonneg_right hx' hy) + · have hmul : a * y ≤ x * y := by + exact mul_le_mul_of_nonneg_right hx hy + simpa [hmin] using hmul + · have hmul : x * y ≤ b * y := by + exact mul_le_mul_of_nonneg_right hx' hy + simpa [hmax] using hmul · have hy' : y ≤ 0 := le_of_not_ge hy have hmin : min (a * y) (b * y) = b * y := by - have hle : b * y ≤ a * y := mul_le_mul_of_nonpos_right hab hy' + have hle : b * y ≤ a * y := by + exact mul_le_mul_of_nonpos_right hab hy' exact min_eq_right hle have hmax : max (a * y) (b * y) = a * y := by - have hle : b * y ≤ a * y := mul_le_mul_of_nonpos_right hab hy' + have hle : b * y ≤ a * y := by + exact mul_le_mul_of_nonpos_right hab hy' exact max_eq_left hle constructor - · simpa [hmin] using (mul_le_mul_of_nonpos_right hx' hy') - · simpa [hmax] using (mul_le_mul_of_nonpos_right hx hy') + · have hmul : b * y ≤ x * y := by + exact mul_le_mul_of_nonpos_right hx' hy' + simpa [hmin] using hmul + · have hmul : x * y ≤ a * y := by + exact mul_le_mul_of_nonpos_right hx hy' + simpa [hmax] using hmul /-- Lower interval endpoint bounds `x * y` when both factors are interval-bounded. -/ lemma mulIntervalLower_le_mul {a b c d x y : Rat} @@ -109,11 +121,11 @@ lemma mulIntervalLower_le_mul {a b c d x y : Rat} have hAy : min (a * c) (a * d) ≤ a * y := by have h := mul_between_of_bounds (a := c) (b := d) (x := y) (y := a) hy hy' - simpa [mul_comm] using h.1 + simpa only [mul_comm] using h.1 have hBy : min (b * c) (b * d) ≤ b * y := by have h := mul_between_of_bounds (a := c) (b := d) (x := y) (y := b) hy hy' - simpa [mul_comm] using h.1 + simpa only [mul_comm] using h.1 have hmin : min (min (a * c) (a * d)) (min (b * c) (b * d)) ≤ min (a * y) (b * y) := by apply le_min @@ -129,11 +141,11 @@ lemma mul_le_mulIntervalUpper {a b c d x y : Rat} have hAy : a * y ≤ max (a * c) (a * d) := by have h := mul_between_of_bounds (a := c) (b := d) (x := y) (y := a) hy hy' - simpa [mul_comm] using h.2 + simpa only [mul_comm] using h.2 have hBy : b * y ≤ max (b * c) (b * d) := by have h := mul_between_of_bounds (a := c) (b := d) (x := y) (y := b) hy hy' - simpa [mul_comm] using h.2 + simpa only [mul_comm] using h.2 have hmax : max (a * y) (b * y) ≤ max (max (a * c) (a * d)) (max (b * c) (b * d)) := by apply max_le @@ -190,22 +202,37 @@ def dotIntervalLowerUpper2SignSplitBoth {n : Nat} (dims1 dims2 : List (Fin n)) let bounds2 := dotIntervalLowerUpper2SignSplit dims2 lo2 hi2 lo1 hi1 (max bounds1.1 bounds2.1, min bounds1.2 bounds2.2) +/-- Sum of lower interval products bounds the dot-product sum (Rat). -/ +private theorem sum_mulIntervalLower_le_sum_mul {n : Nat} + (lo1 hi1 lo2 hi2 x y : Fin n → Rat) + (hlo1 : ∀ j, lo1 j ≤ x j) (hhi1 : ∀ j, x j ≤ hi1 j) + (hlo2 : ∀ j, lo2 j ≤ y j) (hhi2 : ∀ j, y j ≤ hi2 j) : + ∑ j, mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j) ≤ + ∑ j, x j * y j := by + refine Finset.sum_le_sum ?_ + intro j _ + exact mulIntervalLower_le_mul (hlo1 j) (hhi1 j) (hlo2 j) (hhi2 j) + +/-- Sum of products is bounded by the upper interval products (Rat). -/ +private theorem sum_mul_le_sum_mulIntervalUpper {n : Nat} + (lo1 hi1 lo2 hi2 x y : Fin n → Rat) + (hlo1 : ∀ j, lo1 j ≤ x j) (hhi1 : ∀ j, x j ≤ hi1 j) + (hlo2 : ∀ j, lo2 j ≤ y j) (hhi2 : ∀ j, y j ≤ hi2 j) : + ∑ j, x j * y j ≤ + ∑ j, mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j) := by + refine Finset.sum_le_sum ?_ + intro j _ + exact mul_le_mulIntervalUpper (hlo1 j) (hhi1 j) (hlo2 j) (hhi2 j) + theorem dotIntervalLower2_le_dotProduct {n : Nat} (lo1 hi1 lo2 hi2 x y : Fin n → Rat) (hlo1 : ∀ j, lo1 j ≤ x j) (hhi1 : ∀ j, x j ≤ hi1 j) (hlo2 : ∀ j, lo2 j ≤ y j) (hhi2 : ∀ j, y j ≤ hi2 j) : dotIntervalLower2 lo1 hi1 lo2 hi2 ≤ dotProduct x y := by classical - have hterm : - ∀ j, - mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j) ≤ x j * y j := by - intro j - exact mulIntervalLower_le_mul (hlo1 j) (hhi1 j) (hlo2 j) (hhi2 j) - have hsum : - ∑ j, mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j) ≤ - ∑ j, x j * y j := by - refine Finset.sum_le_sum ?_ - intro j _ - exact hterm j + have hsum := + sum_mulIntervalLower_le_sum_mul + (lo1 := lo1) (hi1 := hi1) (lo2 := lo2) (hi2 := hi2) (x := x) (y := y) + hlo1 hhi1 hlo2 hhi2 simpa [dotIntervalLower2, Linear.sumFin_eq_sum_univ, dotProduct] using hsum theorem dotProduct_le_dotIntervalUpper2 {n : Nat} (lo1 hi1 lo2 hi2 x y : Fin n → Rat) @@ -213,17 +240,10 @@ theorem dotProduct_le_dotIntervalUpper2 {n : Nat} (lo1 hi1 lo2 hi2 x y : Fin n (hlo2 : ∀ j, lo2 j ≤ y j) (hhi2 : ∀ j, y j ≤ hi2 j) : dotProduct x y ≤ dotIntervalUpper2 lo1 hi1 lo2 hi2 := by classical - have hterm : - ∀ j, - x j * y j ≤ mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j) := by - intro j - exact mul_le_mulIntervalUpper (hlo1 j) (hhi1 j) (hlo2 j) (hhi2 j) - have hsum : - ∑ j, x j * y j ≤ - ∑ j, mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j) := by - refine Finset.sum_le_sum ?_ - intro j _ - exact hterm j + have hsum := + sum_mul_le_sum_mulIntervalUpper + (lo1 := lo1) (hi1 := hi1) (lo2 := lo2) (hi2 := hi2) (x := x) (y := y) + hlo1 hhi1 hlo2 hhi2 simpa [dotIntervalUpper2, Linear.sumFin_eq_sum_univ, dotProduct] using hsum /-- Lower interval endpoint using a shared-denominator accumulator. -/ def dotIntervalLowerCommonDen {n : Nat} (v lo hi : Fin n → Rat) : Rat := @@ -251,11 +271,11 @@ def dotIntervalUpperUnnorm {n : Nat} (v lo hi : Fin n → Rat) : Rat := theorem dotIntervalLowerCommonDen_eq {n : Nat} (v lo hi : Fin n → Rat) : dotIntervalLowerCommonDen v lo hi = dotIntervalLower v lo hi := by - simp [dotIntervalLowerCommonDen, dotIntervalLower, Linear.sumFinCommonDen_eq_sumFin] + simp only [dotIntervalLowerCommonDen, dotIntervalLower, Linear.sumFinCommonDen_eq_sumFin] theorem dotIntervalUpperCommonDen_eq {n : Nat} (v lo hi : Fin n → Rat) : dotIntervalUpperCommonDen v lo hi = dotIntervalUpper v lo hi := by - simp [dotIntervalUpperCommonDen, dotIntervalUpper, Linear.sumFinCommonDen_eq_sumFin] + simp only [dotIntervalUpperCommonDen, dotIntervalUpper, Linear.sumFinCommonDen_eq_sumFin] private lemma foldl_pair_fst {α : Type _} (xs : List α) (f g : α → Rat) (a b : Rat) : (xs.foldl (fun acc x => (acc.1 + f x, acc.2 + g x)) (a, b)).1 = @@ -300,7 +320,7 @@ theorem dotIntervalLowerUpper2CommonDen_snd {n : Nat} (lo1 hi1 lo2 hi2 : Fin n theorem dotIntervalLowerUpper2CommonDen_eq {n : Nat} (lo1 hi1 lo2 hi2 : Fin n → Rat) : dotIntervalLowerUpper2CommonDen lo1 hi1 lo2 hi2 = (dotIntervalLower2 lo1 hi1 lo2 hi2, dotIntervalUpper2 lo1 hi1 lo2 hi2) := by - ext <;> simp [dotIntervalLowerUpper2CommonDen_fst, dotIntervalLowerUpper2CommonDen_snd] + ext <;> simp only [dotIntervalLowerUpper2CommonDen_fst, dotIntervalLowerUpper2CommonDen_snd] theorem dotIntervalLowerUpperCommonDen_fst {n : Nat} (v lo hi : Fin n → Rat) : (dotIntervalLowerUpperCommonDen v lo hi).1 = dotIntervalLowerCommonDen v lo hi := by @@ -328,7 +348,7 @@ theorem dotIntervalLowerUpperCommonDen_snd {n : Nat} (v lo hi : Fin n → Rat) : theorem dotIntervalLowerUpperCommonDen_eq {n : Nat} (v lo hi : Fin n → Rat) : dotIntervalLowerUpperCommonDen v lo hi = (dotIntervalLowerCommonDen v lo hi, dotIntervalUpperCommonDen v lo hi) := by - ext <;> simp [dotIntervalLowerUpperCommonDen_fst, dotIntervalLowerUpperCommonDen_snd] + ext <;> simp only [dotIntervalLowerUpperCommonDen_fst, dotIntervalLowerUpperCommonDen_snd] theorem dotIntervalLowerUnnorm_eq {n : Nat} (v lo hi : Fin n → Rat) : dotIntervalLowerUnnorm v lo hi = dotIntervalLower v lo hi := rfl @@ -375,13 +395,13 @@ def dotIntervalUpperCachedRat {n : Nat} (v lo hi : Fin n → Rat) : Rat := theorem dotIntervalLowerCachedRat_eq {n : Nat} (v lo hi : Fin n → Rat) : dotIntervalLowerCachedRat v lo hi = dotIntervalLower v lo hi := by classical - simp [dotIntervalLowerCachedRat, dotIntervalLower, Linear.sumFin_eq_list_foldl, + simp only [dotIntervalLowerCachedRat, dotIntervalLower, Linear.sumFin_eq_list_foldl, Array.getElem_ofFn] theorem dotIntervalUpperCachedRat_eq {n : Nat} (v lo hi : Fin n → Rat) : dotIntervalUpperCachedRat v lo hi = dotIntervalUpper v lo hi := by classical - simp [dotIntervalUpperCachedRat, dotIntervalUpper, Linear.sumFin_eq_list_foldl, + simp only [dotIntervalUpperCachedRat, dotIntervalUpper, Linear.sumFin_eq_list_foldl, Array.getElem_ofFn] /-! Absolute bounds. -/ @@ -408,9 +428,15 @@ theorem dotIntervalLower_le_dotProduct {n : Nat} (v lo hi x : Fin n → Rat) refine Finset.sum_le_sum ?_ intro j _ by_cases hv : 0 ≤ v j - · simpa [hv] using (mul_le_mul_of_nonneg_left (hlo j) hv) + · have hmul : v j * lo j ≤ v j * x j := by + exact mul_le_mul_of_nonneg_left (hlo j) hv + simpa [hv] using hmul · have hv' : v j ≤ 0 := le_of_lt (lt_of_not_ge hv) - simpa [hv] using (mul_le_mul_of_nonpos_left (hhi j) hv') + have hmul : v j * hi j ≤ v j * x j := by + have hmul' : hi j * v j ≤ x j * v j := by + exact mul_le_mul_of_nonpos_right (hhi j) hv' + simpa only [mul_comm] using hmul' + simpa [hv] using hmul theorem dotProduct_le_dotIntervalUpper {n : Nat} (v lo hi x : Fin n → Rat) (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : @@ -420,9 +446,15 @@ theorem dotProduct_le_dotIntervalUpper {n : Nat} (v lo hi x : Fin n → Rat) refine Finset.sum_le_sum ?_ intro j _ by_cases hv : 0 ≤ v j - · simpa [hv] using (mul_le_mul_of_nonneg_left (hhi j) hv) + · have hmul : v j * x j ≤ v j * hi j := by + exact mul_le_mul_of_nonneg_left (hhi j) hv + simpa [hv] using hmul · have hv' : v j ≤ 0 := le_of_lt (lt_of_not_ge hv) - simpa [hv] using (mul_le_mul_of_nonpos_left (hlo j) hv') + have hmul : v j * x j ≤ v j * lo j := by + have hmul' : x j * v j ≤ lo j * v j := by + exact mul_le_mul_of_nonpos_right (hlo j) hv' + simpa only [mul_comm] using hmul' + simpa [hv] using hmul theorem abs_le_max_abs_abs_of_interval {a b x : Rat} (hlo : a ≤ x) (hhi : x ≤ b) : |x| ≤ max |a| |b| := by @@ -467,24 +499,36 @@ lemma mul_between_of_bounds_real {a b x y : Real} (hx : a ≤ x) (hx' : x ≤ b) have hab : a ≤ b := le_trans hx hx' by_cases hy : 0 ≤ y · have hmin : min (a * y) (b * y) = a * y := by - have hle : a * y ≤ b * y := mul_le_mul_of_nonneg_right hab hy + have hle : a * y ≤ b * y := by + exact mul_le_mul_of_nonneg_right hab hy exact min_eq_left hle have hmax : max (a * y) (b * y) = b * y := by - have hle : a * y ≤ b * y := mul_le_mul_of_nonneg_right hab hy + have hle : a * y ≤ b * y := by + exact mul_le_mul_of_nonneg_right hab hy exact max_eq_right hle constructor - · simpa [hmin] using (mul_le_mul_of_nonneg_right hx hy) - · simpa [hmax] using (mul_le_mul_of_nonneg_right hx' hy) + · have hmul : a * y ≤ x * y := by + exact mul_le_mul_of_nonneg_right hx hy + simpa [hmin] using hmul + · have hmul : x * y ≤ b * y := by + exact mul_le_mul_of_nonneg_right hx' hy + simpa [hmax] using hmul · have hy' : y ≤ 0 := le_of_not_ge hy have hmin : min (a * y) (b * y) = b * y := by - have hle : b * y ≤ a * y := mul_le_mul_of_nonpos_right hab hy' + have hle : b * y ≤ a * y := by + exact mul_le_mul_of_nonpos_right hab hy' exact min_eq_right hle have hmax : max (a * y) (b * y) = a * y := by - have hle : b * y ≤ a * y := mul_le_mul_of_nonpos_right hab hy' + have hle : b * y ≤ a * y := by + exact mul_le_mul_of_nonpos_right hab hy' exact max_eq_left hle constructor - · simpa [hmin] using (mul_le_mul_of_nonpos_right hx' hy') - · simpa [hmax] using (mul_le_mul_of_nonpos_right hx hy') + · have hmul : b * y ≤ x * y := by + exact mul_le_mul_of_nonpos_right hx' hy' + simpa [hmin] using hmul + · have hmul : x * y ≤ a * y := by + exact mul_le_mul_of_nonpos_right hx hy' + simpa [hmax] using hmul lemma mulIntervalLower_le_mul_real {a b c d : Rat} {x y : Real} (hx : (a : Real) ≤ x) (hx' : x ≤ (b : Real)) @@ -494,12 +538,12 @@ lemma mulIntervalLower_le_mul_real {a b c d : Rat} {x y : Real} min ((a : Real) * (c : Real)) ((a : Real) * (d : Real)) ≤ (a : Real) * y := by have h := mul_between_of_bounds_real (a := (c : Real)) (b := (d : Real)) (x := y) (y := (a : Real)) hy hy' - simpa [mul_comm] using h.1 + simpa only [mul_comm] using h.1 have hBy : min ((b : Real) * (c : Real)) ((b : Real) * (d : Real)) ≤ (b : Real) * y := by have h := mul_between_of_bounds_real (a := (c : Real)) (b := (d : Real)) (x := y) (y := (b : Real)) hy hy' - simpa [mul_comm] using h.1 + simpa only [mul_comm] using h.1 have hmin : min (min ((a : Real) * (c : Real)) ((a : Real) * (d : Real))) (min ((b : Real) * (c : Real)) ((b : Real) * (d : Real))) ≤ @@ -513,7 +557,7 @@ lemma mulIntervalLower_le_mul_real {a b c d : Rat} {x y : Real} (mulIntervalLower a b c d : Real) = min (min ((a : Real) * (c : Real)) ((a : Real) * (d : Real))) (min ((b : Real) * (c : Real)) ((b : Real) * (d : Real))) := by - simp [mulIntervalLower, Rat.cast_min, Rat.cast_mul] + simp only [mulIntervalLower, Rat.cast_min, Rat.cast_mul] calc (mulIntervalLower a b c d : Real) = min (min ((a : Real) * (c : Real)) ((a : Real) * (d : Real))) @@ -529,12 +573,12 @@ lemma mul_le_mulIntervalUpper_real {a b c d : Rat} {x y : Real} (a : Real) * y ≤ max ((a : Real) * (c : Real)) ((a : Real) * (d : Real)) := by have h := mul_between_of_bounds_real (a := (c : Real)) (b := (d : Real)) (x := y) (y := (a : Real)) hy hy' - simpa [mul_comm] using h.2 + simpa only [mul_comm] using h.2 have hBy : (b : Real) * y ≤ max ((b : Real) * (c : Real)) ((b : Real) * (d : Real)) := by have h := mul_between_of_bounds_real (a := (c : Real)) (b := (d : Real)) (x := y) (y := (b : Real)) hy hy' - simpa [mul_comm] using h.2 + simpa only [mul_comm] using h.2 have hmax : max ((a : Real) * y) ((b : Real) * y) ≤ max (max ((a : Real) * (c : Real)) ((a : Real) * (d : Real))) @@ -548,13 +592,35 @@ lemma mul_le_mulIntervalUpper_real {a b c d : Rat} {x y : Real} (mulIntervalUpper a b c d : Real) = max (max ((a : Real) * (c : Real)) ((a : Real) * (d : Real))) (max ((b : Real) * (c : Real)) ((b : Real) * (d : Real))) := by - simp [mulIntervalUpper, Rat.cast_max, Rat.cast_mul] + simp only [mulIntervalUpper, Rat.cast_max, Rat.cast_mul] calc x * y ≤ max ((a : Real) * y) ((b : Real) * y) := hxy _ ≤ max (max ((a : Real) * (c : Real)) ((a : Real) * (d : Real))) (max ((b : Real) * (c : Real)) ((b : Real) * (d : Real))) := hmax _ = (mulIntervalUpper a b c d : Real) := hcast.symm +/-- Sum of lower interval products bounds the dot-product sum (Real). -/ +private theorem sum_mulIntervalLower_le_sum_mul_real {n : Nat} + (lo1 hi1 lo2 hi2 : Fin n → Rat) (x y : Fin n → Real) + (hlo1 : ∀ j, (lo1 j : Real) ≤ x j) (hhi1 : ∀ j, x j ≤ (hi1 j : Real)) + (hlo2 : ∀ j, (lo2 j : Real) ≤ y j) (hhi2 : ∀ j, y j ≤ (hi2 j : Real)) : + (∑ j, (mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j) : Real)) ≤ + ∑ j, x j * y j := by + refine Finset.sum_le_sum ?_ + intro j _ + exact mulIntervalLower_le_mul_real (hlo1 j) (hhi1 j) (hlo2 j) (hhi2 j) + +/-- Sum of products is bounded by the upper interval products (Real). -/ +private theorem sum_mul_le_sum_mulIntervalUpper_real {n : Nat} + (lo1 hi1 lo2 hi2 : Fin n → Rat) (x y : Fin n → Real) + (hlo1 : ∀ j, (lo1 j : Real) ≤ x j) (hhi1 : ∀ j, x j ≤ (hi1 j : Real)) + (hlo2 : ∀ j, (lo2 j : Real) ≤ y j) (hhi2 : ∀ j, y j ≤ (hi2 j : Real)) : + ∑ j, x j * y j ≤ + ∑ j, (mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j) : Real) := by + refine Finset.sum_le_sum ?_ + intro j _ + exact mul_le_mulIntervalUpper_real (hlo1 j) (hhi1 j) (hlo2 j) (hhi2 j) + theorem dotIntervalLower2_le_dotProduct_real {n : Nat} (lo1 hi1 lo2 hi2 : Fin n → Rat) (x y : Fin n → Real) (hlo1 : ∀ j, (lo1 j : Real) ≤ x j) (hhi1 : ∀ j, x j ≤ (hi1 j : Real)) @@ -567,12 +633,10 @@ theorem dotIntervalLower2_le_dotProduct_real {n : Nat} (lo1 hi1 lo2 hi2 : Fin n simpa [dotIntervalLower2, ratToReal] using (Linear.ratToReal_sumFin (f := fun j => mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j))) - have hsum : - (∑ j, (mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j) : Real)) ≤ - ∑ j, x j * y j := by - refine Finset.sum_le_sum ?_ - intro j _ - exact mulIntervalLower_le_mul_real (hlo1 j) (hhi1 j) (hlo2 j) (hhi2 j) + have hsum := + sum_mulIntervalLower_le_sum_mul_real + (lo1 := lo1) (hi1 := hi1) (lo2 := lo2) (hi2 := hi2) (x := x) (y := y) + hlo1 hhi1 hlo2 hhi2 simpa [hcast, dotProduct] using hsum theorem dotProduct_le_dotIntervalUpper2_real {n : Nat} (lo1 hi1 lo2 hi2 : Fin n → Rat) @@ -587,12 +651,10 @@ theorem dotProduct_le_dotIntervalUpper2_real {n : Nat} (lo1 hi1 lo2 hi2 : Fin n simpa [dotIntervalUpper2, ratToReal] using (Linear.ratToReal_sumFin (f := fun j => mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j))) - have hsum : - ∑ j, x j * y j ≤ - ∑ j, (mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j) : Real) := by - refine Finset.sum_le_sum ?_ - intro j _ - exact mul_le_mulIntervalUpper_real (hlo1 j) (hhi1 j) (hlo2 j) (hhi2 j) + have hsum := + sum_mul_le_sum_mulIntervalUpper_real + (lo1 := lo1) (hi1 := hi1) (lo2 := lo2) (hi2 := hi2) (x := x) (y := y) + hlo1 hhi1 hlo2 hhi2 simpa [hcast, dotProduct] using hsum theorem dotIntervalLowerUpper2SignSplit_spec_real {n : Nat} (dims : List (Fin n)) @@ -612,7 +674,7 @@ theorem dotIntervalLowerUpper2SignSplit_spec_real {n : Nat} (dims : List (Fin n) dotProduct_le_dotIntervalUpper2_real (lo1 := lo1) (hi1 := hi1) (lo2 := lo2) (hi2 := hi2) (x := x) (y := y) hlo1 hhi1 hlo2 hhi2 - simpa [dotIntervalLowerUpper2SignSplit, dotIntervalLowerUpper2CommonDen_fst, + simpa only [dotIntervalLowerUpper2SignSplit, dotIntervalLowerUpper2CommonDen_fst, dotIntervalLowerUpper2CommonDen_snd] using And.intro hlow hhigh | cons i rest ih => by_cases hx : 0 ≤ x i @@ -634,13 +696,9 @@ theorem dotIntervalLowerUpper2SignSplit_spec_real {n : Nat} (dims : List (Fin n) have hpos := ih (lo1 := clamped.1) (hi1 := clamped.2) hlo1' hhi1' have hlow : (min boundsPos.1 boundsNeg.1 : Real) ≤ dotProduct x y := by - have hmin : (min boundsPos.1 boundsNeg.1 : Real) ≤ (boundsPos.1 : Real) := by - exact min_le_left _ _ - exact le_trans hmin hpos.1 + exact le_trans (min_le_left _ _) hpos.1 have hhigh : dotProduct x y ≤ (max boundsPos.2 boundsNeg.2 : Real) := by - have hmax : (boundsPos.2 : Real) ≤ (max boundsPos.2 boundsNeg.2 : Real) := by - exact le_max_left _ _ - exact le_trans hpos.2 hmax + exact le_trans hpos.2 (le_max_left _ _) simpa [dotIntervalLowerUpper2SignSplit, boundsPos, boundsNeg, clamped] using And.intro hlow hhigh · have hxneg : x i ≤ 0 := le_of_lt (lt_of_not_ge hx) @@ -662,13 +720,9 @@ theorem dotIntervalLowerUpper2SignSplit_spec_real {n : Nat} (dims : List (Fin n) have hneg := ih (lo1 := clamped.1) (hi1 := clamped.2) hlo1' hhi1' have hlow : (min boundsPos.1 boundsNeg.1 : Real) ≤ dotProduct x y := by - have hmin : (min boundsPos.1 boundsNeg.1 : Real) ≤ (boundsNeg.1 : Real) := by - exact min_le_right _ _ - exact le_trans hmin hneg.1 + exact le_trans (min_le_right _ _) hneg.1 have hhigh : dotProduct x y ≤ (max boundsPos.2 boundsNeg.2 : Real) := by - have hmax : (boundsNeg.2 : Real) ≤ (max boundsPos.2 boundsNeg.2 : Real) := by - exact le_max_right _ _ - exact le_trans hneg.2 hmax + exact le_trans hneg.2 (le_max_right _ _) simpa [dotIntervalLowerUpper2SignSplit, boundsPos, boundsNeg, clamped] using And.intro hlow hhigh @@ -690,7 +744,7 @@ theorem dotIntervalLowerUpper2SignSplitBoth_spec_real {n : Nat} (dims1 dims2 : L (dims := dims2) (lo1 := lo2) (hi1 := hi2) (lo2 := lo1) (hi2 := hi1) (x := y) (y := x) hlo2 hhi2 hlo1 hhi1 have h2 : (bounds2.1 : Real) ≤ dotProduct x y ∧ dotProduct x y ≤ (bounds2.2 : Real) := by - simpa [dotProduct_comm] using h2swap + simpa only [dotProduct_comm] using h2swap have hlow' : max (bounds1.1 : Real) (bounds2.1 : Real) ≤ dotProduct x y := (max_le_iff).2 ⟨h1.1, h2.1⟩ have hhigh' : dotProduct x y ≤ min (bounds1.2 : Real) (bounds2.2 : Real) := @@ -699,7 +753,7 @@ theorem dotIntervalLowerUpper2SignSplitBoth_spec_real {n : Nat} (dims1 dims2 : L simpa [ratToReal_max] using hlow' have hhigh : dotProduct x y ≤ ((min bounds1.2 bounds2.2 : Rat) : Real) := by simpa [ratToReal_min] using hhigh' - simpa [dotIntervalLowerUpper2SignSplitBoth, bounds1, bounds2] using And.intro hlow hhigh + simpa only [dotIntervalLowerUpper2SignSplitBoth, bounds1, bounds2] using And.intro hlow hhigh theorem dotIntervalLower_le_dotProduct_real {n : Nat} (v lo hi : Fin n → Rat) (x : Fin n → Real) @@ -719,12 +773,26 @@ theorem dotIntervalLower_le_dotProduct_real {n : Nat} (v lo hi : Fin n → Rat) intro j _ by_cases hv : 0 ≤ v j · have hv' : (0 : Real) ≤ (v j : Real) := ratToReal_nonneg_of_nonneg hv - simpa [hv] using (mul_le_mul_of_nonneg_left (hlo j) hv') + have hmul : (v j : Real) * (lo j : Real) ≤ (v j : Real) * x j := by + exact mul_le_mul_of_nonneg_left (hlo j) hv' + simpa [hv] using hmul · have hv' : (v j : Real) ≤ 0 := by exact (ratToReal_nonpos_iff (x := v j)).2 (le_of_not_ge hv) - simpa [hv] using (mul_le_mul_of_nonpos_left (hhi j) hv') + have hmul : (v j : Real) * (hi j : Real) ≤ (v j : Real) * x j := by + exact mul_le_mul_of_nonpos_left (hhi j) hv' + simpa [hv] using hmul simpa [hcast, dotProduct] using hsum +theorem dotIntervalLower_le_dotProduct_real_add {n : Nat} + (v lo hi : Fin n → Rat) (x : Fin n → Real) (b : Real) + (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : + (dotIntervalLower v lo hi : Real) + b ≤ + dotProduct (fun j => (v j : Real)) x + b := by + have hlow := + dotIntervalLower_le_dotProduct_real (v := v) (lo := lo) (hi := hi) + (x := x) hlo hhi + simpa [add_comm] using add_le_add_left hlow b + theorem dotProduct_le_dotIntervalUpper_real {n : Nat} (v lo hi : Fin n → Rat) (x : Fin n → Real) (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : @@ -743,12 +811,26 @@ theorem dotProduct_le_dotIntervalUpper_real {n : Nat} (v lo hi : Fin n → Rat) intro j _ by_cases hv : 0 ≤ v j · have hv' : (0 : Real) ≤ (v j : Real) := ratToReal_nonneg_of_nonneg hv - simpa [hv] using (mul_le_mul_of_nonneg_left (hhi j) hv') + have hmul : (v j : Real) * x j ≤ (v j : Real) * (hi j : Real) := by + exact mul_le_mul_of_nonneg_left (hhi j) hv' + simpa [hv] using hmul · have hv' : (v j : Real) ≤ 0 := by exact (ratToReal_nonpos_iff (x := v j)).2 (le_of_not_ge hv) - simpa [hv] using (mul_le_mul_of_nonpos_left (hlo j) hv') + have hmul : (v j : Real) * x j ≤ (v j : Real) * (lo j : Real) := by + exact mul_le_mul_of_nonpos_left (hlo j) hv' + simpa [hv] using hmul simpa [hcast, dotProduct] using hsum +theorem dotProduct_le_dotIntervalUpper_real_add {n : Nat} + (v lo hi : Fin n → Rat) (x : Fin n → Real) (b : Real) + (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : + dotProduct (fun j => (v j : Real)) x + b ≤ + (dotIntervalUpper v lo hi : Real) + b := by + have hhigh := + dotProduct_le_dotIntervalUpper_real (v := v) (lo := lo) (hi := hi) + (x := x) hlo hhi + simpa [add_comm] using add_le_add_left hhigh b + theorem abs_le_max_abs_abs_of_interval_real {a b x : Real} (hlo : a ≤ x) (hhi : x ≤ b) : |x| ≤ max |a| |b| := by exact abs_le_max_abs_abs hlo hhi @@ -765,8 +847,12 @@ theorem abs_le_intervalAbsBound_real {n : Nat} (lo hi : Fin n → Rat) (x : Fin max |(lo i : Real)| |(hi i : Real)| ≤ (intervalAbsBound lo hi : Real) := by refine max_le_iff.mpr ?_ constructor - · exact ratToReal_abs_le_of_le (le_trans (le_max_left _ _) hsup) - · exact ratToReal_abs_le_of_le (le_trans (le_max_right _ _) hsup) + · have hleft : |lo i| ≤ intervalAbsBound lo hi := by + exact le_trans (le_max_left _ _) (max_abs_le_intervalAbsBound lo hi i) + exact ratToReal_abs_le_of_le hleft + · have hright : |hi i| ≤ intervalAbsBound lo hi := by + exact le_trans (le_max_right _ _) (max_abs_le_intervalAbsBound lo hi i) + exact ratToReal_abs_le_of_le hright exact le_trans hbound hsup_real theorem abs_dotProduct_le_dotIntervalAbsBound_real {n : Nat} (v lo hi : Fin n → Rat) diff --git a/Nfp/Sound/Bounds/Mlp.lean b/Nfp/Sound/Bounds/Mlp.lean index 26a8574..659e967 100644 --- a/Nfp/Sound/Bounds/Mlp.lean +++ b/Nfp/Sound/Bounds/Mlp.lean @@ -65,16 +65,16 @@ theorem mlpBounds_spec {dModel hidden : Nat} dotProduct (fun j => (wIn j h : Real)) x + (bIn h : Real) have hpre_lower : ∀ h, (preLo h : Real) ≤ pre h := by intro h - simpa [pre, preLo] using - add_le_add_right - (dotIntervalLower_le_dotProduct_real (v := fun j => wIn j h) lo hi x hlo hhi) - (bIn h : Real) + have hlow := + dotIntervalLower_le_dotProduct_real_add (v := fun j => wIn j h) + (lo := lo) (hi := hi) (x := x) (b := (bIn h : Real)) hlo hhi + simpa [pre, preLo] using hlow have hpre_upper : ∀ h, pre h ≤ (preHi h : Real) := by intro h - simpa [pre, preHi] using - add_le_add_right - (dotProduct_le_dotIntervalUpper_real (v := fun j => wIn j h) lo hi x hlo hhi) - (bIn h : Real) + have hhigh := + dotProduct_le_dotIntervalUpper_real_add (v := fun j => wIn j h) + (lo := lo) (hi := hi) (x := x) (b := (bIn h : Real)) hlo hhi + simpa [pre, preHi] using hhigh let geluBounds : Fin hidden → Rat × Rat := fun h => geluInterval (preLo h) (preHi h) let geluLo : Fin hidden → Rat := fun h => (geluBounds h).1 let geluHi : Fin hidden → Rat := fun h => (geluBounds h).2 @@ -90,18 +90,22 @@ theorem mlpBounds_spec {dModel hidden : Nat} dotIntervalUpper (fun h => wOut h i) geluLo geluHi + bOut i have hout_lower : (outLo i : Real) ≤ dotProduct (fun h => (wOut h i : Real)) hidden + (bOut i : Real) := by - simpa [outLo] using - add_le_add_right - (dotIntervalLower_le_dotProduct_real (v := fun h => wOut h i) geluLo geluHi hidden - (fun h => (hgelu h).1) (fun h => (hgelu h).2)) - (bOut i : Real) + have hgelu_lo : ∀ h, (geluLo h : Real) ≤ hidden h := fun h => (hgelu h).1 + have hgelu_hi : ∀ h, hidden h ≤ (geluHi h : Real) := fun h => (hgelu h).2 + have hlow := + dotIntervalLower_le_dotProduct_real_add (v := fun h => wOut h i) + (lo := geluLo) (hi := geluHi) (x := hidden) (b := (bOut i : Real)) + hgelu_lo hgelu_hi + simpa [outLo] using hlow have hout_upper : dotProduct (fun h => (wOut h i : Real)) hidden + (bOut i : Real) ≤ (outHi i : Real) := by - simpa [outHi] using - add_le_add_right - (dotProduct_le_dotIntervalUpper_real (v := fun h => wOut h i) geluLo geluHi hidden - (fun h => (hgelu h).1) (fun h => (hgelu h).2)) - (bOut i : Real) + have hgelu_lo : ∀ h, (geluLo h : Real) ≤ hidden h := fun h => (hgelu h).1 + have hgelu_hi : ∀ h, hidden h ≤ (geluHi h : Real) := fun h => (hgelu h).2 + have hhigh := + dotProduct_le_dotIntervalUpper_real_add (v := fun h => wOut h i) + (lo := geluLo) (hi := geluHi) (x := hidden) (b := (bOut i : Real)) + hgelu_lo hgelu_hi + simpa [outHi] using hhigh have hlo' : (outLo i : Real) ≤ mlpReal wIn bIn wOut bOut x i := by simpa [mlpReal, hidden, pre] using hout_lower have hhi' : mlpReal wIn bIn wOut bOut x i ≤ (outHi i : Real) := by @@ -199,8 +203,12 @@ theorem residualAddBounds_spec {n : Nat} (x : Fin n → Rat) ∀ i, (bounds.1 i : Real) ≤ (x i : Real) + y i ∧ (x i : Real) + y i ≤ (bounds.2 i : Real) := by intro bounds i - have hlow := add_le_add_left (hlo i) (x i : Real) - have hhigh := add_le_add_left (hhi i) (x i : Real) + have hlow : (x i : Real) + (lo i : Real) ≤ (x i : Real) + y i := by + simpa [add_comm] using + add_le_add_left (hlo i) (x i : Real) + have hhigh : (x i : Real) + y i ≤ (x i : Real) + (hi i : Real) := by + simpa [add_comm] using + add_le_add_left (hhi i) (x i : Real) constructor · simpa [bounds, residualAddBounds] using hlow · simpa [bounds, residualAddBounds] using hhigh @@ -267,8 +275,14 @@ theorem layerNormAbsMlpResidualBounds_spec {n hidden : Nat} layerNormAbsMlpBounds_spec eps gamma beta wIn bIn wOut bOut lo hi x hne heps hsqrt hlo hhi have hlo' := (hmlp i).1 have hhi' := (hmlp i).2 - have hlow := add_le_add (hlo i) hlo' - have hhigh := add_le_add (hhi i) hhi' + have hlow : + (lo i : Real) + (mlp.1 i : Real) ≤ + x i + mlpReal wIn bIn wOut bOut (layerNormRealOfReal eps gamma beta x) i := by + exact add_le_add (hlo i) hlo' + have hhigh : + x i + mlpReal wIn bIn wOut bOut (layerNormRealOfReal eps gamma beta x) i ≤ + (hi i : Real) + (mlp.2 i : Real) := by + exact add_le_add (hhi i) hhi' constructor · simpa [bounds, layerNormAbsMlpResidualBounds, mlp] using hlow · simpa [bounds, layerNormAbsMlpResidualBounds, mlp] using hhigh diff --git a/Nfp/Sound/Bounds/Transformer.lean b/Nfp/Sound/Bounds/Transformer.lean index c6ae5a4..6544275 100644 --- a/Nfp/Sound/Bounds/Transformer.lean +++ b/Nfp/Sound/Bounds/Transformer.lean @@ -21,7 +21,6 @@ namespace Bounds open scoped BigOperators - /-- Real-valued output of a transformer layer. -/ noncomputable def transformerLayerReal {seq dModel dHead numHeads hidden : Nat} [NeZero seq] (eps : Rat) (layer : Model.Gpt2LayerSlice dModel hidden) @@ -69,10 +68,7 @@ def transformerLayerBoundsPos {seq dModel dHead numHeads hidden : Nat} [NeZero s (Fin seq → Fin dModel → Rat) × (Fin seq → Fin dModel → Rat) := let positions := (Finset.univ : Finset (Fin seq)) let hpos : positions.Nonempty := by - classical - have h : Nonempty (Fin seq) := - ⟨⟨0, Nat.pos_of_ne_zero (NeZero.ne (n := seq))⟩⟩ - exact (Finset.univ_nonempty_iff.mpr h) + simp [positions] let loCached := cacheBound2 lo let hiCached := cacheBound2 hi let base := intervalBoundsOn positions hpos loCached hiCached @@ -109,10 +105,7 @@ theorem transformerLayerBoundsPos_spec {seq dModel dHead numHeads hidden : Nat} intro bounds q i let positions := (Finset.univ : Finset (Fin seq)) have hpos : positions.Nonempty := by - classical - have h : Nonempty (Fin seq) := - ⟨⟨0, Nat.pos_of_ne_zero (NeZero.ne (n := seq))⟩⟩ - exact (Finset.univ_nonempty_iff.mpr h) + simp [positions] let loCached := cacheBound2 lo let hiCached := cacheBound2 hi have hloCached : ∀ q i, (loCached q i : Real) ≤ x q i := by @@ -147,11 +140,21 @@ theorem transformerLayerBoundsPos_spec {seq dModel dHead numHeads hidden : Nat} layer.attnBias scores x q j have yLo : ∀ j, (loCached q j : Real) + (attn.1 j : Real) ≤ y j := by intro j - have hlow := add_le_add (hloCached q j) (hattn j).1 + have hlow : + (loCached q j : Real) + (attn.1 j : Real) ≤ + x q j + + attentionOutputReal eps layer.ln1Gamma layer.ln1Beta heads + layer.attnBias scores x q j := by + exact add_le_add (hloCached q j) (hattn j).1 simpa [y] using hlow have yHi : ∀ j, y j ≤ (hiCached q j : Real) + (attn.2 j : Real) := by intro j - have hhigh := add_le_add (hhiCached q j) (hattn j).2 + have hhigh : + x q j + + attentionOutputReal eps layer.ln1Gamma layer.ln1Beta heads + layer.attnBias scores x q j ≤ + (hiCached q j : Real) + (attn.2 j : Real) := by + exact add_le_add (hhiCached q j) (hattn j).2 simpa [y] using hhigh let yLoCached := cacheBound2 (fun q i => loCached q i + attnLo i) let yHiCached := cacheBound2 (fun q i => hiCached q i + attnHi i) diff --git a/Nfp/Sound/Bounds/Transformer/Embedding.lean b/Nfp/Sound/Bounds/Transformer/Embedding.lean index 8ce4e27..c63bff7 100644 --- a/Nfp/Sound/Bounds/Transformer/Embedding.lean +++ b/Nfp/Sound/Bounds/Transformer/Embedding.lean @@ -23,6 +23,15 @@ private lemma fin_univ_nonempty (seq : Nat) [NeZero seq] : refine ⟨⟨0, ?_⟩, by simp⟩ exact Nat.pos_of_ne_zero (NeZero.ne (n := seq)) +/-- `inf'`/`sup'` bounds for a selected position. -/ +private lemma inf_sup_bounds {seq : Nat} (positions : Finset (Fin seq)) + (hpos : positions.Nonempty) (f : Fin seq → Rat) + {q : Fin seq} (hq : q ∈ positions) : + positions.inf' hpos f ≤ f q ∧ f q ≤ positions.sup' hpos f := by + constructor + · exact Finset.inf'_le (s := positions) (f := f) (b := q) hq + · exact Finset.le_sup' (s := positions) (f := f) (b := q) hq + /-- Interval bounds across tokens for an embedding map. -/ def embeddingIntervalBounds {seq dModel : Nat} [NeZero seq] (x : Fin seq → Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := @@ -39,19 +48,15 @@ theorem embeddingIntervalBounds_spec {seq dModel : Nat} [NeZero seq] (x q i : Real) ≤ (bounds.2 i : Real) := by classical intro bounds q i - have hloRat : bounds.1 i ≤ x q i := by - have h := - Finset.inf'_le (s := (Finset.univ : Finset (Fin seq))) - (f := fun k => x k i) (b := q) (by simp) - simpa [bounds, embeddingIntervalBounds, fin_univ_nonempty] using h - have hhiRat : x q i ≤ bounds.2 i := by - have h := - Finset.le_sup' (s := (Finset.univ : Finset (Fin seq))) - (f := fun k => x k i) (b := q) (by simp) + have hbounds : + bounds.1 i ≤ x q i ∧ x q i ≤ bounds.2 i := by + have h := inf_sup_bounds (positions := (Finset.univ : Finset (Fin seq))) + (hpos := fin_univ_nonempty (seq := seq)) (f := fun k => x k i) + (q := q) (hq := by simp) simpa [bounds, embeddingIntervalBounds, fin_univ_nonempty] using h constructor - · exact ratToReal_le_of_le hloRat - · exact ratToReal_le_of_le hhiRat + · exact ratToReal_le_of_le hbounds.1 + · exact ratToReal_le_of_le hbounds.2 /-- Interval bounds across a finite set of positions for an embedding map. -/ def embeddingIntervalBoundsOn {seq dModel : Nat} [NeZero seq] @@ -70,19 +75,13 @@ theorem embeddingIntervalBoundsOn_spec {seq dModel : Nat} [NeZero seq] (x q i : Real) ≤ (bounds.2 i : Real) := by classical intro bounds q hq i - have hloRat : bounds.1 i ≤ x q i := by - have h := - Finset.inf'_le (s := positions) - (f := fun k => x k i) (b := q) hq - simpa [bounds, embeddingIntervalBoundsOn] using h - have hhiRat : x q i ≤ bounds.2 i := by - have h := - Finset.le_sup' (s := positions) - (f := fun k => x k i) (b := q) hq + have hbounds : bounds.1 i ≤ x q i ∧ x q i ≤ bounds.2 i := by + have h := inf_sup_bounds (positions := positions) (hpos := hpos) + (f := fun k => x k i) (q := q) (hq := hq) simpa [bounds, embeddingIntervalBoundsOn] using h constructor - · exact ratToReal_le_of_le hloRat - · exact ratToReal_le_of_le hhiRat + · exact ratToReal_le_of_le hbounds.1 + · exact ratToReal_le_of_le hbounds.2 /-- Collapse per-position interval bounds over a finite set of positions. -/ def intervalBoundsOn {seq dModel : Nat} [NeZero seq] @@ -104,15 +103,13 @@ theorem intervalBoundsOn_spec {seq dModel : Nat} [NeZero seq] classical intro bounds q hq i have hmin : bounds.1 i ≤ lo q i := by - have h := - Finset.inf'_le (s := positions) - (f := fun k => lo k i) (b := q) hq - simpa [bounds, intervalBoundsOn] using h + have h := inf_sup_bounds (positions := positions) (hpos := hpos) + (f := fun k => lo k i) (q := q) (hq := hq) + simpa [bounds, intervalBoundsOn] using h.1 have hmax : hi q i ≤ bounds.2 i := by - have h := - Finset.le_sup' (s := positions) - (f := fun k => hi k i) (b := q) hq - simpa [bounds, intervalBoundsOn] using h + have h := inf_sup_bounds (positions := positions) (hpos := hpos) + (f := fun k => hi k i) (q := q) (hq := hq) + simpa [bounds, intervalBoundsOn] using h.2 have hlo' := hlo q hq i have hhi' := hhi q hq i constructor diff --git a/Nfp/Sound/Induction/Core.lean b/Nfp/Sound/Induction/Core.lean index 9f03f88..54e9128 100644 --- a/Nfp/Sound/Induction/Core.lean +++ b/Nfp/Sound/Induction/Core.lean @@ -91,6 +91,658 @@ def buildValueRangeCert? [NeZero seq] exact some ⟨cert, h⟩ else exact none +/-- Cached bounds and derived quantities for induction-head core certificates. -/ +structure InductionHeadCoreCache (seq dModel dHead : Nat) where + /-- Cached LayerNorm bound pair. -/ + lnBounds : (Fin seq → Fin dModel → Rat) × (Fin seq → Fin dModel → Rat) + /-- LayerNorm lower bounds. -/ + lnLo : Fin seq → Fin dModel → Rat + /-- LayerNorm upper bounds. -/ + lnHi : Fin seq → Fin dModel → Rat + /-- Tasks for LayerNorm absolute maxima. -/ + lnAbsMaxTask : Fin seq → Rat + /-- Cached LayerNorm absolute maxima. -/ + lnAbsMaxArr : Array Rat + /-- LayerNorm absolute-max lookup. -/ + lnAbsMax : Fin seq → Rat + /-- Tasks for inverse-std bounds. -/ + invStdBoundsTasks : Array (Task (Rat × Rat)) + /-- Cached inverse-std bounds. -/ + invStdBoundsArr : Array (Rat × Rat) + /-- Inverse-std lower bounds. -/ + invStdLo : Fin seq → Rat + /-- Inverse-std upper bounds. -/ + invStdHi : Fin seq → Rat + /-- Cached query base terms. -/ + qBaseArr : Array Rat + /-- Query base lookup. -/ + qBase : Fin dHead → Rat + /-- Cached key base terms. -/ + kBaseArr : Array Rat + /-- Key base lookup. -/ + kBase : Fin dHead → Rat + /-- Tasks for query coefficient rows. -/ + qCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) + /-- Cached query coefficient rows. -/ + qCoeffArr : Array { row : Array Rat // row.size = dHead } + /-- Query coefficient lookup. -/ + qCoeff : Fin seq → Fin dHead → Rat + /-- Tasks for key coefficient rows. -/ + kCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) + /-- Cached key coefficient rows. -/ + kCoeffArr : Array { row : Array Rat // row.size = dHead } + /-- Key coefficient lookup. -/ + kCoeff : Fin seq → Fin dHead → Rat + /-- Query lower bounds. -/ + qLo : Fin seq → Fin dHead → Rat + /-- Query upper bounds. -/ + qHi : Fin seq → Fin dHead → Rat + /-- Key lower bounds. -/ + kLo : Fin seq → Fin dHead → Rat + /-- Key upper bounds. -/ + kHi : Fin seq → Fin dHead → Rat + /-- Query absolute bounds. -/ + qAbs : Fin seq → Fin dHead → Rat + /-- Key absolute bounds. -/ + kAbs : Fin seq → Fin dHead → Rat + /-- Cached max query abs bounds. -/ + qAbsMaxArr : Array Rat + /-- Max query abs bound lookup. -/ + qAbsMax : Fin dHead → Rat + /-- Cached max key abs bounds. -/ + kAbsMaxArr : Array Rat + /-- Max key abs bound lookup. -/ + kAbsMax : Fin dHead → Rat + /-- Causal mask predicate. -/ + masked : Fin seq → Fin seq → Prop + /-- Split budget for query dims. -/ + splitBudgetQ : Nat + /-- Split budget for key dims. -/ + splitBudgetK : Nat + /-- Split budget for base diff dims. -/ + splitBudgetDiffBase : Nat + /-- Split budget for refined diff dims. -/ + splitBudgetDiffRefined : Nat + /-- Split dims for query bounds. -/ + splitDimsQ : Fin seq → List (Fin dHead) + /-- Split dims for key bounds. -/ + splitDimsK : Fin seq → Fin seq → List (Fin dHead) + /-- Split dims for diff bounds with budget. -/ + splitDimsDiffCore : Nat → Fin seq → Fin seq → List (Fin dHead) + /-- Split dims for base diff bounds. -/ + splitDimsDiffBase : Fin seq → Fin seq → List (Fin dHead) + /-- Split dims for refined diff bounds. -/ + splitDimsDiffRefined : Fin seq → Fin seq → List (Fin dHead) + /-- Tasks for dot-product interval rows. -/ + dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) + /-- Tasks for base diff dot rows. -/ + dotDiffRowTasksBase : Array (Task { row : Array (Rat × Rat) // row.size = seq }) + /-- Dot-product lower bounds. -/ + dotLo : Fin seq → Fin seq → Rat + /-- Dot-product upper bounds. -/ + dotHi : Fin seq → Fin seq → Rat + /-- Base diff dot-product lower bounds. -/ + dotDiffLoBase : Fin seq → Fin seq → Rat + /-- Base diff dot-product upper bounds. -/ + dotDiffHiBase : Fin seq → Fin seq → Rat + /-- Dot-product absolute bounds. -/ + dotAbs : Fin seq → Fin seq → Rat + /-- Base score absolute bounds. -/ + scoreBaseAbs : Fin seq → Fin seq → Rat + /-- Score lower bounds. -/ + scoreLo : Fin seq → Fin seq → Rat + /-- Score upper bounds. -/ + scoreHi : Fin seq → Fin seq → Rat + /-- Score lower bounds at prev key. -/ + scoreLoPrev : Fin seq → Rat + /-- Base score-gap lower bounds. -/ + scoreGapLoBase : Fin seq → Fin seq → Rat + /-- Other-key set for each query. -/ + otherKeys : Fin seq → Finset (Fin seq) + /-- Worst key candidate per query. -/ + worstKey : Fin seq → Option (Fin seq) + /-- Refined diff dot-product lower bounds. -/ + dotDiffLo : Fin seq → Fin seq → Rat + /-- Refined diff dot-product upper bounds. -/ + dotDiffHi : Fin seq → Fin seq → Rat + /-- Score-gap lower bounds. -/ + scoreGapLo : Fin seq → Fin seq → Rat + /-- Margin per query. -/ + marginAt : Fin seq → Rat + /-- Epsilon per query. -/ + epsAt : Fin seq → Rat + /-- Global margin. -/ + margin : Rat + /-- Global epsilon. -/ + eps : Rat + /-- Cached direction head vector. -/ + dirHeadVec : Vector Rat dHead + /-- Direction head lookup. -/ + dirHead : Fin dHead → Rat + /-- Value-direction weight dot products. -/ + wvDir : Fin dModel → Rat + /-- Direction bias term. -/ + bDir : Rat + /-- Value absolute bounds. -/ + valsAbs : Fin seq → Rat + /-- Value lower bounds. -/ + valsLo : Fin seq → Rat + /-- Value upper bounds. -/ + valsHi : Fin seq → Rat + /-- Universe of query indices. -/ + univ : Finset (Fin seq) + /-- Global value lower bound. -/ + lo : Rat + /-- Global value upper bound. -/ + hi : Rat + /-- Value-interval certificate. -/ + valCert : ValueInterval seq + /-- Induction-head certificate. -/ + cert : InductionHeadCert seq + +/-- Build cached core quantities for induction-head certificates. -/ +def buildInductionHeadCoreCache [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : + InductionHeadCoreCache seq dModel dHead := by + classical + let lnBounds := Bounds.cacheBoundPair2 (fun q => + Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) + let lnLo : Fin seq → Fin dModel → Rat := lnBounds.1 + let lnHi : Fin seq → Fin dModel → Rat := lnBounds.2 + let lnAbsMaxTask : Fin seq → Rat := + Bounds.cacheBoundTask (fun q => + Bounds.intervalAbsBound (lnLo q) (lnHi q)) + let lnAbsMaxArr : Array Rat := + Array.ofFn (fun q : Fin seq => lnAbsMaxTask q) + let lnAbsMax : Fin seq → Rat := fun q => + lnAbsMaxArr[q.1]'(by + have hsize : lnAbsMaxArr.size = seq := by + simp [lnAbsMaxArr] + simp [hsize]) + let invStdBoundsTasks : Array (Task (Rat × Rat)) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => invStdBounds inputs.lnEps (inputs.embed q))) + let invStdBoundsArr : Array (Rat × Rat) := + Array.ofFn (fun q : Fin seq => + (invStdBoundsTasks[q.1]'(by + have hsize : invStdBoundsTasks.size = seq := by + simp [invStdBoundsTasks] + simp [hsize])).get) + let invStdLo : Fin seq → Rat := fun q => + (invStdBoundsArr[q.1]'(by + have hsize : invStdBoundsArr.size = seq := by + simp [invStdBoundsArr] + simp [hsize])).1 + let invStdHi : Fin seq → Rat := fun q => + (invStdBoundsArr[q.1]'(by + have hsize : invStdBoundsArr.size = seq := by + simp [invStdBoundsArr] + simp [hsize])).2 + let qBaseArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.wq j d) (fun j => inputs.ln1Beta j) + + inputs.bq d) + let qBase : Fin dHead → Rat := fun d => + qBaseArr[d.1]'(by + have hsize : qBaseArr.size = dHead := by + simp [qBaseArr] + simp [hsize]) + let kBaseArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.wk j d) (fun j => inputs.ln1Beta j) + + inputs.bk d) + let kBase : Fin dHead → Rat := fun d => + kBaseArr[d.1]'(by + have hsize : kBaseArr.size = dHead := by + simp [kBaseArr] + simp [hsize]) + let qCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + let μ := mean (inputs.embed q) + let coeff : Fin dModel → Rat := fun j => + inputs.ln1Gamma j * (inputs.embed q j - μ) + ⟨Array.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.wq j d) coeff), + by simp⟩)) + let qCoeffArr : Array { row : Array Rat // row.size = dHead } := + Array.ofFn (fun q : Fin seq => + (qCoeffRowTasks[q.1]'(by + have hsize : qCoeffRowTasks.size = seq := by + simp [qCoeffRowTasks] + simp [hsize])).get) + let qCoeff : Fin seq → Fin dHead → Rat := fun q d => + let row := qCoeffArr[q.1]'(by + have hsize : qCoeffArr.size = seq := by + simp [qCoeffArr] + simp [hsize]) + row.1[d.1]'(by + simp [row.2]) + let kCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + let μ := mean (inputs.embed q) + let coeff : Fin dModel → Rat := fun j => + inputs.ln1Gamma j * (inputs.embed q j - μ) + ⟨Array.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.wk j d) coeff), + by simp⟩)) + let kCoeffArr : Array { row : Array Rat // row.size = dHead } := + Array.ofFn (fun q : Fin seq => + (kCoeffRowTasks[q.1]'(by + have hsize : kCoeffRowTasks.size = seq := by + simp [kCoeffRowTasks] + simp [hsize])).get) + let kCoeff : Fin seq → Fin dHead → Rat := fun q d => + let row := kCoeffArr[q.1]'(by + have hsize : kCoeffArr.size = seq := by + simp [kCoeffArr] + simp [hsize]) + row.1[d.1]'(by + simp [row.2]) + let qLo : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) + qBase d + bounds.1 + let qHi : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) + qBase d + bounds.2 + let kLo : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) + kBase d + bounds.1 + let kHi : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) + kBase d + bounds.2 + let qAbs : Fin seq → Fin dHead → Rat := fun q d => max |qLo q d| |qHi q d| + let kAbs : Fin seq → Fin dHead → Rat := fun q d => max |kLo q d| |kHi q d| + let qAbsMaxArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := Finset.univ_nonempty + univ.sup' hnonempty (fun q => qAbs q d)) + let qAbsMax : Fin dHead → Rat := fun d => + qAbsMaxArr[d.1]'(by + have hsize : qAbsMaxArr.size = dHead := by + simp [qAbsMaxArr] + simp [hsize]) + let kAbsMaxArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := Finset.univ_nonempty + univ.sup' hnonempty (fun k => kAbs k d)) + let kAbsMax : Fin dHead → Rat := fun d => + kAbsMaxArr[d.1]'(by + have hsize : kAbsMaxArr.size = dHead := by + simp [kAbsMaxArr] + simp [hsize]) + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let splitBudgetQ : Nat := 2 + let splitBudgetK : Nat := 2 + let splitBudgetDiffBase : Nat := 0 + let splitBudgetDiffRefined : Nat := 12 + let splitDimsQ : Fin seq → List (Fin dHead) := fun q => + let ambig := + (List.finRange dHead).filter (fun d => decide (qLo q d < 0 ∧ 0 < qHi q d)) + let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d + let step + (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) + (d : Fin dHead) : + Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := + let s := score d + match best with + | (none, none) => (some (s, d), none) + | (some b1, none) => + if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) + | (some b1, some b2) => + if b1.1 < s then (some (s, d), some b1) + else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) + | (none, some b2) => + if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) + let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => + match ambig.foldl step (none, none) with + | (some b1, some b2) => [b1.2, b2.2] + | (some b1, none) => [b1.2] + | (none, _) => [] + let dims1 := top2 ambig + let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) + (dims1 ++ dims2).take splitBudgetQ + let splitDimsK : Fin seq → Fin seq → List (Fin dHead) := fun q k => + let ambig := + (List.finRange dHead).filter (fun d => decide (kLo k d < 0 ∧ 0 < kHi k d)) + let score : Fin dHead → Rat := fun d => (kHi k d - kLo k d) * qAbs q d + let step + (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) + (d : Fin dHead) : + Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := + let s := score d + match best with + | (none, none) => (some (s, d), none) + | (some b1, none) => + if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) + | (some b1, some b2) => + if b1.1 < s then (some (s, d), some b1) + else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) + | (none, some b2) => + if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) + let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => + match ambig.foldl step (none, none) with + | (some b1, some b2) => [b1.2, b2.2] + | (some b1, none) => [b1.2] + | (none, _) => [] + let dims1 := top2 ambig + let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) + (dims1 ++ dims2).take splitBudgetK + let splitDimsDiffCore : Nat → Fin seq → Fin seq → List (Fin dHead) := fun budget q k => + let prev := inputs.prev q + let diffLo : Fin dHead → Rat := fun d => kLo prev d - kHi k d + let diffHi : Fin dHead → Rat := fun d => kHi prev d - kLo k d + let ambig := + (List.finRange dHead).filter (fun d => decide (diffLo d < 0 ∧ 0 < diffHi d)) + let score : Fin dHead → Rat := fun d => (diffHi d - diffLo d) * qAbs q d + let step + (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) + (d : Fin dHead) : + Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := + let s := score d + match best with + | (none, none) => (some (s, d), none) + | (some b1, none) => + if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) + | (some b1, some b2) => + if b1.1 < s then (some (s, d), some b1) + else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) + | (none, some b2) => + if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) + let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => + match ambig.foldl step (none, none) with + | (some b1, some b2) => [b1.2, b2.2] + | (some b1, none) => [b1.2] + | (none, _) => [] + let dims1 : List (Fin dHead) := top2 ambig + let dims2 : List (Fin dHead) := + top2 (ambig.filter (fun d => decide (d ∉ dims1))) + let memDims2 : Fin dHead → Bool := fun d => + dims2.any (fun d' => decide (d' = d)) + let dims3 : List (Fin dHead) := + top2 + ((ambig.filter (fun d => decide (d ∉ dims1))).filter + (fun d => !memDims2 d)) + (dims1 ++ dims2 ++ dims3).take budget + let splitDimsDiffBase : Fin seq → Fin seq → List (Fin dHead) := + splitDimsDiffCore splitBudgetDiffBase + let splitDimsDiffRefined : Fin seq → Fin seq → List (Fin dHead) := + splitDimsDiffCore splitBudgetDiffRefined + let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + let dimsQ := splitDimsQ q + ⟨Array.ofFn (fun k : Fin seq => + let dimsK := splitDimsK q k + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsK + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo k d) (fun d => kHi k d)), + by simp⟩)) + let dotDiffRowTasksBase : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + if hq : q ∈ inputs.active then + let dimsQ := splitDimsQ q + ⟨Array.ofFn (fun k : Fin seq => + if masked q k then + (0, 0) + else + let dimsDiff := splitDimsDiffBase q k + let prev := inputs.prev q + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)), + by simp⟩ + else + ⟨Array.ofFn (fun _ : Fin seq => (0, 0)), by simp⟩)) + let dotLo : Fin seq → Fin seq → Rat := fun q k => + let row := (dotRowTasks[q.1]'(by + simp [dotRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.1 + let dotHi : Fin seq → Fin seq → Rat := fun q k => + let row := (dotRowTasks[q.1]'(by + simp [dotRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.2 + let dotDiffLoBase : Fin seq → Fin seq → Rat := fun q k => + let row := (dotDiffRowTasksBase[q.1]'(by + simp [dotDiffRowTasksBase, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.1 + let dotDiffHiBase : Fin seq → Fin seq → Rat := fun q k => + let row := (dotDiffRowTasksBase[q.1]'(by + simp [dotDiffRowTasksBase, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.2 + let dotAbs : Fin seq → Fin seq → Rat := fun q k => max |dotLo q k| |dotHi q k| + let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => + |inputs.scale| * dotAbs q k + let scoreLo : Fin seq → Fin seq → Rat := fun q k => + if masked q k then + inputs.maskValue + else + if hscale : 0 ≤ inputs.scale then + inputs.scale * dotLo q k + else + inputs.scale * dotHi q k + let scoreHi : Fin seq → Fin seq → Rat := fun q k => + if masked q k then + inputs.maskValue + else + if hscale : 0 ≤ inputs.scale then + inputs.scale * dotHi q k + else + inputs.scale * dotLo q k + let scoreLoPrev : Fin seq → Rat := fun q => + scoreLo q (inputs.prev q) + let scoreGapLoBase : Fin seq → Fin seq → Rat := fun q k => + if masked q (inputs.prev q) then + scoreLoPrev q - scoreHi q k + else if masked q k then + scoreLoPrev q - inputs.maskValue + else if hscale : 0 ≤ inputs.scale then + inputs.scale * dotDiffLoBase q k + else + inputs.scale * dotDiffHiBase q k + let otherKeys : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + let worstKey : Fin seq → Option (Fin seq) := fun q => + if hq : q ∈ inputs.active then + let ks := (List.finRange seq).filter (fun k => decide (k ≠ inputs.prev q)) + match ks with + | [] => none + | k :: ks => + let step (best : Rat × Fin seq) (k : Fin seq) := + let s := scoreGapLoBase q k + if s ≤ best.1 then (s, k) else best + some (ks.foldl step (scoreGapLoBase q k, k)).2 + else + none + let dotDiffLo : Fin seq → Fin seq → Rat := fun q k => + match worstKey q with + | some k' => + if hk : k = k' then + let dimsQ := splitDimsQ q + let dimsDiff := splitDimsDiffRefined q k + let prev := inputs.prev q + (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)).1 + else + dotDiffLoBase q k + | none => dotDiffLoBase q k + let dotDiffHi : Fin seq → Fin seq → Rat := fun q k => + match worstKey q with + | some k' => + if hk : k = k' then + let dimsQ := splitDimsQ q + let dimsDiff := splitDimsDiffRefined q k + let prev := inputs.prev q + (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)).2 + else + dotDiffHiBase q k + | none => dotDiffHiBase q k + let scoreGapLo : Fin seq → Fin seq → Rat := fun q k => + if masked q (inputs.prev q) then + scoreLoPrev q - scoreHi q k + else if masked q k then + scoreLoPrev q - inputs.maskValue + else if hscale : 0 ≤ inputs.scale then + inputs.scale * dotDiffLo q k + else + inputs.scale * dotDiffHi q k + let marginAt : Fin seq → Rat := fun q => + if hq : q ∈ inputs.active then + let other := otherKeys q + if h : other.Nonempty then + other.inf' h (fun k => scoreGapLo q k) + else + (0 : Rat) + else + (0 : Rat) + let epsAt : Fin seq → Rat := fun q => + if marginAt q < 0 then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + marginAt q) + let margin : Rat := + if h : inputs.active.Nonempty then + inputs.active.inf' h marginAt + else + (0 : Rat) + let eps : Rat := + if margin < 0 then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + margin) + let dirHeadVec := dirHeadVecOfInputs inputs + let dirHead : Fin dHead → Rat := fun d => dirHeadVec.get d + let wvDir : Fin dModel → Rat := + Bounds.cacheBoundTask (fun j => + Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) + let bDir : Rat := + Linear.dotFin dHead dirHead (fun d => inputs.bv d) + let valsAbs : Fin seq → Rat := fun q => + Linear.sumFin dModel (fun j => |wvDir j|) * lnAbsMax q + let valsLo : Fin seq → Rat := fun q => bDir - valsAbs q + let valsHi : Fin seq → Rat := fun q => bDir + valsAbs q + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := by simp [univ] + let lo := univ.inf' hnonempty valsLo + let hi := univ.sup' hnonempty valsHi + let valCert : ValueInterval seq := + { lo := lo + hi := hi + valsLo := valsLo + valsHi := valsHi + direction := some inputs.directionSpec } + let cert : InductionHeadCert seq := + { eps := eps + epsAt := epsAt + margin := margin + active := inputs.active + prev := inputs.prev + values := valCert } + exact + { lnBounds := lnBounds + lnLo := lnLo + lnHi := lnHi + lnAbsMaxTask := lnAbsMaxTask + lnAbsMaxArr := lnAbsMaxArr + lnAbsMax := lnAbsMax + invStdBoundsTasks := invStdBoundsTasks + invStdBoundsArr := invStdBoundsArr + invStdLo := invStdLo + invStdHi := invStdHi + qBaseArr := qBaseArr + qBase := qBase + kBaseArr := kBaseArr + kBase := kBase + qCoeffRowTasks := qCoeffRowTasks + qCoeffArr := qCoeffArr + qCoeff := qCoeff + kCoeffRowTasks := kCoeffRowTasks + kCoeffArr := kCoeffArr + kCoeff := kCoeff + qLo := qLo + qHi := qHi + kLo := kLo + kHi := kHi + qAbs := qAbs + kAbs := kAbs + qAbsMaxArr := qAbsMaxArr + qAbsMax := qAbsMax + kAbsMaxArr := kAbsMaxArr + kAbsMax := kAbsMax + masked := masked + splitBudgetQ := splitBudgetQ + splitBudgetK := splitBudgetK + splitBudgetDiffBase := splitBudgetDiffBase + splitBudgetDiffRefined := splitBudgetDiffRefined + splitDimsQ := splitDimsQ + splitDimsK := splitDimsK + splitDimsDiffCore := splitDimsDiffCore + splitDimsDiffBase := splitDimsDiffBase + splitDimsDiffRefined := splitDimsDiffRefined + dotRowTasks := dotRowTasks + dotDiffRowTasksBase := dotDiffRowTasksBase + dotLo := dotLo + dotHi := dotHi + dotDiffLoBase := dotDiffLoBase + dotDiffHiBase := dotDiffHiBase + dotAbs := dotAbs + scoreBaseAbs := scoreBaseAbs + scoreLo := scoreLo + scoreHi := scoreHi + scoreLoPrev := scoreLoPrev + scoreGapLoBase := scoreGapLoBase + otherKeys := otherKeys + worstKey := worstKey + dotDiffLo := dotDiffLo + dotDiffHi := dotDiffHi + scoreGapLo := scoreGapLo + marginAt := marginAt + epsAt := epsAt + margin := margin + eps := eps + dirHeadVec := dirHeadVec + dirHead := dirHead + wvDir := wvDir + bDir := bDir + valsAbs := valsAbs + valsLo := valsLo + valsHi := valsHi + univ := univ + lo := lo + hi := hi + valCert := valCert + cert := cert } + +/-- The cached certificate is built from cache fields. -/ +theorem buildInductionHeadCoreCache_cert_eq [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : + (buildInductionHeadCoreCache inputs).cert = + { eps := (buildInductionHeadCoreCache inputs).eps + epsAt := (buildInductionHeadCoreCache inputs).epsAt + margin := (buildInductionHeadCoreCache inputs).margin + active := inputs.active + prev := inputs.prev + values := (buildInductionHeadCoreCache inputs).valCert } := by + rfl /-- Build induction certificates from exact head inputs (core computation). -/ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : @@ -101,422 +753,54 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} · by_cases hmodel : dModel = 0 · exact none · by_cases hactive : inputs.active.Nonempty - · let lnBounds := Bounds.cacheBoundPair2 (fun q => - Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) - let lnLo : Fin seq → Fin dModel → Rat := lnBounds.1 - let lnHi : Fin seq → Fin dModel → Rat := lnBounds.2 - let lnAbsMaxTask : Fin seq → Rat := - Bounds.cacheBoundTask (fun q => - Bounds.intervalAbsBound (lnLo q) (lnHi q)) - let lnAbsMaxArr : Array Rat := - Array.ofFn (fun q : Fin seq => lnAbsMaxTask q) - let lnAbsMax : Fin seq → Rat := fun q => - lnAbsMaxArr[q.1]'(by - have hsize : lnAbsMaxArr.size = seq := by - simp [lnAbsMaxArr] - simp [hsize]) - let invStdBoundsTasks : Array (Task (Rat × Rat)) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => invStdBounds inputs.lnEps (inputs.embed q))) - let invStdBoundsArr : Array (Rat × Rat) := - Array.ofFn (fun q : Fin seq => - (invStdBoundsTasks[q.1]'(by - have hsize : invStdBoundsTasks.size = seq := by - simp [invStdBoundsTasks] - simp [hsize])).get) - let invStdLo : Fin seq → Rat := fun q => - (invStdBoundsArr[q.1]'(by - have hsize : invStdBoundsArr.size = seq := by - simp [invStdBoundsArr] - simp [hsize])).1 - let invStdHi : Fin seq → Rat := fun q => - (invStdBoundsArr[q.1]'(by - have hsize : invStdBoundsArr.size = seq := by - simp [invStdBoundsArr] - simp [hsize])).2 - let qBaseArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.wq j d) (fun j => inputs.ln1Beta j) + - inputs.bq d) - let qBase : Fin dHead → Rat := fun d => - qBaseArr[d.1]'(by - have hsize : qBaseArr.size = dHead := by - simp [qBaseArr] - simp [hsize]) - let kBaseArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.wk j d) (fun j => inputs.ln1Beta j) + - inputs.bk d) - let kBase : Fin dHead → Rat := fun d => - kBaseArr[d.1]'(by - have hsize : kBaseArr.size = dHead := by - simp [kBaseArr] - simp [hsize]) - let qCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - let μ := mean (inputs.embed q) - let coeff : Fin dModel → Rat := fun j => - inputs.ln1Gamma j * (inputs.embed q j - μ) - ⟨Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.wq j d) coeff), - by simp⟩)) - let qCoeffArr : Array { row : Array Rat // row.size = dHead } := - Array.ofFn (fun q : Fin seq => - (qCoeffRowTasks[q.1]'(by - have hsize : qCoeffRowTasks.size = seq := by - simp [qCoeffRowTasks] - simp [hsize])).get) - let qCoeff : Fin seq → Fin dHead → Rat := fun q d => - let row := qCoeffArr[q.1]'(by - have hsize : qCoeffArr.size = seq := by - simp [qCoeffArr] - simp [hsize]) - row.1[d.1]'(by - simp [row.2]) - let kCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - let μ := mean (inputs.embed q) - let coeff : Fin dModel → Rat := fun j => - inputs.ln1Gamma j * (inputs.embed q j - μ) - ⟨Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.wk j d) coeff), - by simp⟩)) - let kCoeffArr : Array { row : Array Rat // row.size = dHead } := - Array.ofFn (fun q : Fin seq => - (kCoeffRowTasks[q.1]'(by - have hsize : kCoeffRowTasks.size = seq := by - simp [kCoeffRowTasks] - simp [hsize])).get) - let kCoeff : Fin seq → Fin dHead → Rat := fun q d => - let row := kCoeffArr[q.1]'(by - have hsize : kCoeffArr.size = seq := by - simp [kCoeffArr] - simp [hsize]) - row.1[d.1]'(by - simp [row.2]) - let qLo : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) - qBase d + bounds.1 - let qHi : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) - qBase d + bounds.2 - let kLo : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) - kBase d + bounds.1 - let kHi : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) - kBase d + bounds.2 - let qAbs : Fin seq → Fin dHead → Rat := fun q d => max |qLo q d| |qHi q d| - let kAbs : Fin seq → Fin dHead → Rat := fun q d => max |kLo q d| |kHi q d| - let qAbsMaxArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := Finset.univ_nonempty - univ.sup' hnonempty (fun q => qAbs q d)) - let qAbsMax : Fin dHead → Rat := fun d => - qAbsMaxArr[d.1]'(by - have hsize : qAbsMaxArr.size = dHead := by - simp [qAbsMaxArr] - simp [hsize]) - let kAbsMaxArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := Finset.univ_nonempty - univ.sup' hnonempty (fun k => kAbs k d)) - let kAbsMax : Fin dHead → Rat := fun d => - kAbsMaxArr[d.1]'(by - have hsize : kAbsMaxArr.size = dHead := by - simp [kAbsMaxArr] - simp [hsize]) - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let splitBudgetQ : Nat := 2 - let splitBudgetK : Nat := 2 - let splitBudgetDiffBase : Nat := 0 - let splitBudgetDiffRefined : Nat := 12 - let splitDimsQ : Fin seq → List (Fin dHead) := fun q => - let ambig := - (List.finRange dHead).filter (fun d => decide (qLo q d < 0 ∧ 0 < qHi q d)) - let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d - let step - (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) - (d : Fin dHead) : - Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := - let s := score d - match best with - | (none, none) => (some (s, d), none) - | (some b1, none) => - if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) - | (some b1, some b2) => - if b1.1 < s then (some (s, d), some b1) - else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) - | (none, some b2) => - if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => - match ambig.foldl step (none, none) with - | (some b1, some b2) => [b1.2, b2.2] - | (some b1, none) => [b1.2] - | (none, _) => [] - let dims1 := top2 ambig - let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) - (dims1 ++ dims2).take splitBudgetQ - let splitDimsK : Fin seq → Fin seq → List (Fin dHead) := fun q k => - let ambig := - (List.finRange dHead).filter (fun d => decide (kLo k d < 0 ∧ 0 < kHi k d)) - let score : Fin dHead → Rat := fun d => (kHi k d - kLo k d) * qAbs q d - let step - (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) - (d : Fin dHead) : - Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := - let s := score d - match best with - | (none, none) => (some (s, d), none) - | (some b1, none) => - if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) - | (some b1, some b2) => - if b1.1 < s then (some (s, d), some b1) - else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) - | (none, some b2) => - if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => - match ambig.foldl step (none, none) with - | (some b1, some b2) => [b1.2, b2.2] - | (some b1, none) => [b1.2] - | (none, _) => [] - let dims1 := top2 ambig - let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) - (dims1 ++ dims2).take splitBudgetK - let splitDimsDiffCore : Nat → Fin seq → Fin seq → List (Fin dHead) := fun budget q k => - let prev := inputs.prev q - let diffLo : Fin dHead → Rat := fun d => kLo prev d - kHi k d - let diffHi : Fin dHead → Rat := fun d => kHi prev d - kLo k d - let ambig := - (List.finRange dHead).filter (fun d => decide (diffLo d < 0 ∧ 0 < diffHi d)) - let score : Fin dHead → Rat := fun d => (diffHi d - diffLo d) * qAbs q d - let step - (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) - (d : Fin dHead) : - Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := - let s := score d - match best with - | (none, none) => (some (s, d), none) - | (some b1, none) => - if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) - | (some b1, some b2) => - if b1.1 < s then (some (s, d), some b1) - else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) - | (none, some b2) => - if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => - match ambig.foldl step (none, none) with - | (some b1, some b2) => [b1.2, b2.2] - | (some b1, none) => [b1.2] - | (none, _) => [] - let dims1 : List (Fin dHead) := top2 ambig - let dims2 : List (Fin dHead) := - top2 (ambig.filter (fun d => decide (d ∉ dims1))) - let memDims2 : Fin dHead → Bool := fun d => - dims2.any (fun d' => decide (d' = d)) - let dims3 : List (Fin dHead) := - top2 - ((ambig.filter (fun d => decide (d ∉ dims1))).filter - (fun d => !memDims2 d)) - (dims1 ++ dims2 ++ dims3).take budget - let splitDimsDiffBase : Fin seq → Fin seq → List (Fin dHead) := - splitDimsDiffCore splitBudgetDiffBase - let splitDimsDiffRefined : Fin seq → Fin seq → List (Fin dHead) := - splitDimsDiffCore splitBudgetDiffRefined - let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - let dimsQ := splitDimsQ q - ⟨Array.ofFn (fun k : Fin seq => - let dimsK := splitDimsK q k - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsK - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo k d) (fun d => kHi k d)), - by simp⟩)) - let dotDiffRowTasksBase : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - if hq : q ∈ inputs.active then - let dimsQ := splitDimsQ q - ⟨Array.ofFn (fun k : Fin seq => - if masked q k then - (0, 0) - else - let dimsDiff := splitDimsDiffBase q k - let prev := inputs.prev q - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)), - by simp⟩ - else - ⟨Array.ofFn (fun _ : Fin seq => (0, 0)), by simp⟩)) - let dotLo : Fin seq → Fin seq → Rat := fun q k => - let row := (dotRowTasks[q.1]'(by - simp [dotRowTasks, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.1 - let dotHi : Fin seq → Fin seq → Rat := fun q k => - let row := (dotRowTasks[q.1]'(by - simp [dotRowTasks, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.2 - let dotDiffLoBase : Fin seq → Fin seq → Rat := fun q k => - let row := (dotDiffRowTasksBase[q.1]'(by - simp [dotDiffRowTasksBase, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.1 - let dotDiffHiBase : Fin seq → Fin seq → Rat := fun q k => - let row := (dotDiffRowTasksBase[q.1]'(by - simp [dotDiffRowTasksBase, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.2 - let dotAbs : Fin seq → Fin seq → Rat := fun q k => max |dotLo q k| |dotHi q k| - let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => - |inputs.scale| * dotAbs q k - let scoreLo : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - if hscale : 0 ≤ inputs.scale then - inputs.scale * dotLo q k - else - inputs.scale * dotHi q k - let scoreHi : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - if hscale : 0 ≤ inputs.scale then - inputs.scale * dotHi q k - else - inputs.scale * dotLo q k - let scoreLoPrev : Fin seq → Rat := fun q => - scoreLo q (inputs.prev q) - let scoreGapLoBase : Fin seq → Fin seq → Rat := fun q k => - if masked q (inputs.prev q) then - scoreLoPrev q - scoreHi q k - else if masked q k then - scoreLoPrev q - inputs.maskValue - else if hscale : 0 ≤ inputs.scale then - inputs.scale * dotDiffLoBase q k - else - inputs.scale * dotDiffHiBase q k - let otherKeys : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - let worstKey : Fin seq → Option (Fin seq) := fun q => - if hq : q ∈ inputs.active then - let ks := (List.finRange seq).filter (fun k => decide (k ≠ inputs.prev q)) - match ks with - | [] => none - | k :: ks => - let step (best : Rat × Fin seq) (k : Fin seq) := - let s := scoreGapLoBase q k - if s ≤ best.1 then (s, k) else best - some (ks.foldl step (scoreGapLoBase q k, k)).2 - else - none - let dotDiffLo : Fin seq → Fin seq → Rat := fun q k => - match worstKey q with - | some k' => - if hk : k = k' then - let dimsQ := splitDimsQ q - let dimsDiff := splitDimsDiffRefined q k - let prev := inputs.prev q - (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)).1 - else - dotDiffLoBase q k - | none => dotDiffLoBase q k - let dotDiffHi : Fin seq → Fin seq → Rat := fun q k => - match worstKey q with - | some k' => - if hk : k = k' then - let dimsQ := splitDimsQ q - let dimsDiff := splitDimsDiffRefined q k - let prev := inputs.prev q - (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)).2 - else - dotDiffHiBase q k - | none => dotDiffHiBase q k - let scoreGapLo : Fin seq → Fin seq → Rat := fun q k => - if masked q (inputs.prev q) then - scoreLoPrev q - scoreHi q k - else if masked q k then - scoreLoPrev q - inputs.maskValue - else if hscale : 0 ≤ inputs.scale then - inputs.scale * dotDiffLo q k - else - inputs.scale * dotDiffHi q k - let marginAt : Fin seq → Rat := fun q => - if hq : q ∈ inputs.active then - let other := otherKeys q - if h : other.Nonempty then - other.inf' h (fun k => scoreGapLo q k) - else - (0 : Rat) - else - (0 : Rat) - let epsAt : Fin seq → Rat := fun q => - if marginAt q < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + marginAt q) - let margin : Rat := - if h : inputs.active.Nonempty then - inputs.active.inf' h marginAt - else - (0 : Rat) - let eps : Rat := - if margin < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + margin) - let dirHeadVec := dirHeadVecOfInputs inputs - let dirHead : Fin dHead → Rat := fun d => dirHeadVec.get d - let wvDir : Fin dModel → Rat := - Bounds.cacheBoundTask (fun j => - Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) - let bDir : Rat := - Linear.dotFin dHead dirHead (fun d => inputs.bv d) - let valsAbs : Fin seq → Rat := fun q => - Linear.sumFin dModel (fun j => |wvDir j|) * lnAbsMax q - let valsLo : Fin seq → Rat := fun q => bDir - valsAbs q - let valsHi : Fin seq → Rat := fun q => bDir + valsAbs q - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := by simp [univ] - let lo := univ.inf' hnonempty valsLo - let hi := univ.sup' hnonempty valsHi - let valCert : ValueInterval seq := - { lo := lo - hi := hi - valsLo := valsLo - valsHi := valsHi - direction := some inputs.directionSpec } - let cert : InductionHeadCert seq := - { eps := eps - epsAt := epsAt - margin := margin - active := inputs.active - prev := inputs.prev - values := valCert } - exact some cert + · exact some (buildInductionHeadCoreCache inputs).cert · exact none · exact none · exact none +/-- `buildInductionCertFromHeadCore?` succeeds under the guard conditions. -/ +theorem buildInductionCertFromHeadCore?_eq_some [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) + (hmodel : dModel ≠ 0) (hactive : inputs.active.Nonempty) : + buildInductionCertFromHeadCore? inputs = + some (buildInductionHeadCoreCache inputs).cert := by + classical + simp [buildInductionCertFromHeadCore?, hEps, hSqrt, hmodel, hactive] + +/-- `buildInductionCertFromHeadCore?` fails when `dModel = 0`. -/ +theorem buildInductionCertFromHeadCore?_eq_none_of_model_eq_zero [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) + (hmodel : dModel = 0) : + buildInductionCertFromHeadCore? inputs = none := by + classical + simp [buildInductionCertFromHeadCore?, hEps, hSqrt, hmodel] + +/-- `buildInductionCertFromHeadCore?` fails when `active` is empty. -/ +theorem buildInductionCertFromHeadCore?_eq_none_of_not_active [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) + (hmodel : dModel ≠ 0) (hactive : ¬inputs.active.Nonempty) : + buildInductionCertFromHeadCore? inputs = none := by + classical + simp [buildInductionCertFromHeadCore?, hEps, hSqrt, hmodel, hactive] + +/-- `buildInductionCertFromHeadCore?` fails when the sqrt lower bound is nonpositive. -/ +theorem buildInductionCertFromHeadCore?_eq_none_of_not_sqrt [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (hEps : 0 < inputs.lnEps) (hSqrt : ¬0 < sqrtLower inputs.lnEps) : + buildInductionCertFromHeadCore? inputs = none := by + classical + simp [buildInductionCertFromHeadCore?, hEps, hSqrt] + +/-- `buildInductionCertFromHeadCore?` fails when `lnEps` is nonpositive. -/ +theorem buildInductionCertFromHeadCore?_eq_none_of_not_eps [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (hEps : ¬0 < inputs.lnEps) : + buildInductionCertFromHeadCore? inputs = none := by + classical + simp [buildInductionCertFromHeadCore?, hEps] + end Sound end Nfp diff --git a/Nfp/Sound/Induction/CoreDefs.lean b/Nfp/Sound/Induction/CoreDefs.lean index a38d0ac..aab2279 100644 --- a/Nfp/Sound/Induction/CoreDefs.lean +++ b/Nfp/Sound/Induction/CoreDefs.lean @@ -39,24 +39,51 @@ noncomputable def lnRealOfInputs {seq dModel dHead : Nat} fun q => Bounds.layerNormReal inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q) +/-- Unfolding lemma for `lnRealOfInputs`. -/ +theorem lnRealOfInputs_def {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) (q : Fin seq) (i : Fin dModel) : + lnRealOfInputs inputs q i = + Bounds.layerNormReal inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q) i := rfl + /-- Real-valued query projections for head inputs. -/ noncomputable def qRealOfInputs {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dHead → Real := fun q d => dotProduct (fun j => (inputs.wq j d : Real)) (lnRealOfInputs inputs q) + (inputs.bq d : Real) +/-- Unfolding lemma for `qRealOfInputs`. -/ +theorem qRealOfInputs_def {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) (q : Fin seq) (d : Fin dHead) : + qRealOfInputs inputs q d = + dotProduct (fun j => (inputs.wq j d : Real)) (lnRealOfInputs inputs q) + + (inputs.bq d : Real) := rfl + /-- Real-valued key projections for head inputs. -/ noncomputable def kRealOfInputs {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dHead → Real := fun q d => dotProduct (fun j => (inputs.wk j d : Real)) (lnRealOfInputs inputs q) + (inputs.bk d : Real) +/-- Unfolding lemma for `kRealOfInputs`. -/ +theorem kRealOfInputs_def {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) (q : Fin seq) (d : Fin dHead) : + kRealOfInputs inputs q d = + dotProduct (fun j => (inputs.wk j d : Real)) (lnRealOfInputs inputs q) + + (inputs.bk d : Real) := rfl + /-- Real-valued value projections for head inputs. -/ noncomputable def vRealOfInputs {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dHead → Real := fun q d => dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs q) + (inputs.bv d : Real) +/-- Unfolding lemma for `vRealOfInputs`. -/ +theorem vRealOfInputs_def {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) (q : Fin seq) (d : Fin dHead) : + vRealOfInputs inputs q d = + dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs q) + + (inputs.bv d : Real) := rfl + /-- Real-valued attention scores for head inputs. -/ noncomputable def scoresRealOfInputs {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin seq → Real := @@ -72,18 +99,46 @@ noncomputable def scoresRealOfInputs {seq dModel dHead : Nat} else base +/-- Unfolding lemma for `scoresRealOfInputs`. -/ +theorem scoresRealOfInputs_def {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) (q k : Fin seq) : + scoresRealOfInputs inputs q k = + let base := + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) + if inputs.maskCausal then + if k ≤ q then + base + else + (inputs.maskValue : Real) + else + base := rfl + /-- Real-valued per-key head outputs in model space. -/ noncomputable def headValueRealOfInputs {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dModel → Real := fun k i => dotProduct (fun d => (inputs.wo i d : Real)) (fun d => vRealOfInputs inputs k d) +/-- Unfolding lemma for `headValueRealOfInputs`. -/ +theorem headValueRealOfInputs_def {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) (k : Fin seq) (i : Fin dModel) : + headValueRealOfInputs inputs k i = + dotProduct (fun d => (inputs.wo i d : Real)) (fun d => vRealOfInputs inputs k d) := rfl + /-- Real-valued direction scores for head inputs. -/ noncomputable def valsRealOfInputs {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Real := let dirHead : Fin dHead → Real := fun d => (dirHeadVecOfInputs inputs).get d fun k => dotProduct dirHead (fun d => vRealOfInputs inputs k d) +/-- Unfolding lemma for `valsRealOfInputs`. -/ +theorem valsRealOfInputs_def {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) (k : Fin seq) : + valsRealOfInputs inputs k = + let dirHead : Fin dHead → Real := fun d => (dirHeadVecOfInputs inputs).get d + dotProduct dirHead (fun d => vRealOfInputs inputs k d) := rfl + /-- Interval data for direction values. -/ structure ValueInterval (seq : Nat) where /-- Lower bound for values. -/ diff --git a/Nfp/Sound/Induction/CoreSound.lean b/Nfp/Sound/Induction/CoreSound.lean index dd0e99f..e9ddabf 100644 --- a/Nfp/Sound/Induction/CoreSound.lean +++ b/Nfp/Sound/Induction/CoreSound.lean @@ -21,7 +21,13 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} · by_cases hSqrt : 0 < sqrtLower inputs.lnEps · by_cases hmodel : dModel = 0 · have : False := by - simp [buildInductionCertFromHeadCore?, hEps, hSqrt, hmodel] at hcore + have hnone := + buildInductionCertFromHeadCore?_eq_none_of_model_eq_zero + (inputs := inputs) hEps hSqrt hmodel + have hcore' : + (none : Option (InductionHeadCert seq)) = some c := by + exact hnone.symm.trans hcore + cases hcore' exact this.elim · by_cases hactive : inputs.active.Nonempty · let lnBounds := Bounds.cacheBoundPair2 (fun q => @@ -35,78 +41,100 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} Array.ofFn (fun q : Fin seq => lnAbsMaxTask q) let lnAbsMax : Fin seq → Rat := fun q => lnAbsMaxArr[q.1]'(by - have hsize : lnAbsMaxArr.size = seq := by simp [lnAbsMaxArr] - simp [hsize]) + simp [lnAbsMaxArr]) let invStdBoundsTasks : Array (Task (Rat × Rat)) := Array.ofFn (fun q : Fin seq => Task.spawn (fun _ => invStdBounds inputs.lnEps (inputs.embed q))) let invStdBoundsArr : Array (Rat × Rat) := Array.ofFn (fun q : Fin seq => (invStdBoundsTasks[q.1]'(by - have hsize : invStdBoundsTasks.size = seq := by simp [invStdBoundsTasks] - simp [hsize])).get) + simp [invStdBoundsTasks])).get) let invStdLo : Fin seq → Rat := fun q => (invStdBoundsArr[q.1]'(by - have hsize : invStdBoundsArr.size = seq := by simp [invStdBoundsArr] - simp [hsize])).1 + simp [invStdBoundsArr])).1 let invStdHi : Fin seq → Rat := fun q => (invStdBoundsArr[q.1]'(by - have hsize : invStdBoundsArr.size = seq := by simp [invStdBoundsArr] - simp [hsize])).2 + simp [invStdBoundsArr])).2 + let lnCoeff : Fin seq → Fin dModel → Rat := fun q j => + inputs.ln1Gamma j * (inputs.embed q j - mean (inputs.embed q)) + let invStd : Fin seq → Real := fun q => + (Real.sqrt ((varianceRat (inputs.embed q) : Real) + (inputs.lnEps : Real)))⁻¹ + have hmeanRat : + ∀ q, (mean (inputs.embed q) : Real) = meanRat (inputs.embed q) := by + intro q + have hmu_rat : mean (inputs.embed q) = meanRat (inputs.embed q) := by + simp [mean_def, hmodel, ratRoundDown] + simpa [ratToReal] using congrArg ratToReal hmu_rat + have hln_affine : + ∀ q j, + lnRealOfInputs inputs q j = + (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q := by + intro q j + have hmu := hmeanRat q + simp [lnRealOfInputs, Bounds.layerNormReal, hmodel, lnCoeff, hmu, invStd, add_comm, + mul_assoc, -mul_eq_mul_left_iff, -mul_eq_mul_right_iff] + have hln_fun : + ∀ q, + lnRealOfInputs inputs q = + fun j => (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q := by + intro q + funext j + exact hln_affine q j + have hinv_bounds : + ∀ q, (invStdLo q : Real) ≤ invStd q ∧ invStd q ≤ (invStdHi q : Real) := by + intro q + simpa [invStd, invStdLo, invStdHi, invStdBoundsArr, invStdBoundsTasks, + invStdBounds, Task.spawn, Array.getElem_ofFn] using + (Bounds.invStdBounds_spec (eps := inputs.lnEps) (x := inputs.embed q) + hmodel hEps hSqrt) let qBaseArr : Array Rat := Array.ofFn (fun d : Fin dHead => Linear.dotFin dModel (fun j => inputs.wq j d) (fun j => inputs.ln1Beta j) + inputs.bq d) let qBase : Fin dHead → Rat := fun d => qBaseArr[d.1]'(by - have hsize : qBaseArr.size = dHead := by simp [qBaseArr] - simp [hsize]) + simp [qBaseArr]) let kBaseArr : Array Rat := Array.ofFn (fun d : Fin dHead => Linear.dotFin dModel (fun j => inputs.wk j d) (fun j => inputs.ln1Beta j) + inputs.bk d) let kBase : Fin dHead → Rat := fun d => kBaseArr[d.1]'(by - have hsize : kBaseArr.size = dHead := by simp [kBaseArr] - simp [hsize]) + simp [kBaseArr]) + let coeffRowTasks : + (Fin dModel → Fin dHead → Rat) → + Array (Task { row : Array Rat // row.size = dHead }) := + fun w => + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + let μ := mean (inputs.embed q) + let coeff : Fin dModel → Rat := fun j => + inputs.ln1Gamma j * (inputs.embed q j - μ) + ⟨Array.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => w j d) coeff), + by simp⟩)) let qCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - let μ := mean (inputs.embed q) - let coeff : Fin dModel → Rat := fun j => - inputs.ln1Gamma j * (inputs.embed q j - μ) - ⟨Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.wq j d) coeff), - by simp⟩)) + coeffRowTasks inputs.wq let qCoeffArr : Array { row : Array Rat // row.size = dHead } := Array.ofFn (fun q : Fin seq => (qCoeffRowTasks[q.1]'(by - have hsize : qCoeffRowTasks.size = seq := by simp [qCoeffRowTasks] - simp [hsize])).get) + simp [qCoeffRowTasks, coeffRowTasks])).get) let qCoeff : Fin seq → Fin dHead → Rat := fun q d => let row := qCoeffArr[q.1]'(by - have hsize : qCoeffArr.size = seq := by simp [qCoeffArr] - simp [hsize]) - row.1[d.1]'(by simp [row.2]) + simp [qCoeffArr]) + row.1[d.1]'(by + simp [row.2]) let kCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - let μ := mean (inputs.embed q) - let coeff : Fin dModel → Rat := fun j => - inputs.ln1Gamma j * (inputs.embed q j - μ) - ⟨Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.wk j d) coeff), - by simp⟩)) + coeffRowTasks inputs.wk let kCoeffArr : Array { row : Array Rat // row.size = dHead } := Array.ofFn (fun q : Fin seq => (kCoeffRowTasks[q.1]'(by - have hsize : kCoeffRowTasks.size = seq := by simp [kCoeffRowTasks] - simp [hsize])).get) + simp [kCoeffRowTasks, coeffRowTasks])).get) let kCoeff : Fin seq → Fin dHead → Rat := fun q d => let row := kCoeffArr[q.1]'(by - have hsize : kCoeffArr.size = seq := by simp [kCoeffArr] - simp [hsize]) - row.1[d.1]'(by simp [row.2]) + simp [kCoeffArr]) + row.1[d.1]'(by + simp [row.2]) let qLo : Fin seq → Fin dHead → Rat := fun q d => let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) qBase d + bounds.1 @@ -128,9 +156,7 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} univ.sup' hnonempty (fun q => qAbs q d)) let qAbsMax : Fin dHead → Rat := fun d => qAbsMaxArr[d.1]'(by - have hsize : qAbsMaxArr.size = dHead := by - simp [qAbsMaxArr] - simp [hsize]) + simp [qAbsMaxArr]) let kAbsMaxArr : Array Rat := Array.ofFn (fun d : Fin dHead => let univ : Finset (Fin seq) := Finset.univ @@ -138,9 +164,7 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} univ.sup' hnonempty (fun k => kAbs k d)) let kAbsMax : Fin dHead → Rat := fun d => kAbsMaxArr[d.1]'(by - have hsize : kAbsMaxArr.size = dHead := by - simp [kAbsMaxArr] - simp [hsize]) + simp [kAbsMaxArr]) let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k let splitBudgetQ : Nat := 2 @@ -427,8 +451,8 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} prev := inputs.prev values := valCert } have hcore' : buildInductionCertFromHeadCore? inputs = some cert := by - simp (config := { zeta := false }) - [buildInductionCertFromHeadCore?, hEps, hSqrt, hmodel, hactive, cert, valCert] + simp (config := { zeta := false }) only + [buildInductionCertFromHeadCore?, hEps, hSqrt, hmodel, hactive] rfl have hc : c = cert := by have hcert : cert = c := by @@ -452,7 +476,7 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} Bounds.abs_le_intervalAbsBound_real (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) (hlo := fun j => (hln j).1) (hhi := fun j => (hln j).2) j - simpa [lnAbsMax, lnAbsMaxArr, lnAbsMaxTask, Bounds.cacheBoundTask_apply, + simpa only [lnAbsMax, lnAbsMaxArr, lnAbsMaxTask, Bounds.cacheBoundTask_apply, Array.getElem_ofFn] using h have hdot_abs_bound : ∀ (v : Fin dModel → Rat) (q : Fin seq), @@ -521,187 +545,179 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} (Linear.dotFin n f g : Real) = dotProduct (fun j => (f j : Real)) (fun j => (g j : Real)) := by simp [Linear.dotFin_def, Linear.sumFin_eq_sum_univ, dotProduct] - have hq_bounds : - ∀ q d, (qLo q d : Real) ≤ qRealOfInputs inputs q d ∧ - qRealOfInputs inputs q d ≤ (qHi q d : Real) := by + have proj_bounds + (w : Fin dModel → Fin dHead → Rat) + (b base : Fin dHead → Rat) + (coeff : Fin seq → Fin dHead → Rat) + (hbase : ∀ d, + (base d : Real) = + dotProduct (fun j => (w j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + (b d : Real)) + (hcoeff : ∀ q d, + (coeff q d : Real) = + dotProduct (fun j => (w j d : Real)) + (fun j => (lnCoeff q j : Real))) : + ∀ q d, + (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).1 : Rat) ≤ + dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + + (b d : Real) ∧ + dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + + (b d : Real) ≤ + (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).2 : Rat) := by intro q d - let x := inputs.embed q; let μRat : Rat := mean x - let centered : Fin dModel → Rat := fun j => x j - μRat - let coeff : Fin dModel → Rat := fun j => inputs.ln1Gamma j * centered j - let invStd : Real := - (Real.sqrt ((varianceRat x : Real) + (inputs.lnEps : Real)))⁻¹ - have hmu : (μRat : Real) = meanRat x := by - have hmu_rat : μRat = meanRat x := by simp [μRat, mean_def, hmodel, ratRoundDown] - simpa [ratToReal] using congrArg ratToReal hmu_rat - have hinv : (invStdLo q : Real) ≤ invStd ∧ invStd ≤ (invStdHi q : Real) := by - simpa [invStdLo, invStdHi, invStdBoundsArr, invStdBoundsTasks, invStdBounds, - Task.spawn, Array.getElem_ofFn] using - (Bounds.invStdBounds_spec (eps := inputs.lnEps) (x := x) hmodel hEps hSqrt) - have hln : ∀ j, - lnRealOfInputs inputs q j = - (inputs.ln1Beta j : Real) + (coeff j : Real) * invStd := by - intro j; simp [lnRealOfInputs, Bounds.layerNormReal, hmodel, coeff, centered, μRat, - hmu, invStd, x, add_comm, mul_assoc, -mul_eq_mul_left_iff, -mul_eq_mul_right_iff] - have hln_fun : + have hinv : (invStdLo q : Real) ≤ invStd q ∧ invStd q ≤ (invStdHi q : Real) := + hinv_bounds q + have hln_fun_q : lnRealOfInputs inputs q = - fun j => (inputs.ln1Beta j : Real) + (coeff j : Real) * invStd := by - funext j; exact hln j - have hbase : - (qBase d : Real) = - dotProduct (fun j => (inputs.wq j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - (inputs.bq d : Real) := by simp [qBase, qBaseArr, dotFin_cast] - have hcoeff : - (qCoeff q d : Real) = - dotProduct (fun j => (inputs.wq j d : Real)) (fun j => (coeff j : Real)) := by - simp [qCoeff, qCoeffArr, qCoeffRowTasks, Task.spawn, coeff, centered, μRat, x, - dotFin_cast] + fun j => (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q := by + exact hln_fun q have hdot_add : - dotProduct (fun j => (inputs.wq j d : Real)) + dotProduct (fun j => (w j d : Real)) (fun j => - (inputs.ln1Beta j : Real) + (coeff j : Real) * invStd) = - dotProduct (fun j => (inputs.wq j d : Real)) + (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q) = + dotProduct (fun j => (w j d : Real)) (fun j => (inputs.ln1Beta j : Real)) + - dotProduct (fun j => (inputs.wq j d : Real)) - (fun j => (coeff j : Real) * invStd) := by - simpa [dotProduct, mul_add] using - (Finset.sum_add_distrib (s := (Finset.univ : Finset (Fin dModel))) - (f := fun j => (inputs.wq j d : Real) * (inputs.ln1Beta j : Real)) - (g := fun j => (inputs.wq j d : Real) * ((coeff j : Real) * invStd))) + dotProduct (fun j => (w j d : Real)) + (fun j => (lnCoeff q j : Real) * invStd q) := by + simpa using + (Nfp.Sound.Linear.dotProduct_add_right + (x := fun j => (w j d : Real)) + (y := fun j => (inputs.ln1Beta j : Real)) + (z := fun j => (lnCoeff q j : Real) * invStd q)) have hdot_coeff : - dotProduct (fun j => (inputs.wq j d : Real)) - (fun j => (coeff j : Real) * invStd) = - dotProduct (fun j => (inputs.wq j d : Real)) (fun j => (coeff j : Real)) * - invStd := by - simp [dotProduct, mul_assoc, Finset.sum_mul] - have hq_real : - qRealOfInputs inputs q d = - (qBase d : Real) + (qCoeff q d : Real) * invStd := by + dotProduct (fun j => (w j d : Real)) + (fun j => (lnCoeff q j : Real) * invStd q) = + dotProduct (fun j => (w j d : Real)) + (fun j => (lnCoeff q j : Real)) * + invStd q := by + simpa using + (Nfp.Sound.Linear.dotProduct_mul_right + (x := fun j => (w j d : Real)) + (y := fun j => (lnCoeff q j : Real)) + (a := invStd q)) + have hreal : + dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + + (b d : Real) = + (base d : Real) + (coeff q d : Real) * invStd q := by calc - qRealOfInputs inputs q d = - dotProduct (fun j => (inputs.wq j d : Real)) + dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + + (b d : Real) = + dotProduct (fun j => (w j d : Real)) (fun j => - (inputs.ln1Beta j : Real) + (coeff j : Real) * invStd) + - (inputs.bq d : Real) := by simp [qRealOfInputs, hln_fun] + (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q) + + (b d : Real) := by + simp [hln_fun_q] _ = - dotProduct (fun j => (inputs.wq j d : Real)) + dotProduct (fun j => (w j d : Real)) (fun j => (inputs.ln1Beta j : Real)) + - dotProduct (fun j => (inputs.wq j d : Real)) (fun j => (coeff j : Real)) * - invStd + - (inputs.bq d : Real) := by simp [hdot_add, hdot_coeff, add_assoc] + dotProduct (fun j => (w j d : Real)) + (fun j => (lnCoeff q j : Real)) * + invStd q + + (b d : Real) := by + simp [hdot_add, hdot_coeff, add_assoc] _ = - (dotProduct (fun j => (inputs.wq j d : Real)) + (dotProduct (fun j => (w j d : Real)) (fun j => (inputs.ln1Beta j : Real)) + - (inputs.bq d : Real)) + - dotProduct (fun j => (inputs.wq j d : Real)) (fun j => (coeff j : Real)) * - invStd := by ac_rfl - _ = (qBase d : Real) + (qCoeff q d : Real) * invStd := by simp [hbase, hcoeff] + (b d : Real)) + + dotProduct (fun j => (w j d : Real)) + (fun j => (lnCoeff q j : Real)) * + invStd q := by ac_rfl + _ = (base d : Real) + (coeff q d : Real) * invStd q := by + simp [hbase, hcoeff] have hscale : - let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) - (bounds.1 : Real) ≤ (qCoeff q d : Real) * invStd ∧ - (qCoeff q d : Real) * invStd ≤ (bounds.2 : Real) := by - exact scaleInterval_bounds_real (x := qCoeff q d) (lo := invStdLo q) - (hi := invStdHi q) (y := invStd) hinv.1 hinv.2 + let bounds := scaleInterval (coeff q d) (invStdLo q) (invStdHi q) + (bounds.1 : Real) ≤ (coeff q d : Real) * invStd q ∧ + (coeff q d : Real) * invStd q ≤ (bounds.2 : Real) := by + exact scaleInterval_bounds_real (x := coeff q d) (lo := invStdLo q) + (hi := invStdHi q) (y := invStd q) hinv.1 hinv.2 have hlow : - (qLo q d : Real) ≤ qRealOfInputs inputs q d := by - simpa [qLo, hq_real] using add_le_add_left hscale.1 (qBase d : Real) + (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).1 : Rat) ≤ + dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + + (b d : Real) := by + simpa [hreal] using add_le_add_left hscale.1 (base d : Real) have hhigh : - qRealOfInputs inputs q d ≤ (qHi q d : Real) := by - simpa [qHi, hq_real] using add_le_add_left hscale.2 (qBase d : Real) + dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + + (b d : Real) ≤ + (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).2 : Rat) := by + simpa [hreal] using add_le_add_left hscale.2 (base d : Real) exact ⟨hlow, hhigh⟩ + have hq_bounds : + ∀ q d, (qLo q d : Real) ≤ qRealOfInputs inputs q d ∧ + qRealOfInputs inputs q d ≤ (qHi q d : Real) := by + intro q d + have hbase : + ∀ d, + (qBase d : Real) = + dotProduct (fun j => (inputs.wq j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + (inputs.bq d : Real) := by + intro d + simp [qBase, qBaseArr, dotFin_cast] + have hcoeff : + ∀ q' d, + (qCoeff q' d : Real) = + dotProduct (fun j => (inputs.wq j d : Real)) + (fun j => (lnCoeff q' j : Real)) := by + intro q' d + simpa [qCoeff, qCoeffArr, qCoeffRowTasks, coeffRowTasks, lnCoeff, Task.spawn] using + (dotFin_cast (f := fun j => inputs.wq j d) + (g := fun j => + inputs.ln1Gamma j * (inputs.embed q' j - mean (inputs.embed q')))) + have h := proj_bounds (w := inputs.wq) (b := inputs.bq) (base := qBase) + (coeff := qCoeff) hbase hcoeff q d + simpa [qLo, qHi, qRealOfInputs] using h have hk_bounds : ∀ q d, (kLo q d : Real) ≤ kRealOfInputs inputs q d ∧ kRealOfInputs inputs q d ≤ (kHi q d : Real) := by intro q d - let x := inputs.embed q; let μRat : Rat := mean x - let centered : Fin dModel → Rat := fun j => x j - μRat - let coeff : Fin dModel → Rat := fun j => inputs.ln1Gamma j * centered j - let invStd : Real := - (Real.sqrt ((varianceRat x : Real) + (inputs.lnEps : Real)))⁻¹ - have hmu : (μRat : Real) = meanRat x := by - have hmu_rat : μRat = meanRat x := by simp [μRat, mean_def, hmodel, ratRoundDown] - simpa [ratToReal] using congrArg ratToReal hmu_rat - have hinv : (invStdLo q : Real) ≤ invStd ∧ invStd ≤ (invStdHi q : Real) := by - simpa [invStdLo, invStdHi, invStdBoundsArr, invStdBoundsTasks, invStdBounds, - Task.spawn, Array.getElem_ofFn] using - (Bounds.invStdBounds_spec (eps := inputs.lnEps) (x := x) hmodel hEps hSqrt) - have hln : ∀ j, - lnRealOfInputs inputs q j = - (inputs.ln1Beta j : Real) + (coeff j : Real) * invStd := by - intro j; simp [lnRealOfInputs, Bounds.layerNormReal, hmodel, coeff, centered, μRat, - hmu, invStd, x, add_comm, mul_assoc, -mul_eq_mul_left_iff, -mul_eq_mul_right_iff] - have hln_fun : - lnRealOfInputs inputs q = - fun j => (inputs.ln1Beta j : Real) + (coeff j : Real) * invStd := by - funext j; exact hln j have hbase : - (kBase d : Real) = - dotProduct (fun j => (inputs.wk j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - (inputs.bk d : Real) := by simp [kBase, kBaseArr, dotFin_cast] - have hcoeff : - (kCoeff q d : Real) = - dotProduct (fun j => (inputs.wk j d : Real)) (fun j => (coeff j : Real)) := by - simp [kCoeff, kCoeffArr, kCoeffRowTasks, Task.spawn, coeff, centered, μRat, x, - dotFin_cast] - have hdot_add : - dotProduct (fun j => (inputs.wk j d : Real)) - (fun j => - (inputs.ln1Beta j : Real) + (coeff j : Real) * invStd) = - dotProduct (fun j => (inputs.wk j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - dotProduct (fun j => (inputs.wk j d : Real)) - (fun j => (coeff j : Real) * invStd) := by - simpa [dotProduct, mul_add] using - (Finset.sum_add_distrib (s := (Finset.univ : Finset (Fin dModel))) - (f := fun j => (inputs.wk j d : Real) * (inputs.ln1Beta j : Real)) - (g := fun j => (inputs.wk j d : Real) * ((coeff j : Real) * invStd))) - have hdot_coeff : - dotProduct (fun j => (inputs.wk j d : Real)) - (fun j => (coeff j : Real) * invStd) = - dotProduct (fun j => (inputs.wk j d : Real)) (fun j => (coeff j : Real)) * - invStd := by - simp [dotProduct, mul_assoc, Finset.sum_mul] - have hk_real : - kRealOfInputs inputs q d = - (kBase d : Real) + (kCoeff q d : Real) * invStd := by - calc - kRealOfInputs inputs q d = - dotProduct (fun j => (inputs.wk j d : Real)) - (fun j => - (inputs.ln1Beta j : Real) + (coeff j : Real) * invStd) + - (inputs.bk d : Real) := by simp [kRealOfInputs, hln_fun] - _ = + ∀ d, + (kBase d : Real) = dotProduct (fun j => (inputs.wk j d : Real)) (fun j => (inputs.ln1Beta j : Real)) + - dotProduct (fun j => (inputs.wk j d : Real)) (fun j => (coeff j : Real)) * - invStd + - (inputs.bk d : Real) := by simp [hdot_add, hdot_coeff, add_assoc] - _ = - (dotProduct (fun j => (inputs.wk j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - (inputs.bk d : Real)) + - dotProduct (fun j => (inputs.wk j d : Real)) (fun j => (coeff j : Real)) * - invStd := by ac_rfl - _ = (kBase d : Real) + (kCoeff q d : Real) * invStd := by simp [hbase, hcoeff] - have hscale : - let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) - (bounds.1 : Real) ≤ (kCoeff q d : Real) * invStd ∧ - (kCoeff q d : Real) * invStd ≤ (bounds.2 : Real) := by - exact scaleInterval_bounds_real (x := kCoeff q d) (lo := invStdLo q) - (hi := invStdHi q) (y := invStd) hinv.1 hinv.2 - have hlow : - (kLo q d : Real) ≤ kRealOfInputs inputs q d := by - simpa [kLo, hk_real] using add_le_add_left hscale.1 (kBase d : Real) - have hhigh : - kRealOfInputs inputs q d ≤ (kHi q d : Real) := by - simpa [kHi, hk_real] using add_le_add_left hscale.2 (kBase d : Real) - exact ⟨hlow, hhigh⟩ + (inputs.bk d : Real) := by + intro d + simp [kBase, kBaseArr, dotFin_cast] + have hcoeff : + ∀ q' d, + (kCoeff q' d : Real) = + dotProduct (fun j => (inputs.wk j d : Real)) + (fun j => (lnCoeff q' j : Real)) := by + intro q' d + simpa [kCoeff, kCoeffArr, kCoeffRowTasks, coeffRowTasks, lnCoeff, Task.spawn] using + (dotFin_cast (f := fun j => inputs.wk j d) + (g := fun j => + inputs.ln1Gamma j * (inputs.embed q' j - mean (inputs.embed q')))) + have h := proj_bounds (w := inputs.wk) (b := inputs.bk) (base := kBase) + (coeff := kCoeff) hbase hcoeff q d + simpa [kLo, kHi, kRealOfInputs] using h + let scoresReal := scoresRealOfInputs inputs + have scoresReal_eq_base_of_not_masked : + ∀ q k, ¬ masked q k → + scoresReal q k = + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) := by + intro q k hnot + by_cases hcausal : inputs.maskCausal + · have hnot_lt : ¬ q < k := by + intro hlt + exact hnot ⟨hcausal, hlt⟩ + have hle : k ≤ q := le_of_not_gt hnot_lt + simp [scoresReal, scoresRealOfInputs, hcausal, hle] + · simp [scoresReal, scoresRealOfInputs, hcausal] + have scoresReal_eq_masked : + ∀ q k, masked q k → scoresReal q k = (inputs.maskValue : Real) := by + intro q k hmask + have hmask' : inputs.maskCausal = true ∧ q < k := by + simpa [masked] using hmask + have hle : ¬ k ≤ q := not_le_of_gt hmask'.2 + simp [scoresReal, scoresRealOfInputs, hmask'.1, hle] have hscore_bounds : - ∀ q k, (scoreLo q k : Real) ≤ scoresRealOfInputs inputs q k ∧ - scoresRealOfInputs inputs q k ≤ (scoreHi q k : Real) := by + ∀ q k, (scoreLo q k : Real) ≤ scoresReal q k ∧ + scoresReal q k ≤ (scoreHi q k : Real) := by intro q k - let scoresReal := scoresRealOfInputs inputs let base := (inputs.scale : Real) * dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) @@ -741,67 +757,52 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} simpa [dotHi, dotRowTasks, Task.spawn, Array.getElem_ofFn] using hspec.2 exact ⟨hlow', hhigh'⟩ - by_cases hcausal : inputs.maskCausal - · by_cases hle : k ≤ q - · have hnot : ¬ q < k := not_lt_of_ge hle - have hscore_eq : scoresReal q k = base := by - simp [scoresReal, scoresRealOfInputs, hcausal, hle, base] - by_cases hscale : 0 ≤ inputs.scale - · have hscale_real : 0 ≤ (inputs.scale : Real) := - ratToReal_nonneg_of_nonneg hscale - have hlow := - mul_le_mul_of_nonneg_left hdot_bounds.1 hscale_real - have hhigh := - mul_le_mul_of_nonneg_left hdot_bounds.2 hscale_real - constructor - · simpa [scoresReal, scoreLo, masked, hcausal, hnot, hscale, hscore_eq, base] - using hlow - · simpa [scoresReal, scoreHi, masked, hcausal, hnot, hscale, hscore_eq, base] - using hhigh - · have hscale_nonpos : inputs.scale ≤ 0 := - le_of_lt (lt_of_not_ge hscale) - have hscale_real : (inputs.scale : Real) ≤ 0 := - (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos - have hlow := - mul_le_mul_of_nonpos_left hdot_bounds.2 hscale_real - have hhigh := - mul_le_mul_of_nonpos_left hdot_bounds.1 hscale_real - constructor - · simpa [scoresReal, scoreLo, masked, hcausal, hnot, hscale, hscore_eq, base] - using hlow - · simpa [scoresReal, scoreHi, masked, hcausal, hnot, hscale, hscore_eq, base] - using hhigh - · have hlt : q < k := lt_of_not_ge hle - constructor - · simp [scoresRealOfInputs, scoreLo, masked, hcausal, hle, hlt] - · simp [scoresRealOfInputs, scoreHi, masked, hcausal, hle, hlt] - · have hscore_eq : scoresReal q k = base := by - simp [scoresReal, scoresRealOfInputs, hcausal, base] + have hscore_base_bounds (hnot : ¬ masked q k) : + (scoreLo q k : Real) ≤ base ∧ base ≤ (scoreHi q k : Real) := by by_cases hscale : 0 ≤ inputs.scale · have hscale_real : 0 ≤ (inputs.scale : Real) := ratToReal_nonneg_of_nonneg hscale - have hlow := - mul_le_mul_of_nonneg_left hdot_bounds.1 hscale_real - have hhigh := - mul_le_mul_of_nonneg_left hdot_bounds.2 hscale_real + have hlow := mul_le_mul_of_nonneg_left hdot_bounds.1 hscale_real + have hhigh := mul_le_mul_of_nonneg_left hdot_bounds.2 hscale_real constructor - · simpa [scoresReal, scoreLo, masked, hcausal, hscale, hscore_eq, base] - using hlow - · simpa [scoresReal, scoreHi, masked, hcausal, hscale, hscore_eq, base] - using hhigh + · simpa [scoreLo, masked, hnot, hscale, base] using hlow + · simpa [scoreHi, masked, hnot, hscale, base] using hhigh · have hscale_nonpos : inputs.scale ≤ 0 := le_of_lt (lt_of_not_ge hscale) have hscale_real : (inputs.scale : Real) ≤ 0 := (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos - have hlow := - mul_le_mul_of_nonpos_left hdot_bounds.2 hscale_real - have hhigh := - mul_le_mul_of_nonpos_left hdot_bounds.1 hscale_real + have hlow := mul_le_mul_of_nonpos_left hdot_bounds.2 hscale_real + have hhigh := mul_le_mul_of_nonpos_left hdot_bounds.1 hscale_real constructor - · simpa [scoresReal, scoreLo, masked, hcausal, hscale, hscore_eq, base] - using hlow - · simpa [scoresReal, scoreHi, masked, hcausal, hscale, hscore_eq, base] - using hhigh + · simpa [scoreLo, masked, hnot, hscale, base] using hlow + · simpa [scoreHi, masked, hnot, hscale, base] using hhigh + by_cases hcausal : inputs.maskCausal + · by_cases hle : k ≤ q + · have hnot : ¬ q < k := not_lt_of_ge hle + have hnot_masked : ¬ masked q k := fun hmk => hnot hmk.2 + have hscore_eq : scoresReal q k = base := + scoresReal_eq_base_of_not_masked q k hnot_masked + have hbase := hscore_base_bounds hnot_masked + constructor + · simpa [hscore_eq] using hbase.1 + · simpa [hscore_eq] using hbase.2 + · have hlt : q < k := lt_of_not_ge hle + have hmask : masked q k := ⟨hcausal, hlt⟩ + have hscore : scoresReal q k = (inputs.maskValue : Real) := + scoresReal_eq_masked q k hmask + constructor + · + simp [hscore, scoreLo, hmask] + · + simp [hscore, scoreHi, hmask] + · have hnot_masked : ¬ masked q k := by + simp [masked, hcausal] + have hscore_eq : scoresReal q k = base := + scoresReal_eq_base_of_not_masked q k hnot_masked + have hbase := hscore_base_bounds hnot_masked + constructor + · simpa [hscore_eq] using hbase.1 + · simpa [hscore_eq] using hbase.2 have hdot_diff_bounds : ∀ q, q ∈ inputs.active → ∀ k, ¬ masked q k → (dotDiffLo q k : Real) ≤ @@ -849,22 +850,22 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} hlo1 hhi1 hlo2 hhi2 have hspecBase := hspec (splitDimsDiffBase q k) have hspecRef := hspec (splitDimsDiffRefined q k) + have hspecBase_bounds : + (dotDiffLoBase q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) ∧ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) ≤ (dotDiffHiBase q k : Real) := by + refine ⟨?_, ?_⟩ + · simpa [dotDiffLoBase, dotDiffRowTasksBase, hq, hmask, Task.spawn, + Array.getElem_ofFn] using hspecBase.1 + · simpa [dotDiffHiBase, dotDiffRowTasksBase, hq, hmask, Task.spawn, + Array.getElem_ofFn] using hspecBase.2 cases hkey : worstKey q with | none => - have hlow' : - (dotDiffLo q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) := by - simpa [dotDiffLo, hkey, hq, dotDiffLoBase, dotDiffRowTasksBase, hmask, - Task.spawn, Array.getElem_ofFn] using hspecBase.1 - have hhigh' : - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by - simpa [dotDiffHi, hkey, hq, dotDiffHiBase, dotDiffRowTasksBase, hmask, - Task.spawn, Array.getElem_ofFn] using hspecBase.2 - exact ⟨hlow', hhigh'⟩ + simpa [dotDiffLo, dotDiffHi, hkey] using hspecBase_bounds | some k' => by_cases hk : k = k' · have hlow' : @@ -884,16 +885,13 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) := by - simpa [dotDiffLo, hkey, hk, hq, dotDiffLoBase, dotDiffRowTasksBase, hmask, - Task.spawn, Array.getElem_ofFn] using hspecBase.1 + simpa [dotDiffLo, hkey, hk] using hspecBase_bounds.1 have hhigh' : dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by - simpa [dotDiffHi, hkey, hk, hq, dotDiffHiBase, dotDiffRowTasksBase, hmask, - Task.spawn, Array.getElem_ofFn] using hspecBase.2 + simpa [dotDiffHi, hkey, hk] using hspecBase_bounds.2 exact ⟨hlow', hhigh'⟩ - let scoresReal := scoresRealOfInputs inputs have hmarginAt_le : ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → marginAt q ≤ scoreGapLo q k := by @@ -940,11 +938,8 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} · have hscore_prev : (scoreLoPrev q : Real) ≤ scoresReal q (inputs.prev q) := by have hprev_bounds := hscore_bounds q (inputs.prev q) simpa [scoreLoPrev] using hprev_bounds.1 - have hmask' : inputs.maskCausal = true ∧ q < k := by - simpa [masked] using hmask - have hle : ¬ k ≤ q := not_le_of_gt hmask'.2 - have hscore_k : scoresReal q k = (inputs.maskValue : Real) := by - simp [scoresReal, scoresRealOfInputs, hmask'.1, hle] + have hscore_k : scoresReal q k = (inputs.maskValue : Real) := + scoresReal_eq_masked q k hmask calc scoresReal q k + (scoreGapLo q k : Real) = (inputs.maskValue : Real) + (scoreLoPrev q : Real) - @@ -976,25 +971,14 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} (inputs.scale : Real) * dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs (inputs.prev q) d) := by - by_cases hcausal : inputs.maskCausal - · have hlt_prev : ¬ q < inputs.prev q := by - intro hlt - exact hprevmask (by exact ⟨hcausal, hlt⟩) - have hle_prev : inputs.prev q ≤ q := le_of_not_gt hlt_prev - simp [scoresReal, scoresRealOfInputs, hcausal, hle_prev] - · simp [scoresReal, scoresRealOfInputs, hcausal] + simpa using + (scoresReal_eq_base_of_not_masked q (inputs.prev q) hprevmask) have hscore_k : scoresReal q k = (inputs.scale : Real) * dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) := by - by_cases hcausal : inputs.maskCausal - · have hlt : ¬ q < k := by - intro hlt - exact hmask (by exact ⟨hcausal, hlt⟩) - have hle : k ≤ q := le_of_not_gt hlt - simp [scoresReal, scoresRealOfInputs, hcausal, hle] - · simp [scoresReal, scoresRealOfInputs, hcausal] + simpa using (scoresReal_eq_base_of_not_masked q k hmask) have hdot_sub : dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs (inputs.prev q) d - @@ -1004,7 +988,11 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) := by classical - simp [dotProduct, mul_sub, Finset.sum_sub_distrib] + simpa using + (Nfp.Sound.Linear.dotProduct_sub_right + (x := fun d => qRealOfInputs inputs q d) + (y := fun d => kRealOfInputs inputs (inputs.prev q) d) + (z := fun d => kRealOfInputs inputs k d)) have hscore_diff : scoresReal q (inputs.prev q) - scoresReal q k = (inputs.scale : Real) * @@ -1405,13 +1393,28 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} oneHot_bounds_at := oneHot_bounds_at value_bounds := hvals_bounds } · have : False := by - simp [buildInductionCertFromHeadCore?, hEps, hSqrt, hmodel, hactive] at hcore + have hnone := + buildInductionCertFromHeadCore?_eq_none_of_not_active + (inputs := inputs) hEps hSqrt hmodel hactive + have hcore' : + (none : Option (InductionHeadCert seq)) = some c := by + exact hnone.symm.trans hcore + cases hcore' exact this.elim · have : False := by - simp [buildInductionCertFromHeadCore?, hEps, hSqrt] at hcore + have hnone := + buildInductionCertFromHeadCore?_eq_none_of_not_sqrt (inputs := inputs) hEps hSqrt + have hcore' : + (none : Option (InductionHeadCert seq)) = some c := by + exact hnone.symm.trans hcore + cases hcore' exact this.elim · have : False := by - simp [buildInductionCertFromHeadCore?, hEps] at hcore + have hnone := buildInductionCertFromHeadCore?_eq_none_of_not_eps (inputs := inputs) hEps + have hcore' : + (none : Option (InductionHeadCert seq)) = some c := by + exact hnone.symm.trans hcore + cases hcore' exact this.elim end Sound end Nfp diff --git a/Nfp/Sound/Induction/CoreSound/Values.lean b/Nfp/Sound/Induction/CoreSound/Values.lean index 4938e70..1a70284 100644 --- a/Nfp/Sound/Induction/CoreSound/Values.lean +++ b/Nfp/Sound/Induction/CoreSound/Values.lean @@ -70,7 +70,11 @@ theorem valsReal_eq_of_dir (inputs : Model.InductionHeadInputs seq dModel dHead) dotProduct (fun d => (dirHead d : Real)) (fun d => dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs k)) + dotProduct (fun d => (dirHead d : Real)) (fun d => (inputs.bv d : Real)) := by - simp [dotProduct, mul_add, Finset.sum_add_distrib] + simpa using + (Nfp.Sound.Linear.dotProduct_add_right + (x := fun d => (dirHead d : Real)) + (y := fun d => dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs k)) + (z := fun d => (inputs.bv d : Real))) have hdot_wv : dotProduct (fun d => (dirHead d : Real)) (fun d => dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs k)) = diff --git a/Nfp/Sound/Induction/HeadBounds.lean b/Nfp/Sound/Induction/HeadBounds.lean index 7acbd04..dec0ea2 100644 --- a/Nfp/Sound/Induction/HeadBounds.lean +++ b/Nfp/Sound/Induction/HeadBounds.lean @@ -791,9 +791,7 @@ theorem headValueValsLo_spec {seq dModel dHead : Nat} def headValueValsLoCommonDenArray {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (vLo vHi : Fin seq → Fin dHead → Rat) : Array Rat := - let dirHead := headValueDirHead inputs - Array.ofFn (fun k => - Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) + headValueValsLoArray inputs vLo vHi /-- Unfold `headValueValsLoCommonDenArray` to its `Array.ofFn` definition. -/ theorem headValueValsLoCommonDenArray_spec {seq dModel dHead : Nat} @@ -867,9 +865,7 @@ theorem headValueValsHi_spec {seq dModel dHead : Nat} def headValueValsHiCommonDenArray {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (vLo vHi : Fin seq → Fin dHead → Rat) : Array Rat := - let dirHead := headValueDirHead inputs - Array.ofFn (fun k => - Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) + headValueValsHiArray inputs vLo vHi /-- Unfold `headValueValsHiCommonDenArray` to its `Array.ofFn` definition. -/ theorem headValueValsHiCommonDenArray_spec {seq dModel dHead : Nat} diff --git a/Nfp/Sound/Induction/HeadOutput.lean b/Nfp/Sound/Induction/HeadOutput.lean index 7ba3966..d67341a 100644 --- a/Nfp/Sound/Induction/HeadOutput.lean +++ b/Nfp/Sound/Induction/HeadOutput.lean @@ -1,5 +1,6 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later +import Aesop import Nfp.Sound.Induction.CoreSound /-! @@ -135,19 +136,35 @@ def buildHeadOutputIntervalFromHead? [NeZero seq] (hln j).1 have hhi : ∀ j, lnRealOfInputs inputs q j ≤ (lnHi q j : Real) := fun j => (hln j).2 - have hlow := - dotIntervalLower_le_dotProduct_real (v := fun j => inputs.wv j d) - (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi - have hhigh := - dotProduct_le_dotIntervalUpper_real (v := fun j => inputs.wv j d) - (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) hlo hhi + have hlow' : + dotIntervalLower (fun j => inputs.wv j d) (lnLo q) (lnHi q) + + (inputs.bv d : Real) ≤ + dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs q) + + (inputs.bv d : Real) := + by + simpa using + dotIntervalLower_le_dotProduct_real_add + (v := fun j => inputs.wv j d) + (lo := lnLo q) (hi := lnHi q) + (x := lnRealOfInputs inputs q) (b := (inputs.bv d : Real)) hlo hhi + have hhigh' : + dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs q) + + (inputs.bv d : Real) ≤ + dotIntervalUpper (fun j => inputs.wv j d) (lnLo q) (lnHi q) + + (inputs.bv d : Real) := + by + simpa using + dotProduct_le_dotIntervalUpper_real_add + (v := fun j => inputs.wv j d) + (lo := lnLo q) (hi := lnHi q) + (x := lnRealOfInputs inputs q) (b := (inputs.bv d : Real)) hlo hhi constructor · simpa [vLo, vRealOfInputs, Bounds.cacheBound2_apply, Bounds.dotIntervalLowerCachedRat_eq, ratToReal_add] using - add_le_add_right hlow (inputs.bv d : Real) + hlow' · simpa [vHi, vRealOfInputs, Bounds.cacheBound2_apply, Bounds.dotIntervalUpperCachedRat_eq, ratToReal_add] using - add_le_add_right hhigh (inputs.bv d : Real) + hhigh' have hhead_bounds : ∀ k i, (headValueLo k i : Real) ≤ headValueRealOfInputs inputs k i ∧ headValueRealOfInputs inputs k i ≤ (headValueHi k i : Real) := by @@ -182,22 +199,22 @@ def buildHeadOutputIntervalFromHead? [NeZero seq] (loVal i : Real) (hiVal i : Real) (fun k => headValueRealOfInputs inputs k i) := by intro i + have hloVal : ∀ k, loVal i ≤ headValueLo k i := by + intro k + dsimp [loVal] + refine (Finset.inf'_le_iff (s := univ) (H := huniv) + (f := fun k => headValueLo k i) (a := headValueLo k i)).2 ?_ + exact ⟨k, by simp [univ], le_rfl⟩ + have hhiVal : ∀ k, headValueHi k i ≤ hiVal i := by + intro k + dsimp [hiVal] + refine (Finset.le_sup'_iff (s := univ) (H := huniv) + (f := fun k => headValueHi k i) (a := headValueHi k i)).2 ?_ + exact ⟨k, ⟨by simp [univ], le_rfl⟩⟩ refine { lo_le_hi := ?_, lo_le := ?_, le_hi := ?_ } · rcases (Finset.univ_nonempty : univ.Nonempty) with ⟨k0, hk0⟩ - have hmem0 : k0 ∈ univ := hk0 - have hloRat : loVal i ≤ headValueLo k0 i := by - change loVal i ≤ headValueLo k0 i - dsimp [loVal] - refine (Finset.inf'_le_iff (s := univ) (H := huniv) - (f := fun k => headValueLo k i) (a := headValueLo k0 i)).2 ?_ - refine ⟨k0, hmem0, ?_⟩ - exact le_rfl - have hhiRat : headValueHi k0 i ≤ hiVal i := by - change headValueHi k0 i ≤ hiVal i - dsimp [hiVal] - refine (Finset.le_sup'_iff (s := univ) (H := huniv) - (f := fun k => headValueHi k i) (a := headValueHi k0 i)).2 ?_ - exact ⟨k0, ⟨hmem0, le_rfl⟩⟩ + have hloRat : loVal i ≤ headValueLo k0 i := hloVal k0 + have hhiRat : headValueHi k0 i ≤ hiVal i := hhiVal k0 have hbounds := hhead_bounds k0 i have hreal : (loVal i : Real) ≤ (hiVal i : Real) := by refine le_trans (ratToReal_le_of_le hloRat) ?_ @@ -205,24 +222,11 @@ def buildHeadOutputIntervalFromHead? [NeZero seq] exact le_trans hbounds.2 (ratToReal_le_of_le hhiRat) exact hreal · intro k - have hmem : k ∈ univ := by simp [univ] - have hloRat : loVal i ≤ headValueLo k i := by - change loVal i ≤ headValueLo k i - dsimp [loVal] - refine (Finset.inf'_le_iff (s := univ) (H := huniv) - (f := fun k => headValueLo k i) (a := headValueLo k i)).2 ?_ - refine ⟨k, hmem, ?_⟩ - exact le_rfl + have hloRat : loVal i ≤ headValueLo k i := hloVal k have hbounds := hhead_bounds k i exact (ratToReal_le_of_le hloRat) |>.trans hbounds.1 · intro k - have hmem : k ∈ univ := by simp [univ] - have hhiRat : headValueHi k i ≤ hiVal i := by - change headValueHi k i ≤ hiVal i - dsimp [hiVal] - refine (Finset.le_sup'_iff (s := univ) (H := huniv) - (f := fun k => headValueHi k i) (a := headValueHi k i)).2 ?_ - exact ⟨k, ⟨hmem, le_rfl⟩⟩ + have hhiRat : headValueHi k i ≤ hiVal i := hhiVal k have hbounds := hhead_bounds k i exact hbounds.2.trans (ratToReal_le_of_le hhiRat) diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean index ec6fe19..189e925 100644 --- a/Nfp/Sound/Induction/LogitDiff.lean +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -1,5 +1,6 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later +import Aesop import Nfp.Circuit.Cert.LogitDiff import Nfp.Sound.Induction @@ -13,10 +14,25 @@ namespace Sound open Nfp.Circuit +variable {seq : Nat} + +private theorem valueRangeBounds_of_valueIntervalBounds + {vals : Fin seq → Real} {c : ValueInterval seq} + (h : ValueIntervalBounds vals c) : + Layers.ValueRangeBounds (Val := Real) (c.lo : Real) (c.hi : Real) vals := by + refine { lo_le_hi := ?_, lo_le := ?_, le_hi := ?_ } + · exact ratToReal_le_of_le h.lo_le_hi + · intro k + exact le_trans (h.lo_le_valsLo k) (h.vals_bounds k).1 + · intro k + exact le_trans (h.vals_bounds k).2 (h.valsHi_le_hi k) + section LogitDiffLowerBound variable {seq dModel dHead : Nat} [NeZero seq] +section + /-- Real-valued logit-diff contribution for a query. -/ noncomputable def headLogitDiff (inputs : Model.InductionHeadInputs seq dModel dHead) (q : Fin seq) : Real := @@ -48,17 +64,9 @@ theorem logitDiffLowerBoundFromCert_le hsound.oneHot_bounds_at q hq have hvalsRange : Layers.ValueRangeBounds (Val := Real) (c.values.lo : Real) (c.values.hi : Real) - (valsRealOfInputs inputs) := by - refine { lo_le_hi := ?_, lo_le := ?_, le_hi := ?_ } - · exact ratToReal_le_of_le hsound.value_bounds.lo_le_hi - · intro k - exact - le_trans (hsound.value_bounds.lo_le_valsLo k) - (hsound.value_bounds.vals_bounds k).1 - · intro k - exact - le_trans (hsound.value_bounds.vals_bounds k).2 - (hsound.value_bounds.valsHi_le_hi k) + (valsRealOfInputs inputs) := + valueRangeBounds_of_valueIntervalBounds + (vals := valsRealOfInputs inputs) (c := c.values) hsound.value_bounds have happrox := Layers.inductionSpecApproxOn_of_oneHotApprox_valueRange (Val := Real) @@ -175,6 +183,8 @@ def buildInductionLogitLowerBoundNonvacuous? · exact some ⟨base, hpos⟩ · exact none +end + end LogitDiffLowerBound end Sound diff --git a/Nfp/Sound/Linear/FinFold.lean b/Nfp/Sound/Linear/FinFold.lean index a224473..da24769 100644 --- a/Nfp/Sound/Linear/FinFold.lean +++ b/Nfp/Sound/Linear/FinFold.lean @@ -100,6 +100,24 @@ theorem dotFin_eq_dotProduct (n : Nat) (x y : Fin n → Rat) : dotFin n x y = dotProduct x y := by simp [dotFin_def, sumFin_eq_sum_univ, dotProduct] +/-- Right-distribute dot products over subtraction (Real-valued). -/ +theorem dotProduct_sub_right {n : Nat} (x y z : Fin n → Real) : + dotProduct x (fun i => y i - z i) = dotProduct x y - dotProduct x z := by + classical + simp only [dotProduct, mul_sub, Finset.sum_sub_distrib] + +/-- Right-distribute dot products over addition (Real-valued). -/ +theorem dotProduct_add_right {n : Nat} (x y z : Fin n → Real) : + dotProduct x (fun i => y i + z i) = dotProduct x y + dotProduct x z := by + classical + simp only [dotProduct, mul_add, Finset.sum_add_distrib] + +/-- Pull a constant factor out of the right-hand side of a dot product. -/ +theorem dotProduct_mul_right {n : Nat} (x y : Fin n → Real) (a : Real) : + dotProduct x (fun i => y i * a) = dotProduct x y * a := by + classical + simp only [dotProduct, mul_assoc, Finset.sum_mul] + end Linear end Sound From 32a8ed7e80999a1e9d24a5f2b1f1ab5bca00340a Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Mon, 12 Jan 2026 16:03:15 +0100 Subject: [PATCH 137/244] Refactor interval bounds foldl pair helper --- .beads/issues.jsonl | 1 + Nfp/Sound/Bounds/MatrixNorm/Interval.lean | 59 +++++++++++------------ Nfp/Sound/Induction/CoreSound.lean | 6 +-- 3 files changed, 32 insertions(+), 34 deletions(-) diff --git a/.beads/issues.jsonl b/.beads/issues.jsonl index e69de29..889428c 100644 --- a/.beads/issues.jsonl +++ b/.beads/issues.jsonl @@ -0,0 +1 @@ +{"id":"nfp-snh","title":"Refactor MatrixNorm interval bounds proofs","description":"Refactor Nfp/Sound/Bounds/MatrixNorm/Interval.lean to reduce proof churn and simplify interval bound lemmas.","status":"closed","priority":2,"issue_type":"task","owner":"robin.gieseke@me.com","created_at":"2026-01-12T15:54:50.752468+01:00","created_by":"TheDarkchip","updated_at":"2026-01-12T16:01:12.070535+01:00","closed_at":"2026-01-12T16:01:12.070535+01:00","close_reason":"Closed","comments":[{"id":1,"issue_id":"nfp-snh","author":"TheDarkchip","text":"Refactored foldl_pair helper for pair foldl projections in MatrixNorm interval bounds; adjusted proofs to use Prod.fst/Prod.snd. Fixed CoreSound cdot lint warnings encountered during build.","created_at":"2026-01-12T15:01:08Z"}]} diff --git a/Nfp/Sound/Bounds/MatrixNorm/Interval.lean b/Nfp/Sound/Bounds/MatrixNorm/Interval.lean index dc875f6..262f39a 100644 --- a/Nfp/Sound/Bounds/MatrixNorm/Interval.lean +++ b/Nfp/Sound/Bounds/MatrixNorm/Interval.lean @@ -277,18 +277,9 @@ theorem dotIntervalUpperCommonDen_eq {n : Nat} (v lo hi : Fin n → Rat) : dotIntervalUpperCommonDen v lo hi = dotIntervalUpper v lo hi := by simp only [dotIntervalUpperCommonDen, dotIntervalUpper, Linear.sumFinCommonDen_eq_sumFin] -private lemma foldl_pair_fst {α : Type _} (xs : List α) (f g : α → Rat) (a b : Rat) : - (xs.foldl (fun acc x => (acc.1 + f x, acc.2 + g x)) (a, b)).1 = - xs.foldl (fun acc x => acc + f x) a := by - induction xs generalizing a b with - | nil => - simp - | cons x xs ih => - simp [List.foldl, ih] - -private lemma foldl_pair_snd {α : Type _} (xs : List α) (f g : α → Rat) (a b : Rat) : - (xs.foldl (fun acc x => (acc.1 + f x, acc.2 + g x)) (a, b)).2 = - xs.foldl (fun acc x => acc + g x) b := by +private lemma foldl_pair {α : Type _} (xs : List α) (f g : α → Rat) (a b : Rat) : + xs.foldl (fun acc x => (acc.1 + f x, acc.2 + g x)) (a, b) = + (xs.foldl (fun acc x => acc + f x) a, xs.foldl (fun acc x => acc + g x) b) := by induction xs generalizing a b with | nil => simp @@ -299,23 +290,27 @@ theorem dotIntervalLowerUpper2CommonDen_fst {n : Nat} (lo1 hi1 lo2 hi2 : Fin n (dotIntervalLowerUpper2CommonDen lo1 hi1 lo2 hi2).1 = dotIntervalLower2 lo1 hi1 lo2 hi2 := by classical - simpa [dotIntervalLowerUpper2CommonDen, dotIntervalLower2, Linear.foldlFin_eq_foldl, - Linear.sumFin_eq_list_foldl, Fin.foldl_eq_foldl_finRange] using - (foldl_pair_fst (xs := List.finRange n) + have hpair := + foldl_pair (xs := List.finRange n) (f := fun j => mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j)) (g := fun j => mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j)) - (a := 0) (b := 0)) + (a := 0) (b := 0) + have hfst := congrArg Prod.fst hpair + simpa [dotIntervalLowerUpper2CommonDen, dotIntervalLower2, Linear.foldlFin_eq_foldl, + Linear.sumFin_eq_list_foldl, Fin.foldl_eq_foldl_finRange] using hfst theorem dotIntervalLowerUpper2CommonDen_snd {n : Nat} (lo1 hi1 lo2 hi2 : Fin n → Rat) : (dotIntervalLowerUpper2CommonDen lo1 hi1 lo2 hi2).2 = dotIntervalUpper2 lo1 hi1 lo2 hi2 := by classical - simpa [dotIntervalLowerUpper2CommonDen, dotIntervalUpper2, Linear.foldlFin_eq_foldl, - Linear.sumFin_eq_list_foldl, Fin.foldl_eq_foldl_finRange] using - (foldl_pair_snd (xs := List.finRange n) + have hpair := + foldl_pair (xs := List.finRange n) (f := fun j => mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j)) (g := fun j => mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j)) - (a := 0) (b := 0)) + (a := 0) (b := 0) + have hsnd := congrArg Prod.snd hpair + simpa [dotIntervalLowerUpper2CommonDen, dotIntervalUpper2, Linear.foldlFin_eq_foldl, + Linear.sumFin_eq_list_foldl, Fin.foldl_eq_foldl_finRange] using hsnd theorem dotIntervalLowerUpper2CommonDen_eq {n : Nat} (lo1 hi1 lo2 hi2 : Fin n → Rat) : dotIntervalLowerUpper2CommonDen lo1 hi1 lo2 hi2 = @@ -325,24 +320,28 @@ theorem dotIntervalLowerUpper2CommonDen_eq {n : Nat} (lo1 hi1 lo2 hi2 : Fin n theorem dotIntervalLowerUpperCommonDen_fst {n : Nat} (v lo hi : Fin n → Rat) : (dotIntervalLowerUpperCommonDen v lo hi).1 = dotIntervalLowerCommonDen v lo hi := by classical - simpa [dotIntervalLowerUpperCommonDen, dotIntervalLowerCommonDen, - Linear.foldlFin_eq_foldl, Linear.sumFinCommonDen_eq_sumFin, - Linear.sumFin_eq_list_foldl, Fin.foldl_eq_foldl_finRange] using - (foldl_pair_fst (xs := List.finRange n) + have hpair := + foldl_pair (xs := List.finRange n) (f := fun j => if 0 ≤ v j then v j * lo j else v j * hi j) (g := fun j => if 0 ≤ v j then v j * hi j else v j * lo j) - (a := 0) (b := 0)) + (a := 0) (b := 0) + have hfst := congrArg Prod.fst hpair + simpa [dotIntervalLowerUpperCommonDen, dotIntervalLowerCommonDen, + Linear.foldlFin_eq_foldl, Linear.sumFinCommonDen_eq_sumFin, + Linear.sumFin_eq_list_foldl, Fin.foldl_eq_foldl_finRange] using hfst theorem dotIntervalLowerUpperCommonDen_snd {n : Nat} (v lo hi : Fin n → Rat) : (dotIntervalLowerUpperCommonDen v lo hi).2 = dotIntervalUpperCommonDen v lo hi := by classical - simpa [dotIntervalLowerUpperCommonDen, dotIntervalUpperCommonDen, - Linear.foldlFin_eq_foldl, Linear.sumFinCommonDen_eq_sumFin, - Linear.sumFin_eq_list_foldl, Fin.foldl_eq_foldl_finRange] using - (foldl_pair_snd (xs := List.finRange n) + have hpair := + foldl_pair (xs := List.finRange n) (f := fun j => if 0 ≤ v j then v j * lo j else v j * hi j) (g := fun j => if 0 ≤ v j then v j * hi j else v j * lo j) - (a := 0) (b := 0)) + (a := 0) (b := 0) + have hsnd := congrArg Prod.snd hpair + simpa [dotIntervalLowerUpperCommonDen, dotIntervalUpperCommonDen, + Linear.foldlFin_eq_foldl, Linear.sumFinCommonDen_eq_sumFin, + Linear.sumFin_eq_list_foldl, Fin.foldl_eq_foldl_finRange] using hsnd /-- Single-pass lower/upper endpoints agree with the common-denominator bounds. -/ theorem dotIntervalLowerUpperCommonDen_eq {n : Nat} (v lo hi : Fin n → Rat) : diff --git a/Nfp/Sound/Induction/CoreSound.lean b/Nfp/Sound/Induction/CoreSound.lean index e9ddabf..4420bba 100644 --- a/Nfp/Sound/Induction/CoreSound.lean +++ b/Nfp/Sound/Induction/CoreSound.lean @@ -791,10 +791,8 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} have hscore : scoresReal q k = (inputs.maskValue : Real) := scoresReal_eq_masked q k hmask constructor - · - simp [hscore, scoreLo, hmask] - · - simp [hscore, scoreHi, hmask] + · simp [hscore, scoreLo, hmask] + · simp [hscore, scoreHi, hmask] · have hnot_masked : ¬ masked q k := by simp [masked, hcausal] have hscore_eq : scoresReal q k = base := From be20a13c0de5f62491a9b64048677b07ca618633 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Mon, 12 Jan 2026 16:10:44 +0100 Subject: [PATCH 138/244] Refactor sqrt rounding helpers --- Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean | 49 ++++++++++------------ 1 file changed, 23 insertions(+), 26 deletions(-) diff --git a/Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean b/Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean index d6b8c7c..535c98c 100644 --- a/Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean +++ b/Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean @@ -28,6 +28,21 @@ lemma rat_nat_cast_nonneg (n : Nat) : (0 : Rat) ≤ (n : Rat) := by lemma rat_nat_cast_pos {n : Nat} (h : 0 < n) : (0 : Rat) < (n : Rat) := by exact (Nat.cast_pos (α := Rat)).2 h +/-- `ratRoundDown` preserves nonnegativity for nonnegative divisions. -/ +theorem ratRoundDown_nonneg_div {a b : Rat} (ha : 0 ≤ a) (hb : 0 ≤ b) : + 0 ≤ ratRoundDown (a / b) := by + exact ratRoundDown_nonneg (q := a / b) (by exact div_nonneg ha hb) + +/-- `ratRoundUp` preserves nonnegativity for nonnegative divisions. -/ +theorem ratRoundUp_nonneg_div {a b : Rat} (ha : 0 ≤ a) (hb : 0 ≤ b) : + 0 ≤ ratRoundUp (a / b) := by + exact ratRoundUp_nonneg (q := a / b) (by exact div_nonneg ha hb) + +/-- `ratRoundUp` preserves positivity for positive divisions. -/ +theorem ratRoundUp_pos_div {a b : Rat} (ha : 0 < a) (hb : 0 < b) : + 0 < ratRoundUp (a / b) := by + exact ratRoundUp_pos (q := a / b) (by exact div_pos ha hb) + /-- Base rational lower bound for a square root. -/ def sqrtLowerBase (q : Rat) : Rat := let num := q.num.natAbs @@ -93,9 +108,7 @@ theorem sqrtLowerBase_nonneg (q : Rat) : 0 ≤ sqrtLowerBase q := by rat_nat_cast_nonneg (Nat.sqrt q.num.natAbs) have hden : 0 ≤ (Nat.sqrt q.den : Rat) + 1 := by simpa using rat_nat_cast_nonneg (Nat.sqrt q.den + 1) - exact ratRoundDown_nonneg - (q := (Nat.sqrt q.num.natAbs : Rat) / (Nat.sqrt q.den + 1)) - (by exact div_nonneg hnum hden) + exact ratRoundDown_nonneg_div hnum hden /-- `sqrtUpperBase` is nonnegative. -/ theorem sqrtUpperBase_nonneg (q : Rat) : 0 ≤ sqrtUpperBase q := by @@ -105,9 +118,7 @@ theorem sqrtUpperBase_nonneg (q : Rat) : 0 ≤ sqrtUpperBase q := by simpa using rat_nat_cast_nonneg (Nat.sqrt q.num.natAbs + 1) have hden : 0 ≤ (Nat.sqrt q.den : Rat) := rat_nat_cast_nonneg (Nat.sqrt q.den) - exact ratRoundUp_nonneg - (q := (Nat.sqrt q.num.natAbs + 1 : Rat) / (Nat.sqrt q.den)) - (by exact div_nonneg hnum hden) + exact ratRoundUp_nonneg_div hnum hden /-- `sqrtUpperBase` is always positive. -/ theorem sqrtUpperBase_pos (q : Rat) : 0 < sqrtUpperBase q := by @@ -118,9 +129,7 @@ theorem sqrtUpperBase_pos (q : Rat) : 0 < sqrtUpperBase q := by have hden_pos : (0 : Rat) < (Nat.sqrt q.den : Rat) := by have hden : 0 < q.den := q.den_pos exact rat_nat_cast_pos (Nat.sqrt_pos.2 hden) - exact ratRoundUp_pos - (q := (Nat.sqrt q.num.natAbs + 1 : Rat) / (Nat.sqrt q.den)) - (by exact div_pos hnum_pos hden_pos) + exact ratRoundUp_pos_div hnum_pos hden_pos /-- `sqrtLowerAlt` is nonnegative. -/ theorem sqrtLowerAlt_nonneg (q : Rat) : 0 ≤ sqrtLowerAlt q := by @@ -130,9 +139,7 @@ theorem sqrtLowerAlt_nonneg (q : Rat) : 0 ≤ sqrtLowerAlt q := by rat_nat_cast_nonneg (Nat.sqrt (q.num.natAbs * q.den)) have hden : 0 ≤ (q.den : Rat) := rat_nat_cast_nonneg q.den - exact ratRoundDown_nonneg - (q := (Nat.sqrt (q.num.natAbs * q.den) : Rat) / q.den) - (by exact div_nonneg hnum hden) + exact ratRoundDown_nonneg_div hnum hden /-- `sqrtUpperAlt` is nonnegative. -/ theorem sqrtUpperAlt_nonneg (q : Rat) : 0 ≤ sqrtUpperAlt q := by @@ -142,9 +149,7 @@ theorem sqrtUpperAlt_nonneg (q : Rat) : 0 ≤ sqrtUpperAlt q := by simpa using rat_nat_cast_nonneg (Nat.sqrt (q.num.natAbs * q.den) + 1) have hden : 0 ≤ (q.den : Rat) := rat_nat_cast_nonneg q.den - exact ratRoundUp_nonneg - (q := (Nat.sqrt (q.num.natAbs * q.den) + 1 : Rat) / q.den) - (by exact div_nonneg hnum hden) + exact ratRoundUp_nonneg_div hnum hden /-- `sqrtUpperAlt` is always positive. -/ theorem sqrtUpperAlt_pos (q : Rat) : 0 < sqrtUpperAlt q := by @@ -155,9 +160,7 @@ theorem sqrtUpperAlt_pos (q : Rat) : 0 < sqrtUpperAlt q := by simpa using rat_nat_cast_pos (Nat.succ_pos (Nat.sqrt (q.num.natAbs * q.den))) have hden_pos : (0 : Rat) < (q.den : Rat) := rat_nat_cast_pos q.den_pos - exact ratRoundUp_pos - (q := (Nat.sqrt (q.num.natAbs * q.den) + 1 : Rat) / q.den) - (by exact div_pos hnum_pos hden_pos) + exact ratRoundUp_pos_div hnum_pos hden_pos /-- `sqrtUpperScaled` is nonnegative. -/ theorem sqrtUpperScaled_nonneg (q : Rat) : 0 ≤ sqrtUpperScaled q := by @@ -169,10 +172,7 @@ theorem sqrtUpperScaled_nonneg (q : Rat) : 0 ≤ sqrtUpperScaled q := by (Nat.sqrt (q.num.natAbs * q.den * sqrtLowerScale * sqrtLowerScale) + 1) have hden : 0 ≤ (q.den : Rat) * (sqrtLowerScale : Rat) := by simpa [Nat.cast_mul] using rat_nat_cast_nonneg (q.den * sqrtLowerScale) - exact ratRoundUp_nonneg - (q := (Nat.sqrt (q.num.natAbs * q.den * sqrtLowerScale * sqrtLowerScale) + 1 : Rat) - / (q.den * sqrtLowerScale)) - (by exact div_nonneg hnum hden) + exact ratRoundUp_nonneg_div hnum hden /-- `sqrtUpperScaled` is always positive. -/ theorem sqrtUpperScaled_pos (q : Rat) : 0 < sqrtUpperScaled q := by @@ -188,10 +188,7 @@ theorem sqrtUpperScaled_pos (q : Rat) : 0 < sqrtUpperScaled q := by have hscale : 0 < sqrtLowerScale := by simp [sqrtLowerScale] simpa [Nat.cast_mul] using rat_nat_cast_pos (Nat.mul_pos hden hscale) - exact ratRoundUp_pos - (q := (Nat.sqrt (q.num.natAbs * q.den * sqrtLowerScale * sqrtLowerScale) + 1 : Rat) - / (q.den * sqrtLowerScale)) - (by exact div_pos hnum_pos hden_pos) + exact ratRoundUp_pos_div hnum_pos hden_pos /-! Combined bounds. -/ From a2d12c28675f2903da95a9b776bc1a10815a89bf Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Mon, 12 Jan 2026 16:16:25 +0100 Subject: [PATCH 139/244] bd sync: 2026-01-12 16:16:25 --- .beads/issues.jsonl | 1 + 1 file changed, 1 insertion(+) diff --git a/.beads/issues.jsonl b/.beads/issues.jsonl index 889428c..97aca46 100644 --- a/.beads/issues.jsonl +++ b/.beads/issues.jsonl @@ -1 +1,2 @@ +{"id":"nfp-g78","title":"Refactor CoreSound lemmas for stability","description":"Refactor Nfp/Sound/Induction/CoreSound.lean to reduce proof churn and improve maintainability; keep proofs explicit and stable.","status":"closed","priority":2,"issue_type":"task","owner":"robin.gieseke@me.com","created_at":"2026-01-12T16:12:25.347409+01:00","created_by":"TheDarkchip","updated_at":"2026-01-12T16:16:04.322186+01:00","closed_at":"2026-01-12T16:16:04.322186+01:00","close_reason":"Completed","labels":["in_progress"]} {"id":"nfp-snh","title":"Refactor MatrixNorm interval bounds proofs","description":"Refactor Nfp/Sound/Bounds/MatrixNorm/Interval.lean to reduce proof churn and simplify interval bound lemmas.","status":"closed","priority":2,"issue_type":"task","owner":"robin.gieseke@me.com","created_at":"2026-01-12T15:54:50.752468+01:00","created_by":"TheDarkchip","updated_at":"2026-01-12T16:01:12.070535+01:00","closed_at":"2026-01-12T16:01:12.070535+01:00","close_reason":"Closed","comments":[{"id":1,"issue_id":"nfp-snh","author":"TheDarkchip","text":"Refactored foldl_pair helper for pair foldl projections in MatrixNorm interval bounds; adjusted proofs to use Prod.fst/Prod.snd. Fixed CoreSound cdot lint warnings encountered during build.","created_at":"2026-01-12T15:01:08Z"}]} From b50ccba8942eb311333c8a91b7394098aa82dafa Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Mon, 12 Jan 2026 16:16:38 +0100 Subject: [PATCH 140/244] Deduplicate top2 split selector --- Nfp/Sound/Induction/CoreSound.lean | 46 +++++++++--------------------- 1 file changed, 14 insertions(+), 32 deletions(-) diff --git a/Nfp/Sound/Induction/CoreSound.lean b/Nfp/Sound/Induction/CoreSound.lean index 4420bba..c8b3ead 100644 --- a/Nfp/Sound/Induction/CoreSound.lean +++ b/Nfp/Sound/Induction/CoreSound.lean @@ -171,10 +171,8 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} let splitBudgetK : Nat := 2 let splitBudgetDiffBase : Nat := 0 let splitBudgetDiffRefined : Nat := 12 - let splitDimsQ : Fin seq → List (Fin dHead) := fun q => - let ambig := - (List.finRange dHead).filter (fun d => decide (qLo q d < 0 ∧ 0 < qHi q d)) - let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d + let top2ByScore : + (Fin dHead → Rat) → List (Fin dHead) → List (Fin dHead) := fun score ambig => let step (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) (d : Fin dHead) : @@ -189,39 +187,23 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) | (none, some b2) => if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => - match ambig.foldl step (none, none) with - | (some b1, some b2) => [b1.2, b2.2] - | (some b1, none) => [b1.2] - | (none, _) => [] - let dims1 := top2 ambig - let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) + match ambig.foldl step (none, none) with + | (some b1, some b2) => [b1.2, b2.2] + | (some b1, none) => [b1.2] + | (none, _) => [] + let splitDimsQ : Fin seq → List (Fin dHead) := fun q => + let ambig := + (List.finRange dHead).filter (fun d => decide (qLo q d < 0 ∧ 0 < qHi q d)) + let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d + let dims1 := top2ByScore score ambig + let dims2 := top2ByScore score (ambig.filter (fun d => decide (d ∉ dims1))) (dims1 ++ dims2).take splitBudgetQ let splitDimsK : Fin seq → Fin seq → List (Fin dHead) := fun q k => let ambig := (List.finRange dHead).filter (fun d => decide (kLo k d < 0 ∧ 0 < kHi k d)) let score : Fin dHead → Rat := fun d => (kHi k d - kLo k d) * qAbs q d - let step - (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) - (d : Fin dHead) : - Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := - let s := score d - match best with - | (none, none) => (some (s, d), none) - | (some b1, none) => - if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) - | (some b1, some b2) => - if b1.1 < s then (some (s, d), some b1) - else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) - | (none, some b2) => - if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => - match ambig.foldl step (none, none) with - | (some b1, some b2) => [b1.2, b2.2] - | (some b1, none) => [b1.2] - | (none, _) => [] - let dims1 := top2 ambig - let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) + let dims1 := top2ByScore score ambig + let dims2 := top2ByScore score (ambig.filter (fun d => decide (d ∉ dims1))) (dims1 ++ dims2).take splitBudgetK let splitDimsDiffCore : Nat → Fin seq → Fin seq → List (Fin dHead) := fun budget q k => let prev := inputs.prev q From 6bad98b557393cba4b81a144b2abb5b475dde9ca Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Mon, 12 Jan 2026 16:22:24 +0100 Subject: [PATCH 141/244] bd sync: 2026-01-12 16:22:24 --- .beads/issues.jsonl | 1 + 1 file changed, 1 insertion(+) diff --git a/.beads/issues.jsonl b/.beads/issues.jsonl index 97aca46..c4f066f 100644 --- a/.beads/issues.jsonl +++ b/.beads/issues.jsonl @@ -1,2 +1,3 @@ +{"id":"nfp-g0j","title":"Refactor HeadBounds reduction helpers","description":"Deduplicate chunked reduction task helpers in Nfp/Sound/Induction/HeadBounds.lean to reduce proof churn while preserving behavior.","status":"closed","priority":2,"issue_type":"task","owner":"robin.gieseke@me.com","created_at":"2026-01-12T16:21:11.583135+01:00","created_by":"TheDarkchip","updated_at":"2026-01-12T16:22:20.621327+01:00","closed_at":"2026-01-12T16:22:20.621327+01:00","close_reason":"Completed"} {"id":"nfp-g78","title":"Refactor CoreSound lemmas for stability","description":"Refactor Nfp/Sound/Induction/CoreSound.lean to reduce proof churn and improve maintainability; keep proofs explicit and stable.","status":"closed","priority":2,"issue_type":"task","owner":"robin.gieseke@me.com","created_at":"2026-01-12T16:12:25.347409+01:00","created_by":"TheDarkchip","updated_at":"2026-01-12T16:16:04.322186+01:00","closed_at":"2026-01-12T16:16:04.322186+01:00","close_reason":"Completed","labels":["in_progress"]} {"id":"nfp-snh","title":"Refactor MatrixNorm interval bounds proofs","description":"Refactor Nfp/Sound/Bounds/MatrixNorm/Interval.lean to reduce proof churn and simplify interval bound lemmas.","status":"closed","priority":2,"issue_type":"task","owner":"robin.gieseke@me.com","created_at":"2026-01-12T15:54:50.752468+01:00","created_by":"TheDarkchip","updated_at":"2026-01-12T16:01:12.070535+01:00","closed_at":"2026-01-12T16:01:12.070535+01:00","close_reason":"Closed","comments":[{"id":1,"issue_id":"nfp-snh","author":"TheDarkchip","text":"Refactored foldl_pair helper for pair foldl projections in MatrixNorm interval bounds; adjusted proofs to use Prod.fst/Prod.snd. Fixed CoreSound cdot lint warnings encountered during build.","created_at":"2026-01-12T15:01:08Z"}]} From 4e8b93bb2ff40578092b95f134b4c142ba9c584c Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Mon, 12 Jan 2026 16:22:35 +0100 Subject: [PATCH 142/244] Refactor HeadBounds reduction task helper --- Nfp/Sound/Induction/HeadBounds.lean | 65 +++++++++++++++++------------ 1 file changed, 38 insertions(+), 27 deletions(-) diff --git a/Nfp/Sound/Induction/HeadBounds.lean b/Nfp/Sound/Induction/HeadBounds.lean index dec0ea2..11f529c 100644 --- a/Nfp/Sound/Induction/HeadBounds.lean +++ b/Nfp/Sound/Induction/HeadBounds.lean @@ -42,7 +42,9 @@ private def reduceMaxArray (arr : Array Rat) : Rat := let init := arr.getD 0 (0 : Rat) arr.foldl (fun acc x => max acc x) init -private def reduceMinFnTask [NeZero seq] (vals : Fin seq → Rat) : Task Rat := +/-- Reduce a `Fin seq`-indexed function in parallel using chunked tasks. -/ +private def reduceFnTask [NeZero seq] (vals : Fin seq → Rat) + (combine : Rat → Rat → Rat) (combineTask : Task Rat → Task Rat → Task Rat) : Task Rat := let n := seq if n = 0 then Task.pure (0 : Rat) @@ -63,36 +65,45 @@ private def reduceMinFnTask [NeZero seq] (vals : Fin seq → Rat) : Task Rat := init else let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) - rest.foldl (fun acc i => min acc (vals (idxs.getD i defaultIdx))) init)) + rest.foldl (fun acc i => combine acc (vals (idxs.getD i defaultIdx))) init)) let init := chunkTasks.getD 0 defaultTask let rest := (List.range (chunkTasks.size - 1)).map (fun i => i + 1) - rest.foldl (fun acc i => taskMin acc (chunkTasks.getD i defaultTask)) init + rest.foldl (fun acc i => combineTask acc (chunkTasks.getD i defaultTask)) init + +/-- Unfold `reduceFnTask` to its chunked-task definition. -/ +theorem reduceFnTask_spec [NeZero seq] (vals : Fin seq → Rat) + (combine : Rat → Rat → Rat) (combineTask : Task Rat → Task Rat → Task Rat) : + reduceFnTask (seq := seq) vals combine combineTask = + let n := seq + if n = 0 then + Task.pure (0 : Rat) + else + let chunkSize : Nat := 256 + let chunks : Nat := (n + chunkSize - 1) / chunkSize + let hpos : 0 < seq := Nat.pos_of_ne_zero (by simpa using (NeZero.ne (n := seq))) + let defaultIdx : Fin seq := ⟨0, hpos⟩ + let idxs : Array (Fin seq) := Array.ofFn (fun i : Fin seq => i) + let defaultTask : Task Rat := Task.pure (0 : Rat) + let chunkTasks : Array (Task Rat) := + Array.ofFn (fun c : Fin chunks => + Task.spawn (fun _ => + let start := c.val * chunkSize + let stop := Nat.min n (start + chunkSize) + let init := vals (idxs.getD start defaultIdx) + if stop ≤ start + 1 then + init + else + let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) + rest.foldl (fun acc i => combine acc (vals (idxs.getD i defaultIdx))) init)) + let init := chunkTasks.getD 0 defaultTask + let rest := (List.range (chunkTasks.size - 1)).map (fun i => i + 1) + rest.foldl (fun acc i => combineTask acc (chunkTasks.getD i defaultTask)) init := rfl + +private def reduceMinFnTask [NeZero seq] (vals : Fin seq → Rat) : Task Rat := + reduceFnTask vals min taskMin private def reduceMaxFnTask [NeZero seq] (vals : Fin seq → Rat) : Task Rat := - let n := seq - if n = 0 then - Task.pure (0 : Rat) - else - let chunkSize : Nat := 256 - let chunks : Nat := (n + chunkSize - 1) / chunkSize - let hpos : 0 < seq := Nat.pos_of_ne_zero (by simpa using (NeZero.ne (n := seq))) - let defaultIdx : Fin seq := ⟨0, hpos⟩ - let idxs : Array (Fin seq) := Array.ofFn (fun i : Fin seq => i) - let defaultTask : Task Rat := Task.pure (0 : Rat) - let chunkTasks : Array (Task Rat) := - Array.ofFn (fun c : Fin chunks => - Task.spawn (fun _ => - let start := c.val * chunkSize - let stop := Nat.min n (start + chunkSize) - let init := vals (idxs.getD start defaultIdx) - if stop ≤ start + 1 then - init - else - let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) - rest.foldl (fun acc i => max acc (vals (idxs.getD i defaultIdx))) init)) - let init := chunkTasks.getD 0 defaultTask - let rest := (List.range (chunkTasks.size - 1)).map (fun i => i + 1) - rest.foldl (fun acc i => taskMax acc (chunkTasks.getD i defaultTask)) init + reduceFnTask vals max taskMax /-- Cached direction head for head inputs. -/ private def dirHeadVecOfInputs {seq dModel dHead : Nat} From 8c9389a13a08a2c4080ed834764580526ca19c70 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Mon, 12 Jan 2026 16:25:10 +0100 Subject: [PATCH 143/244] bd sync branch --- .beads/config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.beads/config.yaml b/.beads/config.yaml index f242785..1de3590 100644 --- a/.beads/config.yaml +++ b/.beads/config.yaml @@ -42,7 +42,7 @@ # This setting persists across clones (unlike database config which is gitignored). # Can also use BEADS_SYNC_BRANCH env var for local override. # If not set, bd sync will require you to run 'bd config set sync.branch '. -# sync-branch: "beads-sync" +sync-branch: "beads-sync" # Multi-repo configuration (experimental - bd-307) # Allows hydrating from multiple repositories and routing writes to the correct JSONL @@ -59,4 +59,4 @@ # - linear.url # - linear.api-key # - github.org -# - github.repo +# - github.repo \ No newline at end of file From a95635b8194bea67a839f0cf0cf07f6b3f0157b5 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Mon, 12 Jan 2026 17:12:36 +0100 Subject: [PATCH 144/244] Tighten induction epsilon bounds --- Nfp/Sound/Induction/Core.lean | 19 ++- Nfp/Sound/Induction/CoreSound.lean | 249 +++++++---------------------- Nfp/Sound/Induction/OneHot.lean | 153 ++++++++++++++++++ 3 files changed, 225 insertions(+), 196 deletions(-) diff --git a/Nfp/Sound/Induction/Core.lean b/Nfp/Sound/Induction/Core.lean index 54e9128..9aa52ed 100644 --- a/Nfp/Sound/Induction/Core.lean +++ b/Nfp/Sound/Induction/Core.lean @@ -615,20 +615,25 @@ def buildInductionHeadCoreCache [NeZero seq] {dModel dHead : Nat} else (0 : Rat) let epsAt : Fin seq → Rat := fun q => - if marginAt q < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + marginAt q) + let other := otherKeys q + let total := + other.sum (fun k => + let gap := scoreGapLo q k + if gap < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + gap)) + min (1 : Rat) total let margin : Rat := if h : inputs.active.Nonempty then inputs.active.inf' h marginAt else (0 : Rat) let eps : Rat := - if margin < 0 then - (1 : Rat) + if h : inputs.active.Nonempty then + inputs.active.sup' h epsAt else - ratDivUp (seq - 1) (1 + margin) + (0 : Rat) let dirHeadVec := dirHeadVecOfInputs inputs let dirHead : Fin dHead → Rat := fun d => dirHeadVec.get d let wvDir : Fin dModel → Rat := diff --git a/Nfp/Sound/Induction/CoreSound.lean b/Nfp/Sound/Induction/CoreSound.lean index c8b3ead..b1d5189 100644 --- a/Nfp/Sound/Induction/CoreSound.lean +++ b/Nfp/Sound/Induction/CoreSound.lean @@ -388,20 +388,25 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} else (0 : Rat) let epsAt : Fin seq → Rat := fun q => - if marginAt q < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + marginAt q) + let other := otherKeys q + let total := + other.sum (fun k => + let gap := scoreGapLo q k + if gap < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + gap)) + min (1 : Rat) total let margin : Rat := if h : inputs.active.Nonempty then inputs.active.inf' h marginAt else (0 : Rat) let eps : Rat := - if margin < 0 then - (1 : Rat) + if h : inputs.active.Nonempty then + inputs.active.sup' h epsAt else - ratDivUp (seq - 1) (1 + margin) + (0 : Rat) have hseq : (1 : Nat) ≤ seq := Nat.succ_le_iff.mpr (Nat.pos_of_ne_zero (NeZero.ne seq)) let dirHeadVec := dirHeadVecOfInputs inputs @@ -1074,6 +1079,44 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} _ = (marginAt q : Real) + scoresReal q k := by simp [add_comm] exact hstep'.trans hscore' + have hepsAt : + ∀ q, epsAt q = + min (1 : Rat) + ((otherKeys q).sum (fun k => + if scoreGapLo q k < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + scoreGapLo q k))) := by + intro q + rfl + have oneHot_bounds_at : + ∀ q, q ∈ inputs.active → + Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAt q : Real) + (fun q' => q' = q) inputs.prev weights := by + intro q hq + exact + Sound.oneHot_bounds_at_of_scoreGapLo + (active := inputs.active) + (prev := inputs.prev) + (scoresReal := scoresReal) + (scoreGapLo := scoreGapLo) + (epsAt := epsAt) + (hepsAt := hepsAt) + (hscore_gap_real_at := hscore_gap_real_at) + q hq + have hepsAt_le_eps : + ∀ q, q ∈ inputs.active → epsAt q ≤ eps := by + intro q hq + have hle : + epsAt q ≤ inputs.active.sup' hactive epsAt := by + exact + (Finset.le_sup'_iff (s := inputs.active) (H := hactive) + (f := epsAt) (a := epsAt q)).2 ⟨q, hq, le_rfl⟩ + simpa [eps, hactive] using hle + have hepsAt_le_eps_real : + ∀ q, q ∈ inputs.active → (epsAt q : Real) ≤ (eps : Real) := by + intro q hq + exact ratToReal_le_of_le (hepsAt_le_eps q hq) have hsoftmax_bounds : Layers.SoftmaxMarginBoundsOn (Val := Real) (eps : Real) (margin : Real) (fun q => q ∈ inputs.active) inputs.prev scoresReal weights := by @@ -1093,190 +1136,18 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} simpa [weights] using (Circuit.softmax_sum_one (scores := scoresReal q)) · intro q hq - have hsum_others_le : (∑ k ∈ others q, weights q k) ≤ (eps : Real) := by - by_cases hneg : margin < 0 - · have heps : (eps : Real) = 1 := by - simp [eps, hneg] - have hsubset : others q ⊆ (Finset.univ : Finset (Fin seq)) := by - intro k hk - simp - have hnonneg : - ∀ k ∈ (Finset.univ : Finset (Fin seq)), 0 ≤ weights q k := by - intro k _ - simpa [weights] using - (Circuit.softmax_nonneg (scores := scoresReal q) k) - have hsum_le : - (∑ k ∈ others q, weights q k) ≤ - ∑ k ∈ (Finset.univ : Finset (Fin seq)), weights q k := - Finset.sum_le_sum_of_subset_of_nonneg hsubset (by - intro k hk _; exact hnonneg k hk) - have hsum_one : (∑ k, weights q k) = 1 := by - simpa [weights] using - (Circuit.softmax_sum_one (scores := scoresReal q)) - simpa [heps, hsum_one] using hsum_le - · have hnonneg : 0 ≤ margin := le_of_not_gt hneg - have hnonneg_real : 0 ≤ (margin : Real) := by - exact ratToReal_nonneg_of_nonneg hnonneg - have hbound : - ∀ k ∈ others q, - weights q k ≤ (1 + (margin : Real))⁻¹ := by - intro k hk - have hkne : k ≠ inputs.prev q := (Finset.mem_erase.mp hk).1 - have hscore := hscore_margin_real q hq k hkne - simpa [weights] using - (Circuit.softmax_other_le_inv_one_add (scores := scoresReal q) - (prev := inputs.prev q) (k := k) (m := (margin : Real)) - hnonneg_real hscore) - have hsum_le : - (∑ k ∈ others q, weights q k) ≤ - ∑ k ∈ others q, (1 + (margin : Real))⁻¹ := - Finset.sum_le_sum hbound - have hsum_const : - (∑ k ∈ others q, (1 + (margin : Real))⁻¹) = - (others q).card * (1 + (margin : Real))⁻¹ := by - simp - have hcard : (others q).card = seq - 1 := by - simp [others, Finset.card_erase_of_mem] - have hsum_le' : - (∑ k ∈ others q, weights q k) ≤ - (seq - 1 : Real) * (1 + (margin : Real))⁻¹ := by - simpa [hcard, Nat.cast_sub hseq, Nat.cast_one] using - (hsum_le.trans_eq hsum_const) - have hpos : (0 : Rat) < 1 + margin := by - have hone : (0 : Rat) < 1 := by - exact zero_lt_one - have hle : (1 : Rat) ≤ 1 + margin := by - exact le_add_of_nonneg_right hnonneg - exact lt_of_lt_of_le hone hle - have hden : (1 + margin) ≠ 0 := by - exact ne_of_gt hpos - have hrat' := ratDivUp_ge_real (seq - 1) (1 + margin) hden - have heps : - (seq - 1 : Real) * (1 + (margin : Real))⁻¹ ≤ (eps : Real) := by - simpa [eps, hneg, ratToReal, Rat.cast_div, Rat.cast_add, - Rat.cast_natCast, div_eq_mul_inv] using hrat' - exact le_trans hsum_le' heps - have hsum_eq : - weights q (inputs.prev q) + ∑ k ∈ others q, weights q k = 1 := by - have hsum' : - weights q (inputs.prev q) + ∑ k ∈ others q, weights q k = - ∑ k, weights q k := by - simp [others] - have hsum_one : (∑ k, weights q k) = 1 := by - simpa [weights] using - (Circuit.softmax_sum_one (scores := scoresReal q)) - calc - weights q (inputs.prev q) + ∑ k ∈ others q, weights q k = - ∑ k, weights q k := hsum' - _ = 1 := hsum_one - have hsum_le' : - weights q (inputs.prev q) + ∑ k ∈ others q, weights q k ≤ + have honehot := oneHot_bounds_at q hq + have hprev := honehot.prev_large q rfl + have hle : + weights q (inputs.prev q) + (epsAt q : Real) ≤ weights q (inputs.prev q) + (eps : Real) := by - simpa [add_comm, add_left_comm, add_assoc] using - (add_le_add_left hsum_others_le (weights q (inputs.prev q))) - have hprev : - 1 ≤ weights q (inputs.prev q) + (eps : Real) := by - simpa [hsum_eq] using hsum_le' - exact hprev + simpa [add_comm] using + (add_le_add_right (hepsAt_le_eps_real q hq) (weights q (inputs.prev q))) + exact hprev.trans hle · intro q hq k hk - have hsum_others_le : (∑ j ∈ others q, weights q j) ≤ (eps : Real) := by - by_cases hneg : margin < 0 - · have heps : (eps : Real) = 1 := by - simp [eps, hneg] - have hsubset : others q ⊆ (Finset.univ : Finset (Fin seq)) := by - intro j hj - simp - have hnonneg : - ∀ j ∈ (Finset.univ : Finset (Fin seq)), 0 ≤ weights q j := by - intro j _ - simpa [weights] using - (Circuit.softmax_nonneg (scores := scoresReal q) j) - have hsum_le : - (∑ j ∈ others q, weights q j) ≤ - ∑ j ∈ (Finset.univ : Finset (Fin seq)), weights q j := - Finset.sum_le_sum_of_subset_of_nonneg hsubset (by - intro j hj _; exact hnonneg j hj) - have hsum_one : (∑ j, weights q j) = 1 := by - simpa [weights] using - (Circuit.softmax_sum_one (scores := scoresReal q)) - simpa [heps, hsum_one] using hsum_le - · have hnonneg : 0 ≤ margin := le_of_not_gt hneg - have hnonneg_real : 0 ≤ (margin : Real) := by - exact ratToReal_nonneg_of_nonneg hnonneg - have hbound : - ∀ j ∈ others q, - weights q j ≤ (1 + (margin : Real))⁻¹ := by - intro j hj - have hjne : j ≠ inputs.prev q := (Finset.mem_erase.mp hj).1 - have hscore := hscore_margin_real q hq j hjne - simpa [weights] using - (Circuit.softmax_other_le_inv_one_add (scores := scoresReal q) - (prev := inputs.prev q) (k := j) (m := (margin : Real)) - hnonneg_real hscore) - have hsum_le : - (∑ j ∈ others q, weights q j) ≤ - ∑ j ∈ others q, (1 + (margin : Real))⁻¹ := - Finset.sum_le_sum hbound - have hsum_const : - (∑ j ∈ others q, (1 + (margin : Real))⁻¹) = - (others q).card * (1 + (margin : Real))⁻¹ := by - simp - have hcard : (others q).card = seq - 1 := by - simp [others, Finset.card_erase_of_mem] - have hsum_le' : - (∑ j ∈ others q, weights q j) ≤ - (seq - 1 : Real) * (1 + (margin : Real))⁻¹ := by - simpa [hcard, Nat.cast_sub hseq, Nat.cast_one] using - (hsum_le.trans_eq hsum_const) - have hpos : (0 : Rat) < 1 + margin := by - have hone : (0 : Rat) < 1 := by - exact zero_lt_one - have hle : (1 : Rat) ≤ 1 + margin := by - exact le_add_of_nonneg_right hnonneg - exact lt_of_lt_of_le hone hle - have hden : (1 + margin) ≠ 0 := by - exact ne_of_gt hpos - have hrat' := ratDivUp_ge_real (seq - 1) (1 + margin) hden - have heps : - (seq - 1 : Real) * (1 + (margin : Real))⁻¹ ≤ (eps : Real) := by - simpa [eps, hneg, ratToReal, Rat.cast_div, Rat.cast_add, - Rat.cast_natCast, div_eq_mul_inv] using hrat' - exact le_trans hsum_le' heps - have hk' : k ∈ others q := by - simp [others, hk] - have hnonneg : - ∀ j ∈ others q, 0 ≤ weights q j := by - intro j _ - simpa [weights] using - (Circuit.softmax_nonneg (scores := scoresReal q) j) - have hle : - weights q k ≤ ∑ j ∈ others q, weights q j := by - simpa using (Finset.single_le_sum hnonneg hk') - exact hle.trans hsum_others_le - have hepsAt : - ∀ q, epsAt q = - if marginAt q < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + marginAt q) := by - intro q - rfl - have oneHot_bounds_at : - ∀ q, q ∈ inputs.active → - Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAt q : Real) - (fun q' => q' = q) inputs.prev weights := by - intro q hq - exact - Sound.oneHot_bounds_at_of_marginAt - (active := inputs.active) - (prev := inputs.prev) - (scoresReal := scoresReal) - (marginAt := marginAt) - (epsAt := epsAt) - (hepsAt := hepsAt) - (hseq := hseq) - (hscore_margin_real_at := hscore_margin_real_at) - q hq + have honehot := oneHot_bounds_at q hq + have hother := honehot.other_le q rfl k hk + exact hother.trans (hepsAt_le_eps_real q hq) have hdirHead : dirHead = fun d => (dirHeadVecOfInputs inputs).get d := by simp [dirHead, dirHeadVec] diff --git a/Nfp/Sound/Induction/OneHot.lean b/Nfp/Sound/Induction/OneHot.lean index 6f12e26..8d65f76 100644 --- a/Nfp/Sound/Induction/OneHot.lean +++ b/Nfp/Sound/Induction/OneHot.lean @@ -2,6 +2,7 @@ import Mathlib.Algebra.BigOperators.Group.Finset.Basic import Mathlib.Algebra.Order.BigOperators.Group.Finset +import Mathlib.Data.Rat.BigOperators import Nfp.Core.Basic import Nfp.Circuit.Layers.Induction import Nfp.Circuit.Layers.Softmax @@ -160,6 +161,158 @@ theorem oneHot_bounds_at_of_marginAt simpa using h exact hle.trans hsum_others_le +/-- One-hot bounds on a single active query, derived from per-key score gaps. -/ +theorem oneHot_bounds_at_of_scoreGapLo + (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (scoresReal : Fin seq → Fin seq → Real) + (scoreGapLo : Fin seq → Fin seq → Rat) + (epsAt : Fin seq → Rat) + (hepsAt : + ∀ q, epsAt q = + min (1 : Rat) + ((Finset.univ : Finset (Fin seq)).erase (prev q) |>.sum (fun k => + if scoreGapLo q k < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + scoreGapLo q k)))) + (hscore_gap_real_at : + ∀ q, q ∈ active → ∀ k, k ≠ prev q → + scoresReal q k + (scoreGapLo q k : Real) ≤ scoresReal q (prev q)) : + ∀ q, q ∈ active → + Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAt q : Real) + (fun q' => q' = q) prev + (fun q k => Circuit.softmax (scoresReal q) k) := by + classical + intro q hq + let weights : Fin seq → Fin seq → Real := fun q k => + Circuit.softmax (scoresReal q) k + let others : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (prev q) + let bound : Fin seq → Rat := fun k => + if scoreGapLo q k < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + scoreGapLo q k) + have hweights_nonneg : ∀ k, 0 ≤ weights q k := by + intro k + simpa [weights] using + (Circuit.softmax_nonneg (scores := scoresReal q) k) + have hsum_one : (∑ k, weights q k) = 1 := by + simpa [weights] using + (Circuit.softmax_sum_one (scores := scoresReal q)) + have hsum_others_le_one : (∑ k ∈ others q, weights q k) ≤ 1 := by + have hsubset : others q ⊆ (Finset.univ : Finset (Fin seq)) := by + intro k hk + simp + have hsum_le : + (∑ k ∈ others q, weights q k) ≤ + ∑ k ∈ (Finset.univ : Finset (Fin seq)), weights q k := + Finset.sum_le_sum_of_subset_of_nonneg hsubset (by + intro k _ _ + exact hweights_nonneg k) + simpa [hsum_one] using hsum_le + have hbound : + ∀ k ∈ others q, weights q k ≤ (bound k : Real) := by + intro k hk + have hkne : k ≠ prev q := (Finset.mem_erase.mp hk).1 + by_cases hneg : scoreGapLo q k < 0 + · have hle : weights q k ≤ 1 := by + simpa [weights] using + (Circuit.softmax_le_one (scores := scoresReal q) k) + simpa [bound, hneg] using hle + · have hnonneg : 0 ≤ scoreGapLo q k := le_of_not_gt hneg + have hnonneg_real : 0 ≤ (scoreGapLo q k : Real) := by + exact ratToReal_nonneg_of_nonneg hnonneg + have hscore := hscore_gap_real_at q hq k hkne + have hsoft : + weights q k ≤ 1 / (1 + (scoreGapLo q k : Real)) := by + simpa [weights] using + (Circuit.softmax_other_le_inv_one_add (scores := scoresReal q) + (prev := prev q) (k := k) (m := (scoreGapLo q k : Real)) + hnonneg_real hscore) + have hpos : (0 : Rat) < 1 + scoreGapLo q k := by + have hle : (1 : Rat) ≤ 1 + scoreGapLo q k := by + exact le_add_of_nonneg_right hnonneg + exact lt_of_lt_of_le zero_lt_one hle + have hden : (1 + scoreGapLo q k) ≠ 0 := by + exact ne_of_gt hpos + have hrat : + 1 / (1 + (scoreGapLo q k : Real)) ≤ + ratToReal (ratDivUp 1 (1 + scoreGapLo q k)) := by + simpa [ratToReal] using + (ratDivUp_ge_real 1 (1 + scoreGapLo q k) hden) + have hbound' : + weights q k ≤ ratToReal (ratDivUp 1 (1 + scoreGapLo q k)) := + hsoft.trans hrat + simpa [bound, hneg] using hbound' + have hsum_others_le : (∑ k ∈ others q, weights q k) ≤ (epsAt q : Real) := by + have hsum_le : + (∑ k ∈ others q, weights q k) ≤ + ∑ k ∈ others q, (bound k : Real) := + Finset.sum_le_sum hbound + have hsum_le_min : + (∑ k ∈ others q, weights q k) ≤ + min (1 : Real) (∑ k ∈ others q, (bound k : Real)) := by + exact le_min hsum_others_le_one hsum_le + have hepsAtReal : + (epsAt q : Real) = min (1 : Real) (∑ k ∈ others q, (bound k : Real)) := by + have h' : epsAt q = min 1 ((others q).sum bound) := by + simpa only [others, bound] using hepsAt q + have h'' : + ratToReal (epsAt q) = ratToReal (min 1 ((others q).sum bound)) := by + exact congrArg ratToReal h' + -- Avoid rewriting the erased-sum into a difference. + simpa [ratToReal_min, ratToReal, Rat.cast_sum] using h'' + simpa [hepsAtReal] using hsum_le_min + refine + { nonneg := ?_ + sum_one := ?_ + prev_large := ?_ + other_le := ?_ } + · intro q' hq' k + subst q' + change 0 ≤ Circuit.softmax (scoresReal q) k + exact Circuit.softmax_nonneg (scores := scoresReal q) k + · intro q' hq' + subst q' + change (∑ k, Circuit.softmax (scoresReal q) k) = 1 + exact Circuit.softmax_sum_one (scores := scoresReal q) + · intro q' hq' + subst q' + have hsum_eq : + weights q (prev q) + ∑ k ∈ others q, weights q k = 1 := by + have hsum' : + weights q (prev q) + ∑ k ∈ others q, weights q k = + ∑ k, weights q k := by + simp [others] + calc + weights q (prev q) + ∑ k ∈ others q, weights q k = + ∑ k, weights q k := hsum' + _ = 1 := hsum_one + have hsum_le' : + weights q (prev q) + ∑ k ∈ others q, weights q k ≤ + weights q (prev q) + (epsAt q : Real) := by + simpa [add_comm, add_left_comm, add_assoc] using + (add_le_add_left hsum_others_le (weights q (prev q))) + have hprev : + 1 ≤ weights q (prev q) + (epsAt q : Real) := by + simpa [hsum_eq] using hsum_le' + exact hprev + · intro q' hq' k hk + subst q' + have hk' : k ∈ others q := by + simp [others, hk] + have hnonneg : + ∀ j ∈ others q, 0 ≤ weights q j := by + intro j _ + exact hweights_nonneg j + have hle : + weights q k ≤ ∑ j ∈ others q, weights q j := by + have h := Finset.single_le_sum hnonneg hk' + simpa using h + exact hle.trans hsum_others_le + end Sound end Nfp From 41d1787ec684e50cd19d9664179b4efe7b30fcb6 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Mon, 12 Jan 2026 23:29:26 +0100 Subject: [PATCH 145/244] Improve induction cert bounds and fix float parsing --- Nfp/Circuit/Cert/LogitDiff.lean | 97 +++++++++++++ Nfp/IO/InductionHead.lean | 81 ++++++++--- Nfp/IO/NfptPure.lean | 4 +- Nfp/Sound/Induction/Core.lean | 57 +++++--- Nfp/Sound/Induction/CoreDefs.lean | 6 + Nfp/Sound/Induction/CoreSound.lean | 212 +++++++++++++++-------------- Nfp/Sound/Induction/LogitDiff.lean | 188 ++++++++++++++++--------- Nfp/Sound/Induction/OneHot.lean | 51 +++++++ 8 files changed, 489 insertions(+), 207 deletions(-) diff --git a/Nfp/Circuit/Cert/LogitDiff.lean b/Nfp/Circuit/Cert/LogitDiff.lean index 28742e7..c521724 100644 --- a/Nfp/Circuit/Cert/LogitDiff.lean +++ b/Nfp/Circuit/Cert/LogitDiff.lean @@ -42,6 +42,42 @@ def logitDiffLowerBoundAt (active : Finset (Fin seq)) else exact none +/-- Compute a lower bound on the logit-diff contribution using per-query eps and the global + lower value bound. -/ +def logitDiffLowerBoundAtLo (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (epsAt : Fin seq → Rat) (lo : Rat) (valsLo : Fin seq → Rat) : Option Rat := by + classical + if h : active.Nonempty then + let f : Fin seq → Rat := fun q => + valsLo (prev q) - epsAt q * (valsLo (prev q) - lo) + let img := active.image f + have himg : img.Nonempty := h.image f + exact some (Finset.min' img himg) + else + exact none + +/-- Compute a lower bound on the logit-diff contribution using per-key weight bounds and + per-key value lower bounds. -/ +def logitDiffLowerBoundWeightedAt (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (weightBoundAt : Fin seq → Fin seq → Rat) + (valsLo : Fin seq → Rat) : Option Rat := by + classical + if h : active.Nonempty then + let others : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (prev q) + let gap : Fin seq → Rat := fun q => + (others q).sum (fun k => + let diff := valsLo (prev q) - valsLo k + weightBoundAt q k * max (0 : Rat) diff) + let f : Fin seq → Rat := fun q => valsLo (prev q) - gap q + let img := active.image f + have himg : img.Nonempty := h.image f + exact some (Finset.min' img himg) + else + exact none + /-- The computed lower bound is below every active `prev` value minus the tolerance gap. -/ theorem logitDiffLowerBound_le (active : Finset (Fin seq)) (prev : Fin seq → Fin seq) @@ -84,6 +120,67 @@ theorem logitDiffLowerBoundAt_le (active : Finset (Fin seq)) Finset.min'_le _ _ hmem simpa [f, gap, hbound'] using hmin +/-- The per-query lower bound is below every active `prev` value minus the `lo`-gap. -/ +theorem logitDiffLowerBoundAtLo_le (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (epsAt : Fin seq → Rat) (lo : Rat) (valsLo : Fin seq → Rat) + (q : Fin seq) (hq : q ∈ active) : + ∀ lb, logitDiffLowerBoundAtLo active prev epsAt lo valsLo = some lb → + lb ≤ valsLo (prev q) - epsAt q * (valsLo (prev q) - lo) := by + classical + intro lb hbound + have hnonempty : active.Nonempty := ⟨q, hq⟩ + let f : Fin seq → Rat := fun q => + valsLo (prev q) - epsAt q * (valsLo (prev q) - lo) + have hbound' : (active.image f).min' (hnonempty.image f) = lb := by + simpa [logitDiffLowerBoundAtLo, hnonempty, f] using hbound + have hmem : f q ∈ (active.image f) := by + refine Finset.mem_image.2 ?_ + exact ⟨q, hq, rfl⟩ + have hmin : (active.image f).min' (hnonempty.image f) ≤ f q := + Finset.min'_le _ _ hmem + have hmin' : lb ≤ f q := by + simpa [hbound'] using hmin + simpa [f] using hmin' + +/-- The weighted lower bound is below every active `prev` value minus the weighted gap. -/ +theorem logitDiffLowerBoundWeightedAt_le (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (weightBoundAt : Fin seq → Fin seq → Rat) + (valsLo : Fin seq → Rat) + (q : Fin seq) (hq : q ∈ active) : + ∀ lb, logitDiffLowerBoundWeightedAt active prev weightBoundAt valsLo = some lb → + lb ≤ + valsLo (prev q) - + ((Finset.univ : Finset (Fin seq)).erase (prev q)).sum (fun k => + weightBoundAt q k * max (0 : Rat) (valsLo (prev q) - valsLo k)) := by + classical + intro lb hbound + have hnonempty : active.Nonempty := ⟨q, hq⟩ + let others : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (prev q) + let gap : Fin seq → Rat := fun q => + (others q).sum (fun k => + let diff := valsLo (prev q) - valsLo k + weightBoundAt q k * max (0 : Rat) diff) + let f : Fin seq → Rat := fun q => valsLo (prev q) - gap q + have hbound' : (active.image f).min' (hnonempty.image f) = lb := by + simpa [logitDiffLowerBoundWeightedAt, hnonempty, f, gap, others] using hbound + have hmem : f q ∈ (active.image f) := by + refine Finset.mem_image.2 ?_ + exact ⟨q, hq, rfl⟩ + have hmin : (active.image f).min' (hnonempty.image f) ≤ f q := + Finset.min'_le _ _ hmem + have hmin' : lb ≤ f q := by + simpa [hbound'] using hmin + have hmin'' : + lb ≤ valsLo (prev q) - + (others q).sum (fun k => + let diff := valsLo (prev q) - valsLo k + weightBoundAt q k * max (0 : Rat) diff) := by + simpa [f, gap] using hmin' + simpa [others] using hmin'' + end Circuit end Nfp diff --git a/Nfp/IO/InductionHead.lean b/Nfp/IO/InductionHead.lean index dfdba89..3993c03 100644 --- a/Nfp/IO/InductionHead.lean +++ b/Nfp/IO/InductionHead.lean @@ -952,18 +952,41 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} | Nat.succ n => let seq := Nat.succ n let _ : NeZero seq := ⟨by simp⟩ - logTiming "start: head build nonvacuous logit-diff" - let res : Option (Sound.InductionLogitLowerBoundNonvacuous inputs) ← - timePure "head: build nonvacuous logit-diff" (fun () => - Sound.buildInductionLogitLowerBoundNonvacuous? inputs) - logTiming "done: head build nonvacuous logit-diff" - match res with + logTiming "start: head build induction cert" + IO.println "timing: head build induction cert start" + flushStdout + let tCert0 ← monoUsNow + let certTask : + Task + (Option { c : Sound.InductionHeadCert seq // + Sound.InductionHeadCertSound inputs c }) := + Task.spawn (prio := Task.Priority.dedicated) (fun _ => + match Sound.buildInductionCertFromHead? inputs with + | none => none + | some ⟨cert, hcert⟩ => + let _ := cert.active.card + some ⟨cert, hcert⟩) + let heartbeatMs ← heartbeatMsFromEnv + if heartbeatMs ≠ 0 then + let mut finished := (← IO.hasFinished certTask) + while !finished do + IO.sleep heartbeatMs + finished := (← IO.hasFinished certTask) + if !finished then + let now ← monoUsNow + IO.println s!"timing: head build induction cert running {now - tCert0} us" + flushStdout + let certOpt ← IO.wait certTask + let tCert1 ← monoUsNow + logTiming s!"done: head build induction cert {tCert1 - tCert0} us" + IO.println s!"timing: head build induction cert {tCert1 - tCert0} us" + IO.println "timing: head build induction cert returned" + flushStdout + match certOpt with | none => - IO.eprintln "error: nonvacuous logit-diff construction failed" + IO.eprintln "error: head inputs rejected" return 2 - | some result => - let cert := result.base.cert - let logitDiffLB := result.base.lb + | some ⟨cert, _hcert⟩ => let activeCount := cert.active.card let defaultMinActive := max 1 (seq / 8) let minActive := minActive?.getD defaultMinActive @@ -980,20 +1003,36 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} IO.eprintln s!"error: eps {ratToString cert.eps} above maximum {ratToString maxEps}" return 2 - match minLogitDiff? with - | some minLogitDiff => - if logitDiffLB < minLogitDiff then + logTiming "start: head logit-diff lower bound" + IO.println "timing: head logit-diff lower bound start" + flushStdout + let logitDiffLB? ← timePure "head: logit-diff lower bound" (fun () => + Sound.logitDiffLowerBoundFromCert cert) + logTiming "done: head logit-diff lower bound" + match logitDiffLB? with + | none => + IO.eprintln "error: empty active set for logit-diff bound" + return 2 + | some logitDiffLB => + if logitDiffLB ≤ 0 then IO.eprintln s!"error: logitDiffLB {ratToString logitDiffLB} \ - below minimum {ratToString minLogitDiff}" + is not strictly positive" return 2 - | none => pure () - let tol := cert.eps * (cert.values.hi - cert.values.lo) - IO.println - s!"ok: nonvacuous induction bound certified \ - (seq={seq}, active={activeCount}, \ - tol={ratToString tol}, logitDiffLB={ratToString logitDiffLB})" - return 0 + match minLogitDiff? with + | some minLogitDiff => + if logitDiffLB < minLogitDiff then + IO.eprintln + s!"error: logitDiffLB {ratToString logitDiffLB} \ + below minimum {ratToString minLogitDiff}" + return 2 + | none => pure () + let tol := cert.eps * (cert.values.hi - cert.values.lo) + IO.println + s!"ok: nonvacuous induction bound certified \ + (seq={seq}, active={activeCount}, \ + tol={ratToString tol}, logitDiffLB={ratToString logitDiffLB})" + return 0 /-- Build and check induction certificates from exact head inputs. -/ def runInductionCertifyHead (inputsPath : System.FilePath) diff --git a/Nfp/IO/NfptPure.lean b/Nfp/IO/NfptPure.lean index 0164c6e..d2143b0 100644 --- a/Nfp/IO/NfptPure.lean +++ b/Nfp/IO/NfptPure.lean @@ -215,8 +215,8 @@ private def ratOfFloatBits (bits : Nat) : Option Rat := some (ratOfIntWithPrec num 1074) else let mant := mantBits + pow2 52 - let exp := expBits - 1023 - let shift : Int := Int.ofNat exp - 52 + let exp : Int := Int.ofNat expBits - 1023 + let shift : Int := exp - 52 let prec : Int := -shift some (ratOfIntWithPrec (sign * Int.ofNat mant) prec) diff --git a/Nfp/Sound/Induction/Core.lean b/Nfp/Sound/Induction/Core.lean index 9aa52ed..1ea11e2 100644 --- a/Nfp/Sound/Induction/Core.lean +++ b/Nfp/Sound/Induction/Core.lean @@ -211,6 +211,8 @@ structure InductionHeadCoreCache (seq dModel dHead : Nat) where marginAt : Fin seq → Rat /-- Epsilon per query. -/ epsAt : Fin seq → Rat + /-- Per-key weight bounds derived from score gaps. -/ + weightBoundAt : Fin seq → Fin seq → Rat /-- Global margin. -/ margin : Rat /-- Global epsilon. -/ @@ -223,8 +225,6 @@ structure InductionHeadCoreCache (seq dModel dHead : Nat) where wvDir : Fin dModel → Rat /-- Direction bias term. -/ bDir : Rat - /-- Value absolute bounds. -/ - valsAbs : Fin seq → Rat /-- Value lower bounds. -/ valsLo : Fin seq → Rat /-- Value upper bounds. -/ @@ -545,7 +545,7 @@ def buildInductionHeadCoreCache [NeZero seq] {dModel dHead : Nat} inputs.scale * dotLo q k let scoreLoPrev : Fin seq → Rat := fun q => scoreLo q (inputs.prev q) - let scoreGapLoBase : Fin seq → Fin seq → Rat := fun q k => + let scoreGapLoBaseRaw : Fin seq → Fin seq → Rat := fun q k => if masked q (inputs.prev q) then scoreLoPrev q - scoreHi q k else if masked q k then @@ -554,9 +554,11 @@ def buildInductionHeadCoreCache [NeZero seq] {dModel dHead : Nat} inputs.scale * dotDiffLoBase q k else inputs.scale * dotDiffHiBase q k + let scoreGapLoBase : Fin seq → Fin seq → Rat := + Bounds.cacheBound2 scoreGapLoBaseRaw let otherKeys : Fin seq → Finset (Fin seq) := fun q => (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - let worstKey : Fin seq → Option (Fin seq) := fun q => + let worstKeyRaw : Fin seq → Option (Fin seq) := fun q => if hq : q ∈ inputs.active then let ks := (List.finRange seq).filter (fun k => decide (k ≠ inputs.prev q)) match ks with @@ -568,6 +570,12 @@ def buildInductionHeadCoreCache [NeZero seq] {dModel dHead : Nat} some (ks.foldl step (scoreGapLoBase q k, k)).2 else none + let worstKeyArr : Array (Thunk (Option (Fin seq))) := + Array.ofFn (fun q => Thunk.mk (fun _ => worstKeyRaw q)) + let worstKey : Fin seq → Option (Fin seq) := fun q => + let t := worstKeyArr[q.1]'(by + simp [worstKeyArr, q.isLt]) + Thunk.get t let dotDiffLo : Fin seq → Fin seq → Rat := fun q k => match worstKey q with | some k' => @@ -596,7 +604,7 @@ def buildInductionHeadCoreCache [NeZero seq] {dModel dHead : Nat} else dotDiffHiBase q k | none => dotDiffHiBase q k - let scoreGapLo : Fin seq → Fin seq → Rat := fun q k => + let scoreGapLoRaw : Fin seq → Fin seq → Rat := fun q k => if masked q (inputs.prev q) then scoreLoPrev q - scoreHi q k else if masked q k then @@ -605,6 +613,8 @@ def buildInductionHeadCoreCache [NeZero seq] {dModel dHead : Nat} inputs.scale * dotDiffLo q k else inputs.scale * dotDiffHi q k + let scoreGapLo : Fin seq → Fin seq → Rat := + Bounds.cacheBound2 scoreGapLoRaw let marginAt : Fin seq → Rat := fun q => if hq : q ∈ inputs.active then let other := otherKeys q @@ -614,16 +624,23 @@ def buildInductionHeadCoreCache [NeZero seq] {dModel dHead : Nat} (0 : Rat) else (0 : Rat) - let epsAt : Fin seq → Rat := fun q => + let weightBoundAtBase : Fin seq → Fin seq → Rat := fun q k => + if hk : k = inputs.prev q then + (0 : Rat) + else + let gap := scoreGapLo q k + if gap < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + gap) + let weightBoundAt : Fin seq → Fin seq → Rat := + Bounds.cacheBound2 weightBoundAtBase + let epsAtBase : Fin seq → Rat := fun q => let other := otherKeys q - let total := - other.sum (fun k => - let gap := scoreGapLo q k - if gap < 0 then - (1 : Rat) - else - ratDivUp 1 (1 + gap)) + let total := other.sum (fun k => weightBoundAt q k) min (1 : Rat) total + let epsAt : Fin seq → Rat := + Bounds.cacheBoundThunk epsAtBase let margin : Rat := if h : inputs.active.Nonempty then inputs.active.inf' h marginAt @@ -641,10 +658,10 @@ def buildInductionHeadCoreCache [NeZero seq] {dModel dHead : Nat} Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) let bDir : Rat := Linear.dotFin dHead dirHead (fun d => inputs.bv d) - let valsAbs : Fin seq → Rat := fun q => - Linear.sumFin dModel (fun j => |wvDir j|) * lnAbsMax q - let valsLo : Fin seq → Rat := fun q => bDir - valsAbs q - let valsHi : Fin seq → Rat := fun q => bDir + valsAbs q + let valsLo : Fin seq → Rat := fun q => + bDir + Bounds.dotIntervalLower (fun j => wvDir j) (lnLo q) (lnHi q) + let valsHi : Fin seq → Rat := fun q => + bDir + Bounds.dotIntervalUpper (fun j => wvDir j) (lnLo q) (lnHi q) let univ : Finset (Fin seq) := Finset.univ have hnonempty : univ.Nonempty := by simp [univ] let lo := univ.inf' hnonempty valsLo @@ -658,6 +675,7 @@ def buildInductionHeadCoreCache [NeZero seq] {dModel dHead : Nat} let cert : InductionHeadCert seq := { eps := eps epsAt := epsAt + weightBoundAt := weightBoundAt margin := margin active := inputs.active prev := inputs.prev @@ -722,13 +740,13 @@ def buildInductionHeadCoreCache [NeZero seq] {dModel dHead : Nat} scoreGapLo := scoreGapLo marginAt := marginAt epsAt := epsAt + weightBoundAt := weightBoundAt margin := margin eps := eps dirHeadVec := dirHeadVec dirHead := dirHead wvDir := wvDir bDir := bDir - valsAbs := valsAbs valsLo := valsLo valsHi := valsHi univ := univ @@ -740,9 +758,10 @@ def buildInductionHeadCoreCache [NeZero seq] {dModel dHead : Nat} /-- The cached certificate is built from cache fields. -/ theorem buildInductionHeadCoreCache_cert_eq [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : - (buildInductionHeadCoreCache inputs).cert = + (buildInductionHeadCoreCache inputs).cert = { eps := (buildInductionHeadCoreCache inputs).eps epsAt := (buildInductionHeadCoreCache inputs).epsAt + weightBoundAt := (buildInductionHeadCoreCache inputs).weightBoundAt margin := (buildInductionHeadCoreCache inputs).margin active := inputs.active prev := inputs.prev diff --git a/Nfp/Sound/Induction/CoreDefs.lean b/Nfp/Sound/Induction/CoreDefs.lean index aab2279..fe0e217 100644 --- a/Nfp/Sound/Induction/CoreDefs.lean +++ b/Nfp/Sound/Induction/CoreDefs.lean @@ -171,6 +171,8 @@ structure InductionHeadCert (seq : Nat) where eps : Rat /-- Per-query weight tolerance derived from local margins. -/ epsAt : Fin seq → Rat + /-- Per-key weight bounds derived from score gaps. -/ + weightBoundAt : Fin seq → Fin seq → Rat /-- Score margin used to justify the weight tolerance. -/ margin : Rat /-- Active queries for which bounds are required. -/ @@ -195,6 +197,10 @@ structure InductionHeadCertSound [NeZero seq] {dModel dHead : Nat} Layers.OneHotApproxBoundsOnActive (Val := Real) (c.epsAt q : Real) (fun q' => q' = q) c.prev (fun q' k => Circuit.softmax (scoresRealOfInputs inputs q') k) + /-- Per-key weight bounds derived from local score gaps. -/ + weight_bounds_at : + ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → + Circuit.softmax (scoresRealOfInputs inputs q) k ≤ (c.weightBoundAt q k : Real) /-- Interval bounds hold for the direction values. -/ value_bounds : ValueIntervalBounds (vals := valsRealOfInputs inputs) c.values diff --git a/Nfp/Sound/Induction/CoreSound.lean b/Nfp/Sound/Induction/CoreSound.lean index b1d5189..4c8fbff 100644 --- a/Nfp/Sound/Induction/CoreSound.lean +++ b/Nfp/Sound/Induction/CoreSound.lean @@ -318,7 +318,7 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} inputs.scale * dotLo q k let scoreLoPrev : Fin seq → Rat := fun q => scoreLo q (inputs.prev q) - let scoreGapLoBase : Fin seq → Fin seq → Rat := fun q k => + let scoreGapLoBaseRaw : Fin seq → Fin seq → Rat := fun q k => if masked q (inputs.prev q) then scoreLoPrev q - scoreHi q k else if masked q k then @@ -327,9 +327,11 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} inputs.scale * dotDiffLoBase q k else inputs.scale * dotDiffHiBase q k + let scoreGapLoBase : Fin seq → Fin seq → Rat := + Bounds.cacheBound2 scoreGapLoBaseRaw let otherKeys : Fin seq → Finset (Fin seq) := fun q => (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - let worstKey : Fin seq → Option (Fin seq) := fun q => + let worstKeyRaw : Fin seq → Option (Fin seq) := fun q => if hq : q ∈ inputs.active then let ks := (List.finRange seq).filter (fun k => decide (k ≠ inputs.prev q)) match ks with @@ -341,6 +343,12 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} some (ks.foldl step (scoreGapLoBase q k, k)).2 else none + let worstKeyArr : Array (Thunk (Option (Fin seq))) := + Array.ofFn (fun q => Thunk.mk (fun _ => worstKeyRaw q)) + let worstKey : Fin seq → Option (Fin seq) := fun q => + let t := worstKeyArr[q.1]'(by + simp [worstKeyArr, q.isLt]) + Thunk.get t let dotDiffLo : Fin seq → Fin seq → Rat := fun q k => match worstKey q with | some k' => @@ -369,7 +377,7 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} else dotDiffHiBase q k | none => dotDiffHiBase q k - let scoreGapLo : Fin seq → Fin seq → Rat := fun q k => + let scoreGapLoRaw : Fin seq → Fin seq → Rat := fun q k => if masked q (inputs.prev q) then scoreLoPrev q - scoreHi q k else if masked q k then @@ -378,6 +386,8 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} inputs.scale * dotDiffLo q k else inputs.scale * dotDiffHi q k + let scoreGapLo : Fin seq → Fin seq → Rat := + Bounds.cacheBound2 scoreGapLoRaw let marginAt : Fin seq → Rat := fun q => if hq : q ∈ inputs.active then let other := otherKeys q @@ -387,16 +397,23 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} (0 : Rat) else (0 : Rat) - let epsAt : Fin seq → Rat := fun q => + let weightBoundAtBase : Fin seq → Fin seq → Rat := fun q k => + if hk : k = inputs.prev q then + (0 : Rat) + else + let gap := scoreGapLo q k + if gap < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + gap) + let weightBoundAt : Fin seq → Fin seq → Rat := + Bounds.cacheBound2 weightBoundAtBase + let epsAtBase : Fin seq → Rat := fun q => let other := otherKeys q - let total := - other.sum (fun k => - let gap := scoreGapLo q k - if gap < 0 then - (1 : Rat) - else - ratDivUp 1 (1 + gap)) + let total := other.sum (fun k => weightBoundAt q k) min (1 : Rat) total + let epsAt : Fin seq → Rat := + Bounds.cacheBoundThunk epsAtBase let margin : Rat := if h : inputs.active.Nonempty then inputs.active.inf' h marginAt @@ -416,10 +433,10 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) let bDir : Rat := Linear.dotFin dHead dirHead (fun d => inputs.bv d) - let valsAbs : Fin seq → Rat := fun q => - Linear.sumFin dModel (fun j => |wvDir j|) * lnAbsMax q - let valsLo : Fin seq → Rat := fun q => bDir - valsAbs q - let valsHi : Fin seq → Rat := fun q => bDir + valsAbs q + let valsLo : Fin seq → Rat := fun q => + bDir + Bounds.dotIntervalLower (fun j => wvDir j) (lnLo q) (lnHi q) + let valsHi : Fin seq → Rat := fun q => + bDir + Bounds.dotIntervalUpper (fun j => wvDir j) (lnLo q) (lnHi q) let univ : Finset (Fin seq) := Finset.univ have hnonempty : univ.Nonempty := by simp [univ] let lo := univ.inf' hnonempty valsLo @@ -433,6 +450,7 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} let cert : InductionHeadCert seq := { eps := eps epsAt := epsAt + weightBoundAt := weightBoundAt margin := margin active := inputs.active prev := inputs.prev @@ -456,78 +474,6 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} (x := inputs.embed q) hmodel hEps hSqrt simpa [lnBounds, lnLo, lnHi, lnRealOfInputs, Bounds.cacheBoundPair2_apply_left, Bounds.cacheBoundPair2_apply_right] using hln i - have hln_abs : ∀ q j, |lnRealOfInputs inputs q j| ≤ (lnAbsMax q : Real) := by - intro q j - have hln := hln_bounds q - have h := - Bounds.abs_le_intervalAbsBound_real (lo := lnLo q) (hi := lnHi q) - (x := lnRealOfInputs inputs q) (hlo := fun j => (hln j).1) - (hhi := fun j => (hln j).2) j - simpa only [lnAbsMax, lnAbsMaxArr, lnAbsMaxTask, Bounds.cacheBoundTask_apply, - Array.getElem_ofFn] using h - have hdot_abs_bound : - ∀ (v : Fin dModel → Rat) (q : Fin seq), - |dotProduct (fun j => (v j : Real)) (lnRealOfInputs inputs q)| ≤ - (Bounds.dotIntervalAbsBound v (lnLo q) (lnHi q) : Real) := by - intro v q - have hln := hln_bounds q - have hlo : ∀ j, (lnLo q j : Real) ≤ lnRealOfInputs inputs q j := fun j => - (hln j).1 - have hhi : ∀ j, lnRealOfInputs inputs q j ≤ (lnHi q j : Real) := fun j => - (hln j).2 - simpa using - (Bounds.abs_dotProduct_le_dotIntervalAbsBound_real - (v := v) (lo := lnLo q) (hi := lnHi q) - (x := lnRealOfInputs inputs q) hlo hhi) - have hdot_abs_bound_sum : - ∀ (v : Fin dModel → Rat) (q : Fin seq), - |dotProduct (fun j => (v j : Real)) (lnRealOfInputs inputs q)| ≤ - (Linear.sumFin dModel (fun j => |v j|) : Real) * (lnAbsMax q : Real) := by - intro v q - have hsum : - |∑ j, (v j : Real) * lnRealOfInputs inputs q j| ≤ - ∑ j, |(v j : Real) * lnRealOfInputs inputs q j| := by simpa [dotProduct] using - (Finset.abs_sum_le_sum_abs (s := (Finset.univ : Finset (Fin dModel))) - (f := fun j => (v j : Real) * lnRealOfInputs inputs q j)) - have hterm : - ∀ j, |(v j : Real) * lnRealOfInputs inputs q j| ≤ - (|v j| : Real) * (lnAbsMax q : Real) := by - intro j - have hln := hln_abs q j - have hnonneg : 0 ≤ (|v j| : Real) := by - exact abs_nonneg _ - calc - |(v j : Real) * lnRealOfInputs inputs q j| = - |(v j : Real)| * |lnRealOfInputs inputs q j| := by - simp [abs_mul] - _ ≤ (|v j| : Real) * (lnAbsMax q : Real) := - mul_le_mul_of_nonneg_left hln hnonneg - have hsum_le : - ∑ j, |(v j : Real) * lnRealOfInputs inputs q j| ≤ - ∑ j, (|v j| : Real) * (lnAbsMax q : Real) := by - refine Finset.sum_le_sum ?_ - intro j _ - exact hterm j - have hsum_mul : - ∑ j, (|v j| : Real) * (lnAbsMax q : Real) = - (∑ j, (|v j| : Real)) * (lnAbsMax q : Real) := by - symm - simpa using - (Finset.sum_mul (s := (Finset.univ : Finset (Fin dModel))) - (f := fun j => (|v j| : Real)) (a := (lnAbsMax q : Real))) - have hsum_cast : - (Linear.sumFin dModel (fun j => |v j|) : Real) = ∑ j, (|v j| : Real) := by - simpa [ratToReal] using (Linear.ratToReal_sumFin (f := fun j => |v j|)) - have hsum_eq : - ∑ j, (|v j| : Real) * (lnAbsMax q : Real) = - (Linear.sumFin dModel (fun j => |v j|) : Real) * (lnAbsMax q : Real) := by - calc - ∑ j, (|v j| : Real) * (lnAbsMax q : Real) - = (∑ j, (|v j| : Real)) * (lnAbsMax q : Real) := hsum_mul - _ = (Linear.sumFin dModel (fun j => |v j|) : Real) * (lnAbsMax q : Real) := by - simp [hsum_cast] - have hfinal := hsum.trans (hsum_le.trans_eq hsum_eq) - simpa [dotProduct] using hfinal have dotFin_cast {n : Nat} (f g : Fin n → Rat) : (Linear.dotFin n f g : Real) = dotProduct (fun j => (f j : Real)) (fun j => (g j : Real)) := by @@ -916,7 +862,8 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} calc scoresReal q k + (scoreGapLo q k : Real) = (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k := by - simp [scoreGapLo, hprevmask, add_comm] + simp [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, + hprevmask, add_comm] _ ≤ (scoreLoPrev q : Real) := hsum_le' _ ≤ scoresReal q (inputs.prev q) := hscore_prev · by_cases hmask : masked q k @@ -929,7 +876,8 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} scoresReal q k + (scoreGapLo q k : Real) = (inputs.maskValue : Real) + (scoreLoPrev q : Real) - (inputs.maskValue : Real) := by - simp [scoreGapLo, hprevmask, hmask, hscore_k] + simp [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, + hprevmask, hmask, hscore_k] _ = (scoreLoPrev q : Real) := by simp [add_sub_cancel_left] _ ≤ scoresReal q (inputs.prev q) := hscore_prev @@ -944,13 +892,15 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} · have hscale_real : 0 ≤ (inputs.scale : Real) := ratToReal_nonneg_of_nonneg hscale have hle := mul_le_mul_of_nonneg_left hdiff.1 hscale_real - simpa [scoreGapLo, hprevmask, hmask, hscale] using hle + simpa [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, + hprevmask, hmask, hscale] using hle · have hscale_nonpos : inputs.scale ≤ 0 := le_of_lt (lt_of_not_ge hscale) have hscale_real : (inputs.scale : Real) ≤ 0 := (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos have hle := mul_le_mul_of_nonpos_left hdiff.2 hscale_real - simpa [scoreGapLo, hprevmask, hmask, hscale] using hle + simpa [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, + hprevmask, hmask, hscale] using hle have hscore_prev : scoresReal q (inputs.prev q) = (inputs.scale : Real) * @@ -1079,6 +1029,16 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} _ = (marginAt q : Real) + scoresReal q k := by simp [add_comm] exact hstep'.trans hscore' + have hweightBoundAt : + ∀ q k, k ≠ inputs.prev q → + weightBoundAt q k = + if scoreGapLo q k < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + scoreGapLo q k) := by + intro q k hk + simpa [weightBoundAt, weightBoundAtBase, hk] using + (Bounds.cacheBound2_apply (f := weightBoundAtBase) q k) have hepsAt : ∀ q, epsAt q = min (1 : Rat) @@ -1088,7 +1048,19 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} else ratDivUp 1 (1 + scoreGapLo q k))) := by intro q - rfl + have hsum : + (otherKeys q).sum (fun k => weightBoundAt q k) = + (otherKeys q).sum (fun k => + if scoreGapLo q k < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + scoreGapLo q k)) := by + refine Finset.sum_congr rfl ?_ + intro k hk + have hk' : k ≠ inputs.prev q := (Finset.mem_erase.mp hk).1 + simp [hweightBoundAt q k hk'] + simpa [epsAt, epsAtBase, hsum] using + (Bounds.cacheBoundThunk_apply (f := epsAtBase) q) have oneHot_bounds_at : ∀ q, q ∈ inputs.active → Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAt q : Real) @@ -1104,6 +1076,20 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} (hepsAt := hepsAt) (hscore_gap_real_at := hscore_gap_real_at) q hq + have weight_bounds_at : + ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → + weights q k ≤ (weightBoundAt q k : Real) := by + intro q hq k hk + exact + Sound.weight_bound_at_of_scoreGapLo + (active := inputs.active) + (prev := inputs.prev) + (scoresReal := scoresReal) + (scoreGapLo := scoreGapLo) + (weightBoundAt := weightBoundAt) + (hweightBoundAt := hweightBoundAt) + (hscore_gap_real_at := hscore_gap_real_at) + q hq k hk have hepsAt_le_eps : ∀ q, q ∈ inputs.active → epsAt q ≤ eps := by intro q hq @@ -1175,22 +1161,39 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} (valCert.valsLo k : Real) ≤ valsRealOfInputs inputs k ∧ valsRealOfInputs inputs k ≤ (valCert.valsHi k : Real) := by intro k - have hdot_abs : - |dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k)| ≤ - (valsAbs k : Real) := by - have hdot := hdot_abs_bound_sum (fun j => wvDir j) k - simpa [valsAbs, ratToReal_mul] using hdot - have hdot_bounds := (abs_le).1 hdot_abs - have hlow' := add_le_add_right hdot_bounds.1 (bDir : Real) - have hhigh' := add_le_add_right hdot_bounds.2 (bDir : Real) + have hln := hln_bounds k + have hlo : ∀ j, (lnLo k j : Real) ≤ lnRealOfInputs inputs k j := fun j => + (hln j).1 + have hhi : ∀ j, lnRealOfInputs inputs k j ≤ (lnHi k j : Real) := fun j => + (hln j).2 + have hlow' : + (Bounds.dotIntervalLower (fun j => wvDir j) (lnLo k) (lnHi k) : Real) + + (bDir : Real) ≤ + dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + + (bDir : Real) := by + simpa using + (Bounds.dotIntervalLower_le_dotProduct_real_add + (v := fun j => wvDir j) + (lo := lnLo k) (hi := lnHi k) + (x := lnRealOfInputs inputs k) (b := (bDir : Real)) hlo hhi) + have hhigh' : + dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + + (bDir : Real) ≤ + (Bounds.dotIntervalUpper (fun j => wvDir j) (lnLo k) (lnHi k) : Real) + + (bDir : Real) := by + simpa using + (Bounds.dotProduct_le_dotIntervalUpper_real_add + (v := fun j => wvDir j) + (lo := lnLo k) (hi := lnHi k) + (x := lnRealOfInputs inputs k) (b := (bDir : Real)) hlo hhi) have hlow : (valCert.valsLo k : Real) ≤ valsRealOfInputs inputs k := by - simpa [valCert, valsLo, valsAbs, hvals_eq k, ratToReal_sub, - sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using hlow' + simpa [valCert, valsLo, hvals_eq k, ratToReal_add, add_comm, add_left_comm, + add_assoc] using hlow' have hhigh : valsRealOfInputs inputs k ≤ (valCert.valsHi k : Real) := by - simpa [valCert, valsHi, valsAbs, hvals_eq k, ratToReal_add, - add_comm, add_left_comm, add_assoc] using hhigh' + simpa [valCert, valsHi, hvals_eq k, ratToReal_add, add_comm, add_left_comm, + add_assoc] using hhigh' exact ⟨hlow, hhigh⟩ have hvals_bounds : ValueIntervalBounds (vals := valsRealOfInputs inputs) valCert := by @@ -1242,6 +1245,7 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} exact { softmax_bounds := hsoftmax_bounds oneHot_bounds_at := oneHot_bounds_at + weight_bounds_at := weight_bounds_at value_bounds := hvals_bounds } · have : False := by have hnone := diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean index 189e925..bd1a6ba 100644 --- a/Nfp/Sound/Induction/LogitDiff.lean +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -16,17 +16,6 @@ open Nfp.Circuit variable {seq : Nat} -private theorem valueRangeBounds_of_valueIntervalBounds - {vals : Fin seq → Real} {c : ValueInterval seq} - (h : ValueIntervalBounds vals c) : - Layers.ValueRangeBounds (Val := Real) (c.lo : Real) (c.hi : Real) vals := by - refine { lo_le_hi := ?_, lo_le := ?_, le_hi := ?_ } - · exact ratToReal_le_of_le h.lo_le_hi - · intro k - exact le_trans (h.lo_le_valsLo k) (h.vals_bounds k).1 - · intro k - exact le_trans (h.vals_bounds k).2 (h.valsHi_le_hi k) - section LogitDiffLowerBound variable {seq dModel dHead : Nat} [NeZero seq] @@ -42,8 +31,8 @@ noncomputable def headLogitDiff (inputs : Model.InductionHeadInputs seq dModel d /-- Lower bound computed from the per-key lower bounds in an induction certificate. -/ def logitDiffLowerBoundFromCert (c : InductionHeadCert seq) : Option Rat := - Circuit.logitDiffLowerBoundAt c.active c.prev c.epsAt - c.values.lo c.values.hi c.values.valsLo + Circuit.logitDiffLowerBoundAtLo c.active c.prev c.epsAt + c.values.lo c.values.valsLo theorem logitDiffLowerBoundFromCert_le (inputs : Model.InductionHeadInputs seq dModel dHead) @@ -58,66 +47,143 @@ theorem logitDiffLowerBoundFromCert_le | succ n => let weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := fun q k => Circuit.softmax (scoresRealOfInputs inputs q) k - have hweights : - Layers.OneHotApproxBoundsOnActive (Val := Real) (c.epsAt q : Real) - (fun q' => q' = q) c.prev weights := - hsound.oneHot_bounds_at q hq - have hvalsRange : - Layers.ValueRangeBounds (Val := Real) (c.values.lo : Real) (c.values.hi : Real) - (valsRealOfInputs inputs) := - valueRangeBounds_of_valueIntervalBounds - (vals := valsRealOfInputs inputs) (c := c.values) hsound.value_bounds - have happrox := - Layers.inductionSpecApproxOn_of_oneHotApprox_valueRange - (Val := Real) - (n := n) - (ε := (c.epsAt q : Real)) - (lo := (c.values.lo : Real)) - (hi := (c.values.hi : Real)) - (active := fun q' => q' = q) - (prev := c.prev) - (weights := weights) - (vals := valsRealOfInputs inputs) - hweights hvalsRange + let vals : Fin (Nat.succ n) → Real := valsRealOfInputs inputs + let others : Finset (Fin (Nat.succ n)) := + (Finset.univ : Finset (Fin (Nat.succ n))).erase (c.prev q) + let sumOthers : Real := ∑ k ∈ others, weights q k + let valsLoPrev : Real := (c.values.valsLo (c.prev q) : Real) + let lo : Real := (c.values.lo : Real) have hboundRat : lb ≤ c.values.valsLo (c.prev q) - - c.epsAt q * (c.values.hi - c.values.lo) := by + c.epsAt q * (c.values.valsLo (c.prev q) - c.values.lo) := by refine - Circuit.logitDiffLowerBoundAt_le + Circuit.logitDiffLowerBoundAtLo_le (active := c.active) (prev := c.prev) (epsAt := c.epsAt) (lo := c.values.lo) - (hi := c.values.hi) - (vals := c.values.valsLo) + (valsLo := c.values.valsLo) q hq lb ?_ simpa [logitDiffLowerBoundFromCert] using hbound have hboundReal : - (lb : Real) ≤ - (c.values.valsLo (c.prev q) : Real) - - (c.epsAt q : Real) * ((c.values.hi : Real) - (c.values.lo : Real)) := by + (lb : Real) ≤ valsLoPrev - (c.epsAt q : Real) * (valsLoPrev - lo) := by simpa [ratToReal_sub, ratToReal_mul] using ratToReal_le_of_le hboundRat - have hvalsLo : - (c.values.valsLo (c.prev q) : Real) ≤ - valsRealOfInputs inputs (c.prev q) := by + have hweights_nonneg : ∀ k, 0 ≤ weights q k := + hsound.softmax_bounds.nonneg q hq + have hweights := hsound.oneHot_bounds_at q hq + have hsum_decomp : + weights q (c.prev q) + ∑ k ∈ others, weights q k = ∑ k, weights q k := by + simp [others] + have hsum : + weights q (c.prev q) + ∑ k ∈ others, weights q k = 1 := by + calc + weights q (c.prev q) + ∑ k ∈ others, weights q k = ∑ k, weights q k := hsum_decomp + _ = 1 := hweights.sum_one q rfl + have hsum_others_le : sumOthers ≤ (c.epsAt q : Real) := by + have hprev : 1 ≤ weights q (c.prev q) + (c.epsAt q : Real) := + hweights.prev_large q rfl + have hprev' : + weights q (c.prev q) + sumOthers ≤ + weights q (c.prev q) + (c.epsAt q : Real) := by + simpa [hsum, sumOthers] using hprev + exact (add_le_add_iff_left (weights q (c.prev q))).1 hprev' + have hvals_lo : ∀ k, lo ≤ vals k := by + intro k + have hlo := hsound.value_bounds.lo_le_valsLo k + have hvals := (hsound.value_bounds.vals_bounds k).1 + exact le_trans hlo hvals + have hvalsLo_prev : valsLoPrev ≤ vals (c.prev q) := by exact (hsound.value_bounds.vals_bounds (c.prev q)).1 - have hvalsLo' : - (c.values.valsLo (c.prev q) : Real) - - (c.epsAt q : Real) * ((c.values.hi : Real) - (c.values.lo : Real)) ≤ - valsRealOfInputs inputs (c.prev q) - - (c.epsAt q : Real) * ((c.values.hi : Real) - (c.values.lo : Real)) := by - exact - sub_le_sub_right hvalsLo - ((c.epsAt q : Real) * ((c.values.hi : Real) - (c.values.lo : Real))) - have hlow : - valsRealOfInputs inputs (c.prev q) - - (c.epsAt q : Real) * ((c.values.hi : Real) - (c.values.lo : Real)) ≤ - dotProduct (weights q) (valsRealOfInputs inputs) := by - exact (sub_le_iff_le_add).2 (happrox q rfl).2 - have hle : - (lb : Real) ≤ dotProduct (weights q) (valsRealOfInputs inputs) := - le_trans hboundReal (le_trans hvalsLo' hlow) - simpa [headLogitDiff, weights] using hle + have hsum_vals_ge : + sumOthers * lo ≤ ∑ k ∈ others, weights q k * vals k := by + have hsum_lo : + sumOthers * lo = ∑ k ∈ others, weights q k * lo := by + have hsum_lo' : + (∑ k ∈ others, weights q k) * lo = + ∑ k ∈ others, weights q k * lo := by + simpa using + (Finset.sum_mul (s := others) (f := fun k => weights q k) (a := lo)) + simpa [sumOthers] using hsum_lo' + have hle : + ∀ k ∈ others, weights q k * lo ≤ weights q k * vals k := by + intro k _hk + have hval := hvals_lo k + have hnonneg := hweights_nonneg k + exact mul_le_mul_of_nonneg_left hval hnonneg + have hsum' : + ∑ k ∈ others, weights q k * lo ≤ + ∑ k ∈ others, weights q k * vals k := by + exact Finset.sum_le_sum hle + simpa [hsum_lo] using hsum' + have hsum_prod : + weights q (c.prev q) * vals (c.prev q) + + ∑ k ∈ others, weights q k * vals k = + ∑ k, weights q k * vals k := by + simp [others] + have hout_eq : + dotProduct (weights q) vals = + weights q (c.prev q) * vals (c.prev q) + + ∑ k ∈ others, weights q k * vals k := by + simpa [dotProduct] using hsum_prod.symm + have hdot_ge : + weights q (c.prev q) * vals (c.prev q) + sumOthers * lo ≤ + dotProduct (weights q) vals := by + have hle : + weights q (c.prev q) * vals (c.prev q) + sumOthers * lo ≤ + weights q (c.prev q) * vals (c.prev q) + + ∑ k ∈ others, weights q k * vals k := by + simpa [add_comm, add_left_comm, add_assoc] using + (add_le_add_left hsum_vals_ge (weights q (c.prev q) * vals (c.prev q))) + simpa [sumOthers, hout_eq, add_comm, add_left_comm, add_assoc] using hle + have hprev_lo : + weights q (c.prev q) * valsLoPrev ≤ + weights q (c.prev q) * vals (c.prev q) := by + exact mul_le_mul_of_nonneg_left hvalsLo_prev (hweights_nonneg (c.prev q)) + have hdot_ge' : + weights q (c.prev q) * valsLoPrev + sumOthers * lo ≤ + dotProduct (weights q) vals := by + have hle : + weights q (c.prev q) * valsLoPrev + sumOthers * lo ≤ + weights q (c.prev q) * vals (c.prev q) + sumOthers * lo := by + simpa [add_comm, add_left_comm, add_assoc] using + (add_le_add_right hprev_lo (sumOthers * lo)) + exact hle.trans hdot_ge + have hsplit : + weights q (c.prev q) * valsLoPrev + sumOthers * lo = + valsLoPrev - sumOthers * (valsLoPrev - lo) := by + have hsplit' : + weights q (c.prev q) * valsLoPrev + sumOthers * lo = + (weights q (c.prev q) + sumOthers) * valsLoPrev - + sumOthers * (valsLoPrev - lo) := by + ring + calc + weights q (c.prev q) * valsLoPrev + sumOthers * lo = + (weights q (c.prev q) + sumOthers) * valsLoPrev - + sumOthers * (valsLoPrev - lo) := hsplit' + _ = valsLoPrev - sumOthers * (valsLoPrev - lo) := by + simp [hsum, sumOthers] + have hdiff_nonneg : 0 ≤ valsLoPrev - lo := by + exact sub_nonneg.mpr (hsound.value_bounds.lo_le_valsLo (c.prev q)) + have hsum_mul_le : + sumOthers * (valsLoPrev - lo) ≤ + (c.epsAt q : Real) * (valsLoPrev - lo) := by + exact mul_le_mul_of_nonneg_right hsum_others_le hdiff_nonneg + have hsub_le : + valsLoPrev - (c.epsAt q : Real) * (valsLoPrev - lo) ≤ + valsLoPrev - sumOthers * (valsLoPrev - lo) := by + exact sub_le_sub_left hsum_mul_le valsLoPrev + have hdot_lower : + valsLoPrev - (c.epsAt q : Real) * (valsLoPrev - lo) ≤ + dotProduct (weights q) vals := by + calc + valsLoPrev - (c.epsAt q : Real) * (valsLoPrev - lo) ≤ + valsLoPrev - sumOthers * (valsLoPrev - lo) := hsub_le + _ = weights q (c.prev q) * valsLoPrev + sumOthers * lo := by + simp [hsplit] + _ ≤ dotProduct (weights q) vals := hdot_ge' + have hle : (lb : Real) ≤ dotProduct (weights q) vals := + le_trans hboundReal hdot_lower + simpa [headLogitDiff, weights, vals] using hle /-- Certified logit-diff lower bound derived from exact head inputs. -/ structure InductionLogitLowerBoundResult diff --git a/Nfp/Sound/Induction/OneHot.lean b/Nfp/Sound/Induction/OneHot.lean index 8d65f76..8cbca6d 100644 --- a/Nfp/Sound/Induction/OneHot.lean +++ b/Nfp/Sound/Induction/OneHot.lean @@ -313,6 +313,57 @@ theorem oneHot_bounds_at_of_scoreGapLo simpa using h exact hle.trans hsum_others_le +/-- Per-key weight bounds on a single active query, derived from per-key score gaps. -/ +theorem weight_bound_at_of_scoreGapLo + (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (scoresReal : Fin seq → Fin seq → Real) + (scoreGapLo : Fin seq → Fin seq → Rat) + (weightBoundAt : Fin seq → Fin seq → Rat) + (hweightBoundAt : + ∀ q k, k ≠ prev q → + weightBoundAt q k = + if scoreGapLo q k < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + scoreGapLo q k)) + (hscore_gap_real_at : + ∀ q, q ∈ active → ∀ k, k ≠ prev q → + scoresReal q k + (scoreGapLo q k : Real) ≤ scoresReal q (prev q)) : + ∀ q, q ∈ active → ∀ k, k ≠ prev q → + Circuit.softmax (scoresReal q) k ≤ (weightBoundAt q k : Real) := by + classical + intro q hq k hk + by_cases hneg : scoreGapLo q k < 0 + · have hle : Circuit.softmax (scoresReal q) k ≤ 1 := by + simpa using (Circuit.softmax_le_one (scores := scoresReal q) k) + simpa [hweightBoundAt q k hk, hneg] using hle + · have hnonneg : 0 ≤ scoreGapLo q k := le_of_not_gt hneg + have hnonneg_real : 0 ≤ (scoreGapLo q k : Real) := by + exact ratToReal_nonneg_of_nonneg hnonneg + have hscore := hscore_gap_real_at q hq k hk + have hsoft : + Circuit.softmax (scoresReal q) k ≤ 1 / (1 + (scoreGapLo q k : Real)) := by + simpa using + (Circuit.softmax_other_le_inv_one_add (scores := scoresReal q) + (prev := prev q) (k := k) (m := (scoreGapLo q k : Real)) + hnonneg_real hscore) + have hpos : (0 : Rat) < 1 + scoreGapLo q k := by + have hle : (1 : Rat) ≤ 1 + scoreGapLo q k := by + exact le_add_of_nonneg_right hnonneg + exact lt_of_lt_of_le zero_lt_one hle + have hden : (1 + scoreGapLo q k) ≠ 0 := by + exact ne_of_gt hpos + have hrat : + 1 / (1 + (scoreGapLo q k : Real)) ≤ + ratToReal (ratDivUp 1 (1 + scoreGapLo q k)) := by + simpa [ratToReal] using + (ratDivUp_ge_real 1 (1 + scoreGapLo q k) hden) + have hbound' : + Circuit.softmax (scoresReal q) k ≤ ratToReal (ratDivUp 1 (1 + scoreGapLo q k)) := + hsoft.trans hrat + simpa [hweightBoundAt q k hk, hneg] using hbound' + end Sound end Nfp From dba1722256ba68b528a25ae94d684b21b3083068 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 00:20:41 +0100 Subject: [PATCH 146/244] Parse layer_norm_eps exactly in head input builder --- scripts/build_gpt2_head_inputs.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/build_gpt2_head_inputs.py b/scripts/build_gpt2_head_inputs.py index 4da754f..5b1c32c 100644 --- a/scripts/build_gpt2_head_inputs.py +++ b/scripts/build_gpt2_head_inputs.py @@ -332,7 +332,10 @@ def main() -> None: ln_eps_raw = header.get("layer_norm_eps") if ln_eps_raw is None: raise SystemExit("Missing layer_norm_eps in header.") - ln_eps = rat_from_float_exact(float(ln_eps_raw)) + try: + ln_eps = Fraction(ln_eps_raw) + except ValueError: + ln_eps = rat_from_float_exact(float(ln_eps_raw)) write_head_inputs( args.output, scale, From e151de633569be0a9d9d7d214eab27ab08aa54cc Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 00:46:12 +0100 Subject: [PATCH 147/244] Streamline timing output and skip masked dot bounds --- Nfp/Cli.lean | 36 ++- Nfp/IO/InductionHead.lean | 371 +++++++++++++++-------------- Nfp/IO/Timing.lean | 120 +++++++--- Nfp/Sound/Induction/Core.lean | 11 +- Nfp/Sound/Induction/CoreSound.lean | 27 ++- 5 files changed, 338 insertions(+), 227 deletions(-) diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index 545ab3b..d2086f5 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -176,8 +176,10 @@ def runInductionCertifyHead (p : Parsed) : IO UInt32 := do let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) let minMarginStr? := (p.flag? "min-margin").map (·.as! String) let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) + let timing? := (p.flag? "timing").map (·.as! Nat) + let heartbeatMs? := (p.flag? "heartbeat-ms").map (·.as! Nat) IO.runInductionCertifyHead inputsPath minActive? minLogitDiffStr? - minMarginStr? maxEpsStr? + minMarginStr? maxEpsStr? timing? heartbeatMs? /-- `nfp induction certify_head_nonvacuous` subcommand. -/ def runInductionCertifyHeadNonvacuous (p : Parsed) : IO UInt32 := do @@ -186,8 +188,10 @@ def runInductionCertifyHeadNonvacuous (p : Parsed) : IO UInt32 := do let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) let minMarginStr? := (p.flag? "min-margin").map (·.as! String) let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) + let timing? := (p.flag? "timing").map (·.as! Nat) + let heartbeatMs? := (p.flag? "heartbeat-ms").map (·.as! Nat) IO.runInductionCertifyHeadNonvacuous inputsPath minActive? minLogitDiffStr? - minMarginStr? maxEpsStr? + minMarginStr? maxEpsStr? timing? heartbeatMs? /-- `nfp induction certify_head` subcommand. -/ def inductionCertifyHeadCmd : Cmd := `[Cli| @@ -201,6 +205,8 @@ def inductionCertifyHeadCmd : Cmd := `[Cli| (rational literal). Defaults to 0." "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." + timing : Nat; "Emit timing output to stdout (0=off, 1=on)." + "heartbeat-ms" : Nat; "Emit progress heartbeat every N ms (0 disables)." ] /-- `nfp induction certify_head_nonvacuous` subcommand. -/ @@ -215,6 +221,8 @@ def inductionCertifyHeadNonvacuousCmd : Cmd := `[Cli| (rational literal; default: 0)." "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." + timing : Nat; "Emit timing output to stdout (0=off, 1=on)." + "heartbeat-ms" : Nat; "Emit progress heartbeat every N ms (0 disables)." ] /-- `nfp induction certify_head_model` subcommand. -/ @@ -229,8 +237,10 @@ def runInductionCertifyHeadModel (p : Parsed) : IO UInt32 := do let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) let minMarginStr? := (p.flag? "min-margin").map (·.as! String) let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) + let timing? := (p.flag? "timing").map (·.as! Nat) + let heartbeatMs? := (p.flag? "heartbeat-ms").map (·.as! Nat) IO.runInductionCertifyHeadModel modelPath layer head dirTarget dirNegative period? - minActive? minLogitDiffStr? minMarginStr? maxEpsStr? + minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? /-- `nfp induction certify_head_model_nonvacuous` subcommand. -/ def runInductionCertifyHeadModelNonvacuous (p : Parsed) : IO UInt32 := do @@ -244,8 +254,10 @@ def runInductionCertifyHeadModelNonvacuous (p : Parsed) : IO UInt32 := do let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) let minMarginStr? := (p.flag? "min-margin").map (·.as! String) let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) + let timing? := (p.flag? "timing").map (·.as! Nat) + let heartbeatMs? := (p.flag? "heartbeat-ms").map (·.as! Nat) IO.runInductionCertifyHeadModelNonvacuous modelPath layer head dirTarget dirNegative period? - minActive? minLogitDiffStr? minMarginStr? maxEpsStr? + minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? /-- `nfp induction certify_head_model_auto` subcommand. -/ def runInductionCertifyHeadModelAuto (p : Parsed) : IO UInt32 := do @@ -257,8 +269,10 @@ def runInductionCertifyHeadModelAuto (p : Parsed) : IO UInt32 := do let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) let minMarginStr? := (p.flag? "min-margin").map (·.as! String) let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) + let timing? := (p.flag? "timing").map (·.as! Nat) + let heartbeatMs? := (p.flag? "heartbeat-ms").map (·.as! Nat) IO.runInductionCertifyHeadModelAuto modelPath layer head period? - minActive? minLogitDiffStr? minMarginStr? maxEpsStr? + minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? /-- `nfp induction certify_head_model_auto_nonvacuous` subcommand. -/ def runInductionCertifyHeadModelAutoNonvacuous (p : Parsed) : IO UInt32 := do @@ -270,8 +284,10 @@ def runInductionCertifyHeadModelAutoNonvacuous (p : Parsed) : IO UInt32 := do let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) let minMarginStr? := (p.flag? "min-margin").map (·.as! String) let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) + let timing? := (p.flag? "timing").map (·.as! Nat) + let heartbeatMs? := (p.flag? "heartbeat-ms").map (·.as! Nat) IO.runInductionCertifyHeadModelAutoNonvacuous modelPath layer head period? - minActive? minLogitDiffStr? minMarginStr? maxEpsStr? + minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? /-- `nfp induction certify_head_model` subcommand. -/ def inductionCertifyHeadModelCmd : Cmd := `[Cli| @@ -290,6 +306,8 @@ def inductionCertifyHeadModelCmd : Cmd := `[Cli| (rational literal). Defaults to 0." "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." + timing : Nat; "Emit timing output to stdout (0=off, 1=on)." + "heartbeat-ms" : Nat; "Emit progress heartbeat every N ms (0 disables)." ] /-- `nfp induction certify_head_model_nonvacuous` subcommand. -/ @@ -309,6 +327,8 @@ def inductionCertifyHeadModelNonvacuousCmd : Cmd := `[Cli| (rational literal; default: 0)." "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." + timing : Nat; "Emit timing output to stdout (0=off, 1=on)." + "heartbeat-ms" : Nat; "Emit progress heartbeat every N ms (0 disables)." ] /-- `nfp induction certify_head_model_auto` subcommand. -/ @@ -327,6 +347,8 @@ def inductionCertifyHeadModelAutoCmd : Cmd := `[Cli| (rational literal). Defaults to 0." "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." + timing : Nat; "Emit timing output to stdout (0=off, 1=on)." + "heartbeat-ms" : Nat; "Emit progress heartbeat every N ms (0 disables)." ] /-- `nfp induction certify_head_model_auto_nonvacuous` subcommand. -/ @@ -345,6 +367,8 @@ def inductionCertifyHeadModelAutoNonvacuousCmd : Cmd := `[Cli| (rational literal; default: 0)." "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." + timing : Nat; "Emit timing output to stdout (0=off, 1=on)." + "heartbeat-ms" : Nat; "Emit progress heartbeat every N ms (0 disables)." ] /-- `nfp induction head_interval` subcommand. -/ diff --git a/Nfp/IO/InductionHead.lean b/Nfp/IO/InductionHead.lean index 3993c03..b0ec922 100644 --- a/Nfp/IO/InductionHead.lean +++ b/Nfp/IO/InductionHead.lean @@ -26,6 +26,17 @@ private def unwrapTaskResult {α : Type} (res : Except IO.Error α) : IO α := | .ok a => pure a | .error e => throw e +private def configureTiming (timing? : Option Nat) (heartbeatMs? : Option Nat) : IO Unit := do + match timing? with + | some v => setTimingStdout (v ≠ 0) + | none => pure () + match heartbeatMs? with + | some v => + setTimingHeartbeatMs (UInt32.ofNat v) + if timing?.isNone && (v != 0) then + setTimingStdout true + | none => pure () + open Nfp.Circuit private def valueBoundsModeFromEnv : IO (Option Bool) := do @@ -34,18 +45,16 @@ private def valueBoundsModeFromEnv : IO (Option Bool) := do | some "cached" => return some false | _ => return none -/-- Read the heartbeat interval (ms) for long-running induction cert builds. -/ -private def heartbeatMsFromEnv : IO UInt32 := do - let defaultMs : Nat := 10000 - let ms := (← IO.getEnv "NFP_TIMING_HEARTBEAT_MS").bind String.toNat? |>.getD defaultMs - return UInt32.ofNat ms +/-- Resolve the heartbeat interval (ms) for long-running induction cert builds. -/ +private def heartbeatMs : IO UInt32 := + timingHeartbeatMs private def timePureWithHeartbeat {α : Type} (label : String) (f : Unit → α) : IO α := do let t0 ← monoUsNow - IO.println s!"timing: {label} start" - flushStdout + timingPrint s!"timing: {label} start" + timingFlush let task : Task α := Task.spawn (fun _ => f ()) - let heartbeatMs ← heartbeatMsFromEnv + let heartbeatMs ← heartbeatMs if heartbeatMs ≠ 0 then let mut finished := (← IO.hasFinished task) while !finished do @@ -53,11 +62,11 @@ private def timePureWithHeartbeat {α : Type} (label : String) (f : Unit → α) finished := (← IO.hasFinished task) if !finished then let now ← monoUsNow - IO.println s!"timing: {label} running {now - t0} us" - flushStdout + timingPrint s!"timing: {label} running {now - t0} us" + timingFlush let res ← IO.wait task let t1 ← monoUsNow - IO.println s!"timing: {label} {t1 - t0} us" + timingPrint s!"timing: {label} {t1 - t0} us" return res private def forceRat (x : Rat) : IO Unit := do @@ -69,8 +78,8 @@ private def forceRat (x : Rat) : IO Unit := do /-- Profile the core induction-head bounds used by the sound certificate builder. -/ private def timeInductionHeadCoreStages {seq dModel dHead : Nat} [NeZero seq] (inputs : Model.InductionHeadInputs seq dModel dHead) : IO Unit := do - IO.println "timing: core stages start" - flushStdout + timingPrint "timing: core stages start" + timingFlush let lnBounds ← timePureWithHeartbeat "core: ln bounds" (fun () => Sound.headLnBounds inputs) let lnLo := lnBounds.1 @@ -152,18 +161,18 @@ private def timeInductionHeadCoreStages {seq dModel dHead : Nat} [NeZero seq] decide (margin < 0)) let verboseTiming ← IO.getEnv "NFP_TIMING_VERBOSE" if verboseTiming.isSome then - IO.println s!"timing: core: margin neg={marginNeg}" + timingPrint s!"timing: core: margin neg={marginNeg}" let tEps0 ← monoUsNow - IO.println "timing: core: eps start" - flushStdout + timingPrint "timing: core: eps start" + timingFlush let eps := if marginNeg then (1 : Rat) else ratDivUp (seq - 1) (1 + margin) let tEps1 ← monoUsNow - IO.println s!"timing: core: eps {tEps1 - tEps0} us" - flushStdout + timingPrint s!"timing: core: eps {tEps1 - tEps0} us" + timingFlush let _ := marginAt let dirHeadVec ← timePureWithHeartbeat "core: dir head vec" (fun () => Sound.dirHeadVecOfInputs inputs) @@ -188,8 +197,8 @@ private def timeInductionHeadCoreStages {seq dModel dHead : Nat} [NeZero seq] let lo := univ.inf' hnonempty valsLo let hi := univ.sup' hnonempty valsHi (lo, hi)) - IO.println "timing: core stages done" - flushStdout + timingPrint "timing: core stages done" + timingFlush /-- Load induction head inputs from disk. -/ def loadInductionHeadInputs (path : System.FilePath) : @@ -198,14 +207,14 @@ def loadInductionHeadInputs (path : System.FilePath) : let t0 ← monoUsNow let data ← IO.FS.readFile path let t1 ← monoUsNow - IO.println s!"timing: read head input file {t1 - t0} us" + timingPrint s!"timing: read head input file {t1 - t0} us" let t2 ← monoUsNow let parsed := match Pure.parseInductionHeadInputs data with | Except.error msg => Except.error msg | Except.ok v => Except.ok v let t3 ← monoUsNow - IO.println s!"timing: parse head input file {t3 - t2} us" + timingPrint s!"timing: parse head input file {t3 - t2} us" return parsed private def ratToString (x : Rat) : String := @@ -333,36 +342,36 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} let seq := Nat.succ n let _ : NeZero seq := ⟨by simp⟩ logTiming "start: head build induction cert" - IO.println "timing: head build induction cert start" - flushStdout + timingPrint "timing: head build induction cert start" + timingFlush let verboseTiming ← IO.getEnv "NFP_TIMING_VERBOSE" let taskBenchEnv ← IO.getEnv "NFP_TASK_BENCH" if taskBenchEnv.isSome then let n := taskBenchEnv.bind String.toNat? |>.getD 1000 Nfp.IO.taskBench n if verboseTiming.isSome then - IO.println s!"timing: head dims seq={seq} dModel={dModel} dHead={dHead}" - IO.println s!"timing: head active card={inputs.active.card}" - flushStdout + timingPrint s!"timing: head dims seq={seq} dModel={dModel} dHead={dHead}" + timingPrint s!"timing: head active card={inputs.active.card}" + timingFlush let precompute := (← IO.getEnv "NFP_TIMING_PRECOMPUTE").isSome if precompute then - IO.println "timing: head ln bounds start" - flushStdout + timingPrint "timing: head ln bounds start" + timingFlush let lnBounds ← timePure "head: ln bounds" (fun () => Sound.headLnBounds inputs) - IO.println "timing: head ln bounds done" - flushStdout - IO.println "timing: head qkv bounds start" - flushStdout + timingPrint "timing: head ln bounds done" + timingFlush + timingPrint "timing: head qkv bounds start" + timingFlush let lnLo := lnBounds.1 let lnHi := lnBounds.2 let qkv ← timePure "head: qkv bounds" (fun () => Sound.headQKVBounds inputs lnLo lnHi) - IO.println "timing: head qkv bounds done" - flushStdout + timingPrint "timing: head qkv bounds done" + timingFlush if verboseTiming.isSome then - IO.println "timing: head qkv abs force start" - flushStdout + timingPrint "timing: head qkv abs force start" + timingFlush let tAbs0 ← monoUsNow for q in List.finRange seq do for d in List.finRange dHead do @@ -370,36 +379,36 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} let _ := qkv.kAbs q d pure () let tAbs1 ← monoUsNow - IO.println s!"timing: head qkv abs force {tAbs1 - tAbs0} us" - flushStdout - IO.println "timing: head score/value bounds spawn start" - flushStdout + timingPrint s!"timing: head qkv abs force {tAbs1 - tAbs0} us" + timingFlush + timingPrint "timing: head score/value bounds spawn start" + timingFlush let tSpawn0 ← monoUsNow if verboseTiming.isSome then - IO.println "timing: head score dotAbs tasks start" - flushStdout + timingPrint "timing: head score dotAbs tasks start" + timingFlush let dotAbs ← timePure "head: score dotAbs tasks" (fun () => dotAbsFromQKV qkv.qAbs qkv.kAbs) if verboseTiming.isSome then - IO.println "timing: head score dotAbs tasks done" - flushStdout + timingPrint "timing: head score dotAbs tasks done" + timingFlush if verboseTiming.isSome then - IO.println "timing: head score dotAbs force start" - flushStdout + timingPrint "timing: head score dotAbs force start" + timingFlush let tForce0 ← monoUsNow match List.finRange seq with | [] => - IO.println "timing: head score dotAbs force skipped (empty seq)" + timingPrint "timing: head score dotAbs force skipped (empty seq)" | q :: _ => match List.finRange seq with | [] => - IO.println "timing: head score dotAbs force skipped (empty seq)" + timingPrint "timing: head score dotAbs force skipped (empty seq)" | k :: _ => let _ := dotAbs q k pure () let tForce1 ← monoUsNow - IO.println s!"timing: head score dotAbs force {tForce1 - tForce0} us" - flushStdout + timingPrint s!"timing: head score dotAbs force {tForce1 - tForce0} us" + timingFlush let inlineVals := (← IO.getEnv "NFP_TIMING_VALUE_INLINE").isSome let valueMode? ← valueBoundsModeFromEnv let useCommon := valueMode?.getD false @@ -422,16 +431,16 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} if verboseTiming.isSome then timeHeadScoreMarginRaw inputs dotAbs activeList let tSpawn1 ← monoUsNow - IO.println s!"timing: head score/value bounds spawn {tSpawn1 - tSpawn0} us" - flushStdout + timingPrint s!"timing: head score/value bounds spawn {tSpawn1 - tSpawn0} us" + timingFlush let skipScoreBounds := (← IO.getEnv "NFP_TIMING_SKIP_SCORE_BOUNDS").isSome let scoreTaskOpt ← if skipScoreBounds then - IO.println "timing: head score bounds skipped" + timingPrint "timing: head score bounds skipped" pure none else - IO.println "timing: head score bounds from dotAbs start" - flushStdout + timingPrint "timing: head score bounds from dotAbs start" + timingFlush let exactMargin := (← IO.getEnv "NFP_TIMING_EXACT_MARGIN").isSome let action := if exactMargin then @@ -441,65 +450,65 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} let t ← action.asTask pure (some t) if verboseTiming.isSome then - IO.println "timing: head value parts start" - flushStdout - IO.println "timing: head value dirHead start" - flushStdout + timingPrint "timing: head value parts start" + timingFlush + timingPrint "timing: head value dirHead start" + timingFlush let tDir0 ← monoUsNow let dirHead := Sound.headValueDirHead inputs match List.finRange dHead with | [] => - IO.println "timing: head value dirHead forced skipped (empty dHead)" + timingPrint "timing: head value dirHead forced skipped (empty dHead)" | d :: _ => let _ := dirHead d pure () let tDir1 ← monoUsNow - IO.println s!"timing: head value dirHead {tDir1 - tDir0} us" - flushStdout - IO.println "timing: head value valsLo start" - flushStdout + timingPrint s!"timing: head value dirHead {tDir1 - tDir0} us" + timingFlush + timingPrint "timing: head value valsLo start" + timingFlush let tLo0 ← monoUsNow let valsLo := Sound.headValueValsLo inputs qkv.vLo qkv.vHi match List.finRange seq with | [] => - IO.println "timing: head value valsLo forced skipped (empty seq)" + timingPrint "timing: head value valsLo forced skipped (empty seq)" | k :: _ => let _ := valsLo k pure () let tLo1 ← monoUsNow - IO.println s!"timing: head value valsLo {tLo1 - tLo0} us" - flushStdout - IO.println "timing: head value valsHi start" - flushStdout + timingPrint s!"timing: head value valsLo {tLo1 - tLo0} us" + timingFlush + timingPrint "timing: head value valsHi start" + timingFlush let tHi0 ← monoUsNow let valsHi := Sound.headValueValsHi inputs qkv.vLo qkv.vHi match List.finRange seq with | [] => - IO.println "timing: head value valsHi forced skipped (empty seq)" + timingPrint "timing: head value valsHi forced skipped (empty seq)" | k :: _ => let _ := valsHi k pure () let tHi1 ← monoUsNow - IO.println s!"timing: head value valsHi {tHi1 - tHi0} us" - flushStdout - IO.println "timing: head value lo start" - flushStdout + timingPrint s!"timing: head value valsHi {tHi1 - tHi0} us" + timingFlush + timingPrint "timing: head value lo start" + timingFlush let tLo2 ← monoUsNow let _ := Sound.headValueLo valsLo let tLo3 ← monoUsNow - IO.println s!"timing: head value lo {tLo3 - tLo2} us" - flushStdout - IO.println "timing: head value hi start" - flushStdout + timingPrint s!"timing: head value lo {tLo3 - tLo2} us" + timingFlush + timingPrint "timing: head value hi start" + timingFlush let tHi2 ← monoUsNow let _ := Sound.headValueHi valsHi let tHi3 ← monoUsNow - IO.println s!"timing: head value hi {tHi3 - tHi2} us" - flushStdout - IO.println "timing: head value parts done" - flushStdout - IO.println "timing: head value bounds start" - flushStdout + timingPrint s!"timing: head value hi {tHi3 - tHi2} us" + timingFlush + timingPrint "timing: head value parts done" + timingFlush + timingPrint "timing: head value bounds start" + timingFlush let tVals0 ← monoUsNow let vals ← match valsInline?, valsTask? with @@ -511,16 +520,16 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} timePure "head: value bounds inline" (fun () => Sound.headValueBounds inputs qkv.vLo qkv.vHi) let tVals1 ← monoUsNow - IO.println s!"timing: head value bounds {tVals1 - tVals0} us" - flushStdout + timingPrint s!"timing: head value bounds {tVals1 - tVals0} us" + timingFlush let scoreOpt ← match scoreTaskOpt with | none => pure none | some scoreTask => do let res ← IO.wait scoreTask let score ← unwrapTaskResult res - IO.println "timing: head score bounds from dotAbs done" - flushStdout + timingPrint "timing: head score bounds from dotAbs done" + timingFlush pure (some score) match scoreOpt with | none => pure () @@ -532,14 +541,14 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} if verboseTiming.isSome then timeHeadScoreFieldForces score if verboseTiming.isSome then - IO.println "timing: head score bounds force start" - flushStdout + timingPrint "timing: head score bounds force start" + timingFlush let tScore0 ← monoUsNow let _ := score.margin let _ := score.eps let tScore1 ← monoUsNow - IO.println s!"timing: head score bounds force {tScore1 - tScore0} us" - flushStdout + timingPrint s!"timing: head score bounds force {tScore1 - tScore0} us" + timingFlush let coreStages := (← IO.getEnv "NFP_TIMING_CORE_STAGES").isSome let coreStagesOnly := (← IO.getEnv "NFP_TIMING_CORE_STAGES_ONLY").isSome if coreStages then @@ -550,8 +559,8 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} if breakdown then let lnBounds ← timePureWithHeartbeat "breakdown: ln bounds" (fun () => Sound.headLnBounds inputs) - IO.println "timing: breakdown ln bounds force start" - flushStdout + timingPrint "timing: breakdown ln bounds force start" + timingFlush let tLn0 ← monoUsNow for q in List.finRange seq do for i in List.finRange dModel do @@ -559,12 +568,12 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} let _ := lnBounds.2 q i pure () let tLn1 ← monoUsNow - IO.println s!"timing: breakdown ln bounds force {tLn1 - tLn0} us" - flushStdout + timingPrint s!"timing: breakdown ln bounds force {tLn1 - tLn0} us" + timingFlush let qkv ← timePureWithHeartbeat "breakdown: qkv bounds" (fun () => Sound.headQKVBounds inputs lnBounds.1 lnBounds.2) - IO.println "timing: breakdown qkv bounds force start" - flushStdout + timingPrint "timing: breakdown qkv bounds force start" + timingFlush let tQkv0 ← monoUsNow for q in List.finRange seq do for d in List.finRange dHead do @@ -578,8 +587,8 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} let _ := qkv.kAbs q d pure () let tQkv1 ← monoUsNow - IO.println s!"timing: breakdown qkv bounds force {tQkv1 - tQkv0} us" - flushStdout + timingPrint s!"timing: breakdown qkv bounds force {tQkv1 - tQkv0} us" + timingFlush let dotAbs : Fin seq → Fin seq → Rat := fun q k => Sound.Linear.dotFin dHead (fun d => qkv.qAbs q d) (fun d => qkv.kAbs k d) let dotAbsRowTasks : @@ -590,16 +599,16 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} ⟨Array.ofFn (fun k : Fin seq => dotAbs q k), by simp⟩))) let dotAbsRowDefault : Task { row : Array Rat // row.size = seq } := Task.spawn (fun _ => ⟨Array.ofFn (fun _ : Fin seq => (0 : Rat)), by simp⟩) - IO.println "timing: breakdown score dotAbs force start" - flushStdout + timingPrint "timing: breakdown score dotAbs force start" + timingFlush let tDot0 ← monoUsNow for q in List.finRange seq do let row := (dotAbsRowTasks.getD q.1 dotAbsRowDefault).get let _ := row pure () let tDot1 ← monoUsNow - IO.println s!"timing: breakdown score dotAbs force {tDot1 - tDot0} us" - flushStdout + timingPrint s!"timing: breakdown score dotAbs force {tDot1 - tDot0} us" + timingFlush let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k let scaleAbs : Rat := |inputs.scale| @@ -658,16 +667,16 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} let marginAtCached : Fin seq → Rat ← timePureWithHeartbeat "breakdown: score margin cache" (fun () => Sound.Bounds.cacheBoundThunk marginAtRaw) - IO.println "timing: breakdown score margin force start" - flushStdout + timingPrint "timing: breakdown score margin force start" + timingFlush let tMargin0 ← monoUsNow for q in List.finRange seq do let m := marginAtCached q forceRat m pure () let tMargin1 ← monoUsNow - IO.println s!"timing: breakdown score margin force {tMargin1 - tMargin0} us" - flushStdout + timingPrint s!"timing: breakdown score margin force {tMargin1 - tMargin0} us" + timingFlush let epsAtRaw : Fin seq → Rat := fun q => let m := marginAtCached q if m < 0 then @@ -677,41 +686,41 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} let epsAtCached : Fin seq → Rat ← timePureWithHeartbeat "breakdown: score eps cache" (fun () => Sound.Bounds.cacheBoundThunk epsAtRaw) - IO.println "timing: breakdown score eps force start" - flushStdout + timingPrint "timing: breakdown score eps force start" + timingFlush let tEps0 ← monoUsNow for q in List.finRange seq do let e := epsAtCached q forceRat e pure () let tEps1 ← monoUsNow - IO.println s!"timing: breakdown score eps force {tEps1 - tEps0} us" - flushStdout + timingPrint s!"timing: breakdown score eps force {tEps1 - tEps0} us" + timingFlush let valsLo ← timePureWithHeartbeat "breakdown: value valsLo" (fun () => Sound.headValueValsLo inputs qkv.vLo qkv.vHi) - IO.println "timing: breakdown value valsLo force start" - flushStdout + timingPrint "timing: breakdown value valsLo force start" + timingFlush let tValsLo0 ← monoUsNow for k in List.finRange seq do let v := valsLo k forceRat v pure () let tValsLo1 ← monoUsNow - IO.println s!"timing: breakdown value valsLo force {tValsLo1 - tValsLo0} us" - flushStdout + timingPrint s!"timing: breakdown value valsLo force {tValsLo1 - tValsLo0} us" + timingFlush let valsHi ← timePureWithHeartbeat "breakdown: value valsHi" (fun () => Sound.headValueValsHi inputs qkv.vLo qkv.vHi) - IO.println "timing: breakdown value valsHi force start" - flushStdout + timingPrint "timing: breakdown value valsHi force start" + timingFlush let tValsHi0 ← monoUsNow for k in List.finRange seq do let v := valsHi k forceRat v pure () let tValsHi1 ← monoUsNow - IO.println s!"timing: breakdown value valsHi force {tValsHi1 - tValsHi0} us" - flushStdout - let heartbeatMs ← heartbeatMsFromEnv + timingPrint s!"timing: breakdown value valsHi force {tValsHi1 - tValsHi0} us" + timingFlush + let heartbeatMs ← heartbeatMs let taskMin (t1 t2 : Task Rat) : Task Rat := Task.bind t1 (fun v1 => Task.map (fun v2 => min v1 v2) t2) let taskMax (t1 t2 : Task Rat) : Task Rat := @@ -747,8 +756,8 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} finished := count remaining := chunkTasks.size if finished < remaining then - IO.println s!"timing: breakdown value lo progress {finished}/{remaining}" - flushStdout + timingPrint s!"timing: breakdown value lo progress {finished}/{remaining}" + timingFlush let init := chunkTasks.getD 0 defaultTask let rest := (List.range (chunkTasks.size - 1)).map (fun i => i + 1) pure ((rest.foldl (fun acc i => taskMin acc (chunkTasks.getD i defaultTask)) init).get) @@ -783,8 +792,8 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} finished := count remaining := chunkTasks.size if finished < remaining then - IO.println s!"timing: breakdown value hi progress {finished}/{remaining}" - flushStdout + timingPrint s!"timing: breakdown value hi progress {finished}/{remaining}" + timingFlush let init := chunkTasks.getD 0 defaultTask let rest := (List.range (chunkTasks.size - 1)).map (fun i => i + 1) pure ((rest.foldl (fun acc i => taskMax acc (chunkTasks.getD i defaultTask)) init).get) @@ -800,7 +809,7 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} else let loTask := Sound.headValueLoTask valsLo let hiTask := Sound.headValueHiTask valsHi - let heartbeatMs ← heartbeatMsFromEnv + let heartbeatMs ← heartbeatMs let tLo0 ← monoUsNow if heartbeatMs ≠ 0 then let mut finished := (← IO.hasFinished loTask) @@ -809,12 +818,12 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} finished := (← IO.hasFinished loTask) if !finished then let now ← monoUsNow - IO.println s!"timing: breakdown: value lo running {now - tLo0} us" - flushStdout + timingPrint s!"timing: breakdown: value lo running {now - tLo0} us" + timingFlush let lo := loTask.get let tLo1 ← monoUsNow - IO.println s!"timing: breakdown: value lo {tLo1 - tLo0} us" - flushStdout + timingPrint s!"timing: breakdown: value lo {tLo1 - tLo0} us" + timingFlush let tHi0 ← monoUsNow if heartbeatMs ≠ 0 then let mut finished := (← IO.hasFinished hiTask) @@ -823,12 +832,12 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} finished := (← IO.hasFinished hiTask) if !finished then let now ← monoUsNow - IO.println s!"timing: breakdown: value hi running {now - tHi0} us" - flushStdout + timingPrint s!"timing: breakdown: value hi running {now - tHi0} us" + timingFlush let hi := hiTask.get let tHi1 ← monoUsNow - IO.println s!"timing: breakdown: value hi {tHi1 - tHi0} us" - flushStdout + timingPrint s!"timing: breakdown: value hi {tHi1 - tHi0} us" + timingFlush let _ := lo let _ := hi if (← IO.getEnv "NFP_TIMING_SEQ_REDUCE").isSome then @@ -857,7 +866,7 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} | some ⟨cert, hcert⟩ => let _ := cert.active.card some ⟨cert, hcert⟩) - let heartbeatMs ← heartbeatMsFromEnv + let heartbeatMs ← heartbeatMs if heartbeatMs ≠ 0 then let mut finished := (← IO.hasFinished certTask) while !finished do @@ -865,24 +874,24 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} finished := (← IO.hasFinished certTask) if !finished then let now ← monoUsNow - IO.println s!"timing: head build induction cert running {now - tCert0} us" - flushStdout + timingPrint s!"timing: head build induction cert running {now - tCert0} us" + timingFlush let certOpt ← IO.wait certTask let tCert1 ← monoUsNow logTiming s!"done: head build induction cert {tCert1 - tCert0} us" - IO.println s!"timing: head build induction cert {tCert1 - tCert0} us" - IO.println "timing: head build induction cert returned" - flushStdout + timingPrint s!"timing: head build induction cert {tCert1 - tCert0} us" + timingPrint "timing: head build induction cert returned" + timingFlush match certOpt with | none => IO.eprintln "error: head inputs rejected" return 2 | some ⟨cert, _hcert⟩ => - IO.println "timing: head active count start" - flushStdout + timingPrint "timing: head active count start" + timingFlush let activeCount := cert.active.card - IO.println "timing: head active count done" - flushStdout + timingPrint "timing: head active count done" + timingFlush let defaultMinActive := max 1 (seq / 8) let minActive := minActive?.getD defaultMinActive if activeCount < minActive then @@ -899,14 +908,14 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} s!"error: eps {ratToString cert.eps} \ above maximum {ratToString maxEps}" return 2 - IO.println "timing: head tol start" - flushStdout + timingPrint "timing: head tol start" + timingFlush let tol := cert.eps * (cert.values.hi - cert.values.lo) - IO.println "timing: head tol done" - flushStdout + timingPrint "timing: head tol done" + timingFlush logTiming "start: head logit-diff lower bound" - IO.println "timing: head logit-diff lower bound start" - flushStdout + timingPrint "timing: head logit-diff lower bound start" + timingFlush let logitDiffLB? ← timePure "head: logit-diff lower bound" (fun () => Circuit.logitDiffLowerBound cert.active cert.prev cert.eps cert.values.lo cert.values.hi cert.values.valsLo) @@ -953,8 +962,8 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} let seq := Nat.succ n let _ : NeZero seq := ⟨by simp⟩ logTiming "start: head build induction cert" - IO.println "timing: head build induction cert start" - flushStdout + timingPrint "timing: head build induction cert start" + timingFlush let tCert0 ← monoUsNow let certTask : Task @@ -966,7 +975,7 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} | some ⟨cert, hcert⟩ => let _ := cert.active.card some ⟨cert, hcert⟩) - let heartbeatMs ← heartbeatMsFromEnv + let heartbeatMs ← heartbeatMs if heartbeatMs ≠ 0 then let mut finished := (← IO.hasFinished certTask) while !finished do @@ -974,14 +983,14 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} finished := (← IO.hasFinished certTask) if !finished then let now ← monoUsNow - IO.println s!"timing: head build induction cert running {now - tCert0} us" - flushStdout + timingPrint s!"timing: head build induction cert running {now - tCert0} us" + timingFlush let certOpt ← IO.wait certTask let tCert1 ← monoUsNow logTiming s!"done: head build induction cert {tCert1 - tCert0} us" - IO.println s!"timing: head build induction cert {tCert1 - tCert0} us" - IO.println "timing: head build induction cert returned" - flushStdout + timingPrint s!"timing: head build induction cert {tCert1 - tCert0} us" + timingPrint "timing: head build induction cert returned" + timingFlush match certOpt with | none => IO.eprintln "error: head inputs rejected" @@ -1004,8 +1013,8 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} s!"error: eps {ratToString cert.eps} above maximum {ratToString maxEps}" return 2 logTiming "start: head logit-diff lower bound" - IO.println "timing: head logit-diff lower bound start" - flushStdout + timingPrint "timing: head logit-diff lower bound start" + timingFlush let logitDiffLB? ← timePure "head: logit-diff lower bound" (fun () => Sound.logitDiffLowerBoundFromCert cert) logTiming "done: head logit-diff lower bound" @@ -1037,7 +1046,9 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} /-- Build and check induction certificates from exact head inputs. -/ def runInductionCertifyHead (inputsPath : System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + (minMarginStr? : Option String) (maxEpsStr? : Option String) + (timing? : Option Nat) (heartbeatMs? : Option Nat) : IO UInt32 := do + configureTiming timing? heartbeatMs? let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? let minMargin?E := parseRatOpt "min-margin" minMarginStr? let maxEps?E := parseRatOpt "max-eps" maxEpsStr? @@ -1066,7 +1077,9 @@ def runInductionCertifyHead (inputsPath : System.FilePath) /-- Build and check a strictly positive induction logit-diff bound from head inputs. -/ def runInductionCertifyHeadNonvacuous (inputsPath : System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + (minMarginStr? : Option String) (maxEpsStr? : Option String) + (timing? : Option Nat) (heartbeatMs? : Option Nat) : IO UInt32 := do + configureTiming timing? heartbeatMs? let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? let minMargin?E := parseRatOpt "min-margin" minMarginStr? let maxEps?E := parseRatOpt "max-eps" maxEpsStr? @@ -1096,7 +1109,9 @@ def runInductionCertifyHeadNonvacuous (inputsPath : System.FilePath) def runInductionCertifyHeadModel (modelPath : System.FilePath) (layer head dirTarget dirNegative : Nat) (period? : Option Nat) (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + (minMarginStr? : Option String) (maxEpsStr? : Option String) + (timing? : Option Nat) (heartbeatMs? : Option Nat) : IO UInt32 := do + configureTiming timing? heartbeatMs? let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? let minMargin?E := parseRatOpt "min-margin" minMarginStr? let maxEps?E := parseRatOpt "max-eps" maxEpsStr? @@ -1114,8 +1129,8 @@ def runInductionCertifyHeadModel (modelPath : System.FilePath) let minMargin := minMargin?.getD (0 : Rat) let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) logTiming "start: read model file" - IO.println "timing: read model file start" - flushStdout + timingPrint "timing: read model file start" + timingFlush let data ← timePhase "read model file" <| IO.FS.readBinFile modelPath let headerE ← timePure "parse model header" (fun () => NfptPure.parseHeader data) @@ -1138,7 +1153,9 @@ def runInductionCertifyHeadModel (modelPath : System.FilePath) def runInductionCertifyHeadModelNonvacuous (modelPath : System.FilePath) (layer head dirTarget dirNegative : Nat) (period? : Option Nat) (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + (minMarginStr? : Option String) (maxEpsStr? : Option String) + (timing? : Option Nat) (heartbeatMs? : Option Nat) : IO UInt32 := do + configureTiming timing? heartbeatMs? let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? let minMargin?E := parseRatOpt "min-margin" minMarginStr? let maxEps?E := parseRatOpt "max-eps" maxEpsStr? @@ -1156,8 +1173,8 @@ def runInductionCertifyHeadModelNonvacuous (modelPath : System.FilePath) let minMargin := minMargin?.getD (0 : Rat) let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) logTiming "start: read model file" - IO.println "timing: read model file start" - flushStdout + timingPrint "timing: read model file start" + timingFlush let data ← timePhase "read model file" <| IO.FS.readBinFile modelPath let headerE ← timePure "parse model header" (fun () => NfptPure.parseHeader data) @@ -1209,7 +1226,9 @@ prompt sequence. -/ def runInductionCertifyHeadModelAuto (modelPath : System.FilePath) (layer head : Nat) (period? : Option Nat) (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + (minMarginStr? : Option String) (maxEpsStr? : Option String) + (timing? : Option Nat) (heartbeatMs? : Option Nat) : IO UInt32 := do + configureTiming timing? heartbeatMs? let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? let minMargin?E := parseRatOpt "min-margin" minMarginStr? let maxEps?E := parseRatOpt "max-eps" maxEpsStr? @@ -1227,8 +1246,8 @@ def runInductionCertifyHeadModelAuto (modelPath : System.FilePath) let minMargin := minMargin?.getD (0 : Rat) let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) logTiming "start: read model file" - IO.println "timing: read model file start" - flushStdout + timingPrint "timing: read model file start" + timingFlush let data ← timePhase "read model file" <| IO.FS.readBinFile modelPath let headerE ← timePure "parse model header" (fun () => NfptPure.parseHeader data) @@ -1267,7 +1286,9 @@ direction tokens from the prompt sequence. -/ def runInductionCertifyHeadModelAutoNonvacuous (modelPath : System.FilePath) (layer head : Nat) (period? : Option Nat) (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + (minMarginStr? : Option String) (maxEpsStr? : Option String) + (timing? : Option Nat) (heartbeatMs? : Option Nat) : IO UInt32 := do + configureTiming timing? heartbeatMs? let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? let minMargin?E := parseRatOpt "min-margin" minMarginStr? let maxEps?E := parseRatOpt "max-eps" maxEpsStr? @@ -1285,8 +1306,8 @@ def runInductionCertifyHeadModelAutoNonvacuous (modelPath : System.FilePath) let minMargin := minMargin?.getD (0 : Rat) let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) logTiming "start: read model file" - IO.println "timing: read model file start" - flushStdout + timingPrint "timing: read model file start" + timingFlush let data ← timePhase "read model file" <| IO.FS.readBinFile modelPath let headerE ← timePure "parse model header" (fun () => NfptPure.parseHeader data) diff --git a/Nfp/IO/Timing.lean b/Nfp/IO/Timing.lean index 49006a1..eebeb29 100644 --- a/Nfp/IO/Timing.lean +++ b/Nfp/IO/Timing.lean @@ -19,6 +19,64 @@ def monoUsNow : IO Nat := do let t ← IO.monoNanosNow return t / 1000 +/-! Timing configuration -/ + +/-- Runtime configuration for timing output. -/ +structure TimingConfig where + /-- Optional stdout override for timing output. -/ + stdout? : Option Bool + /-- Optional heartbeat interval override (ms). -/ + heartbeatMs? : Option UInt32 + +initialize timingConfig : IO.Ref TimingConfig ← + IO.mkRef { stdout? := none, heartbeatMs? := none } + +/-- Enable or disable timing stdout output. -/ +def setTimingStdout (enabled : Bool) : IO Unit := do + timingConfig.modify (fun cfg => { cfg with stdout? := some enabled }) + +/-- Override the heartbeat interval (ms). -/ +def setTimingHeartbeatMs (ms : UInt32) : IO Unit := do + timingConfig.modify (fun cfg => { cfg with heartbeatMs? := some ms }) + +/-- Resolve whether timing output should be printed. -/ +def timingStdoutEnabled : IO Bool := do + let cfg ← timingConfig.get + match cfg.stdout? with + | some enabled => return enabled + | none => + match (← IO.getEnv "NFP_TIMING_STDOUT") with + | some "1" => return true + | some "true" => return true + | some "yes" => return true + | _ => return false + +/-- Resolve the heartbeat interval (ms), respecting overrides. -/ +def timingHeartbeatMs : IO UInt32 := do + let cfg ← timingConfig.get + match cfg.heartbeatMs? with + | some ms => return ms + | none => + let defaultMs : Nat := 0 + let ms := + (← IO.getEnv "NFP_TIMING_HEARTBEAT_MS").bind String.toNat? |>.getD defaultMs + return UInt32.ofNat ms + +/-- Print a timing line only when stdout timing is enabled. -/ +def timingPrint (line : String) : IO Unit := do + if (← timingStdoutEnabled) then + IO.println line + else + pure () + +/-- Flush stdout only when timing output is enabled. -/ +def timingFlush : IO Unit := do + if (← timingStdoutEnabled) then + let h ← IO.getStdout + h.flush + else + pure () + /-- Append a timing log line to `NFP_TIMING_LOG` when set. -/ def logTiming (line : String) : IO Unit := do match (← IO.getEnv "NFP_TIMING_LOG") with @@ -28,34 +86,34 @@ def logTiming (line : String) : IO Unit := do h.flush | none => pure () -/-- Time an IO phase and print the duration in microseconds. -/ +/-- Time an IO phase and print the duration when timing output is enabled. -/ def timePhase {α : Type} (label : String) (act : IO α) : IO α := do logTiming s!"start: {label}" let t0 ← monoUsNow let res ← act let t1 ← monoUsNow logTiming s!"done: {label} {t1 - t0} us" - IO.println s!"timing: {label} {t1 - t0} us" + timingPrint s!"timing: {label} {t1 - t0} us" return res -/-- Time an IO phase supplied as a thunk and print the duration in microseconds. -/ +/-- Time an IO phase supplied as a thunk and print the duration when timing output is enabled. -/ def timePhaseThunk {α : Type} (label : String) (act : Unit → IO α) : IO α := do logTiming s!"start: {label}" let t0 ← monoUsNow let res ← act () let t1 ← monoUsNow logTiming s!"done: {label} {t1 - t0} us" - IO.println s!"timing: {label} {t1 - t0} us" + timingPrint s!"timing: {label} {t1 - t0} us" return res -/-- Time a pure thunk and print the duration in microseconds. -/ +/-- Time a pure thunk and print the duration when timing output is enabled. -/ def timePure {α : Type} (label : String) (f : Unit → α) : IO α := do logTiming s!"start: {label}" let t0 ← monoUsNow let res := f () let t1 ← monoUsNow logTiming s!"done: {label} {t1 - t0} us" - IO.println s!"timing: {label} {t1 - t0} us" + timingPrint s!"timing: {label} {t1 - t0} us" return res /-- Flush stdout immediately for interleaved timing output. -/ @@ -66,7 +124,7 @@ def flushStdout : IO Unit := do /-- Measure task spawn/get overhead on this machine. -/ def taskBench (n : Nat) : IO Unit := do if n = 0 then - IO.println "timing: task bench skipped (n=0)" + timingPrint "timing: task bench skipped (n=0)" return let t0 ← monoUsNow let tasks := (List.range n).map (fun _ => Task.spawn (fun _ => ())) @@ -76,48 +134,48 @@ def taskBench (n : Nat) : IO Unit := do let t1 ← monoUsNow let total := t1 - t0 let avg := total / n - IO.println s!"timing: task bench n={n} total={total} us avg={avg} us" + timingPrint s!"timing: task bench n={n} total={total} us avg={avg} us" /-- Force a sample score-gap computation for timing. -/ def timeHeadScoreSampleGap {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (score : Sound.HeadScoreBounds seq dModel dHead) : IO Unit := do - IO.println "timing: head score sample gap start" - (← IO.getStdout).flush + timingPrint "timing: head score sample gap start" + timingFlush let t0 ← monoUsNow match List.finRange seq with | [] => - IO.println "timing: head score sample gap skipped (empty seq)" + timingPrint "timing: head score sample gap skipped (empty seq)" | q :: _ => let _ := score.scoreLo q (inputs.prev q) let _ := score.scoreHi q (inputs.prev q) let _ := score.scoreLo q (inputs.prev q) - score.scoreHi q (inputs.prev q) pure () let t1 ← monoUsNow - IO.println s!"timing: head score sample gap {t1 - t0} us" - (← IO.getStdout).flush + timingPrint s!"timing: head score sample gap {t1 - t0} us" + timingFlush /-- Force marginAt evaluation over the active list for timing. -/ def timeHeadScoreMarginList {seq dModel dHead : Nat} (activeList : List (Fin seq)) (score : Sound.HeadScoreBounds seq dModel dHead) : IO Unit := do - IO.println "timing: head score marginAt list start" - (← IO.getStdout).flush + timingPrint "timing: head score marginAt list start" + timingFlush let t0 ← monoUsNow for q in activeList do let _ := score.marginAt q pure () let t1 ← monoUsNow - IO.println s!"timing: head score marginAt list {t1 - t0} us" - (← IO.getStdout).flush + timingPrint s!"timing: head score marginAt list {t1 - t0} us" + timingFlush /-- Force marginAt evaluation without constructing the full score bounds record. -/ def timeHeadScoreMarginRaw {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (dotAbs : Fin seq → Fin seq → Rat) (activeList : List (Fin seq)) : IO Unit := do - IO.println "timing: head score marginRaw list start" - (← IO.getStdout).flush + timingPrint "timing: head score marginRaw list start" + timingFlush let t0 ← monoUsNow let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k @@ -164,29 +222,29 @@ def timeHeadScoreMarginRaw {seq dModel dHead : Nat} let _ := marginAtRaw q pure () let t1 ← monoUsNow - IO.println s!"timing: head score marginRaw list {t1 - t0} us" - (← IO.getStdout).flush + timingPrint s!"timing: head score marginRaw list {t1 - t0} us" + timingFlush /-- Force individual score-bound fields to locate slow evaluations. -/ def timeHeadScoreFieldForces {seq dModel dHead : Nat} (score : Sound.HeadScoreBounds seq dModel dHead) : IO Unit := do - IO.println "timing: head score field force start" - (← IO.getStdout).flush + timingPrint "timing: head score field force start" + timingFlush let timeOne (label : String) (f : Unit → IO Unit) : IO Unit := do let t0 ← monoUsNow f () let t1 ← monoUsNow - IO.println s!"timing: head score field {label} {t1 - t0} us" - (← IO.getStdout).flush + timingPrint s!"timing: head score field {label} {t1 - t0} us" + timingFlush match List.finRange seq with | [] => - IO.println "timing: head score field force skipped (empty seq)" - (← IO.getStdout).flush + timingPrint "timing: head score field force skipped (empty seq)" + timingFlush | q :: _ => match List.finRange seq with | [] => - IO.println "timing: head score field force skipped (empty seq)" - (← IO.getStdout).flush + timingPrint "timing: head score field force skipped (empty seq)" + timingFlush | k :: _ => timeOne "scoreBaseAbs" (fun _ => do let _ := score.scoreBaseAbs q k; pure ()) timeOne "scoreAbs" (fun _ => do let _ := score.scoreAbs q k; pure ()) @@ -196,8 +254,8 @@ def timeHeadScoreFieldForces {seq dModel dHead : Nat} timeOne "epsAt" (fun _ => do let _ := score.epsAt q; pure ()) timeOne "margin" (fun _ => do let _ := score.margin; pure ()) timeOne "eps" (fun _ => do let _ := score.eps; pure ()) - IO.println "timing: head score field force done" - (← IO.getStdout).flush + timingPrint "timing: head score field force done" + timingFlush end IO diff --git a/Nfp/Sound/Induction/Core.lean b/Nfp/Sound/Induction/Core.lean index 1ea11e2..8649e11 100644 --- a/Nfp/Sound/Induction/Core.lean +++ b/Nfp/Sound/Induction/Core.lean @@ -477,10 +477,13 @@ def buildInductionHeadCoreCache [NeZero seq] {dModel dHead : Nat} Task.spawn (fun _ => let dimsQ := splitDimsQ q ⟨Array.ofFn (fun k : Fin seq => - let dimsK := splitDimsK q k - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsK - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo k d) (fun d => kHi k d)), + if masked q k then + (0, 0) + else + let dimsK := splitDimsK q k + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsK + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo k d) (fun d => kHi k d)), by simp⟩)) let dotDiffRowTasksBase : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := Array.ofFn (fun q : Fin seq => diff --git a/Nfp/Sound/Induction/CoreSound.lean b/Nfp/Sound/Induction/CoreSound.lean index 4c8fbff..6b18354 100644 --- a/Nfp/Sound/Induction/CoreSound.lean +++ b/Nfp/Sound/Induction/CoreSound.lean @@ -250,10 +250,13 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} Task.spawn (fun _ => let dimsQ := splitDimsQ q ⟨Array.ofFn (fun k : Fin seq => - let dimsK := splitDimsK q k - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsK - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo k d) (fun d => kHi k d)), + if masked q k then + (0, 0) + else + let dimsK := splitDimsK q k + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsK + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo k d) (fun d => kHi k d)), by simp⟩)) let dotDiffRowTasksBase : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := Array.ofFn (fun q : Fin seq => @@ -654,7 +657,7 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} let base := (inputs.scale : Real) * dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) - have hdot_bounds : + have hdot_bounds (hnot : ¬ masked q k) : (dotLo q k : Real) ≤ dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) ∧ @@ -682,12 +685,12 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} (dotLo q k : Real) ≤ dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) := by - simpa [dotLo, dotRowTasks, Task.spawn, Array.getElem_ofFn] + simpa [dotLo, dotRowTasks, Task.spawn, Array.getElem_ofFn, hnot] using hspec.1 have hhigh' : dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) ≤ (dotHi q k : Real) := by - simpa [dotHi, dotRowTasks, Task.spawn, Array.getElem_ofFn] + simpa [dotHi, dotRowTasks, Task.spawn, Array.getElem_ofFn, hnot] using hspec.2 exact ⟨hlow', hhigh'⟩ have hscore_base_bounds (hnot : ¬ masked q k) : @@ -695,8 +698,9 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} by_cases hscale : 0 ≤ inputs.scale · have hscale_real : 0 ≤ (inputs.scale : Real) := ratToReal_nonneg_of_nonneg hscale - have hlow := mul_le_mul_of_nonneg_left hdot_bounds.1 hscale_real - have hhigh := mul_le_mul_of_nonneg_left hdot_bounds.2 hscale_real + have hdot := hdot_bounds hnot + have hlow := mul_le_mul_of_nonneg_left hdot.1 hscale_real + have hhigh := mul_le_mul_of_nonneg_left hdot.2 hscale_real constructor · simpa [scoreLo, masked, hnot, hscale, base] using hlow · simpa [scoreHi, masked, hnot, hscale, base] using hhigh @@ -704,8 +708,9 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} le_of_lt (lt_of_not_ge hscale) have hscale_real : (inputs.scale : Real) ≤ 0 := (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos - have hlow := mul_le_mul_of_nonpos_left hdot_bounds.2 hscale_real - have hhigh := mul_le_mul_of_nonpos_left hdot_bounds.1 hscale_real + have hdot := hdot_bounds hnot + have hlow := mul_le_mul_of_nonpos_left hdot.2 hscale_real + have hhigh := mul_le_mul_of_nonpos_left hdot.1 hscale_real constructor · simpa [scoreLo, masked, hnot, hscale, base] using hlow · simpa [scoreHi, masked, hnot, hscale, base] using hhigh From 4c03ba45fa265cb3a49b3284f68f295646fa635f Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 01:12:08 +0100 Subject: [PATCH 148/244] Add split-budget config for induction certs --- Nfp/Cli.lean | 66 ++++++++++++++++++ Nfp/IO/InductionHead.lean | 76 +++++++++++++++------ Nfp/IO/Timing.lean | 2 + Nfp/Sound/Induction/Core.lean | 102 ++++++++++++++++++++++++---- Nfp/Sound/Induction/CoreDefs.lean | 26 +++++++ Nfp/Sound/Induction/CoreSound.lean | 46 +++++++++---- Nfp/Sound/Induction/HeadOutput.lean | 14 +++- 7 files changed, 283 insertions(+), 49 deletions(-) diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index d2086f5..505acb9 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -178,8 +178,13 @@ def runInductionCertifyHead (p : Parsed) : IO UInt32 := do let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) let timing? := (p.flag? "timing").map (·.as! Nat) let heartbeatMs? := (p.flag? "heartbeat-ms").map (·.as! Nat) + let splitBudgetQ? := (p.flag? "split-budget-q").map (·.as! Nat) + let splitBudgetK? := (p.flag? "split-budget-k").map (·.as! Nat) + let splitBudgetDiffBase? := (p.flag? "split-budget-diff-base").map (·.as! Nat) + let splitBudgetDiffRefined? := (p.flag? "split-budget-diff-refined").map (·.as! Nat) IO.runInductionCertifyHead inputsPath minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? + splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? /-- `nfp induction certify_head_nonvacuous` subcommand. -/ def runInductionCertifyHeadNonvacuous (p : Parsed) : IO UInt32 := do @@ -190,8 +195,13 @@ def runInductionCertifyHeadNonvacuous (p : Parsed) : IO UInt32 := do let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) let timing? := (p.flag? "timing").map (·.as! Nat) let heartbeatMs? := (p.flag? "heartbeat-ms").map (·.as! Nat) + let splitBudgetQ? := (p.flag? "split-budget-q").map (·.as! Nat) + let splitBudgetK? := (p.flag? "split-budget-k").map (·.as! Nat) + let splitBudgetDiffBase? := (p.flag? "split-budget-diff-base").map (·.as! Nat) + let splitBudgetDiffRefined? := (p.flag? "split-budget-diff-refined").map (·.as! Nat) IO.runInductionCertifyHeadNonvacuous inputsPath minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? + splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? /-- `nfp induction certify_head` subcommand. -/ def inductionCertifyHeadCmd : Cmd := `[Cli| @@ -207,6 +217,12 @@ def inductionCertifyHeadCmd : Cmd := `[Cli| "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." timing : Nat; "Emit timing output to stdout (0=off, 1=on)." "heartbeat-ms" : Nat; "Emit progress heartbeat every N ms (0 disables)." + "split-budget-q" : Nat; "Split-budget for query dims in sign-splitting bounds (default: 2)." + "split-budget-k" : Nat; "Split-budget for key dims in sign-splitting bounds (default: 2)." + "split-budget-diff-base" : Nat; "Split-budget for base diff dims in sign-splitting bounds \ + (default: 0)." + "split-budget-diff-refined" : Nat; "Split-budget for refined diff dims in sign-splitting \ + bounds (default: 12)." ] /-- `nfp induction certify_head_nonvacuous` subcommand. -/ @@ -223,6 +239,12 @@ def inductionCertifyHeadNonvacuousCmd : Cmd := `[Cli| "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." timing : Nat; "Emit timing output to stdout (0=off, 1=on)." "heartbeat-ms" : Nat; "Emit progress heartbeat every N ms (0 disables)." + "split-budget-q" : Nat; "Split-budget for query dims in sign-splitting bounds (default: 2)." + "split-budget-k" : Nat; "Split-budget for key dims in sign-splitting bounds (default: 2)." + "split-budget-diff-base" : Nat; "Split-budget for base diff dims in sign-splitting bounds \ + (default: 0)." + "split-budget-diff-refined" : Nat; "Split-budget for refined diff dims in sign-splitting \ + bounds (default: 12)." ] /-- `nfp induction certify_head_model` subcommand. -/ @@ -239,8 +261,13 @@ def runInductionCertifyHeadModel (p : Parsed) : IO UInt32 := do let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) let timing? := (p.flag? "timing").map (·.as! Nat) let heartbeatMs? := (p.flag? "heartbeat-ms").map (·.as! Nat) + let splitBudgetQ? := (p.flag? "split-budget-q").map (·.as! Nat) + let splitBudgetK? := (p.flag? "split-budget-k").map (·.as! Nat) + let splitBudgetDiffBase? := (p.flag? "split-budget-diff-base").map (·.as! Nat) + let splitBudgetDiffRefined? := (p.flag? "split-budget-diff-refined").map (·.as! Nat) IO.runInductionCertifyHeadModel modelPath layer head dirTarget dirNegative period? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? + splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? /-- `nfp induction certify_head_model_nonvacuous` subcommand. -/ def runInductionCertifyHeadModelNonvacuous (p : Parsed) : IO UInt32 := do @@ -256,8 +283,13 @@ def runInductionCertifyHeadModelNonvacuous (p : Parsed) : IO UInt32 := do let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) let timing? := (p.flag? "timing").map (·.as! Nat) let heartbeatMs? := (p.flag? "heartbeat-ms").map (·.as! Nat) + let splitBudgetQ? := (p.flag? "split-budget-q").map (·.as! Nat) + let splitBudgetK? := (p.flag? "split-budget-k").map (·.as! Nat) + let splitBudgetDiffBase? := (p.flag? "split-budget-diff-base").map (·.as! Nat) + let splitBudgetDiffRefined? := (p.flag? "split-budget-diff-refined").map (·.as! Nat) IO.runInductionCertifyHeadModelNonvacuous modelPath layer head dirTarget dirNegative period? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? + splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? /-- `nfp induction certify_head_model_auto` subcommand. -/ def runInductionCertifyHeadModelAuto (p : Parsed) : IO UInt32 := do @@ -271,8 +303,13 @@ def runInductionCertifyHeadModelAuto (p : Parsed) : IO UInt32 := do let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) let timing? := (p.flag? "timing").map (·.as! Nat) let heartbeatMs? := (p.flag? "heartbeat-ms").map (·.as! Nat) + let splitBudgetQ? := (p.flag? "split-budget-q").map (·.as! Nat) + let splitBudgetK? := (p.flag? "split-budget-k").map (·.as! Nat) + let splitBudgetDiffBase? := (p.flag? "split-budget-diff-base").map (·.as! Nat) + let splitBudgetDiffRefined? := (p.flag? "split-budget-diff-refined").map (·.as! Nat) IO.runInductionCertifyHeadModelAuto modelPath layer head period? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? + splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? /-- `nfp induction certify_head_model_auto_nonvacuous` subcommand. -/ def runInductionCertifyHeadModelAutoNonvacuous (p : Parsed) : IO UInt32 := do @@ -286,8 +323,13 @@ def runInductionCertifyHeadModelAutoNonvacuous (p : Parsed) : IO UInt32 := do let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) let timing? := (p.flag? "timing").map (·.as! Nat) let heartbeatMs? := (p.flag? "heartbeat-ms").map (·.as! Nat) + let splitBudgetQ? := (p.flag? "split-budget-q").map (·.as! Nat) + let splitBudgetK? := (p.flag? "split-budget-k").map (·.as! Nat) + let splitBudgetDiffBase? := (p.flag? "split-budget-diff-base").map (·.as! Nat) + let splitBudgetDiffRefined? := (p.flag? "split-budget-diff-refined").map (·.as! Nat) IO.runInductionCertifyHeadModelAutoNonvacuous modelPath layer head period? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? + splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? /-- `nfp induction certify_head_model` subcommand. -/ def inductionCertifyHeadModelCmd : Cmd := `[Cli| @@ -308,6 +350,12 @@ def inductionCertifyHeadModelCmd : Cmd := `[Cli| "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." timing : Nat; "Emit timing output to stdout (0=off, 1=on)." "heartbeat-ms" : Nat; "Emit progress heartbeat every N ms (0 disables)." + "split-budget-q" : Nat; "Split-budget for query dims in sign-splitting bounds (default: 2)." + "split-budget-k" : Nat; "Split-budget for key dims in sign-splitting bounds (default: 2)." + "split-budget-diff-base" : Nat; "Split-budget for base diff dims in sign-splitting bounds \ + (default: 0)." + "split-budget-diff-refined" : Nat; "Split-budget for refined diff dims in sign-splitting \ + bounds (default: 12)." ] /-- `nfp induction certify_head_model_nonvacuous` subcommand. -/ @@ -329,6 +377,12 @@ def inductionCertifyHeadModelNonvacuousCmd : Cmd := `[Cli| "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." timing : Nat; "Emit timing output to stdout (0=off, 1=on)." "heartbeat-ms" : Nat; "Emit progress heartbeat every N ms (0 disables)." + "split-budget-q" : Nat; "Split-budget for query dims in sign-splitting bounds (default: 2)." + "split-budget-k" : Nat; "Split-budget for key dims in sign-splitting bounds (default: 2)." + "split-budget-diff-base" : Nat; "Split-budget for base diff dims in sign-splitting bounds \ + (default: 0)." + "split-budget-diff-refined" : Nat; "Split-budget for refined diff dims in sign-splitting \ + bounds (default: 12)." ] /-- `nfp induction certify_head_model_auto` subcommand. -/ @@ -349,6 +403,12 @@ def inductionCertifyHeadModelAutoCmd : Cmd := `[Cli| "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." timing : Nat; "Emit timing output to stdout (0=off, 1=on)." "heartbeat-ms" : Nat; "Emit progress heartbeat every N ms (0 disables)." + "split-budget-q" : Nat; "Split-budget for query dims in sign-splitting bounds (default: 2)." + "split-budget-k" : Nat; "Split-budget for key dims in sign-splitting bounds (default: 2)." + "split-budget-diff-base" : Nat; "Split-budget for base diff dims in sign-splitting bounds \ + (default: 0)." + "split-budget-diff-refined" : Nat; "Split-budget for refined diff dims in sign-splitting \ + bounds (default: 12)." ] /-- `nfp induction certify_head_model_auto_nonvacuous` subcommand. -/ @@ -369,6 +429,12 @@ def inductionCertifyHeadModelAutoNonvacuousCmd : Cmd := `[Cli| "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." timing : Nat; "Emit timing output to stdout (0=off, 1=on)." "heartbeat-ms" : Nat; "Emit progress heartbeat every N ms (0 disables)." + "split-budget-q" : Nat; "Split-budget for query dims in sign-splitting bounds (default: 2)." + "split-budget-k" : Nat; "Split-budget for key dims in sign-splitting bounds (default: 2)." + "split-budget-diff-base" : Nat; "Split-budget for base diff dims in sign-splitting bounds \ + (default: 0)." + "split-budget-diff-refined" : Nat; "Split-budget for refined diff dims in sign-splitting \ + bounds (default: 12)." ] /-- `nfp induction head_interval` subcommand. -/ diff --git a/Nfp/IO/InductionHead.lean b/Nfp/IO/InductionHead.lean index b0ec922..12cc169 100644 --- a/Nfp/IO/InductionHead.lean +++ b/Nfp/IO/InductionHead.lean @@ -37,6 +37,16 @@ private def configureTiming (timing? : Option Nat) (heartbeatMs? : Option Nat) : setTimingStdout true | none => pure () +private def splitConfigFromOptions + (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : + Sound.InductionHeadSplitConfig := + let base := Sound.defaultInductionHeadSplitConfig + { base with + splitBudgetQ := splitBudgetQ?.getD base.splitBudgetQ + splitBudgetK := splitBudgetK?.getD base.splitBudgetK + splitBudgetDiffBase := splitBudgetDiffBase?.getD base.splitBudgetDiffBase + splitBudgetDiffRefined := splitBudgetDiffRefined?.getD base.splitBudgetDiffRefined } + open Nfp.Circuit private def valueBoundsModeFromEnv : IO (Option Bool) := do @@ -332,6 +342,7 @@ private def headScoreBoundsFromQAbsKAbsTimed {seq dModel dHead : Nat} [NeZero se private def checkInductionHeadInputs {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) + (cfg : Sound.InductionHeadSplitConfig) (minActive? : Option Nat) (minLogitDiff? : Option Rat) (minMargin maxEps : Rat) : IO UInt32 := do match seq with @@ -720,7 +731,7 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} let tValsHi1 ← monoUsNow timingPrint s!"timing: breakdown value valsHi force {tValsHi1 - tValsHi0} us" timingFlush - let heartbeatMs ← heartbeatMs + let heartbeatMsProgress ← heartbeatMs let taskMin (t1 t2 : Task Rat) : Task Rat := Task.bind t1 (fun v1 => Task.map (fun v2 => min v1 v2) t2) let taskMax (t1 t2 : Task Rat) : Task Rat := @@ -744,11 +755,11 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} else let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) rest.foldl (fun acc i => taskMin acc (tasks.getD i defaultTask)) init) - if heartbeatMs ≠ 0 then + if heartbeatMsProgress ≠ 0 then let mut finished := 0 let mut remaining := chunkTasks.size while finished < remaining do - IO.sleep heartbeatMs + IO.sleep heartbeatMsProgress let mut count := 0 for t in chunkTasks do if (← IO.hasFinished t) then @@ -780,11 +791,11 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} else let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) rest.foldl (fun acc i => taskMax acc (tasks.getD i defaultTask)) init) - if heartbeatMs ≠ 0 then + if heartbeatMsProgress ≠ 0 then let mut finished := 0 let mut remaining := chunkTasks.size while finished < remaining do - IO.sleep heartbeatMs + IO.sleep heartbeatMsProgress let mut count := 0 for t in chunkTasks do if (← IO.hasFinished t) then @@ -861,7 +872,7 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} (Option { c : Sound.InductionHeadCert seq // Sound.InductionHeadCertSound inputs c }) := Task.spawn (prio := Task.Priority.dedicated) (fun _ => - match Sound.buildInductionCertFromHead? inputs with + match Sound.buildInductionCertFromHeadWith? cfg inputs with | none => none | some ⟨cert, hcert⟩ => let _ := cert.active.card @@ -952,6 +963,7 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) + (cfg : Sound.InductionHeadSplitConfig) (minActive? : Option Nat) (minLogitDiff? : Option Rat) (minMargin maxEps : Rat) : IO UInt32 := do match seq with @@ -970,7 +982,7 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} (Option { c : Sound.InductionHeadCert seq // Sound.InductionHeadCertSound inputs c }) := Task.spawn (prio := Task.Priority.dedicated) (fun _ => - match Sound.buildInductionCertFromHead? inputs with + match Sound.buildInductionCertFromHeadWith? cfg inputs with | none => none | some ⟨cert, hcert⟩ => let _ := cert.active.card @@ -1047,8 +1059,12 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} def runInductionCertifyHead (inputsPath : System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) - (timing? : Option Nat) (heartbeatMs? : Option Nat) : IO UInt32 := do + (timing? : Option Nat) (heartbeatMs? : Option Nat) + (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : + IO UInt32 := do configureTiming timing? heartbeatMs? + let splitCfg := + splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? let minMargin?E := parseRatOpt "min-margin" minMarginStr? let maxEps?E := parseRatOpt "max-eps" maxEpsStr? @@ -1072,14 +1088,18 @@ def runInductionCertifyHead (inputsPath : System.FilePath) IO.eprintln s!"error: {msg}" return 1 | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => - checkInductionHeadInputs inputs minActive? minLogitDiff? minMargin maxEps + checkInductionHeadInputs inputs splitCfg minActive? minLogitDiff? minMargin maxEps /-- Build and check a strictly positive induction logit-diff bound from head inputs. -/ def runInductionCertifyHeadNonvacuous (inputsPath : System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) - (timing? : Option Nat) (heartbeatMs? : Option Nat) : IO UInt32 := do + (timing? : Option Nat) (heartbeatMs? : Option Nat) + (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : + IO UInt32 := do configureTiming timing? heartbeatMs? + let splitCfg := + splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? let minMargin?E := parseRatOpt "min-margin" minMarginStr? let maxEps?E := parseRatOpt "max-eps" maxEpsStr? @@ -1103,15 +1123,20 @@ def runInductionCertifyHeadNonvacuous (inputsPath : System.FilePath) IO.eprintln s!"error: {msg}" return 1 | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => - checkInductionHeadInputsNonvacuous inputs minActive? minLogitDiff? minMargin maxEps + checkInductionHeadInputsNonvacuous inputs splitCfg minActive? minLogitDiff? + minMargin maxEps /-- Build and check induction certificates from a model binary. -/ def runInductionCertifyHeadModel (modelPath : System.FilePath) (layer head dirTarget dirNegative : Nat) (period? : Option Nat) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) - (timing? : Option Nat) (heartbeatMs? : Option Nat) : IO UInt32 := do + (timing? : Option Nat) (heartbeatMs? : Option Nat) + (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : + IO UInt32 := do configureTiming timing? heartbeatMs? + let splitCfg := + splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? let minMargin?E := parseRatOpt "min-margin" minMarginStr? let maxEps?E := parseRatOpt "max-eps" maxEpsStr? @@ -1147,15 +1172,19 @@ def runInductionCertifyHeadModel (modelPath : System.FilePath) IO.eprintln s!"error: {msg}" return 1 | Except.ok inputs => - checkInductionHeadInputs inputs minActive? minLogitDiff? minMargin maxEps + checkInductionHeadInputs inputs splitCfg minActive? minLogitDiff? minMargin maxEps /-- Build and check a strictly positive induction logit-diff bound from a model binary. -/ def runInductionCertifyHeadModelNonvacuous (modelPath : System.FilePath) (layer head dirTarget dirNegative : Nat) (period? : Option Nat) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) - (timing? : Option Nat) (heartbeatMs? : Option Nat) : IO UInt32 := do + (timing? : Option Nat) (heartbeatMs? : Option Nat) + (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : + IO UInt32 := do configureTiming timing? heartbeatMs? + let splitCfg := + splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? let minMargin?E := parseRatOpt "min-margin" minMarginStr? let maxEps?E := parseRatOpt "max-eps" maxEpsStr? @@ -1191,7 +1220,8 @@ def runInductionCertifyHeadModelNonvacuous (modelPath : System.FilePath) IO.eprintln s!"error: {msg}" return 1 | Except.ok inputs => - checkInductionHeadInputsNonvacuous inputs minActive? minLogitDiff? minMargin maxEps + checkInductionHeadInputsNonvacuous inputs splitCfg minActive? minLogitDiff? + minMargin maxEps /-- Heuristic logit-diff direction derived from prompt tokens. -/ private def deriveDirectionFromTokens {seq : Nat} (tokens : Fin seq → Nat) : @@ -1227,8 +1257,12 @@ def runInductionCertifyHeadModelAuto (modelPath : System.FilePath) (layer head : Nat) (period? : Option Nat) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) - (timing? : Option Nat) (heartbeatMs? : Option Nat) : IO UInt32 := do + (timing? : Option Nat) (heartbeatMs? : Option Nat) + (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : + IO UInt32 := do configureTiming timing? heartbeatMs? + let splitCfg := + splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? let minMargin?E := parseRatOpt "min-margin" minMarginStr? let maxEps?E := parseRatOpt "max-eps" maxEpsStr? @@ -1278,7 +1312,7 @@ def runInductionCertifyHeadModelAuto (modelPath : System.FilePath) IO.eprintln s!"error: {msg}" return 1 | Except.ok inputs => - checkInductionHeadInputs inputs minActive? minLogitDiff? + checkInductionHeadInputs inputs splitCfg minActive? minLogitDiff? minMargin maxEps /-- Build and check a strictly positive induction logit-diff bound from a model binary, deriving @@ -1287,8 +1321,12 @@ def runInductionCertifyHeadModelAutoNonvacuous (modelPath : System.FilePath) (layer head : Nat) (period? : Option Nat) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) - (timing? : Option Nat) (heartbeatMs? : Option Nat) : IO UInt32 := do + (timing? : Option Nat) (heartbeatMs? : Option Nat) + (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : + IO UInt32 := do configureTiming timing? heartbeatMs? + let splitCfg := + splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? let minMargin?E := parseRatOpt "min-margin" minMarginStr? let maxEps?E := parseRatOpt "max-eps" maxEpsStr? @@ -1338,7 +1376,7 @@ def runInductionCertifyHeadModelAutoNonvacuous (modelPath : System.FilePath) IO.eprintln s!"error: {msg}" return 1 | Except.ok inputs => - checkInductionHeadInputsNonvacuous inputs minActive? minLogitDiff? + checkInductionHeadInputsNonvacuous inputs splitCfg minActive? minLogitDiff? minMargin maxEps /-- Build head-output interval bounds from exact head inputs. -/ diff --git a/Nfp/IO/Timing.lean b/Nfp/IO/Timing.lean index eebeb29..7eccf95 100644 --- a/Nfp/IO/Timing.lean +++ b/Nfp/IO/Timing.lean @@ -27,7 +27,9 @@ structure TimingConfig where stdout? : Option Bool /-- Optional heartbeat interval override (ms). -/ heartbeatMs? : Option UInt32 + deriving Inhabited +/-- Mutable timing configuration (overrides environment defaults). -/ initialize timingConfig : IO.Ref TimingConfig ← IO.mkRef { stdout? := none, heartbeatMs? := none } diff --git a/Nfp/Sound/Induction/Core.lean b/Nfp/Sound/Induction/Core.lean index 8649e11..f5ef393 100644 --- a/Nfp/Sound/Induction/Core.lean +++ b/Nfp/Sound/Induction/Core.lean @@ -241,7 +241,8 @@ structure InductionHeadCoreCache (seq dModel dHead : Nat) where cert : InductionHeadCert seq /-- Build cached core quantities for induction-head certificates. -/ -def buildInductionHeadCoreCache [NeZero seq] {dModel dHead : Nat} +def buildInductionHeadCoreCacheWith [NeZero seq] {dModel dHead : Nat} + (cfg : InductionHeadSplitConfig) (inputs : Model.InductionHeadInputs seq dModel dHead) : InductionHeadCoreCache seq dModel dHead := by classical @@ -376,10 +377,10 @@ def buildInductionHeadCoreCache [NeZero seq] {dModel dHead : Nat} simp [hsize]) let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k - let splitBudgetQ : Nat := 2 - let splitBudgetK : Nat := 2 - let splitBudgetDiffBase : Nat := 0 - let splitBudgetDiffRefined : Nat := 12 + let splitBudgetQ : Nat := cfg.splitBudgetQ + let splitBudgetK : Nat := cfg.splitBudgetK + let splitBudgetDiffBase : Nat := cfg.splitBudgetDiffBase + let splitBudgetDiffRefined : Nat := cfg.splitBudgetDiffRefined let splitDimsQ : Fin seq → List (Fin dHead) := fun q => let ambig := (List.finRange dHead).filter (fun d => decide (qLo q d < 0 ∧ 0 < qHi q d)) @@ -758,6 +759,12 @@ def buildInductionHeadCoreCache [NeZero seq] {dModel dHead : Nat} valCert := valCert cert := cert } +/-- Build cached core quantities for induction-head certificates using the default split budgets. -/ +def buildInductionHeadCoreCache [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : + InductionHeadCoreCache seq dModel dHead := + buildInductionHeadCoreCacheWith defaultInductionHeadSplitConfig inputs + /-- The cached certificate is built from cache fields. -/ theorem buildInductionHeadCoreCache_cert_eq [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : @@ -771,7 +778,8 @@ theorem buildInductionHeadCoreCache_cert_eq [NeZero seq] {dModel dHead : Nat} values := (buildInductionHeadCoreCache inputs).valCert } := by rfl /-- Build induction certificates from exact head inputs (core computation). -/ -def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} +def buildInductionCertFromHeadCoreWith? [NeZero seq] {dModel dHead : Nat} + (cfg : InductionHeadSplitConfig) (inputs : Model.InductionHeadInputs seq dModel dHead) : Option (InductionHeadCert seq) := by classical @@ -780,11 +788,70 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} · by_cases hmodel : dModel = 0 · exact none · by_cases hactive : inputs.active.Nonempty - · exact some (buildInductionHeadCoreCache inputs).cert + · exact some (buildInductionHeadCoreCacheWith cfg inputs).cert · exact none · exact none · exact none +/-- Build induction certificates from exact head inputs using the default split budgets. -/ +def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : + Option (InductionHeadCert seq) := + buildInductionCertFromHeadCoreWith? defaultInductionHeadSplitConfig inputs + +/-- `buildInductionCertFromHeadCoreWith?` succeeds under the guard conditions. -/ +theorem buildInductionCertFromHeadCoreWith?_eq_some [NeZero seq] {dModel dHead : Nat} + (cfg : InductionHeadSplitConfig) + (inputs : Model.InductionHeadInputs seq dModel dHead) + (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) + (hmodel : dModel ≠ 0) (hactive : inputs.active.Nonempty) : + buildInductionCertFromHeadCoreWith? cfg inputs = + some (buildInductionHeadCoreCacheWith cfg inputs).cert := by + classical + simp [buildInductionCertFromHeadCoreWith?, hEps, hSqrt, hmodel, hactive] + +/-- `buildInductionCertFromHeadCoreWith?` fails when `dModel = 0`. -/ +theorem buildInductionCertFromHeadCoreWith?_eq_none_of_model_eq_zero + [NeZero seq] {dModel dHead : Nat} + (cfg : InductionHeadSplitConfig) + (inputs : Model.InductionHeadInputs seq dModel dHead) + (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) + (hmodel : dModel = 0) : + buildInductionCertFromHeadCoreWith? cfg inputs = none := by + classical + simp [buildInductionCertFromHeadCoreWith?, hEps, hSqrt, hmodel] + +/-- `buildInductionCertFromHeadCoreWith?` fails when `active` is empty. -/ +theorem buildInductionCertFromHeadCoreWith?_eq_none_of_not_active + [NeZero seq] {dModel dHead : Nat} + (cfg : InductionHeadSplitConfig) + (inputs : Model.InductionHeadInputs seq dModel dHead) + (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) + (hmodel : dModel ≠ 0) (hactive : ¬inputs.active.Nonempty) : + buildInductionCertFromHeadCoreWith? cfg inputs = none := by + classical + simp [buildInductionCertFromHeadCoreWith?, hEps, hSqrt, hmodel, hactive] + +/-- `buildInductionCertFromHeadCoreWith?` fails when the sqrt lower bound is nonpositive. -/ +theorem buildInductionCertFromHeadCoreWith?_eq_none_of_not_sqrt + [NeZero seq] {dModel dHead : Nat} + (cfg : InductionHeadSplitConfig) + (inputs : Model.InductionHeadInputs seq dModel dHead) + (hEps : 0 < inputs.lnEps) (hSqrt : ¬0 < sqrtLower inputs.lnEps) : + buildInductionCertFromHeadCoreWith? cfg inputs = none := by + classical + simp [buildInductionCertFromHeadCoreWith?, hEps, hSqrt] + +/-- `buildInductionCertFromHeadCoreWith?` fails when `lnEps` is nonpositive. -/ +theorem buildInductionCertFromHeadCoreWith?_eq_none_of_not_eps + [NeZero seq] {dModel dHead : Nat} + (cfg : InductionHeadSplitConfig) + (inputs : Model.InductionHeadInputs seq dModel dHead) + (hEps : ¬0 < inputs.lnEps) : + buildInductionCertFromHeadCoreWith? cfg inputs = none := by + classical + simp [buildInductionCertFromHeadCoreWith?, hEps] + /-- `buildInductionCertFromHeadCore?` succeeds under the guard conditions. -/ theorem buildInductionCertFromHeadCore?_eq_some [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) @@ -793,7 +860,10 @@ theorem buildInductionCertFromHeadCore?_eq_some [NeZero seq] {dModel dHead : Nat buildInductionCertFromHeadCore? inputs = some (buildInductionHeadCoreCache inputs).cert := by classical - simp [buildInductionCertFromHeadCore?, hEps, hSqrt, hmodel, hactive] + simpa [buildInductionCertFromHeadCore?, buildInductionHeadCoreCache] using + (buildInductionCertFromHeadCoreWith?_eq_some + (cfg := defaultInductionHeadSplitConfig) (inputs := inputs) + hEps hSqrt hmodel hactive) /-- `buildInductionCertFromHeadCore?` fails when `dModel = 0`. -/ theorem buildInductionCertFromHeadCore?_eq_none_of_model_eq_zero [NeZero seq] {dModel dHead : Nat} @@ -802,7 +872,9 @@ theorem buildInductionCertFromHeadCore?_eq_none_of_model_eq_zero [NeZero seq] {d (hmodel : dModel = 0) : buildInductionCertFromHeadCore? inputs = none := by classical - simp [buildInductionCertFromHeadCore?, hEps, hSqrt, hmodel] + simpa [buildInductionCertFromHeadCore?] using + (buildInductionCertFromHeadCoreWith?_eq_none_of_model_eq_zero + (cfg := defaultInductionHeadSplitConfig) (inputs := inputs) hEps hSqrt hmodel) /-- `buildInductionCertFromHeadCore?` fails when `active` is empty. -/ theorem buildInductionCertFromHeadCore?_eq_none_of_not_active [NeZero seq] {dModel dHead : Nat} @@ -811,7 +883,9 @@ theorem buildInductionCertFromHeadCore?_eq_none_of_not_active [NeZero seq] {dMod (hmodel : dModel ≠ 0) (hactive : ¬inputs.active.Nonempty) : buildInductionCertFromHeadCore? inputs = none := by classical - simp [buildInductionCertFromHeadCore?, hEps, hSqrt, hmodel, hactive] + simpa [buildInductionCertFromHeadCore?] using + (buildInductionCertFromHeadCoreWith?_eq_none_of_not_active + (cfg := defaultInductionHeadSplitConfig) (inputs := inputs) hEps hSqrt hmodel hactive) /-- `buildInductionCertFromHeadCore?` fails when the sqrt lower bound is nonpositive. -/ theorem buildInductionCertFromHeadCore?_eq_none_of_not_sqrt [NeZero seq] {dModel dHead : Nat} @@ -819,7 +893,9 @@ theorem buildInductionCertFromHeadCore?_eq_none_of_not_sqrt [NeZero seq] {dModel (hEps : 0 < inputs.lnEps) (hSqrt : ¬0 < sqrtLower inputs.lnEps) : buildInductionCertFromHeadCore? inputs = none := by classical - simp [buildInductionCertFromHeadCore?, hEps, hSqrt] + simpa [buildInductionCertFromHeadCore?] using + (buildInductionCertFromHeadCoreWith?_eq_none_of_not_sqrt + (cfg := defaultInductionHeadSplitConfig) (inputs := inputs) hEps hSqrt) /-- `buildInductionCertFromHeadCore?` fails when `lnEps` is nonpositive. -/ theorem buildInductionCertFromHeadCore?_eq_none_of_not_eps [NeZero seq] {dModel dHead : Nat} @@ -827,7 +903,9 @@ theorem buildInductionCertFromHeadCore?_eq_none_of_not_eps [NeZero seq] {dModel (hEps : ¬0 < inputs.lnEps) : buildInductionCertFromHeadCore? inputs = none := by classical - simp [buildInductionCertFromHeadCore?, hEps] + simpa [buildInductionCertFromHeadCore?] using + (buildInductionCertFromHeadCoreWith?_eq_none_of_not_eps + (cfg := defaultInductionHeadSplitConfig) (inputs := inputs) hEps) end Sound end Nfp diff --git a/Nfp/Sound/Induction/CoreDefs.lean b/Nfp/Sound/Induction/CoreDefs.lean index fe0e217..9831e57 100644 --- a/Nfp/Sound/Induction/CoreDefs.lean +++ b/Nfp/Sound/Induction/CoreDefs.lean @@ -165,6 +165,32 @@ structure ValueIntervalBounds {seq : Nat} (vals : Fin seq → Real) /-- `hi` is above every upper bound. -/ valsHi_le_hi : ∀ k, (c.valsHi k : Real) ≤ (c.hi : Real) +/-- Split-budget knobs for sign-splitting bounds in induction-head certificates. -/ +structure InductionHeadSplitConfig where + /-- Split budget for query dims. -/ + splitBudgetQ : Nat + /-- Split budget for key dims. -/ + splitBudgetK : Nat + /-- Split budget for base diff dims. -/ + splitBudgetDiffBase : Nat + /-- Split budget for refined diff dims. -/ + splitBudgetDiffRefined : Nat + +/-- Default split budgets for induction-head sign-splitting bounds. -/ +def defaultInductionHeadSplitConfig : InductionHeadSplitConfig := + { splitBudgetQ := 2 + splitBudgetK := 2 + splitBudgetDiffBase := 0 + splitBudgetDiffRefined := 12 } + +/-- Unfolding lemma for `defaultInductionHeadSplitConfig`. -/ +theorem defaultInductionHeadSplitConfig_def : + defaultInductionHeadSplitConfig = + { splitBudgetQ := 2 + splitBudgetK := 2 + splitBudgetDiffBase := 0 + splitBudgetDiffRefined := 12 } := rfl + /-- Sound induction-certificate payload built from exact head inputs. -/ structure InductionHeadCert (seq : Nat) where /-- Weight tolerance. -/ diff --git a/Nfp/Sound/Induction/CoreSound.lean b/Nfp/Sound/Induction/CoreSound.lean index 6b18354..b6e3992 100644 --- a/Nfp/Sound/Induction/CoreSound.lean +++ b/Nfp/Sound/Induction/CoreSound.lean @@ -11,10 +11,11 @@ set_option maxHeartbeats 5000000 in -- The soundness proof expands many cached bounds; extra heartbeats avoid spurious timeouts. set_option synthInstance.maxHeartbeats 200000 in -- Instance search also touches the expanded caches; allow more room to avoid timeouts. -/-- Soundness for `buildInductionCertFromHeadCore?`. -/ -theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} +/-- Soundness for `buildInductionCertFromHeadCoreWith?`. -/ +theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : Nat} + (cfg : InductionHeadSplitConfig) (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) - (hcore : buildInductionCertFromHeadCore? inputs = some c) : + (hcore : buildInductionCertFromHeadCoreWith? cfg inputs = some c) : InductionHeadCertSound inputs c := by classical by_cases hEps : 0 < inputs.lnEps @@ -22,8 +23,8 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} · by_cases hmodel : dModel = 0 · have : False := by have hnone := - buildInductionCertFromHeadCore?_eq_none_of_model_eq_zero - (inputs := inputs) hEps hSqrt hmodel + buildInductionCertFromHeadCoreWith?_eq_none_of_model_eq_zero + (cfg := cfg) (inputs := inputs) hEps hSqrt hmodel have hcore' : (none : Option (InductionHeadCert seq)) = some c := by exact hnone.symm.trans hcore @@ -167,10 +168,10 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} simp [kAbsMaxArr]) let masked : Fin seq → Fin seq → Prop := fun q k => inputs.maskCausal = true ∧ q < k - let splitBudgetQ : Nat := 2 - let splitBudgetK : Nat := 2 - let splitBudgetDiffBase : Nat := 0 - let splitBudgetDiffRefined : Nat := 12 + let splitBudgetQ : Nat := cfg.splitBudgetQ + let splitBudgetK : Nat := cfg.splitBudgetK + let splitBudgetDiffBase : Nat := cfg.splitBudgetDiffBase + let splitBudgetDiffRefined : Nat := cfg.splitBudgetDiffRefined let top2ByScore : (Fin dHead → Rat) → List (Fin dHead) → List (Fin dHead) := fun score ambig => let step @@ -458,9 +459,9 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} active := inputs.active prev := inputs.prev values := valCert } - have hcore' : buildInductionCertFromHeadCore? inputs = some cert := by + have hcore' : buildInductionCertFromHeadCoreWith? cfg inputs = some cert := by simp (config := { zeta := false }) only - [buildInductionCertFromHeadCore?, hEps, hSqrt, hmodel, hactive] + [buildInductionCertFromHeadCoreWith?, hEps, hSqrt, hmodel, hactive] rfl have hc : c = cert := by have hcert : cert = c := by @@ -1254,8 +1255,8 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} value_bounds := hvals_bounds } · have : False := by have hnone := - buildInductionCertFromHeadCore?_eq_none_of_not_active - (inputs := inputs) hEps hSqrt hmodel hactive + buildInductionCertFromHeadCoreWith?_eq_none_of_not_active + (cfg := cfg) (inputs := inputs) hEps hSqrt hmodel hactive have hcore' : (none : Option (InductionHeadCert seq)) = some c := by exact hnone.symm.trans hcore @@ -1263,18 +1264,33 @@ theorem buildInductionCertFromHeadCore?_sound [NeZero seq] {dModel dHead : Nat} exact this.elim · have : False := by have hnone := - buildInductionCertFromHeadCore?_eq_none_of_not_sqrt (inputs := inputs) hEps hSqrt + buildInductionCertFromHeadCoreWith?_eq_none_of_not_sqrt + (cfg := cfg) (inputs := inputs) hEps hSqrt have hcore' : (none : Option (InductionHeadCert seq)) = some c := by exact hnone.symm.trans hcore cases hcore' exact this.elim · have : False := by - have hnone := buildInductionCertFromHeadCore?_eq_none_of_not_eps (inputs := inputs) hEps + have hnone := + buildInductionCertFromHeadCoreWith?_eq_none_of_not_eps + (cfg := cfg) (inputs := inputs) hEps have hcore' : (none : Option (InductionHeadCert seq)) = some c := by exact hnone.symm.trans hcore cases hcore' exact this.elim + +/-- Soundness for `buildInductionCertFromHeadCore?`. -/ +theorem buildInductionCertFromHeadCore?_sound + [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) + (hcore : buildInductionCertFromHeadCore? inputs = some c) : + InductionHeadCertSound inputs c := by + simpa [buildInductionCertFromHeadCore?] using + (buildInductionCertFromHeadCoreWith?_sound + (cfg := defaultInductionHeadSplitConfig) inputs c + (by + simpa [buildInductionCertFromHeadCore?] using hcore)) end Sound end Nfp diff --git a/Nfp/Sound/Induction/HeadOutput.lean b/Nfp/Sound/Induction/HeadOutput.lean index d67341a..e6e15dc 100644 --- a/Nfp/Sound/Induction/HeadOutput.lean +++ b/Nfp/Sound/Induction/HeadOutput.lean @@ -19,14 +19,22 @@ open Nfp.Sound.Bounds variable {seq : Nat} /-- Build and certify induction certificates from exact head inputs. -/ -def buildInductionCertFromHead? [NeZero seq] {dModel dHead : Nat} +def buildInductionCertFromHeadWith? [NeZero seq] {dModel dHead : Nat} + (cfg : InductionHeadSplitConfig) (inputs : Model.InductionHeadInputs seq dModel dHead) : Option {c : InductionHeadCert seq // InductionHeadCertSound inputs c} := by classical - cases hcore : buildInductionCertFromHeadCore? inputs with + cases hcore : buildInductionCertFromHeadCoreWith? cfg inputs with | none => exact none | some c => - exact some ⟨c, buildInductionCertFromHeadCore?_sound inputs c hcore⟩ + exact some ⟨c, buildInductionCertFromHeadCoreWith?_sound (cfg := cfg) inputs c hcore⟩ + +/-- Build and certify induction certificates from exact head inputs using the default split +budgets. -/ +def buildInductionCertFromHead? [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : + Option {c : InductionHeadCert seq // InductionHeadCertSound inputs c} := + buildInductionCertFromHeadWith? defaultInductionHeadSplitConfig inputs section HeadOutputInterval From 4361a71b5801da10a8048245e0319949e5f3f8f6 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 04:05:44 +0100 Subject: [PATCH 149/244] Optimize induction head certification --- Nfp/IO/InductionHead.lean | 83 ++++++++--- Nfp/Sound/Induction/Core.lean | 183 ++++++++++++----------- Nfp/Sound/Induction/CoreSound.lean | 107 ++++++++------ Nfp/Sound/Induction/HeadBounds.lean | 152 +++++++++++++++++++ Nfp/Sound/Induction/LogitDiff.lean | 220 ++++++++++++++++++++++++++++ 5 files changed, 591 insertions(+), 154 deletions(-) diff --git a/Nfp/IO/InductionHead.lean b/Nfp/IO/InductionHead.lean index 12cc169..724a751 100644 --- a/Nfp/IO/InductionHead.lean +++ b/Nfp/IO/InductionHead.lean @@ -965,7 +965,7 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (cfg : Sound.InductionHeadSplitConfig) (minActive? : Option Nat) (minLogitDiff? : Option Rat) - (minMargin maxEps : Rat) : IO UInt32 := do + (minMargin? : Option Rat) (maxEps : Rat) : IO UInt32 := do match seq with | 0 => IO.eprintln "error: seq must be positive" @@ -1015,20 +1015,58 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} IO.eprintln s!"error: active queries {activeCount} below minimum {minActive}" return 2 - if cert.margin < minMargin then - IO.eprintln - s!"error: margin {ratToString cert.margin} \ - below minimum {ratToString minMargin}" - return 2 if maxEps < cert.eps then IO.eprintln s!"error: eps {ratToString cert.eps} above maximum {ratToString maxEps}" return 2 + let marginViolation? : Option Rat := + match minMargin? with + | none => none + | some minMargin => + if cert.margin < minMargin then + some minMargin + else + none + match marginViolation? with + | some minMargin => + IO.eprintln + s!"error: margin {ratToString cert.margin} \ + below minimum {ratToString minMargin}" + return 2 + | none => pure () logTiming "start: head logit-diff lower bound" timingPrint "timing: head logit-diff lower bound start" timingFlush - let logitDiffLB? ← timePure "head: logit-diff lower bound" (fun () => + let logitDiffLB0? ← timePure "head: logit-diff lower bound" (fun () => Sound.logitDiffLowerBoundFromCert cert) + let needsWeighted : Bool := + match logitDiffLB0? with + | none => true + | some lb0 => + if lb0 ≤ 0 then + true + else + match minLogitDiff? with + | some minLogitDiff => lb0 < minLogitDiff + | none => false + let logitDiffWeighted? ← + if needsWeighted then + timePure "head: logit-diff lower bound weighted" (fun () => + Sound.logitDiffLowerBoundFromCertWeighted cert) + else + pure none + let logitDiffLB? : Option Rat := + match logitDiffLB0?, logitDiffWeighted? with + | some lb0, some lb1 => some (max lb0 lb1) + | some lb0, none => some lb0 + | none, some lb1 => some lb1 + | none, none => none + let boundLabel : String := + match logitDiffLB0?, logitDiffWeighted? with + | some _, some _ => "max" + | none, some _ => "weighted" + | some _, none => "eps" + | none, none => "none" logTiming "done: head logit-diff lower bound" match logitDiffLB? with | none => @@ -1040,19 +1078,27 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} s!"error: logitDiffLB {ratToString logitDiffLB} \ is not strictly positive" return 2 - match minLogitDiff? with + let violation? : Option Rat := + match minLogitDiff? with + | none => none + | some minLogitDiff => + if logitDiffLB < minLogitDiff then + some minLogitDiff + else + none + match violation? with | some minLogitDiff => - if logitDiffLB < minLogitDiff then - IO.eprintln - s!"error: logitDiffLB {ratToString logitDiffLB} \ - below minimum {ratToString minLogitDiff}" - return 2 + IO.eprintln + s!"error: logitDiffLB {ratToString logitDiffLB} \ + below minimum {ratToString minLogitDiff}" + return 2 | none => pure () let tol := cert.eps * (cert.values.hi - cert.values.lo) IO.println s!"ok: nonvacuous induction bound certified \ (seq={seq}, active={activeCount}, \ - tol={ratToString tol}, logitDiffLB={ratToString logitDiffLB})" + tol={ratToString tol}, logitDiffLB={ratToString logitDiffLB}, \ + bound={boundLabel})" return 0 /-- Build and check induction certificates from exact head inputs. -/ @@ -1114,7 +1160,6 @@ def runInductionCertifyHeadNonvacuous (inputsPath : System.FilePath) IO.eprintln s!"error: {msg}" return 2 | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let minMargin := minMargin?.getD (0 : Rat) let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) let parsedInputs ← timePhase "load head inputs" <| loadInductionHeadInputs inputsPath @@ -1124,7 +1169,7 @@ def runInductionCertifyHeadNonvacuous (inputsPath : System.FilePath) return 1 | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => checkInductionHeadInputsNonvacuous inputs splitCfg minActive? minLogitDiff? - minMargin maxEps + minMargin? maxEps /-- Build and check induction certificates from a model binary. -/ def runInductionCertifyHeadModel (modelPath : System.FilePath) @@ -1199,7 +1244,6 @@ def runInductionCertifyHeadModelNonvacuous (modelPath : System.FilePath) IO.eprintln s!"error: {msg}" return 2 | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => - let minMargin := minMargin?.getD (0 : Rat) let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) logTiming "start: read model file" timingPrint "timing: read model file start" @@ -1221,7 +1265,7 @@ def runInductionCertifyHeadModelNonvacuous (modelPath : System.FilePath) return 1 | Except.ok inputs => checkInductionHeadInputsNonvacuous inputs splitCfg minActive? minLogitDiff? - minMargin maxEps + minMargin? maxEps /-- Heuristic logit-diff direction derived from prompt tokens. -/ private def deriveDirectionFromTokens {seq : Nat} (tokens : Fin seq → Nat) : @@ -1341,7 +1385,6 @@ def runInductionCertifyHeadModelAutoNonvacuous (modelPath : System.FilePath) IO.eprintln s!"error: {msg}" return 2 | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => - let minMargin := minMargin?.getD (0 : Rat) let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) logTiming "start: read model file" timingPrint "timing: read model file start" @@ -1377,7 +1420,7 @@ def runInductionCertifyHeadModelAutoNonvacuous (modelPath : System.FilePath) return 1 | Except.ok inputs => checkInductionHeadInputsNonvacuous inputs splitCfg minActive? minLogitDiff? - minMargin maxEps + minMargin? maxEps /-- Build head-output interval bounds from exact head inputs. -/ def runInductionHeadInterval (inputsPath : System.FilePath) diff --git a/Nfp/Sound/Induction/Core.lean b/Nfp/Sound/Induction/Core.lean index f5ef393..e481c48 100644 --- a/Nfp/Sound/Induction/Core.lean +++ b/Nfp/Sound/Induction/Core.lean @@ -381,94 +381,105 @@ def buildInductionHeadCoreCacheWith [NeZero seq] {dModel dHead : Nat} let splitBudgetK : Nat := cfg.splitBudgetK let splitBudgetDiffBase : Nat := cfg.splitBudgetDiffBase let splitBudgetDiffRefined : Nat := cfg.splitBudgetDiffRefined + let finRangeHead : List (Fin dHead) := List.finRange dHead + let finRangeSeq : List (Fin seq) := List.finRange seq let splitDimsQ : Fin seq → List (Fin dHead) := fun q => - let ambig := - (List.finRange dHead).filter (fun d => decide (qLo q d < 0 ∧ 0 < qHi q d)) - let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d - let step - (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) - (d : Fin dHead) : - Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := - let s := score d - match best with - | (none, none) => (some (s, d), none) - | (some b1, none) => - if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) - | (some b1, some b2) => - if b1.1 < s then (some (s, d), some b1) - else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) - | (none, some b2) => - if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => - match ambig.foldl step (none, none) with - | (some b1, some b2) => [b1.2, b2.2] - | (some b1, none) => [b1.2] - | (none, _) => [] - let dims1 := top2 ambig - let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) - (dims1 ++ dims2).take splitBudgetQ + if splitBudgetQ = 0 then + [] + else + let ambig := + finRangeHead.filter (fun d => decide (qLo q d < 0 ∧ 0 < qHi q d)) + let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d + let step + (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) + (d : Fin dHead) : + Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := + let s := score d + match best with + | (none, none) => (some (s, d), none) + | (some b1, none) => + if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) + | (some b1, some b2) => + if b1.1 < s then (some (s, d), some b1) + else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) + | (none, some b2) => + if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) + let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => + match ambig.foldl step (none, none) with + | (some b1, some b2) => [b1.2, b2.2] + | (some b1, none) => [b1.2] + | (none, _) => [] + let dims1 := top2 ambig + let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) + (dims1 ++ dims2).take splitBudgetQ let splitDimsK : Fin seq → Fin seq → List (Fin dHead) := fun q k => - let ambig := - (List.finRange dHead).filter (fun d => decide (kLo k d < 0 ∧ 0 < kHi k d)) - let score : Fin dHead → Rat := fun d => (kHi k d - kLo k d) * qAbs q d - let step - (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) - (d : Fin dHead) : - Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := - let s := score d - match best with - | (none, none) => (some (s, d), none) - | (some b1, none) => - if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) - | (some b1, some b2) => - if b1.1 < s then (some (s, d), some b1) - else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) - | (none, some b2) => - if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => - match ambig.foldl step (none, none) with - | (some b1, some b2) => [b1.2, b2.2] - | (some b1, none) => [b1.2] - | (none, _) => [] - let dims1 := top2 ambig - let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) - (dims1 ++ dims2).take splitBudgetK + if splitBudgetK = 0 then + [] + else + let ambig := + finRangeHead.filter (fun d => decide (kLo k d < 0 ∧ 0 < kHi k d)) + let score : Fin dHead → Rat := fun d => (kHi k d - kLo k d) * qAbs q d + let step + (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) + (d : Fin dHead) : + Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := + let s := score d + match best with + | (none, none) => (some (s, d), none) + | (some b1, none) => + if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) + | (some b1, some b2) => + if b1.1 < s then (some (s, d), some b1) + else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) + | (none, some b2) => + if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) + let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => + match ambig.foldl step (none, none) with + | (some b1, some b2) => [b1.2, b2.2] + | (some b1, none) => [b1.2] + | (none, _) => [] + let dims1 := top2 ambig + let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) + (dims1 ++ dims2).take splitBudgetK let splitDimsDiffCore : Nat → Fin seq → Fin seq → List (Fin dHead) := fun budget q k => - let prev := inputs.prev q - let diffLo : Fin dHead → Rat := fun d => kLo prev d - kHi k d - let diffHi : Fin dHead → Rat := fun d => kHi prev d - kLo k d - let ambig := - (List.finRange dHead).filter (fun d => decide (diffLo d < 0 ∧ 0 < diffHi d)) - let score : Fin dHead → Rat := fun d => (diffHi d - diffLo d) * qAbs q d - let step - (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) - (d : Fin dHead) : - Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := - let s := score d - match best with - | (none, none) => (some (s, d), none) - | (some b1, none) => - if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) - | (some b1, some b2) => - if b1.1 < s then (some (s, d), some b1) - else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) - | (none, some b2) => - if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => - match ambig.foldl step (none, none) with - | (some b1, some b2) => [b1.2, b2.2] - | (some b1, none) => [b1.2] - | (none, _) => [] - let dims1 : List (Fin dHead) := top2 ambig - let dims2 : List (Fin dHead) := - top2 (ambig.filter (fun d => decide (d ∉ dims1))) - let memDims2 : Fin dHead → Bool := fun d => - dims2.any (fun d' => decide (d' = d)) - let dims3 : List (Fin dHead) := - top2 - ((ambig.filter (fun d => decide (d ∉ dims1))).filter - (fun d => !memDims2 d)) - (dims1 ++ dims2 ++ dims3).take budget + if budget = 0 then + [] + else + let prev := inputs.prev q + let diffLo : Fin dHead → Rat := fun d => kLo prev d - kHi k d + let diffHi : Fin dHead → Rat := fun d => kHi prev d - kLo k d + let ambig := + finRangeHead.filter (fun d => decide (diffLo d < 0 ∧ 0 < diffHi d)) + let score : Fin dHead → Rat := fun d => (diffHi d - diffLo d) * qAbs q d + let step + (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) + (d : Fin dHead) : + Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := + let s := score d + match best with + | (none, none) => (some (s, d), none) + | (some b1, none) => + if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) + | (some b1, some b2) => + if b1.1 < s then (some (s, d), some b1) + else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) + | (none, some b2) => + if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) + let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => + match ambig.foldl step (none, none) with + | (some b1, some b2) => [b1.2, b2.2] + | (some b1, none) => [b1.2] + | (none, _) => [] + let dims1 : List (Fin dHead) := top2 ambig + let dims2 : List (Fin dHead) := + top2 (ambig.filter (fun d => decide (d ∉ dims1))) + let memDims2 : Fin dHead → Bool := fun d => + dims2.any (fun d' => decide (d' = d)) + let dims3 : List (Fin dHead) := + top2 + ((ambig.filter (fun d => decide (d ∉ dims1))).filter + (fun d => !memDims2 d)) + (dims1 ++ dims2 ++ dims3).take budget let splitDimsDiffBase : Fin seq → Fin seq → List (Fin dHead) := splitDimsDiffCore splitBudgetDiffBase let splitDimsDiffRefined : Fin seq → Fin seq → List (Fin dHead) := @@ -564,7 +575,7 @@ def buildInductionHeadCoreCacheWith [NeZero seq] {dModel dHead : Nat} (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) let worstKeyRaw : Fin seq → Option (Fin seq) := fun q => if hq : q ∈ inputs.active then - let ks := (List.finRange seq).filter (fun k => decide (k ≠ inputs.prev q)) + let ks := finRangeSeq.filter (fun k => decide (k ≠ inputs.prev q)) match ks with | [] => none | k :: ks => diff --git a/Nfp/Sound/Induction/CoreSound.lean b/Nfp/Sound/Induction/CoreSound.lean index b6e3992..0dc5c6d 100644 --- a/Nfp/Sound/Induction/CoreSound.lean +++ b/Nfp/Sound/Induction/CoreSound.lean @@ -192,56 +192,67 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N | (some b1, some b2) => [b1.2, b2.2] | (some b1, none) => [b1.2] | (none, _) => [] + let finRangeHead : List (Fin dHead) := List.finRange dHead + let finRangeSeq : List (Fin seq) := List.finRange seq let splitDimsQ : Fin seq → List (Fin dHead) := fun q => - let ambig := - (List.finRange dHead).filter (fun d => decide (qLo q d < 0 ∧ 0 < qHi q d)) - let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d - let dims1 := top2ByScore score ambig - let dims2 := top2ByScore score (ambig.filter (fun d => decide (d ∉ dims1))) - (dims1 ++ dims2).take splitBudgetQ + if splitBudgetQ = 0 then + [] + else + let ambig := + finRangeHead.filter (fun d => decide (qLo q d < 0 ∧ 0 < qHi q d)) + let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d + let dims1 := top2ByScore score ambig + let dims2 := top2ByScore score (ambig.filter (fun d => decide (d ∉ dims1))) + (dims1 ++ dims2).take splitBudgetQ let splitDimsK : Fin seq → Fin seq → List (Fin dHead) := fun q k => - let ambig := - (List.finRange dHead).filter (fun d => decide (kLo k d < 0 ∧ 0 < kHi k d)) - let score : Fin dHead → Rat := fun d => (kHi k d - kLo k d) * qAbs q d - let dims1 := top2ByScore score ambig - let dims2 := top2ByScore score (ambig.filter (fun d => decide (d ∉ dims1))) - (dims1 ++ dims2).take splitBudgetK + if splitBudgetK = 0 then + [] + else + let ambig := + finRangeHead.filter (fun d => decide (kLo k d < 0 ∧ 0 < kHi k d)) + let score : Fin dHead → Rat := fun d => (kHi k d - kLo k d) * qAbs q d + let dims1 := top2ByScore score ambig + let dims2 := top2ByScore score (ambig.filter (fun d => decide (d ∉ dims1))) + (dims1 ++ dims2).take splitBudgetK let splitDimsDiffCore : Nat → Fin seq → Fin seq → List (Fin dHead) := fun budget q k => - let prev := inputs.prev q - let diffLo : Fin dHead → Rat := fun d => kLo prev d - kHi k d - let diffHi : Fin dHead → Rat := fun d => kHi prev d - kLo k d - let ambig := - (List.finRange dHead).filter (fun d => decide (diffLo d < 0 ∧ 0 < diffHi d)) - let score : Fin dHead → Rat := fun d => (diffHi d - diffLo d) * qAbs q d - let step - (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) - (d : Fin dHead) : - Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := - let s := score d - match best with - | (none, none) => (some (s, d), none) - | (some b1, none) => - if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) - | (some b1, some b2) => - if b1.1 < s then (some (s, d), some b1) - else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) - | (none, some b2) => - if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => - match ambig.foldl step (none, none) with - | (some b1, some b2) => [b1.2, b2.2] - | (some b1, none) => [b1.2] - | (none, _) => [] - let dims1 : List (Fin dHead) := top2 ambig - let dims2 : List (Fin dHead) := - top2 (ambig.filter (fun d => decide (d ∉ dims1))) - let memDims2 : Fin dHead → Bool := fun d => - dims2.any (fun d' => decide (d' = d)) - let dims3 : List (Fin dHead) := - top2 - ((ambig.filter (fun d => decide (d ∉ dims1))).filter - (fun d => !memDims2 d)) - (dims1 ++ dims2 ++ dims3).take budget + if budget = 0 then + [] + else + let prev := inputs.prev q + let diffLo : Fin dHead → Rat := fun d => kLo prev d - kHi k d + let diffHi : Fin dHead → Rat := fun d => kHi prev d - kLo k d + let ambig := + finRangeHead.filter (fun d => decide (diffLo d < 0 ∧ 0 < diffHi d)) + let score : Fin dHead → Rat := fun d => (diffHi d - diffLo d) * qAbs q d + let step + (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) + (d : Fin dHead) : + Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := + let s := score d + match best with + | (none, none) => (some (s, d), none) + | (some b1, none) => + if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) + | (some b1, some b2) => + if b1.1 < s then (some (s, d), some b1) + else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) + | (none, some b2) => + if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) + let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => + match ambig.foldl step (none, none) with + | (some b1, some b2) => [b1.2, b2.2] + | (some b1, none) => [b1.2] + | (none, _) => [] + let dims1 : List (Fin dHead) := top2 ambig + let dims2 : List (Fin dHead) := + top2 (ambig.filter (fun d => decide (d ∉ dims1))) + let memDims2 : Fin dHead → Bool := fun d => + dims2.any (fun d' => decide (d' = d)) + let dims3 : List (Fin dHead) := + top2 + ((ambig.filter (fun d => decide (d ∉ dims1))).filter + (fun d => !memDims2 d)) + (dims1 ++ dims2 ++ dims3).take budget let splitDimsDiffBase : Fin seq → Fin seq → List (Fin dHead) := splitDimsDiffCore splitBudgetDiffBase let splitDimsDiffRefined : Fin seq → Fin seq → List (Fin dHead) := @@ -337,7 +348,7 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) let worstKeyRaw : Fin seq → Option (Fin seq) := fun q => if hq : q ∈ inputs.active then - let ks := (List.finRange seq).filter (fun k => decide (k ≠ inputs.prev q)) + let ks := finRangeSeq.filter (fun k => decide (k ≠ inputs.prev q)) match ks with | [] => none | k :: ks => diff --git a/Nfp/Sound/Induction/HeadBounds.lean b/Nfp/Sound/Induction/HeadBounds.lean index 11f529c..037ed18 100644 --- a/Nfp/Sound/Induction/HeadBounds.lean +++ b/Nfp/Sound/Induction/HeadBounds.lean @@ -30,6 +30,63 @@ private def taskMin (t1 t2 : Task Rat) : Task Rat := private def taskMax (t1 t2 : Task Rat) : Task Rat := Task.bind t1 (fun v1 => Task.map (fun v2 => max v1 v2) t2) +/-! Small lemmas for extracting `get` from task folds. -/ + +/-- `taskMin` exposes its `get` as a plain `min` on task results. -/ +private theorem taskMin_get (t1 t2 : Task Rat) : + (taskMin t1 t2).get = min t1.get t2.get := by + rfl + +/-- `taskMax` exposes its `get` as a plain `max` on task results. -/ +private theorem taskMax_get (t1 t2 : Task Rat) : + (taskMax t1 t2).get = max t1.get t2.get := by + rfl + +/-- Pull `get` through a `List.foldl` when the step is `get`-compatible. -/ +private theorem foldl_task_get_eq {α β : Type} (step : Task β → α → Task β) (step' : β → α → β) + (hstep : ∀ acc a, (step acc a).get = step' acc.get a) : + ∀ (xs : List α) (acc : Task β), + (List.foldl step acc xs).get = List.foldl step' acc.get xs + | [], acc => rfl + | x :: xs, acc => by + simpa [List.foldl, hstep] using foldl_task_get_eq step step' hstep xs (step acc x) + +/-- `List.foldl` over `taskMin` exposes a fold over `min` on task results. -/ +private theorem foldl_taskMin_get_eq {α : Type} (f : α → Task Rat) (xs : List α) + (init : Task Rat) : + (List.foldl (fun acc a => taskMin acc (f a)) init xs).get = + List.foldl (fun acc a => min acc (f a).get) init.get xs := by + refine + foldl_task_get_eq + (step := fun acc a => taskMin acc (f a)) + (step' := fun acc a => min acc (f a).get) + (hstep := ?_) + xs init + intro acc a + simp [taskMin_get] + +/-- `List.foldl` over `taskMax` exposes a fold over `max` on task results. -/ +private theorem foldl_taskMax_get_eq {α : Type} (f : α → Task Rat) (xs : List α) + (init : Task Rat) : + (List.foldl (fun acc a => taskMax acc (f a)) init xs).get = + List.foldl (fun acc a => max acc (f a).get) init.get xs := by + refine + foldl_task_get_eq + (step := fun acc a => taskMax acc (f a)) + (step' := fun acc a => max acc (f a).get) + (hstep := ?_) + xs init + intro acc a + simp [taskMax_get] + +/-- `Array.get?` + `Option.getD` followed by `Task.get` agrees with `getD` on values. -/ +private theorem task_getD_ofFn {n : Nat} (f : Fin n → Rat) (i : Nat) : + ((Array.ofFn fun c => ({ get := f c } : Task Rat))[i]?.getD { get := (0 : Rat) }).get = + (Array.ofFn f)[i]?.getD (0 : Rat) := by + by_cases h : i < n + · simp [h, Array.size_ofFn] + · simp [h, Array.size_ofFn] + /-! Helpers for reducing cached arrays without extra allocation. -/ /-- Reduce an array of rational bounds to its minimum (defaulting to `0` on empty arrays). -/ @@ -42,6 +99,59 @@ private def reduceMaxArray (arr : Array Rat) : Rat := let init := arr.getD 0 (0 : Rat) arr.foldl (fun acc x => max acc x) init +/-- Reduce a `Fin seq`-indexed function using the chunked sequential algorithm. -/ +private def reduceFnChunked [NeZero seq] (vals : Fin seq → Rat) + (combine : Rat → Rat → Rat) : Rat := + let n := seq + if n = 0 then + (0 : Rat) + else + let chunkSize : Nat := 256 + let chunks : Nat := (n + chunkSize - 1) / chunkSize + let hpos : 0 < seq := Nat.pos_of_ne_zero (by simpa using (NeZero.ne (n := seq))) + let defaultIdx : Fin seq := ⟨0, hpos⟩ + let idxs : Array (Fin seq) := Array.ofFn (fun i : Fin seq => i) + let chunkVals : Array Rat := + Array.ofFn (fun c : Fin chunks => + let start := c.val * chunkSize + let stop := Nat.min n (start + chunkSize) + let init := vals (idxs.getD start defaultIdx) + if stop ≤ start + 1 then + init + else + let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) + rest.foldl (fun acc i => combine acc (vals (idxs.getD i defaultIdx))) init) + let init := chunkVals.getD 0 (0 : Rat) + let rest := (List.range (chunkVals.size - 1)).map (fun i => i + 1) + rest.foldl (fun acc i => combine acc (chunkVals.getD i 0)) init + +/-- Unfold `reduceFnChunked` to its chunked sequential definition. -/ +theorem reduceFnChunked_spec [NeZero seq] (vals : Fin seq → Rat) + (combine : Rat → Rat → Rat) : + reduceFnChunked (seq := seq) vals combine = + let n := seq + if n = 0 then + (0 : Rat) + else + let chunkSize : Nat := 256 + let chunks : Nat := (n + chunkSize - 1) / chunkSize + let hpos : 0 < seq := Nat.pos_of_ne_zero (by simpa using (NeZero.ne (n := seq))) + let defaultIdx : Fin seq := ⟨0, hpos⟩ + let idxs : Array (Fin seq) := Array.ofFn (fun i : Fin seq => i) + let chunkVals : Array Rat := + Array.ofFn (fun c : Fin chunks => + let start := c.val * chunkSize + let stop := Nat.min n (start + chunkSize) + let init := vals (idxs.getD start defaultIdx) + if stop ≤ start + 1 then + init + else + let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) + rest.foldl (fun acc i => combine acc (vals (idxs.getD i defaultIdx))) init) + let init := chunkVals.getD 0 (0 : Rat) + let rest := (List.range (chunkVals.size - 1)).map (fun i => i + 1) + rest.foldl (fun acc i => combine acc (chunkVals.getD i 0)) init := rfl + /-- Reduce a `Fin seq`-indexed function in parallel using chunked tasks. -/ private def reduceFnTask [NeZero seq] (vals : Fin seq → Rat) (combine : Rat → Rat → Rat) (combineTask : Task Rat → Task Rat → Task Rat) : Task Rat := @@ -105,6 +215,38 @@ private def reduceMinFnTask [NeZero seq] (vals : Fin seq → Rat) : Task Rat := private def reduceMaxFnTask [NeZero seq] (vals : Fin seq → Rat) : Task Rat := reduceFnTask vals max taskMax +/-- Chunked sequential minimum over a `Fin seq`-indexed function. -/ +private def reduceMinFnChunked [NeZero seq] (vals : Fin seq → Rat) : Rat := + reduceFnChunked vals min + +/-- Unfold `reduceMinFnChunked` to `reduceFnChunked` with `min`. -/ +theorem reduceMinFnChunked_spec [NeZero seq] (vals : Fin seq → Rat) : + reduceMinFnChunked vals = reduceFnChunked vals min := rfl + +/-- Chunked sequential maximum over a `Fin seq`-indexed function. -/ +private def reduceMaxFnChunked [NeZero seq] (vals : Fin seq → Rat) : Rat := + reduceFnChunked vals max + +/-- Unfold `reduceMaxFnChunked` to `reduceFnChunked` with `max`. -/ +theorem reduceMaxFnChunked_spec [NeZero seq] (vals : Fin seq → Rat) : + reduceMaxFnChunked vals = reduceFnChunked vals max := rfl + +/-- The chunked parallel min-reduction task returns the sequential chunked result. -/ +theorem reduceMinFnTask_get_eq [NeZero seq] (vals : Fin seq → Rat) : + (reduceMinFnTask vals).get = reduceMinFnChunked vals := by + classical + have hseq : seq ≠ 0 := NeZero.ne (n := seq) + simp [reduceMinFnTask, reduceMinFnChunked, reduceFnTask, reduceFnChunked, hseq, + Task.spawn, foldl_taskMin_get_eq, task_getD_ofFn] + +/-- The chunked parallel max-reduction task returns the sequential chunked result. -/ +theorem reduceMaxFnTask_get_eq [NeZero seq] (vals : Fin seq → Rat) : + (reduceMaxFnTask vals).get = reduceMaxFnChunked vals := by + classical + have hseq : seq ≠ 0 := NeZero.ne (n := seq) + simp [reduceMaxFnTask, reduceMaxFnChunked, reduceFnTask, reduceFnChunked, hseq, + Task.spawn, foldl_taskMax_get_eq, task_getD_ofFn] + /-- Cached direction head for head inputs. -/ private def dirHeadVecOfInputs {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : Vector Rat dHead := @@ -937,6 +1079,11 @@ def headValueLoTask [NeZero seq] (valsLo : Fin seq → Rat) : Task Rat := theorem headValueLoTask_spec [NeZero seq] (valsLo : Fin seq → Rat) : headValueLoTask valsLo = reduceMinFnTask valsLo := rfl +/-- Chunked task reduction agrees with the sequential chunked value bound. -/ +theorem headValueLoTask_get_eq [NeZero seq] (valsLo : Fin seq → Rat) : + (headValueLoTask valsLo).get = reduceMinFnChunked valsLo := by + simp [headValueLoTask_spec, reduceMinFnTask_get_eq] + /-- Global upper value bound from an array of per-key values. -/ def headValueHiArray (valsHi : Array Rat) : Rat := reduceMaxArray valsHi @@ -959,6 +1106,11 @@ def headValueHiTask [NeZero seq] (valsHi : Fin seq → Rat) : Task Rat := theorem headValueHiTask_spec [NeZero seq] (valsHi : Fin seq → Rat) : headValueHiTask valsHi = reduceMaxFnTask valsHi := rfl +/-- Chunked task reduction agrees with the sequential chunked value bound. -/ +theorem headValueHiTask_get_eq [NeZero seq] (valsHi : Fin seq → Rat) : + (headValueHiTask valsHi).get = reduceMaxFnChunked valsHi := by + simp [headValueHiTask_spec, reduceMaxFnTask_get_eq] + /-- Build `HeadValueBounds` from precomputed arrays. -/ private def headValueBoundsOfArrays {seq dModel dHead : Nat} (valsLoArr valsHiArr : Array Rat) : HeadValueBounds seq dModel dHead := diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean index bd1a6ba..8db9932 100644 --- a/Nfp/Sound/Induction/LogitDiff.lean +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -34,6 +34,10 @@ def logitDiffLowerBoundFromCert (c : InductionHeadCert seq) : Option Rat := Circuit.logitDiffLowerBoundAtLo c.active c.prev c.epsAt c.values.lo c.values.valsLo +/-- Lower bound computed from per-key weight bounds in an induction certificate. -/ +def logitDiffLowerBoundFromCertWeighted (c : InductionHeadCert seq) : Option Rat := + Circuit.logitDiffLowerBoundWeightedAt c.active c.prev c.weightBoundAt c.values.valsLo + theorem logitDiffLowerBoundFromCert_le (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) (hsound : InductionHeadCertSound inputs c) @@ -185,6 +189,222 @@ theorem logitDiffLowerBoundFromCert_le le_trans hboundReal hdot_lower simpa [headLogitDiff, weights, vals] using hle +/-- The weighted per-key logit-diff lower bound is sound on active queries. -/ +theorem logitDiffLowerBoundFromCertWeighted_le + (inputs : Model.InductionHeadInputs seq dModel dHead) + (c : InductionHeadCert seq) (hsound : InductionHeadCertSound inputs c) + {lb : Rat} (hbound : logitDiffLowerBoundFromCertWeighted c = some lb) + {q : Fin seq} (hq : q ∈ c.active) : + (lb : Real) ≤ headLogitDiff inputs q := by + classical + cases seq with + | zero => + cases (NeZero.ne (n := (0 : Nat)) rfl) + | succ n => + let weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := fun q k => + Circuit.softmax (scoresRealOfInputs inputs q) k + let vals : Fin (Nat.succ n) → Real := valsRealOfInputs inputs + let others : Finset (Fin (Nat.succ n)) := + (Finset.univ : Finset (Fin (Nat.succ n))).erase (c.prev q) + let valsLoPrevRat : Rat := c.values.valsLo (c.prev q) + let valsLoPrev : Real := (valsLoPrevRat : Real) + have hboundRat : + lb ≤ valsLoPrevRat - + (others.sum (fun k => + c.weightBoundAt q k * max (0 : Rat) (valsLoPrevRat - c.values.valsLo k))) := by + refine + Circuit.logitDiffLowerBoundWeightedAt_le + (active := c.active) + (prev := c.prev) + (weightBoundAt := c.weightBoundAt) + (valsLo := c.values.valsLo) + q hq lb ?_ + simpa [logitDiffLowerBoundFromCertWeighted] using hbound + have hboundReal : + (lb : Real) ≤ + valsLoPrev - + (others.sum (fun k => + (c.weightBoundAt q k : Real) * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)))) := by + simpa [valsLoPrevRat, valsLoPrev, ratToReal_sub, ratToReal_mul, ratToReal_max, + ratToReal, Rat.cast_sum] using ratToReal_le_of_le hboundRat + have hweights_nonneg : ∀ k, 0 ≤ weights q k := + hsound.softmax_bounds.nonneg q hq + have hweights := hsound.oneHot_bounds_at q hq + have hsum_decomp : + weights q (c.prev q) + ∑ k ∈ others, weights q k = ∑ k, weights q k := by + simp [others] + have hsum : + weights q (c.prev q) + ∑ k ∈ others, weights q k = 1 := by + calc + weights q (c.prev q) + ∑ k ∈ others, weights q k = ∑ k, weights q k := hsum_decomp + _ = 1 := hweights.sum_one q rfl + have hvalsLo_prev : valsLoPrev ≤ vals (c.prev q) := by + exact (hsound.value_bounds.vals_bounds (c.prev q)).1 + have hvals_lower : + ∀ k, + valsLoPrev - max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) ≤ vals k := by + intro k + by_cases hle : valsLoPrev ≤ (c.values.valsLo k : Real) + · have hdiff : valsLoPrev - (c.values.valsLo k : Real) ≤ 0 := sub_nonpos.mpr hle + have hmax : + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) = 0 := by + simp [hdiff] + have hvals_lo : (c.values.valsLo k : Real) ≤ vals k := + (hsound.value_bounds.vals_bounds k).1 + have hvals_prev : valsLoPrev ≤ vals k := le_trans hle hvals_lo + simpa [hmax] using hvals_prev + · have hlt : (c.values.valsLo k : Real) < valsLoPrev := lt_of_not_ge hle + have hdiff_nonneg : 0 ≤ valsLoPrev - (c.values.valsLo k : Real) := + le_of_lt (sub_pos.mpr hlt) + have hmax : + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) = + valsLoPrev - (c.values.valsLo k : Real) := by + simp [hdiff_nonneg] + have hvals_lo : (c.values.valsLo k : Real) ≤ vals k := + (hsound.value_bounds.vals_bounds k).1 + simpa [hmax] using hvals_lo + have hsum_vals_ge : + ∑ k ∈ others, weights q k * + (valsLoPrev - max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real))) ≤ + ∑ k ∈ others, weights q k * vals k := by + refine Finset.sum_le_sum ?_ + intro k hk + exact + mul_le_mul_of_nonneg_left + (hvals_lower k) + (hweights_nonneg k) + have hsum_prod : + weights q (c.prev q) * vals (c.prev q) + + ∑ k ∈ others, weights q k * vals k = + ∑ k, weights q k * vals k := by + simp [others] + have hout_eq : + dotProduct (weights q) vals = + weights q (c.prev q) * vals (c.prev q) + + ∑ k ∈ others, weights q k * vals k := by + simpa [dotProduct] using hsum_prod.symm + have hprev_lo : + weights q (c.prev q) * valsLoPrev ≤ + weights q (c.prev q) * vals (c.prev q) := by + exact mul_le_mul_of_nonneg_left hvalsLo_prev (hweights_nonneg (c.prev q)) + have hdot_ge' : + weights q (c.prev q) * valsLoPrev + + ∑ k ∈ others, weights q k * + (valsLoPrev - max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real))) ≤ + dotProduct (weights q) vals := by + have hle : + weights q (c.prev q) * valsLoPrev + + ∑ k ∈ others, weights q k * + (valsLoPrev - max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real))) ≤ + weights q (c.prev q) * vals (c.prev q) + + ∑ k ∈ others, weights q k * vals k := by + exact add_le_add hprev_lo hsum_vals_ge + simpa [hout_eq, add_comm, add_left_comm, add_assoc] using hle + have hsum_weights : + (∑ k ∈ others, weights q k * valsLoPrev) = + (∑ k ∈ others, weights q k) * valsLoPrev := by + have hsum_mul : + (∑ k ∈ others, weights q k) * valsLoPrev = + ∑ k ∈ others, weights q k * valsLoPrev := by + simpa using + (Finset.sum_mul (s := others) (f := fun k => weights q k) (a := valsLoPrev)) + exact hsum_mul.symm + have hsum_split : + ∑ k ∈ others, weights q k * + (valsLoPrev - max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real))) = + (∑ k ∈ others, weights q k) * valsLoPrev - + ∑ k ∈ others, weights q k * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) := by + calc + ∑ k ∈ others, weights q k * + (valsLoPrev - max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real))) = + ∑ k ∈ others, + (weights q k * valsLoPrev - + weights q k * max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real))) := by + simp [mul_sub] + _ = + (∑ k ∈ others, weights q k * valsLoPrev) - + ∑ k ∈ others, weights q k * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) := by + simp [Finset.sum_sub_distrib] + _ = + (∑ k ∈ others, weights q k) * valsLoPrev - + ∑ k ∈ others, weights q k * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) := by + simp [hsum_weights] + have hsplit : + weights q (c.prev q) * valsLoPrev + + ∑ k ∈ others, weights q k * + (valsLoPrev - max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real))) = + valsLoPrev - + ∑ k ∈ others, weights q k * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) := by + calc + weights q (c.prev q) * valsLoPrev + + ∑ k ∈ others, weights q k * + (valsLoPrev - max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real))) = + weights q (c.prev q) * valsLoPrev + + ((∑ k ∈ others, weights q k) * valsLoPrev - + ∑ k ∈ others, weights q k * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real))) := by + simp [hsum_split] + _ = + (weights q (c.prev q) + ∑ k ∈ others, weights q k) * valsLoPrev - + ∑ k ∈ others, weights q k * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) := by + ring + _ = + valsLoPrev - + ∑ k ∈ others, weights q k * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) := by + simp [hsum] + have hweight_bound : + ∀ k ∈ others, weights q k ≤ (c.weightBoundAt q k : Real) := by + intro k hk + have hk' : k ≠ c.prev q := (Finset.mem_erase.mp hk).1 + exact hsound.weight_bounds_at q hq k hk' + have hsum_gap_le : + ∑ k ∈ others, weights q k * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) ≤ + ∑ k ∈ others, (c.weightBoundAt q k : Real) * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) := by + refine Finset.sum_le_sum ?_ + intro k hk + have hnonneg : + 0 ≤ max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) := + le_max_left _ _ + exact mul_le_mul_of_nonneg_right (hweight_bound k hk) hnonneg + have hsub_le : + valsLoPrev - + ∑ k ∈ others, (c.weightBoundAt q k : Real) * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) ≤ + valsLoPrev - + ∑ k ∈ others, weights q k * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) := by + exact sub_le_sub_left hsum_gap_le valsLoPrev + have hdot_lower : + valsLoPrev - + ∑ k ∈ others, (c.weightBoundAt q k : Real) * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) ≤ + dotProduct (weights q) vals := by + calc + valsLoPrev - + ∑ k ∈ others, (c.weightBoundAt q k : Real) * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) ≤ + valsLoPrev - + ∑ k ∈ others, weights q k * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) := hsub_le + _ = + weights q (c.prev q) * valsLoPrev + + ∑ k ∈ others, weights q k * + (valsLoPrev - max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real))) := by + simpa using hsplit.symm + _ ≤ dotProduct (weights q) vals := hdot_ge' + have hle : (lb : Real) ≤ dotProduct (weights q) vals := + le_trans hboundReal hdot_lower + simpa [headLogitDiff, weights, vals] using hle + /-- Certified logit-diff lower bound derived from exact head inputs. -/ structure InductionLogitLowerBoundResult (inputs : Model.InductionHeadInputs seq dModel dHead) where From e1d85177307cdaeab0374cd578a66723a9958c72 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 08:00:49 +0100 Subject: [PATCH 150/244] Streamline induction CLI --- Nfp/Cli.lean | 221 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 217 insertions(+), 4 deletions(-) diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index 505acb9..1b44f35 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -25,6 +25,208 @@ def versionCmd : Cmd := `[Cli| "Print the NFP version." ] +private def parseDirectionSpec (raw : String) : Except String (Nat × Nat) := do + let partsComma := raw.splitOn "," + let parts := if partsComma.length = 2 then partsComma else raw.splitOn ":" + match parts with + | [targetRaw, negativeRaw] => + match targetRaw.toNat?, negativeRaw.toNat? with + | some target, some negative => pure (target, negative) + | _, _ => throw s!"direction must be two natural numbers (got '{raw}')" + | _ => + throw s!"direction must look like \"target,negative\" (got '{raw}')" + +private def parseSplitPreset (raw : String) : + Except String (Option Nat × Option Nat × Option Nat × Option Nat) := do + let key := raw.toLower + match key with + | "balanced" | "default" => pure (none, none, none, none) + | "fast" => pure (some 0, some 0, some 0, some 0) + | "tight" => pure (some 4, some 4, some 2, some 16) + | _ => + throw s!"unknown preset '{raw}' (expected: fast, balanced, tight)" + +private def runInductionCertifyUnified (requireNonvacuous : Bool) (p : Parsed) : IO UInt32 := do + let inputsPath? := (p.flag? "inputs").map (·.as! String) + let modelPath? := (p.flag? "model").map (·.as! String) + let layer? := (p.flag? "layer").map (·.as! Nat) + let head? := (p.flag? "head").map (·.as! Nat) + let period? := (p.flag? "period").map (·.as! Nat) + let directionStr? := (p.flag? "direction").map (·.as! String) + let presetStr? := (p.flag? "preset").map (·.as! String) + let minActive? := (p.flag? "min-active").map (·.as! Nat) + let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) + let minMarginStr? := (p.flag? "min-margin").map (·.as! String) + let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) + let timing? := (p.flag? "timing").map (·.as! Nat) + let heartbeatMs? := (p.flag? "heartbeat-ms").map (·.as! Nat) + let fail (msg : String) : IO UInt32 := do + IO.eprintln s!"error: {msg}" + return 2 + let presetE := + match presetStr? with + | none => Except.ok (none, none, none, none) + | some raw => parseSplitPreset raw + let directionE := + match directionStr? with + | none => Except.ok none + | some raw => (parseDirectionSpec raw).map some + match presetE, directionE with + | Except.error msg, _ => fail msg + | _, Except.error msg => fail msg + | Except.ok ⟨splitBudgetQ?, splitBudgetK?, splitBudgetDiffBase?, splitBudgetDiffRefined?⟩, + Except.ok direction? => + match inputsPath?, modelPath? with + | some inputsPath, none => + if layer?.isSome || head?.isSome || period?.isSome then + fail "--layer/--head/--period are only valid with --model" + else if direction?.isSome then + fail "--direction is only valid with --model" + else if requireNonvacuous then + IO.runInductionCertifyHeadNonvacuous inputsPath minActive? minLogitDiffStr? + minMarginStr? maxEpsStr? timing? heartbeatMs? + splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + else + IO.runInductionCertifyHead inputsPath minActive? minLogitDiffStr? + minMarginStr? maxEpsStr? timing? heartbeatMs? + splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + | none, some modelPath => + match layer?, head? with + | some layer, some head => + match direction? with + | some ⟨dirTarget, dirNegative⟩ => + if requireNonvacuous then + IO.runInductionCertifyHeadModelNonvacuous modelPath layer head dirTarget + dirNegative period? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? + timing? heartbeatMs? + splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + else + IO.runInductionCertifyHeadModel modelPath layer head dirTarget dirNegative + period? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? + timing? heartbeatMs? + splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + | none => + if requireNonvacuous then + IO.runInductionCertifyHeadModelAutoNonvacuous modelPath layer head period? + minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? + splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + else + IO.runInductionCertifyHeadModelAuto modelPath layer head period? + minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? + splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + | _, _ => + fail "--layer and --head are required with --model" + | none, none => + fail "provide exactly one of --inputs or --model" + | some _, some _ => + fail "provide exactly one of --inputs or --model" + +private def runInductionCertifySimple (p : Parsed) : IO UInt32 := + runInductionCertifyUnified false p + +private def runInductionCertifyNonvacuousSimple (p : Parsed) : IO UInt32 := + runInductionCertifyUnified true p + +private def runInductionIntervalSimple (p : Parsed) : IO UInt32 := do + let inputsPath? := (p.flag? "inputs").map (·.as! String) + let modelPath? := (p.flag? "model").map (·.as! String) + let layer? := (p.flag? "layer").map (·.as! Nat) + let head? := (p.flag? "head").map (·.as! Nat) + let period? := (p.flag? "period").map (·.as! Nat) + let directionStr? := (p.flag? "direction").map (·.as! String) + let outPath? := (p.flag? "out").map (·.as! String) + let fail (msg : String) : IO UInt32 := do + IO.eprintln s!"error: {msg}" + return 2 + let directionE := + match directionStr? with + | none => Except.ok none + | some raw => (parseDirectionSpec raw).map some + match directionE with + | Except.error msg => fail msg + | Except.ok direction? => + match inputsPath?, modelPath? with + | some inputsPath, none => + if layer?.isSome || head?.isSome || period?.isSome then + fail "--layer/--head/--period are only valid with --model" + else if direction?.isSome then + fail "--direction is only valid with --model" + else + IO.runInductionHeadInterval inputsPath outPath? + | none, some modelPath => + match layer?, head?, direction? with + | some layer, some head, some ⟨dirTarget, dirNegative⟩ => + IO.runInductionHeadIntervalModel modelPath layer head dirTarget dirNegative period? + outPath? + | _, _, none => + fail "--direction is required with --model (use \"target,negative\")" + | _, _, _ => + fail "--layer and --head are required with --model" + | none, none => + fail "provide exactly one of --inputs or --model" + | some _, some _ => + fail "provide exactly one of --inputs or --model" + +/-- `nfp induction certify` subcommand (streamlined). -/ +def inductionCertifySimpleCmd : Cmd := `[Cli| + certify VIA runInductionCertifySimple; + "Check induction head certificates from inputs or a model file." + FLAGS: + inputs : String; "Path to the induction head input file (use either --inputs or --model)." + model : String; "Path to the NFP_BINARY_V1 model file (use either --inputs or --model)." + layer : Nat; "Layer index for the induction head (required with --model)." + head : Nat; "Head index for the induction head (required with --model)." + period : Nat; "Optional prompt period override (model only; default: derive from tokens)." + direction : String; "Optional logit-diff direction as \"target,negative\" (model only). \ + When omitted with --model, direction is derived from tokens." + preset : String; "Split-budget preset: fast | balanced | tight (default: balanced)." + "min-active" : Nat; "Optional minimum number of active queries required \ + (default: max 1 (seq/8))." + "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ + (rational literal; default: 0)." + "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." + "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." + timing : Nat; "Emit timing output to stdout (0=off, 1=on)." + "heartbeat-ms" : Nat; "Emit progress heartbeat every N ms (0 disables)." +] + +/-- `nfp induction certify_nonvacuous` subcommand (streamlined). -/ +def inductionCertifyNonvacuousSimpleCmd : Cmd := `[Cli| + certify_nonvacuous VIA runInductionCertifyNonvacuousSimple; + "Require a strictly positive logit-diff bound from inputs or a model file." + FLAGS: + inputs : String; "Path to the induction head input file (use either --inputs or --model)." + model : String; "Path to the NFP_BINARY_V1 model file (use either --inputs or --model)." + layer : Nat; "Layer index for the induction head (required with --model)." + head : Nat; "Head index for the induction head (required with --model)." + period : Nat; "Optional prompt period override (model only; default: derive from tokens)." + direction : String; "Optional logit-diff direction as \"target,negative\" (model only). \ + When omitted with --model, direction is derived from tokens." + preset : String; "Split-budget preset: fast | balanced | tight (default: balanced)." + "min-active" : Nat; "Optional minimum number of active queries required \ + (default: max 1 (seq/8))." + "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ + (rational literal; default: 0)." + "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." + "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." + timing : Nat; "Emit timing output to stdout (0=off, 1=on)." + "heartbeat-ms" : Nat; "Emit progress heartbeat every N ms (0 disables)." +] + +/-- `nfp induction interval` subcommand (streamlined). -/ +def inductionIntervalSimpleCmd : Cmd := `[Cli| + interval VIA runInductionIntervalSimple; + "Build head-output interval bounds from inputs or a model file." + FLAGS: + inputs : String; "Path to the induction head input file (use either --inputs or --model)." + model : String; "Path to the NFP_BINARY_V1 model file (use either --inputs or --model)." + layer : Nat; "Layer index for the induction head (required with --model)." + head : Nat; "Head index for the induction head (required with --model)." + period : Nat; "Optional prompt period override (model only; default: derive from tokens)." + direction : String; "Required logit-diff direction as \"target,negative\" (model only)." + out : String; "Optional path to write the residual-interval certificate." +] + /-- Check induction certificates for induction heads. -/ def runInductionCertify (p : Parsed) : IO UInt32 := do let scoresPath := p.flag! "scores" |>.as! String @@ -477,10 +679,10 @@ def inductionHeadIntervalModelCmd : Cmd := `[Cli| out : String; "Optional path to write the residual-interval certificate." ] -/-- Induction-head subcommands. -/ -def inductionCmd : Cmd := `[Cli| - induction NOOP; - "Induction-head utilities." +/-- Advanced induction-head subcommands (full flag surface). -/ +def inductionAdvancedCmd : Cmd := `[Cli| + advanced NOOP; + "Advanced induction-head utilities (full flag set)." SUBCOMMANDS: inductionCertifyCmd; inductionCertifySoundCmd; @@ -497,6 +699,17 @@ def inductionCmd : Cmd := `[Cli| inductionHeadIntervalModelCmd ] +/-- Induction-head subcommands. -/ +def inductionCmd : Cmd := `[Cli| + induction NOOP; + "Induction-head utilities (streamlined). Use `nfp induction advanced --help` for full options." + SUBCOMMANDS: + inductionCertifySimpleCmd; + inductionCertifyNonvacuousSimpleCmd; + inductionIntervalSimpleCmd; + inductionAdvancedCmd +] + /-- The root CLI command. -/ def nfpCmd : Cmd := `[Cli| nfp NOOP; From 1dcff504fdd1f265fa3fbf52a2f87b49de511b2b Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 08:10:26 +0100 Subject: [PATCH 151/244] Add induction head certification audit --- docs/induction_cert_audit.md | 80 ++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 docs/induction_cert_audit.md diff --git a/docs/induction_cert_audit.md b/docs/induction_cert_audit.md new file mode 100644 index 0000000..dafc18c --- /dev/null +++ b/docs/induction_cert_audit.md @@ -0,0 +1,80 @@ +# Induction Head Certification Audit + +Goal: assess whether the current Lean proofs justify the claim that we can certify +induction heads, and spell out the scope and limitations of that claim. + +## Formal proof chain (Lean) + +- `buildInductionCertFromHeadCoreWith?` returns a certificate under explicit guards + (`lnEps > 0`, `sqrtLower lnEps > 0`, `dModel ≠ 0`, `active.Nonempty`), so the + computation is only claimed when these preconditions hold + (`Nfp/Sound/Induction/Core.lean`). +- `buildInductionCertFromHeadWith?` wraps the core computation and returns + a proof-carrying certificate `⟨c, InductionHeadCertSound inputs c⟩` + (`Nfp/Sound/Induction/HeadOutput.lean`). +- `buildInductionCertFromHeadCoreWith?_sound` proves that any returned certificate + satisfies `InductionHeadCertSound`, i.e. the softmax-margin bounds, one-hot + bounds, and value-interval bounds that define the head-level certificate + (`Nfp/Sound/Induction/CoreSound.lean`). +- `buildInductionLogitLowerBoundFromHead?` and + `buildInductionLogitLowerBoundNonvacuous?` lift the head certificate to a + logit-diff lower bound; the key lemma `logitDiffLowerBoundFromCert_le` shows + the bound is sound on active queries (`Nfp/Sound/Induction/LogitDiff.lean`). + +## Mechanistic mapping (Transformer Circuits) + +The mechanistic induction-head story is a QK/OV decomposition: +- QK: identify a matching prior token (prefix-matching attention). +- OV: write the continuation token (or logit-diff direction) into the residual stream. + +The certificate aligns to that decomposition: +- The softmax-margin bounds constrain the QK pattern so that attention to the + chosen `prev` index dominates other keys (mechanistic “prefix match”). +- The value-interval bounds and logit-diff lower bound constrain the OV path in + the chosen direction, so the head’s contribution increases the target logit + relative to the negative logit. + +This is direct mechanistic evidence in the Transformer Circuits sense: it ties +parameters (Q/K/V/O + LayerNorm) to certified bounds on attention and value +contributions, but only for the specific inputs and direction supplied. + +Sources referenced for the mechanistic framing: +- `transformer-circuits-framework.md` (QK/OV decomposition). +- `induction-heads.md` (induction head behavior definition). +- `foundations.md` (reverse-engineering framing and feature decomposition). + +## Preconditions and scope limits + +These proofs are sufficient for a **conditional** certification claim: +if the inputs are correct and the builder returns a certificate, then the +head-level bounds hold. They are **not** sufficient for a global claim that a +head “is an induction head” without additional assumptions. + +Key assumptions and limitations: +- `prev` is an input, not a derived fact: the proofs assume it is the intended + prefix-matching index. There is no theorem linking `prev` to token identity. +- `directionSpec` is metadata only. The certificate proves a logit-diff bound + for the provided `direction` vector; it does not prove that vector equals the + model’s actual target-minus-negative unembedding direction. +- The active set is user-supplied and can be strict; bounds only hold for + `q ∈ active`, not all positions. +- The head-level certificate does not imply end-to-end behavior across blocks; + there is no formal bridge to full-model logits. + +## Conclusion + +Yes—**within the formal scope** of the current definitions, the proofs are +enough to claim that we can certify induction-head behavior at the head level: +they certify attention to a specified `prev` index and a logit-diff lower bound +along a specified direction. What is still missing is a proof that those inputs +correspond to the behavioral induction-head definition on actual sequences and +that the certified head contribution scales to end-to-end model logits. + +## Next steps + +- Formalize the relationship between `directionSpec` and the logit-diff vector + derived from unembedding (so the certified direction matches token-level claims). +- Add a proof or verified derivation that the `prev` mapping corresponds to the + induction pattern for a given prompt sequence. +- Build a bridge theorem that lifts head-level certificates to block-level or + end-to-end logit bounds. From d1992365d87315cfa4720836db565c99c94a3325 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 08:13:37 +0100 Subject: [PATCH 152/244] Clarify induction certification claims --- CLAIMS.md | 43 ++++++++++++++++++++++------------------ SOUNDNESS_LIMITATIONS.md | 14 ++++++++----- 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/CLAIMS.md b/CLAIMS.md index c77870a..8105a41 100644 --- a/CLAIMS.md +++ b/CLAIMS.md @@ -9,7 +9,10 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri - Softmax-margin certificate soundness: `checkSoftmaxMarginCert` implies `SoftmaxMarginBoundsOn`. - Value-range certificate soundness: `checkValueRangeCert` implies `ValueRangeBounds`. -- Logit-diff lower bound lemma: `logitDiffLowerBound_le`. +- Induction-head certificate soundness: `InductionHeadCertSound` holds whenever + `buildInductionCertFromHeadCoreWith?` returns a certificate for the given inputs. +- Logit-diff lower bound lemmas: `logitDiffLowerBound_le` and + `logitDiffLowerBoundFromCert_le`. - Downstream linear certificate soundness: `checkDownstreamLinearCert` implies `DownstreamLinearBounds`. - Residual-interval certificate soundness: `checkResidualIntervalCert` implies @@ -21,25 +24,27 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri ## Soundly checked by the trusted CLI -- `nfp induction certify` verifies softmax-margin certificates, value-range certificates, - and computes a logit-diff lower bound. -- `nfp induction certify_sound` recomputes `eps`/`margin` and `lo`/`hi` from raw entries - and verifies the resulting certificates. -- `nfp induction certify_head` recomputes scores/values from exact head inputs and verifies - the resulting induction certificate (experimental, potentially slow). -- `nfp induction certify_head_model` reads a model binary, derives head inputs in Lean, - and verifies the resulting induction certificate (includes attention projection biases - and derives `prev`/active from the stored token sequence by default). -- `nfp induction certify_head_model_auto` derives the logit-diff direction from the prompt - tokens stored in the model file before running the same head-input checker. -- `nfp induction certify_end_to_end` composes a head-level logit-diff lower bound with a - downstream error certificate (arithmetic consistency only). -- `nfp induction certify_end_to_end_matrix` computes a downstream bound from a matrix payload - using verified row-sum norms, then composes it with the head-level logit-diff lower bound. -- `nfp induction certify_end_to_end_model` derives the unembedding direction from an +- `nfp induction certify` verifies head-level induction certificates from either a head-input + file or a model binary, and can compute a logit-diff lower bound. +- `nfp induction certify_nonvacuous` requires a strictly positive logit-diff lower bound. +- `nfp induction advanced certify_sound` recomputes `eps`/`margin` and `lo`/`hi` from raw + entries and verifies the resulting certificates. +- `nfp induction advanced certify_head` recomputes scores/values from exact head inputs and + verifies the resulting induction certificate (experimental, potentially slow). +- `nfp induction advanced certify_head_model` reads a model binary, derives head inputs in Lean, + and verifies the resulting induction certificate (includes attention projection biases and + derives `prev`/active from the stored token sequence by default). +- `nfp induction advanced certify_head_model_auto` derives the logit-diff direction from the + prompt tokens stored in the model file before running the same head-input checker. +- `nfp induction advanced certify_end_to_end` composes a head-level logit-diff lower bound with + a downstream error certificate (arithmetic consistency only). +- `nfp induction advanced certify_end_to_end_matrix` computes a downstream bound from a matrix + payload using verified row-sum norms, then composes it with the head-level logit-diff lower + bound. +- `nfp induction advanced certify_end_to_end_model` derives the unembedding direction from an `NFP_BINARY_V1` model file, computes a downstream error bound from either a supplied - residual-interval certificate or a verified model-derived interval, and composes it with - the head-level logit-diff lower bound. + residual-interval certificate or a verified model-derived interval, and composes it with the + head-level logit-diff lower bound. ## Untrusted / heuristic diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index 67fcb45..775eee5 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -6,21 +6,24 @@ It is intentionally brief and focused on the soundness boundary. ## Current limitations - The trusted CLI only **checks certificates**; it does not search for witnesses or run a model. -- Induction certificates are **head-level** (softmax-margin + value-range + logit-diff lower bound). - They do **not** yet imply end-to-end model behavior. +- Induction certificates are **head-level** (softmax-margin + value-range + logit-diff lower bound), + and they are conditional on the supplied `prev`, `active`, and `direction` inputs. They do **not** + yet imply end-to-end model behavior. - Downstream error bounds can be computed from a **matrix payload** inside Lean. A model-based path exists, but it currently uses only the unembedding direction and derives residual intervals via conservative interval propagation (ignoring attention-score structure), which can be loose. - The `certify_head` path uses a **head-input file** extracted by an untrusted script; the extractor now includes attention projection biases and LayerNorm metadata, but the Lean-side computation - still ignores LayerNorm and the shared attention output bias. + still ignores the shared attention output bias. - The `certify_head_model` path derives head inputs from the model binary in Lean, includes - attention projection biases, and derives `prev`/active from the stored token sequence by - default, but still ignores LayerNorm and the shared attention output bias. It currently + attention projection biases and LayerNorm metadata, and derives `prev`/active from the stored + token sequence by default, but still ignores the shared attention output bias. It currently requires `head_dim` to be a perfect square to represent the scale as an exact rational. - The `certify_head_model_auto` path derives the logit-diff direction from the stored prompt tokens using a heuristic; use explicit direction tokens for fixed claims. +- The certification does not prove that `prev` corresponds to the behavioral induction pattern, + nor that the chosen direction matches a task-specific target/negative logit semantics. - Performance: exact head-input recomputation in Lean can be slow for nontrivial sequence lengths. - There is no bridge theorem connecting certificate validity to a full circuit/model semantics statement (for example, a formal statement about logits under a transformer block stack). @@ -30,6 +33,7 @@ It is intentionally brief and focused on the soundness boundary. - Tighten model-derived residual intervals (e.g., use attention-weight certificates or score-aware bounds) to avoid vacuity. - Replace untrusted extraction with a verified parser for model weight slices. +- Prove or verify that `prev` and `direction` are derived from token-level semantics. - Add a formal bridge from certificates to circuit semantics and (eventually) to end-to-end transformer claims. - Improve performance for the exact head-input path without weakening soundness. From 121923217d5da5232bee51c17bc1c8291fddb92d Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 08:23:10 +0100 Subject: [PATCH 153/244] Document prevOfTokens correctness --- Nfp/Model/InductionPrompt.lean | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/Nfp/Model/InductionPrompt.lean b/Nfp/Model/InductionPrompt.lean index df0ac2f..52b6de2 100644 --- a/Nfp/Model/InductionPrompt.lean +++ b/Nfp/Model/InductionPrompt.lean @@ -50,6 +50,40 @@ theorem mem_activeOfTokens {seq : Nat} {tokens : Fin seq → Nat} {q : Fin seq} q ∈ activeOfTokens tokens ↔ ∃ k, k.val < q.val ∧ tokens k = tokens q := by simp [activeOfTokens] +/-- If a prior matching token exists, `prevOfTokens` picks a matching index and is maximal. -/ +theorem prevOfTokens_spec {seq : Nat} {tokens : Fin seq → Nat} {q : Fin seq} + (h : ∃ k, k < q ∧ tokens k = tokens q) : + let p := prevOfTokens tokens q + p < q ∧ tokens p = tokens q ∧ + ∀ k, k < q → tokens k = tokens q → k ≤ p := by + classical + let candidates : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).filter (fun k => + k < q ∧ tokens k = tokens q) + have hnonempty : candidates.Nonempty := by + rcases h with ⟨k, hk, htok⟩ + exact ⟨k, by simp [candidates, hk, htok]⟩ + by_cases h' : candidates.Nonempty + · have hmem : Finset.max' candidates h' ∈ candidates := + Finset.max'_mem candidates h' + have hcond : + Finset.max' candidates h' < q ∧ + tokens (Finset.max' candidates h') = tokens q := by + have hmem' := (Finset.mem_filter.1 hmem).2 + simpa using hmem' + have hmax : + ∀ k, k < q → tokens k = tokens q → + k ≤ Finset.max' candidates h' := by + intro k hk htok + have hk_mem : k ∈ candidates := by + simp [candidates, hk, htok] + have hk_mem' : k ∈ (candidates : Set (Fin seq)) := by + simpa using hk_mem + exact (Finset.isGreatest_max' (s := candidates) h').2 hk_mem' + simpa [prevOfTokens, candidates, h'] using + And.intro hcond.1 (And.intro hcond.2 hmax) + · exact (h' hnonempty).elim + end Model end Nfp From 360f3c721a327a17b91d607b36f4f40da9dfb3db Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 08:26:14 +0100 Subject: [PATCH 154/244] Relate active tokens to prev mapping --- Nfp/Model/InductionPrompt.lean | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/Nfp/Model/InductionPrompt.lean b/Nfp/Model/InductionPrompt.lean index 52b6de2..c49c4ed 100644 --- a/Nfp/Model/InductionPrompt.lean +++ b/Nfp/Model/InductionPrompt.lean @@ -84,6 +84,18 @@ theorem prevOfTokens_spec {seq : Nat} {tokens : Fin seq → Nat} {q : Fin seq} And.intro hcond.1 (And.intro hcond.2 hmax) · exact (h' hnonempty).elim +/-- Active queries imply the `prevOfTokens` maximal-match specification. -/ +theorem prevOfTokens_spec_of_active {seq : Nat} {tokens : Fin seq → Nat} {q : Fin seq} + (hq : q ∈ activeOfTokens tokens) : + let p := prevOfTokens tokens q + p < q ∧ tokens p = tokens q ∧ + ∀ k, k < q → tokens k = tokens q → k ≤ p := by + have h := (mem_activeOfTokens (tokens := tokens) (q := q)).1 hq + rcases h with ⟨k, hk, htok⟩ + have hk' : k < q := by + exact (Fin.lt_def).2 hk + exact prevOfTokens_spec (tokens := tokens) (q := q) ⟨k, hk', htok⟩ + end Model end Nfp From e18fc72f5243bd52817efb9cd8b4626d79a18728 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 09:29:04 +0100 Subject: [PATCH 155/244] Clarify model-derived induction inputs --- CLAIMS.md | 6 +- Nfp/IO/NfptPure.lean | 170 ++++++++++++++++++++++++++++------- SOUNDNESS_LIMITATIONS.md | 6 +- docs/induction_cert_audit.md | 20 +++-- 4 files changed, 163 insertions(+), 39 deletions(-) diff --git a/CLAIMS.md b/CLAIMS.md index 8105a41..f1065ad 100644 --- a/CLAIMS.md +++ b/CLAIMS.md @@ -33,9 +33,11 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri verifies the resulting induction certificate (experimental, potentially slow). - `nfp induction advanced certify_head_model` reads a model binary, derives head inputs in Lean, and verifies the resulting induction certificate (includes attention projection biases and - derives `prev`/active from the stored token sequence by default). + derives `prev`/active from the stored token sequence by default, and builds the logit-diff + direction vector from the target/negative unembedding columns). - `nfp induction advanced certify_head_model_auto` derives the logit-diff direction from the - prompt tokens stored in the model file before running the same head-input checker. + prompt tokens stored in the model file before running the same head-input checker (the + direction vector still uses the unembedding columns). - `nfp induction advanced certify_end_to_end` composes a head-level logit-diff lower bound with a downstream error certificate (arithmetic consistency only). - `nfp induction advanced certify_end_to_end_matrix` computes a downstream bound from a matrix diff --git a/Nfp/IO/NfptPure.lean b/Nfp/IO/NfptPure.lean index d2143b0..83b4446 100644 --- a/Nfp/IO/NfptPure.lean +++ b/Nfp/IO/NfptPure.lean @@ -589,16 +589,15 @@ def readUnembedColumn (data : ByteArray) (start : Nat) (h : NfptHeader) (col : N throw s!"column out of range: {col}" /-- Read induction-head inputs directly from the model binary. -/ -def readInductionHeadInputs (data : ByteArray) (start : Nat) (h : NfptHeader) - (layer head dirTarget dirNegative : Nat) (period? : Option Nat) : - Except String (Model.InductionHeadInputs h.seqLen h.modelDim h.headDim) := do - let scale ← scaleOfHeadDim h.headDim - let tokens ← readTokens data start h - let embed ← readEmbeddings data start h - let weights ← readHeadWeights data start h layer head - let (attnBias, ln1Gamma, ln1Beta) ← readLayerAttnBiasLn1 data start h layer - let colTarget ← readUnembedColumn data start h dirTarget - let colNegative ← readUnembedColumn data start h dirNegative +def buildInductionHeadInputs (h : NfptHeader) (scale : Rat) + (tokens : Fin h.seqLen → Nat) + (embed : Fin h.seqLen → Fin h.modelDim → Rat) + (weights : Model.Gpt2HeadWeights h.modelDim h.headDim) + (attnBias ln1Gamma ln1Beta : Fin h.modelDim → Rat) + (dirTarget dirNegative : Nat) + (colTarget colNegative : Fin h.modelDim → Rat) + (period? : Option Nat) : + Model.InductionHeadInputs h.seqLen h.modelDim h.headDim := let direction : Fin h.modelDim → Rat := fun i => colTarget i - colNegative i let directionSpec : Circuit.DirectionSpec := { target := dirTarget, negative := dirNegative } @@ -610,26 +609,137 @@ def readInductionHeadInputs (data : ByteArray) (start : Nat) (h : NfptHeader) match period? with | some period => Model.prevOfPeriod (seq := h.seqLen) period | none => Model.prevOfTokens (seq := h.seqLen) tokens - pure - { scale := scale - active := active - prev := prev - embed := embed - lnEps := h.layerNormEps - ln1Gamma := ln1Gamma - ln1Beta := ln1Beta - wq := weights.wq - bq := weights.bq - wk := weights.wk - bk := weights.bk - wv := weights.wv - bv := weights.bv - wo := weights.wo - attnBias := attnBias - maskCausal := true - maskValue := (-10000 : Rat) - directionSpec := directionSpec - direction := direction } + { scale := scale + active := active + prev := prev + embed := embed + lnEps := h.layerNormEps + ln1Gamma := ln1Gamma + ln1Beta := ln1Beta + wq := weights.wq + bq := weights.bq + wk := weights.wk + bk := weights.bk + wv := weights.wv + bv := weights.bv + wo := weights.wo + attnBias := attnBias + maskCausal := true + maskValue := (-10000 : Rat) + directionSpec := directionSpec + direction := direction } + +/-- Definitional characterization of `buildInductionHeadInputs`. -/ +theorem buildInductionHeadInputs_def (h : NfptHeader) (scale : Rat) + (tokens : Fin h.seqLen → Nat) + (embed : Fin h.seqLen → Fin h.modelDim → Rat) + (weights : Model.Gpt2HeadWeights h.modelDim h.headDim) + (attnBias ln1Gamma ln1Beta : Fin h.modelDim → Rat) + (dirTarget dirNegative : Nat) + (colTarget colNegative : Fin h.modelDim → Rat) + (period? : Option Nat) : + buildInductionHeadInputs h scale tokens embed weights attnBias ln1Gamma ln1Beta + dirTarget dirNegative colTarget colNegative period? = + { scale := scale + active := + match period? with + | some period => Model.activeOfPeriod (seq := h.seqLen) period + | none => Model.activeOfTokens (seq := h.seqLen) tokens + prev := + match period? with + | some period => Model.prevOfPeriod (seq := h.seqLen) period + | none => Model.prevOfTokens (seq := h.seqLen) tokens + embed := embed + lnEps := h.layerNormEps + ln1Gamma := ln1Gamma + ln1Beta := ln1Beta + wq := weights.wq + bq := weights.bq + wk := weights.wk + bk := weights.bk + wv := weights.wv + bv := weights.bv + wo := weights.wo + attnBias := attnBias + maskCausal := true + maskValue := (-10000 : Rat) + directionSpec := { target := dirTarget, negative := dirNegative } + direction := fun i => colTarget i - colNegative i } := rfl + +/-- `buildInductionHeadInputs` uses the supplied direction ids and columns. -/ +theorem buildInductionHeadInputs_direction_def (h : NfptHeader) (scale : Rat) + (tokens : Fin h.seqLen → Nat) + (embed : Fin h.seqLen → Fin h.modelDim → Rat) + (weights : Model.Gpt2HeadWeights h.modelDim h.headDim) + (attnBias ln1Gamma ln1Beta : Fin h.modelDim → Rat) + (dirTarget dirNegative : Nat) + (colTarget colNegative : Fin h.modelDim → Rat) + (period? : Option Nat) : + let inputs := + buildInductionHeadInputs h scale tokens embed weights attnBias ln1Gamma ln1Beta + dirTarget dirNegative colTarget colNegative period? + inputs.directionSpec = { target := dirTarget, negative := dirNegative } ∧ + inputs.direction = fun i => colTarget i - colNegative i := by + simp [buildInductionHeadInputs] + +/-- `buildInductionHeadInputs` derives `prev`/`active` from tokens or a fixed period. -/ +theorem buildInductionHeadInputs_prev_active_def (h : NfptHeader) (scale : Rat) + (tokens : Fin h.seqLen → Nat) + (embed : Fin h.seqLen → Fin h.modelDim → Rat) + (weights : Model.Gpt2HeadWeights h.modelDim h.headDim) + (attnBias ln1Gamma ln1Beta : Fin h.modelDim → Rat) + (dirTarget dirNegative : Nat) + (colTarget colNegative : Fin h.modelDim → Rat) + (period? : Option Nat) : + let inputs := + buildInductionHeadInputs h scale tokens embed weights attnBias ln1Gamma ln1Beta + dirTarget dirNegative colTarget colNegative period? + inputs.active = + (match period? with + | some period => Model.activeOfPeriod (seq := h.seqLen) period + | none => Model.activeOfTokens (seq := h.seqLen) tokens) ∧ + inputs.prev = + (match period? with + | some period => Model.prevOfPeriod (seq := h.seqLen) period + | none => Model.prevOfTokens (seq := h.seqLen) tokens) := by + simp [buildInductionHeadInputs] + +/-- Active queries pick the maximal matching prior token when `period? = none`. -/ +theorem buildInductionHeadInputs_prev_spec_of_active (h : NfptHeader) (scale : Rat) + (tokens : Fin h.seqLen → Nat) + (embed : Fin h.seqLen → Fin h.modelDim → Rat) + (weights : Model.Gpt2HeadWeights h.modelDim h.headDim) + (attnBias ln1Gamma ln1Beta : Fin h.modelDim → Rat) + (dirTarget dirNegative : Nat) + (colTarget colNegative : Fin h.modelDim → Rat) : + ∀ {q}, + q ∈ (buildInductionHeadInputs h scale tokens embed weights attnBias ln1Gamma ln1Beta + dirTarget dirNegative colTarget colNegative none).active → + let p := + (buildInductionHeadInputs h scale tokens embed weights attnBias ln1Gamma ln1Beta + dirTarget dirNegative colTarget colNegative none).prev q + p < q ∧ tokens p = tokens q ∧ + ∀ k, k < q → tokens k = tokens q → k ≤ p := by + intro q hq + have hq' : q ∈ Model.activeOfTokens (seq := h.seqLen) tokens := by + simpa [buildInductionHeadInputs] using hq + have hspec := Model.prevOfTokens_spec_of_active (tokens := tokens) (q := q) hq' + simpa [buildInductionHeadInputs] using hspec + +/-- Read induction-head inputs directly from the model binary. -/ +def readInductionHeadInputs (data : ByteArray) (start : Nat) (h : NfptHeader) + (layer head dirTarget dirNegative : Nat) (period? : Option Nat) : + Except String (Model.InductionHeadInputs h.seqLen h.modelDim h.headDim) := do + let scale ← scaleOfHeadDim h.headDim + let tokens ← readTokens data start h + let embed ← readEmbeddings data start h + let weights ← readHeadWeights data start h layer head + let (attnBias, ln1Gamma, ln1Beta) ← readLayerAttnBiasLn1 data start h layer + let colTarget ← readUnembedColumn data start h dirTarget + let colNegative ← readUnembedColumn data start h dirNegative + pure <| + buildInductionHeadInputs h scale tokens embed weights attnBias ln1Gamma ln1Beta + dirTarget dirNegative colTarget colNegative period? end NfptPure diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index 775eee5..1407347 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -22,8 +22,10 @@ It is intentionally brief and focused on the soundness boundary. requires `head_dim` to be a perfect square to represent the scale as an exact rational. - The `certify_head_model_auto` path derives the logit-diff direction from the stored prompt tokens using a heuristic; use explicit direction tokens for fixed claims. -- The certification does not prove that `prev` corresponds to the behavioral induction pattern, - nor that the chosen direction matches a task-specific target/negative logit semantics. +- The certification does not yet prove end-to-end behavioral induction claims. For + `certify_head_model` with `period? = none`, `prev` is derived from tokens and is the maximal + prior match, but other inputs (head-input files or explicit periods) still rely on supplied + `prev` maps. The chosen direction still assumes the unembedding columns encode token logits. - Performance: exact head-input recomputation in Lean can be slow for nontrivial sequence lengths. - There is no bridge theorem connecting certificate validity to a full circuit/model semantics statement (for example, a formal statement about logits under a transformer block stack). diff --git a/docs/induction_cert_audit.md b/docs/induction_cert_audit.md index dafc18c..55eaff3 100644 --- a/docs/induction_cert_audit.md +++ b/docs/induction_cert_audit.md @@ -9,6 +9,14 @@ induction heads, and spell out the scope and limitations of that claim. (`lnEps > 0`, `sqrtLower lnEps > 0`, `dModel ≠ 0`, `active.Nonempty`), so the computation is only claimed when these preconditions hold (`Nfp/Sound/Induction/Core.lean`). +- `buildInductionHeadInputs_def` shows the model-derived head inputs are + definitional: `prev`/`active` are computed from tokens (or a fixed period), + and the `direction` vector is the unembedding-column difference for the + provided target/negative token ids (`Nfp/IO/NfptPure.lean`). +- `buildInductionHeadInputs_prev_spec_of_active` and + `prevOfTokens_spec_of_active` prove that when `period? = none`, + every active query has a maximal prior matching token in `prev` + (`Nfp/IO/NfptPure.lean`, `Nfp/Model/InductionPrompt.lean`). - `buildInductionCertFromHeadWith?` wraps the core computation and returns a proof-carrying certificate `⟨c, InductionHeadCertSound inputs c⟩` (`Nfp/Sound/Induction/HeadOutput.lean`). @@ -51,11 +59,13 @@ head-level bounds hold. They are **not** sufficient for a global claim that a head “is an induction head” without additional assumptions. Key assumptions and limitations: -- `prev` is an input, not a derived fact: the proofs assume it is the intended - prefix-matching index. There is no theorem linking `prev` to token identity. -- `directionSpec` is metadata only. The certificate proves a logit-diff bound - for the provided `direction` vector; it does not prove that vector equals the - model’s actual target-minus-negative unembedding direction. +- For `certify_head_model` with `period? = none`, `prev`/`active` are derived + from tokens and `prev` is the maximal prior match. For head-input files or + when `period?` is set explicitly, `prev` remains a user-supplied input. +- The certificate proves a logit-diff bound along the supplied `direction` + vector. For model-derived inputs, this vector is the target-minus-negative + unembedding column difference, but we still assume that the unembedding + matrix represents the model’s logit map. - The active set is user-supplied and can be strict; bounds only hold for `q ∈ active`, not all positions. - The head-level certificate does not imply end-to-end behavior across blocks; From 8f50953e95a410d3b139f60e81649386b2259201 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 12:19:00 +0100 Subject: [PATCH 156/244] Add head logit-diff bridge with residual bounds --- CLAIMS.md | 6 +- Nfp/Sound/Induction/LogitDiff.lean | 198 ++++++++++++++++++++++++++++- docs/induction_cert_audit.md | 18 ++- 3 files changed, 210 insertions(+), 12 deletions(-) diff --git a/CLAIMS.md b/CLAIMS.md index f1065ad..231da81 100644 --- a/CLAIMS.md +++ b/CLAIMS.md @@ -13,6 +13,9 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri `buildInductionCertFromHeadCoreWith?` returns a certificate for the given inputs. - Logit-diff lower bound lemmas: `logitDiffLowerBound_le` and `logitDiffLowerBoundFromCert_le`. +- Bridge lemmas composing head logit-diff bounds with head outputs and residual + interval bounds: `headLogitDiff_eq_direction_dot_headOutput` and + `logitDiffLowerBound_with_residual`. - Downstream linear certificate soundness: `checkDownstreamLinearCert` implies `DownstreamLinearBounds`. - Residual-interval certificate soundness: `checkResidualIntervalCert` implies @@ -62,4 +65,5 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri - End-to-end claims about GPT-2 logits or Jacobians derived from certificates. - Sound, verified downstream bounds computed from GPT-2 weights inside Lean. -- A bridge theorem connecting certificate validity to full circuit/model semantics. +- A full end-to-end bridge from head certificates to full-model logit bounds + (beyond the head-output + residual-interval composition). diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean index 8db9932..ab6bcbb 100644 --- a/Nfp/Sound/Induction/LogitDiff.lean +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -1,7 +1,9 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Aesop +import Mathlib.Data.Vector.Basic import Nfp.Circuit.Cert.LogitDiff +import Nfp.Sound.Bounds.MatrixNorm.Interval import Nfp.Sound.Induction /-! @@ -16,11 +18,81 @@ open Nfp.Circuit variable {seq : Nat} -section LogitDiffLowerBound +section Direction -variable {seq dModel dHead : Nat} [NeZero seq] +variable {seq dModel dHead : Nat} -section +/-- Direction projection of a single-key head output. -/ +theorem direction_dot_headValue_eq_valsReal + (inputs : Model.InductionHeadInputs seq dModel dHead) (k : Fin seq) : + dotProduct (fun i => (inputs.direction i : Real)) + (fun i => headValueRealOfInputs inputs k i) = + valsRealOfInputs inputs k := by + classical + let dir : Fin dModel → Real := fun i => (inputs.direction i : Real) + let v : Fin dHead → Real := fun d => vRealOfInputs inputs k d + have hswap : + ∑ i, dir i * ∑ d, (inputs.wo i d : Real) * v d = + ∑ d, (∑ i, dir i * (inputs.wo i d : Real)) * v d := by + calc + ∑ i, dir i * ∑ d, (inputs.wo i d : Real) * v d + = ∑ i, ∑ d, dir i * ((inputs.wo i d : Real) * v d) := by + simp [Finset.mul_sum] + _ = ∑ d, ∑ i, dir i * ((inputs.wo i d : Real) * v d) := by + simpa using + (Finset.sum_comm (s := (Finset.univ : Finset (Fin dModel))) + (t := (Finset.univ : Finset (Fin dHead))) + (f := fun i d => dir i * ((inputs.wo i d : Real) * v d))) + _ = ∑ d, (∑ i, dir i * (inputs.wo i d : Real)) * v d := by + refine Finset.sum_congr rfl ?_ + intro d _ + simp [mul_assoc, Finset.sum_mul] + have hdirHead : + ∀ d, ((dirHeadVecOfInputs inputs).get d : Real) = + ∑ i, (inputs.wo i d : Real) * (inputs.direction i : Real) := by + intro d + calc + ((dirHeadVecOfInputs inputs).get d : Real) = + ratToReal (Linear.dotFin dModel (fun i => inputs.wo i d) + (fun i => inputs.direction i)) := by + simp [dirHeadVecOfInputs, Vector.get, Vector.ofFn, ratToReal] + _ = + ratToReal (dotProduct (fun i => inputs.wo i d) (fun i => inputs.direction i)) := by + simp [Linear.dotFin_eq_dotProduct] + _ = + ∑ i, ratToReal (inputs.wo i d * inputs.direction i) := by + simp [dotProduct, Linear.ratToReal_sum_univ] + _ = + ∑ i, (inputs.wo i d : Real) * (inputs.direction i : Real) := by + simp [ratToReal] + calc + dotProduct dir (fun i => headValueRealOfInputs inputs k i) + = ∑ i, dir i * + ∑ d, (inputs.wo i d : Real) * v d := by + simp [dir, v, headValueRealOfInputs, dotProduct] + _ = ∑ d, (∑ i, dir i * (inputs.wo i d : Real)) * v d := by + simp [hswap] + _ = ∑ d, ((dirHeadVecOfInputs inputs).get d : Real) * v d := by + refine Finset.sum_congr rfl ?_ + intro d _ + have hdir : + ∑ i, dir i * (inputs.wo i d : Real) = + ((dirHeadVecOfInputs inputs).get d : Real) := by + calc + ∑ i, dir i * (inputs.wo i d : Real) + = ∑ i, (inputs.wo i d : Real) * (inputs.direction i : Real) := by + simp [dir, mul_comm] + _ = ((dirHeadVecOfInputs inputs).get d : Real) := by + simpa using (hdirHead d).symm + simp [hdir] + _ = valsRealOfInputs inputs k := by + simp [valsRealOfInputs, v, dotProduct] + +end Direction + +section LogitDiffLowerBound + +variable {seq dModel dHead : Nat} /-- Real-valued logit-diff contribution for a query. -/ noncomputable def headLogitDiff (inputs : Model.InductionHeadInputs seq dModel dHead) @@ -38,6 +110,10 @@ def logitDiffLowerBoundFromCert (c : InductionHeadCert seq) : Option Rat := def logitDiffLowerBoundFromCertWeighted (c : InductionHeadCert seq) : Option Rat := Circuit.logitDiffLowerBoundWeightedAt c.active c.prev c.weightBoundAt c.values.valsLo +section WithNeZero + +variable [NeZero seq] + theorem logitDiffLowerBoundFromCert_le (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) (hsound : InductionHeadCertSound inputs c) @@ -469,7 +545,121 @@ def buildInductionLogitLowerBoundNonvacuous? · exact some ⟨base, hpos⟩ · exact none -end +end WithNeZero + +/-! End-to-end lower bounds from head certificates plus residual intervals. -/ + +/-- The head logit-diff equals the direction dot product of the head output. -/ +theorem headLogitDiff_eq_direction_dot_headOutput + (inputs : Model.InductionHeadInputs seq dModel dHead) (q : Fin seq) : + headLogitDiff inputs q = + dotProduct (fun i => (inputs.direction i : Real)) + (fun i => headOutput inputs q i) := by + classical + let dir : Fin dModel → Real := fun i => (inputs.direction i : Real) + let weights : Fin seq → Fin seq → Real := fun q k => + Circuit.softmax (scoresRealOfInputs inputs q) k + have hswap : + dotProduct dir (fun i => headOutput inputs q i) = + ∑ k, weights q k * dotProduct dir (fun i => headValueRealOfInputs inputs k i) := by + calc + dotProduct dir (fun i => headOutput inputs q i) + = ∑ i, dir i * ∑ k, weights q k * headValueRealOfInputs inputs k i := by + simp [dir, headOutput, headOutputWithScores, weights, dotProduct] + _ = ∑ i, ∑ k, dir i * (weights q k * headValueRealOfInputs inputs k i) := by + simp [Finset.mul_sum] + _ = ∑ k, ∑ i, dir i * (weights q k * headValueRealOfInputs inputs k i) := by + simpa using + (Finset.sum_comm (s := (Finset.univ : Finset (Fin dModel))) + (t := (Finset.univ : Finset (Fin seq))) + (f := fun i k => dir i * (weights q k * headValueRealOfInputs inputs k i))) + _ = ∑ k, weights q k * ∑ i, dir i * headValueRealOfInputs inputs k i := by + refine Finset.sum_congr rfl ?_ + intro k _ + calc + ∑ i, dir i * (weights q k * headValueRealOfInputs inputs k i) = + ∑ i, weights q k * (dir i * headValueRealOfInputs inputs k i) := by + refine Finset.sum_congr rfl ?_ + intro i _ + simp [mul_assoc, mul_left_comm, mul_comm] + _ = weights q k * ∑ i, dir i * headValueRealOfInputs inputs k i := by + simp [Finset.mul_sum] + have hsum : + dotProduct dir (fun i => headOutput inputs q i) = + ∑ k, weights q k * valsRealOfInputs inputs k := by + calc + dotProduct dir (fun i => headOutput inputs q i) = + ∑ k, weights q k * dotProduct dir (fun i => headValueRealOfInputs inputs k i) := hswap + _ = ∑ k, weights q k * valsRealOfInputs inputs k := by + refine Finset.sum_congr rfl ?_ + intro k _ + have hdir := direction_dot_headValue_eq_valsReal (inputs := inputs) (k := k) + exact congrArg (fun x => weights q k * x) (by simpa [dir] using hdir) + calc + headLogitDiff inputs q = + dotProduct (weights q) (valsRealOfInputs inputs) := by + simp [headLogitDiff, weights] + _ = ∑ k, weights q k * valsRealOfInputs inputs k := by + simp [dotProduct] + _ = dotProduct (fun i => (inputs.direction i : Real)) + (fun i => headOutput inputs q i) := by + simpa [dir] using hsum.symm + +/-- Combine a head logit-diff bound with residual interval bounds. -/ +theorem logitDiffLowerBound_with_residual + (inputs : Model.InductionHeadInputs seq dModel dHead) + (lb : Rat) + (hlb : ∀ q, q ∈ inputs.active → (lb : Real) ≤ headLogitDiff inputs q) + (residual : Fin seq → Fin dModel → Real) + (lo hi : Fin dModel → Rat) + (hres : ∀ q, q ∈ inputs.active → ∀ i, + (lo i : Real) ≤ residual q i ∧ residual q i ≤ (hi i : Real)) : + ∀ q, q ∈ inputs.active → + (lb : Real) - (Bounds.dotIntervalAbsBound inputs.direction lo hi : Real) ≤ + dotProduct (fun i => (inputs.direction i : Real)) + (fun i => headOutput inputs q i + residual q i) := by + intro q hq + have hhead := hlb q hq + have hres' : + |dotProduct (fun i => (inputs.direction i : Real)) (residual q)| ≤ + (Bounds.dotIntervalAbsBound inputs.direction lo hi : Real) := by + have hlo : ∀ i, (lo i : Real) ≤ residual q i := fun i => (hres q hq i).1 + have hhi : ∀ i, residual q i ≤ (hi i : Real) := fun i => (hres q hq i).2 + simpa using + (Bounds.abs_dotProduct_le_dotIntervalAbsBound_real + (v := inputs.direction) (lo := lo) (hi := hi) (x := residual q) hlo hhi) + have hres_lower : + -(Bounds.dotIntervalAbsBound inputs.direction lo hi : Real) ≤ + dotProduct (fun i => (inputs.direction i : Real)) (residual q) := by + exact (abs_le.mp hres').1 + have hsum : + (lb : Real) - (Bounds.dotIntervalAbsBound inputs.direction lo hi : Real) ≤ + headLogitDiff inputs q + + dotProduct (fun i => (inputs.direction i : Real)) (residual q) := by + have hsum' : (lb : Real) + -(Bounds.dotIntervalAbsBound inputs.direction lo hi : Real) ≤ + headLogitDiff inputs q + + dotProduct (fun i => (inputs.direction i : Real)) (residual q) := + add_le_add hhead hres_lower + simpa [sub_eq_add_neg] using hsum' + calc + (lb : Real) - (Bounds.dotIntervalAbsBound inputs.direction lo hi : Real) ≤ + headLogitDiff inputs q + + dotProduct (fun i => (inputs.direction i : Real)) (residual q) := hsum + _ = + dotProduct (fun i => (inputs.direction i : Real)) + (fun i => headOutput inputs q i + residual q i) := by + have hdot : + dotProduct (fun i => (inputs.direction i : Real)) + (fun i => headOutput inputs q i + residual q i) = + dotProduct (fun i => (inputs.direction i : Real)) + (fun i => headOutput inputs q i) + + dotProduct (fun i => (inputs.direction i : Real)) (residual q) := by + simpa using + (Linear.dotProduct_add_right + (x := fun i => (inputs.direction i : Real)) + (y := fun i => headOutput inputs q i) + (z := residual q)) + simp [headLogitDiff_eq_direction_dot_headOutput, hdot] end LogitDiffLowerBound diff --git a/docs/induction_cert_audit.md b/docs/induction_cert_audit.md index 55eaff3..a5030c5 100644 --- a/docs/induction_cert_audit.md +++ b/docs/induction_cert_audit.md @@ -68,17 +68,21 @@ Key assumptions and limitations: matrix represents the model’s logit map. - The active set is user-supplied and can be strict; bounds only hold for `q ∈ active`, not all positions. -- The head-level certificate does not imply end-to-end behavior across blocks; - there is no formal bridge to full-model logits. +- There is now a formal bridge from head logit-diff bounds plus residual interval + bounds to a direction lower bound on `headOutput + residual`, but full + end-to-end model logits still require verified residual bounds through the + rest of the stack. ## Conclusion Yes—**within the formal scope** of the current definitions, the proofs are enough to claim that we can certify induction-head behavior at the head level: they certify attention to a specified `prev` index and a logit-diff lower bound -along a specified direction. What is still missing is a proof that those inputs -correspond to the behavioral induction-head definition on actual sequences and -that the certified head contribution scales to end-to-end model logits. +along a specified direction. We now have a bridge that composes those bounds +with residual interval bounds to certify `headOutput + residual`, but we still +need a proof that the inputs correspond to the behavioral induction-head +definition on actual sequences and that residual bounds are derived from full +model semantics. ## Next steps @@ -86,5 +90,5 @@ that the certified head contribution scales to end-to-end model logits. derived from unembedding (so the certified direction matches token-level claims). - Add a proof or verified derivation that the `prev` mapping corresponds to the induction pattern for a given prompt sequence. -- Build a bridge theorem that lifts head-level certificates to block-level or - end-to-end logit bounds. +- Extend the bridge to full transformer stacks by deriving residual interval + bounds from verified layer/block semantics. From 31fc251e9e0ecebbfd75cf25a4d1795c49f04df4 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 12:31:14 +0100 Subject: [PATCH 157/244] Package GPT-2 residual bounds as certified intervals --- CLAIMS.md | 2 ++ Nfp/Sound/Bounds/Transformer.lean | 45 +++++++++++++++++++++++++++++++ docs/induction_cert_audit.md | 3 +++ 3 files changed, 50 insertions(+) diff --git a/CLAIMS.md b/CLAIMS.md index 231da81..21c0499 100644 --- a/CLAIMS.md +++ b/CLAIMS.md @@ -20,6 +20,8 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri `DownstreamLinearBounds`. - Residual-interval certificate soundness: `checkResidualIntervalCert` implies `ResidualIntervalBounds`. +- GPT-2 residual interval bounds from model slices are sound for + `transformerStackFinalReal` on active positions (`gpt2ResidualIntervalBoundsActive_sound`). - Row-sum matrix norm bounds for `mulVec` under uniform input magnitude. - Tanh-GELU bounds and interval propagation through MLP layers. - Interval bounds for multi-head attention and full transformer-layer residual blocks. diff --git a/Nfp/Sound/Bounds/Transformer.lean b/Nfp/Sound/Bounds/Transformer.lean index 6544275..87fef74 100644 --- a/Nfp/Sound/Bounds/Transformer.lean +++ b/Nfp/Sound/Bounds/Transformer.lean @@ -3,6 +3,7 @@ import Mathlib.Algebra.BigOperators.Group.Finset.Basic import Mathlib.Data.List.Range import Mathlib.Data.Real.Basic +import Nfp.Circuit.Cert.ResidualInterval import Nfp.Model.Gpt2 import Nfp.Sound.Bounds.Attention import Nfp.Sound.Bounds.LayerNorm @@ -512,6 +513,50 @@ theorem gpt2ResidualIntervalBoundsActive_spec simpa [bounds, gpt2ResidualIntervalBoundsActive, final, baseLo, baseHi] using hbounds q hq i +/-- Package GPT-2 residual bounds into a residual-interval certificate. -/ +theorem gpt2ResidualIntervalBoundsActive_sound + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] + (active : Finset (Fin seq)) (hactive : active.Nonempty) (eps : Rat) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (finalLn : Model.Gpt2FinalLayerNorm dModel) + (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) + (embed : Fin seq → Fin dModel → Rat) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : + let bounds := gpt2ResidualIntervalBoundsActive active hactive eps layers heads finalLn embed + let cert : Circuit.ResidualIntervalCert dModel := { lo := bounds.1, hi := bounds.2 } + Circuit.ResidualIntervalBounds cert ∧ + ∀ q, q ∈ active → ∀ i, + (cert.lo i : Real) ≤ + transformerStackFinalReal eps finalLn layers heads scores + (fun q i => (embed q i : Real)) q i ∧ + transformerStackFinalReal eps finalLn layers heads scores + (fun q i => (embed q i : Real)) q i ≤ (cert.hi i : Real) := by + classical + intro bounds cert + have hspec : + ∀ q, q ∈ active → ∀ i, + (bounds.1 i : Real) ≤ + transformerStackFinalReal eps finalLn layers heads scores + (fun q i => (embed q i : Real)) q i ∧ + transformerStackFinalReal eps finalLn layers heads scores + (fun q i => (embed q i : Real)) q i ≤ (bounds.2 i : Real) := by + simpa [bounds] using + (gpt2ResidualIntervalBoundsActive_spec (active := active) (hactive := hactive) + (eps := eps) (layers := layers) (heads := heads) (finalLn := finalLn) + (scores := scores) (embed := embed) (hne := hne) (heps := heps) (hsqrt := hsqrt)) + have hbounds : Circuit.ResidualIntervalBounds cert := by + refine { lo_le_hi := ?_ } + intro i + rcases hactive with ⟨q0, hq0⟩ + have hq := hspec q0 hq0 i + have hreal : (bounds.1 i : Real) ≤ (bounds.2 i : Real) := hq.1.trans hq.2 + exact (ratToReal_le_iff (x := bounds.1 i) (y := bounds.2 i)).1 hreal + refine And.intro hbounds ?_ + intro q hq i + have hq' := hspec q hq i + simpa [cert] using hq' + end Bounds end Sound diff --git a/docs/induction_cert_audit.md b/docs/induction_cert_audit.md index a5030c5..3bf7ea1 100644 --- a/docs/induction_cert_audit.md +++ b/docs/induction_cert_audit.md @@ -72,6 +72,9 @@ Key assumptions and limitations: bounds to a direction lower bound on `headOutput + residual`, but full end-to-end model logits still require verified residual bounds through the rest of the stack. + We now have a theorem packaging GPT-2 residual interval bounds derived from + model slices into a sound `ResidualIntervalCert`, but it is not yet connected + to the head-level logit-diff contribution inside the full stack. ## Conclusion From 955f056962e24c7fdbbaeba033d357e501145150 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 12:42:39 +0100 Subject: [PATCH 158/244] Compose head logit-diff with output intervals --- CLAIMS.md | 3 +- Nfp/Sound/Induction/LogitDiff.lean | 63 ++++++++++++++++++++++++++++++ SOUNDNESS_LIMITATIONS.md | 3 ++ docs/induction_cert_audit.md | 3 ++ 4 files changed, 71 insertions(+), 1 deletion(-) diff --git a/CLAIMS.md b/CLAIMS.md index 21c0499..245cf62 100644 --- a/CLAIMS.md +++ b/CLAIMS.md @@ -15,7 +15,8 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri `logitDiffLowerBoundFromCert_le`. - Bridge lemmas composing head logit-diff bounds with head outputs and residual interval bounds: `headLogitDiff_eq_direction_dot_headOutput` and - `logitDiffLowerBound_with_residual`. + `logitDiffLowerBound_with_residual`, plus interval-composition + `logitDiffLowerBound_with_output_intervals`. - Downstream linear certificate soundness: `checkDownstreamLinearCert` implies `DownstreamLinearBounds`. - Residual-interval certificate soundness: `checkResidualIntervalCert` implies diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean index ab6bcbb..8440ff5 100644 --- a/Nfp/Sound/Induction/LogitDiff.lean +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -661,6 +661,69 @@ theorem logitDiffLowerBound_with_residual (z := residual q)) simp [headLogitDiff_eq_direction_dot_headOutput, hdot] +/-- Combine a head logit-diff bound with intervals on head output and a downstream output. -/ +theorem logitDiffLowerBound_with_output_intervals + (inputs : Model.InductionHeadInputs seq dModel dHead) + (lb : Rat) + (hlb : ∀ q, q ∈ inputs.active → (lb : Real) ≤ headLogitDiff inputs q) + (output : Fin seq → Fin dModel → Real) + (outLo outHi : Fin dModel → Rat) + (hout : ∀ q, q ∈ inputs.active → ∀ i, + (outLo i : Real) ≤ output q i ∧ output q i ≤ (outHi i : Real)) + (headLo headHi : Fin dModel → Rat) + (hhead : ∀ q, q ∈ inputs.active → ∀ i, + (headLo i : Real) ≤ headOutput inputs q i ∧ + headOutput inputs q i ≤ (headHi i : Real)) : + ∀ q, q ∈ inputs.active → + (lb : Real) - + (Bounds.dotIntervalAbsBound inputs.direction + (fun i => outLo i - headHi i) (fun i => outHi i - headLo i) : Real) ≤ + dotProduct (fun i => (inputs.direction i : Real)) (fun i => output q i) := by + intro q hq + let residual : Fin seq → Fin dModel → Real := + fun q i => output q i - headOutput inputs q i + let lo : Fin dModel → Rat := fun i => outLo i - headHi i + let hi : Fin dModel → Rat := fun i => outHi i - headLo i + have hres : ∀ q, q ∈ inputs.active → ∀ i, + (lo i : Real) ≤ residual q i ∧ residual q i ≤ (hi i : Real) := by + intro q hq i + have hout_q := hout q hq i + have hhead_q := hhead q hq i + have hlow : + (outLo i : Real) - (headHi i : Real) ≤ + output q i - headOutput inputs q i := by + exact sub_le_sub hout_q.1 hhead_q.2 + have hhigh : + output q i - headOutput inputs q i ≤ + (outHi i : Real) - (headLo i : Real) := by + exact sub_le_sub hout_q.2 hhead_q.1 + constructor + · simpa [lo, residual, ratToReal_sub] using hlow + · simpa [hi, residual, ratToReal_sub] using hhigh + have hbound := + logitDiffLowerBound_with_residual + (inputs := inputs) + (lb := lb) + (hlb := hlb) + (residual := residual) + (lo := lo) + (hi := hi) + hres + q + hq + have hdot : + dotProduct (fun i => (inputs.direction i : Real)) + (fun i => headOutput inputs q i + residual q i) = + dotProduct (fun i => (inputs.direction i : Real)) + (fun i => output q i) := by + refine Finset.sum_congr rfl ?_ + intro i _ + have hsum : + headOutput inputs q i + residual q i = output q i := by + simp [residual, sub_eq_add_neg, add_left_comm] + simp [hsum] + simpa [lo, hi, hdot] using hbound + end LogitDiffLowerBound end Sound diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index 1407347..66020e1 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -26,6 +26,9 @@ It is intentionally brief and focused on the soundness boundary. `certify_head_model` with `period? = none`, `prev` is derived from tokens and is the maximal prior match, but other inputs (head-input files or explicit periods) still rely on supplied `prev` maps. The chosen direction still assumes the unembedding columns encode token logits. +- There is now a sound interval-composition lemma that combines head logit-diff bounds with + head/output intervals via subtraction, but it does not model how head outputs propagate + through subsequent LN/MLP blocks (so tight end-to-end claims remain open). - Performance: exact head-input recomputation in Lean can be slow for nontrivial sequence lengths. - There is no bridge theorem connecting certificate validity to a full circuit/model semantics statement (for example, a formal statement about logits under a transformer block stack). diff --git a/docs/induction_cert_audit.md b/docs/induction_cert_audit.md index 3bf7ea1..3e13b1d 100644 --- a/docs/induction_cert_audit.md +++ b/docs/induction_cert_audit.md @@ -75,6 +75,9 @@ Key assumptions and limitations: We now have a theorem packaging GPT-2 residual interval bounds derived from model slices into a sound `ResidualIntervalCert`, but it is not yet connected to the head-level logit-diff contribution inside the full stack. + A new lemma composes head logit-diff bounds with *both* head-output intervals + and downstream output intervals, yielding a sound lower bound on the direction + dot of the downstream output (via interval subtraction). ## Conclusion From 0563d22cf5e192b6aef55cf705878ab7f6c7c607 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 12:50:10 +0100 Subject: [PATCH 159/244] Add GPT-2 end-to-end logit-diff bound --- CLAIMS.md | 3 ++ MODULE_MAP.md | 2 + Nfp/Sound/Induction.lean | 1 + Nfp/Sound/Induction/EndToEnd.lean | 82 ++++++++++++++++++++++++++++++ Nfp/Sound/Induction/LogitDiff.lean | 2 +- SOUNDNESS_LIMITATIONS.md | 2 + docs/induction_cert_audit.md | 7 ++- 7 files changed, 97 insertions(+), 2 deletions(-) create mode 100644 Nfp/Sound/Induction/EndToEnd.lean diff --git a/CLAIMS.md b/CLAIMS.md index 245cf62..be2f5c2 100644 --- a/CLAIMS.md +++ b/CLAIMS.md @@ -23,6 +23,9 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri `ResidualIntervalBounds`. - GPT-2 residual interval bounds from model slices are sound for `transformerStackFinalReal` on active positions (`gpt2ResidualIntervalBoundsActive_sound`). +- End-to-end direction-dot lower bounds on `transformerStackFinalReal` can be derived by + composing head logit-diff bounds with head/output intervals + (`logitDiffLowerBound_end_to_end_gpt2`). - Row-sum matrix norm bounds for `mulVec` under uniform input magnitude. - Tanh-GELU bounds and interval propagation through MLP layers. - Interval bounds for multi-head attention and full transformer-layer residual blocks. diff --git a/MODULE_MAP.md b/MODULE_MAP.md index 93d67fe..9f07346 100644 --- a/MODULE_MAP.md +++ b/MODULE_MAP.md @@ -165,6 +165,8 @@ but you **must** update this list in the same commit. - Helper lemmas for value-direction projections in the core soundness proof. - `Nfp/Sound/Induction/CoreDefs.lean` - Core definitions and soundness predicates for induction certificates. +- `Nfp/Sound/Induction/EndToEnd.lean` + - End-to-end induction bounds combining head certificates with transformer-stack intervals. - `Nfp/Sound/Induction/HeadOutput.lean` - Head-output interval certificates built from induction head inputs. - `Nfp/Sound/Induction/HeadBounds.lean` diff --git a/Nfp/Sound/Induction.lean b/Nfp/Sound/Induction.lean index 7376f8e..ea2bcfb 100644 --- a/Nfp/Sound/Induction.lean +++ b/Nfp/Sound/Induction.lean @@ -2,6 +2,7 @@ import Nfp.Sound.Induction.Core import Nfp.Sound.Induction.CoreSound +import Nfp.Sound.Induction.EndToEnd import Nfp.Sound.Induction.HeadOutput /-! diff --git a/Nfp/Sound/Induction/EndToEnd.lean b/Nfp/Sound/Induction/EndToEnd.lean new file mode 100644 index 0000000..35ad055 --- /dev/null +++ b/Nfp/Sound/Induction/EndToEnd.lean @@ -0,0 +1,82 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Sound.Bounds.Transformer +import Nfp.Sound.Induction.HeadOutput +import Nfp.Sound.Induction.LogitDiff + +/-! +End-to-end induction bounds that combine head certificates with transformer-stack intervals. +-/ + +namespace Nfp + +namespace Sound + +/-- Compose head logit-diff bounds with GPT-2 stack output intervals. -/ +theorem logitDiffLowerBound_end_to_end_gpt2 + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] + (inputs : Model.InductionHeadInputs seq dModel dHead) + (lb : Rat) + (hlb : ∀ q, q ∈ inputs.active → (lb : Real) ≤ headLogitDiff inputs q) + (headCert : Circuit.ResidualIntervalCert dModel) + (hhead : HeadOutputIntervalSound inputs inputs.active headCert) + (eps : Rat) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (finalLn : Model.Gpt2FinalLayerNorm dModel) + (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < Bounds.sqrtLower eps) + (hactive : inputs.active.Nonempty) : + let bounds := + Bounds.gpt2ResidualIntervalBoundsActive inputs.active hactive eps layers heads finalLn + inputs.embed + let output : Fin seq → Fin dModel → Real := + fun q i => Bounds.transformerStackFinalReal eps finalLn layers heads scores + (fun q i => (inputs.embed q i : Real)) q i + ∀ q, q ∈ inputs.active → + (lb : Real) - + (Bounds.dotIntervalAbsBound inputs.direction + (fun i => bounds.1 i - headCert.hi i) + (fun i => bounds.2 i - headCert.lo i) : Real) ≤ + dotProduct (fun i => (inputs.direction i : Real)) (fun i => output q i) := by + classical + intro bounds output q hq + let cert : Circuit.ResidualIntervalCert dModel := { lo := bounds.1, hi := bounds.2 } + have hbounds : + ∀ q, q ∈ inputs.active → ∀ i, + (bounds.1 i : Real) ≤ output q i ∧ output q i ≤ (bounds.2 i : Real) := by + have hsound := + Bounds.gpt2ResidualIntervalBoundsActive_sound + (active := inputs.active) + (hactive := hactive) + (eps := eps) + (layers := layers) + (heads := heads) + (finalLn := finalLn) + (scores := scores) + (embed := inputs.embed) + (hne := hne) + (heps := heps) + (hsqrt := hsqrt) + rcases (by simpa [bounds, cert, output] using hsound) with ⟨_, hmem⟩ + exact hmem + have hhead_out := hhead.output_mem + have h := + logitDiffLowerBound_with_output_intervals + (inputs := inputs) + (lb := lb) + (hlb := hlb) + (output := output) + (outLo := bounds.1) + (outHi := bounds.2) + (hout := hbounds) + (headLo := headCert.lo) + (headHi := headCert.hi) + (hhead := hhead_out) + q + hq + simpa [bounds] using h + +end Sound + +end Nfp diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean index 8440ff5..be5b705 100644 --- a/Nfp/Sound/Induction/LogitDiff.lean +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -4,7 +4,7 @@ import Aesop import Mathlib.Data.Vector.Basic import Nfp.Circuit.Cert.LogitDiff import Nfp.Sound.Bounds.MatrixNorm.Interval -import Nfp.Sound.Induction +import Nfp.Sound.Induction.HeadOutput /-! Logit-diff bounds derived from induction certificates. diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index 66020e1..2ea573f 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -29,6 +29,8 @@ It is intentionally brief and focused on the soundness boundary. - There is now a sound interval-composition lemma that combines head logit-diff bounds with head/output intervals via subtraction, but it does not model how head outputs propagate through subsequent LN/MLP blocks (so tight end-to-end claims remain open). +- The GPT-2 end-to-end bound currently relies on these coarse intervals, so it can be + conservative or vacuous unless the downstream intervals are tightened. - Performance: exact head-input recomputation in Lean can be slow for nontrivial sequence lengths. - There is no bridge theorem connecting certificate validity to a full circuit/model semantics statement (for example, a formal statement about logits under a transformer block stack). diff --git a/docs/induction_cert_audit.md b/docs/induction_cert_audit.md index 3e13b1d..df7c88f 100644 --- a/docs/induction_cert_audit.md +++ b/docs/induction_cert_audit.md @@ -28,6 +28,10 @@ induction heads, and spell out the scope and limitations of that claim. `buildInductionLogitLowerBoundNonvacuous?` lift the head certificate to a logit-diff lower bound; the key lemma `logitDiffLowerBoundFromCert_le` shows the bound is sound on active queries (`Nfp/Sound/Induction/LogitDiff.lean`). +- `logitDiffLowerBound_end_to_end_gpt2` combines head logit-diff bounds, head + output intervals, and GPT-2 stack output intervals to give a direction lower + bound on `transformerStackFinalReal` + (`Nfp/Sound/Induction/EndToEnd.lean`, `Nfp/Sound/Bounds/Transformer.lean`). ## Mechanistic mapping (Transformer Circuits) @@ -77,7 +81,8 @@ Key assumptions and limitations: to the head-level logit-diff contribution inside the full stack. A new lemma composes head logit-diff bounds with *both* head-output intervals and downstream output intervals, yielding a sound lower bound on the direction - dot of the downstream output (via interval subtraction). + dot of the downstream output (via interval subtraction), and we now instantiate + this for GPT-2 stack outputs via `logitDiffLowerBound_end_to_end_gpt2`. ## Conclusion From ddba480820f1a65bdd1a62a204ef01a03ef8f509 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 13:22:07 +0100 Subject: [PATCH 160/244] Add head interval option to end-to-end model cert --- CLAIMS.md | 3 +- Nfp/Cli.lean | 8 +++- Nfp/IO.lean | 133 +++++++++++++++++++++++++++++++++++++++++++++------ 3 files changed, 128 insertions(+), 16 deletions(-) diff --git a/CLAIMS.md b/CLAIMS.md index be2f5c2..eccb279 100644 --- a/CLAIMS.md +++ b/CLAIMS.md @@ -55,7 +55,8 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri - `nfp induction advanced certify_end_to_end_model` derives the unembedding direction from an `NFP_BINARY_V1` model file, computes a downstream error bound from either a supplied residual-interval certificate or a verified model-derived interval, and composes it with the - head-level logit-diff lower bound. + head-level logit-diff lower bound (optionally using `--layer/--head` to add head-output + interval bounds for a tighter end-to-end check). ## Untrusted / heuristic diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index 1b44f35..1e4f60d 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -345,12 +345,15 @@ def runInductionCertifyEndToEndModel (p : Parsed) : IO UInt32 := do let valuesPath := p.flag! "values" |>.as! String let modelPath := p.flag! "model" |>.as! String let residualIntervalPath? := (p.flag? "residual-interval").map (·.as! String) + let layer? := (p.flag? "layer").map (·.as! Nat) + let head? := (p.flag? "head").map (·.as! Nat) + let period? := (p.flag? "period").map (·.as! Nat) let minActive? := (p.flag? "min-active").map (·.as! Nat) let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) let minMarginStr? := (p.flag? "min-margin").map (·.as! String) let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) IO.runInductionCertifyEndToEndModel scoresPath valuesPath modelPath residualIntervalPath? - minActive? minLogitDiffStr? minMarginStr? maxEpsStr? + layer? head? period? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? /-- `nfp induction certify_end_to_end_model` subcommand. -/ def inductionCertifyEndToEndModelCmd : Cmd := `[Cli| @@ -362,6 +365,9 @@ def inductionCertifyEndToEndModelCmd : Cmd := `[Cli| model : String; "Path to the NFP_BINARY_V1 model file." "residual-interval" : String; "Optional path to a residual-interval certificate file \ (defaults to deriving from the model)." + layer : Nat; "Optional layer index for a head-output interval bound (requires --head)." + head : Nat; "Optional head index for a head-output interval bound (requires --layer)." + period : Nat; "Optional prompt period override when reading head inputs." "min-active" : Nat; "Optional minimum number of active queries required \ (default: max 1 (seq/8))." "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ diff --git a/Nfp/IO.lean b/Nfp/IO.lean index 2278526..344d7a8 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -491,9 +491,10 @@ def runInductionCertifyEndToEndMatrix (scoresPath : System.FilePath) (loaded from disk or derived from the model). -/ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) (valuesPath : System.FilePath) (modelPath : System.FilePath) - (residualIntervalPath? : Option System.FilePath) (minActive? : Option Nat) - (minLogitDiffStr? : Option String) (minMarginStr? : Option String) - (maxEpsStr? : Option String) : IO UInt32 := do + (residualIntervalPath? : Option System.FilePath) + (layer? : Option Nat) (head? : Option Nat) (period? : Option Nat) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? let minMargin?E := parseRatOpt "min-margin" minMarginStr? let maxEps?E := parseRatOpt "max-eps" maxEpsStr? @@ -629,6 +630,19 @@ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) if residualOk then let dirPos := dirSpec.target let dirNeg := dirSpec.negative + if layer?.isSome != head?.isSome then + IO.eprintln + "error: --layer and --head must be provided \ + together" + return 2 + let headChoice? : Option (Nat × Nat) := + match layer?, head? with + | some layer, some head => some (layer, head) + | _, _ => none + if period?.isSome && headChoice?.isNone then + IO.eprintln + "warning: --period ignored without \ + --layer/--head" let colTargetE ← timePure "read unembed column target" (fun () => NfptPure.readUnembedColumn @@ -650,36 +664,127 @@ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) let dirVec : Fin header.modelDim → Rat := fun i => colTarget i - colNeg i + let dotIntervalAbs := + Sound.Bounds.dotIntervalAbsBound + let intervalErrorFromHead? : + Model.InductionHeadInputs + seq header.modelDim header.headDim → + ResidualIntervalCert header.modelDim → + Option Rat := + fun inputs residual => by + classical + match hseq0 : seq with + | 0 => exact none + | Nat.succ n => + let _ : NeZero seq := by + exact ⟨by simp [hseq0]⟩ + match + Sound.buildHeadOutputIntervalFromHead? + inputs with + | none => exact none + | some result => + exact some + (dotIntervalAbs + dirVec + (fun i => + residual.lo i - + result.cert.hi i) + (fun i => + residual.hi i - + result.cert.lo i)) let downstreamError ← timePure "downstream error" (fun () => - Sound.Bounds.dotIntervalAbsBound + dotIntervalAbs dirVec residualCert'.lo residualCert'.hi) let finalLB := logitDiffLB - downstreamError + let intervalError? ← + match headChoice? with + | none => pure none + | some (layer, head) => do + let inputsE ← + timePure "read head inputs" (fun () => + NfptPure.readInductionHeadInputs + data start header layer head + dirPos dirNeg period?) + match inputsE with + | Except.error msg => + IO.eprintln s!"warning: {msg}" + pure none + | Except.ok inputs => + let inputs' : + Model.InductionHeadInputs + seq header.modelDim + header.headDim := by + simpa [hseq] using inputs + let inputsAligned : + Model.InductionHeadInputs + seq header.modelDim + header.headDim := + { inputs' with + active := cert.active + prev := cert.prev } + let intervalError? ← + timePure + "head output interval" + (fun () => + intervalErrorFromHead? + inputsAligned + residualCert') + match intervalError? with + | none => + IO.eprintln + "warning: head output interval \ + rejected" + pure none + | some intervalError => + pure (some intervalError) + let intervalLB? := + intervalError?.map (fun err => + logitDiffLB - err) + let effectiveLB := + match intervalLB? with + | some intervalLB => max finalLB intervalLB + | none => finalLB let violation? : Option Rat := match effectiveMinLogitDiff with | none => none | some minLogitDiff => - if finalLB < minLogitDiff then + if effectiveLB < minLogitDiff then some minLogitDiff else none match violation? with | some minLogitDiff => IO.eprintln - s!"error: end-to-end logitDiffLB \ - {finalLB} below minimum \ + s!"error: end-to-end bound \ + {effectiveLB} below minimum \ {minLogitDiff}" return (2 : UInt32) | none => - IO.println - s!"ok: end-to-end induction \ - bound certified (seq={seq}, \ - active={activeCount}, \ - logitDiffLB={logitDiffLB}, \ - downstreamError={downstreamError}, \ - finalLB={finalLB})" + match intervalLB? with + | none => + IO.println + s!"ok: end-to-end induction \ + bound certified (seq={seq}, \ + active={activeCount}, \ + logitDiffLB={logitDiffLB}, \ + downstreamError={downstreamError}, \ + finalLB={finalLB})" + | some intervalLB => + let intervalError := + logitDiffLB - intervalLB + IO.println + s!"ok: end-to-end induction \ + bound certified (seq={seq}, \ + active={activeCount}, \ + logitDiffLB={logitDiffLB}, \ + downstreamError={downstreamError}, \ + finalLB={finalLB}, \ + intervalError={intervalError}, \ + intervalLB={intervalLB}, \ + effectiveLB={effectiveLB})" return 0 else IO.eprintln From bcbbafa5c5b7fa9697b0bb14c30c8c4cf5ba9388 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 17:00:12 +0100 Subject: [PATCH 161/244] Avoid stack overflow in softmax-margin parsing --- Nfp/Circuit/Cert.lean | 94 +++++++++++++++++++++++++-- Nfp/IO/Pure/SoftmaxMargin/Cert.lean | 16 ++--- Nfp/IO/Pure/SoftmaxMargin/Raw.lean | 16 ++--- Nfp/IO/Pure/SoftmaxMargin/Shared.lean | 53 +++++++-------- 4 files changed, 125 insertions(+), 54 deletions(-) diff --git a/Nfp/Circuit/Cert.lean b/Nfp/Circuit/Cert.lean index 8ca7959..d4d3705 100644 --- a/Nfp/Circuit/Cert.lean +++ b/Nfp/Circuit/Cert.lean @@ -55,20 +55,100 @@ section local instance : Std.Commutative (α := Bool) (· && ·) := ⟨Bool.and_comm⟩ local instance : Std.Associative (α := Bool) (· && ·) := ⟨Bool.and_assoc⟩ -/-- Boolean `all` over a finset. -/ -def finsetAll {β : Type v} (s : Finset β) (p : β → Bool) : Bool := - s.fold (· && ·) true p +/-- Boolean `all` over a finset (tail-recursive fold over the multiset). -/ +def finsetAll {β : Type v} (s : Finset β) (p : β → Bool) : Bool := by + classical + let f : Bool → β → Bool := fun acc a => acc && p a + have hf : RightCommutative f := by + refine ⟨?_⟩ + intro b a c + calc + f (f b a) c = ((b && p a) && p c) := rfl + _ = (b && (p a && p c)) := by simp [Bool.and_assoc] + _ = (b && (p c && p a)) := by simp [Bool.and_comm] + _ = ((b && p c) && p a) := by simp [Bool.and_assoc] + _ = f (f b c) a := rfl + let _ : RightCommutative f := hf + exact Multiset.foldl (f := f) (b := true) s.1 theorem finsetAll_eq_true_iff {β : Type v} {s : Finset β} {p : β → Bool} : finsetAll s p = true ↔ ∀ a ∈ s, p a = true := by classical + let f : Bool → β → Bool := fun acc a => acc && p a + have hf : RightCommutative f := by + refine ⟨?_⟩ + intro b a c + calc + f (f b a) c = ((b && p a) && p c) := rfl + _ = (b && (p a && p c)) := by simp [Bool.and_assoc] + _ = (b && (p c && p a)) := by simp [Bool.and_comm] + _ = ((b && p c) && p a) := by simp [Bool.and_assoc] + _ = f (f b c) a := rfl + let _ : RightCommutative f := hf + have hfoldl : + ∀ (s : Multiset β) (acc : Bool), + Multiset.foldl (f := f) (b := acc) s = true ↔ + acc = true ∧ Multiset.foldl (f := f) (b := true) s = true := by + intro s acc + revert acc + refine Multiset.induction_on s ?h0 ?hcons + · intro acc + simp [Multiset.foldl_zero] + · intro a s ih acc + have ih_acc : + Multiset.foldl (f := f) (b := acc && p a) s = true ↔ + (acc && p a) = true ∧ Multiset.foldl (f := f) (b := true) s = true := by + simpa using (ih (acc := acc && p a)) + have ih_pa : + Multiset.foldl (f := f) (b := p a) s = true ↔ + p a = true ∧ Multiset.foldl (f := f) (b := true) s = true := by + simpa using (ih (acc := p a)) + have hgoal : + Multiset.foldl (f := f) (b := acc && p a) s = true ↔ + acc = true ∧ Multiset.foldl (f := f) (b := p a) s = true := by + constructor + · intro h + have haccpa := ih_acc.mp h + have haccpa' : acc = true ∧ p a = true := by + simpa [Bool.and_eq_true] using haccpa.1 + have hacc : acc = true := haccpa'.1 + have hpa : p a = true := haccpa'.2 + have hfold : Multiset.foldl (f := f) (b := p a) s = true := + ih_pa.mpr ⟨hpa, haccpa.2⟩ + exact ⟨hacc, hfold⟩ + · intro h + rcases h with ⟨hacc, hfold⟩ + have hpa := ih_pa.mp hfold + have haccpa : (acc && p a) = true := by + simpa [Bool.and_eq_true] using And.intro hacc hpa.1 + exact ih_acc.mpr ⟨haccpa, hpa.2⟩ + simpa [Multiset.foldl_cons, f] using hgoal induction s using Finset.induction_on with | empty => - simp [finsetAll] + simp [finsetAll, Multiset.foldl_zero] | @insert a s ha ih => - have hfold : finsetAll (insert a s) p = true ↔ p a = true ∧ finsetAll s p = true := by - simp [finsetAll, ha, Bool.and_eq_true] - simp [hfold, ih] + have hfold : + finsetAll (insert a s) p = true ↔ + p a = true ∧ finsetAll s p = true := by + have hval : (insert a s).1 = a ::ₘ s.1 := by + simpa using (Finset.insert_val_of_notMem (a := a) (s := s) ha) + calc + finsetAll (insert a s) p = true ↔ + Multiset.foldl (f := f) (b := true) (insert a s).1 = true := by + simp [finsetAll, f] + _ ↔ Multiset.foldl (f := f) (b := true) (a ::ₘ s.1) = true := by + simp [hval] + _ ↔ Multiset.foldl (f := f) (b := f true a) s.1 = true := by + simp [Multiset.foldl_cons] + _ ↔ Multiset.foldl (f := f) (b := p a) s.1 = true := by + simp [f] + _ ↔ p a = true ∧ Multiset.foldl (f := f) (b := true) s.1 = true := by + simpa using (hfoldl (s := s.1) (acc := p a)) + _ ↔ p a = true ∧ finsetAll s p = true := by + simp [finsetAll, f] + have hfold' : + finsetAll (insert a s) p = true ↔ p a = true ∧ finsetAll s p = true := hfold + simpa [Finset.forall_mem_insert, ih] using hfold' /-- Boolean check for interface equality. -/ def sameInterface (C₁ C₂ : Circuit ι α) : Bool := diff --git a/Nfp/IO/Pure/SoftmaxMargin/Cert.lean b/Nfp/IO/Pure/SoftmaxMargin/Cert.lean index e43e378..8354fd5 100644 --- a/Nfp/IO/Pure/SoftmaxMargin/Cert.lean +++ b/Nfp/IO/Pure/SoftmaxMargin/Cert.lean @@ -25,21 +25,21 @@ private def finalizeState {seq : Nat} (hpos : 0 < seq) match st.margin with | some v => pure v | none => throw "missing margin entry" - if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => (st.prev q).isSome) then + if !st.prev.all Option.isSome then throw "missing prev entries" - if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.scores q k).isSome)) then + if !st.scores.all (fun row => row.all Option.isSome) then throw "missing score entries" - if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.weights q k).isSome)) then + if !st.weights.all (fun row => row.all Option.isSome) then throw "missing weight entries" let defaultPrev : Fin seq := ⟨0, hpos⟩ let prevFun : Fin seq → Fin seq := fun q => - (st.prev q).getD defaultPrev + (st.prev[q.1]!).getD defaultPrev let scoresFun : Fin seq → Fin seq → Rat := fun q k => - (st.scores q k).getD 0 + let row := st.scores[q.1]! + (row[k.1]!).getD 0 let weightsFun : Fin seq → Fin seq → Rat := fun q k => - (st.weights q k).getD 0 + let row := st.weights[q.1]! + (row[k.1]!).getD 0 let active := if st.activeSeen then st.active diff --git a/Nfp/IO/Pure/SoftmaxMargin/Raw.lean b/Nfp/IO/Pure/SoftmaxMargin/Raw.lean index 35d787f..7988511 100644 --- a/Nfp/IO/Pure/SoftmaxMargin/Raw.lean +++ b/Nfp/IO/Pure/SoftmaxMargin/Raw.lean @@ -28,21 +28,21 @@ structure SoftmaxMarginRaw (seq : Nat) where private def finalizeRawState {seq : Nat} (hpos : 0 < seq) (st : SoftmaxMargin.ParseState seq) : Except String (SoftmaxMarginRaw seq) := do - if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => (st.prev q).isSome) then + if !st.prev.all Option.isSome then throw "missing prev entries" - if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.scores q k).isSome)) then + if !st.scores.all (fun row => row.all Option.isSome) then throw "missing score entries" - if !finsetAll (Finset.univ : Finset (Fin seq)) (fun q => - finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.weights q k).isSome)) then + if !st.weights.all (fun row => row.all Option.isSome) then throw "missing weight entries" let defaultPrev : Fin seq := ⟨0, hpos⟩ let prevFun : Fin seq → Fin seq := fun q => - (st.prev q).getD defaultPrev + (st.prev[q.1]!).getD defaultPrev let scoresFun : Fin seq → Fin seq → Rat := fun q k => - (st.scores q k).getD 0 + let row := st.scores[q.1]! + (row[k.1]!).getD 0 let weightsFun : Fin seq → Fin seq → Rat := fun q k => - (st.weights q k).getD 0 + let row := st.weights[q.1]! + (row[k.1]!).getD 0 let active := if st.activeSeen then st.active diff --git a/Nfp/IO/Pure/SoftmaxMargin/Shared.lean b/Nfp/IO/Pure/SoftmaxMargin/Shared.lean index 2939c0a..ee421bd 100644 --- a/Nfp/IO/Pure/SoftmaxMargin/Shared.lean +++ b/Nfp/IO/Pure/SoftmaxMargin/Shared.lean @@ -26,37 +26,33 @@ structure ParseState (seq : Nat) where /-- Whether any active entries were parsed. -/ activeSeen : Bool /-- Optional predecessor pointer per query. -/ - prev : Fin seq → Option (Fin seq) + prev : Array (Option (Fin seq)) /-- Optional score matrix entries. -/ - scores : Fin seq → Fin seq → Option Rat + scores : Array (Array (Option Rat)) /-- Optional weight matrix entries. -/ - weights : Fin seq → Fin seq → Option Rat + weights : Array (Array (Option Rat)) /-- Initialize a softmax-margin parse state. -/ def initState (seq : Nat) : ParseState seq := + let row : Array (Option Rat) := Array.replicate seq none { eps := none margin := none active := ∅ activeSeen := false - prev := fun _ => none - scores := fun _ _ => none - weights := fun _ _ => none } + prev := Array.replicate seq none + scores := Array.replicate seq row + weights := Array.replicate seq row } /-- Set a predecessor entry from `(q, k)` tokens. -/ def setPrev {seq : Nat} (st : ParseState seq) (q k : Nat) : Except String (ParseState seq) := do - if hq : q < seq then + if q < seq then if hk : k < seq then - let qFin : Fin seq := ⟨q, hq⟩ let kFin : Fin seq := ⟨k, hk⟩ - match st.prev qFin with + match st.prev[q]! with | some _ => throw s!"duplicate prev entry for q={q}" | none => - let prev' : Fin seq → Option (Fin seq) := fun q' => - if q' = qFin then - some kFin - else - st.prev q' + let prev' := st.prev.set! q (some kFin) return { st with prev := prev' } else throw s!"prev index out of range: k={k}" @@ -75,24 +71,17 @@ def setActive {seq : Nat} (st : ParseState seq) (q : Nat) : Except String (Parse throw s!"active index out of range: q={q}" /-- Insert a matrix entry for scores/weights. -/ -def setMatrixEntry {seq : Nat} (mat : Fin seq → Fin seq → Option Rat) - (q k : Nat) (v : Rat) : Except String (Fin seq → Fin seq → Option Rat) := do - if hq : q < seq then - if hk : k < seq then - let qFin : Fin seq := ⟨q, hq⟩ - let kFin : Fin seq := ⟨k, hk⟩ - match mat qFin kFin with +def setMatrixEntry {seq : Nat} (mat : Array (Array (Option Rat))) + (q k : Nat) (v : Rat) : Except String (Array (Array (Option Rat))) := do + if q < seq then + if k < seq then + let row := mat[q]! + match row[k]! with | some _ => throw s!"duplicate matrix entry at ({q}, {k})" | none => - let mat' : Fin seq → Fin seq → Option Rat := fun q' k' => - if q' = qFin then - if k' = kFin then - some v - else - mat q' k' - else - mat q' k' + let row' := row.set! k (some v) + let mat' := mat.set! q row' return mat' else throw s!"index out of range: k={k}" @@ -118,10 +107,12 @@ def parseLine {seq : Nat} (st : ParseState seq) | ["prev", q, k] => setPrev st (← parseNat q) (← parseNat k) | ["score", q, k, val] => - let mat ← setMatrixEntry st.scores (← parseNat q) (← parseNat k) (← parseRat val) + let mat ← setMatrixEntry (seq := seq) st.scores (← parseNat q) (← parseNat k) + (← parseRat val) return { st with scores := mat } | ["weight", q, k, val] => - let mat ← setMatrixEntry st.weights (← parseNat q) (← parseNat k) (← parseRat val) + let mat ← setMatrixEntry (seq := seq) st.weights (← parseNat q) (← parseNat k) + (← parseRat val) return { st with weights := mat } | _ => throw s!"unrecognized line: '{String.intercalate " " tokens}'" From cf9e238da81a12a8a27764cca56a17383ad2b5f9 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 18:07:58 +0100 Subject: [PATCH 162/244] Remove module map references --- AGENTS.md | 9 +- MODULE_MAP.md | 224 -------------------------------------------------- 2 files changed, 1 insertion(+), 232 deletions(-) delete mode 100644 MODULE_MAP.md diff --git a/AGENTS.md b/AGENTS.md index 9b257ed..1c26020 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -103,7 +103,7 @@ prefer the **clean redesign**, but do it consciously and document the rationale. ## 3. Workflow Expectations (How to make changes) ### 3.1 Before coding -- Identify the right module (see `MODULE_MAP.md`). +- Identify the right module. - Skim the top docstring / main definitions in that module. - Look for existing lemmas and naming patterns to match. @@ -172,12 +172,6 @@ prefer the **clean redesign**, but do it consciously and document the rationale. --- -## 5. Module Map (Where Things Live) - -The module map lives in `MODULE_MAP.md`. - ---- - ## 6. Axioms & Trust Boundary This repo treats “axioms creep” as a serious regression. @@ -202,7 +196,6 @@ This repo treats “axioms creep” as a serious regression. - [ ] New nontrivial definitions/theorems have short, accurate docstrings. - [ ] Core invariants (nonnegativity, normalization, finiteness, acyclicity) are preserved and, where possible, explicitly proved. -- [ ] Module map in `MODULE_MAP.md` is accurate (updated in the same commit if needed). - [ ] If CLI behavior changed: `lake build nfp --wfail` succeeds and basic `nfp ... --help` works. ## Landing the Plane (Session Completion) diff --git a/MODULE_MAP.md b/MODULE_MAP.md deleted file mode 100644 index 9f07346..0000000 --- a/MODULE_MAP.md +++ /dev/null @@ -1,224 +0,0 @@ -# Module Map (Where Things Live) - -This is a *map*, not a prison. You may reshuffle if a better design emerges, -but you **must** update this list in the same commit. - -## Core types -- `Nfp/Core/Basic.lean` - - `Mass` alias for nonnegative weights used throughout the rewrite. -- `Nfp/Core.lean` - - Aggregator for core shared definitions. - -## Probability vectors -- `Nfp/Prob/Basic.lean` - - `ProbVec` definition + invariants. -- `Nfp/Prob/Operations.lean` - - `pure`, `mix`, and basic lemmas. -- `Nfp/Prob.lean` - - Aggregator for probability modules. - -## Mixers -- `Nfp/Mixer/Basic.lean` - - `Mixer` structure and row-stochastic invariant. -- `Nfp/Mixer/Operations.lean` - - `push`, `comp`, and `id` mixers. -- `Nfp/Mixer.lean` - - Aggregator for mixer modules. - -## Systems (DAG + local mixing) -- `Nfp/System/Dag.lean` - - DAG relation + parent/child sets. -- `Nfp/System/LocalSystem.lean` - - `LocalSystem` with edge support, row-stochastic predicate, and evaluation semantics. -- `Nfp/System.lean` - - Aggregator for system modules. - -## Circuits (certification core) -- `Nfp/Circuit/Basic.lean` - - DAG-based circuit structure with inputs/outputs and gate semantics. -- `Nfp/Circuit/Combinators.lean` - - Core circuit combinators (relabeling, interface transport). -- `Nfp/Circuit/Interface.lean` - - Typed input/output interfaces and interface-based evaluation. -- `Nfp/Circuit/Semantics.lean` - - Well-founded evaluation semantics for circuits. -- `Nfp/Circuit/WellFormed.lean` - - Basic well-formedness conditions for circuit inputs. -- `Nfp/Circuit/Cert.lean` - - Equivalence definition and finite checker. -- `Nfp/Circuit/Cert/SoftmaxMargin.lean` - - Softmax-margin certificate payloads and checker soundness. -- `Nfp/Circuit/Cert/ValueRange.lean` - - Value-range certificate payloads and checker soundness. -- `Nfp/Circuit/Cert/LogitDiff.lean` - - Logit-diff lower-bound computation for induction certificates. -- `Nfp/Circuit/Cert/DownstreamLinear.lean` - - Downstream linear error certificates for end-to-end induction bounds. -- `Nfp/Circuit/Cert/ResidualBound.lean` - - Residual-stream bound certificates for downstream error computation. -- `Nfp/Circuit/Cert/ResidualInterval.lean` - - Residual-stream interval certificates for downstream dot-product bounds. -- `Nfp/Circuit/Typed.lean` - - Typed circuit wrapper and interface-level equivalence checker. -- `Nfp/Circuit/Compose.lean` - - Sequential composition and residual wiring for typed circuits. -- `Nfp/Circuit/Gates/Basic.lean` - - Basic gate combinators for aggregating parent values. -- `Nfp/Circuit/Gates/Linear.lean` - - Linear and affine gate combinators built from `Matrix.mulVec`. -- `Nfp/Circuit/Gates.lean` - - Aggregator for gate combinator modules. -- `Nfp/Circuit/Tensor.lean` - - Typed tensor indices and tensor aliases. -- `Nfp/Circuit/Layers/Linear.lean` - - Linear/affine layer circuits with typed interfaces. -- `Nfp/Circuit/Layers/Tensor.lean` - - Batched linear/affine layer circuits for tensor-shaped data. -- `Nfp/Circuit/Layers/Reshape.lean` - - Reshape combinators for product-typed circuit interfaces. -- `Nfp/Circuit/Layers/Heads.lean` - - Head split/merge combinators for transformer-shaped indices. -- `Nfp/Circuit/Layers/Softmax.lean` - - Softmax helpers and margin-based bounds for layer reasoning. -- `Nfp/Circuit/Layers/Attention.lean` - - Q/K/V, output projection wiring, and attention score/mixing core. -- `Nfp/Circuit/Layers/Induction.lean` - - Induction-head weight specs and attention-core output lemmas. -- `Nfp/Circuit/Layers/TransformerBlock.lean` - - GPT-style transformer block wiring from LN/attention/MLP circuits. -- `Nfp/Circuit/Layers.lean` - - Aggregator for circuit layer modules. -- `Nfp/Circuit.lean` - - Aggregator for circuit modules. - -## CLI surface -- `Nfp/IO/Pure.lean` - - Aggregator for pure parsing helpers. -- `Nfp/IO/Pure/Basic.lean` - - Shared parsing helpers (`Nat`/`Int`/`Rat`, token cleanup). -- `Nfp/IO/Pure/InductionHead.lean` - - Induction-head input payload parsing from text/bytes. -- `Nfp/IO/Pure/InductionHead/Bytes.lean` - - Byte-level parser for induction-head input payloads. -- `Nfp/IO/Pure/SoftmaxMargin.lean` - - Aggregator for softmax-margin parsing helpers. -- `Nfp/IO/Pure/SoftmaxMargin/Shared.lean` - - Shared parsing helpers for softmax-margin payloads. -- `Nfp/IO/Pure/SoftmaxMargin/Cert.lean` - - Softmax-margin certificate parser. -- `Nfp/IO/Pure/SoftmaxMargin/Raw.lean` - - Softmax-margin raw-input parser. -- `Nfp/IO/Pure/ValueRange.lean` - - Aggregator for value-range parsing helpers. -- `Nfp/IO/Pure/ValueRange/Shared.lean` - - Shared parsing helpers for value-range payloads. -- `Nfp/IO/Pure/ValueRange/Cert.lean` - - Value-range certificate parser. -- `Nfp/IO/Pure/ValueRange/Raw.lean` - - Value-range raw-input parser. -- `Nfp/IO/Pure/Downstream.lean` - - Downstream linear and matrix payload parsers. -- `Nfp/IO/Pure/Residual.lean` - - Residual-bound and residual-interval payload parsers. -- `Nfp/IO/NfptPure.lean` - - Pure parsing helpers for `NFP_BINARY_V1` model slices. -- `Nfp/IO/HeadScore.lean` - - Pure task-based cache builder for head score dot-abs bounds. -- `Nfp/IO/Loaders.lean` - - IO loaders for certificates and raw inputs. -- `Nfp/IO/Checks.lean` - - IO checks for certificate validity. -- `Nfp/IO/Derive.lean` - - IO derivations building certificates from model binaries. -- `Nfp/IO/Timing.lean` - - IO timing helpers with microsecond reporting and phase wrappers. -- `Nfp/IO/Util.lean` - - Small CLI parsing utilities shared across IO entrypoints. -- `Nfp/IO/InductionHead.lean` - - Induction-head IO pipeline with timing instrumentation. -- `Nfp/IO/Bench/Rational.lean` - - Microbenchmarks for rational arithmetic and caching. -- `Nfp/IO/Bench/InductionCore.lean` - - Benchmark helpers for induction-head core certification. -- `Nfp/IO/Bench/InductionCounts.lean` - - Call-count instrumentation for induction-head computations. -- `Nfp/IO.lean` - - IO-only wrappers for loading inputs and running checks. -- `Nfp/Cli.lean` - - CLI commands and `main` implementation. -- `Main.lean` - - Thin entrypoint delegating to `Nfp.Cli.main`. - - Benchmark entrypoint for rational microbenchmarks. -- `Nfp.lean` - - Top-level reexports. -- `TheoremAxioms.lean` - - Axiom dashboard for `theorem-axioms` build target (`#print axioms`). - -## Sound certification -- `Nfp/Sound/Induction.lean` - - Aggregator for induction soundness modules. -- `Nfp/Sound/Induction/Core.lean` - - Sound builders and core proofs for induction certificates from exact inputs. -- `Nfp/Sound/Induction/CoreSound.lean` - - Soundness proof for `buildInductionCertFromHeadCore?`. -- `Nfp/Sound/Induction/CoreSound/Values.lean` - - Helper lemmas for value-direction projections in the core soundness proof. -- `Nfp/Sound/Induction/CoreDefs.lean` - - Core definitions and soundness predicates for induction certificates. -- `Nfp/Sound/Induction/EndToEnd.lean` - - End-to-end induction bounds combining head certificates with transformer-stack intervals. -- `Nfp/Sound/Induction/HeadOutput.lean` - - Head-output interval certificates built from induction head inputs. -- `Nfp/Sound/Induction/HeadBounds.lean` - - Helper bounds used to stage head-induction certificate construction. -- `Nfp/Sound/Induction/LogitDiff.lean` - - Logit-diff bounds derived from induction certificates. -- `Nfp/Sound/Induction/OneHot.lean` - - Per-query one-hot bounds derived from score margins. -- `Nfp/Sound/Bounds/Cache.lean` - - Cached bound evaluators (thunk/task backed) for interval computations. -- `Nfp/Sound/Bounds/MatrixNorm.lean` - - Row-sum matrix norms and downstream linear certificate builders. -- `Nfp/Sound/Bounds/MatrixNorm/Interval.lean` - - Dot-product and matrix-vector interval bounds (rational and real). -- `Nfp/Sound/Bounds/LayerNorm.lean` - - LayerNorm interval bounds and end-to-end soundness lemmas. -- `Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean` - - Mean/variance helpers for LayerNorm bounds. -- `Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean` - - Square-root bounds (rational + real) used by LayerNorm and invStd bounds. -- `Nfp/Sound/Bounds/LayerNorm/InvStd.lean` - - Inverse-standard-deviation bounds for LayerNorm. -- `Nfp/Sound/Bounds/UnnormRat.lean` - - Unnormalized rational helpers for deferred normalization in bounds kernels. -- `Nfp/Sound/Bounds/Gelu.lean` - - Tanh-GELU bounds for interval propagation through MLPs. -- `Nfp/Sound/Bounds/Mlp.lean` - - Interval bounds for GPT-2 MLP blocks and LayerNorm composition. -- `Nfp/Sound/Bounds/Attention.lean` - - Interval bounds for multi-head attention and transformer layers. -- `Nfp/Sound/Bounds/Transformer.lean` - - Interval bounds for transformer stacks and final LayerNorm outputs. -- `Nfp/Sound/Bounds/Transformer/Embedding.lean` - - Embedding interval bounds and position-restricted bounds. -- `Nfp/Sound/Linear/FinFold.lean` - - Tail-recursive folds and sums for sound linear computations. -- `Nfp/Sound/Gpt2/HeadInputs.lean` - - Sound construction of GPT-2 induction head inputs. -- `Nfp/Sound.lean` - - Aggregator for sound certification modules. - -## Model inputs -- `Nfp/Model/InductionHead.lean` - - Exact induction-head input payloads (embeddings and projection weights). -- `Nfp/Model/InductionPrompt.lean` - - Prompt utilities (`prev` map and active set for periodic prompts). -- `Nfp/Model/Gpt2.lean` - - Exact GPT-2 head-slice data, layer/MLP/LayerNorm parameters, and embedding helpers. -- `Nfp/Model.lean` - - Aggregator for model input modules. - -If you introduce a new conceptual layer: -- either extend the closest existing file, -- or add a new module with a clear name + top docstring, -- and update this map in the same commit. From 865875dab878ce44d172165ada7a5b9dc3a29e0a Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 18:31:47 +0100 Subject: [PATCH 163/244] Refactor module structure and remove bench tooling --- Nfp/Circuit.lean | 6 - Nfp/Circuit/Cert.lean | 218 +------ Nfp/Circuit/Cert/Basic.lean | 215 +++++++ Nfp/Circuit/Cert/DownstreamLinear.lean | 2 +- Nfp/Circuit/Cert/ResidualBound.lean | 2 +- Nfp/Circuit/Cert/ResidualInterval.lean | 2 +- Nfp/Circuit/Cert/SoftmaxMargin.lean | 2 +- Nfp/Circuit/Cert/ValueRange.lean | 2 +- Nfp/Circuit/Typed.lean | 2 +- Nfp/IO.lean | 800 +----------------------- Nfp/IO/Bench/InductionCore.lean | 229 ------- Nfp/IO/Bench/InductionCounts.lean | 72 --- Nfp/IO/Bench/Rational.lean | 362 ----------- Nfp/IO/InductionHead.lean | 4 - Nfp/IO/Run.lean | 805 +++++++++++++++++++++++++ Nfp/IO/Timing.lean | 17 +- Nfp/Sound.lean | 8 +- Nfp/Sound/Bounds.lean | 15 + Nfp/Sound/Induction.lean | 6 +- lakefile.toml | 12 - 20 files changed, 1064 insertions(+), 1717 deletions(-) create mode 100644 Nfp/Circuit/Cert/Basic.lean delete mode 100644 Nfp/IO/Bench/InductionCore.lean delete mode 100644 Nfp/IO/Bench/InductionCounts.lean delete mode 100644 Nfp/IO/Bench/Rational.lean create mode 100644 Nfp/IO/Run.lean create mode 100644 Nfp/Sound/Bounds.lean diff --git a/Nfp/Circuit.lean b/Nfp/Circuit.lean index dc663ad..51917ea 100644 --- a/Nfp/Circuit.lean +++ b/Nfp/Circuit.lean @@ -6,12 +6,6 @@ import Nfp.Circuit.Interface import Nfp.Circuit.Semantics import Nfp.Circuit.WellFormed import Nfp.Circuit.Cert -import Nfp.Circuit.Cert.SoftmaxMargin -import Nfp.Circuit.Cert.ValueRange -import Nfp.Circuit.Cert.LogitDiff -import Nfp.Circuit.Cert.DownstreamLinear -import Nfp.Circuit.Cert.ResidualBound -import Nfp.Circuit.Cert.ResidualInterval import Nfp.Circuit.Typed import Nfp.Circuit.Compose import Nfp.Circuit.Gates diff --git a/Nfp/Circuit/Cert.lean b/Nfp/Circuit/Cert.lean index d4d3705..7e21afd 100644 --- a/Nfp/Circuit/Cert.lean +++ b/Nfp/Circuit/Cert.lean @@ -1,215 +1,13 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Data.Finset.Fold -import Mathlib.Data.Finset.Insert -import Mathlib.Data.Fintype.Pi -import Nfp.Circuit.Interface -import Nfp.Circuit.Semantics +import Nfp.Circuit.Cert.Basic +import Nfp.Circuit.Cert.DownstreamLinear +import Nfp.Circuit.Cert.LogitDiff +import Nfp.Circuit.Cert.ResidualBound +import Nfp.Circuit.Cert.ResidualInterval +import Nfp.Circuit.Cert.SoftmaxMargin +import Nfp.Circuit.Cert.ValueRange /-! -Circuit equivalence and a finite checker. +Certificate definitions and checkers for circuits. -/ - -namespace Nfp - -universe u v u' u_in u_out - -namespace Circuit - -variable {ι : Type u} [Fintype ι] [DecidableEq ι] -variable {α : Type v} - -/-- Circuits share the same input/output interface. -/ -def SameInterface (C₁ C₂ : Circuit ι α) : Prop := - C₁.inputs = C₂.inputs ∧ C₁.outputs = C₂.outputs - -/-- `SameInterface` is decidable. -/ -instance (C₁ C₂ : Circuit ι α) : Decidable (SameInterface C₁ C₂) := by - dsimp [SameInterface] - infer_instance - -/-- Circuits agree on outputs for all input assignments on a fixed interface. -/ -def EquivOn (C₁ C₂ : Circuit ι α) (h : SameInterface C₁ C₂) : Prop := - ∀ input : C₁.InputAssignment, ∀ i ∈ C₁.outputs, - evalInput C₁ input i = evalInput C₂ (InputAssignment.cast h.1 input) i - -/-- Circuits are equivalent if they share an interface and agree on all inputs. -/ -def Equiv (C₁ C₂ : Circuit ι α) : Prop := - ∃ h : SameInterface C₁ C₂, EquivOn C₁ C₂ h - -section Interface - -variable {ι₁ : Type u} [Fintype ι₁] [DecidableEq ι₁] -variable {ι₂ : Type u'} [Fintype ι₂] [DecidableEq ι₂] -variable {ι_in : Type u_in} {ι_out : Type u_out} - -/-- Circuits agree on outputs for all typed inputs on a shared interface. -/ -def EquivOnInterface (C₁ : Circuit ι₁ α) (C₂ : Circuit ι₂ α) - (I₁ : Interface C₁ ι_in ι_out) (I₂ : Interface C₂ ι_in ι_out) : Prop := - ∀ input : ι_in → α, ∀ o : ι_out, I₁.eval input o = I₂.eval input o - -end Interface - -section - -local instance : Std.Commutative (α := Bool) (· && ·) := ⟨Bool.and_comm⟩ -local instance : Std.Associative (α := Bool) (· && ·) := ⟨Bool.and_assoc⟩ - -/-- Boolean `all` over a finset (tail-recursive fold over the multiset). -/ -def finsetAll {β : Type v} (s : Finset β) (p : β → Bool) : Bool := by - classical - let f : Bool → β → Bool := fun acc a => acc && p a - have hf : RightCommutative f := by - refine ⟨?_⟩ - intro b a c - calc - f (f b a) c = ((b && p a) && p c) := rfl - _ = (b && (p a && p c)) := by simp [Bool.and_assoc] - _ = (b && (p c && p a)) := by simp [Bool.and_comm] - _ = ((b && p c) && p a) := by simp [Bool.and_assoc] - _ = f (f b c) a := rfl - let _ : RightCommutative f := hf - exact Multiset.foldl (f := f) (b := true) s.1 - -theorem finsetAll_eq_true_iff {β : Type v} {s : Finset β} {p : β → Bool} : - finsetAll s p = true ↔ ∀ a ∈ s, p a = true := by - classical - let f : Bool → β → Bool := fun acc a => acc && p a - have hf : RightCommutative f := by - refine ⟨?_⟩ - intro b a c - calc - f (f b a) c = ((b && p a) && p c) := rfl - _ = (b && (p a && p c)) := by simp [Bool.and_assoc] - _ = (b && (p c && p a)) := by simp [Bool.and_comm] - _ = ((b && p c) && p a) := by simp [Bool.and_assoc] - _ = f (f b c) a := rfl - let _ : RightCommutative f := hf - have hfoldl : - ∀ (s : Multiset β) (acc : Bool), - Multiset.foldl (f := f) (b := acc) s = true ↔ - acc = true ∧ Multiset.foldl (f := f) (b := true) s = true := by - intro s acc - revert acc - refine Multiset.induction_on s ?h0 ?hcons - · intro acc - simp [Multiset.foldl_zero] - · intro a s ih acc - have ih_acc : - Multiset.foldl (f := f) (b := acc && p a) s = true ↔ - (acc && p a) = true ∧ Multiset.foldl (f := f) (b := true) s = true := by - simpa using (ih (acc := acc && p a)) - have ih_pa : - Multiset.foldl (f := f) (b := p a) s = true ↔ - p a = true ∧ Multiset.foldl (f := f) (b := true) s = true := by - simpa using (ih (acc := p a)) - have hgoal : - Multiset.foldl (f := f) (b := acc && p a) s = true ↔ - acc = true ∧ Multiset.foldl (f := f) (b := p a) s = true := by - constructor - · intro h - have haccpa := ih_acc.mp h - have haccpa' : acc = true ∧ p a = true := by - simpa [Bool.and_eq_true] using haccpa.1 - have hacc : acc = true := haccpa'.1 - have hpa : p a = true := haccpa'.2 - have hfold : Multiset.foldl (f := f) (b := p a) s = true := - ih_pa.mpr ⟨hpa, haccpa.2⟩ - exact ⟨hacc, hfold⟩ - · intro h - rcases h with ⟨hacc, hfold⟩ - have hpa := ih_pa.mp hfold - have haccpa : (acc && p a) = true := by - simpa [Bool.and_eq_true] using And.intro hacc hpa.1 - exact ih_acc.mpr ⟨haccpa, hpa.2⟩ - simpa [Multiset.foldl_cons, f] using hgoal - induction s using Finset.induction_on with - | empty => - simp [finsetAll, Multiset.foldl_zero] - | @insert a s ha ih => - have hfold : - finsetAll (insert a s) p = true ↔ - p a = true ∧ finsetAll s p = true := by - have hval : (insert a s).1 = a ::ₘ s.1 := by - simpa using (Finset.insert_val_of_notMem (a := a) (s := s) ha) - calc - finsetAll (insert a s) p = true ↔ - Multiset.foldl (f := f) (b := true) (insert a s).1 = true := by - simp [finsetAll, f] - _ ↔ Multiset.foldl (f := f) (b := true) (a ::ₘ s.1) = true := by - simp [hval] - _ ↔ Multiset.foldl (f := f) (b := f true a) s.1 = true := by - simp [Multiset.foldl_cons] - _ ↔ Multiset.foldl (f := f) (b := p a) s.1 = true := by - simp [f] - _ ↔ p a = true ∧ Multiset.foldl (f := f) (b := true) s.1 = true := by - simpa using (hfoldl (s := s.1) (acc := p a)) - _ ↔ p a = true ∧ finsetAll s p = true := by - simp [finsetAll, f] - have hfold' : - finsetAll (insert a s) p = true ↔ p a = true ∧ finsetAll s p = true := hfold - simpa [Finset.forall_mem_insert, ih] using hfold' - -/-- Boolean check for interface equality. -/ -def sameInterface (C₁ C₂ : Circuit ι α) : Bool := - decide (C₁.inputs = C₂.inputs) && decide (C₁.outputs = C₂.outputs) - -theorem sameInterface_eq_true_iff (C₁ C₂ : Circuit ι α) : - sameInterface C₁ C₂ = true ↔ SameInterface C₁ C₂ := by - simp [sameInterface, SameInterface, Bool.and_eq_true] - -/-- Decide equivalence by enumerating all input assignments on a finite value type. -/ -def checkEquiv (C₁ C₂ : Circuit ι α) [Fintype α] [DecidableEq α] : Bool := - if h : SameInterface C₁ C₂ then - finsetAll (Finset.univ : Finset C₁.InputAssignment) (fun input => - finsetAll C₁.outputs (fun i => - decide (evalInput C₁ input i = evalInput C₂ (InputAssignment.cast h.1 input) i))) - else - false - -/-- `checkEquiv` is sound and complete for `Equiv`. -/ -theorem checkEquiv_eq_true_iff (C₁ C₂ : Circuit ι α) [Fintype α] [DecidableEq α] : - checkEquiv C₁ C₂ = true ↔ Equiv C₁ C₂ := by - classical - by_cases h : SameInterface C₁ C₂ - · have hcheck : checkEquiv C₁ C₂ = true ↔ EquivOn C₁ C₂ h := by - simp [checkEquiv, h, EquivOn, finsetAll_eq_true_iff] - constructor - · intro hc - exact ⟨h, hcheck.mp hc⟩ - · intro hEquiv - rcases hEquiv with ⟨h', hEq⟩ - have hh : h' = h := Subsingleton.elim _ _ - exact hcheck.mpr (by simpa [hh] using hEq) - · simp [checkEquiv, h, Equiv] - -end - -section InterfaceCheck - -variable {ι₁ : Type u} [Fintype ι₁] [DecidableEq ι₁] -variable {ι₂ : Type u'} [Fintype ι₂] [DecidableEq ι₂] -variable {ι_in : Type u_in} [Fintype ι_in] [DecidableEq ι_in] -variable {ι_out : Type u_out} [Fintype ι_out] - -/-- Decide interface-based equivalence by enumerating typed inputs. -/ -def checkEquivOnInterface (C₁ : Circuit ι₁ α) (C₂ : Circuit ι₂ α) - (I₁ : Interface C₁ ι_in ι_out) (I₂ : Interface C₂ ι_in ι_out) - [Fintype α] [DecidableEq α] : Bool := - finsetAll (Finset.univ : Finset (ι_in → α)) (fun input => - finsetAll (Finset.univ : Finset ι_out) (fun o => - decide (I₁.eval input o = I₂.eval input o))) - -/-- `checkEquivOnInterface` is sound and complete for `EquivOnInterface`. -/ -theorem checkEquivOnInterface_eq_true_iff (C₁ : Circuit ι₁ α) (C₂ : Circuit ι₂ α) - (I₁ : Interface C₁ ι_in ι_out) (I₂ : Interface C₂ ι_in ι_out) - [Fintype α] [DecidableEq α] : - checkEquivOnInterface C₁ C₂ I₁ I₂ = true ↔ EquivOnInterface C₁ C₂ I₁ I₂ := by - classical - simp [checkEquivOnInterface, EquivOnInterface, finsetAll_eq_true_iff] - -end InterfaceCheck - -end Circuit - -end Nfp diff --git a/Nfp/Circuit/Cert/Basic.lean b/Nfp/Circuit/Cert/Basic.lean new file mode 100644 index 0000000..d4d3705 --- /dev/null +++ b/Nfp/Circuit/Cert/Basic.lean @@ -0,0 +1,215 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Data.Finset.Fold +import Mathlib.Data.Finset.Insert +import Mathlib.Data.Fintype.Pi +import Nfp.Circuit.Interface +import Nfp.Circuit.Semantics + +/-! +Circuit equivalence and a finite checker. +-/ + +namespace Nfp + +universe u v u' u_in u_out + +namespace Circuit + +variable {ι : Type u} [Fintype ι] [DecidableEq ι] +variable {α : Type v} + +/-- Circuits share the same input/output interface. -/ +def SameInterface (C₁ C₂ : Circuit ι α) : Prop := + C₁.inputs = C₂.inputs ∧ C₁.outputs = C₂.outputs + +/-- `SameInterface` is decidable. -/ +instance (C₁ C₂ : Circuit ι α) : Decidable (SameInterface C₁ C₂) := by + dsimp [SameInterface] + infer_instance + +/-- Circuits agree on outputs for all input assignments on a fixed interface. -/ +def EquivOn (C₁ C₂ : Circuit ι α) (h : SameInterface C₁ C₂) : Prop := + ∀ input : C₁.InputAssignment, ∀ i ∈ C₁.outputs, + evalInput C₁ input i = evalInput C₂ (InputAssignment.cast h.1 input) i + +/-- Circuits are equivalent if they share an interface and agree on all inputs. -/ +def Equiv (C₁ C₂ : Circuit ι α) : Prop := + ∃ h : SameInterface C₁ C₂, EquivOn C₁ C₂ h + +section Interface + +variable {ι₁ : Type u} [Fintype ι₁] [DecidableEq ι₁] +variable {ι₂ : Type u'} [Fintype ι₂] [DecidableEq ι₂] +variable {ι_in : Type u_in} {ι_out : Type u_out} + +/-- Circuits agree on outputs for all typed inputs on a shared interface. -/ +def EquivOnInterface (C₁ : Circuit ι₁ α) (C₂ : Circuit ι₂ α) + (I₁ : Interface C₁ ι_in ι_out) (I₂ : Interface C₂ ι_in ι_out) : Prop := + ∀ input : ι_in → α, ∀ o : ι_out, I₁.eval input o = I₂.eval input o + +end Interface + +section + +local instance : Std.Commutative (α := Bool) (· && ·) := ⟨Bool.and_comm⟩ +local instance : Std.Associative (α := Bool) (· && ·) := ⟨Bool.and_assoc⟩ + +/-- Boolean `all` over a finset (tail-recursive fold over the multiset). -/ +def finsetAll {β : Type v} (s : Finset β) (p : β → Bool) : Bool := by + classical + let f : Bool → β → Bool := fun acc a => acc && p a + have hf : RightCommutative f := by + refine ⟨?_⟩ + intro b a c + calc + f (f b a) c = ((b && p a) && p c) := rfl + _ = (b && (p a && p c)) := by simp [Bool.and_assoc] + _ = (b && (p c && p a)) := by simp [Bool.and_comm] + _ = ((b && p c) && p a) := by simp [Bool.and_assoc] + _ = f (f b c) a := rfl + let _ : RightCommutative f := hf + exact Multiset.foldl (f := f) (b := true) s.1 + +theorem finsetAll_eq_true_iff {β : Type v} {s : Finset β} {p : β → Bool} : + finsetAll s p = true ↔ ∀ a ∈ s, p a = true := by + classical + let f : Bool → β → Bool := fun acc a => acc && p a + have hf : RightCommutative f := by + refine ⟨?_⟩ + intro b a c + calc + f (f b a) c = ((b && p a) && p c) := rfl + _ = (b && (p a && p c)) := by simp [Bool.and_assoc] + _ = (b && (p c && p a)) := by simp [Bool.and_comm] + _ = ((b && p c) && p a) := by simp [Bool.and_assoc] + _ = f (f b c) a := rfl + let _ : RightCommutative f := hf + have hfoldl : + ∀ (s : Multiset β) (acc : Bool), + Multiset.foldl (f := f) (b := acc) s = true ↔ + acc = true ∧ Multiset.foldl (f := f) (b := true) s = true := by + intro s acc + revert acc + refine Multiset.induction_on s ?h0 ?hcons + · intro acc + simp [Multiset.foldl_zero] + · intro a s ih acc + have ih_acc : + Multiset.foldl (f := f) (b := acc && p a) s = true ↔ + (acc && p a) = true ∧ Multiset.foldl (f := f) (b := true) s = true := by + simpa using (ih (acc := acc && p a)) + have ih_pa : + Multiset.foldl (f := f) (b := p a) s = true ↔ + p a = true ∧ Multiset.foldl (f := f) (b := true) s = true := by + simpa using (ih (acc := p a)) + have hgoal : + Multiset.foldl (f := f) (b := acc && p a) s = true ↔ + acc = true ∧ Multiset.foldl (f := f) (b := p a) s = true := by + constructor + · intro h + have haccpa := ih_acc.mp h + have haccpa' : acc = true ∧ p a = true := by + simpa [Bool.and_eq_true] using haccpa.1 + have hacc : acc = true := haccpa'.1 + have hpa : p a = true := haccpa'.2 + have hfold : Multiset.foldl (f := f) (b := p a) s = true := + ih_pa.mpr ⟨hpa, haccpa.2⟩ + exact ⟨hacc, hfold⟩ + · intro h + rcases h with ⟨hacc, hfold⟩ + have hpa := ih_pa.mp hfold + have haccpa : (acc && p a) = true := by + simpa [Bool.and_eq_true] using And.intro hacc hpa.1 + exact ih_acc.mpr ⟨haccpa, hpa.2⟩ + simpa [Multiset.foldl_cons, f] using hgoal + induction s using Finset.induction_on with + | empty => + simp [finsetAll, Multiset.foldl_zero] + | @insert a s ha ih => + have hfold : + finsetAll (insert a s) p = true ↔ + p a = true ∧ finsetAll s p = true := by + have hval : (insert a s).1 = a ::ₘ s.1 := by + simpa using (Finset.insert_val_of_notMem (a := a) (s := s) ha) + calc + finsetAll (insert a s) p = true ↔ + Multiset.foldl (f := f) (b := true) (insert a s).1 = true := by + simp [finsetAll, f] + _ ↔ Multiset.foldl (f := f) (b := true) (a ::ₘ s.1) = true := by + simp [hval] + _ ↔ Multiset.foldl (f := f) (b := f true a) s.1 = true := by + simp [Multiset.foldl_cons] + _ ↔ Multiset.foldl (f := f) (b := p a) s.1 = true := by + simp [f] + _ ↔ p a = true ∧ Multiset.foldl (f := f) (b := true) s.1 = true := by + simpa using (hfoldl (s := s.1) (acc := p a)) + _ ↔ p a = true ∧ finsetAll s p = true := by + simp [finsetAll, f] + have hfold' : + finsetAll (insert a s) p = true ↔ p a = true ∧ finsetAll s p = true := hfold + simpa [Finset.forall_mem_insert, ih] using hfold' + +/-- Boolean check for interface equality. -/ +def sameInterface (C₁ C₂ : Circuit ι α) : Bool := + decide (C₁.inputs = C₂.inputs) && decide (C₁.outputs = C₂.outputs) + +theorem sameInterface_eq_true_iff (C₁ C₂ : Circuit ι α) : + sameInterface C₁ C₂ = true ↔ SameInterface C₁ C₂ := by + simp [sameInterface, SameInterface, Bool.and_eq_true] + +/-- Decide equivalence by enumerating all input assignments on a finite value type. -/ +def checkEquiv (C₁ C₂ : Circuit ι α) [Fintype α] [DecidableEq α] : Bool := + if h : SameInterface C₁ C₂ then + finsetAll (Finset.univ : Finset C₁.InputAssignment) (fun input => + finsetAll C₁.outputs (fun i => + decide (evalInput C₁ input i = evalInput C₂ (InputAssignment.cast h.1 input) i))) + else + false + +/-- `checkEquiv` is sound and complete for `Equiv`. -/ +theorem checkEquiv_eq_true_iff (C₁ C₂ : Circuit ι α) [Fintype α] [DecidableEq α] : + checkEquiv C₁ C₂ = true ↔ Equiv C₁ C₂ := by + classical + by_cases h : SameInterface C₁ C₂ + · have hcheck : checkEquiv C₁ C₂ = true ↔ EquivOn C₁ C₂ h := by + simp [checkEquiv, h, EquivOn, finsetAll_eq_true_iff] + constructor + · intro hc + exact ⟨h, hcheck.mp hc⟩ + · intro hEquiv + rcases hEquiv with ⟨h', hEq⟩ + have hh : h' = h := Subsingleton.elim _ _ + exact hcheck.mpr (by simpa [hh] using hEq) + · simp [checkEquiv, h, Equiv] + +end + +section InterfaceCheck + +variable {ι₁ : Type u} [Fintype ι₁] [DecidableEq ι₁] +variable {ι₂ : Type u'} [Fintype ι₂] [DecidableEq ι₂] +variable {ι_in : Type u_in} [Fintype ι_in] [DecidableEq ι_in] +variable {ι_out : Type u_out} [Fintype ι_out] + +/-- Decide interface-based equivalence by enumerating typed inputs. -/ +def checkEquivOnInterface (C₁ : Circuit ι₁ α) (C₂ : Circuit ι₂ α) + (I₁ : Interface C₁ ι_in ι_out) (I₂ : Interface C₂ ι_in ι_out) + [Fintype α] [DecidableEq α] : Bool := + finsetAll (Finset.univ : Finset (ι_in → α)) (fun input => + finsetAll (Finset.univ : Finset ι_out) (fun o => + decide (I₁.eval input o = I₂.eval input o))) + +/-- `checkEquivOnInterface` is sound and complete for `EquivOnInterface`. -/ +theorem checkEquivOnInterface_eq_true_iff (C₁ : Circuit ι₁ α) (C₂ : Circuit ι₂ α) + (I₁ : Interface C₁ ι_in ι_out) (I₂ : Interface C₂ ι_in ι_out) + [Fintype α] [DecidableEq α] : + checkEquivOnInterface C₁ C₂ I₁ I₂ = true ↔ EquivOnInterface C₁ C₂ I₁ I₂ := by + classical + simp [checkEquivOnInterface, EquivOnInterface, finsetAll_eq_true_iff] + +end InterfaceCheck + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Cert/DownstreamLinear.lean b/Nfp/Circuit/Cert/DownstreamLinear.lean index 85e1d96..e612c1e 100644 --- a/Nfp/Circuit/Cert/DownstreamLinear.lean +++ b/Nfp/Circuit/Cert/DownstreamLinear.lean @@ -1,7 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Nfp.Core.Basic -import Nfp.Circuit.Cert +import Nfp.Circuit.Cert.Basic /-! Downstream linear certificates for end-to-end induction bounds. diff --git a/Nfp/Circuit/Cert/ResidualBound.lean b/Nfp/Circuit/Cert/ResidualBound.lean index e7aa6bd..09cf83e 100644 --- a/Nfp/Circuit/Cert/ResidualBound.lean +++ b/Nfp/Circuit/Cert/ResidualBound.lean @@ -1,7 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Nfp.Core.Basic -import Nfp.Circuit.Cert +import Nfp.Circuit.Cert.Basic /-! Residual-stream bound certificates. diff --git a/Nfp/Circuit/Cert/ResidualInterval.lean b/Nfp/Circuit/Cert/ResidualInterval.lean index aa36547..b1d74c7 100644 --- a/Nfp/Circuit/Cert/ResidualInterval.lean +++ b/Nfp/Circuit/Cert/ResidualInterval.lean @@ -1,7 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Nfp.Core.Basic -import Nfp.Circuit.Cert +import Nfp.Circuit.Cert.Basic /-! Residual-stream interval certificates. diff --git a/Nfp/Circuit/Cert/SoftmaxMargin.lean b/Nfp/Circuit/Cert/SoftmaxMargin.lean index 3dc3784..5987f6c 100644 --- a/Nfp/Circuit/Cert/SoftmaxMargin.lean +++ b/Nfp/Circuit/Cert/SoftmaxMargin.lean @@ -2,7 +2,7 @@ import Mathlib.Algebra.BigOperators.Group.Finset.Basic import Nfp.Core.Basic -import Nfp.Circuit.Cert +import Nfp.Circuit.Cert.Basic import Nfp.Circuit.Layers.Induction /-! diff --git a/Nfp/Circuit/Cert/ValueRange.lean b/Nfp/Circuit/Cert/ValueRange.lean index ad20a04..342f93f 100644 --- a/Nfp/Circuit/Cert/ValueRange.lean +++ b/Nfp/Circuit/Cert/ValueRange.lean @@ -2,7 +2,7 @@ import Mathlib.Algebra.BigOperators.Group.Finset.Basic import Nfp.Core.Basic -import Nfp.Circuit.Cert +import Nfp.Circuit.Cert.Basic import Nfp.Circuit.Layers.Induction /-! diff --git a/Nfp/Circuit/Typed.lean b/Nfp/Circuit/Typed.lean index 590ee4c..b45d63d 100644 --- a/Nfp/Circuit/Typed.lean +++ b/Nfp/Circuit/Typed.lean @@ -1,7 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later import Nfp.Circuit.Combinators -import Nfp.Circuit.Cert +import Nfp.Circuit.Cert.Basic /-! Typed circuit wrappers and typed equivalence checking. diff --git a/Nfp/IO.lean b/Nfp/IO.lean index 344d7a8..a2b17d2 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -1,799 +1,15 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later + import Nfp.IO.Checks import Nfp.IO.Derive -import Nfp.IO.Loaders -import Nfp.IO.NfptPure import Nfp.IO.HeadScore import Nfp.IO.InductionHead -import Nfp.IO.Util -import Nfp.Circuit.Cert.LogitDiff -import Nfp.Circuit.Cert.DownstreamLinear -import Nfp.Circuit.Cert.ResidualBound -import Nfp.Circuit.Cert.ResidualInterval -import Nfp.Sound.Bounds.MatrixNorm -import Nfp.Sound.Bounds.Transformer -import Nfp.Sound.Induction -import Nfp.Sound.Induction.HeadBounds -import Nfp.Sound.Induction.LogitDiff -import Nfp.Sound.Linear.FinFold +import Nfp.IO.Loaders +import Nfp.IO.NfptPure +import Nfp.IO.Run import Nfp.IO.Timing -namespace Nfp -namespace IO -open Nfp.Circuit +import Nfp.IO.Util -/-- Check induction certificates and print a short status line. -/ -def runInductionCertify (scoresPath : System.FilePath) - (valuesPath? : Option System.FilePath) (minActive? : Option Nat) - (minLogitDiffStr? : Option String) (minMarginStr? : Option String) - (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - if minLogitDiff?.isSome && valuesPath?.isNone then - IO.eprintln "error: min-logit-diff requires --values" - return 2 - let parsedScores ← timePhase "load softmax cert" <| - loadSoftmaxMarginCert scoresPath - match parsedScores with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seq, cert⟩ => - let scoresOk ← timePhase "check softmax cert" <| - checkSoftmaxMargin seq cert - match scoresOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - let activeCount := cert.active.card - let defaultMinActive := max 1 (seq / 8) - let minActive := minActive?.getD defaultMinActive - if activeCount < minActive then - IO.eprintln - s!"error: active queries {activeCount} below minimum {minActive}" - return 2 - if cert.margin < minMargin then - IO.eprintln - s!"error: margin {cert.margin} below minimum {minMargin}" - return 2 - if maxEps < cert.eps then - IO.eprintln - s!"error: eps {cert.eps} above maximum {maxEps}" - return 2 - match valuesPath? with - | none => - IO.println - s!"ok: softmax-margin certificate accepted \ - (seq={seq}, active={activeCount})" - return 0 - | some valuesPath => - let parsedValues ← loadValueRangeCert valuesPath - match parsedValues with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seqVals, certVals⟩ => - if hseq : seqVals ≠ seq then - IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" - return 2 - else - have hseq' : seqVals = seq := by - exact (not_ne_iff).1 hseq - let certVals' : ValueRangeCert seq := by - simpa [hseq'] using certVals - let valuesOk ← checkValueRange seq certVals' - match valuesOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - let tol := cert.eps * (certVals'.hi - certVals'.lo) - let logitDiffLB? := - Circuit.logitDiffLowerBound cert.active cert.prev cert.eps - certVals'.lo certVals'.hi certVals'.vals - let effectiveMinLogitDiff := - match minLogitDiff?, certVals'.direction with - | some v, _ => some v - | none, some _ => some (0 : Rat) - | none, none => none - match logitDiffLB? with - | none => - IO.eprintln "error: empty active set for logit-diff bound" - return (2 : UInt32) - | some logitDiffLB => - let violation? : Option Rat := - match effectiveMinLogitDiff with - | none => none - | some minLogitDiff => - if logitDiffLB < minLogitDiff then - some minLogitDiff - else - none - match violation? with - | some minLogitDiff => - IO.eprintln - s!"error: logitDiffLB {logitDiffLB} \ - below minimum {minLogitDiff}" - return (2 : UInt32) - | none => - IO.println - s!"ok: induction bound certified \ - (seq={seq}, active={activeCount}, tol={tol}, \ - logitDiffLB={logitDiffLB})" - return 0 -/-- Build and check induction certificates from raw scores/values. -/ -def runInductionCertifySound (scoresPath : System.FilePath) - (valuesPath : System.FilePath) (minActive? : Option Nat) - (minLogitDiffStr? : Option String) (minMarginStr? : Option String) - (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - let parsedScores ← loadSoftmaxMarginRaw scoresPath - match parsedScores with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seq, raw⟩ => - match seq with - | 0 => - IO.eprintln "error: seq must be positive" - return 2 - | Nat.succ n => - let seq := Nat.succ n - let _ : NeZero seq := ⟨by simp⟩ - match Sound.buildSoftmaxMarginCert? raw.active raw.prev raw.scores raw.weights with - | none => - IO.eprintln "error: softmax-margin inputs rejected" - return 2 - | some ⟨cert, _⟩ => - let activeCount := cert.active.card - let defaultMinActive := max 1 (seq / 8) - let minActive := minActive?.getD defaultMinActive - if activeCount < minActive then - IO.eprintln - s!"error: active queries {activeCount} below minimum {minActive}" - return 2 - if cert.margin < minMargin then - IO.eprintln - s!"error: margin {cert.margin} below minimum {minMargin}" - return 2 - if maxEps < cert.eps then - IO.eprintln - s!"error: eps {cert.eps} above maximum {maxEps}" - return 2 - let parsedValues ← loadValueRangeRaw valuesPath - match parsedValues with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seqVals, rawVals⟩ => - if hseq : seqVals ≠ seq then - IO.eprintln - s!"error: seq mismatch (scores={seq}, values={seqVals})" - return 2 - else - have hseq' : seqVals = seq := by - exact (not_ne_iff).1 hseq - let rawVals' : Pure.ValueRangeRaw seq := by - simpa [hseq'] using rawVals - match Sound.buildValueRangeCert? rawVals'.vals rawVals'.direction with - | none => - IO.eprintln "error: value-range inputs rejected" - return 2 - | some ⟨certVals, _⟩ => - let tol := cert.eps * (certVals.hi - certVals.lo) - let logitDiffLB? := - Circuit.logitDiffLowerBound cert.active cert.prev cert.eps - certVals.lo certVals.hi certVals.vals - let effectiveMinLogitDiff := - match minLogitDiff?, certVals.direction with - | some v, _ => some v - | none, some _ => some (0 : Rat) - | none, none => none - match logitDiffLB? with - | none => - IO.eprintln "error: empty active set for logit-diff bound" - return 2 - | some logitDiffLB => - let violation? : Option Rat := - match effectiveMinLogitDiff with - | none => none - | some minLogitDiff => - if logitDiffLB < minLogitDiff then - some minLogitDiff - else - none - match violation? with - | some minLogitDiff => - IO.eprintln - s!"error: logitDiffLB {logitDiffLB} \ - below minimum {minLogitDiff}" - return 2 - | none => - IO.println - s!"ok: induction bound certified \ - (seq={seq}, active={activeCount}, \ - tol={tol}, logitDiffLB={logitDiffLB})" - return 0 -/-- Check end-to-end induction certificates with a downstream error bound. -/ -def runInductionCertifyEndToEnd (scoresPath : System.FilePath) - (valuesPath : System.FilePath) (downstreamPath : System.FilePath) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - let parsedScores ← timePhase "load softmax cert" <| - loadSoftmaxMarginCert scoresPath - match parsedScores with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seq, cert⟩ => - let scoresOk ← timePhase "check softmax cert" <| - checkSoftmaxMargin seq cert - match scoresOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - let activeCount := cert.active.card - let defaultMinActive := max 1 (seq / 8) - let minActive := minActive?.getD defaultMinActive - if activeCount < minActive then - IO.eprintln - s!"error: active queries {activeCount} below minimum {minActive}" - return 2 - if cert.margin < minMargin then - IO.eprintln - s!"error: margin {cert.margin} below minimum {minMargin}" - return 2 - if maxEps < cert.eps then - IO.eprintln - s!"error: eps {cert.eps} above maximum {maxEps}" - return 2 - let parsedValues ← timePhase "load value cert" <| - loadValueRangeCert valuesPath - match parsedValues with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seqVals, certVals⟩ => - if hseq : seqVals ≠ seq then - IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" - return 2 - else - have hseq' : seqVals = seq := by - exact (not_ne_iff).1 hseq - let certVals' : ValueRangeCert seq := by - simpa [hseq'] using certVals - let valuesOk ← timePhase "check value cert" <| - checkValueRange seq certVals' - match valuesOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - let logitDiffLB? ← timePure "logit-diff lower bound" (fun () => - Circuit.logitDiffLowerBound cert.active cert.prev cert.eps - certVals'.lo certVals'.hi certVals'.vals) - let effectiveMinLogitDiff := - match minLogitDiff?, certVals'.direction with - | some v, _ => some v - | none, some _ => some (0 : Rat) - | none, none => none - match logitDiffLB? with - | none => - IO.eprintln "error: empty active set for logit-diff bound" - return (2 : UInt32) - | some logitDiffLB => - let parsedDownstream ← loadDownstreamLinearCert downstreamPath - match parsedDownstream with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok downstream => - let downstreamOk := Circuit.checkDownstreamLinearCert downstream - if downstreamOk then - let finalLB := logitDiffLB - downstream.error - let violation? : Option Rat := - match effectiveMinLogitDiff with - | none => none - | some minLogitDiff => - if finalLB < minLogitDiff then - some minLogitDiff - else - none - match violation? with - | some minLogitDiff => - IO.eprintln - s!"error: end-to-end logitDiffLB {finalLB} \ - below minimum {minLogitDiff}" - return (2 : UInt32) - | none => - IO.println - s!"ok: end-to-end induction bound certified \ - (seq={seq}, active={activeCount}, \ - logitDiffLB={logitDiffLB}, \ - downstreamError={downstream.error}, \ - finalLB={finalLB})" - return 0 - else - IO.eprintln "error: downstream certificate rejected" - return 2 -/-- Check end-to-end induction certificates with a downstream matrix. -/ -def runInductionCertifyEndToEndMatrix (scoresPath : System.FilePath) - (valuesPath : System.FilePath) (matrixPath : System.FilePath) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - let parsedScores ← timePhase "load softmax cert" <| - loadSoftmaxMarginCert scoresPath - match parsedScores with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seq, cert⟩ => - let scoresOk ← timePhase "check softmax cert" <| - checkSoftmaxMargin seq cert - match scoresOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - let activeCount := cert.active.card - let defaultMinActive := max 1 (seq / 8) - let minActive := minActive?.getD defaultMinActive - if activeCount < minActive then - IO.eprintln - s!"error: active queries {activeCount} below minimum {minActive}" - return 2 - if cert.margin < minMargin then - IO.eprintln - s!"error: margin {cert.margin} below minimum {minMargin}" - return 2 - if maxEps < cert.eps then - IO.eprintln - s!"error: eps {cert.eps} above maximum {maxEps}" - return 2 - let parsedValues ← timePhase "load value cert" <| - loadValueRangeCert valuesPath - match parsedValues with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seqVals, certVals⟩ => - if hseq : seqVals ≠ seq then - IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" - return 2 - else - have hseq' : seqVals = seq := by - exact (not_ne_iff).1 hseq - let certVals' : ValueRangeCert seq := by - simpa [hseq'] using certVals - let valuesOk ← timePhase "check value cert" <| - checkValueRange seq certVals' - match valuesOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - let logitDiffLB? ← timePure "logit-diff lower bound" (fun () => - Circuit.logitDiffLowerBound cert.active cert.prev cert.eps - certVals'.lo certVals'.hi certVals'.vals) - let effectiveMinLogitDiff := - match minLogitDiff?, certVals'.direction with - | some v, _ => some v - | none, some _ => some (0 : Rat) - | none, none => none - match logitDiffLB? with - | none => - IO.eprintln "error: empty active set for logit-diff bound" - return (2 : UInt32) - | some logitDiffLB => - let parsedMatrix ← loadDownstreamMatrixRaw matrixPath - match parsedMatrix with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨rows, ⟨cols, raw⟩⟩ => - let inputBound := raw.inputBound - if hneg : inputBound < 0 then - IO.eprintln - s!"error: input-bound {inputBound} must be nonnegative" - return 2 - else - have hinput : 0 ≤ inputBound := by - exact le_of_not_gt hneg - let W : Matrix (Fin rows) (Fin cols) Rat := raw.entries - let downstream := - (Sound.Bounds.buildDownstreamLinearCert W inputBound hinput).1 - let finalLB := logitDiffLB - downstream.error - let violation? : Option Rat := - match effectiveMinLogitDiff with - | none => none - | some minLogitDiff => - if finalLB < minLogitDiff then - some minLogitDiff - else - none - match violation? with - | some minLogitDiff => - IO.eprintln - s!"error: end-to-end logitDiffLB {finalLB} \ - below minimum {minLogitDiff}" - return (2 : UInt32) - | none => - IO.println - s!"ok: end-to-end induction bound certified \ - (seq={seq}, active={activeCount}, \ - logitDiffLB={logitDiffLB}, \ - downstreamError={downstream.error}, \ - finalLB={finalLB})" - return 0 -/-- Check end-to-end induction certificates using a model file and residual bounds - (loaded from disk or derived from the model). -/ -def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) - (valuesPath : System.FilePath) (modelPath : System.FilePath) - (residualIntervalPath? : Option System.FilePath) - (layer? : Option Nat) (head? : Option Nat) (period? : Option Nat) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - let parsedScores ← timePhase "load softmax cert" <| - loadSoftmaxMarginCert scoresPath - match parsedScores with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seq, cert⟩ => - let scoresOk ← timePhase "check softmax cert" <| - checkSoftmaxMargin seq cert - match scoresOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - let activeCount := cert.active.card - let defaultMinActive := max 1 (seq / 8) - let minActive := minActive?.getD defaultMinActive - if activeCount < minActive then - IO.eprintln - s!"error: active queries {activeCount} below minimum {minActive}" - return 2 - if cert.margin < minMargin then - IO.eprintln - s!"error: margin {cert.margin} below minimum {minMargin}" - return 2 - if maxEps < cert.eps then - IO.eprintln - s!"error: eps {cert.eps} above maximum {maxEps}" - return 2 - let parsedValues ← timePhase "load value cert" <| - loadValueRangeCert valuesPath - match parsedValues with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seqVals, certVals⟩ => - if hseq : seqVals ≠ seq then - IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" - return 2 - else - have hseq' : seqVals = seq := by - exact (not_ne_iff).1 hseq - let certVals' : ValueRangeCert seq := by - simpa [hseq'] using certVals - let valuesOk ← timePhase "check value cert" <| - checkValueRange seq certVals' - match valuesOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - let logitDiffLB? ← timePure "logit-diff lower bound" (fun () => - Circuit.logitDiffLowerBound cert.active cert.prev cert.eps - certVals'.lo certVals'.hi certVals'.vals) - let effectiveMinLogitDiff := - match minLogitDiff?, certVals'.direction with - | some v, _ => some v - | none, some _ => some (0 : Rat) - | none, none => none - match logitDiffLB? with - | none => - IO.eprintln "error: empty active set for logit-diff bound" - return (2 : UInt32) - | some logitDiffLB => - match certVals'.direction with - | none => - IO.eprintln - "error: value-range certificate missing direction \ - metadata" - return 2 - | some dirSpec => - let data ← timePhase "read model file" <| - IO.FS.readBinFile modelPath - let headerE ← timePure "parse model header" (fun () => - NfptPure.parseHeader data) - match headerE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨header, start⟩ => - if hseq : header.seqLen = seq then - let active? : Option (Finset (Fin header.seqLen)) := - if hactive : cert.active.Nonempty then - some (by simpa [hseq] using cert.active) - else - none - let residualCertE : Except String - (ResidualIntervalCert header.modelDim) ← - match residualIntervalPath? with - | some residualIntervalPath => do - let parsedResidual ← - timePhase "load residual interval" <| - loadResidualIntervalCert residualIntervalPath - match parsedResidual with - | Except.error msg => pure (Except.error msg) - | Except.ok ⟨dim, residualCert⟩ => - if hdim : dim = header.modelDim then - let residualCert' : - ResidualIntervalCert header.modelDim := by - simpa [hdim] using residualCert - pure (Except.ok residualCert') - else - pure (Except.error - s!"residual interval dim {dim} \ - does not match model dim {header.modelDim}") - | none => - deriveResidualIntervalFromModel data start header - active? - match residualCertE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok residualCert' => - let residualOk ← - timePure "check residual interval" (fun () => - Circuit.checkResidualIntervalCert residualCert') - if residualOk then - let dirPos := dirSpec.target - let dirNeg := dirSpec.negative - if layer?.isSome != head?.isSome then - IO.eprintln - "error: --layer and --head must be provided \ - together" - return 2 - let headChoice? : Option (Nat × Nat) := - match layer?, head? with - | some layer, some head => some (layer, head) - | _, _ => none - if period?.isSome && headChoice?.isNone then - IO.eprintln - "warning: --period ignored without \ - --layer/--head" - let colTargetE ← - timePure "read unembed column target" (fun () => - NfptPure.readUnembedColumn - data start header dirPos) - match colTargetE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok colTarget => - let colNegE ← - timePure "read unembed column negative" (fun () => - NfptPure.readUnembedColumn - data start header dirNeg) - match colNegE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok colNeg => - let dirVec : - Fin header.modelDim → Rat := - fun i => colTarget i - colNeg i - let dotIntervalAbs := - Sound.Bounds.dotIntervalAbsBound - let intervalErrorFromHead? : - Model.InductionHeadInputs - seq header.modelDim header.headDim → - ResidualIntervalCert header.modelDim → - Option Rat := - fun inputs residual => by - classical - match hseq0 : seq with - | 0 => exact none - | Nat.succ n => - let _ : NeZero seq := by - exact ⟨by simp [hseq0]⟩ - match - Sound.buildHeadOutputIntervalFromHead? - inputs with - | none => exact none - | some result => - exact some - (dotIntervalAbs - dirVec - (fun i => - residual.lo i - - result.cert.hi i) - (fun i => - residual.hi i - - result.cert.lo i)) - let downstreamError ← - timePure "downstream error" (fun () => - dotIntervalAbs - dirVec - residualCert'.lo - residualCert'.hi) - let finalLB := logitDiffLB - downstreamError - let intervalError? ← - match headChoice? with - | none => pure none - | some (layer, head) => do - let inputsE ← - timePure "read head inputs" (fun () => - NfptPure.readInductionHeadInputs - data start header layer head - dirPos dirNeg period?) - match inputsE with - | Except.error msg => - IO.eprintln s!"warning: {msg}" - pure none - | Except.ok inputs => - let inputs' : - Model.InductionHeadInputs - seq header.modelDim - header.headDim := by - simpa [hseq] using inputs - let inputsAligned : - Model.InductionHeadInputs - seq header.modelDim - header.headDim := - { inputs' with - active := cert.active - prev := cert.prev } - let intervalError? ← - timePure - "head output interval" - (fun () => - intervalErrorFromHead? - inputsAligned - residualCert') - match intervalError? with - | none => - IO.eprintln - "warning: head output interval \ - rejected" - pure none - | some intervalError => - pure (some intervalError) - let intervalLB? := - intervalError?.map (fun err => - logitDiffLB - err) - let effectiveLB := - match intervalLB? with - | some intervalLB => max finalLB intervalLB - | none => finalLB - let violation? : Option Rat := - match effectiveMinLogitDiff with - | none => none - | some minLogitDiff => - if effectiveLB < minLogitDiff then - some minLogitDiff - else - none - match violation? with - | some minLogitDiff => - IO.eprintln - s!"error: end-to-end bound \ - {effectiveLB} below minimum \ - {minLogitDiff}" - return (2 : UInt32) - | none => - match intervalLB? with - | none => - IO.println - s!"ok: end-to-end induction \ - bound certified (seq={seq}, \ - active={activeCount}, \ - logitDiffLB={logitDiffLB}, \ - downstreamError={downstreamError}, \ - finalLB={finalLB})" - | some intervalLB => - let intervalError := - logitDiffLB - intervalLB - IO.println - s!"ok: end-to-end induction \ - bound certified (seq={seq}, \ - active={activeCount}, \ - logitDiffLB={logitDiffLB}, \ - downstreamError={downstreamError}, \ - finalLB={finalLB}, \ - intervalError={intervalError}, \ - intervalLB={intervalLB}, \ - effectiveLB={effectiveLB})" - return 0 - else - IO.eprintln - "error: residual-interval certificate rejected" - return 2 - else - IO.eprintln - s!"error: model seq {header.seqLen} \ - does not match cert seq {seq}" - return 2 -end IO -end Nfp +/-! +IO-only wrappers for loading inputs and running checks. +-/ diff --git a/Nfp/IO/Bench/InductionCore.lean b/Nfp/IO/Bench/InductionCore.lean deleted file mode 100644 index 7ba8c29..0000000 --- a/Nfp/IO/Bench/InductionCore.lean +++ /dev/null @@ -1,229 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Nfp.IO.Timing -import Nfp.Sound.Induction -import Nfp.Sound.Induction.HeadBounds - -/-! -Benchmark helpers for induction-head core certification. --/ - -namespace Nfp - -namespace IO - -open Sound -open scoped BigOperators - -private def benchPhasePure {α : Type} (label : String) (act : Unit → α) : IO α := do - IO.println s!"bench: {label} start" - flushStdout - timePhase label (pure (act ())) - -private def forceScore {seq dModel dHead : Nat} - (score : Sound.HeadScoreBounds seq dModel dHead) : Rat := - score.margin + score.eps - -private def forceValues {seq dModel dHead : Nat} - (vals : Sound.HeadValueBounds seq dModel dHead) : Rat := - vals.lo + vals.hi - -private def forceQAbs {seq dHead : Nat} - (qAbs : Fin seq → Fin dHead → Rat) : Rat := - (Finset.univ : Finset (Fin seq)).sum (fun q => - (Finset.univ : Finset (Fin dHead)).sum (fun d => qAbs q d)) - -private def forceLn {seq dModel : Nat} - (ln : Fin seq → Fin dModel → Rat) : Rat := - (Finset.univ : Finset (Fin seq)).sum (fun q => - (Finset.univ : Finset (Fin dModel)).sum (fun i => ln q i)) - -private def forceKAbs {seq dHead : Nat} - (kAbs : Fin seq → Fin dHead → Rat) : Rat := - (Finset.univ : Finset (Fin seq)).sum (fun q => - (Finset.univ : Finset (Fin dHead)).sum (fun d => kAbs q d)) - -private def forceDotAbs {seq dHead : Nat} - (qAbs kAbs : Fin seq → Fin dHead → Rat) : Rat := - (Finset.univ : Finset (Fin seq)).sum (fun q => - (Finset.univ : Finset (Fin seq)).sum (fun k => - Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d))) - -private def forceDotAbsTasksReduce {seq dHead : Nat} - (qAbs kAbs : Fin seq → Fin dHead → Rat) : Rat := - let tasks : Array (Task Rat) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - (Finset.univ : Finset (Fin seq)).sum (fun k => - Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d)))) - (Finset.univ : Finset (Fin seq)).sum (fun q => - (tasks[q.1]'(by simp [tasks, q.isLt])).get) - -private def isPow2 (n : Nat) : Bool := - if n = 0 then - false - else - decide (Nat.pow 2 (Nat.log2 n) = n) - -private def isPow2Den (q : Rat) : Bool := - isPow2 q.den - -private def countPow2Den {seq dHead : Nat} - (qs : List (Fin seq)) (ds : List (Fin dHead)) - (f : Fin seq → Fin dHead → Rat) : Nat := - qs.foldl (fun acc q => - ds.foldl (fun acc' d => acc' + (if isPow2Den (f q d) then 1 else 0)) acc) 0 - -private def pow2DenSampleReport {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (qkv : Sound.HeadQKVBounds seq dModel dHead) : String := - let qs := (List.finRange seq).take (min seq 2) - let ds := (List.finRange dHead).take (min dHead 8) - let total := qs.length * ds.length - let qLoDy := countPow2Den qs ds qkv.qLo - let qHiDy := countPow2Den qs ds qkv.qHi - let qAbsDy := countPow2Den qs ds qkv.qAbs - let kAbsDy := countPow2Den qs ds qkv.kAbs - let epsDy := if isPow2Den inputs.lnEps then 1 else 0 - s!"pow2-den sample: total={total} qLo={qLoDy} qHi={qHiDy} qAbs={qAbsDy} " ++ - s!"kAbs={kAbsDy} lnEps={epsDy}" - -private def pow2DenSanityReport : String := - let rat := ratRoundDown (Rat.divInt 1 8) - let powChecks := - s!"pow2(1)={isPow2 1} pow2(2)={isPow2 2} pow2(3)={isPow2 3} " ++ - s!"pow2(4)={isPow2 4} pow2(8)={isPow2 8}" - let ratCheck := s!"rat(1/8).den={rat.den} pow2den={isPow2Den rat}" - s!"pow2-den sanity: {powChecks} {ratCheck}" - -private def forceQRowTasks {seq dHead : Nat} - (q0 : Fin seq) (qLo : Fin seq → Fin dHead → Rat) : Int := - let tasks : Array (Task Rat) := - Array.ofFn (fun d : Fin dHead => - Task.spawn (fun _ => qLo q0 d)) - let total := - (Finset.univ : Finset (Fin dHead)).sum (fun d => - (tasks[d.1]'(by simp [tasks, d.isLt])).get) - total.num - -private def qAbsRowChunk {seq dHead : Nat} - (q0 : Fin seq) (qAbs : Fin seq → Fin dHead → Rat) (start stop : Nat) : Rat := - let chunk : Finset (Fin dHead) := - (Finset.univ : Finset (Fin dHead)).filter (fun d => start ≤ d.1 ∧ d.1 < stop) - chunk.sum (fun d => qAbs q0 d) - -private def forceQAbsRowTasksReduce {seq dHead : Nat} - (q0 : Fin seq) (qAbs : Fin seq → Fin dHead → Rat) (chunkSize : Nat) : Rat := - if chunkSize = 0 then - (0 : Rat) - else - let chunks : Nat := (dHead + chunkSize - 1) / chunkSize - let tasks : Array (Task Rat) := - Array.ofFn (fun i : Fin chunks => - Task.spawn (fun _ => - let start := i.1 * chunkSize - let stop := min dHead (start + chunkSize) - qAbsRowChunk q0 qAbs start stop)) - (Finset.univ : Finset (Fin chunks)).sum (fun i => - (tasks[i.1]'(by simp [tasks, i.isLt])).get) - -private def forceQAbsAllTasksReduce {seq dHead : Nat} - (qAbs : Fin seq → Fin dHead → Rat) (chunkSize : Nat) : Rat := - let tasks : Array (Task Rat) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => forceQAbsRowTasksReduce q qAbs chunkSize)) - (Finset.univ : Finset (Fin seq)).sum (fun q => - (tasks[q.1]'(by simp [tasks, q.isLt])).get) - -private def forceQAbsActiveTasksReduce {seq dHead : Nat} - (active : Finset (Fin seq)) (qAbs : Fin seq → Fin dHead → Rat) (chunkSize : Nat) : Rat := - if hactive : active.Nonempty then - let tasks : Array (Task Rat) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - if q ∈ active then - forceQAbsRowTasksReduce q qAbs chunkSize - else - (0 : Rat))) - active.sum (fun q => - (tasks[q.1]'(by simp [tasks, q.isLt])).get) - else - (0 : Rat) - -/-- Run a core benchmark from already-parsed head inputs. -/ -def runCoreBenchFromInputs {seq dModel dHead : Nat} [NeZero seq] - (inputs : Model.InductionHeadInputs seq dModel dHead) : IO Unit := do - let lnBounds ← benchPhasePure "ln bounds" (fun () => Sound.headLnBounds inputs) - let lnLo := lnBounds.1 - let lnHi := lnBounds.2 - let _ ← benchPhasePure "lnLo force" (fun () => forceLn lnLo) - let _ ← benchPhasePure "lnHi force" (fun () => forceLn lnHi) - let qkv ← benchPhasePure "qkv bounds" (fun () => Sound.headQKVBounds inputs lnLo lnHi) - let _ ← timePhase "pow2-den sample" (do - IO.println (pow2DenSampleReport inputs qkv) - IO.println pow2DenSanityReport - pure ()) - let _ ← benchPhasePure "qLo single" (fun () => - match h : dHead with - | 0 => (0 : Rat) - | Nat.succ _ => - let q0 : Fin seq := - ⟨0, Nat.pos_of_ne_zero (NeZero.ne (n := seq))⟩ - let d0 : Fin dHead := ⟨0, by simp [h]⟩ - qkv.qLo q0 d0) - let _ ← benchPhasePure "qLo row tasks" (fun () => - let q0 : Fin seq := - ⟨0, by - exact Nat.pos_of_ne_zero (NeZero.ne (n := seq))⟩ - forceQRowTasks q0 qkv.qLo) - let _ ← benchPhasePure "qLo row" (fun () => - let q0 : Fin seq := - ⟨0, by - exact Nat.pos_of_ne_zero (NeZero.ne (n := seq))⟩ - let total := - (Finset.univ : Finset (Fin dHead)).sum (fun d => qkv.qLo q0 d) - total.num) - let _ ← benchPhasePure "qLo force" (fun () => forceQAbs qkv.qLo) - let _ ← benchPhasePure "qHi force" (fun () => forceQAbs qkv.qHi) - let _ ← benchPhasePure "qAbs single" (fun () => - let q0 : Fin seq := - ⟨0, by - exact Nat.pos_of_ne_zero (NeZero.ne (n := seq))⟩ - match h : dHead with - | 0 => (0 : Rat) - | Nat.succ _ => - let d0 : Fin dHead := ⟨0, by simp [h]⟩ - qkv.qAbs q0 d0) - let _ ← benchPhasePure "qAbs row tasks reduce" (fun () => - let q0 : Fin seq := - ⟨0, by - exact Nat.pos_of_ne_zero (NeZero.ne (n := seq))⟩ - forceQAbsRowTasksReduce q0 qkv.qAbs 1) - let _ ← benchPhasePure "qAbs force tasks reduce active" (fun () => - forceQAbsActiveTasksReduce inputs.active qkv.qAbs 1) - let _ ← benchPhasePure "qAbs force tasks reduce" (fun () => - forceQAbsAllTasksReduce qkv.qAbs 1) - let _ ← benchPhasePure "kAbs force tasks reduce active" (fun () => - forceQAbsActiveTasksReduce inputs.active qkv.kAbs 1) - let _ ← benchPhasePure "kAbs force tasks reduce" (fun () => - forceQAbsAllTasksReduce qkv.kAbs 1) - let _ ← benchPhasePure "kAbs force tasks reduce (bench)" (fun () => - forceQAbsAllTasksReduce qkv.kAbs 1) - let _ ← benchPhasePure "dotAbs force tasks reduce" (fun () => - forceDotAbsTasksReduce qkv.qAbs qkv.kAbs) - let _ ← benchPhasePure "dotAbs force" (fun () => forceDotAbs qkv.qAbs qkv.kAbs) - let score ← benchPhasePure "score bounds" (fun () => - Sound.headScoreBounds inputs qkv.qAbs qkv.kAbs) - let _ ← benchPhasePure "score force" (fun () => forceScore score) - let vals ← benchPhasePure "value bounds" (fun () => - Sound.headValueBounds inputs qkv.vLo qkv.vHi) - let _ ← benchPhasePure "value force" (fun () => forceValues vals) - let cert ← benchPhasePure "core cert" (fun () => - Sound.buildInductionCertFromHeadCore? inputs) - match cert with - | none => IO.println "bench: core cert none" - | some _ => IO.println "bench: core cert some" - -end IO - -end Nfp diff --git a/Nfp/IO/Bench/InductionCounts.lean b/Nfp/IO/Bench/InductionCounts.lean deleted file mode 100644 index 3571f18..0000000 --- a/Nfp/IO/Bench/InductionCounts.lean +++ /dev/null @@ -1,72 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std.Data.HashMap -import Nfp.Model.InductionHead - -/-! -Call-count instrumentation for induction-head computations. - -This is a placeholder-only benchmark: it records how often key functions would be -called in a score-bound pass without performing heavy arithmetic. --/ - -namespace Nfp - -namespace IO - -open scoped BigOperators - -private def bumpCount (ref : IO.Ref (Std.HashMap String Nat)) (key : String) (n : Nat) : - IO Unit := do - ref.modify (fun m => - let cur := (m.get? key).getD 0 - m.insert key (cur + n)) - -private def printCounts (ref : IO.Ref (Std.HashMap String Nat)) : IO Unit := do - let m ← ref.get - let entries := m.toList - IO.println "counts:" - for (k, v) in entries do - IO.println s!" {k}: {v}" - -private def countScoreCalls {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (ref : IO.Ref (Std.HashMap String Nat)) : IO Unit := do - let activeCount := inputs.active.card - let otherCount := seq - 1 - let rowCount := activeCount * otherCount - let elemCount := rowCount * dHead - bumpCount ref "scoreBounds:scoreLo" rowCount - bumpCount ref "scoreBounds:scoreHi" rowCount - bumpCount ref "scoreBounds:qAbs" elemCount - bumpCount ref "scoreBounds:qLo" elemCount - bumpCount ref "scoreBounds:qHi" elemCount - bumpCount ref "scoreBounds:kAbs" elemCount - bumpCount ref "scoreBounds:kLo" elemCount - bumpCount ref "scoreBounds:kHi" elemCount - -private def countQKVCalls {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (ref : IO.Ref (Std.HashMap String Nat)) : IO Unit := do - let activeCount := inputs.active.card - let elemCount := activeCount * dHead - bumpCount ref "qkvBounds:qLo" elemCount - bumpCount ref "qkvBounds:qHi" elemCount - bumpCount ref "qkvBounds:kLo" elemCount - bumpCount ref "qkvBounds:kHi" elemCount - bumpCount ref "qkvBounds:vLo" elemCount - bumpCount ref "qkvBounds:vHi" elemCount - bumpCount ref "qkvBounds:qAbs" elemCount - bumpCount ref "qkvBounds:kAbs" elemCount - -/-- Count calls used by score/QKV bounds on the active set. -/ -def countInductionCalls {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : IO Unit := do - let ref ← IO.mkRef (∅ : Std.HashMap String Nat) - countQKVCalls inputs ref - countScoreCalls inputs ref - printCounts ref - -end IO - -end Nfp diff --git a/Nfp/IO/Bench/Rational.lean b/Nfp/IO/Bench/Rational.lean deleted file mode 100644 index 657d56a..0000000 --- a/Nfp/IO/Bench/Rational.lean +++ /dev/null @@ -1,362 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.List.Range -import Nfp.IO.Timing -import Nfp.Sound.Bounds.MatrixNorm -import Nfp.Sound.Linear.FinFold - -/-! -Microbenchmarks for rational arithmetic and caching strategies. --/ - -namespace Nfp - -namespace IO - -open Sound - -private def benchItersFor (base : Nat) (n : Nat) : Nat := - let scale := max 1 (n / 64) - max 1 (base / scale) - -private def mkRat (num den : Nat) (neg : Bool) : Rat := - let n : Int := Int.ofNat (num + 1) - let d : Int := Int.ofNat (den + 1) - let q : Rat := Rat.divInt (if neg then -n else n) d - ratRoundDown q - -private def mkVecRat (n : Nat) (seed : Nat) (salt : Nat) (negEvery : Nat) : Fin n → Rat := fun i => - let idx := i.1 + seed + salt - let neg := (idx % negEvery) = 0 - mkRat (idx % 97) (idx % 89) neg - -private def mkInterval (n : Nat) (seed : Nat) : - (Fin n → Rat) × (Fin n → Rat) × (Fin n → Rat) := - let v : Fin n → Rat := mkVecRat n seed 0 2 - let base : Fin n → Rat := mkVecRat n seed 13 3 - let lo : Fin n → Rat := fun i => base i - 1 - let hi : Fin n → Rat := fun i => base i + 1 - (v, lo, hi) - -private def benchLoop (label : String) (iters : Nat) (act : Unit → Rat) : IO Unit := do - let t0 ← monoUsNow - let mut last : Rat := 0 - for _ in List.range iters do - last := act () - let t1 ← monoUsNow - let total := t1 - t0 - let avg := total / max 1 iters - IO.println s!"bench: {label} iters={iters} total={total} us avg={avg} us last={last}" - -private def benchDotInterval (n iters seed : Nat) : IO Unit := do - let (v, lo, hi) := mkInterval n seed - let labelBase := s!"n={n}" - benchLoop s!"dotIntervalLower {labelBase}" iters (fun () => - Sound.Bounds.dotIntervalLower v lo hi) - benchLoop s!"dotIntervalLowerCommonDen {labelBase}" iters (fun () => - Sound.Bounds.dotIntervalLowerCommonDen v lo hi) - benchLoop s!"dotIntervalLowerCachedRat {labelBase}" iters (fun () => - Sound.Bounds.dotIntervalLowerCachedRat v lo hi) - benchLoop s!"dotIntervalUpper {labelBase}" iters (fun () => - Sound.Bounds.dotIntervalUpper v lo hi) - benchLoop s!"dotIntervalUpperCommonDen {labelBase}" iters (fun () => - Sound.Bounds.dotIntervalUpperCommonDen v lo hi) - benchLoop s!"dotIntervalUpperCachedRat {labelBase}" iters (fun () => - Sound.Bounds.dotIntervalUpperCachedRat v lo hi) - -private def dotIntervalLowerCachedCore {n : Nat} - (vArr loArr hiArr : Array Rat) - (hv : vArr.size = n) (hlo : loArr.size = n) (hhi : hiArr.size = n) : Rat := - let term : Fin n → Rat := fun j => - let vj := vArr[j.1]'(by - simp [hv, j.isLt]) - let loj := loArr[j.1]'(by - simp [hlo, j.isLt]) - let hij := hiArr[j.1]'(by - simp [hhi, j.isLt]) - if 0 ≤ vj then - vj * loj - else - vj * hij - Sound.Linear.sumFin n term - -private def dotIntervalUpperCachedCore {n : Nat} - (vArr loArr hiArr : Array Rat) - (hv : vArr.size = n) (hlo : loArr.size = n) (hhi : hiArr.size = n) : Rat := - let term : Fin n → Rat := fun j => - let vj := vArr[j.1]'(by - simp [hv, j.isLt]) - let loj := loArr[j.1]'(by - simp [hlo, j.isLt]) - let hij := hiArr[j.1]'(by - simp [hhi, j.isLt]) - if 0 ≤ vj then - vj * hij - else - vj * loj - Sound.Linear.sumFin n term - -private def benchDotIntervalCachedParts (n iters seed : Nat) : IO Unit := do - let (v, lo, hi) := mkInterval n seed - let vArr := Array.ofFn v - let loArr := Array.ofFn lo - let hiArr := Array.ofFn hi - have hv : vArr.size = n := by simp [vArr] - have hlo : loArr.size = n := by simp [loArr] - have hhi : hiArr.size = n := by simp [hiArr] - let labelBase := s!"n={n}" - benchLoop s!"dotIntervalLowerCachedRat arrays {labelBase}" iters (fun () => - let vArr' := Array.ofFn v - let loArr' := Array.ofFn lo - let hiArr' := Array.ofFn hi - vArr'.size + loArr'.size + hiArr'.size) - benchLoop s!"dotIntervalLowerCachedRat sum {labelBase}" iters (fun () => - dotIntervalLowerCachedCore vArr loArr hiArr hv hlo hhi) - benchLoop s!"dotIntervalUpperCachedRat sum {labelBase}" iters (fun () => - dotIntervalUpperCachedCore vArr loArr hiArr hv hlo hhi) - -private def benchDotFin (n iters seed : Nat) : IO Unit := do - let x : Fin n → Rat := mkVecRat n seed 7 4 - let y : Fin n → Rat := mkVecRat n seed 19 5 - let labelBase := s!"n={n}" - benchLoop s!"dotFin {labelBase}" iters (fun () => - Sound.Linear.dotFin n x y) - -private def headShapeIters (base : Nat) : Nat := - max 1 (base / 10) - -private def mkHeadAbs (seq dHead : Nat) (seed : Nat) (salt : Nat) : Fin seq → Fin dHead → Rat := - fun q d => - mkRat (q.1 * 31 + d.1 + seed + salt) - (q.1 + d.1 + 7 + seed + salt) (((q.1 + d.1) % 3) = 0) - -private def mkHeadVal (seq dHead : Nat) (seed : Nat) (salt : Nat) : Fin seq → Fin dHead → Rat := - fun q d => - mkRat (q.1 * 17 + d.1 + seed + salt) - (q.1 + d.1 + 11 + seed + salt) (((q.1 + d.1) % 5) = 0) - -private def mkHeadDir (dHead : Nat) (seed : Nat) (salt : Nat) : Fin dHead → Rat := fun d => - mkRat (d.1 + seed + salt) (d.1 + 3 + seed + salt) ((d.1 % 2) = 0) - -private def benchHeadDotAbs (iters seed : Nat) : IO Unit := do - let seq := 8 - let dHead := 64 - let qAbs : Fin seq → Fin dHead → Rat := mkHeadAbs seq dHead seed 3 - let kAbs : Fin seq → Fin dHead → Rat := mkHeadAbs seq dHead seed 19 - benchLoop "head dotAbs dotFin" iters (fun () => - (List.finRange seq).foldl (fun acc q => - (List.finRange seq).foldl (fun acc' k => - acc' + Sound.Linear.dotFin dHead (qAbs q) (kAbs k)) acc) 0) - -private def benchHeadValueBounds (iters seed : Nat) : IO Unit := do - let seq := 8 - let dHead := 64 - let dirHead : Fin dHead → Rat := mkHeadDir dHead seed 5 - let vLo : Fin seq → Fin dHead → Rat := mkHeadVal seq dHead seed 11 - let vHi : Fin seq → Fin dHead → Rat := mkHeadVal seq dHead seed 23 - let dirArr := Array.ofFn dirHead - have hdir : dirArr.size = dHead := by simp [dirArr] - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := by - simp [univ] - benchLoop "head value bounds (cached)" iters (fun () => - let valsLo : Fin seq → Rat := fun k => - Sound.Bounds.dotIntervalLowerCachedRat dirHead (vLo k) (vHi k) - let valsHi : Fin seq → Rat := fun k => - Sound.Bounds.dotIntervalUpperCachedRat dirHead (vLo k) (vHi k) - let lo := univ.inf' hnonempty valsLo - let hi := univ.sup' hnonempty valsHi - lo + hi) - benchLoop "head value bounds (common den)" iters (fun () => - let valsLo : Fin seq → Rat := fun k => - Sound.Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k) - let valsHi : Fin seq → Rat := fun k => - Sound.Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k) - let lo := univ.inf' hnonempty valsLo - let hi := univ.sup' hnonempty valsHi - lo + hi) - benchLoop "head value bounds (direct)" iters (fun () => - let valsLo : Fin seq → Rat := fun k => - Sound.Bounds.dotIntervalLower dirHead (vLo k) (vHi k) - let valsHi : Fin seq → Rat := fun k => - Sound.Bounds.dotIntervalUpper dirHead (vLo k) (vHi k) - let lo := univ.inf' hnonempty valsLo - let hi := univ.sup' hnonempty valsHi - lo + hi) - benchLoop "head value bounds (cached, reuse dir)" iters (fun () => - let valsLo : Fin seq → Rat := fun k => - let loArr := Array.ofFn (vLo k) - let hiArr := Array.ofFn (vHi k) - have hlo : loArr.size = dHead := by simp [loArr] - have hhi : hiArr.size = dHead := by simp [hiArr] - dotIntervalLowerCachedCore dirArr loArr hiArr hdir hlo hhi - let valsHi : Fin seq → Rat := fun k => - let loArr := Array.ofFn (vLo k) - let hiArr := Array.ofFn (vHi k) - have hlo : loArr.size = dHead := by simp [loArr] - have hhi : hiArr.size = dHead := by simp [hiArr] - dotIntervalUpperCachedCore dirArr loArr hiArr hdir hlo hhi - let lo := univ.inf' hnonempty valsLo - let hi := univ.sup' hnonempty valsHi - lo + hi) - -private def benchRatDivInt (iters seed : Nat) : IO Unit := do - let bigNum : Int := - Int.ofNat (2 ^ 200) * Int.ofNat (3 ^ 120) + Int.ofNat (5 ^ 90) + Int.ofNat seed - let bigDen : Int := - Int.ofNat (2 ^ 150) * Int.ofNat (3 ^ 80) + (Int.ofNat seed) + 1 - benchLoop "ratRoundDown divInt big" iters (fun () => - ratRoundDown (Rat.divInt bigNum bigDen)) - -private def forceQkvSumLimited {seq dModel dHead : Nat} - (qkv : Sound.HeadQKVBounds seq dModel dHead) (qLimit dLimit : Nat) : Rat := - let qs := (List.finRange seq).take qLimit - let ds := (List.finRange dHead).take dLimit - qs.foldl (fun acc q => - ds.foldl (fun acc' d => - acc' + qkv.qLo q d + qkv.qHi q d + - qkv.kLo q d + qkv.kHi q d + - qkv.vLo q d + qkv.vHi q d + - qkv.qAbs q d + qkv.kAbs q d) acc) 0 - -private def forceQkvSumDirect {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (lnLo lnHi : Fin seq → Fin dModel → Rat) - (qLimit dLimit : Nat) - (dotLower dotUpper : (Fin dModel → Rat) → (Fin dModel → Rat) → (Fin dModel → Rat) → Rat) : - Rat := - let qs := (List.finRange seq).take qLimit - let ds := (List.finRange dHead).take dLimit - qs.foldl (fun acc q => - ds.foldl (fun acc' d => - let qLo := dotLower (fun j => inputs.wq j d) (lnLo q) (lnHi q) + inputs.bq d - let qHi := dotUpper (fun j => inputs.wq j d) (lnLo q) (lnHi q) + inputs.bq d - let kLo := dotLower (fun j => inputs.wk j d) (lnLo q) (lnHi q) + inputs.bk d - let kHi := dotUpper (fun j => inputs.wk j d) (lnLo q) (lnHi q) + inputs.bk d - let vLo := dotLower (fun j => inputs.wv j d) (lnLo q) (lnHi q) + inputs.bv d - let vHi := dotUpper (fun j => inputs.wv j d) (lnLo q) (lnHi q) + inputs.bv d - let qAbs := max |qLo| |qHi| - let kAbs := max |kLo| |kHi| - acc' + qLo + qHi + kLo + kHi + vLo + vHi + qAbs + kAbs) acc) 0 - -private def benchHeadInputs {seq dModel dHead : Nat} [NeZero seq] - (inputs : Model.InductionHeadInputs seq dModel dHead) (iters : Nat) : IO Unit := do - IO.println "bench: head ln bounds start" - (← IO.getStdout).flush - let lnBounds ← Nfp.IO.timePure "bench: head ln bounds" (fun () => - Sound.headLnBounds inputs) - IO.println "bench: head qkv bounds start" - (← IO.getStdout).flush - let qLimit := - match (← IO.getEnv "NFP_BENCH_QKV_Q") with - | some raw => raw.toNat?.getD seq - | none => seq - let dLimit := - match (← IO.getEnv "NFP_BENCH_QKV_D") with - | some raw => raw.toNat?.getD dHead - | none => dHead - let skipCache := (← IO.getEnv "NFP_BENCH_SKIP_QKV_CACHE").isSome - if !skipCache then - IO.println s!"bench: head qkv bounds (cachedRat) start q={qLimit} d={dLimit}" - (← IO.getStdout).flush - let _sumRat ← Nfp.IO.timePure "bench: head qkv bounds (cachedRat)" (fun () => - forceQkvSumLimited (Sound.headQKVBounds inputs lnBounds.1 lnBounds.2) qLimit dLimit) - pure () - IO.println s!"bench: head qkv bounds (directRat) start q={qLimit} d={dLimit}" - (← IO.getStdout).flush - let _sumDirectRat ← Nfp.IO.timePure "bench: head qkv bounds (directRat)" (fun () => - forceQkvSumDirect inputs lnBounds.1 lnBounds.2 qLimit dLimit - Sound.Bounds.dotIntervalLowerCachedRat Sound.Bounds.dotIntervalUpperCachedRat) - IO.println s!"bench: head qkv bounds (directRatNoCache) start q={qLimit} d={dLimit}" - (← IO.getStdout).flush - let _sumDirectRatNoCache ← Nfp.IO.timePure "bench: head qkv bounds (directRatNoCache)" (fun () => - forceQkvSumDirect inputs lnBounds.1 lnBounds.2 qLimit dLimit - Sound.Bounds.dotIntervalLower Sound.Bounds.dotIntervalUpper) - let qkv := Sound.headQKVBounds inputs lnBounds.1 lnBounds.2 - let qAbs := qkv.qAbs - let kAbs := qkv.kAbs - benchLoop "head inputs dotAbs dotFin" iters (fun () => - (List.finRange seq).foldl (fun acc q => - (List.finRange seq).foldl (fun acc' k => - acc' + Sound.Linear.dotFin dHead (qAbs q) (kAbs k)) acc) 0) - IO.println "bench: head value dir start" - (← IO.getStdout).flush - let dirHead ← Nfp.IO.timePure "bench: head value dir" (fun () => - Sound.headValueDirHead inputs) - let dirArr := Array.ofFn dirHead - have hdir : dirArr.size = dHead := by simp [dirArr] - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := by simp [univ] - benchLoop "head inputs value bounds (cached)" iters (fun () => - let valsLo : Fin seq → Rat := fun k => - Sound.Bounds.dotIntervalLowerCachedRat dirHead (qkv.vLo k) (qkv.vHi k) - let valsHi : Fin seq → Rat := fun k => - Sound.Bounds.dotIntervalUpperCachedRat dirHead (qkv.vLo k) (qkv.vHi k) - let lo := univ.inf' hnonempty valsLo - let hi := univ.sup' hnonempty valsHi - lo + hi) - benchLoop "head inputs value bounds (direct)" iters (fun () => - let valsLo : Fin seq → Rat := fun k => - Sound.Bounds.dotIntervalLower dirHead (qkv.vLo k) (qkv.vHi k) - let valsHi : Fin seq → Rat := fun k => - Sound.Bounds.dotIntervalUpper dirHead (qkv.vLo k) (qkv.vHi k) - let lo := univ.inf' hnonempty valsLo - let hi := univ.sup' hnonempty valsHi - lo + hi) - benchLoop "head inputs value bounds (cached, reuse dir)" iters (fun () => - let valsLo : Fin seq → Rat := fun k => - let loArr := Array.ofFn (qkv.vLo k) - let hiArr := Array.ofFn (qkv.vHi k) - have hlo : loArr.size = dHead := by simp [loArr] - have hhi : hiArr.size = dHead := by simp [hiArr] - dotIntervalLowerCachedCore dirArr loArr hiArr hdir hlo hhi - let valsHi : Fin seq → Rat := fun k => - let loArr := Array.ofFn (qkv.vLo k) - let hiArr := Array.ofFn (qkv.vHi k) - have hlo : loArr.size = dHead := by simp [loArr] - have hhi : hiArr.size = dHead := by simp [hiArr] - dotIntervalUpperCachedCore dirArr loArr hiArr hdir hlo hhi - let lo := univ.inf' hnonempty valsLo - let hi := univ.sup' hnonempty valsHi - lo + hi) - -/-- Run rational microbenchmarks for several vector sizes. -/ -def runRatBench (seed : Nat) : IO Unit := do - let baseIters := - match (← IO.getEnv "NFP_BENCH_ITERS") with - | some raw => raw.toNat?.getD 200 - | none => 200 - let sizes : List Nat := [8, 64, 256, 768] - for n in sizes do - let iters := benchItersFor baseIters n - IO.println s!"bench: start n={n} iters={iters}" - benchDotInterval n iters seed - benchDotIntervalCachedParts n iters seed - benchDotFin n iters seed - let headIters := headShapeIters baseIters - IO.println s!"bench: start head-shape iters={headIters}" - benchHeadDotAbs headIters seed - benchHeadValueBounds headIters seed - benchRatDivInt headIters seed - -/-- Run benchmarks using a real induction-head input payload. -/ -def runRatBenchFromInputs {seq dModel dHead : Nat} [NeZero seq] - (seed : Nat) (inputs : Model.InductionHeadInputs seq dModel dHead) : IO Unit := do - let skipSynth := (← IO.getEnv "NFP_BENCH_SKIP_SYNTH").isSome - if !skipSynth then - runRatBench seed - let baseIters := - match (← IO.getEnv "NFP_BENCH_ITERS") with - | some raw => raw.toNat?.getD 200 - | none => 200 - let headIters := - match (← IO.getEnv "NFP_BENCH_HEAD_ITERS") with - | some raw => raw.toNat?.getD (headShapeIters baseIters) - | none => headShapeIters baseIters - IO.println s!"bench: start head-inputs iters={headIters}" - (← IO.getStdout).flush - benchHeadInputs inputs headIters - -end IO - -end Nfp diff --git a/Nfp/IO/InductionHead.lean b/Nfp/IO/InductionHead.lean index 724a751..8922aa7 100644 --- a/Nfp/IO/InductionHead.lean +++ b/Nfp/IO/InductionHead.lean @@ -356,10 +356,6 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} timingPrint "timing: head build induction cert start" timingFlush let verboseTiming ← IO.getEnv "NFP_TIMING_VERBOSE" - let taskBenchEnv ← IO.getEnv "NFP_TASK_BENCH" - if taskBenchEnv.isSome then - let n := taskBenchEnv.bind String.toNat? |>.getD 1000 - Nfp.IO.taskBench n if verboseTiming.isSome then timingPrint s!"timing: head dims seq={seq} dModel={dModel} dHead={dHead}" timingPrint s!"timing: head active card={inputs.active.card}" diff --git a/Nfp/IO/Run.lean b/Nfp/IO/Run.lean new file mode 100644 index 0000000..9e9a598 --- /dev/null +++ b/Nfp/IO/Run.lean @@ -0,0 +1,805 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.IO.Checks +import Nfp.IO.Derive +import Nfp.IO.HeadScore +import Nfp.IO.InductionHead +import Nfp.IO.Loaders +import Nfp.IO.NfptPure +import Nfp.IO.Timing +import Nfp.IO.Util +import Nfp.Circuit.Cert.DownstreamLinear +import Nfp.Circuit.Cert.LogitDiff +import Nfp.Circuit.Cert.ResidualBound +import Nfp.Circuit.Cert.ResidualInterval +import Nfp.Sound.Bounds.MatrixNorm +import Nfp.Sound.Bounds.Transformer +import Nfp.Sound.Induction +import Nfp.Sound.Induction.HeadBounds +import Nfp.Sound.Induction.LogitDiff +import Nfp.Sound.Linear.FinFold + +/-! +IO entrypoints used by the CLI. +-/ + +namespace Nfp +namespace IO +open Nfp.Circuit + +/-- Check induction certificates and print a short status line. -/ +def runInductionCertify (scoresPath : System.FilePath) + (valuesPath? : Option System.FilePath) (minActive? : Option Nat) + (minLogitDiffStr? : Option String) (minMarginStr? : Option String) + (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) + if minLogitDiff?.isSome && valuesPath?.isNone then + IO.eprintln "error: min-logit-diff requires --values" + return 2 + let parsedScores ← timePhase "load softmax cert" <| + loadSoftmaxMarginCert scoresPath + match parsedScores with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seq, cert⟩ => + let scoresOk ← timePhase "check softmax cert" <| + checkSoftmaxMargin seq cert + match scoresOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let activeCount := cert.active.card + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" + return 2 + if cert.margin < minMargin then + IO.eprintln + s!"error: margin {cert.margin} below minimum {minMargin}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {cert.eps} above maximum {maxEps}" + return 2 + match valuesPath? with + | none => + IO.println + s!"ok: softmax-margin certificate accepted \ + (seq={seq}, active={activeCount})" + return 0 + | some valuesPath => + let parsedValues ← loadValueRangeCert valuesPath + match parsedValues with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seqVals, certVals⟩ => + if hseq : seqVals ≠ seq then + IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" + return 2 + else + have hseq' : seqVals = seq := by + exact (not_ne_iff).1 hseq + let certVals' : ValueRangeCert seq := by + simpa [hseq'] using certVals + let valuesOk ← checkValueRange seq certVals' + match valuesOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let tol := cert.eps * (certVals'.hi - certVals'.lo) + let logitDiffLB? := + Circuit.logitDiffLowerBound cert.active cert.prev cert.eps + certVals'.lo certVals'.hi certVals'.vals + let effectiveMinLogitDiff := + match minLogitDiff?, certVals'.direction with + | some v, _ => some v + | none, some _ => some (0 : Rat) + | none, none => none + match logitDiffLB? with + | none => + IO.eprintln "error: empty active set for logit-diff bound" + return (2 : UInt32) + | some logitDiffLB => + let violation? : Option Rat := + match effectiveMinLogitDiff with + | none => none + | some minLogitDiff => + if logitDiffLB < minLogitDiff then + some minLogitDiff + else + none + match violation? with + | some minLogitDiff => + IO.eprintln + s!"error: logitDiffLB {logitDiffLB} \ + below minimum {minLogitDiff}" + return (2 : UInt32) + | none => + IO.println + s!"ok: induction bound certified \ + (seq={seq}, active={activeCount}, tol={tol}, \ + logitDiffLB={logitDiffLB})" + return 0 +/-- Build and check induction certificates from raw scores/values. -/ +def runInductionCertifySound (scoresPath : System.FilePath) + (valuesPath : System.FilePath) (minActive? : Option Nat) + (minLogitDiffStr? : Option String) (minMarginStr? : Option String) + (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) + let parsedScores ← loadSoftmaxMarginRaw scoresPath + match parsedScores with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seq, raw⟩ => + match seq with + | 0 => + IO.eprintln "error: seq must be positive" + return 2 + | Nat.succ n => + let seq := Nat.succ n + let _ : NeZero seq := ⟨by simp⟩ + match Sound.buildSoftmaxMarginCert? raw.active raw.prev raw.scores raw.weights with + | none => + IO.eprintln "error: softmax-margin inputs rejected" + return 2 + | some ⟨cert, _⟩ => + let activeCount := cert.active.card + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" + return 2 + if cert.margin < minMargin then + IO.eprintln + s!"error: margin {cert.margin} below minimum {minMargin}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {cert.eps} above maximum {maxEps}" + return 2 + let parsedValues ← loadValueRangeRaw valuesPath + match parsedValues with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seqVals, rawVals⟩ => + if hseq : seqVals ≠ seq then + IO.eprintln + s!"error: seq mismatch (scores={seq}, values={seqVals})" + return 2 + else + have hseq' : seqVals = seq := by + exact (not_ne_iff).1 hseq + let rawVals' : Pure.ValueRangeRaw seq := by + simpa [hseq'] using rawVals + match Sound.buildValueRangeCert? rawVals'.vals rawVals'.direction with + | none => + IO.eprintln "error: value-range inputs rejected" + return 2 + | some ⟨certVals, _⟩ => + let tol := cert.eps * (certVals.hi - certVals.lo) + let logitDiffLB? := + Circuit.logitDiffLowerBound cert.active cert.prev cert.eps + certVals.lo certVals.hi certVals.vals + let effectiveMinLogitDiff := + match minLogitDiff?, certVals.direction with + | some v, _ => some v + | none, some _ => some (0 : Rat) + | none, none => none + match logitDiffLB? with + | none => + IO.eprintln "error: empty active set for logit-diff bound" + return 2 + | some logitDiffLB => + let violation? : Option Rat := + match effectiveMinLogitDiff with + | none => none + | some minLogitDiff => + if logitDiffLB < minLogitDiff then + some minLogitDiff + else + none + match violation? with + | some minLogitDiff => + IO.eprintln + s!"error: logitDiffLB {logitDiffLB} \ + below minimum {minLogitDiff}" + return 2 + | none => + IO.println + s!"ok: induction bound certified \ + (seq={seq}, active={activeCount}, \ + tol={tol}, logitDiffLB={logitDiffLB})" + return 0 +/-- Check end-to-end induction certificates with a downstream error bound. -/ +def runInductionCertifyEndToEnd (scoresPath : System.FilePath) + (valuesPath : System.FilePath) (downstreamPath : System.FilePath) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) + let parsedScores ← timePhase "load softmax cert" <| + loadSoftmaxMarginCert scoresPath + match parsedScores with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seq, cert⟩ => + let scoresOk ← timePhase "check softmax cert" <| + checkSoftmaxMargin seq cert + match scoresOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let activeCount := cert.active.card + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" + return 2 + if cert.margin < minMargin then + IO.eprintln + s!"error: margin {cert.margin} below minimum {minMargin}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {cert.eps} above maximum {maxEps}" + return 2 + let parsedValues ← timePhase "load value cert" <| + loadValueRangeCert valuesPath + match parsedValues with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seqVals, certVals⟩ => + if hseq : seqVals ≠ seq then + IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" + return 2 + else + have hseq' : seqVals = seq := by + exact (not_ne_iff).1 hseq + let certVals' : ValueRangeCert seq := by + simpa [hseq'] using certVals + let valuesOk ← timePhase "check value cert" <| + checkValueRange seq certVals' + match valuesOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let logitDiffLB? ← timePure "logit-diff lower bound" (fun () => + Circuit.logitDiffLowerBound cert.active cert.prev cert.eps + certVals'.lo certVals'.hi certVals'.vals) + let effectiveMinLogitDiff := + match minLogitDiff?, certVals'.direction with + | some v, _ => some v + | none, some _ => some (0 : Rat) + | none, none => none + match logitDiffLB? with + | none => + IO.eprintln "error: empty active set for logit-diff bound" + return (2 : UInt32) + | some logitDiffLB => + let parsedDownstream ← loadDownstreamLinearCert downstreamPath + match parsedDownstream with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok downstream => + let downstreamOk := Circuit.checkDownstreamLinearCert downstream + if downstreamOk then + let finalLB := logitDiffLB - downstream.error + let violation? : Option Rat := + match effectiveMinLogitDiff with + | none => none + | some minLogitDiff => + if finalLB < minLogitDiff then + some minLogitDiff + else + none + match violation? with + | some minLogitDiff => + IO.eprintln + s!"error: end-to-end logitDiffLB {finalLB} \ + below minimum {minLogitDiff}" + return (2 : UInt32) + | none => + IO.println + s!"ok: end-to-end induction bound certified \ + (seq={seq}, active={activeCount}, \ + logitDiffLB={logitDiffLB}, \ + downstreamError={downstream.error}, \ + finalLB={finalLB})" + return 0 + else + IO.eprintln "error: downstream certificate rejected" + return 2 +/-- Check end-to-end induction certificates with a downstream matrix. -/ +def runInductionCertifyEndToEndMatrix (scoresPath : System.FilePath) + (valuesPath : System.FilePath) (matrixPath : System.FilePath) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) + let parsedScores ← timePhase "load softmax cert" <| + loadSoftmaxMarginCert scoresPath + match parsedScores with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seq, cert⟩ => + let scoresOk ← timePhase "check softmax cert" <| + checkSoftmaxMargin seq cert + match scoresOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let activeCount := cert.active.card + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" + return 2 + if cert.margin < minMargin then + IO.eprintln + s!"error: margin {cert.margin} below minimum {minMargin}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {cert.eps} above maximum {maxEps}" + return 2 + let parsedValues ← timePhase "load value cert" <| + loadValueRangeCert valuesPath + match parsedValues with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seqVals, certVals⟩ => + if hseq : seqVals ≠ seq then + IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" + return 2 + else + have hseq' : seqVals = seq := by + exact (not_ne_iff).1 hseq + let certVals' : ValueRangeCert seq := by + simpa [hseq'] using certVals + let valuesOk ← timePhase "check value cert" <| + checkValueRange seq certVals' + match valuesOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let logitDiffLB? ← timePure "logit-diff lower bound" (fun () => + Circuit.logitDiffLowerBound cert.active cert.prev cert.eps + certVals'.lo certVals'.hi certVals'.vals) + let effectiveMinLogitDiff := + match minLogitDiff?, certVals'.direction with + | some v, _ => some v + | none, some _ => some (0 : Rat) + | none, none => none + match logitDiffLB? with + | none => + IO.eprintln "error: empty active set for logit-diff bound" + return (2 : UInt32) + | some logitDiffLB => + let parsedMatrix ← loadDownstreamMatrixRaw matrixPath + match parsedMatrix with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨rows, ⟨cols, raw⟩⟩ => + let inputBound := raw.inputBound + if hneg : inputBound < 0 then + IO.eprintln + s!"error: input-bound {inputBound} must be nonnegative" + return 2 + else + have hinput : 0 ≤ inputBound := by + exact le_of_not_gt hneg + let W : Matrix (Fin rows) (Fin cols) Rat := raw.entries + let downstream := + (Sound.Bounds.buildDownstreamLinearCert W inputBound hinput).1 + let finalLB := logitDiffLB - downstream.error + let violation? : Option Rat := + match effectiveMinLogitDiff with + | none => none + | some minLogitDiff => + if finalLB < minLogitDiff then + some minLogitDiff + else + none + match violation? with + | some minLogitDiff => + IO.eprintln + s!"error: end-to-end logitDiffLB {finalLB} \ + below minimum {minLogitDiff}" + return (2 : UInt32) + | none => + IO.println + s!"ok: end-to-end induction bound certified \ + (seq={seq}, active={activeCount}, \ + logitDiffLB={logitDiffLB}, \ + downstreamError={downstream.error}, \ + finalLB={finalLB})" + return 0 +/-- Check end-to-end induction certificates using a model file and residual bounds + (loaded from disk or derived from the model). -/ +def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) + (valuesPath : System.FilePath) (modelPath : System.FilePath) + (residualIntervalPath? : Option System.FilePath) + (layer? : Option Nat) (head? : Option Nat) (period? : Option Nat) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) + let parsedScores ← timePhase "load softmax cert" <| + loadSoftmaxMarginCert scoresPath + match parsedScores with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seq, cert⟩ => + let scoresOk ← timePhase "check softmax cert" <| + checkSoftmaxMargin seq cert + match scoresOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let activeCount := cert.active.card + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" + return 2 + if cert.margin < minMargin then + IO.eprintln + s!"error: margin {cert.margin} below minimum {minMargin}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {cert.eps} above maximum {maxEps}" + return 2 + let parsedValues ← timePhase "load value cert" <| + loadValueRangeCert valuesPath + match parsedValues with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seqVals, certVals⟩ => + if hseq : seqVals ≠ seq then + IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" + return 2 + else + have hseq' : seqVals = seq := by + exact (not_ne_iff).1 hseq + let certVals' : ValueRangeCert seq := by + simpa [hseq'] using certVals + let valuesOk ← timePhase "check value cert" <| + checkValueRange seq certVals' + match valuesOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let logitDiffLB? ← timePure "logit-diff lower bound" (fun () => + Circuit.logitDiffLowerBound cert.active cert.prev cert.eps + certVals'.lo certVals'.hi certVals'.vals) + let effectiveMinLogitDiff := + match minLogitDiff?, certVals'.direction with + | some v, _ => some v + | none, some _ => some (0 : Rat) + | none, none => none + match logitDiffLB? with + | none => + IO.eprintln "error: empty active set for logit-diff bound" + return (2 : UInt32) + | some logitDiffLB => + match certVals'.direction with + | none => + IO.eprintln + "error: value-range certificate missing direction \ + metadata" + return 2 + | some dirSpec => + let data ← timePhase "read model file" <| + IO.FS.readBinFile modelPath + let headerE ← timePure "parse model header" (fun () => + NfptPure.parseHeader data) + match headerE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨header, start⟩ => + if hseq : header.seqLen = seq then + let active? : Option (Finset (Fin header.seqLen)) := + if hactive : cert.active.Nonempty then + some (by simpa [hseq] using cert.active) + else + none + let residualCertE : Except String + (ResidualIntervalCert header.modelDim) ← + match residualIntervalPath? with + | some residualIntervalPath => do + let parsedResidual ← + timePhase "load residual interval" <| + loadResidualIntervalCert residualIntervalPath + match parsedResidual with + | Except.error msg => pure (Except.error msg) + | Except.ok ⟨dim, residualCert⟩ => + if hdim : dim = header.modelDim then + let residualCert' : + ResidualIntervalCert header.modelDim := by + simpa [hdim] using residualCert + pure (Except.ok residualCert') + else + pure (Except.error + s!"residual interval dim {dim} \ + does not match model dim {header.modelDim}") + | none => + deriveResidualIntervalFromModel data start header + active? + match residualCertE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok residualCert' => + let residualOk ← + timePure "check residual interval" (fun () => + Circuit.checkResidualIntervalCert residualCert') + if residualOk then + let dirPos := dirSpec.target + let dirNeg := dirSpec.negative + if layer?.isSome != head?.isSome then + IO.eprintln + "error: --layer and --head must be provided \ + together" + return 2 + let headChoice? : Option (Nat × Nat) := + match layer?, head? with + | some layer, some head => some (layer, head) + | _, _ => none + if period?.isSome && headChoice?.isNone then + IO.eprintln + "warning: --period ignored without \ + --layer/--head" + let colTargetE ← + timePure "read unembed column target" (fun () => + NfptPure.readUnembedColumn + data start header dirPos) + match colTargetE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok colTarget => + let colNegE ← + timePure "read unembed column negative" (fun () => + NfptPure.readUnembedColumn + data start header dirNeg) + match colNegE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok colNeg => + let dirVec : + Fin header.modelDim → Rat := + fun i => colTarget i - colNeg i + let dotIntervalAbs := + Sound.Bounds.dotIntervalAbsBound + let intervalErrorFromHead? : + Model.InductionHeadInputs + seq header.modelDim header.headDim → + ResidualIntervalCert header.modelDim → + Option Rat := + fun inputs residual => by + classical + match hseq0 : seq with + | 0 => exact none + | Nat.succ n => + let _ : NeZero seq := by + exact ⟨by simp [hseq0]⟩ + match + Sound.buildHeadOutputIntervalFromHead? + inputs with + | none => exact none + | some result => + exact some + (dotIntervalAbs + dirVec + (fun i => + residual.lo i - + result.cert.hi i) + (fun i => + residual.hi i - + result.cert.lo i)) + let downstreamError ← + timePure "downstream error" (fun () => + dotIntervalAbs + dirVec + residualCert'.lo + residualCert'.hi) + let finalLB := logitDiffLB - downstreamError + let intervalError? ← + match headChoice? with + | none => pure none + | some (layer, head) => do + let inputsE ← + timePure "read head inputs" (fun () => + NfptPure.readInductionHeadInputs + data start header layer head + dirPos dirNeg period?) + match inputsE with + | Except.error msg => + IO.eprintln s!"warning: {msg}" + pure none + | Except.ok inputs => + let inputs' : + Model.InductionHeadInputs + seq header.modelDim + header.headDim := by + simpa [hseq] using inputs + let inputsAligned : + Model.InductionHeadInputs + seq header.modelDim + header.headDim := + { inputs' with + active := cert.active + prev := cert.prev } + let intervalError? ← + timePure + "head output interval" + (fun () => + intervalErrorFromHead? + inputsAligned + residualCert') + match intervalError? with + | none => + IO.eprintln + "warning: head output interval \ + rejected" + pure none + | some intervalError => + pure (some intervalError) + let intervalLB? := + intervalError?.map (fun err => + logitDiffLB - err) + let effectiveLB := + match intervalLB? with + | some intervalLB => max finalLB intervalLB + | none => finalLB + let violation? : Option Rat := + match effectiveMinLogitDiff with + | none => none + | some minLogitDiff => + if effectiveLB < minLogitDiff then + some minLogitDiff + else + none + match violation? with + | some minLogitDiff => + IO.eprintln + s!"error: end-to-end bound \ + {effectiveLB} below minimum \ + {minLogitDiff}" + return (2 : UInt32) + | none => + match intervalLB? with + | none => + IO.println + s!"ok: end-to-end induction \ + bound certified (seq={seq}, \ + active={activeCount}, \ + logitDiffLB={logitDiffLB}, \ + downstreamError={downstreamError}, \ + finalLB={finalLB})" + | some intervalLB => + let intervalError := + logitDiffLB - intervalLB + IO.println + s!"ok: end-to-end induction \ + bound certified (seq={seq}, \ + active={activeCount}, \ + logitDiffLB={logitDiffLB}, \ + downstreamError={downstreamError}, \ + finalLB={finalLB}, \ + intervalError={intervalError}, \ + intervalLB={intervalLB}, \ + effectiveLB={effectiveLB})" + return 0 + else + IO.eprintln + "error: residual-interval certificate rejected" + return 2 + else + IO.eprintln + s!"error: model seq {header.seqLen} \ + does not match cert seq {seq}" + return 2 +end IO +end Nfp diff --git a/Nfp/IO/Timing.lean b/Nfp/IO/Timing.lean index 7eccf95..41f037e 100644 --- a/Nfp/IO/Timing.lean +++ b/Nfp/IO/Timing.lean @@ -5,7 +5,7 @@ import Nfp.Model.InductionHead import Nfp.Sound.Induction.HeadBounds /-! -Small IO helpers for benchmarking task overhead and profiling slow phases. +Small IO helpers for profiling slow phases. -/ namespace Nfp @@ -123,21 +123,6 @@ def flushStdout : IO Unit := do let h ← IO.getStdout h.flush -/-- Measure task spawn/get overhead on this machine. -/ -def taskBench (n : Nat) : IO Unit := do - if n = 0 then - timingPrint "timing: task bench skipped (n=0)" - return - let t0 ← monoUsNow - let tasks := (List.range n).map (fun _ => Task.spawn (fun _ => ())) - for t in tasks do - let _ := t.get - pure () - let t1 ← monoUsNow - let total := t1 - t0 - let avg := total / n - timingPrint s!"timing: task bench n={n} total={total} us avg={avg} us" - /-- Force a sample score-gap computation for timing. -/ def timeHeadScoreSampleGap {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) diff --git a/Nfp/Sound.lean b/Nfp/Sound.lean index dac2283..0a49652 100644 --- a/Nfp/Sound.lean +++ b/Nfp/Sound.lean @@ -2,13 +2,7 @@ import Nfp.Sound.Gpt2.HeadInputs import Nfp.Sound.Induction -import Nfp.Sound.Induction.HeadBounds -import Nfp.Sound.Induction.LogitDiff -import Nfp.Sound.Bounds.MatrixNorm -import Nfp.Sound.Bounds.Gelu -import Nfp.Sound.Bounds.Mlp -import Nfp.Sound.Bounds.Attention -import Nfp.Sound.Bounds.Transformer +import Nfp.Sound.Bounds import Nfp.Sound.Linear.FinFold /-! diff --git a/Nfp/Sound/Bounds.lean b/Nfp/Sound/Bounds.lean new file mode 100644 index 0000000..21a9f24 --- /dev/null +++ b/Nfp/Sound/Bounds.lean @@ -0,0 +1,15 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Sound.Bounds.Attention +import Nfp.Sound.Bounds.Cache +import Nfp.Sound.Bounds.Gelu +import Nfp.Sound.Bounds.LayerNorm +import Nfp.Sound.Bounds.LayerNorm.InvStd +import Nfp.Sound.Bounds.MatrixNorm +import Nfp.Sound.Bounds.Mlp +import Nfp.Sound.Bounds.Transformer +import Nfp.Sound.Bounds.UnnormRat + +/-! +Aggregator for sound interval bounds. +-/ diff --git a/Nfp/Sound/Induction.lean b/Nfp/Sound/Induction.lean index ea2bcfb..fa7ef37 100644 --- a/Nfp/Sound/Induction.lean +++ b/Nfp/Sound/Induction.lean @@ -3,10 +3,14 @@ import Nfp.Sound.Induction.Core import Nfp.Sound.Induction.CoreSound import Nfp.Sound.Induction.EndToEnd +import Nfp.Sound.Induction.HeadBounds import Nfp.Sound.Induction.HeadOutput +import Nfp.Sound.Induction.LogitDiff +import Nfp.Sound.Induction.OneHot /-! Sound builders for induction certificates. -This module re-exports the core constructions and head-output interval bounds. +This module re-exports the core constructions, head-output interval bounds, +and logit-diff helpers. -/ diff --git a/lakefile.toml b/lakefile.toml index e627101..29a1b9d 100644 --- a/lakefile.toml +++ b/lakefile.toml @@ -37,18 +37,6 @@ roots = ["Nfp"] name = "nfp" root = "Main" -[[lean_exe]] -name = "bench-rational" -root = "BenchRational" - -[[lean_exe]] -name = "bench-induction-core" -root = "BenchInductionCore" - -[[lean_exe]] -name = "bench-induction-counts" -root = "BenchInductionCounts" - [[lean_exe]] name = "theorem-axioms" root = "TheoremAxioms" From c46601c04c80fbbc9ec403832be583411354c401 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 18:37:25 +0100 Subject: [PATCH 164/244] Split bounds basics into submodules --- Nfp/Sound/Bounds/LayerNorm.lean | 809 +---------------------- Nfp/Sound/Bounds/LayerNorm/Basic.lean | 813 ++++++++++++++++++++++++ Nfp/Sound/Bounds/LayerNorm/InvStd.lean | 2 +- Nfp/Sound/Bounds/MatrixNorm.lean | 179 +----- Nfp/Sound/Bounds/MatrixNorm/Basic.lean | 183 ++++++ Nfp/Sound/Bounds/Transformer.lean | 560 +--------------- Nfp/Sound/Bounds/Transformer/Basic.lean | 564 ++++++++++++++++ 7 files changed, 1568 insertions(+), 1542 deletions(-) create mode 100644 Nfp/Sound/Bounds/LayerNorm/Basic.lean create mode 100644 Nfp/Sound/Bounds/MatrixNorm/Basic.lean create mode 100644 Nfp/Sound/Bounds/Transformer/Basic.lean diff --git a/Nfp/Sound/Bounds/LayerNorm.lean b/Nfp/Sound/Bounds/LayerNorm.lean index f5e11ae..7e17dd5 100644 --- a/Nfp/Sound/Bounds/LayerNorm.lean +++ b/Nfp/Sound/Bounds/LayerNorm.lean @@ -1,813 +1,10 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.BigOperators.Fin -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Algebra.Order.BigOperators.Group.Finset -import Mathlib.Algebra.Order.Field.Basic -import Mathlib.Algebra.Order.Ring.Basic -import Mathlib.Data.Real.Sqrt -import Mathlib.Data.Rat.BigOperators -import Mathlib.Data.Rat.Cast.Order -import Nfp.Core.Basic +import Nfp.Sound.Bounds.LayerNorm.Basic +import Nfp.Sound.Bounds.LayerNorm.InvStd import Nfp.Sound.Bounds.LayerNorm.MeanVariance import Nfp.Sound.Bounds.LayerNorm.SqrtBounds -import Nfp.Sound.Linear.FinFold /-! -LayerNorm interval bounds for rational inputs. - -This module computes rational interval bounds for LayerNorm outputs and proves -those bounds sound for real-valued LayerNorm semantics. --/ - -namespace Nfp - -namespace Sound - -namespace Bounds - -open scoped BigOperators - -/-- Bounds for multiplying a scalar by a bounded value. -/ -def scaleInterval (x lo hi : Rat) : Rat × Rat := - if 0 ≤ x then - (x * lo, x * hi) - else - (x * hi, x * lo) - -/-- `scaleInterval` bounds a product. -/ -theorem scaleInterval_bounds {x lo hi y : Rat} - (hlo : lo ≤ y) (hhi : y ≤ hi) : - let bounds := scaleInterval x lo hi - bounds.1 ≤ x * y ∧ x * y ≤ bounds.2 := by - by_cases hx : 0 ≤ x - · have hbounds : x * lo ≤ x * y ∧ x * y ≤ x * hi := by - exact ⟨mul_le_mul_of_nonneg_left hlo hx, mul_le_mul_of_nonneg_left hhi hx⟩ - simpa [scaleInterval, hx] using hbounds - · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) - have hbounds : x * hi ≤ x * y ∧ x * y ≤ x * lo := by - exact ⟨mul_le_mul_of_nonpos_left hhi hx', mul_le_mul_of_nonpos_left hlo hx'⟩ - simpa [scaleInterval, hx] using hbounds - -/-- `scaleInterval` bounds interpreted in the reals. -/ -theorem scaleInterval_bounds_real {x lo hi : Rat} {y : Real} - (hlo : (lo : Real) ≤ y) (hhi : y ≤ (hi : Real)) : - let bounds := scaleInterval x lo hi - (bounds.1 : Real) ≤ (x : Real) * y ∧ (x : Real) * y ≤ (bounds.2 : Real) := by - by_cases hx : 0 ≤ x - · have hx' : 0 ≤ (x : Real) := ratToReal_nonneg_of_nonneg hx - have hbounds : (x : Real) * (lo : Real) ≤ (x : Real) * y ∧ - (x : Real) * y ≤ (x : Real) * (hi : Real) := by - exact ⟨mul_le_mul_of_nonneg_left hlo hx', mul_le_mul_of_nonneg_left hhi hx'⟩ - simpa [scaleInterval, hx] using hbounds - · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) - have hx'' : (x : Real) ≤ 0 := (ratToReal_nonpos_iff (x := x)).2 hx' - have hbounds : (x : Real) * (hi : Real) ≤ (x : Real) * y ∧ - (x : Real) * y ≤ (x : Real) * (lo : Real) := by - exact ⟨mul_le_mul_of_nonpos_left hhi hx'', mul_le_mul_of_nonpos_left hlo hx''⟩ - simpa [scaleInterval, hx] using hbounds - -/-- Real-valued LayerNorm output for a vector. -/ -noncomputable def layerNormReal {n : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) : Fin n → Real := - if n = 0 then - fun _ => 0 - else - let μ : Real := meanRat x - let varEps : Real := (varianceRat x : Real) + (eps : Real) - let invStd : Real := (Real.sqrt varEps)⁻¹ - fun i => (gamma i : Real) * ((x i : Real) - μ) * invStd + (beta i : Real) - -/-- Real-valued LayerNorm output for a real vector. -/ -noncomputable def layerNormRealOfReal {n : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Real) : Fin n → Real := - if n = 0 then - fun _ => 0 - else - let μ : Real := meanReal x - let varEps : Real := varianceReal x + (eps : Real) - let invStd : Real := (Real.sqrt varEps)⁻¹ - fun i => (gamma i : Real) * (x i - μ) * invStd + (beta i : Real) - -/-- Interval bounds for LayerNorm outputs. -/ -def layerNormBounds {n : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) : - (Fin n → Rat) × (Fin n → Rat) := - if n = 0 then - (fun _ => 0, fun _ => 0) - else - let μ : Rat := mean x - let centered : Fin n → Rat := fun i => x i - μ - let var : Rat := variance x - let varEps : Rat := var + eps - let sqrtLowerBound : Rat := max (sqrtLower eps) (sqrtLower varEps) - let sqrtUpperBound : Rat := sqrtUpper varEps - let invStdLower : Rat := ratDivDown 1 sqrtUpperBound - let invStdUpper : Rat := ratDivUp 1 sqrtLowerBound - let coeff : Fin n → Rat := fun i => gamma i * centered i - let lo : Fin n → Rat := fun i => - if 0 ≤ coeff i then - beta i + coeff i * invStdLower - else - beta i + coeff i * invStdUpper - let hi : Fin n → Rat := fun i => - if 0 ≤ coeff i then - beta i + coeff i * invStdUpper - else - beta i + coeff i * invStdLower - (lo, hi) - -/-- `layerNormBounds` soundness for real LayerNorm outputs. -/ -theorem layerNormBounds_spec {n : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) - (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : - let bounds := layerNormBounds eps gamma beta x - ∀ i, - (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i ∧ - layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by - classical - intro bounds i - let μRat : Rat := mean x - let varRat : Rat := variance x - let varEpsRat : Rat := varRat + eps - let sqrtLowerBound : Rat := max (sqrtLower eps) (sqrtLower varEpsRat) - let sqrtUpperBound : Rat := sqrtUpper varEpsRat - let invStdLower : Rat := ratDivDown 1 sqrtUpperBound - let invStdUpper : Rat := ratDivUp 1 sqrtLowerBound - let centered : Rat := x i - μRat - let coeff : Rat := gamma i * centered - let μ : Real := meanRat x - let varEps : Real := (varianceRat x : Real) + (eps : Real) - let invStd : Real := (Real.sqrt varEps)⁻¹ - have hmu : (μRat : Real) = μ := by - simp [μRat, μ, mean_def, hne, ratRoundDown] - have hvar : (varRat : Real) = (varianceRat x : Real) := by - simp [varRat, variance_def, hne, ratRoundDown] - have hvarEps : (varEpsRat : Real) = varEps := by - simp [varEpsRat, varEps, hvar] - have hvar_nonneg : 0 ≤ (varianceRat x : Real) := varianceRat_nonneg_real x hne - have hvar_nonneg_rat : 0 ≤ varianceRat x := by - exact (ratToReal_nonneg_iff (x := varianceRat x)).1 hvar_nonneg - have hvarRat_nonneg : 0 ≤ varRat := by - have h := ratRoundDown_nonneg (q := varianceRat x) hvar_nonneg_rat - simpa [varRat, variance_def x hne] using h - have hvarEps_nonneg : 0 ≤ varEpsRat := by - exact add_nonneg hvarRat_nonneg (le_of_lt heps) - have hsqrt_lower : - (sqrtLowerBound : Real) ≤ Real.sqrt varEps := by - have hsqrt_eps : (sqrtLower eps : Real) ≤ Real.sqrt varEps := by - have hsqrt_eps' : (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by - have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) - simpa using h - have hle : (eps : Real) ≤ varEps := by - have hle' : (eps : Real) ≤ (varianceRat x : Real) + (eps : Real) := - le_add_of_nonneg_left hvar_nonneg - simpa [varEps] using hle' - exact le_trans hsqrt_eps' (Real.sqrt_le_sqrt hle) - have hsqrt_var : (sqrtLower varEpsRat : Real) ≤ Real.sqrt varEps := by - have hsqrt_var' : - (sqrtLower varEpsRat : Real) ≤ Real.sqrt (varEpsRat : Real) := by - have h := sqrtLower_le_real_sqrt (q := varEpsRat) hvarEps_nonneg - simpa using h - have hle : (varEpsRat : Real) ≤ varEps := by - simp [hvarEps] - exact le_trans hsqrt_var' (Real.sqrt_le_sqrt hle) - have hmax : - max (sqrtLower eps : Real) (sqrtLower varEpsRat : Real) ≤ Real.sqrt varEps := - (max_le_iff).2 ⟨hsqrt_eps, hsqrt_var⟩ - simpa [sqrtLowerBound, ratToReal_max] using hmax - have hsqrt_upper : - Real.sqrt varEps ≤ (sqrtUpperBound : Real) := by - have h := real_sqrt_le_sqrtUpper (q := varEpsRat) hvarEps_nonneg - simpa [sqrtUpperBound, hvarEps] using h - have hsqrt_lower_pos_rat : 0 < sqrtLowerBound := by - have hpos : 0 < sqrtLower eps := hsqrt - have hpos' : 0 < max (sqrtLower eps) (sqrtLower varEpsRat) := - lt_of_lt_of_le hpos (le_max_left _ _) - simpa [sqrtLowerBound] using hpos' - have hsqrt_lower_pos : 0 < (sqrtLowerBound : Real) := by - exact (Rat.cast_pos (K := Real) (q := sqrtLowerBound)).2 hsqrt_lower_pos_rat - have hsqrt_upper_pos_rat : 0 < sqrtUpperBound := by - simpa [sqrtUpperBound] using sqrtUpper_pos varEpsRat - have hsqrt_upper_pos : 0 < (sqrtUpperBound : Real) := by - exact (Rat.cast_pos (K := Real) (q := sqrtUpperBound)).2 hsqrt_upper_pos_rat - have hvarEps_pos : 0 < varEps := by - have heps_real : 0 < (eps : Real) := by - exact_mod_cast heps - have hpos := add_pos_of_nonneg_of_pos hvar_nonneg heps_real - simpa [varEps] using hpos - have hsqrt_pos : 0 < Real.sqrt varEps := Real.sqrt_pos.2 hvarEps_pos - have hinv_lower_real : - (sqrtUpperBound : Real)⁻¹ ≤ invStd := by - have hle := inv_anti₀ hsqrt_pos hsqrt_upper - simpa [invStd] using hle - have hinv_upper_real : - invStd ≤ (sqrtLowerBound : Real)⁻¹ := by - have hle := inv_anti₀ hsqrt_lower_pos hsqrt_lower - simpa [invStd] using hle - have hupper_ne : sqrtUpperBound ≠ 0 := ne_of_gt hsqrt_upper_pos_rat - have hlower_ne : sqrtLowerBound ≠ 0 := ne_of_gt hsqrt_lower_pos_rat - have hinv_lower : (invStdLower : Real) ≤ invStd := by - simpa [invStdLower, ratDivDown, hupper_ne, one_div] using hinv_lower_real - have hinv_upper : invStd ≤ (invStdUpper : Real) := by - simpa [invStdUpper, ratDivUp, hlower_ne, one_div] using hinv_upper_real - have hlayer : - layerNormReal eps gamma beta x i = - (beta i : Real) + (coeff : Real) * invStd := by - simp [layerNormReal, hne, coeff, centered, μ, hmu, invStd, varEps, add_comm, mul_assoc] - by_cases hcoeff : 0 ≤ coeff - · have hcoeff_real : 0 ≤ (coeff : Real) := - ratToReal_nonneg_of_nonneg hcoeff - have hlow_raw : - (beta i : Real) + (coeff : Real) * (invStdLower : Real) ≤ - (beta i : Real) + (coeff : Real) * invStd := by - have hmul := mul_le_mul_of_nonneg_left hinv_lower hcoeff_real - simpa only [add_comm] using add_le_add_left hmul (beta i : Real) - have hhigh_raw : - (beta i : Real) + (coeff : Real) * invStd ≤ - (beta i : Real) + (coeff : Real) * (invStdUpper : Real) := by - have hmul := mul_le_mul_of_nonneg_left hinv_upper hcoeff_real - simpa only [add_comm] using add_le_add_left hmul (beta i : Real) - have hlo : (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i := by - simpa [bounds, layerNormBounds, hne, μRat, centered, varRat, varEpsRat, - sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] - using hlow_raw - have hhi : layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by - simpa [bounds, layerNormBounds, hne, μRat, centered, varRat, varEpsRat, - sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] - using hhigh_raw - exact And.intro hlo hhi - · have hcoeff_lt : coeff < 0 := lt_of_not_ge hcoeff - have hcoeff_real : (coeff : Real) ≤ 0 := by - exact_mod_cast (le_of_lt hcoeff_lt) - have hlow_raw : - (beta i : Real) + (coeff : Real) * (invStdUpper : Real) ≤ - (beta i : Real) + (coeff : Real) * invStd := by - have hmul := mul_le_mul_of_nonpos_left hinv_upper hcoeff_real - simpa only [add_comm] using add_le_add_left hmul (beta i : Real) - have hhigh_raw : - (beta i : Real) + (coeff : Real) * invStd ≤ - (beta i : Real) + (coeff : Real) * (invStdLower : Real) := by - have hmul := mul_le_mul_of_nonpos_left hinv_lower hcoeff_real - simpa only [add_comm] using add_le_add_left hmul (beta i : Real) - have hlo : (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i := by - simpa [bounds, layerNormBounds, hne, μRat, centered, varRat, varEpsRat, - sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] - using hlow_raw - have hhi : layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by - simpa [bounds, layerNormBounds, hne, μRat, centered, varRat, varEpsRat, - sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] - using hhigh_raw - exact And.intro hlo hhi - -/-! -Local bounds for monotone multiplication in real-valued bounds. +LayerNorm bounds and supporting lemmas. -/ - -/-- Lower sqrt bound against the variance-plus-eps term. -/ -theorem sqrtLower_le_real_sqrt_varEps {n : Nat} (eps : Rat) (x : Fin n → Rat) - (hne : n ≠ 0) (heps : 0 < eps) : - let varEps : Real := (varianceRat x : Real) + (eps : Real) - (sqrtLower eps : Real) ≤ Real.sqrt varEps := by - intro varEps - have hsqrt_eps : - (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by - have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) - simpa using h - have hvar_nonneg : 0 ≤ (varianceRat x : Real) := varianceRat_nonneg_real x hne - have hle : (eps : Real) ≤ varEps := by - have hle' : (eps : Real) ≤ (varianceRat x : Real) + (eps : Real) := - le_add_of_nonneg_left hvar_nonneg - simpa [varEps] using hle' - have hsqrt_eps' : Real.sqrt (eps : Real) ≤ Real.sqrt varEps := by - exact Real.sqrt_le_sqrt hle - exact le_trans hsqrt_eps hsqrt_eps' - -/-- Inverse-std upper bound from the lower sqrt bound. -/ -theorem invStd_le_invStdBound {n : Nat} (eps : Rat) (x : Fin n → Rat) - (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : - let varEps : Real := (varianceRat x : Real) + (eps : Real) - let invStd : Real := (Real.sqrt varEps)⁻¹ - let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) - invStd ≤ (invStdBound : Real) := by - intro varEps invStd invStdBound - have hsqrt_lower : (sqrtLower eps : Real) ≤ Real.sqrt varEps := by - simpa [varEps] using - (sqrtLower_le_real_sqrt_varEps (eps := eps) (x := x) hne heps) - have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by - exact (Rat.cast_pos (K := Real) (q := sqrtLower eps)).2 hsqrt - have hinv_sqrt : invStd ≤ (sqrtLower eps : Real)⁻¹ := by - have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower - simpa [invStd] using h - have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by - have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt - have hdiv := ratDivUp_ge_real (x := 1) (y := sqrtLower eps) hy - simpa [invStdBound, one_div] using hdiv - exact le_trans hinv_sqrt hinv_bound - -/-- Inverse-std is nonnegative. -/ -theorem invStd_nonneg {n : Nat} (eps : Rat) (x : Fin n → Rat) : - let varEps : Real := (varianceRat x : Real) + (eps : Real) - let invStd : Real := (Real.sqrt varEps)⁻¹ - 0 ≤ invStd := by - intro varEps invStd - have hsqrt_nonneg : 0 ≤ Real.sqrt varEps := by - exact Real.sqrt_nonneg _ - exact inv_nonneg.2 hsqrt_nonneg - -/-- Interval bounds for LayerNorm outputs from per-coordinate intervals. -/ -def layerNormIntervalBounds {n : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) (lo hi : Fin n → Rat) : - (Fin n → Rat) × (Fin n → Rat) := - if n = 0 then - (fun _ => 0, fun _ => 0) - else - let μLo := mean lo - let μHi := meanUpper hi - let centeredBound : Fin n → Rat := fun i => - max |lo i - μHi| |hi i - μLo| - let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) - let radius : Fin n → Rat := fun i => |gamma i| * centeredBound i * invStdBound - (fun i => beta i - radius i, fun i => beta i + radius i) - -/-- `layerNormIntervalBounds` soundness for real LayerNorm outputs. -/ -theorem layerNormIntervalBounds_spec {n : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) (lo hi : Fin n → Rat) (x : Fin n → Rat) - (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) - (hlo : ∀ i, lo i ≤ x i) (hhi : ∀ i, x i ≤ hi i) : - let bounds := layerNormIntervalBounds eps gamma beta lo hi - ∀ i, - (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i ∧ - layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by - classical - intro bounds i - let μLo : Rat := mean lo - let μHi : Rat := meanUpper hi - let centeredBound : Fin n → Rat := fun j => max |lo j - μHi| |hi j - μLo| - let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) - let varEps : Real := (varianceRat x : Real) + (eps : Real) - let μ : Real := meanRat x - let invStd : Real := (Real.sqrt varEps)⁻¹ - have hcentered_nonneg : 0 ≤ (centeredBound i : Real) := by - have h0 : 0 ≤ centeredBound i := by - dsimp [centeredBound] - exact le_trans (abs_nonneg _) (le_max_left _ _) - exact ratToReal_nonneg_of_nonneg h0 - have hcentered_abs : |(x i : Real) - μ| ≤ (centeredBound i : Real) := by - have hmean_lo_real : (μLo : Real) ≤ μ := by - have hmean_rat : (meanRat lo : Real) ≤ (meanRat x : Real) := - meanRat_le_meanRat_real lo x hne hlo - have hdown : (μLo : Real) ≤ (meanRat lo : Real) := by - simpa [μLo, mean_def lo hne] using ratRoundDown_le_real (meanRat lo) - exact le_trans hdown hmean_rat - have hmean_hi_real : μ ≤ (μHi : Real) := by - have hmean_rat : (meanRat x : Real) ≤ (meanRat hi : Real) := - meanRat_le_meanRat_real x hi hne hhi - have hup : (meanRat hi : Real) ≤ (μHi : Real) := by - simpa [μHi, meanUpper_def hi hne] using real_le_ratRoundUp (meanRat hi) - exact le_trans hmean_rat hup - have hlo' : (lo i : Real) - (μHi : Real) ≤ (x i : Real) - μ := by - have h1 : (lo i : Real) - (μHi : Real) ≤ (lo i : Real) - μ := by - exact sub_le_sub_left hmean_hi_real (lo i : Real) - have h2 : (lo i : Real) - μ ≤ (x i : Real) - μ := by - exact sub_le_sub_right - (by - exact ratToReal_le_of_le (hlo i)) - μ - exact le_trans h1 h2 - have hhi' : (x i : Real) - μ ≤ (hi i : Real) - (μLo : Real) := by - have h1 : (x i : Real) - μ ≤ (hi i : Real) - μ := by - exact sub_le_sub_right - (by - exact ratToReal_le_of_le (hhi i)) - μ - have h2 : (hi i : Real) - μ ≤ (hi i : Real) - (μLo : Real) := by - exact sub_le_sub_left hmean_lo_real (hi i : Real) - exact le_trans h1 h2 - have hbound := abs_le_max_of_bounds hlo' hhi' - simpa [centeredBound, μLo, μHi, ratToReal_abs, ratToReal_sub, - ratToReal_max] using hbound - have hinv : invStd ≤ (invStdBound : Real) := by - simpa [varEps, invStd, invStdBound] using - (invStd_le_invStdBound (eps := eps) (x := x) hne heps hsqrt) - have hinv_nonneg : 0 ≤ invStd := by - simp [varEps, invStd] - have hmul1 : |(x i : Real) - μ| * invStd ≤ - (centeredBound i : Real) * (invStdBound : Real) := by - have hleft : - |(x i : Real) - μ| * invStd ≤ (centeredBound i : Real) * invStd := by - exact mul_le_mul_of_nonneg_right hcentered_abs hinv_nonneg - have hright : - (centeredBound i : Real) * invStd ≤ (centeredBound i : Real) * (invStdBound : Real) := by - exact mul_le_mul_of_nonneg_left hinv hcentered_nonneg - exact le_trans hleft hright - have hmul2 : |(gamma i : Real)| * |(x i : Real) - μ| * invStd ≤ - |(gamma i : Real)| * (centeredBound i : Real) * (invStdBound : Real) := by - have hgamma_nonneg : 0 ≤ |(gamma i : Real)| := abs_nonneg _ - have hmul2' : |(gamma i : Real)| * (|(x i : Real) - μ| * invStd) ≤ - |(gamma i : Real)| * ((centeredBound i : Real) * (invStdBound : Real)) := by - exact mul_le_mul_of_nonneg_left hmul1 hgamma_nonneg - simpa only [mul_assoc] using hmul2' - let t : Real := (gamma i : Real) * ((x i : Real) - μ) * invStd - have ht_abs : - |t| ≤ |(gamma i : Real)| * (centeredBound i : Real) * (invStdBound : Real) := by - have ht : |t| = |(gamma i : Real)| * |(x i : Real) - μ| * invStd := by - have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg - simp [t, abs_mul, hinv_abs, mul_assoc] - simpa [ht] using hmul2 - let radius : Fin n → Rat := fun j => |gamma j| * centeredBound j * invStdBound - have ht_abs' : |t| ≤ (radius i : Real) := by - simpa [radius, centeredBound, invStdBound] using ht_abs - have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by - exact abs_le.mp ht_abs' - have hlow : - (beta i : Real) - (radius i : Real) ≤ t + (beta i : Real) := by - have h : (beta i : Real) + -(radius i : Real) ≤ (beta i : Real) + t := by - simpa only [add_comm, add_left_comm, add_assoc] using - add_le_add_left hbounds.1 (beta i : Real) - simpa only [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h - have hhigh : - t + (beta i : Real) ≤ (beta i : Real) + (radius i : Real) := by - have h : (beta i : Real) + t ≤ (beta i : Real) + (radius i : Real) := by - simpa only [add_comm, add_left_comm, add_assoc] using - add_le_add_left hbounds.2 (beta i : Real) - simpa only [add_comm, add_left_comm, add_assoc] using h - have hreal : - layerNormReal eps gamma beta x i = t + (beta i : Real) := by - simp [layerNormReal, hne, μ, invStd, varEps, t, add_comm] - have hlo : (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i := by - simpa [bounds, layerNormIntervalBounds, hne, radius, centeredBound, invStdBound, μLo, μHi, - hreal] using hlow - have hhi : layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by - simpa [bounds, layerNormIntervalBounds, hne, radius, centeredBound, invStdBound, μLo, μHi, - hreal] using hhigh - exact And.intro hlo hhi - -/-- Interval bounds for LayerNorm outputs from an absolute input bound. -/ -def layerNormAbsBounds {n : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) (absBound : Rat) : - (Fin n → Rat) × (Fin n → Rat) := - let centeredBound : Rat := 2 * absBound - let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) - let radius : Fin n → Rat := fun i => |gamma i| * centeredBound * invStdBound - (fun i => beta i - radius i, fun i => beta i + radius i) - -/-- Bound a centered value by double the absolute bound. -/ -private theorem abs_sub_le_double_bound {a b bound : Real} - (ha : |a| ≤ bound) (hb : |b| ≤ bound) : - |a - b| ≤ bound + bound := by - have h1 : |a - b| ≤ |a| + |b| := by - simpa [sub_eq_add_neg, abs_neg] using abs_add_le a (-b) - exact le_trans h1 (add_le_add ha hb) - -/-- `layerNormAbsBounds` soundness for real LayerNorm outputs under absolute input bounds. -/ -theorem layerNormAbsBounds_spec {n : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) (absBound : Rat) (x : Fin n → Rat) - (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) - (habs : ∀ i, |x i| ≤ absBound) : - let bounds := layerNormAbsBounds eps gamma beta absBound - ∀ i, - (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i ∧ - layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by - classical - intro bounds i - have hmean_abs_real : |(meanRat x : Real)| ≤ (absBound : Real) := by - have h := - meanReal_abs_le_bound (x := fun j => (x j : Real)) (bound := absBound) hne - (by - intro j - exact ratToReal_abs_le_of_le (habs j)) - simpa [meanReal_eq_meanRat] using h - have hbound_nonneg : 0 ≤ absBound := by - have hposn : 0 < n := Nat.pos_of_ne_zero hne - let i0 : Fin n := ⟨0, hposn⟩ - have h0 : 0 ≤ |x i0| := abs_nonneg _ - exact le_trans h0 (habs i0) - let centeredBound : Rat := 2 * absBound - let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) - let varEps : Real := (varianceRat x : Real) + (eps : Real) - let μ : Real := meanRat x - let invStd : Real := (Real.sqrt varEps)⁻¹ - have hcentered_abs : |(x i : Real) - μ| ≤ (centeredBound : Real) := by - have hx : |(x i : Real)| ≤ (absBound : Real) := by - exact ratToReal_abs_le_of_le (habs i) - have hmu : |μ| ≤ (absBound : Real) := by - simpa [μ] using hmean_abs_real - have h12 : |(x i : Real) - μ| ≤ (absBound : Real) + (absBound : Real) := - abs_sub_le_double_bound hx hmu - simpa [centeredBound, two_mul] using h12 - have hbound_nonneg_real : 0 ≤ (absBound : Real) := by - exact ratToReal_nonneg_of_nonneg hbound_nonneg - have hcentered_nonneg : 0 ≤ (centeredBound : Real) := by - have hsum := add_nonneg hbound_nonneg_real hbound_nonneg_real - simpa [centeredBound, two_mul] using hsum - have hinv : invStd ≤ (invStdBound : Real) := by - simpa [varEps, invStd, invStdBound] using - (invStd_le_invStdBound (eps := eps) (x := x) hne heps hsqrt) - have hinv_nonneg : 0 ≤ invStd := by - simp [varEps, invStd] - have hmul1 : |(x i : Real) - μ| * invStd ≤ - (centeredBound : Real) * (invStdBound : Real) := by - have hleft : - |(x i : Real) - μ| * invStd ≤ (centeredBound : Real) * invStd := by - exact mul_le_mul_of_nonneg_right hcentered_abs hinv_nonneg - have hright : - (centeredBound : Real) * invStd ≤ (centeredBound : Real) * (invStdBound : Real) := by - exact mul_le_mul_of_nonneg_left hinv hcentered_nonneg - exact le_trans hleft hright - have hmul2 : |(gamma i : Real)| * |(x i : Real) - μ| * invStd ≤ - |(gamma i : Real)| * (centeredBound : Real) * (invStdBound : Real) := by - have hgamma_nonneg : 0 ≤ |(gamma i : Real)| := abs_nonneg _ - have hmul2' : |(gamma i : Real)| * (|(x i : Real) - μ| * invStd) ≤ - |(gamma i : Real)| * ((centeredBound : Real) * (invStdBound : Real)) := by - exact mul_le_mul_of_nonneg_left hmul1 hgamma_nonneg - simpa only [mul_assoc] using hmul2' - let t : Real := (gamma i : Real) * ((x i : Real) - μ) * invStd - have ht_abs : - |t| ≤ |(gamma i : Real)| * (centeredBound : Real) * (invStdBound : Real) := by - have ht : |t| = |(gamma i : Real)| * |(x i : Real) - μ| * invStd := by - have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg - simp [t, abs_mul, hinv_abs, mul_assoc] - simpa [ht] using hmul2 - let radius : Fin n → Rat := fun j => |gamma j| * centeredBound * invStdBound - have ht_abs' : |t| ≤ (radius i : Real) := by - simpa [radius, centeredBound, invStdBound] using ht_abs - have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by - exact abs_le.mp ht_abs' - have hlow : - (beta i : Real) - (radius i : Real) ≤ t + (beta i : Real) := by - have h : (beta i : Real) + -(radius i : Real) ≤ (beta i : Real) + t := by - simpa only [add_comm, add_left_comm, add_assoc] using - add_le_add_left hbounds.1 (beta i : Real) - simpa only [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h - have hhigh : - t + (beta i : Real) ≤ (beta i : Real) + (radius i : Real) := by - have h : (beta i : Real) + t ≤ (beta i : Real) + (radius i : Real) := by - simpa only [add_comm, add_left_comm, add_assoc] using - add_le_add_left hbounds.2 (beta i : Real) - simpa only [add_comm, add_left_comm, add_assoc] using h - have hreal : - layerNormReal eps gamma beta x i = t + (beta i : Real) := by - simp [layerNormReal, hne, μ, invStd, varEps, t, add_comm] - have hlo : (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i := by - simpa [bounds, layerNormAbsBounds, radius, centeredBound, invStdBound, hreal] using hlow - have hhi : layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by - simpa [bounds, layerNormAbsBounds, radius, centeredBound, invStdBound, hreal] using hhigh - exact And.intro hlo hhi - -/-- `layerNormAbsBounds` soundness for real LayerNorm outputs on real inputs. -/ -theorem layerNormAbsBounds_spec_real {n : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) (absBound : Rat) (x : Fin n → Real) - (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) - (habs : ∀ i, |x i| ≤ (absBound : Real)) : - let bounds := layerNormAbsBounds eps gamma beta absBound - ∀ i, - (bounds.1 i : Real) ≤ layerNormRealOfReal eps gamma beta x i ∧ - layerNormRealOfReal eps gamma beta x i ≤ (bounds.2 i : Real) := by - classical - intro bounds i - have hmean_abs : |meanReal x| ≤ (absBound : Real) := - meanReal_abs_le_bound x absBound hne habs - have hbound_nonneg_real : 0 ≤ (absBound : Real) := by - have hposn : 0 < n := Nat.pos_of_ne_zero hne - let i0 : Fin n := ⟨0, hposn⟩ - have h0 : 0 ≤ |x i0| := abs_nonneg _ - exact le_trans h0 (habs i0) - let centeredBound : Rat := 2 * absBound - let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) - let varEps : Real := varianceReal x + (eps : Real) - let μ : Real := meanReal x - let invStd : Real := (Real.sqrt varEps)⁻¹ - have hcentered_abs : |x i - μ| ≤ (centeredBound : Real) := by - have hx : |x i| ≤ (absBound : Real) := habs i - have hmu : |μ| ≤ (absBound : Real) := by - simpa using hmean_abs - have h12 : |x i - μ| ≤ (absBound : Real) + (absBound : Real) := - abs_sub_le_double_bound hx hmu - simpa [centeredBound, two_mul] using h12 - have hcentered_nonneg : 0 ≤ (centeredBound : Real) := by - have hsum := add_nonneg hbound_nonneg_real hbound_nonneg_real - simpa [centeredBound, two_mul] using hsum - have hvar_nonneg : 0 ≤ varianceReal x := varianceReal_nonneg x hne - have hsqrt_lower : - (sqrtLower eps : Real) ≤ Real.sqrt varEps := by - have hsqrt_eps : - (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by - have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) - simpa using h - have hle : (eps : Real) ≤ varEps := by - have hle' : (eps : Real) ≤ varianceReal x + (eps : Real) := by - exact le_add_of_nonneg_left hvar_nonneg - simpa [varEps] using hle' - have hsqrt_eps' : Real.sqrt (eps : Real) ≤ Real.sqrt varEps := by - exact Real.sqrt_le_sqrt hle - exact le_trans hsqrt_eps hsqrt_eps' - have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by - exact (Rat.cast_pos (K := Real) (q := sqrtLower eps)).2 hsqrt - have hinv_sqrt : invStd ≤ (sqrtLower eps : Real)⁻¹ := by - have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower - simpa [invStd] using h - have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by - have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt - have hdiv := ratDivUp_ge_real (x := 1) (y := sqrtLower eps) hy - simpa [invStdBound, one_div] using hdiv - have hinv : invStd ≤ (invStdBound : Real) := by - exact le_trans hinv_sqrt hinv_bound - have hinv_nonneg : 0 ≤ invStd := by - have hsqrt_nonneg : 0 ≤ Real.sqrt varEps := by - exact Real.sqrt_nonneg _ - exact inv_nonneg.2 hsqrt_nonneg - have hmul1 : |x i - μ| * invStd ≤ - (centeredBound : Real) * (invStdBound : Real) := by - have hleft : - |x i - μ| * invStd ≤ (centeredBound : Real) * invStd := by - exact mul_le_mul_of_nonneg_right hcentered_abs hinv_nonneg - have hright : - (centeredBound : Real) * invStd ≤ (centeredBound : Real) * (invStdBound : Real) := by - exact mul_le_mul_of_nonneg_left hinv hcentered_nonneg - exact le_trans hleft hright - have hmul2 : |(gamma i : Real)| * |x i - μ| * invStd ≤ - |(gamma i : Real)| * (centeredBound : Real) * (invStdBound : Real) := by - have hgamma_nonneg : 0 ≤ |(gamma i : Real)| := abs_nonneg _ - have hmul2' : |(gamma i : Real)| * (|x i - μ| * invStd) ≤ - |(gamma i : Real)| * ((centeredBound : Real) * (invStdBound : Real)) := by - exact mul_le_mul_of_nonneg_left hmul1 hgamma_nonneg - simpa only [mul_assoc] using hmul2' - let t : Real := (gamma i : Real) * (x i - μ) * invStd - have ht_abs : - |t| ≤ |(gamma i : Real)| * (centeredBound : Real) * (invStdBound : Real) := by - have ht : |t| = |(gamma i : Real)| * |x i - μ| * invStd := by - have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg - simp [t, abs_mul, hinv_abs, mul_assoc] - simpa [ht] using hmul2 - let radius : Fin n → Rat := fun j => |gamma j| * centeredBound * invStdBound - have ht_abs' : |t| ≤ (radius i : Real) := by - simpa [radius, centeredBound, invStdBound] using ht_abs - have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by - exact abs_le.mp ht_abs' - have hlow : - (beta i : Real) - (radius i : Real) ≤ t + (beta i : Real) := by - have h : (beta i : Real) + -(radius i : Real) ≤ (beta i : Real) + t := by - simpa only [add_comm, add_left_comm, add_assoc] using - add_le_add_left hbounds.1 (beta i : Real) - simpa only [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h - have hhigh : - t + (beta i : Real) ≤ (beta i : Real) + (radius i : Real) := by - have h : (beta i : Real) + t ≤ (beta i : Real) + (radius i : Real) := by - simpa only [add_comm, add_left_comm, add_assoc] using - add_le_add_left hbounds.2 (beta i : Real) - simpa only [add_comm, add_left_comm, add_assoc] using h - have hreal : - layerNormRealOfReal eps gamma beta x i = t + (beta i : Real) := by - simp [layerNormRealOfReal, hne, μ, invStd, varEps, t, add_comm] - have hlo : (bounds.1 i : Real) ≤ layerNormRealOfReal eps gamma beta x i := by - simpa [bounds, layerNormAbsBounds, radius, centeredBound, invStdBound, hreal] using hlow - have hhi : layerNormRealOfReal eps gamma beta x i ≤ (bounds.2 i : Real) := by - simpa [bounds, layerNormAbsBounds, radius, centeredBound, invStdBound, hreal] using hhigh - exact And.intro hlo hhi - -/-- `layerNormIntervalBounds` soundness for real LayerNorm outputs on real inputs. -/ -theorem layerNormIntervalBounds_spec_real {n : Nat} - (eps : Rat) (gamma beta : Fin n → Rat) (lo hi : Fin n → Rat) (x : Fin n → Real) - (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) - (hlo : ∀ i, (lo i : Real) ≤ x i) (hhi : ∀ i, x i ≤ (hi i : Real)) : - let bounds := layerNormIntervalBounds eps gamma beta lo hi - ∀ i, - (bounds.1 i : Real) ≤ layerNormRealOfReal eps gamma beta x i ∧ - layerNormRealOfReal eps gamma beta x i ≤ (bounds.2 i : Real) := by - classical - intro bounds i - have hmean_lo : (mean lo : Real) ≤ meanReal x := by - have h := - meanReal_le_meanReal (x := fun j => (lo j : Real)) (y := x) hne - (fun j => hlo j) - have hrat : (meanRat lo : Real) ≤ meanReal x := by - simpa [meanReal_eq_meanRat] using h - have hdown : (mean lo : Real) ≤ (meanRat lo : Real) := by - simpa [mean_def lo hne] using ratRoundDown_le_real (meanRat lo) - exact le_trans hdown hrat - have hmean_hi : meanReal x ≤ (meanUpper hi : Real) := by - have h := - meanReal_le_meanReal (x := x) (y := fun j => (hi j : Real)) hne - (fun j => hhi j) - have hrat : meanReal x ≤ (meanRat hi : Real) := by - simpa [meanReal_eq_meanRat] using h - have hup : (meanRat hi : Real) ≤ (meanUpper hi : Real) := by - simpa [meanUpper_def hi hne] using real_le_ratRoundUp (meanRat hi) - exact le_trans hrat hup - let μLo : Rat := mean lo - let μHi : Rat := meanUpper hi - let centeredBound : Fin n → Rat := fun j => max |lo j - μHi| |hi j - μLo| - let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) - let varEps : Real := varianceReal x + (eps : Real) - let μ : Real := meanReal x - let invStd : Real := (Real.sqrt varEps)⁻¹ - have hcentered_nonneg : 0 ≤ (centeredBound i : Real) := by - have h0 : 0 ≤ centeredBound i := by - dsimp [centeredBound] - exact le_trans (abs_nonneg _) (le_max_left _ _) - exact ratToReal_nonneg_of_nonneg h0 - have hcentered_abs : |x i - μ| ≤ (centeredBound i : Real) := by - have hmean_lo_real : (μLo : Real) ≤ μ := by - simpa [μLo, μ] using hmean_lo - have hmean_hi_real : μ ≤ (μHi : Real) := by - simpa [μHi, μ] using hmean_hi - have hlo' : (lo i : Real) - (μHi : Real) ≤ x i - μ := by - have h1 : (lo i : Real) - (μHi : Real) ≤ (lo i : Real) - μ := by - exact sub_le_sub_left hmean_hi_real (lo i : Real) - have h2 : (lo i : Real) - μ ≤ x i - μ := by - exact sub_le_sub_right (hlo i) μ - exact le_trans h1 h2 - have hhi' : x i - μ ≤ (hi i : Real) - (μLo : Real) := by - have h1 : x i - μ ≤ (hi i : Real) - μ := by - exact sub_le_sub_right (hhi i) μ - have h2 : (hi i : Real) - μ ≤ (hi i : Real) - (μLo : Real) := by - exact sub_le_sub_left hmean_lo_real (hi i : Real) - exact le_trans h1 h2 - have hbound := abs_le_max_of_bounds hlo' hhi' - simpa [centeredBound, μLo, μHi, ratToReal_abs, ratToReal_sub, - ratToReal_max] using hbound - have hvar_nonneg : 0 ≤ varianceReal x := varianceReal_nonneg x hne - have hsqrt_lower : - (sqrtLower eps : Real) ≤ Real.sqrt varEps := by - have hsqrt_eps : - (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by - have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) - simpa using h - have hle : (eps : Real) ≤ varEps := by - have hle' : (eps : Real) ≤ varianceReal x + (eps : Real) := by - exact le_add_of_nonneg_left hvar_nonneg - simpa [varEps] using hle' - have hsqrt_eps' : Real.sqrt (eps : Real) ≤ Real.sqrt varEps := by - exact Real.sqrt_le_sqrt hle - exact le_trans hsqrt_eps hsqrt_eps' - have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by - exact (Rat.cast_pos (K := Real) (q := sqrtLower eps)).2 hsqrt - have hinv_sqrt : invStd ≤ (sqrtLower eps : Real)⁻¹ := by - have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower - simpa [invStd] using h - have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by - have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt - have hdiv := ratDivUp_ge_real (x := 1) (y := sqrtLower eps) hy - simpa [invStdBound, one_div] using hdiv - have hinv : invStd ≤ (invStdBound : Real) := by - exact le_trans hinv_sqrt hinv_bound - have hinv_nonneg : 0 ≤ invStd := by - have hsqrt_nonneg : 0 ≤ Real.sqrt varEps := by - exact Real.sqrt_nonneg _ - exact inv_nonneg.2 hsqrt_nonneg - have hmul1 : |x i - μ| * invStd ≤ - (centeredBound i : Real) * (invStdBound : Real) := by - have hleft : |x i - μ| * invStd ≤ (centeredBound i : Real) * invStd := by - exact mul_le_mul_of_nonneg_right hcentered_abs hinv_nonneg - have hright : - (centeredBound i : Real) * invStd ≤ (centeredBound i : Real) * (invStdBound : Real) := by - exact mul_le_mul_of_nonneg_left hinv hcentered_nonneg - exact le_trans hleft hright - have hmul2 : |(gamma i : Real)| * |x i - μ| * invStd ≤ - |(gamma i : Real)| * (centeredBound i : Real) * (invStdBound : Real) := by - have hgamma_nonneg : 0 ≤ |(gamma i : Real)| := abs_nonneg _ - have hmul2' : |(gamma i : Real)| * (|x i - μ| * invStd) ≤ - |(gamma i : Real)| * ((centeredBound i : Real) * (invStdBound : Real)) := by - exact mul_le_mul_of_nonneg_left hmul1 hgamma_nonneg - simpa only [mul_assoc] using hmul2' - let t : Real := (gamma i : Real) * (x i - μ) * invStd - have ht_abs : - |t| ≤ |(gamma i : Real)| * (centeredBound i : Real) * (invStdBound : Real) := by - have ht : |t| = |(gamma i : Real)| * |x i - μ| * invStd := by - have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg - simp [t, abs_mul, hinv_abs, mul_assoc] - simpa [ht] using hmul2 - let radius : Fin n → Rat := fun j => |gamma j| * centeredBound j * invStdBound - have ht_abs' : |t| ≤ (radius i : Real) := by - simpa [radius, centeredBound, invStdBound] using ht_abs - have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by - exact abs_le.mp ht_abs' - have hlow : - (beta i : Real) - (radius i : Real) ≤ t + (beta i : Real) := by - have h : (beta i : Real) + -(radius i : Real) ≤ (beta i : Real) + t := by - simpa only [add_comm, add_left_comm, add_assoc] using - add_le_add_left hbounds.1 (beta i : Real) - simpa only [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h - have hhigh : - t + (beta i : Real) ≤ (beta i : Real) + (radius i : Real) := by - have h : (beta i : Real) + t ≤ (beta i : Real) + (radius i : Real) := by - simpa only [add_comm, add_left_comm, add_assoc] using - add_le_add_left hbounds.2 (beta i : Real) - simpa only [add_comm, add_left_comm, add_assoc] using h - have hreal : - layerNormRealOfReal eps gamma beta x i = t + (beta i : Real) := by - simp [layerNormRealOfReal, hne, μ, invStd, varEps, t, add_comm] - have hlo : (bounds.1 i : Real) ≤ layerNormRealOfReal eps gamma beta x i := by - simpa [bounds, layerNormIntervalBounds, hne, radius, centeredBound, invStdBound, μLo, μHi, - hreal] using hlow - have hhi : layerNormRealOfReal eps gamma beta x i ≤ (bounds.2 i : Real) := by - simpa [bounds, layerNormIntervalBounds, hne, radius, centeredBound, invStdBound, μLo, μHi, - hreal] using hhigh - exact And.intro hlo hhi - -end Bounds - -end Sound - -end Nfp diff --git a/Nfp/Sound/Bounds/LayerNorm/Basic.lean b/Nfp/Sound/Bounds/LayerNorm/Basic.lean new file mode 100644 index 0000000..f5e11ae --- /dev/null +++ b/Nfp/Sound/Bounds/LayerNorm/Basic.lean @@ -0,0 +1,813 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.BigOperators.Fin +import Mathlib.Algebra.BigOperators.Group.Finset.Basic +import Mathlib.Algebra.Order.BigOperators.Group.Finset +import Mathlib.Algebra.Order.Field.Basic +import Mathlib.Algebra.Order.Ring.Basic +import Mathlib.Data.Real.Sqrt +import Mathlib.Data.Rat.BigOperators +import Mathlib.Data.Rat.Cast.Order +import Nfp.Core.Basic +import Nfp.Sound.Bounds.LayerNorm.MeanVariance +import Nfp.Sound.Bounds.LayerNorm.SqrtBounds +import Nfp.Sound.Linear.FinFold + +/-! +LayerNorm interval bounds for rational inputs. + +This module computes rational interval bounds for LayerNorm outputs and proves +those bounds sound for real-valued LayerNorm semantics. +-/ + +namespace Nfp + +namespace Sound + +namespace Bounds + +open scoped BigOperators + +/-- Bounds for multiplying a scalar by a bounded value. -/ +def scaleInterval (x lo hi : Rat) : Rat × Rat := + if 0 ≤ x then + (x * lo, x * hi) + else + (x * hi, x * lo) + +/-- `scaleInterval` bounds a product. -/ +theorem scaleInterval_bounds {x lo hi y : Rat} + (hlo : lo ≤ y) (hhi : y ≤ hi) : + let bounds := scaleInterval x lo hi + bounds.1 ≤ x * y ∧ x * y ≤ bounds.2 := by + by_cases hx : 0 ≤ x + · have hbounds : x * lo ≤ x * y ∧ x * y ≤ x * hi := by + exact ⟨mul_le_mul_of_nonneg_left hlo hx, mul_le_mul_of_nonneg_left hhi hx⟩ + simpa [scaleInterval, hx] using hbounds + · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) + have hbounds : x * hi ≤ x * y ∧ x * y ≤ x * lo := by + exact ⟨mul_le_mul_of_nonpos_left hhi hx', mul_le_mul_of_nonpos_left hlo hx'⟩ + simpa [scaleInterval, hx] using hbounds + +/-- `scaleInterval` bounds interpreted in the reals. -/ +theorem scaleInterval_bounds_real {x lo hi : Rat} {y : Real} + (hlo : (lo : Real) ≤ y) (hhi : y ≤ (hi : Real)) : + let bounds := scaleInterval x lo hi + (bounds.1 : Real) ≤ (x : Real) * y ∧ (x : Real) * y ≤ (bounds.2 : Real) := by + by_cases hx : 0 ≤ x + · have hx' : 0 ≤ (x : Real) := ratToReal_nonneg_of_nonneg hx + have hbounds : (x : Real) * (lo : Real) ≤ (x : Real) * y ∧ + (x : Real) * y ≤ (x : Real) * (hi : Real) := by + exact ⟨mul_le_mul_of_nonneg_left hlo hx', mul_le_mul_of_nonneg_left hhi hx'⟩ + simpa [scaleInterval, hx] using hbounds + · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) + have hx'' : (x : Real) ≤ 0 := (ratToReal_nonpos_iff (x := x)).2 hx' + have hbounds : (x : Real) * (hi : Real) ≤ (x : Real) * y ∧ + (x : Real) * y ≤ (x : Real) * (lo : Real) := by + exact ⟨mul_le_mul_of_nonpos_left hhi hx'', mul_le_mul_of_nonpos_left hlo hx''⟩ + simpa [scaleInterval, hx] using hbounds + +/-- Real-valued LayerNorm output for a vector. -/ +noncomputable def layerNormReal {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) : Fin n → Real := + if n = 0 then + fun _ => 0 + else + let μ : Real := meanRat x + let varEps : Real := (varianceRat x : Real) + (eps : Real) + let invStd : Real := (Real.sqrt varEps)⁻¹ + fun i => (gamma i : Real) * ((x i : Real) - μ) * invStd + (beta i : Real) + +/-- Real-valued LayerNorm output for a real vector. -/ +noncomputable def layerNormRealOfReal {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Real) : Fin n → Real := + if n = 0 then + fun _ => 0 + else + let μ : Real := meanReal x + let varEps : Real := varianceReal x + (eps : Real) + let invStd : Real := (Real.sqrt varEps)⁻¹ + fun i => (gamma i : Real) * (x i - μ) * invStd + (beta i : Real) + +/-- Interval bounds for LayerNorm outputs. -/ +def layerNormBounds {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) : + (Fin n → Rat) × (Fin n → Rat) := + if n = 0 then + (fun _ => 0, fun _ => 0) + else + let μ : Rat := mean x + let centered : Fin n → Rat := fun i => x i - μ + let var : Rat := variance x + let varEps : Rat := var + eps + let sqrtLowerBound : Rat := max (sqrtLower eps) (sqrtLower varEps) + let sqrtUpperBound : Rat := sqrtUpper varEps + let invStdLower : Rat := ratDivDown 1 sqrtUpperBound + let invStdUpper : Rat := ratDivUp 1 sqrtLowerBound + let coeff : Fin n → Rat := fun i => gamma i * centered i + let lo : Fin n → Rat := fun i => + if 0 ≤ coeff i then + beta i + coeff i * invStdLower + else + beta i + coeff i * invStdUpper + let hi : Fin n → Rat := fun i => + if 0 ≤ coeff i then + beta i + coeff i * invStdUpper + else + beta i + coeff i * invStdLower + (lo, hi) + +/-- `layerNormBounds` soundness for real LayerNorm outputs. -/ +theorem layerNormBounds_spec {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : + let bounds := layerNormBounds eps gamma beta x + ∀ i, + (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i ∧ + layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + classical + intro bounds i + let μRat : Rat := mean x + let varRat : Rat := variance x + let varEpsRat : Rat := varRat + eps + let sqrtLowerBound : Rat := max (sqrtLower eps) (sqrtLower varEpsRat) + let sqrtUpperBound : Rat := sqrtUpper varEpsRat + let invStdLower : Rat := ratDivDown 1 sqrtUpperBound + let invStdUpper : Rat := ratDivUp 1 sqrtLowerBound + let centered : Rat := x i - μRat + let coeff : Rat := gamma i * centered + let μ : Real := meanRat x + let varEps : Real := (varianceRat x : Real) + (eps : Real) + let invStd : Real := (Real.sqrt varEps)⁻¹ + have hmu : (μRat : Real) = μ := by + simp [μRat, μ, mean_def, hne, ratRoundDown] + have hvar : (varRat : Real) = (varianceRat x : Real) := by + simp [varRat, variance_def, hne, ratRoundDown] + have hvarEps : (varEpsRat : Real) = varEps := by + simp [varEpsRat, varEps, hvar] + have hvar_nonneg : 0 ≤ (varianceRat x : Real) := varianceRat_nonneg_real x hne + have hvar_nonneg_rat : 0 ≤ varianceRat x := by + exact (ratToReal_nonneg_iff (x := varianceRat x)).1 hvar_nonneg + have hvarRat_nonneg : 0 ≤ varRat := by + have h := ratRoundDown_nonneg (q := varianceRat x) hvar_nonneg_rat + simpa [varRat, variance_def x hne] using h + have hvarEps_nonneg : 0 ≤ varEpsRat := by + exact add_nonneg hvarRat_nonneg (le_of_lt heps) + have hsqrt_lower : + (sqrtLowerBound : Real) ≤ Real.sqrt varEps := by + have hsqrt_eps : (sqrtLower eps : Real) ≤ Real.sqrt varEps := by + have hsqrt_eps' : (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by + have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) + simpa using h + have hle : (eps : Real) ≤ varEps := by + have hle' : (eps : Real) ≤ (varianceRat x : Real) + (eps : Real) := + le_add_of_nonneg_left hvar_nonneg + simpa [varEps] using hle' + exact le_trans hsqrt_eps' (Real.sqrt_le_sqrt hle) + have hsqrt_var : (sqrtLower varEpsRat : Real) ≤ Real.sqrt varEps := by + have hsqrt_var' : + (sqrtLower varEpsRat : Real) ≤ Real.sqrt (varEpsRat : Real) := by + have h := sqrtLower_le_real_sqrt (q := varEpsRat) hvarEps_nonneg + simpa using h + have hle : (varEpsRat : Real) ≤ varEps := by + simp [hvarEps] + exact le_trans hsqrt_var' (Real.sqrt_le_sqrt hle) + have hmax : + max (sqrtLower eps : Real) (sqrtLower varEpsRat : Real) ≤ Real.sqrt varEps := + (max_le_iff).2 ⟨hsqrt_eps, hsqrt_var⟩ + simpa [sqrtLowerBound, ratToReal_max] using hmax + have hsqrt_upper : + Real.sqrt varEps ≤ (sqrtUpperBound : Real) := by + have h := real_sqrt_le_sqrtUpper (q := varEpsRat) hvarEps_nonneg + simpa [sqrtUpperBound, hvarEps] using h + have hsqrt_lower_pos_rat : 0 < sqrtLowerBound := by + have hpos : 0 < sqrtLower eps := hsqrt + have hpos' : 0 < max (sqrtLower eps) (sqrtLower varEpsRat) := + lt_of_lt_of_le hpos (le_max_left _ _) + simpa [sqrtLowerBound] using hpos' + have hsqrt_lower_pos : 0 < (sqrtLowerBound : Real) := by + exact (Rat.cast_pos (K := Real) (q := sqrtLowerBound)).2 hsqrt_lower_pos_rat + have hsqrt_upper_pos_rat : 0 < sqrtUpperBound := by + simpa [sqrtUpperBound] using sqrtUpper_pos varEpsRat + have hsqrt_upper_pos : 0 < (sqrtUpperBound : Real) := by + exact (Rat.cast_pos (K := Real) (q := sqrtUpperBound)).2 hsqrt_upper_pos_rat + have hvarEps_pos : 0 < varEps := by + have heps_real : 0 < (eps : Real) := by + exact_mod_cast heps + have hpos := add_pos_of_nonneg_of_pos hvar_nonneg heps_real + simpa [varEps] using hpos + have hsqrt_pos : 0 < Real.sqrt varEps := Real.sqrt_pos.2 hvarEps_pos + have hinv_lower_real : + (sqrtUpperBound : Real)⁻¹ ≤ invStd := by + have hle := inv_anti₀ hsqrt_pos hsqrt_upper + simpa [invStd] using hle + have hinv_upper_real : + invStd ≤ (sqrtLowerBound : Real)⁻¹ := by + have hle := inv_anti₀ hsqrt_lower_pos hsqrt_lower + simpa [invStd] using hle + have hupper_ne : sqrtUpperBound ≠ 0 := ne_of_gt hsqrt_upper_pos_rat + have hlower_ne : sqrtLowerBound ≠ 0 := ne_of_gt hsqrt_lower_pos_rat + have hinv_lower : (invStdLower : Real) ≤ invStd := by + simpa [invStdLower, ratDivDown, hupper_ne, one_div] using hinv_lower_real + have hinv_upper : invStd ≤ (invStdUpper : Real) := by + simpa [invStdUpper, ratDivUp, hlower_ne, one_div] using hinv_upper_real + have hlayer : + layerNormReal eps gamma beta x i = + (beta i : Real) + (coeff : Real) * invStd := by + simp [layerNormReal, hne, coeff, centered, μ, hmu, invStd, varEps, add_comm, mul_assoc] + by_cases hcoeff : 0 ≤ coeff + · have hcoeff_real : 0 ≤ (coeff : Real) := + ratToReal_nonneg_of_nonneg hcoeff + have hlow_raw : + (beta i : Real) + (coeff : Real) * (invStdLower : Real) ≤ + (beta i : Real) + (coeff : Real) * invStd := by + have hmul := mul_le_mul_of_nonneg_left hinv_lower hcoeff_real + simpa only [add_comm] using add_le_add_left hmul (beta i : Real) + have hhigh_raw : + (beta i : Real) + (coeff : Real) * invStd ≤ + (beta i : Real) + (coeff : Real) * (invStdUpper : Real) := by + have hmul := mul_le_mul_of_nonneg_left hinv_upper hcoeff_real + simpa only [add_comm] using add_le_add_left hmul (beta i : Real) + have hlo : (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i := by + simpa [bounds, layerNormBounds, hne, μRat, centered, varRat, varEpsRat, + sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] + using hlow_raw + have hhi : layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + simpa [bounds, layerNormBounds, hne, μRat, centered, varRat, varEpsRat, + sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] + using hhigh_raw + exact And.intro hlo hhi + · have hcoeff_lt : coeff < 0 := lt_of_not_ge hcoeff + have hcoeff_real : (coeff : Real) ≤ 0 := by + exact_mod_cast (le_of_lt hcoeff_lt) + have hlow_raw : + (beta i : Real) + (coeff : Real) * (invStdUpper : Real) ≤ + (beta i : Real) + (coeff : Real) * invStd := by + have hmul := mul_le_mul_of_nonpos_left hinv_upper hcoeff_real + simpa only [add_comm] using add_le_add_left hmul (beta i : Real) + have hhigh_raw : + (beta i : Real) + (coeff : Real) * invStd ≤ + (beta i : Real) + (coeff : Real) * (invStdLower : Real) := by + have hmul := mul_le_mul_of_nonpos_left hinv_lower hcoeff_real + simpa only [add_comm] using add_le_add_left hmul (beta i : Real) + have hlo : (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i := by + simpa [bounds, layerNormBounds, hne, μRat, centered, varRat, varEpsRat, + sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] + using hlow_raw + have hhi : layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + simpa [bounds, layerNormBounds, hne, μRat, centered, varRat, varEpsRat, + sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] + using hhigh_raw + exact And.intro hlo hhi + +/-! +Local bounds for monotone multiplication in real-valued bounds. +-/ + +/-- Lower sqrt bound against the variance-plus-eps term. -/ +theorem sqrtLower_le_real_sqrt_varEps {n : Nat} (eps : Rat) (x : Fin n → Rat) + (hne : n ≠ 0) (heps : 0 < eps) : + let varEps : Real := (varianceRat x : Real) + (eps : Real) + (sqrtLower eps : Real) ≤ Real.sqrt varEps := by + intro varEps + have hsqrt_eps : + (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by + have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) + simpa using h + have hvar_nonneg : 0 ≤ (varianceRat x : Real) := varianceRat_nonneg_real x hne + have hle : (eps : Real) ≤ varEps := by + have hle' : (eps : Real) ≤ (varianceRat x : Real) + (eps : Real) := + le_add_of_nonneg_left hvar_nonneg + simpa [varEps] using hle' + have hsqrt_eps' : Real.sqrt (eps : Real) ≤ Real.sqrt varEps := by + exact Real.sqrt_le_sqrt hle + exact le_trans hsqrt_eps hsqrt_eps' + +/-- Inverse-std upper bound from the lower sqrt bound. -/ +theorem invStd_le_invStdBound {n : Nat} (eps : Rat) (x : Fin n → Rat) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : + let varEps : Real := (varianceRat x : Real) + (eps : Real) + let invStd : Real := (Real.sqrt varEps)⁻¹ + let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) + invStd ≤ (invStdBound : Real) := by + intro varEps invStd invStdBound + have hsqrt_lower : (sqrtLower eps : Real) ≤ Real.sqrt varEps := by + simpa [varEps] using + (sqrtLower_le_real_sqrt_varEps (eps := eps) (x := x) hne heps) + have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by + exact (Rat.cast_pos (K := Real) (q := sqrtLower eps)).2 hsqrt + have hinv_sqrt : invStd ≤ (sqrtLower eps : Real)⁻¹ := by + have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower + simpa [invStd] using h + have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by + have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt + have hdiv := ratDivUp_ge_real (x := 1) (y := sqrtLower eps) hy + simpa [invStdBound, one_div] using hdiv + exact le_trans hinv_sqrt hinv_bound + +/-- Inverse-std is nonnegative. -/ +theorem invStd_nonneg {n : Nat} (eps : Rat) (x : Fin n → Rat) : + let varEps : Real := (varianceRat x : Real) + (eps : Real) + let invStd : Real := (Real.sqrt varEps)⁻¹ + 0 ≤ invStd := by + intro varEps invStd + have hsqrt_nonneg : 0 ≤ Real.sqrt varEps := by + exact Real.sqrt_nonneg _ + exact inv_nonneg.2 hsqrt_nonneg + +/-- Interval bounds for LayerNorm outputs from per-coordinate intervals. -/ +def layerNormIntervalBounds {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (lo hi : Fin n → Rat) : + (Fin n → Rat) × (Fin n → Rat) := + if n = 0 then + (fun _ => 0, fun _ => 0) + else + let μLo := mean lo + let μHi := meanUpper hi + let centeredBound : Fin n → Rat := fun i => + max |lo i - μHi| |hi i - μLo| + let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) + let radius : Fin n → Rat := fun i => |gamma i| * centeredBound i * invStdBound + (fun i => beta i - radius i, fun i => beta i + radius i) + +/-- `layerNormIntervalBounds` soundness for real LayerNorm outputs. -/ +theorem layerNormIntervalBounds_spec {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (lo hi : Fin n → Rat) (x : Fin n → Rat) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) + (hlo : ∀ i, lo i ≤ x i) (hhi : ∀ i, x i ≤ hi i) : + let bounds := layerNormIntervalBounds eps gamma beta lo hi + ∀ i, + (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i ∧ + layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + classical + intro bounds i + let μLo : Rat := mean lo + let μHi : Rat := meanUpper hi + let centeredBound : Fin n → Rat := fun j => max |lo j - μHi| |hi j - μLo| + let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) + let varEps : Real := (varianceRat x : Real) + (eps : Real) + let μ : Real := meanRat x + let invStd : Real := (Real.sqrt varEps)⁻¹ + have hcentered_nonneg : 0 ≤ (centeredBound i : Real) := by + have h0 : 0 ≤ centeredBound i := by + dsimp [centeredBound] + exact le_trans (abs_nonneg _) (le_max_left _ _) + exact ratToReal_nonneg_of_nonneg h0 + have hcentered_abs : |(x i : Real) - μ| ≤ (centeredBound i : Real) := by + have hmean_lo_real : (μLo : Real) ≤ μ := by + have hmean_rat : (meanRat lo : Real) ≤ (meanRat x : Real) := + meanRat_le_meanRat_real lo x hne hlo + have hdown : (μLo : Real) ≤ (meanRat lo : Real) := by + simpa [μLo, mean_def lo hne] using ratRoundDown_le_real (meanRat lo) + exact le_trans hdown hmean_rat + have hmean_hi_real : μ ≤ (μHi : Real) := by + have hmean_rat : (meanRat x : Real) ≤ (meanRat hi : Real) := + meanRat_le_meanRat_real x hi hne hhi + have hup : (meanRat hi : Real) ≤ (μHi : Real) := by + simpa [μHi, meanUpper_def hi hne] using real_le_ratRoundUp (meanRat hi) + exact le_trans hmean_rat hup + have hlo' : (lo i : Real) - (μHi : Real) ≤ (x i : Real) - μ := by + have h1 : (lo i : Real) - (μHi : Real) ≤ (lo i : Real) - μ := by + exact sub_le_sub_left hmean_hi_real (lo i : Real) + have h2 : (lo i : Real) - μ ≤ (x i : Real) - μ := by + exact sub_le_sub_right + (by + exact ratToReal_le_of_le (hlo i)) + μ + exact le_trans h1 h2 + have hhi' : (x i : Real) - μ ≤ (hi i : Real) - (μLo : Real) := by + have h1 : (x i : Real) - μ ≤ (hi i : Real) - μ := by + exact sub_le_sub_right + (by + exact ratToReal_le_of_le (hhi i)) + μ + have h2 : (hi i : Real) - μ ≤ (hi i : Real) - (μLo : Real) := by + exact sub_le_sub_left hmean_lo_real (hi i : Real) + exact le_trans h1 h2 + have hbound := abs_le_max_of_bounds hlo' hhi' + simpa [centeredBound, μLo, μHi, ratToReal_abs, ratToReal_sub, + ratToReal_max] using hbound + have hinv : invStd ≤ (invStdBound : Real) := by + simpa [varEps, invStd, invStdBound] using + (invStd_le_invStdBound (eps := eps) (x := x) hne heps hsqrt) + have hinv_nonneg : 0 ≤ invStd := by + simp [varEps, invStd] + have hmul1 : |(x i : Real) - μ| * invStd ≤ + (centeredBound i : Real) * (invStdBound : Real) := by + have hleft : + |(x i : Real) - μ| * invStd ≤ (centeredBound i : Real) * invStd := by + exact mul_le_mul_of_nonneg_right hcentered_abs hinv_nonneg + have hright : + (centeredBound i : Real) * invStd ≤ (centeredBound i : Real) * (invStdBound : Real) := by + exact mul_le_mul_of_nonneg_left hinv hcentered_nonneg + exact le_trans hleft hright + have hmul2 : |(gamma i : Real)| * |(x i : Real) - μ| * invStd ≤ + |(gamma i : Real)| * (centeredBound i : Real) * (invStdBound : Real) := by + have hgamma_nonneg : 0 ≤ |(gamma i : Real)| := abs_nonneg _ + have hmul2' : |(gamma i : Real)| * (|(x i : Real) - μ| * invStd) ≤ + |(gamma i : Real)| * ((centeredBound i : Real) * (invStdBound : Real)) := by + exact mul_le_mul_of_nonneg_left hmul1 hgamma_nonneg + simpa only [mul_assoc] using hmul2' + let t : Real := (gamma i : Real) * ((x i : Real) - μ) * invStd + have ht_abs : + |t| ≤ |(gamma i : Real)| * (centeredBound i : Real) * (invStdBound : Real) := by + have ht : |t| = |(gamma i : Real)| * |(x i : Real) - μ| * invStd := by + have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg + simp [t, abs_mul, hinv_abs, mul_assoc] + simpa [ht] using hmul2 + let radius : Fin n → Rat := fun j => |gamma j| * centeredBound j * invStdBound + have ht_abs' : |t| ≤ (radius i : Real) := by + simpa [radius, centeredBound, invStdBound] using ht_abs + have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by + exact abs_le.mp ht_abs' + have hlow : + (beta i : Real) - (radius i : Real) ≤ t + (beta i : Real) := by + have h : (beta i : Real) + -(radius i : Real) ≤ (beta i : Real) + t := by + simpa only [add_comm, add_left_comm, add_assoc] using + add_le_add_left hbounds.1 (beta i : Real) + simpa only [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h + have hhigh : + t + (beta i : Real) ≤ (beta i : Real) + (radius i : Real) := by + have h : (beta i : Real) + t ≤ (beta i : Real) + (radius i : Real) := by + simpa only [add_comm, add_left_comm, add_assoc] using + add_le_add_left hbounds.2 (beta i : Real) + simpa only [add_comm, add_left_comm, add_assoc] using h + have hreal : + layerNormReal eps gamma beta x i = t + (beta i : Real) := by + simp [layerNormReal, hne, μ, invStd, varEps, t, add_comm] + have hlo : (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i := by + simpa [bounds, layerNormIntervalBounds, hne, radius, centeredBound, invStdBound, μLo, μHi, + hreal] using hlow + have hhi : layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + simpa [bounds, layerNormIntervalBounds, hne, radius, centeredBound, invStdBound, μLo, μHi, + hreal] using hhigh + exact And.intro hlo hhi + +/-- Interval bounds for LayerNorm outputs from an absolute input bound. -/ +def layerNormAbsBounds {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (absBound : Rat) : + (Fin n → Rat) × (Fin n → Rat) := + let centeredBound : Rat := 2 * absBound + let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) + let radius : Fin n → Rat := fun i => |gamma i| * centeredBound * invStdBound + (fun i => beta i - radius i, fun i => beta i + radius i) + +/-- Bound a centered value by double the absolute bound. -/ +private theorem abs_sub_le_double_bound {a b bound : Real} + (ha : |a| ≤ bound) (hb : |b| ≤ bound) : + |a - b| ≤ bound + bound := by + have h1 : |a - b| ≤ |a| + |b| := by + simpa [sub_eq_add_neg, abs_neg] using abs_add_le a (-b) + exact le_trans h1 (add_le_add ha hb) + +/-- `layerNormAbsBounds` soundness for real LayerNorm outputs under absolute input bounds. -/ +theorem layerNormAbsBounds_spec {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (absBound : Rat) (x : Fin n → Rat) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) + (habs : ∀ i, |x i| ≤ absBound) : + let bounds := layerNormAbsBounds eps gamma beta absBound + ∀ i, + (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i ∧ + layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + classical + intro bounds i + have hmean_abs_real : |(meanRat x : Real)| ≤ (absBound : Real) := by + have h := + meanReal_abs_le_bound (x := fun j => (x j : Real)) (bound := absBound) hne + (by + intro j + exact ratToReal_abs_le_of_le (habs j)) + simpa [meanReal_eq_meanRat] using h + have hbound_nonneg : 0 ≤ absBound := by + have hposn : 0 < n := Nat.pos_of_ne_zero hne + let i0 : Fin n := ⟨0, hposn⟩ + have h0 : 0 ≤ |x i0| := abs_nonneg _ + exact le_trans h0 (habs i0) + let centeredBound : Rat := 2 * absBound + let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) + let varEps : Real := (varianceRat x : Real) + (eps : Real) + let μ : Real := meanRat x + let invStd : Real := (Real.sqrt varEps)⁻¹ + have hcentered_abs : |(x i : Real) - μ| ≤ (centeredBound : Real) := by + have hx : |(x i : Real)| ≤ (absBound : Real) := by + exact ratToReal_abs_le_of_le (habs i) + have hmu : |μ| ≤ (absBound : Real) := by + simpa [μ] using hmean_abs_real + have h12 : |(x i : Real) - μ| ≤ (absBound : Real) + (absBound : Real) := + abs_sub_le_double_bound hx hmu + simpa [centeredBound, two_mul] using h12 + have hbound_nonneg_real : 0 ≤ (absBound : Real) := by + exact ratToReal_nonneg_of_nonneg hbound_nonneg + have hcentered_nonneg : 0 ≤ (centeredBound : Real) := by + have hsum := add_nonneg hbound_nonneg_real hbound_nonneg_real + simpa [centeredBound, two_mul] using hsum + have hinv : invStd ≤ (invStdBound : Real) := by + simpa [varEps, invStd, invStdBound] using + (invStd_le_invStdBound (eps := eps) (x := x) hne heps hsqrt) + have hinv_nonneg : 0 ≤ invStd := by + simp [varEps, invStd] + have hmul1 : |(x i : Real) - μ| * invStd ≤ + (centeredBound : Real) * (invStdBound : Real) := by + have hleft : + |(x i : Real) - μ| * invStd ≤ (centeredBound : Real) * invStd := by + exact mul_le_mul_of_nonneg_right hcentered_abs hinv_nonneg + have hright : + (centeredBound : Real) * invStd ≤ (centeredBound : Real) * (invStdBound : Real) := by + exact mul_le_mul_of_nonneg_left hinv hcentered_nonneg + exact le_trans hleft hright + have hmul2 : |(gamma i : Real)| * |(x i : Real) - μ| * invStd ≤ + |(gamma i : Real)| * (centeredBound : Real) * (invStdBound : Real) := by + have hgamma_nonneg : 0 ≤ |(gamma i : Real)| := abs_nonneg _ + have hmul2' : |(gamma i : Real)| * (|(x i : Real) - μ| * invStd) ≤ + |(gamma i : Real)| * ((centeredBound : Real) * (invStdBound : Real)) := by + exact mul_le_mul_of_nonneg_left hmul1 hgamma_nonneg + simpa only [mul_assoc] using hmul2' + let t : Real := (gamma i : Real) * ((x i : Real) - μ) * invStd + have ht_abs : + |t| ≤ |(gamma i : Real)| * (centeredBound : Real) * (invStdBound : Real) := by + have ht : |t| = |(gamma i : Real)| * |(x i : Real) - μ| * invStd := by + have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg + simp [t, abs_mul, hinv_abs, mul_assoc] + simpa [ht] using hmul2 + let radius : Fin n → Rat := fun j => |gamma j| * centeredBound * invStdBound + have ht_abs' : |t| ≤ (radius i : Real) := by + simpa [radius, centeredBound, invStdBound] using ht_abs + have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by + exact abs_le.mp ht_abs' + have hlow : + (beta i : Real) - (radius i : Real) ≤ t + (beta i : Real) := by + have h : (beta i : Real) + -(radius i : Real) ≤ (beta i : Real) + t := by + simpa only [add_comm, add_left_comm, add_assoc] using + add_le_add_left hbounds.1 (beta i : Real) + simpa only [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h + have hhigh : + t + (beta i : Real) ≤ (beta i : Real) + (radius i : Real) := by + have h : (beta i : Real) + t ≤ (beta i : Real) + (radius i : Real) := by + simpa only [add_comm, add_left_comm, add_assoc] using + add_le_add_left hbounds.2 (beta i : Real) + simpa only [add_comm, add_left_comm, add_assoc] using h + have hreal : + layerNormReal eps gamma beta x i = t + (beta i : Real) := by + simp [layerNormReal, hne, μ, invStd, varEps, t, add_comm] + have hlo : (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i := by + simpa [bounds, layerNormAbsBounds, radius, centeredBound, invStdBound, hreal] using hlow + have hhi : layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + simpa [bounds, layerNormAbsBounds, radius, centeredBound, invStdBound, hreal] using hhigh + exact And.intro hlo hhi + +/-- `layerNormAbsBounds` soundness for real LayerNorm outputs on real inputs. -/ +theorem layerNormAbsBounds_spec_real {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (absBound : Rat) (x : Fin n → Real) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) + (habs : ∀ i, |x i| ≤ (absBound : Real)) : + let bounds := layerNormAbsBounds eps gamma beta absBound + ∀ i, + (bounds.1 i : Real) ≤ layerNormRealOfReal eps gamma beta x i ∧ + layerNormRealOfReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + classical + intro bounds i + have hmean_abs : |meanReal x| ≤ (absBound : Real) := + meanReal_abs_le_bound x absBound hne habs + have hbound_nonneg_real : 0 ≤ (absBound : Real) := by + have hposn : 0 < n := Nat.pos_of_ne_zero hne + let i0 : Fin n := ⟨0, hposn⟩ + have h0 : 0 ≤ |x i0| := abs_nonneg _ + exact le_trans h0 (habs i0) + let centeredBound : Rat := 2 * absBound + let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) + let varEps : Real := varianceReal x + (eps : Real) + let μ : Real := meanReal x + let invStd : Real := (Real.sqrt varEps)⁻¹ + have hcentered_abs : |x i - μ| ≤ (centeredBound : Real) := by + have hx : |x i| ≤ (absBound : Real) := habs i + have hmu : |μ| ≤ (absBound : Real) := by + simpa using hmean_abs + have h12 : |x i - μ| ≤ (absBound : Real) + (absBound : Real) := + abs_sub_le_double_bound hx hmu + simpa [centeredBound, two_mul] using h12 + have hcentered_nonneg : 0 ≤ (centeredBound : Real) := by + have hsum := add_nonneg hbound_nonneg_real hbound_nonneg_real + simpa [centeredBound, two_mul] using hsum + have hvar_nonneg : 0 ≤ varianceReal x := varianceReal_nonneg x hne + have hsqrt_lower : + (sqrtLower eps : Real) ≤ Real.sqrt varEps := by + have hsqrt_eps : + (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by + have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) + simpa using h + have hle : (eps : Real) ≤ varEps := by + have hle' : (eps : Real) ≤ varianceReal x + (eps : Real) := by + exact le_add_of_nonneg_left hvar_nonneg + simpa [varEps] using hle' + have hsqrt_eps' : Real.sqrt (eps : Real) ≤ Real.sqrt varEps := by + exact Real.sqrt_le_sqrt hle + exact le_trans hsqrt_eps hsqrt_eps' + have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by + exact (Rat.cast_pos (K := Real) (q := sqrtLower eps)).2 hsqrt + have hinv_sqrt : invStd ≤ (sqrtLower eps : Real)⁻¹ := by + have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower + simpa [invStd] using h + have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by + have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt + have hdiv := ratDivUp_ge_real (x := 1) (y := sqrtLower eps) hy + simpa [invStdBound, one_div] using hdiv + have hinv : invStd ≤ (invStdBound : Real) := by + exact le_trans hinv_sqrt hinv_bound + have hinv_nonneg : 0 ≤ invStd := by + have hsqrt_nonneg : 0 ≤ Real.sqrt varEps := by + exact Real.sqrt_nonneg _ + exact inv_nonneg.2 hsqrt_nonneg + have hmul1 : |x i - μ| * invStd ≤ + (centeredBound : Real) * (invStdBound : Real) := by + have hleft : + |x i - μ| * invStd ≤ (centeredBound : Real) * invStd := by + exact mul_le_mul_of_nonneg_right hcentered_abs hinv_nonneg + have hright : + (centeredBound : Real) * invStd ≤ (centeredBound : Real) * (invStdBound : Real) := by + exact mul_le_mul_of_nonneg_left hinv hcentered_nonneg + exact le_trans hleft hright + have hmul2 : |(gamma i : Real)| * |x i - μ| * invStd ≤ + |(gamma i : Real)| * (centeredBound : Real) * (invStdBound : Real) := by + have hgamma_nonneg : 0 ≤ |(gamma i : Real)| := abs_nonneg _ + have hmul2' : |(gamma i : Real)| * (|x i - μ| * invStd) ≤ + |(gamma i : Real)| * ((centeredBound : Real) * (invStdBound : Real)) := by + exact mul_le_mul_of_nonneg_left hmul1 hgamma_nonneg + simpa only [mul_assoc] using hmul2' + let t : Real := (gamma i : Real) * (x i - μ) * invStd + have ht_abs : + |t| ≤ |(gamma i : Real)| * (centeredBound : Real) * (invStdBound : Real) := by + have ht : |t| = |(gamma i : Real)| * |x i - μ| * invStd := by + have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg + simp [t, abs_mul, hinv_abs, mul_assoc] + simpa [ht] using hmul2 + let radius : Fin n → Rat := fun j => |gamma j| * centeredBound * invStdBound + have ht_abs' : |t| ≤ (radius i : Real) := by + simpa [radius, centeredBound, invStdBound] using ht_abs + have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by + exact abs_le.mp ht_abs' + have hlow : + (beta i : Real) - (radius i : Real) ≤ t + (beta i : Real) := by + have h : (beta i : Real) + -(radius i : Real) ≤ (beta i : Real) + t := by + simpa only [add_comm, add_left_comm, add_assoc] using + add_le_add_left hbounds.1 (beta i : Real) + simpa only [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h + have hhigh : + t + (beta i : Real) ≤ (beta i : Real) + (radius i : Real) := by + have h : (beta i : Real) + t ≤ (beta i : Real) + (radius i : Real) := by + simpa only [add_comm, add_left_comm, add_assoc] using + add_le_add_left hbounds.2 (beta i : Real) + simpa only [add_comm, add_left_comm, add_assoc] using h + have hreal : + layerNormRealOfReal eps gamma beta x i = t + (beta i : Real) := by + simp [layerNormRealOfReal, hne, μ, invStd, varEps, t, add_comm] + have hlo : (bounds.1 i : Real) ≤ layerNormRealOfReal eps gamma beta x i := by + simpa [bounds, layerNormAbsBounds, radius, centeredBound, invStdBound, hreal] using hlow + have hhi : layerNormRealOfReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + simpa [bounds, layerNormAbsBounds, radius, centeredBound, invStdBound, hreal] using hhigh + exact And.intro hlo hhi + +/-- `layerNormIntervalBounds` soundness for real LayerNorm outputs on real inputs. -/ +theorem layerNormIntervalBounds_spec_real {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (lo hi : Fin n → Rat) (x : Fin n → Real) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) + (hlo : ∀ i, (lo i : Real) ≤ x i) (hhi : ∀ i, x i ≤ (hi i : Real)) : + let bounds := layerNormIntervalBounds eps gamma beta lo hi + ∀ i, + (bounds.1 i : Real) ≤ layerNormRealOfReal eps gamma beta x i ∧ + layerNormRealOfReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + classical + intro bounds i + have hmean_lo : (mean lo : Real) ≤ meanReal x := by + have h := + meanReal_le_meanReal (x := fun j => (lo j : Real)) (y := x) hne + (fun j => hlo j) + have hrat : (meanRat lo : Real) ≤ meanReal x := by + simpa [meanReal_eq_meanRat] using h + have hdown : (mean lo : Real) ≤ (meanRat lo : Real) := by + simpa [mean_def lo hne] using ratRoundDown_le_real (meanRat lo) + exact le_trans hdown hrat + have hmean_hi : meanReal x ≤ (meanUpper hi : Real) := by + have h := + meanReal_le_meanReal (x := x) (y := fun j => (hi j : Real)) hne + (fun j => hhi j) + have hrat : meanReal x ≤ (meanRat hi : Real) := by + simpa [meanReal_eq_meanRat] using h + have hup : (meanRat hi : Real) ≤ (meanUpper hi : Real) := by + simpa [meanUpper_def hi hne] using real_le_ratRoundUp (meanRat hi) + exact le_trans hrat hup + let μLo : Rat := mean lo + let μHi : Rat := meanUpper hi + let centeredBound : Fin n → Rat := fun j => max |lo j - μHi| |hi j - μLo| + let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) + let varEps : Real := varianceReal x + (eps : Real) + let μ : Real := meanReal x + let invStd : Real := (Real.sqrt varEps)⁻¹ + have hcentered_nonneg : 0 ≤ (centeredBound i : Real) := by + have h0 : 0 ≤ centeredBound i := by + dsimp [centeredBound] + exact le_trans (abs_nonneg _) (le_max_left _ _) + exact ratToReal_nonneg_of_nonneg h0 + have hcentered_abs : |x i - μ| ≤ (centeredBound i : Real) := by + have hmean_lo_real : (μLo : Real) ≤ μ := by + simpa [μLo, μ] using hmean_lo + have hmean_hi_real : μ ≤ (μHi : Real) := by + simpa [μHi, μ] using hmean_hi + have hlo' : (lo i : Real) - (μHi : Real) ≤ x i - μ := by + have h1 : (lo i : Real) - (μHi : Real) ≤ (lo i : Real) - μ := by + exact sub_le_sub_left hmean_hi_real (lo i : Real) + have h2 : (lo i : Real) - μ ≤ x i - μ := by + exact sub_le_sub_right (hlo i) μ + exact le_trans h1 h2 + have hhi' : x i - μ ≤ (hi i : Real) - (μLo : Real) := by + have h1 : x i - μ ≤ (hi i : Real) - μ := by + exact sub_le_sub_right (hhi i) μ + have h2 : (hi i : Real) - μ ≤ (hi i : Real) - (μLo : Real) := by + exact sub_le_sub_left hmean_lo_real (hi i : Real) + exact le_trans h1 h2 + have hbound := abs_le_max_of_bounds hlo' hhi' + simpa [centeredBound, μLo, μHi, ratToReal_abs, ratToReal_sub, + ratToReal_max] using hbound + have hvar_nonneg : 0 ≤ varianceReal x := varianceReal_nonneg x hne + have hsqrt_lower : + (sqrtLower eps : Real) ≤ Real.sqrt varEps := by + have hsqrt_eps : + (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by + have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) + simpa using h + have hle : (eps : Real) ≤ varEps := by + have hle' : (eps : Real) ≤ varianceReal x + (eps : Real) := by + exact le_add_of_nonneg_left hvar_nonneg + simpa [varEps] using hle' + have hsqrt_eps' : Real.sqrt (eps : Real) ≤ Real.sqrt varEps := by + exact Real.sqrt_le_sqrt hle + exact le_trans hsqrt_eps hsqrt_eps' + have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by + exact (Rat.cast_pos (K := Real) (q := sqrtLower eps)).2 hsqrt + have hinv_sqrt : invStd ≤ (sqrtLower eps : Real)⁻¹ := by + have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower + simpa [invStd] using h + have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by + have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt + have hdiv := ratDivUp_ge_real (x := 1) (y := sqrtLower eps) hy + simpa [invStdBound, one_div] using hdiv + have hinv : invStd ≤ (invStdBound : Real) := by + exact le_trans hinv_sqrt hinv_bound + have hinv_nonneg : 0 ≤ invStd := by + have hsqrt_nonneg : 0 ≤ Real.sqrt varEps := by + exact Real.sqrt_nonneg _ + exact inv_nonneg.2 hsqrt_nonneg + have hmul1 : |x i - μ| * invStd ≤ + (centeredBound i : Real) * (invStdBound : Real) := by + have hleft : |x i - μ| * invStd ≤ (centeredBound i : Real) * invStd := by + exact mul_le_mul_of_nonneg_right hcentered_abs hinv_nonneg + have hright : + (centeredBound i : Real) * invStd ≤ (centeredBound i : Real) * (invStdBound : Real) := by + exact mul_le_mul_of_nonneg_left hinv hcentered_nonneg + exact le_trans hleft hright + have hmul2 : |(gamma i : Real)| * |x i - μ| * invStd ≤ + |(gamma i : Real)| * (centeredBound i : Real) * (invStdBound : Real) := by + have hgamma_nonneg : 0 ≤ |(gamma i : Real)| := abs_nonneg _ + have hmul2' : |(gamma i : Real)| * (|x i - μ| * invStd) ≤ + |(gamma i : Real)| * ((centeredBound i : Real) * (invStdBound : Real)) := by + exact mul_le_mul_of_nonneg_left hmul1 hgamma_nonneg + simpa only [mul_assoc] using hmul2' + let t : Real := (gamma i : Real) * (x i - μ) * invStd + have ht_abs : + |t| ≤ |(gamma i : Real)| * (centeredBound i : Real) * (invStdBound : Real) := by + have ht : |t| = |(gamma i : Real)| * |x i - μ| * invStd := by + have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg + simp [t, abs_mul, hinv_abs, mul_assoc] + simpa [ht] using hmul2 + let radius : Fin n → Rat := fun j => |gamma j| * centeredBound j * invStdBound + have ht_abs' : |t| ≤ (radius i : Real) := by + simpa [radius, centeredBound, invStdBound] using ht_abs + have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by + exact abs_le.mp ht_abs' + have hlow : + (beta i : Real) - (radius i : Real) ≤ t + (beta i : Real) := by + have h : (beta i : Real) + -(radius i : Real) ≤ (beta i : Real) + t := by + simpa only [add_comm, add_left_comm, add_assoc] using + add_le_add_left hbounds.1 (beta i : Real) + simpa only [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h + have hhigh : + t + (beta i : Real) ≤ (beta i : Real) + (radius i : Real) := by + have h : (beta i : Real) + t ≤ (beta i : Real) + (radius i : Real) := by + simpa only [add_comm, add_left_comm, add_assoc] using + add_le_add_left hbounds.2 (beta i : Real) + simpa only [add_comm, add_left_comm, add_assoc] using h + have hreal : + layerNormRealOfReal eps gamma beta x i = t + (beta i : Real) := by + simp [layerNormRealOfReal, hne, μ, invStd, varEps, t, add_comm] + have hlo : (bounds.1 i : Real) ≤ layerNormRealOfReal eps gamma beta x i := by + simpa [bounds, layerNormIntervalBounds, hne, radius, centeredBound, invStdBound, μLo, μHi, + hreal] using hlow + have hhi : layerNormRealOfReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + simpa [bounds, layerNormIntervalBounds, hne, radius, centeredBound, invStdBound, μLo, μHi, + hreal] using hhigh + exact And.intro hlo hhi + +end Bounds + +end Sound + +end Nfp diff --git a/Nfp/Sound/Bounds/LayerNorm/InvStd.lean b/Nfp/Sound/Bounds/LayerNorm/InvStd.lean index 120d3c8..aa88768 100644 --- a/Nfp/Sound/Bounds/LayerNorm/InvStd.lean +++ b/Nfp/Sound/Bounds/LayerNorm/InvStd.lean @@ -7,7 +7,7 @@ import Nfp.Sound.Bounds.LayerNorm.SqrtBounds Inverse-standard-deviation bounds for LayerNorm. This module isolates invStd bounds and their soundness proof to keep -`LayerNorm.lean` below the style linter's file-length limit. +`LayerNorm/Basic.lean` below the style linter's file-length limit. -/ namespace Nfp diff --git a/Nfp/Sound/Bounds/MatrixNorm.lean b/Nfp/Sound/Bounds/MatrixNorm.lean index 3e29e40..324dde6 100644 --- a/Nfp/Sound/Bounds/MatrixNorm.lean +++ b/Nfp/Sound/Bounds/MatrixNorm.lean @@ -1,183 +1,8 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.BigOperators.Fin -import Mathlib.Algebra.Order.BigOperators.Group.Finset -import Mathlib.Algebra.Order.Ring.Abs -import Mathlib.Data.Fintype.Basic -import Mathlib.Data.Matrix.Mul -import Mathlib.Data.Real.Basic -import Nfp.Circuit.Cert.DownstreamLinear -import Nfp.Circuit.Cert.ResidualInterval -import Nfp.Core.Basic +import Nfp.Sound.Bounds.MatrixNorm.Basic import Nfp.Sound.Bounds.MatrixNorm.Interval -import Nfp.Sound.Linear.FinFold /-! -Row-sum matrix norms for downstream linear certificates. - -These bounds are used to compute verified downstream error certificates -from explicit Rat matrices. +Matrix norm and interval bound helpers for downstream certificates. -/ - -namespace Nfp - -namespace Sound - -namespace Bounds - -open scoped BigOperators - -/-- Row-sum of absolute values for a matrix row. -/ -def rowSum {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : Rat := - Linear.sumFin n (fun j => |W i j|) - -/-- Weighted row-sum using per-coordinate bounds. -/ -def rowSumWeighted {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (bound : Fin n → Rat) (i : Fin m) : Rat := - Linear.sumFin n (fun j => |W i j| * bound j) - -/-- Maximum row-sum norm (defaults to `0` on empty matrices). -/ -def rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) : Rat := - Linear.foldlFin m (fun acc i => max acc (rowSum W i)) 0 - -/-- Maximum weighted row-sum (defaults to `0` on empty matrices). -/ -def rowSumWeightedNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (bound : Fin n → Rat) : Rat := - Linear.foldlFin m (fun acc i => max acc (rowSumWeighted W bound i)) 0 - -/-- Row-sums are nonnegative. -/ -theorem rowSum_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : - 0 ≤ rowSum W i := by - simpa [rowSum, Linear.sumFin_eq_sum_univ] using - (Finset.sum_nonneg (fun j _ => abs_nonneg (W i j))) - -/-- Weighted row-sums are nonnegative under nonnegative bounds. -/ -theorem rowSumWeighted_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (bound : Fin n → Rat) (i : Fin m) (hbound : ∀ j, 0 ≤ bound j) : - 0 ≤ rowSumWeighted W bound i := by - classical - have hsum : 0 ≤ ∑ j, |W i j| * bound j := by - refine Finset.sum_nonneg ?_ - intro j _ - exact mul_nonneg (abs_nonneg (W i j)) (hbound j) - simpa [rowSumWeighted, Linear.sumFin_eq_sum_univ] using hsum - -/-- Each row-sum is bounded by the row-sum norm. -/ -theorem rowSum_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : - rowSum W i ≤ rowSumNorm W := by - simpa [rowSumNorm] using - (foldlFin_max_ge (f := fun j => rowSum W j) i) - -/-- Each weighted row-sum is bounded by the weighted row-sum norm. -/ -theorem rowSumWeighted_le_rowSumWeightedNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (bound : Fin n → Rat) (i : Fin m) : - rowSumWeighted W bound i ≤ rowSumWeightedNorm W bound := by - simpa [rowSumWeightedNorm] using - (foldlFin_max_ge (f := fun j => rowSumWeighted W bound j) i) - -/-- The row-sum norm is nonnegative. -/ -theorem rowSumNorm_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) : - 0 ≤ rowSumNorm W := by - simpa [rowSumNorm] using - (foldlFin_max_ge_init (f := fun i => rowSum W i) (init := (0 : Rat))) - -/-- Weighted row-sum norm is nonnegative. -/ -theorem rowSumWeightedNorm_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (bound : Fin n → Rat) : - 0 ≤ rowSumWeightedNorm W bound := by - simpa [rowSumWeightedNorm] using - (foldlFin_max_ge_init (f := fun i => rowSumWeighted W bound i) (init := (0 : Rat))) - -/-- Downstream error from per-coordinate residual bounds. -/ -def downstreamErrorFromBounds {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (bound : Fin n → Rat) : Rat := - rowSumWeightedNorm W bound - -/-- `downstreamErrorFromBounds` is nonnegative. -/ -theorem downstreamErrorFromBounds_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (bound : Fin n → Rat) : - 0 ≤ downstreamErrorFromBounds W bound := by - simpa [downstreamErrorFromBounds] using rowSumWeightedNorm_nonneg W bound - -/-- Build a residual-interval certificate by applying a matrix to an input interval. -/ -def buildResidualIntervalCertFromMatrix {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (lo hi : Fin n → Rat) (hlohi : ∀ j, lo j ≤ hi j) : - {c : Circuit.ResidualIntervalCert m // Circuit.ResidualIntervalBounds c} := by - let lo' := mulVecIntervalLower W lo hi - let hi' := mulVecIntervalUpper W lo hi - refine ⟨{ lo := lo', hi := hi' }, ?_⟩ - refine { lo_le_hi := ?_ } - intro i - exact mulVecIntervalLower_le_upper W lo hi hlohi i - -/-- Summed absolute row entries factor out a scalar bound. -/ -theorem sum_abs_row_mul_eq_rowSum_mul {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (i : Fin m) (inputBound : Rat) : - (∑ j, |W i j| * inputBound) = rowSum W i * inputBound := by - have hsum : - (∑ j, |W i j|) * inputBound = ∑ j, |W i j| * inputBound := by - simpa using - (Finset.sum_mul - (s := (Finset.univ : Finset (Fin n))) - (f := fun j => |W i j|) - (a := inputBound)) - simpa [rowSum, Linear.sumFin_eq_sum_univ] using hsum.symm - -/-- Row-sum norm bounds a matrix-vector product under a uniform input bound. -/ -theorem abs_mulVec_le_rowSum {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (x : Fin n → Rat) (inputBound : Rat) - (hx : ∀ j, |x j| ≤ inputBound) : - ∀ i, |Matrix.mulVec W x i| ≤ rowSum W i * inputBound := by - intro i - have h1 : |∑ j, W i j * x j| ≤ ∑ j, |W i j * x j| := by - simpa using - (Finset.abs_sum_le_sum_abs - (f := fun j => W i j * x j) - (s := (Finset.univ : Finset (Fin n)))) - have h2 : ∑ j, |W i j * x j| ≤ ∑ j, |W i j| * inputBound := by - refine Finset.sum_le_sum ?_ - intro j _ - have hnonneg : 0 ≤ |W i j| := abs_nonneg (W i j) - calc - |W i j * x j| = |W i j| * |x j| := by - simp [abs_mul] - _ ≤ |W i j| * inputBound := by - exact mul_le_mul_of_nonneg_left (hx j) hnonneg - have h3 : ∑ j, |W i j| * inputBound = rowSum W i * inputBound := - sum_abs_row_mul_eq_rowSum_mul W i inputBound - simpa [Matrix.mulVec, dotProduct] using h1.trans (h2.trans_eq h3) - -/-- Row-sum norm bounds a matrix-vector product under a uniform input bound. -/ -theorem abs_mulVec_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (x : Fin n → Rat) (inputBound : Rat) - (hx : ∀ j, |x j| ≤ inputBound) (hinput : 0 ≤ inputBound) : - ∀ i, |Matrix.mulVec W x i| ≤ rowSumNorm W * inputBound := by - intro i - have hrow : |Matrix.mulVec W x i| ≤ rowSum W i * inputBound := - abs_mulVec_le_rowSum W x inputBound hx i - have hle : rowSum W i ≤ rowSumNorm W := rowSum_le_rowSumNorm W i - have hmul : rowSum W i * inputBound ≤ rowSumNorm W * inputBound := - mul_le_mul_of_nonneg_right hle hinput - exact hrow.trans hmul - -/-- Build a downstream linear certificate from a matrix and input bound. -/ -def buildDownstreamLinearCert {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (inputBound : Rat) (hinput : 0 ≤ inputBound) : - {c : Circuit.DownstreamLinearCert // Circuit.DownstreamLinearBounds c} := by - let gain := rowSumNorm W - let error := gain * inputBound - refine ⟨{ error := error, gain := gain, inputBound := inputBound }, ?_⟩ - refine - { error_nonneg := ?_ - gain_nonneg := ?_ - input_nonneg := hinput - error_eq := rfl } - · exact mul_nonneg (rowSumNorm_nonneg W) hinput - · exact rowSumNorm_nonneg W - - -end Bounds - -end Sound - -end Nfp diff --git a/Nfp/Sound/Bounds/MatrixNorm/Basic.lean b/Nfp/Sound/Bounds/MatrixNorm/Basic.lean new file mode 100644 index 0000000..3e29e40 --- /dev/null +++ b/Nfp/Sound/Bounds/MatrixNorm/Basic.lean @@ -0,0 +1,183 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.BigOperators.Fin +import Mathlib.Algebra.Order.BigOperators.Group.Finset +import Mathlib.Algebra.Order.Ring.Abs +import Mathlib.Data.Fintype.Basic +import Mathlib.Data.Matrix.Mul +import Mathlib.Data.Real.Basic +import Nfp.Circuit.Cert.DownstreamLinear +import Nfp.Circuit.Cert.ResidualInterval +import Nfp.Core.Basic +import Nfp.Sound.Bounds.MatrixNorm.Interval +import Nfp.Sound.Linear.FinFold + +/-! +Row-sum matrix norms for downstream linear certificates. + +These bounds are used to compute verified downstream error certificates +from explicit Rat matrices. +-/ + +namespace Nfp + +namespace Sound + +namespace Bounds + +open scoped BigOperators + +/-- Row-sum of absolute values for a matrix row. -/ +def rowSum {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : Rat := + Linear.sumFin n (fun j => |W i j|) + +/-- Weighted row-sum using per-coordinate bounds. -/ +def rowSumWeighted {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (bound : Fin n → Rat) (i : Fin m) : Rat := + Linear.sumFin n (fun j => |W i j| * bound j) + +/-- Maximum row-sum norm (defaults to `0` on empty matrices). -/ +def rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) : Rat := + Linear.foldlFin m (fun acc i => max acc (rowSum W i)) 0 + +/-- Maximum weighted row-sum (defaults to `0` on empty matrices). -/ +def rowSumWeightedNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (bound : Fin n → Rat) : Rat := + Linear.foldlFin m (fun acc i => max acc (rowSumWeighted W bound i)) 0 + +/-- Row-sums are nonnegative. -/ +theorem rowSum_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : + 0 ≤ rowSum W i := by + simpa [rowSum, Linear.sumFin_eq_sum_univ] using + (Finset.sum_nonneg (fun j _ => abs_nonneg (W i j))) + +/-- Weighted row-sums are nonnegative under nonnegative bounds. -/ +theorem rowSumWeighted_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (bound : Fin n → Rat) (i : Fin m) (hbound : ∀ j, 0 ≤ bound j) : + 0 ≤ rowSumWeighted W bound i := by + classical + have hsum : 0 ≤ ∑ j, |W i j| * bound j := by + refine Finset.sum_nonneg ?_ + intro j _ + exact mul_nonneg (abs_nonneg (W i j)) (hbound j) + simpa [rowSumWeighted, Linear.sumFin_eq_sum_univ] using hsum + +/-- Each row-sum is bounded by the row-sum norm. -/ +theorem rowSum_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : + rowSum W i ≤ rowSumNorm W := by + simpa [rowSumNorm] using + (foldlFin_max_ge (f := fun j => rowSum W j) i) + +/-- Each weighted row-sum is bounded by the weighted row-sum norm. -/ +theorem rowSumWeighted_le_rowSumWeightedNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (bound : Fin n → Rat) (i : Fin m) : + rowSumWeighted W bound i ≤ rowSumWeightedNorm W bound := by + simpa [rowSumWeightedNorm] using + (foldlFin_max_ge (f := fun j => rowSumWeighted W bound j) i) + +/-- The row-sum norm is nonnegative. -/ +theorem rowSumNorm_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) : + 0 ≤ rowSumNorm W := by + simpa [rowSumNorm] using + (foldlFin_max_ge_init (f := fun i => rowSum W i) (init := (0 : Rat))) + +/-- Weighted row-sum norm is nonnegative. -/ +theorem rowSumWeightedNorm_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (bound : Fin n → Rat) : + 0 ≤ rowSumWeightedNorm W bound := by + simpa [rowSumWeightedNorm] using + (foldlFin_max_ge_init (f := fun i => rowSumWeighted W bound i) (init := (0 : Rat))) + +/-- Downstream error from per-coordinate residual bounds. -/ +def downstreamErrorFromBounds {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (bound : Fin n → Rat) : Rat := + rowSumWeightedNorm W bound + +/-- `downstreamErrorFromBounds` is nonnegative. -/ +theorem downstreamErrorFromBounds_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (bound : Fin n → Rat) : + 0 ≤ downstreamErrorFromBounds W bound := by + simpa [downstreamErrorFromBounds] using rowSumWeightedNorm_nonneg W bound + +/-- Build a residual-interval certificate by applying a matrix to an input interval. -/ +def buildResidualIntervalCertFromMatrix {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (lo hi : Fin n → Rat) (hlohi : ∀ j, lo j ≤ hi j) : + {c : Circuit.ResidualIntervalCert m // Circuit.ResidualIntervalBounds c} := by + let lo' := mulVecIntervalLower W lo hi + let hi' := mulVecIntervalUpper W lo hi + refine ⟨{ lo := lo', hi := hi' }, ?_⟩ + refine { lo_le_hi := ?_ } + intro i + exact mulVecIntervalLower_le_upper W lo hi hlohi i + +/-- Summed absolute row entries factor out a scalar bound. -/ +theorem sum_abs_row_mul_eq_rowSum_mul {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (i : Fin m) (inputBound : Rat) : + (∑ j, |W i j| * inputBound) = rowSum W i * inputBound := by + have hsum : + (∑ j, |W i j|) * inputBound = ∑ j, |W i j| * inputBound := by + simpa using + (Finset.sum_mul + (s := (Finset.univ : Finset (Fin n))) + (f := fun j => |W i j|) + (a := inputBound)) + simpa [rowSum, Linear.sumFin_eq_sum_univ] using hsum.symm + +/-- Row-sum norm bounds a matrix-vector product under a uniform input bound. -/ +theorem abs_mulVec_le_rowSum {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (x : Fin n → Rat) (inputBound : Rat) + (hx : ∀ j, |x j| ≤ inputBound) : + ∀ i, |Matrix.mulVec W x i| ≤ rowSum W i * inputBound := by + intro i + have h1 : |∑ j, W i j * x j| ≤ ∑ j, |W i j * x j| := by + simpa using + (Finset.abs_sum_le_sum_abs + (f := fun j => W i j * x j) + (s := (Finset.univ : Finset (Fin n)))) + have h2 : ∑ j, |W i j * x j| ≤ ∑ j, |W i j| * inputBound := by + refine Finset.sum_le_sum ?_ + intro j _ + have hnonneg : 0 ≤ |W i j| := abs_nonneg (W i j) + calc + |W i j * x j| = |W i j| * |x j| := by + simp [abs_mul] + _ ≤ |W i j| * inputBound := by + exact mul_le_mul_of_nonneg_left (hx j) hnonneg + have h3 : ∑ j, |W i j| * inputBound = rowSum W i * inputBound := + sum_abs_row_mul_eq_rowSum_mul W i inputBound + simpa [Matrix.mulVec, dotProduct] using h1.trans (h2.trans_eq h3) + +/-- Row-sum norm bounds a matrix-vector product under a uniform input bound. -/ +theorem abs_mulVec_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (x : Fin n → Rat) (inputBound : Rat) + (hx : ∀ j, |x j| ≤ inputBound) (hinput : 0 ≤ inputBound) : + ∀ i, |Matrix.mulVec W x i| ≤ rowSumNorm W * inputBound := by + intro i + have hrow : |Matrix.mulVec W x i| ≤ rowSum W i * inputBound := + abs_mulVec_le_rowSum W x inputBound hx i + have hle : rowSum W i ≤ rowSumNorm W := rowSum_le_rowSumNorm W i + have hmul : rowSum W i * inputBound ≤ rowSumNorm W * inputBound := + mul_le_mul_of_nonneg_right hle hinput + exact hrow.trans hmul + +/-- Build a downstream linear certificate from a matrix and input bound. -/ +def buildDownstreamLinearCert {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) + (inputBound : Rat) (hinput : 0 ≤ inputBound) : + {c : Circuit.DownstreamLinearCert // Circuit.DownstreamLinearBounds c} := by + let gain := rowSumNorm W + let error := gain * inputBound + refine ⟨{ error := error, gain := gain, inputBound := inputBound }, ?_⟩ + refine + { error_nonneg := ?_ + gain_nonneg := ?_ + input_nonneg := hinput + error_eq := rfl } + · exact mul_nonneg (rowSumNorm_nonneg W) hinput + · exact rowSumNorm_nonneg W + + +end Bounds + +end Sound + +end Nfp diff --git a/Nfp/Sound/Bounds/Transformer.lean b/Nfp/Sound/Bounds/Transformer.lean index 87fef74..d76e3f9 100644 --- a/Nfp/Sound/Bounds/Transformer.lean +++ b/Nfp/Sound/Bounds/Transformer.lean @@ -1,564 +1,8 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Data.List.Range -import Mathlib.Data.Real.Basic -import Nfp.Circuit.Cert.ResidualInterval -import Nfp.Model.Gpt2 -import Nfp.Sound.Bounds.Attention -import Nfp.Sound.Bounds.LayerNorm +import Nfp.Sound.Bounds.Transformer.Basic import Nfp.Sound.Bounds.Transformer.Embedding -import Nfp.Sound.Linear.FinFold /-! -Interval bounds for transformer stacks and final LayerNorm outputs. +Transformer-stack interval bounds and supporting lemmas. -/ - -namespace Nfp - -namespace Sound - -namespace Bounds - -open scoped BigOperators - -/-- Real-valued output of a transformer layer. -/ -noncomputable def transformerLayerReal {seq dModel dHead numHeads hidden : Nat} [NeZero seq] - (eps : Rat) (layer : Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (scores : Fin numHeads → Fin seq → Fin seq → Real) - (x : Fin seq → Fin dModel → Real) (q : Fin seq) (i : Fin dModel) : Real := - x q i + - attentionOutputReal eps layer.ln1Gamma layer.ln1Beta heads layer.attnBias scores x q i + - mlpReal layer.mlpWIn layer.mlpBIn layer.mlpWOut layer.mlpBOut - (layerNormRealOfReal eps layer.ln2Gamma layer.ln2Beta - (fun j => - x q j + attentionOutputReal eps layer.ln1Gamma layer.ln1Beta heads - layer.attnBias scores x q j)) i - -/-- `transformerLayerBounds` soundness for `transformerLayerReal`. -/ -theorem transformerLayerBounds_spec_real {seq dModel dHead numHeads hidden : Nat} [NeZero seq] - (eps : Rat) (layer : Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (scores : Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) - (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) - (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : - let bounds := transformerLayerBounds eps layer.ln1Gamma layer.ln1Beta layer.ln2Gamma - layer.ln2Beta heads layer.attnBias layer.mlpWIn layer.mlpBIn layer.mlpWOut - layer.mlpBOut lo hi - ∀ q i, - (bounds.1 i : Real) ≤ transformerLayerReal eps layer heads scores x q i ∧ - transformerLayerReal eps layer heads scores x q i ≤ (bounds.2 i : Real) := by - classical - simpa [transformerLayerReal] using - (transformerLayerBounds_spec (eps := eps) - (ln1Gamma := layer.ln1Gamma) (ln1Beta := layer.ln1Beta) - (ln2Gamma := layer.ln2Gamma) (ln2Beta := layer.ln2Beta) - (heads := heads) (attnBias := layer.attnBias) - (mlpWIn := layer.mlpWIn) (mlpBIn := layer.mlpBIn) - (mlpWOut := layer.mlpWOut) (mlpBOut := layer.mlpBOut) - (scores := scores) (lo := lo) (hi := hi) (x := x) - hne heps hsqrt hlo hhi) - -/-- Interval bounds for a transformer layer from per-position bounds. -/ -def transformerLayerBoundsPos {seq dModel dHead numHeads hidden : Nat} [NeZero seq] - (eps : Rat) (layer : Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (lo hi : Fin seq → Fin dModel → Rat) : - (Fin seq → Fin dModel → Rat) × (Fin seq → Fin dModel → Rat) := - let positions := (Finset.univ : Finset (Fin seq)) - let hpos : positions.Nonempty := by - simp [positions] - let loCached := cacheBound2 lo - let hiCached := cacheBound2 hi - let base := intervalBoundsOn positions hpos loCached hiCached - let baseLo := cacheBound base.1 - let baseHi := cacheBound base.2 - let attn := attentionOutputBounds eps layer.ln1Gamma layer.ln1Beta heads layer.attnBias - baseLo baseHi - let attnLo := cacheBound attn.1 - let attnHi := cacheBound attn.2 - let yLo : Fin seq → Fin dModel → Rat := fun q i => loCached q i + attnLo i - let yHi : Fin seq → Fin dModel → Rat := fun q i => hiCached q i + attnHi i - let yLoCached := cacheBound2 yLo - let yHiCached := cacheBound2 yHi - let out := cacheBoundPair2 (fun q => - layerNormAbsMlpResidualBounds eps layer.ln2Gamma layer.ln2Beta - layer.mlpWIn layer.mlpBIn layer.mlpWOut layer.mlpBOut - (yLoCached q) (yHiCached q)) - out - -/-- `transformerLayerBoundsPos` soundness for `transformerLayerReal`. -/ -theorem transformerLayerBoundsPos_spec {seq dModel dHead numHeads hidden : Nat} [NeZero seq] - (eps : Rat) (layer : Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (scores : Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin seq → Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) - (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) - (hlo : ∀ q i, (lo q i : Real) ≤ x q i) - (hhi : ∀ q i, x q i ≤ (hi q i : Real)) : - let bounds := transformerLayerBoundsPos eps layer heads lo hi - ∀ q i, - (bounds.1 q i : Real) ≤ transformerLayerReal eps layer heads scores x q i ∧ - transformerLayerReal eps layer heads scores x q i ≤ (bounds.2 q i : Real) := by - classical - intro bounds q i - let positions := (Finset.univ : Finset (Fin seq)) - have hpos : positions.Nonempty := by - simp [positions] - let loCached := cacheBound2 lo - let hiCached := cacheBound2 hi - have hloCached : ∀ q i, (loCached q i : Real) ≤ x q i := by - intro q i - simpa [loCached, cacheBound2_apply] using hlo q i - have hhiCached : ∀ q i, x q i ≤ (hiCached q i : Real) := by - intro q i - simpa [hiCached, cacheBound2_apply] using hhi q i - let base := intervalBoundsOn positions hpos loCached hiCached - have hbase := intervalBoundsOn_spec positions hpos loCached hiCached x - (fun q _ i => hloCached q i) (fun q _ i => hhiCached q i) - have hloBase : ∀ q i, (base.1 i : Real) ≤ x q i := fun q i => - (hbase q (by simp [positions]) i).1 - have hhiBase : ∀ q i, x q i ≤ (base.2 i : Real) := fun q i => - (hbase q (by simp [positions]) i).2 - let baseLo := cacheBound base.1 - let baseHi := cacheBound base.2 - have hloBaseCached : ∀ q i, (baseLo i : Real) ≤ x q i := by - intro q i - simpa [baseLo, cacheBound_apply] using hloBase q i - have hhiBaseCached : ∀ q i, x q i ≤ (baseHi i : Real) := by - intro q i - simpa [baseHi, cacheBound_apply] using hhiBase q i - let attn := attentionOutputBounds eps layer.ln1Gamma layer.ln1Beta heads layer.attnBias - baseLo baseHi - have hattn := attentionOutputBounds_spec eps layer.ln1Gamma layer.ln1Beta heads - layer.attnBias scores baseLo baseHi x hne heps hsqrt hloBaseCached hhiBaseCached q - let attnLo := cacheBound attn.1 - let attnHi := cacheBound attn.2 - let y := fun j => - x q j + attentionOutputReal eps layer.ln1Gamma layer.ln1Beta heads - layer.attnBias scores x q j - have yLo : ∀ j, (loCached q j : Real) + (attn.1 j : Real) ≤ y j := by - intro j - have hlow : - (loCached q j : Real) + (attn.1 j : Real) ≤ - x q j + - attentionOutputReal eps layer.ln1Gamma layer.ln1Beta heads - layer.attnBias scores x q j := by - exact add_le_add (hloCached q j) (hattn j).1 - simpa [y] using hlow - have yHi : ∀ j, y j ≤ (hiCached q j : Real) + (attn.2 j : Real) := by - intro j - have hhigh : - x q j + - attentionOutputReal eps layer.ln1Gamma layer.ln1Beta heads - layer.attnBias scores x q j ≤ - (hiCached q j : Real) + (attn.2 j : Real) := by - exact add_le_add (hhiCached q j) (hattn j).2 - simpa [y] using hhigh - let yLoCached := cacheBound2 (fun q i => loCached q i + attnLo i) - let yHiCached := cacheBound2 (fun q i => hiCached q i + attnHi i) - have yLoCached_bound : ∀ j, (yLoCached q j : Real) ≤ y j := by - intro j - simpa [yLoCached, attnLo, cacheBound_apply, cacheBound2_apply] using (yLo j) - have yHiCached_bound : ∀ j, y j ≤ (yHiCached q j : Real) := by - intro j - simpa [yHiCached, attnHi, cacheBound_apply, cacheBound2_apply] using (yHi j) - have hmlp := - layerNormAbsMlpResidualBounds_spec eps layer.ln2Gamma layer.ln2Beta - layer.mlpWIn layer.mlpBIn layer.mlpWOut layer.mlpBOut - (yLoCached q) (yHiCached q) y hne heps hsqrt yLoCached_bound yHiCached_bound - have hmlp_i := hmlp i - simpa [bounds, transformerLayerBoundsPos, positions, base, loCached, hiCached, baseLo, baseHi, - attn, attnLo, attnHi, y, yLoCached, yHiCached, cacheBound2_apply, cacheBoundPair2_apply_left, - cacheBoundPair2_apply_right, transformerLayerReal, cacheBound_apply] using hmlp_i - -/-- Real-valued transformer stack output (folded left over layers). -/ -noncomputable def transformerStackReal - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (x : Fin seq → Fin dModel → Real) : Fin seq → Fin dModel → Real := - let step := fun x layerIdx => - transformerLayerReal eps (layers layerIdx) (heads layerIdx) (scores layerIdx) x - Linear.foldlFin numLayers step x - -/-- Interval bounds for a transformer stack (folded left over layers). -/ -def transformerStackBounds {dModel dHead numHeads hidden numLayers : Nat} - (eps : Rat) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (lo hi : Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := - let step := fun bounds layerIdx => - transformerLayerBounds eps (layers layerIdx).ln1Gamma (layers layerIdx).ln1Beta - (layers layerIdx).ln2Gamma (layers layerIdx).ln2Beta (heads layerIdx) - (layers layerIdx).attnBias (layers layerIdx).mlpWIn (layers layerIdx).mlpBIn - (layers layerIdx).mlpWOut (layers layerIdx).mlpBOut bounds.1 bounds.2 - Linear.foldlFin numLayers step (lo, hi) - -/-- Interval bounds for a transformer stack from per-position bounds. -/ -def transformerStackBoundsPos {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] - (eps : Rat) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (lo hi : Fin seq → Fin dModel → Rat) : - (Fin seq → Fin dModel → Rat) × (Fin seq → Fin dModel → Rat) := - let step := fun bounds layerIdx => - transformerLayerBoundsPos eps (layers layerIdx) (heads layerIdx) bounds.1 bounds.2 - Linear.foldlFin numLayers step (lo, hi) - -private theorem transformerStackBoundsPos_spec_list - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] - (eps : Rat) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : - ∀ (ls : List (Fin numLayers)) (lo hi : Fin seq → Fin dModel → Rat) - (x : Fin seq → Fin dModel → Real), - (∀ q i, (lo q i : Real) ≤ x q i) → - (∀ q i, x q i ≤ (hi q i : Real)) → - let bounds := (ls.foldl - (fun bounds layerIdx => - transformerLayerBoundsPos eps (layers layerIdx) (heads layerIdx) bounds.1 bounds.2) - (lo, hi)) - let x' := (ls.foldl - (fun x layerIdx => - transformerLayerReal eps (layers layerIdx) (heads layerIdx) (scores layerIdx) x) - x) - ∀ q i, - (bounds.1 q i : Real) ≤ x' q i ∧ - x' q i ≤ (bounds.2 q i : Real) := by - intro ls lo hi x hlo hhi - induction ls generalizing lo hi x hlo hhi with - | nil => - simpa using fun q i => And.intro (hlo q i) (hhi q i) - | cons l ls ih => - have hstep := - transformerLayerBoundsPos_spec eps (layers l) (heads l) (scores l) lo hi x - hne heps hsqrt hlo hhi - let bounds1 := transformerLayerBoundsPos eps (layers l) (heads l) lo hi - let x1 := transformerLayerReal eps (layers l) (heads l) (scores l) x - have hlo1 : ∀ q i, (bounds1.1 q i : Real) ≤ x1 q i := fun q i => (hstep q i).1 - have hhi1 : ∀ q i, x1 q i ≤ (bounds1.2 q i : Real) := fun q i => (hstep q i).2 - have ih' := ih bounds1.1 bounds1.2 x1 hlo1 hhi1 - simpa [bounds1, x1] using ih' - -/-- `transformerStackBoundsPos` soundness for real transformer-stack outputs. -/ -theorem transformerStackBoundsPos_spec {seq dModel dHead numHeads hidden numLayers : Nat} - [NeZero seq] (eps : Rat) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin seq → Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) - (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) - (hlo : ∀ q i, (lo q i : Real) ≤ x q i) - (hhi : ∀ q i, x q i ≤ (hi q i : Real)) : - let bounds := transformerStackBoundsPos eps layers heads lo hi - ∀ q i, - (bounds.1 q i : Real) ≤ transformerStackReal eps layers heads scores x q i ∧ - transformerStackReal eps layers heads scores x q i ≤ (bounds.2 q i : Real) := by - classical - simpa [transformerStackBoundsPos, transformerStackReal, Linear.foldlFin_eq_foldl, - Fin.foldl_eq_foldl_finRange] using - transformerStackBoundsPos_spec_list eps layers heads scores hne heps hsqrt - (List.finRange numLayers) lo hi x hlo hhi - -private theorem transformerStackBounds_spec_list - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] - (eps : Rat) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : - ∀ (ls : List (Fin numLayers)) (lo hi : Fin dModel → Rat) - (x : Fin seq → Fin dModel → Real), - (∀ q i, (lo i : Real) ≤ x q i) → - (∀ q i, x q i ≤ (hi i : Real)) → - let bounds := (ls.foldl - (fun bounds layerIdx => - transformerLayerBounds eps (layers layerIdx).ln1Gamma (layers layerIdx).ln1Beta - (layers layerIdx).ln2Gamma (layers layerIdx).ln2Beta (heads layerIdx) - (layers layerIdx).attnBias (layers layerIdx).mlpWIn (layers layerIdx).mlpBIn - (layers layerIdx).mlpWOut (layers layerIdx).mlpBOut bounds.1 bounds.2) - (lo, hi)) - let x' := (ls.foldl - (fun x layerIdx => - transformerLayerReal eps (layers layerIdx) (heads layerIdx) (scores layerIdx) x) - x) - ∀ q i, - (bounds.1 i : Real) ≤ x' q i ∧ - x' q i ≤ (bounds.2 i : Real) := by - intro ls lo hi x hlo hhi - induction ls generalizing lo hi x hlo hhi with - | nil => - simpa using fun q i => And.intro (hlo q i) (hhi q i) - | cons l ls ih => - have hstep := - transformerLayerBounds_spec_real eps (layers l) (heads l) (scores l) lo hi x - hne heps hsqrt hlo hhi - let bounds1 := - transformerLayerBounds eps (layers l).ln1Gamma (layers l).ln1Beta (layers l).ln2Gamma - (layers l).ln2Beta (heads l) (layers l).attnBias (layers l).mlpWIn (layers l).mlpBIn - (layers l).mlpWOut (layers l).mlpBOut lo hi - let x1 := transformerLayerReal eps (layers l) (heads l) (scores l) x - have hlo1 : ∀ q i, (bounds1.1 i : Real) ≤ x1 q i := fun q i => (hstep q i).1 - have hhi1 : ∀ q i, x1 q i ≤ (bounds1.2 i : Real) := fun q i => (hstep q i).2 - have ih' := ih bounds1.1 bounds1.2 x1 hlo1 hhi1 - simpa [bounds1, x1] using ih' - -/-- `transformerStackBounds` soundness for real transformer-stack outputs. -/ -theorem transformerStackBounds_spec {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] - (eps : Rat) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) - (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) - (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : - let bounds := transformerStackBounds eps layers heads lo hi - ∀ q i, - (bounds.1 i : Real) ≤ transformerStackReal eps layers heads scores x q i ∧ - transformerStackReal eps layers heads scores x q i ≤ (bounds.2 i : Real) := by - classical - simpa [transformerStackBounds, transformerStackReal, Linear.foldlFin_eq_foldl, - Fin.foldl_eq_foldl_finRange] using - transformerStackBounds_spec_list eps layers heads scores hne heps hsqrt - (List.finRange numLayers) lo hi x hlo hhi - -/-- Real-valued transformer stack output after the final LayerNorm. -/ -noncomputable def transformerStackFinalReal {seq dModel dHead numHeads hidden numLayers : Nat} - [NeZero seq] (eps : Rat) (finalLn : Model.Gpt2FinalLayerNorm dModel) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (x : Fin seq → Fin dModel → Real) (q : Fin seq) (i : Fin dModel) : Real := - layerNormRealOfReal eps finalLn.gamma finalLn.beta - (fun j => transformerStackReal eps layers heads scores x q j) i - -/-- Interval bounds for transformer stack outputs after the final LayerNorm. -/ -def transformerStackFinalBounds {dModel dHead numHeads hidden numLayers : Nat} - (eps : Rat) (finalLn : Model.Gpt2FinalLayerNorm dModel) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (lo hi : Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := - let stack := transformerStackBounds eps layers heads lo hi - layerNormIntervalBounds eps finalLn.gamma finalLn.beta stack.1 stack.2 - -/-- `transformerStackFinalBounds` soundness for real outputs. -/ -theorem transformerStackFinalBounds_spec {seq dModel dHead numHeads hidden numLayers : Nat} - [NeZero seq] (eps : Rat) (finalLn : Model.Gpt2FinalLayerNorm dModel) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) - (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) - (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : - let bounds := transformerStackFinalBounds eps finalLn layers heads lo hi - ∀ q i, - (bounds.1 i : Real) ≤ transformerStackFinalReal eps finalLn layers heads scores x q i ∧ - transformerStackFinalReal eps finalLn layers heads scores x q i ≤ (bounds.2 i : Real) := by - classical - intro bounds q i - let stack := transformerStackBounds eps layers heads lo hi - have hstack := - transformerStackBounds_spec eps layers heads scores lo hi x hne heps hsqrt hlo hhi q - have hlo' : ∀ k, (stack.1 k : Real) ≤ transformerStackReal eps layers heads scores x q k := - fun k => (hstack k).1 - have hhi' : ∀ k, transformerStackReal eps layers heads scores x q k ≤ (stack.2 k : Real) := - fun k => (hstack k).2 - have hln := - layerNormIntervalBounds_spec_real eps finalLn.gamma finalLn.beta stack.1 stack.2 - (fun j => transformerStackReal eps layers heads scores x q j) hne heps hsqrt hlo' hhi' - simpa [bounds, transformerStackFinalBounds, stack, transformerStackFinalReal] using hln i - -/-- Interval bounds for transformer stack outputs after the final LayerNorm (per-position). -/ -def transformerStackFinalBoundsPos - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) - (finalLn : Model.Gpt2FinalLayerNorm dModel) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (lo hi : Fin seq → Fin dModel → Rat) : - (Fin seq → Fin dModel → Rat) × (Fin seq → Fin dModel → Rat) := - let stack := transformerStackBoundsPos eps layers heads lo hi - let ln := fun q => - layerNormIntervalBounds eps finalLn.gamma finalLn.beta (stack.1 q) (stack.2 q) - (fun q i => (ln q).1 i, fun q i => (ln q).2 i) - -/-- `transformerStackFinalBoundsPos` soundness for real outputs. -/ -theorem transformerStackFinalBoundsPos_spec - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) - (finalLn : Model.Gpt2FinalLayerNorm dModel) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin seq → Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) - (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) - (hlo : ∀ q i, (lo q i : Real) ≤ x q i) - (hhi : ∀ q i, x q i ≤ (hi q i : Real)) : - let bounds := transformerStackFinalBoundsPos eps finalLn layers heads lo hi - ∀ q i, - (bounds.1 q i : Real) ≤ - transformerStackFinalReal eps finalLn layers heads scores x q i ∧ - transformerStackFinalReal eps finalLn layers heads scores x q i ≤ - (bounds.2 q i : Real) := by - classical - intro bounds q i - let stack := transformerStackBoundsPos eps layers heads lo hi - have hstack := - transformerStackBoundsPos_spec eps layers heads scores lo hi x hne heps hsqrt hlo hhi q - have hlo' : ∀ j, (stack.1 q j : Real) ≤ transformerStackReal eps layers heads scores x q j := - fun j => (hstack j).1 - have hhi' : ∀ j, transformerStackReal eps layers heads scores x q j ≤ (stack.2 q j : Real) := - fun j => (hstack j).2 - have hln := - layerNormIntervalBounds_spec_real eps finalLn.gamma finalLn.beta (stack.1 q) (stack.2 q) - (fun j => transformerStackReal eps layers heads scores x q j) hne heps hsqrt hlo' hhi' - simpa [bounds, transformerStackFinalBoundsPos, stack, transformerStackFinalReal] using hln i - -/-- Residual interval bounds for a GPT-2 stack from exact embeddings. -/ -def gpt2ResidualIntervalBounds - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (finalLn : Model.Gpt2FinalLayerNorm dModel) - (embed : Fin seq → Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := - let base := embeddingIntervalBounds embed - transformerStackFinalBounds eps finalLn layers heads base.1 base.2 - -/-- `gpt2ResidualIntervalBounds` soundness for real GPT-2 outputs. -/ -theorem gpt2ResidualIntervalBounds_spec - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (finalLn : Model.Gpt2FinalLayerNorm dModel) - (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (embed : Fin seq → Fin dModel → Rat) - (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : - let bounds := gpt2ResidualIntervalBounds eps layers heads finalLn embed - ∀ q i, - (bounds.1 i : Real) ≤ - transformerStackFinalReal eps finalLn layers heads scores - (fun q i => (embed q i : Real)) q i ∧ - transformerStackFinalReal eps finalLn layers heads scores - (fun q i => (embed q i : Real)) q i ≤ (bounds.2 i : Real) := by - classical - intro bounds q i - let base := embeddingIntervalBounds embed - have hbase := embeddingIntervalBounds_spec embed - have hlo : ∀ q i, (base.1 i : Real) ≤ (embed q i : Real) := fun q i => (hbase q i).1 - have hhi : ∀ q i, (embed q i : Real) ≤ (base.2 i : Real) := fun q i => (hbase q i).2 - have hstack := - transformerStackFinalBounds_spec eps finalLn layers heads scores base.1 base.2 - (fun q i => (embed q i : Real)) hne heps hsqrt hlo hhi q i - simpa [bounds, gpt2ResidualIntervalBounds, base] using hstack - -/-- Residual interval bounds over an active set from exact embeddings. -/ -def gpt2ResidualIntervalBoundsActive - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] - (active : Finset (Fin seq)) (hactive : active.Nonempty) (eps : Rat) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (finalLn : Model.Gpt2FinalLayerNorm dModel) - (embed : Fin seq → Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := - let baseLo : Fin seq → Fin dModel → Rat := embed - let baseHi : Fin seq → Fin dModel → Rat := embed - let final := transformerStackFinalBoundsPos eps finalLn layers heads baseLo baseHi - intervalBoundsOn active hactive final.1 final.2 - -/-- `gpt2ResidualIntervalBoundsActive` soundness for real GPT-2 outputs. -/ -theorem gpt2ResidualIntervalBoundsActive_spec - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] - (active : Finset (Fin seq)) (hactive : active.Nonempty) (eps : Rat) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (finalLn : Model.Gpt2FinalLayerNorm dModel) - (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (embed : Fin seq → Fin dModel → Rat) - (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : - let bounds := gpt2ResidualIntervalBoundsActive active hactive eps layers heads finalLn embed - ∀ q, q ∈ active → ∀ i, - (bounds.1 i : Real) ≤ - transformerStackFinalReal eps finalLn layers heads scores - (fun q i => (embed q i : Real)) q i ∧ - transformerStackFinalReal eps finalLn layers heads scores - (fun q i => (embed q i : Real)) q i ≤ (bounds.2 i : Real) := by - classical - intro bounds q hq i - let baseLo : Fin seq → Fin dModel → Rat := embed - let baseHi : Fin seq → Fin dModel → Rat := embed - let final := transformerStackFinalBoundsPos eps finalLn layers heads baseLo baseHi - have hfinal := - transformerStackFinalBoundsPos_spec eps finalLn layers heads scores baseLo baseHi - (fun q i => (embed q i : Real)) hne heps hsqrt - (fun q i => by simp [baseLo]) - (fun q i => by simp [baseHi]) - have hlo : ∀ q, q ∈ active → ∀ i, - (final.1 q i : Real) ≤ - transformerStackFinalReal eps finalLn layers heads scores - (fun q i => (embed q i : Real)) q i := by - intro q hq i - simpa [final] using (hfinal q i).1 - have hhi : ∀ q, q ∈ active → ∀ i, - transformerStackFinalReal eps finalLn layers heads scores - (fun q i => (embed q i : Real)) q i ≤ (final.2 q i : Real) := by - intro q hq i - simpa [final] using (hfinal q i).2 - have hbounds := intervalBoundsOn_spec active hactive final.1 final.2 - (fun q i => transformerStackFinalReal eps finalLn layers heads scores - (fun q i => (embed q i : Real)) q i) - hlo hhi - simpa [bounds, gpt2ResidualIntervalBoundsActive, final, baseLo, baseHi] using - hbounds q hq i - -/-- Package GPT-2 residual bounds into a residual-interval certificate. -/ -theorem gpt2ResidualIntervalBoundsActive_sound - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] - (active : Finset (Fin seq)) (hactive : active.Nonempty) (eps : Rat) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (finalLn : Model.Gpt2FinalLayerNorm dModel) - (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (embed : Fin seq → Fin dModel → Rat) - (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : - let bounds := gpt2ResidualIntervalBoundsActive active hactive eps layers heads finalLn embed - let cert : Circuit.ResidualIntervalCert dModel := { lo := bounds.1, hi := bounds.2 } - Circuit.ResidualIntervalBounds cert ∧ - ∀ q, q ∈ active → ∀ i, - (cert.lo i : Real) ≤ - transformerStackFinalReal eps finalLn layers heads scores - (fun q i => (embed q i : Real)) q i ∧ - transformerStackFinalReal eps finalLn layers heads scores - (fun q i => (embed q i : Real)) q i ≤ (cert.hi i : Real) := by - classical - intro bounds cert - have hspec : - ∀ q, q ∈ active → ∀ i, - (bounds.1 i : Real) ≤ - transformerStackFinalReal eps finalLn layers heads scores - (fun q i => (embed q i : Real)) q i ∧ - transformerStackFinalReal eps finalLn layers heads scores - (fun q i => (embed q i : Real)) q i ≤ (bounds.2 i : Real) := by - simpa [bounds] using - (gpt2ResidualIntervalBoundsActive_spec (active := active) (hactive := hactive) - (eps := eps) (layers := layers) (heads := heads) (finalLn := finalLn) - (scores := scores) (embed := embed) (hne := hne) (heps := heps) (hsqrt := hsqrt)) - have hbounds : Circuit.ResidualIntervalBounds cert := by - refine { lo_le_hi := ?_ } - intro i - rcases hactive with ⟨q0, hq0⟩ - have hq := hspec q0 hq0 i - have hreal : (bounds.1 i : Real) ≤ (bounds.2 i : Real) := hq.1.trans hq.2 - exact (ratToReal_le_iff (x := bounds.1 i) (y := bounds.2 i)).1 hreal - refine And.intro hbounds ?_ - intro q hq i - have hq' := hspec q hq i - simpa [cert] using hq' - -end Bounds - -end Sound - -end Nfp diff --git a/Nfp/Sound/Bounds/Transformer/Basic.lean b/Nfp/Sound/Bounds/Transformer/Basic.lean new file mode 100644 index 0000000..87fef74 --- /dev/null +++ b/Nfp/Sound/Bounds/Transformer/Basic.lean @@ -0,0 +1,564 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.BigOperators.Group.Finset.Basic +import Mathlib.Data.List.Range +import Mathlib.Data.Real.Basic +import Nfp.Circuit.Cert.ResidualInterval +import Nfp.Model.Gpt2 +import Nfp.Sound.Bounds.Attention +import Nfp.Sound.Bounds.LayerNorm +import Nfp.Sound.Bounds.Transformer.Embedding +import Nfp.Sound.Linear.FinFold + +/-! +Interval bounds for transformer stacks and final LayerNorm outputs. +-/ + +namespace Nfp + +namespace Sound + +namespace Bounds + +open scoped BigOperators + +/-- Real-valued output of a transformer layer. -/ +noncomputable def transformerLayerReal {seq dModel dHead numHeads hidden : Nat} [NeZero seq] + (eps : Rat) (layer : Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (scores : Fin numHeads → Fin seq → Fin seq → Real) + (x : Fin seq → Fin dModel → Real) (q : Fin seq) (i : Fin dModel) : Real := + x q i + + attentionOutputReal eps layer.ln1Gamma layer.ln1Beta heads layer.attnBias scores x q i + + mlpReal layer.mlpWIn layer.mlpBIn layer.mlpWOut layer.mlpBOut + (layerNormRealOfReal eps layer.ln2Gamma layer.ln2Beta + (fun j => + x q j + attentionOutputReal eps layer.ln1Gamma layer.ln1Beta heads + layer.attnBias scores x q j)) i + +/-- `transformerLayerBounds` soundness for `transformerLayerReal`. -/ +theorem transformerLayerBounds_spec_real {seq dModel dHead numHeads hidden : Nat} [NeZero seq] + (eps : Rat) (layer : Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (scores : Fin numHeads → Fin seq → Fin seq → Real) + (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) + (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : + let bounds := transformerLayerBounds eps layer.ln1Gamma layer.ln1Beta layer.ln2Gamma + layer.ln2Beta heads layer.attnBias layer.mlpWIn layer.mlpBIn layer.mlpWOut + layer.mlpBOut lo hi + ∀ q i, + (bounds.1 i : Real) ≤ transformerLayerReal eps layer heads scores x q i ∧ + transformerLayerReal eps layer heads scores x q i ≤ (bounds.2 i : Real) := by + classical + simpa [transformerLayerReal] using + (transformerLayerBounds_spec (eps := eps) + (ln1Gamma := layer.ln1Gamma) (ln1Beta := layer.ln1Beta) + (ln2Gamma := layer.ln2Gamma) (ln2Beta := layer.ln2Beta) + (heads := heads) (attnBias := layer.attnBias) + (mlpWIn := layer.mlpWIn) (mlpBIn := layer.mlpBIn) + (mlpWOut := layer.mlpWOut) (mlpBOut := layer.mlpBOut) + (scores := scores) (lo := lo) (hi := hi) (x := x) + hne heps hsqrt hlo hhi) + +/-- Interval bounds for a transformer layer from per-position bounds. -/ +def transformerLayerBoundsPos {seq dModel dHead numHeads hidden : Nat} [NeZero seq] + (eps : Rat) (layer : Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (lo hi : Fin seq → Fin dModel → Rat) : + (Fin seq → Fin dModel → Rat) × (Fin seq → Fin dModel → Rat) := + let positions := (Finset.univ : Finset (Fin seq)) + let hpos : positions.Nonempty := by + simp [positions] + let loCached := cacheBound2 lo + let hiCached := cacheBound2 hi + let base := intervalBoundsOn positions hpos loCached hiCached + let baseLo := cacheBound base.1 + let baseHi := cacheBound base.2 + let attn := attentionOutputBounds eps layer.ln1Gamma layer.ln1Beta heads layer.attnBias + baseLo baseHi + let attnLo := cacheBound attn.1 + let attnHi := cacheBound attn.2 + let yLo : Fin seq → Fin dModel → Rat := fun q i => loCached q i + attnLo i + let yHi : Fin seq → Fin dModel → Rat := fun q i => hiCached q i + attnHi i + let yLoCached := cacheBound2 yLo + let yHiCached := cacheBound2 yHi + let out := cacheBoundPair2 (fun q => + layerNormAbsMlpResidualBounds eps layer.ln2Gamma layer.ln2Beta + layer.mlpWIn layer.mlpBIn layer.mlpWOut layer.mlpBOut + (yLoCached q) (yHiCached q)) + out + +/-- `transformerLayerBoundsPos` soundness for `transformerLayerReal`. -/ +theorem transformerLayerBoundsPos_spec {seq dModel dHead numHeads hidden : Nat} [NeZero seq] + (eps : Rat) (layer : Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (scores : Fin numHeads → Fin seq → Fin seq → Real) + (lo hi : Fin seq → Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) + (hlo : ∀ q i, (lo q i : Real) ≤ x q i) + (hhi : ∀ q i, x q i ≤ (hi q i : Real)) : + let bounds := transformerLayerBoundsPos eps layer heads lo hi + ∀ q i, + (bounds.1 q i : Real) ≤ transformerLayerReal eps layer heads scores x q i ∧ + transformerLayerReal eps layer heads scores x q i ≤ (bounds.2 q i : Real) := by + classical + intro bounds q i + let positions := (Finset.univ : Finset (Fin seq)) + have hpos : positions.Nonempty := by + simp [positions] + let loCached := cacheBound2 lo + let hiCached := cacheBound2 hi + have hloCached : ∀ q i, (loCached q i : Real) ≤ x q i := by + intro q i + simpa [loCached, cacheBound2_apply] using hlo q i + have hhiCached : ∀ q i, x q i ≤ (hiCached q i : Real) := by + intro q i + simpa [hiCached, cacheBound2_apply] using hhi q i + let base := intervalBoundsOn positions hpos loCached hiCached + have hbase := intervalBoundsOn_spec positions hpos loCached hiCached x + (fun q _ i => hloCached q i) (fun q _ i => hhiCached q i) + have hloBase : ∀ q i, (base.1 i : Real) ≤ x q i := fun q i => + (hbase q (by simp [positions]) i).1 + have hhiBase : ∀ q i, x q i ≤ (base.2 i : Real) := fun q i => + (hbase q (by simp [positions]) i).2 + let baseLo := cacheBound base.1 + let baseHi := cacheBound base.2 + have hloBaseCached : ∀ q i, (baseLo i : Real) ≤ x q i := by + intro q i + simpa [baseLo, cacheBound_apply] using hloBase q i + have hhiBaseCached : ∀ q i, x q i ≤ (baseHi i : Real) := by + intro q i + simpa [baseHi, cacheBound_apply] using hhiBase q i + let attn := attentionOutputBounds eps layer.ln1Gamma layer.ln1Beta heads layer.attnBias + baseLo baseHi + have hattn := attentionOutputBounds_spec eps layer.ln1Gamma layer.ln1Beta heads + layer.attnBias scores baseLo baseHi x hne heps hsqrt hloBaseCached hhiBaseCached q + let attnLo := cacheBound attn.1 + let attnHi := cacheBound attn.2 + let y := fun j => + x q j + attentionOutputReal eps layer.ln1Gamma layer.ln1Beta heads + layer.attnBias scores x q j + have yLo : ∀ j, (loCached q j : Real) + (attn.1 j : Real) ≤ y j := by + intro j + have hlow : + (loCached q j : Real) + (attn.1 j : Real) ≤ + x q j + + attentionOutputReal eps layer.ln1Gamma layer.ln1Beta heads + layer.attnBias scores x q j := by + exact add_le_add (hloCached q j) (hattn j).1 + simpa [y] using hlow + have yHi : ∀ j, y j ≤ (hiCached q j : Real) + (attn.2 j : Real) := by + intro j + have hhigh : + x q j + + attentionOutputReal eps layer.ln1Gamma layer.ln1Beta heads + layer.attnBias scores x q j ≤ + (hiCached q j : Real) + (attn.2 j : Real) := by + exact add_le_add (hhiCached q j) (hattn j).2 + simpa [y] using hhigh + let yLoCached := cacheBound2 (fun q i => loCached q i + attnLo i) + let yHiCached := cacheBound2 (fun q i => hiCached q i + attnHi i) + have yLoCached_bound : ∀ j, (yLoCached q j : Real) ≤ y j := by + intro j + simpa [yLoCached, attnLo, cacheBound_apply, cacheBound2_apply] using (yLo j) + have yHiCached_bound : ∀ j, y j ≤ (yHiCached q j : Real) := by + intro j + simpa [yHiCached, attnHi, cacheBound_apply, cacheBound2_apply] using (yHi j) + have hmlp := + layerNormAbsMlpResidualBounds_spec eps layer.ln2Gamma layer.ln2Beta + layer.mlpWIn layer.mlpBIn layer.mlpWOut layer.mlpBOut + (yLoCached q) (yHiCached q) y hne heps hsqrt yLoCached_bound yHiCached_bound + have hmlp_i := hmlp i + simpa [bounds, transformerLayerBoundsPos, positions, base, loCached, hiCached, baseLo, baseHi, + attn, attnLo, attnHi, y, yLoCached, yHiCached, cacheBound2_apply, cacheBoundPair2_apply_left, + cacheBoundPair2_apply_right, transformerLayerReal, cacheBound_apply] using hmlp_i + +/-- Real-valued transformer stack output (folded left over layers). -/ +noncomputable def transformerStackReal + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) + (x : Fin seq → Fin dModel → Real) : Fin seq → Fin dModel → Real := + let step := fun x layerIdx => + transformerLayerReal eps (layers layerIdx) (heads layerIdx) (scores layerIdx) x + Linear.foldlFin numLayers step x + +/-- Interval bounds for a transformer stack (folded left over layers). -/ +def transformerStackBounds {dModel dHead numHeads hidden numLayers : Nat} + (eps : Rat) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (lo hi : Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := + let step := fun bounds layerIdx => + transformerLayerBounds eps (layers layerIdx).ln1Gamma (layers layerIdx).ln1Beta + (layers layerIdx).ln2Gamma (layers layerIdx).ln2Beta (heads layerIdx) + (layers layerIdx).attnBias (layers layerIdx).mlpWIn (layers layerIdx).mlpBIn + (layers layerIdx).mlpWOut (layers layerIdx).mlpBOut bounds.1 bounds.2 + Linear.foldlFin numLayers step (lo, hi) + +/-- Interval bounds for a transformer stack from per-position bounds. -/ +def transformerStackBoundsPos {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] + (eps : Rat) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (lo hi : Fin seq → Fin dModel → Rat) : + (Fin seq → Fin dModel → Rat) × (Fin seq → Fin dModel → Rat) := + let step := fun bounds layerIdx => + transformerLayerBoundsPos eps (layers layerIdx) (heads layerIdx) bounds.1 bounds.2 + Linear.foldlFin numLayers step (lo, hi) + +private theorem transformerStackBoundsPos_spec_list + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] + (eps : Rat) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : + ∀ (ls : List (Fin numLayers)) (lo hi : Fin seq → Fin dModel → Rat) + (x : Fin seq → Fin dModel → Real), + (∀ q i, (lo q i : Real) ≤ x q i) → + (∀ q i, x q i ≤ (hi q i : Real)) → + let bounds := (ls.foldl + (fun bounds layerIdx => + transformerLayerBoundsPos eps (layers layerIdx) (heads layerIdx) bounds.1 bounds.2) + (lo, hi)) + let x' := (ls.foldl + (fun x layerIdx => + transformerLayerReal eps (layers layerIdx) (heads layerIdx) (scores layerIdx) x) + x) + ∀ q i, + (bounds.1 q i : Real) ≤ x' q i ∧ + x' q i ≤ (bounds.2 q i : Real) := by + intro ls lo hi x hlo hhi + induction ls generalizing lo hi x hlo hhi with + | nil => + simpa using fun q i => And.intro (hlo q i) (hhi q i) + | cons l ls ih => + have hstep := + transformerLayerBoundsPos_spec eps (layers l) (heads l) (scores l) lo hi x + hne heps hsqrt hlo hhi + let bounds1 := transformerLayerBoundsPos eps (layers l) (heads l) lo hi + let x1 := transformerLayerReal eps (layers l) (heads l) (scores l) x + have hlo1 : ∀ q i, (bounds1.1 q i : Real) ≤ x1 q i := fun q i => (hstep q i).1 + have hhi1 : ∀ q i, x1 q i ≤ (bounds1.2 q i : Real) := fun q i => (hstep q i).2 + have ih' := ih bounds1.1 bounds1.2 x1 hlo1 hhi1 + simpa [bounds1, x1] using ih' + +/-- `transformerStackBoundsPos` soundness for real transformer-stack outputs. -/ +theorem transformerStackBoundsPos_spec {seq dModel dHead numHeads hidden numLayers : Nat} + [NeZero seq] (eps : Rat) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) + (lo hi : Fin seq → Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) + (hlo : ∀ q i, (lo q i : Real) ≤ x q i) + (hhi : ∀ q i, x q i ≤ (hi q i : Real)) : + let bounds := transformerStackBoundsPos eps layers heads lo hi + ∀ q i, + (bounds.1 q i : Real) ≤ transformerStackReal eps layers heads scores x q i ∧ + transformerStackReal eps layers heads scores x q i ≤ (bounds.2 q i : Real) := by + classical + simpa [transformerStackBoundsPos, transformerStackReal, Linear.foldlFin_eq_foldl, + Fin.foldl_eq_foldl_finRange] using + transformerStackBoundsPos_spec_list eps layers heads scores hne heps hsqrt + (List.finRange numLayers) lo hi x hlo hhi + +private theorem transformerStackBounds_spec_list + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] + (eps : Rat) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : + ∀ (ls : List (Fin numLayers)) (lo hi : Fin dModel → Rat) + (x : Fin seq → Fin dModel → Real), + (∀ q i, (lo i : Real) ≤ x q i) → + (∀ q i, x q i ≤ (hi i : Real)) → + let bounds := (ls.foldl + (fun bounds layerIdx => + transformerLayerBounds eps (layers layerIdx).ln1Gamma (layers layerIdx).ln1Beta + (layers layerIdx).ln2Gamma (layers layerIdx).ln2Beta (heads layerIdx) + (layers layerIdx).attnBias (layers layerIdx).mlpWIn (layers layerIdx).mlpBIn + (layers layerIdx).mlpWOut (layers layerIdx).mlpBOut bounds.1 bounds.2) + (lo, hi)) + let x' := (ls.foldl + (fun x layerIdx => + transformerLayerReal eps (layers layerIdx) (heads layerIdx) (scores layerIdx) x) + x) + ∀ q i, + (bounds.1 i : Real) ≤ x' q i ∧ + x' q i ≤ (bounds.2 i : Real) := by + intro ls lo hi x hlo hhi + induction ls generalizing lo hi x hlo hhi with + | nil => + simpa using fun q i => And.intro (hlo q i) (hhi q i) + | cons l ls ih => + have hstep := + transformerLayerBounds_spec_real eps (layers l) (heads l) (scores l) lo hi x + hne heps hsqrt hlo hhi + let bounds1 := + transformerLayerBounds eps (layers l).ln1Gamma (layers l).ln1Beta (layers l).ln2Gamma + (layers l).ln2Beta (heads l) (layers l).attnBias (layers l).mlpWIn (layers l).mlpBIn + (layers l).mlpWOut (layers l).mlpBOut lo hi + let x1 := transformerLayerReal eps (layers l) (heads l) (scores l) x + have hlo1 : ∀ q i, (bounds1.1 i : Real) ≤ x1 q i := fun q i => (hstep q i).1 + have hhi1 : ∀ q i, x1 q i ≤ (bounds1.2 i : Real) := fun q i => (hstep q i).2 + have ih' := ih bounds1.1 bounds1.2 x1 hlo1 hhi1 + simpa [bounds1, x1] using ih' + +/-- `transformerStackBounds` soundness for real transformer-stack outputs. -/ +theorem transformerStackBounds_spec {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] + (eps : Rat) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) + (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) + (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : + let bounds := transformerStackBounds eps layers heads lo hi + ∀ q i, + (bounds.1 i : Real) ≤ transformerStackReal eps layers heads scores x q i ∧ + transformerStackReal eps layers heads scores x q i ≤ (bounds.2 i : Real) := by + classical + simpa [transformerStackBounds, transformerStackReal, Linear.foldlFin_eq_foldl, + Fin.foldl_eq_foldl_finRange] using + transformerStackBounds_spec_list eps layers heads scores hne heps hsqrt + (List.finRange numLayers) lo hi x hlo hhi + +/-- Real-valued transformer stack output after the final LayerNorm. -/ +noncomputable def transformerStackFinalReal {seq dModel dHead numHeads hidden numLayers : Nat} + [NeZero seq] (eps : Rat) (finalLn : Model.Gpt2FinalLayerNorm dModel) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) + (x : Fin seq → Fin dModel → Real) (q : Fin seq) (i : Fin dModel) : Real := + layerNormRealOfReal eps finalLn.gamma finalLn.beta + (fun j => transformerStackReal eps layers heads scores x q j) i + +/-- Interval bounds for transformer stack outputs after the final LayerNorm. -/ +def transformerStackFinalBounds {dModel dHead numHeads hidden numLayers : Nat} + (eps : Rat) (finalLn : Model.Gpt2FinalLayerNorm dModel) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (lo hi : Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := + let stack := transformerStackBounds eps layers heads lo hi + layerNormIntervalBounds eps finalLn.gamma finalLn.beta stack.1 stack.2 + +/-- `transformerStackFinalBounds` soundness for real outputs. -/ +theorem transformerStackFinalBounds_spec {seq dModel dHead numHeads hidden numLayers : Nat} + [NeZero seq] (eps : Rat) (finalLn : Model.Gpt2FinalLayerNorm dModel) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) + (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) + (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : + let bounds := transformerStackFinalBounds eps finalLn layers heads lo hi + ∀ q i, + (bounds.1 i : Real) ≤ transformerStackFinalReal eps finalLn layers heads scores x q i ∧ + transformerStackFinalReal eps finalLn layers heads scores x q i ≤ (bounds.2 i : Real) := by + classical + intro bounds q i + let stack := transformerStackBounds eps layers heads lo hi + have hstack := + transformerStackBounds_spec eps layers heads scores lo hi x hne heps hsqrt hlo hhi q + have hlo' : ∀ k, (stack.1 k : Real) ≤ transformerStackReal eps layers heads scores x q k := + fun k => (hstack k).1 + have hhi' : ∀ k, transformerStackReal eps layers heads scores x q k ≤ (stack.2 k : Real) := + fun k => (hstack k).2 + have hln := + layerNormIntervalBounds_spec_real eps finalLn.gamma finalLn.beta stack.1 stack.2 + (fun j => transformerStackReal eps layers heads scores x q j) hne heps hsqrt hlo' hhi' + simpa [bounds, transformerStackFinalBounds, stack, transformerStackFinalReal] using hln i + +/-- Interval bounds for transformer stack outputs after the final LayerNorm (per-position). -/ +def transformerStackFinalBoundsPos + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) + (finalLn : Model.Gpt2FinalLayerNorm dModel) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (lo hi : Fin seq → Fin dModel → Rat) : + (Fin seq → Fin dModel → Rat) × (Fin seq → Fin dModel → Rat) := + let stack := transformerStackBoundsPos eps layers heads lo hi + let ln := fun q => + layerNormIntervalBounds eps finalLn.gamma finalLn.beta (stack.1 q) (stack.2 q) + (fun q i => (ln q).1 i, fun q i => (ln q).2 i) + +/-- `transformerStackFinalBoundsPos` soundness for real outputs. -/ +theorem transformerStackFinalBoundsPos_spec + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) + (finalLn : Model.Gpt2FinalLayerNorm dModel) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) + (lo hi : Fin seq → Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) + (hlo : ∀ q i, (lo q i : Real) ≤ x q i) + (hhi : ∀ q i, x q i ≤ (hi q i : Real)) : + let bounds := transformerStackFinalBoundsPos eps finalLn layers heads lo hi + ∀ q i, + (bounds.1 q i : Real) ≤ + transformerStackFinalReal eps finalLn layers heads scores x q i ∧ + transformerStackFinalReal eps finalLn layers heads scores x q i ≤ + (bounds.2 q i : Real) := by + classical + intro bounds q i + let stack := transformerStackBoundsPos eps layers heads lo hi + have hstack := + transformerStackBoundsPos_spec eps layers heads scores lo hi x hne heps hsqrt hlo hhi q + have hlo' : ∀ j, (stack.1 q j : Real) ≤ transformerStackReal eps layers heads scores x q j := + fun j => (hstack j).1 + have hhi' : ∀ j, transformerStackReal eps layers heads scores x q j ≤ (stack.2 q j : Real) := + fun j => (hstack j).2 + have hln := + layerNormIntervalBounds_spec_real eps finalLn.gamma finalLn.beta (stack.1 q) (stack.2 q) + (fun j => transformerStackReal eps layers heads scores x q j) hne heps hsqrt hlo' hhi' + simpa [bounds, transformerStackFinalBoundsPos, stack, transformerStackFinalReal] using hln i + +/-- Residual interval bounds for a GPT-2 stack from exact embeddings. -/ +def gpt2ResidualIntervalBounds + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (finalLn : Model.Gpt2FinalLayerNorm dModel) + (embed : Fin seq → Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := + let base := embeddingIntervalBounds embed + transformerStackFinalBounds eps finalLn layers heads base.1 base.2 + +/-- `gpt2ResidualIntervalBounds` soundness for real GPT-2 outputs. -/ +theorem gpt2ResidualIntervalBounds_spec + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (finalLn : Model.Gpt2FinalLayerNorm dModel) + (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) + (embed : Fin seq → Fin dModel → Rat) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : + let bounds := gpt2ResidualIntervalBounds eps layers heads finalLn embed + ∀ q i, + (bounds.1 i : Real) ≤ + transformerStackFinalReal eps finalLn layers heads scores + (fun q i => (embed q i : Real)) q i ∧ + transformerStackFinalReal eps finalLn layers heads scores + (fun q i => (embed q i : Real)) q i ≤ (bounds.2 i : Real) := by + classical + intro bounds q i + let base := embeddingIntervalBounds embed + have hbase := embeddingIntervalBounds_spec embed + have hlo : ∀ q i, (base.1 i : Real) ≤ (embed q i : Real) := fun q i => (hbase q i).1 + have hhi : ∀ q i, (embed q i : Real) ≤ (base.2 i : Real) := fun q i => (hbase q i).2 + have hstack := + transformerStackFinalBounds_spec eps finalLn layers heads scores base.1 base.2 + (fun q i => (embed q i : Real)) hne heps hsqrt hlo hhi q i + simpa [bounds, gpt2ResidualIntervalBounds, base] using hstack + +/-- Residual interval bounds over an active set from exact embeddings. -/ +def gpt2ResidualIntervalBoundsActive + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] + (active : Finset (Fin seq)) (hactive : active.Nonempty) (eps : Rat) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (finalLn : Model.Gpt2FinalLayerNorm dModel) + (embed : Fin seq → Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := + let baseLo : Fin seq → Fin dModel → Rat := embed + let baseHi : Fin seq → Fin dModel → Rat := embed + let final := transformerStackFinalBoundsPos eps finalLn layers heads baseLo baseHi + intervalBoundsOn active hactive final.1 final.2 + +/-- `gpt2ResidualIntervalBoundsActive` soundness for real GPT-2 outputs. -/ +theorem gpt2ResidualIntervalBoundsActive_spec + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] + (active : Finset (Fin seq)) (hactive : active.Nonempty) (eps : Rat) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (finalLn : Model.Gpt2FinalLayerNorm dModel) + (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) + (embed : Fin seq → Fin dModel → Rat) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : + let bounds := gpt2ResidualIntervalBoundsActive active hactive eps layers heads finalLn embed + ∀ q, q ∈ active → ∀ i, + (bounds.1 i : Real) ≤ + transformerStackFinalReal eps finalLn layers heads scores + (fun q i => (embed q i : Real)) q i ∧ + transformerStackFinalReal eps finalLn layers heads scores + (fun q i => (embed q i : Real)) q i ≤ (bounds.2 i : Real) := by + classical + intro bounds q hq i + let baseLo : Fin seq → Fin dModel → Rat := embed + let baseHi : Fin seq → Fin dModel → Rat := embed + let final := transformerStackFinalBoundsPos eps finalLn layers heads baseLo baseHi + have hfinal := + transformerStackFinalBoundsPos_spec eps finalLn layers heads scores baseLo baseHi + (fun q i => (embed q i : Real)) hne heps hsqrt + (fun q i => by simp [baseLo]) + (fun q i => by simp [baseHi]) + have hlo : ∀ q, q ∈ active → ∀ i, + (final.1 q i : Real) ≤ + transformerStackFinalReal eps finalLn layers heads scores + (fun q i => (embed q i : Real)) q i := by + intro q hq i + simpa [final] using (hfinal q i).1 + have hhi : ∀ q, q ∈ active → ∀ i, + transformerStackFinalReal eps finalLn layers heads scores + (fun q i => (embed q i : Real)) q i ≤ (final.2 q i : Real) := by + intro q hq i + simpa [final] using (hfinal q i).2 + have hbounds := intervalBoundsOn_spec active hactive final.1 final.2 + (fun q i => transformerStackFinalReal eps finalLn layers heads scores + (fun q i => (embed q i : Real)) q i) + hlo hhi + simpa [bounds, gpt2ResidualIntervalBoundsActive, final, baseLo, baseHi] using + hbounds q hq i + +/-- Package GPT-2 residual bounds into a residual-interval certificate. -/ +theorem gpt2ResidualIntervalBoundsActive_sound + {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] + (active : Finset (Fin seq)) (hactive : active.Nonempty) (eps : Rat) + (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) + (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (finalLn : Model.Gpt2FinalLayerNorm dModel) + (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) + (embed : Fin seq → Fin dModel → Rat) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : + let bounds := gpt2ResidualIntervalBoundsActive active hactive eps layers heads finalLn embed + let cert : Circuit.ResidualIntervalCert dModel := { lo := bounds.1, hi := bounds.2 } + Circuit.ResidualIntervalBounds cert ∧ + ∀ q, q ∈ active → ∀ i, + (cert.lo i : Real) ≤ + transformerStackFinalReal eps finalLn layers heads scores + (fun q i => (embed q i : Real)) q i ∧ + transformerStackFinalReal eps finalLn layers heads scores + (fun q i => (embed q i : Real)) q i ≤ (cert.hi i : Real) := by + classical + intro bounds cert + have hspec : + ∀ q, q ∈ active → ∀ i, + (bounds.1 i : Real) ≤ + transformerStackFinalReal eps finalLn layers heads scores + (fun q i => (embed q i : Real)) q i ∧ + transformerStackFinalReal eps finalLn layers heads scores + (fun q i => (embed q i : Real)) q i ≤ (bounds.2 i : Real) := by + simpa [bounds] using + (gpt2ResidualIntervalBoundsActive_spec (active := active) (hactive := hactive) + (eps := eps) (layers := layers) (heads := heads) (finalLn := finalLn) + (scores := scores) (embed := embed) (hne := hne) (heps := heps) (hsqrt := hsqrt)) + have hbounds : Circuit.ResidualIntervalBounds cert := by + refine { lo_le_hi := ?_ } + intro i + rcases hactive with ⟨q0, hq0⟩ + have hq := hspec q0 hq0 i + have hreal : (bounds.1 i : Real) ≤ (bounds.2 i : Real) := hq.1.trans hq.2 + exact (ratToReal_le_iff (x := bounds.1 i) (y := bounds.2 i)).1 hreal + refine And.intro hbounds ?_ + intro q hq i + have hq' := hspec q hq i + simpa [cert] using hq' + +end Bounds + +end Sound + +end Nfp From e21800cbdd1dc2cf5efadb29069ec441ab4c69d5 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 18:44:29 +0100 Subject: [PATCH 165/244] Add Basic submodules for induction IO and bounds --- Nfp/IO/InductionHead.lean | 1449 +------------------- Nfp/IO/InductionHead/Basic.lean | 1454 +++++++++++++++++++++ Nfp/Sound/Induction/CoreSound.lean | 1310 +------------------ Nfp/Sound/Induction/CoreSound/Basic.lean | 1307 ++++++++++++++++++ Nfp/Sound/Induction/HeadBounds.lean | 1233 +---------------- Nfp/Sound/Induction/HeadBounds/Basic.lean | 1238 ++++++++++++++++++ 6 files changed, 4006 insertions(+), 3985 deletions(-) create mode 100644 Nfp/IO/InductionHead/Basic.lean create mode 100644 Nfp/Sound/Induction/CoreSound/Basic.lean create mode 100644 Nfp/Sound/Induction/HeadBounds/Basic.lean diff --git a/Nfp/IO/InductionHead.lean b/Nfp/IO/InductionHead.lean index 8922aa7..42663f3 100644 --- a/Nfp/IO/InductionHead.lean +++ b/Nfp/IO/InductionHead.lean @@ -1,1454 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Data.List.Range -import Nfp.IO.Pure -import Nfp.IO.NfptPure -import Nfp.IO.HeadScore -import Nfp.IO.Timing -import Nfp.IO.Util -import Nfp.Circuit.Cert.LogitDiff -import Nfp.Circuit.Cert.ResidualInterval -import Nfp.Sound.Induction -import Nfp.Sound.Induction.HeadBounds -import Nfp.Sound.Induction.LogitDiff -import Nfp.Sound.Linear.FinFold +import Nfp.IO.InductionHead.Basic /-! IO helpers for induction-head certificate construction. -/ - -namespace Nfp - -namespace IO - -private def unwrapTaskResult {α : Type} (res : Except IO.Error α) : IO α := - match res with - | .ok a => pure a - | .error e => throw e - -private def configureTiming (timing? : Option Nat) (heartbeatMs? : Option Nat) : IO Unit := do - match timing? with - | some v => setTimingStdout (v ≠ 0) - | none => pure () - match heartbeatMs? with - | some v => - setTimingHeartbeatMs (UInt32.ofNat v) - if timing?.isNone && (v != 0) then - setTimingStdout true - | none => pure () - -private def splitConfigFromOptions - (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : - Sound.InductionHeadSplitConfig := - let base := Sound.defaultInductionHeadSplitConfig - { base with - splitBudgetQ := splitBudgetQ?.getD base.splitBudgetQ - splitBudgetK := splitBudgetK?.getD base.splitBudgetK - splitBudgetDiffBase := splitBudgetDiffBase?.getD base.splitBudgetDiffBase - splitBudgetDiffRefined := splitBudgetDiffRefined?.getD base.splitBudgetDiffRefined } - -open Nfp.Circuit - -private def valueBoundsModeFromEnv : IO (Option Bool) := do - match (← IO.getEnv "NFP_VALUE_BOUNDS_MODE") with - | some "common" => return some true - | some "cached" => return some false - | _ => return none - -/-- Resolve the heartbeat interval (ms) for long-running induction cert builds. -/ -private def heartbeatMs : IO UInt32 := - timingHeartbeatMs - -private def timePureWithHeartbeat {α : Type} (label : String) (f : Unit → α) : IO α := do - let t0 ← monoUsNow - timingPrint s!"timing: {label} start" - timingFlush - let task : Task α := Task.spawn (fun _ => f ()) - let heartbeatMs ← heartbeatMs - if heartbeatMs ≠ 0 then - let mut finished := (← IO.hasFinished task) - while !finished do - IO.sleep heartbeatMs - finished := (← IO.hasFinished task) - if !finished then - let now ← monoUsNow - timingPrint s!"timing: {label} running {now - t0} us" - timingFlush - let res ← IO.wait task - let t1 ← monoUsNow - timingPrint s!"timing: {label} {t1 - t0} us" - return res - -private def forceRat (x : Rat) : IO Unit := do - if x = x then - pure () - else - pure () - -/-- Profile the core induction-head bounds used by the sound certificate builder. -/ -private def timeInductionHeadCoreStages {seq dModel dHead : Nat} [NeZero seq] - (inputs : Model.InductionHeadInputs seq dModel dHead) : IO Unit := do - timingPrint "timing: core stages start" - timingFlush - let lnBounds ← timePureWithHeartbeat "core: ln bounds" (fun () => - Sound.headLnBounds inputs) - let lnLo := lnBounds.1 - let lnHi := lnBounds.2 - let lnAbsMaxTask : Fin seq → Rat := - Sound.Bounds.cacheBoundTask (fun q => - Sound.Bounds.intervalAbsBound (lnLo q) (lnHi q)) - let lnAbsMaxArr ← timePureWithHeartbeat "core: lnAbsMax force" (fun () => - Array.ofFn (fun q : Fin seq => lnAbsMaxTask q)) - let lnAbsMax : Fin seq → Rat := fun q => - lnAbsMaxArr.getD q.1 (0 : Rat) - let lnAbsMaxMax : Rat := - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := by - simp [univ] - univ.sup' hnonempty (fun q => lnAbsMax q) - let qAbsRowTasks : Array (Task (Array Rat)) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - Array.ofFn (fun d : Fin dHead => - Sound.Bounds.dotIntervalAbsBound (fun j => inputs.wq j d) (lnLo q) (lnHi q) + - |inputs.bq d|))) - let qAbsBaseArr ← timePureWithHeartbeat "core: qAbs base force" (fun () => - let defaultTask : Task (Array Rat) := Task.spawn (fun _ => #[]) - Array.ofFn (fun q : Fin seq => - (qAbsRowTasks.getD q.1 defaultTask).get)) - let qAbsBase : Fin seq → Fin dHead → Rat := fun q d => - let row := qAbsBaseArr.getD q.1 #[] - row.getD d.1 (0 : Rat) - let kAbsRowTasks : Array (Task (Array Rat)) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - Array.ofFn (fun d : Fin dHead => - Sound.Bounds.dotIntervalAbsBound (fun j => inputs.wk j d) (lnLo q) (lnHi q) + - |inputs.bk d|))) - let kAbsBaseArr ← timePureWithHeartbeat "core: kAbs base force" (fun () => - let defaultTask : Task (Array Rat) := Task.spawn (fun _ => #[]) - Array.ofFn (fun q : Fin seq => - (kAbsRowTasks.getD q.1 defaultTask).get)) - let kAbsBase : Fin seq → Fin dHead → Rat := fun q d => - let row := kAbsBaseArr.getD q.1 #[] - row.getD d.1 (0 : Rat) - let dotAbs ← timePureWithHeartbeat "core: dotAbs tasks" (fun () => - dotAbsFromQKV qAbsBase kAbsBase) - let _ ← timePureWithHeartbeat "core: dotAbs force" (fun () => - match List.finRange seq with - | [] => (0 : Rat) - | q :: _ => - match List.finRange seq with - | [] => (0 : Rat) - | k :: _ => dotAbs q k) - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => - |inputs.scale| * dotAbs q k - let scoreLo : Fin seq → Fin seq → Rat := fun q k => - if masked q k then inputs.maskValue else -scoreBaseAbs q k - let scoreHi : Fin seq → Fin seq → Rat := fun q k => - if masked q k then inputs.maskValue else scoreBaseAbs q k - let scoreLoPrev : Fin seq → Rat := fun q => - scoreLo q (inputs.prev q) - let otherKeys : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - let marginAt : Fin seq → Rat := fun q => - if hq : q ∈ inputs.active then - let other := otherKeys q - if h : other.Nonempty then - other.inf' h (fun k => scoreLoPrev q - scoreHi q k) - else - (0 : Rat) - else - (0 : Rat) - let margin ← timePureWithHeartbeat "core: margin" (fun () => - if h : inputs.active.Nonempty then - inputs.active.inf' h marginAt - else - (0 : Rat)) - let marginNeg ← timePureWithHeartbeat "core: margin < 0" (fun () => - decide (margin < 0)) - let verboseTiming ← IO.getEnv "NFP_TIMING_VERBOSE" - if verboseTiming.isSome then - timingPrint s!"timing: core: margin neg={marginNeg}" - let tEps0 ← monoUsNow - timingPrint "timing: core: eps start" - timingFlush - let eps := - if marginNeg then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + margin) - let tEps1 ← monoUsNow - timingPrint s!"timing: core: eps {tEps1 - tEps0} us" - timingFlush - let _ := marginAt - let dirHeadVec ← timePureWithHeartbeat "core: dir head vec" (fun () => - Sound.dirHeadVecOfInputs inputs) - let dirHead : Fin dHead → Rat := fun d => dirHeadVec.get d - let wvDir : Fin dModel → Rat := - Sound.Bounds.cacheBoundTask (fun j => - Sound.Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) - let _ ← timePureWithHeartbeat "core: wvDir force" (fun () => - Array.ofFn (fun j : Fin dModel => wvDir j)) - let bDir ← timePureWithHeartbeat "core: bDir" (fun () => - Sound.Linear.dotFin dHead dirHead (fun d => inputs.bv d)) - let valsAbsBase ← timePureWithHeartbeat "core: valsAbsBase" (fun () => - Sound.Linear.sumFin dModel (fun j => |wvDir j|) * lnAbsMaxMax) - let valsLoBase := bDir - valsAbsBase - let valsHiBase := bDir + valsAbsBase - let valsLo : Fin seq → Rat := fun _ => valsLoBase - let valsHi : Fin seq → Rat := fun _ => valsHiBase - let _ ← timePureWithHeartbeat "core: value bounds" (fun () => - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := by - simp [univ] - let lo := univ.inf' hnonempty valsLo - let hi := univ.sup' hnonempty valsHi - (lo, hi)) - timingPrint "timing: core stages done" - timingFlush - -/-- Load induction head inputs from disk. -/ -def loadInductionHeadInputs (path : System.FilePath) : - IO (Except String (Sigma (fun seq => - Sigma (fun dModel => Sigma (fun dHead => Model.InductionHeadInputs seq dModel dHead))))) := do - let t0 ← monoUsNow - let data ← IO.FS.readFile path - let t1 ← monoUsNow - timingPrint s!"timing: read head input file {t1 - t0} us" - let t2 ← monoUsNow - let parsed := - match Pure.parseInductionHeadInputs data with - | Except.error msg => Except.error msg - | Except.ok v => Except.ok v - let t3 ← monoUsNow - timingPrint s!"timing: parse head input file {t3 - t2} us" - return parsed - -private def ratToString (x : Rat) : String := - toString x - -private def renderResidualIntervalCert {n : Nat} (c : Circuit.ResidualIntervalCert n) : String := - let header := s!"dim {n}" - let lines := - (List.finRange n).foldr (fun i acc => - s!"lo {i.val} {ratToString (c.lo i)}" :: - s!"hi {i.val} {ratToString (c.hi i)}" :: acc) [] - String.intercalate "\n" (header :: lines) - -private def emitResidualIntervalCert {n : Nat} (c : Circuit.ResidualIntervalCert n) - (outPath? : Option System.FilePath) : IO Unit := do - let payload := renderResidualIntervalCert c - match outPath? with - | some path => IO.FS.writeFile path (payload ++ "\n") - | none => IO.println payload - -private def buildHeadOutputIntervalFromInputs {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (outPath? : Option System.FilePath) : IO UInt32 := do - match seq with - | 0 => - IO.eprintln "error: seq must be positive" - return 2 - | Nat.succ n => - let seq := Nat.succ n - let _ : NeZero seq := ⟨by simp⟩ - match Sound.buildHeadOutputIntervalFromHead? inputs with - | none => - IO.eprintln "error: head output interval rejected" - return 2 - | some result => - emitResidualIntervalCert result.cert outPath? - if outPath?.isSome then - let activeCount := result.active.card - IO.println - s!"ok: head output interval built (seq={seq}, dim={dModel}, active={activeCount})" - return 0 - -private def headScoreBoundsFromDotAbsTimed {seq dModel dHead : Nat} [NeZero seq] - (inputs : Model.InductionHeadInputs seq dModel dHead) - (dotAbs : Fin seq → Fin seq → Rat) : - IO (Sound.HeadScoreBounds seq dModel dHead) := do - timePure "head: score bounds" (fun () => - Sound.headScoreBoundsFromDotAbs inputs dotAbs) - -private def headScoreBoundsFromQAbsKAbsTimed {seq dModel dHead : Nat} [NeZero seq] - (inputs : Model.InductionHeadInputs seq dModel dHead) - (qAbs kAbs : Fin seq → Fin dHead → Rat) - (dotAbs : Fin seq → Fin seq → Rat) : - IO (Sound.HeadScoreBounds seq dModel dHead) := do - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => - |inputs.scale| * dotAbs q k - let scoreLo : Fin seq → Fin seq → Rat := fun q k => - if masked q k then inputs.maskValue else -scoreBaseAbs q k - let scoreHi : Fin seq → Fin seq → Rat := fun q k => - if masked q k then inputs.maskValue else scoreBaseAbs q k - let kAbsMax : Fin dHead → Rat := fun d => - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := Finset.univ_nonempty - univ.sup' hnonempty (fun k => kAbs k d) - let dotAbsUpper : Fin seq → Rat := fun q => - Sound.Linear.dotFin dHead (fun d => qAbs q d) kAbsMax - let scoreHiUpper : Fin seq → Rat := fun q => - max inputs.maskValue (|inputs.scale| * dotAbsUpper q) - let marginTasks : Array (Task Rat) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - if q ∈ inputs.active then - let prev := inputs.prev q - let scoreLoPrev := scoreLo q prev - scoreLoPrev - scoreHiUpper q - else - (0 : Rat))) - let marginAt : Fin seq → Rat := fun q => - (marginTasks[q.1]'(by - simp [marginTasks, q.isLt])).get - let epsTasks : Array (Task Rat) := - Array.ofFn (fun q : Fin seq => - (marginTasks[q.1]'(by - simp [marginTasks, q.isLt])).map (fun m => - if m < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + m))) - let epsAt : Fin seq → Rat := fun q => - (epsTasks[q.1]'(by - simp [epsTasks, q.isLt])).get - let margin : Rat := - if h : inputs.active.Nonempty then - inputs.active.inf' h marginAt - else - (0 : Rat) - let eps : Rat := - if margin < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + margin) - let result : Sound.HeadScoreBounds seq dModel dHead := - { dotAbs := dotAbs - scoreBaseAbs := scoreBaseAbs - scoreAbs := fun q k => if masked q k then |inputs.maskValue| else scoreBaseAbs q k - scoreLo := scoreLo - scoreHi := scoreHi - marginAt := marginAt - epsAt := epsAt - margin := margin - eps := eps } - return result - -private def checkInductionHeadInputs {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cfg : Sound.InductionHeadSplitConfig) - (minActive? : Option Nat) (minLogitDiff? : Option Rat) - (minMargin maxEps : Rat) : IO UInt32 := do - match seq with - | 0 => - IO.eprintln "error: seq must be positive" - return 2 - | Nat.succ n => - let seq := Nat.succ n - let _ : NeZero seq := ⟨by simp⟩ - logTiming "start: head build induction cert" - timingPrint "timing: head build induction cert start" - timingFlush - let verboseTiming ← IO.getEnv "NFP_TIMING_VERBOSE" - if verboseTiming.isSome then - timingPrint s!"timing: head dims seq={seq} dModel={dModel} dHead={dHead}" - timingPrint s!"timing: head active card={inputs.active.card}" - timingFlush - let precompute := (← IO.getEnv "NFP_TIMING_PRECOMPUTE").isSome - if precompute then - timingPrint "timing: head ln bounds start" - timingFlush - let lnBounds ← timePure "head: ln bounds" (fun () => - Sound.headLnBounds inputs) - timingPrint "timing: head ln bounds done" - timingFlush - timingPrint "timing: head qkv bounds start" - timingFlush - let lnLo := lnBounds.1 - let lnHi := lnBounds.2 - let qkv ← timePure "head: qkv bounds" (fun () => - Sound.headQKVBounds inputs lnLo lnHi) - timingPrint "timing: head qkv bounds done" - timingFlush - if verboseTiming.isSome then - timingPrint "timing: head qkv abs force start" - timingFlush - let tAbs0 ← monoUsNow - for q in List.finRange seq do - for d in List.finRange dHead do - let _ := qkv.qAbs q d - let _ := qkv.kAbs q d - pure () - let tAbs1 ← monoUsNow - timingPrint s!"timing: head qkv abs force {tAbs1 - tAbs0} us" - timingFlush - timingPrint "timing: head score/value bounds spawn start" - timingFlush - let tSpawn0 ← monoUsNow - if verboseTiming.isSome then - timingPrint "timing: head score dotAbs tasks start" - timingFlush - let dotAbs ← timePure "head: score dotAbs tasks" (fun () => - dotAbsFromQKV qkv.qAbs qkv.kAbs) - if verboseTiming.isSome then - timingPrint "timing: head score dotAbs tasks done" - timingFlush - if verboseTiming.isSome then - timingPrint "timing: head score dotAbs force start" - timingFlush - let tForce0 ← monoUsNow - match List.finRange seq with - | [] => - timingPrint "timing: head score dotAbs force skipped (empty seq)" - | q :: _ => - match List.finRange seq with - | [] => - timingPrint "timing: head score dotAbs force skipped (empty seq)" - | k :: _ => - let _ := dotAbs q k - pure () - let tForce1 ← monoUsNow - timingPrint s!"timing: head score dotAbs force {tForce1 - tForce0} us" - timingFlush - let inlineVals := (← IO.getEnv "NFP_TIMING_VALUE_INLINE").isSome - let valueMode? ← valueBoundsModeFromEnv - let useCommon := valueMode?.getD false - let (valsInline?, valsTask?) := - if inlineVals then - let vals := - if useCommon then - Sound.headValueBoundsCommonDen inputs qkv.vLo qkv.vHi - else - Sound.headValueBounds inputs qkv.vLo qkv.vHi - (some vals, none) - else - let task := - if useCommon then - Sound.headValueBoundsCommonDenTask inputs qkv.vLo qkv.vHi - else - Sound.headValueBoundsTask inputs qkv.vLo qkv.vHi - (none, some task) - let activeList := (List.finRange seq).filter (fun q => q ∈ inputs.active) - if verboseTiming.isSome then - timeHeadScoreMarginRaw inputs dotAbs activeList - let tSpawn1 ← monoUsNow - timingPrint s!"timing: head score/value bounds spawn {tSpawn1 - tSpawn0} us" - timingFlush - let skipScoreBounds := (← IO.getEnv "NFP_TIMING_SKIP_SCORE_BOUNDS").isSome - let scoreTaskOpt ← - if skipScoreBounds then - timingPrint "timing: head score bounds skipped" - pure none - else - timingPrint "timing: head score bounds from dotAbs start" - timingFlush - let exactMargin := (← IO.getEnv "NFP_TIMING_EXACT_MARGIN").isSome - let action := - if exactMargin then - headScoreBoundsFromDotAbsTimed inputs dotAbs - else - headScoreBoundsFromQAbsKAbsTimed inputs qkv.qAbs qkv.kAbs dotAbs - let t ← action.asTask - pure (some t) - if verboseTiming.isSome then - timingPrint "timing: head value parts start" - timingFlush - timingPrint "timing: head value dirHead start" - timingFlush - let tDir0 ← monoUsNow - let dirHead := Sound.headValueDirHead inputs - match List.finRange dHead with - | [] => - timingPrint "timing: head value dirHead forced skipped (empty dHead)" - | d :: _ => - let _ := dirHead d - pure () - let tDir1 ← monoUsNow - timingPrint s!"timing: head value dirHead {tDir1 - tDir0} us" - timingFlush - timingPrint "timing: head value valsLo start" - timingFlush - let tLo0 ← monoUsNow - let valsLo := Sound.headValueValsLo inputs qkv.vLo qkv.vHi - match List.finRange seq with - | [] => - timingPrint "timing: head value valsLo forced skipped (empty seq)" - | k :: _ => - let _ := valsLo k - pure () - let tLo1 ← monoUsNow - timingPrint s!"timing: head value valsLo {tLo1 - tLo0} us" - timingFlush - timingPrint "timing: head value valsHi start" - timingFlush - let tHi0 ← monoUsNow - let valsHi := Sound.headValueValsHi inputs qkv.vLo qkv.vHi - match List.finRange seq with - | [] => - timingPrint "timing: head value valsHi forced skipped (empty seq)" - | k :: _ => - let _ := valsHi k - pure () - let tHi1 ← monoUsNow - timingPrint s!"timing: head value valsHi {tHi1 - tHi0} us" - timingFlush - timingPrint "timing: head value lo start" - timingFlush - let tLo2 ← monoUsNow - let _ := Sound.headValueLo valsLo - let tLo3 ← monoUsNow - timingPrint s!"timing: head value lo {tLo3 - tLo2} us" - timingFlush - timingPrint "timing: head value hi start" - timingFlush - let tHi2 ← monoUsNow - let _ := Sound.headValueHi valsHi - let tHi3 ← monoUsNow - timingPrint s!"timing: head value hi {tHi3 - tHi2} us" - timingFlush - timingPrint "timing: head value parts done" - timingFlush - timingPrint "timing: head value bounds start" - timingFlush - let tVals0 ← monoUsNow - let vals ← - match valsInline?, valsTask? with - | some vals, _ => - timePure "head: value bounds inline" (fun () => vals) - | none, some valsTask => - timePure "head: value bounds wait" (fun () => valsTask.get) - | none, none => - timePure "head: value bounds inline" (fun () => - Sound.headValueBounds inputs qkv.vLo qkv.vHi) - let tVals1 ← monoUsNow - timingPrint s!"timing: head value bounds {tVals1 - tVals0} us" - timingFlush - let scoreOpt ← - match scoreTaskOpt with - | none => pure none - | some scoreTask => do - let res ← IO.wait scoreTask - let score ← unwrapTaskResult res - timingPrint "timing: head score bounds from dotAbs done" - timingFlush - pure (some score) - match scoreOpt with - | none => pure () - | some score => - if verboseTiming.isSome then - timeHeadScoreSampleGap inputs score - if verboseTiming.isSome then - timeHeadScoreMarginList activeList score - if verboseTiming.isSome then - timeHeadScoreFieldForces score - if verboseTiming.isSome then - timingPrint "timing: head score bounds force start" - timingFlush - let tScore0 ← monoUsNow - let _ := score.margin - let _ := score.eps - let tScore1 ← monoUsNow - timingPrint s!"timing: head score bounds force {tScore1 - tScore0} us" - timingFlush - let coreStages := (← IO.getEnv "NFP_TIMING_CORE_STAGES").isSome - let coreStagesOnly := (← IO.getEnv "NFP_TIMING_CORE_STAGES_ONLY").isSome - if coreStages then - timeInductionHeadCoreStages inputs - if coreStagesOnly then - return 0 - let breakdown := (← IO.getEnv "NFP_TIMING_BREAKDOWN").isSome - if breakdown then - let lnBounds ← timePureWithHeartbeat "breakdown: ln bounds" (fun () => - Sound.headLnBounds inputs) - timingPrint "timing: breakdown ln bounds force start" - timingFlush - let tLn0 ← monoUsNow - for q in List.finRange seq do - for i in List.finRange dModel do - let _ := lnBounds.1 q i - let _ := lnBounds.2 q i - pure () - let tLn1 ← monoUsNow - timingPrint s!"timing: breakdown ln bounds force {tLn1 - tLn0} us" - timingFlush - let qkv ← timePureWithHeartbeat "breakdown: qkv bounds" (fun () => - Sound.headQKVBounds inputs lnBounds.1 lnBounds.2) - timingPrint "timing: breakdown qkv bounds force start" - timingFlush - let tQkv0 ← monoUsNow - for q in List.finRange seq do - for d in List.finRange dHead do - let _ := qkv.qLo q d - let _ := qkv.qHi q d - let _ := qkv.kLo q d - let _ := qkv.kHi q d - let _ := qkv.vLo q d - let _ := qkv.vHi q d - let _ := qkv.qAbs q d - let _ := qkv.kAbs q d - pure () - let tQkv1 ← monoUsNow - timingPrint s!"timing: breakdown qkv bounds force {tQkv1 - tQkv0} us" - timingFlush - let dotAbs : Fin seq → Fin seq → Rat := fun q k => - Sound.Linear.dotFin dHead (fun d => qkv.qAbs q d) (fun d => qkv.kAbs k d) - let dotAbsRowTasks : - Array (Task { row : Array Rat // row.size = seq }) ← - timePureWithHeartbeat "breakdown: score dotAbs rows" (fun () => - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - ⟨Array.ofFn (fun k : Fin seq => dotAbs q k), by simp⟩))) - let dotAbsRowDefault : Task { row : Array Rat // row.size = seq } := - Task.spawn (fun _ => ⟨Array.ofFn (fun _ : Fin seq => (0 : Rat)), by simp⟩) - timingPrint "timing: breakdown score dotAbs force start" - timingFlush - let tDot0 ← monoUsNow - for q in List.finRange seq do - let row := (dotAbsRowTasks.getD q.1 dotAbsRowDefault).get - let _ := row - pure () - let tDot1 ← monoUsNow - timingPrint s!"timing: breakdown score dotAbs force {tDot1 - tDot0} us" - timingFlush - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let scaleAbs : Rat := |inputs.scale| - let marginAtRaw : Fin seq → Rat := fun q => - let row := (dotAbsRowTasks.getD q.1 dotAbsRowDefault).get - if q ∈ inputs.active then - let rowArr := row.1 - let prev := inputs.prev q - let dotAbsPrev := rowArr.getD prev.1 0 - if masked q prev then - let scoreLoPrev := inputs.maskValue - let scoreHiAt : Fin seq → Rat := fun k => - if masked q k then inputs.maskValue else scaleAbs * rowArr.getD k.1 0 - let maskedGap := scoreLoPrev - inputs.maskValue - let step : - (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := - fun acc k => - if k = prev then - acc - else if masked q k then - (acc.1, true) - else - let v := scoreLoPrev - scoreHiAt k - match acc.1 with - | none => (some v, acc.2) - | some cur => (some (min cur v), acc.2) - let acc := Sound.Linear.foldlFin seq step (none, false) - match acc.1, acc.2 with - | some unmaskedMin, true => min unmaskedMin maskedGap - | some unmaskedMin, false => unmaskedMin - | none, true => maskedGap - | none, false => (0 : Rat) - else - let scoreLoPrev := -(scaleAbs * dotAbsPrev) - let maskedGap := scoreLoPrev - inputs.maskValue - let step : - (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := - fun acc k => - if k = prev then - acc - else if masked q k then - (acc.1, true) - else - let raw := -(dotAbsPrev + rowArr.getD k.1 0) - match acc.1 with - | none => (some raw, acc.2) - | some cur => (some (min cur raw), acc.2) - let acc := Sound.Linear.foldlFin seq step (none, false) - match acc.1, acc.2 with - | some unmaskedMin, true => min (scaleAbs * unmaskedMin) maskedGap - | some unmaskedMin, false => scaleAbs * unmaskedMin - | none, true => maskedGap - | none, false => (0 : Rat) - else - (0 : Rat) - let marginAtCached : Fin seq → Rat ← - timePureWithHeartbeat "breakdown: score margin cache" (fun () => - Sound.Bounds.cacheBoundThunk marginAtRaw) - timingPrint "timing: breakdown score margin force start" - timingFlush - let tMargin0 ← monoUsNow - for q in List.finRange seq do - let m := marginAtCached q - forceRat m - pure () - let tMargin1 ← monoUsNow - timingPrint s!"timing: breakdown score margin force {tMargin1 - tMargin0} us" - timingFlush - let epsAtRaw : Fin seq → Rat := fun q => - let m := marginAtCached q - if m < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + m) - let epsAtCached : Fin seq → Rat ← - timePureWithHeartbeat "breakdown: score eps cache" (fun () => - Sound.Bounds.cacheBoundThunk epsAtRaw) - timingPrint "timing: breakdown score eps force start" - timingFlush - let tEps0 ← monoUsNow - for q in List.finRange seq do - let e := epsAtCached q - forceRat e - pure () - let tEps1 ← monoUsNow - timingPrint s!"timing: breakdown score eps force {tEps1 - tEps0} us" - timingFlush - let valsLo ← timePureWithHeartbeat "breakdown: value valsLo" (fun () => - Sound.headValueValsLo inputs qkv.vLo qkv.vHi) - timingPrint "timing: breakdown value valsLo force start" - timingFlush - let tValsLo0 ← monoUsNow - for k in List.finRange seq do - let v := valsLo k - forceRat v - pure () - let tValsLo1 ← monoUsNow - timingPrint s!"timing: breakdown value valsLo force {tValsLo1 - tValsLo0} us" - timingFlush - let valsHi ← timePureWithHeartbeat "breakdown: value valsHi" (fun () => - Sound.headValueValsHi inputs qkv.vLo qkv.vHi) - timingPrint "timing: breakdown value valsHi force start" - timingFlush - let tValsHi0 ← monoUsNow - for k in List.finRange seq do - let v := valsHi k - forceRat v - pure () - let tValsHi1 ← monoUsNow - timingPrint s!"timing: breakdown value valsHi force {tValsHi1 - tValsHi0} us" - timingFlush - let heartbeatMsProgress ← heartbeatMs - let taskMin (t1 t2 : Task Rat) : Task Rat := - Task.bind t1 (fun v1 => Task.map (fun v2 => min v1 v2) t2) - let taskMax (t1 t2 : Task Rat) : Task Rat := - Task.bind t1 (fun v1 => Task.map (fun v2 => max v1 v2) t2) - let reduceMinTasksWithProgress (tasks : Array (Task Rat)) : - IO Rat := do - let n := tasks.size - if n = 0 then - pure (0 : Rat) - else - let chunkSize : Nat := 16 - let chunks : Nat := (n + chunkSize - 1) / chunkSize - let defaultTask : Task Rat := Task.pure (0 : Rat) - let chunkTasks : Array (Task Rat) := - Array.ofFn (fun c : Fin chunks => - let start := c.val * chunkSize - let stop := Nat.min n (start + chunkSize) - let init := tasks.getD start defaultTask - if stop ≤ start + 1 then - init - else - let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) - rest.foldl (fun acc i => taskMin acc (tasks.getD i defaultTask)) init) - if heartbeatMsProgress ≠ 0 then - let mut finished := 0 - let mut remaining := chunkTasks.size - while finished < remaining do - IO.sleep heartbeatMsProgress - let mut count := 0 - for t in chunkTasks do - if (← IO.hasFinished t) then - count := count + 1 - finished := count - remaining := chunkTasks.size - if finished < remaining then - timingPrint s!"timing: breakdown value lo progress {finished}/{remaining}" - timingFlush - let init := chunkTasks.getD 0 defaultTask - let rest := (List.range (chunkTasks.size - 1)).map (fun i => i + 1) - pure ((rest.foldl (fun acc i => taskMin acc (chunkTasks.getD i defaultTask)) init).get) - let reduceMaxTasksWithProgress (tasks : Array (Task Rat)) : - IO Rat := do - let n := tasks.size - if n = 0 then - pure (0 : Rat) - else - let chunkSize : Nat := 16 - let chunks : Nat := (n + chunkSize - 1) / chunkSize - let defaultTask : Task Rat := Task.pure (0 : Rat) - let chunkTasks : Array (Task Rat) := - Array.ofFn (fun c : Fin chunks => - let start := c.val * chunkSize - let stop := Nat.min n (start + chunkSize) - let init := tasks.getD start defaultTask - if stop ≤ start + 1 then - init - else - let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) - rest.foldl (fun acc i => taskMax acc (tasks.getD i defaultTask)) init) - if heartbeatMsProgress ≠ 0 then - let mut finished := 0 - let mut remaining := chunkTasks.size - while finished < remaining do - IO.sleep heartbeatMsProgress - let mut count := 0 - for t in chunkTasks do - if (← IO.hasFinished t) then - count := count + 1 - finished := count - remaining := chunkTasks.size - if finished < remaining then - timingPrint s!"timing: breakdown value hi progress {finished}/{remaining}" - timingFlush - let init := chunkTasks.getD 0 defaultTask - let rest := (List.range (chunkTasks.size - 1)).map (fun i => i + 1) - pure ((rest.foldl (fun acc i => taskMax acc (chunkTasks.getD i defaultTask)) init).get) - if (← IO.getEnv "NFP_TIMING_TASK_PROGRESS").isSome then - let tasksLo := - (List.finRange seq).map (fun k => Task.spawn (fun _ => valsLo k)) - let tasksHi := - (List.finRange seq).map (fun k => Task.spawn (fun _ => valsHi k)) - let _ ← timePureWithHeartbeat "breakdown: value lo progress" (fun () => - reduceMinTasksWithProgress tasksLo.toArray) - let _ ← timePureWithHeartbeat "breakdown: value hi progress" (fun () => - reduceMaxTasksWithProgress tasksHi.toArray) - else - let loTask := Sound.headValueLoTask valsLo - let hiTask := Sound.headValueHiTask valsHi - let heartbeatMs ← heartbeatMs - let tLo0 ← monoUsNow - if heartbeatMs ≠ 0 then - let mut finished := (← IO.hasFinished loTask) - while !finished do - IO.sleep heartbeatMs - finished := (← IO.hasFinished loTask) - if !finished then - let now ← monoUsNow - timingPrint s!"timing: breakdown: value lo running {now - tLo0} us" - timingFlush - let lo := loTask.get - let tLo1 ← monoUsNow - timingPrint s!"timing: breakdown: value lo {tLo1 - tLo0} us" - timingFlush - let tHi0 ← monoUsNow - if heartbeatMs ≠ 0 then - let mut finished := (← IO.hasFinished hiTask) - while !finished do - IO.sleep heartbeatMs - finished := (← IO.hasFinished hiTask) - if !finished then - let now ← monoUsNow - timingPrint s!"timing: breakdown: value hi running {now - tHi0} us" - timingFlush - let hi := hiTask.get - let tHi1 ← monoUsNow - timingPrint s!"timing: breakdown: value hi {tHi1 - tHi0} us" - timingFlush - let _ := lo - let _ := hi - if (← IO.getEnv "NFP_TIMING_SEQ_REDUCE").isSome then - let loSeq ← timePureWithHeartbeat "breakdown: value lo seq" (fun () => - match List.finRange seq with - | [] => (0 : Rat) - | k :: ks => - let init := valsLo k - ks.foldl (fun acc k => min acc (valsLo k)) init) - let hiSeq ← timePureWithHeartbeat "breakdown: value hi seq" (fun () => - match List.finRange seq with - | [] => (0 : Rat) - | k :: ks => - let init := valsHi k - ks.foldl (fun acc k => max acc (valsHi k)) init) - let _ := loSeq - let _ := hiSeq - let tCert0 ← monoUsNow - let certTask : - Task - (Option { c : Sound.InductionHeadCert seq // - Sound.InductionHeadCertSound inputs c }) := - Task.spawn (prio := Task.Priority.dedicated) (fun _ => - match Sound.buildInductionCertFromHeadWith? cfg inputs with - | none => none - | some ⟨cert, hcert⟩ => - let _ := cert.active.card - some ⟨cert, hcert⟩) - let heartbeatMs ← heartbeatMs - if heartbeatMs ≠ 0 then - let mut finished := (← IO.hasFinished certTask) - while !finished do - IO.sleep heartbeatMs - finished := (← IO.hasFinished certTask) - if !finished then - let now ← monoUsNow - timingPrint s!"timing: head build induction cert running {now - tCert0} us" - timingFlush - let certOpt ← IO.wait certTask - let tCert1 ← monoUsNow - logTiming s!"done: head build induction cert {tCert1 - tCert0} us" - timingPrint s!"timing: head build induction cert {tCert1 - tCert0} us" - timingPrint "timing: head build induction cert returned" - timingFlush - match certOpt with - | none => - IO.eprintln "error: head inputs rejected" - return 2 - | some ⟨cert, _hcert⟩ => - timingPrint "timing: head active count start" - timingFlush - let activeCount := cert.active.card - timingPrint "timing: head active count done" - timingFlush - let defaultMinActive := max 1 (seq / 8) - let minActive := minActive?.getD defaultMinActive - if activeCount < minActive then - IO.eprintln - s!"error: active queries {activeCount} below minimum {minActive}" - return 2 - if cert.margin < minMargin then - IO.eprintln - s!"error: margin {ratToString cert.margin} \ - below minimum {ratToString minMargin}" - return 2 - if maxEps < cert.eps then - IO.eprintln - s!"error: eps {ratToString cert.eps} \ - above maximum {ratToString maxEps}" - return 2 - timingPrint "timing: head tol start" - timingFlush - let tol := cert.eps * (cert.values.hi - cert.values.lo) - timingPrint "timing: head tol done" - timingFlush - logTiming "start: head logit-diff lower bound" - timingPrint "timing: head logit-diff lower bound start" - timingFlush - let logitDiffLB? ← timePure "head: logit-diff lower bound" (fun () => - Circuit.logitDiffLowerBound cert.active cert.prev cert.eps - cert.values.lo cert.values.hi cert.values.valsLo) - logTiming "done: head logit-diff lower bound" - let effectiveMinLogitDiff := - match minLogitDiff? with - | some v => some v - | none => some (0 : Rat) - match logitDiffLB? with - | none => - IO.eprintln "error: empty active set for logit-diff bound" - return 2 - | some logitDiffLB => - let violation? : Option Rat := - match effectiveMinLogitDiff with - | none => none - | some minLogitDiff => - if logitDiffLB < minLogitDiff then - some minLogitDiff - else - none - match violation? with - | some minLogitDiff => - IO.eprintln - s!"error: logitDiffLB {ratToString logitDiffLB} \ - below minimum {ratToString minLogitDiff}" - return 2 - | none => - IO.println - s!"ok: induction bound certified \ - (seq={seq}, active={activeCount}, \ - tol={ratToString tol}, logitDiffLB={ratToString logitDiffLB})" - return 0 - -private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cfg : Sound.InductionHeadSplitConfig) - (minActive? : Option Nat) (minLogitDiff? : Option Rat) - (minMargin? : Option Rat) (maxEps : Rat) : IO UInt32 := do - match seq with - | 0 => - IO.eprintln "error: seq must be positive" - return 2 - | Nat.succ n => - let seq := Nat.succ n - let _ : NeZero seq := ⟨by simp⟩ - logTiming "start: head build induction cert" - timingPrint "timing: head build induction cert start" - timingFlush - let tCert0 ← monoUsNow - let certTask : - Task - (Option { c : Sound.InductionHeadCert seq // - Sound.InductionHeadCertSound inputs c }) := - Task.spawn (prio := Task.Priority.dedicated) (fun _ => - match Sound.buildInductionCertFromHeadWith? cfg inputs with - | none => none - | some ⟨cert, hcert⟩ => - let _ := cert.active.card - some ⟨cert, hcert⟩) - let heartbeatMs ← heartbeatMs - if heartbeatMs ≠ 0 then - let mut finished := (← IO.hasFinished certTask) - while !finished do - IO.sleep heartbeatMs - finished := (← IO.hasFinished certTask) - if !finished then - let now ← monoUsNow - timingPrint s!"timing: head build induction cert running {now - tCert0} us" - timingFlush - let certOpt ← IO.wait certTask - let tCert1 ← monoUsNow - logTiming s!"done: head build induction cert {tCert1 - tCert0} us" - timingPrint s!"timing: head build induction cert {tCert1 - tCert0} us" - timingPrint "timing: head build induction cert returned" - timingFlush - match certOpt with - | none => - IO.eprintln "error: head inputs rejected" - return 2 - | some ⟨cert, _hcert⟩ => - let activeCount := cert.active.card - let defaultMinActive := max 1 (seq / 8) - let minActive := minActive?.getD defaultMinActive - if activeCount < minActive then - IO.eprintln - s!"error: active queries {activeCount} below minimum {minActive}" - return 2 - if maxEps < cert.eps then - IO.eprintln - s!"error: eps {ratToString cert.eps} above maximum {ratToString maxEps}" - return 2 - let marginViolation? : Option Rat := - match minMargin? with - | none => none - | some minMargin => - if cert.margin < minMargin then - some minMargin - else - none - match marginViolation? with - | some minMargin => - IO.eprintln - s!"error: margin {ratToString cert.margin} \ - below minimum {ratToString minMargin}" - return 2 - | none => pure () - logTiming "start: head logit-diff lower bound" - timingPrint "timing: head logit-diff lower bound start" - timingFlush - let logitDiffLB0? ← timePure "head: logit-diff lower bound" (fun () => - Sound.logitDiffLowerBoundFromCert cert) - let needsWeighted : Bool := - match logitDiffLB0? with - | none => true - | some lb0 => - if lb0 ≤ 0 then - true - else - match minLogitDiff? with - | some minLogitDiff => lb0 < minLogitDiff - | none => false - let logitDiffWeighted? ← - if needsWeighted then - timePure "head: logit-diff lower bound weighted" (fun () => - Sound.logitDiffLowerBoundFromCertWeighted cert) - else - pure none - let logitDiffLB? : Option Rat := - match logitDiffLB0?, logitDiffWeighted? with - | some lb0, some lb1 => some (max lb0 lb1) - | some lb0, none => some lb0 - | none, some lb1 => some lb1 - | none, none => none - let boundLabel : String := - match logitDiffLB0?, logitDiffWeighted? with - | some _, some _ => "max" - | none, some _ => "weighted" - | some _, none => "eps" - | none, none => "none" - logTiming "done: head logit-diff lower bound" - match logitDiffLB? with - | none => - IO.eprintln "error: empty active set for logit-diff bound" - return 2 - | some logitDiffLB => - if logitDiffLB ≤ 0 then - IO.eprintln - s!"error: logitDiffLB {ratToString logitDiffLB} \ - is not strictly positive" - return 2 - let violation? : Option Rat := - match minLogitDiff? with - | none => none - | some minLogitDiff => - if logitDiffLB < minLogitDiff then - some minLogitDiff - else - none - match violation? with - | some minLogitDiff => - IO.eprintln - s!"error: logitDiffLB {ratToString logitDiffLB} \ - below minimum {ratToString minLogitDiff}" - return 2 - | none => pure () - let tol := cert.eps * (cert.values.hi - cert.values.lo) - IO.println - s!"ok: nonvacuous induction bound certified \ - (seq={seq}, active={activeCount}, \ - tol={ratToString tol}, logitDiffLB={ratToString logitDiffLB}, \ - bound={boundLabel})" - return 0 - -/-- Build and check induction certificates from exact head inputs. -/ -def runInductionCertifyHead (inputsPath : System.FilePath) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) - (timing? : Option Nat) (heartbeatMs? : Option Nat) - (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : - IO UInt32 := do - configureTiming timing? heartbeatMs? - let splitCfg := - splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - let parsedInputs ← timePhase "load head inputs" <| - loadInductionHeadInputs inputsPath - match parsedInputs with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => - checkInductionHeadInputs inputs splitCfg minActive? minLogitDiff? minMargin maxEps - -/-- Build and check a strictly positive induction logit-diff bound from head inputs. -/ -def runInductionCertifyHeadNonvacuous (inputsPath : System.FilePath) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) - (timing? : Option Nat) (heartbeatMs? : Option Nat) - (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : - IO UInt32 := do - configureTiming timing? heartbeatMs? - let splitCfg := - splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - let parsedInputs ← timePhase "load head inputs" <| - loadInductionHeadInputs inputsPath - match parsedInputs with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => - checkInductionHeadInputsNonvacuous inputs splitCfg minActive? minLogitDiff? - minMargin? maxEps - -/-- Build and check induction certificates from a model binary. -/ -def runInductionCertifyHeadModel (modelPath : System.FilePath) - (layer head dirTarget dirNegative : Nat) (period? : Option Nat) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) - (timing? : Option Nat) (heartbeatMs? : Option Nat) - (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : - IO UInt32 := do - configureTiming timing? heartbeatMs? - let splitCfg := - splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - logTiming "start: read model file" - timingPrint "timing: read model file start" - timingFlush - let data ← timePhase "read model file" <| IO.FS.readBinFile modelPath - let headerE ← timePure "parse model header" (fun () => - NfptPure.parseHeader data) - match headerE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨header, start⟩ => - let inputsE ← timePure "read head inputs" (fun () => - NfptPure.readInductionHeadInputs - data start header layer head dirTarget dirNegative period?) - match inputsE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok inputs => - checkInductionHeadInputs inputs splitCfg minActive? minLogitDiff? minMargin maxEps - -/-- Build and check a strictly positive induction logit-diff bound from a model binary. -/ -def runInductionCertifyHeadModelNonvacuous (modelPath : System.FilePath) - (layer head dirTarget dirNegative : Nat) (period? : Option Nat) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) - (timing? : Option Nat) (heartbeatMs? : Option Nat) - (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : - IO UInt32 := do - configureTiming timing? heartbeatMs? - let splitCfg := - splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - logTiming "start: read model file" - timingPrint "timing: read model file start" - timingFlush - let data ← timePhase "read model file" <| IO.FS.readBinFile modelPath - let headerE ← timePure "parse model header" (fun () => - NfptPure.parseHeader data) - match headerE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨header, start⟩ => - let inputsE ← timePure "read head inputs" (fun () => - NfptPure.readInductionHeadInputs - data start header layer head dirTarget dirNegative period?) - match inputsE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok inputs => - checkInductionHeadInputsNonvacuous inputs splitCfg minActive? minLogitDiff? - minMargin? maxEps - -/-- Heuristic logit-diff direction derived from prompt tokens. -/ -private def deriveDirectionFromTokens {seq : Nat} (tokens : Fin seq → Nat) : - Except String (Nat × Nat) := do - let tokenArr : Array Nat := Array.ofFn (fun i : Fin seq => tokens i) - let n := tokenArr.size - if n < 2 then - throw "token sequence must have length at least 2" - let lastTok := tokenArr.getD (n - 1) 0 - let prevIdx? := - (List.range (n - 1)).reverse.find? (fun i => - tokenArr.getD i lastTok = lastTok) - let targetTok := - match prevIdx? with - | some i => tokenArr.getD (i + 1) lastTok - | none => lastTok - let neg0 := tokenArr.getD (n - 2) lastTok - let neg := - if neg0 = targetTok then - if lastTok ≠ targetTok then - lastTok - else if targetTok ≠ 0 then - 0 - else - 1 - else - neg0 - return (targetTok, neg) - -/-- Build and check induction certificates from a model binary, deriving direction tokens from the -prompt sequence. -/ -def runInductionCertifyHeadModelAuto (modelPath : System.FilePath) - (layer head : Nat) (period? : Option Nat) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) - (timing? : Option Nat) (heartbeatMs? : Option Nat) - (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : - IO UInt32 := do - configureTiming timing? heartbeatMs? - let splitCfg := - splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - logTiming "start: read model file" - timingPrint "timing: read model file start" - timingFlush - let data ← timePhase "read model file" <| IO.FS.readBinFile modelPath - let headerE ← timePure "parse model header" (fun () => - NfptPure.parseHeader data) - match headerE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨header, start⟩ => - let tokensE ← timePure "read prompt tokens" (fun () => - NfptPure.readTokens data start header) - match tokensE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok tokens => - match deriveDirectionFromTokens tokens with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨dirTarget, dirNegative⟩ => - IO.println - s!"info: direction-target={dirTarget} direction-negative={dirNegative}" - let inputsE ← timePure "read head inputs" (fun () => - NfptPure.readInductionHeadInputs - data start header layer head dirTarget dirNegative period?) - match inputsE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok inputs => - checkInductionHeadInputs inputs splitCfg minActive? minLogitDiff? - minMargin maxEps - -/-- Build and check a strictly positive induction logit-diff bound from a model binary, deriving -direction tokens from the prompt sequence. -/ -def runInductionCertifyHeadModelAutoNonvacuous (modelPath : System.FilePath) - (layer head : Nat) (period? : Option Nat) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) - (timing? : Option Nat) (heartbeatMs? : Option Nat) - (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : - IO UInt32 := do - configureTiming timing? heartbeatMs? - let splitCfg := - splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - logTiming "start: read model file" - timingPrint "timing: read model file start" - timingFlush - let data ← timePhase "read model file" <| IO.FS.readBinFile modelPath - let headerE ← timePure "parse model header" (fun () => - NfptPure.parseHeader data) - match headerE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨header, start⟩ => - let tokensE ← timePure "read prompt tokens" (fun () => - NfptPure.readTokens data start header) - match tokensE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok tokens => - match deriveDirectionFromTokens tokens with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨dirTarget, dirNegative⟩ => - IO.println - s!"info: direction-target={dirTarget} direction-negative={dirNegative}" - let inputsE ← timePure "read head inputs" (fun () => - NfptPure.readInductionHeadInputs - data start header layer head dirTarget dirNegative period?) - match inputsE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok inputs => - checkInductionHeadInputsNonvacuous inputs splitCfg minActive? minLogitDiff? - minMargin? maxEps - -/-- Build head-output interval bounds from exact head inputs. -/ -def runInductionHeadInterval (inputsPath : System.FilePath) - (outPath? : Option System.FilePath) : IO UInt32 := do - let parsedInputs ← loadInductionHeadInputs inputsPath - match parsedInputs with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => - buildHeadOutputIntervalFromInputs inputs outPath? - -/-- Build head-output interval bounds from a model binary. -/ -def runInductionHeadIntervalModel (modelPath : System.FilePath) - (layer head dirTarget dirNegative : Nat) (period? : Option Nat) - (outPath? : Option System.FilePath) : IO UInt32 := do - let data ← IO.FS.readBinFile modelPath - match NfptPure.parseHeader data with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨header, start⟩ => - match - NfptPure.readInductionHeadInputs - data start header layer head dirTarget dirNegative period? - with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok inputs => - buildHeadOutputIntervalFromInputs inputs outPath? - -end IO - -end Nfp diff --git a/Nfp/IO/InductionHead/Basic.lean b/Nfp/IO/InductionHead/Basic.lean new file mode 100644 index 0000000..8922aa7 --- /dev/null +++ b/Nfp/IO/InductionHead/Basic.lean @@ -0,0 +1,1454 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Data.List.Range +import Nfp.IO.Pure +import Nfp.IO.NfptPure +import Nfp.IO.HeadScore +import Nfp.IO.Timing +import Nfp.IO.Util +import Nfp.Circuit.Cert.LogitDiff +import Nfp.Circuit.Cert.ResidualInterval +import Nfp.Sound.Induction +import Nfp.Sound.Induction.HeadBounds +import Nfp.Sound.Induction.LogitDiff +import Nfp.Sound.Linear.FinFold + +/-! +IO helpers for induction-head certificate construction. +-/ + +namespace Nfp + +namespace IO + +private def unwrapTaskResult {α : Type} (res : Except IO.Error α) : IO α := + match res with + | .ok a => pure a + | .error e => throw e + +private def configureTiming (timing? : Option Nat) (heartbeatMs? : Option Nat) : IO Unit := do + match timing? with + | some v => setTimingStdout (v ≠ 0) + | none => pure () + match heartbeatMs? with + | some v => + setTimingHeartbeatMs (UInt32.ofNat v) + if timing?.isNone && (v != 0) then + setTimingStdout true + | none => pure () + +private def splitConfigFromOptions + (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : + Sound.InductionHeadSplitConfig := + let base := Sound.defaultInductionHeadSplitConfig + { base with + splitBudgetQ := splitBudgetQ?.getD base.splitBudgetQ + splitBudgetK := splitBudgetK?.getD base.splitBudgetK + splitBudgetDiffBase := splitBudgetDiffBase?.getD base.splitBudgetDiffBase + splitBudgetDiffRefined := splitBudgetDiffRefined?.getD base.splitBudgetDiffRefined } + +open Nfp.Circuit + +private def valueBoundsModeFromEnv : IO (Option Bool) := do + match (← IO.getEnv "NFP_VALUE_BOUNDS_MODE") with + | some "common" => return some true + | some "cached" => return some false + | _ => return none + +/-- Resolve the heartbeat interval (ms) for long-running induction cert builds. -/ +private def heartbeatMs : IO UInt32 := + timingHeartbeatMs + +private def timePureWithHeartbeat {α : Type} (label : String) (f : Unit → α) : IO α := do + let t0 ← monoUsNow + timingPrint s!"timing: {label} start" + timingFlush + let task : Task α := Task.spawn (fun _ => f ()) + let heartbeatMs ← heartbeatMs + if heartbeatMs ≠ 0 then + let mut finished := (← IO.hasFinished task) + while !finished do + IO.sleep heartbeatMs + finished := (← IO.hasFinished task) + if !finished then + let now ← monoUsNow + timingPrint s!"timing: {label} running {now - t0} us" + timingFlush + let res ← IO.wait task + let t1 ← monoUsNow + timingPrint s!"timing: {label} {t1 - t0} us" + return res + +private def forceRat (x : Rat) : IO Unit := do + if x = x then + pure () + else + pure () + +/-- Profile the core induction-head bounds used by the sound certificate builder. -/ +private def timeInductionHeadCoreStages {seq dModel dHead : Nat} [NeZero seq] + (inputs : Model.InductionHeadInputs seq dModel dHead) : IO Unit := do + timingPrint "timing: core stages start" + timingFlush + let lnBounds ← timePureWithHeartbeat "core: ln bounds" (fun () => + Sound.headLnBounds inputs) + let lnLo := lnBounds.1 + let lnHi := lnBounds.2 + let lnAbsMaxTask : Fin seq → Rat := + Sound.Bounds.cacheBoundTask (fun q => + Sound.Bounds.intervalAbsBound (lnLo q) (lnHi q)) + let lnAbsMaxArr ← timePureWithHeartbeat "core: lnAbsMax force" (fun () => + Array.ofFn (fun q : Fin seq => lnAbsMaxTask q)) + let lnAbsMax : Fin seq → Rat := fun q => + lnAbsMaxArr.getD q.1 (0 : Rat) + let lnAbsMaxMax : Rat := + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := by + simp [univ] + univ.sup' hnonempty (fun q => lnAbsMax q) + let qAbsRowTasks : Array (Task (Array Rat)) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + Array.ofFn (fun d : Fin dHead => + Sound.Bounds.dotIntervalAbsBound (fun j => inputs.wq j d) (lnLo q) (lnHi q) + + |inputs.bq d|))) + let qAbsBaseArr ← timePureWithHeartbeat "core: qAbs base force" (fun () => + let defaultTask : Task (Array Rat) := Task.spawn (fun _ => #[]) + Array.ofFn (fun q : Fin seq => + (qAbsRowTasks.getD q.1 defaultTask).get)) + let qAbsBase : Fin seq → Fin dHead → Rat := fun q d => + let row := qAbsBaseArr.getD q.1 #[] + row.getD d.1 (0 : Rat) + let kAbsRowTasks : Array (Task (Array Rat)) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + Array.ofFn (fun d : Fin dHead => + Sound.Bounds.dotIntervalAbsBound (fun j => inputs.wk j d) (lnLo q) (lnHi q) + + |inputs.bk d|))) + let kAbsBaseArr ← timePureWithHeartbeat "core: kAbs base force" (fun () => + let defaultTask : Task (Array Rat) := Task.spawn (fun _ => #[]) + Array.ofFn (fun q : Fin seq => + (kAbsRowTasks.getD q.1 defaultTask).get)) + let kAbsBase : Fin seq → Fin dHead → Rat := fun q d => + let row := kAbsBaseArr.getD q.1 #[] + row.getD d.1 (0 : Rat) + let dotAbs ← timePureWithHeartbeat "core: dotAbs tasks" (fun () => + dotAbsFromQKV qAbsBase kAbsBase) + let _ ← timePureWithHeartbeat "core: dotAbs force" (fun () => + match List.finRange seq with + | [] => (0 : Rat) + | q :: _ => + match List.finRange seq with + | [] => (0 : Rat) + | k :: _ => dotAbs q k) + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => + |inputs.scale| * dotAbs q k + let scoreLo : Fin seq → Fin seq → Rat := fun q k => + if masked q k then inputs.maskValue else -scoreBaseAbs q k + let scoreHi : Fin seq → Fin seq → Rat := fun q k => + if masked q k then inputs.maskValue else scoreBaseAbs q k + let scoreLoPrev : Fin seq → Rat := fun q => + scoreLo q (inputs.prev q) + let otherKeys : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + let marginAt : Fin seq → Rat := fun q => + if hq : q ∈ inputs.active then + let other := otherKeys q + if h : other.Nonempty then + other.inf' h (fun k => scoreLoPrev q - scoreHi q k) + else + (0 : Rat) + else + (0 : Rat) + let margin ← timePureWithHeartbeat "core: margin" (fun () => + if h : inputs.active.Nonempty then + inputs.active.inf' h marginAt + else + (0 : Rat)) + let marginNeg ← timePureWithHeartbeat "core: margin < 0" (fun () => + decide (margin < 0)) + let verboseTiming ← IO.getEnv "NFP_TIMING_VERBOSE" + if verboseTiming.isSome then + timingPrint s!"timing: core: margin neg={marginNeg}" + let tEps0 ← monoUsNow + timingPrint "timing: core: eps start" + timingFlush + let eps := + if marginNeg then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + margin) + let tEps1 ← monoUsNow + timingPrint s!"timing: core: eps {tEps1 - tEps0} us" + timingFlush + let _ := marginAt + let dirHeadVec ← timePureWithHeartbeat "core: dir head vec" (fun () => + Sound.dirHeadVecOfInputs inputs) + let dirHead : Fin dHead → Rat := fun d => dirHeadVec.get d + let wvDir : Fin dModel → Rat := + Sound.Bounds.cacheBoundTask (fun j => + Sound.Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) + let _ ← timePureWithHeartbeat "core: wvDir force" (fun () => + Array.ofFn (fun j : Fin dModel => wvDir j)) + let bDir ← timePureWithHeartbeat "core: bDir" (fun () => + Sound.Linear.dotFin dHead dirHead (fun d => inputs.bv d)) + let valsAbsBase ← timePureWithHeartbeat "core: valsAbsBase" (fun () => + Sound.Linear.sumFin dModel (fun j => |wvDir j|) * lnAbsMaxMax) + let valsLoBase := bDir - valsAbsBase + let valsHiBase := bDir + valsAbsBase + let valsLo : Fin seq → Rat := fun _ => valsLoBase + let valsHi : Fin seq → Rat := fun _ => valsHiBase + let _ ← timePureWithHeartbeat "core: value bounds" (fun () => + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := by + simp [univ] + let lo := univ.inf' hnonempty valsLo + let hi := univ.sup' hnonempty valsHi + (lo, hi)) + timingPrint "timing: core stages done" + timingFlush + +/-- Load induction head inputs from disk. -/ +def loadInductionHeadInputs (path : System.FilePath) : + IO (Except String (Sigma (fun seq => + Sigma (fun dModel => Sigma (fun dHead => Model.InductionHeadInputs seq dModel dHead))))) := do + let t0 ← monoUsNow + let data ← IO.FS.readFile path + let t1 ← monoUsNow + timingPrint s!"timing: read head input file {t1 - t0} us" + let t2 ← monoUsNow + let parsed := + match Pure.parseInductionHeadInputs data with + | Except.error msg => Except.error msg + | Except.ok v => Except.ok v + let t3 ← monoUsNow + timingPrint s!"timing: parse head input file {t3 - t2} us" + return parsed + +private def ratToString (x : Rat) : String := + toString x + +private def renderResidualIntervalCert {n : Nat} (c : Circuit.ResidualIntervalCert n) : String := + let header := s!"dim {n}" + let lines := + (List.finRange n).foldr (fun i acc => + s!"lo {i.val} {ratToString (c.lo i)}" :: + s!"hi {i.val} {ratToString (c.hi i)}" :: acc) [] + String.intercalate "\n" (header :: lines) + +private def emitResidualIntervalCert {n : Nat} (c : Circuit.ResidualIntervalCert n) + (outPath? : Option System.FilePath) : IO Unit := do + let payload := renderResidualIntervalCert c + match outPath? with + | some path => IO.FS.writeFile path (payload ++ "\n") + | none => IO.println payload + +private def buildHeadOutputIntervalFromInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (outPath? : Option System.FilePath) : IO UInt32 := do + match seq with + | 0 => + IO.eprintln "error: seq must be positive" + return 2 + | Nat.succ n => + let seq := Nat.succ n + let _ : NeZero seq := ⟨by simp⟩ + match Sound.buildHeadOutputIntervalFromHead? inputs with + | none => + IO.eprintln "error: head output interval rejected" + return 2 + | some result => + emitResidualIntervalCert result.cert outPath? + if outPath?.isSome then + let activeCount := result.active.card + IO.println + s!"ok: head output interval built (seq={seq}, dim={dModel}, active={activeCount})" + return 0 + +private def headScoreBoundsFromDotAbsTimed {seq dModel dHead : Nat} [NeZero seq] + (inputs : Model.InductionHeadInputs seq dModel dHead) + (dotAbs : Fin seq → Fin seq → Rat) : + IO (Sound.HeadScoreBounds seq dModel dHead) := do + timePure "head: score bounds" (fun () => + Sound.headScoreBoundsFromDotAbs inputs dotAbs) + +private def headScoreBoundsFromQAbsKAbsTimed {seq dModel dHead : Nat} [NeZero seq] + (inputs : Model.InductionHeadInputs seq dModel dHead) + (qAbs kAbs : Fin seq → Fin dHead → Rat) + (dotAbs : Fin seq → Fin seq → Rat) : + IO (Sound.HeadScoreBounds seq dModel dHead) := do + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => + |inputs.scale| * dotAbs q k + let scoreLo : Fin seq → Fin seq → Rat := fun q k => + if masked q k then inputs.maskValue else -scoreBaseAbs q k + let scoreHi : Fin seq → Fin seq → Rat := fun q k => + if masked q k then inputs.maskValue else scoreBaseAbs q k + let kAbsMax : Fin dHead → Rat := fun d => + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := Finset.univ_nonempty + univ.sup' hnonempty (fun k => kAbs k d) + let dotAbsUpper : Fin seq → Rat := fun q => + Sound.Linear.dotFin dHead (fun d => qAbs q d) kAbsMax + let scoreHiUpper : Fin seq → Rat := fun q => + max inputs.maskValue (|inputs.scale| * dotAbsUpper q) + let marginTasks : Array (Task Rat) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + if q ∈ inputs.active then + let prev := inputs.prev q + let scoreLoPrev := scoreLo q prev + scoreLoPrev - scoreHiUpper q + else + (0 : Rat))) + let marginAt : Fin seq → Rat := fun q => + (marginTasks[q.1]'(by + simp [marginTasks, q.isLt])).get + let epsTasks : Array (Task Rat) := + Array.ofFn (fun q : Fin seq => + (marginTasks[q.1]'(by + simp [marginTasks, q.isLt])).map (fun m => + if m < 0 then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + m))) + let epsAt : Fin seq → Rat := fun q => + (epsTasks[q.1]'(by + simp [epsTasks, q.isLt])).get + let margin : Rat := + if h : inputs.active.Nonempty then + inputs.active.inf' h marginAt + else + (0 : Rat) + let eps : Rat := + if margin < 0 then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + margin) + let result : Sound.HeadScoreBounds seq dModel dHead := + { dotAbs := dotAbs + scoreBaseAbs := scoreBaseAbs + scoreAbs := fun q k => if masked q k then |inputs.maskValue| else scoreBaseAbs q k + scoreLo := scoreLo + scoreHi := scoreHi + marginAt := marginAt + epsAt := epsAt + margin := margin + eps := eps } + return result + +private def checkInductionHeadInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cfg : Sound.InductionHeadSplitConfig) + (minActive? : Option Nat) (minLogitDiff? : Option Rat) + (minMargin maxEps : Rat) : IO UInt32 := do + match seq with + | 0 => + IO.eprintln "error: seq must be positive" + return 2 + | Nat.succ n => + let seq := Nat.succ n + let _ : NeZero seq := ⟨by simp⟩ + logTiming "start: head build induction cert" + timingPrint "timing: head build induction cert start" + timingFlush + let verboseTiming ← IO.getEnv "NFP_TIMING_VERBOSE" + if verboseTiming.isSome then + timingPrint s!"timing: head dims seq={seq} dModel={dModel} dHead={dHead}" + timingPrint s!"timing: head active card={inputs.active.card}" + timingFlush + let precompute := (← IO.getEnv "NFP_TIMING_PRECOMPUTE").isSome + if precompute then + timingPrint "timing: head ln bounds start" + timingFlush + let lnBounds ← timePure "head: ln bounds" (fun () => + Sound.headLnBounds inputs) + timingPrint "timing: head ln bounds done" + timingFlush + timingPrint "timing: head qkv bounds start" + timingFlush + let lnLo := lnBounds.1 + let lnHi := lnBounds.2 + let qkv ← timePure "head: qkv bounds" (fun () => + Sound.headQKVBounds inputs lnLo lnHi) + timingPrint "timing: head qkv bounds done" + timingFlush + if verboseTiming.isSome then + timingPrint "timing: head qkv abs force start" + timingFlush + let tAbs0 ← monoUsNow + for q in List.finRange seq do + for d in List.finRange dHead do + let _ := qkv.qAbs q d + let _ := qkv.kAbs q d + pure () + let tAbs1 ← monoUsNow + timingPrint s!"timing: head qkv abs force {tAbs1 - tAbs0} us" + timingFlush + timingPrint "timing: head score/value bounds spawn start" + timingFlush + let tSpawn0 ← monoUsNow + if verboseTiming.isSome then + timingPrint "timing: head score dotAbs tasks start" + timingFlush + let dotAbs ← timePure "head: score dotAbs tasks" (fun () => + dotAbsFromQKV qkv.qAbs qkv.kAbs) + if verboseTiming.isSome then + timingPrint "timing: head score dotAbs tasks done" + timingFlush + if verboseTiming.isSome then + timingPrint "timing: head score dotAbs force start" + timingFlush + let tForce0 ← monoUsNow + match List.finRange seq with + | [] => + timingPrint "timing: head score dotAbs force skipped (empty seq)" + | q :: _ => + match List.finRange seq with + | [] => + timingPrint "timing: head score dotAbs force skipped (empty seq)" + | k :: _ => + let _ := dotAbs q k + pure () + let tForce1 ← monoUsNow + timingPrint s!"timing: head score dotAbs force {tForce1 - tForce0} us" + timingFlush + let inlineVals := (← IO.getEnv "NFP_TIMING_VALUE_INLINE").isSome + let valueMode? ← valueBoundsModeFromEnv + let useCommon := valueMode?.getD false + let (valsInline?, valsTask?) := + if inlineVals then + let vals := + if useCommon then + Sound.headValueBoundsCommonDen inputs qkv.vLo qkv.vHi + else + Sound.headValueBounds inputs qkv.vLo qkv.vHi + (some vals, none) + else + let task := + if useCommon then + Sound.headValueBoundsCommonDenTask inputs qkv.vLo qkv.vHi + else + Sound.headValueBoundsTask inputs qkv.vLo qkv.vHi + (none, some task) + let activeList := (List.finRange seq).filter (fun q => q ∈ inputs.active) + if verboseTiming.isSome then + timeHeadScoreMarginRaw inputs dotAbs activeList + let tSpawn1 ← monoUsNow + timingPrint s!"timing: head score/value bounds spawn {tSpawn1 - tSpawn0} us" + timingFlush + let skipScoreBounds := (← IO.getEnv "NFP_TIMING_SKIP_SCORE_BOUNDS").isSome + let scoreTaskOpt ← + if skipScoreBounds then + timingPrint "timing: head score bounds skipped" + pure none + else + timingPrint "timing: head score bounds from dotAbs start" + timingFlush + let exactMargin := (← IO.getEnv "NFP_TIMING_EXACT_MARGIN").isSome + let action := + if exactMargin then + headScoreBoundsFromDotAbsTimed inputs dotAbs + else + headScoreBoundsFromQAbsKAbsTimed inputs qkv.qAbs qkv.kAbs dotAbs + let t ← action.asTask + pure (some t) + if verboseTiming.isSome then + timingPrint "timing: head value parts start" + timingFlush + timingPrint "timing: head value dirHead start" + timingFlush + let tDir0 ← monoUsNow + let dirHead := Sound.headValueDirHead inputs + match List.finRange dHead with + | [] => + timingPrint "timing: head value dirHead forced skipped (empty dHead)" + | d :: _ => + let _ := dirHead d + pure () + let tDir1 ← monoUsNow + timingPrint s!"timing: head value dirHead {tDir1 - tDir0} us" + timingFlush + timingPrint "timing: head value valsLo start" + timingFlush + let tLo0 ← monoUsNow + let valsLo := Sound.headValueValsLo inputs qkv.vLo qkv.vHi + match List.finRange seq with + | [] => + timingPrint "timing: head value valsLo forced skipped (empty seq)" + | k :: _ => + let _ := valsLo k + pure () + let tLo1 ← monoUsNow + timingPrint s!"timing: head value valsLo {tLo1 - tLo0} us" + timingFlush + timingPrint "timing: head value valsHi start" + timingFlush + let tHi0 ← monoUsNow + let valsHi := Sound.headValueValsHi inputs qkv.vLo qkv.vHi + match List.finRange seq with + | [] => + timingPrint "timing: head value valsHi forced skipped (empty seq)" + | k :: _ => + let _ := valsHi k + pure () + let tHi1 ← monoUsNow + timingPrint s!"timing: head value valsHi {tHi1 - tHi0} us" + timingFlush + timingPrint "timing: head value lo start" + timingFlush + let tLo2 ← monoUsNow + let _ := Sound.headValueLo valsLo + let tLo3 ← monoUsNow + timingPrint s!"timing: head value lo {tLo3 - tLo2} us" + timingFlush + timingPrint "timing: head value hi start" + timingFlush + let tHi2 ← monoUsNow + let _ := Sound.headValueHi valsHi + let tHi3 ← monoUsNow + timingPrint s!"timing: head value hi {tHi3 - tHi2} us" + timingFlush + timingPrint "timing: head value parts done" + timingFlush + timingPrint "timing: head value bounds start" + timingFlush + let tVals0 ← monoUsNow + let vals ← + match valsInline?, valsTask? with + | some vals, _ => + timePure "head: value bounds inline" (fun () => vals) + | none, some valsTask => + timePure "head: value bounds wait" (fun () => valsTask.get) + | none, none => + timePure "head: value bounds inline" (fun () => + Sound.headValueBounds inputs qkv.vLo qkv.vHi) + let tVals1 ← monoUsNow + timingPrint s!"timing: head value bounds {tVals1 - tVals0} us" + timingFlush + let scoreOpt ← + match scoreTaskOpt with + | none => pure none + | some scoreTask => do + let res ← IO.wait scoreTask + let score ← unwrapTaskResult res + timingPrint "timing: head score bounds from dotAbs done" + timingFlush + pure (some score) + match scoreOpt with + | none => pure () + | some score => + if verboseTiming.isSome then + timeHeadScoreSampleGap inputs score + if verboseTiming.isSome then + timeHeadScoreMarginList activeList score + if verboseTiming.isSome then + timeHeadScoreFieldForces score + if verboseTiming.isSome then + timingPrint "timing: head score bounds force start" + timingFlush + let tScore0 ← monoUsNow + let _ := score.margin + let _ := score.eps + let tScore1 ← monoUsNow + timingPrint s!"timing: head score bounds force {tScore1 - tScore0} us" + timingFlush + let coreStages := (← IO.getEnv "NFP_TIMING_CORE_STAGES").isSome + let coreStagesOnly := (← IO.getEnv "NFP_TIMING_CORE_STAGES_ONLY").isSome + if coreStages then + timeInductionHeadCoreStages inputs + if coreStagesOnly then + return 0 + let breakdown := (← IO.getEnv "NFP_TIMING_BREAKDOWN").isSome + if breakdown then + let lnBounds ← timePureWithHeartbeat "breakdown: ln bounds" (fun () => + Sound.headLnBounds inputs) + timingPrint "timing: breakdown ln bounds force start" + timingFlush + let tLn0 ← monoUsNow + for q in List.finRange seq do + for i in List.finRange dModel do + let _ := lnBounds.1 q i + let _ := lnBounds.2 q i + pure () + let tLn1 ← monoUsNow + timingPrint s!"timing: breakdown ln bounds force {tLn1 - tLn0} us" + timingFlush + let qkv ← timePureWithHeartbeat "breakdown: qkv bounds" (fun () => + Sound.headQKVBounds inputs lnBounds.1 lnBounds.2) + timingPrint "timing: breakdown qkv bounds force start" + timingFlush + let tQkv0 ← monoUsNow + for q in List.finRange seq do + for d in List.finRange dHead do + let _ := qkv.qLo q d + let _ := qkv.qHi q d + let _ := qkv.kLo q d + let _ := qkv.kHi q d + let _ := qkv.vLo q d + let _ := qkv.vHi q d + let _ := qkv.qAbs q d + let _ := qkv.kAbs q d + pure () + let tQkv1 ← monoUsNow + timingPrint s!"timing: breakdown qkv bounds force {tQkv1 - tQkv0} us" + timingFlush + let dotAbs : Fin seq → Fin seq → Rat := fun q k => + Sound.Linear.dotFin dHead (fun d => qkv.qAbs q d) (fun d => qkv.kAbs k d) + let dotAbsRowTasks : + Array (Task { row : Array Rat // row.size = seq }) ← + timePureWithHeartbeat "breakdown: score dotAbs rows" (fun () => + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + ⟨Array.ofFn (fun k : Fin seq => dotAbs q k), by simp⟩))) + let dotAbsRowDefault : Task { row : Array Rat // row.size = seq } := + Task.spawn (fun _ => ⟨Array.ofFn (fun _ : Fin seq => (0 : Rat)), by simp⟩) + timingPrint "timing: breakdown score dotAbs force start" + timingFlush + let tDot0 ← monoUsNow + for q in List.finRange seq do + let row := (dotAbsRowTasks.getD q.1 dotAbsRowDefault).get + let _ := row + pure () + let tDot1 ← monoUsNow + timingPrint s!"timing: breakdown score dotAbs force {tDot1 - tDot0} us" + timingFlush + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let scaleAbs : Rat := |inputs.scale| + let marginAtRaw : Fin seq → Rat := fun q => + let row := (dotAbsRowTasks.getD q.1 dotAbsRowDefault).get + if q ∈ inputs.active then + let rowArr := row.1 + let prev := inputs.prev q + let dotAbsPrev := rowArr.getD prev.1 0 + if masked q prev then + let scoreLoPrev := inputs.maskValue + let scoreHiAt : Fin seq → Rat := fun k => + if masked q k then inputs.maskValue else scaleAbs * rowArr.getD k.1 0 + let maskedGap := scoreLoPrev - inputs.maskValue + let step : + (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := + fun acc k => + if k = prev then + acc + else if masked q k then + (acc.1, true) + else + let v := scoreLoPrev - scoreHiAt k + match acc.1 with + | none => (some v, acc.2) + | some cur => (some (min cur v), acc.2) + let acc := Sound.Linear.foldlFin seq step (none, false) + match acc.1, acc.2 with + | some unmaskedMin, true => min unmaskedMin maskedGap + | some unmaskedMin, false => unmaskedMin + | none, true => maskedGap + | none, false => (0 : Rat) + else + let scoreLoPrev := -(scaleAbs * dotAbsPrev) + let maskedGap := scoreLoPrev - inputs.maskValue + let step : + (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := + fun acc k => + if k = prev then + acc + else if masked q k then + (acc.1, true) + else + let raw := -(dotAbsPrev + rowArr.getD k.1 0) + match acc.1 with + | none => (some raw, acc.2) + | some cur => (some (min cur raw), acc.2) + let acc := Sound.Linear.foldlFin seq step (none, false) + match acc.1, acc.2 with + | some unmaskedMin, true => min (scaleAbs * unmaskedMin) maskedGap + | some unmaskedMin, false => scaleAbs * unmaskedMin + | none, true => maskedGap + | none, false => (0 : Rat) + else + (0 : Rat) + let marginAtCached : Fin seq → Rat ← + timePureWithHeartbeat "breakdown: score margin cache" (fun () => + Sound.Bounds.cacheBoundThunk marginAtRaw) + timingPrint "timing: breakdown score margin force start" + timingFlush + let tMargin0 ← monoUsNow + for q in List.finRange seq do + let m := marginAtCached q + forceRat m + pure () + let tMargin1 ← monoUsNow + timingPrint s!"timing: breakdown score margin force {tMargin1 - tMargin0} us" + timingFlush + let epsAtRaw : Fin seq → Rat := fun q => + let m := marginAtCached q + if m < 0 then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + m) + let epsAtCached : Fin seq → Rat ← + timePureWithHeartbeat "breakdown: score eps cache" (fun () => + Sound.Bounds.cacheBoundThunk epsAtRaw) + timingPrint "timing: breakdown score eps force start" + timingFlush + let tEps0 ← monoUsNow + for q in List.finRange seq do + let e := epsAtCached q + forceRat e + pure () + let tEps1 ← monoUsNow + timingPrint s!"timing: breakdown score eps force {tEps1 - tEps0} us" + timingFlush + let valsLo ← timePureWithHeartbeat "breakdown: value valsLo" (fun () => + Sound.headValueValsLo inputs qkv.vLo qkv.vHi) + timingPrint "timing: breakdown value valsLo force start" + timingFlush + let tValsLo0 ← monoUsNow + for k in List.finRange seq do + let v := valsLo k + forceRat v + pure () + let tValsLo1 ← monoUsNow + timingPrint s!"timing: breakdown value valsLo force {tValsLo1 - tValsLo0} us" + timingFlush + let valsHi ← timePureWithHeartbeat "breakdown: value valsHi" (fun () => + Sound.headValueValsHi inputs qkv.vLo qkv.vHi) + timingPrint "timing: breakdown value valsHi force start" + timingFlush + let tValsHi0 ← monoUsNow + for k in List.finRange seq do + let v := valsHi k + forceRat v + pure () + let tValsHi1 ← monoUsNow + timingPrint s!"timing: breakdown value valsHi force {tValsHi1 - tValsHi0} us" + timingFlush + let heartbeatMsProgress ← heartbeatMs + let taskMin (t1 t2 : Task Rat) : Task Rat := + Task.bind t1 (fun v1 => Task.map (fun v2 => min v1 v2) t2) + let taskMax (t1 t2 : Task Rat) : Task Rat := + Task.bind t1 (fun v1 => Task.map (fun v2 => max v1 v2) t2) + let reduceMinTasksWithProgress (tasks : Array (Task Rat)) : + IO Rat := do + let n := tasks.size + if n = 0 then + pure (0 : Rat) + else + let chunkSize : Nat := 16 + let chunks : Nat := (n + chunkSize - 1) / chunkSize + let defaultTask : Task Rat := Task.pure (0 : Rat) + let chunkTasks : Array (Task Rat) := + Array.ofFn (fun c : Fin chunks => + let start := c.val * chunkSize + let stop := Nat.min n (start + chunkSize) + let init := tasks.getD start defaultTask + if stop ≤ start + 1 then + init + else + let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) + rest.foldl (fun acc i => taskMin acc (tasks.getD i defaultTask)) init) + if heartbeatMsProgress ≠ 0 then + let mut finished := 0 + let mut remaining := chunkTasks.size + while finished < remaining do + IO.sleep heartbeatMsProgress + let mut count := 0 + for t in chunkTasks do + if (← IO.hasFinished t) then + count := count + 1 + finished := count + remaining := chunkTasks.size + if finished < remaining then + timingPrint s!"timing: breakdown value lo progress {finished}/{remaining}" + timingFlush + let init := chunkTasks.getD 0 defaultTask + let rest := (List.range (chunkTasks.size - 1)).map (fun i => i + 1) + pure ((rest.foldl (fun acc i => taskMin acc (chunkTasks.getD i defaultTask)) init).get) + let reduceMaxTasksWithProgress (tasks : Array (Task Rat)) : + IO Rat := do + let n := tasks.size + if n = 0 then + pure (0 : Rat) + else + let chunkSize : Nat := 16 + let chunks : Nat := (n + chunkSize - 1) / chunkSize + let defaultTask : Task Rat := Task.pure (0 : Rat) + let chunkTasks : Array (Task Rat) := + Array.ofFn (fun c : Fin chunks => + let start := c.val * chunkSize + let stop := Nat.min n (start + chunkSize) + let init := tasks.getD start defaultTask + if stop ≤ start + 1 then + init + else + let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) + rest.foldl (fun acc i => taskMax acc (tasks.getD i defaultTask)) init) + if heartbeatMsProgress ≠ 0 then + let mut finished := 0 + let mut remaining := chunkTasks.size + while finished < remaining do + IO.sleep heartbeatMsProgress + let mut count := 0 + for t in chunkTasks do + if (← IO.hasFinished t) then + count := count + 1 + finished := count + remaining := chunkTasks.size + if finished < remaining then + timingPrint s!"timing: breakdown value hi progress {finished}/{remaining}" + timingFlush + let init := chunkTasks.getD 0 defaultTask + let rest := (List.range (chunkTasks.size - 1)).map (fun i => i + 1) + pure ((rest.foldl (fun acc i => taskMax acc (chunkTasks.getD i defaultTask)) init).get) + if (← IO.getEnv "NFP_TIMING_TASK_PROGRESS").isSome then + let tasksLo := + (List.finRange seq).map (fun k => Task.spawn (fun _ => valsLo k)) + let tasksHi := + (List.finRange seq).map (fun k => Task.spawn (fun _ => valsHi k)) + let _ ← timePureWithHeartbeat "breakdown: value lo progress" (fun () => + reduceMinTasksWithProgress tasksLo.toArray) + let _ ← timePureWithHeartbeat "breakdown: value hi progress" (fun () => + reduceMaxTasksWithProgress tasksHi.toArray) + else + let loTask := Sound.headValueLoTask valsLo + let hiTask := Sound.headValueHiTask valsHi + let heartbeatMs ← heartbeatMs + let tLo0 ← monoUsNow + if heartbeatMs ≠ 0 then + let mut finished := (← IO.hasFinished loTask) + while !finished do + IO.sleep heartbeatMs + finished := (← IO.hasFinished loTask) + if !finished then + let now ← monoUsNow + timingPrint s!"timing: breakdown: value lo running {now - tLo0} us" + timingFlush + let lo := loTask.get + let tLo1 ← monoUsNow + timingPrint s!"timing: breakdown: value lo {tLo1 - tLo0} us" + timingFlush + let tHi0 ← monoUsNow + if heartbeatMs ≠ 0 then + let mut finished := (← IO.hasFinished hiTask) + while !finished do + IO.sleep heartbeatMs + finished := (← IO.hasFinished hiTask) + if !finished then + let now ← monoUsNow + timingPrint s!"timing: breakdown: value hi running {now - tHi0} us" + timingFlush + let hi := hiTask.get + let tHi1 ← monoUsNow + timingPrint s!"timing: breakdown: value hi {tHi1 - tHi0} us" + timingFlush + let _ := lo + let _ := hi + if (← IO.getEnv "NFP_TIMING_SEQ_REDUCE").isSome then + let loSeq ← timePureWithHeartbeat "breakdown: value lo seq" (fun () => + match List.finRange seq with + | [] => (0 : Rat) + | k :: ks => + let init := valsLo k + ks.foldl (fun acc k => min acc (valsLo k)) init) + let hiSeq ← timePureWithHeartbeat "breakdown: value hi seq" (fun () => + match List.finRange seq with + | [] => (0 : Rat) + | k :: ks => + let init := valsHi k + ks.foldl (fun acc k => max acc (valsHi k)) init) + let _ := loSeq + let _ := hiSeq + let tCert0 ← monoUsNow + let certTask : + Task + (Option { c : Sound.InductionHeadCert seq // + Sound.InductionHeadCertSound inputs c }) := + Task.spawn (prio := Task.Priority.dedicated) (fun _ => + match Sound.buildInductionCertFromHeadWith? cfg inputs with + | none => none + | some ⟨cert, hcert⟩ => + let _ := cert.active.card + some ⟨cert, hcert⟩) + let heartbeatMs ← heartbeatMs + if heartbeatMs ≠ 0 then + let mut finished := (← IO.hasFinished certTask) + while !finished do + IO.sleep heartbeatMs + finished := (← IO.hasFinished certTask) + if !finished then + let now ← monoUsNow + timingPrint s!"timing: head build induction cert running {now - tCert0} us" + timingFlush + let certOpt ← IO.wait certTask + let tCert1 ← monoUsNow + logTiming s!"done: head build induction cert {tCert1 - tCert0} us" + timingPrint s!"timing: head build induction cert {tCert1 - tCert0} us" + timingPrint "timing: head build induction cert returned" + timingFlush + match certOpt with + | none => + IO.eprintln "error: head inputs rejected" + return 2 + | some ⟨cert, _hcert⟩ => + timingPrint "timing: head active count start" + timingFlush + let activeCount := cert.active.card + timingPrint "timing: head active count done" + timingFlush + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" + return 2 + if cert.margin < minMargin then + IO.eprintln + s!"error: margin {ratToString cert.margin} \ + below minimum {ratToString minMargin}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {ratToString cert.eps} \ + above maximum {ratToString maxEps}" + return 2 + timingPrint "timing: head tol start" + timingFlush + let tol := cert.eps * (cert.values.hi - cert.values.lo) + timingPrint "timing: head tol done" + timingFlush + logTiming "start: head logit-diff lower bound" + timingPrint "timing: head logit-diff lower bound start" + timingFlush + let logitDiffLB? ← timePure "head: logit-diff lower bound" (fun () => + Circuit.logitDiffLowerBound cert.active cert.prev cert.eps + cert.values.lo cert.values.hi cert.values.valsLo) + logTiming "done: head logit-diff lower bound" + let effectiveMinLogitDiff := + match minLogitDiff? with + | some v => some v + | none => some (0 : Rat) + match logitDiffLB? with + | none => + IO.eprintln "error: empty active set for logit-diff bound" + return 2 + | some logitDiffLB => + let violation? : Option Rat := + match effectiveMinLogitDiff with + | none => none + | some minLogitDiff => + if logitDiffLB < minLogitDiff then + some minLogitDiff + else + none + match violation? with + | some minLogitDiff => + IO.eprintln + s!"error: logitDiffLB {ratToString logitDiffLB} \ + below minimum {ratToString minLogitDiff}" + return 2 + | none => + IO.println + s!"ok: induction bound certified \ + (seq={seq}, active={activeCount}, \ + tol={ratToString tol}, logitDiffLB={ratToString logitDiffLB})" + return 0 + +private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cfg : Sound.InductionHeadSplitConfig) + (minActive? : Option Nat) (minLogitDiff? : Option Rat) + (minMargin? : Option Rat) (maxEps : Rat) : IO UInt32 := do + match seq with + | 0 => + IO.eprintln "error: seq must be positive" + return 2 + | Nat.succ n => + let seq := Nat.succ n + let _ : NeZero seq := ⟨by simp⟩ + logTiming "start: head build induction cert" + timingPrint "timing: head build induction cert start" + timingFlush + let tCert0 ← monoUsNow + let certTask : + Task + (Option { c : Sound.InductionHeadCert seq // + Sound.InductionHeadCertSound inputs c }) := + Task.spawn (prio := Task.Priority.dedicated) (fun _ => + match Sound.buildInductionCertFromHeadWith? cfg inputs with + | none => none + | some ⟨cert, hcert⟩ => + let _ := cert.active.card + some ⟨cert, hcert⟩) + let heartbeatMs ← heartbeatMs + if heartbeatMs ≠ 0 then + let mut finished := (← IO.hasFinished certTask) + while !finished do + IO.sleep heartbeatMs + finished := (← IO.hasFinished certTask) + if !finished then + let now ← monoUsNow + timingPrint s!"timing: head build induction cert running {now - tCert0} us" + timingFlush + let certOpt ← IO.wait certTask + let tCert1 ← monoUsNow + logTiming s!"done: head build induction cert {tCert1 - tCert0} us" + timingPrint s!"timing: head build induction cert {tCert1 - tCert0} us" + timingPrint "timing: head build induction cert returned" + timingFlush + match certOpt with + | none => + IO.eprintln "error: head inputs rejected" + return 2 + | some ⟨cert, _hcert⟩ => + let activeCount := cert.active.card + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {ratToString cert.eps} above maximum {ratToString maxEps}" + return 2 + let marginViolation? : Option Rat := + match minMargin? with + | none => none + | some minMargin => + if cert.margin < minMargin then + some minMargin + else + none + match marginViolation? with + | some minMargin => + IO.eprintln + s!"error: margin {ratToString cert.margin} \ + below minimum {ratToString minMargin}" + return 2 + | none => pure () + logTiming "start: head logit-diff lower bound" + timingPrint "timing: head logit-diff lower bound start" + timingFlush + let logitDiffLB0? ← timePure "head: logit-diff lower bound" (fun () => + Sound.logitDiffLowerBoundFromCert cert) + let needsWeighted : Bool := + match logitDiffLB0? with + | none => true + | some lb0 => + if lb0 ≤ 0 then + true + else + match minLogitDiff? with + | some minLogitDiff => lb0 < minLogitDiff + | none => false + let logitDiffWeighted? ← + if needsWeighted then + timePure "head: logit-diff lower bound weighted" (fun () => + Sound.logitDiffLowerBoundFromCertWeighted cert) + else + pure none + let logitDiffLB? : Option Rat := + match logitDiffLB0?, logitDiffWeighted? with + | some lb0, some lb1 => some (max lb0 lb1) + | some lb0, none => some lb0 + | none, some lb1 => some lb1 + | none, none => none + let boundLabel : String := + match logitDiffLB0?, logitDiffWeighted? with + | some _, some _ => "max" + | none, some _ => "weighted" + | some _, none => "eps" + | none, none => "none" + logTiming "done: head logit-diff lower bound" + match logitDiffLB? with + | none => + IO.eprintln "error: empty active set for logit-diff bound" + return 2 + | some logitDiffLB => + if logitDiffLB ≤ 0 then + IO.eprintln + s!"error: logitDiffLB {ratToString logitDiffLB} \ + is not strictly positive" + return 2 + let violation? : Option Rat := + match minLogitDiff? with + | none => none + | some minLogitDiff => + if logitDiffLB < minLogitDiff then + some minLogitDiff + else + none + match violation? with + | some minLogitDiff => + IO.eprintln + s!"error: logitDiffLB {ratToString logitDiffLB} \ + below minimum {ratToString minLogitDiff}" + return 2 + | none => pure () + let tol := cert.eps * (cert.values.hi - cert.values.lo) + IO.println + s!"ok: nonvacuous induction bound certified \ + (seq={seq}, active={activeCount}, \ + tol={ratToString tol}, logitDiffLB={ratToString logitDiffLB}, \ + bound={boundLabel})" + return 0 + +/-- Build and check induction certificates from exact head inputs. -/ +def runInductionCertifyHead (inputsPath : System.FilePath) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) + (timing? : Option Nat) (heartbeatMs? : Option Nat) + (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : + IO UInt32 := do + configureTiming timing? heartbeatMs? + let splitCfg := + splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) + let parsedInputs ← timePhase "load head inputs" <| + loadInductionHeadInputs inputsPath + match parsedInputs with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => + checkInductionHeadInputs inputs splitCfg minActive? minLogitDiff? minMargin maxEps + +/-- Build and check a strictly positive induction logit-diff bound from head inputs. -/ +def runInductionCertifyHeadNonvacuous (inputsPath : System.FilePath) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) + (timing? : Option Nat) (heartbeatMs? : Option Nat) + (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : + IO UInt32 := do + configureTiming timing? heartbeatMs? + let splitCfg := + splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) + let parsedInputs ← timePhase "load head inputs" <| + loadInductionHeadInputs inputsPath + match parsedInputs with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => + checkInductionHeadInputsNonvacuous inputs splitCfg minActive? minLogitDiff? + minMargin? maxEps + +/-- Build and check induction certificates from a model binary. -/ +def runInductionCertifyHeadModel (modelPath : System.FilePath) + (layer head dirTarget dirNegative : Nat) (period? : Option Nat) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) + (timing? : Option Nat) (heartbeatMs? : Option Nat) + (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : + IO UInt32 := do + configureTiming timing? heartbeatMs? + let splitCfg := + splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) + logTiming "start: read model file" + timingPrint "timing: read model file start" + timingFlush + let data ← timePhase "read model file" <| IO.FS.readBinFile modelPath + let headerE ← timePure "parse model header" (fun () => + NfptPure.parseHeader data) + match headerE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨header, start⟩ => + let inputsE ← timePure "read head inputs" (fun () => + NfptPure.readInductionHeadInputs + data start header layer head dirTarget dirNegative period?) + match inputsE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok inputs => + checkInductionHeadInputs inputs splitCfg minActive? minLogitDiff? minMargin maxEps + +/-- Build and check a strictly positive induction logit-diff bound from a model binary. -/ +def runInductionCertifyHeadModelNonvacuous (modelPath : System.FilePath) + (layer head dirTarget dirNegative : Nat) (period? : Option Nat) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) + (timing? : Option Nat) (heartbeatMs? : Option Nat) + (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : + IO UInt32 := do + configureTiming timing? heartbeatMs? + let splitCfg := + splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) + logTiming "start: read model file" + timingPrint "timing: read model file start" + timingFlush + let data ← timePhase "read model file" <| IO.FS.readBinFile modelPath + let headerE ← timePure "parse model header" (fun () => + NfptPure.parseHeader data) + match headerE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨header, start⟩ => + let inputsE ← timePure "read head inputs" (fun () => + NfptPure.readInductionHeadInputs + data start header layer head dirTarget dirNegative period?) + match inputsE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok inputs => + checkInductionHeadInputsNonvacuous inputs splitCfg minActive? minLogitDiff? + minMargin? maxEps + +/-- Heuristic logit-diff direction derived from prompt tokens. -/ +private def deriveDirectionFromTokens {seq : Nat} (tokens : Fin seq → Nat) : + Except String (Nat × Nat) := do + let tokenArr : Array Nat := Array.ofFn (fun i : Fin seq => tokens i) + let n := tokenArr.size + if n < 2 then + throw "token sequence must have length at least 2" + let lastTok := tokenArr.getD (n - 1) 0 + let prevIdx? := + (List.range (n - 1)).reverse.find? (fun i => + tokenArr.getD i lastTok = lastTok) + let targetTok := + match prevIdx? with + | some i => tokenArr.getD (i + 1) lastTok + | none => lastTok + let neg0 := tokenArr.getD (n - 2) lastTok + let neg := + if neg0 = targetTok then + if lastTok ≠ targetTok then + lastTok + else if targetTok ≠ 0 then + 0 + else + 1 + else + neg0 + return (targetTok, neg) + +/-- Build and check induction certificates from a model binary, deriving direction tokens from the +prompt sequence. -/ +def runInductionCertifyHeadModelAuto (modelPath : System.FilePath) + (layer head : Nat) (period? : Option Nat) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) + (timing? : Option Nat) (heartbeatMs? : Option Nat) + (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : + IO UInt32 := do + configureTiming timing? heartbeatMs? + let splitCfg := + splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) + logTiming "start: read model file" + timingPrint "timing: read model file start" + timingFlush + let data ← timePhase "read model file" <| IO.FS.readBinFile modelPath + let headerE ← timePure "parse model header" (fun () => + NfptPure.parseHeader data) + match headerE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨header, start⟩ => + let tokensE ← timePure "read prompt tokens" (fun () => + NfptPure.readTokens data start header) + match tokensE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok tokens => + match deriveDirectionFromTokens tokens with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨dirTarget, dirNegative⟩ => + IO.println + s!"info: direction-target={dirTarget} direction-negative={dirNegative}" + let inputsE ← timePure "read head inputs" (fun () => + NfptPure.readInductionHeadInputs + data start header layer head dirTarget dirNegative period?) + match inputsE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok inputs => + checkInductionHeadInputs inputs splitCfg minActive? minLogitDiff? + minMargin maxEps + +/-- Build and check a strictly positive induction logit-diff bound from a model binary, deriving +direction tokens from the prompt sequence. -/ +def runInductionCertifyHeadModelAutoNonvacuous (modelPath : System.FilePath) + (layer head : Nat) (period? : Option Nat) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) + (timing? : Option Nat) (heartbeatMs? : Option Nat) + (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : + IO UInt32 := do + configureTiming timing? heartbeatMs? + let splitCfg := + splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) + logTiming "start: read model file" + timingPrint "timing: read model file start" + timingFlush + let data ← timePhase "read model file" <| IO.FS.readBinFile modelPath + let headerE ← timePure "parse model header" (fun () => + NfptPure.parseHeader data) + match headerE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨header, start⟩ => + let tokensE ← timePure "read prompt tokens" (fun () => + NfptPure.readTokens data start header) + match tokensE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok tokens => + match deriveDirectionFromTokens tokens with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨dirTarget, dirNegative⟩ => + IO.println + s!"info: direction-target={dirTarget} direction-negative={dirNegative}" + let inputsE ← timePure "read head inputs" (fun () => + NfptPure.readInductionHeadInputs + data start header layer head dirTarget dirNegative period?) + match inputsE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok inputs => + checkInductionHeadInputsNonvacuous inputs splitCfg minActive? minLogitDiff? + minMargin? maxEps + +/-- Build head-output interval bounds from exact head inputs. -/ +def runInductionHeadInterval (inputsPath : System.FilePath) + (outPath? : Option System.FilePath) : IO UInt32 := do + let parsedInputs ← loadInductionHeadInputs inputsPath + match parsedInputs with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => + buildHeadOutputIntervalFromInputs inputs outPath? + +/-- Build head-output interval bounds from a model binary. -/ +def runInductionHeadIntervalModel (modelPath : System.FilePath) + (layer head dirTarget dirNegative : Nat) (period? : Option Nat) + (outPath? : Option System.FilePath) : IO UInt32 := do + let data ← IO.FS.readBinFile modelPath + match NfptPure.parseHeader data with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨header, start⟩ => + match + NfptPure.readInductionHeadInputs + data start header layer head dirTarget dirNegative period? + with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok inputs => + buildHeadOutputIntervalFromInputs inputs outPath? + +end IO + +end Nfp diff --git a/Nfp/Sound/Induction/CoreSound.lean b/Nfp/Sound/Induction/CoreSound.lean index 0dc5c6d..281343f 100644 --- a/Nfp/Sound/Induction/CoreSound.lean +++ b/Nfp/Sound/Induction/CoreSound.lean @@ -1,1307 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Sound.Induction.Core -import Nfp.Sound.Induction.CoreSound.Values -namespace Nfp -namespace Sound -open scoped BigOperators -open Nfp.Circuit -open Nfp.Sound.Bounds -variable {seq : Nat} -set_option maxHeartbeats 5000000 in --- The soundness proof expands many cached bounds; extra heartbeats avoid spurious timeouts. -set_option synthInstance.maxHeartbeats 200000 in --- Instance search also touches the expanded caches; allow more room to avoid timeouts. -/-- Soundness for `buildInductionCertFromHeadCoreWith?`. -/ -theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : Nat} - (cfg : InductionHeadSplitConfig) - (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) - (hcore : buildInductionCertFromHeadCoreWith? cfg inputs = some c) : - InductionHeadCertSound inputs c := by - classical - by_cases hEps : 0 < inputs.lnEps - · by_cases hSqrt : 0 < sqrtLower inputs.lnEps - · by_cases hmodel : dModel = 0 - · have : False := by - have hnone := - buildInductionCertFromHeadCoreWith?_eq_none_of_model_eq_zero - (cfg := cfg) (inputs := inputs) hEps hSqrt hmodel - have hcore' : - (none : Option (InductionHeadCert seq)) = some c := by - exact hnone.symm.trans hcore - cases hcore' - exact this.elim - · by_cases hactive : inputs.active.Nonempty - · let lnBounds := Bounds.cacheBoundPair2 (fun q => - Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) - let lnLo : Fin seq → Fin dModel → Rat := lnBounds.1 - let lnHi : Fin seq → Fin dModel → Rat := lnBounds.2 - let lnAbsMaxTask : Fin seq → Rat := - Bounds.cacheBoundTask (fun q => - Bounds.intervalAbsBound (lnLo q) (lnHi q)) - let lnAbsMaxArr : Array Rat := - Array.ofFn (fun q : Fin seq => lnAbsMaxTask q) - let lnAbsMax : Fin seq → Rat := fun q => - lnAbsMaxArr[q.1]'(by - simp [lnAbsMaxArr]) - let invStdBoundsTasks : Array (Task (Rat × Rat)) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => invStdBounds inputs.lnEps (inputs.embed q))) - let invStdBoundsArr : Array (Rat × Rat) := - Array.ofFn (fun q : Fin seq => - (invStdBoundsTasks[q.1]'(by - simp [invStdBoundsTasks])).get) - let invStdLo : Fin seq → Rat := fun q => - (invStdBoundsArr[q.1]'(by - simp [invStdBoundsArr])).1 - let invStdHi : Fin seq → Rat := fun q => - (invStdBoundsArr[q.1]'(by - simp [invStdBoundsArr])).2 - let lnCoeff : Fin seq → Fin dModel → Rat := fun q j => - inputs.ln1Gamma j * (inputs.embed q j - mean (inputs.embed q)) - let invStd : Fin seq → Real := fun q => - (Real.sqrt ((varianceRat (inputs.embed q) : Real) + (inputs.lnEps : Real)))⁻¹ - have hmeanRat : - ∀ q, (mean (inputs.embed q) : Real) = meanRat (inputs.embed q) := by - intro q - have hmu_rat : mean (inputs.embed q) = meanRat (inputs.embed q) := by - simp [mean_def, hmodel, ratRoundDown] - simpa [ratToReal] using congrArg ratToReal hmu_rat - have hln_affine : - ∀ q j, - lnRealOfInputs inputs q j = - (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q := by - intro q j - have hmu := hmeanRat q - simp [lnRealOfInputs, Bounds.layerNormReal, hmodel, lnCoeff, hmu, invStd, add_comm, - mul_assoc, -mul_eq_mul_left_iff, -mul_eq_mul_right_iff] - have hln_fun : - ∀ q, - lnRealOfInputs inputs q = - fun j => (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q := by - intro q - funext j - exact hln_affine q j - have hinv_bounds : - ∀ q, (invStdLo q : Real) ≤ invStd q ∧ invStd q ≤ (invStdHi q : Real) := by - intro q - simpa [invStd, invStdLo, invStdHi, invStdBoundsArr, invStdBoundsTasks, - invStdBounds, Task.spawn, Array.getElem_ofFn] using - (Bounds.invStdBounds_spec (eps := inputs.lnEps) (x := inputs.embed q) - hmodel hEps hSqrt) - let qBaseArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.wq j d) (fun j => inputs.ln1Beta j) + - inputs.bq d) - let qBase : Fin dHead → Rat := fun d => - qBaseArr[d.1]'(by - simp [qBaseArr]) - let kBaseArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.wk j d) (fun j => inputs.ln1Beta j) + - inputs.bk d) - let kBase : Fin dHead → Rat := fun d => - kBaseArr[d.1]'(by - simp [kBaseArr]) - let coeffRowTasks : - (Fin dModel → Fin dHead → Rat) → - Array (Task { row : Array Rat // row.size = dHead }) := - fun w => - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - let μ := mean (inputs.embed q) - let coeff : Fin dModel → Rat := fun j => - inputs.ln1Gamma j * (inputs.embed q j - μ) - ⟨Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => w j d) coeff), - by simp⟩)) - let qCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := - coeffRowTasks inputs.wq - let qCoeffArr : Array { row : Array Rat // row.size = dHead } := - Array.ofFn (fun q : Fin seq => - (qCoeffRowTasks[q.1]'(by - simp [qCoeffRowTasks, coeffRowTasks])).get) - let qCoeff : Fin seq → Fin dHead → Rat := fun q d => - let row := qCoeffArr[q.1]'(by - simp [qCoeffArr]) - row.1[d.1]'(by - simp [row.2]) - let kCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := - coeffRowTasks inputs.wk - let kCoeffArr : Array { row : Array Rat // row.size = dHead } := - Array.ofFn (fun q : Fin seq => - (kCoeffRowTasks[q.1]'(by - simp [kCoeffRowTasks, coeffRowTasks])).get) - let kCoeff : Fin seq → Fin dHead → Rat := fun q d => - let row := kCoeffArr[q.1]'(by - simp [kCoeffArr]) - row.1[d.1]'(by - simp [row.2]) - let qLo : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) - qBase d + bounds.1 - let qHi : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) - qBase d + bounds.2 - let kLo : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) - kBase d + bounds.1 - let kHi : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) - kBase d + bounds.2 - let qAbs : Fin seq → Fin dHead → Rat := fun q d => max |qLo q d| |qHi q d| - let kAbs : Fin seq → Fin dHead → Rat := fun q d => max |kLo q d| |kHi q d| - let qAbsMaxArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := Finset.univ_nonempty - univ.sup' hnonempty (fun q => qAbs q d)) - let qAbsMax : Fin dHead → Rat := fun d => - qAbsMaxArr[d.1]'(by - simp [qAbsMaxArr]) - let kAbsMaxArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := Finset.univ_nonempty - univ.sup' hnonempty (fun k => kAbs k d)) - let kAbsMax : Fin dHead → Rat := fun d => - kAbsMaxArr[d.1]'(by - simp [kAbsMaxArr]) - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let splitBudgetQ : Nat := cfg.splitBudgetQ - let splitBudgetK : Nat := cfg.splitBudgetK - let splitBudgetDiffBase : Nat := cfg.splitBudgetDiffBase - let splitBudgetDiffRefined : Nat := cfg.splitBudgetDiffRefined - let top2ByScore : - (Fin dHead → Rat) → List (Fin dHead) → List (Fin dHead) := fun score ambig => - let step - (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) - (d : Fin dHead) : - Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := - let s := score d - match best with - | (none, none) => (some (s, d), none) - | (some b1, none) => - if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) - | (some b1, some b2) => - if b1.1 < s then (some (s, d), some b1) - else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) - | (none, some b2) => - if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - match ambig.foldl step (none, none) with - | (some b1, some b2) => [b1.2, b2.2] - | (some b1, none) => [b1.2] - | (none, _) => [] - let finRangeHead : List (Fin dHead) := List.finRange dHead - let finRangeSeq : List (Fin seq) := List.finRange seq - let splitDimsQ : Fin seq → List (Fin dHead) := fun q => - if splitBudgetQ = 0 then - [] - else - let ambig := - finRangeHead.filter (fun d => decide (qLo q d < 0 ∧ 0 < qHi q d)) - let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d - let dims1 := top2ByScore score ambig - let dims2 := top2ByScore score (ambig.filter (fun d => decide (d ∉ dims1))) - (dims1 ++ dims2).take splitBudgetQ - let splitDimsK : Fin seq → Fin seq → List (Fin dHead) := fun q k => - if splitBudgetK = 0 then - [] - else - let ambig := - finRangeHead.filter (fun d => decide (kLo k d < 0 ∧ 0 < kHi k d)) - let score : Fin dHead → Rat := fun d => (kHi k d - kLo k d) * qAbs q d - let dims1 := top2ByScore score ambig - let dims2 := top2ByScore score (ambig.filter (fun d => decide (d ∉ dims1))) - (dims1 ++ dims2).take splitBudgetK - let splitDimsDiffCore : Nat → Fin seq → Fin seq → List (Fin dHead) := fun budget q k => - if budget = 0 then - [] - else - let prev := inputs.prev q - let diffLo : Fin dHead → Rat := fun d => kLo prev d - kHi k d - let diffHi : Fin dHead → Rat := fun d => kHi prev d - kLo k d - let ambig := - finRangeHead.filter (fun d => decide (diffLo d < 0 ∧ 0 < diffHi d)) - let score : Fin dHead → Rat := fun d => (diffHi d - diffLo d) * qAbs q d - let step - (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) - (d : Fin dHead) : - Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := - let s := score d - match best with - | (none, none) => (some (s, d), none) - | (some b1, none) => - if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) - | (some b1, some b2) => - if b1.1 < s then (some (s, d), some b1) - else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) - | (none, some b2) => - if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => - match ambig.foldl step (none, none) with - | (some b1, some b2) => [b1.2, b2.2] - | (some b1, none) => [b1.2] - | (none, _) => [] - let dims1 : List (Fin dHead) := top2 ambig - let dims2 : List (Fin dHead) := - top2 (ambig.filter (fun d => decide (d ∉ dims1))) - let memDims2 : Fin dHead → Bool := fun d => - dims2.any (fun d' => decide (d' = d)) - let dims3 : List (Fin dHead) := - top2 - ((ambig.filter (fun d => decide (d ∉ dims1))).filter - (fun d => !memDims2 d)) - (dims1 ++ dims2 ++ dims3).take budget - let splitDimsDiffBase : Fin seq → Fin seq → List (Fin dHead) := - splitDimsDiffCore splitBudgetDiffBase - let splitDimsDiffRefined : Fin seq → Fin seq → List (Fin dHead) := - splitDimsDiffCore splitBudgetDiffRefined - let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - let dimsQ := splitDimsQ q - ⟨Array.ofFn (fun k : Fin seq => - if masked q k then - (0, 0) - else - let dimsK := splitDimsK q k - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsK - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo k d) (fun d => kHi k d)), - by simp⟩)) - let dotDiffRowTasksBase : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - if hq : q ∈ inputs.active then - let dimsQ := splitDimsQ q - ⟨Array.ofFn (fun k : Fin seq => - if masked q k then - (0, 0) - else - let dimsDiff := splitDimsDiffBase q k - let prev := inputs.prev q - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)), - by simp⟩ - else - ⟨Array.ofFn (fun _ : Fin seq => (0, 0)), by simp⟩)) - let dotLo : Fin seq → Fin seq → Rat := fun q k => - let row := (dotRowTasks[q.1]'(by - simp [dotRowTasks, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.1 - let dotHi : Fin seq → Fin seq → Rat := fun q k => - let row := (dotRowTasks[q.1]'(by - simp [dotRowTasks, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.2 - let dotDiffLoBase : Fin seq → Fin seq → Rat := fun q k => - let row := (dotDiffRowTasksBase[q.1]'(by - simp [dotDiffRowTasksBase, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.1 - let dotDiffHiBase : Fin seq → Fin seq → Rat := fun q k => - let row := (dotDiffRowTasksBase[q.1]'(by - simp [dotDiffRowTasksBase, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.2 - let dotAbs : Fin seq → Fin seq → Rat := fun q k => max |dotLo q k| |dotHi q k| - let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => - |inputs.scale| * dotAbs q k - let scoreLo : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - if hscale : 0 ≤ inputs.scale then - inputs.scale * dotLo q k - else - inputs.scale * dotHi q k - let scoreHi : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - if hscale : 0 ≤ inputs.scale then - inputs.scale * dotHi q k - else - inputs.scale * dotLo q k - let scoreLoPrev : Fin seq → Rat := fun q => - scoreLo q (inputs.prev q) - let scoreGapLoBaseRaw : Fin seq → Fin seq → Rat := fun q k => - if masked q (inputs.prev q) then - scoreLoPrev q - scoreHi q k - else if masked q k then - scoreLoPrev q - inputs.maskValue - else if hscale : 0 ≤ inputs.scale then - inputs.scale * dotDiffLoBase q k - else - inputs.scale * dotDiffHiBase q k - let scoreGapLoBase : Fin seq → Fin seq → Rat := - Bounds.cacheBound2 scoreGapLoBaseRaw - let otherKeys : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - let worstKeyRaw : Fin seq → Option (Fin seq) := fun q => - if hq : q ∈ inputs.active then - let ks := finRangeSeq.filter (fun k => decide (k ≠ inputs.prev q)) - match ks with - | [] => none - | k :: ks => - let step (best : Rat × Fin seq) (k : Fin seq) := - let s := scoreGapLoBase q k - if s ≤ best.1 then (s, k) else best - some (ks.foldl step (scoreGapLoBase q k, k)).2 - else - none - let worstKeyArr : Array (Thunk (Option (Fin seq))) := - Array.ofFn (fun q => Thunk.mk (fun _ => worstKeyRaw q)) - let worstKey : Fin seq → Option (Fin seq) := fun q => - let t := worstKeyArr[q.1]'(by - simp [worstKeyArr, q.isLt]) - Thunk.get t - let dotDiffLo : Fin seq → Fin seq → Rat := fun q k => - match worstKey q with - | some k' => - if hk : k = k' then - let dimsQ := splitDimsQ q - let dimsDiff := splitDimsDiffRefined q k - let prev := inputs.prev q - (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)).1 - else - dotDiffLoBase q k - | none => dotDiffLoBase q k - let dotDiffHi : Fin seq → Fin seq → Rat := fun q k => - match worstKey q with - | some k' => - if hk : k = k' then - let dimsQ := splitDimsQ q - let dimsDiff := splitDimsDiffRefined q k - let prev := inputs.prev q - (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)).2 - else - dotDiffHiBase q k - | none => dotDiffHiBase q k - let scoreGapLoRaw : Fin seq → Fin seq → Rat := fun q k => - if masked q (inputs.prev q) then - scoreLoPrev q - scoreHi q k - else if masked q k then - scoreLoPrev q - inputs.maskValue - else if hscale : 0 ≤ inputs.scale then - inputs.scale * dotDiffLo q k - else - inputs.scale * dotDiffHi q k - let scoreGapLo : Fin seq → Fin seq → Rat := - Bounds.cacheBound2 scoreGapLoRaw - let marginAt : Fin seq → Rat := fun q => - if hq : q ∈ inputs.active then - let other := otherKeys q - if h : other.Nonempty then - other.inf' h (fun k => scoreGapLo q k) - else - (0 : Rat) - else - (0 : Rat) - let weightBoundAtBase : Fin seq → Fin seq → Rat := fun q k => - if hk : k = inputs.prev q then - (0 : Rat) - else - let gap := scoreGapLo q k - if gap < 0 then - (1 : Rat) - else - ratDivUp 1 (1 + gap) - let weightBoundAt : Fin seq → Fin seq → Rat := - Bounds.cacheBound2 weightBoundAtBase - let epsAtBase : Fin seq → Rat := fun q => - let other := otherKeys q - let total := other.sum (fun k => weightBoundAt q k) - min (1 : Rat) total - let epsAt : Fin seq → Rat := - Bounds.cacheBoundThunk epsAtBase - let margin : Rat := - if h : inputs.active.Nonempty then - inputs.active.inf' h marginAt - else - (0 : Rat) - let eps : Rat := - if h : inputs.active.Nonempty then - inputs.active.sup' h epsAt - else - (0 : Rat) - have hseq : (1 : Nat) ≤ seq := - Nat.succ_le_iff.mpr (Nat.pos_of_ne_zero (NeZero.ne seq)) - let dirHeadVec := dirHeadVecOfInputs inputs - let dirHead : Fin dHead → Rat := fun d => dirHeadVec.get d - let wvDir : Fin dModel → Rat := - Bounds.cacheBoundTask (fun j => - Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) - let bDir : Rat := - Linear.dotFin dHead dirHead (fun d => inputs.bv d) - let valsLo : Fin seq → Rat := fun q => - bDir + Bounds.dotIntervalLower (fun j => wvDir j) (lnLo q) (lnHi q) - let valsHi : Fin seq → Rat := fun q => - bDir + Bounds.dotIntervalUpper (fun j => wvDir j) (lnLo q) (lnHi q) - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := by simp [univ] - let lo := univ.inf' hnonempty valsLo - let hi := univ.sup' hnonempty valsHi - let valCert : ValueInterval seq := - { lo := lo - hi := hi - valsLo := valsLo - valsHi := valsHi - direction := some inputs.directionSpec } - let cert : InductionHeadCert seq := - { eps := eps - epsAt := epsAt - weightBoundAt := weightBoundAt - margin := margin - active := inputs.active - prev := inputs.prev - values := valCert } - have hcore' : buildInductionCertFromHeadCoreWith? cfg inputs = some cert := by - simp (config := { zeta := false }) only - [buildInductionCertFromHeadCoreWith?, hEps, hSqrt, hmodel, hactive] - rfl - have hc : c = cert := by - have hcert : cert = c := by - exact Option.some.inj (hcore'.symm.trans hcore) - simpa using hcert.symm - subst hc - have hln_bounds : - ∀ q i, (lnLo q i : Real) ≤ lnRealOfInputs inputs q i ∧ - lnRealOfInputs inputs q i ≤ (lnHi q i : Real) := by - intro q i - have hln := - Bounds.layerNormBounds_spec (eps := inputs.lnEps) - (gamma := inputs.ln1Gamma) (beta := inputs.ln1Beta) - (x := inputs.embed q) hmodel hEps hSqrt - simpa [lnBounds, lnLo, lnHi, lnRealOfInputs, Bounds.cacheBoundPair2_apply_left, - Bounds.cacheBoundPair2_apply_right] using hln i - have dotFin_cast {n : Nat} (f g : Fin n → Rat) : - (Linear.dotFin n f g : Real) = - dotProduct (fun j => (f j : Real)) (fun j => (g j : Real)) := by - simp [Linear.dotFin_def, Linear.sumFin_eq_sum_univ, dotProduct] - have proj_bounds - (w : Fin dModel → Fin dHead → Rat) - (b base : Fin dHead → Rat) - (coeff : Fin seq → Fin dHead → Rat) - (hbase : ∀ d, - (base d : Real) = - dotProduct (fun j => (w j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - (b d : Real)) - (hcoeff : ∀ q d, - (coeff q d : Real) = - dotProduct (fun j => (w j d : Real)) - (fun j => (lnCoeff q j : Real))) : - ∀ q d, - (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).1 : Rat) ≤ - dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + - (b d : Real) ∧ - dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + - (b d : Real) ≤ - (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).2 : Rat) := by - intro q d - have hinv : (invStdLo q : Real) ≤ invStd q ∧ invStd q ≤ (invStdHi q : Real) := - hinv_bounds q - have hln_fun_q : - lnRealOfInputs inputs q = - fun j => (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q := by - exact hln_fun q - have hdot_add : - dotProduct (fun j => (w j d : Real)) - (fun j => - (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q) = - dotProduct (fun j => (w j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - dotProduct (fun j => (w j d : Real)) - (fun j => (lnCoeff q j : Real) * invStd q) := by - simpa using - (Nfp.Sound.Linear.dotProduct_add_right - (x := fun j => (w j d : Real)) - (y := fun j => (inputs.ln1Beta j : Real)) - (z := fun j => (lnCoeff q j : Real) * invStd q)) - have hdot_coeff : - dotProduct (fun j => (w j d : Real)) - (fun j => (lnCoeff q j : Real) * invStd q) = - dotProduct (fun j => (w j d : Real)) - (fun j => (lnCoeff q j : Real)) * - invStd q := by - simpa using - (Nfp.Sound.Linear.dotProduct_mul_right - (x := fun j => (w j d : Real)) - (y := fun j => (lnCoeff q j : Real)) - (a := invStd q)) - have hreal : - dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + - (b d : Real) = - (base d : Real) + (coeff q d : Real) * invStd q := by - calc - dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + - (b d : Real) = - dotProduct (fun j => (w j d : Real)) - (fun j => - (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q) + - (b d : Real) := by - simp [hln_fun_q] - _ = - dotProduct (fun j => (w j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - dotProduct (fun j => (w j d : Real)) - (fun j => (lnCoeff q j : Real)) * - invStd q + - (b d : Real) := by - simp [hdot_add, hdot_coeff, add_assoc] - _ = - (dotProduct (fun j => (w j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - (b d : Real)) + - dotProduct (fun j => (w j d : Real)) - (fun j => (lnCoeff q j : Real)) * - invStd q := by ac_rfl - _ = (base d : Real) + (coeff q d : Real) * invStd q := by - simp [hbase, hcoeff] - have hscale : - let bounds := scaleInterval (coeff q d) (invStdLo q) (invStdHi q) - (bounds.1 : Real) ≤ (coeff q d : Real) * invStd q ∧ - (coeff q d : Real) * invStd q ≤ (bounds.2 : Real) := by - exact scaleInterval_bounds_real (x := coeff q d) (lo := invStdLo q) - (hi := invStdHi q) (y := invStd q) hinv.1 hinv.2 - have hlow : - (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).1 : Rat) ≤ - dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + - (b d : Real) := by - simpa [hreal] using add_le_add_left hscale.1 (base d : Real) - have hhigh : - dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + - (b d : Real) ≤ - (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).2 : Rat) := by - simpa [hreal] using add_le_add_left hscale.2 (base d : Real) - exact ⟨hlow, hhigh⟩ - have hq_bounds : - ∀ q d, (qLo q d : Real) ≤ qRealOfInputs inputs q d ∧ - qRealOfInputs inputs q d ≤ (qHi q d : Real) := by - intro q d - have hbase : - ∀ d, - (qBase d : Real) = - dotProduct (fun j => (inputs.wq j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - (inputs.bq d : Real) := by - intro d - simp [qBase, qBaseArr, dotFin_cast] - have hcoeff : - ∀ q' d, - (qCoeff q' d : Real) = - dotProduct (fun j => (inputs.wq j d : Real)) - (fun j => (lnCoeff q' j : Real)) := by - intro q' d - simpa [qCoeff, qCoeffArr, qCoeffRowTasks, coeffRowTasks, lnCoeff, Task.spawn] using - (dotFin_cast (f := fun j => inputs.wq j d) - (g := fun j => - inputs.ln1Gamma j * (inputs.embed q' j - mean (inputs.embed q')))) - have h := proj_bounds (w := inputs.wq) (b := inputs.bq) (base := qBase) - (coeff := qCoeff) hbase hcoeff q d - simpa [qLo, qHi, qRealOfInputs] using h - have hk_bounds : - ∀ q d, (kLo q d : Real) ≤ kRealOfInputs inputs q d ∧ - kRealOfInputs inputs q d ≤ (kHi q d : Real) := by - intro q d - have hbase : - ∀ d, - (kBase d : Real) = - dotProduct (fun j => (inputs.wk j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - (inputs.bk d : Real) := by - intro d - simp [kBase, kBaseArr, dotFin_cast] - have hcoeff : - ∀ q' d, - (kCoeff q' d : Real) = - dotProduct (fun j => (inputs.wk j d : Real)) - (fun j => (lnCoeff q' j : Real)) := by - intro q' d - simpa [kCoeff, kCoeffArr, kCoeffRowTasks, coeffRowTasks, lnCoeff, Task.spawn] using - (dotFin_cast (f := fun j => inputs.wk j d) - (g := fun j => - inputs.ln1Gamma j * (inputs.embed q' j - mean (inputs.embed q')))) - have h := proj_bounds (w := inputs.wk) (b := inputs.bk) (base := kBase) - (coeff := kCoeff) hbase hcoeff q d - simpa [kLo, kHi, kRealOfInputs] using h - let scoresReal := scoresRealOfInputs inputs - have scoresReal_eq_base_of_not_masked : - ∀ q k, ¬ masked q k → - scoresReal q k = - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) := by - intro q k hnot - by_cases hcausal : inputs.maskCausal - · have hnot_lt : ¬ q < k := by - intro hlt - exact hnot ⟨hcausal, hlt⟩ - have hle : k ≤ q := le_of_not_gt hnot_lt - simp [scoresReal, scoresRealOfInputs, hcausal, hle] - · simp [scoresReal, scoresRealOfInputs, hcausal] - have scoresReal_eq_masked : - ∀ q k, masked q k → scoresReal q k = (inputs.maskValue : Real) := by - intro q k hmask - have hmask' : inputs.maskCausal = true ∧ q < k := by - simpa [masked] using hmask - have hle : ¬ k ≤ q := not_le_of_gt hmask'.2 - simp [scoresReal, scoresRealOfInputs, hmask'.1, hle] - have hscore_bounds : - ∀ q k, (scoreLo q k : Real) ≤ scoresReal q k ∧ - scoresReal q k ≤ (scoreHi q k : Real) := by - intro q k - let base := - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) - have hdot_bounds (hnot : ¬ masked q k) : - (dotLo q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) ∧ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) ≤ (dotHi q k : Real) := by - have hq := hq_bounds q - have hk := hk_bounds k - have hlo1 : ∀ d, (qLo q d : Real) ≤ qRealOfInputs inputs q d := fun d => - (hq d).1 - have hhi1 : ∀ d, qRealOfInputs inputs q d ≤ (qHi q d : Real) := fun d => - (hq d).2 - have hlo2 : ∀ d, (kLo k d : Real) ≤ kRealOfInputs inputs k d := fun d => - (hk d).1 - have hhi2 : ∀ d, kRealOfInputs inputs k d ≤ (kHi k d : Real) := fun d => - (hk d).2 - have hspec := - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth_spec_real - (dims1 := splitDimsQ q) (dims2 := splitDimsK q k) - (lo1 := fun d => qLo q d) (hi1 := fun d => qHi q d) - (lo2 := fun d => kLo k d) (hi2 := fun d => kHi k d) - (x := fun d => qRealOfInputs inputs q d) - (y := fun d => kRealOfInputs inputs k d) - hlo1 hhi1 hlo2 hhi2 - have hlow' : - (dotLo q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) := by - simpa [dotLo, dotRowTasks, Task.spawn, Array.getElem_ofFn, hnot] - using hspec.1 - have hhigh' : - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) ≤ (dotHi q k : Real) := by - simpa [dotHi, dotRowTasks, Task.spawn, Array.getElem_ofFn, hnot] - using hspec.2 - exact ⟨hlow', hhigh'⟩ - have hscore_base_bounds (hnot : ¬ masked q k) : - (scoreLo q k : Real) ≤ base ∧ base ≤ (scoreHi q k : Real) := by - by_cases hscale : 0 ≤ inputs.scale - · have hscale_real : 0 ≤ (inputs.scale : Real) := - ratToReal_nonneg_of_nonneg hscale - have hdot := hdot_bounds hnot - have hlow := mul_le_mul_of_nonneg_left hdot.1 hscale_real - have hhigh := mul_le_mul_of_nonneg_left hdot.2 hscale_real - constructor - · simpa [scoreLo, masked, hnot, hscale, base] using hlow - · simpa [scoreHi, masked, hnot, hscale, base] using hhigh - · have hscale_nonpos : inputs.scale ≤ 0 := - le_of_lt (lt_of_not_ge hscale) - have hscale_real : (inputs.scale : Real) ≤ 0 := - (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos - have hdot := hdot_bounds hnot - have hlow := mul_le_mul_of_nonpos_left hdot.2 hscale_real - have hhigh := mul_le_mul_of_nonpos_left hdot.1 hscale_real - constructor - · simpa [scoreLo, masked, hnot, hscale, base] using hlow - · simpa [scoreHi, masked, hnot, hscale, base] using hhigh - by_cases hcausal : inputs.maskCausal - · by_cases hle : k ≤ q - · have hnot : ¬ q < k := not_lt_of_ge hle - have hnot_masked : ¬ masked q k := fun hmk => hnot hmk.2 - have hscore_eq : scoresReal q k = base := - scoresReal_eq_base_of_not_masked q k hnot_masked - have hbase := hscore_base_bounds hnot_masked - constructor - · simpa [hscore_eq] using hbase.1 - · simpa [hscore_eq] using hbase.2 - · have hlt : q < k := lt_of_not_ge hle - have hmask : masked q k := ⟨hcausal, hlt⟩ - have hscore : scoresReal q k = (inputs.maskValue : Real) := - scoresReal_eq_masked q k hmask - constructor - · simp [hscore, scoreLo, hmask] - · simp [hscore, scoreHi, hmask] - · have hnot_masked : ¬ masked q k := by - simp [masked, hcausal] - have hscore_eq : scoresReal q k = base := - scoresReal_eq_base_of_not_masked q k hnot_masked - have hbase := hscore_base_bounds hnot_masked - constructor - · simpa [hscore_eq] using hbase.1 - · simpa [hscore_eq] using hbase.2 - have hdot_diff_bounds : - ∀ q, q ∈ inputs.active → ∀ k, ¬ masked q k → - (dotDiffLo q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) ∧ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by - intro q hq k hmask - have hq_bounds' := hq_bounds q - have hkprev := hk_bounds (inputs.prev q) - have hk := hk_bounds k - have hlo1 : ∀ d, (qLo q d : Real) ≤ qRealOfInputs inputs q d := fun d => - (hq_bounds' d).1 - have hhi1 : ∀ d, qRealOfInputs inputs q d ≤ (qHi q d : Real) := fun d => - (hq_bounds' d).2 - have hlo2 : - ∀ d, - (kLo (inputs.prev q) d - kHi k d : Rat) ≤ - (kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) := by - intro d - have hprev_lo := (hkprev d).1 - have hk_hi := (hk d).2 - have h := sub_le_sub hprev_lo hk_hi - simpa [ratToReal_sub] using h - have hhi2 : - ∀ d, - (kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) ≤ - (kHi (inputs.prev q) d - kLo k d : Rat) := by - intro d - have hprev_hi := (hkprev d).2 - have hk_lo := (hk d).1 - have h := sub_le_sub hprev_hi hk_lo - simpa [ratToReal_sub] using h - have hspec (dimsDiff : List (Fin dHead)) := - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth_spec_real - (dims1 := splitDimsQ q) (dims2 := dimsDiff) - (lo1 := fun d => qLo q d) (hi1 := fun d => qHi q d) - (lo2 := fun d => kLo (inputs.prev q) d - kHi k d) - (hi2 := fun d => kHi (inputs.prev q) d - kLo k d) - (x := fun d => qRealOfInputs inputs q d) - (y := fun d => - kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) - hlo1 hhi1 hlo2 hhi2 - have hspecBase := hspec (splitDimsDiffBase q k) - have hspecRef := hspec (splitDimsDiffRefined q k) - have hspecBase_bounds : - (dotDiffLoBase q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) ∧ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) ≤ (dotDiffHiBase q k : Real) := by - refine ⟨?_, ?_⟩ - · simpa [dotDiffLoBase, dotDiffRowTasksBase, hq, hmask, Task.spawn, - Array.getElem_ofFn] using hspecBase.1 - · simpa [dotDiffHiBase, dotDiffRowTasksBase, hq, hmask, Task.spawn, - Array.getElem_ofFn] using hspecBase.2 - cases hkey : worstKey q with - | none => - simpa [dotDiffLo, dotDiffHi, hkey] using hspecBase_bounds - | some k' => - by_cases hk : k = k' - · have hlow' : - (dotDiffLo q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) := by - simpa [dotDiffLo, hkey, hk] using hspecRef.1 - have hhigh' : - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by - simpa [dotDiffHi, hkey, hk] using hspecRef.2 - exact ⟨hlow', hhigh'⟩ - · have hlow' : - (dotDiffLo q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) := by - simpa [dotDiffLo, hkey, hk] using hspecBase_bounds.1 - have hhigh' : - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by - simpa [dotDiffHi, hkey, hk] using hspecBase_bounds.2 - exact ⟨hlow', hhigh'⟩ - have hmarginAt_le : - ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → - marginAt q ≤ scoreGapLo q k := by - intro q hq k hk - have hmem : k ∈ otherKeys q := by simp [otherKeys, hk] - have hnonempty : (otherKeys q).Nonempty := ⟨k, hmem⟩ - have hle : - (otherKeys q).inf' hnonempty (fun k => scoreGapLo q k) ≤ scoreGapLo q k := by - exact (Finset.inf'_le_iff (s := otherKeys q) (H := hnonempty) - (f := fun k => scoreGapLo q k) (a := scoreGapLo q k)).2 - ⟨k, hmem, le_rfl⟩ - simpa [marginAt, hq, hnonempty] using hle - have hscore_gap_real_at : - ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → - scoresReal q k + (scoreGapLo q k : Real) ≤ scoresReal q (inputs.prev q) := by - intro q hq k hk - by_cases hprevmask : masked q (inputs.prev q) - · have hscore_hi : scoresReal q k ≤ (scoreHi q k : Real) := - (hscore_bounds q k).2 - have hscore_prev : (scoreLoPrev q : Real) ≤ scoresReal q (inputs.prev q) := by - have hprev_bounds := hscore_bounds q (inputs.prev q) - simpa [scoreLoPrev] using hprev_bounds.1 - have hsum_le' : - (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k ≤ - (scoreLoPrev q : Real) := by - have hsub : - (scoreLoPrev q : Real) - (scoreHi q k : Real) ≤ - (scoreLoPrev q : Real) - scoresReal q k := - sub_le_sub_left hscore_hi (scoreLoPrev q : Real) - calc - (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k - ≤ (scoreLoPrev q : Real) - scoresReal q k + scoresReal q k := by - simpa [add_comm, add_left_comm, add_assoc] using - (add_le_add_left hsub (scoresReal q k)) - _ = (scoreLoPrev q : Real) := by - simp [sub_add_cancel] - calc - scoresReal q k + (scoreGapLo q k : Real) - = (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k := by - simp [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, - hprevmask, add_comm] - _ ≤ (scoreLoPrev q : Real) := hsum_le' - _ ≤ scoresReal q (inputs.prev q) := hscore_prev - · by_cases hmask : masked q k - · have hscore_prev : (scoreLoPrev q : Real) ≤ scoresReal q (inputs.prev q) := by - have hprev_bounds := hscore_bounds q (inputs.prev q) - simpa [scoreLoPrev] using hprev_bounds.1 - have hscore_k : scoresReal q k = (inputs.maskValue : Real) := - scoresReal_eq_masked q k hmask - calc - scoresReal q k + (scoreGapLo q k : Real) - = (inputs.maskValue : Real) + (scoreLoPrev q : Real) - - (inputs.maskValue : Real) := by - simp [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, - hprevmask, hmask, hscore_k] - _ = (scoreLoPrev q : Real) := by - simp [add_sub_cancel_left] - _ ≤ scoresReal q (inputs.prev q) := hscore_prev - · have hdiff := hdot_diff_bounds q hq k hmask - have hgap_le : - (scoreGapLo q k : Real) ≤ - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) := by - by_cases hscale : 0 ≤ inputs.scale - · have hscale_real : 0 ≤ (inputs.scale : Real) := - ratToReal_nonneg_of_nonneg hscale - have hle := mul_le_mul_of_nonneg_left hdiff.1 hscale_real - simpa [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, - hprevmask, hmask, hscale] using hle - · have hscale_nonpos : inputs.scale ≤ 0 := - le_of_lt (lt_of_not_ge hscale) - have hscale_real : (inputs.scale : Real) ≤ 0 := - (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos - have hle := mul_le_mul_of_nonpos_left hdiff.2 hscale_real - simpa [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, - hprevmask, hmask, hscale] using hle - have hscore_prev : - scoresReal q (inputs.prev q) = - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d) := by - simpa using - (scoresReal_eq_base_of_not_masked q (inputs.prev q) hprevmask) - have hscore_k : - scoresReal q k = - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) := by - simpa using (scoresReal_eq_base_of_not_masked q k hmask) - have hdot_sub : - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) = - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d) - - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) := by - classical - simpa using - (Nfp.Sound.Linear.dotProduct_sub_right - (x := fun d => qRealOfInputs inputs q d) - (y := fun d => kRealOfInputs inputs (inputs.prev q) d) - (z := fun d => kRealOfInputs inputs k d)) - have hscore_diff : - scoresReal q (inputs.prev q) - scoresReal q k = - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) := by - calc - scoresReal q (inputs.prev q) - scoresReal q k - = - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d) - - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) := by - simp [hscore_prev, hscore_k] - _ = - (inputs.scale : Real) * - (dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d) - - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d)) := by - simp [mul_sub] - _ = - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) := by - simp [hdot_sub] - have hgap_le' : - (scoreGapLo q k : Real) ≤ - scoresReal q (inputs.prev q) - scoresReal q k := by - simpa [hscore_diff] using hgap_le - have hgap_add := - add_le_add_right hgap_le' (scoresReal q k) - have hgap_add' : - scoresReal q k + (scoreGapLo q k : Real) ≤ - scoresReal q (inputs.prev q) := by - have hcancel : - scoresReal q k + (scoresReal q (inputs.prev q) - scoresReal q k) = - scoresReal q (inputs.prev q) := by - calc - scoresReal q k + (scoresReal q (inputs.prev q) - scoresReal q k) - = - scoresReal q k + scoresReal q (inputs.prev q) - - scoresReal q k := by - symm - exact add_sub_assoc (scoresReal q k) - (scoresReal q (inputs.prev q)) (scoresReal q k) - _ = scoresReal q (inputs.prev q) := by - simp [add_sub_cancel_left] - calc - scoresReal q k + (scoreGapLo q k : Real) - ≤ scoresReal q k + (scoresReal q (inputs.prev q) - scoresReal q k) := - hgap_add - _ = scoresReal q (inputs.prev q) := hcancel - exact hgap_add' - let weights : Fin seq → Fin seq → Real := fun q k => - Circuit.softmax (scoresReal q) k - let others : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - have hscore_margin_real_at : - ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → - scoresReal q k + (marginAt q : Real) ≤ scoresReal q (inputs.prev q) := by - intro q hq k hk - have hmargin_le : marginAt q ≤ scoreGapLo q k := - hmarginAt_le q hq k hk - have hmargin_le_real : (marginAt q : Real) ≤ (scoreGapLo q k : Real) := - ratToReal_le_of_le hmargin_le - have hscore_gap := hscore_gap_real_at q hq k hk - have hstep := add_le_add_right hmargin_le_real (scoresReal q k) - have hstep' : - scoresReal q k + (marginAt q : Real) ≤ - scoresReal q k + (scoreGapLo q k : Real) := by - exact hstep - exact hstep'.trans hscore_gap - have hscore_margin_real : - ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → - scoresReal q k + (margin : Real) ≤ scoresReal q (inputs.prev q) := by - intro q hq k hk - have hmargin_le : margin ≤ marginAt q := by - have hmem : q ∈ inputs.active := hq - have hnonempty : inputs.active.Nonempty := hactive - have hle := - (Finset.inf'_le_iff (s := inputs.active) (H := hnonempty) - (f := marginAt) (a := marginAt q)).2 ⟨q, hmem, le_rfl⟩ - simpa [margin, hnonempty] using hle - have hmargin_le_real : (margin : Real) ≤ (marginAt q : Real) := - ratToReal_le_of_le hmargin_le - have hscore := hscore_margin_real_at q hq k hk - have hscore' : - (marginAt q : Real) + scoresReal q k ≤ scoresReal q (inputs.prev q) := by - simpa [add_comm] using hscore - have hstep := add_le_add_right hmargin_le_real (scoresReal q k) - have hstep' : - scoresReal q k + (margin : Real) ≤ (marginAt q : Real) + scoresReal q k := by - calc - scoresReal q k + (margin : Real) ≤ scoresReal q k + (marginAt q : Real) := hstep - _ = (marginAt q : Real) + scoresReal q k := by - simp [add_comm] - exact hstep'.trans hscore' - have hweightBoundAt : - ∀ q k, k ≠ inputs.prev q → - weightBoundAt q k = - if scoreGapLo q k < 0 then - (1 : Rat) - else - ratDivUp 1 (1 + scoreGapLo q k) := by - intro q k hk - simpa [weightBoundAt, weightBoundAtBase, hk] using - (Bounds.cacheBound2_apply (f := weightBoundAtBase) q k) - have hepsAt : - ∀ q, epsAt q = - min (1 : Rat) - ((otherKeys q).sum (fun k => - if scoreGapLo q k < 0 then - (1 : Rat) - else - ratDivUp 1 (1 + scoreGapLo q k))) := by - intro q - have hsum : - (otherKeys q).sum (fun k => weightBoundAt q k) = - (otherKeys q).sum (fun k => - if scoreGapLo q k < 0 then - (1 : Rat) - else - ratDivUp 1 (1 + scoreGapLo q k)) := by - refine Finset.sum_congr rfl ?_ - intro k hk - have hk' : k ≠ inputs.prev q := (Finset.mem_erase.mp hk).1 - simp [hweightBoundAt q k hk'] - simpa [epsAt, epsAtBase, hsum] using - (Bounds.cacheBoundThunk_apply (f := epsAtBase) q) - have oneHot_bounds_at : - ∀ q, q ∈ inputs.active → - Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAt q : Real) - (fun q' => q' = q) inputs.prev weights := by - intro q hq - exact - Sound.oneHot_bounds_at_of_scoreGapLo - (active := inputs.active) - (prev := inputs.prev) - (scoresReal := scoresReal) - (scoreGapLo := scoreGapLo) - (epsAt := epsAt) - (hepsAt := hepsAt) - (hscore_gap_real_at := hscore_gap_real_at) - q hq - have weight_bounds_at : - ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → - weights q k ≤ (weightBoundAt q k : Real) := by - intro q hq k hk - exact - Sound.weight_bound_at_of_scoreGapLo - (active := inputs.active) - (prev := inputs.prev) - (scoresReal := scoresReal) - (scoreGapLo := scoreGapLo) - (weightBoundAt := weightBoundAt) - (hweightBoundAt := hweightBoundAt) - (hscore_gap_real_at := hscore_gap_real_at) - q hq k hk - have hepsAt_le_eps : - ∀ q, q ∈ inputs.active → epsAt q ≤ eps := by - intro q hq - have hle : - epsAt q ≤ inputs.active.sup' hactive epsAt := by - exact - (Finset.le_sup'_iff (s := inputs.active) (H := hactive) - (f := epsAt) (a := epsAt q)).2 ⟨q, hq, le_rfl⟩ - simpa [eps, hactive] using hle - have hepsAt_le_eps_real : - ∀ q, q ∈ inputs.active → (epsAt q : Real) ≤ (eps : Real) := by - intro q hq - exact ratToReal_le_of_le (hepsAt_le_eps q hq) - have hsoftmax_bounds : - Layers.SoftmaxMarginBoundsOn (Val := Real) (eps : Real) (margin : Real) - (fun q => q ∈ inputs.active) inputs.prev scoresReal weights := by - classical - refine - { score_margin := ?_ - nonneg := ?_ - sum_one := ?_ - prev_large := ?_ - other_le := ?_ } - · intro q hq k hk - exact hscore_margin_real q hq k hk - · intro q _ k - simpa [weights] using - (Circuit.softmax_nonneg (scores := scoresReal q) k) - · intro q _ - simpa [weights] using - (Circuit.softmax_sum_one (scores := scoresReal q)) - · intro q hq - have honehot := oneHot_bounds_at q hq - have hprev := honehot.prev_large q rfl - have hle : - weights q (inputs.prev q) + (epsAt q : Real) ≤ - weights q (inputs.prev q) + (eps : Real) := by - simpa [add_comm] using - (add_le_add_right (hepsAt_le_eps_real q hq) (weights q (inputs.prev q))) - exact hprev.trans hle - · intro q hq k hk - have honehot := oneHot_bounds_at q hq - have hother := honehot.other_le q rfl k hk - exact hother.trans (hepsAt_le_eps_real q hq) - have hdirHead : - dirHead = fun d => (dirHeadVecOfInputs inputs).get d := by - simp [dirHead, dirHeadVec] - have hwvDir : - ∀ j, wvDir j = Linear.dotFin dHead dirHead (fun d => inputs.wv j d) := by - intro j - simp [wvDir, Bounds.cacheBoundTask_apply] - have hbDir : - bDir = Linear.dotFin dHead dirHead (fun d => inputs.bv d) := by - rfl - have hdir_wv : - ∀ j, (wvDir j : Real) = ∑ d, (dirHead d : Real) * (inputs.wv j d : Real) := - wvDir_real_eq_sum inputs dirHead wvDir hwvDir - have hdir_bv : - (bDir : Real) = ∑ d, (dirHead d : Real) * (inputs.bv d : Real) := - bDir_real_eq_sum inputs dirHead bDir hbDir - have hvals_eq : - ∀ k, - valsRealOfInputs inputs k = - dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + - (bDir : Real) := - valsReal_eq_of_dir inputs dirHead hdirHead wvDir bDir hdir_wv hdir_bv - have hvals_bounds_at : - ∀ k, - (valCert.valsLo k : Real) ≤ valsRealOfInputs inputs k ∧ - valsRealOfInputs inputs k ≤ (valCert.valsHi k : Real) := by - intro k - have hln := hln_bounds k - have hlo : ∀ j, (lnLo k j : Real) ≤ lnRealOfInputs inputs k j := fun j => - (hln j).1 - have hhi : ∀ j, lnRealOfInputs inputs k j ≤ (lnHi k j : Real) := fun j => - (hln j).2 - have hlow' : - (Bounds.dotIntervalLower (fun j => wvDir j) (lnLo k) (lnHi k) : Real) + - (bDir : Real) ≤ - dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + - (bDir : Real) := by - simpa using - (Bounds.dotIntervalLower_le_dotProduct_real_add - (v := fun j => wvDir j) - (lo := lnLo k) (hi := lnHi k) - (x := lnRealOfInputs inputs k) (b := (bDir : Real)) hlo hhi) - have hhigh' : - dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + - (bDir : Real) ≤ - (Bounds.dotIntervalUpper (fun j => wvDir j) (lnLo k) (lnHi k) : Real) + - (bDir : Real) := by - simpa using - (Bounds.dotProduct_le_dotIntervalUpper_real_add - (v := fun j => wvDir j) - (lo := lnLo k) (hi := lnHi k) - (x := lnRealOfInputs inputs k) (b := (bDir : Real)) hlo hhi) - have hlow : - (valCert.valsLo k : Real) ≤ valsRealOfInputs inputs k := by - simpa [valCert, valsLo, hvals_eq k, ratToReal_add, add_comm, add_left_comm, - add_assoc] using hlow' - have hhigh : - valsRealOfInputs inputs k ≤ (valCert.valsHi k : Real) := by - simpa [valCert, valsHi, hvals_eq k, ratToReal_add, add_comm, add_left_comm, - add_assoc] using hhigh' - exact ⟨hlow, hhigh⟩ - have hvals_bounds : - ValueIntervalBounds (vals := valsRealOfInputs inputs) valCert := by - refine - { lo_le_hi := ?_ - lo_le_valsLo := ?_ - vals_bounds := ?_ - valsHi_le_hi := ?_ } - · rcases (Finset.univ_nonempty : univ.Nonempty) with ⟨k0, hk0⟩ - have hmem0 : k0 ∈ univ := hk0 - have hlo : (valCert.lo : Real) ≤ (valCert.valsLo k0 : Real) := by - have hloRat : valCert.lo ≤ valCert.valsLo k0 := by - change lo ≤ valsLo k0 - dsimp [lo] - refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) - (f := valsLo) (a := valsLo k0)).2 ⟨k0, hmem0, le_rfl⟩ - exact ratToReal_le_of_le hloRat - have hvals : - (valCert.valsLo k0 : Real) ≤ valsRealOfInputs inputs k0 ∧ - valsRealOfInputs inputs k0 ≤ (valCert.valsHi k0 : Real) := by - exact hvals_bounds_at k0 - have hhi : (valCert.valsHi k0 : Real) ≤ (valCert.hi : Real) := by - have hhiRat : valCert.valsHi k0 ≤ valCert.hi := by - change valsHi k0 ≤ hi - dsimp [hi] - refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) - (f := valsHi) (a := valsHi k0)).2 ⟨k0, ⟨hmem0, le_rfl⟩⟩ - exact ratToReal_le_of_le hhiRat - have hreal : - (valCert.lo : Real) ≤ (valCert.hi : Real) := - le_trans hlo (le_trans hvals.1 (le_trans hvals.2 hhi)) - exact (ratToReal_le_iff (x := valCert.lo) (y := valCert.hi)).1 hreal - · intro k - have hloRat : valCert.lo ≤ valCert.valsLo k := by - change lo ≤ valsLo k - dsimp [lo] - refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) - (f := valsLo) (a := valsLo k)).2 ⟨k, by simp [univ], le_rfl⟩ - exact ratToReal_le_of_le hloRat - · intro k - exact hvals_bounds_at k - · intro k - have hhiRat : valCert.valsHi k ≤ valCert.hi := by - change valsHi k ≤ hi - dsimp [hi] - refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) - (f := valsHi) (a := valsHi k)).2 ⟨k, by simp [univ], le_rfl⟩ - exact ratToReal_le_of_le hhiRat - exact - { softmax_bounds := hsoftmax_bounds - oneHot_bounds_at := oneHot_bounds_at - weight_bounds_at := weight_bounds_at - value_bounds := hvals_bounds } - · have : False := by - have hnone := - buildInductionCertFromHeadCoreWith?_eq_none_of_not_active - (cfg := cfg) (inputs := inputs) hEps hSqrt hmodel hactive - have hcore' : - (none : Option (InductionHeadCert seq)) = some c := by - exact hnone.symm.trans hcore - cases hcore' - exact this.elim - · have : False := by - have hnone := - buildInductionCertFromHeadCoreWith?_eq_none_of_not_sqrt - (cfg := cfg) (inputs := inputs) hEps hSqrt - have hcore' : - (none : Option (InductionHeadCert seq)) = some c := by - exact hnone.symm.trans hcore - cases hcore' - exact this.elim - · have : False := by - have hnone := - buildInductionCertFromHeadCoreWith?_eq_none_of_not_eps - (cfg := cfg) (inputs := inputs) hEps - have hcore' : - (none : Option (InductionHeadCert seq)) = some c := by - exact hnone.symm.trans hcore - cases hcore' - exact this.elim -/-- Soundness for `buildInductionCertFromHeadCore?`. -/ -theorem buildInductionCertFromHeadCore?_sound - [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) - (hcore : buildInductionCertFromHeadCore? inputs = some c) : - InductionHeadCertSound inputs c := by - simpa [buildInductionCertFromHeadCore?] using - (buildInductionCertFromHeadCoreWith?_sound - (cfg := defaultInductionHeadSplitConfig) inputs c - (by - simpa [buildInductionCertFromHeadCore?] using hcore)) -end Sound -end Nfp +import Nfp.Sound.Induction.CoreSound.Basic + +/-! +Soundness proofs for induction-head core certificates. +-/ diff --git a/Nfp/Sound/Induction/CoreSound/Basic.lean b/Nfp/Sound/Induction/CoreSound/Basic.lean new file mode 100644 index 0000000..0dc5c6d --- /dev/null +++ b/Nfp/Sound/Induction/CoreSound/Basic.lean @@ -0,0 +1,1307 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later +import Nfp.Sound.Induction.Core +import Nfp.Sound.Induction.CoreSound.Values +namespace Nfp +namespace Sound +open scoped BigOperators +open Nfp.Circuit +open Nfp.Sound.Bounds +variable {seq : Nat} +set_option maxHeartbeats 5000000 in +-- The soundness proof expands many cached bounds; extra heartbeats avoid spurious timeouts. +set_option synthInstance.maxHeartbeats 200000 in +-- Instance search also touches the expanded caches; allow more room to avoid timeouts. +/-- Soundness for `buildInductionCertFromHeadCoreWith?`. -/ +theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : Nat} + (cfg : InductionHeadSplitConfig) + (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) + (hcore : buildInductionCertFromHeadCoreWith? cfg inputs = some c) : + InductionHeadCertSound inputs c := by + classical + by_cases hEps : 0 < inputs.lnEps + · by_cases hSqrt : 0 < sqrtLower inputs.lnEps + · by_cases hmodel : dModel = 0 + · have : False := by + have hnone := + buildInductionCertFromHeadCoreWith?_eq_none_of_model_eq_zero + (cfg := cfg) (inputs := inputs) hEps hSqrt hmodel + have hcore' : + (none : Option (InductionHeadCert seq)) = some c := by + exact hnone.symm.trans hcore + cases hcore' + exact this.elim + · by_cases hactive : inputs.active.Nonempty + · let lnBounds := Bounds.cacheBoundPair2 (fun q => + Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) + let lnLo : Fin seq → Fin dModel → Rat := lnBounds.1 + let lnHi : Fin seq → Fin dModel → Rat := lnBounds.2 + let lnAbsMaxTask : Fin seq → Rat := + Bounds.cacheBoundTask (fun q => + Bounds.intervalAbsBound (lnLo q) (lnHi q)) + let lnAbsMaxArr : Array Rat := + Array.ofFn (fun q : Fin seq => lnAbsMaxTask q) + let lnAbsMax : Fin seq → Rat := fun q => + lnAbsMaxArr[q.1]'(by + simp [lnAbsMaxArr]) + let invStdBoundsTasks : Array (Task (Rat × Rat)) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => invStdBounds inputs.lnEps (inputs.embed q))) + let invStdBoundsArr : Array (Rat × Rat) := + Array.ofFn (fun q : Fin seq => + (invStdBoundsTasks[q.1]'(by + simp [invStdBoundsTasks])).get) + let invStdLo : Fin seq → Rat := fun q => + (invStdBoundsArr[q.1]'(by + simp [invStdBoundsArr])).1 + let invStdHi : Fin seq → Rat := fun q => + (invStdBoundsArr[q.1]'(by + simp [invStdBoundsArr])).2 + let lnCoeff : Fin seq → Fin dModel → Rat := fun q j => + inputs.ln1Gamma j * (inputs.embed q j - mean (inputs.embed q)) + let invStd : Fin seq → Real := fun q => + (Real.sqrt ((varianceRat (inputs.embed q) : Real) + (inputs.lnEps : Real)))⁻¹ + have hmeanRat : + ∀ q, (mean (inputs.embed q) : Real) = meanRat (inputs.embed q) := by + intro q + have hmu_rat : mean (inputs.embed q) = meanRat (inputs.embed q) := by + simp [mean_def, hmodel, ratRoundDown] + simpa [ratToReal] using congrArg ratToReal hmu_rat + have hln_affine : + ∀ q j, + lnRealOfInputs inputs q j = + (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q := by + intro q j + have hmu := hmeanRat q + simp [lnRealOfInputs, Bounds.layerNormReal, hmodel, lnCoeff, hmu, invStd, add_comm, + mul_assoc, -mul_eq_mul_left_iff, -mul_eq_mul_right_iff] + have hln_fun : + ∀ q, + lnRealOfInputs inputs q = + fun j => (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q := by + intro q + funext j + exact hln_affine q j + have hinv_bounds : + ∀ q, (invStdLo q : Real) ≤ invStd q ∧ invStd q ≤ (invStdHi q : Real) := by + intro q + simpa [invStd, invStdLo, invStdHi, invStdBoundsArr, invStdBoundsTasks, + invStdBounds, Task.spawn, Array.getElem_ofFn] using + (Bounds.invStdBounds_spec (eps := inputs.lnEps) (x := inputs.embed q) + hmodel hEps hSqrt) + let qBaseArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.wq j d) (fun j => inputs.ln1Beta j) + + inputs.bq d) + let qBase : Fin dHead → Rat := fun d => + qBaseArr[d.1]'(by + simp [qBaseArr]) + let kBaseArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.wk j d) (fun j => inputs.ln1Beta j) + + inputs.bk d) + let kBase : Fin dHead → Rat := fun d => + kBaseArr[d.1]'(by + simp [kBaseArr]) + let coeffRowTasks : + (Fin dModel → Fin dHead → Rat) → + Array (Task { row : Array Rat // row.size = dHead }) := + fun w => + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + let μ := mean (inputs.embed q) + let coeff : Fin dModel → Rat := fun j => + inputs.ln1Gamma j * (inputs.embed q j - μ) + ⟨Array.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => w j d) coeff), + by simp⟩)) + let qCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := + coeffRowTasks inputs.wq + let qCoeffArr : Array { row : Array Rat // row.size = dHead } := + Array.ofFn (fun q : Fin seq => + (qCoeffRowTasks[q.1]'(by + simp [qCoeffRowTasks, coeffRowTasks])).get) + let qCoeff : Fin seq → Fin dHead → Rat := fun q d => + let row := qCoeffArr[q.1]'(by + simp [qCoeffArr]) + row.1[d.1]'(by + simp [row.2]) + let kCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := + coeffRowTasks inputs.wk + let kCoeffArr : Array { row : Array Rat // row.size = dHead } := + Array.ofFn (fun q : Fin seq => + (kCoeffRowTasks[q.1]'(by + simp [kCoeffRowTasks, coeffRowTasks])).get) + let kCoeff : Fin seq → Fin dHead → Rat := fun q d => + let row := kCoeffArr[q.1]'(by + simp [kCoeffArr]) + row.1[d.1]'(by + simp [row.2]) + let qLo : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) + qBase d + bounds.1 + let qHi : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) + qBase d + bounds.2 + let kLo : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) + kBase d + bounds.1 + let kHi : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) + kBase d + bounds.2 + let qAbs : Fin seq → Fin dHead → Rat := fun q d => max |qLo q d| |qHi q d| + let kAbs : Fin seq → Fin dHead → Rat := fun q d => max |kLo q d| |kHi q d| + let qAbsMaxArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := Finset.univ_nonempty + univ.sup' hnonempty (fun q => qAbs q d)) + let qAbsMax : Fin dHead → Rat := fun d => + qAbsMaxArr[d.1]'(by + simp [qAbsMaxArr]) + let kAbsMaxArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := Finset.univ_nonempty + univ.sup' hnonempty (fun k => kAbs k d)) + let kAbsMax : Fin dHead → Rat := fun d => + kAbsMaxArr[d.1]'(by + simp [kAbsMaxArr]) + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let splitBudgetQ : Nat := cfg.splitBudgetQ + let splitBudgetK : Nat := cfg.splitBudgetK + let splitBudgetDiffBase : Nat := cfg.splitBudgetDiffBase + let splitBudgetDiffRefined : Nat := cfg.splitBudgetDiffRefined + let top2ByScore : + (Fin dHead → Rat) → List (Fin dHead) → List (Fin dHead) := fun score ambig => + let step + (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) + (d : Fin dHead) : + Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := + let s := score d + match best with + | (none, none) => (some (s, d), none) + | (some b1, none) => + if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) + | (some b1, some b2) => + if b1.1 < s then (some (s, d), some b1) + else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) + | (none, some b2) => + if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) + match ambig.foldl step (none, none) with + | (some b1, some b2) => [b1.2, b2.2] + | (some b1, none) => [b1.2] + | (none, _) => [] + let finRangeHead : List (Fin dHead) := List.finRange dHead + let finRangeSeq : List (Fin seq) := List.finRange seq + let splitDimsQ : Fin seq → List (Fin dHead) := fun q => + if splitBudgetQ = 0 then + [] + else + let ambig := + finRangeHead.filter (fun d => decide (qLo q d < 0 ∧ 0 < qHi q d)) + let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d + let dims1 := top2ByScore score ambig + let dims2 := top2ByScore score (ambig.filter (fun d => decide (d ∉ dims1))) + (dims1 ++ dims2).take splitBudgetQ + let splitDimsK : Fin seq → Fin seq → List (Fin dHead) := fun q k => + if splitBudgetK = 0 then + [] + else + let ambig := + finRangeHead.filter (fun d => decide (kLo k d < 0 ∧ 0 < kHi k d)) + let score : Fin dHead → Rat := fun d => (kHi k d - kLo k d) * qAbs q d + let dims1 := top2ByScore score ambig + let dims2 := top2ByScore score (ambig.filter (fun d => decide (d ∉ dims1))) + (dims1 ++ dims2).take splitBudgetK + let splitDimsDiffCore : Nat → Fin seq → Fin seq → List (Fin dHead) := fun budget q k => + if budget = 0 then + [] + else + let prev := inputs.prev q + let diffLo : Fin dHead → Rat := fun d => kLo prev d - kHi k d + let diffHi : Fin dHead → Rat := fun d => kHi prev d - kLo k d + let ambig := + finRangeHead.filter (fun d => decide (diffLo d < 0 ∧ 0 < diffHi d)) + let score : Fin dHead → Rat := fun d => (diffHi d - diffLo d) * qAbs q d + let step + (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) + (d : Fin dHead) : + Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := + let s := score d + match best with + | (none, none) => (some (s, d), none) + | (some b1, none) => + if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) + | (some b1, some b2) => + if b1.1 < s then (some (s, d), some b1) + else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) + | (none, some b2) => + if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) + let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => + match ambig.foldl step (none, none) with + | (some b1, some b2) => [b1.2, b2.2] + | (some b1, none) => [b1.2] + | (none, _) => [] + let dims1 : List (Fin dHead) := top2 ambig + let dims2 : List (Fin dHead) := + top2 (ambig.filter (fun d => decide (d ∉ dims1))) + let memDims2 : Fin dHead → Bool := fun d => + dims2.any (fun d' => decide (d' = d)) + let dims3 : List (Fin dHead) := + top2 + ((ambig.filter (fun d => decide (d ∉ dims1))).filter + (fun d => !memDims2 d)) + (dims1 ++ dims2 ++ dims3).take budget + let splitDimsDiffBase : Fin seq → Fin seq → List (Fin dHead) := + splitDimsDiffCore splitBudgetDiffBase + let splitDimsDiffRefined : Fin seq → Fin seq → List (Fin dHead) := + splitDimsDiffCore splitBudgetDiffRefined + let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + let dimsQ := splitDimsQ q + ⟨Array.ofFn (fun k : Fin seq => + if masked q k then + (0, 0) + else + let dimsK := splitDimsK q k + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsK + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo k d) (fun d => kHi k d)), + by simp⟩)) + let dotDiffRowTasksBase : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + if hq : q ∈ inputs.active then + let dimsQ := splitDimsQ q + ⟨Array.ofFn (fun k : Fin seq => + if masked q k then + (0, 0) + else + let dimsDiff := splitDimsDiffBase q k + let prev := inputs.prev q + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)), + by simp⟩ + else + ⟨Array.ofFn (fun _ : Fin seq => (0, 0)), by simp⟩)) + let dotLo : Fin seq → Fin seq → Rat := fun q k => + let row := (dotRowTasks[q.1]'(by + simp [dotRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.1 + let dotHi : Fin seq → Fin seq → Rat := fun q k => + let row := (dotRowTasks[q.1]'(by + simp [dotRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.2 + let dotDiffLoBase : Fin seq → Fin seq → Rat := fun q k => + let row := (dotDiffRowTasksBase[q.1]'(by + simp [dotDiffRowTasksBase, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.1 + let dotDiffHiBase : Fin seq → Fin seq → Rat := fun q k => + let row := (dotDiffRowTasksBase[q.1]'(by + simp [dotDiffRowTasksBase, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.2 + let dotAbs : Fin seq → Fin seq → Rat := fun q k => max |dotLo q k| |dotHi q k| + let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => + |inputs.scale| * dotAbs q k + let scoreLo : Fin seq → Fin seq → Rat := fun q k => + if masked q k then + inputs.maskValue + else + if hscale : 0 ≤ inputs.scale then + inputs.scale * dotLo q k + else + inputs.scale * dotHi q k + let scoreHi : Fin seq → Fin seq → Rat := fun q k => + if masked q k then + inputs.maskValue + else + if hscale : 0 ≤ inputs.scale then + inputs.scale * dotHi q k + else + inputs.scale * dotLo q k + let scoreLoPrev : Fin seq → Rat := fun q => + scoreLo q (inputs.prev q) + let scoreGapLoBaseRaw : Fin seq → Fin seq → Rat := fun q k => + if masked q (inputs.prev q) then + scoreLoPrev q - scoreHi q k + else if masked q k then + scoreLoPrev q - inputs.maskValue + else if hscale : 0 ≤ inputs.scale then + inputs.scale * dotDiffLoBase q k + else + inputs.scale * dotDiffHiBase q k + let scoreGapLoBase : Fin seq → Fin seq → Rat := + Bounds.cacheBound2 scoreGapLoBaseRaw + let otherKeys : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + let worstKeyRaw : Fin seq → Option (Fin seq) := fun q => + if hq : q ∈ inputs.active then + let ks := finRangeSeq.filter (fun k => decide (k ≠ inputs.prev q)) + match ks with + | [] => none + | k :: ks => + let step (best : Rat × Fin seq) (k : Fin seq) := + let s := scoreGapLoBase q k + if s ≤ best.1 then (s, k) else best + some (ks.foldl step (scoreGapLoBase q k, k)).2 + else + none + let worstKeyArr : Array (Thunk (Option (Fin seq))) := + Array.ofFn (fun q => Thunk.mk (fun _ => worstKeyRaw q)) + let worstKey : Fin seq → Option (Fin seq) := fun q => + let t := worstKeyArr[q.1]'(by + simp [worstKeyArr, q.isLt]) + Thunk.get t + let dotDiffLo : Fin seq → Fin seq → Rat := fun q k => + match worstKey q with + | some k' => + if hk : k = k' then + let dimsQ := splitDimsQ q + let dimsDiff := splitDimsDiffRefined q k + let prev := inputs.prev q + (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)).1 + else + dotDiffLoBase q k + | none => dotDiffLoBase q k + let dotDiffHi : Fin seq → Fin seq → Rat := fun q k => + match worstKey q with + | some k' => + if hk : k = k' then + let dimsQ := splitDimsQ q + let dimsDiff := splitDimsDiffRefined q k + let prev := inputs.prev q + (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)).2 + else + dotDiffHiBase q k + | none => dotDiffHiBase q k + let scoreGapLoRaw : Fin seq → Fin seq → Rat := fun q k => + if masked q (inputs.prev q) then + scoreLoPrev q - scoreHi q k + else if masked q k then + scoreLoPrev q - inputs.maskValue + else if hscale : 0 ≤ inputs.scale then + inputs.scale * dotDiffLo q k + else + inputs.scale * dotDiffHi q k + let scoreGapLo : Fin seq → Fin seq → Rat := + Bounds.cacheBound2 scoreGapLoRaw + let marginAt : Fin seq → Rat := fun q => + if hq : q ∈ inputs.active then + let other := otherKeys q + if h : other.Nonempty then + other.inf' h (fun k => scoreGapLo q k) + else + (0 : Rat) + else + (0 : Rat) + let weightBoundAtBase : Fin seq → Fin seq → Rat := fun q k => + if hk : k = inputs.prev q then + (0 : Rat) + else + let gap := scoreGapLo q k + if gap < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + gap) + let weightBoundAt : Fin seq → Fin seq → Rat := + Bounds.cacheBound2 weightBoundAtBase + let epsAtBase : Fin seq → Rat := fun q => + let other := otherKeys q + let total := other.sum (fun k => weightBoundAt q k) + min (1 : Rat) total + let epsAt : Fin seq → Rat := + Bounds.cacheBoundThunk epsAtBase + let margin : Rat := + if h : inputs.active.Nonempty then + inputs.active.inf' h marginAt + else + (0 : Rat) + let eps : Rat := + if h : inputs.active.Nonempty then + inputs.active.sup' h epsAt + else + (0 : Rat) + have hseq : (1 : Nat) ≤ seq := + Nat.succ_le_iff.mpr (Nat.pos_of_ne_zero (NeZero.ne seq)) + let dirHeadVec := dirHeadVecOfInputs inputs + let dirHead : Fin dHead → Rat := fun d => dirHeadVec.get d + let wvDir : Fin dModel → Rat := + Bounds.cacheBoundTask (fun j => + Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) + let bDir : Rat := + Linear.dotFin dHead dirHead (fun d => inputs.bv d) + let valsLo : Fin seq → Rat := fun q => + bDir + Bounds.dotIntervalLower (fun j => wvDir j) (lnLo q) (lnHi q) + let valsHi : Fin seq → Rat := fun q => + bDir + Bounds.dotIntervalUpper (fun j => wvDir j) (lnLo q) (lnHi q) + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := by simp [univ] + let lo := univ.inf' hnonempty valsLo + let hi := univ.sup' hnonempty valsHi + let valCert : ValueInterval seq := + { lo := lo + hi := hi + valsLo := valsLo + valsHi := valsHi + direction := some inputs.directionSpec } + let cert : InductionHeadCert seq := + { eps := eps + epsAt := epsAt + weightBoundAt := weightBoundAt + margin := margin + active := inputs.active + prev := inputs.prev + values := valCert } + have hcore' : buildInductionCertFromHeadCoreWith? cfg inputs = some cert := by + simp (config := { zeta := false }) only + [buildInductionCertFromHeadCoreWith?, hEps, hSqrt, hmodel, hactive] + rfl + have hc : c = cert := by + have hcert : cert = c := by + exact Option.some.inj (hcore'.symm.trans hcore) + simpa using hcert.symm + subst hc + have hln_bounds : + ∀ q i, (lnLo q i : Real) ≤ lnRealOfInputs inputs q i ∧ + lnRealOfInputs inputs q i ≤ (lnHi q i : Real) := by + intro q i + have hln := + Bounds.layerNormBounds_spec (eps := inputs.lnEps) + (gamma := inputs.ln1Gamma) (beta := inputs.ln1Beta) + (x := inputs.embed q) hmodel hEps hSqrt + simpa [lnBounds, lnLo, lnHi, lnRealOfInputs, Bounds.cacheBoundPair2_apply_left, + Bounds.cacheBoundPair2_apply_right] using hln i + have dotFin_cast {n : Nat} (f g : Fin n → Rat) : + (Linear.dotFin n f g : Real) = + dotProduct (fun j => (f j : Real)) (fun j => (g j : Real)) := by + simp [Linear.dotFin_def, Linear.sumFin_eq_sum_univ, dotProduct] + have proj_bounds + (w : Fin dModel → Fin dHead → Rat) + (b base : Fin dHead → Rat) + (coeff : Fin seq → Fin dHead → Rat) + (hbase : ∀ d, + (base d : Real) = + dotProduct (fun j => (w j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + (b d : Real)) + (hcoeff : ∀ q d, + (coeff q d : Real) = + dotProduct (fun j => (w j d : Real)) + (fun j => (lnCoeff q j : Real))) : + ∀ q d, + (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).1 : Rat) ≤ + dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + + (b d : Real) ∧ + dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + + (b d : Real) ≤ + (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).2 : Rat) := by + intro q d + have hinv : (invStdLo q : Real) ≤ invStd q ∧ invStd q ≤ (invStdHi q : Real) := + hinv_bounds q + have hln_fun_q : + lnRealOfInputs inputs q = + fun j => (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q := by + exact hln_fun q + have hdot_add : + dotProduct (fun j => (w j d : Real)) + (fun j => + (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q) = + dotProduct (fun j => (w j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + dotProduct (fun j => (w j d : Real)) + (fun j => (lnCoeff q j : Real) * invStd q) := by + simpa using + (Nfp.Sound.Linear.dotProduct_add_right + (x := fun j => (w j d : Real)) + (y := fun j => (inputs.ln1Beta j : Real)) + (z := fun j => (lnCoeff q j : Real) * invStd q)) + have hdot_coeff : + dotProduct (fun j => (w j d : Real)) + (fun j => (lnCoeff q j : Real) * invStd q) = + dotProduct (fun j => (w j d : Real)) + (fun j => (lnCoeff q j : Real)) * + invStd q := by + simpa using + (Nfp.Sound.Linear.dotProduct_mul_right + (x := fun j => (w j d : Real)) + (y := fun j => (lnCoeff q j : Real)) + (a := invStd q)) + have hreal : + dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + + (b d : Real) = + (base d : Real) + (coeff q d : Real) * invStd q := by + calc + dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + + (b d : Real) = + dotProduct (fun j => (w j d : Real)) + (fun j => + (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q) + + (b d : Real) := by + simp [hln_fun_q] + _ = + dotProduct (fun j => (w j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + dotProduct (fun j => (w j d : Real)) + (fun j => (lnCoeff q j : Real)) * + invStd q + + (b d : Real) := by + simp [hdot_add, hdot_coeff, add_assoc] + _ = + (dotProduct (fun j => (w j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + (b d : Real)) + + dotProduct (fun j => (w j d : Real)) + (fun j => (lnCoeff q j : Real)) * + invStd q := by ac_rfl + _ = (base d : Real) + (coeff q d : Real) * invStd q := by + simp [hbase, hcoeff] + have hscale : + let bounds := scaleInterval (coeff q d) (invStdLo q) (invStdHi q) + (bounds.1 : Real) ≤ (coeff q d : Real) * invStd q ∧ + (coeff q d : Real) * invStd q ≤ (bounds.2 : Real) := by + exact scaleInterval_bounds_real (x := coeff q d) (lo := invStdLo q) + (hi := invStdHi q) (y := invStd q) hinv.1 hinv.2 + have hlow : + (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).1 : Rat) ≤ + dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + + (b d : Real) := by + simpa [hreal] using add_le_add_left hscale.1 (base d : Real) + have hhigh : + dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + + (b d : Real) ≤ + (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).2 : Rat) := by + simpa [hreal] using add_le_add_left hscale.2 (base d : Real) + exact ⟨hlow, hhigh⟩ + have hq_bounds : + ∀ q d, (qLo q d : Real) ≤ qRealOfInputs inputs q d ∧ + qRealOfInputs inputs q d ≤ (qHi q d : Real) := by + intro q d + have hbase : + ∀ d, + (qBase d : Real) = + dotProduct (fun j => (inputs.wq j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + (inputs.bq d : Real) := by + intro d + simp [qBase, qBaseArr, dotFin_cast] + have hcoeff : + ∀ q' d, + (qCoeff q' d : Real) = + dotProduct (fun j => (inputs.wq j d : Real)) + (fun j => (lnCoeff q' j : Real)) := by + intro q' d + simpa [qCoeff, qCoeffArr, qCoeffRowTasks, coeffRowTasks, lnCoeff, Task.spawn] using + (dotFin_cast (f := fun j => inputs.wq j d) + (g := fun j => + inputs.ln1Gamma j * (inputs.embed q' j - mean (inputs.embed q')))) + have h := proj_bounds (w := inputs.wq) (b := inputs.bq) (base := qBase) + (coeff := qCoeff) hbase hcoeff q d + simpa [qLo, qHi, qRealOfInputs] using h + have hk_bounds : + ∀ q d, (kLo q d : Real) ≤ kRealOfInputs inputs q d ∧ + kRealOfInputs inputs q d ≤ (kHi q d : Real) := by + intro q d + have hbase : + ∀ d, + (kBase d : Real) = + dotProduct (fun j => (inputs.wk j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + (inputs.bk d : Real) := by + intro d + simp [kBase, kBaseArr, dotFin_cast] + have hcoeff : + ∀ q' d, + (kCoeff q' d : Real) = + dotProduct (fun j => (inputs.wk j d : Real)) + (fun j => (lnCoeff q' j : Real)) := by + intro q' d + simpa [kCoeff, kCoeffArr, kCoeffRowTasks, coeffRowTasks, lnCoeff, Task.spawn] using + (dotFin_cast (f := fun j => inputs.wk j d) + (g := fun j => + inputs.ln1Gamma j * (inputs.embed q' j - mean (inputs.embed q')))) + have h := proj_bounds (w := inputs.wk) (b := inputs.bk) (base := kBase) + (coeff := kCoeff) hbase hcoeff q d + simpa [kLo, kHi, kRealOfInputs] using h + let scoresReal := scoresRealOfInputs inputs + have scoresReal_eq_base_of_not_masked : + ∀ q k, ¬ masked q k → + scoresReal q k = + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) := by + intro q k hnot + by_cases hcausal : inputs.maskCausal + · have hnot_lt : ¬ q < k := by + intro hlt + exact hnot ⟨hcausal, hlt⟩ + have hle : k ≤ q := le_of_not_gt hnot_lt + simp [scoresReal, scoresRealOfInputs, hcausal, hle] + · simp [scoresReal, scoresRealOfInputs, hcausal] + have scoresReal_eq_masked : + ∀ q k, masked q k → scoresReal q k = (inputs.maskValue : Real) := by + intro q k hmask + have hmask' : inputs.maskCausal = true ∧ q < k := by + simpa [masked] using hmask + have hle : ¬ k ≤ q := not_le_of_gt hmask'.2 + simp [scoresReal, scoresRealOfInputs, hmask'.1, hle] + have hscore_bounds : + ∀ q k, (scoreLo q k : Real) ≤ scoresReal q k ∧ + scoresReal q k ≤ (scoreHi q k : Real) := by + intro q k + let base := + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) + have hdot_bounds (hnot : ¬ masked q k) : + (dotLo q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) ∧ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) ≤ (dotHi q k : Real) := by + have hq := hq_bounds q + have hk := hk_bounds k + have hlo1 : ∀ d, (qLo q d : Real) ≤ qRealOfInputs inputs q d := fun d => + (hq d).1 + have hhi1 : ∀ d, qRealOfInputs inputs q d ≤ (qHi q d : Real) := fun d => + (hq d).2 + have hlo2 : ∀ d, (kLo k d : Real) ≤ kRealOfInputs inputs k d := fun d => + (hk d).1 + have hhi2 : ∀ d, kRealOfInputs inputs k d ≤ (kHi k d : Real) := fun d => + (hk d).2 + have hspec := + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth_spec_real + (dims1 := splitDimsQ q) (dims2 := splitDimsK q k) + (lo1 := fun d => qLo q d) (hi1 := fun d => qHi q d) + (lo2 := fun d => kLo k d) (hi2 := fun d => kHi k d) + (x := fun d => qRealOfInputs inputs q d) + (y := fun d => kRealOfInputs inputs k d) + hlo1 hhi1 hlo2 hhi2 + have hlow' : + (dotLo q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) := by + simpa [dotLo, dotRowTasks, Task.spawn, Array.getElem_ofFn, hnot] + using hspec.1 + have hhigh' : + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) ≤ (dotHi q k : Real) := by + simpa [dotHi, dotRowTasks, Task.spawn, Array.getElem_ofFn, hnot] + using hspec.2 + exact ⟨hlow', hhigh'⟩ + have hscore_base_bounds (hnot : ¬ masked q k) : + (scoreLo q k : Real) ≤ base ∧ base ≤ (scoreHi q k : Real) := by + by_cases hscale : 0 ≤ inputs.scale + · have hscale_real : 0 ≤ (inputs.scale : Real) := + ratToReal_nonneg_of_nonneg hscale + have hdot := hdot_bounds hnot + have hlow := mul_le_mul_of_nonneg_left hdot.1 hscale_real + have hhigh := mul_le_mul_of_nonneg_left hdot.2 hscale_real + constructor + · simpa [scoreLo, masked, hnot, hscale, base] using hlow + · simpa [scoreHi, masked, hnot, hscale, base] using hhigh + · have hscale_nonpos : inputs.scale ≤ 0 := + le_of_lt (lt_of_not_ge hscale) + have hscale_real : (inputs.scale : Real) ≤ 0 := + (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos + have hdot := hdot_bounds hnot + have hlow := mul_le_mul_of_nonpos_left hdot.2 hscale_real + have hhigh := mul_le_mul_of_nonpos_left hdot.1 hscale_real + constructor + · simpa [scoreLo, masked, hnot, hscale, base] using hlow + · simpa [scoreHi, masked, hnot, hscale, base] using hhigh + by_cases hcausal : inputs.maskCausal + · by_cases hle : k ≤ q + · have hnot : ¬ q < k := not_lt_of_ge hle + have hnot_masked : ¬ masked q k := fun hmk => hnot hmk.2 + have hscore_eq : scoresReal q k = base := + scoresReal_eq_base_of_not_masked q k hnot_masked + have hbase := hscore_base_bounds hnot_masked + constructor + · simpa [hscore_eq] using hbase.1 + · simpa [hscore_eq] using hbase.2 + · have hlt : q < k := lt_of_not_ge hle + have hmask : masked q k := ⟨hcausal, hlt⟩ + have hscore : scoresReal q k = (inputs.maskValue : Real) := + scoresReal_eq_masked q k hmask + constructor + · simp [hscore, scoreLo, hmask] + · simp [hscore, scoreHi, hmask] + · have hnot_masked : ¬ masked q k := by + simp [masked, hcausal] + have hscore_eq : scoresReal q k = base := + scoresReal_eq_base_of_not_masked q k hnot_masked + have hbase := hscore_base_bounds hnot_masked + constructor + · simpa [hscore_eq] using hbase.1 + · simpa [hscore_eq] using hbase.2 + have hdot_diff_bounds : + ∀ q, q ∈ inputs.active → ∀ k, ¬ masked q k → + (dotDiffLo q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) ∧ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by + intro q hq k hmask + have hq_bounds' := hq_bounds q + have hkprev := hk_bounds (inputs.prev q) + have hk := hk_bounds k + have hlo1 : ∀ d, (qLo q d : Real) ≤ qRealOfInputs inputs q d := fun d => + (hq_bounds' d).1 + have hhi1 : ∀ d, qRealOfInputs inputs q d ≤ (qHi q d : Real) := fun d => + (hq_bounds' d).2 + have hlo2 : + ∀ d, + (kLo (inputs.prev q) d - kHi k d : Rat) ≤ + (kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) := by + intro d + have hprev_lo := (hkprev d).1 + have hk_hi := (hk d).2 + have h := sub_le_sub hprev_lo hk_hi + simpa [ratToReal_sub] using h + have hhi2 : + ∀ d, + (kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) ≤ + (kHi (inputs.prev q) d - kLo k d : Rat) := by + intro d + have hprev_hi := (hkprev d).2 + have hk_lo := (hk d).1 + have h := sub_le_sub hprev_hi hk_lo + simpa [ratToReal_sub] using h + have hspec (dimsDiff : List (Fin dHead)) := + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth_spec_real + (dims1 := splitDimsQ q) (dims2 := dimsDiff) + (lo1 := fun d => qLo q d) (hi1 := fun d => qHi q d) + (lo2 := fun d => kLo (inputs.prev q) d - kHi k d) + (hi2 := fun d => kHi (inputs.prev q) d - kLo k d) + (x := fun d => qRealOfInputs inputs q d) + (y := fun d => + kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) + hlo1 hhi1 hlo2 hhi2 + have hspecBase := hspec (splitDimsDiffBase q k) + have hspecRef := hspec (splitDimsDiffRefined q k) + have hspecBase_bounds : + (dotDiffLoBase q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) ∧ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) ≤ (dotDiffHiBase q k : Real) := by + refine ⟨?_, ?_⟩ + · simpa [dotDiffLoBase, dotDiffRowTasksBase, hq, hmask, Task.spawn, + Array.getElem_ofFn] using hspecBase.1 + · simpa [dotDiffHiBase, dotDiffRowTasksBase, hq, hmask, Task.spawn, + Array.getElem_ofFn] using hspecBase.2 + cases hkey : worstKey q with + | none => + simpa [dotDiffLo, dotDiffHi, hkey] using hspecBase_bounds + | some k' => + by_cases hk : k = k' + · have hlow' : + (dotDiffLo q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) := by + simpa [dotDiffLo, hkey, hk] using hspecRef.1 + have hhigh' : + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by + simpa [dotDiffHi, hkey, hk] using hspecRef.2 + exact ⟨hlow', hhigh'⟩ + · have hlow' : + (dotDiffLo q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) := by + simpa [dotDiffLo, hkey, hk] using hspecBase_bounds.1 + have hhigh' : + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by + simpa [dotDiffHi, hkey, hk] using hspecBase_bounds.2 + exact ⟨hlow', hhigh'⟩ + have hmarginAt_le : + ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → + marginAt q ≤ scoreGapLo q k := by + intro q hq k hk + have hmem : k ∈ otherKeys q := by simp [otherKeys, hk] + have hnonempty : (otherKeys q).Nonempty := ⟨k, hmem⟩ + have hle : + (otherKeys q).inf' hnonempty (fun k => scoreGapLo q k) ≤ scoreGapLo q k := by + exact (Finset.inf'_le_iff (s := otherKeys q) (H := hnonempty) + (f := fun k => scoreGapLo q k) (a := scoreGapLo q k)).2 + ⟨k, hmem, le_rfl⟩ + simpa [marginAt, hq, hnonempty] using hle + have hscore_gap_real_at : + ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → + scoresReal q k + (scoreGapLo q k : Real) ≤ scoresReal q (inputs.prev q) := by + intro q hq k hk + by_cases hprevmask : masked q (inputs.prev q) + · have hscore_hi : scoresReal q k ≤ (scoreHi q k : Real) := + (hscore_bounds q k).2 + have hscore_prev : (scoreLoPrev q : Real) ≤ scoresReal q (inputs.prev q) := by + have hprev_bounds := hscore_bounds q (inputs.prev q) + simpa [scoreLoPrev] using hprev_bounds.1 + have hsum_le' : + (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k ≤ + (scoreLoPrev q : Real) := by + have hsub : + (scoreLoPrev q : Real) - (scoreHi q k : Real) ≤ + (scoreLoPrev q : Real) - scoresReal q k := + sub_le_sub_left hscore_hi (scoreLoPrev q : Real) + calc + (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k + ≤ (scoreLoPrev q : Real) - scoresReal q k + scoresReal q k := by + simpa [add_comm, add_left_comm, add_assoc] using + (add_le_add_left hsub (scoresReal q k)) + _ = (scoreLoPrev q : Real) := by + simp [sub_add_cancel] + calc + scoresReal q k + (scoreGapLo q k : Real) + = (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k := by + simp [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, + hprevmask, add_comm] + _ ≤ (scoreLoPrev q : Real) := hsum_le' + _ ≤ scoresReal q (inputs.prev q) := hscore_prev + · by_cases hmask : masked q k + · have hscore_prev : (scoreLoPrev q : Real) ≤ scoresReal q (inputs.prev q) := by + have hprev_bounds := hscore_bounds q (inputs.prev q) + simpa [scoreLoPrev] using hprev_bounds.1 + have hscore_k : scoresReal q k = (inputs.maskValue : Real) := + scoresReal_eq_masked q k hmask + calc + scoresReal q k + (scoreGapLo q k : Real) + = (inputs.maskValue : Real) + (scoreLoPrev q : Real) - + (inputs.maskValue : Real) := by + simp [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, + hprevmask, hmask, hscore_k] + _ = (scoreLoPrev q : Real) := by + simp [add_sub_cancel_left] + _ ≤ scoresReal q (inputs.prev q) := hscore_prev + · have hdiff := hdot_diff_bounds q hq k hmask + have hgap_le : + (scoreGapLo q k : Real) ≤ + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) := by + by_cases hscale : 0 ≤ inputs.scale + · have hscale_real : 0 ≤ (inputs.scale : Real) := + ratToReal_nonneg_of_nonneg hscale + have hle := mul_le_mul_of_nonneg_left hdiff.1 hscale_real + simpa [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, + hprevmask, hmask, hscale] using hle + · have hscale_nonpos : inputs.scale ≤ 0 := + le_of_lt (lt_of_not_ge hscale) + have hscale_real : (inputs.scale : Real) ≤ 0 := + (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos + have hle := mul_le_mul_of_nonpos_left hdiff.2 hscale_real + simpa [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, + hprevmask, hmask, hscale] using hle + have hscore_prev : + scoresReal q (inputs.prev q) = + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d) := by + simpa using + (scoresReal_eq_base_of_not_masked q (inputs.prev q) hprevmask) + have hscore_k : + scoresReal q k = + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) := by + simpa using (scoresReal_eq_base_of_not_masked q k hmask) + have hdot_sub : + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) = + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d) - + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) := by + classical + simpa using + (Nfp.Sound.Linear.dotProduct_sub_right + (x := fun d => qRealOfInputs inputs q d) + (y := fun d => kRealOfInputs inputs (inputs.prev q) d) + (z := fun d => kRealOfInputs inputs k d)) + have hscore_diff : + scoresReal q (inputs.prev q) - scoresReal q k = + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) := by + calc + scoresReal q (inputs.prev q) - scoresReal q k + = + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d) - + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) := by + simp [hscore_prev, hscore_k] + _ = + (inputs.scale : Real) * + (dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d) - + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d)) := by + simp [mul_sub] + _ = + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) := by + simp [hdot_sub] + have hgap_le' : + (scoreGapLo q k : Real) ≤ + scoresReal q (inputs.prev q) - scoresReal q k := by + simpa [hscore_diff] using hgap_le + have hgap_add := + add_le_add_right hgap_le' (scoresReal q k) + have hgap_add' : + scoresReal q k + (scoreGapLo q k : Real) ≤ + scoresReal q (inputs.prev q) := by + have hcancel : + scoresReal q k + (scoresReal q (inputs.prev q) - scoresReal q k) = + scoresReal q (inputs.prev q) := by + calc + scoresReal q k + (scoresReal q (inputs.prev q) - scoresReal q k) + = + scoresReal q k + scoresReal q (inputs.prev q) - + scoresReal q k := by + symm + exact add_sub_assoc (scoresReal q k) + (scoresReal q (inputs.prev q)) (scoresReal q k) + _ = scoresReal q (inputs.prev q) := by + simp [add_sub_cancel_left] + calc + scoresReal q k + (scoreGapLo q k : Real) + ≤ scoresReal q k + (scoresReal q (inputs.prev q) - scoresReal q k) := + hgap_add + _ = scoresReal q (inputs.prev q) := hcancel + exact hgap_add' + let weights : Fin seq → Fin seq → Real := fun q k => + Circuit.softmax (scoresReal q) k + let others : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + have hscore_margin_real_at : + ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → + scoresReal q k + (marginAt q : Real) ≤ scoresReal q (inputs.prev q) := by + intro q hq k hk + have hmargin_le : marginAt q ≤ scoreGapLo q k := + hmarginAt_le q hq k hk + have hmargin_le_real : (marginAt q : Real) ≤ (scoreGapLo q k : Real) := + ratToReal_le_of_le hmargin_le + have hscore_gap := hscore_gap_real_at q hq k hk + have hstep := add_le_add_right hmargin_le_real (scoresReal q k) + have hstep' : + scoresReal q k + (marginAt q : Real) ≤ + scoresReal q k + (scoreGapLo q k : Real) := by + exact hstep + exact hstep'.trans hscore_gap + have hscore_margin_real : + ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → + scoresReal q k + (margin : Real) ≤ scoresReal q (inputs.prev q) := by + intro q hq k hk + have hmargin_le : margin ≤ marginAt q := by + have hmem : q ∈ inputs.active := hq + have hnonempty : inputs.active.Nonempty := hactive + have hle := + (Finset.inf'_le_iff (s := inputs.active) (H := hnonempty) + (f := marginAt) (a := marginAt q)).2 ⟨q, hmem, le_rfl⟩ + simpa [margin, hnonempty] using hle + have hmargin_le_real : (margin : Real) ≤ (marginAt q : Real) := + ratToReal_le_of_le hmargin_le + have hscore := hscore_margin_real_at q hq k hk + have hscore' : + (marginAt q : Real) + scoresReal q k ≤ scoresReal q (inputs.prev q) := by + simpa [add_comm] using hscore + have hstep := add_le_add_right hmargin_le_real (scoresReal q k) + have hstep' : + scoresReal q k + (margin : Real) ≤ (marginAt q : Real) + scoresReal q k := by + calc + scoresReal q k + (margin : Real) ≤ scoresReal q k + (marginAt q : Real) := hstep + _ = (marginAt q : Real) + scoresReal q k := by + simp [add_comm] + exact hstep'.trans hscore' + have hweightBoundAt : + ∀ q k, k ≠ inputs.prev q → + weightBoundAt q k = + if scoreGapLo q k < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + scoreGapLo q k) := by + intro q k hk + simpa [weightBoundAt, weightBoundAtBase, hk] using + (Bounds.cacheBound2_apply (f := weightBoundAtBase) q k) + have hepsAt : + ∀ q, epsAt q = + min (1 : Rat) + ((otherKeys q).sum (fun k => + if scoreGapLo q k < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + scoreGapLo q k))) := by + intro q + have hsum : + (otherKeys q).sum (fun k => weightBoundAt q k) = + (otherKeys q).sum (fun k => + if scoreGapLo q k < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + scoreGapLo q k)) := by + refine Finset.sum_congr rfl ?_ + intro k hk + have hk' : k ≠ inputs.prev q := (Finset.mem_erase.mp hk).1 + simp [hweightBoundAt q k hk'] + simpa [epsAt, epsAtBase, hsum] using + (Bounds.cacheBoundThunk_apply (f := epsAtBase) q) + have oneHot_bounds_at : + ∀ q, q ∈ inputs.active → + Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAt q : Real) + (fun q' => q' = q) inputs.prev weights := by + intro q hq + exact + Sound.oneHot_bounds_at_of_scoreGapLo + (active := inputs.active) + (prev := inputs.prev) + (scoresReal := scoresReal) + (scoreGapLo := scoreGapLo) + (epsAt := epsAt) + (hepsAt := hepsAt) + (hscore_gap_real_at := hscore_gap_real_at) + q hq + have weight_bounds_at : + ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → + weights q k ≤ (weightBoundAt q k : Real) := by + intro q hq k hk + exact + Sound.weight_bound_at_of_scoreGapLo + (active := inputs.active) + (prev := inputs.prev) + (scoresReal := scoresReal) + (scoreGapLo := scoreGapLo) + (weightBoundAt := weightBoundAt) + (hweightBoundAt := hweightBoundAt) + (hscore_gap_real_at := hscore_gap_real_at) + q hq k hk + have hepsAt_le_eps : + ∀ q, q ∈ inputs.active → epsAt q ≤ eps := by + intro q hq + have hle : + epsAt q ≤ inputs.active.sup' hactive epsAt := by + exact + (Finset.le_sup'_iff (s := inputs.active) (H := hactive) + (f := epsAt) (a := epsAt q)).2 ⟨q, hq, le_rfl⟩ + simpa [eps, hactive] using hle + have hepsAt_le_eps_real : + ∀ q, q ∈ inputs.active → (epsAt q : Real) ≤ (eps : Real) := by + intro q hq + exact ratToReal_le_of_le (hepsAt_le_eps q hq) + have hsoftmax_bounds : + Layers.SoftmaxMarginBoundsOn (Val := Real) (eps : Real) (margin : Real) + (fun q => q ∈ inputs.active) inputs.prev scoresReal weights := by + classical + refine + { score_margin := ?_ + nonneg := ?_ + sum_one := ?_ + prev_large := ?_ + other_le := ?_ } + · intro q hq k hk + exact hscore_margin_real q hq k hk + · intro q _ k + simpa [weights] using + (Circuit.softmax_nonneg (scores := scoresReal q) k) + · intro q _ + simpa [weights] using + (Circuit.softmax_sum_one (scores := scoresReal q)) + · intro q hq + have honehot := oneHot_bounds_at q hq + have hprev := honehot.prev_large q rfl + have hle : + weights q (inputs.prev q) + (epsAt q : Real) ≤ + weights q (inputs.prev q) + (eps : Real) := by + simpa [add_comm] using + (add_le_add_right (hepsAt_le_eps_real q hq) (weights q (inputs.prev q))) + exact hprev.trans hle + · intro q hq k hk + have honehot := oneHot_bounds_at q hq + have hother := honehot.other_le q rfl k hk + exact hother.trans (hepsAt_le_eps_real q hq) + have hdirHead : + dirHead = fun d => (dirHeadVecOfInputs inputs).get d := by + simp [dirHead, dirHeadVec] + have hwvDir : + ∀ j, wvDir j = Linear.dotFin dHead dirHead (fun d => inputs.wv j d) := by + intro j + simp [wvDir, Bounds.cacheBoundTask_apply] + have hbDir : + bDir = Linear.dotFin dHead dirHead (fun d => inputs.bv d) := by + rfl + have hdir_wv : + ∀ j, (wvDir j : Real) = ∑ d, (dirHead d : Real) * (inputs.wv j d : Real) := + wvDir_real_eq_sum inputs dirHead wvDir hwvDir + have hdir_bv : + (bDir : Real) = ∑ d, (dirHead d : Real) * (inputs.bv d : Real) := + bDir_real_eq_sum inputs dirHead bDir hbDir + have hvals_eq : + ∀ k, + valsRealOfInputs inputs k = + dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + + (bDir : Real) := + valsReal_eq_of_dir inputs dirHead hdirHead wvDir bDir hdir_wv hdir_bv + have hvals_bounds_at : + ∀ k, + (valCert.valsLo k : Real) ≤ valsRealOfInputs inputs k ∧ + valsRealOfInputs inputs k ≤ (valCert.valsHi k : Real) := by + intro k + have hln := hln_bounds k + have hlo : ∀ j, (lnLo k j : Real) ≤ lnRealOfInputs inputs k j := fun j => + (hln j).1 + have hhi : ∀ j, lnRealOfInputs inputs k j ≤ (lnHi k j : Real) := fun j => + (hln j).2 + have hlow' : + (Bounds.dotIntervalLower (fun j => wvDir j) (lnLo k) (lnHi k) : Real) + + (bDir : Real) ≤ + dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + + (bDir : Real) := by + simpa using + (Bounds.dotIntervalLower_le_dotProduct_real_add + (v := fun j => wvDir j) + (lo := lnLo k) (hi := lnHi k) + (x := lnRealOfInputs inputs k) (b := (bDir : Real)) hlo hhi) + have hhigh' : + dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + + (bDir : Real) ≤ + (Bounds.dotIntervalUpper (fun j => wvDir j) (lnLo k) (lnHi k) : Real) + + (bDir : Real) := by + simpa using + (Bounds.dotProduct_le_dotIntervalUpper_real_add + (v := fun j => wvDir j) + (lo := lnLo k) (hi := lnHi k) + (x := lnRealOfInputs inputs k) (b := (bDir : Real)) hlo hhi) + have hlow : + (valCert.valsLo k : Real) ≤ valsRealOfInputs inputs k := by + simpa [valCert, valsLo, hvals_eq k, ratToReal_add, add_comm, add_left_comm, + add_assoc] using hlow' + have hhigh : + valsRealOfInputs inputs k ≤ (valCert.valsHi k : Real) := by + simpa [valCert, valsHi, hvals_eq k, ratToReal_add, add_comm, add_left_comm, + add_assoc] using hhigh' + exact ⟨hlow, hhigh⟩ + have hvals_bounds : + ValueIntervalBounds (vals := valsRealOfInputs inputs) valCert := by + refine + { lo_le_hi := ?_ + lo_le_valsLo := ?_ + vals_bounds := ?_ + valsHi_le_hi := ?_ } + · rcases (Finset.univ_nonempty : univ.Nonempty) with ⟨k0, hk0⟩ + have hmem0 : k0 ∈ univ := hk0 + have hlo : (valCert.lo : Real) ≤ (valCert.valsLo k0 : Real) := by + have hloRat : valCert.lo ≤ valCert.valsLo k0 := by + change lo ≤ valsLo k0 + dsimp [lo] + refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) + (f := valsLo) (a := valsLo k0)).2 ⟨k0, hmem0, le_rfl⟩ + exact ratToReal_le_of_le hloRat + have hvals : + (valCert.valsLo k0 : Real) ≤ valsRealOfInputs inputs k0 ∧ + valsRealOfInputs inputs k0 ≤ (valCert.valsHi k0 : Real) := by + exact hvals_bounds_at k0 + have hhi : (valCert.valsHi k0 : Real) ≤ (valCert.hi : Real) := by + have hhiRat : valCert.valsHi k0 ≤ valCert.hi := by + change valsHi k0 ≤ hi + dsimp [hi] + refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) + (f := valsHi) (a := valsHi k0)).2 ⟨k0, ⟨hmem0, le_rfl⟩⟩ + exact ratToReal_le_of_le hhiRat + have hreal : + (valCert.lo : Real) ≤ (valCert.hi : Real) := + le_trans hlo (le_trans hvals.1 (le_trans hvals.2 hhi)) + exact (ratToReal_le_iff (x := valCert.lo) (y := valCert.hi)).1 hreal + · intro k + have hloRat : valCert.lo ≤ valCert.valsLo k := by + change lo ≤ valsLo k + dsimp [lo] + refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) + (f := valsLo) (a := valsLo k)).2 ⟨k, by simp [univ], le_rfl⟩ + exact ratToReal_le_of_le hloRat + · intro k + exact hvals_bounds_at k + · intro k + have hhiRat : valCert.valsHi k ≤ valCert.hi := by + change valsHi k ≤ hi + dsimp [hi] + refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) + (f := valsHi) (a := valsHi k)).2 ⟨k, by simp [univ], le_rfl⟩ + exact ratToReal_le_of_le hhiRat + exact + { softmax_bounds := hsoftmax_bounds + oneHot_bounds_at := oneHot_bounds_at + weight_bounds_at := weight_bounds_at + value_bounds := hvals_bounds } + · have : False := by + have hnone := + buildInductionCertFromHeadCoreWith?_eq_none_of_not_active + (cfg := cfg) (inputs := inputs) hEps hSqrt hmodel hactive + have hcore' : + (none : Option (InductionHeadCert seq)) = some c := by + exact hnone.symm.trans hcore + cases hcore' + exact this.elim + · have : False := by + have hnone := + buildInductionCertFromHeadCoreWith?_eq_none_of_not_sqrt + (cfg := cfg) (inputs := inputs) hEps hSqrt + have hcore' : + (none : Option (InductionHeadCert seq)) = some c := by + exact hnone.symm.trans hcore + cases hcore' + exact this.elim + · have : False := by + have hnone := + buildInductionCertFromHeadCoreWith?_eq_none_of_not_eps + (cfg := cfg) (inputs := inputs) hEps + have hcore' : + (none : Option (InductionHeadCert seq)) = some c := by + exact hnone.symm.trans hcore + cases hcore' + exact this.elim + +/-- Soundness for `buildInductionCertFromHeadCore?`. -/ +theorem buildInductionCertFromHeadCore?_sound + [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) + (hcore : buildInductionCertFromHeadCore? inputs = some c) : + InductionHeadCertSound inputs c := by + simpa [buildInductionCertFromHeadCore?] using + (buildInductionCertFromHeadCoreWith?_sound + (cfg := defaultInductionHeadSplitConfig) inputs c + (by + simpa [buildInductionCertFromHeadCore?] using hcore)) +end Sound +end Nfp diff --git a/Nfp/Sound/Induction/HeadBounds.lean b/Nfp/Sound/Induction/HeadBounds.lean index 037ed18..5e2f30d 100644 --- a/Nfp/Sound/Induction/HeadBounds.lean +++ b/Nfp/Sound/Induction/HeadBounds.lean @@ -1,1238 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Core.Basic -import Mathlib.Data.Finset.Basic -import Mathlib.Data.List.Range -import Mathlib.Data.Vector.Defs -import Nfp.Model.InductionHead -import Nfp.Sound.Bounds.Attention -import Nfp.Sound.Bounds.LayerNorm -import Nfp.Sound.Bounds.MatrixNorm -import Nfp.Sound.Linear.FinFold +import Nfp.Sound.Induction.HeadBounds.Basic /-! Helper bounds for head-induction certificate construction. - -These are pure precomputations that are useful for profiling and staging. -/ - -namespace Nfp - -namespace Sound - -open Nfp.Sound.Bounds - -variable {seq : Nat} - -private def taskMin (t1 t2 : Task Rat) : Task Rat := - Task.bind t1 (fun v1 => Task.map (fun v2 => min v1 v2) t2) - -private def taskMax (t1 t2 : Task Rat) : Task Rat := - Task.bind t1 (fun v1 => Task.map (fun v2 => max v1 v2) t2) - -/-! Small lemmas for extracting `get` from task folds. -/ - -/-- `taskMin` exposes its `get` as a plain `min` on task results. -/ -private theorem taskMin_get (t1 t2 : Task Rat) : - (taskMin t1 t2).get = min t1.get t2.get := by - rfl - -/-- `taskMax` exposes its `get` as a plain `max` on task results. -/ -private theorem taskMax_get (t1 t2 : Task Rat) : - (taskMax t1 t2).get = max t1.get t2.get := by - rfl - -/-- Pull `get` through a `List.foldl` when the step is `get`-compatible. -/ -private theorem foldl_task_get_eq {α β : Type} (step : Task β → α → Task β) (step' : β → α → β) - (hstep : ∀ acc a, (step acc a).get = step' acc.get a) : - ∀ (xs : List α) (acc : Task β), - (List.foldl step acc xs).get = List.foldl step' acc.get xs - | [], acc => rfl - | x :: xs, acc => by - simpa [List.foldl, hstep] using foldl_task_get_eq step step' hstep xs (step acc x) - -/-- `List.foldl` over `taskMin` exposes a fold over `min` on task results. -/ -private theorem foldl_taskMin_get_eq {α : Type} (f : α → Task Rat) (xs : List α) - (init : Task Rat) : - (List.foldl (fun acc a => taskMin acc (f a)) init xs).get = - List.foldl (fun acc a => min acc (f a).get) init.get xs := by - refine - foldl_task_get_eq - (step := fun acc a => taskMin acc (f a)) - (step' := fun acc a => min acc (f a).get) - (hstep := ?_) - xs init - intro acc a - simp [taskMin_get] - -/-- `List.foldl` over `taskMax` exposes a fold over `max` on task results. -/ -private theorem foldl_taskMax_get_eq {α : Type} (f : α → Task Rat) (xs : List α) - (init : Task Rat) : - (List.foldl (fun acc a => taskMax acc (f a)) init xs).get = - List.foldl (fun acc a => max acc (f a).get) init.get xs := by - refine - foldl_task_get_eq - (step := fun acc a => taskMax acc (f a)) - (step' := fun acc a => max acc (f a).get) - (hstep := ?_) - xs init - intro acc a - simp [taskMax_get] - -/-- `Array.get?` + `Option.getD` followed by `Task.get` agrees with `getD` on values. -/ -private theorem task_getD_ofFn {n : Nat} (f : Fin n → Rat) (i : Nat) : - ((Array.ofFn fun c => ({ get := f c } : Task Rat))[i]?.getD { get := (0 : Rat) }).get = - (Array.ofFn f)[i]?.getD (0 : Rat) := by - by_cases h : i < n - · simp [h, Array.size_ofFn] - · simp [h, Array.size_ofFn] - -/-! Helpers for reducing cached arrays without extra allocation. -/ - -/-- Reduce an array of rational bounds to its minimum (defaulting to `0` on empty arrays). -/ -private def reduceMinArray (arr : Array Rat) : Rat := - let init := arr.getD 0 (0 : Rat) - arr.foldl (fun acc x => min acc x) init - -/-- Reduce an array of rational bounds to its maximum (defaulting to `0` on empty arrays). -/ -private def reduceMaxArray (arr : Array Rat) : Rat := - let init := arr.getD 0 (0 : Rat) - arr.foldl (fun acc x => max acc x) init - -/-- Reduce a `Fin seq`-indexed function using the chunked sequential algorithm. -/ -private def reduceFnChunked [NeZero seq] (vals : Fin seq → Rat) - (combine : Rat → Rat → Rat) : Rat := - let n := seq - if n = 0 then - (0 : Rat) - else - let chunkSize : Nat := 256 - let chunks : Nat := (n + chunkSize - 1) / chunkSize - let hpos : 0 < seq := Nat.pos_of_ne_zero (by simpa using (NeZero.ne (n := seq))) - let defaultIdx : Fin seq := ⟨0, hpos⟩ - let idxs : Array (Fin seq) := Array.ofFn (fun i : Fin seq => i) - let chunkVals : Array Rat := - Array.ofFn (fun c : Fin chunks => - let start := c.val * chunkSize - let stop := Nat.min n (start + chunkSize) - let init := vals (idxs.getD start defaultIdx) - if stop ≤ start + 1 then - init - else - let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) - rest.foldl (fun acc i => combine acc (vals (idxs.getD i defaultIdx))) init) - let init := chunkVals.getD 0 (0 : Rat) - let rest := (List.range (chunkVals.size - 1)).map (fun i => i + 1) - rest.foldl (fun acc i => combine acc (chunkVals.getD i 0)) init - -/-- Unfold `reduceFnChunked` to its chunked sequential definition. -/ -theorem reduceFnChunked_spec [NeZero seq] (vals : Fin seq → Rat) - (combine : Rat → Rat → Rat) : - reduceFnChunked (seq := seq) vals combine = - let n := seq - if n = 0 then - (0 : Rat) - else - let chunkSize : Nat := 256 - let chunks : Nat := (n + chunkSize - 1) / chunkSize - let hpos : 0 < seq := Nat.pos_of_ne_zero (by simpa using (NeZero.ne (n := seq))) - let defaultIdx : Fin seq := ⟨0, hpos⟩ - let idxs : Array (Fin seq) := Array.ofFn (fun i : Fin seq => i) - let chunkVals : Array Rat := - Array.ofFn (fun c : Fin chunks => - let start := c.val * chunkSize - let stop := Nat.min n (start + chunkSize) - let init := vals (idxs.getD start defaultIdx) - if stop ≤ start + 1 then - init - else - let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) - rest.foldl (fun acc i => combine acc (vals (idxs.getD i defaultIdx))) init) - let init := chunkVals.getD 0 (0 : Rat) - let rest := (List.range (chunkVals.size - 1)).map (fun i => i + 1) - rest.foldl (fun acc i => combine acc (chunkVals.getD i 0)) init := rfl - -/-- Reduce a `Fin seq`-indexed function in parallel using chunked tasks. -/ -private def reduceFnTask [NeZero seq] (vals : Fin seq → Rat) - (combine : Rat → Rat → Rat) (combineTask : Task Rat → Task Rat → Task Rat) : Task Rat := - let n := seq - if n = 0 then - Task.pure (0 : Rat) - else - let chunkSize : Nat := 256 - let chunks : Nat := (n + chunkSize - 1) / chunkSize - let hpos : 0 < seq := Nat.pos_of_ne_zero (by simpa using (NeZero.ne (n := seq))) - let defaultIdx : Fin seq := ⟨0, hpos⟩ - let idxs : Array (Fin seq) := Array.ofFn (fun i : Fin seq => i) - let defaultTask : Task Rat := Task.pure (0 : Rat) - let chunkTasks : Array (Task Rat) := - Array.ofFn (fun c : Fin chunks => - Task.spawn (fun _ => - let start := c.val * chunkSize - let stop := Nat.min n (start + chunkSize) - let init := vals (idxs.getD start defaultIdx) - if stop ≤ start + 1 then - init - else - let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) - rest.foldl (fun acc i => combine acc (vals (idxs.getD i defaultIdx))) init)) - let init := chunkTasks.getD 0 defaultTask - let rest := (List.range (chunkTasks.size - 1)).map (fun i => i + 1) - rest.foldl (fun acc i => combineTask acc (chunkTasks.getD i defaultTask)) init - -/-- Unfold `reduceFnTask` to its chunked-task definition. -/ -theorem reduceFnTask_spec [NeZero seq] (vals : Fin seq → Rat) - (combine : Rat → Rat → Rat) (combineTask : Task Rat → Task Rat → Task Rat) : - reduceFnTask (seq := seq) vals combine combineTask = - let n := seq - if n = 0 then - Task.pure (0 : Rat) - else - let chunkSize : Nat := 256 - let chunks : Nat := (n + chunkSize - 1) / chunkSize - let hpos : 0 < seq := Nat.pos_of_ne_zero (by simpa using (NeZero.ne (n := seq))) - let defaultIdx : Fin seq := ⟨0, hpos⟩ - let idxs : Array (Fin seq) := Array.ofFn (fun i : Fin seq => i) - let defaultTask : Task Rat := Task.pure (0 : Rat) - let chunkTasks : Array (Task Rat) := - Array.ofFn (fun c : Fin chunks => - Task.spawn (fun _ => - let start := c.val * chunkSize - let stop := Nat.min n (start + chunkSize) - let init := vals (idxs.getD start defaultIdx) - if stop ≤ start + 1 then - init - else - let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) - rest.foldl (fun acc i => combine acc (vals (idxs.getD i defaultIdx))) init)) - let init := chunkTasks.getD 0 defaultTask - let rest := (List.range (chunkTasks.size - 1)).map (fun i => i + 1) - rest.foldl (fun acc i => combineTask acc (chunkTasks.getD i defaultTask)) init := rfl - -private def reduceMinFnTask [NeZero seq] (vals : Fin seq → Rat) : Task Rat := - reduceFnTask vals min taskMin - -private def reduceMaxFnTask [NeZero seq] (vals : Fin seq → Rat) : Task Rat := - reduceFnTask vals max taskMax - -/-- Chunked sequential minimum over a `Fin seq`-indexed function. -/ -private def reduceMinFnChunked [NeZero seq] (vals : Fin seq → Rat) : Rat := - reduceFnChunked vals min - -/-- Unfold `reduceMinFnChunked` to `reduceFnChunked` with `min`. -/ -theorem reduceMinFnChunked_spec [NeZero seq] (vals : Fin seq → Rat) : - reduceMinFnChunked vals = reduceFnChunked vals min := rfl - -/-- Chunked sequential maximum over a `Fin seq`-indexed function. -/ -private def reduceMaxFnChunked [NeZero seq] (vals : Fin seq → Rat) : Rat := - reduceFnChunked vals max - -/-- Unfold `reduceMaxFnChunked` to `reduceFnChunked` with `max`. -/ -theorem reduceMaxFnChunked_spec [NeZero seq] (vals : Fin seq → Rat) : - reduceMaxFnChunked vals = reduceFnChunked vals max := rfl - -/-- The chunked parallel min-reduction task returns the sequential chunked result. -/ -theorem reduceMinFnTask_get_eq [NeZero seq] (vals : Fin seq → Rat) : - (reduceMinFnTask vals).get = reduceMinFnChunked vals := by - classical - have hseq : seq ≠ 0 := NeZero.ne (n := seq) - simp [reduceMinFnTask, reduceMinFnChunked, reduceFnTask, reduceFnChunked, hseq, - Task.spawn, foldl_taskMin_get_eq, task_getD_ofFn] - -/-- The chunked parallel max-reduction task returns the sequential chunked result. -/ -theorem reduceMaxFnTask_get_eq [NeZero seq] (vals : Fin seq → Rat) : - (reduceMaxFnTask vals).get = reduceMaxFnChunked vals := by - classical - have hseq : seq ≠ 0 := NeZero.ne (n := seq) - simp [reduceMaxFnTask, reduceMaxFnChunked, reduceFnTask, reduceFnChunked, hseq, - Task.spawn, foldl_taskMax_get_eq, task_getD_ofFn] - -/-- Cached direction head for head inputs. -/ -private def dirHeadVecOfInputs {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : Vector Rat dHead := - Vector.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.wo j d) (fun j => inputs.direction j)) - -/-- LayerNorm bounds used by the induction-head builder. -/ -def headLnBounds [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : - (Fin seq → Fin dModel → Rat) × (Fin seq → Fin dModel → Rat) := - Bounds.cacheBoundPair2 (fun q => - Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) - -theorem headLnBounds_spec [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : - headLnBounds inputs = - Bounds.cacheBoundPair2 (fun q => - Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) := rfl - -/-- Q/K/V bounds used by the induction-head builder. -/ -structure HeadQKVBounds (seq dModel dHead : Nat) where - /-- Q lower bounds. -/ - qLo : Fin seq → Fin dHead → Rat - /-- Q upper bounds. -/ - qHi : Fin seq → Fin dHead → Rat - /-- K lower bounds. -/ - kLo : Fin seq → Fin dHead → Rat - /-- K upper bounds. -/ - kHi : Fin seq → Fin dHead → Rat - /-- V lower bounds. -/ - vLo : Fin seq → Fin dHead → Rat - /-- V upper bounds. -/ - vHi : Fin seq → Fin dHead → Rat - /-- Q absolute bounds. -/ - qAbs : Fin seq → Fin dHead → Rat - /-- K absolute bounds. -/ - kAbs : Fin seq → Fin dHead → Rat - -/-- Compute Q/K/V bounds from LayerNorm bounds. -/ -def headQKVBounds [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (lnLo lnHi : Fin seq → Fin dModel → Rat) : - HeadQKVBounds seq dModel dHead := - let qLo := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalLowerUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + - inputs.bq d) - let qHi := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalUpperUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + - inputs.bq d) - let kLo := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalLowerUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + - inputs.bk d) - let kHi := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalUpperUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + - inputs.bk d) - let vLo := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalLowerUnnorm (fun j => inputs.wv j d) (lnLo q) (lnHi q) + - inputs.bv d) - let vHi := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalUpperUnnorm (fun j => inputs.wv j d) (lnLo q) (lnHi q) + - inputs.bv d) - let qAbs := - Bounds.cacheBound2 (fun q d => max |qLo q d| |qHi q d|) - let kAbs := - Bounds.cacheBound2 (fun q d => max |kLo q d| |kHi q d|) - { qLo := qLo - qHi := qHi - kLo := kLo - kHi := kHi - vLo := vLo - vHi := vHi - qAbs := qAbs - kAbs := kAbs } - -theorem headQKVBounds_spec [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (lnLo lnHi : Fin seq → Fin dModel → Rat) : - headQKVBounds inputs lnLo lnHi = - let qLo := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalLowerUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + - inputs.bq d) - let qHi := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalUpperUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + - inputs.bq d) - let kLo := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalLowerUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + - inputs.bk d) - let kHi := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalUpperUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + - inputs.bk d) - let vLo := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalLowerUnnorm (fun j => inputs.wv j d) (lnLo q) (lnHi q) + - inputs.bv d) - let vHi := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalUpperUnnorm (fun j => inputs.wv j d) (lnLo q) (lnHi q) + - inputs.bv d) - let qAbs := - Bounds.cacheBound2 (fun q d => max |qLo q d| |qHi q d|) - let kAbs := - Bounds.cacheBound2 (fun q d => max |kLo q d| |kHi q d|) - { qLo := qLo - qHi := qHi - kLo := kLo - kHi := kHi - vLo := vLo - vHi := vHi - qAbs := qAbs - kAbs := kAbs } := rfl - -/-- Score and margin bounds used by the induction-head builder. -/ -structure HeadScoreBounds (seq dModel dHead : Nat) where - /-- Absolute dot-product bound. -/ - dotAbs : Fin seq → Fin seq → Rat - /-- Base score absolute bound. -/ - scoreBaseAbs : Fin seq → Fin seq → Rat - /-- Score absolute bound with causal masking. -/ - scoreAbs : Fin seq → Fin seq → Rat - /-- Score lower bound. -/ - scoreLo : Fin seq → Fin seq → Rat - /-- Score upper bound. -/ - scoreHi : Fin seq → Fin seq → Rat - /-- Margin per query. -/ - marginAt : Fin seq → Rat - /-- Epsilon per query. -/ - epsAt : Fin seq → Rat - /-- Global margin. -/ - margin : Rat - /-- Global epsilon. -/ - eps : Rat - -/-- Compute score and margin bounds from cached score lower/upper bounds. -/ -def headScoreBoundsFromCaches [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (dotAbs : Fin seq → Fin seq → Rat) - (scoreLo scoreHi : Fin seq → Fin seq → Rat) : - HeadScoreBounds seq dModel dHead := - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => - |inputs.scale| * dotAbs q k - let scoreAbs : Fin seq → Fin seq → Rat := fun q k => - if masked q k then |inputs.maskValue| else scoreBaseAbs q k - let otherKeys : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - let maskedKeys : Fin seq → Finset (Fin seq) := fun q => - if inputs.maskCausal = true then - (otherKeys q).filter (fun k => q < k) - else - (∅ : Finset (Fin seq)) - let unmaskedKeys : Fin seq → Finset (Fin seq) := fun q => - (otherKeys q) \ (maskedKeys q) - let maskedGap : Fin seq → Rat := fun q => - scoreLo q (inputs.prev q) - inputs.maskValue - let marginTasks : Array (Task Rat) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - if q ∈ inputs.active then - let other := unmaskedKeys q - let masked := maskedKeys q - if hunmasked : other.Nonempty then - let unmaskedMin := other.inf' hunmasked (fun k => - scoreLo q (inputs.prev q) - scoreHi q k) - if hmasked : masked.Nonempty then - min unmaskedMin (maskedGap q) - else - unmaskedMin - else - if hmasked : masked.Nonempty then - maskedGap q - else - (0 : Rat) - else - (0 : Rat))) - let marginAt : Fin seq → Rat := fun q => - (marginTasks[q.1]'(by - simp [marginTasks, q.isLt])).get - let epsTasks : Array (Task Rat) := - Array.ofFn (fun q : Fin seq => - (marginTasks[q.1]'(by - simp [marginTasks, q.isLt])).map (fun m => - if m < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + m))) - let epsAt : Fin seq → Rat := fun q => - (epsTasks[q.1]'(by - simp [epsTasks, q.isLt])).get - let margin : Rat := - if h : inputs.active.Nonempty then - inputs.active.inf' h marginAt - else - (0 : Rat) - let eps : Rat := - if margin < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + margin) - { dotAbs := dotAbs - scoreBaseAbs := scoreBaseAbs - scoreAbs := scoreAbs - scoreLo := scoreLo - scoreHi := scoreHi - marginAt := marginAt - epsAt := epsAt - margin := margin - eps := eps } - -/-- Compute score and margin bounds from dot-product absolute bounds. -/ -def headScoreBoundsFromDotAbs [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (dotAbs : Fin seq → Fin seq → Rat) : HeadScoreBounds seq dModel dHead := - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let dotAbsRowTasks : Array (Task { row : Array Rat // row.size = seq }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - ⟨Array.ofFn (fun k : Fin seq => dotAbs q k), by simp⟩)) - let scaleAbs : Rat := |inputs.scale| - let scoreLoCached : Fin seq → Fin seq → Rat := fun q k => - let row := (dotAbsRowTasks[q.1]'(by - simp [dotAbsRowTasks, q.isLt])).get - let base := scaleAbs * row.1.getD k.1 0 - if masked q k then inputs.maskValue else -base - let scoreHiCached : Fin seq → Fin seq → Rat := fun q k => - let row := (dotAbsRowTasks[q.1]'(by - simp [dotAbsRowTasks, q.isLt])).get - let base := scaleAbs * row.1.getD k.1 0 - if masked q k then inputs.maskValue else base - let marginAtRaw : Fin seq → Rat := fun q => - let row := (dotAbsRowTasks[q.1]'(by - simp [dotAbsRowTasks, q.isLt])).get - if q ∈ inputs.active then - let rowArr := row.1 - let prev := inputs.prev q - let dotAbsPrev := rowArr.getD prev.1 0 - if masked q prev then - let scoreLoPrev := inputs.maskValue - let scoreHiAt : Fin seq → Rat := fun k => - if masked q k then inputs.maskValue else scaleAbs * rowArr.getD k.1 0 - let maskedGap := scoreLoPrev - inputs.maskValue - let step : - (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := - fun acc k => - if k = prev then - acc - else if masked q k then - (acc.1, true) - else - let v := scoreLoPrev - scoreHiAt k - match acc.1 with - | none => (some v, acc.2) - | some cur => (some (min cur v), acc.2) - let acc := Linear.foldlFin seq step (none, false) - match acc.1, acc.2 with - | some unmaskedMin, true => min unmaskedMin maskedGap - | some unmaskedMin, false => unmaskedMin - | none, true => maskedGap - | none, false => (0 : Rat) - else - let scoreLoPrev := -(scaleAbs * dotAbsPrev) - let maskedGap := scoreLoPrev - inputs.maskValue - let step : - (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := - fun acc k => - if k = prev then - acc - else if masked q k then - (acc.1, true) - else - let raw := -(dotAbsPrev + rowArr.getD k.1 0) - match acc.1 with - | none => (some raw, acc.2) - | some cur => (some (min cur raw), acc.2) - let acc := Linear.foldlFin seq step (none, false) - match acc.1, acc.2 with - | some unmaskedMin, true => min (scaleAbs * unmaskedMin) maskedGap - | some unmaskedMin, false => scaleAbs * unmaskedMin - | none, true => maskedGap - | none, false => (0 : Rat) - else - (0 : Rat) - let marginAtCached := Bounds.cacheBoundThunk marginAtRaw - let marginAt : Fin seq → Rat := fun q => - marginAtCached q - let epsAtRaw : Fin seq → Rat := fun q => - let m := marginAt q - if m < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + m) - let epsAtCached := Bounds.cacheBoundThunk epsAtRaw - let epsAt : Fin seq → Rat := fun q => - epsAtCached q - let margin : Rat := - if h : inputs.active.Nonempty then - inputs.active.inf' h marginAt - else - (0 : Rat) - let eps : Rat := - if margin < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + margin) - { dotAbs := dotAbs - scoreBaseAbs := fun q k => |inputs.scale| * dotAbs q k - scoreAbs := fun q k => - if masked q k then |inputs.maskValue| else |inputs.scale| * dotAbs q k - scoreLo := scoreLoCached - scoreHi := scoreHiCached - marginAt := marginAt - epsAt := epsAt - margin := margin - eps := eps } - -theorem headScoreBoundsFromDotAbs_spec [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (dotAbs : Fin seq → Fin seq → Rat) : - headScoreBoundsFromDotAbs inputs dotAbs = - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let dotAbsRowTasks : Array (Task { row : Array Rat // row.size = seq }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - ⟨Array.ofFn (fun k : Fin seq => dotAbs q k), by simp⟩)) - let scaleAbs : Rat := |inputs.scale| - let scoreLoCached : Fin seq → Fin seq → Rat := fun q k => - let row := (dotAbsRowTasks[q.1]'(by - simp [dotAbsRowTasks, q.isLt])).get - let base := scaleAbs * row.1.getD k.1 0 - if masked q k then inputs.maskValue else -base - let scoreHiCached : Fin seq → Fin seq → Rat := fun q k => - let row := (dotAbsRowTasks[q.1]'(by - simp [dotAbsRowTasks, q.isLt])).get - let base := scaleAbs * row.1.getD k.1 0 - if masked q k then inputs.maskValue else base - let marginAtRaw : Fin seq → Rat := fun q => - let row := (dotAbsRowTasks[q.1]'(by - simp [dotAbsRowTasks, q.isLt])).get - if q ∈ inputs.active then - let rowArr := row.1 - let prev := inputs.prev q - let dotAbsPrev := rowArr.getD prev.1 0 - if masked q prev then - let scoreLoPrev := inputs.maskValue - let scoreHiAt : Fin seq → Rat := fun k => - if masked q k then inputs.maskValue else scaleAbs * rowArr.getD k.1 0 - let maskedGap := scoreLoPrev - inputs.maskValue - let step : - (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := - fun acc k => - if k = prev then - acc - else if masked q k then - (acc.1, true) - else - let v := scoreLoPrev - scoreHiAt k - match acc.1 with - | none => (some v, acc.2) - | some cur => (some (min cur v), acc.2) - let acc := Linear.foldlFin seq step (none, false) - match acc.1, acc.2 with - | some unmaskedMin, true => min unmaskedMin maskedGap - | some unmaskedMin, false => unmaskedMin - | none, true => maskedGap - | none, false => (0 : Rat) - else - let scoreLoPrev := -(scaleAbs * dotAbsPrev) - let maskedGap := scoreLoPrev - inputs.maskValue - let step : - (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := - fun acc k => - if k = prev then - acc - else if masked q k then - (acc.1, true) - else - let raw := -(dotAbsPrev + rowArr.getD k.1 0) - match acc.1 with - | none => (some raw, acc.2) - | some cur => (some (min cur raw), acc.2) - let acc := Linear.foldlFin seq step (none, false) - match acc.1, acc.2 with - | some unmaskedMin, true => min (scaleAbs * unmaskedMin) maskedGap - | some unmaskedMin, false => scaleAbs * unmaskedMin - | none, true => maskedGap - | none, false => (0 : Rat) - else - (0 : Rat) - let marginAtCached := Bounds.cacheBoundThunk marginAtRaw - let marginAt : Fin seq → Rat := fun q => - marginAtCached q - let epsAtRaw : Fin seq → Rat := fun q => - let m := marginAt q - if m < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + m) - let epsAtCached := Bounds.cacheBoundThunk epsAtRaw - let epsAt : Fin seq → Rat := fun q => - epsAtCached q - let margin : Rat := - if h : inputs.active.Nonempty then - inputs.active.inf' h marginAt - else - (0 : Rat) - let eps : Rat := - if margin < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + margin) - { dotAbs := dotAbs - scoreBaseAbs := fun q k => |inputs.scale| * dotAbs q k - scoreAbs := fun q k => - if masked q k then |inputs.maskValue| else |inputs.scale| * dotAbs q k - scoreLo := scoreLoCached - scoreHi := scoreHiCached - marginAt := marginAt - epsAt := epsAt - margin := margin - eps := eps } := rfl - -/-- Compute score and margin bounds from Q/K interval bounds. -/ -def headScoreBoundsFromIntervals [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (qLo qHi kLo kHi : Fin seq → Fin dHead → Rat) : - HeadScoreBounds seq dModel dHead := - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - ⟨Array.ofFn (fun k : Fin seq => - dotIntervalLowerUpper2CommonDen (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo k d) (fun d => kHi k d)), - by simp⟩)) - let dotLo : Fin seq → Fin seq → Rat := fun q k => - let row := (dotRowTasks[q.1]'(by - simp [dotRowTasks, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.1 - let dotHi : Fin seq → Fin seq → Rat := fun q k => - let row := (dotRowTasks[q.1]'(by - simp [dotRowTasks, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.2 - let dotAbs : Fin seq → Fin seq → Rat := fun q k => max |dotLo q k| |dotHi q k| - let scoreLo : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - if 0 ≤ inputs.scale then - inputs.scale * dotLo q k - else - inputs.scale * dotHi q k - let scoreHi : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - if 0 ≤ inputs.scale then - inputs.scale * dotHi q k - else - inputs.scale * dotLo q k - headScoreBoundsFromCaches inputs dotAbs scoreLo scoreHi - -theorem headScoreBoundsFromIntervals_spec [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (qLo qHi kLo kHi : Fin seq → Fin dHead → Rat) : - headScoreBoundsFromIntervals inputs qLo qHi kLo kHi = - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - ⟨Array.ofFn (fun k : Fin seq => - dotIntervalLowerUpper2CommonDen (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo k d) (fun d => kHi k d)), - by simp⟩)) - let dotLo : Fin seq → Fin seq → Rat := fun q k => - let row := (dotRowTasks[q.1]'(by - simp [dotRowTasks, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.1 - let dotHi : Fin seq → Fin seq → Rat := fun q k => - let row := (dotRowTasks[q.1]'(by - simp [dotRowTasks, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.2 - let dotAbs : Fin seq → Fin seq → Rat := fun q k => max |dotLo q k| |dotHi q k| - let scoreLo : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - if 0 ≤ inputs.scale then - inputs.scale * dotLo q k - else - inputs.scale * dotHi q k - let scoreHi : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - if 0 ≤ inputs.scale then - inputs.scale * dotHi q k - else - inputs.scale * dotLo q k - headScoreBoundsFromCaches inputs dotAbs scoreLo scoreHi := rfl - -/-- Compute score and margin bounds from Q/K absolute bounds. -/ -def headScoreBounds [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (qAbs kAbs : Fin seq → Fin dHead → Rat) : - HeadScoreBounds seq dModel dHead := - headScoreBoundsFromDotAbs inputs (fun q k => - Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d)) - -theorem headScoreBounds_spec [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (qAbs kAbs : Fin seq → Fin dHead → Rat) : - headScoreBounds inputs qAbs kAbs = - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let dotAbs : Fin seq → Fin seq → Rat := fun q k => - Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d) - let dotAbsRowTasks : Array (Task { row : Array Rat // row.size = seq }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - ⟨Array.ofFn (fun k : Fin seq => dotAbs q k), by simp⟩)) - let scaleAbs : Rat := |inputs.scale| - let scoreLoCached : Fin seq → Fin seq → Rat := fun q k => - let row := (dotAbsRowTasks[q.1]'(by - simp [dotAbsRowTasks, q.isLt])).get - let base := scaleAbs * row.1.getD k.1 0 - if masked q k then inputs.maskValue else -base - let scoreHiCached : Fin seq → Fin seq → Rat := fun q k => - let row := (dotAbsRowTasks[q.1]'(by - simp [dotAbsRowTasks, q.isLt])).get - let base := scaleAbs * row.1.getD k.1 0 - if masked q k then inputs.maskValue else base - let marginAtRaw : Fin seq → Rat := fun q => - let row := (dotAbsRowTasks[q.1]'(by - simp [dotAbsRowTasks, q.isLt])).get - if q ∈ inputs.active then - let rowArr := row.1 - let prev := inputs.prev q - let dotAbsPrev := rowArr.getD prev.1 0 - if masked q prev then - let scoreLoPrev := inputs.maskValue - let scoreHiAt : Fin seq → Rat := fun k => - if masked q k then inputs.maskValue else scaleAbs * rowArr.getD k.1 0 - let maskedGap := scoreLoPrev - inputs.maskValue - let step : - (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := - fun acc k => - if k = prev then - acc - else if masked q k then - (acc.1, true) - else - let v := scoreLoPrev - scoreHiAt k - match acc.1 with - | none => (some v, acc.2) - | some cur => (some (min cur v), acc.2) - let acc := Linear.foldlFin seq step (none, false) - match acc.1, acc.2 with - | some unmaskedMin, true => min unmaskedMin maskedGap - | some unmaskedMin, false => unmaskedMin - | none, true => maskedGap - | none, false => (0 : Rat) - else - let scoreLoPrev := -(scaleAbs * dotAbsPrev) - let maskedGap := scoreLoPrev - inputs.maskValue - let step : - (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := - fun acc k => - if k = prev then - acc - else if masked q k then - (acc.1, true) - else - let raw := -(dotAbsPrev + rowArr.getD k.1 0) - match acc.1 with - | none => (some raw, acc.2) - | some cur => (some (min cur raw), acc.2) - let acc := Linear.foldlFin seq step (none, false) - match acc.1, acc.2 with - | some unmaskedMin, true => min (scaleAbs * unmaskedMin) maskedGap - | some unmaskedMin, false => scaleAbs * unmaskedMin - | none, true => maskedGap - | none, false => (0 : Rat) - else - (0 : Rat) - let marginAtCached := Bounds.cacheBoundThunk marginAtRaw - let marginAt : Fin seq → Rat := fun q => - marginAtCached q - let epsAtRaw : Fin seq → Rat := fun q => - let m := marginAt q - if m < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + m) - let epsAtCached := Bounds.cacheBoundThunk epsAtRaw - let epsAt : Fin seq → Rat := fun q => - epsAtCached q - let margin : Rat := - if h : inputs.active.Nonempty then - inputs.active.inf' h marginAt - else - (0 : Rat) - let eps : Rat := - if margin < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + margin) - { dotAbs := dotAbs - scoreBaseAbs := fun q k => |inputs.scale| * dotAbs q k - scoreAbs := fun q k => - if masked q k then |inputs.maskValue| else |inputs.scale| * dotAbs q k - scoreLo := scoreLoCached - scoreHi := scoreHiCached - marginAt := marginAt - epsAt := epsAt - margin := margin - eps := eps } := rfl - -/-- Value bounds used by the induction-head builder. -/ -structure HeadValueBounds (seq dModel dHead : Nat) where - /-- Value lower bounds. -/ - valsLo : Fin seq → Rat - /-- Value upper bounds. -/ - valsHi : Fin seq → Rat - /-- Global value lower bound. -/ - lo : Rat - /-- Global value upper bound. -/ - hi : Rat - -/-- Cached direction vector for value bounds. -/ -def headValueDirHead {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin dHead → Rat := - let dirHeadVec := dirHeadVecOfInputs inputs - fun d => dirHeadVec.get d - -theorem headValueDirHead_spec {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : - headValueDirHead inputs = - let dirHeadVec := dirHeadVecOfInputs inputs - fun d => dirHeadVec.get d := rfl - -/-- Cached lower value bounds from V intervals. -/ -def headValueValsLoArray {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : Array Rat := - let dirHead := headValueDirHead inputs - Array.ofFn (fun k => - Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) - -/-- Unfold `headValueValsLoArray` to its `Array.ofFn` definition. -/ -theorem headValueValsLoArray_spec {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueValsLoArray inputs vLo vHi = - let dirHead := headValueDirHead inputs - Array.ofFn (fun k => - Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) := rfl - -/-- Cached lower value bounds from V intervals. -/ -def headValueValsLo {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : Fin seq → Rat := - let arr := headValueValsLoArray inputs vLo vHi - fun k => arr.getD k.1 (0 : Rat) - -theorem headValueValsLo_spec {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueValsLo inputs vLo vHi = - let arr := headValueValsLoArray inputs vLo vHi - fun k => arr.getD k.1 (0 : Rat) := rfl - -/-- Cached lower value bounds from V intervals using a common-denominator sum. -/ -def headValueValsLoCommonDenArray {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : Array Rat := - headValueValsLoArray inputs vLo vHi - -/-- Unfold `headValueValsLoCommonDenArray` to its `Array.ofFn` definition. -/ -theorem headValueValsLoCommonDenArray_spec {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueValsLoCommonDenArray inputs vLo vHi = - let dirHead := headValueDirHead inputs - Array.ofFn (fun k => - Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) := rfl - -/-- Cached lower value bounds from V intervals using a common-denominator sum. -/ -def headValueValsLoCommonDen {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : Fin seq → Rat := - let arr := headValueValsLoCommonDenArray inputs vLo vHi - fun k => arr.getD k.1 (0 : Rat) - -theorem headValueValsLoCommonDen_spec {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueValsLoCommonDen inputs vLo vHi = - let arr := headValueValsLoCommonDenArray inputs vLo vHi - fun k => arr.getD k.1 (0 : Rat) := rfl - -/-- Common-denominator lower bounds agree with cached rational bounds pointwise. -/ -theorem headValueValsLoCommonDenArray_eq {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueValsLoCommonDenArray inputs vLo vHi = headValueValsLoArray inputs vLo vHi := by - rfl - -theorem headValueValsLoCommonDen_eq {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueValsLoCommonDen inputs vLo vHi = headValueValsLo inputs vLo vHi := by - funext k - simp [headValueValsLoCommonDen, headValueValsLo, headValueValsLoCommonDenArray_eq] - -/-- Cached upper value bounds from V intervals. -/ -def headValueValsHiArray {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : Array Rat := - let dirHead := headValueDirHead inputs - Array.ofFn (fun k => - Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) - -/-- Unfold `headValueValsHiArray` to its `Array.ofFn` definition. -/ -theorem headValueValsHiArray_spec {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueValsHiArray inputs vLo vHi = - let dirHead := headValueDirHead inputs - Array.ofFn (fun k => - Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) := rfl - -/-- Cached upper value bounds from V intervals. -/ -def headValueValsHi {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : Fin seq → Rat := - let arr := headValueValsHiArray inputs vLo vHi - fun k => arr.getD k.1 (0 : Rat) - -theorem headValueValsHi_spec {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueValsHi inputs vLo vHi = - let arr := headValueValsHiArray inputs vLo vHi - fun k => arr.getD k.1 (0 : Rat) := rfl - -/-- Cached upper value bounds from V intervals using a common-denominator sum. -/ -def headValueValsHiCommonDenArray {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : Array Rat := - headValueValsHiArray inputs vLo vHi - -/-- Unfold `headValueValsHiCommonDenArray` to its `Array.ofFn` definition. -/ -theorem headValueValsHiCommonDenArray_spec {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueValsHiCommonDenArray inputs vLo vHi = - let dirHead := headValueDirHead inputs - Array.ofFn (fun k => - Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) := rfl - -/-- Cached upper value bounds from V intervals using a common-denominator sum. -/ -def headValueValsHiCommonDen {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : Fin seq → Rat := - let arr := headValueValsHiCommonDenArray inputs vLo vHi - fun k => arr.getD k.1 (0 : Rat) - -theorem headValueValsHiCommonDen_spec {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueValsHiCommonDen inputs vLo vHi = - let arr := headValueValsHiCommonDenArray inputs vLo vHi - fun k => arr.getD k.1 (0 : Rat) := rfl - -/-- Common-denominator upper bounds agree with cached rational bounds pointwise. -/ -theorem headValueValsHiCommonDenArray_eq {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueValsHiCommonDenArray inputs vLo vHi = headValueValsHiArray inputs vLo vHi := by - rfl - -theorem headValueValsHiCommonDen_eq {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueValsHiCommonDen inputs vLo vHi = headValueValsHi inputs vLo vHi := by - funext k - simp [headValueValsHiCommonDen, headValueValsHi, headValueValsHiCommonDenArray_eq] - -/-- Global lower value bound from an array of per-key values. -/ -def headValueLoArray (valsLo : Array Rat) : Rat := - reduceMinArray valsLo - -/-- Unfold `headValueLoArray` to its reduction helper. -/ -theorem headValueLoArray_spec (valsLo : Array Rat) : - headValueLoArray valsLo = reduceMinArray valsLo := rfl - -/-- Global lower value bound from cached per-key values. -/ -def headValueLo [NeZero seq] (valsLo : Fin seq → Rat) : Rat := - headValueLoArray (Array.ofFn valsLo) - -theorem headValueLo_spec [NeZero seq] (valsLo : Fin seq → Rat) : - headValueLo valsLo = headValueLoArray (Array.ofFn valsLo) := rfl - -/-- Task wrapper for `headValueLo`. -/ -def headValueLoTask [NeZero seq] (valsLo : Fin seq → Rat) : Task Rat := - reduceMinFnTask valsLo - -theorem headValueLoTask_spec [NeZero seq] (valsLo : Fin seq → Rat) : - headValueLoTask valsLo = reduceMinFnTask valsLo := rfl - -/-- Chunked task reduction agrees with the sequential chunked value bound. -/ -theorem headValueLoTask_get_eq [NeZero seq] (valsLo : Fin seq → Rat) : - (headValueLoTask valsLo).get = reduceMinFnChunked valsLo := by - simp [headValueLoTask_spec, reduceMinFnTask_get_eq] - -/-- Global upper value bound from an array of per-key values. -/ -def headValueHiArray (valsHi : Array Rat) : Rat := - reduceMaxArray valsHi - -/-- Unfold `headValueHiArray` to its reduction helper. -/ -theorem headValueHiArray_spec (valsHi : Array Rat) : - headValueHiArray valsHi = reduceMaxArray valsHi := rfl - -/-- Global upper value bound from cached per-key values. -/ -def headValueHi [NeZero seq] (valsHi : Fin seq → Rat) : Rat := - headValueHiArray (Array.ofFn valsHi) - -theorem headValueHi_spec [NeZero seq] (valsHi : Fin seq → Rat) : - headValueHi valsHi = headValueHiArray (Array.ofFn valsHi) := rfl - -/-- Task wrapper for `headValueHi`. -/ -def headValueHiTask [NeZero seq] (valsHi : Fin seq → Rat) : Task Rat := - reduceMaxFnTask valsHi - -theorem headValueHiTask_spec [NeZero seq] (valsHi : Fin seq → Rat) : - headValueHiTask valsHi = reduceMaxFnTask valsHi := rfl - -/-- Chunked task reduction agrees with the sequential chunked value bound. -/ -theorem headValueHiTask_get_eq [NeZero seq] (valsHi : Fin seq → Rat) : - (headValueHiTask valsHi).get = reduceMaxFnChunked valsHi := by - simp [headValueHiTask_spec, reduceMaxFnTask_get_eq] - -/-- Build `HeadValueBounds` from precomputed arrays. -/ -private def headValueBoundsOfArrays {seq dModel dHead : Nat} - (valsLoArr valsHiArr : Array Rat) : HeadValueBounds seq dModel dHead := - let valsLo : Fin seq → Rat := fun k => valsLoArr.getD k.1 (0 : Rat) - let valsHi : Fin seq → Rat := fun k => valsHiArr.getD k.1 (0 : Rat) - let lo := headValueLoArray valsLoArr - let hi := headValueHiArray valsHiArr - { valsLo := valsLo, valsHi := valsHi, lo := lo, hi := hi } - -/-- Build a cached bounds array in parallel from a per-key computation. -/ -private def buildBoundArrayTask [NeZero seq] (f : Fin seq → Rat) : Task (Array Rat) := - let n := seq - let chunkSize : Nat := 64 - let chunks : Nat := (n + chunkSize - 1) / chunkSize - let hpos : 0 < seq := Nat.pos_of_ne_zero (by simpa using (NeZero.ne (n := seq))) - let defaultIdx : Fin seq := ⟨0, hpos⟩ - let idxs : Array (Fin seq) := Array.ofFn (fun i : Fin seq => i) - let chunkTasks : List (Task (Array Rat)) := - (List.range chunks).map (fun c => - Task.spawn (fun _ => - let start := c * chunkSize - let stop := Nat.min n (start + chunkSize) - let vals := - (List.range (stop - start)).map (fun i => - f (idxs.getD (start + i) defaultIdx)) - vals.toArray)) - Task.mapList (fun xs => xs.foldl (fun acc arr => acc ++ arr) #[]) chunkTasks - -/-- Compute value bounds from V interval bounds. -/ -def headValueBounds [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - HeadValueBounds seq dModel dHead := - let valsLoArr := headValueValsLoArray inputs vLo vHi - let valsHiArr := headValueValsHiArray inputs vLo vHi - headValueBoundsOfArrays valsLoArr valsHiArr - -theorem headValueBounds_spec [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueBounds inputs vLo vHi = - let valsLoArr := headValueValsLoArray inputs vLo vHi - let valsHiArr := headValueValsHiArray inputs vLo vHi - headValueBoundsOfArrays valsLoArr valsHiArr := rfl - -/-- Compute value bounds from V interval bounds in parallel. -/ -def headValueBoundsTask [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - Task (HeadValueBounds seq dModel dHead) := - let dirHead := headValueDirHead inputs - let valsLoTask := buildBoundArrayTask (fun k => - Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) - let valsHiTask := buildBoundArrayTask (fun k => - Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) - Task.bind valsLoTask (fun valsLoArr => - Task.map (fun valsHiArr => headValueBoundsOfArrays valsLoArr valsHiArr) valsHiTask) - -/-- Unfold `headValueBoundsTask` to its task graph. -/ -theorem headValueBoundsTask_spec [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueBoundsTask inputs vLo vHi = - let dirHead := headValueDirHead inputs - let valsLoTask := buildBoundArrayTask (fun k => - Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) - let valsHiTask := buildBoundArrayTask (fun k => - Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) - Task.bind valsLoTask (fun valsLoArr => - Task.map (fun valsHiArr => headValueBoundsOfArrays valsLoArr valsHiArr) valsHiTask) := rfl - -/-- Compute value bounds from V interval bounds using a common-denominator sum. -/ -def headValueBoundsCommonDen [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - HeadValueBounds seq dModel dHead := - let valsLoArr := headValueValsLoCommonDenArray inputs vLo vHi - let valsHiArr := headValueValsHiCommonDenArray inputs vLo vHi - headValueBoundsOfArrays valsLoArr valsHiArr - -theorem headValueBoundsCommonDen_spec [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueBoundsCommonDen inputs vLo vHi = - let valsLoArr := headValueValsLoCommonDenArray inputs vLo vHi - let valsHiArr := headValueValsHiCommonDenArray inputs vLo vHi - headValueBoundsOfArrays valsLoArr valsHiArr := rfl - -/-- Compute value bounds from V intervals using a common-denominator sum in parallel. -/ -def headValueBoundsCommonDenTask [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - Task (HeadValueBounds seq dModel dHead) := - let dirHead := headValueDirHead inputs - let valsLoTask := buildBoundArrayTask (fun k => - Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) - let valsHiTask := buildBoundArrayTask (fun k => - Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) - Task.bind valsLoTask (fun valsLoArr => - Task.map (fun valsHiArr => headValueBoundsOfArrays valsLoArr valsHiArr) valsHiTask) - -/-- Unfold `headValueBoundsCommonDenTask` to its task graph. -/ -theorem headValueBoundsCommonDenTask_spec [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueBoundsCommonDenTask inputs vLo vHi = - let dirHead := headValueDirHead inputs - let valsLoTask := buildBoundArrayTask (fun k => - Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) - let valsHiTask := buildBoundArrayTask (fun k => - Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) - Task.bind valsLoTask (fun valsLoArr => - Task.map (fun valsHiArr => headValueBoundsOfArrays valsLoArr valsHiArr) valsHiTask) := rfl - -theorem headValueBoundsCommonDen_eq [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueBoundsCommonDen inputs vLo vHi = headValueBounds inputs vLo vHi := by - classical - simp [headValueBoundsCommonDen, headValueBounds, headValueValsLoCommonDenArray_eq, - headValueValsHiCommonDenArray_eq] - -end Sound - -end Nfp diff --git a/Nfp/Sound/Induction/HeadBounds/Basic.lean b/Nfp/Sound/Induction/HeadBounds/Basic.lean new file mode 100644 index 0000000..037ed18 --- /dev/null +++ b/Nfp/Sound/Induction/HeadBounds/Basic.lean @@ -0,0 +1,1238 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.Core.Basic +import Mathlib.Data.Finset.Basic +import Mathlib.Data.List.Range +import Mathlib.Data.Vector.Defs +import Nfp.Model.InductionHead +import Nfp.Sound.Bounds.Attention +import Nfp.Sound.Bounds.LayerNorm +import Nfp.Sound.Bounds.MatrixNorm +import Nfp.Sound.Linear.FinFold + +/-! +Helper bounds for head-induction certificate construction. + +These are pure precomputations that are useful for profiling and staging. +-/ + +namespace Nfp + +namespace Sound + +open Nfp.Sound.Bounds + +variable {seq : Nat} + +private def taskMin (t1 t2 : Task Rat) : Task Rat := + Task.bind t1 (fun v1 => Task.map (fun v2 => min v1 v2) t2) + +private def taskMax (t1 t2 : Task Rat) : Task Rat := + Task.bind t1 (fun v1 => Task.map (fun v2 => max v1 v2) t2) + +/-! Small lemmas for extracting `get` from task folds. -/ + +/-- `taskMin` exposes its `get` as a plain `min` on task results. -/ +private theorem taskMin_get (t1 t2 : Task Rat) : + (taskMin t1 t2).get = min t1.get t2.get := by + rfl + +/-- `taskMax` exposes its `get` as a plain `max` on task results. -/ +private theorem taskMax_get (t1 t2 : Task Rat) : + (taskMax t1 t2).get = max t1.get t2.get := by + rfl + +/-- Pull `get` through a `List.foldl` when the step is `get`-compatible. -/ +private theorem foldl_task_get_eq {α β : Type} (step : Task β → α → Task β) (step' : β → α → β) + (hstep : ∀ acc a, (step acc a).get = step' acc.get a) : + ∀ (xs : List α) (acc : Task β), + (List.foldl step acc xs).get = List.foldl step' acc.get xs + | [], acc => rfl + | x :: xs, acc => by + simpa [List.foldl, hstep] using foldl_task_get_eq step step' hstep xs (step acc x) + +/-- `List.foldl` over `taskMin` exposes a fold over `min` on task results. -/ +private theorem foldl_taskMin_get_eq {α : Type} (f : α → Task Rat) (xs : List α) + (init : Task Rat) : + (List.foldl (fun acc a => taskMin acc (f a)) init xs).get = + List.foldl (fun acc a => min acc (f a).get) init.get xs := by + refine + foldl_task_get_eq + (step := fun acc a => taskMin acc (f a)) + (step' := fun acc a => min acc (f a).get) + (hstep := ?_) + xs init + intro acc a + simp [taskMin_get] + +/-- `List.foldl` over `taskMax` exposes a fold over `max` on task results. -/ +private theorem foldl_taskMax_get_eq {α : Type} (f : α → Task Rat) (xs : List α) + (init : Task Rat) : + (List.foldl (fun acc a => taskMax acc (f a)) init xs).get = + List.foldl (fun acc a => max acc (f a).get) init.get xs := by + refine + foldl_task_get_eq + (step := fun acc a => taskMax acc (f a)) + (step' := fun acc a => max acc (f a).get) + (hstep := ?_) + xs init + intro acc a + simp [taskMax_get] + +/-- `Array.get?` + `Option.getD` followed by `Task.get` agrees with `getD` on values. -/ +private theorem task_getD_ofFn {n : Nat} (f : Fin n → Rat) (i : Nat) : + ((Array.ofFn fun c => ({ get := f c } : Task Rat))[i]?.getD { get := (0 : Rat) }).get = + (Array.ofFn f)[i]?.getD (0 : Rat) := by + by_cases h : i < n + · simp [h, Array.size_ofFn] + · simp [h, Array.size_ofFn] + +/-! Helpers for reducing cached arrays without extra allocation. -/ + +/-- Reduce an array of rational bounds to its minimum (defaulting to `0` on empty arrays). -/ +private def reduceMinArray (arr : Array Rat) : Rat := + let init := arr.getD 0 (0 : Rat) + arr.foldl (fun acc x => min acc x) init + +/-- Reduce an array of rational bounds to its maximum (defaulting to `0` on empty arrays). -/ +private def reduceMaxArray (arr : Array Rat) : Rat := + let init := arr.getD 0 (0 : Rat) + arr.foldl (fun acc x => max acc x) init + +/-- Reduce a `Fin seq`-indexed function using the chunked sequential algorithm. -/ +private def reduceFnChunked [NeZero seq] (vals : Fin seq → Rat) + (combine : Rat → Rat → Rat) : Rat := + let n := seq + if n = 0 then + (0 : Rat) + else + let chunkSize : Nat := 256 + let chunks : Nat := (n + chunkSize - 1) / chunkSize + let hpos : 0 < seq := Nat.pos_of_ne_zero (by simpa using (NeZero.ne (n := seq))) + let defaultIdx : Fin seq := ⟨0, hpos⟩ + let idxs : Array (Fin seq) := Array.ofFn (fun i : Fin seq => i) + let chunkVals : Array Rat := + Array.ofFn (fun c : Fin chunks => + let start := c.val * chunkSize + let stop := Nat.min n (start + chunkSize) + let init := vals (idxs.getD start defaultIdx) + if stop ≤ start + 1 then + init + else + let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) + rest.foldl (fun acc i => combine acc (vals (idxs.getD i defaultIdx))) init) + let init := chunkVals.getD 0 (0 : Rat) + let rest := (List.range (chunkVals.size - 1)).map (fun i => i + 1) + rest.foldl (fun acc i => combine acc (chunkVals.getD i 0)) init + +/-- Unfold `reduceFnChunked` to its chunked sequential definition. -/ +theorem reduceFnChunked_spec [NeZero seq] (vals : Fin seq → Rat) + (combine : Rat → Rat → Rat) : + reduceFnChunked (seq := seq) vals combine = + let n := seq + if n = 0 then + (0 : Rat) + else + let chunkSize : Nat := 256 + let chunks : Nat := (n + chunkSize - 1) / chunkSize + let hpos : 0 < seq := Nat.pos_of_ne_zero (by simpa using (NeZero.ne (n := seq))) + let defaultIdx : Fin seq := ⟨0, hpos⟩ + let idxs : Array (Fin seq) := Array.ofFn (fun i : Fin seq => i) + let chunkVals : Array Rat := + Array.ofFn (fun c : Fin chunks => + let start := c.val * chunkSize + let stop := Nat.min n (start + chunkSize) + let init := vals (idxs.getD start defaultIdx) + if stop ≤ start + 1 then + init + else + let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) + rest.foldl (fun acc i => combine acc (vals (idxs.getD i defaultIdx))) init) + let init := chunkVals.getD 0 (0 : Rat) + let rest := (List.range (chunkVals.size - 1)).map (fun i => i + 1) + rest.foldl (fun acc i => combine acc (chunkVals.getD i 0)) init := rfl + +/-- Reduce a `Fin seq`-indexed function in parallel using chunked tasks. -/ +private def reduceFnTask [NeZero seq] (vals : Fin seq → Rat) + (combine : Rat → Rat → Rat) (combineTask : Task Rat → Task Rat → Task Rat) : Task Rat := + let n := seq + if n = 0 then + Task.pure (0 : Rat) + else + let chunkSize : Nat := 256 + let chunks : Nat := (n + chunkSize - 1) / chunkSize + let hpos : 0 < seq := Nat.pos_of_ne_zero (by simpa using (NeZero.ne (n := seq))) + let defaultIdx : Fin seq := ⟨0, hpos⟩ + let idxs : Array (Fin seq) := Array.ofFn (fun i : Fin seq => i) + let defaultTask : Task Rat := Task.pure (0 : Rat) + let chunkTasks : Array (Task Rat) := + Array.ofFn (fun c : Fin chunks => + Task.spawn (fun _ => + let start := c.val * chunkSize + let stop := Nat.min n (start + chunkSize) + let init := vals (idxs.getD start defaultIdx) + if stop ≤ start + 1 then + init + else + let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) + rest.foldl (fun acc i => combine acc (vals (idxs.getD i defaultIdx))) init)) + let init := chunkTasks.getD 0 defaultTask + let rest := (List.range (chunkTasks.size - 1)).map (fun i => i + 1) + rest.foldl (fun acc i => combineTask acc (chunkTasks.getD i defaultTask)) init + +/-- Unfold `reduceFnTask` to its chunked-task definition. -/ +theorem reduceFnTask_spec [NeZero seq] (vals : Fin seq → Rat) + (combine : Rat → Rat → Rat) (combineTask : Task Rat → Task Rat → Task Rat) : + reduceFnTask (seq := seq) vals combine combineTask = + let n := seq + if n = 0 then + Task.pure (0 : Rat) + else + let chunkSize : Nat := 256 + let chunks : Nat := (n + chunkSize - 1) / chunkSize + let hpos : 0 < seq := Nat.pos_of_ne_zero (by simpa using (NeZero.ne (n := seq))) + let defaultIdx : Fin seq := ⟨0, hpos⟩ + let idxs : Array (Fin seq) := Array.ofFn (fun i : Fin seq => i) + let defaultTask : Task Rat := Task.pure (0 : Rat) + let chunkTasks : Array (Task Rat) := + Array.ofFn (fun c : Fin chunks => + Task.spawn (fun _ => + let start := c.val * chunkSize + let stop := Nat.min n (start + chunkSize) + let init := vals (idxs.getD start defaultIdx) + if stop ≤ start + 1 then + init + else + let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) + rest.foldl (fun acc i => combine acc (vals (idxs.getD i defaultIdx))) init)) + let init := chunkTasks.getD 0 defaultTask + let rest := (List.range (chunkTasks.size - 1)).map (fun i => i + 1) + rest.foldl (fun acc i => combineTask acc (chunkTasks.getD i defaultTask)) init := rfl + +private def reduceMinFnTask [NeZero seq] (vals : Fin seq → Rat) : Task Rat := + reduceFnTask vals min taskMin + +private def reduceMaxFnTask [NeZero seq] (vals : Fin seq → Rat) : Task Rat := + reduceFnTask vals max taskMax + +/-- Chunked sequential minimum over a `Fin seq`-indexed function. -/ +private def reduceMinFnChunked [NeZero seq] (vals : Fin seq → Rat) : Rat := + reduceFnChunked vals min + +/-- Unfold `reduceMinFnChunked` to `reduceFnChunked` with `min`. -/ +theorem reduceMinFnChunked_spec [NeZero seq] (vals : Fin seq → Rat) : + reduceMinFnChunked vals = reduceFnChunked vals min := rfl + +/-- Chunked sequential maximum over a `Fin seq`-indexed function. -/ +private def reduceMaxFnChunked [NeZero seq] (vals : Fin seq → Rat) : Rat := + reduceFnChunked vals max + +/-- Unfold `reduceMaxFnChunked` to `reduceFnChunked` with `max`. -/ +theorem reduceMaxFnChunked_spec [NeZero seq] (vals : Fin seq → Rat) : + reduceMaxFnChunked vals = reduceFnChunked vals max := rfl + +/-- The chunked parallel min-reduction task returns the sequential chunked result. -/ +theorem reduceMinFnTask_get_eq [NeZero seq] (vals : Fin seq → Rat) : + (reduceMinFnTask vals).get = reduceMinFnChunked vals := by + classical + have hseq : seq ≠ 0 := NeZero.ne (n := seq) + simp [reduceMinFnTask, reduceMinFnChunked, reduceFnTask, reduceFnChunked, hseq, + Task.spawn, foldl_taskMin_get_eq, task_getD_ofFn] + +/-- The chunked parallel max-reduction task returns the sequential chunked result. -/ +theorem reduceMaxFnTask_get_eq [NeZero seq] (vals : Fin seq → Rat) : + (reduceMaxFnTask vals).get = reduceMaxFnChunked vals := by + classical + have hseq : seq ≠ 0 := NeZero.ne (n := seq) + simp [reduceMaxFnTask, reduceMaxFnChunked, reduceFnTask, reduceFnChunked, hseq, + Task.spawn, foldl_taskMax_get_eq, task_getD_ofFn] + +/-- Cached direction head for head inputs. -/ +private def dirHeadVecOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Vector Rat dHead := + Vector.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.wo j d) (fun j => inputs.direction j)) + +/-- LayerNorm bounds used by the induction-head builder. -/ +def headLnBounds [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : + (Fin seq → Fin dModel → Rat) × (Fin seq → Fin dModel → Rat) := + Bounds.cacheBoundPair2 (fun q => + Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) + +theorem headLnBounds_spec [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : + headLnBounds inputs = + Bounds.cacheBoundPair2 (fun q => + Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) := rfl + +/-- Q/K/V bounds used by the induction-head builder. -/ +structure HeadQKVBounds (seq dModel dHead : Nat) where + /-- Q lower bounds. -/ + qLo : Fin seq → Fin dHead → Rat + /-- Q upper bounds. -/ + qHi : Fin seq → Fin dHead → Rat + /-- K lower bounds. -/ + kLo : Fin seq → Fin dHead → Rat + /-- K upper bounds. -/ + kHi : Fin seq → Fin dHead → Rat + /-- V lower bounds. -/ + vLo : Fin seq → Fin dHead → Rat + /-- V upper bounds. -/ + vHi : Fin seq → Fin dHead → Rat + /-- Q absolute bounds. -/ + qAbs : Fin seq → Fin dHead → Rat + /-- K absolute bounds. -/ + kAbs : Fin seq → Fin dHead → Rat + +/-- Compute Q/K/V bounds from LayerNorm bounds. -/ +def headQKVBounds [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (lnLo lnHi : Fin seq → Fin dModel → Rat) : + HeadQKVBounds seq dModel dHead := + let qLo := + Bounds.cacheBound2 (fun q d => + Bounds.dotIntervalLowerUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + + inputs.bq d) + let qHi := + Bounds.cacheBound2 (fun q d => + Bounds.dotIntervalUpperUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + + inputs.bq d) + let kLo := + Bounds.cacheBound2 (fun q d => + Bounds.dotIntervalLowerUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + + inputs.bk d) + let kHi := + Bounds.cacheBound2 (fun q d => + Bounds.dotIntervalUpperUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + + inputs.bk d) + let vLo := + Bounds.cacheBound2 (fun q d => + Bounds.dotIntervalLowerUnnorm (fun j => inputs.wv j d) (lnLo q) (lnHi q) + + inputs.bv d) + let vHi := + Bounds.cacheBound2 (fun q d => + Bounds.dotIntervalUpperUnnorm (fun j => inputs.wv j d) (lnLo q) (lnHi q) + + inputs.bv d) + let qAbs := + Bounds.cacheBound2 (fun q d => max |qLo q d| |qHi q d|) + let kAbs := + Bounds.cacheBound2 (fun q d => max |kLo q d| |kHi q d|) + { qLo := qLo + qHi := qHi + kLo := kLo + kHi := kHi + vLo := vLo + vHi := vHi + qAbs := qAbs + kAbs := kAbs } + +theorem headQKVBounds_spec [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (lnLo lnHi : Fin seq → Fin dModel → Rat) : + headQKVBounds inputs lnLo lnHi = + let qLo := + Bounds.cacheBound2 (fun q d => + Bounds.dotIntervalLowerUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + + inputs.bq d) + let qHi := + Bounds.cacheBound2 (fun q d => + Bounds.dotIntervalUpperUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + + inputs.bq d) + let kLo := + Bounds.cacheBound2 (fun q d => + Bounds.dotIntervalLowerUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + + inputs.bk d) + let kHi := + Bounds.cacheBound2 (fun q d => + Bounds.dotIntervalUpperUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + + inputs.bk d) + let vLo := + Bounds.cacheBound2 (fun q d => + Bounds.dotIntervalLowerUnnorm (fun j => inputs.wv j d) (lnLo q) (lnHi q) + + inputs.bv d) + let vHi := + Bounds.cacheBound2 (fun q d => + Bounds.dotIntervalUpperUnnorm (fun j => inputs.wv j d) (lnLo q) (lnHi q) + + inputs.bv d) + let qAbs := + Bounds.cacheBound2 (fun q d => max |qLo q d| |qHi q d|) + let kAbs := + Bounds.cacheBound2 (fun q d => max |kLo q d| |kHi q d|) + { qLo := qLo + qHi := qHi + kLo := kLo + kHi := kHi + vLo := vLo + vHi := vHi + qAbs := qAbs + kAbs := kAbs } := rfl + +/-- Score and margin bounds used by the induction-head builder. -/ +structure HeadScoreBounds (seq dModel dHead : Nat) where + /-- Absolute dot-product bound. -/ + dotAbs : Fin seq → Fin seq → Rat + /-- Base score absolute bound. -/ + scoreBaseAbs : Fin seq → Fin seq → Rat + /-- Score absolute bound with causal masking. -/ + scoreAbs : Fin seq → Fin seq → Rat + /-- Score lower bound. -/ + scoreLo : Fin seq → Fin seq → Rat + /-- Score upper bound. -/ + scoreHi : Fin seq → Fin seq → Rat + /-- Margin per query. -/ + marginAt : Fin seq → Rat + /-- Epsilon per query. -/ + epsAt : Fin seq → Rat + /-- Global margin. -/ + margin : Rat + /-- Global epsilon. -/ + eps : Rat + +/-- Compute score and margin bounds from cached score lower/upper bounds. -/ +def headScoreBoundsFromCaches [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (dotAbs : Fin seq → Fin seq → Rat) + (scoreLo scoreHi : Fin seq → Fin seq → Rat) : + HeadScoreBounds seq dModel dHead := + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => + |inputs.scale| * dotAbs q k + let scoreAbs : Fin seq → Fin seq → Rat := fun q k => + if masked q k then |inputs.maskValue| else scoreBaseAbs q k + let otherKeys : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + let maskedKeys : Fin seq → Finset (Fin seq) := fun q => + if inputs.maskCausal = true then + (otherKeys q).filter (fun k => q < k) + else + (∅ : Finset (Fin seq)) + let unmaskedKeys : Fin seq → Finset (Fin seq) := fun q => + (otherKeys q) \ (maskedKeys q) + let maskedGap : Fin seq → Rat := fun q => + scoreLo q (inputs.prev q) - inputs.maskValue + let marginTasks : Array (Task Rat) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + if q ∈ inputs.active then + let other := unmaskedKeys q + let masked := maskedKeys q + if hunmasked : other.Nonempty then + let unmaskedMin := other.inf' hunmasked (fun k => + scoreLo q (inputs.prev q) - scoreHi q k) + if hmasked : masked.Nonempty then + min unmaskedMin (maskedGap q) + else + unmaskedMin + else + if hmasked : masked.Nonempty then + maskedGap q + else + (0 : Rat) + else + (0 : Rat))) + let marginAt : Fin seq → Rat := fun q => + (marginTasks[q.1]'(by + simp [marginTasks, q.isLt])).get + let epsTasks : Array (Task Rat) := + Array.ofFn (fun q : Fin seq => + (marginTasks[q.1]'(by + simp [marginTasks, q.isLt])).map (fun m => + if m < 0 then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + m))) + let epsAt : Fin seq → Rat := fun q => + (epsTasks[q.1]'(by + simp [epsTasks, q.isLt])).get + let margin : Rat := + if h : inputs.active.Nonempty then + inputs.active.inf' h marginAt + else + (0 : Rat) + let eps : Rat := + if margin < 0 then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + margin) + { dotAbs := dotAbs + scoreBaseAbs := scoreBaseAbs + scoreAbs := scoreAbs + scoreLo := scoreLo + scoreHi := scoreHi + marginAt := marginAt + epsAt := epsAt + margin := margin + eps := eps } + +/-- Compute score and margin bounds from dot-product absolute bounds. -/ +def headScoreBoundsFromDotAbs [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (dotAbs : Fin seq → Fin seq → Rat) : HeadScoreBounds seq dModel dHead := + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let dotAbsRowTasks : Array (Task { row : Array Rat // row.size = seq }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + ⟨Array.ofFn (fun k : Fin seq => dotAbs q k), by simp⟩)) + let scaleAbs : Rat := |inputs.scale| + let scoreLoCached : Fin seq → Fin seq → Rat := fun q k => + let row := (dotAbsRowTasks[q.1]'(by + simp [dotAbsRowTasks, q.isLt])).get + let base := scaleAbs * row.1.getD k.1 0 + if masked q k then inputs.maskValue else -base + let scoreHiCached : Fin seq → Fin seq → Rat := fun q k => + let row := (dotAbsRowTasks[q.1]'(by + simp [dotAbsRowTasks, q.isLt])).get + let base := scaleAbs * row.1.getD k.1 0 + if masked q k then inputs.maskValue else base + let marginAtRaw : Fin seq → Rat := fun q => + let row := (dotAbsRowTasks[q.1]'(by + simp [dotAbsRowTasks, q.isLt])).get + if q ∈ inputs.active then + let rowArr := row.1 + let prev := inputs.prev q + let dotAbsPrev := rowArr.getD prev.1 0 + if masked q prev then + let scoreLoPrev := inputs.maskValue + let scoreHiAt : Fin seq → Rat := fun k => + if masked q k then inputs.maskValue else scaleAbs * rowArr.getD k.1 0 + let maskedGap := scoreLoPrev - inputs.maskValue + let step : + (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := + fun acc k => + if k = prev then + acc + else if masked q k then + (acc.1, true) + else + let v := scoreLoPrev - scoreHiAt k + match acc.1 with + | none => (some v, acc.2) + | some cur => (some (min cur v), acc.2) + let acc := Linear.foldlFin seq step (none, false) + match acc.1, acc.2 with + | some unmaskedMin, true => min unmaskedMin maskedGap + | some unmaskedMin, false => unmaskedMin + | none, true => maskedGap + | none, false => (0 : Rat) + else + let scoreLoPrev := -(scaleAbs * dotAbsPrev) + let maskedGap := scoreLoPrev - inputs.maskValue + let step : + (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := + fun acc k => + if k = prev then + acc + else if masked q k then + (acc.1, true) + else + let raw := -(dotAbsPrev + rowArr.getD k.1 0) + match acc.1 with + | none => (some raw, acc.2) + | some cur => (some (min cur raw), acc.2) + let acc := Linear.foldlFin seq step (none, false) + match acc.1, acc.2 with + | some unmaskedMin, true => min (scaleAbs * unmaskedMin) maskedGap + | some unmaskedMin, false => scaleAbs * unmaskedMin + | none, true => maskedGap + | none, false => (0 : Rat) + else + (0 : Rat) + let marginAtCached := Bounds.cacheBoundThunk marginAtRaw + let marginAt : Fin seq → Rat := fun q => + marginAtCached q + let epsAtRaw : Fin seq → Rat := fun q => + let m := marginAt q + if m < 0 then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + m) + let epsAtCached := Bounds.cacheBoundThunk epsAtRaw + let epsAt : Fin seq → Rat := fun q => + epsAtCached q + let margin : Rat := + if h : inputs.active.Nonempty then + inputs.active.inf' h marginAt + else + (0 : Rat) + let eps : Rat := + if margin < 0 then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + margin) + { dotAbs := dotAbs + scoreBaseAbs := fun q k => |inputs.scale| * dotAbs q k + scoreAbs := fun q k => + if masked q k then |inputs.maskValue| else |inputs.scale| * dotAbs q k + scoreLo := scoreLoCached + scoreHi := scoreHiCached + marginAt := marginAt + epsAt := epsAt + margin := margin + eps := eps } + +theorem headScoreBoundsFromDotAbs_spec [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (dotAbs : Fin seq → Fin seq → Rat) : + headScoreBoundsFromDotAbs inputs dotAbs = + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let dotAbsRowTasks : Array (Task { row : Array Rat // row.size = seq }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + ⟨Array.ofFn (fun k : Fin seq => dotAbs q k), by simp⟩)) + let scaleAbs : Rat := |inputs.scale| + let scoreLoCached : Fin seq → Fin seq → Rat := fun q k => + let row := (dotAbsRowTasks[q.1]'(by + simp [dotAbsRowTasks, q.isLt])).get + let base := scaleAbs * row.1.getD k.1 0 + if masked q k then inputs.maskValue else -base + let scoreHiCached : Fin seq → Fin seq → Rat := fun q k => + let row := (dotAbsRowTasks[q.1]'(by + simp [dotAbsRowTasks, q.isLt])).get + let base := scaleAbs * row.1.getD k.1 0 + if masked q k then inputs.maskValue else base + let marginAtRaw : Fin seq → Rat := fun q => + let row := (dotAbsRowTasks[q.1]'(by + simp [dotAbsRowTasks, q.isLt])).get + if q ∈ inputs.active then + let rowArr := row.1 + let prev := inputs.prev q + let dotAbsPrev := rowArr.getD prev.1 0 + if masked q prev then + let scoreLoPrev := inputs.maskValue + let scoreHiAt : Fin seq → Rat := fun k => + if masked q k then inputs.maskValue else scaleAbs * rowArr.getD k.1 0 + let maskedGap := scoreLoPrev - inputs.maskValue + let step : + (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := + fun acc k => + if k = prev then + acc + else if masked q k then + (acc.1, true) + else + let v := scoreLoPrev - scoreHiAt k + match acc.1 with + | none => (some v, acc.2) + | some cur => (some (min cur v), acc.2) + let acc := Linear.foldlFin seq step (none, false) + match acc.1, acc.2 with + | some unmaskedMin, true => min unmaskedMin maskedGap + | some unmaskedMin, false => unmaskedMin + | none, true => maskedGap + | none, false => (0 : Rat) + else + let scoreLoPrev := -(scaleAbs * dotAbsPrev) + let maskedGap := scoreLoPrev - inputs.maskValue + let step : + (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := + fun acc k => + if k = prev then + acc + else if masked q k then + (acc.1, true) + else + let raw := -(dotAbsPrev + rowArr.getD k.1 0) + match acc.1 with + | none => (some raw, acc.2) + | some cur => (some (min cur raw), acc.2) + let acc := Linear.foldlFin seq step (none, false) + match acc.1, acc.2 with + | some unmaskedMin, true => min (scaleAbs * unmaskedMin) maskedGap + | some unmaskedMin, false => scaleAbs * unmaskedMin + | none, true => maskedGap + | none, false => (0 : Rat) + else + (0 : Rat) + let marginAtCached := Bounds.cacheBoundThunk marginAtRaw + let marginAt : Fin seq → Rat := fun q => + marginAtCached q + let epsAtRaw : Fin seq → Rat := fun q => + let m := marginAt q + if m < 0 then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + m) + let epsAtCached := Bounds.cacheBoundThunk epsAtRaw + let epsAt : Fin seq → Rat := fun q => + epsAtCached q + let margin : Rat := + if h : inputs.active.Nonempty then + inputs.active.inf' h marginAt + else + (0 : Rat) + let eps : Rat := + if margin < 0 then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + margin) + { dotAbs := dotAbs + scoreBaseAbs := fun q k => |inputs.scale| * dotAbs q k + scoreAbs := fun q k => + if masked q k then |inputs.maskValue| else |inputs.scale| * dotAbs q k + scoreLo := scoreLoCached + scoreHi := scoreHiCached + marginAt := marginAt + epsAt := epsAt + margin := margin + eps := eps } := rfl + +/-- Compute score and margin bounds from Q/K interval bounds. -/ +def headScoreBoundsFromIntervals [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (qLo qHi kLo kHi : Fin seq → Fin dHead → Rat) : + HeadScoreBounds seq dModel dHead := + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + ⟨Array.ofFn (fun k : Fin seq => + dotIntervalLowerUpper2CommonDen (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo k d) (fun d => kHi k d)), + by simp⟩)) + let dotLo : Fin seq → Fin seq → Rat := fun q k => + let row := (dotRowTasks[q.1]'(by + simp [dotRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.1 + let dotHi : Fin seq → Fin seq → Rat := fun q k => + let row := (dotRowTasks[q.1]'(by + simp [dotRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.2 + let dotAbs : Fin seq → Fin seq → Rat := fun q k => max |dotLo q k| |dotHi q k| + let scoreLo : Fin seq → Fin seq → Rat := fun q k => + if masked q k then + inputs.maskValue + else + if 0 ≤ inputs.scale then + inputs.scale * dotLo q k + else + inputs.scale * dotHi q k + let scoreHi : Fin seq → Fin seq → Rat := fun q k => + if masked q k then + inputs.maskValue + else + if 0 ≤ inputs.scale then + inputs.scale * dotHi q k + else + inputs.scale * dotLo q k + headScoreBoundsFromCaches inputs dotAbs scoreLo scoreHi + +theorem headScoreBoundsFromIntervals_spec [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (qLo qHi kLo kHi : Fin seq → Fin dHead → Rat) : + headScoreBoundsFromIntervals inputs qLo qHi kLo kHi = + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + ⟨Array.ofFn (fun k : Fin seq => + dotIntervalLowerUpper2CommonDen (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo k d) (fun d => kHi k d)), + by simp⟩)) + let dotLo : Fin seq → Fin seq → Rat := fun q k => + let row := (dotRowTasks[q.1]'(by + simp [dotRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.1 + let dotHi : Fin seq → Fin seq → Rat := fun q k => + let row := (dotRowTasks[q.1]'(by + simp [dotRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.2 + let dotAbs : Fin seq → Fin seq → Rat := fun q k => max |dotLo q k| |dotHi q k| + let scoreLo : Fin seq → Fin seq → Rat := fun q k => + if masked q k then + inputs.maskValue + else + if 0 ≤ inputs.scale then + inputs.scale * dotLo q k + else + inputs.scale * dotHi q k + let scoreHi : Fin seq → Fin seq → Rat := fun q k => + if masked q k then + inputs.maskValue + else + if 0 ≤ inputs.scale then + inputs.scale * dotHi q k + else + inputs.scale * dotLo q k + headScoreBoundsFromCaches inputs dotAbs scoreLo scoreHi := rfl + +/-- Compute score and margin bounds from Q/K absolute bounds. -/ +def headScoreBounds [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (qAbs kAbs : Fin seq → Fin dHead → Rat) : + HeadScoreBounds seq dModel dHead := + headScoreBoundsFromDotAbs inputs (fun q k => + Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d)) + +theorem headScoreBounds_spec [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (qAbs kAbs : Fin seq → Fin dHead → Rat) : + headScoreBounds inputs qAbs kAbs = + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let dotAbs : Fin seq → Fin seq → Rat := fun q k => + Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d) + let dotAbsRowTasks : Array (Task { row : Array Rat // row.size = seq }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + ⟨Array.ofFn (fun k : Fin seq => dotAbs q k), by simp⟩)) + let scaleAbs : Rat := |inputs.scale| + let scoreLoCached : Fin seq → Fin seq → Rat := fun q k => + let row := (dotAbsRowTasks[q.1]'(by + simp [dotAbsRowTasks, q.isLt])).get + let base := scaleAbs * row.1.getD k.1 0 + if masked q k then inputs.maskValue else -base + let scoreHiCached : Fin seq → Fin seq → Rat := fun q k => + let row := (dotAbsRowTasks[q.1]'(by + simp [dotAbsRowTasks, q.isLt])).get + let base := scaleAbs * row.1.getD k.1 0 + if masked q k then inputs.maskValue else base + let marginAtRaw : Fin seq → Rat := fun q => + let row := (dotAbsRowTasks[q.1]'(by + simp [dotAbsRowTasks, q.isLt])).get + if q ∈ inputs.active then + let rowArr := row.1 + let prev := inputs.prev q + let dotAbsPrev := rowArr.getD prev.1 0 + if masked q prev then + let scoreLoPrev := inputs.maskValue + let scoreHiAt : Fin seq → Rat := fun k => + if masked q k then inputs.maskValue else scaleAbs * rowArr.getD k.1 0 + let maskedGap := scoreLoPrev - inputs.maskValue + let step : + (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := + fun acc k => + if k = prev then + acc + else if masked q k then + (acc.1, true) + else + let v := scoreLoPrev - scoreHiAt k + match acc.1 with + | none => (some v, acc.2) + | some cur => (some (min cur v), acc.2) + let acc := Linear.foldlFin seq step (none, false) + match acc.1, acc.2 with + | some unmaskedMin, true => min unmaskedMin maskedGap + | some unmaskedMin, false => unmaskedMin + | none, true => maskedGap + | none, false => (0 : Rat) + else + let scoreLoPrev := -(scaleAbs * dotAbsPrev) + let maskedGap := scoreLoPrev - inputs.maskValue + let step : + (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := + fun acc k => + if k = prev then + acc + else if masked q k then + (acc.1, true) + else + let raw := -(dotAbsPrev + rowArr.getD k.1 0) + match acc.1 with + | none => (some raw, acc.2) + | some cur => (some (min cur raw), acc.2) + let acc := Linear.foldlFin seq step (none, false) + match acc.1, acc.2 with + | some unmaskedMin, true => min (scaleAbs * unmaskedMin) maskedGap + | some unmaskedMin, false => scaleAbs * unmaskedMin + | none, true => maskedGap + | none, false => (0 : Rat) + else + (0 : Rat) + let marginAtCached := Bounds.cacheBoundThunk marginAtRaw + let marginAt : Fin seq → Rat := fun q => + marginAtCached q + let epsAtRaw : Fin seq → Rat := fun q => + let m := marginAt q + if m < 0 then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + m) + let epsAtCached := Bounds.cacheBoundThunk epsAtRaw + let epsAt : Fin seq → Rat := fun q => + epsAtCached q + let margin : Rat := + if h : inputs.active.Nonempty then + inputs.active.inf' h marginAt + else + (0 : Rat) + let eps : Rat := + if margin < 0 then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + margin) + { dotAbs := dotAbs + scoreBaseAbs := fun q k => |inputs.scale| * dotAbs q k + scoreAbs := fun q k => + if masked q k then |inputs.maskValue| else |inputs.scale| * dotAbs q k + scoreLo := scoreLoCached + scoreHi := scoreHiCached + marginAt := marginAt + epsAt := epsAt + margin := margin + eps := eps } := rfl + +/-- Value bounds used by the induction-head builder. -/ +structure HeadValueBounds (seq dModel dHead : Nat) where + /-- Value lower bounds. -/ + valsLo : Fin seq → Rat + /-- Value upper bounds. -/ + valsHi : Fin seq → Rat + /-- Global value lower bound. -/ + lo : Rat + /-- Global value upper bound. -/ + hi : Rat + +/-- Cached direction vector for value bounds. -/ +def headValueDirHead {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin dHead → Rat := + let dirHeadVec := dirHeadVecOfInputs inputs + fun d => dirHeadVec.get d + +theorem headValueDirHead_spec {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : + headValueDirHead inputs = + let dirHeadVec := dirHeadVecOfInputs inputs + fun d => dirHeadVec.get d := rfl + +/-- Cached lower value bounds from V intervals. -/ +def headValueValsLoArray {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : Array Rat := + let dirHead := headValueDirHead inputs + Array.ofFn (fun k => + Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) + +/-- Unfold `headValueValsLoArray` to its `Array.ofFn` definition. -/ +theorem headValueValsLoArray_spec {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueValsLoArray inputs vLo vHi = + let dirHead := headValueDirHead inputs + Array.ofFn (fun k => + Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) := rfl + +/-- Cached lower value bounds from V intervals. -/ +def headValueValsLo {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : Fin seq → Rat := + let arr := headValueValsLoArray inputs vLo vHi + fun k => arr.getD k.1 (0 : Rat) + +theorem headValueValsLo_spec {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueValsLo inputs vLo vHi = + let arr := headValueValsLoArray inputs vLo vHi + fun k => arr.getD k.1 (0 : Rat) := rfl + +/-- Cached lower value bounds from V intervals using a common-denominator sum. -/ +def headValueValsLoCommonDenArray {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : Array Rat := + headValueValsLoArray inputs vLo vHi + +/-- Unfold `headValueValsLoCommonDenArray` to its `Array.ofFn` definition. -/ +theorem headValueValsLoCommonDenArray_spec {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueValsLoCommonDenArray inputs vLo vHi = + let dirHead := headValueDirHead inputs + Array.ofFn (fun k => + Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) := rfl + +/-- Cached lower value bounds from V intervals using a common-denominator sum. -/ +def headValueValsLoCommonDen {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : Fin seq → Rat := + let arr := headValueValsLoCommonDenArray inputs vLo vHi + fun k => arr.getD k.1 (0 : Rat) + +theorem headValueValsLoCommonDen_spec {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueValsLoCommonDen inputs vLo vHi = + let arr := headValueValsLoCommonDenArray inputs vLo vHi + fun k => arr.getD k.1 (0 : Rat) := rfl + +/-- Common-denominator lower bounds agree with cached rational bounds pointwise. -/ +theorem headValueValsLoCommonDenArray_eq {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueValsLoCommonDenArray inputs vLo vHi = headValueValsLoArray inputs vLo vHi := by + rfl + +theorem headValueValsLoCommonDen_eq {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueValsLoCommonDen inputs vLo vHi = headValueValsLo inputs vLo vHi := by + funext k + simp [headValueValsLoCommonDen, headValueValsLo, headValueValsLoCommonDenArray_eq] + +/-- Cached upper value bounds from V intervals. -/ +def headValueValsHiArray {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : Array Rat := + let dirHead := headValueDirHead inputs + Array.ofFn (fun k => + Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) + +/-- Unfold `headValueValsHiArray` to its `Array.ofFn` definition. -/ +theorem headValueValsHiArray_spec {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueValsHiArray inputs vLo vHi = + let dirHead := headValueDirHead inputs + Array.ofFn (fun k => + Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) := rfl + +/-- Cached upper value bounds from V intervals. -/ +def headValueValsHi {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : Fin seq → Rat := + let arr := headValueValsHiArray inputs vLo vHi + fun k => arr.getD k.1 (0 : Rat) + +theorem headValueValsHi_spec {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueValsHi inputs vLo vHi = + let arr := headValueValsHiArray inputs vLo vHi + fun k => arr.getD k.1 (0 : Rat) := rfl + +/-- Cached upper value bounds from V intervals using a common-denominator sum. -/ +def headValueValsHiCommonDenArray {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : Array Rat := + headValueValsHiArray inputs vLo vHi + +/-- Unfold `headValueValsHiCommonDenArray` to its `Array.ofFn` definition. -/ +theorem headValueValsHiCommonDenArray_spec {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueValsHiCommonDenArray inputs vLo vHi = + let dirHead := headValueDirHead inputs + Array.ofFn (fun k => + Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) := rfl + +/-- Cached upper value bounds from V intervals using a common-denominator sum. -/ +def headValueValsHiCommonDen {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : Fin seq → Rat := + let arr := headValueValsHiCommonDenArray inputs vLo vHi + fun k => arr.getD k.1 (0 : Rat) + +theorem headValueValsHiCommonDen_spec {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueValsHiCommonDen inputs vLo vHi = + let arr := headValueValsHiCommonDenArray inputs vLo vHi + fun k => arr.getD k.1 (0 : Rat) := rfl + +/-- Common-denominator upper bounds agree with cached rational bounds pointwise. -/ +theorem headValueValsHiCommonDenArray_eq {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueValsHiCommonDenArray inputs vLo vHi = headValueValsHiArray inputs vLo vHi := by + rfl + +theorem headValueValsHiCommonDen_eq {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueValsHiCommonDen inputs vLo vHi = headValueValsHi inputs vLo vHi := by + funext k + simp [headValueValsHiCommonDen, headValueValsHi, headValueValsHiCommonDenArray_eq] + +/-- Global lower value bound from an array of per-key values. -/ +def headValueLoArray (valsLo : Array Rat) : Rat := + reduceMinArray valsLo + +/-- Unfold `headValueLoArray` to its reduction helper. -/ +theorem headValueLoArray_spec (valsLo : Array Rat) : + headValueLoArray valsLo = reduceMinArray valsLo := rfl + +/-- Global lower value bound from cached per-key values. -/ +def headValueLo [NeZero seq] (valsLo : Fin seq → Rat) : Rat := + headValueLoArray (Array.ofFn valsLo) + +theorem headValueLo_spec [NeZero seq] (valsLo : Fin seq → Rat) : + headValueLo valsLo = headValueLoArray (Array.ofFn valsLo) := rfl + +/-- Task wrapper for `headValueLo`. -/ +def headValueLoTask [NeZero seq] (valsLo : Fin seq → Rat) : Task Rat := + reduceMinFnTask valsLo + +theorem headValueLoTask_spec [NeZero seq] (valsLo : Fin seq → Rat) : + headValueLoTask valsLo = reduceMinFnTask valsLo := rfl + +/-- Chunked task reduction agrees with the sequential chunked value bound. -/ +theorem headValueLoTask_get_eq [NeZero seq] (valsLo : Fin seq → Rat) : + (headValueLoTask valsLo).get = reduceMinFnChunked valsLo := by + simp [headValueLoTask_spec, reduceMinFnTask_get_eq] + +/-- Global upper value bound from an array of per-key values. -/ +def headValueHiArray (valsHi : Array Rat) : Rat := + reduceMaxArray valsHi + +/-- Unfold `headValueHiArray` to its reduction helper. -/ +theorem headValueHiArray_spec (valsHi : Array Rat) : + headValueHiArray valsHi = reduceMaxArray valsHi := rfl + +/-- Global upper value bound from cached per-key values. -/ +def headValueHi [NeZero seq] (valsHi : Fin seq → Rat) : Rat := + headValueHiArray (Array.ofFn valsHi) + +theorem headValueHi_spec [NeZero seq] (valsHi : Fin seq → Rat) : + headValueHi valsHi = headValueHiArray (Array.ofFn valsHi) := rfl + +/-- Task wrapper for `headValueHi`. -/ +def headValueHiTask [NeZero seq] (valsHi : Fin seq → Rat) : Task Rat := + reduceMaxFnTask valsHi + +theorem headValueHiTask_spec [NeZero seq] (valsHi : Fin seq → Rat) : + headValueHiTask valsHi = reduceMaxFnTask valsHi := rfl + +/-- Chunked task reduction agrees with the sequential chunked value bound. -/ +theorem headValueHiTask_get_eq [NeZero seq] (valsHi : Fin seq → Rat) : + (headValueHiTask valsHi).get = reduceMaxFnChunked valsHi := by + simp [headValueHiTask_spec, reduceMaxFnTask_get_eq] + +/-- Build `HeadValueBounds` from precomputed arrays. -/ +private def headValueBoundsOfArrays {seq dModel dHead : Nat} + (valsLoArr valsHiArr : Array Rat) : HeadValueBounds seq dModel dHead := + let valsLo : Fin seq → Rat := fun k => valsLoArr.getD k.1 (0 : Rat) + let valsHi : Fin seq → Rat := fun k => valsHiArr.getD k.1 (0 : Rat) + let lo := headValueLoArray valsLoArr + let hi := headValueHiArray valsHiArr + { valsLo := valsLo, valsHi := valsHi, lo := lo, hi := hi } + +/-- Build a cached bounds array in parallel from a per-key computation. -/ +private def buildBoundArrayTask [NeZero seq] (f : Fin seq → Rat) : Task (Array Rat) := + let n := seq + let chunkSize : Nat := 64 + let chunks : Nat := (n + chunkSize - 1) / chunkSize + let hpos : 0 < seq := Nat.pos_of_ne_zero (by simpa using (NeZero.ne (n := seq))) + let defaultIdx : Fin seq := ⟨0, hpos⟩ + let idxs : Array (Fin seq) := Array.ofFn (fun i : Fin seq => i) + let chunkTasks : List (Task (Array Rat)) := + (List.range chunks).map (fun c => + Task.spawn (fun _ => + let start := c * chunkSize + let stop := Nat.min n (start + chunkSize) + let vals := + (List.range (stop - start)).map (fun i => + f (idxs.getD (start + i) defaultIdx)) + vals.toArray)) + Task.mapList (fun xs => xs.foldl (fun acc arr => acc ++ arr) #[]) chunkTasks + +/-- Compute value bounds from V interval bounds. -/ +def headValueBounds [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + HeadValueBounds seq dModel dHead := + let valsLoArr := headValueValsLoArray inputs vLo vHi + let valsHiArr := headValueValsHiArray inputs vLo vHi + headValueBoundsOfArrays valsLoArr valsHiArr + +theorem headValueBounds_spec [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueBounds inputs vLo vHi = + let valsLoArr := headValueValsLoArray inputs vLo vHi + let valsHiArr := headValueValsHiArray inputs vLo vHi + headValueBoundsOfArrays valsLoArr valsHiArr := rfl + +/-- Compute value bounds from V interval bounds in parallel. -/ +def headValueBoundsTask [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + Task (HeadValueBounds seq dModel dHead) := + let dirHead := headValueDirHead inputs + let valsLoTask := buildBoundArrayTask (fun k => + Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) + let valsHiTask := buildBoundArrayTask (fun k => + Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) + Task.bind valsLoTask (fun valsLoArr => + Task.map (fun valsHiArr => headValueBoundsOfArrays valsLoArr valsHiArr) valsHiTask) + +/-- Unfold `headValueBoundsTask` to its task graph. -/ +theorem headValueBoundsTask_spec [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueBoundsTask inputs vLo vHi = + let dirHead := headValueDirHead inputs + let valsLoTask := buildBoundArrayTask (fun k => + Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) + let valsHiTask := buildBoundArrayTask (fun k => + Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) + Task.bind valsLoTask (fun valsLoArr => + Task.map (fun valsHiArr => headValueBoundsOfArrays valsLoArr valsHiArr) valsHiTask) := rfl + +/-- Compute value bounds from V interval bounds using a common-denominator sum. -/ +def headValueBoundsCommonDen [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + HeadValueBounds seq dModel dHead := + let valsLoArr := headValueValsLoCommonDenArray inputs vLo vHi + let valsHiArr := headValueValsHiCommonDenArray inputs vLo vHi + headValueBoundsOfArrays valsLoArr valsHiArr + +theorem headValueBoundsCommonDen_spec [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueBoundsCommonDen inputs vLo vHi = + let valsLoArr := headValueValsLoCommonDenArray inputs vLo vHi + let valsHiArr := headValueValsHiCommonDenArray inputs vLo vHi + headValueBoundsOfArrays valsLoArr valsHiArr := rfl + +/-- Compute value bounds from V intervals using a common-denominator sum in parallel. -/ +def headValueBoundsCommonDenTask [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + Task (HeadValueBounds seq dModel dHead) := + let dirHead := headValueDirHead inputs + let valsLoTask := buildBoundArrayTask (fun k => + Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) + let valsHiTask := buildBoundArrayTask (fun k => + Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) + Task.bind valsLoTask (fun valsLoArr => + Task.map (fun valsHiArr => headValueBoundsOfArrays valsLoArr valsHiArr) valsHiTask) + +/-- Unfold `headValueBoundsCommonDenTask` to its task graph. -/ +theorem headValueBoundsCommonDenTask_spec [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueBoundsCommonDenTask inputs vLo vHi = + let dirHead := headValueDirHead inputs + let valsLoTask := buildBoundArrayTask (fun k => + Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) + let valsHiTask := buildBoundArrayTask (fun k => + Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) + Task.bind valsLoTask (fun valsLoArr => + Task.map (fun valsHiArr => headValueBoundsOfArrays valsLoArr valsHiArr) valsHiTask) := rfl + +theorem headValueBoundsCommonDen_eq [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (vLo vHi : Fin seq → Fin dHead → Rat) : + headValueBoundsCommonDen inputs vLo vHi = headValueBounds inputs vLo vHi := by + classical + simp [headValueBoundsCommonDen, headValueBounds, headValueValsLoCommonDenArray_eq, + headValueValsHiCommonDenArray_eq] + +end Sound + +end Nfp From ae25636b17cc305a2b4d285fa6a6cddf82975d1a Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 18:48:24 +0100 Subject: [PATCH 166/244] Add Basic submodules for induction core and run IO --- Nfp/Circuit/Layers/Induction.lean | 955 +---------------------- Nfp/Circuit/Layers/Induction/Basic.lean | 958 ++++++++++++++++++++++++ Nfp/IO/Run.lean | 800 +------------------- Nfp/IO/Run/Basic.lean | 805 ++++++++++++++++++++ Nfp/Sound/Induction/Core.lean | 923 +---------------------- Nfp/Sound/Induction/Core/Basic.lean | 922 +++++++++++++++++++++++ 6 files changed, 2692 insertions(+), 2671 deletions(-) create mode 100644 Nfp/Circuit/Layers/Induction/Basic.lean create mode 100644 Nfp/IO/Run/Basic.lean create mode 100644 Nfp/Sound/Induction/Core/Basic.lean diff --git a/Nfp/Circuit/Layers/Induction.lean b/Nfp/Circuit/Layers/Induction.lean index 8e1c3b6..7072cad 100644 --- a/Nfp/Circuit/Layers/Induction.lean +++ b/Nfp/Circuit/Layers/Induction.lean @@ -1,958 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Algebra.BigOperators.Ring.Finset -import Mathlib.Algebra.Order.Monoid.Unbundled.Basic -import Mathlib.Algebra.Order.Ring.Defs -import Nfp.Circuit.Layers.Attention +import Nfp.Circuit.Layers.Induction.Basic /-! -Induction-head specifications for attention cores. +Induction-head layer wiring and helper lemmas. -/ - -namespace Nfp - -namespace Circuit - -namespace Layers - -universe v - -open scoped BigOperators - -section Weights - -variable {Val : Type v} [NonAssocSemiring Val] -variable {seq : Nat} - -/-- Induction weights are one-hot at `prev` for each query position. -/ -def InductionWeights (prev : Fin seq → Fin seq) - (weights : Fin seq → Fin seq → Val) : Prop := - ∀ q, weights q = Pi.single (prev q) 1 - -/-- A one-hot weight vector selects the corresponding value in a dot product. -/ -theorem dotProduct_eq_of_oneHot (k : Fin seq) (vals : Fin seq → Val) : - dotProduct (Pi.single k 1) vals = vals k := by - simp - -/-- Induction weights select the `prev` value in each dot product. -/ -theorem dotProduct_eq_prev (prev : Fin seq → Fin seq) - (weights : Fin seq → Fin seq → Val) (vals : Fin seq → Fin seq → Val) - (hweights : InductionWeights (Val := Val) prev weights) (q : Fin seq) : - dotProduct (weights q) (vals q) = vals q (prev q) := by - have hq : weights q = Pi.single (prev q) 1 := hweights q - simp [hq] - -end Weights - -section Spec - -variable {Val : Type v} -variable {n : Nat} - -/-- Induction-head spec: for nonzero queries, outputs copy `prev` values. -/ -def InductionSpec (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) - (out vals : Fin (Nat.succ n) → Val) : Prop := - ∀ q, q ≠ 0 → out q = vals (prev q) - -/-- Concrete `prev` map on `Fin (n + 1)` (with `0 ↦ 0`). -/ -def prevIndex : Fin (Nat.succ n) → Fin (Nat.succ n) - | ⟨0, _⟩ => 0 - | ⟨Nat.succ k, hk⟩ => - ⟨k, Nat.lt_trans (Nat.lt_of_succ_lt_succ hk) (Nat.lt_succ_self n)⟩ - -end Spec - -section ApproxSpec - -variable {Val : Type v} [AddCommMonoid Val] [PartialOrder Val] [IsOrderedAddMonoid Val] -variable {n : Nat} - -/-- Approximate induction-head spec: outputs are within `ε` of `prev` values. -/ -def InductionSpecApprox (ε : Val) - (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) - (out vals : Fin (Nat.succ n) → Val) : Prop := - ∀ q, q ≠ 0 → out q ≤ vals (prev q) + ε ∧ vals (prev q) ≤ out q + ε - -/-- Approximate induction-head spec restricted to active queries. -/ -def InductionSpecApproxOn (ε : Val) (active : Fin (Nat.succ n) → Prop) - (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) - (out vals : Fin (Nat.succ n) → Val) : Prop := - ∀ q, active q → out q ≤ vals (prev q) + ε ∧ vals (prev q) ≤ out q + ε - -/-- Exact induction spec implies the approximate spec for any nonnegative tolerance. -/ -theorem inductionSpecApprox_of_spec (ε : Val) (hε : 0 ≤ ε) - (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) - (out vals : Fin (Nat.succ n) → Val) - (h : InductionSpec prev out vals) : - InductionSpecApprox (Val := Val) (n := n) ε prev out vals := by - intro q hq - have hq' : out q = vals (prev q) := h q hq - constructor <;> - simpa [hq'] using - (le_add_of_nonneg_right hε : - vals (prev q) ≤ vals (prev q) + ε) - -end ApproxSpec - -section ValueRange - -variable {Val : Type v} [PartialOrder Val] -variable {seq : Nat} - -/-- Value-range bounds for a vector of attention values. -/ -structure ValueRangeBounds (lo hi : Val) (vals : Fin seq → Val) : Prop where - /-- Lower and upper bounds are ordered. -/ - lo_le_hi : lo ≤ hi - /-- All values are at least `lo`. -/ - lo_le : ∀ k, lo ≤ vals k - /-- All values are at most `hi`. -/ - le_hi : ∀ k, vals k ≤ hi - -end ValueRange - -section Bounds - -variable {Val : Type v} [Semiring Val] [PartialOrder Val] -variable {seq : Nat} [NeZero seq] - -/-- Numeric bounds certifying one-hot weights on nonzero queries. -/ -structure OneHotBoundsOn (prev : Fin seq → Fin seq) - (weights : Fin seq → Fin seq → Val) : Prop where - /-- All weights are nonnegative on nonzero queries. -/ - nonneg : ∀ q, q ≠ 0 → ∀ k, 0 ≤ weights q k - /-- Weights sum to one on nonzero queries. -/ - sum_one : ∀ q, q ≠ 0 → (∑ k, weights q k) = 1 - /-- Non-prev weights are nonpositive on nonzero queries. -/ - other_le_zero : ∀ q, q ≠ 0 → ∀ k, k ≠ prev q → weights q k ≤ 0 - -/-- Certified bounds imply one-hot weights on nonzero queries. -/ -theorem oneHot_of_boundsOn (prev : Fin seq → Fin seq) - (weights : Fin seq → Fin seq → Val) [DecidableEq (Fin seq)] - (h : OneHotBoundsOn prev weights) : - ∀ q, q ≠ 0 → weights q = Pi.single (prev q) 1 := by - intro q hq - funext k - by_cases hk : k = prev q - · subst hk - have hzero : - (∑ k ∈ (Finset.univ.erase (prev q)), weights q k) = 0 := by - refine Finset.sum_eq_zero ?_ - intro k hk' - have hkne : k ≠ prev q := (Finset.mem_erase.1 hk').1 - have hle : weights q k ≤ 0 := h.other_le_zero q hq k hkne - have hge : 0 ≤ weights q k := h.nonneg q hq k - exact le_antisymm hle hge - have hsum : - weights q (prev q) + - ∑ k ∈ (Finset.univ.erase (prev q)), weights q k = - ∑ k, weights q k := by - simpa using - (Finset.add_sum_erase - (s := (Finset.univ : Finset (Fin seq))) - (f := weights q) (a := prev q) (by simp)) - have hprev : weights q (prev q) = 1 := by - have hsum' : - weights q (prev q) + 0 = 1 := by - simpa [hzero, h.sum_one q hq] using hsum - simpa using hsum' - simp [Pi.single, hprev] - · have hle : weights q k ≤ 0 := h.other_le_zero q hq k hk - have hge : 0 ≤ weights q k := h.nonneg q hq k - have hzero : weights q k = 0 := le_antisymm hle hge - simp [Pi.single, hk, hzero] - -end Bounds - -section ApproxBounds - -variable {Val : Type v} [Semiring Val] [PartialOrder Val] -variable {seq : Nat} [NeZero seq] - -/-- Approximate one-hot bounds for attention weights on nonzero queries. -/ -structure OneHotApproxBoundsOn (ε : Val) (prev : Fin seq → Fin seq) - (weights : Fin seq → Fin seq → Val) : Prop where - /-- All weights are nonnegative on nonzero queries. -/ - nonneg : ∀ q, q ≠ 0 → ∀ k, 0 ≤ weights q k - /-- Weights sum to one on nonzero queries. -/ - sum_one : ∀ q, q ≠ 0 → (∑ k, weights q k) = 1 - /-- The `prev` weight is within `ε` of one on nonzero queries. -/ - prev_large : ∀ q, q ≠ 0 → 1 ≤ weights q (prev q) + ε - /-- Non-prev weights are at most `ε` on nonzero queries. -/ - other_le : ∀ q, q ≠ 0 → ∀ k, k ≠ prev q → weights q k ≤ ε - -/-- Approximate one-hot bounds for attention weights on active queries. -/ -structure OneHotApproxBoundsOnActive (ε : Val) (active : Fin seq → Prop) - (prev : Fin seq → Fin seq) - (weights : Fin seq → Fin seq → Val) : Prop where - /-- All weights are nonnegative on active queries. -/ - nonneg : ∀ q, active q → ∀ k, 0 ≤ weights q k - /-- Weights sum to one on active queries. -/ - sum_one : ∀ q, active q → (∑ k, weights q k) = 1 - /-- The `prev` weight is within `ε` of one on active queries. -/ - prev_large : ∀ q, active q → 1 ≤ weights q (prev q) + ε - /-- Non-prev weights are at most `ε` on active queries. -/ - other_le : ∀ q, active q → ∀ k, k ≠ prev q → weights q k ≤ ε - -/-- Lift global approximate bounds to an active-set version. -/ -theorem oneHotApproxBoundsOnActive_of_on (ε : Val) (prev : Fin seq → Fin seq) - (weights : Fin seq → Fin seq → Val) - (h : OneHotApproxBoundsOn (Val := Val) ε prev weights) : - OneHotApproxBoundsOnActive (Val := Val) ε (fun q => q ≠ 0) prev weights := by - refine { nonneg := ?_, sum_one := ?_, prev_large := ?_, other_le := ?_ } - · intro q hq k - exact h.nonneg q hq k - · intro q hq - exact h.sum_one q hq - · intro q hq - exact h.prev_large q hq - · intro q hq k hk - exact h.other_le q hq k hk - -/-- Approximate induction weights: prev weight near one, others at most `ε`. -/ -def InductionWeightsApprox (ε : Val) (prev : Fin seq → Fin seq) - (weights : Fin seq → Fin seq → Val) : Prop := - ∀ q, q ≠ 0 → - 1 ≤ weights q (prev q) + ε ∧ - ∀ k, k ≠ prev q → weights q k ≤ ε - -/-- Approximate bounds imply approximate induction weights. -/ -theorem inductionWeightsApprox_of_boundsOn (ε : Val) (prev : Fin seq → Fin seq) - (weights : Fin seq → Fin seq → Val) - (h : OneHotApproxBoundsOn ε prev weights) : - InductionWeightsApprox (Val := Val) ε prev weights := by - intro q hq - exact ⟨h.prev_large q hq, h.other_le q hq⟩ - -end ApproxBounds - -section ApproxOutput - -variable {Val : Type v} [Ring Val] [LinearOrder Val] [IsOrderedRing Val] -variable {n : Nat} - -local instance : NeZero (Nat.succ n) := ⟨by simp⟩ - -/-- Approximate one-hot weights plus bounded values yield an approximate induction spec - on active queries. -/ -theorem inductionSpecApproxOn_of_oneHotApprox_valueRange - (ε lo hi : Val) (active : Fin (Nat.succ n) → Prop) - (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) - (weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Val) - (vals : Fin (Nat.succ n) → Val) - (hweights : OneHotApproxBoundsOnActive (Val := Val) ε active prev weights) - (hvals : ValueRangeBounds (Val := Val) lo hi vals) : - InductionSpecApproxOn (Val := Val) (n := n) (ε * (hi - lo)) active prev - (fun q => dotProduct (weights q) vals) vals := by - classical - intro q hq - let others : Finset (Fin (Nat.succ n)) := - (Finset.univ : Finset (Fin (Nat.succ n))).erase (prev q) - have hsum_decomp : - weights q (prev q) + ∑ k ∈ others, weights q k = ∑ k, weights q k := by - simp [others] - have hsum : - weights q (prev q) + ∑ k ∈ others, weights q k = 1 := by - simpa [hweights.sum_one q hq] using hsum_decomp - have hsum_others_le : (∑ k ∈ others, weights q k) ≤ ε := by - have hprev : 1 ≤ weights q (prev q) + ε := hweights.prev_large q hq - have hprev' : - weights q (prev q) + ∑ k ∈ others, weights q k ≤ weights q (prev q) + ε := by - simpa [hsum] using hprev - exact (add_le_add_iff_left (weights q (prev q))).1 hprev' - have hsum_others_nonneg : 0 ≤ ∑ k ∈ others, weights q k := by - refine Finset.sum_nonneg ?_ - intro k hk - exact hweights.nonneg q hq k - have hvals_hi : ∀ k, vals k ≤ hi := hvals.le_hi - have hvals_lo : ∀ k, lo ≤ vals k := hvals.lo_le - have hdiff_nonneg : 0 ≤ hi - lo := sub_nonneg.mpr hvals.lo_le_hi - have hsum_vals_le : - (∑ k ∈ others, weights q k * vals k) ≤ (∑ k ∈ others, weights q k) * hi := by - have hle : ∀ k ∈ others, weights q k * vals k ≤ weights q k * hi := by - intro k hk - have hval : vals k ≤ hi := hvals_hi k - have hnonneg : 0 ≤ weights q k := hweights.nonneg q hq k - exact mul_le_mul_of_nonneg_left hval hnonneg - calc - ∑ k ∈ others, weights q k * vals k - ≤ ∑ k ∈ others, weights q k * hi := Finset.sum_le_sum hle - _ = (∑ k ∈ others, weights q k) * hi := by - simpa using - (Finset.sum_mul (s := others) (f := fun k => weights q k) (a := hi)).symm - have hsum_vals_ge : - (∑ k ∈ others, weights q k) * lo ≤ (∑ k ∈ others, weights q k * vals k) := by - have hle : ∀ k ∈ others, weights q k * lo ≤ weights q k * vals k := by - intro k hk - have hval : lo ≤ vals k := hvals_lo k - have hnonneg : 0 ≤ weights q k := hweights.nonneg q hq k - exact mul_le_mul_of_nonneg_left hval hnonneg - calc - (∑ k ∈ others, weights q k) * lo - = ∑ k ∈ others, weights q k * lo := by - exact - (Finset.sum_mul (s := others) (f := fun k => weights q k) (a := lo)) - _ ≤ ∑ k ∈ others, weights q k * vals k := Finset.sum_le_sum hle - have hsum_prod : - weights q (prev q) * vals (prev q) + ∑ k ∈ others, weights q k * vals k = - ∑ k, weights q k * vals k := by - simp [others] - have hout_eq : - dotProduct (weights q) vals = - weights q (prev q) * vals (prev q) + ∑ k ∈ others, weights q k * vals k := by - simpa [dotProduct] using hsum_prod.symm - have hsum_val_prev : - weights q (prev q) * vals (prev q) + - (∑ k ∈ others, weights q k) * vals (prev q) = - vals (prev q) := by - calc - weights q (prev q) * vals (prev q) + - (∑ k ∈ others, weights q k) * vals (prev q) = - (weights q (prev q) + ∑ k ∈ others, weights q k) * vals (prev q) := by - simpa using - (add_mul (weights q (prev q)) (∑ k ∈ others, weights q k) (vals (prev q))).symm - _ = 1 * vals (prev q) := by - simp [hsum] - _ = vals (prev q) := by simp - have hsplit : - (∑ k ∈ others, weights q k) * hi = - (∑ k ∈ others, weights q k) * lo + - (∑ k ∈ others, weights q k) * (hi - lo) := by - calc - (∑ k ∈ others, weights q k) * hi = - (∑ k ∈ others, weights q k) * lo + - (∑ k ∈ others, weights q k) * hi - - (∑ k ∈ others, weights q k) * lo := by - exact - (add_sub_cancel_left - ((∑ k ∈ others, weights q k) * lo) ((∑ k ∈ others, weights q k) * hi)).symm - _ = (∑ k ∈ others, weights q k) * lo + - ((∑ k ∈ others, weights q k) * hi - - (∑ k ∈ others, weights q k) * lo) := by - simp [sub_eq_add_neg, add_assoc] - _ = (∑ k ∈ others, weights q k) * lo + - (∑ k ∈ others, weights q k) * (hi - lo) := by - simp [mul_sub] - have hsum_prev_le : - weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * lo ≤ - vals (prev q) := by - have hmul : (∑ k ∈ others, weights q k) * lo ≤ - (∑ k ∈ others, weights q k) * vals (prev q) := - mul_le_mul_of_nonneg_left (hvals_lo (prev q)) hsum_others_nonneg - calc - weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * lo - ≤ weights q (prev q) * vals (prev q) + - (∑ k ∈ others, weights q k) * vals (prev q) := by - have h := - add_le_add_left hmul (weights q (prev q) * vals (prev q)) - simpa [add_comm, add_left_comm, add_assoc] using h - _ = vals (prev q) := hsum_val_prev - have hupper_mid : - weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi ≤ - vals (prev q) + (∑ k ∈ others, weights q k) * (hi - lo) := by - calc - weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi = - weights q (prev q) * vals (prev q) + - ((∑ k ∈ others, weights q k) * lo + - (∑ k ∈ others, weights q k) * (hi - lo)) := by - simp [hsplit] - _ = weights q (prev q) * vals (prev q) + - (∑ k ∈ others, weights q k) * lo + - (∑ k ∈ others, weights q k) * (hi - lo) := by - simp [add_assoc] - _ ≤ vals (prev q) + (∑ k ∈ others, weights q k) * (hi - lo) := by - have h := - add_le_add_right hsum_prev_le - ((∑ k ∈ others, weights q k) * (hi - lo)) - simpa [add_comm, add_left_comm, add_assoc] using h - have hupper : - dotProduct (weights q) vals ≤ vals (prev q) + ε * (hi - lo) := by - have hmul : - (∑ k ∈ others, weights q k) * (hi - lo) ≤ ε * (hi - lo) := by - exact mul_le_mul_of_nonneg_right hsum_others_le hdiff_nonneg - calc - dotProduct (weights q) vals = - weights q (prev q) * vals (prev q) + ∑ k ∈ others, weights q k * vals k := hout_eq - _ ≤ weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi := by - have h := - add_le_add_left hsum_vals_le (weights q (prev q) * vals (prev q)) - simpa [add_comm, add_left_comm, add_assoc] using h - _ ≤ vals (prev q) + (∑ k ∈ others, weights q k) * (hi - lo) := hupper_mid - _ ≤ vals (prev q) + ε * (hi - lo) := by - have h := - add_le_add_left hmul (vals (prev q)) - simpa [add_comm, add_left_comm, add_assoc] using h - have hprev_le : - vals (prev q) ≤ - weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi := by - have hmul : (∑ k ∈ others, weights q k) * vals (prev q) ≤ - (∑ k ∈ others, weights q k) * hi := - mul_le_mul_of_nonneg_left (hvals_hi (prev q)) hsum_others_nonneg - have hmul' : - weights q (prev q) * vals (prev q) + - (∑ k ∈ others, weights q k) * vals (prev q) ≤ - weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi := by - have h := - add_le_add_left hmul (weights q (prev q) * vals (prev q)) - simpa [add_comm, add_left_comm, add_assoc] using h - calc - vals (prev q) = - weights q (prev q) * vals (prev q) + - (∑ k ∈ others, weights q k) * vals (prev q) := by - simpa using hsum_val_prev.symm - _ ≤ weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi := hmul' - have hprev_le' : - vals (prev q) ≤ - weights q (prev q) * vals (prev q) + - (∑ k ∈ others, weights q k) * lo + - (∑ k ∈ others, weights q k) * (hi - lo) := by - calc - vals (prev q) ≤ - weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi := hprev_le - _ = - weights q (prev q) * vals (prev q) + - (∑ k ∈ others, weights q k) * lo + - (∑ k ∈ others, weights q k) * (hi - lo) := by - simp [hsplit, add_assoc] - have hsub : - vals (prev q) - (∑ k ∈ others, weights q k) * (hi - lo) ≤ - weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * lo := by - exact (sub_le_iff_le_add).2 hprev_le' - have hlowershift : - vals (prev q) - ε * (hi - lo) ≤ - vals (prev q) - (∑ k ∈ others, weights q k) * (hi - lo) := by - have hmul : - (∑ k ∈ others, weights q k) * (hi - lo) ≤ ε * (hi - lo) := by - exact mul_le_mul_of_nonneg_right hsum_others_le hdiff_nonneg - exact sub_le_sub_left hmul (vals (prev q)) - have hlow : - vals (prev q) - ε * (hi - lo) ≤ dotProduct (weights q) vals := by - calc - vals (prev q) - ε * (hi - lo) ≤ - vals (prev q) - (∑ k ∈ others, weights q k) * (hi - lo) := hlowershift - _ ≤ weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * lo := hsub - _ ≤ dotProduct (weights q) vals := by - calc - weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * lo - ≤ weights q (prev q) * vals (prev q) + ∑ k ∈ others, weights q k * vals k := by - have h := - add_le_add_left hsum_vals_ge (weights q (prev q) * vals (prev q)) - simpa [add_comm, add_left_comm, add_assoc] using h - _ = dotProduct (weights q) vals := by - simp [hout_eq] - have hlower : - vals (prev q) ≤ dotProduct (weights q) vals + ε * (hi - lo) := by - exact (sub_le_iff_le_add).1 hlow - exact ⟨hupper, hlower⟩ - -/-- Approximate one-hot weights plus bounded values yield an approximate induction spec. -/ -theorem inductionSpecApprox_of_oneHotApprox_valueRange - (ε lo hi : Val) - (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) - (weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Val) - (vals : Fin (Nat.succ n) → Val) - (hweights : OneHotApproxBoundsOn (Val := Val) ε prev weights) - (hvals : ValueRangeBounds (Val := Val) lo hi vals) : - InductionSpecApprox (Val := Val) (n := n) (ε * (hi - lo)) prev - (fun q => dotProduct (weights q) vals) vals := by - have hweights' : - OneHotApproxBoundsOnActive (Val := Val) ε (fun q => q ≠ 0) prev weights := - oneHotApproxBoundsOnActive_of_on (Val := Val) (seq := Nat.succ n) - (ε := ε) (prev := prev) (weights := weights) hweights - exact - inductionSpecApproxOn_of_oneHotApprox_valueRange - (Val := Val) - (n := n) - (ε := ε) - (lo := lo) - (hi := hi) - (active := fun q => q ≠ 0) - (prev := prev) - (weights := weights) - (vals := vals) - (hweights := hweights') - (hvals := hvals) - -end ApproxOutput - -section SoftmaxMargin - -variable {Val : Type v} [Semiring Val] [PartialOrder Val] -variable {seq : Nat} [NeZero seq] - -/-- Softmax margin certificates for approximate one-hot weights. -/ -structure SoftmaxMarginBounds (ε margin : Val) (prev : Fin seq → Fin seq) - (scores weights : Fin seq → Fin seq → Val) : Prop where - /-- Score gap between `prev` and other keys on nonzero queries. -/ - score_margin : ∀ q, q ≠ 0 → ∀ k, k ≠ prev q → scores q k + margin ≤ scores q (prev q) - /-- All weights are nonnegative on nonzero queries. -/ - nonneg : ∀ q, q ≠ 0 → ∀ k, 0 ≤ weights q k - /-- Weights sum to one on nonzero queries. -/ - sum_one : ∀ q, q ≠ 0 → (∑ k, weights q k) = 1 - /-- The `prev` weight is within `ε` of one on nonzero queries. -/ - prev_large : ∀ q, q ≠ 0 → 1 ≤ weights q (prev q) + ε - /-- Non-prev weights are at most `ε` on nonzero queries. -/ - other_le : ∀ q, q ≠ 0 → ∀ k, k ≠ prev q → weights q k ≤ ε - -/-- Softmax margin certificates for approximate one-hot weights on active queries. -/ -structure SoftmaxMarginBoundsOn (ε margin : Val) (active : Fin seq → Prop) - (prev : Fin seq → Fin seq) - (scores weights : Fin seq → Fin seq → Val) : Prop where - /-- Score gap between `prev` and other keys on active queries. -/ - score_margin : ∀ q, active q → ∀ k, k ≠ prev q → scores q k + margin ≤ scores q (prev q) - /-- All weights are nonnegative on active queries. -/ - nonneg : ∀ q, active q → ∀ k, 0 ≤ weights q k - /-- Weights sum to one on active queries. -/ - sum_one : ∀ q, active q → (∑ k, weights q k) = 1 - /-- The `prev` weight is within `ε` of one on active queries. -/ - prev_large : ∀ q, active q → 1 ≤ weights q (prev q) + ε - /-- Non-prev weights are at most `ε` on active queries. -/ - other_le : ∀ q, active q → ∀ k, k ≠ prev q → weights q k ≤ ε - -/-- Lift global softmax-margin bounds to an active-set version. -/ -theorem softmaxMarginBoundsOn_of_on (ε margin : Val) (prev : Fin seq → Fin seq) - (scores weights : Fin seq → Fin seq → Val) - (h : SoftmaxMarginBounds (Val := Val) ε margin prev scores weights) : - SoftmaxMarginBoundsOn (Val := Val) ε margin (fun q => q ≠ 0) prev scores weights := by - refine - { score_margin := ?_ - nonneg := ?_ - sum_one := ?_ - prev_large := ?_ - other_le := ?_ } - · intro q hq k hk - exact h.score_margin q hq k hk - · intro q hq k - exact h.nonneg q hq k - · intro q hq - exact h.sum_one q hq - · intro q hq - exact h.prev_large q hq - · intro q hq k hk - exact h.other_le q hq k hk - -/-- Margin certificates yield approximate one-hot bounds for the weights. -/ -theorem oneHotApproxBounds_of_softmaxMargin (ε margin : Val) (prev : Fin seq → Fin seq) - (scores weights : Fin seq → Fin seq → Val) - (h : SoftmaxMarginBounds (Val := Val) ε margin prev scores weights) : - OneHotApproxBoundsOn (Val := Val) ε prev weights := by - exact - { nonneg := h.nonneg - sum_one := h.sum_one - prev_large := h.prev_large - other_le := h.other_le } - -/-- Margin certificates imply approximate induction-weight bounds. -/ -theorem inductionWeightsApprox_of_softmaxMargin (ε margin : Val) - (prev : Fin seq → Fin seq) - (scores weights : Fin seq → Fin seq → Val) - (h : SoftmaxMarginBounds (Val := Val) ε margin prev scores weights) : - InductionWeightsApprox (Val := Val) ε prev weights := by - exact inductionWeightsApprox_of_boundsOn - (Val := Val) - (seq := seq) - (ε := ε) - (prev := prev) - (weights := weights) - (h := oneHotApproxBounds_of_softmaxMargin - (Val := Val) - (seq := seq) - (ε := ε) - (margin := margin) - (prev := prev) - (scores := scores) - (weights := weights) - h) - -end SoftmaxMargin - -section SoftmaxMarginActive - -variable {Val : Type v} [Semiring Val] [PartialOrder Val] -variable {seq : Nat} - -/-- Margin certificates yield approximate one-hot bounds on active queries. -/ -theorem oneHotApproxBoundsOnActive_of_softmaxMargin (ε margin : Val) - (active : Fin seq → Prop) - (prev : Fin seq → Fin seq) - (scores weights : Fin seq → Fin seq → Val) - (h : SoftmaxMarginBoundsOn (Val := Val) ε margin active prev scores weights) : - OneHotApproxBoundsOnActive (Val := Val) ε active prev weights := by - exact - { nonneg := h.nonneg - sum_one := h.sum_one - prev_large := h.prev_large - other_le := h.other_le } - -end SoftmaxMarginActive - -section Attention - -variable {Batch : Type} [Fintype Batch] [DecidableEq Batch] -variable {seq heads dim : Nat} -variable {Val : Type v} [NonAssocSemiring Val] - -/-- Typed V-input label for attention cores. -/ -abbrev attnInputV (v : QkvIndex Batch seq heads dim) : - AttentionInput Batch seq heads dim := - Sum.inr (Sum.inr v) - -/-- Weight function feeding an attention output node. -/ -def attentionOutWeights (b : Batch) (h : Fin heads) (q : Fin seq) (d : Fin dim) - (rec : - ∀ j, - (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel j - (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) → - Val) : - Fin seq → Val := - fun k => - rec (attnWeight (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, h, q, k)) - (attentionDag_rel_weight_out (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) - b h q k d) - -/-- Value function feeding an attention output node. -/ -def attentionOutValues (b : Batch) (h : Fin heads) (q : Fin seq) (d : Fin dim) - (rec : - ∀ j, - (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel j - (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) → - Val) : - Fin seq → Val := - fun k => - rec (attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, k, h, d)) - (attentionDag_rel_v_out (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) - b k h d q) - -/-- One-hot attention weights force the output to copy the selected value. -/ -theorem attentionGate_out_eq_of_oneHot (scale : Val) - (softmax : (Fin seq → Val) → Fin seq → Val) (prev : Fin seq → Fin seq) - (b : Batch) (h : Fin heads) (q : Fin seq) (d : Fin dim) - (rec : - ∀ j, - (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel j - (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) → - Val) - (hweights : - attentionOutWeights (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) - b h q d rec = - Pi.single (prev q) 1) : - attentionGate (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax - (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) rec = - attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) - b h q d rec (prev q) := by - simp only [attentionGate] - change - dotProduct - (attentionOutWeights (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) - b h q d rec) - (attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) - b h q d rec) = - attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) - b h q d rec (prev q) - rw [hweights] - exact dotProduct_eq_of_oneHot (Val := Val) (seq := seq) (k := prev q) - (vals := attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) - b h q d rec) - -section Typed - -variable (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) - -/-- Attention output equals the selected V input when weights are one-hot. -/ -theorem attentionTyped_eval_out_eq_of_oneHot (prev : Fin seq → Fin seq) - (input : AttentionInput Batch seq heads dim → Val) - (b : Batch) (h : Fin heads) (q : Fin seq) (d : Fin dim) - (hweights : - attentionOutWeights (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) - b h q d - (fun j _ => - Circuit.evalInput - (attentionCircuit (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) - scale softmax) - ((attentionInterface (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) - scale softmax).toInputAssignment input) j) = - Pi.single (prev q) 1) : - (attentionTyped (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax).eval - input (b, q, h, d) = - input - (attnInputV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) - (b, prev q, h, d)) := by - let C := - attentionCircuit (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax - let I := - attentionInterface (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax - let inputAssign := I.toInputAssignment input - have hnot : - attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d) ∉ - attentionInputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) := by - simpa using - (not_mem_attentionInputs_inr (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) - (s := Sum.inr (Sum.inr (b, q, h, d)))) - have hgate : - Circuit.evalInput C inputAssign - (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) = - attentionGate (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax - (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) - (fun j _ => Circuit.evalInput C inputAssign j) := by - exact Circuit.evalInput_eq_gate (C := C) (input := inputAssign) - (i := attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) - hnot - have hcopy : - Circuit.evalInput C inputAssign - (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) = - attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) - b h q d (fun j _ => Circuit.evalInput C inputAssign j) (prev q) := by - have hgate' : - attentionGate (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax - (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) - (fun j _ => Circuit.evalInput C inputAssign j) = - attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) - b h q d (fun j _ => Circuit.evalInput C inputAssign j) (prev q) := - attentionGate_out_eq_of_oneHot (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) - (scale := scale) (softmax := softmax) (prev := prev) (b := b) (h := h) (q := q) (d := d) - (rec := fun j _ => Circuit.evalInput C inputAssign j) hweights - exact hgate.trans hgate' - have hmem : - attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, prev q, h, d) ∈ - attentionInputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) := by - refine (mem_attentionInputs_iff (Batch := Batch) (seq := seq) (heads := heads) - (dim := dim)).2 ?_ - exact ⟨attnInputV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) - (b, prev q, h, d), rfl⟩ - have hinput : - Circuit.evalInput C inputAssign - (attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, prev q, h, d)) = - input (attnInputV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) - (b, prev q, h, d)) := by - have h := - Circuit.evalInput_eq_input (C := C) (input := inputAssign) - (i := attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) - (b, prev q, h, d)) hmem - simpa [inputAssign, I, attentionInterface, attnInputV] using h - have hvals : - attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) - b h q d (fun j _ => Circuit.evalInput C inputAssign j) (prev q) = - Circuit.evalInput C inputAssign - (attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) - (b, prev q, h, d)) := rfl - calc - (attentionTyped (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax).eval - input (b, q, h, d) = - Circuit.evalInput C inputAssign (I.outputs (b, q, h, d)).1 := by - simp [TypedCircuit.eval, Interface.eval, C, I, inputAssign, attentionTyped] - _ = Circuit.evalInput C inputAssign - (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) := by - rfl - _ = attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) - b h q d (fun j _ => Circuit.evalInput C inputAssign j) (prev q) := hcopy - _ = Circuit.evalInput C inputAssign - (attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) - (b, prev q, h, d)) := hvals - _ = input (attnInputV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) - (b, prev q, h, d)) := hinput - -end Typed - -end Attention - -section InductionSpecTyped - -variable {Batch : Type} [Fintype Batch] [DecidableEq Batch] -variable {heads dim n : Nat} -variable {Val : Type v} [NonAssocSemiring Val] - -variable (scale : Val) -variable (softmax : (Fin (Nat.succ n) → Val) → Fin (Nat.succ n) → Val) - -/-- One-hot weights on nonzero queries imply the induction spec for typed evaluation. -/ -theorem attentionTyped_eval_inductionSpec_of_oneHot - (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) - (input : AttentionInput Batch (Nat.succ n) heads dim → Val) - (b : Batch) (h : Fin heads) (d : Fin dim) - (hweights : - ∀ q, q ≠ 0 → - attentionOutWeights - (Batch := Batch) - (seq := Nat.succ n) - (heads := heads) - (dim := dim) - b h q d - (fun j _ => - Circuit.evalInput - (attentionCircuit - (Batch := Batch) - (seq := Nat.succ n) - (heads := heads) - (dim := dim) - scale softmax) - ((attentionInterface - (Batch := Batch) - (seq := Nat.succ n) - (heads := heads) - (dim := dim) - scale softmax).toInputAssignment input) j) = - Pi.single (prev q) 1) : - InductionSpec (n := n) prev - (fun q => - (attentionTyped - (Batch := Batch) - (seq := Nat.succ n) - (heads := heads) - (dim := dim) - scale softmax).eval input (b, q, h, d)) - (fun k => - input (attnInputV - (Batch := Batch) - (seq := Nat.succ n) - (heads := heads) - (dim := dim) - (b, k, h, d))) := by - intro q hq - have hweights_q := hweights q hq - exact attentionTyped_eval_out_eq_of_oneHot - (Batch := Batch) - (seq := Nat.succ n) - (heads := heads) - (dim := dim) - (scale := scale) - (softmax := softmax) - (prev := prev) - (input := input) - (b := b) - (h := h) - (q := q) - (d := d) - hweights_q - -/-- Induction spec for `prevIndex` under one-hot weight hypotheses. -/ -theorem attentionTyped_eval_inductionSpec_prevIndex - (input : AttentionInput Batch (Nat.succ n) heads dim → Val) - (b : Batch) (h : Fin heads) (d : Fin dim) - (hweights : - ∀ q, q ≠ 0 → - attentionOutWeights - (Batch := Batch) - (seq := Nat.succ n) - (heads := heads) - (dim := dim) - b h q d - (fun j _ => - Circuit.evalInput - (attentionCircuit - (Batch := Batch) - (seq := Nat.succ n) - (heads := heads) - (dim := dim) - scale softmax) - ((attentionInterface - (Batch := Batch) - (seq := Nat.succ n) - (heads := heads) - (dim := dim) - scale softmax).toInputAssignment input) j) = - Pi.single (prevIndex (n := n) q) 1) : - InductionSpec (n := n) (prevIndex (n := n)) - (fun q => - (attentionTyped - (Batch := Batch) - (seq := Nat.succ n) - (heads := heads) - (dim := dim) - scale softmax).eval input (b, q, h, d)) - (fun k => - input (attnInputV - (Batch := Batch) - (seq := Nat.succ n) - (heads := heads) - (dim := dim) - (b, k, h, d))) := by - exact attentionTyped_eval_inductionSpec_of_oneHot - (Batch := Batch) - (heads := heads) - (dim := dim) - (n := n) - (scale := scale) - (softmax := softmax) - (prev := prevIndex (n := n)) - (input := input) - (b := b) - (h := h) - (d := d) - hweights - -end InductionSpecTyped - -section InductionSpecApproxTyped - -variable {Batch : Type} [Fintype Batch] [DecidableEq Batch] -variable {heads dim n : Nat} -variable {Val : Type v} [NonAssocSemiring Val] [PartialOrder Val] [IsOrderedAddMonoid Val] - -variable (scale : Val) -variable (softmax : (Fin (Nat.succ n) → Val) → Fin (Nat.succ n) → Val) - -/-- One-hot weights imply the approximate induction spec for any nonnegative tolerance. -/ -theorem attentionTyped_eval_inductionSpecApprox_of_oneHot (ε : Val) (hε : 0 ≤ ε) - (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) - (input : AttentionInput Batch (Nat.succ n) heads dim → Val) - (b : Batch) (h : Fin heads) (d : Fin dim) - (hweights : - ∀ q, q ≠ 0 → - attentionOutWeights - (Batch := Batch) - (seq := Nat.succ n) - (heads := heads) - (dim := dim) - b h q d - (fun j _ => - Circuit.evalInput - (attentionCircuit - (Batch := Batch) - (seq := Nat.succ n) - (heads := heads) - (dim := dim) - scale softmax) - ((attentionInterface - (Batch := Batch) - (seq := Nat.succ n) - (heads := heads) - (dim := dim) - scale softmax).toInputAssignment input) j) = - Pi.single (prev q) 1) : - InductionSpecApprox (Val := Val) (n := n) ε prev - (fun q => - (attentionTyped - (Batch := Batch) - (seq := Nat.succ n) - (heads := heads) - (dim := dim) - scale softmax).eval input (b, q, h, d)) - (fun k => - input (attnInputV - (Batch := Batch) - (seq := Nat.succ n) - (heads := heads) - (dim := dim) - (b, k, h, d))) := by - apply inductionSpecApprox_of_spec (Val := Val) (n := n) (ε := ε) hε - exact attentionTyped_eval_inductionSpec_of_oneHot - (Batch := Batch) - (heads := heads) - (dim := dim) - (n := n) - (scale := scale) - (softmax := softmax) - (prev := prev) - (input := input) - (b := b) - (h := h) - (d := d) - hweights - -end InductionSpecApproxTyped - -end Layers - -end Circuit - -end Nfp diff --git a/Nfp/Circuit/Layers/Induction/Basic.lean b/Nfp/Circuit/Layers/Induction/Basic.lean new file mode 100644 index 0000000..8e1c3b6 --- /dev/null +++ b/Nfp/Circuit/Layers/Induction/Basic.lean @@ -0,0 +1,958 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Mathlib.Algebra.BigOperators.Group.Finset.Basic +import Mathlib.Algebra.BigOperators.Ring.Finset +import Mathlib.Algebra.Order.Monoid.Unbundled.Basic +import Mathlib.Algebra.Order.Ring.Defs +import Nfp.Circuit.Layers.Attention + +/-! +Induction-head specifications for attention cores. +-/ + +namespace Nfp + +namespace Circuit + +namespace Layers + +universe v + +open scoped BigOperators + +section Weights + +variable {Val : Type v} [NonAssocSemiring Val] +variable {seq : Nat} + +/-- Induction weights are one-hot at `prev` for each query position. -/ +def InductionWeights (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) : Prop := + ∀ q, weights q = Pi.single (prev q) 1 + +/-- A one-hot weight vector selects the corresponding value in a dot product. -/ +theorem dotProduct_eq_of_oneHot (k : Fin seq) (vals : Fin seq → Val) : + dotProduct (Pi.single k 1) vals = vals k := by + simp + +/-- Induction weights select the `prev` value in each dot product. -/ +theorem dotProduct_eq_prev (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) (vals : Fin seq → Fin seq → Val) + (hweights : InductionWeights (Val := Val) prev weights) (q : Fin seq) : + dotProduct (weights q) (vals q) = vals q (prev q) := by + have hq : weights q = Pi.single (prev q) 1 := hweights q + simp [hq] + +end Weights + +section Spec + +variable {Val : Type v} +variable {n : Nat} + +/-- Induction-head spec: for nonzero queries, outputs copy `prev` values. -/ +def InductionSpec (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (out vals : Fin (Nat.succ n) → Val) : Prop := + ∀ q, q ≠ 0 → out q = vals (prev q) + +/-- Concrete `prev` map on `Fin (n + 1)` (with `0 ↦ 0`). -/ +def prevIndex : Fin (Nat.succ n) → Fin (Nat.succ n) + | ⟨0, _⟩ => 0 + | ⟨Nat.succ k, hk⟩ => + ⟨k, Nat.lt_trans (Nat.lt_of_succ_lt_succ hk) (Nat.lt_succ_self n)⟩ + +end Spec + +section ApproxSpec + +variable {Val : Type v} [AddCommMonoid Val] [PartialOrder Val] [IsOrderedAddMonoid Val] +variable {n : Nat} + +/-- Approximate induction-head spec: outputs are within `ε` of `prev` values. -/ +def InductionSpecApprox (ε : Val) + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (out vals : Fin (Nat.succ n) → Val) : Prop := + ∀ q, q ≠ 0 → out q ≤ vals (prev q) + ε ∧ vals (prev q) ≤ out q + ε + +/-- Approximate induction-head spec restricted to active queries. -/ +def InductionSpecApproxOn (ε : Val) (active : Fin (Nat.succ n) → Prop) + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (out vals : Fin (Nat.succ n) → Val) : Prop := + ∀ q, active q → out q ≤ vals (prev q) + ε ∧ vals (prev q) ≤ out q + ε + +/-- Exact induction spec implies the approximate spec for any nonnegative tolerance. -/ +theorem inductionSpecApprox_of_spec (ε : Val) (hε : 0 ≤ ε) + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (out vals : Fin (Nat.succ n) → Val) + (h : InductionSpec prev out vals) : + InductionSpecApprox (Val := Val) (n := n) ε prev out vals := by + intro q hq + have hq' : out q = vals (prev q) := h q hq + constructor <;> + simpa [hq'] using + (le_add_of_nonneg_right hε : + vals (prev q) ≤ vals (prev q) + ε) + +end ApproxSpec + +section ValueRange + +variable {Val : Type v} [PartialOrder Val] +variable {seq : Nat} + +/-- Value-range bounds for a vector of attention values. -/ +structure ValueRangeBounds (lo hi : Val) (vals : Fin seq → Val) : Prop where + /-- Lower and upper bounds are ordered. -/ + lo_le_hi : lo ≤ hi + /-- All values are at least `lo`. -/ + lo_le : ∀ k, lo ≤ vals k + /-- All values are at most `hi`. -/ + le_hi : ∀ k, vals k ≤ hi + +end ValueRange + +section Bounds + +variable {Val : Type v} [Semiring Val] [PartialOrder Val] +variable {seq : Nat} [NeZero seq] + +/-- Numeric bounds certifying one-hot weights on nonzero queries. -/ +structure OneHotBoundsOn (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) : Prop where + /-- All weights are nonnegative on nonzero queries. -/ + nonneg : ∀ q, q ≠ 0 → ∀ k, 0 ≤ weights q k + /-- Weights sum to one on nonzero queries. -/ + sum_one : ∀ q, q ≠ 0 → (∑ k, weights q k) = 1 + /-- Non-prev weights are nonpositive on nonzero queries. -/ + other_le_zero : ∀ q, q ≠ 0 → ∀ k, k ≠ prev q → weights q k ≤ 0 + +/-- Certified bounds imply one-hot weights on nonzero queries. -/ +theorem oneHot_of_boundsOn (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) [DecidableEq (Fin seq)] + (h : OneHotBoundsOn prev weights) : + ∀ q, q ≠ 0 → weights q = Pi.single (prev q) 1 := by + intro q hq + funext k + by_cases hk : k = prev q + · subst hk + have hzero : + (∑ k ∈ (Finset.univ.erase (prev q)), weights q k) = 0 := by + refine Finset.sum_eq_zero ?_ + intro k hk' + have hkne : k ≠ prev q := (Finset.mem_erase.1 hk').1 + have hle : weights q k ≤ 0 := h.other_le_zero q hq k hkne + have hge : 0 ≤ weights q k := h.nonneg q hq k + exact le_antisymm hle hge + have hsum : + weights q (prev q) + + ∑ k ∈ (Finset.univ.erase (prev q)), weights q k = + ∑ k, weights q k := by + simpa using + (Finset.add_sum_erase + (s := (Finset.univ : Finset (Fin seq))) + (f := weights q) (a := prev q) (by simp)) + have hprev : weights q (prev q) = 1 := by + have hsum' : + weights q (prev q) + 0 = 1 := by + simpa [hzero, h.sum_one q hq] using hsum + simpa using hsum' + simp [Pi.single, hprev] + · have hle : weights q k ≤ 0 := h.other_le_zero q hq k hk + have hge : 0 ≤ weights q k := h.nonneg q hq k + have hzero : weights q k = 0 := le_antisymm hle hge + simp [Pi.single, hk, hzero] + +end Bounds + +section ApproxBounds + +variable {Val : Type v} [Semiring Val] [PartialOrder Val] +variable {seq : Nat} [NeZero seq] + +/-- Approximate one-hot bounds for attention weights on nonzero queries. -/ +structure OneHotApproxBoundsOn (ε : Val) (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) : Prop where + /-- All weights are nonnegative on nonzero queries. -/ + nonneg : ∀ q, q ≠ 0 → ∀ k, 0 ≤ weights q k + /-- Weights sum to one on nonzero queries. -/ + sum_one : ∀ q, q ≠ 0 → (∑ k, weights q k) = 1 + /-- The `prev` weight is within `ε` of one on nonzero queries. -/ + prev_large : ∀ q, q ≠ 0 → 1 ≤ weights q (prev q) + ε + /-- Non-prev weights are at most `ε` on nonzero queries. -/ + other_le : ∀ q, q ≠ 0 → ∀ k, k ≠ prev q → weights q k ≤ ε + +/-- Approximate one-hot bounds for attention weights on active queries. -/ +structure OneHotApproxBoundsOnActive (ε : Val) (active : Fin seq → Prop) + (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) : Prop where + /-- All weights are nonnegative on active queries. -/ + nonneg : ∀ q, active q → ∀ k, 0 ≤ weights q k + /-- Weights sum to one on active queries. -/ + sum_one : ∀ q, active q → (∑ k, weights q k) = 1 + /-- The `prev` weight is within `ε` of one on active queries. -/ + prev_large : ∀ q, active q → 1 ≤ weights q (prev q) + ε + /-- Non-prev weights are at most `ε` on active queries. -/ + other_le : ∀ q, active q → ∀ k, k ≠ prev q → weights q k ≤ ε + +/-- Lift global approximate bounds to an active-set version. -/ +theorem oneHotApproxBoundsOnActive_of_on (ε : Val) (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) + (h : OneHotApproxBoundsOn (Val := Val) ε prev weights) : + OneHotApproxBoundsOnActive (Val := Val) ε (fun q => q ≠ 0) prev weights := by + refine { nonneg := ?_, sum_one := ?_, prev_large := ?_, other_le := ?_ } + · intro q hq k + exact h.nonneg q hq k + · intro q hq + exact h.sum_one q hq + · intro q hq + exact h.prev_large q hq + · intro q hq k hk + exact h.other_le q hq k hk + +/-- Approximate induction weights: prev weight near one, others at most `ε`. -/ +def InductionWeightsApprox (ε : Val) (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) : Prop := + ∀ q, q ≠ 0 → + 1 ≤ weights q (prev q) + ε ∧ + ∀ k, k ≠ prev q → weights q k ≤ ε + +/-- Approximate bounds imply approximate induction weights. -/ +theorem inductionWeightsApprox_of_boundsOn (ε : Val) (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) + (h : OneHotApproxBoundsOn ε prev weights) : + InductionWeightsApprox (Val := Val) ε prev weights := by + intro q hq + exact ⟨h.prev_large q hq, h.other_le q hq⟩ + +end ApproxBounds + +section ApproxOutput + +variable {Val : Type v} [Ring Val] [LinearOrder Val] [IsOrderedRing Val] +variable {n : Nat} + +local instance : NeZero (Nat.succ n) := ⟨by simp⟩ + +/-- Approximate one-hot weights plus bounded values yield an approximate induction spec + on active queries. -/ +theorem inductionSpecApproxOn_of_oneHotApprox_valueRange + (ε lo hi : Val) (active : Fin (Nat.succ n) → Prop) + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Val) + (vals : Fin (Nat.succ n) → Val) + (hweights : OneHotApproxBoundsOnActive (Val := Val) ε active prev weights) + (hvals : ValueRangeBounds (Val := Val) lo hi vals) : + InductionSpecApproxOn (Val := Val) (n := n) (ε * (hi - lo)) active prev + (fun q => dotProduct (weights q) vals) vals := by + classical + intro q hq + let others : Finset (Fin (Nat.succ n)) := + (Finset.univ : Finset (Fin (Nat.succ n))).erase (prev q) + have hsum_decomp : + weights q (prev q) + ∑ k ∈ others, weights q k = ∑ k, weights q k := by + simp [others] + have hsum : + weights q (prev q) + ∑ k ∈ others, weights q k = 1 := by + simpa [hweights.sum_one q hq] using hsum_decomp + have hsum_others_le : (∑ k ∈ others, weights q k) ≤ ε := by + have hprev : 1 ≤ weights q (prev q) + ε := hweights.prev_large q hq + have hprev' : + weights q (prev q) + ∑ k ∈ others, weights q k ≤ weights q (prev q) + ε := by + simpa [hsum] using hprev + exact (add_le_add_iff_left (weights q (prev q))).1 hprev' + have hsum_others_nonneg : 0 ≤ ∑ k ∈ others, weights q k := by + refine Finset.sum_nonneg ?_ + intro k hk + exact hweights.nonneg q hq k + have hvals_hi : ∀ k, vals k ≤ hi := hvals.le_hi + have hvals_lo : ∀ k, lo ≤ vals k := hvals.lo_le + have hdiff_nonneg : 0 ≤ hi - lo := sub_nonneg.mpr hvals.lo_le_hi + have hsum_vals_le : + (∑ k ∈ others, weights q k * vals k) ≤ (∑ k ∈ others, weights q k) * hi := by + have hle : ∀ k ∈ others, weights q k * vals k ≤ weights q k * hi := by + intro k hk + have hval : vals k ≤ hi := hvals_hi k + have hnonneg : 0 ≤ weights q k := hweights.nonneg q hq k + exact mul_le_mul_of_nonneg_left hval hnonneg + calc + ∑ k ∈ others, weights q k * vals k + ≤ ∑ k ∈ others, weights q k * hi := Finset.sum_le_sum hle + _ = (∑ k ∈ others, weights q k) * hi := by + simpa using + (Finset.sum_mul (s := others) (f := fun k => weights q k) (a := hi)).symm + have hsum_vals_ge : + (∑ k ∈ others, weights q k) * lo ≤ (∑ k ∈ others, weights q k * vals k) := by + have hle : ∀ k ∈ others, weights q k * lo ≤ weights q k * vals k := by + intro k hk + have hval : lo ≤ vals k := hvals_lo k + have hnonneg : 0 ≤ weights q k := hweights.nonneg q hq k + exact mul_le_mul_of_nonneg_left hval hnonneg + calc + (∑ k ∈ others, weights q k) * lo + = ∑ k ∈ others, weights q k * lo := by + exact + (Finset.sum_mul (s := others) (f := fun k => weights q k) (a := lo)) + _ ≤ ∑ k ∈ others, weights q k * vals k := Finset.sum_le_sum hle + have hsum_prod : + weights q (prev q) * vals (prev q) + ∑ k ∈ others, weights q k * vals k = + ∑ k, weights q k * vals k := by + simp [others] + have hout_eq : + dotProduct (weights q) vals = + weights q (prev q) * vals (prev q) + ∑ k ∈ others, weights q k * vals k := by + simpa [dotProduct] using hsum_prod.symm + have hsum_val_prev : + weights q (prev q) * vals (prev q) + + (∑ k ∈ others, weights q k) * vals (prev q) = + vals (prev q) := by + calc + weights q (prev q) * vals (prev q) + + (∑ k ∈ others, weights q k) * vals (prev q) = + (weights q (prev q) + ∑ k ∈ others, weights q k) * vals (prev q) := by + simpa using + (add_mul (weights q (prev q)) (∑ k ∈ others, weights q k) (vals (prev q))).symm + _ = 1 * vals (prev q) := by + simp [hsum] + _ = vals (prev q) := by simp + have hsplit : + (∑ k ∈ others, weights q k) * hi = + (∑ k ∈ others, weights q k) * lo + + (∑ k ∈ others, weights q k) * (hi - lo) := by + calc + (∑ k ∈ others, weights q k) * hi = + (∑ k ∈ others, weights q k) * lo + + (∑ k ∈ others, weights q k) * hi - + (∑ k ∈ others, weights q k) * lo := by + exact + (add_sub_cancel_left + ((∑ k ∈ others, weights q k) * lo) ((∑ k ∈ others, weights q k) * hi)).symm + _ = (∑ k ∈ others, weights q k) * lo + + ((∑ k ∈ others, weights q k) * hi - + (∑ k ∈ others, weights q k) * lo) := by + simp [sub_eq_add_neg, add_assoc] + _ = (∑ k ∈ others, weights q k) * lo + + (∑ k ∈ others, weights q k) * (hi - lo) := by + simp [mul_sub] + have hsum_prev_le : + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * lo ≤ + vals (prev q) := by + have hmul : (∑ k ∈ others, weights q k) * lo ≤ + (∑ k ∈ others, weights q k) * vals (prev q) := + mul_le_mul_of_nonneg_left (hvals_lo (prev q)) hsum_others_nonneg + calc + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * lo + ≤ weights q (prev q) * vals (prev q) + + (∑ k ∈ others, weights q k) * vals (prev q) := by + have h := + add_le_add_left hmul (weights q (prev q) * vals (prev q)) + simpa [add_comm, add_left_comm, add_assoc] using h + _ = vals (prev q) := hsum_val_prev + have hupper_mid : + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi ≤ + vals (prev q) + (∑ k ∈ others, weights q k) * (hi - lo) := by + calc + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi = + weights q (prev q) * vals (prev q) + + ((∑ k ∈ others, weights q k) * lo + + (∑ k ∈ others, weights q k) * (hi - lo)) := by + simp [hsplit] + _ = weights q (prev q) * vals (prev q) + + (∑ k ∈ others, weights q k) * lo + + (∑ k ∈ others, weights q k) * (hi - lo) := by + simp [add_assoc] + _ ≤ vals (prev q) + (∑ k ∈ others, weights q k) * (hi - lo) := by + have h := + add_le_add_right hsum_prev_le + ((∑ k ∈ others, weights q k) * (hi - lo)) + simpa [add_comm, add_left_comm, add_assoc] using h + have hupper : + dotProduct (weights q) vals ≤ vals (prev q) + ε * (hi - lo) := by + have hmul : + (∑ k ∈ others, weights q k) * (hi - lo) ≤ ε * (hi - lo) := by + exact mul_le_mul_of_nonneg_right hsum_others_le hdiff_nonneg + calc + dotProduct (weights q) vals = + weights q (prev q) * vals (prev q) + ∑ k ∈ others, weights q k * vals k := hout_eq + _ ≤ weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi := by + have h := + add_le_add_left hsum_vals_le (weights q (prev q) * vals (prev q)) + simpa [add_comm, add_left_comm, add_assoc] using h + _ ≤ vals (prev q) + (∑ k ∈ others, weights q k) * (hi - lo) := hupper_mid + _ ≤ vals (prev q) + ε * (hi - lo) := by + have h := + add_le_add_left hmul (vals (prev q)) + simpa [add_comm, add_left_comm, add_assoc] using h + have hprev_le : + vals (prev q) ≤ + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi := by + have hmul : (∑ k ∈ others, weights q k) * vals (prev q) ≤ + (∑ k ∈ others, weights q k) * hi := + mul_le_mul_of_nonneg_left (hvals_hi (prev q)) hsum_others_nonneg + have hmul' : + weights q (prev q) * vals (prev q) + + (∑ k ∈ others, weights q k) * vals (prev q) ≤ + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi := by + have h := + add_le_add_left hmul (weights q (prev q) * vals (prev q)) + simpa [add_comm, add_left_comm, add_assoc] using h + calc + vals (prev q) = + weights q (prev q) * vals (prev q) + + (∑ k ∈ others, weights q k) * vals (prev q) := by + simpa using hsum_val_prev.symm + _ ≤ weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi := hmul' + have hprev_le' : + vals (prev q) ≤ + weights q (prev q) * vals (prev q) + + (∑ k ∈ others, weights q k) * lo + + (∑ k ∈ others, weights q k) * (hi - lo) := by + calc + vals (prev q) ≤ + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi := hprev_le + _ = + weights q (prev q) * vals (prev q) + + (∑ k ∈ others, weights q k) * lo + + (∑ k ∈ others, weights q k) * (hi - lo) := by + simp [hsplit, add_assoc] + have hsub : + vals (prev q) - (∑ k ∈ others, weights q k) * (hi - lo) ≤ + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * lo := by + exact (sub_le_iff_le_add).2 hprev_le' + have hlowershift : + vals (prev q) - ε * (hi - lo) ≤ + vals (prev q) - (∑ k ∈ others, weights q k) * (hi - lo) := by + have hmul : + (∑ k ∈ others, weights q k) * (hi - lo) ≤ ε * (hi - lo) := by + exact mul_le_mul_of_nonneg_right hsum_others_le hdiff_nonneg + exact sub_le_sub_left hmul (vals (prev q)) + have hlow : + vals (prev q) - ε * (hi - lo) ≤ dotProduct (weights q) vals := by + calc + vals (prev q) - ε * (hi - lo) ≤ + vals (prev q) - (∑ k ∈ others, weights q k) * (hi - lo) := hlowershift + _ ≤ weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * lo := hsub + _ ≤ dotProduct (weights q) vals := by + calc + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * lo + ≤ weights q (prev q) * vals (prev q) + ∑ k ∈ others, weights q k * vals k := by + have h := + add_le_add_left hsum_vals_ge (weights q (prev q) * vals (prev q)) + simpa [add_comm, add_left_comm, add_assoc] using h + _ = dotProduct (weights q) vals := by + simp [hout_eq] + have hlower : + vals (prev q) ≤ dotProduct (weights q) vals + ε * (hi - lo) := by + exact (sub_le_iff_le_add).1 hlow + exact ⟨hupper, hlower⟩ + +/-- Approximate one-hot weights plus bounded values yield an approximate induction spec. -/ +theorem inductionSpecApprox_of_oneHotApprox_valueRange + (ε lo hi : Val) + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Val) + (vals : Fin (Nat.succ n) → Val) + (hweights : OneHotApproxBoundsOn (Val := Val) ε prev weights) + (hvals : ValueRangeBounds (Val := Val) lo hi vals) : + InductionSpecApprox (Val := Val) (n := n) (ε * (hi - lo)) prev + (fun q => dotProduct (weights q) vals) vals := by + have hweights' : + OneHotApproxBoundsOnActive (Val := Val) ε (fun q => q ≠ 0) prev weights := + oneHotApproxBoundsOnActive_of_on (Val := Val) (seq := Nat.succ n) + (ε := ε) (prev := prev) (weights := weights) hweights + exact + inductionSpecApproxOn_of_oneHotApprox_valueRange + (Val := Val) + (n := n) + (ε := ε) + (lo := lo) + (hi := hi) + (active := fun q => q ≠ 0) + (prev := prev) + (weights := weights) + (vals := vals) + (hweights := hweights') + (hvals := hvals) + +end ApproxOutput + +section SoftmaxMargin + +variable {Val : Type v} [Semiring Val] [PartialOrder Val] +variable {seq : Nat} [NeZero seq] + +/-- Softmax margin certificates for approximate one-hot weights. -/ +structure SoftmaxMarginBounds (ε margin : Val) (prev : Fin seq → Fin seq) + (scores weights : Fin seq → Fin seq → Val) : Prop where + /-- Score gap between `prev` and other keys on nonzero queries. -/ + score_margin : ∀ q, q ≠ 0 → ∀ k, k ≠ prev q → scores q k + margin ≤ scores q (prev q) + /-- All weights are nonnegative on nonzero queries. -/ + nonneg : ∀ q, q ≠ 0 → ∀ k, 0 ≤ weights q k + /-- Weights sum to one on nonzero queries. -/ + sum_one : ∀ q, q ≠ 0 → (∑ k, weights q k) = 1 + /-- The `prev` weight is within `ε` of one on nonzero queries. -/ + prev_large : ∀ q, q ≠ 0 → 1 ≤ weights q (prev q) + ε + /-- Non-prev weights are at most `ε` on nonzero queries. -/ + other_le : ∀ q, q ≠ 0 → ∀ k, k ≠ prev q → weights q k ≤ ε + +/-- Softmax margin certificates for approximate one-hot weights on active queries. -/ +structure SoftmaxMarginBoundsOn (ε margin : Val) (active : Fin seq → Prop) + (prev : Fin seq → Fin seq) + (scores weights : Fin seq → Fin seq → Val) : Prop where + /-- Score gap between `prev` and other keys on active queries. -/ + score_margin : ∀ q, active q → ∀ k, k ≠ prev q → scores q k + margin ≤ scores q (prev q) + /-- All weights are nonnegative on active queries. -/ + nonneg : ∀ q, active q → ∀ k, 0 ≤ weights q k + /-- Weights sum to one on active queries. -/ + sum_one : ∀ q, active q → (∑ k, weights q k) = 1 + /-- The `prev` weight is within `ε` of one on active queries. -/ + prev_large : ∀ q, active q → 1 ≤ weights q (prev q) + ε + /-- Non-prev weights are at most `ε` on active queries. -/ + other_le : ∀ q, active q → ∀ k, k ≠ prev q → weights q k ≤ ε + +/-- Lift global softmax-margin bounds to an active-set version. -/ +theorem softmaxMarginBoundsOn_of_on (ε margin : Val) (prev : Fin seq → Fin seq) + (scores weights : Fin seq → Fin seq → Val) + (h : SoftmaxMarginBounds (Val := Val) ε margin prev scores weights) : + SoftmaxMarginBoundsOn (Val := Val) ε margin (fun q => q ≠ 0) prev scores weights := by + refine + { score_margin := ?_ + nonneg := ?_ + sum_one := ?_ + prev_large := ?_ + other_le := ?_ } + · intro q hq k hk + exact h.score_margin q hq k hk + · intro q hq k + exact h.nonneg q hq k + · intro q hq + exact h.sum_one q hq + · intro q hq + exact h.prev_large q hq + · intro q hq k hk + exact h.other_le q hq k hk + +/-- Margin certificates yield approximate one-hot bounds for the weights. -/ +theorem oneHotApproxBounds_of_softmaxMargin (ε margin : Val) (prev : Fin seq → Fin seq) + (scores weights : Fin seq → Fin seq → Val) + (h : SoftmaxMarginBounds (Val := Val) ε margin prev scores weights) : + OneHotApproxBoundsOn (Val := Val) ε prev weights := by + exact + { nonneg := h.nonneg + sum_one := h.sum_one + prev_large := h.prev_large + other_le := h.other_le } + +/-- Margin certificates imply approximate induction-weight bounds. -/ +theorem inductionWeightsApprox_of_softmaxMargin (ε margin : Val) + (prev : Fin seq → Fin seq) + (scores weights : Fin seq → Fin seq → Val) + (h : SoftmaxMarginBounds (Val := Val) ε margin prev scores weights) : + InductionWeightsApprox (Val := Val) ε prev weights := by + exact inductionWeightsApprox_of_boundsOn + (Val := Val) + (seq := seq) + (ε := ε) + (prev := prev) + (weights := weights) + (h := oneHotApproxBounds_of_softmaxMargin + (Val := Val) + (seq := seq) + (ε := ε) + (margin := margin) + (prev := prev) + (scores := scores) + (weights := weights) + h) + +end SoftmaxMargin + +section SoftmaxMarginActive + +variable {Val : Type v} [Semiring Val] [PartialOrder Val] +variable {seq : Nat} + +/-- Margin certificates yield approximate one-hot bounds on active queries. -/ +theorem oneHotApproxBoundsOnActive_of_softmaxMargin (ε margin : Val) + (active : Fin seq → Prop) + (prev : Fin seq → Fin seq) + (scores weights : Fin seq → Fin seq → Val) + (h : SoftmaxMarginBoundsOn (Val := Val) ε margin active prev scores weights) : + OneHotApproxBoundsOnActive (Val := Val) ε active prev weights := by + exact + { nonneg := h.nonneg + sum_one := h.sum_one + prev_large := h.prev_large + other_le := h.other_le } + +end SoftmaxMarginActive + +section Attention + +variable {Batch : Type} [Fintype Batch] [DecidableEq Batch] +variable {seq heads dim : Nat} +variable {Val : Type v} [NonAssocSemiring Val] + +/-- Typed V-input label for attention cores. -/ +abbrev attnInputV (v : QkvIndex Batch seq heads dim) : + AttentionInput Batch seq heads dim := + Sum.inr (Sum.inr v) + +/-- Weight function feeding an attention output node. -/ +def attentionOutWeights (b : Batch) (h : Fin heads) (q : Fin seq) (d : Fin dim) + (rec : + ∀ j, + (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel j + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) → + Val) : + Fin seq → Val := + fun k => + rec (attnWeight (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, h, q, k)) + (attentionDag_rel_weight_out (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q k d) + +/-- Value function feeding an attention output node. -/ +def attentionOutValues (b : Batch) (h : Fin heads) (q : Fin seq) (d : Fin dim) + (rec : + ∀ j, + (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel j + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) → + Val) : + Fin seq → Val := + fun k => + rec (attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, k, h, d)) + (attentionDag_rel_v_out (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b k h d q) + +/-- One-hot attention weights force the output to copy the selected value. -/ +theorem attentionGate_out_eq_of_oneHot (scale : Val) + (softmax : (Fin seq → Val) → Fin seq → Val) (prev : Fin seq → Fin seq) + (b : Batch) (h : Fin heads) (q : Fin seq) (d : Fin dim) + (rec : + ∀ j, + (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel j + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) → + Val) + (hweights : + attentionOutWeights (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d rec = + Pi.single (prev q) 1) : + attentionGate (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) rec = + attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d rec (prev q) := by + simp only [attentionGate] + change + dotProduct + (attentionOutWeights (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d rec) + (attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d rec) = + attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d rec (prev q) + rw [hweights] + exact dotProduct_eq_of_oneHot (Val := Val) (seq := seq) (k := prev q) + (vals := attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d rec) + +section Typed + +variable (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) + +/-- Attention output equals the selected V input when weights are one-hot. -/ +theorem attentionTyped_eval_out_eq_of_oneHot (prev : Fin seq → Fin seq) + (input : AttentionInput Batch seq heads dim → Val) + (b : Batch) (h : Fin heads) (q : Fin seq) (d : Fin dim) + (hweights : + attentionOutWeights (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d + (fun j _ => + Circuit.evalInput + (attentionCircuit (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + scale softmax) + ((attentionInterface (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + scale softmax).toInputAssignment input) j) = + Pi.single (prev q) 1) : + (attentionTyped (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax).eval + input (b, q, h, d) = + input + (attnInputV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, prev q, h, d)) := by + let C := + attentionCircuit (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax + let I := + attentionInterface (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax + let inputAssign := I.toInputAssignment input + have hnot : + attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d) ∉ + attentionInputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) := by + simpa using + (not_mem_attentionInputs_inr (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (s := Sum.inr (Sum.inr (b, q, h, d)))) + have hgate : + Circuit.evalInput C inputAssign + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) = + attentionGate (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) + (fun j _ => Circuit.evalInput C inputAssign j) := by + exact Circuit.evalInput_eq_gate (C := C) (input := inputAssign) + (i := attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) + hnot + have hcopy : + Circuit.evalInput C inputAssign + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) = + attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d (fun j _ => Circuit.evalInput C inputAssign j) (prev q) := by + have hgate' : + attentionGate (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) + (fun j _ => Circuit.evalInput C inputAssign j) = + attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d (fun j _ => Circuit.evalInput C inputAssign j) (prev q) := + attentionGate_out_eq_of_oneHot (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (scale := scale) (softmax := softmax) (prev := prev) (b := b) (h := h) (q := q) (d := d) + (rec := fun j _ => Circuit.evalInput C inputAssign j) hweights + exact hgate.trans hgate' + have hmem : + attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, prev q, h, d) ∈ + attentionInputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) := by + refine (mem_attentionInputs_iff (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim)).2 ?_ + exact ⟨attnInputV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, prev q, h, d), rfl⟩ + have hinput : + Circuit.evalInput C inputAssign + (attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, prev q, h, d)) = + input (attnInputV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, prev q, h, d)) := by + have h := + Circuit.evalInput_eq_input (C := C) (input := inputAssign) + (i := attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, prev q, h, d)) hmem + simpa [inputAssign, I, attentionInterface, attnInputV] using h + have hvals : + attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d (fun j _ => Circuit.evalInput C inputAssign j) (prev q) = + Circuit.evalInput C inputAssign + (attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, prev q, h, d)) := rfl + calc + (attentionTyped (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax).eval + input (b, q, h, d) = + Circuit.evalInput C inputAssign (I.outputs (b, q, h, d)).1 := by + simp [TypedCircuit.eval, Interface.eval, C, I, inputAssign, attentionTyped] + _ = Circuit.evalInput C inputAssign + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) := by + rfl + _ = attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d (fun j _ => Circuit.evalInput C inputAssign j) (prev q) := hcopy + _ = Circuit.evalInput C inputAssign + (attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, prev q, h, d)) := hvals + _ = input (attnInputV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, prev q, h, d)) := hinput + +end Typed + +end Attention + +section InductionSpecTyped + +variable {Batch : Type} [Fintype Batch] [DecidableEq Batch] +variable {heads dim n : Nat} +variable {Val : Type v} [NonAssocSemiring Val] + +variable (scale : Val) +variable (softmax : (Fin (Nat.succ n) → Val) → Fin (Nat.succ n) → Val) + +/-- One-hot weights on nonzero queries imply the induction spec for typed evaluation. -/ +theorem attentionTyped_eval_inductionSpec_of_oneHot + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (input : AttentionInput Batch (Nat.succ n) heads dim → Val) + (b : Batch) (h : Fin heads) (d : Fin dim) + (hweights : + ∀ q, q ≠ 0 → + attentionOutWeights + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + b h q d + (fun j _ => + Circuit.evalInput + (attentionCircuit + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax) + ((attentionInterface + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax).toInputAssignment input) j) = + Pi.single (prev q) 1) : + InductionSpec (n := n) prev + (fun q => + (attentionTyped + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax).eval input (b, q, h, d)) + (fun k => + input (attnInputV + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + (b, k, h, d))) := by + intro q hq + have hweights_q := hweights q hq + exact attentionTyped_eval_out_eq_of_oneHot + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + (scale := scale) + (softmax := softmax) + (prev := prev) + (input := input) + (b := b) + (h := h) + (q := q) + (d := d) + hweights_q + +/-- Induction spec for `prevIndex` under one-hot weight hypotheses. -/ +theorem attentionTyped_eval_inductionSpec_prevIndex + (input : AttentionInput Batch (Nat.succ n) heads dim → Val) + (b : Batch) (h : Fin heads) (d : Fin dim) + (hweights : + ∀ q, q ≠ 0 → + attentionOutWeights + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + b h q d + (fun j _ => + Circuit.evalInput + (attentionCircuit + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax) + ((attentionInterface + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax).toInputAssignment input) j) = + Pi.single (prevIndex (n := n) q) 1) : + InductionSpec (n := n) (prevIndex (n := n)) + (fun q => + (attentionTyped + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax).eval input (b, q, h, d)) + (fun k => + input (attnInputV + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + (b, k, h, d))) := by + exact attentionTyped_eval_inductionSpec_of_oneHot + (Batch := Batch) + (heads := heads) + (dim := dim) + (n := n) + (scale := scale) + (softmax := softmax) + (prev := prevIndex (n := n)) + (input := input) + (b := b) + (h := h) + (d := d) + hweights + +end InductionSpecTyped + +section InductionSpecApproxTyped + +variable {Batch : Type} [Fintype Batch] [DecidableEq Batch] +variable {heads dim n : Nat} +variable {Val : Type v} [NonAssocSemiring Val] [PartialOrder Val] [IsOrderedAddMonoid Val] + +variable (scale : Val) +variable (softmax : (Fin (Nat.succ n) → Val) → Fin (Nat.succ n) → Val) + +/-- One-hot weights imply the approximate induction spec for any nonnegative tolerance. -/ +theorem attentionTyped_eval_inductionSpecApprox_of_oneHot (ε : Val) (hε : 0 ≤ ε) + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (input : AttentionInput Batch (Nat.succ n) heads dim → Val) + (b : Batch) (h : Fin heads) (d : Fin dim) + (hweights : + ∀ q, q ≠ 0 → + attentionOutWeights + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + b h q d + (fun j _ => + Circuit.evalInput + (attentionCircuit + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax) + ((attentionInterface + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax).toInputAssignment input) j) = + Pi.single (prev q) 1) : + InductionSpecApprox (Val := Val) (n := n) ε prev + (fun q => + (attentionTyped + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax).eval input (b, q, h, d)) + (fun k => + input (attnInputV + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + (b, k, h, d))) := by + apply inductionSpecApprox_of_spec (Val := Val) (n := n) (ε := ε) hε + exact attentionTyped_eval_inductionSpec_of_oneHot + (Batch := Batch) + (heads := heads) + (dim := dim) + (n := n) + (scale := scale) + (softmax := softmax) + (prev := prev) + (input := input) + (b := b) + (h := h) + (d := d) + hweights + +end InductionSpecApproxTyped + +end Layers + +end Circuit + +end Nfp diff --git a/Nfp/IO/Run.lean b/Nfp/IO/Run.lean index 9e9a598..2ea7c46 100644 --- a/Nfp/IO/Run.lean +++ b/Nfp/IO/Run.lean @@ -1,805 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.IO.Checks -import Nfp.IO.Derive -import Nfp.IO.HeadScore -import Nfp.IO.InductionHead -import Nfp.IO.Loaders -import Nfp.IO.NfptPure -import Nfp.IO.Timing -import Nfp.IO.Util -import Nfp.Circuit.Cert.DownstreamLinear -import Nfp.Circuit.Cert.LogitDiff -import Nfp.Circuit.Cert.ResidualBound -import Nfp.Circuit.Cert.ResidualInterval -import Nfp.Sound.Bounds.MatrixNorm -import Nfp.Sound.Bounds.Transformer -import Nfp.Sound.Induction -import Nfp.Sound.Induction.HeadBounds -import Nfp.Sound.Induction.LogitDiff -import Nfp.Sound.Linear.FinFold +import Nfp.IO.Run.Basic /-! IO entrypoints used by the CLI. -/ - -namespace Nfp -namespace IO -open Nfp.Circuit - -/-- Check induction certificates and print a short status line. -/ -def runInductionCertify (scoresPath : System.FilePath) - (valuesPath? : Option System.FilePath) (minActive? : Option Nat) - (minLogitDiffStr? : Option String) (minMarginStr? : Option String) - (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - if minLogitDiff?.isSome && valuesPath?.isNone then - IO.eprintln "error: min-logit-diff requires --values" - return 2 - let parsedScores ← timePhase "load softmax cert" <| - loadSoftmaxMarginCert scoresPath - match parsedScores with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seq, cert⟩ => - let scoresOk ← timePhase "check softmax cert" <| - checkSoftmaxMargin seq cert - match scoresOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - let activeCount := cert.active.card - let defaultMinActive := max 1 (seq / 8) - let minActive := minActive?.getD defaultMinActive - if activeCount < minActive then - IO.eprintln - s!"error: active queries {activeCount} below minimum {minActive}" - return 2 - if cert.margin < minMargin then - IO.eprintln - s!"error: margin {cert.margin} below minimum {minMargin}" - return 2 - if maxEps < cert.eps then - IO.eprintln - s!"error: eps {cert.eps} above maximum {maxEps}" - return 2 - match valuesPath? with - | none => - IO.println - s!"ok: softmax-margin certificate accepted \ - (seq={seq}, active={activeCount})" - return 0 - | some valuesPath => - let parsedValues ← loadValueRangeCert valuesPath - match parsedValues with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seqVals, certVals⟩ => - if hseq : seqVals ≠ seq then - IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" - return 2 - else - have hseq' : seqVals = seq := by - exact (not_ne_iff).1 hseq - let certVals' : ValueRangeCert seq := by - simpa [hseq'] using certVals - let valuesOk ← checkValueRange seq certVals' - match valuesOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - let tol := cert.eps * (certVals'.hi - certVals'.lo) - let logitDiffLB? := - Circuit.logitDiffLowerBound cert.active cert.prev cert.eps - certVals'.lo certVals'.hi certVals'.vals - let effectiveMinLogitDiff := - match minLogitDiff?, certVals'.direction with - | some v, _ => some v - | none, some _ => some (0 : Rat) - | none, none => none - match logitDiffLB? with - | none => - IO.eprintln "error: empty active set for logit-diff bound" - return (2 : UInt32) - | some logitDiffLB => - let violation? : Option Rat := - match effectiveMinLogitDiff with - | none => none - | some minLogitDiff => - if logitDiffLB < minLogitDiff then - some minLogitDiff - else - none - match violation? with - | some minLogitDiff => - IO.eprintln - s!"error: logitDiffLB {logitDiffLB} \ - below minimum {minLogitDiff}" - return (2 : UInt32) - | none => - IO.println - s!"ok: induction bound certified \ - (seq={seq}, active={activeCount}, tol={tol}, \ - logitDiffLB={logitDiffLB})" - return 0 -/-- Build and check induction certificates from raw scores/values. -/ -def runInductionCertifySound (scoresPath : System.FilePath) - (valuesPath : System.FilePath) (minActive? : Option Nat) - (minLogitDiffStr? : Option String) (minMarginStr? : Option String) - (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - let parsedScores ← loadSoftmaxMarginRaw scoresPath - match parsedScores with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seq, raw⟩ => - match seq with - | 0 => - IO.eprintln "error: seq must be positive" - return 2 - | Nat.succ n => - let seq := Nat.succ n - let _ : NeZero seq := ⟨by simp⟩ - match Sound.buildSoftmaxMarginCert? raw.active raw.prev raw.scores raw.weights with - | none => - IO.eprintln "error: softmax-margin inputs rejected" - return 2 - | some ⟨cert, _⟩ => - let activeCount := cert.active.card - let defaultMinActive := max 1 (seq / 8) - let minActive := minActive?.getD defaultMinActive - if activeCount < minActive then - IO.eprintln - s!"error: active queries {activeCount} below minimum {minActive}" - return 2 - if cert.margin < minMargin then - IO.eprintln - s!"error: margin {cert.margin} below minimum {minMargin}" - return 2 - if maxEps < cert.eps then - IO.eprintln - s!"error: eps {cert.eps} above maximum {maxEps}" - return 2 - let parsedValues ← loadValueRangeRaw valuesPath - match parsedValues with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seqVals, rawVals⟩ => - if hseq : seqVals ≠ seq then - IO.eprintln - s!"error: seq mismatch (scores={seq}, values={seqVals})" - return 2 - else - have hseq' : seqVals = seq := by - exact (not_ne_iff).1 hseq - let rawVals' : Pure.ValueRangeRaw seq := by - simpa [hseq'] using rawVals - match Sound.buildValueRangeCert? rawVals'.vals rawVals'.direction with - | none => - IO.eprintln "error: value-range inputs rejected" - return 2 - | some ⟨certVals, _⟩ => - let tol := cert.eps * (certVals.hi - certVals.lo) - let logitDiffLB? := - Circuit.logitDiffLowerBound cert.active cert.prev cert.eps - certVals.lo certVals.hi certVals.vals - let effectiveMinLogitDiff := - match minLogitDiff?, certVals.direction with - | some v, _ => some v - | none, some _ => some (0 : Rat) - | none, none => none - match logitDiffLB? with - | none => - IO.eprintln "error: empty active set for logit-diff bound" - return 2 - | some logitDiffLB => - let violation? : Option Rat := - match effectiveMinLogitDiff with - | none => none - | some minLogitDiff => - if logitDiffLB < minLogitDiff then - some minLogitDiff - else - none - match violation? with - | some minLogitDiff => - IO.eprintln - s!"error: logitDiffLB {logitDiffLB} \ - below minimum {minLogitDiff}" - return 2 - | none => - IO.println - s!"ok: induction bound certified \ - (seq={seq}, active={activeCount}, \ - tol={tol}, logitDiffLB={logitDiffLB})" - return 0 -/-- Check end-to-end induction certificates with a downstream error bound. -/ -def runInductionCertifyEndToEnd (scoresPath : System.FilePath) - (valuesPath : System.FilePath) (downstreamPath : System.FilePath) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - let parsedScores ← timePhase "load softmax cert" <| - loadSoftmaxMarginCert scoresPath - match parsedScores with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seq, cert⟩ => - let scoresOk ← timePhase "check softmax cert" <| - checkSoftmaxMargin seq cert - match scoresOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - let activeCount := cert.active.card - let defaultMinActive := max 1 (seq / 8) - let minActive := minActive?.getD defaultMinActive - if activeCount < minActive then - IO.eprintln - s!"error: active queries {activeCount} below minimum {minActive}" - return 2 - if cert.margin < minMargin then - IO.eprintln - s!"error: margin {cert.margin} below minimum {minMargin}" - return 2 - if maxEps < cert.eps then - IO.eprintln - s!"error: eps {cert.eps} above maximum {maxEps}" - return 2 - let parsedValues ← timePhase "load value cert" <| - loadValueRangeCert valuesPath - match parsedValues with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seqVals, certVals⟩ => - if hseq : seqVals ≠ seq then - IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" - return 2 - else - have hseq' : seqVals = seq := by - exact (not_ne_iff).1 hseq - let certVals' : ValueRangeCert seq := by - simpa [hseq'] using certVals - let valuesOk ← timePhase "check value cert" <| - checkValueRange seq certVals' - match valuesOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - let logitDiffLB? ← timePure "logit-diff lower bound" (fun () => - Circuit.logitDiffLowerBound cert.active cert.prev cert.eps - certVals'.lo certVals'.hi certVals'.vals) - let effectiveMinLogitDiff := - match minLogitDiff?, certVals'.direction with - | some v, _ => some v - | none, some _ => some (0 : Rat) - | none, none => none - match logitDiffLB? with - | none => - IO.eprintln "error: empty active set for logit-diff bound" - return (2 : UInt32) - | some logitDiffLB => - let parsedDownstream ← loadDownstreamLinearCert downstreamPath - match parsedDownstream with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok downstream => - let downstreamOk := Circuit.checkDownstreamLinearCert downstream - if downstreamOk then - let finalLB := logitDiffLB - downstream.error - let violation? : Option Rat := - match effectiveMinLogitDiff with - | none => none - | some minLogitDiff => - if finalLB < minLogitDiff then - some minLogitDiff - else - none - match violation? with - | some minLogitDiff => - IO.eprintln - s!"error: end-to-end logitDiffLB {finalLB} \ - below minimum {minLogitDiff}" - return (2 : UInt32) - | none => - IO.println - s!"ok: end-to-end induction bound certified \ - (seq={seq}, active={activeCount}, \ - logitDiffLB={logitDiffLB}, \ - downstreamError={downstream.error}, \ - finalLB={finalLB})" - return 0 - else - IO.eprintln "error: downstream certificate rejected" - return 2 -/-- Check end-to-end induction certificates with a downstream matrix. -/ -def runInductionCertifyEndToEndMatrix (scoresPath : System.FilePath) - (valuesPath : System.FilePath) (matrixPath : System.FilePath) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - let parsedScores ← timePhase "load softmax cert" <| - loadSoftmaxMarginCert scoresPath - match parsedScores with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seq, cert⟩ => - let scoresOk ← timePhase "check softmax cert" <| - checkSoftmaxMargin seq cert - match scoresOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - let activeCount := cert.active.card - let defaultMinActive := max 1 (seq / 8) - let minActive := minActive?.getD defaultMinActive - if activeCount < minActive then - IO.eprintln - s!"error: active queries {activeCount} below minimum {minActive}" - return 2 - if cert.margin < minMargin then - IO.eprintln - s!"error: margin {cert.margin} below minimum {minMargin}" - return 2 - if maxEps < cert.eps then - IO.eprintln - s!"error: eps {cert.eps} above maximum {maxEps}" - return 2 - let parsedValues ← timePhase "load value cert" <| - loadValueRangeCert valuesPath - match parsedValues with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seqVals, certVals⟩ => - if hseq : seqVals ≠ seq then - IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" - return 2 - else - have hseq' : seqVals = seq := by - exact (not_ne_iff).1 hseq - let certVals' : ValueRangeCert seq := by - simpa [hseq'] using certVals - let valuesOk ← timePhase "check value cert" <| - checkValueRange seq certVals' - match valuesOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - let logitDiffLB? ← timePure "logit-diff lower bound" (fun () => - Circuit.logitDiffLowerBound cert.active cert.prev cert.eps - certVals'.lo certVals'.hi certVals'.vals) - let effectiveMinLogitDiff := - match minLogitDiff?, certVals'.direction with - | some v, _ => some v - | none, some _ => some (0 : Rat) - | none, none => none - match logitDiffLB? with - | none => - IO.eprintln "error: empty active set for logit-diff bound" - return (2 : UInt32) - | some logitDiffLB => - let parsedMatrix ← loadDownstreamMatrixRaw matrixPath - match parsedMatrix with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨rows, ⟨cols, raw⟩⟩ => - let inputBound := raw.inputBound - if hneg : inputBound < 0 then - IO.eprintln - s!"error: input-bound {inputBound} must be nonnegative" - return 2 - else - have hinput : 0 ≤ inputBound := by - exact le_of_not_gt hneg - let W : Matrix (Fin rows) (Fin cols) Rat := raw.entries - let downstream := - (Sound.Bounds.buildDownstreamLinearCert W inputBound hinput).1 - let finalLB := logitDiffLB - downstream.error - let violation? : Option Rat := - match effectiveMinLogitDiff with - | none => none - | some minLogitDiff => - if finalLB < minLogitDiff then - some minLogitDiff - else - none - match violation? with - | some minLogitDiff => - IO.eprintln - s!"error: end-to-end logitDiffLB {finalLB} \ - below minimum {minLogitDiff}" - return (2 : UInt32) - | none => - IO.println - s!"ok: end-to-end induction bound certified \ - (seq={seq}, active={activeCount}, \ - logitDiffLB={logitDiffLB}, \ - downstreamError={downstream.error}, \ - finalLB={finalLB})" - return 0 -/-- Check end-to-end induction certificates using a model file and residual bounds - (loaded from disk or derived from the model). -/ -def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) - (valuesPath : System.FilePath) (modelPath : System.FilePath) - (residualIntervalPath? : Option System.FilePath) - (layer? : Option Nat) (head? : Option Nat) (period? : Option Nat) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - let parsedScores ← timePhase "load softmax cert" <| - loadSoftmaxMarginCert scoresPath - match parsedScores with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seq, cert⟩ => - let scoresOk ← timePhase "check softmax cert" <| - checkSoftmaxMargin seq cert - match scoresOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - let activeCount := cert.active.card - let defaultMinActive := max 1 (seq / 8) - let minActive := minActive?.getD defaultMinActive - if activeCount < minActive then - IO.eprintln - s!"error: active queries {activeCount} below minimum {minActive}" - return 2 - if cert.margin < minMargin then - IO.eprintln - s!"error: margin {cert.margin} below minimum {minMargin}" - return 2 - if maxEps < cert.eps then - IO.eprintln - s!"error: eps {cert.eps} above maximum {maxEps}" - return 2 - let parsedValues ← timePhase "load value cert" <| - loadValueRangeCert valuesPath - match parsedValues with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seqVals, certVals⟩ => - if hseq : seqVals ≠ seq then - IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" - return 2 - else - have hseq' : seqVals = seq := by - exact (not_ne_iff).1 hseq - let certVals' : ValueRangeCert seq := by - simpa [hseq'] using certVals - let valuesOk ← timePhase "check value cert" <| - checkValueRange seq certVals' - match valuesOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - let logitDiffLB? ← timePure "logit-diff lower bound" (fun () => - Circuit.logitDiffLowerBound cert.active cert.prev cert.eps - certVals'.lo certVals'.hi certVals'.vals) - let effectiveMinLogitDiff := - match minLogitDiff?, certVals'.direction with - | some v, _ => some v - | none, some _ => some (0 : Rat) - | none, none => none - match logitDiffLB? with - | none => - IO.eprintln "error: empty active set for logit-diff bound" - return (2 : UInt32) - | some logitDiffLB => - match certVals'.direction with - | none => - IO.eprintln - "error: value-range certificate missing direction \ - metadata" - return 2 - | some dirSpec => - let data ← timePhase "read model file" <| - IO.FS.readBinFile modelPath - let headerE ← timePure "parse model header" (fun () => - NfptPure.parseHeader data) - match headerE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨header, start⟩ => - if hseq : header.seqLen = seq then - let active? : Option (Finset (Fin header.seqLen)) := - if hactive : cert.active.Nonempty then - some (by simpa [hseq] using cert.active) - else - none - let residualCertE : Except String - (ResidualIntervalCert header.modelDim) ← - match residualIntervalPath? with - | some residualIntervalPath => do - let parsedResidual ← - timePhase "load residual interval" <| - loadResidualIntervalCert residualIntervalPath - match parsedResidual with - | Except.error msg => pure (Except.error msg) - | Except.ok ⟨dim, residualCert⟩ => - if hdim : dim = header.modelDim then - let residualCert' : - ResidualIntervalCert header.modelDim := by - simpa [hdim] using residualCert - pure (Except.ok residualCert') - else - pure (Except.error - s!"residual interval dim {dim} \ - does not match model dim {header.modelDim}") - | none => - deriveResidualIntervalFromModel data start header - active? - match residualCertE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok residualCert' => - let residualOk ← - timePure "check residual interval" (fun () => - Circuit.checkResidualIntervalCert residualCert') - if residualOk then - let dirPos := dirSpec.target - let dirNeg := dirSpec.negative - if layer?.isSome != head?.isSome then - IO.eprintln - "error: --layer and --head must be provided \ - together" - return 2 - let headChoice? : Option (Nat × Nat) := - match layer?, head? with - | some layer, some head => some (layer, head) - | _, _ => none - if period?.isSome && headChoice?.isNone then - IO.eprintln - "warning: --period ignored without \ - --layer/--head" - let colTargetE ← - timePure "read unembed column target" (fun () => - NfptPure.readUnembedColumn - data start header dirPos) - match colTargetE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok colTarget => - let colNegE ← - timePure "read unembed column negative" (fun () => - NfptPure.readUnembedColumn - data start header dirNeg) - match colNegE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok colNeg => - let dirVec : - Fin header.modelDim → Rat := - fun i => colTarget i - colNeg i - let dotIntervalAbs := - Sound.Bounds.dotIntervalAbsBound - let intervalErrorFromHead? : - Model.InductionHeadInputs - seq header.modelDim header.headDim → - ResidualIntervalCert header.modelDim → - Option Rat := - fun inputs residual => by - classical - match hseq0 : seq with - | 0 => exact none - | Nat.succ n => - let _ : NeZero seq := by - exact ⟨by simp [hseq0]⟩ - match - Sound.buildHeadOutputIntervalFromHead? - inputs with - | none => exact none - | some result => - exact some - (dotIntervalAbs - dirVec - (fun i => - residual.lo i - - result.cert.hi i) - (fun i => - residual.hi i - - result.cert.lo i)) - let downstreamError ← - timePure "downstream error" (fun () => - dotIntervalAbs - dirVec - residualCert'.lo - residualCert'.hi) - let finalLB := logitDiffLB - downstreamError - let intervalError? ← - match headChoice? with - | none => pure none - | some (layer, head) => do - let inputsE ← - timePure "read head inputs" (fun () => - NfptPure.readInductionHeadInputs - data start header layer head - dirPos dirNeg period?) - match inputsE with - | Except.error msg => - IO.eprintln s!"warning: {msg}" - pure none - | Except.ok inputs => - let inputs' : - Model.InductionHeadInputs - seq header.modelDim - header.headDim := by - simpa [hseq] using inputs - let inputsAligned : - Model.InductionHeadInputs - seq header.modelDim - header.headDim := - { inputs' with - active := cert.active - prev := cert.prev } - let intervalError? ← - timePure - "head output interval" - (fun () => - intervalErrorFromHead? - inputsAligned - residualCert') - match intervalError? with - | none => - IO.eprintln - "warning: head output interval \ - rejected" - pure none - | some intervalError => - pure (some intervalError) - let intervalLB? := - intervalError?.map (fun err => - logitDiffLB - err) - let effectiveLB := - match intervalLB? with - | some intervalLB => max finalLB intervalLB - | none => finalLB - let violation? : Option Rat := - match effectiveMinLogitDiff with - | none => none - | some minLogitDiff => - if effectiveLB < minLogitDiff then - some minLogitDiff - else - none - match violation? with - | some minLogitDiff => - IO.eprintln - s!"error: end-to-end bound \ - {effectiveLB} below minimum \ - {minLogitDiff}" - return (2 : UInt32) - | none => - match intervalLB? with - | none => - IO.println - s!"ok: end-to-end induction \ - bound certified (seq={seq}, \ - active={activeCount}, \ - logitDiffLB={logitDiffLB}, \ - downstreamError={downstreamError}, \ - finalLB={finalLB})" - | some intervalLB => - let intervalError := - logitDiffLB - intervalLB - IO.println - s!"ok: end-to-end induction \ - bound certified (seq={seq}, \ - active={activeCount}, \ - logitDiffLB={logitDiffLB}, \ - downstreamError={downstreamError}, \ - finalLB={finalLB}, \ - intervalError={intervalError}, \ - intervalLB={intervalLB}, \ - effectiveLB={effectiveLB})" - return 0 - else - IO.eprintln - "error: residual-interval certificate rejected" - return 2 - else - IO.eprintln - s!"error: model seq {header.seqLen} \ - does not match cert seq {seq}" - return 2 -end IO -end Nfp diff --git a/Nfp/IO/Run/Basic.lean b/Nfp/IO/Run/Basic.lean new file mode 100644 index 0000000..9e9a598 --- /dev/null +++ b/Nfp/IO/Run/Basic.lean @@ -0,0 +1,805 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +import Nfp.IO.Checks +import Nfp.IO.Derive +import Nfp.IO.HeadScore +import Nfp.IO.InductionHead +import Nfp.IO.Loaders +import Nfp.IO.NfptPure +import Nfp.IO.Timing +import Nfp.IO.Util +import Nfp.Circuit.Cert.DownstreamLinear +import Nfp.Circuit.Cert.LogitDiff +import Nfp.Circuit.Cert.ResidualBound +import Nfp.Circuit.Cert.ResidualInterval +import Nfp.Sound.Bounds.MatrixNorm +import Nfp.Sound.Bounds.Transformer +import Nfp.Sound.Induction +import Nfp.Sound.Induction.HeadBounds +import Nfp.Sound.Induction.LogitDiff +import Nfp.Sound.Linear.FinFold + +/-! +IO entrypoints used by the CLI. +-/ + +namespace Nfp +namespace IO +open Nfp.Circuit + +/-- Check induction certificates and print a short status line. -/ +def runInductionCertify (scoresPath : System.FilePath) + (valuesPath? : Option System.FilePath) (minActive? : Option Nat) + (minLogitDiffStr? : Option String) (minMarginStr? : Option String) + (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) + if minLogitDiff?.isSome && valuesPath?.isNone then + IO.eprintln "error: min-logit-diff requires --values" + return 2 + let parsedScores ← timePhase "load softmax cert" <| + loadSoftmaxMarginCert scoresPath + match parsedScores with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seq, cert⟩ => + let scoresOk ← timePhase "check softmax cert" <| + checkSoftmaxMargin seq cert + match scoresOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let activeCount := cert.active.card + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" + return 2 + if cert.margin < minMargin then + IO.eprintln + s!"error: margin {cert.margin} below minimum {minMargin}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {cert.eps} above maximum {maxEps}" + return 2 + match valuesPath? with + | none => + IO.println + s!"ok: softmax-margin certificate accepted \ + (seq={seq}, active={activeCount})" + return 0 + | some valuesPath => + let parsedValues ← loadValueRangeCert valuesPath + match parsedValues with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seqVals, certVals⟩ => + if hseq : seqVals ≠ seq then + IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" + return 2 + else + have hseq' : seqVals = seq := by + exact (not_ne_iff).1 hseq + let certVals' : ValueRangeCert seq := by + simpa [hseq'] using certVals + let valuesOk ← checkValueRange seq certVals' + match valuesOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let tol := cert.eps * (certVals'.hi - certVals'.lo) + let logitDiffLB? := + Circuit.logitDiffLowerBound cert.active cert.prev cert.eps + certVals'.lo certVals'.hi certVals'.vals + let effectiveMinLogitDiff := + match minLogitDiff?, certVals'.direction with + | some v, _ => some v + | none, some _ => some (0 : Rat) + | none, none => none + match logitDiffLB? with + | none => + IO.eprintln "error: empty active set for logit-diff bound" + return (2 : UInt32) + | some logitDiffLB => + let violation? : Option Rat := + match effectiveMinLogitDiff with + | none => none + | some minLogitDiff => + if logitDiffLB < minLogitDiff then + some minLogitDiff + else + none + match violation? with + | some minLogitDiff => + IO.eprintln + s!"error: logitDiffLB {logitDiffLB} \ + below minimum {minLogitDiff}" + return (2 : UInt32) + | none => + IO.println + s!"ok: induction bound certified \ + (seq={seq}, active={activeCount}, tol={tol}, \ + logitDiffLB={logitDiffLB})" + return 0 +/-- Build and check induction certificates from raw scores/values. -/ +def runInductionCertifySound (scoresPath : System.FilePath) + (valuesPath : System.FilePath) (minActive? : Option Nat) + (minLogitDiffStr? : Option String) (minMarginStr? : Option String) + (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) + let parsedScores ← loadSoftmaxMarginRaw scoresPath + match parsedScores with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seq, raw⟩ => + match seq with + | 0 => + IO.eprintln "error: seq must be positive" + return 2 + | Nat.succ n => + let seq := Nat.succ n + let _ : NeZero seq := ⟨by simp⟩ + match Sound.buildSoftmaxMarginCert? raw.active raw.prev raw.scores raw.weights with + | none => + IO.eprintln "error: softmax-margin inputs rejected" + return 2 + | some ⟨cert, _⟩ => + let activeCount := cert.active.card + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" + return 2 + if cert.margin < minMargin then + IO.eprintln + s!"error: margin {cert.margin} below minimum {minMargin}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {cert.eps} above maximum {maxEps}" + return 2 + let parsedValues ← loadValueRangeRaw valuesPath + match parsedValues with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seqVals, rawVals⟩ => + if hseq : seqVals ≠ seq then + IO.eprintln + s!"error: seq mismatch (scores={seq}, values={seqVals})" + return 2 + else + have hseq' : seqVals = seq := by + exact (not_ne_iff).1 hseq + let rawVals' : Pure.ValueRangeRaw seq := by + simpa [hseq'] using rawVals + match Sound.buildValueRangeCert? rawVals'.vals rawVals'.direction with + | none => + IO.eprintln "error: value-range inputs rejected" + return 2 + | some ⟨certVals, _⟩ => + let tol := cert.eps * (certVals.hi - certVals.lo) + let logitDiffLB? := + Circuit.logitDiffLowerBound cert.active cert.prev cert.eps + certVals.lo certVals.hi certVals.vals + let effectiveMinLogitDiff := + match minLogitDiff?, certVals.direction with + | some v, _ => some v + | none, some _ => some (0 : Rat) + | none, none => none + match logitDiffLB? with + | none => + IO.eprintln "error: empty active set for logit-diff bound" + return 2 + | some logitDiffLB => + let violation? : Option Rat := + match effectiveMinLogitDiff with + | none => none + | some minLogitDiff => + if logitDiffLB < minLogitDiff then + some minLogitDiff + else + none + match violation? with + | some minLogitDiff => + IO.eprintln + s!"error: logitDiffLB {logitDiffLB} \ + below minimum {minLogitDiff}" + return 2 + | none => + IO.println + s!"ok: induction bound certified \ + (seq={seq}, active={activeCount}, \ + tol={tol}, logitDiffLB={logitDiffLB})" + return 0 +/-- Check end-to-end induction certificates with a downstream error bound. -/ +def runInductionCertifyEndToEnd (scoresPath : System.FilePath) + (valuesPath : System.FilePath) (downstreamPath : System.FilePath) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) + let parsedScores ← timePhase "load softmax cert" <| + loadSoftmaxMarginCert scoresPath + match parsedScores with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seq, cert⟩ => + let scoresOk ← timePhase "check softmax cert" <| + checkSoftmaxMargin seq cert + match scoresOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let activeCount := cert.active.card + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" + return 2 + if cert.margin < minMargin then + IO.eprintln + s!"error: margin {cert.margin} below minimum {minMargin}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {cert.eps} above maximum {maxEps}" + return 2 + let parsedValues ← timePhase "load value cert" <| + loadValueRangeCert valuesPath + match parsedValues with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seqVals, certVals⟩ => + if hseq : seqVals ≠ seq then + IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" + return 2 + else + have hseq' : seqVals = seq := by + exact (not_ne_iff).1 hseq + let certVals' : ValueRangeCert seq := by + simpa [hseq'] using certVals + let valuesOk ← timePhase "check value cert" <| + checkValueRange seq certVals' + match valuesOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let logitDiffLB? ← timePure "logit-diff lower bound" (fun () => + Circuit.logitDiffLowerBound cert.active cert.prev cert.eps + certVals'.lo certVals'.hi certVals'.vals) + let effectiveMinLogitDiff := + match minLogitDiff?, certVals'.direction with + | some v, _ => some v + | none, some _ => some (0 : Rat) + | none, none => none + match logitDiffLB? with + | none => + IO.eprintln "error: empty active set for logit-diff bound" + return (2 : UInt32) + | some logitDiffLB => + let parsedDownstream ← loadDownstreamLinearCert downstreamPath + match parsedDownstream with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok downstream => + let downstreamOk := Circuit.checkDownstreamLinearCert downstream + if downstreamOk then + let finalLB := logitDiffLB - downstream.error + let violation? : Option Rat := + match effectiveMinLogitDiff with + | none => none + | some minLogitDiff => + if finalLB < minLogitDiff then + some minLogitDiff + else + none + match violation? with + | some minLogitDiff => + IO.eprintln + s!"error: end-to-end logitDiffLB {finalLB} \ + below minimum {minLogitDiff}" + return (2 : UInt32) + | none => + IO.println + s!"ok: end-to-end induction bound certified \ + (seq={seq}, active={activeCount}, \ + logitDiffLB={logitDiffLB}, \ + downstreamError={downstream.error}, \ + finalLB={finalLB})" + return 0 + else + IO.eprintln "error: downstream certificate rejected" + return 2 +/-- Check end-to-end induction certificates with a downstream matrix. -/ +def runInductionCertifyEndToEndMatrix (scoresPath : System.FilePath) + (valuesPath : System.FilePath) (matrixPath : System.FilePath) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) + let parsedScores ← timePhase "load softmax cert" <| + loadSoftmaxMarginCert scoresPath + match parsedScores with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seq, cert⟩ => + let scoresOk ← timePhase "check softmax cert" <| + checkSoftmaxMargin seq cert + match scoresOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let activeCount := cert.active.card + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" + return 2 + if cert.margin < minMargin then + IO.eprintln + s!"error: margin {cert.margin} below minimum {minMargin}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {cert.eps} above maximum {maxEps}" + return 2 + let parsedValues ← timePhase "load value cert" <| + loadValueRangeCert valuesPath + match parsedValues with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seqVals, certVals⟩ => + if hseq : seqVals ≠ seq then + IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" + return 2 + else + have hseq' : seqVals = seq := by + exact (not_ne_iff).1 hseq + let certVals' : ValueRangeCert seq := by + simpa [hseq'] using certVals + let valuesOk ← timePhase "check value cert" <| + checkValueRange seq certVals' + match valuesOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let logitDiffLB? ← timePure "logit-diff lower bound" (fun () => + Circuit.logitDiffLowerBound cert.active cert.prev cert.eps + certVals'.lo certVals'.hi certVals'.vals) + let effectiveMinLogitDiff := + match minLogitDiff?, certVals'.direction with + | some v, _ => some v + | none, some _ => some (0 : Rat) + | none, none => none + match logitDiffLB? with + | none => + IO.eprintln "error: empty active set for logit-diff bound" + return (2 : UInt32) + | some logitDiffLB => + let parsedMatrix ← loadDownstreamMatrixRaw matrixPath + match parsedMatrix with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨rows, ⟨cols, raw⟩⟩ => + let inputBound := raw.inputBound + if hneg : inputBound < 0 then + IO.eprintln + s!"error: input-bound {inputBound} must be nonnegative" + return 2 + else + have hinput : 0 ≤ inputBound := by + exact le_of_not_gt hneg + let W : Matrix (Fin rows) (Fin cols) Rat := raw.entries + let downstream := + (Sound.Bounds.buildDownstreamLinearCert W inputBound hinput).1 + let finalLB := logitDiffLB - downstream.error + let violation? : Option Rat := + match effectiveMinLogitDiff with + | none => none + | some minLogitDiff => + if finalLB < minLogitDiff then + some minLogitDiff + else + none + match violation? with + | some minLogitDiff => + IO.eprintln + s!"error: end-to-end logitDiffLB {finalLB} \ + below minimum {minLogitDiff}" + return (2 : UInt32) + | none => + IO.println + s!"ok: end-to-end induction bound certified \ + (seq={seq}, active={activeCount}, \ + logitDiffLB={logitDiffLB}, \ + downstreamError={downstream.error}, \ + finalLB={finalLB})" + return 0 +/-- Check end-to-end induction certificates using a model file and residual bounds + (loaded from disk or derived from the model). -/ +def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) + (valuesPath : System.FilePath) (modelPath : System.FilePath) + (residualIntervalPath? : Option System.FilePath) + (layer? : Option Nat) (head? : Option Nat) (period? : Option Nat) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) + let parsedScores ← timePhase "load softmax cert" <| + loadSoftmaxMarginCert scoresPath + match parsedScores with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seq, cert⟩ => + let scoresOk ← timePhase "check softmax cert" <| + checkSoftmaxMargin seq cert + match scoresOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let activeCount := cert.active.card + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" + return 2 + if cert.margin < minMargin then + IO.eprintln + s!"error: margin {cert.margin} below minimum {minMargin}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {cert.eps} above maximum {maxEps}" + return 2 + let parsedValues ← timePhase "load value cert" <| + loadValueRangeCert valuesPath + match parsedValues with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seqVals, certVals⟩ => + if hseq : seqVals ≠ seq then + IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" + return 2 + else + have hseq' : seqVals = seq := by + exact (not_ne_iff).1 hseq + let certVals' : ValueRangeCert seq := by + simpa [hseq'] using certVals + let valuesOk ← timePhase "check value cert" <| + checkValueRange seq certVals' + match valuesOk with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok () => + let logitDiffLB? ← timePure "logit-diff lower bound" (fun () => + Circuit.logitDiffLowerBound cert.active cert.prev cert.eps + certVals'.lo certVals'.hi certVals'.vals) + let effectiveMinLogitDiff := + match minLogitDiff?, certVals'.direction with + | some v, _ => some v + | none, some _ => some (0 : Rat) + | none, none => none + match logitDiffLB? with + | none => + IO.eprintln "error: empty active set for logit-diff bound" + return (2 : UInt32) + | some logitDiffLB => + match certVals'.direction with + | none => + IO.eprintln + "error: value-range certificate missing direction \ + metadata" + return 2 + | some dirSpec => + let data ← timePhase "read model file" <| + IO.FS.readBinFile modelPath + let headerE ← timePure "parse model header" (fun () => + NfptPure.parseHeader data) + match headerE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨header, start⟩ => + if hseq : header.seqLen = seq then + let active? : Option (Finset (Fin header.seqLen)) := + if hactive : cert.active.Nonempty then + some (by simpa [hseq] using cert.active) + else + none + let residualCertE : Except String + (ResidualIntervalCert header.modelDim) ← + match residualIntervalPath? with + | some residualIntervalPath => do + let parsedResidual ← + timePhase "load residual interval" <| + loadResidualIntervalCert residualIntervalPath + match parsedResidual with + | Except.error msg => pure (Except.error msg) + | Except.ok ⟨dim, residualCert⟩ => + if hdim : dim = header.modelDim then + let residualCert' : + ResidualIntervalCert header.modelDim := by + simpa [hdim] using residualCert + pure (Except.ok residualCert') + else + pure (Except.error + s!"residual interval dim {dim} \ + does not match model dim {header.modelDim}") + | none => + deriveResidualIntervalFromModel data start header + active? + match residualCertE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok residualCert' => + let residualOk ← + timePure "check residual interval" (fun () => + Circuit.checkResidualIntervalCert residualCert') + if residualOk then + let dirPos := dirSpec.target + let dirNeg := dirSpec.negative + if layer?.isSome != head?.isSome then + IO.eprintln + "error: --layer and --head must be provided \ + together" + return 2 + let headChoice? : Option (Nat × Nat) := + match layer?, head? with + | some layer, some head => some (layer, head) + | _, _ => none + if period?.isSome && headChoice?.isNone then + IO.eprintln + "warning: --period ignored without \ + --layer/--head" + let colTargetE ← + timePure "read unembed column target" (fun () => + NfptPure.readUnembedColumn + data start header dirPos) + match colTargetE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok colTarget => + let colNegE ← + timePure "read unembed column negative" (fun () => + NfptPure.readUnembedColumn + data start header dirNeg) + match colNegE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok colNeg => + let dirVec : + Fin header.modelDim → Rat := + fun i => colTarget i - colNeg i + let dotIntervalAbs := + Sound.Bounds.dotIntervalAbsBound + let intervalErrorFromHead? : + Model.InductionHeadInputs + seq header.modelDim header.headDim → + ResidualIntervalCert header.modelDim → + Option Rat := + fun inputs residual => by + classical + match hseq0 : seq with + | 0 => exact none + | Nat.succ n => + let _ : NeZero seq := by + exact ⟨by simp [hseq0]⟩ + match + Sound.buildHeadOutputIntervalFromHead? + inputs with + | none => exact none + | some result => + exact some + (dotIntervalAbs + dirVec + (fun i => + residual.lo i - + result.cert.hi i) + (fun i => + residual.hi i - + result.cert.lo i)) + let downstreamError ← + timePure "downstream error" (fun () => + dotIntervalAbs + dirVec + residualCert'.lo + residualCert'.hi) + let finalLB := logitDiffLB - downstreamError + let intervalError? ← + match headChoice? with + | none => pure none + | some (layer, head) => do + let inputsE ← + timePure "read head inputs" (fun () => + NfptPure.readInductionHeadInputs + data start header layer head + dirPos dirNeg period?) + match inputsE with + | Except.error msg => + IO.eprintln s!"warning: {msg}" + pure none + | Except.ok inputs => + let inputs' : + Model.InductionHeadInputs + seq header.modelDim + header.headDim := by + simpa [hseq] using inputs + let inputsAligned : + Model.InductionHeadInputs + seq header.modelDim + header.headDim := + { inputs' with + active := cert.active + prev := cert.prev } + let intervalError? ← + timePure + "head output interval" + (fun () => + intervalErrorFromHead? + inputsAligned + residualCert') + match intervalError? with + | none => + IO.eprintln + "warning: head output interval \ + rejected" + pure none + | some intervalError => + pure (some intervalError) + let intervalLB? := + intervalError?.map (fun err => + logitDiffLB - err) + let effectiveLB := + match intervalLB? with + | some intervalLB => max finalLB intervalLB + | none => finalLB + let violation? : Option Rat := + match effectiveMinLogitDiff with + | none => none + | some minLogitDiff => + if effectiveLB < minLogitDiff then + some minLogitDiff + else + none + match violation? with + | some minLogitDiff => + IO.eprintln + s!"error: end-to-end bound \ + {effectiveLB} below minimum \ + {minLogitDiff}" + return (2 : UInt32) + | none => + match intervalLB? with + | none => + IO.println + s!"ok: end-to-end induction \ + bound certified (seq={seq}, \ + active={activeCount}, \ + logitDiffLB={logitDiffLB}, \ + downstreamError={downstreamError}, \ + finalLB={finalLB})" + | some intervalLB => + let intervalError := + logitDiffLB - intervalLB + IO.println + s!"ok: end-to-end induction \ + bound certified (seq={seq}, \ + active={activeCount}, \ + logitDiffLB={logitDiffLB}, \ + downstreamError={downstreamError}, \ + finalLB={finalLB}, \ + intervalError={intervalError}, \ + intervalLB={intervalLB}, \ + effectiveLB={effectiveLB})" + return 0 + else + IO.eprintln + "error: residual-interval certificate rejected" + return 2 + else + IO.eprintln + s!"error: model seq {header.seqLen} \ + does not match cert seq {seq}" + return 2 +end IO +end Nfp diff --git a/Nfp/Sound/Induction/Core.lean b/Nfp/Sound/Induction/Core.lean index e481c48..3d62976 100644 --- a/Nfp/Sound/Induction/Core.lean +++ b/Nfp/Sound/Induction/Core.lean @@ -1,922 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Algebra.Order.BigOperators.Group.Finset -import Mathlib.Algebra.Order.Field.Basic -import Nfp.Core.Basic -import Mathlib.Data.Finset.Lattice.Fold -import Nfp.Circuit.Cert.ResidualInterval -import Nfp.Circuit.Cert.SoftmaxMargin -import Nfp.Circuit.Cert.ValueRange -import Nfp.Sound.Bounds.Attention -import Nfp.Sound.Bounds.Cache -import Nfp.Sound.Bounds.LayerNorm -import Nfp.Sound.Bounds.LayerNorm.InvStd -import Nfp.Sound.Bounds.MatrixNorm -import Nfp.Sound.Induction.CoreDefs -import Nfp.Sound.Induction.OneHot -import Nfp.Sound.Linear.FinFold -/-! Sound builders for induction certificates; recompute bounds inside Lean from exact inputs and -derive softmax tolerances from score margins rather than trusting external weight dumps. -/ -namespace Nfp -namespace Sound -open scoped BigOperators -open Nfp.Circuit -open Nfp.Sound.Bounds -variable {seq : Nat} -/-- Build and certify a softmax-margin certificate from exact scores/weights. -/ -def buildSoftmaxMarginCert? [NeZero seq] - (active : Finset (Fin seq)) - (prev : Fin seq → Fin seq) - (scores : Fin seq → Fin seq → Rat) - (weights : Fin seq → Fin seq → Rat) : - Option {c : SoftmaxMarginCert seq // checkSoftmaxMarginCert c = true} := by - classical - let otherKeys : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (prev q) - let epsAt : Fin seq → Rat := fun q => - let other := otherKeys q - let maxOther := - if h : other.Nonempty then - other.sup' h (fun k => weights q k) - else - (0 : Rat) - let deficit := (1 : Rat) - weights q (prev q) - max maxOther deficit - let marginAt : Fin seq → Rat := fun q => - let other := otherKeys q - if h : other.Nonempty then - other.inf' h (fun k => scores q (prev q) - scores q k) - else - (0 : Rat) - let eps := - if h : active.Nonempty then - active.sup' h epsAt - else - (0 : Rat) - let margin := - if h : active.Nonempty then - active.inf' h marginAt - else - (0 : Rat) - let cert : SoftmaxMarginCert seq := - { eps := eps - margin := margin - active := active - prev := prev - scores := scores - weights := weights } - if h : checkSoftmaxMarginCert cert = true then - exact some ⟨cert, h⟩ - else - exact none -/-- Build and certify a value-range certificate from exact values. -/ -def buildValueRangeCert? [NeZero seq] - (vals : Fin seq → Rat) - (direction : Option DirectionSpec) : - Option {c : ValueRangeCert seq // checkValueRangeCert c = true} := by - classical - let _ : Nonempty (Fin seq) := by - refine ⟨⟨0, ?_⟩⟩ - exact Nat.pos_of_ne_zero (NeZero.ne seq) - let univ : Finset (Fin seq) := Finset.univ - let hnonempty : univ.Nonempty := Finset.univ_nonempty - let lo := univ.inf' hnonempty vals - let hi := univ.sup' hnonempty vals - let cert : ValueRangeCert seq := - { lo := lo - hi := hi - vals := vals - direction := direction } - if h : checkValueRangeCert cert = true then - exact some ⟨cert, h⟩ - else - exact none -/-- Cached bounds and derived quantities for induction-head core certificates. -/ -structure InductionHeadCoreCache (seq dModel dHead : Nat) where - /-- Cached LayerNorm bound pair. -/ - lnBounds : (Fin seq → Fin dModel → Rat) × (Fin seq → Fin dModel → Rat) - /-- LayerNorm lower bounds. -/ - lnLo : Fin seq → Fin dModel → Rat - /-- LayerNorm upper bounds. -/ - lnHi : Fin seq → Fin dModel → Rat - /-- Tasks for LayerNorm absolute maxima. -/ - lnAbsMaxTask : Fin seq → Rat - /-- Cached LayerNorm absolute maxima. -/ - lnAbsMaxArr : Array Rat - /-- LayerNorm absolute-max lookup. -/ - lnAbsMax : Fin seq → Rat - /-- Tasks for inverse-std bounds. -/ - invStdBoundsTasks : Array (Task (Rat × Rat)) - /-- Cached inverse-std bounds. -/ - invStdBoundsArr : Array (Rat × Rat) - /-- Inverse-std lower bounds. -/ - invStdLo : Fin seq → Rat - /-- Inverse-std upper bounds. -/ - invStdHi : Fin seq → Rat - /-- Cached query base terms. -/ - qBaseArr : Array Rat - /-- Query base lookup. -/ - qBase : Fin dHead → Rat - /-- Cached key base terms. -/ - kBaseArr : Array Rat - /-- Key base lookup. -/ - kBase : Fin dHead → Rat - /-- Tasks for query coefficient rows. -/ - qCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) - /-- Cached query coefficient rows. -/ - qCoeffArr : Array { row : Array Rat // row.size = dHead } - /-- Query coefficient lookup. -/ - qCoeff : Fin seq → Fin dHead → Rat - /-- Tasks for key coefficient rows. -/ - kCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) - /-- Cached key coefficient rows. -/ - kCoeffArr : Array { row : Array Rat // row.size = dHead } - /-- Key coefficient lookup. -/ - kCoeff : Fin seq → Fin dHead → Rat - /-- Query lower bounds. -/ - qLo : Fin seq → Fin dHead → Rat - /-- Query upper bounds. -/ - qHi : Fin seq → Fin dHead → Rat - /-- Key lower bounds. -/ - kLo : Fin seq → Fin dHead → Rat - /-- Key upper bounds. -/ - kHi : Fin seq → Fin dHead → Rat - /-- Query absolute bounds. -/ - qAbs : Fin seq → Fin dHead → Rat - /-- Key absolute bounds. -/ - kAbs : Fin seq → Fin dHead → Rat - /-- Cached max query abs bounds. -/ - qAbsMaxArr : Array Rat - /-- Max query abs bound lookup. -/ - qAbsMax : Fin dHead → Rat - /-- Cached max key abs bounds. -/ - kAbsMaxArr : Array Rat - /-- Max key abs bound lookup. -/ - kAbsMax : Fin dHead → Rat - /-- Causal mask predicate. -/ - masked : Fin seq → Fin seq → Prop - /-- Split budget for query dims. -/ - splitBudgetQ : Nat - /-- Split budget for key dims. -/ - splitBudgetK : Nat - /-- Split budget for base diff dims. -/ - splitBudgetDiffBase : Nat - /-- Split budget for refined diff dims. -/ - splitBudgetDiffRefined : Nat - /-- Split dims for query bounds. -/ - splitDimsQ : Fin seq → List (Fin dHead) - /-- Split dims for key bounds. -/ - splitDimsK : Fin seq → Fin seq → List (Fin dHead) - /-- Split dims for diff bounds with budget. -/ - splitDimsDiffCore : Nat → Fin seq → Fin seq → List (Fin dHead) - /-- Split dims for base diff bounds. -/ - splitDimsDiffBase : Fin seq → Fin seq → List (Fin dHead) - /-- Split dims for refined diff bounds. -/ - splitDimsDiffRefined : Fin seq → Fin seq → List (Fin dHead) - /-- Tasks for dot-product interval rows. -/ - dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) - /-- Tasks for base diff dot rows. -/ - dotDiffRowTasksBase : Array (Task { row : Array (Rat × Rat) // row.size = seq }) - /-- Dot-product lower bounds. -/ - dotLo : Fin seq → Fin seq → Rat - /-- Dot-product upper bounds. -/ - dotHi : Fin seq → Fin seq → Rat - /-- Base diff dot-product lower bounds. -/ - dotDiffLoBase : Fin seq → Fin seq → Rat - /-- Base diff dot-product upper bounds. -/ - dotDiffHiBase : Fin seq → Fin seq → Rat - /-- Dot-product absolute bounds. -/ - dotAbs : Fin seq → Fin seq → Rat - /-- Base score absolute bounds. -/ - scoreBaseAbs : Fin seq → Fin seq → Rat - /-- Score lower bounds. -/ - scoreLo : Fin seq → Fin seq → Rat - /-- Score upper bounds. -/ - scoreHi : Fin seq → Fin seq → Rat - /-- Score lower bounds at prev key. -/ - scoreLoPrev : Fin seq → Rat - /-- Base score-gap lower bounds. -/ - scoreGapLoBase : Fin seq → Fin seq → Rat - /-- Other-key set for each query. -/ - otherKeys : Fin seq → Finset (Fin seq) - /-- Worst key candidate per query. -/ - worstKey : Fin seq → Option (Fin seq) - /-- Refined diff dot-product lower bounds. -/ - dotDiffLo : Fin seq → Fin seq → Rat - /-- Refined diff dot-product upper bounds. -/ - dotDiffHi : Fin seq → Fin seq → Rat - /-- Score-gap lower bounds. -/ - scoreGapLo : Fin seq → Fin seq → Rat - /-- Margin per query. -/ - marginAt : Fin seq → Rat - /-- Epsilon per query. -/ - epsAt : Fin seq → Rat - /-- Per-key weight bounds derived from score gaps. -/ - weightBoundAt : Fin seq → Fin seq → Rat - /-- Global margin. -/ - margin : Rat - /-- Global epsilon. -/ - eps : Rat - /-- Cached direction head vector. -/ - dirHeadVec : Vector Rat dHead - /-- Direction head lookup. -/ - dirHead : Fin dHead → Rat - /-- Value-direction weight dot products. -/ - wvDir : Fin dModel → Rat - /-- Direction bias term. -/ - bDir : Rat - /-- Value lower bounds. -/ - valsLo : Fin seq → Rat - /-- Value upper bounds. -/ - valsHi : Fin seq → Rat - /-- Universe of query indices. -/ - univ : Finset (Fin seq) - /-- Global value lower bound. -/ - lo : Rat - /-- Global value upper bound. -/ - hi : Rat - /-- Value-interval certificate. -/ - valCert : ValueInterval seq - /-- Induction-head certificate. -/ - cert : InductionHeadCert seq -/-- Build cached core quantities for induction-head certificates. -/ -def buildInductionHeadCoreCacheWith [NeZero seq] {dModel dHead : Nat} - (cfg : InductionHeadSplitConfig) - (inputs : Model.InductionHeadInputs seq dModel dHead) : - InductionHeadCoreCache seq dModel dHead := by - classical - let lnBounds := Bounds.cacheBoundPair2 (fun q => - Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) - let lnLo : Fin seq → Fin dModel → Rat := lnBounds.1 - let lnHi : Fin seq → Fin dModel → Rat := lnBounds.2 - let lnAbsMaxTask : Fin seq → Rat := - Bounds.cacheBoundTask (fun q => - Bounds.intervalAbsBound (lnLo q) (lnHi q)) - let lnAbsMaxArr : Array Rat := - Array.ofFn (fun q : Fin seq => lnAbsMaxTask q) - let lnAbsMax : Fin seq → Rat := fun q => - lnAbsMaxArr[q.1]'(by - have hsize : lnAbsMaxArr.size = seq := by - simp [lnAbsMaxArr] - simp [hsize]) - let invStdBoundsTasks : Array (Task (Rat × Rat)) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => invStdBounds inputs.lnEps (inputs.embed q))) - let invStdBoundsArr : Array (Rat × Rat) := - Array.ofFn (fun q : Fin seq => - (invStdBoundsTasks[q.1]'(by - have hsize : invStdBoundsTasks.size = seq := by - simp [invStdBoundsTasks] - simp [hsize])).get) - let invStdLo : Fin seq → Rat := fun q => - (invStdBoundsArr[q.1]'(by - have hsize : invStdBoundsArr.size = seq := by - simp [invStdBoundsArr] - simp [hsize])).1 - let invStdHi : Fin seq → Rat := fun q => - (invStdBoundsArr[q.1]'(by - have hsize : invStdBoundsArr.size = seq := by - simp [invStdBoundsArr] - simp [hsize])).2 - let qBaseArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.wq j d) (fun j => inputs.ln1Beta j) + - inputs.bq d) - let qBase : Fin dHead → Rat := fun d => - qBaseArr[d.1]'(by - have hsize : qBaseArr.size = dHead := by - simp [qBaseArr] - simp [hsize]) - let kBaseArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.wk j d) (fun j => inputs.ln1Beta j) + - inputs.bk d) - let kBase : Fin dHead → Rat := fun d => - kBaseArr[d.1]'(by - have hsize : kBaseArr.size = dHead := by - simp [kBaseArr] - simp [hsize]) - let qCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - let μ := mean (inputs.embed q) - let coeff : Fin dModel → Rat := fun j => - inputs.ln1Gamma j * (inputs.embed q j - μ) - ⟨Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.wq j d) coeff), - by simp⟩)) - let qCoeffArr : Array { row : Array Rat // row.size = dHead } := - Array.ofFn (fun q : Fin seq => - (qCoeffRowTasks[q.1]'(by - have hsize : qCoeffRowTasks.size = seq := by - simp [qCoeffRowTasks] - simp [hsize])).get) - let qCoeff : Fin seq → Fin dHead → Rat := fun q d => - let row := qCoeffArr[q.1]'(by - have hsize : qCoeffArr.size = seq := by - simp [qCoeffArr] - simp [hsize]) - row.1[d.1]'(by - simp [row.2]) - let kCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - let μ := mean (inputs.embed q) - let coeff : Fin dModel → Rat := fun j => - inputs.ln1Gamma j * (inputs.embed q j - μ) - ⟨Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.wk j d) coeff), - by simp⟩)) - let kCoeffArr : Array { row : Array Rat // row.size = dHead } := - Array.ofFn (fun q : Fin seq => - (kCoeffRowTasks[q.1]'(by - have hsize : kCoeffRowTasks.size = seq := by - simp [kCoeffRowTasks] - simp [hsize])).get) - let kCoeff : Fin seq → Fin dHead → Rat := fun q d => - let row := kCoeffArr[q.1]'(by - have hsize : kCoeffArr.size = seq := by - simp [kCoeffArr] - simp [hsize]) - row.1[d.1]'(by - simp [row.2]) - let qLo : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) - qBase d + bounds.1 - let qHi : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) - qBase d + bounds.2 - let kLo : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) - kBase d + bounds.1 - let kHi : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) - kBase d + bounds.2 - let qAbs : Fin seq → Fin dHead → Rat := fun q d => max |qLo q d| |qHi q d| - let kAbs : Fin seq → Fin dHead → Rat := fun q d => max |kLo q d| |kHi q d| - let qAbsMaxArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := Finset.univ_nonempty - univ.sup' hnonempty (fun q => qAbs q d)) - let qAbsMax : Fin dHead → Rat := fun d => - qAbsMaxArr[d.1]'(by - have hsize : qAbsMaxArr.size = dHead := by - simp [qAbsMaxArr] - simp [hsize]) - let kAbsMaxArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := Finset.univ_nonempty - univ.sup' hnonempty (fun k => kAbs k d)) - let kAbsMax : Fin dHead → Rat := fun d => - kAbsMaxArr[d.1]'(by - have hsize : kAbsMaxArr.size = dHead := by - simp [kAbsMaxArr] - simp [hsize]) - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let splitBudgetQ : Nat := cfg.splitBudgetQ - let splitBudgetK : Nat := cfg.splitBudgetK - let splitBudgetDiffBase : Nat := cfg.splitBudgetDiffBase - let splitBudgetDiffRefined : Nat := cfg.splitBudgetDiffRefined - let finRangeHead : List (Fin dHead) := List.finRange dHead - let finRangeSeq : List (Fin seq) := List.finRange seq - let splitDimsQ : Fin seq → List (Fin dHead) := fun q => - if splitBudgetQ = 0 then - [] - else - let ambig := - finRangeHead.filter (fun d => decide (qLo q d < 0 ∧ 0 < qHi q d)) - let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d - let step - (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) - (d : Fin dHead) : - Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := - let s := score d - match best with - | (none, none) => (some (s, d), none) - | (some b1, none) => - if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) - | (some b1, some b2) => - if b1.1 < s then (some (s, d), some b1) - else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) - | (none, some b2) => - if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => - match ambig.foldl step (none, none) with - | (some b1, some b2) => [b1.2, b2.2] - | (some b1, none) => [b1.2] - | (none, _) => [] - let dims1 := top2 ambig - let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) - (dims1 ++ dims2).take splitBudgetQ - let splitDimsK : Fin seq → Fin seq → List (Fin dHead) := fun q k => - if splitBudgetK = 0 then - [] - else - let ambig := - finRangeHead.filter (fun d => decide (kLo k d < 0 ∧ 0 < kHi k d)) - let score : Fin dHead → Rat := fun d => (kHi k d - kLo k d) * qAbs q d - let step - (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) - (d : Fin dHead) : - Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := - let s := score d - match best with - | (none, none) => (some (s, d), none) - | (some b1, none) => - if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) - | (some b1, some b2) => - if b1.1 < s then (some (s, d), some b1) - else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) - | (none, some b2) => - if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => - match ambig.foldl step (none, none) with - | (some b1, some b2) => [b1.2, b2.2] - | (some b1, none) => [b1.2] - | (none, _) => [] - let dims1 := top2 ambig - let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) - (dims1 ++ dims2).take splitBudgetK - let splitDimsDiffCore : Nat → Fin seq → Fin seq → List (Fin dHead) := fun budget q k => - if budget = 0 then - [] - else - let prev := inputs.prev q - let diffLo : Fin dHead → Rat := fun d => kLo prev d - kHi k d - let diffHi : Fin dHead → Rat := fun d => kHi prev d - kLo k d - let ambig := - finRangeHead.filter (fun d => decide (diffLo d < 0 ∧ 0 < diffHi d)) - let score : Fin dHead → Rat := fun d => (diffHi d - diffLo d) * qAbs q d - let step - (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) - (d : Fin dHead) : - Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := - let s := score d - match best with - | (none, none) => (some (s, d), none) - | (some b1, none) => - if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) - | (some b1, some b2) => - if b1.1 < s then (some (s, d), some b1) - else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) - | (none, some b2) => - if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => - match ambig.foldl step (none, none) with - | (some b1, some b2) => [b1.2, b2.2] - | (some b1, none) => [b1.2] - | (none, _) => [] - let dims1 : List (Fin dHead) := top2 ambig - let dims2 : List (Fin dHead) := - top2 (ambig.filter (fun d => decide (d ∉ dims1))) - let memDims2 : Fin dHead → Bool := fun d => - dims2.any (fun d' => decide (d' = d)) - let dims3 : List (Fin dHead) := - top2 - ((ambig.filter (fun d => decide (d ∉ dims1))).filter - (fun d => !memDims2 d)) - (dims1 ++ dims2 ++ dims3).take budget - let splitDimsDiffBase : Fin seq → Fin seq → List (Fin dHead) := - splitDimsDiffCore splitBudgetDiffBase - let splitDimsDiffRefined : Fin seq → Fin seq → List (Fin dHead) := - splitDimsDiffCore splitBudgetDiffRefined - let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - let dimsQ := splitDimsQ q - ⟨Array.ofFn (fun k : Fin seq => - if masked q k then - (0, 0) - else - let dimsK := splitDimsK q k - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsK - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo k d) (fun d => kHi k d)), - by simp⟩)) - let dotDiffRowTasksBase : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - if hq : q ∈ inputs.active then - let dimsQ := splitDimsQ q - ⟨Array.ofFn (fun k : Fin seq => - if masked q k then - (0, 0) - else - let dimsDiff := splitDimsDiffBase q k - let prev := inputs.prev q - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)), - by simp⟩ - else - ⟨Array.ofFn (fun _ : Fin seq => (0, 0)), by simp⟩)) - let dotLo : Fin seq → Fin seq → Rat := fun q k => - let row := (dotRowTasks[q.1]'(by - simp [dotRowTasks, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.1 - let dotHi : Fin seq → Fin seq → Rat := fun q k => - let row := (dotRowTasks[q.1]'(by - simp [dotRowTasks, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.2 - let dotDiffLoBase : Fin seq → Fin seq → Rat := fun q k => - let row := (dotDiffRowTasksBase[q.1]'(by - simp [dotDiffRowTasksBase, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.1 - let dotDiffHiBase : Fin seq → Fin seq → Rat := fun q k => - let row := (dotDiffRowTasksBase[q.1]'(by - simp [dotDiffRowTasksBase, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.2 - let dotAbs : Fin seq → Fin seq → Rat := fun q k => max |dotLo q k| |dotHi q k| - let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => - |inputs.scale| * dotAbs q k - let scoreLo : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - if hscale : 0 ≤ inputs.scale then - inputs.scale * dotLo q k - else - inputs.scale * dotHi q k - let scoreHi : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - if hscale : 0 ≤ inputs.scale then - inputs.scale * dotHi q k - else - inputs.scale * dotLo q k - let scoreLoPrev : Fin seq → Rat := fun q => - scoreLo q (inputs.prev q) - let scoreGapLoBaseRaw : Fin seq → Fin seq → Rat := fun q k => - if masked q (inputs.prev q) then - scoreLoPrev q - scoreHi q k - else if masked q k then - scoreLoPrev q - inputs.maskValue - else if hscale : 0 ≤ inputs.scale then - inputs.scale * dotDiffLoBase q k - else - inputs.scale * dotDiffHiBase q k - let scoreGapLoBase : Fin seq → Fin seq → Rat := - Bounds.cacheBound2 scoreGapLoBaseRaw - let otherKeys : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - let worstKeyRaw : Fin seq → Option (Fin seq) := fun q => - if hq : q ∈ inputs.active then - let ks := finRangeSeq.filter (fun k => decide (k ≠ inputs.prev q)) - match ks with - | [] => none - | k :: ks => - let step (best : Rat × Fin seq) (k : Fin seq) := - let s := scoreGapLoBase q k - if s ≤ best.1 then (s, k) else best - some (ks.foldl step (scoreGapLoBase q k, k)).2 - else - none - let worstKeyArr : Array (Thunk (Option (Fin seq))) := - Array.ofFn (fun q => Thunk.mk (fun _ => worstKeyRaw q)) - let worstKey : Fin seq → Option (Fin seq) := fun q => - let t := worstKeyArr[q.1]'(by - simp [worstKeyArr, q.isLt]) - Thunk.get t - let dotDiffLo : Fin seq → Fin seq → Rat := fun q k => - match worstKey q with - | some k' => - if hk : k = k' then - let dimsQ := splitDimsQ q - let dimsDiff := splitDimsDiffRefined q k - let prev := inputs.prev q - (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)).1 - else - dotDiffLoBase q k - | none => dotDiffLoBase q k - let dotDiffHi : Fin seq → Fin seq → Rat := fun q k => - match worstKey q with - | some k' => - if hk : k = k' then - let dimsQ := splitDimsQ q - let dimsDiff := splitDimsDiffRefined q k - let prev := inputs.prev q - (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)).2 - else - dotDiffHiBase q k - | none => dotDiffHiBase q k - let scoreGapLoRaw : Fin seq → Fin seq → Rat := fun q k => - if masked q (inputs.prev q) then - scoreLoPrev q - scoreHi q k - else if masked q k then - scoreLoPrev q - inputs.maskValue - else if hscale : 0 ≤ inputs.scale then - inputs.scale * dotDiffLo q k - else - inputs.scale * dotDiffHi q k - let scoreGapLo : Fin seq → Fin seq → Rat := - Bounds.cacheBound2 scoreGapLoRaw - let marginAt : Fin seq → Rat := fun q => - if hq : q ∈ inputs.active then - let other := otherKeys q - if h : other.Nonempty then - other.inf' h (fun k => scoreGapLo q k) - else - (0 : Rat) - else - (0 : Rat) - let weightBoundAtBase : Fin seq → Fin seq → Rat := fun q k => - if hk : k = inputs.prev q then - (0 : Rat) - else - let gap := scoreGapLo q k - if gap < 0 then - (1 : Rat) - else - ratDivUp 1 (1 + gap) - let weightBoundAt : Fin seq → Fin seq → Rat := - Bounds.cacheBound2 weightBoundAtBase - let epsAtBase : Fin seq → Rat := fun q => - let other := otherKeys q - let total := other.sum (fun k => weightBoundAt q k) - min (1 : Rat) total - let epsAt : Fin seq → Rat := - Bounds.cacheBoundThunk epsAtBase - let margin : Rat := - if h : inputs.active.Nonempty then - inputs.active.inf' h marginAt - else - (0 : Rat) - let eps : Rat := - if h : inputs.active.Nonempty then - inputs.active.sup' h epsAt - else - (0 : Rat) - let dirHeadVec := dirHeadVecOfInputs inputs - let dirHead : Fin dHead → Rat := fun d => dirHeadVec.get d - let wvDir : Fin dModel → Rat := - Bounds.cacheBoundTask (fun j => - Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) - let bDir : Rat := - Linear.dotFin dHead dirHead (fun d => inputs.bv d) - let valsLo : Fin seq → Rat := fun q => - bDir + Bounds.dotIntervalLower (fun j => wvDir j) (lnLo q) (lnHi q) - let valsHi : Fin seq → Rat := fun q => - bDir + Bounds.dotIntervalUpper (fun j => wvDir j) (lnLo q) (lnHi q) - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := by simp [univ] - let lo := univ.inf' hnonempty valsLo - let hi := univ.sup' hnonempty valsHi - let valCert : ValueInterval seq := - { lo := lo - hi := hi - valsLo := valsLo - valsHi := valsHi - direction := some inputs.directionSpec } - let cert : InductionHeadCert seq := - { eps := eps - epsAt := epsAt - weightBoundAt := weightBoundAt - margin := margin - active := inputs.active - prev := inputs.prev - values := valCert } - exact - { lnBounds := lnBounds - lnLo := lnLo - lnHi := lnHi - lnAbsMaxTask := lnAbsMaxTask - lnAbsMaxArr := lnAbsMaxArr - lnAbsMax := lnAbsMax - invStdBoundsTasks := invStdBoundsTasks - invStdBoundsArr := invStdBoundsArr - invStdLo := invStdLo - invStdHi := invStdHi - qBaseArr := qBaseArr - qBase := qBase - kBaseArr := kBaseArr - kBase := kBase - qCoeffRowTasks := qCoeffRowTasks - qCoeffArr := qCoeffArr - qCoeff := qCoeff - kCoeffRowTasks := kCoeffRowTasks - kCoeffArr := kCoeffArr - kCoeff := kCoeff - qLo := qLo - qHi := qHi - kLo := kLo - kHi := kHi - qAbs := qAbs - kAbs := kAbs - qAbsMaxArr := qAbsMaxArr - qAbsMax := qAbsMax - kAbsMaxArr := kAbsMaxArr - kAbsMax := kAbsMax - masked := masked - splitBudgetQ := splitBudgetQ - splitBudgetK := splitBudgetK - splitBudgetDiffBase := splitBudgetDiffBase - splitBudgetDiffRefined := splitBudgetDiffRefined - splitDimsQ := splitDimsQ - splitDimsK := splitDimsK - splitDimsDiffCore := splitDimsDiffCore - splitDimsDiffBase := splitDimsDiffBase - splitDimsDiffRefined := splitDimsDiffRefined - dotRowTasks := dotRowTasks - dotDiffRowTasksBase := dotDiffRowTasksBase - dotLo := dotLo - dotHi := dotHi - dotDiffLoBase := dotDiffLoBase - dotDiffHiBase := dotDiffHiBase - dotAbs := dotAbs - scoreBaseAbs := scoreBaseAbs - scoreLo := scoreLo - scoreHi := scoreHi - scoreLoPrev := scoreLoPrev - scoreGapLoBase := scoreGapLoBase - otherKeys := otherKeys - worstKey := worstKey - dotDiffLo := dotDiffLo - dotDiffHi := dotDiffHi - scoreGapLo := scoreGapLo - marginAt := marginAt - epsAt := epsAt - weightBoundAt := weightBoundAt - margin := margin - eps := eps - dirHeadVec := dirHeadVec - dirHead := dirHead - wvDir := wvDir - bDir := bDir - valsLo := valsLo - valsHi := valsHi - univ := univ - lo := lo - hi := hi - valCert := valCert - cert := cert } +import Nfp.Sound.Induction.Core.Basic -/-- Build cached core quantities for induction-head certificates using the default split budgets. -/ -def buildInductionHeadCoreCache [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : - InductionHeadCoreCache seq dModel dHead := - buildInductionHeadCoreCacheWith defaultInductionHeadSplitConfig inputs - -/-- The cached certificate is built from cache fields. -/ -theorem buildInductionHeadCoreCache_cert_eq [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : - (buildInductionHeadCoreCache inputs).cert = - { eps := (buildInductionHeadCoreCache inputs).eps - epsAt := (buildInductionHeadCoreCache inputs).epsAt - weightBoundAt := (buildInductionHeadCoreCache inputs).weightBoundAt - margin := (buildInductionHeadCoreCache inputs).margin - active := inputs.active - prev := inputs.prev - values := (buildInductionHeadCoreCache inputs).valCert } := by - rfl -/-- Build induction certificates from exact head inputs (core computation). -/ -def buildInductionCertFromHeadCoreWith? [NeZero seq] {dModel dHead : Nat} - (cfg : InductionHeadSplitConfig) - (inputs : Model.InductionHeadInputs seq dModel dHead) : - Option (InductionHeadCert seq) := by - classical - by_cases hEps : 0 < inputs.lnEps - · by_cases hSqrt : 0 < sqrtLower inputs.lnEps - · by_cases hmodel : dModel = 0 - · exact none - · by_cases hactive : inputs.active.Nonempty - · exact some (buildInductionHeadCoreCacheWith cfg inputs).cert - · exact none - · exact none - · exact none - -/-- Build induction certificates from exact head inputs using the default split budgets. -/ -def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : - Option (InductionHeadCert seq) := - buildInductionCertFromHeadCoreWith? defaultInductionHeadSplitConfig inputs - -/-- `buildInductionCertFromHeadCoreWith?` succeeds under the guard conditions. -/ -theorem buildInductionCertFromHeadCoreWith?_eq_some [NeZero seq] {dModel dHead : Nat} - (cfg : InductionHeadSplitConfig) - (inputs : Model.InductionHeadInputs seq dModel dHead) - (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) - (hmodel : dModel ≠ 0) (hactive : inputs.active.Nonempty) : - buildInductionCertFromHeadCoreWith? cfg inputs = - some (buildInductionHeadCoreCacheWith cfg inputs).cert := by - classical - simp [buildInductionCertFromHeadCoreWith?, hEps, hSqrt, hmodel, hactive] - -/-- `buildInductionCertFromHeadCoreWith?` fails when `dModel = 0`. -/ -theorem buildInductionCertFromHeadCoreWith?_eq_none_of_model_eq_zero - [NeZero seq] {dModel dHead : Nat} - (cfg : InductionHeadSplitConfig) - (inputs : Model.InductionHeadInputs seq dModel dHead) - (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) - (hmodel : dModel = 0) : - buildInductionCertFromHeadCoreWith? cfg inputs = none := by - classical - simp [buildInductionCertFromHeadCoreWith?, hEps, hSqrt, hmodel] - -/-- `buildInductionCertFromHeadCoreWith?` fails when `active` is empty. -/ -theorem buildInductionCertFromHeadCoreWith?_eq_none_of_not_active - [NeZero seq] {dModel dHead : Nat} - (cfg : InductionHeadSplitConfig) - (inputs : Model.InductionHeadInputs seq dModel dHead) - (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) - (hmodel : dModel ≠ 0) (hactive : ¬inputs.active.Nonempty) : - buildInductionCertFromHeadCoreWith? cfg inputs = none := by - classical - simp [buildInductionCertFromHeadCoreWith?, hEps, hSqrt, hmodel, hactive] - -/-- `buildInductionCertFromHeadCoreWith?` fails when the sqrt lower bound is nonpositive. -/ -theorem buildInductionCertFromHeadCoreWith?_eq_none_of_not_sqrt - [NeZero seq] {dModel dHead : Nat} - (cfg : InductionHeadSplitConfig) - (inputs : Model.InductionHeadInputs seq dModel dHead) - (hEps : 0 < inputs.lnEps) (hSqrt : ¬0 < sqrtLower inputs.lnEps) : - buildInductionCertFromHeadCoreWith? cfg inputs = none := by - classical - simp [buildInductionCertFromHeadCoreWith?, hEps, hSqrt] - -/-- `buildInductionCertFromHeadCoreWith?` fails when `lnEps` is nonpositive. -/ -theorem buildInductionCertFromHeadCoreWith?_eq_none_of_not_eps - [NeZero seq] {dModel dHead : Nat} - (cfg : InductionHeadSplitConfig) - (inputs : Model.InductionHeadInputs seq dModel dHead) - (hEps : ¬0 < inputs.lnEps) : - buildInductionCertFromHeadCoreWith? cfg inputs = none := by - classical - simp [buildInductionCertFromHeadCoreWith?, hEps] - -/-- `buildInductionCertFromHeadCore?` succeeds under the guard conditions. -/ -theorem buildInductionCertFromHeadCore?_eq_some [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) - (hmodel : dModel ≠ 0) (hactive : inputs.active.Nonempty) : - buildInductionCertFromHeadCore? inputs = - some (buildInductionHeadCoreCache inputs).cert := by - classical - simpa [buildInductionCertFromHeadCore?, buildInductionHeadCoreCache] using - (buildInductionCertFromHeadCoreWith?_eq_some - (cfg := defaultInductionHeadSplitConfig) (inputs := inputs) - hEps hSqrt hmodel hactive) - -/-- `buildInductionCertFromHeadCore?` fails when `dModel = 0`. -/ -theorem buildInductionCertFromHeadCore?_eq_none_of_model_eq_zero [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) - (hmodel : dModel = 0) : - buildInductionCertFromHeadCore? inputs = none := by - classical - simpa [buildInductionCertFromHeadCore?] using - (buildInductionCertFromHeadCoreWith?_eq_none_of_model_eq_zero - (cfg := defaultInductionHeadSplitConfig) (inputs := inputs) hEps hSqrt hmodel) - -/-- `buildInductionCertFromHeadCore?` fails when `active` is empty. -/ -theorem buildInductionCertFromHeadCore?_eq_none_of_not_active [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) - (hmodel : dModel ≠ 0) (hactive : ¬inputs.active.Nonempty) : - buildInductionCertFromHeadCore? inputs = none := by - classical - simpa [buildInductionCertFromHeadCore?] using - (buildInductionCertFromHeadCoreWith?_eq_none_of_not_active - (cfg := defaultInductionHeadSplitConfig) (inputs := inputs) hEps hSqrt hmodel hactive) - -/-- `buildInductionCertFromHeadCore?` fails when the sqrt lower bound is nonpositive. -/ -theorem buildInductionCertFromHeadCore?_eq_none_of_not_sqrt [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (hEps : 0 < inputs.lnEps) (hSqrt : ¬0 < sqrtLower inputs.lnEps) : - buildInductionCertFromHeadCore? inputs = none := by - classical - simpa [buildInductionCertFromHeadCore?] using - (buildInductionCertFromHeadCoreWith?_eq_none_of_not_sqrt - (cfg := defaultInductionHeadSplitConfig) (inputs := inputs) hEps hSqrt) - -/-- `buildInductionCertFromHeadCore?` fails when `lnEps` is nonpositive. -/ -theorem buildInductionCertFromHeadCore?_eq_none_of_not_eps [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (hEps : ¬0 < inputs.lnEps) : - buildInductionCertFromHeadCore? inputs = none := by - classical - simpa [buildInductionCertFromHeadCore?] using - (buildInductionCertFromHeadCoreWith?_eq_none_of_not_eps - (cfg := defaultInductionHeadSplitConfig) (inputs := inputs) hEps) - -end Sound -end Nfp +/-! +Core definitions and constructors for induction certificates. +-/ diff --git a/Nfp/Sound/Induction/Core/Basic.lean b/Nfp/Sound/Induction/Core/Basic.lean new file mode 100644 index 0000000..e481c48 --- /dev/null +++ b/Nfp/Sound/Induction/Core/Basic.lean @@ -0,0 +1,922 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later +import Mathlib.Algebra.BigOperators.Group.Finset.Basic +import Mathlib.Algebra.Order.BigOperators.Group.Finset +import Mathlib.Algebra.Order.Field.Basic +import Nfp.Core.Basic +import Mathlib.Data.Finset.Lattice.Fold +import Nfp.Circuit.Cert.ResidualInterval +import Nfp.Circuit.Cert.SoftmaxMargin +import Nfp.Circuit.Cert.ValueRange +import Nfp.Sound.Bounds.Attention +import Nfp.Sound.Bounds.Cache +import Nfp.Sound.Bounds.LayerNorm +import Nfp.Sound.Bounds.LayerNorm.InvStd +import Nfp.Sound.Bounds.MatrixNorm +import Nfp.Sound.Induction.CoreDefs +import Nfp.Sound.Induction.OneHot +import Nfp.Sound.Linear.FinFold +/-! Sound builders for induction certificates; recompute bounds inside Lean from exact inputs and +derive softmax tolerances from score margins rather than trusting external weight dumps. -/ +namespace Nfp +namespace Sound +open scoped BigOperators +open Nfp.Circuit +open Nfp.Sound.Bounds +variable {seq : Nat} +/-- Build and certify a softmax-margin certificate from exact scores/weights. -/ +def buildSoftmaxMarginCert? [NeZero seq] + (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (scores : Fin seq → Fin seq → Rat) + (weights : Fin seq → Fin seq → Rat) : + Option {c : SoftmaxMarginCert seq // checkSoftmaxMarginCert c = true} := by + classical + let otherKeys : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (prev q) + let epsAt : Fin seq → Rat := fun q => + let other := otherKeys q + let maxOther := + if h : other.Nonempty then + other.sup' h (fun k => weights q k) + else + (0 : Rat) + let deficit := (1 : Rat) - weights q (prev q) + max maxOther deficit + let marginAt : Fin seq → Rat := fun q => + let other := otherKeys q + if h : other.Nonempty then + other.inf' h (fun k => scores q (prev q) - scores q k) + else + (0 : Rat) + let eps := + if h : active.Nonempty then + active.sup' h epsAt + else + (0 : Rat) + let margin := + if h : active.Nonempty then + active.inf' h marginAt + else + (0 : Rat) + let cert : SoftmaxMarginCert seq := + { eps := eps + margin := margin + active := active + prev := prev + scores := scores + weights := weights } + if h : checkSoftmaxMarginCert cert = true then + exact some ⟨cert, h⟩ + else + exact none +/-- Build and certify a value-range certificate from exact values. -/ +def buildValueRangeCert? [NeZero seq] + (vals : Fin seq → Rat) + (direction : Option DirectionSpec) : + Option {c : ValueRangeCert seq // checkValueRangeCert c = true} := by + classical + let _ : Nonempty (Fin seq) := by + refine ⟨⟨0, ?_⟩⟩ + exact Nat.pos_of_ne_zero (NeZero.ne seq) + let univ : Finset (Fin seq) := Finset.univ + let hnonempty : univ.Nonempty := Finset.univ_nonempty + let lo := univ.inf' hnonempty vals + let hi := univ.sup' hnonempty vals + let cert : ValueRangeCert seq := + { lo := lo + hi := hi + vals := vals + direction := direction } + if h : checkValueRangeCert cert = true then + exact some ⟨cert, h⟩ + else + exact none +/-- Cached bounds and derived quantities for induction-head core certificates. -/ +structure InductionHeadCoreCache (seq dModel dHead : Nat) where + /-- Cached LayerNorm bound pair. -/ + lnBounds : (Fin seq → Fin dModel → Rat) × (Fin seq → Fin dModel → Rat) + /-- LayerNorm lower bounds. -/ + lnLo : Fin seq → Fin dModel → Rat + /-- LayerNorm upper bounds. -/ + lnHi : Fin seq → Fin dModel → Rat + /-- Tasks for LayerNorm absolute maxima. -/ + lnAbsMaxTask : Fin seq → Rat + /-- Cached LayerNorm absolute maxima. -/ + lnAbsMaxArr : Array Rat + /-- LayerNorm absolute-max lookup. -/ + lnAbsMax : Fin seq → Rat + /-- Tasks for inverse-std bounds. -/ + invStdBoundsTasks : Array (Task (Rat × Rat)) + /-- Cached inverse-std bounds. -/ + invStdBoundsArr : Array (Rat × Rat) + /-- Inverse-std lower bounds. -/ + invStdLo : Fin seq → Rat + /-- Inverse-std upper bounds. -/ + invStdHi : Fin seq → Rat + /-- Cached query base terms. -/ + qBaseArr : Array Rat + /-- Query base lookup. -/ + qBase : Fin dHead → Rat + /-- Cached key base terms. -/ + kBaseArr : Array Rat + /-- Key base lookup. -/ + kBase : Fin dHead → Rat + /-- Tasks for query coefficient rows. -/ + qCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) + /-- Cached query coefficient rows. -/ + qCoeffArr : Array { row : Array Rat // row.size = dHead } + /-- Query coefficient lookup. -/ + qCoeff : Fin seq → Fin dHead → Rat + /-- Tasks for key coefficient rows. -/ + kCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) + /-- Cached key coefficient rows. -/ + kCoeffArr : Array { row : Array Rat // row.size = dHead } + /-- Key coefficient lookup. -/ + kCoeff : Fin seq → Fin dHead → Rat + /-- Query lower bounds. -/ + qLo : Fin seq → Fin dHead → Rat + /-- Query upper bounds. -/ + qHi : Fin seq → Fin dHead → Rat + /-- Key lower bounds. -/ + kLo : Fin seq → Fin dHead → Rat + /-- Key upper bounds. -/ + kHi : Fin seq → Fin dHead → Rat + /-- Query absolute bounds. -/ + qAbs : Fin seq → Fin dHead → Rat + /-- Key absolute bounds. -/ + kAbs : Fin seq → Fin dHead → Rat + /-- Cached max query abs bounds. -/ + qAbsMaxArr : Array Rat + /-- Max query abs bound lookup. -/ + qAbsMax : Fin dHead → Rat + /-- Cached max key abs bounds. -/ + kAbsMaxArr : Array Rat + /-- Max key abs bound lookup. -/ + kAbsMax : Fin dHead → Rat + /-- Causal mask predicate. -/ + masked : Fin seq → Fin seq → Prop + /-- Split budget for query dims. -/ + splitBudgetQ : Nat + /-- Split budget for key dims. -/ + splitBudgetK : Nat + /-- Split budget for base diff dims. -/ + splitBudgetDiffBase : Nat + /-- Split budget for refined diff dims. -/ + splitBudgetDiffRefined : Nat + /-- Split dims for query bounds. -/ + splitDimsQ : Fin seq → List (Fin dHead) + /-- Split dims for key bounds. -/ + splitDimsK : Fin seq → Fin seq → List (Fin dHead) + /-- Split dims for diff bounds with budget. -/ + splitDimsDiffCore : Nat → Fin seq → Fin seq → List (Fin dHead) + /-- Split dims for base diff bounds. -/ + splitDimsDiffBase : Fin seq → Fin seq → List (Fin dHead) + /-- Split dims for refined diff bounds. -/ + splitDimsDiffRefined : Fin seq → Fin seq → List (Fin dHead) + /-- Tasks for dot-product interval rows. -/ + dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) + /-- Tasks for base diff dot rows. -/ + dotDiffRowTasksBase : Array (Task { row : Array (Rat × Rat) // row.size = seq }) + /-- Dot-product lower bounds. -/ + dotLo : Fin seq → Fin seq → Rat + /-- Dot-product upper bounds. -/ + dotHi : Fin seq → Fin seq → Rat + /-- Base diff dot-product lower bounds. -/ + dotDiffLoBase : Fin seq → Fin seq → Rat + /-- Base diff dot-product upper bounds. -/ + dotDiffHiBase : Fin seq → Fin seq → Rat + /-- Dot-product absolute bounds. -/ + dotAbs : Fin seq → Fin seq → Rat + /-- Base score absolute bounds. -/ + scoreBaseAbs : Fin seq → Fin seq → Rat + /-- Score lower bounds. -/ + scoreLo : Fin seq → Fin seq → Rat + /-- Score upper bounds. -/ + scoreHi : Fin seq → Fin seq → Rat + /-- Score lower bounds at prev key. -/ + scoreLoPrev : Fin seq → Rat + /-- Base score-gap lower bounds. -/ + scoreGapLoBase : Fin seq → Fin seq → Rat + /-- Other-key set for each query. -/ + otherKeys : Fin seq → Finset (Fin seq) + /-- Worst key candidate per query. -/ + worstKey : Fin seq → Option (Fin seq) + /-- Refined diff dot-product lower bounds. -/ + dotDiffLo : Fin seq → Fin seq → Rat + /-- Refined diff dot-product upper bounds. -/ + dotDiffHi : Fin seq → Fin seq → Rat + /-- Score-gap lower bounds. -/ + scoreGapLo : Fin seq → Fin seq → Rat + /-- Margin per query. -/ + marginAt : Fin seq → Rat + /-- Epsilon per query. -/ + epsAt : Fin seq → Rat + /-- Per-key weight bounds derived from score gaps. -/ + weightBoundAt : Fin seq → Fin seq → Rat + /-- Global margin. -/ + margin : Rat + /-- Global epsilon. -/ + eps : Rat + /-- Cached direction head vector. -/ + dirHeadVec : Vector Rat dHead + /-- Direction head lookup. -/ + dirHead : Fin dHead → Rat + /-- Value-direction weight dot products. -/ + wvDir : Fin dModel → Rat + /-- Direction bias term. -/ + bDir : Rat + /-- Value lower bounds. -/ + valsLo : Fin seq → Rat + /-- Value upper bounds. -/ + valsHi : Fin seq → Rat + /-- Universe of query indices. -/ + univ : Finset (Fin seq) + /-- Global value lower bound. -/ + lo : Rat + /-- Global value upper bound. -/ + hi : Rat + /-- Value-interval certificate. -/ + valCert : ValueInterval seq + /-- Induction-head certificate. -/ + cert : InductionHeadCert seq + +/-- Build cached core quantities for induction-head certificates. -/ +def buildInductionHeadCoreCacheWith [NeZero seq] {dModel dHead : Nat} + (cfg : InductionHeadSplitConfig) + (inputs : Model.InductionHeadInputs seq dModel dHead) : + InductionHeadCoreCache seq dModel dHead := by + classical + let lnBounds := Bounds.cacheBoundPair2 (fun q => + Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) + let lnLo : Fin seq → Fin dModel → Rat := lnBounds.1 + let lnHi : Fin seq → Fin dModel → Rat := lnBounds.2 + let lnAbsMaxTask : Fin seq → Rat := + Bounds.cacheBoundTask (fun q => + Bounds.intervalAbsBound (lnLo q) (lnHi q)) + let lnAbsMaxArr : Array Rat := + Array.ofFn (fun q : Fin seq => lnAbsMaxTask q) + let lnAbsMax : Fin seq → Rat := fun q => + lnAbsMaxArr[q.1]'(by + have hsize : lnAbsMaxArr.size = seq := by + simp [lnAbsMaxArr] + simp [hsize]) + let invStdBoundsTasks : Array (Task (Rat × Rat)) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => invStdBounds inputs.lnEps (inputs.embed q))) + let invStdBoundsArr : Array (Rat × Rat) := + Array.ofFn (fun q : Fin seq => + (invStdBoundsTasks[q.1]'(by + have hsize : invStdBoundsTasks.size = seq := by + simp [invStdBoundsTasks] + simp [hsize])).get) + let invStdLo : Fin seq → Rat := fun q => + (invStdBoundsArr[q.1]'(by + have hsize : invStdBoundsArr.size = seq := by + simp [invStdBoundsArr] + simp [hsize])).1 + let invStdHi : Fin seq → Rat := fun q => + (invStdBoundsArr[q.1]'(by + have hsize : invStdBoundsArr.size = seq := by + simp [invStdBoundsArr] + simp [hsize])).2 + let qBaseArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.wq j d) (fun j => inputs.ln1Beta j) + + inputs.bq d) + let qBase : Fin dHead → Rat := fun d => + qBaseArr[d.1]'(by + have hsize : qBaseArr.size = dHead := by + simp [qBaseArr] + simp [hsize]) + let kBaseArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.wk j d) (fun j => inputs.ln1Beta j) + + inputs.bk d) + let kBase : Fin dHead → Rat := fun d => + kBaseArr[d.1]'(by + have hsize : kBaseArr.size = dHead := by + simp [kBaseArr] + simp [hsize]) + let qCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + let μ := mean (inputs.embed q) + let coeff : Fin dModel → Rat := fun j => + inputs.ln1Gamma j * (inputs.embed q j - μ) + ⟨Array.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.wq j d) coeff), + by simp⟩)) + let qCoeffArr : Array { row : Array Rat // row.size = dHead } := + Array.ofFn (fun q : Fin seq => + (qCoeffRowTasks[q.1]'(by + have hsize : qCoeffRowTasks.size = seq := by + simp [qCoeffRowTasks] + simp [hsize])).get) + let qCoeff : Fin seq → Fin dHead → Rat := fun q d => + let row := qCoeffArr[q.1]'(by + have hsize : qCoeffArr.size = seq := by + simp [qCoeffArr] + simp [hsize]) + row.1[d.1]'(by + simp [row.2]) + let kCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + let μ := mean (inputs.embed q) + let coeff : Fin dModel → Rat := fun j => + inputs.ln1Gamma j * (inputs.embed q j - μ) + ⟨Array.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.wk j d) coeff), + by simp⟩)) + let kCoeffArr : Array { row : Array Rat // row.size = dHead } := + Array.ofFn (fun q : Fin seq => + (kCoeffRowTasks[q.1]'(by + have hsize : kCoeffRowTasks.size = seq := by + simp [kCoeffRowTasks] + simp [hsize])).get) + let kCoeff : Fin seq → Fin dHead → Rat := fun q d => + let row := kCoeffArr[q.1]'(by + have hsize : kCoeffArr.size = seq := by + simp [kCoeffArr] + simp [hsize]) + row.1[d.1]'(by + simp [row.2]) + let qLo : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) + qBase d + bounds.1 + let qHi : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) + qBase d + bounds.2 + let kLo : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) + kBase d + bounds.1 + let kHi : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) + kBase d + bounds.2 + let qAbs : Fin seq → Fin dHead → Rat := fun q d => max |qLo q d| |qHi q d| + let kAbs : Fin seq → Fin dHead → Rat := fun q d => max |kLo q d| |kHi q d| + let qAbsMaxArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := Finset.univ_nonempty + univ.sup' hnonempty (fun q => qAbs q d)) + let qAbsMax : Fin dHead → Rat := fun d => + qAbsMaxArr[d.1]'(by + have hsize : qAbsMaxArr.size = dHead := by + simp [qAbsMaxArr] + simp [hsize]) + let kAbsMaxArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := Finset.univ_nonempty + univ.sup' hnonempty (fun k => kAbs k d)) + let kAbsMax : Fin dHead → Rat := fun d => + kAbsMaxArr[d.1]'(by + have hsize : kAbsMaxArr.size = dHead := by + simp [kAbsMaxArr] + simp [hsize]) + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let splitBudgetQ : Nat := cfg.splitBudgetQ + let splitBudgetK : Nat := cfg.splitBudgetK + let splitBudgetDiffBase : Nat := cfg.splitBudgetDiffBase + let splitBudgetDiffRefined : Nat := cfg.splitBudgetDiffRefined + let finRangeHead : List (Fin dHead) := List.finRange dHead + let finRangeSeq : List (Fin seq) := List.finRange seq + let splitDimsQ : Fin seq → List (Fin dHead) := fun q => + if splitBudgetQ = 0 then + [] + else + let ambig := + finRangeHead.filter (fun d => decide (qLo q d < 0 ∧ 0 < qHi q d)) + let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d + let step + (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) + (d : Fin dHead) : + Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := + let s := score d + match best with + | (none, none) => (some (s, d), none) + | (some b1, none) => + if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) + | (some b1, some b2) => + if b1.1 < s then (some (s, d), some b1) + else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) + | (none, some b2) => + if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) + let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => + match ambig.foldl step (none, none) with + | (some b1, some b2) => [b1.2, b2.2] + | (some b1, none) => [b1.2] + | (none, _) => [] + let dims1 := top2 ambig + let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) + (dims1 ++ dims2).take splitBudgetQ + let splitDimsK : Fin seq → Fin seq → List (Fin dHead) := fun q k => + if splitBudgetK = 0 then + [] + else + let ambig := + finRangeHead.filter (fun d => decide (kLo k d < 0 ∧ 0 < kHi k d)) + let score : Fin dHead → Rat := fun d => (kHi k d - kLo k d) * qAbs q d + let step + (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) + (d : Fin dHead) : + Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := + let s := score d + match best with + | (none, none) => (some (s, d), none) + | (some b1, none) => + if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) + | (some b1, some b2) => + if b1.1 < s then (some (s, d), some b1) + else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) + | (none, some b2) => + if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) + let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => + match ambig.foldl step (none, none) with + | (some b1, some b2) => [b1.2, b2.2] + | (some b1, none) => [b1.2] + | (none, _) => [] + let dims1 := top2 ambig + let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) + (dims1 ++ dims2).take splitBudgetK + let splitDimsDiffCore : Nat → Fin seq → Fin seq → List (Fin dHead) := fun budget q k => + if budget = 0 then + [] + else + let prev := inputs.prev q + let diffLo : Fin dHead → Rat := fun d => kLo prev d - kHi k d + let diffHi : Fin dHead → Rat := fun d => kHi prev d - kLo k d + let ambig := + finRangeHead.filter (fun d => decide (diffLo d < 0 ∧ 0 < diffHi d)) + let score : Fin dHead → Rat := fun d => (diffHi d - diffLo d) * qAbs q d + let step + (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) + (d : Fin dHead) : + Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := + let s := score d + match best with + | (none, none) => (some (s, d), none) + | (some b1, none) => + if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) + | (some b1, some b2) => + if b1.1 < s then (some (s, d), some b1) + else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) + | (none, some b2) => + if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) + let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => + match ambig.foldl step (none, none) with + | (some b1, some b2) => [b1.2, b2.2] + | (some b1, none) => [b1.2] + | (none, _) => [] + let dims1 : List (Fin dHead) := top2 ambig + let dims2 : List (Fin dHead) := + top2 (ambig.filter (fun d => decide (d ∉ dims1))) + let memDims2 : Fin dHead → Bool := fun d => + dims2.any (fun d' => decide (d' = d)) + let dims3 : List (Fin dHead) := + top2 + ((ambig.filter (fun d => decide (d ∉ dims1))).filter + (fun d => !memDims2 d)) + (dims1 ++ dims2 ++ dims3).take budget + let splitDimsDiffBase : Fin seq → Fin seq → List (Fin dHead) := + splitDimsDiffCore splitBudgetDiffBase + let splitDimsDiffRefined : Fin seq → Fin seq → List (Fin dHead) := + splitDimsDiffCore splitBudgetDiffRefined + let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + let dimsQ := splitDimsQ q + ⟨Array.ofFn (fun k : Fin seq => + if masked q k then + (0, 0) + else + let dimsK := splitDimsK q k + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsK + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo k d) (fun d => kHi k d)), + by simp⟩)) + let dotDiffRowTasksBase : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + if hq : q ∈ inputs.active then + let dimsQ := splitDimsQ q + ⟨Array.ofFn (fun k : Fin seq => + if masked q k then + (0, 0) + else + let dimsDiff := splitDimsDiffBase q k + let prev := inputs.prev q + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)), + by simp⟩ + else + ⟨Array.ofFn (fun _ : Fin seq => (0, 0)), by simp⟩)) + let dotLo : Fin seq → Fin seq → Rat := fun q k => + let row := (dotRowTasks[q.1]'(by + simp [dotRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.1 + let dotHi : Fin seq → Fin seq → Rat := fun q k => + let row := (dotRowTasks[q.1]'(by + simp [dotRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.2 + let dotDiffLoBase : Fin seq → Fin seq → Rat := fun q k => + let row := (dotDiffRowTasksBase[q.1]'(by + simp [dotDiffRowTasksBase, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.1 + let dotDiffHiBase : Fin seq → Fin seq → Rat := fun q k => + let row := (dotDiffRowTasksBase[q.1]'(by + simp [dotDiffRowTasksBase, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.2 + let dotAbs : Fin seq → Fin seq → Rat := fun q k => max |dotLo q k| |dotHi q k| + let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => + |inputs.scale| * dotAbs q k + let scoreLo : Fin seq → Fin seq → Rat := fun q k => + if masked q k then + inputs.maskValue + else + if hscale : 0 ≤ inputs.scale then + inputs.scale * dotLo q k + else + inputs.scale * dotHi q k + let scoreHi : Fin seq → Fin seq → Rat := fun q k => + if masked q k then + inputs.maskValue + else + if hscale : 0 ≤ inputs.scale then + inputs.scale * dotHi q k + else + inputs.scale * dotLo q k + let scoreLoPrev : Fin seq → Rat := fun q => + scoreLo q (inputs.prev q) + let scoreGapLoBaseRaw : Fin seq → Fin seq → Rat := fun q k => + if masked q (inputs.prev q) then + scoreLoPrev q - scoreHi q k + else if masked q k then + scoreLoPrev q - inputs.maskValue + else if hscale : 0 ≤ inputs.scale then + inputs.scale * dotDiffLoBase q k + else + inputs.scale * dotDiffHiBase q k + let scoreGapLoBase : Fin seq → Fin seq → Rat := + Bounds.cacheBound2 scoreGapLoBaseRaw + let otherKeys : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + let worstKeyRaw : Fin seq → Option (Fin seq) := fun q => + if hq : q ∈ inputs.active then + let ks := finRangeSeq.filter (fun k => decide (k ≠ inputs.prev q)) + match ks with + | [] => none + | k :: ks => + let step (best : Rat × Fin seq) (k : Fin seq) := + let s := scoreGapLoBase q k + if s ≤ best.1 then (s, k) else best + some (ks.foldl step (scoreGapLoBase q k, k)).2 + else + none + let worstKeyArr : Array (Thunk (Option (Fin seq))) := + Array.ofFn (fun q => Thunk.mk (fun _ => worstKeyRaw q)) + let worstKey : Fin seq → Option (Fin seq) := fun q => + let t := worstKeyArr[q.1]'(by + simp [worstKeyArr, q.isLt]) + Thunk.get t + let dotDiffLo : Fin seq → Fin seq → Rat := fun q k => + match worstKey q with + | some k' => + if hk : k = k' then + let dimsQ := splitDimsQ q + let dimsDiff := splitDimsDiffRefined q k + let prev := inputs.prev q + (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)).1 + else + dotDiffLoBase q k + | none => dotDiffLoBase q k + let dotDiffHi : Fin seq → Fin seq → Rat := fun q k => + match worstKey q with + | some k' => + if hk : k = k' then + let dimsQ := splitDimsQ q + let dimsDiff := splitDimsDiffRefined q k + let prev := inputs.prev q + (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)).2 + else + dotDiffHiBase q k + | none => dotDiffHiBase q k + let scoreGapLoRaw : Fin seq → Fin seq → Rat := fun q k => + if masked q (inputs.prev q) then + scoreLoPrev q - scoreHi q k + else if masked q k then + scoreLoPrev q - inputs.maskValue + else if hscale : 0 ≤ inputs.scale then + inputs.scale * dotDiffLo q k + else + inputs.scale * dotDiffHi q k + let scoreGapLo : Fin seq → Fin seq → Rat := + Bounds.cacheBound2 scoreGapLoRaw + let marginAt : Fin seq → Rat := fun q => + if hq : q ∈ inputs.active then + let other := otherKeys q + if h : other.Nonempty then + other.inf' h (fun k => scoreGapLo q k) + else + (0 : Rat) + else + (0 : Rat) + let weightBoundAtBase : Fin seq → Fin seq → Rat := fun q k => + if hk : k = inputs.prev q then + (0 : Rat) + else + let gap := scoreGapLo q k + if gap < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + gap) + let weightBoundAt : Fin seq → Fin seq → Rat := + Bounds.cacheBound2 weightBoundAtBase + let epsAtBase : Fin seq → Rat := fun q => + let other := otherKeys q + let total := other.sum (fun k => weightBoundAt q k) + min (1 : Rat) total + let epsAt : Fin seq → Rat := + Bounds.cacheBoundThunk epsAtBase + let margin : Rat := + if h : inputs.active.Nonempty then + inputs.active.inf' h marginAt + else + (0 : Rat) + let eps : Rat := + if h : inputs.active.Nonempty then + inputs.active.sup' h epsAt + else + (0 : Rat) + let dirHeadVec := dirHeadVecOfInputs inputs + let dirHead : Fin dHead → Rat := fun d => dirHeadVec.get d + let wvDir : Fin dModel → Rat := + Bounds.cacheBoundTask (fun j => + Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) + let bDir : Rat := + Linear.dotFin dHead dirHead (fun d => inputs.bv d) + let valsLo : Fin seq → Rat := fun q => + bDir + Bounds.dotIntervalLower (fun j => wvDir j) (lnLo q) (lnHi q) + let valsHi : Fin seq → Rat := fun q => + bDir + Bounds.dotIntervalUpper (fun j => wvDir j) (lnLo q) (lnHi q) + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := by simp [univ] + let lo := univ.inf' hnonempty valsLo + let hi := univ.sup' hnonempty valsHi + let valCert : ValueInterval seq := + { lo := lo + hi := hi + valsLo := valsLo + valsHi := valsHi + direction := some inputs.directionSpec } + let cert : InductionHeadCert seq := + { eps := eps + epsAt := epsAt + weightBoundAt := weightBoundAt + margin := margin + active := inputs.active + prev := inputs.prev + values := valCert } + exact + { lnBounds := lnBounds + lnLo := lnLo + lnHi := lnHi + lnAbsMaxTask := lnAbsMaxTask + lnAbsMaxArr := lnAbsMaxArr + lnAbsMax := lnAbsMax + invStdBoundsTasks := invStdBoundsTasks + invStdBoundsArr := invStdBoundsArr + invStdLo := invStdLo + invStdHi := invStdHi + qBaseArr := qBaseArr + qBase := qBase + kBaseArr := kBaseArr + kBase := kBase + qCoeffRowTasks := qCoeffRowTasks + qCoeffArr := qCoeffArr + qCoeff := qCoeff + kCoeffRowTasks := kCoeffRowTasks + kCoeffArr := kCoeffArr + kCoeff := kCoeff + qLo := qLo + qHi := qHi + kLo := kLo + kHi := kHi + qAbs := qAbs + kAbs := kAbs + qAbsMaxArr := qAbsMaxArr + qAbsMax := qAbsMax + kAbsMaxArr := kAbsMaxArr + kAbsMax := kAbsMax + masked := masked + splitBudgetQ := splitBudgetQ + splitBudgetK := splitBudgetK + splitBudgetDiffBase := splitBudgetDiffBase + splitBudgetDiffRefined := splitBudgetDiffRefined + splitDimsQ := splitDimsQ + splitDimsK := splitDimsK + splitDimsDiffCore := splitDimsDiffCore + splitDimsDiffBase := splitDimsDiffBase + splitDimsDiffRefined := splitDimsDiffRefined + dotRowTasks := dotRowTasks + dotDiffRowTasksBase := dotDiffRowTasksBase + dotLo := dotLo + dotHi := dotHi + dotDiffLoBase := dotDiffLoBase + dotDiffHiBase := dotDiffHiBase + dotAbs := dotAbs + scoreBaseAbs := scoreBaseAbs + scoreLo := scoreLo + scoreHi := scoreHi + scoreLoPrev := scoreLoPrev + scoreGapLoBase := scoreGapLoBase + otherKeys := otherKeys + worstKey := worstKey + dotDiffLo := dotDiffLo + dotDiffHi := dotDiffHi + scoreGapLo := scoreGapLo + marginAt := marginAt + epsAt := epsAt + weightBoundAt := weightBoundAt + margin := margin + eps := eps + dirHeadVec := dirHeadVec + dirHead := dirHead + wvDir := wvDir + bDir := bDir + valsLo := valsLo + valsHi := valsHi + univ := univ + lo := lo + hi := hi + valCert := valCert + cert := cert } + +/-- Build cached core quantities for induction-head certificates using the default split budgets. -/ +def buildInductionHeadCoreCache [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : + InductionHeadCoreCache seq dModel dHead := + buildInductionHeadCoreCacheWith defaultInductionHeadSplitConfig inputs + +/-- The cached certificate is built from cache fields. -/ +theorem buildInductionHeadCoreCache_cert_eq [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : + (buildInductionHeadCoreCache inputs).cert = + { eps := (buildInductionHeadCoreCache inputs).eps + epsAt := (buildInductionHeadCoreCache inputs).epsAt + weightBoundAt := (buildInductionHeadCoreCache inputs).weightBoundAt + margin := (buildInductionHeadCoreCache inputs).margin + active := inputs.active + prev := inputs.prev + values := (buildInductionHeadCoreCache inputs).valCert } := by + rfl +/-- Build induction certificates from exact head inputs (core computation). -/ +def buildInductionCertFromHeadCoreWith? [NeZero seq] {dModel dHead : Nat} + (cfg : InductionHeadSplitConfig) + (inputs : Model.InductionHeadInputs seq dModel dHead) : + Option (InductionHeadCert seq) := by + classical + by_cases hEps : 0 < inputs.lnEps + · by_cases hSqrt : 0 < sqrtLower inputs.lnEps + · by_cases hmodel : dModel = 0 + · exact none + · by_cases hactive : inputs.active.Nonempty + · exact some (buildInductionHeadCoreCacheWith cfg inputs).cert + · exact none + · exact none + · exact none + +/-- Build induction certificates from exact head inputs using the default split budgets. -/ +def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : + Option (InductionHeadCert seq) := + buildInductionCertFromHeadCoreWith? defaultInductionHeadSplitConfig inputs + +/-- `buildInductionCertFromHeadCoreWith?` succeeds under the guard conditions. -/ +theorem buildInductionCertFromHeadCoreWith?_eq_some [NeZero seq] {dModel dHead : Nat} + (cfg : InductionHeadSplitConfig) + (inputs : Model.InductionHeadInputs seq dModel dHead) + (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) + (hmodel : dModel ≠ 0) (hactive : inputs.active.Nonempty) : + buildInductionCertFromHeadCoreWith? cfg inputs = + some (buildInductionHeadCoreCacheWith cfg inputs).cert := by + classical + simp [buildInductionCertFromHeadCoreWith?, hEps, hSqrt, hmodel, hactive] + +/-- `buildInductionCertFromHeadCoreWith?` fails when `dModel = 0`. -/ +theorem buildInductionCertFromHeadCoreWith?_eq_none_of_model_eq_zero + [NeZero seq] {dModel dHead : Nat} + (cfg : InductionHeadSplitConfig) + (inputs : Model.InductionHeadInputs seq dModel dHead) + (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) + (hmodel : dModel = 0) : + buildInductionCertFromHeadCoreWith? cfg inputs = none := by + classical + simp [buildInductionCertFromHeadCoreWith?, hEps, hSqrt, hmodel] + +/-- `buildInductionCertFromHeadCoreWith?` fails when `active` is empty. -/ +theorem buildInductionCertFromHeadCoreWith?_eq_none_of_not_active + [NeZero seq] {dModel dHead : Nat} + (cfg : InductionHeadSplitConfig) + (inputs : Model.InductionHeadInputs seq dModel dHead) + (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) + (hmodel : dModel ≠ 0) (hactive : ¬inputs.active.Nonempty) : + buildInductionCertFromHeadCoreWith? cfg inputs = none := by + classical + simp [buildInductionCertFromHeadCoreWith?, hEps, hSqrt, hmodel, hactive] + +/-- `buildInductionCertFromHeadCoreWith?` fails when the sqrt lower bound is nonpositive. -/ +theorem buildInductionCertFromHeadCoreWith?_eq_none_of_not_sqrt + [NeZero seq] {dModel dHead : Nat} + (cfg : InductionHeadSplitConfig) + (inputs : Model.InductionHeadInputs seq dModel dHead) + (hEps : 0 < inputs.lnEps) (hSqrt : ¬0 < sqrtLower inputs.lnEps) : + buildInductionCertFromHeadCoreWith? cfg inputs = none := by + classical + simp [buildInductionCertFromHeadCoreWith?, hEps, hSqrt] + +/-- `buildInductionCertFromHeadCoreWith?` fails when `lnEps` is nonpositive. -/ +theorem buildInductionCertFromHeadCoreWith?_eq_none_of_not_eps + [NeZero seq] {dModel dHead : Nat} + (cfg : InductionHeadSplitConfig) + (inputs : Model.InductionHeadInputs seq dModel dHead) + (hEps : ¬0 < inputs.lnEps) : + buildInductionCertFromHeadCoreWith? cfg inputs = none := by + classical + simp [buildInductionCertFromHeadCoreWith?, hEps] + +/-- `buildInductionCertFromHeadCore?` succeeds under the guard conditions. -/ +theorem buildInductionCertFromHeadCore?_eq_some [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) + (hmodel : dModel ≠ 0) (hactive : inputs.active.Nonempty) : + buildInductionCertFromHeadCore? inputs = + some (buildInductionHeadCoreCache inputs).cert := by + classical + simpa [buildInductionCertFromHeadCore?, buildInductionHeadCoreCache] using + (buildInductionCertFromHeadCoreWith?_eq_some + (cfg := defaultInductionHeadSplitConfig) (inputs := inputs) + hEps hSqrt hmodel hactive) + +/-- `buildInductionCertFromHeadCore?` fails when `dModel = 0`. -/ +theorem buildInductionCertFromHeadCore?_eq_none_of_model_eq_zero [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) + (hmodel : dModel = 0) : + buildInductionCertFromHeadCore? inputs = none := by + classical + simpa [buildInductionCertFromHeadCore?] using + (buildInductionCertFromHeadCoreWith?_eq_none_of_model_eq_zero + (cfg := defaultInductionHeadSplitConfig) (inputs := inputs) hEps hSqrt hmodel) + +/-- `buildInductionCertFromHeadCore?` fails when `active` is empty. -/ +theorem buildInductionCertFromHeadCore?_eq_none_of_not_active [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) + (hmodel : dModel ≠ 0) (hactive : ¬inputs.active.Nonempty) : + buildInductionCertFromHeadCore? inputs = none := by + classical + simpa [buildInductionCertFromHeadCore?] using + (buildInductionCertFromHeadCoreWith?_eq_none_of_not_active + (cfg := defaultInductionHeadSplitConfig) (inputs := inputs) hEps hSqrt hmodel hactive) + +/-- `buildInductionCertFromHeadCore?` fails when the sqrt lower bound is nonpositive. -/ +theorem buildInductionCertFromHeadCore?_eq_none_of_not_sqrt [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (hEps : 0 < inputs.lnEps) (hSqrt : ¬0 < sqrtLower inputs.lnEps) : + buildInductionCertFromHeadCore? inputs = none := by + classical + simpa [buildInductionCertFromHeadCore?] using + (buildInductionCertFromHeadCoreWith?_eq_none_of_not_sqrt + (cfg := defaultInductionHeadSplitConfig) (inputs := inputs) hEps hSqrt) + +/-- `buildInductionCertFromHeadCore?` fails when `lnEps` is nonpositive. -/ +theorem buildInductionCertFromHeadCore?_eq_none_of_not_eps [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (hEps : ¬0 < inputs.lnEps) : + buildInductionCertFromHeadCore? inputs = none := by + classical + simpa [buildInductionCertFromHeadCore?] using + (buildInductionCertFromHeadCoreWith?_eq_none_of_not_eps + (cfg := defaultInductionHeadSplitConfig) (inputs := inputs) hEps) + +end Sound +end Nfp From 41b1246e266a37ffc0bbb265c2016fef01b9ba74 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 19:07:34 +0100 Subject: [PATCH 167/244] Pilot module-system conversion for core modules --- Nfp/Core.lean | 4 +++- Nfp/Core/Basic.lean | 15 ++++++++++++--- Nfp/Mixer.lean | 6 ++++-- Nfp/Mixer/Basic.lean | 8 +++++++- Nfp/Mixer/Operations.lean | 12 +++++++++--- Nfp/Prob.lean | 6 ++++-- Nfp/Prob/Basic.lean | 10 ++++++++-- Nfp/Prob/Operations.lean | 10 ++++++++-- Nfp/System.lean | 6 ++++-- Nfp/System/Dag.lean | 11 +++++++++-- Nfp/System/LocalSystem.lean | 12 +++++++++--- lakefile.toml | 1 + 12 files changed, 78 insertions(+), 23 deletions(-) diff --git a/Nfp/Core.lean b/Nfp/Core.lean index e9e6bc6..d4de670 100644 --- a/Nfp/Core.lean +++ b/Nfp/Core.lean @@ -1,6 +1,8 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Core.Basic +module + +public import Nfp.Core.Basic /-! Core shared definitions for the NFP rewrite. diff --git a/Nfp/Core/Basic.lean b/Nfp/Core/Basic.lean index 758f488..5d06aab 100644 --- a/Nfp/Core/Basic.lean +++ b/Nfp/Core/Basic.lean @@ -1,13 +1,20 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Data.NNReal.Basic -import Mathlib.Data.Rat.Cast.Lemmas -import Mathlib.Data.Rat.Cast.Order +module + +public import Mathlib.Algebra.Order.Group.Unbundled.Abs +public import Mathlib.Data.NNReal.Defs +public import Mathlib.Data.NNReal.Basic +public import Mathlib.Data.Rat.Cast.Lemmas +public import Mathlib.Data.Rat.Cast.Order +public import Mathlib.Data.Real.Basic /-! Basic shared definitions for the NFP rewrite. -/ +@[expose] public section + namespace Nfp /-- Nonnegative mass used for probabilities and weights. -/ @@ -155,3 +162,5 @@ theorem ratToReal_abs_le_of_le {x y : Rat} (h : |x| ≤ y) : simp [ratToReal] end Nfp + +end diff --git a/Nfp/Mixer.lean b/Nfp/Mixer.lean index 747fc6f..66abc23 100644 --- a/Nfp/Mixer.lean +++ b/Nfp/Mixer.lean @@ -1,7 +1,9 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Mixer.Basic -import Nfp.Mixer.Operations +module + +public import Nfp.Mixer.Basic +public import Nfp.Mixer.Operations /-! Row-stochastic mixers. diff --git a/Nfp/Mixer/Basic.lean b/Nfp/Mixer/Basic.lean index 44c0317..3549e4c 100644 --- a/Nfp/Mixer/Basic.lean +++ b/Nfp/Mixer/Basic.lean @@ -1,11 +1,15 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Prob.Basic +module + +public import Nfp.Prob.Basic /-! Row-stochastic mixers. -/ +@[expose] public section + open scoped BigOperators namespace Nfp @@ -35,3 +39,5 @@ def row (M : Mixer ι κ) (i : ι) : ProbVec κ := end Mixer end Nfp + +end diff --git a/Nfp/Mixer/Operations.lean b/Nfp/Mixer/Operations.lean index 5508ddf..bf70c7a 100644 --- a/Nfp/Mixer/Operations.lean +++ b/Nfp/Mixer/Operations.lean @@ -1,13 +1,17 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Mixer.Basic -import Nfp.Prob.Operations -import Mathlib.Algebra.BigOperators.Ring.Finset +module + +public import Nfp.Mixer.Basic +public import Nfp.Prob.Operations +public import Mathlib.Algebra.BigOperators.Ring.Finset /-! Mixer operations (pushforward, composition, identity). -/ +@[expose] public section + open scoped BigOperators namespace Nfp @@ -55,3 +59,5 @@ def id (ι : Type u) [Fintype ι] [DecidableEq ι] : Mixer ι ι := end Mixer end Nfp + +end diff --git a/Nfp/Prob.lean b/Nfp/Prob.lean index 292da09..c4ffe1f 100644 --- a/Nfp/Prob.lean +++ b/Nfp/Prob.lean @@ -1,7 +1,9 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Prob.Basic -import Nfp.Prob.Operations +module + +public import Nfp.Prob.Basic +public import Nfp.Prob.Operations /-! Probability vectors. diff --git a/Nfp/Prob/Basic.lean b/Nfp/Prob/Basic.lean index 92a3bb4..b37d658 100644 --- a/Nfp/Prob/Basic.lean +++ b/Nfp/Prob/Basic.lean @@ -1,12 +1,16 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Core -import Mathlib.Data.Fintype.BigOperators +module + +public import Nfp.Core +public import Mathlib.Data.Fintype.BigOperators /-! Probability vectors on finite types. -/ +@[expose] public section + open scoped BigOperators namespace Nfp @@ -31,3 +35,5 @@ instance : CoeFun (ProbVec ι) (fun _ => ι → Mass) := ⟨ProbVec.mass⟩ end ProbVec end Nfp + +end diff --git a/Nfp/Prob/Operations.lean b/Nfp/Prob/Operations.lean index 79440d4..ff957fa 100644 --- a/Nfp/Prob/Operations.lean +++ b/Nfp/Prob/Operations.lean @@ -1,12 +1,16 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Prob.Basic -import Mathlib.Algebra.BigOperators.Ring.Finset +module + +public import Nfp.Prob.Basic +public import Mathlib.Algebra.BigOperators.Ring.Finset /-! Basic constructions on probability vectors. -/ +@[expose] public section + open scoped BigOperators namespace Nfp @@ -43,3 +47,5 @@ def mix (a b : Mass) (h : a + b = 1) (p q : ProbVec ι) : ProbVec ι := end ProbVec end Nfp + +end diff --git a/Nfp/System.lean b/Nfp/System.lean index ab8e7ad..9ee4bdb 100644 --- a/Nfp/System.lean +++ b/Nfp/System.lean @@ -1,7 +1,9 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.System.Dag -import Nfp.System.LocalSystem +module + +public import Nfp.System.Dag +public import Nfp.System.LocalSystem /-! DAG-based system foundations. diff --git a/Nfp/System/Dag.lean b/Nfp/System/Dag.lean index 1049fe7..7cde304 100644 --- a/Nfp/System/Dag.lean +++ b/Nfp/System/Dag.lean @@ -1,12 +1,17 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Combinatorics.Digraph.Basic -import Mathlib.Data.Finset.Basic +module + +public import Mathlib.Combinatorics.Digraph.Basic +public import Mathlib.Data.Fintype.Defs +public import Mathlib.Data.Finset.Basic /-! Directed acyclic graph foundations. -/ +@[expose] public section + namespace Nfp universe u u' @@ -68,3 +73,5 @@ end Relabel end Dag end Nfp + +end diff --git a/Nfp/System/LocalSystem.lean b/Nfp/System/LocalSystem.lean index 0ee1662..5fc3f9e 100644 --- a/Nfp/System/LocalSystem.lean +++ b/Nfp/System/LocalSystem.lean @@ -1,13 +1,17 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Data.Fintype.BigOperators -import Nfp.Mixer.Basic -import Nfp.System.Dag +module + +public import Mathlib.Data.Fintype.BigOperators +public import Nfp.Mixer.Basic +public import Nfp.System.Dag /-! Local mixing systems on finite DAGs. -/ +@[expose] public section + open scoped BigOperators namespace Nfp @@ -68,3 +72,5 @@ theorem eval_eq (L : LocalSystem ι) (input : ι → Mass) (i : ι) : end LocalSystem end Nfp + +end diff --git a/lakefile.toml b/lakefile.toml index 29a1b9d..d394613 100644 --- a/lakefile.toml +++ b/lakefile.toml @@ -7,6 +7,7 @@ defaultTargets = ["Nfp"] pp.unicode.fun = true # pretty-prints `fun a ↦ b` autoImplicit = false relaxedAutoImplicit = false +experimental.module = true weak.linter.mathlibStandardSet = true maxSynthPendingDepth = 3 linter.unusedVariables = true From 97179ce7feb94c2f72cff83061f99db57a1a7a28 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 19:33:15 +0100 Subject: [PATCH 168/244] Refactor modules across circuit, model, sound, and IO --- Nfp.lean | 16 ++-- Nfp/Circuit/Basic.lean | 6 +- Nfp/Circuit/Cert.lean | 16 ++-- Nfp/Circuit/Cert/Basic.lean | 14 ++-- Nfp/Circuit/Cert/DownstreamLinear.lean | 8 +- Nfp/Circuit/Cert/LogitDiff.lean | 10 ++- Nfp/Circuit/Cert/ResidualBound.lean | 8 +- Nfp/Circuit/Cert/ResidualInterval.lean | 8 +- Nfp/Circuit/Cert/SoftmaxMargin.lean | 12 ++- Nfp/Circuit/Cert/ValueRange.lean | 12 ++- Nfp/Circuit/Combinators.lean | 10 ++- Nfp/Circuit/Compose.lean | 14 ++-- Nfp/Circuit/Gates.lean | 6 +- Nfp/Circuit/Gates/Basic.lean | 10 ++- Nfp/Circuit/Gates/Linear.lean | 6 +- Nfp/Circuit/Interface.lean | 6 +- Nfp/Circuit/Layers.lean | 18 ++-- Nfp/Circuit/Layers/Attention.lean | 14 ++-- Nfp/Circuit/Layers/Heads.lean | 10 ++- Nfp/Circuit/Layers/Induction.lean | 4 +- Nfp/Circuit/Layers/Induction/Basic.lean | 14 ++-- Nfp/Circuit/Layers/Linear.lean | 14 ++-- Nfp/Circuit/Layers/Reshape.lean | 8 +- Nfp/Circuit/Layers/Softmax.lean | 12 ++- Nfp/Circuit/Layers/Tensor.lean | 6 +- Nfp/Circuit/Layers/TransformerBlock.lean | 8 +- Nfp/Circuit/Semantics.lean | 6 +- Nfp/Circuit/Tensor.lean | 6 +- Nfp/Circuit/Typed.lean | 8 +- Nfp/Circuit/WellFormed.lean | 6 +- Nfp/IO/Checks.lean | 8 +- Nfp/IO/Derive.lean | 22 +++-- Nfp/IO/HeadScore.lean | 10 ++- Nfp/IO/InductionHead.lean | 4 +- Nfp/IO/InductionHead/Basic.lean | 28 ++++--- Nfp/IO/Loaders.lean | 14 ++-- Nfp/IO/NfptPure.lean | 18 ++-- Nfp/IO/Pure.lean | 14 ++-- Nfp/IO/Pure/Basic.lean | 6 +- Nfp/IO/Pure/Downstream.lean | 8 +- Nfp/IO/Pure/InductionHead.lean | 6 +- Nfp/IO/Pure/InductionHead/Bytes.lean | 16 ++-- Nfp/IO/Pure/Residual.lean | 10 ++- Nfp/IO/Pure/SoftmaxMargin.lean | 6 +- Nfp/IO/Pure/SoftmaxMargin/Cert.lean | 8 +- Nfp/IO/Pure/SoftmaxMargin/Raw.lean | 8 +- Nfp/IO/Pure/SoftmaxMargin/Shared.lean | 8 +- Nfp/IO/Pure/ValueRange.lean | 6 +- Nfp/IO/Pure/ValueRange/Cert.lean | 8 +- Nfp/IO/Pure/ValueRange/Raw.lean | 8 +- Nfp/IO/Pure/ValueRange/Shared.lean | 8 +- Nfp/IO/Run.lean | 4 +- Nfp/IO/Run/Basic.lean | 40 +++++---- Nfp/IO/Timing.lean | 10 ++- Nfp/IO/Util.lean | 6 +- Nfp/Model/Gpt2.lean | 8 +- Nfp/Model/InductionHead.lean | 10 ++- Nfp/Model/InductionPrompt.lean | 8 +- Nfp/Sound/Bounds.lean | 20 +++-- Nfp/Sound/Bounds/Attention.lean | 26 +++--- Nfp/Sound/Bounds/Cache.lean | 6 +- Nfp/Sound/Bounds/Gelu.lean | 12 ++- Nfp/Sound/Bounds/LayerNorm.lean | 10 ++- Nfp/Sound/Bounds/LayerNorm/Basic.lean | 28 ++++--- Nfp/Sound/Bounds/LayerNorm/InvStd.lean | 8 +- Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean | 20 +++-- Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean | 16 ++-- Nfp/Sound/Bounds/MatrixNorm.lean | 6 +- Nfp/Sound/Bounds/MatrixNorm/Basic.lean | 26 +++--- Nfp/Sound/Bounds/MatrixNorm/Interval.lean | 18 ++-- Nfp/Sound/Bounds/Mlp.lean | 14 ++-- Nfp/Sound/Bounds/Transformer.lean | 6 +- Nfp/Sound/Bounds/Transformer/Basic.lean | 22 +++-- Nfp/Sound/Bounds/Transformer/Embedding.lean | 8 +- Nfp/Sound/Bounds/UnnormRat.lean | 8 +- Nfp/Sound/Gpt2/HeadInputs.lean | 10 ++- Nfp/Sound/Induction.lean | 16 ++-- Nfp/Sound/Induction/Core.lean | 4 +- Nfp/Sound/Induction/Core/Basic.lean | 37 +++++---- Nfp/Sound/Induction/CoreDefs.lean | 22 +++-- Nfp/Sound/Induction/CoreSound.lean | 4 +- Nfp/Sound/Induction/CoreSound/Basic.lean | 9 +- Nfp/Sound/Induction/CoreSound/Values.lean | 10 ++- Nfp/Sound/Induction/EndToEnd.lean | 10 ++- Nfp/Sound/Induction/HeadBounds.lean | 4 +- Nfp/Sound/Induction/HeadBounds/Basic.lean | 86 ++++++++++---------- Nfp/Sound/Induction/HeadOutput.lean | 8 +- Nfp/Sound/Induction/LogitDiff.lean | 14 ++-- Nfp/Sound/Induction/OneHot.lean | 16 ++-- Nfp/Sound/Linear/FinFold.lean | 14 ++-- 90 files changed, 725 insertions(+), 399 deletions(-) diff --git a/Nfp.lean b/Nfp.lean index c99e563..2dae38d 100644 --- a/Nfp.lean +++ b/Nfp.lean @@ -1,12 +1,14 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Core -import Nfp.Prob -import Nfp.Mixer -import Nfp.System -import Nfp.Circuit -import Nfp.Model -import Nfp.Sound +module + +public import Nfp.Core +public import Nfp.Prob +public import Nfp.Mixer +public import Nfp.System +public import Nfp.Circuit +public import Nfp.Model +public import Nfp.Sound /-! Top-level reexports and trust dashboard for the NFP rewrite. diff --git a/Nfp/Circuit/Basic.lean b/Nfp/Circuit/Basic.lean index 2570bc8..dd1f58a 100644 --- a/Nfp/Circuit/Basic.lean +++ b/Nfp/Circuit/Basic.lean @@ -1,11 +1,15 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.System.Dag +module + +public import Nfp.System.Dag /-! Circuit foundations: a DAG with designated inputs/outputs and gate semantics. -/ +@[expose] public section + namespace Nfp universe u v diff --git a/Nfp/Circuit/Cert.lean b/Nfp/Circuit/Cert.lean index 7e21afd..b71e3c0 100644 --- a/Nfp/Circuit/Cert.lean +++ b/Nfp/Circuit/Cert.lean @@ -1,12 +1,14 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Circuit.Cert.Basic -import Nfp.Circuit.Cert.DownstreamLinear -import Nfp.Circuit.Cert.LogitDiff -import Nfp.Circuit.Cert.ResidualBound -import Nfp.Circuit.Cert.ResidualInterval -import Nfp.Circuit.Cert.SoftmaxMargin -import Nfp.Circuit.Cert.ValueRange +module + +public import Nfp.Circuit.Cert.Basic +public import Nfp.Circuit.Cert.DownstreamLinear +public import Nfp.Circuit.Cert.LogitDiff +public import Nfp.Circuit.Cert.ResidualBound +public import Nfp.Circuit.Cert.ResidualInterval +public import Nfp.Circuit.Cert.SoftmaxMargin +public import Nfp.Circuit.Cert.ValueRange /-! Certificate definitions and checkers for circuits. diff --git a/Nfp/Circuit/Cert/Basic.lean b/Nfp/Circuit/Cert/Basic.lean index d4d3705..3b4186e 100644 --- a/Nfp/Circuit/Cert/Basic.lean +++ b/Nfp/Circuit/Cert/Basic.lean @@ -1,15 +1,19 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Data.Finset.Fold -import Mathlib.Data.Finset.Insert -import Mathlib.Data.Fintype.Pi -import Nfp.Circuit.Interface -import Nfp.Circuit.Semantics +module + +public import Mathlib.Data.Finset.Fold +public import Mathlib.Data.Finset.Insert +public import Mathlib.Data.Fintype.Pi +public import Nfp.Circuit.Interface +public import Nfp.Circuit.Semantics /-! Circuit equivalence and a finite checker. -/ +@[expose] public section + namespace Nfp universe u v u' u_in u_out diff --git a/Nfp/Circuit/Cert/DownstreamLinear.lean b/Nfp/Circuit/Cert/DownstreamLinear.lean index e612c1e..dfb8662 100644 --- a/Nfp/Circuit/Cert/DownstreamLinear.lean +++ b/Nfp/Circuit/Cert/DownstreamLinear.lean @@ -1,7 +1,9 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Core.Basic -import Nfp.Circuit.Cert.Basic +module + +public import Nfp.Core.Basic +public import Nfp.Circuit.Cert.Basic /-! Downstream linear certificates for end-to-end induction bounds. @@ -11,6 +13,8 @@ The checker only verifies arithmetic consistency (`error = gain * inputBound`) and nonnegativity of the reported quantities. -/ +@[expose] public section + namespace Nfp namespace Circuit diff --git a/Nfp/Circuit/Cert/LogitDiff.lean b/Nfp/Circuit/Cert/LogitDiff.lean index c521724..6a43101 100644 --- a/Nfp/Circuit/Cert/LogitDiff.lean +++ b/Nfp/Circuit/Cert/LogitDiff.lean @@ -1,13 +1,17 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Core.Basic -import Mathlib.Data.Finset.Image -import Nfp.Circuit.Layers.Induction +module + +public import Nfp.Core.Basic +public import Mathlib.Data.Finset.Image +public import Nfp.Circuit.Layers.Induction /-! Lower bounds for logit-diff contributions from induction-style heads. -/ +@[expose] public section + namespace Nfp namespace Circuit diff --git a/Nfp/Circuit/Cert/ResidualBound.lean b/Nfp/Circuit/Cert/ResidualBound.lean index 09cf83e..1287511 100644 --- a/Nfp/Circuit/Cert/ResidualBound.lean +++ b/Nfp/Circuit/Cert/ResidualBound.lean @@ -1,7 +1,9 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Core.Basic -import Nfp.Circuit.Cert.Basic +module + +public import Nfp.Core.Basic +public import Nfp.Circuit.Cert.Basic /-! Residual-stream bound certificates. @@ -9,6 +11,8 @@ Residual-stream bound certificates. These certificates record per-coordinate absolute bounds for residual vectors. -/ +@[expose] public section + namespace Nfp namespace Circuit diff --git a/Nfp/Circuit/Cert/ResidualInterval.lean b/Nfp/Circuit/Cert/ResidualInterval.lean index b1d74c7..d9f5f99 100644 --- a/Nfp/Circuit/Cert/ResidualInterval.lean +++ b/Nfp/Circuit/Cert/ResidualInterval.lean @@ -1,7 +1,9 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Core.Basic -import Nfp.Circuit.Cert.Basic +module + +public import Nfp.Core.Basic +public import Nfp.Circuit.Cert.Basic /-! Residual-stream interval certificates. @@ -9,6 +11,8 @@ Residual-stream interval certificates. These certificates record per-coordinate lower/upper bounds for residual vectors. -/ +@[expose] public section + namespace Nfp namespace Circuit diff --git a/Nfp/Circuit/Cert/SoftmaxMargin.lean b/Nfp/Circuit/Cert/SoftmaxMargin.lean index 5987f6c..263184c 100644 --- a/Nfp/Circuit/Cert/SoftmaxMargin.lean +++ b/Nfp/Circuit/Cert/SoftmaxMargin.lean @@ -1,14 +1,18 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Nfp.Core.Basic -import Nfp.Circuit.Cert.Basic -import Nfp.Circuit.Layers.Induction +module + +public import Mathlib.Algebra.BigOperators.Group.Finset.Basic +public import Nfp.Core.Basic +public import Nfp.Circuit.Cert.Basic +public import Nfp.Circuit.Layers.Induction /-! Softmax-margin certificates for approximate one-hot attention weights. -/ +@[expose] public section + namespace Nfp namespace Circuit diff --git a/Nfp/Circuit/Cert/ValueRange.lean b/Nfp/Circuit/Cert/ValueRange.lean index 342f93f..cbaf2a8 100644 --- a/Nfp/Circuit/Cert/ValueRange.lean +++ b/Nfp/Circuit/Cert/ValueRange.lean @@ -1,14 +1,18 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Nfp.Core.Basic -import Nfp.Circuit.Cert.Basic -import Nfp.Circuit.Layers.Induction +module + +public import Mathlib.Algebra.BigOperators.Group.Finset.Basic +public import Nfp.Core.Basic +public import Nfp.Circuit.Cert.Basic +public import Nfp.Circuit.Layers.Induction /-! Value-range certificates for attention value vectors. -/ +@[expose] public section + namespace Nfp namespace Circuit diff --git a/Nfp/Circuit/Combinators.lean b/Nfp/Circuit/Combinators.lean index 02bd981..d3202f9 100644 --- a/Nfp/Circuit/Combinators.lean +++ b/Nfp/Circuit/Combinators.lean @@ -1,13 +1,17 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Data.Finset.Image -import Mathlib.Logic.Equiv.Basic -import Nfp.Circuit.Interface +module + +public import Mathlib.Data.Finset.Image +public import Mathlib.Logic.Equiv.Basic +public import Nfp.Circuit.Interface /-! Circuit combinators such as relabeling. -/ +@[expose] public section + namespace Nfp universe u v u' u_in u_out diff --git a/Nfp/Circuit/Compose.lean b/Nfp/Circuit/Compose.lean index 77fedb7..f5a25d5 100644 --- a/Nfp/Circuit/Compose.lean +++ b/Nfp/Circuit/Compose.lean @@ -1,15 +1,19 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Data.Finset.Disjoint -import Mathlib.Data.Fintype.Sum -import Mathlib.Data.Sum.Order -import Mathlib.Logic.Embedding.Basic -import Nfp.Circuit.Typed +module + +public import Mathlib.Data.Finset.Disjoint +public import Mathlib.Data.Fintype.Sum +public import Mathlib.Data.Sum.Order +public import Mathlib.Logic.Embedding.Basic +public import Nfp.Circuit.Typed /-! Combinators for composing typed circuits (sequential and residual wiring). -/ +@[expose] public section + namespace Nfp universe u v u' u_in u_mid u_out diff --git a/Nfp/Circuit/Gates.lean b/Nfp/Circuit/Gates.lean index 2e96c14..06b6d54 100644 --- a/Nfp/Circuit/Gates.lean +++ b/Nfp/Circuit/Gates.lean @@ -1,7 +1,9 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Circuit.Gates.Basic -import Nfp.Circuit.Gates.Linear +module + +public import Nfp.Circuit.Gates.Basic +public import Nfp.Circuit.Gates.Linear /-! Gate combinators for circuit semantics. diff --git a/Nfp/Circuit/Gates/Basic.lean b/Nfp/Circuit/Gates/Basic.lean index 0cb28d4..d8ec954 100644 --- a/Nfp/Circuit/Gates/Basic.lean +++ b/Nfp/Circuit/Gates/Basic.lean @@ -1,13 +1,17 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.Ring.Basic -import Mathlib.Data.Finset.Attach -import Mathlib.Data.Fintype.BigOperators +module + +public import Mathlib.Algebra.Ring.Basic +public import Mathlib.Data.Finset.Attach +public import Mathlib.Data.Fintype.BigOperators /-! Basic gate combinators for aggregating parent values. -/ +@[expose] public section + namespace Nfp namespace Circuit diff --git a/Nfp/Circuit/Gates/Linear.lean b/Nfp/Circuit/Gates/Linear.lean index 4f42f7c..603c57b 100644 --- a/Nfp/Circuit/Gates/Linear.lean +++ b/Nfp/Circuit/Gates/Linear.lean @@ -1,11 +1,15 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Data.Matrix.Mul +module + +public import Mathlib.Data.Matrix.Mul /-! Linear and affine gate combinators built from `Matrix.mulVec`. -/ +@[expose] public section + namespace Nfp namespace Circuit diff --git a/Nfp/Circuit/Interface.lean b/Nfp/Circuit/Interface.lean index 4ec091a..fca84b1 100644 --- a/Nfp/Circuit/Interface.lean +++ b/Nfp/Circuit/Interface.lean @@ -1,11 +1,15 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Circuit.Semantics +module + +public import Nfp.Circuit.Semantics /-! Typed input/output interfaces for circuits. -/ +@[expose] public section + namespace Nfp universe u v u_in u_out diff --git a/Nfp/Circuit/Layers.lean b/Nfp/Circuit/Layers.lean index 5a54a6d..d0e0111 100644 --- a/Nfp/Circuit/Layers.lean +++ b/Nfp/Circuit/Layers.lean @@ -1,13 +1,15 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Circuit.Layers.Linear -import Nfp.Circuit.Layers.Tensor -import Nfp.Circuit.Layers.Reshape -import Nfp.Circuit.Layers.Heads -import Nfp.Circuit.Layers.Attention -import Nfp.Circuit.Layers.Softmax -import Nfp.Circuit.Layers.Induction -import Nfp.Circuit.Layers.TransformerBlock +module + +public import Nfp.Circuit.Layers.Linear +public import Nfp.Circuit.Layers.Tensor +public import Nfp.Circuit.Layers.Reshape +public import Nfp.Circuit.Layers.Heads +public import Nfp.Circuit.Layers.Attention +public import Nfp.Circuit.Layers.Softmax +public import Nfp.Circuit.Layers.Induction +public import Nfp.Circuit.Layers.TransformerBlock /-! Circuit layer combinators. diff --git a/Nfp/Circuit/Layers/Attention.lean b/Nfp/Circuit/Layers/Attention.lean index 2b4e478..8df9801 100644 --- a/Nfp/Circuit/Layers/Attention.lean +++ b/Nfp/Circuit/Layers/Attention.lean @@ -1,15 +1,19 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Data.Finset.Image -import Mathlib.Data.Matrix.Mul -import Mathlib.Logic.Embedding.Basic -import Nfp.Circuit.Layers.Heads -import Nfp.Circuit.Layers.Tensor +module + +public import Mathlib.Data.Finset.Image +public import Mathlib.Data.Matrix.Mul +public import Mathlib.Logic.Embedding.Basic +public import Nfp.Circuit.Layers.Heads +public import Nfp.Circuit.Layers.Tensor /-! QKV and output projection wiring for attention layers, plus attention score/mixing core. -/ +@[expose] public section + namespace Nfp namespace Circuit diff --git a/Nfp/Circuit/Layers/Heads.lean b/Nfp/Circuit/Layers/Heads.lean index 84a7f4c..e12719b 100644 --- a/Nfp/Circuit/Layers/Heads.lean +++ b/Nfp/Circuit/Layers/Heads.lean @@ -1,13 +1,17 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Logic.Equiv.Fin.Basic -import Mathlib.Logic.Equiv.Prod -import Nfp.Circuit.Layers.Reshape +module + +public import Mathlib.Logic.Equiv.Fin.Basic +public import Mathlib.Logic.Equiv.Prod +public import Nfp.Circuit.Layers.Reshape /-! Head split/merge combinators for transformer-style shapes. -/ +@[expose] public section + namespace Nfp namespace Circuit diff --git a/Nfp/Circuit/Layers/Induction.lean b/Nfp/Circuit/Layers/Induction.lean index 7072cad..934c3dc 100644 --- a/Nfp/Circuit/Layers/Induction.lean +++ b/Nfp/Circuit/Layers/Induction.lean @@ -1,6 +1,8 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Circuit.Layers.Induction.Basic +module + +public import Nfp.Circuit.Layers.Induction.Basic /-! Induction-head layer wiring and helper lemmas. diff --git a/Nfp/Circuit/Layers/Induction/Basic.lean b/Nfp/Circuit/Layers/Induction/Basic.lean index 8e1c3b6..da4fcd6 100644 --- a/Nfp/Circuit/Layers/Induction/Basic.lean +++ b/Nfp/Circuit/Layers/Induction/Basic.lean @@ -1,15 +1,19 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Algebra.BigOperators.Ring.Finset -import Mathlib.Algebra.Order.Monoid.Unbundled.Basic -import Mathlib.Algebra.Order.Ring.Defs -import Nfp.Circuit.Layers.Attention +module + +public import Mathlib.Algebra.BigOperators.Group.Finset.Basic +public import Mathlib.Algebra.BigOperators.Ring.Finset +public import Mathlib.Algebra.Order.Monoid.Unbundled.Basic +public import Mathlib.Algebra.Order.Ring.Defs +public import Nfp.Circuit.Layers.Attention /-! Induction-head specifications for attention cores. -/ +@[expose] public section + namespace Nfp namespace Circuit diff --git a/Nfp/Circuit/Layers/Linear.lean b/Nfp/Circuit/Layers/Linear.lean index 62c28e5..57b6b3a 100644 --- a/Nfp/Circuit/Layers/Linear.lean +++ b/Nfp/Circuit/Layers/Linear.lean @@ -1,15 +1,19 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Data.Finset.Image -import Mathlib.Logic.Embedding.Basic -import Nfp.Circuit.Basic -import Nfp.Circuit.Gates.Linear -import Nfp.Circuit.Typed +module + +public import Mathlib.Data.Finset.Image +public import Mathlib.Logic.Embedding.Basic +public import Nfp.Circuit.Basic +public import Nfp.Circuit.Gates.Linear +public import Nfp.Circuit.Typed /-! Linear and affine layer circuits. -/ +@[expose] public section + namespace Nfp namespace Circuit diff --git a/Nfp/Circuit/Layers/Reshape.lean b/Nfp/Circuit/Layers/Reshape.lean index 0737007..e0180c1 100644 --- a/Nfp/Circuit/Layers/Reshape.lean +++ b/Nfp/Circuit/Layers/Reshape.lean @@ -1,12 +1,16 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Logic.Equiv.Prod -import Nfp.Circuit.Typed +module + +public import Mathlib.Logic.Equiv.Prod +public import Nfp.Circuit.Typed /-! Reshape combinators for product-typed circuit interfaces. -/ +@[expose] public section + namespace Nfp namespace Circuit diff --git a/Nfp/Circuit/Layers/Softmax.lean b/Nfp/Circuit/Layers/Softmax.lean index 3e7968f..d9c1bbc 100644 --- a/Nfp/Circuit/Layers/Softmax.lean +++ b/Nfp/Circuit/Layers/Softmax.lean @@ -1,9 +1,11 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.BigOperators.Field -import Mathlib.Algebra.Order.BigOperators.Group.Finset -import Mathlib.Analysis.Complex.Exponential -import Mathlib.Data.Finset.Card +module + +public import Mathlib.Algebra.BigOperators.Field +public import Mathlib.Algebra.Order.BigOperators.Group.Finset +public import Mathlib.Analysis.Complex.Exponential +public import Mathlib.Data.Finset.Card /-! Real-valued softmax utilities and margin-based bounds. @@ -12,6 +14,8 @@ These lemmas provide the analytical bridge from score gaps to softmax weight upper bounds. -/ +@[expose] public section + namespace Nfp namespace Circuit diff --git a/Nfp/Circuit/Layers/Tensor.lean b/Nfp/Circuit/Layers/Tensor.lean index 7d5a21d..768f329 100644 --- a/Nfp/Circuit/Layers/Tensor.lean +++ b/Nfp/Circuit/Layers/Tensor.lean @@ -1,11 +1,15 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Circuit.Layers.Linear +module + +public import Nfp.Circuit.Layers.Linear /-! Tensor-shaped layer builders (batched linear and affine layers). -/ +@[expose] public section + namespace Nfp namespace Circuit diff --git a/Nfp/Circuit/Layers/TransformerBlock.lean b/Nfp/Circuit/Layers/TransformerBlock.lean index 657cd59..3124023 100644 --- a/Nfp/Circuit/Layers/TransformerBlock.lean +++ b/Nfp/Circuit/Layers/TransformerBlock.lean @@ -1,12 +1,16 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Circuit.Compose -import Nfp.Circuit.Layers.Attention +module + +public import Nfp.Circuit.Compose +public import Nfp.Circuit.Layers.Attention /-! Transformer block wiring built from sequential composition and residual links. -/ +@[expose] public section + namespace Nfp namespace Circuit diff --git a/Nfp/Circuit/Semantics.lean b/Nfp/Circuit/Semantics.lean index 300129f..f30d6d6 100644 --- a/Nfp/Circuit/Semantics.lean +++ b/Nfp/Circuit/Semantics.lean @@ -1,11 +1,15 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Circuit.Basic +module + +public import Nfp.Circuit.Basic /-! Evaluation semantics for finite circuits. -/ +@[expose] public section + namespace Nfp universe u v diff --git a/Nfp/Circuit/Tensor.lean b/Nfp/Circuit/Tensor.lean index 165cb26..f8b6917 100644 --- a/Nfp/Circuit/Tensor.lean +++ b/Nfp/Circuit/Tensor.lean @@ -1,11 +1,15 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Data.Matrix.Basic +module + +public import Mathlib.Data.Matrix.Basic /-! Typed tensor indices and tensor aliases. -/ +@[expose] public section + namespace Nfp namespace Circuit diff --git a/Nfp/Circuit/Typed.lean b/Nfp/Circuit/Typed.lean index b45d63d..1ef1330 100644 --- a/Nfp/Circuit/Typed.lean +++ b/Nfp/Circuit/Typed.lean @@ -1,12 +1,16 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Circuit.Combinators -import Nfp.Circuit.Cert.Basic +module + +public import Nfp.Circuit.Combinators +public import Nfp.Circuit.Cert.Basic /-! Typed circuit wrappers and typed equivalence checking. -/ +@[expose] public section + namespace Nfp universe u v u' u_in u_out diff --git a/Nfp/Circuit/WellFormed.lean b/Nfp/Circuit/WellFormed.lean index f4a0535..e04bdb9 100644 --- a/Nfp/Circuit/WellFormed.lean +++ b/Nfp/Circuit/WellFormed.lean @@ -1,11 +1,15 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Circuit.Basic +module + +public import Nfp.Circuit.Basic /-! Well-formedness conditions for circuits. -/ +@[expose] public section + namespace Nfp universe u v diff --git a/Nfp/IO/Checks.lean b/Nfp/IO/Checks.lean index 224b3ac..edae3a7 100644 --- a/Nfp/IO/Checks.lean +++ b/Nfp/IO/Checks.lean @@ -1,12 +1,16 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Circuit.Cert.SoftmaxMargin -import Nfp.Circuit.Cert.ValueRange +module + +public import Nfp.Circuit.Cert.SoftmaxMargin +public import Nfp.Circuit.Cert.ValueRange /-! IO checks for certificates. -/ +public section + namespace Nfp namespace IO diff --git a/Nfp/IO/Derive.lean b/Nfp/IO/Derive.lean index c092908..7158dc1 100644 --- a/Nfp/IO/Derive.lean +++ b/Nfp/IO/Derive.lean @@ -1,19 +1,23 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Data.List.Range -import Mathlib.Data.Matrix.Mul -import Mathlib.Data.Vector.Defs -import Nfp.IO.NfptPure -import Nfp.IO.Timing -import Nfp.Model.Gpt2 -import Nfp.Sound.Bounds.Transformer -import Nfp.Sound.Induction -import Nfp.Sound.Induction.HeadBounds +module + +public import Mathlib.Data.List.Range +public import Mathlib.Data.Matrix.Mul +public import Mathlib.Data.Vector.Defs +public import Nfp.IO.NfptPure +public import Nfp.IO.Timing +public import Nfp.Model.Gpt2 +public import Nfp.Sound.Bounds.Transformer +public import Nfp.Sound.Induction +public import Nfp.Sound.Induction.HeadBounds /-! IO derivations that build certificates from model binaries. -/ +public section + namespace Nfp namespace IO diff --git a/Nfp/IO/HeadScore.lean b/Nfp/IO/HeadScore.lean index b5fd6a5..7bcec9a 100644 --- a/Nfp/IO/HeadScore.lean +++ b/Nfp/IO/HeadScore.lean @@ -1,12 +1,16 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Core.Basic -import Nfp.Sound.Linear.FinFold +module + +public import Nfp.Core.Basic +public import Nfp.Sound.Linear.FinFold /-! Pure helpers for building cached dot-abs functions for head scoring. -/ +public section + namespace Nfp namespace IO @@ -31,7 +35,7 @@ def dotAbsFromQKV {seq dHead : Nat} simp [row, cache, rowTasks, Task.spawn] simp [hrow, k.isLt]) -theorem dotAbsFromQKV_spec {seq dHead : Nat} +private theorem dotAbsFromQKV_spec {seq dHead : Nat} (qAbs kAbs : Fin seq → Fin dHead → Rat) : dotAbsFromQKV qAbs kAbs = let rowTasks : Array (Task (Array Rat)) := diff --git a/Nfp/IO/InductionHead.lean b/Nfp/IO/InductionHead.lean index 42663f3..aa94960 100644 --- a/Nfp/IO/InductionHead.lean +++ b/Nfp/IO/InductionHead.lean @@ -1,6 +1,8 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.IO.InductionHead.Basic +module + +public import Nfp.IO.InductionHead.Basic /-! IO helpers for induction-head certificate construction. diff --git a/Nfp/IO/InductionHead/Basic.lean b/Nfp/IO/InductionHead/Basic.lean index 8922aa7..1708797 100644 --- a/Nfp/IO/InductionHead/Basic.lean +++ b/Nfp/IO/InductionHead/Basic.lean @@ -1,22 +1,26 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Data.List.Range -import Nfp.IO.Pure -import Nfp.IO.NfptPure -import Nfp.IO.HeadScore -import Nfp.IO.Timing -import Nfp.IO.Util -import Nfp.Circuit.Cert.LogitDiff -import Nfp.Circuit.Cert.ResidualInterval -import Nfp.Sound.Induction -import Nfp.Sound.Induction.HeadBounds -import Nfp.Sound.Induction.LogitDiff -import Nfp.Sound.Linear.FinFold +module + +public import Mathlib.Data.List.Range +public import Nfp.IO.Pure +public import Nfp.IO.NfptPure +public import Nfp.IO.HeadScore +public import Nfp.IO.Timing +public import Nfp.IO.Util +public import Nfp.Circuit.Cert.LogitDiff +public import Nfp.Circuit.Cert.ResidualInterval +public import Nfp.Sound.Induction +public import Nfp.Sound.Induction.HeadBounds +public import Nfp.Sound.Induction.LogitDiff +public import Nfp.Sound.Linear.FinFold /-! IO helpers for induction-head certificate construction. -/ +public section + namespace Nfp namespace IO diff --git a/Nfp/IO/Loaders.lean b/Nfp/IO/Loaders.lean index 7efc4ac..8d0ac1f 100644 --- a/Nfp/IO/Loaders.lean +++ b/Nfp/IO/Loaders.lean @@ -1,15 +1,19 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.IO.Pure -import Nfp.Circuit.Cert.LogitDiff -import Nfp.Circuit.Cert.DownstreamLinear -import Nfp.Circuit.Cert.ResidualBound -import Nfp.Circuit.Cert.ResidualInterval +module + +public import Nfp.IO.Pure +public import Nfp.Circuit.Cert.LogitDiff +public import Nfp.Circuit.Cert.DownstreamLinear +public import Nfp.Circuit.Cert.ResidualBound +public import Nfp.Circuit.Cert.ResidualInterval /-! IO loaders for certificates and raw inputs. -/ +public section + namespace Nfp namespace IO diff --git a/Nfp/IO/NfptPure.lean b/Nfp/IO/NfptPure.lean index 83b4446..f33ee8d 100644 --- a/Nfp/IO/NfptPure.lean +++ b/Nfp/IO/NfptPure.lean @@ -1,10 +1,12 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Data.List.Range -import Nfp.Core.Basic -import Nfp.Model.Gpt2 -import Nfp.Model.InductionHead -import Nfp.Model.InductionPrompt +module + +public import Mathlib.Data.List.Range +public import Nfp.Core.Basic +public import Nfp.Model.Gpt2 +public import Nfp.Model.InductionHead +public import Nfp.Model.InductionPrompt /-! Pure parsing utilities for `NFP_BINARY_V1` model files. @@ -12,6 +14,8 @@ Pure parsing utilities for `NFP_BINARY_V1` model files. These helpers parse headers and extract selected weight slices as rational values. -/ +public section + namespace Nfp namespace IO @@ -630,7 +634,7 @@ def buildInductionHeadInputs (h : NfptHeader) (scale : Rat) direction := direction } /-- Definitional characterization of `buildInductionHeadInputs`. -/ -theorem buildInductionHeadInputs_def (h : NfptHeader) (scale : Rat) +private theorem buildInductionHeadInputs_def (h : NfptHeader) (scale : Rat) (tokens : Fin h.seqLen → Nat) (embed : Fin h.seqLen → Fin h.modelDim → Rat) (weights : Model.Gpt2HeadWeights h.modelDim h.headDim) @@ -702,7 +706,7 @@ theorem buildInductionHeadInputs_prev_active_def (h : NfptHeader) (scale : Rat) (match period? with | some period => Model.prevOfPeriod (seq := h.seqLen) period | none => Model.prevOfTokens (seq := h.seqLen) tokens) := by - simp [buildInductionHeadInputs] + constructor <;> rfl /-- Active queries pick the maximal matching prior token when `period? = none`. -/ theorem buildInductionHeadInputs_prev_spec_of_active (h : NfptHeader) (scale : Rat) diff --git a/Nfp/IO/Pure.lean b/Nfp/IO/Pure.lean index 9bb0758..0119f01 100644 --- a/Nfp/IO/Pure.lean +++ b/Nfp/IO/Pure.lean @@ -1,11 +1,13 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.IO.Pure.Basic -import Nfp.IO.Pure.Downstream -import Nfp.IO.Pure.InductionHead -import Nfp.IO.Pure.Residual -import Nfp.IO.Pure.SoftmaxMargin -import Nfp.IO.Pure.ValueRange +module + +public import Nfp.IO.Pure.Basic +public import Nfp.IO.Pure.Downstream +public import Nfp.IO.Pure.InductionHead +public import Nfp.IO.Pure.Residual +public import Nfp.IO.Pure.SoftmaxMargin +public import Nfp.IO.Pure.ValueRange /-! Aggregator for pure CLI parsing helpers. diff --git a/Nfp/IO/Pure/Basic.lean b/Nfp/IO/Pure/Basic.lean index 48481e9..81ccb4a 100644 --- a/Nfp/IO/Pure/Basic.lean +++ b/Nfp/IO/Pure/Basic.lean @@ -1,11 +1,15 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Core.Basic +module + +public import Nfp.Core.Basic /-! Shared parsing helpers for CLI inputs. -/ +public section + namespace Nfp namespace IO diff --git a/Nfp/IO/Pure/Downstream.lean b/Nfp/IO/Pure/Downstream.lean index 353d3ca..cf6dba8 100644 --- a/Nfp/IO/Pure/Downstream.lean +++ b/Nfp/IO/Pure/Downstream.lean @@ -1,12 +1,16 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Circuit.Cert.DownstreamLinear -import Nfp.IO.Pure.Basic +module + +public import Nfp.Circuit.Cert.DownstreamLinear +public import Nfp.IO.Pure.Basic /-! Pure parsing helpers for downstream linear and matrix payloads. -/ +public section + namespace Nfp namespace IO diff --git a/Nfp/IO/Pure/InductionHead.lean b/Nfp/IO/Pure/InductionHead.lean index 7881ed1..0c3adcf 100644 --- a/Nfp/IO/Pure/InductionHead.lean +++ b/Nfp/IO/Pure/InductionHead.lean @@ -1,11 +1,15 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.IO.Pure.InductionHead.Bytes +module + +public import Nfp.IO.Pure.InductionHead.Bytes /-! Parsing helpers for induction-head input payloads. -/ +public section + namespace Nfp namespace IO diff --git a/Nfp/IO/Pure/InductionHead/Bytes.lean b/Nfp/IO/Pure/InductionHead/Bytes.lean index 2333f70..9db33d4 100644 --- a/Nfp/IO/Pure/InductionHead/Bytes.lean +++ b/Nfp/IO/Pure/InductionHead/Bytes.lean @@ -1,13 +1,17 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Data.Finset.Insert -import Nfp.IO.Pure.Basic -import Nfp.Model.InductionHead +module + +public import Mathlib.Data.Finset.Insert +public import Nfp.IO.Pure.Basic +public import Nfp.Model.InductionHead /-! Parsing helpers for induction-head input payloads from UTF-8 bytes. -/ +public section + namespace Nfp namespace IO @@ -71,7 +75,7 @@ private def parseNatBytesSpec (data : ByteArray) (t : ByteToken) : Except String private def parseNatBytes (data : ByteArray) (t : ByteToken) : Except String Nat := parseNatBytesSpec data t -theorem parseNatBytes_eq_spec (data : ByteArray) (t : ByteToken) : +private theorem parseNatBytes_eq_spec (data : ByteArray) (t : ByteToken) : parseNatBytes data t = parseNatBytesSpec data t := by rfl @@ -90,7 +94,7 @@ private def parseIntBytesSpec (data : ByteArray) (t : ByteToken) : Except String private def parseIntBytes (data : ByteArray) (t : ByteToken) : Except String Int := parseIntBytesSpec data t -theorem parseIntBytes_eq_spec (data : ByteArray) (t : ByteToken) : +private theorem parseIntBytes_eq_spec (data : ByteArray) (t : ByteToken) : parseIntBytes data t = parseIntBytesSpec data t := by rfl @@ -121,7 +125,7 @@ private def parseRatBytesSpec (data : ByteArray) (t : ByteToken) : Except String private def parseRatBytes (data : ByteArray) (t : ByteToken) : Except String Rat := parseRatBytesSpec data t -theorem parseRatBytes_eq_spec (data : ByteArray) (t : ByteToken) : +private theorem parseRatBytes_eq_spec (data : ByteArray) (t : ByteToken) : parseRatBytes data t = parseRatBytesSpec data t := by rfl diff --git a/Nfp/IO/Pure/Residual.lean b/Nfp/IO/Pure/Residual.lean index d02ef07..97e67b9 100644 --- a/Nfp/IO/Pure/Residual.lean +++ b/Nfp/IO/Pure/Residual.lean @@ -1,13 +1,17 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Circuit.Cert.ResidualBound -import Nfp.Circuit.Cert.ResidualInterval -import Nfp.IO.Pure.Basic +module + +public import Nfp.Circuit.Cert.ResidualBound +public import Nfp.Circuit.Cert.ResidualInterval +public import Nfp.IO.Pure.Basic /-! Pure parsing helpers for residual-bound and residual-interval certificates. -/ +public section + namespace Nfp namespace IO diff --git a/Nfp/IO/Pure/SoftmaxMargin.lean b/Nfp/IO/Pure/SoftmaxMargin.lean index 0b12370..c771a7d 100644 --- a/Nfp/IO/Pure/SoftmaxMargin.lean +++ b/Nfp/IO/Pure/SoftmaxMargin.lean @@ -1,7 +1,9 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.IO.Pure.SoftmaxMargin.Cert -import Nfp.IO.Pure.SoftmaxMargin.Raw +module + +public import Nfp.IO.Pure.SoftmaxMargin.Cert +public import Nfp.IO.Pure.SoftmaxMargin.Raw /-! Aggregator for softmax-margin parsing helpers. diff --git a/Nfp/IO/Pure/SoftmaxMargin/Cert.lean b/Nfp/IO/Pure/SoftmaxMargin/Cert.lean index 8354fd5..97ba361 100644 --- a/Nfp/IO/Pure/SoftmaxMargin/Cert.lean +++ b/Nfp/IO/Pure/SoftmaxMargin/Cert.lean @@ -1,12 +1,16 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Circuit.Cert.SoftmaxMargin -import Nfp.IO.Pure.SoftmaxMargin.Shared +module + +public import Nfp.Circuit.Cert.SoftmaxMargin +public import Nfp.IO.Pure.SoftmaxMargin.Shared /-! Pure parsing helpers for softmax-margin certificates. -/ +public section + namespace Nfp namespace IO diff --git a/Nfp/IO/Pure/SoftmaxMargin/Raw.lean b/Nfp/IO/Pure/SoftmaxMargin/Raw.lean index 7988511..6869e00 100644 --- a/Nfp/IO/Pure/SoftmaxMargin/Raw.lean +++ b/Nfp/IO/Pure/SoftmaxMargin/Raw.lean @@ -1,12 +1,16 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Circuit.Cert.SoftmaxMargin -import Nfp.IO.Pure.SoftmaxMargin.Shared +module + +public import Nfp.Circuit.Cert.SoftmaxMargin +public import Nfp.IO.Pure.SoftmaxMargin.Shared /-! Pure parsing helpers for raw softmax-margin inputs. -/ +public section + namespace Nfp namespace IO diff --git a/Nfp/IO/Pure/SoftmaxMargin/Shared.lean b/Nfp/IO/Pure/SoftmaxMargin/Shared.lean index ee421bd..4930a67 100644 --- a/Nfp/IO/Pure/SoftmaxMargin/Shared.lean +++ b/Nfp/IO/Pure/SoftmaxMargin/Shared.lean @@ -1,12 +1,16 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Data.Finset.Insert -import Nfp.IO.Pure.Basic +module + +public import Mathlib.Data.Finset.Insert +public import Nfp.IO.Pure.Basic /-! Shared parsing helpers for softmax-margin payloads. -/ +public section + namespace Nfp namespace IO diff --git a/Nfp/IO/Pure/ValueRange.lean b/Nfp/IO/Pure/ValueRange.lean index a6053d4..f41810b 100644 --- a/Nfp/IO/Pure/ValueRange.lean +++ b/Nfp/IO/Pure/ValueRange.lean @@ -1,7 +1,9 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.IO.Pure.ValueRange.Cert -import Nfp.IO.Pure.ValueRange.Raw +module + +public import Nfp.IO.Pure.ValueRange.Cert +public import Nfp.IO.Pure.ValueRange.Raw /-! Aggregator for value-range parsing helpers. diff --git a/Nfp/IO/Pure/ValueRange/Cert.lean b/Nfp/IO/Pure/ValueRange/Cert.lean index 5a54f32..ee7c14f 100644 --- a/Nfp/IO/Pure/ValueRange/Cert.lean +++ b/Nfp/IO/Pure/ValueRange/Cert.lean @@ -1,12 +1,16 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Circuit.Cert.ValueRange -import Nfp.IO.Pure.ValueRange.Shared +module + +public import Nfp.Circuit.Cert.ValueRange +public import Nfp.IO.Pure.ValueRange.Shared /-! Pure parsing helpers for value-range certificates. -/ +public section + namespace Nfp namespace IO diff --git a/Nfp/IO/Pure/ValueRange/Raw.lean b/Nfp/IO/Pure/ValueRange/Raw.lean index a9da85b..7807093 100644 --- a/Nfp/IO/Pure/ValueRange/Raw.lean +++ b/Nfp/IO/Pure/ValueRange/Raw.lean @@ -1,12 +1,16 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Circuit.Cert.ValueRange -import Nfp.IO.Pure.ValueRange.Shared +module + +public import Nfp.Circuit.Cert.ValueRange +public import Nfp.IO.Pure.ValueRange.Shared /-! Pure parsing helpers for raw value-range inputs. -/ +public section + namespace Nfp namespace IO diff --git a/Nfp/IO/Pure/ValueRange/Shared.lean b/Nfp/IO/Pure/ValueRange/Shared.lean index 441600f..93a8fc5 100644 --- a/Nfp/IO/Pure/ValueRange/Shared.lean +++ b/Nfp/IO/Pure/ValueRange/Shared.lean @@ -1,12 +1,16 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Circuit.Cert.ValueRange -import Nfp.IO.Pure.Basic +module + +public import Nfp.Circuit.Cert.ValueRange +public import Nfp.IO.Pure.Basic /-! Shared parsing helpers for value-range payloads. -/ +public section + namespace Nfp namespace IO diff --git a/Nfp/IO/Run.lean b/Nfp/IO/Run.lean index 2ea7c46..d48e4f3 100644 --- a/Nfp/IO/Run.lean +++ b/Nfp/IO/Run.lean @@ -1,6 +1,8 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.IO.Run.Basic +module + +public import Nfp.IO.Run.Basic /-! IO entrypoints used by the CLI. diff --git a/Nfp/IO/Run/Basic.lean b/Nfp/IO/Run/Basic.lean index 9e9a598..fccad3a 100644 --- a/Nfp/IO/Run/Basic.lean +++ b/Nfp/IO/Run/Basic.lean @@ -1,28 +1,32 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.IO.Checks -import Nfp.IO.Derive -import Nfp.IO.HeadScore -import Nfp.IO.InductionHead -import Nfp.IO.Loaders -import Nfp.IO.NfptPure -import Nfp.IO.Timing -import Nfp.IO.Util -import Nfp.Circuit.Cert.DownstreamLinear -import Nfp.Circuit.Cert.LogitDiff -import Nfp.Circuit.Cert.ResidualBound -import Nfp.Circuit.Cert.ResidualInterval -import Nfp.Sound.Bounds.MatrixNorm -import Nfp.Sound.Bounds.Transformer -import Nfp.Sound.Induction -import Nfp.Sound.Induction.HeadBounds -import Nfp.Sound.Induction.LogitDiff -import Nfp.Sound.Linear.FinFold +module + +public import Nfp.IO.Checks +public import Nfp.IO.Derive +public import Nfp.IO.HeadScore +public import Nfp.IO.InductionHead +public import Nfp.IO.Loaders +public import Nfp.IO.NfptPure +public import Nfp.IO.Timing +public import Nfp.IO.Util +public import Nfp.Circuit.Cert.DownstreamLinear +public import Nfp.Circuit.Cert.LogitDiff +public import Nfp.Circuit.Cert.ResidualBound +public import Nfp.Circuit.Cert.ResidualInterval +public import Nfp.Sound.Bounds.MatrixNorm +public import Nfp.Sound.Bounds.Transformer +public import Nfp.Sound.Induction +public import Nfp.Sound.Induction.HeadBounds +public import Nfp.Sound.Induction.LogitDiff +public import Nfp.Sound.Linear.FinFold /-! IO entrypoints used by the CLI. -/ +public section + namespace Nfp namespace IO open Nfp.Circuit diff --git a/Nfp/IO/Timing.lean b/Nfp/IO/Timing.lean index 41f037e..24b2a78 100644 --- a/Nfp/IO/Timing.lean +++ b/Nfp/IO/Timing.lean @@ -1,13 +1,17 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Data.List.Range -import Nfp.Model.InductionHead -import Nfp.Sound.Induction.HeadBounds +module + +public import Mathlib.Data.List.Range +public import Nfp.Model.InductionHead +public import Nfp.Sound.Induction.HeadBounds /-! Small IO helpers for profiling slow phases. -/ +public section + namespace Nfp namespace IO diff --git a/Nfp/IO/Util.lean b/Nfp/IO/Util.lean index bb374fe..d52c91f 100644 --- a/Nfp/IO/Util.lean +++ b/Nfp/IO/Util.lean @@ -1,11 +1,15 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.IO.Pure +module + +public import Nfp.IO.Pure /-! Small shared helpers for IO parsing. -/ +public section + namespace Nfp namespace IO diff --git a/Nfp/Model/Gpt2.lean b/Nfp/Model/Gpt2.lean index bfd9644..82db571 100644 --- a/Nfp/Model/Gpt2.lean +++ b/Nfp/Model/Gpt2.lean @@ -1,7 +1,9 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Core.Basic -import Nfp.Circuit.Cert.ValueRange +module + +public import Nfp.Core.Basic +public import Nfp.Circuit.Cert.ValueRange /-! Exact GPT-2 slices for induction certification and downstream bounds. @@ -11,6 +13,8 @@ MLP/LayerNorm parameters needed to build `InductionHeadInputs` and downstream bound computations. -/ +@[expose] public section + namespace Nfp namespace Model diff --git a/Nfp/Model/InductionHead.lean b/Nfp/Model/InductionHead.lean index 697652d..9d77227 100644 --- a/Nfp/Model/InductionHead.lean +++ b/Nfp/Model/InductionHead.lean @@ -1,8 +1,10 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Data.Finset.Basic -import Nfp.Core.Basic -import Nfp.Circuit.Cert.ValueRange +module + +public import Mathlib.Data.Finset.Basic +public import Nfp.Core.Basic +public import Nfp.Circuit.Cert.ValueRange /-! Exact inputs for induction-head scoring and value-direction computations. @@ -11,6 +13,8 @@ These structures store exact rational inputs (embeddings and weights) for a single attention head. They are intended to be consumed by sound builders. -/ +@[expose] public section + namespace Nfp namespace Model diff --git a/Nfp/Model/InductionPrompt.lean b/Nfp/Model/InductionPrompt.lean index c49c4ed..df13727 100644 --- a/Nfp/Model/InductionPrompt.lean +++ b/Nfp/Model/InductionPrompt.lean @@ -1,7 +1,9 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Data.Finset.Max -import Mathlib.Data.Fintype.Basic +module + +public import Mathlib.Data.Finset.Max +public import Mathlib.Data.Fintype.Basic /-! Helpers for induction-style prompts. @@ -11,6 +13,8 @@ active-query set from a fixed period. They keep the prompt bookkeeping separate from the model weights. -/ +@[expose] public section + namespace Nfp namespace Model diff --git a/Nfp/Sound/Bounds.lean b/Nfp/Sound/Bounds.lean index 21a9f24..b5d1015 100644 --- a/Nfp/Sound/Bounds.lean +++ b/Nfp/Sound/Bounds.lean @@ -1,14 +1,16 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Sound.Bounds.Attention -import Nfp.Sound.Bounds.Cache -import Nfp.Sound.Bounds.Gelu -import Nfp.Sound.Bounds.LayerNorm -import Nfp.Sound.Bounds.LayerNorm.InvStd -import Nfp.Sound.Bounds.MatrixNorm -import Nfp.Sound.Bounds.Mlp -import Nfp.Sound.Bounds.Transformer -import Nfp.Sound.Bounds.UnnormRat +module + +public import Nfp.Sound.Bounds.Attention +public import Nfp.Sound.Bounds.Cache +public import Nfp.Sound.Bounds.Gelu +public import Nfp.Sound.Bounds.LayerNorm +public import Nfp.Sound.Bounds.LayerNorm.InvStd +public import Nfp.Sound.Bounds.MatrixNorm +public import Nfp.Sound.Bounds.Mlp +public import Nfp.Sound.Bounds.Transformer +public import Nfp.Sound.Bounds.UnnormRat /-! Aggregator for sound interval bounds. diff --git a/Nfp/Sound/Bounds/Attention.lean b/Nfp/Sound/Bounds/Attention.lean index 731e1f0..c81fa85 100644 --- a/Nfp/Sound/Bounds/Attention.lean +++ b/Nfp/Sound/Bounds/Attention.lean @@ -1,21 +1,25 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.BigOperators.Field -import Mathlib.Algebra.BigOperators.Ring.Finset -import Mathlib.Algebra.Order.BigOperators.Group.Finset -import Mathlib.Data.Real.Basic -import Nfp.Circuit.Layers.Softmax -import Nfp.Core.Basic -import Nfp.Model.Gpt2 -import Nfp.Sound.Bounds.Cache -import Nfp.Sound.Bounds.LayerNorm -import Nfp.Sound.Bounds.MatrixNorm -import Nfp.Sound.Bounds.Mlp +module + +public import Mathlib.Algebra.BigOperators.Field +public import Mathlib.Algebra.BigOperators.Ring.Finset +public import Mathlib.Algebra.Order.BigOperators.Group.Finset +public import Mathlib.Data.Real.Basic +public import Nfp.Circuit.Layers.Softmax +public import Nfp.Core.Basic +public import Nfp.Model.Gpt2 +public import Nfp.Sound.Bounds.Cache +public import Nfp.Sound.Bounds.LayerNorm +public import Nfp.Sound.Bounds.MatrixNorm +public import Nfp.Sound.Bounds.Mlp /-! Interval bounds for multi-head attention and transformer layers. -/ +@[expose] public section + namespace Nfp namespace Sound diff --git a/Nfp/Sound/Bounds/Cache.lean b/Nfp/Sound/Bounds/Cache.lean index 7e21f0b..4b8e1aa 100644 --- a/Nfp/Sound/Bounds/Cache.lean +++ b/Nfp/Sound/Bounds/Cache.lean @@ -1,11 +1,15 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Core.Basic +module + +public import Nfp.Core.Basic /-! Caching helpers for interval bounds. -/ +@[expose] public section + namespace Nfp namespace Sound diff --git a/Nfp/Sound/Bounds/Gelu.lean b/Nfp/Sound/Bounds/Gelu.lean index 5fdeb93..f76f06e 100644 --- a/Nfp/Sound/Bounds/Gelu.lean +++ b/Nfp/Sound/Bounds/Gelu.lean @@ -1,15 +1,19 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.Order.Ring.Abs -import Mathlib.Analysis.Complex.Trigonometric -import Mathlib.Analysis.SpecialFunctions.Trigonometric.Basic -import Nfp.Core.Basic +module + +public import Mathlib.Algebra.Order.Ring.Abs +public import Mathlib.Analysis.Complex.Trigonometric +public import Mathlib.Analysis.SpecialFunctions.Trigonometric.Basic +public import Nfp.Core.Basic /-! Tanh-based GELU bounds for GPT-2 style MLPs. These bounds are used to propagate interval constraints through nonlinear gates. -/ +@[expose] public section + namespace Nfp namespace Sound diff --git a/Nfp/Sound/Bounds/LayerNorm.lean b/Nfp/Sound/Bounds/LayerNorm.lean index 7e17dd5..ab110cf 100644 --- a/Nfp/Sound/Bounds/LayerNorm.lean +++ b/Nfp/Sound/Bounds/LayerNorm.lean @@ -1,9 +1,11 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Sound.Bounds.LayerNorm.Basic -import Nfp.Sound.Bounds.LayerNorm.InvStd -import Nfp.Sound.Bounds.LayerNorm.MeanVariance -import Nfp.Sound.Bounds.LayerNorm.SqrtBounds +module + +public import Nfp.Sound.Bounds.LayerNorm.Basic +public import Nfp.Sound.Bounds.LayerNorm.InvStd +public import Nfp.Sound.Bounds.LayerNorm.MeanVariance +public import Nfp.Sound.Bounds.LayerNorm.SqrtBounds /-! LayerNorm bounds and supporting lemmas. diff --git a/Nfp/Sound/Bounds/LayerNorm/Basic.lean b/Nfp/Sound/Bounds/LayerNorm/Basic.lean index f5e11ae..bf0846f 100644 --- a/Nfp/Sound/Bounds/LayerNorm/Basic.lean +++ b/Nfp/Sound/Bounds/LayerNorm/Basic.lean @@ -1,17 +1,19 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.BigOperators.Fin -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Algebra.Order.BigOperators.Group.Finset -import Mathlib.Algebra.Order.Field.Basic -import Mathlib.Algebra.Order.Ring.Basic -import Mathlib.Data.Real.Sqrt -import Mathlib.Data.Rat.BigOperators -import Mathlib.Data.Rat.Cast.Order -import Nfp.Core.Basic -import Nfp.Sound.Bounds.LayerNorm.MeanVariance -import Nfp.Sound.Bounds.LayerNorm.SqrtBounds -import Nfp.Sound.Linear.FinFold +module + +public import Mathlib.Algebra.BigOperators.Fin +public import Mathlib.Algebra.BigOperators.Group.Finset.Basic +public import Mathlib.Algebra.Order.BigOperators.Group.Finset +public import Mathlib.Algebra.Order.Field.Basic +public import Mathlib.Algebra.Order.Ring.Basic +public import Mathlib.Data.Real.Sqrt +public import Mathlib.Data.Rat.BigOperators +public import Mathlib.Data.Rat.Cast.Order +public import Nfp.Core.Basic +public import Nfp.Sound.Bounds.LayerNorm.MeanVariance +public import Nfp.Sound.Bounds.LayerNorm.SqrtBounds +public import Nfp.Sound.Linear.FinFold /-! LayerNorm interval bounds for rational inputs. @@ -20,6 +22,8 @@ This module computes rational interval bounds for LayerNorm outputs and proves those bounds sound for real-valued LayerNorm semantics. -/ +@[expose] public section + namespace Nfp namespace Sound diff --git a/Nfp/Sound/Bounds/LayerNorm/InvStd.lean b/Nfp/Sound/Bounds/LayerNorm/InvStd.lean index aa88768..cac7454 100644 --- a/Nfp/Sound/Bounds/LayerNorm/InvStd.lean +++ b/Nfp/Sound/Bounds/LayerNorm/InvStd.lean @@ -1,7 +1,9 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Sound.Bounds.LayerNorm.MeanVariance -import Nfp.Sound.Bounds.LayerNorm.SqrtBounds +module + +public import Nfp.Sound.Bounds.LayerNorm.MeanVariance +public import Nfp.Sound.Bounds.LayerNorm.SqrtBounds /-! Inverse-standard-deviation bounds for LayerNorm. @@ -10,6 +12,8 @@ This module isolates invStd bounds and their soundness proof to keep `LayerNorm/Basic.lean` below the style linter's file-length limit. -/ +@[expose] public section + namespace Nfp namespace Sound diff --git a/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean b/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean index ffcd835..b03f058 100644 --- a/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean +++ b/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean @@ -1,13 +1,15 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.BigOperators.Fin -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Algebra.Order.BigOperators.Group.Finset -import Mathlib.Algebra.Order.Field.Basic -import Mathlib.Algebra.Order.Ring.Basic -import Mathlib.Data.Rat.BigOperators -import Mathlib.Data.Rat.Cast.Order -import Nfp.Core.Basic +module + +public import Mathlib.Algebra.BigOperators.Fin +public import Mathlib.Algebra.BigOperators.Group.Finset.Basic +public import Mathlib.Algebra.Order.BigOperators.Group.Finset +public import Mathlib.Algebra.Order.Field.Basic +public import Mathlib.Algebra.Order.Ring.Basic +public import Mathlib.Data.Rat.BigOperators +public import Mathlib.Data.Rat.Cast.Order +public import Nfp.Core.Basic /-! Mean/variance helpers for LayerNorm bounds. @@ -16,6 +18,8 @@ This module isolates the rational and real mean/variance definitions and their basic lemmas to keep `LayerNorm` bounds modular. -/ +@[expose] public section + namespace Nfp namespace Sound diff --git a/Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean b/Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean index 535c98c..bbf85a0 100644 --- a/Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean +++ b/Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean @@ -1,11 +1,13 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.Order.Field.Basic -import Mathlib.Algebra.Order.Ring.Basic -import Mathlib.Data.Nat.Sqrt -import Mathlib.Data.Real.Sqrt -import Mathlib.Data.Rat.Cast.Order -import Nfp.Core.Basic +module + +public import Mathlib.Algebra.Order.Field.Basic +public import Mathlib.Algebra.Order.Ring.Basic +public import Mathlib.Data.Nat.Sqrt +public import Mathlib.Data.Real.Sqrt +public import Mathlib.Data.Rat.Cast.Order +public import Nfp.Core.Basic /-! Square-root bounds for LayerNorm intervals. @@ -14,6 +16,8 @@ This module isolates the rational sqrt lower/upper bounds and their basic nonnegativity/positivity lemmas so the main LayerNorm bounds stay focused. -/ +@[expose] public section + namespace Nfp namespace Sound diff --git a/Nfp/Sound/Bounds/MatrixNorm.lean b/Nfp/Sound/Bounds/MatrixNorm.lean index 324dde6..988f4bd 100644 --- a/Nfp/Sound/Bounds/MatrixNorm.lean +++ b/Nfp/Sound/Bounds/MatrixNorm.lean @@ -1,7 +1,9 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Sound.Bounds.MatrixNorm.Basic -import Nfp.Sound.Bounds.MatrixNorm.Interval +module + +public import Nfp.Sound.Bounds.MatrixNorm.Basic +public import Nfp.Sound.Bounds.MatrixNorm.Interval /-! Matrix norm and interval bound helpers for downstream certificates. diff --git a/Nfp/Sound/Bounds/MatrixNorm/Basic.lean b/Nfp/Sound/Bounds/MatrixNorm/Basic.lean index 3e29e40..0d7b1a3 100644 --- a/Nfp/Sound/Bounds/MatrixNorm/Basic.lean +++ b/Nfp/Sound/Bounds/MatrixNorm/Basic.lean @@ -1,16 +1,18 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.BigOperators.Fin -import Mathlib.Algebra.Order.BigOperators.Group.Finset -import Mathlib.Algebra.Order.Ring.Abs -import Mathlib.Data.Fintype.Basic -import Mathlib.Data.Matrix.Mul -import Mathlib.Data.Real.Basic -import Nfp.Circuit.Cert.DownstreamLinear -import Nfp.Circuit.Cert.ResidualInterval -import Nfp.Core.Basic -import Nfp.Sound.Bounds.MatrixNorm.Interval -import Nfp.Sound.Linear.FinFold +module + +public import Mathlib.Algebra.BigOperators.Fin +public import Mathlib.Algebra.Order.BigOperators.Group.Finset +public import Mathlib.Algebra.Order.Ring.Abs +public import Mathlib.Data.Fintype.Basic +public import Mathlib.Data.Matrix.Mul +public import Mathlib.Data.Real.Basic +public import Nfp.Circuit.Cert.DownstreamLinear +public import Nfp.Circuit.Cert.ResidualInterval +public import Nfp.Core.Basic +public import Nfp.Sound.Bounds.MatrixNorm.Interval +public import Nfp.Sound.Linear.FinFold /-! Row-sum matrix norms for downstream linear certificates. @@ -19,6 +21,8 @@ These bounds are used to compute verified downstream error certificates from explicit Rat matrices. -/ +@[expose] public section + namespace Nfp namespace Sound diff --git a/Nfp/Sound/Bounds/MatrixNorm/Interval.lean b/Nfp/Sound/Bounds/MatrixNorm/Interval.lean index 262f39a..8fc7455 100644 --- a/Nfp/Sound/Bounds/MatrixNorm/Interval.lean +++ b/Nfp/Sound/Bounds/MatrixNorm/Interval.lean @@ -1,12 +1,14 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.BigOperators.Fin -import Mathlib.Algebra.Order.BigOperators.Group.Finset -import Mathlib.Algebra.Order.Ring.Abs -import Mathlib.Data.Matrix.Mul -import Mathlib.Data.Real.Basic -import Nfp.Core.Basic -import Nfp.Sound.Linear.FinFold +module + +public import Mathlib.Algebra.BigOperators.Fin +public import Mathlib.Algebra.Order.BigOperators.Group.Finset +public import Mathlib.Algebra.Order.Ring.Abs +public import Mathlib.Data.Matrix.Mul +public import Mathlib.Data.Real.Basic +public import Nfp.Core.Basic +public import Nfp.Sound.Linear.FinFold /-! Interval bounds for dot products and matrix-vector products. @@ -14,6 +16,8 @@ Interval bounds for dot products and matrix-vector products. This module isolates interval-bound helpers used across downstream certificates. -/ +@[expose] public section + namespace Nfp namespace Sound diff --git a/Nfp/Sound/Bounds/Mlp.lean b/Nfp/Sound/Bounds/Mlp.lean index 659e967..118dffd 100644 --- a/Nfp/Sound/Bounds/Mlp.lean +++ b/Nfp/Sound/Bounds/Mlp.lean @@ -1,15 +1,19 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Nfp.Core.Basic -import Nfp.Sound.Bounds.Gelu -import Nfp.Sound.Bounds.LayerNorm -import Nfp.Sound.Bounds.MatrixNorm +module + +public import Mathlib.Algebra.BigOperators.Group.Finset.Basic +public import Nfp.Core.Basic +public import Nfp.Sound.Bounds.Gelu +public import Nfp.Sound.Bounds.LayerNorm +public import Nfp.Sound.Bounds.MatrixNorm /-! Interval bounds for GPT-2 MLP blocks (linear + GELU + linear). -/ +@[expose] public section + namespace Nfp namespace Sound diff --git a/Nfp/Sound/Bounds/Transformer.lean b/Nfp/Sound/Bounds/Transformer.lean index d76e3f9..2104172 100644 --- a/Nfp/Sound/Bounds/Transformer.lean +++ b/Nfp/Sound/Bounds/Transformer.lean @@ -1,7 +1,9 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Sound.Bounds.Transformer.Basic -import Nfp.Sound.Bounds.Transformer.Embedding +module + +public import Nfp.Sound.Bounds.Transformer.Basic +public import Nfp.Sound.Bounds.Transformer.Embedding /-! Transformer-stack interval bounds and supporting lemmas. diff --git a/Nfp/Sound/Bounds/Transformer/Basic.lean b/Nfp/Sound/Bounds/Transformer/Basic.lean index 87fef74..253582c 100644 --- a/Nfp/Sound/Bounds/Transformer/Basic.lean +++ b/Nfp/Sound/Bounds/Transformer/Basic.lean @@ -1,19 +1,23 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Data.List.Range -import Mathlib.Data.Real.Basic -import Nfp.Circuit.Cert.ResidualInterval -import Nfp.Model.Gpt2 -import Nfp.Sound.Bounds.Attention -import Nfp.Sound.Bounds.LayerNorm -import Nfp.Sound.Bounds.Transformer.Embedding -import Nfp.Sound.Linear.FinFold +module + +public import Mathlib.Algebra.BigOperators.Group.Finset.Basic +public import Mathlib.Data.List.Range +public import Mathlib.Data.Real.Basic +public import Nfp.Circuit.Cert.ResidualInterval +public import Nfp.Model.Gpt2 +public import Nfp.Sound.Bounds.Attention +public import Nfp.Sound.Bounds.LayerNorm +public import Nfp.Sound.Bounds.Transformer.Embedding +public import Nfp.Sound.Linear.FinFold /-! Interval bounds for transformer stacks and final LayerNorm outputs. -/ +@[expose] public section + namespace Nfp namespace Sound diff --git a/Nfp/Sound/Bounds/Transformer/Embedding.lean b/Nfp/Sound/Bounds/Transformer/Embedding.lean index c63bff7..070d2ac 100644 --- a/Nfp/Sound/Bounds/Transformer/Embedding.lean +++ b/Nfp/Sound/Bounds/Transformer/Embedding.lean @@ -1,7 +1,9 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Nfp.Core.Basic +module + +public import Mathlib.Algebra.BigOperators.Group.Finset.Basic +public import Nfp.Core.Basic /-! Embedding interval bounds for transformer stacks. @@ -9,6 +11,8 @@ Embedding interval bounds for transformer stacks. This module isolates per-position and per-set embedding bounds. -/ +public section + namespace Nfp namespace Sound diff --git a/Nfp/Sound/Bounds/UnnormRat.lean b/Nfp/Sound/Bounds/UnnormRat.lean index ff9f1c3..802c441 100644 --- a/Nfp/Sound/Bounds/UnnormRat.lean +++ b/Nfp/Sound/Bounds/UnnormRat.lean @@ -1,7 +1,9 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Core.Basic -import Nfp.Sound.Linear.FinFold +module + +public import Nfp.Core.Basic +public import Nfp.Sound.Linear.FinFold /-! Unnormalized rational arithmetic. @@ -10,6 +12,8 @@ Rat values already avoid gcd normalization, so this module provides a lightweight alias and helper API used by older code paths. -/ +@[expose] public section + namespace Nfp namespace Sound diff --git a/Nfp/Sound/Gpt2/HeadInputs.lean b/Nfp/Sound/Gpt2/HeadInputs.lean index 7d66d41..7a3452f 100644 --- a/Nfp/Sound/Gpt2/HeadInputs.lean +++ b/Nfp/Sound/Gpt2/HeadInputs.lean @@ -1,8 +1,10 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Model.Gpt2 -import Nfp.Model.InductionHead -import Nfp.Model.InductionPrompt +module + +public import Nfp.Model.Gpt2 +public import Nfp.Model.InductionHead +public import Nfp.Model.InductionPrompt /-! Sound builder for GPT-2 induction head inputs. @@ -12,6 +14,8 @@ periodic prompt description. The construction is purely definitional and is captured by an explicit theorem, so the trusted core does not hide any logic. -/ +@[expose] public section + namespace Nfp namespace Sound diff --git a/Nfp/Sound/Induction.lean b/Nfp/Sound/Induction.lean index fa7ef37..ac6fb80 100644 --- a/Nfp/Sound/Induction.lean +++ b/Nfp/Sound/Induction.lean @@ -1,12 +1,14 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Sound.Induction.Core -import Nfp.Sound.Induction.CoreSound -import Nfp.Sound.Induction.EndToEnd -import Nfp.Sound.Induction.HeadBounds -import Nfp.Sound.Induction.HeadOutput -import Nfp.Sound.Induction.LogitDiff -import Nfp.Sound.Induction.OneHot +module + +public import Nfp.Sound.Induction.Core +public import Nfp.Sound.Induction.CoreSound +public import Nfp.Sound.Induction.EndToEnd +public import Nfp.Sound.Induction.HeadBounds +public import Nfp.Sound.Induction.HeadOutput +public import Nfp.Sound.Induction.LogitDiff +public import Nfp.Sound.Induction.OneHot /-! Sound builders for induction certificates. diff --git a/Nfp/Sound/Induction/Core.lean b/Nfp/Sound/Induction/Core.lean index 3d62976..d87c878 100644 --- a/Nfp/Sound/Induction/Core.lean +++ b/Nfp/Sound/Induction/Core.lean @@ -1,6 +1,8 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Sound.Induction.Core.Basic +module + +public import Nfp.Sound.Induction.Core.Basic /-! Core definitions and constructors for induction certificates. diff --git a/Nfp/Sound/Induction/Core/Basic.lean b/Nfp/Sound/Induction/Core/Basic.lean index e481c48..d058e0e 100644 --- a/Nfp/Sound/Induction/Core/Basic.lean +++ b/Nfp/Sound/Induction/Core/Basic.lean @@ -1,22 +1,27 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Algebra.Order.BigOperators.Group.Finset -import Mathlib.Algebra.Order.Field.Basic -import Nfp.Core.Basic -import Mathlib.Data.Finset.Lattice.Fold -import Nfp.Circuit.Cert.ResidualInterval -import Nfp.Circuit.Cert.SoftmaxMargin -import Nfp.Circuit.Cert.ValueRange -import Nfp.Sound.Bounds.Attention -import Nfp.Sound.Bounds.Cache -import Nfp.Sound.Bounds.LayerNorm -import Nfp.Sound.Bounds.LayerNorm.InvStd -import Nfp.Sound.Bounds.MatrixNorm -import Nfp.Sound.Induction.CoreDefs -import Nfp.Sound.Induction.OneHot -import Nfp.Sound.Linear.FinFold +module + +public import Mathlib.Algebra.BigOperators.Group.Finset.Basic +public import Mathlib.Algebra.Order.BigOperators.Group.Finset +public import Mathlib.Algebra.Order.Field.Basic +public import Nfp.Core.Basic +public import Mathlib.Data.Finset.Lattice.Fold +public import Nfp.Circuit.Cert.ResidualInterval +public import Nfp.Circuit.Cert.SoftmaxMargin +public import Nfp.Circuit.Cert.ValueRange +public import Nfp.Sound.Bounds.Attention +public import Nfp.Sound.Bounds.Cache +public import Nfp.Sound.Bounds.LayerNorm +public import Nfp.Sound.Bounds.LayerNorm.InvStd +public import Nfp.Sound.Bounds.MatrixNorm +public import Nfp.Sound.Induction.CoreDefs +public import Nfp.Sound.Induction.OneHot +public import Nfp.Sound.Linear.FinFold /-! Sound builders for induction certificates; recompute bounds inside Lean from exact inputs and derive softmax tolerances from score margins rather than trusting external weight dumps. -/ + +@[expose] public section + namespace Nfp namespace Sound open scoped BigOperators diff --git a/Nfp/Sound/Induction/CoreDefs.lean b/Nfp/Sound/Induction/CoreDefs.lean index 9831e57..3e586b5 100644 --- a/Nfp/Sound/Induction/CoreDefs.lean +++ b/Nfp/Sound/Induction/CoreDefs.lean @@ -1,14 +1,16 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Data.Vector.Defs -import Nfp.Circuit.Layers.Induction -import Nfp.Circuit.Layers.Softmax -import Nfp.Core.Basic -import Nfp.Model.InductionHead -import Nfp.Sound.Bounds.LayerNorm -import Nfp.Sound.Bounds.MatrixNorm -import Nfp.Sound.Linear.FinFold +module + +public import Mathlib.Algebra.BigOperators.Group.Finset.Basic +public import Mathlib.Data.Vector.Defs +public import Nfp.Circuit.Layers.Induction +public import Nfp.Circuit.Layers.Softmax +public import Nfp.Core.Basic +public import Nfp.Model.InductionHead +public import Nfp.Sound.Bounds.LayerNorm +public import Nfp.Sound.Bounds.MatrixNorm +public import Nfp.Sound.Linear.FinFold /-! Core definitions for induction-head certificates. @@ -16,6 +18,8 @@ Core definitions for induction-head certificates. These definitions are shared across induction certificate builders and checkers. -/ +@[expose] public section + namespace Nfp namespace Sound diff --git a/Nfp/Sound/Induction/CoreSound.lean b/Nfp/Sound/Induction/CoreSound.lean index 281343f..6542435 100644 --- a/Nfp/Sound/Induction/CoreSound.lean +++ b/Nfp/Sound/Induction/CoreSound.lean @@ -1,6 +1,8 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Sound.Induction.CoreSound.Basic +module + +public import Nfp.Sound.Induction.CoreSound.Basic /-! Soundness proofs for induction-head core certificates. diff --git a/Nfp/Sound/Induction/CoreSound/Basic.lean b/Nfp/Sound/Induction/CoreSound/Basic.lean index 0dc5c6d..12ccf0f 100644 --- a/Nfp/Sound/Induction/CoreSound/Basic.lean +++ b/Nfp/Sound/Induction/CoreSound/Basic.lean @@ -1,6 +1,11 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Sound.Induction.Core -import Nfp.Sound.Induction.CoreSound.Values +module + +public import Nfp.Sound.Induction.Core +public import Nfp.Sound.Induction.CoreSound.Values + +@[expose] public section + namespace Nfp namespace Sound open scoped BigOperators diff --git a/Nfp/Sound/Induction/CoreSound/Values.lean b/Nfp/Sound/Induction/CoreSound/Values.lean index 1a70284..b448ac7 100644 --- a/Nfp/Sound/Induction/CoreSound/Values.lean +++ b/Nfp/Sound/Induction/CoreSound/Values.lean @@ -1,7 +1,9 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Nfp.Sound.Induction.CoreDefs -import Nfp.Sound.Linear.FinFold +module + +public import Mathlib.Algebra.BigOperators.Group.Finset.Basic +public import Nfp.Sound.Induction.CoreDefs +public import Nfp.Sound.Linear.FinFold /-! Helper lemmas for value-direction bounds in induction-head soundness. @@ -10,6 +12,8 @@ These isolate the algebra needed to rewrite direction-value projections into dot products over cached `wvDir`/`bDir` terms. -/ +@[expose] public section + namespace Nfp namespace Sound diff --git a/Nfp/Sound/Induction/EndToEnd.lean b/Nfp/Sound/Induction/EndToEnd.lean index 35ad055..e193a60 100644 --- a/Nfp/Sound/Induction/EndToEnd.lean +++ b/Nfp/Sound/Induction/EndToEnd.lean @@ -1,13 +1,17 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Sound.Bounds.Transformer -import Nfp.Sound.Induction.HeadOutput -import Nfp.Sound.Induction.LogitDiff +module + +public import Nfp.Sound.Bounds.Transformer +public import Nfp.Sound.Induction.HeadOutput +public import Nfp.Sound.Induction.LogitDiff /-! End-to-end induction bounds that combine head certificates with transformer-stack intervals. -/ +@[expose] public section + namespace Nfp namespace Sound diff --git a/Nfp/Sound/Induction/HeadBounds.lean b/Nfp/Sound/Induction/HeadBounds.lean index 5e2f30d..d582fad 100644 --- a/Nfp/Sound/Induction/HeadBounds.lean +++ b/Nfp/Sound/Induction/HeadBounds.lean @@ -1,6 +1,8 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Sound.Induction.HeadBounds.Basic +module + +public import Nfp.Sound.Induction.HeadBounds.Basic /-! Helper bounds for head-induction certificate construction. diff --git a/Nfp/Sound/Induction/HeadBounds/Basic.lean b/Nfp/Sound/Induction/HeadBounds/Basic.lean index 037ed18..4d7b4d2 100644 --- a/Nfp/Sound/Induction/HeadBounds/Basic.lean +++ b/Nfp/Sound/Induction/HeadBounds/Basic.lean @@ -1,14 +1,16 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Core.Basic -import Mathlib.Data.Finset.Basic -import Mathlib.Data.List.Range -import Mathlib.Data.Vector.Defs -import Nfp.Model.InductionHead -import Nfp.Sound.Bounds.Attention -import Nfp.Sound.Bounds.LayerNorm -import Nfp.Sound.Bounds.MatrixNorm -import Nfp.Sound.Linear.FinFold +module + +public import Nfp.Core.Basic +public import Mathlib.Data.Finset.Basic +public import Mathlib.Data.List.Range +public import Mathlib.Data.Vector.Defs +public import Nfp.Model.InductionHead +public import Nfp.Sound.Bounds.Attention +public import Nfp.Sound.Bounds.LayerNorm +public import Nfp.Sound.Bounds.MatrixNorm +public import Nfp.Sound.Linear.FinFold /-! Helper bounds for head-induction certificate construction. @@ -16,6 +18,8 @@ Helper bounds for head-induction certificate construction. These are pure precomputations that are useful for profiling and staging. -/ +public section + namespace Nfp namespace Sound @@ -126,7 +130,7 @@ private def reduceFnChunked [NeZero seq] (vals : Fin seq → Rat) rest.foldl (fun acc i => combine acc (chunkVals.getD i 0)) init /-- Unfold `reduceFnChunked` to its chunked sequential definition. -/ -theorem reduceFnChunked_spec [NeZero seq] (vals : Fin seq → Rat) +private theorem reduceFnChunked_spec [NeZero seq] (vals : Fin seq → Rat) (combine : Rat → Rat → Rat) : reduceFnChunked (seq := seq) vals combine = let n := seq @@ -181,7 +185,7 @@ private def reduceFnTask [NeZero seq] (vals : Fin seq → Rat) rest.foldl (fun acc i => combineTask acc (chunkTasks.getD i defaultTask)) init /-- Unfold `reduceFnTask` to its chunked-task definition. -/ -theorem reduceFnTask_spec [NeZero seq] (vals : Fin seq → Rat) +private theorem reduceFnTask_spec [NeZero seq] (vals : Fin seq → Rat) (combine : Rat → Rat → Rat) (combineTask : Task Rat → Task Rat → Task Rat) : reduceFnTask (seq := seq) vals combine combineTask = let n := seq @@ -220,7 +224,7 @@ private def reduceMinFnChunked [NeZero seq] (vals : Fin seq → Rat) : Rat := reduceFnChunked vals min /-- Unfold `reduceMinFnChunked` to `reduceFnChunked` with `min`. -/ -theorem reduceMinFnChunked_spec [NeZero seq] (vals : Fin seq → Rat) : +private theorem reduceMinFnChunked_spec [NeZero seq] (vals : Fin seq → Rat) : reduceMinFnChunked vals = reduceFnChunked vals min := rfl /-- Chunked sequential maximum over a `Fin seq`-indexed function. -/ @@ -228,11 +232,11 @@ private def reduceMaxFnChunked [NeZero seq] (vals : Fin seq → Rat) : Rat := reduceFnChunked vals max /-- Unfold `reduceMaxFnChunked` to `reduceFnChunked` with `max`. -/ -theorem reduceMaxFnChunked_spec [NeZero seq] (vals : Fin seq → Rat) : +private theorem reduceMaxFnChunked_spec [NeZero seq] (vals : Fin seq → Rat) : reduceMaxFnChunked vals = reduceFnChunked vals max := rfl /-- The chunked parallel min-reduction task returns the sequential chunked result. -/ -theorem reduceMinFnTask_get_eq [NeZero seq] (vals : Fin seq → Rat) : +private theorem reduceMinFnTask_get_eq [NeZero seq] (vals : Fin seq → Rat) : (reduceMinFnTask vals).get = reduceMinFnChunked vals := by classical have hseq : seq ≠ 0 := NeZero.ne (n := seq) @@ -240,7 +244,7 @@ theorem reduceMinFnTask_get_eq [NeZero seq] (vals : Fin seq → Rat) : Task.spawn, foldl_taskMin_get_eq, task_getD_ofFn] /-- The chunked parallel max-reduction task returns the sequential chunked result. -/ -theorem reduceMaxFnTask_get_eq [NeZero seq] (vals : Fin seq → Rat) : +private theorem reduceMaxFnTask_get_eq [NeZero seq] (vals : Fin seq → Rat) : (reduceMaxFnTask vals).get = reduceMaxFnChunked vals := by classical have hseq : seq ≠ 0 := NeZero.ne (n := seq) @@ -260,7 +264,7 @@ def headLnBounds [NeZero seq] {dModel dHead : Nat} Bounds.cacheBoundPair2 (fun q => Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) -theorem headLnBounds_spec [NeZero seq] {dModel dHead : Nat} +private theorem headLnBounds_spec [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : headLnBounds inputs = Bounds.cacheBoundPair2 (fun q => @@ -327,7 +331,7 @@ def headQKVBounds [NeZero seq] {dModel dHead : Nat} qAbs := qAbs kAbs := kAbs } -theorem headQKVBounds_spec [NeZero seq] {dModel dHead : Nat} +private theorem headQKVBounds_spec [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (lnLo lnHi : Fin seq → Fin dModel → Rat) : headQKVBounds inputs lnLo lnHi = @@ -573,7 +577,7 @@ def headScoreBoundsFromDotAbs [NeZero seq] {dModel dHead : Nat} margin := margin eps := eps } -theorem headScoreBoundsFromDotAbs_spec [NeZero seq] {dModel dHead : Nat} +private theorem headScoreBoundsFromDotAbs_spec [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (dotAbs : Fin seq → Fin seq → Rat) : headScoreBoundsFromDotAbs inputs dotAbs = @@ -725,7 +729,7 @@ def headScoreBoundsFromIntervals [NeZero seq] {dModel dHead : Nat} inputs.scale * dotLo q k headScoreBoundsFromCaches inputs dotAbs scoreLo scoreHi -theorem headScoreBoundsFromIntervals_spec [NeZero seq] {dModel dHead : Nat} +private theorem headScoreBoundsFromIntervals_spec [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (qLo qHi kLo kHi : Fin seq → Fin dHead → Rat) : headScoreBoundsFromIntervals inputs qLo qHi kLo kHi = @@ -777,7 +781,7 @@ def headScoreBounds [NeZero seq] {dModel dHead : Nat} headScoreBoundsFromDotAbs inputs (fun q k => Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d)) -theorem headScoreBounds_spec [NeZero seq] {dModel dHead : Nat} +private theorem headScoreBounds_spec [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (qAbs kAbs : Fin seq → Fin dHead → Rat) : headScoreBounds inputs qAbs kAbs = @@ -903,7 +907,7 @@ def headValueDirHead {seq dModel dHead : Nat} let dirHeadVec := dirHeadVecOfInputs inputs fun d => dirHeadVec.get d -theorem headValueDirHead_spec {seq dModel dHead : Nat} +private theorem headValueDirHead_spec {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : headValueDirHead inputs = let dirHeadVec := dirHeadVecOfInputs inputs @@ -918,7 +922,7 @@ def headValueValsLoArray {seq dModel dHead : Nat} Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) /-- Unfold `headValueValsLoArray` to its `Array.ofFn` definition. -/ -theorem headValueValsLoArray_spec {seq dModel dHead : Nat} +private theorem headValueValsLoArray_spec {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (vLo vHi : Fin seq → Fin dHead → Rat) : headValueValsLoArray inputs vLo vHi = @@ -933,7 +937,7 @@ def headValueValsLo {seq dModel dHead : Nat} let arr := headValueValsLoArray inputs vLo vHi fun k => arr.getD k.1 (0 : Rat) -theorem headValueValsLo_spec {seq dModel dHead : Nat} +private theorem headValueValsLo_spec {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (vLo vHi : Fin seq → Fin dHead → Rat) : headValueValsLo inputs vLo vHi = @@ -947,7 +951,7 @@ def headValueValsLoCommonDenArray {seq dModel dHead : Nat} headValueValsLoArray inputs vLo vHi /-- Unfold `headValueValsLoCommonDenArray` to its `Array.ofFn` definition. -/ -theorem headValueValsLoCommonDenArray_spec {seq dModel dHead : Nat} +private theorem headValueValsLoCommonDenArray_spec {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (vLo vHi : Fin seq → Fin dHead → Rat) : headValueValsLoCommonDenArray inputs vLo vHi = @@ -962,7 +966,7 @@ def headValueValsLoCommonDen {seq dModel dHead : Nat} let arr := headValueValsLoCommonDenArray inputs vLo vHi fun k => arr.getD k.1 (0 : Rat) -theorem headValueValsLoCommonDen_spec {seq dModel dHead : Nat} +private theorem headValueValsLoCommonDen_spec {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (vLo vHi : Fin seq → Fin dHead → Rat) : headValueValsLoCommonDen inputs vLo vHi = @@ -992,7 +996,7 @@ def headValueValsHiArray {seq dModel dHead : Nat} Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) /-- Unfold `headValueValsHiArray` to its `Array.ofFn` definition. -/ -theorem headValueValsHiArray_spec {seq dModel dHead : Nat} +private theorem headValueValsHiArray_spec {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (vLo vHi : Fin seq → Fin dHead → Rat) : headValueValsHiArray inputs vLo vHi = @@ -1007,7 +1011,7 @@ def headValueValsHi {seq dModel dHead : Nat} let arr := headValueValsHiArray inputs vLo vHi fun k => arr.getD k.1 (0 : Rat) -theorem headValueValsHi_spec {seq dModel dHead : Nat} +private theorem headValueValsHi_spec {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (vLo vHi : Fin seq → Fin dHead → Rat) : headValueValsHi inputs vLo vHi = @@ -1021,7 +1025,7 @@ def headValueValsHiCommonDenArray {seq dModel dHead : Nat} headValueValsHiArray inputs vLo vHi /-- Unfold `headValueValsHiCommonDenArray` to its `Array.ofFn` definition. -/ -theorem headValueValsHiCommonDenArray_spec {seq dModel dHead : Nat} +private theorem headValueValsHiCommonDenArray_spec {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (vLo vHi : Fin seq → Fin dHead → Rat) : headValueValsHiCommonDenArray inputs vLo vHi = @@ -1036,7 +1040,7 @@ def headValueValsHiCommonDen {seq dModel dHead : Nat} let arr := headValueValsHiCommonDenArray inputs vLo vHi fun k => arr.getD k.1 (0 : Rat) -theorem headValueValsHiCommonDen_spec {seq dModel dHead : Nat} +private theorem headValueValsHiCommonDen_spec {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (vLo vHi : Fin seq → Fin dHead → Rat) : headValueValsHiCommonDen inputs vLo vHi = @@ -1062,25 +1066,25 @@ def headValueLoArray (valsLo : Array Rat) : Rat := reduceMinArray valsLo /-- Unfold `headValueLoArray` to its reduction helper. -/ -theorem headValueLoArray_spec (valsLo : Array Rat) : +private theorem headValueLoArray_spec (valsLo : Array Rat) : headValueLoArray valsLo = reduceMinArray valsLo := rfl /-- Global lower value bound from cached per-key values. -/ def headValueLo [NeZero seq] (valsLo : Fin seq → Rat) : Rat := headValueLoArray (Array.ofFn valsLo) -theorem headValueLo_spec [NeZero seq] (valsLo : Fin seq → Rat) : +private theorem headValueLo_spec [NeZero seq] (valsLo : Fin seq → Rat) : headValueLo valsLo = headValueLoArray (Array.ofFn valsLo) := rfl /-- Task wrapper for `headValueLo`. -/ def headValueLoTask [NeZero seq] (valsLo : Fin seq → Rat) : Task Rat := reduceMinFnTask valsLo -theorem headValueLoTask_spec [NeZero seq] (valsLo : Fin seq → Rat) : +private theorem headValueLoTask_spec [NeZero seq] (valsLo : Fin seq → Rat) : headValueLoTask valsLo = reduceMinFnTask valsLo := rfl /-- Chunked task reduction agrees with the sequential chunked value bound. -/ -theorem headValueLoTask_get_eq [NeZero seq] (valsLo : Fin seq → Rat) : +private theorem headValueLoTask_get_eq [NeZero seq] (valsLo : Fin seq → Rat) : (headValueLoTask valsLo).get = reduceMinFnChunked valsLo := by simp [headValueLoTask_spec, reduceMinFnTask_get_eq] @@ -1089,25 +1093,25 @@ def headValueHiArray (valsHi : Array Rat) : Rat := reduceMaxArray valsHi /-- Unfold `headValueHiArray` to its reduction helper. -/ -theorem headValueHiArray_spec (valsHi : Array Rat) : +private theorem headValueHiArray_spec (valsHi : Array Rat) : headValueHiArray valsHi = reduceMaxArray valsHi := rfl /-- Global upper value bound from cached per-key values. -/ def headValueHi [NeZero seq] (valsHi : Fin seq → Rat) : Rat := headValueHiArray (Array.ofFn valsHi) -theorem headValueHi_spec [NeZero seq] (valsHi : Fin seq → Rat) : +private theorem headValueHi_spec [NeZero seq] (valsHi : Fin seq → Rat) : headValueHi valsHi = headValueHiArray (Array.ofFn valsHi) := rfl /-- Task wrapper for `headValueHi`. -/ def headValueHiTask [NeZero seq] (valsHi : Fin seq → Rat) : Task Rat := reduceMaxFnTask valsHi -theorem headValueHiTask_spec [NeZero seq] (valsHi : Fin seq → Rat) : +private theorem headValueHiTask_spec [NeZero seq] (valsHi : Fin seq → Rat) : headValueHiTask valsHi = reduceMaxFnTask valsHi := rfl /-- Chunked task reduction agrees with the sequential chunked value bound. -/ -theorem headValueHiTask_get_eq [NeZero seq] (valsHi : Fin seq → Rat) : +private theorem headValueHiTask_get_eq [NeZero seq] (valsHi : Fin seq → Rat) : (headValueHiTask valsHi).get = reduceMaxFnChunked valsHi := by simp [headValueHiTask_spec, reduceMaxFnTask_get_eq] @@ -1148,7 +1152,7 @@ def headValueBounds [NeZero seq] {dModel dHead : Nat} let valsHiArr := headValueValsHiArray inputs vLo vHi headValueBoundsOfArrays valsLoArr valsHiArr -theorem headValueBounds_spec [NeZero seq] {dModel dHead : Nat} +private theorem headValueBounds_spec [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (vLo vHi : Fin seq → Fin dHead → Rat) : headValueBounds inputs vLo vHi = @@ -1170,7 +1174,7 @@ def headValueBoundsTask [NeZero seq] {dModel dHead : Nat} Task.map (fun valsHiArr => headValueBoundsOfArrays valsLoArr valsHiArr) valsHiTask) /-- Unfold `headValueBoundsTask` to its task graph. -/ -theorem headValueBoundsTask_spec [NeZero seq] {dModel dHead : Nat} +private theorem headValueBoundsTask_spec [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (vLo vHi : Fin seq → Fin dHead → Rat) : headValueBoundsTask inputs vLo vHi = @@ -1191,7 +1195,7 @@ def headValueBoundsCommonDen [NeZero seq] {dModel dHead : Nat} let valsHiArr := headValueValsHiCommonDenArray inputs vLo vHi headValueBoundsOfArrays valsLoArr valsHiArr -theorem headValueBoundsCommonDen_spec [NeZero seq] {dModel dHead : Nat} +private theorem headValueBoundsCommonDen_spec [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (vLo vHi : Fin seq → Fin dHead → Rat) : headValueBoundsCommonDen inputs vLo vHi = @@ -1213,7 +1217,7 @@ def headValueBoundsCommonDenTask [NeZero seq] {dModel dHead : Nat} Task.map (fun valsHiArr => headValueBoundsOfArrays valsLoArr valsHiArr) valsHiTask) /-- Unfold `headValueBoundsCommonDenTask` to its task graph. -/ -theorem headValueBoundsCommonDenTask_spec [NeZero seq] {dModel dHead : Nat} +private theorem headValueBoundsCommonDenTask_spec [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (vLo vHi : Fin seq → Fin dHead → Rat) : headValueBoundsCommonDenTask inputs vLo vHi = diff --git a/Nfp/Sound/Induction/HeadOutput.lean b/Nfp/Sound/Induction/HeadOutput.lean index e6e15dc..53e9930 100644 --- a/Nfp/Sound/Induction/HeadOutput.lean +++ b/Nfp/Sound/Induction/HeadOutput.lean @@ -1,12 +1,16 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Aesop -import Nfp.Sound.Induction.CoreSound +module + +public import Aesop +public import Nfp.Sound.Induction.CoreSound /-! Head-output interval certificates for induction heads. -/ +@[expose] public section + namespace Nfp namespace Sound diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean index be5b705..2647557 100644 --- a/Nfp/Sound/Induction/LogitDiff.lean +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -1,15 +1,19 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Aesop -import Mathlib.Data.Vector.Basic -import Nfp.Circuit.Cert.LogitDiff -import Nfp.Sound.Bounds.MatrixNorm.Interval -import Nfp.Sound.Induction.HeadOutput +module + +public import Aesop +public import Mathlib.Data.Vector.Basic +public import Nfp.Circuit.Cert.LogitDiff +public import Nfp.Sound.Bounds.MatrixNorm.Interval +public import Nfp.Sound.Induction.HeadOutput /-! Logit-diff bounds derived from induction certificates. -/ +@[expose] public section + namespace Nfp namespace Sound diff --git a/Nfp/Sound/Induction/OneHot.lean b/Nfp/Sound/Induction/OneHot.lean index 8cbca6d..02ebb29 100644 --- a/Nfp/Sound/Induction/OneHot.lean +++ b/Nfp/Sound/Induction/OneHot.lean @@ -1,16 +1,20 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Algebra.Order.BigOperators.Group.Finset -import Mathlib.Data.Rat.BigOperators -import Nfp.Core.Basic -import Nfp.Circuit.Layers.Induction -import Nfp.Circuit.Layers.Softmax +module + +public import Mathlib.Algebra.BigOperators.Group.Finset.Basic +public import Mathlib.Algebra.Order.BigOperators.Group.Finset +public import Mathlib.Data.Rat.BigOperators +public import Nfp.Core.Basic +public import Nfp.Circuit.Layers.Induction +public import Nfp.Circuit.Layers.Softmax /-! Per-query one-hot bounds derived from score margins. -/ +@[expose] public section + namespace Nfp namespace Sound diff --git a/Nfp/Sound/Linear/FinFold.lean b/Nfp/Sound/Linear/FinFold.lean index da24769..8668494 100644 --- a/Nfp/Sound/Linear/FinFold.lean +++ b/Nfp/Sound/Linear/FinFold.lean @@ -1,10 +1,12 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Algebra.BigOperators.Fin -import Mathlib.Data.Matrix.Mul -import Mathlib.Data.Rat.BigOperators -import Batteries.Data.Fin.Fold -import Nfp.Core.Basic +module + +public import Mathlib.Algebra.BigOperators.Fin +public import Mathlib.Data.Matrix.Mul +public import Mathlib.Data.Rat.BigOperators +public import Batteries.Data.Fin.Fold +public import Nfp.Core.Basic /-! Tail-recursive folds and sums over `Fin`. @@ -12,6 +14,8 @@ Tail-recursive folds and sums over `Fin`. These helpers keep sound computations stack-safe while remaining explicit. -/ +@[expose] public section + namespace Nfp namespace Sound From f350dbfdf852b77a11f77079943f0f496de0de73 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 19:33:44 +0100 Subject: [PATCH 169/244] Make Nfp aggregate modules explicit modules --- Nfp/Circuit.lean | 24 +++++++++++++----------- Nfp/IO.lean | 20 +++++++++++--------- Nfp/Model.lean | 8 +++++--- Nfp/Sound.lean | 10 ++++++---- 4 files changed, 35 insertions(+), 27 deletions(-) diff --git a/Nfp/Circuit.lean b/Nfp/Circuit.lean index 51917ea..ae32604 100644 --- a/Nfp/Circuit.lean +++ b/Nfp/Circuit.lean @@ -1,16 +1,18 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Circuit.Basic -import Nfp.Circuit.Combinators -import Nfp.Circuit.Interface -import Nfp.Circuit.Semantics -import Nfp.Circuit.WellFormed -import Nfp.Circuit.Cert -import Nfp.Circuit.Typed -import Nfp.Circuit.Compose -import Nfp.Circuit.Gates -import Nfp.Circuit.Tensor -import Nfp.Circuit.Layers +module + +public import Nfp.Circuit.Basic +public import Nfp.Circuit.Combinators +public import Nfp.Circuit.Interface +public import Nfp.Circuit.Semantics +public import Nfp.Circuit.WellFormed +public import Nfp.Circuit.Cert +public import Nfp.Circuit.Typed +public import Nfp.Circuit.Compose +public import Nfp.Circuit.Gates +public import Nfp.Circuit.Tensor +public import Nfp.Circuit.Layers /-! Circuit definitions, semantics, and equivalence checking. diff --git a/Nfp/IO.lean b/Nfp/IO.lean index a2b17d2..47e3f56 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -1,14 +1,16 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.IO.Checks -import Nfp.IO.Derive -import Nfp.IO.HeadScore -import Nfp.IO.InductionHead -import Nfp.IO.Loaders -import Nfp.IO.NfptPure -import Nfp.IO.Run -import Nfp.IO.Timing -import Nfp.IO.Util +module + +public import Nfp.IO.Checks +public import Nfp.IO.Derive +public import Nfp.IO.HeadScore +public import Nfp.IO.InductionHead +public import Nfp.IO.Loaders +public import Nfp.IO.NfptPure +public import Nfp.IO.Run +public import Nfp.IO.Timing +public import Nfp.IO.Util /-! IO-only wrappers for loading inputs and running checks. diff --git a/Nfp/Model.lean b/Nfp/Model.lean index 9b9f188..7a97665 100644 --- a/Nfp/Model.lean +++ b/Nfp/Model.lean @@ -1,8 +1,10 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Model.Gpt2 -import Nfp.Model.InductionHead -import Nfp.Model.InductionPrompt +module + +public import Nfp.Model.Gpt2 +public import Nfp.Model.InductionHead +public import Nfp.Model.InductionPrompt /-! Model-specific data containers for the NFP rewrite. diff --git a/Nfp/Sound.lean b/Nfp/Sound.lean index 0a49652..dddb21c 100644 --- a/Nfp/Sound.lean +++ b/Nfp/Sound.lean @@ -1,9 +1,11 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Sound.Gpt2.HeadInputs -import Nfp.Sound.Induction -import Nfp.Sound.Bounds -import Nfp.Sound.Linear.FinFold +module + +public import Nfp.Sound.Gpt2.HeadInputs +public import Nfp.Sound.Induction +public import Nfp.Sound.Bounds +public import Nfp.Sound.Linear.FinFold /-! Sound certificate builders and verified helpers. From e0a21205e983b8f6c366c053a3539d950a076c78 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 19:49:42 +0100 Subject: [PATCH 170/244] Trim exposure in model and circuit modules --- Nfp/Circuit/Basic.lean | 2 +- Nfp/Circuit/Cert/Basic.lean | 4 ++-- Nfp/Circuit/Cert/DownstreamLinear.lean | 2 +- Nfp/Circuit/Cert/LogitDiff.lean | 2 +- Nfp/Circuit/Cert/ResidualBound.lean | 2 +- Nfp/Circuit/Cert/ResidualInterval.lean | 2 +- Nfp/Circuit/Cert/SoftmaxMargin.lean | 2 +- Nfp/Circuit/Cert/ValueRange.lean | 2 +- Nfp/Circuit/Combinators.lean | 2 +- Nfp/Circuit/Compose.lean | 2 +- Nfp/Circuit/Gates/Basic.lean | 2 +- Nfp/Circuit/Gates/Linear.lean | 2 +- Nfp/Circuit/Interface.lean | 9 +++++---- Nfp/Circuit/Layers/Attention.lean | 18 +++++++++--------- Nfp/Circuit/Layers/Heads.lean | 2 +- Nfp/Circuit/Layers/Induction/Basic.lean | 4 ++-- Nfp/Circuit/Layers/Linear.lean | 4 ++-- Nfp/Circuit/Layers/Reshape.lean | 2 +- Nfp/Circuit/Layers/Softmax.lean | 2 +- Nfp/Circuit/Layers/Tensor.lean | 2 +- Nfp/Circuit/Layers/TransformerBlock.lean | 2 +- Nfp/Circuit/Semantics.lean | 2 +- Nfp/Circuit/Tensor.lean | 2 +- Nfp/Circuit/Typed.lean | 4 ++-- Nfp/Circuit/WellFormed.lean | 2 +- Nfp/Model/Gpt2.lean | 2 +- Nfp/Model/InductionHead.lean | 2 +- Nfp/Model/InductionPrompt.lean | 2 +- 28 files changed, 44 insertions(+), 43 deletions(-) diff --git a/Nfp/Circuit/Basic.lean b/Nfp/Circuit/Basic.lean index dd1f58a..c796bb3 100644 --- a/Nfp/Circuit/Basic.lean +++ b/Nfp/Circuit/Basic.lean @@ -8,7 +8,7 @@ public import Nfp.System.Dag Circuit foundations: a DAG with designated inputs/outputs and gate semantics. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Circuit/Cert/Basic.lean b/Nfp/Circuit/Cert/Basic.lean index 3b4186e..7636610 100644 --- a/Nfp/Circuit/Cert/Basic.lean +++ b/Nfp/Circuit/Cert/Basic.lean @@ -12,7 +12,7 @@ public import Nfp.Circuit.Semantics Circuit equivalence and a finite checker. -/ -@[expose] public section +public section namespace Nfp @@ -28,7 +28,7 @@ def SameInterface (C₁ C₂ : Circuit ι α) : Prop := C₁.inputs = C₂.inputs ∧ C₁.outputs = C₂.outputs /-- `SameInterface` is decidable. -/ -instance (C₁ C₂ : Circuit ι α) : Decidable (SameInterface C₁ C₂) := by +private instance (C₁ C₂ : Circuit ι α) : Decidable (SameInterface C₁ C₂) := by dsimp [SameInterface] infer_instance diff --git a/Nfp/Circuit/Cert/DownstreamLinear.lean b/Nfp/Circuit/Cert/DownstreamLinear.lean index dfb8662..9d21059 100644 --- a/Nfp/Circuit/Cert/DownstreamLinear.lean +++ b/Nfp/Circuit/Cert/DownstreamLinear.lean @@ -13,7 +13,7 @@ The checker only verifies arithmetic consistency (`error = gain * inputBound`) and nonnegativity of the reported quantities. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Circuit/Cert/LogitDiff.lean b/Nfp/Circuit/Cert/LogitDiff.lean index 6a43101..8b5327d 100644 --- a/Nfp/Circuit/Cert/LogitDiff.lean +++ b/Nfp/Circuit/Cert/LogitDiff.lean @@ -10,7 +10,7 @@ public import Nfp.Circuit.Layers.Induction Lower bounds for logit-diff contributions from induction-style heads. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Circuit/Cert/ResidualBound.lean b/Nfp/Circuit/Cert/ResidualBound.lean index 1287511..167803b 100644 --- a/Nfp/Circuit/Cert/ResidualBound.lean +++ b/Nfp/Circuit/Cert/ResidualBound.lean @@ -11,7 +11,7 @@ Residual-stream bound certificates. These certificates record per-coordinate absolute bounds for residual vectors. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Circuit/Cert/ResidualInterval.lean b/Nfp/Circuit/Cert/ResidualInterval.lean index d9f5f99..7295ecc 100644 --- a/Nfp/Circuit/Cert/ResidualInterval.lean +++ b/Nfp/Circuit/Cert/ResidualInterval.lean @@ -11,7 +11,7 @@ Residual-stream interval certificates. These certificates record per-coordinate lower/upper bounds for residual vectors. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Circuit/Cert/SoftmaxMargin.lean b/Nfp/Circuit/Cert/SoftmaxMargin.lean index 263184c..63cd189 100644 --- a/Nfp/Circuit/Cert/SoftmaxMargin.lean +++ b/Nfp/Circuit/Cert/SoftmaxMargin.lean @@ -11,7 +11,7 @@ public import Nfp.Circuit.Layers.Induction Softmax-margin certificates for approximate one-hot attention weights. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Circuit/Cert/ValueRange.lean b/Nfp/Circuit/Cert/ValueRange.lean index cbaf2a8..56be21d 100644 --- a/Nfp/Circuit/Cert/ValueRange.lean +++ b/Nfp/Circuit/Cert/ValueRange.lean @@ -11,7 +11,7 @@ public import Nfp.Circuit.Layers.Induction Value-range certificates for attention value vectors. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Circuit/Combinators.lean b/Nfp/Circuit/Combinators.lean index d3202f9..41258f5 100644 --- a/Nfp/Circuit/Combinators.lean +++ b/Nfp/Circuit/Combinators.lean @@ -10,7 +10,7 @@ public import Nfp.Circuit.Interface Circuit combinators such as relabeling. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Circuit/Compose.lean b/Nfp/Circuit/Compose.lean index f5a25d5..ad692c0 100644 --- a/Nfp/Circuit/Compose.lean +++ b/Nfp/Circuit/Compose.lean @@ -12,7 +12,7 @@ public import Nfp.Circuit.Typed Combinators for composing typed circuits (sequential and residual wiring). -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Circuit/Gates/Basic.lean b/Nfp/Circuit/Gates/Basic.lean index d8ec954..31f60c8 100644 --- a/Nfp/Circuit/Gates/Basic.lean +++ b/Nfp/Circuit/Gates/Basic.lean @@ -10,7 +10,7 @@ public import Mathlib.Data.Fintype.BigOperators Basic gate combinators for aggregating parent values. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Circuit/Gates/Linear.lean b/Nfp/Circuit/Gates/Linear.lean index 603c57b..6502b9b 100644 --- a/Nfp/Circuit/Gates/Linear.lean +++ b/Nfp/Circuit/Gates/Linear.lean @@ -8,7 +8,7 @@ public import Mathlib.Data.Matrix.Mul Linear and affine gate combinators built from `Matrix.mulVec`. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Circuit/Interface.lean b/Nfp/Circuit/Interface.lean index fca84b1..be7a97d 100644 --- a/Nfp/Circuit/Interface.lean +++ b/Nfp/Circuit/Interface.lean @@ -8,7 +8,7 @@ public import Nfp.Circuit.Semantics Typed input/output interfaces for circuits. -/ -@[expose] public section +public section namespace Nfp @@ -31,15 +31,16 @@ namespace Interface variable {C : Circuit ι α} {ι_in : Type u_in} {ι_out : Type u_out} /-- Convert a typed input assignment into an input-node assignment. -/ -def toInputAssignment (I : Interface C ι_in ι_out) (input : ι_in → α) : C.InputAssignment := +@[expose] def toInputAssignment (I : Interface C ι_in ι_out) (input : ι_in → α) : + C.InputAssignment := fun i => input (I.inputs.symm i) /-- Evaluate a circuit on a typed interface. -/ -def eval (I : Interface C ι_in ι_out) (input : ι_in → α) : ι_out → α := +@[expose] def eval (I : Interface C ι_in ι_out) (input : ι_in → α) : ι_out → α := fun o => evalInput C (I.toInputAssignment input) (I.outputs o).1 /-- Unfolding equation for `Interface.eval`. -/ -theorem eval_eq (I : Interface C ι_in ι_out) (input : ι_in → α) (o : ι_out) : +private theorem eval_eq (I : Interface C ι_in ι_out) (input : ι_in → α) (o : ι_out) : I.eval input o = evalInput C (I.toInputAssignment input) (I.outputs o).1 := rfl diff --git a/Nfp/Circuit/Layers/Attention.lean b/Nfp/Circuit/Layers/Attention.lean index 8df9801..e90255f 100644 --- a/Nfp/Circuit/Layers/Attention.lean +++ b/Nfp/Circuit/Layers/Attention.lean @@ -12,7 +12,7 @@ public import Nfp.Circuit.Layers.Tensor QKV and output projection wiring for attention layers, plus attention score/mixing core. -/ -@[expose] public section +public section namespace Nfp @@ -485,7 +485,7 @@ end Dag section Inputs /-- Input nodes for the attention core. -/ -def attentionInputs : Finset (AttentionNode Batch seq heads dim) := +@[expose] def attentionInputs : Finset (AttentionNode Batch seq heads dim) := (Finset.univ : Finset (AttentionInput Batch seq heads dim)).map Embedding.inl open scoped Classical in @@ -513,7 +513,7 @@ theorem not_mem_attentionInputs_inr (s : Sum (ScoreIndex Batch seq heads) open scoped Classical in /-- Input labels correspond to input nodes in the attention core. -/ -def attentionInputEquiv : +@[expose] def attentionInputEquiv : AttentionInput Batch seq heads dim ≃ { i // i ∈ attentionInputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) } := { toFun := fun a => @@ -545,7 +545,7 @@ end Inputs section Outputs /-- Output nodes for the attention core. -/ -def attentionOutputs : Finset (AttentionNode Batch seq heads dim) := +@[expose] def attentionOutputs : Finset (AttentionNode Batch seq heads dim) := (Finset.univ : Finset (AttentionOutput Batch seq heads dim)).map { toFun := attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) inj' := by @@ -594,7 +594,7 @@ theorem not_mem_attentionOutputs_weight (w : WeightIndex Batch seq heads) : open scoped Classical in /-- Output labels correspond to output nodes in the attention core. -/ -def attentionOutputEquiv : +@[expose] def attentionOutputEquiv : AttentionOutput Batch seq heads dim ≃ { i // i ∈ attentionOutputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) } := { toFun := fun o => @@ -647,7 +647,7 @@ section Circuits variable [DecidableEq Batch] /-- Gate semantics for attention score/mixing circuits. -/ -def attentionGate (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : +@[expose] def attentionGate (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : ∀ i, (∀ j, (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel j i → @@ -707,7 +707,7 @@ def attentionGate (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val exact dotProduct weights vals /-- Circuit for attention score/mixing. -/ -def attentionCircuit (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : +@[expose] def attentionCircuit (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : Circuit (AttentionNode Batch seq heads dim) Val := { dag := attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) inputs := attentionInputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) @@ -716,7 +716,7 @@ def attentionCircuit (scale : Val) (softmax : (Fin seq → Val) → Fin seq → attentionGate (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax } /-- Typed interface for attention score/mixing circuits. -/ -def attentionInterface (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : +@[expose] def attentionInterface (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : Interface (attentionCircuit (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax) (AttentionInput Batch seq heads dim) (AttentionOutput Batch seq heads dim) := @@ -730,7 +730,7 @@ section Typed variable [DecidableEq Batch] /-- Typed attention score/mixing circuit. -/ -def attentionTyped (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : +@[expose] def attentionTyped (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : TypedCircuit (AttentionNode Batch seq heads dim) Val (AttentionInput Batch seq heads dim) (AttentionOutput Batch seq heads dim) := { circuit := attentionCircuit (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) diff --git a/Nfp/Circuit/Layers/Heads.lean b/Nfp/Circuit/Layers/Heads.lean index e12719b..2cebb38 100644 --- a/Nfp/Circuit/Layers/Heads.lean +++ b/Nfp/Circuit/Layers/Heads.lean @@ -10,7 +10,7 @@ public import Nfp.Circuit.Layers.Reshape Head split/merge combinators for transformer-style shapes. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Circuit/Layers/Induction/Basic.lean b/Nfp/Circuit/Layers/Induction/Basic.lean index da4fcd6..c08c14c 100644 --- a/Nfp/Circuit/Layers/Induction/Basic.lean +++ b/Nfp/Circuit/Layers/Induction/Basic.lean @@ -12,7 +12,7 @@ public import Nfp.Circuit.Layers.Attention Induction-head specifications for attention cores. -/ -@[expose] public section +public section namespace Nfp @@ -79,7 +79,7 @@ def InductionSpecApprox (ε : Val) ∀ q, q ≠ 0 → out q ≤ vals (prev q) + ε ∧ vals (prev q) ≤ out q + ε /-- Approximate induction-head spec restricted to active queries. -/ -def InductionSpecApproxOn (ε : Val) (active : Fin (Nat.succ n) → Prop) +@[expose] def InductionSpecApproxOn (ε : Val) (active : Fin (Nat.succ n) → Prop) (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) (out vals : Fin (Nat.succ n) → Val) : Prop := ∀ q, active q → out q ≤ vals (prev q) + ε ∧ vals (prev q) ≤ out q + ε diff --git a/Nfp/Circuit/Layers/Linear.lean b/Nfp/Circuit/Layers/Linear.lean index 57b6b3a..21847d2 100644 --- a/Nfp/Circuit/Layers/Linear.lean +++ b/Nfp/Circuit/Layers/Linear.lean @@ -12,7 +12,7 @@ public import Nfp.Circuit.Typed Linear and affine layer circuits. -/ -@[expose] public section +public section namespace Nfp @@ -30,7 +30,7 @@ variable {Row Col : Type u} abbrev LinearNode (Row Col : Type u) : Type u := Sum Col Row /-- Rank function used to orient layer edges from inputs to outputs. -/ -def linearRank : LinearNode Row Col → Nat +@[expose] def linearRank : LinearNode Row Col → Nat | Sum.inl _ => 0 | Sum.inr _ => 1 diff --git a/Nfp/Circuit/Layers/Reshape.lean b/Nfp/Circuit/Layers/Reshape.lean index e0180c1..59ad3fa 100644 --- a/Nfp/Circuit/Layers/Reshape.lean +++ b/Nfp/Circuit/Layers/Reshape.lean @@ -9,7 +9,7 @@ public import Nfp.Circuit.Typed Reshape combinators for product-typed circuit interfaces. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Circuit/Layers/Softmax.lean b/Nfp/Circuit/Layers/Softmax.lean index d9c1bbc..6446159 100644 --- a/Nfp/Circuit/Layers/Softmax.lean +++ b/Nfp/Circuit/Layers/Softmax.lean @@ -14,7 +14,7 @@ These lemmas provide the analytical bridge from score gaps to softmax weight upper bounds. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Circuit/Layers/Tensor.lean b/Nfp/Circuit/Layers/Tensor.lean index 768f329..11860a2 100644 --- a/Nfp/Circuit/Layers/Tensor.lean +++ b/Nfp/Circuit/Layers/Tensor.lean @@ -8,7 +8,7 @@ public import Nfp.Circuit.Layers.Linear Tensor-shaped layer builders (batched linear and affine layers). -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Circuit/Layers/TransformerBlock.lean b/Nfp/Circuit/Layers/TransformerBlock.lean index 3124023..14e24c3 100644 --- a/Nfp/Circuit/Layers/TransformerBlock.lean +++ b/Nfp/Circuit/Layers/TransformerBlock.lean @@ -9,7 +9,7 @@ public import Nfp.Circuit.Layers.Attention Transformer block wiring built from sequential composition and residual links. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Circuit/Semantics.lean b/Nfp/Circuit/Semantics.lean index f30d6d6..6e629e3 100644 --- a/Nfp/Circuit/Semantics.lean +++ b/Nfp/Circuit/Semantics.lean @@ -8,7 +8,7 @@ public import Nfp.Circuit.Basic Evaluation semantics for finite circuits. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Circuit/Tensor.lean b/Nfp/Circuit/Tensor.lean index f8b6917..e6b252d 100644 --- a/Nfp/Circuit/Tensor.lean +++ b/Nfp/Circuit/Tensor.lean @@ -8,7 +8,7 @@ public import Mathlib.Data.Matrix.Basic Typed tensor indices and tensor aliases. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Circuit/Typed.lean b/Nfp/Circuit/Typed.lean index 1ef1330..f98f6b6 100644 --- a/Nfp/Circuit/Typed.lean +++ b/Nfp/Circuit/Typed.lean @@ -9,7 +9,7 @@ public import Nfp.Circuit.Cert.Basic Typed circuit wrappers and typed equivalence checking. -/ -@[expose] public section +public section namespace Nfp @@ -31,7 +31,7 @@ variable {Node : Type u} [Fintype Node] [DecidableEq Node] variable {Val : Type v} {Input : Type u_in} {Output : Type u_out} /-- Evaluate a typed circuit on a typed input. -/ -def eval (T : TypedCircuit Node Val Input Output) (input : Input → Val) : Output → Val := +@[expose] def eval (T : TypedCircuit Node Val Input Output) (input : Input → Val) : Output → Val := T.interface.eval input /-- Decide equivalence by enumerating typed inputs. -/ diff --git a/Nfp/Circuit/WellFormed.lean b/Nfp/Circuit/WellFormed.lean index e04bdb9..c9813d5 100644 --- a/Nfp/Circuit/WellFormed.lean +++ b/Nfp/Circuit/WellFormed.lean @@ -8,7 +8,7 @@ public import Nfp.Circuit.Basic Well-formedness conditions for circuits. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Model/Gpt2.lean b/Nfp/Model/Gpt2.lean index 82db571..9efb892 100644 --- a/Nfp/Model/Gpt2.lean +++ b/Nfp/Model/Gpt2.lean @@ -13,7 +13,7 @@ MLP/LayerNorm parameters needed to build `InductionHeadInputs` and downstream bound computations. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Model/InductionHead.lean b/Nfp/Model/InductionHead.lean index 9d77227..f2d879a 100644 --- a/Nfp/Model/InductionHead.lean +++ b/Nfp/Model/InductionHead.lean @@ -13,7 +13,7 @@ These structures store exact rational inputs (embeddings and weights) for a single attention head. They are intended to be consumed by sound builders. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Model/InductionPrompt.lean b/Nfp/Model/InductionPrompt.lean index df13727..5302c6f 100644 --- a/Nfp/Model/InductionPrompt.lean +++ b/Nfp/Model/InductionPrompt.lean @@ -13,7 +13,7 @@ active-query set from a fixed period. They keep the prompt bookkeeping separate from the model weights. -/ -@[expose] public section +public section namespace Nfp From bc087babd18d9ac882f27926c81dfa5476128b7d Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 20:20:36 +0100 Subject: [PATCH 171/244] Refine Sound exposure usage --- Nfp/Sound/Bounds/Attention.lean | 4 +- Nfp/Sound/Bounds/Cache.lean | 2 +- Nfp/Sound/Bounds/Gelu.lean | 2 +- Nfp/Sound/Bounds/LayerNorm/Basic.lean | 15 ++++++- Nfp/Sound/Bounds/LayerNorm/InvStd.lean | 15 ++++++- Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean | 2 +- Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean | 2 +- Nfp/Sound/Bounds/MatrixNorm/Basic.lean | 2 +- Nfp/Sound/Bounds/MatrixNorm/Interval.lean | 6 +-- Nfp/Sound/Bounds/Mlp.lean | 2 +- Nfp/Sound/Bounds/Transformer/Basic.lean | 2 +- Nfp/Sound/Bounds/UnnormRat.lean | 2 +- Nfp/Sound/Gpt2/HeadInputs.lean | 4 +- Nfp/Sound/Induction/Core/Basic.lean | 11 ++++- Nfp/Sound/Induction/CoreDefs.lean | 36 ++++++++++------ Nfp/Sound/Induction/CoreSound/Basic.lean | 43 +++++++++++--------- Nfp/Sound/Induction/CoreSound/Values.lean | 2 +- Nfp/Sound/Induction/EndToEnd.lean | 2 +- Nfp/Sound/Induction/HeadOutput.lean | 14 ++++--- Nfp/Sound/Induction/LogitDiff.lean | 8 ++-- Nfp/Sound/Induction/OneHot.lean | 2 +- Nfp/Sound/Linear/FinFold.lean | 6 +-- 22 files changed, 118 insertions(+), 66 deletions(-) diff --git a/Nfp/Sound/Bounds/Attention.lean b/Nfp/Sound/Bounds/Attention.lean index c81fa85..27ce779 100644 --- a/Nfp/Sound/Bounds/Attention.lean +++ b/Nfp/Sound/Bounds/Attention.lean @@ -18,7 +18,7 @@ public import Nfp.Sound.Bounds.Mlp Interval bounds for multi-head attention and transformer layers. -/ -@[expose] public section +public section namespace Nfp @@ -49,7 +49,7 @@ noncomputable def attentionOutputReal {seq dModel dHead numHeads : Nat} [NeZero (∑ h, headProj h q i) + (attnBias i : Real) /-- Unfolding lemma for `attentionOutputReal`. -/ -theorem attentionOutputReal_def {seq dModel dHead numHeads : Nat} [NeZero seq] +private theorem attentionOutputReal_def {seq dModel dHead numHeads : Nat} [NeZero seq] (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) (attnBias : Fin dModel → Rat) diff --git a/Nfp/Sound/Bounds/Cache.lean b/Nfp/Sound/Bounds/Cache.lean index 4b8e1aa..a88f9d8 100644 --- a/Nfp/Sound/Bounds/Cache.lean +++ b/Nfp/Sound/Bounds/Cache.lean @@ -8,7 +8,7 @@ public import Nfp.Core.Basic Caching helpers for interval bounds. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Sound/Bounds/Gelu.lean b/Nfp/Sound/Bounds/Gelu.lean index f76f06e..2c4499f 100644 --- a/Nfp/Sound/Bounds/Gelu.lean +++ b/Nfp/Sound/Bounds/Gelu.lean @@ -12,7 +12,7 @@ Tanh-based GELU bounds for GPT-2 style MLPs. These bounds are used to propagate interval constraints through nonlinear gates. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Sound/Bounds/LayerNorm/Basic.lean b/Nfp/Sound/Bounds/LayerNorm/Basic.lean index bf0846f..8750aef 100644 --- a/Nfp/Sound/Bounds/LayerNorm/Basic.lean +++ b/Nfp/Sound/Bounds/LayerNorm/Basic.lean @@ -22,7 +22,7 @@ This module computes rational interval bounds for LayerNorm outputs and proves those bounds sound for real-valued LayerNorm semantics. -/ -@[expose] public section +public section namespace Nfp @@ -82,6 +82,19 @@ noncomputable def layerNormReal {n : Nat} let invStd : Real := (Real.sqrt varEps)⁻¹ fun i => (gamma i : Real) * ((x i : Real) - μ) * invStd + (beta i : Real) +/-- Unfolding lemma for `layerNormReal`. -/ +theorem layerNormReal_def {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) : + layerNormReal eps gamma beta x = + if n = 0 then + fun _ => 0 + else + let μ : Real := meanRat x + let varEps : Real := (varianceRat x : Real) + (eps : Real) + let invStd : Real := (Real.sqrt varEps)⁻¹ + fun i => (gamma i : Real) * ((x i : Real) - μ) * invStd + (beta i : Real) := by + simp [layerNormReal] + /-- Real-valued LayerNorm output for a real vector. -/ noncomputable def layerNormRealOfReal {n : Nat} (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Real) : Fin n → Real := diff --git a/Nfp/Sound/Bounds/LayerNorm/InvStd.lean b/Nfp/Sound/Bounds/LayerNorm/InvStd.lean index cac7454..992d167 100644 --- a/Nfp/Sound/Bounds/LayerNorm/InvStd.lean +++ b/Nfp/Sound/Bounds/LayerNorm/InvStd.lean @@ -12,7 +12,7 @@ This module isolates invStd bounds and their soundness proof to keep `LayerNorm/Basic.lean` below the style linter's file-length limit. -/ -@[expose] public section +public section namespace Nfp @@ -31,6 +31,19 @@ def invStdBounds {n : Nat} (eps : Rat) (x : Fin n → Rat) : Rat × Rat := let sqrtUpperBound : Rat := sqrtUpper varEps (ratDivDown 1 sqrtUpperBound, ratDivUp 1 sqrtLowerBound) +/-- Unfolding lemma for `invStdBounds`. -/ +theorem invStdBounds_def {n : Nat} (eps : Rat) (x : Fin n → Rat) : + invStdBounds eps x = + if n = 0 then + (0, 0) + else + let var : Rat := variance x + let varEps : Rat := var + eps + let sqrtLowerBound : Rat := max (sqrtLower eps) (sqrtLower varEps) + let sqrtUpperBound : Rat := sqrtUpper varEps + (ratDivDown 1 sqrtUpperBound, ratDivUp 1 sqrtLowerBound) := by + simp [invStdBounds] + /-- `invStdBounds` soundness for real inverse-std terms. -/ theorem invStdBounds_spec {n : Nat} (eps : Rat) (x : Fin n → Rat) (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : diff --git a/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean b/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean index b03f058..c669a21 100644 --- a/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean +++ b/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean @@ -18,7 +18,7 @@ This module isolates the rational and real mean/variance definitions and their basic lemmas to keep `LayerNorm` bounds modular. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean b/Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean index bbf85a0..591a819 100644 --- a/Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean +++ b/Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean @@ -16,7 +16,7 @@ This module isolates the rational sqrt lower/upper bounds and their basic nonnegativity/positivity lemmas so the main LayerNorm bounds stay focused. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Sound/Bounds/MatrixNorm/Basic.lean b/Nfp/Sound/Bounds/MatrixNorm/Basic.lean index 0d7b1a3..5e9bd32 100644 --- a/Nfp/Sound/Bounds/MatrixNorm/Basic.lean +++ b/Nfp/Sound/Bounds/MatrixNorm/Basic.lean @@ -21,7 +21,7 @@ These bounds are used to compute verified downstream error certificates from explicit Rat matrices. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Sound/Bounds/MatrixNorm/Interval.lean b/Nfp/Sound/Bounds/MatrixNorm/Interval.lean index 8fc7455..ddb9edc 100644 --- a/Nfp/Sound/Bounds/MatrixNorm/Interval.lean +++ b/Nfp/Sound/Bounds/MatrixNorm/Interval.lean @@ -16,7 +16,7 @@ Interval bounds for dot products and matrix-vector products. This module isolates interval-bound helpers used across downstream certificates. -/ -@[expose] public section +public section namespace Nfp @@ -353,10 +353,10 @@ theorem dotIntervalLowerUpperCommonDen_eq {n : Nat} (v lo hi : Fin n → Rat) : (dotIntervalLowerCommonDen v lo hi, dotIntervalUpperCommonDen v lo hi) := by ext <;> simp only [dotIntervalLowerUpperCommonDen_fst, dotIntervalLowerUpperCommonDen_snd] -theorem dotIntervalLowerUnnorm_eq {n : Nat} (v lo hi : Fin n → Rat) : +private theorem dotIntervalLowerUnnorm_eq {n : Nat} (v lo hi : Fin n → Rat) : dotIntervalLowerUnnorm v lo hi = dotIntervalLower v lo hi := rfl -theorem dotIntervalUpperUnnorm_eq {n : Nat} (v lo hi : Fin n → Rat) : +private theorem dotIntervalUpperUnnorm_eq {n : Nat} (v lo hi : Fin n → Rat) : dotIntervalUpperUnnorm v lo hi = dotIntervalUpper v lo hi := rfl /-! Cached endpoints. -/ diff --git a/Nfp/Sound/Bounds/Mlp.lean b/Nfp/Sound/Bounds/Mlp.lean index 118dffd..097bc66 100644 --- a/Nfp/Sound/Bounds/Mlp.lean +++ b/Nfp/Sound/Bounds/Mlp.lean @@ -12,7 +12,7 @@ public import Nfp.Sound.Bounds.MatrixNorm Interval bounds for GPT-2 MLP blocks (linear + GELU + linear). -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Sound/Bounds/Transformer/Basic.lean b/Nfp/Sound/Bounds/Transformer/Basic.lean index 253582c..d8a64e3 100644 --- a/Nfp/Sound/Bounds/Transformer/Basic.lean +++ b/Nfp/Sound/Bounds/Transformer/Basic.lean @@ -16,7 +16,7 @@ public import Nfp.Sound.Linear.FinFold Interval bounds for transformer stacks and final LayerNorm outputs. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Sound/Bounds/UnnormRat.lean b/Nfp/Sound/Bounds/UnnormRat.lean index 802c441..90824fa 100644 --- a/Nfp/Sound/Bounds/UnnormRat.lean +++ b/Nfp/Sound/Bounds/UnnormRat.lean @@ -12,7 +12,7 @@ Rat values already avoid gcd normalization, so this module provides a lightweight alias and helper API used by older code paths. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Sound/Gpt2/HeadInputs.lean b/Nfp/Sound/Gpt2/HeadInputs.lean index 7a3452f..9270075 100644 --- a/Nfp/Sound/Gpt2/HeadInputs.lean +++ b/Nfp/Sound/Gpt2/HeadInputs.lean @@ -14,7 +14,7 @@ periodic prompt description. The construction is purely definitional and is captured by an explicit theorem, so the trusted core does not hide any logic. -/ -@[expose] public section +public section namespace Nfp @@ -25,7 +25,7 @@ namespace Gpt2 open Nfp.Model /-- Build induction-head inputs from a GPT-2 head slice and prompt period. -/ -def buildInductionHeadInputs {seq dModel dHead vocab : Nat} +@[expose] def buildInductionHeadInputs {seq dModel dHead vocab : Nat} (slice : Gpt2HeadSlice seq dModel dHead vocab) (period : Nat) : Model.InductionHeadInputs seq dModel dHead := { scale := slice.scale diff --git a/Nfp/Sound/Induction/Core/Basic.lean b/Nfp/Sound/Induction/Core/Basic.lean index d058e0e..1ded587 100644 --- a/Nfp/Sound/Induction/Core/Basic.lean +++ b/Nfp/Sound/Induction/Core/Basic.lean @@ -20,7 +20,7 @@ public import Nfp.Sound.Linear.FinFold /-! Sound builders for induction certificates; recompute bounds inside Lean from exact inputs and derive softmax tolerances from score margins rather than trusting external weight dumps. -/ -@[expose] public section +public section namespace Nfp namespace Sound @@ -246,7 +246,7 @@ structure InductionHeadCoreCache (seq dModel dHead : Nat) where cert : InductionHeadCert seq /-- Build cached core quantities for induction-head certificates. -/ -def buildInductionHeadCoreCacheWith [NeZero seq] {dModel dHead : Nat} +@[expose] def buildInductionHeadCoreCacheWith [NeZero seq] {dModel dHead : Nat} (cfg : InductionHeadSplitConfig) (inputs : Model.InductionHeadInputs seq dModel dHead) : InductionHeadCoreCache seq dModel dHead := by @@ -815,6 +815,13 @@ def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} Option (InductionHeadCert seq) := buildInductionCertFromHeadCoreWith? defaultInductionHeadSplitConfig inputs +/-- Unfolding lemma for `buildInductionCertFromHeadCore?`. -/ +theorem buildInductionCertFromHeadCore?_def [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : + buildInductionCertFromHeadCore? inputs = + buildInductionCertFromHeadCoreWith? defaultInductionHeadSplitConfig inputs := by + simp [buildInductionCertFromHeadCore?] + /-- `buildInductionCertFromHeadCoreWith?` succeeds under the guard conditions. -/ theorem buildInductionCertFromHeadCoreWith?_eq_some [NeZero seq] {dModel dHead : Nat} (cfg : InductionHeadSplitConfig) diff --git a/Nfp/Sound/Induction/CoreDefs.lean b/Nfp/Sound/Induction/CoreDefs.lean index 3e586b5..b599601 100644 --- a/Nfp/Sound/Induction/CoreDefs.lean +++ b/Nfp/Sound/Induction/CoreDefs.lean @@ -3,7 +3,7 @@ module public import Mathlib.Algebra.BigOperators.Group.Finset.Basic -public import Mathlib.Data.Vector.Defs +public import Batteries.Data.Vector.Lemmas public import Nfp.Circuit.Layers.Induction public import Nfp.Circuit.Layers.Softmax public import Nfp.Core.Basic @@ -18,7 +18,7 @@ Core definitions for induction-head certificates. These definitions are shared across induction certificate builders and checkers. -/ -@[expose] public section +public section namespace Nfp @@ -37,6 +37,13 @@ def dirHeadVecOfInputs {seq dModel dHead : Nat} Vector.ofFn (fun d : Fin dHead => Linear.dotFin dModel (fun j => inputs.wo j d) (fun j => inputs.direction j)) +/-- Unfolding lemma for `dirHeadVecOfInputs`. -/ +theorem dirHeadVecOfInputs_get {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) (d : Fin dHead) : + (dirHeadVecOfInputs inputs).get d = + Linear.dotFin dModel (fun j => inputs.wo j d) (fun j => inputs.direction j) := by + simp [dirHeadVecOfInputs] + /-- Real-valued LayerNorm outputs for head inputs. -/ noncomputable def lnRealOfInputs {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dModel → Real := @@ -47,7 +54,8 @@ noncomputable def lnRealOfInputs {seq dModel dHead : Nat} theorem lnRealOfInputs_def {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (q : Fin seq) (i : Fin dModel) : lnRealOfInputs inputs q i = - Bounds.layerNormReal inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q) i := rfl + Bounds.layerNormReal inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q) i := by + simp [lnRealOfInputs] /-- Real-valued query projections for head inputs. -/ noncomputable def qRealOfInputs {seq dModel dHead : Nat} @@ -60,7 +68,8 @@ theorem qRealOfInputs_def {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (q : Fin seq) (d : Fin dHead) : qRealOfInputs inputs q d = dotProduct (fun j => (inputs.wq j d : Real)) (lnRealOfInputs inputs q) + - (inputs.bq d : Real) := rfl + (inputs.bq d : Real) := by + simp [qRealOfInputs] /-- Real-valued key projections for head inputs. -/ noncomputable def kRealOfInputs {seq dModel dHead : Nat} @@ -73,16 +82,17 @@ theorem kRealOfInputs_def {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (q : Fin seq) (d : Fin dHead) : kRealOfInputs inputs q d = dotProduct (fun j => (inputs.wk j d : Real)) (lnRealOfInputs inputs q) + - (inputs.bk d : Real) := rfl + (inputs.bk d : Real) := by + simp [kRealOfInputs] /-- Real-valued value projections for head inputs. -/ -noncomputable def vRealOfInputs {seq dModel dHead : Nat} +@[expose] noncomputable def vRealOfInputs {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dHead → Real := fun q d => dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs q) + (inputs.bv d : Real) /-- Unfolding lemma for `vRealOfInputs`. -/ -theorem vRealOfInputs_def {seq dModel dHead : Nat} +private theorem vRealOfInputs_def {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (q : Fin seq) (d : Fin dHead) : vRealOfInputs inputs q d = dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs q) + @@ -116,7 +126,8 @@ theorem scoresRealOfInputs_def {seq dModel dHead : Nat} else (inputs.maskValue : Real) else - base := rfl + base := by + simp [scoresRealOfInputs] /-- Real-valued per-key head outputs in model space. -/ noncomputable def headValueRealOfInputs {seq dModel dHead : Nat} @@ -128,16 +139,17 @@ noncomputable def headValueRealOfInputs {seq dModel dHead : Nat} theorem headValueRealOfInputs_def {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (k : Fin seq) (i : Fin dModel) : headValueRealOfInputs inputs k i = - dotProduct (fun d => (inputs.wo i d : Real)) (fun d => vRealOfInputs inputs k d) := rfl + dotProduct (fun d => (inputs.wo i d : Real)) (fun d => vRealOfInputs inputs k d) := by + simp [headValueRealOfInputs] /-- Real-valued direction scores for head inputs. -/ -noncomputable def valsRealOfInputs {seq dModel dHead : Nat} +@[expose] noncomputable def valsRealOfInputs {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Real := let dirHead : Fin dHead → Real := fun d => (dirHeadVecOfInputs inputs).get d fun k => dotProduct dirHead (fun d => vRealOfInputs inputs k d) /-- Unfolding lemma for `valsRealOfInputs`. -/ -theorem valsRealOfInputs_def {seq dModel dHead : Nat} +private theorem valsRealOfInputs_def {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (k : Fin seq) : valsRealOfInputs inputs k = let dirHead : Fin dHead → Real := fun d => (dirHeadVecOfInputs inputs).get d @@ -188,7 +200,7 @@ def defaultInductionHeadSplitConfig : InductionHeadSplitConfig := splitBudgetDiffRefined := 12 } /-- Unfolding lemma for `defaultInductionHeadSplitConfig`. -/ -theorem defaultInductionHeadSplitConfig_def : +private theorem defaultInductionHeadSplitConfig_def : defaultInductionHeadSplitConfig = { splitBudgetQ := 2 splitBudgetK := 2 diff --git a/Nfp/Sound/Induction/CoreSound/Basic.lean b/Nfp/Sound/Induction/CoreSound/Basic.lean index 12ccf0f..3514a9f 100644 --- a/Nfp/Sound/Induction/CoreSound/Basic.lean +++ b/Nfp/Sound/Induction/CoreSound/Basic.lean @@ -4,7 +4,7 @@ module public import Nfp.Sound.Induction.Core public import Nfp.Sound.Induction.CoreSound.Values -@[expose] public section +public section namespace Nfp namespace Sound @@ -77,8 +77,8 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q := by intro q j have hmu := hmeanRat q - simp [lnRealOfInputs, Bounds.layerNormReal, hmodel, lnCoeff, hmu, invStd, add_comm, - mul_assoc, -mul_eq_mul_left_iff, -mul_eq_mul_right_iff] + simp [lnRealOfInputs_def, Bounds.layerNormReal_def, hmodel, lnCoeff, hmu, invStd, + add_comm, mul_assoc, -mul_eq_mul_left_iff, -mul_eq_mul_right_iff] have hln_fun : ∀ q, lnRealOfInputs inputs q = @@ -90,7 +90,7 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N ∀ q, (invStdLo q : Real) ≤ invStd q ∧ invStd q ≤ (invStdHi q : Real) := by intro q simpa [invStd, invStdLo, invStdHi, invStdBoundsArr, invStdBoundsTasks, - invStdBounds, Task.spawn, Array.getElem_ofFn] using + Bounds.invStdBounds_def, Task.spawn, Array.getElem_ofFn] using (Bounds.invStdBounds_spec (eps := inputs.lnEps) (x := inputs.embed q) hmodel hEps hSqrt) let qBaseArr : Array Rat := @@ -476,9 +476,12 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N prev := inputs.prev values := valCert } have hcore' : buildInductionCertFromHeadCoreWith? cfg inputs = some cert := by - simp (config := { zeta := false }) only - [buildInductionCertFromHeadCoreWith?, hEps, hSqrt, hmodel, hactive] - rfl + have hcore'' : + buildInductionCertFromHeadCoreWith? cfg inputs = + some (buildInductionHeadCoreCacheWith cfg inputs).cert := + buildInductionCertFromHeadCoreWith?_eq_some + (cfg := cfg) (inputs := inputs) hEps hSqrt hmodel hactive + simpa using hcore'' have hc : c = cert := by have hcert : cert = c := by exact Option.some.inj (hcore'.symm.trans hcore) @@ -492,8 +495,9 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N Bounds.layerNormBounds_spec (eps := inputs.lnEps) (gamma := inputs.ln1Gamma) (beta := inputs.ln1Beta) (x := inputs.embed q) hmodel hEps hSqrt - simpa [lnBounds, lnLo, lnHi, lnRealOfInputs, Bounds.cacheBoundPair2_apply_left, - Bounds.cacheBoundPair2_apply_right] using hln i + simpa [lnBounds, lnLo, lnHi, lnRealOfInputs_def, + Bounds.cacheBoundPair2_apply_left, Bounds.cacheBoundPair2_apply_right] using + hln i have dotFin_cast {n : Nat} (f g : Fin n → Rat) : (Linear.dotFin n f g : Real) = dotProduct (fun j => (f j : Real)) (fun j => (g j : Real)) := by @@ -619,7 +623,7 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N inputs.ln1Gamma j * (inputs.embed q' j - mean (inputs.embed q')))) have h := proj_bounds (w := inputs.wq) (b := inputs.bq) (base := qBase) (coeff := qCoeff) hbase hcoeff q d - simpa [qLo, qHi, qRealOfInputs] using h + simpa [qLo, qHi, qRealOfInputs_def] using h have hk_bounds : ∀ q d, (kLo q d : Real) ≤ kRealOfInputs inputs q d ∧ kRealOfInputs inputs q d ≤ (kHi q d : Real) := by @@ -644,7 +648,7 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N inputs.ln1Gamma j * (inputs.embed q' j - mean (inputs.embed q')))) have h := proj_bounds (w := inputs.wk) (b := inputs.bk) (base := kBase) (coeff := kCoeff) hbase hcoeff q d - simpa [kLo, kHi, kRealOfInputs] using h + simpa [kLo, kHi, kRealOfInputs_def] using h let scoresReal := scoresRealOfInputs inputs have scoresReal_eq_base_of_not_masked : ∀ q k, ¬ masked q k → @@ -658,15 +662,15 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N intro hlt exact hnot ⟨hcausal, hlt⟩ have hle : k ≤ q := le_of_not_gt hnot_lt - simp [scoresReal, scoresRealOfInputs, hcausal, hle] - · simp [scoresReal, scoresRealOfInputs, hcausal] + simp [scoresReal, scoresRealOfInputs_def, hcausal, hle] + · simp [scoresReal, scoresRealOfInputs_def, hcausal] have scoresReal_eq_masked : ∀ q k, masked q k → scoresReal q k = (inputs.maskValue : Real) := by intro q k hmask have hmask' : inputs.maskCausal = true ∧ q < k := by simpa [masked] using hmask have hle : ¬ k ≤ q := not_le_of_gt hmask'.2 - simp [scoresReal, scoresRealOfInputs, hmask'.1, hle] + simp [scoresReal, scoresRealOfInputs_def, hmask'.1, hle] have hscore_bounds : ∀ q k, (scoreLo q k : Real) ≤ scoresReal q k ∧ scoresReal q k ≤ (scoreHi q k : Real) := by @@ -1303,10 +1307,11 @@ theorem buildInductionCertFromHeadCore?_sound (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) (hcore : buildInductionCertFromHeadCore? inputs = some c) : InductionHeadCertSound inputs c := by - simpa [buildInductionCertFromHeadCore?] using - (buildInductionCertFromHeadCoreWith?_sound - (cfg := defaultInductionHeadSplitConfig) inputs c - (by - simpa [buildInductionCertFromHeadCore?] using hcore)) + have hcore' : + buildInductionCertFromHeadCoreWith? defaultInductionHeadSplitConfig inputs = some c := by + simpa [buildInductionCertFromHeadCore?_def] using hcore + exact + buildInductionCertFromHeadCoreWith?_sound + (cfg := defaultInductionHeadSplitConfig) inputs c hcore' end Sound end Nfp diff --git a/Nfp/Sound/Induction/CoreSound/Values.lean b/Nfp/Sound/Induction/CoreSound/Values.lean index b448ac7..2e5d52a 100644 --- a/Nfp/Sound/Induction/CoreSound/Values.lean +++ b/Nfp/Sound/Induction/CoreSound/Values.lean @@ -12,7 +12,7 @@ These isolate the algebra needed to rewrite direction-value projections into dot products over cached `wvDir`/`bDir` terms. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Sound/Induction/EndToEnd.lean b/Nfp/Sound/Induction/EndToEnd.lean index e193a60..c5866ce 100644 --- a/Nfp/Sound/Induction/EndToEnd.lean +++ b/Nfp/Sound/Induction/EndToEnd.lean @@ -10,7 +10,7 @@ public import Nfp.Sound.Induction.LogitDiff End-to-end induction bounds that combine head certificates with transformer-stack intervals. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Sound/Induction/HeadOutput.lean b/Nfp/Sound/Induction/HeadOutput.lean index 53e9930..1a0fe47 100644 --- a/Nfp/Sound/Induction/HeadOutput.lean +++ b/Nfp/Sound/Induction/HeadOutput.lean @@ -9,7 +9,7 @@ public import Nfp.Sound.Induction.CoreSound Head-output interval certificates for induction heads. -/ -@[expose] public section +public section namespace Nfp @@ -63,7 +63,8 @@ theorem headOutputWithScores_def (scores : Fin seq → Fin seq → Real) let weights : Fin seq → Fin seq → Real := fun q k => Circuit.softmax (scores q) k let vals : Fin seq → Real := fun k => headValueRealOfInputs inputs k i - dotProduct (weights q) vals := rfl + dotProduct (weights q) vals := by + simp [headOutputWithScores] /-- Real-valued head output for a query and model dimension. -/ def headOutput (inputs : Model.InductionHeadInputs seq dModel dHead) @@ -74,7 +75,8 @@ def headOutput (inputs : Model.InductionHeadInputs seq dModel dHead) theorem headOutput_def (inputs : Model.InductionHeadInputs seq dModel dHead) (q : Fin seq) (i : Fin dModel) : headOutput inputs q i = - headOutputWithScores (scoresRealOfInputs inputs) inputs q i := rfl + headOutputWithScores (scoresRealOfInputs inputs) inputs q i := by + simp [headOutput] /-- Soundness predicate for head-output interval bounds. -/ structure HeadOutputIntervalSound [NeZero seq] @@ -138,7 +140,7 @@ def buildHeadOutputIntervalFromHead? [NeZero seq] Bounds.layerNormBounds_spec (eps := inputs.lnEps) (gamma := inputs.ln1Gamma) (beta := inputs.ln1Beta) (x := inputs.embed q) hmodel hEps hSqrt - simpa [lnBounds, lnLo, lnHi, lnRealOfInputs] using hln i + simpa [lnBounds, lnLo, lnHi, lnRealOfInputs_def] using hln i have hv_bounds : ∀ q d, (vLo q d : Real) ≤ vRealOfInputs inputs q d ∧ vRealOfInputs inputs q d ≤ (vHi q d : Real) := by @@ -193,8 +195,8 @@ def buildHeadOutputIntervalFromHead? [NeZero seq] (lo := vLo k) (hi := vHi k) (x := fun d => vRealOfInputs inputs k d) hlo hhi constructor - · simpa [headValueLo, headValueRealOfInputs] using hlow - · simpa [headValueHi, headValueRealOfInputs] using hhigh + · simpa [headValueLo, headValueRealOfInputs_def] using hlow + · simpa [headValueHi, headValueRealOfInputs_def] using hhigh let scoresReal : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := scoresRealOfInputs inputs let weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := fun q k => diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean index 2647557..d461b9c 100644 --- a/Nfp/Sound/Induction/LogitDiff.lean +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -12,7 +12,7 @@ public import Nfp.Sound.Induction.HeadOutput Logit-diff bounds derived from induction certificates. -/ -@[expose] public section +public section namespace Nfp @@ -59,7 +59,7 @@ theorem direction_dot_headValue_eq_valsReal ((dirHeadVecOfInputs inputs).get d : Real) = ratToReal (Linear.dotFin dModel (fun i => inputs.wo i d) (fun i => inputs.direction i)) := by - simp [dirHeadVecOfInputs, Vector.get, Vector.ofFn, ratToReal] + simp [dirHeadVecOfInputs_get, ratToReal] _ = ratToReal (dotProduct (fun i => inputs.wo i d) (fun i => inputs.direction i)) := by simp [Linear.dotFin_eq_dotProduct] @@ -73,7 +73,7 @@ theorem direction_dot_headValue_eq_valsReal dotProduct dir (fun i => headValueRealOfInputs inputs k i) = ∑ i, dir i * ∑ d, (inputs.wo i d : Real) * v d := by - simp [dir, v, headValueRealOfInputs, dotProduct] + simp [dir, v, headValueRealOfInputs_def, dotProduct] _ = ∑ d, (∑ i, dir i * (inputs.wo i d : Real)) * v d := by simp [hswap] _ = ∑ d, ((dirHeadVecOfInputs inputs).get d : Real) * v d := by @@ -569,7 +569,7 @@ theorem headLogitDiff_eq_direction_dot_headOutput calc dotProduct dir (fun i => headOutput inputs q i) = ∑ i, dir i * ∑ k, weights q k * headValueRealOfInputs inputs k i := by - simp [dir, headOutput, headOutputWithScores, weights, dotProduct] + simp [dir, headOutput_def, headOutputWithScores_def, weights, dotProduct] _ = ∑ i, ∑ k, dir i * (weights q k * headValueRealOfInputs inputs k i) := by simp [Finset.mul_sum] _ = ∑ k, ∑ i, dir i * (weights q k * headValueRealOfInputs inputs k i) := by diff --git a/Nfp/Sound/Induction/OneHot.lean b/Nfp/Sound/Induction/OneHot.lean index 02ebb29..22d1313 100644 --- a/Nfp/Sound/Induction/OneHot.lean +++ b/Nfp/Sound/Induction/OneHot.lean @@ -13,7 +13,7 @@ public import Nfp.Circuit.Layers.Softmax Per-query one-hot bounds derived from score margins. -/ -@[expose] public section +public section namespace Nfp diff --git a/Nfp/Sound/Linear/FinFold.lean b/Nfp/Sound/Linear/FinFold.lean index 8668494..8426925 100644 --- a/Nfp/Sound/Linear/FinFold.lean +++ b/Nfp/Sound/Linear/FinFold.lean @@ -14,7 +14,7 @@ Tail-recursive folds and sums over `Fin`. These helpers keep sound computations stack-safe while remaining explicit. -/ -@[expose] public section +public section namespace Nfp @@ -39,7 +39,7 @@ def sumFin (n : Nat) (f : Fin n → Rat) : Rat := foldlFin n (fun acc i => acc + f i) 0 /-- Tail-recursive sum over `Fin n` (alias for `sumFin`). -/ -def sumFinCommonDen (n : Nat) (f : Fin n → Rat) : Rat := +@[expose] def sumFinCommonDen (n : Nat) (f : Fin n → Rat) : Rat := sumFin n f /-- `sumFin` as a left fold over the finite range list. -/ @@ -92,7 +92,7 @@ theorem sumFinCommonDen_eq_sumFin (n : Nat) (f : Fin n → Rat) : sumFinCommonDen n f = sumFin n f := rfl /-- Dot product over `Fin n` (Rat-valued). -/ -def dotFin (n : Nat) (x y : Fin n → Rat) : Rat := +@[expose] def dotFin (n : Nat) (x y : Fin n → Rat) : Rat := sumFin n (fun i => x i * y i) /-- Unfolding lemma for `dotFin`. -/ From 1b505478e9a5b4fab0f8b73dc2083a2300d1ca6f Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 20:34:23 +0100 Subject: [PATCH 172/244] Trim remaining Sound exposures --- Nfp/Sound/Gpt2/HeadInputs.lean | 5 +++-- Nfp/Sound/Induction/CoreDefs.lean | 14 ++++++++------ Nfp/Sound/Induction/CoreSound/Values.lean | 2 +- Nfp/Sound/Induction/HeadOutput.lean | 4 ++-- Nfp/Sound/Induction/LogitDiff.lean | 2 +- Nfp/Sound/Linear/FinFold.lean | 10 ++++++---- 6 files changed, 21 insertions(+), 16 deletions(-) diff --git a/Nfp/Sound/Gpt2/HeadInputs.lean b/Nfp/Sound/Gpt2/HeadInputs.lean index 9270075..b29f0f7 100644 --- a/Nfp/Sound/Gpt2/HeadInputs.lean +++ b/Nfp/Sound/Gpt2/HeadInputs.lean @@ -25,7 +25,7 @@ namespace Gpt2 open Nfp.Model /-- Build induction-head inputs from a GPT-2 head slice and prompt period. -/ -@[expose] def buildInductionHeadInputs {seq dModel dHead vocab : Nat} +def buildInductionHeadInputs {seq dModel dHead vocab : Nat} (slice : Gpt2HeadSlice seq dModel dHead vocab) (period : Nat) : Model.InductionHeadInputs seq dModel dHead := { scale := slice.scale @@ -70,7 +70,8 @@ theorem buildInductionHeadInputs_def {seq dModel dHead vocab : Nat} maskCausal := true maskValue := (-10000 : Rat) directionSpec := slice.direction.spec - direction := slice.directionVec } := rfl + direction := slice.directionVec } := by + simp [buildInductionHeadInputs] end Gpt2 diff --git a/Nfp/Sound/Induction/CoreDefs.lean b/Nfp/Sound/Induction/CoreDefs.lean index b599601..6cbf702 100644 --- a/Nfp/Sound/Induction/CoreDefs.lean +++ b/Nfp/Sound/Induction/CoreDefs.lean @@ -86,17 +86,18 @@ theorem kRealOfInputs_def {seq dModel dHead : Nat} simp [kRealOfInputs] /-- Real-valued value projections for head inputs. -/ -@[expose] noncomputable def vRealOfInputs {seq dModel dHead : Nat} +noncomputable def vRealOfInputs {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dHead → Real := fun q d => dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs q) + (inputs.bv d : Real) /-- Unfolding lemma for `vRealOfInputs`. -/ -private theorem vRealOfInputs_def {seq dModel dHead : Nat} +theorem vRealOfInputs_def {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (q : Fin seq) (d : Fin dHead) : vRealOfInputs inputs q d = dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs q) + - (inputs.bv d : Real) := rfl + (inputs.bv d : Real) := by + simp [vRealOfInputs] /-- Real-valued attention scores for head inputs. -/ noncomputable def scoresRealOfInputs {seq dModel dHead : Nat} @@ -143,17 +144,18 @@ theorem headValueRealOfInputs_def {seq dModel dHead : Nat} simp [headValueRealOfInputs] /-- Real-valued direction scores for head inputs. -/ -@[expose] noncomputable def valsRealOfInputs {seq dModel dHead : Nat} +noncomputable def valsRealOfInputs {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Real := let dirHead : Fin dHead → Real := fun d => (dirHeadVecOfInputs inputs).get d fun k => dotProduct dirHead (fun d => vRealOfInputs inputs k d) /-- Unfolding lemma for `valsRealOfInputs`. -/ -private theorem valsRealOfInputs_def {seq dModel dHead : Nat} +theorem valsRealOfInputs_def {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (k : Fin seq) : valsRealOfInputs inputs k = let dirHead : Fin dHead → Real := fun d => (dirHeadVecOfInputs inputs).get d - dotProduct dirHead (fun d => vRealOfInputs inputs k d) := rfl + dotProduct dirHead (fun d => vRealOfInputs inputs k d) := by + simp [valsRealOfInputs] /-- Interval data for direction values. -/ structure ValueInterval (seq : Nat) where diff --git a/Nfp/Sound/Induction/CoreSound/Values.lean b/Nfp/Sound/Induction/CoreSound/Values.lean index 2e5d52a..4978444 100644 --- a/Nfp/Sound/Induction/CoreSound/Values.lean +++ b/Nfp/Sound/Induction/CoreSound/Values.lean @@ -124,7 +124,7 @@ theorem valsReal_eq_of_dir (inputs : Model.InductionHeadInputs seq dModel dHead) (fun d => dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs k) + (inputs.bv d : Real)) := by - simp [valsRealOfInputs, vRealOfInputs, hdirHead_real] + simp [valsRealOfInputs_def, vRealOfInputs_def, hdirHead_real] _ = dotProduct (fun d => (dirHead d : Real)) (fun d => dotProduct (fun j => (inputs.wv j d : Real)) diff --git a/Nfp/Sound/Induction/HeadOutput.lean b/Nfp/Sound/Induction/HeadOutput.lean index 1a0fe47..4513116 100644 --- a/Nfp/Sound/Induction/HeadOutput.lean +++ b/Nfp/Sound/Induction/HeadOutput.lean @@ -173,10 +173,10 @@ def buildHeadOutputIntervalFromHead? [NeZero seq] (lo := lnLo q) (hi := lnHi q) (x := lnRealOfInputs inputs q) (b := (inputs.bv d : Real)) hlo hhi constructor - · simpa [vLo, vRealOfInputs, Bounds.cacheBound2_apply, + · simpa [vLo, vRealOfInputs_def, Bounds.cacheBound2_apply, Bounds.dotIntervalLowerCachedRat_eq, ratToReal_add] using hlow' - · simpa [vHi, vRealOfInputs, Bounds.cacheBound2_apply, + · simpa [vHi, vRealOfInputs_def, Bounds.cacheBound2_apply, Bounds.dotIntervalUpperCachedRat_eq, ratToReal_add] using hhigh' have hhead_bounds : diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean index d461b9c..e139d22 100644 --- a/Nfp/Sound/Induction/LogitDiff.lean +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -90,7 +90,7 @@ theorem direction_dot_headValue_eq_valsReal simpa using (hdirHead d).symm simp [hdir] _ = valsRealOfInputs inputs k := by - simp [valsRealOfInputs, v, dotProduct] + simp [valsRealOfInputs_def, v, dotProduct] end Direction diff --git a/Nfp/Sound/Linear/FinFold.lean b/Nfp/Sound/Linear/FinFold.lean index 8426925..2ec1c58 100644 --- a/Nfp/Sound/Linear/FinFold.lean +++ b/Nfp/Sound/Linear/FinFold.lean @@ -39,7 +39,7 @@ def sumFin (n : Nat) (f : Fin n → Rat) : Rat := foldlFin n (fun acc i => acc + f i) 0 /-- Tail-recursive sum over `Fin n` (alias for `sumFin`). -/ -@[expose] def sumFinCommonDen (n : Nat) (f : Fin n → Rat) : Rat := +def sumFinCommonDen (n : Nat) (f : Fin n → Rat) : Rat := sumFin n f /-- `sumFin` as a left fold over the finite range list. -/ @@ -89,15 +89,17 @@ theorem ratToReal_sumFin {n : Nat} (f : Fin n → Rat) : /-- `sumFinCommonDen` agrees with `sumFin`. -/ theorem sumFinCommonDen_eq_sumFin (n : Nat) (f : Fin n → Rat) : - sumFinCommonDen n f = sumFin n f := rfl + sumFinCommonDen n f = sumFin n f := by + simp [sumFinCommonDen] /-- Dot product over `Fin n` (Rat-valued). -/ -@[expose] def dotFin (n : Nat) (x y : Fin n → Rat) : Rat := +def dotFin (n : Nat) (x y : Fin n → Rat) : Rat := sumFin n (fun i => x i * y i) /-- Unfolding lemma for `dotFin`. -/ theorem dotFin_def (n : Nat) (x y : Fin n → Rat) : - dotFin n x y = sumFin n (fun i => x i * y i) := rfl + dotFin n x y = sumFin n (fun i => x i * y i) := by + simp [dotFin] /-- `dotFin` matches `dotProduct`. -/ theorem dotFin_eq_dotProduct (n : Nat) (x y : Fin n → Rat) : From b50d55083b28e0180bc5cadc967fff3e541e92a1 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 20:49:38 +0100 Subject: [PATCH 173/244] Remove final Sound exposure --- Nfp/Sound/Induction/Core/Basic.lean | 2 +- Nfp/Sound/Induction/CoreSound/Basic.lean | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/Nfp/Sound/Induction/Core/Basic.lean b/Nfp/Sound/Induction/Core/Basic.lean index 1ded587..b41eef6 100644 --- a/Nfp/Sound/Induction/Core/Basic.lean +++ b/Nfp/Sound/Induction/Core/Basic.lean @@ -246,7 +246,7 @@ structure InductionHeadCoreCache (seq dModel dHead : Nat) where cert : InductionHeadCert seq /-- Build cached core quantities for induction-head certificates. -/ -@[expose] def buildInductionHeadCoreCacheWith [NeZero seq] {dModel dHead : Nat} +def buildInductionHeadCoreCacheWith [NeZero seq] {dModel dHead : Nat} (cfg : InductionHeadSplitConfig) (inputs : Model.InductionHeadInputs seq dModel dHead) : InductionHeadCoreCache seq dModel dHead := by diff --git a/Nfp/Sound/Induction/CoreSound/Basic.lean b/Nfp/Sound/Induction/CoreSound/Basic.lean index 3514a9f..77c5d9c 100644 --- a/Nfp/Sound/Induction/CoreSound/Basic.lean +++ b/Nfp/Sound/Induction/CoreSound/Basic.lean @@ -1,6 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later module +import all Nfp.Sound.Induction.Core.Basic public import Nfp.Sound.Induction.Core public import Nfp.Sound.Induction.CoreSound.Values From 11fa6ab5949c43f32497df78d4c34bcb607acad9 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Tue, 13 Jan 2026 22:24:03 +0100 Subject: [PATCH 174/244] refactor: align module structure and defs --- Nfp/Circuit/Combinators.lean | 3 +- Nfp/Circuit/Interface.lean | 23 ++- Nfp/Circuit/Layers/Attention.lean | 152 ++++++++++++++++++- Nfp/Circuit/Layers/Induction/Basic.lean | 22 ++- Nfp/Circuit/Layers/Linear.lean | 10 +- Nfp/Circuit/Layers/Tensor.lean | 2 +- Nfp/Circuit/Typed.lean | 7 +- Nfp/Core/Basic.lean | 26 +++- Nfp/Mixer/Basic.lean | 2 +- Nfp/Mixer/Operations.lean | 2 +- Nfp/Prob/Basic.lean | 2 +- Nfp/Prob/Operations.lean | 2 +- Nfp/Sound/Bounds/Attention.lean | 2 +- Nfp/Sound/Bounds/Gelu.lean | 8 +- Nfp/Sound/Bounds/LayerNorm/Basic.lean | 45 +++--- Nfp/Sound/Bounds/LayerNorm/InvStd.lean | 10 +- Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean | 2 +- Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean | 18 +-- Nfp/Sound/Bounds/MatrixNorm/Interval.lean | 40 +++-- Nfp/Sound/Bounds/Mlp.lean | 2 +- Nfp/Sound/Bounds/Transformer/Basic.lean | 4 +- Nfp/Sound/Bounds/Transformer/Embedding.lean | 12 +- Nfp/Sound/Induction/CoreSound/Basic.lean | 38 +++-- Nfp/Sound/Induction/HeadOutput.lean | 66 +++++--- Nfp/Sound/Induction/LogitDiff.lean | 8 +- Nfp/Sound/Induction/OneHot.lean | 18 +-- Nfp/Sound/Linear/FinFold.lean | 2 +- Nfp/System/Dag.lean | 7 +- Nfp/System/LocalSystem.lean | 2 +- 29 files changed, 397 insertions(+), 140 deletions(-) diff --git a/Nfp/Circuit/Combinators.lean b/Nfp/Circuit/Combinators.lean index 41258f5..aaada71 100644 --- a/Nfp/Circuit/Combinators.lean +++ b/Nfp/Circuit/Combinators.lean @@ -33,8 +33,7 @@ def relabel (C : Circuit Node Val) (e : _root_.Equiv Node Node') : Circuit Node' refine C.gate (e.symm i) ?_ intro j h refine rec (e j) ?_ - change C.dag.rel (e.symm (e j)) (e.symm i) - simpa using h + simpa [Dag.relabel_rel_iff] using h namespace Interface diff --git a/Nfp/Circuit/Interface.lean b/Nfp/Circuit/Interface.lean index be7a97d..37ddf57 100644 --- a/Nfp/Circuit/Interface.lean +++ b/Nfp/Circuit/Interface.lean @@ -16,7 +16,7 @@ universe u v u_in u_out namespace Circuit -variable {ι : Type u} [Fintype ι] [DecidableEq ι] +variable {ι : Type u} [Fintype ι] variable {α : Type v} /-- A typed input/output interface for a circuit. -/ @@ -31,19 +31,30 @@ namespace Interface variable {C : Circuit ι α} {ι_in : Type u_in} {ι_out : Type u_out} /-- Convert a typed input assignment into an input-node assignment. -/ -@[expose] def toInputAssignment (I : Interface C ι_in ι_out) (input : ι_in → α) : +def toInputAssignment (I : Interface C ι_in ι_out) (input : ι_in → α) : C.InputAssignment := fun i => input (I.inputs.symm i) +/-- Definitional characterization of `Interface.toInputAssignment`. -/ +theorem toInputAssignment_def (I : Interface C ι_in ι_out) (input : ι_in → α) : + I.toInputAssignment input = fun i => input (I.inputs.symm i) := by + rfl + +section Eval + +variable [DecidableEq ι] + /-- Evaluate a circuit on a typed interface. -/ -@[expose] def eval (I : Interface C ι_in ι_out) (input : ι_in → α) : ι_out → α := +def eval (I : Interface C ι_in ι_out) (input : ι_in → α) : ι_out → α := fun o => evalInput C (I.toInputAssignment input) (I.outputs o).1 -/-- Unfolding equation for `Interface.eval`. -/ -private theorem eval_eq (I : Interface C ι_in ι_out) (input : ι_in → α) (o : ι_out) : - I.eval input o = evalInput C (I.toInputAssignment input) (I.outputs o).1 := +/-- Definitional characterization of `Interface.eval`. -/ +theorem eval_def (I : Interface C ι_in ι_out) (input : ι_in → α) (o : ι_out) : + I.eval input o = evalInput C (I.toInputAssignment input) (I.outputs o).1 := by rfl +end Eval + end Interface end Circuit diff --git a/Nfp/Circuit/Layers/Attention.lean b/Nfp/Circuit/Layers/Attention.lean index e90255f..6005a62 100644 --- a/Nfp/Circuit/Layers/Attention.lean +++ b/Nfp/Circuit/Layers/Attention.lean @@ -485,9 +485,15 @@ end Dag section Inputs /-- Input nodes for the attention core. -/ -@[expose] def attentionInputs : Finset (AttentionNode Batch seq heads dim) := +def attentionInputs : Finset (AttentionNode Batch seq heads dim) := (Finset.univ : Finset (AttentionInput Batch seq heads dim)).map Embedding.inl +/-- Definitional characterization of `attentionInputs`. -/ +theorem attentionInputs_def : + attentionInputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) = + (Finset.univ : Finset (AttentionInput Batch seq heads dim)).map Embedding.inl := by + rfl + open scoped Classical in /-- Membership in attention inputs corresponds to being a left injection. -/ theorem mem_attentionInputs_iff {s : AttentionNode Batch seq heads dim} : @@ -513,7 +519,7 @@ theorem not_mem_attentionInputs_inr (s : Sum (ScoreIndex Batch seq heads) open scoped Classical in /-- Input labels correspond to input nodes in the attention core. -/ -@[expose] def attentionInputEquiv : +def attentionInputEquiv : AttentionInput Batch seq heads dim ≃ { i // i ∈ attentionInputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) } := { toFun := fun a => @@ -540,12 +546,40 @@ open scoped Classical in cases (not_mem_attentionInputs_inr (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) s hs) } +/-- Definitional characterization of `attentionInputEquiv`. -/ +theorem attentionInputEquiv_def : + attentionInputEquiv (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) = + { toFun := fun a => + ⟨Sum.inl a, + (mem_attentionInputs_iff (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).2 + ⟨a, rfl⟩⟩ + invFun := fun i => + match i with + | ⟨Sum.inl a, _⟩ => a + | ⟨Sum.inr s, h⟩ => + False.elim + (not_mem_attentionInputs_inr (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) s h) + left_inv := by + intro a + rfl + right_inv := by + intro i + cases i with + | mk s hs => + cases s with + | inl a => rfl + | inr s => + cases (not_mem_attentionInputs_inr (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) s hs) } := by + rfl + end Inputs section Outputs /-- Output nodes for the attention core. -/ -@[expose] def attentionOutputs : Finset (AttentionNode Batch seq heads dim) := +def attentionOutputs : Finset (AttentionNode Batch seq heads dim) := (Finset.univ : Finset (AttentionOutput Batch seq heads dim)).map { toFun := attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) inj' := by @@ -553,6 +587,17 @@ section Outputs cases h rfl } +/-- Definitional characterization of `attentionOutputs`. -/ +theorem attentionOutputs_def : + attentionOutputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) = + (Finset.univ : Finset (AttentionOutput Batch seq heads dim)).map + { toFun := attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + inj' := by + intro a b h + cases h + rfl } := by + rfl + open scoped Classical in /-- Membership in attention outputs corresponds to being an output injection. -/ theorem mem_attentionOutputs_iff {s : AttentionNode Batch seq heads dim} : @@ -594,7 +639,7 @@ theorem not_mem_attentionOutputs_weight (w : WeightIndex Batch seq heads) : open scoped Classical in /-- Output labels correspond to output nodes in the attention core. -/ -@[expose] def attentionOutputEquiv : +def attentionOutputEquiv : AttentionOutput Batch seq heads dim ≃ { i // i ∈ attentionOutputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) } := { toFun := fun o => @@ -640,6 +685,53 @@ open scoped Classical in | inr _ => rfl } +/-- Definitional characterization of `attentionOutputEquiv`. -/ +theorem attentionOutputEquiv_def : + attentionOutputEquiv (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) = + { toFun := fun o => + ⟨attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) o, + (mem_attentionOutputs_iff (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).2 + ⟨o, rfl⟩⟩ + invFun := fun i => + match i with + | ⟨Sum.inr (Sum.inr (Sum.inr o)), _⟩ => o + | ⟨Sum.inl s, h⟩ => + False.elim + (not_mem_attentionOutputs_inl (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) s h) + | ⟨Sum.inr (Sum.inl s), h⟩ => + False.elim + (not_mem_attentionOutputs_score (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) s h) + | ⟨Sum.inr (Sum.inr (Sum.inl w)), h⟩ => + False.elim + (not_mem_attentionOutputs_weight (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) w h) + left_inv := by + intro o + rfl + right_inv := by + intro i + cases i with + | mk s hs => + cases s with + | inl s => + cases (not_mem_attentionOutputs_inl (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) s hs) + | inr s => + cases s with + | inl s => + cases (not_mem_attentionOutputs_score (Batch := Batch) (seq := seq) + (heads := heads) (dim := dim) s hs) + | inr s => + cases s with + | inl w => + cases (not_mem_attentionOutputs_weight (Batch := Batch) (seq := seq) + (heads := heads) (dim := dim) w hs) + | inr _ => + rfl } := by + rfl + end Outputs section Circuits @@ -647,7 +739,7 @@ section Circuits variable [DecidableEq Batch] /-- Gate semantics for attention score/mixing circuits. -/ -@[expose] def attentionGate (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : +def attentionGate (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : ∀ i, (∀ j, (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel j i → @@ -706,8 +798,33 @@ variable [DecidableEq Batch] (dim := dim) b k h d q) exact dotProduct weights vals +/-- Definitional characterization of `attentionGate` on output nodes. -/ +theorem attentionGate_out_def (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) + (b : Batch) (q : Fin seq) (h : Fin heads) (d : Fin dim) + (rec : + ∀ j, + (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel j + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) → + Val) : + attentionGate (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) rec = + let weightNode : Fin seq → AttentionNode Batch seq heads dim := fun k => + attnWeight (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, h, q, k) + let valueNode : Fin seq → AttentionNode Batch seq heads dim := fun k => + attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, k, h, d) + let weights : Fin seq → Val := fun k => + rec (weightNode k) + (attentionDag_rel_weight_out (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q k d) + let vals : Fin seq → Val := fun k => + rec (valueNode k) + (attentionDag_rel_v_out (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b k h d q) + dotProduct weights vals := by + simp [attentionGate] + /-- Circuit for attention score/mixing. -/ -@[expose] def attentionCircuit (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : +def attentionCircuit (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : Circuit (AttentionNode Batch seq heads dim) Val := { dag := attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) inputs := attentionInputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) @@ -715,8 +832,18 @@ variable [DecidableEq Batch] gate := attentionGate (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax } +/-- Definitional characterization of `attentionCircuit`. -/ +theorem attentionCircuit_def (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : + attentionCircuit (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax = + { dag := attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + inputs := attentionInputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + outputs := attentionOutputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + gate := attentionGate (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale + softmax } := by + rfl + /-- Typed interface for attention score/mixing circuits. -/ -@[expose] def attentionInterface (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : +def attentionInterface (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : Interface (attentionCircuit (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax) (AttentionInput Batch seq heads dim) (AttentionOutput Batch seq heads dim) := @@ -730,7 +857,7 @@ section Typed variable [DecidableEq Batch] /-- Typed attention score/mixing circuit. -/ -@[expose] def attentionTyped (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : +def attentionTyped (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : TypedCircuit (AttentionNode Batch seq heads dim) Val (AttentionInput Batch seq heads dim) (AttentionOutput Batch seq heads dim) := { circuit := attentionCircuit (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) @@ -738,6 +865,15 @@ variable [DecidableEq Batch] interface := attentionInterface (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax } +/-- Definitional characterization of `attentionTyped`. -/ +theorem attentionTyped_def (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : + attentionTyped (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax = + { circuit := attentionCircuit (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + scale softmax + interface := attentionInterface (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + scale softmax } := by + rfl + end Typed end Layers diff --git a/Nfp/Circuit/Layers/Induction/Basic.lean b/Nfp/Circuit/Layers/Induction/Basic.lean index c08c14c..30ff41c 100644 --- a/Nfp/Circuit/Layers/Induction/Basic.lean +++ b/Nfp/Circuit/Layers/Induction/Basic.lean @@ -6,6 +6,7 @@ public import Mathlib.Algebra.BigOperators.Group.Finset.Basic public import Mathlib.Algebra.BigOperators.Ring.Finset public import Mathlib.Algebra.Order.Monoid.Unbundled.Basic public import Mathlib.Algebra.Order.Ring.Defs +import all Nfp.Circuit.Layers.Attention public import Nfp.Circuit.Layers.Attention /-! @@ -69,7 +70,7 @@ end Spec section ApproxSpec -variable {Val : Type v} [AddCommMonoid Val] [PartialOrder Val] [IsOrderedAddMonoid Val] +variable {Val : Type v} [AddCommMonoid Val] [PartialOrder Val] variable {n : Nat} /-- Approximate induction-head spec: outputs are within `ε` of `prev` values. -/ @@ -79,11 +80,21 @@ def InductionSpecApprox (ε : Val) ∀ q, q ≠ 0 → out q ≤ vals (prev q) + ε ∧ vals (prev q) ≤ out q + ε /-- Approximate induction-head spec restricted to active queries. -/ -@[expose] def InductionSpecApproxOn (ε : Val) (active : Fin (Nat.succ n) → Prop) +def InductionSpecApproxOn (ε : Val) (active : Fin (Nat.succ n) → Prop) (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) (out vals : Fin (Nat.succ n) → Val) : Prop := ∀ q, active q → out q ≤ vals (prev q) + ε ∧ vals (prev q) ≤ out q + ε +/-- Definitional characterization of `InductionSpecApproxOn`. -/ +theorem InductionSpecApproxOn_def (ε : Val) (active : Fin (Nat.succ n) → Prop) + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (out vals : Fin (Nat.succ n) → Val) : + InductionSpecApproxOn (Val := Val) (n := n) ε active prev out vals = + ∀ q, active q → out q ≤ vals (prev q) + ε ∧ vals (prev q) ≤ out q + ε := by + rfl + +variable [IsOrderedAddMonoid Val] + /-- Exact induction spec implies the approximate spec for any nonnegative tolerance. -/ theorem inductionSpecApprox_of_spec (ε : Val) (hε : 0 ≤ ε) (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) @@ -644,7 +655,7 @@ theorem attentionGate_out_eq_of_oneHot (scale : Val) (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) rec = attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) b h q d rec (prev q) := by - simp only [attentionGate] + simp only [attentionGate_out_def] change dotProduct (attentionOutWeights (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) @@ -732,7 +743,8 @@ theorem attentionTyped_eval_out_eq_of_oneHot (prev : Fin seq → Fin seq) Circuit.evalInput_eq_input (C := C) (input := inputAssign) (i := attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, prev q, h, d)) hmem - simpa [inputAssign, I, attentionInterface, attnInputV] using h + simpa [inputAssign, I, attentionInterface, attentionInputEquiv_def, + Interface.toInputAssignment_def, attnInputV] using h have hvals : attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) b h q d (fun j _ => Circuit.evalInput C inputAssign j) (prev q) = @@ -743,7 +755,7 @@ theorem attentionTyped_eval_out_eq_of_oneHot (prev : Fin seq → Fin seq) (attentionTyped (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax).eval input (b, q, h, d) = Circuit.evalInput C inputAssign (I.outputs (b, q, h, d)).1 := by - simp [TypedCircuit.eval, Interface.eval, C, I, inputAssign, attentionTyped] + simp [TypedCircuit.eval_def, Interface.eval_def, attentionTyped_def, C, I, inputAssign] _ = Circuit.evalInput C inputAssign (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) := by rfl diff --git a/Nfp/Circuit/Layers/Linear.lean b/Nfp/Circuit/Layers/Linear.lean index 21847d2..1f08adb 100644 --- a/Nfp/Circuit/Layers/Linear.lean +++ b/Nfp/Circuit/Layers/Linear.lean @@ -30,10 +30,18 @@ variable {Row Col : Type u} abbrev LinearNode (Row Col : Type u) : Type u := Sum Col Row /-- Rank function used to orient layer edges from inputs to outputs. -/ -@[expose] def linearRank : LinearNode Row Col → Nat +def linearRank : LinearNode Row Col → Nat | Sum.inl _ => 0 | Sum.inr _ => 1 +/-- Definitional characterization of `linearRank`. -/ +theorem linearRank_def (x : LinearNode Row Col) : + linearRank (Row := Row) (Col := Col) x = + match x with + | Sum.inl _ => 0 + | Sum.inr _ => 1 := by + cases x <;> rfl + section Dag variable [Fintype Row] [Fintype Col] diff --git a/Nfp/Circuit/Layers/Tensor.lean b/Nfp/Circuit/Layers/Tensor.lean index 11860a2..beb5e84 100644 --- a/Nfp/Circuit/Layers/Tensor.lean +++ b/Nfp/Circuit/Layers/Tensor.lean @@ -59,7 +59,7 @@ def batchedLinearDag : Dag (BatchedLinearNode Batch Row Col) := linearRank (Row := Batch × Row) (Col := Batch × Col) j < linearRank (Row := Batch × Row) (Col := Batch × Col) i) := by intro j i h - cases j <;> cases i <;> simp [batchedLinearAdj, linearRank] at h ⊢ + cases j <;> cases i <;> simp [batchedLinearAdj, linearRank_def] at h ⊢ have hwf : WellFounded (fun j i => linearRank (Row := Batch × Row) (Col := Batch × Col) j < linearRank (Row := Batch × Row) (Col := Batch × Col) i) := by diff --git a/Nfp/Circuit/Typed.lean b/Nfp/Circuit/Typed.lean index f98f6b6..4c27e8f 100644 --- a/Nfp/Circuit/Typed.lean +++ b/Nfp/Circuit/Typed.lean @@ -31,9 +31,14 @@ variable {Node : Type u} [Fintype Node] [DecidableEq Node] variable {Val : Type v} {Input : Type u_in} {Output : Type u_out} /-- Evaluate a typed circuit on a typed input. -/ -@[expose] def eval (T : TypedCircuit Node Val Input Output) (input : Input → Val) : Output → Val := +def eval (T : TypedCircuit Node Val Input Output) (input : Input → Val) : Output → Val := T.interface.eval input +/-- Definitional characterization of `TypedCircuit.eval`. -/ +theorem eval_def (T : TypedCircuit Node Val Input Output) (input : Input → Val) : + T.eval input = T.interface.eval input := by + rfl + /-- Decide equivalence by enumerating typed inputs. -/ def checkEquiv (T1 T2 : TypedCircuit Node Val Input Output) [Fintype Input] [DecidableEq Input] [Fintype Output] diff --git a/Nfp/Core/Basic.lean b/Nfp/Core/Basic.lean index 5d06aab..2e593b0 100644 --- a/Nfp/Core/Basic.lean +++ b/Nfp/Core/Basic.lean @@ -13,7 +13,7 @@ public import Mathlib.Data.Real.Basic Basic shared definitions for the NFP rewrite. -/ -@[expose] public section +public section namespace Nfp @@ -27,14 +27,28 @@ def defaultRatPrec : Int := 48 def ratRoundDown (q : Rat) (_prec : Int := defaultRatPrec) : Rat := q +/-- Definitional characterization of `ratRoundDown`. -/ +theorem ratRoundDown_def (q : Rat) (prec : Int := defaultRatPrec) : + ratRoundDown q prec = q := by + rfl + /-- Round a rational up (identity in the exact-rational refactor). -/ def ratRoundUp (q : Rat) (_prec : Int := defaultRatPrec) : Rat := q +/-- Definitional characterization of `ratRoundUp`. -/ +theorem ratRoundUp_def (q : Rat) (prec : Int := defaultRatPrec) : + ratRoundUp q prec = q := by + rfl + /-- Real cast of a rational value. -/ def ratToReal (x : Rat) : Real := (x : Real) +/-- Definitional characterization of `ratToReal`. -/ +theorem ratToReal_def (x : Rat) : ratToReal x = (x : Real) := by + rfl + @[simp] theorem ratToReal_zero : ratToReal 0 = 0 := by simp [ratToReal] @@ -84,6 +98,11 @@ def ratDivDown (x y : Rat) (_prec : Int := defaultRatPrec) : Rat := else x / y +/-- Definitional characterization of `ratDivDown`. -/ +theorem ratDivDown_def (x y : Rat) (prec : Int := defaultRatPrec) : + ratDivDown x y prec = if y = 0 then 0 else x / y := by + rfl + /-- Rational division with upward rounding (exact for rationals). -/ def ratDivUp (x y : Rat) (_prec : Int := defaultRatPrec) : Rat := if y = 0 then @@ -91,6 +110,11 @@ def ratDivUp (x y : Rat) (_prec : Int := defaultRatPrec) : Rat := else x / y +/-- Definitional characterization of `ratDivUp`. -/ +theorem ratDivUp_def (x y : Rat) (prec : Int := defaultRatPrec) : + ratDivUp x y prec = if y = 0 then 0 else x / y := by + rfl + theorem ratDivUp_ge (x y : Rat) (hy : y ≠ 0) : (x / y : Rat) ≤ (ratDivUp x y : Rat) := by simp [ratDivUp, hy] diff --git a/Nfp/Mixer/Basic.lean b/Nfp/Mixer/Basic.lean index 3549e4c..6f07d65 100644 --- a/Nfp/Mixer/Basic.lean +++ b/Nfp/Mixer/Basic.lean @@ -8,7 +8,7 @@ public import Nfp.Prob.Basic Row-stochastic mixers. -/ -@[expose] public section +public section open scoped BigOperators diff --git a/Nfp/Mixer/Operations.lean b/Nfp/Mixer/Operations.lean index bf70c7a..ee88bec 100644 --- a/Nfp/Mixer/Operations.lean +++ b/Nfp/Mixer/Operations.lean @@ -10,7 +10,7 @@ public import Mathlib.Algebra.BigOperators.Ring.Finset Mixer operations (pushforward, composition, identity). -/ -@[expose] public section +public section open scoped BigOperators diff --git a/Nfp/Prob/Basic.lean b/Nfp/Prob/Basic.lean index b37d658..586db72 100644 --- a/Nfp/Prob/Basic.lean +++ b/Nfp/Prob/Basic.lean @@ -9,7 +9,7 @@ public import Mathlib.Data.Fintype.BigOperators Probability vectors on finite types. -/ -@[expose] public section +public section open scoped BigOperators diff --git a/Nfp/Prob/Operations.lean b/Nfp/Prob/Operations.lean index ff957fa..fbf0f0c 100644 --- a/Nfp/Prob/Operations.lean +++ b/Nfp/Prob/Operations.lean @@ -9,7 +9,7 @@ public import Mathlib.Algebra.BigOperators.Ring.Finset Basic constructions on probability vectors. -/ -@[expose] public section +public section open scoped BigOperators diff --git a/Nfp/Sound/Bounds/Attention.lean b/Nfp/Sound/Bounds/Attention.lean index 27ce779..9a84016 100644 --- a/Nfp/Sound/Bounds/Attention.lean +++ b/Nfp/Sound/Bounds/Attention.lean @@ -179,7 +179,7 @@ theorem attentionOutputBounds_spec {seq dModel dHead numHeads : Nat} [NeZero seq max |(lo i : Real)| |(hi i : Real)| ≤ (absBound : Real) := by have hsup' : ratToReal (max |lo i| |hi i|) ≤ ratToReal absBound := ratToReal_le_of_le hsup - simpa [ratToReal_abs, ratToReal_max] using hsup' + simpa [ratToReal_abs, ratToReal_max, ratToReal_def] using hsup' exact le_trans hbound hsup_real have hln_bounds : ∀ q i, (lnLo i : Real) ≤ lnOut q i ∧ lnOut q i ≤ (lnHi i : Real) := by intro q i diff --git a/Nfp/Sound/Bounds/Gelu.lean b/Nfp/Sound/Bounds/Gelu.lean index 2c4499f..432abec 100644 --- a/Nfp/Sound/Bounds/Gelu.lean +++ b/Nfp/Sound/Bounds/Gelu.lean @@ -130,10 +130,10 @@ theorem geluInterval_bounds {lo hi : Rat} {x : Real} exact le_trans hgelu.2 hmax' by_cases hhi0 : 0 ≤ hi · have hhi0r : 0 ≤ (hi : Real) := by - exact ratToReal_nonneg_of_nonneg hhi0 + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hhi0 simpa [geluInterval, hhi0, max_eq_left hhi0r] using hmax · have hhi0r : (hi : Real) ≤ 0 := by - exact (ratToReal_nonpos_iff (x := hi)).2 (le_of_not_ge hhi0) + simpa [ratToReal_def] using (ratToReal_nonpos_iff (x := hi)).2 (le_of_not_ge hhi0) have hx0 : x ≤ 0 := le_trans hhi hhi0r have hmax' : max x 0 = 0 := max_eq_right hx0 have hhi'' : geluTanh x ≤ (0 : Real) := by @@ -141,7 +141,7 @@ theorem geluInterval_bounds {lo hi : Rat} {x : Real} simpa [geluInterval, hhi0, ratToReal_zero] using hhi'' by_cases hlo0 : lo ≤ 0 · have hlo0r : (lo : Real) ≤ 0 := by - exact (ratToReal_nonpos_iff (x := lo)).2 hlo0 + simpa [ratToReal_def] using (ratToReal_nonpos_iff (x := lo)).2 hlo0 have hmin : min (lo : Real) 0 ≤ min x 0 := min_le_min hlo le_rfl have hlo' : (lo : Real) ≤ geluTanh x := by have hmin' : (lo : Real) ≤ min x 0 := by @@ -151,7 +151,7 @@ theorem geluInterval_bounds {lo hi : Rat} {x : Real} · simpa [geluInterval, hlo0] using hlo' · exact hupper · have hlo0r : 0 ≤ (lo : Real) := by - exact ratToReal_nonneg_of_nonneg (le_of_not_ge hlo0) + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg (le_of_not_ge hlo0) have hx0 : 0 ≤ x := le_trans hlo0r hlo have hmin' : min x 0 = 0 := min_eq_right hx0 have hlo' : (0 : Real) ≤ geluTanh x := by diff --git a/Nfp/Sound/Bounds/LayerNorm/Basic.lean b/Nfp/Sound/Bounds/LayerNorm/Basic.lean index 8750aef..0e731c0 100644 --- a/Nfp/Sound/Bounds/LayerNorm/Basic.lean +++ b/Nfp/Sound/Bounds/LayerNorm/Basic.lean @@ -59,13 +59,15 @@ theorem scaleInterval_bounds_real {x lo hi : Rat} {y : Real} let bounds := scaleInterval x lo hi (bounds.1 : Real) ≤ (x : Real) * y ∧ (x : Real) * y ≤ (bounds.2 : Real) := by by_cases hx : 0 ≤ x - · have hx' : 0 ≤ (x : Real) := ratToReal_nonneg_of_nonneg hx + · have hx' : 0 ≤ (x : Real) := by + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hx have hbounds : (x : Real) * (lo : Real) ≤ (x : Real) * y ∧ (x : Real) * y ≤ (x : Real) * (hi : Real) := by exact ⟨mul_le_mul_of_nonneg_left hlo hx', mul_le_mul_of_nonneg_left hhi hx'⟩ simpa [scaleInterval, hx] using hbounds · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) - have hx'' : (x : Real) ≤ 0 := (ratToReal_nonpos_iff (x := x)).2 hx' + have hx'' : (x : Real) ≤ 0 := by + simpa [ratToReal_def] using (ratToReal_nonpos_iff (x := x)).2 hx' have hbounds : (x : Real) * (hi : Real) ≤ (x : Real) * y ∧ (x : Real) * y ≤ (x : Real) * (lo : Real) := by exact ⟨mul_le_mul_of_nonpos_left hhi hx'', mul_le_mul_of_nonpos_left hlo hx''⟩ @@ -157,14 +159,16 @@ theorem layerNormBounds_spec {n : Nat} let varEps : Real := (varianceRat x : Real) + (eps : Real) let invStd : Real := (Real.sqrt varEps)⁻¹ have hmu : (μRat : Real) = μ := by - simp [μRat, μ, mean_def, hne, ratRoundDown] + simp [μRat, μ, mean_def, hne, ratRoundDown_def] have hvar : (varRat : Real) = (varianceRat x : Real) := by - simp [varRat, variance_def, hne, ratRoundDown] + simp [varRat, variance_def, hne, ratRoundDown_def] have hvarEps : (varEpsRat : Real) = varEps := by simp [varEpsRat, varEps, hvar] have hvar_nonneg : 0 ≤ (varianceRat x : Real) := varianceRat_nonneg_real x hne + have hvar_nonneg_real : 0 ≤ ratToReal (varianceRat x) := by + simpa [ratToReal_def] using hvar_nonneg have hvar_nonneg_rat : 0 ≤ varianceRat x := by - exact (ratToReal_nonneg_iff (x := varianceRat x)).1 hvar_nonneg + exact (ratToReal_nonneg_iff (x := varianceRat x)).1 hvar_nonneg_real have hvarRat_nonneg : 0 ≤ varRat := by have h := ratRoundDown_nonneg (q := varianceRat x) hvar_nonneg_rat simpa [varRat, variance_def x hne] using h @@ -225,16 +229,16 @@ theorem layerNormBounds_spec {n : Nat} have hupper_ne : sqrtUpperBound ≠ 0 := ne_of_gt hsqrt_upper_pos_rat have hlower_ne : sqrtLowerBound ≠ 0 := ne_of_gt hsqrt_lower_pos_rat have hinv_lower : (invStdLower : Real) ≤ invStd := by - simpa [invStdLower, ratDivDown, hupper_ne, one_div] using hinv_lower_real + simpa [invStdLower, ratDivDown_def, hupper_ne, one_div] using hinv_lower_real have hinv_upper : invStd ≤ (invStdUpper : Real) := by - simpa [invStdUpper, ratDivUp, hlower_ne, one_div] using hinv_upper_real + simpa [invStdUpper, ratDivUp_def, hlower_ne, one_div] using hinv_upper_real have hlayer : layerNormReal eps gamma beta x i = (beta i : Real) + (coeff : Real) * invStd := by simp [layerNormReal, hne, coeff, centered, μ, hmu, invStd, varEps, add_comm, mul_assoc] by_cases hcoeff : 0 ≤ coeff - · have hcoeff_real : 0 ≤ (coeff : Real) := - ratToReal_nonneg_of_nonneg hcoeff + · have hcoeff_real : 0 ≤ (coeff : Real) := by + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hcoeff have hlow_raw : (beta i : Real) + (coeff : Real) * (invStdLower : Real) ≤ (beta i : Real) + (coeff : Real) * invStd := by @@ -318,7 +322,8 @@ theorem invStd_le_invStdBound {n : Nat} (eps : Rat) (x : Fin n → Rat) simpa [invStd] using h have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt - have hdiv := ratDivUp_ge_real (x := 1) (y := sqrtLower eps) hy + have hdiv := by + simpa [ratToReal_def] using ratDivUp_ge_real (x := 1) (y := sqrtLower eps) hy simpa [invStdBound, one_div] using hdiv exact le_trans hinv_sqrt hinv_bound @@ -369,7 +374,7 @@ theorem layerNormIntervalBounds_spec {n : Nat} have h0 : 0 ≤ centeredBound i := by dsimp [centeredBound] exact le_trans (abs_nonneg _) (le_max_left _ _) - exact ratToReal_nonneg_of_nonneg h0 + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg h0 have hcentered_abs : |(x i : Real) - μ| ≤ (centeredBound i : Real) := by have hmean_lo_real : (μLo : Real) ≤ μ := by have hmean_rat : (meanRat lo : Real) ≤ (meanRat x : Real) := @@ -389,14 +394,14 @@ theorem layerNormIntervalBounds_spec {n : Nat} have h2 : (lo i : Real) - μ ≤ (x i : Real) - μ := by exact sub_le_sub_right (by - exact ratToReal_le_of_le (hlo i)) + simpa [ratToReal_def] using ratToReal_le_of_le (hlo i)) μ exact le_trans h1 h2 have hhi' : (x i : Real) - μ ≤ (hi i : Real) - (μLo : Real) := by have h1 : (x i : Real) - μ ≤ (hi i : Real) - μ := by exact sub_le_sub_right (by - exact ratToReal_le_of_le (hhi i)) + simpa [ratToReal_def] using ratToReal_le_of_le (hhi i)) μ have h2 : (hi i : Real) - μ ≤ (hi i : Real) - (μLo : Real) := by exact sub_le_sub_left hmean_lo_real (hi i : Real) @@ -493,7 +498,7 @@ theorem layerNormAbsBounds_spec {n : Nat} meanReal_abs_le_bound (x := fun j => (x j : Real)) (bound := absBound) hne (by intro j - exact ratToReal_abs_le_of_le (habs j)) + simpa [ratToReal_def] using ratToReal_abs_le_of_le (habs j)) simpa [meanReal_eq_meanRat] using h have hbound_nonneg : 0 ≤ absBound := by have hposn : 0 < n := Nat.pos_of_ne_zero hne @@ -507,14 +512,14 @@ theorem layerNormAbsBounds_spec {n : Nat} let invStd : Real := (Real.sqrt varEps)⁻¹ have hcentered_abs : |(x i : Real) - μ| ≤ (centeredBound : Real) := by have hx : |(x i : Real)| ≤ (absBound : Real) := by - exact ratToReal_abs_le_of_le (habs i) + simpa [ratToReal_def] using ratToReal_abs_le_of_le (habs i) have hmu : |μ| ≤ (absBound : Real) := by simpa [μ] using hmean_abs_real have h12 : |(x i : Real) - μ| ≤ (absBound : Real) + (absBound : Real) := abs_sub_le_double_bound hx hmu simpa [centeredBound, two_mul] using h12 have hbound_nonneg_real : 0 ≤ (absBound : Real) := by - exact ratToReal_nonneg_of_nonneg hbound_nonneg + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hbound_nonneg have hcentered_nonneg : 0 ≤ (centeredBound : Real) := by have hsum := add_nonneg hbound_nonneg_real hbound_nonneg_real simpa [centeredBound, two_mul] using hsum @@ -626,7 +631,8 @@ theorem layerNormAbsBounds_spec_real {n : Nat} simpa [invStd] using h have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt - have hdiv := ratDivUp_ge_real (x := 1) (y := sqrtLower eps) hy + have hdiv := by + simpa [ratToReal_def] using ratDivUp_ge_real (x := 1) (y := sqrtLower eps) hy simpa [invStdBound, one_div] using hdiv have hinv : invStd ≤ (invStdBound : Real) := by exact le_trans hinv_sqrt hinv_bound @@ -723,7 +729,7 @@ theorem layerNormIntervalBounds_spec_real {n : Nat} have h0 : 0 ≤ centeredBound i := by dsimp [centeredBound] exact le_trans (abs_nonneg _) (le_max_left _ _) - exact ratToReal_nonneg_of_nonneg h0 + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg h0 have hcentered_abs : |x i - μ| ≤ (centeredBound i : Real) := by have hmean_lo_real : (μLo : Real) ≤ μ := by simpa [μLo, μ] using hmean_lo @@ -765,7 +771,8 @@ theorem layerNormIntervalBounds_spec_real {n : Nat} simpa [invStd] using h have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt - have hdiv := ratDivUp_ge_real (x := 1) (y := sqrtLower eps) hy + have hdiv := by + simpa [ratToReal_def] using ratDivUp_ge_real (x := 1) (y := sqrtLower eps) hy simpa [invStdBound, one_div] using hdiv have hinv : invStd ≤ (invStdBound : Real) := by exact le_trans hinv_sqrt hinv_bound diff --git a/Nfp/Sound/Bounds/LayerNorm/InvStd.lean b/Nfp/Sound/Bounds/LayerNorm/InvStd.lean index 992d167..4e2e03f 100644 --- a/Nfp/Sound/Bounds/LayerNorm/InvStd.lean +++ b/Nfp/Sound/Bounds/LayerNorm/InvStd.lean @@ -60,12 +60,14 @@ theorem invStdBounds_spec {n : Nat} (eps : Rat) (x : Fin n → Rat) let invStdUpper : Rat := ratDivUp 1 sqrtLowerBound let varEps : Real := (varianceRat x : Real) + (eps : Real) have hvar : (varRat : Real) = (varianceRat x : Real) := by - simp [varRat, variance_def, hne, ratRoundDown] + simp [varRat, variance_def, hne, ratRoundDown_def] have hvarEps : (varEpsRat : Real) = varEps := by simp [varEpsRat, varEps, hvar] have hvar_nonneg : 0 ≤ (varianceRat x : Real) := varianceRat_nonneg_real x hne have hvar_nonneg_rat : 0 ≤ varianceRat x := by - exact (ratToReal_nonneg_iff (x := varianceRat x)).1 hvar_nonneg + have hvar_nonneg_real : 0 ≤ ratToReal (varianceRat x) := by + simpa [ratToReal_def] using hvar_nonneg + exact (ratToReal_nonneg_iff (x := varianceRat x)).1 hvar_nonneg_real have hvarRat_nonneg : 0 ≤ varRat := by have h := ratRoundDown_nonneg (q := varianceRat x) hvar_nonneg_rat simpa [varRat, variance_def x hne] using h @@ -121,9 +123,9 @@ theorem invStdBounds_spec {n : Nat} (eps : Rat) (x : Fin n → Rat) have hupper_ne : sqrtUpperBound ≠ 0 := ne_of_gt hsqrt_upper_pos_rat have hlower_ne : sqrtLowerBound ≠ 0 := ne_of_gt hsqrt_lower_pos_rat have hinv_lower : (invStdLower : Real) ≤ invStd := by - simpa [invStdLower, ratDivDown, hupper_ne, one_div] using hinv_lower_real + simpa [invStdLower, ratDivDown_def, hupper_ne, one_div] using hinv_lower_real have hinv_upper : invStd ≤ (invStdUpper : Real) := by - simpa [invStdUpper, ratDivUp, hlower_ne, one_div] using hinv_upper_real + simpa [invStdUpper, ratDivUp_def, hlower_ne, one_div] using hinv_upper_real have hbounds : bounds = (invStdLower, invStdUpper) := by simp [bounds, invStdBounds, hne, varRat, varEpsRat, sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper] diff --git a/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean b/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean index c669a21..e07069b 100644 --- a/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean +++ b/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean @@ -169,7 +169,7 @@ theorem meanRat_le_meanRat_real {n : Nat} (x y : Fin n → Rat) (hne : n ≠ 0) meanReal (fun i => (x i : Real)) ≤ meanReal (fun i => (y i : Real)) := by refine meanReal_le_meanReal (x := fun i => (x i : Real)) (y := fun i => (y i : Real)) hne ?_ intro i - exact ratToReal_le_of_le (hxy i) + simpa [ratToReal_def] using ratToReal_le_of_le (hxy i) simpa [meanReal_eq_meanRat] using hreal /-- Variance of a real vector (defaults to `0` when `n = 0`). -/ diff --git a/Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean b/Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean index 591a819..aa0cdf2 100644 --- a/Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean +++ b/Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean @@ -307,13 +307,13 @@ theorem sqrtLowerBase_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : have hden_nonneg : 0 ≤ (b + 1 : Real) := by exact_mod_cast (Nat.zero_le (b + 1)) exact div_nonneg hnum_nonneg hden_nonneg have hq_nonneg : 0 ≤ (q : Real) := by - exact ratToReal_nonneg_of_nonneg hq + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hq have hle : (a : Real) / (b + 1 : Real) ≤ Real.sqrt (q : Real) := (Real.le_sqrt hnonneg hq_nonneg).2 hsq have hdown : (sqrtLowerBase q : Real) ≤ (a : Real) / (b + 1 : Real) := by have hdown' : - ratToReal (ratRoundDown ((a : Rat) / (b + 1))) ≤ + (ratRoundDown ((a : Rat) / (b + 1)) : Real) ≤ (a : Real) / (b + 1 : Real) := by simpa using ratRoundDown_le_real ((a : Rat) / (b + 1)) simpa [sqrtLowerBase, num, den, a, b] using hdown' @@ -364,7 +364,7 @@ theorem real_sqrt_le_sqrtUpperBase {q : Rat} (hq : 0 ≤ q) : (a + 1 : Real) / (b : Real) ≤ (sqrtUpperBase q : Real) := by have hup' : (a + 1 : Real) / (b : Real) ≤ - ratToReal (ratRoundUp ((a + 1 : Rat) / b)) := by + (ratRoundUp ((a + 1 : Rat) / b) : Real) := by simpa using real_le_ratRoundUp ((a + 1 : Rat) / b) simpa [sqrtUpperBase, num, den, a, b] using hup' exact le_trans hle hup @@ -406,13 +406,13 @@ theorem sqrtLowerAlt_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : have hden_nonneg : 0 ≤ (den : Real) := by exact_mod_cast (Nat.zero_le den) exact div_nonneg hnum_nonneg hden_nonneg have hq_nonneg : 0 ≤ (q : Real) := by - exact ratToReal_nonneg_of_nonneg hq + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hq have hle : (a : Real) / (den : Real) ≤ Real.sqrt (q : Real) := (Real.le_sqrt hnonneg hq_nonneg).2 hsq have hdown : (sqrtLowerAlt q : Real) ≤ (a : Real) / (den : Real) := by have hdown' : - ratToReal (ratRoundDown ((a : Rat) / den)) ≤ + (ratRoundDown ((a : Rat) / den) : Real) ≤ (a : Real) / (den : Real) := by simpa using ratRoundDown_le_real ((a : Rat) / den) simpa [sqrtLowerAlt, num, den, a] using hdown' @@ -477,14 +477,14 @@ theorem sqrtLowerScaled_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : exact mul_nonneg (le_of_lt hden_pos) (le_of_lt hscale_pos) exact div_nonneg hnum_nonneg hden_nonneg have hq_nonneg : 0 ≤ (q : Real) := by - exact ratToReal_nonneg_of_nonneg hq + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hq have hle : (a : Real) / ((den : Real) * (scale : Real)) ≤ Real.sqrt (q : Real) := (Real.le_sqrt hnonneg hq_nonneg).2 hsq have hdown : (sqrtLowerScaled q : Real) ≤ (a : Real) / ((den : Real) * (scale : Real)) := by have hdown' : - ratToReal (ratRoundDown ((a : Rat) / (den * scale))) ≤ + (ratRoundDown ((a : Rat) / (den * scale)) : Real) ≤ (a : Real) / ((den : Real) * (scale : Real)) := by simpa using ratRoundDown_le_real ((a : Rat) / (den * scale)) simpa [sqrtLowerScaled, num, den, scale, a] using hdown' @@ -533,7 +533,7 @@ theorem real_sqrt_le_sqrtUpperAlt {q : Rat} (hq : 0 ≤ q) : (a + 1 : Real) / (den : Real) ≤ (sqrtUpperAlt q : Real) := by have hup' : (a + 1 : Real) / (den : Real) ≤ - ratToReal (ratRoundUp ((a + 1 : Rat) / den)) := by + (ratRoundUp ((a + 1 : Rat) / den) : Real) := by simpa using real_le_ratRoundUp ((a + 1 : Rat) / den) simpa [sqrtUpperAlt, num, den, a] using hup' exact le_trans hle hup @@ -599,7 +599,7 @@ theorem real_sqrt_le_sqrtUpperScaled {q : Rat} (hq : 0 ≤ q) : (a + 1 : Real) / ((den : Real) * (scale : Real)) ≤ (sqrtUpperScaled q : Real) := by have hup' : (a + 1 : Real) / ((den : Real) * (scale : Real)) ≤ - ratToReal (ratRoundUp ((a + 1 : Rat) / (den * scale))) := by + (ratRoundUp ((a + 1 : Rat) / (den * scale)) : Real) := by simpa using real_le_ratRoundUp ((a + 1 : Rat) / (den * scale)) simpa [sqrtUpperScaled, num, den, scale, a] using hup' exact le_trans hle hup diff --git a/Nfp/Sound/Bounds/MatrixNorm/Interval.lean b/Nfp/Sound/Bounds/MatrixNorm/Interval.lean index ddb9edc..1e99329 100644 --- a/Nfp/Sound/Bounds/MatrixNorm/Interval.lean +++ b/Nfp/Sound/Bounds/MatrixNorm/Interval.lean @@ -633,7 +633,7 @@ theorem dotIntervalLower2_le_dotProduct_real {n : Nat} (lo1 hi1 lo2 hi2 : Fin n have hcast : (dotIntervalLower2 lo1 hi1 lo2 hi2 : Real) = ∑ j, (mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j) : Real) := by - simpa [dotIntervalLower2, ratToReal] using + simpa [dotIntervalLower2, ratToReal_def] using (Linear.ratToReal_sumFin (f := fun j => mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j))) have hsum := @@ -651,7 +651,7 @@ theorem dotProduct_le_dotIntervalUpper2_real {n : Nat} (lo1 hi1 lo2 hi2 : Fin n have hcast : (dotIntervalUpper2 lo1 hi1 lo2 hi2 : Real) = ∑ j, (mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j) : Real) := by - simpa [dotIntervalUpper2, ratToReal] using + simpa [dotIntervalUpper2, ratToReal_def] using (Linear.ratToReal_sumFin (f := fun j => mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j))) have hsum := @@ -766,21 +766,27 @@ theorem dotIntervalLower_le_dotProduct_real {n : Nat} (v lo hi : Fin n → Rat) have hcast : (dotIntervalLower v lo hi : Real) = ∑ j, if 0 ≤ v j then (v j : Real) * (lo j : Real) else (v j : Real) * (hi j : Real) := by - simpa [dotIntervalLower, ratToReal_mul, ratToReal_if] using - (Linear.ratToReal_sumFin - (f := fun j => if 0 ≤ v j then v j * lo j else v j * hi j)) + have hcast' : + ratToReal (dotIntervalLower v lo hi) = + ∑ j, if 0 ≤ v j then ratToReal (v j) * ratToReal (lo j) else + ratToReal (v j) * ratToReal (hi j) := by + simpa [dotIntervalLower, ratToReal_if, ratToReal_mul] using + (Linear.ratToReal_sumFin + (f := fun j => if 0 ≤ v j then v j * lo j else v j * hi j)) + simpa [ratToReal_def] using hcast' have hsum : (∑ j, if 0 ≤ v j then (v j : Real) * (lo j : Real) else (v j : Real) * (hi j : Real)) ≤ ∑ j, (v j : Real) * x j := by refine Finset.sum_le_sum ?_ intro j _ by_cases hv : 0 ≤ v j - · have hv' : (0 : Real) ≤ (v j : Real) := ratToReal_nonneg_of_nonneg hv + · have hv' : (0 : Real) ≤ (v j : Real) := by + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hv have hmul : (v j : Real) * (lo j : Real) ≤ (v j : Real) * x j := by exact mul_le_mul_of_nonneg_left (hlo j) hv' simpa [hv] using hmul · have hv' : (v j : Real) ≤ 0 := by - exact (ratToReal_nonpos_iff (x := v j)).2 (le_of_not_ge hv) + simpa [ratToReal_def] using (ratToReal_nonpos_iff (x := v j)).2 (le_of_not_ge hv) have hmul : (v j : Real) * (hi j : Real) ≤ (v j : Real) * x j := by exact mul_le_mul_of_nonpos_left (hhi j) hv' simpa [hv] using hmul @@ -804,21 +810,27 @@ theorem dotProduct_le_dotIntervalUpper_real {n : Nat} (v lo hi : Fin n → Rat) have hcast : (dotIntervalUpper v lo hi : Real) = ∑ j, if 0 ≤ v j then (v j : Real) * (hi j : Real) else (v j : Real) * (lo j : Real) := by - simpa [dotIntervalUpper, ratToReal_mul, ratToReal_if] using - (Linear.ratToReal_sumFin - (f := fun j => if 0 ≤ v j then v j * hi j else v j * lo j)) + have hcast' : + ratToReal (dotIntervalUpper v lo hi) = + ∑ j, if 0 ≤ v j then ratToReal (v j) * ratToReal (hi j) else + ratToReal (v j) * ratToReal (lo j) := by + simpa [dotIntervalUpper, ratToReal_if, ratToReal_mul] using + (Linear.ratToReal_sumFin + (f := fun j => if 0 ≤ v j then v j * hi j else v j * lo j)) + simpa [ratToReal_def] using hcast' have hsum : ∑ j, (v j : Real) * x j ≤ ∑ j, if 0 ≤ v j then (v j : Real) * (hi j : Real) else (v j : Real) * (lo j : Real) := by refine Finset.sum_le_sum ?_ intro j _ by_cases hv : 0 ≤ v j - · have hv' : (0 : Real) ≤ (v j : Real) := ratToReal_nonneg_of_nonneg hv + · have hv' : (0 : Real) ≤ (v j : Real) := by + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hv have hmul : (v j : Real) * x j ≤ (v j : Real) * (hi j : Real) := by exact mul_le_mul_of_nonneg_left (hhi j) hv' simpa [hv] using hmul · have hv' : (v j : Real) ≤ 0 := by - exact (ratToReal_nonpos_iff (x := v j)).2 (le_of_not_ge hv) + simpa [ratToReal_def] using (ratToReal_nonpos_iff (x := v j)).2 (le_of_not_ge hv) have hmul : (v j : Real) * x j ≤ (v j : Real) * (lo j : Real) := by exact mul_le_mul_of_nonpos_left (hlo j) hv' simpa [hv] using hmul @@ -852,10 +864,10 @@ theorem abs_le_intervalAbsBound_real {n : Nat} (lo hi : Fin n → Rat) (x : Fin constructor · have hleft : |lo i| ≤ intervalAbsBound lo hi := by exact le_trans (le_max_left _ _) (max_abs_le_intervalAbsBound lo hi i) - exact ratToReal_abs_le_of_le hleft + simpa [ratToReal_def] using ratToReal_abs_le_of_le hleft · have hright : |hi i| ≤ intervalAbsBound lo hi := by exact le_trans (le_max_right _ _) (max_abs_le_intervalAbsBound lo hi i) - exact ratToReal_abs_le_of_le hright + simpa [ratToReal_def] using ratToReal_abs_le_of_le hright exact le_trans hbound hsup_real theorem abs_dotProduct_le_dotIntervalAbsBound_real {n : Nat} (v lo hi : Fin n → Rat) diff --git a/Nfp/Sound/Bounds/Mlp.lean b/Nfp/Sound/Bounds/Mlp.lean index 097bc66..1af5c3f 100644 --- a/Nfp/Sound/Bounds/Mlp.lean +++ b/Nfp/Sound/Bounds/Mlp.lean @@ -184,7 +184,7 @@ theorem layerNormAbsMlpBounds_spec {n hidden : Nat} have hsup' : ratToReal (max |lo j| |hi j|) ≤ ratToReal absBound := ratToReal_le_of_le hsup - simpa [ratToReal_abs, ratToReal_max] using hsup' + simpa [ratToReal_abs, ratToReal_max, ratToReal_def] using hsup' exact le_trans hbound hsup_real have hln := layerNormAbsBounds_spec_real eps gamma beta absBound x hne heps hsqrt habs diff --git a/Nfp/Sound/Bounds/Transformer/Basic.lean b/Nfp/Sound/Bounds/Transformer/Basic.lean index d8a64e3..579018e 100644 --- a/Nfp/Sound/Bounds/Transformer/Basic.lean +++ b/Nfp/Sound/Bounds/Transformer/Basic.lean @@ -555,7 +555,9 @@ theorem gpt2ResidualIntervalBoundsActive_sound rcases hactive with ⟨q0, hq0⟩ have hq := hspec q0 hq0 i have hreal : (bounds.1 i : Real) ≤ (bounds.2 i : Real) := hq.1.trans hq.2 - exact (ratToReal_le_iff (x := bounds.1 i) (y := bounds.2 i)).1 hreal + have hreal' : ratToReal (bounds.1 i) ≤ ratToReal (bounds.2 i) := by + simpa [ratToReal_def] using hreal + exact (ratToReal_le_iff (x := bounds.1 i) (y := bounds.2 i)).1 hreal' refine And.intro hbounds ?_ intro q hq i have hq' := hspec q hq i diff --git a/Nfp/Sound/Bounds/Transformer/Embedding.lean b/Nfp/Sound/Bounds/Transformer/Embedding.lean index 070d2ac..394eeb0 100644 --- a/Nfp/Sound/Bounds/Transformer/Embedding.lean +++ b/Nfp/Sound/Bounds/Transformer/Embedding.lean @@ -59,8 +59,8 @@ theorem embeddingIntervalBounds_spec {seq dModel : Nat} [NeZero seq] (q := q) (hq := by simp) simpa [bounds, embeddingIntervalBounds, fin_univ_nonempty] using h constructor - · exact ratToReal_le_of_le hbounds.1 - · exact ratToReal_le_of_le hbounds.2 + · simpa [ratToReal_def] using ratToReal_le_of_le hbounds.1 + · simpa [ratToReal_def] using ratToReal_le_of_le hbounds.2 /-- Interval bounds across a finite set of positions for an embedding map. -/ def embeddingIntervalBoundsOn {seq dModel : Nat} [NeZero seq] @@ -84,8 +84,8 @@ theorem embeddingIntervalBoundsOn_spec {seq dModel : Nat} [NeZero seq] (f := fun k => x k i) (q := q) (hq := hq) simpa [bounds, embeddingIntervalBoundsOn] using h constructor - · exact ratToReal_le_of_le hbounds.1 - · exact ratToReal_le_of_le hbounds.2 + · simpa [ratToReal_def] using ratToReal_le_of_le hbounds.1 + · simpa [ratToReal_def] using ratToReal_le_of_le hbounds.2 /-- Collapse per-position interval bounds over a finite set of positions. -/ def intervalBoundsOn {seq dModel : Nat} [NeZero seq] @@ -119,11 +119,11 @@ theorem intervalBoundsOn_spec {seq dModel : Nat} [NeZero seq] constructor · have hmin_real : (bounds.1 i : Real) ≤ (lo q i : Real) := by - exact ratToReal_le_of_le hmin + simpa [ratToReal_def] using ratToReal_le_of_le hmin exact le_trans hmin_real hlo' · have hmax_real : (hi q i : Real) ≤ (bounds.2 i : Real) := by - exact ratToReal_le_of_le hmax + simpa [ratToReal_def] using ratToReal_le_of_le hmax exact le_trans hhi' hmax_real end Bounds diff --git a/Nfp/Sound/Induction/CoreSound/Basic.lean b/Nfp/Sound/Induction/CoreSound/Basic.lean index 77c5d9c..6e41449 100644 --- a/Nfp/Sound/Induction/CoreSound/Basic.lean +++ b/Nfp/Sound/Induction/CoreSound/Basic.lean @@ -70,8 +70,8 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N ∀ q, (mean (inputs.embed q) : Real) = meanRat (inputs.embed q) := by intro q have hmu_rat : mean (inputs.embed q) = meanRat (inputs.embed q) := by - simp [mean_def, hmodel, ratRoundDown] - simpa [ratToReal] using congrArg ratToReal hmu_rat + simp [mean_def, hmodel, ratRoundDown_def] + simpa [ratToReal_def] using congrArg ratToReal hmu_rat have hln_affine : ∀ q j, lnRealOfInputs inputs q j = @@ -719,7 +719,8 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N (scoreLo q k : Real) ≤ base ∧ base ≤ (scoreHi q k : Real) := by by_cases hscale : 0 ≤ inputs.scale · have hscale_real : 0 ≤ (inputs.scale : Real) := - ratToReal_nonneg_of_nonneg hscale + by + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hscale have hdot := hdot_bounds hnot have hlow := mul_le_mul_of_nonneg_left hdot.1 hscale_real have hhigh := mul_le_mul_of_nonneg_left hdot.2 hscale_real @@ -729,7 +730,9 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N · have hscale_nonpos : inputs.scale ≤ 0 := le_of_lt (lt_of_not_ge hscale) have hscale_real : (inputs.scale : Real) ≤ 0 := - (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos + by + simpa [ratToReal_def] using + (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos have hdot := hdot_bounds hnot have hlow := mul_le_mul_of_nonpos_left hdot.2 hscale_real have hhigh := mul_le_mul_of_nonpos_left hdot.1 hscale_real @@ -917,14 +920,17 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N kRealOfInputs inputs k d) := by by_cases hscale : 0 ≤ inputs.scale · have hscale_real : 0 ≤ (inputs.scale : Real) := - ratToReal_nonneg_of_nonneg hscale + by + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hscale have hle := mul_le_mul_of_nonneg_left hdiff.1 hscale_real simpa [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, hprevmask, hmask, hscale] using hle · have hscale_nonpos : inputs.scale ≤ 0 := le_of_lt (lt_of_not_ge hscale) have hscale_real : (inputs.scale : Real) ≤ 0 := - (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos + by + simpa [ratToReal_def] using + (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos have hle := mul_le_mul_of_nonpos_left hdiff.2 hscale_real simpa [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, hprevmask, hmask, hscale] using hle @@ -1023,7 +1029,8 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N have hmargin_le : marginAt q ≤ scoreGapLo q k := hmarginAt_le q hq k hk have hmargin_le_real : (marginAt q : Real) ≤ (scoreGapLo q k : Real) := - ratToReal_le_of_le hmargin_le + by + simpa [ratToReal_def] using ratToReal_le_of_le hmargin_le have hscore_gap := hscore_gap_real_at q hq k hk have hstep := add_le_add_right hmargin_le_real (scoresReal q k) have hstep' : @@ -1043,7 +1050,8 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N (f := marginAt) (a := marginAt q)).2 ⟨q, hmem, le_rfl⟩ simpa [margin, hnonempty] using hle have hmargin_le_real : (margin : Real) ≤ (marginAt q : Real) := - ratToReal_le_of_le hmargin_le + by + simpa [ratToReal_def] using ratToReal_le_of_le hmargin_le have hscore := hscore_margin_real_at q hq k hk have hscore' : (marginAt q : Real) + scoresReal q k ≤ scoresReal q (inputs.prev q) := by @@ -1129,7 +1137,7 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N have hepsAt_le_eps_real : ∀ q, q ∈ inputs.active → (epsAt q : Real) ≤ (eps : Real) := by intro q hq - exact ratToReal_le_of_le (hepsAt_le_eps q hq) + simpa [ratToReal_def] using ratToReal_le_of_le (hepsAt_le_eps q hq) have hsoftmax_bounds : Layers.SoftmaxMarginBoundsOn (Val := Real) (eps : Real) (margin : Real) (fun q => q ∈ inputs.active) inputs.prev scoresReal weights := by @@ -1237,7 +1245,7 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N dsimp [lo] refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) (f := valsLo) (a := valsLo k0)).2 ⟨k0, hmem0, le_rfl⟩ - exact ratToReal_le_of_le hloRat + simpa [ratToReal_def] using ratToReal_le_of_le hloRat have hvals : (valCert.valsLo k0 : Real) ≤ valsRealOfInputs inputs k0 ∧ valsRealOfInputs inputs k0 ≤ (valCert.valsHi k0 : Real) := by @@ -1248,18 +1256,20 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N dsimp [hi] refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) (f := valsHi) (a := valsHi k0)).2 ⟨k0, ⟨hmem0, le_rfl⟩⟩ - exact ratToReal_le_of_le hhiRat + simpa [ratToReal_def] using ratToReal_le_of_le hhiRat have hreal : (valCert.lo : Real) ≤ (valCert.hi : Real) := le_trans hlo (le_trans hvals.1 (le_trans hvals.2 hhi)) - exact (ratToReal_le_iff (x := valCert.lo) (y := valCert.hi)).1 hreal + have hreal' : ratToReal valCert.lo ≤ ratToReal valCert.hi := by + simpa [ratToReal_def] using hreal + exact (ratToReal_le_iff (x := valCert.lo) (y := valCert.hi)).1 hreal' · intro k have hloRat : valCert.lo ≤ valCert.valsLo k := by change lo ≤ valsLo k dsimp [lo] refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) (f := valsLo) (a := valsLo k)).2 ⟨k, by simp [univ], le_rfl⟩ - exact ratToReal_le_of_le hloRat + simpa [ratToReal_def] using ratToReal_le_of_le hloRat · intro k exact hvals_bounds_at k · intro k @@ -1268,7 +1278,7 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N dsimp [hi] refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) (f := valsHi) (a := valsHi k)).2 ⟨k, by simp [univ], le_rfl⟩ - exact ratToReal_le_of_le hhiRat + simpa [ratToReal_def] using ratToReal_le_of_le hhiRat exact { softmax_bounds := hsoftmax_bounds oneHot_bounds_at := oneHot_bounds_at diff --git a/Nfp/Sound/Induction/HeadOutput.lean b/Nfp/Sound/Induction/HeadOutput.lean index 4513116..91231e3 100644 --- a/Nfp/Sound/Induction/HeadOutput.lean +++ b/Nfp/Sound/Induction/HeadOutput.lean @@ -230,20 +230,25 @@ def buildHeadOutputIntervalFromHead? [NeZero seq] have hloRat : loVal i ≤ headValueLo k0 i := hloVal k0 have hhiRat : headValueHi k0 i ≤ hiVal i := hhiVal k0 have hbounds := hhead_bounds k0 i + have hloReal : (loVal i : Real) ≤ (headValueLo k0 i : Real) := by + simpa [ratToReal_def] using ratToReal_le_of_le hloRat + have hhiReal : (headValueHi k0 i : Real) ≤ (hiVal i : Real) := by + simpa [ratToReal_def] using ratToReal_le_of_le hhiRat have hreal : (loVal i : Real) ≤ (hiVal i : Real) := by - refine le_trans (ratToReal_le_of_le hloRat) ?_ - refine le_trans hbounds.1 ?_ - exact le_trans hbounds.2 (ratToReal_le_of_le hhiRat) + exact le_trans hloReal (le_trans hbounds.1 (le_trans hbounds.2 hhiReal)) exact hreal · intro k have hloRat : loVal i ≤ headValueLo k i := hloVal k have hbounds := hhead_bounds k i - exact (ratToReal_le_of_le hloRat) |>.trans hbounds.1 + have hloReal : (loVal i : Real) ≤ (headValueLo k i : Real) := by + simpa [ratToReal_def] using ratToReal_le_of_le hloRat + exact hloReal.trans hbounds.1 · intro k have hhiRat : headValueHi k i ≤ hiVal i := hhiVal k have hbounds := hhead_bounds k i - exact hbounds.2.trans - (ratToReal_le_of_le hhiRat) + have hhiReal : (headValueHi k i : Real) ≤ (hiVal i : Real) := by + simpa [ratToReal_def] using ratToReal_le_of_le hhiRat + exact hbounds.2.trans hhiReal have hsoftmax : Layers.SoftmaxMarginBoundsOn (Val := Real) (cert.eps : Real) (cert.margin : Real) @@ -302,7 +307,23 @@ def buildHeadOutputIntervalFromHead? [NeZero seq] headOutput inputs q i ≤ (hiOut i : Real) := by intro q hq i have hactive : activeSet.Nonempty := ⟨q, hq⟩ - have hspec := (happrox i) q hq + have hspec : + dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) ≤ + headValueRealOfInputs inputs (cert.prev q) i + + (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) ∧ + headValueRealOfInputs inputs (cert.prev q) i ≤ + dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) + + (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) := by + have happrox' : + ∀ q, q ∈ activeSet → + dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) ≤ + headValueRealOfInputs inputs (cert.prev q) i + + (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) ∧ + headValueRealOfInputs inputs (cert.prev q) i ≤ + dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) + + (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) := by + simpa [Layers.InductionSpecApproxOn_def] using (happrox i) + exact happrox' q hq have hout_def : headOutput inputs q i = dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) := by @@ -346,20 +367,22 @@ def buildHeadOutputIntervalFromHead? [NeZero seq] hlower'' have hlo : (loOut i : Real) ≤ (boundLoRat q i : Real) := by - refine ratToReal_le_of_le ?_ - simpa [loOut, hactive] using - (Finset.inf'_le - (s := activeSet) - (f := fun q => boundLoRat q i) - (b := q) hq) + have hloRat : loOut i ≤ boundLoRat q i := by + simpa [loOut, hactive] using + (Finset.inf'_le + (s := activeSet) + (f := fun q => boundLoRat q i) + (b := q) hq) + simpa [ratToReal_def] using ratToReal_le_of_le hloRat have hhi : (boundHiRat q i : Real) ≤ (hiOut i : Real) := by - refine ratToReal_le_of_le ?_ - simpa [hiOut, hactive] using - (Finset.le_sup' - (s := activeSet) - (f := fun q => boundHiRat q i) - (b := q) hq) + have hhiRat : boundHiRat q i ≤ hiOut i := by + simpa [hiOut, hactive] using + (Finset.le_sup' + (s := activeSet) + (f := fun q => boundHiRat q i) + (b := q) hq) + simpa [ratToReal_def] using ratToReal_le_of_le hhiRat exact ⟨le_trans hlo hlower, le_trans hupper hhi⟩ have hbounds : Circuit.ResidualIntervalBounds { lo := loOut, hi := hiOut } := by refine { lo_le_hi := ?_ } @@ -367,8 +390,9 @@ def buildHeadOutputIntervalFromHead? [NeZero seq] by_cases hactive : activeSet.Nonempty · rcases hactive with ⟨q, hq⟩ have hout_i := hout q hq i - exact (ratToReal_le_iff (x := loOut i) (y := hiOut i)).1 - (le_trans hout_i.1 hout_i.2) + have hreal : ratToReal (loOut i) ≤ ratToReal (hiOut i) := by + simpa [ratToReal_def] using le_trans hout_i.1 hout_i.2 + exact (ratToReal_le_iff (x := loOut i) (y := hiOut i)).1 hreal · simp [loOut, hiOut, hactive] let certOut : Circuit.ResidualIntervalCert dModel := { lo := loOut, hi := hiOut } exact some diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean index e139d22..d22ecae 100644 --- a/Nfp/Sound/Induction/LogitDiff.lean +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -59,7 +59,7 @@ theorem direction_dot_headValue_eq_valsReal ((dirHeadVecOfInputs inputs).get d : Real) = ratToReal (Linear.dotFin dModel (fun i => inputs.wo i d) (fun i => inputs.direction i)) := by - simp [dirHeadVecOfInputs_get, ratToReal] + simp [dirHeadVecOfInputs_get, ratToReal_def] _ = ratToReal (dotProduct (fun i => inputs.wo i d) (fun i => inputs.direction i)) := by simp [Linear.dotFin_eq_dotProduct] @@ -68,7 +68,7 @@ theorem direction_dot_headValue_eq_valsReal simp [dotProduct, Linear.ratToReal_sum_univ] _ = ∑ i, (inputs.wo i d : Real) * (inputs.direction i : Real) := by - simp [ratToReal] + simp [ratToReal_def] calc dotProduct dir (fun i => headValueRealOfInputs inputs k i) = ∑ i, dir i * @@ -151,7 +151,7 @@ theorem logitDiffLowerBoundFromCert_le simpa [logitDiffLowerBoundFromCert] using hbound have hboundReal : (lb : Real) ≤ valsLoPrev - (c.epsAt q : Real) * (valsLoPrev - lo) := by - simpa [ratToReal_sub, ratToReal_mul] using ratToReal_le_of_le hboundRat + simpa [ratToReal_sub, ratToReal_mul, ratToReal_def] using ratToReal_le_of_le hboundRat have hweights_nonneg : ∀ k, 0 ≤ weights q k := hsound.softmax_bounds.nonneg q hq have hweights := hsound.oneHot_bounds_at q hq @@ -307,7 +307,7 @@ theorem logitDiffLowerBoundFromCertWeighted_le (c.weightBoundAt q k : Real) * max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)))) := by simpa [valsLoPrevRat, valsLoPrev, ratToReal_sub, ratToReal_mul, ratToReal_max, - ratToReal, Rat.cast_sum] using ratToReal_le_of_le hboundRat + ratToReal_def, Rat.cast_sum] using ratToReal_le_of_le hboundRat have hweights_nonneg : ∀ k, 0 ≤ weights q k := hsound.softmax_bounds.nonneg q hq have hweights := hsound.oneHot_bounds_at q hq diff --git a/Nfp/Sound/Induction/OneHot.lean b/Nfp/Sound/Induction/OneHot.lean index 22d1313..3e8e82e 100644 --- a/Nfp/Sound/Induction/OneHot.lean +++ b/Nfp/Sound/Induction/OneHot.lean @@ -73,7 +73,7 @@ theorem oneHot_bounds_at_of_marginAt simpa [heps, hsum_one] using hsum_le · have hnonneg : 0 ≤ marginAt q := le_of_not_gt hneg have hnonneg_real : 0 ≤ (marginAt q : Real) := by - exact ratToReal_nonneg_of_nonneg hnonneg + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hnonneg have hbound : ∀ k ∈ others q, weights q k ≤ (1 + (marginAt q : Real))⁻¹ := by @@ -113,7 +113,7 @@ theorem oneHot_bounds_at_of_marginAt (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ ≤ (ratDivUp (seq - 1) (1 + marginAt q) : Real) := by have hrat' := ratDivUp_ge_real (seq - 1) (1 + marginAt q) hden - simpa [ratToReal, Rat.cast_div, Rat.cast_add, Rat.cast_natCast, + simpa [ratToReal_def, Rat.cast_div, Rat.cast_add, Rat.cast_natCast, div_eq_mul_inv] using hrat' simpa [hepsAt, hneg] using hrat exact le_trans hsum_le' heps @@ -227,7 +227,7 @@ theorem oneHot_bounds_at_of_scoreGapLo simpa [bound, hneg] using hle · have hnonneg : 0 ≤ scoreGapLo q k := le_of_not_gt hneg have hnonneg_real : 0 ≤ (scoreGapLo q k : Real) := by - exact ratToReal_nonneg_of_nonneg hnonneg + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hnonneg have hscore := hscore_gap_real_at q hq k hkne have hsoft : weights q k ≤ 1 / (1 + (scoreGapLo q k : Real)) := by @@ -244,12 +244,12 @@ theorem oneHot_bounds_at_of_scoreGapLo have hrat : 1 / (1 + (scoreGapLo q k : Real)) ≤ ratToReal (ratDivUp 1 (1 + scoreGapLo q k)) := by - simpa [ratToReal] using + simpa [ratToReal_def] using (ratDivUp_ge_real 1 (1 + scoreGapLo q k) hden) have hbound' : weights q k ≤ ratToReal (ratDivUp 1 (1 + scoreGapLo q k)) := hsoft.trans hrat - simpa [bound, hneg] using hbound' + simpa [bound, hneg, ratToReal_def] using hbound' have hsum_others_le : (∑ k ∈ others q, weights q k) ≤ (epsAt q : Real) := by have hsum_le : (∑ k ∈ others q, weights q k) ≤ @@ -267,7 +267,7 @@ theorem oneHot_bounds_at_of_scoreGapLo ratToReal (epsAt q) = ratToReal (min 1 ((others q).sum bound)) := by exact congrArg ratToReal h' -- Avoid rewriting the erased-sum into a difference. - simpa [ratToReal_min, ratToReal, Rat.cast_sum] using h'' + simpa [ratToReal_min, ratToReal_def, Rat.cast_sum] using h'' simpa [hepsAtReal] using hsum_le_min refine { nonneg := ?_ @@ -344,7 +344,7 @@ theorem weight_bound_at_of_scoreGapLo simpa [hweightBoundAt q k hk, hneg] using hle · have hnonneg : 0 ≤ scoreGapLo q k := le_of_not_gt hneg have hnonneg_real : 0 ≤ (scoreGapLo q k : Real) := by - exact ratToReal_nonneg_of_nonneg hnonneg + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hnonneg have hscore := hscore_gap_real_at q hq k hk have hsoft : Circuit.softmax (scoresReal q) k ≤ 1 / (1 + (scoreGapLo q k : Real)) := by @@ -361,12 +361,12 @@ theorem weight_bound_at_of_scoreGapLo have hrat : 1 / (1 + (scoreGapLo q k : Real)) ≤ ratToReal (ratDivUp 1 (1 + scoreGapLo q k)) := by - simpa [ratToReal] using + simpa [ratToReal_def] using (ratDivUp_ge_real 1 (1 + scoreGapLo q k) hden) have hbound' : Circuit.softmax (scoresReal q) k ≤ ratToReal (ratDivUp 1 (1 + scoreGapLo q k)) := hsoft.trans hrat - simpa [hweightBoundAt q k hk, hneg] using hbound' + simpa [hweightBoundAt q k hk, hneg, ratToReal_def] using hbound' end Sound diff --git a/Nfp/Sound/Linear/FinFold.lean b/Nfp/Sound/Linear/FinFold.lean index 2ec1c58..1d72965 100644 --- a/Nfp/Sound/Linear/FinFold.lean +++ b/Nfp/Sound/Linear/FinFold.lean @@ -79,7 +79,7 @@ theorem sumFin_eq_sum_univ {n : Nat} (f : Fin n → Rat) : theorem ratToReal_sum_univ {n : Nat} (f : Fin n → Rat) : ratToReal (∑ i, f i) = ∑ i, ratToReal (f i) := by classical - simp [ratToReal] + simp [ratToReal_def] /-- Casting a rational `sumFin` to `Real` commutes with summation. -/ theorem ratToReal_sumFin {n : Nat} (f : Fin n → Rat) : diff --git a/Nfp/System/Dag.lean b/Nfp/System/Dag.lean index 7cde304..6f44ae1 100644 --- a/Nfp/System/Dag.lean +++ b/Nfp/System/Dag.lean @@ -10,7 +10,7 @@ public import Mathlib.Data.Finset.Basic Directed acyclic graph foundations. -/ -@[expose] public section +public section namespace Nfp @@ -68,6 +68,11 @@ def relabel (G : Dag ι) (e : ι ≃ ι') : Dag ι' := wf := by simpa using (InvImage.wf (f := e.symm) (h := G.wf)) } +/-- Relabeling preserves adjacency via the equivalence. -/ +theorem relabel_rel_iff (G : Dag ι) (e : ι ≃ ι') (a b : ι') : + (G.relabel e).rel a b ↔ G.rel (e.symm a) (e.symm b) := by + rfl + end Relabel end Dag diff --git a/Nfp/System/LocalSystem.lean b/Nfp/System/LocalSystem.lean index 5fc3f9e..6870784 100644 --- a/Nfp/System/LocalSystem.lean +++ b/Nfp/System/LocalSystem.lean @@ -10,7 +10,7 @@ public import Nfp.System.Dag Local mixing systems on finite DAGs. -/ -@[expose] public section +public section open scoped BigOperators From dde521ed0e19bef61b7025ac9e3723949413b4a7 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Wed, 14 Jan 2026 00:41:23 +0100 Subject: [PATCH 175/244] docs: add mathlib module style guide --- AGENTS.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/AGENTS.md b/AGENTS.md index 1c26020..ab47dd9 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -170,6 +170,16 @@ prefer the **clean redesign**, but do it consciously and document the rationale. - update all call sites, - leave a brief comment (or commit message rationale). +### 4.5 Mathlib Module Structure (Local Baseline) +Based on `.lake/packages/mathlib` in this workspace: +- Use module headers (`module`) in Lean source files, with `public import` in the header. +- Keep imports minimal; add only what the file directly uses. +- Put exported declarations in a `public section`; otherwise they remain private to the module. +- Do not expose bodies by default; use `@[expose] public section` only when downstream unfolding is required. +- Prefer explicit `*_def` lemmas over exposing bodies when you just need definitional rewriting. +- Avoid `import all` unless a proof truly needs the private scope; keep such imports local and documented. +- Keep aggregator modules (e.g., `Nfp.Core`, `Nfp.Sound`) as thin reexports. + --- ## 6. Axioms & Trust Boundary From f836f989286b8080a93c5514375236c602c00f25 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Wed, 14 Jan 2026 00:45:49 +0100 Subject: [PATCH 176/244] docs: refine mathlib module style guide --- AGENTS.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index ab47dd9..c549102 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -172,13 +172,13 @@ prefer the **clean redesign**, but do it consciously and document the rationale. ### 4.5 Mathlib Module Structure (Local Baseline) Based on `.lake/packages/mathlib` in this workspace: -- Use module headers (`module`) in Lean source files, with `public import` in the header. -- Keep imports minimal; add only what the file directly uses. -- Put exported declarations in a `public section`; otherwise they remain private to the module. -- Do not expose bodies by default; use `@[expose] public section` only when downstream unfolding is required. -- Prefer explicit `*_def` lemmas over exposing bodies when you just need definitional rewriting. -- Avoid `import all` unless a proof truly needs the private scope; keep such imports local and documented. -- Keep aggregator modules (e.g., `Nfp.Core`, `Nfp.Sound`) as thin reexports. +- All `.lean` files are modules (start with `module`); keep that uniform. +- Standard library files use `public import` to reexport dependencies; tactic/meta/Util files often use plain `import` or `public meta import`. +- Most files open with `@[expose] public section` (sometimes `public section`, `public noncomputable section`, or `public meta section`). +- Use `@[expose] public section` when broad unfolding is expected; use plain `public section` or `private` when you want bodies hidden. +- Prefer `*_def` lemmas for targeted definitional rewriting even when definitions are exposed. +- `import all` is rare (used only when private scope is required); keep it local and documented. +- Keep aggregator modules (e.g., `Nfp.Core`, `Nfp.Sound`) as thin `public import` reexports. --- From d37dbe72cf7f96203b8a2bb386beba0e74a480c7 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Wed, 14 Jan 2026 00:58:59 +0100 Subject: [PATCH 177/244] Align module headers for CLI and executables --- Main.lean | 6 ++++++ Nfp/Cli.lean | 8 +++++++- TheoremAxioms.lean | 6 ++++++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/Main.lean b/Main.lean index 6f5e59a..e5d1eaa 100644 --- a/Main.lean +++ b/Main.lean @@ -1,7 +1,13 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later +module + import Nfp.Cli +public section + /-- CLI entry point. -/ def main (args : List String) : IO UInt32 := Nfp.main args + +end diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index 1e4f60d..d89b4bd 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -1,12 +1,16 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Cli +module + +public import Cli import Nfp.IO /-! Minimal CLI surface for the NFP rewrite. -/ +public section + open Cli namespace Nfp @@ -733,3 +737,5 @@ def main (args : List String) : IO UInt32 := do nfpCmd.validate args end Nfp + +end diff --git a/TheoremAxioms.lean b/TheoremAxioms.lean index b390259..5477f81 100644 --- a/TheoremAxioms.lean +++ b/TheoremAxioms.lean @@ -1,5 +1,7 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later +module + import Nfp /-! @@ -8,6 +10,8 @@ These `#print axioms` lines help ensure we only depend on a small set of axioms (ideally a subset of: `propext`, `Classical.choice`, `Quot.sound`). -/ +public section + #print axioms Nfp.ProbVec.sum_mass #print axioms Nfp.ProbVec.pure #print axioms Nfp.ProbVec.mix @@ -26,3 +30,5 @@ These `#print axioms` lines help ensure we only depend on a small set of axioms /-- Entrypoint for the axiom report build target. -/ def main : IO Unit := pure () + +end From 231d6b34627d16372a1c51421b63ef4553c0cfd8 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Wed, 14 Jan 2026 01:26:37 +0100 Subject: [PATCH 178/244] Update AGENTS guidance --- AGENTS.md | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index c549102..7945cf0 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -125,7 +125,7 @@ prefer the **clean redesign**, but do it consciously and document the rationale. ### 4.1 Naming and organization - Prefer consistent, descriptive names: - - `_lemma`, `_iff`, `_eq`, `_mono`, `_nonneg`, `_sum_one`, etc. + - `_iff`, `_eq`, `_mono`, `_nonneg`, `_sum_one`, `_def`, etc. - Keep namespaces coherent: - attach lemmas to the structure/namespace they conceptually belong to. @@ -147,6 +147,7 @@ prefer the **clean redesign**, but do it consciously and document the rationale. - the goal is genuinely routine, - it stays fast and stable under small refactors, and - it does not rely on a large implicit rule universe. + - Note: mathlib itself uses `by aesop` widely; we are stricter in trusted/core code here. - Prefer local rules over global rules: - If a lemma is meant to be reused by Aesop, tag it deliberately (e.g. `@[aesop safe]`) @@ -172,12 +173,12 @@ prefer the **clean redesign**, but do it consciously and document the rationale. ### 4.5 Mathlib Module Structure (Local Baseline) Based on `.lake/packages/mathlib` in this workspace: -- All `.lean` files are modules (start with `module`); keep that uniform. -- Standard library files use `public import` to reexport dependencies; tactic/meta/Util files often use plain `import` or `public meta import`. -- Most files open with `@[expose] public section` (sometimes `public section`, `public noncomputable section`, or `public meta section`). +- `Mathlib/` files start with `module`; `MathlibTest/` and `Archive/` files may not. +- Most `Mathlib/` files use `public import` for reexports, but plain `import` appears in core files too; meta helpers often use `public meta import` based on public API/compile-time reachability. +- Many `Mathlib/` files open with `@[expose] public section`; `public meta section` is common in meta/Util, while `public section` and `public noncomputable section` are rarer. - Use `@[expose] public section` when broad unfolding is expected; use plain `public section` or `private` when you want bodies hidden. - Prefer `*_def` lemmas for targeted definitional rewriting even when definitions are exposed. -- `import all` is rare (used only when private scope is required); keep it local and documented. +- `import all` is rare, requires `allowImportAll`, and is used mainly for tests/private access; keep it local and documented. - Keep aggregator modules (e.g., `Nfp.Core`, `Nfp.Sound`) as thin `public import` reexports. --- From df9cc8f370454b8257320cad6c48a003b5903973 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Wed, 14 Jan 2026 01:40:14 +0100 Subject: [PATCH 179/244] Centralize softmax weight invariants --- Nfp/Circuit/Layers/Softmax.lean | 26 ++++++++++++++++++++ Nfp/Sound/Bounds/Attention.lean | 8 +++++-- Nfp/Sound/Induction/CoreSound/Basic.lean | 9 +++---- Nfp/Sound/Induction/OneHot.lean | 30 +++++++++++------------- 4 files changed, 51 insertions(+), 22 deletions(-) diff --git a/Nfp/Circuit/Layers/Softmax.lean b/Nfp/Circuit/Layers/Softmax.lean index 6446159..a117099 100644 --- a/Nfp/Circuit/Layers/Softmax.lean +++ b/Nfp/Circuit/Layers/Softmax.lean @@ -61,6 +61,32 @@ lemma softmax_sum_one [NeZero seq] (scores : Fin seq → Real) : _ = 1 := by simp [hdenom] +/-- Real-valued row-stochastic weights with explicit nonnegativity and row-sum proofs. + Kept separate from `ProbVec` because softmax outputs `Real` rather than `NNReal`. -/ +structure SoftmaxWeights (seq : Nat) [NeZero seq] where + /-- Weight assigned to each query/key pair. -/ + weights : Fin seq → Fin seq → Real + /-- All weights are nonnegative. -/ + nonneg : ∀ q k, 0 ≤ weights q k + /-- Each row sums to one. -/ + sum_one : ∀ q, (∑ k, weights q k) = 1 + +/-- Package softmax weights with row-stochastic proofs. -/ +def softmaxWeights [NeZero seq] (scores : Fin seq → Fin seq → Real) : + SoftmaxWeights seq := + { weights := fun q k => softmax (scores q) k + nonneg := by + intro q k + simpa using softmax_nonneg (scores := scores q) k + sum_one := by + intro q + simpa using softmax_sum_one (scores := scores q) } + +/-- Definitional unfolding for `softmaxWeights.weights`. -/ +theorem softmaxWeights_weights [NeZero seq] (scores : Fin seq → Fin seq → Real) : + (softmaxWeights scores).weights = fun q k => softmax (scores q) k := by + rfl + lemma softmax_le_one [NeZero seq] (scores : Fin seq → Real) (k : Fin seq) : softmax scores k ≤ 1 := by classical diff --git a/Nfp/Sound/Bounds/Attention.lean b/Nfp/Sound/Bounds/Attention.lean index 9a84016..085bc1b 100644 --- a/Nfp/Sound/Bounds/Attention.lean +++ b/Nfp/Sound/Bounds/Attention.lean @@ -162,6 +162,8 @@ theorem attentionOutputBounds_spec {seq dModel dHead numHeads : Nat} [NeZero seq let sumHi : Fin dModel → Rat := fun j => ∑ h, headHi h j let headValue : Fin numHeads → Fin seq → Fin dHead → Real := fun h k d => dotProduct (fun j => ((heads h).wv j d : Real)) (lnOut k) + (heads h).bv d + let softmaxWeights : Fin numHeads → Circuit.SoftmaxWeights seq := fun h => + Circuit.softmaxWeights (scores h) let headWeights : Fin numHeads → Fin seq → Fin seq → Real := fun h q k => Circuit.softmax (scores h q) k let headOutput : Fin numHeads → Fin seq → Fin dHead → Real := fun h q d => @@ -215,9 +217,11 @@ theorem attentionOutputBounds_spec {seq dModel dHead numHeads : Nat} [NeZero seq headValue h k d ≤ (vHi h d : Real) := fun k => (hvals k d).2 have hnonneg : ∀ k, 0 ≤ headWeights h q k := by intro k - exact Circuit.softmax_nonneg (scores h q) k + simpa [headWeights, softmaxWeights, Circuit.softmaxWeights_weights] using + (softmaxWeights h).nonneg q k have hsum : ∑ k, headWeights h q k = 1 := by - simpa [headWeights] using Circuit.softmax_sum_one (scores h q) + simpa [headWeights, softmaxWeights, Circuit.softmaxWeights_weights] using + (softmaxWeights h).sum_one q have h := dotProduct_bounds_of_weights (lo := (vLo h d : Real)) (hi := (vHi h d : Real)) (vals := fun k => headValue h k d) (w := headWeights h q) hlo' hhi' hnonneg hsum diff --git a/Nfp/Sound/Induction/CoreSound/Basic.lean b/Nfp/Sound/Induction/CoreSound/Basic.lean index 6e41449..862976d 100644 --- a/Nfp/Sound/Induction/CoreSound/Basic.lean +++ b/Nfp/Sound/Induction/CoreSound/Basic.lean @@ -1018,6 +1018,7 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N hgap_add _ = scoresReal q (inputs.prev q) := hcancel exact hgap_add' + let softmaxWeights := Circuit.softmaxWeights scoresReal let weights : Fin seq → Fin seq → Real := fun q k => Circuit.softmax (scoresReal q) k let others : Fin seq → Finset (Fin seq) := fun q => @@ -1151,11 +1152,11 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N · intro q hq k hk exact hscore_margin_real q hq k hk · intro q _ k - simpa [weights] using - (Circuit.softmax_nonneg (scores := scoresReal q) k) + simpa [weights, softmaxWeights, Circuit.softmaxWeights_weights] using + softmaxWeights.nonneg q k · intro q _ - simpa [weights] using - (Circuit.softmax_sum_one (scores := scoresReal q)) + simpa [weights, softmaxWeights, Circuit.softmaxWeights_weights] using + softmaxWeights.sum_one q · intro q hq have honehot := oneHot_bounds_at q hq have hprev := honehot.prev_large q rfl diff --git a/Nfp/Sound/Induction/OneHot.lean b/Nfp/Sound/Induction/OneHot.lean index 3e8e82e..38ca93b 100644 --- a/Nfp/Sound/Induction/OneHot.lean +++ b/Nfp/Sound/Induction/OneHot.lean @@ -46,17 +46,18 @@ theorem oneHot_bounds_at_of_marginAt (fun q k => Circuit.softmax (scoresReal q) k) := by classical intro q hq + let softmaxWeights := Circuit.softmaxWeights scoresReal let weights : Fin seq → Fin seq → Real := fun q k => Circuit.softmax (scoresReal q) k let others : Fin seq → Finset (Fin seq) := fun q => (Finset.univ : Finset (Fin seq)).erase (prev q) have hweights_nonneg : ∀ k, 0 ≤ weights q k := by intro k - simpa [weights] using - (Circuit.softmax_nonneg (scores := scoresReal q) k) + simpa [weights, softmaxWeights, Circuit.softmaxWeights_weights] using + softmaxWeights.nonneg q k have hsum_one : (∑ k, weights q k) = 1 := by - simpa [weights] using - (Circuit.softmax_sum_one (scores := scoresReal q)) + simpa [weights, softmaxWeights, Circuit.softmaxWeights_weights] using + softmaxWeights.sum_one q have hsum_others_le : (∑ k ∈ others q, weights q k) ≤ (epsAt q : Real) := by by_cases hneg : marginAt q < 0 · have heps : (epsAt q : Real) = 1 := by @@ -124,12 +125,10 @@ theorem oneHot_bounds_at_of_marginAt other_le := ?_ } · intro q' hq' k subst q' - change 0 ≤ Circuit.softmax (scoresReal q) k - exact Circuit.softmax_nonneg (scores := scoresReal q) k + exact hweights_nonneg k · intro q' hq' subst q' - change (∑ k, Circuit.softmax (scoresReal q) k) = 1 - exact Circuit.softmax_sum_one (scores := scoresReal q) + exact hsum_one · intro q' hq' subst q' have hsum_eq : @@ -189,6 +188,7 @@ theorem oneHot_bounds_at_of_scoreGapLo (fun q k => Circuit.softmax (scoresReal q) k) := by classical intro q hq + let softmaxWeights := Circuit.softmaxWeights scoresReal let weights : Fin seq → Fin seq → Real := fun q k => Circuit.softmax (scoresReal q) k let others : Fin seq → Finset (Fin seq) := fun q => @@ -200,11 +200,11 @@ theorem oneHot_bounds_at_of_scoreGapLo ratDivUp 1 (1 + scoreGapLo q k) have hweights_nonneg : ∀ k, 0 ≤ weights q k := by intro k - simpa [weights] using - (Circuit.softmax_nonneg (scores := scoresReal q) k) + simpa [weights, softmaxWeights, Circuit.softmaxWeights_weights] using + softmaxWeights.nonneg q k have hsum_one : (∑ k, weights q k) = 1 := by - simpa [weights] using - (Circuit.softmax_sum_one (scores := scoresReal q)) + simpa [weights, softmaxWeights, Circuit.softmaxWeights_weights] using + softmaxWeights.sum_one q have hsum_others_le_one : (∑ k ∈ others q, weights q k) ≤ 1 := by have hsubset : others q ⊆ (Finset.univ : Finset (Fin seq)) := by intro k hk @@ -276,12 +276,10 @@ theorem oneHot_bounds_at_of_scoreGapLo other_le := ?_ } · intro q' hq' k subst q' - change 0 ≤ Circuit.softmax (scoresReal q) k - exact Circuit.softmax_nonneg (scores := scoresReal q) k + exact hweights_nonneg k · intro q' hq' subst q' - change (∑ k, Circuit.softmax (scoresReal q) k) = 1 - exact Circuit.softmax_sum_one (scores := scoresReal q) + exact hsum_one · intro q' hq' subst q' have hsum_eq : From b1381aff6edb4d44dcfccdbc5357ddcd0672fe90 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Wed, 14 Jan 2026 02:45:27 +0100 Subject: [PATCH 180/244] Tighten induction weight bounds and logit-diff LB --- Nfp/IO/InductionHead/Basic.lean | 3 +- Nfp/Sound/Induction/Core/Basic.lean | 22 ++++++-- Nfp/Sound/Induction/CoreDefs.lean | 2 +- Nfp/Sound/Induction/CoreSound/Basic.lean | 68 +++++++++++++++++------- Nfp/Sound/Induction/LogitDiff.lean | 45 ++++++++++++++++ 5 files changed, 114 insertions(+), 26 deletions(-) diff --git a/Nfp/IO/InductionHead/Basic.lean b/Nfp/IO/InductionHead/Basic.lean index 1708797..b2d66b1 100644 --- a/Nfp/IO/InductionHead/Basic.lean +++ b/Nfp/IO/InductionHead/Basic.lean @@ -928,8 +928,7 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} timingPrint "timing: head logit-diff lower bound start" timingFlush let logitDiffLB? ← timePure "head: logit-diff lower bound" (fun () => - Circuit.logitDiffLowerBound cert.active cert.prev cert.eps - cert.values.lo cert.values.hi cert.values.valsLo) + Sound.logitDiffLowerBoundFromCertBest cert) logTiming "done: head logit-diff lower bound" let effectiveMinLogitDiff := match minLogitDiff? with diff --git a/Nfp/Sound/Induction/Core/Basic.lean b/Nfp/Sound/Induction/Core/Basic.lean index b41eef6..2b605cc 100644 --- a/Nfp/Sound/Induction/Core/Basic.lean +++ b/Nfp/Sound/Induction/Core/Basic.lean @@ -653,14 +653,16 @@ def buildInductionHeadCoreCacheWith [NeZero seq] {dModel dHead : Nat} (1 : Rat) else ratDivUp 1 (1 + gap) - let weightBoundAt : Fin seq → Fin seq → Rat := - Bounds.cacheBound2 weightBoundAtBase let epsAtBase : Fin seq → Rat := fun q => let other := otherKeys q - let total := other.sum (fun k => weightBoundAt q k) + let total := other.sum (fun k => weightBoundAtBase q k) min (1 : Rat) total let epsAt : Fin seq → Rat := Bounds.cacheBoundThunk epsAtBase + let weightBoundAtClampedBase : Fin seq → Fin seq → Rat := fun q k => + min (weightBoundAtBase q k) (epsAt q) + let weightBoundAt : Fin seq → Fin seq → Rat := + Bounds.cacheBound2 weightBoundAtClampedBase let margin : Rat := if h : inputs.active.Nonempty then inputs.active.inf' h marginAt @@ -793,6 +795,20 @@ theorem buildInductionHeadCoreCache_cert_eq [NeZero seq] {dModel dHead : Nat} prev := inputs.prev values := (buildInductionHeadCoreCache inputs).valCert } := by rfl + +/-- The cached certificate is built from cache fields (custom split config). -/ +theorem buildInductionHeadCoreCacheWith_cert_eq [NeZero seq] {dModel dHead : Nat} + (cfg : InductionHeadSplitConfig) + (inputs : Model.InductionHeadInputs seq dModel dHead) : + (buildInductionHeadCoreCacheWith cfg inputs).cert = + { eps := (buildInductionHeadCoreCacheWith cfg inputs).eps + epsAt := (buildInductionHeadCoreCacheWith cfg inputs).epsAt + weightBoundAt := (buildInductionHeadCoreCacheWith cfg inputs).weightBoundAt + margin := (buildInductionHeadCoreCacheWith cfg inputs).margin + active := inputs.active + prev := inputs.prev + values := (buildInductionHeadCoreCacheWith cfg inputs).valCert } := by + rfl /-- Build induction certificates from exact head inputs (core computation). -/ def buildInductionCertFromHeadCoreWith? [NeZero seq] {dModel dHead : Nat} (cfg : InductionHeadSplitConfig) diff --git a/Nfp/Sound/Induction/CoreDefs.lean b/Nfp/Sound/Induction/CoreDefs.lean index 6cbf702..aa315be 100644 --- a/Nfp/Sound/Induction/CoreDefs.lean +++ b/Nfp/Sound/Induction/CoreDefs.lean @@ -215,7 +215,7 @@ structure InductionHeadCert (seq : Nat) where eps : Rat /-- Per-query weight tolerance derived from local margins. -/ epsAt : Fin seq → Rat - /-- Per-key weight bounds derived from score gaps. -/ + /-- Per-key weight bounds derived from score gaps, clamped by `epsAt`. -/ weightBoundAt : Fin seq → Fin seq → Rat /-- Score margin used to justify the weight tolerance. -/ margin : Rat diff --git a/Nfp/Sound/Induction/CoreSound/Basic.lean b/Nfp/Sound/Induction/CoreSound/Basic.lean index 862976d..0079b5c 100644 --- a/Nfp/Sound/Induction/CoreSound/Basic.lean +++ b/Nfp/Sound/Induction/CoreSound/Basic.lean @@ -427,14 +427,16 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N (1 : Rat) else ratDivUp 1 (1 + gap) - let weightBoundAt : Fin seq → Fin seq → Rat := - Bounds.cacheBound2 weightBoundAtBase let epsAtBase : Fin seq → Rat := fun q => let other := otherKeys q - let total := other.sum (fun k => weightBoundAt q k) + let total := other.sum (fun k => weightBoundAtBase q k) min (1 : Rat) total let epsAt : Fin seq → Rat := Bounds.cacheBoundThunk epsAtBase + let weightBoundAtClampedBase : Fin seq → Fin seq → Rat := fun q k => + min (weightBoundAtBase q k) (epsAt q) + let weightBoundAt : Fin seq → Fin seq → Rat := + Bounds.cacheBound2 weightBoundAtClampedBase let margin : Rat := if h : inputs.active.Nonempty then inputs.active.inf' h marginAt @@ -482,7 +484,7 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N some (buildInductionHeadCoreCacheWith cfg inputs).cert := buildInductionCertFromHeadCoreWith?_eq_some (cfg := cfg) (inputs := inputs) hEps hSqrt hmodel hactive - simpa using hcore'' + simpa [buildInductionHeadCoreCacheWith_cert_eq] using hcore'' have hc : c = cert := by have hcert : cert = c := by exact Option.some.inj (hcore'.symm.trans hcore) @@ -1065,16 +1067,21 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N _ = (marginAt q : Real) + scoresReal q k := by simp [add_comm] exact hstep'.trans hscore' - have hweightBoundAt : + have hweightBoundAtBase : ∀ q k, k ≠ inputs.prev q → - weightBoundAt q k = + weightBoundAtBase q k = if scoreGapLo q k < 0 then (1 : Rat) else ratDivUp 1 (1 + scoreGapLo q k) := by intro q k hk - simpa [weightBoundAt, weightBoundAtBase, hk] using - (Bounds.cacheBound2_apply (f := weightBoundAtBase) q k) + simp [weightBoundAtBase, hk] + have hweightBoundAt : + ∀ q k, + weightBoundAt q k = min (weightBoundAtBase q k) (epsAt q) := by + intro q k + simpa [weightBoundAt, weightBoundAtClampedBase] using + (Bounds.cacheBound2_apply (f := weightBoundAtClampedBase) q k) have hepsAt : ∀ q, epsAt q = min (1 : Rat) @@ -1085,7 +1092,7 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N ratDivUp 1 (1 + scoreGapLo q k))) := by intro q have hsum : - (otherKeys q).sum (fun k => weightBoundAt q k) = + (otherKeys q).sum (fun k => weightBoundAtBase q k) = (otherKeys q).sum (fun k => if scoreGapLo q k < 0 then (1 : Rat) @@ -1094,7 +1101,7 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N refine Finset.sum_congr rfl ?_ intro k hk have hk' : k ≠ inputs.prev q := (Finset.mem_erase.mp hk).1 - simp [hweightBoundAt q k hk'] + simp [hweightBoundAtBase q k hk'] simpa [epsAt, epsAtBase, hsum] using (Bounds.cacheBoundThunk_apply (f := epsAtBase) q) have oneHot_bounds_at : @@ -1116,16 +1123,37 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → weights q k ≤ (weightBoundAt q k : Real) := by intro q hq k hk - exact - Sound.weight_bound_at_of_scoreGapLo - (active := inputs.active) - (prev := inputs.prev) - (scoresReal := scoresReal) - (scoreGapLo := scoreGapLo) - (weightBoundAt := weightBoundAt) - (hweightBoundAt := hweightBoundAt) - (hscore_gap_real_at := hscore_gap_real_at) - q hq k hk + have hbound_base : + weights q k ≤ (weightBoundAtBase q k : Real) := by + exact + Sound.weight_bound_at_of_scoreGapLo + (active := inputs.active) + (prev := inputs.prev) + (scoresReal := scoresReal) + (scoreGapLo := scoreGapLo) + (weightBoundAt := weightBoundAtBase) + (hweightBoundAt := hweightBoundAtBase) + (hscore_gap_real_at := hscore_gap_real_at) + q hq k hk + have hbound_eps : + weights q k ≤ (epsAt q : Real) := by + have honehot := oneHot_bounds_at q hq + exact honehot.other_le q rfl k hk + have hbound_min : + weights q k ≤ min (weightBoundAtBase q k : Real) (epsAt q : Real) := by + exact le_min hbound_base hbound_eps + have hweightBoundAt_real : + (weightBoundAt q k : Real) = + min (weightBoundAtBase q k : Real) (epsAt q : Real) := by + have hmin : + weightBoundAt q k = min (weightBoundAtBase q k) (epsAt q) := + hweightBoundAt q k + have hmin' : + ratToReal (weightBoundAt q k) = + ratToReal (min (weightBoundAtBase q k) (epsAt q)) := + congrArg ratToReal hmin + simpa [ratToReal_min, ratToReal_def] using hmin' + simpa [hweightBoundAt_real] using hbound_min have hepsAt_le_eps : ∀ q, q ∈ inputs.active → epsAt q ≤ eps := by intro q hq diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean index d22ecae..d235405 100644 --- a/Nfp/Sound/Induction/LogitDiff.lean +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -114,6 +114,14 @@ def logitDiffLowerBoundFromCert (c : InductionHeadCert seq) : Option Rat := def logitDiffLowerBoundFromCertWeighted (c : InductionHeadCert seq) : Option Rat := Circuit.logitDiffLowerBoundWeightedAt c.active c.prev c.weightBoundAt c.values.valsLo +/-- Best available logit-diff lower bound from an induction certificate. -/ +def logitDiffLowerBoundFromCertBest (c : InductionHeadCert seq) : Option Rat := + match logitDiffLowerBoundFromCert c, logitDiffLowerBoundFromCertWeighted c with + | some lb0, some lb1 => some (max lb0 lb1) + | some lb0, none => some lb0 + | none, some lb1 => some lb1 + | none, none => none + section WithNeZero variable [NeZero seq] @@ -485,6 +493,43 @@ theorem logitDiffLowerBoundFromCertWeighted_le le_trans hboundReal hdot_lower simpa [headLogitDiff, weights, vals] using hle +/-- The best available logit-diff lower bound is sound on active queries. -/ +theorem logitDiffLowerBoundFromCertBest_le + (inputs : Model.InductionHeadInputs seq dModel dHead) + (c : InductionHeadCert seq) (hsound : InductionHeadCertSound inputs c) + {lb : Rat} (hbound : logitDiffLowerBoundFromCertBest c = some lb) + {q : Fin seq} (hq : q ∈ c.active) : + (lb : Real) ≤ headLogitDiff inputs q := by + classical + cases h0 : logitDiffLowerBoundFromCert c with + | none => + cases h1 : logitDiffLowerBoundFromCertWeighted c with + | none => + simp [logitDiffLowerBoundFromCertBest, h0, h1] at hbound + | some lb1 => + have hbound' : lb1 = lb := by + simpa [logitDiffLowerBoundFromCertBest, h0, h1] using hbound + cases hbound' + exact logitDiffLowerBoundFromCertWeighted_le inputs c hsound h1 hq + | some lb0 => + cases h1 : logitDiffLowerBoundFromCertWeighted c with + | none => + have hbound' : lb0 = lb := by + simpa [logitDiffLowerBoundFromCertBest, h0, h1] using hbound + cases hbound' + exact logitDiffLowerBoundFromCert_le inputs c hsound h0 hq + | some lb1 => + have hbound' : max lb0 lb1 = lb := by + simpa [logitDiffLowerBoundFromCertBest, h0, h1] using hbound + cases hbound' + have h0le : (lb0 : Real) ≤ headLogitDiff inputs q := + logitDiffLowerBoundFromCert_le inputs c hsound h0 hq + have h1le : (lb1 : Real) ≤ headLogitDiff inputs q := + logitDiffLowerBoundFromCertWeighted_le inputs c hsound h1 hq + have hmax : max (lb0 : Real) (lb1 : Real) ≤ headLogitDiff inputs q := + max_le_iff.mpr ⟨h0le, h1le⟩ + simpa [ratToReal_max, ratToReal_def] using hmax + /-- Certified logit-diff lower bound derived from exact head inputs. -/ structure InductionLogitLowerBoundResult (inputs : Model.InductionHeadInputs seq dModel dHead) where From f237b5894d6525f9617cbfd635c4d595e973e898 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Wed, 14 Jan 2026 11:43:09 +0100 Subject: [PATCH 181/244] Cache logit-diff bounds and parallelize weightBoundAt --- Nfp/IO/InductionHead/Basic.lean | 53 +++++++++++++---- Nfp/Sound/Induction/Core/Basic.lean | 2 +- Nfp/Sound/Induction/CoreSound/Basic.lean | 4 +- Nfp/Sound/Induction/LogitDiff.lean | 76 ++++++++++++++++++++---- 4 files changed, 109 insertions(+), 26 deletions(-) diff --git a/Nfp/IO/InductionHead/Basic.lean b/Nfp/IO/InductionHead/Basic.lean index b2d66b1..941c286 100644 --- a/Nfp/IO/InductionHead/Basic.lean +++ b/Nfp/IO/InductionHead/Basic.lean @@ -924,16 +924,46 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} let tol := cert.eps * (cert.values.hi - cert.values.lo) timingPrint "timing: head tol done" timingFlush - logTiming "start: head logit-diff lower bound" - timingPrint "timing: head logit-diff lower bound start" - timingFlush - let logitDiffLB? ← timePure "head: logit-diff lower bound" (fun () => - Sound.logitDiffLowerBoundFromCertBest cert) - logTiming "done: head logit-diff lower bound" let effectiveMinLogitDiff := match minLogitDiff? with | some v => some v | none => some (0 : Rat) + let logitCache := Nfp.Sound.logitDiffCache cert + logTiming "start: head logit-diff lower bound" + timingPrint "timing: head logit-diff lower bound start" + timingFlush + let logitDiffLB0? ← timePureWithHeartbeat + "head: logit-diff lower bound unweighted" (fun () => + Nfp.Sound.logitDiffLowerBoundFromCache cert logitCache) + let logitDiffLB? ← + match logitDiffLB0? with + | none => pure none + | some lb0 => + match effectiveMinLogitDiff with + | some minLogitDiff => + if lb0 >= minLogitDiff then + timingPrint "timing: head logit-diff weighted skipped" + timingFlush + pure (some lb0) + else + let lb1? ← timePureWithHeartbeat + "head: logit-diff lower bound weighted" (fun () => + Nfp.Sound.logitDiffLowerBoundWeightedFromCache cert logitCache) + let lb := + match lb1? with + | some lb1 => max lb0 lb1 + | none => lb0 + pure (some lb) + | none => + let lb1? ← timePureWithHeartbeat + "head: logit-diff lower bound weighted" (fun () => + Nfp.Sound.logitDiffLowerBoundWeightedFromCache cert logitCache) + let lb := + match lb1? with + | some lb1 => max lb0 lb1 + | none => lb0 + pure (some lb) + logTiming "done: head logit-diff lower bound" match logitDiffLB? with | none => IO.eprintln "error: empty active set for logit-diff bound" @@ -1036,8 +1066,10 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} logTiming "start: head logit-diff lower bound" timingPrint "timing: head logit-diff lower bound start" timingFlush - let logitDiffLB0? ← timePure "head: logit-diff lower bound" (fun () => - Sound.logitDiffLowerBoundFromCert cert) + let logitCache := Nfp.Sound.logitDiffCache cert + let logitDiffLB0? ← timePureWithHeartbeat + "head: logit-diff lower bound unweighted" (fun () => + Nfp.Sound.logitDiffLowerBoundFromCache cert logitCache) let needsWeighted : Bool := match logitDiffLB0? with | none => true @@ -1050,8 +1082,9 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} | none => false let logitDiffWeighted? ← if needsWeighted then - timePure "head: logit-diff lower bound weighted" (fun () => - Sound.logitDiffLowerBoundFromCertWeighted cert) + timePureWithHeartbeat + "head: logit-diff lower bound weighted" (fun () => + Nfp.Sound.logitDiffLowerBoundWeightedFromCache cert logitCache) else pure none let logitDiffLB? : Option Rat := diff --git a/Nfp/Sound/Induction/Core/Basic.lean b/Nfp/Sound/Induction/Core/Basic.lean index 2b605cc..ee071ec 100644 --- a/Nfp/Sound/Induction/Core/Basic.lean +++ b/Nfp/Sound/Induction/Core/Basic.lean @@ -662,7 +662,7 @@ def buildInductionHeadCoreCacheWith [NeZero seq] {dModel dHead : Nat} let weightBoundAtClampedBase : Fin seq → Fin seq → Rat := fun q k => min (weightBoundAtBase q k) (epsAt q) let weightBoundAt : Fin seq → Fin seq → Rat := - Bounds.cacheBound2 weightBoundAtClampedBase + Bounds.cacheBound2Task weightBoundAtClampedBase let margin : Rat := if h : inputs.active.Nonempty then inputs.active.inf' h marginAt diff --git a/Nfp/Sound/Induction/CoreSound/Basic.lean b/Nfp/Sound/Induction/CoreSound/Basic.lean index 0079b5c..49fd792 100644 --- a/Nfp/Sound/Induction/CoreSound/Basic.lean +++ b/Nfp/Sound/Induction/CoreSound/Basic.lean @@ -436,7 +436,7 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N let weightBoundAtClampedBase : Fin seq → Fin seq → Rat := fun q k => min (weightBoundAtBase q k) (epsAt q) let weightBoundAt : Fin seq → Fin seq → Rat := - Bounds.cacheBound2 weightBoundAtClampedBase + Bounds.cacheBound2Task weightBoundAtClampedBase let margin : Rat := if h : inputs.active.Nonempty then inputs.active.inf' h marginAt @@ -1081,7 +1081,7 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N weightBoundAt q k = min (weightBoundAtBase q k) (epsAt q) := by intro q k simpa [weightBoundAt, weightBoundAtClampedBase] using - (Bounds.cacheBound2_apply (f := weightBoundAtClampedBase) q k) + (Bounds.cacheBound2Task_apply (f := weightBoundAtClampedBase) q k) have hepsAt : ∀ q, epsAt q = min (1 : Rat) diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean index d235405..e542c2a 100644 --- a/Nfp/Sound/Induction/LogitDiff.lean +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -107,12 +107,48 @@ noncomputable def headLogitDiff (inputs : Model.InductionHeadInputs seq dModel d /-- Lower bound computed from the per-key lower bounds in an induction certificate. -/ def logitDiffLowerBoundFromCert (c : InductionHeadCert seq) : Option Rat := - Circuit.logitDiffLowerBoundAtLo c.active c.prev c.epsAt - c.values.lo c.values.valsLo + let epsAt := Bounds.cacheBoundTask c.epsAt + let valsLo := Bounds.cacheBoundTask c.values.valsLo + Circuit.logitDiffLowerBoundAtLo c.active c.prev epsAt + c.values.lo valsLo /-- Lower bound computed from per-key weight bounds in an induction certificate. -/ def logitDiffLowerBoundFromCertWeighted (c : InductionHeadCert seq) : Option Rat := - Circuit.logitDiffLowerBoundWeightedAt c.active c.prev c.weightBoundAt c.values.valsLo + let valsLo := Bounds.cacheBoundTask c.values.valsLo + Circuit.logitDiffLowerBoundWeightedAt c.active c.prev c.weightBoundAt valsLo + +/-- Cached eps and value lower bounds for logit-diff computations. -/ +structure LogitDiffCache (seq : Nat) where + /-- Per-query eps bounds. -/ + epsAt : Fin seq → Rat + /-- Per-key value lower bounds. -/ + valsLo : Fin seq → Rat + +/-- Build a shared cache for logit-diff computations from a certificate. -/ +def logitDiffCache (c : InductionHeadCert seq) : LogitDiffCache seq := + { epsAt := Bounds.cacheBoundTask c.epsAt + valsLo := Bounds.cacheBoundTask c.values.valsLo } + +/-- Unweighted logit-diff lower bound from a shared cache. -/ +def logitDiffLowerBoundFromCache (c : InductionHeadCert seq) (cache : LogitDiffCache seq) : + Option Rat := + Circuit.logitDiffLowerBoundAtLo c.active c.prev cache.epsAt c.values.lo cache.valsLo + +/-- Weighted logit-diff lower bound from a shared cache. -/ +def logitDiffLowerBoundWeightedFromCache (c : InductionHeadCert seq) (cache : LogitDiffCache seq) : + Option Rat := + Circuit.logitDiffLowerBoundWeightedAt c.active c.prev c.weightBoundAt cache.valsLo + +/-- `logitDiffLowerBoundFromCache` matches the cached default computation. -/ +theorem logitDiffLowerBoundFromCache_eq (c : InductionHeadCert seq) : + logitDiffLowerBoundFromCache c (logitDiffCache c) = logitDiffLowerBoundFromCert c := by + rfl + +/-- `logitDiffLowerBoundWeightedFromCache` matches the cached default computation. -/ +theorem logitDiffLowerBoundWeightedFromCache_eq (c : InductionHeadCert seq) : + logitDiffLowerBoundWeightedFromCache c (logitDiffCache c) = + logitDiffLowerBoundFromCertWeighted c := by + rfl /-- Best available logit-diff lower bound from an induction certificate. -/ def logitDiffLowerBoundFromCertBest (c : InductionHeadCert seq) : Option Rat := @@ -140,26 +176,33 @@ theorem logitDiffLowerBoundFromCert_le let weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := fun q k => Circuit.softmax (scoresRealOfInputs inputs q) k let vals : Fin (Nat.succ n) → Real := valsRealOfInputs inputs + let epsAt := Bounds.cacheBoundTask c.epsAt + let valsLo := Bounds.cacheBoundTask c.values.valsLo let others : Finset (Fin (Nat.succ n)) := (Finset.univ : Finset (Fin (Nat.succ n))).erase (c.prev q) let sumOthers : Real := ∑ k ∈ others, weights q k let valsLoPrev : Real := (c.values.valsLo (c.prev q) : Real) let lo : Real := (c.values.lo : Real) have hboundRat : - lb ≤ c.values.valsLo (c.prev q) - - c.epsAt q * (c.values.valsLo (c.prev q) - c.values.lo) := by + lb ≤ valsLo (c.prev q) - + epsAt q * (valsLo (c.prev q) - c.values.lo) := by refine Circuit.logitDiffLowerBoundAtLo_le (active := c.active) (prev := c.prev) - (epsAt := c.epsAt) + (epsAt := epsAt) (lo := c.values.lo) - (valsLo := c.values.valsLo) + (valsLo := valsLo) q hq lb ?_ simpa [logitDiffLowerBoundFromCert] using hbound + have hboundRat' : + lb ≤ c.values.valsLo (c.prev q) - + c.epsAt q * (c.values.valsLo (c.prev q) - c.values.lo) := by + simpa [epsAt, valsLo, Bounds.cacheBoundTask_apply] using hboundRat have hboundReal : (lb : Real) ≤ valsLoPrev - (c.epsAt q : Real) * (valsLoPrev - lo) := by - simpa [ratToReal_sub, ratToReal_mul, ratToReal_def] using ratToReal_le_of_le hboundRat + simpa [ratToReal_sub, ratToReal_mul, ratToReal_def] using + ratToReal_le_of_le hboundRat' have hweights_nonneg : ∀ k, 0 ≤ weights q k := hsound.softmax_bounds.nonneg q hq have hweights := hsound.oneHot_bounds_at q hq @@ -292,22 +335,28 @@ theorem logitDiffLowerBoundFromCertWeighted_le let weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := fun q k => Circuit.softmax (scoresRealOfInputs inputs q) k let vals : Fin (Nat.succ n) → Real := valsRealOfInputs inputs + let valsLoCached := Bounds.cacheBoundTask c.values.valsLo let others : Finset (Fin (Nat.succ n)) := (Finset.univ : Finset (Fin (Nat.succ n))).erase (c.prev q) - let valsLoPrevRat : Rat := c.values.valsLo (c.prev q) + let valsLoPrevRat : Rat := valsLoCached (c.prev q) let valsLoPrev : Real := (valsLoPrevRat : Real) have hboundRat : lb ≤ valsLoPrevRat - (others.sum (fun k => - c.weightBoundAt q k * max (0 : Rat) (valsLoPrevRat - c.values.valsLo k))) := by + c.weightBoundAt q k * max (0 : Rat) (valsLoPrevRat - valsLoCached k))) := by refine Circuit.logitDiffLowerBoundWeightedAt_le (active := c.active) (prev := c.prev) (weightBoundAt := c.weightBoundAt) - (valsLo := c.values.valsLo) + (valsLo := valsLoCached) q hq lb ?_ simpa [logitDiffLowerBoundFromCertWeighted] using hbound + have hboundRat' : + lb ≤ valsLoPrevRat - + (others.sum (fun k => + c.weightBoundAt q k * max (0 : Rat) (valsLoPrevRat - c.values.valsLo k))) := by + simpa [valsLoCached, valsLoPrevRat, Bounds.cacheBoundTask_apply] using hboundRat have hboundReal : (lb : Real) ≤ valsLoPrev - @@ -315,7 +364,7 @@ theorem logitDiffLowerBoundFromCertWeighted_le (c.weightBoundAt q k : Real) * max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)))) := by simpa [valsLoPrevRat, valsLoPrev, ratToReal_sub, ratToReal_mul, ratToReal_max, - ratToReal_def, Rat.cast_sum] using ratToReal_le_of_le hboundRat + ratToReal_def, Rat.cast_sum] using ratToReal_le_of_le hboundRat' have hweights_nonneg : ∀ k, 0 ≤ weights q k := hsound.softmax_bounds.nonneg q hq have hweights := hsound.oneHot_bounds_at q hq @@ -328,7 +377,8 @@ theorem logitDiffLowerBoundFromCertWeighted_le weights q (c.prev q) + ∑ k ∈ others, weights q k = ∑ k, weights q k := hsum_decomp _ = 1 := hweights.sum_one q rfl have hvalsLo_prev : valsLoPrev ≤ vals (c.prev q) := by - exact (hsound.value_bounds.vals_bounds (c.prev q)).1 + have hvals := (hsound.value_bounds.vals_bounds (c.prev q)).1 + simpa [valsLoPrev, valsLoPrevRat, valsLoCached, Bounds.cacheBoundTask_apply] using hvals have hvals_lower : ∀ k, valsLoPrev - max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) ≤ vals k := by From 38bb9aae6766eaca101296c252bb49397381fd13 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Wed, 14 Jan 2026 13:21:44 +0100 Subject: [PATCH 182/244] Move induction-head timing helpers into IO.Timing --- Nfp/IO/InductionHead/Basic.lean | 157 +------------------------ Nfp/IO/Timing.lean | 199 ++++++++++++++++++++++++++++++++ 2 files changed, 201 insertions(+), 155 deletions(-) diff --git a/Nfp/IO/InductionHead/Basic.lean b/Nfp/IO/InductionHead/Basic.lean index 941c286..3fed3b1 100644 --- a/Nfp/IO/InductionHead/Basic.lean +++ b/Nfp/IO/InductionHead/Basic.lean @@ -59,161 +59,6 @@ private def valueBoundsModeFromEnv : IO (Option Bool) := do | some "cached" => return some false | _ => return none -/-- Resolve the heartbeat interval (ms) for long-running induction cert builds. -/ -private def heartbeatMs : IO UInt32 := - timingHeartbeatMs - -private def timePureWithHeartbeat {α : Type} (label : String) (f : Unit → α) : IO α := do - let t0 ← monoUsNow - timingPrint s!"timing: {label} start" - timingFlush - let task : Task α := Task.spawn (fun _ => f ()) - let heartbeatMs ← heartbeatMs - if heartbeatMs ≠ 0 then - let mut finished := (← IO.hasFinished task) - while !finished do - IO.sleep heartbeatMs - finished := (← IO.hasFinished task) - if !finished then - let now ← monoUsNow - timingPrint s!"timing: {label} running {now - t0} us" - timingFlush - let res ← IO.wait task - let t1 ← monoUsNow - timingPrint s!"timing: {label} {t1 - t0} us" - return res - -private def forceRat (x : Rat) : IO Unit := do - if x = x then - pure () - else - pure () - -/-- Profile the core induction-head bounds used by the sound certificate builder. -/ -private def timeInductionHeadCoreStages {seq dModel dHead : Nat} [NeZero seq] - (inputs : Model.InductionHeadInputs seq dModel dHead) : IO Unit := do - timingPrint "timing: core stages start" - timingFlush - let lnBounds ← timePureWithHeartbeat "core: ln bounds" (fun () => - Sound.headLnBounds inputs) - let lnLo := lnBounds.1 - let lnHi := lnBounds.2 - let lnAbsMaxTask : Fin seq → Rat := - Sound.Bounds.cacheBoundTask (fun q => - Sound.Bounds.intervalAbsBound (lnLo q) (lnHi q)) - let lnAbsMaxArr ← timePureWithHeartbeat "core: lnAbsMax force" (fun () => - Array.ofFn (fun q : Fin seq => lnAbsMaxTask q)) - let lnAbsMax : Fin seq → Rat := fun q => - lnAbsMaxArr.getD q.1 (0 : Rat) - let lnAbsMaxMax : Rat := - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := by - simp [univ] - univ.sup' hnonempty (fun q => lnAbsMax q) - let qAbsRowTasks : Array (Task (Array Rat)) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - Array.ofFn (fun d : Fin dHead => - Sound.Bounds.dotIntervalAbsBound (fun j => inputs.wq j d) (lnLo q) (lnHi q) + - |inputs.bq d|))) - let qAbsBaseArr ← timePureWithHeartbeat "core: qAbs base force" (fun () => - let defaultTask : Task (Array Rat) := Task.spawn (fun _ => #[]) - Array.ofFn (fun q : Fin seq => - (qAbsRowTasks.getD q.1 defaultTask).get)) - let qAbsBase : Fin seq → Fin dHead → Rat := fun q d => - let row := qAbsBaseArr.getD q.1 #[] - row.getD d.1 (0 : Rat) - let kAbsRowTasks : Array (Task (Array Rat)) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - Array.ofFn (fun d : Fin dHead => - Sound.Bounds.dotIntervalAbsBound (fun j => inputs.wk j d) (lnLo q) (lnHi q) + - |inputs.bk d|))) - let kAbsBaseArr ← timePureWithHeartbeat "core: kAbs base force" (fun () => - let defaultTask : Task (Array Rat) := Task.spawn (fun _ => #[]) - Array.ofFn (fun q : Fin seq => - (kAbsRowTasks.getD q.1 defaultTask).get)) - let kAbsBase : Fin seq → Fin dHead → Rat := fun q d => - let row := kAbsBaseArr.getD q.1 #[] - row.getD d.1 (0 : Rat) - let dotAbs ← timePureWithHeartbeat "core: dotAbs tasks" (fun () => - dotAbsFromQKV qAbsBase kAbsBase) - let _ ← timePureWithHeartbeat "core: dotAbs force" (fun () => - match List.finRange seq with - | [] => (0 : Rat) - | q :: _ => - match List.finRange seq with - | [] => (0 : Rat) - | k :: _ => dotAbs q k) - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => - |inputs.scale| * dotAbs q k - let scoreLo : Fin seq → Fin seq → Rat := fun q k => - if masked q k then inputs.maskValue else -scoreBaseAbs q k - let scoreHi : Fin seq → Fin seq → Rat := fun q k => - if masked q k then inputs.maskValue else scoreBaseAbs q k - let scoreLoPrev : Fin seq → Rat := fun q => - scoreLo q (inputs.prev q) - let otherKeys : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - let marginAt : Fin seq → Rat := fun q => - if hq : q ∈ inputs.active then - let other := otherKeys q - if h : other.Nonempty then - other.inf' h (fun k => scoreLoPrev q - scoreHi q k) - else - (0 : Rat) - else - (0 : Rat) - let margin ← timePureWithHeartbeat "core: margin" (fun () => - if h : inputs.active.Nonempty then - inputs.active.inf' h marginAt - else - (0 : Rat)) - let marginNeg ← timePureWithHeartbeat "core: margin < 0" (fun () => - decide (margin < 0)) - let verboseTiming ← IO.getEnv "NFP_TIMING_VERBOSE" - if verboseTiming.isSome then - timingPrint s!"timing: core: margin neg={marginNeg}" - let tEps0 ← monoUsNow - timingPrint "timing: core: eps start" - timingFlush - let eps := - if marginNeg then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + margin) - let tEps1 ← monoUsNow - timingPrint s!"timing: core: eps {tEps1 - tEps0} us" - timingFlush - let _ := marginAt - let dirHeadVec ← timePureWithHeartbeat "core: dir head vec" (fun () => - Sound.dirHeadVecOfInputs inputs) - let dirHead : Fin dHead → Rat := fun d => dirHeadVec.get d - let wvDir : Fin dModel → Rat := - Sound.Bounds.cacheBoundTask (fun j => - Sound.Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) - let _ ← timePureWithHeartbeat "core: wvDir force" (fun () => - Array.ofFn (fun j : Fin dModel => wvDir j)) - let bDir ← timePureWithHeartbeat "core: bDir" (fun () => - Sound.Linear.dotFin dHead dirHead (fun d => inputs.bv d)) - let valsAbsBase ← timePureWithHeartbeat "core: valsAbsBase" (fun () => - Sound.Linear.sumFin dModel (fun j => |wvDir j|) * lnAbsMaxMax) - let valsLoBase := bDir - valsAbsBase - let valsHiBase := bDir + valsAbsBase - let valsLo : Fin seq → Rat := fun _ => valsLoBase - let valsHi : Fin seq → Rat := fun _ => valsHiBase - let _ ← timePureWithHeartbeat "core: value bounds" (fun () => - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := by - simp [univ] - let lo := univ.inf' hnonempty valsLo - let hi := univ.sup' hnonempty valsHi - (lo, hi)) - timingPrint "timing: core stages done" - timingFlush - /-- Load induction head inputs from disk. -/ def loadInductionHeadInputs (path : System.FilePath) : IO (Except String (Sigma (fun seq => @@ -932,6 +777,7 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} logTiming "start: head logit-diff lower bound" timingPrint "timing: head logit-diff lower bound start" timingFlush + profileLogitDiffWeighted cert logitCache let logitDiffLB0? ← timePureWithHeartbeat "head: logit-diff lower bound unweighted" (fun () => Nfp.Sound.logitDiffLowerBoundFromCache cert logitCache) @@ -1067,6 +913,7 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} timingPrint "timing: head logit-diff lower bound start" timingFlush let logitCache := Nfp.Sound.logitDiffCache cert + profileLogitDiffWeighted cert logitCache let logitDiffLB0? ← timePureWithHeartbeat "head: logit-diff lower bound unweighted" (fun () => Nfp.Sound.logitDiffLowerBoundFromCache cert logitCache) diff --git a/Nfp/IO/Timing.lean b/Nfp/IO/Timing.lean index 24b2a78..b71c654 100644 --- a/Nfp/IO/Timing.lean +++ b/Nfp/IO/Timing.lean @@ -3,8 +3,10 @@ module public import Mathlib.Data.List.Range +public import Nfp.IO.HeadScore public import Nfp.Model.InductionHead public import Nfp.Sound.Induction.HeadBounds +public import Nfp.Sound.Induction.LogitDiff /-! Small IO helpers for profiling slow phases. @@ -68,6 +70,10 @@ def timingHeartbeatMs : IO UInt32 := do (← IO.getEnv "NFP_TIMING_HEARTBEAT_MS").bind String.toNat? |>.getD defaultMs return UInt32.ofNat ms +/-- Resolve the heartbeat interval (ms) for long-running induction cert builds. -/ +def heartbeatMs : IO UInt32 := + timingHeartbeatMs + /-- Print a timing line only when stdout timing is enabled. -/ def timingPrint (line : String) : IO Unit := do if (← timingStdoutEnabled) then @@ -122,6 +128,27 @@ def timePure {α : Type} (label : String) (f : Unit → α) : IO α := do timingPrint s!"timing: {label} {t1 - t0} us" return res +/-- Time a pure thunk, printing heartbeat updates while it runs. -/ +def timePureWithHeartbeat {α : Type} (label : String) (f : Unit → α) : IO α := do + let t0 ← monoUsNow + timingPrint s!"timing: {label} start" + timingFlush + let task : Task α := Task.spawn (fun _ => f ()) + let heartbeatMs ← heartbeatMs + if heartbeatMs ≠ 0 then + let mut finished := (← IO.hasFinished task) + while !finished do + IO.sleep heartbeatMs + finished := (← IO.hasFinished task) + if !finished then + let now ← monoUsNow + timingPrint s!"timing: {label} running {now - t0} us" + timingFlush + let res ← IO.wait task + let t1 ← monoUsNow + timingPrint s!"timing: {label} {t1 - t0} us" + return res + /-- Flush stdout immediately for interleaved timing output. -/ def flushStdout : IO Unit := do let h ← IO.getStdout @@ -248,6 +275,178 @@ def timeHeadScoreFieldForces {seq dModel dHead : Nat} timingPrint "timing: head score field force done" timingFlush +/-- Force a rational to help isolate cached computations. -/ +def forceRat (x : Rat) : IO Unit := do + if x = x then + pure () + else + pure () + +/-- Report detailed timing for weighted logit-diff components when enabled. -/ +def logitDiffProfileEnabled : IO Bool := do + return (← IO.getEnv "NFP_TIMING_LOGITDIFF_PROFILE").isSome + +/-- Profile weighted logit-diff sub-steps when logit-diff profiling is enabled. -/ +def profileLogitDiffWeighted {seq : Nat} + (cert : Sound.InductionHeadCert seq) + (cache : Sound.LogitDiffCache seq) : IO Unit := do + if !(← logitDiffProfileEnabled) then + pure () + else + timingPrint "timing: logit-diff profile start" + timingFlush + let _ ← timePureWithHeartbeat "logit-diff profile: valsLo force" (fun () => + Array.ofFn (fun q : Fin seq => cache.valsLo q)) + let _ ← timePureWithHeartbeat "logit-diff profile: weightBoundAt force" (fun () => + Array.ofFn (fun q : Fin seq => + Array.ofFn (fun k : Fin seq => cert.weightBoundAt q k))) + let _ ← timePureWithHeartbeat "logit-diff profile: weighted gap sum" (fun () => + let others : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (cert.prev q) + Array.ofFn (fun q : Fin seq => + (others q).sum (fun k => + let diff := cache.valsLo (cert.prev q) - cache.valsLo k + cert.weightBoundAt q k * max (0 : Rat) diff))) + let _ ← timePureWithHeartbeat "logit-diff profile: weighted min" (fun () => + let others : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (cert.prev q) + let gap : Fin seq → Rat := fun q => + (others q).sum (fun k => + let diff := cache.valsLo (cert.prev q) - cache.valsLo k + cert.weightBoundAt q k * max (0 : Rat) diff) + let f : Fin seq → Rat := fun q => cache.valsLo (cert.prev q) - gap q + if h : cert.active.Nonempty then + let img := cert.active.image f + let _ := Finset.min' img (h.image f) + () + else + ()) + +/-- Profile the core induction-head bounds used by the sound certificate builder. -/ +def timeInductionHeadCoreStages {seq dModel dHead : Nat} [NeZero seq] + (inputs : Model.InductionHeadInputs seq dModel dHead) : IO Unit := do + timingPrint "timing: core stages start" + timingFlush + let lnBounds ← timePureWithHeartbeat "core: ln bounds" (fun () => + Sound.headLnBounds inputs) + let lnLo := lnBounds.1 + let lnHi := lnBounds.2 + let lnAbsMaxTask : Fin seq → Rat := + Sound.Bounds.cacheBoundTask (fun q => + Sound.Bounds.intervalAbsBound (lnLo q) (lnHi q)) + let lnAbsMaxArr ← timePureWithHeartbeat "core: lnAbsMax force" (fun () => + Array.ofFn (fun q : Fin seq => lnAbsMaxTask q)) + let lnAbsMax : Fin seq → Rat := fun q => + lnAbsMaxArr.getD q.1 (0 : Rat) + let lnAbsMaxMax : Rat := + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := by + simp [univ] + univ.sup' hnonempty (fun q => lnAbsMax q) + let qAbsRowTasks : Array (Task (Array Rat)) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + Array.ofFn (fun d : Fin dHead => + Sound.Bounds.dotIntervalAbsBound (fun j => inputs.wq j d) (lnLo q) (lnHi q) + + |inputs.bq d|))) + let qAbsBaseArr ← timePureWithHeartbeat "core: qAbs base force" (fun () => + let defaultTask : Task (Array Rat) := Task.spawn (fun _ => #[]) + Array.ofFn (fun q : Fin seq => + (qAbsRowTasks.getD q.1 defaultTask).get)) + let qAbsBase : Fin seq → Fin dHead → Rat := fun q d => + let row := qAbsBaseArr.getD q.1 #[] + row.getD d.1 (0 : Rat) + let kAbsRowTasks : Array (Task (Array Rat)) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + Array.ofFn (fun d : Fin dHead => + Sound.Bounds.dotIntervalAbsBound (fun j => inputs.wk j d) (lnLo q) (lnHi q) + + |inputs.bk d|))) + let kAbsBaseArr ← timePureWithHeartbeat "core: kAbs base force" (fun () => + let defaultTask : Task (Array Rat) := Task.spawn (fun _ => #[]) + Array.ofFn (fun q : Fin seq => + (kAbsRowTasks.getD q.1 defaultTask).get)) + let kAbsBase : Fin seq → Fin dHead → Rat := fun q d => + let row := kAbsBaseArr.getD q.1 #[] + row.getD d.1 (0 : Rat) + let dotAbs ← timePureWithHeartbeat "core: dotAbs tasks" (fun () => + dotAbsFromQKV qAbsBase kAbsBase) + let _ ← timePureWithHeartbeat "core: dotAbs force" (fun () => + match List.finRange seq with + | [] => (0 : Rat) + | q :: _ => + match List.finRange seq with + | [] => (0 : Rat) + | k :: _ => dotAbs q k) + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => + |inputs.scale| * dotAbs q k + let scoreLo : Fin seq → Fin seq → Rat := fun q k => + if masked q k then inputs.maskValue else -scoreBaseAbs q k + let scoreHi : Fin seq → Fin seq → Rat := fun q k => + if masked q k then inputs.maskValue else scoreBaseAbs q k + let scoreLoPrev : Fin seq → Rat := fun q => + scoreLo q (inputs.prev q) + let otherKeys : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + let marginAt : Fin seq → Rat := fun q => + if hq : q ∈ inputs.active then + let other := otherKeys q + if h : other.Nonempty then + other.inf' h (fun k => scoreLoPrev q - scoreHi q k) + else + (0 : Rat) + else + (0 : Rat) + let margin ← timePureWithHeartbeat "core: margin" (fun () => + if h : inputs.active.Nonempty then + inputs.active.inf' h marginAt + else + (0 : Rat)) + let marginNeg ← timePureWithHeartbeat "core: margin < 0" (fun () => + decide (margin < 0)) + let verboseTiming ← IO.getEnv "NFP_TIMING_VERBOSE" + if verboseTiming.isSome then + timingPrint s!"timing: core: margin neg={marginNeg}" + let tEps0 ← monoUsNow + timingPrint "timing: core: eps start" + timingFlush + let eps := + if marginNeg then + (1 : Rat) + else + ratDivUp (seq - 1) (1 + margin) + let tEps1 ← monoUsNow + timingPrint s!"timing: core: eps {tEps1 - tEps0} us" + timingFlush + let _ := marginAt + let dirHeadVec ← timePureWithHeartbeat "core: dir head vec" (fun () => + Sound.dirHeadVecOfInputs inputs) + let dirHead : Fin dHead → Rat := fun d => dirHeadVec.get d + let wvDir : Fin dModel → Rat := + Sound.Bounds.cacheBoundTask (fun j => + Sound.Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) + let _ ← timePureWithHeartbeat "core: wvDir force" (fun () => + Array.ofFn (fun j : Fin dModel => wvDir j)) + let bDir ← timePureWithHeartbeat "core: bDir" (fun () => + Sound.Linear.dotFin dHead dirHead (fun d => inputs.bv d)) + let valsAbsBase ← timePureWithHeartbeat "core: valsAbsBase" (fun () => + Sound.Linear.sumFin dModel (fun j => |wvDir j|) * lnAbsMaxMax) + let valsLoBase := bDir - valsAbsBase + let valsHiBase := bDir + valsAbsBase + let valsLo : Fin seq → Rat := fun _ => valsLoBase + let valsHi : Fin seq → Rat := fun _ => valsHiBase + let _ ← timePureWithHeartbeat "core: value bounds" (fun () => + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := by + simp [univ] + let lo := univ.inf' hnonempty valsLo + let hi := univ.sup' hnonempty valsHi + (lo, hi)) + timingPrint "timing: core stages done" + timingFlush + end IO end Nfp From 4312769032b9305898fd823feddcd64186d4e5ae Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Wed, 14 Jan 2026 18:35:59 +0100 Subject: [PATCH 183/244] Optimize weighted logit-diff path --- Nfp/Circuit/Cert/LogitDiff.lean | 33 ++++- Nfp/IO/InductionHead/Basic.lean | 24 +++- Nfp/Sound/Induction/Core/Basic.lean | 108 ++++++++------ Nfp/Sound/Induction/CoreSound/Basic.lean | 176 +++++++++++++---------- 4 files changed, 207 insertions(+), 134 deletions(-) diff --git a/Nfp/Circuit/Cert/LogitDiff.lean b/Nfp/Circuit/Cert/LogitDiff.lean index 8b5327d..183639e 100644 --- a/Nfp/Circuit/Cert/LogitDiff.lean +++ b/Nfp/Circuit/Cert/LogitDiff.lean @@ -74,7 +74,11 @@ def logitDiffLowerBoundWeightedAt (active : Finset (Fin seq)) let gap : Fin seq → Rat := fun q => (others q).sum (fun k => let diff := valsLo (prev q) - valsLo k - weightBoundAt q k * max (0 : Rat) diff) + let diffPos := max (0 : Rat) diff + if hdiff : diffPos = 0 then + 0 + else + weightBoundAt q k * diffPos) let f : Fin seq → Rat := fun q => valsLo (prev q) - gap q let img := active.image f have himg : img.Nonempty := h.image f @@ -166,7 +170,10 @@ theorem logitDiffLowerBoundWeightedAt_le (active : Finset (Fin seq)) let gap : Fin seq → Rat := fun q => (others q).sum (fun k => let diff := valsLo (prev q) - valsLo k - weightBoundAt q k * max (0 : Rat) diff) + if hdiff : max (0 : Rat) diff = 0 then + 0 + else + weightBoundAt q k * max (0 : Rat) diff) let f : Fin seq → Rat := fun q => valsLo (prev q) - gap q have hbound' : (active.image f).min' (hnonempty.image f) = lb := by simpa [logitDiffLowerBoundWeightedAt, hnonempty, f, gap, others] using hbound @@ -177,13 +184,25 @@ theorem logitDiffLowerBoundWeightedAt_le (active : Finset (Fin seq)) Finset.min'_le _ _ hmem have hmin' : lb ≤ f q := by simpa [hbound'] using hmin - have hmin'' : + have hgap_eq : + gap q = + (others q).sum (fun k => + weightBoundAt q k * max (0 : Rat) (valsLo (prev q) - valsLo k)) := by + classical + dsimp [gap] + refine Finset.sum_congr rfl ?_ + intro k hk + by_cases hdiff : max (0 : Rat) (valsLo (prev q) - valsLo k) = 0 + · simp [hdiff] + · simp [hdiff] + have hmin'' : lb ≤ valsLo (prev q) - gap q := by + simpa [f] using hmin' + have hmin''' : lb ≤ valsLo (prev q) - (others q).sum (fun k => - let diff := valsLo (prev q) - valsLo k - weightBoundAt q k * max (0 : Rat) diff) := by - simpa [f, gap] using hmin' - simpa [others] using hmin'' + weightBoundAt q k * max (0 : Rat) (valsLo (prev q) - valsLo k)) := by + simpa [hgap_eq] using hmin'' + simpa [others] using hmin''' end Circuit diff --git a/Nfp/IO/InductionHead/Basic.lean b/Nfp/IO/InductionHead/Basic.lean index 3fed3b1..8f7be5e 100644 --- a/Nfp/IO/InductionHead/Basic.lean +++ b/Nfp/IO/InductionHead/Basic.lean @@ -913,7 +913,17 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} timingPrint "timing: head logit-diff lower bound start" timingFlush let logitCache := Nfp.Sound.logitDiffCache cert - profileLogitDiffWeighted cert logitCache + let profiling ← logitDiffProfileEnabled + if profiling then + profileLogitDiffWeighted cert logitCache + else + pure () + let weightedTask? : Option (Task (Option Rat)) := + if profiling then + none + else + some (Task.spawn (fun _ => + Nfp.Sound.logitDiffLowerBoundWeightedFromCache cert logitCache)) let logitDiffLB0? ← timePureWithHeartbeat "head: logit-diff lower bound unweighted" (fun () => Nfp.Sound.logitDiffLowerBoundFromCache cert logitCache) @@ -929,9 +939,15 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} | none => false let logitDiffWeighted? ← if needsWeighted then - timePureWithHeartbeat - "head: logit-diff lower bound weighted" (fun () => - Nfp.Sound.logitDiffLowerBoundWeightedFromCache cert logitCache) + match weightedTask? with + | some task => + timePureWithHeartbeat + "head: logit-diff lower bound weighted" (fun () => + task.get) + | none => + timePureWithHeartbeat + "head: logit-diff lower bound weighted" (fun () => + Nfp.Sound.logitDiffLowerBoundWeightedFromCache cert logitCache) else pure none let logitDiffLB? : Option Rat := diff --git a/Nfp/Sound/Induction/Core/Basic.lean b/Nfp/Sound/Induction/Core/Basic.lean index ee071ec..d44c805 100644 --- a/Nfp/Sound/Induction/Core/Basic.lean +++ b/Nfp/Sound/Induction/Core/Basic.lean @@ -578,52 +578,64 @@ def buildInductionHeadCoreCacheWith [NeZero seq] {dModel dHead : Nat} Bounds.cacheBound2 scoreGapLoBaseRaw let otherKeys : Fin seq → Finset (Fin seq) := fun q => (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - let worstKeyRaw : Fin seq → Option (Fin seq) := fun q => - if hq : q ∈ inputs.active then - let ks := finRangeSeq.filter (fun k => decide (k ≠ inputs.prev q)) - match ks with - | [] => none - | k :: ks => - let step (best : Rat × Fin seq) (k : Fin seq) := - let s := scoreGapLoBase q k - if s ≤ best.1 then (s, k) else best - some (ks.foldl step (scoreGapLoBase q k, k)).2 + -- Skip worst-key refinement when base/refined budgets match to avoid duplicate score-gap work. + let worstKey : Fin seq → Option (Fin seq) := + if h : cfg.splitBudgetDiffRefined = cfg.splitBudgetDiffBase then + fun _ => none else - none - let worstKeyArr : Array (Thunk (Option (Fin seq))) := - Array.ofFn (fun q => Thunk.mk (fun _ => worstKeyRaw q)) - let worstKey : Fin seq → Option (Fin seq) := fun q => - let t := worstKeyArr[q.1]'(by - simp [worstKeyArr, q.isLt]) - Thunk.get t - let dotDiffLo : Fin seq → Fin seq → Rat := fun q k => - match worstKey q with - | some k' => - if hk : k = k' then - let dimsQ := splitDimsQ q - let dimsDiff := splitDimsDiffRefined q k - let prev := inputs.prev q - (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)).1 - else - dotDiffLoBase q k - | none => dotDiffLoBase q k - let dotDiffHi : Fin seq → Fin seq → Rat := fun q k => - match worstKey q with - | some k' => - if hk : k = k' then - let dimsQ := splitDimsQ q - let dimsDiff := splitDimsDiffRefined q k - let prev := inputs.prev q - (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)).2 + let worstKeyRaw : Fin seq → Option (Fin seq) := fun q => + if hq : q ∈ inputs.active then + let ks := finRangeSeq.filter (fun k => decide (k ≠ inputs.prev q)) + match ks with + | [] => none + | k :: ks => + let step (best : Rat × Fin seq) (k : Fin seq) := + let s := scoreGapLoBase q k + if s ≤ best.1 then (s, k) else best + some (ks.foldl step (scoreGapLoBase q k, k)).2 else - dotDiffHiBase q k - | none => dotDiffHiBase q k + none + let worstKeyArr : Array (Thunk (Option (Fin seq))) := + Array.ofFn (fun q => Thunk.mk (fun _ => worstKeyRaw q)) + fun q => + let t := worstKeyArr[q.1]'(by + simp [worstKeyArr, q.isLt]) + Thunk.get t + let dotDiffLoHi : (Fin seq → Fin seq → Rat) × (Fin seq → Fin seq → Rat) := + if h : cfg.splitBudgetDiffRefined = cfg.splitBudgetDiffBase then + (dotDiffLoBase, dotDiffHiBase) + else + let dotDiffLo : Fin seq → Fin seq → Rat := fun q k => + match worstKey q with + | some k' => + if hk : k = k' then + let dimsQ := splitDimsQ q + let dimsDiff := splitDimsDiffRefined q k + let prev := inputs.prev q + (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)).1 + else + dotDiffLoBase q k + | none => dotDiffLoBase q k + let dotDiffHi : Fin seq → Fin seq → Rat := fun q k => + match worstKey q with + | some k' => + if hk : k = k' then + let dimsQ := splitDimsQ q + let dimsDiff := splitDimsDiffRefined q k + let prev := inputs.prev q + (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)).2 + else + dotDiffHiBase q k + | none => dotDiffHiBase q k + (dotDiffLo, dotDiffHi) + let dotDiffLo : Fin seq → Fin seq → Rat := dotDiffLoHi.1 + let dotDiffHi : Fin seq → Fin seq → Rat := dotDiffLoHi.2 let scoreGapLoRaw : Fin seq → Fin seq → Rat := fun q k => if masked q (inputs.prev q) then scoreLoPrev q - scoreHi q k @@ -653,16 +665,18 @@ def buildInductionHeadCoreCacheWith [NeZero seq] {dModel dHead : Nat} (1 : Rat) else ratDivUp 1 (1 + gap) + let weightBoundAtBaseCached : Fin seq → Fin seq → Rat := + Bounds.cacheBound2Task weightBoundAtBase let epsAtBase : Fin seq → Rat := fun q => let other := otherKeys q - let total := other.sum (fun k => weightBoundAtBase q k) + let total := other.sum (fun k => weightBoundAtBaseCached q k) min (1 : Rat) total let epsAt : Fin seq → Rat := Bounds.cacheBoundThunk epsAtBase let weightBoundAtClampedBase : Fin seq → Fin seq → Rat := fun q k => - min (weightBoundAtBase q k) (epsAt q) + min (weightBoundAtBaseCached q k) (epsAt q) let weightBoundAt : Fin seq → Fin seq → Rat := - Bounds.cacheBound2Task weightBoundAtClampedBase + Bounds.cacheBound2 weightBoundAtClampedBase let margin : Rat := if h : inputs.active.Nonempty then inputs.active.inf' h marginAt diff --git a/Nfp/Sound/Induction/CoreSound/Basic.lean b/Nfp/Sound/Induction/CoreSound/Basic.lean index 49fd792..cb744a9 100644 --- a/Nfp/Sound/Induction/CoreSound/Basic.lean +++ b/Nfp/Sound/Induction/CoreSound/Basic.lean @@ -352,52 +352,63 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N Bounds.cacheBound2 scoreGapLoBaseRaw let otherKeys : Fin seq → Finset (Fin seq) := fun q => (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - let worstKeyRaw : Fin seq → Option (Fin seq) := fun q => - if hq : q ∈ inputs.active then - let ks := finRangeSeq.filter (fun k => decide (k ≠ inputs.prev q)) - match ks with - | [] => none - | k :: ks => - let step (best : Rat × Fin seq) (k : Fin seq) := - let s := scoreGapLoBase q k - if s ≤ best.1 then (s, k) else best - some (ks.foldl step (scoreGapLoBase q k, k)).2 + let worstKey : Fin seq → Option (Fin seq) := + if h : cfg.splitBudgetDiffRefined = cfg.splitBudgetDiffBase then + fun _ => none else - none - let worstKeyArr : Array (Thunk (Option (Fin seq))) := - Array.ofFn (fun q => Thunk.mk (fun _ => worstKeyRaw q)) - let worstKey : Fin seq → Option (Fin seq) := fun q => - let t := worstKeyArr[q.1]'(by - simp [worstKeyArr, q.isLt]) - Thunk.get t - let dotDiffLo : Fin seq → Fin seq → Rat := fun q k => - match worstKey q with - | some k' => - if hk : k = k' then - let dimsQ := splitDimsQ q - let dimsDiff := splitDimsDiffRefined q k - let prev := inputs.prev q - (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)).1 - else - dotDiffLoBase q k - | none => dotDiffLoBase q k - let dotDiffHi : Fin seq → Fin seq → Rat := fun q k => - match worstKey q with - | some k' => - if hk : k = k' then - let dimsQ := splitDimsQ q - let dimsDiff := splitDimsDiffRefined q k - let prev := inputs.prev q - (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)).2 + let worstKeyRaw : Fin seq → Option (Fin seq) := fun q => + if hq : q ∈ inputs.active then + let ks := finRangeSeq.filter (fun k => decide (k ≠ inputs.prev q)) + match ks with + | [] => none + | k :: ks => + let step (best : Rat × Fin seq) (k : Fin seq) := + let s := scoreGapLoBase q k + if s ≤ best.1 then (s, k) else best + some (ks.foldl step (scoreGapLoBase q k, k)).2 else - dotDiffHiBase q k - | none => dotDiffHiBase q k + none + let worstKeyArr : Array (Thunk (Option (Fin seq))) := + Array.ofFn (fun q => Thunk.mk (fun _ => worstKeyRaw q)) + fun q => + let t := worstKeyArr[q.1]'(by + simp [worstKeyArr, q.isLt]) + Thunk.get t + let dotDiffLoHi : (Fin seq → Fin seq → Rat) × (Fin seq → Fin seq → Rat) := + if h : cfg.splitBudgetDiffRefined = cfg.splitBudgetDiffBase then + (dotDiffLoBase, dotDiffHiBase) + else + let dotDiffLo : Fin seq → Fin seq → Rat := fun q k => + match worstKey q with + | some k' => + if hk : k = k' then + let dimsQ := splitDimsQ q + let dimsDiff := splitDimsDiffRefined q k + let prev := inputs.prev q + (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)).1 + else + dotDiffLoBase q k + | none => dotDiffLoBase q k + let dotDiffHi : Fin seq → Fin seq → Rat := fun q k => + match worstKey q with + | some k' => + if hk : k = k' then + let dimsQ := splitDimsQ q + let dimsDiff := splitDimsDiffRefined q k + let prev := inputs.prev q + (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)).2 + else + dotDiffHiBase q k + | none => dotDiffHiBase q k + (dotDiffLo, dotDiffHi) + let dotDiffLo : Fin seq → Fin seq → Rat := dotDiffLoHi.1 + let dotDiffHi : Fin seq → Fin seq → Rat := dotDiffLoHi.2 let scoreGapLoRaw : Fin seq → Fin seq → Rat := fun q k => if masked q (inputs.prev q) then scoreLoPrev q - scoreHi q k @@ -427,16 +438,18 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N (1 : Rat) else ratDivUp 1 (1 + gap) + let weightBoundAtBaseCached : Fin seq → Fin seq → Rat := + Bounds.cacheBound2Task weightBoundAtBase let epsAtBase : Fin seq → Rat := fun q => let other := otherKeys q - let total := other.sum (fun k => weightBoundAtBase q k) + let total := other.sum (fun k => weightBoundAtBaseCached q k) min (1 : Rat) total let epsAt : Fin seq → Rat := Bounds.cacheBoundThunk epsAtBase let weightBoundAtClampedBase : Fin seq → Fin seq → Rat := fun q k => - min (weightBoundAtBase q k) (epsAt q) + min (weightBoundAtBaseCached q k) (epsAt q) let weightBoundAt : Fin seq → Fin seq → Rat := - Bounds.cacheBound2Task weightBoundAtClampedBase + Bounds.cacheBound2 weightBoundAtClampedBase let margin : Rat := if h : inputs.active.Nonempty then inputs.active.inf' h marginAt @@ -826,35 +839,40 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N Array.getElem_ofFn] using hspecBase.1 · simpa [dotDiffHiBase, dotDiffRowTasksBase, hq, hmask, Task.spawn, Array.getElem_ofFn] using hspecBase.2 - cases hkey : worstKey q with - | none => - simpa [dotDiffLo, dotDiffHi, hkey] using hspecBase_bounds - | some k' => - by_cases hk : k = k' - · have hlow' : - (dotDiffLo q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) := by - simpa [dotDiffLo, hkey, hk] using hspecRef.1 - have hhigh' : - dotProduct (fun d => qRealOfInputs inputs q d) + by_cases hbudget : cfg.splitBudgetDiffRefined = cfg.splitBudgetDiffBase + · simpa [dotDiffLo, dotDiffHi, dotDiffLoHi, hbudget] using hspecBase_bounds + · cases hkey : worstKey q with + | none => + simpa [dotDiffLo, dotDiffHi, dotDiffLoHi, hbudget, hkey] using + hspecBase_bounds + | some k' => + by_cases hk : k = k' + · have hlow' : + (dotDiffLo q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by - simpa [dotDiffHi, hkey, hk] using hspecRef.2 - exact ⟨hlow', hhigh'⟩ - · have hlow' : - (dotDiffLo q k : Real) ≤ + kRealOfInputs inputs k d) := by + simpa [dotDiffLo, dotDiffLoHi, hbudget, hkey, hk] using hspecRef.1 + have hhigh' : dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) := by - simpa [dotDiffLo, hkey, hk] using hspecBase_bounds.1 - have hhigh' : - dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by + simpa [dotDiffHi, dotDiffLoHi, hbudget, hkey, hk] using hspecRef.2 + exact ⟨hlow', hhigh'⟩ + · have hlow' : + (dotDiffLo q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by - simpa [dotDiffHi, hkey, hk] using hspecBase_bounds.2 - exact ⟨hlow', hhigh'⟩ + kRealOfInputs inputs k d) := by + simpa [dotDiffLo, dotDiffLoHi, hbudget, hkey, hk] using + hspecBase_bounds.1 + have hhigh' : + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by + simpa [dotDiffHi, dotDiffLoHi, hbudget, hkey, hk] using + hspecBase_bounds.2 + exact ⟨hlow', hhigh'⟩ have hmarginAt_le : ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → marginAt q ≤ scoreGapLo q k := by @@ -1080,8 +1098,8 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N ∀ q k, weightBoundAt q k = min (weightBoundAtBase q k) (epsAt q) := by intro q k - simpa [weightBoundAt, weightBoundAtClampedBase] using - (Bounds.cacheBound2Task_apply (f := weightBoundAtClampedBase) q k) + simp [weightBoundAt, weightBoundAtClampedBase, weightBoundAtBaseCached, + Bounds.cacheBound2_apply, Bounds.cacheBound2Task_apply] have hepsAt : ∀ q, epsAt q = min (1 : Rat) @@ -1092,6 +1110,12 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N ratDivUp 1 (1 + scoreGapLo q k))) := by intro q have hsum : + (otherKeys q).sum (fun k => weightBoundAtBaseCached q k) = + (otherKeys q).sum (fun k => weightBoundAtBase q k) := by + refine Finset.sum_congr rfl ?_ + intro k hk + simp [weightBoundAtBaseCached, Bounds.cacheBound2Task_apply] + have hsum' : (otherKeys q).sum (fun k => weightBoundAtBase q k) = (otherKeys q).sum (fun k => if scoreGapLo q k < 0 then @@ -1102,7 +1126,7 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N intro k hk have hk' : k ≠ inputs.prev q := (Finset.mem_erase.mp hk).1 simp [hweightBoundAtBase q k hk'] - simpa [epsAt, epsAtBase, hsum] using + simpa [epsAt, epsAtBase, hsum, hsum'] using (Bounds.cacheBoundThunk_apply (f := epsAtBase) q) have oneHot_bounds_at : ∀ q, q ∈ inputs.active → From 2a1c43e51204b699b9fdf8fd3ef3a4dd649bc905 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Wed, 14 Jan 2026 19:59:19 +0100 Subject: [PATCH 184/244] Speed weighted logit-diff gap --- Nfp/Circuit/Cert/LogitDiff.lean | 28 ++++++++++++++++++++ Nfp/Sound/Induction/Core/Basic.lean | 5 +--- Nfp/Sound/Induction/CoreDefs.lean | 2 +- Nfp/Sound/Induction/CoreSound/Basic.lean | 33 +++++++----------------- Nfp/Sound/Induction/LogitDiff.lean | 26 +++++++++++++++++-- 5 files changed, 64 insertions(+), 30 deletions(-) diff --git a/Nfp/Circuit/Cert/LogitDiff.lean b/Nfp/Circuit/Cert/LogitDiff.lean index 183639e..5284c0c 100644 --- a/Nfp/Circuit/Cert/LogitDiff.lean +++ b/Nfp/Circuit/Cert/LogitDiff.lean @@ -86,6 +86,34 @@ def logitDiffLowerBoundWeightedAt (active : Finset (Fin seq)) else exact none +/-- Unfolding lemma for `logitDiffLowerBoundWeightedAt`. -/ +theorem logitDiffLowerBoundWeightedAt_def (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (weightBoundAt : Fin seq → Fin seq → Rat) + (valsLo : Fin seq → Rat) : + logitDiffLowerBoundWeightedAt active prev weightBoundAt valsLo = + by + classical + if h : active.Nonempty then + let others : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (prev q) + let gap : Fin seq → Rat := fun q => + (others q).sum (fun k => + let diff := valsLo (prev q) - valsLo k + let diffPos := max (0 : Rat) diff + if diffPos = 0 then + 0 + else + weightBoundAt q k * diffPos) + let f : Fin seq → Rat := fun q => valsLo (prev q) - gap q + let img := active.image f + have himg : img.Nonempty := h.image f + exact some (Finset.min' img himg) + else + exact none := by + classical + rfl + /-- The computed lower bound is below every active `prev` value minus the tolerance gap. -/ theorem logitDiffLowerBound_le (active : Finset (Fin seq)) (prev : Fin seq → Fin seq) diff --git a/Nfp/Sound/Induction/Core/Basic.lean b/Nfp/Sound/Induction/Core/Basic.lean index d44c805..9c13f66 100644 --- a/Nfp/Sound/Induction/Core/Basic.lean +++ b/Nfp/Sound/Induction/Core/Basic.lean @@ -673,10 +673,7 @@ def buildInductionHeadCoreCacheWith [NeZero seq] {dModel dHead : Nat} min (1 : Rat) total let epsAt : Fin seq → Rat := Bounds.cacheBoundThunk epsAtBase - let weightBoundAtClampedBase : Fin seq → Fin seq → Rat := fun q k => - min (weightBoundAtBaseCached q k) (epsAt q) - let weightBoundAt : Fin seq → Fin seq → Rat := - Bounds.cacheBound2 weightBoundAtClampedBase + let weightBoundAt : Fin seq → Fin seq → Rat := weightBoundAtBaseCached let margin : Rat := if h : inputs.active.Nonempty then inputs.active.inf' h marginAt diff --git a/Nfp/Sound/Induction/CoreDefs.lean b/Nfp/Sound/Induction/CoreDefs.lean index aa315be..6cbf702 100644 --- a/Nfp/Sound/Induction/CoreDefs.lean +++ b/Nfp/Sound/Induction/CoreDefs.lean @@ -215,7 +215,7 @@ structure InductionHeadCert (seq : Nat) where eps : Rat /-- Per-query weight tolerance derived from local margins. -/ epsAt : Fin seq → Rat - /-- Per-key weight bounds derived from score gaps, clamped by `epsAt`. -/ + /-- Per-key weight bounds derived from score gaps. -/ weightBoundAt : Fin seq → Fin seq → Rat /-- Score margin used to justify the weight tolerance. -/ margin : Rat diff --git a/Nfp/Sound/Induction/CoreSound/Basic.lean b/Nfp/Sound/Induction/CoreSound/Basic.lean index cb744a9..09ce05c 100644 --- a/Nfp/Sound/Induction/CoreSound/Basic.lean +++ b/Nfp/Sound/Induction/CoreSound/Basic.lean @@ -446,10 +446,7 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N min (1 : Rat) total let epsAt : Fin seq → Rat := Bounds.cacheBoundThunk epsAtBase - let weightBoundAtClampedBase : Fin seq → Fin seq → Rat := fun q k => - min (weightBoundAtBaseCached q k) (epsAt q) - let weightBoundAt : Fin seq → Fin seq → Rat := - Bounds.cacheBound2 weightBoundAtClampedBase + let weightBoundAt : Fin seq → Fin seq → Rat := weightBoundAtBaseCached let margin : Rat := if h : inputs.active.Nonempty then inputs.active.inf' h marginAt @@ -1096,10 +1093,9 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N simp [weightBoundAtBase, hk] have hweightBoundAt : ∀ q k, - weightBoundAt q k = min (weightBoundAtBase q k) (epsAt q) := by + weightBoundAt q k = weightBoundAtBase q k := by intro q k - simp [weightBoundAt, weightBoundAtClampedBase, weightBoundAtBaseCached, - Bounds.cacheBound2_apply, Bounds.cacheBound2Task_apply] + simp [weightBoundAt, weightBoundAtBaseCached, Bounds.cacheBound2Task_apply] have hepsAt : ∀ q, epsAt q = min (1 : Rat) @@ -1159,25 +1155,16 @@ theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : N (hweightBoundAt := hweightBoundAtBase) (hscore_gap_real_at := hscore_gap_real_at) q hq k hk - have hbound_eps : - weights q k ≤ (epsAt q : Real) := by - have honehot := oneHot_bounds_at q hq - exact honehot.other_le q rfl k hk - have hbound_min : - weights q k ≤ min (weightBoundAtBase q k : Real) (epsAt q : Real) := by - exact le_min hbound_base hbound_eps have hweightBoundAt_real : (weightBoundAt q k : Real) = - min (weightBoundAtBase q k : Real) (epsAt q : Real) := by - have hmin : - weightBoundAt q k = min (weightBoundAtBase q k) (epsAt q) := + (weightBoundAtBase q k : Real) := by + have hbase : weightBoundAt q k = weightBoundAtBase q k := hweightBoundAt q k - have hmin' : - ratToReal (weightBoundAt q k) = - ratToReal (min (weightBoundAtBase q k) (epsAt q)) := - congrArg ratToReal hmin - simpa [ratToReal_min, ratToReal_def] using hmin' - simpa [hweightBoundAt_real] using hbound_min + have hbase' : + ratToReal (weightBoundAt q k) = ratToReal (weightBoundAtBase q k) := + congrArg ratToReal hbase + simpa [ratToReal_def] using hbase' + simpa [hweightBoundAt_real] using hbound_base have hepsAt_le_eps : ∀ q, q ∈ inputs.active → epsAt q ≤ eps := by intro q hq diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean index e542c2a..eada082 100644 --- a/Nfp/Sound/Induction/LogitDiff.lean +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -137,7 +137,24 @@ def logitDiffLowerBoundFromCache (c : InductionHeadCert seq) (cache : LogitDiffC /-- Weighted logit-diff lower bound from a shared cache. -/ def logitDiffLowerBoundWeightedFromCache (c : InductionHeadCert seq) (cache : LogitDiffCache seq) : Option Rat := - Circuit.logitDiffLowerBoundWeightedAt c.active c.prev c.weightBoundAt cache.valsLo + let others : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (c.prev q) + let gapBase : Fin seq → Rat := fun q => + (others q).sum (fun k => + let diff := cache.valsLo (c.prev q) - cache.valsLo k + let diffPos := max (0 : Rat) diff + if diffPos = 0 then + 0 + else + c.weightBoundAt q k * diffPos) + let gap : Fin seq → Rat := Bounds.cacheBoundTask gapBase + if h : c.active.Nonempty then + let f : Fin seq → Rat := fun q => cache.valsLo (c.prev q) - gap q + let img := c.active.image f + have himg : img.Nonempty := h.image f + some (Finset.min' img himg) + else + none /-- `logitDiffLowerBoundFromCache` matches the cached default computation. -/ theorem logitDiffLowerBoundFromCache_eq (c : InductionHeadCert seq) : @@ -148,7 +165,12 @@ theorem logitDiffLowerBoundFromCache_eq (c : InductionHeadCert seq) : theorem logitDiffLowerBoundWeightedFromCache_eq (c : InductionHeadCert seq) : logitDiffLowerBoundWeightedFromCache c (logitDiffCache c) = logitDiffLowerBoundFromCertWeighted c := by - rfl + classical + unfold logitDiffLowerBoundWeightedFromCache logitDiffLowerBoundFromCertWeighted logitDiffCache + have hvals : Bounds.cacheBoundTask c.values.valsLo = c.values.valsLo := by + funext k + simp [Bounds.cacheBoundTask_apply] + simp [hvals, Bounds.cacheBoundTask_apply, logitDiffLowerBoundWeightedAt_def] /-- Best available logit-diff lower bound from an induction certificate. -/ def logitDiffLowerBoundFromCertBest (c : InductionHeadCert seq) : Option Rat := From a4fbf98914e6786b1a4ad8bd1108c64302b2661e Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 16 Jan 2026 04:29:57 +0100 Subject: [PATCH 185/244] Optimize induction head logit-diff caching --- Nfp/Circuit/Cert/LogitDiff.lean | 151 +- Nfp/Core/Basic.lean | 1 + Nfp/IO/InductionHead/Basic.lean | 140 +- Nfp/IO/Timing.lean | 32 +- Nfp/Sound/Induction.lean | 2 + Nfp/Sound/Induction/Core/Basic.lean | 452 +++++- Nfp/Sound/Induction/CoreDefs.lean | 21 + Nfp/Sound/Induction/CoreSound/Basic.lean | 1371 +---------------- .../CoreSound/Basic/CacheBounds.lean | 615 ++++++++ .../Induction/CoreSound/Basic/CertSound.lean | 1316 ++++++++++++++++ .../CoreSound/Basic/DefaultSound.lean | 29 + Nfp/Sound/Induction/CoreSound/Values.lean | 83 + Nfp/Sound/Induction/HeadOutput.lean | 26 + Nfp/Sound/Induction/LogitDiff.lean | 755 ++++++++- Nfp/Sound/Induction/OneHot.lean | 114 ++ Nfp/Sound/Induction/Refine.lean | 331 ++++ Nfp/Sound/Induction/RefineSound.lean | 596 +++++++ Nfp/System/Dag.lean | 1 + Nfp/Tactic/Linter.lean | 9 + Nfp/Tactic/Linter/NoHeartbeats.lean | 56 + lakefile.toml | 1 + 21 files changed, 4471 insertions(+), 1631 deletions(-) create mode 100644 Nfp/Sound/Induction/CoreSound/Basic/CacheBounds.lean create mode 100644 Nfp/Sound/Induction/CoreSound/Basic/CertSound.lean create mode 100644 Nfp/Sound/Induction/CoreSound/Basic/DefaultSound.lean create mode 100644 Nfp/Sound/Induction/Refine.lean create mode 100644 Nfp/Sound/Induction/RefineSound.lean create mode 100644 Nfp/Tactic/Linter.lean create mode 100644 Nfp/Tactic/Linter/NoHeartbeats.lean diff --git a/Nfp/Circuit/Cert/LogitDiff.lean b/Nfp/Circuit/Cert/LogitDiff.lean index 5284c0c..8eecac8 100644 --- a/Nfp/Circuit/Cert/LogitDiff.lean +++ b/Nfp/Circuit/Cert/LogitDiff.lean @@ -3,7 +3,7 @@ module public import Nfp.Core.Basic -public import Mathlib.Data.Finset.Image +public import Mathlib.Data.Finset.Lattice.Fold public import Nfp.Circuit.Layers.Induction /-! @@ -26,9 +26,7 @@ def logitDiffLowerBound (active : Finset (Fin seq)) if h : active.Nonempty then let gap := eps * (hi - lo) let f : Fin seq → Rat := fun q => vals (prev q) - gap - let img := active.image f - have himg : img.Nonempty := h.image f - exact some (Finset.min' img himg) + exact some (active.inf' h f) else exact none @@ -40,9 +38,7 @@ def logitDiffLowerBoundAt (active : Finset (Fin seq)) if h : active.Nonempty then let gap : Fin seq → Rat := fun q => epsAt q * (hi - lo) let f : Fin seq → Rat := fun q => vals (prev q) - gap q - let img := active.image f - have himg : img.Nonempty := h.image f - exact some (Finset.min' img himg) + exact some (active.inf' h f) else exact none @@ -55,9 +51,21 @@ def logitDiffLowerBoundAtLo (active : Finset (Fin seq)) if h : active.Nonempty then let f : Fin seq → Rat := fun q => valsLo (prev q) - epsAt q * (valsLo (prev q) - lo) - let img := active.image f - have himg : img.Nonempty := h.image f - exact some (Finset.min' img himg) + exact some (active.inf' h f) + else + exact none + +/-- Compute a lower bound on the logit-diff contribution using per-query eps and per-query + lower bounds for other values. -/ +def logitDiffLowerBoundAtLoAt (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (epsAt : Fin seq → Rat) (loAt : Fin seq → Rat) (valsLo : Fin seq → Rat) : Option Rat := by + classical + if h : active.Nonempty then + let f : Fin seq → Rat := fun q => + let delta := valsLo (prev q) - loAt q + valsLo (prev q) - epsAt q * max (0 : Rat) delta + exact some (active.inf' h f) else exact none @@ -69,20 +77,12 @@ def logitDiffLowerBoundWeightedAt (active : Finset (Fin seq)) (valsLo : Fin seq → Rat) : Option Rat := by classical if h : active.Nonempty then - let others : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (prev q) let gap : Fin seq → Rat := fun q => - (others q).sum (fun k => + (Finset.univ : Finset (Fin seq)).sum (fun k => let diff := valsLo (prev q) - valsLo k - let diffPos := max (0 : Rat) diff - if hdiff : diffPos = 0 then - 0 - else - weightBoundAt q k * diffPos) + weightBoundAt q k * max (0 : Rat) diff) let f : Fin seq → Rat := fun q => valsLo (prev q) - gap q - let img := active.image f - have himg : img.Nonempty := h.image f - exact some (Finset.min' img himg) + exact some (active.inf' h f) else exact none @@ -95,20 +95,12 @@ theorem logitDiffLowerBoundWeightedAt_def (active : Finset (Fin seq)) by classical if h : active.Nonempty then - let others : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (prev q) let gap : Fin seq → Rat := fun q => - (others q).sum (fun k => + (Finset.univ : Finset (Fin seq)).sum (fun k => let diff := valsLo (prev q) - valsLo k - let diffPos := max (0 : Rat) diff - if diffPos = 0 then - 0 - else - weightBoundAt q k * diffPos) + weightBoundAt q k * max (0 : Rat) diff) let f : Fin seq → Rat := fun q => valsLo (prev q) - gap q - let img := active.image f - have himg : img.Nonempty := h.image f - exact some (Finset.min' img himg) + exact some (active.inf' h f) else exact none := by classical @@ -126,13 +118,10 @@ theorem logitDiffLowerBound_le (active : Finset (Fin seq)) have hnonempty : active.Nonempty := ⟨q, hq⟩ let gap := eps * (hi - lo) let f : Fin seq → Rat := fun q => vals (prev q) - gap - have hbound' : (active.image f).min' (hnonempty.image f) = lb := by + have hbound' : active.inf' hnonempty f = lb := by simpa [logitDiffLowerBound, hnonempty, f, gap] using hbound - have hmem : f q ∈ (active.image f) := by - refine Finset.mem_image.2 ?_ - exact ⟨q, hq, rfl⟩ - have hmin : (active.image f).min' (hnonempty.image f) ≤ f q := - Finset.min'_le _ _ hmem + have hmin : active.inf' hnonempty f ≤ f q := + Finset.inf'_le (s := active) (f := f) hq simpa [f, gap, hbound'] using hmin /-- The per-query lower bound is below every active `prev` value minus the local gap. -/ @@ -147,13 +136,10 @@ theorem logitDiffLowerBoundAt_le (active : Finset (Fin seq)) have hnonempty : active.Nonempty := ⟨q, hq⟩ let gap : Fin seq → Rat := fun q => epsAt q * (hi - lo) let f : Fin seq → Rat := fun q => vals (prev q) - gap q - have hbound' : (active.image f).min' (hnonempty.image f) = lb := by + have hbound' : active.inf' hnonempty f = lb := by simpa [logitDiffLowerBoundAt, hnonempty, f, gap] using hbound - have hmem : f q ∈ (active.image f) := by - refine Finset.mem_image.2 ?_ - exact ⟨q, hq, rfl⟩ - have hmin : (active.image f).min' (hnonempty.image f) ≤ f q := - Finset.min'_le _ _ hmem + have hmin : active.inf' hnonempty f ≤ f q := + Finset.inf'_le (s := active) (f := f) hq simpa [f, gap, hbound'] using hmin /-- The per-query lower bound is below every active `prev` value minus the `lo`-gap. -/ @@ -168,16 +154,30 @@ theorem logitDiffLowerBoundAtLo_le (active : Finset (Fin seq)) have hnonempty : active.Nonempty := ⟨q, hq⟩ let f : Fin seq → Rat := fun q => valsLo (prev q) - epsAt q * (valsLo (prev q) - lo) - have hbound' : (active.image f).min' (hnonempty.image f) = lb := by + have hbound' : active.inf' hnonempty f = lb := by simpa [logitDiffLowerBoundAtLo, hnonempty, f] using hbound - have hmem : f q ∈ (active.image f) := by - refine Finset.mem_image.2 ?_ - exact ⟨q, hq, rfl⟩ - have hmin : (active.image f).min' (hnonempty.image f) ≤ f q := - Finset.min'_le _ _ hmem - have hmin' : lb ≤ f q := by - simpa [hbound'] using hmin - simpa [f] using hmin' + have hmin : active.inf' hnonempty f ≤ f q := + Finset.inf'_le (s := active) (f := f) hq + simpa [f, hbound'] using hmin + +/-- The per-query lower bound is below every active `prev` value minus the local `loAt` gap. -/ +theorem logitDiffLowerBoundAtLoAt_le (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (epsAt : Fin seq → Rat) (loAt : Fin seq → Rat) (valsLo : Fin seq → Rat) + (q : Fin seq) (hq : q ∈ active) : + ∀ lb, logitDiffLowerBoundAtLoAt active prev epsAt loAt valsLo = some lb → + lb ≤ valsLo (prev q) - epsAt q * max (0 : Rat) (valsLo (prev q) - loAt q) := by + classical + intro lb hbound + have hnonempty : active.Nonempty := ⟨q, hq⟩ + let f : Fin seq → Rat := fun q => + let delta := valsLo (prev q) - loAt q + valsLo (prev q) - epsAt q * max (0 : Rat) delta + have hbound' : active.inf' hnonempty f = lb := by + simpa [logitDiffLowerBoundAtLoAt, hnonempty, f] using hbound + have hmin : active.inf' hnonempty f ≤ f q := + Finset.inf'_le (s := active) (f := f) hq + simpa [f, hbound'] using hmin /-- The weighted lower bound is below every active `prev` value minus the weighted gap. -/ theorem logitDiffLowerBoundWeightedAt_le (active : Finset (Fin seq)) @@ -188,49 +188,20 @@ theorem logitDiffLowerBoundWeightedAt_le (active : Finset (Fin seq)) ∀ lb, logitDiffLowerBoundWeightedAt active prev weightBoundAt valsLo = some lb → lb ≤ valsLo (prev q) - - ((Finset.univ : Finset (Fin seq)).erase (prev q)).sum (fun k => + (Finset.univ : Finset (Fin seq)).sum (fun k => weightBoundAt q k * max (0 : Rat) (valsLo (prev q) - valsLo k)) := by classical intro lb hbound have hnonempty : active.Nonempty := ⟨q, hq⟩ - let others : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (prev q) let gap : Fin seq → Rat := fun q => - (others q).sum (fun k => - let diff := valsLo (prev q) - valsLo k - if hdiff : max (0 : Rat) diff = 0 then - 0 - else - weightBoundAt q k * max (0 : Rat) diff) + (Finset.univ : Finset (Fin seq)).sum (fun k => + weightBoundAt q k * max (0 : Rat) (valsLo (prev q) - valsLo k)) let f : Fin seq → Rat := fun q => valsLo (prev q) - gap q - have hbound' : (active.image f).min' (hnonempty.image f) = lb := by - simpa [logitDiffLowerBoundWeightedAt, hnonempty, f, gap, others] using hbound - have hmem : f q ∈ (active.image f) := by - refine Finset.mem_image.2 ?_ - exact ⟨q, hq, rfl⟩ - have hmin : (active.image f).min' (hnonempty.image f) ≤ f q := - Finset.min'_le _ _ hmem - have hmin' : lb ≤ f q := by - simpa [hbound'] using hmin - have hgap_eq : - gap q = - (others q).sum (fun k => - weightBoundAt q k * max (0 : Rat) (valsLo (prev q) - valsLo k)) := by - classical - dsimp [gap] - refine Finset.sum_congr rfl ?_ - intro k hk - by_cases hdiff : max (0 : Rat) (valsLo (prev q) - valsLo k) = 0 - · simp [hdiff] - · simp [hdiff] - have hmin'' : lb ≤ valsLo (prev q) - gap q := by - simpa [f] using hmin' - have hmin''' : - lb ≤ valsLo (prev q) - - (others q).sum (fun k => - weightBoundAt q k * max (0 : Rat) (valsLo (prev q) - valsLo k)) := by - simpa [hgap_eq] using hmin'' - simpa [others] using hmin''' + have hbound' : active.inf' hnonempty f = lb := by + simpa [logitDiffLowerBoundWeightedAt, hnonempty, f, gap] using hbound + have hmin : active.inf' hnonempty f ≤ f q := + Finset.inf'_le (s := active) (f := f) hq + simpa [f, gap, hbound'] using hmin end Circuit diff --git a/Nfp/Core/Basic.lean b/Nfp/Core/Basic.lean index 2e593b0..36e5f7b 100644 --- a/Nfp/Core/Basic.lean +++ b/Nfp/Core/Basic.lean @@ -2,6 +2,7 @@ module +public meta import Nfp.Tactic.Linter public import Mathlib.Algebra.Order.Group.Unbundled.Abs public import Mathlib.Data.NNReal.Defs public import Mathlib.Data.NNReal.Basic diff --git a/Nfp/IO/InductionHead/Basic.lean b/Nfp/IO/InductionHead/Basic.lean index 8f7be5e..6614293 100644 --- a/Nfp/IO/InductionHead/Basic.lean +++ b/Nfp/IO/InductionHead/Basic.lean @@ -79,6 +79,17 @@ def loadInductionHeadInputs (path : System.FilePath) : private def ratToString (x : Rat) : String := toString x +private def ratOptToString (x : Option Rat) : String := + match x with + | some v => ratToString v + | none => "none" + +private def logitDiffDebugEnabled : IO Bool := do + return (← IO.getEnv "NFP_LOGITDIFF_DEBUG").isSome + +private def logitDiffRefineEnabled : IO Bool := do + return (← IO.getEnv "NFP_LOGITDIFF_REFINE").isSome + private def renderResidualIntervalCert {n : Nat} (c : Circuit.ResidualIntervalCert n) : String := let header := s!"dim {n}" let lines := @@ -714,14 +725,14 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} let tCert0 ← monoUsNow let certTask : Task - (Option { c : Sound.InductionHeadCert seq // - Sound.InductionHeadCertSound inputs c }) := + (Option { cache : Sound.InductionHeadCoreCache seq dModel dHead // + Sound.InductionHeadCertSound inputs cache.cert }) := Task.spawn (prio := Task.Priority.dedicated) (fun _ => - match Sound.buildInductionCertFromHeadWith? cfg inputs with + match Sound.buildInductionCertFromHeadWithCache? cfg inputs with | none => none - | some ⟨cert, hcert⟩ => - let _ := cert.active.card - some ⟨cert, hcert⟩) + | some ⟨cache, hcert⟩ => + let _ := cache.cert.active.card + some ⟨cache, hcert⟩) let heartbeatMs ← heartbeatMs if heartbeatMs ≠ 0 then let mut finished := (← IO.hasFinished certTask) @@ -742,7 +753,8 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} | none => IO.eprintln "error: head inputs rejected" return 2 - | some ⟨cert, _hcert⟩ => + | some ⟨cache, _hcert⟩ => + let cert := cache.cert timingPrint "timing: head active count start" timingFlush let activeCount := cert.active.card @@ -780,7 +792,7 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} profileLogitDiffWeighted cert logitCache let logitDiffLB0? ← timePureWithHeartbeat "head: logit-diff lower bound unweighted" (fun () => - Nfp.Sound.logitDiffLowerBoundFromCache cert logitCache) + Nfp.Sound.logitDiffLowerBoundRefineOnDemand inputs cache cert logitCache) let logitDiffLB? ← match logitDiffLB0? with | none => pure none @@ -854,14 +866,14 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} let tCert0 ← monoUsNow let certTask : Task - (Option { c : Sound.InductionHeadCert seq // - Sound.InductionHeadCertSound inputs c }) := + (Option { cache : Sound.InductionHeadCoreCache seq dModel dHead // + Sound.InductionHeadCertSound inputs cache.cert }) := Task.spawn (prio := Task.Priority.dedicated) (fun _ => - match Sound.buildInductionCertFromHeadWith? cfg inputs with + match Sound.buildInductionCertFromHeadWithCache? cfg inputs with | none => none - | some ⟨cert, hcert⟩ => - let _ := cert.active.card - some ⟨cert, hcert⟩) + | some ⟨cache, hcert⟩ => + let _ := cache.cert.active.card + some ⟨cache, hcert⟩) let heartbeatMs ← heartbeatMs if heartbeatMs ≠ 0 then let mut finished := (← IO.hasFinished certTask) @@ -882,7 +894,8 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} | none => IO.eprintln "error: head inputs rejected" return 2 - | some ⟨cert, _hcert⟩ => + | some ⟨cache, _hcert⟩ => + let cert := cache.cert let activeCount := cert.active.card let defaultMinActive := max 1 (seq / 8) let minActive := minActive?.getD defaultMinActive @@ -926,7 +939,96 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} Nfp.Sound.logitDiffLowerBoundWeightedFromCache cert logitCache)) let logitDiffLB0? ← timePureWithHeartbeat "head: logit-diff lower bound unweighted" (fun () => - Nfp.Sound.logitDiffLowerBoundFromCache cert logitCache) + Nfp.Sound.logitDiffLowerBoundRefineOnDemand inputs cache cert logitCache) + if (← logitDiffDebugEnabled) then + match logitDiffLB0? with + | some lb0 => + if lb0 ≤ 0 then + match Nfp.Sound.logitDiffLowerBoundAtLoDebug cert logitCache with + | none => + IO.eprintln "debug: logitDiffLB0 witness not found" + | some ⟨info, _⟩ => + IO.eprintln + s!"debug: logitDiffLB0 witness q={info.q.1}, prev={info.prev.1}" + IO.eprintln + s!"debug: eps={ratToString info.eps}, \ + valsPrevLo={ratToString info.valsPrevLo}, \ + loAt={ratToString info.loAt}, \ + lo={ratToString info.lo}" + IO.eprintln + s!"debug: valsPrevLoMinusLoAt={ratToString info.valsPrevLoMinusLoAt}, \ + gap={ratToString info.gap}, \ + fAtQ={ratToString (info.valsPrevLo - info.gap)}, \ + lbAtQ={ratToString info.lbAtQ}" + let weightBoundAt := cert.weightBoundAt + let step : (Rat × Nat × Rat) → Fin seq → (Rat × Nat × Rat) := + fun acc k => + if k = info.prev then + acc + else + let w := weightBoundAt info.q k + let sum := acc.1 + w + let ones := if w = (1 : Rat) then acc.2.1 + 1 else acc.2.1 + let maxW := if w > acc.2.2 then w else acc.2.2 + (sum, ones, maxW) + let acc := Sound.Linear.foldlFin seq step (0, 0, 0) + IO.eprintln + s!"debug: epsAt={ratToString (cert.epsAt info.q)}, \ + weightSum={ratToString acc.1}, ones={acc.2.1}, \ + maxWeight={ratToString acc.2.2}" + let valsLo := logitCache.valsLo + let stepOnes : Array String → Fin seq → Array String := + fun acc k => + if k = info.prev then + acc + else + let w := weightBoundAt info.q k + if w = (1 : Rat) then + acc.push + s!"k={k.1} valsLo={ratToString (valsLo k)}" + else + acc + let ones := Sound.Linear.foldlFin seq stepOnes #[] + let onesMsg := + if ones.isEmpty then + "none" + else + String.intercalate ", " ones.toList + IO.eprintln s!"debug: weightBoundAt=1 keys: {onesMsg}" + let stepLoAt : Array String → Fin seq → Array String := + fun acc k => + if k = info.prev then + acc + else if valsLo k = info.loAt then + acc.push + s!"k={k.1} w={ratToString (weightBoundAt info.q k)}" + else + acc + let loAtKeys := Sound.Linear.foldlFin seq stepLoAt #[] + let loAtMsg := + if loAtKeys.isEmpty then + "none" + else + String.intercalate ", " loAtKeys.toList + IO.eprintln s!"debug: loAt keys: {loAtMsg}" + if (← logitDiffRefineEnabled) then + let refineBudget := max 1 cfg.splitBudgetDiffRefined + let refineKeys := Sound.refineKeysAtWithWeightOnes inputs cache info.q + IO.eprintln + s!"debug: refine budget={refineBudget}, \ + refineKeys.card={refineKeys.card}" + let refineSpec := + Sound.refineSpecForQueryWithWeightOnes inputs cache info.q refineBudget + let refinedLB? := + Sound.logitDiffLowerBoundRefinedFromCache + inputs cache cert logitCache refineSpec + match refinedLB? with + | none => + IO.eprintln "debug: refined logitDiffLB0 none" + | some lb => + IO.eprintln + s!"debug: refined logitDiffLB0={ratToString lb}" + | none => pure () let needsWeighted : Bool := match logitDiffLB0? with | none => true @@ -969,6 +1071,12 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} return 2 | some logitDiffLB => if logitDiffLB ≤ 0 then + if (← logitDiffDebugEnabled) then + IO.eprintln + s!"debug: logitDiffLB0={ratOptToString logitDiffLB0?}, \ + logitDiffWeighted={ratOptToString logitDiffWeighted?}, \ + logitDiffLB={ratToString logitDiffLB}, \ + bound={boundLabel}" IO.eprintln s!"error: logitDiffLB {ratToString logitDiffLB} \ is not strictly positive" diff --git a/Nfp/IO/Timing.lean b/Nfp/IO/Timing.lean index b71c654..e755868 100644 --- a/Nfp/IO/Timing.lean +++ b/Nfp/IO/Timing.lean @@ -295,29 +295,31 @@ def profileLogitDiffWeighted {seq : Nat} else timingPrint "timing: logit-diff profile start" timingFlush - let _ ← timePureWithHeartbeat "logit-diff profile: valsLo force" (fun () => + let valsLoArr ← timePureWithHeartbeat "logit-diff profile: valsLo force" (fun () => Array.ofFn (fun q : Fin seq => cache.valsLo q)) - let _ ← timePureWithHeartbeat "logit-diff profile: weightBoundAt force" (fun () => + let weightRows ← timePureWithHeartbeat "logit-diff profile: weightBoundAt force" (fun () => Array.ofFn (fun q : Fin seq => Array.ofFn (fun k : Fin seq => cert.weightBoundAt q k))) + let valsLo : Fin seq → Rat := fun k => + valsLoArr.getD k.1 (0 : Rat) + let weightBoundAt : Fin seq → Fin seq → Rat := fun q k => + let row := weightRows.getD q.1 #[] + row.getD k.1 (0 : Rat) let _ ← timePureWithHeartbeat "logit-diff profile: weighted gap sum" (fun () => - let others : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (cert.prev q) Array.ofFn (fun q : Fin seq => - (others q).sum (fun k => - let diff := cache.valsLo (cert.prev q) - cache.valsLo k - cert.weightBoundAt q k * max (0 : Rat) diff))) + let valsLoPrev := valsLo (cert.prev q) + Linear.sumFin seq (fun k => + let diff := valsLoPrev - valsLo k + weightBoundAt q k * max (0 : Rat) diff))) let _ ← timePureWithHeartbeat "logit-diff profile: weighted min" (fun () => - let others : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (cert.prev q) let gap : Fin seq → Rat := fun q => - (others q).sum (fun k => - let diff := cache.valsLo (cert.prev q) - cache.valsLo k - cert.weightBoundAt q k * max (0 : Rat) diff) - let f : Fin seq → Rat := fun q => cache.valsLo (cert.prev q) - gap q + let valsLoPrev := valsLo (cert.prev q) + Linear.sumFin seq (fun k => + let diff := valsLoPrev - valsLo k + weightBoundAt q k * max (0 : Rat) diff) + let f : Fin seq → Rat := fun q => valsLo (cert.prev q) - gap q if h : cert.active.Nonempty then - let img := cert.active.image f - let _ := Finset.min' img (h.image f) + let _ := cert.active.inf' h f () else ()) diff --git a/Nfp/Sound/Induction.lean b/Nfp/Sound/Induction.lean index ac6fb80..0343827 100644 --- a/Nfp/Sound/Induction.lean +++ b/Nfp/Sound/Induction.lean @@ -9,6 +9,8 @@ public import Nfp.Sound.Induction.HeadBounds public import Nfp.Sound.Induction.HeadOutput public import Nfp.Sound.Induction.LogitDiff public import Nfp.Sound.Induction.OneHot +public import Nfp.Sound.Induction.Refine +public import Nfp.Sound.Induction.RefineSound /-! Sound builders for induction certificates. diff --git a/Nfp/Sound/Induction/Core/Basic.lean b/Nfp/Sound/Induction/Core/Basic.lean index 9c13f66..07459fd 100644 --- a/Nfp/Sound/Induction/Core/Basic.lean +++ b/Nfp/Sound/Induction/Core/Basic.lean @@ -245,6 +245,323 @@ structure InductionHeadCoreCache (seq dModel dHead : Nat) where /-- Induction-head certificate. -/ cert : InductionHeadCert seq +/-- Cached certificate-related fields derived from score gaps. -/ +structure InductionHeadCertFields (seq : Nat) where + /-- Margin per query. -/ + marginAt : Fin seq → Rat + /-- Base weight bounds derived from score gaps. -/ + weightBoundAtBase : Fin seq → Fin seq → Rat + /-- Cached base weight bounds. -/ + weightBoundAtBaseCached : Fin seq → Fin seq → Rat + /-- Base epsilon per query. -/ + epsAtBase : Fin seq → Rat + /-- Epsilon per query. -/ + epsAt : Fin seq → Rat + /-- Per-key weight bounds derived from score gaps. -/ + weightBoundAt : Fin seq → Fin seq → Rat + /-- Global margin. -/ + margin : Rat + /-- Global epsilon. -/ + eps : Rat + +/-- Build certificate-related cached fields from score gaps. -/ +def buildInductionHeadCertFields [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (otherKeys : Fin seq → Finset (Fin seq)) + (scoreGapLo : Fin seq → Fin seq → Rat) : InductionHeadCertFields seq := by + let marginAt : Fin seq → Rat := fun q => + if hq : q ∈ inputs.active then + let other := otherKeys q + if h : other.Nonempty then + other.inf' h (fun k => scoreGapLo q k) + else + (0 : Rat) + else + (0 : Rat) + let weightBoundAtBase : Fin seq → Fin seq → Rat := fun q k => + if hk : k = inputs.prev q then + (0 : Rat) + else + let gap := scoreGapLo q k + if gap < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + gap) + let weightBoundAtBaseCached : Fin seq → Fin seq → Rat := + Bounds.cacheBound2Task weightBoundAtBase + let epsAtBase : Fin seq → Rat := fun q => + let other := otherKeys q + let total := other.sum (fun k => weightBoundAtBaseCached q k) + min (1 : Rat) total + let epsAt : Fin seq → Rat := Bounds.cacheBoundThunk epsAtBase + let weightBoundAt : Fin seq → Fin seq → Rat := weightBoundAtBaseCached + let margin : Rat := + if h : inputs.active.Nonempty then + inputs.active.inf' h marginAt + else + (0 : Rat) + let eps : Rat := + if h : inputs.active.Nonempty then + inputs.active.sup' h epsAt + else + (0 : Rat) + exact + { marginAt := marginAt + weightBoundAtBase := weightBoundAtBase + weightBoundAtBaseCached := weightBoundAtBaseCached + epsAtBase := epsAtBase + epsAt := epsAt + weightBoundAt := weightBoundAt + margin := margin + eps := eps } + +/-- Unfolding lemma for `buildInductionHeadCertFields`. -/ +theorem buildInductionHeadCertFields_def [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (otherKeys : Fin seq → Finset (Fin seq)) + (scoreGapLo : Fin seq → Fin seq → Rat) : + buildInductionHeadCertFields inputs otherKeys scoreGapLo = + (let marginAt : Fin seq → Rat := fun q => + if _ : q ∈ inputs.active then + let other := otherKeys q + if h : other.Nonempty then + other.inf' h (fun k => scoreGapLo q k) + else + (0 : Rat) + else + (0 : Rat) + let weightBoundAtBase : Fin seq → Fin seq → Rat := fun q k => + if _ : k = inputs.prev q then + (0 : Rat) + else + let gap := scoreGapLo q k + if gap < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + gap) + let weightBoundAtBaseCached : Fin seq → Fin seq → Rat := + Bounds.cacheBound2Task weightBoundAtBase + let epsAtBase : Fin seq → Rat := fun q => + let other := otherKeys q + let total := other.sum (fun k => weightBoundAtBaseCached q k) + min (1 : Rat) total + let epsAt : Fin seq → Rat := Bounds.cacheBoundThunk epsAtBase + let weightBoundAt : Fin seq → Fin seq → Rat := weightBoundAtBaseCached + let margin : Rat := + if h : inputs.active.Nonempty then + inputs.active.inf' h marginAt + else + (0 : Rat) + let eps : Rat := + if h : inputs.active.Nonempty then + inputs.active.sup' h epsAt + else + (0 : Rat) + { marginAt := marginAt + weightBoundAtBase := weightBoundAtBase + weightBoundAtBaseCached := weightBoundAtBaseCached + epsAtBase := epsAtBase + epsAt := epsAt + weightBoundAt := weightBoundAt + margin := margin + eps := eps }) := by + rfl + +/-- The `eps` field of `buildInductionHeadCertFields` is the active supremum when nonempty. -/ +theorem buildInductionHeadCertFields_eps_eq [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (otherKeys : Fin seq → Finset (Fin seq)) + (scoreGapLo : Fin seq → Fin seq → Rat) : + (buildInductionHeadCertFields inputs otherKeys scoreGapLo).eps = + if h : inputs.active.Nonempty then + inputs.active.sup' h + (buildInductionHeadCertFields inputs otherKeys scoreGapLo).epsAt + else + (0 : Rat) := by + rfl + +/-- Build an induction-head certificate from cached fields. -/ +def inductionHeadCertOfCacheFields [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (eps : Rat) (epsAt : Fin seq → Rat) (weightBoundAt : Fin seq → Fin seq → Rat) + (margin : Rat) (valCert : ValueInterval seq) : InductionHeadCert seq := + { eps := eps + epsAt := epsAt + weightBoundAt := weightBoundAt + margin := margin + active := inputs.active + prev := inputs.prev + values := valCert } + +/-- Unfolding lemma for `inductionHeadCertOfCacheFields`. -/ +theorem inductionHeadCertOfCacheFields_def [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (eps : Rat) (epsAt : Fin seq → Rat) (weightBoundAt : Fin seq → Fin seq → Rat) + (margin : Rat) (valCert : ValueInterval seq) : + inductionHeadCertOfCacheFields inputs eps epsAt weightBoundAt margin valCert = + { eps := eps + epsAt := epsAt + weightBoundAt := weightBoundAt + margin := margin + active := inputs.active + prev := inputs.prev + values := valCert } := by + rfl + +/-- Build a value-interval certificate from per-query bounds. -/ +def buildInductionHeadValCert [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (valsLo valsHi : Fin seq → Rat) : ValueInterval seq := by + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := by simp [univ] + let lo := univ.inf' hnonempty valsLo + let hi := univ.sup' hnonempty valsHi + exact + { lo := lo + hi := hi + valsLo := valsLo + valsHi := valsHi + direction := some inputs.directionSpec } + +/-- Unfolding lemma for `buildInductionHeadValCert`. -/ +theorem buildInductionHeadValCert_def [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (valsLo valsHi : Fin seq → Rat) : + buildInductionHeadValCert inputs valsLo valsHi = + (let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := by simp [univ] + let lo := univ.inf' hnonempty valsLo + let hi := univ.sup' hnonempty valsHi + { lo := lo + hi := hi + valsLo := valsLo + valsHi := valsHi + direction := some inputs.directionSpec }) := by + rfl + +/-- `buildInductionHeadValCert` preserves the provided lower bounds. -/ +theorem buildInductionHeadValCert_valsLo [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (valsLo valsHi : Fin seq → Rat) : + (buildInductionHeadValCert inputs valsLo valsHi).valsLo = valsLo := by + rfl + +/-- `buildInductionHeadValCert` preserves the provided upper bounds. -/ +theorem buildInductionHeadValCert_valsHi [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (valsLo valsHi : Fin seq → Rat) : + (buildInductionHeadValCert inputs valsLo valsHi).valsHi = valsHi := by + rfl + +/-- `buildInductionHeadValCert` yields pointwise value bounds from interval bounds. -/ +theorem buildInductionHeadValCert_bounds_at [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (valsReal : Fin seq → Real) + (valsLo valsHi : Fin seq → Rat) + (hvals : ∀ k, (valsLo k : Real) ≤ valsReal k ∧ valsReal k ≤ (valsHi k : Real)) : + ∀ k, + ((buildInductionHeadValCert inputs valsLo valsHi).valsLo k : Real) ≤ valsReal k ∧ + valsReal k ≤ ((buildInductionHeadValCert inputs valsLo valsHi).valsHi k : Real) := by + intro k + have hprojLo : + (buildInductionHeadValCert inputs valsLo valsHi).valsLo = valsLo := by + exact buildInductionHeadValCert_valsLo (inputs := inputs) valsLo valsHi + have hprojHi : + (buildInductionHeadValCert inputs valsLo valsHi).valsHi = valsHi := by + exact buildInductionHeadValCert_valsHi (inputs := inputs) valsLo valsHi + have hlo : + ((buildInductionHeadValCert inputs valsLo valsHi).valsLo k : Real) = valsLo k := by + exact congrArg (fun r : Rat => (r : Real)) (congrArg (fun f => f k) hprojLo) + have hhi : + ((buildInductionHeadValCert inputs valsLo valsHi).valsHi k : Real) = valsHi k := by + exact congrArg (fun r : Rat => (r : Real)) (congrArg (fun f => f k) hprojHi) + simpa [hlo, hhi] using hvals k + +/-- `buildInductionHeadValCert` satisfies `ValueIntervalBounds` from pointwise value bounds. -/ +theorem buildInductionHeadValCert_bounds [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (valsReal : Fin seq → Real) + (valsLo valsHi : Fin seq → Rat) + (hvals : ∀ k, (valsLo k : Real) ≤ valsReal k ∧ valsReal k ≤ (valsHi k : Real)) : + ValueIntervalBounds (vals := valsReal) (buildInductionHeadValCert inputs valsLo valsHi) := by + let valCert := buildInductionHeadValCert inputs valsLo valsHi + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := by simp [univ] + let lo := univ.inf' hnonempty valsLo + let hi := univ.sup' hnonempty valsHi + refine + { lo_le_hi := ?_ + lo_le_valsLo := ?_ + vals_bounds := ?_ + valsHi_le_hi := ?_ } + · rcases (Finset.univ_nonempty : univ.Nonempty) with ⟨k0, hk0⟩ + have hmem0 : k0 ∈ univ := hk0 + have hlo : (valCert.lo : Real) ≤ (valCert.valsLo k0 : Real) := by + have hloRat : valCert.lo ≤ valCert.valsLo k0 := by + change lo ≤ valsLo k0 + dsimp [lo] + refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) + (f := valsLo) (a := valsLo k0)).2 ⟨k0, hmem0, le_rfl⟩ + simpa [ratToReal_def] using ratToReal_le_of_le hloRat + have hvals : (valCert.valsLo k0 : Real) ≤ valsReal k0 ∧ + valsReal k0 ≤ (valCert.valsHi k0 : Real) := by + exact buildInductionHeadValCert_bounds_at (inputs := inputs) + (valsReal := valsReal) (valsLo := valsLo) (valsHi := valsHi) hvals k0 + have hhi : (valCert.valsHi k0 : Real) ≤ (valCert.hi : Real) := by + have hhiRat : valCert.valsHi k0 ≤ valCert.hi := by + change valsHi k0 ≤ hi + dsimp [hi] + refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) + (f := valsHi) (a := valsHi k0)).2 ⟨k0, ⟨hmem0, le_rfl⟩⟩ + simpa [ratToReal_def] using ratToReal_le_of_le hhiRat + have hreal : (valCert.lo : Real) ≤ (valCert.hi : Real) := + le_trans hlo (le_trans hvals.1 (le_trans hvals.2 hhi)) + have hreal' : ratToReal valCert.lo ≤ ratToReal valCert.hi := by + simpa [ratToReal_def] using hreal + exact (ratToReal_le_iff (x := valCert.lo) (y := valCert.hi)).1 hreal' + · intro k + have hloRat : valCert.lo ≤ valCert.valsLo k := by + change lo ≤ valsLo k + dsimp [lo] + refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) + (f := valsLo) (a := valsLo k)).2 ⟨k, by simp [univ], le_rfl⟩ + simpa [ratToReal_def] using ratToReal_le_of_le hloRat + · exact buildInductionHeadValCert_bounds_at (inputs := inputs) + (valsReal := valsReal) (valsLo := valsLo) (valsHi := valsHi) hvals + · intro k + have hhiRat : valCert.valsHi k ≤ valCert.hi := by + change valsHi k ≤ hi + dsimp [hi] + refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) + (f := valsHi) (a := valsHi k)).2 ⟨k, by simp [univ], le_rfl⟩ + simpa [ratToReal_def] using ratToReal_le_of_le hhiRat + +/-- Build an induction-head certificate from cached fields and value bounds. -/ +def buildInductionHeadCert [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (certFields : InductionHeadCertFields seq) + (valCert : ValueInterval seq) : InductionHeadCert seq := + inductionHeadCertOfCacheFields inputs certFields.eps certFields.epsAt + certFields.weightBoundAt certFields.margin valCert + +/-- Unfolding lemma for `buildInductionHeadCert`. -/ +theorem buildInductionHeadCert_def [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (certFields : InductionHeadCertFields seq) + (valCert : ValueInterval seq) : + buildInductionHeadCert inputs certFields valCert = + inductionHeadCertOfCacheFields inputs certFields.eps certFields.epsAt + certFields.weightBoundAt certFields.margin valCert := by + rfl + +/-- `buildInductionHeadCert` preserves the provided value certificate. -/ +theorem buildInductionHeadCert_values [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (certFields : InductionHeadCertFields seq) + (valCert : ValueInterval seq) : + (buildInductionHeadCert inputs certFields valCert).values = valCert := by + rfl + /-- Build cached core quantities for induction-head certificates. -/ def buildInductionHeadCoreCacheWith [NeZero seq] {dModel dHead : Nat} (cfg : InductionHeadSplitConfig) @@ -601,38 +918,52 @@ def buildInductionHeadCoreCacheWith [NeZero seq] {dModel dHead : Nat} let t := worstKeyArr[q.1]'(by simp [worstKeyArr, q.isLt]) Thunk.get t + let refineKeys : Fin seq → Finset (Fin seq) := + if h : cfg.splitBudgetDiffRefined = cfg.splitBudgetDiffBase then + fun _ => ∅ + else + let refineKeysRaw : Fin seq → Finset (Fin seq) := fun q => + let base : Finset (Fin seq) := + match worstKey q with + | some k => {k} + | none => ∅ + if hq : q ∈ inputs.active then + let other := otherKeys q + base ∪ other.filter (fun k => decide (scoreGapLoBase q k < 0)) + else + base + let refineKeysArr : Array (Thunk (Finset (Fin seq))) := + Array.ofFn (fun q => Thunk.mk (fun _ => refineKeysRaw q)) + fun q => + let t := refineKeysArr[q.1]'(by + simp [refineKeysArr, q.isLt]) + Thunk.get t let dotDiffLoHi : (Fin seq → Fin seq → Rat) × (Fin seq → Fin seq → Rat) := if h : cfg.splitBudgetDiffRefined = cfg.splitBudgetDiffBase then (dotDiffLoBase, dotDiffHiBase) else let dotDiffLo : Fin seq → Fin seq → Rat := fun q k => - match worstKey q with - | some k' => - if hk : k = k' then - let dimsQ := splitDimsQ q - let dimsDiff := splitDimsDiffRefined q k - let prev := inputs.prev q - (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)).1 - else - dotDiffLoBase q k - | none => dotDiffLoBase q k + if hk : k ∈ refineKeys q then + let dimsQ := splitDimsQ q + let dimsDiff := splitDimsDiffRefined q k + let prev := inputs.prev q + (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)).1 + else + dotDiffLoBase q k let dotDiffHi : Fin seq → Fin seq → Rat := fun q k => - match worstKey q with - | some k' => - if hk : k = k' then - let dimsQ := splitDimsQ q - let dimsDiff := splitDimsDiffRefined q k - let prev := inputs.prev q - (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)).2 - else - dotDiffHiBase q k - | none => dotDiffHiBase q k + if hk : k ∈ refineKeys q then + let dimsQ := splitDimsQ q + let dimsDiff := splitDimsDiffRefined q k + let prev := inputs.prev q + (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)).2 + else + dotDiffHiBase q k (dotDiffLo, dotDiffHi) let dotDiffLo : Fin seq → Fin seq → Rat := dotDiffLoHi.1 let dotDiffHi : Fin seq → Fin seq → Rat := dotDiffLoHi.2 @@ -647,48 +978,26 @@ def buildInductionHeadCoreCacheWith [NeZero seq] {dModel dHead : Nat} inputs.scale * dotDiffHi q k let scoreGapLo : Fin seq → Fin seq → Rat := Bounds.cacheBound2 scoreGapLoRaw - let marginAt : Fin seq → Rat := fun q => - if hq : q ∈ inputs.active then - let other := otherKeys q - if h : other.Nonempty then - other.inf' h (fun k => scoreGapLo q k) - else - (0 : Rat) - else - (0 : Rat) - let weightBoundAtBase : Fin seq → Fin seq → Rat := fun q k => - if hk : k = inputs.prev q then - (0 : Rat) - else - let gap := scoreGapLo q k - if gap < 0 then - (1 : Rat) - else - ratDivUp 1 (1 + gap) - let weightBoundAtBaseCached : Fin seq → Fin seq → Rat := - Bounds.cacheBound2Task weightBoundAtBase - let epsAtBase : Fin seq → Rat := fun q => - let other := otherKeys q - let total := other.sum (fun k => weightBoundAtBaseCached q k) - min (1 : Rat) total - let epsAt : Fin seq → Rat := - Bounds.cacheBoundThunk epsAtBase - let weightBoundAt : Fin seq → Fin seq → Rat := weightBoundAtBaseCached - let margin : Rat := - if h : inputs.active.Nonempty then - inputs.active.inf' h marginAt - else - (0 : Rat) - let eps : Rat := - if h : inputs.active.Nonempty then - inputs.active.sup' h epsAt - else - (0 : Rat) + let certFields := buildInductionHeadCertFields inputs otherKeys scoreGapLo + let marginAt : Fin seq → Rat := certFields.marginAt + let weightBoundAtBase : Fin seq → Fin seq → Rat := certFields.weightBoundAtBase + let weightBoundAtBaseCached : Fin seq → Fin seq → Rat := certFields.weightBoundAtBaseCached + let epsAtBase : Fin seq → Rat := certFields.epsAtBase + let epsAt : Fin seq → Rat := certFields.epsAt + let weightBoundAt : Fin seq → Fin seq → Rat := certFields.weightBoundAt + let margin : Rat := certFields.margin + let eps : Rat := certFields.eps let dirHeadVec := dirHeadVecOfInputs inputs let dirHead : Fin dHead → Rat := fun d => dirHeadVec.get d - let wvDir : Fin dModel → Rat := + let wvDirTask : Fin dModel → Rat := Bounds.cacheBoundTask (fun j => Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) + let wvDirArr : Array Rat := Array.ofFn wvDirTask + let wvDir : Fin dModel → Rat := fun j => + wvDirArr[j.1]'(by + have hsize : wvDirArr.size = dModel := by + simp [wvDirArr] + simp [hsize, j.isLt]) let bDir : Rat := Linear.dotFin dHead dirHead (fun d => inputs.bv d) let valsLo : Fin seq → Rat := fun q => @@ -699,20 +1008,9 @@ def buildInductionHeadCoreCacheWith [NeZero seq] {dModel dHead : Nat} have hnonempty : univ.Nonempty := by simp [univ] let lo := univ.inf' hnonempty valsLo let hi := univ.sup' hnonempty valsHi - let valCert : ValueInterval seq := - { lo := lo - hi := hi - valsLo := valsLo - valsHi := valsHi - direction := some inputs.directionSpec } + let valCert : ValueInterval seq := buildInductionHeadValCert inputs valsLo valsHi let cert : InductionHeadCert seq := - { eps := eps - epsAt := epsAt - weightBoundAt := weightBoundAt - margin := margin - active := inputs.active - prev := inputs.prev - values := valCert } + buildInductionHeadCert inputs certFields valCert exact { lnBounds := lnBounds lnLo := lnLo diff --git a/Nfp/Sound/Induction/CoreDefs.lean b/Nfp/Sound/Induction/CoreDefs.lean index 6cbf702..3444c49 100644 --- a/Nfp/Sound/Induction/CoreDefs.lean +++ b/Nfp/Sound/Induction/CoreDefs.lean @@ -226,6 +226,27 @@ structure InductionHeadCert (seq : Nat) where /-- Value-interval certificate for the direction values. -/ values : ValueInterval seq +/-- Extensionality lemma for `InductionHeadCert`. -/ +@[ext] theorem InductionHeadCert.ext {seq : Nat} {c₁ c₂ : InductionHeadCert seq} + (hε : c₁.eps = c₂.eps) + (hεAt : c₁.epsAt = c₂.epsAt) + (hweight : c₁.weightBoundAt = c₂.weightBoundAt) + (hmargin : c₁.margin = c₂.margin) + (hactive : c₁.active = c₂.active) + (hprev : c₁.prev = c₂.prev) + (hvalues : c₁.values = c₂.values) : + c₁ = c₂ := by + cases c₁ + cases c₂ + cases hε + cases hεAt + cases hweight + cases hmargin + cases hactive + cases hprev + cases hvalues + rfl + /-- Soundness predicate for `InductionHeadCert`. -/ structure InductionHeadCertSound [NeZero seq] {dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) : Prop where diff --git a/Nfp/Sound/Induction/CoreSound/Basic.lean b/Nfp/Sound/Induction/CoreSound/Basic.lean index 09ce05c..19bde24 100644 --- a/Nfp/Sound/Induction/CoreSound/Basic.lean +++ b/Nfp/Sound/Induction/CoreSound/Basic.lean @@ -1,1368 +1,11 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -module - -import all Nfp.Sound.Induction.Core.Basic -public import Nfp.Sound.Induction.Core -public import Nfp.Sound.Induction.CoreSound.Values -public section +module -namespace Nfp -namespace Sound -open scoped BigOperators -open Nfp.Circuit -open Nfp.Sound.Bounds -variable {seq : Nat} -set_option maxHeartbeats 5000000 in --- The soundness proof expands many cached bounds; extra heartbeats avoid spurious timeouts. -set_option synthInstance.maxHeartbeats 200000 in --- Instance search also touches the expanded caches; allow more room to avoid timeouts. -/-- Soundness for `buildInductionCertFromHeadCoreWith?`. -/ -theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : Nat} - (cfg : InductionHeadSplitConfig) - (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) - (hcore : buildInductionCertFromHeadCoreWith? cfg inputs = some c) : - InductionHeadCertSound inputs c := by - classical - by_cases hEps : 0 < inputs.lnEps - · by_cases hSqrt : 0 < sqrtLower inputs.lnEps - · by_cases hmodel : dModel = 0 - · have : False := by - have hnone := - buildInductionCertFromHeadCoreWith?_eq_none_of_model_eq_zero - (cfg := cfg) (inputs := inputs) hEps hSqrt hmodel - have hcore' : - (none : Option (InductionHeadCert seq)) = some c := by - exact hnone.symm.trans hcore - cases hcore' - exact this.elim - · by_cases hactive : inputs.active.Nonempty - · let lnBounds := Bounds.cacheBoundPair2 (fun q => - Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) - let lnLo : Fin seq → Fin dModel → Rat := lnBounds.1 - let lnHi : Fin seq → Fin dModel → Rat := lnBounds.2 - let lnAbsMaxTask : Fin seq → Rat := - Bounds.cacheBoundTask (fun q => - Bounds.intervalAbsBound (lnLo q) (lnHi q)) - let lnAbsMaxArr : Array Rat := - Array.ofFn (fun q : Fin seq => lnAbsMaxTask q) - let lnAbsMax : Fin seq → Rat := fun q => - lnAbsMaxArr[q.1]'(by - simp [lnAbsMaxArr]) - let invStdBoundsTasks : Array (Task (Rat × Rat)) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => invStdBounds inputs.lnEps (inputs.embed q))) - let invStdBoundsArr : Array (Rat × Rat) := - Array.ofFn (fun q : Fin seq => - (invStdBoundsTasks[q.1]'(by - simp [invStdBoundsTasks])).get) - let invStdLo : Fin seq → Rat := fun q => - (invStdBoundsArr[q.1]'(by - simp [invStdBoundsArr])).1 - let invStdHi : Fin seq → Rat := fun q => - (invStdBoundsArr[q.1]'(by - simp [invStdBoundsArr])).2 - let lnCoeff : Fin seq → Fin dModel → Rat := fun q j => - inputs.ln1Gamma j * (inputs.embed q j - mean (inputs.embed q)) - let invStd : Fin seq → Real := fun q => - (Real.sqrt ((varianceRat (inputs.embed q) : Real) + (inputs.lnEps : Real)))⁻¹ - have hmeanRat : - ∀ q, (mean (inputs.embed q) : Real) = meanRat (inputs.embed q) := by - intro q - have hmu_rat : mean (inputs.embed q) = meanRat (inputs.embed q) := by - simp [mean_def, hmodel, ratRoundDown_def] - simpa [ratToReal_def] using congrArg ratToReal hmu_rat - have hln_affine : - ∀ q j, - lnRealOfInputs inputs q j = - (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q := by - intro q j - have hmu := hmeanRat q - simp [lnRealOfInputs_def, Bounds.layerNormReal_def, hmodel, lnCoeff, hmu, invStd, - add_comm, mul_assoc, -mul_eq_mul_left_iff, -mul_eq_mul_right_iff] - have hln_fun : - ∀ q, - lnRealOfInputs inputs q = - fun j => (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q := by - intro q - funext j - exact hln_affine q j - have hinv_bounds : - ∀ q, (invStdLo q : Real) ≤ invStd q ∧ invStd q ≤ (invStdHi q : Real) := by - intro q - simpa [invStd, invStdLo, invStdHi, invStdBoundsArr, invStdBoundsTasks, - Bounds.invStdBounds_def, Task.spawn, Array.getElem_ofFn] using - (Bounds.invStdBounds_spec (eps := inputs.lnEps) (x := inputs.embed q) - hmodel hEps hSqrt) - let qBaseArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.wq j d) (fun j => inputs.ln1Beta j) + - inputs.bq d) - let qBase : Fin dHead → Rat := fun d => - qBaseArr[d.1]'(by - simp [qBaseArr]) - let kBaseArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.wk j d) (fun j => inputs.ln1Beta j) + - inputs.bk d) - let kBase : Fin dHead → Rat := fun d => - kBaseArr[d.1]'(by - simp [kBaseArr]) - let coeffRowTasks : - (Fin dModel → Fin dHead → Rat) → - Array (Task { row : Array Rat // row.size = dHead }) := - fun w => - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - let μ := mean (inputs.embed q) - let coeff : Fin dModel → Rat := fun j => - inputs.ln1Gamma j * (inputs.embed q j - μ) - ⟨Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => w j d) coeff), - by simp⟩)) - let qCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := - coeffRowTasks inputs.wq - let qCoeffArr : Array { row : Array Rat // row.size = dHead } := - Array.ofFn (fun q : Fin seq => - (qCoeffRowTasks[q.1]'(by - simp [qCoeffRowTasks, coeffRowTasks])).get) - let qCoeff : Fin seq → Fin dHead → Rat := fun q d => - let row := qCoeffArr[q.1]'(by - simp [qCoeffArr]) - row.1[d.1]'(by - simp [row.2]) - let kCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := - coeffRowTasks inputs.wk - let kCoeffArr : Array { row : Array Rat // row.size = dHead } := - Array.ofFn (fun q : Fin seq => - (kCoeffRowTasks[q.1]'(by - simp [kCoeffRowTasks, coeffRowTasks])).get) - let kCoeff : Fin seq → Fin dHead → Rat := fun q d => - let row := kCoeffArr[q.1]'(by - simp [kCoeffArr]) - row.1[d.1]'(by - simp [row.2]) - let qLo : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) - qBase d + bounds.1 - let qHi : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) - qBase d + bounds.2 - let kLo : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) - kBase d + bounds.1 - let kHi : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) - kBase d + bounds.2 - let qAbs : Fin seq → Fin dHead → Rat := fun q d => max |qLo q d| |qHi q d| - let kAbs : Fin seq → Fin dHead → Rat := fun q d => max |kLo q d| |kHi q d| - let qAbsMaxArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := Finset.univ_nonempty - univ.sup' hnonempty (fun q => qAbs q d)) - let qAbsMax : Fin dHead → Rat := fun d => - qAbsMaxArr[d.1]'(by - simp [qAbsMaxArr]) - let kAbsMaxArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := Finset.univ_nonempty - univ.sup' hnonempty (fun k => kAbs k d)) - let kAbsMax : Fin dHead → Rat := fun d => - kAbsMaxArr[d.1]'(by - simp [kAbsMaxArr]) - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let splitBudgetQ : Nat := cfg.splitBudgetQ - let splitBudgetK : Nat := cfg.splitBudgetK - let splitBudgetDiffBase : Nat := cfg.splitBudgetDiffBase - let splitBudgetDiffRefined : Nat := cfg.splitBudgetDiffRefined - let top2ByScore : - (Fin dHead → Rat) → List (Fin dHead) → List (Fin dHead) := fun score ambig => - let step - (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) - (d : Fin dHead) : - Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := - let s := score d - match best with - | (none, none) => (some (s, d), none) - | (some b1, none) => - if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) - | (some b1, some b2) => - if b1.1 < s then (some (s, d), some b1) - else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) - | (none, some b2) => - if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - match ambig.foldl step (none, none) with - | (some b1, some b2) => [b1.2, b2.2] - | (some b1, none) => [b1.2] - | (none, _) => [] - let finRangeHead : List (Fin dHead) := List.finRange dHead - let finRangeSeq : List (Fin seq) := List.finRange seq - let splitDimsQ : Fin seq → List (Fin dHead) := fun q => - if splitBudgetQ = 0 then - [] - else - let ambig := - finRangeHead.filter (fun d => decide (qLo q d < 0 ∧ 0 < qHi q d)) - let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d - let dims1 := top2ByScore score ambig - let dims2 := top2ByScore score (ambig.filter (fun d => decide (d ∉ dims1))) - (dims1 ++ dims2).take splitBudgetQ - let splitDimsK : Fin seq → Fin seq → List (Fin dHead) := fun q k => - if splitBudgetK = 0 then - [] - else - let ambig := - finRangeHead.filter (fun d => decide (kLo k d < 0 ∧ 0 < kHi k d)) - let score : Fin dHead → Rat := fun d => (kHi k d - kLo k d) * qAbs q d - let dims1 := top2ByScore score ambig - let dims2 := top2ByScore score (ambig.filter (fun d => decide (d ∉ dims1))) - (dims1 ++ dims2).take splitBudgetK - let splitDimsDiffCore : Nat → Fin seq → Fin seq → List (Fin dHead) := fun budget q k => - if budget = 0 then - [] - else - let prev := inputs.prev q - let diffLo : Fin dHead → Rat := fun d => kLo prev d - kHi k d - let diffHi : Fin dHead → Rat := fun d => kHi prev d - kLo k d - let ambig := - finRangeHead.filter (fun d => decide (diffLo d < 0 ∧ 0 < diffHi d)) - let score : Fin dHead → Rat := fun d => (diffHi d - diffLo d) * qAbs q d - let step - (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) - (d : Fin dHead) : - Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := - let s := score d - match best with - | (none, none) => (some (s, d), none) - | (some b1, none) => - if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) - | (some b1, some b2) => - if b1.1 < s then (some (s, d), some b1) - else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) - | (none, some b2) => - if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => - match ambig.foldl step (none, none) with - | (some b1, some b2) => [b1.2, b2.2] - | (some b1, none) => [b1.2] - | (none, _) => [] - let dims1 : List (Fin dHead) := top2 ambig - let dims2 : List (Fin dHead) := - top2 (ambig.filter (fun d => decide (d ∉ dims1))) - let memDims2 : Fin dHead → Bool := fun d => - dims2.any (fun d' => decide (d' = d)) - let dims3 : List (Fin dHead) := - top2 - ((ambig.filter (fun d => decide (d ∉ dims1))).filter - (fun d => !memDims2 d)) - (dims1 ++ dims2 ++ dims3).take budget - let splitDimsDiffBase : Fin seq → Fin seq → List (Fin dHead) := - splitDimsDiffCore splitBudgetDiffBase - let splitDimsDiffRefined : Fin seq → Fin seq → List (Fin dHead) := - splitDimsDiffCore splitBudgetDiffRefined - let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - let dimsQ := splitDimsQ q - ⟨Array.ofFn (fun k : Fin seq => - if masked q k then - (0, 0) - else - let dimsK := splitDimsK q k - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsK - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo k d) (fun d => kHi k d)), - by simp⟩)) - let dotDiffRowTasksBase : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - if hq : q ∈ inputs.active then - let dimsQ := splitDimsQ q - ⟨Array.ofFn (fun k : Fin seq => - if masked q k then - (0, 0) - else - let dimsDiff := splitDimsDiffBase q k - let prev := inputs.prev q - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)), - by simp⟩ - else - ⟨Array.ofFn (fun _ : Fin seq => (0, 0)), by simp⟩)) - let dotLo : Fin seq → Fin seq → Rat := fun q k => - let row := (dotRowTasks[q.1]'(by - simp [dotRowTasks, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.1 - let dotHi : Fin seq → Fin seq → Rat := fun q k => - let row := (dotRowTasks[q.1]'(by - simp [dotRowTasks, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.2 - let dotDiffLoBase : Fin seq → Fin seq → Rat := fun q k => - let row := (dotDiffRowTasksBase[q.1]'(by - simp [dotDiffRowTasksBase, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.1 - let dotDiffHiBase : Fin seq → Fin seq → Rat := fun q k => - let row := (dotDiffRowTasksBase[q.1]'(by - simp [dotDiffRowTasksBase, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.2 - let dotAbs : Fin seq → Fin seq → Rat := fun q k => max |dotLo q k| |dotHi q k| - let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => - |inputs.scale| * dotAbs q k - let scoreLo : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - if hscale : 0 ≤ inputs.scale then - inputs.scale * dotLo q k - else - inputs.scale * dotHi q k - let scoreHi : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - if hscale : 0 ≤ inputs.scale then - inputs.scale * dotHi q k - else - inputs.scale * dotLo q k - let scoreLoPrev : Fin seq → Rat := fun q => - scoreLo q (inputs.prev q) - let scoreGapLoBaseRaw : Fin seq → Fin seq → Rat := fun q k => - if masked q (inputs.prev q) then - scoreLoPrev q - scoreHi q k - else if masked q k then - scoreLoPrev q - inputs.maskValue - else if hscale : 0 ≤ inputs.scale then - inputs.scale * dotDiffLoBase q k - else - inputs.scale * dotDiffHiBase q k - let scoreGapLoBase : Fin seq → Fin seq → Rat := - Bounds.cacheBound2 scoreGapLoBaseRaw - let otherKeys : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - let worstKey : Fin seq → Option (Fin seq) := - if h : cfg.splitBudgetDiffRefined = cfg.splitBudgetDiffBase then - fun _ => none - else - let worstKeyRaw : Fin seq → Option (Fin seq) := fun q => - if hq : q ∈ inputs.active then - let ks := finRangeSeq.filter (fun k => decide (k ≠ inputs.prev q)) - match ks with - | [] => none - | k :: ks => - let step (best : Rat × Fin seq) (k : Fin seq) := - let s := scoreGapLoBase q k - if s ≤ best.1 then (s, k) else best - some (ks.foldl step (scoreGapLoBase q k, k)).2 - else - none - let worstKeyArr : Array (Thunk (Option (Fin seq))) := - Array.ofFn (fun q => Thunk.mk (fun _ => worstKeyRaw q)) - fun q => - let t := worstKeyArr[q.1]'(by - simp [worstKeyArr, q.isLt]) - Thunk.get t - let dotDiffLoHi : (Fin seq → Fin seq → Rat) × (Fin seq → Fin seq → Rat) := - if h : cfg.splitBudgetDiffRefined = cfg.splitBudgetDiffBase then - (dotDiffLoBase, dotDiffHiBase) - else - let dotDiffLo : Fin seq → Fin seq → Rat := fun q k => - match worstKey q with - | some k' => - if hk : k = k' then - let dimsQ := splitDimsQ q - let dimsDiff := splitDimsDiffRefined q k - let prev := inputs.prev q - (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)).1 - else - dotDiffLoBase q k - | none => dotDiffLoBase q k - let dotDiffHi : Fin seq → Fin seq → Rat := fun q k => - match worstKey q with - | some k' => - if hk : k = k' then - let dimsQ := splitDimsQ q - let dimsDiff := splitDimsDiffRefined q k - let prev := inputs.prev q - (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)).2 - else - dotDiffHiBase q k - | none => dotDiffHiBase q k - (dotDiffLo, dotDiffHi) - let dotDiffLo : Fin seq → Fin seq → Rat := dotDiffLoHi.1 - let dotDiffHi : Fin seq → Fin seq → Rat := dotDiffLoHi.2 - let scoreGapLoRaw : Fin seq → Fin seq → Rat := fun q k => - if masked q (inputs.prev q) then - scoreLoPrev q - scoreHi q k - else if masked q k then - scoreLoPrev q - inputs.maskValue - else if hscale : 0 ≤ inputs.scale then - inputs.scale * dotDiffLo q k - else - inputs.scale * dotDiffHi q k - let scoreGapLo : Fin seq → Fin seq → Rat := - Bounds.cacheBound2 scoreGapLoRaw - let marginAt : Fin seq → Rat := fun q => - if hq : q ∈ inputs.active then - let other := otherKeys q - if h : other.Nonempty then - other.inf' h (fun k => scoreGapLo q k) - else - (0 : Rat) - else - (0 : Rat) - let weightBoundAtBase : Fin seq → Fin seq → Rat := fun q k => - if hk : k = inputs.prev q then - (0 : Rat) - else - let gap := scoreGapLo q k - if gap < 0 then - (1 : Rat) - else - ratDivUp 1 (1 + gap) - let weightBoundAtBaseCached : Fin seq → Fin seq → Rat := - Bounds.cacheBound2Task weightBoundAtBase - let epsAtBase : Fin seq → Rat := fun q => - let other := otherKeys q - let total := other.sum (fun k => weightBoundAtBaseCached q k) - min (1 : Rat) total - let epsAt : Fin seq → Rat := - Bounds.cacheBoundThunk epsAtBase - let weightBoundAt : Fin seq → Fin seq → Rat := weightBoundAtBaseCached - let margin : Rat := - if h : inputs.active.Nonempty then - inputs.active.inf' h marginAt - else - (0 : Rat) - let eps : Rat := - if h : inputs.active.Nonempty then - inputs.active.sup' h epsAt - else - (0 : Rat) - have hseq : (1 : Nat) ≤ seq := - Nat.succ_le_iff.mpr (Nat.pos_of_ne_zero (NeZero.ne seq)) - let dirHeadVec := dirHeadVecOfInputs inputs - let dirHead : Fin dHead → Rat := fun d => dirHeadVec.get d - let wvDir : Fin dModel → Rat := - Bounds.cacheBoundTask (fun j => - Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) - let bDir : Rat := - Linear.dotFin dHead dirHead (fun d => inputs.bv d) - let valsLo : Fin seq → Rat := fun q => - bDir + Bounds.dotIntervalLower (fun j => wvDir j) (lnLo q) (lnHi q) - let valsHi : Fin seq → Rat := fun q => - bDir + Bounds.dotIntervalUpper (fun j => wvDir j) (lnLo q) (lnHi q) - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := by simp [univ] - let lo := univ.inf' hnonempty valsLo - let hi := univ.sup' hnonempty valsHi - let valCert : ValueInterval seq := - { lo := lo - hi := hi - valsLo := valsLo - valsHi := valsHi - direction := some inputs.directionSpec } - let cert : InductionHeadCert seq := - { eps := eps - epsAt := epsAt - weightBoundAt := weightBoundAt - margin := margin - active := inputs.active - prev := inputs.prev - values := valCert } - have hcore' : buildInductionCertFromHeadCoreWith? cfg inputs = some cert := by - have hcore'' : - buildInductionCertFromHeadCoreWith? cfg inputs = - some (buildInductionHeadCoreCacheWith cfg inputs).cert := - buildInductionCertFromHeadCoreWith?_eq_some - (cfg := cfg) (inputs := inputs) hEps hSqrt hmodel hactive - simpa [buildInductionHeadCoreCacheWith_cert_eq] using hcore'' - have hc : c = cert := by - have hcert : cert = c := by - exact Option.some.inj (hcore'.symm.trans hcore) - simpa using hcert.symm - subst hc - have hln_bounds : - ∀ q i, (lnLo q i : Real) ≤ lnRealOfInputs inputs q i ∧ - lnRealOfInputs inputs q i ≤ (lnHi q i : Real) := by - intro q i - have hln := - Bounds.layerNormBounds_spec (eps := inputs.lnEps) - (gamma := inputs.ln1Gamma) (beta := inputs.ln1Beta) - (x := inputs.embed q) hmodel hEps hSqrt - simpa [lnBounds, lnLo, lnHi, lnRealOfInputs_def, - Bounds.cacheBoundPair2_apply_left, Bounds.cacheBoundPair2_apply_right] using - hln i - have dotFin_cast {n : Nat} (f g : Fin n → Rat) : - (Linear.dotFin n f g : Real) = - dotProduct (fun j => (f j : Real)) (fun j => (g j : Real)) := by - simp [Linear.dotFin_def, Linear.sumFin_eq_sum_univ, dotProduct] - have proj_bounds - (w : Fin dModel → Fin dHead → Rat) - (b base : Fin dHead → Rat) - (coeff : Fin seq → Fin dHead → Rat) - (hbase : ∀ d, - (base d : Real) = - dotProduct (fun j => (w j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - (b d : Real)) - (hcoeff : ∀ q d, - (coeff q d : Real) = - dotProduct (fun j => (w j d : Real)) - (fun j => (lnCoeff q j : Real))) : - ∀ q d, - (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).1 : Rat) ≤ - dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + - (b d : Real) ∧ - dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + - (b d : Real) ≤ - (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).2 : Rat) := by - intro q d - have hinv : (invStdLo q : Real) ≤ invStd q ∧ invStd q ≤ (invStdHi q : Real) := - hinv_bounds q - have hln_fun_q : - lnRealOfInputs inputs q = - fun j => (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q := by - exact hln_fun q - have hdot_add : - dotProduct (fun j => (w j d : Real)) - (fun j => - (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q) = - dotProduct (fun j => (w j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - dotProduct (fun j => (w j d : Real)) - (fun j => (lnCoeff q j : Real) * invStd q) := by - simpa using - (Nfp.Sound.Linear.dotProduct_add_right - (x := fun j => (w j d : Real)) - (y := fun j => (inputs.ln1Beta j : Real)) - (z := fun j => (lnCoeff q j : Real) * invStd q)) - have hdot_coeff : - dotProduct (fun j => (w j d : Real)) - (fun j => (lnCoeff q j : Real) * invStd q) = - dotProduct (fun j => (w j d : Real)) - (fun j => (lnCoeff q j : Real)) * - invStd q := by - simpa using - (Nfp.Sound.Linear.dotProduct_mul_right - (x := fun j => (w j d : Real)) - (y := fun j => (lnCoeff q j : Real)) - (a := invStd q)) - have hreal : - dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + - (b d : Real) = - (base d : Real) + (coeff q d : Real) * invStd q := by - calc - dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + - (b d : Real) = - dotProduct (fun j => (w j d : Real)) - (fun j => - (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q) + - (b d : Real) := by - simp [hln_fun_q] - _ = - dotProduct (fun j => (w j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - dotProduct (fun j => (w j d : Real)) - (fun j => (lnCoeff q j : Real)) * - invStd q + - (b d : Real) := by - simp [hdot_add, hdot_coeff, add_assoc] - _ = - (dotProduct (fun j => (w j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - (b d : Real)) + - dotProduct (fun j => (w j d : Real)) - (fun j => (lnCoeff q j : Real)) * - invStd q := by ac_rfl - _ = (base d : Real) + (coeff q d : Real) * invStd q := by - simp [hbase, hcoeff] - have hscale : - let bounds := scaleInterval (coeff q d) (invStdLo q) (invStdHi q) - (bounds.1 : Real) ≤ (coeff q d : Real) * invStd q ∧ - (coeff q d : Real) * invStd q ≤ (bounds.2 : Real) := by - exact scaleInterval_bounds_real (x := coeff q d) (lo := invStdLo q) - (hi := invStdHi q) (y := invStd q) hinv.1 hinv.2 - have hlow : - (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).1 : Rat) ≤ - dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + - (b d : Real) := by - simpa [hreal] using add_le_add_left hscale.1 (base d : Real) - have hhigh : - dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + - (b d : Real) ≤ - (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).2 : Rat) := by - simpa [hreal] using add_le_add_left hscale.2 (base d : Real) - exact ⟨hlow, hhigh⟩ - have hq_bounds : - ∀ q d, (qLo q d : Real) ≤ qRealOfInputs inputs q d ∧ - qRealOfInputs inputs q d ≤ (qHi q d : Real) := by - intro q d - have hbase : - ∀ d, - (qBase d : Real) = - dotProduct (fun j => (inputs.wq j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - (inputs.bq d : Real) := by - intro d - simp [qBase, qBaseArr, dotFin_cast] - have hcoeff : - ∀ q' d, - (qCoeff q' d : Real) = - dotProduct (fun j => (inputs.wq j d : Real)) - (fun j => (lnCoeff q' j : Real)) := by - intro q' d - simpa [qCoeff, qCoeffArr, qCoeffRowTasks, coeffRowTasks, lnCoeff, Task.spawn] using - (dotFin_cast (f := fun j => inputs.wq j d) - (g := fun j => - inputs.ln1Gamma j * (inputs.embed q' j - mean (inputs.embed q')))) - have h := proj_bounds (w := inputs.wq) (b := inputs.bq) (base := qBase) - (coeff := qCoeff) hbase hcoeff q d - simpa [qLo, qHi, qRealOfInputs_def] using h - have hk_bounds : - ∀ q d, (kLo q d : Real) ≤ kRealOfInputs inputs q d ∧ - kRealOfInputs inputs q d ≤ (kHi q d : Real) := by - intro q d - have hbase : - ∀ d, - (kBase d : Real) = - dotProduct (fun j => (inputs.wk j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - (inputs.bk d : Real) := by - intro d - simp [kBase, kBaseArr, dotFin_cast] - have hcoeff : - ∀ q' d, - (kCoeff q' d : Real) = - dotProduct (fun j => (inputs.wk j d : Real)) - (fun j => (lnCoeff q' j : Real)) := by - intro q' d - simpa [kCoeff, kCoeffArr, kCoeffRowTasks, coeffRowTasks, lnCoeff, Task.spawn] using - (dotFin_cast (f := fun j => inputs.wk j d) - (g := fun j => - inputs.ln1Gamma j * (inputs.embed q' j - mean (inputs.embed q')))) - have h := proj_bounds (w := inputs.wk) (b := inputs.bk) (base := kBase) - (coeff := kCoeff) hbase hcoeff q d - simpa [kLo, kHi, kRealOfInputs_def] using h - let scoresReal := scoresRealOfInputs inputs - have scoresReal_eq_base_of_not_masked : - ∀ q k, ¬ masked q k → - scoresReal q k = - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) := by - intro q k hnot - by_cases hcausal : inputs.maskCausal - · have hnot_lt : ¬ q < k := by - intro hlt - exact hnot ⟨hcausal, hlt⟩ - have hle : k ≤ q := le_of_not_gt hnot_lt - simp [scoresReal, scoresRealOfInputs_def, hcausal, hle] - · simp [scoresReal, scoresRealOfInputs_def, hcausal] - have scoresReal_eq_masked : - ∀ q k, masked q k → scoresReal q k = (inputs.maskValue : Real) := by - intro q k hmask - have hmask' : inputs.maskCausal = true ∧ q < k := by - simpa [masked] using hmask - have hle : ¬ k ≤ q := not_le_of_gt hmask'.2 - simp [scoresReal, scoresRealOfInputs_def, hmask'.1, hle] - have hscore_bounds : - ∀ q k, (scoreLo q k : Real) ≤ scoresReal q k ∧ - scoresReal q k ≤ (scoreHi q k : Real) := by - intro q k - let base := - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) - have hdot_bounds (hnot : ¬ masked q k) : - (dotLo q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) ∧ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) ≤ (dotHi q k : Real) := by - have hq := hq_bounds q - have hk := hk_bounds k - have hlo1 : ∀ d, (qLo q d : Real) ≤ qRealOfInputs inputs q d := fun d => - (hq d).1 - have hhi1 : ∀ d, qRealOfInputs inputs q d ≤ (qHi q d : Real) := fun d => - (hq d).2 - have hlo2 : ∀ d, (kLo k d : Real) ≤ kRealOfInputs inputs k d := fun d => - (hk d).1 - have hhi2 : ∀ d, kRealOfInputs inputs k d ≤ (kHi k d : Real) := fun d => - (hk d).2 - have hspec := - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth_spec_real - (dims1 := splitDimsQ q) (dims2 := splitDimsK q k) - (lo1 := fun d => qLo q d) (hi1 := fun d => qHi q d) - (lo2 := fun d => kLo k d) (hi2 := fun d => kHi k d) - (x := fun d => qRealOfInputs inputs q d) - (y := fun d => kRealOfInputs inputs k d) - hlo1 hhi1 hlo2 hhi2 - have hlow' : - (dotLo q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) := by - simpa [dotLo, dotRowTasks, Task.spawn, Array.getElem_ofFn, hnot] - using hspec.1 - have hhigh' : - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) ≤ (dotHi q k : Real) := by - simpa [dotHi, dotRowTasks, Task.spawn, Array.getElem_ofFn, hnot] - using hspec.2 - exact ⟨hlow', hhigh'⟩ - have hscore_base_bounds (hnot : ¬ masked q k) : - (scoreLo q k : Real) ≤ base ∧ base ≤ (scoreHi q k : Real) := by - by_cases hscale : 0 ≤ inputs.scale - · have hscale_real : 0 ≤ (inputs.scale : Real) := - by - simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hscale - have hdot := hdot_bounds hnot - have hlow := mul_le_mul_of_nonneg_left hdot.1 hscale_real - have hhigh := mul_le_mul_of_nonneg_left hdot.2 hscale_real - constructor - · simpa [scoreLo, masked, hnot, hscale, base] using hlow - · simpa [scoreHi, masked, hnot, hscale, base] using hhigh - · have hscale_nonpos : inputs.scale ≤ 0 := - le_of_lt (lt_of_not_ge hscale) - have hscale_real : (inputs.scale : Real) ≤ 0 := - by - simpa [ratToReal_def] using - (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos - have hdot := hdot_bounds hnot - have hlow := mul_le_mul_of_nonpos_left hdot.2 hscale_real - have hhigh := mul_le_mul_of_nonpos_left hdot.1 hscale_real - constructor - · simpa [scoreLo, masked, hnot, hscale, base] using hlow - · simpa [scoreHi, masked, hnot, hscale, base] using hhigh - by_cases hcausal : inputs.maskCausal - · by_cases hle : k ≤ q - · have hnot : ¬ q < k := not_lt_of_ge hle - have hnot_masked : ¬ masked q k := fun hmk => hnot hmk.2 - have hscore_eq : scoresReal q k = base := - scoresReal_eq_base_of_not_masked q k hnot_masked - have hbase := hscore_base_bounds hnot_masked - constructor - · simpa [hscore_eq] using hbase.1 - · simpa [hscore_eq] using hbase.2 - · have hlt : q < k := lt_of_not_ge hle - have hmask : masked q k := ⟨hcausal, hlt⟩ - have hscore : scoresReal q k = (inputs.maskValue : Real) := - scoresReal_eq_masked q k hmask - constructor - · simp [hscore, scoreLo, hmask] - · simp [hscore, scoreHi, hmask] - · have hnot_masked : ¬ masked q k := by - simp [masked, hcausal] - have hscore_eq : scoresReal q k = base := - scoresReal_eq_base_of_not_masked q k hnot_masked - have hbase := hscore_base_bounds hnot_masked - constructor - · simpa [hscore_eq] using hbase.1 - · simpa [hscore_eq] using hbase.2 - have hdot_diff_bounds : - ∀ q, q ∈ inputs.active → ∀ k, ¬ masked q k → - (dotDiffLo q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) ∧ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by - intro q hq k hmask - have hq_bounds' := hq_bounds q - have hkprev := hk_bounds (inputs.prev q) - have hk := hk_bounds k - have hlo1 : ∀ d, (qLo q d : Real) ≤ qRealOfInputs inputs q d := fun d => - (hq_bounds' d).1 - have hhi1 : ∀ d, qRealOfInputs inputs q d ≤ (qHi q d : Real) := fun d => - (hq_bounds' d).2 - have hlo2 : - ∀ d, - (kLo (inputs.prev q) d - kHi k d : Rat) ≤ - (kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) := by - intro d - have hprev_lo := (hkprev d).1 - have hk_hi := (hk d).2 - have h := sub_le_sub hprev_lo hk_hi - simpa [ratToReal_sub] using h - have hhi2 : - ∀ d, - (kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) ≤ - (kHi (inputs.prev q) d - kLo k d : Rat) := by - intro d - have hprev_hi := (hkprev d).2 - have hk_lo := (hk d).1 - have h := sub_le_sub hprev_hi hk_lo - simpa [ratToReal_sub] using h - have hspec (dimsDiff : List (Fin dHead)) := - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth_spec_real - (dims1 := splitDimsQ q) (dims2 := dimsDiff) - (lo1 := fun d => qLo q d) (hi1 := fun d => qHi q d) - (lo2 := fun d => kLo (inputs.prev q) d - kHi k d) - (hi2 := fun d => kHi (inputs.prev q) d - kLo k d) - (x := fun d => qRealOfInputs inputs q d) - (y := fun d => - kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) - hlo1 hhi1 hlo2 hhi2 - have hspecBase := hspec (splitDimsDiffBase q k) - have hspecRef := hspec (splitDimsDiffRefined q k) - have hspecBase_bounds : - (dotDiffLoBase q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) ∧ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) ≤ (dotDiffHiBase q k : Real) := by - refine ⟨?_, ?_⟩ - · simpa [dotDiffLoBase, dotDiffRowTasksBase, hq, hmask, Task.spawn, - Array.getElem_ofFn] using hspecBase.1 - · simpa [dotDiffHiBase, dotDiffRowTasksBase, hq, hmask, Task.spawn, - Array.getElem_ofFn] using hspecBase.2 - by_cases hbudget : cfg.splitBudgetDiffRefined = cfg.splitBudgetDiffBase - · simpa [dotDiffLo, dotDiffHi, dotDiffLoHi, hbudget] using hspecBase_bounds - · cases hkey : worstKey q with - | none => - simpa [dotDiffLo, dotDiffHi, dotDiffLoHi, hbudget, hkey] using - hspecBase_bounds - | some k' => - by_cases hk : k = k' - · have hlow' : - (dotDiffLo q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) := by - simpa [dotDiffLo, dotDiffLoHi, hbudget, hkey, hk] using hspecRef.1 - have hhigh' : - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by - simpa [dotDiffHi, dotDiffLoHi, hbudget, hkey, hk] using hspecRef.2 - exact ⟨hlow', hhigh'⟩ - · have hlow' : - (dotDiffLo q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) := by - simpa [dotDiffLo, dotDiffLoHi, hbudget, hkey, hk] using - hspecBase_bounds.1 - have hhigh' : - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by - simpa [dotDiffHi, dotDiffLoHi, hbudget, hkey, hk] using - hspecBase_bounds.2 - exact ⟨hlow', hhigh'⟩ - have hmarginAt_le : - ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → - marginAt q ≤ scoreGapLo q k := by - intro q hq k hk - have hmem : k ∈ otherKeys q := by simp [otherKeys, hk] - have hnonempty : (otherKeys q).Nonempty := ⟨k, hmem⟩ - have hle : - (otherKeys q).inf' hnonempty (fun k => scoreGapLo q k) ≤ scoreGapLo q k := by - exact (Finset.inf'_le_iff (s := otherKeys q) (H := hnonempty) - (f := fun k => scoreGapLo q k) (a := scoreGapLo q k)).2 - ⟨k, hmem, le_rfl⟩ - simpa [marginAt, hq, hnonempty] using hle - have hscore_gap_real_at : - ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → - scoresReal q k + (scoreGapLo q k : Real) ≤ scoresReal q (inputs.prev q) := by - intro q hq k hk - by_cases hprevmask : masked q (inputs.prev q) - · have hscore_hi : scoresReal q k ≤ (scoreHi q k : Real) := - (hscore_bounds q k).2 - have hscore_prev : (scoreLoPrev q : Real) ≤ scoresReal q (inputs.prev q) := by - have hprev_bounds := hscore_bounds q (inputs.prev q) - simpa [scoreLoPrev] using hprev_bounds.1 - have hsum_le' : - (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k ≤ - (scoreLoPrev q : Real) := by - have hsub : - (scoreLoPrev q : Real) - (scoreHi q k : Real) ≤ - (scoreLoPrev q : Real) - scoresReal q k := - sub_le_sub_left hscore_hi (scoreLoPrev q : Real) - calc - (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k - ≤ (scoreLoPrev q : Real) - scoresReal q k + scoresReal q k := by - simpa [add_comm, add_left_comm, add_assoc] using - (add_le_add_left hsub (scoresReal q k)) - _ = (scoreLoPrev q : Real) := by - simp [sub_add_cancel] - calc - scoresReal q k + (scoreGapLo q k : Real) - = (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k := by - simp [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, - hprevmask, add_comm] - _ ≤ (scoreLoPrev q : Real) := hsum_le' - _ ≤ scoresReal q (inputs.prev q) := hscore_prev - · by_cases hmask : masked q k - · have hscore_prev : (scoreLoPrev q : Real) ≤ scoresReal q (inputs.prev q) := by - have hprev_bounds := hscore_bounds q (inputs.prev q) - simpa [scoreLoPrev] using hprev_bounds.1 - have hscore_k : scoresReal q k = (inputs.maskValue : Real) := - scoresReal_eq_masked q k hmask - calc - scoresReal q k + (scoreGapLo q k : Real) - = (inputs.maskValue : Real) + (scoreLoPrev q : Real) - - (inputs.maskValue : Real) := by - simp [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, - hprevmask, hmask, hscore_k] - _ = (scoreLoPrev q : Real) := by - simp [add_sub_cancel_left] - _ ≤ scoresReal q (inputs.prev q) := hscore_prev - · have hdiff := hdot_diff_bounds q hq k hmask - have hgap_le : - (scoreGapLo q k : Real) ≤ - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) := by - by_cases hscale : 0 ≤ inputs.scale - · have hscale_real : 0 ≤ (inputs.scale : Real) := - by - simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hscale - have hle := mul_le_mul_of_nonneg_left hdiff.1 hscale_real - simpa [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, - hprevmask, hmask, hscale] using hle - · have hscale_nonpos : inputs.scale ≤ 0 := - le_of_lt (lt_of_not_ge hscale) - have hscale_real : (inputs.scale : Real) ≤ 0 := - by - simpa [ratToReal_def] using - (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos - have hle := mul_le_mul_of_nonpos_left hdiff.2 hscale_real - simpa [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, - hprevmask, hmask, hscale] using hle - have hscore_prev : - scoresReal q (inputs.prev q) = - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d) := by - simpa using - (scoresReal_eq_base_of_not_masked q (inputs.prev q) hprevmask) - have hscore_k : - scoresReal q k = - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) := by - simpa using (scoresReal_eq_base_of_not_masked q k hmask) - have hdot_sub : - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) = - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d) - - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) := by - classical - simpa using - (Nfp.Sound.Linear.dotProduct_sub_right - (x := fun d => qRealOfInputs inputs q d) - (y := fun d => kRealOfInputs inputs (inputs.prev q) d) - (z := fun d => kRealOfInputs inputs k d)) - have hscore_diff : - scoresReal q (inputs.prev q) - scoresReal q k = - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) := by - calc - scoresReal q (inputs.prev q) - scoresReal q k - = - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d) - - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) := by - simp [hscore_prev, hscore_k] - _ = - (inputs.scale : Real) * - (dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d) - - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d)) := by - simp [mul_sub] - _ = - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) := by - simp [hdot_sub] - have hgap_le' : - (scoreGapLo q k : Real) ≤ - scoresReal q (inputs.prev q) - scoresReal q k := by - simpa [hscore_diff] using hgap_le - have hgap_add := - add_le_add_right hgap_le' (scoresReal q k) - have hgap_add' : - scoresReal q k + (scoreGapLo q k : Real) ≤ - scoresReal q (inputs.prev q) := by - have hcancel : - scoresReal q k + (scoresReal q (inputs.prev q) - scoresReal q k) = - scoresReal q (inputs.prev q) := by - calc - scoresReal q k + (scoresReal q (inputs.prev q) - scoresReal q k) - = - scoresReal q k + scoresReal q (inputs.prev q) - - scoresReal q k := by - symm - exact add_sub_assoc (scoresReal q k) - (scoresReal q (inputs.prev q)) (scoresReal q k) - _ = scoresReal q (inputs.prev q) := by - simp [add_sub_cancel_left] - calc - scoresReal q k + (scoreGapLo q k : Real) - ≤ scoresReal q k + (scoresReal q (inputs.prev q) - scoresReal q k) := - hgap_add - _ = scoresReal q (inputs.prev q) := hcancel - exact hgap_add' - let softmaxWeights := Circuit.softmaxWeights scoresReal - let weights : Fin seq → Fin seq → Real := fun q k => - Circuit.softmax (scoresReal q) k - let others : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - have hscore_margin_real_at : - ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → - scoresReal q k + (marginAt q : Real) ≤ scoresReal q (inputs.prev q) := by - intro q hq k hk - have hmargin_le : marginAt q ≤ scoreGapLo q k := - hmarginAt_le q hq k hk - have hmargin_le_real : (marginAt q : Real) ≤ (scoreGapLo q k : Real) := - by - simpa [ratToReal_def] using ratToReal_le_of_le hmargin_le - have hscore_gap := hscore_gap_real_at q hq k hk - have hstep := add_le_add_right hmargin_le_real (scoresReal q k) - have hstep' : - scoresReal q k + (marginAt q : Real) ≤ - scoresReal q k + (scoreGapLo q k : Real) := by - exact hstep - exact hstep'.trans hscore_gap - have hscore_margin_real : - ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → - scoresReal q k + (margin : Real) ≤ scoresReal q (inputs.prev q) := by - intro q hq k hk - have hmargin_le : margin ≤ marginAt q := by - have hmem : q ∈ inputs.active := hq - have hnonempty : inputs.active.Nonempty := hactive - have hle := - (Finset.inf'_le_iff (s := inputs.active) (H := hnonempty) - (f := marginAt) (a := marginAt q)).2 ⟨q, hmem, le_rfl⟩ - simpa [margin, hnonempty] using hle - have hmargin_le_real : (margin : Real) ≤ (marginAt q : Real) := - by - simpa [ratToReal_def] using ratToReal_le_of_le hmargin_le - have hscore := hscore_margin_real_at q hq k hk - have hscore' : - (marginAt q : Real) + scoresReal q k ≤ scoresReal q (inputs.prev q) := by - simpa [add_comm] using hscore - have hstep := add_le_add_right hmargin_le_real (scoresReal q k) - have hstep' : - scoresReal q k + (margin : Real) ≤ (marginAt q : Real) + scoresReal q k := by - calc - scoresReal q k + (margin : Real) ≤ scoresReal q k + (marginAt q : Real) := hstep - _ = (marginAt q : Real) + scoresReal q k := by - simp [add_comm] - exact hstep'.trans hscore' - have hweightBoundAtBase : - ∀ q k, k ≠ inputs.prev q → - weightBoundAtBase q k = - if scoreGapLo q k < 0 then - (1 : Rat) - else - ratDivUp 1 (1 + scoreGapLo q k) := by - intro q k hk - simp [weightBoundAtBase, hk] - have hweightBoundAt : - ∀ q k, - weightBoundAt q k = weightBoundAtBase q k := by - intro q k - simp [weightBoundAt, weightBoundAtBaseCached, Bounds.cacheBound2Task_apply] - have hepsAt : - ∀ q, epsAt q = - min (1 : Rat) - ((otherKeys q).sum (fun k => - if scoreGapLo q k < 0 then - (1 : Rat) - else - ratDivUp 1 (1 + scoreGapLo q k))) := by - intro q - have hsum : - (otherKeys q).sum (fun k => weightBoundAtBaseCached q k) = - (otherKeys q).sum (fun k => weightBoundAtBase q k) := by - refine Finset.sum_congr rfl ?_ - intro k hk - simp [weightBoundAtBaseCached, Bounds.cacheBound2Task_apply] - have hsum' : - (otherKeys q).sum (fun k => weightBoundAtBase q k) = - (otherKeys q).sum (fun k => - if scoreGapLo q k < 0 then - (1 : Rat) - else - ratDivUp 1 (1 + scoreGapLo q k)) := by - refine Finset.sum_congr rfl ?_ - intro k hk - have hk' : k ≠ inputs.prev q := (Finset.mem_erase.mp hk).1 - simp [hweightBoundAtBase q k hk'] - simpa [epsAt, epsAtBase, hsum, hsum'] using - (Bounds.cacheBoundThunk_apply (f := epsAtBase) q) - have oneHot_bounds_at : - ∀ q, q ∈ inputs.active → - Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAt q : Real) - (fun q' => q' = q) inputs.prev weights := by - intro q hq - exact - Sound.oneHot_bounds_at_of_scoreGapLo - (active := inputs.active) - (prev := inputs.prev) - (scoresReal := scoresReal) - (scoreGapLo := scoreGapLo) - (epsAt := epsAt) - (hepsAt := hepsAt) - (hscore_gap_real_at := hscore_gap_real_at) - q hq - have weight_bounds_at : - ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → - weights q k ≤ (weightBoundAt q k : Real) := by - intro q hq k hk - have hbound_base : - weights q k ≤ (weightBoundAtBase q k : Real) := by - exact - Sound.weight_bound_at_of_scoreGapLo - (active := inputs.active) - (prev := inputs.prev) - (scoresReal := scoresReal) - (scoreGapLo := scoreGapLo) - (weightBoundAt := weightBoundAtBase) - (hweightBoundAt := hweightBoundAtBase) - (hscore_gap_real_at := hscore_gap_real_at) - q hq k hk - have hweightBoundAt_real : - (weightBoundAt q k : Real) = - (weightBoundAtBase q k : Real) := by - have hbase : weightBoundAt q k = weightBoundAtBase q k := - hweightBoundAt q k - have hbase' : - ratToReal (weightBoundAt q k) = ratToReal (weightBoundAtBase q k) := - congrArg ratToReal hbase - simpa [ratToReal_def] using hbase' - simpa [hweightBoundAt_real] using hbound_base - have hepsAt_le_eps : - ∀ q, q ∈ inputs.active → epsAt q ≤ eps := by - intro q hq - have hle : - epsAt q ≤ inputs.active.sup' hactive epsAt := by - exact - (Finset.le_sup'_iff (s := inputs.active) (H := hactive) - (f := epsAt) (a := epsAt q)).2 ⟨q, hq, le_rfl⟩ - simpa [eps, hactive] using hle - have hepsAt_le_eps_real : - ∀ q, q ∈ inputs.active → (epsAt q : Real) ≤ (eps : Real) := by - intro q hq - simpa [ratToReal_def] using ratToReal_le_of_le (hepsAt_le_eps q hq) - have hsoftmax_bounds : - Layers.SoftmaxMarginBoundsOn (Val := Real) (eps : Real) (margin : Real) - (fun q => q ∈ inputs.active) inputs.prev scoresReal weights := by - classical - refine - { score_margin := ?_ - nonneg := ?_ - sum_one := ?_ - prev_large := ?_ - other_le := ?_ } - · intro q hq k hk - exact hscore_margin_real q hq k hk - · intro q _ k - simpa [weights, softmaxWeights, Circuit.softmaxWeights_weights] using - softmaxWeights.nonneg q k - · intro q _ - simpa [weights, softmaxWeights, Circuit.softmaxWeights_weights] using - softmaxWeights.sum_one q - · intro q hq - have honehot := oneHot_bounds_at q hq - have hprev := honehot.prev_large q rfl - have hle : - weights q (inputs.prev q) + (epsAt q : Real) ≤ - weights q (inputs.prev q) + (eps : Real) := by - simpa [add_comm] using - (add_le_add_right (hepsAt_le_eps_real q hq) (weights q (inputs.prev q))) - exact hprev.trans hle - · intro q hq k hk - have honehot := oneHot_bounds_at q hq - have hother := honehot.other_le q rfl k hk - exact hother.trans (hepsAt_le_eps_real q hq) - have hdirHead : - dirHead = fun d => (dirHeadVecOfInputs inputs).get d := by - simp [dirHead, dirHeadVec] - have hwvDir : - ∀ j, wvDir j = Linear.dotFin dHead dirHead (fun d => inputs.wv j d) := by - intro j - simp [wvDir, Bounds.cacheBoundTask_apply] - have hbDir : - bDir = Linear.dotFin dHead dirHead (fun d => inputs.bv d) := by - rfl - have hdir_wv : - ∀ j, (wvDir j : Real) = ∑ d, (dirHead d : Real) * (inputs.wv j d : Real) := - wvDir_real_eq_sum inputs dirHead wvDir hwvDir - have hdir_bv : - (bDir : Real) = ∑ d, (dirHead d : Real) * (inputs.bv d : Real) := - bDir_real_eq_sum inputs dirHead bDir hbDir - have hvals_eq : - ∀ k, - valsRealOfInputs inputs k = - dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + - (bDir : Real) := - valsReal_eq_of_dir inputs dirHead hdirHead wvDir bDir hdir_wv hdir_bv - have hvals_bounds_at : - ∀ k, - (valCert.valsLo k : Real) ≤ valsRealOfInputs inputs k ∧ - valsRealOfInputs inputs k ≤ (valCert.valsHi k : Real) := by - intro k - have hln := hln_bounds k - have hlo : ∀ j, (lnLo k j : Real) ≤ lnRealOfInputs inputs k j := fun j => - (hln j).1 - have hhi : ∀ j, lnRealOfInputs inputs k j ≤ (lnHi k j : Real) := fun j => - (hln j).2 - have hlow' : - (Bounds.dotIntervalLower (fun j => wvDir j) (lnLo k) (lnHi k) : Real) + - (bDir : Real) ≤ - dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + - (bDir : Real) := by - simpa using - (Bounds.dotIntervalLower_le_dotProduct_real_add - (v := fun j => wvDir j) - (lo := lnLo k) (hi := lnHi k) - (x := lnRealOfInputs inputs k) (b := (bDir : Real)) hlo hhi) - have hhigh' : - dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + - (bDir : Real) ≤ - (Bounds.dotIntervalUpper (fun j => wvDir j) (lnLo k) (lnHi k) : Real) + - (bDir : Real) := by - simpa using - (Bounds.dotProduct_le_dotIntervalUpper_real_add - (v := fun j => wvDir j) - (lo := lnLo k) (hi := lnHi k) - (x := lnRealOfInputs inputs k) (b := (bDir : Real)) hlo hhi) - have hlow : - (valCert.valsLo k : Real) ≤ valsRealOfInputs inputs k := by - simpa [valCert, valsLo, hvals_eq k, ratToReal_add, add_comm, add_left_comm, - add_assoc] using hlow' - have hhigh : - valsRealOfInputs inputs k ≤ (valCert.valsHi k : Real) := by - simpa [valCert, valsHi, hvals_eq k, ratToReal_add, add_comm, add_left_comm, - add_assoc] using hhigh' - exact ⟨hlow, hhigh⟩ - have hvals_bounds : - ValueIntervalBounds (vals := valsRealOfInputs inputs) valCert := by - refine - { lo_le_hi := ?_ - lo_le_valsLo := ?_ - vals_bounds := ?_ - valsHi_le_hi := ?_ } - · rcases (Finset.univ_nonempty : univ.Nonempty) with ⟨k0, hk0⟩ - have hmem0 : k0 ∈ univ := hk0 - have hlo : (valCert.lo : Real) ≤ (valCert.valsLo k0 : Real) := by - have hloRat : valCert.lo ≤ valCert.valsLo k0 := by - change lo ≤ valsLo k0 - dsimp [lo] - refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) - (f := valsLo) (a := valsLo k0)).2 ⟨k0, hmem0, le_rfl⟩ - simpa [ratToReal_def] using ratToReal_le_of_le hloRat - have hvals : - (valCert.valsLo k0 : Real) ≤ valsRealOfInputs inputs k0 ∧ - valsRealOfInputs inputs k0 ≤ (valCert.valsHi k0 : Real) := by - exact hvals_bounds_at k0 - have hhi : (valCert.valsHi k0 : Real) ≤ (valCert.hi : Real) := by - have hhiRat : valCert.valsHi k0 ≤ valCert.hi := by - change valsHi k0 ≤ hi - dsimp [hi] - refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) - (f := valsHi) (a := valsHi k0)).2 ⟨k0, ⟨hmem0, le_rfl⟩⟩ - simpa [ratToReal_def] using ratToReal_le_of_le hhiRat - have hreal : - (valCert.lo : Real) ≤ (valCert.hi : Real) := - le_trans hlo (le_trans hvals.1 (le_trans hvals.2 hhi)) - have hreal' : ratToReal valCert.lo ≤ ratToReal valCert.hi := by - simpa [ratToReal_def] using hreal - exact (ratToReal_le_iff (x := valCert.lo) (y := valCert.hi)).1 hreal' - · intro k - have hloRat : valCert.lo ≤ valCert.valsLo k := by - change lo ≤ valsLo k - dsimp [lo] - refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) - (f := valsLo) (a := valsLo k)).2 ⟨k, by simp [univ], le_rfl⟩ - simpa [ratToReal_def] using ratToReal_le_of_le hloRat - · intro k - exact hvals_bounds_at k - · intro k - have hhiRat : valCert.valsHi k ≤ valCert.hi := by - change valsHi k ≤ hi - dsimp [hi] - refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) - (f := valsHi) (a := valsHi k)).2 ⟨k, by simp [univ], le_rfl⟩ - simpa [ratToReal_def] using ratToReal_le_of_le hhiRat - exact - { softmax_bounds := hsoftmax_bounds - oneHot_bounds_at := oneHot_bounds_at - weight_bounds_at := weight_bounds_at - value_bounds := hvals_bounds } - · have : False := by - have hnone := - buildInductionCertFromHeadCoreWith?_eq_none_of_not_active - (cfg := cfg) (inputs := inputs) hEps hSqrt hmodel hactive - have hcore' : - (none : Option (InductionHeadCert seq)) = some c := by - exact hnone.symm.trans hcore - cases hcore' - exact this.elim - · have : False := by - have hnone := - buildInductionCertFromHeadCoreWith?_eq_none_of_not_sqrt - (cfg := cfg) (inputs := inputs) hEps hSqrt - have hcore' : - (none : Option (InductionHeadCert seq)) = some c := by - exact hnone.symm.trans hcore - cases hcore' - exact this.elim - · have : False := by - have hnone := - buildInductionCertFromHeadCoreWith?_eq_none_of_not_eps - (cfg := cfg) (inputs := inputs) hEps - have hcore' : - (none : Option (InductionHeadCert seq)) = some c := by - exact hnone.symm.trans hcore - cases hcore' - exact this.elim +public import Nfp.Sound.Induction.CoreSound.Basic.CertSound +public import Nfp.Sound.Induction.CoreSound.Basic.CacheBounds +public import Nfp.Sound.Induction.CoreSound.Basic.DefaultSound -/-- Soundness for `buildInductionCertFromHeadCore?`. -/ -theorem buildInductionCertFromHeadCore?_sound - [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) - (hcore : buildInductionCertFromHeadCore? inputs = some c) : - InductionHeadCertSound inputs c := by - have hcore' : - buildInductionCertFromHeadCoreWith? defaultInductionHeadSplitConfig inputs = some c := by - simpa [buildInductionCertFromHeadCore?_def] using hcore - exact - buildInductionCertFromHeadCoreWith?_sound - (cfg := defaultInductionHeadSplitConfig) inputs c hcore' -end Sound -end Nfp +/-! +Core soundness proofs for induction-head certificates. +-/ diff --git a/Nfp/Sound/Induction/CoreSound/Basic/CacheBounds.lean b/Nfp/Sound/Induction/CoreSound/Basic/CacheBounds.lean new file mode 100644 index 0000000..626a97f --- /dev/null +++ b/Nfp/Sound/Induction/CoreSound/Basic/CacheBounds.lean @@ -0,0 +1,615 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later +module + +import all Nfp.Sound.Induction.Core.Basic +public import Nfp.Sound.Induction.Core +public import Nfp.Sound.Induction.CoreSound.Values + +public section + +namespace Nfp +namespace Sound +open scoped BigOperators +open Nfp.Circuit +open Nfp.Sound.Bounds +variable {seq : Nat} +/-- Bounds for cached projections and scores from `buildInductionHeadCoreCacheWith`. -/ +theorem buildInductionHeadCoreCacheWith_bounds + [NeZero seq] {dModel dHead : Nat} + (cfg : InductionHeadSplitConfig) + (inputs : Model.InductionHeadInputs seq dModel dHead) + (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) + (hmodel : dModel ≠ 0) : + (∀ q d, + ((buildInductionHeadCoreCacheWith cfg inputs).qLo q d : Real) ≤ + qRealOfInputs inputs q d ∧ + qRealOfInputs inputs q d ≤ + ((buildInductionHeadCoreCacheWith cfg inputs).qHi q d : Real)) ∧ + (∀ q d, + ((buildInductionHeadCoreCacheWith cfg inputs).kLo q d : Real) ≤ + kRealOfInputs inputs q d ∧ + kRealOfInputs inputs q d ≤ + ((buildInductionHeadCoreCacheWith cfg inputs).kHi q d : Real)) ∧ + (∀ q k, + ((buildInductionHeadCoreCacheWith cfg inputs).scoreLo q k : Real) ≤ + scoresRealOfInputs inputs q k ∧ + scoresRealOfInputs inputs q k ≤ + ((buildInductionHeadCoreCacheWith cfg inputs).scoreHi q k : Real)) := by + classical + set cache := buildInductionHeadCoreCacheWith cfg inputs with hcache + have dotFin_cast {n : Nat} (f g : Fin n → Rat) : + (Linear.dotFin n f g : Real) = + dotProduct (fun j => (f j : Real)) (fun j => (g j : Real)) := by + simp [Linear.dotFin_def, Linear.sumFin_eq_sum_univ, dotProduct] + let lnCoeff : Fin seq → Fin dModel → Rat := fun q j => + inputs.ln1Gamma j * (inputs.embed q j - mean (inputs.embed q)) + let invStd : Fin seq → Real := fun q => + (Real.sqrt ((varianceRat (inputs.embed q) : Real) + (inputs.lnEps : Real)))⁻¹ + have hmeanRat : + ∀ q, (mean (inputs.embed q) : Real) = meanRat (inputs.embed q) := by + intro q + have hmu_rat : mean (inputs.embed q) = meanRat (inputs.embed q) := by + simp [mean_def, hmodel, ratRoundDown_def] + simpa [ratToReal_def] using congrArg ratToReal hmu_rat + have hln_affine : + ∀ q j, + lnRealOfInputs inputs q j = + (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q := by + intro q j + have hmu := hmeanRat q + simp [lnRealOfInputs_def, Bounds.layerNormReal_def, hmodel, lnCoeff, hmu, invStd, + add_comm, mul_assoc, -mul_eq_mul_left_iff, -mul_eq_mul_right_iff] + have hln_fun : + ∀ q, + lnRealOfInputs inputs q = + fun j => (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q := by + intro q + funext j + exact hln_affine q j + let invStdBoundsTasks : Array (Task (Rat × Rat)) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => invStdBounds inputs.lnEps (inputs.embed q))) + let invStdBoundsArr : Array (Rat × Rat) := + Array.ofFn (fun q : Fin seq => + (invStdBoundsTasks[q.1]'(by + simp [invStdBoundsTasks])).get) + let invStdLo : Fin seq → Rat := fun q => + (invStdBoundsArr[q.1]'(by + simp [invStdBoundsArr])).1 + let invStdHi : Fin seq → Rat := fun q => + (invStdBoundsArr[q.1]'(by + simp [invStdBoundsArr])).2 + have hinv_bounds : + ∀ q, (invStdLo q : Real) ≤ invStd q ∧ invStd q ≤ (invStdHi q : Real) := by + intro q + simpa [invStd, invStdLo, invStdHi, invStdBoundsArr, invStdBoundsTasks, + Bounds.invStdBounds_def, Task.spawn, Array.getElem_ofFn] using + (Bounds.invStdBounds_spec (eps := inputs.lnEps) (x := inputs.embed q) + hmodel hEps hSqrt) + have proj_bounds + (w : Fin dModel → Fin dHead → Rat) + (b base : Fin dHead → Rat) + (coeff : Fin seq → Fin dHead → Rat) + (hbase : ∀ d, + (base d : Real) = + dotProduct (fun j => (w j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + (b d : Real)) + (hcoeff : ∀ q d, + (coeff q d : Real) = + dotProduct (fun j => (w j d : Real)) + (fun j => (lnCoeff q j : Real))) : + ∀ q d, + (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).1 : Rat) ≤ + dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + + (b d : Real) ∧ + dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + + (b d : Real) ≤ + (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).2 : Rat) := by + intro q d + have hinv : (invStdLo q : Real) ≤ invStd q ∧ invStd q ≤ (invStdHi q : Real) := + hinv_bounds q + have hln_fun_q : + lnRealOfInputs inputs q = + fun j => (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q := by + exact hln_fun q + have hdot_add : + dotProduct (fun j => (w j d : Real)) + (fun j => + (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q) = + dotProduct (fun j => (w j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + dotProduct (fun j => (w j d : Real)) + (fun j => (lnCoeff q j : Real) * invStd q) := by + simpa using + (Nfp.Sound.Linear.dotProduct_add_right + (x := fun j => (w j d : Real)) + (y := fun j => (inputs.ln1Beta j : Real)) + (z := fun j => (lnCoeff q j : Real) * invStd q)) + have hdot_coeff : + dotProduct (fun j => (w j d : Real)) + (fun j => (lnCoeff q j : Real) * invStd q) = + dotProduct (fun j => (w j d : Real)) + (fun j => (lnCoeff q j : Real)) * + invStd q := by + simpa using + (Nfp.Sound.Linear.dotProduct_mul_right + (x := fun j => (w j d : Real)) + (y := fun j => (lnCoeff q j : Real)) + (a := invStd q)) + have hreal : + dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + + (b d : Real) = + (base d : Real) + (coeff q d : Real) * invStd q := by + calc + dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + + (b d : Real) = + dotProduct (fun j => (w j d : Real)) + (fun j => + (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q) + + (b d : Real) := by + simp [hln_fun_q] + _ = + dotProduct (fun j => (w j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + dotProduct (fun j => (w j d : Real)) + (fun j => (lnCoeff q j : Real)) * + invStd q + + (b d : Real) := by + simp [hdot_add, hdot_coeff, add_assoc] + _ = + (dotProduct (fun j => (w j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + (b d : Real)) + + dotProduct (fun j => (w j d : Real)) + (fun j => (lnCoeff q j : Real)) * + invStd q := by ac_rfl + _ = (base d : Real) + (coeff q d : Real) * invStd q := by + simp [hbase, hcoeff] + have hscale : + let bounds := scaleInterval (coeff q d) (invStdLo q) (invStdHi q) + (bounds.1 : Real) ≤ (coeff q d : Real) * invStd q ∧ + (coeff q d : Real) * invStd q ≤ (bounds.2 : Real) := by + exact scaleInterval_bounds_real (x := coeff q d) (lo := invStdLo q) + (hi := invStdHi q) (y := invStd q) hinv.1 hinv.2 + have hlow : + (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).1 : Rat) ≤ + dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + + (b d : Real) := by + simpa [hreal] using add_le_add_left hscale.1 (base d : Real) + have hhigh : + dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + + (b d : Real) ≤ + (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).2 : Rat) := by + simpa [hreal] using add_le_add_left hscale.2 (base d : Real) + exact ⟨hlow, hhigh⟩ + let qBaseArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.wq j d) (fun j => inputs.ln1Beta j) + + inputs.bq d) + let qBase : Fin dHead → Rat := fun d => + qBaseArr[d.1]'(by + simp [qBaseArr]) + let kBaseArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.wk j d) (fun j => inputs.ln1Beta j) + + inputs.bk d) + let kBase : Fin dHead → Rat := fun d => + kBaseArr[d.1]'(by + simp [kBaseArr]) + let coeffRowTasks : + (Fin dModel → Fin dHead → Rat) → + Array (Task { row : Array Rat // row.size = dHead }) := + fun w => + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + let μ := mean (inputs.embed q) + let coeff : Fin dModel → Rat := fun j => + inputs.ln1Gamma j * (inputs.embed q j - μ) + ⟨Array.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => w j d) coeff), + by simp⟩)) + let qCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := + coeffRowTasks inputs.wq + let qCoeffArr : Array { row : Array Rat // row.size = dHead } := + Array.ofFn (fun q : Fin seq => + (qCoeffRowTasks[q.1]'(by + simp [qCoeffRowTasks, coeffRowTasks])).get) + let qCoeff : Fin seq → Fin dHead → Rat := fun q d => + let row := qCoeffArr[q.1]'(by + simp [qCoeffArr]) + row.1[d.1]'(by + simp [row.2]) + let kCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := + coeffRowTasks inputs.wk + let kCoeffArr : Array { row : Array Rat // row.size = dHead } := + Array.ofFn (fun q : Fin seq => + (kCoeffRowTasks[q.1]'(by + simp [kCoeffRowTasks, coeffRowTasks])).get) + let kCoeff : Fin seq → Fin dHead → Rat := fun q d => + let row := kCoeffArr[q.1]'(by + simp [kCoeffArr]) + row.1[d.1]'(by + simp [row.2]) + let qLo : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) + qBase d + bounds.1 + let qHi : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) + qBase d + bounds.2 + let kLo : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) + kBase d + bounds.1 + let kHi : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) + kBase d + bounds.2 + let qAbs : Fin seq → Fin dHead → Rat := fun q d => max |qLo q d| |qHi q d| + let kAbs : Fin seq → Fin dHead → Rat := fun q d => max |kLo q d| |kHi q d| + let qAbsMaxArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := Finset.univ_nonempty + univ.sup' hnonempty (fun q => qAbs q d)) + let qAbsMax : Fin dHead → Rat := fun d => + qAbsMaxArr[d.1]'(by + simp [qAbsMaxArr]) + let kAbsMaxArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := Finset.univ_nonempty + univ.sup' hnonempty (fun k => kAbs k d)) + let kAbsMax : Fin dHead → Rat := fun d => + kAbsMaxArr[d.1]'(by + simp [kAbsMaxArr]) + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let splitBudgetQ : Nat := cfg.splitBudgetQ + let splitBudgetK : Nat := cfg.splitBudgetK + let finRangeHead : List (Fin dHead) := List.finRange dHead + let splitDimsQ : Fin seq → List (Fin dHead) := fun q => + if splitBudgetQ = 0 then + [] + else + let ambig := + finRangeHead.filter (fun d => decide (qLo q d < 0 ∧ 0 < qHi q d)) + let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d + let step + (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) + (d : Fin dHead) : + Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := + let s := score d + match best with + | (none, none) => (some (s, d), none) + | (some b1, none) => + if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) + | (some b1, some b2) => + if b1.1 < s then (some (s, d), some b1) + else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) + | (none, some b2) => + if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) + let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => + match ambig.foldl step (none, none) with + | (some b1, some b2) => [b1.2, b2.2] + | (some b1, none) => [b1.2] + | (none, _) => [] + let dims1 := top2 ambig + let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) + (dims1 ++ dims2).take splitBudgetQ + let splitDimsK : Fin seq → Fin seq → List (Fin dHead) := fun q k => + if splitBudgetK = 0 then + [] + else + let ambig := + finRangeHead.filter (fun d => decide (kLo k d < 0 ∧ 0 < kHi k d)) + let score : Fin dHead → Rat := fun d => (kHi k d - kLo k d) * qAbs q d + let step + (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) + (d : Fin dHead) : + Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := + let s := score d + match best with + | (none, none) => (some (s, d), none) + | (some b1, none) => + if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) + | (some b1, some b2) => + if b1.1 < s then (some (s, d), some b1) + else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) + | (none, some b2) => + if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) + let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => + match ambig.foldl step (none, none) with + | (some b1, some b2) => [b1.2, b2.2] + | (some b1, none) => [b1.2] + | (none, _) => [] + let dims1 := top2 ambig + let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) + (dims1 ++ dims2).take splitBudgetK + let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + let dimsQ := splitDimsQ q + ⟨Array.ofFn (fun k : Fin seq => + if masked q k then + (0, 0) + else + let dimsK := splitDimsK q k + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsK + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo k d) (fun d => kHi k d)), + by simp⟩)) + let dotLo : Fin seq → Fin seq → Rat := fun q k => + let row := (dotRowTasks[q.1]'(by + simp [dotRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.1 + let dotHi : Fin seq → Fin seq → Rat := fun q k => + let row := (dotRowTasks[q.1]'(by + simp [dotRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.2 + let scoreLo : Fin seq → Fin seq → Rat := fun q k => + if masked q k then + inputs.maskValue + else + if hscale : 0 ≤ inputs.scale then + inputs.scale * dotLo q k + else + inputs.scale * dotHi q k + let scoreHi : Fin seq → Fin seq → Rat := fun q k => + if masked q k then + inputs.maskValue + else + if hscale : 0 ≤ inputs.scale then + inputs.scale * dotHi q k + else + inputs.scale * dotLo q k + have dotLo_eq : + ∀ q k, + dotLo q k = + if masked q k then + (0 : Rat) + else + let dimsQ := splitDimsQ q + let dimsK := splitDimsK q k + (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsK + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo k d) (fun d => kHi k d)).1 := by + intro q k + classical + by_cases hmk : masked q k + · simp [dotLo, dotRowTasks, Task.spawn, Array.getElem_ofFn, hmk] + · simp [dotLo, dotRowTasks, Task.spawn, Array.getElem_ofFn, hmk] + have dotHi_eq : + ∀ q k, + dotHi q k = + if masked q k then + (0 : Rat) + else + let dimsQ := splitDimsQ q + let dimsK := splitDimsK q k + (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsK + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo k d) (fun d => kHi k d)).2 := by + intro q k + classical + by_cases hmk : masked q k + · simp [dotHi, dotRowTasks, Task.spawn, Array.getElem_ofFn, hmk] + · simp [dotHi, dotRowTasks, Task.spawn, Array.getElem_ofFn, hmk] + have hq_bounds_local : + ∀ q d, (qLo q d : Real) ≤ qRealOfInputs inputs q d ∧ + qRealOfInputs inputs q d ≤ (qHi q d : Real) := by + intro q d + have hbase : + ∀ d, + (qBase d : Real) = + dotProduct (fun j => (inputs.wq j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + (inputs.bq d : Real) := by + intro d + simp [qBase, qBaseArr, dotFin_cast] + have hcoeff : + ∀ q' d, + (qCoeff q' d : Real) = + dotProduct (fun j => (inputs.wq j d : Real)) + (fun j => (lnCoeff q' j : Real)) := by + intro q' d + simpa [qCoeff, qCoeffArr, qCoeffRowTasks, coeffRowTasks, lnCoeff, Task.spawn] using + (dotFin_cast (f := fun j => inputs.wq j d) + (g := fun j => + inputs.ln1Gamma j * (inputs.embed q' j - mean (inputs.embed q')))) + have h := proj_bounds (w := inputs.wq) (b := inputs.bq) (base := qBase) + (coeff := qCoeff) hbase hcoeff q d + simpa [qLo, qHi, qRealOfInputs_def] using h + have hk_bounds_local : + ∀ q d, (kLo q d : Real) ≤ kRealOfInputs inputs q d ∧ + kRealOfInputs inputs q d ≤ (kHi q d : Real) := by + intro q d + have hbase : + ∀ d, + (kBase d : Real) = + dotProduct (fun j => (inputs.wk j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + (inputs.bk d : Real) := by + intro d + simp [kBase, kBaseArr, dotFin_cast] + have hcoeff : + ∀ q' d, + (kCoeff q' d : Real) = + dotProduct (fun j => (inputs.wk j d : Real)) + (fun j => (lnCoeff q' j : Real)) := by + intro q' d + simpa [kCoeff, kCoeffArr, kCoeffRowTasks, coeffRowTasks, lnCoeff, Task.spawn] using + (dotFin_cast (f := fun j => inputs.wk j d) + (g := fun j => + inputs.ln1Gamma j * (inputs.embed q' j - mean (inputs.embed q')))) + have h := proj_bounds (w := inputs.wk) (b := inputs.bk) (base := kBase) + (coeff := kCoeff) hbase hcoeff q d + simpa [kLo, kHi, kRealOfInputs_def] using h + let scoresReal := scoresRealOfInputs inputs + have scoresReal_eq_base_of_not_masked : + ∀ q k, ¬ masked q k → + scoresReal q k = + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) := by + intro q k hnot + by_cases hcausal : inputs.maskCausal + · have hnot_lt : ¬ q < k := by + intro hlt + exact hnot ⟨hcausal, hlt⟩ + have hle : k ≤ q := le_of_not_gt hnot_lt + simp [scoresReal, scoresRealOfInputs_def, hcausal, hle] + · simp [scoresReal, scoresRealOfInputs_def, hcausal] + have scoresReal_eq_masked : + ∀ q k, masked q k → scoresReal q k = (inputs.maskValue : Real) := by + intro q k hmask + have hmask' : inputs.maskCausal = true ∧ q < k := by + simpa [masked] using hmask + have hle : ¬ k ≤ q := not_le_of_gt hmask'.2 + simp [scoresReal, scoresRealOfInputs_def, hmask'.1, hle] + have hscore_bounds_local : + ∀ q k, (scoreLo q k : Real) ≤ scoresReal q k ∧ + scoresReal q k ≤ (scoreHi q k : Real) := by + intro q k + let base := + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) + have hdot_bounds (hnot : ¬ masked q k) : + (dotLo q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) ∧ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) ≤ (dotHi q k : Real) := by + have hq := hq_bounds_local q + have hk := hk_bounds_local k + have hlo1 : ∀ d, (qLo q d : Real) ≤ qRealOfInputs inputs q d := fun d => + (hq d).1 + have hhi1 : ∀ d, qRealOfInputs inputs q d ≤ (qHi q d : Real) := fun d => + (hq d).2 + have hlo2 : ∀ d, (kLo k d : Real) ≤ kRealOfInputs inputs k d := fun d => + (hk d).1 + have hhi2 : ∀ d, kRealOfInputs inputs k d ≤ (kHi k d : Real) := fun d => + (hk d).2 + have hspec := + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth_spec_real + (dims1 := splitDimsQ q) (dims2 := splitDimsK q k) + (lo1 := fun d => qLo q d) (hi1 := fun d => qHi q d) + (lo2 := fun d => kLo k d) (hi2 := fun d => kHi k d) + (x := fun d => qRealOfInputs inputs q d) + (y := fun d => kRealOfInputs inputs k d) + hlo1 hhi1 hlo2 hhi2 + have hlow' : + (dotLo q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) := by + simpa [dotLo_eq, hnot] using hspec.1 + have hhigh' : + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) ≤ (dotHi q k : Real) := by + simpa [dotHi_eq, hnot] using hspec.2 + exact ⟨hlow', hhigh'⟩ + have hscore_base_bounds (hnot : ¬ masked q k) : + (scoreLo q k : Real) ≤ base ∧ base ≤ (scoreHi q k : Real) := by + by_cases hscale : 0 ≤ inputs.scale + · have hscale_real : 0 ≤ (inputs.scale : Real) := by + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hscale + have hdot := hdot_bounds hnot + have hlow := mul_le_mul_of_nonneg_left hdot.1 hscale_real + have hhigh := mul_le_mul_of_nonneg_left hdot.2 hscale_real + constructor + · simpa [scoreLo, masked, hnot, hscale, base] using hlow + · simpa [scoreHi, masked, hnot, hscale, base] using hhigh + · have hscale_nonpos : inputs.scale ≤ 0 := + le_of_lt (lt_of_not_ge hscale) + have hscale_real : (inputs.scale : Real) ≤ 0 := by + simpa [ratToReal_def] using + (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos + have hdot := hdot_bounds hnot + have hlow := mul_le_mul_of_nonpos_left hdot.2 hscale_real + have hhigh := mul_le_mul_of_nonpos_left hdot.1 hscale_real + constructor + · simpa [scoreLo, masked, hnot, hscale, base] using hlow + · simpa [scoreHi, masked, hnot, hscale, base] using hhigh + by_cases hcausal : inputs.maskCausal + · by_cases hle : k ≤ q + · have hnot : ¬ q < k := not_lt_of_ge hle + have hnot_masked : ¬ masked q k := fun hmk => hnot hmk.2 + have hscore_eq : scoresReal q k = base := + scoresReal_eq_base_of_not_masked q k hnot_masked + have hbase := hscore_base_bounds hnot_masked + constructor + · simpa [hscore_eq] using hbase.1 + · simpa [hscore_eq] using hbase.2 + · have hlt : q < k := lt_of_not_ge hle + have hmask : masked q k := ⟨hcausal, hlt⟩ + have hscore : scoresReal q k = (inputs.maskValue : Real) := + scoresReal_eq_masked q k hmask + constructor + · simp [hscore, scoreLo, hmask, masked] + · simp [hscore, scoreHi, hmask, masked] + · have hnot_masked : ¬ masked q k := by + simp [masked, hcausal] + have hscore_eq : scoresReal q k = base := + scoresReal_eq_base_of_not_masked q k hnot_masked + have hbase := hscore_base_bounds hnot_masked + constructor + · simpa [hscore_eq] using hbase.1 + · simpa [hscore_eq] using hbase.2 + have hlocal : + (∀ q d, (qLo q d : Real) ≤ qRealOfInputs inputs q d ∧ + qRealOfInputs inputs q d ≤ (qHi q d : Real)) ∧ + (∀ q d, (kLo q d : Real) ≤ kRealOfInputs inputs q d ∧ + kRealOfInputs inputs q d ≤ (kHi q d : Real)) ∧ + (∀ q k, (scoreLo q k : Real) ≤ scoresReal q k ∧ + scoresReal q k ≤ (scoreHi q k : Real)) := by + exact ⟨hq_bounds_local, hk_bounds_local, hscore_bounds_local⟩ + simpa (config := { zeta := false }) [hcache, buildInductionHeadCoreCacheWith] using + hlocal + +/-- Query bounds for `buildInductionHeadCoreCacheWith`. -/ +theorem buildInductionHeadCoreCacheWith_q_bounds + [NeZero seq] {dModel dHead : Nat} + (cfg : InductionHeadSplitConfig) + (inputs : Model.InductionHeadInputs seq dModel dHead) + (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) + (hmodel : dModel ≠ 0) : + ∀ q d, + ((buildInductionHeadCoreCacheWith cfg inputs).qLo q d : Real) ≤ + qRealOfInputs inputs q d ∧ + qRealOfInputs inputs q d ≤ + ((buildInductionHeadCoreCacheWith cfg inputs).qHi q d : Real) := by + exact (buildInductionHeadCoreCacheWith_bounds cfg inputs hEps hSqrt hmodel).1 + +/-- Key bounds for `buildInductionHeadCoreCacheWith`. -/ +theorem buildInductionHeadCoreCacheWith_k_bounds + [NeZero seq] {dModel dHead : Nat} + (cfg : InductionHeadSplitConfig) + (inputs : Model.InductionHeadInputs seq dModel dHead) + (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) + (hmodel : dModel ≠ 0) : + ∀ q d, + ((buildInductionHeadCoreCacheWith cfg inputs).kLo q d : Real) ≤ + kRealOfInputs inputs q d ∧ + kRealOfInputs inputs q d ≤ + ((buildInductionHeadCoreCacheWith cfg inputs).kHi q d : Real) := by + exact (buildInductionHeadCoreCacheWith_bounds cfg inputs hEps hSqrt hmodel).2.1 + +/-- Score bounds for `buildInductionHeadCoreCacheWith`. -/ +theorem buildInductionHeadCoreCacheWith_score_bounds + [NeZero seq] {dModel dHead : Nat} + (cfg : InductionHeadSplitConfig) + (inputs : Model.InductionHeadInputs seq dModel dHead) + (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) + (hmodel : dModel ≠ 0) : + ∀ q k, + ((buildInductionHeadCoreCacheWith cfg inputs).scoreLo q k : Real) ≤ + scoresRealOfInputs inputs q k ∧ + scoresRealOfInputs inputs q k ≤ + ((buildInductionHeadCoreCacheWith cfg inputs).scoreHi q k : Real) := by + exact (buildInductionHeadCoreCacheWith_bounds cfg inputs hEps hSqrt hmodel).2.2 + + +end Sound +end Nfp diff --git a/Nfp/Sound/Induction/CoreSound/Basic/CertSound.lean b/Nfp/Sound/Induction/CoreSound/Basic/CertSound.lean new file mode 100644 index 0000000..1981093 --- /dev/null +++ b/Nfp/Sound/Induction/CoreSound/Basic/CertSound.lean @@ -0,0 +1,1316 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later +module + +import all Nfp.Sound.Induction.Core.Basic +public import Nfp.Sound.Induction.Core +public import Nfp.Sound.Induction.CoreSound.Values + +public section + +namespace Nfp +namespace Sound +open scoped BigOperators +open Nfp.Circuit +open Nfp.Sound.Bounds +variable {seq : Nat} +/-- Soundness for `buildInductionCertFromHeadCoreWith?`. -/ +theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : Nat} + (cfg : InductionHeadSplitConfig) + (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) + (hcore : buildInductionCertFromHeadCoreWith? cfg inputs = some c) : + InductionHeadCertSound inputs c := by + classical + by_cases hEps : 0 < inputs.lnEps + · by_cases hSqrt : 0 < sqrtLower inputs.lnEps + · by_cases hmodel : dModel = 0 + · have : False := by + have hnone := + buildInductionCertFromHeadCoreWith?_eq_none_of_model_eq_zero + (cfg := cfg) (inputs := inputs) hEps hSqrt hmodel + have hcore' : + (none : Option (InductionHeadCert seq)) = some c := by + exact hnone.symm.trans hcore + cases hcore' + exact this.elim + · by_cases hactive : inputs.active.Nonempty + · let lnBounds := Bounds.cacheBoundPair2 (fun q => + Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) + let lnLo : Fin seq → Fin dModel → Rat := lnBounds.1 + let lnHi : Fin seq → Fin dModel → Rat := lnBounds.2 + let lnAbsMaxTask : Fin seq → Rat := + Bounds.cacheBoundTask (fun q => + Bounds.intervalAbsBound (lnLo q) (lnHi q)) + let lnAbsMaxArr : Array Rat := + Array.ofFn (fun q : Fin seq => lnAbsMaxTask q) + let lnAbsMax : Fin seq → Rat := fun q => + lnAbsMaxArr[q.1]'(by + simp [lnAbsMaxArr]) + let invStdBoundsTasks : Array (Task (Rat × Rat)) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => invStdBounds inputs.lnEps (inputs.embed q))) + let invStdBoundsArr : Array (Rat × Rat) := + Array.ofFn (fun q : Fin seq => + (invStdBoundsTasks[q.1]'(by + simp [invStdBoundsTasks])).get) + let invStdLo : Fin seq → Rat := fun q => + (invStdBoundsArr[q.1]'(by + simp [invStdBoundsArr])).1 + let invStdHi : Fin seq → Rat := fun q => + (invStdBoundsArr[q.1]'(by + simp [invStdBoundsArr])).2 + let lnCoeff : Fin seq → Fin dModel → Rat := fun q j => + inputs.ln1Gamma j * (inputs.embed q j - mean (inputs.embed q)) + let invStd : Fin seq → Real := fun q => + (Real.sqrt ((varianceRat (inputs.embed q) : Real) + (inputs.lnEps : Real)))⁻¹ + have hmeanRat : + ∀ q, (mean (inputs.embed q) : Real) = meanRat (inputs.embed q) := by + intro q + have hmu_rat : mean (inputs.embed q) = meanRat (inputs.embed q) := by + simp [mean_def, hmodel, ratRoundDown_def] + simpa [ratToReal_def] using congrArg ratToReal hmu_rat + have hln_affine : + ∀ q j, + lnRealOfInputs inputs q j = + (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q := by + intro q j + have hmu := hmeanRat q + simp [lnRealOfInputs_def, Bounds.layerNormReal_def, hmodel, lnCoeff, hmu, invStd, + add_comm, mul_assoc, -mul_eq_mul_left_iff, -mul_eq_mul_right_iff] + have hln_fun : + ∀ q, + lnRealOfInputs inputs q = + fun j => (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q := by + intro q + funext j + exact hln_affine q j + have hinv_bounds : + ∀ q, (invStdLo q : Real) ≤ invStd q ∧ invStd q ≤ (invStdHi q : Real) := by + intro q + simpa [invStd, invStdLo, invStdHi, invStdBoundsArr, invStdBoundsTasks, + Bounds.invStdBounds_def, Task.spawn, Array.getElem_ofFn] using + (Bounds.invStdBounds_spec (eps := inputs.lnEps) (x := inputs.embed q) + hmodel hEps hSqrt) + let qBaseArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.wq j d) (fun j => inputs.ln1Beta j) + + inputs.bq d) + let qBase : Fin dHead → Rat := fun d => + qBaseArr[d.1]'(by + simp [qBaseArr]) + let kBaseArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.wk j d) (fun j => inputs.ln1Beta j) + + inputs.bk d) + let kBase : Fin dHead → Rat := fun d => + kBaseArr[d.1]'(by + simp [kBaseArr]) + let coeffRowTasks : + (Fin dModel → Fin dHead → Rat) → + Array (Task { row : Array Rat // row.size = dHead }) := + fun w => + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + let μ := mean (inputs.embed q) + let coeff : Fin dModel → Rat := fun j => + inputs.ln1Gamma j * (inputs.embed q j - μ) + ⟨Array.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => w j d) coeff), + by simp⟩)) + let qCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := + coeffRowTasks inputs.wq + let qCoeffArr : Array { row : Array Rat // row.size = dHead } := + Array.ofFn (fun q : Fin seq => + (qCoeffRowTasks[q.1]'(by + simp [qCoeffRowTasks, coeffRowTasks])).get) + let qCoeff : Fin seq → Fin dHead → Rat := fun q d => + let row := qCoeffArr[q.1]'(by + simp [qCoeffArr]) + row.1[d.1]'(by + simp [row.2]) + let kCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := + coeffRowTasks inputs.wk + let kCoeffArr : Array { row : Array Rat // row.size = dHead } := + Array.ofFn (fun q : Fin seq => + (kCoeffRowTasks[q.1]'(by + simp [kCoeffRowTasks, coeffRowTasks])).get) + let kCoeff : Fin seq → Fin dHead → Rat := fun q d => + let row := kCoeffArr[q.1]'(by + simp [kCoeffArr]) + row.1[d.1]'(by + simp [row.2]) + let qLo : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) + qBase d + bounds.1 + let qHi : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) + qBase d + bounds.2 + let kLo : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) + kBase d + bounds.1 + let kHi : Fin seq → Fin dHead → Rat := fun q d => + let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) + kBase d + bounds.2 + let qAbs : Fin seq → Fin dHead → Rat := fun q d => max |qLo q d| |qHi q d| + let kAbs : Fin seq → Fin dHead → Rat := fun q d => max |kLo q d| |kHi q d| + let qAbsMaxArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := Finset.univ_nonempty + univ.sup' hnonempty (fun q => qAbs q d)) + let qAbsMax : Fin dHead → Rat := fun d => + qAbsMaxArr[d.1]'(by + simp [qAbsMaxArr]) + let kAbsMaxArr : Array Rat := + Array.ofFn (fun d : Fin dHead => + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := Finset.univ_nonempty + univ.sup' hnonempty (fun k => kAbs k d)) + let kAbsMax : Fin dHead → Rat := fun d => + kAbsMaxArr[d.1]'(by + simp [kAbsMaxArr]) + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + let splitBudgetQ : Nat := cfg.splitBudgetQ + let splitBudgetK : Nat := cfg.splitBudgetK + let splitBudgetDiffBase : Nat := cfg.splitBudgetDiffBase + let splitBudgetDiffRefined : Nat := cfg.splitBudgetDiffRefined + let top2ByScore : + (Fin dHead → Rat) → List (Fin dHead) → List (Fin dHead) := fun score ambig => + let step + (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) + (d : Fin dHead) : + Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := + let s := score d + match best with + | (none, none) => (some (s, d), none) + | (some b1, none) => + if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) + | (some b1, some b2) => + if b1.1 < s then (some (s, d), some b1) + else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) + | (none, some b2) => + if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) + match ambig.foldl step (none, none) with + | (some b1, some b2) => [b1.2, b2.2] + | (some b1, none) => [b1.2] + | (none, _) => [] + let finRangeHead : List (Fin dHead) := List.finRange dHead + let finRangeSeq : List (Fin seq) := List.finRange seq + let splitDimsQ : Fin seq → List (Fin dHead) := fun q => + if splitBudgetQ = 0 then + [] + else + let ambig := + finRangeHead.filter (fun d => decide (qLo q d < 0 ∧ 0 < qHi q d)) + let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d + let dims1 := top2ByScore score ambig + let dims2 := top2ByScore score (ambig.filter (fun d => decide (d ∉ dims1))) + (dims1 ++ dims2).take splitBudgetQ + let splitDimsK : Fin seq → Fin seq → List (Fin dHead) := fun q k => + if splitBudgetK = 0 then + [] + else + let ambig := + finRangeHead.filter (fun d => decide (kLo k d < 0 ∧ 0 < kHi k d)) + let score : Fin dHead → Rat := fun d => (kHi k d - kLo k d) * qAbs q d + let dims1 := top2ByScore score ambig + let dims2 := top2ByScore score (ambig.filter (fun d => decide (d ∉ dims1))) + (dims1 ++ dims2).take splitBudgetK + let splitDimsDiffCore : Nat → Fin seq → Fin seq → List (Fin dHead) := + fun budget q k => + if budget = 0 then + [] + else + let prev := inputs.prev q + let diffLo : Fin dHead → Rat := fun d => kLo prev d - kHi k d + let diffHi : Fin dHead → Rat := fun d => kHi prev d - kLo k d + let ambig := + finRangeHead.filter (fun d => decide (diffLo d < 0 ∧ 0 < diffHi d)) + let score : Fin dHead → Rat := fun d => (diffHi d - diffLo d) * qAbs q d + let step + (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) + (d : Fin dHead) : + Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := + let s := score d + match best with + | (none, none) => (some (s, d), none) + | (some b1, none) => + if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) + | (some b1, some b2) => + if b1.1 < s then (some (s, d), some b1) + else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) + | (none, some b2) => + if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) + let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => + match ambig.foldl step (none, none) with + | (some b1, some b2) => [b1.2, b2.2] + | (some b1, none) => [b1.2] + | (none, _) => [] + let dims1 : List (Fin dHead) := top2 ambig + let dims2 : List (Fin dHead) := + top2 (ambig.filter (fun d => decide (d ∉ dims1))) + let memDims2 : Fin dHead → Bool := fun d => + dims2.any (fun d' => decide (d' = d)) + let dims3 : List (Fin dHead) := + top2 + ((ambig.filter (fun d => decide (d ∉ dims1))).filter + (fun d => !memDims2 d)) + (dims1 ++ dims2 ++ dims3).take budget + let splitDimsDiffBase : Fin seq → Fin seq → List (Fin dHead) := + splitDimsDiffCore splitBudgetDiffBase + let splitDimsDiffRefined : Fin seq → Fin seq → List (Fin dHead) := + splitDimsDiffCore splitBudgetDiffRefined + let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + let dimsQ := splitDimsQ q + ⟨Array.ofFn (fun k : Fin seq => + if masked q k then + (0, 0) + else + let dimsK := splitDimsK q k + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsK + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo k d) (fun d => kHi k d)), + by simp⟩)) + let dotDiffRowTasksBase : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := + Array.ofFn (fun q : Fin seq => + Task.spawn (fun _ => + if hq : q ∈ inputs.active then + let dimsQ := splitDimsQ q + ⟨Array.ofFn (fun k : Fin seq => + if masked q k then + (0, 0) + else + let dimsDiff := splitDimsDiffBase q k + let prev := inputs.prev q + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)), + by simp⟩ + else + ⟨Array.ofFn (fun _ : Fin seq => (0, 0)), by simp⟩)) + let dotLo : Fin seq → Fin seq → Rat := fun q k => + let row := (dotRowTasks[q.1]'(by + simp [dotRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.1 + let dotHi : Fin seq → Fin seq → Rat := fun q k => + let row := (dotRowTasks[q.1]'(by + simp [dotRowTasks, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.2 + let dotDiffLoBase : Fin seq → Fin seq → Rat := fun q k => + let row := (dotDiffRowTasksBase[q.1]'(by + simp [dotDiffRowTasksBase, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.1 + let dotDiffHiBase : Fin seq → Fin seq → Rat := fun q k => + let row := (dotDiffRowTasksBase[q.1]'(by + simp [dotDiffRowTasksBase, q.isLt])).get + let entry := row.1[k.1]'(by + simp [row.2, k.isLt]) + entry.2 + let dotAbs : Fin seq → Fin seq → Rat := fun q k => max |dotLo q k| |dotHi q k| + let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => + |inputs.scale| * dotAbs q k + let scoreLo : Fin seq → Fin seq → Rat := fun q k => + if masked q k then + inputs.maskValue + else + if hscale : 0 ≤ inputs.scale then + inputs.scale * dotLo q k + else + inputs.scale * dotHi q k + let scoreHi : Fin seq → Fin seq → Rat := fun q k => + if masked q k then + inputs.maskValue + else + if hscale : 0 ≤ inputs.scale then + inputs.scale * dotHi q k + else + inputs.scale * dotLo q k + let scoreLoPrev : Fin seq → Rat := fun q => + scoreLo q (inputs.prev q) + let scoreGapLoBaseRaw : Fin seq → Fin seq → Rat := fun q k => + if masked q (inputs.prev q) then + scoreLoPrev q - scoreHi q k + else if masked q k then + scoreLoPrev q - inputs.maskValue + else if hscale : 0 ≤ inputs.scale then + inputs.scale * dotDiffLoBase q k + else + inputs.scale * dotDiffHiBase q k + let scoreGapLoBase : Fin seq → Fin seq → Rat := + Bounds.cacheBound2 scoreGapLoBaseRaw + let otherKeys : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + let worstKey : Fin seq → Option (Fin seq) := + if h : cfg.splitBudgetDiffRefined = cfg.splitBudgetDiffBase then + fun _ => none + else + let worstKeyRaw : Fin seq → Option (Fin seq) := fun q => + if hq : q ∈ inputs.active then + let ks := finRangeSeq.filter (fun k => decide (k ≠ inputs.prev q)) + match ks with + | [] => none + | k :: ks => + let step (best : Rat × Fin seq) (k : Fin seq) := + let s := scoreGapLoBase q k + if s ≤ best.1 then (s, k) else best + some (ks.foldl step (scoreGapLoBase q k, k)).2 + else + none + let worstKeyArr : Array (Thunk (Option (Fin seq))) := + Array.ofFn (fun q => Thunk.mk (fun _ => worstKeyRaw q)) + fun q => + let t := worstKeyArr[q.1]'(by + simp [worstKeyArr, q.isLt]) + Thunk.get t + let refineKeys : Fin seq → Finset (Fin seq) := + if h : cfg.splitBudgetDiffRefined = cfg.splitBudgetDiffBase then + fun _ => ∅ + else + let refineKeysRaw : Fin seq → Finset (Fin seq) := fun q => + let base : Finset (Fin seq) := + match worstKey q with + | some k => {k} + | none => ∅ + if hq : q ∈ inputs.active then + let other := otherKeys q + base ∪ other.filter (fun k => decide (scoreGapLoBase q k < 0)) + else + base + let refineKeysArr : Array (Thunk (Finset (Fin seq))) := + Array.ofFn (fun q => Thunk.mk (fun _ => refineKeysRaw q)) + fun q => + let t := refineKeysArr[q.1]'(by + simp [refineKeysArr, q.isLt]) + Thunk.get t + let dotDiffLoHi : (Fin seq → Fin seq → Rat) × (Fin seq → Fin seq → Rat) := + if h : cfg.splitBudgetDiffRefined = cfg.splitBudgetDiffBase then + (dotDiffLoBase, dotDiffHiBase) + else + let dotDiffLo : Fin seq → Fin seq → Rat := fun q k => + if hk : k ∈ refineKeys q then + let dimsQ := splitDimsQ q + let dimsDiff := splitDimsDiffRefined q k + let prev := inputs.prev q + (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)).1 + else + dotDiffLoBase q k + let dotDiffHi : Fin seq → Fin seq → Rat := fun q k => + if hk : k ∈ refineKeys q then + let dimsQ := splitDimsQ q + let dimsDiff := splitDimsDiffRefined q k + let prev := inputs.prev q + (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => qLo q d) (fun d => qHi q d) + (fun d => kLo prev d - kHi k d) + (fun d => kHi prev d - kLo k d)).2 + else + dotDiffHiBase q k + (dotDiffLo, dotDiffHi) + let dotDiffLo : Fin seq → Fin seq → Rat := dotDiffLoHi.1 + let dotDiffHi : Fin seq → Fin seq → Rat := dotDiffLoHi.2 + let scoreGapLoRaw : Fin seq → Fin seq → Rat := fun q k => + if masked q (inputs.prev q) then + scoreLoPrev q - scoreHi q k + else if masked q k then + scoreLoPrev q - inputs.maskValue + else if hscale : 0 ≤ inputs.scale then + inputs.scale * dotDiffLo q k + else + inputs.scale * dotDiffHi q k + let scoreGapLo : Fin seq → Fin seq → Rat := + Bounds.cacheBound2 scoreGapLoRaw + let certFields := buildInductionHeadCertFields inputs otherKeys scoreGapLo + let marginAt : Fin seq → Rat := certFields.marginAt + let weightBoundAtBase : Fin seq → Fin seq → Rat := certFields.weightBoundAtBase + let weightBoundAtBaseCached : Fin seq → Fin seq → Rat := + certFields.weightBoundAtBaseCached + let epsAtBase : Fin seq → Rat := certFields.epsAtBase + let epsAt : Fin seq → Rat := certFields.epsAt + let weightBoundAt : Fin seq → Fin seq → Rat := certFields.weightBoundAt + let margin : Rat := certFields.margin + let eps : Rat := certFields.eps + have hseq : (1 : Nat) ≤ seq := + Nat.succ_le_iff.mpr (Nat.pos_of_ne_zero (NeZero.ne seq)) + let dirHead : Fin dHead → Rat := fun d => (dirHeadVecOfInputs inputs).get d + let wvDirRaw : Fin dModel → Rat := fun j => + Linear.dotFin dHead dirHead (fun d => inputs.wv j d) + let wvDirTask : Fin dModel → Rat := Bounds.cacheBoundTask wvDirRaw + let wvDirArr : Array Rat := Array.ofFn wvDirTask + let wvDir : Fin dModel → Rat := fun j => + wvDirArr[j.1]'(by + have hsize : wvDirArr.size = dModel := by + simp [wvDirArr] + simp [hsize, j.isLt]) + let bDir : Rat := + Linear.dotFin dHead dirHead (fun d => inputs.bv d) + let valsLo : Fin seq → Rat := fun q => + bDir + Bounds.dotIntervalLower wvDir (lnLo q) (lnHi q) + let valsHi : Fin seq → Rat := fun q => + bDir + Bounds.dotIntervalUpper wvDir (lnLo q) (lnHi q) + let hvalsLo : ∀ k, + valsLo k = bDir + Bounds.dotIntervalLower wvDir (lnLo k) (lnHi k) := fun _ => rfl + let hvalsHi : ∀ k, + valsHi k = bDir + Bounds.dotIntervalUpper wvDir (lnLo k) (lnHi k) := fun _ => rfl + let univ : Finset (Fin seq) := Finset.univ + have hnonempty : univ.Nonempty := by simp [univ] + let lo := univ.inf' hnonempty valsLo + let hi := univ.sup' hnonempty valsHi + let valCert : ValueInterval seq := buildInductionHeadValCert inputs valsLo valsHi + let cert : InductionHeadCert seq := + { eps := eps + epsAt := epsAt + weightBoundAt := weightBoundAt + margin := margin + active := inputs.active + prev := inputs.prev + values := valCert } + have hcore' : buildInductionCertFromHeadCoreWith? cfg inputs = some cert := by + have hcore'' : + buildInductionCertFromHeadCoreWith? cfg inputs = + some (buildInductionHeadCoreCacheWith cfg inputs).cert := + buildInductionCertFromHeadCoreWith?_eq_some + (cfg := cfg) (inputs := inputs) hEps hSqrt hmodel hactive + have hcert_eq : + (buildInductionHeadCoreCacheWith cfg inputs).cert = cert := by + rfl + simpa [hcert_eq] using hcore'' + have hc : c = cert := by + have hcert : cert = c := by + exact Option.some.inj (hcore'.symm.trans hcore) + simpa using hcert.symm + subst hc + have hln_bounds : + ∀ q i, (lnLo q i : Real) ≤ lnRealOfInputs inputs q i ∧ + lnRealOfInputs inputs q i ≤ (lnHi q i : Real) := by + intro q i + have hln := + Bounds.layerNormBounds_spec (eps := inputs.lnEps) + (gamma := inputs.ln1Gamma) (beta := inputs.ln1Beta) + (x := inputs.embed q) hmodel hEps hSqrt + simpa [lnBounds, lnLo, lnHi, lnRealOfInputs_def, + Bounds.cacheBoundPair2_apply_left, Bounds.cacheBoundPair2_apply_right] using + hln i + have dotFin_cast {n : Nat} (f g : Fin n → Rat) : + (Linear.dotFin n f g : Real) = + dotProduct (fun j => (f j : Real)) (fun j => (g j : Real)) := by + simp [Linear.dotFin_def, Linear.sumFin_eq_sum_univ, dotProduct] + have proj_bounds + (w : Fin dModel → Fin dHead → Rat) + (b base : Fin dHead → Rat) + (coeff : Fin seq → Fin dHead → Rat) + (hbase : ∀ d, + (base d : Real) = + dotProduct (fun j => (w j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + (b d : Real)) + (hcoeff : ∀ q d, + (coeff q d : Real) = + dotProduct (fun j => (w j d : Real)) + (fun j => (lnCoeff q j : Real))) : + ∀ q d, + (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).1 : Rat) ≤ + dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + + (b d : Real) ∧ + dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + + (b d : Real) ≤ + (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).2 : Rat) := by + intro q d + have hinv : (invStdLo q : Real) ≤ invStd q ∧ invStd q ≤ (invStdHi q : Real) := + hinv_bounds q + have hln_fun_q : + lnRealOfInputs inputs q = + fun j => (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q := by + exact hln_fun q + have hdot_add : + dotProduct (fun j => (w j d : Real)) + (fun j => + (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q) = + dotProduct (fun j => (w j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + dotProduct (fun j => (w j d : Real)) + (fun j => (lnCoeff q j : Real) * invStd q) := by + simpa using + (Nfp.Sound.Linear.dotProduct_add_right + (x := fun j => (w j d : Real)) + (y := fun j => (inputs.ln1Beta j : Real)) + (z := fun j => (lnCoeff q j : Real) * invStd q)) + have hdot_coeff : + dotProduct (fun j => (w j d : Real)) + (fun j => (lnCoeff q j : Real) * invStd q) = + dotProduct (fun j => (w j d : Real)) + (fun j => (lnCoeff q j : Real)) * + invStd q := by + simpa using + (Nfp.Sound.Linear.dotProduct_mul_right + (x := fun j => (w j d : Real)) + (y := fun j => (lnCoeff q j : Real)) + (a := invStd q)) + have hreal : + dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + + (b d : Real) = + (base d : Real) + (coeff q d : Real) * invStd q := by + calc + dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + + (b d : Real) = + dotProduct (fun j => (w j d : Real)) + (fun j => + (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q) + + (b d : Real) := by + simp [hln_fun_q] + _ = + dotProduct (fun j => (w j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + dotProduct (fun j => (w j d : Real)) + (fun j => (lnCoeff q j : Real)) * + invStd q + + (b d : Real) := by + simp [hdot_add, hdot_coeff, add_assoc] + _ = + (dotProduct (fun j => (w j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + (b d : Real)) + + dotProduct (fun j => (w j d : Real)) + (fun j => (lnCoeff q j : Real)) * + invStd q := by ac_rfl + _ = (base d : Real) + (coeff q d : Real) * invStd q := by + simp [hbase, hcoeff] + have hscale : + let bounds := scaleInterval (coeff q d) (invStdLo q) (invStdHi q) + (bounds.1 : Real) ≤ (coeff q d : Real) * invStd q ∧ + (coeff q d : Real) * invStd q ≤ (bounds.2 : Real) := by + exact scaleInterval_bounds_real (x := coeff q d) (lo := invStdLo q) + (hi := invStdHi q) (y := invStd q) hinv.1 hinv.2 + have hlow : + (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).1 : Rat) ≤ + dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + + (b d : Real) := by + simpa [hreal] using add_le_add_left hscale.1 (base d : Real) + have hhigh : + dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + + (b d : Real) ≤ + (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).2 : Rat) := by + simpa [hreal] using add_le_add_left hscale.2 (base d : Real) + exact ⟨hlow, hhigh⟩ + have hq_bounds : + ∀ q d, (qLo q d : Real) ≤ qRealOfInputs inputs q d ∧ + qRealOfInputs inputs q d ≤ (qHi q d : Real) := by + intro q d + have hbase : + ∀ d, + (qBase d : Real) = + dotProduct (fun j => (inputs.wq j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + (inputs.bq d : Real) := by + intro d + simp [qBase, qBaseArr, dotFin_cast] + have hcoeff : + ∀ q' d, + (qCoeff q' d : Real) = + dotProduct (fun j => (inputs.wq j d : Real)) + (fun j => (lnCoeff q' j : Real)) := by + intro q' d + simpa [qCoeff, qCoeffArr, qCoeffRowTasks, coeffRowTasks, lnCoeff, Task.spawn] using + (dotFin_cast (f := fun j => inputs.wq j d) + (g := fun j => + inputs.ln1Gamma j * (inputs.embed q' j - mean (inputs.embed q')))) + have h := proj_bounds (w := inputs.wq) (b := inputs.bq) (base := qBase) + (coeff := qCoeff) hbase hcoeff q d + simpa [qLo, qHi, qRealOfInputs_def] using h + have hk_bounds : + ∀ q d, (kLo q d : Real) ≤ kRealOfInputs inputs q d ∧ + kRealOfInputs inputs q d ≤ (kHi q d : Real) := by + intro q d + have hbase : + ∀ d, + (kBase d : Real) = + dotProduct (fun j => (inputs.wk j d : Real)) + (fun j => (inputs.ln1Beta j : Real)) + + (inputs.bk d : Real) := by + intro d + simp [kBase, kBaseArr, dotFin_cast] + have hcoeff : + ∀ q' d, + (kCoeff q' d : Real) = + dotProduct (fun j => (inputs.wk j d : Real)) + (fun j => (lnCoeff q' j : Real)) := by + intro q' d + simpa [kCoeff, kCoeffArr, kCoeffRowTasks, coeffRowTasks, lnCoeff, Task.spawn] using + (dotFin_cast (f := fun j => inputs.wk j d) + (g := fun j => + inputs.ln1Gamma j * (inputs.embed q' j - mean (inputs.embed q')))) + have h := proj_bounds (w := inputs.wk) (b := inputs.bk) (base := kBase) + (coeff := kCoeff) hbase hcoeff q d + simpa [kLo, kHi, kRealOfInputs_def] using h + let scoresReal := scoresRealOfInputs inputs + have scoresReal_eq_base_of_not_masked : + ∀ q k, ¬ masked q k → + scoresReal q k = + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) := by + intro q k hnot + by_cases hcausal : inputs.maskCausal + · have hnot_lt : ¬ q < k := by + intro hlt + exact hnot ⟨hcausal, hlt⟩ + have hle : k ≤ q := le_of_not_gt hnot_lt + simp [scoresReal, scoresRealOfInputs_def, hcausal, hle] + · simp [scoresReal, scoresRealOfInputs_def, hcausal] + have scoresReal_eq_masked : + ∀ q k, masked q k → scoresReal q k = (inputs.maskValue : Real) := by + intro q k hmask + have hmask' : inputs.maskCausal = true ∧ q < k := by + simpa [masked] using hmask + have hle : ¬ k ≤ q := not_le_of_gt hmask'.2 + simp [scoresReal, scoresRealOfInputs_def, hmask'.1, hle] + have hscore_bounds : + ∀ q k, (scoreLo q k : Real) ≤ scoresReal q k ∧ + scoresReal q k ≤ (scoreHi q k : Real) := by + intro q k + let base := + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) + have hdot_bounds (hnot : ¬ masked q k) : + (dotLo q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) ∧ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) ≤ (dotHi q k : Real) := by + have hq := hq_bounds q + have hk := hk_bounds k + have hlo1 : ∀ d, (qLo q d : Real) ≤ qRealOfInputs inputs q d := fun d => + (hq d).1 + have hhi1 : ∀ d, qRealOfInputs inputs q d ≤ (qHi q d : Real) := fun d => + (hq d).2 + have hlo2 : ∀ d, (kLo k d : Real) ≤ kRealOfInputs inputs k d := fun d => + (hk d).1 + have hhi2 : ∀ d, kRealOfInputs inputs k d ≤ (kHi k d : Real) := fun d => + (hk d).2 + have hspec := + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth_spec_real + (dims1 := splitDimsQ q) (dims2 := splitDimsK q k) + (lo1 := fun d => qLo q d) (hi1 := fun d => qHi q d) + (lo2 := fun d => kLo k d) (hi2 := fun d => kHi k d) + (x := fun d => qRealOfInputs inputs q d) + (y := fun d => kRealOfInputs inputs k d) + hlo1 hhi1 hlo2 hhi2 + have hlow' : + (dotLo q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) := by + simpa [dotLo, dotRowTasks, Task.spawn, Array.getElem_ofFn, hnot] + using hspec.1 + have hhigh' : + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) ≤ (dotHi q k : Real) := by + simpa [dotHi, dotRowTasks, Task.spawn, Array.getElem_ofFn, hnot] + using hspec.2 + exact ⟨hlow', hhigh'⟩ + have hscore_base_bounds (hnot : ¬ masked q k) : + (scoreLo q k : Real) ≤ base ∧ base ≤ (scoreHi q k : Real) := by + by_cases hscale : 0 ≤ inputs.scale + · have hscale_real : 0 ≤ (inputs.scale : Real) := + by + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hscale + have hdot := hdot_bounds hnot + have hlow := mul_le_mul_of_nonneg_left hdot.1 hscale_real + have hhigh := mul_le_mul_of_nonneg_left hdot.2 hscale_real + have hscoreLo : scoreLo q k = inputs.scale * dotLo q k := by + simp [scoreLo, masked, hnot, hscale] + have hscoreHi : scoreHi q k = inputs.scale * dotHi q k := by + simp [scoreHi, masked, hnot, hscale] + constructor + · simpa [hscoreLo, base, Rat.cast_mul] using hlow + · simpa [hscoreHi, base, Rat.cast_mul] using hhigh + · have hscale_nonpos : inputs.scale ≤ 0 := + le_of_lt (lt_of_not_ge hscale) + have hscale_real : (inputs.scale : Real) ≤ 0 := + by + simpa [ratToReal_def] using + (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos + have hdot := hdot_bounds hnot + have hlow := mul_le_mul_of_nonpos_left hdot.2 hscale_real + have hhigh := mul_le_mul_of_nonpos_left hdot.1 hscale_real + have hscoreLo : scoreLo q k = inputs.scale * dotHi q k := by + simp [scoreLo, masked, hnot, hscale] + have hscoreHi : scoreHi q k = inputs.scale * dotLo q k := by + simp [scoreHi, masked, hnot, hscale] + constructor + · simpa [hscoreLo, base, Rat.cast_mul] using hlow + · simpa [hscoreHi, base, Rat.cast_mul] using hhigh + by_cases hcausal : inputs.maskCausal + · by_cases hle : k ≤ q + · have hnot : ¬ q < k := not_lt_of_ge hle + have hnot_masked : ¬ masked q k := fun hmk => hnot hmk.2 + have hscore_eq : scoresReal q k = base := + scoresReal_eq_base_of_not_masked q k hnot_masked + have hbase := hscore_base_bounds hnot_masked + constructor + · simpa [hscore_eq] using hbase.1 + · simpa [hscore_eq] using hbase.2 + · have hlt : q < k := lt_of_not_ge hle + have hmask : masked q k := ⟨hcausal, hlt⟩ + have hscore : scoresReal q k = (inputs.maskValue : Real) := + scoresReal_eq_masked q k hmask + constructor + · simp [hscore, scoreLo, hmask] + · simp [hscore, scoreHi, hmask] + · have hnot_masked : ¬ masked q k := by + simp [masked, hcausal] + have hscore_eq : scoresReal q k = base := + scoresReal_eq_base_of_not_masked q k hnot_masked + have hbase := hscore_base_bounds hnot_masked + constructor + · simpa [hscore_eq] using hbase.1 + · simpa [hscore_eq] using hbase.2 + have hdot_diff_bounds : + ∀ q, q ∈ inputs.active → ∀ k, ¬ masked q k → + (dotDiffLo q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) ∧ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by + intro q hq k hmask + have hq_bounds' := hq_bounds q + have hkprev := hk_bounds (inputs.prev q) + have hk := hk_bounds k + have hlo1 : ∀ d, (qLo q d : Real) ≤ qRealOfInputs inputs q d := fun d => + (hq_bounds' d).1 + have hhi1 : ∀ d, qRealOfInputs inputs q d ≤ (qHi q d : Real) := fun d => + (hq_bounds' d).2 + have hlo2 : + ∀ d, + (kLo (inputs.prev q) d - kHi k d : Rat) ≤ + (kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) := by + intro d + have hprev_lo := (hkprev d).1 + have hk_hi := (hk d).2 + have h := sub_le_sub hprev_lo hk_hi + simpa [ratToReal_sub] using h + have hhi2 : + ∀ d, + (kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) ≤ + (kHi (inputs.prev q) d - kLo k d : Rat) := by + intro d + have hprev_hi := (hkprev d).2 + have hk_lo := (hk d).1 + have h := sub_le_sub hprev_hi hk_lo + simpa [ratToReal_sub] using h + have hspec (dimsDiff : List (Fin dHead)) := + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth_spec_real + (dims1 := splitDimsQ q) (dims2 := dimsDiff) + (lo1 := fun d => qLo q d) (hi1 := fun d => qHi q d) + (lo2 := fun d => kLo (inputs.prev q) d - kHi k d) + (hi2 := fun d => kHi (inputs.prev q) d - kLo k d) + (x := fun d => qRealOfInputs inputs q d) + (y := fun d => + kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) + hlo1 hhi1 hlo2 hhi2 + have hspecBase := hspec (splitDimsDiffBase q k) + have hspecRef := hspec (splitDimsDiffRefined q k) + have hspecBase_bounds : + (dotDiffLoBase q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) ∧ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) ≤ (dotDiffHiBase q k : Real) := by + refine ⟨?_, ?_⟩ + · simpa [dotDiffLoBase, dotDiffRowTasksBase, hq, hmask, Task.spawn, + Array.getElem_ofFn] using hspecBase.1 + · simpa [dotDiffHiBase, dotDiffRowTasksBase, hq, hmask, Task.spawn, + Array.getElem_ofFn] using hspecBase.2 + by_cases hbudget : cfg.splitBudgetDiffRefined = cfg.splitBudgetDiffBase + · simpa [dotDiffLo, dotDiffHi, dotDiffLoHi, hbudget] using hspecBase_bounds + · by_cases hmem : k ∈ refineKeys q + · have hlow' : + (dotDiffLo q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) := by + simpa [dotDiffLo, dotDiffLoHi, hbudget, hmem] using hspecRef.1 + have hhigh' : + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by + simpa [dotDiffHi, dotDiffLoHi, hbudget, hmem] using hspecRef.2 + exact ⟨hlow', hhigh'⟩ + · have hlow' : + (dotDiffLo q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) := by + simpa [dotDiffLo, dotDiffLoHi, hbudget, hmem] using + hspecBase_bounds.1 + have hhigh' : + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by + simpa [dotDiffHi, dotDiffLoHi, hbudget, hmem] using + hspecBase_bounds.2 + exact ⟨hlow', hhigh'⟩ + have hmarginAt_le : + ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → + marginAt q ≤ scoreGapLo q k := by + intro q hq k hk + have hmem : k ∈ otherKeys q := by simp [otherKeys, hk] + have hnonempty : (otherKeys q).Nonempty := ⟨k, hmem⟩ + have hmarginAt_eq : + marginAt q = + (otherKeys q).inf' hnonempty (fun k => scoreGapLo q k) := by + simp [marginAt, certFields, buildInductionHeadCertFields_def, hq, hnonempty] + have hle : + (otherKeys q).inf' hnonempty (fun k => scoreGapLo q k) ≤ scoreGapLo q k := by + exact (Finset.inf'_le_iff (s := otherKeys q) (H := hnonempty) + (f := fun k => scoreGapLo q k) (a := scoreGapLo q k)).2 + ⟨k, hmem, le_rfl⟩ + simpa [hmarginAt_eq] using hle + have hscore_gap_real_at : + ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → + scoresReal q k + (scoreGapLo q k : Real) ≤ scoresReal q (inputs.prev q) := by + intro q hq k hk + by_cases hprevmask : masked q (inputs.prev q) + · have hscore_hi : scoresReal q k ≤ (scoreHi q k : Real) := + (hscore_bounds q k).2 + have hscore_prev : (scoreLoPrev q : Real) ≤ scoresReal q (inputs.prev q) := by + have hprev_bounds := hscore_bounds q (inputs.prev q) + simpa [scoreLoPrev] using hprev_bounds.1 + have hsum_le' : + (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k ≤ + (scoreLoPrev q : Real) := by + have hsub : + (scoreLoPrev q : Real) - (scoreHi q k : Real) ≤ + (scoreLoPrev q : Real) - scoresReal q k := + sub_le_sub_left hscore_hi (scoreLoPrev q : Real) + calc + (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k + ≤ (scoreLoPrev q : Real) - scoresReal q k + scoresReal q k := by + simpa [add_comm, add_left_comm, add_assoc] using + (add_le_add_left hsub (scoresReal q k)) + _ = (scoreLoPrev q : Real) := by + simp [sub_add_cancel] + calc + scoresReal q k + (scoreGapLo q k : Real) + = (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k := by + simp [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, + hprevmask, add_comm] + _ ≤ (scoreLoPrev q : Real) := hsum_le' + _ ≤ scoresReal q (inputs.prev q) := hscore_prev + · by_cases hmask : masked q k + · have hscore_prev : (scoreLoPrev q : Real) ≤ scoresReal q (inputs.prev q) := by + have hprev_bounds := hscore_bounds q (inputs.prev q) + simpa [scoreLoPrev] using hprev_bounds.1 + have hscore_k : scoresReal q k = (inputs.maskValue : Real) := + scoresReal_eq_masked q k hmask + calc + scoresReal q k + (scoreGapLo q k : Real) + = (inputs.maskValue : Real) + (scoreLoPrev q : Real) - + (inputs.maskValue : Real) := by + simp [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, + hprevmask, hmask, hscore_k] + _ = (scoreLoPrev q : Real) := by + simp [add_sub_cancel_left] + _ ≤ scoresReal q (inputs.prev q) := hscore_prev + · have hdiff := hdot_diff_bounds q hq k hmask + have hgap_le : + (scoreGapLo q k : Real) ≤ + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) := by + by_cases hscale : 0 ≤ inputs.scale + · have hscale_real : 0 ≤ (inputs.scale : Real) := + by + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hscale + have hle := mul_le_mul_of_nonneg_left hdiff.1 hscale_real + simpa [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, + hprevmask, hmask, hscale] using hle + · have hscale_nonpos : inputs.scale ≤ 0 := + le_of_lt (lt_of_not_ge hscale) + have hscale_real : (inputs.scale : Real) ≤ 0 := + by + simpa [ratToReal_def] using + (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos + have hle := mul_le_mul_of_nonpos_left hdiff.2 hscale_real + simpa [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, + hprevmask, hmask, hscale] using hle + have hscore_prev : + scoresReal q (inputs.prev q) = + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d) := by + simpa using + (scoresReal_eq_base_of_not_masked q (inputs.prev q) hprevmask) + have hscore_k : + scoresReal q k = + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) := by + simpa using (scoresReal_eq_base_of_not_masked q k hmask) + have hdot_sub : + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) = + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d) - + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) := by + classical + simpa using + (Nfp.Sound.Linear.dotProduct_sub_right + (x := fun d => qRealOfInputs inputs q d) + (y := fun d => kRealOfInputs inputs (inputs.prev q) d) + (z := fun d => kRealOfInputs inputs k d)) + have hscore_diff : + scoresReal q (inputs.prev q) - scoresReal q k = + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) := by + calc + scoresReal q (inputs.prev q) - scoresReal q k + = + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d) - + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) := by + simp [hscore_prev, hscore_k] + _ = + (inputs.scale : Real) * + (dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d) - + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d)) := by + simp [mul_sub] + _ = + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) := by + simp [hdot_sub] + have hgap_le' : + (scoreGapLo q k : Real) ≤ + scoresReal q (inputs.prev q) - scoresReal q k := by + simpa [hscore_diff] using hgap_le + have hgap_add := + add_le_add_right hgap_le' (scoresReal q k) + have hgap_add' : + scoresReal q k + (scoreGapLo q k : Real) ≤ + scoresReal q (inputs.prev q) := by + have hcancel : + scoresReal q k + (scoresReal q (inputs.prev q) - scoresReal q k) = + scoresReal q (inputs.prev q) := by + calc + scoresReal q k + (scoresReal q (inputs.prev q) - scoresReal q k) + = + scoresReal q k + scoresReal q (inputs.prev q) - + scoresReal q k := by + symm + exact add_sub_assoc (scoresReal q k) + (scoresReal q (inputs.prev q)) (scoresReal q k) + _ = scoresReal q (inputs.prev q) := by + simp [add_sub_cancel_left] + calc + scoresReal q k + (scoreGapLo q k : Real) + ≤ scoresReal q k + (scoresReal q (inputs.prev q) - scoresReal q k) := + hgap_add + _ = scoresReal q (inputs.prev q) := hcancel + exact hgap_add' + let softmaxWeights := Circuit.softmaxWeights scoresReal + let weights : Fin seq → Fin seq → Real := fun q k => + Circuit.softmax (scoresReal q) k + let others : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + have hscore_margin_real_at : + ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → + scoresReal q k + (marginAt q : Real) ≤ scoresReal q (inputs.prev q) := by + intro q hq k hk + have hmargin_le : marginAt q ≤ scoreGapLo q k := + hmarginAt_le q hq k hk + have hmargin_le_real : (marginAt q : Real) ≤ (scoreGapLo q k : Real) := + by + simpa [ratToReal_def] using ratToReal_le_of_le hmargin_le + have hscore_gap := hscore_gap_real_at q hq k hk + have hstep := add_le_add_right hmargin_le_real (scoresReal q k) + have hstep' : + scoresReal q k + (marginAt q : Real) ≤ + scoresReal q k + (scoreGapLo q k : Real) := by + exact hstep + exact hstep'.trans hscore_gap + have hscore_margin_real : + ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → + scoresReal q k + (margin : Real) ≤ scoresReal q (inputs.prev q) := by + intro q hq k hk + have hmargin_le : margin ≤ marginAt q := by + have hmem : q ∈ inputs.active := hq + have hnonempty : inputs.active.Nonempty := hactive + have hle := + (Finset.inf'_le_iff (s := inputs.active) (H := hnonempty) + (f := marginAt) (a := marginAt q)).2 ⟨q, hmem, le_rfl⟩ + simpa [margin, certFields, buildInductionHeadCertFields_def, hnonempty] using hle + have hmargin_le_real : (margin : Real) ≤ (marginAt q : Real) := + by + simpa [ratToReal_def] using ratToReal_le_of_le hmargin_le + have hscore := hscore_margin_real_at q hq k hk + have hscore' : + (marginAt q : Real) + scoresReal q k ≤ scoresReal q (inputs.prev q) := by + simpa [add_comm] using hscore + have hstep := add_le_add_right hmargin_le_real (scoresReal q k) + have hstep' : + scoresReal q k + (margin : Real) ≤ (marginAt q : Real) + scoresReal q k := by + calc + scoresReal q k + (margin : Real) ≤ scoresReal q k + (marginAt q : Real) := hstep + _ = (marginAt q : Real) + scoresReal q k := by + simp [add_comm] + exact hstep'.trans hscore' + have hweightBoundAtBase : + ∀ q k, k ≠ inputs.prev q → + weightBoundAtBase q k = + if scoreGapLo q k < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + scoreGapLo q k) := by + intro q k hk + simp [weightBoundAtBase, certFields, buildInductionHeadCertFields_def, hk] + have hweightBoundAt : + ∀ q k, + weightBoundAt q k = weightBoundAtBase q k := by + intro q k + by_cases hk : k = inputs.prev q + · simp [weightBoundAt, weightBoundAtBase, certFields, + buildInductionHeadCertFields_def, Bounds.cacheBound2Task_apply, + hk] + · simp [weightBoundAt, weightBoundAtBase, certFields, + buildInductionHeadCertFields_def, Bounds.cacheBound2Task_apply, + hk] + have hepsAt : + ∀ q, epsAt q = + min (1 : Rat) + ((otherKeys q).sum (fun k => + if scoreGapLo q k < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + scoreGapLo q k))) := by + intro q + have hsum : + (otherKeys q).sum (fun k => + Bounds.cacheBound2Task + (fun q k => + if k = inputs.prev q then + (0 : Rat) + else if scoreGapLo q k < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + scoreGapLo q k)) q k) = + (otherKeys q).sum (fun k => + if scoreGapLo q k < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + scoreGapLo q k)) := by + refine Finset.sum_congr rfl ?_ + intro k hk + have hk' : k ≠ inputs.prev q := by + have hk'' : k ∈ Finset.univ.erase (inputs.prev q) := by + simpa [otherKeys] using hk + exact (Finset.mem_erase.mp hk'').1 + simp [Bounds.cacheBound2Task_apply, hk'] + simpa [epsAt, epsAtBase, certFields, buildInductionHeadCertFields_def, hsum] using + (Bounds.cacheBoundThunk_apply (f := epsAtBase) q) + have oneHot_bounds_at : + ∀ q, q ∈ inputs.active → + Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAt q : Real) + (fun q' => q' = q) inputs.prev weights := by + intro q hq + exact + Sound.oneHot_bounds_at_of_scoreGapLo + (active := inputs.active) + (prev := inputs.prev) + (scoresReal := scoresReal) + (scoreGapLo := scoreGapLo) + (epsAt := epsAt) + (hepsAt := hepsAt) + (hscore_gap_real_at := hscore_gap_real_at) + q hq + have weight_bounds_at : + ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → + weights q k ≤ (weightBoundAt q k : Real) := by + intro q hq k hk + have hbound_base : + weights q k ≤ (weightBoundAtBase q k : Real) := by + exact + Sound.weight_bound_at_of_scoreGapLo + (active := inputs.active) + (prev := inputs.prev) + (scoresReal := scoresReal) + (scoreGapLo := scoreGapLo) + (weightBoundAt := weightBoundAtBase) + (hweightBoundAt := hweightBoundAtBase) + (hscore_gap_real_at := hscore_gap_real_at) + q hq k hk + have hweightBoundAt_real : + (weightBoundAt q k : Real) = + (weightBoundAtBase q k : Real) := by + have hbase : weightBoundAt q k = weightBoundAtBase q k := + hweightBoundAt q k + have hbase' : + ratToReal (weightBoundAt q k) = ratToReal (weightBoundAtBase q k) := + congrArg ratToReal hbase + simpa [ratToReal_def] using hbase' + simpa [hweightBoundAt_real] using hbound_base + have hepsAt_le_eps : + ∀ q, q ∈ inputs.active → epsAt q ≤ eps := by + intro q hq + have hle : + epsAt q ≤ inputs.active.sup' hactive epsAt := by + exact + (Finset.le_sup'_iff (s := inputs.active) (H := hactive) + (f := epsAt) (a := epsAt q)).2 ⟨q, hq, le_rfl⟩ + have heps_def : + eps = inputs.active.sup' hactive epsAt := by + have heps_def' : + certFields.eps = + if h : inputs.active.Nonempty then + inputs.active.sup' h certFields.epsAt + else + (0 : Rat) := by + simpa [certFields] using + (buildInductionHeadCertFields_eps_eq + (inputs := inputs) (otherKeys := otherKeys) (scoreGapLo := scoreGapLo)) + have heps_def'' : + certFields.eps = inputs.active.sup' hactive certFields.epsAt := by + simpa [hactive] using heps_def' + simpa [eps, epsAt] using heps_def'' + simpa [heps_def] using hle + have hepsAt_le_eps_real : + ∀ q, q ∈ inputs.active → (epsAt q : Real) ≤ (eps : Real) := by + intro q hq + simpa [ratToReal_def] using ratToReal_le_of_le (hepsAt_le_eps q hq) + have hsoftmax_bounds : + Layers.SoftmaxMarginBoundsOn (Val := Real) (eps : Real) (margin : Real) + (fun q => q ∈ inputs.active) inputs.prev scoresReal weights := by + classical + refine + { score_margin := ?_ + nonneg := ?_ + sum_one := ?_ + prev_large := ?_ + other_le := ?_ } + · intro q hq k hk + exact hscore_margin_real q hq k hk + · intro q _ k + simpa [weights, softmaxWeights, Circuit.softmaxWeights_weights] using + softmaxWeights.nonneg q k + · intro q _ + simpa [weights, softmaxWeights, Circuit.softmaxWeights_weights] using + softmaxWeights.sum_one q + · intro q hq + have honehot := oneHot_bounds_at q hq + have hprev := honehot.prev_large q rfl + have hle : + weights q (inputs.prev q) + (epsAt q : Real) ≤ + weights q (inputs.prev q) + (eps : Real) := by + have hle' := + add_le_add_left (hepsAt_le_eps_real q hq) (weights q (inputs.prev q)) + calc + weights q (inputs.prev q) + (epsAt q : Real) + = (epsAt q : Real) + weights q (inputs.prev q) := by + exact add_comm _ _ + _ ≤ (eps : Real) + weights q (inputs.prev q) := hle' + _ = weights q (inputs.prev q) + (eps : Real) := by + exact add_comm _ _ + exact hprev.trans hle + · intro q hq k hk + have honehot := oneHot_bounds_at q hq + have hother := honehot.other_le q rfl k hk + exact hother.trans (hepsAt_le_eps_real q hq) + have hwvDirRaw : + ∀ j, wvDirRaw j = Linear.dotFin dHead dirHead (fun d => inputs.wv j d) := by + intro j + rfl + have hwvDirTask : ∀ j, wvDirTask j = wvDirRaw j := by + intro j + simpa [wvDirTask] using (Bounds.cacheBoundTask_apply wvDirRaw j) + have hwvDir : ∀ j, wvDir j = Linear.dotFin dHead dirHead (fun d => inputs.wv j d) := by + intro j + have hwv : wvDir j = wvDirTask j := by + simp only [wvDir, wvDirArr, Array.getElem_ofFn] + have hwv' : wvDir j = wvDirRaw j := by + exact hwv.trans (hwvDirTask j) + calc + wvDir j = wvDirRaw j := hwv' + _ = Linear.dotFin dHead dirHead (fun d => inputs.wv j d) := hwvDirRaw j + have hbDir : + bDir = Linear.dotFin dHead dirHead (fun d => inputs.bv d) := by + rfl + have hvals_bounds : + ValueIntervalBounds (vals := valsRealOfInputs inputs) valCert := + valCert_bounds_of_ln_bounds (inputs := inputs) (dirHead := dirHead) (hdirHead := rfl) + (wvDir := wvDir) (bDir := bDir) (hwvDir := hwvDir) (hbDir := hbDir) + (lnLo := lnLo) (lnHi := lnHi) (valsLo := valsLo) (valsHi := valsHi) + (hvalsLo := hvalsLo) (hvalsHi := hvalsHi) (hln := hln_bounds) + have hcert_eps : cert.eps = eps := by rfl + have hcert_margin : cert.margin = margin := by rfl + have hcert_active : cert.active = inputs.active := by rfl + have hcert_prev : cert.prev = inputs.prev := by rfl + have hcert_epsAt : cert.epsAt = epsAt := by rfl + have hcert_weight : cert.weightBoundAt = weightBoundAt := by rfl + have hcert_values : cert.values = valCert := by rfl + refine + { softmax_bounds := ?_ + oneHot_bounds_at := ?_ + weight_bounds_at := ?_ + value_bounds := ?_ } + · simpa [hcert_eps, hcert_margin, hcert_active, hcert_prev] using hsoftmax_bounds + · intro q hq + simpa [hcert_epsAt, hcert_active, hcert_prev] using oneHot_bounds_at q hq + · intro q hq k hk + simpa [hcert_weight, hcert_active, hcert_prev] using weight_bounds_at q hq k hk + · simpa [hcert_values] using hvals_bounds + · have : False := by + have hnone := + buildInductionCertFromHeadCoreWith?_eq_none_of_not_active + (cfg := cfg) (inputs := inputs) hEps hSqrt hmodel hactive + have hcore' : + (none : Option (InductionHeadCert seq)) = some c := by + exact hnone.symm.trans hcore + cases hcore' + exact this.elim + · have : False := by + have hnone := + buildInductionCertFromHeadCoreWith?_eq_none_of_not_sqrt + (cfg := cfg) (inputs := inputs) hEps hSqrt + have hcore' : + (none : Option (InductionHeadCert seq)) = some c := by + exact hnone.symm.trans hcore + cases hcore' + exact this.elim + · have : False := by + have hnone := + buildInductionCertFromHeadCoreWith?_eq_none_of_not_eps + (cfg := cfg) (inputs := inputs) hEps + have hcore' : + (none : Option (InductionHeadCert seq)) = some c := by + exact hnone.symm.trans hcore + cases hcore' + exact this.elim + + +end Sound +end Nfp diff --git a/Nfp/Sound/Induction/CoreSound/Basic/DefaultSound.lean b/Nfp/Sound/Induction/CoreSound/Basic/DefaultSound.lean new file mode 100644 index 0000000..5171a42 --- /dev/null +++ b/Nfp/Sound/Induction/CoreSound/Basic/DefaultSound.lean @@ -0,0 +1,29 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later +module + +import all Nfp.Sound.Induction.Core.Basic +public import Nfp.Sound.Induction.Core +public import Nfp.Sound.Induction.CoreSound.Basic.CertSound + +public section + +namespace Nfp +namespace Sound +open scoped BigOperators +open Nfp.Circuit +open Nfp.Sound.Bounds +variable {seq : Nat} +/-- Soundness for `buildInductionCertFromHeadCore?`. -/ +theorem buildInductionCertFromHeadCore?_sound + [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) + (hcore : buildInductionCertFromHeadCore? inputs = some c) : + InductionHeadCertSound inputs c := by + have hcore' : + buildInductionCertFromHeadCoreWith? defaultInductionHeadSplitConfig inputs = some c := by + simpa [buildInductionCertFromHeadCore?_def] using hcore + exact + buildInductionCertFromHeadCoreWith?_sound + (cfg := defaultInductionHeadSplitConfig) inputs c hcore' +end Sound +end Nfp diff --git a/Nfp/Sound/Induction/CoreSound/Values.lean b/Nfp/Sound/Induction/CoreSound/Values.lean index 4978444..68e0962 100644 --- a/Nfp/Sound/Induction/CoreSound/Values.lean +++ b/Nfp/Sound/Induction/CoreSound/Values.lean @@ -3,6 +3,7 @@ module public import Mathlib.Algebra.BigOperators.Group.Finset.Basic public import Nfp.Sound.Induction.CoreDefs +public import Nfp.Sound.Induction.Core.Basic public import Nfp.Sound.Linear.FinFold /-! @@ -142,5 +143,87 @@ theorem valsReal_eq_of_dir (inputs : Model.InductionHeadInputs seq dModel dHead) simpa [dotProduct] using hdir_bv.symm simp [hb] +/-- Bound `valsRealOfInputs` using cached `wvDir`/`bDir` and logit interval bounds. -/ +theorem valsReal_bounds_at_of_ln_bounds (inputs : Model.InductionHeadInputs seq dModel dHead) + (dirHead : Fin dHead → Rat) + (hdirHead : dirHead = fun d => (dirHeadVecOfInputs inputs).get d) + (wvDir : Fin dModel → Rat) (bDir : Rat) + (hwvDir : ∀ j, wvDir j = Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) + (hbDir : bDir = Linear.dotFin dHead dirHead (fun d => inputs.bv d)) + (lnLo lnHi : Fin seq → Fin dModel → Rat) + (valsLo valsHi : Fin seq → Rat) + (hvalsLo : + ∀ k, valsLo k = bDir + Bounds.dotIntervalLower wvDir (lnLo k) (lnHi k)) + (hvalsHi : + ∀ k, valsHi k = bDir + Bounds.dotIntervalUpper wvDir (lnLo k) (lnHi k)) + (hln : + ∀ k j, (lnLo k j : Real) ≤ lnRealOfInputs inputs k j ∧ + lnRealOfInputs inputs k j ≤ (lnHi k j : Real)) : + ∀ k, + (valsLo k : Rat) ≤ valsRealOfInputs inputs k ∧ + valsRealOfInputs inputs k ≤ (valsHi k : Rat) := by + intro k + have hdir_wv : + ∀ j, (wvDir j : Real) = ∑ d, (dirHead d : Real) * (inputs.wv j d : Real) := + wvDir_real_eq_sum inputs dirHead wvDir hwvDir + have hdir_bv : + (bDir : Real) = ∑ d, (dirHead d : Real) * (inputs.bv d : Real) := + bDir_real_eq_sum inputs dirHead bDir hbDir + have hvals_eq := + valsReal_eq_of_dir inputs dirHead hdirHead wvDir bDir hdir_wv hdir_bv k + have hlo : ∀ j, (lnLo k j : Real) ≤ lnRealOfInputs inputs k j := + fun j => (hln k j).1 + have hhi : ∀ j, lnRealOfInputs inputs k j ≤ (lnHi k j : Real) := + fun j => (hln k j).2 + have hlow' : + (bDir + Bounds.dotIntervalLower wvDir (lnLo k) (lnHi k) : Rat) ≤ + dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + (bDir : Real) := by + simpa [Rat.cast_add, add_comm] using + (Bounds.dotIntervalLower_le_dotProduct_real_add + (n := dModel) (v := wvDir) + (lo := lnLo k) (hi := lnHi k) + (x := lnRealOfInputs inputs k) (b := (bDir : Real)) + hlo hhi) + have hhigh' : + dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + (bDir : Real) ≤ + (bDir + Bounds.dotIntervalUpper wvDir (lnLo k) (lnHi k) : Rat) := by + simpa [Rat.cast_add, add_comm] using + (Bounds.dotProduct_le_dotIntervalUpper_real_add + (n := dModel) (v := wvDir) + (lo := lnLo k) (hi := lnHi k) + (x := lnRealOfInputs inputs k) (b := (bDir : Real)) + hlo hhi) + constructor + · rw [hvalsLo k, hvals_eq] + exact hlow' + · rw [hvalsHi k, hvals_eq] + exact hhigh' + +/-- Build `ValueIntervalBounds` from logit interval bounds for `buildInductionHeadValCert`. -/ +theorem valCert_bounds_of_ln_bounds [NeZero seq] + (inputs : Model.InductionHeadInputs seq dModel dHead) + (dirHead : Fin dHead → Rat) + (hdirHead : dirHead = fun d => (dirHeadVecOfInputs inputs).get d) + (wvDir : Fin dModel → Rat) (bDir : Rat) + (hwvDir : ∀ j, wvDir j = Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) + (hbDir : bDir = Linear.dotFin dHead dirHead (fun d => inputs.bv d)) + (lnLo lnHi : Fin seq → Fin dModel → Rat) + (valsLo valsHi : Fin seq → Rat) + (hvalsLo : + ∀ k, valsLo k = bDir + Bounds.dotIntervalLower wvDir (lnLo k) (lnHi k)) + (hvalsHi : + ∀ k, valsHi k = bDir + Bounds.dotIntervalUpper wvDir (lnLo k) (lnHi k)) + (hln : + ∀ k j, (lnLo k j : Real) ≤ lnRealOfInputs inputs k j ∧ + lnRealOfInputs inputs k j ≤ (lnHi k j : Real)) : + ValueIntervalBounds (vals := valsRealOfInputs inputs) + (buildInductionHeadValCert inputs valsLo valsHi) := by + have hvals_bounds_at := + valsReal_bounds_at_of_ln_bounds inputs dirHead hdirHead wvDir bDir hwvDir hbDir + lnLo lnHi valsLo valsHi hvalsLo hvalsHi hln + exact buildInductionHeadValCert_bounds (inputs := inputs) + (valsReal := valsRealOfInputs inputs) (valsLo := valsLo) (valsHi := valsHi) + hvals_bounds_at + end Sound end Nfp diff --git a/Nfp/Sound/Induction/HeadOutput.lean b/Nfp/Sound/Induction/HeadOutput.lean index 91231e3..af9ff4c 100644 --- a/Nfp/Sound/Induction/HeadOutput.lean +++ b/Nfp/Sound/Induction/HeadOutput.lean @@ -33,6 +33,32 @@ def buildInductionCertFromHeadWith? [NeZero seq] {dModel dHead : Nat} | some c => exact some ⟨c, buildInductionCertFromHeadCoreWith?_sound (cfg := cfg) inputs c hcore⟩ +/-- Build and certify induction certificates from exact head inputs, retaining the core cache. -/ +def buildInductionCertFromHeadWithCache? [NeZero seq] {dModel dHead : Nat} + (cfg : InductionHeadSplitConfig) + (inputs : Model.InductionHeadInputs seq dModel dHead) : + Option {cache : InductionHeadCoreCache seq dModel dHead // + InductionHeadCertSound inputs cache.cert} := by + classical + by_cases hEps : 0 < inputs.lnEps + · by_cases hSqrt : 0 < sqrtLower inputs.lnEps + · by_cases hmodel : dModel = 0 + · exact none + · by_cases hactive : inputs.active.Nonempty + · let cache := buildInductionHeadCoreCacheWith cfg inputs + have hmodel' : dModel ≠ 0 := by + exact hmodel + have hcore : + buildInductionCertFromHeadCoreWith? cfg inputs = some cache.cert := by + simpa [cache] using + (buildInductionCertFromHeadCoreWith?_eq_some + (cfg := cfg) (inputs := inputs) hEps hSqrt hmodel' hactive) + exact some ⟨cache, + buildInductionCertFromHeadCoreWith?_sound (cfg := cfg) inputs cache.cert hcore⟩ + · exact none + · exact none + · exact none + /-- Build and certify induction certificates from exact head inputs using the default split budgets. -/ def buildInductionCertFromHead? [NeZero seq] {dModel dHead : Nat} diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean index eada082..5ff77bf 100644 --- a/Nfp/Sound/Induction/LogitDiff.lean +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -3,10 +3,12 @@ module public import Aesop +public import Mathlib.Data.List.MinMax public import Mathlib.Data.Vector.Basic public import Nfp.Circuit.Cert.LogitDiff public import Nfp.Sound.Bounds.MatrixNorm.Interval public import Nfp.Sound.Induction.HeadOutput +public import Nfp.Sound.Induction.Refine /-! Logit-diff bounds derived from induction certificates. @@ -109,8 +111,14 @@ noncomputable def headLogitDiff (inputs : Model.InductionHeadInputs seq dModel d def logitDiffLowerBoundFromCert (c : InductionHeadCert seq) : Option Rat := let epsAt := Bounds.cacheBoundTask c.epsAt let valsLo := Bounds.cacheBoundTask c.values.valsLo - Circuit.logitDiffLowerBoundAtLo c.active c.prev epsAt - c.values.lo valsLo + let loAt : Fin seq → Rat := fun q => + let others : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).erase (c.prev q) + if h : others.Nonempty then + others.inf' h valsLo + else + c.values.lo + Circuit.logitDiffLowerBoundAtLoAt c.active c.prev epsAt loAt valsLo /-- Lower bound computed from per-key weight bounds in an induction certificate. -/ def logitDiffLowerBoundFromCertWeighted (c : InductionHeadCert seq) : Option Rat := @@ -129,37 +137,355 @@ def logitDiffCache (c : InductionHeadCert seq) : LogitDiffCache seq := { epsAt := Bounds.cacheBoundTask c.epsAt valsLo := Bounds.cacheBoundTask c.values.valsLo } +/-- Unfolding lemma for `logitDiffCache`. -/ +theorem logitDiffCache_def (c : InductionHeadCert seq) : + logitDiffCache c = + { epsAt := Bounds.cacheBoundTask c.epsAt + valsLo := Bounds.cacheBoundTask c.values.valsLo } := by + rfl + /-- Unweighted logit-diff lower bound from a shared cache. -/ def logitDiffLowerBoundFromCache (c : InductionHeadCert seq) (cache : LogitDiffCache seq) : Option Rat := - Circuit.logitDiffLowerBoundAtLo c.active c.prev cache.epsAt c.values.lo cache.valsLo + let epsArr : Array Rat := Array.ofFn cache.epsAt + let valsLoArr : Array Rat := Array.ofFn cache.valsLo + let epsAt : Fin seq → Rat := fun q => + epsArr[q.1]'(by + simp [epsArr, q.isLt]) + let valsLo : Fin seq → Rat := fun q => + valsLoArr[q.1]'(by + simp [valsLoArr, q.isLt]) + let loAt : Fin seq → Rat := fun q => + let others : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).erase (c.prev q) + if h : others.Nonempty then + others.inf' h valsLo + else + c.values.lo + Circuit.logitDiffLowerBoundAtLoAt c.active c.prev epsAt loAt valsLo + +/-- Query attaining the cached unweighted logit-diff lower bound, if any. -/ +def logitDiffLowerBoundArgminFromCache (c : InductionHeadCert seq) (cache : LogitDiffCache seq) : + Option (Fin seq) := + let epsArr : Array Rat := Array.ofFn cache.epsAt + let valsLoArr : Array Rat := Array.ofFn cache.valsLo + let epsAt : Fin seq → Rat := fun q => + epsArr[q.1]'(by + simp [epsArr, q.isLt]) + let valsLo : Fin seq → Rat := fun q => + valsLoArr[q.1]'(by + simp [valsLoArr, q.isLt]) + let loAt : Fin seq → Rat := fun q => + let others : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).erase (c.prev q) + if h : others.Nonempty then + others.inf' h valsLo + else + c.values.lo + let f : Fin seq → Rat := fun q => + let delta := valsLo (c.prev q) - loAt q + valsLo (c.prev q) - epsAt q * max (0 : Rat) delta + let qs := (List.finRange seq).filter (fun q => decide (q ∈ c.active)) + List.argmin f qs + +/-- Unfolding lemma for `logitDiffLowerBoundArgminFromCache`. -/ +theorem logitDiffLowerBoundArgminFromCache_def + (c : InductionHeadCert seq) (cache : LogitDiffCache seq) : + logitDiffLowerBoundArgminFromCache c cache = + let epsArr : Array Rat := Array.ofFn cache.epsAt + let valsLoArr : Array Rat := Array.ofFn cache.valsLo + let epsAt : Fin seq → Rat := fun q => + epsArr[q.1]'(by + simp [epsArr, q.isLt]) + let valsLo : Fin seq → Rat := fun q => + valsLoArr[q.1]'(by + simp [valsLoArr, q.isLt]) + let loAt : Fin seq → Rat := fun q => + let others : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).erase (c.prev q) + if h : others.Nonempty then + others.inf' h valsLo + else + c.values.lo + let f : Fin seq → Rat := fun q => + let delta := valsLo (c.prev q) - loAt q + valsLo (c.prev q) - epsAt q * max (0 : Rat) delta + let qs := (List.finRange seq).filter (fun q => decide (q ∈ c.active)) + List.argmin f qs := by + rfl + +/-- Unweighted logit-diff lower bound from a shared cache and custom `epsAt`. -/ +def logitDiffLowerBoundFromCacheWithEps (c : InductionHeadCert seq) (cache : LogitDiffCache seq) + (epsAtCustom : Fin seq → Rat) : Option Rat := + let epsArr : Array Rat := Array.ofFn epsAtCustom + let valsLoArr : Array Rat := Array.ofFn cache.valsLo + let epsAt : Fin seq → Rat := fun q => + epsArr[q.1]'(by + simp [epsArr, q.isLt]) + let valsLo : Fin seq → Rat := fun q => + valsLoArr[q.1]'(by + simp [valsLoArr, q.isLt]) + let loAt : Fin seq → Rat := fun q => + let others : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).erase (c.prev q) + if h : others.Nonempty then + others.inf' h valsLo + else + c.values.lo + Circuit.logitDiffLowerBoundAtLoAt c.active c.prev epsAt loAt valsLo + +/-- Unfold `logitDiffLowerBoundFromCache` as the custom-eps variant. -/ +theorem logitDiffLowerBoundFromCache_eq_withEps + (c : InductionHeadCert seq) (cache : LogitDiffCache seq) : + logitDiffLowerBoundFromCache c cache = + logitDiffLowerBoundFromCacheWithEps c cache cache.epsAt := by + rfl + +/-- Unfolding lemma for `logitDiffLowerBoundFromCacheWithEps`. -/ +theorem logitDiffLowerBoundFromCacheWithEps_def + (c : InductionHeadCert seq) (cache : LogitDiffCache seq) + (epsAtCustom : Fin seq → Rat) : + logitDiffLowerBoundFromCacheWithEps c cache epsAtCustom = + let epsArr : Array Rat := Array.ofFn epsAtCustom + let valsLoArr : Array Rat := Array.ofFn cache.valsLo + let epsAt : Fin seq → Rat := fun q => + epsArr[q.1]'(by + simp [epsArr, q.isLt]) + let valsLo : Fin seq → Rat := fun q => + valsLoArr[q.1]'(by + simp [valsLoArr, q.isLt]) + let loAt : Fin seq → Rat := fun q => + let others : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).erase (c.prev q) + if h : others.Nonempty then + others.inf' h valsLo + else + c.values.lo + Circuit.logitDiffLowerBoundAtLoAt c.active c.prev epsAt loAt valsLo := by + rfl + +/-- Refined unweighted logit-diff lower bound using an overlayed `epsAt`. -/ +def logitDiffLowerBoundRefinedFromCache + (inputs : Model.InductionHeadInputs seq dModel dHead) + (core : InductionHeadCoreCache seq dModel dHead) + (c : InductionHeadCert seq) (cache : LogitDiffCache seq) + (spec : InductionHeadRefineSpec seq) : Option Rat := + let weightBoundAt := weightBoundAtOverlay inputs core spec + let epsAt := epsAtOverlay core weightBoundAt + logitDiffLowerBoundFromCacheWithEps c cache epsAt + +/-- Unfolding lemma for `logitDiffLowerBoundRefinedFromCache`. -/ +theorem logitDiffLowerBoundRefinedFromCache_def + (inputs : Model.InductionHeadInputs seq dModel dHead) + (core : InductionHeadCoreCache seq dModel dHead) + (c : InductionHeadCert seq) (cache : LogitDiffCache seq) + (spec : InductionHeadRefineSpec seq) : + logitDiffLowerBoundRefinedFromCache inputs core c cache spec = + let weightBoundAt := weightBoundAtOverlay inputs core spec + let epsAt := epsAtOverlay core weightBoundAt + logitDiffLowerBoundFromCacheWithEps c cache epsAt := by + rfl + +/-- Refine-on-demand unweighted logit-diff bound using a supplied refinement spec. -/ +def logitDiffLowerBoundRefineOnDemandWithSpec + (inputs : Model.InductionHeadInputs seq dModel dHead) + (core : InductionHeadCoreCache seq dModel dHead) + (c : InductionHeadCert seq) (cache : LogitDiffCache seq) + (spec : InductionHeadRefineSpec seq) : Option Rat := + match logitDiffLowerBoundFromCache c cache with + | none => none + | some lb0 => + if lb0 ≤ 0 then + match logitDiffLowerBoundRefinedFromCache inputs core c cache spec with + | some lb1 => some (max lb0 lb1) + | none => some lb0 + else + some lb0 + +/-- Unfolding lemma for `logitDiffLowerBoundRefineOnDemandWithSpec`. -/ +theorem logitDiffLowerBoundRefineOnDemandWithSpec_def + (inputs : Model.InductionHeadInputs seq dModel dHead) + (core : InductionHeadCoreCache seq dModel dHead) + (c : InductionHeadCert seq) (cache : LogitDiffCache seq) + (spec : InductionHeadRefineSpec seq) : + logitDiffLowerBoundRefineOnDemandWithSpec inputs core c cache spec = + match logitDiffLowerBoundFromCache c cache with + | none => none + | some lb0 => + if lb0 ≤ 0 then + match logitDiffLowerBoundRefinedFromCache inputs core c cache spec with + | some lb1 => some (max lb0 lb1) + | none => some lb0 + else + some lb0 := by + rfl + +/-- Refine-on-demand unweighted logit-diff bound, refining only the argmin query. -/ +def logitDiffLowerBoundRefineOnDemand + (inputs : Model.InductionHeadInputs seq dModel dHead) + (core : InductionHeadCoreCache seq dModel dHead) + (c : InductionHeadCert seq) (cache : LogitDiffCache seq) : Option Rat := + match logitDiffLowerBoundFromCache c cache with + | none => none + | some lb0 => + if lb0 ≤ 0 then + match logitDiffLowerBoundArgminFromCache c cache with + | none => some lb0 + | some q0 => + let refineBudget := max 1 core.splitBudgetDiffRefined + let spec := refineSpecForQueryWithWeightOnes inputs core q0 refineBudget + match logitDiffLowerBoundRefinedFromCache inputs core c cache spec with + | some lb1 => some (max lb0 lb1) + | none => some lb0 + else + some lb0 + +/-- Unfolding lemma for `logitDiffLowerBoundRefineOnDemand`. -/ +theorem logitDiffLowerBoundRefineOnDemand_def + (inputs : Model.InductionHeadInputs seq dModel dHead) + (core : InductionHeadCoreCache seq dModel dHead) + (c : InductionHeadCert seq) (cache : LogitDiffCache seq) : + logitDiffLowerBoundRefineOnDemand inputs core c cache = + match logitDiffLowerBoundFromCache c cache with + | none => none + | some lb0 => + if lb0 ≤ 0 then + match logitDiffLowerBoundArgminFromCache c cache with + | none => some lb0 + | some q0 => + let refineBudget := max 1 core.splitBudgetDiffRefined + let spec := refineSpecForQueryWithWeightOnes inputs core q0 refineBudget + match logitDiffLowerBoundRefinedFromCache inputs core c cache spec with + | some lb1 => some (max lb0 lb1) + | none => some lb0 + else + some lb0 := by + rfl /-- Weighted logit-diff lower bound from a shared cache. -/ def logitDiffLowerBoundWeightedFromCache (c : InductionHeadCert seq) (cache : LogitDiffCache seq) : Option Rat := - let others : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (c.prev q) + let valsLoArr : Array Rat := Array.ofFn cache.valsLo + let valsLo : Fin seq → Rat := fun k => + valsLoArr[k.1]'(by + simp [valsLoArr, k.isLt]) + let weightRows : Array (Array Rat) := + Array.ofFn (fun q : Fin seq => Array.ofFn (fun k : Fin seq => c.weightBoundAt q k)) + let weightBoundAt : Fin seq → Fin seq → Rat := fun q k => + let row := weightRows[q.1]'(by + simp [weightRows, q.isLt]) + row[k.1]'(by + have hrow : row.size = seq := by + simp [row, weightRows] + simp [hrow, k.isLt]) let gapBase : Fin seq → Rat := fun q => - (others q).sum (fun k => - let diff := cache.valsLo (c.prev q) - cache.valsLo k - let diffPos := max (0 : Rat) diff - if diffPos = 0 then - 0 - else - c.weightBoundAt q k * diffPos) + let valsLoPrev := valsLo (c.prev q) + Linear.sumFin seq (fun k => + let diff := valsLoPrev - valsLo k + weightBoundAt q k * max (0 : Rat) diff) let gap : Fin seq → Rat := Bounds.cacheBoundTask gapBase if h : c.active.Nonempty then - let f : Fin seq → Rat := fun q => cache.valsLo (c.prev q) - gap q - let img := c.active.image f - have himg : img.Nonempty := h.image f - some (Finset.min' img himg) + let f : Fin seq → Rat := fun q => valsLo (c.prev q) - gap q + some (c.active.inf' h f) else none +/-- Debug payload for the unweighted logit-diff lower bound. -/ +structure LogitDiffAtLoDebug (seq : Nat) where + /-- Query attaining the bound, if found. -/ + q : Fin seq + /-- Previous index for the query. -/ + prev : Fin seq + /-- Per-query eps bound. -/ + eps : Rat + /-- Lower bound for the previous value. -/ + valsPrevLo : Rat + /-- Global lower value bound. -/ + lo : Rat + /-- Per-query lower bound for other values. -/ + loAt : Rat + /-- `valsPrevLo - loAt`. -/ + valsPrevLoMinusLoAt : Rat + /-- `eps * max 0 (valsPrevLo - loAt)`. -/ + gap : Rat + /-- Lower bound reported by `logitDiffLowerBoundFromCache`. -/ + lbAtQ : Rat + +/-- Attempt to recover a query that attains the unweighted logit-diff bound. -/ +def logitDiffLowerBoundAtLoDebug (c : InductionHeadCert seq) (cache : LogitDiffCache seq) : + Option {d : LogitDiffAtLoDebug seq // + logitDiffLowerBoundFromCache c cache = some d.lbAtQ} := + let epsArr : Array Rat := Array.ofFn cache.epsAt + let valsLoArr : Array Rat := Array.ofFn cache.valsLo + let epsAt : Fin seq → Rat := fun q => + epsArr[q.1]'(by + simp [epsArr, q.isLt]) + let valsLo : Fin seq → Rat := fun q => + valsLoArr[q.1]'(by + simp [valsLoArr, q.isLt]) + let loAt : Fin seq → Rat := fun q => + let others : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).erase (c.prev q) + if h : others.Nonempty then + others.inf' h valsLo + else + c.values.lo + let f : Fin seq → Rat := fun q => + let valsPrevLo := valsLo (c.prev q) + let delta := valsPrevLo - loAt q + valsPrevLo - epsAt q * max (0 : Rat) delta + let best? : Option (Fin seq × Rat) := + Linear.foldlFin seq + (fun acc q => + if hq : q ∈ c.active then + let val := f q + match acc with + | none => some (q, val) + | some (qBest, best) => + if val ≤ best then + some (q, val) + else + some (qBest, best) + else + acc) + none + match logitDiffLowerBoundFromCache c cache with + | none => none + | some lb => + match best? with + | none => none + | some (q, _) => + let prev := c.prev q + let valsPrevLo := valsLo prev + let loAtQ := loAt q + let delta := valsPrevLo - loAtQ + let gap := epsAt q * max (0 : Rat) delta + let d : LogitDiffAtLoDebug seq := + { q := q + prev := prev + eps := epsAt q + valsPrevLo := valsPrevLo + lo := c.values.lo + loAt := loAtQ + valsPrevLoMinusLoAt := delta + gap := gap + lbAtQ := lb } + have h' : some lb = some d.lbAtQ := by + simp [d] + some ⟨d, h'⟩ + /-- `logitDiffLowerBoundFromCache` matches the cached default computation. -/ theorem logitDiffLowerBoundFromCache_eq (c : InductionHeadCert seq) : logitDiffLowerBoundFromCache c (logitDiffCache c) = logitDiffLowerBoundFromCert c := by - rfl + classical + unfold logitDiffLowerBoundFromCache logitDiffLowerBoundFromCert logitDiffCache + have heps : Bounds.cacheBoundTask c.epsAt = c.epsAt := by + funext k + simp [Bounds.cacheBoundTask_apply] + have hvals : Bounds.cacheBoundTask c.values.valsLo = c.values.valsLo := by + funext k + simp [Bounds.cacheBoundTask_apply] + simp [heps, hvals, Bounds.cacheBoundTask_apply] /-- `logitDiffLowerBoundWeightedFromCache` matches the cached default computation. -/ theorem logitDiffLowerBoundWeightedFromCache_eq (c : InductionHeadCert seq) : @@ -170,7 +496,8 @@ theorem logitDiffLowerBoundWeightedFromCache_eq (c : InductionHeadCert seq) : have hvals : Bounds.cacheBoundTask c.values.valsLo = c.values.valsLo := by funext k simp [Bounds.cacheBoundTask_apply] - simp [hvals, Bounds.cacheBoundTask_apply, logitDiffLowerBoundWeightedAt_def] + simp [hvals, Bounds.cacheBoundTask_apply, logitDiffLowerBoundWeightedAt_def, + Linear.sumFin_eq_sum_univ] /-- Best available logit-diff lower bound from an induction certificate. -/ def logitDiffLowerBoundFromCertBest (c : InductionHeadCert seq) : Option Rat := @@ -200,30 +527,40 @@ theorem logitDiffLowerBoundFromCert_le let vals : Fin (Nat.succ n) → Real := valsRealOfInputs inputs let epsAt := Bounds.cacheBoundTask c.epsAt let valsLo := Bounds.cacheBoundTask c.values.valsLo + let loAt : Fin (Nat.succ n) → Rat := fun q => + let others : Finset (Fin (Nat.succ n)) := + (Finset.univ : Finset (Fin (Nat.succ n))).erase (c.prev q) + if h : others.Nonempty then + others.inf' h valsLo + else + c.values.lo let others : Finset (Fin (Nat.succ n)) := (Finset.univ : Finset (Fin (Nat.succ n))).erase (c.prev q) let sumOthers : Real := ∑ k ∈ others, weights q k let valsLoPrev : Real := (c.values.valsLo (c.prev q) : Real) - let lo : Real := (c.values.lo : Real) + let loAtRat : Rat := loAt q + let loAtReal : Real := (loAtRat : Real) have hboundRat : lb ≤ valsLo (c.prev q) - - epsAt q * (valsLo (c.prev q) - c.values.lo) := by + epsAt q * max (0 : Rat) (valsLo (c.prev q) - loAt q) := by refine - Circuit.logitDiffLowerBoundAtLo_le + Circuit.logitDiffLowerBoundAtLoAt_le (active := c.active) (prev := c.prev) (epsAt := epsAt) - (lo := c.values.lo) + (loAt := loAt) (valsLo := valsLo) q hq lb ?_ - simpa [logitDiffLowerBoundFromCert] using hbound + simpa [logitDiffLowerBoundFromCert, loAt] using hbound have hboundRat' : lb ≤ c.values.valsLo (c.prev q) - - c.epsAt q * (c.values.valsLo (c.prev q) - c.values.lo) := by + c.epsAt q * max (0 : Rat) (c.values.valsLo (c.prev q) - loAt q) := by simpa [epsAt, valsLo, Bounds.cacheBoundTask_apply] using hboundRat have hboundReal : - (lb : Real) ≤ valsLoPrev - (c.epsAt q : Real) * (valsLoPrev - lo) := by - simpa [ratToReal_sub, ratToReal_mul, ratToReal_def] using + (lb : Real) ≤ + valsLoPrev - (c.epsAt q : Real) * max (0 : Real) (valsLoPrev - loAtReal) := by + simpa [loAtRat, loAtReal, ratToReal_sub, ratToReal_mul, ratToReal_max, + ratToReal_def] using ratToReal_le_of_le hboundRat' have hweights_nonneg : ∀ k, 0 ≤ weights q k := hsound.softmax_bounds.nonneg q hq @@ -244,31 +581,275 @@ theorem logitDiffLowerBoundFromCert_le weights q (c.prev q) + (c.epsAt q : Real) := by simpa [hsum, sumOthers] using hprev exact (add_le_add_iff_left (weights q (c.prev q))).1 hprev' - have hvals_lo : ∀ k, lo ≤ vals k := by + have hloAt_le_valsLo : ∀ k ∈ others, loAtRat ≤ valsLo k := by + intro k hk + have hnonempty : others.Nonempty := ⟨k, hk⟩ + have hmin : others.inf' hnonempty valsLo ≤ valsLo k := + Finset.inf'_le (s := others) (f := valsLo) hk + have hnonempty' : (Finset.univ.erase (c.prev q)).Nonempty := by + simpa [others] using hnonempty + have hloAt : loAtRat = others.inf' hnonempty valsLo := by + dsimp [loAtRat, loAt] + simp [hnonempty', others] + calc + loAtRat = others.inf' hnonempty valsLo := hloAt + _ ≤ valsLo k := hmin + have hvals_lo : ∀ k ∈ others, loAtReal ≤ vals k := by + intro k hk + have hloRat := hloAt_le_valsLo k hk + have hloReal : loAtReal ≤ (valsLo k : Real) := by + simpa [loAtReal, ratToReal_def] using (ratToReal_le_of_le hloRat) + have hvals : (valsLo k : Real) ≤ vals k := by + simpa [valsLo, Bounds.cacheBoundTask_apply] using + (hsound.value_bounds.vals_bounds k).1 + exact le_trans hloReal hvals + have hvalsLo_prev : valsLoPrev ≤ vals (c.prev q) := by + exact (hsound.value_bounds.vals_bounds (c.prev q)).1 + have hsum_vals_ge : + sumOthers * loAtReal ≤ ∑ k ∈ others, weights q k * vals k := by + have hsum_lo : + sumOthers * loAtReal = ∑ k ∈ others, weights q k * loAtReal := by + have hsum_lo' : + (∑ k ∈ others, weights q k) * loAtReal = + ∑ k ∈ others, weights q k * loAtReal := by + simpa using + (Finset.sum_mul (s := others) (f := fun k => weights q k) (a := loAtReal)) + simpa [sumOthers] using hsum_lo' + have hle : + ∀ k ∈ others, weights q k * loAtReal ≤ weights q k * vals k := by + intro k _hk + have hval := hvals_lo k _hk + have hnonneg := hweights_nonneg k + exact mul_le_mul_of_nonneg_left hval hnonneg + have hsum' : + ∑ k ∈ others, weights q k * loAtReal ≤ + ∑ k ∈ others, weights q k * vals k := by + exact Finset.sum_le_sum hle + simpa [hsum_lo] using hsum' + have hsum_prod : + weights q (c.prev q) * vals (c.prev q) + + ∑ k ∈ others, weights q k * vals k = + ∑ k, weights q k * vals k := by + simp [others] + have hout_eq : + dotProduct (weights q) vals = + weights q (c.prev q) * vals (c.prev q) + + ∑ k ∈ others, weights q k * vals k := by + simpa [dotProduct] using hsum_prod.symm + have hdot_ge : + weights q (c.prev q) * vals (c.prev q) + sumOthers * loAtReal ≤ + dotProduct (weights q) vals := by + have hle : + weights q (c.prev q) * vals (c.prev q) + sumOthers * loAtReal ≤ + weights q (c.prev q) * vals (c.prev q) + + ∑ k ∈ others, weights q k * vals k := by + simpa [add_comm, add_left_comm, add_assoc] using + (add_le_add_left hsum_vals_ge (weights q (c.prev q) * vals (c.prev q))) + simpa [sumOthers, hout_eq, add_comm, add_left_comm, add_assoc] using hle + have hprev_lo : + weights q (c.prev q) * valsLoPrev ≤ + weights q (c.prev q) * vals (c.prev q) := by + exact mul_le_mul_of_nonneg_left hvalsLo_prev (hweights_nonneg (c.prev q)) + have hdot_ge' : + weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal ≤ + dotProduct (weights q) vals := by + have hle : + weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal ≤ + weights q (c.prev q) * vals (c.prev q) + sumOthers * loAtReal := by + simpa [add_comm, add_left_comm, add_assoc] using + (add_le_add_right hprev_lo (sumOthers * loAtReal)) + exact hle.trans hdot_ge + have hsplit : + weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal = + valsLoPrev - sumOthers * (valsLoPrev - loAtReal) := by + have hsplit' : + weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal = + (weights q (c.prev q) + sumOthers) * valsLoPrev - + sumOthers * (valsLoPrev - loAtReal) := by + ring + calc + weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal = + (weights q (c.prev q) + sumOthers) * valsLoPrev - + sumOthers * (valsLoPrev - loAtReal) := hsplit' + _ = valsLoPrev - sumOthers * (valsLoPrev - loAtReal) := by + simp [hsum, sumOthers] + have hdiff_le : valsLoPrev - loAtReal ≤ max (0 : Real) (valsLoPrev - loAtReal) := by + exact le_max_right _ _ + have hsum_nonneg : 0 ≤ sumOthers := by + have hnonneg : ∀ k ∈ others, 0 ≤ weights q k := by + intro k _hk + exact hweights_nonneg k + have hsum_nonneg' : 0 ≤ ∑ k ∈ others, weights q k := by + exact Finset.sum_nonneg hnonneg + simpa [sumOthers] using hsum_nonneg' + have hsum_mul_le_left : + sumOthers * (valsLoPrev - loAtReal) ≤ + sumOthers * max (0 : Real) (valsLoPrev - loAtReal) := by + exact mul_le_mul_of_nonneg_left hdiff_le hsum_nonneg + have hmax_nonneg : 0 ≤ max (0 : Real) (valsLoPrev - loAtReal) := by + exact le_max_left _ _ + have hsum_mul_le : + sumOthers * (valsLoPrev - loAtReal) ≤ + (c.epsAt q : Real) * max (0 : Real) (valsLoPrev - loAtReal) := by + have hsum_mul_le_right : + sumOthers * max (0 : Real) (valsLoPrev - loAtReal) ≤ + (c.epsAt q : Real) * max (0 : Real) (valsLoPrev - loAtReal) := by + exact mul_le_mul_of_nonneg_right hsum_others_le hmax_nonneg + exact le_trans hsum_mul_le_left hsum_mul_le_right + have hsub_le : + valsLoPrev - (c.epsAt q : Real) * max (0 : Real) (valsLoPrev - loAtReal) ≤ + valsLoPrev - sumOthers * (valsLoPrev - loAtReal) := by + exact sub_le_sub_left hsum_mul_le valsLoPrev + have hdot_lower : + valsLoPrev - (c.epsAt q : Real) * max (0 : Real) (valsLoPrev - loAtReal) ≤ + dotProduct (weights q) vals := by + calc + valsLoPrev - (c.epsAt q : Real) * max (0 : Real) (valsLoPrev - loAtReal) ≤ + valsLoPrev - sumOthers * (valsLoPrev - loAtReal) := hsub_le + _ = weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal := by + simp [hsplit] + _ ≤ dotProduct (weights q) vals := hdot_ge' + have hle : (lb : Real) ≤ dotProduct (weights q) vals := + le_trans hboundReal hdot_lower + simpa [headLogitDiff, weights, vals] using hle + +/-- The unweighted logit-diff lower bound is sound for any valid per-query `epsAt`. -/ +theorem logitDiffLowerBoundFromCacheWithEps_le + (inputs : Model.InductionHeadInputs seq dModel dHead) + (c : InductionHeadCert seq) (epsAtCustom : Fin seq → Rat) + (hsound : InductionHeadCertSound inputs c) + (honeHot : + ∀ q, q ∈ c.active → + Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAtCustom q : Real) + (fun q' => q' = q) c.prev + (fun q' k => Circuit.softmax (scoresRealOfInputs inputs q') k)) + {lb : Rat} + (hbound : + logitDiffLowerBoundFromCacheWithEps c (logitDiffCache c) epsAtCustom = some lb) + {q : Fin seq} (hq : q ∈ c.active) : + (lb : Real) ≤ headLogitDiff inputs q := by + classical + cases seq with + | zero => + cases (NeZero.ne (n := (0 : Nat)) rfl) + | succ n => + let weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := fun q k => + Circuit.softmax (scoresRealOfInputs inputs q) k + let vals : Fin (Nat.succ n) → Real := valsRealOfInputs inputs + let epsArr : Array Rat := Array.ofFn epsAtCustom + let valsLoArr : Array Rat := Array.ofFn (logitDiffCache c).valsLo + let epsAt : Fin (Nat.succ n) → Rat := fun q => + epsArr[q.1]'(by + simp [epsArr, q.isLt]) + let valsLo : Fin (Nat.succ n) → Rat := fun q => + valsLoArr[q.1]'(by + simp [valsLoArr, q.isLt]) + let loAt : Fin (Nat.succ n) → Rat := fun q => + let others : Finset (Fin (Nat.succ n)) := + (Finset.univ : Finset (Fin (Nat.succ n))).erase (c.prev q) + if h : others.Nonempty then + others.inf' h valsLo + else + c.values.lo + let others : Finset (Fin (Nat.succ n)) := + (Finset.univ : Finset (Fin (Nat.succ n))).erase (c.prev q) + let sumOthers : Real := ∑ k ∈ others, weights q k + let valsLoPrev : Real := (c.values.valsLo (c.prev q) : Real) + let loAtRat : Rat := loAt q + let loAtReal : Real := (loAtRat : Real) + have hboundRat : + lb ≤ valsLo (c.prev q) - + epsAt q * max (0 : Rat) (valsLo (c.prev q) - loAt q) := by + refine + Circuit.logitDiffLowerBoundAtLoAt_le + (active := c.active) + (prev := c.prev) + (epsAt := epsAt) + (loAt := loAt) + (valsLo := valsLo) + q hq lb ?_ + simpa [logitDiffLowerBoundFromCacheWithEps, loAt, epsAt, valsLo, valsLoArr, epsArr, + logitDiffCache] using hbound + have hepsAt : epsAt q = epsAtCustom q := by + simp [epsAt, epsArr] + have hvalsLo : ∀ k, valsLo k = c.values.valsLo k := by intro k - have hlo := hsound.value_bounds.lo_le_valsLo k - have hvals := (hsound.value_bounds.vals_bounds k).1 - exact le_trans hlo hvals + simp [valsLo, valsLoArr, logitDiffCache, Bounds.cacheBoundTask_apply] + have hboundRat' : + lb ≤ c.values.valsLo (c.prev q) - + epsAtCustom q * max (0 : Rat) (c.values.valsLo (c.prev q) - loAt q) := by + simpa [hepsAt, hvalsLo] using hboundRat + have hboundReal : + (lb : Real) ≤ + valsLoPrev - (epsAtCustom q : Real) * + max (0 : Real) (valsLoPrev - loAtReal) := by + simpa [loAtRat, loAtReal, ratToReal_sub, ratToReal_mul, ratToReal_max, ratToReal_def] + using ratToReal_le_of_le hboundRat' + have hweights_nonneg : ∀ k, 0 ≤ weights q k := by + have hweights := honeHot q hq + simpa [weights] using hweights.nonneg q rfl + have hweights := honeHot q hq + have hsum_decomp : + weights q (c.prev q) + ∑ k ∈ others, weights q k = ∑ k, weights q k := by + simp [others] + have hsum : + weights q (c.prev q) + ∑ k ∈ others, weights q k = 1 := by + have hsum_one : (∑ k, weights q k) = 1 := by + simpa [weights] using hweights.sum_one q rfl + calc + weights q (c.prev q) + ∑ k ∈ others, weights q k = ∑ k, weights q k := hsum_decomp + _ = 1 := hsum_one + have hsum_others_le : sumOthers ≤ (epsAtCustom q : Real) := by + have hprev : 1 ≤ weights q (c.prev q) + (epsAtCustom q : Real) := + hweights.prev_large q rfl + have hprev' : + weights q (c.prev q) + sumOthers ≤ + weights q (c.prev q) + (epsAtCustom q : Real) := by + simpa [hsum, sumOthers] using hprev + exact (add_le_add_iff_left (weights q (c.prev q))).1 hprev' + have hloAt_le_valsLo : ∀ k ∈ others, loAtRat ≤ c.values.valsLo k := by + intro k hk + have hnonempty : others.Nonempty := ⟨k, hk⟩ + have hmin : others.inf' hnonempty valsLo ≤ valsLo k := + Finset.inf'_le (s := others) (f := valsLo) hk + have hnonempty' : (Finset.univ.erase (c.prev q)).Nonempty := by + simpa [others] using hnonempty + have hloAt : loAtRat = others.inf' hnonempty valsLo := by + dsimp [loAtRat, loAt] + simp [hnonempty', others] + have hvalsLo' : valsLo k = c.values.valsLo k := hvalsLo k + calc + loAtRat = others.inf' hnonempty valsLo := hloAt + _ ≤ valsLo k := hmin + _ = c.values.valsLo k := hvalsLo' + have hvals_lo : ∀ k ∈ others, loAtReal ≤ vals k := by + intro k hk + have hloRat := hloAt_le_valsLo k hk + have hloReal : loAtReal ≤ (c.values.valsLo k : Real) := by + simpa [loAtReal, ratToReal_def] using (ratToReal_le_of_le hloRat) + have hvals : (c.values.valsLo k : Real) ≤ vals k := by + simpa using (hsound.value_bounds.vals_bounds k).1 + exact le_trans hloReal hvals have hvalsLo_prev : valsLoPrev ≤ vals (c.prev q) := by exact (hsound.value_bounds.vals_bounds (c.prev q)).1 have hsum_vals_ge : - sumOthers * lo ≤ ∑ k ∈ others, weights q k * vals k := by + sumOthers * loAtReal ≤ ∑ k ∈ others, weights q k * vals k := by have hsum_lo : - sumOthers * lo = ∑ k ∈ others, weights q k * lo := by + sumOthers * loAtReal = ∑ k ∈ others, weights q k * loAtReal := by have hsum_lo' : - (∑ k ∈ others, weights q k) * lo = - ∑ k ∈ others, weights q k * lo := by + (∑ k ∈ others, weights q k) * loAtReal = + ∑ k ∈ others, weights q k * loAtReal := by simpa using - (Finset.sum_mul (s := others) (f := fun k => weights q k) (a := lo)) + (Finset.sum_mul (s := others) (f := fun k => weights q k) (a := loAtReal)) simpa [sumOthers] using hsum_lo' have hle : - ∀ k ∈ others, weights q k * lo ≤ weights q k * vals k := by + ∀ k ∈ others, weights q k * loAtReal ≤ weights q k * vals k := by intro k _hk - have hval := hvals_lo k + have hval := hvals_lo k _hk have hnonneg := hweights_nonneg k exact mul_le_mul_of_nonneg_left hval hnonneg have hsum' : - ∑ k ∈ others, weights q k * lo ≤ + ∑ k ∈ others, weights q k * loAtReal ≤ ∑ k ∈ others, weights q k * vals k := by exact Finset.sum_le_sum hle simpa [hsum_lo] using hsum' @@ -283,10 +864,10 @@ theorem logitDiffLowerBoundFromCert_le ∑ k ∈ others, weights q k * vals k := by simpa [dotProduct] using hsum_prod.symm have hdot_ge : - weights q (c.prev q) * vals (c.prev q) + sumOthers * lo ≤ + weights q (c.prev q) * vals (c.prev q) + sumOthers * loAtReal ≤ dotProduct (weights q) vals := by have hle : - weights q (c.prev q) * vals (c.prev q) + sumOthers * lo ≤ + weights q (c.prev q) * vals (c.prev q) + sumOthers * loAtReal ≤ weights q (c.prev q) * vals (c.prev q) + ∑ k ∈ others, weights q k * vals k := by simpa [add_comm, add_left_comm, add_assoc] using @@ -297,45 +878,62 @@ theorem logitDiffLowerBoundFromCert_le weights q (c.prev q) * vals (c.prev q) := by exact mul_le_mul_of_nonneg_left hvalsLo_prev (hweights_nonneg (c.prev q)) have hdot_ge' : - weights q (c.prev q) * valsLoPrev + sumOthers * lo ≤ + weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal ≤ dotProduct (weights q) vals := by have hle : - weights q (c.prev q) * valsLoPrev + sumOthers * lo ≤ - weights q (c.prev q) * vals (c.prev q) + sumOthers * lo := by + weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal ≤ + weights q (c.prev q) * vals (c.prev q) + sumOthers * loAtReal := by simpa [add_comm, add_left_comm, add_assoc] using - (add_le_add_right hprev_lo (sumOthers * lo)) + (add_le_add_right hprev_lo (sumOthers * loAtReal)) exact hle.trans hdot_ge have hsplit : - weights q (c.prev q) * valsLoPrev + sumOthers * lo = - valsLoPrev - sumOthers * (valsLoPrev - lo) := by + weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal = + valsLoPrev - sumOthers * (valsLoPrev - loAtReal) := by have hsplit' : - weights q (c.prev q) * valsLoPrev + sumOthers * lo = + weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal = (weights q (c.prev q) + sumOthers) * valsLoPrev - - sumOthers * (valsLoPrev - lo) := by + sumOthers * (valsLoPrev - loAtReal) := by ring calc - weights q (c.prev q) * valsLoPrev + sumOthers * lo = + weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal = (weights q (c.prev q) + sumOthers) * valsLoPrev - - sumOthers * (valsLoPrev - lo) := hsplit' - _ = valsLoPrev - sumOthers * (valsLoPrev - lo) := by + sumOthers * (valsLoPrev - loAtReal) := hsplit' + _ = valsLoPrev - sumOthers * (valsLoPrev - loAtReal) := by simp [hsum, sumOthers] - have hdiff_nonneg : 0 ≤ valsLoPrev - lo := by - exact sub_nonneg.mpr (hsound.value_bounds.lo_le_valsLo (c.prev q)) + have hdiff_le : valsLoPrev - loAtReal ≤ max (0 : Real) (valsLoPrev - loAtReal) := by + exact le_max_right _ _ + have hsum_nonneg : 0 ≤ sumOthers := by + have hnonneg : ∀ k ∈ others, 0 ≤ weights q k := by + intro k _hk + exact hweights_nonneg k + have hsum_nonneg' : 0 ≤ ∑ k ∈ others, weights q k := by + exact Finset.sum_nonneg hnonneg + simpa [sumOthers] using hsum_nonneg' + have hsum_mul_le_left : + sumOthers * (valsLoPrev - loAtReal) ≤ + sumOthers * max (0 : Real) (valsLoPrev - loAtReal) := by + exact mul_le_mul_of_nonneg_left hdiff_le hsum_nonneg + have hmax_nonneg : 0 ≤ max (0 : Real) (valsLoPrev - loAtReal) := by + exact le_max_left _ _ have hsum_mul_le : - sumOthers * (valsLoPrev - lo) ≤ - (c.epsAt q : Real) * (valsLoPrev - lo) := by - exact mul_le_mul_of_nonneg_right hsum_others_le hdiff_nonneg + sumOthers * (valsLoPrev - loAtReal) ≤ + (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) := by + have hsum_mul_le_right : + sumOthers * max (0 : Real) (valsLoPrev - loAtReal) ≤ + (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) := by + exact mul_le_mul_of_nonneg_right hsum_others_le hmax_nonneg + exact le_trans hsum_mul_le_left hsum_mul_le_right have hsub_le : - valsLoPrev - (c.epsAt q : Real) * (valsLoPrev - lo) ≤ - valsLoPrev - sumOthers * (valsLoPrev - lo) := by + valsLoPrev - (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) ≤ + valsLoPrev - sumOthers * (valsLoPrev - loAtReal) := by exact sub_le_sub_left hsum_mul_le valsLoPrev have hdot_lower : - valsLoPrev - (c.epsAt q : Real) * (valsLoPrev - lo) ≤ + valsLoPrev - (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) ≤ dotProduct (weights q) vals := by calc - valsLoPrev - (c.epsAt q : Real) * (valsLoPrev - lo) ≤ - valsLoPrev - sumOthers * (valsLoPrev - lo) := hsub_le - _ = weights q (c.prev q) * valsLoPrev + sumOthers * lo := by + valsLoPrev - (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) ≤ + valsLoPrev - sumOthers * (valsLoPrev - loAtReal) := hsub_le + _ = weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal := by simp [hsplit] _ ≤ dotProduct (weights q) vals := hdot_ge' have hle : (lb : Real) ≤ dotProduct (weights q) vals := @@ -364,8 +962,8 @@ theorem logitDiffLowerBoundFromCertWeighted_le let valsLoPrev : Real := (valsLoPrevRat : Real) have hboundRat : lb ≤ valsLoPrevRat - - (others.sum (fun k => - c.weightBoundAt q k * max (0 : Rat) (valsLoPrevRat - valsLoCached k))) := by + (Finset.univ : Finset (Fin (Nat.succ n))).sum (fun k => + c.weightBoundAt q k * max (0 : Rat) (valsLoPrevRat - valsLoCached k)) := by refine Circuit.logitDiffLowerBoundWeightedAt_le (active := c.active) @@ -376,17 +974,36 @@ theorem logitDiffLowerBoundFromCertWeighted_le simpa [logitDiffLowerBoundFromCertWeighted] using hbound have hboundRat' : lb ≤ valsLoPrevRat - - (others.sum (fun k => - c.weightBoundAt q k * max (0 : Rat) (valsLoPrevRat - c.values.valsLo k))) := by + (Finset.univ : Finset (Fin (Nat.succ n))).sum (fun k => + c.weightBoundAt q k * max (0 : Rat) (valsLoPrevRat - c.values.valsLo k)) := by simpa [valsLoCached, valsLoPrevRat, Bounds.cacheBoundTask_apply] using hboundRat have hboundReal : (lb : Real) ≤ valsLoPrev - - (others.sum (fun k => + (Finset.univ : Finset (Fin (Nat.succ n))).sum (fun k => (c.weightBoundAt q k : Real) * - max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)))) := by + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real))) := by simpa [valsLoPrevRat, valsLoPrev, ratToReal_sub, ratToReal_mul, ratToReal_max, ratToReal_def, Rat.cast_sum] using ratToReal_le_of_le hboundRat' + let gapTerm : Fin (Nat.succ n) → Real := fun k => + (c.weightBoundAt q k : Real) * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) + have hgap_prev : gapTerm (c.prev q) = 0 := by + have hdiff : valsLoPrev - (c.values.valsLo (c.prev q) : Real) = 0 := by + simp [valsLoPrev, valsLoPrevRat, valsLoCached, Bounds.cacheBoundTask_apply] + simp [gapTerm, hdiff] + have hsum_gap : + (Finset.univ : Finset (Fin (Nat.succ n))).sum gapTerm = + ∑ k ∈ others, gapTerm k := by + classical + have hsum := + (Finset.sum_erase (s := (Finset.univ : Finset (Fin (Nat.succ n)))) + (f := gapTerm) (a := c.prev q) hgap_prev) + simpa [others] using hsum.symm + have hboundReal' : + (lb : Real) ≤ + valsLoPrev - ∑ k ∈ others, gapTerm k := by + simpa [gapTerm, hsum_gap] using hboundReal have hweights_nonneg : ∀ k, 0 ≤ weights q k := hsound.softmax_bounds.nonneg q hq have hweights := hsound.oneHot_bounds_at q hq @@ -562,7 +1179,7 @@ theorem logitDiffLowerBoundFromCertWeighted_le simpa using hsplit.symm _ ≤ dotProduct (weights q) vals := hdot_ge' have hle : (lb : Real) ≤ dotProduct (weights q) vals := - le_trans hboundReal hdot_lower + le_trans hboundReal' hdot_lower simpa [headLogitDiff, weights, vals] using hle /-- The best available logit-diff lower bound is sound on active queries. -/ diff --git a/Nfp/Sound/Induction/OneHot.lean b/Nfp/Sound/Induction/OneHot.lean index 38ca93b..71d9ca9 100644 --- a/Nfp/Sound/Induction/OneHot.lean +++ b/Nfp/Sound/Induction/OneHot.lean @@ -315,6 +315,120 @@ theorem oneHot_bounds_at_of_scoreGapLo simpa using h exact hle.trans hsum_others_le +/-- One-hot bounds on a single active query, derived from per-key weight bounds. -/ +theorem oneHot_bounds_at_of_weight_bounds + (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (scoresReal : Fin seq → Fin seq → Real) + (weightBoundAt : Fin seq → Fin seq → Rat) + (epsAt : Fin seq → Rat) + (hepsAt : + ∀ q, epsAt q = + min (1 : Rat) + ((Finset.univ : Finset (Fin seq)).erase (prev q) |>.sum (fun k => + weightBoundAt q k))) + (hweight_bounds : + ∀ q, q ∈ active → ∀ k, k ≠ prev q → + Circuit.softmax (scoresReal q) k ≤ (weightBoundAt q k : Real)) : + ∀ q, q ∈ active → + Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAt q : Real) + (fun q' => q' = q) prev + (fun q k => Circuit.softmax (scoresReal q) k) := by + classical + intro q hq + let softmaxWeights := Circuit.softmaxWeights scoresReal + let weights : Fin seq → Fin seq → Real := fun q k => + Circuit.softmax (scoresReal q) k + let others : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (prev q) + have hweights_nonneg : ∀ k, 0 ≤ weights q k := by + intro k + simpa [weights, softmaxWeights, Circuit.softmaxWeights_weights] using + softmaxWeights.nonneg q k + have hsum_one : (∑ k, weights q k) = 1 := by + simpa [weights, softmaxWeights, Circuit.softmaxWeights_weights] using + softmaxWeights.sum_one q + have hsum_others_le_one : (∑ k ∈ others q, weights q k) ≤ 1 := by + have hsubset : others q ⊆ (Finset.univ : Finset (Fin seq)) := by + intro k hk + simp + have hsum_le : + (∑ k ∈ others q, weights q k) ≤ + ∑ k ∈ (Finset.univ : Finset (Fin seq)), weights q k := + Finset.sum_le_sum_of_subset_of_nonneg hsubset (by + intro k _ _ + exact hweights_nonneg k) + simpa [hsum_one] using hsum_le + have hbound : + ∀ k ∈ others q, weights q k ≤ (weightBoundAt q k : Real) := by + intro k hk + have hkne : k ≠ prev q := (Finset.mem_erase.mp hk).1 + exact hweight_bounds q hq k hkne + have hsum_others_le : (∑ k ∈ others q, weights q k) ≤ (epsAt q : Real) := by + have hsum_le : + (∑ k ∈ others q, weights q k) ≤ + ∑ k ∈ others q, (weightBoundAt q k : Real) := + Finset.sum_le_sum hbound + have hsum_le_min : + (∑ k ∈ others q, weights q k) ≤ + min (1 : Real) (∑ k ∈ others q, (weightBoundAt q k : Real)) := by + exact le_min hsum_others_le_one hsum_le + have hepsAtReal : + (epsAt q : Real) = min (1 : Real) (∑ k ∈ others q, (weightBoundAt q k : Real)) := by + have h' : epsAt q = min 1 ((others q).sum (fun k => weightBoundAt q k)) := by + simpa [others] using hepsAt q + have h'' : + ratToReal (epsAt q) = + ratToReal (min 1 ((others q).sum (fun k => weightBoundAt q k))) := by + exact congrArg ratToReal h' + simpa [ratToReal_min, ratToReal_def, Rat.cast_sum] using h'' + simpa [hepsAtReal] using hsum_le_min + refine + { nonneg := ?_ + sum_one := ?_ + prev_large := ?_ + other_le := ?_ } + · intro q' hq' k + subst q' + exact hweights_nonneg k + · intro q' hq' + subst q' + exact hsum_one + · intro q' hq' + subst q' + have hsum_eq : + weights q (prev q) + ∑ k ∈ others q, weights q k = 1 := by + have hsum' : + weights q (prev q) + ∑ k ∈ others q, weights q k = + ∑ k, weights q k := by + simp [others] + calc + weights q (prev q) + ∑ k ∈ others q, weights q k = + ∑ k, weights q k := hsum' + _ = 1 := hsum_one + have hsum_le' : + weights q (prev q) + ∑ k ∈ others q, weights q k ≤ + weights q (prev q) + (epsAt q : Real) := by + simpa [add_comm, add_left_comm, add_assoc] using + (add_le_add_left hsum_others_le (weights q (prev q))) + have hprev : + 1 ≤ weights q (prev q) + (epsAt q : Real) := by + simpa [hsum_eq] using hsum_le' + exact hprev + · intro q' hq' k hk + subst q' + have hk' : k ∈ others q := by + simp [others, hk] + have hnonneg : + ∀ j ∈ others q, 0 ≤ weights q j := by + intro j _ + exact hweights_nonneg j + have hle : + weights q k ≤ ∑ j ∈ others q, weights q j := by + have h := Finset.single_le_sum hnonneg hk' + simpa using h + exact hle.trans hsum_others_le + /-- Per-key weight bounds on a single active query, derived from per-key score gaps. -/ theorem weight_bound_at_of_scoreGapLo (active : Finset (Fin seq)) diff --git a/Nfp/Sound/Induction/Refine.lean b/Nfp/Sound/Induction/Refine.lean new file mode 100644 index 0000000..ed57a35 --- /dev/null +++ b/Nfp/Sound/Induction/Refine.lean @@ -0,0 +1,331 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Sound.Induction.Core + +/-! +Refine-on-demand helpers for induction-head bounds. + +These definitions reuse cached core bounds to compute tightened score gaps and +weight bounds for selected query/key pairs without rebuilding the full cache. +-/ + +public section + +namespace Nfp + +namespace Sound + +variable {seq dModel dHead : Nat} + +/-- Specification for refining per-key bounds. -/ +structure InductionHeadRefineSpec (seq : Nat) where + /-- Keys to refine for each query. -/ + refineKeys : Fin seq → Finset (Fin seq) + /-- Split budget for refined diff bounds. -/ + splitBudgetDiffRefined : Nat + +/-- Worst key under the base score-gap lower bound (excluding `prev`). -/ +def worstKeyBase + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (q : Fin seq) : Option (Fin seq) := + let ks := (List.finRange seq).filter (fun k => decide (k ≠ inputs.prev q)) + match ks with + | [] => none + | k :: ks => + let step (best : Rat × Fin seq) (k : Fin seq) := + let s := cache.scoreGapLoBase q k + if s ≤ best.1 then (s, k) else best + some (ks.foldl step (cache.scoreGapLoBase q k, k)).2 + +/-- Unfolding lemma for `worstKeyBase`. -/ +theorem worstKeyBase_def + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (q : Fin seq) : + worstKeyBase inputs cache q = + let ks := (List.finRange seq).filter (fun k => decide (k ≠ inputs.prev q)) + match ks with + | [] => none + | k :: ks => + let step (best : Rat × Fin seq) (k : Fin seq) := + let s := cache.scoreGapLoBase q k + if s ≤ best.1 then (s, k) else best + some (ks.foldl step (cache.scoreGapLoBase q k, k)).2 := by + rfl + +/-- Keys whose base weight bounds are already `1`. -/ +def weightOneKeysAt + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (q : Fin seq) : Finset (Fin seq) := + let others : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + others.filter (fun k => decide (cache.weightBoundAt q k = (1 : Rat))) + +/-- Unfolding lemma for `weightOneKeysAt`. -/ +theorem weightOneKeysAt_def + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (q : Fin seq) : + weightOneKeysAt inputs cache q = + let others : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + others.filter (fun k => decide (cache.weightBoundAt q k = (1 : Rat))) := by + rfl + +/-- Refinement keys for a query, seeded by negative base gaps and the worst key. -/ +def refineKeysAt + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (q : Fin seq) : Finset (Fin seq) := + let neg := + (cache.otherKeys q).filter (fun k => decide (cache.scoreGapLoBase q k < 0)) + match worstKeyBase inputs cache q with + | none => neg + | some k => insert k neg + +/-- Unfolding lemma for `refineKeysAt`. -/ +theorem refineKeysAt_def + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (q : Fin seq) : + refineKeysAt inputs cache q = + let neg := + (cache.otherKeys q).filter (fun k => decide (cache.scoreGapLoBase q k < 0)) + match worstKeyBase inputs cache q with + | none => neg + | some k => insert k neg := by + rfl + +/-- Refinement keys that also include weight-one keys. -/ +def refineKeysAtWithWeightOnes + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (q : Fin seq) : Finset (Fin seq) := + refineKeysAt inputs cache q ∪ weightOneKeysAt inputs cache q + +/-- Unfolding lemma for `refineKeysAtWithWeightOnes`. -/ +theorem refineKeysAtWithWeightOnes_def + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (q : Fin seq) : + refineKeysAtWithWeightOnes inputs cache q = + refineKeysAt inputs cache q ∪ weightOneKeysAt inputs cache q := by + rfl + +/-- Refinement spec focused on a single query. -/ +def refineSpecForQuery + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (q : Fin seq) (budget : Nat) : InductionHeadRefineSpec seq := + let keys := refineKeysAt inputs cache q + { refineKeys := fun q' => if _ : q' = q then keys else ∅ + splitBudgetDiffRefined := budget } + +/-- Unfolding lemma for `refineSpecForQuery`. -/ +theorem refineSpecForQuery_def + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (q : Fin seq) (budget : Nat) : + refineSpecForQuery inputs cache q budget = + let keys := refineKeysAt inputs cache q + { refineKeys := fun q' => if _ : q' = q then keys else ∅ + splitBudgetDiffRefined := budget } := by + rfl + +/-- Refinement spec for a single query, including weight-one keys. -/ +def refineSpecForQueryWithWeightOnes + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (q : Fin seq) (budget : Nat) : InductionHeadRefineSpec seq := + let keys := refineKeysAtWithWeightOnes inputs cache q + { refineKeys := fun q' => if _ : q' = q then keys else ∅ + splitBudgetDiffRefined := budget } + +/-- Unfolding lemma for `refineSpecForQueryWithWeightOnes`. -/ +theorem refineSpecForQueryWithWeightOnes_def + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (q : Fin seq) (budget : Nat) : + refineSpecForQueryWithWeightOnes inputs cache q budget = + let keys := refineKeysAtWithWeightOnes inputs cache q + { refineKeys := fun q' => if _ : q' = q then keys else ∅ + splitBudgetDiffRefined := budget } := by + rfl + +/-- Refined diff dot-product lower bound at a single `(q,k)` pair. -/ +def dotDiffLoRefinedAt + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (budget : Nat) (q k : Fin seq) : Rat := + let dimsQ := cache.splitDimsQ q + let dimsDiff := cache.splitDimsDiffCore budget q k + let prev := inputs.prev q + (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => cache.qLo q d) (fun d => cache.qHi q d) + (fun d => cache.kLo prev d - cache.kHi k d) + (fun d => cache.kHi prev d - cache.kLo k d)).1 + +/-- Unfolding lemma for `dotDiffLoRefinedAt`. -/ +theorem dotDiffLoRefinedAt_def + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (budget : Nat) (q k : Fin seq) : + dotDiffLoRefinedAt inputs cache budget q k = + let dimsQ := cache.splitDimsQ q + let dimsDiff := cache.splitDimsDiffCore budget q k + let prev := inputs.prev q + (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => cache.qLo q d) (fun d => cache.qHi q d) + (fun d => cache.kLo prev d - cache.kHi k d) + (fun d => cache.kHi prev d - cache.kLo k d)).1 := by + rfl + +/-- Refined diff dot-product upper bound at a single `(q,k)` pair. -/ +def dotDiffHiRefinedAt + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (budget : Nat) (q k : Fin seq) : Rat := + let dimsQ := cache.splitDimsQ q + let dimsDiff := cache.splitDimsDiffCore budget q k + let prev := inputs.prev q + (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => cache.qLo q d) (fun d => cache.qHi q d) + (fun d => cache.kLo prev d - cache.kHi k d) + (fun d => cache.kHi prev d - cache.kLo k d)).2 + +/-- Unfolding lemma for `dotDiffHiRefinedAt`. -/ +theorem dotDiffHiRefinedAt_def + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (budget : Nat) (q k : Fin seq) : + dotDiffHiRefinedAt inputs cache budget q k = + let dimsQ := cache.splitDimsQ q + let dimsDiff := cache.splitDimsDiffCore budget q k + let prev := inputs.prev q + (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff + (fun d => cache.qLo q d) (fun d => cache.qHi q d) + (fun d => cache.kLo prev d - cache.kHi k d) + (fun d => cache.kHi prev d - cache.kLo k d)).2 := by + rfl + +/-- Refined score-gap lower bound at `(q,k)` using a custom diff budget. -/ +def scoreGapLoRefinedAt + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (budget : Nat) (q k : Fin seq) : Rat := + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + if masked q (inputs.prev q) then + cache.scoreLoPrev q - cache.scoreHi q k + else if masked q k then + cache.scoreLoPrev q - inputs.maskValue + else if _ : 0 ≤ inputs.scale then + inputs.scale * dotDiffLoRefinedAt inputs cache budget q k + else + inputs.scale * dotDiffHiRefinedAt inputs cache budget q k + +/-- Unfolding lemma for `scoreGapLoRefinedAt`. -/ +theorem scoreGapLoRefinedAt_def + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (budget : Nat) (q k : Fin seq) : + scoreGapLoRefinedAt inputs cache budget q k = + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + if masked q (inputs.prev q) then + cache.scoreLoPrev q - cache.scoreHi q k + else if masked q k then + cache.scoreLoPrev q - inputs.maskValue + else if _ : 0 ≤ inputs.scale then + inputs.scale * dotDiffLoRefinedAt inputs cache budget q k + else + inputs.scale * dotDiffHiRefinedAt inputs cache budget q k := by + rfl + +/-- Refined per-key weight bound at `(q,k)` derived from refined score gaps. -/ +def weightBoundAtRefinedAt + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (budget : Nat) (q k : Fin seq) : Rat := + if _ : k = inputs.prev q then + (0 : Rat) + else + let gap := scoreGapLoRefinedAt inputs cache budget q k + if gap < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + gap) + +/-- Unfolding lemma for `weightBoundAtRefinedAt`. -/ +theorem weightBoundAtRefinedAt_def + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (budget : Nat) (q k : Fin seq) : + weightBoundAtRefinedAt inputs cache budget q k = + if _ : k = inputs.prev q then + (0 : Rat) + else + let gap := scoreGapLoRefinedAt inputs cache budget q k + if gap < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + gap) := by + rfl + +/-- Overlay that refines only selected `(q,k)` weight bounds. -/ +def weightBoundAtOverlay + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (spec : InductionHeadRefineSpec seq) : + Fin seq → Fin seq → Rat := fun q k => + if _ : k = inputs.prev q then + (0 : Rat) + else if _ : k ∈ spec.refineKeys q then + weightBoundAtRefinedAt inputs cache spec.splitBudgetDiffRefined q k + else + cache.weightBoundAt q k + +/-- Unfolding lemma for `weightBoundAtOverlay`. -/ +theorem weightBoundAtOverlay_def + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (spec : InductionHeadRefineSpec seq) + (q k : Fin seq) : + weightBoundAtOverlay inputs cache spec q k = + if _ : k = inputs.prev q then + (0 : Rat) + else if _ : k ∈ spec.refineKeys q then + weightBoundAtRefinedAt inputs cache spec.splitBudgetDiffRefined q k + else + cache.weightBoundAt q k := by + rfl + +/-- Overlayed eps bound derived from overlayed per-key bounds. -/ +def epsAtOverlay + (cache : InductionHeadCoreCache seq dModel dHead) + (weightBoundAt : Fin seq → Fin seq → Rat) : + Fin seq → Rat := fun q => + let other : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).erase (cache.cert.prev q) + let total := other.sum (fun k => weightBoundAt q k) + min (1 : Rat) total + +/-- Unfolding lemma for `epsAtOverlay`. -/ +theorem epsAtOverlay_def + (cache : InductionHeadCoreCache seq dModel dHead) + (weightBoundAt : Fin seq → Fin seq → Rat) + (q : Fin seq) : + epsAtOverlay cache weightBoundAt q = + let other : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).erase (cache.cert.prev q) + let total := other.sum (fun k => weightBoundAt q k) + min (1 : Rat) total := by + rfl + +end Sound + +end Nfp diff --git a/Nfp/Sound/Induction/RefineSound.lean b/Nfp/Sound/Induction/RefineSound.lean new file mode 100644 index 0000000..7b25c4e --- /dev/null +++ b/Nfp/Sound/Induction/RefineSound.lean @@ -0,0 +1,596 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Sound.Induction.LogitDiff +public import Nfp.Sound.Induction.OneHot +public import Nfp.Sound.Induction.Refine + +/-! +Soundness lemmas for refine-on-demand overlays. +-/ + +public section + +namespace Nfp + +namespace Sound + +open Nfp.Circuit +open Nfp.Sound.Bounds + +variable {seq dModel dHead : Nat} + +/-- Refined score-gap bounds are sound when cache score and KV bounds are sound. -/ +theorem scoreGapLoRefinedAt_real_at_of_bounds + [NeZero seq] (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (budget : Nat) + (active : Finset (Fin seq)) + (hq_bounds : + ∀ q d, (cache.qLo q d : Real) ≤ qRealOfInputs inputs q d ∧ + qRealOfInputs inputs q d ≤ (cache.qHi q d : Real)) + (hk_bounds : + ∀ q d, (cache.kLo q d : Real) ≤ kRealOfInputs inputs q d ∧ + kRealOfInputs inputs q d ≤ (cache.kHi q d : Real)) + (hscore_prev : + ∀ q, q ∈ active → + (cache.scoreLoPrev q : Real) ≤ scoresRealOfInputs inputs q (inputs.prev q)) + (hscore_hi : + ∀ q k, scoresRealOfInputs inputs q k ≤ (cache.scoreHi q k : Real)) : + ∀ q, q ∈ active → ∀ k, k ≠ inputs.prev q → + scoresRealOfInputs inputs q k + + (scoreGapLoRefinedAt inputs cache budget q k : Real) ≤ + scoresRealOfInputs inputs q (inputs.prev q) := by + classical + let scoresReal := scoresRealOfInputs inputs + let masked : Fin seq → Fin seq → Prop := fun q k => + inputs.maskCausal = true ∧ q < k + have scoresReal_eq_base_of_not_masked : + ∀ q k, ¬ masked q k → + scoresReal q k = + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) := by + intro q k hnot + by_cases hcausal : inputs.maskCausal + · have hnot_lt : ¬ q < k := by + intro hlt + exact hnot ⟨hcausal, hlt⟩ + have hle : k ≤ q := le_of_not_gt hnot_lt + simp [scoresReal, scoresRealOfInputs_def, hcausal, hle] + · simp [scoresReal, scoresRealOfInputs_def, hcausal] + have scoresReal_eq_masked : + ∀ q k, masked q k → scoresReal q k = (inputs.maskValue : Real) := by + intro q k hmask + have hmask' : inputs.maskCausal = true ∧ q < k := by + simpa [masked] using hmask + have hle : ¬ k ≤ q := not_le_of_gt hmask'.2 + simp [scoresReal, scoresRealOfInputs_def, hmask'.1, hle] + have hdot_diff_bounds : + ∀ q k, + (dotDiffLoRefinedAt inputs cache budget q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) ∧ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) ≤ + (dotDiffHiRefinedAt inputs cache budget q k : Real) := by + intro q k + have hlo1 : ∀ d, (cache.qLo q d : Real) ≤ qRealOfInputs inputs q d := fun d => + (hq_bounds q d).1 + have hhi1 : ∀ d, qRealOfInputs inputs q d ≤ (cache.qHi q d : Real) := fun d => + (hq_bounds q d).2 + have hlo2 : + ∀ d, + (cache.kLo (inputs.prev q) d - cache.kHi k d : Rat) ≤ + (kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) := by + intro d + have hprev_lo := (hk_bounds (inputs.prev q) d).1 + have hk_hi := (hk_bounds k d).2 + have h := sub_le_sub hprev_lo hk_hi + simpa [ratToReal_sub] using h + have hhi2 : + ∀ d, + (kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) ≤ + (cache.kHi (inputs.prev q) d - cache.kLo k d : Rat) := by + intro d + have hprev_hi := (hk_bounds (inputs.prev q) d).2 + have hk_lo := (hk_bounds k d).1 + have h := sub_le_sub hprev_hi hk_lo + simpa [ratToReal_sub] using h + have hspec := + _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth_spec_real + (dims1 := cache.splitDimsQ q) (dims2 := cache.splitDimsDiffCore budget q k) + (lo1 := fun d => cache.qLo q d) (hi1 := fun d => cache.qHi q d) + (lo2 := fun d => cache.kLo (inputs.prev q) d - cache.kHi k d) + (hi2 := fun d => cache.kHi (inputs.prev q) d - cache.kLo k d) + (x := fun d => qRealOfInputs inputs q d) + (y := fun d => + kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) + hlo1 hhi1 hlo2 hhi2 + have hlow' : + (dotDiffLoRefinedAt inputs cache budget q k : Real) ≤ + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) := by + simpa [dotDiffLoRefinedAt_def] using hspec.1 + have hhigh' : + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) ≤ (dotDiffHiRefinedAt inputs cache budget q k : Real) := by + simpa [dotDiffHiRefinedAt_def] using hspec.2 + exact ⟨hlow', hhigh'⟩ + intro q hq k hk + by_cases hprevmask : masked q (inputs.prev q) + · have hscore_hi' : scoresReal q k ≤ (cache.scoreHi q k : Real) := + hscore_hi q k + have hscore_prev' : (cache.scoreLoPrev q : Real) ≤ scoresReal q (inputs.prev q) := + hscore_prev q hq + have hsum_le' : + (cache.scoreLoPrev q : Real) - (cache.scoreHi q k : Real) + scoresReal q k ≤ + (cache.scoreLoPrev q : Real) := by + have hsub : + (cache.scoreLoPrev q : Real) - (cache.scoreHi q k : Real) ≤ + (cache.scoreLoPrev q : Real) - scoresReal q k := + sub_le_sub_left hscore_hi' (cache.scoreLoPrev q : Real) + calc + (cache.scoreLoPrev q : Real) - (cache.scoreHi q k : Real) + scoresReal q k + ≤ (cache.scoreLoPrev q : Real) - scoresReal q k + scoresReal q k := by + simpa [add_comm, add_left_comm, add_assoc] using + (add_le_add_left hsub (scoresReal q k)) + _ = (cache.scoreLoPrev q : Real) := by + simp [sub_add_cancel] + calc + scoresReal q k + (scoreGapLoRefinedAt inputs cache budget q k : Real) + = (cache.scoreLoPrev q : Real) - (cache.scoreHi q k : Real) + scoresReal q k := by + simp [scoreGapLoRefinedAt_def, hprevmask, masked, add_comm] + _ ≤ (cache.scoreLoPrev q : Real) := hsum_le' + _ ≤ scoresReal q (inputs.prev q) := hscore_prev' + · by_cases hmask : masked q k + · have hscore_prev' : (cache.scoreLoPrev q : Real) ≤ scoresReal q (inputs.prev q) := + hscore_prev q hq + have hscore_k : scoresReal q k = (inputs.maskValue : Real) := + scoresReal_eq_masked q k hmask + have hmask' : inputs.maskCausal = true ∧ q < k := by + simpa [masked] using hmask + have hnot_lt_prev : ¬ q < inputs.prev q := by + intro hlt + exact hprevmask ⟨hmask'.1, hlt⟩ + calc + scoresReal q k + (scoreGapLoRefinedAt inputs cache budget q k : Real) + = (inputs.maskValue : Real) + (cache.scoreLoPrev q : Real) - + (inputs.maskValue : Real) := by + simp [scoreGapLoRefinedAt_def, hmask', hnot_lt_prev, hscore_k] + _ = (cache.scoreLoPrev q : Real) := by + simp [add_sub_cancel_left] + _ ≤ scoresReal q (inputs.prev q) := hscore_prev' + · have hdiff := hdot_diff_bounds q k + have hgap_le : + (scoreGapLoRefinedAt inputs cache budget q k : Real) ≤ + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) := by + by_cases hscale : 0 ≤ inputs.scale + · have hscale_real : 0 ≤ (inputs.scale : Real) := by + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hscale + have hle := mul_le_mul_of_nonneg_left hdiff.1 hscale_real + simpa [scoreGapLoRefinedAt_def, hprevmask, hmask, hscale, masked] using hle + · have hscale_nonpos : inputs.scale ≤ 0 := + le_of_lt (lt_of_not_ge hscale) + have hscale_real : (inputs.scale : Real) ≤ 0 := by + simpa [ratToReal_def] using + (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos + have hle := mul_le_mul_of_nonpos_left hdiff.2 hscale_real + simpa [scoreGapLoRefinedAt_def, hprevmask, hmask, hscale, masked] using hle + have hscore_prev : + scoresReal q (inputs.prev q) = + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d) := by + simpa using + (scoresReal_eq_base_of_not_masked q (inputs.prev q) hprevmask) + have hscore_k : + scoresReal q k = + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) := by + simpa using (scoresReal_eq_base_of_not_masked q k hmask) + have hdot_sub : + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) = + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d) - + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) := by + classical + simpa using + (Nfp.Sound.Linear.dotProduct_sub_right + (x := fun d => qRealOfInputs inputs q d) + (y := fun d => kRealOfInputs inputs (inputs.prev q) d) + (z := fun d => kRealOfInputs inputs k d)) + have hscore_diff : + scoresReal q (inputs.prev q) - scoresReal q k = + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) := by + calc + scoresReal q (inputs.prev q) - scoresReal q k + = + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d) - + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d) := by + simp [hscore_prev, hscore_k] + _ = + (inputs.scale : Real) * + (dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d) - + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs k d)) := by + simp [mul_sub] + _ = + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) + (fun d => kRealOfInputs inputs (inputs.prev q) d - + kRealOfInputs inputs k d) := by + simp [hdot_sub] + have hgap_le' : + (scoreGapLoRefinedAt inputs cache budget q k : Real) ≤ + scoresReal q (inputs.prev q) - scoresReal q k := by + simpa [hscore_diff] using hgap_le + have hgap_add := add_le_add_right hgap_le' (scoresReal q k) + have hgap_add' : + scoresReal q k + (scoreGapLoRefinedAt inputs cache budget q k : Real) ≤ + scoresReal q (inputs.prev q) := by + have hcancel : + scoresReal q k + (scoresReal q (inputs.prev q) - scoresReal q k) = + scoresReal q (inputs.prev q) := by + calc + scoresReal q k + (scoresReal q (inputs.prev q) - scoresReal q k) + = + scoresReal q k + scoresReal q (inputs.prev q) - + scoresReal q k := by + symm + exact add_sub_assoc (scoresReal q k) + (scoresReal q (inputs.prev q)) (scoresReal q k) + _ = scoresReal q (inputs.prev q) := by + simp [add_sub_cancel_left] + calc + scoresReal q k + (scoreGapLoRefinedAt inputs cache budget q k : Real) + ≤ scoresReal q k + (scoresReal q (inputs.prev q) - scoresReal q k) := hgap_add + _ = scoresReal q (inputs.prev q) := hcancel + exact hgap_add' + +/-- Refined per-key weight bounds are sound when refined score gaps are sound. -/ +theorem weight_bound_at_refinedAt_of_scoreGapLo + [NeZero seq] (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (budget : Nat) + (active : Finset (Fin seq)) + (hscore_gap_real_at : + ∀ q, q ∈ active → ∀ k, k ≠ inputs.prev q → + scoresRealOfInputs inputs q k + + (scoreGapLoRefinedAt inputs cache budget q k : Real) ≤ + scoresRealOfInputs inputs q (inputs.prev q)) : + ∀ q, q ∈ active → ∀ k, k ≠ inputs.prev q → + Circuit.softmax (scoresRealOfInputs inputs q) k ≤ + (weightBoundAtRefinedAt inputs cache budget q k : Real) := by + classical + intro q hq k hk + refine + Sound.weight_bound_at_of_scoreGapLo + (active := active) + (prev := inputs.prev) + (scoresReal := scoresRealOfInputs inputs) + (scoreGapLo := scoreGapLoRefinedAt inputs cache budget) + (weightBoundAt := weightBoundAtRefinedAt inputs cache budget) + (hweightBoundAt := ?_) + (hscore_gap_real_at := hscore_gap_real_at) + q hq k hk + intro q' k' hk' + simp [weightBoundAtRefinedAt_def, hk'] + +/-- Overlayed per-key bounds are sound when base and refined bounds are sound. -/ +theorem weight_bounds_at_overlay_of_refined + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (spec : InductionHeadRefineSpec seq) + (active : Finset (Fin seq)) + (hbase : + ∀ q, q ∈ active → ∀ k, k ≠ inputs.prev q → + Circuit.softmax (scoresRealOfInputs inputs q) k ≤ + (cache.weightBoundAt q k : Real)) + (hrefine : + ∀ q, q ∈ active → ∀ k, k ≠ inputs.prev q → k ∈ spec.refineKeys q → + Circuit.softmax (scoresRealOfInputs inputs q) k ≤ + (weightBoundAtRefinedAt inputs cache spec.splitBudgetDiffRefined q k : Real)) : + ∀ q, q ∈ active → ∀ k, k ≠ inputs.prev q → + Circuit.softmax (scoresRealOfInputs inputs q) k ≤ + (weightBoundAtOverlay inputs cache spec q k : Real) := by + classical + intro q hq k hk + by_cases hmem : k ∈ spec.refineKeys q + · have h := hrefine q hq k hk hmem + simpa [weightBoundAtOverlay_def, hk, hmem] using h + · have h := hbase q hq k hk + simpa [weightBoundAtOverlay_def, hk, hmem] using h + +/-- One-hot bounds derived from an overlayed per-key bound. -/ +theorem oneHot_bounds_at_overlay + [NeZero seq] (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (c : InductionHeadCert seq) + (hcert : c = cache.cert) + (spec : InductionHeadRefineSpec seq) + (hweight_overlay : + ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → + Circuit.softmax (scoresRealOfInputs inputs q) k ≤ + (weightBoundAtOverlay inputs cache spec q k : Real)) : + ∀ q, q ∈ c.active → + Layers.OneHotApproxBoundsOnActive (Val := Real) + (epsAtOverlay cache (weightBoundAtOverlay inputs cache spec) q : Real) + (fun q' => q' = q) c.prev + (fun q' k => Circuit.softmax (scoresRealOfInputs inputs q') k) := by + classical + intro q hq + refine + Sound.oneHot_bounds_at_of_weight_bounds + (active := c.active) + (prev := c.prev) + (scoresReal := scoresRealOfInputs inputs) + (weightBoundAt := weightBoundAtOverlay inputs cache spec) + (epsAt := epsAtOverlay cache (weightBoundAtOverlay inputs cache spec)) + (hepsAt := ?_) + (hweight_bounds := ?_) q hq + · intro q' + cases hcert + simp [epsAtOverlay_def] + · intro q' hq' k hk + exact hweight_overlay q' hq' k hk + +/-- The refined unweighted logit-diff lower bound is sound on active queries. -/ +theorem logitDiffLowerBoundRefinedFromCache_le + [NeZero seq] (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (c : InductionHeadCert seq) (logitCache : LogitDiffCache seq) + (spec : InductionHeadRefineSpec seq) + (hcert : c = cache.cert) + (hcache : logitCache = logitDiffCache c) + (hsound : InductionHeadCertSound inputs c) + (hweight_overlay : + ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → + Circuit.softmax (scoresRealOfInputs inputs q) k ≤ + (weightBoundAtOverlay inputs cache spec q k : Real)) + {lb : Rat} + (hbound : logitDiffLowerBoundRefinedFromCache inputs cache c logitCache spec = some lb) + {q : Fin seq} (hq : q ∈ c.active) : + (lb : Real) ≤ headLogitDiff inputs q := by + classical + have honeHot : + ∀ q, q ∈ c.active → + Layers.OneHotApproxBoundsOnActive (Val := Real) + (epsAtOverlay cache (weightBoundAtOverlay inputs cache spec) q : Real) + (fun q' => q' = q) c.prev + (fun q' k => Circuit.softmax (scoresRealOfInputs inputs q') k) := by + intro q hq + exact oneHot_bounds_at_overlay (inputs := inputs) (cache := cache) (c := c) (hcert := hcert) + (spec := spec) (hweight_overlay := hweight_overlay) q hq + have hbound' : + logitDiffLowerBoundFromCacheWithEps c (logitDiffCache c) + (epsAtOverlay cache (weightBoundAtOverlay inputs cache spec)) = some lb := by + simpa [logitDiffLowerBoundRefinedFromCache_def, hcache] using hbound + exact + logitDiffLowerBoundFromCacheWithEps_le + (inputs := inputs) + (c := c) + (epsAtCustom := epsAtOverlay cache (weightBoundAtOverlay inputs cache spec)) + (hsound := hsound) + (honeHot := honeHot) + (hbound := hbound') + (hq := hq) + +/-- Refine-on-demand logit-diff lower bound using a supplied refinement spec is sound. -/ +theorem logitDiffLowerBoundRefineOnDemandWithSpec_le + [NeZero seq] (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (c : InductionHeadCert seq) (logitCache : LogitDiffCache seq) + (spec : InductionHeadRefineSpec seq) + (hcert : c = cache.cert) + (hcache : logitCache = logitDiffCache c) + (hsound : InductionHeadCertSound inputs c) + (hweight_overlay : + ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → + Circuit.softmax (scoresRealOfInputs inputs q) k ≤ + (weightBoundAtOverlay inputs cache spec q k : Real)) + {lb : Rat} + (hbound : + logitDiffLowerBoundRefineOnDemandWithSpec inputs cache c logitCache spec = some lb) + {q : Fin seq} (hq : q ∈ c.active) : + (lb : Real) ≤ headLogitDiff inputs q := by + classical + have honeHot : + ∀ q, q ∈ c.active → + Layers.OneHotApproxBoundsOnActive (Val := Real) + ((logitDiffCache c).epsAt q : Real) + (fun q' => q' = q) c.prev + (fun q' k => Circuit.softmax (scoresRealOfInputs inputs q') k) := by + intro q hq + have h := hsound.oneHot_bounds_at q hq + have heps : (logitDiffCache c).epsAt q = c.epsAt q := by + simp [logitDiffCache_def, Bounds.cacheBoundTask_apply] + simpa [heps] using h + have hbase_le : + ∀ {lb0 : Rat}, + logitDiffLowerBoundFromCache c logitCache = some lb0 → + (lb0 : Real) ≤ headLogitDiff inputs q := by + intro lb0 hbound0 + have hbound0' : + logitDiffLowerBoundFromCache c (logitDiffCache c) = some lb0 := by + simpa [hcache] using hbound0 + have hbound0'' : + logitDiffLowerBoundFromCacheWithEps c (logitDiffCache c) + (logitDiffCache c).epsAt = some lb0 := by + simpa [logitDiffLowerBoundFromCache_eq_withEps] using hbound0' + exact + logitDiffLowerBoundFromCacheWithEps_le + (inputs := inputs) + (c := c) + (epsAtCustom := (logitDiffCache c).epsAt) + (hsound := hsound) + (honeHot := honeHot) + (hbound := hbound0'') + (hq := hq) + cases h0 : logitDiffLowerBoundFromCache c logitCache with + | none => + simp [logitDiffLowerBoundRefineOnDemandWithSpec_def, h0] at hbound + | some lb0 => + by_cases hnonpos : lb0 ≤ 0 + · cases h1 : logitDiffLowerBoundRefinedFromCache inputs cache c logitCache spec with + | none => + have hlb : lb = lb0 := by + simpa [logitDiffLowerBoundRefineOnDemandWithSpec_def, h0, hnonpos, h1] using + hbound.symm + have hbase := hbase_le (lb0 := lb0) h0 + simpa [hlb] using hbase + | some lb1 => + have hlb : lb = max lb0 lb1 := by + simpa [logitDiffLowerBoundRefineOnDemandWithSpec_def, h0, hnonpos, h1] using + hbound.symm + have hbase := hbase_le (lb0 := lb0) h0 + have hrefine := + logitDiffLowerBoundRefinedFromCache_le + (inputs := inputs) + (cache := cache) + (c := c) + (logitCache := logitCache) + (spec := spec) + (hcert := hcert) + (hcache := hcache) + (hsound := hsound) + (hweight_overlay := hweight_overlay) + (hbound := h1) + (hq := hq) + have hmax' : + max (lb0 : Real) (lb1 : Real) ≤ headLogitDiff inputs q := by + exact max_le_iff.mpr ⟨hbase, hrefine⟩ + have hmax : (max lb0 lb1 : Real) ≤ headLogitDiff inputs q := by + simpa [ratToReal_max] using hmax' + simpa [hlb] using hmax + · have hlb : lb = lb0 := by + simpa [logitDiffLowerBoundRefineOnDemandWithSpec_def, h0, hnonpos] using hbound.symm + have hbase := hbase_le (lb0 := lb0) h0 + simpa [hlb] using hbase + +/-- Refine-on-demand logit-diff lower bound using argmin refinement keys is sound. -/ +theorem logitDiffLowerBoundRefineOnDemand_le + [NeZero seq] (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (c : InductionHeadCert seq) (logitCache : LogitDiffCache seq) + (hcert : c = cache.cert) + (hcache : logitCache = logitDiffCache c) + (hsound : InductionHeadCertSound inputs c) + (hweight_overlay : + let refineBudget := max 1 cache.splitBudgetDiffRefined + ∀ q0 : Fin seq, + let spec := refineSpecForQueryWithWeightOnes inputs cache q0 refineBudget + ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → + Circuit.softmax (scoresRealOfInputs inputs q) k ≤ + (weightBoundAtOverlay inputs cache spec q k : Real)) + {lb : Rat} + (hbound : logitDiffLowerBoundRefineOnDemand inputs cache c logitCache = some lb) + {q : Fin seq} (hq : q ∈ c.active) : + (lb : Real) ≤ headLogitDiff inputs q := by + classical + have hbase_le : + ∀ {lb0 : Rat}, + logitDiffLowerBoundFromCache c logitCache = some lb0 → + (lb0 : Real) ≤ headLogitDiff inputs q := by + intro lb0 hbound0 + have hbound0' : + logitDiffLowerBoundFromCache c (logitDiffCache c) = some lb0 := by + simpa [hcache] using hbound0 + have hbound0'' : + logitDiffLowerBoundFromCacheWithEps c (logitDiffCache c) + (logitDiffCache c).epsAt = some lb0 := by + simpa [logitDiffLowerBoundFromCache_eq_withEps] using hbound0' + exact + logitDiffLowerBoundFromCacheWithEps_le + (inputs := inputs) + (c := c) + (epsAtCustom := (logitDiffCache c).epsAt) + (hsound := hsound) + (honeHot := by + intro q' hq' + have h := hsound.oneHot_bounds_at q' hq' + have heps : (logitDiffCache c).epsAt q' = c.epsAt q' := by + simp [logitDiffCache_def, Bounds.cacheBoundTask_apply] + simpa [heps] using h) + (hbound := hbound0'') + (hq := hq) + cases h0 : logitDiffLowerBoundFromCache c logitCache with + | none => + simp [logitDiffLowerBoundRefineOnDemand_def, h0] at hbound + | some lb0 => + by_cases hnonpos : lb0 ≤ 0 + · cases hargmin : logitDiffLowerBoundArgminFromCache c logitCache with + | none => + have hlb : lb = lb0 := by + simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin] using + hbound.symm + have hbase := hbase_le (lb0 := lb0) h0 + simpa [hlb] using hbase + | some q0 => + let refineBudget := max 1 cache.splitBudgetDiffRefined + let spec := refineSpecForQueryWithWeightOnes inputs cache q0 refineBudget + cases h1 : + logitDiffLowerBoundRefinedFromCache inputs cache c logitCache spec with + | none => + have hlb : lb = lb0 := by + simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin, h1, + spec, refineBudget] using hbound.symm + have hbase := hbase_le (lb0 := lb0) h0 + simpa [hlb] using hbase + | some lb1 => + have hlb : lb = max lb0 lb1 := by + simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin, h1, + spec, refineBudget] using hbound.symm + have hbase := hbase_le (lb0 := lb0) h0 + have hweight_overlay' : + ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → + Circuit.softmax (scoresRealOfInputs inputs q) k ≤ + (weightBoundAtOverlay inputs cache spec q k : Real) := by + simpa [spec, refineBudget] using hweight_overlay q0 + have hrefine := + logitDiffLowerBoundRefinedFromCache_le + (inputs := inputs) + (cache := cache) + (c := c) + (logitCache := logitCache) + (spec := spec) + (hcert := hcert) + (hcache := hcache) + (hsound := hsound) + (hweight_overlay := hweight_overlay') + (hbound := h1) + (hq := hq) + have hmax' : + max (lb0 : Real) (lb1 : Real) ≤ headLogitDiff inputs q := by + exact max_le_iff.mpr ⟨hbase, hrefine⟩ + have hmax : (max lb0 lb1 : Real) ≤ headLogitDiff inputs q := by + simpa [ratToReal_max] using hmax' + simpa [hlb] using hmax + · have hlb : lb = lb0 := by + simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos] using hbound.symm + have hbase := hbase_le (lb0 := lb0) h0 + simpa [hlb] using hbase + +end Sound + +end Nfp diff --git a/Nfp/System/Dag.lean b/Nfp/System/Dag.lean index 6f44ae1..7bf74f2 100644 --- a/Nfp/System/Dag.lean +++ b/Nfp/System/Dag.lean @@ -2,6 +2,7 @@ module +public meta import Nfp.Tactic.Linter public import Mathlib.Combinatorics.Digraph.Basic public import Mathlib.Data.Fintype.Defs public import Mathlib.Data.Finset.Basic diff --git a/Nfp/Tactic/Linter.lean b/Nfp/Tactic/Linter.lean new file mode 100644 index 0000000..610b3a3 --- /dev/null +++ b/Nfp/Tactic/Linter.lean @@ -0,0 +1,9 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public meta import Nfp.Tactic.Linter.NoHeartbeats + +/-! +Aggregator for NFP linters. +-/ diff --git a/Nfp/Tactic/Linter/NoHeartbeats.lean b/Nfp/Tactic/Linter/NoHeartbeats.lean new file mode 100644 index 0000000..cd86d33 --- /dev/null +++ b/Nfp/Tactic/Linter/NoHeartbeats.lean @@ -0,0 +1,56 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public meta import Lean.Elab.Command +public meta import Lean.Linter.Basic + +/-! +Syntax linter forbidding heartbeat budget options. +-/ + +meta section + +namespace Nfp +namespace Linter + +open Lean Parser Elab Command Linter + +/-- Enable the no-heartbeats linter. -/ +public register_option linter.nfp.noHeartbeats : Bool := { + defValue := false + descr := "enable the noHeartbeats linter" +} + +namespace NoHeartbeats + +/-- Return the option name if syntax is a `set_option` command, term, or tactic. -/ +def parseSetOption : Syntax → Option Name + | `(command|set_option $name:ident $_val) => some name.getId + | `(set_option $name:ident $_val in $_x) => some name.getId + | `(tactic|set_option $name:ident $_val in $_x) => some name.getId + | _ => none + +/-- True if the option is a heartbeat budget. -/ +def isHeartbeatOption (name : Name) : Bool := + name == `maxHeartbeats || name == `synthInstance.maxHeartbeats + +/-- Linter that forbids heartbeat budget options in this repository. -/ +def noHeartbeatsLinter : Linter where + run stx := do + unless getLinterValue linter.nfp.noHeartbeats (← getLinterOptions) do + return + if (← MonadState.get).messages.hasErrors then + return + if let some head := stx.find? (fun stx => (parseSetOption stx).isSome) then + if let some name := parseSetOption head then + if isHeartbeatOption name then + logLint linter.nfp.noHeartbeats head + m!"Setting option '{name}' is forbidden; refactor the proof instead." + +initialize addLinter noHeartbeatsLinter + +end NoHeartbeats + +end Linter +end Nfp diff --git a/lakefile.toml b/lakefile.toml index d394613..29ba710 100644 --- a/lakefile.toml +++ b/lakefile.toml @@ -24,6 +24,7 @@ weak.linter.style.lambdaSyntax = true weak.linter.style.dollarSyntax = true weak.linter.style.cdot = true weak.linter.style.longLine = true +weak.linter.nfp.noHeartbeats = true [[require]] name = "mathlib" From 49fec42367d325f8eae076b37e2b351a8fccebf1 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 16 Jan 2026 05:05:22 +0100 Subject: [PATCH 186/244] Boost refine-on-demand budgets for logit-diff --- Nfp/Sound/Induction/LogitDiff.lean | 22 +++- Nfp/Sound/Induction/Refine.lean | 9 ++ Nfp/Sound/Induction/RefineSound.lean | 148 +++++++++++++++++++++------ 3 files changed, 147 insertions(+), 32 deletions(-) diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean index 5ff77bf..f985181 100644 --- a/Nfp/Sound/Induction/LogitDiff.lean +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -335,8 +335,17 @@ def logitDiffLowerBoundRefineOnDemand let refineBudget := max 1 core.splitBudgetDiffRefined let spec := refineSpecForQueryWithWeightOnes inputs core q0 refineBudget match logitDiffLowerBoundRefinedFromCache inputs core c cache spec with - | some lb1 => some (max lb0 lb1) | none => some lb0 + | some lb1 => + let lb01 := max lb0 lb1 + if lb01 ≤ 0 then + let refineBudget' := refineBudgetBoost refineBudget + let spec' := refineSpecForQueryWithWeightOnes inputs core q0 refineBudget' + match logitDiffLowerBoundRefinedFromCache inputs core c cache spec' with + | some lb2 => some (max lb01 lb2) + | none => some lb01 + else + some lb01 else some lb0 @@ -356,8 +365,17 @@ theorem logitDiffLowerBoundRefineOnDemand_def let refineBudget := max 1 core.splitBudgetDiffRefined let spec := refineSpecForQueryWithWeightOnes inputs core q0 refineBudget match logitDiffLowerBoundRefinedFromCache inputs core c cache spec with - | some lb1 => some (max lb0 lb1) | none => some lb0 + | some lb1 => + let lb01 := max lb0 lb1 + if lb01 ≤ 0 then + let refineBudget' := refineBudgetBoost refineBudget + let spec' := refineSpecForQueryWithWeightOnes inputs core q0 refineBudget' + match logitDiffLowerBoundRefinedFromCache inputs core c cache spec' with + | some lb2 => some (max lb01 lb2) + | none => some lb01 + else + some lb01 else some lb0 := by rfl diff --git a/Nfp/Sound/Induction/Refine.lean b/Nfp/Sound/Induction/Refine.lean index ed57a35..a273327 100644 --- a/Nfp/Sound/Induction/Refine.lean +++ b/Nfp/Sound/Induction/Refine.lean @@ -26,6 +26,15 @@ structure InductionHeadRefineSpec (seq : Nat) where /-- Split budget for refined diff bounds. -/ splitBudgetDiffRefined : Nat +/-- Heuristic boost for refinement budgets. -/ +def refineBudgetBoost (budget : Nat) : Nat := + max (budget + 1) (2 * budget) + +/-- Unfolding lemma for `refineBudgetBoost`. -/ +theorem refineBudgetBoost_def (budget : Nat) : + refineBudgetBoost budget = max (budget + 1) (2 * budget) := by + rfl + /-- Worst key under the base score-gap lower bound (excluding `prev`). -/ def worstKeyBase (inputs : Model.InductionHeadInputs seq dModel dHead) diff --git a/Nfp/Sound/Induction/RefineSound.lean b/Nfp/Sound/Induction/RefineSound.lean index 7b25c4e..e3c90d1 100644 --- a/Nfp/Sound/Induction/RefineSound.lean +++ b/Nfp/Sound/Induction/RefineSound.lean @@ -497,8 +497,7 @@ theorem logitDiffLowerBoundRefineOnDemand_le (hcache : logitCache = logitDiffCache c) (hsound : InductionHeadCertSound inputs c) (hweight_overlay : - let refineBudget := max 1 cache.splitBudgetDiffRefined - ∀ q0 : Fin seq, + ∀ q0 : Fin seq, ∀ refineBudget : Nat, let spec := refineSpecForQueryWithWeightOnes inputs cache q0 refineBudget ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → Circuit.softmax (scoresRealOfInputs inputs q) k ≤ @@ -558,34 +557,123 @@ theorem logitDiffLowerBoundRefineOnDemand_le have hbase := hbase_le (lb0 := lb0) h0 simpa [hlb] using hbase | some lb1 => - have hlb : lb = max lb0 lb1 := by - simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin, h1, - spec, refineBudget] using hbound.symm - have hbase := hbase_le (lb0 := lb0) h0 - have hweight_overlay' : - ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → - Circuit.softmax (scoresRealOfInputs inputs q) k ≤ - (weightBoundAtOverlay inputs cache spec q k : Real) := by - simpa [spec, refineBudget] using hweight_overlay q0 - have hrefine := - logitDiffLowerBoundRefinedFromCache_le - (inputs := inputs) - (cache := cache) - (c := c) - (logitCache := logitCache) - (spec := spec) - (hcert := hcert) - (hcache := hcache) - (hsound := hsound) - (hweight_overlay := hweight_overlay') - (hbound := h1) - (hq := hq) - have hmax' : - max (lb0 : Real) (lb1 : Real) ≤ headLogitDiff inputs q := by - exact max_le_iff.mpr ⟨hbase, hrefine⟩ - have hmax : (max lb0 lb1 : Real) ≤ headLogitDiff inputs q := by - simpa [ratToReal_max] using hmax' - simpa [hlb] using hmax + let lb01 := max lb0 lb1 + by_cases hnonpos1 : lb01 ≤ 0 + · let refineBudget' := refineBudgetBoost refineBudget + let spec' := refineSpecForQueryWithWeightOnes inputs cache q0 refineBudget' + cases h2 : + logitDiffLowerBoundRefinedFromCache inputs cache c logitCache spec' with + | none => + have hlb : lb = lb01 := by + simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin, h1, + spec, refineBudget, lb01, hnonpos1, h2, spec', refineBudget'] using + hbound.symm + have hbase := hbase_le (lb0 := lb0) h0 + have hweight_overlay' : + ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → + Circuit.softmax (scoresRealOfInputs inputs q) k ≤ + (weightBoundAtOverlay inputs cache spec q k : Real) := by + simpa [spec, refineBudget] using hweight_overlay q0 refineBudget + have hrefine := + logitDiffLowerBoundRefinedFromCache_le + (inputs := inputs) + (cache := cache) + (c := c) + (logitCache := logitCache) + (spec := spec) + (hcert := hcert) + (hcache := hcache) + (hsound := hsound) + (hweight_overlay := hweight_overlay') + (hbound := h1) + (hq := hq) + have hmax' : + max (lb0 : Real) (lb1 : Real) ≤ headLogitDiff inputs q := by + exact max_le_iff.mpr ⟨hbase, hrefine⟩ + have hmax : (lb01 : Real) ≤ headLogitDiff inputs q := by + simpa [lb01, ratToReal_max] using hmax' + simpa [hlb] using hmax + | some lb2 => + have hlb : lb = max lb01 lb2 := by + simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin, h1, + spec, refineBudget, lb01, hnonpos1, h2, spec', refineBudget'] using + hbound.symm + have hbase := hbase_le (lb0 := lb0) h0 + have hweight_overlay' : + ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → + Circuit.softmax (scoresRealOfInputs inputs q) k ≤ + (weightBoundAtOverlay inputs cache spec q k : Real) := by + simpa [spec, refineBudget] using hweight_overlay q0 refineBudget + have hweight_overlay'' : + ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → + Circuit.softmax (scoresRealOfInputs inputs q) k ≤ + (weightBoundAtOverlay inputs cache spec' q k : Real) := by + simpa [spec', refineBudget'] using hweight_overlay q0 refineBudget' + have hrefine := + logitDiffLowerBoundRefinedFromCache_le + (inputs := inputs) + (cache := cache) + (c := c) + (logitCache := logitCache) + (spec := spec) + (hcert := hcert) + (hcache := hcache) + (hsound := hsound) + (hweight_overlay := hweight_overlay') + (hbound := h1) + (hq := hq) + have hrefine' := + logitDiffLowerBoundRefinedFromCache_le + (inputs := inputs) + (cache := cache) + (c := c) + (logitCache := logitCache) + (spec := spec') + (hcert := hcert) + (hcache := hcache) + (hsound := hsound) + (hweight_overlay := hweight_overlay'') + (hbound := h2) + (hq := hq) + have hmax01' : + max (lb0 : Real) (lb1 : Real) ≤ headLogitDiff inputs q := by + exact max_le_iff.mpr ⟨hbase, hrefine⟩ + have hmax01 : (lb01 : Real) ≤ headLogitDiff inputs q := by + simpa [lb01, ratToReal_max] using hmax01' + have hmax' : + max (lb01 : Real) (lb2 : Real) ≤ headLogitDiff inputs q := by + exact max_le_iff.mpr ⟨hmax01, hrefine'⟩ + have hmax : (max lb01 lb2 : Real) ≤ headLogitDiff inputs q := by + simpa [ratToReal_max] using hmax' + simpa [hlb] using hmax + · have hlb : lb = lb01 := by + simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin, h1, + spec, refineBudget, lb01, hnonpos1] using hbound.symm + have hbase := hbase_le (lb0 := lb0) h0 + have hweight_overlay' : + ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → + Circuit.softmax (scoresRealOfInputs inputs q) k ≤ + (weightBoundAtOverlay inputs cache spec q k : Real) := by + simpa [spec, refineBudget] using hweight_overlay q0 refineBudget + have hrefine := + logitDiffLowerBoundRefinedFromCache_le + (inputs := inputs) + (cache := cache) + (c := c) + (logitCache := logitCache) + (spec := spec) + (hcert := hcert) + (hcache := hcache) + (hsound := hsound) + (hweight_overlay := hweight_overlay') + (hbound := h1) + (hq := hq) + have hmax' : + max (lb0 : Real) (lb1 : Real) ≤ headLogitDiff inputs q := by + exact max_le_iff.mpr ⟨hbase, hrefine⟩ + have hmax : (lb01 : Real) ≤ headLogitDiff inputs q := by + simpa [lb01, ratToReal_max] using hmax' + simpa [hlb] using hmax · have hlb : lb = lb0 := by simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos] using hbound.symm have hbase := hbase_le (lb0 := lb0) h0 From 703cd76e11346938a67db17c9f9234046431178d Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 16 Jan 2026 13:33:03 +0100 Subject: [PATCH 187/244] Refactor logit-diff soundness and nonvacuous IO --- Nfp/IO/InductionHead.lean | 1 + Nfp/IO/InductionHead/Basic.lean | 490 +++++---------------- Nfp/IO/InductionHead/Nonvacuous.lean | 440 ++++++++++++++++++ Nfp/Sound/Bounds/LayerNorm/Basic.lean | 193 ++++++++ Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean | 190 ++++++++ Nfp/Sound/Induction/LogitDiff.lean | 335 ++++---------- Nfp/Sound/Induction/LogitDiffSound.lean | 467 ++++++++++++++++++++ Nfp/Sound/Induction/Refine.lean | 115 +++++ Nfp/Sound/Induction/RefineSound.lean | 289 ++++++++++-- 9 files changed, 1864 insertions(+), 656 deletions(-) create mode 100644 Nfp/IO/InductionHead/Nonvacuous.lean create mode 100644 Nfp/Sound/Induction/LogitDiffSound.lean diff --git a/Nfp/IO/InductionHead.lean b/Nfp/IO/InductionHead.lean index aa94960..f29e2d9 100644 --- a/Nfp/IO/InductionHead.lean +++ b/Nfp/IO/InductionHead.lean @@ -3,6 +3,7 @@ module public import Nfp.IO.InductionHead.Basic +public import Nfp.IO.InductionHead.Nonvacuous /-! IO helpers for induction-head certificate construction. diff --git a/Nfp/IO/InductionHead/Basic.lean b/Nfp/IO/InductionHead/Basic.lean index 6614293..1a72f93 100644 --- a/Nfp/IO/InductionHead/Basic.lean +++ b/Nfp/IO/InductionHead/Basic.lean @@ -30,7 +30,8 @@ private def unwrapTaskResult {α : Type} (res : Except IO.Error α) : IO α := | .ok a => pure a | .error e => throw e -private def configureTiming (timing? : Option Nat) (heartbeatMs? : Option Nat) : IO Unit := do +/-- Configure timing output and heartbeat reporting. -/ +def configureTiming (timing? : Option Nat) (heartbeatMs? : Option Nat) : IO Unit := do match timing? with | some v => setTimingStdout (v ≠ 0) | none => pure () @@ -41,7 +42,8 @@ private def configureTiming (timing? : Option Nat) (heartbeatMs? : Option Nat) : setTimingStdout true | none => pure () -private def splitConfigFromOptions +/-- Translate CLI split-budget options into a split config. -/ +def splitConfigFromOptions (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : Sound.InductionHeadSplitConfig := let base := Sound.defaultInductionHeadSplitConfig @@ -76,18 +78,26 @@ def loadInductionHeadInputs (path : System.FilePath) : timingPrint s!"timing: parse head input file {t3 - t2} us" return parsed -private def ratToString (x : Rat) : String := +/-- Render a rational for logging. -/ +def ratToString (x : Rat) : String := toString x -private def ratOptToString (x : Option Rat) : String := +/-- Render an optional rational for logging. -/ +def ratOptToString (x : Option Rat) : String := match x with | some v => ratToString v | none => "none" -private def logitDiffDebugEnabled : IO Bool := do +/-- Check whether logit-diff debug logging is enabled. -/ +def logitDiffDebugEnabled : IO Bool := do return (← IO.getEnv "NFP_LOGITDIFF_DEBUG").isSome -private def logitDiffRefineEnabled : IO Bool := do +/-- Check whether logit-diff debug should exit early after dumping a witness. -/ +def logitDiffDebugEarlyExitEnabled : IO Bool := do + return (← IO.getEnv "NFP_LOGITDIFF_DEBUG_EARLY_EXIT").isSome + +/-- Check whether logit-diff refinement debug output is enabled. -/ +def logitDiffRefineEnabled : IO Bool := do return (← IO.getEnv "NFP_LOGITDIFF_REFINE").isSome private def renderResidualIntervalCert {n : Nat} (c : Circuit.ResidualIntervalCert n) : String := @@ -786,157 +796,109 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} | some v => some v | none => some (0 : Rat) let logitCache := Nfp.Sound.logitDiffCache cert - logTiming "start: head logit-diff lower bound" - timingPrint "timing: head logit-diff lower bound start" - timingFlush - profileLogitDiffWeighted cert logitCache - let logitDiffLB0? ← timePureWithHeartbeat - "head: logit-diff lower bound unweighted" (fun () => - Nfp.Sound.logitDiffLowerBoundRefineOnDemand inputs cache cert logitCache) - let logitDiffLB? ← - match logitDiffLB0? with - | none => pure none - | some lb0 => - match effectiveMinLogitDiff with - | some minLogitDiff => - if lb0 >= minLogitDiff then - timingPrint "timing: head logit-diff weighted skipped" - timingFlush - pure (some lb0) - else - let lb1? ← timePureWithHeartbeat - "head: logit-diff lower bound weighted" (fun () => - Nfp.Sound.logitDiffLowerBoundWeightedFromCache cert logitCache) - let lb := - match lb1? with - | some lb1 => max lb0 lb1 - | none => lb0 - pure (some lb) - | none => - let lb1? ← timePureWithHeartbeat - "head: logit-diff lower bound weighted" (fun () => - Nfp.Sound.logitDiffLowerBoundWeightedFromCache cert logitCache) - let lb := - match lb1? with - | some lb1 => max lb0 lb1 - | none => lb0 - pure (some lb) - logTiming "done: head logit-diff lower bound" - match logitDiffLB? with - | none => - IO.eprintln "error: empty active set for logit-diff bound" - return 2 - | some logitDiffLB => - let violation? : Option Rat := - match effectiveMinLogitDiff with - | none => none - | some minLogitDiff => - if logitDiffLB < minLogitDiff then - some minLogitDiff - else - none - match violation? with - | some minLogitDiff => - IO.eprintln - s!"error: logitDiffLB {ratToString logitDiffLB} \ - below minimum {ratToString minLogitDiff}" - return 2 - | none => - IO.println - s!"ok: induction bound certified \ - (seq={seq}, active={activeCount}, \ - tol={ratToString tol}, logitDiffLB={ratToString logitDiffLB})" - return 0 - -private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cfg : Sound.InductionHeadSplitConfig) - (minActive? : Option Nat) (minLogitDiff? : Option Rat) - (minMargin? : Option Rat) (maxEps : Rat) : IO UInt32 := do - match seq with - | 0 => - IO.eprintln "error: seq must be positive" - return 2 - | Nat.succ n => - let seq := Nat.succ n - let _ : NeZero seq := ⟨by simp⟩ - logTiming "start: head build induction cert" - timingPrint "timing: head build induction cert start" - timingFlush - let tCert0 ← monoUsNow - let certTask : - Task - (Option { cache : Sound.InductionHeadCoreCache seq dModel dHead // - Sound.InductionHeadCertSound inputs cache.cert }) := - Task.spawn (prio := Task.Priority.dedicated) (fun _ => - match Sound.buildInductionCertFromHeadWithCache? cfg inputs with - | none => none - | some ⟨cache, hcert⟩ => - let _ := cache.cert.active.card - some ⟨cache, hcert⟩) - let heartbeatMs ← heartbeatMs - if heartbeatMs ≠ 0 then - let mut finished := (← IO.hasFinished certTask) - while !finished do - IO.sleep heartbeatMs - finished := (← IO.hasFinished certTask) - if !finished then - let now ← monoUsNow - timingPrint s!"timing: head build induction cert running {now - tCert0} us" - timingFlush - let certOpt ← IO.wait certTask - let tCert1 ← monoUsNow - logTiming s!"done: head build induction cert {tCert1 - tCert0} us" - timingPrint s!"timing: head build induction cert {tCert1 - tCert0} us" - timingPrint "timing: head build induction cert returned" - timingFlush - match certOpt with - | none => - IO.eprintln "error: head inputs rejected" - return 2 - | some ⟨cache, _hcert⟩ => - let cert := cache.cert - let activeCount := cert.active.card - let defaultMinActive := max 1 (seq / 8) - let minActive := minActive?.getD defaultMinActive - if activeCount < minActive then + let emitLogitDiffDebug (info : Nfp.Sound.LogitDiffAtLoDebug seq) : IO Unit := do IO.eprintln - s!"error: active queries {activeCount} below minimum {minActive}" - return 2 - if maxEps < cert.eps then + s!"debug: logitDiffLB0 witness q={info.q.1}, prev={info.prev.1}" IO.eprintln - s!"error: eps {ratToString cert.eps} above maximum {ratToString maxEps}" - return 2 - let marginViolation? : Option Rat := - match minMargin? with - | none => none - | some minMargin => - if cert.margin < minMargin then - some minMargin + s!"debug: eps={ratToString info.eps}, \ + valsPrevLo={ratToString info.valsPrevLo}, \ + loAt={ratToString info.loAt}, \ + lo={ratToString info.lo}" + IO.eprintln + s!"debug: valsPrevLoMinusLoAt={ratToString info.valsPrevLoMinusLoAt}, \ + gap={ratToString info.gap}, \ + fAtQ={ratToString (info.valsPrevLo - info.gap)}, \ + lbAtQ={ratToString info.lbAtQ}" + let weightBoundAt := cert.weightBoundAt + let step : (Rat × Nat × Rat) → Fin seq → (Rat × Nat × Rat) := + fun acc k => + if k = info.prev then + acc + else + let w := weightBoundAt info.q k + let sum := acc.1 + w + let ones := if w = (1 : Rat) then acc.2.1 + 1 else acc.2.1 + let maxW := if w > acc.2.2 then w else acc.2.2 + (sum, ones, maxW) + let acc := Sound.Linear.foldlFin seq step (0, 0, 0) + IO.eprintln + s!"debug: epsAt={ratToString (cert.epsAt info.q)}, \ + weightSum={ratToString acc.1}, ones={acc.2.1}, \ + maxWeight={ratToString acc.2.2}" + let valsLo := logitCache.valsLo + let stepOnes : Array String → Fin seq → Array String := + fun acc k => + if k = info.prev then + acc else - none - match marginViolation? with - | some minMargin => + let w := weightBoundAt info.q k + if w = (1 : Rat) then + acc.push + s!"k={k.1} valsLo={ratToString (valsLo k)}" + else + acc + let ones := Sound.Linear.foldlFin seq stepOnes #[] + let onesMsg := + if ones.isEmpty then + "none" + else + String.intercalate ", " ones.toList + IO.eprintln s!"debug: weightBoundAt=1 keys: {onesMsg}" + let stepLoAt : Array String → Fin seq → Array String := + fun acc k => + if k = info.prev then + acc + else if valsLo k = info.loAt then + acc.push + s!"k={k.1} w={ratToString (weightBoundAt info.q k)}" + else + acc + let loAtKeys := Sound.Linear.foldlFin seq stepLoAt #[] + let loAtMsg := + if loAtKeys.isEmpty then + "none" + else + String.intercalate ", " loAtKeys.toList + IO.eprintln s!"debug: loAt keys: {loAtMsg}" + if (← logitDiffRefineEnabled) then + let refineBudget := max 1 cfg.splitBudgetDiffRefined + let refineKeys := Sound.refineKeysAtWithWeightOnes inputs cache info.q IO.eprintln - s!"error: margin {ratToString cert.margin} \ - below minimum {ratToString minMargin}" - return 2 - | none => pure () + s!"debug: refine budget={refineBudget}, \ + refineKeys.card={refineKeys.card}" + let refineSpec := + Sound.refineSpecForQueryWithWeightOnes inputs cache info.q refineBudget + let refinedLB? := + Sound.logitDiffLowerBoundRefinedFromCache + inputs cache cert logitCache refineSpec + match refinedLB? with + | none => + IO.eprintln "debug: refined logitDiffLB0 none" + | some lb => + IO.eprintln + s!"debug: refined logitDiffLB0={ratToString lb}" logTiming "start: head logit-diff lower bound" timingPrint "timing: head logit-diff lower bound start" timingFlush - let logitCache := Nfp.Sound.logitDiffCache cert - let profiling ← logitDiffProfileEnabled - if profiling then - profileLogitDiffWeighted cert logitCache - else - pure () - let weightedTask? : Option (Task (Option Rat)) := - if profiling then - none + profileLogitDiffWeighted cert logitCache + let earlyExit? ← + if (← logitDiffDebugEnabled) && (← logitDiffDebugEarlyExitEnabled) then + let debug? ← timePureWithHeartbeat + "head: logit-diff lower bound debug" (fun () => + Nfp.Sound.logitDiffLowerBoundAtLoDebug cert logitCache) + match debug? with + | none => + IO.eprintln "debug: logitDiffLB0 witness not found" + | some ⟨info, _⟩ => + emitLogitDiffDebug info + IO.eprintln "debug: early exit requested" + pure (some ()) else - some (Task.spawn (fun _ => - Nfp.Sound.logitDiffLowerBoundWeightedFromCache cert logitCache)) + pure none + match earlyExit? with + | some _ => return 2 + | none => pure () + let weightedTask? : Option (Task (Option Rat)) := none let logitDiffLB0? ← timePureWithHeartbeat "head: logit-diff lower bound unweighted" (fun () => Nfp.Sound.logitDiffLowerBoundRefineOnDemand inputs cache cert logitCache) @@ -948,86 +910,7 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} | none => IO.eprintln "debug: logitDiffLB0 witness not found" | some ⟨info, _⟩ => - IO.eprintln - s!"debug: logitDiffLB0 witness q={info.q.1}, prev={info.prev.1}" - IO.eprintln - s!"debug: eps={ratToString info.eps}, \ - valsPrevLo={ratToString info.valsPrevLo}, \ - loAt={ratToString info.loAt}, \ - lo={ratToString info.lo}" - IO.eprintln - s!"debug: valsPrevLoMinusLoAt={ratToString info.valsPrevLoMinusLoAt}, \ - gap={ratToString info.gap}, \ - fAtQ={ratToString (info.valsPrevLo - info.gap)}, \ - lbAtQ={ratToString info.lbAtQ}" - let weightBoundAt := cert.weightBoundAt - let step : (Rat × Nat × Rat) → Fin seq → (Rat × Nat × Rat) := - fun acc k => - if k = info.prev then - acc - else - let w := weightBoundAt info.q k - let sum := acc.1 + w - let ones := if w = (1 : Rat) then acc.2.1 + 1 else acc.2.1 - let maxW := if w > acc.2.2 then w else acc.2.2 - (sum, ones, maxW) - let acc := Sound.Linear.foldlFin seq step (0, 0, 0) - IO.eprintln - s!"debug: epsAt={ratToString (cert.epsAt info.q)}, \ - weightSum={ratToString acc.1}, ones={acc.2.1}, \ - maxWeight={ratToString acc.2.2}" - let valsLo := logitCache.valsLo - let stepOnes : Array String → Fin seq → Array String := - fun acc k => - if k = info.prev then - acc - else - let w := weightBoundAt info.q k - if w = (1 : Rat) then - acc.push - s!"k={k.1} valsLo={ratToString (valsLo k)}" - else - acc - let ones := Sound.Linear.foldlFin seq stepOnes #[] - let onesMsg := - if ones.isEmpty then - "none" - else - String.intercalate ", " ones.toList - IO.eprintln s!"debug: weightBoundAt=1 keys: {onesMsg}" - let stepLoAt : Array String → Fin seq → Array String := - fun acc k => - if k = info.prev then - acc - else if valsLo k = info.loAt then - acc.push - s!"k={k.1} w={ratToString (weightBoundAt info.q k)}" - else - acc - let loAtKeys := Sound.Linear.foldlFin seq stepLoAt #[] - let loAtMsg := - if loAtKeys.isEmpty then - "none" - else - String.intercalate ", " loAtKeys.toList - IO.eprintln s!"debug: loAt keys: {loAtMsg}" - if (← logitDiffRefineEnabled) then - let refineBudget := max 1 cfg.splitBudgetDiffRefined - let refineKeys := Sound.refineKeysAtWithWeightOnes inputs cache info.q - IO.eprintln - s!"debug: refine budget={refineBudget}, \ - refineKeys.card={refineKeys.card}" - let refineSpec := - Sound.refineSpecForQueryWithWeightOnes inputs cache info.q refineBudget - let refinedLB? := - Sound.logitDiffLowerBoundRefinedFromCache - inputs cache cert logitCache refineSpec - match refinedLB? with - | none => - IO.eprintln "debug: refined logitDiffLB0 none" - | some lb => - IO.eprintln - s!"debug: refined logitDiffLB0={ratToString lb}" + emitLogitDiffDebug info | none => pure () let needsWeighted : Bool := match logitDiffLB0? with @@ -1103,7 +986,6 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} tol={ratToString tol}, logitDiffLB={ratToString logitDiffLB}, \ bound={boundLabel})" return 0 - /-- Build and check induction certificates from exact head inputs. -/ def runInductionCertifyHead (inputsPath : System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) @@ -1139,41 +1021,6 @@ def runInductionCertifyHead (inputsPath : System.FilePath) | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => checkInductionHeadInputs inputs splitCfg minActive? minLogitDiff? minMargin maxEps -/-- Build and check a strictly positive induction logit-diff bound from head inputs. -/ -def runInductionCertifyHeadNonvacuous (inputsPath : System.FilePath) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) - (timing? : Option Nat) (heartbeatMs? : Option Nat) - (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : - IO UInt32 := do - configureTiming timing? heartbeatMs? - let splitCfg := - splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - let parsedInputs ← timePhase "load head inputs" <| - loadInductionHeadInputs inputsPath - match parsedInputs with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => - checkInductionHeadInputsNonvacuous inputs splitCfg minActive? minLogitDiff? - minMargin? maxEps - /-- Build and check induction certificates from a model binary. -/ def runInductionCertifyHeadModel (modelPath : System.FilePath) (layer head dirTarget dirNegative : Nat) (period? : Option Nat) @@ -1222,56 +1069,8 @@ def runInductionCertifyHeadModel (modelPath : System.FilePath) | Except.ok inputs => checkInductionHeadInputs inputs splitCfg minActive? minLogitDiff? minMargin maxEps -/-- Build and check a strictly positive induction logit-diff bound from a model binary. -/ -def runInductionCertifyHeadModelNonvacuous (modelPath : System.FilePath) - (layer head dirTarget dirNegative : Nat) (period? : Option Nat) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) - (timing? : Option Nat) (heartbeatMs? : Option Nat) - (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : - IO UInt32 := do - configureTiming timing? heartbeatMs? - let splitCfg := - splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - logTiming "start: read model file" - timingPrint "timing: read model file start" - timingFlush - let data ← timePhase "read model file" <| IO.FS.readBinFile modelPath - let headerE ← timePure "parse model header" (fun () => - NfptPure.parseHeader data) - match headerE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨header, start⟩ => - let inputsE ← timePure "read head inputs" (fun () => - NfptPure.readInductionHeadInputs - data start header layer head dirTarget dirNegative period?) - match inputsE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok inputs => - checkInductionHeadInputsNonvacuous inputs splitCfg minActive? minLogitDiff? - minMargin? maxEps - /-- Heuristic logit-diff direction derived from prompt tokens. -/ -private def deriveDirectionFromTokens {seq : Nat} (tokens : Fin seq → Nat) : +def deriveDirectionFromTokens {seq : Nat} (tokens : Fin seq → Nat) : Except String (Nat × Nat) := do let tokenArr : Array Nat := Array.ofFn (fun i : Fin seq => tokens i) let n := tokenArr.size @@ -1362,69 +1161,6 @@ def runInductionCertifyHeadModelAuto (modelPath : System.FilePath) checkInductionHeadInputs inputs splitCfg minActive? minLogitDiff? minMargin maxEps -/-- Build and check a strictly positive induction logit-diff bound from a model binary, deriving -direction tokens from the prompt sequence. -/ -def runInductionCertifyHeadModelAutoNonvacuous (modelPath : System.FilePath) - (layer head : Nat) (period? : Option Nat) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) - (timing? : Option Nat) (heartbeatMs? : Option Nat) - (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : - IO UInt32 := do - configureTiming timing? heartbeatMs? - let splitCfg := - splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - logTiming "start: read model file" - timingPrint "timing: read model file start" - timingFlush - let data ← timePhase "read model file" <| IO.FS.readBinFile modelPath - let headerE ← timePure "parse model header" (fun () => - NfptPure.parseHeader data) - match headerE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨header, start⟩ => - let tokensE ← timePure "read prompt tokens" (fun () => - NfptPure.readTokens data start header) - match tokensE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok tokens => - match deriveDirectionFromTokens tokens with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨dirTarget, dirNegative⟩ => - IO.println - s!"info: direction-target={dirTarget} direction-negative={dirNegative}" - let inputsE ← timePure "read head inputs" (fun () => - NfptPure.readInductionHeadInputs - data start header layer head dirTarget dirNegative period?) - match inputsE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok inputs => - checkInductionHeadInputsNonvacuous inputs splitCfg minActive? minLogitDiff? - minMargin? maxEps - /-- Build head-output interval bounds from exact head inputs. -/ def runInductionHeadInterval (inputsPath : System.FilePath) (outPath? : Option System.FilePath) : IO UInt32 := do diff --git a/Nfp/IO/InductionHead/Nonvacuous.lean b/Nfp/IO/InductionHead/Nonvacuous.lean new file mode 100644 index 0000000..73a8bc6 --- /dev/null +++ b/Nfp/IO/InductionHead/Nonvacuous.lean @@ -0,0 +1,440 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +import Nfp.IO.InductionHead.Basic + +/-! +IO helpers for nonvacuous induction-head certificate checks. +-/ + +public section + +namespace Nfp + +namespace IO + +/-- Build and check induction certificates from exact head inputs. -/ +private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cfg : Sound.InductionHeadSplitConfig) + (minActive? : Option Nat) (minLogitDiff? : Option Rat) + (minMargin? : Option Rat) (maxEps : Rat) : IO UInt32 := do + match seq with + | 0 => + IO.eprintln "error: seq must be positive" + return 2 + | Nat.succ n => + let seq := Nat.succ n + let _ : NeZero seq := ⟨by simp⟩ + logTiming "start: head build induction cert" + timingPrint "timing: head build induction cert start" + timingFlush + let tCert0 ← monoUsNow + let certTask : + Task + (Option { cache : Sound.InductionHeadCoreCache seq dModel dHead // + Sound.InductionHeadCertSound inputs cache.cert }) := + Task.spawn (prio := Task.Priority.dedicated) (fun _ => + match Sound.buildInductionCertFromHeadWithCache? cfg inputs with + | none => none + | some ⟨cache, hcert⟩ => + let _ := cache.cert.active.card + some ⟨cache, hcert⟩) + let heartbeatMs ← heartbeatMs + if heartbeatMs ≠ 0 then + let mut finished := (← IO.hasFinished certTask) + while !finished do + IO.sleep heartbeatMs + finished := (← IO.hasFinished certTask) + if !finished then + let now ← monoUsNow + timingPrint s!"timing: head build induction cert running {now - tCert0} us" + timingFlush + let certOpt ← IO.wait certTask + let tCert1 ← monoUsNow + logTiming s!"done: head build induction cert {tCert1 - tCert0} us" + timingPrint s!"timing: head build induction cert {tCert1 - tCert0} us" + timingPrint "timing: head build induction cert returned" + timingFlush + match certOpt with + | none => + IO.eprintln "error: head inputs rejected" + return 2 + | some ⟨cache, _hcert⟩ => + let cert := cache.cert + let activeCount := cert.active.card + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {ratToString cert.eps} above maximum {ratToString maxEps}" + return 2 + let marginViolation? : Option Rat := + match minMargin? with + | none => none + | some minMargin => + if cert.margin < minMargin then + some minMargin + else + none + match marginViolation? with + | some minMargin => + IO.eprintln + s!"error: margin {ratToString cert.margin} \ + below minimum {ratToString minMargin}" + return 2 + | none => pure () + logTiming "start: head logit-diff lower bound" + timingPrint "timing: head logit-diff lower bound start" + timingFlush + let logitCache := Nfp.Sound.logitDiffCache cert + let emitLogitDiffDebug (info : Nfp.Sound.LogitDiffAtLoDebug seq) : IO Unit := do + IO.eprintln + s!"debug: logitDiffLB0 witness q={info.q.1}, prev={info.prev.1}" + IO.eprintln + s!"debug: eps={ratToString info.eps}, \ + valsPrevLo={ratToString info.valsPrevLo}, \ + loAt={ratToString info.loAt}, \ + lo={ratToString info.lo}" + IO.eprintln + s!"debug: valsPrevLoMinusLoAt={ratToString info.valsPrevLoMinusLoAt}, \ + gap={ratToString info.gap}, \ + fAtQ={ratToString (info.valsPrevLo - info.gap)}, \ + lbAtQ={ratToString info.lbAtQ}" + let weightBoundAt := cert.weightBoundAt + let step : (Rat × Nat × Rat) → Fin seq → (Rat × Nat × Rat) := + fun acc k => + if k = info.prev then + acc + else + let w := weightBoundAt info.q k + let sum := acc.1 + w + let ones := if w = (1 : Rat) then acc.2.1 + 1 else acc.2.1 + let maxW := if w > acc.2.2 then w else acc.2.2 + (sum, ones, maxW) + let acc := Sound.Linear.foldlFin seq step (0, 0, 0) + IO.eprintln + s!"debug: epsAt={ratToString (cert.epsAt info.q)}, \ + weightSum={ratToString acc.1}, ones={acc.2.1}, \ + maxWeight={ratToString acc.2.2}" + let valsLo := logitCache.valsLo + let stepOnes : Array String → Fin seq → Array String := + fun acc k => + if k = info.prev then + acc + else + let w := weightBoundAt info.q k + if w = (1 : Rat) then + acc.push + s!"k={k.1} valsLo={ratToString (valsLo k)}" + else + acc + let ones := Sound.Linear.foldlFin seq stepOnes #[] + let onesMsg := + if ones.isEmpty then + "none" + else + String.intercalate ", " ones.toList + IO.eprintln s!"debug: weightBoundAt=1 keys: {onesMsg}" + let stepLoAt : Array String → Fin seq → Array String := + fun acc k => + if k = info.prev then + acc + else if valsLo k = info.loAt then + acc.push + s!"k={k.1} w={ratToString (weightBoundAt info.q k)}" + else + acc + let loAtKeys := Sound.Linear.foldlFin seq stepLoAt #[] + let loAtMsg := + if loAtKeys.isEmpty then + "none" + else + String.intercalate ", " loAtKeys.toList + IO.eprintln s!"debug: loAt keys: {loAtMsg}" + if (← logitDiffRefineEnabled) then + let refineBudget := max 1 cfg.splitBudgetDiffRefined + let refineKeys := Sound.refineKeysAtWithWeightOnes inputs cache info.q + IO.eprintln + s!"debug: refine budget={refineBudget}, \ + refineKeys.card={refineKeys.card}" + let refineSpec := + Sound.refineSpecForQueryWithWeightOnes inputs cache info.q refineBudget + let refinedLB? := + Sound.logitDiffLowerBoundRefinedFromCache + inputs cache cert logitCache refineSpec + match refinedLB? with + | none => + IO.eprintln "debug: refined logitDiffLB0 none" + | some lb => + IO.eprintln + s!"debug: refined logitDiffLB0={ratToString lb}" + let profiling ← logitDiffProfileEnabled + if profiling then + profileLogitDiffWeighted cert logitCache + else + pure () + let earlyExit? ← + if (← logitDiffDebugEnabled) && (← logitDiffDebugEarlyExitEnabled) then + let debug? ← timePureWithHeartbeat + "head: logit-diff lower bound debug" (fun () => + Nfp.Sound.logitDiffLowerBoundAtLoDebug cert logitCache) + match debug? with + | none => + IO.eprintln "debug: logitDiffLB0 witness not found" + | some ⟨info, _⟩ => + emitLogitDiffDebug info + IO.eprintln "debug: early exit requested" + pure (some ()) + else + pure none + match earlyExit? with + | some _ => return 2 + | none => pure () + let weightedTask? : Option (Task (Option Rat)) := + if profiling then + none + else + some (Task.spawn (fun _ => + Nfp.Sound.logitDiffLowerBoundWeightedFromCache cert logitCache)) + let logitDiffLB0? ← timePureWithHeartbeat + "head: logit-diff lower bound unweighted" (fun () => + Nfp.Sound.logitDiffLowerBoundRefineOnDemand inputs cache cert logitCache) + if (← logitDiffDebugEnabled) then + match logitDiffLB0? with + | some lb0 => + if lb0 ≤ 0 then + match Nfp.Sound.logitDiffLowerBoundAtLoDebug cert logitCache with + | none => + IO.eprintln "debug: logitDiffLB0 witness not found" + | some ⟨info, _⟩ => emitLogitDiffDebug info + | none => pure () + let needsWeighted : Bool := + match logitDiffLB0? with + | none => true + | some lb0 => + if lb0 ≤ 0 then + true + else + match minLogitDiff? with + | some minLogitDiff => lb0 < minLogitDiff + | none => false + let logitDiffWeighted? ← + if needsWeighted then + match weightedTask? with + | some task => + timePureWithHeartbeat + "head: logit-diff lower bound weighted" (fun () => + task.get) + | none => + timePureWithHeartbeat + "head: logit-diff lower bound weighted" (fun () => + Nfp.Sound.logitDiffLowerBoundWeightedFromCache cert logitCache) + else + pure none + let logitDiffLB? : Option Rat := + match logitDiffLB0?, logitDiffWeighted? with + | some lb0, some lb1 => some (max lb0 lb1) + | some lb0, none => some lb0 + | none, some lb1 => some lb1 + | none, none => none + let boundLabel : String := + match logitDiffLB0?, logitDiffWeighted? with + | some _, some _ => "max" + | none, some _ => "weighted" + | some _, none => "eps" + | none, none => "none" + logTiming "done: head logit-diff lower bound" + match logitDiffLB? with + | none => + IO.eprintln "error: empty active set for logit-diff bound" + return 2 + | some logitDiffLB => + if logitDiffLB ≤ 0 then + if (← logitDiffDebugEnabled) then + IO.eprintln + s!"debug: logitDiffLB0={ratOptToString logitDiffLB0?}, \ + logitDiffWeighted={ratOptToString logitDiffWeighted?}, \ + logitDiffLB={ratToString logitDiffLB}, \ + bound={boundLabel}" + IO.eprintln + s!"error: logitDiffLB {ratToString logitDiffLB} \ + is not strictly positive" + return 2 + let violation? : Option Rat := + match minLogitDiff? with + | none => none + | some minLogitDiff => + if logitDiffLB < minLogitDiff then + some minLogitDiff + else + none + match violation? with + | some minLogitDiff => + IO.eprintln + s!"error: logitDiffLB {ratToString logitDiffLB} \ + below minimum {ratToString minLogitDiff}" + return 2 + | none => pure () + let tol := cert.eps * (cert.values.hi - cert.values.lo) + IO.println + s!"ok: nonvacuous induction bound certified \ + (seq={seq}, active={activeCount}, \ + tol={ratToString tol}, logitDiffLB={ratToString logitDiffLB}, \ + bound={boundLabel})" + return 0 + +/-- Build and check a strictly positive induction logit-diff bound from head inputs. -/ +def runInductionCertifyHeadNonvacuous (inputsPath : System.FilePath) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) + (timing? : Option Nat) (heartbeatMs? : Option Nat) + (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : + IO UInt32 := do + configureTiming timing? heartbeatMs? + let splitCfg := + splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) + let parsedInputs ← timePhase "load head inputs" <| + loadInductionHeadInputs inputsPath + match parsedInputs with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => + checkInductionHeadInputsNonvacuous inputs splitCfg minActive? minLogitDiff? + minMargin? maxEps + +/-- Build and check a strictly positive induction logit-diff bound from a model binary. -/ +def runInductionCertifyHeadModelNonvacuous (modelPath : System.FilePath) + (layer head dirTarget dirNegative : Nat) (period? : Option Nat) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) + (timing? : Option Nat) (heartbeatMs? : Option Nat) + (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : + IO UInt32 := do + configureTiming timing? heartbeatMs? + let splitCfg := + splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) + logTiming "start: read model file" + timingPrint "timing: read model file start" + timingFlush + let data ← timePhase "read model file" <| IO.FS.readBinFile modelPath + let headerE ← timePure "parse model header" (fun () => + NfptPure.parseHeader data) + match headerE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨header, start⟩ => + let inputsE ← timePure "read head inputs" (fun () => + NfptPure.readInductionHeadInputs + data start header layer head dirTarget dirNegative period?) + match inputsE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok inputs => + checkInductionHeadInputsNonvacuous inputs splitCfg minActive? minLogitDiff? + minMargin? maxEps + +/-- Build and check a strictly positive induction logit-diff bound from a model binary, deriving +direction tokens from the prompt sequence. -/ +def runInductionCertifyHeadModelAutoNonvacuous (modelPath : System.FilePath) + (layer head : Nat) (period? : Option Nat) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) + (timing? : Option Nat) (heartbeatMs? : Option Nat) + (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : + IO UInt32 := do + configureTiming timing? heartbeatMs? + let splitCfg := + splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) + logTiming "start: read model file" + timingPrint "timing: read model file start" + timingFlush + let data ← timePhase "read model file" <| IO.FS.readBinFile modelPath + let headerE ← timePure "parse model header" (fun () => + NfptPure.parseHeader data) + match headerE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨header, start⟩ => + let tokensE ← timePure "read prompt tokens" (fun () => + NfptPure.readTokens data start header) + match tokensE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok tokens => + match deriveDirectionFromTokens tokens with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨dirTarget, dirNegative⟩ => + IO.println + s!"info: direction-target={dirTarget} direction-negative={dirNegative}" + let inputsE ← timePure "read head inputs" (fun () => + NfptPure.readInductionHeadInputs + data start header layer head dirTarget dirNegative period?) + match inputsE with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok inputs => + checkInductionHeadInputsNonvacuous inputs splitCfg minActive? minLogitDiff? + minMargin? maxEps + +end IO + +end Nfp diff --git a/Nfp/Sound/Bounds/LayerNorm/Basic.lean b/Nfp/Sound/Bounds/LayerNorm/Basic.lean index 0e731c0..998402a 100644 --- a/Nfp/Sound/Bounds/LayerNorm/Basic.lean +++ b/Nfp/Sound/Bounds/LayerNorm/Basic.lean @@ -136,6 +136,35 @@ def layerNormBounds {n : Nat} beta i + coeff i * invStdLower (lo, hi) +/-- Interval bounds for LayerNorm outputs with a custom sqrt scale. -/ +def layerNormBoundsWithScale {n : Nat} (scale : Nat) + (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) : + (Fin n → Rat) × (Fin n → Rat) := + if n = 0 then + (fun _ => 0, fun _ => 0) + else + let μ : Rat := mean x + let centered : Fin n → Rat := fun i => x i - μ + let var : Rat := variance x + let varEps : Rat := var + eps + let sqrtLowerBound : Rat := + max (sqrtLowerWithScale scale eps) (sqrtLowerWithScale scale varEps) + let sqrtUpperBound : Rat := sqrtUpperWithScale scale varEps + let invStdLower : Rat := ratDivDown 1 sqrtUpperBound + let invStdUpper : Rat := ratDivUp 1 sqrtLowerBound + let coeff : Fin n → Rat := fun i => gamma i * centered i + let lo : Fin n → Rat := fun i => + if 0 ≤ coeff i then + beta i + coeff i * invStdLower + else + beta i + coeff i * invStdUpper + let hi : Fin n → Rat := fun i => + if 0 ≤ coeff i then + beta i + coeff i * invStdUpper + else + beta i + coeff i * invStdLower + (lo, hi) + /-- `layerNormBounds` soundness for real LayerNorm outputs. -/ theorem layerNormBounds_spec {n : Nat} (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) @@ -281,6 +310,170 @@ theorem layerNormBounds_spec {n : Nat} using hhigh_raw exact And.intro hlo hhi +/-- `layerNormBoundsWithScale` soundness for real LayerNorm outputs. -/ +theorem layerNormBoundsWithScale_spec {n : Nat} {scale : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLowerWithScale scale eps) + (hscale : 0 < scale) : + let bounds := layerNormBoundsWithScale scale eps gamma beta x + ∀ i, + (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i ∧ + layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + classical + intro bounds i + let μRat : Rat := mean x + let varRat : Rat := variance x + let varEpsRat : Rat := varRat + eps + let sqrtLowerBound : Rat := + max (sqrtLowerWithScale scale eps) (sqrtLowerWithScale scale varEpsRat) + let sqrtUpperBound : Rat := sqrtUpperWithScale scale varEpsRat + let invStdLower : Rat := ratDivDown 1 sqrtUpperBound + let invStdUpper : Rat := ratDivUp 1 sqrtLowerBound + let centered : Rat := x i - μRat + let coeff : Rat := gamma i * centered + let μ : Real := meanRat x + let varEps : Real := (varianceRat x : Real) + (eps : Real) + let invStd : Real := (Real.sqrt varEps)⁻¹ + have hmu : (μRat : Real) = μ := by + simp [μRat, μ, mean_def, hne, ratRoundDown_def] + have hvar : (varRat : Real) = (varianceRat x : Real) := by + simp [varRat, variance_def, hne, ratRoundDown_def] + have hvarEps : (varEpsRat : Real) = varEps := by + simp [varEpsRat, varEps, hvar] + have hvar_nonneg : 0 ≤ (varianceRat x : Real) := varianceRat_nonneg_real x hne + have hvar_nonneg_real : 0 ≤ ratToReal (varianceRat x) := by + simpa [ratToReal_def] using hvar_nonneg + have hvar_nonneg_rat : 0 ≤ varianceRat x := by + exact (ratToReal_nonneg_iff (x := varianceRat x)).1 hvar_nonneg_real + have hvarRat_nonneg : 0 ≤ varRat := by + have h := ratRoundDown_nonneg (q := varianceRat x) hvar_nonneg_rat + simpa [varRat, variance_def x hne] using h + have hvarEps_nonneg : 0 ≤ varEpsRat := by + exact add_nonneg hvarRat_nonneg (le_of_lt heps) + have hsqrt_lower : + (sqrtLowerBound : Real) ≤ Real.sqrt varEps := by + have hsqrt_eps : (sqrtLowerWithScale scale eps : Real) ≤ Real.sqrt varEps := by + have hsqrt_eps' : + (sqrtLowerWithScale scale eps : Real) ≤ Real.sqrt (eps : Real) := by + have h := + sqrtLowerWithScale_le_real_sqrt (q := eps) (scale := scale) + (by exact le_of_lt heps) hscale + simpa using h + have hle : (eps : Real) ≤ varEps := by + have hle' : (eps : Real) ≤ (varianceRat x : Real) + (eps : Real) := + le_add_of_nonneg_left hvar_nonneg + simpa [varEps] using hle' + exact le_trans hsqrt_eps' (Real.sqrt_le_sqrt hle) + have hsqrt_var : + (sqrtLowerWithScale scale varEpsRat : Real) ≤ Real.sqrt varEps := by + have hsqrt_var' : + (sqrtLowerWithScale scale varEpsRat : Real) ≤ + Real.sqrt (varEpsRat : Real) := by + have h := + sqrtLowerWithScale_le_real_sqrt (q := varEpsRat) (scale := scale) + hvarEps_nonneg hscale + simpa using h + have hle : (varEpsRat : Real) ≤ varEps := by + simp [hvarEps] + exact le_trans hsqrt_var' (Real.sqrt_le_sqrt hle) + have hmax : + max (sqrtLowerWithScale scale eps : Real) + (sqrtLowerWithScale scale varEpsRat : Real) ≤ Real.sqrt varEps := + (max_le_iff).2 ⟨hsqrt_eps, hsqrt_var⟩ + simpa [sqrtLowerBound, ratToReal_max] using hmax + have hsqrt_upper : + Real.sqrt varEps ≤ (sqrtUpperBound : Real) := by + have h := + real_sqrt_le_sqrtUpperWithScale (q := varEpsRat) (scale := scale) + hvarEps_nonneg hscale + simpa [sqrtUpperBound, hvarEps] using h + have hsqrt_lower_pos_rat : 0 < sqrtLowerBound := by + have hpos : 0 < sqrtLowerWithScale scale eps := hsqrt + have hpos' : 0 < max (sqrtLowerWithScale scale eps) (sqrtLowerWithScale scale varEpsRat) := + lt_of_lt_of_le hpos (le_max_left _ _) + simpa [sqrtLowerBound] using hpos' + have hsqrt_lower_pos : 0 < (sqrtLowerBound : Real) := by + exact (Rat.cast_pos (K := Real) (q := sqrtLowerBound)).2 hsqrt_lower_pos_rat + have hsqrt_upper_pos_rat : 0 < sqrtUpperBound := by + have hpos_real : 0 < (sqrtUpperWithScale scale varEpsRat : Real) := by + have hvarEps_pos : 0 < varEps := by + have heps_real : 0 < (eps : Real) := by + exact_mod_cast heps + have hpos := add_pos_of_nonneg_of_pos hvar_nonneg heps_real + simpa [varEps] using hpos + exact lt_of_lt_of_le (Real.sqrt_pos.2 hvarEps_pos) hsqrt_upper + exact (Rat.cast_pos (K := Real) (q := sqrtUpperWithScale scale varEpsRat)).1 hpos_real + have hsqrt_upper_pos : 0 < (sqrtUpperBound : Real) := by + exact (Rat.cast_pos (K := Real) (q := sqrtUpperBound)).2 hsqrt_upper_pos_rat + have hvarEps_pos : 0 < varEps := by + have heps_real : 0 < (eps : Real) := by + exact_mod_cast heps + have hpos := add_pos_of_nonneg_of_pos hvar_nonneg heps_real + simpa [varEps] using hpos + have hsqrt_pos : 0 < Real.sqrt varEps := Real.sqrt_pos.2 hvarEps_pos + have hinv_lower_real : + (sqrtUpperBound : Real)⁻¹ ≤ invStd := by + have hle := inv_anti₀ hsqrt_pos hsqrt_upper + simpa [invStd] using hle + have hinv_upper_real : + invStd ≤ (sqrtLowerBound : Real)⁻¹ := by + have hle := inv_anti₀ hsqrt_lower_pos hsqrt_lower + simpa [invStd] using hle + have hupper_ne : sqrtUpperBound ≠ 0 := ne_of_gt hsqrt_upper_pos_rat + have hlower_ne : sqrtLowerBound ≠ 0 := ne_of_gt hsqrt_lower_pos_rat + have hinv_lower : (invStdLower : Real) ≤ invStd := by + simpa [invStdLower, ratDivDown_def, hupper_ne, one_div] using hinv_lower_real + have hinv_upper : invStd ≤ (invStdUpper : Real) := by + simpa [invStdUpper, ratDivUp_def, hlower_ne, one_div] using hinv_upper_real + have hlayer : + layerNormReal eps gamma beta x i = + (beta i : Real) + (coeff : Real) * invStd := by + simp [layerNormReal, hne, coeff, centered, μ, hmu, invStd, varEps, add_comm, mul_assoc] + by_cases hcoeff : 0 ≤ coeff + · have hcoeff_real : 0 ≤ (coeff : Real) := by + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hcoeff + have hlow_raw : + (beta i : Real) + (coeff : Real) * (invStdLower : Real) ≤ + (beta i : Real) + (coeff : Real) * invStd := by + have hmul := mul_le_mul_of_nonneg_left hinv_lower hcoeff_real + simpa only [add_comm] using add_le_add_left hmul (beta i : Real) + have hhigh_raw : + (beta i : Real) + (coeff : Real) * invStd ≤ + (beta i : Real) + (coeff : Real) * (invStdUpper : Real) := by + have hmul := mul_le_mul_of_nonneg_left hinv_upper hcoeff_real + simpa only [add_comm] using add_le_add_left hmul (beta i : Real) + have hlo : (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i := by + simpa [bounds, layerNormBoundsWithScale, hne, μRat, centered, varRat, varEpsRat, + sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] + using hlow_raw + have hhi : layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + simpa [bounds, layerNormBoundsWithScale, hne, μRat, centered, varRat, varEpsRat, + sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] + using hhigh_raw + exact And.intro hlo hhi + · have hcoeff_lt : coeff < 0 := lt_of_not_ge hcoeff + have hcoeff_real : (coeff : Real) ≤ 0 := by + exact_mod_cast (le_of_lt hcoeff_lt) + have hlow_raw : + (beta i : Real) + (coeff : Real) * (invStdUpper : Real) ≤ + (beta i : Real) + (coeff : Real) * invStd := by + have hmul := mul_le_mul_of_nonpos_left hinv_upper hcoeff_real + simpa only [add_comm] using add_le_add_left hmul (beta i : Real) + have hhigh_raw : + (beta i : Real) + (coeff : Real) * invStd ≤ + (beta i : Real) + (coeff : Real) * (invStdLower : Real) := by + have hmul := mul_le_mul_of_nonpos_left hinv_lower hcoeff_real + simpa only [add_comm] using add_le_add_left hmul (beta i : Real) + have hlo : (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i := by + simpa [bounds, layerNormBoundsWithScale, hne, μRat, centered, varRat, varEpsRat, + sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] + using hlow_raw + have hhi : layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + simpa [bounds, layerNormBoundsWithScale, hne, μRat, centered, varRat, varEpsRat, + sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] + using hhigh_raw + exact And.intro hlo hhi + /-! Local bounds for monotone multiplication in real-valued bounds. -/ diff --git a/Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean b/Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean index aa0cdf2..40362ea 100644 --- a/Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean +++ b/Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean @@ -80,6 +80,10 @@ def sqrtUpperAlt (q : Rat) : Rat := /-- Extra precision scale for `sqrtLowerScaled`. -/ def sqrtLowerScale : Nat := 1048576 +/-- Unfolding lemma for `sqrtLowerScale`. -/ +theorem sqrtLowerScale_def : sqrtLowerScale = 1048576 := by + rfl + /-- Scaled rational lower bound for a square root (extra precision). -/ def sqrtLowerScaled (q : Rat) : Rat := let num := q.num.natAbs @@ -96,6 +100,20 @@ def sqrtUpperScaled (q : Rat) : Rat := let a := Nat.sqrt (num * den * scale * scale) ratRoundUp ((a + 1 : Rat) / (den * scale)) +/-- Scaled rational lower bound for a square root with a custom scale. -/ +def sqrtLowerScaledWith (scale : Nat) (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den + let a := Nat.sqrt (num * den * scale * scale) + ratRoundDown ((a : Rat) / (den * scale)) + +/-- Scaled rational upper bound for a square root with a custom scale. -/ +def sqrtUpperScaledWith (scale : Nat) (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den + let a := Nat.sqrt (num * den * scale * scale) + ratRoundUp ((a + 1 : Rat) / (den * scale)) + /-- Rational lower bound for a square root (tighter of three bounds). -/ def sqrtLower (q : Rat) : Rat := max (max (sqrtLowerBase q) (sqrtLowerAlt q)) (sqrtLowerScaled q) @@ -104,6 +122,14 @@ def sqrtLower (q : Rat) : Rat := def sqrtUpper (q : Rat) : Rat := min (min (sqrtUpperBase q) (sqrtUpperAlt q)) (sqrtUpperScaled q) +/-- Rational lower bound for a square root with a custom scale. -/ +def sqrtLowerWithScale (scale : Nat) (q : Rat) : Rat := + max (max (sqrtLowerBase q) (sqrtLowerAlt q)) (sqrtLowerScaledWith scale q) + +/-- Rational upper bound for a square root with a custom scale. -/ +def sqrtUpperWithScale (scale : Nat) (q : Rat) : Rat := + min (min (sqrtUpperBase q) (sqrtUpperAlt q)) (sqrtUpperScaledWith scale q) + /-- `sqrtLowerBase` is nonnegative. -/ theorem sqrtLowerBase_nonneg (q : Rat) : 0 ≤ sqrtLowerBase q := by classical @@ -490,6 +516,77 @@ theorem sqrtLowerScaled_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : simpa [sqrtLowerScaled, num, den, scale, a] using hdown' exact le_trans hdown hle +/-- Scaled square-root lower bound in reals with a custom scale. -/ +theorem sqrtLowerScaledWith_le_real_sqrt {q : Rat} {scale : Nat} + (hq : 0 ≤ q) (hscale : 0 < scale) : + (sqrtLowerScaledWith scale q : Real) ≤ Real.sqrt (q : Real) := by + classical + set num : Nat := q.num.natAbs + set den : Nat := q.den + set a : Nat := Nat.sqrt (num * den * scale * scale) + have hden_pos : 0 < (den : Real) := by + exact_mod_cast q.den_pos + have hscale_pos : 0 < (scale : Real) := by + exact_mod_cast hscale + have hnumden_le : (a ^ 2 : Real) ≤ (num * den * scale * scale : Nat) := by + exact_mod_cast (Nat.sqrt_le' (num * den * scale * scale)) + have hmul : + (a ^ 2 : Real) ≤ (num : Real) * den * (scale : Real) * (scale : Real) := by + simpa [num, den, Nat.cast_mul, mul_assoc, mul_left_comm, mul_comm] using hnumden_le + have hdenScale_pos : 0 < (den : Real) * (scale : Real) := + mul_pos hden_pos hscale_pos + have hdenScale_pos2 : 0 < ((den : Real) * (scale : Real)) ^ 2 := by + exact pow_pos hdenScale_pos 2 + have hmul' : + (a ^ 2 : Real) * ((den : Real) * (scale : Real)) ^ 2 ≤ + ((num : Real) * den * (scale : Real) * (scale : Real)) * + ((den : Real) * (scale : Real)) ^ 2 := by + have hnonneg : 0 ≤ ((den : Real) * (scale : Real)) ^ 2 := by + exact sq_nonneg _ + exact mul_le_mul_of_nonneg_right hmul hnonneg + have hdiv : + (a ^ 2 : Real) / ((den : Real) * (scale : Real)) ^ 2 ≤ + ((num : Real) * den * (scale : Real) * (scale : Real)) / + ((den : Real) * (scale : Real)) ^ 2 := by + exact (div_le_div_iff₀ hdenScale_pos2 hdenScale_pos2).2 hmul' + have hq_cast : + (q : Real) = + (num : Real) * den * (scale : Real) * (scale : Real) / + ((den : Real) * (scale : Real)) ^ 2 := by + simpa [num, den] using + (rat_cast_eq_num_den_scale (q := q) hq (scale := scale) hscale) + have hpow : + ((a : Real) / ((den : Real) * (scale : Real))) ^ 2 = + (a ^ 2 : Real) / ((den : Real) * (scale : Real)) ^ 2 := by + simp [pow_two, div_mul_div_comm] + have hsq : + ((a : Real) / ((den : Real) * (scale : Real))) ^ 2 ≤ (q : Real) := by + calc + ((a : Real) / ((den : Real) * (scale : Real))) ^ 2 + = (a ^ 2 : Real) / ((den : Real) * (scale : Real)) ^ 2 := hpow + _ ≤ ((num : Real) * den * (scale : Real) * (scale : Real)) / + ((den : Real) * (scale : Real)) ^ 2 := hdiv + _ = (q : Real) := by simp [hq_cast] + have hnonneg : 0 ≤ (a : Real) / ((den : Real) * (scale : Real)) := by + have hnum_nonneg : 0 ≤ (a : Real) := by exact_mod_cast (Nat.zero_le a) + have hden_nonneg : 0 ≤ (den : Real) * (scale : Real) := by + exact mul_nonneg (le_of_lt hden_pos) (le_of_lt hscale_pos) + exact div_nonneg hnum_nonneg hden_nonneg + have hq_nonneg : 0 ≤ (q : Real) := by + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hq + have hle : + (a : Real) / ((den : Real) * (scale : Real)) ≤ Real.sqrt (q : Real) := + (Real.le_sqrt hnonneg hq_nonneg).2 hsq + have hdown : + (sqrtLowerScaledWith scale q : Real) ≤ + (a : Real) / ((den : Real) * (scale : Real)) := by + have hdown' : + (ratRoundDown ((a : Rat) / (den * scale)) : Real) ≤ + (a : Real) / ((den : Real) * (scale : Real)) := by + simpa using ratRoundDown_le_real ((a : Rat) / (den * scale)) + simpa [sqrtLowerScaledWith, num, den, a] using hdown' + exact le_trans hdown hle + /-- Alternate square-root upper bound in reals. -/ theorem real_sqrt_le_sqrtUpperAlt {q : Rat} (hq : 0 ≤ q) : Real.sqrt (q : Real) ≤ (sqrtUpperAlt q : Real) := by @@ -604,6 +701,99 @@ theorem real_sqrt_le_sqrtUpperScaled {q : Rat} (hq : 0 ≤ q) : simpa [sqrtUpperScaled, num, den, scale, a] using hup' exact le_trans hle hup +/-- Scaled square-root upper bound in reals with a custom scale. -/ +theorem real_sqrt_le_sqrtUpperScaledWith {q : Rat} {scale : Nat} + (hq : 0 ≤ q) (hscale : 0 < scale) : + Real.sqrt (q : Real) ≤ (sqrtUpperScaledWith scale q : Real) := by + classical + set num : Nat := q.num.natAbs + set den : Nat := q.den + set a : Nat := Nat.sqrt (num * den * scale * scale) + have hden_pos : 0 < (den : Real) := by + exact_mod_cast q.den_pos + have hscale_pos : 0 < (scale : Real) := by + exact_mod_cast hscale + have hnumden_lt : (num * den * scale * scale : Real) < (a + 1) ^ 2 := by + exact_mod_cast (Nat.lt_succ_sqrt' (num * den * scale * scale)) + have hmul : + (num : Real) * den * (scale : Real) * (scale : Real) ≤ (a + 1 : Real) ^ 2 := by + exact le_of_lt hnumden_lt + have hdenScale_pos : 0 < (den : Real) * (scale : Real) := by + exact mul_pos hden_pos hscale_pos + have hdenScale_pos2 : 0 < ((den : Real) * (scale : Real)) ^ 2 := by + exact pow_pos hdenScale_pos 2 + have hmul' : + (num : Real) * den * (scale : Real) * (scale : Real) * + ((den : Real) * (scale : Real)) ^ 2 ≤ + (a + 1 : Real) ^ 2 * ((den : Real) * (scale : Real)) ^ 2 := by + have hden_sq_nonneg : 0 ≤ ((den : Real) * (scale : Real)) ^ 2 := by + exact sq_nonneg _ + exact mul_le_mul_of_nonneg_right hmul hden_sq_nonneg + have hdiv : + (num : Real) * den * (scale : Real) * (scale : Real) / + ((den : Real) * (scale : Real)) ^ 2 ≤ + (a + 1 : Real) ^ 2 / ((den : Real) * (scale : Real)) ^ 2 := by + exact (div_le_div_iff₀ hdenScale_pos2 hdenScale_pos2).2 hmul' + have hq_cast : + (q : Real) = + (num : Real) * den * (scale : Real) * (scale : Real) / + ((den : Real) * (scale : Real)) ^ 2 := by + simpa [num, den] using + (rat_cast_eq_num_den_scale (q := q) hq (scale := scale) hscale) + have hpow : + ((a + 1 : Real) / ((den : Real) * (scale : Real))) ^ 2 = + (a + 1 : Real) ^ 2 / ((den : Real) * (scale : Real)) ^ 2 := by + simp [pow_two, div_mul_div_comm] + have hsq : (q : Real) ≤ ((a + 1 : Real) / ((den : Real) * (scale : Real))) ^ 2 := by + simpa [hq_cast, hpow] using hdiv + have hnonneg : 0 ≤ ((a + 1 : Real) / ((den : Real) * (scale : Real))) := by + have hnum_nonneg : 0 ≤ (a + 1 : Real) := by exact_mod_cast (Nat.zero_le (a + 1)) + have hden_nonneg : 0 ≤ (den : Real) * (scale : Real) := by + exact mul_nonneg (le_of_lt hden_pos) (le_of_lt hscale_pos) + exact div_nonneg hnum_nonneg hden_nonneg + have hle : + Real.sqrt (q : Real) ≤ (a + 1 : Real) / ((den : Real) * (scale : Real)) := + (Real.sqrt_le_iff).2 ⟨hnonneg, hsq⟩ + have hup : + (a + 1 : Real) / ((den : Real) * (scale : Real)) ≤ + (sqrtUpperScaledWith scale q : Real) := by + have hup' : + (a + 1 : Real) / ((den : Real) * (scale : Real)) ≤ + (ratRoundUp ((a + 1 : Rat) / (den * scale)) : Real) := by + simpa using real_le_ratRoundUp ((a + 1 : Rat) / (den * scale)) + simpa [sqrtUpperScaledWith, num, den, a] using hup' + exact le_trans hle hup + +theorem sqrtLowerWithScale_le_real_sqrt {q : Rat} {scale : Nat} + (hq : 0 ≤ q) (hscale : 0 < scale) : + (sqrtLowerWithScale scale q : Real) ≤ Real.sqrt (q : Real) := by + have hbase := sqrtLowerBase_le_real_sqrt (q := q) hq + have halt := sqrtLowerAlt_le_real_sqrt (q := q) hq + have hscaled := sqrtLowerScaledWith_le_real_sqrt (q := q) (scale := scale) hq hscale + have hmax1 : + (max (sqrtLowerBase q) (sqrtLowerAlt q) : Real) ≤ Real.sqrt (q : Real) := by + simpa [ratToReal_max] using (max_le_iff).2 ⟨hbase, halt⟩ + have hmax2 : + (max (max (sqrtLowerBase q) (sqrtLowerAlt q)) (sqrtLowerScaledWith scale q) : Real) ≤ + Real.sqrt (q : Real) := by + simpa [ratToReal_max] using (max_le_iff).2 ⟨hmax1, hscaled⟩ + simpa [sqrtLowerWithScale] using hmax2 + +theorem real_sqrt_le_sqrtUpperWithScale {q : Rat} {scale : Nat} + (hq : 0 ≤ q) (hscale : 0 < scale) : + Real.sqrt (q : Real) ≤ (sqrtUpperWithScale scale q : Real) := by + have hbase := real_sqrt_le_sqrtUpperBase (q := q) hq + have halt := real_sqrt_le_sqrtUpperAlt (q := q) hq + have hscaled := real_sqrt_le_sqrtUpperScaledWith (q := q) (scale := scale) hq hscale + have hmin1 : + Real.sqrt (q : Real) ≤ min (sqrtUpperBase q : Real) (sqrtUpperAlt q : Real) := by + exact (le_min_iff).2 ⟨hbase, halt⟩ + have hmin2 : + Real.sqrt (q : Real) ≤ + min (min (sqrtUpperBase q : Real) (sqrtUpperAlt q : Real)) + (sqrtUpperScaledWith scale q : Real) := by + exact (le_min_iff).2 ⟨hmin1, hscaled⟩ + simpa [sqrtUpperWithScale, ratToReal_min] using hmin2 /-- Square-root lower bound in reals (tighter of three bounds). -/ theorem sqrtLower_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : (sqrtLower q : Real) ≤ Real.sqrt (q : Real) := by diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean index f985181..ae49847 100644 --- a/Nfp/Sound/Induction/LogitDiff.lean +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -107,6 +107,14 @@ noncomputable def headLogitDiff (inputs : Model.InductionHeadInputs seq dModel d Circuit.softmax (scoresRealOfInputs inputs q) k dotProduct (weights q) (valsRealOfInputs inputs) +/-- Unfolding lemma for `headLogitDiff`. -/ +theorem headLogitDiff_def (inputs : Model.InductionHeadInputs seq dModel dHead) (q : Fin seq) : + headLogitDiff inputs q = + let weights : Fin seq → Fin seq → Real := fun q k => + Circuit.softmax (scoresRealOfInputs inputs q) k + dotProduct (weights q) (valsRealOfInputs inputs) := by + rfl + /-- Lower bound computed from the per-key lower bounds in an induction certificate. -/ def logitDiffLowerBoundFromCert (c : InductionHeadCert seq) : Option Rat := let epsAt := Bounds.cacheBoundTask c.epsAt @@ -234,6 +242,26 @@ def logitDiffLowerBoundFromCacheWithEps (c : InductionHeadCert seq) (cache : Log c.values.lo Circuit.logitDiffLowerBoundAtLoAt c.active c.prev epsAt loAt valsLo +/-- Unweighted logit-diff lower bound from a custom eps and value lower bounds. -/ +def logitDiffLowerBoundFromCacheWithEpsVals (c : InductionHeadCert seq) + (epsAtCustom valsLoCustom : Fin seq → Rat) : Option Rat := + let epsArr : Array Rat := Array.ofFn epsAtCustom + let valsLoArr : Array Rat := Array.ofFn valsLoCustom + let epsAt : Fin seq → Rat := fun q => + epsArr[q.1]'(by + simp [epsArr, q.isLt]) + let valsLo : Fin seq → Rat := fun q => + valsLoArr[q.1]'(by + simp [valsLoArr, q.isLt]) + let loAt : Fin seq → Rat := fun q => + let others : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).erase (c.prev q) + if h : others.Nonempty then + others.inf' h valsLo + else + c.values.lo + Circuit.logitDiffLowerBoundAtLoAt c.active c.prev epsAt loAt valsLo + /-- Unfold `logitDiffLowerBoundFromCache` as the custom-eps variant. -/ theorem logitDiffLowerBoundFromCache_eq_withEps (c : InductionHeadCert seq) (cache : LogitDiffCache seq) : @@ -264,6 +292,28 @@ theorem logitDiffLowerBoundFromCacheWithEps_def Circuit.logitDiffLowerBoundAtLoAt c.active c.prev epsAt loAt valsLo := by rfl +/-- Unfolding lemma for `logitDiffLowerBoundFromCacheWithEpsVals`. -/ +theorem logitDiffLowerBoundFromCacheWithEpsVals_def + (c : InductionHeadCert seq) (epsAtCustom valsLoCustom : Fin seq → Rat) : + logitDiffLowerBoundFromCacheWithEpsVals c epsAtCustom valsLoCustom = + let epsArr : Array Rat := Array.ofFn epsAtCustom + let valsLoArr : Array Rat := Array.ofFn valsLoCustom + let epsAt : Fin seq → Rat := fun q => + epsArr[q.1]'(by + simp [epsArr, q.isLt]) + let valsLo : Fin seq → Rat := fun q => + valsLoArr[q.1]'(by + simp [valsLoArr, q.isLt]) + let loAt : Fin seq → Rat := fun q => + let others : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).erase (c.prev q) + if h : others.Nonempty then + others.inf' h valsLo + else + c.values.lo + Circuit.logitDiffLowerBoundAtLoAt c.active c.prev epsAt loAt valsLo := by + rfl + /-- Refined unweighted logit-diff lower bound using an overlayed `epsAt`. -/ def logitDiffLowerBoundRefinedFromCache (inputs : Model.InductionHeadInputs seq dModel dHead) @@ -338,14 +388,27 @@ def logitDiffLowerBoundRefineOnDemand | none => some lb0 | some lb1 => let lb01 := max lb0 lb1 - if lb01 ≤ 0 then - let refineBudget' := refineBudgetBoost refineBudget - let spec' := refineSpecForQueryWithWeightOnes inputs core q0 refineBudget' - match logitDiffLowerBoundRefinedFromCache inputs core c cache spec' with - | some lb2 => some (max lb01 lb2) - | none => some lb01 - else - some lb01 + let lbWeight? : Option Rat := + if lb01 ≤ 0 then + let refineBudget' := refineBudgetBoost refineBudget + let spec' := refineSpecForQueryWithWeightOnes inputs core q0 refineBudget' + match logitDiffLowerBoundRefinedFromCache inputs core c cache spec' with + | some lb2 => some (max lb01 lb2) + | none => some lb01 + else + some lb01 + match lbWeight? with + | none => some lb01 + | some lbWeight => + if lbWeight ≤ 0 then + let valBudget := refineBudgetBoost refineBudget + let valKeys := loAtKeysAt inputs core q0 + let valsLo := valsLoOverlay inputs core valBudget valKeys + match logitDiffLowerBoundFromCacheWithEpsVals c cache.epsAt valsLo with + | some lb2 => some (max lbWeight lb2) + | none => some lbWeight + else + some lbWeight else some lb0 @@ -368,14 +431,27 @@ theorem logitDiffLowerBoundRefineOnDemand_def | none => some lb0 | some lb1 => let lb01 := max lb0 lb1 - if lb01 ≤ 0 then - let refineBudget' := refineBudgetBoost refineBudget - let spec' := refineSpecForQueryWithWeightOnes inputs core q0 refineBudget' - match logitDiffLowerBoundRefinedFromCache inputs core c cache spec' with - | some lb2 => some (max lb01 lb2) - | none => some lb01 - else - some lb01 + let lbWeight? : Option Rat := + if lb01 ≤ 0 then + let refineBudget' := refineBudgetBoost refineBudget + let spec' := refineSpecForQueryWithWeightOnes inputs core q0 refineBudget' + match logitDiffLowerBoundRefinedFromCache inputs core c cache spec' with + | some lb2 => some (max lb01 lb2) + | none => some lb01 + else + some lb01 + match lbWeight? with + | none => some lb01 + | some lbWeight => + if lbWeight ≤ 0 then + let valBudget := refineBudgetBoost refineBudget + let valKeys := loAtKeysAt inputs core q0 + let valsLo := valsLoOverlay inputs core valBudget valKeys + match logitDiffLowerBoundFromCacheWithEpsVals c cache.epsAt valsLo with + | some lb2 => some (max lbWeight lb2) + | none => some lbWeight + else + some lbWeight else some lb0 := by rfl @@ -731,233 +807,6 @@ theorem logitDiffLowerBoundFromCert_le le_trans hboundReal hdot_lower simpa [headLogitDiff, weights, vals] using hle -/-- The unweighted logit-diff lower bound is sound for any valid per-query `epsAt`. -/ -theorem logitDiffLowerBoundFromCacheWithEps_le - (inputs : Model.InductionHeadInputs seq dModel dHead) - (c : InductionHeadCert seq) (epsAtCustom : Fin seq → Rat) - (hsound : InductionHeadCertSound inputs c) - (honeHot : - ∀ q, q ∈ c.active → - Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAtCustom q : Real) - (fun q' => q' = q) c.prev - (fun q' k => Circuit.softmax (scoresRealOfInputs inputs q') k)) - {lb : Rat} - (hbound : - logitDiffLowerBoundFromCacheWithEps c (logitDiffCache c) epsAtCustom = some lb) - {q : Fin seq} (hq : q ∈ c.active) : - (lb : Real) ≤ headLogitDiff inputs q := by - classical - cases seq with - | zero => - cases (NeZero.ne (n := (0 : Nat)) rfl) - | succ n => - let weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := fun q k => - Circuit.softmax (scoresRealOfInputs inputs q) k - let vals : Fin (Nat.succ n) → Real := valsRealOfInputs inputs - let epsArr : Array Rat := Array.ofFn epsAtCustom - let valsLoArr : Array Rat := Array.ofFn (logitDiffCache c).valsLo - let epsAt : Fin (Nat.succ n) → Rat := fun q => - epsArr[q.1]'(by - simp [epsArr, q.isLt]) - let valsLo : Fin (Nat.succ n) → Rat := fun q => - valsLoArr[q.1]'(by - simp [valsLoArr, q.isLt]) - let loAt : Fin (Nat.succ n) → Rat := fun q => - let others : Finset (Fin (Nat.succ n)) := - (Finset.univ : Finset (Fin (Nat.succ n))).erase (c.prev q) - if h : others.Nonempty then - others.inf' h valsLo - else - c.values.lo - let others : Finset (Fin (Nat.succ n)) := - (Finset.univ : Finset (Fin (Nat.succ n))).erase (c.prev q) - let sumOthers : Real := ∑ k ∈ others, weights q k - let valsLoPrev : Real := (c.values.valsLo (c.prev q) : Real) - let loAtRat : Rat := loAt q - let loAtReal : Real := (loAtRat : Real) - have hboundRat : - lb ≤ valsLo (c.prev q) - - epsAt q * max (0 : Rat) (valsLo (c.prev q) - loAt q) := by - refine - Circuit.logitDiffLowerBoundAtLoAt_le - (active := c.active) - (prev := c.prev) - (epsAt := epsAt) - (loAt := loAt) - (valsLo := valsLo) - q hq lb ?_ - simpa [logitDiffLowerBoundFromCacheWithEps, loAt, epsAt, valsLo, valsLoArr, epsArr, - logitDiffCache] using hbound - have hepsAt : epsAt q = epsAtCustom q := by - simp [epsAt, epsArr] - have hvalsLo : ∀ k, valsLo k = c.values.valsLo k := by - intro k - simp [valsLo, valsLoArr, logitDiffCache, Bounds.cacheBoundTask_apply] - have hboundRat' : - lb ≤ c.values.valsLo (c.prev q) - - epsAtCustom q * max (0 : Rat) (c.values.valsLo (c.prev q) - loAt q) := by - simpa [hepsAt, hvalsLo] using hboundRat - have hboundReal : - (lb : Real) ≤ - valsLoPrev - (epsAtCustom q : Real) * - max (0 : Real) (valsLoPrev - loAtReal) := by - simpa [loAtRat, loAtReal, ratToReal_sub, ratToReal_mul, ratToReal_max, ratToReal_def] - using ratToReal_le_of_le hboundRat' - have hweights_nonneg : ∀ k, 0 ≤ weights q k := by - have hweights := honeHot q hq - simpa [weights] using hweights.nonneg q rfl - have hweights := honeHot q hq - have hsum_decomp : - weights q (c.prev q) + ∑ k ∈ others, weights q k = ∑ k, weights q k := by - simp [others] - have hsum : - weights q (c.prev q) + ∑ k ∈ others, weights q k = 1 := by - have hsum_one : (∑ k, weights q k) = 1 := by - simpa [weights] using hweights.sum_one q rfl - calc - weights q (c.prev q) + ∑ k ∈ others, weights q k = ∑ k, weights q k := hsum_decomp - _ = 1 := hsum_one - have hsum_others_le : sumOthers ≤ (epsAtCustom q : Real) := by - have hprev : 1 ≤ weights q (c.prev q) + (epsAtCustom q : Real) := - hweights.prev_large q rfl - have hprev' : - weights q (c.prev q) + sumOthers ≤ - weights q (c.prev q) + (epsAtCustom q : Real) := by - simpa [hsum, sumOthers] using hprev - exact (add_le_add_iff_left (weights q (c.prev q))).1 hprev' - have hloAt_le_valsLo : ∀ k ∈ others, loAtRat ≤ c.values.valsLo k := by - intro k hk - have hnonempty : others.Nonempty := ⟨k, hk⟩ - have hmin : others.inf' hnonempty valsLo ≤ valsLo k := - Finset.inf'_le (s := others) (f := valsLo) hk - have hnonempty' : (Finset.univ.erase (c.prev q)).Nonempty := by - simpa [others] using hnonempty - have hloAt : loAtRat = others.inf' hnonempty valsLo := by - dsimp [loAtRat, loAt] - simp [hnonempty', others] - have hvalsLo' : valsLo k = c.values.valsLo k := hvalsLo k - calc - loAtRat = others.inf' hnonempty valsLo := hloAt - _ ≤ valsLo k := hmin - _ = c.values.valsLo k := hvalsLo' - have hvals_lo : ∀ k ∈ others, loAtReal ≤ vals k := by - intro k hk - have hloRat := hloAt_le_valsLo k hk - have hloReal : loAtReal ≤ (c.values.valsLo k : Real) := by - simpa [loAtReal, ratToReal_def] using (ratToReal_le_of_le hloRat) - have hvals : (c.values.valsLo k : Real) ≤ vals k := by - simpa using (hsound.value_bounds.vals_bounds k).1 - exact le_trans hloReal hvals - have hvalsLo_prev : valsLoPrev ≤ vals (c.prev q) := by - exact (hsound.value_bounds.vals_bounds (c.prev q)).1 - have hsum_vals_ge : - sumOthers * loAtReal ≤ ∑ k ∈ others, weights q k * vals k := by - have hsum_lo : - sumOthers * loAtReal = ∑ k ∈ others, weights q k * loAtReal := by - have hsum_lo' : - (∑ k ∈ others, weights q k) * loAtReal = - ∑ k ∈ others, weights q k * loAtReal := by - simpa using - (Finset.sum_mul (s := others) (f := fun k => weights q k) (a := loAtReal)) - simpa [sumOthers] using hsum_lo' - have hle : - ∀ k ∈ others, weights q k * loAtReal ≤ weights q k * vals k := by - intro k _hk - have hval := hvals_lo k _hk - have hnonneg := hweights_nonneg k - exact mul_le_mul_of_nonneg_left hval hnonneg - have hsum' : - ∑ k ∈ others, weights q k * loAtReal ≤ - ∑ k ∈ others, weights q k * vals k := by - exact Finset.sum_le_sum hle - simpa [hsum_lo] using hsum' - have hsum_prod : - weights q (c.prev q) * vals (c.prev q) + - ∑ k ∈ others, weights q k * vals k = - ∑ k, weights q k * vals k := by - simp [others] - have hout_eq : - dotProduct (weights q) vals = - weights q (c.prev q) * vals (c.prev q) + - ∑ k ∈ others, weights q k * vals k := by - simpa [dotProduct] using hsum_prod.symm - have hdot_ge : - weights q (c.prev q) * vals (c.prev q) + sumOthers * loAtReal ≤ - dotProduct (weights q) vals := by - have hle : - weights q (c.prev q) * vals (c.prev q) + sumOthers * loAtReal ≤ - weights q (c.prev q) * vals (c.prev q) + - ∑ k ∈ others, weights q k * vals k := by - simpa [add_comm, add_left_comm, add_assoc] using - (add_le_add_left hsum_vals_ge (weights q (c.prev q) * vals (c.prev q))) - simpa [sumOthers, hout_eq, add_comm, add_left_comm, add_assoc] using hle - have hprev_lo : - weights q (c.prev q) * valsLoPrev ≤ - weights q (c.prev q) * vals (c.prev q) := by - exact mul_le_mul_of_nonneg_left hvalsLo_prev (hweights_nonneg (c.prev q)) - have hdot_ge' : - weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal ≤ - dotProduct (weights q) vals := by - have hle : - weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal ≤ - weights q (c.prev q) * vals (c.prev q) + sumOthers * loAtReal := by - simpa [add_comm, add_left_comm, add_assoc] using - (add_le_add_right hprev_lo (sumOthers * loAtReal)) - exact hle.trans hdot_ge - have hsplit : - weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal = - valsLoPrev - sumOthers * (valsLoPrev - loAtReal) := by - have hsplit' : - weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal = - (weights q (c.prev q) + sumOthers) * valsLoPrev - - sumOthers * (valsLoPrev - loAtReal) := by - ring - calc - weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal = - (weights q (c.prev q) + sumOthers) * valsLoPrev - - sumOthers * (valsLoPrev - loAtReal) := hsplit' - _ = valsLoPrev - sumOthers * (valsLoPrev - loAtReal) := by - simp [hsum, sumOthers] - have hdiff_le : valsLoPrev - loAtReal ≤ max (0 : Real) (valsLoPrev - loAtReal) := by - exact le_max_right _ _ - have hsum_nonneg : 0 ≤ sumOthers := by - have hnonneg : ∀ k ∈ others, 0 ≤ weights q k := by - intro k _hk - exact hweights_nonneg k - have hsum_nonneg' : 0 ≤ ∑ k ∈ others, weights q k := by - exact Finset.sum_nonneg hnonneg - simpa [sumOthers] using hsum_nonneg' - have hsum_mul_le_left : - sumOthers * (valsLoPrev - loAtReal) ≤ - sumOthers * max (0 : Real) (valsLoPrev - loAtReal) := by - exact mul_le_mul_of_nonneg_left hdiff_le hsum_nonneg - have hmax_nonneg : 0 ≤ max (0 : Real) (valsLoPrev - loAtReal) := by - exact le_max_left _ _ - have hsum_mul_le : - sumOthers * (valsLoPrev - loAtReal) ≤ - (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) := by - have hsum_mul_le_right : - sumOthers * max (0 : Real) (valsLoPrev - loAtReal) ≤ - (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) := by - exact mul_le_mul_of_nonneg_right hsum_others_le hmax_nonneg - exact le_trans hsum_mul_le_left hsum_mul_le_right - have hsub_le : - valsLoPrev - (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) ≤ - valsLoPrev - sumOthers * (valsLoPrev - loAtReal) := by - exact sub_le_sub_left hsum_mul_le valsLoPrev - have hdot_lower : - valsLoPrev - (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) ≤ - dotProduct (weights q) vals := by - calc - valsLoPrev - (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) ≤ - valsLoPrev - sumOthers * (valsLoPrev - loAtReal) := hsub_le - _ = weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal := by - simp [hsplit] - _ ≤ dotProduct (weights q) vals := hdot_ge' - have hle : (lb : Real) ≤ dotProduct (weights q) vals := - le_trans hboundReal hdot_lower - simpa [headLogitDiff, weights, vals] using hle - /-- The weighted per-key logit-diff lower bound is sound on active queries. -/ theorem logitDiffLowerBoundFromCertWeighted_le (inputs : Model.InductionHeadInputs seq dModel dHead) diff --git a/Nfp/Sound/Induction/LogitDiffSound.lean b/Nfp/Sound/Induction/LogitDiffSound.lean new file mode 100644 index 0000000..e4d329d --- /dev/null +++ b/Nfp/Sound/Induction/LogitDiffSound.lean @@ -0,0 +1,467 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Sound.Induction.LogitDiff + +/-! +Soundness lemmas for logit-diff lower bounds with custom eps/values. +-/ + +public section + +namespace Nfp + +namespace Sound + +open Nfp.Circuit + +variable {seq dModel dHead : Nat} + +section WithNeZero + +variable [NeZero seq] + +/-- The unweighted logit-diff lower bound is sound for any valid per-query `epsAt`. -/ +theorem logitDiffLowerBoundFromCacheWithEps_le + (inputs : Model.InductionHeadInputs seq dModel dHead) + (c : InductionHeadCert seq) (epsAtCustom : Fin seq → Rat) + (hsound : InductionHeadCertSound inputs c) + (honeHot : + ∀ q, q ∈ c.active → + Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAtCustom q : Real) + (fun q' => q' = q) c.prev + (fun q' k => Circuit.softmax (scoresRealOfInputs inputs q') k)) + {lb : Rat} + (hbound : + logitDiffLowerBoundFromCacheWithEps c (logitDiffCache c) epsAtCustom = some lb) + {q : Fin seq} (hq : q ∈ c.active) : + (lb : Real) ≤ headLogitDiff inputs q := by + classical + cases seq with + | zero => + cases (NeZero.ne (n := (0 : Nat)) rfl) + | succ n => + let weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := fun q k => + Circuit.softmax (scoresRealOfInputs inputs q) k + let vals : Fin (Nat.succ n) → Real := valsRealOfInputs inputs + let epsArr : Array Rat := Array.ofFn epsAtCustom + let valsLoArr : Array Rat := Array.ofFn (logitDiffCache c).valsLo + let epsAt : Fin (Nat.succ n) → Rat := fun q => + epsArr[q.1]'(by + simp [epsArr, q.isLt]) + let valsLo : Fin (Nat.succ n) → Rat := fun q => + valsLoArr[q.1]'(by + simp [valsLoArr, q.isLt]) + let loAt : Fin (Nat.succ n) → Rat := fun q => + let others : Finset (Fin (Nat.succ n)) := + (Finset.univ : Finset (Fin (Nat.succ n))).erase (c.prev q) + if h : others.Nonempty then + others.inf' h valsLo + else + c.values.lo + let others : Finset (Fin (Nat.succ n)) := + (Finset.univ : Finset (Fin (Nat.succ n))).erase (c.prev q) + let sumOthers : Real := ∑ k ∈ others, weights q k + let valsLoPrev : Real := (c.values.valsLo (c.prev q) : Real) + let loAtRat : Rat := loAt q + let loAtReal : Real := (loAtRat : Real) + have hboundRat : + lb ≤ valsLo (c.prev q) - + epsAt q * max (0 : Rat) (valsLo (c.prev q) - loAt q) := by + refine + Circuit.logitDiffLowerBoundAtLoAt_le + (active := c.active) + (prev := c.prev) + (epsAt := epsAt) + (loAt := loAt) + (valsLo := valsLo) + q hq lb ?_ + simpa [logitDiffLowerBoundFromCacheWithEps_def, logitDiffCache_def, loAt, epsAt, valsLo, + valsLoArr, epsArr] using hbound + have hepsAt : epsAt q = epsAtCustom q := by + simp [epsAt, epsArr] + have hvalsLo : ∀ k, valsLo k = c.values.valsLo k := by + intro k + simp [valsLo, valsLoArr, logitDiffCache_def, Bounds.cacheBoundTask_apply] + have hboundRat' : + lb ≤ c.values.valsLo (c.prev q) - + epsAtCustom q * max (0 : Rat) (c.values.valsLo (c.prev q) - loAt q) := by + simpa [hepsAt, hvalsLo] using hboundRat + have hboundReal : + (lb : Real) ≤ + valsLoPrev - (epsAtCustom q : Real) * + max (0 : Real) (valsLoPrev - loAtReal) := by + simpa [loAtRat, loAtReal, ratToReal_sub, ratToReal_mul, ratToReal_max, ratToReal_def] + using ratToReal_le_of_le hboundRat' + have hweights_nonneg : ∀ k, 0 ≤ weights q k := by + have hweights := honeHot q hq + simpa [weights] using hweights.nonneg q rfl + have hweights := honeHot q hq + have hsum_decomp : + weights q (c.prev q) + ∑ k ∈ others, weights q k = ∑ k, weights q k := by + simp [others] + have hsum : + weights q (c.prev q) + ∑ k ∈ others, weights q k = 1 := by + have hsum_one : (∑ k, weights q k) = 1 := by + simpa [weights] using hweights.sum_one q rfl + calc + weights q (c.prev q) + ∑ k ∈ others, weights q k = ∑ k, weights q k := hsum_decomp + _ = 1 := hsum_one + have hsum_others_le : sumOthers ≤ (epsAtCustom q : Real) := by + have hprev : 1 ≤ weights q (c.prev q) + (epsAtCustom q : Real) := + hweights.prev_large q rfl + have hprev' : + weights q (c.prev q) + sumOthers ≤ + weights q (c.prev q) + (epsAtCustom q : Real) := by + simpa [hsum, sumOthers] using hprev + exact (add_le_add_iff_left (weights q (c.prev q))).1 hprev' + have hloAt_le_valsLo : ∀ k ∈ others, loAtRat ≤ c.values.valsLo k := by + intro k hk + have hnonempty : others.Nonempty := ⟨k, hk⟩ + have hmin : others.inf' hnonempty valsLo ≤ valsLo k := + Finset.inf'_le (s := others) (f := valsLo) hk + have hnonempty' : (Finset.univ.erase (c.prev q)).Nonempty := by + simpa [others] using hnonempty + have hloAt : loAtRat = others.inf' hnonempty valsLo := by + dsimp [loAtRat, loAt] + simp [hnonempty', others] + have hvalsLo' : valsLo k = c.values.valsLo k := hvalsLo k + calc + loAtRat = others.inf' hnonempty valsLo := hloAt + _ ≤ valsLo k := hmin + _ = c.values.valsLo k := hvalsLo' + have hvals_lo : ∀ k ∈ others, loAtReal ≤ vals k := by + intro k hk + have hloRat := hloAt_le_valsLo k hk + have hloReal : loAtReal ≤ (c.values.valsLo k : Real) := by + simpa [loAtReal, ratToReal_def] using (ratToReal_le_of_le hloRat) + have hvals : (c.values.valsLo k : Real) ≤ vals k := by + simpa using (hsound.value_bounds.vals_bounds k).1 + exact le_trans hloReal hvals + have hvalsLo_prev : valsLoPrev ≤ vals (c.prev q) := by + exact (hsound.value_bounds.vals_bounds (c.prev q)).1 + have hsum_lo : + sumOthers * loAtReal = ∑ k ∈ others, weights q k * loAtReal := by + have hsum_lo' : + (∑ k ∈ others, weights q k) * loAtReal = + ∑ k ∈ others, weights q k * loAtReal := by + simpa using + (Finset.sum_mul (s := others) (f := fun k => weights q k) (a := loAtReal)) + simpa [sumOthers] using hsum_lo' + have hsum_vals_ge : + sumOthers * loAtReal ≤ ∑ k ∈ others, weights q k * vals k := by + have hle : + ∀ k ∈ others, weights q k * loAtReal ≤ weights q k * vals k := by + intro k _hk + have hval := hvals_lo k _hk + have hnonneg := hweights_nonneg k + exact mul_le_mul_of_nonneg_left hval hnonneg + have hsum' : + ∑ k ∈ others, weights q k * loAtReal ≤ + ∑ k ∈ others, weights q k * vals k := by + exact Finset.sum_le_sum hle + simpa [hsum_lo] using hsum' + have hsum_vals_ge' : + ∑ k ∈ others, weights q k * loAtReal ≤ + ∑ k ∈ others, weights q k * vals k := by + simpa [hsum_lo] using hsum_vals_ge + have hsum_nonneg : 0 ≤ sumOthers := by + have hnonneg : ∀ k ∈ others, 0 ≤ weights q k := by + intro k hk + exact hweights_nonneg k + have hsum_nonneg' : 0 ≤ ∑ k ∈ others, weights q k := by + exact Finset.sum_nonneg hnonneg + simpa [sumOthers] using hsum_nonneg' + have hsplit : + weights q (c.prev q) = 1 - sumOthers := by + have hsum' : weights q (c.prev q) + sumOthers = 1 := by + simpa [sumOthers] using hsum + exact (eq_sub_iff_add_eq).2 hsum' + have hdiff_le : valsLoPrev - loAtReal ≤ max (0 : Real) (valsLoPrev - loAtReal) := by + exact le_max_right _ _ + have hsum_mul_le_left : + sumOthers * (valsLoPrev - loAtReal) ≤ + sumOthers * max (0 : Real) (valsLoPrev - loAtReal) := by + exact mul_le_mul_of_nonneg_left hdiff_le hsum_nonneg + have hmax_nonneg : 0 ≤ max (0 : Real) (valsLoPrev - loAtReal) := by + exact le_max_left _ _ + have hsum_mul_le : + sumOthers * (valsLoPrev - loAtReal) ≤ + (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) := by + have hsum_mul_le_right : + sumOthers * max (0 : Real) (valsLoPrev - loAtReal) ≤ + (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) := by + exact mul_le_mul_of_nonneg_right hsum_others_le hmax_nonneg + exact le_trans hsum_mul_le_left hsum_mul_le_right + have hsub_le : + valsLoPrev - (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) ≤ + valsLoPrev - sumOthers * (valsLoPrev - loAtReal) := by + exact sub_le_sub_left hsum_mul_le valsLoPrev + have hdot_lower : + valsLoPrev - (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) ≤ + dotProduct (weights q) vals := by + calc + valsLoPrev - (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) ≤ + valsLoPrev - sumOthers * (valsLoPrev - loAtReal) := hsub_le + _ = weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal := by + have hsplit_calc : + valsLoPrev - sumOthers * (valsLoPrev - loAtReal) = + (1 - sumOthers) * valsLoPrev + sumOthers * loAtReal := by + ring + simpa [hsplit] using hsplit_calc + _ ≤ dotProduct (weights q) vals := by + have hprev_le := mul_le_mul_of_nonneg_left hvalsLo_prev (hweights_nonneg (c.prev q)) + have hdot_ge : + weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal ≤ + weights q (c.prev q) * vals (c.prev q) + sumOthers * loAtReal := by + simpa [add_comm, add_left_comm, add_assoc] using + (add_le_add_right hprev_le (sumOthers * loAtReal)) + have hdot_ge' : + weights q (c.prev q) * vals (c.prev q) + sumOthers * loAtReal ≤ + dotProduct (weights q) vals := by + calc + weights q (c.prev q) * vals (c.prev q) + sumOthers * loAtReal + = weights q (c.prev q) * vals (c.prev q) + + ∑ k ∈ others, weights q k * loAtReal := by + simp [hsum_lo] + _ ≤ weights q (c.prev q) * vals (c.prev q) + + ∑ k ∈ others, weights q k * vals k := by + simpa [add_comm, add_left_comm, add_assoc] using + (add_le_add_left hsum_vals_ge' + (weights q (c.prev q) * vals (c.prev q))) + _ = dotProduct (weights q) vals := by + simp [dotProduct, others] + exact le_trans hdot_ge hdot_ge' + have hle : (lb : Real) ≤ dotProduct (weights q) vals := + le_trans hboundReal hdot_lower + simpa [headLogitDiff_def, weights, vals] using hle + +/-- The unweighted logit-diff lower bound is sound for custom eps and values. -/ +theorem logitDiffLowerBoundFromCacheWithEpsVals_le + (inputs : Model.InductionHeadInputs seq dModel dHead) + (c : InductionHeadCert seq) (epsAtCustom valsLoCustom : Fin seq → Rat) + (hsound : InductionHeadCertSound inputs c) + (honeHot : + ∀ q, q ∈ c.active → + Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAtCustom q : Real) + (fun q' => q' = q) c.prev + (fun q' k => Circuit.softmax (scoresRealOfInputs inputs q') k)) + (hvalsLo : + ∀ k, (valsLoCustom k : Real) ≤ valsRealOfInputs inputs k) + {lb : Rat} + (hbound : + logitDiffLowerBoundFromCacheWithEpsVals c epsAtCustom valsLoCustom = some lb) + {q : Fin seq} (hq : q ∈ c.active) : + (lb : Real) ≤ headLogitDiff inputs q := by + classical + cases seq with + | zero => + cases (NeZero.ne (n := (0 : Nat)) rfl) + | succ n => + let weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := fun q k => + Circuit.softmax (scoresRealOfInputs inputs q) k + let vals : Fin (Nat.succ n) → Real := valsRealOfInputs inputs + let epsArr : Array Rat := Array.ofFn epsAtCustom + let valsLoArr : Array Rat := Array.ofFn valsLoCustom + let epsAt : Fin (Nat.succ n) → Rat := fun q => + epsArr[q.1]'(by + simp [epsArr, q.isLt]) + let valsLo : Fin (Nat.succ n) → Rat := fun q => + valsLoArr[q.1]'(by + simp [valsLoArr, q.isLt]) + let loAt : Fin (Nat.succ n) → Rat := fun q => + let others : Finset (Fin (Nat.succ n)) := + (Finset.univ : Finset (Fin (Nat.succ n))).erase (c.prev q) + if h : others.Nonempty then + others.inf' h valsLo + else + c.values.lo + let others : Finset (Fin (Nat.succ n)) := + (Finset.univ : Finset (Fin (Nat.succ n))).erase (c.prev q) + let sumOthers : Real := ∑ k ∈ others, weights q k + let valsLoPrev : Real := (valsLo (c.prev q) : Real) + let loAtRat : Rat := loAt q + let loAtReal : Real := (loAtRat : Real) + have hvalsLo_eq : ∀ k, valsLo k = valsLoCustom k := by + intro k + simp [valsLo, valsLoArr] + have hboundRat : + lb ≤ valsLo (c.prev q) - + epsAt q * max (0 : Rat) (valsLo (c.prev q) - loAt q) := by + refine + Circuit.logitDiffLowerBoundAtLoAt_le + (active := c.active) + (prev := c.prev) + (epsAt := epsAt) + (loAt := loAt) + (valsLo := valsLo) + q hq lb ?_ + simpa [logitDiffLowerBoundFromCacheWithEpsVals_def, loAt, epsAt, valsLo, valsLoArr, epsArr] + using hbound + have hepsAt : epsAt q = epsAtCustom q := by + simp [epsAt, epsArr] + have hboundRat' : + lb ≤ valsLoCustom (c.prev q) - + epsAtCustom q * max (0 : Rat) (valsLoCustom (c.prev q) - loAt q) := by + simpa [hepsAt, hvalsLo_eq] using hboundRat + have hboundReal : + (lb : Real) ≤ + valsLoPrev - (epsAtCustom q : Real) * + max (0 : Real) (valsLoPrev - loAtReal) := by + have hvalsLoPrev_eq : valsLoPrev = (valsLoCustom (c.prev q) : Real) := by + simp [valsLoPrev, valsLo, valsLoArr] + simpa [hvalsLoPrev_eq, loAtRat, loAtReal, ratToReal_sub, ratToReal_mul, ratToReal_max, + ratToReal_def] + using ratToReal_le_of_le hboundRat' + have hweights_nonneg : ∀ k, 0 ≤ weights q k := by + have hweights := honeHot q hq + simpa [weights] using hweights.nonneg q rfl + have hweights := honeHot q hq + have hsum_decomp : + weights q (c.prev q) + ∑ k ∈ others, weights q k = ∑ k, weights q k := by + simp [others] + have hsum : + weights q (c.prev q) + ∑ k ∈ others, weights q k = 1 := by + have hsum_one : (∑ k, weights q k) = 1 := by + simpa [weights] using hweights.sum_one q rfl + calc + weights q (c.prev q) + ∑ k ∈ others, weights q k = ∑ k, weights q k := hsum_decomp + _ = 1 := hsum_one + have hsum_others_le : sumOthers ≤ (epsAtCustom q : Real) := by + have hprev : 1 ≤ weights q (c.prev q) + (epsAtCustom q : Real) := + hweights.prev_large q rfl + have hprev' : + weights q (c.prev q) + sumOthers ≤ + weights q (c.prev q) + (epsAtCustom q : Real) := by + simpa [hsum, sumOthers] using hprev + exact (add_le_add_iff_left (weights q (c.prev q))).1 hprev' + have hvalsLo_real : ∀ k, (valsLo k : Real) ≤ vals k := by + intro k + have hvals := hvalsLo k + simpa [valsLo, valsLoArr, vals] using hvals + have hprev_lo : valsLoPrev ≤ vals (c.prev q) := by + simpa [valsLoPrev] using hvalsLo_real (c.prev q) + have hloAt_le_valsLo : ∀ k ∈ others, loAtRat ≤ valsLo k := by + intro k hk + have hnonempty : others.Nonempty := ⟨k, hk⟩ + have hmin : others.inf' hnonempty valsLo ≤ valsLo k := + Finset.inf'_le (s := others) (f := valsLo) hk + have hnonempty' : (Finset.univ.erase (c.prev q)).Nonempty := by + simpa [others] using hnonempty + have hloAt : loAtRat = others.inf' hnonempty valsLo := by + dsimp [loAtRat, loAt] + simp [hnonempty', others] + calc + loAtRat = others.inf' hnonempty valsLo := hloAt + _ ≤ valsLo k := hmin + have hvals_lo : ∀ k ∈ others, loAtReal ≤ vals k := by + intro k hk + have hloRat := hloAt_le_valsLo k hk + have hloReal : loAtReal ≤ (valsLo k : Real) := by + simpa [loAtReal, ratToReal_def] using (ratToReal_le_of_le hloRat) + have hvalsReal : (valsLo k : Real) ≤ vals k := hvalsLo_real k + exact le_trans hloReal hvalsReal + have hsum_nonneg : 0 ≤ sumOthers := by + have hnonneg : ∀ k ∈ others, 0 ≤ weights q k := by + intro k hk + exact hweights_nonneg k + have hsum_nonneg' : 0 ≤ ∑ k ∈ others, weights q k := by + exact Finset.sum_nonneg hnonneg + simpa [sumOthers] using hsum_nonneg' + have hsplit : + weights q (c.prev q) = 1 - sumOthers := by + have hsum' : weights q (c.prev q) + sumOthers = 1 := by + simpa [sumOthers] using hsum + exact (eq_sub_iff_add_eq).2 hsum' + have hdiff_le : valsLoPrev - loAtReal ≤ max (0 : Real) (valsLoPrev - loAtReal) := by + exact le_max_right _ _ + have hsum_mul_le_left : + sumOthers * (valsLoPrev - loAtReal) ≤ + sumOthers * max (0 : Real) (valsLoPrev - loAtReal) := by + exact mul_le_mul_of_nonneg_left hdiff_le hsum_nonneg + have hmax_nonneg : 0 ≤ max (0 : Real) (valsLoPrev - loAtReal) := by + exact le_max_left _ _ + have hsum_mul_le : + sumOthers * (valsLoPrev - loAtReal) ≤ + (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) := by + have hsum_mul_le_right : + sumOthers * max (0 : Real) (valsLoPrev - loAtReal) ≤ + (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) := by + exact mul_le_mul_of_nonneg_right hsum_others_le hmax_nonneg + exact le_trans hsum_mul_le_left hsum_mul_le_right + have hsub_le : + valsLoPrev - (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) ≤ + valsLoPrev - sumOthers * (valsLoPrev - loAtReal) := by + exact sub_le_sub_left hsum_mul_le valsLoPrev + have hsum_lo : + sumOthers * loAtReal = ∑ k ∈ others, weights q k * loAtReal := by + have hsum' : + (∑ k ∈ others, weights q k) * loAtReal = + ∑ k ∈ others, weights q k * loAtReal := by + simpa using + (Finset.sum_mul (s := others) (f := fun k => weights q k) (a := loAtReal)) + simpa [sumOthers] using hsum' + have hsum_vals_ge : + sumOthers * loAtReal ≤ ∑ k ∈ others, weights q k * vals k := by + have hle : ∀ k ∈ others, weights q k * loAtReal ≤ weights q k * vals k := by + intro k hk + have hlo := hvals_lo k hk + have hnonneg := hweights_nonneg k + exact mul_le_mul_of_nonneg_left hlo hnonneg + have hle' : + ∑ k ∈ others, weights q k * loAtReal ≤ + ∑ k ∈ others, weights q k * vals k := by + exact Finset.sum_le_sum hle + simpa [hsum_lo] using hle' + have hsum_vals_ge' : + ∑ k ∈ others, weights q k * loAtReal ≤ + ∑ k ∈ others, weights q k * vals k := by + simpa [hsum_lo] using hsum_vals_ge + have hdot_ge'' : + weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal ≤ + dotProduct (weights q) vals := by + have hdot_ge : + weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal ≤ + weights q (c.prev q) * vals (c.prev q) + sumOthers * loAtReal := by + have hprev_le := mul_le_mul_of_nonneg_left hprev_lo (hweights_nonneg (c.prev q)) + simpa [add_comm, add_left_comm, add_assoc] using + (add_le_add_right hprev_le (sumOthers * loAtReal)) + have hdot_ge' : + weights q (c.prev q) * vals (c.prev q) + sumOthers * loAtReal ≤ + dotProduct (weights q) vals := by + calc + weights q (c.prev q) * vals (c.prev q) + sumOthers * loAtReal = + weights q (c.prev q) * vals (c.prev q) + + ∑ k ∈ others, weights q k * loAtReal := by + simp [hsum_lo] + _ ≤ weights q (c.prev q) * vals (c.prev q) + + ∑ k ∈ others, weights q k * vals k := by + simpa [add_comm, add_left_comm, add_assoc] using + (add_le_add_left hsum_vals_ge' + (weights q (c.prev q) * vals (c.prev q))) + _ = dotProduct (weights q) vals := by + simp [dotProduct, others] + exact le_trans hdot_ge hdot_ge' + have hdot_lower : + valsLoPrev - (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) ≤ + dotProduct (weights q) vals := by + calc + valsLoPrev - (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) ≤ + valsLoPrev - sumOthers * (valsLoPrev - loAtReal) := hsub_le + _ = weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal := by + have hsplit_calc : + valsLoPrev - sumOthers * (valsLoPrev - loAtReal) = + (1 - sumOthers) * valsLoPrev + sumOthers * loAtReal := by + ring + simpa [hsplit] using hsplit_calc + _ ≤ dotProduct (weights q) vals := hdot_ge'' + have hle : (lb : Real) ≤ dotProduct (weights q) vals := + le_trans hboundReal hdot_lower + simpa [headLogitDiff_def, weights, vals] using hle + +end WithNeZero + +end Sound + +end Nfp diff --git a/Nfp/Sound/Induction/Refine.lean b/Nfp/Sound/Induction/Refine.lean index a273327..d9d74d4 100644 --- a/Nfp/Sound/Induction/Refine.lean +++ b/Nfp/Sound/Induction/Refine.lean @@ -35,6 +35,15 @@ theorem refineBudgetBoost_def (budget : Nat) : refineBudgetBoost budget = max (budget + 1) (2 * budget) := by rfl +/-- Scale used for refined value bounds. -/ +def valRefineScale (budget : Nat) : Nat := + Bounds.sqrtLowerScale * refineBudgetBoost budget + +/-- Unfolding lemma for `valRefineScale`. -/ +theorem valRefineScale_def (budget : Nat) : + valRefineScale budget = Bounds.sqrtLowerScale * refineBudgetBoost budget := by + rfl + /-- Worst key under the base score-gap lower bound (excluding `prev`). -/ def worstKeyBase (inputs : Model.InductionHeadInputs seq dModel dHead) @@ -85,6 +94,34 @@ theorem weightOneKeysAt_def others.filter (fun k => decide (cache.weightBoundAt q k = (1 : Rat))) := by rfl +/-- Keys attaining the per-query lower-value minimum (excluding `prev`). -/ +def loAtKeysAt + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (q : Fin seq) : Finset (Fin seq) := + let others : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + if h : others.Nonempty then + let lo := others.inf' h cache.valsLo + others.filter (fun k => decide (cache.valsLo k = lo)) + else + ∅ + +/-- Unfolding lemma for `loAtKeysAt`. -/ +theorem loAtKeysAt_def + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (q : Fin seq) : + loAtKeysAt inputs cache q = + let others : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) + if h : others.Nonempty then + let lo := others.inf' h cache.valsLo + others.filter (fun k => decide (cache.valsLo k = lo)) + else + ∅ := by + rfl + /-- Refinement keys for a query, seeded by negative base gaps and the worst key. -/ def refineKeysAt (inputs : Model.InductionHeadInputs seq dModel dHead) @@ -165,6 +202,84 @@ theorem refineSpecForQueryWithWeightOnes_def splitBudgetDiffRefined := budget } := by rfl +/-- Refined value lower bound at a single key (fallbacks to base bounds if disabled). -/ +def valsLoRefinedAt + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (budget : Nat) (k : Fin seq) : Rat := + let scale := valRefineScale budget + if _ : 0 < inputs.lnEps then + if _ : 0 < Bounds.sqrtLowerWithScale scale inputs.lnEps then + if _ : dModel = 0 then + cache.cert.values.valsLo k + else + let dirHead : Fin dHead → Rat := fun d => (dirHeadVecOfInputs inputs).get d + let wvDir : Fin dModel → Rat := + fun j => Linear.dotFin dHead dirHead (fun d => inputs.wv j d) + let bDir : Rat := + Linear.dotFin dHead dirHead (fun d => inputs.bv d) + let lnBounds := + Bounds.layerNormBoundsWithScale scale inputs.lnEps inputs.ln1Gamma inputs.ln1Beta + (inputs.embed k) + let lnLo := lnBounds.1 + let lnHi := lnBounds.2 + bDir + Bounds.dotIntervalLower wvDir lnLo lnHi + else + cache.cert.values.valsLo k + else + cache.cert.values.valsLo k + +/-- Unfolding lemma for `valsLoRefinedAt`. -/ +theorem valsLoRefinedAt_def + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (budget : Nat) (k : Fin seq) : + valsLoRefinedAt inputs cache budget k = + let scale := valRefineScale budget + if _ : 0 < inputs.lnEps then + if _ : 0 < Bounds.sqrtLowerWithScale scale inputs.lnEps then + if _ : dModel = 0 then + cache.cert.values.valsLo k + else + let dirHead : Fin dHead → Rat := fun d => (dirHeadVecOfInputs inputs).get d + let wvDir : Fin dModel → Rat := + fun j => Linear.dotFin dHead dirHead (fun d => inputs.wv j d) + let bDir : Rat := + Linear.dotFin dHead dirHead (fun d => inputs.bv d) + let lnBounds := + Bounds.layerNormBoundsWithScale scale inputs.lnEps inputs.ln1Gamma inputs.ln1Beta + (inputs.embed k) + let lnLo := lnBounds.1 + let lnHi := lnBounds.2 + bDir + Bounds.dotIntervalLower wvDir lnLo lnHi + else + cache.cert.values.valsLo k + else + cache.cert.values.valsLo k := by + rfl + +/-- Overlay refined value lower bounds on a subset of keys. -/ +def valsLoOverlay + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (budget : Nat) (refineKeys : Finset (Fin seq)) : Fin seq → Rat := fun k => + if k ∈ refineKeys then + valsLoRefinedAt inputs cache budget k + else + cache.cert.values.valsLo k + +/-- Unfolding lemma for `valsLoOverlay`. -/ +theorem valsLoOverlay_def + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (budget : Nat) (refineKeys : Finset (Fin seq)) (k : Fin seq) : + valsLoOverlay inputs cache budget refineKeys k = + if k ∈ refineKeys then + valsLoRefinedAt inputs cache budget k + else + cache.cert.values.valsLo k := by + rfl + /-- Refined diff dot-product lower bound at a single `(q,k)` pair. -/ def dotDiffLoRefinedAt (inputs : Model.InductionHeadInputs seq dModel dHead) diff --git a/Nfp/Sound/Induction/RefineSound.lean b/Nfp/Sound/Induction/RefineSound.lean index e3c90d1..ccfd3b6 100644 --- a/Nfp/Sound/Induction/RefineSound.lean +++ b/Nfp/Sound/Induction/RefineSound.lean @@ -2,9 +2,10 @@ module -public import Nfp.Sound.Induction.LogitDiff +public import Nfp.Sound.Induction.LogitDiffSound public import Nfp.Sound.Induction.OneHot public import Nfp.Sound.Induction.Refine +public import Nfp.Sound.Induction.CoreSound.Values /-! Soundness lemmas for refine-on-demand overlays. @@ -355,6 +356,88 @@ theorem oneHot_bounds_at_overlay · intro q' hq' k hk exact hweight_overlay q' hq' k hk +/-- Refined value lower bounds are sound when LayerNorm bounds are sound. -/ +theorem valsLoRefinedAt_le_valsReal + [NeZero seq] (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (budget : Nat) + (hsound : InductionHeadCertSound inputs cache.cert) + (k : Fin seq) : + (valsLoRefinedAt inputs cache budget k : Real) ≤ valsRealOfInputs inputs k := by + classical + by_cases hEps : 0 < inputs.lnEps + · by_cases hSqrt : + 0 < Bounds.sqrtLowerWithScale (valRefineScale budget) inputs.lnEps + · by_cases hmodel : dModel = 0 + · have hvals := hsound.value_bounds.vals_bounds k + simpa [valsLoRefinedAt_def, hEps, hSqrt, hmodel] using hvals.1 + · let scale : Nat := valRefineScale budget + let dirHead : Fin dHead → Rat := fun d => (dirHeadVecOfInputs inputs).get d + let wvDir : Fin dModel → Rat := + fun j => Linear.dotFin dHead dirHead (fun d => inputs.wv j d) + let bDir : Rat := + Linear.dotFin dHead dirHead (fun d => inputs.bv d) + let lnBounds : + Fin seq → (Fin dModel → Rat) × (Fin dModel → Rat) := fun k' => + Bounds.layerNormBoundsWithScale scale inputs.lnEps inputs.ln1Gamma inputs.ln1Beta + (inputs.embed k') + let lnLo : Fin seq → Fin dModel → Rat := fun k' => (lnBounds k').1 + let lnHi : Fin seq → Fin dModel → Rat := fun k' => (lnBounds k').2 + let valsLo : Fin seq → Rat := fun k' => + bDir + Bounds.dotIntervalLower wvDir (lnLo k') (lnHi k') + let valsHi : Fin seq → Rat := fun k' => + bDir + Bounds.dotIntervalUpper wvDir (lnLo k') (lnHi k') + have hscale_pos : 0 < scale := by + have hbase : 0 < Bounds.sqrtLowerScale := by + simp [Bounds.sqrtLowerScale_def] + have hboost : 0 < refineBudgetBoost budget := by + have hle : budget + 1 ≤ refineBudgetBoost budget := by + simp [refineBudgetBoost_def] + exact lt_of_lt_of_le (Nat.succ_pos budget) hle + simpa [scale, valRefineScale_def] using Nat.mul_pos hbase hboost + have hln : + ∀ k' j, (lnLo k' j : Real) ≤ lnRealOfInputs inputs k' j ∧ + lnRealOfInputs inputs k' j ≤ (lnHi k' j : Real) := by + intro k' j + have hln' := + Bounds.layerNormBoundsWithScale_spec (scale := scale) + (eps := inputs.lnEps) (gamma := inputs.ln1Gamma) (beta := inputs.ln1Beta) + (x := inputs.embed k') hmodel hEps hSqrt hscale_pos + simpa [lnBounds, lnLo, lnHi, lnRealOfInputs_def] using hln' j + have hvals := + valsReal_bounds_at_of_ln_bounds (inputs := inputs) + (dirHead := dirHead) (hdirHead := rfl) + (wvDir := wvDir) (bDir := bDir) + (hwvDir := by intro j; rfl) + (hbDir := by rfl) + (lnLo := lnLo) (lnHi := lnHi) + (valsLo := valsLo) (valsHi := valsHi) + (hvalsLo := by intro k'; rfl) + (hvalsHi := by intro k'; rfl) + (hln := hln) + have hvals_k := (hvals k).1 + simpa [valsLoRefinedAt_def, hEps, hSqrt, hmodel, scale, lnBounds, lnLo, lnHi, + valsLo, wvDir, bDir] using hvals_k + · have hvals := hsound.value_bounds.vals_bounds k + simpa [valsLoRefinedAt_def, hEps, hSqrt] using hvals.1 + · have hvals := hsound.value_bounds.vals_bounds k + simpa [valsLoRefinedAt_def, hEps] using hvals.1 + +/-- Overlayed value lower bounds remain sound. -/ +theorem valsLoOverlay_le_valsReal + [NeZero seq] (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (budget : Nat) (refineKeys : Finset (Fin seq)) + (hsound : InductionHeadCertSound inputs cache.cert) : + ∀ k, (valsLoOverlay inputs cache budget refineKeys k : Real) ≤ valsRealOfInputs inputs k := by + intro k + by_cases hmem : k ∈ refineKeys + · simpa [valsLoOverlay_def, hmem] using + (valsLoRefinedAt_le_valsReal (inputs := inputs) (cache := cache) + (budget := budget) (hsound := hsound) k) + · have hvals := hsound.value_bounds.vals_bounds k + simpa [valsLoOverlay_def, hmem] using hvals.1 + /-- The refined unweighted logit-diff lower bound is sound on active queries. -/ theorem logitDiffLowerBoundRefinedFromCache_le [NeZero seq] (inputs : Model.InductionHeadInputs seq dModel dHead) @@ -564,40 +647,110 @@ theorem logitDiffLowerBoundRefineOnDemand_le cases h2 : logitDiffLowerBoundRefinedFromCache inputs cache c logitCache spec' with | none => - have hlb : lb = lb01 := by - simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin, h1, - spec, refineBudget, lb01, hnonpos1, h2, spec', refineBudget'] using - hbound.symm - have hbase := hbase_le (lb0 := lb0) h0 - have hweight_overlay' : - ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → - Circuit.softmax (scoresRealOfInputs inputs q) k ≤ - (weightBoundAtOverlay inputs cache spec q k : Real) := by - simpa [spec, refineBudget] using hweight_overlay q0 refineBudget - have hrefine := - logitDiffLowerBoundRefinedFromCache_le - (inputs := inputs) - (cache := cache) - (c := c) - (logitCache := logitCache) - (spec := spec) - (hcert := hcert) - (hcache := hcache) - (hsound := hsound) - (hweight_overlay := hweight_overlay') - (hbound := h1) - (hq := hq) - have hmax' : - max (lb0 : Real) (lb1 : Real) ≤ headLogitDiff inputs q := by - exact max_le_iff.mpr ⟨hbase, hrefine⟩ - have hmax : (lb01 : Real) ≤ headLogitDiff inputs q := by - simpa [lb01, ratToReal_max] using hmax' - simpa [hlb] using hmax + let valBudget := refineBudgetBoost refineBudget + let valKeys := loAtKeysAt inputs cache q0 + let valsLo := valsLoOverlay inputs cache valBudget valKeys + cases hval : + logitDiffLowerBoundFromCacheWithEpsVals c logitCache.epsAt valsLo with + | none => + have hlb : lb = lb01 := by + simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin, + h1, spec, refineBudget, lb01, hnonpos1, h2, spec', refineBudget', + valBudget, valKeys, valsLo, hval] using hbound.symm + have hbase := hbase_le (lb0 := lb0) h0 + have hweight_overlay' : + ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → + Circuit.softmax (scoresRealOfInputs inputs q) k ≤ + (weightBoundAtOverlay inputs cache spec q k : Real) := by + simpa [spec, refineBudget] using hweight_overlay q0 refineBudget + have hrefine := + logitDiffLowerBoundRefinedFromCache_le + (inputs := inputs) + (cache := cache) + (c := c) + (logitCache := logitCache) + (spec := spec) + (hcert := hcert) + (hcache := hcache) + (hsound := hsound) + (hweight_overlay := hweight_overlay') + (hbound := h1) + (hq := hq) + have hmax' : + max (lb0 : Real) (lb1 : Real) ≤ headLogitDiff inputs q := by + exact max_le_iff.mpr ⟨hbase, hrefine⟩ + have hmax : (lb01 : Real) ≤ headLogitDiff inputs q := by + simpa [lb01, ratToReal_max] using hmax' + simpa [hlb] using hmax + | some lb2 => + have hlb : lb = max lb01 lb2 := by + simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin, + h1, spec, refineBudget, lb01, hnonpos1, h2, spec', refineBudget', + valBudget, valKeys, valsLo, hval] using hbound.symm + have hbase := hbase_le (lb0 := lb0) h0 + have hweight_overlay' : + ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → + Circuit.softmax (scoresRealOfInputs inputs q) k ≤ + (weightBoundAtOverlay inputs cache spec q k : Real) := by + simpa [spec, refineBudget] using hweight_overlay q0 refineBudget + have hrefine := + logitDiffLowerBoundRefinedFromCache_le + (inputs := inputs) + (cache := cache) + (c := c) + (logitCache := logitCache) + (spec := spec) + (hcert := hcert) + (hcache := hcache) + (hsound := hsound) + (hweight_overlay := hweight_overlay') + (hbound := h1) + (hq := hq) + have hmax' : + max (lb0 : Real) (lb1 : Real) ≤ headLogitDiff inputs q := by + exact max_le_iff.mpr ⟨hbase, hrefine⟩ + have hmax : (lb01 : Real) ≤ headLogitDiff inputs q := by + simpa [lb01, ratToReal_max] using hmax' + have hsound_cache : InductionHeadCertSound inputs cache.cert := by + simpa [hcert] using hsound + have hvalsLo : + ∀ k, (valsLo k : Real) ≤ valsRealOfInputs inputs k := by + exact valsLoOverlay_le_valsReal (inputs := inputs) (cache := cache) + (budget := valBudget) (refineKeys := valKeys) hsound_cache + have honeHot : + ∀ q, q ∈ c.active → + Layers.OneHotApproxBoundsOnActive (Val := Real) + ((logitDiffCache c).epsAt q : Real) + (fun q' => q' = q) c.prev + (fun q' k => Circuit.softmax (scoresRealOfInputs inputs q') k) := + by + intro q' hq' + have h := hsound.oneHot_bounds_at q' hq' + have heps : (logitDiffCache c).epsAt q' = c.epsAt q' := by + simp [logitDiffCache_def, Bounds.cacheBoundTask_apply] + simpa [heps] using h + have hval' : + logitDiffLowerBoundFromCacheWithEpsVals c (logitDiffCache c).epsAt + valsLo = some lb2 := by + simpa [hcache] using hval + have hrefine_val := + logitDiffLowerBoundFromCacheWithEpsVals_le + (inputs := inputs) + (c := c) + (epsAtCustom := (logitDiffCache c).epsAt) + (valsLoCustom := valsLo) + (hsound := hsound) + (honeHot := honeHot) + (hvalsLo := hvalsLo) + (hbound := hval') + (hq := hq) + have hmax' : + max (lb01 : Real) (lb2 : Real) ≤ headLogitDiff inputs q := by + exact max_le_iff.mpr ⟨hmax, hrefine_val⟩ + have hmax : (max lb01 lb2 : Real) ≤ headLogitDiff inputs q := by + simpa [ratToReal_max] using hmax' + simpa [hlb] using hmax | some lb2 => - have hlb : lb = max lb01 lb2 := by - simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin, h1, - spec, refineBudget, lb01, hnonpos1, h2, spec', refineBudget'] using - hbound.symm have hbase := hbase_le (lb0 := lb0) h0 have hweight_overlay' : ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → @@ -643,9 +796,73 @@ theorem logitDiffLowerBoundRefineOnDemand_le have hmax' : max (lb01 : Real) (lb2 : Real) ≤ headLogitDiff inputs q := by exact max_le_iff.mpr ⟨hmax01, hrefine'⟩ - have hmax : (max lb01 lb2 : Real) ≤ headLogitDiff inputs q := by + have hmax_weight : (max lb01 lb2 : Real) ≤ headLogitDiff inputs q := by simpa [ratToReal_max] using hmax' - simpa [hlb] using hmax + let lbWeight : Rat := max lb01 lb2 + by_cases hweight_nonpos : lbWeight ≤ 0 + · let valBudget := refineBudgetBoost refineBudget + let valKeys := loAtKeysAt inputs cache q0 + let valsLo := valsLoOverlay inputs cache valBudget valKeys + cases hval : + logitDiffLowerBoundFromCacheWithEpsVals c logitCache.epsAt valsLo with + | none => + have hlb : lb = lbWeight := by + simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin, + h1, spec, refineBudget, lb01, hnonpos1, h2, spec', + refineBudget', lbWeight, hweight_nonpos, valBudget, valKeys, + valsLo, hval] using hbound.symm + simpa [hlb, lbWeight] using hmax_weight + | some lb3 => + have hlb : lb = max lbWeight lb3 := by + simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin, + h1, spec, refineBudget, lb01, hnonpos1, h2, spec', + refineBudget', lbWeight, hweight_nonpos, valBudget, valKeys, + valsLo, hval] using hbound.symm + have hsound_cache : InductionHeadCertSound inputs cache.cert := by + simpa [hcert] using hsound + have hvalsLo : + ∀ k, (valsLo k : Real) ≤ valsRealOfInputs inputs k := by + exact valsLoOverlay_le_valsReal (inputs := inputs) (cache := cache) + (budget := valBudget) (refineKeys := valKeys) hsound_cache + have honeHot : + ∀ q, q ∈ c.active → + Layers.OneHotApproxBoundsOnActive (Val := Real) + ((logitDiffCache c).epsAt q : Real) + (fun q' => q' = q) c.prev + (fun q' k => + Circuit.softmax (scoresRealOfInputs inputs q') k) := by + intro q' hq' + have h := hsound.oneHot_bounds_at q' hq' + have heps : (logitDiffCache c).epsAt q' = c.epsAt q' := by + simp [logitDiffCache_def, Bounds.cacheBoundTask_apply] + simpa [heps] using h + have hval' : + logitDiffLowerBoundFromCacheWithEpsVals c (logitDiffCache c).epsAt + valsLo = some lb3 := by + simpa [hcache] using hval + have hrefine_val := + logitDiffLowerBoundFromCacheWithEpsVals_le + (inputs := inputs) + (c := c) + (epsAtCustom := (logitDiffCache c).epsAt) + (valsLoCustom := valsLo) + (hsound := hsound) + (honeHot := honeHot) + (hvalsLo := hvalsLo) + (hbound := hval') + (hq := hq) + have hmax' : + max (lbWeight : Real) (lb3 : Real) ≤ headLogitDiff inputs q := by + exact max_le_iff.mpr ⟨by simpa [lbWeight] using hmax_weight, + hrefine_val⟩ + have hmax : (max lbWeight lb3 : Real) ≤ headLogitDiff inputs q := by + simpa [ratToReal_max] using hmax' + simpa [hlb] using hmax + · have hlb : lb = lbWeight := by + simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin, h1, + spec, refineBudget, lb01, hnonpos1, h2, spec', refineBudget', + lbWeight, hweight_nonpos] using hbound.symm + simpa [hlb, lbWeight] using hmax_weight · have hlb : lb = lb01 := by simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin, h1, spec, refineBudget, lb01, hnonpos1] using hbound.symm From 3a74bc6a7fd4d986b3452684de033894ee2015ce Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 16 Jan 2026 13:54:03 +0100 Subject: [PATCH 188/244] Refine loAt-min keys in logit-diff bounds --- Nfp/Sound/Induction/Refine.lean | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Nfp/Sound/Induction/Refine.lean b/Nfp/Sound/Induction/Refine.lean index d9d74d4..b232fdd 100644 --- a/Nfp/Sound/Induction/Refine.lean +++ b/Nfp/Sound/Induction/Refine.lean @@ -146,12 +146,12 @@ theorem refineKeysAt_def | some k => insert k neg := by rfl -/-- Refinement keys that also include weight-one keys. -/ +/-- Refinement keys that also include weight-one and `loAt`-minimizing keys. -/ def refineKeysAtWithWeightOnes (inputs : Model.InductionHeadInputs seq dModel dHead) (cache : InductionHeadCoreCache seq dModel dHead) (q : Fin seq) : Finset (Fin seq) := - refineKeysAt inputs cache q ∪ weightOneKeysAt inputs cache q + refineKeysAt inputs cache q ∪ weightOneKeysAt inputs cache q ∪ loAtKeysAt inputs cache q /-- Unfolding lemma for `refineKeysAtWithWeightOnes`. -/ theorem refineKeysAtWithWeightOnes_def @@ -159,7 +159,7 @@ theorem refineKeysAtWithWeightOnes_def (cache : InductionHeadCoreCache seq dModel dHead) (q : Fin seq) : refineKeysAtWithWeightOnes inputs cache q = - refineKeysAt inputs cache q ∪ weightOneKeysAt inputs cache q := by + refineKeysAt inputs cache q ∪ weightOneKeysAt inputs cache q ∪ loAtKeysAt inputs cache q := by rfl /-- Refinement spec focused on a single query. -/ @@ -182,7 +182,7 @@ theorem refineSpecForQuery_def splitBudgetDiffRefined := budget } := by rfl -/-- Refinement spec for a single query, including weight-one keys. -/ +/-- Refinement spec for a single query, including weight-one and `loAt`-minimizing keys. -/ def refineSpecForQueryWithWeightOnes (inputs : Model.InductionHeadInputs seq dModel dHead) (cache : InductionHeadCoreCache seq dModel dHead) From 7239a066a12633814f9f264f999cf14c10dcbe3d Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 16 Jan 2026 14:15:26 +0100 Subject: [PATCH 189/244] Refine top-weight keys in logit-diff --- Nfp/IO/InductionHead/Basic.lean | 2 +- Nfp/IO/InductionHead/Nonvacuous.lean | 2 +- Nfp/Sound/Induction/Refine.lean | 68 ++++++++++++++++++++++++---- 3 files changed, 61 insertions(+), 11 deletions(-) diff --git a/Nfp/IO/InductionHead/Basic.lean b/Nfp/IO/InductionHead/Basic.lean index 1a72f93..ce3f2fe 100644 --- a/Nfp/IO/InductionHead/Basic.lean +++ b/Nfp/IO/InductionHead/Basic.lean @@ -862,7 +862,7 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} IO.eprintln s!"debug: loAt keys: {loAtMsg}" if (← logitDiffRefineEnabled) then let refineBudget := max 1 cfg.splitBudgetDiffRefined - let refineKeys := Sound.refineKeysAtWithWeightOnes inputs cache info.q + let refineKeys := Sound.refineKeysAtWithWeightOnes inputs cache info.q refineBudget IO.eprintln s!"debug: refine budget={refineBudget}, \ refineKeys.card={refineKeys.card}" diff --git a/Nfp/IO/InductionHead/Nonvacuous.lean b/Nfp/IO/InductionHead/Nonvacuous.lean index 73a8bc6..9291e14 100644 --- a/Nfp/IO/InductionHead/Nonvacuous.lean +++ b/Nfp/IO/InductionHead/Nonvacuous.lean @@ -159,7 +159,7 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} IO.eprintln s!"debug: loAt keys: {loAtMsg}" if (← logitDiffRefineEnabled) then let refineBudget := max 1 cfg.splitBudgetDiffRefined - let refineKeys := Sound.refineKeysAtWithWeightOnes inputs cache info.q + let refineKeys := Sound.refineKeysAtWithWeightOnes inputs cache info.q refineBudget IO.eprintln s!"debug: refine budget={refineBudget}, \ refineKeys.card={refineKeys.card}" diff --git a/Nfp/Sound/Induction/Refine.lean b/Nfp/Sound/Induction/Refine.lean index b232fdd..699fa5c 100644 --- a/Nfp/Sound/Induction/Refine.lean +++ b/Nfp/Sound/Induction/Refine.lean @@ -35,6 +35,15 @@ theorem refineBudgetBoost_def (budget : Nat) : refineBudgetBoost budget = max (budget + 1) (2 * budget) := by rfl +/-- Heuristic cap on the number of top-weight keys to refine. -/ +def refineTopWeightCount (budget : Nat) : Nat := + min 8 (max 1 (2 * budget)) + +/-- Unfolding lemma for `refineTopWeightCount`. -/ +theorem refineTopWeightCount_def (budget : Nat) : + refineTopWeightCount budget = min 8 (max 1 (2 * budget)) := by + rfl + /-- Scale used for refined value bounds. -/ def valRefineScale (budget : Nat) : Nat := Bounds.sqrtLowerScale * refineBudgetBoost budget @@ -122,6 +131,38 @@ theorem loAtKeysAt_def ∅ := by rfl +/-- Top-weight keys for a query (excluding `prev`), capped by `count`. -/ +def topWeightKeysAt + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (q : Fin seq) (count : Nat) : Finset (Fin seq) := + if count = 0 then + ∅ + else + let others := (List.finRange seq).filter (fun k => decide (k ≠ inputs.prev q)) + let weighted : Array (Rat × Fin seq) := + others.toArray.map (fun k => (cache.weightBoundAt q k, k)) + let sorted := weighted.qsort (fun a b => a.1 > b.1) + let keys := (sorted.toList.take count).map (fun p => p.2) + keys.foldr (fun k acc => insert k acc) ∅ + +/-- Unfolding lemma for `topWeightKeysAt`. -/ +theorem topWeightKeysAt_def + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (q : Fin seq) (count : Nat) : + topWeightKeysAt inputs cache q count = + if count = 0 then + ∅ + else + let others := (List.finRange seq).filter (fun k => decide (k ≠ inputs.prev q)) + let weighted : Array (Rat × Fin seq) := + others.toArray.map (fun k => (cache.weightBoundAt q k, k)) + let sorted := weighted.qsort (fun a b => a.1 > b.1) + let keys := (sorted.toList.take count).map (fun p => p.2) + keys.foldr (fun k acc => insert k acc) ∅ := by + rfl + /-- Refinement keys for a query, seeded by negative base gaps and the worst key. -/ def refineKeysAt (inputs : Model.InductionHeadInputs seq dModel dHead) @@ -146,20 +187,28 @@ theorem refineKeysAt_def | some k => insert k neg := by rfl -/-- Refinement keys that also include weight-one and `loAt`-minimizing keys. -/ +/-- Refinement keys that also include weight-one, `loAt`-minimizing, and top-weight keys. -/ def refineKeysAtWithWeightOnes (inputs : Model.InductionHeadInputs seq dModel dHead) (cache : InductionHeadCoreCache seq dModel dHead) - (q : Fin seq) : Finset (Fin seq) := - refineKeysAt inputs cache q ∪ weightOneKeysAt inputs cache q ∪ loAtKeysAt inputs cache q + (q : Fin seq) (budget : Nat) : Finset (Fin seq) := + let topCount := refineTopWeightCount budget + refineKeysAt inputs cache q ∪ + weightOneKeysAt inputs cache q ∪ + loAtKeysAt inputs cache q ∪ + topWeightKeysAt inputs cache q topCount /-- Unfolding lemma for `refineKeysAtWithWeightOnes`. -/ theorem refineKeysAtWithWeightOnes_def (inputs : Model.InductionHeadInputs seq dModel dHead) (cache : InductionHeadCoreCache seq dModel dHead) - (q : Fin seq) : - refineKeysAtWithWeightOnes inputs cache q = - refineKeysAt inputs cache q ∪ weightOneKeysAt inputs cache q ∪ loAtKeysAt inputs cache q := by + (q : Fin seq) (budget : Nat) : + refineKeysAtWithWeightOnes inputs cache q budget = + let topCount := refineTopWeightCount budget + refineKeysAt inputs cache q ∪ + weightOneKeysAt inputs cache q ∪ + loAtKeysAt inputs cache q ∪ + topWeightKeysAt inputs cache q topCount := by rfl /-- Refinement spec focused on a single query. -/ @@ -182,12 +231,13 @@ theorem refineSpecForQuery_def splitBudgetDiffRefined := budget } := by rfl -/-- Refinement spec for a single query, including weight-one and `loAt`-minimizing keys. -/ +/-- Refinement spec for a single query, including weight-one, `loAt`-minimizing, and top-weight +keys. -/ def refineSpecForQueryWithWeightOnes (inputs : Model.InductionHeadInputs seq dModel dHead) (cache : InductionHeadCoreCache seq dModel dHead) (q : Fin seq) (budget : Nat) : InductionHeadRefineSpec seq := - let keys := refineKeysAtWithWeightOnes inputs cache q + let keys := refineKeysAtWithWeightOnes inputs cache q budget { refineKeys := fun q' => if _ : q' = q then keys else ∅ splitBudgetDiffRefined := budget } @@ -197,7 +247,7 @@ theorem refineSpecForQueryWithWeightOnes_def (cache : InductionHeadCoreCache seq dModel dHead) (q : Fin seq) (budget : Nat) : refineSpecForQueryWithWeightOnes inputs cache q budget = - let keys := refineKeysAtWithWeightOnes inputs cache q + let keys := refineKeysAtWithWeightOnes inputs cache q budget { refineKeys := fun q' => if _ : q' = q then keys else ∅ splitBudgetDiffRefined := budget } := by rfl From ab82c87a946ce6390144369e0b7af4af698dc963 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 16 Jan 2026 14:41:32 +0100 Subject: [PATCH 190/244] Add LocalSystem sparsity lemmas --- Nfp/System/LocalSystem.lean | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/Nfp/System/LocalSystem.lean b/Nfp/System/LocalSystem.lean index 6870784..60ea203 100644 --- a/Nfp/System/LocalSystem.lean +++ b/Nfp/System/LocalSystem.lean @@ -47,6 +47,38 @@ theorem weight_eq_zero_of_not_parent (L : LocalSystem ι) {i j : ι} (h : ¬ L.d by simpa using L.support i j h +/-- Off-edge weights remain zero after coercion to a mixer. -/ +theorem toMixer_weight_eq_zero_of_not_parent (L : LocalSystem ι) (h : IsRowStochastic L) + {i j : ι} (hrel : ¬ L.dag.rel j i) : + (toMixer L h).weight i j = 0 := by + simpa [toMixer] using weight_eq_zero_of_not_parent (L := L) (i := i) (j := j) hrel + +/-- Row sums can be restricted to parents in a row-stochastic local system. -/ +theorem row_sum_parents (L : LocalSystem ι) (h : IsRowStochastic L) (i : ι) : + Finset.sum (L.dag.parents i) (fun j => L.weight i j) = 1 := by + classical + have hsubset : L.dag.parents i ⊆ (Finset.univ : Finset ι) := by + intro j _hj + exact Finset.mem_univ j + have hzero : + ∀ j ∈ (Finset.univ : Finset ι), j ∉ L.dag.parents i → L.weight i j = 0 := by + intro j _hj hj + have hrel : ¬ L.dag.rel j i := by + intro hrel + have hmem : j ∈ L.dag.parents i := by + simpa using hrel + exact hj hmem + exact weight_eq_zero_of_not_parent (L := L) (i := i) (j := j) hrel + have hsum : + Finset.sum (L.dag.parents i) (fun j => L.weight i j) = + Finset.sum (Finset.univ : Finset ι) (fun j => L.weight i j) := by + exact Finset.sum_subset hsubset hzero + calc + Finset.sum (L.dag.parents i) (fun j => L.weight i j) = + Finset.sum (Finset.univ : Finset ι) (fun j => L.weight i j) := hsum + _ = 1 := by + simpa using h i + /-- One-step evaluation functional used by `eval`. -/ def evalStep (L : LocalSystem ι) (input : ι → Mass) (i : ι) (rec : ∀ j, L.dag.rel j i → Mass) : Mass := From 0bef4077d3fc23c6154c5230d74ab2f75250e3cd Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 16 Jan 2026 18:35:30 +0100 Subject: [PATCH 191/244] Add induction diagnostics and literature-aligned discovery defaults --- Nfp/IO/InductionHead/Basic.lean | 186 ++++++++++++++++++ Nfp/IO/InductionHead/Nonvacuous.lean | 82 ++++++++ README.md | 5 + scripts/discover_gpt2_induction_targets.py | 216 ++++++++++++++++++--- scripts/scan_gpt2_induction_sound.py | 152 +++++++++------ scripts/sweep_gpt2_induction_nonvacuous.py | 154 +++++++++++---- 6 files changed, 669 insertions(+), 126 deletions(-) diff --git a/Nfp/IO/InductionHead/Basic.lean b/Nfp/IO/InductionHead/Basic.lean index ce3f2fe..200f21d 100644 --- a/Nfp/IO/InductionHead/Basic.lean +++ b/Nfp/IO/InductionHead/Basic.lean @@ -100,6 +100,40 @@ def logitDiffDebugEarlyExitEnabled : IO Bool := do def logitDiffRefineEnabled : IO Bool := do return (← IO.getEnv "NFP_LOGITDIFF_REFINE").isSome +/-- Parse an optional query index for alternative logit-diff bound diagnostics. -/ +def logitDiffAltBoundQuery : IO (Option Nat) := do + match (← IO.getEnv "NFP_LOGITDIFF_ALT_BOUND_Q") with + | none => return none + | some txt => + match txt.toNat? with + | some n => return some n + | none => + IO.eprintln s!"warn: invalid NFP_LOGITDIFF_ALT_BOUND_Q={txt}" + return none + +/-- Parse an optional query index for q-only logit-diff diagnostics. -/ +def logitDiffQueryOnly : IO (Option Nat) := do + match (← IO.getEnv "NFP_LOGITDIFF_Q_ONLY") with + | none => return none + | some txt => + match txt.toNat? with + | some n => return some n + | none => + IO.eprintln s!"warn: invalid NFP_LOGITDIFF_Q_ONLY={txt}" + return none + +/-- Check whether q-only logit-diff diagnostics should include refined weight bounds. -/ +def logitDiffQueryOnlyRefineEnabled : IO Bool := do + return (← IO.getEnv "NFP_LOGITDIFF_Q_ONLY_REFINE").isSome + +/-- Check whether q-only logit-diff diagnostics should include refined value bounds. -/ +def logitDiffQueryOnlyValsEnabled : IO Bool := do + return (← IO.getEnv "NFP_LOGITDIFF_Q_ONLY_VALS").isSome + +/-- Check whether q-only logit-diff diagnostics should exit early. -/ +def logitDiffQueryOnlyEarlyExitEnabled : IO Bool := do + return (← IO.getEnv "NFP_LOGITDIFF_Q_ONLY_EARLY_EXIT").isSome + private def renderResidualIntervalCert {n : Nat} (c : Circuit.ResidualIntervalCert n) : String := let header := s!"dim {n}" let lines := @@ -115,6 +149,76 @@ private def emitResidualIntervalCert {n : Nat} (c : Circuit.ResidualIntervalCert | some path => IO.FS.writeFile path (payload ++ "\n") | none => IO.println payload +/-- Emit q-only logit-diff diagnostics, returning whether early exit was requested. -/ +def emitLogitDiffQueryOnly {seq dModel dHead : Nat} [NeZero seq] + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cfg : Sound.InductionHeadSplitConfig) + (cache : Sound.InductionHeadCoreCache seq dModel dHead) + (cert : Sound.InductionHeadCert seq) + (logitCache : Sound.LogitDiffCache seq) : IO Bool := do + match (← logitDiffQueryOnly) with + | none => return false + | some qNat => + if hq : qNat < seq then + let q : Fin seq := ⟨qNat, hq⟩ + let prev := cert.prev q + let epsAt : Fin seq → Rat := logitCache.epsAt + let others : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).erase prev + IO.eprintln + s!"debug: q-only q={qNat} prev={prev.1} \ + epsAt={ratToString (epsAt q)}" + if (← logitDiffQueryOnlyValsEnabled) then + let valsLo : Fin seq → Rat := logitCache.valsLo + let loAt : Rat := + if h : others.Nonempty then + others.inf' h valsLo + else + cert.values.lo + let valsPrevLo := valsLo prev + let delta := valsPrevLo - loAt + let gap := epsAt q * max (0 : Rat) delta + let lbAtQ := valsPrevLo - gap + IO.eprintln + s!"debug: q-only loAt={ratToString loAt} \ + valsPrevLo={ratToString valsPrevLo} \ + lbAtQ={ratToString lbAtQ}" + if (← logitDiffQueryOnlyRefineEnabled) then + let refineBudget := max 1 cfg.splitBudgetDiffRefined + let spec := Sound.refineSpecForQueryWithWeightOnes inputs cache q refineBudget + let weightBoundAt := Sound.weightBoundAtOverlay inputs cache spec + let epsAtRef := Sound.epsAtOverlay cache weightBoundAt q + IO.eprintln + s!"debug: q-only refined budget={refineBudget} \ + epsAt={ratToString epsAtRef}" + if (← logitDiffQueryOnlyValsEnabled) then + let valBudget := Sound.refineBudgetBoost refineBudget + let valKeys := Sound.loAtKeysAt inputs cache q + let valsLoRef : Fin seq → Rat := + Sound.valsLoOverlay inputs cache valBudget valKeys + let loAtRef : Rat := + if h : others.Nonempty then + others.inf' h valsLoRef + else + cert.values.lo + let valsPrevLoRef := valsLoRef prev + let deltaRef := valsPrevLoRef - loAtRef + let gapRef := epsAtRef * max (0 : Rat) deltaRef + let lbAtQRef := valsPrevLoRef - gapRef + IO.eprintln + s!"debug: q-only refined loAt={ratToString loAtRef} \ + valsPrevLo={ratToString valsPrevLoRef} \ + lbAtQ={ratToString lbAtQRef}" + let earlyExit := (← logitDiffQueryOnlyEarlyExitEnabled) || + (← logitDiffDebugEarlyExitEnabled) + if earlyExit then + IO.eprintln "debug: early exit requested (q-only)" + return true + return false + else + IO.eprintln s!"warn: NFP_LOGITDIFF_Q_ONLY={qNat} out of range (seq={seq})" + return false + private def buildHeadOutputIntervalFromInputs {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (outPath? : Option System.FilePath) : IO UInt32 := do @@ -796,6 +900,10 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} | some v => some v | none => some (0 : Rat) let logitCache := Nfp.Sound.logitDiffCache cert + let qOnlyExit ← + emitLogitDiffQueryOnly inputs cfg cache cert logitCache + if qOnlyExit then + return 2 let emitLogitDiffDebug (info : Nfp.Sound.LogitDiffAtLoDebug seq) : IO Unit := do IO.eprintln s!"debug: logitDiffLB0 witness q={info.q.1}, prev={info.prev.1}" @@ -860,6 +968,38 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} else String.intercalate ", " loAtKeys.toList IO.eprintln s!"debug: loAt keys: {loAtMsg}" + let scoreLoPrev := cache.scoreLoPrev info.q + let stepAlt : + (Rat × Nat × Nat) → Fin seq → (Rat × Nat × Nat) := + fun acc k => + if k = info.prev then + acc + else + let g := scoreLoPrev - cache.scoreHi info.q k + let nonneg := if g ≥ (0 : Rat) then acc.2.1 + 1 else acc.2.1 + let gtNegOne := if g > (-1 : Rat) then acc.2.2 + 1 else acc.2.2 + let expLB := + if g ≥ (0 : Rat) then + (1 : Rat) + g + g * g / (2 : Rat) + else + max (0 : Rat) ((1 : Rat) + g) + let w := (1 : Rat) / ((1 : Rat) + expLB) + (acc.1 + w, nonneg, gtNegOne) + let accAlt := Sound.Linear.foldlFin seq stepAlt (0, 0, 0) + IO.eprintln + s!"debug: alt-exp epsAt={ratToString accAlt.1}, \ + g>=0={accAlt.2.1}, g>-1={accAlt.2.2}" + let stepMin : Option Rat → Fin seq → Option Rat := + fun acc k => + if k = info.prev then + acc + else + let g := scoreLoPrev - cache.scoreHi info.q k + match acc with + | none => some g + | some cur => some (min cur g) + let minGap := Sound.Linear.foldlFin seq stepMin none + IO.eprintln s!"debug: alt-exp min(scoreLoPrev-scoreHi)={ratOptToString minGap}" if (← logitDiffRefineEnabled) then let refineBudget := max 1 cfg.splitBudgetDiffRefined let refineKeys := Sound.refineKeysAtWithWeightOnes inputs cache info.q refineBudget @@ -881,6 +1021,52 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} timingPrint "timing: head logit-diff lower bound start" timingFlush profileLogitDiffWeighted cert logitCache + let altQuery? ← logitDiffAltBoundQuery + match altQuery? with + | none => pure () + | some qNat => + if hq : qNat < seq then + let q : Fin seq := ⟨qNat, hq⟩ + let prev := cert.prev q + let scoreLoPrev := cache.scoreLoPrev q + let stepAlt : + (Rat × Nat × Nat) → Fin seq → (Rat × Nat × Nat) := + fun acc k => + if k = prev then + acc + else + let g := scoreLoPrev - cache.scoreHi q k + let nonneg := if g ≥ (0 : Rat) then acc.2.1 + 1 else acc.2.1 + let gtNegOne := if g > (-1 : Rat) then acc.2.2 + 1 else acc.2.2 + let expLB := + if g ≥ (0 : Rat) then + (1 : Rat) + g + g * g / (2 : Rat) + else + max (0 : Rat) ((1 : Rat) + g) + let w := (1 : Rat) / ((1 : Rat) + expLB) + (acc.1 + w, nonneg, gtNegOne) + let accAlt := Sound.Linear.foldlFin seq stepAlt (0, 0, 0) + let stepMin : Option Rat → Fin seq → Option Rat := + fun acc k => + if k = prev then + acc + else + let g := scoreLoPrev - cache.scoreHi q k + match acc with + | none => some g + | some cur => some (min cur g) + let minGap := Sound.Linear.foldlFin seq stepMin none + IO.eprintln + s!"debug: alt-exp q={qNat} prev={prev.1} \ + epsAt={ratToString accAlt.1} \ + g>=0={accAlt.2.1} g>-1={accAlt.2.2} \ + minGap={ratOptToString minGap}" + if (← logitDiffDebugEarlyExitEnabled) then + IO.eprintln "debug: early exit requested (alt bound)" + return 2 + else + IO.eprintln + s!"warn: NFP_LOGITDIFF_ALT_BOUND_Q={qNat} out of range (seq={seq})" let earlyExit? ← if (← logitDiffDebugEnabled) && (← logitDiffDebugEarlyExitEnabled) then let debug? ← timePureWithHeartbeat diff --git a/Nfp/IO/InductionHead/Nonvacuous.lean b/Nfp/IO/InductionHead/Nonvacuous.lean index 9291e14..3f0e96c 100644 --- a/Nfp/IO/InductionHead/Nonvacuous.lean +++ b/Nfp/IO/InductionHead/Nonvacuous.lean @@ -93,6 +93,10 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} timingPrint "timing: head logit-diff lower bound start" timingFlush let logitCache := Nfp.Sound.logitDiffCache cert + let qOnlyExit ← + emitLogitDiffQueryOnly inputs cfg cache cert logitCache + if qOnlyExit then + return 2 let emitLogitDiffDebug (info : Nfp.Sound.LogitDiffAtLoDebug seq) : IO Unit := do IO.eprintln s!"debug: logitDiffLB0 witness q={info.q.1}, prev={info.prev.1}" @@ -157,6 +161,38 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} else String.intercalate ", " loAtKeys.toList IO.eprintln s!"debug: loAt keys: {loAtMsg}" + let scoreLoPrev := cache.scoreLoPrev info.q + let stepAlt : + (Rat × Nat × Nat) → Fin seq → (Rat × Nat × Nat) := + fun acc k => + if k = info.prev then + acc + else + let g := scoreLoPrev - cache.scoreHi info.q k + let nonneg := if g ≥ (0 : Rat) then acc.2.1 + 1 else acc.2.1 + let gtNegOne := if g > (-1 : Rat) then acc.2.2 + 1 else acc.2.2 + let expLB := + if g ≥ (0 : Rat) then + (1 : Rat) + g + g * g / (2 : Rat) + else + max (0 : Rat) ((1 : Rat) + g) + let w := (1 : Rat) / ((1 : Rat) + expLB) + (acc.1 + w, nonneg, gtNegOne) + let accAlt := Sound.Linear.foldlFin seq stepAlt (0, 0, 0) + IO.eprintln + s!"debug: alt-exp epsAt={ratToString accAlt.1}, \ + g>=0={accAlt.2.1}, g>-1={accAlt.2.2}" + let stepMin : Option Rat → Fin seq → Option Rat := + fun acc k => + if k = info.prev then + acc + else + let g := scoreLoPrev - cache.scoreHi info.q k + match acc with + | none => some g + | some cur => some (min cur g) + let minGap := Sound.Linear.foldlFin seq stepMin none + IO.eprintln s!"debug: alt-exp min(scoreLoPrev-scoreHi)={ratOptToString minGap}" if (← logitDiffRefineEnabled) then let refineBudget := max 1 cfg.splitBudgetDiffRefined let refineKeys := Sound.refineKeysAtWithWeightOnes inputs cache info.q refineBudget @@ -179,6 +215,52 @@ private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} profileLogitDiffWeighted cert logitCache else pure () + let altQuery? ← logitDiffAltBoundQuery + match altQuery? with + | none => pure () + | some qNat => + if hq : qNat < seq then + let q : Fin seq := ⟨qNat, hq⟩ + let prev := cert.prev q + let scoreLoPrev := cache.scoreLoPrev q + let stepAlt : + (Rat × Nat × Nat) → Fin seq → (Rat × Nat × Nat) := + fun acc k => + if k = prev then + acc + else + let g := scoreLoPrev - cache.scoreHi q k + let nonneg := if g ≥ (0 : Rat) then acc.2.1 + 1 else acc.2.1 + let gtNegOne := if g > (-1 : Rat) then acc.2.2 + 1 else acc.2.2 + let expLB := + if g ≥ (0 : Rat) then + (1 : Rat) + g + g * g / (2 : Rat) + else + max (0 : Rat) ((1 : Rat) + g) + let w := (1 : Rat) / ((1 : Rat) + expLB) + (acc.1 + w, nonneg, gtNegOne) + let accAlt := Sound.Linear.foldlFin seq stepAlt (0, 0, 0) + let stepMin : Option Rat → Fin seq → Option Rat := + fun acc k => + if k = prev then + acc + else + let g := scoreLoPrev - cache.scoreHi q k + match acc with + | none => some g + | some cur => some (min cur g) + let minGap := Sound.Linear.foldlFin seq stepMin none + IO.eprintln + s!"debug: alt-exp q={qNat} prev={prev.1} \ + epsAt={ratToString accAlt.1} \ + g>=0={accAlt.2.1} g>-1={accAlt.2.2} \ + minGap={ratOptToString minGap}" + if (← logitDiffDebugEarlyExitEnabled) then + IO.eprintln "debug: early exit requested (alt bound)" + return 2 + else + IO.eprintln + s!"warn: NFP_LOGITDIFF_ALT_BOUND_Q={qNat} out of range (seq={seq})" let earlyExit? ← if (← logitDiffDebugEnabled) && (← logitDiffDebugEarlyExitEnabled) then let debug? ← timePureWithHeartbeat diff --git a/README.md b/README.md index b56fdef..77ecb4f 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,11 @@ verified by the CLI. ### Build a head certificate (untrusted) +Note: the discovery/scan/sweep helper scripts use **one-based** layer/head +indices (literature-aligned), default to **bigram prefix matching** for +`prev`, and **rank by attention score** unless you explicitly switch to +logit-diff mode. The Lean CLI continues to accept zero-based layer/head indices. + ```bash python scripts/build_gpt2_induction_cert.py \ --output reports/gpt2_induction.cert \ diff --git a/scripts/discover_gpt2_induction_targets.py b/scripts/discover_gpt2_induction_targets.py index b4cce0f..721eb23 100644 --- a/scripts/discover_gpt2_induction_targets.py +++ b/scripts/discover_gpt2_induction_targets.py @@ -6,7 +6,14 @@ This script is untrusted: it uses floating-point arithmetic to score candidates and optionally invokes the Lean verifier (`nfp induction certify_head_model_nonvacuous`) -to confirm nonvacuous bounds. +to confirm nonvacuous bounds when scoring by logit-diff. + +Layer/head indices are one-based to align with the mechanistic interpretability +literature. + +By default, `prev`/active are built from bigram prefix matches (the token at +q-1 maps to its previous occurrence), and heads are ranked by attention to +`prev`. """ from __future__ import annotations @@ -35,6 +42,22 @@ class HeadResult: min_prev: float value_range: float active: int + prev_mean: float + prev_median: float + prev_top1_frac: float + + +@dataclass(frozen=True) +class AttnResult: + layer: int + head: int + score: float + prev_mean: float + prev_median: float + prev_top1_frac: float + eps: float + margin: float + active: int def parse_header(f) -> Dict[str, str]: @@ -93,6 +116,17 @@ def build_prev(tokens: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: return prev, active +def build_prev_bigram(tokens: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + prev_token, active_token = build_prev(tokens) + prev = np.zeros_like(tokens) + active = np.zeros_like(tokens, dtype=bool) + if tokens.size <= 1: + return prev, active + prev[1:] = prev_token[:-1] + active[1:] = active_token[:-1] + return prev, active + + def layer_norm(x: np.ndarray, gamma: np.ndarray, beta: np.ndarray, eps: float) -> np.ndarray: mean = x.mean(axis=1, keepdims=True) var = ((x - mean) ** 2).mean(axis=1, keepdims=True) @@ -118,9 +152,9 @@ def parse_index_list(raw: str | None, max_value: int) -> List[int] | None: if not part: continue idx = int(part) - if idx < 0 or idx >= max_value: - raise ValueError(f"index {idx} out of range [0,{max_value})") - out.append(idx) + if idx <= 0 or idx > max_value: + raise ValueError(f"index {idx} out of range [1,{max_value}]") + out.append(idx - 1) return out @@ -154,11 +188,16 @@ def read_unembed_column( return data -def compute_eps_margin(weights: np.ndarray, scores: np.ndarray, - prev: np.ndarray, active_positions: Iterable[int]) -> Tuple[float, float]: +def compute_eps_margin( + weights: np.ndarray, + scores: np.ndarray, + prev: np.ndarray, + active_positions: Iterable[int], +) -> Tuple[float, float, float, float, float]: eps_vals: List[float] = [] margin_vals: List[float] = [] - seq = weights.shape[0] + prev_vals: List[float] = [] + max_other_vals: List[float] = [] for q in active_positions: prev_q = int(prev[q]) prev_w = weights[q, prev_q] @@ -166,18 +205,39 @@ def compute_eps_margin(weights: np.ndarray, scores: np.ndarray, eps_vals.append(max(max_other, 1.0 - prev_w)) diffs = scores[q, prev_q] - np.delete(scores[q], prev_q) margin_vals.append(float(np.min(diffs)) if diffs.size > 0 else 0.0) + prev_vals.append(float(prev_w)) + max_other_vals.append(float(max_other)) if not eps_vals: - return 0.0, 0.0 - return max(eps_vals), min(margin_vals) + return 0.0, 0.0, 0.0, 0.0, 0.0 + prev_arr = np.asarray(prev_vals, dtype=np.float64) + max_other_arr = np.asarray(max_other_vals, dtype=np.float64) + prev_mean = float(prev_arr.mean()) + prev_median = float(np.median(prev_arr)) + prev_top1 = float(np.mean(prev_arr >= max_other_arr)) + return max(eps_vals), min(margin_vals), prev_mean, prev_median, prev_top1 def format_result(result: HeadResult) -> str: + layer = result.layer + 1 + head = result.head + 1 return ( - f"L{result.layer}H{result.head} target={result.target} " + f"L{layer}H{head} target={result.target} " f"negative={result.negative} logitLB={result.logit_lb:.6f} " f"eps={result.eps:.6f} margin={result.margin:.6f} " f"minPrev={result.min_prev:.6f} range={result.value_range:.6f} " - f"active={result.active}" + f"prevMean={result.prev_mean:.6f} prevMedian={result.prev_median:.6f} " + f"prevTop1={result.prev_top1_frac:.3f} active={result.active}" + ) + + +def format_attn_result(result: AttnResult) -> str: + layer = result.layer + 1 + head = result.head + 1 + return ( + f"L{layer}H{head} score={result.score:.6f} " + f"prevMean={result.prev_mean:.6f} prevMedian={result.prev_median:.6f} " + f"prevTop1={result.prev_top1_frac:.3f} " + f"eps={result.eps:.6f} margin={result.margin:.6f} active={result.active}" ) @@ -189,6 +249,18 @@ def main() -> int: parser.add_argument("--top", type=int, default=20, help="Number of results to report") parser.add_argument("--verify-top", type=int, default=0, help="Run verifier on the top N candidates") + parser.add_argument( + "--score-mode", + choices=["attn", "logit"], + default="attn", + help="Rank heads by attention to prev (attn) or logit lower bound (logit).", + ) + parser.add_argument( + "--min-score", + type=float, + default=0.0, + help="Minimum attention score (attn mode).", + ) parser.add_argument("--min-eps", type=float, default=0.5, help="Filter candidates with eps above this value") parser.add_argument("--min-margin", type=float, default=0.0, @@ -198,6 +270,12 @@ def main() -> int: parser.add_argument("--layers", help="Comma-separated layer list or 'all'") parser.add_argument("--heads", help="Comma-separated head list or 'all'") parser.add_argument("--period", type=int, help="Optional prompt period override") + parser.add_argument( + "--prev-mode", + choices=["bigram", "token", "period"], + default="bigram", + help="Choose prev/active construction (default: bigram prefix match).", + ) parser.add_argument("--output", type=Path, default=Path("reports/gpt2_induction_discover.txt")) parser.add_argument("--json-out", type=Path, help="Optional JSON output path") parser.add_argument("--nfp-bin", help="Path to nfp binary (defaults to .lake/build/bin/nfp)") @@ -205,6 +283,8 @@ def main() -> int: if args.max_tokens <= 1: raise SystemExit("max-tokens must be at least 2") + if args.verify_top > 0 and args.score_mode != "logit": + raise SystemExit("--verify-top requires --score-mode=logit") if not args.model.exists(): raise SystemExit(f"Missing model file: {args.model}") @@ -226,11 +306,18 @@ def main() -> int: tokens = read_i32(f, seq_len) embeddings = read_f64(f, seq_len * model_dim).reshape(seq_len, model_dim) - if args.period is not None: + if args.prev_mode != "period" and args.period is not None: + raise SystemExit("--period is incompatible with --prev-mode=token/bigram") + if args.prev_mode == "period" and args.period is None: + raise SystemExit("--prev-mode=period requires --period") + + if args.prev_mode == "period": period = int(args.period) prev = np.arange(seq_len, dtype=np.int64) prev = np.where(prev >= period, prev - period, 0) active_mask = np.arange(seq_len) >= period + elif args.prev_mode == "bigram": + prev, active_mask = build_prev_bigram(tokens) else: prev, active_mask = build_prev(tokens) @@ -249,7 +336,10 @@ def main() -> int: if len(unique_tokens) < 2: raise SystemExit("Need at least two unique tokens to form directions") - head_data: Dict[Tuple[int, int], Tuple[np.ndarray, np.ndarray, float, float]] = {} + head_data: Dict[ + Tuple[int, int], + Tuple[np.ndarray, np.ndarray, float, float, float, float, float], + ] = {} for layer_idx in range(num_layers): head_weights = [] @@ -289,8 +379,18 @@ def main() -> int: scores[mask] = -10000.0 weights = softmax(scores) - eps, margin = compute_eps_margin(weights, scores, prev, active_positions) - head_data[(layer_idx, head_idx)] = (v, wo, eps, margin) + eps, margin, prev_mean, prev_median, prev_top1 = compute_eps_margin( + weights, scores, prev, active_positions + ) + head_data[(layer_idx, head_idx)] = ( + v, + wo, + eps, + margin, + prev_mean, + prev_median, + prev_top1, + ) ln_f_gamma = read_f64(f, model_dim) _ln_f_beta = read_f64(f, model_dim) @@ -308,9 +408,29 @@ def main() -> int: ) results: List[HeadResult] = [] + attn_results: List[AttnResult] = [] prev_indices = prev[np.array(active_positions, dtype=np.int64)] - for (layer_idx, head_idx), (v, wo, eps, margin) in head_data.items(): - if eps > args.min_eps or margin < args.min_margin: + for (layer_idx, head_idx), (v, wo, eps, margin, prev_mean, prev_median, prev_top1) in head_data.items(): + if args.score_mode == "logit": + if eps > args.min_eps or margin < args.min_margin: + continue + if args.score_mode == "attn": + score = prev_mean + if score < args.min_score: + continue + attn_results.append( + AttnResult( + layer=layer_idx, + head=head_idx, + score=score, + prev_mean=prev_mean, + prev_median=prev_median, + prev_top1_frac=prev_top1, + eps=eps, + margin=margin, + active=len(active_positions), + ) + ) continue proj: Dict[int, np.ndarray] = {} for tok in unique_tokens: @@ -340,25 +460,43 @@ def main() -> int: min_prev=min_prev, value_range=value_range, active=len(active_positions), + prev_mean=prev_mean, + prev_median=prev_median, + prev_top1_frac=prev_top1, ) if best is None or candidate.logit_lb > best.logit_lb: best = candidate if best is not None: results.append(best) - results.sort(key=lambda r: r.logit_lb, reverse=True) + if args.score_mode == "attn": + attn_results.sort(key=lambda r: r.score, reverse=True) + else: + results.sort(key=lambda r: r.logit_lb, reverse=True) args.output.parent.mkdir(parents=True, exist_ok=True) with args.output.open("w", encoding="ascii") as f: f.write("Induction discovery (approximate ranking)\n") f.write(f"model={args.model}\n") + f.write(f"score_mode={args.score_mode}\n") f.write(f"tokens={len(unique_tokens)} active={len(active_positions)}\n") - f.write(f"min-eps={args.min_eps} min-margin={args.min_margin} min-logit-lb={args.min_logit_lb}\n") - for rank, result in enumerate(results[: args.top], start=1): - f.write(f"{rank:02d} {format_result(result)}\n") + f.write( + f"min-eps={args.min_eps} min-margin={args.min_margin} " + f"min-logit-lb={args.min_logit_lb} min-score={args.min_score}\n" + ) + if args.score_mode == "attn": + for rank, result in enumerate(attn_results[: args.top], start=1): + f.write(f"{rank:02d} {format_attn_result(result)}\n") + else: + for rank, result in enumerate(results[: args.top], start=1): + f.write(f"{rank:02d} {format_result(result)}\n") print(f"Wrote report to {args.output}") - for rank, result in enumerate(results[: args.top], start=1): - print(f"{rank:02d} {format_result(result)}") + if args.score_mode == "attn": + for rank, result in enumerate(attn_results[: args.top], start=1): + print(f"{rank:02d} {format_attn_result(result)}") + else: + for rank, result in enumerate(results[: args.top], start=1): + print(f"{rank:02d} {format_result(result)}") if args.json_out is not None: args.json_out.parent.mkdir(parents=True, exist_ok=True) @@ -366,14 +504,34 @@ def main() -> int: "model": str(args.model), "tokens": len(unique_tokens), "active": len(active_positions), + "score_mode": args.score_mode, "min_eps": args.min_eps, "min_margin": args.min_margin, "min_logit_lb": args.min_logit_lb, - "results": [ + "min_score": args.min_score, + } + if args.score_mode == "attn": + payload["results"] = [ + { + "rank": rank, + "layer": r.layer + 1, + "head": r.head + 1, + "score": r.score, + "prev_mean": r.prev_mean, + "prev_median": r.prev_median, + "prev_top1_frac": r.prev_top1_frac, + "eps": r.eps, + "margin": r.margin, + "active": r.active, + } + for rank, r in enumerate(attn_results[: args.top], start=1) + ] + else: + payload["results"] = [ { "rank": rank, - "layer": r.layer, - "head": r.head, + "layer": r.layer + 1, + "head": r.head + 1, "target": r.target, "negative": r.negative, "logit_lb": r.logit_lb, @@ -381,11 +539,13 @@ def main() -> int: "margin": r.margin, "min_prev": r.min_prev, "value_range": r.value_range, + "prev_mean": r.prev_mean, + "prev_median": r.prev_median, + "prev_top1_frac": r.prev_top1_frac, "active": r.active, } for rank, r in enumerate(results[: args.top], start=1) - ], - } + ] args.json_out.write_text(json.dumps(payload, indent=2), encoding="ascii") if args.verify_top > 0 and results: diff --git a/scripts/scan_gpt2_induction_sound.py b/scripts/scan_gpt2_induction_sound.py index b9e8ec4..dbac47c 100755 --- a/scripts/scan_gpt2_induction_sound.py +++ b/scripts/scan_gpt2_induction_sound.py @@ -2,12 +2,15 @@ # SPDX-License-Identifier: AGPL-3.0-or-later """ -Scan GPT-2 induction head candidates with SOUND logit-diff bounds. +Scan GPT-2 induction head candidates with attention or logit-diff bounds. This script: 1) Ensures a GPT-2 "rigorous induction" binary model exists. -2) Uses the untrusted discovery helper to propose head/direction candidates. -3) Runs `nfp induction certify_head_model_nonvacuous` to check each candidate. +2) Uses the untrusted discovery helper to propose head candidates. +3) Optionally runs `nfp induction certify_head_model_nonvacuous` in logit mode. + +Layer/head indices are one-based (literature-aligned). `prev` defaults to bigram +prefix matching. """ from __future__ import annotations @@ -119,9 +122,22 @@ def main() -> int: parser.add_argument("--min-eps", type=float, default=0.5) parser.add_argument("--min-margin", type=float, default=0.0) parser.add_argument("--min-logit-lb", type=float, default=0.0) + parser.add_argument("--min-score", type=float, default=0.0) + parser.add_argument( + "--score-mode", + choices=["attn", "logit"], + default="attn", + help="Rank by attention score or logit-diff bound.", + ) parser.add_argument("--layers", help="Comma-separated layer list or 'all'") parser.add_argument("--heads", help="Comma-separated head list or 'all'") parser.add_argument("--period", type=int) + parser.add_argument( + "--prev-mode", + choices=["bigram", "token", "period"], + default="bigram", + help="Choose prev/active construction (forwarded to discovery).", + ) parser.add_argument("--output", default="reports/gpt2_induction_sound_scan.txt") args = parser.parse_args() args.jobs = max(1, args.jobs) @@ -156,12 +172,16 @@ def main() -> int: str(model_path), "--top", str(args.top), + "--score-mode", + args.score_mode, "--min-eps", str(args.min_eps), "--min-margin", str(args.min_margin), "--min-logit-lb", str(args.min_logit_lb), + "--min-score", + str(args.min_score), "--output", str(discover_txt), "--json-out", @@ -173,6 +193,8 @@ def main() -> int: discover_cmd += ["--heads", args.heads] if args.period is not None: discover_cmd += ["--period", str(args.period)] + if args.prev_mode != "bigram": + discover_cmd += ["--prev-mode", args.prev_mode] run_cmd(discover_cmd) payload = json.loads(discover_json.read_text(encoding="ascii")) candidates = payload.get("results", []) @@ -182,70 +204,90 @@ def main() -> int: results: list[tuple[Fraction, dict[str, int]]] = [] - def run_cert(candidate: dict[str, int]) -> tuple[dict[str, int], Fraction | None]: - layer = int(candidate["layer"]) - head = int(candidate["head"]) - target_id = int(candidate.get("target", target)) - negative_id = int(candidate.get("negative", negative)) - cmd = nfp_cmd + [ - "induction", - "certify_head_model_nonvacuous", - "--model", - str(model_path), - "--layer", - str(layer), - "--head", - str(head), - "--direction-target", - str(target_id), - "--direction-negative", - str(negative_id), - ] - if args.period is not None: - cmd += ["--period", str(args.period)] - try: - cert_out = run_cmd(cmd) - except subprocess.CalledProcessError: - return candidate, None - return candidate, parse_logit_lb(cert_out) - - if args.jobs == 1: - for candidate in candidates: - candidate_out, logit_lb = run_cert(candidate) - if logit_lb is None: - continue - results.append((logit_lb, candidate_out)) - else: - with ThreadPoolExecutor(max_workers=args.jobs) as executor: - futures = {executor.submit(run_cert, candidate): candidate for candidate in candidates} - for future in as_completed(futures): - candidate_out, logit_lb = future.result() + if args.score_mode == "logit": + def run_cert(candidate: dict[str, int]) -> tuple[dict[str, int], Fraction | None]: + layer = int(candidate["layer"]) - 1 + head = int(candidate["head"]) - 1 + target_id = int(candidate.get("target", target)) + negative_id = int(candidate.get("negative", negative)) + cmd = nfp_cmd + [ + "induction", + "certify_head_model_nonvacuous", + "--model", + str(model_path), + "--layer", + str(layer), + "--head", + str(head), + "--direction-target", + str(target_id), + "--direction-negative", + str(negative_id), + ] + if args.period is not None: + cmd += ["--period", str(args.period)] + try: + cert_out = run_cmd(cmd) + except subprocess.CalledProcessError: + return candidate, None + return candidate, parse_logit_lb(cert_out) + + if args.jobs == 1: + for candidate in candidates: + candidate_out, logit_lb = run_cert(candidate) if logit_lb is None: continue results.append((logit_lb, candidate_out)) + else: + with ThreadPoolExecutor(max_workers=args.jobs) as executor: + futures = {executor.submit(run_cert, candidate): candidate for candidate in candidates} + for future in as_completed(futures): + candidate_out, logit_lb = future.result() + if logit_lb is None: + continue + results.append((logit_lb, candidate_out)) - if not results: - print("No sound logit bounds produced.", file=sys.stderr) - return 1 + if not results: + print("No sound logit bounds produced.", file=sys.stderr) + return 1 - results.sort(key=lambda x: x[0], reverse=True) + results.sort(key=lambda x: x[0], reverse=True) out_path = Path(args.output) out_path.parent.mkdir(parents=True, exist_ok=True) with out_path.open("w", encoding="ascii") as f: - f.write("SOUND induction scan (logitDiffLB ranking)\n") + if args.score_mode == "logit": + f.write("SOUND induction scan (logitDiffLB ranking)\n") + else: + f.write("Induction attention scan (prev-attn ranking)\n") f.write(f"model={model_path}\n") f.write(f"target={target} negative={negative}\n") eps_header = header.get("layer_norm_eps") or header.get("eps") or "unknown" f.write(f"top={args.top} eps={eps_header}\n") - for rank, (lb, candidate) in enumerate(results, start=1): - layer = int(candidate["layer"]) - head = int(candidate["head"]) - target_id = int(candidate.get("target", target)) - negative_id = int(candidate.get("negative", negative)) - f.write( - f"{rank:02d} L{layer}H{head} " - f"target={target_id} negative={negative_id} logitDiffLB={lb}\n" - ) + if args.score_mode == "logit": + for rank, (lb, candidate) in enumerate(results, start=1): + layer = int(candidate["layer"]) + head = int(candidate["head"]) + target_id = int(candidate.get("target", target)) + negative_id = int(candidate.get("negative", negative)) + f.write( + f"{rank:02d} L{layer}H{head} " + f"target={target_id} negative={negative_id} logitDiffLB={lb}\n" + ) + else: + for rank, candidate in enumerate(candidates[: args.top], start=1): + layer = int(candidate["layer"]) + head = int(candidate["head"]) + score = candidate.get("score") + prev_mean = candidate.get("prev_mean") + prev_median = candidate.get("prev_median") + prev_top1 = candidate.get("prev_top1_frac") + eps = candidate.get("eps") + margin = candidate.get("margin") + f.write( + f"{rank:02d} L{layer}H{head} score={score} " + f"prevMean={prev_mean} prevMedian={prev_median} prevTop1={prev_top1} " + f"eps={eps} margin={margin}\n" + ) print(f"Report written to {out_path}") return 0 diff --git a/scripts/sweep_gpt2_induction_nonvacuous.py b/scripts/sweep_gpt2_induction_nonvacuous.py index 8b996a3..26d3541 100644 --- a/scripts/sweep_gpt2_induction_nonvacuous.py +++ b/scripts/sweep_gpt2_induction_nonvacuous.py @@ -2,10 +2,13 @@ # SPDX-License-Identifier: AGPL-3.0-or-later """ -Sweep prompt parameters and verify nonvacuous induction bounds for GPT-2. +Sweep prompt parameters and evaluate induction head scores for GPT-2. This is untrusted orchestration: discovery uses floating-point math and only -Lean verification results are treated as definitive. +Lean verification results are treated as definitive in logit mode. + +Layer/head indices are one-based (literature-aligned). `prev` defaults to bigram +prefix matching. """ from __future__ import annotations @@ -97,8 +100,11 @@ def run_discovery( min_eps: float, min_margin: float, min_logit_lb: float, + min_score: float, + score_mode: str, period: int | None, output_dir: Path, + prev_mode: str, ) -> list[dict]: output_dir.mkdir(parents=True, exist_ok=True) json_out = output_dir / f"{model.stem}.json" @@ -116,11 +122,17 @@ def run_discovery( str(min_margin), "--min-logit-lb", str(min_logit_lb), + "--min-score", + str(min_score), + "--score-mode", + score_mode, "--json-out", str(json_out), ] if period is not None: cmd += ["--period", str(period)] + if prev_mode != "bigram": + cmd += ["--prev-mode", prev_mode] run_cmd(cmd, check=True) payload = json.loads(json_out.read_text(encoding="ascii")) return payload.get("results", []) @@ -171,6 +183,11 @@ def write_csv_row(path: Path, row: dict) -> None: "head", "target", "negative", + "score_mode", + "score", + "prev_mean", + "prev_median", + "prev_top1_frac", "approx_logit_lb", "approx_eps", "approx_margin", @@ -204,8 +221,21 @@ def main() -> int: parser.add_argument("--min-eps", type=float, default=0.5) parser.add_argument("--min-margin", type=float, default=0.0) parser.add_argument("--min-logit-lb", type=float, default=0.0) + parser.add_argument("--min-score", type=float, default=0.0) + parser.add_argument( + "--score-mode", + choices=["attn", "logit"], + default="attn", + help="Rank by attention score or logit-diff bound.", + ) parser.add_argument("--use-period", action="store_true", help="Use pattern length as the period override") + parser.add_argument( + "--prev-mode", + choices=["bigram", "token", "period"], + default="bigram", + help="Choose prev/active construction (forwarded to discovery).", + ) parser.add_argument("--nfp-bin", help="Path to nfp binary") parser.add_argument("--discovery-dir", type=Path, default=Path("reports/discovery")) args = parser.parse_args() @@ -231,8 +261,11 @@ def main() -> int: args.min_eps, args.min_margin, args.min_logit_lb, + args.min_score, + args.score_mode, period, args.discovery_dir, + args.prev_mode, ) if not results: print( @@ -240,48 +273,83 @@ def main() -> int: flush=True, ) continue - for result in results[: args.verify_top]: - verify = verify_candidate( - nfp_cmd, - model_path, - result["layer"], - result["head"], - result["target"], - result["negative"], - period, - ) - status = "ok" if verify.ok else "fail" - if verify.ok: - print( - f"verified L{result['layer']}H{result['head']} " - f"seq={seq_len} pat={pattern_len} seed={seed}", - flush=True, + if args.score_mode == "logit": + for result in results[: args.verify_top]: + layer = result["layer"] - 1 + head = result["head"] - 1 + verify = verify_candidate( + nfp_cmd, + model_path, + layer, + head, + result["target"], + result["negative"], + period, ) - row = { - "model_path": model_path, - "seq_len": seq_len, - "pattern_len": pattern_len, - "seed": seed, - "layer": result["layer"], - "head": result["head"], - "target": result["target"], - "negative": result["negative"], - "approx_logit_lb": result["logit_lb"], - "approx_eps": result["eps"], - "approx_margin": result["margin"], - "approx_min_prev": result["min_prev"], - "approx_value_range": result["value_range"], - "active": result["active"], - "period": period if period is not None else "", - "verify_status": status, - "verify_logit_lb": verify.logit_lb or "", - } - write_csv_row(args.output, row) - if not verify.ok: - if verify.stdout: - print(f" out: {verify.stdout}", flush=True) - if verify.stderr: - print(f" err: {verify.stderr}", flush=True) + status = "ok" if verify.ok else "fail" + if verify.ok: + print( + f"verified L{result['layer']}H{result['head']} " + f"seq={seq_len} pat={pattern_len} seed={seed}", + flush=True, + ) + row = { + "model_path": model_path, + "seq_len": seq_len, + "pattern_len": pattern_len, + "seed": seed, + "layer": result["layer"], + "head": result["head"], + "target": result["target"], + "negative": result["negative"], + "score_mode": args.score_mode, + "score": "", + "prev_mean": result.get("prev_mean", ""), + "prev_median": result.get("prev_median", ""), + "prev_top1_frac": result.get("prev_top1_frac", ""), + "approx_logit_lb": result["logit_lb"], + "approx_eps": result["eps"], + "approx_margin": result["margin"], + "approx_min_prev": result["min_prev"], + "approx_value_range": result["value_range"], + "active": result["active"], + "period": period if period is not None else "", + "verify_status": status, + "verify_logit_lb": verify.logit_lb or "", + } + write_csv_row(args.output, row) + if not verify.ok: + if verify.stdout: + print(f" out: {verify.stdout}", flush=True) + if verify.stderr: + print(f" err: {verify.stderr}", flush=True) + else: + for result in results[: args.top]: + row = { + "model_path": model_path, + "seq_len": seq_len, + "pattern_len": pattern_len, + "seed": seed, + "layer": result["layer"], + "head": result["head"], + "target": result.get("target", ""), + "negative": result.get("negative", ""), + "score_mode": args.score_mode, + "score": result.get("score", ""), + "prev_mean": result.get("prev_mean", ""), + "prev_median": result.get("prev_median", ""), + "prev_top1_frac": result.get("prev_top1_frac", ""), + "approx_logit_lb": result.get("logit_lb", ""), + "approx_eps": result.get("eps", ""), + "approx_margin": result.get("margin", ""), + "approx_min_prev": result.get("min_prev", ""), + "approx_value_range": result.get("value_range", ""), + "active": result.get("active", ""), + "period": period if period is not None else "", + "verify_status": "", + "verify_logit_lb": "", + } + write_csv_row(args.output, row) return 0 From 3d304daee9aa31f024d8d6c3de9c3e65ad843821 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 16 Jan 2026 18:37:03 +0100 Subject: [PATCH 192/244] Fix bigram prev mapping to canonical induction definition --- scripts/discover_gpt2_induction_targets.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/scripts/discover_gpt2_induction_targets.py b/scripts/discover_gpt2_induction_targets.py index 721eb23..b649e75 100644 --- a/scripts/discover_gpt2_induction_targets.py +++ b/scripts/discover_gpt2_induction_targets.py @@ -12,8 +12,8 @@ literature. By default, `prev`/active are built from bigram prefix matches (the token at -q-1 maps to its previous occurrence), and heads are ranked by attention to -`prev`. +q-1 maps to the *following* token after its previous occurrence), and heads are +ranked by attention to `prev`. """ from __future__ import annotations @@ -117,13 +117,22 @@ def build_prev(tokens: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: def build_prev_bigram(tokens: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Bigram-prefix induction prev: for each position q>=1, look at token at q-1, + find its previous occurrence index j, and set prev[q] = j + 1. + + Example (tokens = [1,2,1,3,2,1]): + prev = [0,0,0,1,0,2], active = [F,F,F,T,F,T] + """ prev_token, active_token = build_prev(tokens) prev = np.zeros_like(tokens) active = np.zeros_like(tokens, dtype=bool) if tokens.size <= 1: return prev, active - prev[1:] = prev_token[:-1] - active[1:] = active_token[:-1] + prev_shift = prev_token[:-1] + 1 + active_shift = active_token[:-1] + prev[1:] = np.where(active_shift, prev_shift, 0) + active[1:] = active_shift return prev, active From e5cab0e1a55744e5f814ed3dceb04d07c00896b8 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 16 Jan 2026 18:45:25 +0100 Subject: [PATCH 193/244] Add shifted induction prev/active definitions --- Nfp/Model/InductionPrompt.lean | 95 ++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/Nfp/Model/InductionPrompt.lean b/Nfp/Model/InductionPrompt.lean index 5302c6f..ba0ea5b 100644 --- a/Nfp/Model/InductionPrompt.lean +++ b/Nfp/Model/InductionPrompt.lean @@ -32,6 +32,33 @@ theorem mem_activeOfPeriod {seq : Nat} {period : Nat} {q : Fin seq} : q ∈ activeOfPeriod (seq := seq) period ↔ period ≤ q.val := by simp [activeOfPeriod] +/-- +Shifted `prev` map for a periodic induction prompt: if `0 < period` and +`period ≤ q`, return `q - period + 1`; otherwise default to `0`. +-/ +def prevOfPeriodShift {seq : Nat} (period : Nat) (q : Fin seq) : Fin seq := by + classical + by_cases hq : period ≤ q.val + · by_cases hper : 0 < period + · have hlt : q.val - period + 1 < seq := by + have hsub : q.val - period < q.val := Nat.sub_lt_of_pos_le hper hq + have hle : q.val - period + 1 ≤ q.val := Nat.succ_le_of_lt hsub + exact lt_of_le_of_lt hle q.isLt + exact ⟨q.val - period + 1, hlt⟩ + · have hpos : 0 < seq := lt_of_le_of_lt (Nat.zero_le _) q.isLt + exact ⟨0, hpos⟩ + · have hpos : 0 < seq := lt_of_le_of_lt (Nat.zero_le _) q.isLt + exact ⟨0, hpos⟩ + +/-- Active queries for shifted periodic induction prompts (`0 < period ≤ q`). -/ +def activeOfPeriodShift {seq : Nat} (period : Nat) : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).filter (fun q => 0 < period ∧ period ≤ q.val) + +/-- Membership characterization for `activeOfPeriodShift`. -/ +theorem mem_activeOfPeriodShift {seq : Nat} {period : Nat} {q : Fin seq} : + q ∈ activeOfPeriodShift (seq := seq) period ↔ 0 < period ∧ period ≤ q.val := by + simp [activeOfPeriodShift] + /-- `prev` map induced by token repeats (defaulting to `0` when no prior match exists). -/ def prevOfTokens {seq : Nat} (tokens : Fin seq → Nat) (q : Fin seq) : Fin seq := by classical @@ -100,6 +127,74 @@ theorem prevOfTokens_spec_of_active {seq : Nat} {tokens : Fin seq → Nat} {q : exact (Fin.lt_def).2 hk exact prevOfTokens_spec (tokens := tokens) (q := q) ⟨k, hk', htok⟩ +/-- +Shifted `prev` map for induction: match the current token to its previous +occurrence and return the following position (`A B ... A -> B`). + +Example (tokens = [1,2,1,3,2,1]): + prevShift = [0,0,0,1,0,2], activeShift = {3,5} +-/ +def prevOfTokensShift {seq : Nat} (tokens : Fin seq → Nat) (q : Fin seq) : Fin seq := by + classical + by_cases hq : q ∈ activeOfTokens tokens + · let p := prevOfTokens tokens q + have hp : + p < q ∧ tokens p = tokens q ∧ + ∀ k, k < q → tokens k = tokens q → k ≤ p := by + simpa [p] using + (prevOfTokens_spec_of_active (tokens := tokens) (q := q) hq) + have hpv : p.val < q.val := (Fin.lt_def).1 hp.1 + have hle : p.val + 1 ≤ q.val := Nat.succ_le_of_lt hpv + have hlt : p.val + 1 < seq := lt_of_le_of_lt hle q.isLt + exact ⟨p.val + 1, hlt⟩ + · let hpos : 0 < seq := lt_of_le_of_lt (Nat.zero_le _) q.isLt + exact ⟨0, hpos⟩ + +/-- Active queries for shifted-token induction (same witness condition). -/ +def activeOfTokensShift {seq : Nat} (tokens : Fin seq → Nat) : Finset (Fin seq) := + activeOfTokens tokens + +/-- Membership characterization for `activeOfTokensShift`. -/ +theorem mem_activeOfTokensShift {seq : Nat} {tokens : Fin seq → Nat} {q : Fin seq} : + q ∈ activeOfTokensShift tokens ↔ ∃ k, k.val < q.val ∧ tokens k = tokens q := by + simp [activeOfTokensShift, activeOfTokens] + +/-- Shifted `prev` agrees with the maximal previous match, advanced by one. -/ +theorem prevOfTokensShift_spec {seq : Nat} {tokens : Fin seq → Nat} {q : Fin seq} + (h : ∃ k, k < q ∧ tokens k = tokens q) : + let p := prevOfTokensShift tokens q + let p0 := prevOfTokens tokens q + p.val = p0.val + 1 ∧ + p0 < q ∧ tokens p0 = tokens q ∧ + ∀ k, k < q → tokens k = tokens q → k ≤ p0 := by + classical + have hactive : q ∈ activeOfTokens tokens := by + rcases h with ⟨k, hk, htok⟩ + exact (mem_activeOfTokens (tokens := tokens) (q := q)).2 + ⟨k, (Fin.lt_def).1 hk, htok⟩ + let p0 := prevOfTokens tokens q + have hp0 : + p0 < q ∧ tokens p0 = tokens q ∧ + ∀ k, k < q → tokens k = tokens q → k ≤ p0 := by + simpa [p0] using (prevOfTokens_spec (tokens := tokens) (q := q) h) + have hpval : (prevOfTokensShift tokens q).val = p0.val + 1 := by + simpa [prevOfTokensShift, hactive, p0] + simpa [p0, hpval, hp0] + +/-- Active shifted queries imply the shifted `prev` maximal-match specification. -/ +theorem prevOfTokensShift_spec_of_active {seq : Nat} {tokens : Fin seq → Nat} {q : Fin seq} + (hq : q ∈ activeOfTokensShift tokens) : + let p := prevOfTokensShift tokens q + let p0 := prevOfTokens tokens q + p.val = p0.val + 1 ∧ + p0 < q ∧ tokens p0 = tokens q ∧ + ∀ k, k < q → tokens k = tokens q → k ≤ p0 := by + have h := (mem_activeOfTokensShift (tokens := tokens) (q := q)).1 hq + rcases h with ⟨k, hk, htok⟩ + have hk' : k < q := by + exact (Fin.lt_def).2 hk + exact prevOfTokensShift_spec (tokens := tokens) (q := q) ⟨k, hk', htok⟩ + end Model end Nfp From 503fe5d03e9ab2928a72d349f1cbcaf98e1af58d Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 16 Jan 2026 18:46:32 +0100 Subject: [PATCH 194/244] Add shifted GPT-2 induction input builder --- Nfp/Sound/Gpt2/HeadInputs.lean | 59 +++++++++++++++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/Nfp/Sound/Gpt2/HeadInputs.lean b/Nfp/Sound/Gpt2/HeadInputs.lean index b29f0f7..15e7fdb 100644 --- a/Nfp/Sound/Gpt2/HeadInputs.lean +++ b/Nfp/Sound/Gpt2/HeadInputs.lean @@ -24,7 +24,12 @@ namespace Gpt2 open Nfp.Model -/-- Build induction-head inputs from a GPT-2 head slice and prompt period. -/ +/-- +Build induction-head inputs from a GPT-2 head slice and prompt period. + +This uses the unshifted periodic prompt (`prev = q - period`), i.e. it matches +the current token rather than the canonical induction copy target. +-/ def buildInductionHeadInputs {seq dModel dHead vocab : Nat} (slice : Gpt2HeadSlice seq dModel dHead vocab) (period : Nat) : Model.InductionHeadInputs seq dModel dHead := @@ -73,6 +78,58 @@ theorem buildInductionHeadInputs_def {seq dModel dHead vocab : Nat} direction := slice.directionVec } := by simp [buildInductionHeadInputs] +/-- +Build induction-head inputs using the canonical shifted periodic prompt +(`prev = q - period + 1`, with `0 < period`). +-/ +def buildInductionHeadInputsShift {seq dModel dHead vocab : Nat} + (slice : Gpt2HeadSlice seq dModel dHead vocab) (period : Nat) : + Model.InductionHeadInputs seq dModel dHead := + { scale := slice.scale + active := activeOfPeriodShift (seq := seq) period + prev := prevOfPeriodShift (seq := seq) period + embed := slice.embed + lnEps := slice.lnEps + ln1Gamma := slice.ln1Gamma + ln1Beta := slice.ln1Beta + wq := slice.wq + bq := slice.bq + wk := slice.wk + bk := slice.bk + wv := slice.wv + bv := slice.bv + wo := slice.wo + attnBias := slice.attnBias + maskCausal := true + maskValue := (-10000 : Rat) + directionSpec := slice.direction.spec + direction := slice.directionVec } + +/-- Definitional characterization of `buildInductionHeadInputsShift`. -/ +theorem buildInductionHeadInputsShift_def {seq dModel dHead vocab : Nat} + (slice : Gpt2HeadSlice seq dModel dHead vocab) (period : Nat) : + buildInductionHeadInputsShift slice period = + { scale := slice.scale + active := activeOfPeriodShift (seq := seq) period + prev := prevOfPeriodShift (seq := seq) period + embed := slice.embed + lnEps := slice.lnEps + ln1Gamma := slice.ln1Gamma + ln1Beta := slice.ln1Beta + wq := slice.wq + bq := slice.bq + wk := slice.wk + bk := slice.bk + wv := slice.wv + bv := slice.bv + wo := slice.wo + attnBias := slice.attnBias + maskCausal := true + maskValue := (-10000 : Rat) + directionSpec := slice.direction.spec + direction := slice.directionVec } := by + simp [buildInductionHeadInputsShift] + end Gpt2 end Sound From c4fa6de6349c2b21a503bad61010943c42f609ca Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 16 Jan 2026 18:56:37 +0100 Subject: [PATCH 195/244] Add shifted prev spec predicates for induction --- Nfp/Model/InductionPrompt.lean | 30 ++++++++++++++++++++++++++++++ Nfp/Sound/Gpt2/HeadInputs.lean | 10 +++++++++- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/Nfp/Model/InductionPrompt.lean b/Nfp/Model/InductionPrompt.lean index ba0ea5b..ab00e30 100644 --- a/Nfp/Model/InductionPrompt.lean +++ b/Nfp/Model/InductionPrompt.lean @@ -4,6 +4,7 @@ module public import Mathlib.Data.Finset.Max public import Mathlib.Data.Fintype.Basic +public import Nfp.Model.InductionHead /-! Helpers for induction-style prompts. @@ -195,6 +196,35 @@ theorem prevOfTokensShift_spec_of_active {seq : Nat} {tokens : Fin seq → Nat} exact (Fin.lt_def).2 hk exact prevOfTokensShift_spec (tokens := tokens) (q := q) ⟨k, hk', htok⟩ +/-- Active queries select a `prev` strictly in the past. -/ +def InductionPrevInPast {seq dModel dHead : Nat} + (inputs : InductionHeadInputs seq dModel dHead) : Prop := + ∀ q, q ∈ inputs.active → inputs.prev q < q + +/-- +Canonical shifted-prev spec for periodic prompts. + +Note: when `1 < period`, every active query has `prev q < q`. +-/ +structure InductionPrevSpecPeriodShift {seq dModel dHead : Nat} + (period : Nat) (inputs : InductionHeadInputs seq dModel dHead) : Prop where + /-- Active queries are the shifted-period active set. -/ + active_eq : inputs.active = activeOfPeriodShift (seq := seq) period + /-- Prev map matches the shifted-period definition. -/ + prev_eq : inputs.prev = prevOfPeriodShift (seq := seq) period + +/-- +Canonical shifted-prev spec for token-based prompts. + +Note: if successive tokens repeat, the shifted target can coincide with `q`. +-/ +structure InductionPrevSpecTokensShift {seq dModel dHead : Nat} + (tokens : Fin seq → Nat) (inputs : InductionHeadInputs seq dModel dHead) : Prop where + /-- Active queries match the shifted-token definition. -/ + active_eq : inputs.active = activeOfTokensShift (seq := seq) tokens + /-- Prev map matches the shifted-token definition. -/ + prev_eq : inputs.prev = prevOfTokensShift (seq := seq) tokens + end Model end Nfp diff --git a/Nfp/Sound/Gpt2/HeadInputs.lean b/Nfp/Sound/Gpt2/HeadInputs.lean index 15e7fdb..10526c5 100644 --- a/Nfp/Sound/Gpt2/HeadInputs.lean +++ b/Nfp/Sound/Gpt2/HeadInputs.lean @@ -80,7 +80,8 @@ theorem buildInductionHeadInputs_def {seq dModel dHead vocab : Nat} /-- Build induction-head inputs using the canonical shifted periodic prompt -(`prev = q - period + 1`, with `0 < period`). +(`prev = q - period + 1`, with `0 < period`). When `1 < period`, every active +query has `prev q < q`. -/ def buildInductionHeadInputsShift {seq dModel dHead vocab : Nat} (slice : Gpt2HeadSlice seq dModel dHead vocab) (period : Nat) : @@ -130,6 +131,13 @@ theorem buildInductionHeadInputsShift_def {seq dModel dHead vocab : Nat} direction := slice.directionVec } := by simp [buildInductionHeadInputsShift] +/-- `buildInductionHeadInputsShift` satisfies the shifted-period prev/active spec. -/ +theorem buildInductionHeadInputsShift_prev_spec {seq dModel dHead vocab : Nat} + (slice : Gpt2HeadSlice seq dModel dHead vocab) (period : Nat) : + InductionPrevSpecPeriodShift (seq := seq) period + (buildInductionHeadInputsShift slice period) := by + constructor <;> simp [buildInductionHeadInputsShift] + end Gpt2 end Sound From 5406f58282892188ed7b2228c4238ae686459e06 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 16 Jan 2026 19:02:44 +0100 Subject: [PATCH 196/244] Add copy/OV scoring to induction discovery --- scripts/discover_gpt2_induction_targets.py | 101 ++++++++++++++++++--- scripts/scan_gpt2_induction_sound.py | 12 ++- scripts/sweep_gpt2_induction_nonvacuous.py | 15 ++- 3 files changed, 108 insertions(+), 20 deletions(-) diff --git a/scripts/discover_gpt2_induction_targets.py b/scripts/discover_gpt2_induction_targets.py index b649e75..ec0ddc6 100644 --- a/scripts/discover_gpt2_induction_targets.py +++ b/scripts/discover_gpt2_induction_targets.py @@ -13,7 +13,8 @@ By default, `prev`/active are built from bigram prefix matches (the token at q-1 maps to the *following* token after its previous occurrence), and heads are -ranked by attention to `prev`. +ranked by attention to `prev`. Use `--score-mode=copy` or `--score-mode=attn_copy` +to include OV/copying alignment in the ranking. """ from __future__ import annotations @@ -55,6 +56,8 @@ class AttnResult: prev_mean: float prev_median: float prev_top1_frac: float + copy_mean: float + copy_weighted_mean: float eps: float margin: float active: int @@ -226,6 +229,30 @@ def compute_eps_margin( return max(eps_vals), min(margin_vals), prev_mean, prev_median, prev_top1 +def compute_copy_scores( + ov: np.ndarray, + weights_prev: np.ndarray, + columns: Dict[int, np.ndarray], + tokens: np.ndarray, + prev: np.ndarray, + active_positions: Iterable[int], +) -> Tuple[float, float]: + copy_vals: List[float] = [] + copy_weighted_vals: List[float] = [] + for idx, q in enumerate(active_positions): + tok = int(tokens[q]) + col = columns.get(tok) + if col is None: + continue + prev_q = int(prev[q]) + val = float(ov[prev_q] @ col) + copy_vals.append(val) + copy_weighted_vals.append(float(weights_prev[idx]) * val) + if not copy_vals: + return 0.0, 0.0 + return float(np.mean(copy_vals)), float(np.mean(copy_weighted_vals)) + + def format_result(result: HeadResult) -> str: layer = result.layer + 1 head = result.head + 1 @@ -246,6 +273,7 @@ def format_attn_result(result: AttnResult) -> str: f"L{layer}H{head} score={result.score:.6f} " f"prevMean={result.prev_mean:.6f} prevMedian={result.prev_median:.6f} " f"prevTop1={result.prev_top1_frac:.3f} " + f"copyMean={result.copy_mean:.6f} copyWeighted={result.copy_weighted_mean:.6f} " f"eps={result.eps:.6f} margin={result.margin:.6f} active={result.active}" ) @@ -260,15 +288,24 @@ def main() -> int: help="Run verifier on the top N candidates") parser.add_argument( "--score-mode", - choices=["attn", "logit"], + choices=["attn", "copy", "attn_copy", "logit"], default="attn", - help="Rank heads by attention to prev (attn) or logit lower bound (logit).", + help=( + "Rank heads by attention to prev (attn), OV copy score (copy), " + "attention-weighted copy score (attn_copy), or logit lower bound (logit)." + ), ) parser.add_argument( "--min-score", type=float, default=0.0, - help="Minimum attention score (attn mode).", + help="Minimum score threshold for the selected score mode.", + ) + parser.add_argument( + "--min-copy", + type=float, + default=None, + help="Optional minimum OV copy score.", ) parser.add_argument("--min-eps", type=float, default=0.5, help="Filter candidates with eps above this value") @@ -333,7 +370,9 @@ def main() -> int: active_positions = [int(i) for i, flag in enumerate(active_mask) if flag] if not active_positions: raise SystemExit("No active positions found in the prompt") + prev_indices = prev[np.array(active_positions, dtype=np.int64)] + prompt_tokens = sorted({int(tok) for tok in tokens.tolist()}) unique_tokens = [] seen = set() for tok in tokens.tolist(): @@ -347,7 +386,7 @@ def main() -> int: head_data: Dict[ Tuple[int, int], - Tuple[np.ndarray, np.ndarray, float, float, float, float, float], + Tuple[np.ndarray, np.ndarray, np.ndarray, float, float, float, float, float], ] = {} for layer_idx in range(num_layers): @@ -391,9 +430,11 @@ def main() -> int: eps, margin, prev_mean, prev_median, prev_top1 = compute_eps_margin( weights, scores, prev, active_positions ) + prev_weights = weights[np.array(active_positions), prev_indices] head_data[(layer_idx, head_idx)] = ( v, wo, + prev_weights, eps, margin, prev_mean, @@ -407,7 +448,7 @@ def main() -> int: unembed_start = f.tell() columns: Dict[int, np.ndarray] = {} - for tok in unique_tokens: + for tok in prompt_tokens: columns[tok] = read_unembed_column( f, unembed_start, @@ -418,13 +459,37 @@ def main() -> int: results: List[HeadResult] = [] attn_results: List[AttnResult] = [] - prev_indices = prev[np.array(active_positions, dtype=np.int64)] - for (layer_idx, head_idx), (v, wo, eps, margin, prev_mean, prev_median, prev_top1) in head_data.items(): + for (layer_idx, head_idx), ( + v, + wo, + prev_weights, + eps, + margin, + prev_mean, + prev_median, + prev_top1, + ) in head_data.items(): if args.score_mode == "logit": if eps > args.min_eps or margin < args.min_margin: continue - if args.score_mode == "attn": - score = prev_mean + if args.score_mode != "logit": + ov = v @ wo + copy_mean, copy_weighted_mean = compute_copy_scores( + ov, + prev_weights, + columns, + tokens, + prev, + active_positions, + ) + if args.min_copy is not None and copy_mean < args.min_copy: + continue + if args.score_mode == "attn": + score = prev_mean + elif args.score_mode == "copy": + score = copy_mean + else: + score = copy_weighted_mean if score < args.min_score: continue attn_results.append( @@ -435,6 +500,8 @@ def main() -> int: prev_mean=prev_mean, prev_median=prev_median, prev_top1_frac=prev_top1, + copy_mean=copy_mean, + copy_weighted_mean=copy_weighted_mean, eps=eps, margin=margin, active=len(active_positions), @@ -478,7 +545,7 @@ def main() -> int: if best is not None: results.append(best) - if args.score_mode == "attn": + if args.score_mode != "logit": attn_results.sort(key=lambda r: r.score, reverse=True) else: results.sort(key=lambda r: r.logit_lb, reverse=True) @@ -490,9 +557,10 @@ def main() -> int: f.write(f"tokens={len(unique_tokens)} active={len(active_positions)}\n") f.write( f"min-eps={args.min_eps} min-margin={args.min_margin} " - f"min-logit-lb={args.min_logit_lb} min-score={args.min_score}\n" + f"min-logit-lb={args.min_logit_lb} min-score={args.min_score} " + f"min-copy={args.min_copy}\n" ) - if args.score_mode == "attn": + if args.score_mode != "logit": for rank, result in enumerate(attn_results[: args.top], start=1): f.write(f"{rank:02d} {format_attn_result(result)}\n") else: @@ -500,7 +568,7 @@ def main() -> int: f.write(f"{rank:02d} {format_result(result)}\n") print(f"Wrote report to {args.output}") - if args.score_mode == "attn": + if args.score_mode != "logit": for rank, result in enumerate(attn_results[: args.top], start=1): print(f"{rank:02d} {format_attn_result(result)}") else: @@ -518,8 +586,9 @@ def main() -> int: "min_margin": args.min_margin, "min_logit_lb": args.min_logit_lb, "min_score": args.min_score, + "min_copy": args.min_copy, } - if args.score_mode == "attn": + if args.score_mode != "logit": payload["results"] = [ { "rank": rank, @@ -529,6 +598,8 @@ def main() -> int: "prev_mean": r.prev_mean, "prev_median": r.prev_median, "prev_top1_frac": r.prev_top1_frac, + "copy_mean": r.copy_mean, + "copy_weighted_mean": r.copy_weighted_mean, "eps": r.eps, "margin": r.margin, "active": r.active, diff --git a/scripts/scan_gpt2_induction_sound.py b/scripts/scan_gpt2_induction_sound.py index dbac47c..2ebb93e 100755 --- a/scripts/scan_gpt2_induction_sound.py +++ b/scripts/scan_gpt2_induction_sound.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: AGPL-3.0-or-later """ -Scan GPT-2 induction head candidates with attention or logit-diff bounds. +Scan GPT-2 induction head candidates with attention/copy or logit-diff bounds. This script: 1) Ensures a GPT-2 "rigorous induction" binary model exists. @@ -123,11 +123,12 @@ def main() -> int: parser.add_argument("--min-margin", type=float, default=0.0) parser.add_argument("--min-logit-lb", type=float, default=0.0) parser.add_argument("--min-score", type=float, default=0.0) + parser.add_argument("--min-copy", type=float) parser.add_argument( "--score-mode", - choices=["attn", "logit"], + choices=["attn", "copy", "attn_copy", "logit"], default="attn", - help="Rank by attention score or logit-diff bound.", + help="Rank by attention/copy score or logit-diff bound.", ) parser.add_argument("--layers", help="Comma-separated layer list or 'all'") parser.add_argument("--heads", help="Comma-separated head list or 'all'") @@ -187,6 +188,8 @@ def main() -> int: "--json-out", str(discover_json), ] + if args.min_copy is not None: + discover_cmd += ["--min-copy", str(args.min_copy)] if args.layers is not None: discover_cmd += ["--layers", args.layers] if args.heads is not None: @@ -281,11 +284,14 @@ def run_cert(candidate: dict[str, int]) -> tuple[dict[str, int], Fraction | None prev_mean = candidate.get("prev_mean") prev_median = candidate.get("prev_median") prev_top1 = candidate.get("prev_top1_frac") + copy_mean = candidate.get("copy_mean") + copy_weighted = candidate.get("copy_weighted_mean") eps = candidate.get("eps") margin = candidate.get("margin") f.write( f"{rank:02d} L{layer}H{head} score={score} " f"prevMean={prev_mean} prevMedian={prev_median} prevTop1={prev_top1} " + f"copyMean={copy_mean} copyWeighted={copy_weighted} " f"eps={eps} margin={margin}\n" ) diff --git a/scripts/sweep_gpt2_induction_nonvacuous.py b/scripts/sweep_gpt2_induction_nonvacuous.py index 26d3541..b71a822 100644 --- a/scripts/sweep_gpt2_induction_nonvacuous.py +++ b/scripts/sweep_gpt2_induction_nonvacuous.py @@ -101,6 +101,7 @@ def run_discovery( min_margin: float, min_logit_lb: float, min_score: float, + min_copy: float | None, score_mode: str, period: int | None, output_dir: Path, @@ -129,6 +130,8 @@ def run_discovery( "--json-out", str(json_out), ] + if min_copy is not None: + cmd += ["--min-copy", str(min_copy)] if period is not None: cmd += ["--period", str(period)] if prev_mode != "bigram": @@ -188,6 +191,8 @@ def write_csv_row(path: Path, row: dict) -> None: "prev_mean", "prev_median", "prev_top1_frac", + "copy_mean", + "copy_weighted_mean", "approx_logit_lb", "approx_eps", "approx_margin", @@ -222,11 +227,12 @@ def main() -> int: parser.add_argument("--min-margin", type=float, default=0.0) parser.add_argument("--min-logit-lb", type=float, default=0.0) parser.add_argument("--min-score", type=float, default=0.0) + parser.add_argument("--min-copy", type=float) parser.add_argument( "--score-mode", - choices=["attn", "logit"], + choices=["attn", "copy", "attn_copy", "logit"], default="attn", - help="Rank by attention score or logit-diff bound.", + help="Rank by attention/copy score or logit-diff bound.", ) parser.add_argument("--use-period", action="store_true", help="Use pattern length as the period override") @@ -262,6 +268,7 @@ def main() -> int: args.min_margin, args.min_logit_lb, args.min_score, + args.min_copy, args.score_mode, period, args.discovery_dir, @@ -307,6 +314,8 @@ def main() -> int: "prev_mean": result.get("prev_mean", ""), "prev_median": result.get("prev_median", ""), "prev_top1_frac": result.get("prev_top1_frac", ""), + "copy_mean": result.get("copy_mean", ""), + "copy_weighted_mean": result.get("copy_weighted_mean", ""), "approx_logit_lb": result["logit_lb"], "approx_eps": result["eps"], "approx_margin": result["margin"], @@ -339,6 +348,8 @@ def main() -> int: "prev_mean": result.get("prev_mean", ""), "prev_median": result.get("prev_median", ""), "prev_top1_frac": result.get("prev_top1_frac", ""), + "copy_mean": result.get("copy_mean", ""), + "copy_weighted_mean": result.get("copy_weighted_mean", ""), "approx_logit_lb": result.get("logit_lb", ""), "approx_eps": result.get("eps", ""), "approx_margin": result.get("margin", ""), From def4cb94f45542280facc3eda4411ab47808db0f Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 16 Jan 2026 19:10:04 +0100 Subject: [PATCH 197/244] Make CLI layer/head indices 1-based by default --- Nfp/Cli.lean | 192 +++++++++++++++++++++++++++++++++++++-------------- README.md | 3 +- 2 files changed, 142 insertions(+), 53 deletions(-) diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index d89b4bd..90e697d 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -50,6 +50,16 @@ private def parseSplitPreset (raw : String) : | _ => throw s!"unknown preset '{raw}' (expected: fast, balanced, tight)" +private def toZeroBased (label : String) (idx : Nat) (zeroBased : Bool) : + Except String Nat := do + if zeroBased then + pure idx + else + if idx = 0 then + throw s!"{label} must be >= 1 for 1-based indexing (use --zero-based for 0-based)" + else + pure (idx - 1) + private def runInductionCertifyUnified (requireNonvacuous : Bool) (p : Parsed) : IO UInt32 := do let inputsPath? := (p.flag? "inputs").map (·.as! String) let modelPath? := (p.flag? "model").map (·.as! String) @@ -64,6 +74,7 @@ private def runInductionCertifyUnified (requireNonvacuous : Bool) (p : Parsed) : let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) let timing? := (p.flag? "timing").map (·.as! Nat) let heartbeatMs? := (p.flag? "heartbeat-ms").map (·.as! Nat) + let zeroBased := p.hasFlag "zero-based" let fail (msg : String) : IO UInt32 := do IO.eprintln s!"error: {msg}" return 2 @@ -97,27 +108,33 @@ private def runInductionCertifyUnified (requireNonvacuous : Bool) (p : Parsed) : | none, some modelPath => match layer?, head? with | some layer, some head => - match direction? with - | some ⟨dirTarget, dirNegative⟩ => - if requireNonvacuous then - IO.runInductionCertifyHeadModelNonvacuous modelPath layer head dirTarget - dirNegative period? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? - timing? heartbeatMs? - splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - else - IO.runInductionCertifyHeadModel modelPath layer head dirTarget dirNegative - period? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? - timing? heartbeatMs? - splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - | none => - if requireNonvacuous then - IO.runInductionCertifyHeadModelAutoNonvacuous modelPath layer head period? - minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? - splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - else - IO.runInductionCertifyHeadModelAuto modelPath layer head period? - minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? - splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + let layerE := toZeroBased "layer" layer zeroBased + let headE := toZeroBased "head" head zeroBased + match layerE, headE with + | Except.error msg, _ => fail msg + | _, Except.error msg => fail msg + | Except.ok layer', Except.ok head' => + match direction? with + | some ⟨dirTarget, dirNegative⟩ => + if requireNonvacuous then + IO.runInductionCertifyHeadModelNonvacuous modelPath layer' head' dirTarget + dirNegative period? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? + timing? heartbeatMs? + splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + else + IO.runInductionCertifyHeadModel modelPath layer' head' dirTarget dirNegative + period? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? + timing? heartbeatMs? + splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + | none => + if requireNonvacuous then + IO.runInductionCertifyHeadModelAutoNonvacuous modelPath layer' head' period? + minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? + splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + else + IO.runInductionCertifyHeadModelAuto modelPath layer' head' period? + minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? + splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? | _, _ => fail "--layer and --head are required with --model" | none, none => @@ -139,6 +156,7 @@ private def runInductionIntervalSimple (p : Parsed) : IO UInt32 := do let period? := (p.flag? "period").map (·.as! Nat) let directionStr? := (p.flag? "direction").map (·.as! String) let outPath? := (p.flag? "out").map (·.as! String) + let zeroBased := p.hasFlag "zero-based" let fail (msg : String) : IO UInt32 := do IO.eprintln s!"error: {msg}" return 2 @@ -160,8 +178,14 @@ private def runInductionIntervalSimple (p : Parsed) : IO UInt32 := do | none, some modelPath => match layer?, head?, direction? with | some layer, some head, some ⟨dirTarget, dirNegative⟩ => - IO.runInductionHeadIntervalModel modelPath layer head dirTarget dirNegative period? - outPath? + let layerE := toZeroBased "layer" layer zeroBased + let headE := toZeroBased "head" head zeroBased + match layerE, headE with + | Except.error msg, _ => fail msg + | _, Except.error msg => fail msg + | Except.ok layer', Except.ok head' => + IO.runInductionHeadIntervalModel modelPath layer' head' dirTarget dirNegative + period? outPath? | _, _, none => fail "--direction is required with --model (use \"target,negative\")" | _, _, _ => @@ -178,8 +202,9 @@ def inductionCertifySimpleCmd : Cmd := `[Cli| FLAGS: inputs : String; "Path to the induction head input file (use either --inputs or --model)." model : String; "Path to the NFP_BINARY_V1 model file (use either --inputs or --model)." - layer : Nat; "Layer index for the induction head (required with --model)." - head : Nat; "Head index for the induction head (required with --model)." + layer : Nat; "Layer index for the induction head (1-based, required with --model)." + head : Nat; "Head index for the induction head (1-based, required with --model)." + "zero-based"; "Interpret --layer/--head as zero-based indices (legacy)." period : Nat; "Optional prompt period override (model only; default: derive from tokens)." direction : String; "Optional logit-diff direction as \"target,negative\" (model only). \ When omitted with --model, direction is derived from tokens." @@ -201,8 +226,9 @@ def inductionCertifyNonvacuousSimpleCmd : Cmd := `[Cli| FLAGS: inputs : String; "Path to the induction head input file (use either --inputs or --model)." model : String; "Path to the NFP_BINARY_V1 model file (use either --inputs or --model)." - layer : Nat; "Layer index for the induction head (required with --model)." - head : Nat; "Head index for the induction head (required with --model)." + layer : Nat; "Layer index for the induction head (1-based, required with --model)." + head : Nat; "Head index for the induction head (1-based, required with --model)." + "zero-based"; "Interpret --layer/--head as zero-based indices (legacy)." period : Nat; "Optional prompt period override (model only; default: derive from tokens)." direction : String; "Optional logit-diff direction as \"target,negative\" (model only). \ When omitted with --model, direction is derived from tokens." @@ -224,8 +250,9 @@ def inductionIntervalSimpleCmd : Cmd := `[Cli| FLAGS: inputs : String; "Path to the induction head input file (use either --inputs or --model)." model : String; "Path to the NFP_BINARY_V1 model file (use either --inputs or --model)." - layer : Nat; "Layer index for the induction head (required with --model)." - head : Nat; "Head index for the induction head (required with --model)." + layer : Nat; "Layer index for the induction head (1-based, required with --model)." + head : Nat; "Head index for the induction head (1-based, required with --model)." + "zero-based"; "Interpret --layer/--head as zero-based indices (legacy)." period : Nat; "Optional prompt period override (model only; default: derive from tokens)." direction : String; "Required logit-diff direction as \"target,negative\" (model only)." out : String; "Optional path to write the residual-interval certificate." @@ -477,9 +504,20 @@ def runInductionCertifyHeadModel (p : Parsed) : IO UInt32 := do let splitBudgetK? := (p.flag? "split-budget-k").map (·.as! Nat) let splitBudgetDiffBase? := (p.flag? "split-budget-diff-base").map (·.as! Nat) let splitBudgetDiffRefined? := (p.flag? "split-budget-diff-refined").map (·.as! Nat) - IO.runInductionCertifyHeadModel modelPath layer head dirTarget dirNegative period? - minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? - splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + let zeroBased := p.hasFlag "zero-based" + let layerE := toZeroBased "layer" layer zeroBased + let headE := toZeroBased "head" head zeroBased + match layerE, headE with + | Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok layer', Except.ok head' => + IO.runInductionCertifyHeadModel modelPath layer' head' dirTarget dirNegative period? + minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? + splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? /-- `nfp induction certify_head_model_nonvacuous` subcommand. -/ def runInductionCertifyHeadModelNonvacuous (p : Parsed) : IO UInt32 := do @@ -499,9 +537,20 @@ def runInductionCertifyHeadModelNonvacuous (p : Parsed) : IO UInt32 := do let splitBudgetK? := (p.flag? "split-budget-k").map (·.as! Nat) let splitBudgetDiffBase? := (p.flag? "split-budget-diff-base").map (·.as! Nat) let splitBudgetDiffRefined? := (p.flag? "split-budget-diff-refined").map (·.as! Nat) - IO.runInductionCertifyHeadModelNonvacuous modelPath layer head dirTarget dirNegative period? - minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? - splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + let zeroBased := p.hasFlag "zero-based" + let layerE := toZeroBased "layer" layer zeroBased + let headE := toZeroBased "head" head zeroBased + match layerE, headE with + | Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok layer', Except.ok head' => + IO.runInductionCertifyHeadModelNonvacuous modelPath layer' head' dirTarget dirNegative + period? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? + splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? /-- `nfp induction certify_head_model_auto` subcommand. -/ def runInductionCertifyHeadModelAuto (p : Parsed) : IO UInt32 := do @@ -519,9 +568,20 @@ def runInductionCertifyHeadModelAuto (p : Parsed) : IO UInt32 := do let splitBudgetK? := (p.flag? "split-budget-k").map (·.as! Nat) let splitBudgetDiffBase? := (p.flag? "split-budget-diff-base").map (·.as! Nat) let splitBudgetDiffRefined? := (p.flag? "split-budget-diff-refined").map (·.as! Nat) - IO.runInductionCertifyHeadModelAuto modelPath layer head period? - minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? - splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + let zeroBased := p.hasFlag "zero-based" + let layerE := toZeroBased "layer" layer zeroBased + let headE := toZeroBased "head" head zeroBased + match layerE, headE with + | Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok layer', Except.ok head' => + IO.runInductionCertifyHeadModelAuto modelPath layer' head' period? + minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? + splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? /-- `nfp induction certify_head_model_auto_nonvacuous` subcommand. -/ def runInductionCertifyHeadModelAutoNonvacuous (p : Parsed) : IO UInt32 := do @@ -539,9 +599,20 @@ def runInductionCertifyHeadModelAutoNonvacuous (p : Parsed) : IO UInt32 := do let splitBudgetK? := (p.flag? "split-budget-k").map (·.as! Nat) let splitBudgetDiffBase? := (p.flag? "split-budget-diff-base").map (·.as! Nat) let splitBudgetDiffRefined? := (p.flag? "split-budget-diff-refined").map (·.as! Nat) - IO.runInductionCertifyHeadModelAutoNonvacuous modelPath layer head period? - minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? - splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + let zeroBased := p.hasFlag "zero-based" + let layerE := toZeroBased "layer" layer zeroBased + let headE := toZeroBased "head" head zeroBased + match layerE, headE with + | Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok layer', Except.ok head' => + IO.runInductionCertifyHeadModelAutoNonvacuous modelPath layer' head' period? + minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? + splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? /-- `nfp induction certify_head_model` subcommand. -/ def inductionCertifyHeadModelCmd : Cmd := `[Cli| @@ -549,8 +620,9 @@ def inductionCertifyHeadModelCmd : Cmd := `[Cli| "Check induction certificates by reading a model binary directly." FLAGS: model : String; "Path to the NFP_BINARY_V1 model file." - layer : Nat; "Layer index for the induction head." - head : Nat; "Head index for the induction head." + layer : Nat; "Layer index for the induction head (1-based)." + head : Nat; "Head index for the induction head (1-based)." + "zero-based"; "Interpret --layer/--head as zero-based indices (legacy)." period : Nat; "Optional prompt period override (default: derive from tokens)." "direction-target" : Nat; "Target token id for logit-diff direction." "direction-negative" : Nat; "Negative token id for logit-diff direction." @@ -576,8 +648,9 @@ def inductionCertifyHeadModelNonvacuousCmd : Cmd := `[Cli| "Require a strictly positive logit-diff bound from a model binary." FLAGS: model : String; "Path to the NFP_BINARY_V1 model file." - layer : Nat; "Layer index for the induction head." - head : Nat; "Head index for the induction head." + layer : Nat; "Layer index for the induction head (1-based)." + head : Nat; "Head index for the induction head (1-based)." + "zero-based"; "Interpret --layer/--head as zero-based indices (legacy)." period : Nat; "Optional prompt period override (default: derive from tokens)." "direction-target" : Nat; "Target token id for logit-diff direction." "direction-negative" : Nat; "Negative token id for logit-diff direction." @@ -604,8 +677,9 @@ def inductionCertifyHeadModelAutoCmd : Cmd := `[Cli| from the prompt tokens." FLAGS: model : String; "Path to the NFP_BINARY_V1 model file." - layer : Nat; "Layer index for the induction head." - head : Nat; "Head index for the induction head." + layer : Nat; "Layer index for the induction head (1-based)." + head : Nat; "Head index for the induction head (1-based)." + "zero-based"; "Interpret --layer/--head as zero-based indices (legacy)." period : Nat; "Optional prompt period override (default: derive from tokens)." "min-active" : Nat; "Optional minimum number of active queries required \ (default: max 1 (seq/8))." @@ -630,8 +704,9 @@ def inductionCertifyHeadModelAutoNonvacuousCmd : Cmd := `[Cli| derived from the prompt tokens." FLAGS: model : String; "Path to the NFP_BINARY_V1 model file." - layer : Nat; "Layer index for the induction head." - head : Nat; "Head index for the induction head." + layer : Nat; "Layer index for the induction head (1-based)." + head : Nat; "Head index for the induction head (1-based)." + "zero-based"; "Interpret --layer/--head as zero-based indices (legacy)." period : Nat; "Optional prompt period override (default: derive from tokens)." "min-active" : Nat; "Optional minimum number of active queries required \ (default: max 1 (seq/8))." @@ -673,7 +748,19 @@ def runInductionHeadIntervalModel (p : Parsed) : IO UInt32 := do let dirTarget := p.flag! "direction-target" |>.as! Nat let dirNegative := p.flag! "direction-negative" |>.as! Nat let outPath? := (p.flag? "out").map (·.as! String) - IO.runInductionHeadIntervalModel modelPath layer head dirTarget dirNegative period? outPath? + let zeroBased := p.hasFlag "zero-based" + let layerE := toZeroBased "layer" layer zeroBased + let headE := toZeroBased "head" head zeroBased + match layerE, headE with + | Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok layer', Except.ok head' => + IO.runInductionHeadIntervalModel modelPath layer' head' dirTarget dirNegative period? + outPath? /-- `nfp induction head_interval_model` subcommand. -/ def inductionHeadIntervalModelCmd : Cmd := `[Cli| @@ -681,8 +768,9 @@ def inductionHeadIntervalModelCmd : Cmd := `[Cli| "Build head-output interval bounds by reading a model binary directly." FLAGS: model : String; "Path to the NFP_BINARY_V1 model file." - layer : Nat; "Layer index for the induction head." - head : Nat; "Head index for the induction head." + layer : Nat; "Layer index for the induction head (1-based)." + head : Nat; "Head index for the induction head (1-based)." + "zero-based"; "Interpret --layer/--head as zero-based indices (legacy)." period : Nat; "Optional prompt period override (default: derive from tokens)." "direction-target" : Nat; "Target token id for logit-diff direction." "direction-negative" : Nat; "Negative token id for logit-diff direction." diff --git a/README.md b/README.md index 77ecb4f..4434b25 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,8 @@ verified by the CLI. Note: the discovery/scan/sweep helper scripts use **one-based** layer/head indices (literature-aligned), default to **bigram prefix matching** for `prev`, and **rank by attention score** unless you explicitly switch to -logit-diff mode. The Lean CLI continues to accept zero-based layer/head indices. +logit-diff mode. The Lean CLI now expects **one-based** layer/head indices by +default; pass `--zero-based` to use legacy zero-based indices. ```bash python scripts/build_gpt2_induction_cert.py \ From c2753ca7cefc20d7dee6e46d54acf16553042512 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 16 Jan 2026 19:13:39 +0100 Subject: [PATCH 198/244] Fix shifted prev spec proof --- Nfp/Model/InductionPrompt.lean | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/Nfp/Model/InductionPrompt.lean b/Nfp/Model/InductionPrompt.lean index ab00e30..267d7d9 100644 --- a/Nfp/Model/InductionPrompt.lean +++ b/Nfp/Model/InductionPrompt.lean @@ -179,8 +179,15 @@ theorem prevOfTokensShift_spec {seq : Nat} {tokens : Fin seq → Nat} {q : Fin s ∀ k, k < q → tokens k = tokens q → k ≤ p0 := by simpa [p0] using (prevOfTokens_spec (tokens := tokens) (q := q) h) have hpval : (prevOfTokensShift tokens q).val = p0.val + 1 := by - simpa [prevOfTokensShift, hactive, p0] - simpa [p0, hpval, hp0] + simp [prevOfTokensShift, hactive, p0] + have hpval' : + (prevOfTokensShift tokens q).val = (prevOfTokens tokens q).val + 1 := by + simpa [p0] using hpval + have hp0' : + prevOfTokens tokens q < q ∧ tokens (prevOfTokens tokens q) = tokens q ∧ + ∀ k, k < q → tokens k = tokens q → k ≤ prevOfTokens tokens q := by + simpa [p0] using hp0 + simpa using And.intro hpval' hp0' /-- Active shifted queries imply the shifted `prev` maximal-match specification. -/ theorem prevOfTokensShift_spec_of_active {seq : Nat} {tokens : Fin seq → Nat} {q : Fin seq} From f995ce73f00c28152bf23f5670eb9a1162d9e8fb Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 16 Jan 2026 21:19:23 +0100 Subject: [PATCH 199/244] Add synthetic prefix-matching benchmark mode --- README.md | 4 ++ scripts/scan_gpt2_induction_sound.py | 84 ++++++++++++++++++++-- scripts/sweep_gpt2_induction_nonvacuous.py | 3 +- 3 files changed, 85 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 4434b25..70fcdfc 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,10 @@ indices (literature-aligned), default to **bigram prefix matching** for logit-diff mode. The Lean CLI now expects **one-based** layer/head indices by default; pass `--zero-based` to use legacy zero-based indices. +For canonical prefix-matching benchmarks, `scripts/scan_gpt2_induction_sound.py` +supports `--synthetic` to generate repeated-random pattern prompts and score +attention/copying on that distribution. + ```bash python scripts/build_gpt2_induction_cert.py \ --output reports/gpt2_induction.cert \ diff --git a/scripts/scan_gpt2_induction_sound.py b/scripts/scan_gpt2_induction_sound.py index 2ebb93e..bc42271 100755 --- a/scripts/scan_gpt2_induction_sound.py +++ b/scripts/scan_gpt2_induction_sound.py @@ -5,7 +5,8 @@ Scan GPT-2 induction head candidates with attention/copy or logit-diff bounds. This script: -1) Ensures a GPT-2 "rigorous induction" binary model exists. +1) Ensures a GPT-2 "rigorous induction" binary model exists (or generates one + from repeated random patterns via --synthetic). 2) Uses the untrusted discovery helper to propose head candidates. 3) Optionally runs `nfp induction certify_head_model_nonvacuous` in logit mode. @@ -32,11 +33,43 @@ def run_cmd(cmd: list[str]) -> str: return proc.stdout -def ensure_model(model_path: Path) -> None: +def ensure_model( + model_path: Path, + *, + seq_len: int = 256, + pattern_len: int = 20, + seed: int = 1337, + vocab_min: int = 1000, + vocab_max: int = 5000, + min_word_length: int = 4, + allow_no_leading_space: bool = False, + model_name: str = "gpt2", +) -> None: if model_path.exists(): return model_path.parent.mkdir(parents=True, exist_ok=True) - generator = [sys.executable, "scripts/generate_rigorous_induction.py", str(model_path)] + generator = [ + sys.executable, + "scripts/generate_rigorous_induction.py", + "--output", + str(model_path), + "--seq-len", + str(seq_len), + "--pattern-len", + str(pattern_len), + "--seed", + str(seed), + "--vocab-min", + str(vocab_min), + "--vocab-max", + str(vocab_max), + "--min-word-length", + str(min_word_length), + "--model", + model_name, + ] + if allow_no_leading_space: + generator.append("--allow-no-leading-space") if shutil.which("uv"): generator = ["uv", "run"] + generator subprocess.run(generator, check=True) @@ -139,6 +172,19 @@ def main() -> int: default="bigram", help="Choose prev/active construction (forwarded to discovery).", ) + parser.add_argument( + "--synthetic", + action="store_true", + help="Generate a repeated-random pattern prompt (prefix-matching benchmark).", + ) + parser.add_argument("--synthetic-seq-len", type=int, default=256) + parser.add_argument("--synthetic-pattern-len", type=int, default=20) + parser.add_argument("--synthetic-seed", type=int, default=1337) + parser.add_argument("--synthetic-vocab-min", type=int, default=1000) + parser.add_argument("--synthetic-vocab-max", type=int, default=5000) + parser.add_argument("--synthetic-min-word-length", type=int, default=4) + parser.add_argument("--synthetic-allow-no-leading-space", action="store_true") + parser.add_argument("--synthetic-model", default="gpt2") parser.add_argument("--output", default="reports/gpt2_induction_sound_scan.txt") args = parser.parse_args() args.jobs = max(1, args.jobs) @@ -147,8 +193,30 @@ def main() -> int: if args.fast and not top_arg and args.top == parser.get_default("top"): args.top = 4 - model_path = Path(args.model) - ensure_model(model_path) + model_arg = any(a.startswith("--model") for a in sys.argv[1:]) + if args.synthetic and not model_arg: + model_path = Path( + "models/" + f"gpt2_rigorous_seq{args.synthetic_seq_len}" + f"_pat{args.synthetic_pattern_len}" + f"_seed{args.synthetic_seed}.nfpt" + ) + else: + model_path = Path(args.model) + if args.synthetic: + ensure_model( + model_path, + seq_len=args.synthetic_seq_len, + pattern_len=args.synthetic_pattern_len, + seed=args.synthetic_seed, + vocab_min=args.synthetic_vocab_min, + vocab_max=args.synthetic_vocab_max, + min_word_length=args.synthetic_min_word_length, + allow_no_leading_space=args.synthetic_allow_no_leading_space, + model_name=args.synthetic_model, + ) + else: + ensure_model(model_path) nfp_cmd = resolve_nfp_cmd(args.nfp_bin) header, tokens = read_header_and_tokens(model_path) @@ -263,6 +331,12 @@ def run_cert(candidate: dict[str, int]) -> tuple[dict[str, int], Fraction | None else: f.write("Induction attention scan (prev-attn ranking)\n") f.write(f"model={model_path}\n") + if args.synthetic: + f.write( + "synthetic=" + f"seq{args.synthetic_seq_len}_pat{args.synthetic_pattern_len}_" + f"seed{args.synthetic_seed}\n" + ) f.write(f"target={target} negative={negative}\n") eps_header = header.get("layer_norm_eps") or header.get("eps") or "unknown" f.write(f"top={args.top} eps={eps_header}\n") diff --git a/scripts/sweep_gpt2_induction_nonvacuous.py b/scripts/sweep_gpt2_induction_nonvacuous.py index b71a822..aad0723 100644 --- a/scripts/sweep_gpt2_induction_nonvacuous.py +++ b/scripts/sweep_gpt2_induction_nonvacuous.py @@ -8,7 +8,8 @@ Lean verification results are treated as definitive in logit mode. Layer/head indices are one-based (literature-aligned). `prev` defaults to bigram -prefix matching. +prefix matching. This sweep generates repeated-random patterns to benchmark +prefix-matching scores across seeds/lengths. """ from __future__ import annotations From d883c8b02f9f63d3bf1be7639c8b5a0930235a95 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 16 Jan 2026 21:42:25 +0100 Subject: [PATCH 200/244] Add activation-based discovery scoring --- scripts/discover_gpt2_induction_targets.py | 207 ++++++++++++++++++++- 1 file changed, 204 insertions(+), 3 deletions(-) diff --git a/scripts/discover_gpt2_induction_targets.py b/scripts/discover_gpt2_induction_targets.py index ec0ddc6..b7c71fd 100644 --- a/scripts/discover_gpt2_induction_targets.py +++ b/scripts/discover_gpt2_induction_targets.py @@ -14,7 +14,9 @@ By default, `prev`/active are built from bigram prefix matches (the token at q-1 maps to the *following* token after its previous occurrence), and heads are ranked by attention to `prev`. Use `--score-mode=copy` or `--score-mode=attn_copy` -to include OV/copying alignment in the ranking. +to include OV/copying alignment in the ranking. Use `--use-activations` to +score heads using real layer activations from a HuggingFace GPT-2 model rather +than the embedding-only approximation stored in the NFP file. """ from __future__ import annotations @@ -23,7 +25,7 @@ import json import os import subprocess -import sys +import math from dataclasses import dataclass from pathlib import Path from typing import Dict, Iterable, List, Tuple @@ -146,6 +148,144 @@ def layer_norm(x: np.ndarray, gamma: np.ndarray, beta: np.ndarray, eps: float) - return x_hat * gamma + beta +def skip_head_weights(f, model_dim: int, head_dim: int) -> None: + skip_f64(f, model_dim * head_dim) # wq + skip_f64(f, head_dim) # bq + skip_f64(f, model_dim * head_dim) # wk + skip_f64(f, head_dim) # bk + skip_f64(f, model_dim * head_dim) # wv + skip_f64(f, head_dim) # bv + skip_f64(f, head_dim * model_dim) # wo + + +def skip_layer_weights( + f, + model_dim: int, + head_dim: int, + num_heads: int, + hidden_dim: int, +) -> None: + for _ in range(num_heads): + skip_head_weights(f, model_dim, head_dim) + skip_f64(f, model_dim) # attn bias + skip_f64(f, model_dim * hidden_dim) + skip_f64(f, hidden_dim) + skip_f64(f, hidden_dim * model_dim) + skip_f64(f, model_dim) + skip_f64(f, model_dim) # ln1 gamma + skip_f64(f, model_dim) # ln1 beta + skip_f64(f, model_dim) # ln2 gamma + skip_f64(f, model_dim) # ln2 beta + + +def load_hf_model_and_states(tokens: np.ndarray, model_name: str, device: str): + try: + import torch + from transformers import AutoModel + except ImportError as exc: # pragma: no cover - optional dependency + raise SystemExit( + "Activation mode requires torch + transformers. " + "Install them (e.g., `uv run --with torch --with transformers ...`)." + ) from exc + + torch.set_grad_enabled(False) + model = AutoModel.from_pretrained(model_name) + model.eval() + model.to(device) + input_ids = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0) + outputs = model(input_ids, output_hidden_states=True, use_cache=False) + hidden_states = outputs.hidden_states + if hidden_states is None: + raise SystemExit("HuggingFace model did not return hidden states.") + return model, hidden_states + + +def get_transformer_blocks(model): + if hasattr(model, "transformer"): + return model.transformer.h + if hasattr(model, "h"): + return model.h + raise SystemExit("Unsupported HuggingFace model structure (missing transformer blocks).") + + +def compute_head_data_from_activations( + model, + hidden_states, + layers: List[int], + heads: List[int], + prev: np.ndarray, + active_positions: List[int], + prev_indices: np.ndarray, + head_dim: int, + seq_len: int, +) -> Dict[ + Tuple[int, int], + Tuple[np.ndarray, np.ndarray, np.ndarray, float, float, float, float, float], +]: + try: + import torch + except ImportError as exc: # pragma: no cover - optional dependency + raise SystemExit("Activation mode requires torch.") from exc + + blocks = get_transformer_blocks(model) + head_data: Dict[ + Tuple[int, int], + Tuple[np.ndarray, np.ndarray, np.ndarray, float, float, float, float, float], + ] = {} + device = hidden_states[0].device + causal_mask = torch.triu( + torch.ones((seq_len, seq_len), device=device, dtype=torch.bool), + diagonal=1, + ) + for layer_idx in layers: + block = blocks[layer_idx] + hidden = hidden_states[layer_idx] + ln = block.ln_1(hidden) + qkv = block.attn.c_attn(ln) + split_size = getattr(block.attn, "split_size", qkv.shape[-1] // 3) + q, k, v = qkv.split(split_size, dim=2) + num_heads = getattr(block.attn, "num_heads", q.shape[-1] // head_dim) + head_dim_local = getattr(block.attn, "head_dim", head_dim) + if head_dim_local != head_dim: + raise SystemExit("HuggingFace head_dim does not match NFP header.") + scale = 1.0 / math.sqrt(head_dim_local) + q = block.attn._split_heads(q, num_heads, head_dim_local) + k = block.attn._split_heads(k, num_heads, head_dim_local) + v = block.attn._split_heads(v, num_heads, head_dim_local) + scores = torch.matmul(q, k.transpose(-2, -1)) * scale + scores = scores.masked_fill(causal_mask, -10000.0) + weights = torch.softmax(scores, dim=-1) + wo_full = block.attn.c_proj.weight + + for head_idx in heads: + weights_head = weights[0, head_idx] + scores_head = scores[0, head_idx] + weights_np = weights_head.detach().cpu().numpy() + scores_np = scores_head.detach().cpu().numpy() + eps, margin, prev_mean, prev_median, prev_top1 = compute_eps_margin( + weights_np, scores_np, prev, active_positions + ) + prev_weights = weights_np[np.array(active_positions), prev_indices] + + v_head = v[0, head_idx] + v_np = v_head.detach().cpu().numpy() + start = head_idx * head_dim_local + end = start + head_dim_local + wo_np = wo_full[start:end, :].detach().cpu().numpy() + + head_data[(layer_idx, head_idx)] = ( + v_np, + wo_np, + prev_weights, + eps, + margin, + prev_mean, + prev_median, + prev_top1, + ) + return head_data + + def softmax(scores: np.ndarray) -> np.ndarray: shift = scores - scores.max(axis=1, keepdims=True) exp = np.exp(shift) @@ -325,6 +465,21 @@ def main() -> int: parser.add_argument("--output", type=Path, default=Path("reports/gpt2_induction_discover.txt")) parser.add_argument("--json-out", type=Path, help="Optional JSON output path") parser.add_argument("--nfp-bin", help="Path to nfp binary (defaults to .lake/build/bin/nfp)") + parser.add_argument( + "--use-activations", + action="store_true", + help="Use HuggingFace GPT-2 activations for Q/K/V instead of embedding-only approximation.", + ) + parser.add_argument( + "--hf-model", + default="gpt2", + help="HuggingFace model name or path (activation mode).", + ) + parser.add_argument( + "--device", + default="cpu", + help="Torch device for activation mode (e.g. cpu, cuda, mps).", + ) args = parser.parse_args() if args.max_tokens <= 1: @@ -350,7 +505,11 @@ def main() -> int: heads = parse_index_list(args.heads, num_heads) or list(range(num_heads)) tokens = read_i32(f, seq_len) - embeddings = read_f64(f, seq_len * model_dim).reshape(seq_len, model_dim) + if args.use_activations: + skip_f64(f, seq_len * model_dim) + embeddings = None + else: + embeddings = read_f64(f, seq_len * model_dim).reshape(seq_len, model_dim) if args.prev_mode != "period" and args.period is not None: raise SystemExit("--period is incompatible with --prev-mode=token/bigram") @@ -389,7 +548,30 @@ def main() -> int: Tuple[np.ndarray, np.ndarray, np.ndarray, float, float, float, float, float], ] = {} + hf_model = None + hf_states = None + if args.use_activations: + hf_model, hf_states = load_hf_model_and_states(tokens, args.hf_model, args.device) + config = getattr(hf_model, "config", None) + if config is not None: + if getattr(config, "n_layer", num_layers) != num_layers: + raise SystemExit("HuggingFace model layer count does not match NFP header.") + if getattr(config, "n_head", num_heads) != num_heads: + raise SystemExit("HuggingFace model head count does not match NFP header.") + if getattr(config, "n_embd", model_dim) != model_dim: + raise SystemExit("HuggingFace model dimension does not match NFP header.") + if getattr(config, "vocab_size", vocab_size) != vocab_size: + raise SystemExit("HuggingFace vocab size does not match NFP header.") + if getattr(config, "n_positions", seq_len) < seq_len: + raise SystemExit("Prompt length exceeds HuggingFace model context.") + if len(hf_states) < num_layers + 1: + raise SystemExit("Hidden state count is smaller than expected.") + if hf_states[0].shape[1] != seq_len: + raise SystemExit("Hidden state sequence length does not match NFP header.") for layer_idx in range(num_layers): + if args.use_activations: + skip_layer_weights(f, model_dim, head_dim, num_heads, hidden_dim) + continue head_weights = [] for _ in range(num_heads): wq = read_f64(f, model_dim * head_dim).reshape(model_dim, head_dim) @@ -457,6 +639,19 @@ def main() -> int: tok, ) + if args.use_activations: + head_data = compute_head_data_from_activations( + hf_model, + hf_states, + layers, + heads, + prev, + active_positions, + prev_indices, + head_dim, + seq_len, + ) + results: List[HeadResult] = [] attn_results: List[AttnResult] = [] for (layer_idx, head_idx), ( @@ -554,6 +749,9 @@ def main() -> int: f.write("Induction discovery (approximate ranking)\n") f.write(f"model={args.model}\n") f.write(f"score_mode={args.score_mode}\n") + f.write(f"use_activations={args.use_activations}\n") + if args.use_activations: + f.write(f"hf_model={args.hf_model} device={args.device}\n") f.write(f"tokens={len(unique_tokens)} active={len(active_positions)}\n") f.write( f"min-eps={args.min_eps} min-margin={args.min_margin} " @@ -587,6 +785,9 @@ def main() -> int: "min_logit_lb": args.min_logit_lb, "min_score": args.min_score, "min_copy": args.min_copy, + "use_activations": args.use_activations, + "hf_model": args.hf_model if args.use_activations else None, + "device": args.device if args.use_activations else None, } if args.score_mode != "logit": payload["results"] = [ From 9ac3a53e87a5b668f3a2881994542b07cc6a6669 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 16 Jan 2026 21:45:27 +0100 Subject: [PATCH 201/244] Fix activation head split for transformers --- scripts/discover_gpt2_induction_targets.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/scripts/discover_gpt2_induction_targets.py b/scripts/discover_gpt2_induction_targets.py index b7c71fd..0d7034b 100644 --- a/scripts/discover_gpt2_induction_targets.py +++ b/scripts/discover_gpt2_induction_targets.py @@ -227,6 +227,11 @@ def compute_head_data_from_activations( except ImportError as exc: # pragma: no cover - optional dependency raise SystemExit("Activation mode requires torch.") from exc + def split_heads(x: "torch.Tensor", num_heads: int, head_dim_local: int) -> "torch.Tensor": + batch, seq, _ = x.shape + x = x.reshape(batch, seq, num_heads, head_dim_local) + return x.permute(0, 2, 1, 3) + blocks = get_transformer_blocks(model) head_data: Dict[ Tuple[int, int], @@ -249,9 +254,9 @@ def compute_head_data_from_activations( if head_dim_local != head_dim: raise SystemExit("HuggingFace head_dim does not match NFP header.") scale = 1.0 / math.sqrt(head_dim_local) - q = block.attn._split_heads(q, num_heads, head_dim_local) - k = block.attn._split_heads(k, num_heads, head_dim_local) - v = block.attn._split_heads(v, num_heads, head_dim_local) + q = split_heads(q, num_heads, head_dim_local) + k = split_heads(k, num_heads, head_dim_local) + v = split_heads(v, num_heads, head_dim_local) scores = torch.matmul(q, k.transpose(-2, -1)) * scale scores = scores.masked_fill(causal_mask, -10000.0) weights = torch.softmax(scores, dim=-1) From 9cbd7839abad30ceddb880ab1e0138ceffe733d1 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 16 Jan 2026 22:07:46 +0100 Subject: [PATCH 202/244] Add induction stripe attention scoring --- scripts/discover_gpt2_induction_targets.py | 162 +++++++++++++++++++-- 1 file changed, 146 insertions(+), 16 deletions(-) diff --git a/scripts/discover_gpt2_induction_targets.py b/scripts/discover_gpt2_induction_targets.py index 0d7034b..87cdc14 100644 --- a/scripts/discover_gpt2_induction_targets.py +++ b/scripts/discover_gpt2_induction_targets.py @@ -65,6 +65,19 @@ class AttnResult: active: int +@dataclass(frozen=True) +class StripeResult: + layer: int + head: int + score: float + stripe_mean: float + stripe_median: float + stripe_top1_frac: float + eps: float + margin: float + active: int + + def parse_header(f) -> Dict[str, str]: header: Dict[str, str] = {} magic = f.readline().decode("ascii").strip() @@ -141,6 +154,16 @@ def build_prev_bigram(tokens: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: return prev, active +def build_prev_period(seq_len: int, period: int) -> Tuple[np.ndarray, np.ndarray]: + prev = np.zeros(seq_len, dtype=np.int64) + active = np.zeros(seq_len, dtype=bool) + idx = np.arange(seq_len) + mask = idx >= period + prev[mask] = idx[mask] - period + active[mask] = True + return prev, active + + def layer_norm(x: np.ndarray, gamma: np.ndarray, beta: np.ndarray, eps: float) -> np.ndarray: mean = x.mean(axis=1, keepdims=True) var = ((x - mean) ** 2).mean(axis=1, keepdims=True) @@ -218,9 +241,14 @@ def compute_head_data_from_activations( prev_indices: np.ndarray, head_dim: int, seq_len: int, -) -> Dict[ - Tuple[int, int], - Tuple[np.ndarray, np.ndarray, np.ndarray, float, float, float, float, float], + stripe_prev: np.ndarray | None = None, + stripe_positions: List[int] | None = None, +) -> Tuple[ + Dict[ + Tuple[int, int], + Tuple[np.ndarray, np.ndarray, np.ndarray, float, float, float, float, float], + ], + Dict[Tuple[int, int], Tuple[float, float, float, float, float]], ]: try: import torch @@ -237,6 +265,7 @@ def split_heads(x: "torch.Tensor", num_heads: int, head_dim_local: int) -> "torc Tuple[int, int], Tuple[np.ndarray, np.ndarray, np.ndarray, float, float, float, float, float], ] = {} + stripe_data: Dict[Tuple[int, int], Tuple[float, float, float, float, float]] = {} device = hidden_states[0].device causal_mask = torch.triu( torch.ones((seq_len, seq_len), device=device, dtype=torch.bool), @@ -270,6 +299,17 @@ def split_heads(x: "torch.Tensor", num_heads: int, head_dim_local: int) -> "torc eps, margin, prev_mean, prev_median, prev_top1 = compute_eps_margin( weights_np, scores_np, prev, active_positions ) + if stripe_prev is not None and stripe_positions is not None: + eps_s, margin_s, stripe_mean, stripe_median, stripe_top1 = compute_eps_margin( + weights_np, scores_np, stripe_prev, stripe_positions + ) + stripe_data[(layer_idx, head_idx)] = ( + stripe_mean, + stripe_median, + stripe_top1, + eps_s, + margin_s, + ) prev_weights = weights_np[np.array(active_positions), prev_indices] v_head = v[0, head_idx] @@ -288,7 +328,7 @@ def split_heads(x: "torch.Tensor", num_heads: int, head_dim_local: int) -> "torc prev_median, prev_top1, ) - return head_data + return head_data, stripe_data def softmax(scores: np.ndarray) -> np.ndarray: @@ -423,6 +463,17 @@ def format_attn_result(result: AttnResult) -> str: ) +def format_stripe_result(result: StripeResult) -> str: + layer = result.layer + 1 + head = result.head + 1 + return ( + f"L{layer}H{head} score={result.score:.6f} " + f"stripeMean={result.stripe_mean:.6f} stripeMedian={result.stripe_median:.6f} " + f"stripeTop1={result.stripe_top1_frac:.3f} " + f"eps={result.eps:.6f} margin={result.margin:.6f} active={result.active}" + ) + + def main() -> int: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--model", required=True, type=Path, help="Path to NFP_BINARY_V1 model") @@ -433,11 +484,12 @@ def main() -> int: help="Run verifier on the top N candidates") parser.add_argument( "--score-mode", - choices=["attn", "copy", "attn_copy", "logit"], + choices=["attn", "copy", "attn_copy", "stripe", "logit"], default="attn", help=( "Rank heads by attention to prev (attn), OV copy score (copy), " - "attention-weighted copy score (attn_copy), or logit lower bound (logit)." + "attention-weighted copy score (attn_copy), induction stripe attention " + "(stripe), or logit lower bound (logit)." ), ) parser.add_argument( @@ -467,6 +519,11 @@ def main() -> int: default="bigram", help="Choose prev/active construction (default: bigram prefix match).", ) + parser.add_argument( + "--stripe-period", + type=int, + help="Period for induction stripe scoring (required for --score-mode=stripe).", + ) parser.add_argument("--output", type=Path, default=Path("reports/gpt2_induction_discover.txt")) parser.add_argument("--json-out", type=Path, help="Optional JSON output path") parser.add_argument("--nfp-bin", help="Path to nfp binary (defaults to .lake/build/bin/nfp)") @@ -491,6 +548,8 @@ def main() -> int: raise SystemExit("max-tokens must be at least 2") if args.verify_top > 0 and args.score_mode != "logit": raise SystemExit("--verify-top requires --score-mode=logit") + if args.score_mode == "stripe" and args.stripe_period is None: + raise SystemExit("--score-mode=stripe requires --stripe-period") if not args.model.exists(): raise SystemExit(f"Missing model file: {args.model}") @@ -523,9 +582,7 @@ def main() -> int: if args.prev_mode == "period": period = int(args.period) - prev = np.arange(seq_len, dtype=np.int64) - prev = np.where(prev >= period, prev - period, 0) - active_mask = np.arange(seq_len) >= period + prev, active_mask = build_prev_period(seq_len, period) elif args.prev_mode == "bigram": prev, active_mask = build_prev_bigram(tokens) else: @@ -535,6 +592,14 @@ def main() -> int: if not active_positions: raise SystemExit("No active positions found in the prompt") prev_indices = prev[np.array(active_positions, dtype=np.int64)] + stripe_prev = None + stripe_positions = None + if args.score_mode == "stripe": + stripe_period = int(args.stripe_period) + stripe_prev, stripe_active = build_prev_period(seq_len, stripe_period) + stripe_positions = [int(i) for i, flag in enumerate(stripe_active) if flag] + if not stripe_positions: + raise SystemExit("No stripe positions found for the requested period") prompt_tokens = sorted({int(tok) for tok in tokens.tolist()}) unique_tokens = [] @@ -552,6 +617,7 @@ def main() -> int: Tuple[int, int], Tuple[np.ndarray, np.ndarray, np.ndarray, float, float, float, float, float], ] = {} + stripe_data: Dict[Tuple[int, int], Tuple[float, float, float, float, float]] = {} hf_model = None hf_states = None @@ -617,6 +683,17 @@ def main() -> int: eps, margin, prev_mean, prev_median, prev_top1 = compute_eps_margin( weights, scores, prev, active_positions ) + if stripe_prev is not None and stripe_positions is not None: + eps_s, margin_s, stripe_mean, stripe_median, stripe_top1 = compute_eps_margin( + weights, scores, stripe_prev, stripe_positions + ) + stripe_data[(layer_idx, head_idx)] = ( + stripe_mean, + stripe_median, + stripe_top1, + eps_s, + margin_s, + ) prev_weights = weights[np.array(active_positions), prev_indices] head_data[(layer_idx, head_idx)] = ( v, @@ -645,7 +722,7 @@ def main() -> int: ) if args.use_activations: - head_data = compute_head_data_from_activations( + head_data, stripe_data = compute_head_data_from_activations( hf_model, hf_states, layers, @@ -655,10 +732,13 @@ def main() -> int: prev_indices, head_dim, seq_len, + stripe_prev=stripe_prev, + stripe_positions=stripe_positions, ) results: List[HeadResult] = [] attn_results: List[AttnResult] = [] + stripe_results: List[StripeResult] = [] for (layer_idx, head_idx), ( v, wo, @@ -672,6 +752,28 @@ def main() -> int: if args.score_mode == "logit": if eps > args.min_eps or margin < args.min_margin: continue + if args.score_mode == "stripe": + stripe = stripe_data.get((layer_idx, head_idx)) + if stripe is None: + continue + stripe_mean, stripe_median, stripe_top1, eps_s, margin_s = stripe + score = stripe_mean + if score < args.min_score: + continue + stripe_results.append( + StripeResult( + layer=layer_idx, + head=head_idx, + score=score, + stripe_mean=stripe_mean, + stripe_median=stripe_median, + stripe_top1_frac=stripe_top1, + eps=eps_s, + margin=margin_s, + active=len(stripe_positions) if stripe_positions is not None else 0, + ) + ) + continue if args.score_mode != "logit": ov = v @ wo copy_mean, copy_weighted_mean = compute_copy_scores( @@ -745,11 +847,14 @@ def main() -> int: if best is not None: results.append(best) - if args.score_mode != "logit": + if args.score_mode == "stripe": + stripe_results.sort(key=lambda r: r.score, reverse=True) + elif args.score_mode != "logit": attn_results.sort(key=lambda r: r.score, reverse=True) else: results.sort(key=lambda r: r.logit_lb, reverse=True) args.output.parent.mkdir(parents=True, exist_ok=True) + active_count = len(stripe_positions) if args.score_mode == "stripe" and stripe_positions else len(active_positions) with args.output.open("w", encoding="ascii") as f: f.write("Induction discovery (approximate ranking)\n") f.write(f"model={args.model}\n") @@ -757,13 +862,18 @@ def main() -> int: f.write(f"use_activations={args.use_activations}\n") if args.use_activations: f.write(f"hf_model={args.hf_model} device={args.device}\n") - f.write(f"tokens={len(unique_tokens)} active={len(active_positions)}\n") + f.write(f"tokens={len(unique_tokens)} active={active_count}\n") + if args.score_mode == "stripe": + f.write(f"stripe_period={args.stripe_period}\n") f.write( f"min-eps={args.min_eps} min-margin={args.min_margin} " f"min-logit-lb={args.min_logit_lb} min-score={args.min_score} " f"min-copy={args.min_copy}\n" ) - if args.score_mode != "logit": + if args.score_mode == "stripe": + for rank, result in enumerate(stripe_results[: args.top], start=1): + f.write(f"{rank:02d} {format_stripe_result(result)}\n") + elif args.score_mode != "logit": for rank, result in enumerate(attn_results[: args.top], start=1): f.write(f"{rank:02d} {format_attn_result(result)}\n") else: @@ -771,7 +881,10 @@ def main() -> int: f.write(f"{rank:02d} {format_result(result)}\n") print(f"Wrote report to {args.output}") - if args.score_mode != "logit": + if args.score_mode == "stripe": + for rank, result in enumerate(stripe_results[: args.top], start=1): + print(f"{rank:02d} {format_stripe_result(result)}") + elif args.score_mode != "logit": for rank, result in enumerate(attn_results[: args.top], start=1): print(f"{rank:02d} {format_attn_result(result)}") else: @@ -783,7 +896,7 @@ def main() -> int: payload = { "model": str(args.model), "tokens": len(unique_tokens), - "active": len(active_positions), + "active": active_count, "score_mode": args.score_mode, "min_eps": args.min_eps, "min_margin": args.min_margin, @@ -793,8 +906,25 @@ def main() -> int: "use_activations": args.use_activations, "hf_model": args.hf_model if args.use_activations else None, "device": args.device if args.use_activations else None, + "stripe_period": args.stripe_period if args.score_mode == "stripe" else None, } - if args.score_mode != "logit": + if args.score_mode == "stripe": + payload["results"] = [ + { + "rank": rank, + "layer": r.layer + 1, + "head": r.head + 1, + "score": r.score, + "stripe_mean": r.stripe_mean, + "stripe_median": r.stripe_median, + "stripe_top1_frac": r.stripe_top1_frac, + "eps": r.eps, + "margin": r.margin, + "active": r.active, + } + for rank, r in enumerate(stripe_results[: args.top], start=1) + ] + elif args.score_mode != "logit": payload["results"] = [ { "rank": rank, From 14417002e60b4cae3c05ac4973483a63591737bd Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 16 Jan 2026 22:20:47 +0100 Subject: [PATCH 203/244] Add circuit scoring for induction discovery --- scripts/discover_gpt2_induction_targets.py | 149 +++++++++++++++++++-- 1 file changed, 136 insertions(+), 13 deletions(-) diff --git a/scripts/discover_gpt2_induction_targets.py b/scripts/discover_gpt2_induction_targets.py index 87cdc14..6322501 100644 --- a/scripts/discover_gpt2_induction_targets.py +++ b/scripts/discover_gpt2_induction_targets.py @@ -78,6 +78,21 @@ class StripeResult: active: int +@dataclass(frozen=True) +class CircuitResult: + prev_layer: int + prev_head: int + induction_layer: int + induction_head: int + score: float + prev_mean: float + prev_median: float + prev_top1_frac: float + stripe_mean: float + stripe_median: float + stripe_top1_frac: float + + def parse_header(f) -> Dict[str, str]: header: Dict[str, str] = {} magic = f.readline().decode("ascii").strip() @@ -243,12 +258,15 @@ def compute_head_data_from_activations( seq_len: int, stripe_prev: np.ndarray | None = None, stripe_positions: List[int] | None = None, + prevtok_prev: np.ndarray | None = None, + prevtok_positions: List[int] | None = None, ) -> Tuple[ Dict[ Tuple[int, int], Tuple[np.ndarray, np.ndarray, np.ndarray, float, float, float, float, float], ], Dict[Tuple[int, int], Tuple[float, float, float, float, float]], + Dict[Tuple[int, int], Tuple[float, float, float, float, float]], ]: try: import torch @@ -266,6 +284,7 @@ def split_heads(x: "torch.Tensor", num_heads: int, head_dim_local: int) -> "torc Tuple[np.ndarray, np.ndarray, np.ndarray, float, float, float, float, float], ] = {} stripe_data: Dict[Tuple[int, int], Tuple[float, float, float, float, float]] = {} + prevtok_data: Dict[Tuple[int, int], Tuple[float, float, float, float, float]] = {} device = hidden_states[0].device causal_mask = torch.triu( torch.ones((seq_len, seq_len), device=device, dtype=torch.bool), @@ -310,6 +329,17 @@ def split_heads(x: "torch.Tensor", num_heads: int, head_dim_local: int) -> "torc eps_s, margin_s, ) + if prevtok_prev is not None and prevtok_positions is not None: + eps_p, margin_p, prevtok_mean, prevtok_median, prevtok_top1 = compute_eps_margin( + weights_np, scores_np, prevtok_prev, prevtok_positions + ) + prevtok_data[(layer_idx, head_idx)] = ( + prevtok_mean, + prevtok_median, + prevtok_top1, + eps_p, + margin_p, + ) prev_weights = weights_np[np.array(active_positions), prev_indices] v_head = v[0, head_idx] @@ -328,7 +358,7 @@ def split_heads(x: "torch.Tensor", num_heads: int, head_dim_local: int) -> "torc prev_median, prev_top1, ) - return head_data, stripe_data + return head_data, stripe_data, prevtok_data def softmax(scores: np.ndarray) -> np.ndarray: @@ -474,6 +504,20 @@ def format_stripe_result(result: StripeResult) -> str: ) +def format_circuit_result(result: CircuitResult) -> str: + prev_layer = result.prev_layer + 1 + prev_head = result.prev_head + 1 + ind_layer = result.induction_layer + 1 + ind_head = result.induction_head + 1 + return ( + f"prev=L{prev_layer}H{prev_head} ind=L{ind_layer}H{ind_head} " + f"score={result.score:.6f} prevMean={result.prev_mean:.6f} " + f"prevMedian={result.prev_median:.6f} prevTop1={result.prev_top1_frac:.3f} " + f"stripeMean={result.stripe_mean:.6f} stripeMedian={result.stripe_median:.6f} " + f"stripeTop1={result.stripe_top1_frac:.3f}" + ) + + def main() -> int: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--model", required=True, type=Path, help="Path to NFP_BINARY_V1 model") @@ -484,12 +528,12 @@ def main() -> int: help="Run verifier on the top N candidates") parser.add_argument( "--score-mode", - choices=["attn", "copy", "attn_copy", "stripe", "logit"], + choices=["attn", "copy", "attn_copy", "stripe", "circuit", "logit"], default="attn", help=( "Rank heads by attention to prev (attn), OV copy score (copy), " "attention-weighted copy score (attn_copy), induction stripe attention " - "(stripe), or logit lower bound (logit)." + "(stripe), circuit pairing (circuit), or logit lower bound (logit)." ), ) parser.add_argument( @@ -548,8 +592,8 @@ def main() -> int: raise SystemExit("max-tokens must be at least 2") if args.verify_top > 0 and args.score_mode != "logit": raise SystemExit("--verify-top requires --score-mode=logit") - if args.score_mode == "stripe" and args.stripe_period is None: - raise SystemExit("--score-mode=stripe requires --stripe-period") + if args.score_mode in {"stripe", "circuit"} and args.stripe_period is None: + raise SystemExit("--score-mode=stripe/circuit requires --stripe-period") if not args.model.exists(): raise SystemExit(f"Missing model file: {args.model}") @@ -594,12 +638,19 @@ def main() -> int: prev_indices = prev[np.array(active_positions, dtype=np.int64)] stripe_prev = None stripe_positions = None - if args.score_mode == "stripe": + if args.score_mode in {"stripe", "circuit"}: stripe_period = int(args.stripe_period) stripe_prev, stripe_active = build_prev_period(seq_len, stripe_period) stripe_positions = [int(i) for i, flag in enumerate(stripe_active) if flag] if not stripe_positions: raise SystemExit("No stripe positions found for the requested period") + prevtok_prev = None + prevtok_positions = None + if args.score_mode == "circuit": + prevtok_prev, prevtok_active = build_prev_period(seq_len, 1) + prevtok_positions = [int(i) for i, flag in enumerate(prevtok_active) if flag] + if not prevtok_positions: + raise SystemExit("No previous-token positions found") prompt_tokens = sorted({int(tok) for tok in tokens.tolist()}) unique_tokens = [] @@ -618,6 +669,7 @@ def main() -> int: Tuple[np.ndarray, np.ndarray, np.ndarray, float, float, float, float, float], ] = {} stripe_data: Dict[Tuple[int, int], Tuple[float, float, float, float, float]] = {} + prevtok_data: Dict[Tuple[int, int], Tuple[float, float, float, float, float]] = {} hf_model = None hf_states = None @@ -694,6 +746,17 @@ def main() -> int: eps_s, margin_s, ) + if prevtok_prev is not None and prevtok_positions is not None: + eps_p, margin_p, prevtok_mean, prevtok_median, prevtok_top1 = compute_eps_margin( + weights, scores, prevtok_prev, prevtok_positions + ) + prevtok_data[(layer_idx, head_idx)] = ( + prevtok_mean, + prevtok_median, + prevtok_top1, + eps_p, + margin_p, + ) prev_weights = weights[np.array(active_positions), prev_indices] head_data[(layer_idx, head_idx)] = ( v, @@ -722,7 +785,7 @@ def main() -> int: ) if args.use_activations: - head_data, stripe_data = compute_head_data_from_activations( + head_data, stripe_data, prevtok_data = compute_head_data_from_activations( hf_model, hf_states, layers, @@ -734,11 +797,14 @@ def main() -> int: seq_len, stripe_prev=stripe_prev, stripe_positions=stripe_positions, + prevtok_prev=prevtok_prev, + prevtok_positions=prevtok_positions, ) results: List[HeadResult] = [] attn_results: List[AttnResult] = [] stripe_results: List[StripeResult] = [] + circuit_results: List[CircuitResult] = [] for (layer_idx, head_idx), ( v, wo, @@ -774,6 +840,37 @@ def main() -> int: ) ) continue + if args.score_mode == "circuit": + stripe = stripe_data.get((layer_idx, head_idx)) + if stripe is None: + continue + stripe_mean, stripe_median, stripe_top1, _eps_s, _margin_s = stripe + best_prev: CircuitResult | None = None + for (prev_layer, prev_head), prev_stats in prevtok_data.items(): + if prev_layer >= layer_idx: + continue + prev_mean, prev_median, prev_top1, _eps_p, _margin_p = prev_stats + score = prev_mean * stripe_mean + if score < args.min_score: + continue + candidate = CircuitResult( + prev_layer=prev_layer, + prev_head=prev_head, + induction_layer=layer_idx, + induction_head=head_idx, + score=score, + prev_mean=prev_mean, + prev_median=prev_median, + prev_top1_frac=prev_top1, + stripe_mean=stripe_mean, + stripe_median=stripe_median, + stripe_top1_frac=stripe_top1, + ) + if best_prev is None or candidate.score > best_prev.score: + best_prev = candidate + if best_prev is not None: + circuit_results.append(best_prev) + continue if args.score_mode != "logit": ov = v @ wo copy_mean, copy_weighted_mean = compute_copy_scores( @@ -847,14 +944,16 @@ def main() -> int: if best is not None: results.append(best) - if args.score_mode == "stripe": + if args.score_mode == "circuit": + circuit_results.sort(key=lambda r: r.score, reverse=True) + elif args.score_mode == "stripe": stripe_results.sort(key=lambda r: r.score, reverse=True) elif args.score_mode != "logit": attn_results.sort(key=lambda r: r.score, reverse=True) else: results.sort(key=lambda r: r.logit_lb, reverse=True) args.output.parent.mkdir(parents=True, exist_ok=True) - active_count = len(stripe_positions) if args.score_mode == "stripe" and stripe_positions else len(active_positions) + active_count = len(stripe_positions) if args.score_mode in {"stripe", "circuit"} and stripe_positions else len(active_positions) with args.output.open("w", encoding="ascii") as f: f.write("Induction discovery (approximate ranking)\n") f.write(f"model={args.model}\n") @@ -863,14 +962,17 @@ def main() -> int: if args.use_activations: f.write(f"hf_model={args.hf_model} device={args.device}\n") f.write(f"tokens={len(unique_tokens)} active={active_count}\n") - if args.score_mode == "stripe": + if args.score_mode in {"stripe", "circuit"}: f.write(f"stripe_period={args.stripe_period}\n") f.write( f"min-eps={args.min_eps} min-margin={args.min_margin} " f"min-logit-lb={args.min_logit_lb} min-score={args.min_score} " f"min-copy={args.min_copy}\n" ) - if args.score_mode == "stripe": + if args.score_mode == "circuit": + for rank, result in enumerate(circuit_results[: args.top], start=1): + f.write(f"{rank:02d} {format_circuit_result(result)}\n") + elif args.score_mode == "stripe": for rank, result in enumerate(stripe_results[: args.top], start=1): f.write(f"{rank:02d} {format_stripe_result(result)}\n") elif args.score_mode != "logit": @@ -881,7 +983,10 @@ def main() -> int: f.write(f"{rank:02d} {format_result(result)}\n") print(f"Wrote report to {args.output}") - if args.score_mode == "stripe": + if args.score_mode == "circuit": + for rank, result in enumerate(circuit_results[: args.top], start=1): + print(f"{rank:02d} {format_circuit_result(result)}") + elif args.score_mode == "stripe": for rank, result in enumerate(stripe_results[: args.top], start=1): print(f"{rank:02d} {format_stripe_result(result)}") elif args.score_mode != "logit": @@ -908,7 +1013,25 @@ def main() -> int: "device": args.device if args.use_activations else None, "stripe_period": args.stripe_period if args.score_mode == "stripe" else None, } - if args.score_mode == "stripe": + if args.score_mode == "circuit": + payload["results"] = [ + { + "rank": rank, + "prev_layer": r.prev_layer + 1, + "prev_head": r.prev_head + 1, + "induction_layer": r.induction_layer + 1, + "induction_head": r.induction_head + 1, + "score": r.score, + "prev_mean": r.prev_mean, + "prev_median": r.prev_median, + "prev_top1_frac": r.prev_top1_frac, + "stripe_mean": r.stripe_mean, + "stripe_median": r.stripe_median, + "stripe_top1_frac": r.stripe_top1_frac, + } + for rank, r in enumerate(circuit_results[: args.top], start=1) + ] + elif args.score_mode == "stripe": payload["results"] = [ { "rank": rank, From 873c5105c39ed9d9a99998b43a79cb95dade65f7 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 16 Jan 2026 22:26:22 +0100 Subject: [PATCH 204/244] Add circuit+copy induction scoring --- scripts/discover_gpt2_induction_targets.py | 131 +++++++++++++++++++-- 1 file changed, 120 insertions(+), 11 deletions(-) diff --git a/scripts/discover_gpt2_induction_targets.py b/scripts/discover_gpt2_induction_targets.py index 6322501..e6946cf 100644 --- a/scripts/discover_gpt2_induction_targets.py +++ b/scripts/discover_gpt2_induction_targets.py @@ -93,6 +93,23 @@ class CircuitResult: stripe_top1_frac: float +@dataclass(frozen=True) +class CircuitCopyResult: + prev_layer: int + prev_head: int + induction_layer: int + induction_head: int + score: float + prev_mean: float + prev_median: float + prev_top1_frac: float + stripe_mean: float + stripe_median: float + stripe_top1_frac: float + copy_mean: float + copy_weighted_mean: float + + def parse_header(f) -> Dict[str, str]: header: Dict[str, str] = {} magic = f.readline().decode("ascii").strip() @@ -518,6 +535,21 @@ def format_circuit_result(result: CircuitResult) -> str: ) +def format_circuit_copy_result(result: CircuitCopyResult) -> str: + prev_layer = result.prev_layer + 1 + prev_head = result.prev_head + 1 + ind_layer = result.induction_layer + 1 + ind_head = result.induction_head + 1 + return ( + f"prev=L{prev_layer}H{prev_head} ind=L{ind_layer}H{ind_head} " + f"score={result.score:.6f} prevMean={result.prev_mean:.6f} " + f"stripeMean={result.stripe_mean:.6f} copyMean={result.copy_mean:.6f} " + f"copyWeighted={result.copy_weighted_mean:.6f} " + f"prevMedian={result.prev_median:.6f} prevTop1={result.prev_top1_frac:.3f} " + f"stripeMedian={result.stripe_median:.6f} stripeTop1={result.stripe_top1_frac:.3f}" + ) + + def main() -> int: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--model", required=True, type=Path, help="Path to NFP_BINARY_V1 model") @@ -528,12 +560,13 @@ def main() -> int: help="Run verifier on the top N candidates") parser.add_argument( "--score-mode", - choices=["attn", "copy", "attn_copy", "stripe", "circuit", "logit"], + choices=["attn", "copy", "attn_copy", "stripe", "circuit", "circuit_copy", "logit"], default="attn", help=( "Rank heads by attention to prev (attn), OV copy score (copy), " "attention-weighted copy score (attn_copy), induction stripe attention " - "(stripe), circuit pairing (circuit), or logit lower bound (logit)." + "(stripe), circuit pairing (circuit), circuit + copy (circuit_copy), " + "or logit lower bound (logit)." ), ) parser.add_argument( @@ -592,8 +625,10 @@ def main() -> int: raise SystemExit("max-tokens must be at least 2") if args.verify_top > 0 and args.score_mode != "logit": raise SystemExit("--verify-top requires --score-mode=logit") - if args.score_mode in {"stripe", "circuit"} and args.stripe_period is None: + if args.score_mode in {"stripe", "circuit", "circuit_copy"} and args.stripe_period is None: raise SystemExit("--score-mode=stripe/circuit requires --stripe-period") + if args.score_mode in {"circuit", "circuit_copy"} and args.prev_mode != "bigram": + raise SystemExit("--score-mode=circuit requires --prev-mode=bigram") if not args.model.exists(): raise SystemExit(f"Missing model file: {args.model}") @@ -638,7 +673,7 @@ def main() -> int: prev_indices = prev[np.array(active_positions, dtype=np.int64)] stripe_prev = None stripe_positions = None - if args.score_mode in {"stripe", "circuit"}: + if args.score_mode in {"stripe", "circuit", "circuit_copy"}: stripe_period = int(args.stripe_period) stripe_prev, stripe_active = build_prev_period(seq_len, stripe_period) stripe_positions = [int(i) for i, flag in enumerate(stripe_active) if flag] @@ -646,7 +681,7 @@ def main() -> int: raise SystemExit("No stripe positions found for the requested period") prevtok_prev = None prevtok_positions = None - if args.score_mode == "circuit": + if args.score_mode in {"circuit", "circuit_copy"}: prevtok_prev, prevtok_active = build_prev_period(seq_len, 1) prevtok_positions = [int(i) for i, flag in enumerate(prevtok_active) if flag] if not prevtok_positions: @@ -805,6 +840,7 @@ def main() -> int: attn_results: List[AttnResult] = [] stripe_results: List[StripeResult] = [] circuit_results: List[CircuitResult] = [] + circuit_copy_results: List[CircuitCopyResult] = [] for (layer_idx, head_idx), ( v, wo, @@ -871,6 +907,51 @@ def main() -> int: if best_prev is not None: circuit_results.append(best_prev) continue + if args.score_mode == "circuit_copy": + stripe = stripe_data.get((layer_idx, head_idx)) + if stripe is None: + continue + stripe_mean, stripe_median, stripe_top1, _eps_s, _margin_s = stripe + ov = v @ wo + copy_mean, copy_weighted_mean = compute_copy_scores( + ov, + prev_weights, + columns, + tokens, + prev, + active_positions, + ) + if args.min_copy is not None and copy_mean < args.min_copy: + continue + copy_score = max(copy_mean, 0.0) + best_prev: CircuitCopyResult | None = None + for (prev_layer, prev_head), prev_stats in prevtok_data.items(): + if prev_layer >= layer_idx: + continue + prev_mean, prev_median, prev_top1, _eps_p, _margin_p = prev_stats + score = prev_mean * stripe_mean * copy_score + if score < args.min_score: + continue + candidate = CircuitCopyResult( + prev_layer=prev_layer, + prev_head=prev_head, + induction_layer=layer_idx, + induction_head=head_idx, + score=score, + prev_mean=prev_mean, + prev_median=prev_median, + prev_top1_frac=prev_top1, + stripe_mean=stripe_mean, + stripe_median=stripe_median, + stripe_top1_frac=stripe_top1, + copy_mean=copy_mean, + copy_weighted_mean=copy_weighted_mean, + ) + if best_prev is None or candidate.score > best_prev.score: + best_prev = candidate + if best_prev is not None: + circuit_copy_results.append(best_prev) + continue if args.score_mode != "logit": ov = v @ wo copy_mean, copy_weighted_mean = compute_copy_scores( @@ -944,7 +1025,9 @@ def main() -> int: if best is not None: results.append(best) - if args.score_mode == "circuit": + if args.score_mode == "circuit_copy": + circuit_copy_results.sort(key=lambda r: r.score, reverse=True) + elif args.score_mode == "circuit": circuit_results.sort(key=lambda r: r.score, reverse=True) elif args.score_mode == "stripe": stripe_results.sort(key=lambda r: r.score, reverse=True) @@ -953,7 +1036,7 @@ def main() -> int: else: results.sort(key=lambda r: r.logit_lb, reverse=True) args.output.parent.mkdir(parents=True, exist_ok=True) - active_count = len(stripe_positions) if args.score_mode in {"stripe", "circuit"} and stripe_positions else len(active_positions) + active_count = len(stripe_positions) if args.score_mode in {"stripe", "circuit", "circuit_copy"} and stripe_positions else len(active_positions) with args.output.open("w", encoding="ascii") as f: f.write("Induction discovery (approximate ranking)\n") f.write(f"model={args.model}\n") @@ -962,14 +1045,17 @@ def main() -> int: if args.use_activations: f.write(f"hf_model={args.hf_model} device={args.device}\n") f.write(f"tokens={len(unique_tokens)} active={active_count}\n") - if args.score_mode in {"stripe", "circuit"}: + if args.score_mode in {"stripe", "circuit", "circuit_copy"}: f.write(f"stripe_period={args.stripe_period}\n") f.write( f"min-eps={args.min_eps} min-margin={args.min_margin} " f"min-logit-lb={args.min_logit_lb} min-score={args.min_score} " f"min-copy={args.min_copy}\n" ) - if args.score_mode == "circuit": + if args.score_mode == "circuit_copy": + for rank, result in enumerate(circuit_copy_results[: args.top], start=1): + f.write(f"{rank:02d} {format_circuit_copy_result(result)}\n") + elif args.score_mode == "circuit": for rank, result in enumerate(circuit_results[: args.top], start=1): f.write(f"{rank:02d} {format_circuit_result(result)}\n") elif args.score_mode == "stripe": @@ -983,7 +1069,10 @@ def main() -> int: f.write(f"{rank:02d} {format_result(result)}\n") print(f"Wrote report to {args.output}") - if args.score_mode == "circuit": + if args.score_mode == "circuit_copy": + for rank, result in enumerate(circuit_copy_results[: args.top], start=1): + print(f"{rank:02d} {format_circuit_copy_result(result)}") + elif args.score_mode == "circuit": for rank, result in enumerate(circuit_results[: args.top], start=1): print(f"{rank:02d} {format_circuit_result(result)}") elif args.score_mode == "stripe": @@ -1013,7 +1102,27 @@ def main() -> int: "device": args.device if args.use_activations else None, "stripe_period": args.stripe_period if args.score_mode == "stripe" else None, } - if args.score_mode == "circuit": + if args.score_mode == "circuit_copy": + payload["results"] = [ + { + "rank": rank, + "prev_layer": r.prev_layer + 1, + "prev_head": r.prev_head + 1, + "induction_layer": r.induction_layer + 1, + "induction_head": r.induction_head + 1, + "score": r.score, + "prev_mean": r.prev_mean, + "prev_median": r.prev_median, + "prev_top1_frac": r.prev_top1_frac, + "stripe_mean": r.stripe_mean, + "stripe_median": r.stripe_median, + "stripe_top1_frac": r.stripe_top1_frac, + "copy_mean": r.copy_mean, + "copy_weighted_mean": r.copy_weighted_mean, + } + for rank, r in enumerate(circuit_copy_results[: args.top], start=1) + ] + elif args.score_mode == "circuit": payload["results"] = [ { "rank": rank, From 33832d51b164d3f793f35fcf5d644e82177b04ea Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 16 Jan 2026 22:46:23 +0100 Subject: [PATCH 205/244] Add seq_len=50 induction diagnostic script --- scripts/diagnose_induction_heads.py | 177 ++++++++++++++++++++++++++++ 1 file changed, 177 insertions(+) create mode 100644 scripts/diagnose_induction_heads.py diff --git a/scripts/diagnose_induction_heads.py b/scripts/diagnose_induction_heads.py new file mode 100644 index 0000000..b35976f --- /dev/null +++ b/scripts/diagnose_induction_heads.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: AGPL-3.0-or-later + +""" +Diagnostic scan for induction heads on repeated random sequences. + +This mirrors the common literature setup: +- build a batch of repeated random token sequences (pattern_len repeated twice), +- run GPT-2 with output_attentions, +- rank heads by induction stripe attention (q -> q - period), +- rank heads by previous-token attention (q -> q - 1). +""" + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + +import numpy as np + +try: + import torch + from transformers import GPT2Model, GPT2Tokenizer +except ImportError as exc: + raise SystemExit("Requires torch + transformers (use `uv run --with torch --with transformers`).") from exc + + +def select_vocab_candidates( + tokenizer, + vocab_min: int, + vocab_max: int, + min_word_length: int, + require_leading_space: bool, +) -> list[int]: + candidates = [] + for tid in range(vocab_min, vocab_max): + word = tokenizer.decode([tid]) + if len(word.strip()) <= min_word_length: + continue + if require_leading_space and not word.startswith(" "): + continue + candidates.append(tid) + return candidates + + +def build_batch( + rng: np.random.Generator, + candidates: list[int], + batch_size: int, + pattern_len: int, + seq_len: int, +) -> np.ndarray: + if seq_len != 2 * pattern_len: + raise ValueError("seq_len must equal 2 * pattern_len for repeated sequence diagnostic") + if len(candidates) < pattern_len: + raise ValueError("Not enough vocab candidates for requested pattern length") + batch = np.zeros((batch_size, seq_len), dtype=np.int64) + for idx in range(batch_size): + pattern = rng.choice(candidates, size=pattern_len, replace=False) + batch[idx] = np.tile(pattern, 2) + return batch + + +def compute_stripe_stats(attn: torch.Tensor, period: int) -> tuple[torch.Tensor, torch.Tensor]: + batch, heads, seq_len, _ = attn.shape + q = torch.arange(period, seq_len, device=attn.device) + k = q - period + block = attn[:, :, q, :] + k_index = k.view(1, 1, -1, 1).expand(batch, heads, -1, 1) + stripe_vals = block.gather(dim=-1, index=k_index).squeeze(-1) + stripe_mean = stripe_vals.mean(dim=(0, 2)) + max_vals = block.max(dim=-1).values + stripe_top1 = (stripe_vals >= max_vals - 1e-12).float().mean(dim=(0, 2)) + return stripe_mean, stripe_top1 + + +def compute_prev_stats(attn: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + batch, heads, seq_len, _ = attn.shape + q = torch.arange(1, seq_len, device=attn.device) + k = q - 1 + block = attn[:, :, q, :] + k_index = k.view(1, 1, -1, 1).expand(batch, heads, -1, 1) + prev_vals = block.gather(dim=-1, index=k_index).squeeze(-1) + prev_mean = prev_vals.mean(dim=(0, 2)) + max_vals = block.max(dim=-1).values + prev_top1 = (prev_vals >= max_vals - 1e-12).float().mean(dim=(0, 2)) + return prev_mean, prev_top1 + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model", default="gpt2", help="HuggingFace model name") + parser.add_argument("--seq-len", type=int, default=50) + parser.add_argument("--pattern-len", type=int, default=25) + parser.add_argument("--batch", type=int, default=30) + parser.add_argument("--seed", type=int, default=1337) + parser.add_argument("--vocab-min", type=int, default=1000) + parser.add_argument("--vocab-max", type=int, default=5000) + parser.add_argument("--min-word-length", type=int, default=4) + parser.add_argument("--require-leading-space", action="store_true", default=True) + parser.add_argument("--allow-no-leading-space", action="store_true") + parser.add_argument("--device", default="cpu") + parser.add_argument("--attn-implementation", default="eager") + parser.add_argument("--top", type=int, default=20) + parser.add_argument("--output", type=Path, default=Path("reports/induction_diagnostic.txt")) + args = parser.parse_args() + + require_leading_space = args.require_leading_space and not args.allow_no_leading_space + rng = np.random.default_rng(args.seed) + + tokenizer = GPT2Tokenizer.from_pretrained(args.model) + model = GPT2Model.from_pretrained(args.model, attn_implementation=args.attn_implementation) + model.eval() + model.to(args.device) + + candidates = select_vocab_candidates( + tokenizer, + vocab_min=args.vocab_min, + vocab_max=args.vocab_max, + min_word_length=args.min_word_length, + require_leading_space=require_leading_space, + ) + batch_tokens = build_batch( + rng, + candidates, + batch_size=args.batch, + pattern_len=args.pattern_len, + seq_len=args.seq_len, + ) + input_ids = torch.tensor(batch_tokens, dtype=torch.long, device=args.device) + with torch.no_grad(): + outputs = model(input_ids, output_attentions=True, use_cache=False) + if outputs.attentions is None: + raise SystemExit("Model did not return attention weights.") + + stripe_scores = [] + prev_scores = [] + for layer_idx, attn in enumerate(outputs.attentions): + stripe_mean, stripe_top1 = compute_stripe_stats(attn, args.pattern_len) + prev_mean, prev_top1 = compute_prev_stats(attn) + for head_idx in range(attn.shape[1]): + stripe_scores.append((float(stripe_mean[head_idx]), float(stripe_top1[head_idx]), layer_idx, head_idx)) + prev_scores.append((float(prev_mean[head_idx]), float(prev_top1[head_idx]), layer_idx, head_idx)) + + stripe_scores.sort(key=lambda x: x[0], reverse=True) + prev_scores.sort(key=lambda x: x[0], reverse=True) + + args.output.parent.mkdir(parents=True, exist_ok=True) + with args.output.open("w", encoding="ascii") as f: + f.write("Induction diagnostic (repeated random sequence)\n") + f.write(f"model={args.model}\n") + f.write(f"seq_len={args.seq_len} pattern_len={args.pattern_len} batch={args.batch}\n") + f.write( + f"vocab_range=[{args.vocab_min},{args.vocab_max}) min_word_length={args.min_word_length} " + f"leading_space={require_leading_space}\n" + ) + f.write("\nTop induction stripe heads:\n") + for rank, (mean, top1, layer, head) in enumerate(stripe_scores[: args.top], start=1): + f.write(f"{rank:02d} L{layer+1}H{head+1} stripeMean={mean:.6f} stripeTop1={top1:.3f}\n") + f.write("\nTop previous-token heads:\n") + for rank, (mean, top1, layer, head) in enumerate(prev_scores[: args.top], start=1): + f.write(f"{rank:02d} L{layer+1}H{head+1} prevMean={mean:.6f} prevTop1={top1:.3f}\n") + + print(f"Wrote report to {args.output}") + print("\nTop induction stripe heads:") + for rank, (mean, top1, layer, head) in enumerate(stripe_scores[: args.top], start=1): + print(f"{rank:02d} L{layer+1}H{head+1} stripeMean={mean:.6f} stripeTop1={top1:.3f}") + print("\nTop previous-token heads:") + for rank, (mean, top1, layer, head) in enumerate(prev_scores[: args.top], start=1): + print(f"{rank:02d} L{layer+1}H{head+1} prevMean={mean:.6f} prevTop1={top1:.3f}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From 62e8f7270ff076f67d60b4be4bbd59b067e6115e Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 16 Jan 2026 23:17:20 +0100 Subject: [PATCH 206/244] Fix diagnostic token lemmas for induction prompts --- Nfp/Model/InductionPrompt.lean | 183 +++++++++++++++++++++++++++++++++ 1 file changed, 183 insertions(+) diff --git a/Nfp/Model/InductionPrompt.lean b/Nfp/Model/InductionPrompt.lean index 267d7d9..ca1fdac 100644 --- a/Nfp/Model/InductionPrompt.lean +++ b/Nfp/Model/InductionPrompt.lean @@ -232,6 +232,189 @@ structure InductionPrevSpecTokensShift {seq dModel dHead : Nat} /-- Prev map matches the shifted-token definition. -/ prev_eq : inputs.prev = prevOfTokensShift (seq := seq) tokens +/-- Helper: lift a first-half index into `Fin (2 * period)`. -/ +lemma lt_double_of_lt_period {period i : Nat} (hi : i < period) : i < 2 * period := by + have hle : period ≤ 2 * period := by + have hpos : 0 < (2 : Nat) := by decide + exact Nat.le_mul_of_pos_left period hpos + exact Nat.lt_of_lt_of_le hi hle + +/-- +Tokens are a repeated pattern of length `period` with no repeats in the first half. + +This matches the usual induction diagnostic: a random pattern of length `period` +repeated twice. +-/ +structure InductionDiagnosticTokens (period : Nat) (tokens : Fin (2 * period) → Nat) : Prop where + /-- Second half repeats the first half with period `period`. -/ + repeat_tok : ∀ q : Fin (2 * period), period ≤ q.val → + tokens q = tokens ⟨q.val - period, lt_of_le_of_lt (Nat.sub_le _ _) q.isLt⟩ + /-- The first-half tokens are pairwise distinct. -/ + inj : ∀ {i j : Nat} (hi : i < period) (hj : j < period), + tokens ⟨i, lt_double_of_lt_period hi⟩ = + tokens ⟨j, lt_double_of_lt_period hj⟩ → i = j + +/-- +In a diagnostic prompt (repeated distinct pattern), the previous matching token +for any query in the second half is exactly `q - period`. +-/ +theorem prevOfTokens_eq_prevOfPeriod_of_diag {period : Nat} + {tokens : Fin (2 * period) → Nat} (hdiag : InductionDiagnosticTokens period tokens) + (hper : 0 < period) {q : Fin (2 * period)} (hq : period ≤ q.val) : + prevOfTokens tokens q = prevOfPeriod (seq := 2 * period) period q := by + classical + let kq : Fin (2 * period) := + ⟨q.val - period, lt_of_le_of_lt (Nat.sub_le _ _) q.isLt⟩ + have hklt : kq < q := by + have hklt' : kq.val < q.val := by + simpa [kq] using (Nat.sub_lt_of_pos_le hper hq) + exact (Fin.lt_def).2 hklt' + have htok : tokens kq = tokens q := by + have := hdiag.repeat_tok q hq + simpa [kq] using this.symm + have hspec := prevOfTokens_spec (tokens := tokens) (q := q) ⟨kq, hklt, htok⟩ + let p := prevOfTokens tokens q + have hp : + p < q ∧ tokens p = tokens q ∧ + ∀ k, k < q → tokens k = tokens q → k ≤ p := by + simpa [p] using hspec + have hqsub : q.val - period < period := by + have hq2 : q.val < 2 * period := q.isLt + have hq2' : q.val < period + period := by simpa [two_mul] using hq2 + exact (Nat.sub_lt_iff_lt_add hq).2 (by simpa [Nat.add_comm] using hq2') + have huniq : + ∀ r : Fin (2 * period), r < q → tokens r = tokens q → r.val = q.val - period := by + intro r hr htokr + by_cases hrper : period ≤ r.val + · have hrsub : r.val - period < period := by + have hr2 : r.val < 2 * period := r.isLt + have hr2' : r.val < period + period := by simpa [two_mul] using hr2 + exact (Nat.sub_lt_iff_lt_add hrper).2 (by simpa [Nat.add_comm] using hr2') + have htok_r : + tokens ⟨r.val - period, lt_of_le_of_lt (Nat.sub_le _ _) r.isLt⟩ = tokens r := by + have := hdiag.repeat_tok r hrper + simpa using this.symm + have htok_q : + tokens q = tokens ⟨q.val - period, lt_of_le_of_lt (Nat.sub_le _ _) q.isLt⟩ := by + simpa using hdiag.repeat_tok q hq + have htok_first : + tokens ⟨r.val - period, lt_of_le_of_lt (Nat.sub_le _ _) r.isLt⟩ = + tokens ⟨q.val - period, lt_of_le_of_lt (Nat.sub_le _ _) q.isLt⟩ := by + calc + tokens ⟨r.val - period, lt_of_le_of_lt (Nat.sub_le _ _) r.isLt⟩ + = tokens r := by simpa using htok_r + _ = tokens q := by simpa using htokr + _ = tokens ⟨q.val - period, lt_of_le_of_lt (Nat.sub_le _ _) q.isLt⟩ := by + simpa using htok_q + have hkeq : r.val - period = q.val - period := by + apply hdiag.inj hrsub hqsub + simpa using htok_first + have hrval : r.val = q.val := by + have h := congrArg (fun x => x + period) hkeq + simpa [Nat.sub_add_cancel hrper, Nat.sub_add_cancel hq] using h + have hrlt : r.val < q.val := (Fin.lt_def).1 hr + have hrlt' : r.val < r.val := by + have hrlt' := hrlt + rw [← hrval] at hrlt' + exact hrlt' + exact (False.elim (lt_irrefl _ hrlt')) + · have hrlt : r.val < period := lt_of_not_ge hrper + have hrfin : + (⟨r.val, lt_double_of_lt_period hrlt⟩ : Fin (2 * period)) = r := by + apply Fin.ext + rfl + have htok_q : + tokens q = tokens ⟨q.val - period, lt_of_le_of_lt (Nat.sub_le _ _) q.isLt⟩ := by + simpa using hdiag.repeat_tok q hq + have htok_first : + tokens ⟨r.val, lt_double_of_lt_period hrlt⟩ = + tokens ⟨q.val - period, lt_of_le_of_lt (Nat.sub_le _ _) q.isLt⟩ := by + calc + tokens ⟨r.val, lt_double_of_lt_period hrlt⟩ = tokens r := by + simp [hrfin] + _ = tokens q := by simpa using htokr + _ = tokens ⟨q.val - period, lt_of_le_of_lt (Nat.sub_le _ _) q.isLt⟩ := by + simpa using htok_q + have hkeq : r.val = q.val - period := by + apply hdiag.inj hrlt hqsub + simpa using htok_first + exact hkeq + have hpval : p.val = q.val - period := by + have := hp.2.2 p hp.1 hp.2.1 + have huniq' := huniq p hp.1 hp.2.1 + exact huniq' + apply Fin.ext + simp [prevOfPeriod, hpval, p] + +/-- Shifted `prev` map matches the period-shifted map under diagnostic tokens. -/ +theorem prevOfTokensShift_eq_prevOfPeriodShift_of_diag {period : Nat} + {tokens : Fin (2 * period) → Nat} (hdiag : InductionDiagnosticTokens period tokens) + (hper : 0 < period) {q : Fin (2 * period)} (hq : period ≤ q.val) : + prevOfTokensShift tokens q = prevOfPeriodShift (seq := 2 * period) period q := by + have hprev := + prevOfTokens_eq_prevOfPeriod_of_diag (tokens := tokens) hdiag hper (q := q) hq + have hactive : q ∈ activeOfTokensShift tokens := by + let kq : Fin (2 * period) := + ⟨q.val - period, lt_of_le_of_lt (Nat.sub_le _ _) q.isLt⟩ + have hklt : kq < q := by + have hklt' : kq.val < q.val := by + simpa [kq] using (Nat.sub_lt_of_pos_le hper hq) + exact (Fin.lt_def).2 hklt' + have htok : tokens kq = tokens q := by + have := hdiag.repeat_tok q hq + simpa [kq] using this.symm + exact (mem_activeOfTokensShift (tokens := tokens) (q := q)).2 + ⟨kq, hklt, htok⟩ + have hactive' : q ∈ activeOfTokens tokens := by + simpa [activeOfTokensShift] using hactive + simp [prevOfTokensShift, hactive', hprev, prevOfPeriodShift, prevOfPeriod, hq, hper] + +/-- Active shifted queries coincide with the periodic active set in diagnostics. -/ +theorem activeOfTokensShift_eq_activeOfPeriodShift_of_diag {period : Nat} + {tokens : Fin (2 * period) → Nat} (hdiag : InductionDiagnosticTokens period tokens) + (hper : 0 < period) : + activeOfTokensShift (seq := 2 * period) tokens = + activeOfPeriodShift (seq := 2 * period) period := by + ext q + constructor + · intro hq + have hq' := (mem_activeOfTokensShift (tokens := tokens) (q := q)).1 hq + rcases hq' with ⟨k, hk, htok⟩ + have hkper : period ≤ q.val := by + by_contra hlt + have hqlt : q.val < period := lt_of_not_ge hlt + have hklt : k.val < period := lt_of_lt_of_le hk (Nat.le_of_lt hqlt) + have hkfin : + (⟨k.val, lt_double_of_lt_period hklt⟩ : Fin (2 * period)) = k := by + apply Fin.ext + rfl + have hqfin : + (⟨q.val, lt_double_of_lt_period hqlt⟩ : Fin (2 * period)) = q := by + apply Fin.ext + rfl + have hkeq : k.val = q.val := by + apply hdiag.inj hklt hqlt + simpa [hkfin, hqfin] using htok + have hk' : k.val < k.val := by + have hk' := hk + rw [← hkeq] at hk' + exact hk' + exact (False.elim (lt_irrefl _ hk')) + exact (mem_activeOfPeriodShift (seq := 2 * period) (period := period) (q := q)).2 + ⟨hper, hkper⟩ + · intro hq + have hq' := (mem_activeOfPeriodShift (seq := 2 * period) (period := period) (q := q)).1 hq + rcases hq' with ⟨_, hqper⟩ + let kq : Fin (2 * period) := + ⟨q.val - period, lt_of_le_of_lt (Nat.sub_le _ _) q.isLt⟩ + have hklt : kq.val < q.val := by + simpa [kq] using (Nat.sub_lt_of_pos_le hper hqper) + have htok : tokens kq = tokens q := by + have := hdiag.repeat_tok q hqper + simpa [kq] using this.symm + exact (mem_activeOfTokensShift (tokens := tokens) (q := q)).2 + ⟨kq, hklt, htok⟩ + end Model end Nfp From 73222da9fbbd3d0f36694ab1a28d6da3a808bc03 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 16 Jan 2026 23:26:29 +0100 Subject: [PATCH 207/244] Add diagnostic spec bridge for induction prompts --- Nfp/Model/InductionPrompt.lean | 45 ++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/Nfp/Model/InductionPrompt.lean b/Nfp/Model/InductionPrompt.lean index ca1fdac..fe50221 100644 --- a/Nfp/Model/InductionPrompt.lean +++ b/Nfp/Model/InductionPrompt.lean @@ -415,6 +415,51 @@ theorem activeOfTokensShift_eq_activeOfPeriodShift_of_diag {period : Nat} exact (mem_activeOfTokensShift (tokens := tokens) (q := q)).2 ⟨kq, hklt, htok⟩ +/-- Diagnostic prompts align shifted-token `prev` with the period-shifted map. -/ +theorem prevOfTokensShift_eq_prevOfPeriodShift_of_diag_all {period : Nat} + {tokens : Fin (2 * period) → Nat} (hdiag : InductionDiagnosticTokens period tokens) + (hper : 0 < period) : + prevOfTokensShift tokens = prevOfPeriodShift (seq := 2 * period) period := by + funext q + by_cases hqper : period ≤ q.val + · simpa using + (prevOfTokensShift_eq_prevOfPeriodShift_of_diag (tokens := tokens) hdiag hper + (q := q) hqper) + · have hqnot_period : q ∉ activeOfPeriodShift (seq := 2 * period) period := by + intro hqmem + have hcond := + (mem_activeOfPeriodShift (seq := 2 * period) (period := period) (q := q)).1 hqmem + exact (hqper hcond.2).elim + have hactive_eq := + activeOfTokensShift_eq_activeOfPeriodShift_of_diag (tokens := tokens) hdiag hper + have hqnot_tokens_shift : q ∉ activeOfTokensShift (seq := 2 * period) tokens := by + simpa [hactive_eq] using hqnot_period + have hqnot_tokens : q ∉ activeOfTokens tokens := by + simpa [activeOfTokensShift] using hqnot_tokens_shift + simp [prevOfTokensShift, hqnot_tokens, prevOfPeriodShift, hqper] + +/-- Diagnostic prompts let period-shift specs re-express as token-shift specs. -/ +theorem InductionPrevSpecTokensShift_of_diag {dModel dHead : Nat} {period : Nat} + {tokens : Fin (2 * period) → Nat} (hdiag : InductionDiagnosticTokens period tokens) + (hper : 0 < period) {inputs : InductionHeadInputs (2 * period) dModel dHead} + (hspec : InductionPrevSpecPeriodShift (seq := 2 * period) period inputs) : + InductionPrevSpecTokensShift (seq := 2 * period) tokens inputs := by + refine ⟨?active, ?prev⟩ + · have hactive_eq := + activeOfTokensShift_eq_activeOfPeriodShift_of_diag (tokens := tokens) hdiag hper + calc + inputs.active = activeOfPeriodShift (seq := 2 * period) period := hspec.active_eq + _ = activeOfTokensShift (seq := 2 * period) tokens := by + symm + exact hactive_eq + · have hprev_eq := + prevOfTokensShift_eq_prevOfPeriodShift_of_diag_all (tokens := tokens) hdiag hper + calc + inputs.prev = prevOfPeriodShift (seq := 2 * period) period := hspec.prev_eq + _ = prevOfTokensShift tokens := by + symm + exact hprev_eq + end Model end Nfp From b031709d4134ff2ab30331a096777a2ece9417e6 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 16 Jan 2026 23:31:39 +0100 Subject: [PATCH 208/244] Add script to certify any induction head --- scripts/certify_induction_head.py | 157 ++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 scripts/certify_induction_head.py diff --git a/scripts/certify_induction_head.py b/scripts/certify_induction_head.py new file mode 100644 index 0000000..25e55f7 --- /dev/null +++ b/scripts/certify_induction_head.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: AGPL-3.0-or-later + +""" +Certify a single induction head from a model binary. + +This script is a small wrapper around +`nfp induction certify_head_model_auto(_nonvacuous)` and optionally +creates a diagnostic prompt model with repeated patterns. +""" + +from __future__ import annotations + +import argparse +import os +import shutil +import subprocess +import sys +from pathlib import Path + + +def resolve_nfp_cmd(nfp_bin: str | None) -> list[str]: + if nfp_bin: + return [nfp_bin] + env_bin = os.environ.get("NFP_BIN") + if env_bin: + return [env_bin] + local_bin = Path(".lake/build/bin/nfp") + if local_bin.exists(): + return [str(local_bin)] + return ["lake", "exe", "nfp"] + + +def ensure_model( + model_path: Path, + *, + seq_len: int, + pattern_len: int, + seed: int, + vocab_min: int, + vocab_max: int, + min_word_length: int, + allow_no_leading_space: bool, + model_name: str, +) -> None: + if model_path.exists(): + return + model_path.parent.mkdir(parents=True, exist_ok=True) + generator = [ + sys.executable, + "scripts/generate_rigorous_induction.py", + "--output", + str(model_path), + "--seq-len", + str(seq_len), + "--pattern-len", + str(pattern_len), + "--seed", + str(seed), + "--vocab-min", + str(vocab_min), + "--vocab-max", + str(vocab_max), + "--min-word-length", + str(min_word_length), + "--model", + model_name, + ] + if allow_no_leading_space: + generator.append("--allow-no-leading-space") + if shutil.which("uv"): + generator = ["uv", "run"] + generator + subprocess.run(generator, check=True) + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Certify an induction head from a model binary." + ) + parser.add_argument("--model", default="models/gpt2_rigorous.nfpt") + parser.add_argument("--layer", type=int, required=True) + parser.add_argument("--head", type=int, required=True) + parser.add_argument("--period", type=int) + parser.add_argument("--nonvacuous", action="store_true") + parser.add_argument("--zero-based", action="store_true") + parser.add_argument("--min-active", type=int) + parser.add_argument("--min-logit-diff", type=str) + parser.add_argument("--min-margin", type=str) + parser.add_argument("--max-eps", type=str) + parser.add_argument("--nfp-bin", help="Path to nfp binary") + + parser.add_argument( + "--ensure-model", + action="store_true", + help="Generate a diagnostic model if the path does not exist", + ) + parser.add_argument("--seq-len", type=int, default=256) + parser.add_argument("--pattern-len", type=int, default=20) + parser.add_argument("--seed", type=int, default=1337) + parser.add_argument("--vocab-min", type=int, default=1000) + parser.add_argument("--vocab-max", type=int, default=5000) + parser.add_argument("--min-word-length", type=int, default=4) + parser.add_argument("--model-name", default="gpt2") + parser.add_argument("--allow-no-leading-space", action="store_true") + + args = parser.parse_args() + + model_path = Path(args.model) + if args.ensure_model: + ensure_model( + model_path, + seq_len=args.seq_len, + pattern_len=args.pattern_len, + seed=args.seed, + vocab_min=args.vocab_min, + vocab_max=args.vocab_max, + min_word_length=args.min_word_length, + allow_no_leading_space=args.allow_no_leading_space, + model_name=args.model_name, + ) + if not model_path.exists(): + print(f"error: model not found at {model_path}", file=sys.stderr) + return 1 + + subcmd = "certify_head_model_auto" + if args.nonvacuous: + subcmd = "certify_head_model_auto_nonvacuous" + + cmd = resolve_nfp_cmd(args.nfp_bin) + [ + "induction", + subcmd, + "--model", + str(model_path), + "--layer", + str(args.layer), + "--head", + str(args.head), + ] + if args.zero_based: + cmd.append("--zero-based") + if args.period is not None: + cmd += ["--period", str(args.period)] + if args.min_active is not None: + cmd += ["--min-active", str(args.min_active)] + if args.min_logit_diff is not None: + cmd += ["--min-logit-diff", args.min_logit_diff] + if args.min_margin is not None: + cmd += ["--min-margin", args.min_margin] + if args.max_eps is not None: + cmd += ["--max-eps", args.max_eps] + + proc = subprocess.run(cmd) + return proc.returncode + + +if __name__ == "__main__": + raise SystemExit(main()) From 70d9bd67b8ef7e82721ceeb78d7a9fe7ccce245d Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Fri, 16 Jan 2026 23:44:44 +0100 Subject: [PATCH 209/244] Extend induction head certify wrapper with advanced flags --- scripts/certify_induction_head.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/scripts/certify_induction_head.py b/scripts/certify_induction_head.py index 25e55f7..671a77b 100644 --- a/scripts/certify_induction_head.py +++ b/scripts/certify_induction_head.py @@ -88,6 +88,12 @@ def main() -> int: parser.add_argument("--min-margin", type=str) parser.add_argument("--max-eps", type=str) parser.add_argument("--nfp-bin", help="Path to nfp binary") + parser.add_argument("--timing", type=int) + parser.add_argument("--heartbeat-ms", type=int) + parser.add_argument("--split-budget-q", type=int) + parser.add_argument("--split-budget-k", type=int) + parser.add_argument("--split-budget-diff-base", type=int) + parser.add_argument("--split-budget-diff-refined", type=int) parser.add_argument( "--ensure-model", @@ -128,6 +134,7 @@ def main() -> int: cmd = resolve_nfp_cmd(args.nfp_bin) + [ "induction", + "advanced", subcmd, "--model", str(model_path), @@ -148,6 +155,18 @@ def main() -> int: cmd += ["--min-margin", args.min_margin] if args.max_eps is not None: cmd += ["--max-eps", args.max_eps] + if args.timing is not None: + cmd += ["--timing", str(args.timing)] + if args.heartbeat_ms is not None: + cmd += ["--heartbeat-ms", str(args.heartbeat_ms)] + if args.split_budget_q is not None: + cmd += ["--split-budget-q", str(args.split_budget_q)] + if args.split_budget_k is not None: + cmd += ["--split-budget-k", str(args.split_budget_k)] + if args.split_budget_diff_base is not None: + cmd += ["--split-budget-diff-base", str(args.split_budget_diff_base)] + if args.split_budget_diff_refined is not None: + cmd += ["--split-budget-diff-refined", str(args.split_budget_diff_refined)] proc = subprocess.run(cmd) return proc.returncode From cd87283b50dd2488f3f1732e60f548a5222118fc Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 00:26:30 +0100 Subject: [PATCH 210/244] Add skip-logit-diff option for induction certs --- Nfp/Cli.lean | 27 +++++++++++++++++---- Nfp/IO/InductionHead/Basic.lean | 21 +++++++++++++---- scripts/certify_induction_head.py | 39 +++++++++++++++++++++++++------ 3 files changed, 70 insertions(+), 17 deletions(-) diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index 90e697d..4135af6 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -74,6 +74,7 @@ private def runInductionCertifyUnified (requireNonvacuous : Bool) (p : Parsed) : let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) let timing? := (p.flag? "timing").map (·.as! Nat) let heartbeatMs? := (p.flag? "heartbeat-ms").map (·.as! Nat) + let skipLogitDiff := p.hasFlag "skip-logit-diff" let zeroBased := p.hasFlag "zero-based" let fail (msg : String) : IO UInt32 := do IO.eprintln s!"error: {msg}" @@ -97,6 +98,8 @@ private def runInductionCertifyUnified (requireNonvacuous : Bool) (p : Parsed) : fail "--layer/--head/--period are only valid with --model" else if direction?.isSome then fail "--direction is only valid with --model" + else if requireNonvacuous && skipLogitDiff then + fail "--skip-logit-diff is not allowed with certify_nonvacuous" else if requireNonvacuous then IO.runInductionCertifyHeadNonvacuous inputsPath minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? @@ -105,6 +108,7 @@ private def runInductionCertifyUnified (requireNonvacuous : Bool) (p : Parsed) : IO.runInductionCertifyHead inputsPath minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + skipLogitDiff | none, some modelPath => match layer?, head? with | some layer, some head => @@ -116,7 +120,9 @@ private def runInductionCertifyUnified (requireNonvacuous : Bool) (p : Parsed) : | Except.ok layer', Except.ok head' => match direction? with | some ⟨dirTarget, dirNegative⟩ => - if requireNonvacuous then + if requireNonvacuous && skipLogitDiff then + fail "--skip-logit-diff is not allowed with certify_nonvacuous" + else if requireNonvacuous then IO.runInductionCertifyHeadModelNonvacuous modelPath layer' head' dirTarget dirNegative period? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? @@ -126,8 +132,11 @@ private def runInductionCertifyUnified (requireNonvacuous : Bool) (p : Parsed) : period? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + skipLogitDiff | none => - if requireNonvacuous then + if requireNonvacuous && skipLogitDiff then + fail "--skip-logit-diff is not allowed with certify_nonvacuous" + else if requireNonvacuous then IO.runInductionCertifyHeadModelAutoNonvacuous modelPath layer' head' period? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? @@ -135,6 +144,7 @@ private def runInductionCertifyUnified (requireNonvacuous : Bool) (p : Parsed) : IO.runInductionCertifyHeadModelAuto modelPath layer' head' period? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + skipLogitDiff | _, _ => fail "--layer and --head are required with --model" | none, none => @@ -215,6 +225,7 @@ def inductionCertifySimpleCmd : Cmd := `[Cli| (rational literal; default: 0)." "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." + "skip-logit-diff"; "Skip logit-diff lower bound computation." timing : Nat; "Emit timing output to stdout (0=off, 1=on)." "heartbeat-ms" : Nat; "Emit progress heartbeat every N ms (0 disables)." ] @@ -421,9 +432,10 @@ def runInductionCertifyHead (p : Parsed) : IO UInt32 := do let splitBudgetK? := (p.flag? "split-budget-k").map (·.as! Nat) let splitBudgetDiffBase? := (p.flag? "split-budget-diff-base").map (·.as! Nat) let splitBudgetDiffRefined? := (p.flag? "split-budget-diff-refined").map (·.as! Nat) + let skipLogitDiff := p.hasFlag "skip-logit-diff" IO.runInductionCertifyHead inputsPath minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? - splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? skipLogitDiff /-- `nfp induction certify_head_nonvacuous` subcommand. -/ def runInductionCertifyHeadNonvacuous (p : Parsed) : IO UInt32 := do @@ -462,6 +474,7 @@ def inductionCertifyHeadCmd : Cmd := `[Cli| (default: 0)." "split-budget-diff-refined" : Nat; "Split-budget for refined diff dims in sign-splitting \ bounds (default: 12)." + "skip-logit-diff" : Bool; "Skip logit-diff lower bound computation." ] /-- `nfp induction certify_head_nonvacuous` subcommand. -/ @@ -484,6 +497,7 @@ def inductionCertifyHeadNonvacuousCmd : Cmd := `[Cli| (default: 0)." "split-budget-diff-refined" : Nat; "Split-budget for refined diff dims in sign-splitting \ bounds (default: 12)." + "skip-logit-diff" : Bool; "Skip logit-diff lower bound computation." ] /-- `nfp induction certify_head_model` subcommand. -/ @@ -504,6 +518,7 @@ def runInductionCertifyHeadModel (p : Parsed) : IO UInt32 := do let splitBudgetK? := (p.flag? "split-budget-k").map (·.as! Nat) let splitBudgetDiffBase? := (p.flag? "split-budget-diff-base").map (·.as! Nat) let splitBudgetDiffRefined? := (p.flag? "split-budget-diff-refined").map (·.as! Nat) + let skipLogitDiff := p.hasFlag "skip-logit-diff" let zeroBased := p.hasFlag "zero-based" let layerE := toZeroBased "layer" layer zeroBased let headE := toZeroBased "head" head zeroBased @@ -517,7 +532,7 @@ def runInductionCertifyHeadModel (p : Parsed) : IO UInt32 := do | Except.ok layer', Except.ok head' => IO.runInductionCertifyHeadModel modelPath layer' head' dirTarget dirNegative period? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? - splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? skipLogitDiff /-- `nfp induction certify_head_model_nonvacuous` subcommand. -/ def runInductionCertifyHeadModelNonvacuous (p : Parsed) : IO UInt32 := do @@ -568,6 +583,7 @@ def runInductionCertifyHeadModelAuto (p : Parsed) : IO UInt32 := do let splitBudgetK? := (p.flag? "split-budget-k").map (·.as! Nat) let splitBudgetDiffBase? := (p.flag? "split-budget-diff-base").map (·.as! Nat) let splitBudgetDiffRefined? := (p.flag? "split-budget-diff-refined").map (·.as! Nat) + let skipLogitDiff := p.hasFlag "skip-logit-diff" let zeroBased := p.hasFlag "zero-based" let layerE := toZeroBased "layer" layer zeroBased let headE := toZeroBased "head" head zeroBased @@ -581,7 +597,7 @@ def runInductionCertifyHeadModelAuto (p : Parsed) : IO UInt32 := do | Except.ok layer', Except.ok head' => IO.runInductionCertifyHeadModelAuto modelPath layer' head' period? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? - splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? skipLogitDiff /-- `nfp induction certify_head_model_auto_nonvacuous` subcommand. -/ def runInductionCertifyHeadModelAutoNonvacuous (p : Parsed) : IO UInt32 := do @@ -640,6 +656,7 @@ def inductionCertifyHeadModelCmd : Cmd := `[Cli| (default: 0)." "split-budget-diff-refined" : Nat; "Split-budget for refined diff dims in sign-splitting \ bounds (default: 12)." + "skip-logit-diff" : Bool; "Skip logit-diff lower bound computation." ] /-- `nfp induction certify_head_model_nonvacuous` subcommand. -/ diff --git a/Nfp/IO/InductionHead/Basic.lean b/Nfp/IO/InductionHead/Basic.lean index 200f21d..8a522f4 100644 --- a/Nfp/IO/InductionHead/Basic.lean +++ b/Nfp/IO/InductionHead/Basic.lean @@ -318,7 +318,7 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} (inputs : Model.InductionHeadInputs seq dModel dHead) (cfg : Sound.InductionHeadSplitConfig) (minActive? : Option Nat) (minLogitDiff? : Option Rat) - (minMargin maxEps : Rat) : IO UInt32 := do + (minMargin maxEps : Rat) (skipLogitDiff : Bool) : IO UInt32 := do match seq with | 0 => IO.eprintln "error: seq must be positive" @@ -890,6 +890,12 @@ private def checkInductionHeadInputs {seq dModel dHead : Nat} s!"error: eps {ratToString cert.eps} \ above maximum {ratToString maxEps}" return 2 + if skipLogitDiff then + IO.println + s!"ok: induction head certificate built (seq={seq}, active={activeCount}, \ + margin={ratToString cert.margin}, eps={ratToString cert.eps}, \ + note=logit-diff skipped)" + return 0 timingPrint "timing: head tol start" timingFlush let tol := cert.eps * (cert.values.hi - cert.values.lo) @@ -1177,7 +1183,8 @@ def runInductionCertifyHead (inputsPath : System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) (timing? : Option Nat) (heartbeatMs? : Option Nat) - (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : + (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) + (skipLogitDiff : Bool) : IO UInt32 := do configureTiming timing? heartbeatMs? let splitCfg := @@ -1206,6 +1213,7 @@ def runInductionCertifyHead (inputsPath : System.FilePath) return 1 | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => checkInductionHeadInputs inputs splitCfg minActive? minLogitDiff? minMargin maxEps + skipLogitDiff /-- Build and check induction certificates from a model binary. -/ def runInductionCertifyHeadModel (modelPath : System.FilePath) @@ -1213,7 +1221,8 @@ def runInductionCertifyHeadModel (modelPath : System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) (timing? : Option Nat) (heartbeatMs? : Option Nat) - (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : + (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) + (skipLogitDiff : Bool) : IO UInt32 := do configureTiming timing? heartbeatMs? let splitCfg := @@ -1254,6 +1263,7 @@ def runInductionCertifyHeadModel (modelPath : System.FilePath) return 1 | Except.ok inputs => checkInductionHeadInputs inputs splitCfg minActive? minLogitDiff? minMargin maxEps + skipLogitDiff /-- Heuristic logit-diff direction derived from prompt tokens. -/ def deriveDirectionFromTokens {seq : Nat} (tokens : Fin seq → Nat) : @@ -1290,7 +1300,8 @@ def runInductionCertifyHeadModelAuto (modelPath : System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) (timing? : Option Nat) (heartbeatMs? : Option Nat) - (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : + (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) + (skipLogitDiff : Bool) : IO UInt32 := do configureTiming timing? heartbeatMs? let splitCfg := @@ -1345,7 +1356,7 @@ def runInductionCertifyHeadModelAuto (modelPath : System.FilePath) return 1 | Except.ok inputs => checkInductionHeadInputs inputs splitCfg minActive? minLogitDiff? - minMargin maxEps + minMargin maxEps skipLogitDiff /-- Build head-output interval bounds from exact head inputs. -/ def runInductionHeadInterval (inputsPath : System.FilePath) diff --git a/scripts/certify_induction_head.py b/scripts/certify_induction_head.py index 671a77b..fa6a159 100644 --- a/scripts/certify_induction_head.py +++ b/scripts/certify_induction_head.py @@ -88,12 +88,18 @@ def main() -> int: parser.add_argument("--min-margin", type=str) parser.add_argument("--max-eps", type=str) parser.add_argument("--nfp-bin", help="Path to nfp binary") + parser.add_argument( + "--preset", + choices=["fast", "balanced", "tight"], + help="Split-budget preset for streamlined certify", + ) parser.add_argument("--timing", type=int) parser.add_argument("--heartbeat-ms", type=int) parser.add_argument("--split-budget-q", type=int) parser.add_argument("--split-budget-k", type=int) parser.add_argument("--split-budget-diff-base", type=int) parser.add_argument("--split-budget-diff-refined", type=int) + parser.add_argument("--skip-logit-diff", action="store_true") parser.add_argument( "--ensure-model", @@ -128,13 +134,28 @@ def main() -> int: print(f"error: model not found at {model_path}", file=sys.stderr) return 1 - subcmd = "certify_head_model_auto" - if args.nonvacuous: - subcmd = "certify_head_model_auto_nonvacuous" - - cmd = resolve_nfp_cmd(args.nfp_bin) + [ - "induction", - "advanced", + use_advanced = any( + val is not None + for val in ( + args.split_budget_q, + args.split_budget_k, + args.split_budget_diff_base, + args.split_budget_diff_refined, + ) + ) + if use_advanced: + subcmd = "certify_head_model_auto" + if args.nonvacuous: + subcmd = "certify_head_model_auto_nonvacuous" + else: + subcmd = "certify" + if args.nonvacuous: + subcmd = "certify_nonvacuous" + + cmd = resolve_nfp_cmd(args.nfp_bin) + ["induction"] + if use_advanced: + cmd.append("advanced") + cmd += [ subcmd, "--model", str(model_path), @@ -155,6 +176,8 @@ def main() -> int: cmd += ["--min-margin", args.min_margin] if args.max_eps is not None: cmd += ["--max-eps", args.max_eps] + if args.preset is not None and not use_advanced: + cmd += ["--preset", args.preset] if args.timing is not None: cmd += ["--timing", str(args.timing)] if args.heartbeat_ms is not None: @@ -167,6 +190,8 @@ def main() -> int: cmd += ["--split-budget-diff-base", str(args.split_budget_diff_base)] if args.split_budget_diff_refined is not None: cmd += ["--split-budget-diff-refined", str(args.split_budget_diff_refined)] + if args.skip_logit_diff: + cmd.append("--skip-logit-diff") proc = subprocess.run(cmd) return proc.returncode From 437c19bb0eab47f4c7dc205c3ea822661c2b414a Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 06:47:42 +0100 Subject: [PATCH 211/244] Add induction circuit specs and shifted-prev circuit cert --- Nfp/Circuit/Layers/Induction.lean | 1 + Nfp/Circuit/Layers/Induction/Basic.lean | 16 +++ Nfp/Circuit/Layers/Induction/Circuit.lean | 80 ++++++++++++ Nfp/Cli.lean | 136 ++++++++++++++++++--- Nfp/IO/InductionHead.lean | 1 + Nfp/IO/InductionHead/Basic.lean | 12 +- Nfp/IO/InductionHead/Circuit.lean | 80 ++++++++++++ Nfp/IO/InductionHead/Nonvacuous.lean | 8 +- Nfp/IO/NfptPure.lean | 100 +++++++++++---- Nfp/IO/Run/Basic.lean | 2 +- Nfp/Model.lean | 1 + Nfp/Model/InductionCircuit.lean | 60 +++++++++ Nfp/Sound/Induction/LogitDiff.lean | 18 ++- Nfp/Sound/Induction/Refine.lean | 41 +++++++ Nfp/Sound/Induction/RefineSound.lean | 18 ++- docs/induction_cert_audit.md | 7 ++ scripts/certify_induction_head.py | 3 + scripts/discover_gpt2_induction_targets.py | 28 ++++- scripts/scan_gpt2_induction_sound.py | 2 +- scripts/sweep_gpt2_induction_nonvacuous.py | 2 +- 20 files changed, 549 insertions(+), 67 deletions(-) create mode 100644 Nfp/Circuit/Layers/Induction/Circuit.lean create mode 100644 Nfp/IO/InductionHead/Circuit.lean create mode 100644 Nfp/Model/InductionCircuit.lean diff --git a/Nfp/Circuit/Layers/Induction.lean b/Nfp/Circuit/Layers/Induction.lean index 934c3dc..a2c49ea 100644 --- a/Nfp/Circuit/Layers/Induction.lean +++ b/Nfp/Circuit/Layers/Induction.lean @@ -3,6 +3,7 @@ module public import Nfp.Circuit.Layers.Induction.Basic +public import Nfp.Circuit.Layers.Induction.Circuit /-! Induction-head layer wiring and helper lemmas. diff --git a/Nfp/Circuit/Layers/Induction/Basic.lean b/Nfp/Circuit/Layers/Induction/Basic.lean index 30ff41c..cc398db 100644 --- a/Nfp/Circuit/Layers/Induction/Basic.lean +++ b/Nfp/Circuit/Layers/Induction/Basic.lean @@ -60,12 +60,28 @@ def InductionSpec (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) (out vals : Fin (Nat.succ n) → Val) : Prop := ∀ q, q ≠ 0 → out q = vals (prev q) +/-- Unfolding lemma for `InductionSpec`. -/ +theorem InductionSpec_def (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (out vals : Fin (Nat.succ n) → Val) : + InductionSpec (n := n) prev out vals = ∀ q, q ≠ 0 → out q = vals (prev q) := by + rfl + /-- Concrete `prev` map on `Fin (n + 1)` (with `0 ↦ 0`). -/ def prevIndex : Fin (Nat.succ n) → Fin (Nat.succ n) | ⟨0, _⟩ => 0 | ⟨Nat.succ k, hk⟩ => ⟨k, Nat.lt_trans (Nat.lt_of_succ_lt_succ hk) (Nat.lt_succ_self n)⟩ +/-- Previous-token head spec: copies the immediately preceding token. -/ +def PrevTokenSpec (out vals : Fin (Nat.succ n) → Val) : Prop := + InductionSpec (n := n) (prevIndex (n := n)) out vals + +/-- Unfolding lemma for `PrevTokenSpec`. -/ +theorem PrevTokenSpec_def (out vals : Fin (Nat.succ n) → Val) : + PrevTokenSpec (n := n) out vals = + InductionSpec (n := n) (prevIndex (n := n)) out vals := by + rfl + end Spec section ApproxSpec diff --git a/Nfp/Circuit/Layers/Induction/Circuit.lean b/Nfp/Circuit/Layers/Induction/Circuit.lean new file mode 100644 index 0000000..30bcce2 --- /dev/null +++ b/Nfp/Circuit/Layers/Induction/Circuit.lean @@ -0,0 +1,80 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Circuit.Layers.Induction.Basic + +/-! +Circuit-level induction specs: previous-token head feeding an induction head. + +These are definitional wrappers that name the canonical two-head induction +mechanism used in the literature: a previous-token head writes the prior token +representation into the residual stream, and the induction head then copies the +appropriate continuation. +-/ + +public section + +namespace Nfp + +namespace Circuit + +namespace Layers + +universe v + +section Specs + +variable {Val : Type v} +variable {n : Nat} + +/-- +Two-head induction circuit spec. + +The previous-token head copies `vals (q-1)` into `prevOut`, and the induction +head copies from `prevOut (prev q)` into `indOut`. +-/ +def InductionCircuitSpec + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (prevOut indOut vals : Fin (Nat.succ n) → Val) : Prop := + PrevTokenSpec (n := n) prevOut vals ∧ + InductionSpec (n := n) prev indOut prevOut + +/-- Unfolding lemma for `InductionCircuitSpec`. -/ +theorem InductionCircuitSpec_def + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (prevOut indOut vals : Fin (Nat.succ n) → Val) : + InductionCircuitSpec (n := n) prev prevOut indOut vals = + (PrevTokenSpec (n := n) prevOut vals ∧ + InductionSpec (n := n) prev indOut prevOut) := by + rfl + +/-- +Circuit composition: the induction head copies the predecessor of the predecessor. + +This is the direct consequence of a previous-token head feeding the induction head. +-/ +theorem InductionCircuitSpec_compose + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (prevOut indOut vals : Fin (Nat.succ n) → Val) + (h : InductionCircuitSpec (n := n) prev prevOut indOut vals) : + ∀ q, q ≠ 0 → prev q ≠ 0 → + indOut q = vals (prevIndex (n := n) (prev q)) := by + intro q hq hprev + rcases h with ⟨hprevTok, hind⟩ + have hprevTok' : + ∀ q, q ≠ 0 → prevOut q = vals (prevIndex (n := n) q) := by + simpa [PrevTokenSpec_def, InductionSpec_def] using hprevTok + have hind' : ∀ q, q ≠ 0 → indOut q = prevOut (prev q) := by + simpa [InductionSpec_def] using hind + calc + indOut q = prevOut (prev q) := hind' q hq + _ = vals (prevIndex (n := n) (prev q)) := hprevTok' (prev q) hprev + +end Specs + +end Layers + +end Circuit + +end Nfp diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index 4135af6..982ff57 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -66,6 +66,7 @@ private def runInductionCertifyUnified (requireNonvacuous : Bool) (p : Parsed) : let layer? := (p.flag? "layer").map (·.as! Nat) let head? := (p.flag? "head").map (·.as! Nat) let period? := (p.flag? "period").map (·.as! Nat) + let prevShift := p.hasFlag "prev-shift" let directionStr? := (p.flag? "direction").map (·.as! String) let presetStr? := (p.flag? "preset").map (·.as! String) let minActive? := (p.flag? "min-active").map (·.as! Nat) @@ -94,8 +95,8 @@ private def runInductionCertifyUnified (requireNonvacuous : Bool) (p : Parsed) : Except.ok direction? => match inputsPath?, modelPath? with | some inputsPath, none => - if layer?.isSome || head?.isSome || period?.isSome then - fail "--layer/--head/--period are only valid with --model" + if layer?.isSome || head?.isSome || period?.isSome || prevShift then + fail "--layer/--head/--period/--prev-shift are only valid with --model" else if direction?.isSome then fail "--direction is only valid with --model" else if requireNonvacuous && skipLogitDiff then @@ -124,12 +125,13 @@ private def runInductionCertifyUnified (requireNonvacuous : Bool) (p : Parsed) : fail "--skip-logit-diff is not allowed with certify_nonvacuous" else if requireNonvacuous then IO.runInductionCertifyHeadModelNonvacuous modelPath layer' head' dirTarget - dirNegative period? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? + dirNegative period? prevShift minActive? minLogitDiffStr? minMarginStr? + maxEpsStr? timing? heartbeatMs? splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? else IO.runInductionCertifyHeadModel modelPath layer' head' dirTarget dirNegative - period? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? + period? prevShift minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? skipLogitDiff @@ -138,11 +140,13 @@ private def runInductionCertifyUnified (requireNonvacuous : Bool) (p : Parsed) : fail "--skip-logit-diff is not allowed with certify_nonvacuous" else if requireNonvacuous then IO.runInductionCertifyHeadModelAutoNonvacuous modelPath layer' head' period? - minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? + prevShift minActive? minLogitDiffStr? minMarginStr? maxEpsStr? + timing? heartbeatMs? splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? else - IO.runInductionCertifyHeadModelAuto modelPath layer' head' period? - minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? + IO.runInductionCertifyHeadModelAuto modelPath layer' head' period? prevShift + minActive? minLogitDiffStr? minMarginStr? maxEpsStr? + timing? heartbeatMs? splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? skipLogitDiff | _, _ => @@ -164,6 +168,7 @@ private def runInductionIntervalSimple (p : Parsed) : IO UInt32 := do let layer? := (p.flag? "layer").map (·.as! Nat) let head? := (p.flag? "head").map (·.as! Nat) let period? := (p.flag? "period").map (·.as! Nat) + let prevShift := p.hasFlag "prev-shift" let directionStr? := (p.flag? "direction").map (·.as! String) let outPath? := (p.flag? "out").map (·.as! String) let zeroBased := p.hasFlag "zero-based" @@ -179,8 +184,8 @@ private def runInductionIntervalSimple (p : Parsed) : IO UInt32 := do | Except.ok direction? => match inputsPath?, modelPath? with | some inputsPath, none => - if layer?.isSome || head?.isSome || period?.isSome then - fail "--layer/--head/--period are only valid with --model" + if layer?.isSome || head?.isSome || period?.isSome || prevShift then + fail "--layer/--head/--period/--prev-shift are only valid with --model" else if direction?.isSome then fail "--direction is only valid with --model" else @@ -195,7 +200,7 @@ private def runInductionIntervalSimple (p : Parsed) : IO UInt32 := do | _, Except.error msg => fail msg | Except.ok layer', Except.ok head' => IO.runInductionHeadIntervalModel modelPath layer' head' dirTarget dirNegative - period? outPath? + period? prevShift outPath? | _, _, none => fail "--direction is required with --model (use \"target,negative\")" | _, _, _ => @@ -216,6 +221,7 @@ def inductionCertifySimpleCmd : Cmd := `[Cli| head : Nat; "Head index for the induction head (1-based, required with --model)." "zero-based"; "Interpret --layer/--head as zero-based indices (legacy)." period : Nat; "Optional prompt period override (model only; default: derive from tokens)." + "prev-shift"; "Use shifted prev/active (prev = q - period + 1) for model inputs." direction : String; "Optional logit-diff direction as \"target,negative\" (model only). \ When omitted with --model, direction is derived from tokens." preset : String; "Split-budget preset: fast | balanced | tight (default: balanced)." @@ -241,6 +247,7 @@ def inductionCertifyNonvacuousSimpleCmd : Cmd := `[Cli| head : Nat; "Head index for the induction head (1-based, required with --model)." "zero-based"; "Interpret --layer/--head as zero-based indices (legacy)." period : Nat; "Optional prompt period override (model only; default: derive from tokens)." + "prev-shift"; "Use shifted prev/active (prev = q - period + 1) for model inputs." direction : String; "Optional logit-diff direction as \"target,negative\" (model only). \ When omitted with --model, direction is derived from tokens." preset : String; "Split-budget preset: fast | balanced | tight (default: balanced)." @@ -506,6 +513,7 @@ def runInductionCertifyHeadModel (p : Parsed) : IO UInt32 := do let layer := p.flag! "layer" |>.as! Nat let head := p.flag! "head" |>.as! Nat let period? := (p.flag? "period").map (·.as! Nat) + let prevShift := p.hasFlag "prev-shift" let dirTarget := p.flag! "direction-target" |>.as! Nat let dirNegative := p.flag! "direction-negative" |>.as! Nat let minActive? := (p.flag? "min-active").map (·.as! Nat) @@ -531,7 +539,7 @@ def runInductionCertifyHeadModel (p : Parsed) : IO UInt32 := do return 2 | Except.ok layer', Except.ok head' => IO.runInductionCertifyHeadModel modelPath layer' head' dirTarget dirNegative period? - minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? + prevShift minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? skipLogitDiff /-- `nfp induction certify_head_model_nonvacuous` subcommand. -/ @@ -540,6 +548,7 @@ def runInductionCertifyHeadModelNonvacuous (p : Parsed) : IO UInt32 := do let layer := p.flag! "layer" |>.as! Nat let head := p.flag! "head" |>.as! Nat let period? := (p.flag? "period").map (·.as! Nat) + let prevShift := p.hasFlag "prev-shift" let dirTarget := p.flag! "direction-target" |>.as! Nat let dirNegative := p.flag! "direction-negative" |>.as! Nat let minActive? := (p.flag? "min-active").map (·.as! Nat) @@ -564,7 +573,8 @@ def runInductionCertifyHeadModelNonvacuous (p : Parsed) : IO UInt32 := do return 2 | Except.ok layer', Except.ok head' => IO.runInductionCertifyHeadModelNonvacuous modelPath layer' head' dirTarget dirNegative - period? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? + period? prevShift minActive? minLogitDiffStr? minMarginStr? maxEpsStr? + timing? heartbeatMs? splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? /-- `nfp induction certify_head_model_auto` subcommand. -/ @@ -573,6 +583,7 @@ def runInductionCertifyHeadModelAuto (p : Parsed) : IO UInt32 := do let layer := p.flag! "layer" |>.as! Nat let head := p.flag! "head" |>.as! Nat let period? := (p.flag? "period").map (·.as! Nat) + let prevShift := p.hasFlag "prev-shift" let minActive? := (p.flag? "min-active").map (·.as! Nat) let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) let minMarginStr? := (p.flag? "min-margin").map (·.as! String) @@ -595,7 +606,7 @@ def runInductionCertifyHeadModelAuto (p : Parsed) : IO UInt32 := do IO.eprintln s!"error: {msg}" return 2 | Except.ok layer', Except.ok head' => - IO.runInductionCertifyHeadModelAuto modelPath layer' head' period? + IO.runInductionCertifyHeadModelAuto modelPath layer' head' period? prevShift minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? skipLogitDiff @@ -605,6 +616,7 @@ def runInductionCertifyHeadModelAutoNonvacuous (p : Parsed) : IO UInt32 := do let layer := p.flag! "layer" |>.as! Nat let head := p.flag! "head" |>.as! Nat let period? := (p.flag? "period").map (·.as! Nat) + let prevShift := p.hasFlag "prev-shift" let minActive? := (p.flag? "min-active").map (·.as! Nat) let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) let minMarginStr? := (p.flag? "min-margin").map (·.as! String) @@ -626,10 +638,65 @@ def runInductionCertifyHeadModelAutoNonvacuous (p : Parsed) : IO UInt32 := do IO.eprintln s!"error: {msg}" return 2 | Except.ok layer', Except.ok head' => - IO.runInductionCertifyHeadModelAutoNonvacuous modelPath layer' head' period? + IO.runInductionCertifyHeadModelAutoNonvacuous modelPath layer' head' period? prevShift minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? +/-! `nfp induction certify_circuit_model` subcommand. -/ +/-- CLI entrypoint for `nfp induction certify_circuit_model`. -/ +def runInductionCertifyCircuitModel (p : Parsed) : IO UInt32 := do + let modelPath := p.flag! "model" |>.as! String + let prevLayer := p.flag! "prev-layer" |>.as! Nat + let prevHead := p.flag! "prev-head" |>.as! Nat + let indLayer := p.flag! "ind-layer" |>.as! Nat + let indHead := p.flag! "ind-head" |>.as! Nat + let period? := (p.flag? "period").map (·.as! Nat) + let dirTarget := p.flag! "direction-target" |>.as! Nat + let dirNegative := p.flag! "direction-negative" |>.as! Nat + let minActive? := (p.flag? "min-active").map (·.as! Nat) + let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) + let minMarginStr? := (p.flag? "min-margin").map (·.as! String) + let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) + let timing? := (p.flag? "timing").map (·.as! Nat) + let heartbeatMs? := (p.flag? "heartbeat-ms").map (·.as! Nat) + let splitBudgetQ? := (p.flag? "split-budget-q").map (·.as! Nat) + let splitBudgetK? := (p.flag? "split-budget-k").map (·.as! Nat) + let splitBudgetDiffBase? := (p.flag? "split-budget-diff-base").map (·.as! Nat) + let splitBudgetDiffRefined? := (p.flag? "split-budget-diff-refined").map (·.as! Nat) + let skipLogitDiff := p.hasFlag "skip-logit-diff" + let zeroBased := p.hasFlag "zero-based" + match period? with + | none => + IO.eprintln "error: --period is required for circuit certification" + return 2 + | some period => + if period = 0 then + IO.eprintln "error: --period must be positive for circuit certification" + return 2 + let prevLayerE := toZeroBased "prev-layer" prevLayer zeroBased + let prevHeadE := toZeroBased "prev-head" prevHead zeroBased + let indLayerE := toZeroBased "ind-layer" indLayer zeroBased + let indHeadE := toZeroBased "ind-head" indHead zeroBased + match prevLayerE, prevHeadE, indLayerE, indHeadE with + | Except.error msg, _, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok prevLayer', Except.ok prevHead', Except.ok indLayer', Except.ok indHead' => + IO.runInductionCertifyCircuitModel modelPath prevLayer' prevHead' indLayer' indHead' + dirTarget dirNegative period + minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? + splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? + skipLogitDiff + /-- `nfp induction certify_head_model` subcommand. -/ def inductionCertifyHeadModelCmd : Cmd := `[Cli| certify_head_model VIA runInductionCertifyHeadModel; @@ -640,6 +707,7 @@ def inductionCertifyHeadModelCmd : Cmd := `[Cli| head : Nat; "Head index for the induction head (1-based)." "zero-based"; "Interpret --layer/--head as zero-based indices (legacy)." period : Nat; "Optional prompt period override (default: derive from tokens)." + "prev-shift"; "Use shifted prev/active (prev = q - period + 1)." "direction-target" : Nat; "Target token id for logit-diff direction." "direction-negative" : Nat; "Negative token id for logit-diff direction." "min-active" : Nat; "Optional minimum number of active queries required \ @@ -669,6 +737,7 @@ def inductionCertifyHeadModelNonvacuousCmd : Cmd := `[Cli| head : Nat; "Head index for the induction head (1-based)." "zero-based"; "Interpret --layer/--head as zero-based indices (legacy)." period : Nat; "Optional prompt period override (default: derive from tokens)." + "prev-shift"; "Use shifted prev/active (prev = q - period + 1)." "direction-target" : Nat; "Target token id for logit-diff direction." "direction-negative" : Nat; "Negative token id for logit-diff direction." "min-active" : Nat; "Optional minimum number of active queries required \ @@ -698,6 +767,7 @@ def inductionCertifyHeadModelAutoCmd : Cmd := `[Cli| head : Nat; "Head index for the induction head (1-based)." "zero-based"; "Interpret --layer/--head as zero-based indices (legacy)." period : Nat; "Optional prompt period override (default: derive from tokens)." + "prev-shift"; "Use shifted prev/active (prev = q - period + 1)." "min-active" : Nat; "Optional minimum number of active queries required \ (default: max 1 (seq/8))." "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ @@ -725,6 +795,7 @@ def inductionCertifyHeadModelAutoNonvacuousCmd : Cmd := `[Cli| head : Nat; "Head index for the induction head (1-based)." "zero-based"; "Interpret --layer/--head as zero-based indices (legacy)." period : Nat; "Optional prompt period override (default: derive from tokens)." + "prev-shift"; "Use shifted prev/active (prev = q - period + 1)." "min-active" : Nat; "Optional minimum number of active queries required \ (default: max 1 (seq/8))." "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ @@ -741,6 +812,38 @@ def inductionCertifyHeadModelAutoNonvacuousCmd : Cmd := `[Cli| bounds (default: 12)." ] +/-- `nfp induction certify_circuit_model` subcommand. -/ +def inductionCertifyCircuitModelCmd : Cmd := `[Cli| + certify_circuit_model VIA runInductionCertifyCircuitModel; + "Check a two-head induction circuit by reading a model binary directly \ + (induction head uses shifted prev)." + FLAGS: + model : String; "Path to the NFP_BINARY_V1 model file." + "prev-layer" : Nat; "Layer index for the previous-token head (1-based)." + "prev-head" : Nat; "Head index for the previous-token head (1-based)." + "ind-layer" : Nat; "Layer index for the induction head (1-based)." + "ind-head" : Nat; "Head index for the induction head (1-based)." + "zero-based"; "Interpret --layer/--head as zero-based indices (legacy)." + period : Nat; "Prompt period override (required)." + "direction-target" : Nat; "Target token id for logit-diff direction." + "direction-negative" : Nat; "Negative token id for logit-diff direction." + "min-active" : Nat; "Optional minimum number of active queries required \ + (default: max 1 (seq/8))." + "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ + (rational literal). Defaults to 0." + "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." + "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." + "skip-logit-diff"; "Skip logit-diff lower bound computation." + timing : Nat; "Emit timing output to stdout (0=off, 1=on)." + "heartbeat-ms" : Nat; "Emit progress heartbeat every N ms (0 disables)." + "split-budget-q" : Nat; "Split-budget for query dims in sign-splitting bounds (default: 2)." + "split-budget-k" : Nat; "Split-budget for key dims in sign-splitting bounds (default: 2)." + "split-budget-diff-base" : Nat; "Split-budget for base diff dims in sign-splitting bounds \ + (default: 0)." + "split-budget-diff-refined" : Nat; "Split-budget for refined diff dims in sign-splitting \ + bounds (default: 12)." +] + /-- `nfp induction head_interval` subcommand. -/ def runInductionHeadInterval (p : Parsed) : IO UInt32 := do let inputsPath := p.flag! "inputs" |>.as! String @@ -762,6 +865,7 @@ def runInductionHeadIntervalModel (p : Parsed) : IO UInt32 := do let layer := p.flag! "layer" |>.as! Nat let head := p.flag! "head" |>.as! Nat let period? := (p.flag? "period").map (·.as! Nat) + let prevShift := p.hasFlag "prev-shift" let dirTarget := p.flag! "direction-target" |>.as! Nat let dirNegative := p.flag! "direction-negative" |>.as! Nat let outPath? := (p.flag? "out").map (·.as! String) @@ -777,7 +881,7 @@ def runInductionHeadIntervalModel (p : Parsed) : IO UInt32 := do return 2 | Except.ok layer', Except.ok head' => IO.runInductionHeadIntervalModel modelPath layer' head' dirTarget dirNegative period? - outPath? + prevShift outPath? /-- `nfp induction head_interval_model` subcommand. -/ def inductionHeadIntervalModelCmd : Cmd := `[Cli| @@ -789,6 +893,7 @@ def inductionHeadIntervalModelCmd : Cmd := `[Cli| head : Nat; "Head index for the induction head (1-based)." "zero-based"; "Interpret --layer/--head as zero-based indices (legacy)." period : Nat; "Optional prompt period override (default: derive from tokens)." + "prev-shift"; "Use shifted prev/active (prev = q - period + 1)." "direction-target" : Nat; "Target token id for logit-diff direction." "direction-negative" : Nat; "Negative token id for logit-diff direction." out : String; "Optional path to write the residual-interval certificate." @@ -810,6 +915,7 @@ def inductionAdvancedCmd : Cmd := `[Cli| inductionCertifyHeadModelNonvacuousCmd; inductionCertifyHeadModelAutoCmd; inductionCertifyHeadModelAutoNonvacuousCmd; + inductionCertifyCircuitModelCmd; inductionHeadIntervalCmd; inductionHeadIntervalModelCmd ] diff --git a/Nfp/IO/InductionHead.lean b/Nfp/IO/InductionHead.lean index f29e2d9..1b9fd5b 100644 --- a/Nfp/IO/InductionHead.lean +++ b/Nfp/IO/InductionHead.lean @@ -3,6 +3,7 @@ module public import Nfp.IO.InductionHead.Basic +public import Nfp.IO.InductionHead.Circuit public import Nfp.IO.InductionHead.Nonvacuous /-! diff --git a/Nfp/IO/InductionHead/Basic.lean b/Nfp/IO/InductionHead/Basic.lean index 8a522f4..7b076e0 100644 --- a/Nfp/IO/InductionHead/Basic.lean +++ b/Nfp/IO/InductionHead/Basic.lean @@ -1217,7 +1217,7 @@ def runInductionCertifyHead (inputsPath : System.FilePath) /-- Build and check induction certificates from a model binary. -/ def runInductionCertifyHeadModel (modelPath : System.FilePath) - (layer head dirTarget dirNegative : Nat) (period? : Option Nat) + (layer head dirTarget dirNegative : Nat) (period? : Option Nat) (shiftPrev : Bool) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) (timing? : Option Nat) (heartbeatMs? : Option Nat) @@ -1256,7 +1256,7 @@ def runInductionCertifyHeadModel (modelPath : System.FilePath) | Except.ok ⟨header, start⟩ => let inputsE ← timePure "read head inputs" (fun () => NfptPure.readInductionHeadInputs - data start header layer head dirTarget dirNegative period?) + data start header layer head dirTarget dirNegative period? shiftPrev) match inputsE with | Except.error msg => IO.eprintln s!"error: {msg}" @@ -1296,7 +1296,7 @@ def deriveDirectionFromTokens {seq : Nat} (tokens : Fin seq → Nat) : /-- Build and check induction certificates from a model binary, deriving direction tokens from the prompt sequence. -/ def runInductionCertifyHeadModelAuto (modelPath : System.FilePath) - (layer head : Nat) (period? : Option Nat) + (layer head : Nat) (period? : Option Nat) (shiftPrev : Bool) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) (timing? : Option Nat) (heartbeatMs? : Option Nat) @@ -1349,7 +1349,7 @@ def runInductionCertifyHeadModelAuto (modelPath : System.FilePath) s!"info: direction-target={dirTarget} direction-negative={dirNegative}" let inputsE ← timePure "read head inputs" (fun () => NfptPure.readInductionHeadInputs - data start header layer head dirTarget dirNegative period?) + data start header layer head dirTarget dirNegative period? shiftPrev) match inputsE with | Except.error msg => IO.eprintln s!"error: {msg}" @@ -1371,7 +1371,7 @@ def runInductionHeadInterval (inputsPath : System.FilePath) /-- Build head-output interval bounds from a model binary. -/ def runInductionHeadIntervalModel (modelPath : System.FilePath) - (layer head dirTarget dirNegative : Nat) (period? : Option Nat) + (layer head dirTarget dirNegative : Nat) (period? : Option Nat) (shiftPrev : Bool) (outPath? : Option System.FilePath) : IO UInt32 := do let data ← IO.FS.readBinFile modelPath match NfptPure.parseHeader data with @@ -1381,7 +1381,7 @@ def runInductionHeadIntervalModel (modelPath : System.FilePath) | Except.ok ⟨header, start⟩ => match NfptPure.readInductionHeadInputs - data start header layer head dirTarget dirNegative period? + data start header layer head dirTarget dirNegative period? shiftPrev with | Except.error msg => IO.eprintln s!"error: {msg}" diff --git a/Nfp/IO/InductionHead/Circuit.lean b/Nfp/IO/InductionHead/Circuit.lean new file mode 100644 index 0000000..9576b5e --- /dev/null +++ b/Nfp/IO/InductionHead/Circuit.lean @@ -0,0 +1,80 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.IO.InductionHead.Basic + +/-! +IO helpers for induction-circuit checks (previous-token head + induction head). +-/ + +public section + +namespace Nfp + +namespace IO + +/-- Check a two-head induction circuit directly from a model binary. + +The induction head is certified with shifted `prev` (canonical circuit), while +the previous-token head uses the unshifted period-1 map. +-/ +def runInductionCertifyCircuitModel (modelPath : System.FilePath) + (prevLayer prevHead indLayer indHead dirTarget dirNegative : Nat) (period : Nat) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) + (timing? : Option Nat) (heartbeatMs? : Option Nat) + (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) + (skipLogitDiff : Bool) : + IO UInt32 := do + let prevCode ← + runInductionCertifyHeadModel + modelPath + prevLayer + prevHead + dirTarget + dirNegative + (some 1) + false + minActive? + minLogitDiffStr? + minMarginStr? + maxEpsStr? + timing? + heartbeatMs? + splitBudgetQ? + splitBudgetK? + splitBudgetDiffBase? + splitBudgetDiffRefined? + skipLogitDiff + if prevCode ≠ 0 then + return prevCode + let indCode ← + runInductionCertifyHeadModel + modelPath + indLayer + indHead + dirTarget + dirNegative + (some period) + true + minActive? + minLogitDiffStr? + minMarginStr? + maxEpsStr? + timing? + heartbeatMs? + splitBudgetQ? + splitBudgetK? + splitBudgetDiffBase? + splitBudgetDiffRefined? + skipLogitDiff + if indCode ≠ 0 then + return indCode + IO.println + "ok: circuit head certificates built (prev-token head + shifted-prev induction head)" + return 0 + +end IO + +end Nfp diff --git a/Nfp/IO/InductionHead/Nonvacuous.lean b/Nfp/IO/InductionHead/Nonvacuous.lean index 3f0e96c..ae7d0f4 100644 --- a/Nfp/IO/InductionHead/Nonvacuous.lean +++ b/Nfp/IO/InductionHead/Nonvacuous.lean @@ -408,7 +408,7 @@ def runInductionCertifyHeadNonvacuous (inputsPath : System.FilePath) /-- Build and check a strictly positive induction logit-diff bound from a model binary. -/ def runInductionCertifyHeadModelNonvacuous (modelPath : System.FilePath) - (layer head dirTarget dirNegative : Nat) (period? : Option Nat) + (layer head dirTarget dirNegative : Nat) (period? : Option Nat) (shiftPrev : Bool) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) (timing? : Option Nat) (heartbeatMs? : Option Nat) @@ -445,7 +445,7 @@ def runInductionCertifyHeadModelNonvacuous (modelPath : System.FilePath) | Except.ok ⟨header, start⟩ => let inputsE ← timePure "read head inputs" (fun () => NfptPure.readInductionHeadInputs - data start header layer head dirTarget dirNegative period?) + data start header layer head dirTarget dirNegative period? shiftPrev) match inputsE with | Except.error msg => IO.eprintln s!"error: {msg}" @@ -457,7 +457,7 @@ def runInductionCertifyHeadModelNonvacuous (modelPath : System.FilePath) /-- Build and check a strictly positive induction logit-diff bound from a model binary, deriving direction tokens from the prompt sequence. -/ def runInductionCertifyHeadModelAutoNonvacuous (modelPath : System.FilePath) - (layer head : Nat) (period? : Option Nat) + (layer head : Nat) (period? : Option Nat) (shiftPrev : Bool) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) (timing? : Option Nat) (heartbeatMs? : Option Nat) @@ -508,7 +508,7 @@ def runInductionCertifyHeadModelAutoNonvacuous (modelPath : System.FilePath) s!"info: direction-target={dirTarget} direction-negative={dirNegative}" let inputsE ← timePure "read head inputs" (fun () => NfptPure.readInductionHeadInputs - data start header layer head dirTarget dirNegative period?) + data start header layer head dirTarget dirNegative period? shiftPrev) match inputsE with | Except.error msg => IO.eprintln s!"error: {msg}" diff --git a/Nfp/IO/NfptPure.lean b/Nfp/IO/NfptPure.lean index f33ee8d..483e405 100644 --- a/Nfp/IO/NfptPure.lean +++ b/Nfp/IO/NfptPure.lean @@ -600,19 +600,35 @@ def buildInductionHeadInputs (h : NfptHeader) (scale : Rat) (attnBias ln1Gamma ln1Beta : Fin h.modelDim → Rat) (dirTarget dirNegative : Nat) (colTarget colNegative : Fin h.modelDim → Rat) - (period? : Option Nat) : + (period? : Option Nat) (shiftPrev : Bool) : Model.InductionHeadInputs h.seqLen h.modelDim h.headDim := let direction : Fin h.modelDim → Rat := fun i => colTarget i - colNegative i let directionSpec : Circuit.DirectionSpec := { target := dirTarget, negative := dirNegative } let active := match period? with - | some period => Model.activeOfPeriod (seq := h.seqLen) period - | none => Model.activeOfTokens (seq := h.seqLen) tokens + | some period => + if shiftPrev then + Model.activeOfPeriodShift (seq := h.seqLen) period + else + Model.activeOfPeriod (seq := h.seqLen) period + | none => + if shiftPrev then + Model.activeOfTokensShift (seq := h.seqLen) tokens + else + Model.activeOfTokens (seq := h.seqLen) tokens let prev := match period? with - | some period => Model.prevOfPeriod (seq := h.seqLen) period - | none => Model.prevOfTokens (seq := h.seqLen) tokens + | some period => + if shiftPrev then + Model.prevOfPeriodShift (seq := h.seqLen) period + else + Model.prevOfPeriod (seq := h.seqLen) period + | none => + if shiftPrev then + Model.prevOfTokensShift (seq := h.seqLen) tokens + else + Model.prevOfTokens (seq := h.seqLen) tokens { scale := scale active := active prev := prev @@ -641,18 +657,34 @@ private theorem buildInductionHeadInputs_def (h : NfptHeader) (scale : Rat) (attnBias ln1Gamma ln1Beta : Fin h.modelDim → Rat) (dirTarget dirNegative : Nat) (colTarget colNegative : Fin h.modelDim → Rat) - (period? : Option Nat) : + (period? : Option Nat) (shiftPrev : Bool) : buildInductionHeadInputs h scale tokens embed weights attnBias ln1Gamma ln1Beta - dirTarget dirNegative colTarget colNegative period? = + dirTarget dirNegative colTarget colNegative period? shiftPrev = { scale := scale active := match period? with - | some period => Model.activeOfPeriod (seq := h.seqLen) period - | none => Model.activeOfTokens (seq := h.seqLen) tokens + | some period => + if shiftPrev then + Model.activeOfPeriodShift (seq := h.seqLen) period + else + Model.activeOfPeriod (seq := h.seqLen) period + | none => + if shiftPrev then + Model.activeOfTokensShift (seq := h.seqLen) tokens + else + Model.activeOfTokens (seq := h.seqLen) tokens prev := match period? with - | some period => Model.prevOfPeriod (seq := h.seqLen) period - | none => Model.prevOfTokens (seq := h.seqLen) tokens + | some period => + if shiftPrev then + Model.prevOfPeriodShift (seq := h.seqLen) period + else + Model.prevOfPeriod (seq := h.seqLen) period + | none => + if shiftPrev then + Model.prevOfTokensShift (seq := h.seqLen) tokens + else + Model.prevOfTokens (seq := h.seqLen) tokens embed := embed lnEps := h.layerNormEps ln1Gamma := ln1Gamma @@ -678,10 +710,10 @@ theorem buildInductionHeadInputs_direction_def (h : NfptHeader) (scale : Rat) (attnBias ln1Gamma ln1Beta : Fin h.modelDim → Rat) (dirTarget dirNegative : Nat) (colTarget colNegative : Fin h.modelDim → Rat) - (period? : Option Nat) : + (period? : Option Nat) (shiftPrev : Bool) : let inputs := buildInductionHeadInputs h scale tokens embed weights attnBias ln1Gamma ln1Beta - dirTarget dirNegative colTarget colNegative period? + dirTarget dirNegative colTarget colNegative period? shiftPrev inputs.directionSpec = { target := dirTarget, negative := dirNegative } ∧ inputs.direction = fun i => colTarget i - colNegative i := by simp [buildInductionHeadInputs] @@ -694,18 +726,34 @@ theorem buildInductionHeadInputs_prev_active_def (h : NfptHeader) (scale : Rat) (attnBias ln1Gamma ln1Beta : Fin h.modelDim → Rat) (dirTarget dirNegative : Nat) (colTarget colNegative : Fin h.modelDim → Rat) - (period? : Option Nat) : + (period? : Option Nat) (shiftPrev : Bool) : let inputs := buildInductionHeadInputs h scale tokens embed weights attnBias ln1Gamma ln1Beta - dirTarget dirNegative colTarget colNegative period? + dirTarget dirNegative colTarget colNegative period? shiftPrev inputs.active = (match period? with - | some period => Model.activeOfPeriod (seq := h.seqLen) period - | none => Model.activeOfTokens (seq := h.seqLen) tokens) ∧ + | some period => + if shiftPrev then + Model.activeOfPeriodShift (seq := h.seqLen) period + else + Model.activeOfPeriod (seq := h.seqLen) period + | none => + if shiftPrev then + Model.activeOfTokensShift (seq := h.seqLen) tokens + else + Model.activeOfTokens (seq := h.seqLen) tokens) ∧ inputs.prev = (match period? with - | some period => Model.prevOfPeriod (seq := h.seqLen) period - | none => Model.prevOfTokens (seq := h.seqLen) tokens) := by + | some period => + if shiftPrev then + Model.prevOfPeriodShift (seq := h.seqLen) period + else + Model.prevOfPeriod (seq := h.seqLen) period + | none => + if shiftPrev then + Model.prevOfTokensShift (seq := h.seqLen) tokens + else + Model.prevOfTokens (seq := h.seqLen) tokens) := by constructor <;> rfl /-- Active queries pick the maximal matching prior token when `period? = none`. -/ @@ -718,10 +766,10 @@ theorem buildInductionHeadInputs_prev_spec_of_active (h : NfptHeader) (scale : R (colTarget colNegative : Fin h.modelDim → Rat) : ∀ {q}, q ∈ (buildInductionHeadInputs h scale tokens embed weights attnBias ln1Gamma ln1Beta - dirTarget dirNegative colTarget colNegative none).active → + dirTarget dirNegative colTarget colNegative none false).active → let p := (buildInductionHeadInputs h scale tokens embed weights attnBias ln1Gamma ln1Beta - dirTarget dirNegative colTarget colNegative none).prev q + dirTarget dirNegative colTarget colNegative none false).prev q p < q ∧ tokens p = tokens q ∧ ∀ k, k < q → tokens k = tokens q → k ≤ p := by intro q hq @@ -730,9 +778,13 @@ theorem buildInductionHeadInputs_prev_spec_of_active (h : NfptHeader) (scale : R have hspec := Model.prevOfTokens_spec_of_active (tokens := tokens) (q := q) hq' simpa [buildInductionHeadInputs] using hspec -/-- Read induction-head inputs directly from the model binary. -/ +/-- Read induction-head inputs directly from the model binary. + +`shiftPrev` selects between the unshifted prompt map (`prev = q - period`) and the +shifted map (`prev = q - period + 1`), with analogous token-derived versions. +-/ def readInductionHeadInputs (data : ByteArray) (start : Nat) (h : NfptHeader) - (layer head dirTarget dirNegative : Nat) (period? : Option Nat) : + (layer head dirTarget dirNegative : Nat) (period? : Option Nat) (shiftPrev : Bool) : Except String (Model.InductionHeadInputs h.seqLen h.modelDim h.headDim) := do let scale ← scaleOfHeadDim h.headDim let tokens ← readTokens data start h @@ -743,7 +795,7 @@ def readInductionHeadInputs (data : ByteArray) (start : Nat) (h : NfptHeader) let colNegative ← readUnembedColumn data start h dirNegative pure <| buildInductionHeadInputs h scale tokens embed weights attnBias ln1Gamma ln1Beta - dirTarget dirNegative colTarget colNegative period? + dirTarget dirNegative colTarget colNegative period? shiftPrev end NfptPure diff --git a/Nfp/IO/Run/Basic.lean b/Nfp/IO/Run/Basic.lean index fccad3a..e2a3b69 100644 --- a/Nfp/IO/Run/Basic.lean +++ b/Nfp/IO/Run/Basic.lean @@ -717,7 +717,7 @@ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) timePure "read head inputs" (fun () => NfptPure.readInductionHeadInputs data start header layer head - dirPos dirNeg period?) + dirPos dirNeg period? false) match inputsE with | Except.error msg => IO.eprintln s!"warning: {msg}" diff --git a/Nfp/Model.lean b/Nfp/Model.lean index 7a97665..cd4f980 100644 --- a/Nfp/Model.lean +++ b/Nfp/Model.lean @@ -5,6 +5,7 @@ module public import Nfp.Model.Gpt2 public import Nfp.Model.InductionHead public import Nfp.Model.InductionPrompt +public import Nfp.Model.InductionCircuit /-! Model-specific data containers for the NFP rewrite. diff --git a/Nfp/Model/InductionCircuit.lean b/Nfp/Model/InductionCircuit.lean new file mode 100644 index 0000000..0ddd03b --- /dev/null +++ b/Nfp/Model/InductionCircuit.lean @@ -0,0 +1,60 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Model.InductionPrompt + +/-! +Circuit-level induction prompt specifications. + +These wrappers name the *shifted* prev/active maps that correspond to the +canonical induction-head circuit (previous-token head feeding induction head). +They are definitional aliases over `InductionPrevSpec*Shift`, but make the +intended mechanistic interpretation explicit for downstream lemmas. +-/ + +public section + +namespace Nfp + +namespace Model + +/-- +Circuit-level shifted-prev spec for periodic prompts. + +This matches the canonical induction circuit: a previous-token head shifts +the match by one position, so the induction head attends to `q - period + 1`. +-/ +structure InductionCircuitSpecPeriodShift {seq dModel dHead : Nat} + (period : Nat) (inputs : InductionHeadInputs seq dModel dHead) : Prop where + /-- The underlying shifted-period prev/active spec. -/ + prev_spec : InductionPrevSpecPeriodShift (seq := seq) period inputs + +/-- +Circuit-level shifted-prev spec for token-based prompts. + +This corresponds to the canonical induction circuit with a previous-token head, +so the induction head uses the shifted-token map. +-/ +structure InductionCircuitSpecTokensShift {seq dModel dHead : Nat} + (tokens : Fin seq → Nat) (inputs : InductionHeadInputs seq dModel dHead) : Prop where + /-- The underlying shifted-token prev/active spec. -/ + prev_spec : InductionPrevSpecTokensShift (seq := seq) tokens inputs + +/-- +Lift a shifted-period circuit spec to the token-based spec for diagnostic prompts. +-/ +theorem InductionCircuitSpecTokensShift_of_diag {dModel dHead : Nat} {period : Nat} + (tokens : Fin (2 * period) → Nat) + (inputs : InductionHeadInputs (2 * period) dModel dHead) + (hdiag : InductionDiagnosticTokens period tokens) + (hper : 0 < period) + (hspec : InductionCircuitSpecPeriodShift (seq := 2 * period) period inputs) : + InductionCircuitSpecTokensShift (seq := 2 * period) tokens inputs := by + refine ⟨?_,⟩ + exact + InductionPrevSpecTokensShift_of_diag (tokens := tokens) hdiag hper hspec.prev_spec + +end Model + +end Nfp diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean index ae49847..2554f68 100644 --- a/Nfp/Sound/Induction/LogitDiff.lean +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -402,9 +402,14 @@ def logitDiffLowerBoundRefineOnDemand | some lbWeight => if lbWeight ≤ 0 then let valBudget := refineBudgetBoost refineBudget - let valKeys := loAtKeysAt inputs core q0 + let valCount := refineLowValueCount refineBudget + let valKeys := + loAtKeysAt inputs core q0 ∪ + lowValueKeysAt inputs core q0 valCount let valsLo := valsLoOverlay inputs core valBudget valKeys - match logitDiffLowerBoundFromCacheWithEpsVals c cache.epsAt valsLo with + let lbRefined? := + logitDiffLowerBoundFromCacheWithEpsVals c cache.epsAt valsLo + match lbRefined? with | some lb2 => some (max lbWeight lb2) | none => some lbWeight else @@ -445,9 +450,14 @@ theorem logitDiffLowerBoundRefineOnDemand_def | some lbWeight => if lbWeight ≤ 0 then let valBudget := refineBudgetBoost refineBudget - let valKeys := loAtKeysAt inputs core q0 + let valCount := refineLowValueCount refineBudget + let valKeys := + loAtKeysAt inputs core q0 ∪ + lowValueKeysAt inputs core q0 valCount let valsLo := valsLoOverlay inputs core valBudget valKeys - match logitDiffLowerBoundFromCacheWithEpsVals c cache.epsAt valsLo with + let lbRefined? := + logitDiffLowerBoundFromCacheWithEpsVals c cache.epsAt valsLo + match lbRefined? with | some lb2 => some (max lbWeight lb2) | none => some lbWeight else diff --git a/Nfp/Sound/Induction/Refine.lean b/Nfp/Sound/Induction/Refine.lean index 699fa5c..9f3b9ab 100644 --- a/Nfp/Sound/Induction/Refine.lean +++ b/Nfp/Sound/Induction/Refine.lean @@ -44,6 +44,15 @@ theorem refineTopWeightCount_def (budget : Nat) : refineTopWeightCount budget = min 8 (max 1 (2 * budget)) := by rfl +/-- Heuristic cap on the number of low-value keys to refine. -/ +def refineLowValueCount (budget : Nat) : Nat := + min 8 (max 1 (2 * budget)) + +/-- Unfolding lemma for `refineLowValueCount`. -/ +theorem refineLowValueCount_def (budget : Nat) : + refineLowValueCount budget = min 8 (max 1 (2 * budget)) := by + rfl + /-- Scale used for refined value bounds. -/ def valRefineScale (budget : Nat) : Nat := Bounds.sqrtLowerScale * refineBudgetBoost budget @@ -163,6 +172,38 @@ theorem topWeightKeysAt_def keys.foldr (fun k acc => insert k acc) ∅ := by rfl +/-- Low-value keys for a query (excluding `prev`), capped by `count`. -/ +def lowValueKeysAt + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (q : Fin seq) (count : Nat) : Finset (Fin seq) := + if count = 0 then + ∅ + else + let others := (List.finRange seq).filter (fun k => decide (k ≠ inputs.prev q)) + let valued : Array (Rat × Fin seq) := + others.toArray.map (fun k => (cache.valsLo k, k)) + let sorted := valued.qsort (fun a b => a.1 < b.1) + let keys := (sorted.toList.take count).map (fun p => p.2) + keys.foldr (fun k acc => insert k acc) ∅ + +/-- Unfolding lemma for `lowValueKeysAt`. -/ +theorem lowValueKeysAt_def + (inputs : Model.InductionHeadInputs seq dModel dHead) + (cache : InductionHeadCoreCache seq dModel dHead) + (q : Fin seq) (count : Nat) : + lowValueKeysAt inputs cache q count = + if count = 0 then + ∅ + else + let others := (List.finRange seq).filter (fun k => decide (k ≠ inputs.prev q)) + let valued : Array (Rat × Fin seq) := + others.toArray.map (fun k => (cache.valsLo k, k)) + let sorted := valued.qsort (fun a b => a.1 < b.1) + let keys := (sorted.toList.take count).map (fun p => p.2) + keys.foldr (fun k acc => insert k acc) ∅ := by + rfl + /-- Refinement keys for a query, seeded by negative base gaps and the worst key. -/ def refineKeysAt (inputs : Model.InductionHeadInputs seq dModel dHead) diff --git a/Nfp/Sound/Induction/RefineSound.lean b/Nfp/Sound/Induction/RefineSound.lean index ccfd3b6..9a09962 100644 --- a/Nfp/Sound/Induction/RefineSound.lean +++ b/Nfp/Sound/Induction/RefineSound.lean @@ -648,7 +648,10 @@ theorem logitDiffLowerBoundRefineOnDemand_le logitDiffLowerBoundRefinedFromCache inputs cache c logitCache spec' with | none => let valBudget := refineBudgetBoost refineBudget - let valKeys := loAtKeysAt inputs cache q0 + let valCount := refineLowValueCount refineBudget + let valKeys := + loAtKeysAt inputs cache q0 ∪ + lowValueKeysAt inputs cache q0 valCount let valsLo := valsLoOverlay inputs cache valBudget valKeys cases hval : logitDiffLowerBoundFromCacheWithEpsVals c logitCache.epsAt valsLo with @@ -656,7 +659,7 @@ theorem logitDiffLowerBoundRefineOnDemand_le have hlb : lb = lb01 := by simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin, h1, spec, refineBudget, lb01, hnonpos1, h2, spec', refineBudget', - valBudget, valKeys, valsLo, hval] using hbound.symm + valBudget, valCount, valKeys, valsLo, hval] using hbound.symm have hbase := hbase_le (lb0 := lb0) h0 have hweight_overlay' : ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → @@ -686,7 +689,7 @@ theorem logitDiffLowerBoundRefineOnDemand_le have hlb : lb = max lb01 lb2 := by simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin, h1, spec, refineBudget, lb01, hnonpos1, h2, spec', refineBudget', - valBudget, valKeys, valsLo, hval] using hbound.symm + valBudget, valCount, valKeys, valsLo, hval] using hbound.symm have hbase := hbase_le (lb0 := lb0) h0 have hweight_overlay' : ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → @@ -801,7 +804,10 @@ theorem logitDiffLowerBoundRefineOnDemand_le let lbWeight : Rat := max lb01 lb2 by_cases hweight_nonpos : lbWeight ≤ 0 · let valBudget := refineBudgetBoost refineBudget - let valKeys := loAtKeysAt inputs cache q0 + let valCount := refineLowValueCount refineBudget + let valKeys := + loAtKeysAt inputs cache q0 ∪ + lowValueKeysAt inputs cache q0 valCount let valsLo := valsLoOverlay inputs cache valBudget valKeys cases hval : logitDiffLowerBoundFromCacheWithEpsVals c logitCache.epsAt valsLo with @@ -810,14 +816,14 @@ theorem logitDiffLowerBoundRefineOnDemand_le simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin, h1, spec, refineBudget, lb01, hnonpos1, h2, spec', refineBudget', lbWeight, hweight_nonpos, valBudget, valKeys, - valsLo, hval] using hbound.symm + valCount, valsLo, hval] using hbound.symm simpa [hlb, lbWeight] using hmax_weight | some lb3 => have hlb : lb = max lbWeight lb3 := by simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin, h1, spec, refineBudget, lb01, hnonpos1, h2, spec', refineBudget', lbWeight, hweight_nonpos, valBudget, valKeys, - valsLo, hval] using hbound.symm + valCount, valsLo, hval] using hbound.symm have hsound_cache : InductionHeadCertSound inputs cache.cert := by simpa [hcert] using hsound have hvalsLo : diff --git a/docs/induction_cert_audit.md b/docs/induction_cert_audit.md index df7c88f..d7ed707 100644 --- a/docs/induction_cert_audit.md +++ b/docs/induction_cert_audit.md @@ -66,6 +66,13 @@ Key assumptions and limitations: - For `certify_head_model` with `period? = none`, `prev`/`active` are derived from tokens and `prev` is the maximal prior match. For head-input files or when `period?` is set explicitly, `prev` remains a user-supplied input. +- The `--prev-shift` flag switches to the **shifted** `prev` map (`q - period + 1` + or the token-shifted analogue). This aligns the head-level certificate with + the canonical induction circuit (previous-token head → induction head), but + it is still a head-level approximation rather than a verified two-head + composition. +- The `certify_circuit_model` CLI uses shifted `prev` for the induction head + by default, while the previous-token head uses the unshifted period-1 map. - The certificate proves a logit-diff bound along the supplied `direction` vector. For model-derived inputs, this vector is the target-minus-negative unembedding column difference, but we still assume that the unembedding diff --git a/scripts/certify_induction_head.py b/scripts/certify_induction_head.py index fa6a159..b183df9 100644 --- a/scripts/certify_induction_head.py +++ b/scripts/certify_induction_head.py @@ -81,6 +81,7 @@ def main() -> int: parser.add_argument("--layer", type=int, required=True) parser.add_argument("--head", type=int, required=True) parser.add_argument("--period", type=int) + parser.add_argument("--prev-shift", action="store_true") parser.add_argument("--nonvacuous", action="store_true") parser.add_argument("--zero-based", action="store_true") parser.add_argument("--min-active", type=int) @@ -168,6 +169,8 @@ def main() -> int: cmd.append("--zero-based") if args.period is not None: cmd += ["--period", str(args.period)] + if args.prev_shift: + cmd.append("--prev-shift") if args.min_active is not None: cmd += ["--min-active", str(args.min_active)] if args.min_logit_diff is not None: diff --git a/scripts/discover_gpt2_induction_targets.py b/scripts/discover_gpt2_induction_targets.py index e6946cf..df146ce 100644 --- a/scripts/discover_gpt2_induction_targets.py +++ b/scripts/discover_gpt2_induction_targets.py @@ -196,6 +196,18 @@ def build_prev_period(seq_len: int, period: int) -> Tuple[np.ndarray, np.ndarray return prev, active +def build_prev_period_shift(seq_len: int, period: int) -> Tuple[np.ndarray, np.ndarray]: + prev = np.zeros(seq_len, dtype=np.int64) + active = np.zeros(seq_len, dtype=bool) + if period <= 0: + return prev, active + idx = np.arange(seq_len) + mask = idx >= period + prev[mask] = idx[mask] - period + 1 + active[mask] = True + return prev, active + + def layer_norm(x: np.ndarray, gamma: np.ndarray, beta: np.ndarray, eps: float) -> np.ndarray: mean = x.mean(axis=1, keepdims=True) var = ((x - mean) ** 2).mean(axis=1, keepdims=True) @@ -592,9 +604,12 @@ def main() -> int: parser.add_argument("--period", type=int, help="Optional prompt period override") parser.add_argument( "--prev-mode", - choices=["bigram", "token", "period"], + choices=["bigram", "token", "period", "period_shift"], default="bigram", - help="Choose prev/active construction (default: bigram prefix match).", + help=( + "Choose prev/active construction (default: bigram prefix match). " + "period_shift uses q-period+1." + ), ) parser.add_argument( "--stripe-period", @@ -654,14 +669,17 @@ def main() -> int: else: embeddings = read_f64(f, seq_len * model_dim).reshape(seq_len, model_dim) - if args.prev_mode != "period" and args.period is not None: + if args.prev_mode not in {"period", "period_shift"} and args.period is not None: raise SystemExit("--period is incompatible with --prev-mode=token/bigram") - if args.prev_mode == "period" and args.period is None: - raise SystemExit("--prev-mode=period requires --period") + if args.prev_mode in {"period", "period_shift"} and args.period is None: + raise SystemExit("--prev-mode=period/period_shift requires --period") if args.prev_mode == "period": period = int(args.period) prev, active_mask = build_prev_period(seq_len, period) + elif args.prev_mode == "period_shift": + period = int(args.period) + prev, active_mask = build_prev_period_shift(seq_len, period) elif args.prev_mode == "bigram": prev, active_mask = build_prev_bigram(tokens) else: diff --git a/scripts/scan_gpt2_induction_sound.py b/scripts/scan_gpt2_induction_sound.py index bc42271..861af47 100755 --- a/scripts/scan_gpt2_induction_sound.py +++ b/scripts/scan_gpt2_induction_sound.py @@ -168,7 +168,7 @@ def main() -> int: parser.add_argument("--period", type=int) parser.add_argument( "--prev-mode", - choices=["bigram", "token", "period"], + choices=["bigram", "token", "period", "period_shift"], default="bigram", help="Choose prev/active construction (forwarded to discovery).", ) diff --git a/scripts/sweep_gpt2_induction_nonvacuous.py b/scripts/sweep_gpt2_induction_nonvacuous.py index aad0723..7e91a33 100644 --- a/scripts/sweep_gpt2_induction_nonvacuous.py +++ b/scripts/sweep_gpt2_induction_nonvacuous.py @@ -239,7 +239,7 @@ def main() -> int: help="Use pattern length as the period override") parser.add_argument( "--prev-mode", - choices=["bigram", "token", "period"], + choices=["bigram", "token", "period", "period_shift"], default="bigram", help="Choose prev/active construction (forwarded to discovery).", ) From dd154971766b2fe45947e2ff095ddb7e44e2a6fa Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 07:13:43 +0100 Subject: [PATCH 212/244] Add explicit induction head cert checker --- Nfp/Circuit/Cert.lean | 1 + Nfp/Circuit/Cert/InductionHead.lean | 310 ++++++++++++++++++++++++++++ 2 files changed, 311 insertions(+) create mode 100644 Nfp/Circuit/Cert/InductionHead.lean diff --git a/Nfp/Circuit/Cert.lean b/Nfp/Circuit/Cert.lean index b71e3c0..b6f1bc4 100644 --- a/Nfp/Circuit/Cert.lean +++ b/Nfp/Circuit/Cert.lean @@ -4,6 +4,7 @@ module public import Nfp.Circuit.Cert.Basic public import Nfp.Circuit.Cert.DownstreamLinear +public import Nfp.Circuit.Cert.InductionHead public import Nfp.Circuit.Cert.LogitDiff public import Nfp.Circuit.Cert.ResidualBound public import Nfp.Circuit.Cert.ResidualInterval diff --git a/Nfp/Circuit/Cert/InductionHead.lean b/Nfp/Circuit/Cert/InductionHead.lean new file mode 100644 index 0000000..301b983 --- /dev/null +++ b/Nfp/Circuit/Cert/InductionHead.lean @@ -0,0 +1,310 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Algebra.BigOperators.Group.Finset.Basic +public import Nfp.Core.Basic +public import Nfp.Circuit.Cert.Basic +public import Nfp.Circuit.Cert.SoftmaxMargin +public import Nfp.Circuit.Cert.ValueRange +public import Nfp.Circuit.Layers.Induction + +/-! +Induction-head certificates with explicit scores, weights, and value bounds. +-/ + +public section + +namespace Nfp + +namespace Circuit + +open scoped BigOperators + +variable {seq : Nat} + +/-- Certificate payload for value-interval bounds (Rat-valued). -/ +structure ValueIntervalCert (seq : Nat) where + /-- Lower bound for values. -/ + lo : Rat + /-- Upper bound for values. -/ + hi : Rat + /-- Lower bounds on per-key values. -/ + valsLo : Fin seq → Rat + /-- Upper bounds on per-key values. -/ + valsHi : Fin seq → Rat + /-- Exact per-key values. -/ + vals : Fin seq → Rat + /-- Optional logit-diff direction metadata (ignored by the checker). -/ + direction : Option DirectionSpec + +/-- Internal consistency predicate for value-interval certificates. -/ +structure ValueIntervalCertBounds {seq : Nat} (c : ValueIntervalCert seq) : Prop where + /-- Interval endpoints are ordered. -/ + lo_le_hi : c.lo ≤ c.hi + /-- `lo` is below every lower bound. -/ + lo_le_valsLo : ∀ k, c.lo ≤ c.valsLo k + /-- Lower bounds are below the values. -/ + valsLo_le_vals : ∀ k, c.valsLo k ≤ c.vals k + /-- Values are below the upper bounds. -/ + vals_le_valsHi : ∀ k, c.vals k ≤ c.valsHi k + /-- Upper bounds are below `hi`. -/ + valsHi_le_hi : ∀ k, c.valsHi k ≤ c.hi + +/-- Boolean checker for value-interval certificates. -/ +def checkValueIntervalCert [NeZero seq] (c : ValueIntervalCert seq) : Bool := + decide (c.lo ≤ c.hi) && + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (c.lo ≤ c.valsLo k) && + decide (c.valsLo k ≤ c.vals k) && + decide (c.vals k ≤ c.valsHi k) && + decide (c.valsHi k ≤ c.hi)) + +/-- `checkValueIntervalCert` is sound for `ValueIntervalCertBounds`. -/ +theorem checkValueIntervalCert_sound [NeZero seq] (c : ValueIntervalCert seq) : + checkValueIntervalCert c = true → ValueIntervalCertBounds c := by + classical + intro hcheck + have hcheck' : + decide (c.lo ≤ c.hi) = true ∧ + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (c.lo ≤ c.valsLo k) && + decide (c.valsLo k ≤ c.vals k) && + decide (c.vals k ≤ c.valsHi k) && + decide (c.valsHi k ≤ c.hi)) = true := by + simpa [checkValueIntervalCert, Bool.and_eq_true] using hcheck + rcases hcheck' with ⟨hlohi, hall⟩ + have hlohi' : c.lo ≤ c.hi := by + simpa [decide_eq_true_iff] using hlohi + have hall' := + (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hall + have hbounds : + ∀ k, c.lo ≤ c.valsLo k ∧ c.valsLo k ≤ c.vals k ∧ + c.vals k ≤ c.valsHi k ∧ c.valsHi k ≤ c.hi := by + intro k + have hk := hall' k (by simp) + have hk' : + decide (c.lo ≤ c.valsLo k) = true ∧ + decide (c.valsLo k ≤ c.vals k) = true ∧ + decide (c.vals k ≤ c.valsHi k) = true ∧ + decide (c.valsHi k ≤ c.hi) = true := by + simpa [Bool.and_eq_true, and_assoc] using hk + rcases hk' with ⟨hlo, hloVals, hvalsHi, hhi⟩ + refine ⟨?_, ?_, ?_, ?_⟩ + · simpa [decide_eq_true_iff] using hlo + · simpa [decide_eq_true_iff] using hloVals + · simpa [decide_eq_true_iff] using hvalsHi + · simpa [decide_eq_true_iff] using hhi + refine + { lo_le_hi := hlohi' + lo_le_valsLo := fun k => (hbounds k).1 + valsLo_le_vals := fun k => (hbounds k).2.1 + vals_le_valsHi := fun k => (hbounds k).2.2.1 + valsHi_le_hi := fun k => (hbounds k).2.2.2 } + +/-- Certificate payload for induction-head bounds (Rat-valued). -/ +structure InductionHeadCert (seq : Nat) where + /-- Weight tolerance. -/ + eps : Rat + /-- Per-query weight tolerance. -/ + epsAt : Fin seq → Rat + /-- Per-key weight bounds derived from score gaps. -/ + weightBoundAt : Fin seq → Fin seq → Rat + /-- Score margin used to justify weight bounds. -/ + margin : Rat + /-- Active queries for which bounds are checked. -/ + active : Finset (Fin seq) + /-- `prev` selector for induction-style attention. -/ + prev : Fin seq → Fin seq + /-- Score matrix entries. -/ + scores : Fin seq → Fin seq → Rat + /-- Attention weight entries. -/ + weights : Fin seq → Fin seq → Rat + /-- Value-interval certificate for direction values. -/ + values : ValueIntervalCert seq + +/-- View an induction certificate as a softmax-margin certificate. -/ +def InductionHeadCert.softmaxMargin (c : InductionHeadCert seq) : SoftmaxMarginCert seq := + { eps := c.eps + margin := c.margin + active := c.active + prev := c.prev + scores := c.scores + weights := c.weights } + +private def weightsOkAt [NeZero seq] (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Rat) (epsAt : Fin seq → Rat) (q : Fin seq) : Bool := + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (0 ≤ weights q k) && + (if k = prev q then + true + else + decide (weights q k ≤ epsAt q))) + +private def checkOneHotAt [NeZero seq] (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Rat) (epsAt : Fin seq → Rat) (q : Fin seq) : Bool := + weightsOkAt prev weights epsAt q && + decide (1 ≤ weights q (prev q) + epsAt q) && + decide ((∑ k, weights q k) = 1) + +private def checkWeightBoundsAt [NeZero seq] (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Rat) (weightBoundAt : Fin seq → Fin seq → Rat) (q : Fin seq) : + Bool := + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + if k = prev q then + true + else + decide (weights q k ≤ weightBoundAt q k)) + +/-- `checkOneHotAt` yields per-query approximate one-hot bounds. -/ +private theorem checkOneHotAt_sound [NeZero seq] (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Rat) (epsAt : Fin seq → Rat) (q : Fin seq) : + checkOneHotAt prev weights epsAt q = true → + Layers.OneHotApproxBoundsOnActive (Val := Rat) (epsAt q : Rat) + (fun q' => q' = q) prev weights := by + classical + intro hOneHot + have hOneHot' : + weightsOkAt prev weights epsAt q = true ∧ + decide (1 ≤ weights q (prev q) + epsAt q) = true ∧ + decide ((∑ k, weights q k) = 1) = true := by + simpa [checkOneHotAt, Bool.and_eq_true, and_assoc] using hOneHot + rcases hOneHot' with ⟨hweights, hprev, hsum⟩ + have hweights' := + (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hweights + refine + { nonneg := ?_ + sum_one := ?_ + prev_large := ?_ + other_le := ?_ } + · intro q' hq' k + cases hq' + have hk := hweights' k (by simp) + have hk' : + decide (0 ≤ weights q k) = true ∧ + (if k = prev q then + true + else + decide (weights q k ≤ epsAt q)) = true := by + simpa [Bool.and_eq_true] using hk + simpa [decide_eq_true_iff] using hk'.1 + · intro q' hq' + cases hq' + simpa [decide_eq_true_iff] using hsum + · intro q' hq' + cases hq' + simpa [decide_eq_true_iff] using hprev + · intro q' hq' k hk + cases hq' + have hk' := hweights' k (by simp) + have hk'' : + decide (0 ≤ weights q k) = true ∧ + (if k = prev q then + true + else + decide (weights q k ≤ epsAt q)) = true := by + simpa [Bool.and_eq_true] using hk' + have hother : + decide (weights q k ≤ epsAt q) = true := by + simpa [hk] using hk''.2 + simpa [decide_eq_true_iff] using hother + +/-- `checkWeightBoundsAt` yields per-key upper bounds on non-`prev` weights. -/ +private theorem checkWeightBoundsAt_sound [NeZero seq] (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Rat) (weightBoundAt : Fin seq → Fin seq → Rat) (q : Fin seq) : + checkWeightBoundsAt prev weights weightBoundAt q = true → + ∀ k, k ≠ prev q → weights q k ≤ weightBoundAt q k := by + classical + intro hweights k hk + have hweights' := + (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hweights + have hk' := hweights' k (by simp) + have hk'' : decide (weights q k ≤ weightBoundAt q k) = true := by + simpa [hk] using hk' + simpa [decide_eq_true_iff] using hk'' + +/-- Boolean checker for induction-head certificates. -/ +def checkInductionHeadCert [NeZero seq] (c : InductionHeadCert seq) : Bool := + checkSoftmaxMarginCert c.softmaxMargin && + finsetAll (Finset.univ : Finset (Fin seq)) (fun q => + if q ∈ c.active then + checkOneHotAt c.prev c.weights c.epsAt q && + checkWeightBoundsAt c.prev c.weights c.weightBoundAt q + else + true) && + checkValueIntervalCert c.values + +/-- Soundness predicate for induction-head certificates. -/ +structure InductionHeadCertBounds [NeZero seq] (c : InductionHeadCert seq) : Prop where + /-- Softmax-margin bounds on active queries. -/ + softmax_bounds : + Layers.SoftmaxMarginBoundsOn (Val := Rat) c.eps c.margin (fun q => q ∈ c.active) + c.prev c.scores c.weights + /-- Per-query one-hot bounds for the weights. -/ + oneHot_bounds_at : + ∀ q, q ∈ c.active → + Layers.OneHotApproxBoundsOnActive (Val := Rat) (c.epsAt q : Rat) + (fun q' => q' = q) c.prev c.weights + /-- Per-key weight bounds for non-`prev` keys. -/ + weight_bounds_at : + ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → + c.weights q k ≤ c.weightBoundAt q k + /-- Value-interval bounds are internally consistent. -/ + value_bounds : ValueIntervalCertBounds c.values + +/-- `checkInductionHeadCert` is sound for `InductionHeadCertBounds`. -/ +theorem checkInductionHeadCert_sound [NeZero seq] (c : InductionHeadCert seq) : + checkInductionHeadCert c = true → InductionHeadCertBounds c := by + classical + intro hcheck + have hsplit : + checkSoftmaxMarginCert c.softmaxMargin = true ∧ + finsetAll (Finset.univ : Finset (Fin seq)) (fun q => + if q ∈ c.active then + checkOneHotAt c.prev c.weights c.epsAt q && + checkWeightBoundsAt c.prev c.weights c.weightBoundAt q + else + true) = true ∧ + checkValueIntervalCert c.values = true := by + simpa [checkInductionHeadCert, Bool.and_eq_true, and_assoc] using hcheck + rcases hsplit with ⟨hsoftmax, hactive, hvalues⟩ + have hsoftmax' : + Layers.SoftmaxMarginBoundsOn (Val := Rat) c.eps c.margin (fun q => q ∈ c.active) + c.prev c.scores c.weights := + checkSoftmaxMarginCert_sound c.softmaxMargin hsoftmax + have hactive' := + (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hactive + have honehot : + ∀ q, q ∈ c.active → + Layers.OneHotApproxBoundsOnActive (Val := Rat) (c.epsAt q : Rat) + (fun q' => q' = q) c.prev c.weights := by + intro q hq + have hq' := hactive' q (by simp) + have hq'' : + checkOneHotAt c.prev c.weights c.epsAt q = true ∧ + checkWeightBoundsAt c.prev c.weights c.weightBoundAt q = true := by + simpa [hq, Bool.and_eq_true] using hq' + have hOneHot : checkOneHotAt c.prev c.weights c.epsAt q = true := hq''.1 + exact checkOneHotAt_sound c.prev c.weights c.epsAt q hOneHot + have hweightBounds : + ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → + c.weights q k ≤ c.weightBoundAt q k := by + intro q hq k hk + have hq' := hactive' q (by simp) + have hq'' : + checkOneHotAt c.prev c.weights c.epsAt q = true ∧ + checkWeightBoundsAt c.prev c.weights c.weightBoundAt q = true := by + simpa [hq, Bool.and_eq_true] using hq' + have hweights : checkWeightBoundsAt c.prev c.weights c.weightBoundAt q = true := hq''.2 + exact checkWeightBoundsAt_sound c.prev c.weights c.weightBoundAt q hweights k hk + have hvals : ValueIntervalCertBounds c.values := + checkValueIntervalCert_sound c.values hvalues + exact + { softmax_bounds := hsoftmax' + oneHot_bounds_at := honehot + weight_bounds_at := hweightBounds + value_bounds := hvals } + +end Circuit + +end Nfp From 1838854a24f5ce7480766c39b1bc6208b4ca65bc Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 07:21:51 +0100 Subject: [PATCH 213/244] Add CLI check for explicit induction certs --- Nfp/Cli.lean | 26 +- Nfp/IO/InductionHead.lean | 1 + Nfp/IO/InductionHead/Cert.lean | 420 +++++++++++++++++++++++++++++++++ 3 files changed, 446 insertions(+), 1 deletion(-) create mode 100644 Nfp/IO/InductionHead/Cert.lean diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index 982ff57..b24e1f4 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -859,6 +859,29 @@ def inductionHeadIntervalCmd : Cmd := `[Cli| out : String; "Optional path to write the residual-interval certificate." ] +/-- `nfp induction head_cert_check` subcommand. -/ +def runInductionHeadCertCheck (p : Parsed) : IO UInt32 := do + let certPath := p.flag! "cert" |>.as! String + let minActive? := (p.flag? "min-active").map (·.as! Nat) + let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) + let minMarginStr? := (p.flag? "min-margin").map (·.as! String) + let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) + IO.runInductionHeadCertCheck certPath minActive? minLogitDiffStr? minMarginStr? maxEpsStr? + +/-- `nfp induction head_cert_check` subcommand. -/ +def inductionHeadCertCheckCmd : Cmd := `[Cli| + head_cert_check VIA runInductionHeadCertCheck; + "Check an explicit induction-head certificate." + FLAGS: + cert : String; "Path to the induction-head certificate file." + "min-active" : Nat; "Optional minimum number of active queries required \ + (default: max 1 (seq/8))." + "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ + (rational literal; defaults to 0 when direction is set)." + "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." + "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." +] + /-- `nfp induction head_interval_model` subcommand. -/ def runInductionHeadIntervalModel (p : Parsed) : IO UInt32 := do let modelPath := p.flag! "model" |>.as! String @@ -917,7 +940,8 @@ def inductionAdvancedCmd : Cmd := `[Cli| inductionCertifyHeadModelAutoNonvacuousCmd; inductionCertifyCircuitModelCmd; inductionHeadIntervalCmd; - inductionHeadIntervalModelCmd + inductionHeadIntervalModelCmd; + inductionHeadCertCheckCmd ] /-- Induction-head subcommands. -/ diff --git a/Nfp/IO/InductionHead.lean b/Nfp/IO/InductionHead.lean index 1b9fd5b..e718a73 100644 --- a/Nfp/IO/InductionHead.lean +++ b/Nfp/IO/InductionHead.lean @@ -3,6 +3,7 @@ module public import Nfp.IO.InductionHead.Basic +public import Nfp.IO.InductionHead.Cert public import Nfp.IO.InductionHead.Circuit public import Nfp.IO.InductionHead.Nonvacuous diff --git a/Nfp/IO/InductionHead/Cert.lean b/Nfp/IO/InductionHead/Cert.lean new file mode 100644 index 0000000..36dc961 --- /dev/null +++ b/Nfp/IO/InductionHead/Cert.lean @@ -0,0 +1,420 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Data.Finset.Insert +public import Nfp.Circuit.Cert.InductionHead +public import Nfp.Circuit.Cert.LogitDiff +public import Nfp.IO.Pure.Basic +public import Nfp.IO.Util + +/-! +Untrusted parsing and checking for explicit induction-head certificates. +-/ + +public section + +namespace Nfp + +namespace IO + +open Nfp.Circuit +open Nfp.IO.Pure + +namespace InductionHeadCert + +/-- State for parsing induction-head certificates. -/ +structure ParseState (seq : Nat) where + /-- Optional epsilon bound. -/ + eps : Option Rat + /-- Optional margin bound. -/ + margin : Option Rat + /-- Active query set. -/ + active : Finset (Fin seq) + /-- Whether any active entries were parsed. -/ + activeSeen : Bool + /-- Optional predecessor pointer per query. -/ + prev : Array (Option (Fin seq)) + /-- Optional score matrix entries. -/ + scores : Array (Array (Option Rat)) + /-- Optional weight matrix entries. -/ + weights : Array (Array (Option Rat)) + /-- Optional per-query epsilon bounds. -/ + epsAt : Array (Option Rat) + /-- Optional per-key weight bounds. -/ + weightBoundAt : Array (Array (Option Rat)) + /-- Optional lower bound for values. -/ + lo : Option Rat + /-- Optional upper bound for values. -/ + hi : Option Rat + /-- Optional per-key lower bounds. -/ + valsLo : Array (Option Rat) + /-- Optional per-key upper bounds. -/ + valsHi : Array (Option Rat) + /-- Optional per-key exact values. -/ + vals : Array (Option Rat) + /-- Optional direction target index. -/ + directionTarget : Option Nat + /-- Optional direction negative index. -/ + directionNegative : Option Nat + +/-- Initialize a parse state. -/ +def initState (seq : Nat) : ParseState seq := + let row : Array (Option Rat) := Array.replicate seq none + { eps := none + margin := none + active := ∅ + activeSeen := false + prev := Array.replicate seq none + scores := Array.replicate seq row + weights := Array.replicate seq row + epsAt := Array.replicate seq none + weightBoundAt := Array.replicate seq row + lo := none + hi := none + valsLo := Array.replicate seq none + valsHi := Array.replicate seq none + vals := Array.replicate seq none + directionTarget := none + directionNegative := none } + +private def setActive {seq : Nat} (st : ParseState seq) (q : Nat) : + Except String (ParseState seq) := do + if hq : q < seq then + let qFin : Fin seq := ⟨q, hq⟩ + if qFin ∈ st.active then + throw s!"duplicate active entry for q={q}" + else + return { st with active := insert qFin st.active, activeSeen := true } + else + throw s!"active index out of range: q={q}" + +private def setPrev {seq : Nat} (st : ParseState seq) (q k : Nat) : + Except String (ParseState seq) := do + if q < seq then + if hk : k < seq then + let kFin : Fin seq := ⟨k, hk⟩ + match st.prev[q]! with + | some _ => + throw s!"duplicate prev entry for q={q}" + | none => + let prev' := st.prev.set! q (some kFin) + return { st with prev := prev' } + else + throw s!"prev index out of range: k={k}" + else + throw s!"prev index out of range: q={q}" + +private def setVecEntry {seq : Nat} (arr : Array (Option Rat)) (idx : Nat) (v : Rat) : + Except String (Array (Option Rat)) := do + if idx < seq then + match arr[idx]! with + | some _ => + throw s!"duplicate entry for k={idx}" + | none => + return arr.set! idx (some v) + else + throw s!"index out of range: k={idx}" + +private def setMatrixEntry {seq : Nat} (mat : Array (Array (Option Rat))) + (q k : Nat) (v : Rat) : Except String (Array (Array (Option Rat))) := do + if q < seq then + if k < seq then + let row := mat[q]! + match row[k]! with + | some _ => + throw s!"duplicate matrix entry at ({q}, {k})" + | none => + let row' := row.set! k (some v) + return mat.set! q row' + else + throw s!"index out of range: k={k}" + else + throw s!"index out of range: q={q}" + +/-- Parse a tokenized line into the parse state. -/ +def parseLine {seq : Nat} (st : ParseState seq) (tokens : List String) : + Except String (ParseState seq) := do + match tokens with + | ["eps", val] => + if st.eps.isSome then + throw "duplicate eps entry" + else + return { st with eps := some (← parseRat val) } + | ["margin", val] => + if st.margin.isSome then + throw "duplicate margin entry" + else + return { st with margin := some (← parseRat val) } + | ["active", q] => + setActive st (← parseNat q) + | ["prev", q, k] => + setPrev st (← parseNat q) (← parseNat k) + | ["score", q, k, val] => + let mat ← setMatrixEntry (seq := seq) st.scores (← parseNat q) (← parseNat k) + (← parseRat val) + return { st with scores := mat } + | ["weight", q, k, val] => + let mat ← setMatrixEntry (seq := seq) st.weights (← parseNat q) (← parseNat k) + (← parseRat val) + return { st with weights := mat } + | ["eps-at", q, val] => + let arr ← setVecEntry (seq := seq) st.epsAt (← parseNat q) (← parseRat val) + return { st with epsAt := arr } + | ["weight-bound", q, k, val] => + let mat ← setMatrixEntry (seq := seq) st.weightBoundAt (← parseNat q) (← parseNat k) + (← parseRat val) + return { st with weightBoundAt := mat } + | ["lo", val] => + if st.lo.isSome then + throw "duplicate lo entry" + else + return { st with lo := some (← parseRat val) } + | ["hi", val] => + if st.hi.isSome then + throw "duplicate hi entry" + else + return { st with hi := some (← parseRat val) } + | ["val", k, val] => + let arr ← setVecEntry (seq := seq) st.vals (← parseNat k) (← parseRat val) + return { st with vals := arr } + | ["val-lo", k, val] => + let arr ← setVecEntry (seq := seq) st.valsLo (← parseNat k) (← parseRat val) + return { st with valsLo := arr } + | ["val-hi", k, val] => + let arr ← setVecEntry (seq := seq) st.valsHi (← parseNat k) (← parseRat val) + return { st with valsHi := arr } + | ["direction-target", tok] => + if st.directionTarget.isSome then + throw "duplicate direction-target entry" + else + return { st with directionTarget := some (← parseNat tok) } + | ["direction-negative", tok] => + if st.directionNegative.isSome then + throw "duplicate direction-negative entry" + else + return { st with directionNegative := some (← parseNat tok) } + | _ => + throw s!"unrecognized line: '{String.intercalate " " tokens}'" + +/-- Extract the `seq` header from tokenized lines. -/ +def parseSeq (tokens : List (List String)) : Except String Nat := do + let mut seq? : Option Nat := none + for t in tokens do + match t with + | ["seq", n] => + if seq?.isSome then + throw "duplicate seq entry" + else + seq? := some (← parseNat n) + | _ => pure () + match seq? with + | some v => pure v + | none => throw "missing seq entry" + +private def finalizeState {seq : Nat} (hpos : 0 < seq) (st : ParseState seq) : + Except String (Circuit.InductionHeadCert seq) := do + let eps ← + match st.eps with + | some v => pure v + | none => throw "missing eps entry" + let margin ← + match st.margin with + | some v => pure v + | none => throw "missing margin entry" + let lo ← + match st.lo with + | some v => pure v + | none => throw "missing lo entry" + let hi ← + match st.hi with + | some v => pure v + | none => throw "missing hi entry" + if !st.prev.all Option.isSome then + throw "missing prev entries" + if !st.scores.all (fun row => row.all Option.isSome) then + throw "missing score entries" + if !st.weights.all (fun row => row.all Option.isSome) then + throw "missing weight entries" + if !st.epsAt.all Option.isSome then + throw "missing eps-at entries" + if !st.weightBoundAt.all (fun row => row.all Option.isSome) then + throw "missing weight-bound entries" + if !st.valsLo.all Option.isSome then + throw "missing val-lo entries" + if !st.valsHi.all Option.isSome then + throw "missing val-hi entries" + if !st.vals.all Option.isSome then + throw "missing val entries" + let defaultPrev : Fin seq := ⟨0, hpos⟩ + let prevFun : Fin seq → Fin seq := fun q => + (st.prev[q.1]!).getD defaultPrev + let scoresFun : Fin seq → Fin seq → Rat := fun q k => + let row := st.scores[q.1]! + (row[k.1]!).getD 0 + let weightsFun : Fin seq → Fin seq → Rat := fun q k => + let row := st.weights[q.1]! + (row[k.1]!).getD 0 + let epsAtFun : Fin seq → Rat := fun q => + (st.epsAt[q.1]!).getD 0 + let weightBoundAtFun : Fin seq → Fin seq → Rat := fun q k => + let row := st.weightBoundAt[q.1]! + (row[k.1]!).getD 0 + let valsLoFun : Fin seq → Rat := fun k => + (st.valsLo[k.1]!).getD 0 + let valsHiFun : Fin seq → Rat := fun k => + (st.valsHi[k.1]!).getD 0 + let valsFun : Fin seq → Rat := fun k => + (st.vals[k.1]!).getD 0 + let direction ← + match st.directionTarget, st.directionNegative with + | none, none => pure none + | some target, some negative => + pure (some { target := target, negative := negative }) + | _, _ => + throw "direction metadata requires both direction-target and direction-negative" + let values : Circuit.ValueIntervalCert seq := + { lo := lo + hi := hi + valsLo := valsLoFun + valsHi := valsHiFun + vals := valsFun + direction := direction } + let active := + if st.activeSeen then + st.active + else + (Finset.univ : Finset (Fin seq)).erase defaultPrev + return + { eps := eps + epsAt := epsAtFun + weightBoundAt := weightBoundAtFun + margin := margin + active := active + prev := prevFun + scores := scoresFun + weights := weightsFun + values := values } + +end InductionHeadCert + +/-- Parse an explicit induction-head certificate from a text payload. -/ +def parseInductionHeadCert (input : String) : + Except String (Sigma Circuit.InductionHeadCert) := do + let lines := input.splitOn "\n" + let tokens := lines.filterMap cleanTokens + let seq ← InductionHeadCert.parseSeq tokens + match seq with + | 0 => throw "seq must be positive" + | Nat.succ n => + let seq := Nat.succ n + let hpos : 0 < seq := Nat.succ_pos n + let st0 : InductionHeadCert.ParseState seq := InductionHeadCert.initState seq + let st ← tokens.foldlM (fun st t => + match t with + | ["seq", _] => pure st + | _ => InductionHeadCert.parseLine st t) st0 + let cert ← InductionHeadCert.finalizeState hpos st + return ⟨seq, cert⟩ + +/-- Load an induction-head certificate from disk. -/ +def loadInductionHeadCert (path : System.FilePath) : + IO (Except String (Sigma Circuit.InductionHeadCert)) := do + let data ← IO.FS.readFile path + return parseInductionHeadCert data + +private def ratToString (x : Rat) : String := + toString x + +/-- Check an explicit induction-head certificate from disk. -/ +def runInductionHeadCertCheck (certPath : System.FilePath) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) + let parsed ← loadInductionHeadCert certPath + match parsed with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seq, cert⟩ => + match seq with + | 0 => + IO.eprintln "error: seq must be positive" + return 2 + | Nat.succ n => + let seq := Nat.succ n + let _ : NeZero seq := ⟨by simp⟩ + let ok := Circuit.checkInductionHeadCert cert + if !ok then + IO.eprintln "error: induction-head certificate rejected" + return 2 + let activeCount := cert.active.card + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" + return 2 + if cert.margin < minMargin then + IO.eprintln + s!"error: margin {ratToString cert.margin} \ + below minimum {ratToString minMargin}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {ratToString cert.eps} \ + above maximum {ratToString maxEps}" + return 2 + let effectiveMinLogitDiff := + match minLogitDiff?, cert.values.direction with + | some v, _ => some v + | none, some _ => some (0 : Rat) + | none, none => none + match effectiveMinLogitDiff with + | none => + IO.println + s!"ok: induction head certificate checked \ + (seq={seq}, active={activeCount}, \ + margin={ratToString cert.margin}, eps={ratToString cert.eps})" + return 0 + | some minLogitDiff => + let logitDiffLB? := + Circuit.logitDiffLowerBoundAt cert.active cert.prev cert.epsAt + cert.values.lo cert.values.hi cert.values.vals + match logitDiffLB? with + | none => + IO.eprintln "error: empty active set for logit-diff bound" + return 2 + | some logitDiffLB => + if logitDiffLB < minLogitDiff then + IO.eprintln + s!"error: logitDiffLB {ratToString logitDiffLB} \ + below minimum {ratToString minLogitDiff}" + return 2 + else + IO.println + s!"ok: induction head certificate checked \ + (seq={seq}, active={activeCount}, \ + margin={ratToString cert.margin}, eps={ratToString cert.eps}, \ + logitDiffLB={ratToString logitDiffLB})" + return 0 + +end IO + +end Nfp + From efda300d0f91a9d5a37f450e1a8e5a27895db8ad Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 07:25:27 +0100 Subject: [PATCH 214/244] Allow induction certify to use explicit certs --- Nfp/Cli.lean | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index b24e1f4..b03182d 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -63,6 +63,7 @@ private def toZeroBased (label : String) (idx : Nat) (zeroBased : Bool) : private def runInductionCertifyUnified (requireNonvacuous : Bool) (p : Parsed) : IO UInt32 := do let inputsPath? := (p.flag? "inputs").map (·.as! String) let modelPath? := (p.flag? "model").map (·.as! String) + let certPath? := (p.flag? "cert").map (·.as! String) let layer? := (p.flag? "layer").map (·.as! Nat) let head? := (p.flag? "head").map (·.as! Nat) let period? := (p.flag? "period").map (·.as! Nat) @@ -93,6 +94,22 @@ private def runInductionCertifyUnified (requireNonvacuous : Bool) (p : Parsed) : | _, Except.error msg => fail msg | Except.ok ⟨splitBudgetQ?, splitBudgetK?, splitBudgetDiffBase?, splitBudgetDiffRefined?⟩, Except.ok direction? => + match certPath? with + | some certPath => + if inputsPath?.isSome || modelPath?.isSome then + fail "provide exactly one of --cert or --inputs/--model" + else if layer?.isSome || head?.isSome || period?.isSome || prevShift then + fail "--layer/--head/--period/--prev-shift are only valid with --model" + else if direction?.isSome then + fail "--direction is only valid with --model" + else if presetStr?.isSome then + fail "--preset is only valid with --inputs or --model" + else if requireNonvacuous && skipLogitDiff then + fail "--skip-logit-diff is not allowed with certify_nonvacuous" + else + IO.runInductionHeadCertCheck certPath minActive? minLogitDiffStr? + minMarginStr? maxEpsStr? + | none => match inputsPath?, modelPath? with | some inputsPath, none => if layer?.isSome || head?.isSome || period?.isSome || prevShift then @@ -152,9 +169,9 @@ private def runInductionCertifyUnified (requireNonvacuous : Bool) (p : Parsed) : | _, _ => fail "--layer and --head are required with --model" | none, none => - fail "provide exactly one of --inputs or --model" + fail "provide exactly one of --cert or --inputs/--model" | some _, some _ => - fail "provide exactly one of --inputs or --model" + fail "provide exactly one of --cert or --inputs/--model" private def runInductionCertifySimple (p : Parsed) : IO UInt32 := runInductionCertifyUnified false p @@ -213,8 +230,10 @@ private def runInductionIntervalSimple (p : Parsed) : IO UInt32 := do /-- `nfp induction certify` subcommand (streamlined). -/ def inductionCertifySimpleCmd : Cmd := `[Cli| certify VIA runInductionCertifySimple; - "Check induction head certificates from inputs or a model file." + "Check induction head certificates from an explicit cert, inputs, or a model file." FLAGS: + cert : String; "Path to the induction head certificate file \ + (use either --cert or --inputs/--model)." inputs : String; "Path to the induction head input file (use either --inputs or --model)." model : String; "Path to the NFP_BINARY_V1 model file (use either --inputs or --model)." layer : Nat; "Layer index for the induction head (1-based, required with --model)." @@ -239,8 +258,10 @@ def inductionCertifySimpleCmd : Cmd := `[Cli| /-- `nfp induction certify_nonvacuous` subcommand (streamlined). -/ def inductionCertifyNonvacuousSimpleCmd : Cmd := `[Cli| certify_nonvacuous VIA runInductionCertifyNonvacuousSimple; - "Require a strictly positive logit-diff bound from inputs or a model file." + "Require a strictly positive logit-diff bound from a cert, inputs, or a model file." FLAGS: + cert : String; "Path to the induction head certificate file \ + (use either --cert or --inputs/--model)." inputs : String; "Path to the induction head input file (use either --inputs or --model)." model : String; "Path to the NFP_BINARY_V1 model file (use either --inputs or --model)." layer : Nat; "Layer index for the induction head (1-based, required with --model)." From e8e31f8625dde174072b523c78ec8b14f4703338 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 07:32:15 +0100 Subject: [PATCH 215/244] Deprecate non-split CLI paths --- Nfp/Cli.lean | 34 ++++++++++++++-------------- Nfp/IO/InductionHead/Basic.lean | 14 ++++++++++++ Nfp/IO/InductionHead/Circuit.lean | 3 +++ Nfp/IO/InductionHead/Nonvacuous.lean | 9 ++++++++ Nfp/IO/Run/Basic.lean | 9 ++++++++ Nfp/IO/Util.lean | 4 ++++ 6 files changed, 56 insertions(+), 17 deletions(-) diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index b03182d..4f1f637 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -230,7 +230,7 @@ private def runInductionIntervalSimple (p : Parsed) : IO UInt32 := do /-- `nfp induction certify` subcommand (streamlined). -/ def inductionCertifySimpleCmd : Cmd := `[Cli| certify VIA runInductionCertifySimple; - "Check induction head certificates from an explicit cert, inputs, or a model file." + "Check induction head certificates from an explicit cert. Inputs/model modes are DEPRECATED." FLAGS: cert : String; "Path to the induction head certificate file \ (use either --cert or --inputs/--model)." @@ -258,7 +258,7 @@ def inductionCertifySimpleCmd : Cmd := `[Cli| /-- `nfp induction certify_nonvacuous` subcommand (streamlined). -/ def inductionCertifyNonvacuousSimpleCmd : Cmd := `[Cli| certify_nonvacuous VIA runInductionCertifyNonvacuousSimple; - "Require a strictly positive logit-diff bound from a cert, inputs, or a model file." + "Require a strictly positive logit-diff bound from a cert. Inputs/model modes are DEPRECATED." FLAGS: cert : String; "Path to the induction head certificate file \ (use either --cert or --inputs/--model)." @@ -285,7 +285,7 @@ def inductionCertifyNonvacuousSimpleCmd : Cmd := `[Cli| /-- `nfp induction interval` subcommand (streamlined). -/ def inductionIntervalSimpleCmd : Cmd := `[Cli| interval VIA runInductionIntervalSimple; - "Build head-output interval bounds from inputs or a model file." + "DEPRECATED: build head-output interval bounds from inputs or a model file." FLAGS: inputs : String; "Path to the induction head input file (use either --inputs or --model)." model : String; "Path to the NFP_BINARY_V1 model file (use either --inputs or --model)." @@ -338,7 +338,7 @@ def runInductionCertifySound (p : Parsed) : IO UInt32 := do /-- `nfp induction certify_sound` subcommand. -/ def inductionCertifySoundCmd : Cmd := `[Cli| certify_sound VIA runInductionCertifySound; - "Check induction certificates from raw scores/values." + "DEPRECATED: check induction certificates from raw scores/values." FLAGS: scores : String; "Path to the raw scores/weights file." values : String; "Path to the raw value entries file." @@ -395,7 +395,7 @@ def runInductionCertifyEndToEndMatrix (p : Parsed) : IO UInt32 := do /-- `nfp induction certify_end_to_end_matrix` subcommand. -/ def inductionCertifyEndToEndMatrixCmd : Cmd := `[Cli| certify_end_to_end_matrix VIA runInductionCertifyEndToEndMatrix; - "Check end-to-end induction bounds using a downstream matrix payload." + "DEPRECATED: check end-to-end induction bounds using a downstream matrix payload." FLAGS: scores : String; "Path to the softmax-margin certificate file." values : String; "Path to the value-range certificate file." @@ -428,7 +428,7 @@ def runInductionCertifyEndToEndModel (p : Parsed) : IO UInt32 := do /-- `nfp induction certify_end_to_end_model` subcommand. -/ def inductionCertifyEndToEndModelCmd : Cmd := `[Cli| certify_end_to_end_model VIA runInductionCertifyEndToEndModel; - "Check end-to-end induction bounds using a model file for the downstream matrix." + "DEPRECATED: check end-to-end induction bounds using a model file for downstream bounds." FLAGS: scores : String; "Path to the softmax-margin certificate file." values : String; "Path to the value-range certificate file." @@ -485,7 +485,7 @@ def runInductionCertifyHeadNonvacuous (p : Parsed) : IO UInt32 := do /-- `nfp induction certify_head` subcommand. -/ def inductionCertifyHeadCmd : Cmd := `[Cli| certify_head VIA runInductionCertifyHead; - "Check induction certificates from exact head inputs." + "DEPRECATED: check induction certificates from exact head inputs." FLAGS: inputs : String; "Path to the induction head input file." "min-active" : Nat; "Optional minimum number of active queries required \ @@ -508,7 +508,7 @@ def inductionCertifyHeadCmd : Cmd := `[Cli| /-- `nfp induction certify_head_nonvacuous` subcommand. -/ def inductionCertifyHeadNonvacuousCmd : Cmd := `[Cli| certify_head_nonvacuous VIA runInductionCertifyHeadNonvacuous; - "Require a strictly positive logit-diff bound from exact head inputs." + "DEPRECATED: require a strictly positive logit-diff bound from exact head inputs." FLAGS: inputs : String; "Path to the induction head input file." "min-active" : Nat; "Optional minimum number of active queries required \ @@ -721,7 +721,7 @@ def runInductionCertifyCircuitModel (p : Parsed) : IO UInt32 := do /-- `nfp induction certify_head_model` subcommand. -/ def inductionCertifyHeadModelCmd : Cmd := `[Cli| certify_head_model VIA runInductionCertifyHeadModel; - "Check induction certificates by reading a model binary directly." + "DEPRECATED: check induction certificates by reading a model binary directly." FLAGS: model : String; "Path to the NFP_BINARY_V1 model file." layer : Nat; "Layer index for the induction head (1-based)." @@ -751,7 +751,7 @@ def inductionCertifyHeadModelCmd : Cmd := `[Cli| /-- `nfp induction certify_head_model_nonvacuous` subcommand. -/ def inductionCertifyHeadModelNonvacuousCmd : Cmd := `[Cli| certify_head_model_nonvacuous VIA runInductionCertifyHeadModelNonvacuous; - "Require a strictly positive logit-diff bound from a model binary." + "DEPRECATED: require a strictly positive logit-diff bound from a model binary." FLAGS: model : String; "Path to the NFP_BINARY_V1 model file." layer : Nat; "Layer index for the induction head (1-based)." @@ -780,8 +780,8 @@ def inductionCertifyHeadModelNonvacuousCmd : Cmd := `[Cli| /-- `nfp induction certify_head_model_auto` subcommand. -/ def inductionCertifyHeadModelAutoCmd : Cmd := `[Cli| certify_head_model_auto VIA runInductionCertifyHeadModelAuto; - "Check induction certificates by reading a model binary and deriving the direction \ - from the prompt tokens." + "DEPRECATED: check induction certificates by reading a model binary and deriving the \ + direction from the prompt tokens." FLAGS: model : String; "Path to the NFP_BINARY_V1 model file." layer : Nat; "Layer index for the induction head (1-based)." @@ -808,8 +808,8 @@ def inductionCertifyHeadModelAutoCmd : Cmd := `[Cli| /-- `nfp induction certify_head_model_auto_nonvacuous` subcommand. -/ def inductionCertifyHeadModelAutoNonvacuousCmd : Cmd := `[Cli| certify_head_model_auto_nonvacuous VIA runInductionCertifyHeadModelAutoNonvacuous; - "Require a strictly positive logit-diff bound from a model binary, with the direction \ - derived from the prompt tokens." + "DEPRECATED: require a strictly positive logit-diff bound from a model binary, with the \ + direction derived from the prompt tokens." FLAGS: model : String; "Path to the NFP_BINARY_V1 model file." layer : Nat; "Layer index for the induction head (1-based)." @@ -836,7 +836,7 @@ def inductionCertifyHeadModelAutoNonvacuousCmd : Cmd := `[Cli| /-- `nfp induction certify_circuit_model` subcommand. -/ def inductionCertifyCircuitModelCmd : Cmd := `[Cli| certify_circuit_model VIA runInductionCertifyCircuitModel; - "Check a two-head induction circuit by reading a model binary directly \ + "DEPRECATED: check a two-head induction circuit by reading a model binary directly \ (induction head uses shifted prev)." FLAGS: model : String; "Path to the NFP_BINARY_V1 model file." @@ -874,7 +874,7 @@ def runInductionHeadInterval (p : Parsed) : IO UInt32 := do /-- `nfp induction head_interval` subcommand. -/ def inductionHeadIntervalCmd : Cmd := `[Cli| head_interval VIA runInductionHeadInterval; - "Build head-output interval bounds from exact head inputs." + "DEPRECATED: build head-output interval bounds from exact head inputs." FLAGS: inputs : String; "Path to the induction head input file." out : String; "Optional path to write the residual-interval certificate." @@ -930,7 +930,7 @@ def runInductionHeadIntervalModel (p : Parsed) : IO UInt32 := do /-- `nfp induction head_interval_model` subcommand. -/ def inductionHeadIntervalModelCmd : Cmd := `[Cli| head_interval_model VIA runInductionHeadIntervalModel; - "Build head-output interval bounds by reading a model binary directly." + "DEPRECATED: build head-output interval bounds by reading a model binary directly." FLAGS: model : String; "Path to the NFP_BINARY_V1 model file." layer : Nat; "Layer index for the induction head (1-based)." diff --git a/Nfp/IO/InductionHead/Basic.lean b/Nfp/IO/InductionHead/Basic.lean index 7b076e0..b5de6c2 100644 --- a/Nfp/IO/InductionHead/Basic.lean +++ b/Nfp/IO/InductionHead/Basic.lean @@ -1186,6 +1186,9 @@ def runInductionCertifyHead (inputsPath : System.FilePath) (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) (skipLogitDiff : Bool) : IO UInt32 := do + warnDeprecated + "certify_head builds certificates from head inputs; use explicit certs \ + via `nfp induction certify --cert` or `nfp induction head_cert_check`." configureTiming timing? heartbeatMs? let splitCfg := splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? @@ -1224,6 +1227,9 @@ def runInductionCertifyHeadModel (modelPath : System.FilePath) (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) (skipLogitDiff : Bool) : IO UInt32 := do + warnDeprecated + "certify_head_model builds certificates from a model file; use explicit certs \ + via `nfp induction certify --cert` or `nfp induction head_cert_check`." configureTiming timing? heartbeatMs? let splitCfg := splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? @@ -1303,6 +1309,9 @@ def runInductionCertifyHeadModelAuto (modelPath : System.FilePath) (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) (skipLogitDiff : Bool) : IO UInt32 := do + warnDeprecated + "certify_head_model_auto builds certificates from a model file; use explicit certs \ + via `nfp induction certify --cert` or `nfp induction head_cert_check`." configureTiming timing? heartbeatMs? let splitCfg := splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? @@ -1361,6 +1370,8 @@ def runInductionCertifyHeadModelAuto (modelPath : System.FilePath) /-- Build head-output interval bounds from exact head inputs. -/ def runInductionHeadInterval (inputsPath : System.FilePath) (outPath? : Option System.FilePath) : IO UInt32 := do + warnDeprecated + "head_interval builds interval bounds from head inputs; use explicit interval certs instead." let parsedInputs ← loadInductionHeadInputs inputsPath match parsedInputs with | Except.error msg => @@ -1373,6 +1384,9 @@ def runInductionHeadInterval (inputsPath : System.FilePath) def runInductionHeadIntervalModel (modelPath : System.FilePath) (layer head dirTarget dirNegative : Nat) (period? : Option Nat) (shiftPrev : Bool) (outPath? : Option System.FilePath) : IO UInt32 := do + warnDeprecated + "head_interval_model builds interval bounds from a model file; \ + use explicit interval certs instead." let data ← IO.FS.readBinFile modelPath match NfptPure.parseHeader data with | Except.error msg => diff --git a/Nfp/IO/InductionHead/Circuit.lean b/Nfp/IO/InductionHead/Circuit.lean index 9576b5e..ef6b580 100644 --- a/Nfp/IO/InductionHead/Circuit.lean +++ b/Nfp/IO/InductionHead/Circuit.lean @@ -27,6 +27,9 @@ def runInductionCertifyCircuitModel (modelPath : System.FilePath) (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) (skipLogitDiff : Bool) : IO UInt32 := do + warnDeprecated + "certify_circuit_model builds certificates from a model file; \ + use explicit certs for each head instead." let prevCode ← runInductionCertifyHeadModel modelPath diff --git a/Nfp/IO/InductionHead/Nonvacuous.lean b/Nfp/IO/InductionHead/Nonvacuous.lean index ae7d0f4..ba74348 100644 --- a/Nfp/IO/InductionHead/Nonvacuous.lean +++ b/Nfp/IO/InductionHead/Nonvacuous.lean @@ -378,6 +378,9 @@ def runInductionCertifyHeadNonvacuous (inputsPath : System.FilePath) (timing? : Option Nat) (heartbeatMs? : Option Nat) (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : IO UInt32 := do + warnDeprecated + "certify_head_nonvacuous builds certificates from head inputs; \ + use explicit certs via `nfp induction certify_nonvacuous --cert`." configureTiming timing? heartbeatMs? let splitCfg := splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? @@ -414,6 +417,9 @@ def runInductionCertifyHeadModelNonvacuous (modelPath : System.FilePath) (timing? : Option Nat) (heartbeatMs? : Option Nat) (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : IO UInt32 := do + warnDeprecated + "certify_head_model_nonvacuous builds certificates from a model file; \ + use explicit certs via `nfp induction certify_nonvacuous --cert`." configureTiming timing? heartbeatMs? let splitCfg := splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? @@ -463,6 +469,9 @@ def runInductionCertifyHeadModelAutoNonvacuous (modelPath : System.FilePath) (timing? : Option Nat) (heartbeatMs? : Option Nat) (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : IO UInt32 := do + warnDeprecated + "certify_head_model_auto_nonvacuous builds certificates from a model file; \ + use explicit certs via `nfp induction certify_nonvacuous --cert`." configureTiming timing? heartbeatMs? let splitCfg := splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? diff --git a/Nfp/IO/Run/Basic.lean b/Nfp/IO/Run/Basic.lean index e2a3b69..301ff46 100644 --- a/Nfp/IO/Run/Basic.lean +++ b/Nfp/IO/Run/Basic.lean @@ -150,6 +150,9 @@ def runInductionCertifySound (scoresPath : System.FilePath) (valuesPath : System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + warnDeprecated + "certify_sound builds certificates from raw scores/values; use explicit certs \ + via `nfp induction certify` or `nfp induction head_cert_check`." let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? let minMargin?E := parseRatOpt "min-margin" minMarginStr? let maxEps?E := parseRatOpt "max-eps" maxEpsStr? @@ -377,6 +380,9 @@ def runInductionCertifyEndToEndMatrix (scoresPath : System.FilePath) (valuesPath : System.FilePath) (matrixPath : System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + warnDeprecated + "certify_end_to_end_matrix builds downstream bounds from a raw matrix payload; \ + use a downstream cert instead." let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? let minMargin?E := parseRatOpt "min-margin" minMarginStr? let maxEps?E := parseRatOpt "max-eps" maxEpsStr? @@ -505,6 +511,9 @@ def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) (layer? : Option Nat) (head? : Option Nat) (period? : Option Nat) (minActive? : Option Nat) (minLogitDiffStr? : Option String) (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + warnDeprecated + "certify_end_to_end_model derives residual bounds from a model file; \ + use an explicit residual-interval cert instead." let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? let minMargin?E := parseRatOpt "min-margin" minMarginStr? let maxEps?E := parseRatOpt "max-eps" maxEpsStr? diff --git a/Nfp/IO/Util.lean b/Nfp/IO/Util.lean index d52c91f..0f183eb 100644 --- a/Nfp/IO/Util.lean +++ b/Nfp/IO/Util.lean @@ -24,6 +24,10 @@ def parseRatOpt (label : String) (raw? : Option String) : | Except.ok v => Except.ok (some v) | Except.error msg => Except.error s!"invalid {label}: {msg}" +/-- Emit a deprecation warning on stderr. -/ +def warnDeprecated (msg : String) : IO Unit := do + IO.eprintln s!"warning: DEPRECATED: {msg}" + end IO end Nfp From 47f338476a2948d2788621916f3f46c80113fa5e Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 07:55:40 +0100 Subject: [PATCH 216/244] Remove non-split induction CLI/IO paths --- Nfp/Cli.lean | 891 +-------------- Nfp/IO.lean | 7 - Nfp/IO/Derive.lean | 141 --- Nfp/IO/HeadScore.lean | 60 - Nfp/IO/InductionHead.lean | 5 +- Nfp/IO/InductionHead/Basic.lean | 1408 ------------------------ Nfp/IO/InductionHead/Cert.lean | 3 +- Nfp/IO/InductionHead/Circuit.lean | 83 -- Nfp/IO/InductionHead/Nonvacuous.lean | 531 --------- Nfp/IO/NfptPure.lean | 804 -------------- Nfp/IO/Run.lean | 9 - Nfp/IO/Run/Basic.lean | 818 -------------- Nfp/IO/Timing.lean | 454 -------- Nfp/Sound/Bounds/MatrixNorm/Basic.lean | 29 - Nfp/Sound/Induction/Core/Basic.lean | 68 -- 15 files changed, 17 insertions(+), 5294 deletions(-) delete mode 100644 Nfp/IO/Derive.lean delete mode 100644 Nfp/IO/HeadScore.lean delete mode 100644 Nfp/IO/InductionHead/Basic.lean delete mode 100644 Nfp/IO/InductionHead/Circuit.lean delete mode 100644 Nfp/IO/InductionHead/Nonvacuous.lean delete mode 100644 Nfp/IO/NfptPure.lean delete mode 100644 Nfp/IO/Run.lean delete mode 100644 Nfp/IO/Run/Basic.lean delete mode 100644 Nfp/IO/Timing.lean diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index 4f1f637..67c1b4e 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -29,149 +29,24 @@ def versionCmd : Cmd := `[Cli| "Print the NFP version." ] -private def parseDirectionSpec (raw : String) : Except String (Nat × Nat) := do - let partsComma := raw.splitOn "," - let parts := if partsComma.length = 2 then partsComma else raw.splitOn ":" - match parts with - | [targetRaw, negativeRaw] => - match targetRaw.toNat?, negativeRaw.toNat? with - | some target, some negative => pure (target, negative) - | _, _ => throw s!"direction must be two natural numbers (got '{raw}')" - | _ => - throw s!"direction must look like \"target,negative\" (got '{raw}')" - -private def parseSplitPreset (raw : String) : - Except String (Option Nat × Option Nat × Option Nat × Option Nat) := do - let key := raw.toLower - match key with - | "balanced" | "default" => pure (none, none, none, none) - | "fast" => pure (some 0, some 0, some 0, some 0) - | "tight" => pure (some 4, some 4, some 2, some 16) - | _ => - throw s!"unknown preset '{raw}' (expected: fast, balanced, tight)" - -private def toZeroBased (label : String) (idx : Nat) (zeroBased : Bool) : - Except String Nat := do - if zeroBased then - pure idx - else - if idx = 0 then - throw s!"{label} must be >= 1 for 1-based indexing (use --zero-based for 0-based)" - else - pure (idx - 1) - private def runInductionCertifyUnified (requireNonvacuous : Bool) (p : Parsed) : IO UInt32 := do - let inputsPath? := (p.flag? "inputs").map (·.as! String) - let modelPath? := (p.flag? "model").map (·.as! String) let certPath? := (p.flag? "cert").map (·.as! String) - let layer? := (p.flag? "layer").map (·.as! Nat) - let head? := (p.flag? "head").map (·.as! Nat) - let period? := (p.flag? "period").map (·.as! Nat) - let prevShift := p.hasFlag "prev-shift" - let directionStr? := (p.flag? "direction").map (·.as! String) - let presetStr? := (p.flag? "preset").map (·.as! String) let minActive? := (p.flag? "min-active").map (·.as! Nat) let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) let minMarginStr? := (p.flag? "min-margin").map (·.as! String) let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) - let timing? := (p.flag? "timing").map (·.as! Nat) - let heartbeatMs? := (p.flag? "heartbeat-ms").map (·.as! Nat) - let skipLogitDiff := p.hasFlag "skip-logit-diff" - let zeroBased := p.hasFlag "zero-based" let fail (msg : String) : IO UInt32 := do IO.eprintln s!"error: {msg}" return 2 - let presetE := - match presetStr? with - | none => Except.ok (none, none, none, none) - | some raw => parseSplitPreset raw - let directionE := - match directionStr? with - | none => Except.ok none - | some raw => (parseDirectionSpec raw).map some - match presetE, directionE with - | Except.error msg, _ => fail msg - | _, Except.error msg => fail msg - | Except.ok ⟨splitBudgetQ?, splitBudgetK?, splitBudgetDiffBase?, splitBudgetDiffRefined?⟩, - Except.ok direction? => - match certPath? with - | some certPath => - if inputsPath?.isSome || modelPath?.isSome then - fail "provide exactly one of --cert or --inputs/--model" - else if layer?.isSome || head?.isSome || period?.isSome || prevShift then - fail "--layer/--head/--period/--prev-shift are only valid with --model" - else if direction?.isSome then - fail "--direction is only valid with --model" - else if presetStr?.isSome then - fail "--preset is only valid with --inputs or --model" - else if requireNonvacuous && skipLogitDiff then - fail "--skip-logit-diff is not allowed with certify_nonvacuous" - else - IO.runInductionHeadCertCheck certPath minActive? minLogitDiffStr? - minMarginStr? maxEpsStr? - | none => - match inputsPath?, modelPath? with - | some inputsPath, none => - if layer?.isSome || head?.isSome || period?.isSome || prevShift then - fail "--layer/--head/--period/--prev-shift are only valid with --model" - else if direction?.isSome then - fail "--direction is only valid with --model" - else if requireNonvacuous && skipLogitDiff then - fail "--skip-logit-diff is not allowed with certify_nonvacuous" - else if requireNonvacuous then - IO.runInductionCertifyHeadNonvacuous inputsPath minActive? minLogitDiffStr? - minMarginStr? maxEpsStr? timing? heartbeatMs? - splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - else - IO.runInductionCertifyHead inputsPath minActive? minLogitDiffStr? - minMarginStr? maxEpsStr? timing? heartbeatMs? - splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - skipLogitDiff - | none, some modelPath => - match layer?, head? with - | some layer, some head => - let layerE := toZeroBased "layer" layer zeroBased - let headE := toZeroBased "head" head zeroBased - match layerE, headE with - | Except.error msg, _ => fail msg - | _, Except.error msg => fail msg - | Except.ok layer', Except.ok head' => - match direction? with - | some ⟨dirTarget, dirNegative⟩ => - if requireNonvacuous && skipLogitDiff then - fail "--skip-logit-diff is not allowed with certify_nonvacuous" - else if requireNonvacuous then - IO.runInductionCertifyHeadModelNonvacuous modelPath layer' head' dirTarget - dirNegative period? prevShift minActive? minLogitDiffStr? minMarginStr? - maxEpsStr? - timing? heartbeatMs? - splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - else - IO.runInductionCertifyHeadModel modelPath layer' head' dirTarget dirNegative - period? prevShift minActive? minLogitDiffStr? minMarginStr? maxEpsStr? - timing? heartbeatMs? - splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - skipLogitDiff - | none => - if requireNonvacuous && skipLogitDiff then - fail "--skip-logit-diff is not allowed with certify_nonvacuous" - else if requireNonvacuous then - IO.runInductionCertifyHeadModelAutoNonvacuous modelPath layer' head' period? - prevShift minActive? minLogitDiffStr? minMarginStr? maxEpsStr? - timing? heartbeatMs? - splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - else - IO.runInductionCertifyHeadModelAuto modelPath layer' head' period? prevShift - minActive? minLogitDiffStr? minMarginStr? maxEpsStr? - timing? heartbeatMs? - splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - skipLogitDiff - | _, _ => - fail "--layer and --head are required with --model" - | none, none => - fail "provide exactly one of --cert or --inputs/--model" - | some _, some _ => - fail "provide exactly one of --cert or --inputs/--model" + match certPath? with + | none => fail "provide --cert" + | some certPath => + if requireNonvacuous then + IO.runInductionHeadCertCheck certPath minActive? minLogitDiffStr? + minMarginStr? maxEpsStr? + else + IO.runInductionHeadCertCheck certPath minActive? minLogitDiffStr? + minMarginStr? maxEpsStr? private def runInductionCertifySimple (p : Parsed) : IO UInt32 := runInductionCertifyUnified false p @@ -179,705 +54,32 @@ private def runInductionCertifySimple (p : Parsed) : IO UInt32 := private def runInductionCertifyNonvacuousSimple (p : Parsed) : IO UInt32 := runInductionCertifyUnified true p -private def runInductionIntervalSimple (p : Parsed) : IO UInt32 := do - let inputsPath? := (p.flag? "inputs").map (·.as! String) - let modelPath? := (p.flag? "model").map (·.as! String) - let layer? := (p.flag? "layer").map (·.as! Nat) - let head? := (p.flag? "head").map (·.as! Nat) - let period? := (p.flag? "period").map (·.as! Nat) - let prevShift := p.hasFlag "prev-shift" - let directionStr? := (p.flag? "direction").map (·.as! String) - let outPath? := (p.flag? "out").map (·.as! String) - let zeroBased := p.hasFlag "zero-based" - let fail (msg : String) : IO UInt32 := do - IO.eprintln s!"error: {msg}" - return 2 - let directionE := - match directionStr? with - | none => Except.ok none - | some raw => (parseDirectionSpec raw).map some - match directionE with - | Except.error msg => fail msg - | Except.ok direction? => - match inputsPath?, modelPath? with - | some inputsPath, none => - if layer?.isSome || head?.isSome || period?.isSome || prevShift then - fail "--layer/--head/--period/--prev-shift are only valid with --model" - else if direction?.isSome then - fail "--direction is only valid with --model" - else - IO.runInductionHeadInterval inputsPath outPath? - | none, some modelPath => - match layer?, head?, direction? with - | some layer, some head, some ⟨dirTarget, dirNegative⟩ => - let layerE := toZeroBased "layer" layer zeroBased - let headE := toZeroBased "head" head zeroBased - match layerE, headE with - | Except.error msg, _ => fail msg - | _, Except.error msg => fail msg - | Except.ok layer', Except.ok head' => - IO.runInductionHeadIntervalModel modelPath layer' head' dirTarget dirNegative - period? prevShift outPath? - | _, _, none => - fail "--direction is required with --model (use \"target,negative\")" - | _, _, _ => - fail "--layer and --head are required with --model" - | none, none => - fail "provide exactly one of --inputs or --model" - | some _, some _ => - fail "provide exactly one of --inputs or --model" - /-- `nfp induction certify` subcommand (streamlined). -/ def inductionCertifySimpleCmd : Cmd := `[Cli| certify VIA runInductionCertifySimple; - "Check induction head certificates from an explicit cert. Inputs/model modes are DEPRECATED." + "Check induction head certificates from an explicit cert." FLAGS: - cert : String; "Path to the induction head certificate file \ - (use either --cert or --inputs/--model)." - inputs : String; "Path to the induction head input file (use either --inputs or --model)." - model : String; "Path to the NFP_BINARY_V1 model file (use either --inputs or --model)." - layer : Nat; "Layer index for the induction head (1-based, required with --model)." - head : Nat; "Head index for the induction head (1-based, required with --model)." - "zero-based"; "Interpret --layer/--head as zero-based indices (legacy)." - period : Nat; "Optional prompt period override (model only; default: derive from tokens)." - "prev-shift"; "Use shifted prev/active (prev = q - period + 1) for model inputs." - direction : String; "Optional logit-diff direction as \"target,negative\" (model only). \ - When omitted with --model, direction is derived from tokens." - preset : String; "Split-budget preset: fast | balanced | tight (default: balanced)." + cert : String; "Path to the induction head certificate file." "min-active" : Nat; "Optional minimum number of active queries required \ (default: max 1 (seq/8))." "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ (rational literal; default: 0)." "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." - "skip-logit-diff"; "Skip logit-diff lower bound computation." - timing : Nat; "Emit timing output to stdout (0=off, 1=on)." - "heartbeat-ms" : Nat; "Emit progress heartbeat every N ms (0 disables)." ] /-- `nfp induction certify_nonvacuous` subcommand (streamlined). -/ def inductionCertifyNonvacuousSimpleCmd : Cmd := `[Cli| certify_nonvacuous VIA runInductionCertifyNonvacuousSimple; - "Require a strictly positive logit-diff bound from a cert. Inputs/model modes are DEPRECATED." - FLAGS: - cert : String; "Path to the induction head certificate file \ - (use either --cert or --inputs/--model)." - inputs : String; "Path to the induction head input file (use either --inputs or --model)." - model : String; "Path to the NFP_BINARY_V1 model file (use either --inputs or --model)." - layer : Nat; "Layer index for the induction head (1-based, required with --model)." - head : Nat; "Head index for the induction head (1-based, required with --model)." - "zero-based"; "Interpret --layer/--head as zero-based indices (legacy)." - period : Nat; "Optional prompt period override (model only; default: derive from tokens)." - "prev-shift"; "Use shifted prev/active (prev = q - period + 1) for model inputs." - direction : String; "Optional logit-diff direction as \"target,negative\" (model only). \ - When omitted with --model, direction is derived from tokens." - preset : String; "Split-budget preset: fast | balanced | tight (default: balanced)." - "min-active" : Nat; "Optional minimum number of active queries required \ - (default: max 1 (seq/8))." - "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ - (rational literal; default: 0)." - "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." - "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." - timing : Nat; "Emit timing output to stdout (0=off, 1=on)." - "heartbeat-ms" : Nat; "Emit progress heartbeat every N ms (0 disables)." -] - -/-- `nfp induction interval` subcommand (streamlined). -/ -def inductionIntervalSimpleCmd : Cmd := `[Cli| - interval VIA runInductionIntervalSimple; - "DEPRECATED: build head-output interval bounds from inputs or a model file." - FLAGS: - inputs : String; "Path to the induction head input file (use either --inputs or --model)." - model : String; "Path to the NFP_BINARY_V1 model file (use either --inputs or --model)." - layer : Nat; "Layer index for the induction head (1-based, required with --model)." - head : Nat; "Head index for the induction head (1-based, required with --model)." - "zero-based"; "Interpret --layer/--head as zero-based indices (legacy)." - period : Nat; "Optional prompt period override (model only; default: derive from tokens)." - direction : String; "Required logit-diff direction as \"target,negative\" (model only)." - out : String; "Optional path to write the residual-interval certificate." -] - -/-- Check induction certificates for induction heads. -/ -def runInductionCertify (p : Parsed) : IO UInt32 := do - let scoresPath := p.flag! "scores" |>.as! String - let valuesPath? := (p.flag? "values").map (·.as! String) - let minActive? := (p.flag? "min-active").map (·.as! Nat) - let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) - let minMarginStr? := (p.flag? "min-margin").map (·.as! String) - let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) - IO.runInductionCertify scoresPath valuesPath? minActive? minLogitDiffStr? - minMarginStr? maxEpsStr? - -/-- `nfp induction certify` subcommand. -/ -def inductionCertifyCmd : Cmd := `[Cli| - certify VIA runInductionCertify; - "Check induction certificates for induction heads." - FLAGS: - scores : String; "Path to the softmax-margin certificate file." - values : String; "Optional path to a value-range certificate file." - "min-active" : Nat; "Optional minimum number of active queries required \ - (default: max 1 (seq/8))." - "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ - (rational literal; requires --values). Defaults \ - to 0 when direction metadata is present." - "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." - "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." -] - -/-- `nfp induction certify-sound` subcommand. -/ -def runInductionCertifySound (p : Parsed) : IO UInt32 := do - let scoresPath := p.flag! "scores" |>.as! String - let valuesPath := p.flag! "values" |>.as! String - let minActive? := (p.flag? "min-active").map (·.as! Nat) - let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) - let minMarginStr? := (p.flag? "min-margin").map (·.as! String) - let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) - IO.runInductionCertifySound scoresPath valuesPath minActive? minLogitDiffStr? - minMarginStr? maxEpsStr? - -/-- `nfp induction certify_sound` subcommand. -/ -def inductionCertifySoundCmd : Cmd := `[Cli| - certify_sound VIA runInductionCertifySound; - "DEPRECATED: check induction certificates from raw scores/values." - FLAGS: - scores : String; "Path to the raw scores/weights file." - values : String; "Path to the raw value entries file." - "min-active" : Nat; "Optional minimum number of active queries required \ - (default: max 1 (seq/8))." - "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ - (rational literal). Defaults to 0 when \ - direction metadata is present." - "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." - "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." -] - -/-- `nfp induction certify_end_to_end` subcommand. -/ -def runInductionCertifyEndToEnd (p : Parsed) : IO UInt32 := do - let scoresPath := p.flag! "scores" |>.as! String - let valuesPath := p.flag! "values" |>.as! String - let downstreamPath := p.flag! "downstream" |>.as! String - let minActive? := (p.flag? "min-active").map (·.as! Nat) - let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) - let minMarginStr? := (p.flag? "min-margin").map (·.as! String) - let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) - IO.runInductionCertifyEndToEnd scoresPath valuesPath downstreamPath - minActive? minLogitDiffStr? minMarginStr? maxEpsStr? - -/-- `nfp induction certify_end_to_end` subcommand. -/ -def inductionCertifyEndToEndCmd : Cmd := `[Cli| - certify_end_to_end VIA runInductionCertifyEndToEnd; - "Check end-to-end induction bounds with a downstream error certificate." - FLAGS: - scores : String; "Path to the softmax-margin certificate file." - values : String; "Path to the value-range certificate file." - downstream : String; "Path to the downstream linear certificate file." - "min-active" : Nat; "Optional minimum number of active queries required \ - (default: max 1 (seq/8))." - "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ - (rational literal). Defaults to 0 when \ - direction metadata is present." - "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." - "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." -] - -/-- `nfp induction certify_end_to_end_matrix` subcommand. -/ -def runInductionCertifyEndToEndMatrix (p : Parsed) : IO UInt32 := do - let scoresPath := p.flag! "scores" |>.as! String - let valuesPath := p.flag! "values" |>.as! String - let matrixPath := p.flag! "matrix" |>.as! String - let minActive? := (p.flag? "min-active").map (·.as! Nat) - let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) - let minMarginStr? := (p.flag? "min-margin").map (·.as! String) - let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) - IO.runInductionCertifyEndToEndMatrix scoresPath valuesPath matrixPath - minActive? minLogitDiffStr? minMarginStr? maxEpsStr? - -/-- `nfp induction certify_end_to_end_matrix` subcommand. -/ -def inductionCertifyEndToEndMatrixCmd : Cmd := `[Cli| - certify_end_to_end_matrix VIA runInductionCertifyEndToEndMatrix; - "DEPRECATED: check end-to-end induction bounds using a downstream matrix payload." - FLAGS: - scores : String; "Path to the softmax-margin certificate file." - values : String; "Path to the value-range certificate file." - matrix : String; "Path to the downstream matrix payload file." - "min-active" : Nat; "Optional minimum number of active queries required \ - (default: max 1 (seq/8))." - "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ - (rational literal). Defaults to 0 when \ - direction metadata is present." - "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." - "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." -] - -/-- `nfp induction certify_end_to_end_model` subcommand. -/ -def runInductionCertifyEndToEndModel (p : Parsed) : IO UInt32 := do - let scoresPath := p.flag! "scores" |>.as! String - let valuesPath := p.flag! "values" |>.as! String - let modelPath := p.flag! "model" |>.as! String - let residualIntervalPath? := (p.flag? "residual-interval").map (·.as! String) - let layer? := (p.flag? "layer").map (·.as! Nat) - let head? := (p.flag? "head").map (·.as! Nat) - let period? := (p.flag? "period").map (·.as! Nat) - let minActive? := (p.flag? "min-active").map (·.as! Nat) - let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) - let minMarginStr? := (p.flag? "min-margin").map (·.as! String) - let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) - IO.runInductionCertifyEndToEndModel scoresPath valuesPath modelPath residualIntervalPath? - layer? head? period? minActive? minLogitDiffStr? minMarginStr? maxEpsStr? - -/-- `nfp induction certify_end_to_end_model` subcommand. -/ -def inductionCertifyEndToEndModelCmd : Cmd := `[Cli| - certify_end_to_end_model VIA runInductionCertifyEndToEndModel; - "DEPRECATED: check end-to-end induction bounds using a model file for downstream bounds." - FLAGS: - scores : String; "Path to the softmax-margin certificate file." - values : String; "Path to the value-range certificate file." - model : String; "Path to the NFP_BINARY_V1 model file." - "residual-interval" : String; "Optional path to a residual-interval certificate file \ - (defaults to deriving from the model)." - layer : Nat; "Optional layer index for a head-output interval bound (requires --head)." - head : Nat; "Optional head index for a head-output interval bound (requires --layer)." - period : Nat; "Optional prompt period override when reading head inputs." - "min-active" : Nat; "Optional minimum number of active queries required \ - (default: max 1 (seq/8))." - "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ - (rational literal). Defaults to 0 when \ - direction metadata is present." - "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." - "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." -] - -/-- `nfp induction certify_head` subcommand. -/ -def runInductionCertifyHead (p : Parsed) : IO UInt32 := do - let inputsPath := p.flag! "inputs" |>.as! String - let minActive? := (p.flag? "min-active").map (·.as! Nat) - let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) - let minMarginStr? := (p.flag? "min-margin").map (·.as! String) - let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) - let timing? := (p.flag? "timing").map (·.as! Nat) - let heartbeatMs? := (p.flag? "heartbeat-ms").map (·.as! Nat) - let splitBudgetQ? := (p.flag? "split-budget-q").map (·.as! Nat) - let splitBudgetK? := (p.flag? "split-budget-k").map (·.as! Nat) - let splitBudgetDiffBase? := (p.flag? "split-budget-diff-base").map (·.as! Nat) - let splitBudgetDiffRefined? := (p.flag? "split-budget-diff-refined").map (·.as! Nat) - let skipLogitDiff := p.hasFlag "skip-logit-diff" - IO.runInductionCertifyHead inputsPath minActive? minLogitDiffStr? - minMarginStr? maxEpsStr? timing? heartbeatMs? - splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? skipLogitDiff - -/-- `nfp induction certify_head_nonvacuous` subcommand. -/ -def runInductionCertifyHeadNonvacuous (p : Parsed) : IO UInt32 := do - let inputsPath := p.flag! "inputs" |>.as! String - let minActive? := (p.flag? "min-active").map (·.as! Nat) - let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) - let minMarginStr? := (p.flag? "min-margin").map (·.as! String) - let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) - let timing? := (p.flag? "timing").map (·.as! Nat) - let heartbeatMs? := (p.flag? "heartbeat-ms").map (·.as! Nat) - let splitBudgetQ? := (p.flag? "split-budget-q").map (·.as! Nat) - let splitBudgetK? := (p.flag? "split-budget-k").map (·.as! Nat) - let splitBudgetDiffBase? := (p.flag? "split-budget-diff-base").map (·.as! Nat) - let splitBudgetDiffRefined? := (p.flag? "split-budget-diff-refined").map (·.as! Nat) - IO.runInductionCertifyHeadNonvacuous inputsPath minActive? minLogitDiffStr? - minMarginStr? maxEpsStr? timing? heartbeatMs? - splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - -/-- `nfp induction certify_head` subcommand. -/ -def inductionCertifyHeadCmd : Cmd := `[Cli| - certify_head VIA runInductionCertifyHead; - "DEPRECATED: check induction certificates from exact head inputs." - FLAGS: - inputs : String; "Path to the induction head input file." - "min-active" : Nat; "Optional minimum number of active queries required \ - (default: max 1 (seq/8))." - "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ - (rational literal). Defaults to 0." - "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." - "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." - timing : Nat; "Emit timing output to stdout (0=off, 1=on)." - "heartbeat-ms" : Nat; "Emit progress heartbeat every N ms (0 disables)." - "split-budget-q" : Nat; "Split-budget for query dims in sign-splitting bounds (default: 2)." - "split-budget-k" : Nat; "Split-budget for key dims in sign-splitting bounds (default: 2)." - "split-budget-diff-base" : Nat; "Split-budget for base diff dims in sign-splitting bounds \ - (default: 0)." - "split-budget-diff-refined" : Nat; "Split-budget for refined diff dims in sign-splitting \ - bounds (default: 12)." - "skip-logit-diff" : Bool; "Skip logit-diff lower bound computation." -] - -/-- `nfp induction certify_head_nonvacuous` subcommand. -/ -def inductionCertifyHeadNonvacuousCmd : Cmd := `[Cli| - certify_head_nonvacuous VIA runInductionCertifyHeadNonvacuous; - "DEPRECATED: require a strictly positive logit-diff bound from exact head inputs." + "Require a strictly positive logit-diff bound from a cert." FLAGS: - inputs : String; "Path to the induction head input file." + cert : String; "Path to the induction head certificate file." "min-active" : Nat; "Optional minimum number of active queries required \ (default: max 1 (seq/8))." "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ (rational literal; default: 0)." "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." - timing : Nat; "Emit timing output to stdout (0=off, 1=on)." - "heartbeat-ms" : Nat; "Emit progress heartbeat every N ms (0 disables)." - "split-budget-q" : Nat; "Split-budget for query dims in sign-splitting bounds (default: 2)." - "split-budget-k" : Nat; "Split-budget for key dims in sign-splitting bounds (default: 2)." - "split-budget-diff-base" : Nat; "Split-budget for base diff dims in sign-splitting bounds \ - (default: 0)." - "split-budget-diff-refined" : Nat; "Split-budget for refined diff dims in sign-splitting \ - bounds (default: 12)." - "skip-logit-diff" : Bool; "Skip logit-diff lower bound computation." -] - -/-- `nfp induction certify_head_model` subcommand. -/ -def runInductionCertifyHeadModel (p : Parsed) : IO UInt32 := do - let modelPath := p.flag! "model" |>.as! String - let layer := p.flag! "layer" |>.as! Nat - let head := p.flag! "head" |>.as! Nat - let period? := (p.flag? "period").map (·.as! Nat) - let prevShift := p.hasFlag "prev-shift" - let dirTarget := p.flag! "direction-target" |>.as! Nat - let dirNegative := p.flag! "direction-negative" |>.as! Nat - let minActive? := (p.flag? "min-active").map (·.as! Nat) - let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) - let minMarginStr? := (p.flag? "min-margin").map (·.as! String) - let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) - let timing? := (p.flag? "timing").map (·.as! Nat) - let heartbeatMs? := (p.flag? "heartbeat-ms").map (·.as! Nat) - let splitBudgetQ? := (p.flag? "split-budget-q").map (·.as! Nat) - let splitBudgetK? := (p.flag? "split-budget-k").map (·.as! Nat) - let splitBudgetDiffBase? := (p.flag? "split-budget-diff-base").map (·.as! Nat) - let splitBudgetDiffRefined? := (p.flag? "split-budget-diff-refined").map (·.as! Nat) - let skipLogitDiff := p.hasFlag "skip-logit-diff" - let zeroBased := p.hasFlag "zero-based" - let layerE := toZeroBased "layer" layer zeroBased - let headE := toZeroBased "head" head zeroBased - match layerE, headE with - | Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok layer', Except.ok head' => - IO.runInductionCertifyHeadModel modelPath layer' head' dirTarget dirNegative period? - prevShift minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? - splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? skipLogitDiff - -/-- `nfp induction certify_head_model_nonvacuous` subcommand. -/ -def runInductionCertifyHeadModelNonvacuous (p : Parsed) : IO UInt32 := do - let modelPath := p.flag! "model" |>.as! String - let layer := p.flag! "layer" |>.as! Nat - let head := p.flag! "head" |>.as! Nat - let period? := (p.flag? "period").map (·.as! Nat) - let prevShift := p.hasFlag "prev-shift" - let dirTarget := p.flag! "direction-target" |>.as! Nat - let dirNegative := p.flag! "direction-negative" |>.as! Nat - let minActive? := (p.flag? "min-active").map (·.as! Nat) - let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) - let minMarginStr? := (p.flag? "min-margin").map (·.as! String) - let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) - let timing? := (p.flag? "timing").map (·.as! Nat) - let heartbeatMs? := (p.flag? "heartbeat-ms").map (·.as! Nat) - let splitBudgetQ? := (p.flag? "split-budget-q").map (·.as! Nat) - let splitBudgetK? := (p.flag? "split-budget-k").map (·.as! Nat) - let splitBudgetDiffBase? := (p.flag? "split-budget-diff-base").map (·.as! Nat) - let splitBudgetDiffRefined? := (p.flag? "split-budget-diff-refined").map (·.as! Nat) - let zeroBased := p.hasFlag "zero-based" - let layerE := toZeroBased "layer" layer zeroBased - let headE := toZeroBased "head" head zeroBased - match layerE, headE with - | Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok layer', Except.ok head' => - IO.runInductionCertifyHeadModelNonvacuous modelPath layer' head' dirTarget dirNegative - period? prevShift minActive? minLogitDiffStr? minMarginStr? maxEpsStr? - timing? heartbeatMs? - splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - -/-- `nfp induction certify_head_model_auto` subcommand. -/ -def runInductionCertifyHeadModelAuto (p : Parsed) : IO UInt32 := do - let modelPath := p.flag! "model" |>.as! String - let layer := p.flag! "layer" |>.as! Nat - let head := p.flag! "head" |>.as! Nat - let period? := (p.flag? "period").map (·.as! Nat) - let prevShift := p.hasFlag "prev-shift" - let minActive? := (p.flag? "min-active").map (·.as! Nat) - let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) - let minMarginStr? := (p.flag? "min-margin").map (·.as! String) - let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) - let timing? := (p.flag? "timing").map (·.as! Nat) - let heartbeatMs? := (p.flag? "heartbeat-ms").map (·.as! Nat) - let splitBudgetQ? := (p.flag? "split-budget-q").map (·.as! Nat) - let splitBudgetK? := (p.flag? "split-budget-k").map (·.as! Nat) - let splitBudgetDiffBase? := (p.flag? "split-budget-diff-base").map (·.as! Nat) - let splitBudgetDiffRefined? := (p.flag? "split-budget-diff-refined").map (·.as! Nat) - let skipLogitDiff := p.hasFlag "skip-logit-diff" - let zeroBased := p.hasFlag "zero-based" - let layerE := toZeroBased "layer" layer zeroBased - let headE := toZeroBased "head" head zeroBased - match layerE, headE with - | Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok layer', Except.ok head' => - IO.runInductionCertifyHeadModelAuto modelPath layer' head' period? prevShift - minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? - splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? skipLogitDiff - -/-- `nfp induction certify_head_model_auto_nonvacuous` subcommand. -/ -def runInductionCertifyHeadModelAutoNonvacuous (p : Parsed) : IO UInt32 := do - let modelPath := p.flag! "model" |>.as! String - let layer := p.flag! "layer" |>.as! Nat - let head := p.flag! "head" |>.as! Nat - let period? := (p.flag? "period").map (·.as! Nat) - let prevShift := p.hasFlag "prev-shift" - let minActive? := (p.flag? "min-active").map (·.as! Nat) - let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) - let minMarginStr? := (p.flag? "min-margin").map (·.as! String) - let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) - let timing? := (p.flag? "timing").map (·.as! Nat) - let heartbeatMs? := (p.flag? "heartbeat-ms").map (·.as! Nat) - let splitBudgetQ? := (p.flag? "split-budget-q").map (·.as! Nat) - let splitBudgetK? := (p.flag? "split-budget-k").map (·.as! Nat) - let splitBudgetDiffBase? := (p.flag? "split-budget-diff-base").map (·.as! Nat) - let splitBudgetDiffRefined? := (p.flag? "split-budget-diff-refined").map (·.as! Nat) - let zeroBased := p.hasFlag "zero-based" - let layerE := toZeroBased "layer" layer zeroBased - let headE := toZeroBased "head" head zeroBased - match layerE, headE with - | Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok layer', Except.ok head' => - IO.runInductionCertifyHeadModelAutoNonvacuous modelPath layer' head' period? prevShift - minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? - splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - -/-! `nfp induction certify_circuit_model` subcommand. -/ -/-- CLI entrypoint for `nfp induction certify_circuit_model`. -/ -def runInductionCertifyCircuitModel (p : Parsed) : IO UInt32 := do - let modelPath := p.flag! "model" |>.as! String - let prevLayer := p.flag! "prev-layer" |>.as! Nat - let prevHead := p.flag! "prev-head" |>.as! Nat - let indLayer := p.flag! "ind-layer" |>.as! Nat - let indHead := p.flag! "ind-head" |>.as! Nat - let period? := (p.flag? "period").map (·.as! Nat) - let dirTarget := p.flag! "direction-target" |>.as! Nat - let dirNegative := p.flag! "direction-negative" |>.as! Nat - let minActive? := (p.flag? "min-active").map (·.as! Nat) - let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) - let minMarginStr? := (p.flag? "min-margin").map (·.as! String) - let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) - let timing? := (p.flag? "timing").map (·.as! Nat) - let heartbeatMs? := (p.flag? "heartbeat-ms").map (·.as! Nat) - let splitBudgetQ? := (p.flag? "split-budget-q").map (·.as! Nat) - let splitBudgetK? := (p.flag? "split-budget-k").map (·.as! Nat) - let splitBudgetDiffBase? := (p.flag? "split-budget-diff-base").map (·.as! Nat) - let splitBudgetDiffRefined? := (p.flag? "split-budget-diff-refined").map (·.as! Nat) - let skipLogitDiff := p.hasFlag "skip-logit-diff" - let zeroBased := p.hasFlag "zero-based" - match period? with - | none => - IO.eprintln "error: --period is required for circuit certification" - return 2 - | some period => - if period = 0 then - IO.eprintln "error: --period must be positive for circuit certification" - return 2 - let prevLayerE := toZeroBased "prev-layer" prevLayer zeroBased - let prevHeadE := toZeroBased "prev-head" prevHead zeroBased - let indLayerE := toZeroBased "ind-layer" indLayer zeroBased - let indHeadE := toZeroBased "ind-head" indHead zeroBased - match prevLayerE, prevHeadE, indLayerE, indHeadE with - | Except.error msg, _, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok prevLayer', Except.ok prevHead', Except.ok indLayer', Except.ok indHead' => - IO.runInductionCertifyCircuitModel modelPath prevLayer' prevHead' indLayer' indHead' - dirTarget dirNegative period - minActive? minLogitDiffStr? minMarginStr? maxEpsStr? timing? heartbeatMs? - splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - skipLogitDiff - -/-- `nfp induction certify_head_model` subcommand. -/ -def inductionCertifyHeadModelCmd : Cmd := `[Cli| - certify_head_model VIA runInductionCertifyHeadModel; - "DEPRECATED: check induction certificates by reading a model binary directly." - FLAGS: - model : String; "Path to the NFP_BINARY_V1 model file." - layer : Nat; "Layer index for the induction head (1-based)." - head : Nat; "Head index for the induction head (1-based)." - "zero-based"; "Interpret --layer/--head as zero-based indices (legacy)." - period : Nat; "Optional prompt period override (default: derive from tokens)." - "prev-shift"; "Use shifted prev/active (prev = q - period + 1)." - "direction-target" : Nat; "Target token id for logit-diff direction." - "direction-negative" : Nat; "Negative token id for logit-diff direction." - "min-active" : Nat; "Optional minimum number of active queries required \ - (default: max 1 (seq/8))." - "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ - (rational literal). Defaults to 0." - "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." - "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." - timing : Nat; "Emit timing output to stdout (0=off, 1=on)." - "heartbeat-ms" : Nat; "Emit progress heartbeat every N ms (0 disables)." - "split-budget-q" : Nat; "Split-budget for query dims in sign-splitting bounds (default: 2)." - "split-budget-k" : Nat; "Split-budget for key dims in sign-splitting bounds (default: 2)." - "split-budget-diff-base" : Nat; "Split-budget for base diff dims in sign-splitting bounds \ - (default: 0)." - "split-budget-diff-refined" : Nat; "Split-budget for refined diff dims in sign-splitting \ - bounds (default: 12)." - "skip-logit-diff" : Bool; "Skip logit-diff lower bound computation." -] - -/-- `nfp induction certify_head_model_nonvacuous` subcommand. -/ -def inductionCertifyHeadModelNonvacuousCmd : Cmd := `[Cli| - certify_head_model_nonvacuous VIA runInductionCertifyHeadModelNonvacuous; - "DEPRECATED: require a strictly positive logit-diff bound from a model binary." - FLAGS: - model : String; "Path to the NFP_BINARY_V1 model file." - layer : Nat; "Layer index for the induction head (1-based)." - head : Nat; "Head index for the induction head (1-based)." - "zero-based"; "Interpret --layer/--head as zero-based indices (legacy)." - period : Nat; "Optional prompt period override (default: derive from tokens)." - "prev-shift"; "Use shifted prev/active (prev = q - period + 1)." - "direction-target" : Nat; "Target token id for logit-diff direction." - "direction-negative" : Nat; "Negative token id for logit-diff direction." - "min-active" : Nat; "Optional minimum number of active queries required \ - (default: max 1 (seq/8))." - "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ - (rational literal; default: 0)." - "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." - "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." - timing : Nat; "Emit timing output to stdout (0=off, 1=on)." - "heartbeat-ms" : Nat; "Emit progress heartbeat every N ms (0 disables)." - "split-budget-q" : Nat; "Split-budget for query dims in sign-splitting bounds (default: 2)." - "split-budget-k" : Nat; "Split-budget for key dims in sign-splitting bounds (default: 2)." - "split-budget-diff-base" : Nat; "Split-budget for base diff dims in sign-splitting bounds \ - (default: 0)." - "split-budget-diff-refined" : Nat; "Split-budget for refined diff dims in sign-splitting \ - bounds (default: 12)." -] - -/-- `nfp induction certify_head_model_auto` subcommand. -/ -def inductionCertifyHeadModelAutoCmd : Cmd := `[Cli| - certify_head_model_auto VIA runInductionCertifyHeadModelAuto; - "DEPRECATED: check induction certificates by reading a model binary and deriving the \ - direction from the prompt tokens." - FLAGS: - model : String; "Path to the NFP_BINARY_V1 model file." - layer : Nat; "Layer index for the induction head (1-based)." - head : Nat; "Head index for the induction head (1-based)." - "zero-based"; "Interpret --layer/--head as zero-based indices (legacy)." - period : Nat; "Optional prompt period override (default: derive from tokens)." - "prev-shift"; "Use shifted prev/active (prev = q - period + 1)." - "min-active" : Nat; "Optional minimum number of active queries required \ - (default: max 1 (seq/8))." - "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ - (rational literal). Defaults to 0." - "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." - "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." - timing : Nat; "Emit timing output to stdout (0=off, 1=on)." - "heartbeat-ms" : Nat; "Emit progress heartbeat every N ms (0 disables)." - "split-budget-q" : Nat; "Split-budget for query dims in sign-splitting bounds (default: 2)." - "split-budget-k" : Nat; "Split-budget for key dims in sign-splitting bounds (default: 2)." - "split-budget-diff-base" : Nat; "Split-budget for base diff dims in sign-splitting bounds \ - (default: 0)." - "split-budget-diff-refined" : Nat; "Split-budget for refined diff dims in sign-splitting \ - bounds (default: 12)." -] - -/-- `nfp induction certify_head_model_auto_nonvacuous` subcommand. -/ -def inductionCertifyHeadModelAutoNonvacuousCmd : Cmd := `[Cli| - certify_head_model_auto_nonvacuous VIA runInductionCertifyHeadModelAutoNonvacuous; - "DEPRECATED: require a strictly positive logit-diff bound from a model binary, with the \ - direction derived from the prompt tokens." - FLAGS: - model : String; "Path to the NFP_BINARY_V1 model file." - layer : Nat; "Layer index for the induction head (1-based)." - head : Nat; "Head index for the induction head (1-based)." - "zero-based"; "Interpret --layer/--head as zero-based indices (legacy)." - period : Nat; "Optional prompt period override (default: derive from tokens)." - "prev-shift"; "Use shifted prev/active (prev = q - period + 1)." - "min-active" : Nat; "Optional minimum number of active queries required \ - (default: max 1 (seq/8))." - "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ - (rational literal; default: 0)." - "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." - "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." - timing : Nat; "Emit timing output to stdout (0=off, 1=on)." - "heartbeat-ms" : Nat; "Emit progress heartbeat every N ms (0 disables)." - "split-budget-q" : Nat; "Split-budget for query dims in sign-splitting bounds (default: 2)." - "split-budget-k" : Nat; "Split-budget for key dims in sign-splitting bounds (default: 2)." - "split-budget-diff-base" : Nat; "Split-budget for base diff dims in sign-splitting bounds \ - (default: 0)." - "split-budget-diff-refined" : Nat; "Split-budget for refined diff dims in sign-splitting \ - bounds (default: 12)." -] - -/-- `nfp induction certify_circuit_model` subcommand. -/ -def inductionCertifyCircuitModelCmd : Cmd := `[Cli| - certify_circuit_model VIA runInductionCertifyCircuitModel; - "DEPRECATED: check a two-head induction circuit by reading a model binary directly \ - (induction head uses shifted prev)." - FLAGS: - model : String; "Path to the NFP_BINARY_V1 model file." - "prev-layer" : Nat; "Layer index for the previous-token head (1-based)." - "prev-head" : Nat; "Head index for the previous-token head (1-based)." - "ind-layer" : Nat; "Layer index for the induction head (1-based)." - "ind-head" : Nat; "Head index for the induction head (1-based)." - "zero-based"; "Interpret --layer/--head as zero-based indices (legacy)." - period : Nat; "Prompt period override (required)." - "direction-target" : Nat; "Target token id for logit-diff direction." - "direction-negative" : Nat; "Negative token id for logit-diff direction." - "min-active" : Nat; "Optional minimum number of active queries required \ - (default: max 1 (seq/8))." - "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ - (rational literal). Defaults to 0." - "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." - "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." - "skip-logit-diff"; "Skip logit-diff lower bound computation." - timing : Nat; "Emit timing output to stdout (0=off, 1=on)." - "heartbeat-ms" : Nat; "Emit progress heartbeat every N ms (0 disables)." - "split-budget-q" : Nat; "Split-budget for query dims in sign-splitting bounds (default: 2)." - "split-budget-k" : Nat; "Split-budget for key dims in sign-splitting bounds (default: 2)." - "split-budget-diff-base" : Nat; "Split-budget for base diff dims in sign-splitting bounds \ - (default: 0)." - "split-budget-diff-refined" : Nat; "Split-budget for refined diff dims in sign-splitting \ - bounds (default: 12)." -] - -/-- `nfp induction head_interval` subcommand. -/ -def runInductionHeadInterval (p : Parsed) : IO UInt32 := do - let inputsPath := p.flag! "inputs" |>.as! String - let outPath? := (p.flag? "out").map (·.as! String) - IO.runInductionHeadInterval inputsPath outPath? - -/-- `nfp induction head_interval` subcommand. -/ -def inductionHeadIntervalCmd : Cmd := `[Cli| - head_interval VIA runInductionHeadInterval; - "DEPRECATED: build head-output interval bounds from exact head inputs." - FLAGS: - inputs : String; "Path to the induction head input file." - out : String; "Optional path to write the residual-interval certificate." ] /-- `nfp induction head_cert_check` subcommand. -/ @@ -903,77 +105,14 @@ def inductionHeadCertCheckCmd : Cmd := `[Cli| "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." ] -/-- `nfp induction head_interval_model` subcommand. -/ -def runInductionHeadIntervalModel (p : Parsed) : IO UInt32 := do - let modelPath := p.flag! "model" |>.as! String - let layer := p.flag! "layer" |>.as! Nat - let head := p.flag! "head" |>.as! Nat - let period? := (p.flag? "period").map (·.as! Nat) - let prevShift := p.hasFlag "prev-shift" - let dirTarget := p.flag! "direction-target" |>.as! Nat - let dirNegative := p.flag! "direction-negative" |>.as! Nat - let outPath? := (p.flag? "out").map (·.as! String) - let zeroBased := p.hasFlag "zero-based" - let layerE := toZeroBased "layer" layer zeroBased - let headE := toZeroBased "head" head zeroBased - match layerE, headE with - | Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok layer', Except.ok head' => - IO.runInductionHeadIntervalModel modelPath layer' head' dirTarget dirNegative period? - prevShift outPath? - -/-- `nfp induction head_interval_model` subcommand. -/ -def inductionHeadIntervalModelCmd : Cmd := `[Cli| - head_interval_model VIA runInductionHeadIntervalModel; - "DEPRECATED: build head-output interval bounds by reading a model binary directly." - FLAGS: - model : String; "Path to the NFP_BINARY_V1 model file." - layer : Nat; "Layer index for the induction head (1-based)." - head : Nat; "Head index for the induction head (1-based)." - "zero-based"; "Interpret --layer/--head as zero-based indices (legacy)." - period : Nat; "Optional prompt period override (default: derive from tokens)." - "prev-shift"; "Use shifted prev/active (prev = q - period + 1)." - "direction-target" : Nat; "Target token id for logit-diff direction." - "direction-negative" : Nat; "Negative token id for logit-diff direction." - out : String; "Optional path to write the residual-interval certificate." -] - -/-- Advanced induction-head subcommands (full flag surface). -/ -def inductionAdvancedCmd : Cmd := `[Cli| - advanced NOOP; - "Advanced induction-head utilities (full flag set)." - SUBCOMMANDS: - inductionCertifyCmd; - inductionCertifySoundCmd; - inductionCertifyEndToEndCmd; - inductionCertifyEndToEndMatrixCmd; - inductionCertifyEndToEndModelCmd; - inductionCertifyHeadCmd; - inductionCertifyHeadNonvacuousCmd; - inductionCertifyHeadModelCmd; - inductionCertifyHeadModelNonvacuousCmd; - inductionCertifyHeadModelAutoCmd; - inductionCertifyHeadModelAutoNonvacuousCmd; - inductionCertifyCircuitModelCmd; - inductionHeadIntervalCmd; - inductionHeadIntervalModelCmd; - inductionHeadCertCheckCmd -] - /-- Induction-head subcommands. -/ def inductionCmd : Cmd := `[Cli| induction NOOP; - "Induction-head utilities (streamlined). Use `nfp induction advanced --help` for full options." + "Induction-head utilities (streamlined)." SUBCOMMANDS: inductionCertifySimpleCmd; inductionCertifyNonvacuousSimpleCmd; - inductionIntervalSimpleCmd; - inductionAdvancedCmd + inductionHeadCertCheckCmd ] /-- The root CLI command. -/ diff --git a/Nfp/IO.lean b/Nfp/IO.lean index 47e3f56..75af425 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -2,14 +2,7 @@ module -public import Nfp.IO.Checks -public import Nfp.IO.Derive -public import Nfp.IO.HeadScore public import Nfp.IO.InductionHead -public import Nfp.IO.Loaders -public import Nfp.IO.NfptPure -public import Nfp.IO.Run -public import Nfp.IO.Timing public import Nfp.IO.Util /-! diff --git a/Nfp/IO/Derive.lean b/Nfp/IO/Derive.lean deleted file mode 100644 index 7158dc1..0000000 --- a/Nfp/IO/Derive.lean +++ /dev/null @@ -1,141 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Mathlib.Data.List.Range -public import Mathlib.Data.Matrix.Mul -public import Mathlib.Data.Vector.Defs -public import Nfp.IO.NfptPure -public import Nfp.IO.Timing -public import Nfp.Model.Gpt2 -public import Nfp.Sound.Bounds.Transformer -public import Nfp.Sound.Induction -public import Nfp.Sound.Induction.HeadBounds - -/-! -IO derivations that build certificates from model binaries. --/ - -public section - -namespace Nfp - -namespace IO - -open Nfp.Circuit - -/-- Build a residual-interval certificate from an on-disk model payload. -/ -def deriveResidualIntervalFromModel (data : ByteArray) (start : Nat) - (header : NfptPure.NfptHeader) (active? : Option (Finset (Fin header.seqLen))) : - IO (Except String (ResidualIntervalCert header.modelDim)) := do - if hseq : header.seqLen = 0 then - return Except.error "seq must be positive" - else - have _ : NeZero header.seqLen := ⟨hseq⟩ - if header.modelDim = 0 then - return Except.error "model dim must be positive" - else if 0 < header.layerNormEps then - let embedE ← timePure "read embeddings" (fun () => - NfptPure.readEmbeddings data start header) - match embedE with - | Except.error msg => return Except.error msg - | Except.ok embed => - let layerSlicesE ← timePure "read layer slices" (fun () => - NfptPure.readLayerSlices data start header) - match layerSlicesE with - | Except.error msg => return Except.error msg - | Except.ok layerSlices => - let headLayersE ← timePure "read layer heads" (fun () => - NfptPure.readLayerHeads data start header) - match headLayersE with - | Except.error msg => return Except.error msg - | Except.ok headLayers => - let finalLnE ← timePure "read final layer norm" (fun () => - NfptPure.readFinalLayerNorm data start header) - match finalLnE with - | Except.error msg => return Except.error msg - | Except.ok finalLn => - let layers : - Fin header.numLayers → - Model.Gpt2LayerSlice header.modelDim header.hiddenDim := - fun l => NfptPure.SizedArray.get layerSlices l - let heads : - Fin header.numLayers → Fin header.numHeads → - Model.Gpt2HeadWeights header.modelDim header.headDim := fun l h => - NfptPure.SizedArray.get (NfptPure.SizedArray.get headLayers l) h - let strict? ← IO.getEnv "NFP_TIMING_STRICT" - match strict? with - | some _ => - logTiming "timing strict enabled" - | none => - logTiming "timing strict disabled" - match active? with - | some active => - if hactive : active.Nonempty then - logTiming "before transformer stack bounds (active)" - let bounds ← timePhaseThunk "transformer stack bounds (active)" - (fun () => do - let bounds := Sound.Bounds.gpt2ResidualIntervalBoundsActive - active hactive header.layerNormEps layers heads finalLn embed - match strict? with - | some _ => - let forced := - (List.finRange header.modelDim).foldl - (fun acc i => acc + bounds.1 i + bounds.2 i) (0 : Rat) - logTiming s!"forced transformer stack sum {forced}" - | none => pure () - return bounds) - logTiming "after transformer stack bounds (active)" - return Except.ok { lo := bounds.1, hi := bounds.2 } - else - logTiming "active set empty; falling back to global bounds" - let base ← timePure "embedding interval bounds" (fun () => - Sound.Bounds.embeddingIntervalBounds embed) - logTiming "before transformer stack bounds" - let stack ← timePhaseThunk "transformer stack bounds" (fun () => do - let stack := Sound.Bounds.transformerStackBounds - (eps := header.layerNormEps) layers heads base.1 base.2 - match strict? with - | some _ => - let forced := - (List.finRange header.modelDim).foldl - (fun acc i => acc + stack.1 i + stack.2 i) (0 : Rat) - logTiming s!"forced transformer stack sum {forced}" - | none => pure () - return stack) - logTiming "after transformer stack bounds" - logTiming "enter final layer norm bounds" - let bounds ← timePure "final layer norm bounds" (fun () => - Sound.Bounds.layerNormIntervalBounds (eps := header.layerNormEps) - finalLn.gamma finalLn.beta stack.1 stack.2) - logTiming "exit final layer norm bounds" - return Except.ok { lo := bounds.1, hi := bounds.2 } - | none => - let base ← timePure "embedding interval bounds" (fun () => - Sound.Bounds.embeddingIntervalBounds embed) - logTiming "before transformer stack bounds" - let stack ← timePhaseThunk "transformer stack bounds" (fun () => do - let stack := Sound.Bounds.transformerStackBounds - (eps := header.layerNormEps) layers heads base.1 base.2 - match strict? with - | some _ => - let forced := - (List.finRange header.modelDim).foldl - (fun acc i => acc + stack.1 i + stack.2 i) (0 : Rat) - logTiming s!"forced transformer stack sum {forced}" - | none => pure () - return stack) - logTiming "after transformer stack bounds" - logTiming "enter final layer norm bounds" - let bounds ← timePure "final layer norm bounds" (fun () => - Sound.Bounds.layerNormIntervalBounds (eps := header.layerNormEps) - finalLn.gamma finalLn.beta stack.1 stack.2) - logTiming "exit final layer norm bounds" - return Except.ok { lo := bounds.1, hi := bounds.2 } - else - return Except.error - s!"layer norm epsilon {header.layerNormEps} must be positive" - -end IO - -end Nfp diff --git a/Nfp/IO/HeadScore.lean b/Nfp/IO/HeadScore.lean deleted file mode 100644 index 7bcec9a..0000000 --- a/Nfp/IO/HeadScore.lean +++ /dev/null @@ -1,60 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.Core.Basic -public import Nfp.Sound.Linear.FinFold - -/-! -Pure helpers for building cached dot-abs functions for head scoring. --/ - -public section - -namespace Nfp - -namespace IO - -/-- Build a cached dot-abs function from Q/K absolute bounds using tasks. -/ -def dotAbsFromQKV {seq dHead : Nat} - (qAbs kAbs : Fin seq → Fin dHead → Rat) : Fin seq → Fin seq → Rat := - let rowTasks : Array (Task (Array Rat)) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - Array.ofFn (fun k : Fin seq => - Sound.Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d)))) - let cache : Array (Array Rat) := - Array.ofFn (fun q : Fin seq => - (rowTasks[q.1]'(by - simp [rowTasks, q.isLt])).get) - fun q k => - let row := cache[q.1]'(by - simp [cache, q.isLt]) - row[k.1]'(by - have hrow : row.size = seq := by - simp [row, cache, rowTasks, Task.spawn] - simp [hrow, k.isLt]) - -private theorem dotAbsFromQKV_spec {seq dHead : Nat} - (qAbs kAbs : Fin seq → Fin dHead → Rat) : - dotAbsFromQKV qAbs kAbs = - let rowTasks : Array (Task (Array Rat)) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - Array.ofFn (fun k : Fin seq => - Sound.Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d)))) - let cache : Array (Array Rat) := - Array.ofFn (fun q : Fin seq => - (rowTasks[q.1]'(by - simp [rowTasks, q.isLt])).get) - fun q k => - let row := cache[q.1]'(by - simp [cache, q.isLt]) - row[k.1]'(by - have hrow : row.size = seq := by - simp [row, cache, rowTasks, Task.spawn] - simp [hrow, k.isLt]) := rfl - -end IO - -end Nfp diff --git a/Nfp/IO/InductionHead.lean b/Nfp/IO/InductionHead.lean index e718a73..29771dc 100644 --- a/Nfp/IO/InductionHead.lean +++ b/Nfp/IO/InductionHead.lean @@ -2,11 +2,8 @@ module -public import Nfp.IO.InductionHead.Basic public import Nfp.IO.InductionHead.Cert -public import Nfp.IO.InductionHead.Circuit -public import Nfp.IO.InductionHead.Nonvacuous /-! -IO helpers for induction-head certificate construction. +IO helpers for induction-head certificate checking. -/ diff --git a/Nfp/IO/InductionHead/Basic.lean b/Nfp/IO/InductionHead/Basic.lean deleted file mode 100644 index b5de6c2..0000000 --- a/Nfp/IO/InductionHead/Basic.lean +++ /dev/null @@ -1,1408 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Mathlib.Data.List.Range -public import Nfp.IO.Pure -public import Nfp.IO.NfptPure -public import Nfp.IO.HeadScore -public import Nfp.IO.Timing -public import Nfp.IO.Util -public import Nfp.Circuit.Cert.LogitDiff -public import Nfp.Circuit.Cert.ResidualInterval -public import Nfp.Sound.Induction -public import Nfp.Sound.Induction.HeadBounds -public import Nfp.Sound.Induction.LogitDiff -public import Nfp.Sound.Linear.FinFold - -/-! -IO helpers for induction-head certificate construction. --/ - -public section - -namespace Nfp - -namespace IO - -private def unwrapTaskResult {α : Type} (res : Except IO.Error α) : IO α := - match res with - | .ok a => pure a - | .error e => throw e - -/-- Configure timing output and heartbeat reporting. -/ -def configureTiming (timing? : Option Nat) (heartbeatMs? : Option Nat) : IO Unit := do - match timing? with - | some v => setTimingStdout (v ≠ 0) - | none => pure () - match heartbeatMs? with - | some v => - setTimingHeartbeatMs (UInt32.ofNat v) - if timing?.isNone && (v != 0) then - setTimingStdout true - | none => pure () - -/-- Translate CLI split-budget options into a split config. -/ -def splitConfigFromOptions - (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : - Sound.InductionHeadSplitConfig := - let base := Sound.defaultInductionHeadSplitConfig - { base with - splitBudgetQ := splitBudgetQ?.getD base.splitBudgetQ - splitBudgetK := splitBudgetK?.getD base.splitBudgetK - splitBudgetDiffBase := splitBudgetDiffBase?.getD base.splitBudgetDiffBase - splitBudgetDiffRefined := splitBudgetDiffRefined?.getD base.splitBudgetDiffRefined } - -open Nfp.Circuit - -private def valueBoundsModeFromEnv : IO (Option Bool) := do - match (← IO.getEnv "NFP_VALUE_BOUNDS_MODE") with - | some "common" => return some true - | some "cached" => return some false - | _ => return none - -/-- Load induction head inputs from disk. -/ -def loadInductionHeadInputs (path : System.FilePath) : - IO (Except String (Sigma (fun seq => - Sigma (fun dModel => Sigma (fun dHead => Model.InductionHeadInputs seq dModel dHead))))) := do - let t0 ← monoUsNow - let data ← IO.FS.readFile path - let t1 ← monoUsNow - timingPrint s!"timing: read head input file {t1 - t0} us" - let t2 ← monoUsNow - let parsed := - match Pure.parseInductionHeadInputs data with - | Except.error msg => Except.error msg - | Except.ok v => Except.ok v - let t3 ← monoUsNow - timingPrint s!"timing: parse head input file {t3 - t2} us" - return parsed - -/-- Render a rational for logging. -/ -def ratToString (x : Rat) : String := - toString x - -/-- Render an optional rational for logging. -/ -def ratOptToString (x : Option Rat) : String := - match x with - | some v => ratToString v - | none => "none" - -/-- Check whether logit-diff debug logging is enabled. -/ -def logitDiffDebugEnabled : IO Bool := do - return (← IO.getEnv "NFP_LOGITDIFF_DEBUG").isSome - -/-- Check whether logit-diff debug should exit early after dumping a witness. -/ -def logitDiffDebugEarlyExitEnabled : IO Bool := do - return (← IO.getEnv "NFP_LOGITDIFF_DEBUG_EARLY_EXIT").isSome - -/-- Check whether logit-diff refinement debug output is enabled. -/ -def logitDiffRefineEnabled : IO Bool := do - return (← IO.getEnv "NFP_LOGITDIFF_REFINE").isSome - -/-- Parse an optional query index for alternative logit-diff bound diagnostics. -/ -def logitDiffAltBoundQuery : IO (Option Nat) := do - match (← IO.getEnv "NFP_LOGITDIFF_ALT_BOUND_Q") with - | none => return none - | some txt => - match txt.toNat? with - | some n => return some n - | none => - IO.eprintln s!"warn: invalid NFP_LOGITDIFF_ALT_BOUND_Q={txt}" - return none - -/-- Parse an optional query index for q-only logit-diff diagnostics. -/ -def logitDiffQueryOnly : IO (Option Nat) := do - match (← IO.getEnv "NFP_LOGITDIFF_Q_ONLY") with - | none => return none - | some txt => - match txt.toNat? with - | some n => return some n - | none => - IO.eprintln s!"warn: invalid NFP_LOGITDIFF_Q_ONLY={txt}" - return none - -/-- Check whether q-only logit-diff diagnostics should include refined weight bounds. -/ -def logitDiffQueryOnlyRefineEnabled : IO Bool := do - return (← IO.getEnv "NFP_LOGITDIFF_Q_ONLY_REFINE").isSome - -/-- Check whether q-only logit-diff diagnostics should include refined value bounds. -/ -def logitDiffQueryOnlyValsEnabled : IO Bool := do - return (← IO.getEnv "NFP_LOGITDIFF_Q_ONLY_VALS").isSome - -/-- Check whether q-only logit-diff diagnostics should exit early. -/ -def logitDiffQueryOnlyEarlyExitEnabled : IO Bool := do - return (← IO.getEnv "NFP_LOGITDIFF_Q_ONLY_EARLY_EXIT").isSome - -private def renderResidualIntervalCert {n : Nat} (c : Circuit.ResidualIntervalCert n) : String := - let header := s!"dim {n}" - let lines := - (List.finRange n).foldr (fun i acc => - s!"lo {i.val} {ratToString (c.lo i)}" :: - s!"hi {i.val} {ratToString (c.hi i)}" :: acc) [] - String.intercalate "\n" (header :: lines) - -private def emitResidualIntervalCert {n : Nat} (c : Circuit.ResidualIntervalCert n) - (outPath? : Option System.FilePath) : IO Unit := do - let payload := renderResidualIntervalCert c - match outPath? with - | some path => IO.FS.writeFile path (payload ++ "\n") - | none => IO.println payload - -/-- Emit q-only logit-diff diagnostics, returning whether early exit was requested. -/ -def emitLogitDiffQueryOnly {seq dModel dHead : Nat} [NeZero seq] - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cfg : Sound.InductionHeadSplitConfig) - (cache : Sound.InductionHeadCoreCache seq dModel dHead) - (cert : Sound.InductionHeadCert seq) - (logitCache : Sound.LogitDiffCache seq) : IO Bool := do - match (← logitDiffQueryOnly) with - | none => return false - | some qNat => - if hq : qNat < seq then - let q : Fin seq := ⟨qNat, hq⟩ - let prev := cert.prev q - let epsAt : Fin seq → Rat := logitCache.epsAt - let others : Finset (Fin seq) := - (Finset.univ : Finset (Fin seq)).erase prev - IO.eprintln - s!"debug: q-only q={qNat} prev={prev.1} \ - epsAt={ratToString (epsAt q)}" - if (← logitDiffQueryOnlyValsEnabled) then - let valsLo : Fin seq → Rat := logitCache.valsLo - let loAt : Rat := - if h : others.Nonempty then - others.inf' h valsLo - else - cert.values.lo - let valsPrevLo := valsLo prev - let delta := valsPrevLo - loAt - let gap := epsAt q * max (0 : Rat) delta - let lbAtQ := valsPrevLo - gap - IO.eprintln - s!"debug: q-only loAt={ratToString loAt} \ - valsPrevLo={ratToString valsPrevLo} \ - lbAtQ={ratToString lbAtQ}" - if (← logitDiffQueryOnlyRefineEnabled) then - let refineBudget := max 1 cfg.splitBudgetDiffRefined - let spec := Sound.refineSpecForQueryWithWeightOnes inputs cache q refineBudget - let weightBoundAt := Sound.weightBoundAtOverlay inputs cache spec - let epsAtRef := Sound.epsAtOverlay cache weightBoundAt q - IO.eprintln - s!"debug: q-only refined budget={refineBudget} \ - epsAt={ratToString epsAtRef}" - if (← logitDiffQueryOnlyValsEnabled) then - let valBudget := Sound.refineBudgetBoost refineBudget - let valKeys := Sound.loAtKeysAt inputs cache q - let valsLoRef : Fin seq → Rat := - Sound.valsLoOverlay inputs cache valBudget valKeys - let loAtRef : Rat := - if h : others.Nonempty then - others.inf' h valsLoRef - else - cert.values.lo - let valsPrevLoRef := valsLoRef prev - let deltaRef := valsPrevLoRef - loAtRef - let gapRef := epsAtRef * max (0 : Rat) deltaRef - let lbAtQRef := valsPrevLoRef - gapRef - IO.eprintln - s!"debug: q-only refined loAt={ratToString loAtRef} \ - valsPrevLo={ratToString valsPrevLoRef} \ - lbAtQ={ratToString lbAtQRef}" - let earlyExit := (← logitDiffQueryOnlyEarlyExitEnabled) || - (← logitDiffDebugEarlyExitEnabled) - if earlyExit then - IO.eprintln "debug: early exit requested (q-only)" - return true - return false - else - IO.eprintln s!"warn: NFP_LOGITDIFF_Q_ONLY={qNat} out of range (seq={seq})" - return false - -private def buildHeadOutputIntervalFromInputs {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (outPath? : Option System.FilePath) : IO UInt32 := do - match seq with - | 0 => - IO.eprintln "error: seq must be positive" - return 2 - | Nat.succ n => - let seq := Nat.succ n - let _ : NeZero seq := ⟨by simp⟩ - match Sound.buildHeadOutputIntervalFromHead? inputs with - | none => - IO.eprintln "error: head output interval rejected" - return 2 - | some result => - emitResidualIntervalCert result.cert outPath? - if outPath?.isSome then - let activeCount := result.active.card - IO.println - s!"ok: head output interval built (seq={seq}, dim={dModel}, active={activeCount})" - return 0 - -private def headScoreBoundsFromDotAbsTimed {seq dModel dHead : Nat} [NeZero seq] - (inputs : Model.InductionHeadInputs seq dModel dHead) - (dotAbs : Fin seq → Fin seq → Rat) : - IO (Sound.HeadScoreBounds seq dModel dHead) := do - timePure "head: score bounds" (fun () => - Sound.headScoreBoundsFromDotAbs inputs dotAbs) - -private def headScoreBoundsFromQAbsKAbsTimed {seq dModel dHead : Nat} [NeZero seq] - (inputs : Model.InductionHeadInputs seq dModel dHead) - (qAbs kAbs : Fin seq → Fin dHead → Rat) - (dotAbs : Fin seq → Fin seq → Rat) : - IO (Sound.HeadScoreBounds seq dModel dHead) := do - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => - |inputs.scale| * dotAbs q k - let scoreLo : Fin seq → Fin seq → Rat := fun q k => - if masked q k then inputs.maskValue else -scoreBaseAbs q k - let scoreHi : Fin seq → Fin seq → Rat := fun q k => - if masked q k then inputs.maskValue else scoreBaseAbs q k - let kAbsMax : Fin dHead → Rat := fun d => - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := Finset.univ_nonempty - univ.sup' hnonempty (fun k => kAbs k d) - let dotAbsUpper : Fin seq → Rat := fun q => - Sound.Linear.dotFin dHead (fun d => qAbs q d) kAbsMax - let scoreHiUpper : Fin seq → Rat := fun q => - max inputs.maskValue (|inputs.scale| * dotAbsUpper q) - let marginTasks : Array (Task Rat) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - if q ∈ inputs.active then - let prev := inputs.prev q - let scoreLoPrev := scoreLo q prev - scoreLoPrev - scoreHiUpper q - else - (0 : Rat))) - let marginAt : Fin seq → Rat := fun q => - (marginTasks[q.1]'(by - simp [marginTasks, q.isLt])).get - let epsTasks : Array (Task Rat) := - Array.ofFn (fun q : Fin seq => - (marginTasks[q.1]'(by - simp [marginTasks, q.isLt])).map (fun m => - if m < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + m))) - let epsAt : Fin seq → Rat := fun q => - (epsTasks[q.1]'(by - simp [epsTasks, q.isLt])).get - let margin : Rat := - if h : inputs.active.Nonempty then - inputs.active.inf' h marginAt - else - (0 : Rat) - let eps : Rat := - if margin < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + margin) - let result : Sound.HeadScoreBounds seq dModel dHead := - { dotAbs := dotAbs - scoreBaseAbs := scoreBaseAbs - scoreAbs := fun q k => if masked q k then |inputs.maskValue| else scoreBaseAbs q k - scoreLo := scoreLo - scoreHi := scoreHi - marginAt := marginAt - epsAt := epsAt - margin := margin - eps := eps } - return result - -private def checkInductionHeadInputs {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cfg : Sound.InductionHeadSplitConfig) - (minActive? : Option Nat) (minLogitDiff? : Option Rat) - (minMargin maxEps : Rat) (skipLogitDiff : Bool) : IO UInt32 := do - match seq with - | 0 => - IO.eprintln "error: seq must be positive" - return 2 - | Nat.succ n => - let seq := Nat.succ n - let _ : NeZero seq := ⟨by simp⟩ - logTiming "start: head build induction cert" - timingPrint "timing: head build induction cert start" - timingFlush - let verboseTiming ← IO.getEnv "NFP_TIMING_VERBOSE" - if verboseTiming.isSome then - timingPrint s!"timing: head dims seq={seq} dModel={dModel} dHead={dHead}" - timingPrint s!"timing: head active card={inputs.active.card}" - timingFlush - let precompute := (← IO.getEnv "NFP_TIMING_PRECOMPUTE").isSome - if precompute then - timingPrint "timing: head ln bounds start" - timingFlush - let lnBounds ← timePure "head: ln bounds" (fun () => - Sound.headLnBounds inputs) - timingPrint "timing: head ln bounds done" - timingFlush - timingPrint "timing: head qkv bounds start" - timingFlush - let lnLo := lnBounds.1 - let lnHi := lnBounds.2 - let qkv ← timePure "head: qkv bounds" (fun () => - Sound.headQKVBounds inputs lnLo lnHi) - timingPrint "timing: head qkv bounds done" - timingFlush - if verboseTiming.isSome then - timingPrint "timing: head qkv abs force start" - timingFlush - let tAbs0 ← monoUsNow - for q in List.finRange seq do - for d in List.finRange dHead do - let _ := qkv.qAbs q d - let _ := qkv.kAbs q d - pure () - let tAbs1 ← monoUsNow - timingPrint s!"timing: head qkv abs force {tAbs1 - tAbs0} us" - timingFlush - timingPrint "timing: head score/value bounds spawn start" - timingFlush - let tSpawn0 ← monoUsNow - if verboseTiming.isSome then - timingPrint "timing: head score dotAbs tasks start" - timingFlush - let dotAbs ← timePure "head: score dotAbs tasks" (fun () => - dotAbsFromQKV qkv.qAbs qkv.kAbs) - if verboseTiming.isSome then - timingPrint "timing: head score dotAbs tasks done" - timingFlush - if verboseTiming.isSome then - timingPrint "timing: head score dotAbs force start" - timingFlush - let tForce0 ← monoUsNow - match List.finRange seq with - | [] => - timingPrint "timing: head score dotAbs force skipped (empty seq)" - | q :: _ => - match List.finRange seq with - | [] => - timingPrint "timing: head score dotAbs force skipped (empty seq)" - | k :: _ => - let _ := dotAbs q k - pure () - let tForce1 ← monoUsNow - timingPrint s!"timing: head score dotAbs force {tForce1 - tForce0} us" - timingFlush - let inlineVals := (← IO.getEnv "NFP_TIMING_VALUE_INLINE").isSome - let valueMode? ← valueBoundsModeFromEnv - let useCommon := valueMode?.getD false - let (valsInline?, valsTask?) := - if inlineVals then - let vals := - if useCommon then - Sound.headValueBoundsCommonDen inputs qkv.vLo qkv.vHi - else - Sound.headValueBounds inputs qkv.vLo qkv.vHi - (some vals, none) - else - let task := - if useCommon then - Sound.headValueBoundsCommonDenTask inputs qkv.vLo qkv.vHi - else - Sound.headValueBoundsTask inputs qkv.vLo qkv.vHi - (none, some task) - let activeList := (List.finRange seq).filter (fun q => q ∈ inputs.active) - if verboseTiming.isSome then - timeHeadScoreMarginRaw inputs dotAbs activeList - let tSpawn1 ← monoUsNow - timingPrint s!"timing: head score/value bounds spawn {tSpawn1 - tSpawn0} us" - timingFlush - let skipScoreBounds := (← IO.getEnv "NFP_TIMING_SKIP_SCORE_BOUNDS").isSome - let scoreTaskOpt ← - if skipScoreBounds then - timingPrint "timing: head score bounds skipped" - pure none - else - timingPrint "timing: head score bounds from dotAbs start" - timingFlush - let exactMargin := (← IO.getEnv "NFP_TIMING_EXACT_MARGIN").isSome - let action := - if exactMargin then - headScoreBoundsFromDotAbsTimed inputs dotAbs - else - headScoreBoundsFromQAbsKAbsTimed inputs qkv.qAbs qkv.kAbs dotAbs - let t ← action.asTask - pure (some t) - if verboseTiming.isSome then - timingPrint "timing: head value parts start" - timingFlush - timingPrint "timing: head value dirHead start" - timingFlush - let tDir0 ← monoUsNow - let dirHead := Sound.headValueDirHead inputs - match List.finRange dHead with - | [] => - timingPrint "timing: head value dirHead forced skipped (empty dHead)" - | d :: _ => - let _ := dirHead d - pure () - let tDir1 ← monoUsNow - timingPrint s!"timing: head value dirHead {tDir1 - tDir0} us" - timingFlush - timingPrint "timing: head value valsLo start" - timingFlush - let tLo0 ← monoUsNow - let valsLo := Sound.headValueValsLo inputs qkv.vLo qkv.vHi - match List.finRange seq with - | [] => - timingPrint "timing: head value valsLo forced skipped (empty seq)" - | k :: _ => - let _ := valsLo k - pure () - let tLo1 ← monoUsNow - timingPrint s!"timing: head value valsLo {tLo1 - tLo0} us" - timingFlush - timingPrint "timing: head value valsHi start" - timingFlush - let tHi0 ← monoUsNow - let valsHi := Sound.headValueValsHi inputs qkv.vLo qkv.vHi - match List.finRange seq with - | [] => - timingPrint "timing: head value valsHi forced skipped (empty seq)" - | k :: _ => - let _ := valsHi k - pure () - let tHi1 ← monoUsNow - timingPrint s!"timing: head value valsHi {tHi1 - tHi0} us" - timingFlush - timingPrint "timing: head value lo start" - timingFlush - let tLo2 ← monoUsNow - let _ := Sound.headValueLo valsLo - let tLo3 ← monoUsNow - timingPrint s!"timing: head value lo {tLo3 - tLo2} us" - timingFlush - timingPrint "timing: head value hi start" - timingFlush - let tHi2 ← monoUsNow - let _ := Sound.headValueHi valsHi - let tHi3 ← monoUsNow - timingPrint s!"timing: head value hi {tHi3 - tHi2} us" - timingFlush - timingPrint "timing: head value parts done" - timingFlush - timingPrint "timing: head value bounds start" - timingFlush - let tVals0 ← monoUsNow - let vals ← - match valsInline?, valsTask? with - | some vals, _ => - timePure "head: value bounds inline" (fun () => vals) - | none, some valsTask => - timePure "head: value bounds wait" (fun () => valsTask.get) - | none, none => - timePure "head: value bounds inline" (fun () => - Sound.headValueBounds inputs qkv.vLo qkv.vHi) - let tVals1 ← monoUsNow - timingPrint s!"timing: head value bounds {tVals1 - tVals0} us" - timingFlush - let scoreOpt ← - match scoreTaskOpt with - | none => pure none - | some scoreTask => do - let res ← IO.wait scoreTask - let score ← unwrapTaskResult res - timingPrint "timing: head score bounds from dotAbs done" - timingFlush - pure (some score) - match scoreOpt with - | none => pure () - | some score => - if verboseTiming.isSome then - timeHeadScoreSampleGap inputs score - if verboseTiming.isSome then - timeHeadScoreMarginList activeList score - if verboseTiming.isSome then - timeHeadScoreFieldForces score - if verboseTiming.isSome then - timingPrint "timing: head score bounds force start" - timingFlush - let tScore0 ← monoUsNow - let _ := score.margin - let _ := score.eps - let tScore1 ← monoUsNow - timingPrint s!"timing: head score bounds force {tScore1 - tScore0} us" - timingFlush - let coreStages := (← IO.getEnv "NFP_TIMING_CORE_STAGES").isSome - let coreStagesOnly := (← IO.getEnv "NFP_TIMING_CORE_STAGES_ONLY").isSome - if coreStages then - timeInductionHeadCoreStages inputs - if coreStagesOnly then - return 0 - let breakdown := (← IO.getEnv "NFP_TIMING_BREAKDOWN").isSome - if breakdown then - let lnBounds ← timePureWithHeartbeat "breakdown: ln bounds" (fun () => - Sound.headLnBounds inputs) - timingPrint "timing: breakdown ln bounds force start" - timingFlush - let tLn0 ← monoUsNow - for q in List.finRange seq do - for i in List.finRange dModel do - let _ := lnBounds.1 q i - let _ := lnBounds.2 q i - pure () - let tLn1 ← monoUsNow - timingPrint s!"timing: breakdown ln bounds force {tLn1 - tLn0} us" - timingFlush - let qkv ← timePureWithHeartbeat "breakdown: qkv bounds" (fun () => - Sound.headQKVBounds inputs lnBounds.1 lnBounds.2) - timingPrint "timing: breakdown qkv bounds force start" - timingFlush - let tQkv0 ← monoUsNow - for q in List.finRange seq do - for d in List.finRange dHead do - let _ := qkv.qLo q d - let _ := qkv.qHi q d - let _ := qkv.kLo q d - let _ := qkv.kHi q d - let _ := qkv.vLo q d - let _ := qkv.vHi q d - let _ := qkv.qAbs q d - let _ := qkv.kAbs q d - pure () - let tQkv1 ← monoUsNow - timingPrint s!"timing: breakdown qkv bounds force {tQkv1 - tQkv0} us" - timingFlush - let dotAbs : Fin seq → Fin seq → Rat := fun q k => - Sound.Linear.dotFin dHead (fun d => qkv.qAbs q d) (fun d => qkv.kAbs k d) - let dotAbsRowTasks : - Array (Task { row : Array Rat // row.size = seq }) ← - timePureWithHeartbeat "breakdown: score dotAbs rows" (fun () => - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - ⟨Array.ofFn (fun k : Fin seq => dotAbs q k), by simp⟩))) - let dotAbsRowDefault : Task { row : Array Rat // row.size = seq } := - Task.spawn (fun _ => ⟨Array.ofFn (fun _ : Fin seq => (0 : Rat)), by simp⟩) - timingPrint "timing: breakdown score dotAbs force start" - timingFlush - let tDot0 ← monoUsNow - for q in List.finRange seq do - let row := (dotAbsRowTasks.getD q.1 dotAbsRowDefault).get - let _ := row - pure () - let tDot1 ← monoUsNow - timingPrint s!"timing: breakdown score dotAbs force {tDot1 - tDot0} us" - timingFlush - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let scaleAbs : Rat := |inputs.scale| - let marginAtRaw : Fin seq → Rat := fun q => - let row := (dotAbsRowTasks.getD q.1 dotAbsRowDefault).get - if q ∈ inputs.active then - let rowArr := row.1 - let prev := inputs.prev q - let dotAbsPrev := rowArr.getD prev.1 0 - if masked q prev then - let scoreLoPrev := inputs.maskValue - let scoreHiAt : Fin seq → Rat := fun k => - if masked q k then inputs.maskValue else scaleAbs * rowArr.getD k.1 0 - let maskedGap := scoreLoPrev - inputs.maskValue - let step : - (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := - fun acc k => - if k = prev then - acc - else if masked q k then - (acc.1, true) - else - let v := scoreLoPrev - scoreHiAt k - match acc.1 with - | none => (some v, acc.2) - | some cur => (some (min cur v), acc.2) - let acc := Sound.Linear.foldlFin seq step (none, false) - match acc.1, acc.2 with - | some unmaskedMin, true => min unmaskedMin maskedGap - | some unmaskedMin, false => unmaskedMin - | none, true => maskedGap - | none, false => (0 : Rat) - else - let scoreLoPrev := -(scaleAbs * dotAbsPrev) - let maskedGap := scoreLoPrev - inputs.maskValue - let step : - (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := - fun acc k => - if k = prev then - acc - else if masked q k then - (acc.1, true) - else - let raw := -(dotAbsPrev + rowArr.getD k.1 0) - match acc.1 with - | none => (some raw, acc.2) - | some cur => (some (min cur raw), acc.2) - let acc := Sound.Linear.foldlFin seq step (none, false) - match acc.1, acc.2 with - | some unmaskedMin, true => min (scaleAbs * unmaskedMin) maskedGap - | some unmaskedMin, false => scaleAbs * unmaskedMin - | none, true => maskedGap - | none, false => (0 : Rat) - else - (0 : Rat) - let marginAtCached : Fin seq → Rat ← - timePureWithHeartbeat "breakdown: score margin cache" (fun () => - Sound.Bounds.cacheBoundThunk marginAtRaw) - timingPrint "timing: breakdown score margin force start" - timingFlush - let tMargin0 ← monoUsNow - for q in List.finRange seq do - let m := marginAtCached q - forceRat m - pure () - let tMargin1 ← monoUsNow - timingPrint s!"timing: breakdown score margin force {tMargin1 - tMargin0} us" - timingFlush - let epsAtRaw : Fin seq → Rat := fun q => - let m := marginAtCached q - if m < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + m) - let epsAtCached : Fin seq → Rat ← - timePureWithHeartbeat "breakdown: score eps cache" (fun () => - Sound.Bounds.cacheBoundThunk epsAtRaw) - timingPrint "timing: breakdown score eps force start" - timingFlush - let tEps0 ← monoUsNow - for q in List.finRange seq do - let e := epsAtCached q - forceRat e - pure () - let tEps1 ← monoUsNow - timingPrint s!"timing: breakdown score eps force {tEps1 - tEps0} us" - timingFlush - let valsLo ← timePureWithHeartbeat "breakdown: value valsLo" (fun () => - Sound.headValueValsLo inputs qkv.vLo qkv.vHi) - timingPrint "timing: breakdown value valsLo force start" - timingFlush - let tValsLo0 ← monoUsNow - for k in List.finRange seq do - let v := valsLo k - forceRat v - pure () - let tValsLo1 ← monoUsNow - timingPrint s!"timing: breakdown value valsLo force {tValsLo1 - tValsLo0} us" - timingFlush - let valsHi ← timePureWithHeartbeat "breakdown: value valsHi" (fun () => - Sound.headValueValsHi inputs qkv.vLo qkv.vHi) - timingPrint "timing: breakdown value valsHi force start" - timingFlush - let tValsHi0 ← monoUsNow - for k in List.finRange seq do - let v := valsHi k - forceRat v - pure () - let tValsHi1 ← monoUsNow - timingPrint s!"timing: breakdown value valsHi force {tValsHi1 - tValsHi0} us" - timingFlush - let heartbeatMsProgress ← heartbeatMs - let taskMin (t1 t2 : Task Rat) : Task Rat := - Task.bind t1 (fun v1 => Task.map (fun v2 => min v1 v2) t2) - let taskMax (t1 t2 : Task Rat) : Task Rat := - Task.bind t1 (fun v1 => Task.map (fun v2 => max v1 v2) t2) - let reduceMinTasksWithProgress (tasks : Array (Task Rat)) : - IO Rat := do - let n := tasks.size - if n = 0 then - pure (0 : Rat) - else - let chunkSize : Nat := 16 - let chunks : Nat := (n + chunkSize - 1) / chunkSize - let defaultTask : Task Rat := Task.pure (0 : Rat) - let chunkTasks : Array (Task Rat) := - Array.ofFn (fun c : Fin chunks => - let start := c.val * chunkSize - let stop := Nat.min n (start + chunkSize) - let init := tasks.getD start defaultTask - if stop ≤ start + 1 then - init - else - let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) - rest.foldl (fun acc i => taskMin acc (tasks.getD i defaultTask)) init) - if heartbeatMsProgress ≠ 0 then - let mut finished := 0 - let mut remaining := chunkTasks.size - while finished < remaining do - IO.sleep heartbeatMsProgress - let mut count := 0 - for t in chunkTasks do - if (← IO.hasFinished t) then - count := count + 1 - finished := count - remaining := chunkTasks.size - if finished < remaining then - timingPrint s!"timing: breakdown value lo progress {finished}/{remaining}" - timingFlush - let init := chunkTasks.getD 0 defaultTask - let rest := (List.range (chunkTasks.size - 1)).map (fun i => i + 1) - pure ((rest.foldl (fun acc i => taskMin acc (chunkTasks.getD i defaultTask)) init).get) - let reduceMaxTasksWithProgress (tasks : Array (Task Rat)) : - IO Rat := do - let n := tasks.size - if n = 0 then - pure (0 : Rat) - else - let chunkSize : Nat := 16 - let chunks : Nat := (n + chunkSize - 1) / chunkSize - let defaultTask : Task Rat := Task.pure (0 : Rat) - let chunkTasks : Array (Task Rat) := - Array.ofFn (fun c : Fin chunks => - let start := c.val * chunkSize - let stop := Nat.min n (start + chunkSize) - let init := tasks.getD start defaultTask - if stop ≤ start + 1 then - init - else - let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) - rest.foldl (fun acc i => taskMax acc (tasks.getD i defaultTask)) init) - if heartbeatMsProgress ≠ 0 then - let mut finished := 0 - let mut remaining := chunkTasks.size - while finished < remaining do - IO.sleep heartbeatMsProgress - let mut count := 0 - for t in chunkTasks do - if (← IO.hasFinished t) then - count := count + 1 - finished := count - remaining := chunkTasks.size - if finished < remaining then - timingPrint s!"timing: breakdown value hi progress {finished}/{remaining}" - timingFlush - let init := chunkTasks.getD 0 defaultTask - let rest := (List.range (chunkTasks.size - 1)).map (fun i => i + 1) - pure ((rest.foldl (fun acc i => taskMax acc (chunkTasks.getD i defaultTask)) init).get) - if (← IO.getEnv "NFP_TIMING_TASK_PROGRESS").isSome then - let tasksLo := - (List.finRange seq).map (fun k => Task.spawn (fun _ => valsLo k)) - let tasksHi := - (List.finRange seq).map (fun k => Task.spawn (fun _ => valsHi k)) - let _ ← timePureWithHeartbeat "breakdown: value lo progress" (fun () => - reduceMinTasksWithProgress tasksLo.toArray) - let _ ← timePureWithHeartbeat "breakdown: value hi progress" (fun () => - reduceMaxTasksWithProgress tasksHi.toArray) - else - let loTask := Sound.headValueLoTask valsLo - let hiTask := Sound.headValueHiTask valsHi - let heartbeatMs ← heartbeatMs - let tLo0 ← monoUsNow - if heartbeatMs ≠ 0 then - let mut finished := (← IO.hasFinished loTask) - while !finished do - IO.sleep heartbeatMs - finished := (← IO.hasFinished loTask) - if !finished then - let now ← monoUsNow - timingPrint s!"timing: breakdown: value lo running {now - tLo0} us" - timingFlush - let lo := loTask.get - let tLo1 ← monoUsNow - timingPrint s!"timing: breakdown: value lo {tLo1 - tLo0} us" - timingFlush - let tHi0 ← monoUsNow - if heartbeatMs ≠ 0 then - let mut finished := (← IO.hasFinished hiTask) - while !finished do - IO.sleep heartbeatMs - finished := (← IO.hasFinished hiTask) - if !finished then - let now ← monoUsNow - timingPrint s!"timing: breakdown: value hi running {now - tHi0} us" - timingFlush - let hi := hiTask.get - let tHi1 ← monoUsNow - timingPrint s!"timing: breakdown: value hi {tHi1 - tHi0} us" - timingFlush - let _ := lo - let _ := hi - if (← IO.getEnv "NFP_TIMING_SEQ_REDUCE").isSome then - let loSeq ← timePureWithHeartbeat "breakdown: value lo seq" (fun () => - match List.finRange seq with - | [] => (0 : Rat) - | k :: ks => - let init := valsLo k - ks.foldl (fun acc k => min acc (valsLo k)) init) - let hiSeq ← timePureWithHeartbeat "breakdown: value hi seq" (fun () => - match List.finRange seq with - | [] => (0 : Rat) - | k :: ks => - let init := valsHi k - ks.foldl (fun acc k => max acc (valsHi k)) init) - let _ := loSeq - let _ := hiSeq - let tCert0 ← monoUsNow - let certTask : - Task - (Option { cache : Sound.InductionHeadCoreCache seq dModel dHead // - Sound.InductionHeadCertSound inputs cache.cert }) := - Task.spawn (prio := Task.Priority.dedicated) (fun _ => - match Sound.buildInductionCertFromHeadWithCache? cfg inputs with - | none => none - | some ⟨cache, hcert⟩ => - let _ := cache.cert.active.card - some ⟨cache, hcert⟩) - let heartbeatMs ← heartbeatMs - if heartbeatMs ≠ 0 then - let mut finished := (← IO.hasFinished certTask) - while !finished do - IO.sleep heartbeatMs - finished := (← IO.hasFinished certTask) - if !finished then - let now ← monoUsNow - timingPrint s!"timing: head build induction cert running {now - tCert0} us" - timingFlush - let certOpt ← IO.wait certTask - let tCert1 ← monoUsNow - logTiming s!"done: head build induction cert {tCert1 - tCert0} us" - timingPrint s!"timing: head build induction cert {tCert1 - tCert0} us" - timingPrint "timing: head build induction cert returned" - timingFlush - match certOpt with - | none => - IO.eprintln "error: head inputs rejected" - return 2 - | some ⟨cache, _hcert⟩ => - let cert := cache.cert - timingPrint "timing: head active count start" - timingFlush - let activeCount := cert.active.card - timingPrint "timing: head active count done" - timingFlush - let defaultMinActive := max 1 (seq / 8) - let minActive := minActive?.getD defaultMinActive - if activeCount < minActive then - IO.eprintln - s!"error: active queries {activeCount} below minimum {minActive}" - return 2 - if cert.margin < minMargin then - IO.eprintln - s!"error: margin {ratToString cert.margin} \ - below minimum {ratToString minMargin}" - return 2 - if maxEps < cert.eps then - IO.eprintln - s!"error: eps {ratToString cert.eps} \ - above maximum {ratToString maxEps}" - return 2 - if skipLogitDiff then - IO.println - s!"ok: induction head certificate built (seq={seq}, active={activeCount}, \ - margin={ratToString cert.margin}, eps={ratToString cert.eps}, \ - note=logit-diff skipped)" - return 0 - timingPrint "timing: head tol start" - timingFlush - let tol := cert.eps * (cert.values.hi - cert.values.lo) - timingPrint "timing: head tol done" - timingFlush - let effectiveMinLogitDiff := - match minLogitDiff? with - | some v => some v - | none => some (0 : Rat) - let logitCache := Nfp.Sound.logitDiffCache cert - let qOnlyExit ← - emitLogitDiffQueryOnly inputs cfg cache cert logitCache - if qOnlyExit then - return 2 - let emitLogitDiffDebug (info : Nfp.Sound.LogitDiffAtLoDebug seq) : IO Unit := do - IO.eprintln - s!"debug: logitDiffLB0 witness q={info.q.1}, prev={info.prev.1}" - IO.eprintln - s!"debug: eps={ratToString info.eps}, \ - valsPrevLo={ratToString info.valsPrevLo}, \ - loAt={ratToString info.loAt}, \ - lo={ratToString info.lo}" - IO.eprintln - s!"debug: valsPrevLoMinusLoAt={ratToString info.valsPrevLoMinusLoAt}, \ - gap={ratToString info.gap}, \ - fAtQ={ratToString (info.valsPrevLo - info.gap)}, \ - lbAtQ={ratToString info.lbAtQ}" - let weightBoundAt := cert.weightBoundAt - let step : (Rat × Nat × Rat) → Fin seq → (Rat × Nat × Rat) := - fun acc k => - if k = info.prev then - acc - else - let w := weightBoundAt info.q k - let sum := acc.1 + w - let ones := if w = (1 : Rat) then acc.2.1 + 1 else acc.2.1 - let maxW := if w > acc.2.2 then w else acc.2.2 - (sum, ones, maxW) - let acc := Sound.Linear.foldlFin seq step (0, 0, 0) - IO.eprintln - s!"debug: epsAt={ratToString (cert.epsAt info.q)}, \ - weightSum={ratToString acc.1}, ones={acc.2.1}, \ - maxWeight={ratToString acc.2.2}" - let valsLo := logitCache.valsLo - let stepOnes : Array String → Fin seq → Array String := - fun acc k => - if k = info.prev then - acc - else - let w := weightBoundAt info.q k - if w = (1 : Rat) then - acc.push - s!"k={k.1} valsLo={ratToString (valsLo k)}" - else - acc - let ones := Sound.Linear.foldlFin seq stepOnes #[] - let onesMsg := - if ones.isEmpty then - "none" - else - String.intercalate ", " ones.toList - IO.eprintln s!"debug: weightBoundAt=1 keys: {onesMsg}" - let stepLoAt : Array String → Fin seq → Array String := - fun acc k => - if k = info.prev then - acc - else if valsLo k = info.loAt then - acc.push - s!"k={k.1} w={ratToString (weightBoundAt info.q k)}" - else - acc - let loAtKeys := Sound.Linear.foldlFin seq stepLoAt #[] - let loAtMsg := - if loAtKeys.isEmpty then - "none" - else - String.intercalate ", " loAtKeys.toList - IO.eprintln s!"debug: loAt keys: {loAtMsg}" - let scoreLoPrev := cache.scoreLoPrev info.q - let stepAlt : - (Rat × Nat × Nat) → Fin seq → (Rat × Nat × Nat) := - fun acc k => - if k = info.prev then - acc - else - let g := scoreLoPrev - cache.scoreHi info.q k - let nonneg := if g ≥ (0 : Rat) then acc.2.1 + 1 else acc.2.1 - let gtNegOne := if g > (-1 : Rat) then acc.2.2 + 1 else acc.2.2 - let expLB := - if g ≥ (0 : Rat) then - (1 : Rat) + g + g * g / (2 : Rat) - else - max (0 : Rat) ((1 : Rat) + g) - let w := (1 : Rat) / ((1 : Rat) + expLB) - (acc.1 + w, nonneg, gtNegOne) - let accAlt := Sound.Linear.foldlFin seq stepAlt (0, 0, 0) - IO.eprintln - s!"debug: alt-exp epsAt={ratToString accAlt.1}, \ - g>=0={accAlt.2.1}, g>-1={accAlt.2.2}" - let stepMin : Option Rat → Fin seq → Option Rat := - fun acc k => - if k = info.prev then - acc - else - let g := scoreLoPrev - cache.scoreHi info.q k - match acc with - | none => some g - | some cur => some (min cur g) - let minGap := Sound.Linear.foldlFin seq stepMin none - IO.eprintln s!"debug: alt-exp min(scoreLoPrev-scoreHi)={ratOptToString minGap}" - if (← logitDiffRefineEnabled) then - let refineBudget := max 1 cfg.splitBudgetDiffRefined - let refineKeys := Sound.refineKeysAtWithWeightOnes inputs cache info.q refineBudget - IO.eprintln - s!"debug: refine budget={refineBudget}, \ - refineKeys.card={refineKeys.card}" - let refineSpec := - Sound.refineSpecForQueryWithWeightOnes inputs cache info.q refineBudget - let refinedLB? := - Sound.logitDiffLowerBoundRefinedFromCache - inputs cache cert logitCache refineSpec - match refinedLB? with - | none => - IO.eprintln "debug: refined logitDiffLB0 none" - | some lb => - IO.eprintln - s!"debug: refined logitDiffLB0={ratToString lb}" - logTiming "start: head logit-diff lower bound" - timingPrint "timing: head logit-diff lower bound start" - timingFlush - profileLogitDiffWeighted cert logitCache - let altQuery? ← logitDiffAltBoundQuery - match altQuery? with - | none => pure () - | some qNat => - if hq : qNat < seq then - let q : Fin seq := ⟨qNat, hq⟩ - let prev := cert.prev q - let scoreLoPrev := cache.scoreLoPrev q - let stepAlt : - (Rat × Nat × Nat) → Fin seq → (Rat × Nat × Nat) := - fun acc k => - if k = prev then - acc - else - let g := scoreLoPrev - cache.scoreHi q k - let nonneg := if g ≥ (0 : Rat) then acc.2.1 + 1 else acc.2.1 - let gtNegOne := if g > (-1 : Rat) then acc.2.2 + 1 else acc.2.2 - let expLB := - if g ≥ (0 : Rat) then - (1 : Rat) + g + g * g / (2 : Rat) - else - max (0 : Rat) ((1 : Rat) + g) - let w := (1 : Rat) / ((1 : Rat) + expLB) - (acc.1 + w, nonneg, gtNegOne) - let accAlt := Sound.Linear.foldlFin seq stepAlt (0, 0, 0) - let stepMin : Option Rat → Fin seq → Option Rat := - fun acc k => - if k = prev then - acc - else - let g := scoreLoPrev - cache.scoreHi q k - match acc with - | none => some g - | some cur => some (min cur g) - let minGap := Sound.Linear.foldlFin seq stepMin none - IO.eprintln - s!"debug: alt-exp q={qNat} prev={prev.1} \ - epsAt={ratToString accAlt.1} \ - g>=0={accAlt.2.1} g>-1={accAlt.2.2} \ - minGap={ratOptToString minGap}" - if (← logitDiffDebugEarlyExitEnabled) then - IO.eprintln "debug: early exit requested (alt bound)" - return 2 - else - IO.eprintln - s!"warn: NFP_LOGITDIFF_ALT_BOUND_Q={qNat} out of range (seq={seq})" - let earlyExit? ← - if (← logitDiffDebugEnabled) && (← logitDiffDebugEarlyExitEnabled) then - let debug? ← timePureWithHeartbeat - "head: logit-diff lower bound debug" (fun () => - Nfp.Sound.logitDiffLowerBoundAtLoDebug cert logitCache) - match debug? with - | none => - IO.eprintln "debug: logitDiffLB0 witness not found" - | some ⟨info, _⟩ => - emitLogitDiffDebug info - IO.eprintln "debug: early exit requested" - pure (some ()) - else - pure none - match earlyExit? with - | some _ => return 2 - | none => pure () - let weightedTask? : Option (Task (Option Rat)) := none - let logitDiffLB0? ← timePureWithHeartbeat - "head: logit-diff lower bound unweighted" (fun () => - Nfp.Sound.logitDiffLowerBoundRefineOnDemand inputs cache cert logitCache) - if (← logitDiffDebugEnabled) then - match logitDiffLB0? with - | some lb0 => - if lb0 ≤ 0 then - match Nfp.Sound.logitDiffLowerBoundAtLoDebug cert logitCache with - | none => - IO.eprintln "debug: logitDiffLB0 witness not found" - | some ⟨info, _⟩ => - emitLogitDiffDebug info - | none => pure () - let needsWeighted : Bool := - match logitDiffLB0? with - | none => true - | some lb0 => - if lb0 ≤ 0 then - true - else - match minLogitDiff? with - | some minLogitDiff => lb0 < minLogitDiff - | none => false - let logitDiffWeighted? ← - if needsWeighted then - match weightedTask? with - | some task => - timePureWithHeartbeat - "head: logit-diff lower bound weighted" (fun () => - task.get) - | none => - timePureWithHeartbeat - "head: logit-diff lower bound weighted" (fun () => - Nfp.Sound.logitDiffLowerBoundWeightedFromCache cert logitCache) - else - pure none - let logitDiffLB? : Option Rat := - match logitDiffLB0?, logitDiffWeighted? with - | some lb0, some lb1 => some (max lb0 lb1) - | some lb0, none => some lb0 - | none, some lb1 => some lb1 - | none, none => none - let boundLabel : String := - match logitDiffLB0?, logitDiffWeighted? with - | some _, some _ => "max" - | none, some _ => "weighted" - | some _, none => "eps" - | none, none => "none" - logTiming "done: head logit-diff lower bound" - match logitDiffLB? with - | none => - IO.eprintln "error: empty active set for logit-diff bound" - return 2 - | some logitDiffLB => - if logitDiffLB ≤ 0 then - if (← logitDiffDebugEnabled) then - IO.eprintln - s!"debug: logitDiffLB0={ratOptToString logitDiffLB0?}, \ - logitDiffWeighted={ratOptToString logitDiffWeighted?}, \ - logitDiffLB={ratToString logitDiffLB}, \ - bound={boundLabel}" - IO.eprintln - s!"error: logitDiffLB {ratToString logitDiffLB} \ - is not strictly positive" - return 2 - let violation? : Option Rat := - match minLogitDiff? with - | none => none - | some minLogitDiff => - if logitDiffLB < minLogitDiff then - some minLogitDiff - else - none - match violation? with - | some minLogitDiff => - IO.eprintln - s!"error: logitDiffLB {ratToString logitDiffLB} \ - below minimum {ratToString minLogitDiff}" - return 2 - | none => pure () - let tol := cert.eps * (cert.values.hi - cert.values.lo) - IO.println - s!"ok: nonvacuous induction bound certified \ - (seq={seq}, active={activeCount}, \ - tol={ratToString tol}, logitDiffLB={ratToString logitDiffLB}, \ - bound={boundLabel})" - return 0 -/-- Build and check induction certificates from exact head inputs. -/ -def runInductionCertifyHead (inputsPath : System.FilePath) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) - (timing? : Option Nat) (heartbeatMs? : Option Nat) - (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) - (skipLogitDiff : Bool) : - IO UInt32 := do - warnDeprecated - "certify_head builds certificates from head inputs; use explicit certs \ - via `nfp induction certify --cert` or `nfp induction head_cert_check`." - configureTiming timing? heartbeatMs? - let splitCfg := - splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - let parsedInputs ← timePhase "load head inputs" <| - loadInductionHeadInputs inputsPath - match parsedInputs with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => - checkInductionHeadInputs inputs splitCfg minActive? minLogitDiff? minMargin maxEps - skipLogitDiff - -/-- Build and check induction certificates from a model binary. -/ -def runInductionCertifyHeadModel (modelPath : System.FilePath) - (layer head dirTarget dirNegative : Nat) (period? : Option Nat) (shiftPrev : Bool) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) - (timing? : Option Nat) (heartbeatMs? : Option Nat) - (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) - (skipLogitDiff : Bool) : - IO UInt32 := do - warnDeprecated - "certify_head_model builds certificates from a model file; use explicit certs \ - via `nfp induction certify --cert` or `nfp induction head_cert_check`." - configureTiming timing? heartbeatMs? - let splitCfg := - splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - logTiming "start: read model file" - timingPrint "timing: read model file start" - timingFlush - let data ← timePhase "read model file" <| IO.FS.readBinFile modelPath - let headerE ← timePure "parse model header" (fun () => - NfptPure.parseHeader data) - match headerE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨header, start⟩ => - let inputsE ← timePure "read head inputs" (fun () => - NfptPure.readInductionHeadInputs - data start header layer head dirTarget dirNegative period? shiftPrev) - match inputsE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok inputs => - checkInductionHeadInputs inputs splitCfg minActive? minLogitDiff? minMargin maxEps - skipLogitDiff - -/-- Heuristic logit-diff direction derived from prompt tokens. -/ -def deriveDirectionFromTokens {seq : Nat} (tokens : Fin seq → Nat) : - Except String (Nat × Nat) := do - let tokenArr : Array Nat := Array.ofFn (fun i : Fin seq => tokens i) - let n := tokenArr.size - if n < 2 then - throw "token sequence must have length at least 2" - let lastTok := tokenArr.getD (n - 1) 0 - let prevIdx? := - (List.range (n - 1)).reverse.find? (fun i => - tokenArr.getD i lastTok = lastTok) - let targetTok := - match prevIdx? with - | some i => tokenArr.getD (i + 1) lastTok - | none => lastTok - let neg0 := tokenArr.getD (n - 2) lastTok - let neg := - if neg0 = targetTok then - if lastTok ≠ targetTok then - lastTok - else if targetTok ≠ 0 then - 0 - else - 1 - else - neg0 - return (targetTok, neg) - -/-- Build and check induction certificates from a model binary, deriving direction tokens from the -prompt sequence. -/ -def runInductionCertifyHeadModelAuto (modelPath : System.FilePath) - (layer head : Nat) (period? : Option Nat) (shiftPrev : Bool) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) - (timing? : Option Nat) (heartbeatMs? : Option Nat) - (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) - (skipLogitDiff : Bool) : - IO UInt32 := do - warnDeprecated - "certify_head_model_auto builds certificates from a model file; use explicit certs \ - via `nfp induction certify --cert` or `nfp induction head_cert_check`." - configureTiming timing? heartbeatMs? - let splitCfg := - splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - logTiming "start: read model file" - timingPrint "timing: read model file start" - timingFlush - let data ← timePhase "read model file" <| IO.FS.readBinFile modelPath - let headerE ← timePure "parse model header" (fun () => - NfptPure.parseHeader data) - match headerE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨header, start⟩ => - let tokensE ← timePure "read prompt tokens" (fun () => - NfptPure.readTokens data start header) - match tokensE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok tokens => - match deriveDirectionFromTokens tokens with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨dirTarget, dirNegative⟩ => - IO.println - s!"info: direction-target={dirTarget} direction-negative={dirNegative}" - let inputsE ← timePure "read head inputs" (fun () => - NfptPure.readInductionHeadInputs - data start header layer head dirTarget dirNegative period? shiftPrev) - match inputsE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok inputs => - checkInductionHeadInputs inputs splitCfg minActive? minLogitDiff? - minMargin maxEps skipLogitDiff - -/-- Build head-output interval bounds from exact head inputs. -/ -def runInductionHeadInterval (inputsPath : System.FilePath) - (outPath? : Option System.FilePath) : IO UInt32 := do - warnDeprecated - "head_interval builds interval bounds from head inputs; use explicit interval certs instead." - let parsedInputs ← loadInductionHeadInputs inputsPath - match parsedInputs with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => - buildHeadOutputIntervalFromInputs inputs outPath? - -/-- Build head-output interval bounds from a model binary. -/ -def runInductionHeadIntervalModel (modelPath : System.FilePath) - (layer head dirTarget dirNegative : Nat) (period? : Option Nat) (shiftPrev : Bool) - (outPath? : Option System.FilePath) : IO UInt32 := do - warnDeprecated - "head_interval_model builds interval bounds from a model file; \ - use explicit interval certs instead." - let data ← IO.FS.readBinFile modelPath - match NfptPure.parseHeader data with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨header, start⟩ => - match - NfptPure.readInductionHeadInputs - data start header layer head dirTarget dirNegative period? shiftPrev - with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok inputs => - buildHeadOutputIntervalFromInputs inputs outPath? - -end IO - -end Nfp diff --git a/Nfp/IO/InductionHead/Cert.lean b/Nfp/IO/InductionHead/Cert.lean index 36dc961..317d520 100644 --- a/Nfp/IO/InductionHead/Cert.lean +++ b/Nfp/IO/InductionHead/Cert.lean @@ -285,7 +285,7 @@ private def finalizeState {seq : Nat} (hpos : 0 < seq) (st : ParseState seq) : st.active else (Finset.univ : Finset (Fin seq)).erase defaultPrev - return + pure { eps := eps epsAt := epsAtFun weightBoundAt := weightBoundAtFun @@ -417,4 +417,3 @@ def runInductionHeadCertCheck (certPath : System.FilePath) end IO end Nfp - diff --git a/Nfp/IO/InductionHead/Circuit.lean b/Nfp/IO/InductionHead/Circuit.lean deleted file mode 100644 index ef6b580..0000000 --- a/Nfp/IO/InductionHead/Circuit.lean +++ /dev/null @@ -1,83 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.IO.InductionHead.Basic - -/-! -IO helpers for induction-circuit checks (previous-token head + induction head). --/ - -public section - -namespace Nfp - -namespace IO - -/-- Check a two-head induction circuit directly from a model binary. - -The induction head is certified with shifted `prev` (canonical circuit), while -the previous-token head uses the unshifted period-1 map. --/ -def runInductionCertifyCircuitModel (modelPath : System.FilePath) - (prevLayer prevHead indLayer indHead dirTarget dirNegative : Nat) (period : Nat) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) - (timing? : Option Nat) (heartbeatMs? : Option Nat) - (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) - (skipLogitDiff : Bool) : - IO UInt32 := do - warnDeprecated - "certify_circuit_model builds certificates from a model file; \ - use explicit certs for each head instead." - let prevCode ← - runInductionCertifyHeadModel - modelPath - prevLayer - prevHead - dirTarget - dirNegative - (some 1) - false - minActive? - minLogitDiffStr? - minMarginStr? - maxEpsStr? - timing? - heartbeatMs? - splitBudgetQ? - splitBudgetK? - splitBudgetDiffBase? - splitBudgetDiffRefined? - skipLogitDiff - if prevCode ≠ 0 then - return prevCode - let indCode ← - runInductionCertifyHeadModel - modelPath - indLayer - indHead - dirTarget - dirNegative - (some period) - true - minActive? - minLogitDiffStr? - minMarginStr? - maxEpsStr? - timing? - heartbeatMs? - splitBudgetQ? - splitBudgetK? - splitBudgetDiffBase? - splitBudgetDiffRefined? - skipLogitDiff - if indCode ≠ 0 then - return indCode - IO.println - "ok: circuit head certificates built (prev-token head + shifted-prev induction head)" - return 0 - -end IO - -end Nfp diff --git a/Nfp/IO/InductionHead/Nonvacuous.lean b/Nfp/IO/InductionHead/Nonvacuous.lean deleted file mode 100644 index ba74348..0000000 --- a/Nfp/IO/InductionHead/Nonvacuous.lean +++ /dev/null @@ -1,531 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -import Nfp.IO.InductionHead.Basic - -/-! -IO helpers for nonvacuous induction-head certificate checks. --/ - -public section - -namespace Nfp - -namespace IO - -/-- Build and check induction certificates from exact head inputs. -/ -private def checkInductionHeadInputsNonvacuous {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cfg : Sound.InductionHeadSplitConfig) - (minActive? : Option Nat) (minLogitDiff? : Option Rat) - (minMargin? : Option Rat) (maxEps : Rat) : IO UInt32 := do - match seq with - | 0 => - IO.eprintln "error: seq must be positive" - return 2 - | Nat.succ n => - let seq := Nat.succ n - let _ : NeZero seq := ⟨by simp⟩ - logTiming "start: head build induction cert" - timingPrint "timing: head build induction cert start" - timingFlush - let tCert0 ← monoUsNow - let certTask : - Task - (Option { cache : Sound.InductionHeadCoreCache seq dModel dHead // - Sound.InductionHeadCertSound inputs cache.cert }) := - Task.spawn (prio := Task.Priority.dedicated) (fun _ => - match Sound.buildInductionCertFromHeadWithCache? cfg inputs with - | none => none - | some ⟨cache, hcert⟩ => - let _ := cache.cert.active.card - some ⟨cache, hcert⟩) - let heartbeatMs ← heartbeatMs - if heartbeatMs ≠ 0 then - let mut finished := (← IO.hasFinished certTask) - while !finished do - IO.sleep heartbeatMs - finished := (← IO.hasFinished certTask) - if !finished then - let now ← monoUsNow - timingPrint s!"timing: head build induction cert running {now - tCert0} us" - timingFlush - let certOpt ← IO.wait certTask - let tCert1 ← monoUsNow - logTiming s!"done: head build induction cert {tCert1 - tCert0} us" - timingPrint s!"timing: head build induction cert {tCert1 - tCert0} us" - timingPrint "timing: head build induction cert returned" - timingFlush - match certOpt with - | none => - IO.eprintln "error: head inputs rejected" - return 2 - | some ⟨cache, _hcert⟩ => - let cert := cache.cert - let activeCount := cert.active.card - let defaultMinActive := max 1 (seq / 8) - let minActive := minActive?.getD defaultMinActive - if activeCount < minActive then - IO.eprintln - s!"error: active queries {activeCount} below minimum {minActive}" - return 2 - if maxEps < cert.eps then - IO.eprintln - s!"error: eps {ratToString cert.eps} above maximum {ratToString maxEps}" - return 2 - let marginViolation? : Option Rat := - match minMargin? with - | none => none - | some minMargin => - if cert.margin < minMargin then - some minMargin - else - none - match marginViolation? with - | some minMargin => - IO.eprintln - s!"error: margin {ratToString cert.margin} \ - below minimum {ratToString minMargin}" - return 2 - | none => pure () - logTiming "start: head logit-diff lower bound" - timingPrint "timing: head logit-diff lower bound start" - timingFlush - let logitCache := Nfp.Sound.logitDiffCache cert - let qOnlyExit ← - emitLogitDiffQueryOnly inputs cfg cache cert logitCache - if qOnlyExit then - return 2 - let emitLogitDiffDebug (info : Nfp.Sound.LogitDiffAtLoDebug seq) : IO Unit := do - IO.eprintln - s!"debug: logitDiffLB0 witness q={info.q.1}, prev={info.prev.1}" - IO.eprintln - s!"debug: eps={ratToString info.eps}, \ - valsPrevLo={ratToString info.valsPrevLo}, \ - loAt={ratToString info.loAt}, \ - lo={ratToString info.lo}" - IO.eprintln - s!"debug: valsPrevLoMinusLoAt={ratToString info.valsPrevLoMinusLoAt}, \ - gap={ratToString info.gap}, \ - fAtQ={ratToString (info.valsPrevLo - info.gap)}, \ - lbAtQ={ratToString info.lbAtQ}" - let weightBoundAt := cert.weightBoundAt - let step : (Rat × Nat × Rat) → Fin seq → (Rat × Nat × Rat) := - fun acc k => - if k = info.prev then - acc - else - let w := weightBoundAt info.q k - let sum := acc.1 + w - let ones := if w = (1 : Rat) then acc.2.1 + 1 else acc.2.1 - let maxW := if w > acc.2.2 then w else acc.2.2 - (sum, ones, maxW) - let acc := Sound.Linear.foldlFin seq step (0, 0, 0) - IO.eprintln - s!"debug: epsAt={ratToString (cert.epsAt info.q)}, \ - weightSum={ratToString acc.1}, ones={acc.2.1}, \ - maxWeight={ratToString acc.2.2}" - let valsLo := logitCache.valsLo - let stepOnes : Array String → Fin seq → Array String := - fun acc k => - if k = info.prev then - acc - else - let w := weightBoundAt info.q k - if w = (1 : Rat) then - acc.push - s!"k={k.1} valsLo={ratToString (valsLo k)}" - else - acc - let ones := Sound.Linear.foldlFin seq stepOnes #[] - let onesMsg := - if ones.isEmpty then - "none" - else - String.intercalate ", " ones.toList - IO.eprintln s!"debug: weightBoundAt=1 keys: {onesMsg}" - let stepLoAt : Array String → Fin seq → Array String := - fun acc k => - if k = info.prev then - acc - else if valsLo k = info.loAt then - acc.push - s!"k={k.1} w={ratToString (weightBoundAt info.q k)}" - else - acc - let loAtKeys := Sound.Linear.foldlFin seq stepLoAt #[] - let loAtMsg := - if loAtKeys.isEmpty then - "none" - else - String.intercalate ", " loAtKeys.toList - IO.eprintln s!"debug: loAt keys: {loAtMsg}" - let scoreLoPrev := cache.scoreLoPrev info.q - let stepAlt : - (Rat × Nat × Nat) → Fin seq → (Rat × Nat × Nat) := - fun acc k => - if k = info.prev then - acc - else - let g := scoreLoPrev - cache.scoreHi info.q k - let nonneg := if g ≥ (0 : Rat) then acc.2.1 + 1 else acc.2.1 - let gtNegOne := if g > (-1 : Rat) then acc.2.2 + 1 else acc.2.2 - let expLB := - if g ≥ (0 : Rat) then - (1 : Rat) + g + g * g / (2 : Rat) - else - max (0 : Rat) ((1 : Rat) + g) - let w := (1 : Rat) / ((1 : Rat) + expLB) - (acc.1 + w, nonneg, gtNegOne) - let accAlt := Sound.Linear.foldlFin seq stepAlt (0, 0, 0) - IO.eprintln - s!"debug: alt-exp epsAt={ratToString accAlt.1}, \ - g>=0={accAlt.2.1}, g>-1={accAlt.2.2}" - let stepMin : Option Rat → Fin seq → Option Rat := - fun acc k => - if k = info.prev then - acc - else - let g := scoreLoPrev - cache.scoreHi info.q k - match acc with - | none => some g - | some cur => some (min cur g) - let minGap := Sound.Linear.foldlFin seq stepMin none - IO.eprintln s!"debug: alt-exp min(scoreLoPrev-scoreHi)={ratOptToString minGap}" - if (← logitDiffRefineEnabled) then - let refineBudget := max 1 cfg.splitBudgetDiffRefined - let refineKeys := Sound.refineKeysAtWithWeightOnes inputs cache info.q refineBudget - IO.eprintln - s!"debug: refine budget={refineBudget}, \ - refineKeys.card={refineKeys.card}" - let refineSpec := - Sound.refineSpecForQueryWithWeightOnes inputs cache info.q refineBudget - let refinedLB? := - Sound.logitDiffLowerBoundRefinedFromCache - inputs cache cert logitCache refineSpec - match refinedLB? with - | none => - IO.eprintln "debug: refined logitDiffLB0 none" - | some lb => - IO.eprintln - s!"debug: refined logitDiffLB0={ratToString lb}" - let profiling ← logitDiffProfileEnabled - if profiling then - profileLogitDiffWeighted cert logitCache - else - pure () - let altQuery? ← logitDiffAltBoundQuery - match altQuery? with - | none => pure () - | some qNat => - if hq : qNat < seq then - let q : Fin seq := ⟨qNat, hq⟩ - let prev := cert.prev q - let scoreLoPrev := cache.scoreLoPrev q - let stepAlt : - (Rat × Nat × Nat) → Fin seq → (Rat × Nat × Nat) := - fun acc k => - if k = prev then - acc - else - let g := scoreLoPrev - cache.scoreHi q k - let nonneg := if g ≥ (0 : Rat) then acc.2.1 + 1 else acc.2.1 - let gtNegOne := if g > (-1 : Rat) then acc.2.2 + 1 else acc.2.2 - let expLB := - if g ≥ (0 : Rat) then - (1 : Rat) + g + g * g / (2 : Rat) - else - max (0 : Rat) ((1 : Rat) + g) - let w := (1 : Rat) / ((1 : Rat) + expLB) - (acc.1 + w, nonneg, gtNegOne) - let accAlt := Sound.Linear.foldlFin seq stepAlt (0, 0, 0) - let stepMin : Option Rat → Fin seq → Option Rat := - fun acc k => - if k = prev then - acc - else - let g := scoreLoPrev - cache.scoreHi q k - match acc with - | none => some g - | some cur => some (min cur g) - let minGap := Sound.Linear.foldlFin seq stepMin none - IO.eprintln - s!"debug: alt-exp q={qNat} prev={prev.1} \ - epsAt={ratToString accAlt.1} \ - g>=0={accAlt.2.1} g>-1={accAlt.2.2} \ - minGap={ratOptToString minGap}" - if (← logitDiffDebugEarlyExitEnabled) then - IO.eprintln "debug: early exit requested (alt bound)" - return 2 - else - IO.eprintln - s!"warn: NFP_LOGITDIFF_ALT_BOUND_Q={qNat} out of range (seq={seq})" - let earlyExit? ← - if (← logitDiffDebugEnabled) && (← logitDiffDebugEarlyExitEnabled) then - let debug? ← timePureWithHeartbeat - "head: logit-diff lower bound debug" (fun () => - Nfp.Sound.logitDiffLowerBoundAtLoDebug cert logitCache) - match debug? with - | none => - IO.eprintln "debug: logitDiffLB0 witness not found" - | some ⟨info, _⟩ => - emitLogitDiffDebug info - IO.eprintln "debug: early exit requested" - pure (some ()) - else - pure none - match earlyExit? with - | some _ => return 2 - | none => pure () - let weightedTask? : Option (Task (Option Rat)) := - if profiling then - none - else - some (Task.spawn (fun _ => - Nfp.Sound.logitDiffLowerBoundWeightedFromCache cert logitCache)) - let logitDiffLB0? ← timePureWithHeartbeat - "head: logit-diff lower bound unweighted" (fun () => - Nfp.Sound.logitDiffLowerBoundRefineOnDemand inputs cache cert logitCache) - if (← logitDiffDebugEnabled) then - match logitDiffLB0? with - | some lb0 => - if lb0 ≤ 0 then - match Nfp.Sound.logitDiffLowerBoundAtLoDebug cert logitCache with - | none => - IO.eprintln "debug: logitDiffLB0 witness not found" - | some ⟨info, _⟩ => emitLogitDiffDebug info - | none => pure () - let needsWeighted : Bool := - match logitDiffLB0? with - | none => true - | some lb0 => - if lb0 ≤ 0 then - true - else - match minLogitDiff? with - | some minLogitDiff => lb0 < minLogitDiff - | none => false - let logitDiffWeighted? ← - if needsWeighted then - match weightedTask? with - | some task => - timePureWithHeartbeat - "head: logit-diff lower bound weighted" (fun () => - task.get) - | none => - timePureWithHeartbeat - "head: logit-diff lower bound weighted" (fun () => - Nfp.Sound.logitDiffLowerBoundWeightedFromCache cert logitCache) - else - pure none - let logitDiffLB? : Option Rat := - match logitDiffLB0?, logitDiffWeighted? with - | some lb0, some lb1 => some (max lb0 lb1) - | some lb0, none => some lb0 - | none, some lb1 => some lb1 - | none, none => none - let boundLabel : String := - match logitDiffLB0?, logitDiffWeighted? with - | some _, some _ => "max" - | none, some _ => "weighted" - | some _, none => "eps" - | none, none => "none" - logTiming "done: head logit-diff lower bound" - match logitDiffLB? with - | none => - IO.eprintln "error: empty active set for logit-diff bound" - return 2 - | some logitDiffLB => - if logitDiffLB ≤ 0 then - if (← logitDiffDebugEnabled) then - IO.eprintln - s!"debug: logitDiffLB0={ratOptToString logitDiffLB0?}, \ - logitDiffWeighted={ratOptToString logitDiffWeighted?}, \ - logitDiffLB={ratToString logitDiffLB}, \ - bound={boundLabel}" - IO.eprintln - s!"error: logitDiffLB {ratToString logitDiffLB} \ - is not strictly positive" - return 2 - let violation? : Option Rat := - match minLogitDiff? with - | none => none - | some minLogitDiff => - if logitDiffLB < minLogitDiff then - some minLogitDiff - else - none - match violation? with - | some minLogitDiff => - IO.eprintln - s!"error: logitDiffLB {ratToString logitDiffLB} \ - below minimum {ratToString minLogitDiff}" - return 2 - | none => pure () - let tol := cert.eps * (cert.values.hi - cert.values.lo) - IO.println - s!"ok: nonvacuous induction bound certified \ - (seq={seq}, active={activeCount}, \ - tol={ratToString tol}, logitDiffLB={ratToString logitDiffLB}, \ - bound={boundLabel})" - return 0 - -/-- Build and check a strictly positive induction logit-diff bound from head inputs. -/ -def runInductionCertifyHeadNonvacuous (inputsPath : System.FilePath) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) - (timing? : Option Nat) (heartbeatMs? : Option Nat) - (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : - IO UInt32 := do - warnDeprecated - "certify_head_nonvacuous builds certificates from head inputs; \ - use explicit certs via `nfp induction certify_nonvacuous --cert`." - configureTiming timing? heartbeatMs? - let splitCfg := - splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - let parsedInputs ← timePhase "load head inputs" <| - loadInductionHeadInputs inputsPath - match parsedInputs with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨_seq, ⟨_dModel, ⟨_dHead, inputs⟩⟩⟩ => - checkInductionHeadInputsNonvacuous inputs splitCfg minActive? minLogitDiff? - minMargin? maxEps - -/-- Build and check a strictly positive induction logit-diff bound from a model binary. -/ -def runInductionCertifyHeadModelNonvacuous (modelPath : System.FilePath) - (layer head dirTarget dirNegative : Nat) (period? : Option Nat) (shiftPrev : Bool) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) - (timing? : Option Nat) (heartbeatMs? : Option Nat) - (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : - IO UInt32 := do - warnDeprecated - "certify_head_model_nonvacuous builds certificates from a model file; \ - use explicit certs via `nfp induction certify_nonvacuous --cert`." - configureTiming timing? heartbeatMs? - let splitCfg := - splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - logTiming "start: read model file" - timingPrint "timing: read model file start" - timingFlush - let data ← timePhase "read model file" <| IO.FS.readBinFile modelPath - let headerE ← timePure "parse model header" (fun () => - NfptPure.parseHeader data) - match headerE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨header, start⟩ => - let inputsE ← timePure "read head inputs" (fun () => - NfptPure.readInductionHeadInputs - data start header layer head dirTarget dirNegative period? shiftPrev) - match inputsE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok inputs => - checkInductionHeadInputsNonvacuous inputs splitCfg minActive? minLogitDiff? - minMargin? maxEps - -/-- Build and check a strictly positive induction logit-diff bound from a model binary, deriving -direction tokens from the prompt sequence. -/ -def runInductionCertifyHeadModelAutoNonvacuous (modelPath : System.FilePath) - (layer head : Nat) (period? : Option Nat) (shiftPrev : Bool) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) - (timing? : Option Nat) (heartbeatMs? : Option Nat) - (splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? : Option Nat) : - IO UInt32 := do - warnDeprecated - "certify_head_model_auto_nonvacuous builds certificates from a model file; \ - use explicit certs via `nfp induction certify_nonvacuous --cert`." - configureTiming timing? heartbeatMs? - let splitCfg := - splitConfigFromOptions splitBudgetQ? splitBudgetK? splitBudgetDiffBase? splitBudgetDiffRefined? - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - logTiming "start: read model file" - timingPrint "timing: read model file start" - timingFlush - let data ← timePhase "read model file" <| IO.FS.readBinFile modelPath - let headerE ← timePure "parse model header" (fun () => - NfptPure.parseHeader data) - match headerE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨header, start⟩ => - let tokensE ← timePure "read prompt tokens" (fun () => - NfptPure.readTokens data start header) - match tokensE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok tokens => - match deriveDirectionFromTokens tokens with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨dirTarget, dirNegative⟩ => - IO.println - s!"info: direction-target={dirTarget} direction-negative={dirNegative}" - let inputsE ← timePure "read head inputs" (fun () => - NfptPure.readInductionHeadInputs - data start header layer head dirTarget dirNegative period? shiftPrev) - match inputsE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok inputs => - checkInductionHeadInputsNonvacuous inputs splitCfg minActive? minLogitDiff? - minMargin? maxEps - -end IO - -end Nfp diff --git a/Nfp/IO/NfptPure.lean b/Nfp/IO/NfptPure.lean deleted file mode 100644 index 483e405..0000000 --- a/Nfp/IO/NfptPure.lean +++ /dev/null @@ -1,804 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Mathlib.Data.List.Range -public import Nfp.Core.Basic -public import Nfp.Model.Gpt2 -public import Nfp.Model.InductionHead -public import Nfp.Model.InductionPrompt - -/-! -Pure parsing utilities for `NFP_BINARY_V1` model files. - -These helpers parse headers and extract selected weight slices as rational values. --/ - -public section - -namespace Nfp - -namespace IO - -namespace NfptPure - -/-- Required header fields for NFP binary models. -/ -structure NfptHeader where - /-- Number of transformer layers. -/ - numLayers : Nat - /-- Number of attention heads per layer. -/ - numHeads : Nat - /-- Model dimension. -/ - modelDim : Nat - /-- Head dimension. -/ - headDim : Nat - /-- MLP hidden dimension. -/ - hiddenDim : Nat - /-- Vocabulary size. -/ - vocabSize : Nat - /-- Sequence length used in the binary. -/ - seqLen : Nat - /-- LayerNorm epsilon parameter. -/ - layerNormEps : Rat - -/-- Array with a fixed size proof. -/ -structure SizedArray (n : Nat) (α : Type) where - /-- Underlying array data. -/ - data : Array α - /-- Size proof for the array. -/ - size_eq : data.size = n - -/-- Index into a `SizedArray` using a `Fin`. -/ -def SizedArray.get {n : Nat} {α : Type} (arr : SizedArray n α) (i : Fin n) : α := - arr.data[i.val]'(by simp [arr.size_eq]) - -private def parseNat (s : String) : Except String Nat := - match s.toNat? with - | some n => Except.ok n - | none => Except.error s!"expected Nat, got '{s}'" - -private def splitKV (line : String) : Option (String × String) := - match line.splitOn "=" with - | [k, v] => some (k.trim, v.trim) - | _ => none - -private def readHeaderField (name : String) (fields : List (String × String)) : - Except String Nat := do - match fields.find? (fun kv => kv.1 = name) with - | some kv => parseNat kv.2 - | none => throw s!"missing header field '{name}'" - -private def parseInt (s : String) : Except String Int := - match s.toInt? with - | some n => Except.ok n - | none => Except.error s!"expected Int, got '{s}'" - -private def pow10 (k : Nat) : Nat := - Nat.pow 10 k - -private def parseRatScientific (s : String) : Except String Rat := do - let s := s.trim - let (sign, rest) := - if s.startsWith "-" then - (-1, s.drop 1) - else if s.startsWith "+" then - (1, s.drop 1) - else - (1, s) - let parts := rest.toLower.splitOn "e" - let (mant, expStr?) ← - match parts with - | [m] => pure (m, none) - | [m, e] => pure (m, some e) - | _ => throw s!"invalid scientific literal '{s}'" - let (intPart, fracPart) ← - match mant.splitOn "." with - | [i] => pure (i, "") - | [i, f] => pure (i, f) - | _ => throw s!"invalid decimal literal '{s}'" - let digits := intPart ++ fracPart - if digits = "" then - throw s!"invalid decimal literal '{s}'" - let n ← parseNat digits - let scale := fracPart.length - let base : Rat := - (Rat.ofInt (sign * Int.ofNat n)) / Rat.ofInt (Int.ofNat (pow10 scale)) - let exp ← - match expStr? with - | none => pure (0 : Int) - | some e => parseInt e - if exp ≥ 0 then - let k := Int.toNat exp - pure (ratRoundDown (base * Rat.ofInt (Int.ofNat (pow10 k)))) - else - let k := Int.toNat (-exp) - pure (ratRoundDown (base / Rat.ofInt (Int.ofNat (pow10 k)))) - -private def readHeaderFieldRat (names : List String) (fields : List (String × String)) : - Except String Rat := do - let rec loop : List String → Option String - | [] => none - | name :: rest => - match fields.find? (fun kv => kv.1 = name) with - | some kv => some kv.2 - | none => loop rest - match loop names with - | some raw => parseRatScientific raw - | none => throw s!"missing header field '{String.intercalate "|" names}'" - -private def sentinelBytes : ByteArray := - "BINARY_START\n".toUTF8 - -private def findSentinel (data : ByteArray) : Option Nat := - let n := data.size - let m := sentinelBytes.size - if m ≤ n then - let maxStart := n - m - let rec loop (i : Nat) (remaining : Nat) : Option Nat := - match remaining with - | 0 => none - | Nat.succ rem => - let ok := - (List.range m).all (fun j => data.get! (i + j) = sentinelBytes.get! j) - if ok then - some i - else - loop (i + 1) rem - loop 0 (maxStart + 1) - else - none - -/-- Parse the NFP binary header and return the binary start offset. -/ -def parseHeader (data : ByteArray) : Except String (NfptHeader × Nat) := do - let idx ← - match findSentinel data with - | some i => pure i - | none => throw "missing BINARY_START sentinel" - let headerBytes := data.extract 0 idx - let headerStr ← - match String.fromUTF8? headerBytes with - | some s => pure s - | none => throw "invalid UTF-8 in header" - let lines := headerStr.splitOn "\n" |>.filter (· ≠ "") - match lines with - | [] => throw "empty header" - | magic :: rest => - if magic != "NFP_BINARY_V1" then - throw s!"unexpected magic '{magic}'" - let fields := rest.filterMap splitKV - let numLayers ← readHeaderField "num_layers" fields - let numHeads ← readHeaderField "num_heads" fields - let modelDim ← readHeaderField "model_dim" fields - let headDim ← readHeaderField "head_dim" fields - let hiddenDim ← readHeaderField "hidden_dim" fields - let vocabSize ← readHeaderField "vocab_size" fields - let seqLen ← readHeaderField "seq_len" fields - let layerNormEps ← readHeaderFieldRat ["layer_norm_eps", "eps"] fields - if numLayers = 0 then - throw "num_layers must be positive" - if numHeads = 0 then - throw "num_heads must be positive" - if modelDim = 0 then - throw "model_dim must be positive" - if headDim = 0 then - throw "head_dim must be positive" - if hiddenDim = 0 then - throw "hidden_dim must be positive" - if vocabSize = 0 then - throw "vocab_size must be positive" - if seqLen = 0 then - throw "seq_len must be positive" - let start := idx + sentinelBytes.size - return ({ numLayers := numLayers - numHeads := numHeads - modelDim := modelDim - headDim := headDim - hiddenDim := hiddenDim - vocabSize := vocabSize - seqLen := seqLen - layerNormEps := layerNormEps }, start) - -private def pow2 (k : Nat) : Nat := - Nat.pow 2 k - -private def getBits (n hi lo : Nat) : Nat := - (n / pow2 lo) % pow2 (hi - lo + 1) - -private def ratOfFloatBits (bits : Nat) : Option Rat := - let signBit := getBits bits 63 63 - let expBits := getBits bits 62 52 - let mantBits := getBits bits 51 0 - let sign : Int := if signBit = 0 then 1 else -1 - if expBits = 2047 then - none - else if expBits = 0 then - if mantBits = 0 then - some 0 - else - let num : Int := sign * Int.ofNat mantBits - some (ratOfIntWithPrec num 1074) - else - let mant := mantBits + pow2 52 - let exp : Int := Int.ofNat expBits - 1023 - let shift : Int := exp - 52 - let prec : Int := -shift - some (ratOfIntWithPrec (sign * Int.ofNat mant) prec) - -private def readNatLE (data : ByteArray) (off : Nat) (count : Nat) : Option Nat := - if off + count ≤ data.size then - let rec loop (i : Nat) (acc : Nat) : Nat := - if i < count then - let byte := data.get! (off + i) - loop (i + 1) (acc + byte.toNat * pow2 (8 * i)) - else - acc - some (loop 0 0) - else - none - -private def readI32 (data : ByteArray) (off : Nat) : Option Int := do - let bits ← readNatLE data off 4 - let two31 := pow2 31 - let two32 := pow2 32 - if bits < two31 then - some (Int.ofNat bits) - else - some (Int.ofNat bits - Int.ofNat two32) - -private def readF64Rat (data : ByteArray) (off : Nat) : Option Rat := do - let bits ← readNatLE data off 8 - ratOfFloatBits bits - -private def bytesI32 (n : Nat) : Nat := - n * 4 - -private def bytesF64 (n : Nat) : Nat := - n * 8 - -private def sqrtNat? (n : Nat) : Option Nat := - let k := Nat.sqrt n - if k * k = n then - some k - else - none - -private def scaleOfHeadDim (dHead : Nat) : Except String Rat := do - match sqrtNat? dHead with - | some k => - if k = 0 then - throw "head_dim must be positive" - else - pure (ratRoundDown (Rat.ofInt 1 / Rat.ofInt (Int.ofNat k))) - | none => - throw "head_dim must be a perfect square to compute scale" - -private def matrixIndex {rows cols : Nat} (i : Fin rows) (j : Fin cols) : Fin (rows * cols) := - let idx := i.val * cols + j.val - have hstep : i.val * cols + j.val < (i.val + 1) * cols := by - have h' : i.val * cols + j.val < i.val * cols + cols := - Nat.add_lt_add_left j.isLt _ - have hmul : (i.val + 1) * cols = i.val * cols + cols := by - simpa [Nat.succ_eq_add_one] using (Nat.succ_mul i.val cols) - exact hmul ▸ h' - have hle : (i.val + 1) * cols ≤ rows * cols := - Nat.mul_le_mul_right cols (Nat.succ_le_iff.mpr i.isLt) - ⟨idx, lt_of_lt_of_le hstep hle⟩ - -private def readF64ListAux (data : ByteArray) (off : Nat) : - Nat → List Rat → Except String (List Rat) - | 0, acc => Except.ok acc.reverse - | Nat.succ n, acc => - match readF64Rat data off with - | some v => readF64ListAux data (off + bytesF64 1) n (v :: acc) - | none => Except.error s!"invalid f64 at offset {off}" - -private theorem readF64ListAux_length (data : ByteArray) : - ∀ (off n : Nat) (acc xs : List Rat), - readF64ListAux data off n acc = Except.ok xs → - xs.length = acc.length + n := by - intro off n acc xs h - induction n generalizing off acc xs with - | zero => - have h' := h - simp only [readF64ListAux] at h' - cases h' - simp - | succ n ih => - cases hread : readF64Rat data off with - | none => - have h' := h - simp only [readF64ListAux, hread] at h' - cases h' - | some v => - have h' := h - simp only [readF64ListAux, hread] at h' - have hlen := ih (off := off + bytesF64 1) (acc := v :: acc) (xs := xs) h' - simpa [List.length, Nat.add_assoc, Nat.add_left_comm, Nat.add_comm] using hlen - -private def readF64List (data : ByteArray) (off : Nat) (count : Nat) : - Except String {xs : List Rat // xs.length = count} := - match h : readF64ListAux data off count [] with - | Except.error msg => Except.error msg - | Except.ok xs => - have hlen : - xs.length = count := by - simpa using readF64ListAux_length (data := data) (off := off) - (n := count) (acc := []) (xs := xs) h - Except.ok ⟨xs, hlen⟩ - -private def readI32ListAux (data : ByteArray) (off : Nat) : - Nat → List Int → Except String (List Int) - | 0, acc => Except.ok acc.reverse - | Nat.succ n, acc => - match readI32 data off with - | some v => readI32ListAux data (off + bytesI32 1) n (v :: acc) - | none => Except.error s!"invalid i32 at offset {off}" - -private theorem readI32ListAux_length (data : ByteArray) : - ∀ (off n : Nat) (acc xs : List Int), - readI32ListAux data off n acc = Except.ok xs → - xs.length = acc.length + n := by - intro off n acc xs h - induction n generalizing off acc xs with - | zero => - have h' := h - simp only [readI32ListAux] at h' - cases h' - simp - | succ n ih => - cases hread : readI32 data off with - | none => - have h' := h - simp only [readI32ListAux, hread] at h' - cases h' - | some v => - have h' := h - simp only [readI32ListAux, hread] at h' - have hlen := ih (off := off + bytesI32 1) (acc := v :: acc) (xs := xs) h' - simpa [List.length, Nat.add_assoc, Nat.add_left_comm, Nat.add_comm] using hlen - -private def readI32List (data : ByteArray) (off : Nat) (count : Nat) : - Except String {xs : List Int // xs.length = count} := - match h : readI32ListAux data off count [] with - | Except.error msg => Except.error msg - | Except.ok xs => - have hlen : - xs.length = count := by - simpa using readI32ListAux_length (data := data) (off := off) - (n := count) (acc := []) (xs := xs) h - Except.ok ⟨xs, hlen⟩ - -private def readF64Matrix (data : ByteArray) (off : Nat) (rows cols : Nat) : - Except String (Fin rows → Fin cols → Rat) := do - let count := rows * cols - let ⟨vals, hlen⟩ ← readF64List data off count - let hlen' : vals.length = rows * cols := by - simpa using hlen - let mat : Fin rows → Fin cols → Rat := fun i j => - let idx := matrixIndex i j - let hidx : idx.val < vals.length := lt_of_lt_of_eq idx.isLt hlen'.symm - vals.get ⟨idx.val, hidx⟩ - return mat - -private def readF64Vec (data : ByteArray) (off : Nat) (count : Nat) : - Except String (Fin count → Rat) := do - let ⟨vals, hlen⟩ ← readF64List data off count - let hlen' : vals.length = count := by - simpa using hlen - let vec : Fin count → Rat := fun i => - vals.get ⟨i.val, lt_of_lt_of_eq i.isLt hlen'.symm⟩ - return vec - -private def f64CountPerHead (h : NfptHeader) : Nat := - 4 * h.modelDim * h.headDim + 3 * h.headDim - -private def f64CountPerLayer (h : NfptHeader) : Nat := - h.numHeads * f64CountPerHead h + - (2 * h.modelDim * h.hiddenDim + h.hiddenDim) + - (6 * h.modelDim) - -private def f64CountBeforeUnembed (h : NfptHeader) : Nat := - h.seqLen * h.modelDim + - h.numLayers * f64CountPerLayer h + - (2 * h.modelDim) - -private def f64CountBeforeHeads (h : NfptHeader) : Nat := - h.seqLen * h.modelDim - -/-- Byte offset from the binary start to the unembedding matrix. -/ -def unembedOffset (h : NfptHeader) : Nat := - bytesI32 h.seqLen + bytesF64 (f64CountBeforeUnembed h) - -private def finalLayerNormOffset (h : NfptHeader) : Nat := - bytesI32 h.seqLen + - bytesF64 (f64CountBeforeHeads h + h.numLayers * f64CountPerLayer h) - -/-- Read input embeddings stored in the binary. -/ -def readEmbeddings (data : ByteArray) (start : Nat) (h : NfptHeader) : - Except String (Fin h.seqLen → Fin h.modelDim → Rat) := do - let base := start + bytesI32 h.seqLen - readF64Matrix data base h.seqLen h.modelDim - -/-- Read input token ids stored in the binary. -/ -def readTokens (data : ByteArray) (start : Nat) (h : NfptHeader) : - Except String (Fin h.seqLen → Nat) := do - let ⟨vals, hlen⟩ ← readI32List data start h.seqLen - let ok := vals.all (fun z => decide (0 ≤ z)) - if !ok then - throw "token ids must be nonnegative" - let hlen' : vals.length = h.seqLen := by - simpa using hlen - let tokens : Fin h.seqLen → Nat := fun i => - Int.toNat (vals.get ⟨i.val, lt_of_lt_of_eq i.isLt hlen'.symm⟩) - return tokens - -private def headOffset (h : NfptHeader) (layer head : Nat) : Nat := - bytesI32 h.seqLen + - bytesF64 (f64CountBeforeHeads h + - layer * f64CountPerLayer h + - head * f64CountPerHead h) - -private def layerExtrasOffset (h : NfptHeader) (layer : Nat) : Nat := - bytesI32 h.seqLen + - bytesF64 (f64CountBeforeHeads h + - layer * f64CountPerLayer h + - h.numHeads * f64CountPerHead h) - -/-- Read attention head weights and biases for a specific layer/head. -/ -def readHeadWeights (data : ByteArray) (start : Nat) (h : NfptHeader) - (layer head : Nat) : - Except String (Model.Gpt2HeadWeights h.modelDim h.headDim) := do - if layer < h.numLayers then - if head < h.numHeads then - let base := start + headOffset h layer head - let wq ← readF64Matrix data base h.modelDim h.headDim - let offbq := base + bytesF64 (h.modelDim * h.headDim) - let bq ← readF64Vec data offbq h.headDim - let offwk := offbq + bytesF64 h.headDim - let wk ← readF64Matrix data offwk h.modelDim h.headDim - let offbk := offwk + bytesF64 (h.modelDim * h.headDim) - let bk ← readF64Vec data offbk h.headDim - let offwv := offbk + bytesF64 h.headDim - let wv ← readF64Matrix data offwv h.modelDim h.headDim - let offbv := offwv + bytesF64 (h.modelDim * h.headDim) - let bv ← readF64Vec data offbv h.headDim - let offwo := offbv + bytesF64 h.headDim - let woRaw ← readF64Matrix data offwo h.headDim h.modelDim - let wo : Fin h.modelDim → Fin h.headDim → Rat := fun i j => woRaw j i - return { wq := wq, bq := bq, wk := wk, bk := bk, wv := wv, bv := bv, wo := wo } - else - throw s!"head index out of range: {head}" - else - throw s!"layer index out of range: {layer}" - -private def readLayerAttnBiasLn1 (data : ByteArray) (start : Nat) (h : NfptHeader) - (layer : Nat) : - Except String ((Fin h.modelDim → Rat) × (Fin h.modelDim → Rat) × - (Fin h.modelDim → Rat)) := do - if layer < h.numLayers then - let base := start + layerExtrasOffset h layer - let attnBias ← readF64Vec data base h.modelDim - let offWIn := base + bytesF64 h.modelDim - let offBIn := offWIn + bytesF64 (h.modelDim * h.hiddenDim) - let offWOut := offBIn + bytesF64 h.hiddenDim - let offBOut := offWOut + bytesF64 (h.hiddenDim * h.modelDim) - let offLn1Gamma := offBOut + bytesF64 h.modelDim - let ln1Gamma ← readF64Vec data offLn1Gamma h.modelDim - let offLn1Beta := offLn1Gamma + bytesF64 h.modelDim - let ln1Beta ← readF64Vec data offLn1Beta h.modelDim - return (attnBias, ln1Gamma, ln1Beta) - else - throw s!"layer index out of range: {layer}" - -/-- Read GPT-2 layer parameters (MLP + LayerNorm) from the model binary. -/ -def readLayerSlice (data : ByteArray) (start : Nat) (h : NfptHeader) - (layer : Nat) : Except String (Model.Gpt2LayerSlice h.modelDim h.hiddenDim) := do - if layer < h.numLayers then - let base := start + layerExtrasOffset h layer - let attnBias ← readF64Vec data base h.modelDim - let offWIn := base + bytesF64 h.modelDim - let mlpWIn ← readF64Matrix data offWIn h.modelDim h.hiddenDim - let offBIn := offWIn + bytesF64 (h.modelDim * h.hiddenDim) - let mlpBIn ← readF64Vec data offBIn h.hiddenDim - let offWOut := offBIn + bytesF64 h.hiddenDim - let mlpWOut ← readF64Matrix data offWOut h.hiddenDim h.modelDim - let offBOut := offWOut + bytesF64 (h.hiddenDim * h.modelDim) - let mlpBOut ← readF64Vec data offBOut h.modelDim - let offLn1Gamma := offBOut + bytesF64 h.modelDim - let ln1Gamma ← readF64Vec data offLn1Gamma h.modelDim - let offLn1Beta := offLn1Gamma + bytesF64 h.modelDim - let ln1Beta ← readF64Vec data offLn1Beta h.modelDim - let offLn2Gamma := offLn1Beta + bytesF64 h.modelDim - let ln2Gamma ← readF64Vec data offLn2Gamma h.modelDim - let offLn2Beta := offLn2Gamma + bytesF64 h.modelDim - let ln2Beta ← readF64Vec data offLn2Beta h.modelDim - return { attnBias := attnBias - mlpWIn := mlpWIn - mlpBIn := mlpBIn - mlpWOut := mlpWOut - mlpBOut := mlpBOut - ln1Gamma := ln1Gamma - ln1Beta := ln1Beta - ln2Gamma := ln2Gamma - ln2Beta := ln2Beta } - else - throw s!"layer index out of range: {layer}" - -/-- Read all GPT-2 layer slices from the model binary. -/ -def readLayerSlices (data : ByteArray) (start : Nat) (h : NfptHeader) : - Except String (SizedArray h.numLayers (Model.Gpt2LayerSlice h.modelDim h.hiddenDim)) := do - let slices ← (List.finRange h.numLayers).foldlM - (fun (acc : Array (Model.Gpt2LayerSlice h.modelDim h.hiddenDim)) layer => do - let slice ← readLayerSlice data start h layer.val - pure (acc.push slice)) - (#[] : Array (Model.Gpt2LayerSlice h.modelDim h.hiddenDim)) - if hlen : slices.size = h.numLayers then - return { data := slices, size_eq := hlen } - else - throw "internal error: layer slice count mismatch" - -/-- Read all attention head weights from the model binary. -/ -def readLayerHeads (data : ByteArray) (start : Nat) (h : NfptHeader) : - Except String - (SizedArray h.numLayers - (SizedArray h.numHeads (Model.Gpt2HeadWeights h.modelDim h.headDim))) := do - let layers ← (List.finRange h.numLayers).foldlM - (fun (acc : Array - (SizedArray h.numHeads (Model.Gpt2HeadWeights h.modelDim h.headDim))) layer => do - let heads ← (List.finRange h.numHeads).foldlM - (fun (accHead : Array (Model.Gpt2HeadWeights h.modelDim h.headDim)) head => do - let weights ← readHeadWeights data start h layer.val head.val - pure (accHead.push weights)) - (#[] : Array (Model.Gpt2HeadWeights h.modelDim h.headDim)) - if hlen : heads.size = h.numHeads then - let headArray : SizedArray h.numHeads (Model.Gpt2HeadWeights h.modelDim h.headDim) := - { data := heads, size_eq := hlen } - pure (acc.push headArray) - else - throw "internal error: head count mismatch") - (#[] : Array - (SizedArray h.numHeads (Model.Gpt2HeadWeights h.modelDim h.headDim))) - if hlen : layers.size = h.numLayers then - return { data := layers, size_eq := hlen } - else - throw "internal error: layer head count mismatch" - -/-- Read the final LayerNorm parameters from the model binary. -/ -def readFinalLayerNorm (data : ByteArray) (start : Nat) (h : NfptHeader) : - Except String (Model.Gpt2FinalLayerNorm h.modelDim) := do - let base := start + finalLayerNormOffset h - let gamma ← readF64Vec data base h.modelDim - let offBeta := base + bytesF64 h.modelDim - let beta ← readF64Vec data offBeta h.modelDim - return { gamma := gamma, beta := beta } - -/-- Read a single unembedding column as exact rationals. -/ -def readUnembedColumn (data : ByteArray) (start : Nat) (h : NfptHeader) (col : Nat) : - Except String (Fin h.modelDim → Rat) := do - if col < h.vocabSize then - let base := start + unembedOffset h - let rows := List.range h.modelDim - let vals ← rows.mapM (fun row => do - let off := base + bytesF64 (row * h.vocabSize + col) - match readF64Rat data off with - | some v => pure v - | none => throw s!"invalid f64 at offset {off}") - if hlen : vals.length = h.modelDim then - let vec : Fin h.modelDim → Rat := fun i => - vals.get ⟨i.val, by simp [hlen]⟩ - return vec - else - throw "internal error: unembed column length mismatch" - else - throw s!"column out of range: {col}" - -/-- Read induction-head inputs directly from the model binary. -/ -def buildInductionHeadInputs (h : NfptHeader) (scale : Rat) - (tokens : Fin h.seqLen → Nat) - (embed : Fin h.seqLen → Fin h.modelDim → Rat) - (weights : Model.Gpt2HeadWeights h.modelDim h.headDim) - (attnBias ln1Gamma ln1Beta : Fin h.modelDim → Rat) - (dirTarget dirNegative : Nat) - (colTarget colNegative : Fin h.modelDim → Rat) - (period? : Option Nat) (shiftPrev : Bool) : - Model.InductionHeadInputs h.seqLen h.modelDim h.headDim := - let direction : Fin h.modelDim → Rat := fun i => colTarget i - colNegative i - let directionSpec : Circuit.DirectionSpec := - { target := dirTarget, negative := dirNegative } - let active := - match period? with - | some period => - if shiftPrev then - Model.activeOfPeriodShift (seq := h.seqLen) period - else - Model.activeOfPeriod (seq := h.seqLen) period - | none => - if shiftPrev then - Model.activeOfTokensShift (seq := h.seqLen) tokens - else - Model.activeOfTokens (seq := h.seqLen) tokens - let prev := - match period? with - | some period => - if shiftPrev then - Model.prevOfPeriodShift (seq := h.seqLen) period - else - Model.prevOfPeriod (seq := h.seqLen) period - | none => - if shiftPrev then - Model.prevOfTokensShift (seq := h.seqLen) tokens - else - Model.prevOfTokens (seq := h.seqLen) tokens - { scale := scale - active := active - prev := prev - embed := embed - lnEps := h.layerNormEps - ln1Gamma := ln1Gamma - ln1Beta := ln1Beta - wq := weights.wq - bq := weights.bq - wk := weights.wk - bk := weights.bk - wv := weights.wv - bv := weights.bv - wo := weights.wo - attnBias := attnBias - maskCausal := true - maskValue := (-10000 : Rat) - directionSpec := directionSpec - direction := direction } - -/-- Definitional characterization of `buildInductionHeadInputs`. -/ -private theorem buildInductionHeadInputs_def (h : NfptHeader) (scale : Rat) - (tokens : Fin h.seqLen → Nat) - (embed : Fin h.seqLen → Fin h.modelDim → Rat) - (weights : Model.Gpt2HeadWeights h.modelDim h.headDim) - (attnBias ln1Gamma ln1Beta : Fin h.modelDim → Rat) - (dirTarget dirNegative : Nat) - (colTarget colNegative : Fin h.modelDim → Rat) - (period? : Option Nat) (shiftPrev : Bool) : - buildInductionHeadInputs h scale tokens embed weights attnBias ln1Gamma ln1Beta - dirTarget dirNegative colTarget colNegative period? shiftPrev = - { scale := scale - active := - match period? with - | some period => - if shiftPrev then - Model.activeOfPeriodShift (seq := h.seqLen) period - else - Model.activeOfPeriod (seq := h.seqLen) period - | none => - if shiftPrev then - Model.activeOfTokensShift (seq := h.seqLen) tokens - else - Model.activeOfTokens (seq := h.seqLen) tokens - prev := - match period? with - | some period => - if shiftPrev then - Model.prevOfPeriodShift (seq := h.seqLen) period - else - Model.prevOfPeriod (seq := h.seqLen) period - | none => - if shiftPrev then - Model.prevOfTokensShift (seq := h.seqLen) tokens - else - Model.prevOfTokens (seq := h.seqLen) tokens - embed := embed - lnEps := h.layerNormEps - ln1Gamma := ln1Gamma - ln1Beta := ln1Beta - wq := weights.wq - bq := weights.bq - wk := weights.wk - bk := weights.bk - wv := weights.wv - bv := weights.bv - wo := weights.wo - attnBias := attnBias - maskCausal := true - maskValue := (-10000 : Rat) - directionSpec := { target := dirTarget, negative := dirNegative } - direction := fun i => colTarget i - colNegative i } := rfl - -/-- `buildInductionHeadInputs` uses the supplied direction ids and columns. -/ -theorem buildInductionHeadInputs_direction_def (h : NfptHeader) (scale : Rat) - (tokens : Fin h.seqLen → Nat) - (embed : Fin h.seqLen → Fin h.modelDim → Rat) - (weights : Model.Gpt2HeadWeights h.modelDim h.headDim) - (attnBias ln1Gamma ln1Beta : Fin h.modelDim → Rat) - (dirTarget dirNegative : Nat) - (colTarget colNegative : Fin h.modelDim → Rat) - (period? : Option Nat) (shiftPrev : Bool) : - let inputs := - buildInductionHeadInputs h scale tokens embed weights attnBias ln1Gamma ln1Beta - dirTarget dirNegative colTarget colNegative period? shiftPrev - inputs.directionSpec = { target := dirTarget, negative := dirNegative } ∧ - inputs.direction = fun i => colTarget i - colNegative i := by - simp [buildInductionHeadInputs] - -/-- `buildInductionHeadInputs` derives `prev`/`active` from tokens or a fixed period. -/ -theorem buildInductionHeadInputs_prev_active_def (h : NfptHeader) (scale : Rat) - (tokens : Fin h.seqLen → Nat) - (embed : Fin h.seqLen → Fin h.modelDim → Rat) - (weights : Model.Gpt2HeadWeights h.modelDim h.headDim) - (attnBias ln1Gamma ln1Beta : Fin h.modelDim → Rat) - (dirTarget dirNegative : Nat) - (colTarget colNegative : Fin h.modelDim → Rat) - (period? : Option Nat) (shiftPrev : Bool) : - let inputs := - buildInductionHeadInputs h scale tokens embed weights attnBias ln1Gamma ln1Beta - dirTarget dirNegative colTarget colNegative period? shiftPrev - inputs.active = - (match period? with - | some period => - if shiftPrev then - Model.activeOfPeriodShift (seq := h.seqLen) period - else - Model.activeOfPeriod (seq := h.seqLen) period - | none => - if shiftPrev then - Model.activeOfTokensShift (seq := h.seqLen) tokens - else - Model.activeOfTokens (seq := h.seqLen) tokens) ∧ - inputs.prev = - (match period? with - | some period => - if shiftPrev then - Model.prevOfPeriodShift (seq := h.seqLen) period - else - Model.prevOfPeriod (seq := h.seqLen) period - | none => - if shiftPrev then - Model.prevOfTokensShift (seq := h.seqLen) tokens - else - Model.prevOfTokens (seq := h.seqLen) tokens) := by - constructor <;> rfl - -/-- Active queries pick the maximal matching prior token when `period? = none`. -/ -theorem buildInductionHeadInputs_prev_spec_of_active (h : NfptHeader) (scale : Rat) - (tokens : Fin h.seqLen → Nat) - (embed : Fin h.seqLen → Fin h.modelDim → Rat) - (weights : Model.Gpt2HeadWeights h.modelDim h.headDim) - (attnBias ln1Gamma ln1Beta : Fin h.modelDim → Rat) - (dirTarget dirNegative : Nat) - (colTarget colNegative : Fin h.modelDim → Rat) : - ∀ {q}, - q ∈ (buildInductionHeadInputs h scale tokens embed weights attnBias ln1Gamma ln1Beta - dirTarget dirNegative colTarget colNegative none false).active → - let p := - (buildInductionHeadInputs h scale tokens embed weights attnBias ln1Gamma ln1Beta - dirTarget dirNegative colTarget colNegative none false).prev q - p < q ∧ tokens p = tokens q ∧ - ∀ k, k < q → tokens k = tokens q → k ≤ p := by - intro q hq - have hq' : q ∈ Model.activeOfTokens (seq := h.seqLen) tokens := by - simpa [buildInductionHeadInputs] using hq - have hspec := Model.prevOfTokens_spec_of_active (tokens := tokens) (q := q) hq' - simpa [buildInductionHeadInputs] using hspec - -/-- Read induction-head inputs directly from the model binary. - -`shiftPrev` selects between the unshifted prompt map (`prev = q - period`) and the -shifted map (`prev = q - period + 1`), with analogous token-derived versions. --/ -def readInductionHeadInputs (data : ByteArray) (start : Nat) (h : NfptHeader) - (layer head dirTarget dirNegative : Nat) (period? : Option Nat) (shiftPrev : Bool) : - Except String (Model.InductionHeadInputs h.seqLen h.modelDim h.headDim) := do - let scale ← scaleOfHeadDim h.headDim - let tokens ← readTokens data start h - let embed ← readEmbeddings data start h - let weights ← readHeadWeights data start h layer head - let (attnBias, ln1Gamma, ln1Beta) ← readLayerAttnBiasLn1 data start h layer - let colTarget ← readUnembedColumn data start h dirTarget - let colNegative ← readUnembedColumn data start h dirNegative - pure <| - buildInductionHeadInputs h scale tokens embed weights attnBias ln1Gamma ln1Beta - dirTarget dirNegative colTarget colNegative period? shiftPrev - -end NfptPure - -end IO - -end Nfp diff --git a/Nfp/IO/Run.lean b/Nfp/IO/Run.lean deleted file mode 100644 index d48e4f3..0000000 --- a/Nfp/IO/Run.lean +++ /dev/null @@ -1,9 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.IO.Run.Basic - -/-! -IO entrypoints used by the CLI. --/ diff --git a/Nfp/IO/Run/Basic.lean b/Nfp/IO/Run/Basic.lean deleted file mode 100644 index 301ff46..0000000 --- a/Nfp/IO/Run/Basic.lean +++ /dev/null @@ -1,818 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.IO.Checks -public import Nfp.IO.Derive -public import Nfp.IO.HeadScore -public import Nfp.IO.InductionHead -public import Nfp.IO.Loaders -public import Nfp.IO.NfptPure -public import Nfp.IO.Timing -public import Nfp.IO.Util -public import Nfp.Circuit.Cert.DownstreamLinear -public import Nfp.Circuit.Cert.LogitDiff -public import Nfp.Circuit.Cert.ResidualBound -public import Nfp.Circuit.Cert.ResidualInterval -public import Nfp.Sound.Bounds.MatrixNorm -public import Nfp.Sound.Bounds.Transformer -public import Nfp.Sound.Induction -public import Nfp.Sound.Induction.HeadBounds -public import Nfp.Sound.Induction.LogitDiff -public import Nfp.Sound.Linear.FinFold - -/-! -IO entrypoints used by the CLI. --/ - -public section - -namespace Nfp -namespace IO -open Nfp.Circuit - -/-- Check induction certificates and print a short status line. -/ -def runInductionCertify (scoresPath : System.FilePath) - (valuesPath? : Option System.FilePath) (minActive? : Option Nat) - (minLogitDiffStr? : Option String) (minMarginStr? : Option String) - (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - if minLogitDiff?.isSome && valuesPath?.isNone then - IO.eprintln "error: min-logit-diff requires --values" - return 2 - let parsedScores ← timePhase "load softmax cert" <| - loadSoftmaxMarginCert scoresPath - match parsedScores with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seq, cert⟩ => - let scoresOk ← timePhase "check softmax cert" <| - checkSoftmaxMargin seq cert - match scoresOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - let activeCount := cert.active.card - let defaultMinActive := max 1 (seq / 8) - let minActive := minActive?.getD defaultMinActive - if activeCount < minActive then - IO.eprintln - s!"error: active queries {activeCount} below minimum {minActive}" - return 2 - if cert.margin < minMargin then - IO.eprintln - s!"error: margin {cert.margin} below minimum {minMargin}" - return 2 - if maxEps < cert.eps then - IO.eprintln - s!"error: eps {cert.eps} above maximum {maxEps}" - return 2 - match valuesPath? with - | none => - IO.println - s!"ok: softmax-margin certificate accepted \ - (seq={seq}, active={activeCount})" - return 0 - | some valuesPath => - let parsedValues ← loadValueRangeCert valuesPath - match parsedValues with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seqVals, certVals⟩ => - if hseq : seqVals ≠ seq then - IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" - return 2 - else - have hseq' : seqVals = seq := by - exact (not_ne_iff).1 hseq - let certVals' : ValueRangeCert seq := by - simpa [hseq'] using certVals - let valuesOk ← checkValueRange seq certVals' - match valuesOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - let tol := cert.eps * (certVals'.hi - certVals'.lo) - let logitDiffLB? := - Circuit.logitDiffLowerBound cert.active cert.prev cert.eps - certVals'.lo certVals'.hi certVals'.vals - let effectiveMinLogitDiff := - match minLogitDiff?, certVals'.direction with - | some v, _ => some v - | none, some _ => some (0 : Rat) - | none, none => none - match logitDiffLB? with - | none => - IO.eprintln "error: empty active set for logit-diff bound" - return (2 : UInt32) - | some logitDiffLB => - let violation? : Option Rat := - match effectiveMinLogitDiff with - | none => none - | some minLogitDiff => - if logitDiffLB < minLogitDiff then - some minLogitDiff - else - none - match violation? with - | some minLogitDiff => - IO.eprintln - s!"error: logitDiffLB {logitDiffLB} \ - below minimum {minLogitDiff}" - return (2 : UInt32) - | none => - IO.println - s!"ok: induction bound certified \ - (seq={seq}, active={activeCount}, tol={tol}, \ - logitDiffLB={logitDiffLB})" - return 0 -/-- Build and check induction certificates from raw scores/values. -/ -def runInductionCertifySound (scoresPath : System.FilePath) - (valuesPath : System.FilePath) (minActive? : Option Nat) - (minLogitDiffStr? : Option String) (minMarginStr? : Option String) - (maxEpsStr? : Option String) : IO UInt32 := do - warnDeprecated - "certify_sound builds certificates from raw scores/values; use explicit certs \ - via `nfp induction certify` or `nfp induction head_cert_check`." - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - let parsedScores ← loadSoftmaxMarginRaw scoresPath - match parsedScores with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seq, raw⟩ => - match seq with - | 0 => - IO.eprintln "error: seq must be positive" - return 2 - | Nat.succ n => - let seq := Nat.succ n - let _ : NeZero seq := ⟨by simp⟩ - match Sound.buildSoftmaxMarginCert? raw.active raw.prev raw.scores raw.weights with - | none => - IO.eprintln "error: softmax-margin inputs rejected" - return 2 - | some ⟨cert, _⟩ => - let activeCount := cert.active.card - let defaultMinActive := max 1 (seq / 8) - let minActive := minActive?.getD defaultMinActive - if activeCount < minActive then - IO.eprintln - s!"error: active queries {activeCount} below minimum {minActive}" - return 2 - if cert.margin < minMargin then - IO.eprintln - s!"error: margin {cert.margin} below minimum {minMargin}" - return 2 - if maxEps < cert.eps then - IO.eprintln - s!"error: eps {cert.eps} above maximum {maxEps}" - return 2 - let parsedValues ← loadValueRangeRaw valuesPath - match parsedValues with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seqVals, rawVals⟩ => - if hseq : seqVals ≠ seq then - IO.eprintln - s!"error: seq mismatch (scores={seq}, values={seqVals})" - return 2 - else - have hseq' : seqVals = seq := by - exact (not_ne_iff).1 hseq - let rawVals' : Pure.ValueRangeRaw seq := by - simpa [hseq'] using rawVals - match Sound.buildValueRangeCert? rawVals'.vals rawVals'.direction with - | none => - IO.eprintln "error: value-range inputs rejected" - return 2 - | some ⟨certVals, _⟩ => - let tol := cert.eps * (certVals.hi - certVals.lo) - let logitDiffLB? := - Circuit.logitDiffLowerBound cert.active cert.prev cert.eps - certVals.lo certVals.hi certVals.vals - let effectiveMinLogitDiff := - match minLogitDiff?, certVals.direction with - | some v, _ => some v - | none, some _ => some (0 : Rat) - | none, none => none - match logitDiffLB? with - | none => - IO.eprintln "error: empty active set for logit-diff bound" - return 2 - | some logitDiffLB => - let violation? : Option Rat := - match effectiveMinLogitDiff with - | none => none - | some minLogitDiff => - if logitDiffLB < minLogitDiff then - some minLogitDiff - else - none - match violation? with - | some minLogitDiff => - IO.eprintln - s!"error: logitDiffLB {logitDiffLB} \ - below minimum {minLogitDiff}" - return 2 - | none => - IO.println - s!"ok: induction bound certified \ - (seq={seq}, active={activeCount}, \ - tol={tol}, logitDiffLB={logitDiffLB})" - return 0 -/-- Check end-to-end induction certificates with a downstream error bound. -/ -def runInductionCertifyEndToEnd (scoresPath : System.FilePath) - (valuesPath : System.FilePath) (downstreamPath : System.FilePath) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - let parsedScores ← timePhase "load softmax cert" <| - loadSoftmaxMarginCert scoresPath - match parsedScores with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seq, cert⟩ => - let scoresOk ← timePhase "check softmax cert" <| - checkSoftmaxMargin seq cert - match scoresOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - let activeCount := cert.active.card - let defaultMinActive := max 1 (seq / 8) - let minActive := minActive?.getD defaultMinActive - if activeCount < minActive then - IO.eprintln - s!"error: active queries {activeCount} below minimum {minActive}" - return 2 - if cert.margin < minMargin then - IO.eprintln - s!"error: margin {cert.margin} below minimum {minMargin}" - return 2 - if maxEps < cert.eps then - IO.eprintln - s!"error: eps {cert.eps} above maximum {maxEps}" - return 2 - let parsedValues ← timePhase "load value cert" <| - loadValueRangeCert valuesPath - match parsedValues with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seqVals, certVals⟩ => - if hseq : seqVals ≠ seq then - IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" - return 2 - else - have hseq' : seqVals = seq := by - exact (not_ne_iff).1 hseq - let certVals' : ValueRangeCert seq := by - simpa [hseq'] using certVals - let valuesOk ← timePhase "check value cert" <| - checkValueRange seq certVals' - match valuesOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - let logitDiffLB? ← timePure "logit-diff lower bound" (fun () => - Circuit.logitDiffLowerBound cert.active cert.prev cert.eps - certVals'.lo certVals'.hi certVals'.vals) - let effectiveMinLogitDiff := - match minLogitDiff?, certVals'.direction with - | some v, _ => some v - | none, some _ => some (0 : Rat) - | none, none => none - match logitDiffLB? with - | none => - IO.eprintln "error: empty active set for logit-diff bound" - return (2 : UInt32) - | some logitDiffLB => - let parsedDownstream ← loadDownstreamLinearCert downstreamPath - match parsedDownstream with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok downstream => - let downstreamOk := Circuit.checkDownstreamLinearCert downstream - if downstreamOk then - let finalLB := logitDiffLB - downstream.error - let violation? : Option Rat := - match effectiveMinLogitDiff with - | none => none - | some minLogitDiff => - if finalLB < minLogitDiff then - some minLogitDiff - else - none - match violation? with - | some minLogitDiff => - IO.eprintln - s!"error: end-to-end logitDiffLB {finalLB} \ - below minimum {minLogitDiff}" - return (2 : UInt32) - | none => - IO.println - s!"ok: end-to-end induction bound certified \ - (seq={seq}, active={activeCount}, \ - logitDiffLB={logitDiffLB}, \ - downstreamError={downstream.error}, \ - finalLB={finalLB})" - return 0 - else - IO.eprintln "error: downstream certificate rejected" - return 2 -/-- Check end-to-end induction certificates with a downstream matrix. -/ -def runInductionCertifyEndToEndMatrix (scoresPath : System.FilePath) - (valuesPath : System.FilePath) (matrixPath : System.FilePath) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - warnDeprecated - "certify_end_to_end_matrix builds downstream bounds from a raw matrix payload; \ - use a downstream cert instead." - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - let parsedScores ← timePhase "load softmax cert" <| - loadSoftmaxMarginCert scoresPath - match parsedScores with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seq, cert⟩ => - let scoresOk ← timePhase "check softmax cert" <| - checkSoftmaxMargin seq cert - match scoresOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - let activeCount := cert.active.card - let defaultMinActive := max 1 (seq / 8) - let minActive := minActive?.getD defaultMinActive - if activeCount < minActive then - IO.eprintln - s!"error: active queries {activeCount} below minimum {minActive}" - return 2 - if cert.margin < minMargin then - IO.eprintln - s!"error: margin {cert.margin} below minimum {minMargin}" - return 2 - if maxEps < cert.eps then - IO.eprintln - s!"error: eps {cert.eps} above maximum {maxEps}" - return 2 - let parsedValues ← timePhase "load value cert" <| - loadValueRangeCert valuesPath - match parsedValues with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seqVals, certVals⟩ => - if hseq : seqVals ≠ seq then - IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" - return 2 - else - have hseq' : seqVals = seq := by - exact (not_ne_iff).1 hseq - let certVals' : ValueRangeCert seq := by - simpa [hseq'] using certVals - let valuesOk ← timePhase "check value cert" <| - checkValueRange seq certVals' - match valuesOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - let logitDiffLB? ← timePure "logit-diff lower bound" (fun () => - Circuit.logitDiffLowerBound cert.active cert.prev cert.eps - certVals'.lo certVals'.hi certVals'.vals) - let effectiveMinLogitDiff := - match minLogitDiff?, certVals'.direction with - | some v, _ => some v - | none, some _ => some (0 : Rat) - | none, none => none - match logitDiffLB? with - | none => - IO.eprintln "error: empty active set for logit-diff bound" - return (2 : UInt32) - | some logitDiffLB => - let parsedMatrix ← loadDownstreamMatrixRaw matrixPath - match parsedMatrix with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨rows, ⟨cols, raw⟩⟩ => - let inputBound := raw.inputBound - if hneg : inputBound < 0 then - IO.eprintln - s!"error: input-bound {inputBound} must be nonnegative" - return 2 - else - have hinput : 0 ≤ inputBound := by - exact le_of_not_gt hneg - let W : Matrix (Fin rows) (Fin cols) Rat := raw.entries - let downstream := - (Sound.Bounds.buildDownstreamLinearCert W inputBound hinput).1 - let finalLB := logitDiffLB - downstream.error - let violation? : Option Rat := - match effectiveMinLogitDiff with - | none => none - | some minLogitDiff => - if finalLB < minLogitDiff then - some minLogitDiff - else - none - match violation? with - | some minLogitDiff => - IO.eprintln - s!"error: end-to-end logitDiffLB {finalLB} \ - below minimum {minLogitDiff}" - return (2 : UInt32) - | none => - IO.println - s!"ok: end-to-end induction bound certified \ - (seq={seq}, active={activeCount}, \ - logitDiffLB={logitDiffLB}, \ - downstreamError={downstream.error}, \ - finalLB={finalLB})" - return 0 -/-- Check end-to-end induction certificates using a model file and residual bounds - (loaded from disk or derived from the model). -/ -def runInductionCertifyEndToEndModel (scoresPath : System.FilePath) - (valuesPath : System.FilePath) (modelPath : System.FilePath) - (residualIntervalPath? : Option System.FilePath) - (layer? : Option Nat) (head? : Option Nat) (period? : Option Nat) - (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do - warnDeprecated - "certify_end_to_end_model derives residual bounds from a model file; \ - use an explicit residual-interval cert instead." - let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? - let minMargin?E := parseRatOpt "min-margin" minMarginStr? - let maxEps?E := parseRatOpt "max-eps" maxEpsStr? - match minLogitDiff?E, minMargin?E, maxEps?E with - | Except.error msg, _, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, Except.error msg, _ => - IO.eprintln s!"error: {msg}" - return 2 - | _, _, Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => do - let minMargin := minMargin?.getD (0 : Rat) - let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) - let parsedScores ← timePhase "load softmax cert" <| - loadSoftmaxMarginCert scoresPath - match parsedScores with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seq, cert⟩ => - let scoresOk ← timePhase "check softmax cert" <| - checkSoftmaxMargin seq cert - match scoresOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - let activeCount := cert.active.card - let defaultMinActive := max 1 (seq / 8) - let minActive := minActive?.getD defaultMinActive - if activeCount < minActive then - IO.eprintln - s!"error: active queries {activeCount} below minimum {minActive}" - return 2 - if cert.margin < minMargin then - IO.eprintln - s!"error: margin {cert.margin} below minimum {minMargin}" - return 2 - if maxEps < cert.eps then - IO.eprintln - s!"error: eps {cert.eps} above maximum {maxEps}" - return 2 - let parsedValues ← timePhase "load value cert" <| - loadValueRangeCert valuesPath - match parsedValues with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨seqVals, certVals⟩ => - if hseq : seqVals ≠ seq then - IO.eprintln s!"error: seq mismatch (scores={seq}, values={seqVals})" - return 2 - else - have hseq' : seqVals = seq := by - exact (not_ne_iff).1 hseq - let certVals' : ValueRangeCert seq := by - simpa [hseq'] using certVals - let valuesOk ← timePhase "check value cert" <| - checkValueRange seq certVals' - match valuesOk with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 2 - | Except.ok () => - let logitDiffLB? ← timePure "logit-diff lower bound" (fun () => - Circuit.logitDiffLowerBound cert.active cert.prev cert.eps - certVals'.lo certVals'.hi certVals'.vals) - let effectiveMinLogitDiff := - match minLogitDiff?, certVals'.direction with - | some v, _ => some v - | none, some _ => some (0 : Rat) - | none, none => none - match logitDiffLB? with - | none => - IO.eprintln "error: empty active set for logit-diff bound" - return (2 : UInt32) - | some logitDiffLB => - match certVals'.direction with - | none => - IO.eprintln - "error: value-range certificate missing direction \ - metadata" - return 2 - | some dirSpec => - let data ← timePhase "read model file" <| - IO.FS.readBinFile modelPath - let headerE ← timePure "parse model header" (fun () => - NfptPure.parseHeader data) - match headerE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok ⟨header, start⟩ => - if hseq : header.seqLen = seq then - let active? : Option (Finset (Fin header.seqLen)) := - if hactive : cert.active.Nonempty then - some (by simpa [hseq] using cert.active) - else - none - let residualCertE : Except String - (ResidualIntervalCert header.modelDim) ← - match residualIntervalPath? with - | some residualIntervalPath => do - let parsedResidual ← - timePhase "load residual interval" <| - loadResidualIntervalCert residualIntervalPath - match parsedResidual with - | Except.error msg => pure (Except.error msg) - | Except.ok ⟨dim, residualCert⟩ => - if hdim : dim = header.modelDim then - let residualCert' : - ResidualIntervalCert header.modelDim := by - simpa [hdim] using residualCert - pure (Except.ok residualCert') - else - pure (Except.error - s!"residual interval dim {dim} \ - does not match model dim {header.modelDim}") - | none => - deriveResidualIntervalFromModel data start header - active? - match residualCertE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok residualCert' => - let residualOk ← - timePure "check residual interval" (fun () => - Circuit.checkResidualIntervalCert residualCert') - if residualOk then - let dirPos := dirSpec.target - let dirNeg := dirSpec.negative - if layer?.isSome != head?.isSome then - IO.eprintln - "error: --layer and --head must be provided \ - together" - return 2 - let headChoice? : Option (Nat × Nat) := - match layer?, head? with - | some layer, some head => some (layer, head) - | _, _ => none - if period?.isSome && headChoice?.isNone then - IO.eprintln - "warning: --period ignored without \ - --layer/--head" - let colTargetE ← - timePure "read unembed column target" (fun () => - NfptPure.readUnembedColumn - data start header dirPos) - match colTargetE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok colTarget => - let colNegE ← - timePure "read unembed column negative" (fun () => - NfptPure.readUnembedColumn - data start header dirNeg) - match colNegE with - | Except.error msg => - IO.eprintln s!"error: {msg}" - return 1 - | Except.ok colNeg => - let dirVec : - Fin header.modelDim → Rat := - fun i => colTarget i - colNeg i - let dotIntervalAbs := - Sound.Bounds.dotIntervalAbsBound - let intervalErrorFromHead? : - Model.InductionHeadInputs - seq header.modelDim header.headDim → - ResidualIntervalCert header.modelDim → - Option Rat := - fun inputs residual => by - classical - match hseq0 : seq with - | 0 => exact none - | Nat.succ n => - let _ : NeZero seq := by - exact ⟨by simp [hseq0]⟩ - match - Sound.buildHeadOutputIntervalFromHead? - inputs with - | none => exact none - | some result => - exact some - (dotIntervalAbs - dirVec - (fun i => - residual.lo i - - result.cert.hi i) - (fun i => - residual.hi i - - result.cert.lo i)) - let downstreamError ← - timePure "downstream error" (fun () => - dotIntervalAbs - dirVec - residualCert'.lo - residualCert'.hi) - let finalLB := logitDiffLB - downstreamError - let intervalError? ← - match headChoice? with - | none => pure none - | some (layer, head) => do - let inputsE ← - timePure "read head inputs" (fun () => - NfptPure.readInductionHeadInputs - data start header layer head - dirPos dirNeg period? false) - match inputsE with - | Except.error msg => - IO.eprintln s!"warning: {msg}" - pure none - | Except.ok inputs => - let inputs' : - Model.InductionHeadInputs - seq header.modelDim - header.headDim := by - simpa [hseq] using inputs - let inputsAligned : - Model.InductionHeadInputs - seq header.modelDim - header.headDim := - { inputs' with - active := cert.active - prev := cert.prev } - let intervalError? ← - timePure - "head output interval" - (fun () => - intervalErrorFromHead? - inputsAligned - residualCert') - match intervalError? with - | none => - IO.eprintln - "warning: head output interval \ - rejected" - pure none - | some intervalError => - pure (some intervalError) - let intervalLB? := - intervalError?.map (fun err => - logitDiffLB - err) - let effectiveLB := - match intervalLB? with - | some intervalLB => max finalLB intervalLB - | none => finalLB - let violation? : Option Rat := - match effectiveMinLogitDiff with - | none => none - | some minLogitDiff => - if effectiveLB < minLogitDiff then - some minLogitDiff - else - none - match violation? with - | some minLogitDiff => - IO.eprintln - s!"error: end-to-end bound \ - {effectiveLB} below minimum \ - {minLogitDiff}" - return (2 : UInt32) - | none => - match intervalLB? with - | none => - IO.println - s!"ok: end-to-end induction \ - bound certified (seq={seq}, \ - active={activeCount}, \ - logitDiffLB={logitDiffLB}, \ - downstreamError={downstreamError}, \ - finalLB={finalLB})" - | some intervalLB => - let intervalError := - logitDiffLB - intervalLB - IO.println - s!"ok: end-to-end induction \ - bound certified (seq={seq}, \ - active={activeCount}, \ - logitDiffLB={logitDiffLB}, \ - downstreamError={downstreamError}, \ - finalLB={finalLB}, \ - intervalError={intervalError}, \ - intervalLB={intervalLB}, \ - effectiveLB={effectiveLB})" - return 0 - else - IO.eprintln - "error: residual-interval certificate rejected" - return 2 - else - IO.eprintln - s!"error: model seq {header.seqLen} \ - does not match cert seq {seq}" - return 2 -end IO -end Nfp diff --git a/Nfp/IO/Timing.lean b/Nfp/IO/Timing.lean deleted file mode 100644 index e755868..0000000 --- a/Nfp/IO/Timing.lean +++ /dev/null @@ -1,454 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Mathlib.Data.List.Range -public import Nfp.IO.HeadScore -public import Nfp.Model.InductionHead -public import Nfp.Sound.Induction.HeadBounds -public import Nfp.Sound.Induction.LogitDiff - -/-! -Small IO helpers for profiling slow phases. --/ - -public section - -namespace Nfp - -namespace IO - -open Sound - -/-- Current monotonic time in microseconds. -/ -def monoUsNow : IO Nat := do - let t ← IO.monoNanosNow - return t / 1000 - -/-! Timing configuration -/ - -/-- Runtime configuration for timing output. -/ -structure TimingConfig where - /-- Optional stdout override for timing output. -/ - stdout? : Option Bool - /-- Optional heartbeat interval override (ms). -/ - heartbeatMs? : Option UInt32 - deriving Inhabited - -/-- Mutable timing configuration (overrides environment defaults). -/ -initialize timingConfig : IO.Ref TimingConfig ← - IO.mkRef { stdout? := none, heartbeatMs? := none } - -/-- Enable or disable timing stdout output. -/ -def setTimingStdout (enabled : Bool) : IO Unit := do - timingConfig.modify (fun cfg => { cfg with stdout? := some enabled }) - -/-- Override the heartbeat interval (ms). -/ -def setTimingHeartbeatMs (ms : UInt32) : IO Unit := do - timingConfig.modify (fun cfg => { cfg with heartbeatMs? := some ms }) - -/-- Resolve whether timing output should be printed. -/ -def timingStdoutEnabled : IO Bool := do - let cfg ← timingConfig.get - match cfg.stdout? with - | some enabled => return enabled - | none => - match (← IO.getEnv "NFP_TIMING_STDOUT") with - | some "1" => return true - | some "true" => return true - | some "yes" => return true - | _ => return false - -/-- Resolve the heartbeat interval (ms), respecting overrides. -/ -def timingHeartbeatMs : IO UInt32 := do - let cfg ← timingConfig.get - match cfg.heartbeatMs? with - | some ms => return ms - | none => - let defaultMs : Nat := 0 - let ms := - (← IO.getEnv "NFP_TIMING_HEARTBEAT_MS").bind String.toNat? |>.getD defaultMs - return UInt32.ofNat ms - -/-- Resolve the heartbeat interval (ms) for long-running induction cert builds. -/ -def heartbeatMs : IO UInt32 := - timingHeartbeatMs - -/-- Print a timing line only when stdout timing is enabled. -/ -def timingPrint (line : String) : IO Unit := do - if (← timingStdoutEnabled) then - IO.println line - else - pure () - -/-- Flush stdout only when timing output is enabled. -/ -def timingFlush : IO Unit := do - if (← timingStdoutEnabled) then - let h ← IO.getStdout - h.flush - else - pure () - -/-- Append a timing log line to `NFP_TIMING_LOG` when set. -/ -def logTiming (line : String) : IO Unit := do - match (← IO.getEnv "NFP_TIMING_LOG") with - | some path => - let h ← IO.FS.Handle.mk (System.FilePath.mk path) IO.FS.Mode.append - h.putStr (line ++ "\n") - h.flush - | none => pure () - -/-- Time an IO phase and print the duration when timing output is enabled. -/ -def timePhase {α : Type} (label : String) (act : IO α) : IO α := do - logTiming s!"start: {label}" - let t0 ← monoUsNow - let res ← act - let t1 ← monoUsNow - logTiming s!"done: {label} {t1 - t0} us" - timingPrint s!"timing: {label} {t1 - t0} us" - return res - -/-- Time an IO phase supplied as a thunk and print the duration when timing output is enabled. -/ -def timePhaseThunk {α : Type} (label : String) (act : Unit → IO α) : IO α := do - logTiming s!"start: {label}" - let t0 ← monoUsNow - let res ← act () - let t1 ← monoUsNow - logTiming s!"done: {label} {t1 - t0} us" - timingPrint s!"timing: {label} {t1 - t0} us" - return res - -/-- Time a pure thunk and print the duration when timing output is enabled. -/ -def timePure {α : Type} (label : String) (f : Unit → α) : IO α := do - logTiming s!"start: {label}" - let t0 ← monoUsNow - let res := f () - let t1 ← monoUsNow - logTiming s!"done: {label} {t1 - t0} us" - timingPrint s!"timing: {label} {t1 - t0} us" - return res - -/-- Time a pure thunk, printing heartbeat updates while it runs. -/ -def timePureWithHeartbeat {α : Type} (label : String) (f : Unit → α) : IO α := do - let t0 ← monoUsNow - timingPrint s!"timing: {label} start" - timingFlush - let task : Task α := Task.spawn (fun _ => f ()) - let heartbeatMs ← heartbeatMs - if heartbeatMs ≠ 0 then - let mut finished := (← IO.hasFinished task) - while !finished do - IO.sleep heartbeatMs - finished := (← IO.hasFinished task) - if !finished then - let now ← monoUsNow - timingPrint s!"timing: {label} running {now - t0} us" - timingFlush - let res ← IO.wait task - let t1 ← monoUsNow - timingPrint s!"timing: {label} {t1 - t0} us" - return res - -/-- Flush stdout immediately for interleaved timing output. -/ -def flushStdout : IO Unit := do - let h ← IO.getStdout - h.flush - -/-- Force a sample score-gap computation for timing. -/ -def timeHeadScoreSampleGap {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (score : Sound.HeadScoreBounds seq dModel dHead) : IO Unit := do - timingPrint "timing: head score sample gap start" - timingFlush - let t0 ← monoUsNow - match List.finRange seq with - | [] => - timingPrint "timing: head score sample gap skipped (empty seq)" - | q :: _ => - let _ := score.scoreLo q (inputs.prev q) - let _ := score.scoreHi q (inputs.prev q) - let _ := score.scoreLo q (inputs.prev q) - score.scoreHi q (inputs.prev q) - pure () - let t1 ← monoUsNow - timingPrint s!"timing: head score sample gap {t1 - t0} us" - timingFlush - -/-- Force marginAt evaluation over the active list for timing. -/ -def timeHeadScoreMarginList {seq dModel dHead : Nat} - (activeList : List (Fin seq)) - (score : Sound.HeadScoreBounds seq dModel dHead) : IO Unit := do - timingPrint "timing: head score marginAt list start" - timingFlush - let t0 ← monoUsNow - for q in activeList do - let _ := score.marginAt q - pure () - let t1 ← monoUsNow - timingPrint s!"timing: head score marginAt list {t1 - t0} us" - timingFlush - -/-- Force marginAt evaluation without constructing the full score bounds record. -/ -def timeHeadScoreMarginRaw {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (dotAbs : Fin seq → Fin seq → Rat) - (activeList : List (Fin seq)) : IO Unit := do - timingPrint "timing: head score marginRaw list start" - timingFlush - let t0 ← monoUsNow - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => - |inputs.scale| * dotAbs q k - let scoreLo : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - -scoreBaseAbs q k - let scoreHi : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - scoreBaseAbs q k - let otherKeys : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - let maskedKeys : Fin seq → Finset (Fin seq) := fun q => - if inputs.maskCausal = true then - (otherKeys q).filter (fun k => q < k) - else - (∅ : Finset (Fin seq)) - let unmaskedKeys : Fin seq → Finset (Fin seq) := fun q => - (otherKeys q) \ (maskedKeys q) - let maskedGap : Fin seq → Rat := fun q => - scoreLo q (inputs.prev q) - inputs.maskValue - let scoreGap : Fin seq → Fin seq → Rat := fun q k => - scoreLo q (inputs.prev q) - scoreHi q k - let marginAtRaw : Fin seq → Rat := fun q => - let other := unmaskedKeys q - let maskedSet := maskedKeys q - if hunmasked : other.Nonempty then - let unmaskedMin := other.inf' hunmasked (fun k => scoreGap q k) - if _hmasked : maskedSet.Nonempty then - min unmaskedMin (maskedGap q) - else - unmaskedMin - else - if _hmasked : maskedSet.Nonempty then - maskedGap q - else - (0 : Rat) - for q in activeList do - let _ := marginAtRaw q - pure () - let t1 ← monoUsNow - timingPrint s!"timing: head score marginRaw list {t1 - t0} us" - timingFlush - -/-- Force individual score-bound fields to locate slow evaluations. -/ -def timeHeadScoreFieldForces {seq dModel dHead : Nat} - (score : Sound.HeadScoreBounds seq dModel dHead) : IO Unit := do - timingPrint "timing: head score field force start" - timingFlush - let timeOne (label : String) (f : Unit → IO Unit) : IO Unit := do - let t0 ← monoUsNow - f () - let t1 ← monoUsNow - timingPrint s!"timing: head score field {label} {t1 - t0} us" - timingFlush - match List.finRange seq with - | [] => - timingPrint "timing: head score field force skipped (empty seq)" - timingFlush - | q :: _ => - match List.finRange seq with - | [] => - timingPrint "timing: head score field force skipped (empty seq)" - timingFlush - | k :: _ => - timeOne "scoreBaseAbs" (fun _ => do let _ := score.scoreBaseAbs q k; pure ()) - timeOne "scoreAbs" (fun _ => do let _ := score.scoreAbs q k; pure ()) - timeOne "scoreLo" (fun _ => do let _ := score.scoreLo q k; pure ()) - timeOne "scoreHi" (fun _ => do let _ := score.scoreHi q k; pure ()) - timeOne "marginAt" (fun _ => do let _ := score.marginAt q; pure ()) - timeOne "epsAt" (fun _ => do let _ := score.epsAt q; pure ()) - timeOne "margin" (fun _ => do let _ := score.margin; pure ()) - timeOne "eps" (fun _ => do let _ := score.eps; pure ()) - timingPrint "timing: head score field force done" - timingFlush - -/-- Force a rational to help isolate cached computations. -/ -def forceRat (x : Rat) : IO Unit := do - if x = x then - pure () - else - pure () - -/-- Report detailed timing for weighted logit-diff components when enabled. -/ -def logitDiffProfileEnabled : IO Bool := do - return (← IO.getEnv "NFP_TIMING_LOGITDIFF_PROFILE").isSome - -/-- Profile weighted logit-diff sub-steps when logit-diff profiling is enabled. -/ -def profileLogitDiffWeighted {seq : Nat} - (cert : Sound.InductionHeadCert seq) - (cache : Sound.LogitDiffCache seq) : IO Unit := do - if !(← logitDiffProfileEnabled) then - pure () - else - timingPrint "timing: logit-diff profile start" - timingFlush - let valsLoArr ← timePureWithHeartbeat "logit-diff profile: valsLo force" (fun () => - Array.ofFn (fun q : Fin seq => cache.valsLo q)) - let weightRows ← timePureWithHeartbeat "logit-diff profile: weightBoundAt force" (fun () => - Array.ofFn (fun q : Fin seq => - Array.ofFn (fun k : Fin seq => cert.weightBoundAt q k))) - let valsLo : Fin seq → Rat := fun k => - valsLoArr.getD k.1 (0 : Rat) - let weightBoundAt : Fin seq → Fin seq → Rat := fun q k => - let row := weightRows.getD q.1 #[] - row.getD k.1 (0 : Rat) - let _ ← timePureWithHeartbeat "logit-diff profile: weighted gap sum" (fun () => - Array.ofFn (fun q : Fin seq => - let valsLoPrev := valsLo (cert.prev q) - Linear.sumFin seq (fun k => - let diff := valsLoPrev - valsLo k - weightBoundAt q k * max (0 : Rat) diff))) - let _ ← timePureWithHeartbeat "logit-diff profile: weighted min" (fun () => - let gap : Fin seq → Rat := fun q => - let valsLoPrev := valsLo (cert.prev q) - Linear.sumFin seq (fun k => - let diff := valsLoPrev - valsLo k - weightBoundAt q k * max (0 : Rat) diff) - let f : Fin seq → Rat := fun q => valsLo (cert.prev q) - gap q - if h : cert.active.Nonempty then - let _ := cert.active.inf' h f - () - else - ()) - -/-- Profile the core induction-head bounds used by the sound certificate builder. -/ -def timeInductionHeadCoreStages {seq dModel dHead : Nat} [NeZero seq] - (inputs : Model.InductionHeadInputs seq dModel dHead) : IO Unit := do - timingPrint "timing: core stages start" - timingFlush - let lnBounds ← timePureWithHeartbeat "core: ln bounds" (fun () => - Sound.headLnBounds inputs) - let lnLo := lnBounds.1 - let lnHi := lnBounds.2 - let lnAbsMaxTask : Fin seq → Rat := - Sound.Bounds.cacheBoundTask (fun q => - Sound.Bounds.intervalAbsBound (lnLo q) (lnHi q)) - let lnAbsMaxArr ← timePureWithHeartbeat "core: lnAbsMax force" (fun () => - Array.ofFn (fun q : Fin seq => lnAbsMaxTask q)) - let lnAbsMax : Fin seq → Rat := fun q => - lnAbsMaxArr.getD q.1 (0 : Rat) - let lnAbsMaxMax : Rat := - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := by - simp [univ] - univ.sup' hnonempty (fun q => lnAbsMax q) - let qAbsRowTasks : Array (Task (Array Rat)) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - Array.ofFn (fun d : Fin dHead => - Sound.Bounds.dotIntervalAbsBound (fun j => inputs.wq j d) (lnLo q) (lnHi q) + - |inputs.bq d|))) - let qAbsBaseArr ← timePureWithHeartbeat "core: qAbs base force" (fun () => - let defaultTask : Task (Array Rat) := Task.spawn (fun _ => #[]) - Array.ofFn (fun q : Fin seq => - (qAbsRowTasks.getD q.1 defaultTask).get)) - let qAbsBase : Fin seq → Fin dHead → Rat := fun q d => - let row := qAbsBaseArr.getD q.1 #[] - row.getD d.1 (0 : Rat) - let kAbsRowTasks : Array (Task (Array Rat)) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - Array.ofFn (fun d : Fin dHead => - Sound.Bounds.dotIntervalAbsBound (fun j => inputs.wk j d) (lnLo q) (lnHi q) + - |inputs.bk d|))) - let kAbsBaseArr ← timePureWithHeartbeat "core: kAbs base force" (fun () => - let defaultTask : Task (Array Rat) := Task.spawn (fun _ => #[]) - Array.ofFn (fun q : Fin seq => - (kAbsRowTasks.getD q.1 defaultTask).get)) - let kAbsBase : Fin seq → Fin dHead → Rat := fun q d => - let row := kAbsBaseArr.getD q.1 #[] - row.getD d.1 (0 : Rat) - let dotAbs ← timePureWithHeartbeat "core: dotAbs tasks" (fun () => - dotAbsFromQKV qAbsBase kAbsBase) - let _ ← timePureWithHeartbeat "core: dotAbs force" (fun () => - match List.finRange seq with - | [] => (0 : Rat) - | q :: _ => - match List.finRange seq with - | [] => (0 : Rat) - | k :: _ => dotAbs q k) - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => - |inputs.scale| * dotAbs q k - let scoreLo : Fin seq → Fin seq → Rat := fun q k => - if masked q k then inputs.maskValue else -scoreBaseAbs q k - let scoreHi : Fin seq → Fin seq → Rat := fun q k => - if masked q k then inputs.maskValue else scoreBaseAbs q k - let scoreLoPrev : Fin seq → Rat := fun q => - scoreLo q (inputs.prev q) - let otherKeys : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - let marginAt : Fin seq → Rat := fun q => - if hq : q ∈ inputs.active then - let other := otherKeys q - if h : other.Nonempty then - other.inf' h (fun k => scoreLoPrev q - scoreHi q k) - else - (0 : Rat) - else - (0 : Rat) - let margin ← timePureWithHeartbeat "core: margin" (fun () => - if h : inputs.active.Nonempty then - inputs.active.inf' h marginAt - else - (0 : Rat)) - let marginNeg ← timePureWithHeartbeat "core: margin < 0" (fun () => - decide (margin < 0)) - let verboseTiming ← IO.getEnv "NFP_TIMING_VERBOSE" - if verboseTiming.isSome then - timingPrint s!"timing: core: margin neg={marginNeg}" - let tEps0 ← monoUsNow - timingPrint "timing: core: eps start" - timingFlush - let eps := - if marginNeg then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + margin) - let tEps1 ← monoUsNow - timingPrint s!"timing: core: eps {tEps1 - tEps0} us" - timingFlush - let _ := marginAt - let dirHeadVec ← timePureWithHeartbeat "core: dir head vec" (fun () => - Sound.dirHeadVecOfInputs inputs) - let dirHead : Fin dHead → Rat := fun d => dirHeadVec.get d - let wvDir : Fin dModel → Rat := - Sound.Bounds.cacheBoundTask (fun j => - Sound.Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) - let _ ← timePureWithHeartbeat "core: wvDir force" (fun () => - Array.ofFn (fun j : Fin dModel => wvDir j)) - let bDir ← timePureWithHeartbeat "core: bDir" (fun () => - Sound.Linear.dotFin dHead dirHead (fun d => inputs.bv d)) - let valsAbsBase ← timePureWithHeartbeat "core: valsAbsBase" (fun () => - Sound.Linear.sumFin dModel (fun j => |wvDir j|) * lnAbsMaxMax) - let valsLoBase := bDir - valsAbsBase - let valsHiBase := bDir + valsAbsBase - let valsLo : Fin seq → Rat := fun _ => valsLoBase - let valsHi : Fin seq → Rat := fun _ => valsHiBase - let _ ← timePureWithHeartbeat "core: value bounds" (fun () => - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := by - simp [univ] - let lo := univ.inf' hnonempty valsLo - let hi := univ.sup' hnonempty valsHi - (lo, hi)) - timingPrint "timing: core stages done" - timingFlush - -end IO - -end Nfp diff --git a/Nfp/Sound/Bounds/MatrixNorm/Basic.lean b/Nfp/Sound/Bounds/MatrixNorm/Basic.lean index 5e9bd32..e431416 100644 --- a/Nfp/Sound/Bounds/MatrixNorm/Basic.lean +++ b/Nfp/Sound/Bounds/MatrixNorm/Basic.lean @@ -8,8 +8,6 @@ public import Mathlib.Algebra.Order.Ring.Abs public import Mathlib.Data.Fintype.Basic public import Mathlib.Data.Matrix.Mul public import Mathlib.Data.Real.Basic -public import Nfp.Circuit.Cert.DownstreamLinear -public import Nfp.Circuit.Cert.ResidualInterval public import Nfp.Core.Basic public import Nfp.Sound.Bounds.MatrixNorm.Interval public import Nfp.Sound.Linear.FinFold @@ -103,17 +101,6 @@ theorem downstreamErrorFromBounds_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) 0 ≤ downstreamErrorFromBounds W bound := by simpa [downstreamErrorFromBounds] using rowSumWeightedNorm_nonneg W bound -/-- Build a residual-interval certificate by applying a matrix to an input interval. -/ -def buildResidualIntervalCertFromMatrix {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (lo hi : Fin n → Rat) (hlohi : ∀ j, lo j ≤ hi j) : - {c : Circuit.ResidualIntervalCert m // Circuit.ResidualIntervalBounds c} := by - let lo' := mulVecIntervalLower W lo hi - let hi' := mulVecIntervalUpper W lo hi - refine ⟨{ lo := lo', hi := hi' }, ?_⟩ - refine { lo_le_hi := ?_ } - intro i - exact mulVecIntervalLower_le_upper W lo hi hlohi i - /-- Summed absolute row entries factor out a scalar bound. -/ theorem sum_abs_row_mul_eq_rowSum_mul {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) (inputBound : Rat) : @@ -164,22 +151,6 @@ theorem abs_mulVec_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) mul_le_mul_of_nonneg_right hle hinput exact hrow.trans hmul -/-- Build a downstream linear certificate from a matrix and input bound. -/ -def buildDownstreamLinearCert {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (inputBound : Rat) (hinput : 0 ≤ inputBound) : - {c : Circuit.DownstreamLinearCert // Circuit.DownstreamLinearBounds c} := by - let gain := rowSumNorm W - let error := gain * inputBound - refine ⟨{ error := error, gain := gain, inputBound := inputBound }, ?_⟩ - refine - { error_nonneg := ?_ - gain_nonneg := ?_ - input_nonneg := hinput - error_eq := rfl } - · exact mul_nonneg (rowSumNorm_nonneg W) hinput - · exact rowSumNorm_nonneg W - - end Bounds end Sound diff --git a/Nfp/Sound/Induction/Core/Basic.lean b/Nfp/Sound/Induction/Core/Basic.lean index 07459fd..49acad6 100644 --- a/Nfp/Sound/Induction/Core/Basic.lean +++ b/Nfp/Sound/Induction/Core/Basic.lean @@ -28,74 +28,6 @@ open scoped BigOperators open Nfp.Circuit open Nfp.Sound.Bounds variable {seq : Nat} -/-- Build and certify a softmax-margin certificate from exact scores/weights. -/ -def buildSoftmaxMarginCert? [NeZero seq] - (active : Finset (Fin seq)) - (prev : Fin seq → Fin seq) - (scores : Fin seq → Fin seq → Rat) - (weights : Fin seq → Fin seq → Rat) : - Option {c : SoftmaxMarginCert seq // checkSoftmaxMarginCert c = true} := by - classical - let otherKeys : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (prev q) - let epsAt : Fin seq → Rat := fun q => - let other := otherKeys q - let maxOther := - if h : other.Nonempty then - other.sup' h (fun k => weights q k) - else - (0 : Rat) - let deficit := (1 : Rat) - weights q (prev q) - max maxOther deficit - let marginAt : Fin seq → Rat := fun q => - let other := otherKeys q - if h : other.Nonempty then - other.inf' h (fun k => scores q (prev q) - scores q k) - else - (0 : Rat) - let eps := - if h : active.Nonempty then - active.sup' h epsAt - else - (0 : Rat) - let margin := - if h : active.Nonempty then - active.inf' h marginAt - else - (0 : Rat) - let cert : SoftmaxMarginCert seq := - { eps := eps - margin := margin - active := active - prev := prev - scores := scores - weights := weights } - if h : checkSoftmaxMarginCert cert = true then - exact some ⟨cert, h⟩ - else - exact none -/-- Build and certify a value-range certificate from exact values. -/ -def buildValueRangeCert? [NeZero seq] - (vals : Fin seq → Rat) - (direction : Option DirectionSpec) : - Option {c : ValueRangeCert seq // checkValueRangeCert c = true} := by - classical - let _ : Nonempty (Fin seq) := by - refine ⟨⟨0, ?_⟩⟩ - exact Nat.pos_of_ne_zero (NeZero.ne seq) - let univ : Finset (Fin seq) := Finset.univ - let hnonempty : univ.Nonempty := Finset.univ_nonempty - let lo := univ.inf' hnonempty vals - let hi := univ.sup' hnonempty vals - let cert : ValueRangeCert seq := - { lo := lo - hi := hi - vals := vals - direction := direction } - if h : checkValueRangeCert cert = true then - exact some ⟨cert, h⟩ - else - exact none /-- Cached bounds and derived quantities for induction-head core certificates. -/ structure InductionHeadCoreCache (seq dModel dHead : Nat) where /-- Cached LayerNorm bound pair. -/ From a79bd40620f88c8ecd336da567bfbaa497f622ba Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 08:07:17 +0100 Subject: [PATCH 217/244] Remove induction cert generators from Sound --- Nfp/Sound/Induction.lean | 10 +- Nfp/Sound/Induction/Core.lean | 4 +- Nfp/Sound/Induction/Core/Basic.lean | 1191 --------------- Nfp/Sound/Induction/CoreSound.lean | 9 - Nfp/Sound/Induction/CoreSound/Basic.lean | 11 - .../CoreSound/Basic/CacheBounds.lean | 615 -------- .../Induction/CoreSound/Basic/CertSound.lean | 1316 ----------------- .../CoreSound/Basic/DefaultSound.lean | 29 - Nfp/Sound/Induction/CoreSound/Values.lean | 229 --- Nfp/Sound/Induction/HeadBounds.lean | 9 - Nfp/Sound/Induction/HeadBounds/Basic.lean | 1242 ---------------- Nfp/Sound/Induction/HeadOutput.lean | 372 +---- Nfp/Sound/Induction/LogitDiff.lean | 217 +-- Nfp/Sound/Induction/Refine.lean | 546 ------- Nfp/Sound/Induction/RefineSound.lean | 907 ------------ 15 files changed, 8 insertions(+), 6699 deletions(-) delete mode 100644 Nfp/Sound/Induction/Core/Basic.lean delete mode 100644 Nfp/Sound/Induction/CoreSound.lean delete mode 100644 Nfp/Sound/Induction/CoreSound/Basic.lean delete mode 100644 Nfp/Sound/Induction/CoreSound/Basic/CacheBounds.lean delete mode 100644 Nfp/Sound/Induction/CoreSound/Basic/CertSound.lean delete mode 100644 Nfp/Sound/Induction/CoreSound/Basic/DefaultSound.lean delete mode 100644 Nfp/Sound/Induction/CoreSound/Values.lean delete mode 100644 Nfp/Sound/Induction/HeadBounds.lean delete mode 100644 Nfp/Sound/Induction/HeadBounds/Basic.lean delete mode 100644 Nfp/Sound/Induction/Refine.lean delete mode 100644 Nfp/Sound/Induction/RefineSound.lean diff --git a/Nfp/Sound/Induction.lean b/Nfp/Sound/Induction.lean index 0343827..a1360a4 100644 --- a/Nfp/Sound/Induction.lean +++ b/Nfp/Sound/Induction.lean @@ -3,18 +3,14 @@ module public import Nfp.Sound.Induction.Core -public import Nfp.Sound.Induction.CoreSound public import Nfp.Sound.Induction.EndToEnd -public import Nfp.Sound.Induction.HeadBounds public import Nfp.Sound.Induction.HeadOutput public import Nfp.Sound.Induction.LogitDiff public import Nfp.Sound.Induction.OneHot -public import Nfp.Sound.Induction.Refine -public import Nfp.Sound.Induction.RefineSound /-! -Sound builders for induction certificates. +Soundness lemmas for induction certificates. -This module re-exports the core constructions, head-output interval bounds, -and logit-diff helpers. +This module re-exports the core definitions, head-output interval predicates, +and logit-diff helpers that operate on explicit certificates. -/ diff --git a/Nfp/Sound/Induction/Core.lean b/Nfp/Sound/Induction/Core.lean index d87c878..b67eca5 100644 --- a/Nfp/Sound/Induction/Core.lean +++ b/Nfp/Sound/Induction/Core.lean @@ -2,8 +2,8 @@ module -public import Nfp.Sound.Induction.Core.Basic +public import Nfp.Sound.Induction.CoreDefs /-! -Core definitions and constructors for induction certificates. +Core definitions for induction certificates. -/ diff --git a/Nfp/Sound/Induction/Core/Basic.lean b/Nfp/Sound/Induction/Core/Basic.lean deleted file mode 100644 index 49acad6..0000000 --- a/Nfp/Sound/Induction/Core/Basic.lean +++ /dev/null @@ -1,1191 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later -module - -public import Mathlib.Algebra.BigOperators.Group.Finset.Basic -public import Mathlib.Algebra.Order.BigOperators.Group.Finset -public import Mathlib.Algebra.Order.Field.Basic -public import Nfp.Core.Basic -public import Mathlib.Data.Finset.Lattice.Fold -public import Nfp.Circuit.Cert.ResidualInterval -public import Nfp.Circuit.Cert.SoftmaxMargin -public import Nfp.Circuit.Cert.ValueRange -public import Nfp.Sound.Bounds.Attention -public import Nfp.Sound.Bounds.Cache -public import Nfp.Sound.Bounds.LayerNorm -public import Nfp.Sound.Bounds.LayerNorm.InvStd -public import Nfp.Sound.Bounds.MatrixNorm -public import Nfp.Sound.Induction.CoreDefs -public import Nfp.Sound.Induction.OneHot -public import Nfp.Sound.Linear.FinFold -/-! Sound builders for induction certificates; recompute bounds inside Lean from exact inputs and -derive softmax tolerances from score margins rather than trusting external weight dumps. -/ - -public section - -namespace Nfp -namespace Sound -open scoped BigOperators -open Nfp.Circuit -open Nfp.Sound.Bounds -variable {seq : Nat} -/-- Cached bounds and derived quantities for induction-head core certificates. -/ -structure InductionHeadCoreCache (seq dModel dHead : Nat) where - /-- Cached LayerNorm bound pair. -/ - lnBounds : (Fin seq → Fin dModel → Rat) × (Fin seq → Fin dModel → Rat) - /-- LayerNorm lower bounds. -/ - lnLo : Fin seq → Fin dModel → Rat - /-- LayerNorm upper bounds. -/ - lnHi : Fin seq → Fin dModel → Rat - /-- Tasks for LayerNorm absolute maxima. -/ - lnAbsMaxTask : Fin seq → Rat - /-- Cached LayerNorm absolute maxima. -/ - lnAbsMaxArr : Array Rat - /-- LayerNorm absolute-max lookup. -/ - lnAbsMax : Fin seq → Rat - /-- Tasks for inverse-std bounds. -/ - invStdBoundsTasks : Array (Task (Rat × Rat)) - /-- Cached inverse-std bounds. -/ - invStdBoundsArr : Array (Rat × Rat) - /-- Inverse-std lower bounds. -/ - invStdLo : Fin seq → Rat - /-- Inverse-std upper bounds. -/ - invStdHi : Fin seq → Rat - /-- Cached query base terms. -/ - qBaseArr : Array Rat - /-- Query base lookup. -/ - qBase : Fin dHead → Rat - /-- Cached key base terms. -/ - kBaseArr : Array Rat - /-- Key base lookup. -/ - kBase : Fin dHead → Rat - /-- Tasks for query coefficient rows. -/ - qCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) - /-- Cached query coefficient rows. -/ - qCoeffArr : Array { row : Array Rat // row.size = dHead } - /-- Query coefficient lookup. -/ - qCoeff : Fin seq → Fin dHead → Rat - /-- Tasks for key coefficient rows. -/ - kCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) - /-- Cached key coefficient rows. -/ - kCoeffArr : Array { row : Array Rat // row.size = dHead } - /-- Key coefficient lookup. -/ - kCoeff : Fin seq → Fin dHead → Rat - /-- Query lower bounds. -/ - qLo : Fin seq → Fin dHead → Rat - /-- Query upper bounds. -/ - qHi : Fin seq → Fin dHead → Rat - /-- Key lower bounds. -/ - kLo : Fin seq → Fin dHead → Rat - /-- Key upper bounds. -/ - kHi : Fin seq → Fin dHead → Rat - /-- Query absolute bounds. -/ - qAbs : Fin seq → Fin dHead → Rat - /-- Key absolute bounds. -/ - kAbs : Fin seq → Fin dHead → Rat - /-- Cached max query abs bounds. -/ - qAbsMaxArr : Array Rat - /-- Max query abs bound lookup. -/ - qAbsMax : Fin dHead → Rat - /-- Cached max key abs bounds. -/ - kAbsMaxArr : Array Rat - /-- Max key abs bound lookup. -/ - kAbsMax : Fin dHead → Rat - /-- Causal mask predicate. -/ - masked : Fin seq → Fin seq → Prop - /-- Split budget for query dims. -/ - splitBudgetQ : Nat - /-- Split budget for key dims. -/ - splitBudgetK : Nat - /-- Split budget for base diff dims. -/ - splitBudgetDiffBase : Nat - /-- Split budget for refined diff dims. -/ - splitBudgetDiffRefined : Nat - /-- Split dims for query bounds. -/ - splitDimsQ : Fin seq → List (Fin dHead) - /-- Split dims for key bounds. -/ - splitDimsK : Fin seq → Fin seq → List (Fin dHead) - /-- Split dims for diff bounds with budget. -/ - splitDimsDiffCore : Nat → Fin seq → Fin seq → List (Fin dHead) - /-- Split dims for base diff bounds. -/ - splitDimsDiffBase : Fin seq → Fin seq → List (Fin dHead) - /-- Split dims for refined diff bounds. -/ - splitDimsDiffRefined : Fin seq → Fin seq → List (Fin dHead) - /-- Tasks for dot-product interval rows. -/ - dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) - /-- Tasks for base diff dot rows. -/ - dotDiffRowTasksBase : Array (Task { row : Array (Rat × Rat) // row.size = seq }) - /-- Dot-product lower bounds. -/ - dotLo : Fin seq → Fin seq → Rat - /-- Dot-product upper bounds. -/ - dotHi : Fin seq → Fin seq → Rat - /-- Base diff dot-product lower bounds. -/ - dotDiffLoBase : Fin seq → Fin seq → Rat - /-- Base diff dot-product upper bounds. -/ - dotDiffHiBase : Fin seq → Fin seq → Rat - /-- Dot-product absolute bounds. -/ - dotAbs : Fin seq → Fin seq → Rat - /-- Base score absolute bounds. -/ - scoreBaseAbs : Fin seq → Fin seq → Rat - /-- Score lower bounds. -/ - scoreLo : Fin seq → Fin seq → Rat - /-- Score upper bounds. -/ - scoreHi : Fin seq → Fin seq → Rat - /-- Score lower bounds at prev key. -/ - scoreLoPrev : Fin seq → Rat - /-- Base score-gap lower bounds. -/ - scoreGapLoBase : Fin seq → Fin seq → Rat - /-- Other-key set for each query. -/ - otherKeys : Fin seq → Finset (Fin seq) - /-- Worst key candidate per query. -/ - worstKey : Fin seq → Option (Fin seq) - /-- Refined diff dot-product lower bounds. -/ - dotDiffLo : Fin seq → Fin seq → Rat - /-- Refined diff dot-product upper bounds. -/ - dotDiffHi : Fin seq → Fin seq → Rat - /-- Score-gap lower bounds. -/ - scoreGapLo : Fin seq → Fin seq → Rat - /-- Margin per query. -/ - marginAt : Fin seq → Rat - /-- Epsilon per query. -/ - epsAt : Fin seq → Rat - /-- Per-key weight bounds derived from score gaps. -/ - weightBoundAt : Fin seq → Fin seq → Rat - /-- Global margin. -/ - margin : Rat - /-- Global epsilon. -/ - eps : Rat - /-- Cached direction head vector. -/ - dirHeadVec : Vector Rat dHead - /-- Direction head lookup. -/ - dirHead : Fin dHead → Rat - /-- Value-direction weight dot products. -/ - wvDir : Fin dModel → Rat - /-- Direction bias term. -/ - bDir : Rat - /-- Value lower bounds. -/ - valsLo : Fin seq → Rat - /-- Value upper bounds. -/ - valsHi : Fin seq → Rat - /-- Universe of query indices. -/ - univ : Finset (Fin seq) - /-- Global value lower bound. -/ - lo : Rat - /-- Global value upper bound. -/ - hi : Rat - /-- Value-interval certificate. -/ - valCert : ValueInterval seq - /-- Induction-head certificate. -/ - cert : InductionHeadCert seq - -/-- Cached certificate-related fields derived from score gaps. -/ -structure InductionHeadCertFields (seq : Nat) where - /-- Margin per query. -/ - marginAt : Fin seq → Rat - /-- Base weight bounds derived from score gaps. -/ - weightBoundAtBase : Fin seq → Fin seq → Rat - /-- Cached base weight bounds. -/ - weightBoundAtBaseCached : Fin seq → Fin seq → Rat - /-- Base epsilon per query. -/ - epsAtBase : Fin seq → Rat - /-- Epsilon per query. -/ - epsAt : Fin seq → Rat - /-- Per-key weight bounds derived from score gaps. -/ - weightBoundAt : Fin seq → Fin seq → Rat - /-- Global margin. -/ - margin : Rat - /-- Global epsilon. -/ - eps : Rat - -/-- Build certificate-related cached fields from score gaps. -/ -def buildInductionHeadCertFields [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (otherKeys : Fin seq → Finset (Fin seq)) - (scoreGapLo : Fin seq → Fin seq → Rat) : InductionHeadCertFields seq := by - let marginAt : Fin seq → Rat := fun q => - if hq : q ∈ inputs.active then - let other := otherKeys q - if h : other.Nonempty then - other.inf' h (fun k => scoreGapLo q k) - else - (0 : Rat) - else - (0 : Rat) - let weightBoundAtBase : Fin seq → Fin seq → Rat := fun q k => - if hk : k = inputs.prev q then - (0 : Rat) - else - let gap := scoreGapLo q k - if gap < 0 then - (1 : Rat) - else - ratDivUp 1 (1 + gap) - let weightBoundAtBaseCached : Fin seq → Fin seq → Rat := - Bounds.cacheBound2Task weightBoundAtBase - let epsAtBase : Fin seq → Rat := fun q => - let other := otherKeys q - let total := other.sum (fun k => weightBoundAtBaseCached q k) - min (1 : Rat) total - let epsAt : Fin seq → Rat := Bounds.cacheBoundThunk epsAtBase - let weightBoundAt : Fin seq → Fin seq → Rat := weightBoundAtBaseCached - let margin : Rat := - if h : inputs.active.Nonempty then - inputs.active.inf' h marginAt - else - (0 : Rat) - let eps : Rat := - if h : inputs.active.Nonempty then - inputs.active.sup' h epsAt - else - (0 : Rat) - exact - { marginAt := marginAt - weightBoundAtBase := weightBoundAtBase - weightBoundAtBaseCached := weightBoundAtBaseCached - epsAtBase := epsAtBase - epsAt := epsAt - weightBoundAt := weightBoundAt - margin := margin - eps := eps } - -/-- Unfolding lemma for `buildInductionHeadCertFields`. -/ -theorem buildInductionHeadCertFields_def [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (otherKeys : Fin seq → Finset (Fin seq)) - (scoreGapLo : Fin seq → Fin seq → Rat) : - buildInductionHeadCertFields inputs otherKeys scoreGapLo = - (let marginAt : Fin seq → Rat := fun q => - if _ : q ∈ inputs.active then - let other := otherKeys q - if h : other.Nonempty then - other.inf' h (fun k => scoreGapLo q k) - else - (0 : Rat) - else - (0 : Rat) - let weightBoundAtBase : Fin seq → Fin seq → Rat := fun q k => - if _ : k = inputs.prev q then - (0 : Rat) - else - let gap := scoreGapLo q k - if gap < 0 then - (1 : Rat) - else - ratDivUp 1 (1 + gap) - let weightBoundAtBaseCached : Fin seq → Fin seq → Rat := - Bounds.cacheBound2Task weightBoundAtBase - let epsAtBase : Fin seq → Rat := fun q => - let other := otherKeys q - let total := other.sum (fun k => weightBoundAtBaseCached q k) - min (1 : Rat) total - let epsAt : Fin seq → Rat := Bounds.cacheBoundThunk epsAtBase - let weightBoundAt : Fin seq → Fin seq → Rat := weightBoundAtBaseCached - let margin : Rat := - if h : inputs.active.Nonempty then - inputs.active.inf' h marginAt - else - (0 : Rat) - let eps : Rat := - if h : inputs.active.Nonempty then - inputs.active.sup' h epsAt - else - (0 : Rat) - { marginAt := marginAt - weightBoundAtBase := weightBoundAtBase - weightBoundAtBaseCached := weightBoundAtBaseCached - epsAtBase := epsAtBase - epsAt := epsAt - weightBoundAt := weightBoundAt - margin := margin - eps := eps }) := by - rfl - -/-- The `eps` field of `buildInductionHeadCertFields` is the active supremum when nonempty. -/ -theorem buildInductionHeadCertFields_eps_eq [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (otherKeys : Fin seq → Finset (Fin seq)) - (scoreGapLo : Fin seq → Fin seq → Rat) : - (buildInductionHeadCertFields inputs otherKeys scoreGapLo).eps = - if h : inputs.active.Nonempty then - inputs.active.sup' h - (buildInductionHeadCertFields inputs otherKeys scoreGapLo).epsAt - else - (0 : Rat) := by - rfl - -/-- Build an induction-head certificate from cached fields. -/ -def inductionHeadCertOfCacheFields [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (eps : Rat) (epsAt : Fin seq → Rat) (weightBoundAt : Fin seq → Fin seq → Rat) - (margin : Rat) (valCert : ValueInterval seq) : InductionHeadCert seq := - { eps := eps - epsAt := epsAt - weightBoundAt := weightBoundAt - margin := margin - active := inputs.active - prev := inputs.prev - values := valCert } - -/-- Unfolding lemma for `inductionHeadCertOfCacheFields`. -/ -theorem inductionHeadCertOfCacheFields_def [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (eps : Rat) (epsAt : Fin seq → Rat) (weightBoundAt : Fin seq → Fin seq → Rat) - (margin : Rat) (valCert : ValueInterval seq) : - inductionHeadCertOfCacheFields inputs eps epsAt weightBoundAt margin valCert = - { eps := eps - epsAt := epsAt - weightBoundAt := weightBoundAt - margin := margin - active := inputs.active - prev := inputs.prev - values := valCert } := by - rfl - -/-- Build a value-interval certificate from per-query bounds. -/ -def buildInductionHeadValCert [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (valsLo valsHi : Fin seq → Rat) : ValueInterval seq := by - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := by simp [univ] - let lo := univ.inf' hnonempty valsLo - let hi := univ.sup' hnonempty valsHi - exact - { lo := lo - hi := hi - valsLo := valsLo - valsHi := valsHi - direction := some inputs.directionSpec } - -/-- Unfolding lemma for `buildInductionHeadValCert`. -/ -theorem buildInductionHeadValCert_def [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (valsLo valsHi : Fin seq → Rat) : - buildInductionHeadValCert inputs valsLo valsHi = - (let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := by simp [univ] - let lo := univ.inf' hnonempty valsLo - let hi := univ.sup' hnonempty valsHi - { lo := lo - hi := hi - valsLo := valsLo - valsHi := valsHi - direction := some inputs.directionSpec }) := by - rfl - -/-- `buildInductionHeadValCert` preserves the provided lower bounds. -/ -theorem buildInductionHeadValCert_valsLo [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (valsLo valsHi : Fin seq → Rat) : - (buildInductionHeadValCert inputs valsLo valsHi).valsLo = valsLo := by - rfl - -/-- `buildInductionHeadValCert` preserves the provided upper bounds. -/ -theorem buildInductionHeadValCert_valsHi [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (valsLo valsHi : Fin seq → Rat) : - (buildInductionHeadValCert inputs valsLo valsHi).valsHi = valsHi := by - rfl - -/-- `buildInductionHeadValCert` yields pointwise value bounds from interval bounds. -/ -theorem buildInductionHeadValCert_bounds_at [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (valsReal : Fin seq → Real) - (valsLo valsHi : Fin seq → Rat) - (hvals : ∀ k, (valsLo k : Real) ≤ valsReal k ∧ valsReal k ≤ (valsHi k : Real)) : - ∀ k, - ((buildInductionHeadValCert inputs valsLo valsHi).valsLo k : Real) ≤ valsReal k ∧ - valsReal k ≤ ((buildInductionHeadValCert inputs valsLo valsHi).valsHi k : Real) := by - intro k - have hprojLo : - (buildInductionHeadValCert inputs valsLo valsHi).valsLo = valsLo := by - exact buildInductionHeadValCert_valsLo (inputs := inputs) valsLo valsHi - have hprojHi : - (buildInductionHeadValCert inputs valsLo valsHi).valsHi = valsHi := by - exact buildInductionHeadValCert_valsHi (inputs := inputs) valsLo valsHi - have hlo : - ((buildInductionHeadValCert inputs valsLo valsHi).valsLo k : Real) = valsLo k := by - exact congrArg (fun r : Rat => (r : Real)) (congrArg (fun f => f k) hprojLo) - have hhi : - ((buildInductionHeadValCert inputs valsLo valsHi).valsHi k : Real) = valsHi k := by - exact congrArg (fun r : Rat => (r : Real)) (congrArg (fun f => f k) hprojHi) - simpa [hlo, hhi] using hvals k - -/-- `buildInductionHeadValCert` satisfies `ValueIntervalBounds` from pointwise value bounds. -/ -theorem buildInductionHeadValCert_bounds [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (valsReal : Fin seq → Real) - (valsLo valsHi : Fin seq → Rat) - (hvals : ∀ k, (valsLo k : Real) ≤ valsReal k ∧ valsReal k ≤ (valsHi k : Real)) : - ValueIntervalBounds (vals := valsReal) (buildInductionHeadValCert inputs valsLo valsHi) := by - let valCert := buildInductionHeadValCert inputs valsLo valsHi - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := by simp [univ] - let lo := univ.inf' hnonempty valsLo - let hi := univ.sup' hnonempty valsHi - refine - { lo_le_hi := ?_ - lo_le_valsLo := ?_ - vals_bounds := ?_ - valsHi_le_hi := ?_ } - · rcases (Finset.univ_nonempty : univ.Nonempty) with ⟨k0, hk0⟩ - have hmem0 : k0 ∈ univ := hk0 - have hlo : (valCert.lo : Real) ≤ (valCert.valsLo k0 : Real) := by - have hloRat : valCert.lo ≤ valCert.valsLo k0 := by - change lo ≤ valsLo k0 - dsimp [lo] - refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) - (f := valsLo) (a := valsLo k0)).2 ⟨k0, hmem0, le_rfl⟩ - simpa [ratToReal_def] using ratToReal_le_of_le hloRat - have hvals : (valCert.valsLo k0 : Real) ≤ valsReal k0 ∧ - valsReal k0 ≤ (valCert.valsHi k0 : Real) := by - exact buildInductionHeadValCert_bounds_at (inputs := inputs) - (valsReal := valsReal) (valsLo := valsLo) (valsHi := valsHi) hvals k0 - have hhi : (valCert.valsHi k0 : Real) ≤ (valCert.hi : Real) := by - have hhiRat : valCert.valsHi k0 ≤ valCert.hi := by - change valsHi k0 ≤ hi - dsimp [hi] - refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) - (f := valsHi) (a := valsHi k0)).2 ⟨k0, ⟨hmem0, le_rfl⟩⟩ - simpa [ratToReal_def] using ratToReal_le_of_le hhiRat - have hreal : (valCert.lo : Real) ≤ (valCert.hi : Real) := - le_trans hlo (le_trans hvals.1 (le_trans hvals.2 hhi)) - have hreal' : ratToReal valCert.lo ≤ ratToReal valCert.hi := by - simpa [ratToReal_def] using hreal - exact (ratToReal_le_iff (x := valCert.lo) (y := valCert.hi)).1 hreal' - · intro k - have hloRat : valCert.lo ≤ valCert.valsLo k := by - change lo ≤ valsLo k - dsimp [lo] - refine (Finset.inf'_le_iff (s := univ) (H := hnonempty) - (f := valsLo) (a := valsLo k)).2 ⟨k, by simp [univ], le_rfl⟩ - simpa [ratToReal_def] using ratToReal_le_of_le hloRat - · exact buildInductionHeadValCert_bounds_at (inputs := inputs) - (valsReal := valsReal) (valsLo := valsLo) (valsHi := valsHi) hvals - · intro k - have hhiRat : valCert.valsHi k ≤ valCert.hi := by - change valsHi k ≤ hi - dsimp [hi] - refine (Finset.le_sup'_iff (s := univ) (H := hnonempty) - (f := valsHi) (a := valsHi k)).2 ⟨k, by simp [univ], le_rfl⟩ - simpa [ratToReal_def] using ratToReal_le_of_le hhiRat - -/-- Build an induction-head certificate from cached fields and value bounds. -/ -def buildInductionHeadCert [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (certFields : InductionHeadCertFields seq) - (valCert : ValueInterval seq) : InductionHeadCert seq := - inductionHeadCertOfCacheFields inputs certFields.eps certFields.epsAt - certFields.weightBoundAt certFields.margin valCert - -/-- Unfolding lemma for `buildInductionHeadCert`. -/ -theorem buildInductionHeadCert_def [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (certFields : InductionHeadCertFields seq) - (valCert : ValueInterval seq) : - buildInductionHeadCert inputs certFields valCert = - inductionHeadCertOfCacheFields inputs certFields.eps certFields.epsAt - certFields.weightBoundAt certFields.margin valCert := by - rfl - -/-- `buildInductionHeadCert` preserves the provided value certificate. -/ -theorem buildInductionHeadCert_values [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (certFields : InductionHeadCertFields seq) - (valCert : ValueInterval seq) : - (buildInductionHeadCert inputs certFields valCert).values = valCert := by - rfl - -/-- Build cached core quantities for induction-head certificates. -/ -def buildInductionHeadCoreCacheWith [NeZero seq] {dModel dHead : Nat} - (cfg : InductionHeadSplitConfig) - (inputs : Model.InductionHeadInputs seq dModel dHead) : - InductionHeadCoreCache seq dModel dHead := by - classical - let lnBounds := Bounds.cacheBoundPair2 (fun q => - Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) - let lnLo : Fin seq → Fin dModel → Rat := lnBounds.1 - let lnHi : Fin seq → Fin dModel → Rat := lnBounds.2 - let lnAbsMaxTask : Fin seq → Rat := - Bounds.cacheBoundTask (fun q => - Bounds.intervalAbsBound (lnLo q) (lnHi q)) - let lnAbsMaxArr : Array Rat := - Array.ofFn (fun q : Fin seq => lnAbsMaxTask q) - let lnAbsMax : Fin seq → Rat := fun q => - lnAbsMaxArr[q.1]'(by - have hsize : lnAbsMaxArr.size = seq := by - simp [lnAbsMaxArr] - simp [hsize]) - let invStdBoundsTasks : Array (Task (Rat × Rat)) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => invStdBounds inputs.lnEps (inputs.embed q))) - let invStdBoundsArr : Array (Rat × Rat) := - Array.ofFn (fun q : Fin seq => - (invStdBoundsTasks[q.1]'(by - have hsize : invStdBoundsTasks.size = seq := by - simp [invStdBoundsTasks] - simp [hsize])).get) - let invStdLo : Fin seq → Rat := fun q => - (invStdBoundsArr[q.1]'(by - have hsize : invStdBoundsArr.size = seq := by - simp [invStdBoundsArr] - simp [hsize])).1 - let invStdHi : Fin seq → Rat := fun q => - (invStdBoundsArr[q.1]'(by - have hsize : invStdBoundsArr.size = seq := by - simp [invStdBoundsArr] - simp [hsize])).2 - let qBaseArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.wq j d) (fun j => inputs.ln1Beta j) + - inputs.bq d) - let qBase : Fin dHead → Rat := fun d => - qBaseArr[d.1]'(by - have hsize : qBaseArr.size = dHead := by - simp [qBaseArr] - simp [hsize]) - let kBaseArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.wk j d) (fun j => inputs.ln1Beta j) + - inputs.bk d) - let kBase : Fin dHead → Rat := fun d => - kBaseArr[d.1]'(by - have hsize : kBaseArr.size = dHead := by - simp [kBaseArr] - simp [hsize]) - let qCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - let μ := mean (inputs.embed q) - let coeff : Fin dModel → Rat := fun j => - inputs.ln1Gamma j * (inputs.embed q j - μ) - ⟨Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.wq j d) coeff), - by simp⟩)) - let qCoeffArr : Array { row : Array Rat // row.size = dHead } := - Array.ofFn (fun q : Fin seq => - (qCoeffRowTasks[q.1]'(by - have hsize : qCoeffRowTasks.size = seq := by - simp [qCoeffRowTasks] - simp [hsize])).get) - let qCoeff : Fin seq → Fin dHead → Rat := fun q d => - let row := qCoeffArr[q.1]'(by - have hsize : qCoeffArr.size = seq := by - simp [qCoeffArr] - simp [hsize]) - row.1[d.1]'(by - simp [row.2]) - let kCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - let μ := mean (inputs.embed q) - let coeff : Fin dModel → Rat := fun j => - inputs.ln1Gamma j * (inputs.embed q j - μ) - ⟨Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.wk j d) coeff), - by simp⟩)) - let kCoeffArr : Array { row : Array Rat // row.size = dHead } := - Array.ofFn (fun q : Fin seq => - (kCoeffRowTasks[q.1]'(by - have hsize : kCoeffRowTasks.size = seq := by - simp [kCoeffRowTasks] - simp [hsize])).get) - let kCoeff : Fin seq → Fin dHead → Rat := fun q d => - let row := kCoeffArr[q.1]'(by - have hsize : kCoeffArr.size = seq := by - simp [kCoeffArr] - simp [hsize]) - row.1[d.1]'(by - simp [row.2]) - let qLo : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) - qBase d + bounds.1 - let qHi : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) - qBase d + bounds.2 - let kLo : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) - kBase d + bounds.1 - let kHi : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) - kBase d + bounds.2 - let qAbs : Fin seq → Fin dHead → Rat := fun q d => max |qLo q d| |qHi q d| - let kAbs : Fin seq → Fin dHead → Rat := fun q d => max |kLo q d| |kHi q d| - let qAbsMaxArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := Finset.univ_nonempty - univ.sup' hnonempty (fun q => qAbs q d)) - let qAbsMax : Fin dHead → Rat := fun d => - qAbsMaxArr[d.1]'(by - have hsize : qAbsMaxArr.size = dHead := by - simp [qAbsMaxArr] - simp [hsize]) - let kAbsMaxArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := Finset.univ_nonempty - univ.sup' hnonempty (fun k => kAbs k d)) - let kAbsMax : Fin dHead → Rat := fun d => - kAbsMaxArr[d.1]'(by - have hsize : kAbsMaxArr.size = dHead := by - simp [kAbsMaxArr] - simp [hsize]) - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let splitBudgetQ : Nat := cfg.splitBudgetQ - let splitBudgetK : Nat := cfg.splitBudgetK - let splitBudgetDiffBase : Nat := cfg.splitBudgetDiffBase - let splitBudgetDiffRefined : Nat := cfg.splitBudgetDiffRefined - let finRangeHead : List (Fin dHead) := List.finRange dHead - let finRangeSeq : List (Fin seq) := List.finRange seq - let splitDimsQ : Fin seq → List (Fin dHead) := fun q => - if splitBudgetQ = 0 then - [] - else - let ambig := - finRangeHead.filter (fun d => decide (qLo q d < 0 ∧ 0 < qHi q d)) - let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d - let step - (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) - (d : Fin dHead) : - Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := - let s := score d - match best with - | (none, none) => (some (s, d), none) - | (some b1, none) => - if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) - | (some b1, some b2) => - if b1.1 < s then (some (s, d), some b1) - else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) - | (none, some b2) => - if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => - match ambig.foldl step (none, none) with - | (some b1, some b2) => [b1.2, b2.2] - | (some b1, none) => [b1.2] - | (none, _) => [] - let dims1 := top2 ambig - let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) - (dims1 ++ dims2).take splitBudgetQ - let splitDimsK : Fin seq → Fin seq → List (Fin dHead) := fun q k => - if splitBudgetK = 0 then - [] - else - let ambig := - finRangeHead.filter (fun d => decide (kLo k d < 0 ∧ 0 < kHi k d)) - let score : Fin dHead → Rat := fun d => (kHi k d - kLo k d) * qAbs q d - let step - (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) - (d : Fin dHead) : - Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := - let s := score d - match best with - | (none, none) => (some (s, d), none) - | (some b1, none) => - if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) - | (some b1, some b2) => - if b1.1 < s then (some (s, d), some b1) - else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) - | (none, some b2) => - if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => - match ambig.foldl step (none, none) with - | (some b1, some b2) => [b1.2, b2.2] - | (some b1, none) => [b1.2] - | (none, _) => [] - let dims1 := top2 ambig - let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) - (dims1 ++ dims2).take splitBudgetK - let splitDimsDiffCore : Nat → Fin seq → Fin seq → List (Fin dHead) := fun budget q k => - if budget = 0 then - [] - else - let prev := inputs.prev q - let diffLo : Fin dHead → Rat := fun d => kLo prev d - kHi k d - let diffHi : Fin dHead → Rat := fun d => kHi prev d - kLo k d - let ambig := - finRangeHead.filter (fun d => decide (diffLo d < 0 ∧ 0 < diffHi d)) - let score : Fin dHead → Rat := fun d => (diffHi d - diffLo d) * qAbs q d - let step - (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) - (d : Fin dHead) : - Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := - let s := score d - match best with - | (none, none) => (some (s, d), none) - | (some b1, none) => - if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) - | (some b1, some b2) => - if b1.1 < s then (some (s, d), some b1) - else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) - | (none, some b2) => - if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => - match ambig.foldl step (none, none) with - | (some b1, some b2) => [b1.2, b2.2] - | (some b1, none) => [b1.2] - | (none, _) => [] - let dims1 : List (Fin dHead) := top2 ambig - let dims2 : List (Fin dHead) := - top2 (ambig.filter (fun d => decide (d ∉ dims1))) - let memDims2 : Fin dHead → Bool := fun d => - dims2.any (fun d' => decide (d' = d)) - let dims3 : List (Fin dHead) := - top2 - ((ambig.filter (fun d => decide (d ∉ dims1))).filter - (fun d => !memDims2 d)) - (dims1 ++ dims2 ++ dims3).take budget - let splitDimsDiffBase : Fin seq → Fin seq → List (Fin dHead) := - splitDimsDiffCore splitBudgetDiffBase - let splitDimsDiffRefined : Fin seq → Fin seq → List (Fin dHead) := - splitDimsDiffCore splitBudgetDiffRefined - let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - let dimsQ := splitDimsQ q - ⟨Array.ofFn (fun k : Fin seq => - if masked q k then - (0, 0) - else - let dimsK := splitDimsK q k - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsK - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo k d) (fun d => kHi k d)), - by simp⟩)) - let dotDiffRowTasksBase : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - if hq : q ∈ inputs.active then - let dimsQ := splitDimsQ q - ⟨Array.ofFn (fun k : Fin seq => - if masked q k then - (0, 0) - else - let dimsDiff := splitDimsDiffBase q k - let prev := inputs.prev q - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)), - by simp⟩ - else - ⟨Array.ofFn (fun _ : Fin seq => (0, 0)), by simp⟩)) - let dotLo : Fin seq → Fin seq → Rat := fun q k => - let row := (dotRowTasks[q.1]'(by - simp [dotRowTasks, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.1 - let dotHi : Fin seq → Fin seq → Rat := fun q k => - let row := (dotRowTasks[q.1]'(by - simp [dotRowTasks, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.2 - let dotDiffLoBase : Fin seq → Fin seq → Rat := fun q k => - let row := (dotDiffRowTasksBase[q.1]'(by - simp [dotDiffRowTasksBase, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.1 - let dotDiffHiBase : Fin seq → Fin seq → Rat := fun q k => - let row := (dotDiffRowTasksBase[q.1]'(by - simp [dotDiffRowTasksBase, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.2 - let dotAbs : Fin seq → Fin seq → Rat := fun q k => max |dotLo q k| |dotHi q k| - let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => - |inputs.scale| * dotAbs q k - let scoreLo : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - if hscale : 0 ≤ inputs.scale then - inputs.scale * dotLo q k - else - inputs.scale * dotHi q k - let scoreHi : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - if hscale : 0 ≤ inputs.scale then - inputs.scale * dotHi q k - else - inputs.scale * dotLo q k - let scoreLoPrev : Fin seq → Rat := fun q => - scoreLo q (inputs.prev q) - let scoreGapLoBaseRaw : Fin seq → Fin seq → Rat := fun q k => - if masked q (inputs.prev q) then - scoreLoPrev q - scoreHi q k - else if masked q k then - scoreLoPrev q - inputs.maskValue - else if hscale : 0 ≤ inputs.scale then - inputs.scale * dotDiffLoBase q k - else - inputs.scale * dotDiffHiBase q k - let scoreGapLoBase : Fin seq → Fin seq → Rat := - Bounds.cacheBound2 scoreGapLoBaseRaw - let otherKeys : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - -- Skip worst-key refinement when base/refined budgets match to avoid duplicate score-gap work. - let worstKey : Fin seq → Option (Fin seq) := - if h : cfg.splitBudgetDiffRefined = cfg.splitBudgetDiffBase then - fun _ => none - else - let worstKeyRaw : Fin seq → Option (Fin seq) := fun q => - if hq : q ∈ inputs.active then - let ks := finRangeSeq.filter (fun k => decide (k ≠ inputs.prev q)) - match ks with - | [] => none - | k :: ks => - let step (best : Rat × Fin seq) (k : Fin seq) := - let s := scoreGapLoBase q k - if s ≤ best.1 then (s, k) else best - some (ks.foldl step (scoreGapLoBase q k, k)).2 - else - none - let worstKeyArr : Array (Thunk (Option (Fin seq))) := - Array.ofFn (fun q => Thunk.mk (fun _ => worstKeyRaw q)) - fun q => - let t := worstKeyArr[q.1]'(by - simp [worstKeyArr, q.isLt]) - Thunk.get t - let refineKeys : Fin seq → Finset (Fin seq) := - if h : cfg.splitBudgetDiffRefined = cfg.splitBudgetDiffBase then - fun _ => ∅ - else - let refineKeysRaw : Fin seq → Finset (Fin seq) := fun q => - let base : Finset (Fin seq) := - match worstKey q with - | some k => {k} - | none => ∅ - if hq : q ∈ inputs.active then - let other := otherKeys q - base ∪ other.filter (fun k => decide (scoreGapLoBase q k < 0)) - else - base - let refineKeysArr : Array (Thunk (Finset (Fin seq))) := - Array.ofFn (fun q => Thunk.mk (fun _ => refineKeysRaw q)) - fun q => - let t := refineKeysArr[q.1]'(by - simp [refineKeysArr, q.isLt]) - Thunk.get t - let dotDiffLoHi : (Fin seq → Fin seq → Rat) × (Fin seq → Fin seq → Rat) := - if h : cfg.splitBudgetDiffRefined = cfg.splitBudgetDiffBase then - (dotDiffLoBase, dotDiffHiBase) - else - let dotDiffLo : Fin seq → Fin seq → Rat := fun q k => - if hk : k ∈ refineKeys q then - let dimsQ := splitDimsQ q - let dimsDiff := splitDimsDiffRefined q k - let prev := inputs.prev q - (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)).1 - else - dotDiffLoBase q k - let dotDiffHi : Fin seq → Fin seq → Rat := fun q k => - if hk : k ∈ refineKeys q then - let dimsQ := splitDimsQ q - let dimsDiff := splitDimsDiffRefined q k - let prev := inputs.prev q - (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)).2 - else - dotDiffHiBase q k - (dotDiffLo, dotDiffHi) - let dotDiffLo : Fin seq → Fin seq → Rat := dotDiffLoHi.1 - let dotDiffHi : Fin seq → Fin seq → Rat := dotDiffLoHi.2 - let scoreGapLoRaw : Fin seq → Fin seq → Rat := fun q k => - if masked q (inputs.prev q) then - scoreLoPrev q - scoreHi q k - else if masked q k then - scoreLoPrev q - inputs.maskValue - else if hscale : 0 ≤ inputs.scale then - inputs.scale * dotDiffLo q k - else - inputs.scale * dotDiffHi q k - let scoreGapLo : Fin seq → Fin seq → Rat := - Bounds.cacheBound2 scoreGapLoRaw - let certFields := buildInductionHeadCertFields inputs otherKeys scoreGapLo - let marginAt : Fin seq → Rat := certFields.marginAt - let weightBoundAtBase : Fin seq → Fin seq → Rat := certFields.weightBoundAtBase - let weightBoundAtBaseCached : Fin seq → Fin seq → Rat := certFields.weightBoundAtBaseCached - let epsAtBase : Fin seq → Rat := certFields.epsAtBase - let epsAt : Fin seq → Rat := certFields.epsAt - let weightBoundAt : Fin seq → Fin seq → Rat := certFields.weightBoundAt - let margin : Rat := certFields.margin - let eps : Rat := certFields.eps - let dirHeadVec := dirHeadVecOfInputs inputs - let dirHead : Fin dHead → Rat := fun d => dirHeadVec.get d - let wvDirTask : Fin dModel → Rat := - Bounds.cacheBoundTask (fun j => - Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) - let wvDirArr : Array Rat := Array.ofFn wvDirTask - let wvDir : Fin dModel → Rat := fun j => - wvDirArr[j.1]'(by - have hsize : wvDirArr.size = dModel := by - simp [wvDirArr] - simp [hsize, j.isLt]) - let bDir : Rat := - Linear.dotFin dHead dirHead (fun d => inputs.bv d) - let valsLo : Fin seq → Rat := fun q => - bDir + Bounds.dotIntervalLower (fun j => wvDir j) (lnLo q) (lnHi q) - let valsHi : Fin seq → Rat := fun q => - bDir + Bounds.dotIntervalUpper (fun j => wvDir j) (lnLo q) (lnHi q) - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := by simp [univ] - let lo := univ.inf' hnonempty valsLo - let hi := univ.sup' hnonempty valsHi - let valCert : ValueInterval seq := buildInductionHeadValCert inputs valsLo valsHi - let cert : InductionHeadCert seq := - buildInductionHeadCert inputs certFields valCert - exact - { lnBounds := lnBounds - lnLo := lnLo - lnHi := lnHi - lnAbsMaxTask := lnAbsMaxTask - lnAbsMaxArr := lnAbsMaxArr - lnAbsMax := lnAbsMax - invStdBoundsTasks := invStdBoundsTasks - invStdBoundsArr := invStdBoundsArr - invStdLo := invStdLo - invStdHi := invStdHi - qBaseArr := qBaseArr - qBase := qBase - kBaseArr := kBaseArr - kBase := kBase - qCoeffRowTasks := qCoeffRowTasks - qCoeffArr := qCoeffArr - qCoeff := qCoeff - kCoeffRowTasks := kCoeffRowTasks - kCoeffArr := kCoeffArr - kCoeff := kCoeff - qLo := qLo - qHi := qHi - kLo := kLo - kHi := kHi - qAbs := qAbs - kAbs := kAbs - qAbsMaxArr := qAbsMaxArr - qAbsMax := qAbsMax - kAbsMaxArr := kAbsMaxArr - kAbsMax := kAbsMax - masked := masked - splitBudgetQ := splitBudgetQ - splitBudgetK := splitBudgetK - splitBudgetDiffBase := splitBudgetDiffBase - splitBudgetDiffRefined := splitBudgetDiffRefined - splitDimsQ := splitDimsQ - splitDimsK := splitDimsK - splitDimsDiffCore := splitDimsDiffCore - splitDimsDiffBase := splitDimsDiffBase - splitDimsDiffRefined := splitDimsDiffRefined - dotRowTasks := dotRowTasks - dotDiffRowTasksBase := dotDiffRowTasksBase - dotLo := dotLo - dotHi := dotHi - dotDiffLoBase := dotDiffLoBase - dotDiffHiBase := dotDiffHiBase - dotAbs := dotAbs - scoreBaseAbs := scoreBaseAbs - scoreLo := scoreLo - scoreHi := scoreHi - scoreLoPrev := scoreLoPrev - scoreGapLoBase := scoreGapLoBase - otherKeys := otherKeys - worstKey := worstKey - dotDiffLo := dotDiffLo - dotDiffHi := dotDiffHi - scoreGapLo := scoreGapLo - marginAt := marginAt - epsAt := epsAt - weightBoundAt := weightBoundAt - margin := margin - eps := eps - dirHeadVec := dirHeadVec - dirHead := dirHead - wvDir := wvDir - bDir := bDir - valsLo := valsLo - valsHi := valsHi - univ := univ - lo := lo - hi := hi - valCert := valCert - cert := cert } - -/-- Build cached core quantities for induction-head certificates using the default split budgets. -/ -def buildInductionHeadCoreCache [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : - InductionHeadCoreCache seq dModel dHead := - buildInductionHeadCoreCacheWith defaultInductionHeadSplitConfig inputs - -/-- The cached certificate is built from cache fields. -/ -theorem buildInductionHeadCoreCache_cert_eq [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : - (buildInductionHeadCoreCache inputs).cert = - { eps := (buildInductionHeadCoreCache inputs).eps - epsAt := (buildInductionHeadCoreCache inputs).epsAt - weightBoundAt := (buildInductionHeadCoreCache inputs).weightBoundAt - margin := (buildInductionHeadCoreCache inputs).margin - active := inputs.active - prev := inputs.prev - values := (buildInductionHeadCoreCache inputs).valCert } := by - rfl - -/-- The cached certificate is built from cache fields (custom split config). -/ -theorem buildInductionHeadCoreCacheWith_cert_eq [NeZero seq] {dModel dHead : Nat} - (cfg : InductionHeadSplitConfig) - (inputs : Model.InductionHeadInputs seq dModel dHead) : - (buildInductionHeadCoreCacheWith cfg inputs).cert = - { eps := (buildInductionHeadCoreCacheWith cfg inputs).eps - epsAt := (buildInductionHeadCoreCacheWith cfg inputs).epsAt - weightBoundAt := (buildInductionHeadCoreCacheWith cfg inputs).weightBoundAt - margin := (buildInductionHeadCoreCacheWith cfg inputs).margin - active := inputs.active - prev := inputs.prev - values := (buildInductionHeadCoreCacheWith cfg inputs).valCert } := by - rfl -/-- Build induction certificates from exact head inputs (core computation). -/ -def buildInductionCertFromHeadCoreWith? [NeZero seq] {dModel dHead : Nat} - (cfg : InductionHeadSplitConfig) - (inputs : Model.InductionHeadInputs seq dModel dHead) : - Option (InductionHeadCert seq) := by - classical - by_cases hEps : 0 < inputs.lnEps - · by_cases hSqrt : 0 < sqrtLower inputs.lnEps - · by_cases hmodel : dModel = 0 - · exact none - · by_cases hactive : inputs.active.Nonempty - · exact some (buildInductionHeadCoreCacheWith cfg inputs).cert - · exact none - · exact none - · exact none - -/-- Build induction certificates from exact head inputs using the default split budgets. -/ -def buildInductionCertFromHeadCore? [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : - Option (InductionHeadCert seq) := - buildInductionCertFromHeadCoreWith? defaultInductionHeadSplitConfig inputs - -/-- Unfolding lemma for `buildInductionCertFromHeadCore?`. -/ -theorem buildInductionCertFromHeadCore?_def [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : - buildInductionCertFromHeadCore? inputs = - buildInductionCertFromHeadCoreWith? defaultInductionHeadSplitConfig inputs := by - simp [buildInductionCertFromHeadCore?] - -/-- `buildInductionCertFromHeadCoreWith?` succeeds under the guard conditions. -/ -theorem buildInductionCertFromHeadCoreWith?_eq_some [NeZero seq] {dModel dHead : Nat} - (cfg : InductionHeadSplitConfig) - (inputs : Model.InductionHeadInputs seq dModel dHead) - (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) - (hmodel : dModel ≠ 0) (hactive : inputs.active.Nonempty) : - buildInductionCertFromHeadCoreWith? cfg inputs = - some (buildInductionHeadCoreCacheWith cfg inputs).cert := by - classical - simp [buildInductionCertFromHeadCoreWith?, hEps, hSqrt, hmodel, hactive] - -/-- `buildInductionCertFromHeadCoreWith?` fails when `dModel = 0`. -/ -theorem buildInductionCertFromHeadCoreWith?_eq_none_of_model_eq_zero - [NeZero seq] {dModel dHead : Nat} - (cfg : InductionHeadSplitConfig) - (inputs : Model.InductionHeadInputs seq dModel dHead) - (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) - (hmodel : dModel = 0) : - buildInductionCertFromHeadCoreWith? cfg inputs = none := by - classical - simp [buildInductionCertFromHeadCoreWith?, hEps, hSqrt, hmodel] - -/-- `buildInductionCertFromHeadCoreWith?` fails when `active` is empty. -/ -theorem buildInductionCertFromHeadCoreWith?_eq_none_of_not_active - [NeZero seq] {dModel dHead : Nat} - (cfg : InductionHeadSplitConfig) - (inputs : Model.InductionHeadInputs seq dModel dHead) - (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) - (hmodel : dModel ≠ 0) (hactive : ¬inputs.active.Nonempty) : - buildInductionCertFromHeadCoreWith? cfg inputs = none := by - classical - simp [buildInductionCertFromHeadCoreWith?, hEps, hSqrt, hmodel, hactive] - -/-- `buildInductionCertFromHeadCoreWith?` fails when the sqrt lower bound is nonpositive. -/ -theorem buildInductionCertFromHeadCoreWith?_eq_none_of_not_sqrt - [NeZero seq] {dModel dHead : Nat} - (cfg : InductionHeadSplitConfig) - (inputs : Model.InductionHeadInputs seq dModel dHead) - (hEps : 0 < inputs.lnEps) (hSqrt : ¬0 < sqrtLower inputs.lnEps) : - buildInductionCertFromHeadCoreWith? cfg inputs = none := by - classical - simp [buildInductionCertFromHeadCoreWith?, hEps, hSqrt] - -/-- `buildInductionCertFromHeadCoreWith?` fails when `lnEps` is nonpositive. -/ -theorem buildInductionCertFromHeadCoreWith?_eq_none_of_not_eps - [NeZero seq] {dModel dHead : Nat} - (cfg : InductionHeadSplitConfig) - (inputs : Model.InductionHeadInputs seq dModel dHead) - (hEps : ¬0 < inputs.lnEps) : - buildInductionCertFromHeadCoreWith? cfg inputs = none := by - classical - simp [buildInductionCertFromHeadCoreWith?, hEps] - -/-- `buildInductionCertFromHeadCore?` succeeds under the guard conditions. -/ -theorem buildInductionCertFromHeadCore?_eq_some [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) - (hmodel : dModel ≠ 0) (hactive : inputs.active.Nonempty) : - buildInductionCertFromHeadCore? inputs = - some (buildInductionHeadCoreCache inputs).cert := by - classical - simpa [buildInductionCertFromHeadCore?, buildInductionHeadCoreCache] using - (buildInductionCertFromHeadCoreWith?_eq_some - (cfg := defaultInductionHeadSplitConfig) (inputs := inputs) - hEps hSqrt hmodel hactive) - -/-- `buildInductionCertFromHeadCore?` fails when `dModel = 0`. -/ -theorem buildInductionCertFromHeadCore?_eq_none_of_model_eq_zero [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) - (hmodel : dModel = 0) : - buildInductionCertFromHeadCore? inputs = none := by - classical - simpa [buildInductionCertFromHeadCore?] using - (buildInductionCertFromHeadCoreWith?_eq_none_of_model_eq_zero - (cfg := defaultInductionHeadSplitConfig) (inputs := inputs) hEps hSqrt hmodel) - -/-- `buildInductionCertFromHeadCore?` fails when `active` is empty. -/ -theorem buildInductionCertFromHeadCore?_eq_none_of_not_active [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) - (hmodel : dModel ≠ 0) (hactive : ¬inputs.active.Nonempty) : - buildInductionCertFromHeadCore? inputs = none := by - classical - simpa [buildInductionCertFromHeadCore?] using - (buildInductionCertFromHeadCoreWith?_eq_none_of_not_active - (cfg := defaultInductionHeadSplitConfig) (inputs := inputs) hEps hSqrt hmodel hactive) - -/-- `buildInductionCertFromHeadCore?` fails when the sqrt lower bound is nonpositive. -/ -theorem buildInductionCertFromHeadCore?_eq_none_of_not_sqrt [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (hEps : 0 < inputs.lnEps) (hSqrt : ¬0 < sqrtLower inputs.lnEps) : - buildInductionCertFromHeadCore? inputs = none := by - classical - simpa [buildInductionCertFromHeadCore?] using - (buildInductionCertFromHeadCoreWith?_eq_none_of_not_sqrt - (cfg := defaultInductionHeadSplitConfig) (inputs := inputs) hEps hSqrt) - -/-- `buildInductionCertFromHeadCore?` fails when `lnEps` is nonpositive. -/ -theorem buildInductionCertFromHeadCore?_eq_none_of_not_eps [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (hEps : ¬0 < inputs.lnEps) : - buildInductionCertFromHeadCore? inputs = none := by - classical - simpa [buildInductionCertFromHeadCore?] using - (buildInductionCertFromHeadCoreWith?_eq_none_of_not_eps - (cfg := defaultInductionHeadSplitConfig) (inputs := inputs) hEps) - -end Sound -end Nfp diff --git a/Nfp/Sound/Induction/CoreSound.lean b/Nfp/Sound/Induction/CoreSound.lean deleted file mode 100644 index 6542435..0000000 --- a/Nfp/Sound/Induction/CoreSound.lean +++ /dev/null @@ -1,9 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.Sound.Induction.CoreSound.Basic - -/-! -Soundness proofs for induction-head core certificates. --/ diff --git a/Nfp/Sound/Induction/CoreSound/Basic.lean b/Nfp/Sound/Induction/CoreSound/Basic.lean deleted file mode 100644 index 19bde24..0000000 --- a/Nfp/Sound/Induction/CoreSound/Basic.lean +++ /dev/null @@ -1,11 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.Sound.Induction.CoreSound.Basic.CertSound -public import Nfp.Sound.Induction.CoreSound.Basic.CacheBounds -public import Nfp.Sound.Induction.CoreSound.Basic.DefaultSound - -/-! -Core soundness proofs for induction-head certificates. --/ diff --git a/Nfp/Sound/Induction/CoreSound/Basic/CacheBounds.lean b/Nfp/Sound/Induction/CoreSound/Basic/CacheBounds.lean deleted file mode 100644 index 626a97f..0000000 --- a/Nfp/Sound/Induction/CoreSound/Basic/CacheBounds.lean +++ /dev/null @@ -1,615 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later -module - -import all Nfp.Sound.Induction.Core.Basic -public import Nfp.Sound.Induction.Core -public import Nfp.Sound.Induction.CoreSound.Values - -public section - -namespace Nfp -namespace Sound -open scoped BigOperators -open Nfp.Circuit -open Nfp.Sound.Bounds -variable {seq : Nat} -/-- Bounds for cached projections and scores from `buildInductionHeadCoreCacheWith`. -/ -theorem buildInductionHeadCoreCacheWith_bounds - [NeZero seq] {dModel dHead : Nat} - (cfg : InductionHeadSplitConfig) - (inputs : Model.InductionHeadInputs seq dModel dHead) - (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) - (hmodel : dModel ≠ 0) : - (∀ q d, - ((buildInductionHeadCoreCacheWith cfg inputs).qLo q d : Real) ≤ - qRealOfInputs inputs q d ∧ - qRealOfInputs inputs q d ≤ - ((buildInductionHeadCoreCacheWith cfg inputs).qHi q d : Real)) ∧ - (∀ q d, - ((buildInductionHeadCoreCacheWith cfg inputs).kLo q d : Real) ≤ - kRealOfInputs inputs q d ∧ - kRealOfInputs inputs q d ≤ - ((buildInductionHeadCoreCacheWith cfg inputs).kHi q d : Real)) ∧ - (∀ q k, - ((buildInductionHeadCoreCacheWith cfg inputs).scoreLo q k : Real) ≤ - scoresRealOfInputs inputs q k ∧ - scoresRealOfInputs inputs q k ≤ - ((buildInductionHeadCoreCacheWith cfg inputs).scoreHi q k : Real)) := by - classical - set cache := buildInductionHeadCoreCacheWith cfg inputs with hcache - have dotFin_cast {n : Nat} (f g : Fin n → Rat) : - (Linear.dotFin n f g : Real) = - dotProduct (fun j => (f j : Real)) (fun j => (g j : Real)) := by - simp [Linear.dotFin_def, Linear.sumFin_eq_sum_univ, dotProduct] - let lnCoeff : Fin seq → Fin dModel → Rat := fun q j => - inputs.ln1Gamma j * (inputs.embed q j - mean (inputs.embed q)) - let invStd : Fin seq → Real := fun q => - (Real.sqrt ((varianceRat (inputs.embed q) : Real) + (inputs.lnEps : Real)))⁻¹ - have hmeanRat : - ∀ q, (mean (inputs.embed q) : Real) = meanRat (inputs.embed q) := by - intro q - have hmu_rat : mean (inputs.embed q) = meanRat (inputs.embed q) := by - simp [mean_def, hmodel, ratRoundDown_def] - simpa [ratToReal_def] using congrArg ratToReal hmu_rat - have hln_affine : - ∀ q j, - lnRealOfInputs inputs q j = - (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q := by - intro q j - have hmu := hmeanRat q - simp [lnRealOfInputs_def, Bounds.layerNormReal_def, hmodel, lnCoeff, hmu, invStd, - add_comm, mul_assoc, -mul_eq_mul_left_iff, -mul_eq_mul_right_iff] - have hln_fun : - ∀ q, - lnRealOfInputs inputs q = - fun j => (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q := by - intro q - funext j - exact hln_affine q j - let invStdBoundsTasks : Array (Task (Rat × Rat)) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => invStdBounds inputs.lnEps (inputs.embed q))) - let invStdBoundsArr : Array (Rat × Rat) := - Array.ofFn (fun q : Fin seq => - (invStdBoundsTasks[q.1]'(by - simp [invStdBoundsTasks])).get) - let invStdLo : Fin seq → Rat := fun q => - (invStdBoundsArr[q.1]'(by - simp [invStdBoundsArr])).1 - let invStdHi : Fin seq → Rat := fun q => - (invStdBoundsArr[q.1]'(by - simp [invStdBoundsArr])).2 - have hinv_bounds : - ∀ q, (invStdLo q : Real) ≤ invStd q ∧ invStd q ≤ (invStdHi q : Real) := by - intro q - simpa [invStd, invStdLo, invStdHi, invStdBoundsArr, invStdBoundsTasks, - Bounds.invStdBounds_def, Task.spawn, Array.getElem_ofFn] using - (Bounds.invStdBounds_spec (eps := inputs.lnEps) (x := inputs.embed q) - hmodel hEps hSqrt) - have proj_bounds - (w : Fin dModel → Fin dHead → Rat) - (b base : Fin dHead → Rat) - (coeff : Fin seq → Fin dHead → Rat) - (hbase : ∀ d, - (base d : Real) = - dotProduct (fun j => (w j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - (b d : Real)) - (hcoeff : ∀ q d, - (coeff q d : Real) = - dotProduct (fun j => (w j d : Real)) - (fun j => (lnCoeff q j : Real))) : - ∀ q d, - (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).1 : Rat) ≤ - dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + - (b d : Real) ∧ - dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + - (b d : Real) ≤ - (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).2 : Rat) := by - intro q d - have hinv : (invStdLo q : Real) ≤ invStd q ∧ invStd q ≤ (invStdHi q : Real) := - hinv_bounds q - have hln_fun_q : - lnRealOfInputs inputs q = - fun j => (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q := by - exact hln_fun q - have hdot_add : - dotProduct (fun j => (w j d : Real)) - (fun j => - (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q) = - dotProduct (fun j => (w j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - dotProduct (fun j => (w j d : Real)) - (fun j => (lnCoeff q j : Real) * invStd q) := by - simpa using - (Nfp.Sound.Linear.dotProduct_add_right - (x := fun j => (w j d : Real)) - (y := fun j => (inputs.ln1Beta j : Real)) - (z := fun j => (lnCoeff q j : Real) * invStd q)) - have hdot_coeff : - dotProduct (fun j => (w j d : Real)) - (fun j => (lnCoeff q j : Real) * invStd q) = - dotProduct (fun j => (w j d : Real)) - (fun j => (lnCoeff q j : Real)) * - invStd q := by - simpa using - (Nfp.Sound.Linear.dotProduct_mul_right - (x := fun j => (w j d : Real)) - (y := fun j => (lnCoeff q j : Real)) - (a := invStd q)) - have hreal : - dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + - (b d : Real) = - (base d : Real) + (coeff q d : Real) * invStd q := by - calc - dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + - (b d : Real) = - dotProduct (fun j => (w j d : Real)) - (fun j => - (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q) + - (b d : Real) := by - simp [hln_fun_q] - _ = - dotProduct (fun j => (w j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - dotProduct (fun j => (w j d : Real)) - (fun j => (lnCoeff q j : Real)) * - invStd q + - (b d : Real) := by - simp [hdot_add, hdot_coeff, add_assoc] - _ = - (dotProduct (fun j => (w j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - (b d : Real)) + - dotProduct (fun j => (w j d : Real)) - (fun j => (lnCoeff q j : Real)) * - invStd q := by ac_rfl - _ = (base d : Real) + (coeff q d : Real) * invStd q := by - simp [hbase, hcoeff] - have hscale : - let bounds := scaleInterval (coeff q d) (invStdLo q) (invStdHi q) - (bounds.1 : Real) ≤ (coeff q d : Real) * invStd q ∧ - (coeff q d : Real) * invStd q ≤ (bounds.2 : Real) := by - exact scaleInterval_bounds_real (x := coeff q d) (lo := invStdLo q) - (hi := invStdHi q) (y := invStd q) hinv.1 hinv.2 - have hlow : - (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).1 : Rat) ≤ - dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + - (b d : Real) := by - simpa [hreal] using add_le_add_left hscale.1 (base d : Real) - have hhigh : - dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + - (b d : Real) ≤ - (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).2 : Rat) := by - simpa [hreal] using add_le_add_left hscale.2 (base d : Real) - exact ⟨hlow, hhigh⟩ - let qBaseArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.wq j d) (fun j => inputs.ln1Beta j) + - inputs.bq d) - let qBase : Fin dHead → Rat := fun d => - qBaseArr[d.1]'(by - simp [qBaseArr]) - let kBaseArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.wk j d) (fun j => inputs.ln1Beta j) + - inputs.bk d) - let kBase : Fin dHead → Rat := fun d => - kBaseArr[d.1]'(by - simp [kBaseArr]) - let coeffRowTasks : - (Fin dModel → Fin dHead → Rat) → - Array (Task { row : Array Rat // row.size = dHead }) := - fun w => - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - let μ := mean (inputs.embed q) - let coeff : Fin dModel → Rat := fun j => - inputs.ln1Gamma j * (inputs.embed q j - μ) - ⟨Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => w j d) coeff), - by simp⟩)) - let qCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := - coeffRowTasks inputs.wq - let qCoeffArr : Array { row : Array Rat // row.size = dHead } := - Array.ofFn (fun q : Fin seq => - (qCoeffRowTasks[q.1]'(by - simp [qCoeffRowTasks, coeffRowTasks])).get) - let qCoeff : Fin seq → Fin dHead → Rat := fun q d => - let row := qCoeffArr[q.1]'(by - simp [qCoeffArr]) - row.1[d.1]'(by - simp [row.2]) - let kCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := - coeffRowTasks inputs.wk - let kCoeffArr : Array { row : Array Rat // row.size = dHead } := - Array.ofFn (fun q : Fin seq => - (kCoeffRowTasks[q.1]'(by - simp [kCoeffRowTasks, coeffRowTasks])).get) - let kCoeff : Fin seq → Fin dHead → Rat := fun q d => - let row := kCoeffArr[q.1]'(by - simp [kCoeffArr]) - row.1[d.1]'(by - simp [row.2]) - let qLo : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) - qBase d + bounds.1 - let qHi : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) - qBase d + bounds.2 - let kLo : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) - kBase d + bounds.1 - let kHi : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) - kBase d + bounds.2 - let qAbs : Fin seq → Fin dHead → Rat := fun q d => max |qLo q d| |qHi q d| - let kAbs : Fin seq → Fin dHead → Rat := fun q d => max |kLo q d| |kHi q d| - let qAbsMaxArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := Finset.univ_nonempty - univ.sup' hnonempty (fun q => qAbs q d)) - let qAbsMax : Fin dHead → Rat := fun d => - qAbsMaxArr[d.1]'(by - simp [qAbsMaxArr]) - let kAbsMaxArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := Finset.univ_nonempty - univ.sup' hnonempty (fun k => kAbs k d)) - let kAbsMax : Fin dHead → Rat := fun d => - kAbsMaxArr[d.1]'(by - simp [kAbsMaxArr]) - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let splitBudgetQ : Nat := cfg.splitBudgetQ - let splitBudgetK : Nat := cfg.splitBudgetK - let finRangeHead : List (Fin dHead) := List.finRange dHead - let splitDimsQ : Fin seq → List (Fin dHead) := fun q => - if splitBudgetQ = 0 then - [] - else - let ambig := - finRangeHead.filter (fun d => decide (qLo q d < 0 ∧ 0 < qHi q d)) - let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d - let step - (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) - (d : Fin dHead) : - Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := - let s := score d - match best with - | (none, none) => (some (s, d), none) - | (some b1, none) => - if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) - | (some b1, some b2) => - if b1.1 < s then (some (s, d), some b1) - else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) - | (none, some b2) => - if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => - match ambig.foldl step (none, none) with - | (some b1, some b2) => [b1.2, b2.2] - | (some b1, none) => [b1.2] - | (none, _) => [] - let dims1 := top2 ambig - let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) - (dims1 ++ dims2).take splitBudgetQ - let splitDimsK : Fin seq → Fin seq → List (Fin dHead) := fun q k => - if splitBudgetK = 0 then - [] - else - let ambig := - finRangeHead.filter (fun d => decide (kLo k d < 0 ∧ 0 < kHi k d)) - let score : Fin dHead → Rat := fun d => (kHi k d - kLo k d) * qAbs q d - let step - (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) - (d : Fin dHead) : - Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := - let s := score d - match best with - | (none, none) => (some (s, d), none) - | (some b1, none) => - if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) - | (some b1, some b2) => - if b1.1 < s then (some (s, d), some b1) - else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) - | (none, some b2) => - if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => - match ambig.foldl step (none, none) with - | (some b1, some b2) => [b1.2, b2.2] - | (some b1, none) => [b1.2] - | (none, _) => [] - let dims1 := top2 ambig - let dims2 := top2 (ambig.filter (fun d => decide (d ∉ dims1))) - (dims1 ++ dims2).take splitBudgetK - let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - let dimsQ := splitDimsQ q - ⟨Array.ofFn (fun k : Fin seq => - if masked q k then - (0, 0) - else - let dimsK := splitDimsK q k - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsK - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo k d) (fun d => kHi k d)), - by simp⟩)) - let dotLo : Fin seq → Fin seq → Rat := fun q k => - let row := (dotRowTasks[q.1]'(by - simp [dotRowTasks, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.1 - let dotHi : Fin seq → Fin seq → Rat := fun q k => - let row := (dotRowTasks[q.1]'(by - simp [dotRowTasks, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.2 - let scoreLo : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - if hscale : 0 ≤ inputs.scale then - inputs.scale * dotLo q k - else - inputs.scale * dotHi q k - let scoreHi : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - if hscale : 0 ≤ inputs.scale then - inputs.scale * dotHi q k - else - inputs.scale * dotLo q k - have dotLo_eq : - ∀ q k, - dotLo q k = - if masked q k then - (0 : Rat) - else - let dimsQ := splitDimsQ q - let dimsK := splitDimsK q k - (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsK - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo k d) (fun d => kHi k d)).1 := by - intro q k - classical - by_cases hmk : masked q k - · simp [dotLo, dotRowTasks, Task.spawn, Array.getElem_ofFn, hmk] - · simp [dotLo, dotRowTasks, Task.spawn, Array.getElem_ofFn, hmk] - have dotHi_eq : - ∀ q k, - dotHi q k = - if masked q k then - (0 : Rat) - else - let dimsQ := splitDimsQ q - let dimsK := splitDimsK q k - (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsK - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo k d) (fun d => kHi k d)).2 := by - intro q k - classical - by_cases hmk : masked q k - · simp [dotHi, dotRowTasks, Task.spawn, Array.getElem_ofFn, hmk] - · simp [dotHi, dotRowTasks, Task.spawn, Array.getElem_ofFn, hmk] - have hq_bounds_local : - ∀ q d, (qLo q d : Real) ≤ qRealOfInputs inputs q d ∧ - qRealOfInputs inputs q d ≤ (qHi q d : Real) := by - intro q d - have hbase : - ∀ d, - (qBase d : Real) = - dotProduct (fun j => (inputs.wq j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - (inputs.bq d : Real) := by - intro d - simp [qBase, qBaseArr, dotFin_cast] - have hcoeff : - ∀ q' d, - (qCoeff q' d : Real) = - dotProduct (fun j => (inputs.wq j d : Real)) - (fun j => (lnCoeff q' j : Real)) := by - intro q' d - simpa [qCoeff, qCoeffArr, qCoeffRowTasks, coeffRowTasks, lnCoeff, Task.spawn] using - (dotFin_cast (f := fun j => inputs.wq j d) - (g := fun j => - inputs.ln1Gamma j * (inputs.embed q' j - mean (inputs.embed q')))) - have h := proj_bounds (w := inputs.wq) (b := inputs.bq) (base := qBase) - (coeff := qCoeff) hbase hcoeff q d - simpa [qLo, qHi, qRealOfInputs_def] using h - have hk_bounds_local : - ∀ q d, (kLo q d : Real) ≤ kRealOfInputs inputs q d ∧ - kRealOfInputs inputs q d ≤ (kHi q d : Real) := by - intro q d - have hbase : - ∀ d, - (kBase d : Real) = - dotProduct (fun j => (inputs.wk j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - (inputs.bk d : Real) := by - intro d - simp [kBase, kBaseArr, dotFin_cast] - have hcoeff : - ∀ q' d, - (kCoeff q' d : Real) = - dotProduct (fun j => (inputs.wk j d : Real)) - (fun j => (lnCoeff q' j : Real)) := by - intro q' d - simpa [kCoeff, kCoeffArr, kCoeffRowTasks, coeffRowTasks, lnCoeff, Task.spawn] using - (dotFin_cast (f := fun j => inputs.wk j d) - (g := fun j => - inputs.ln1Gamma j * (inputs.embed q' j - mean (inputs.embed q')))) - have h := proj_bounds (w := inputs.wk) (b := inputs.bk) (base := kBase) - (coeff := kCoeff) hbase hcoeff q d - simpa [kLo, kHi, kRealOfInputs_def] using h - let scoresReal := scoresRealOfInputs inputs - have scoresReal_eq_base_of_not_masked : - ∀ q k, ¬ masked q k → - scoresReal q k = - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) := by - intro q k hnot - by_cases hcausal : inputs.maskCausal - · have hnot_lt : ¬ q < k := by - intro hlt - exact hnot ⟨hcausal, hlt⟩ - have hle : k ≤ q := le_of_not_gt hnot_lt - simp [scoresReal, scoresRealOfInputs_def, hcausal, hle] - · simp [scoresReal, scoresRealOfInputs_def, hcausal] - have scoresReal_eq_masked : - ∀ q k, masked q k → scoresReal q k = (inputs.maskValue : Real) := by - intro q k hmask - have hmask' : inputs.maskCausal = true ∧ q < k := by - simpa [masked] using hmask - have hle : ¬ k ≤ q := not_le_of_gt hmask'.2 - simp [scoresReal, scoresRealOfInputs_def, hmask'.1, hle] - have hscore_bounds_local : - ∀ q k, (scoreLo q k : Real) ≤ scoresReal q k ∧ - scoresReal q k ≤ (scoreHi q k : Real) := by - intro q k - let base := - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) - have hdot_bounds (hnot : ¬ masked q k) : - (dotLo q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) ∧ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) ≤ (dotHi q k : Real) := by - have hq := hq_bounds_local q - have hk := hk_bounds_local k - have hlo1 : ∀ d, (qLo q d : Real) ≤ qRealOfInputs inputs q d := fun d => - (hq d).1 - have hhi1 : ∀ d, qRealOfInputs inputs q d ≤ (qHi q d : Real) := fun d => - (hq d).2 - have hlo2 : ∀ d, (kLo k d : Real) ≤ kRealOfInputs inputs k d := fun d => - (hk d).1 - have hhi2 : ∀ d, kRealOfInputs inputs k d ≤ (kHi k d : Real) := fun d => - (hk d).2 - have hspec := - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth_spec_real - (dims1 := splitDimsQ q) (dims2 := splitDimsK q k) - (lo1 := fun d => qLo q d) (hi1 := fun d => qHi q d) - (lo2 := fun d => kLo k d) (hi2 := fun d => kHi k d) - (x := fun d => qRealOfInputs inputs q d) - (y := fun d => kRealOfInputs inputs k d) - hlo1 hhi1 hlo2 hhi2 - have hlow' : - (dotLo q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) := by - simpa [dotLo_eq, hnot] using hspec.1 - have hhigh' : - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) ≤ (dotHi q k : Real) := by - simpa [dotHi_eq, hnot] using hspec.2 - exact ⟨hlow', hhigh'⟩ - have hscore_base_bounds (hnot : ¬ masked q k) : - (scoreLo q k : Real) ≤ base ∧ base ≤ (scoreHi q k : Real) := by - by_cases hscale : 0 ≤ inputs.scale - · have hscale_real : 0 ≤ (inputs.scale : Real) := by - simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hscale - have hdot := hdot_bounds hnot - have hlow := mul_le_mul_of_nonneg_left hdot.1 hscale_real - have hhigh := mul_le_mul_of_nonneg_left hdot.2 hscale_real - constructor - · simpa [scoreLo, masked, hnot, hscale, base] using hlow - · simpa [scoreHi, masked, hnot, hscale, base] using hhigh - · have hscale_nonpos : inputs.scale ≤ 0 := - le_of_lt (lt_of_not_ge hscale) - have hscale_real : (inputs.scale : Real) ≤ 0 := by - simpa [ratToReal_def] using - (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos - have hdot := hdot_bounds hnot - have hlow := mul_le_mul_of_nonpos_left hdot.2 hscale_real - have hhigh := mul_le_mul_of_nonpos_left hdot.1 hscale_real - constructor - · simpa [scoreLo, masked, hnot, hscale, base] using hlow - · simpa [scoreHi, masked, hnot, hscale, base] using hhigh - by_cases hcausal : inputs.maskCausal - · by_cases hle : k ≤ q - · have hnot : ¬ q < k := not_lt_of_ge hle - have hnot_masked : ¬ masked q k := fun hmk => hnot hmk.2 - have hscore_eq : scoresReal q k = base := - scoresReal_eq_base_of_not_masked q k hnot_masked - have hbase := hscore_base_bounds hnot_masked - constructor - · simpa [hscore_eq] using hbase.1 - · simpa [hscore_eq] using hbase.2 - · have hlt : q < k := lt_of_not_ge hle - have hmask : masked q k := ⟨hcausal, hlt⟩ - have hscore : scoresReal q k = (inputs.maskValue : Real) := - scoresReal_eq_masked q k hmask - constructor - · simp [hscore, scoreLo, hmask, masked] - · simp [hscore, scoreHi, hmask, masked] - · have hnot_masked : ¬ masked q k := by - simp [masked, hcausal] - have hscore_eq : scoresReal q k = base := - scoresReal_eq_base_of_not_masked q k hnot_masked - have hbase := hscore_base_bounds hnot_masked - constructor - · simpa [hscore_eq] using hbase.1 - · simpa [hscore_eq] using hbase.2 - have hlocal : - (∀ q d, (qLo q d : Real) ≤ qRealOfInputs inputs q d ∧ - qRealOfInputs inputs q d ≤ (qHi q d : Real)) ∧ - (∀ q d, (kLo q d : Real) ≤ kRealOfInputs inputs q d ∧ - kRealOfInputs inputs q d ≤ (kHi q d : Real)) ∧ - (∀ q k, (scoreLo q k : Real) ≤ scoresReal q k ∧ - scoresReal q k ≤ (scoreHi q k : Real)) := by - exact ⟨hq_bounds_local, hk_bounds_local, hscore_bounds_local⟩ - simpa (config := { zeta := false }) [hcache, buildInductionHeadCoreCacheWith] using - hlocal - -/-- Query bounds for `buildInductionHeadCoreCacheWith`. -/ -theorem buildInductionHeadCoreCacheWith_q_bounds - [NeZero seq] {dModel dHead : Nat} - (cfg : InductionHeadSplitConfig) - (inputs : Model.InductionHeadInputs seq dModel dHead) - (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) - (hmodel : dModel ≠ 0) : - ∀ q d, - ((buildInductionHeadCoreCacheWith cfg inputs).qLo q d : Real) ≤ - qRealOfInputs inputs q d ∧ - qRealOfInputs inputs q d ≤ - ((buildInductionHeadCoreCacheWith cfg inputs).qHi q d : Real) := by - exact (buildInductionHeadCoreCacheWith_bounds cfg inputs hEps hSqrt hmodel).1 - -/-- Key bounds for `buildInductionHeadCoreCacheWith`. -/ -theorem buildInductionHeadCoreCacheWith_k_bounds - [NeZero seq] {dModel dHead : Nat} - (cfg : InductionHeadSplitConfig) - (inputs : Model.InductionHeadInputs seq dModel dHead) - (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) - (hmodel : dModel ≠ 0) : - ∀ q d, - ((buildInductionHeadCoreCacheWith cfg inputs).kLo q d : Real) ≤ - kRealOfInputs inputs q d ∧ - kRealOfInputs inputs q d ≤ - ((buildInductionHeadCoreCacheWith cfg inputs).kHi q d : Real) := by - exact (buildInductionHeadCoreCacheWith_bounds cfg inputs hEps hSqrt hmodel).2.1 - -/-- Score bounds for `buildInductionHeadCoreCacheWith`. -/ -theorem buildInductionHeadCoreCacheWith_score_bounds - [NeZero seq] {dModel dHead : Nat} - (cfg : InductionHeadSplitConfig) - (inputs : Model.InductionHeadInputs seq dModel dHead) - (hEps : 0 < inputs.lnEps) (hSqrt : 0 < sqrtLower inputs.lnEps) - (hmodel : dModel ≠ 0) : - ∀ q k, - ((buildInductionHeadCoreCacheWith cfg inputs).scoreLo q k : Real) ≤ - scoresRealOfInputs inputs q k ∧ - scoresRealOfInputs inputs q k ≤ - ((buildInductionHeadCoreCacheWith cfg inputs).scoreHi q k : Real) := by - exact (buildInductionHeadCoreCacheWith_bounds cfg inputs hEps hSqrt hmodel).2.2 - - -end Sound -end Nfp diff --git a/Nfp/Sound/Induction/CoreSound/Basic/CertSound.lean b/Nfp/Sound/Induction/CoreSound/Basic/CertSound.lean deleted file mode 100644 index 1981093..0000000 --- a/Nfp/Sound/Induction/CoreSound/Basic/CertSound.lean +++ /dev/null @@ -1,1316 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later -module - -import all Nfp.Sound.Induction.Core.Basic -public import Nfp.Sound.Induction.Core -public import Nfp.Sound.Induction.CoreSound.Values - -public section - -namespace Nfp -namespace Sound -open scoped BigOperators -open Nfp.Circuit -open Nfp.Sound.Bounds -variable {seq : Nat} -/-- Soundness for `buildInductionCertFromHeadCoreWith?`. -/ -theorem buildInductionCertFromHeadCoreWith?_sound [NeZero seq] {dModel dHead : Nat} - (cfg : InductionHeadSplitConfig) - (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) - (hcore : buildInductionCertFromHeadCoreWith? cfg inputs = some c) : - InductionHeadCertSound inputs c := by - classical - by_cases hEps : 0 < inputs.lnEps - · by_cases hSqrt : 0 < sqrtLower inputs.lnEps - · by_cases hmodel : dModel = 0 - · have : False := by - have hnone := - buildInductionCertFromHeadCoreWith?_eq_none_of_model_eq_zero - (cfg := cfg) (inputs := inputs) hEps hSqrt hmodel - have hcore' : - (none : Option (InductionHeadCert seq)) = some c := by - exact hnone.symm.trans hcore - cases hcore' - exact this.elim - · by_cases hactive : inputs.active.Nonempty - · let lnBounds := Bounds.cacheBoundPair2 (fun q => - Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) - let lnLo : Fin seq → Fin dModel → Rat := lnBounds.1 - let lnHi : Fin seq → Fin dModel → Rat := lnBounds.2 - let lnAbsMaxTask : Fin seq → Rat := - Bounds.cacheBoundTask (fun q => - Bounds.intervalAbsBound (lnLo q) (lnHi q)) - let lnAbsMaxArr : Array Rat := - Array.ofFn (fun q : Fin seq => lnAbsMaxTask q) - let lnAbsMax : Fin seq → Rat := fun q => - lnAbsMaxArr[q.1]'(by - simp [lnAbsMaxArr]) - let invStdBoundsTasks : Array (Task (Rat × Rat)) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => invStdBounds inputs.lnEps (inputs.embed q))) - let invStdBoundsArr : Array (Rat × Rat) := - Array.ofFn (fun q : Fin seq => - (invStdBoundsTasks[q.1]'(by - simp [invStdBoundsTasks])).get) - let invStdLo : Fin seq → Rat := fun q => - (invStdBoundsArr[q.1]'(by - simp [invStdBoundsArr])).1 - let invStdHi : Fin seq → Rat := fun q => - (invStdBoundsArr[q.1]'(by - simp [invStdBoundsArr])).2 - let lnCoeff : Fin seq → Fin dModel → Rat := fun q j => - inputs.ln1Gamma j * (inputs.embed q j - mean (inputs.embed q)) - let invStd : Fin seq → Real := fun q => - (Real.sqrt ((varianceRat (inputs.embed q) : Real) + (inputs.lnEps : Real)))⁻¹ - have hmeanRat : - ∀ q, (mean (inputs.embed q) : Real) = meanRat (inputs.embed q) := by - intro q - have hmu_rat : mean (inputs.embed q) = meanRat (inputs.embed q) := by - simp [mean_def, hmodel, ratRoundDown_def] - simpa [ratToReal_def] using congrArg ratToReal hmu_rat - have hln_affine : - ∀ q j, - lnRealOfInputs inputs q j = - (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q := by - intro q j - have hmu := hmeanRat q - simp [lnRealOfInputs_def, Bounds.layerNormReal_def, hmodel, lnCoeff, hmu, invStd, - add_comm, mul_assoc, -mul_eq_mul_left_iff, -mul_eq_mul_right_iff] - have hln_fun : - ∀ q, - lnRealOfInputs inputs q = - fun j => (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q := by - intro q - funext j - exact hln_affine q j - have hinv_bounds : - ∀ q, (invStdLo q : Real) ≤ invStd q ∧ invStd q ≤ (invStdHi q : Real) := by - intro q - simpa [invStd, invStdLo, invStdHi, invStdBoundsArr, invStdBoundsTasks, - Bounds.invStdBounds_def, Task.spawn, Array.getElem_ofFn] using - (Bounds.invStdBounds_spec (eps := inputs.lnEps) (x := inputs.embed q) - hmodel hEps hSqrt) - let qBaseArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.wq j d) (fun j => inputs.ln1Beta j) + - inputs.bq d) - let qBase : Fin dHead → Rat := fun d => - qBaseArr[d.1]'(by - simp [qBaseArr]) - let kBaseArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.wk j d) (fun j => inputs.ln1Beta j) + - inputs.bk d) - let kBase : Fin dHead → Rat := fun d => - kBaseArr[d.1]'(by - simp [kBaseArr]) - let coeffRowTasks : - (Fin dModel → Fin dHead → Rat) → - Array (Task { row : Array Rat // row.size = dHead }) := - fun w => - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - let μ := mean (inputs.embed q) - let coeff : Fin dModel → Rat := fun j => - inputs.ln1Gamma j * (inputs.embed q j - μ) - ⟨Array.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => w j d) coeff), - by simp⟩)) - let qCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := - coeffRowTasks inputs.wq - let qCoeffArr : Array { row : Array Rat // row.size = dHead } := - Array.ofFn (fun q : Fin seq => - (qCoeffRowTasks[q.1]'(by - simp [qCoeffRowTasks, coeffRowTasks])).get) - let qCoeff : Fin seq → Fin dHead → Rat := fun q d => - let row := qCoeffArr[q.1]'(by - simp [qCoeffArr]) - row.1[d.1]'(by - simp [row.2]) - let kCoeffRowTasks : Array (Task { row : Array Rat // row.size = dHead }) := - coeffRowTasks inputs.wk - let kCoeffArr : Array { row : Array Rat // row.size = dHead } := - Array.ofFn (fun q : Fin seq => - (kCoeffRowTasks[q.1]'(by - simp [kCoeffRowTasks, coeffRowTasks])).get) - let kCoeff : Fin seq → Fin dHead → Rat := fun q d => - let row := kCoeffArr[q.1]'(by - simp [kCoeffArr]) - row.1[d.1]'(by - simp [row.2]) - let qLo : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) - qBase d + bounds.1 - let qHi : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (qCoeff q d) (invStdLo q) (invStdHi q) - qBase d + bounds.2 - let kLo : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) - kBase d + bounds.1 - let kHi : Fin seq → Fin dHead → Rat := fun q d => - let bounds := scaleInterval (kCoeff q d) (invStdLo q) (invStdHi q) - kBase d + bounds.2 - let qAbs : Fin seq → Fin dHead → Rat := fun q d => max |qLo q d| |qHi q d| - let kAbs : Fin seq → Fin dHead → Rat := fun q d => max |kLo q d| |kHi q d| - let qAbsMaxArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := Finset.univ_nonempty - univ.sup' hnonempty (fun q => qAbs q d)) - let qAbsMax : Fin dHead → Rat := fun d => - qAbsMaxArr[d.1]'(by - simp [qAbsMaxArr]) - let kAbsMaxArr : Array Rat := - Array.ofFn (fun d : Fin dHead => - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := Finset.univ_nonempty - univ.sup' hnonempty (fun k => kAbs k d)) - let kAbsMax : Fin dHead → Rat := fun d => - kAbsMaxArr[d.1]'(by - simp [kAbsMaxArr]) - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let splitBudgetQ : Nat := cfg.splitBudgetQ - let splitBudgetK : Nat := cfg.splitBudgetK - let splitBudgetDiffBase : Nat := cfg.splitBudgetDiffBase - let splitBudgetDiffRefined : Nat := cfg.splitBudgetDiffRefined - let top2ByScore : - (Fin dHead → Rat) → List (Fin dHead) → List (Fin dHead) := fun score ambig => - let step - (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) - (d : Fin dHead) : - Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := - let s := score d - match best with - | (none, none) => (some (s, d), none) - | (some b1, none) => - if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) - | (some b1, some b2) => - if b1.1 < s then (some (s, d), some b1) - else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) - | (none, some b2) => - if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - match ambig.foldl step (none, none) with - | (some b1, some b2) => [b1.2, b2.2] - | (some b1, none) => [b1.2] - | (none, _) => [] - let finRangeHead : List (Fin dHead) := List.finRange dHead - let finRangeSeq : List (Fin seq) := List.finRange seq - let splitDimsQ : Fin seq → List (Fin dHead) := fun q => - if splitBudgetQ = 0 then - [] - else - let ambig := - finRangeHead.filter (fun d => decide (qLo q d < 0 ∧ 0 < qHi q d)) - let score : Fin dHead → Rat := fun d => (qHi q d - qLo q d) * kAbsMax d - let dims1 := top2ByScore score ambig - let dims2 := top2ByScore score (ambig.filter (fun d => decide (d ∉ dims1))) - (dims1 ++ dims2).take splitBudgetQ - let splitDimsK : Fin seq → Fin seq → List (Fin dHead) := fun q k => - if splitBudgetK = 0 then - [] - else - let ambig := - finRangeHead.filter (fun d => decide (kLo k d < 0 ∧ 0 < kHi k d)) - let score : Fin dHead → Rat := fun d => (kHi k d - kLo k d) * qAbs q d - let dims1 := top2ByScore score ambig - let dims2 := top2ByScore score (ambig.filter (fun d => decide (d ∉ dims1))) - (dims1 ++ dims2).take splitBudgetK - let splitDimsDiffCore : Nat → Fin seq → Fin seq → List (Fin dHead) := - fun budget q k => - if budget = 0 then - [] - else - let prev := inputs.prev q - let diffLo : Fin dHead → Rat := fun d => kLo prev d - kHi k d - let diffHi : Fin dHead → Rat := fun d => kHi prev d - kLo k d - let ambig := - finRangeHead.filter (fun d => decide (diffLo d < 0 ∧ 0 < diffHi d)) - let score : Fin dHead → Rat := fun d => (diffHi d - diffLo d) * qAbs q d - let step - (best : Option (Rat × Fin dHead) × Option (Rat × Fin dHead)) - (d : Fin dHead) : - Option (Rat × Fin dHead) × Option (Rat × Fin dHead) := - let s := score d - match best with - | (none, none) => (some (s, d), none) - | (some b1, none) => - if b1.1 < s then (some (s, d), some b1) else (some b1, some (s, d)) - | (some b1, some b2) => - if b1.1 < s then (some (s, d), some b1) - else if b2.1 < s then (some b1, some (s, d)) else (some b1, some b2) - | (none, some b2) => - if b2.1 < s then (some (s, d), some b2) else (some b2, some (s, d)) - let top2 : List (Fin dHead) → List (Fin dHead) := fun ambig => - match ambig.foldl step (none, none) with - | (some b1, some b2) => [b1.2, b2.2] - | (some b1, none) => [b1.2] - | (none, _) => [] - let dims1 : List (Fin dHead) := top2 ambig - let dims2 : List (Fin dHead) := - top2 (ambig.filter (fun d => decide (d ∉ dims1))) - let memDims2 : Fin dHead → Bool := fun d => - dims2.any (fun d' => decide (d' = d)) - let dims3 : List (Fin dHead) := - top2 - ((ambig.filter (fun d => decide (d ∉ dims1))).filter - (fun d => !memDims2 d)) - (dims1 ++ dims2 ++ dims3).take budget - let splitDimsDiffBase : Fin seq → Fin seq → List (Fin dHead) := - splitDimsDiffCore splitBudgetDiffBase - let splitDimsDiffRefined : Fin seq → Fin seq → List (Fin dHead) := - splitDimsDiffCore splitBudgetDiffRefined - let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - let dimsQ := splitDimsQ q - ⟨Array.ofFn (fun k : Fin seq => - if masked q k then - (0, 0) - else - let dimsK := splitDimsK q k - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsK - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo k d) (fun d => kHi k d)), - by simp⟩)) - let dotDiffRowTasksBase : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - if hq : q ∈ inputs.active then - let dimsQ := splitDimsQ q - ⟨Array.ofFn (fun k : Fin seq => - if masked q k then - (0, 0) - else - let dimsDiff := splitDimsDiffBase q k - let prev := inputs.prev q - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)), - by simp⟩ - else - ⟨Array.ofFn (fun _ : Fin seq => (0, 0)), by simp⟩)) - let dotLo : Fin seq → Fin seq → Rat := fun q k => - let row := (dotRowTasks[q.1]'(by - simp [dotRowTasks, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.1 - let dotHi : Fin seq → Fin seq → Rat := fun q k => - let row := (dotRowTasks[q.1]'(by - simp [dotRowTasks, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.2 - let dotDiffLoBase : Fin seq → Fin seq → Rat := fun q k => - let row := (dotDiffRowTasksBase[q.1]'(by - simp [dotDiffRowTasksBase, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.1 - let dotDiffHiBase : Fin seq → Fin seq → Rat := fun q k => - let row := (dotDiffRowTasksBase[q.1]'(by - simp [dotDiffRowTasksBase, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.2 - let dotAbs : Fin seq → Fin seq → Rat := fun q k => max |dotLo q k| |dotHi q k| - let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => - |inputs.scale| * dotAbs q k - let scoreLo : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - if hscale : 0 ≤ inputs.scale then - inputs.scale * dotLo q k - else - inputs.scale * dotHi q k - let scoreHi : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - if hscale : 0 ≤ inputs.scale then - inputs.scale * dotHi q k - else - inputs.scale * dotLo q k - let scoreLoPrev : Fin seq → Rat := fun q => - scoreLo q (inputs.prev q) - let scoreGapLoBaseRaw : Fin seq → Fin seq → Rat := fun q k => - if masked q (inputs.prev q) then - scoreLoPrev q - scoreHi q k - else if masked q k then - scoreLoPrev q - inputs.maskValue - else if hscale : 0 ≤ inputs.scale then - inputs.scale * dotDiffLoBase q k - else - inputs.scale * dotDiffHiBase q k - let scoreGapLoBase : Fin seq → Fin seq → Rat := - Bounds.cacheBound2 scoreGapLoBaseRaw - let otherKeys : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - let worstKey : Fin seq → Option (Fin seq) := - if h : cfg.splitBudgetDiffRefined = cfg.splitBudgetDiffBase then - fun _ => none - else - let worstKeyRaw : Fin seq → Option (Fin seq) := fun q => - if hq : q ∈ inputs.active then - let ks := finRangeSeq.filter (fun k => decide (k ≠ inputs.prev q)) - match ks with - | [] => none - | k :: ks => - let step (best : Rat × Fin seq) (k : Fin seq) := - let s := scoreGapLoBase q k - if s ≤ best.1 then (s, k) else best - some (ks.foldl step (scoreGapLoBase q k, k)).2 - else - none - let worstKeyArr : Array (Thunk (Option (Fin seq))) := - Array.ofFn (fun q => Thunk.mk (fun _ => worstKeyRaw q)) - fun q => - let t := worstKeyArr[q.1]'(by - simp [worstKeyArr, q.isLt]) - Thunk.get t - let refineKeys : Fin seq → Finset (Fin seq) := - if h : cfg.splitBudgetDiffRefined = cfg.splitBudgetDiffBase then - fun _ => ∅ - else - let refineKeysRaw : Fin seq → Finset (Fin seq) := fun q => - let base : Finset (Fin seq) := - match worstKey q with - | some k => {k} - | none => ∅ - if hq : q ∈ inputs.active then - let other := otherKeys q - base ∪ other.filter (fun k => decide (scoreGapLoBase q k < 0)) - else - base - let refineKeysArr : Array (Thunk (Finset (Fin seq))) := - Array.ofFn (fun q => Thunk.mk (fun _ => refineKeysRaw q)) - fun q => - let t := refineKeysArr[q.1]'(by - simp [refineKeysArr, q.isLt]) - Thunk.get t - let dotDiffLoHi : (Fin seq → Fin seq → Rat) × (Fin seq → Fin seq → Rat) := - if h : cfg.splitBudgetDiffRefined = cfg.splitBudgetDiffBase then - (dotDiffLoBase, dotDiffHiBase) - else - let dotDiffLo : Fin seq → Fin seq → Rat := fun q k => - if hk : k ∈ refineKeys q then - let dimsQ := splitDimsQ q - let dimsDiff := splitDimsDiffRefined q k - let prev := inputs.prev q - (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)).1 - else - dotDiffLoBase q k - let dotDiffHi : Fin seq → Fin seq → Rat := fun q k => - if hk : k ∈ refineKeys q then - let dimsQ := splitDimsQ q - let dimsDiff := splitDimsDiffRefined q k - let prev := inputs.prev q - (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo prev d - kHi k d) - (fun d => kHi prev d - kLo k d)).2 - else - dotDiffHiBase q k - (dotDiffLo, dotDiffHi) - let dotDiffLo : Fin seq → Fin seq → Rat := dotDiffLoHi.1 - let dotDiffHi : Fin seq → Fin seq → Rat := dotDiffLoHi.2 - let scoreGapLoRaw : Fin seq → Fin seq → Rat := fun q k => - if masked q (inputs.prev q) then - scoreLoPrev q - scoreHi q k - else if masked q k then - scoreLoPrev q - inputs.maskValue - else if hscale : 0 ≤ inputs.scale then - inputs.scale * dotDiffLo q k - else - inputs.scale * dotDiffHi q k - let scoreGapLo : Fin seq → Fin seq → Rat := - Bounds.cacheBound2 scoreGapLoRaw - let certFields := buildInductionHeadCertFields inputs otherKeys scoreGapLo - let marginAt : Fin seq → Rat := certFields.marginAt - let weightBoundAtBase : Fin seq → Fin seq → Rat := certFields.weightBoundAtBase - let weightBoundAtBaseCached : Fin seq → Fin seq → Rat := - certFields.weightBoundAtBaseCached - let epsAtBase : Fin seq → Rat := certFields.epsAtBase - let epsAt : Fin seq → Rat := certFields.epsAt - let weightBoundAt : Fin seq → Fin seq → Rat := certFields.weightBoundAt - let margin : Rat := certFields.margin - let eps : Rat := certFields.eps - have hseq : (1 : Nat) ≤ seq := - Nat.succ_le_iff.mpr (Nat.pos_of_ne_zero (NeZero.ne seq)) - let dirHead : Fin dHead → Rat := fun d => (dirHeadVecOfInputs inputs).get d - let wvDirRaw : Fin dModel → Rat := fun j => - Linear.dotFin dHead dirHead (fun d => inputs.wv j d) - let wvDirTask : Fin dModel → Rat := Bounds.cacheBoundTask wvDirRaw - let wvDirArr : Array Rat := Array.ofFn wvDirTask - let wvDir : Fin dModel → Rat := fun j => - wvDirArr[j.1]'(by - have hsize : wvDirArr.size = dModel := by - simp [wvDirArr] - simp [hsize, j.isLt]) - let bDir : Rat := - Linear.dotFin dHead dirHead (fun d => inputs.bv d) - let valsLo : Fin seq → Rat := fun q => - bDir + Bounds.dotIntervalLower wvDir (lnLo q) (lnHi q) - let valsHi : Fin seq → Rat := fun q => - bDir + Bounds.dotIntervalUpper wvDir (lnLo q) (lnHi q) - let hvalsLo : ∀ k, - valsLo k = bDir + Bounds.dotIntervalLower wvDir (lnLo k) (lnHi k) := fun _ => rfl - let hvalsHi : ∀ k, - valsHi k = bDir + Bounds.dotIntervalUpper wvDir (lnLo k) (lnHi k) := fun _ => rfl - let univ : Finset (Fin seq) := Finset.univ - have hnonempty : univ.Nonempty := by simp [univ] - let lo := univ.inf' hnonempty valsLo - let hi := univ.sup' hnonempty valsHi - let valCert : ValueInterval seq := buildInductionHeadValCert inputs valsLo valsHi - let cert : InductionHeadCert seq := - { eps := eps - epsAt := epsAt - weightBoundAt := weightBoundAt - margin := margin - active := inputs.active - prev := inputs.prev - values := valCert } - have hcore' : buildInductionCertFromHeadCoreWith? cfg inputs = some cert := by - have hcore'' : - buildInductionCertFromHeadCoreWith? cfg inputs = - some (buildInductionHeadCoreCacheWith cfg inputs).cert := - buildInductionCertFromHeadCoreWith?_eq_some - (cfg := cfg) (inputs := inputs) hEps hSqrt hmodel hactive - have hcert_eq : - (buildInductionHeadCoreCacheWith cfg inputs).cert = cert := by - rfl - simpa [hcert_eq] using hcore'' - have hc : c = cert := by - have hcert : cert = c := by - exact Option.some.inj (hcore'.symm.trans hcore) - simpa using hcert.symm - subst hc - have hln_bounds : - ∀ q i, (lnLo q i : Real) ≤ lnRealOfInputs inputs q i ∧ - lnRealOfInputs inputs q i ≤ (lnHi q i : Real) := by - intro q i - have hln := - Bounds.layerNormBounds_spec (eps := inputs.lnEps) - (gamma := inputs.ln1Gamma) (beta := inputs.ln1Beta) - (x := inputs.embed q) hmodel hEps hSqrt - simpa [lnBounds, lnLo, lnHi, lnRealOfInputs_def, - Bounds.cacheBoundPair2_apply_left, Bounds.cacheBoundPair2_apply_right] using - hln i - have dotFin_cast {n : Nat} (f g : Fin n → Rat) : - (Linear.dotFin n f g : Real) = - dotProduct (fun j => (f j : Real)) (fun j => (g j : Real)) := by - simp [Linear.dotFin_def, Linear.sumFin_eq_sum_univ, dotProduct] - have proj_bounds - (w : Fin dModel → Fin dHead → Rat) - (b base : Fin dHead → Rat) - (coeff : Fin seq → Fin dHead → Rat) - (hbase : ∀ d, - (base d : Real) = - dotProduct (fun j => (w j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - (b d : Real)) - (hcoeff : ∀ q d, - (coeff q d : Real) = - dotProduct (fun j => (w j d : Real)) - (fun j => (lnCoeff q j : Real))) : - ∀ q d, - (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).1 : Rat) ≤ - dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + - (b d : Real) ∧ - dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + - (b d : Real) ≤ - (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).2 : Rat) := by - intro q d - have hinv : (invStdLo q : Real) ≤ invStd q ∧ invStd q ≤ (invStdHi q : Real) := - hinv_bounds q - have hln_fun_q : - lnRealOfInputs inputs q = - fun j => (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q := by - exact hln_fun q - have hdot_add : - dotProduct (fun j => (w j d : Real)) - (fun j => - (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q) = - dotProduct (fun j => (w j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - dotProduct (fun j => (w j d : Real)) - (fun j => (lnCoeff q j : Real) * invStd q) := by - simpa using - (Nfp.Sound.Linear.dotProduct_add_right - (x := fun j => (w j d : Real)) - (y := fun j => (inputs.ln1Beta j : Real)) - (z := fun j => (lnCoeff q j : Real) * invStd q)) - have hdot_coeff : - dotProduct (fun j => (w j d : Real)) - (fun j => (lnCoeff q j : Real) * invStd q) = - dotProduct (fun j => (w j d : Real)) - (fun j => (lnCoeff q j : Real)) * - invStd q := by - simpa using - (Nfp.Sound.Linear.dotProduct_mul_right - (x := fun j => (w j d : Real)) - (y := fun j => (lnCoeff q j : Real)) - (a := invStd q)) - have hreal : - dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + - (b d : Real) = - (base d : Real) + (coeff q d : Real) * invStd q := by - calc - dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + - (b d : Real) = - dotProduct (fun j => (w j d : Real)) - (fun j => - (inputs.ln1Beta j : Real) + (lnCoeff q j : Real) * invStd q) + - (b d : Real) := by - simp [hln_fun_q] - _ = - dotProduct (fun j => (w j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - dotProduct (fun j => (w j d : Real)) - (fun j => (lnCoeff q j : Real)) * - invStd q + - (b d : Real) := by - simp [hdot_add, hdot_coeff, add_assoc] - _ = - (dotProduct (fun j => (w j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - (b d : Real)) + - dotProduct (fun j => (w j d : Real)) - (fun j => (lnCoeff q j : Real)) * - invStd q := by ac_rfl - _ = (base d : Real) + (coeff q d : Real) * invStd q := by - simp [hbase, hcoeff] - have hscale : - let bounds := scaleInterval (coeff q d) (invStdLo q) (invStdHi q) - (bounds.1 : Real) ≤ (coeff q d : Real) * invStd q ∧ - (coeff q d : Real) * invStd q ≤ (bounds.2 : Real) := by - exact scaleInterval_bounds_real (x := coeff q d) (lo := invStdLo q) - (hi := invStdHi q) (y := invStd q) hinv.1 hinv.2 - have hlow : - (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).1 : Rat) ≤ - dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + - (b d : Real) := by - simpa [hreal] using add_le_add_left hscale.1 (base d : Real) - have hhigh : - dotProduct (fun j => (w j d : Real)) (lnRealOfInputs inputs q) + - (b d : Real) ≤ - (base d + (scaleInterval (coeff q d) (invStdLo q) (invStdHi q)).2 : Rat) := by - simpa [hreal] using add_le_add_left hscale.2 (base d : Real) - exact ⟨hlow, hhigh⟩ - have hq_bounds : - ∀ q d, (qLo q d : Real) ≤ qRealOfInputs inputs q d ∧ - qRealOfInputs inputs q d ≤ (qHi q d : Real) := by - intro q d - have hbase : - ∀ d, - (qBase d : Real) = - dotProduct (fun j => (inputs.wq j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - (inputs.bq d : Real) := by - intro d - simp [qBase, qBaseArr, dotFin_cast] - have hcoeff : - ∀ q' d, - (qCoeff q' d : Real) = - dotProduct (fun j => (inputs.wq j d : Real)) - (fun j => (lnCoeff q' j : Real)) := by - intro q' d - simpa [qCoeff, qCoeffArr, qCoeffRowTasks, coeffRowTasks, lnCoeff, Task.spawn] using - (dotFin_cast (f := fun j => inputs.wq j d) - (g := fun j => - inputs.ln1Gamma j * (inputs.embed q' j - mean (inputs.embed q')))) - have h := proj_bounds (w := inputs.wq) (b := inputs.bq) (base := qBase) - (coeff := qCoeff) hbase hcoeff q d - simpa [qLo, qHi, qRealOfInputs_def] using h - have hk_bounds : - ∀ q d, (kLo q d : Real) ≤ kRealOfInputs inputs q d ∧ - kRealOfInputs inputs q d ≤ (kHi q d : Real) := by - intro q d - have hbase : - ∀ d, - (kBase d : Real) = - dotProduct (fun j => (inputs.wk j d : Real)) - (fun j => (inputs.ln1Beta j : Real)) + - (inputs.bk d : Real) := by - intro d - simp [kBase, kBaseArr, dotFin_cast] - have hcoeff : - ∀ q' d, - (kCoeff q' d : Real) = - dotProduct (fun j => (inputs.wk j d : Real)) - (fun j => (lnCoeff q' j : Real)) := by - intro q' d - simpa [kCoeff, kCoeffArr, kCoeffRowTasks, coeffRowTasks, lnCoeff, Task.spawn] using - (dotFin_cast (f := fun j => inputs.wk j d) - (g := fun j => - inputs.ln1Gamma j * (inputs.embed q' j - mean (inputs.embed q')))) - have h := proj_bounds (w := inputs.wk) (b := inputs.bk) (base := kBase) - (coeff := kCoeff) hbase hcoeff q d - simpa [kLo, kHi, kRealOfInputs_def] using h - let scoresReal := scoresRealOfInputs inputs - have scoresReal_eq_base_of_not_masked : - ∀ q k, ¬ masked q k → - scoresReal q k = - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) := by - intro q k hnot - by_cases hcausal : inputs.maskCausal - · have hnot_lt : ¬ q < k := by - intro hlt - exact hnot ⟨hcausal, hlt⟩ - have hle : k ≤ q := le_of_not_gt hnot_lt - simp [scoresReal, scoresRealOfInputs_def, hcausal, hle] - · simp [scoresReal, scoresRealOfInputs_def, hcausal] - have scoresReal_eq_masked : - ∀ q k, masked q k → scoresReal q k = (inputs.maskValue : Real) := by - intro q k hmask - have hmask' : inputs.maskCausal = true ∧ q < k := by - simpa [masked] using hmask - have hle : ¬ k ≤ q := not_le_of_gt hmask'.2 - simp [scoresReal, scoresRealOfInputs_def, hmask'.1, hle] - have hscore_bounds : - ∀ q k, (scoreLo q k : Real) ≤ scoresReal q k ∧ - scoresReal q k ≤ (scoreHi q k : Real) := by - intro q k - let base := - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) - have hdot_bounds (hnot : ¬ masked q k) : - (dotLo q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) ∧ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) ≤ (dotHi q k : Real) := by - have hq := hq_bounds q - have hk := hk_bounds k - have hlo1 : ∀ d, (qLo q d : Real) ≤ qRealOfInputs inputs q d := fun d => - (hq d).1 - have hhi1 : ∀ d, qRealOfInputs inputs q d ≤ (qHi q d : Real) := fun d => - (hq d).2 - have hlo2 : ∀ d, (kLo k d : Real) ≤ kRealOfInputs inputs k d := fun d => - (hk d).1 - have hhi2 : ∀ d, kRealOfInputs inputs k d ≤ (kHi k d : Real) := fun d => - (hk d).2 - have hspec := - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth_spec_real - (dims1 := splitDimsQ q) (dims2 := splitDimsK q k) - (lo1 := fun d => qLo q d) (hi1 := fun d => qHi q d) - (lo2 := fun d => kLo k d) (hi2 := fun d => kHi k d) - (x := fun d => qRealOfInputs inputs q d) - (y := fun d => kRealOfInputs inputs k d) - hlo1 hhi1 hlo2 hhi2 - have hlow' : - (dotLo q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) := by - simpa [dotLo, dotRowTasks, Task.spawn, Array.getElem_ofFn, hnot] - using hspec.1 - have hhigh' : - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) ≤ (dotHi q k : Real) := by - simpa [dotHi, dotRowTasks, Task.spawn, Array.getElem_ofFn, hnot] - using hspec.2 - exact ⟨hlow', hhigh'⟩ - have hscore_base_bounds (hnot : ¬ masked q k) : - (scoreLo q k : Real) ≤ base ∧ base ≤ (scoreHi q k : Real) := by - by_cases hscale : 0 ≤ inputs.scale - · have hscale_real : 0 ≤ (inputs.scale : Real) := - by - simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hscale - have hdot := hdot_bounds hnot - have hlow := mul_le_mul_of_nonneg_left hdot.1 hscale_real - have hhigh := mul_le_mul_of_nonneg_left hdot.2 hscale_real - have hscoreLo : scoreLo q k = inputs.scale * dotLo q k := by - simp [scoreLo, masked, hnot, hscale] - have hscoreHi : scoreHi q k = inputs.scale * dotHi q k := by - simp [scoreHi, masked, hnot, hscale] - constructor - · simpa [hscoreLo, base, Rat.cast_mul] using hlow - · simpa [hscoreHi, base, Rat.cast_mul] using hhigh - · have hscale_nonpos : inputs.scale ≤ 0 := - le_of_lt (lt_of_not_ge hscale) - have hscale_real : (inputs.scale : Real) ≤ 0 := - by - simpa [ratToReal_def] using - (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos - have hdot := hdot_bounds hnot - have hlow := mul_le_mul_of_nonpos_left hdot.2 hscale_real - have hhigh := mul_le_mul_of_nonpos_left hdot.1 hscale_real - have hscoreLo : scoreLo q k = inputs.scale * dotHi q k := by - simp [scoreLo, masked, hnot, hscale] - have hscoreHi : scoreHi q k = inputs.scale * dotLo q k := by - simp [scoreHi, masked, hnot, hscale] - constructor - · simpa [hscoreLo, base, Rat.cast_mul] using hlow - · simpa [hscoreHi, base, Rat.cast_mul] using hhigh - by_cases hcausal : inputs.maskCausal - · by_cases hle : k ≤ q - · have hnot : ¬ q < k := not_lt_of_ge hle - have hnot_masked : ¬ masked q k := fun hmk => hnot hmk.2 - have hscore_eq : scoresReal q k = base := - scoresReal_eq_base_of_not_masked q k hnot_masked - have hbase := hscore_base_bounds hnot_masked - constructor - · simpa [hscore_eq] using hbase.1 - · simpa [hscore_eq] using hbase.2 - · have hlt : q < k := lt_of_not_ge hle - have hmask : masked q k := ⟨hcausal, hlt⟩ - have hscore : scoresReal q k = (inputs.maskValue : Real) := - scoresReal_eq_masked q k hmask - constructor - · simp [hscore, scoreLo, hmask] - · simp [hscore, scoreHi, hmask] - · have hnot_masked : ¬ masked q k := by - simp [masked, hcausal] - have hscore_eq : scoresReal q k = base := - scoresReal_eq_base_of_not_masked q k hnot_masked - have hbase := hscore_base_bounds hnot_masked - constructor - · simpa [hscore_eq] using hbase.1 - · simpa [hscore_eq] using hbase.2 - have hdot_diff_bounds : - ∀ q, q ∈ inputs.active → ∀ k, ¬ masked q k → - (dotDiffLo q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) ∧ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by - intro q hq k hmask - have hq_bounds' := hq_bounds q - have hkprev := hk_bounds (inputs.prev q) - have hk := hk_bounds k - have hlo1 : ∀ d, (qLo q d : Real) ≤ qRealOfInputs inputs q d := fun d => - (hq_bounds' d).1 - have hhi1 : ∀ d, qRealOfInputs inputs q d ≤ (qHi q d : Real) := fun d => - (hq_bounds' d).2 - have hlo2 : - ∀ d, - (kLo (inputs.prev q) d - kHi k d : Rat) ≤ - (kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) := by - intro d - have hprev_lo := (hkprev d).1 - have hk_hi := (hk d).2 - have h := sub_le_sub hprev_lo hk_hi - simpa [ratToReal_sub] using h - have hhi2 : - ∀ d, - (kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) ≤ - (kHi (inputs.prev q) d - kLo k d : Rat) := by - intro d - have hprev_hi := (hkprev d).2 - have hk_lo := (hk d).1 - have h := sub_le_sub hprev_hi hk_lo - simpa [ratToReal_sub] using h - have hspec (dimsDiff : List (Fin dHead)) := - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth_spec_real - (dims1 := splitDimsQ q) (dims2 := dimsDiff) - (lo1 := fun d => qLo q d) (hi1 := fun d => qHi q d) - (lo2 := fun d => kLo (inputs.prev q) d - kHi k d) - (hi2 := fun d => kHi (inputs.prev q) d - kLo k d) - (x := fun d => qRealOfInputs inputs q d) - (y := fun d => - kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) - hlo1 hhi1 hlo2 hhi2 - have hspecBase := hspec (splitDimsDiffBase q k) - have hspecRef := hspec (splitDimsDiffRefined q k) - have hspecBase_bounds : - (dotDiffLoBase q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) ∧ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) ≤ (dotDiffHiBase q k : Real) := by - refine ⟨?_, ?_⟩ - · simpa [dotDiffLoBase, dotDiffRowTasksBase, hq, hmask, Task.spawn, - Array.getElem_ofFn] using hspecBase.1 - · simpa [dotDiffHiBase, dotDiffRowTasksBase, hq, hmask, Task.spawn, - Array.getElem_ofFn] using hspecBase.2 - by_cases hbudget : cfg.splitBudgetDiffRefined = cfg.splitBudgetDiffBase - · simpa [dotDiffLo, dotDiffHi, dotDiffLoHi, hbudget] using hspecBase_bounds - · by_cases hmem : k ∈ refineKeys q - · have hlow' : - (dotDiffLo q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) := by - simpa [dotDiffLo, dotDiffLoHi, hbudget, hmem] using hspecRef.1 - have hhigh' : - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by - simpa [dotDiffHi, dotDiffLoHi, hbudget, hmem] using hspecRef.2 - exact ⟨hlow', hhigh'⟩ - · have hlow' : - (dotDiffLo q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) := by - simpa [dotDiffLo, dotDiffLoHi, hbudget, hmem] using - hspecBase_bounds.1 - have hhigh' : - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) ≤ (dotDiffHi q k : Real) := by - simpa [dotDiffHi, dotDiffLoHi, hbudget, hmem] using - hspecBase_bounds.2 - exact ⟨hlow', hhigh'⟩ - have hmarginAt_le : - ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → - marginAt q ≤ scoreGapLo q k := by - intro q hq k hk - have hmem : k ∈ otherKeys q := by simp [otherKeys, hk] - have hnonempty : (otherKeys q).Nonempty := ⟨k, hmem⟩ - have hmarginAt_eq : - marginAt q = - (otherKeys q).inf' hnonempty (fun k => scoreGapLo q k) := by - simp [marginAt, certFields, buildInductionHeadCertFields_def, hq, hnonempty] - have hle : - (otherKeys q).inf' hnonempty (fun k => scoreGapLo q k) ≤ scoreGapLo q k := by - exact (Finset.inf'_le_iff (s := otherKeys q) (H := hnonempty) - (f := fun k => scoreGapLo q k) (a := scoreGapLo q k)).2 - ⟨k, hmem, le_rfl⟩ - simpa [hmarginAt_eq] using hle - have hscore_gap_real_at : - ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → - scoresReal q k + (scoreGapLo q k : Real) ≤ scoresReal q (inputs.prev q) := by - intro q hq k hk - by_cases hprevmask : masked q (inputs.prev q) - · have hscore_hi : scoresReal q k ≤ (scoreHi q k : Real) := - (hscore_bounds q k).2 - have hscore_prev : (scoreLoPrev q : Real) ≤ scoresReal q (inputs.prev q) := by - have hprev_bounds := hscore_bounds q (inputs.prev q) - simpa [scoreLoPrev] using hprev_bounds.1 - have hsum_le' : - (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k ≤ - (scoreLoPrev q : Real) := by - have hsub : - (scoreLoPrev q : Real) - (scoreHi q k : Real) ≤ - (scoreLoPrev q : Real) - scoresReal q k := - sub_le_sub_left hscore_hi (scoreLoPrev q : Real) - calc - (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k - ≤ (scoreLoPrev q : Real) - scoresReal q k + scoresReal q k := by - simpa [add_comm, add_left_comm, add_assoc] using - (add_le_add_left hsub (scoresReal q k)) - _ = (scoreLoPrev q : Real) := by - simp [sub_add_cancel] - calc - scoresReal q k + (scoreGapLo q k : Real) - = (scoreLoPrev q : Real) - (scoreHi q k : Real) + scoresReal q k := by - simp [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, - hprevmask, add_comm] - _ ≤ (scoreLoPrev q : Real) := hsum_le' - _ ≤ scoresReal q (inputs.prev q) := hscore_prev - · by_cases hmask : masked q k - · have hscore_prev : (scoreLoPrev q : Real) ≤ scoresReal q (inputs.prev q) := by - have hprev_bounds := hscore_bounds q (inputs.prev q) - simpa [scoreLoPrev] using hprev_bounds.1 - have hscore_k : scoresReal q k = (inputs.maskValue : Real) := - scoresReal_eq_masked q k hmask - calc - scoresReal q k + (scoreGapLo q k : Real) - = (inputs.maskValue : Real) + (scoreLoPrev q : Real) - - (inputs.maskValue : Real) := by - simp [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, - hprevmask, hmask, hscore_k] - _ = (scoreLoPrev q : Real) := by - simp [add_sub_cancel_left] - _ ≤ scoresReal q (inputs.prev q) := hscore_prev - · have hdiff := hdot_diff_bounds q hq k hmask - have hgap_le : - (scoreGapLo q k : Real) ≤ - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) := by - by_cases hscale : 0 ≤ inputs.scale - · have hscale_real : 0 ≤ (inputs.scale : Real) := - by - simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hscale - have hle := mul_le_mul_of_nonneg_left hdiff.1 hscale_real - simpa [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, - hprevmask, hmask, hscale] using hle - · have hscale_nonpos : inputs.scale ≤ 0 := - le_of_lt (lt_of_not_ge hscale) - have hscale_real : (inputs.scale : Real) ≤ 0 := - by - simpa [ratToReal_def] using - (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos - have hle := mul_le_mul_of_nonpos_left hdiff.2 hscale_real - simpa [scoreGapLo, scoreGapLoRaw, Bounds.cacheBound2_apply, - hprevmask, hmask, hscale] using hle - have hscore_prev : - scoresReal q (inputs.prev q) = - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d) := by - simpa using - (scoresReal_eq_base_of_not_masked q (inputs.prev q) hprevmask) - have hscore_k : - scoresReal q k = - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) := by - simpa using (scoresReal_eq_base_of_not_masked q k hmask) - have hdot_sub : - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) = - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d) - - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) := by - classical - simpa using - (Nfp.Sound.Linear.dotProduct_sub_right - (x := fun d => qRealOfInputs inputs q d) - (y := fun d => kRealOfInputs inputs (inputs.prev q) d) - (z := fun d => kRealOfInputs inputs k d)) - have hscore_diff : - scoresReal q (inputs.prev q) - scoresReal q k = - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) := by - calc - scoresReal q (inputs.prev q) - scoresReal q k - = - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d) - - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) := by - simp [hscore_prev, hscore_k] - _ = - (inputs.scale : Real) * - (dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d) - - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d)) := by - simp [mul_sub] - _ = - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) := by - simp [hdot_sub] - have hgap_le' : - (scoreGapLo q k : Real) ≤ - scoresReal q (inputs.prev q) - scoresReal q k := by - simpa [hscore_diff] using hgap_le - have hgap_add := - add_le_add_right hgap_le' (scoresReal q k) - have hgap_add' : - scoresReal q k + (scoreGapLo q k : Real) ≤ - scoresReal q (inputs.prev q) := by - have hcancel : - scoresReal q k + (scoresReal q (inputs.prev q) - scoresReal q k) = - scoresReal q (inputs.prev q) := by - calc - scoresReal q k + (scoresReal q (inputs.prev q) - scoresReal q k) - = - scoresReal q k + scoresReal q (inputs.prev q) - - scoresReal q k := by - symm - exact add_sub_assoc (scoresReal q k) - (scoresReal q (inputs.prev q)) (scoresReal q k) - _ = scoresReal q (inputs.prev q) := by - simp [add_sub_cancel_left] - calc - scoresReal q k + (scoreGapLo q k : Real) - ≤ scoresReal q k + (scoresReal q (inputs.prev q) - scoresReal q k) := - hgap_add - _ = scoresReal q (inputs.prev q) := hcancel - exact hgap_add' - let softmaxWeights := Circuit.softmaxWeights scoresReal - let weights : Fin seq → Fin seq → Real := fun q k => - Circuit.softmax (scoresReal q) k - let others : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - have hscore_margin_real_at : - ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → - scoresReal q k + (marginAt q : Real) ≤ scoresReal q (inputs.prev q) := by - intro q hq k hk - have hmargin_le : marginAt q ≤ scoreGapLo q k := - hmarginAt_le q hq k hk - have hmargin_le_real : (marginAt q : Real) ≤ (scoreGapLo q k : Real) := - by - simpa [ratToReal_def] using ratToReal_le_of_le hmargin_le - have hscore_gap := hscore_gap_real_at q hq k hk - have hstep := add_le_add_right hmargin_le_real (scoresReal q k) - have hstep' : - scoresReal q k + (marginAt q : Real) ≤ - scoresReal q k + (scoreGapLo q k : Real) := by - exact hstep - exact hstep'.trans hscore_gap - have hscore_margin_real : - ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → - scoresReal q k + (margin : Real) ≤ scoresReal q (inputs.prev q) := by - intro q hq k hk - have hmargin_le : margin ≤ marginAt q := by - have hmem : q ∈ inputs.active := hq - have hnonempty : inputs.active.Nonempty := hactive - have hle := - (Finset.inf'_le_iff (s := inputs.active) (H := hnonempty) - (f := marginAt) (a := marginAt q)).2 ⟨q, hmem, le_rfl⟩ - simpa [margin, certFields, buildInductionHeadCertFields_def, hnonempty] using hle - have hmargin_le_real : (margin : Real) ≤ (marginAt q : Real) := - by - simpa [ratToReal_def] using ratToReal_le_of_le hmargin_le - have hscore := hscore_margin_real_at q hq k hk - have hscore' : - (marginAt q : Real) + scoresReal q k ≤ scoresReal q (inputs.prev q) := by - simpa [add_comm] using hscore - have hstep := add_le_add_right hmargin_le_real (scoresReal q k) - have hstep' : - scoresReal q k + (margin : Real) ≤ (marginAt q : Real) + scoresReal q k := by - calc - scoresReal q k + (margin : Real) ≤ scoresReal q k + (marginAt q : Real) := hstep - _ = (marginAt q : Real) + scoresReal q k := by - simp [add_comm] - exact hstep'.trans hscore' - have hweightBoundAtBase : - ∀ q k, k ≠ inputs.prev q → - weightBoundAtBase q k = - if scoreGapLo q k < 0 then - (1 : Rat) - else - ratDivUp 1 (1 + scoreGapLo q k) := by - intro q k hk - simp [weightBoundAtBase, certFields, buildInductionHeadCertFields_def, hk] - have hweightBoundAt : - ∀ q k, - weightBoundAt q k = weightBoundAtBase q k := by - intro q k - by_cases hk : k = inputs.prev q - · simp [weightBoundAt, weightBoundAtBase, certFields, - buildInductionHeadCertFields_def, Bounds.cacheBound2Task_apply, - hk] - · simp [weightBoundAt, weightBoundAtBase, certFields, - buildInductionHeadCertFields_def, Bounds.cacheBound2Task_apply, - hk] - have hepsAt : - ∀ q, epsAt q = - min (1 : Rat) - ((otherKeys q).sum (fun k => - if scoreGapLo q k < 0 then - (1 : Rat) - else - ratDivUp 1 (1 + scoreGapLo q k))) := by - intro q - have hsum : - (otherKeys q).sum (fun k => - Bounds.cacheBound2Task - (fun q k => - if k = inputs.prev q then - (0 : Rat) - else if scoreGapLo q k < 0 then - (1 : Rat) - else - ratDivUp 1 (1 + scoreGapLo q k)) q k) = - (otherKeys q).sum (fun k => - if scoreGapLo q k < 0 then - (1 : Rat) - else - ratDivUp 1 (1 + scoreGapLo q k)) := by - refine Finset.sum_congr rfl ?_ - intro k hk - have hk' : k ≠ inputs.prev q := by - have hk'' : k ∈ Finset.univ.erase (inputs.prev q) := by - simpa [otherKeys] using hk - exact (Finset.mem_erase.mp hk'').1 - simp [Bounds.cacheBound2Task_apply, hk'] - simpa [epsAt, epsAtBase, certFields, buildInductionHeadCertFields_def, hsum] using - (Bounds.cacheBoundThunk_apply (f := epsAtBase) q) - have oneHot_bounds_at : - ∀ q, q ∈ inputs.active → - Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAt q : Real) - (fun q' => q' = q) inputs.prev weights := by - intro q hq - exact - Sound.oneHot_bounds_at_of_scoreGapLo - (active := inputs.active) - (prev := inputs.prev) - (scoresReal := scoresReal) - (scoreGapLo := scoreGapLo) - (epsAt := epsAt) - (hepsAt := hepsAt) - (hscore_gap_real_at := hscore_gap_real_at) - q hq - have weight_bounds_at : - ∀ q, q ∈ inputs.active → ∀ k, k ≠ inputs.prev q → - weights q k ≤ (weightBoundAt q k : Real) := by - intro q hq k hk - have hbound_base : - weights q k ≤ (weightBoundAtBase q k : Real) := by - exact - Sound.weight_bound_at_of_scoreGapLo - (active := inputs.active) - (prev := inputs.prev) - (scoresReal := scoresReal) - (scoreGapLo := scoreGapLo) - (weightBoundAt := weightBoundAtBase) - (hweightBoundAt := hweightBoundAtBase) - (hscore_gap_real_at := hscore_gap_real_at) - q hq k hk - have hweightBoundAt_real : - (weightBoundAt q k : Real) = - (weightBoundAtBase q k : Real) := by - have hbase : weightBoundAt q k = weightBoundAtBase q k := - hweightBoundAt q k - have hbase' : - ratToReal (weightBoundAt q k) = ratToReal (weightBoundAtBase q k) := - congrArg ratToReal hbase - simpa [ratToReal_def] using hbase' - simpa [hweightBoundAt_real] using hbound_base - have hepsAt_le_eps : - ∀ q, q ∈ inputs.active → epsAt q ≤ eps := by - intro q hq - have hle : - epsAt q ≤ inputs.active.sup' hactive epsAt := by - exact - (Finset.le_sup'_iff (s := inputs.active) (H := hactive) - (f := epsAt) (a := epsAt q)).2 ⟨q, hq, le_rfl⟩ - have heps_def : - eps = inputs.active.sup' hactive epsAt := by - have heps_def' : - certFields.eps = - if h : inputs.active.Nonempty then - inputs.active.sup' h certFields.epsAt - else - (0 : Rat) := by - simpa [certFields] using - (buildInductionHeadCertFields_eps_eq - (inputs := inputs) (otherKeys := otherKeys) (scoreGapLo := scoreGapLo)) - have heps_def'' : - certFields.eps = inputs.active.sup' hactive certFields.epsAt := by - simpa [hactive] using heps_def' - simpa [eps, epsAt] using heps_def'' - simpa [heps_def] using hle - have hepsAt_le_eps_real : - ∀ q, q ∈ inputs.active → (epsAt q : Real) ≤ (eps : Real) := by - intro q hq - simpa [ratToReal_def] using ratToReal_le_of_le (hepsAt_le_eps q hq) - have hsoftmax_bounds : - Layers.SoftmaxMarginBoundsOn (Val := Real) (eps : Real) (margin : Real) - (fun q => q ∈ inputs.active) inputs.prev scoresReal weights := by - classical - refine - { score_margin := ?_ - nonneg := ?_ - sum_one := ?_ - prev_large := ?_ - other_le := ?_ } - · intro q hq k hk - exact hscore_margin_real q hq k hk - · intro q _ k - simpa [weights, softmaxWeights, Circuit.softmaxWeights_weights] using - softmaxWeights.nonneg q k - · intro q _ - simpa [weights, softmaxWeights, Circuit.softmaxWeights_weights] using - softmaxWeights.sum_one q - · intro q hq - have honehot := oneHot_bounds_at q hq - have hprev := honehot.prev_large q rfl - have hle : - weights q (inputs.prev q) + (epsAt q : Real) ≤ - weights q (inputs.prev q) + (eps : Real) := by - have hle' := - add_le_add_left (hepsAt_le_eps_real q hq) (weights q (inputs.prev q)) - calc - weights q (inputs.prev q) + (epsAt q : Real) - = (epsAt q : Real) + weights q (inputs.prev q) := by - exact add_comm _ _ - _ ≤ (eps : Real) + weights q (inputs.prev q) := hle' - _ = weights q (inputs.prev q) + (eps : Real) := by - exact add_comm _ _ - exact hprev.trans hle - · intro q hq k hk - have honehot := oneHot_bounds_at q hq - have hother := honehot.other_le q rfl k hk - exact hother.trans (hepsAt_le_eps_real q hq) - have hwvDirRaw : - ∀ j, wvDirRaw j = Linear.dotFin dHead dirHead (fun d => inputs.wv j d) := by - intro j - rfl - have hwvDirTask : ∀ j, wvDirTask j = wvDirRaw j := by - intro j - simpa [wvDirTask] using (Bounds.cacheBoundTask_apply wvDirRaw j) - have hwvDir : ∀ j, wvDir j = Linear.dotFin dHead dirHead (fun d => inputs.wv j d) := by - intro j - have hwv : wvDir j = wvDirTask j := by - simp only [wvDir, wvDirArr, Array.getElem_ofFn] - have hwv' : wvDir j = wvDirRaw j := by - exact hwv.trans (hwvDirTask j) - calc - wvDir j = wvDirRaw j := hwv' - _ = Linear.dotFin dHead dirHead (fun d => inputs.wv j d) := hwvDirRaw j - have hbDir : - bDir = Linear.dotFin dHead dirHead (fun d => inputs.bv d) := by - rfl - have hvals_bounds : - ValueIntervalBounds (vals := valsRealOfInputs inputs) valCert := - valCert_bounds_of_ln_bounds (inputs := inputs) (dirHead := dirHead) (hdirHead := rfl) - (wvDir := wvDir) (bDir := bDir) (hwvDir := hwvDir) (hbDir := hbDir) - (lnLo := lnLo) (lnHi := lnHi) (valsLo := valsLo) (valsHi := valsHi) - (hvalsLo := hvalsLo) (hvalsHi := hvalsHi) (hln := hln_bounds) - have hcert_eps : cert.eps = eps := by rfl - have hcert_margin : cert.margin = margin := by rfl - have hcert_active : cert.active = inputs.active := by rfl - have hcert_prev : cert.prev = inputs.prev := by rfl - have hcert_epsAt : cert.epsAt = epsAt := by rfl - have hcert_weight : cert.weightBoundAt = weightBoundAt := by rfl - have hcert_values : cert.values = valCert := by rfl - refine - { softmax_bounds := ?_ - oneHot_bounds_at := ?_ - weight_bounds_at := ?_ - value_bounds := ?_ } - · simpa [hcert_eps, hcert_margin, hcert_active, hcert_prev] using hsoftmax_bounds - · intro q hq - simpa [hcert_epsAt, hcert_active, hcert_prev] using oneHot_bounds_at q hq - · intro q hq k hk - simpa [hcert_weight, hcert_active, hcert_prev] using weight_bounds_at q hq k hk - · simpa [hcert_values] using hvals_bounds - · have : False := by - have hnone := - buildInductionCertFromHeadCoreWith?_eq_none_of_not_active - (cfg := cfg) (inputs := inputs) hEps hSqrt hmodel hactive - have hcore' : - (none : Option (InductionHeadCert seq)) = some c := by - exact hnone.symm.trans hcore - cases hcore' - exact this.elim - · have : False := by - have hnone := - buildInductionCertFromHeadCoreWith?_eq_none_of_not_sqrt - (cfg := cfg) (inputs := inputs) hEps hSqrt - have hcore' : - (none : Option (InductionHeadCert seq)) = some c := by - exact hnone.symm.trans hcore - cases hcore' - exact this.elim - · have : False := by - have hnone := - buildInductionCertFromHeadCoreWith?_eq_none_of_not_eps - (cfg := cfg) (inputs := inputs) hEps - have hcore' : - (none : Option (InductionHeadCert seq)) = some c := by - exact hnone.symm.trans hcore - cases hcore' - exact this.elim - - -end Sound -end Nfp diff --git a/Nfp/Sound/Induction/CoreSound/Basic/DefaultSound.lean b/Nfp/Sound/Induction/CoreSound/Basic/DefaultSound.lean deleted file mode 100644 index 5171a42..0000000 --- a/Nfp/Sound/Induction/CoreSound/Basic/DefaultSound.lean +++ /dev/null @@ -1,29 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later -module - -import all Nfp.Sound.Induction.Core.Basic -public import Nfp.Sound.Induction.Core -public import Nfp.Sound.Induction.CoreSound.Basic.CertSound - -public section - -namespace Nfp -namespace Sound -open scoped BigOperators -open Nfp.Circuit -open Nfp.Sound.Bounds -variable {seq : Nat} -/-- Soundness for `buildInductionCertFromHeadCore?`. -/ -theorem buildInductionCertFromHeadCore?_sound - [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) - (hcore : buildInductionCertFromHeadCore? inputs = some c) : - InductionHeadCertSound inputs c := by - have hcore' : - buildInductionCertFromHeadCoreWith? defaultInductionHeadSplitConfig inputs = some c := by - simpa [buildInductionCertFromHeadCore?_def] using hcore - exact - buildInductionCertFromHeadCoreWith?_sound - (cfg := defaultInductionHeadSplitConfig) inputs c hcore' -end Sound -end Nfp diff --git a/Nfp/Sound/Induction/CoreSound/Values.lean b/Nfp/Sound/Induction/CoreSound/Values.lean deleted file mode 100644 index 68e0962..0000000 --- a/Nfp/Sound/Induction/CoreSound/Values.lean +++ /dev/null @@ -1,229 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later -module - -public import Mathlib.Algebra.BigOperators.Group.Finset.Basic -public import Nfp.Sound.Induction.CoreDefs -public import Nfp.Sound.Induction.Core.Basic -public import Nfp.Sound.Linear.FinFold - -/-! -Helper lemmas for value-direction bounds in induction-head soundness. - -These isolate the algebra needed to rewrite direction-value projections into -dot products over cached `wvDir`/`bDir` terms. --/ - -public section - -namespace Nfp - -namespace Sound - -open scoped BigOperators - -open Nfp.Sound.Linear - -variable {seq dModel dHead : Nat} - -/-- Cast a cached `wvDir` dot to a Real-valued sum over head weights. -/ -theorem wvDir_real_eq_sum (inputs : Model.InductionHeadInputs seq dModel dHead) - (dirHead : Fin dHead → Rat) (wvDir : Fin dModel → Rat) - (hwvDir : ∀ j, wvDir j = Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) : - ∀ j, (wvDir j : Real) = ∑ d, (dirHead d : Real) * (inputs.wv j d : Real) := by - intro j - calc - (wvDir j : Real) - = ((∑ d, dirHead d * inputs.wv j d : Rat) : Real) := by - simp [hwvDir j, Linear.dotFin_eq_dotProduct, dotProduct] - _ = ∑ d, (dirHead d : Real) * (inputs.wv j d : Real) := by - simp - -/-- Cast a cached `bDir` dot to a Real-valued sum over head biases. -/ -theorem bDir_real_eq_sum (inputs : Model.InductionHeadInputs seq dModel dHead) - (dirHead : Fin dHead → Rat) (bDir : Rat) - (hbDir : bDir = Linear.dotFin dHead dirHead (fun d => inputs.bv d)) : - (bDir : Real) = ∑ d, (dirHead d : Real) * (inputs.bv d : Real) := by - calc - (bDir : Real) - = ((∑ d, dirHead d * inputs.bv d : Rat) : Real) := by - simp [hbDir, Linear.dotFin_eq_dotProduct, dotProduct] - _ = ∑ d, (dirHead d : Real) * (inputs.bv d : Real) := by - simp - -/-- Rewrite direction values using cached `wvDir` and `bDir` sums. -/ -theorem valsReal_eq_of_dir (inputs : Model.InductionHeadInputs seq dModel dHead) - (dirHead : Fin dHead → Rat) - (hdirHead : dirHead = fun d => (dirHeadVecOfInputs inputs).get d) - (wvDir : Fin dModel → Rat) (bDir : Rat) - (hdir_wv : ∀ j, (wvDir j : Real) = ∑ d, (dirHead d : Real) * (inputs.wv j d : Real)) - (hdir_bv : (bDir : Real) = ∑ d, (dirHead d : Real) * (inputs.bv d : Real)) : - ∀ k, - valsRealOfInputs inputs k = - dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + (bDir : Real) := by - intro k - classical - have hdirHead_real : - (fun d => (dirHeadVecOfInputs inputs).get d : Fin dHead → Real) = - fun d => (dirHead d : Real) := by - funext d - simp [hdirHead] - have hdot_add : - dotProduct (fun d => (dirHead d : Real)) - (fun d => - dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs k) + - (inputs.bv d : Real)) = - dotProduct (fun d => (dirHead d : Real)) - (fun d => dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs k)) + - dotProduct (fun d => (dirHead d : Real)) (fun d => (inputs.bv d : Real)) := by - simpa using - (Nfp.Sound.Linear.dotProduct_add_right - (x := fun d => (dirHead d : Real)) - (y := fun d => dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs k)) - (z := fun d => (inputs.bv d : Real))) - have hdot_wv : - dotProduct (fun d => (dirHead d : Real)) - (fun d => dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs k)) = - dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) := by - calc - dotProduct (fun d => (dirHead d : Real)) - (fun d => dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs k)) = - ∑ d, (dirHead d : Real) * ∑ j, - (inputs.wv j d : Real) * lnRealOfInputs inputs k j := by - simp [dotProduct] - _ = ∑ d, ∑ j, - (dirHead d : Real) * - ((inputs.wv j d : Real) * lnRealOfInputs inputs k j) := by - simp [Finset.mul_sum] - _ = ∑ j, ∑ d, - (dirHead d : Real) * - ((inputs.wv j d : Real) * lnRealOfInputs inputs k j) := by - simpa using - (Finset.sum_comm (s := (Finset.univ : Finset (Fin dHead))) - (t := (Finset.univ : Finset (Fin dModel))) - (f := fun d j => - (dirHead d : Real) * ((inputs.wv j d : Real) * lnRealOfInputs inputs k j))) - _ = ∑ j, (∑ d, (dirHead d : Real) * (inputs.wv j d : Real)) * - lnRealOfInputs inputs k j := by - refine Finset.sum_congr rfl ?_ - intro j _ - have hsum : - (∑ d, (dirHead d : Real) * (inputs.wv j d : Real)) * - lnRealOfInputs inputs k j = - ∑ d, (dirHead d : Real) * (inputs.wv j d : Real) * - lnRealOfInputs inputs k j := by - simp [Finset.sum_mul, mul_assoc] - simpa [mul_assoc] using hsum.symm - _ = ∑ j, (wvDir j : Real) * lnRealOfInputs inputs k j := by - refine Finset.sum_congr rfl ?_ - intro j _ - simp [hdir_wv j] - _ = dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) := by - simp [dotProduct] - calc - valsRealOfInputs inputs k = - dotProduct (fun d => (dirHead d : Real)) - (fun d => - dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs k) + - (inputs.bv d : Real)) := by - simp [valsRealOfInputs_def, vRealOfInputs_def, hdirHead_real] - _ = - dotProduct (fun d => (dirHead d : Real)) - (fun d => dotProduct (fun j => (inputs.wv j d : Real)) - (lnRealOfInputs inputs k)) + - dotProduct (fun d => (dirHead d : Real)) (fun d => (inputs.bv d : Real)) := hdot_add - _ = - dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + - dotProduct (fun d => (dirHead d : Real)) (fun d => (inputs.bv d : Real)) := by - simp [hdot_wv] - _ = - dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + (bDir : Real) := by - have hb : - dotProduct (fun d => (dirHead d : Real)) (fun d => (inputs.bv d : Real)) = - (bDir : Real) := by - simpa [dotProduct] using hdir_bv.symm - simp [hb] - -/-- Bound `valsRealOfInputs` using cached `wvDir`/`bDir` and logit interval bounds. -/ -theorem valsReal_bounds_at_of_ln_bounds (inputs : Model.InductionHeadInputs seq dModel dHead) - (dirHead : Fin dHead → Rat) - (hdirHead : dirHead = fun d => (dirHeadVecOfInputs inputs).get d) - (wvDir : Fin dModel → Rat) (bDir : Rat) - (hwvDir : ∀ j, wvDir j = Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) - (hbDir : bDir = Linear.dotFin dHead dirHead (fun d => inputs.bv d)) - (lnLo lnHi : Fin seq → Fin dModel → Rat) - (valsLo valsHi : Fin seq → Rat) - (hvalsLo : - ∀ k, valsLo k = bDir + Bounds.dotIntervalLower wvDir (lnLo k) (lnHi k)) - (hvalsHi : - ∀ k, valsHi k = bDir + Bounds.dotIntervalUpper wvDir (lnLo k) (lnHi k)) - (hln : - ∀ k j, (lnLo k j : Real) ≤ lnRealOfInputs inputs k j ∧ - lnRealOfInputs inputs k j ≤ (lnHi k j : Real)) : - ∀ k, - (valsLo k : Rat) ≤ valsRealOfInputs inputs k ∧ - valsRealOfInputs inputs k ≤ (valsHi k : Rat) := by - intro k - have hdir_wv : - ∀ j, (wvDir j : Real) = ∑ d, (dirHead d : Real) * (inputs.wv j d : Real) := - wvDir_real_eq_sum inputs dirHead wvDir hwvDir - have hdir_bv : - (bDir : Real) = ∑ d, (dirHead d : Real) * (inputs.bv d : Real) := - bDir_real_eq_sum inputs dirHead bDir hbDir - have hvals_eq := - valsReal_eq_of_dir inputs dirHead hdirHead wvDir bDir hdir_wv hdir_bv k - have hlo : ∀ j, (lnLo k j : Real) ≤ lnRealOfInputs inputs k j := - fun j => (hln k j).1 - have hhi : ∀ j, lnRealOfInputs inputs k j ≤ (lnHi k j : Real) := - fun j => (hln k j).2 - have hlow' : - (bDir + Bounds.dotIntervalLower wvDir (lnLo k) (lnHi k) : Rat) ≤ - dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + (bDir : Real) := by - simpa [Rat.cast_add, add_comm] using - (Bounds.dotIntervalLower_le_dotProduct_real_add - (n := dModel) (v := wvDir) - (lo := lnLo k) (hi := lnHi k) - (x := lnRealOfInputs inputs k) (b := (bDir : Real)) - hlo hhi) - have hhigh' : - dotProduct (fun j => (wvDir j : Real)) (lnRealOfInputs inputs k) + (bDir : Real) ≤ - (bDir + Bounds.dotIntervalUpper wvDir (lnLo k) (lnHi k) : Rat) := by - simpa [Rat.cast_add, add_comm] using - (Bounds.dotProduct_le_dotIntervalUpper_real_add - (n := dModel) (v := wvDir) - (lo := lnLo k) (hi := lnHi k) - (x := lnRealOfInputs inputs k) (b := (bDir : Real)) - hlo hhi) - constructor - · rw [hvalsLo k, hvals_eq] - exact hlow' - · rw [hvalsHi k, hvals_eq] - exact hhigh' - -/-- Build `ValueIntervalBounds` from logit interval bounds for `buildInductionHeadValCert`. -/ -theorem valCert_bounds_of_ln_bounds [NeZero seq] - (inputs : Model.InductionHeadInputs seq dModel dHead) - (dirHead : Fin dHead → Rat) - (hdirHead : dirHead = fun d => (dirHeadVecOfInputs inputs).get d) - (wvDir : Fin dModel → Rat) (bDir : Rat) - (hwvDir : ∀ j, wvDir j = Linear.dotFin dHead dirHead (fun d => inputs.wv j d)) - (hbDir : bDir = Linear.dotFin dHead dirHead (fun d => inputs.bv d)) - (lnLo lnHi : Fin seq → Fin dModel → Rat) - (valsLo valsHi : Fin seq → Rat) - (hvalsLo : - ∀ k, valsLo k = bDir + Bounds.dotIntervalLower wvDir (lnLo k) (lnHi k)) - (hvalsHi : - ∀ k, valsHi k = bDir + Bounds.dotIntervalUpper wvDir (lnLo k) (lnHi k)) - (hln : - ∀ k j, (lnLo k j : Real) ≤ lnRealOfInputs inputs k j ∧ - lnRealOfInputs inputs k j ≤ (lnHi k j : Real)) : - ValueIntervalBounds (vals := valsRealOfInputs inputs) - (buildInductionHeadValCert inputs valsLo valsHi) := by - have hvals_bounds_at := - valsReal_bounds_at_of_ln_bounds inputs dirHead hdirHead wvDir bDir hwvDir hbDir - lnLo lnHi valsLo valsHi hvalsLo hvalsHi hln - exact buildInductionHeadValCert_bounds (inputs := inputs) - (valsReal := valsRealOfInputs inputs) (valsLo := valsLo) (valsHi := valsHi) - hvals_bounds_at - -end Sound -end Nfp diff --git a/Nfp/Sound/Induction/HeadBounds.lean b/Nfp/Sound/Induction/HeadBounds.lean deleted file mode 100644 index d582fad..0000000 --- a/Nfp/Sound/Induction/HeadBounds.lean +++ /dev/null @@ -1,9 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.Sound.Induction.HeadBounds.Basic - -/-! -Helper bounds for head-induction certificate construction. --/ diff --git a/Nfp/Sound/Induction/HeadBounds/Basic.lean b/Nfp/Sound/Induction/HeadBounds/Basic.lean deleted file mode 100644 index 4d7b4d2..0000000 --- a/Nfp/Sound/Induction/HeadBounds/Basic.lean +++ /dev/null @@ -1,1242 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.Core.Basic -public import Mathlib.Data.Finset.Basic -public import Mathlib.Data.List.Range -public import Mathlib.Data.Vector.Defs -public import Nfp.Model.InductionHead -public import Nfp.Sound.Bounds.Attention -public import Nfp.Sound.Bounds.LayerNorm -public import Nfp.Sound.Bounds.MatrixNorm -public import Nfp.Sound.Linear.FinFold - -/-! -Helper bounds for head-induction certificate construction. - -These are pure precomputations that are useful for profiling and staging. --/ - -public section - -namespace Nfp - -namespace Sound - -open Nfp.Sound.Bounds - -variable {seq : Nat} - -private def taskMin (t1 t2 : Task Rat) : Task Rat := - Task.bind t1 (fun v1 => Task.map (fun v2 => min v1 v2) t2) - -private def taskMax (t1 t2 : Task Rat) : Task Rat := - Task.bind t1 (fun v1 => Task.map (fun v2 => max v1 v2) t2) - -/-! Small lemmas for extracting `get` from task folds. -/ - -/-- `taskMin` exposes its `get` as a plain `min` on task results. -/ -private theorem taskMin_get (t1 t2 : Task Rat) : - (taskMin t1 t2).get = min t1.get t2.get := by - rfl - -/-- `taskMax` exposes its `get` as a plain `max` on task results. -/ -private theorem taskMax_get (t1 t2 : Task Rat) : - (taskMax t1 t2).get = max t1.get t2.get := by - rfl - -/-- Pull `get` through a `List.foldl` when the step is `get`-compatible. -/ -private theorem foldl_task_get_eq {α β : Type} (step : Task β → α → Task β) (step' : β → α → β) - (hstep : ∀ acc a, (step acc a).get = step' acc.get a) : - ∀ (xs : List α) (acc : Task β), - (List.foldl step acc xs).get = List.foldl step' acc.get xs - | [], acc => rfl - | x :: xs, acc => by - simpa [List.foldl, hstep] using foldl_task_get_eq step step' hstep xs (step acc x) - -/-- `List.foldl` over `taskMin` exposes a fold over `min` on task results. -/ -private theorem foldl_taskMin_get_eq {α : Type} (f : α → Task Rat) (xs : List α) - (init : Task Rat) : - (List.foldl (fun acc a => taskMin acc (f a)) init xs).get = - List.foldl (fun acc a => min acc (f a).get) init.get xs := by - refine - foldl_task_get_eq - (step := fun acc a => taskMin acc (f a)) - (step' := fun acc a => min acc (f a).get) - (hstep := ?_) - xs init - intro acc a - simp [taskMin_get] - -/-- `List.foldl` over `taskMax` exposes a fold over `max` on task results. -/ -private theorem foldl_taskMax_get_eq {α : Type} (f : α → Task Rat) (xs : List α) - (init : Task Rat) : - (List.foldl (fun acc a => taskMax acc (f a)) init xs).get = - List.foldl (fun acc a => max acc (f a).get) init.get xs := by - refine - foldl_task_get_eq - (step := fun acc a => taskMax acc (f a)) - (step' := fun acc a => max acc (f a).get) - (hstep := ?_) - xs init - intro acc a - simp [taskMax_get] - -/-- `Array.get?` + `Option.getD` followed by `Task.get` agrees with `getD` on values. -/ -private theorem task_getD_ofFn {n : Nat} (f : Fin n → Rat) (i : Nat) : - ((Array.ofFn fun c => ({ get := f c } : Task Rat))[i]?.getD { get := (0 : Rat) }).get = - (Array.ofFn f)[i]?.getD (0 : Rat) := by - by_cases h : i < n - · simp [h, Array.size_ofFn] - · simp [h, Array.size_ofFn] - -/-! Helpers for reducing cached arrays without extra allocation. -/ - -/-- Reduce an array of rational bounds to its minimum (defaulting to `0` on empty arrays). -/ -private def reduceMinArray (arr : Array Rat) : Rat := - let init := arr.getD 0 (0 : Rat) - arr.foldl (fun acc x => min acc x) init - -/-- Reduce an array of rational bounds to its maximum (defaulting to `0` on empty arrays). -/ -private def reduceMaxArray (arr : Array Rat) : Rat := - let init := arr.getD 0 (0 : Rat) - arr.foldl (fun acc x => max acc x) init - -/-- Reduce a `Fin seq`-indexed function using the chunked sequential algorithm. -/ -private def reduceFnChunked [NeZero seq] (vals : Fin seq → Rat) - (combine : Rat → Rat → Rat) : Rat := - let n := seq - if n = 0 then - (0 : Rat) - else - let chunkSize : Nat := 256 - let chunks : Nat := (n + chunkSize - 1) / chunkSize - let hpos : 0 < seq := Nat.pos_of_ne_zero (by simpa using (NeZero.ne (n := seq))) - let defaultIdx : Fin seq := ⟨0, hpos⟩ - let idxs : Array (Fin seq) := Array.ofFn (fun i : Fin seq => i) - let chunkVals : Array Rat := - Array.ofFn (fun c : Fin chunks => - let start := c.val * chunkSize - let stop := Nat.min n (start + chunkSize) - let init := vals (idxs.getD start defaultIdx) - if stop ≤ start + 1 then - init - else - let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) - rest.foldl (fun acc i => combine acc (vals (idxs.getD i defaultIdx))) init) - let init := chunkVals.getD 0 (0 : Rat) - let rest := (List.range (chunkVals.size - 1)).map (fun i => i + 1) - rest.foldl (fun acc i => combine acc (chunkVals.getD i 0)) init - -/-- Unfold `reduceFnChunked` to its chunked sequential definition. -/ -private theorem reduceFnChunked_spec [NeZero seq] (vals : Fin seq → Rat) - (combine : Rat → Rat → Rat) : - reduceFnChunked (seq := seq) vals combine = - let n := seq - if n = 0 then - (0 : Rat) - else - let chunkSize : Nat := 256 - let chunks : Nat := (n + chunkSize - 1) / chunkSize - let hpos : 0 < seq := Nat.pos_of_ne_zero (by simpa using (NeZero.ne (n := seq))) - let defaultIdx : Fin seq := ⟨0, hpos⟩ - let idxs : Array (Fin seq) := Array.ofFn (fun i : Fin seq => i) - let chunkVals : Array Rat := - Array.ofFn (fun c : Fin chunks => - let start := c.val * chunkSize - let stop := Nat.min n (start + chunkSize) - let init := vals (idxs.getD start defaultIdx) - if stop ≤ start + 1 then - init - else - let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) - rest.foldl (fun acc i => combine acc (vals (idxs.getD i defaultIdx))) init) - let init := chunkVals.getD 0 (0 : Rat) - let rest := (List.range (chunkVals.size - 1)).map (fun i => i + 1) - rest.foldl (fun acc i => combine acc (chunkVals.getD i 0)) init := rfl - -/-- Reduce a `Fin seq`-indexed function in parallel using chunked tasks. -/ -private def reduceFnTask [NeZero seq] (vals : Fin seq → Rat) - (combine : Rat → Rat → Rat) (combineTask : Task Rat → Task Rat → Task Rat) : Task Rat := - let n := seq - if n = 0 then - Task.pure (0 : Rat) - else - let chunkSize : Nat := 256 - let chunks : Nat := (n + chunkSize - 1) / chunkSize - let hpos : 0 < seq := Nat.pos_of_ne_zero (by simpa using (NeZero.ne (n := seq))) - let defaultIdx : Fin seq := ⟨0, hpos⟩ - let idxs : Array (Fin seq) := Array.ofFn (fun i : Fin seq => i) - let defaultTask : Task Rat := Task.pure (0 : Rat) - let chunkTasks : Array (Task Rat) := - Array.ofFn (fun c : Fin chunks => - Task.spawn (fun _ => - let start := c.val * chunkSize - let stop := Nat.min n (start + chunkSize) - let init := vals (idxs.getD start defaultIdx) - if stop ≤ start + 1 then - init - else - let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) - rest.foldl (fun acc i => combine acc (vals (idxs.getD i defaultIdx))) init)) - let init := chunkTasks.getD 0 defaultTask - let rest := (List.range (chunkTasks.size - 1)).map (fun i => i + 1) - rest.foldl (fun acc i => combineTask acc (chunkTasks.getD i defaultTask)) init - -/-- Unfold `reduceFnTask` to its chunked-task definition. -/ -private theorem reduceFnTask_spec [NeZero seq] (vals : Fin seq → Rat) - (combine : Rat → Rat → Rat) (combineTask : Task Rat → Task Rat → Task Rat) : - reduceFnTask (seq := seq) vals combine combineTask = - let n := seq - if n = 0 then - Task.pure (0 : Rat) - else - let chunkSize : Nat := 256 - let chunks : Nat := (n + chunkSize - 1) / chunkSize - let hpos : 0 < seq := Nat.pos_of_ne_zero (by simpa using (NeZero.ne (n := seq))) - let defaultIdx : Fin seq := ⟨0, hpos⟩ - let idxs : Array (Fin seq) := Array.ofFn (fun i : Fin seq => i) - let defaultTask : Task Rat := Task.pure (0 : Rat) - let chunkTasks : Array (Task Rat) := - Array.ofFn (fun c : Fin chunks => - Task.spawn (fun _ => - let start := c.val * chunkSize - let stop := Nat.min n (start + chunkSize) - let init := vals (idxs.getD start defaultIdx) - if stop ≤ start + 1 then - init - else - let rest := (List.range (stop - start - 1)).map (fun i => start + i + 1) - rest.foldl (fun acc i => combine acc (vals (idxs.getD i defaultIdx))) init)) - let init := chunkTasks.getD 0 defaultTask - let rest := (List.range (chunkTasks.size - 1)).map (fun i => i + 1) - rest.foldl (fun acc i => combineTask acc (chunkTasks.getD i defaultTask)) init := rfl - -private def reduceMinFnTask [NeZero seq] (vals : Fin seq → Rat) : Task Rat := - reduceFnTask vals min taskMin - -private def reduceMaxFnTask [NeZero seq] (vals : Fin seq → Rat) : Task Rat := - reduceFnTask vals max taskMax - -/-- Chunked sequential minimum over a `Fin seq`-indexed function. -/ -private def reduceMinFnChunked [NeZero seq] (vals : Fin seq → Rat) : Rat := - reduceFnChunked vals min - -/-- Unfold `reduceMinFnChunked` to `reduceFnChunked` with `min`. -/ -private theorem reduceMinFnChunked_spec [NeZero seq] (vals : Fin seq → Rat) : - reduceMinFnChunked vals = reduceFnChunked vals min := rfl - -/-- Chunked sequential maximum over a `Fin seq`-indexed function. -/ -private def reduceMaxFnChunked [NeZero seq] (vals : Fin seq → Rat) : Rat := - reduceFnChunked vals max - -/-- Unfold `reduceMaxFnChunked` to `reduceFnChunked` with `max`. -/ -private theorem reduceMaxFnChunked_spec [NeZero seq] (vals : Fin seq → Rat) : - reduceMaxFnChunked vals = reduceFnChunked vals max := rfl - -/-- The chunked parallel min-reduction task returns the sequential chunked result. -/ -private theorem reduceMinFnTask_get_eq [NeZero seq] (vals : Fin seq → Rat) : - (reduceMinFnTask vals).get = reduceMinFnChunked vals := by - classical - have hseq : seq ≠ 0 := NeZero.ne (n := seq) - simp [reduceMinFnTask, reduceMinFnChunked, reduceFnTask, reduceFnChunked, hseq, - Task.spawn, foldl_taskMin_get_eq, task_getD_ofFn] - -/-- The chunked parallel max-reduction task returns the sequential chunked result. -/ -private theorem reduceMaxFnTask_get_eq [NeZero seq] (vals : Fin seq → Rat) : - (reduceMaxFnTask vals).get = reduceMaxFnChunked vals := by - classical - have hseq : seq ≠ 0 := NeZero.ne (n := seq) - simp [reduceMaxFnTask, reduceMaxFnChunked, reduceFnTask, reduceFnChunked, hseq, - Task.spawn, foldl_taskMax_get_eq, task_getD_ofFn] - -/-- Cached direction head for head inputs. -/ -private def dirHeadVecOfInputs {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : Vector Rat dHead := - Vector.ofFn (fun d : Fin dHead => - Linear.dotFin dModel (fun j => inputs.wo j d) (fun j => inputs.direction j)) - -/-- LayerNorm bounds used by the induction-head builder. -/ -def headLnBounds [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : - (Fin seq → Fin dModel → Rat) × (Fin seq → Fin dModel → Rat) := - Bounds.cacheBoundPair2 (fun q => - Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) - -private theorem headLnBounds_spec [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : - headLnBounds inputs = - Bounds.cacheBoundPair2 (fun q => - Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q)) := rfl - -/-- Q/K/V bounds used by the induction-head builder. -/ -structure HeadQKVBounds (seq dModel dHead : Nat) where - /-- Q lower bounds. -/ - qLo : Fin seq → Fin dHead → Rat - /-- Q upper bounds. -/ - qHi : Fin seq → Fin dHead → Rat - /-- K lower bounds. -/ - kLo : Fin seq → Fin dHead → Rat - /-- K upper bounds. -/ - kHi : Fin seq → Fin dHead → Rat - /-- V lower bounds. -/ - vLo : Fin seq → Fin dHead → Rat - /-- V upper bounds. -/ - vHi : Fin seq → Fin dHead → Rat - /-- Q absolute bounds. -/ - qAbs : Fin seq → Fin dHead → Rat - /-- K absolute bounds. -/ - kAbs : Fin seq → Fin dHead → Rat - -/-- Compute Q/K/V bounds from LayerNorm bounds. -/ -def headQKVBounds [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (lnLo lnHi : Fin seq → Fin dModel → Rat) : - HeadQKVBounds seq dModel dHead := - let qLo := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalLowerUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + - inputs.bq d) - let qHi := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalUpperUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + - inputs.bq d) - let kLo := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalLowerUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + - inputs.bk d) - let kHi := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalUpperUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + - inputs.bk d) - let vLo := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalLowerUnnorm (fun j => inputs.wv j d) (lnLo q) (lnHi q) + - inputs.bv d) - let vHi := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalUpperUnnorm (fun j => inputs.wv j d) (lnLo q) (lnHi q) + - inputs.bv d) - let qAbs := - Bounds.cacheBound2 (fun q d => max |qLo q d| |qHi q d|) - let kAbs := - Bounds.cacheBound2 (fun q d => max |kLo q d| |kHi q d|) - { qLo := qLo - qHi := qHi - kLo := kLo - kHi := kHi - vLo := vLo - vHi := vHi - qAbs := qAbs - kAbs := kAbs } - -private theorem headQKVBounds_spec [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (lnLo lnHi : Fin seq → Fin dModel → Rat) : - headQKVBounds inputs lnLo lnHi = - let qLo := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalLowerUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + - inputs.bq d) - let qHi := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalUpperUnnorm (fun j => inputs.wq j d) (lnLo q) (lnHi q) + - inputs.bq d) - let kLo := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalLowerUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + - inputs.bk d) - let kHi := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalUpperUnnorm (fun j => inputs.wk j d) (lnLo q) (lnHi q) + - inputs.bk d) - let vLo := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalLowerUnnorm (fun j => inputs.wv j d) (lnLo q) (lnHi q) + - inputs.bv d) - let vHi := - Bounds.cacheBound2 (fun q d => - Bounds.dotIntervalUpperUnnorm (fun j => inputs.wv j d) (lnLo q) (lnHi q) + - inputs.bv d) - let qAbs := - Bounds.cacheBound2 (fun q d => max |qLo q d| |qHi q d|) - let kAbs := - Bounds.cacheBound2 (fun q d => max |kLo q d| |kHi q d|) - { qLo := qLo - qHi := qHi - kLo := kLo - kHi := kHi - vLo := vLo - vHi := vHi - qAbs := qAbs - kAbs := kAbs } := rfl - -/-- Score and margin bounds used by the induction-head builder. -/ -structure HeadScoreBounds (seq dModel dHead : Nat) where - /-- Absolute dot-product bound. -/ - dotAbs : Fin seq → Fin seq → Rat - /-- Base score absolute bound. -/ - scoreBaseAbs : Fin seq → Fin seq → Rat - /-- Score absolute bound with causal masking. -/ - scoreAbs : Fin seq → Fin seq → Rat - /-- Score lower bound. -/ - scoreLo : Fin seq → Fin seq → Rat - /-- Score upper bound. -/ - scoreHi : Fin seq → Fin seq → Rat - /-- Margin per query. -/ - marginAt : Fin seq → Rat - /-- Epsilon per query. -/ - epsAt : Fin seq → Rat - /-- Global margin. -/ - margin : Rat - /-- Global epsilon. -/ - eps : Rat - -/-- Compute score and margin bounds from cached score lower/upper bounds. -/ -def headScoreBoundsFromCaches [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (dotAbs : Fin seq → Fin seq → Rat) - (scoreLo scoreHi : Fin seq → Fin seq → Rat) : - HeadScoreBounds seq dModel dHead := - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let scoreBaseAbs : Fin seq → Fin seq → Rat := fun q k => - |inputs.scale| * dotAbs q k - let scoreAbs : Fin seq → Fin seq → Rat := fun q k => - if masked q k then |inputs.maskValue| else scoreBaseAbs q k - let otherKeys : Fin seq → Finset (Fin seq) := fun q => - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - let maskedKeys : Fin seq → Finset (Fin seq) := fun q => - if inputs.maskCausal = true then - (otherKeys q).filter (fun k => q < k) - else - (∅ : Finset (Fin seq)) - let unmaskedKeys : Fin seq → Finset (Fin seq) := fun q => - (otherKeys q) \ (maskedKeys q) - let maskedGap : Fin seq → Rat := fun q => - scoreLo q (inputs.prev q) - inputs.maskValue - let marginTasks : Array (Task Rat) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - if q ∈ inputs.active then - let other := unmaskedKeys q - let masked := maskedKeys q - if hunmasked : other.Nonempty then - let unmaskedMin := other.inf' hunmasked (fun k => - scoreLo q (inputs.prev q) - scoreHi q k) - if hmasked : masked.Nonempty then - min unmaskedMin (maskedGap q) - else - unmaskedMin - else - if hmasked : masked.Nonempty then - maskedGap q - else - (0 : Rat) - else - (0 : Rat))) - let marginAt : Fin seq → Rat := fun q => - (marginTasks[q.1]'(by - simp [marginTasks, q.isLt])).get - let epsTasks : Array (Task Rat) := - Array.ofFn (fun q : Fin seq => - (marginTasks[q.1]'(by - simp [marginTasks, q.isLt])).map (fun m => - if m < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + m))) - let epsAt : Fin seq → Rat := fun q => - (epsTasks[q.1]'(by - simp [epsTasks, q.isLt])).get - let margin : Rat := - if h : inputs.active.Nonempty then - inputs.active.inf' h marginAt - else - (0 : Rat) - let eps : Rat := - if margin < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + margin) - { dotAbs := dotAbs - scoreBaseAbs := scoreBaseAbs - scoreAbs := scoreAbs - scoreLo := scoreLo - scoreHi := scoreHi - marginAt := marginAt - epsAt := epsAt - margin := margin - eps := eps } - -/-- Compute score and margin bounds from dot-product absolute bounds. -/ -def headScoreBoundsFromDotAbs [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (dotAbs : Fin seq → Fin seq → Rat) : HeadScoreBounds seq dModel dHead := - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let dotAbsRowTasks : Array (Task { row : Array Rat // row.size = seq }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - ⟨Array.ofFn (fun k : Fin seq => dotAbs q k), by simp⟩)) - let scaleAbs : Rat := |inputs.scale| - let scoreLoCached : Fin seq → Fin seq → Rat := fun q k => - let row := (dotAbsRowTasks[q.1]'(by - simp [dotAbsRowTasks, q.isLt])).get - let base := scaleAbs * row.1.getD k.1 0 - if masked q k then inputs.maskValue else -base - let scoreHiCached : Fin seq → Fin seq → Rat := fun q k => - let row := (dotAbsRowTasks[q.1]'(by - simp [dotAbsRowTasks, q.isLt])).get - let base := scaleAbs * row.1.getD k.1 0 - if masked q k then inputs.maskValue else base - let marginAtRaw : Fin seq → Rat := fun q => - let row := (dotAbsRowTasks[q.1]'(by - simp [dotAbsRowTasks, q.isLt])).get - if q ∈ inputs.active then - let rowArr := row.1 - let prev := inputs.prev q - let dotAbsPrev := rowArr.getD prev.1 0 - if masked q prev then - let scoreLoPrev := inputs.maskValue - let scoreHiAt : Fin seq → Rat := fun k => - if masked q k then inputs.maskValue else scaleAbs * rowArr.getD k.1 0 - let maskedGap := scoreLoPrev - inputs.maskValue - let step : - (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := - fun acc k => - if k = prev then - acc - else if masked q k then - (acc.1, true) - else - let v := scoreLoPrev - scoreHiAt k - match acc.1 with - | none => (some v, acc.2) - | some cur => (some (min cur v), acc.2) - let acc := Linear.foldlFin seq step (none, false) - match acc.1, acc.2 with - | some unmaskedMin, true => min unmaskedMin maskedGap - | some unmaskedMin, false => unmaskedMin - | none, true => maskedGap - | none, false => (0 : Rat) - else - let scoreLoPrev := -(scaleAbs * dotAbsPrev) - let maskedGap := scoreLoPrev - inputs.maskValue - let step : - (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := - fun acc k => - if k = prev then - acc - else if masked q k then - (acc.1, true) - else - let raw := -(dotAbsPrev + rowArr.getD k.1 0) - match acc.1 with - | none => (some raw, acc.2) - | some cur => (some (min cur raw), acc.2) - let acc := Linear.foldlFin seq step (none, false) - match acc.1, acc.2 with - | some unmaskedMin, true => min (scaleAbs * unmaskedMin) maskedGap - | some unmaskedMin, false => scaleAbs * unmaskedMin - | none, true => maskedGap - | none, false => (0 : Rat) - else - (0 : Rat) - let marginAtCached := Bounds.cacheBoundThunk marginAtRaw - let marginAt : Fin seq → Rat := fun q => - marginAtCached q - let epsAtRaw : Fin seq → Rat := fun q => - let m := marginAt q - if m < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + m) - let epsAtCached := Bounds.cacheBoundThunk epsAtRaw - let epsAt : Fin seq → Rat := fun q => - epsAtCached q - let margin : Rat := - if h : inputs.active.Nonempty then - inputs.active.inf' h marginAt - else - (0 : Rat) - let eps : Rat := - if margin < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + margin) - { dotAbs := dotAbs - scoreBaseAbs := fun q k => |inputs.scale| * dotAbs q k - scoreAbs := fun q k => - if masked q k then |inputs.maskValue| else |inputs.scale| * dotAbs q k - scoreLo := scoreLoCached - scoreHi := scoreHiCached - marginAt := marginAt - epsAt := epsAt - margin := margin - eps := eps } - -private theorem headScoreBoundsFromDotAbs_spec [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (dotAbs : Fin seq → Fin seq → Rat) : - headScoreBoundsFromDotAbs inputs dotAbs = - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let dotAbsRowTasks : Array (Task { row : Array Rat // row.size = seq }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - ⟨Array.ofFn (fun k : Fin seq => dotAbs q k), by simp⟩)) - let scaleAbs : Rat := |inputs.scale| - let scoreLoCached : Fin seq → Fin seq → Rat := fun q k => - let row := (dotAbsRowTasks[q.1]'(by - simp [dotAbsRowTasks, q.isLt])).get - let base := scaleAbs * row.1.getD k.1 0 - if masked q k then inputs.maskValue else -base - let scoreHiCached : Fin seq → Fin seq → Rat := fun q k => - let row := (dotAbsRowTasks[q.1]'(by - simp [dotAbsRowTasks, q.isLt])).get - let base := scaleAbs * row.1.getD k.1 0 - if masked q k then inputs.maskValue else base - let marginAtRaw : Fin seq → Rat := fun q => - let row := (dotAbsRowTasks[q.1]'(by - simp [dotAbsRowTasks, q.isLt])).get - if q ∈ inputs.active then - let rowArr := row.1 - let prev := inputs.prev q - let dotAbsPrev := rowArr.getD prev.1 0 - if masked q prev then - let scoreLoPrev := inputs.maskValue - let scoreHiAt : Fin seq → Rat := fun k => - if masked q k then inputs.maskValue else scaleAbs * rowArr.getD k.1 0 - let maskedGap := scoreLoPrev - inputs.maskValue - let step : - (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := - fun acc k => - if k = prev then - acc - else if masked q k then - (acc.1, true) - else - let v := scoreLoPrev - scoreHiAt k - match acc.1 with - | none => (some v, acc.2) - | some cur => (some (min cur v), acc.2) - let acc := Linear.foldlFin seq step (none, false) - match acc.1, acc.2 with - | some unmaskedMin, true => min unmaskedMin maskedGap - | some unmaskedMin, false => unmaskedMin - | none, true => maskedGap - | none, false => (0 : Rat) - else - let scoreLoPrev := -(scaleAbs * dotAbsPrev) - let maskedGap := scoreLoPrev - inputs.maskValue - let step : - (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := - fun acc k => - if k = prev then - acc - else if masked q k then - (acc.1, true) - else - let raw := -(dotAbsPrev + rowArr.getD k.1 0) - match acc.1 with - | none => (some raw, acc.2) - | some cur => (some (min cur raw), acc.2) - let acc := Linear.foldlFin seq step (none, false) - match acc.1, acc.2 with - | some unmaskedMin, true => min (scaleAbs * unmaskedMin) maskedGap - | some unmaskedMin, false => scaleAbs * unmaskedMin - | none, true => maskedGap - | none, false => (0 : Rat) - else - (0 : Rat) - let marginAtCached := Bounds.cacheBoundThunk marginAtRaw - let marginAt : Fin seq → Rat := fun q => - marginAtCached q - let epsAtRaw : Fin seq → Rat := fun q => - let m := marginAt q - if m < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + m) - let epsAtCached := Bounds.cacheBoundThunk epsAtRaw - let epsAt : Fin seq → Rat := fun q => - epsAtCached q - let margin : Rat := - if h : inputs.active.Nonempty then - inputs.active.inf' h marginAt - else - (0 : Rat) - let eps : Rat := - if margin < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + margin) - { dotAbs := dotAbs - scoreBaseAbs := fun q k => |inputs.scale| * dotAbs q k - scoreAbs := fun q k => - if masked q k then |inputs.maskValue| else |inputs.scale| * dotAbs q k - scoreLo := scoreLoCached - scoreHi := scoreHiCached - marginAt := marginAt - epsAt := epsAt - margin := margin - eps := eps } := rfl - -/-- Compute score and margin bounds from Q/K interval bounds. -/ -def headScoreBoundsFromIntervals [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (qLo qHi kLo kHi : Fin seq → Fin dHead → Rat) : - HeadScoreBounds seq dModel dHead := - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - ⟨Array.ofFn (fun k : Fin seq => - dotIntervalLowerUpper2CommonDen (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo k d) (fun d => kHi k d)), - by simp⟩)) - let dotLo : Fin seq → Fin seq → Rat := fun q k => - let row := (dotRowTasks[q.1]'(by - simp [dotRowTasks, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.1 - let dotHi : Fin seq → Fin seq → Rat := fun q k => - let row := (dotRowTasks[q.1]'(by - simp [dotRowTasks, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.2 - let dotAbs : Fin seq → Fin seq → Rat := fun q k => max |dotLo q k| |dotHi q k| - let scoreLo : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - if 0 ≤ inputs.scale then - inputs.scale * dotLo q k - else - inputs.scale * dotHi q k - let scoreHi : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - if 0 ≤ inputs.scale then - inputs.scale * dotHi q k - else - inputs.scale * dotLo q k - headScoreBoundsFromCaches inputs dotAbs scoreLo scoreHi - -private theorem headScoreBoundsFromIntervals_spec [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (qLo qHi kLo kHi : Fin seq → Fin dHead → Rat) : - headScoreBoundsFromIntervals inputs qLo qHi kLo kHi = - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let dotRowTasks : Array (Task { row : Array (Rat × Rat) // row.size = seq }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - ⟨Array.ofFn (fun k : Fin seq => - dotIntervalLowerUpper2CommonDen (fun d => qLo q d) (fun d => qHi q d) - (fun d => kLo k d) (fun d => kHi k d)), - by simp⟩)) - let dotLo : Fin seq → Fin seq → Rat := fun q k => - let row := (dotRowTasks[q.1]'(by - simp [dotRowTasks, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.1 - let dotHi : Fin seq → Fin seq → Rat := fun q k => - let row := (dotRowTasks[q.1]'(by - simp [dotRowTasks, q.isLt])).get - let entry := row.1[k.1]'(by - simp [row.2, k.isLt]) - entry.2 - let dotAbs : Fin seq → Fin seq → Rat := fun q k => max |dotLo q k| |dotHi q k| - let scoreLo : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - if 0 ≤ inputs.scale then - inputs.scale * dotLo q k - else - inputs.scale * dotHi q k - let scoreHi : Fin seq → Fin seq → Rat := fun q k => - if masked q k then - inputs.maskValue - else - if 0 ≤ inputs.scale then - inputs.scale * dotHi q k - else - inputs.scale * dotLo q k - headScoreBoundsFromCaches inputs dotAbs scoreLo scoreHi := rfl - -/-- Compute score and margin bounds from Q/K absolute bounds. -/ -def headScoreBounds [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (qAbs kAbs : Fin seq → Fin dHead → Rat) : - HeadScoreBounds seq dModel dHead := - headScoreBoundsFromDotAbs inputs (fun q k => - Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d)) - -private theorem headScoreBounds_spec [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (qAbs kAbs : Fin seq → Fin dHead → Rat) : - headScoreBounds inputs qAbs kAbs = - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - let dotAbs : Fin seq → Fin seq → Rat := fun q k => - Linear.dotFin dHead (fun d => qAbs q d) (fun d => kAbs k d) - let dotAbsRowTasks : Array (Task { row : Array Rat // row.size = seq }) := - Array.ofFn (fun q : Fin seq => - Task.spawn (fun _ => - ⟨Array.ofFn (fun k : Fin seq => dotAbs q k), by simp⟩)) - let scaleAbs : Rat := |inputs.scale| - let scoreLoCached : Fin seq → Fin seq → Rat := fun q k => - let row := (dotAbsRowTasks[q.1]'(by - simp [dotAbsRowTasks, q.isLt])).get - let base := scaleAbs * row.1.getD k.1 0 - if masked q k then inputs.maskValue else -base - let scoreHiCached : Fin seq → Fin seq → Rat := fun q k => - let row := (dotAbsRowTasks[q.1]'(by - simp [dotAbsRowTasks, q.isLt])).get - let base := scaleAbs * row.1.getD k.1 0 - if masked q k then inputs.maskValue else base - let marginAtRaw : Fin seq → Rat := fun q => - let row := (dotAbsRowTasks[q.1]'(by - simp [dotAbsRowTasks, q.isLt])).get - if q ∈ inputs.active then - let rowArr := row.1 - let prev := inputs.prev q - let dotAbsPrev := rowArr.getD prev.1 0 - if masked q prev then - let scoreLoPrev := inputs.maskValue - let scoreHiAt : Fin seq → Rat := fun k => - if masked q k then inputs.maskValue else scaleAbs * rowArr.getD k.1 0 - let maskedGap := scoreLoPrev - inputs.maskValue - let step : - (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := - fun acc k => - if k = prev then - acc - else if masked q k then - (acc.1, true) - else - let v := scoreLoPrev - scoreHiAt k - match acc.1 with - | none => (some v, acc.2) - | some cur => (some (min cur v), acc.2) - let acc := Linear.foldlFin seq step (none, false) - match acc.1, acc.2 with - | some unmaskedMin, true => min unmaskedMin maskedGap - | some unmaskedMin, false => unmaskedMin - | none, true => maskedGap - | none, false => (0 : Rat) - else - let scoreLoPrev := -(scaleAbs * dotAbsPrev) - let maskedGap := scoreLoPrev - inputs.maskValue - let step : - (Option Rat × Bool) → Fin seq → (Option Rat × Bool) := - fun acc k => - if k = prev then - acc - else if masked q k then - (acc.1, true) - else - let raw := -(dotAbsPrev + rowArr.getD k.1 0) - match acc.1 with - | none => (some raw, acc.2) - | some cur => (some (min cur raw), acc.2) - let acc := Linear.foldlFin seq step (none, false) - match acc.1, acc.2 with - | some unmaskedMin, true => min (scaleAbs * unmaskedMin) maskedGap - | some unmaskedMin, false => scaleAbs * unmaskedMin - | none, true => maskedGap - | none, false => (0 : Rat) - else - (0 : Rat) - let marginAtCached := Bounds.cacheBoundThunk marginAtRaw - let marginAt : Fin seq → Rat := fun q => - marginAtCached q - let epsAtRaw : Fin seq → Rat := fun q => - let m := marginAt q - if m < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + m) - let epsAtCached := Bounds.cacheBoundThunk epsAtRaw - let epsAt : Fin seq → Rat := fun q => - epsAtCached q - let margin : Rat := - if h : inputs.active.Nonempty then - inputs.active.inf' h marginAt - else - (0 : Rat) - let eps : Rat := - if margin < 0 then - (1 : Rat) - else - ratDivUp (seq - 1) (1 + margin) - { dotAbs := dotAbs - scoreBaseAbs := fun q k => |inputs.scale| * dotAbs q k - scoreAbs := fun q k => - if masked q k then |inputs.maskValue| else |inputs.scale| * dotAbs q k - scoreLo := scoreLoCached - scoreHi := scoreHiCached - marginAt := marginAt - epsAt := epsAt - margin := margin - eps := eps } := rfl - -/-- Value bounds used by the induction-head builder. -/ -structure HeadValueBounds (seq dModel dHead : Nat) where - /-- Value lower bounds. -/ - valsLo : Fin seq → Rat - /-- Value upper bounds. -/ - valsHi : Fin seq → Rat - /-- Global value lower bound. -/ - lo : Rat - /-- Global value upper bound. -/ - hi : Rat - -/-- Cached direction vector for value bounds. -/ -def headValueDirHead {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin dHead → Rat := - let dirHeadVec := dirHeadVecOfInputs inputs - fun d => dirHeadVec.get d - -private theorem headValueDirHead_spec {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : - headValueDirHead inputs = - let dirHeadVec := dirHeadVecOfInputs inputs - fun d => dirHeadVec.get d := rfl - -/-- Cached lower value bounds from V intervals. -/ -def headValueValsLoArray {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : Array Rat := - let dirHead := headValueDirHead inputs - Array.ofFn (fun k => - Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) - -/-- Unfold `headValueValsLoArray` to its `Array.ofFn` definition. -/ -private theorem headValueValsLoArray_spec {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueValsLoArray inputs vLo vHi = - let dirHead := headValueDirHead inputs - Array.ofFn (fun k => - Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) := rfl - -/-- Cached lower value bounds from V intervals. -/ -def headValueValsLo {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : Fin seq → Rat := - let arr := headValueValsLoArray inputs vLo vHi - fun k => arr.getD k.1 (0 : Rat) - -private theorem headValueValsLo_spec {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueValsLo inputs vLo vHi = - let arr := headValueValsLoArray inputs vLo vHi - fun k => arr.getD k.1 (0 : Rat) := rfl - -/-- Cached lower value bounds from V intervals using a common-denominator sum. -/ -def headValueValsLoCommonDenArray {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : Array Rat := - headValueValsLoArray inputs vLo vHi - -/-- Unfold `headValueValsLoCommonDenArray` to its `Array.ofFn` definition. -/ -private theorem headValueValsLoCommonDenArray_spec {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueValsLoCommonDenArray inputs vLo vHi = - let dirHead := headValueDirHead inputs - Array.ofFn (fun k => - Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) := rfl - -/-- Cached lower value bounds from V intervals using a common-denominator sum. -/ -def headValueValsLoCommonDen {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : Fin seq → Rat := - let arr := headValueValsLoCommonDenArray inputs vLo vHi - fun k => arr.getD k.1 (0 : Rat) - -private theorem headValueValsLoCommonDen_spec {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueValsLoCommonDen inputs vLo vHi = - let arr := headValueValsLoCommonDenArray inputs vLo vHi - fun k => arr.getD k.1 (0 : Rat) := rfl - -/-- Common-denominator lower bounds agree with cached rational bounds pointwise. -/ -theorem headValueValsLoCommonDenArray_eq {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueValsLoCommonDenArray inputs vLo vHi = headValueValsLoArray inputs vLo vHi := by - rfl - -theorem headValueValsLoCommonDen_eq {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueValsLoCommonDen inputs vLo vHi = headValueValsLo inputs vLo vHi := by - funext k - simp [headValueValsLoCommonDen, headValueValsLo, headValueValsLoCommonDenArray_eq] - -/-- Cached upper value bounds from V intervals. -/ -def headValueValsHiArray {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : Array Rat := - let dirHead := headValueDirHead inputs - Array.ofFn (fun k => - Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) - -/-- Unfold `headValueValsHiArray` to its `Array.ofFn` definition. -/ -private theorem headValueValsHiArray_spec {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueValsHiArray inputs vLo vHi = - let dirHead := headValueDirHead inputs - Array.ofFn (fun k => - Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) := rfl - -/-- Cached upper value bounds from V intervals. -/ -def headValueValsHi {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : Fin seq → Rat := - let arr := headValueValsHiArray inputs vLo vHi - fun k => arr.getD k.1 (0 : Rat) - -private theorem headValueValsHi_spec {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueValsHi inputs vLo vHi = - let arr := headValueValsHiArray inputs vLo vHi - fun k => arr.getD k.1 (0 : Rat) := rfl - -/-- Cached upper value bounds from V intervals using a common-denominator sum. -/ -def headValueValsHiCommonDenArray {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : Array Rat := - headValueValsHiArray inputs vLo vHi - -/-- Unfold `headValueValsHiCommonDenArray` to its `Array.ofFn` definition. -/ -private theorem headValueValsHiCommonDenArray_spec {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueValsHiCommonDenArray inputs vLo vHi = - let dirHead := headValueDirHead inputs - Array.ofFn (fun k => - Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) := rfl - -/-- Cached upper value bounds from V intervals using a common-denominator sum. -/ -def headValueValsHiCommonDen {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : Fin seq → Rat := - let arr := headValueValsHiCommonDenArray inputs vLo vHi - fun k => arr.getD k.1 (0 : Rat) - -private theorem headValueValsHiCommonDen_spec {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueValsHiCommonDen inputs vLo vHi = - let arr := headValueValsHiCommonDenArray inputs vLo vHi - fun k => arr.getD k.1 (0 : Rat) := rfl - -/-- Common-denominator upper bounds agree with cached rational bounds pointwise. -/ -theorem headValueValsHiCommonDenArray_eq {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueValsHiCommonDenArray inputs vLo vHi = headValueValsHiArray inputs vLo vHi := by - rfl - -theorem headValueValsHiCommonDen_eq {seq dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueValsHiCommonDen inputs vLo vHi = headValueValsHi inputs vLo vHi := by - funext k - simp [headValueValsHiCommonDen, headValueValsHi, headValueValsHiCommonDenArray_eq] - -/-- Global lower value bound from an array of per-key values. -/ -def headValueLoArray (valsLo : Array Rat) : Rat := - reduceMinArray valsLo - -/-- Unfold `headValueLoArray` to its reduction helper. -/ -private theorem headValueLoArray_spec (valsLo : Array Rat) : - headValueLoArray valsLo = reduceMinArray valsLo := rfl - -/-- Global lower value bound from cached per-key values. -/ -def headValueLo [NeZero seq] (valsLo : Fin seq → Rat) : Rat := - headValueLoArray (Array.ofFn valsLo) - -private theorem headValueLo_spec [NeZero seq] (valsLo : Fin seq → Rat) : - headValueLo valsLo = headValueLoArray (Array.ofFn valsLo) := rfl - -/-- Task wrapper for `headValueLo`. -/ -def headValueLoTask [NeZero seq] (valsLo : Fin seq → Rat) : Task Rat := - reduceMinFnTask valsLo - -private theorem headValueLoTask_spec [NeZero seq] (valsLo : Fin seq → Rat) : - headValueLoTask valsLo = reduceMinFnTask valsLo := rfl - -/-- Chunked task reduction agrees with the sequential chunked value bound. -/ -private theorem headValueLoTask_get_eq [NeZero seq] (valsLo : Fin seq → Rat) : - (headValueLoTask valsLo).get = reduceMinFnChunked valsLo := by - simp [headValueLoTask_spec, reduceMinFnTask_get_eq] - -/-- Global upper value bound from an array of per-key values. -/ -def headValueHiArray (valsHi : Array Rat) : Rat := - reduceMaxArray valsHi - -/-- Unfold `headValueHiArray` to its reduction helper. -/ -private theorem headValueHiArray_spec (valsHi : Array Rat) : - headValueHiArray valsHi = reduceMaxArray valsHi := rfl - -/-- Global upper value bound from cached per-key values. -/ -def headValueHi [NeZero seq] (valsHi : Fin seq → Rat) : Rat := - headValueHiArray (Array.ofFn valsHi) - -private theorem headValueHi_spec [NeZero seq] (valsHi : Fin seq → Rat) : - headValueHi valsHi = headValueHiArray (Array.ofFn valsHi) := rfl - -/-- Task wrapper for `headValueHi`. -/ -def headValueHiTask [NeZero seq] (valsHi : Fin seq → Rat) : Task Rat := - reduceMaxFnTask valsHi - -private theorem headValueHiTask_spec [NeZero seq] (valsHi : Fin seq → Rat) : - headValueHiTask valsHi = reduceMaxFnTask valsHi := rfl - -/-- Chunked task reduction agrees with the sequential chunked value bound. -/ -private theorem headValueHiTask_get_eq [NeZero seq] (valsHi : Fin seq → Rat) : - (headValueHiTask valsHi).get = reduceMaxFnChunked valsHi := by - simp [headValueHiTask_spec, reduceMaxFnTask_get_eq] - -/-- Build `HeadValueBounds` from precomputed arrays. -/ -private def headValueBoundsOfArrays {seq dModel dHead : Nat} - (valsLoArr valsHiArr : Array Rat) : HeadValueBounds seq dModel dHead := - let valsLo : Fin seq → Rat := fun k => valsLoArr.getD k.1 (0 : Rat) - let valsHi : Fin seq → Rat := fun k => valsHiArr.getD k.1 (0 : Rat) - let lo := headValueLoArray valsLoArr - let hi := headValueHiArray valsHiArr - { valsLo := valsLo, valsHi := valsHi, lo := lo, hi := hi } - -/-- Build a cached bounds array in parallel from a per-key computation. -/ -private def buildBoundArrayTask [NeZero seq] (f : Fin seq → Rat) : Task (Array Rat) := - let n := seq - let chunkSize : Nat := 64 - let chunks : Nat := (n + chunkSize - 1) / chunkSize - let hpos : 0 < seq := Nat.pos_of_ne_zero (by simpa using (NeZero.ne (n := seq))) - let defaultIdx : Fin seq := ⟨0, hpos⟩ - let idxs : Array (Fin seq) := Array.ofFn (fun i : Fin seq => i) - let chunkTasks : List (Task (Array Rat)) := - (List.range chunks).map (fun c => - Task.spawn (fun _ => - let start := c * chunkSize - let stop := Nat.min n (start + chunkSize) - let vals := - (List.range (stop - start)).map (fun i => - f (idxs.getD (start + i) defaultIdx)) - vals.toArray)) - Task.mapList (fun xs => xs.foldl (fun acc arr => acc ++ arr) #[]) chunkTasks - -/-- Compute value bounds from V interval bounds. -/ -def headValueBounds [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - HeadValueBounds seq dModel dHead := - let valsLoArr := headValueValsLoArray inputs vLo vHi - let valsHiArr := headValueValsHiArray inputs vLo vHi - headValueBoundsOfArrays valsLoArr valsHiArr - -private theorem headValueBounds_spec [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueBounds inputs vLo vHi = - let valsLoArr := headValueValsLoArray inputs vLo vHi - let valsHiArr := headValueValsHiArray inputs vLo vHi - headValueBoundsOfArrays valsLoArr valsHiArr := rfl - -/-- Compute value bounds from V interval bounds in parallel. -/ -def headValueBoundsTask [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - Task (HeadValueBounds seq dModel dHead) := - let dirHead := headValueDirHead inputs - let valsLoTask := buildBoundArrayTask (fun k => - Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) - let valsHiTask := buildBoundArrayTask (fun k => - Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) - Task.bind valsLoTask (fun valsLoArr => - Task.map (fun valsHiArr => headValueBoundsOfArrays valsLoArr valsHiArr) valsHiTask) - -/-- Unfold `headValueBoundsTask` to its task graph. -/ -private theorem headValueBoundsTask_spec [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueBoundsTask inputs vLo vHi = - let dirHead := headValueDirHead inputs - let valsLoTask := buildBoundArrayTask (fun k => - Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) - let valsHiTask := buildBoundArrayTask (fun k => - Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) - Task.bind valsLoTask (fun valsLoArr => - Task.map (fun valsHiArr => headValueBoundsOfArrays valsLoArr valsHiArr) valsHiTask) := rfl - -/-- Compute value bounds from V interval bounds using a common-denominator sum. -/ -def headValueBoundsCommonDen [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - HeadValueBounds seq dModel dHead := - let valsLoArr := headValueValsLoCommonDenArray inputs vLo vHi - let valsHiArr := headValueValsHiCommonDenArray inputs vLo vHi - headValueBoundsOfArrays valsLoArr valsHiArr - -private theorem headValueBoundsCommonDen_spec [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueBoundsCommonDen inputs vLo vHi = - let valsLoArr := headValueValsLoCommonDenArray inputs vLo vHi - let valsHiArr := headValueValsHiCommonDenArray inputs vLo vHi - headValueBoundsOfArrays valsLoArr valsHiArr := rfl - -/-- Compute value bounds from V intervals using a common-denominator sum in parallel. -/ -def headValueBoundsCommonDenTask [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - Task (HeadValueBounds seq dModel dHead) := - let dirHead := headValueDirHead inputs - let valsLoTask := buildBoundArrayTask (fun k => - Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) - let valsHiTask := buildBoundArrayTask (fun k => - Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) - Task.bind valsLoTask (fun valsLoArr => - Task.map (fun valsHiArr => headValueBoundsOfArrays valsLoArr valsHiArr) valsHiTask) - -/-- Unfold `headValueBoundsCommonDenTask` to its task graph. -/ -private theorem headValueBoundsCommonDenTask_spec [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueBoundsCommonDenTask inputs vLo vHi = - let dirHead := headValueDirHead inputs - let valsLoTask := buildBoundArrayTask (fun k => - Bounds.dotIntervalLowerCommonDen dirHead (vLo k) (vHi k)) - let valsHiTask := buildBoundArrayTask (fun k => - Bounds.dotIntervalUpperCommonDen dirHead (vLo k) (vHi k)) - Task.bind valsLoTask (fun valsLoArr => - Task.map (fun valsHiArr => headValueBoundsOfArrays valsLoArr valsHiArr) valsHiTask) := rfl - -theorem headValueBoundsCommonDen_eq [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) - (vLo vHi : Fin seq → Fin dHead → Rat) : - headValueBoundsCommonDen inputs vLo vHi = headValueBounds inputs vLo vHi := by - classical - simp [headValueBoundsCommonDen, headValueBounds, headValueValsLoCommonDenArray_eq, - headValueValsHiCommonDenArray_eq] - -end Sound - -end Nfp diff --git a/Nfp/Sound/Induction/HeadOutput.lean b/Nfp/Sound/Induction/HeadOutput.lean index af9ff4c..2723975 100644 --- a/Nfp/Sound/Induction/HeadOutput.lean +++ b/Nfp/Sound/Induction/HeadOutput.lean @@ -2,8 +2,8 @@ module -public import Aesop -public import Nfp.Sound.Induction.CoreSound +public import Nfp.Circuit.Cert.ResidualInterval +public import Nfp.Sound.Induction.CoreDefs /-! Head-output interval certificates for induction heads. @@ -15,58 +15,7 @@ namespace Nfp namespace Sound -open scoped BigOperators - open Nfp.Circuit -open Nfp.Sound.Bounds - -variable {seq : Nat} - -/-- Build and certify induction certificates from exact head inputs. -/ -def buildInductionCertFromHeadWith? [NeZero seq] {dModel dHead : Nat} - (cfg : InductionHeadSplitConfig) - (inputs : Model.InductionHeadInputs seq dModel dHead) : - Option {c : InductionHeadCert seq // InductionHeadCertSound inputs c} := by - classical - cases hcore : buildInductionCertFromHeadCoreWith? cfg inputs with - | none => exact none - | some c => - exact some ⟨c, buildInductionCertFromHeadCoreWith?_sound (cfg := cfg) inputs c hcore⟩ - -/-- Build and certify induction certificates from exact head inputs, retaining the core cache. -/ -def buildInductionCertFromHeadWithCache? [NeZero seq] {dModel dHead : Nat} - (cfg : InductionHeadSplitConfig) - (inputs : Model.InductionHeadInputs seq dModel dHead) : - Option {cache : InductionHeadCoreCache seq dModel dHead // - InductionHeadCertSound inputs cache.cert} := by - classical - by_cases hEps : 0 < inputs.lnEps - · by_cases hSqrt : 0 < sqrtLower inputs.lnEps - · by_cases hmodel : dModel = 0 - · exact none - · by_cases hactive : inputs.active.Nonempty - · let cache := buildInductionHeadCoreCacheWith cfg inputs - have hmodel' : dModel ≠ 0 := by - exact hmodel - have hcore : - buildInductionCertFromHeadCoreWith? cfg inputs = some cache.cert := by - simpa [cache] using - (buildInductionCertFromHeadCoreWith?_eq_some - (cfg := cfg) (inputs := inputs) hEps hSqrt hmodel' hactive) - exact some ⟨cache, - buildInductionCertFromHeadCoreWith?_sound (cfg := cfg) inputs cache.cert hcore⟩ - · exact none - · exact none - · exact none - -/-- Build and certify induction certificates from exact head inputs using the default split -budgets. -/ -def buildInductionCertFromHead? [NeZero seq] {dModel dHead : Nat} - (inputs : Model.InductionHeadInputs seq dModel dHead) : - Option {c : InductionHeadCert seq // InductionHeadCertSound inputs c} := - buildInductionCertFromHeadWith? defaultInductionHeadSplitConfig inputs - -section HeadOutputInterval variable {seq dModel dHead : Nat} @@ -117,325 +66,8 @@ structure HeadOutputIntervalSound [NeZero seq] (c.lo i : Real) ≤ headOutput inputs q i ∧ headOutput inputs q i ≤ (c.hi i : Real) -/-- Certified head-output interval data for a specific active set. -/ -structure HeadOutputIntervalResult [NeZero seq] - (inputs : Model.InductionHeadInputs seq dModel dHead) where - /-- Active queries covered by the interval bounds. -/ - active : Finset (Fin seq) - /-- Residual-interval certificate for head outputs. -/ - cert : Circuit.ResidualIntervalCert dModel - /-- Soundness proof for the interval bounds. -/ - sound : HeadOutputIntervalSound inputs active cert - -/-- Build residual-interval bounds for head outputs on active queries. -/ -def buildHeadOutputIntervalFromHead? [NeZero seq] - (inputs : Model.InductionHeadInputs seq dModel dHead) : - Option (HeadOutputIntervalResult inputs) := by - classical - cases seq with - | zero => - cases (NeZero.ne (n := (0 : Nat)) rfl) - | succ n => - by_cases hEps : 0 < inputs.lnEps - · by_cases hSqrt : 0 < sqrtLower inputs.lnEps - · by_cases hmodel : dModel = 0 - · exact none - · cases hbuild : buildInductionCertFromHead? inputs with - | none => exact none - | some certWithProof => - rcases certWithProof with ⟨cert, hcert⟩ - let lnBounds : Fin (Nat.succ n) → (Fin dModel → Rat) × (Fin dModel → Rat) := - fun q => - Bounds.layerNormBounds inputs.lnEps inputs.ln1Gamma inputs.ln1Beta - (inputs.embed q) - let lnLo : Fin (Nat.succ n) → Fin dModel → Rat := fun q => (lnBounds q).1 - let lnHi : Fin (Nat.succ n) → Fin dModel → Rat := fun q => (lnBounds q).2 - let vLo : Fin (Nat.succ n) → Fin dHead → Rat := fun q d => - dotIntervalLower (fun j => inputs.wv j d) (lnLo q) (lnHi q) + inputs.bv d - let vHi : Fin (Nat.succ n) → Fin dHead → Rat := fun q d => - dotIntervalUpper (fun j => inputs.wv j d) (lnLo q) (lnHi q) + inputs.bv d - let headValueLo : Fin (Nat.succ n) → Fin dModel → Rat := fun k i => - dotIntervalLower (fun d => inputs.wo i d) (vLo k) (vHi k) - let headValueHi : Fin (Nat.succ n) → Fin dModel → Rat := fun k i => - dotIntervalUpper (fun d => inputs.wo i d) (vLo k) (vHi k) - have hln_bounds : - ∀ q i, (lnLo q i : Real) ≤ lnRealOfInputs inputs q i ∧ - lnRealOfInputs inputs q i ≤ (lnHi q i : Real) := by - intro q i - have hln := - Bounds.layerNormBounds_spec (eps := inputs.lnEps) - (gamma := inputs.ln1Gamma) (beta := inputs.ln1Beta) - (x := inputs.embed q) hmodel hEps hSqrt - simpa [lnBounds, lnLo, lnHi, lnRealOfInputs_def] using hln i - have hv_bounds : - ∀ q d, (vLo q d : Real) ≤ vRealOfInputs inputs q d ∧ - vRealOfInputs inputs q d ≤ (vHi q d : Real) := by - intro q d - have hln := hln_bounds q - have hlo : ∀ j, (lnLo q j : Real) ≤ lnRealOfInputs inputs q j := fun j => - (hln j).1 - have hhi : ∀ j, lnRealOfInputs inputs q j ≤ (lnHi q j : Real) := fun j => - (hln j).2 - have hlow' : - dotIntervalLower (fun j => inputs.wv j d) (lnLo q) (lnHi q) + - (inputs.bv d : Real) ≤ - dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs q) + - (inputs.bv d : Real) := - by - simpa using - dotIntervalLower_le_dotProduct_real_add - (v := fun j => inputs.wv j d) - (lo := lnLo q) (hi := lnHi q) - (x := lnRealOfInputs inputs q) (b := (inputs.bv d : Real)) hlo hhi - have hhigh' : - dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs q) + - (inputs.bv d : Real) ≤ - dotIntervalUpper (fun j => inputs.wv j d) (lnLo q) (lnHi q) + - (inputs.bv d : Real) := - by - simpa using - dotProduct_le_dotIntervalUpper_real_add - (v := fun j => inputs.wv j d) - (lo := lnLo q) (hi := lnHi q) - (x := lnRealOfInputs inputs q) (b := (inputs.bv d : Real)) hlo hhi - constructor - · simpa [vLo, vRealOfInputs_def, Bounds.cacheBound2_apply, - Bounds.dotIntervalLowerCachedRat_eq, ratToReal_add] using - hlow' - · simpa [vHi, vRealOfInputs_def, Bounds.cacheBound2_apply, - Bounds.dotIntervalUpperCachedRat_eq, ratToReal_add] using - hhigh' - have hhead_bounds : - ∀ k i, (headValueLo k i : Real) ≤ headValueRealOfInputs inputs k i ∧ - headValueRealOfInputs inputs k i ≤ (headValueHi k i : Real) := by - intro k i - have hv := hv_bounds k - have hlo : ∀ d, (vLo k d : Real) ≤ vRealOfInputs inputs k d := fun d => (hv d).1 - have hhi : ∀ d, vRealOfInputs inputs k d ≤ (vHi k d : Real) := fun d => (hv d).2 - have hlow := - dotIntervalLower_le_dotProduct_real (v := fun d => inputs.wo i d) - (lo := vLo k) (hi := vHi k) - (x := fun d => vRealOfInputs inputs k d) hlo hhi - have hhigh := - dotProduct_le_dotIntervalUpper_real (v := fun d => inputs.wo i d) - (lo := vLo k) (hi := vHi k) - (x := fun d => vRealOfInputs inputs k d) hlo hhi - constructor - · simpa [headValueLo, headValueRealOfInputs_def] using hlow - · simpa [headValueHi, headValueRealOfInputs_def] using hhigh - let scoresReal : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := - scoresRealOfInputs inputs - let weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := fun q k => - Circuit.softmax (scoresReal q) k - let activeSet : Finset (Fin (Nat.succ n)) := cert.active - let univ : Finset (Fin (Nat.succ n)) := Finset.univ - have huniv : univ.Nonempty := by simp [univ] - let loVal : Fin dModel → Rat := fun i => - univ.inf' huniv (fun k => headValueLo k i) - let hiVal : Fin dModel → Rat := fun i => - univ.sup' huniv (fun k => headValueHi k i) - have hvalsBoundsReal : - ∀ i, Layers.ValueRangeBounds (Val := Real) - (loVal i : Real) (hiVal i : Real) - (fun k => headValueRealOfInputs inputs k i) := by - intro i - have hloVal : ∀ k, loVal i ≤ headValueLo k i := by - intro k - dsimp [loVal] - refine (Finset.inf'_le_iff (s := univ) (H := huniv) - (f := fun k => headValueLo k i) (a := headValueLo k i)).2 ?_ - exact ⟨k, by simp [univ], le_rfl⟩ - have hhiVal : ∀ k, headValueHi k i ≤ hiVal i := by - intro k - dsimp [hiVal] - refine (Finset.le_sup'_iff (s := univ) (H := huniv) - (f := fun k => headValueHi k i) (a := headValueHi k i)).2 ?_ - exact ⟨k, ⟨by simp [univ], le_rfl⟩⟩ - refine { lo_le_hi := ?_, lo_le := ?_, le_hi := ?_ } - · rcases (Finset.univ_nonempty : univ.Nonempty) with ⟨k0, hk0⟩ - have hloRat : loVal i ≤ headValueLo k0 i := hloVal k0 - have hhiRat : headValueHi k0 i ≤ hiVal i := hhiVal k0 - have hbounds := hhead_bounds k0 i - have hloReal : (loVal i : Real) ≤ (headValueLo k0 i : Real) := by - simpa [ratToReal_def] using ratToReal_le_of_le hloRat - have hhiReal : (headValueHi k0 i : Real) ≤ (hiVal i : Real) := by - simpa [ratToReal_def] using ratToReal_le_of_le hhiRat - have hreal : (loVal i : Real) ≤ (hiVal i : Real) := by - exact le_trans hloReal (le_trans hbounds.1 (le_trans hbounds.2 hhiReal)) - exact hreal - · intro k - have hloRat : loVal i ≤ headValueLo k i := hloVal k - have hbounds := hhead_bounds k i - have hloReal : (loVal i : Real) ≤ (headValueLo k i : Real) := by - simpa [ratToReal_def] using ratToReal_le_of_le hloRat - exact hloReal.trans hbounds.1 - · intro k - have hhiRat : headValueHi k i ≤ hiVal i := hhiVal k - have hbounds := hhead_bounds k i - have hhiReal : (headValueHi k i : Real) ≤ (hiVal i : Real) := by - simpa [ratToReal_def] using ratToReal_le_of_le hhiRat - exact hbounds.2.trans hhiReal - have hsoftmax : - Layers.SoftmaxMarginBoundsOn (Val := Real) - (cert.eps : Real) (cert.margin : Real) - (fun q => q ∈ activeSet) cert.prev scoresReal weights := by - simpa [scoresReal, weights, activeSet] using hcert.softmax_bounds - have hweights : - Layers.OneHotApproxBoundsOnActive (Val := Real) (cert.eps : Real) - (fun q => q ∈ activeSet) cert.prev weights := - Layers.oneHotApproxBoundsOnActive_of_softmaxMargin - (Val := Real) - (ε := (cert.eps : Real)) - (margin := (cert.margin : Real)) - (active := fun q => q ∈ activeSet) - (prev := cert.prev) - (scores := scoresReal) - (weights := weights) - hsoftmax - have happrox : - ∀ i, Layers.InductionSpecApproxOn (Val := Real) (n := n) - ((cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real))) - (fun q => q ∈ activeSet) cert.prev - (fun q => dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i)) - (fun k => headValueRealOfInputs inputs k i) := by - intro i - exact - Layers.inductionSpecApproxOn_of_oneHotApprox_valueRange - (Val := Real) - (n := n) - (ε := (cert.eps : Real)) - (lo := (loVal i : Real)) - (hi := (hiVal i : Real)) - (active := fun q => q ∈ activeSet) - (prev := cert.prev) - (weights := weights) - (vals := fun k => headValueRealOfInputs inputs k i) - (hweights := hweights) - (hvals := hvalsBoundsReal i) - let delta : Fin dModel → Rat := fun i => hiVal i - loVal i - let boundLoRat : Fin (Nat.succ n) → Fin dModel → Rat := fun q i => - headValueLo (cert.prev q) i - cert.eps * delta i - let boundHiRat : Fin (Nat.succ n) → Fin dModel → Rat := fun q i => - headValueHi (cert.prev q) i + cert.eps * delta i - let loOut : Fin dModel → Rat := fun i => - if h : activeSet.Nonempty then - activeSet.inf' h (fun q => boundLoRat q i) - else - 0 - let hiOut : Fin dModel → Rat := fun i => - if h : activeSet.Nonempty then - activeSet.sup' h (fun q => boundHiRat q i) - else - 0 - have hout : - ∀ q, q ∈ activeSet → ∀ i, - (loOut i : Real) ≤ headOutput inputs q i ∧ - headOutput inputs q i ≤ (hiOut i : Real) := by - intro q hq i - have hactive : activeSet.Nonempty := ⟨q, hq⟩ - have hspec : - dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) ≤ - headValueRealOfInputs inputs (cert.prev q) i + - (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) ∧ - headValueRealOfInputs inputs (cert.prev q) i ≤ - dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) + - (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) := by - have happrox' : - ∀ q, q ∈ activeSet → - dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) ≤ - headValueRealOfInputs inputs (cert.prev q) i + - (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) ∧ - headValueRealOfInputs inputs (cert.prev q) i ≤ - dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) + - (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) := by - simpa [Layers.InductionSpecApproxOn_def] using (happrox i) - exact happrox' q hq - have hout_def : - headOutput inputs q i = - dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) := by - simp [headOutput, headOutputWithScores, scoresReal, weights] - have hprev_bounds := hhead_bounds (cert.prev q) i - have hupper : - headOutput inputs q i ≤ (boundHiRat q i : Real) := by - have hupper' : - dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) ≤ - headValueRealOfInputs inputs (cert.prev q) i + - (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) := by - exact hspec.1 - have hupper'' : - dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) ≤ - (headValueHi (cert.prev q) i : Real) + - (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) := by - have hprev_bounds' := - (add_le_add_iff_right - ((cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)))).2 - hprev_bounds.2 - exact le_trans hupper' hprev_bounds' - simpa - [hout_def, boundHiRat, delta, ratToReal_add, ratToReal_mul, - ratToReal_sub] using - hupper'' - have hlower : - (boundLoRat q i : Real) ≤ headOutput inputs q i := by - have hlower' : - (headValueRealOfInputs inputs (cert.prev q) i : Real) - - (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) ≤ - dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) := by - exact (sub_le_iff_le_add).2 hspec.2 - have hlower'' : - (headValueLo (cert.prev q) i : Real) - - (cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)) ≤ - dotProduct (weights q) (fun k => headValueRealOfInputs inputs k i) := by - refine le_trans (sub_le_sub_right hprev_bounds.1 - ((cert.eps : Real) * ((hiVal i : Real) - (loVal i : Real)))) ?_ - exact hlower' - simpa [hout_def, boundLoRat, delta, ratToReal_mul, ratToReal_sub] using - hlower'' - have hlo : - (loOut i : Real) ≤ (boundLoRat q i : Real) := by - have hloRat : loOut i ≤ boundLoRat q i := by - simpa [loOut, hactive] using - (Finset.inf'_le - (s := activeSet) - (f := fun q => boundLoRat q i) - (b := q) hq) - simpa [ratToReal_def] using ratToReal_le_of_le hloRat - have hhi : - (boundHiRat q i : Real) ≤ (hiOut i : Real) := by - have hhiRat : boundHiRat q i ≤ hiOut i := by - simpa [hiOut, hactive] using - (Finset.le_sup' - (s := activeSet) - (f := fun q => boundHiRat q i) - (b := q) hq) - simpa [ratToReal_def] using ratToReal_le_of_le hhiRat - exact ⟨le_trans hlo hlower, le_trans hupper hhi⟩ - have hbounds : Circuit.ResidualIntervalBounds { lo := loOut, hi := hiOut } := by - refine { lo_le_hi := ?_ } - intro i - by_cases hactive : activeSet.Nonempty - · rcases hactive with ⟨q, hq⟩ - have hout_i := hout q hq i - have hreal : ratToReal (loOut i) ≤ ratToReal (hiOut i) := by - simpa [ratToReal_def] using le_trans hout_i.1 hout_i.2 - exact (ratToReal_le_iff (x := loOut i) (y := hiOut i)).1 hreal - · simp [loOut, hiOut, hactive] - let certOut : Circuit.ResidualIntervalCert dModel := { lo := loOut, hi := hiOut } - exact some - { active := activeSet - cert := certOut - sound := - { bounds := hbounds - output_mem := by - intro q hq i - exact hout q hq i } } - · exact none - · exact none - end -end HeadOutputInterval - end Sound end Nfp diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean index 2554f68..970daaf 100644 --- a/Nfp/Sound/Induction/LogitDiff.lean +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -6,9 +6,9 @@ public import Aesop public import Mathlib.Data.List.MinMax public import Mathlib.Data.Vector.Basic public import Nfp.Circuit.Cert.LogitDiff +public import Nfp.Sound.Bounds.Cache public import Nfp.Sound.Bounds.MatrixNorm.Interval public import Nfp.Sound.Induction.HeadOutput -public import Nfp.Sound.Induction.Refine /-! Logit-diff bounds derived from induction certificates. @@ -314,157 +314,6 @@ theorem logitDiffLowerBoundFromCacheWithEpsVals_def Circuit.logitDiffLowerBoundAtLoAt c.active c.prev epsAt loAt valsLo := by rfl -/-- Refined unweighted logit-diff lower bound using an overlayed `epsAt`. -/ -def logitDiffLowerBoundRefinedFromCache - (inputs : Model.InductionHeadInputs seq dModel dHead) - (core : InductionHeadCoreCache seq dModel dHead) - (c : InductionHeadCert seq) (cache : LogitDiffCache seq) - (spec : InductionHeadRefineSpec seq) : Option Rat := - let weightBoundAt := weightBoundAtOverlay inputs core spec - let epsAt := epsAtOverlay core weightBoundAt - logitDiffLowerBoundFromCacheWithEps c cache epsAt - -/-- Unfolding lemma for `logitDiffLowerBoundRefinedFromCache`. -/ -theorem logitDiffLowerBoundRefinedFromCache_def - (inputs : Model.InductionHeadInputs seq dModel dHead) - (core : InductionHeadCoreCache seq dModel dHead) - (c : InductionHeadCert seq) (cache : LogitDiffCache seq) - (spec : InductionHeadRefineSpec seq) : - logitDiffLowerBoundRefinedFromCache inputs core c cache spec = - let weightBoundAt := weightBoundAtOverlay inputs core spec - let epsAt := epsAtOverlay core weightBoundAt - logitDiffLowerBoundFromCacheWithEps c cache epsAt := by - rfl - -/-- Refine-on-demand unweighted logit-diff bound using a supplied refinement spec. -/ -def logitDiffLowerBoundRefineOnDemandWithSpec - (inputs : Model.InductionHeadInputs seq dModel dHead) - (core : InductionHeadCoreCache seq dModel dHead) - (c : InductionHeadCert seq) (cache : LogitDiffCache seq) - (spec : InductionHeadRefineSpec seq) : Option Rat := - match logitDiffLowerBoundFromCache c cache with - | none => none - | some lb0 => - if lb0 ≤ 0 then - match logitDiffLowerBoundRefinedFromCache inputs core c cache spec with - | some lb1 => some (max lb0 lb1) - | none => some lb0 - else - some lb0 - -/-- Unfolding lemma for `logitDiffLowerBoundRefineOnDemandWithSpec`. -/ -theorem logitDiffLowerBoundRefineOnDemandWithSpec_def - (inputs : Model.InductionHeadInputs seq dModel dHead) - (core : InductionHeadCoreCache seq dModel dHead) - (c : InductionHeadCert seq) (cache : LogitDiffCache seq) - (spec : InductionHeadRefineSpec seq) : - logitDiffLowerBoundRefineOnDemandWithSpec inputs core c cache spec = - match logitDiffLowerBoundFromCache c cache with - | none => none - | some lb0 => - if lb0 ≤ 0 then - match logitDiffLowerBoundRefinedFromCache inputs core c cache spec with - | some lb1 => some (max lb0 lb1) - | none => some lb0 - else - some lb0 := by - rfl - -/-- Refine-on-demand unweighted logit-diff bound, refining only the argmin query. -/ -def logitDiffLowerBoundRefineOnDemand - (inputs : Model.InductionHeadInputs seq dModel dHead) - (core : InductionHeadCoreCache seq dModel dHead) - (c : InductionHeadCert seq) (cache : LogitDiffCache seq) : Option Rat := - match logitDiffLowerBoundFromCache c cache with - | none => none - | some lb0 => - if lb0 ≤ 0 then - match logitDiffLowerBoundArgminFromCache c cache with - | none => some lb0 - | some q0 => - let refineBudget := max 1 core.splitBudgetDiffRefined - let spec := refineSpecForQueryWithWeightOnes inputs core q0 refineBudget - match logitDiffLowerBoundRefinedFromCache inputs core c cache spec with - | none => some lb0 - | some lb1 => - let lb01 := max lb0 lb1 - let lbWeight? : Option Rat := - if lb01 ≤ 0 then - let refineBudget' := refineBudgetBoost refineBudget - let spec' := refineSpecForQueryWithWeightOnes inputs core q0 refineBudget' - match logitDiffLowerBoundRefinedFromCache inputs core c cache spec' with - | some lb2 => some (max lb01 lb2) - | none => some lb01 - else - some lb01 - match lbWeight? with - | none => some lb01 - | some lbWeight => - if lbWeight ≤ 0 then - let valBudget := refineBudgetBoost refineBudget - let valCount := refineLowValueCount refineBudget - let valKeys := - loAtKeysAt inputs core q0 ∪ - lowValueKeysAt inputs core q0 valCount - let valsLo := valsLoOverlay inputs core valBudget valKeys - let lbRefined? := - logitDiffLowerBoundFromCacheWithEpsVals c cache.epsAt valsLo - match lbRefined? with - | some lb2 => some (max lbWeight lb2) - | none => some lbWeight - else - some lbWeight - else - some lb0 - -/-- Unfolding lemma for `logitDiffLowerBoundRefineOnDemand`. -/ -theorem logitDiffLowerBoundRefineOnDemand_def - (inputs : Model.InductionHeadInputs seq dModel dHead) - (core : InductionHeadCoreCache seq dModel dHead) - (c : InductionHeadCert seq) (cache : LogitDiffCache seq) : - logitDiffLowerBoundRefineOnDemand inputs core c cache = - match logitDiffLowerBoundFromCache c cache with - | none => none - | some lb0 => - if lb0 ≤ 0 then - match logitDiffLowerBoundArgminFromCache c cache with - | none => some lb0 - | some q0 => - let refineBudget := max 1 core.splitBudgetDiffRefined - let spec := refineSpecForQueryWithWeightOnes inputs core q0 refineBudget - match logitDiffLowerBoundRefinedFromCache inputs core c cache spec with - | none => some lb0 - | some lb1 => - let lb01 := max lb0 lb1 - let lbWeight? : Option Rat := - if lb01 ≤ 0 then - let refineBudget' := refineBudgetBoost refineBudget - let spec' := refineSpecForQueryWithWeightOnes inputs core q0 refineBudget' - match logitDiffLowerBoundRefinedFromCache inputs core c cache spec' with - | some lb2 => some (max lb01 lb2) - | none => some lb01 - else - some lb01 - match lbWeight? with - | none => some lb01 - | some lbWeight => - if lbWeight ≤ 0 then - let valBudget := refineBudgetBoost refineBudget - let valCount := refineLowValueCount refineBudget - let valKeys := - loAtKeysAt inputs core q0 ∪ - lowValueKeysAt inputs core q0 valCount - let valsLo := valsLoOverlay inputs core valBudget valKeys - let lbRefined? := - logitDiffLowerBoundFromCacheWithEpsVals c cache.epsAt valsLo - match lbRefined? with - | some lb2 => some (max lbWeight lb2) - | none => some lbWeight - else - some lbWeight - else - some lb0 := by - rfl /-- Weighted logit-diff lower bound from a shared cache. -/ def logitDiffLowerBoundWeightedFromCache (c : InductionHeadCert seq) (cache : LogitDiffCache seq) : @@ -1096,70 +945,6 @@ theorem logitDiffLowerBoundFromCertBest_le max_le_iff.mpr ⟨h0le, h1le⟩ simpa [ratToReal_max, ratToReal_def] using hmax -/-- Certified logit-diff lower bound derived from exact head inputs. -/ -structure InductionLogitLowerBoundResult - (inputs : Model.InductionHeadInputs seq dModel dHead) where - /-- Induction certificate built from the head inputs. -/ - cert : InductionHeadCert seq - /-- Soundness proof for the induction certificate. -/ - sound : InductionHeadCertSound inputs cert - /-- Reported lower bound on logit diff. -/ - lb : Rat - /-- `lb` is computed from `logitDiffLowerBoundFromCert`. -/ - lb_def : logitDiffLowerBoundFromCert cert = some lb - /-- The lower bound is sound on active queries. -/ - lb_sound : ∀ q, q ∈ cert.active → (lb : Real) ≤ headLogitDiff inputs q - -/-- Nonvacuous logit-diff bound (strictly positive). -/ -structure InductionLogitLowerBoundNonvacuous - (inputs : Model.InductionHeadInputs seq dModel dHead) where - /-- Base logit-diff bound data. -/ - base : InductionLogitLowerBoundResult inputs - /-- The reported bound is strictly positive. -/ - lb_pos : 0 < base.lb - -/-- Build a logit-diff lower bound from exact head inputs. -/ -def buildInductionLogitLowerBoundFromHead? - (inputs : Model.InductionHeadInputs seq dModel dHead) : - Option (InductionLogitLowerBoundResult inputs) := by - classical - cases hcert : buildInductionCertFromHead? inputs with - | none => exact none - | some certWithProof => - rcases certWithProof with ⟨cert, hsound⟩ - cases hlb : logitDiffLowerBoundFromCert cert with - | none => exact none - | some lb => - refine some ?_ - refine - { cert := cert - sound := hsound - lb := lb - lb_def := hlb - lb_sound := ?_ } - intro q hq - exact - logitDiffLowerBoundFromCert_le - (inputs := inputs) - (c := cert) - (hsound := hsound) - (lb := lb) - (hbound := hlb) - (q := q) - hq - -/-- Build a strictly positive logit-diff lower bound from exact head inputs. -/ -def buildInductionLogitLowerBoundNonvacuous? - (inputs : Model.InductionHeadInputs seq dModel dHead) : - Option (InductionLogitLowerBoundNonvacuous inputs) := by - classical - cases hbase : buildInductionLogitLowerBoundFromHead? inputs with - | none => exact none - | some base => - by_cases hpos : 0 < base.lb - · exact some ⟨base, hpos⟩ - · exact none - end WithNeZero /-! End-to-end lower bounds from head certificates plus residual intervals. -/ diff --git a/Nfp/Sound/Induction/Refine.lean b/Nfp/Sound/Induction/Refine.lean deleted file mode 100644 index 9f3b9ab..0000000 --- a/Nfp/Sound/Induction/Refine.lean +++ /dev/null @@ -1,546 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.Sound.Induction.Core - -/-! -Refine-on-demand helpers for induction-head bounds. - -These definitions reuse cached core bounds to compute tightened score gaps and -weight bounds for selected query/key pairs without rebuilding the full cache. --/ - -public section - -namespace Nfp - -namespace Sound - -variable {seq dModel dHead : Nat} - -/-- Specification for refining per-key bounds. -/ -structure InductionHeadRefineSpec (seq : Nat) where - /-- Keys to refine for each query. -/ - refineKeys : Fin seq → Finset (Fin seq) - /-- Split budget for refined diff bounds. -/ - splitBudgetDiffRefined : Nat - -/-- Heuristic boost for refinement budgets. -/ -def refineBudgetBoost (budget : Nat) : Nat := - max (budget + 1) (2 * budget) - -/-- Unfolding lemma for `refineBudgetBoost`. -/ -theorem refineBudgetBoost_def (budget : Nat) : - refineBudgetBoost budget = max (budget + 1) (2 * budget) := by - rfl - -/-- Heuristic cap on the number of top-weight keys to refine. -/ -def refineTopWeightCount (budget : Nat) : Nat := - min 8 (max 1 (2 * budget)) - -/-- Unfolding lemma for `refineTopWeightCount`. -/ -theorem refineTopWeightCount_def (budget : Nat) : - refineTopWeightCount budget = min 8 (max 1 (2 * budget)) := by - rfl - -/-- Heuristic cap on the number of low-value keys to refine. -/ -def refineLowValueCount (budget : Nat) : Nat := - min 8 (max 1 (2 * budget)) - -/-- Unfolding lemma for `refineLowValueCount`. -/ -theorem refineLowValueCount_def (budget : Nat) : - refineLowValueCount budget = min 8 (max 1 (2 * budget)) := by - rfl - -/-- Scale used for refined value bounds. -/ -def valRefineScale (budget : Nat) : Nat := - Bounds.sqrtLowerScale * refineBudgetBoost budget - -/-- Unfolding lemma for `valRefineScale`. -/ -theorem valRefineScale_def (budget : Nat) : - valRefineScale budget = Bounds.sqrtLowerScale * refineBudgetBoost budget := by - rfl - -/-- Worst key under the base score-gap lower bound (excluding `prev`). -/ -def worstKeyBase - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (q : Fin seq) : Option (Fin seq) := - let ks := (List.finRange seq).filter (fun k => decide (k ≠ inputs.prev q)) - match ks with - | [] => none - | k :: ks => - let step (best : Rat × Fin seq) (k : Fin seq) := - let s := cache.scoreGapLoBase q k - if s ≤ best.1 then (s, k) else best - some (ks.foldl step (cache.scoreGapLoBase q k, k)).2 - -/-- Unfolding lemma for `worstKeyBase`. -/ -theorem worstKeyBase_def - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (q : Fin seq) : - worstKeyBase inputs cache q = - let ks := (List.finRange seq).filter (fun k => decide (k ≠ inputs.prev q)) - match ks with - | [] => none - | k :: ks => - let step (best : Rat × Fin seq) (k : Fin seq) := - let s := cache.scoreGapLoBase q k - if s ≤ best.1 then (s, k) else best - some (ks.foldl step (cache.scoreGapLoBase q k, k)).2 := by - rfl - -/-- Keys whose base weight bounds are already `1`. -/ -def weightOneKeysAt - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (q : Fin seq) : Finset (Fin seq) := - let others : Finset (Fin seq) := - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - others.filter (fun k => decide (cache.weightBoundAt q k = (1 : Rat))) - -/-- Unfolding lemma for `weightOneKeysAt`. -/ -theorem weightOneKeysAt_def - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (q : Fin seq) : - weightOneKeysAt inputs cache q = - let others : Finset (Fin seq) := - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - others.filter (fun k => decide (cache.weightBoundAt q k = (1 : Rat))) := by - rfl - -/-- Keys attaining the per-query lower-value minimum (excluding `prev`). -/ -def loAtKeysAt - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (q : Fin seq) : Finset (Fin seq) := - let others : Finset (Fin seq) := - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - if h : others.Nonempty then - let lo := others.inf' h cache.valsLo - others.filter (fun k => decide (cache.valsLo k = lo)) - else - ∅ - -/-- Unfolding lemma for `loAtKeysAt`. -/ -theorem loAtKeysAt_def - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (q : Fin seq) : - loAtKeysAt inputs cache q = - let others : Finset (Fin seq) := - (Finset.univ : Finset (Fin seq)).erase (inputs.prev q) - if h : others.Nonempty then - let lo := others.inf' h cache.valsLo - others.filter (fun k => decide (cache.valsLo k = lo)) - else - ∅ := by - rfl - -/-- Top-weight keys for a query (excluding `prev`), capped by `count`. -/ -def topWeightKeysAt - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (q : Fin seq) (count : Nat) : Finset (Fin seq) := - if count = 0 then - ∅ - else - let others := (List.finRange seq).filter (fun k => decide (k ≠ inputs.prev q)) - let weighted : Array (Rat × Fin seq) := - others.toArray.map (fun k => (cache.weightBoundAt q k, k)) - let sorted := weighted.qsort (fun a b => a.1 > b.1) - let keys := (sorted.toList.take count).map (fun p => p.2) - keys.foldr (fun k acc => insert k acc) ∅ - -/-- Unfolding lemma for `topWeightKeysAt`. -/ -theorem topWeightKeysAt_def - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (q : Fin seq) (count : Nat) : - topWeightKeysAt inputs cache q count = - if count = 0 then - ∅ - else - let others := (List.finRange seq).filter (fun k => decide (k ≠ inputs.prev q)) - let weighted : Array (Rat × Fin seq) := - others.toArray.map (fun k => (cache.weightBoundAt q k, k)) - let sorted := weighted.qsort (fun a b => a.1 > b.1) - let keys := (sorted.toList.take count).map (fun p => p.2) - keys.foldr (fun k acc => insert k acc) ∅ := by - rfl - -/-- Low-value keys for a query (excluding `prev`), capped by `count`. -/ -def lowValueKeysAt - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (q : Fin seq) (count : Nat) : Finset (Fin seq) := - if count = 0 then - ∅ - else - let others := (List.finRange seq).filter (fun k => decide (k ≠ inputs.prev q)) - let valued : Array (Rat × Fin seq) := - others.toArray.map (fun k => (cache.valsLo k, k)) - let sorted := valued.qsort (fun a b => a.1 < b.1) - let keys := (sorted.toList.take count).map (fun p => p.2) - keys.foldr (fun k acc => insert k acc) ∅ - -/-- Unfolding lemma for `lowValueKeysAt`. -/ -theorem lowValueKeysAt_def - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (q : Fin seq) (count : Nat) : - lowValueKeysAt inputs cache q count = - if count = 0 then - ∅ - else - let others := (List.finRange seq).filter (fun k => decide (k ≠ inputs.prev q)) - let valued : Array (Rat × Fin seq) := - others.toArray.map (fun k => (cache.valsLo k, k)) - let sorted := valued.qsort (fun a b => a.1 < b.1) - let keys := (sorted.toList.take count).map (fun p => p.2) - keys.foldr (fun k acc => insert k acc) ∅ := by - rfl - -/-- Refinement keys for a query, seeded by negative base gaps and the worst key. -/ -def refineKeysAt - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (q : Fin seq) : Finset (Fin seq) := - let neg := - (cache.otherKeys q).filter (fun k => decide (cache.scoreGapLoBase q k < 0)) - match worstKeyBase inputs cache q with - | none => neg - | some k => insert k neg - -/-- Unfolding lemma for `refineKeysAt`. -/ -theorem refineKeysAt_def - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (q : Fin seq) : - refineKeysAt inputs cache q = - let neg := - (cache.otherKeys q).filter (fun k => decide (cache.scoreGapLoBase q k < 0)) - match worstKeyBase inputs cache q with - | none => neg - | some k => insert k neg := by - rfl - -/-- Refinement keys that also include weight-one, `loAt`-minimizing, and top-weight keys. -/ -def refineKeysAtWithWeightOnes - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (q : Fin seq) (budget : Nat) : Finset (Fin seq) := - let topCount := refineTopWeightCount budget - refineKeysAt inputs cache q ∪ - weightOneKeysAt inputs cache q ∪ - loAtKeysAt inputs cache q ∪ - topWeightKeysAt inputs cache q topCount - -/-- Unfolding lemma for `refineKeysAtWithWeightOnes`. -/ -theorem refineKeysAtWithWeightOnes_def - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (q : Fin seq) (budget : Nat) : - refineKeysAtWithWeightOnes inputs cache q budget = - let topCount := refineTopWeightCount budget - refineKeysAt inputs cache q ∪ - weightOneKeysAt inputs cache q ∪ - loAtKeysAt inputs cache q ∪ - topWeightKeysAt inputs cache q topCount := by - rfl - -/-- Refinement spec focused on a single query. -/ -def refineSpecForQuery - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (q : Fin seq) (budget : Nat) : InductionHeadRefineSpec seq := - let keys := refineKeysAt inputs cache q - { refineKeys := fun q' => if _ : q' = q then keys else ∅ - splitBudgetDiffRefined := budget } - -/-- Unfolding lemma for `refineSpecForQuery`. -/ -theorem refineSpecForQuery_def - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (q : Fin seq) (budget : Nat) : - refineSpecForQuery inputs cache q budget = - let keys := refineKeysAt inputs cache q - { refineKeys := fun q' => if _ : q' = q then keys else ∅ - splitBudgetDiffRefined := budget } := by - rfl - -/-- Refinement spec for a single query, including weight-one, `loAt`-minimizing, and top-weight -keys. -/ -def refineSpecForQueryWithWeightOnes - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (q : Fin seq) (budget : Nat) : InductionHeadRefineSpec seq := - let keys := refineKeysAtWithWeightOnes inputs cache q budget - { refineKeys := fun q' => if _ : q' = q then keys else ∅ - splitBudgetDiffRefined := budget } - -/-- Unfolding lemma for `refineSpecForQueryWithWeightOnes`. -/ -theorem refineSpecForQueryWithWeightOnes_def - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (q : Fin seq) (budget : Nat) : - refineSpecForQueryWithWeightOnes inputs cache q budget = - let keys := refineKeysAtWithWeightOnes inputs cache q budget - { refineKeys := fun q' => if _ : q' = q then keys else ∅ - splitBudgetDiffRefined := budget } := by - rfl - -/-- Refined value lower bound at a single key (fallbacks to base bounds if disabled). -/ -def valsLoRefinedAt - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (budget : Nat) (k : Fin seq) : Rat := - let scale := valRefineScale budget - if _ : 0 < inputs.lnEps then - if _ : 0 < Bounds.sqrtLowerWithScale scale inputs.lnEps then - if _ : dModel = 0 then - cache.cert.values.valsLo k - else - let dirHead : Fin dHead → Rat := fun d => (dirHeadVecOfInputs inputs).get d - let wvDir : Fin dModel → Rat := - fun j => Linear.dotFin dHead dirHead (fun d => inputs.wv j d) - let bDir : Rat := - Linear.dotFin dHead dirHead (fun d => inputs.bv d) - let lnBounds := - Bounds.layerNormBoundsWithScale scale inputs.lnEps inputs.ln1Gamma inputs.ln1Beta - (inputs.embed k) - let lnLo := lnBounds.1 - let lnHi := lnBounds.2 - bDir + Bounds.dotIntervalLower wvDir lnLo lnHi - else - cache.cert.values.valsLo k - else - cache.cert.values.valsLo k - -/-- Unfolding lemma for `valsLoRefinedAt`. -/ -theorem valsLoRefinedAt_def - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (budget : Nat) (k : Fin seq) : - valsLoRefinedAt inputs cache budget k = - let scale := valRefineScale budget - if _ : 0 < inputs.lnEps then - if _ : 0 < Bounds.sqrtLowerWithScale scale inputs.lnEps then - if _ : dModel = 0 then - cache.cert.values.valsLo k - else - let dirHead : Fin dHead → Rat := fun d => (dirHeadVecOfInputs inputs).get d - let wvDir : Fin dModel → Rat := - fun j => Linear.dotFin dHead dirHead (fun d => inputs.wv j d) - let bDir : Rat := - Linear.dotFin dHead dirHead (fun d => inputs.bv d) - let lnBounds := - Bounds.layerNormBoundsWithScale scale inputs.lnEps inputs.ln1Gamma inputs.ln1Beta - (inputs.embed k) - let lnLo := lnBounds.1 - let lnHi := lnBounds.2 - bDir + Bounds.dotIntervalLower wvDir lnLo lnHi - else - cache.cert.values.valsLo k - else - cache.cert.values.valsLo k := by - rfl - -/-- Overlay refined value lower bounds on a subset of keys. -/ -def valsLoOverlay - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (budget : Nat) (refineKeys : Finset (Fin seq)) : Fin seq → Rat := fun k => - if k ∈ refineKeys then - valsLoRefinedAt inputs cache budget k - else - cache.cert.values.valsLo k - -/-- Unfolding lemma for `valsLoOverlay`. -/ -theorem valsLoOverlay_def - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (budget : Nat) (refineKeys : Finset (Fin seq)) (k : Fin seq) : - valsLoOverlay inputs cache budget refineKeys k = - if k ∈ refineKeys then - valsLoRefinedAt inputs cache budget k - else - cache.cert.values.valsLo k := by - rfl - -/-- Refined diff dot-product lower bound at a single `(q,k)` pair. -/ -def dotDiffLoRefinedAt - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (budget : Nat) (q k : Fin seq) : Rat := - let dimsQ := cache.splitDimsQ q - let dimsDiff := cache.splitDimsDiffCore budget q k - let prev := inputs.prev q - (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => cache.qLo q d) (fun d => cache.qHi q d) - (fun d => cache.kLo prev d - cache.kHi k d) - (fun d => cache.kHi prev d - cache.kLo k d)).1 - -/-- Unfolding lemma for `dotDiffLoRefinedAt`. -/ -theorem dotDiffLoRefinedAt_def - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (budget : Nat) (q k : Fin seq) : - dotDiffLoRefinedAt inputs cache budget q k = - let dimsQ := cache.splitDimsQ q - let dimsDiff := cache.splitDimsDiffCore budget q k - let prev := inputs.prev q - (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => cache.qLo q d) (fun d => cache.qHi q d) - (fun d => cache.kLo prev d - cache.kHi k d) - (fun d => cache.kHi prev d - cache.kLo k d)).1 := by - rfl - -/-- Refined diff dot-product upper bound at a single `(q,k)` pair. -/ -def dotDiffHiRefinedAt - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (budget : Nat) (q k : Fin seq) : Rat := - let dimsQ := cache.splitDimsQ q - let dimsDiff := cache.splitDimsDiffCore budget q k - let prev := inputs.prev q - (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => cache.qLo q d) (fun d => cache.qHi q d) - (fun d => cache.kLo prev d - cache.kHi k d) - (fun d => cache.kHi prev d - cache.kLo k d)).2 - -/-- Unfolding lemma for `dotDiffHiRefinedAt`. -/ -theorem dotDiffHiRefinedAt_def - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (budget : Nat) (q k : Fin seq) : - dotDiffHiRefinedAt inputs cache budget q k = - let dimsQ := cache.splitDimsQ q - let dimsDiff := cache.splitDimsDiffCore budget q k - let prev := inputs.prev q - (_root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth dimsQ dimsDiff - (fun d => cache.qLo q d) (fun d => cache.qHi q d) - (fun d => cache.kLo prev d - cache.kHi k d) - (fun d => cache.kHi prev d - cache.kLo k d)).2 := by - rfl - -/-- Refined score-gap lower bound at `(q,k)` using a custom diff budget. -/ -def scoreGapLoRefinedAt - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (budget : Nat) (q k : Fin seq) : Rat := - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - if masked q (inputs.prev q) then - cache.scoreLoPrev q - cache.scoreHi q k - else if masked q k then - cache.scoreLoPrev q - inputs.maskValue - else if _ : 0 ≤ inputs.scale then - inputs.scale * dotDiffLoRefinedAt inputs cache budget q k - else - inputs.scale * dotDiffHiRefinedAt inputs cache budget q k - -/-- Unfolding lemma for `scoreGapLoRefinedAt`. -/ -theorem scoreGapLoRefinedAt_def - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (budget : Nat) (q k : Fin seq) : - scoreGapLoRefinedAt inputs cache budget q k = - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - if masked q (inputs.prev q) then - cache.scoreLoPrev q - cache.scoreHi q k - else if masked q k then - cache.scoreLoPrev q - inputs.maskValue - else if _ : 0 ≤ inputs.scale then - inputs.scale * dotDiffLoRefinedAt inputs cache budget q k - else - inputs.scale * dotDiffHiRefinedAt inputs cache budget q k := by - rfl - -/-- Refined per-key weight bound at `(q,k)` derived from refined score gaps. -/ -def weightBoundAtRefinedAt - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (budget : Nat) (q k : Fin seq) : Rat := - if _ : k = inputs.prev q then - (0 : Rat) - else - let gap := scoreGapLoRefinedAt inputs cache budget q k - if gap < 0 then - (1 : Rat) - else - ratDivUp 1 (1 + gap) - -/-- Unfolding lemma for `weightBoundAtRefinedAt`. -/ -theorem weightBoundAtRefinedAt_def - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (budget : Nat) (q k : Fin seq) : - weightBoundAtRefinedAt inputs cache budget q k = - if _ : k = inputs.prev q then - (0 : Rat) - else - let gap := scoreGapLoRefinedAt inputs cache budget q k - if gap < 0 then - (1 : Rat) - else - ratDivUp 1 (1 + gap) := by - rfl - -/-- Overlay that refines only selected `(q,k)` weight bounds. -/ -def weightBoundAtOverlay - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (spec : InductionHeadRefineSpec seq) : - Fin seq → Fin seq → Rat := fun q k => - if _ : k = inputs.prev q then - (0 : Rat) - else if _ : k ∈ spec.refineKeys q then - weightBoundAtRefinedAt inputs cache spec.splitBudgetDiffRefined q k - else - cache.weightBoundAt q k - -/-- Unfolding lemma for `weightBoundAtOverlay`. -/ -theorem weightBoundAtOverlay_def - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (spec : InductionHeadRefineSpec seq) - (q k : Fin seq) : - weightBoundAtOverlay inputs cache spec q k = - if _ : k = inputs.prev q then - (0 : Rat) - else if _ : k ∈ spec.refineKeys q then - weightBoundAtRefinedAt inputs cache spec.splitBudgetDiffRefined q k - else - cache.weightBoundAt q k := by - rfl - -/-- Overlayed eps bound derived from overlayed per-key bounds. -/ -def epsAtOverlay - (cache : InductionHeadCoreCache seq dModel dHead) - (weightBoundAt : Fin seq → Fin seq → Rat) : - Fin seq → Rat := fun q => - let other : Finset (Fin seq) := - (Finset.univ : Finset (Fin seq)).erase (cache.cert.prev q) - let total := other.sum (fun k => weightBoundAt q k) - min (1 : Rat) total - -/-- Unfolding lemma for `epsAtOverlay`. -/ -theorem epsAtOverlay_def - (cache : InductionHeadCoreCache seq dModel dHead) - (weightBoundAt : Fin seq → Fin seq → Rat) - (q : Fin seq) : - epsAtOverlay cache weightBoundAt q = - let other : Finset (Fin seq) := - (Finset.univ : Finset (Fin seq)).erase (cache.cert.prev q) - let total := other.sum (fun k => weightBoundAt q k) - min (1 : Rat) total := by - rfl - -end Sound - -end Nfp diff --git a/Nfp/Sound/Induction/RefineSound.lean b/Nfp/Sound/Induction/RefineSound.lean deleted file mode 100644 index 9a09962..0000000 --- a/Nfp/Sound/Induction/RefineSound.lean +++ /dev/null @@ -1,907 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.Sound.Induction.LogitDiffSound -public import Nfp.Sound.Induction.OneHot -public import Nfp.Sound.Induction.Refine -public import Nfp.Sound.Induction.CoreSound.Values - -/-! -Soundness lemmas for refine-on-demand overlays. --/ - -public section - -namespace Nfp - -namespace Sound - -open Nfp.Circuit -open Nfp.Sound.Bounds - -variable {seq dModel dHead : Nat} - -/-- Refined score-gap bounds are sound when cache score and KV bounds are sound. -/ -theorem scoreGapLoRefinedAt_real_at_of_bounds - [NeZero seq] (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (budget : Nat) - (active : Finset (Fin seq)) - (hq_bounds : - ∀ q d, (cache.qLo q d : Real) ≤ qRealOfInputs inputs q d ∧ - qRealOfInputs inputs q d ≤ (cache.qHi q d : Real)) - (hk_bounds : - ∀ q d, (cache.kLo q d : Real) ≤ kRealOfInputs inputs q d ∧ - kRealOfInputs inputs q d ≤ (cache.kHi q d : Real)) - (hscore_prev : - ∀ q, q ∈ active → - (cache.scoreLoPrev q : Real) ≤ scoresRealOfInputs inputs q (inputs.prev q)) - (hscore_hi : - ∀ q k, scoresRealOfInputs inputs q k ≤ (cache.scoreHi q k : Real)) : - ∀ q, q ∈ active → ∀ k, k ≠ inputs.prev q → - scoresRealOfInputs inputs q k + - (scoreGapLoRefinedAt inputs cache budget q k : Real) ≤ - scoresRealOfInputs inputs q (inputs.prev q) := by - classical - let scoresReal := scoresRealOfInputs inputs - let masked : Fin seq → Fin seq → Prop := fun q k => - inputs.maskCausal = true ∧ q < k - have scoresReal_eq_base_of_not_masked : - ∀ q k, ¬ masked q k → - scoresReal q k = - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) := by - intro q k hnot - by_cases hcausal : inputs.maskCausal - · have hnot_lt : ¬ q < k := by - intro hlt - exact hnot ⟨hcausal, hlt⟩ - have hle : k ≤ q := le_of_not_gt hnot_lt - simp [scoresReal, scoresRealOfInputs_def, hcausal, hle] - · simp [scoresReal, scoresRealOfInputs_def, hcausal] - have scoresReal_eq_masked : - ∀ q k, masked q k → scoresReal q k = (inputs.maskValue : Real) := by - intro q k hmask - have hmask' : inputs.maskCausal = true ∧ q < k := by - simpa [masked] using hmask - have hle : ¬ k ≤ q := not_le_of_gt hmask'.2 - simp [scoresReal, scoresRealOfInputs_def, hmask'.1, hle] - have hdot_diff_bounds : - ∀ q k, - (dotDiffLoRefinedAt inputs cache budget q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) ∧ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) ≤ - (dotDiffHiRefinedAt inputs cache budget q k : Real) := by - intro q k - have hlo1 : ∀ d, (cache.qLo q d : Real) ≤ qRealOfInputs inputs q d := fun d => - (hq_bounds q d).1 - have hhi1 : ∀ d, qRealOfInputs inputs q d ≤ (cache.qHi q d : Real) := fun d => - (hq_bounds q d).2 - have hlo2 : - ∀ d, - (cache.kLo (inputs.prev q) d - cache.kHi k d : Rat) ≤ - (kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) := by - intro d - have hprev_lo := (hk_bounds (inputs.prev q) d).1 - have hk_hi := (hk_bounds k d).2 - have h := sub_le_sub hprev_lo hk_hi - simpa [ratToReal_sub] using h - have hhi2 : - ∀ d, - (kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) ≤ - (cache.kHi (inputs.prev q) d - cache.kLo k d : Rat) := by - intro d - have hprev_hi := (hk_bounds (inputs.prev q) d).2 - have hk_lo := (hk_bounds k d).1 - have h := sub_le_sub hprev_hi hk_lo - simpa [ratToReal_sub] using h - have hspec := - _root_.Nfp.Sound.Bounds.dotIntervalLowerUpper2SignSplitBoth_spec_real - (dims1 := cache.splitDimsQ q) (dims2 := cache.splitDimsDiffCore budget q k) - (lo1 := fun d => cache.qLo q d) (hi1 := fun d => cache.qHi q d) - (lo2 := fun d => cache.kLo (inputs.prev q) d - cache.kHi k d) - (hi2 := fun d => cache.kHi (inputs.prev q) d - cache.kLo k d) - (x := fun d => qRealOfInputs inputs q d) - (y := fun d => - kRealOfInputs inputs (inputs.prev q) d - kRealOfInputs inputs k d) - hlo1 hhi1 hlo2 hhi2 - have hlow' : - (dotDiffLoRefinedAt inputs cache budget q k : Real) ≤ - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) := by - simpa [dotDiffLoRefinedAt_def] using hspec.1 - have hhigh' : - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) ≤ (dotDiffHiRefinedAt inputs cache budget q k : Real) := by - simpa [dotDiffHiRefinedAt_def] using hspec.2 - exact ⟨hlow', hhigh'⟩ - intro q hq k hk - by_cases hprevmask : masked q (inputs.prev q) - · have hscore_hi' : scoresReal q k ≤ (cache.scoreHi q k : Real) := - hscore_hi q k - have hscore_prev' : (cache.scoreLoPrev q : Real) ≤ scoresReal q (inputs.prev q) := - hscore_prev q hq - have hsum_le' : - (cache.scoreLoPrev q : Real) - (cache.scoreHi q k : Real) + scoresReal q k ≤ - (cache.scoreLoPrev q : Real) := by - have hsub : - (cache.scoreLoPrev q : Real) - (cache.scoreHi q k : Real) ≤ - (cache.scoreLoPrev q : Real) - scoresReal q k := - sub_le_sub_left hscore_hi' (cache.scoreLoPrev q : Real) - calc - (cache.scoreLoPrev q : Real) - (cache.scoreHi q k : Real) + scoresReal q k - ≤ (cache.scoreLoPrev q : Real) - scoresReal q k + scoresReal q k := by - simpa [add_comm, add_left_comm, add_assoc] using - (add_le_add_left hsub (scoresReal q k)) - _ = (cache.scoreLoPrev q : Real) := by - simp [sub_add_cancel] - calc - scoresReal q k + (scoreGapLoRefinedAt inputs cache budget q k : Real) - = (cache.scoreLoPrev q : Real) - (cache.scoreHi q k : Real) + scoresReal q k := by - simp [scoreGapLoRefinedAt_def, hprevmask, masked, add_comm] - _ ≤ (cache.scoreLoPrev q : Real) := hsum_le' - _ ≤ scoresReal q (inputs.prev q) := hscore_prev' - · by_cases hmask : masked q k - · have hscore_prev' : (cache.scoreLoPrev q : Real) ≤ scoresReal q (inputs.prev q) := - hscore_prev q hq - have hscore_k : scoresReal q k = (inputs.maskValue : Real) := - scoresReal_eq_masked q k hmask - have hmask' : inputs.maskCausal = true ∧ q < k := by - simpa [masked] using hmask - have hnot_lt_prev : ¬ q < inputs.prev q := by - intro hlt - exact hprevmask ⟨hmask'.1, hlt⟩ - calc - scoresReal q k + (scoreGapLoRefinedAt inputs cache budget q k : Real) - = (inputs.maskValue : Real) + (cache.scoreLoPrev q : Real) - - (inputs.maskValue : Real) := by - simp [scoreGapLoRefinedAt_def, hmask', hnot_lt_prev, hscore_k] - _ = (cache.scoreLoPrev q : Real) := by - simp [add_sub_cancel_left] - _ ≤ scoresReal q (inputs.prev q) := hscore_prev' - · have hdiff := hdot_diff_bounds q k - have hgap_le : - (scoreGapLoRefinedAt inputs cache budget q k : Real) ≤ - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) := by - by_cases hscale : 0 ≤ inputs.scale - · have hscale_real : 0 ≤ (inputs.scale : Real) := by - simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hscale - have hle := mul_le_mul_of_nonneg_left hdiff.1 hscale_real - simpa [scoreGapLoRefinedAt_def, hprevmask, hmask, hscale, masked] using hle - · have hscale_nonpos : inputs.scale ≤ 0 := - le_of_lt (lt_of_not_ge hscale) - have hscale_real : (inputs.scale : Real) ≤ 0 := by - simpa [ratToReal_def] using - (ratToReal_nonpos_iff (x := inputs.scale)).2 hscale_nonpos - have hle := mul_le_mul_of_nonpos_left hdiff.2 hscale_real - simpa [scoreGapLoRefinedAt_def, hprevmask, hmask, hscale, masked] using hle - have hscore_prev : - scoresReal q (inputs.prev q) = - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d) := by - simpa using - (scoresReal_eq_base_of_not_masked q (inputs.prev q) hprevmask) - have hscore_k : - scoresReal q k = - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) := by - simpa using (scoresReal_eq_base_of_not_masked q k hmask) - have hdot_sub : - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) = - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d) - - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) := by - classical - simpa using - (Nfp.Sound.Linear.dotProduct_sub_right - (x := fun d => qRealOfInputs inputs q d) - (y := fun d => kRealOfInputs inputs (inputs.prev q) d) - (z := fun d => kRealOfInputs inputs k d)) - have hscore_diff : - scoresReal q (inputs.prev q) - scoresReal q k = - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) := by - calc - scoresReal q (inputs.prev q) - scoresReal q k - = - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d) - - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d) := by - simp [hscore_prev, hscore_k] - _ = - (inputs.scale : Real) * - (dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d) - - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs k d)) := by - simp [mul_sub] - _ = - (inputs.scale : Real) * - dotProduct (fun d => qRealOfInputs inputs q d) - (fun d => kRealOfInputs inputs (inputs.prev q) d - - kRealOfInputs inputs k d) := by - simp [hdot_sub] - have hgap_le' : - (scoreGapLoRefinedAt inputs cache budget q k : Real) ≤ - scoresReal q (inputs.prev q) - scoresReal q k := by - simpa [hscore_diff] using hgap_le - have hgap_add := add_le_add_right hgap_le' (scoresReal q k) - have hgap_add' : - scoresReal q k + (scoreGapLoRefinedAt inputs cache budget q k : Real) ≤ - scoresReal q (inputs.prev q) := by - have hcancel : - scoresReal q k + (scoresReal q (inputs.prev q) - scoresReal q k) = - scoresReal q (inputs.prev q) := by - calc - scoresReal q k + (scoresReal q (inputs.prev q) - scoresReal q k) - = - scoresReal q k + scoresReal q (inputs.prev q) - - scoresReal q k := by - symm - exact add_sub_assoc (scoresReal q k) - (scoresReal q (inputs.prev q)) (scoresReal q k) - _ = scoresReal q (inputs.prev q) := by - simp [add_sub_cancel_left] - calc - scoresReal q k + (scoreGapLoRefinedAt inputs cache budget q k : Real) - ≤ scoresReal q k + (scoresReal q (inputs.prev q) - scoresReal q k) := hgap_add - _ = scoresReal q (inputs.prev q) := hcancel - exact hgap_add' - -/-- Refined per-key weight bounds are sound when refined score gaps are sound. -/ -theorem weight_bound_at_refinedAt_of_scoreGapLo - [NeZero seq] (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (budget : Nat) - (active : Finset (Fin seq)) - (hscore_gap_real_at : - ∀ q, q ∈ active → ∀ k, k ≠ inputs.prev q → - scoresRealOfInputs inputs q k + - (scoreGapLoRefinedAt inputs cache budget q k : Real) ≤ - scoresRealOfInputs inputs q (inputs.prev q)) : - ∀ q, q ∈ active → ∀ k, k ≠ inputs.prev q → - Circuit.softmax (scoresRealOfInputs inputs q) k ≤ - (weightBoundAtRefinedAt inputs cache budget q k : Real) := by - classical - intro q hq k hk - refine - Sound.weight_bound_at_of_scoreGapLo - (active := active) - (prev := inputs.prev) - (scoresReal := scoresRealOfInputs inputs) - (scoreGapLo := scoreGapLoRefinedAt inputs cache budget) - (weightBoundAt := weightBoundAtRefinedAt inputs cache budget) - (hweightBoundAt := ?_) - (hscore_gap_real_at := hscore_gap_real_at) - q hq k hk - intro q' k' hk' - simp [weightBoundAtRefinedAt_def, hk'] - -/-- Overlayed per-key bounds are sound when base and refined bounds are sound. -/ -theorem weight_bounds_at_overlay_of_refined - (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (spec : InductionHeadRefineSpec seq) - (active : Finset (Fin seq)) - (hbase : - ∀ q, q ∈ active → ∀ k, k ≠ inputs.prev q → - Circuit.softmax (scoresRealOfInputs inputs q) k ≤ - (cache.weightBoundAt q k : Real)) - (hrefine : - ∀ q, q ∈ active → ∀ k, k ≠ inputs.prev q → k ∈ spec.refineKeys q → - Circuit.softmax (scoresRealOfInputs inputs q) k ≤ - (weightBoundAtRefinedAt inputs cache spec.splitBudgetDiffRefined q k : Real)) : - ∀ q, q ∈ active → ∀ k, k ≠ inputs.prev q → - Circuit.softmax (scoresRealOfInputs inputs q) k ≤ - (weightBoundAtOverlay inputs cache spec q k : Real) := by - classical - intro q hq k hk - by_cases hmem : k ∈ spec.refineKeys q - · have h := hrefine q hq k hk hmem - simpa [weightBoundAtOverlay_def, hk, hmem] using h - · have h := hbase q hq k hk - simpa [weightBoundAtOverlay_def, hk, hmem] using h - -/-- One-hot bounds derived from an overlayed per-key bound. -/ -theorem oneHot_bounds_at_overlay - [NeZero seq] (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (c : InductionHeadCert seq) - (hcert : c = cache.cert) - (spec : InductionHeadRefineSpec seq) - (hweight_overlay : - ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → - Circuit.softmax (scoresRealOfInputs inputs q) k ≤ - (weightBoundAtOverlay inputs cache spec q k : Real)) : - ∀ q, q ∈ c.active → - Layers.OneHotApproxBoundsOnActive (Val := Real) - (epsAtOverlay cache (weightBoundAtOverlay inputs cache spec) q : Real) - (fun q' => q' = q) c.prev - (fun q' k => Circuit.softmax (scoresRealOfInputs inputs q') k) := by - classical - intro q hq - refine - Sound.oneHot_bounds_at_of_weight_bounds - (active := c.active) - (prev := c.prev) - (scoresReal := scoresRealOfInputs inputs) - (weightBoundAt := weightBoundAtOverlay inputs cache spec) - (epsAt := epsAtOverlay cache (weightBoundAtOverlay inputs cache spec)) - (hepsAt := ?_) - (hweight_bounds := ?_) q hq - · intro q' - cases hcert - simp [epsAtOverlay_def] - · intro q' hq' k hk - exact hweight_overlay q' hq' k hk - -/-- Refined value lower bounds are sound when LayerNorm bounds are sound. -/ -theorem valsLoRefinedAt_le_valsReal - [NeZero seq] (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (budget : Nat) - (hsound : InductionHeadCertSound inputs cache.cert) - (k : Fin seq) : - (valsLoRefinedAt inputs cache budget k : Real) ≤ valsRealOfInputs inputs k := by - classical - by_cases hEps : 0 < inputs.lnEps - · by_cases hSqrt : - 0 < Bounds.sqrtLowerWithScale (valRefineScale budget) inputs.lnEps - · by_cases hmodel : dModel = 0 - · have hvals := hsound.value_bounds.vals_bounds k - simpa [valsLoRefinedAt_def, hEps, hSqrt, hmodel] using hvals.1 - · let scale : Nat := valRefineScale budget - let dirHead : Fin dHead → Rat := fun d => (dirHeadVecOfInputs inputs).get d - let wvDir : Fin dModel → Rat := - fun j => Linear.dotFin dHead dirHead (fun d => inputs.wv j d) - let bDir : Rat := - Linear.dotFin dHead dirHead (fun d => inputs.bv d) - let lnBounds : - Fin seq → (Fin dModel → Rat) × (Fin dModel → Rat) := fun k' => - Bounds.layerNormBoundsWithScale scale inputs.lnEps inputs.ln1Gamma inputs.ln1Beta - (inputs.embed k') - let lnLo : Fin seq → Fin dModel → Rat := fun k' => (lnBounds k').1 - let lnHi : Fin seq → Fin dModel → Rat := fun k' => (lnBounds k').2 - let valsLo : Fin seq → Rat := fun k' => - bDir + Bounds.dotIntervalLower wvDir (lnLo k') (lnHi k') - let valsHi : Fin seq → Rat := fun k' => - bDir + Bounds.dotIntervalUpper wvDir (lnLo k') (lnHi k') - have hscale_pos : 0 < scale := by - have hbase : 0 < Bounds.sqrtLowerScale := by - simp [Bounds.sqrtLowerScale_def] - have hboost : 0 < refineBudgetBoost budget := by - have hle : budget + 1 ≤ refineBudgetBoost budget := by - simp [refineBudgetBoost_def] - exact lt_of_lt_of_le (Nat.succ_pos budget) hle - simpa [scale, valRefineScale_def] using Nat.mul_pos hbase hboost - have hln : - ∀ k' j, (lnLo k' j : Real) ≤ lnRealOfInputs inputs k' j ∧ - lnRealOfInputs inputs k' j ≤ (lnHi k' j : Real) := by - intro k' j - have hln' := - Bounds.layerNormBoundsWithScale_spec (scale := scale) - (eps := inputs.lnEps) (gamma := inputs.ln1Gamma) (beta := inputs.ln1Beta) - (x := inputs.embed k') hmodel hEps hSqrt hscale_pos - simpa [lnBounds, lnLo, lnHi, lnRealOfInputs_def] using hln' j - have hvals := - valsReal_bounds_at_of_ln_bounds (inputs := inputs) - (dirHead := dirHead) (hdirHead := rfl) - (wvDir := wvDir) (bDir := bDir) - (hwvDir := by intro j; rfl) - (hbDir := by rfl) - (lnLo := lnLo) (lnHi := lnHi) - (valsLo := valsLo) (valsHi := valsHi) - (hvalsLo := by intro k'; rfl) - (hvalsHi := by intro k'; rfl) - (hln := hln) - have hvals_k := (hvals k).1 - simpa [valsLoRefinedAt_def, hEps, hSqrt, hmodel, scale, lnBounds, lnLo, lnHi, - valsLo, wvDir, bDir] using hvals_k - · have hvals := hsound.value_bounds.vals_bounds k - simpa [valsLoRefinedAt_def, hEps, hSqrt] using hvals.1 - · have hvals := hsound.value_bounds.vals_bounds k - simpa [valsLoRefinedAt_def, hEps] using hvals.1 - -/-- Overlayed value lower bounds remain sound. -/ -theorem valsLoOverlay_le_valsReal - [NeZero seq] (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (budget : Nat) (refineKeys : Finset (Fin seq)) - (hsound : InductionHeadCertSound inputs cache.cert) : - ∀ k, (valsLoOverlay inputs cache budget refineKeys k : Real) ≤ valsRealOfInputs inputs k := by - intro k - by_cases hmem : k ∈ refineKeys - · simpa [valsLoOverlay_def, hmem] using - (valsLoRefinedAt_le_valsReal (inputs := inputs) (cache := cache) - (budget := budget) (hsound := hsound) k) - · have hvals := hsound.value_bounds.vals_bounds k - simpa [valsLoOverlay_def, hmem] using hvals.1 - -/-- The refined unweighted logit-diff lower bound is sound on active queries. -/ -theorem logitDiffLowerBoundRefinedFromCache_le - [NeZero seq] (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (c : InductionHeadCert seq) (logitCache : LogitDiffCache seq) - (spec : InductionHeadRefineSpec seq) - (hcert : c = cache.cert) - (hcache : logitCache = logitDiffCache c) - (hsound : InductionHeadCertSound inputs c) - (hweight_overlay : - ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → - Circuit.softmax (scoresRealOfInputs inputs q) k ≤ - (weightBoundAtOverlay inputs cache spec q k : Real)) - {lb : Rat} - (hbound : logitDiffLowerBoundRefinedFromCache inputs cache c logitCache spec = some lb) - {q : Fin seq} (hq : q ∈ c.active) : - (lb : Real) ≤ headLogitDiff inputs q := by - classical - have honeHot : - ∀ q, q ∈ c.active → - Layers.OneHotApproxBoundsOnActive (Val := Real) - (epsAtOverlay cache (weightBoundAtOverlay inputs cache spec) q : Real) - (fun q' => q' = q) c.prev - (fun q' k => Circuit.softmax (scoresRealOfInputs inputs q') k) := by - intro q hq - exact oneHot_bounds_at_overlay (inputs := inputs) (cache := cache) (c := c) (hcert := hcert) - (spec := spec) (hweight_overlay := hweight_overlay) q hq - have hbound' : - logitDiffLowerBoundFromCacheWithEps c (logitDiffCache c) - (epsAtOverlay cache (weightBoundAtOverlay inputs cache spec)) = some lb := by - simpa [logitDiffLowerBoundRefinedFromCache_def, hcache] using hbound - exact - logitDiffLowerBoundFromCacheWithEps_le - (inputs := inputs) - (c := c) - (epsAtCustom := epsAtOverlay cache (weightBoundAtOverlay inputs cache spec)) - (hsound := hsound) - (honeHot := honeHot) - (hbound := hbound') - (hq := hq) - -/-- Refine-on-demand logit-diff lower bound using a supplied refinement spec is sound. -/ -theorem logitDiffLowerBoundRefineOnDemandWithSpec_le - [NeZero seq] (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (c : InductionHeadCert seq) (logitCache : LogitDiffCache seq) - (spec : InductionHeadRefineSpec seq) - (hcert : c = cache.cert) - (hcache : logitCache = logitDiffCache c) - (hsound : InductionHeadCertSound inputs c) - (hweight_overlay : - ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → - Circuit.softmax (scoresRealOfInputs inputs q) k ≤ - (weightBoundAtOverlay inputs cache spec q k : Real)) - {lb : Rat} - (hbound : - logitDiffLowerBoundRefineOnDemandWithSpec inputs cache c logitCache spec = some lb) - {q : Fin seq} (hq : q ∈ c.active) : - (lb : Real) ≤ headLogitDiff inputs q := by - classical - have honeHot : - ∀ q, q ∈ c.active → - Layers.OneHotApproxBoundsOnActive (Val := Real) - ((logitDiffCache c).epsAt q : Real) - (fun q' => q' = q) c.prev - (fun q' k => Circuit.softmax (scoresRealOfInputs inputs q') k) := by - intro q hq - have h := hsound.oneHot_bounds_at q hq - have heps : (logitDiffCache c).epsAt q = c.epsAt q := by - simp [logitDiffCache_def, Bounds.cacheBoundTask_apply] - simpa [heps] using h - have hbase_le : - ∀ {lb0 : Rat}, - logitDiffLowerBoundFromCache c logitCache = some lb0 → - (lb0 : Real) ≤ headLogitDiff inputs q := by - intro lb0 hbound0 - have hbound0' : - logitDiffLowerBoundFromCache c (logitDiffCache c) = some lb0 := by - simpa [hcache] using hbound0 - have hbound0'' : - logitDiffLowerBoundFromCacheWithEps c (logitDiffCache c) - (logitDiffCache c).epsAt = some lb0 := by - simpa [logitDiffLowerBoundFromCache_eq_withEps] using hbound0' - exact - logitDiffLowerBoundFromCacheWithEps_le - (inputs := inputs) - (c := c) - (epsAtCustom := (logitDiffCache c).epsAt) - (hsound := hsound) - (honeHot := honeHot) - (hbound := hbound0'') - (hq := hq) - cases h0 : logitDiffLowerBoundFromCache c logitCache with - | none => - simp [logitDiffLowerBoundRefineOnDemandWithSpec_def, h0] at hbound - | some lb0 => - by_cases hnonpos : lb0 ≤ 0 - · cases h1 : logitDiffLowerBoundRefinedFromCache inputs cache c logitCache spec with - | none => - have hlb : lb = lb0 := by - simpa [logitDiffLowerBoundRefineOnDemandWithSpec_def, h0, hnonpos, h1] using - hbound.symm - have hbase := hbase_le (lb0 := lb0) h0 - simpa [hlb] using hbase - | some lb1 => - have hlb : lb = max lb0 lb1 := by - simpa [logitDiffLowerBoundRefineOnDemandWithSpec_def, h0, hnonpos, h1] using - hbound.symm - have hbase := hbase_le (lb0 := lb0) h0 - have hrefine := - logitDiffLowerBoundRefinedFromCache_le - (inputs := inputs) - (cache := cache) - (c := c) - (logitCache := logitCache) - (spec := spec) - (hcert := hcert) - (hcache := hcache) - (hsound := hsound) - (hweight_overlay := hweight_overlay) - (hbound := h1) - (hq := hq) - have hmax' : - max (lb0 : Real) (lb1 : Real) ≤ headLogitDiff inputs q := by - exact max_le_iff.mpr ⟨hbase, hrefine⟩ - have hmax : (max lb0 lb1 : Real) ≤ headLogitDiff inputs q := by - simpa [ratToReal_max] using hmax' - simpa [hlb] using hmax - · have hlb : lb = lb0 := by - simpa [logitDiffLowerBoundRefineOnDemandWithSpec_def, h0, hnonpos] using hbound.symm - have hbase := hbase_le (lb0 := lb0) h0 - simpa [hlb] using hbase - -/-- Refine-on-demand logit-diff lower bound using argmin refinement keys is sound. -/ -theorem logitDiffLowerBoundRefineOnDemand_le - [NeZero seq] (inputs : Model.InductionHeadInputs seq dModel dHead) - (cache : InductionHeadCoreCache seq dModel dHead) - (c : InductionHeadCert seq) (logitCache : LogitDiffCache seq) - (hcert : c = cache.cert) - (hcache : logitCache = logitDiffCache c) - (hsound : InductionHeadCertSound inputs c) - (hweight_overlay : - ∀ q0 : Fin seq, ∀ refineBudget : Nat, - let spec := refineSpecForQueryWithWeightOnes inputs cache q0 refineBudget - ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → - Circuit.softmax (scoresRealOfInputs inputs q) k ≤ - (weightBoundAtOverlay inputs cache spec q k : Real)) - {lb : Rat} - (hbound : logitDiffLowerBoundRefineOnDemand inputs cache c logitCache = some lb) - {q : Fin seq} (hq : q ∈ c.active) : - (lb : Real) ≤ headLogitDiff inputs q := by - classical - have hbase_le : - ∀ {lb0 : Rat}, - logitDiffLowerBoundFromCache c logitCache = some lb0 → - (lb0 : Real) ≤ headLogitDiff inputs q := by - intro lb0 hbound0 - have hbound0' : - logitDiffLowerBoundFromCache c (logitDiffCache c) = some lb0 := by - simpa [hcache] using hbound0 - have hbound0'' : - logitDiffLowerBoundFromCacheWithEps c (logitDiffCache c) - (logitDiffCache c).epsAt = some lb0 := by - simpa [logitDiffLowerBoundFromCache_eq_withEps] using hbound0' - exact - logitDiffLowerBoundFromCacheWithEps_le - (inputs := inputs) - (c := c) - (epsAtCustom := (logitDiffCache c).epsAt) - (hsound := hsound) - (honeHot := by - intro q' hq' - have h := hsound.oneHot_bounds_at q' hq' - have heps : (logitDiffCache c).epsAt q' = c.epsAt q' := by - simp [logitDiffCache_def, Bounds.cacheBoundTask_apply] - simpa [heps] using h) - (hbound := hbound0'') - (hq := hq) - cases h0 : logitDiffLowerBoundFromCache c logitCache with - | none => - simp [logitDiffLowerBoundRefineOnDemand_def, h0] at hbound - | some lb0 => - by_cases hnonpos : lb0 ≤ 0 - · cases hargmin : logitDiffLowerBoundArgminFromCache c logitCache with - | none => - have hlb : lb = lb0 := by - simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin] using - hbound.symm - have hbase := hbase_le (lb0 := lb0) h0 - simpa [hlb] using hbase - | some q0 => - let refineBudget := max 1 cache.splitBudgetDiffRefined - let spec := refineSpecForQueryWithWeightOnes inputs cache q0 refineBudget - cases h1 : - logitDiffLowerBoundRefinedFromCache inputs cache c logitCache spec with - | none => - have hlb : lb = lb0 := by - simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin, h1, - spec, refineBudget] using hbound.symm - have hbase := hbase_le (lb0 := lb0) h0 - simpa [hlb] using hbase - | some lb1 => - let lb01 := max lb0 lb1 - by_cases hnonpos1 : lb01 ≤ 0 - · let refineBudget' := refineBudgetBoost refineBudget - let spec' := refineSpecForQueryWithWeightOnes inputs cache q0 refineBudget' - cases h2 : - logitDiffLowerBoundRefinedFromCache inputs cache c logitCache spec' with - | none => - let valBudget := refineBudgetBoost refineBudget - let valCount := refineLowValueCount refineBudget - let valKeys := - loAtKeysAt inputs cache q0 ∪ - lowValueKeysAt inputs cache q0 valCount - let valsLo := valsLoOverlay inputs cache valBudget valKeys - cases hval : - logitDiffLowerBoundFromCacheWithEpsVals c logitCache.epsAt valsLo with - | none => - have hlb : lb = lb01 := by - simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin, - h1, spec, refineBudget, lb01, hnonpos1, h2, spec', refineBudget', - valBudget, valCount, valKeys, valsLo, hval] using hbound.symm - have hbase := hbase_le (lb0 := lb0) h0 - have hweight_overlay' : - ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → - Circuit.softmax (scoresRealOfInputs inputs q) k ≤ - (weightBoundAtOverlay inputs cache spec q k : Real) := by - simpa [spec, refineBudget] using hweight_overlay q0 refineBudget - have hrefine := - logitDiffLowerBoundRefinedFromCache_le - (inputs := inputs) - (cache := cache) - (c := c) - (logitCache := logitCache) - (spec := spec) - (hcert := hcert) - (hcache := hcache) - (hsound := hsound) - (hweight_overlay := hweight_overlay') - (hbound := h1) - (hq := hq) - have hmax' : - max (lb0 : Real) (lb1 : Real) ≤ headLogitDiff inputs q := by - exact max_le_iff.mpr ⟨hbase, hrefine⟩ - have hmax : (lb01 : Real) ≤ headLogitDiff inputs q := by - simpa [lb01, ratToReal_max] using hmax' - simpa [hlb] using hmax - | some lb2 => - have hlb : lb = max lb01 lb2 := by - simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin, - h1, spec, refineBudget, lb01, hnonpos1, h2, spec', refineBudget', - valBudget, valCount, valKeys, valsLo, hval] using hbound.symm - have hbase := hbase_le (lb0 := lb0) h0 - have hweight_overlay' : - ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → - Circuit.softmax (scoresRealOfInputs inputs q) k ≤ - (weightBoundAtOverlay inputs cache spec q k : Real) := by - simpa [spec, refineBudget] using hweight_overlay q0 refineBudget - have hrefine := - logitDiffLowerBoundRefinedFromCache_le - (inputs := inputs) - (cache := cache) - (c := c) - (logitCache := logitCache) - (spec := spec) - (hcert := hcert) - (hcache := hcache) - (hsound := hsound) - (hweight_overlay := hweight_overlay') - (hbound := h1) - (hq := hq) - have hmax' : - max (lb0 : Real) (lb1 : Real) ≤ headLogitDiff inputs q := by - exact max_le_iff.mpr ⟨hbase, hrefine⟩ - have hmax : (lb01 : Real) ≤ headLogitDiff inputs q := by - simpa [lb01, ratToReal_max] using hmax' - have hsound_cache : InductionHeadCertSound inputs cache.cert := by - simpa [hcert] using hsound - have hvalsLo : - ∀ k, (valsLo k : Real) ≤ valsRealOfInputs inputs k := by - exact valsLoOverlay_le_valsReal (inputs := inputs) (cache := cache) - (budget := valBudget) (refineKeys := valKeys) hsound_cache - have honeHot : - ∀ q, q ∈ c.active → - Layers.OneHotApproxBoundsOnActive (Val := Real) - ((logitDiffCache c).epsAt q : Real) - (fun q' => q' = q) c.prev - (fun q' k => Circuit.softmax (scoresRealOfInputs inputs q') k) := - by - intro q' hq' - have h := hsound.oneHot_bounds_at q' hq' - have heps : (logitDiffCache c).epsAt q' = c.epsAt q' := by - simp [logitDiffCache_def, Bounds.cacheBoundTask_apply] - simpa [heps] using h - have hval' : - logitDiffLowerBoundFromCacheWithEpsVals c (logitDiffCache c).epsAt - valsLo = some lb2 := by - simpa [hcache] using hval - have hrefine_val := - logitDiffLowerBoundFromCacheWithEpsVals_le - (inputs := inputs) - (c := c) - (epsAtCustom := (logitDiffCache c).epsAt) - (valsLoCustom := valsLo) - (hsound := hsound) - (honeHot := honeHot) - (hvalsLo := hvalsLo) - (hbound := hval') - (hq := hq) - have hmax' : - max (lb01 : Real) (lb2 : Real) ≤ headLogitDiff inputs q := by - exact max_le_iff.mpr ⟨hmax, hrefine_val⟩ - have hmax : (max lb01 lb2 : Real) ≤ headLogitDiff inputs q := by - simpa [ratToReal_max] using hmax' - simpa [hlb] using hmax - | some lb2 => - have hbase := hbase_le (lb0 := lb0) h0 - have hweight_overlay' : - ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → - Circuit.softmax (scoresRealOfInputs inputs q) k ≤ - (weightBoundAtOverlay inputs cache spec q k : Real) := by - simpa [spec, refineBudget] using hweight_overlay q0 refineBudget - have hweight_overlay'' : - ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → - Circuit.softmax (scoresRealOfInputs inputs q) k ≤ - (weightBoundAtOverlay inputs cache spec' q k : Real) := by - simpa [spec', refineBudget'] using hweight_overlay q0 refineBudget' - have hrefine := - logitDiffLowerBoundRefinedFromCache_le - (inputs := inputs) - (cache := cache) - (c := c) - (logitCache := logitCache) - (spec := spec) - (hcert := hcert) - (hcache := hcache) - (hsound := hsound) - (hweight_overlay := hweight_overlay') - (hbound := h1) - (hq := hq) - have hrefine' := - logitDiffLowerBoundRefinedFromCache_le - (inputs := inputs) - (cache := cache) - (c := c) - (logitCache := logitCache) - (spec := spec') - (hcert := hcert) - (hcache := hcache) - (hsound := hsound) - (hweight_overlay := hweight_overlay'') - (hbound := h2) - (hq := hq) - have hmax01' : - max (lb0 : Real) (lb1 : Real) ≤ headLogitDiff inputs q := by - exact max_le_iff.mpr ⟨hbase, hrefine⟩ - have hmax01 : (lb01 : Real) ≤ headLogitDiff inputs q := by - simpa [lb01, ratToReal_max] using hmax01' - have hmax' : - max (lb01 : Real) (lb2 : Real) ≤ headLogitDiff inputs q := by - exact max_le_iff.mpr ⟨hmax01, hrefine'⟩ - have hmax_weight : (max lb01 lb2 : Real) ≤ headLogitDiff inputs q := by - simpa [ratToReal_max] using hmax' - let lbWeight : Rat := max lb01 lb2 - by_cases hweight_nonpos : lbWeight ≤ 0 - · let valBudget := refineBudgetBoost refineBudget - let valCount := refineLowValueCount refineBudget - let valKeys := - loAtKeysAt inputs cache q0 ∪ - lowValueKeysAt inputs cache q0 valCount - let valsLo := valsLoOverlay inputs cache valBudget valKeys - cases hval : - logitDiffLowerBoundFromCacheWithEpsVals c logitCache.epsAt valsLo with - | none => - have hlb : lb = lbWeight := by - simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin, - h1, spec, refineBudget, lb01, hnonpos1, h2, spec', - refineBudget', lbWeight, hweight_nonpos, valBudget, valKeys, - valCount, valsLo, hval] using hbound.symm - simpa [hlb, lbWeight] using hmax_weight - | some lb3 => - have hlb : lb = max lbWeight lb3 := by - simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin, - h1, spec, refineBudget, lb01, hnonpos1, h2, spec', - refineBudget', lbWeight, hweight_nonpos, valBudget, valKeys, - valCount, valsLo, hval] using hbound.symm - have hsound_cache : InductionHeadCertSound inputs cache.cert := by - simpa [hcert] using hsound - have hvalsLo : - ∀ k, (valsLo k : Real) ≤ valsRealOfInputs inputs k := by - exact valsLoOverlay_le_valsReal (inputs := inputs) (cache := cache) - (budget := valBudget) (refineKeys := valKeys) hsound_cache - have honeHot : - ∀ q, q ∈ c.active → - Layers.OneHotApproxBoundsOnActive (Val := Real) - ((logitDiffCache c).epsAt q : Real) - (fun q' => q' = q) c.prev - (fun q' k => - Circuit.softmax (scoresRealOfInputs inputs q') k) := by - intro q' hq' - have h := hsound.oneHot_bounds_at q' hq' - have heps : (logitDiffCache c).epsAt q' = c.epsAt q' := by - simp [logitDiffCache_def, Bounds.cacheBoundTask_apply] - simpa [heps] using h - have hval' : - logitDiffLowerBoundFromCacheWithEpsVals c (logitDiffCache c).epsAt - valsLo = some lb3 := by - simpa [hcache] using hval - have hrefine_val := - logitDiffLowerBoundFromCacheWithEpsVals_le - (inputs := inputs) - (c := c) - (epsAtCustom := (logitDiffCache c).epsAt) - (valsLoCustom := valsLo) - (hsound := hsound) - (honeHot := honeHot) - (hvalsLo := hvalsLo) - (hbound := hval') - (hq := hq) - have hmax' : - max (lbWeight : Real) (lb3 : Real) ≤ headLogitDiff inputs q := by - exact max_le_iff.mpr ⟨by simpa [lbWeight] using hmax_weight, - hrefine_val⟩ - have hmax : (max lbWeight lb3 : Real) ≤ headLogitDiff inputs q := by - simpa [ratToReal_max] using hmax' - simpa [hlb] using hmax - · have hlb : lb = lbWeight := by - simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin, h1, - spec, refineBudget, lb01, hnonpos1, h2, spec', refineBudget', - lbWeight, hweight_nonpos] using hbound.symm - simpa [hlb, lbWeight] using hmax_weight - · have hlb : lb = lb01 := by - simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos, hargmin, h1, - spec, refineBudget, lb01, hnonpos1] using hbound.symm - have hbase := hbase_le (lb0 := lb0) h0 - have hweight_overlay' : - ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → - Circuit.softmax (scoresRealOfInputs inputs q) k ≤ - (weightBoundAtOverlay inputs cache spec q k : Real) := by - simpa [spec, refineBudget] using hweight_overlay q0 refineBudget - have hrefine := - logitDiffLowerBoundRefinedFromCache_le - (inputs := inputs) - (cache := cache) - (c := c) - (logitCache := logitCache) - (spec := spec) - (hcert := hcert) - (hcache := hcache) - (hsound := hsound) - (hweight_overlay := hweight_overlay') - (hbound := h1) - (hq := hq) - have hmax' : - max (lb0 : Real) (lb1 : Real) ≤ headLogitDiff inputs q := by - exact max_le_iff.mpr ⟨hbase, hrefine⟩ - have hmax : (lb01 : Real) ≤ headLogitDiff inputs q := by - simpa [lb01, ratToReal_max] using hmax' - simpa [hlb] using hmax - · have hlb : lb = lb0 := by - simpa [logitDiffLowerBoundRefineOnDemand_def, h0, hnonpos] using hbound.symm - have hbase := hbase_le (lb0 := lb0) h0 - simpa [hlb] using hbase - -end Sound - -end Nfp From fd265ad5e65a8eb3f03e7d2bbaf4a11a33481c4b Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 08:11:21 +0100 Subject: [PATCH 218/244] Remove residual cert packaging in bounds --- Nfp/Sound/Bounds/Transformer/Basic.lean | 46 ------------------------- Nfp/Sound/Induction/EndToEnd.lean | 15 ++++---- 2 files changed, 6 insertions(+), 55 deletions(-) diff --git a/Nfp/Sound/Bounds/Transformer/Basic.lean b/Nfp/Sound/Bounds/Transformer/Basic.lean index 579018e..d9f134e 100644 --- a/Nfp/Sound/Bounds/Transformer/Basic.lean +++ b/Nfp/Sound/Bounds/Transformer/Basic.lean @@ -517,52 +517,6 @@ theorem gpt2ResidualIntervalBoundsActive_spec simpa [bounds, gpt2ResidualIntervalBoundsActive, final, baseLo, baseHi] using hbounds q hq i -/-- Package GPT-2 residual bounds into a residual-interval certificate. -/ -theorem gpt2ResidualIntervalBoundsActive_sound - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] - (active : Finset (Fin seq)) (hactive : active.Nonempty) (eps : Rat) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (finalLn : Model.Gpt2FinalLayerNorm dModel) - (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (embed : Fin seq → Fin dModel → Rat) - (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : - let bounds := gpt2ResidualIntervalBoundsActive active hactive eps layers heads finalLn embed - let cert : Circuit.ResidualIntervalCert dModel := { lo := bounds.1, hi := bounds.2 } - Circuit.ResidualIntervalBounds cert ∧ - ∀ q, q ∈ active → ∀ i, - (cert.lo i : Real) ≤ - transformerStackFinalReal eps finalLn layers heads scores - (fun q i => (embed q i : Real)) q i ∧ - transformerStackFinalReal eps finalLn layers heads scores - (fun q i => (embed q i : Real)) q i ≤ (cert.hi i : Real) := by - classical - intro bounds cert - have hspec : - ∀ q, q ∈ active → ∀ i, - (bounds.1 i : Real) ≤ - transformerStackFinalReal eps finalLn layers heads scores - (fun q i => (embed q i : Real)) q i ∧ - transformerStackFinalReal eps finalLn layers heads scores - (fun q i => (embed q i : Real)) q i ≤ (bounds.2 i : Real) := by - simpa [bounds] using - (gpt2ResidualIntervalBoundsActive_spec (active := active) (hactive := hactive) - (eps := eps) (layers := layers) (heads := heads) (finalLn := finalLn) - (scores := scores) (embed := embed) (hne := hne) (heps := heps) (hsqrt := hsqrt)) - have hbounds : Circuit.ResidualIntervalBounds cert := by - refine { lo_le_hi := ?_ } - intro i - rcases hactive with ⟨q0, hq0⟩ - have hq := hspec q0 hq0 i - have hreal : (bounds.1 i : Real) ≤ (bounds.2 i : Real) := hq.1.trans hq.2 - have hreal' : ratToReal (bounds.1 i) ≤ ratToReal (bounds.2 i) := by - simpa [ratToReal_def] using hreal - exact (ratToReal_le_iff (x := bounds.1 i) (y := bounds.2 i)).1 hreal' - refine And.intro hbounds ?_ - intro q hq i - have hq' := hspec q hq i - simpa [cert] using hq' - end Bounds end Sound diff --git a/Nfp/Sound/Induction/EndToEnd.lean b/Nfp/Sound/Induction/EndToEnd.lean index c5866ce..9a28bc1 100644 --- a/Nfp/Sound/Induction/EndToEnd.lean +++ b/Nfp/Sound/Induction/EndToEnd.lean @@ -31,13 +31,13 @@ theorem logitDiffLowerBound_end_to_end_gpt2 (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < Bounds.sqrtLower eps) (hactive : inputs.active.Nonempty) : - let bounds := + let bounds := Bounds.gpt2ResidualIntervalBoundsActive inputs.active hactive eps layers heads finalLn inputs.embed - let output : Fin seq → Fin dModel → Real := + let output : Fin seq → Fin dModel → Real := fun q i => Bounds.transformerStackFinalReal eps finalLn layers heads scores (fun q i => (inputs.embed q i : Real)) q i - ∀ q, q ∈ inputs.active → + ∀ q, q ∈ inputs.active → (lb : Real) - (Bounds.dotIntervalAbsBound inputs.direction (fun i => bounds.1 i - headCert.hi i) @@ -45,12 +45,11 @@ theorem logitDiffLowerBound_end_to_end_gpt2 dotProduct (fun i => (inputs.direction i : Real)) (fun i => output q i) := by classical intro bounds output q hq - let cert : Circuit.ResidualIntervalCert dModel := { lo := bounds.1, hi := bounds.2 } have hbounds : ∀ q, q ∈ inputs.active → ∀ i, (bounds.1 i : Real) ≤ output q i ∧ output q i ≤ (bounds.2 i : Real) := by - have hsound := - Bounds.gpt2ResidualIntervalBoundsActive_sound + simpa [bounds, output] using + (Bounds.gpt2ResidualIntervalBoundsActive_spec (active := inputs.active) (hactive := hactive) (eps := eps) @@ -61,9 +60,7 @@ theorem logitDiffLowerBound_end_to_end_gpt2 (embed := inputs.embed) (hne := hne) (heps := heps) - (hsqrt := hsqrt) - rcases (by simpa [bounds, cert, output] using hsound) with ⟨_, hmem⟩ - exact hmem + (hsqrt := hsqrt)) have hhead_out := hhead.output_mem have h := logitDiffLowerBound_with_output_intervals From b28b7c5b06520a6a1f183c705c51498a74e131d5 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 08:18:04 +0100 Subject: [PATCH 219/244] Move IO parsing helpers to untrusted Parse namespace --- Nfp/IO/InductionHead/Cert.lean | 4 ++-- Nfp/IO/Loaders.lean | 24 +++++++++---------- Nfp/IO/Parse.lean | 14 +++++++++++ Nfp/IO/{Pure => Parse}/Basic.lean | 4 ++-- Nfp/IO/{Pure => Parse}/Downstream.lean | 8 +++---- Nfp/IO/{Pure => Parse}/InductionHead.lean | 6 ++--- .../{Pure => Parse}/InductionHead/Bytes.lean | 6 ++--- Nfp/IO/{Pure => Parse}/Residual.lean | 8 +++---- Nfp/IO/{Pure => Parse}/SoftmaxMargin.lean | 4 ++-- .../{Pure => Parse}/SoftmaxMargin/Cert.lean | 8 +++---- Nfp/IO/{Pure => Parse}/SoftmaxMargin/Raw.lean | 8 +++---- .../{Pure => Parse}/SoftmaxMargin/Shared.lean | 6 ++--- Nfp/IO/{Pure => Parse}/ValueRange.lean | 4 ++-- Nfp/IO/{Pure => Parse}/ValueRange/Cert.lean | 8 +++---- Nfp/IO/{Pure => Parse}/ValueRange/Raw.lean | 8 +++---- Nfp/IO/{Pure => Parse}/ValueRange/Shared.lean | 6 ++--- Nfp/IO/Pure.lean | 14 ----------- Nfp/IO/Util.lean | 4 ++-- 18 files changed, 72 insertions(+), 72 deletions(-) create mode 100644 Nfp/IO/Parse.lean rename Nfp/IO/{Pure => Parse}/Basic.lean (98%) rename Nfp/IO/{Pure => Parse}/Downstream.lean (98%) rename Nfp/IO/{Pure => Parse}/InductionHead.lean (86%) rename Nfp/IO/{Pure => Parse}/InductionHead/Bytes.lean (99%) rename Nfp/IO/{Pure => Parse}/Residual.lean (97%) rename Nfp/IO/{Pure => Parse}/SoftmaxMargin.lean (54%) rename Nfp/IO/{Pure => Parse}/SoftmaxMargin/Cert.lean (94%) rename Nfp/IO/{Pure => Parse}/SoftmaxMargin/Raw.lean (94%) rename Nfp/IO/{Pure => Parse}/SoftmaxMargin/Shared.lean (98%) rename Nfp/IO/{Pure => Parse}/ValueRange.lean (55%) rename Nfp/IO/{Pure => Parse}/ValueRange/Cert.lean (93%) rename Nfp/IO/{Pure => Parse}/ValueRange/Raw.lean (93%) rename Nfp/IO/{Pure => Parse}/ValueRange/Shared.lean (98%) delete mode 100644 Nfp/IO/Pure.lean diff --git a/Nfp/IO/InductionHead/Cert.lean b/Nfp/IO/InductionHead/Cert.lean index 317d520..55555f0 100644 --- a/Nfp/IO/InductionHead/Cert.lean +++ b/Nfp/IO/InductionHead/Cert.lean @@ -5,7 +5,7 @@ module public import Mathlib.Data.Finset.Insert public import Nfp.Circuit.Cert.InductionHead public import Nfp.Circuit.Cert.LogitDiff -public import Nfp.IO.Pure.Basic +public import Nfp.IO.Parse.Basic public import Nfp.IO.Util /-! @@ -19,7 +19,7 @@ namespace Nfp namespace IO open Nfp.Circuit -open Nfp.IO.Pure +open Nfp.IO.Parse namespace InductionHeadCert diff --git a/Nfp/IO/Loaders.lean b/Nfp/IO/Loaders.lean index 8d0ac1f..5f1336f 100644 --- a/Nfp/IO/Loaders.lean +++ b/Nfp/IO/Loaders.lean @@ -2,7 +2,7 @@ module -public import Nfp.IO.Pure +public import Nfp.IO.Parse public import Nfp.Circuit.Cert.LogitDiff public import Nfp.Circuit.Cert.DownstreamLinear public import Nfp.Circuit.Cert.ResidualBound @@ -24,50 +24,50 @@ open Nfp.Circuit def loadSoftmaxMarginCert (path : System.FilePath) : IO (Except String (Sigma SoftmaxMarginCert)) := do let data ← IO.FS.readFile path - return Pure.parseSoftmaxMarginCert data + return Parse.parseSoftmaxMarginCert data /-- Load raw softmax-margin inputs from disk. -/ def loadSoftmaxMarginRaw (path : System.FilePath) : - IO (Except String (Sigma Pure.SoftmaxMarginRaw)) := do + IO (Except String (Sigma Parse.SoftmaxMarginRaw)) := do let data ← IO.FS.readFile path - return Pure.parseSoftmaxMarginRaw data + return Parse.parseSoftmaxMarginRaw data /-- Load a value-range certificate from disk. -/ def loadValueRangeCert (path : System.FilePath) : IO (Except String (Sigma ValueRangeCert)) := do let data ← IO.FS.readFile path - return Pure.parseValueRangeCert data + return Parse.parseValueRangeCert data /-- Load a downstream linear certificate from disk. -/ def loadDownstreamLinearCert (path : System.FilePath) : IO (Except String DownstreamLinearCert) := do let data ← IO.FS.readFile path - return Pure.parseDownstreamLinearCert data + return Parse.parseDownstreamLinearCert data /-- Load a downstream matrix payload from disk. -/ def loadDownstreamMatrixRaw (path : System.FilePath) : IO (Except String (Sigma (fun rows => - Sigma (fun cols => Pure.DownstreamMatrixRaw rows cols)))) := do + Sigma (fun cols => Parse.DownstreamMatrixRaw rows cols)))) := do let data ← IO.FS.readFile path - return Pure.parseDownstreamMatrixRaw data + return Parse.parseDownstreamMatrixRaw data /-- Load a residual-bound certificate from disk. -/ def loadResidualBoundCert (path : System.FilePath) : IO (Except String (Sigma (fun n => ResidualBoundCert n))) := do let data ← IO.FS.readFile path - return Pure.parseResidualBoundCert data + return Parse.parseResidualBoundCert data /-- Load a residual-interval certificate from disk. -/ def loadResidualIntervalCert (path : System.FilePath) : IO (Except String (Sigma (fun n => ResidualIntervalCert n))) := do let data ← IO.FS.readFile path - return Pure.parseResidualIntervalCert data + return Parse.parseResidualIntervalCert data /-- Load raw value-range inputs from disk. -/ def loadValueRangeRaw (path : System.FilePath) : - IO (Except String (Sigma Pure.ValueRangeRaw)) := do + IO (Except String (Sigma Parse.ValueRangeRaw)) := do let data ← IO.FS.readFile path - return Pure.parseValueRangeRaw data + return Parse.parseValueRangeRaw data end IO diff --git a/Nfp/IO/Parse.lean b/Nfp/IO/Parse.lean new file mode 100644 index 0000000..82bd51b --- /dev/null +++ b/Nfp/IO/Parse.lean @@ -0,0 +1,14 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.IO.Parse.Basic +public import Nfp.IO.Parse.Downstream +public import Nfp.IO.Parse.InductionHead +public import Nfp.IO.Parse.Residual +public import Nfp.IO.Parse.SoftmaxMargin +public import Nfp.IO.Parse.ValueRange + +/-! +Aggregator for pure CLI parsing helpers. +-/ diff --git a/Nfp/IO/Pure/Basic.lean b/Nfp/IO/Parse/Basic.lean similarity index 98% rename from Nfp/IO/Pure/Basic.lean rename to Nfp/IO/Parse/Basic.lean index 81ccb4a..0dfd2c6 100644 --- a/Nfp/IO/Pure/Basic.lean +++ b/Nfp/IO/Parse/Basic.lean @@ -14,7 +14,7 @@ namespace Nfp namespace IO -namespace Pure +namespace Parse /-- Split a line into whitespace-separated tokens. -/ def splitWords (line : String) : List String := @@ -72,7 +72,7 @@ def parseRat (s : String) : Except String Rat := do | _ => throw s!"invalid rational '{s}'" -end Pure +end Parse end IO diff --git a/Nfp/IO/Pure/Downstream.lean b/Nfp/IO/Parse/Downstream.lean similarity index 98% rename from Nfp/IO/Pure/Downstream.lean rename to Nfp/IO/Parse/Downstream.lean index cf6dba8..2711bdc 100644 --- a/Nfp/IO/Pure/Downstream.lean +++ b/Nfp/IO/Parse/Downstream.lean @@ -3,10 +3,10 @@ module public import Nfp.Circuit.Cert.DownstreamLinear -public import Nfp.IO.Pure.Basic +public import Nfp.IO.Parse.Basic /-! -Pure parsing helpers for downstream linear and matrix payloads. +Parse parsing helpers for downstream linear and matrix payloads. -/ public section @@ -15,7 +15,7 @@ namespace Nfp namespace IO -namespace Pure +namespace Parse open Nfp.Circuit @@ -199,7 +199,7 @@ def parseDownstreamMatrixRaw (input : String) : let raw ← finalizeDownstreamMatrixState st return ⟨rows, ⟨cols, raw⟩⟩ -end Pure +end Parse end IO diff --git a/Nfp/IO/Pure/InductionHead.lean b/Nfp/IO/Parse/InductionHead.lean similarity index 86% rename from Nfp/IO/Pure/InductionHead.lean rename to Nfp/IO/Parse/InductionHead.lean index 0c3adcf..f6bc47f 100644 --- a/Nfp/IO/Pure/InductionHead.lean +++ b/Nfp/IO/Parse/InductionHead.lean @@ -2,7 +2,7 @@ module -public import Nfp.IO.Pure.InductionHead.Bytes +public import Nfp.IO.Parse.InductionHead.Bytes /-! Parsing helpers for induction-head input payloads. @@ -14,7 +14,7 @@ namespace Nfp namespace IO -namespace Pure +namespace Parse /-- Parse a raw induction head input payload from text. -/ def parseInductionHeadInputs (input : String) : @@ -22,7 +22,7 @@ def parseInductionHeadInputs (input : String) : Sigma (fun dModel => Sigma (fun dHead => Model.InductionHeadInputs seq dModel dHead)))) := do parseInductionHeadInputsBytes input.toUTF8 -end Pure +end Parse end IO diff --git a/Nfp/IO/Pure/InductionHead/Bytes.lean b/Nfp/IO/Parse/InductionHead/Bytes.lean similarity index 99% rename from Nfp/IO/Pure/InductionHead/Bytes.lean rename to Nfp/IO/Parse/InductionHead/Bytes.lean index 9db33d4..8b18d55 100644 --- a/Nfp/IO/Pure/InductionHead/Bytes.lean +++ b/Nfp/IO/Parse/InductionHead/Bytes.lean @@ -3,7 +3,7 @@ module public import Mathlib.Data.Finset.Insert -public import Nfp.IO.Pure.Basic +public import Nfp.IO.Parse.Basic public import Nfp.Model.InductionHead /-! @@ -16,7 +16,7 @@ namespace Nfp namespace IO -namespace Pure +namespace Parse private def kwSeq : ByteArray := "seq".toUTF8 private def kwDModel : ByteArray := "d_model".toUTF8 @@ -783,7 +783,7 @@ def parseInductionHeadInputsBytes (data : ByteArray) : let inputs ← finalizeHeadState hpos st return ⟨seq, ⟨dModel, ⟨dHead, inputs⟩⟩⟩ -end Pure +end Parse end IO diff --git a/Nfp/IO/Pure/Residual.lean b/Nfp/IO/Parse/Residual.lean similarity index 97% rename from Nfp/IO/Pure/Residual.lean rename to Nfp/IO/Parse/Residual.lean index 97e67b9..3769a79 100644 --- a/Nfp/IO/Pure/Residual.lean +++ b/Nfp/IO/Parse/Residual.lean @@ -4,10 +4,10 @@ module public import Nfp.Circuit.Cert.ResidualBound public import Nfp.Circuit.Cert.ResidualInterval -public import Nfp.IO.Pure.Basic +public import Nfp.IO.Parse.Basic /-! -Pure parsing helpers for residual-bound and residual-interval certificates. +Parse parsing helpers for residual-bound and residual-interval certificates. -/ public section @@ -16,7 +16,7 @@ namespace Nfp namespace IO -namespace Pure +namespace Parse open Nfp.Circuit @@ -133,7 +133,7 @@ def parseResidualIntervalCert (input : String) : return ⟨dim, cert⟩ | _ => throw "expected header 'dim '" -end Pure +end Parse end IO diff --git a/Nfp/IO/Pure/SoftmaxMargin.lean b/Nfp/IO/Parse/SoftmaxMargin.lean similarity index 54% rename from Nfp/IO/Pure/SoftmaxMargin.lean rename to Nfp/IO/Parse/SoftmaxMargin.lean index c771a7d..648a6e9 100644 --- a/Nfp/IO/Pure/SoftmaxMargin.lean +++ b/Nfp/IO/Parse/SoftmaxMargin.lean @@ -2,8 +2,8 @@ module -public import Nfp.IO.Pure.SoftmaxMargin.Cert -public import Nfp.IO.Pure.SoftmaxMargin.Raw +public import Nfp.IO.Parse.SoftmaxMargin.Cert +public import Nfp.IO.Parse.SoftmaxMargin.Raw /-! Aggregator for softmax-margin parsing helpers. diff --git a/Nfp/IO/Pure/SoftmaxMargin/Cert.lean b/Nfp/IO/Parse/SoftmaxMargin/Cert.lean similarity index 94% rename from Nfp/IO/Pure/SoftmaxMargin/Cert.lean rename to Nfp/IO/Parse/SoftmaxMargin/Cert.lean index 97ba361..94892f0 100644 --- a/Nfp/IO/Pure/SoftmaxMargin/Cert.lean +++ b/Nfp/IO/Parse/SoftmaxMargin/Cert.lean @@ -3,10 +3,10 @@ module public import Nfp.Circuit.Cert.SoftmaxMargin -public import Nfp.IO.Pure.SoftmaxMargin.Shared +public import Nfp.IO.Parse.SoftmaxMargin.Shared /-! -Pure parsing helpers for softmax-margin certificates. +Parse parsing helpers for softmax-margin certificates. -/ public section @@ -15,7 +15,7 @@ namespace Nfp namespace IO -namespace Pure +namespace Parse open Nfp.Circuit @@ -76,7 +76,7 @@ def parseSoftmaxMarginCert (input : String) : let cert ← finalizeState hpos st return ⟨seq, cert⟩ -end Pure +end Parse end IO diff --git a/Nfp/IO/Pure/SoftmaxMargin/Raw.lean b/Nfp/IO/Parse/SoftmaxMargin/Raw.lean similarity index 94% rename from Nfp/IO/Pure/SoftmaxMargin/Raw.lean rename to Nfp/IO/Parse/SoftmaxMargin/Raw.lean index 6869e00..8bfeb3e 100644 --- a/Nfp/IO/Pure/SoftmaxMargin/Raw.lean +++ b/Nfp/IO/Parse/SoftmaxMargin/Raw.lean @@ -3,10 +3,10 @@ module public import Nfp.Circuit.Cert.SoftmaxMargin -public import Nfp.IO.Pure.SoftmaxMargin.Shared +public import Nfp.IO.Parse.SoftmaxMargin.Shared /-! -Pure parsing helpers for raw softmax-margin inputs. +Parse parsing helpers for raw softmax-margin inputs. -/ public section @@ -15,7 +15,7 @@ namespace Nfp namespace IO -namespace Pure +namespace Parse open Nfp.Circuit @@ -77,7 +77,7 @@ def parseSoftmaxMarginRaw (input : String) : let raw ← finalizeRawState hpos st return ⟨seq, raw⟩ -end Pure +end Parse end IO diff --git a/Nfp/IO/Pure/SoftmaxMargin/Shared.lean b/Nfp/IO/Parse/SoftmaxMargin/Shared.lean similarity index 98% rename from Nfp/IO/Pure/SoftmaxMargin/Shared.lean rename to Nfp/IO/Parse/SoftmaxMargin/Shared.lean index 4930a67..0ac29c9 100644 --- a/Nfp/IO/Pure/SoftmaxMargin/Shared.lean +++ b/Nfp/IO/Parse/SoftmaxMargin/Shared.lean @@ -3,7 +3,7 @@ module public import Mathlib.Data.Finset.Insert -public import Nfp.IO.Pure.Basic +public import Nfp.IO.Parse.Basic /-! Shared parsing helpers for softmax-margin payloads. @@ -15,7 +15,7 @@ namespace Nfp namespace IO -namespace Pure +namespace Parse namespace SoftmaxMargin @@ -138,7 +138,7 @@ def parseSeq (tokens : List (List String)) : Except String Nat := do end SoftmaxMargin -end Pure +end Parse end IO diff --git a/Nfp/IO/Pure/ValueRange.lean b/Nfp/IO/Parse/ValueRange.lean similarity index 55% rename from Nfp/IO/Pure/ValueRange.lean rename to Nfp/IO/Parse/ValueRange.lean index f41810b..102943f 100644 --- a/Nfp/IO/Pure/ValueRange.lean +++ b/Nfp/IO/Parse/ValueRange.lean @@ -2,8 +2,8 @@ module -public import Nfp.IO.Pure.ValueRange.Cert -public import Nfp.IO.Pure.ValueRange.Raw +public import Nfp.IO.Parse.ValueRange.Cert +public import Nfp.IO.Parse.ValueRange.Raw /-! Aggregator for value-range parsing helpers. diff --git a/Nfp/IO/Pure/ValueRange/Cert.lean b/Nfp/IO/Parse/ValueRange/Cert.lean similarity index 93% rename from Nfp/IO/Pure/ValueRange/Cert.lean rename to Nfp/IO/Parse/ValueRange/Cert.lean index ee7c14f..e13d6e1 100644 --- a/Nfp/IO/Pure/ValueRange/Cert.lean +++ b/Nfp/IO/Parse/ValueRange/Cert.lean @@ -3,10 +3,10 @@ module public import Nfp.Circuit.Cert.ValueRange -public import Nfp.IO.Pure.ValueRange.Shared +public import Nfp.IO.Parse.ValueRange.Shared /-! -Pure parsing helpers for value-range certificates. +Parse parsing helpers for value-range certificates. -/ public section @@ -15,7 +15,7 @@ namespace Nfp namespace IO -namespace Pure +namespace Parse open Nfp.Circuit @@ -60,7 +60,7 @@ def parseValueRangeCert (input : String) : let cert ← finalizeValueState st return ⟨seq, cert⟩ -end Pure +end Parse end IO diff --git a/Nfp/IO/Pure/ValueRange/Raw.lean b/Nfp/IO/Parse/ValueRange/Raw.lean similarity index 93% rename from Nfp/IO/Pure/ValueRange/Raw.lean rename to Nfp/IO/Parse/ValueRange/Raw.lean index 7807093..b6411c2 100644 --- a/Nfp/IO/Pure/ValueRange/Raw.lean +++ b/Nfp/IO/Parse/ValueRange/Raw.lean @@ -3,10 +3,10 @@ module public import Nfp.Circuit.Cert.ValueRange -public import Nfp.IO.Pure.ValueRange.Shared +public import Nfp.IO.Parse.ValueRange.Shared /-! -Pure parsing helpers for raw value-range inputs. +Parse parsing helpers for raw value-range inputs. -/ public section @@ -15,7 +15,7 @@ namespace Nfp namespace IO -namespace Pure +namespace Parse open Nfp.Circuit @@ -59,7 +59,7 @@ def parseValueRangeRaw (input : String) : let raw ← finalizeValueRawState st return ⟨seq, raw⟩ -end Pure +end Parse end IO diff --git a/Nfp/IO/Pure/ValueRange/Shared.lean b/Nfp/IO/Parse/ValueRange/Shared.lean similarity index 98% rename from Nfp/IO/Pure/ValueRange/Shared.lean rename to Nfp/IO/Parse/ValueRange/Shared.lean index 93a8fc5..e51800c 100644 --- a/Nfp/IO/Pure/ValueRange/Shared.lean +++ b/Nfp/IO/Parse/ValueRange/Shared.lean @@ -3,7 +3,7 @@ module public import Nfp.Circuit.Cert.ValueRange -public import Nfp.IO.Pure.Basic +public import Nfp.IO.Parse.Basic /-! Shared parsing helpers for value-range payloads. @@ -15,7 +15,7 @@ namespace Nfp namespace IO -namespace Pure +namespace Parse namespace ValueRange @@ -110,7 +110,7 @@ def parseSeq (tokens : List (List String)) : Except String Nat := do end ValueRange -end Pure +end Parse end IO diff --git a/Nfp/IO/Pure.lean b/Nfp/IO/Pure.lean deleted file mode 100644 index 0119f01..0000000 --- a/Nfp/IO/Pure.lean +++ /dev/null @@ -1,14 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.IO.Pure.Basic -public import Nfp.IO.Pure.Downstream -public import Nfp.IO.Pure.InductionHead -public import Nfp.IO.Pure.Residual -public import Nfp.IO.Pure.SoftmaxMargin -public import Nfp.IO.Pure.ValueRange - -/-! -Aggregator for pure CLI parsing helpers. --/ diff --git a/Nfp/IO/Util.lean b/Nfp/IO/Util.lean index 0f183eb..00eaa42 100644 --- a/Nfp/IO/Util.lean +++ b/Nfp/IO/Util.lean @@ -2,7 +2,7 @@ module -public import Nfp.IO.Pure +public import Nfp.IO.Parse /-! Small shared helpers for IO parsing. @@ -20,7 +20,7 @@ def parseRatOpt (label : String) (raw? : Option String) : match raw? with | none => Except.ok none | some raw => - match Pure.parseRat raw with + match Parse.parseRat raw with | Except.ok v => Except.ok (some v) | Except.error msg => Except.error s!"invalid {label}: {msg}" From aadcd7ac9044245c396b8b129e6b14d8c099686e Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 08:22:40 +0100 Subject: [PATCH 220/244] Move GPT-2 head input builders out of Sound --- Nfp/{Sound => }/Gpt2/HeadInputs.lean | 6 +----- Nfp/Sound.lean | 1 - 2 files changed, 1 insertion(+), 6 deletions(-) rename Nfp/{Sound => }/Gpt2/HeadInputs.lean (98%) diff --git a/Nfp/Sound/Gpt2/HeadInputs.lean b/Nfp/Gpt2/HeadInputs.lean similarity index 98% rename from Nfp/Sound/Gpt2/HeadInputs.lean rename to Nfp/Gpt2/HeadInputs.lean index 10526c5..fb3eb1e 100644 --- a/Nfp/Sound/Gpt2/HeadInputs.lean +++ b/Nfp/Gpt2/HeadInputs.lean @@ -7,7 +7,7 @@ public import Nfp.Model.InductionHead public import Nfp.Model.InductionPrompt /-! -Sound builder for GPT-2 induction head inputs. +Untrusted builder for GPT-2 induction head inputs. This converts exact GPT-2 head slices into `InductionHeadInputs` using a periodic prompt description. The construction is purely definitional and is @@ -18,8 +18,6 @@ public section namespace Nfp -namespace Sound - namespace Gpt2 open Nfp.Model @@ -140,6 +138,4 @@ theorem buildInductionHeadInputsShift_prev_spec {seq dModel dHead vocab : Nat} end Gpt2 -end Sound - end Nfp diff --git a/Nfp/Sound.lean b/Nfp/Sound.lean index dddb21c..d7ad656 100644 --- a/Nfp/Sound.lean +++ b/Nfp/Sound.lean @@ -2,7 +2,6 @@ module -public import Nfp.Sound.Gpt2.HeadInputs public import Nfp.Sound.Induction public import Nfp.Sound.Bounds public import Nfp.Sound.Linear.FinFold From 3212c4d07d8b0635d421f1fbad96b5ad8ec24fa8 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 08:35:21 +0100 Subject: [PATCH 221/244] Move bounds and linear helpers out of Sound --- Nfp/Bounds.lean | 17 +++++++++++++++++ Nfp/{Sound => }/Bounds/Attention.lean | 10 ++++------ Nfp/{Sound => }/Bounds/Cache.lean | 2 -- Nfp/{Sound => }/Bounds/Gelu.lean | 2 -- Nfp/Bounds/LayerNorm.lean | 12 ++++++++++++ Nfp/{Sound => }/Bounds/LayerNorm/Basic.lean | 8 +++----- Nfp/{Sound => }/Bounds/LayerNorm/InvStd.lean | 6 ++---- .../Bounds/LayerNorm/MeanVariance.lean | 2 -- .../Bounds/LayerNorm/SqrtBounds.lean | 2 -- Nfp/{Sound => }/Bounds/MatrixNorm.lean | 4 ++-- Nfp/{Sound => }/Bounds/MatrixNorm/Basic.lean | 6 ++---- Nfp/{Sound => }/Bounds/MatrixNorm/Interval.lean | 4 +--- Nfp/{Sound => }/Bounds/Mlp.lean | 8 +++----- Nfp/{Sound => }/Bounds/Transformer.lean | 4 ++-- Nfp/{Sound => }/Bounds/Transformer/Basic.lean | 10 ++++------ .../Bounds/Transformer/Embedding.lean | 2 -- Nfp/{Sound => }/Bounds/UnnormRat.lean | 4 +--- Nfp/{Sound => }/Linear/FinFold.lean | 2 -- Nfp/Sound.lean | 2 -- Nfp/Sound/Bounds.lean | 17 ----------------- Nfp/Sound/Bounds/LayerNorm.lean | 12 ------------ Nfp/Sound/Induction/CoreDefs.lean | 8 ++++---- Nfp/Sound/Induction/EndToEnd.lean | 2 +- Nfp/Sound/Induction/LogitDiff.lean | 4 ++-- 24 files changed, 60 insertions(+), 90 deletions(-) create mode 100644 Nfp/Bounds.lean rename Nfp/{Sound => }/Bounds/Attention.lean (99%) rename Nfp/{Sound => }/Bounds/Cache.lean (99%) rename Nfp/{Sound => }/Bounds/Gelu.lean (99%) create mode 100644 Nfp/Bounds/LayerNorm.lean rename Nfp/{Sound => }/Bounds/LayerNorm/Basic.lean (99%) rename Nfp/{Sound => }/Bounds/LayerNorm/InvStd.lean (97%) rename Nfp/{Sound => }/Bounds/LayerNorm/MeanVariance.lean (99%) rename Nfp/{Sound => }/Bounds/LayerNorm/SqrtBounds.lean (99%) rename Nfp/{Sound => }/Bounds/MatrixNorm.lean (56%) rename Nfp/{Sound => }/Bounds/MatrixNorm/Basic.lean (98%) rename Nfp/{Sound => }/Bounds/MatrixNorm/Interval.lean (99%) rename Nfp/{Sound => }/Bounds/Mlp.lean (98%) rename Nfp/{Sound => }/Bounds/Transformer.lean (54%) rename Nfp/{Sound => }/Bounds/Transformer/Basic.lean (99%) rename Nfp/{Sound => }/Bounds/Transformer/Embedding.lean (99%) rename Nfp/{Sound => }/Bounds/UnnormRat.lean (95%) rename Nfp/{Sound => }/Linear/FinFold.lean (99%) delete mode 100644 Nfp/Sound/Bounds.lean delete mode 100644 Nfp/Sound/Bounds/LayerNorm.lean diff --git a/Nfp/Bounds.lean b/Nfp/Bounds.lean new file mode 100644 index 0000000..fd95fa5 --- /dev/null +++ b/Nfp/Bounds.lean @@ -0,0 +1,17 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Bounds.Attention +public import Nfp.Bounds.Cache +public import Nfp.Bounds.Gelu +public import Nfp.Bounds.LayerNorm +public import Nfp.Bounds.LayerNorm.InvStd +public import Nfp.Bounds.MatrixNorm +public import Nfp.Bounds.Mlp +public import Nfp.Bounds.Transformer +public import Nfp.Bounds.UnnormRat + +/-! +Aggregator for untrusted interval bounds. +-/ diff --git a/Nfp/Sound/Bounds/Attention.lean b/Nfp/Bounds/Attention.lean similarity index 99% rename from Nfp/Sound/Bounds/Attention.lean rename to Nfp/Bounds/Attention.lean index 085bc1b..935104d 100644 --- a/Nfp/Sound/Bounds/Attention.lean +++ b/Nfp/Bounds/Attention.lean @@ -9,10 +9,10 @@ public import Mathlib.Data.Real.Basic public import Nfp.Circuit.Layers.Softmax public import Nfp.Core.Basic public import Nfp.Model.Gpt2 -public import Nfp.Sound.Bounds.Cache -public import Nfp.Sound.Bounds.LayerNorm -public import Nfp.Sound.Bounds.MatrixNorm -public import Nfp.Sound.Bounds.Mlp +public import Nfp.Bounds.Cache +public import Nfp.Bounds.LayerNorm +public import Nfp.Bounds.MatrixNorm +public import Nfp.Bounds.Mlp /-! Interval bounds for multi-head attention and transformer layers. @@ -22,7 +22,6 @@ public section namespace Nfp -namespace Sound namespace Bounds @@ -401,6 +400,5 @@ theorem transformerLayerBounds_spec {seq dModel dHead numHeads hidden : Nat} [Ne end Bounds -end Sound end Nfp diff --git a/Nfp/Sound/Bounds/Cache.lean b/Nfp/Bounds/Cache.lean similarity index 99% rename from Nfp/Sound/Bounds/Cache.lean rename to Nfp/Bounds/Cache.lean index a88f9d8..a4f8101 100644 --- a/Nfp/Sound/Bounds/Cache.lean +++ b/Nfp/Bounds/Cache.lean @@ -12,7 +12,6 @@ public section namespace Nfp -namespace Sound namespace Bounds @@ -219,6 +218,5 @@ theorem cacheBoundPair2TaskElem_apply_right {m n : Nat} end Bounds -end Sound end Nfp diff --git a/Nfp/Sound/Bounds/Gelu.lean b/Nfp/Bounds/Gelu.lean similarity index 99% rename from Nfp/Sound/Bounds/Gelu.lean rename to Nfp/Bounds/Gelu.lean index 432abec..77edfce 100644 --- a/Nfp/Sound/Bounds/Gelu.lean +++ b/Nfp/Bounds/Gelu.lean @@ -16,7 +16,6 @@ public section namespace Nfp -namespace Sound namespace Bounds @@ -162,6 +161,5 @@ theorem geluInterval_bounds {lo hi : Rat} {x : Real} end Bounds -end Sound end Nfp diff --git a/Nfp/Bounds/LayerNorm.lean b/Nfp/Bounds/LayerNorm.lean new file mode 100644 index 0000000..28d0317 --- /dev/null +++ b/Nfp/Bounds/LayerNorm.lean @@ -0,0 +1,12 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Bounds.LayerNorm.Basic +public import Nfp.Bounds.LayerNorm.InvStd +public import Nfp.Bounds.LayerNorm.MeanVariance +public import Nfp.Bounds.LayerNorm.SqrtBounds + +/-! +LayerNorm bounds and supporting lemmas. +-/ diff --git a/Nfp/Sound/Bounds/LayerNorm/Basic.lean b/Nfp/Bounds/LayerNorm/Basic.lean similarity index 99% rename from Nfp/Sound/Bounds/LayerNorm/Basic.lean rename to Nfp/Bounds/LayerNorm/Basic.lean index 998402a..24f1a98 100644 --- a/Nfp/Sound/Bounds/LayerNorm/Basic.lean +++ b/Nfp/Bounds/LayerNorm/Basic.lean @@ -11,9 +11,9 @@ public import Mathlib.Data.Real.Sqrt public import Mathlib.Data.Rat.BigOperators public import Mathlib.Data.Rat.Cast.Order public import Nfp.Core.Basic -public import Nfp.Sound.Bounds.LayerNorm.MeanVariance -public import Nfp.Sound.Bounds.LayerNorm.SqrtBounds -public import Nfp.Sound.Linear.FinFold +public import Nfp.Bounds.LayerNorm.MeanVariance +public import Nfp.Bounds.LayerNorm.SqrtBounds +public import Nfp.Linear.FinFold /-! LayerNorm interval bounds for rational inputs. @@ -26,7 +26,6 @@ public section namespace Nfp -namespace Sound namespace Bounds @@ -1025,6 +1024,5 @@ theorem layerNormIntervalBounds_spec_real {n : Nat} end Bounds -end Sound end Nfp diff --git a/Nfp/Sound/Bounds/LayerNorm/InvStd.lean b/Nfp/Bounds/LayerNorm/InvStd.lean similarity index 97% rename from Nfp/Sound/Bounds/LayerNorm/InvStd.lean rename to Nfp/Bounds/LayerNorm/InvStd.lean index 4e2e03f..c4dadec 100644 --- a/Nfp/Sound/Bounds/LayerNorm/InvStd.lean +++ b/Nfp/Bounds/LayerNorm/InvStd.lean @@ -2,8 +2,8 @@ module -public import Nfp.Sound.Bounds.LayerNorm.MeanVariance -public import Nfp.Sound.Bounds.LayerNorm.SqrtBounds +public import Nfp.Bounds.LayerNorm.MeanVariance +public import Nfp.Bounds.LayerNorm.SqrtBounds /-! Inverse-standard-deviation bounds for LayerNorm. @@ -16,7 +16,6 @@ public section namespace Nfp -namespace Sound namespace Bounds @@ -135,6 +134,5 @@ theorem invStdBounds_spec {n : Nat} (eps : Rat) (x : Fin n → Rat) end Bounds -end Sound end Nfp diff --git a/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean b/Nfp/Bounds/LayerNorm/MeanVariance.lean similarity index 99% rename from Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean rename to Nfp/Bounds/LayerNorm/MeanVariance.lean index e07069b..ef3bf3f 100644 --- a/Nfp/Sound/Bounds/LayerNorm/MeanVariance.lean +++ b/Nfp/Bounds/LayerNorm/MeanVariance.lean @@ -22,7 +22,6 @@ public section namespace Nfp -namespace Sound namespace Bounds @@ -242,6 +241,5 @@ theorem meanReal_abs_le_bound {n : Nat} (x : Fin n → Real) (bound : Rat) end Bounds -end Sound end Nfp diff --git a/Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean b/Nfp/Bounds/LayerNorm/SqrtBounds.lean similarity index 99% rename from Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean rename to Nfp/Bounds/LayerNorm/SqrtBounds.lean index 40362ea..c1f0828 100644 --- a/Nfp/Sound/Bounds/LayerNorm/SqrtBounds.lean +++ b/Nfp/Bounds/LayerNorm/SqrtBounds.lean @@ -20,7 +20,6 @@ public section namespace Nfp -namespace Sound namespace Bounds @@ -826,6 +825,5 @@ theorem real_sqrt_le_sqrtUpper {q : Rat} (hq : 0 ≤ q) : end Bounds -end Sound end Nfp diff --git a/Nfp/Sound/Bounds/MatrixNorm.lean b/Nfp/Bounds/MatrixNorm.lean similarity index 56% rename from Nfp/Sound/Bounds/MatrixNorm.lean rename to Nfp/Bounds/MatrixNorm.lean index 988f4bd..d278f37 100644 --- a/Nfp/Sound/Bounds/MatrixNorm.lean +++ b/Nfp/Bounds/MatrixNorm.lean @@ -2,8 +2,8 @@ module -public import Nfp.Sound.Bounds.MatrixNorm.Basic -public import Nfp.Sound.Bounds.MatrixNorm.Interval +public import Nfp.Bounds.MatrixNorm.Basic +public import Nfp.Bounds.MatrixNorm.Interval /-! Matrix norm and interval bound helpers for downstream certificates. diff --git a/Nfp/Sound/Bounds/MatrixNorm/Basic.lean b/Nfp/Bounds/MatrixNorm/Basic.lean similarity index 98% rename from Nfp/Sound/Bounds/MatrixNorm/Basic.lean rename to Nfp/Bounds/MatrixNorm/Basic.lean index e431416..23d6ace 100644 --- a/Nfp/Sound/Bounds/MatrixNorm/Basic.lean +++ b/Nfp/Bounds/MatrixNorm/Basic.lean @@ -9,8 +9,8 @@ public import Mathlib.Data.Fintype.Basic public import Mathlib.Data.Matrix.Mul public import Mathlib.Data.Real.Basic public import Nfp.Core.Basic -public import Nfp.Sound.Bounds.MatrixNorm.Interval -public import Nfp.Sound.Linear.FinFold +public import Nfp.Bounds.MatrixNorm.Interval +public import Nfp.Linear.FinFold /-! Row-sum matrix norms for downstream linear certificates. @@ -23,7 +23,6 @@ public section namespace Nfp -namespace Sound namespace Bounds @@ -153,6 +152,5 @@ theorem abs_mulVec_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) end Bounds -end Sound end Nfp diff --git a/Nfp/Sound/Bounds/MatrixNorm/Interval.lean b/Nfp/Bounds/MatrixNorm/Interval.lean similarity index 99% rename from Nfp/Sound/Bounds/MatrixNorm/Interval.lean rename to Nfp/Bounds/MatrixNorm/Interval.lean index 1e99329..4395611 100644 --- a/Nfp/Sound/Bounds/MatrixNorm/Interval.lean +++ b/Nfp/Bounds/MatrixNorm/Interval.lean @@ -8,7 +8,7 @@ public import Mathlib.Algebra.Order.Ring.Abs public import Mathlib.Data.Matrix.Mul public import Mathlib.Data.Real.Basic public import Nfp.Core.Basic -public import Nfp.Sound.Linear.FinFold +public import Nfp.Linear.FinFold /-! Interval bounds for dot products and matrix-vector products. @@ -20,7 +20,6 @@ public section namespace Nfp -namespace Sound namespace Bounds @@ -918,6 +917,5 @@ theorem mulVecIntervalLower_le_upper {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat end Bounds -end Sound end Nfp diff --git a/Nfp/Sound/Bounds/Mlp.lean b/Nfp/Bounds/Mlp.lean similarity index 98% rename from Nfp/Sound/Bounds/Mlp.lean rename to Nfp/Bounds/Mlp.lean index 1af5c3f..66eed18 100644 --- a/Nfp/Sound/Bounds/Mlp.lean +++ b/Nfp/Bounds/Mlp.lean @@ -4,9 +4,9 @@ module public import Mathlib.Algebra.BigOperators.Group.Finset.Basic public import Nfp.Core.Basic -public import Nfp.Sound.Bounds.Gelu -public import Nfp.Sound.Bounds.LayerNorm -public import Nfp.Sound.Bounds.MatrixNorm +public import Nfp.Bounds.Gelu +public import Nfp.Bounds.LayerNorm +public import Nfp.Bounds.MatrixNorm /-! Interval bounds for GPT-2 MLP blocks (linear + GELU + linear). @@ -16,7 +16,6 @@ public section namespace Nfp -namespace Sound namespace Bounds @@ -293,6 +292,5 @@ theorem layerNormAbsMlpResidualBounds_spec {n hidden : Nat} end Bounds -end Sound end Nfp diff --git a/Nfp/Sound/Bounds/Transformer.lean b/Nfp/Bounds/Transformer.lean similarity index 54% rename from Nfp/Sound/Bounds/Transformer.lean rename to Nfp/Bounds/Transformer.lean index 2104172..23a059a 100644 --- a/Nfp/Sound/Bounds/Transformer.lean +++ b/Nfp/Bounds/Transformer.lean @@ -2,8 +2,8 @@ module -public import Nfp.Sound.Bounds.Transformer.Basic -public import Nfp.Sound.Bounds.Transformer.Embedding +public import Nfp.Bounds.Transformer.Basic +public import Nfp.Bounds.Transformer.Embedding /-! Transformer-stack interval bounds and supporting lemmas. diff --git a/Nfp/Sound/Bounds/Transformer/Basic.lean b/Nfp/Bounds/Transformer/Basic.lean similarity index 99% rename from Nfp/Sound/Bounds/Transformer/Basic.lean rename to Nfp/Bounds/Transformer/Basic.lean index d9f134e..c5388e3 100644 --- a/Nfp/Sound/Bounds/Transformer/Basic.lean +++ b/Nfp/Bounds/Transformer/Basic.lean @@ -7,10 +7,10 @@ public import Mathlib.Data.List.Range public import Mathlib.Data.Real.Basic public import Nfp.Circuit.Cert.ResidualInterval public import Nfp.Model.Gpt2 -public import Nfp.Sound.Bounds.Attention -public import Nfp.Sound.Bounds.LayerNorm -public import Nfp.Sound.Bounds.Transformer.Embedding -public import Nfp.Sound.Linear.FinFold +public import Nfp.Bounds.Attention +public import Nfp.Bounds.LayerNorm +public import Nfp.Bounds.Transformer.Embedding +public import Nfp.Linear.FinFold /-! Interval bounds for transformer stacks and final LayerNorm outputs. @@ -20,7 +20,6 @@ public section namespace Nfp -namespace Sound namespace Bounds @@ -519,6 +518,5 @@ theorem gpt2ResidualIntervalBoundsActive_spec end Bounds -end Sound end Nfp diff --git a/Nfp/Sound/Bounds/Transformer/Embedding.lean b/Nfp/Bounds/Transformer/Embedding.lean similarity index 99% rename from Nfp/Sound/Bounds/Transformer/Embedding.lean rename to Nfp/Bounds/Transformer/Embedding.lean index 394eeb0..9f1943a 100644 --- a/Nfp/Sound/Bounds/Transformer/Embedding.lean +++ b/Nfp/Bounds/Transformer/Embedding.lean @@ -15,7 +15,6 @@ public section namespace Nfp -namespace Sound namespace Bounds @@ -128,6 +127,5 @@ theorem intervalBoundsOn_spec {seq dModel : Nat} [NeZero seq] end Bounds -end Sound end Nfp diff --git a/Nfp/Sound/Bounds/UnnormRat.lean b/Nfp/Bounds/UnnormRat.lean similarity index 95% rename from Nfp/Sound/Bounds/UnnormRat.lean rename to Nfp/Bounds/UnnormRat.lean index 90824fa..1c3b5f6 100644 --- a/Nfp/Sound/Bounds/UnnormRat.lean +++ b/Nfp/Bounds/UnnormRat.lean @@ -3,7 +3,7 @@ module public import Nfp.Core.Basic -public import Nfp.Sound.Linear.FinFold +public import Nfp.Linear.FinFold /-! Unnormalized rational arithmetic. @@ -16,7 +16,6 @@ public section namespace Nfp -namespace Sound namespace Bounds @@ -59,6 +58,5 @@ theorem UnnormRat.toRat_sumFin (n : Nat) (f : Fin n → UnnormRat) : end Bounds -end Sound end Nfp diff --git a/Nfp/Sound/Linear/FinFold.lean b/Nfp/Linear/FinFold.lean similarity index 99% rename from Nfp/Sound/Linear/FinFold.lean rename to Nfp/Linear/FinFold.lean index 1d72965..f4114f6 100644 --- a/Nfp/Sound/Linear/FinFold.lean +++ b/Nfp/Linear/FinFold.lean @@ -18,7 +18,6 @@ public section namespace Nfp -namespace Sound namespace Linear @@ -126,6 +125,5 @@ theorem dotProduct_mul_right {n : Nat} (x y : Fin n → Real) (a : Real) : end Linear -end Sound end Nfp diff --git a/Nfp/Sound.lean b/Nfp/Sound.lean index d7ad656..26a7f9d 100644 --- a/Nfp/Sound.lean +++ b/Nfp/Sound.lean @@ -3,8 +3,6 @@ module public import Nfp.Sound.Induction -public import Nfp.Sound.Bounds -public import Nfp.Sound.Linear.FinFold /-! Sound certificate builders and verified helpers. diff --git a/Nfp/Sound/Bounds.lean b/Nfp/Sound/Bounds.lean deleted file mode 100644 index b5d1015..0000000 --- a/Nfp/Sound/Bounds.lean +++ /dev/null @@ -1,17 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.Sound.Bounds.Attention -public import Nfp.Sound.Bounds.Cache -public import Nfp.Sound.Bounds.Gelu -public import Nfp.Sound.Bounds.LayerNorm -public import Nfp.Sound.Bounds.LayerNorm.InvStd -public import Nfp.Sound.Bounds.MatrixNorm -public import Nfp.Sound.Bounds.Mlp -public import Nfp.Sound.Bounds.Transformer -public import Nfp.Sound.Bounds.UnnormRat - -/-! -Aggregator for sound interval bounds. --/ diff --git a/Nfp/Sound/Bounds/LayerNorm.lean b/Nfp/Sound/Bounds/LayerNorm.lean deleted file mode 100644 index ab110cf..0000000 --- a/Nfp/Sound/Bounds/LayerNorm.lean +++ /dev/null @@ -1,12 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.Sound.Bounds.LayerNorm.Basic -public import Nfp.Sound.Bounds.LayerNorm.InvStd -public import Nfp.Sound.Bounds.LayerNorm.MeanVariance -public import Nfp.Sound.Bounds.LayerNorm.SqrtBounds - -/-! -LayerNorm bounds and supporting lemmas. --/ diff --git a/Nfp/Sound/Induction/CoreDefs.lean b/Nfp/Sound/Induction/CoreDefs.lean index 3444c49..ec7294e 100644 --- a/Nfp/Sound/Induction/CoreDefs.lean +++ b/Nfp/Sound/Induction/CoreDefs.lean @@ -8,9 +8,9 @@ public import Nfp.Circuit.Layers.Induction public import Nfp.Circuit.Layers.Softmax public import Nfp.Core.Basic public import Nfp.Model.InductionHead -public import Nfp.Sound.Bounds.LayerNorm -public import Nfp.Sound.Bounds.MatrixNorm -public import Nfp.Sound.Linear.FinFold +public import Nfp.Bounds.LayerNorm +public import Nfp.Bounds.MatrixNorm +public import Nfp.Linear.FinFold /-! Core definitions for induction-head certificates. @@ -27,7 +27,7 @@ namespace Sound open scoped BigOperators open Nfp.Circuit -open Nfp.Sound.Bounds +open Nfp.Bounds variable {seq : Nat} diff --git a/Nfp/Sound/Induction/EndToEnd.lean b/Nfp/Sound/Induction/EndToEnd.lean index 9a28bc1..400911c 100644 --- a/Nfp/Sound/Induction/EndToEnd.lean +++ b/Nfp/Sound/Induction/EndToEnd.lean @@ -2,7 +2,7 @@ module -public import Nfp.Sound.Bounds.Transformer +public import Nfp.Bounds.Transformer public import Nfp.Sound.Induction.HeadOutput public import Nfp.Sound.Induction.LogitDiff diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean index 970daaf..875df15 100644 --- a/Nfp/Sound/Induction/LogitDiff.lean +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -6,8 +6,8 @@ public import Aesop public import Mathlib.Data.List.MinMax public import Mathlib.Data.Vector.Basic public import Nfp.Circuit.Cert.LogitDiff -public import Nfp.Sound.Bounds.Cache -public import Nfp.Sound.Bounds.MatrixNorm.Interval +public import Nfp.Bounds.Cache +public import Nfp.Bounds.MatrixNorm.Interval public import Nfp.Sound.Induction.HeadOutput /-! From 24e0c9e900039acd1fac41194f557640f8c02fa6 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 09:04:38 +0100 Subject: [PATCH 222/244] Align cert pipeline to Python generation and Lean checks --- CLAIMS.md | 65 +- Nfp/Gpt2/HeadInputs.lean | 141 -- Nfp/IO/Loaders.lean | 22 +- Nfp/IO/Parse.lean | 1 - Nfp/IO/Parse/Downstream.lean | 128 +- Nfp/IO/Parse/InductionHead.lean | 29 - Nfp/IO/Parse/InductionHead/Bytes.lean | 790 ----------- Nfp/IO/Parse/SoftmaxMargin.lean | 1 - Nfp/IO/Parse/SoftmaxMargin/Raw.lean | 84 -- Nfp/IO/Parse/ValueRange.lean | 1 - Nfp/IO/Parse/ValueRange/Raw.lean | 66 - Nfp/Model/Gpt2.lean | 2 +- Nfp/Sound.lean | 2 +- Nfp/Sound/Induction/CoreDefs.lean | 2 +- README.md | 199 +-- SOUNDNESS_LIMITATIONS.md | 51 +- docs/induction_cert_audit.md | 99 +- scripts/build_downstream_linear_cert.py | 4 +- scripts/build_gpt2_head_inputs.py | 369 ----- scripts/build_gpt2_induction_cert.py | 133 +- .../build_gpt2_induction_cert_from_binary.py | 102 +- scripts/build_residual_bound_cert.py | 2 +- scripts/build_residual_interval_cert.py | 3 +- scripts/certify_induction_head.py | 204 --- scripts/ci_sanity_forward_step.py | 206 --- scripts/demo_gpt2_induction_sound.sh | 35 - scripts/demo_gpt2_sound.sh | 31 - scripts/demo_tiny_induction_cert.sh | 26 - scripts/demo_tiny_local_binary.sh | 28 - scripts/discover_gpt2_induction_targets.py | 1251 ----------------- scripts/sanity_forward_step.py | 139 -- scripts/scan_gpt2_induction_sound.py | 377 ----- scripts/sweep_gpt2_induction_nonvacuous.py | 370 ----- 33 files changed, 274 insertions(+), 4689 deletions(-) delete mode 100644 Nfp/Gpt2/HeadInputs.lean delete mode 100644 Nfp/IO/Parse/InductionHead.lean delete mode 100644 Nfp/IO/Parse/InductionHead/Bytes.lean delete mode 100644 Nfp/IO/Parse/SoftmaxMargin/Raw.lean delete mode 100644 Nfp/IO/Parse/ValueRange/Raw.lean delete mode 100644 scripts/build_gpt2_head_inputs.py delete mode 100644 scripts/certify_induction_head.py delete mode 100644 scripts/ci_sanity_forward_step.py delete mode 100755 scripts/demo_gpt2_induction_sound.sh delete mode 100755 scripts/demo_gpt2_sound.sh delete mode 100755 scripts/demo_tiny_induction_cert.sh delete mode 100755 scripts/demo_tiny_local_binary.sh delete mode 100644 scripts/discover_gpt2_induction_targets.py delete mode 100644 scripts/sanity_forward_step.py delete mode 100755 scripts/scan_gpt2_induction_sound.py delete mode 100644 scripts/sweep_gpt2_induction_nonvacuous.py diff --git a/CLAIMS.md b/CLAIMS.md index eccb279..c7f0514 100644 --- a/CLAIMS.md +++ b/CLAIMS.md @@ -9,68 +9,43 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri - Softmax-margin certificate soundness: `checkSoftmaxMarginCert` implies `SoftmaxMarginBoundsOn`. - Value-range certificate soundness: `checkValueRangeCert` implies `ValueRangeBounds`. -- Induction-head certificate soundness: `InductionHeadCertSound` holds whenever - `buildInductionCertFromHeadCoreWith?` returns a certificate for the given inputs. -- Logit-diff lower bound lemmas: `logitDiffLowerBound_le` and - `logitDiffLowerBoundFromCert_le`. +- Induction-head certificate soundness: `checkInductionHeadCert` implies + `InductionHeadCertBounds`. +- Logit-diff lower bound lemmas: `logitDiffLowerBound_le`, `logitDiffLowerBoundAt_le`, and + `logitDiffLowerBoundWeightedAt_le`. - Bridge lemmas composing head logit-diff bounds with head outputs and residual - interval bounds: `headLogitDiff_eq_direction_dot_headOutput` and - `logitDiffLowerBound_with_residual`, plus interval-composition - `logitDiffLowerBound_with_output_intervals`. + interval bounds: `headLogitDiff_eq_direction_dot_headOutput`, + `logitDiffLowerBound_with_residual`, and `logitDiffLowerBound_with_output_intervals`. - Downstream linear certificate soundness: `checkDownstreamLinearCert` implies `DownstreamLinearBounds`. - Residual-interval certificate soundness: `checkResidualIntervalCert` implies `ResidualIntervalBounds`. -- GPT-2 residual interval bounds from model slices are sound for - `transformerStackFinalReal` on active positions (`gpt2ResidualIntervalBoundsActive_sound`). - End-to-end direction-dot lower bounds on `transformerStackFinalReal` can be derived by - composing head logit-diff bounds with head/output intervals + composing head logit-diff bounds with residual interval bounds (`logitDiffLowerBound_end_to_end_gpt2`). - Row-sum matrix norm bounds for `mulVec` under uniform input magnitude. - Tanh-GELU bounds and interval propagation through MLP layers. -- Interval bounds for multi-head attention and full transformer-layer residual blocks. -- Interval bounds for transformer stacks and final LayerNorm outputs. +- Interval bounds for multi-head attention, transformer-layer residual blocks, transformer + stacks, and final LayerNorm outputs. ## Soundly checked by the trusted CLI -- `nfp induction certify` verifies head-level induction certificates from either a head-input - file or a model binary, and can compute a logit-diff lower bound. -- `nfp induction certify_nonvacuous` requires a strictly positive logit-diff lower bound. -- `nfp induction advanced certify_sound` recomputes `eps`/`margin` and `lo`/`hi` from raw - entries and verifies the resulting certificates. -- `nfp induction advanced certify_head` recomputes scores/values from exact head inputs and - verifies the resulting induction certificate (experimental, potentially slow). -- `nfp induction advanced certify_head_model` reads a model binary, derives head inputs in Lean, - and verifies the resulting induction certificate (includes attention projection biases and - derives `prev`/active from the stored token sequence by default, and builds the logit-diff - direction vector from the target/negative unembedding columns). -- `nfp induction advanced certify_head_model_auto` derives the logit-diff direction from the - prompt tokens stored in the model file before running the same head-input checker (the - direction vector still uses the unembedding columns). -- `nfp induction advanced certify_end_to_end` composes a head-level logit-diff lower bound with - a downstream error certificate (arithmetic consistency only). -- `nfp induction advanced certify_end_to_end_matrix` computes a downstream bound from a matrix - payload using verified row-sum norms, then composes it with the head-level logit-diff lower - bound. -- `nfp induction advanced certify_end_to_end_model` derives the unembedding direction from an - `NFP_BINARY_V1` model file, computes a downstream error bound from either a supplied - residual-interval certificate or a verified model-derived interval, and composes it with the - head-level logit-diff lower bound (optionally using `--layer/--head` to add head-output - interval bounds for a tighter end-to-end check). +- `nfp induction certify`, `nfp induction certify_nonvacuous`, and + `nfp induction head_cert_check` verify explicit induction-head certificates from a single + cert file, optionally enforcing minimum `active`, `margin`, `eps`, and logit-diff gates. ## Untrusted / heuristic -- Python helpers that generate certificates from GPT-2 weights or head inputs: - `scripts/build_gpt2_induction_cert.py`, `scripts/build_gpt2_head_inputs.py`, +- Python helpers that generate explicit certificates from GPT-2 weights or `.nfpt` files: + `scripts/build_gpt2_induction_cert.py`, `scripts/build_gpt2_induction_cert_from_binary.py`, + `scripts/build_residual_interval_cert.py`, `scripts/build_residual_bound_cert.py`, and `scripts/build_downstream_linear_cert.py`. -- The head-input extractor now emits attention projection biases and LayerNorm metadata, but - the Lean-side computation still ignores LayerNorm and the shared attention output bias. -- External residual-interval scripts remain untrusted; model-derived bounds are now available. -- Any downstream error bound provided externally (outside the matrix-payload path). +- Exporters and dataset generators for `.nfpt` model files. +- Any choice of prompts, directions, or candidate heads used by certificate generators. ## Not yet proven +- A verified extraction pipeline from model weights to explicit certificates. - End-to-end claims about GPT-2 logits or Jacobians derived from certificates. -- Sound, verified downstream bounds computed from GPT-2 weights inside Lean. -- A full end-to-end bridge from head certificates to full-model logit bounds - (beyond the head-output + residual-interval composition). +- A full bridge from explicit certificates to complete model semantics (beyond head-level + and residual-interval compositions). diff --git a/Nfp/Gpt2/HeadInputs.lean b/Nfp/Gpt2/HeadInputs.lean deleted file mode 100644 index fb3eb1e..0000000 --- a/Nfp/Gpt2/HeadInputs.lean +++ /dev/null @@ -1,141 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.Model.Gpt2 -public import Nfp.Model.InductionHead -public import Nfp.Model.InductionPrompt - -/-! -Untrusted builder for GPT-2 induction head inputs. - -This converts exact GPT-2 head slices into `InductionHeadInputs` using a -periodic prompt description. The construction is purely definitional and is -captured by an explicit theorem, so the trusted core does not hide any logic. --/ - -public section - -namespace Nfp - -namespace Gpt2 - -open Nfp.Model - -/-- -Build induction-head inputs from a GPT-2 head slice and prompt period. - -This uses the unshifted periodic prompt (`prev = q - period`), i.e. it matches -the current token rather than the canonical induction copy target. --/ -def buildInductionHeadInputs {seq dModel dHead vocab : Nat} - (slice : Gpt2HeadSlice seq dModel dHead vocab) (period : Nat) : - Model.InductionHeadInputs seq dModel dHead := - { scale := slice.scale - active := activeOfPeriod (seq := seq) period - prev := prevOfPeriod (seq := seq) period - embed := slice.embed - lnEps := slice.lnEps - ln1Gamma := slice.ln1Gamma - ln1Beta := slice.ln1Beta - wq := slice.wq - bq := slice.bq - wk := slice.wk - bk := slice.bk - wv := slice.wv - bv := slice.bv - wo := slice.wo - attnBias := slice.attnBias - maskCausal := true - maskValue := (-10000 : Rat) - directionSpec := slice.direction.spec - direction := slice.directionVec } - -/-- Definitional characterization of `buildInductionHeadInputs`. -/ -theorem buildInductionHeadInputs_def {seq dModel dHead vocab : Nat} - (slice : Gpt2HeadSlice seq dModel dHead vocab) (period : Nat) : - buildInductionHeadInputs slice period = - { scale := slice.scale - active := activeOfPeriod (seq := seq) period - prev := prevOfPeriod (seq := seq) period - embed := slice.embed - lnEps := slice.lnEps - ln1Gamma := slice.ln1Gamma - ln1Beta := slice.ln1Beta - wq := slice.wq - bq := slice.bq - wk := slice.wk - bk := slice.bk - wv := slice.wv - bv := slice.bv - wo := slice.wo - attnBias := slice.attnBias - maskCausal := true - maskValue := (-10000 : Rat) - directionSpec := slice.direction.spec - direction := slice.directionVec } := by - simp [buildInductionHeadInputs] - -/-- -Build induction-head inputs using the canonical shifted periodic prompt -(`prev = q - period + 1`, with `0 < period`). When `1 < period`, every active -query has `prev q < q`. --/ -def buildInductionHeadInputsShift {seq dModel dHead vocab : Nat} - (slice : Gpt2HeadSlice seq dModel dHead vocab) (period : Nat) : - Model.InductionHeadInputs seq dModel dHead := - { scale := slice.scale - active := activeOfPeriodShift (seq := seq) period - prev := prevOfPeriodShift (seq := seq) period - embed := slice.embed - lnEps := slice.lnEps - ln1Gamma := slice.ln1Gamma - ln1Beta := slice.ln1Beta - wq := slice.wq - bq := slice.bq - wk := slice.wk - bk := slice.bk - wv := slice.wv - bv := slice.bv - wo := slice.wo - attnBias := slice.attnBias - maskCausal := true - maskValue := (-10000 : Rat) - directionSpec := slice.direction.spec - direction := slice.directionVec } - -/-- Definitional characterization of `buildInductionHeadInputsShift`. -/ -theorem buildInductionHeadInputsShift_def {seq dModel dHead vocab : Nat} - (slice : Gpt2HeadSlice seq dModel dHead vocab) (period : Nat) : - buildInductionHeadInputsShift slice period = - { scale := slice.scale - active := activeOfPeriodShift (seq := seq) period - prev := prevOfPeriodShift (seq := seq) period - embed := slice.embed - lnEps := slice.lnEps - ln1Gamma := slice.ln1Gamma - ln1Beta := slice.ln1Beta - wq := slice.wq - bq := slice.bq - wk := slice.wk - bk := slice.bk - wv := slice.wv - bv := slice.bv - wo := slice.wo - attnBias := slice.attnBias - maskCausal := true - maskValue := (-10000 : Rat) - directionSpec := slice.direction.spec - direction := slice.directionVec } := by - simp [buildInductionHeadInputsShift] - -/-- `buildInductionHeadInputsShift` satisfies the shifted-period prev/active spec. -/ -theorem buildInductionHeadInputsShift_prev_spec {seq dModel dHead vocab : Nat} - (slice : Gpt2HeadSlice seq dModel dHead vocab) (period : Nat) : - InductionPrevSpecPeriodShift (seq := seq) period - (buildInductionHeadInputsShift slice period) := by - constructor <;> simp [buildInductionHeadInputsShift] - -end Gpt2 - -end Nfp diff --git a/Nfp/IO/Loaders.lean b/Nfp/IO/Loaders.lean index 5f1336f..8d1259e 100644 --- a/Nfp/IO/Loaders.lean +++ b/Nfp/IO/Loaders.lean @@ -3,13 +3,12 @@ module public import Nfp.IO.Parse -public import Nfp.Circuit.Cert.LogitDiff public import Nfp.Circuit.Cert.DownstreamLinear public import Nfp.Circuit.Cert.ResidualBound public import Nfp.Circuit.Cert.ResidualInterval /-! -IO loaders for certificates and raw inputs. +IO loaders for certificates. -/ public section @@ -26,12 +25,6 @@ def loadSoftmaxMarginCert (path : System.FilePath) : let data ← IO.FS.readFile path return Parse.parseSoftmaxMarginCert data -/-- Load raw softmax-margin inputs from disk. -/ -def loadSoftmaxMarginRaw (path : System.FilePath) : - IO (Except String (Sigma Parse.SoftmaxMarginRaw)) := do - let data ← IO.FS.readFile path - return Parse.parseSoftmaxMarginRaw data - /-- Load a value-range certificate from disk. -/ def loadValueRangeCert (path : System.FilePath) : IO (Except String (Sigma ValueRangeCert)) := do @@ -44,13 +37,6 @@ def loadDownstreamLinearCert (path : System.FilePath) : let data ← IO.FS.readFile path return Parse.parseDownstreamLinearCert data -/-- Load a downstream matrix payload from disk. -/ -def loadDownstreamMatrixRaw (path : System.FilePath) : - IO (Except String (Sigma (fun rows => - Sigma (fun cols => Parse.DownstreamMatrixRaw rows cols)))) := do - let data ← IO.FS.readFile path - return Parse.parseDownstreamMatrixRaw data - /-- Load a residual-bound certificate from disk. -/ def loadResidualBoundCert (path : System.FilePath) : IO (Except String (Sigma (fun n => ResidualBoundCert n))) := do @@ -63,12 +49,6 @@ def loadResidualIntervalCert (path : System.FilePath) : let data ← IO.FS.readFile path return Parse.parseResidualIntervalCert data -/-- Load raw value-range inputs from disk. -/ -def loadValueRangeRaw (path : System.FilePath) : - IO (Except String (Sigma Parse.ValueRangeRaw)) := do - let data ← IO.FS.readFile path - return Parse.parseValueRangeRaw data - end IO end Nfp diff --git a/Nfp/IO/Parse.lean b/Nfp/IO/Parse.lean index 82bd51b..3bc69fa 100644 --- a/Nfp/IO/Parse.lean +++ b/Nfp/IO/Parse.lean @@ -4,7 +4,6 @@ module public import Nfp.IO.Parse.Basic public import Nfp.IO.Parse.Downstream -public import Nfp.IO.Parse.InductionHead public import Nfp.IO.Parse.Residual public import Nfp.IO.Parse.SoftmaxMargin public import Nfp.IO.Parse.ValueRange diff --git a/Nfp/IO/Parse/Downstream.lean b/Nfp/IO/Parse/Downstream.lean index 2711bdc..f950aec 100644 --- a/Nfp/IO/Parse/Downstream.lean +++ b/Nfp/IO/Parse/Downstream.lean @@ -6,7 +6,7 @@ public import Nfp.Circuit.Cert.DownstreamLinear public import Nfp.IO.Parse.Basic /-! -Parse parsing helpers for downstream linear and matrix payloads. +Parse parsing helpers for downstream linear certificates. -/ public section @@ -73,132 +73,6 @@ def parseDownstreamLinearCert (input : String) : let st ← tokens.foldlM (fun st t => parseDownstreamLinearLine st t) st0 finalizeDownstreamLinearState st -private def initPrevOpt (n : Nat) : Array (Option (Fin n)) := - Array.replicate n none - -private def initActiveBits (n : Nat) : Array Bool := - Array.replicate n false - -private def activeFromBits {n : Nat} (bits : Array Bool) : Finset (Fin n) := - (Finset.univ : Finset (Fin n)).filter (fun i => bits.getD i.1 false) - -private def arrayAllSome {α : Type} (arr : Array (Option α)) : Bool := - (List.range arr.size).all (fun i => (arr.getD i none).isSome) - -private def matAllSome {α : Type} (mat : Array (Array (Option α))) : Bool := - (List.range mat.size).all (fun i => arrayAllSome (mat.getD i #[])) - -/-- Raw downstream matrix payload with an input bound. -/ -structure DownstreamMatrixRaw (rows cols : Nat) where - /-- Input magnitude bound. -/ - inputBound : Rat - /-- Matrix entries. -/ - entries : Fin rows → Fin cols → Rat - -private structure DownstreamMatrixParseState (rows cols : Nat) where - inputBound : Option Rat - entries : Fin rows → Fin cols → Option Rat - -private def initDownstreamMatrixState (rows cols : Nat) : - DownstreamMatrixParseState rows cols := - { inputBound := none, entries := fun _ _ => none } - -private def setRectEntry {rows cols : Nat} (mat : Fin rows → Fin cols → Option Rat) - (i j : Nat) (v : Rat) : Except String (Fin rows → Fin cols → Option Rat) := do - if hi : i < rows then - if hj : j < cols then - let iFin : Fin rows := ⟨i, hi⟩ - let jFin : Fin cols := ⟨j, hj⟩ - match mat iFin jFin with - | some _ => - throw s!"duplicate matrix entry at ({i}, {j})" - | none => - let mat' : Fin rows → Fin cols → Option Rat := fun i' j' => - if i' = iFin then - if j' = jFin then - some v - else - mat i' j' - else - mat i' j' - return mat' - else - throw s!"index out of range: col={j}" - else - throw s!"index out of range: row={i}" - -private def parseDownstreamMatrixLine {rows cols : Nat} - (st : DownstreamMatrixParseState rows cols) (tokens : List String) : - Except String (DownstreamMatrixParseState rows cols) := do - match tokens with - | ["input-bound", val] => - if st.inputBound.isSome then - throw "duplicate input-bound entry" - else - return { st with inputBound := some (← parseRat val) } - | ["w", i, j, val] => - let mat ← setRectEntry st.entries (← parseNat i) (← parseNat j) (← parseRat val) - return { st with entries := mat } - | _ => - throw s!"unrecognized line: '{String.intercalate " " tokens}'" - -private def finalizeDownstreamMatrixState {rows cols : Nat} - (st : DownstreamMatrixParseState rows cols) : - Except String (DownstreamMatrixRaw rows cols) := do - let inputBound ← - match st.inputBound with - | some v => pure v - | none => throw "missing input-bound entry" - if !finsetAll (Finset.univ : Finset (Fin rows)) (fun i => - finsetAll (Finset.univ : Finset (Fin cols)) (fun j => (st.entries i j).isSome)) then - throw "missing matrix entries" - let entries : Fin rows → Fin cols → Rat := fun i j => - (st.entries i j).getD 0 - return { inputBound := inputBound, entries := entries } - -/-- Parse a downstream matrix payload from text. -/ -def parseDownstreamMatrixRaw (input : String) : - Except String (Sigma (fun rows => Sigma (fun cols => DownstreamMatrixRaw rows cols))) := do - let lines := input.splitOn "\n" - let tokens := lines.filterMap cleanTokens - let mut rows? : Option Nat := none - let mut cols? : Option Nat := none - for t in tokens do - match t with - | ["rows", n] => - if rows?.isSome then - throw "duplicate rows entry" - else - rows? := some (← parseNat n) - | ["cols", n] => - if cols?.isSome then - throw "duplicate cols entry" - else - cols? := some (← parseNat n) - | _ => pure () - let rows ← - match rows? with - | some v => pure v - | none => throw "missing rows entry" - let cols ← - match cols? with - | some v => pure v - | none => throw "missing cols entry" - match rows, cols with - | 0, _ => throw "rows must be positive" - | _, 0 => throw "cols must be positive" - | Nat.succ r, Nat.succ c => - let rows := Nat.succ r - let cols := Nat.succ c - let st0 := initDownstreamMatrixState rows cols - let st ← tokens.foldlM (fun st t => - match t with - | ["rows", _] => pure st - | ["cols", _] => pure st - | _ => parseDownstreamMatrixLine st t) st0 - let raw ← finalizeDownstreamMatrixState st - return ⟨rows, ⟨cols, raw⟩⟩ - end Parse end IO diff --git a/Nfp/IO/Parse/InductionHead.lean b/Nfp/IO/Parse/InductionHead.lean deleted file mode 100644 index f6bc47f..0000000 --- a/Nfp/IO/Parse/InductionHead.lean +++ /dev/null @@ -1,29 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.IO.Parse.InductionHead.Bytes - -/-! -Parsing helpers for induction-head input payloads. --/ - -public section - -namespace Nfp - -namespace IO - -namespace Parse - -/-- Parse a raw induction head input payload from text. -/ -def parseInductionHeadInputs (input : String) : - Except String (Sigma (fun seq => - Sigma (fun dModel => Sigma (fun dHead => Model.InductionHeadInputs seq dModel dHead)))) := do - parseInductionHeadInputsBytes input.toUTF8 - -end Parse - -end IO - -end Nfp diff --git a/Nfp/IO/Parse/InductionHead/Bytes.lean b/Nfp/IO/Parse/InductionHead/Bytes.lean deleted file mode 100644 index 8b18d55..0000000 --- a/Nfp/IO/Parse/InductionHead/Bytes.lean +++ /dev/null @@ -1,790 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Mathlib.Data.Finset.Insert -public import Nfp.IO.Parse.Basic -public import Nfp.Model.InductionHead - -/-! -Parsing helpers for induction-head input payloads from UTF-8 bytes. --/ - -public section - -namespace Nfp - -namespace IO - -namespace Parse - -private def kwSeq : ByteArray := "seq".toUTF8 -private def kwDModel : ByteArray := "d_model".toUTF8 -private def kwDHead : ByteArray := "d_head".toUTF8 -private def kwScale : ByteArray := "scale".toUTF8 -private def kwActive : ByteArray := "active".toUTF8 -private def kwPrev : ByteArray := "prev".toUTF8 -private def kwEmbed : ByteArray := "embed".toUTF8 -private def kwLnEps : ByteArray := "ln_eps".toUTF8 -private def kwLn1Gamma : ByteArray := "ln1_gamma".toUTF8 -private def kwLn1Beta : ByteArray := "ln1_beta".toUTF8 -private def kwAttnBias : ByteArray := "attn_bias".toUTF8 -private def kwMask : ByteArray := "mask".toUTF8 -private def kwMaskValue : ByteArray := "mask_value".toUTF8 -private def kwCausal : ByteArray := "causal".toUTF8 -private def kwNone : ByteArray := "none".toUTF8 -private def kwDirection : ByteArray := "direction".toUTF8 -private def kwDirectionTarget : ByteArray := "direction-target".toUTF8 -private def kwDirectionNegative : ByteArray := "direction-negative".toUTF8 - -private structure ByteToken where - start : Nat - stop : Nat - -private def tokenLen (t : ByteToken) : Nat := - t.stop - t.start - -private def tokenEq (data : ByteArray) (t : ByteToken) (kw : ByteArray) : Bool := Id.run do - if tokenLen t != kw.size then - return false - let mut i := 0 - while i < kw.size do - if data.get! (t.start + i) != kw.get! i then - return false - i := i + 1 - return true - -private def parseNatBytesCore (data : ByteArray) (i stop : Nat) (acc : Nat) : - Except String Nat := - if h : i < stop then - let b := data.get! i - if b >= 48 && b <= 57 then - parseNatBytesCore data (i + 1) stop (acc * 10 + (b.toNat - 48)) - else - Except.error "expected Nat" - else - Except.ok acc -termination_by stop - i - -private def parseNatBytesSpec (data : ByteArray) (t : ByteToken) : Except String Nat := - if tokenLen t = 0 then - throw "expected Nat" - else - parseNatBytesCore data t.start t.stop 0 - -private def parseNatBytes (data : ByteArray) (t : ByteToken) : Except String Nat := - parseNatBytesSpec data t - -private theorem parseNatBytes_eq_spec (data : ByteArray) (t : ByteToken) : - parseNatBytes data t = parseNatBytesSpec data t := by - rfl - -private def parseIntBytesSpec (data : ByteArray) (t : ByteToken) : Except String Int := do - if tokenLen t = 0 then - throw "expected Int" - let first := data.get! t.start - if first = 45 then - let t' : ByteToken := { start := t.start + 1, stop := t.stop } - let n ← parseNatBytesSpec data t' - return -Int.ofNat n - else - let n ← parseNatBytesSpec data t - return Int.ofNat n - -private def parseIntBytes (data : ByteArray) (t : ByteToken) : Except String Int := - parseIntBytesSpec data t - -private theorem parseIntBytes_eq_spec (data : ByteArray) (t : ByteToken) : - parseIntBytes data t = parseIntBytesSpec data t := by - rfl - -private def findSlash (data : ByteArray) (i stop : Nat) : Option Nat := - if h : i < stop then - if data.get! i = 47 then - some i - else - findSlash data (i + 1) stop - else - none -termination_by stop - i - -private def parseRatBytesSpec (data : ByteArray) (t : ByteToken) : Except String Rat := do - match findSlash data t.start t.stop with - | none => - return ratRoundDown (Rat.ofInt (← parseIntBytesSpec data t)) - | some s => - let numTok : ByteToken := { start := t.start, stop := s } - let denTok : ByteToken := { start := s + 1, stop := t.stop } - let n ← parseIntBytesSpec data numTok - let d ← parseNatBytesSpec data denTok - if d = 0 then - throw "invalid rational: zero denominator" - else - return ratRoundDown (Rat.divInt n (Int.ofNat d)) - -private def parseRatBytes (data : ByteArray) (t : ByteToken) : Except String Rat := - parseRatBytesSpec data t - -private theorem parseRatBytes_eq_spec (data : ByteArray) (t : ByteToken) : - parseRatBytes data t = parseRatBytesSpec data t := by - rfl - -private def nextLineBounds (data : ByteArray) (start : Nat) : Nat × Nat × Nat := - Id.run do - let mut i := start - let lineStart := start - while i < data.size do - let b := data.get! i - if b == 10 || b == 13 then - let lineEnd := i - let mut j := i + 1 - if b == 13 && j < data.size && data.get! j == 10 then - j := j + 1 - return (j, lineStart, lineEnd) - i := i + 1 - return (data.size, lineStart, data.size) - -private def skipSpaces (data : ByteArray) (i lineEnd : Nat) : Nat := - Id.run do - let mut j := i - while j < lineEnd do - let b := data.get! j - if b == 32 || b == 9 then - j := j + 1 - else - break - return j - -private def readToken (data : ByteArray) (i lineEnd : Nat) : - Option (ByteToken × Nat) := - Id.run do - let j := skipSpaces data i lineEnd - if j >= lineEnd then - return none - let start := j - let mut k := j - while k < lineEnd do - let b := data.get! k - if b == 32 || b == 9 then - break - k := k + 1 - return some ({ start := start, stop := k }, k) - -private def expectToken (data : ByteArray) (i lineEnd : Nat) : - Except String (ByteToken × Nat) := do - match readToken data i lineEnd with - | some out => return out - | none => throw "expected token" - -private def ensureNoMoreTokens (data : ByteArray) (i lineEnd : Nat) : - Except String Unit := do - let j := skipSpaces data i lineEnd - if j < lineEnd then - throw "unrecognized line" - -private def parseNatAt (data : ByteArray) (i lineEnd : Nat) : - Except String (Nat × Nat) := do - let (tok, i') ← expectToken data i lineEnd - let n ← parseNatBytes data tok - return (n, i') - -private def parseRatAt (data : ByteArray) (i lineEnd : Nat) : - Except String (Rat × Nat) := do - let (tok, i') ← expectToken data i lineEnd - let r ← parseRatBytes data tok - return (r, i') - -private def setVecEntry (n : Nat) (vec : Array (Option Rat)) - (i : Nat) (v : Rat) : - Except String (Array (Option Rat)) := do - if i < n then - match vec.getD i none with - | some _ => - throw s!"duplicate entry for index={i}" - | none => - let vec' := vec.set! i (some v) - return vec' - else - throw s!"index out of range: i={i}" - -private def setMatEntry (rows cols : Nat) (mat : Array (Array (Option Rat))) - (i j : Nat) (v : Rat) : Except String (Array (Array (Option Rat))) := do - if i < rows then - if j < cols then - let row := mat.getD i #[] - match row.getD j none with - | some _ => - throw s!"duplicate entry for index=({i}, {j})" - | none => - let row' := row.set! j (some v) - let mat' := mat.set! i row' - return mat' - else - throw s!"index out of range: j={j}" - else - throw s!"index out of range: i={i}" - -private def initVecOpt (n : Nat) : Array (Option Rat) := - Array.replicate n none - -private def initMatOpt (rows cols : Nat) : Array (Array (Option Rat)) := - Array.replicate rows (initVecOpt cols) - -private def initPrevOpt (n : Nat) : Array (Option (Fin n)) := - Array.replicate n none - -private def initActiveBits (n : Nat) : Array Bool := - Array.replicate n false - -private def activeFromBits {n : Nat} (bits : Array Bool) : Finset (Fin n) := - (Finset.univ : Finset (Fin n)).filter (fun i => bits.getD i.1 false) - -private def arrayAllSome {α : Type} (arr : Array (Option α)) : Bool := - (List.range arr.size).all (fun i => (arr.getD i none).isSome) - -private def matAllSome {α : Type} (mat : Array (Array (Option α))) : Bool := - (List.range mat.size).all (fun i => arrayAllSome (mat.getD i #[])) - -private structure HeadParseState (seq dModel dHead : Nat) where - scale : Option Rat - activeBits : Array Bool - activeSeen : Bool - prev : Array (Option (Fin seq)) - embed : Array (Array (Option Rat)) - lnEps : Option Rat - ln1Gamma : Array (Option Rat) - ln1Beta : Array (Option Rat) - wq : Array (Array (Option Rat)) - bq : Array (Option Rat) - wk : Array (Array (Option Rat)) - bk : Array (Option Rat) - wv : Array (Array (Option Rat)) - bv : Array (Option Rat) - wo : Array (Array (Option Rat)) - attnBias : Array (Option Rat) - maskCausal : Option Bool - maskValue : Option Rat - directionTarget : Option Nat - directionNegative : Option Nat - direction : Array (Option Rat) - -private def initHeadState (seq dModel dHead : Nat) : HeadParseState seq dModel dHead := - { scale := none - activeBits := initActiveBits seq - activeSeen := false - prev := initPrevOpt seq - embed := initMatOpt seq dModel - lnEps := none - ln1Gamma := initVecOpt dModel - ln1Beta := initVecOpt dModel - wq := initMatOpt dModel dHead - bq := initVecOpt dHead - wk := initMatOpt dModel dHead - bk := initVecOpt dHead - wv := initMatOpt dModel dHead - bv := initVecOpt dHead - wo := initMatOpt dModel dHead - attnBias := initVecOpt dModel - maskCausal := none - maskValue := none - directionTarget := none - directionNegative := none - direction := initVecOpt dModel } - -private def setHeadActive {seq dModel dHead : Nat} - (st : HeadParseState seq dModel dHead) (q : Nat) : - Except String (HeadParseState seq dModel dHead) := do - if q < seq then - return { st with activeBits := st.activeBits.set! q true, activeSeen := true } - else - throw s!"active index out of range: q={q}" - -private def setHeadPrev {seq dModel dHead : Nat} - (st : HeadParseState seq dModel dHead) (q k : Nat) : - Except String (HeadParseState seq dModel dHead) := do - if q < seq then - if hk : k < seq then - let kFin : Fin seq := ⟨k, hk⟩ - match st.prev.getD q none with - | some _ => - throw s!"duplicate prev entry for q={q}" - | none => - return { st with prev := st.prev.set! q (some kFin) } - else - throw s!"prev index out of range: k={k}" - else - throw s!"prev index out of range: q={q}" - -private def parseHeadLine {seq dModel dHead : Nat} (st : HeadParseState seq dModel dHead) - (tokens : List String) : Except String (HeadParseState seq dModel dHead) := do - match tokens with - | ["scale", val] => - if st.scale.isSome then - throw "duplicate scale entry" - else - return { st with scale := some (← parseRat val) } - | ["active", q] => - setHeadActive st (← parseNat q) - | ["prev", q, k] => - setHeadPrev st (← parseNat q) (← parseNat k) - | ["embed", q, d, val] => - let mat ← - setMatEntry seq dModel st.embed (← parseNat q) (← parseNat d) (← parseRat val) - return { st with embed := mat } - | ["ln_eps", val] => - if st.lnEps.isSome then - throw "duplicate ln_eps entry" - else - return { st with lnEps := some (← parseRat val) } - | ["ln1_gamma", d, val] => - let vec ← setVecEntry dModel st.ln1Gamma (← parseNat d) (← parseRat val) - return { st with ln1Gamma := vec } - | ["ln1_beta", d, val] => - let vec ← setVecEntry dModel st.ln1Beta (← parseNat d) (← parseRat val) - return { st with ln1Beta := vec } - | ["wq", i, j, val] => - let mat ← - setMatEntry dModel dHead st.wq (← parseNat i) (← parseNat j) (← parseRat val) - return { st with wq := mat } - | ["bq", j, val] => - let vec ← setVecEntry dHead st.bq (← parseNat j) (← parseRat val) - return { st with bq := vec } - | ["wk", i, j, val] => - let mat ← - setMatEntry dModel dHead st.wk (← parseNat i) (← parseNat j) (← parseRat val) - return { st with wk := mat } - | ["bk", j, val] => - let vec ← setVecEntry dHead st.bk (← parseNat j) (← parseRat val) - return { st with bk := vec } - | ["wv", i, j, val] => - let mat ← - setMatEntry dModel dHead st.wv (← parseNat i) (← parseNat j) (← parseRat val) - return { st with wv := mat } - | ["bv", j, val] => - let vec ← setVecEntry dHead st.bv (← parseNat j) (← parseRat val) - return { st with bv := vec } - | ["wo", i, j, val] => - let mat ← - setMatEntry dModel dHead st.wo (← parseNat i) (← parseNat j) (← parseRat val) - return { st with wo := mat } - | ["attn_bias", d, val] => - let vec ← setVecEntry dModel st.attnBias (← parseNat d) (← parseRat val) - return { st with attnBias := vec } - | ["mask", kind] => - if st.maskCausal.isSome then - throw "duplicate mask entry" - else - match kind with - | "causal" => return { st with maskCausal := some true } - | "none" => return { st with maskCausal := some false } - | _ => throw "mask must be 'causal' or 'none'" - | ["mask_value", val] => - if st.maskValue.isSome then - throw "duplicate mask_value entry" - else - return { st with maskValue := some (← parseRat val) } - | ["direction-target", tok] => - if st.directionTarget.isSome then - throw "duplicate direction-target entry" - else - return { st with directionTarget := some (← parseNat tok) } - | ["direction-negative", tok] => - if st.directionNegative.isSome then - throw "duplicate direction-negative entry" - else - return { st with directionNegative := some (← parseNat tok) } - | ["direction", d, val] => - let vec ← setVecEntry dModel st.direction (← parseNat d) (← parseRat val) - return { st with direction := vec } - | _ => - throw s!"unrecognized line: '{String.intercalate " " tokens}'" - -private def parseHeadLineBytes {seq dModel dHead : Nat} (data : ByteArray) - (st : HeadParseState seq dModel dHead) (lineStart lineEnd : Nat) : - Except String (HeadParseState seq dModel dHead) := do - let i0 := skipSpaces data lineStart lineEnd - if i0 >= lineEnd then - return st - if data.get! i0 = 35 then - return st - match readToken data i0 lineEnd with - | none => return st - | some (t0, i1) => - let len := tokenLen t0 - let b0 := data.get! t0.start - match b0 with - | 115 => -- s - if len = kwSeq.size && tokenEq data t0 kwSeq then - return st - else if len = kwScale.size && tokenEq data t0 kwScale then - if st.scale.isSome then - throw "duplicate scale entry" - else - let (t1, i2) ← expectToken data i1 lineEnd - ensureNoMoreTokens data i2 lineEnd - return { st with scale := some (← parseRatBytes data t1) } - else - throw "unrecognized line" - | 97 => -- a - if len = kwActive.size && tokenEq data t0 kwActive then - let (q, i2) ← parseNatAt data i1 lineEnd - ensureNoMoreTokens data i2 lineEnd - setHeadActive st q - else if len = kwAttnBias.size && tokenEq data t0 kwAttnBias then - let (d, i2) ← parseNatAt data i1 lineEnd - let (v, i3) ← parseRatAt data i2 lineEnd - ensureNoMoreTokens data i3 lineEnd - let vec ← setVecEntry dModel st.attnBias d v - return { st with attnBias := vec } - else - throw "unrecognized line" - | 112 => -- p - if len = kwPrev.size && tokenEq data t0 kwPrev then - let (q, i2) ← parseNatAt data i1 lineEnd - let (k, i3) ← parseNatAt data i2 lineEnd - ensureNoMoreTokens data i3 lineEnd - setHeadPrev st q k - else - throw "unrecognized line" - | 101 => -- e - if len = kwEmbed.size && tokenEq data t0 kwEmbed then - let (q, i2) ← parseNatAt data i1 lineEnd - let (d, i3) ← parseNatAt data i2 lineEnd - let (v, i4) ← parseRatAt data i3 lineEnd - ensureNoMoreTokens data i4 lineEnd - let mat ← setMatEntry seq dModel st.embed q d v - return { st with embed := mat } - else - throw "unrecognized line" - | 108 => -- l - if len = kwLnEps.size && tokenEq data t0 kwLnEps then - if st.lnEps.isSome then - throw "duplicate ln_eps entry" - else - let (v, i2) ← parseRatAt data i1 lineEnd - ensureNoMoreTokens data i2 lineEnd - return { st with lnEps := some v } - else if len = kwLn1Gamma.size && tokenEq data t0 kwLn1Gamma then - let (d, i2) ← parseNatAt data i1 lineEnd - let (v, i3) ← parseRatAt data i2 lineEnd - ensureNoMoreTokens data i3 lineEnd - let vec ← setVecEntry dModel st.ln1Gamma d v - return { st with ln1Gamma := vec } - else if len = kwLn1Beta.size && tokenEq data t0 kwLn1Beta then - let (d, i2) ← parseNatAt data i1 lineEnd - let (v, i3) ← parseRatAt data i2 lineEnd - ensureNoMoreTokens data i3 lineEnd - let vec ← setVecEntry dModel st.ln1Beta d v - return { st with ln1Beta := vec } - else - throw "unrecognized line" - | 119 => -- w - if len = 2 then - let b1 := data.get! (t0.start + 1) - let (i, i2) ← parseNatAt data i1 lineEnd - let (j, i3) ← parseNatAt data i2 lineEnd - let (v, i4) ← parseRatAt data i3 lineEnd - ensureNoMoreTokens data i4 lineEnd - if b1 = 113 then - let mat ← setMatEntry dModel dHead st.wq i j v - return { st with wq := mat } - else if b1 = 107 then - let mat ← setMatEntry dModel dHead st.wk i j v - return { st with wk := mat } - else if b1 = 118 then - let mat ← setMatEntry dModel dHead st.wv i j v - return { st with wv := mat } - else if b1 = 111 then - let mat ← setMatEntry dModel dHead st.wo i j v - return { st with wo := mat } - else - throw "unrecognized line" - else - throw "unrecognized line" - | 98 => -- b - if len = 2 then - let b1 := data.get! (t0.start + 1) - let (j, i2) ← parseNatAt data i1 lineEnd - let (v, i3) ← parseRatAt data i2 lineEnd - ensureNoMoreTokens data i3 lineEnd - if b1 = 113 then - let vec ← setVecEntry dHead st.bq j v - return { st with bq := vec } - else if b1 = 107 then - let vec ← setVecEntry dHead st.bk j v - return { st with bk := vec } - else if b1 = 118 then - let vec ← setVecEntry dHead st.bv j v - return { st with bv := vec } - else - throw "unrecognized line" - else - throw "unrecognized line" - | 109 => -- m - if len = kwMask.size && tokenEq data t0 kwMask then - if st.maskCausal.isSome then - throw "duplicate mask entry" - else - let (t1, i2) ← expectToken data i1 lineEnd - ensureNoMoreTokens data i2 lineEnd - if tokenEq data t1 kwCausal then - return { st with maskCausal := some true } - else if tokenEq data t1 kwNone then - return { st with maskCausal := some false } - else - throw "mask must be 'causal' or 'none'" - else if len = kwMaskValue.size && tokenEq data t0 kwMaskValue then - if st.maskValue.isSome then - throw "duplicate mask_value entry" - else - let (v, i2) ← parseRatAt data i1 lineEnd - ensureNoMoreTokens data i2 lineEnd - return { st with maskValue := some v } - else - throw "unrecognized line" - | 100 => -- d - if len = kwDModel.size && tokenEq data t0 kwDModel then - return st - else if len = kwDHead.size && tokenEq data t0 kwDHead then - return st - else if len = kwDirection.size && tokenEq data t0 kwDirection then - let (d, i2) ← parseNatAt data i1 lineEnd - let (v, i3) ← parseRatAt data i2 lineEnd - ensureNoMoreTokens data i3 lineEnd - let vec ← setVecEntry dModel st.direction d v - return { st with direction := vec } - else if len = kwDirectionTarget.size && tokenEq data t0 kwDirectionTarget then - if st.directionTarget.isSome then - throw "duplicate direction-target entry" - else - let (v, i2) ← parseNatAt data i1 lineEnd - ensureNoMoreTokens data i2 lineEnd - return { st with directionTarget := some v } - else if len = kwDirectionNegative.size && tokenEq data t0 kwDirectionNegative then - if st.directionNegative.isSome then - throw "duplicate direction-negative entry" - else - let (v, i2) ← parseNatAt data i1 lineEnd - ensureNoMoreTokens data i2 lineEnd - return { st with directionNegative := some v } - else - throw "unrecognized line" - | _ => - throw "unrecognized line" - -private def parseHeaderLineBytes (data : ByteArray) (lineStart lineEnd : Nat) - (seq? dModel? dHead? : Option Nat) : - Except String (Option Nat × Option Nat × Option Nat) := do - let i0 := skipSpaces data lineStart lineEnd - if i0 >= lineEnd then - return (seq?, dModel?, dHead?) - if data.get! i0 = 35 then - return (seq?, dModel?, dHead?) - match readToken data i0 lineEnd with - | none => return (seq?, dModel?, dHead?) - | some (t0, i1) => - if tokenEq data t0 kwSeq then - let (v, i2) ← parseNatAt data i1 lineEnd - ensureNoMoreTokens data i2 lineEnd - if seq?.isSome then - throw "duplicate seq entry" - else - return (some v, dModel?, dHead?) - else if tokenEq data t0 kwDModel then - let (v, i2) ← parseNatAt data i1 lineEnd - ensureNoMoreTokens data i2 lineEnd - if dModel?.isSome then - throw "duplicate d_model entry" - else - return (seq?, some v, dHead?) - else if tokenEq data t0 kwDHead then - let (v, i2) ← parseNatAt data i1 lineEnd - ensureNoMoreTokens data i2 lineEnd - if dHead?.isSome then - throw "duplicate d_head entry" - else - return (seq?, dModel?, some v) - else - return (seq?, dModel?, dHead?) - -private def finalizeHeadState {seq dModel dHead : Nat} (hpos : 0 < seq) - (st : HeadParseState seq dModel dHead) : - Except String (Model.InductionHeadInputs seq dModel dHead) := do - let scale ← - match st.scale with - | some v => pure v - | none => throw "missing scale entry" - if !arrayAllSome st.prev then - throw "missing prev entries" - if !matAllSome st.embed then - throw "missing embed entries" - let lnEps ← - match st.lnEps with - | some v => pure v - | none => throw "missing ln_eps entry" - if !arrayAllSome st.ln1Gamma then - throw "missing ln1_gamma entries" - if !arrayAllSome st.ln1Beta then - throw "missing ln1_beta entries" - if !matAllSome st.wq then - throw "missing wq entries" - if !arrayAllSome st.bq then - throw "missing bq entries" - if !matAllSome st.wk then - throw "missing wk entries" - if !arrayAllSome st.bk then - throw "missing bk entries" - if !matAllSome st.wv then - throw "missing wv entries" - if !arrayAllSome st.bv then - throw "missing bv entries" - if !matAllSome st.wo then - throw "missing wo entries" - if !arrayAllSome st.attnBias then - throw "missing attn_bias entries" - if !arrayAllSome st.direction then - throw "missing direction entries" - let directionSpec ← - match st.directionTarget, st.directionNegative with - | some target, some negative => pure { target := target, negative := negative } - | _, _ => - throw "direction metadata requires both direction-target and direction-negative" - let defaultPrev : Fin seq := ⟨0, hpos⟩ - let prevFun : Fin seq → Fin seq := fun q => - (st.prev.getD q.1 none).getD defaultPrev - let embedArr : Array (Array Rat) := - st.embed.map (fun row => row.map (fun v => v.getD 0)) - let ln1GammaArr : Array Rat := - st.ln1Gamma.map (fun v => v.getD 0) - let ln1BetaArr : Array Rat := - st.ln1Beta.map (fun v => v.getD 0) - let wqArr : Array (Array Rat) := - st.wq.map (fun row => row.map (fun v => v.getD 0)) - let bqArr : Array Rat := - st.bq.map (fun v => v.getD 0) - let wkArr : Array (Array Rat) := - st.wk.map (fun row => row.map (fun v => v.getD 0)) - let bkArr : Array Rat := - st.bk.map (fun v => v.getD 0) - let wvArr : Array (Array Rat) := - st.wv.map (fun row => row.map (fun v => v.getD 0)) - let bvArr : Array Rat := - st.bv.map (fun v => v.getD 0) - let woArr : Array (Array Rat) := - st.wo.map (fun row => row.map (fun v => v.getD 0)) - let attnBiasArr : Array Rat := - st.attnBias.map (fun v => v.getD 0) - let directionArr : Array Rat := - st.direction.map (fun v => v.getD 0) - let embedFun : Fin seq → Fin dModel → Rat := fun q d => - (embedArr.getD q.1 #[]).getD d.1 0 - let ln1GammaFun : Fin dModel → Rat := fun d => - ln1GammaArr.getD d.1 0 - let ln1BetaFun : Fin dModel → Rat := fun d => - ln1BetaArr.getD d.1 0 - let wqFun : Fin dModel → Fin dHead → Rat := fun i j => - (wqArr.getD i.1 #[]).getD j.1 0 - let bqFun : Fin dHead → Rat := fun j => - bqArr.getD j.1 0 - let wkFun : Fin dModel → Fin dHead → Rat := fun i j => - (wkArr.getD i.1 #[]).getD j.1 0 - let bkFun : Fin dHead → Rat := fun j => - bkArr.getD j.1 0 - let wvFun : Fin dModel → Fin dHead → Rat := fun i j => - (wvArr.getD i.1 #[]).getD j.1 0 - let bvFun : Fin dHead → Rat := fun j => - bvArr.getD j.1 0 - let woFun : Fin dModel → Fin dHead → Rat := fun i j => - (woArr.getD i.1 #[]).getD j.1 0 - let attnBiasFun : Fin dModel → Rat := fun d => - attnBiasArr.getD d.1 0 - let maskCausal := st.maskCausal.getD false - let maskValue := - match st.maskValue with - | some v => v - | none => if maskCausal then (-10000 : Rat) else 0 - let directionFun : Fin dModel → Rat := fun d => - directionArr.getD d.1 0 - let active := - if st.activeSeen then - activeFromBits st.activeBits - else - (Finset.univ : Finset (Fin seq)).erase defaultPrev - pure - { scale := scale - active := active - prev := prevFun - embed := embedFun - lnEps := lnEps - ln1Gamma := ln1GammaFun - ln1Beta := ln1BetaFun - wq := wqFun - bq := bqFun - wk := wkFun - bk := bkFun - wv := wvFun - bv := bvFun - wo := woFun - attnBias := attnBiasFun - maskCausal := maskCausal - maskValue := maskValue - directionSpec := directionSpec - direction := directionFun } - -/-- Parse a raw induction head input payload from UTF-8 bytes. -/ -def parseInductionHeadInputsBytes (data : ByteArray) : - Except String (Sigma (fun seq => - Sigma (fun dModel => Sigma (fun dHead => Model.InductionHeadInputs seq dModel dHead)))) := do - let mut seq? : Option Nat := none - let mut dModel? : Option Nat := none - let mut dHead? : Option Nat := none - let mut i := 0 - let mut afterDims := 0 - let mut haveDims := false - while i < data.size && !haveDims do - let (i', lineStart, lineEnd) := nextLineBounds data i - i := i' - afterDims := i' - let (seqNew, dModelNew, dHeadNew) ← - parseHeaderLineBytes data lineStart lineEnd seq? dModel? dHead? - seq? := seqNew - dModel? := dModelNew - dHead? := dHeadNew - if seq?.isSome && dModel?.isSome && dHead?.isSome then - haveDims := true - let seq ← - match seq? with - | some v => pure v - | none => throw "missing seq entry" - let dModel ← - match dModel? with - | some v => pure v - | none => throw "missing d_model entry" - let dHead ← - match dHead? with - | some v => pure v - | none => throw "missing d_head entry" - match seq, dModel, dHead with - | 0, _, _ => throw "seq must be positive" - | _, 0, _ => throw "d_model must be positive" - | _, _, 0 => throw "d_head must be positive" - | Nat.succ n, Nat.succ m, Nat.succ h => - let seq := Nat.succ n - let dModel := Nat.succ m - let dHead := Nat.succ h - let hpos : 0 < seq := Nat.succ_pos n - let st0 : HeadParseState seq dModel dHead := initHeadState seq dModel dHead - let mut st := st0 - let mut j := afterDims - while j < data.size do - let (j', lineStart, lineEnd) := nextLineBounds data j - j := j' - st ← parseHeadLineBytes data st lineStart lineEnd - let inputs ← finalizeHeadState hpos st - return ⟨seq, ⟨dModel, ⟨dHead, inputs⟩⟩⟩ - -end Parse - -end IO - -end Nfp diff --git a/Nfp/IO/Parse/SoftmaxMargin.lean b/Nfp/IO/Parse/SoftmaxMargin.lean index 648a6e9..0d504e5 100644 --- a/Nfp/IO/Parse/SoftmaxMargin.lean +++ b/Nfp/IO/Parse/SoftmaxMargin.lean @@ -3,7 +3,6 @@ module public import Nfp.IO.Parse.SoftmaxMargin.Cert -public import Nfp.IO.Parse.SoftmaxMargin.Raw /-! Aggregator for softmax-margin parsing helpers. diff --git a/Nfp/IO/Parse/SoftmaxMargin/Raw.lean b/Nfp/IO/Parse/SoftmaxMargin/Raw.lean deleted file mode 100644 index 8bfeb3e..0000000 --- a/Nfp/IO/Parse/SoftmaxMargin/Raw.lean +++ /dev/null @@ -1,84 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.Circuit.Cert.SoftmaxMargin -public import Nfp.IO.Parse.SoftmaxMargin.Shared - -/-! -Parse parsing helpers for raw softmax-margin inputs. --/ - -public section - -namespace Nfp - -namespace IO - -namespace Parse - -open Nfp.Circuit - -/-- Raw softmax-margin payload without `eps`/`margin`. -/ -structure SoftmaxMarginRaw (seq : Nat) where - /-- Active queries for which bounds are required. -/ - active : Finset (Fin seq) - /-- `prev` selector for induction-style attention. -/ - prev : Fin seq → Fin seq - /-- Score matrix entries. -/ - scores : Fin seq → Fin seq → Rat - /-- Attention weight entries. -/ - weights : Fin seq → Fin seq → Rat - -private def finalizeRawState {seq : Nat} (hpos : 0 < seq) - (st : SoftmaxMargin.ParseState seq) : Except String (SoftmaxMarginRaw seq) := do - if !st.prev.all Option.isSome then - throw "missing prev entries" - if !st.scores.all (fun row => row.all Option.isSome) then - throw "missing score entries" - if !st.weights.all (fun row => row.all Option.isSome) then - throw "missing weight entries" - let defaultPrev : Fin seq := ⟨0, hpos⟩ - let prevFun : Fin seq → Fin seq := fun q => - (st.prev[q.1]!).getD defaultPrev - let scoresFun : Fin seq → Fin seq → Rat := fun q k => - let row := st.scores[q.1]! - (row[k.1]!).getD 0 - let weightsFun : Fin seq → Fin seq → Rat := fun q k => - let row := st.weights[q.1]! - (row[k.1]!).getD 0 - let active := - if st.activeSeen then - st.active - else - (Finset.univ : Finset (Fin seq)).erase defaultPrev - pure - { active := active - prev := prevFun - scores := scoresFun - weights := weightsFun } - -/-- Parse a raw softmax-margin payload from text (ignores any `eps`/`margin`). -/ -def parseSoftmaxMarginRaw (input : String) : - Except String (Sigma SoftmaxMarginRaw) := do - let lines := input.splitOn "\n" - let tokens := lines.filterMap cleanTokens - let seq ← SoftmaxMargin.parseSeq tokens - match seq with - | 0 => throw "seq must be positive" - | Nat.succ n => - let seq := Nat.succ n - let hpos : 0 < seq := Nat.succ_pos n - let st0 : SoftmaxMargin.ParseState seq := SoftmaxMargin.initState seq - let st ← tokens.foldlM (fun st t => - match t with - | ["seq", _] => pure st - | _ => SoftmaxMargin.parseLine st t) st0 - let raw ← finalizeRawState hpos st - return ⟨seq, raw⟩ - -end Parse - -end IO - -end Nfp diff --git a/Nfp/IO/Parse/ValueRange.lean b/Nfp/IO/Parse/ValueRange.lean index 102943f..6d706a1 100644 --- a/Nfp/IO/Parse/ValueRange.lean +++ b/Nfp/IO/Parse/ValueRange.lean @@ -3,7 +3,6 @@ module public import Nfp.IO.Parse.ValueRange.Cert -public import Nfp.IO.Parse.ValueRange.Raw /-! Aggregator for value-range parsing helpers. diff --git a/Nfp/IO/Parse/ValueRange/Raw.lean b/Nfp/IO/Parse/ValueRange/Raw.lean deleted file mode 100644 index b6411c2..0000000 --- a/Nfp/IO/Parse/ValueRange/Raw.lean +++ /dev/null @@ -1,66 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.Circuit.Cert.ValueRange -public import Nfp.IO.Parse.ValueRange.Shared - -/-! -Parse parsing helpers for raw value-range inputs. --/ - -public section - -namespace Nfp - -namespace IO - -namespace Parse - -open Nfp.Circuit - -/-- Raw value-range payload without `lo`/`hi` bounds. -/ -structure ValueRangeRaw (seq : Nat) where - /-- Value entries. -/ - vals : Fin seq → Rat - /-- Optional logit-diff direction metadata. -/ - direction : Option Circuit.DirectionSpec - -private def finalizeValueRawState {seq : Nat} (st : ValueRange.ParseState seq) : - Except String (ValueRangeRaw seq) := do - if !finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.vals k).isSome) then - throw "missing value entries" - let valsFun : Fin seq → Rat := fun k => - (st.vals k).getD 0 - let direction ← - match st.directionTarget, st.directionNegative with - | none, none => pure none - | some target, some negative => - pure (some { target := target, negative := negative }) - | _, _ => - throw "direction metadata requires both direction-target and direction-negative" - return { vals := valsFun, direction := direction } - -/-- Parse a raw value-range payload from text (ignores any `lo`/`hi`). -/ -def parseValueRangeRaw (input : String) : - Except String (Sigma ValueRangeRaw) := do - let lines := input.splitOn "\n" - let tokens := lines.filterMap cleanTokens - let seq ← ValueRange.parseSeq tokens - match seq with - | 0 => throw "seq must be positive" - | Nat.succ n => - let seq := Nat.succ n - let st0 : ValueRange.ParseState seq := ValueRange.initState seq - let st ← tokens.foldlM (fun st t => - match t with - | ["seq", _] => pure st - | _ => ValueRange.parseLine st t) st0 - let raw ← finalizeValueRawState st - return ⟨seq, raw⟩ - -end Parse - -end IO - -end Nfp diff --git a/Nfp/Model/Gpt2.lean b/Nfp/Model/Gpt2.lean index 9efb892..61d23a7 100644 --- a/Nfp/Model/Gpt2.lean +++ b/Nfp/Model/Gpt2.lean @@ -9,7 +9,7 @@ public import Nfp.Circuit.Cert.ValueRange Exact GPT-2 slices for induction certification and downstream bounds. This module holds token embeddings, head projection weights, and per-layer -MLP/LayerNorm parameters needed to build `InductionHeadInputs` and downstream +MLP/LayerNorm parameters used to define `InductionHeadInputs` and downstream bound computations. -/ diff --git a/Nfp/Sound.lean b/Nfp/Sound.lean index 26a7f9d..a535626 100644 --- a/Nfp/Sound.lean +++ b/Nfp/Sound.lean @@ -5,5 +5,5 @@ module public import Nfp.Sound.Induction /-! -Sound certificate builders and verified helpers. +Soundness theorems and verified helpers. -/ diff --git a/Nfp/Sound/Induction/CoreDefs.lean b/Nfp/Sound/Induction/CoreDefs.lean index ec7294e..e620385 100644 --- a/Nfp/Sound/Induction/CoreDefs.lean +++ b/Nfp/Sound/Induction/CoreDefs.lean @@ -15,7 +15,7 @@ public import Nfp.Linear.FinFold /-! Core definitions for induction-head certificates. -These definitions are shared across induction certificate builders and checkers. +These definitions are shared across induction certificate checkers and proofs. -/ public section diff --git a/README.md b/README.md index 70fcdfc..1763292 100644 --- a/README.md +++ b/README.md @@ -36,236 +36,81 @@ The authoritative module map and invariants are tracked in `AGENTS.md`. High-level layout: - `Nfp/Core`, `Nfp/Prob`, `Nfp/Mixer`, `Nfp/System`: core math infrastructure. - `Nfp/Circuit`: circuits, typed interfaces, and layer wiring (attention, induction). -- `Nfp/Sound`: sound builders and verified helpers. +- `Nfp/Sound`: soundness theorems and verified helpers. - `Nfp/IO`, `Nfp/Cli`: parsing and CLI entrypoints. ## Induction Certification (prototype) -The current prototype checks **head-level induction certificates** and can optionally compose them -with a **downstream error bound**. Certificates are produced by **untrusted** helper scripts and -verified by the CLI. +The current prototype checks **explicit induction-head certificates**. Certificates are produced +by **untrusted** Python scripts and verified by the Lean CLI; no model forward pass runs in Lean. ### Build a head certificate (untrusted) -Note: the discovery/scan/sweep helper scripts use **one-based** layer/head -indices (literature-aligned), default to **bigram prefix matching** for -`prev`, and **rank by attention score** unless you explicitly switch to -logit-diff mode. The Lean CLI now expects **one-based** layer/head indices by -default; pass `--zero-based` to use legacy zero-based indices. - -For canonical prefix-matching benchmarks, `scripts/scan_gpt2_induction_sound.py` -supports `--synthetic` to generate repeated-random pattern prompts and score -attention/copying on that distribution. - ```bash python scripts/build_gpt2_induction_cert.py \ --output reports/gpt2_induction.cert \ --layer 5 --head 1 --seq 32 --pattern-length 16 \ - --values-out reports/gpt2_induction.values --value-dim 0 \ --active-eps-max 1/2 ``` -If you want values aligned to a logit-diff direction, add: +Optional direction metadata: ``` --direction-target --direction-negative ``` -### Verify a head certificate (trusted checker) - -```bash -lake exe nfp induction certify \ - --scores reports/gpt2_induction.cert \ - --values reports/gpt2_induction.values -``` - -Non-vacuity gates (optional): - -``` ---min-margin --max-eps --min-active --min-logit-diff -``` - -### Recompute bounds inside Lean (sound builder) +If you already have an `NFP_BINARY_V1` model file: ```bash -lake exe nfp induction certify_sound \ - --scores reports/gpt2_induction.cert \ - --values reports/gpt2_induction.values -``` - -This ignores any `eps`/`margin`/`lo`/`hi` lines and recomputes them from the raw entries. - -### Compute exact head inputs inside Lean (experimental) - -```bash -lake exe nfp induction certify_head --inputs reports/gpt2_induction.head -``` - -This path recomputes scores/values in Lean from exact head inputs. It is **experimental** and can -be slow for nontrivial sequence lengths. - -You can also derive the head inputs directly from an `NFP_BINARY_V1` model file: - -```bash -lake exe nfp induction certify_head_model \ - --model models/gpt2_rigorous_with_gelu_kind_seq32.nfpt \ +python scripts/build_gpt2_induction_cert_from_binary.py \ + --model models/gpt2_rigorous.nfpt \ --layer 5 --head 1 \ - --direction-target 1 --direction-negative 2 + --direction-target 1 --direction-negative 2 \ + --output reports/gpt2_induction.cert ``` -By default, `certify_head_model` derives the `prev` map and active set from the -token sequence stored in the model file. Use `--period ` to override with a -fixed periodic prompt. - -### GPT2-small (model-driven) - -To certify induction heads from GPT2-small weights, export a model binary and -let the CLI derive the logit-diff direction from the stored prompt tokens -(prefix matching: [A][B] ... [A] -> [B]): - -```bash -python scripts/export_gpt2.py models/gpt2_small.nfpt - -lake exe nfp induction certify_head_model_auto \ - --model models/gpt2_small.nfpt \ - --layer 5 --head 1 -``` - -Use `--period ` to override the prompt period derived from tokens. - -### End-to-end check with downstream bound (prototype) +### Verify a head certificate (trusted checker) ```bash -python scripts/build_downstream_linear_cert.py \ - --output reports/gpt2_downstream.cert \ - --gain 3/2 --input-bound 5/4 - -lake exe nfp induction certify_end_to_end \ - --scores reports/gpt2_induction.cert \ - --values reports/gpt2_induction.values \ - --downstream reports/gpt2_downstream.cert +lake exe nfp induction certify --cert reports/gpt2_induction.cert ``` -The downstream certificate is **checked for internal arithmetic consistency** but is externally -computed. You can also compute the downstream bound inside Lean from a matrix payload: +Optional gates: -```bash -lake exe nfp induction certify_end_to_end_matrix \ - --scores reports/gpt2_induction.cert \ - --values reports/gpt2_induction.values \ - --matrix reports/gpt2_downstream.matrix ``` - -Or derive the downstream matrix directly from an `NFP_BINARY_V1` model file -(currently uses the unembedding direction only). If `--residual-interval` is omitted, -the tool derives a conservative residual interval from the model: - -```bash -lake exe nfp induction certify_end_to_end_model \ - --scores reports/gpt2_induction.cert \ - --values reports/gpt2_induction.values \ - --model models/gpt2_rigorous.nfpt +--min-active --min-margin --max-eps --min-logit-diff ``` -To use an external residual-interval certificate instead, include -`--residual-interval reports/gpt2_residual.interval`. - ## File formats -### Softmax-margin certificate +### Induction-head certificate ``` seq +direction-target +direction-negative eps margin active prev score weight -``` - -`active ` lines declare the queries on which bounds are required; if omitted, the checker -defaults to all nonzero queries. - -### Value-range certificate - -``` -seq -direction-target -direction-negative +eps-at +weight-bound lo hi val +val-lo +val-hi ``` -`direction-*` lines are optional metadata for directional (logit-diff) values. - -### Downstream linear certificate - -``` -error -gain -input-bound -``` - -The checker enforces `error = gain * input-bound` and nonnegativity of all fields. - -### Downstream matrix payload - -``` -rows -cols -input-bound -w -``` - -The checker computes a row-sum norm bound from the matrix entries. - -### Residual-interval certificate - -``` -dim -lo -hi -``` - -Each `lo`/`hi` entry supplies an interval bound for residual vector coordinate `i`, -used to compute downstream error. - -### Head input format (for `certify_head`) - -``` -seq -d_model -d_head -scale -direction-target -direction-negative -direction -active -prev -embed -ln_eps -ln1_gamma -ln1_beta -wq -bq -wk -bk -wv -bv -wo -attn_bias -``` - -All `direction`, `embed`, and projection matrices must be fully specified. If no `active` lines +`direction-*` lines are optional metadata; if present, both must appear. If no `active` lines appear, the checker defaults to all nonzero queries. ## Soundness boundary - Untrusted scripts may use floating-point numerics to generate candidate certificates. -- The CLI **only verifies** certificate constraints inside Lean; it does not search for witnesses. -- Downstream error certificates are currently **not derived in Lean** (work in progress). +- The CLI **only verifies** explicit certificates; it does not search for witnesses or run models. For known gaps, see `SOUNDNESS_LIMITATIONS.md`. diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index 2ea573f..b02e687 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -5,42 +5,23 @@ It is intentionally brief and focused on the soundness boundary. ## Current limitations -- The trusted CLI only **checks certificates**; it does not search for witnesses or run a model. -- Induction certificates are **head-level** (softmax-margin + value-range + logit-diff lower bound), - and they are conditional on the supplied `prev`, `active`, and `direction` inputs. They do **not** +- The trusted CLI only **checks explicit certificates**; it does not search for witnesses or + run model evaluation. +- Induction certificates are **head-level** (softmax-margin + value-interval + logit-diff lower + bound) and conditional on the supplied `prev`, `active`, and `direction` inputs. They do **not** yet imply end-to-end model behavior. -- Downstream error bounds can be computed from a **matrix payload** inside Lean. A model-based - path exists, but it currently uses only the unembedding direction and derives residual - intervals via conservative interval propagation (ignoring attention-score structure), - which can be loose. -- The `certify_head` path uses a **head-input file** extracted by an untrusted script; the extractor - now includes attention projection biases and LayerNorm metadata, but the Lean-side computation - still ignores the shared attention output bias. -- The `certify_head_model` path derives head inputs from the model binary in Lean, includes - attention projection biases and LayerNorm metadata, and derives `prev`/active from the stored - token sequence by default, but still ignores the shared attention output bias. It currently - requires `head_dim` to be a perfect square to represent the scale as an exact rational. -- The `certify_head_model_auto` path derives the logit-diff direction from the stored prompt - tokens using a heuristic; use explicit direction tokens for fixed claims. -- The certification does not yet prove end-to-end behavioral induction claims. For - `certify_head_model` with `period? = none`, `prev` is derived from tokens and is the maximal - prior match, but other inputs (head-input files or explicit periods) still rely on supplied - `prev` maps. The chosen direction still assumes the unembedding columns encode token logits. -- There is now a sound interval-composition lemma that combines head logit-diff bounds with - head/output intervals via subtraction, but it does not model how head outputs propagate - through subsequent LN/MLP blocks (so tight end-to-end claims remain open). -- The GPT-2 end-to-end bound currently relies on these coarse intervals, so it can be - conservative or vacuous unless the downstream intervals are tightened. -- Performance: exact head-input recomputation in Lean can be slow for nontrivial sequence lengths. -- There is no bridge theorem connecting certificate validity to a full circuit/model semantics - statement (for example, a formal statement about logits under a transformer block stack). +- Direction metadata (`direction-target`, `direction-negative`) is untrusted and assumes that the + unembedding columns represent token logits. +- The active set is user-supplied (or defaulted by the parser); bounds only hold for + `q ∈ active`. +- Residual and downstream bounds are provided as explicit certificates; there is no verified + end-to-end model derivation of these bounds inside Lean. +- Performance: checking large certificates can be expensive for long sequences. ## Remaining work -- Tighten model-derived residual intervals (e.g., use attention-weight certificates or - score-aware bounds) to avoid vacuity. -- Replace untrusted extraction with a verified parser for model weight slices. -- Prove or verify that `prev` and `direction` are derived from token-level semantics. -- Add a formal bridge from certificates to circuit semantics and (eventually) to end-to-end - transformer claims. -- Improve performance for the exact head-input path without weakening soundness. +- Prove or verify that `prev`, `active`, and `direction` are derived from token-level semantics. +- Add a verified extraction pipeline from model weights to explicit certificates. +- Tighten residual and downstream interval bounds to avoid vacuity. +- Extend the bridge from certificates to full circuit/model semantics and (eventually) to + end-to-end transformer claims. diff --git a/docs/induction_cert_audit.md b/docs/induction_cert_audit.md index d7ed707..43724bc 100644 --- a/docs/induction_cert_audit.md +++ b/docs/induction_cert_audit.md @@ -5,33 +5,19 @@ induction heads, and spell out the scope and limitations of that claim. ## Formal proof chain (Lean) -- `buildInductionCertFromHeadCoreWith?` returns a certificate under explicit guards - (`lnEps > 0`, `sqrtLower lnEps > 0`, `dModel ≠ 0`, `active.Nonempty`), so the - computation is only claimed when these preconditions hold - (`Nfp/Sound/Induction/Core.lean`). -- `buildInductionHeadInputs_def` shows the model-derived head inputs are - definitional: `prev`/`active` are computed from tokens (or a fixed period), - and the `direction` vector is the unembedding-column difference for the - provided target/negative token ids (`Nfp/IO/NfptPure.lean`). -- `buildInductionHeadInputs_prev_spec_of_active` and - `prevOfTokens_spec_of_active` prove that when `period? = none`, - every active query has a maximal prior matching token in `prev` - (`Nfp/IO/NfptPure.lean`, `Nfp/Model/InductionPrompt.lean`). -- `buildInductionCertFromHeadWith?` wraps the core computation and returns - a proof-carrying certificate `⟨c, InductionHeadCertSound inputs c⟩` - (`Nfp/Sound/Induction/HeadOutput.lean`). -- `buildInductionCertFromHeadCoreWith?_sound` proves that any returned certificate - satisfies `InductionHeadCertSound`, i.e. the softmax-margin bounds, one-hot - bounds, and value-interval bounds that define the head-level certificate - (`Nfp/Sound/Induction/CoreSound.lean`). -- `buildInductionLogitLowerBoundFromHead?` and - `buildInductionLogitLowerBoundNonvacuous?` lift the head certificate to a - logit-diff lower bound; the key lemma `logitDiffLowerBoundFromCert_le` shows - the bound is sound on active queries (`Nfp/Sound/Induction/LogitDiff.lean`). -- `logitDiffLowerBound_end_to_end_gpt2` combines head logit-diff bounds, head - output intervals, and GPT-2 stack output intervals to give a direction lower - bound on `transformerStackFinalReal` - (`Nfp/Sound/Induction/EndToEnd.lean`, `Nfp/Sound/Bounds/Transformer.lean`). +- Explicit induction-head certificates are parsed from text in + `Nfp/IO/InductionHead/Cert.lean`. +- `checkInductionHeadCert` and `checkInductionHeadCert_sound` show that a + passing certificate satisfies `InductionHeadCertBounds` + (`Nfp/Circuit/Cert/InductionHead.lean`). +- `logitDiffLowerBoundAt` plus `logitDiffLowerBoundAt_le` give a certified lower + bound on the logit-diff contribution derived from the certificate’s values + (`Nfp/Circuit/Cert/LogitDiff.lean`). +- `headLogitDiff_eq_direction_dot_headOutput`, `logitDiffLowerBound_with_residual`, + and `logitDiffLowerBound_with_output_intervals` compose head-level logit-diff + bounds with output intervals (`Nfp/Sound/Induction/LogitDiff.lean`). +- `logitDiffLowerBound_end_to_end_gpt2` instantiates the composition for GPT-2 + stack outputs (`Nfp/Sound/Induction/EndToEnd.lean`). ## Mechanistic mapping (Transformer Circuits) @@ -50,63 +36,30 @@ This is direct mechanistic evidence in the Transformer Circuits sense: it ties parameters (Q/K/V/O + LayerNorm) to certified bounds on attention and value contributions, but only for the specific inputs and direction supplied. -Sources referenced for the mechanistic framing: -- `transformer-circuits-framework.md` (QK/OV decomposition). -- `induction-heads.md` (induction head behavior definition). -- `foundations.md` (reverse-engineering framing and feature decomposition). - ## Preconditions and scope limits These proofs are sufficient for a **conditional** certification claim: -if the inputs are correct and the builder returns a certificate, then the -head-level bounds hold. They are **not** sufficient for a global claim that a -head “is an induction head” without additional assumptions. +if the explicit certificate passes the checker, then the head-level bounds hold. +They are **not** sufficient for a global claim that a head “is an induction head” +without additional assumptions. Key assumptions and limitations: -- For `certify_head_model` with `period? = none`, `prev`/`active` are derived - from tokens and `prev` is the maximal prior match. For head-input files or - when `period?` is set explicitly, `prev` remains a user-supplied input. -- The `--prev-shift` flag switches to the **shifted** `prev` map (`q - period + 1` - or the token-shifted analogue). This aligns the head-level certificate with - the canonical induction circuit (previous-token head → induction head), but - it is still a head-level approximation rather than a verified two-head - composition. -- The `certify_circuit_model` CLI uses shifted `prev` for the induction head - by default, while the previous-token head uses the unshifted period-1 map. -- The certificate proves a logit-diff bound along the supplied `direction` - vector. For model-derived inputs, this vector is the target-minus-negative - unembedding column difference, but we still assume that the unembedding - matrix represents the model’s logit map. -- The active set is user-supplied and can be strict; bounds only hold for - `q ∈ active`, not all positions. -- There is now a formal bridge from head logit-diff bounds plus residual interval - bounds to a direction lower bound on `headOutput + residual`, but full - end-to-end model logits still require verified residual bounds through the - rest of the stack. - We now have a theorem packaging GPT-2 residual interval bounds derived from - model slices into a sound `ResidualIntervalCert`, but it is not yet connected - to the head-level logit-diff contribution inside the full stack. - A new lemma composes head logit-diff bounds with *both* head-output intervals - and downstream output intervals, yielding a sound lower bound on the direction - dot of the downstream output (via interval subtraction), and we now instantiate - this for GPT-2 stack outputs via `logitDiffLowerBound_end_to_end_gpt2`. +- `prev`, `active`, and `direction` are user-supplied or produced by untrusted + scripts; Lean does not (yet) verify their derivation from token-level semantics. +- The active set can be strict; bounds only hold for `q ∈ active`, not all positions. +- The direction metadata assumes the unembedding columns encode the model’s logit map. +- End-to-end claims rely on external residual/downstream interval certificates; the + current checker only verifies those certificates once provided. ## Conclusion Yes—**within the formal scope** of the current definitions, the proofs are enough to claim that we can certify induction-head behavior at the head level: they certify attention to a specified `prev` index and a logit-diff lower bound -along a specified direction. We now have a bridge that composes those bounds -with residual interval bounds to certify `headOutput + residual`, but we still -need a proof that the inputs correspond to the behavioral induction-head -definition on actual sequences and that residual bounds are derived from full -model semantics. +along a specified direction, conditional on an explicit certificate. ## Next steps -- Formalize the relationship between `directionSpec` and the logit-diff vector - derived from unembedding (so the certified direction matches token-level claims). -- Add a proof or verified derivation that the `prev` mapping corresponds to the - induction pattern for a given prompt sequence. -- Extend the bridge to full transformer stacks by deriving residual interval - bounds from verified layer/block semantics. +- Add a verified extraction pipeline from model weights to explicit certificates. +- Prove that `prev`, `active`, and `direction` correspond to token-level semantics. +- Tighten residual/downstream interval bounds to strengthen end-to-end claims. diff --git a/scripts/build_downstream_linear_cert.py b/scripts/build_downstream_linear_cert.py index 206da9c..d1640cb 100644 --- a/scripts/build_downstream_linear_cert.py +++ b/scripts/build_downstream_linear_cert.py @@ -4,8 +4,8 @@ """ Build a downstream linear certificate from externally computed bounds. -This script is untrusted: it only formats rational inputs into the certificate -format expected by `nfp induction certify_end_to_end`. +This script is untrusted: it only formats rational inputs into the downstream +linear certificate format expected by the Lean checker. Usage: python scripts/build_downstream_linear_cert.py \ diff --git a/scripts/build_gpt2_head_inputs.py b/scripts/build_gpt2_head_inputs.py deleted file mode 100644 index 5b1c32c..0000000 --- a/scripts/build_gpt2_head_inputs.py +++ /dev/null @@ -1,369 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: AGPL-3.0-or-later - -""" -Build an induction head input file from an NFP_BINARY_V1 model. - -This is an untrusted helper that extracts a single head slice plus the -prompt embeddings from an `.nfpt` file and writes the text format consumed by -`nfp induction certify_head`. -""" - -from __future__ import annotations - -import argparse -import math -import struct -from fractions import Fraction -from pathlib import Path -from typing import Dict, Tuple - -import numpy as np - - -def rat_from_float_exact(x: float) -> Fraction: - if not math.isfinite(x): - raise SystemExit(f"non-finite float encountered: {x}") - num, den = x.as_integer_ratio() - return Fraction(num, den) - - -def rat_to_str(q: Fraction) -> str: - if q.denominator == 1: - return str(q.numerator) - return f"{q.numerator}/{q.denominator}" - - -def parse_header(f) -> Dict[str, str]: - header: Dict[str, str] = {} - magic = f.readline().decode("ascii").strip() - if magic != "NFP_BINARY_V1": - raise SystemExit(f"Unsupported magic header: {magic}") - while True: - line = f.readline() - if line == b"": - raise SystemExit("Unexpected EOF while reading header.") - text = line.decode("ascii").strip() - if text == "BINARY_START": - break - if "=" in text: - key, value = text.split("=", 1) - header[key.strip()] = value.strip() - return header - - -def read_i32(f, count: int) -> np.ndarray: - raw = f.read(count * 4) - if len(raw) != count * 4: - raise SystemExit("Unexpected EOF while reading int32 payload.") - return np.frombuffer(raw, dtype=" np.ndarray: - raw = f.read(count * 8) - if len(raw) != count * 8: - raise SystemExit("Unexpected EOF while reading float64 payload.") - return np.frombuffer(raw, dtype=" None: - offset = count * 8 - f.seek(offset, 1) - - -def build_prev(tokens: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: - prev = np.zeros_like(tokens) - active = np.zeros_like(tokens, dtype=bool) - last_seen: Dict[int, int] = {} - for idx, tok in enumerate(tokens.tolist()): - if idx == 0: - prev[idx] = 0 - active[idx] = False - else: - if tok in last_seen: - prev[idx] = last_seen[tok] - active[idx] = True - else: - prev[idx] = 0 - active[idx] = False - last_seen[tok] = idx - return prev, active - - -def read_head_weights( - f, - num_layers: int, - num_heads: int, - model_dim: int, - head_dim: int, - hidden_dim: int, - layer: int, - head: int, -) -> Tuple[ - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, -]: - target = (layer, head) - wq = wk = wv = wo = None - bq = bk = bv = None - attn_bias = ln1_gamma = ln1_beta = None - for layer_idx in range(num_layers): - for head_idx in range(num_heads): - wq_block = read_f64(f, model_dim * head_dim).reshape(model_dim, head_dim) - bq_block = read_f64(f, head_dim) # b_Q - wk_block = read_f64(f, model_dim * head_dim).reshape(model_dim, head_dim) - bk_block = read_f64(f, head_dim) # b_K - wv_block = read_f64(f, model_dim * head_dim).reshape(model_dim, head_dim) - bv_block = read_f64(f, head_dim) # b_V - wo_block = read_f64(f, head_dim * model_dim).reshape(head_dim, model_dim) - if (layer_idx, head_idx) == target: - wq = wq_block - wk = wk_block - wv = wv_block - wo = wo_block - bq = bq_block - bk = bk_block - bv = bv_block - # Skip per-layer non-head data. - attn_bias_block = read_f64(f, model_dim) # attn_bias - skip_f64(f, model_dim * hidden_dim) # w_in - skip_f64(f, hidden_dim) # b_in - skip_f64(f, hidden_dim * model_dim) # w_out - skip_f64(f, model_dim) # b_out - ln1_gamma_block = read_f64(f, model_dim) # ln1_gamma - ln1_beta_block = read_f64(f, model_dim) # ln1_beta - skip_f64(f, model_dim) # ln2_gamma - skip_f64(f, model_dim) # ln2_beta - if layer_idx == layer: - attn_bias = attn_bias_block - ln1_gamma = ln1_gamma_block - ln1_beta = ln1_beta_block - if ( - wq is None - or wk is None - or wv is None - or wo is None - or bq is None - or bk is None - or bv is None - or attn_bias is None - or ln1_gamma is None - or ln1_beta is None - ): - raise SystemExit("Failed to locate head weights.") - return wq, bq, wk, bk, wv, bv, wo, attn_bias, ln1_gamma, ln1_beta - - -def read_unembed_columns( - f, - start: int, - model_dim: int, - vocab_size: int, - target: int, - negative: int, -) -> Tuple[np.ndarray, np.ndarray]: - row_bytes = vocab_size * 8 - col_t = np.zeros(model_dim, dtype=np.float64) - col_n = np.zeros(model_dim, dtype=np.float64) - for row in range(model_dim): - base = start + row * row_bytes - f.seek(base + target * 8) - col_t[row] = struct.unpack(" None: - seq, model_dim = embeddings.shape - _, head_dim = wq.shape - with path.open("w", encoding="ascii") as f: - f.write(f"seq {seq}\n") - f.write(f"d_model {model_dim}\n") - f.write(f"d_head {head_dim}\n") - f.write(f"scale {rat_to_str(scale)}\n") - for q, flag in enumerate(active.tolist()): - if flag: - f.write(f"active {q}\n") - for q, k in enumerate(prev.tolist()): - f.write(f"prev {q} {k}\n") - for q in range(seq): - for d in range(model_dim): - f.write(f"embed {q} {d} {rat_to_str(rat_from_float_exact(float(embeddings[q, d])))}\n") - f.write(f"ln_eps {rat_to_str(ln_eps)}\n") - for d in range(model_dim): - f.write(f"ln1_gamma {d} {rat_to_str(rat_from_float_exact(float(ln1_gamma[d])))}\n") - for d in range(model_dim): - f.write(f"ln1_beta {d} {rat_to_str(rat_from_float_exact(float(ln1_beta[d])))}\n") - for i in range(model_dim): - for j in range(head_dim): - f.write(f"wq {i} {j} {rat_to_str(rat_from_float_exact(float(wq[i, j])))}\n") - for j in range(head_dim): - f.write(f"bq {j} {rat_to_str(rat_from_float_exact(float(bq[j])))}\n") - for i in range(model_dim): - for j in range(head_dim): - f.write(f"wk {i} {j} {rat_to_str(rat_from_float_exact(float(wk[i, j])))}\n") - for j in range(head_dim): - f.write(f"bk {j} {rat_to_str(rat_from_float_exact(float(bk[j])))}\n") - for i in range(model_dim): - for j in range(head_dim): - f.write(f"wv {i} {j} {rat_to_str(rat_from_float_exact(float(wv[i, j])))}\n") - for j in range(head_dim): - f.write(f"bv {j} {rat_to_str(rat_from_float_exact(float(bv[j])))}\n") - for i in range(model_dim): - for j in range(head_dim): - f.write(f"wo {i} {j} {rat_to_str(rat_from_float_exact(float(wo[i, j])))}\n") - for d in range(model_dim): - f.write(f"attn_bias {d} {rat_to_str(rat_from_float_exact(float(attn_bias[d])))}\n") - f.write(f"mask {'causal' if mask_causal else 'none'}\n") - f.write(f"mask_value {rat_to_str(mask_value)}\n") - f.write(f"direction-target {direction_target}\n") - f.write(f"direction-negative {direction_negative}\n") - for d in range(model_dim): - f.write(f"direction {d} {rat_to_str(rat_from_float_exact(float(direction[d])))}\n") - - -def main() -> None: - ap = argparse.ArgumentParser(description=__doc__) - ap.add_argument("--model", type=Path, required=True, help="Path to NFP_BINARY_V1 model") - ap.add_argument("--layer", type=int, required=True, help="Layer index") - ap.add_argument("--head", type=int, required=True, help="Head index") - ap.add_argument("--output", type=Path, required=True, help="Path for the head input file") - ap.add_argument("--direction-target", type=int, required=True, help="Target token id") - ap.add_argument("--direction-negative", type=int, required=True, help="Negative token id") - args = ap.parse_args() - - if not args.model.exists(): - raise SystemExit(f"Missing model file: {args.model}") - - with args.model.open("rb") as f: - header = parse_header(f) - num_layers = int(header["num_layers"]) - num_heads = int(header["num_heads"]) - model_dim = int(header["model_dim"]) - head_dim = int(header["head_dim"]) - vocab_size = int(header["vocab_size"]) - seq_len = int(header["seq_len"]) - hidden_dim = int(header["hidden_dim"]) - - if args.layer < 0 or args.layer >= num_layers: - raise SystemExit("layer index out of range") - if args.head < 0 or args.head >= num_heads: - raise SystemExit("head index out of range") - if not (0 <= args.direction_target < vocab_size): - raise SystemExit("direction-target out of vocab range") - if not (0 <= args.direction_negative < vocab_size): - raise SystemExit("direction-negative out of vocab range") - - tokens = read_i32(f, seq_len) - embeddings = read_f64(f, seq_len * model_dim).reshape(seq_len, model_dim) - - wq, bq, wk, bk, wv, bv, wo_raw, attn_bias, ln1_gamma, ln1_beta = read_head_weights( - f, - num_layers, - num_heads, - model_dim, - head_dim, - hidden_dim, - args.layer, - args.head, - ) - - # Skip final layer norm parameters. - skip_f64(f, model_dim) # ln_f_gamma - skip_f64(f, model_dim) # ln_f_beta - - unembed_start = f.tell() - col_target, col_negative = read_unembed_columns( - f, - unembed_start, - model_dim, - vocab_size, - args.direction_target, - args.direction_negative, - ) - - prev, active = build_prev(tokens) - direction = col_target - col_negative - scale_denom = int(math.isqrt(head_dim)) - if scale_denom * scale_denom != head_dim: - scale = rat_from_float_exact(1.0 / math.sqrt(head_dim)) - else: - scale = Fraction(1, scale_denom) - - # Stored W_O is (head_dim, model_dim); transpose to model_dim × head_dim. - wo = wo_raw.T - - mask_causal = True - mask_value = Fraction(-10000, 1) - - args.output.parent.mkdir(parents=True, exist_ok=True) - ln_eps_raw = header.get("layer_norm_eps") - if ln_eps_raw is None: - raise SystemExit("Missing layer_norm_eps in header.") - try: - ln_eps = Fraction(ln_eps_raw) - except ValueError: - ln_eps = rat_from_float_exact(float(ln_eps_raw)) - write_head_inputs( - args.output, - scale, - tokens, - embeddings, - prev, - active, - wq, - bq, - wk, - bk, - wv, - bv, - wo, - attn_bias, - ln_eps, - ln1_gamma, - ln1_beta, - mask_causal, - mask_value, - args.direction_target, - args.direction_negative, - direction, - ) - - print(f"Wrote head inputs to {args.output}") - print(f"seq={seq_len} d_model={model_dim} d_head={head_dim}") - - -if __name__ == "__main__": - main() diff --git a/scripts/build_gpt2_induction_cert.py b/scripts/build_gpt2_induction_cert.py index 80be5f1..fa206df 100644 --- a/scripts/build_gpt2_induction_cert.py +++ b/scripts/build_gpt2_induction_cert.py @@ -2,11 +2,11 @@ # SPDX-License-Identifier: AGPL-3.0-or-later """ -Build a softmax-margin certificate for a GPT-2-small induction head. +Build an induction-head certificate for a GPT-2-small induction head. This script is untrusted and uses floating-point arithmetic to produce a -rational certificate compatible with `nfp induction certify`. Active -induction positions are recorded as `active ` lines in the output. +rational induction-head certificate compatible with `nfp induction certify`. +Active induction positions are recorded as `active ` lines in the output. Usage: python scripts/build_gpt2_induction_cert.py --output reports/gpt2_induction.cert \ @@ -121,6 +121,42 @@ def write_scores(path: Path, seq: int, prev: np.ndarray, scores, weights, eps=No for k in range(seq): f.write(f"weight {q} {k} {rat_to_str(weights[q][k])}\n") +def write_induction_cert(path: Path, seq: int, prev: np.ndarray, scores, weights, + eps, margin, active, eps_at, weight_bound_at, + vals, direction_target=None, direction_negative=None) -> None: + lo = min(vals) + hi = max(vals) + with path.open("w", encoding="ascii") as f: + f.write(f"seq {seq}\n") + if direction_target is not None and direction_negative is not None: + f.write(f"direction-target {direction_target}\n") + f.write(f"direction-negative {direction_negative}\n") + f.write(f"eps {rat_to_str(eps)}\n") + f.write(f"margin {rat_to_str(margin)}\n") + if active is not None: + for q in active: + f.write(f"active {q}\n") + for q, k in enumerate(prev.tolist()): + f.write(f"prev {q} {k}\n") + for q in range(seq): + for k in range(seq): + f.write(f"score {q} {k} {rat_to_str(scores[q][k])}\n") + for q in range(seq): + for k in range(seq): + f.write(f"weight {q} {k} {rat_to_str(weights[q][k])}\n") + for q in range(seq): + f.write(f"eps-at {q} {rat_to_str(eps_at[q])}\n") + for q in range(seq): + for k in range(seq): + f.write(f"weight-bound {q} {k} {rat_to_str(weight_bound_at[q][k])}\n") + f.write(f"lo {rat_to_str(lo)}\n") + f.write(f"hi {rat_to_str(hi)}\n") + for k, val in enumerate(vals): + val_str = rat_to_str(val) + f.write(f"val {k} {val_str}\n") + f.write(f"val-lo {k} {val_str}\n") + f.write(f"val-hi {k} {val_str}\n") + def write_value_range(path: Path, seq: int, values, decimals: int, direction_target=None, direction_negative=None) -> None: @@ -191,17 +227,19 @@ def main() -> None: margin = None eps_by_q: dict[int, Fraction] = {} margin_by_q: dict[int, Fraction] = {} - for q in candidate_positions: + for q in range(args.seq): prev_q = prev[q] prev_w = weights_rat[q][prev_q] - max_other = max(weights_rat[q][k] for k in range(args.seq) if k != prev_q) + if args.seq == 1: + max_other = Fraction(0) + else: + max_other = max(weights_rat[q][k] for k in range(args.seq) if k != prev_q) deficit = Fraction(1) - prev_w eps_by_q[q] = max(max_other, deficit) diffs = [scores_rat[q][prev_q] - scores_rat[q][k] for k in range(args.seq) if k != prev_q] - if diffs: - margin_by_q[q] = min(diffs) + margin_by_q[q] = min(diffs) if diffs else Fraction(0) active_positions = candidate_positions eps_threshold = Fraction(args.active_eps_max) @@ -209,6 +247,10 @@ def main() -> None: if not active_positions and candidate_positions: print("Warning: no active positions satisfy active-eps-max; certificate may be vacuous.") + if not active_positions and args.seq > 1: + if candidate_positions: + print("Warning: no active positions satisfy active-eps-max; using all nonzero queries.") + active_positions = list(range(1, args.seq)) if active_positions: eps = max(eps_by_q[q] for q in active_positions) margin = min((margin_by_q[q] for q in active_positions), default=Fraction(0)) @@ -218,10 +260,59 @@ def main() -> None: if candidate_positions: print(f"Active positions: {len(active_positions)}/{len(candidate_positions)}") + eps_at = [] + for q in range(args.seq): + prev_q = prev[q] + if args.seq == 1: + max_other = Fraction(0) + else: + max_other = max(weights_rat[q][k] for k in range(args.seq) if k != prev_q) + deficit = Fraction(1) - weights_rat[q][prev_q] + eps_at.append(max(max_other, deficit)) + weight_bound_at = weights_rat + + direction_target = None + direction_negative = None + if (args.direction_target is None) != (args.direction_negative is None): + raise SystemExit("direction-target and direction-negative must be provided together") + if args.direction_target is not None: + wte = model.wte.weight.detach().cpu().numpy() + if args.direction_target < 0 or args.direction_target >= wte.shape[0]: + raise SystemExit("direction-target out of vocab range") + if args.direction_negative < 0 or args.direction_negative >= wte.shape[0]: + raise SystemExit("direction-negative out of vocab range") + direction_target = args.direction_target + direction_negative = args.direction_negative + direction = wte[direction_target] - wte[direction_negative] + head_dim = model.config.n_embd // model.config.n_head + start, end = args.head * head_dim, (args.head + 1) * head_dim + w_o = model.h[args.layer].attn.c_proj.weight.detach().cpu().numpy()[:, start:end] + dir_head = w_o.T @ direction + vals = values @ dir_head + else: + if args.value_dim < 0 or args.value_dim >= values.shape[1]: + raise SystemExit(f"value-dim must be in [0, {values.shape[1] - 1}]") + vals = values[:, args.value_dim] + + vals_rat = [rat_from_float(float(vals[k]), args.decimals) for k in range(args.seq)] + output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) - write_scores(output_path, args.seq, prev, scores_rat, weights_rat, - eps=eps, margin=margin, active=active_positions) + write_induction_cert( + output_path, + args.seq, + prev, + scores_rat, + weights_rat, + eps, + margin, + active_positions, + eps_at, + weight_bound_at, + vals_rat, + direction_target=direction_target, + direction_negative=direction_negative, + ) if args.scores_out: scores_path = Path(args.scores_out) @@ -232,27 +323,9 @@ def main() -> None: if args.values_out: values_path = Path(args.values_out) values_path.parent.mkdir(parents=True, exist_ok=True) - if (args.direction_target is None) != (args.direction_negative is None): - raise SystemExit("direction-target and direction-negative must be provided together") - if args.direction_target is not None: - wte = model.wte.weight.detach().cpu().numpy() - if args.direction_target < 0 or args.direction_target >= wte.shape[0]: - raise SystemExit("direction-target out of vocab range") - if args.direction_negative < 0 or args.direction_negative >= wte.shape[0]: - raise SystemExit("direction-negative out of vocab range") - direction = wte[args.direction_target] - wte[args.direction_negative] - head_dim = model.config.n_embd // model.config.n_head - start, end = args.head * head_dim, (args.head + 1) * head_dim - w_o = model.h[args.layer].attn.c_proj.weight.detach().cpu().numpy()[:, start:end] - dir_head = w_o.T @ direction - dir_vals = values @ dir_head - write_value_range(values_path, args.seq, dir_vals, args.decimals, - direction_target=args.direction_target, - direction_negative=args.direction_negative) - else: - if args.value_dim < 0 or args.value_dim >= values.shape[1]: - raise SystemExit(f"value-dim must be in [0, {values.shape[1] - 1}]") - write_value_range(values_path, args.seq, values[:, args.value_dim], args.decimals) + write_value_range(values_path, args.seq, vals, args.decimals, + direction_target=direction_target, + direction_negative=direction_negative) print(f"Wrote certificate to {output_path}") if args.scores_out: diff --git a/scripts/build_gpt2_induction_cert_from_binary.py b/scripts/build_gpt2_induction_cert_from_binary.py index 8229320..c7862d3 100644 --- a/scripts/build_gpt2_induction_cert_from_binary.py +++ b/scripts/build_gpt2_induction_cert_from_binary.py @@ -2,10 +2,10 @@ # SPDX-License-Identifier: AGPL-3.0-or-later """ -Build a softmax-margin certificate and value-range certificate from an NFP_BINARY_V1 model. +Build an induction-head certificate from an NFP_BINARY_V1 model. -This is untrusted and uses floating-point arithmetic to produce rational certificates -compatible with `nfp induction certify`. +This is untrusted and uses floating-point arithmetic to produce a rational +induction-head certificate compatible with `nfp induction certify`. """ from __future__ import annotations @@ -191,11 +191,19 @@ def softmax(scores: np.ndarray) -> np.ndarray: return exp / exp.sum(axis=1, keepdims=True) -def write_softmax_cert(path: Path, seq: int, prev: np.ndarray, +def write_induction_cert(path: Path, seq: int, prev: np.ndarray, scores_rat, weights_rat, eps: Fraction, - margin: Fraction, active_positions) -> None: + margin: Fraction, active_positions, + eps_at, weight_bound_at, vals_rat, + direction_target: int | None, + direction_negative: int | None) -> None: + lo = min(vals_rat) + hi = max(vals_rat) with path.open("w", encoding="ascii") as f: f.write(f"seq {seq}\n") + if direction_target is not None and direction_negative is not None: + f.write(f"direction-target {direction_target}\n") + f.write(f"direction-negative {direction_negative}\n") f.write(f"eps {rat_to_str(eps)}\n") f.write(f"margin {rat_to_str(margin)}\n") for q in active_positions: @@ -208,6 +216,18 @@ def write_softmax_cert(path: Path, seq: int, prev: np.ndarray, for q in range(seq): for k in range(seq): f.write(f"weight {q} {k} {rat_to_str(weights_rat[q][k])}\n") + for q in range(seq): + f.write(f"eps-at {q} {rat_to_str(eps_at[q])}\n") + for q in range(seq): + for k in range(seq): + f.write(f"weight-bound {q} {k} {rat_to_str(weight_bound_at[q][k])}\n") + f.write(f"lo {rat_to_str(lo)}\n") + f.write(f"hi {rat_to_str(hi)}\n") + for k, val in enumerate(vals_rat): + val_str = rat_to_str(val) + f.write(f"val {k} {val_str}\n") + f.write(f"val-lo {k} {val_str}\n") + f.write(f"val-hi {k} {val_str}\n") def write_value_range(path: Path, seq: int, values, decimals: int, @@ -231,9 +251,10 @@ def main() -> None: ap.add_argument("--model", type=Path, required=True, help="Path to NFP_BINARY_V1 model") ap.add_argument("--layer", type=int, required=True, help="Layer index") ap.add_argument("--head", type=int, required=True, help="Head index") - ap.add_argument("--output", type=Path, required=True, help="Path for softmax certificate") - ap.add_argument("--values-out", type=Path, required=True, - help="Path for value-range certificate") + ap.add_argument("--output", type=Path, required=True, + help="Path for induction-head certificate") + ap.add_argument("--values-out", type=Path, + help="Optional path for a value-range certificate") ap.add_argument("--direction-target", type=int, required=True, help="Target token id for logit-diff direction") ap.add_argument("--direction-negative", type=int, required=True, @@ -329,22 +350,28 @@ def main() -> None: eps_by_q: dict[int, Fraction] = {} margin_by_q: dict[int, Fraction] = {} - for q in candidate_positions: + for q in range(seq_len): prev_q = prev[q] prev_w = weights_rat[q][prev_q] - max_other = max(weights_rat[q][k] for k in range(seq_len) if k != prev_q) + if seq_len == 1: + max_other = Fraction(0) + else: + max_other = max(weights_rat[q][k] for k in range(seq_len) if k != prev_q) deficit = Fraction(1) - prev_w eps_by_q[q] = max(max_other, deficit) diffs = [scores_rat[q][prev_q] - scores_rat[q][k] for k in range(seq_len) if k != prev_q] - if diffs: - margin_by_q[q] = min(diffs) + margin_by_q[q] = min(diffs) if diffs else Fraction(0) active_positions = [q for q in candidate_positions if eps_by_q[q] <= active_eps_max] if not active_positions and candidate_positions: print("Warning: no active positions satisfy active-eps-max; certificate may be vacuous.") + if not active_positions and seq_len > 1: + if candidate_positions: + print("Warning: no active positions satisfy active-eps-max; using all nonzero queries.") + active_positions = list(range(1, seq_len)) if active_positions: eps = max(eps_by_q[q] for q in active_positions) margin = min((margin_by_q[q] for q in active_positions), default=Fraction(0)) @@ -352,23 +379,50 @@ def main() -> None: eps = Fraction(0) margin = Fraction(0) - output_path = args.output - output_path.parent.mkdir(parents=True, exist_ok=True) - write_softmax_cert(output_path, seq_len, prev, scores_rat, weights_rat, eps, margin, - active_positions) + eps_at = [] + for q in range(seq_len): + prev_q = prev[q] + if seq_len == 1: + max_other = Fraction(0) + else: + max_other = max(weights_rat[q][k] for k in range(seq_len) if k != prev_q) + deficit = Fraction(1) - weights_rat[q][prev_q] + eps_at.append(max(max_other, deficit)) + weight_bound_at = weights_rat wo = wo_raw.T direction = col_target - col_negative dir_head = wo.T @ direction dir_vals = v @ dir_head - values_path = args.values_out - values_path.parent.mkdir(parents=True, exist_ok=True) - write_value_range(values_path, seq_len, dir_vals, args.decimals, - direction_target=args.direction_target, - direction_negative=args.direction_negative) - - print(f"Wrote softmax certificate to {output_path}") - print(f"Wrote value-range certificate to {values_path}") + vals_rat = [rat_from_float(float(dir_vals[k]), args.decimals) for k in range(seq_len)] + + output_path = args.output + output_path.parent.mkdir(parents=True, exist_ok=True) + write_induction_cert( + output_path, + seq_len, + prev, + scores_rat, + weights_rat, + eps, + margin, + active_positions, + eps_at, + weight_bound_at, + vals_rat, + args.direction_target, + args.direction_negative, + ) + + if args.values_out: + values_path = args.values_out + values_path.parent.mkdir(parents=True, exist_ok=True) + write_value_range(values_path, seq_len, dir_vals, args.decimals, + direction_target=args.direction_target, + direction_negative=args.direction_negative) + print(f"Wrote value-range certificate to {values_path}") + + print(f"Wrote induction-head certificate to {output_path}") if candidate_positions: print(f"Active positions: {len(active_positions)}/{len(candidate_positions)}") diff --git a/scripts/build_residual_bound_cert.py b/scripts/build_residual_bound_cert.py index a34b5c6..7f54c51 100644 --- a/scripts/build_residual_bound_cert.py +++ b/scripts/build_residual_bound_cert.py @@ -7,7 +7,7 @@ This script is untrusted. It computes per-coordinate absolute bounds by taking maxima over a fixed input sequence (optionally restricted to active positions from a softmax-margin certificate). The resulting bounds are -rounded up to rationals for checking by `nfp induction certify_end_to_end_model`. +rounded up to rationals for Lean-side checking. Usage: uv run scripts/build_residual_bound_cert.py \ diff --git a/scripts/build_residual_interval_cert.py b/scripts/build_residual_interval_cert.py index accfd61..51896d2 100644 --- a/scripts/build_residual_interval_cert.py +++ b/scripts/build_residual_interval_cert.py @@ -7,8 +7,7 @@ This script is untrusted. It computes per-coordinate min/max bounds by taking extrema over a fixed input sequence (optionally restricted to active positions from a softmax-margin certificate). The resulting intervals are -expanded slightly and rounded outwards to rationals for checking by -`nfp induction certify_end_to_end_model`. +expanded slightly and rounded outwards to rationals for Lean-side checking. Usage: uv run scripts/build_residual_interval_cert.py \ diff --git a/scripts/certify_induction_head.py b/scripts/certify_induction_head.py deleted file mode 100644 index b183df9..0000000 --- a/scripts/certify_induction_head.py +++ /dev/null @@ -1,204 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: AGPL-3.0-or-later - -""" -Certify a single induction head from a model binary. - -This script is a small wrapper around -`nfp induction certify_head_model_auto(_nonvacuous)` and optionally -creates a diagnostic prompt model with repeated patterns. -""" - -from __future__ import annotations - -import argparse -import os -import shutil -import subprocess -import sys -from pathlib import Path - - -def resolve_nfp_cmd(nfp_bin: str | None) -> list[str]: - if nfp_bin: - return [nfp_bin] - env_bin = os.environ.get("NFP_BIN") - if env_bin: - return [env_bin] - local_bin = Path(".lake/build/bin/nfp") - if local_bin.exists(): - return [str(local_bin)] - return ["lake", "exe", "nfp"] - - -def ensure_model( - model_path: Path, - *, - seq_len: int, - pattern_len: int, - seed: int, - vocab_min: int, - vocab_max: int, - min_word_length: int, - allow_no_leading_space: bool, - model_name: str, -) -> None: - if model_path.exists(): - return - model_path.parent.mkdir(parents=True, exist_ok=True) - generator = [ - sys.executable, - "scripts/generate_rigorous_induction.py", - "--output", - str(model_path), - "--seq-len", - str(seq_len), - "--pattern-len", - str(pattern_len), - "--seed", - str(seed), - "--vocab-min", - str(vocab_min), - "--vocab-max", - str(vocab_max), - "--min-word-length", - str(min_word_length), - "--model", - model_name, - ] - if allow_no_leading_space: - generator.append("--allow-no-leading-space") - if shutil.which("uv"): - generator = ["uv", "run"] + generator - subprocess.run(generator, check=True) - - -def main() -> int: - parser = argparse.ArgumentParser( - description="Certify an induction head from a model binary." - ) - parser.add_argument("--model", default="models/gpt2_rigorous.nfpt") - parser.add_argument("--layer", type=int, required=True) - parser.add_argument("--head", type=int, required=True) - parser.add_argument("--period", type=int) - parser.add_argument("--prev-shift", action="store_true") - parser.add_argument("--nonvacuous", action="store_true") - parser.add_argument("--zero-based", action="store_true") - parser.add_argument("--min-active", type=int) - parser.add_argument("--min-logit-diff", type=str) - parser.add_argument("--min-margin", type=str) - parser.add_argument("--max-eps", type=str) - parser.add_argument("--nfp-bin", help="Path to nfp binary") - parser.add_argument( - "--preset", - choices=["fast", "balanced", "tight"], - help="Split-budget preset for streamlined certify", - ) - parser.add_argument("--timing", type=int) - parser.add_argument("--heartbeat-ms", type=int) - parser.add_argument("--split-budget-q", type=int) - parser.add_argument("--split-budget-k", type=int) - parser.add_argument("--split-budget-diff-base", type=int) - parser.add_argument("--split-budget-diff-refined", type=int) - parser.add_argument("--skip-logit-diff", action="store_true") - - parser.add_argument( - "--ensure-model", - action="store_true", - help="Generate a diagnostic model if the path does not exist", - ) - parser.add_argument("--seq-len", type=int, default=256) - parser.add_argument("--pattern-len", type=int, default=20) - parser.add_argument("--seed", type=int, default=1337) - parser.add_argument("--vocab-min", type=int, default=1000) - parser.add_argument("--vocab-max", type=int, default=5000) - parser.add_argument("--min-word-length", type=int, default=4) - parser.add_argument("--model-name", default="gpt2") - parser.add_argument("--allow-no-leading-space", action="store_true") - - args = parser.parse_args() - - model_path = Path(args.model) - if args.ensure_model: - ensure_model( - model_path, - seq_len=args.seq_len, - pattern_len=args.pattern_len, - seed=args.seed, - vocab_min=args.vocab_min, - vocab_max=args.vocab_max, - min_word_length=args.min_word_length, - allow_no_leading_space=args.allow_no_leading_space, - model_name=args.model_name, - ) - if not model_path.exists(): - print(f"error: model not found at {model_path}", file=sys.stderr) - return 1 - - use_advanced = any( - val is not None - for val in ( - args.split_budget_q, - args.split_budget_k, - args.split_budget_diff_base, - args.split_budget_diff_refined, - ) - ) - if use_advanced: - subcmd = "certify_head_model_auto" - if args.nonvacuous: - subcmd = "certify_head_model_auto_nonvacuous" - else: - subcmd = "certify" - if args.nonvacuous: - subcmd = "certify_nonvacuous" - - cmd = resolve_nfp_cmd(args.nfp_bin) + ["induction"] - if use_advanced: - cmd.append("advanced") - cmd += [ - subcmd, - "--model", - str(model_path), - "--layer", - str(args.layer), - "--head", - str(args.head), - ] - if args.zero_based: - cmd.append("--zero-based") - if args.period is not None: - cmd += ["--period", str(args.period)] - if args.prev_shift: - cmd.append("--prev-shift") - if args.min_active is not None: - cmd += ["--min-active", str(args.min_active)] - if args.min_logit_diff is not None: - cmd += ["--min-logit-diff", args.min_logit_diff] - if args.min_margin is not None: - cmd += ["--min-margin", args.min_margin] - if args.max_eps is not None: - cmd += ["--max-eps", args.max_eps] - if args.preset is not None and not use_advanced: - cmd += ["--preset", args.preset] - if args.timing is not None: - cmd += ["--timing", str(args.timing)] - if args.heartbeat_ms is not None: - cmd += ["--heartbeat-ms", str(args.heartbeat_ms)] - if args.split_budget_q is not None: - cmd += ["--split-budget-q", str(args.split_budget_q)] - if args.split_budget_k is not None: - cmd += ["--split-budget-k", str(args.split_budget_k)] - if args.split_budget_diff_base is not None: - cmd += ["--split-budget-diff-base", str(args.split_budget_diff_base)] - if args.split_budget_diff_refined is not None: - cmd += ["--split-budget-diff-refined", str(args.split_budget_diff_refined)] - if args.skip_logit_diff: - cmd.append("--skip-logit-diff") - - proc = subprocess.run(cmd) - return proc.returncode - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/scripts/ci_sanity_forward_step.py b/scripts/ci_sanity_forward_step.py deleted file mode 100644 index a2668dd..0000000 --- a/scripts/ci_sanity_forward_step.py +++ /dev/null @@ -1,206 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: AGPL-3.0-or-later - -""" -CI sanity check: export a tiny GPT-2-like model to `.nfpt` and compare one forward step -against PyTorch. - -This is meant to be cheap enough for CI while still catching semantic drift (e.g. missing -attention biases, wrong GeLU constant, etc.). - -Usage (CI): - python3 scripts/ci_sanity_forward_step.py --model sshleifer/tiny-gpt2 --seqLen 8 -""" - -from __future__ import annotations - -import argparse -import subprocess -import tempfile -from pathlib import Path - -import numpy as np - -try: - import torch - from transformers import GPT2Model -except ImportError as e: - raise SystemExit( - "Missing deps. Install with e.g.: pip install torch transformers\n" - f"ImportError: {e}" - ) - - -def write_header(f, **fields: object) -> None: - f.write(b"NFP_BINARY_V1\n") - for key, value in fields.items(): - f.write(f"{key}={value}\n".encode("ascii")) - f.write(b"BINARY_START\n") - - -def write_i32(f, data: np.ndarray) -> None: - arr = np.ascontiguousarray(data, dtype=" None: - arr = np.ascontiguousarray(data, dtype=" np.ndarray: - if param is None: - return np.zeros(size, dtype=np.float64) - return param.detach().cpu().numpy() - - -def export_nfpt(model: GPT2Model, seq_len: int, out_path: Path) -> np.ndarray: - cfg = model.config - num_layers = int(cfg.n_layer) - num_heads = int(cfg.n_head) - model_dim = int(cfg.n_embd) - head_dim = model_dim // num_heads - hidden_dim = int(cfg.n_inner or 4 * model_dim) - vocab_size = int(cfg.vocab_size) - layer_norm_eps = float(cfg.layer_norm_epsilon) - - # Deterministic token sequence. - tokens = (np.arange(seq_len, dtype=np.int64) % min(32, vocab_size)).astype(np.int64) - - # GPT-2 embeddings: wte[token] + wpe[pos]. - wte = model.wte.weight.detach().cpu().numpy() - wpe = model.wpe.weight.detach().cpu().numpy() - embeddings = wte[tokens] + wpe[:seq_len] - - with out_path.open("wb") as f: - write_header( - f, - num_layers=num_layers, - num_heads=num_heads, - model_dim=model_dim, - head_dim=head_dim, - hidden_dim=hidden_dim, - vocab_size=vocab_size, - seq_len=seq_len, - layer_norm_eps=layer_norm_eps, - gelu_kind="tanh", - ) - - write_i32(f, tokens) - write_f64(f, embeddings) - - for layer_idx in range(num_layers): - block = model.h[layer_idx] - - c_attn_w = block.attn.c_attn.weight.detach().cpu().numpy() # (d, 3d) - c_attn_b = get_bias(block.attn.c_attn.bias, 3 * model_dim) # (3d,) - c_proj_w = block.attn.c_proj.weight.detach().cpu().numpy() # (d, d) - c_proj_b = get_bias(block.attn.c_proj.bias, model_dim) # (d,) - - W_Q_all = c_attn_w[:, 0:model_dim] - W_K_all = c_attn_w[:, model_dim : 2 * model_dim] - W_V_all = c_attn_w[:, 2 * model_dim : 3 * model_dim] - b_Q_all = c_attn_b[0:model_dim] - b_K_all = c_attn_b[model_dim : 2 * model_dim] - b_V_all = c_attn_b[2 * model_dim : 3 * model_dim] - - for h in range(num_heads): - start, end = h * head_dim, (h + 1) * head_dim - write_f64(f, W_Q_all[:, start:end]) - write_f64(f, b_Q_all[start:end]) - write_f64(f, W_K_all[:, start:end]) - write_f64(f, b_K_all[start:end]) - write_f64(f, W_V_all[:, start:end]) - write_f64(f, b_V_all[start:end]) - write_f64(f, c_proj_w[start:end, :]) - - write_f64(f, c_proj_b) - - write_f64(f, block.mlp.c_fc.weight.detach().cpu().numpy()) - write_f64(f, get_bias(block.mlp.c_fc.bias, hidden_dim)) - write_f64(f, block.mlp.c_proj.weight.detach().cpu().numpy()) - write_f64(f, get_bias(block.mlp.c_proj.bias, model_dim)) - - write_f64(f, block.ln_1.weight.detach().cpu().numpy()) - write_f64(f, get_bias(block.ln_1.bias, model_dim)) - write_f64(f, block.ln_2.weight.detach().cpu().numpy()) - write_f64(f, get_bias(block.ln_2.bias, model_dim)) - - write_f64(f, model.ln_f.weight.detach().cpu().numpy()) - write_f64(f, get_bias(model.ln_f.bias, model_dim)) - write_f64(f, wte.T) - - return tokens - - -def run_lean_dump(model_path: Path, layer: int, pos: int, take: int) -> np.ndarray: - exe = Path(".lake/build/bin/nfp") - if not exe.exists(): - raise SystemExit("Missing `.lake/build/bin/nfp`. Run `lake build nfp` first.") - cmd = [ - str(exe), - "dump", - "--kind", - "afterLayer", - "--layer", - str(layer), - "--pos", - str(pos), - "--take", - str(take), - str(model_path), - ] - out = subprocess.check_output(cmd, text=True) - lines = [ln.strip() for ln in out.splitlines() if ln.strip()] - dump_i = None - for i, ln in enumerate(lines): - if ln.startswith("DUMP "): - dump_i = i - if dump_i is None or dump_i + 2 >= len(lines): - raise SystemExit(f"Unexpected Lean dump output:\n{out}") - vals = [float(x) for x in lines[dump_i + 2].split()] - return np.asarray(vals, dtype=np.float64) - - -def run_torch(model: GPT2Model, tokens: np.ndarray, layer: int, pos: int, take: int) -> np.ndarray: - input_ids = torch.tensor(tokens.reshape(1, -1), dtype=torch.long) - with torch.no_grad(): - out = model(input_ids=input_ids, output_hidden_states=True) - hs = out.hidden_states - if hs is None: - raise SystemExit("Expected output_hidden_states=True to return hidden_states.") - x = hs[layer + 1][0, pos, :take].detach().cpu().numpy() - return x.astype(np.float64) - - -def main() -> None: - ap = argparse.ArgumentParser() - ap.add_argument("--model", default="sshleifer/tiny-gpt2") - ap.add_argument("--seqLen", type=int, default=8) - ap.add_argument("--layer", type=int, default=0) - ap.add_argument("--pos", type=int, default=0) - ap.add_argument("--take", type=int, default=16) - ap.add_argument("--tol", type=float, default=1e-3) - args = ap.parse_args() - - m = GPT2Model.from_pretrained(args.model) - m = m.to(dtype=torch.float64) - m.eval() - - with tempfile.TemporaryDirectory() as td: - path = Path(td) / "ci_tiny.nfpt" - toks = export_nfpt(m, args.seqLen, path) - lean = run_lean_dump(path, args.layer, args.pos, args.take) - torch_x = run_torch(m, toks, args.layer, args.pos, args.take) - diff = float(np.max(np.abs(lean - torch_x))) - print(f"max_abs_diff={diff:.6g} tol={args.tol:.6g}") - if not np.isfinite(diff) or diff > args.tol: - print("FAIL") - print("lean :", lean.tolist()) - print("torch:", torch_x.tolist()) - raise SystemExit(1) - print("OK") - - -if __name__ == "__main__": - main() diff --git a/scripts/demo_gpt2_induction_sound.sh b/scripts/demo_gpt2_induction_sound.sh deleted file mode 100755 index cd3a8ae..0000000 --- a/scripts/demo_gpt2_induction_sound.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail - -MODEL_PATH="${1:-models/gpt2_rigorous.nfpt}" -REPORT_PATH="${2:-reports/gpt2_induction_sound_scan.txt}" -EXTRA_ARGS=() -if [ "$#" -gt 2 ]; then - EXTRA_ARGS=("${@:3}") -fi - -PYTHON_BIN="python" -if ! command -v "$PYTHON_BIN" >/dev/null 2>&1; then - PYTHON_BIN="python3" -fi - -if ! "$PYTHON_BIN" scripts/ensure_gelu_kind.py --check "$MODEL_PATH"; then - PATCHED_PATH="${MODEL_PATH%.nfpt}_with_gelu_kind.nfpt" - "$PYTHON_BIN" scripts/ensure_gelu_kind.py "$MODEL_PATH" \ - --output "$PATCHED_PATH" \ - --default tanh - MODEL_PATH="$PATCHED_PATH" -fi - -if [ "${#EXTRA_ARGS[@]}" -gt 0 ]; then - "$PYTHON_BIN" scripts/scan_gpt2_induction_sound.py \ - --model "$MODEL_PATH" \ - --output "$REPORT_PATH" \ - "${EXTRA_ARGS[@]}" -else - "$PYTHON_BIN" scripts/scan_gpt2_induction_sound.py \ - --model "$MODEL_PATH" \ - --output "$REPORT_PATH" -fi - -echo "Report written to $REPORT_PATH" diff --git a/scripts/demo_gpt2_sound.sh b/scripts/demo_gpt2_sound.sh deleted file mode 100755 index 67660b9..0000000 --- a/scripts/demo_gpt2_sound.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail - -MODEL_PATH="${1:-models/gpt2.nfpt}" -REPORT_PATH="${2:-reports/gpt2_sound_demo.txt}" - -mkdir -p "$(dirname "$MODEL_PATH")" "$(dirname "$REPORT_PATH")" - -PYTHON_BIN="python" -if ! command -v "$PYTHON_BIN" >/dev/null 2>&1; then - PYTHON_BIN="python3" -fi - -if [ ! -f "$MODEL_PATH" ]; then - if command -v uv >/dev/null 2>&1; then - uv run python scripts/export_gpt2.py "$MODEL_PATH" - else - "$PYTHON_BIN" scripts/export_gpt2.py "$MODEL_PATH" - fi -fi - -if ! "$PYTHON_BIN" scripts/ensure_gelu_kind.py --check "$MODEL_PATH"; then - PATCHED_PATH="${MODEL_PATH%.nfpt}_with_gelu_kind.nfpt" - "$PYTHON_BIN" scripts/ensure_gelu_kind.py "$MODEL_PATH" \ - --output "$PATCHED_PATH" \ - --default tanh - MODEL_PATH="$PATCHED_PATH" -fi - -lake exe nfp certify "$MODEL_PATH" --output "$REPORT_PATH" -echo "Report written to $REPORT_PATH" diff --git a/scripts/demo_tiny_induction_cert.sh b/scripts/demo_tiny_induction_cert.sh deleted file mode 100755 index f4e1ebc..0000000 --- a/scripts/demo_tiny_induction_cert.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail - -MODEL_TEXT="${1:-tests/fixtures/tiny_sound_model.nfpt}" -INPUT_TEXT="${2:-tests/fixtures/tiny_sound_input.nfpt}" -BINARY_PATH="${3:-tests/fixtures/tiny_sound_binary.nfpt}" -REPORT_PATH="${4:-reports/tiny_induction_cert.txt}" - -mkdir -p "$(dirname "$BINARY_PATH")" "$(dirname "$REPORT_PATH")" - -PYTHON_BIN="python" -if ! command -v "$PYTHON_BIN" >/dev/null 2>&1; then - PYTHON_BIN="python3" -fi - -"$PYTHON_BIN" scripts/convert_text_fixture_to_binary.py \ - --model "$MODEL_TEXT" \ - --input "$INPUT_TEXT" \ - --output "$BINARY_PATH" - -lake exe nfp induction_cert "$BINARY_PATH" \ - --layer1 0 --head1 0 --layer2 0 --head2 0 --coord 0 \ - --offset1 -1 --offset2 -1 --target 2 --negative 1 \ - --delta 0.05 --output "$REPORT_PATH" - -echo "Report written to $REPORT_PATH" diff --git a/scripts/demo_tiny_local_binary.sh b/scripts/demo_tiny_local_binary.sh deleted file mode 100755 index 28a2018..0000000 --- a/scripts/demo_tiny_local_binary.sh +++ /dev/null @@ -1,28 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail - -MODEL_TEXT="${1:-tests/fixtures/tiny_sound_model.nfpt}" -INPUT_TEXT="${2:-tests/fixtures/tiny_sound_input.nfpt}" -BINARY_PATH="${3:-tests/fixtures/tiny_sound_binary.nfpt}" -REPORT_PATH="${4:-reports/tiny_sound_local_binary.txt}" - -mkdir -p "$(dirname "$BINARY_PATH")" "$(dirname "$REPORT_PATH")" - -if [ "${USE_UV:-0}" = "1" ] && command -v uv >/dev/null 2>&1; then - uv run python scripts/convert_text_fixture_to_binary.py \ - --model "$MODEL_TEXT" \ - --input "$INPUT_TEXT" \ - --output "$BINARY_PATH" -else - PYTHON_BIN="python" - if ! command -v "$PYTHON_BIN" >/dev/null 2>&1; then - PYTHON_BIN="python3" - fi - "$PYTHON_BIN" scripts/convert_text_fixture_to_binary.py \ - --model "$MODEL_TEXT" \ - --input "$INPUT_TEXT" \ - --output "$BINARY_PATH" -fi - -lake exe nfp certify "$BINARY_PATH" --delta 0.05 --output "$REPORT_PATH" -echo "Report written to $REPORT_PATH" diff --git a/scripts/discover_gpt2_induction_targets.py b/scripts/discover_gpt2_induction_targets.py deleted file mode 100644 index df146ce..0000000 --- a/scripts/discover_gpt2_induction_targets.py +++ /dev/null @@ -1,1251 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: AGPL-3.0-or-later - -""" -Discover promising GPT-2 induction heads and logit-diff directions from an NFP binary. - -This script is untrusted: it uses floating-point arithmetic to score candidates -and optionally invokes the Lean verifier (`nfp induction certify_head_model_nonvacuous`) -to confirm nonvacuous bounds when scoring by logit-diff. - -Layer/head indices are one-based to align with the mechanistic interpretability -literature. - -By default, `prev`/active are built from bigram prefix matches (the token at -q-1 maps to the *following* token after its previous occurrence), and heads are -ranked by attention to `prev`. Use `--score-mode=copy` or `--score-mode=attn_copy` -to include OV/copying alignment in the ranking. Use `--use-activations` to -score heads using real layer activations from a HuggingFace GPT-2 model rather -than the embedding-only approximation stored in the NFP file. -""" - -from __future__ import annotations - -import argparse -import json -import os -import subprocess -import math -from dataclasses import dataclass -from pathlib import Path -from typing import Dict, Iterable, List, Tuple - -import numpy as np - - -@dataclass(frozen=True) -class HeadResult: - layer: int - head: int - target: int - negative: int - logit_lb: float - eps: float - margin: float - min_prev: float - value_range: float - active: int - prev_mean: float - prev_median: float - prev_top1_frac: float - - -@dataclass(frozen=True) -class AttnResult: - layer: int - head: int - score: float - prev_mean: float - prev_median: float - prev_top1_frac: float - copy_mean: float - copy_weighted_mean: float - eps: float - margin: float - active: int - - -@dataclass(frozen=True) -class StripeResult: - layer: int - head: int - score: float - stripe_mean: float - stripe_median: float - stripe_top1_frac: float - eps: float - margin: float - active: int - - -@dataclass(frozen=True) -class CircuitResult: - prev_layer: int - prev_head: int - induction_layer: int - induction_head: int - score: float - prev_mean: float - prev_median: float - prev_top1_frac: float - stripe_mean: float - stripe_median: float - stripe_top1_frac: float - - -@dataclass(frozen=True) -class CircuitCopyResult: - prev_layer: int - prev_head: int - induction_layer: int - induction_head: int - score: float - prev_mean: float - prev_median: float - prev_top1_frac: float - stripe_mean: float - stripe_median: float - stripe_top1_frac: float - copy_mean: float - copy_weighted_mean: float - - -def parse_header(f) -> Dict[str, str]: - header: Dict[str, str] = {} - magic = f.readline().decode("ascii").strip() - if magic != "NFP_BINARY_V1": - raise SystemExit(f"Unsupported magic header: {magic}") - while True: - line = f.readline() - if line == b"": - raise SystemExit("Unexpected EOF while reading header.") - text = line.decode("ascii").strip() - if text == "BINARY_START": - break - if "=" in text: - key, value = text.split("=", 1) - header[key.strip()] = value.strip() - return header - - -def read_i32(f, count: int) -> np.ndarray: - raw = f.read(count * 4) - if len(raw) != count * 4: - raise SystemExit("Unexpected EOF while reading int32 payload.") - return np.frombuffer(raw, dtype=" np.ndarray: - raw = f.read(count * 8) - if len(raw) != count * 8: - raise SystemExit("Unexpected EOF while reading float64 payload.") - return np.frombuffer(raw, dtype=" None: - offset = count * 8 - f.seek(offset, 1) - - -def build_prev(tokens: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: - prev = np.zeros_like(tokens) - active = np.zeros_like(tokens, dtype=bool) - last_seen: Dict[int, int] = {} - for idx, tok in enumerate(tokens.tolist()): - if idx == 0: - prev[idx] = 0 - active[idx] = False - else: - if tok in last_seen: - prev[idx] = last_seen[tok] - active[idx] = True - else: - prev[idx] = 0 - active[idx] = False - last_seen[tok] = idx - return prev, active - - -def build_prev_bigram(tokens: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: - """ - Bigram-prefix induction prev: for each position q>=1, look at token at q-1, - find its previous occurrence index j, and set prev[q] = j + 1. - - Example (tokens = [1,2,1,3,2,1]): - prev = [0,0,0,1,0,2], active = [F,F,F,T,F,T] - """ - prev_token, active_token = build_prev(tokens) - prev = np.zeros_like(tokens) - active = np.zeros_like(tokens, dtype=bool) - if tokens.size <= 1: - return prev, active - prev_shift = prev_token[:-1] + 1 - active_shift = active_token[:-1] - prev[1:] = np.where(active_shift, prev_shift, 0) - active[1:] = active_shift - return prev, active - - -def build_prev_period(seq_len: int, period: int) -> Tuple[np.ndarray, np.ndarray]: - prev = np.zeros(seq_len, dtype=np.int64) - active = np.zeros(seq_len, dtype=bool) - idx = np.arange(seq_len) - mask = idx >= period - prev[mask] = idx[mask] - period - active[mask] = True - return prev, active - - -def build_prev_period_shift(seq_len: int, period: int) -> Tuple[np.ndarray, np.ndarray]: - prev = np.zeros(seq_len, dtype=np.int64) - active = np.zeros(seq_len, dtype=bool) - if period <= 0: - return prev, active - idx = np.arange(seq_len) - mask = idx >= period - prev[mask] = idx[mask] - period + 1 - active[mask] = True - return prev, active - - -def layer_norm(x: np.ndarray, gamma: np.ndarray, beta: np.ndarray, eps: float) -> np.ndarray: - mean = x.mean(axis=1, keepdims=True) - var = ((x - mean) ** 2).mean(axis=1, keepdims=True) - x_hat = (x - mean) / np.sqrt(var + eps) - return x_hat * gamma + beta - - -def skip_head_weights(f, model_dim: int, head_dim: int) -> None: - skip_f64(f, model_dim * head_dim) # wq - skip_f64(f, head_dim) # bq - skip_f64(f, model_dim * head_dim) # wk - skip_f64(f, head_dim) # bk - skip_f64(f, model_dim * head_dim) # wv - skip_f64(f, head_dim) # bv - skip_f64(f, head_dim * model_dim) # wo - - -def skip_layer_weights( - f, - model_dim: int, - head_dim: int, - num_heads: int, - hidden_dim: int, -) -> None: - for _ in range(num_heads): - skip_head_weights(f, model_dim, head_dim) - skip_f64(f, model_dim) # attn bias - skip_f64(f, model_dim * hidden_dim) - skip_f64(f, hidden_dim) - skip_f64(f, hidden_dim * model_dim) - skip_f64(f, model_dim) - skip_f64(f, model_dim) # ln1 gamma - skip_f64(f, model_dim) # ln1 beta - skip_f64(f, model_dim) # ln2 gamma - skip_f64(f, model_dim) # ln2 beta - - -def load_hf_model_and_states(tokens: np.ndarray, model_name: str, device: str): - try: - import torch - from transformers import AutoModel - except ImportError as exc: # pragma: no cover - optional dependency - raise SystemExit( - "Activation mode requires torch + transformers. " - "Install them (e.g., `uv run --with torch --with transformers ...`)." - ) from exc - - torch.set_grad_enabled(False) - model = AutoModel.from_pretrained(model_name) - model.eval() - model.to(device) - input_ids = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0) - outputs = model(input_ids, output_hidden_states=True, use_cache=False) - hidden_states = outputs.hidden_states - if hidden_states is None: - raise SystemExit("HuggingFace model did not return hidden states.") - return model, hidden_states - - -def get_transformer_blocks(model): - if hasattr(model, "transformer"): - return model.transformer.h - if hasattr(model, "h"): - return model.h - raise SystemExit("Unsupported HuggingFace model structure (missing transformer blocks).") - - -def compute_head_data_from_activations( - model, - hidden_states, - layers: List[int], - heads: List[int], - prev: np.ndarray, - active_positions: List[int], - prev_indices: np.ndarray, - head_dim: int, - seq_len: int, - stripe_prev: np.ndarray | None = None, - stripe_positions: List[int] | None = None, - prevtok_prev: np.ndarray | None = None, - prevtok_positions: List[int] | None = None, -) -> Tuple[ - Dict[ - Tuple[int, int], - Tuple[np.ndarray, np.ndarray, np.ndarray, float, float, float, float, float], - ], - Dict[Tuple[int, int], Tuple[float, float, float, float, float]], - Dict[Tuple[int, int], Tuple[float, float, float, float, float]], -]: - try: - import torch - except ImportError as exc: # pragma: no cover - optional dependency - raise SystemExit("Activation mode requires torch.") from exc - - def split_heads(x: "torch.Tensor", num_heads: int, head_dim_local: int) -> "torch.Tensor": - batch, seq, _ = x.shape - x = x.reshape(batch, seq, num_heads, head_dim_local) - return x.permute(0, 2, 1, 3) - - blocks = get_transformer_blocks(model) - head_data: Dict[ - Tuple[int, int], - Tuple[np.ndarray, np.ndarray, np.ndarray, float, float, float, float, float], - ] = {} - stripe_data: Dict[Tuple[int, int], Tuple[float, float, float, float, float]] = {} - prevtok_data: Dict[Tuple[int, int], Tuple[float, float, float, float, float]] = {} - device = hidden_states[0].device - causal_mask = torch.triu( - torch.ones((seq_len, seq_len), device=device, dtype=torch.bool), - diagonal=1, - ) - for layer_idx in layers: - block = blocks[layer_idx] - hidden = hidden_states[layer_idx] - ln = block.ln_1(hidden) - qkv = block.attn.c_attn(ln) - split_size = getattr(block.attn, "split_size", qkv.shape[-1] // 3) - q, k, v = qkv.split(split_size, dim=2) - num_heads = getattr(block.attn, "num_heads", q.shape[-1] // head_dim) - head_dim_local = getattr(block.attn, "head_dim", head_dim) - if head_dim_local != head_dim: - raise SystemExit("HuggingFace head_dim does not match NFP header.") - scale = 1.0 / math.sqrt(head_dim_local) - q = split_heads(q, num_heads, head_dim_local) - k = split_heads(k, num_heads, head_dim_local) - v = split_heads(v, num_heads, head_dim_local) - scores = torch.matmul(q, k.transpose(-2, -1)) * scale - scores = scores.masked_fill(causal_mask, -10000.0) - weights = torch.softmax(scores, dim=-1) - wo_full = block.attn.c_proj.weight - - for head_idx in heads: - weights_head = weights[0, head_idx] - scores_head = scores[0, head_idx] - weights_np = weights_head.detach().cpu().numpy() - scores_np = scores_head.detach().cpu().numpy() - eps, margin, prev_mean, prev_median, prev_top1 = compute_eps_margin( - weights_np, scores_np, prev, active_positions - ) - if stripe_prev is not None and stripe_positions is not None: - eps_s, margin_s, stripe_mean, stripe_median, stripe_top1 = compute_eps_margin( - weights_np, scores_np, stripe_prev, stripe_positions - ) - stripe_data[(layer_idx, head_idx)] = ( - stripe_mean, - stripe_median, - stripe_top1, - eps_s, - margin_s, - ) - if prevtok_prev is not None and prevtok_positions is not None: - eps_p, margin_p, prevtok_mean, prevtok_median, prevtok_top1 = compute_eps_margin( - weights_np, scores_np, prevtok_prev, prevtok_positions - ) - prevtok_data[(layer_idx, head_idx)] = ( - prevtok_mean, - prevtok_median, - prevtok_top1, - eps_p, - margin_p, - ) - prev_weights = weights_np[np.array(active_positions), prev_indices] - - v_head = v[0, head_idx] - v_np = v_head.detach().cpu().numpy() - start = head_idx * head_dim_local - end = start + head_dim_local - wo_np = wo_full[start:end, :].detach().cpu().numpy() - - head_data[(layer_idx, head_idx)] = ( - v_np, - wo_np, - prev_weights, - eps, - margin, - prev_mean, - prev_median, - prev_top1, - ) - return head_data, stripe_data, prevtok_data - - -def softmax(scores: np.ndarray) -> np.ndarray: - shift = scores - scores.max(axis=1, keepdims=True) - exp = np.exp(shift) - return exp / exp.sum(axis=1, keepdims=True) - - -def parse_index_list(raw: str | None, max_value: int) -> List[int] | None: - if raw is None: - return None - raw = raw.strip() - if raw.lower() == "all": - return list(range(max_value)) - out: List[int] = [] - for part in raw.split(","): - part = part.strip() - if not part: - continue - idx = int(part) - if idx <= 0 or idx > max_value: - raise ValueError(f"index {idx} out of range [1,{max_value}]") - out.append(idx - 1) - return out - - -def resolve_nfp_cmd(nfp_bin: str | None) -> List[str]: - if nfp_bin: - return [nfp_bin] - env_bin = os.environ.get("NFP_BIN") - if env_bin: - return [env_bin] - local_bin = Path(".lake/build/bin/nfp") - if local_bin.exists(): - return [str(local_bin)] - return ["lake", "exe", "nfp"] - - -def read_unembed_column( - f, - start: int, - model_dim: int, - vocab_size: int, - col: int, -) -> np.ndarray: - if col < 0 or col >= vocab_size: - raise ValueError(f"column {col} out of range") - row_bytes = vocab_size * 8 - data = np.zeros(model_dim, dtype=np.float64) - for row in range(model_dim): - base = start + row * row_bytes - f.seek(base + col * 8) - data[row] = np.frombuffer(f.read(8), dtype=" Tuple[float, float, float, float, float]: - eps_vals: List[float] = [] - margin_vals: List[float] = [] - prev_vals: List[float] = [] - max_other_vals: List[float] = [] - for q in active_positions: - prev_q = int(prev[q]) - prev_w = weights[q, prev_q] - max_other = np.max(np.delete(weights[q], prev_q)) - eps_vals.append(max(max_other, 1.0 - prev_w)) - diffs = scores[q, prev_q] - np.delete(scores[q], prev_q) - margin_vals.append(float(np.min(diffs)) if diffs.size > 0 else 0.0) - prev_vals.append(float(prev_w)) - max_other_vals.append(float(max_other)) - if not eps_vals: - return 0.0, 0.0, 0.0, 0.0, 0.0 - prev_arr = np.asarray(prev_vals, dtype=np.float64) - max_other_arr = np.asarray(max_other_vals, dtype=np.float64) - prev_mean = float(prev_arr.mean()) - prev_median = float(np.median(prev_arr)) - prev_top1 = float(np.mean(prev_arr >= max_other_arr)) - return max(eps_vals), min(margin_vals), prev_mean, prev_median, prev_top1 - - -def compute_copy_scores( - ov: np.ndarray, - weights_prev: np.ndarray, - columns: Dict[int, np.ndarray], - tokens: np.ndarray, - prev: np.ndarray, - active_positions: Iterable[int], -) -> Tuple[float, float]: - copy_vals: List[float] = [] - copy_weighted_vals: List[float] = [] - for idx, q in enumerate(active_positions): - tok = int(tokens[q]) - col = columns.get(tok) - if col is None: - continue - prev_q = int(prev[q]) - val = float(ov[prev_q] @ col) - copy_vals.append(val) - copy_weighted_vals.append(float(weights_prev[idx]) * val) - if not copy_vals: - return 0.0, 0.0 - return float(np.mean(copy_vals)), float(np.mean(copy_weighted_vals)) - - -def format_result(result: HeadResult) -> str: - layer = result.layer + 1 - head = result.head + 1 - return ( - f"L{layer}H{head} target={result.target} " - f"negative={result.negative} logitLB={result.logit_lb:.6f} " - f"eps={result.eps:.6f} margin={result.margin:.6f} " - f"minPrev={result.min_prev:.6f} range={result.value_range:.6f} " - f"prevMean={result.prev_mean:.6f} prevMedian={result.prev_median:.6f} " - f"prevTop1={result.prev_top1_frac:.3f} active={result.active}" - ) - - -def format_attn_result(result: AttnResult) -> str: - layer = result.layer + 1 - head = result.head + 1 - return ( - f"L{layer}H{head} score={result.score:.6f} " - f"prevMean={result.prev_mean:.6f} prevMedian={result.prev_median:.6f} " - f"prevTop1={result.prev_top1_frac:.3f} " - f"copyMean={result.copy_mean:.6f} copyWeighted={result.copy_weighted_mean:.6f} " - f"eps={result.eps:.6f} margin={result.margin:.6f} active={result.active}" - ) - - -def format_stripe_result(result: StripeResult) -> str: - layer = result.layer + 1 - head = result.head + 1 - return ( - f"L{layer}H{head} score={result.score:.6f} " - f"stripeMean={result.stripe_mean:.6f} stripeMedian={result.stripe_median:.6f} " - f"stripeTop1={result.stripe_top1_frac:.3f} " - f"eps={result.eps:.6f} margin={result.margin:.6f} active={result.active}" - ) - - -def format_circuit_result(result: CircuitResult) -> str: - prev_layer = result.prev_layer + 1 - prev_head = result.prev_head + 1 - ind_layer = result.induction_layer + 1 - ind_head = result.induction_head + 1 - return ( - f"prev=L{prev_layer}H{prev_head} ind=L{ind_layer}H{ind_head} " - f"score={result.score:.6f} prevMean={result.prev_mean:.6f} " - f"prevMedian={result.prev_median:.6f} prevTop1={result.prev_top1_frac:.3f} " - f"stripeMean={result.stripe_mean:.6f} stripeMedian={result.stripe_median:.6f} " - f"stripeTop1={result.stripe_top1_frac:.3f}" - ) - - -def format_circuit_copy_result(result: CircuitCopyResult) -> str: - prev_layer = result.prev_layer + 1 - prev_head = result.prev_head + 1 - ind_layer = result.induction_layer + 1 - ind_head = result.induction_head + 1 - return ( - f"prev=L{prev_layer}H{prev_head} ind=L{ind_layer}H{ind_head} " - f"score={result.score:.6f} prevMean={result.prev_mean:.6f} " - f"stripeMean={result.stripe_mean:.6f} copyMean={result.copy_mean:.6f} " - f"copyWeighted={result.copy_weighted_mean:.6f} " - f"prevMedian={result.prev_median:.6f} prevTop1={result.prev_top1_frac:.3f} " - f"stripeMedian={result.stripe_median:.6f} stripeTop1={result.stripe_top1_frac:.3f}" - ) - - -def main() -> int: - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("--model", required=True, type=Path, help="Path to NFP_BINARY_V1 model") - parser.add_argument("--max-tokens", type=int, default=32, - help="Maximum unique tokens from the prompt to consider") - parser.add_argument("--top", type=int, default=20, help="Number of results to report") - parser.add_argument("--verify-top", type=int, default=0, - help="Run verifier on the top N candidates") - parser.add_argument( - "--score-mode", - choices=["attn", "copy", "attn_copy", "stripe", "circuit", "circuit_copy", "logit"], - default="attn", - help=( - "Rank heads by attention to prev (attn), OV copy score (copy), " - "attention-weighted copy score (attn_copy), induction stripe attention " - "(stripe), circuit pairing (circuit), circuit + copy (circuit_copy), " - "or logit lower bound (logit)." - ), - ) - parser.add_argument( - "--min-score", - type=float, - default=0.0, - help="Minimum score threshold for the selected score mode.", - ) - parser.add_argument( - "--min-copy", - type=float, - default=None, - help="Optional minimum OV copy score.", - ) - parser.add_argument("--min-eps", type=float, default=0.5, - help="Filter candidates with eps above this value") - parser.add_argument("--min-margin", type=float, default=0.0, - help="Filter candidates with margin below this value") - parser.add_argument("--min-logit-lb", type=float, default=0.0, - help="Filter candidates with logit lower bound below this value") - parser.add_argument("--layers", help="Comma-separated layer list or 'all'") - parser.add_argument("--heads", help="Comma-separated head list or 'all'") - parser.add_argument("--period", type=int, help="Optional prompt period override") - parser.add_argument( - "--prev-mode", - choices=["bigram", "token", "period", "period_shift"], - default="bigram", - help=( - "Choose prev/active construction (default: bigram prefix match). " - "period_shift uses q-period+1." - ), - ) - parser.add_argument( - "--stripe-period", - type=int, - help="Period for induction stripe scoring (required for --score-mode=stripe).", - ) - parser.add_argument("--output", type=Path, default=Path("reports/gpt2_induction_discover.txt")) - parser.add_argument("--json-out", type=Path, help="Optional JSON output path") - parser.add_argument("--nfp-bin", help="Path to nfp binary (defaults to .lake/build/bin/nfp)") - parser.add_argument( - "--use-activations", - action="store_true", - help="Use HuggingFace GPT-2 activations for Q/K/V instead of embedding-only approximation.", - ) - parser.add_argument( - "--hf-model", - default="gpt2", - help="HuggingFace model name or path (activation mode).", - ) - parser.add_argument( - "--device", - default="cpu", - help="Torch device for activation mode (e.g. cpu, cuda, mps).", - ) - args = parser.parse_args() - - if args.max_tokens <= 1: - raise SystemExit("max-tokens must be at least 2") - if args.verify_top > 0 and args.score_mode != "logit": - raise SystemExit("--verify-top requires --score-mode=logit") - if args.score_mode in {"stripe", "circuit", "circuit_copy"} and args.stripe_period is None: - raise SystemExit("--score-mode=stripe/circuit requires --stripe-period") - if args.score_mode in {"circuit", "circuit_copy"} and args.prev_mode != "bigram": - raise SystemExit("--score-mode=circuit requires --prev-mode=bigram") - - if not args.model.exists(): - raise SystemExit(f"Missing model file: {args.model}") - - with args.model.open("rb") as f: - header = parse_header(f) - num_layers = int(header["num_layers"]) - num_heads = int(header["num_heads"]) - model_dim = int(header["model_dim"]) - head_dim = int(header["head_dim"]) - vocab_size = int(header["vocab_size"]) - seq_len = int(header["seq_len"]) - hidden_dim = int(header["hidden_dim"]) - ln_eps = float(header.get("layer_norm_eps", header.get("eps", "0"))) - - layers = parse_index_list(args.layers, num_layers) or list(range(num_layers)) - heads = parse_index_list(args.heads, num_heads) or list(range(num_heads)) - - tokens = read_i32(f, seq_len) - if args.use_activations: - skip_f64(f, seq_len * model_dim) - embeddings = None - else: - embeddings = read_f64(f, seq_len * model_dim).reshape(seq_len, model_dim) - - if args.prev_mode not in {"period", "period_shift"} and args.period is not None: - raise SystemExit("--period is incompatible with --prev-mode=token/bigram") - if args.prev_mode in {"period", "period_shift"} and args.period is None: - raise SystemExit("--prev-mode=period/period_shift requires --period") - - if args.prev_mode == "period": - period = int(args.period) - prev, active_mask = build_prev_period(seq_len, period) - elif args.prev_mode == "period_shift": - period = int(args.period) - prev, active_mask = build_prev_period_shift(seq_len, period) - elif args.prev_mode == "bigram": - prev, active_mask = build_prev_bigram(tokens) - else: - prev, active_mask = build_prev(tokens) - - active_positions = [int(i) for i, flag in enumerate(active_mask) if flag] - if not active_positions: - raise SystemExit("No active positions found in the prompt") - prev_indices = prev[np.array(active_positions, dtype=np.int64)] - stripe_prev = None - stripe_positions = None - if args.score_mode in {"stripe", "circuit", "circuit_copy"}: - stripe_period = int(args.stripe_period) - stripe_prev, stripe_active = build_prev_period(seq_len, stripe_period) - stripe_positions = [int(i) for i, flag in enumerate(stripe_active) if flag] - if not stripe_positions: - raise SystemExit("No stripe positions found for the requested period") - prevtok_prev = None - prevtok_positions = None - if args.score_mode in {"circuit", "circuit_copy"}: - prevtok_prev, prevtok_active = build_prev_period(seq_len, 1) - prevtok_positions = [int(i) for i, flag in enumerate(prevtok_active) if flag] - if not prevtok_positions: - raise SystemExit("No previous-token positions found") - - prompt_tokens = sorted({int(tok) for tok in tokens.tolist()}) - unique_tokens = [] - seen = set() - for tok in tokens.tolist(): - if tok not in seen: - seen.add(tok) - unique_tokens.append(int(tok)) - if len(unique_tokens) >= args.max_tokens: - break - if len(unique_tokens) < 2: - raise SystemExit("Need at least two unique tokens to form directions") - - head_data: Dict[ - Tuple[int, int], - Tuple[np.ndarray, np.ndarray, np.ndarray, float, float, float, float, float], - ] = {} - stripe_data: Dict[Tuple[int, int], Tuple[float, float, float, float, float]] = {} - prevtok_data: Dict[Tuple[int, int], Tuple[float, float, float, float, float]] = {} - - hf_model = None - hf_states = None - if args.use_activations: - hf_model, hf_states = load_hf_model_and_states(tokens, args.hf_model, args.device) - config = getattr(hf_model, "config", None) - if config is not None: - if getattr(config, "n_layer", num_layers) != num_layers: - raise SystemExit("HuggingFace model layer count does not match NFP header.") - if getattr(config, "n_head", num_heads) != num_heads: - raise SystemExit("HuggingFace model head count does not match NFP header.") - if getattr(config, "n_embd", model_dim) != model_dim: - raise SystemExit("HuggingFace model dimension does not match NFP header.") - if getattr(config, "vocab_size", vocab_size) != vocab_size: - raise SystemExit("HuggingFace vocab size does not match NFP header.") - if getattr(config, "n_positions", seq_len) < seq_len: - raise SystemExit("Prompt length exceeds HuggingFace model context.") - if len(hf_states) < num_layers + 1: - raise SystemExit("Hidden state count is smaller than expected.") - if hf_states[0].shape[1] != seq_len: - raise SystemExit("Hidden state sequence length does not match NFP header.") - for layer_idx in range(num_layers): - if args.use_activations: - skip_layer_weights(f, model_dim, head_dim, num_heads, hidden_dim) - continue - head_weights = [] - for _ in range(num_heads): - wq = read_f64(f, model_dim * head_dim).reshape(model_dim, head_dim) - bq = read_f64(f, head_dim) - wk = read_f64(f, model_dim * head_dim).reshape(model_dim, head_dim) - bk = read_f64(f, head_dim) - wv = read_f64(f, model_dim * head_dim).reshape(model_dim, head_dim) - bv = read_f64(f, head_dim) - wo = read_f64(f, head_dim * model_dim).reshape(head_dim, model_dim) - head_weights.append((wq, bq, wk, bk, wv, bv, wo)) - _attn_bias = read_f64(f, model_dim) - skip_f64(f, model_dim * hidden_dim) - skip_f64(f, hidden_dim) - skip_f64(f, hidden_dim * model_dim) - skip_f64(f, model_dim) - ln1_gamma = read_f64(f, model_dim) - ln1_beta = read_f64(f, model_dim) - skip_f64(f, model_dim) - skip_f64(f, model_dim) - - if layer_idx not in layers: - continue - - ln = layer_norm(embeddings, ln1_gamma, ln1_beta, ln_eps) - scale = 1.0 / np.sqrt(head_dim) - for head_idx in heads: - wq, bq, wk, bk, wv, bv, wo = head_weights[head_idx] - q = ln @ wq + bq - k = ln @ wk + bk - v = ln @ wv + bv - - scores = scale * (q @ k.T) - mask = np.triu(np.ones((seq_len, seq_len), dtype=bool), k=1) - scores = scores.copy() - scores[mask] = -10000.0 - weights = softmax(scores) - - eps, margin, prev_mean, prev_median, prev_top1 = compute_eps_margin( - weights, scores, prev, active_positions - ) - if stripe_prev is not None and stripe_positions is not None: - eps_s, margin_s, stripe_mean, stripe_median, stripe_top1 = compute_eps_margin( - weights, scores, stripe_prev, stripe_positions - ) - stripe_data[(layer_idx, head_idx)] = ( - stripe_mean, - stripe_median, - stripe_top1, - eps_s, - margin_s, - ) - if prevtok_prev is not None and prevtok_positions is not None: - eps_p, margin_p, prevtok_mean, prevtok_median, prevtok_top1 = compute_eps_margin( - weights, scores, prevtok_prev, prevtok_positions - ) - prevtok_data[(layer_idx, head_idx)] = ( - prevtok_mean, - prevtok_median, - prevtok_top1, - eps_p, - margin_p, - ) - prev_weights = weights[np.array(active_positions), prev_indices] - head_data[(layer_idx, head_idx)] = ( - v, - wo, - prev_weights, - eps, - margin, - prev_mean, - prev_median, - prev_top1, - ) - - ln_f_gamma = read_f64(f, model_dim) - _ln_f_beta = read_f64(f, model_dim) - _ = ln_f_gamma - unembed_start = f.tell() - - columns: Dict[int, np.ndarray] = {} - for tok in prompt_tokens: - columns[tok] = read_unembed_column( - f, - unembed_start, - model_dim, - vocab_size, - tok, - ) - - if args.use_activations: - head_data, stripe_data, prevtok_data = compute_head_data_from_activations( - hf_model, - hf_states, - layers, - heads, - prev, - active_positions, - prev_indices, - head_dim, - seq_len, - stripe_prev=stripe_prev, - stripe_positions=stripe_positions, - prevtok_prev=prevtok_prev, - prevtok_positions=prevtok_positions, - ) - - results: List[HeadResult] = [] - attn_results: List[AttnResult] = [] - stripe_results: List[StripeResult] = [] - circuit_results: List[CircuitResult] = [] - circuit_copy_results: List[CircuitCopyResult] = [] - for (layer_idx, head_idx), ( - v, - wo, - prev_weights, - eps, - margin, - prev_mean, - prev_median, - prev_top1, - ) in head_data.items(): - if args.score_mode == "logit": - if eps > args.min_eps or margin < args.min_margin: - continue - if args.score_mode == "stripe": - stripe = stripe_data.get((layer_idx, head_idx)) - if stripe is None: - continue - stripe_mean, stripe_median, stripe_top1, eps_s, margin_s = stripe - score = stripe_mean - if score < args.min_score: - continue - stripe_results.append( - StripeResult( - layer=layer_idx, - head=head_idx, - score=score, - stripe_mean=stripe_mean, - stripe_median=stripe_median, - stripe_top1_frac=stripe_top1, - eps=eps_s, - margin=margin_s, - active=len(stripe_positions) if stripe_positions is not None else 0, - ) - ) - continue - if args.score_mode == "circuit": - stripe = stripe_data.get((layer_idx, head_idx)) - if stripe is None: - continue - stripe_mean, stripe_median, stripe_top1, _eps_s, _margin_s = stripe - best_prev: CircuitResult | None = None - for (prev_layer, prev_head), prev_stats in prevtok_data.items(): - if prev_layer >= layer_idx: - continue - prev_mean, prev_median, prev_top1, _eps_p, _margin_p = prev_stats - score = prev_mean * stripe_mean - if score < args.min_score: - continue - candidate = CircuitResult( - prev_layer=prev_layer, - prev_head=prev_head, - induction_layer=layer_idx, - induction_head=head_idx, - score=score, - prev_mean=prev_mean, - prev_median=prev_median, - prev_top1_frac=prev_top1, - stripe_mean=stripe_mean, - stripe_median=stripe_median, - stripe_top1_frac=stripe_top1, - ) - if best_prev is None or candidate.score > best_prev.score: - best_prev = candidate - if best_prev is not None: - circuit_results.append(best_prev) - continue - if args.score_mode == "circuit_copy": - stripe = stripe_data.get((layer_idx, head_idx)) - if stripe is None: - continue - stripe_mean, stripe_median, stripe_top1, _eps_s, _margin_s = stripe - ov = v @ wo - copy_mean, copy_weighted_mean = compute_copy_scores( - ov, - prev_weights, - columns, - tokens, - prev, - active_positions, - ) - if args.min_copy is not None and copy_mean < args.min_copy: - continue - copy_score = max(copy_mean, 0.0) - best_prev: CircuitCopyResult | None = None - for (prev_layer, prev_head), prev_stats in prevtok_data.items(): - if prev_layer >= layer_idx: - continue - prev_mean, prev_median, prev_top1, _eps_p, _margin_p = prev_stats - score = prev_mean * stripe_mean * copy_score - if score < args.min_score: - continue - candidate = CircuitCopyResult( - prev_layer=prev_layer, - prev_head=prev_head, - induction_layer=layer_idx, - induction_head=head_idx, - score=score, - prev_mean=prev_mean, - prev_median=prev_median, - prev_top1_frac=prev_top1, - stripe_mean=stripe_mean, - stripe_median=stripe_median, - stripe_top1_frac=stripe_top1, - copy_mean=copy_mean, - copy_weighted_mean=copy_weighted_mean, - ) - if best_prev is None or candidate.score > best_prev.score: - best_prev = candidate - if best_prev is not None: - circuit_copy_results.append(best_prev) - continue - if args.score_mode != "logit": - ov = v @ wo - copy_mean, copy_weighted_mean = compute_copy_scores( - ov, - prev_weights, - columns, - tokens, - prev, - active_positions, - ) - if args.min_copy is not None and copy_mean < args.min_copy: - continue - if args.score_mode == "attn": - score = prev_mean - elif args.score_mode == "copy": - score = copy_mean - else: - score = copy_weighted_mean - if score < args.min_score: - continue - attn_results.append( - AttnResult( - layer=layer_idx, - head=head_idx, - score=score, - prev_mean=prev_mean, - prev_median=prev_median, - prev_top1_frac=prev_top1, - copy_mean=copy_mean, - copy_weighted_mean=copy_weighted_mean, - eps=eps, - margin=margin, - active=len(active_positions), - ) - ) - continue - proj: Dict[int, np.ndarray] = {} - for tok in unique_tokens: - dir_head = wo @ columns[tok] - proj[tok] = v @ dir_head - best: HeadResult | None = None - for target in unique_tokens: - vals_target = proj[target] - for negative in unique_tokens: - if target == negative: - continue - vals = vals_target - proj[negative] - vals_prev = vals[prev_indices] - min_prev = float(vals_prev.min()) if vals_prev.size else 0.0 - value_range = float(vals.max() - vals.min()) - logit_lb = min_prev - eps * value_range - if logit_lb < args.min_logit_lb: - continue - candidate = HeadResult( - layer=layer_idx, - head=head_idx, - target=target, - negative=negative, - logit_lb=logit_lb, - eps=eps, - margin=margin, - min_prev=min_prev, - value_range=value_range, - active=len(active_positions), - prev_mean=prev_mean, - prev_median=prev_median, - prev_top1_frac=prev_top1, - ) - if best is None or candidate.logit_lb > best.logit_lb: - best = candidate - if best is not None: - results.append(best) - - if args.score_mode == "circuit_copy": - circuit_copy_results.sort(key=lambda r: r.score, reverse=True) - elif args.score_mode == "circuit": - circuit_results.sort(key=lambda r: r.score, reverse=True) - elif args.score_mode == "stripe": - stripe_results.sort(key=lambda r: r.score, reverse=True) - elif args.score_mode != "logit": - attn_results.sort(key=lambda r: r.score, reverse=True) - else: - results.sort(key=lambda r: r.logit_lb, reverse=True) - args.output.parent.mkdir(parents=True, exist_ok=True) - active_count = len(stripe_positions) if args.score_mode in {"stripe", "circuit", "circuit_copy"} and stripe_positions else len(active_positions) - with args.output.open("w", encoding="ascii") as f: - f.write("Induction discovery (approximate ranking)\n") - f.write(f"model={args.model}\n") - f.write(f"score_mode={args.score_mode}\n") - f.write(f"use_activations={args.use_activations}\n") - if args.use_activations: - f.write(f"hf_model={args.hf_model} device={args.device}\n") - f.write(f"tokens={len(unique_tokens)} active={active_count}\n") - if args.score_mode in {"stripe", "circuit", "circuit_copy"}: - f.write(f"stripe_period={args.stripe_period}\n") - f.write( - f"min-eps={args.min_eps} min-margin={args.min_margin} " - f"min-logit-lb={args.min_logit_lb} min-score={args.min_score} " - f"min-copy={args.min_copy}\n" - ) - if args.score_mode == "circuit_copy": - for rank, result in enumerate(circuit_copy_results[: args.top], start=1): - f.write(f"{rank:02d} {format_circuit_copy_result(result)}\n") - elif args.score_mode == "circuit": - for rank, result in enumerate(circuit_results[: args.top], start=1): - f.write(f"{rank:02d} {format_circuit_result(result)}\n") - elif args.score_mode == "stripe": - for rank, result in enumerate(stripe_results[: args.top], start=1): - f.write(f"{rank:02d} {format_stripe_result(result)}\n") - elif args.score_mode != "logit": - for rank, result in enumerate(attn_results[: args.top], start=1): - f.write(f"{rank:02d} {format_attn_result(result)}\n") - else: - for rank, result in enumerate(results[: args.top], start=1): - f.write(f"{rank:02d} {format_result(result)}\n") - - print(f"Wrote report to {args.output}") - if args.score_mode == "circuit_copy": - for rank, result in enumerate(circuit_copy_results[: args.top], start=1): - print(f"{rank:02d} {format_circuit_copy_result(result)}") - elif args.score_mode == "circuit": - for rank, result in enumerate(circuit_results[: args.top], start=1): - print(f"{rank:02d} {format_circuit_result(result)}") - elif args.score_mode == "stripe": - for rank, result in enumerate(stripe_results[: args.top], start=1): - print(f"{rank:02d} {format_stripe_result(result)}") - elif args.score_mode != "logit": - for rank, result in enumerate(attn_results[: args.top], start=1): - print(f"{rank:02d} {format_attn_result(result)}") - else: - for rank, result in enumerate(results[: args.top], start=1): - print(f"{rank:02d} {format_result(result)}") - - if args.json_out is not None: - args.json_out.parent.mkdir(parents=True, exist_ok=True) - payload = { - "model": str(args.model), - "tokens": len(unique_tokens), - "active": active_count, - "score_mode": args.score_mode, - "min_eps": args.min_eps, - "min_margin": args.min_margin, - "min_logit_lb": args.min_logit_lb, - "min_score": args.min_score, - "min_copy": args.min_copy, - "use_activations": args.use_activations, - "hf_model": args.hf_model if args.use_activations else None, - "device": args.device if args.use_activations else None, - "stripe_period": args.stripe_period if args.score_mode == "stripe" else None, - } - if args.score_mode == "circuit_copy": - payload["results"] = [ - { - "rank": rank, - "prev_layer": r.prev_layer + 1, - "prev_head": r.prev_head + 1, - "induction_layer": r.induction_layer + 1, - "induction_head": r.induction_head + 1, - "score": r.score, - "prev_mean": r.prev_mean, - "prev_median": r.prev_median, - "prev_top1_frac": r.prev_top1_frac, - "stripe_mean": r.stripe_mean, - "stripe_median": r.stripe_median, - "stripe_top1_frac": r.stripe_top1_frac, - "copy_mean": r.copy_mean, - "copy_weighted_mean": r.copy_weighted_mean, - } - for rank, r in enumerate(circuit_copy_results[: args.top], start=1) - ] - elif args.score_mode == "circuit": - payload["results"] = [ - { - "rank": rank, - "prev_layer": r.prev_layer + 1, - "prev_head": r.prev_head + 1, - "induction_layer": r.induction_layer + 1, - "induction_head": r.induction_head + 1, - "score": r.score, - "prev_mean": r.prev_mean, - "prev_median": r.prev_median, - "prev_top1_frac": r.prev_top1_frac, - "stripe_mean": r.stripe_mean, - "stripe_median": r.stripe_median, - "stripe_top1_frac": r.stripe_top1_frac, - } - for rank, r in enumerate(circuit_results[: args.top], start=1) - ] - elif args.score_mode == "stripe": - payload["results"] = [ - { - "rank": rank, - "layer": r.layer + 1, - "head": r.head + 1, - "score": r.score, - "stripe_mean": r.stripe_mean, - "stripe_median": r.stripe_median, - "stripe_top1_frac": r.stripe_top1_frac, - "eps": r.eps, - "margin": r.margin, - "active": r.active, - } - for rank, r in enumerate(stripe_results[: args.top], start=1) - ] - elif args.score_mode != "logit": - payload["results"] = [ - { - "rank": rank, - "layer": r.layer + 1, - "head": r.head + 1, - "score": r.score, - "prev_mean": r.prev_mean, - "prev_median": r.prev_median, - "prev_top1_frac": r.prev_top1_frac, - "copy_mean": r.copy_mean, - "copy_weighted_mean": r.copy_weighted_mean, - "eps": r.eps, - "margin": r.margin, - "active": r.active, - } - for rank, r in enumerate(attn_results[: args.top], start=1) - ] - else: - payload["results"] = [ - { - "rank": rank, - "layer": r.layer + 1, - "head": r.head + 1, - "target": r.target, - "negative": r.negative, - "logit_lb": r.logit_lb, - "eps": r.eps, - "margin": r.margin, - "min_prev": r.min_prev, - "value_range": r.value_range, - "prev_mean": r.prev_mean, - "prev_median": r.prev_median, - "prev_top1_frac": r.prev_top1_frac, - "active": r.active, - } - for rank, r in enumerate(results[: args.top], start=1) - ] - args.json_out.write_text(json.dumps(payload, indent=2), encoding="ascii") - - if args.verify_top > 0 and results: - nfp_cmd = resolve_nfp_cmd(args.nfp_bin) - print("\nVerifying top candidates with Lean checker:") - for result in results[: args.verify_top]: - cmd = nfp_cmd + [ - "induction", - "certify_head_model_nonvacuous", - "--model", - str(args.model), - "--layer", - str(result.layer), - "--head", - str(result.head), - "--direction-target", - str(result.target), - "--direction-negative", - str(result.negative), - ] - if args.period is not None: - cmd += ["--period", str(args.period)] - proc = subprocess.run(cmd, capture_output=True, text=True) - status = "ok" if proc.returncode == 0 else "fail" - stdout = proc.stdout.strip().replace("\n", " ") - stderr = proc.stderr.strip().replace("\n", " ") - print(f"{status} {result.layer}/{result.head} tgt={result.target} neg={result.negative}") - if stdout: - print(f" out: {stdout}") - if stderr: - print(f" err: {stderr}") - - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/scripts/sanity_forward_step.py b/scripts/sanity_forward_step.py deleted file mode 100644 index d700a9a..0000000 --- a/scripts/sanity_forward_step.py +++ /dev/null @@ -1,139 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: AGPL-3.0-or-later - -""" -Sanity check: compare one forward step against PyTorch / HuggingFace GPT-2. - -This is a correctness guardrail to ensure that the exported `.nfpt` semantics match what -the Lean-side executable actually analyzes (including attention biases). - -Usage: - uv run scripts/sanity_forward_step.py models/gpt2_rigorous.nfpt --layer 0 --pos 0 --take 16 -""" - -from __future__ import annotations - -import argparse -import subprocess -from pathlib import Path - -import numpy as np - -try: - import torch - from transformers import GPT2Model -except ImportError as e: - raise SystemExit( - "Missing deps. Install with e.g.: uv add torch transformers\n" - f"ImportError: {e}" - ) - - -def parse_tokens_from_nfpt(path: Path) -> list[int]: - with path.open("rb") as f: - magic = None - header: dict[str, str] = {} - while True: - line = f.readline() - if line == b"": - raise SystemExit("Unexpected EOF while reading header.") - text = line.decode("ascii").rstrip("\n") - if magic is None: - magic = text.strip() - if text.strip() == "BINARY_START": - break - if "=" in text: - key, value = text.split("=", 1) - header[key.strip()] = value.strip() - - if magic != "NFP_BINARY_V1": - raise SystemExit(f"Unsupported magic header: {magic}") - seq_len_raw = header.get("seq_len") - if seq_len_raw is None: - raise SystemExit("Missing seq_len in header.") - - seq_len = int(seq_len_raw) - raw = f.read(seq_len * 4) - if len(raw) != seq_len * 4: - raise SystemExit("Unexpected EOF while reading TOKENS section.") - toks = np.frombuffer(raw, dtype=" np.ndarray: - exe = Path(".lake/build/bin/nfp") - if not exe.exists(): - raise SystemExit("Missing `.lake/build/bin/nfp`. Run `lake build nfp` first.") - cmd = [ - str(exe), - "dump", - "--kind", - "afterLayer", - "--layer", - str(layer), - "--pos", - str(pos), - "--take", - str(take), - str(model_path), - ] - out = subprocess.check_output(cmd, text=True) - lines = [ln.strip() for ln in out.splitlines() if ln.strip()] - dump_i = None - for i, ln in enumerate(lines): - if ln.startswith("DUMP "): - dump_i = i - if dump_i is None or dump_i + 2 >= len(lines): - raise SystemExit(f"Unexpected Lean dump output:\n{out}") - vals = [float(x) for x in lines[dump_i + 2].split()] - return np.asarray(vals, dtype=np.float64) - - -def run_torch(model_name: str, tokens: list[int], layer: int, pos: int, take: int) -> np.ndarray: - m = GPT2Model.from_pretrained(model_name) - # Lean `Float` is IEEE-754 double; run PyTorch in float64 to avoid large float32-vs-float64 drift - # (especially through attention softmax on long sequences). - m = m.to(dtype=torch.float64) - m.eval() - input_ids = torch.tensor([tokens], dtype=torch.long) - with torch.no_grad(): - out = m(input_ids=input_ids, output_hidden_states=True) - hs = out.hidden_states - if hs is None: - raise SystemExit("Expected output_hidden_states=True to return hidden_states.") - # hs[0] = embeddings, hs[layer+1] = after block `layer`. - x = hs[layer + 1][0, pos, :take].detach().cpu().numpy() - return x.astype(np.float64) - - -def main() -> None: - ap = argparse.ArgumentParser() - ap.add_argument("model", type=Path) - ap.add_argument("--hf", default="gpt2", help="HuggingFace model name (default: gpt2)") - ap.add_argument("--layer", type=int, default=0) - ap.add_argument("--pos", type=int, default=0) - ap.add_argument("--take", type=int, default=16) - ap.add_argument("--tol", type=float, default=1e-3) - args = ap.parse_args() - - if not args.model.exists(): - raise SystemExit(f"Missing file: {args.model}") - tokens = parse_tokens_from_nfpt(args.model) - if len(tokens) == 0: - raise SystemExit("No tokens parsed from TOKENS section.") - - lean = run_lean_dump(args.model, args.layer, args.pos, args.take) - torch_x = run_torch(args.hf, tokens, args.layer, args.pos, args.take) - - diff = np.max(np.abs(lean - torch_x)) - print(f"max_abs_diff={diff:.6g} tol={args.tol:.6g}") - if not np.isfinite(diff) or diff > args.tol: - print("FAIL") - print("lean :", lean.tolist()) - print("torch:", torch_x.tolist()) - raise SystemExit(1) - print("OK") - - -if __name__ == "__main__": - main() diff --git a/scripts/scan_gpt2_induction_sound.py b/scripts/scan_gpt2_induction_sound.py deleted file mode 100755 index 861af47..0000000 --- a/scripts/scan_gpt2_induction_sound.py +++ /dev/null @@ -1,377 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: AGPL-3.0-or-later - -""" -Scan GPT-2 induction head candidates with attention/copy or logit-diff bounds. - -This script: -1) Ensures a GPT-2 "rigorous induction" binary model exists (or generates one - from repeated random patterns via --synthetic). -2) Uses the untrusted discovery helper to propose head candidates. -3) Optionally runs `nfp induction certify_head_model_nonvacuous` in logit mode. - -Layer/head indices are one-based (literature-aligned). `prev` defaults to bigram -prefix matching. -""" - -from __future__ import annotations - -import argparse -import json -import os -import shutil -import struct -import subprocess -import sys -from concurrent.futures import ThreadPoolExecutor, as_completed -from fractions import Fraction -from pathlib import Path - - -def run_cmd(cmd: list[str]) -> str: - proc = subprocess.run(cmd, check=True, capture_output=True, text=True) - return proc.stdout - - -def ensure_model( - model_path: Path, - *, - seq_len: int = 256, - pattern_len: int = 20, - seed: int = 1337, - vocab_min: int = 1000, - vocab_max: int = 5000, - min_word_length: int = 4, - allow_no_leading_space: bool = False, - model_name: str = "gpt2", -) -> None: - if model_path.exists(): - return - model_path.parent.mkdir(parents=True, exist_ok=True) - generator = [ - sys.executable, - "scripts/generate_rigorous_induction.py", - "--output", - str(model_path), - "--seq-len", - str(seq_len), - "--pattern-len", - str(pattern_len), - "--seed", - str(seed), - "--vocab-min", - str(vocab_min), - "--vocab-max", - str(vocab_max), - "--min-word-length", - str(min_word_length), - "--model", - model_name, - ] - if allow_no_leading_space: - generator.append("--allow-no-leading-space") - if shutil.which("uv"): - generator = ["uv", "run"] + generator - subprocess.run(generator, check=True) - - -def resolve_nfp_cmd(nfp_bin: str | None) -> list[str]: - if nfp_bin: - return [nfp_bin] - env_bin = os.environ.get("NFP_BIN") - if env_bin: - return [env_bin] - local_bin = Path(".lake/build/bin/nfp") - if local_bin.exists(): - return [str(local_bin)] - return ["lake", "exe", "nfp"] - - -def read_header_and_tokens(path: Path) -> tuple[dict[str, str], list[int]]: - header: dict[str, str] = {} - with path.open("rb") as f: - while True: - line = f.readline() - if not line: - raise ValueError("unexpected EOF while reading header") - text = line.decode("ascii").strip() - if text == "BINARY_START": - break - if "=" in text: - key, value = text.split("=", 1) - header[key.strip()] = value.strip() - seq_len_raw = header.get("seq_len") - if seq_len_raw is None: - raise ValueError("header missing seq_len") - seq_len = int(seq_len_raw) - token_bytes = f.read(seq_len * 4) - if len(token_bytes) != seq_len * 4: - raise ValueError("unexpected EOF while reading tokens") - tokens = list(struct.unpack("<" + "i" * seq_len, token_bytes)) - return header, tokens - - -def derive_target_negative(tokens: list[int]) -> tuple[int, int]: - if len(tokens) < 2: - raise ValueError("need at least 2 tokens to derive target/negative") - last = tokens[-1] - prev_idx = None - for i in range(len(tokens) - 2, -1, -1): - if tokens[i] == last: - prev_idx = i - break - if prev_idx is not None and prev_idx + 1 < len(tokens): - target = tokens[prev_idx + 1] - else: - target = last - negative = tokens[-2] - if negative == target: - negative = last if last != target else (0 if target != 0 else 1) - return target, negative - - -def parse_logit_lb(output: str) -> Fraction | None: - for line in output.splitlines(): - if "logitDiffLB=" not in line: - continue - for token in line.split(): - if token.startswith("logitDiffLB="): - value = token.split("=", 1)[1].strip("),") - try: - return Fraction(value) - except ValueError: - return None - return None - - -def main() -> int: - parser = argparse.ArgumentParser() - parser.add_argument("--model", default="models/gpt2_rigorous.nfpt") - parser.add_argument("--top", type=int, default=8) - parser.add_argument("--maxSeqLen", type=int, default=256) - parser.add_argument("--jobs", type=int, default=1) - parser.add_argument("--fast", action="store_true") - parser.add_argument("--nfp-bin", help="Path to nfp binary (defaults to .lake/build/bin/nfp)") - parser.add_argument("--min-eps", type=float, default=0.5) - parser.add_argument("--min-margin", type=float, default=0.0) - parser.add_argument("--min-logit-lb", type=float, default=0.0) - parser.add_argument("--min-score", type=float, default=0.0) - parser.add_argument("--min-copy", type=float) - parser.add_argument( - "--score-mode", - choices=["attn", "copy", "attn_copy", "logit"], - default="attn", - help="Rank by attention/copy score or logit-diff bound.", - ) - parser.add_argument("--layers", help="Comma-separated layer list or 'all'") - parser.add_argument("--heads", help="Comma-separated head list or 'all'") - parser.add_argument("--period", type=int) - parser.add_argument( - "--prev-mode", - choices=["bigram", "token", "period", "period_shift"], - default="bigram", - help="Choose prev/active construction (forwarded to discovery).", - ) - parser.add_argument( - "--synthetic", - action="store_true", - help="Generate a repeated-random pattern prompt (prefix-matching benchmark).", - ) - parser.add_argument("--synthetic-seq-len", type=int, default=256) - parser.add_argument("--synthetic-pattern-len", type=int, default=20) - parser.add_argument("--synthetic-seed", type=int, default=1337) - parser.add_argument("--synthetic-vocab-min", type=int, default=1000) - parser.add_argument("--synthetic-vocab-max", type=int, default=5000) - parser.add_argument("--synthetic-min-word-length", type=int, default=4) - parser.add_argument("--synthetic-allow-no-leading-space", action="store_true") - parser.add_argument("--synthetic-model", default="gpt2") - parser.add_argument("--output", default="reports/gpt2_induction_sound_scan.txt") - args = parser.parse_args() - args.jobs = max(1, args.jobs) - top_arg = any(a.startswith("--top") for a in sys.argv[1:]) - max_seq_len_arg = any(a.startswith("--maxSeqLen") for a in sys.argv[1:]) - if args.fast and not top_arg and args.top == parser.get_default("top"): - args.top = 4 - - model_arg = any(a.startswith("--model") for a in sys.argv[1:]) - if args.synthetic and not model_arg: - model_path = Path( - "models/" - f"gpt2_rigorous_seq{args.synthetic_seq_len}" - f"_pat{args.synthetic_pattern_len}" - f"_seed{args.synthetic_seed}.nfpt" - ) - else: - model_path = Path(args.model) - if args.synthetic: - ensure_model( - model_path, - seq_len=args.synthetic_seq_len, - pattern_len=args.synthetic_pattern_len, - seed=args.synthetic_seed, - vocab_min=args.synthetic_vocab_min, - vocab_max=args.synthetic_vocab_max, - min_word_length=args.synthetic_min_word_length, - allow_no_leading_space=args.synthetic_allow_no_leading_space, - model_name=args.synthetic_model, - ) - else: - ensure_model(model_path) - nfp_cmd = resolve_nfp_cmd(args.nfp_bin) - - header, tokens = read_header_and_tokens(model_path) - seq_len = int(header.get("seq_len", "0")) - if args.fast and not max_seq_len_arg and args.maxSeqLen == parser.get_default("maxSeqLen"): - if seq_len <= 128: - args.maxSeqLen = 128 - if seq_len > args.maxSeqLen: - print( - f"Error: seq_len {seq_len} exceeds maxSeqLen {args.maxSeqLen}", - file=sys.stderr, - ) - return 1 - target, negative = derive_target_negative(tokens) - - discover_json = Path(args.output).with_suffix(".json") - discover_txt = Path(args.output).with_suffix(".discover.txt") - discover_cmd = [ - sys.executable, - "scripts/discover_gpt2_induction_targets.py", - "--model", - str(model_path), - "--top", - str(args.top), - "--score-mode", - args.score_mode, - "--min-eps", - str(args.min_eps), - "--min-margin", - str(args.min_margin), - "--min-logit-lb", - str(args.min_logit_lb), - "--min-score", - str(args.min_score), - "--output", - str(discover_txt), - "--json-out", - str(discover_json), - ] - if args.min_copy is not None: - discover_cmd += ["--min-copy", str(args.min_copy)] - if args.layers is not None: - discover_cmd += ["--layers", args.layers] - if args.heads is not None: - discover_cmd += ["--heads", args.heads] - if args.period is not None: - discover_cmd += ["--period", str(args.period)] - if args.prev_mode != "bigram": - discover_cmd += ["--prev-mode", args.prev_mode] - run_cmd(discover_cmd) - payload = json.loads(discover_json.read_text(encoding="ascii")) - candidates = payload.get("results", []) - if not candidates: - print("No induction candidates found.", file=sys.stderr) - return 1 - - results: list[tuple[Fraction, dict[str, int]]] = [] - - if args.score_mode == "logit": - def run_cert(candidate: dict[str, int]) -> tuple[dict[str, int], Fraction | None]: - layer = int(candidate["layer"]) - 1 - head = int(candidate["head"]) - 1 - target_id = int(candidate.get("target", target)) - negative_id = int(candidate.get("negative", negative)) - cmd = nfp_cmd + [ - "induction", - "certify_head_model_nonvacuous", - "--model", - str(model_path), - "--layer", - str(layer), - "--head", - str(head), - "--direction-target", - str(target_id), - "--direction-negative", - str(negative_id), - ] - if args.period is not None: - cmd += ["--period", str(args.period)] - try: - cert_out = run_cmd(cmd) - except subprocess.CalledProcessError: - return candidate, None - return candidate, parse_logit_lb(cert_out) - - if args.jobs == 1: - for candidate in candidates: - candidate_out, logit_lb = run_cert(candidate) - if logit_lb is None: - continue - results.append((logit_lb, candidate_out)) - else: - with ThreadPoolExecutor(max_workers=args.jobs) as executor: - futures = {executor.submit(run_cert, candidate): candidate for candidate in candidates} - for future in as_completed(futures): - candidate_out, logit_lb = future.result() - if logit_lb is None: - continue - results.append((logit_lb, candidate_out)) - - if not results: - print("No sound logit bounds produced.", file=sys.stderr) - return 1 - - results.sort(key=lambda x: x[0], reverse=True) - out_path = Path(args.output) - out_path.parent.mkdir(parents=True, exist_ok=True) - with out_path.open("w", encoding="ascii") as f: - if args.score_mode == "logit": - f.write("SOUND induction scan (logitDiffLB ranking)\n") - else: - f.write("Induction attention scan (prev-attn ranking)\n") - f.write(f"model={model_path}\n") - if args.synthetic: - f.write( - "synthetic=" - f"seq{args.synthetic_seq_len}_pat{args.synthetic_pattern_len}_" - f"seed{args.synthetic_seed}\n" - ) - f.write(f"target={target} negative={negative}\n") - eps_header = header.get("layer_norm_eps") or header.get("eps") or "unknown" - f.write(f"top={args.top} eps={eps_header}\n") - if args.score_mode == "logit": - for rank, (lb, candidate) in enumerate(results, start=1): - layer = int(candidate["layer"]) - head = int(candidate["head"]) - target_id = int(candidate.get("target", target)) - negative_id = int(candidate.get("negative", negative)) - f.write( - f"{rank:02d} L{layer}H{head} " - f"target={target_id} negative={negative_id} logitDiffLB={lb}\n" - ) - else: - for rank, candidate in enumerate(candidates[: args.top], start=1): - layer = int(candidate["layer"]) - head = int(candidate["head"]) - score = candidate.get("score") - prev_mean = candidate.get("prev_mean") - prev_median = candidate.get("prev_median") - prev_top1 = candidate.get("prev_top1_frac") - copy_mean = candidate.get("copy_mean") - copy_weighted = candidate.get("copy_weighted_mean") - eps = candidate.get("eps") - margin = candidate.get("margin") - f.write( - f"{rank:02d} L{layer}H{head} score={score} " - f"prevMean={prev_mean} prevMedian={prev_median} prevTop1={prev_top1} " - f"copyMean={copy_mean} copyWeighted={copy_weighted} " - f"eps={eps} margin={margin}\n" - ) - - print(f"Report written to {out_path}") - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/scripts/sweep_gpt2_induction_nonvacuous.py b/scripts/sweep_gpt2_induction_nonvacuous.py deleted file mode 100644 index 7e91a33..0000000 --- a/scripts/sweep_gpt2_induction_nonvacuous.py +++ /dev/null @@ -1,370 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: AGPL-3.0-or-later - -""" -Sweep prompt parameters and evaluate induction head scores for GPT-2. - -This is untrusted orchestration: discovery uses floating-point math and only -Lean verification results are treated as definitive in logit mode. - -Layer/head indices are one-based (literature-aligned). `prev` defaults to bigram -prefix matching. This sweep generates repeated-random patterns to benchmark -prefix-matching scores across seeds/lengths. -""" - -from __future__ import annotations - -import argparse -import json -import re -import shutil -import subprocess -from dataclasses import dataclass -from pathlib import Path -from typing import Iterable, List - - -LOGIT_RE = re.compile(r"logitDiffLB=([^\s\)]+)") - - -@dataclass(frozen=True) -class VerifyResult: - ok: bool - logit_lb: str | None - stdout: str - stderr: str - - -def parse_int_list(raw: str, name: str) -> List[int]: - items: List[int] = [] - for part in raw.split(","): - part = part.strip() - if not part: - continue - try: - items.append(int(part)) - except ValueError as exc: - raise ValueError(f"invalid {name} entry: {part}") from exc - if not items: - raise ValueError(f"{name} list is empty") - return items - - -def resolve_python_cmd() -> List[str]: - if shutil.which("uv"): - return ["uv", "run"] - return ["python3"] - - -def resolve_nfp_cmd(nfp_bin: str | None) -> List[str]: - if nfp_bin: - return [nfp_bin] - local_bin = Path(".lake/build/bin/nfp") - if local_bin.exists(): - return [str(local_bin)] - return ["lake", "exe", "nfp"] - - -def run_cmd(cmd: Iterable[str], check: bool = True) -> subprocess.CompletedProcess: - return subprocess.run(list(cmd), check=check, capture_output=True, text=True) - - -def ensure_model( - generator: Path, - output: Path, - seq_len: int, - pattern_len: int, - seed: int, -) -> None: - if output.exists(): - return - output.parent.mkdir(parents=True, exist_ok=True) - cmd = resolve_python_cmd() + [ - str(generator), - "--output", - str(output), - "--seq-len", - str(seq_len), - "--pattern-len", - str(pattern_len), - "--seed", - str(seed), - ] - run_cmd(cmd, check=True) - - -def run_discovery( - discover_script: Path, - model: Path, - max_tokens: int, - top: int, - min_eps: float, - min_margin: float, - min_logit_lb: float, - min_score: float, - min_copy: float | None, - score_mode: str, - period: int | None, - output_dir: Path, - prev_mode: str, -) -> list[dict]: - output_dir.mkdir(parents=True, exist_ok=True) - json_out = output_dir / f"{model.stem}.json" - cmd = resolve_python_cmd() + [ - str(discover_script), - "--model", - str(model), - "--max-tokens", - str(max_tokens), - "--top", - str(top), - "--min-eps", - str(min_eps), - "--min-margin", - str(min_margin), - "--min-logit-lb", - str(min_logit_lb), - "--min-score", - str(min_score), - "--score-mode", - score_mode, - "--json-out", - str(json_out), - ] - if min_copy is not None: - cmd += ["--min-copy", str(min_copy)] - if period is not None: - cmd += ["--period", str(period)] - if prev_mode != "bigram": - cmd += ["--prev-mode", prev_mode] - run_cmd(cmd, check=True) - payload = json.loads(json_out.read_text(encoding="ascii")) - return payload.get("results", []) - - -def verify_candidate( - nfp_cmd: List[str], - model: Path, - layer: int, - head: int, - target: int, - negative: int, - period: int | None, -) -> VerifyResult: - cmd = nfp_cmd + [ - "induction", - "certify_head_model_nonvacuous", - "--model", - str(model), - "--layer", - str(layer), - "--head", - str(head), - "--direction-target", - str(target), - "--direction-negative", - str(negative), - ] - if period is not None: - cmd += ["--period", str(period)] - proc = run_cmd(cmd, check=False) - stdout = proc.stdout.strip() - stderr = proc.stderr.strip() - logit_lb = None - match = LOGIT_RE.search(stdout) - if match: - logit_lb = match.group(1) - return VerifyResult(proc.returncode == 0, logit_lb, stdout, stderr) - - -def write_csv_row(path: Path, row: dict) -> None: - header = [ - "model_path", - "seq_len", - "pattern_len", - "seed", - "layer", - "head", - "target", - "negative", - "score_mode", - "score", - "prev_mean", - "prev_median", - "prev_top1_frac", - "copy_mean", - "copy_weighted_mean", - "approx_logit_lb", - "approx_eps", - "approx_margin", - "approx_min_prev", - "approx_value_range", - "active", - "period", - "verify_status", - "verify_logit_lb", - ] - new_file = not path.exists() or path.stat().st_size == 0 - path.parent.mkdir(parents=True, exist_ok=True) - with path.open("a", encoding="ascii") as f: - if new_file: - f.write(",".join(header) + "\n") - f.write(",".join(str(row.get(col, "")) for col in header) + "\n") - - -def main() -> int: - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("--output", type=Path, default=Path("reports/gpt2_induction_sweep.csv")) - parser.add_argument("--model-dir", type=Path, default=Path("models")) - parser.add_argument("--generator", type=Path, default=Path("scripts/generate_rigorous_induction.py")) - parser.add_argument("--discover", type=Path, default=Path("scripts/discover_gpt2_induction_targets.py")) - parser.add_argument("--seq-lens", default="64") - parser.add_argument("--pattern-lens", default="16") - parser.add_argument("--seeds", default="1337") - parser.add_argument("--max-tokens", type=int, default=32) - parser.add_argument("--top", type=int, default=10) - parser.add_argument("--verify-top", type=int, default=3) - parser.add_argument("--min-eps", type=float, default=0.5) - parser.add_argument("--min-margin", type=float, default=0.0) - parser.add_argument("--min-logit-lb", type=float, default=0.0) - parser.add_argument("--min-score", type=float, default=0.0) - parser.add_argument("--min-copy", type=float) - parser.add_argument( - "--score-mode", - choices=["attn", "copy", "attn_copy", "logit"], - default="attn", - help="Rank by attention/copy score or logit-diff bound.", - ) - parser.add_argument("--use-period", action="store_true", - help="Use pattern length as the period override") - parser.add_argument( - "--prev-mode", - choices=["bigram", "token", "period", "period_shift"], - default="bigram", - help="Choose prev/active construction (forwarded to discovery).", - ) - parser.add_argument("--nfp-bin", help="Path to nfp binary") - parser.add_argument("--discovery-dir", type=Path, default=Path("reports/discovery")) - args = parser.parse_args() - - seq_lens = parse_int_list(args.seq_lens, "seq-lens") - pattern_lens = parse_int_list(args.pattern_lens, "pattern-lens") - seeds = parse_int_list(args.seeds, "seeds") - - nfp_cmd = resolve_nfp_cmd(args.nfp_bin) - - for seq_len in seq_lens: - for pattern_len in pattern_lens: - for seed in seeds: - model_name = f"gpt2_rigorous_seq{seq_len}_pat{pattern_len}_seed{seed}.nfpt" - model_path = args.model_dir / model_name - ensure_model(args.generator, model_path, seq_len, pattern_len, seed) - period = pattern_len if args.use_period else None - results = run_discovery( - args.discover, - model_path, - args.max_tokens, - args.top, - args.min_eps, - args.min_margin, - args.min_logit_lb, - args.min_score, - args.min_copy, - args.score_mode, - period, - args.discovery_dir, - args.prev_mode, - ) - if not results: - print( - f"no candidates for seq={seq_len} pat={pattern_len} seed={seed}", - flush=True, - ) - continue - if args.score_mode == "logit": - for result in results[: args.verify_top]: - layer = result["layer"] - 1 - head = result["head"] - 1 - verify = verify_candidate( - nfp_cmd, - model_path, - layer, - head, - result["target"], - result["negative"], - period, - ) - status = "ok" if verify.ok else "fail" - if verify.ok: - print( - f"verified L{result['layer']}H{result['head']} " - f"seq={seq_len} pat={pattern_len} seed={seed}", - flush=True, - ) - row = { - "model_path": model_path, - "seq_len": seq_len, - "pattern_len": pattern_len, - "seed": seed, - "layer": result["layer"], - "head": result["head"], - "target": result["target"], - "negative": result["negative"], - "score_mode": args.score_mode, - "score": "", - "prev_mean": result.get("prev_mean", ""), - "prev_median": result.get("prev_median", ""), - "prev_top1_frac": result.get("prev_top1_frac", ""), - "copy_mean": result.get("copy_mean", ""), - "copy_weighted_mean": result.get("copy_weighted_mean", ""), - "approx_logit_lb": result["logit_lb"], - "approx_eps": result["eps"], - "approx_margin": result["margin"], - "approx_min_prev": result["min_prev"], - "approx_value_range": result["value_range"], - "active": result["active"], - "period": period if period is not None else "", - "verify_status": status, - "verify_logit_lb": verify.logit_lb or "", - } - write_csv_row(args.output, row) - if not verify.ok: - if verify.stdout: - print(f" out: {verify.stdout}", flush=True) - if verify.stderr: - print(f" err: {verify.stderr}", flush=True) - else: - for result in results[: args.top]: - row = { - "model_path": model_path, - "seq_len": seq_len, - "pattern_len": pattern_len, - "seed": seed, - "layer": result["layer"], - "head": result["head"], - "target": result.get("target", ""), - "negative": result.get("negative", ""), - "score_mode": args.score_mode, - "score": result.get("score", ""), - "prev_mean": result.get("prev_mean", ""), - "prev_median": result.get("prev_median", ""), - "prev_top1_frac": result.get("prev_top1_frac", ""), - "copy_mean": result.get("copy_mean", ""), - "copy_weighted_mean": result.get("copy_weighted_mean", ""), - "approx_logit_lb": result.get("logit_lb", ""), - "approx_eps": result.get("eps", ""), - "approx_margin": result.get("margin", ""), - "approx_min_prev": result.get("min_prev", ""), - "approx_value_range": result.get("value_range", ""), - "active": result.get("active", ""), - "period": period if period is not None else "", - "verify_status": "", - "verify_logit_lb": "", - } - write_csv_row(args.output, row) - - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) From 16033ed434fe8e0671395091813b1d0b070769c5 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 09:27:12 +0100 Subject: [PATCH 223/244] Limit induction certification to head-level parity --- CLAIMS.md | 27 +- Nfp/Bounds.lean | 2 - Nfp/Bounds/Attention.lean | 1 - Nfp/Bounds/MatrixNorm.lean | 10 - Nfp/Bounds/MatrixNorm/Basic.lean | 156 ---- Nfp/Bounds/MatrixNorm/Interval.lean | 921 ------------------------ Nfp/Bounds/Mlp.lean | 1 - Nfp/Bounds/Transformer.lean | 10 - Nfp/Bounds/Transformer/Basic.lean | 522 -------------- Nfp/Bounds/Transformer/Embedding.lean | 131 ---- Nfp/Circuit/Cert.lean | 3 - Nfp/Circuit/Cert/DownstreamLinear.lean | 66 -- Nfp/Circuit/Cert/ResidualBound.lean | 51 -- Nfp/Circuit/Cert/ResidualInterval.lean | 53 -- Nfp/IO/Loaders.lean | 54 -- Nfp/IO/Parse.lean | 2 - Nfp/IO/Parse/Downstream.lean | 80 -- Nfp/IO/Parse/Residual.lean | 140 ---- Nfp/Model/Gpt2.lean | 6 +- Nfp/Model/InductionCircuit.lean | 2 +- Nfp/Sound/Induction.lean | 5 +- Nfp/Sound/Induction/CoreDefs.lean | 1 - Nfp/Sound/Induction/EndToEnd.lean | 83 --- Nfp/Sound/Induction/HeadOutput.lean | 16 +- Nfp/Sound/Induction/LogitDiff.lean | 122 +--- SOUNDNESS_LIMITATIONS.md | 7 +- docs/induction_cert_audit.md | 10 +- scripts/build_downstream_linear_cert.py | 66 -- scripts/build_residual_bound_cert.py | 151 ---- scripts/build_residual_interval_cert.py | 209 ------ 30 files changed, 19 insertions(+), 2889 deletions(-) delete mode 100644 Nfp/Bounds/MatrixNorm.lean delete mode 100644 Nfp/Bounds/MatrixNorm/Basic.lean delete mode 100644 Nfp/Bounds/MatrixNorm/Interval.lean delete mode 100644 Nfp/Bounds/Transformer.lean delete mode 100644 Nfp/Bounds/Transformer/Basic.lean delete mode 100644 Nfp/Bounds/Transformer/Embedding.lean delete mode 100644 Nfp/Circuit/Cert/DownstreamLinear.lean delete mode 100644 Nfp/Circuit/Cert/ResidualBound.lean delete mode 100644 Nfp/Circuit/Cert/ResidualInterval.lean delete mode 100644 Nfp/IO/Loaders.lean delete mode 100644 Nfp/IO/Parse/Downstream.lean delete mode 100644 Nfp/IO/Parse/Residual.lean delete mode 100644 Nfp/Sound/Induction/EndToEnd.lean delete mode 100644 scripts/build_downstream_linear_cert.py delete mode 100644 scripts/build_residual_bound_cert.py delete mode 100644 scripts/build_residual_interval_cert.py diff --git a/CLAIMS.md b/CLAIMS.md index c7f0514..44f453d 100644 --- a/CLAIMS.md +++ b/CLAIMS.md @@ -13,20 +13,9 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri `InductionHeadCertBounds`. - Logit-diff lower bound lemmas: `logitDiffLowerBound_le`, `logitDiffLowerBoundAt_le`, and `logitDiffLowerBoundWeightedAt_le`. -- Bridge lemmas composing head logit-diff bounds with head outputs and residual - interval bounds: `headLogitDiff_eq_direction_dot_headOutput`, - `logitDiffLowerBound_with_residual`, and `logitDiffLowerBound_with_output_intervals`. -- Downstream linear certificate soundness: `checkDownstreamLinearCert` implies - `DownstreamLinearBounds`. -- Residual-interval certificate soundness: `checkResidualIntervalCert` implies - `ResidualIntervalBounds`. -- End-to-end direction-dot lower bounds on `transformerStackFinalReal` can be derived by - composing head logit-diff bounds with residual interval bounds - (`logitDiffLowerBound_end_to_end_gpt2`). -- Row-sum matrix norm bounds for `mulVec` under uniform input magnitude. -- Tanh-GELU bounds and interval propagation through MLP layers. -- Interval bounds for multi-head attention, transformer-layer residual blocks, transformer - stacks, and final LayerNorm outputs. +- The head logit-diff equals the direction dot product of the head output + (`headLogitDiff_eq_direction_dot_headOutput`). +- Row-stochastic attention/one-hot bounds for induction heads and related interval lemmas. ## Soundly checked by the trusted CLI @@ -36,10 +25,9 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri ## Untrusted / heuristic -- Python helpers that generate explicit certificates from GPT-2 weights or `.nfpt` files: - `scripts/build_gpt2_induction_cert.py`, `scripts/build_gpt2_induction_cert_from_binary.py`, - `scripts/build_residual_interval_cert.py`, `scripts/build_residual_bound_cert.py`, and - `scripts/build_downstream_linear_cert.py`. +- Python helpers that generate explicit induction-head certificates from GPT-2 weights or + `.nfpt` files: `scripts/build_gpt2_induction_cert.py`, + `scripts/build_gpt2_induction_cert_from_binary.py`. - Exporters and dataset generators for `.nfpt` model files. - Any choice of prompts, directions, or candidate heads used by certificate generators. @@ -47,5 +35,4 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri - A verified extraction pipeline from model weights to explicit certificates. - End-to-end claims about GPT-2 logits or Jacobians derived from certificates. -- A full bridge from explicit certificates to complete model semantics (beyond head-level - and residual-interval compositions). +- A full bridge from explicit head certificates to complete model semantics. diff --git a/Nfp/Bounds.lean b/Nfp/Bounds.lean index fd95fa5..eeb5226 100644 --- a/Nfp/Bounds.lean +++ b/Nfp/Bounds.lean @@ -7,9 +7,7 @@ public import Nfp.Bounds.Cache public import Nfp.Bounds.Gelu public import Nfp.Bounds.LayerNorm public import Nfp.Bounds.LayerNorm.InvStd -public import Nfp.Bounds.MatrixNorm public import Nfp.Bounds.Mlp -public import Nfp.Bounds.Transformer public import Nfp.Bounds.UnnormRat /-! diff --git a/Nfp/Bounds/Attention.lean b/Nfp/Bounds/Attention.lean index 935104d..94aa802 100644 --- a/Nfp/Bounds/Attention.lean +++ b/Nfp/Bounds/Attention.lean @@ -11,7 +11,6 @@ public import Nfp.Core.Basic public import Nfp.Model.Gpt2 public import Nfp.Bounds.Cache public import Nfp.Bounds.LayerNorm -public import Nfp.Bounds.MatrixNorm public import Nfp.Bounds.Mlp /-! diff --git a/Nfp/Bounds/MatrixNorm.lean b/Nfp/Bounds/MatrixNorm.lean deleted file mode 100644 index d278f37..0000000 --- a/Nfp/Bounds/MatrixNorm.lean +++ /dev/null @@ -1,10 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.Bounds.MatrixNorm.Basic -public import Nfp.Bounds.MatrixNorm.Interval - -/-! -Matrix norm and interval bound helpers for downstream certificates. --/ diff --git a/Nfp/Bounds/MatrixNorm/Basic.lean b/Nfp/Bounds/MatrixNorm/Basic.lean deleted file mode 100644 index 23d6ace..0000000 --- a/Nfp/Bounds/MatrixNorm/Basic.lean +++ /dev/null @@ -1,156 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Mathlib.Algebra.BigOperators.Fin -public import Mathlib.Algebra.Order.BigOperators.Group.Finset -public import Mathlib.Algebra.Order.Ring.Abs -public import Mathlib.Data.Fintype.Basic -public import Mathlib.Data.Matrix.Mul -public import Mathlib.Data.Real.Basic -public import Nfp.Core.Basic -public import Nfp.Bounds.MatrixNorm.Interval -public import Nfp.Linear.FinFold - -/-! -Row-sum matrix norms for downstream linear certificates. - -These bounds are used to compute verified downstream error certificates -from explicit Rat matrices. --/ - -public section - -namespace Nfp - - -namespace Bounds - -open scoped BigOperators - -/-- Row-sum of absolute values for a matrix row. -/ -def rowSum {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : Rat := - Linear.sumFin n (fun j => |W i j|) - -/-- Weighted row-sum using per-coordinate bounds. -/ -def rowSumWeighted {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (bound : Fin n → Rat) (i : Fin m) : Rat := - Linear.sumFin n (fun j => |W i j| * bound j) - -/-- Maximum row-sum norm (defaults to `0` on empty matrices). -/ -def rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) : Rat := - Linear.foldlFin m (fun acc i => max acc (rowSum W i)) 0 - -/-- Maximum weighted row-sum (defaults to `0` on empty matrices). -/ -def rowSumWeightedNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (bound : Fin n → Rat) : Rat := - Linear.foldlFin m (fun acc i => max acc (rowSumWeighted W bound i)) 0 - -/-- Row-sums are nonnegative. -/ -theorem rowSum_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : - 0 ≤ rowSum W i := by - simpa [rowSum, Linear.sumFin_eq_sum_univ] using - (Finset.sum_nonneg (fun j _ => abs_nonneg (W i j))) - -/-- Weighted row-sums are nonnegative under nonnegative bounds. -/ -theorem rowSumWeighted_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (bound : Fin n → Rat) (i : Fin m) (hbound : ∀ j, 0 ≤ bound j) : - 0 ≤ rowSumWeighted W bound i := by - classical - have hsum : 0 ≤ ∑ j, |W i j| * bound j := by - refine Finset.sum_nonneg ?_ - intro j _ - exact mul_nonneg (abs_nonneg (W i j)) (hbound j) - simpa [rowSumWeighted, Linear.sumFin_eq_sum_univ] using hsum - -/-- Each row-sum is bounded by the row-sum norm. -/ -theorem rowSum_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) (i : Fin m) : - rowSum W i ≤ rowSumNorm W := by - simpa [rowSumNorm] using - (foldlFin_max_ge (f := fun j => rowSum W j) i) - -/-- Each weighted row-sum is bounded by the weighted row-sum norm. -/ -theorem rowSumWeighted_le_rowSumWeightedNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (bound : Fin n → Rat) (i : Fin m) : - rowSumWeighted W bound i ≤ rowSumWeightedNorm W bound := by - simpa [rowSumWeightedNorm] using - (foldlFin_max_ge (f := fun j => rowSumWeighted W bound j) i) - -/-- The row-sum norm is nonnegative. -/ -theorem rowSumNorm_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) : - 0 ≤ rowSumNorm W := by - simpa [rowSumNorm] using - (foldlFin_max_ge_init (f := fun i => rowSum W i) (init := (0 : Rat))) - -/-- Weighted row-sum norm is nonnegative. -/ -theorem rowSumWeightedNorm_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (bound : Fin n → Rat) : - 0 ≤ rowSumWeightedNorm W bound := by - simpa [rowSumWeightedNorm] using - (foldlFin_max_ge_init (f := fun i => rowSumWeighted W bound i) (init := (0 : Rat))) - -/-- Downstream error from per-coordinate residual bounds. -/ -def downstreamErrorFromBounds {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (bound : Fin n → Rat) : Rat := - rowSumWeightedNorm W bound - -/-- `downstreamErrorFromBounds` is nonnegative. -/ -theorem downstreamErrorFromBounds_nonneg {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (bound : Fin n → Rat) : - 0 ≤ downstreamErrorFromBounds W bound := by - simpa [downstreamErrorFromBounds] using rowSumWeightedNorm_nonneg W bound - -/-- Summed absolute row entries factor out a scalar bound. -/ -theorem sum_abs_row_mul_eq_rowSum_mul {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (i : Fin m) (inputBound : Rat) : - (∑ j, |W i j| * inputBound) = rowSum W i * inputBound := by - have hsum : - (∑ j, |W i j|) * inputBound = ∑ j, |W i j| * inputBound := by - simpa using - (Finset.sum_mul - (s := (Finset.univ : Finset (Fin n))) - (f := fun j => |W i j|) - (a := inputBound)) - simpa [rowSum, Linear.sumFin_eq_sum_univ] using hsum.symm - -/-- Row-sum norm bounds a matrix-vector product under a uniform input bound. -/ -theorem abs_mulVec_le_rowSum {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (x : Fin n → Rat) (inputBound : Rat) - (hx : ∀ j, |x j| ≤ inputBound) : - ∀ i, |Matrix.mulVec W x i| ≤ rowSum W i * inputBound := by - intro i - have h1 : |∑ j, W i j * x j| ≤ ∑ j, |W i j * x j| := by - simpa using - (Finset.abs_sum_le_sum_abs - (f := fun j => W i j * x j) - (s := (Finset.univ : Finset (Fin n)))) - have h2 : ∑ j, |W i j * x j| ≤ ∑ j, |W i j| * inputBound := by - refine Finset.sum_le_sum ?_ - intro j _ - have hnonneg : 0 ≤ |W i j| := abs_nonneg (W i j) - calc - |W i j * x j| = |W i j| * |x j| := by - simp [abs_mul] - _ ≤ |W i j| * inputBound := by - exact mul_le_mul_of_nonneg_left (hx j) hnonneg - have h3 : ∑ j, |W i j| * inputBound = rowSum W i * inputBound := - sum_abs_row_mul_eq_rowSum_mul W i inputBound - simpa [Matrix.mulVec, dotProduct] using h1.trans (h2.trans_eq h3) - -/-- Row-sum norm bounds a matrix-vector product under a uniform input bound. -/ -theorem abs_mulVec_le_rowSumNorm {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (x : Fin n → Rat) (inputBound : Rat) - (hx : ∀ j, |x j| ≤ inputBound) (hinput : 0 ≤ inputBound) : - ∀ i, |Matrix.mulVec W x i| ≤ rowSumNorm W * inputBound := by - intro i - have hrow : |Matrix.mulVec W x i| ≤ rowSum W i * inputBound := - abs_mulVec_le_rowSum W x inputBound hx i - have hle : rowSum W i ≤ rowSumNorm W := rowSum_le_rowSumNorm W i - have hmul : rowSum W i * inputBound ≤ rowSumNorm W * inputBound := - mul_le_mul_of_nonneg_right hle hinput - exact hrow.trans hmul - -end Bounds - - -end Nfp diff --git a/Nfp/Bounds/MatrixNorm/Interval.lean b/Nfp/Bounds/MatrixNorm/Interval.lean deleted file mode 100644 index 4395611..0000000 --- a/Nfp/Bounds/MatrixNorm/Interval.lean +++ /dev/null @@ -1,921 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Mathlib.Algebra.BigOperators.Fin -public import Mathlib.Algebra.Order.BigOperators.Group.Finset -public import Mathlib.Algebra.Order.Ring.Abs -public import Mathlib.Data.Matrix.Mul -public import Mathlib.Data.Real.Basic -public import Nfp.Core.Basic -public import Nfp.Linear.FinFold - -/-! -Interval bounds for dot products and matrix-vector products. - -This module isolates interval-bound helpers used across downstream certificates. --/ - -public section - -namespace Nfp - - -namespace Bounds - -open scoped BigOperators - -lemma foldl_max_ge_init {α : Type _} (f : α → Rat) : - ∀ (l : List α) (init : Rat), - init ≤ l.foldl (fun acc x => max acc (f x)) init := by - intro l init - induction l generalizing init with - | nil => - simp - | cons a l ih => - simpa [List.foldl] using - le_trans (le_max_left _ _) (ih (max init (f a))) - -lemma foldl_max_ge_mem {α : Type _} (f : α → Rat) : - ∀ (l : List α) (a : α) (init : Rat), - a ∈ l → f a ≤ l.foldl (fun acc x => max acc (f x)) init := by - intro l a init hmem - induction l generalizing init with - | nil => - cases hmem - | cons b l ih => - rcases (List.mem_cons.mp hmem) with rfl | hmem - · simpa [List.foldl] using - le_trans (le_max_right _ _) - (foldl_max_ge_init (f := f) l (max init (f a))) - · simpa [List.foldl] using ih (init := max init (f b)) hmem - -lemma foldlFin_max_ge_init {n : Nat} (f : Fin n → Rat) (init : Rat) : - init ≤ Linear.foldlFin n (fun acc j => max acc (f j)) init := by - classical - simpa [Linear.foldlFin_eq_foldl, Fin.foldl_eq_foldl_finRange] using - (foldl_max_ge_init (f := f) (List.finRange n) init) - -lemma foldlFin_max_ge {n : Nat} (f : Fin n → Rat) (i : Fin n) : - f i ≤ Linear.foldlFin n (fun acc j => max acc (f j)) 0 := by - classical - have hmem : i ∈ List.finRange n := by - simp - simpa [Linear.foldlFin_eq_foldl, Fin.foldl_eq_foldl_finRange] using - (foldl_max_ge_mem (f := f) (List.finRange n) i 0 hmem) - -/-- Lower interval endpoint for a dot product with per-coordinate bounds. -/ -def dotIntervalLower {n : Nat} (v lo hi : Fin n → Rat) : Rat := - Linear.sumFin n (fun j => if 0 ≤ v j then v j * lo j else v j * hi j) - -/-- Upper interval endpoint for a dot product with per-coordinate bounds. -/ -def dotIntervalUpper {n : Nat} (v lo hi : Fin n → Rat) : Rat := - Linear.sumFin n (fun j => if 0 ≤ v j then v j * hi j else v j * lo j) - -/-- Lower interval endpoint for a product of two intervals. -/ -def mulIntervalLower (a b c d : Rat) : Rat := - min (min (a * c) (a * d)) (min (b * c) (b * d)) - -/-- Upper interval endpoint for a product of two intervals. -/ -def mulIntervalUpper (a b c d : Rat) : Rat := - max (max (a * c) (a * d)) (max (b * c) (b * d)) - -/-- `x * y` lies between `min (a * y) (b * y)` and `max (a * y) (b * y)` when `a ≤ x ≤ b`. -/ -lemma mul_between_of_bounds {a b x y : Rat} (hx : a ≤ x) (hx' : x ≤ b) : - min (a * y) (b * y) ≤ x * y ∧ x * y ≤ max (a * y) (b * y) := by - have hab : a ≤ b := le_trans hx hx' - by_cases hy : 0 ≤ y - · have hmin : min (a * y) (b * y) = a * y := by - have hle : a * y ≤ b * y := by - exact mul_le_mul_of_nonneg_right hab hy - exact min_eq_left hle - have hmax : max (a * y) (b * y) = b * y := by - have hle : a * y ≤ b * y := by - exact mul_le_mul_of_nonneg_right hab hy - exact max_eq_right hle - constructor - · have hmul : a * y ≤ x * y := by - exact mul_le_mul_of_nonneg_right hx hy - simpa [hmin] using hmul - · have hmul : x * y ≤ b * y := by - exact mul_le_mul_of_nonneg_right hx' hy - simpa [hmax] using hmul - · have hy' : y ≤ 0 := le_of_not_ge hy - have hmin : min (a * y) (b * y) = b * y := by - have hle : b * y ≤ a * y := by - exact mul_le_mul_of_nonpos_right hab hy' - exact min_eq_right hle - have hmax : max (a * y) (b * y) = a * y := by - have hle : b * y ≤ a * y := by - exact mul_le_mul_of_nonpos_right hab hy' - exact max_eq_left hle - constructor - · have hmul : b * y ≤ x * y := by - exact mul_le_mul_of_nonpos_right hx' hy' - simpa [hmin] using hmul - · have hmul : x * y ≤ a * y := by - exact mul_le_mul_of_nonpos_right hx hy' - simpa [hmax] using hmul - -/-- Lower interval endpoint bounds `x * y` when both factors are interval-bounded. -/ -lemma mulIntervalLower_le_mul {a b c d x y : Rat} - (hx : a ≤ x) (hx' : x ≤ b) (hy : c ≤ y) (hy' : y ≤ d) : - mulIntervalLower a b c d ≤ x * y := by - have hAy : - min (a * c) (a * d) ≤ a * y := by - have h := mul_between_of_bounds (a := c) (b := d) (x := y) (y := a) hy hy' - simpa only [mul_comm] using h.1 - have hBy : - min (b * c) (b * d) ≤ b * y := by - have h := mul_between_of_bounds (a := c) (b := d) (x := y) (y := b) hy hy' - simpa only [mul_comm] using h.1 - have hmin : - min (min (a * c) (a * d)) (min (b * c) (b * d)) ≤ min (a * y) (b * y) := by - apply le_min - · exact le_trans (min_le_left _ _) hAy - · exact le_trans (min_le_right _ _) hBy - have hxy := (mul_between_of_bounds (a := a) (b := b) (x := x) (y := y) hx hx').1 - exact le_trans hmin hxy - -/-- Upper interval endpoint bounds `x * y` when both factors are interval-bounded. -/ -lemma mul_le_mulIntervalUpper {a b c d x y : Rat} - (hx : a ≤ x) (hx' : x ≤ b) (hy : c ≤ y) (hy' : y ≤ d) : - x * y ≤ mulIntervalUpper a b c d := by - have hAy : - a * y ≤ max (a * c) (a * d) := by - have h := mul_between_of_bounds (a := c) (b := d) (x := y) (y := a) hy hy' - simpa only [mul_comm] using h.2 - have hBy : - b * y ≤ max (b * c) (b * d) := by - have h := mul_between_of_bounds (a := c) (b := d) (x := y) (y := b) hy hy' - simpa only [mul_comm] using h.2 - have hmax : - max (a * y) (b * y) ≤ max (max (a * c) (a * d)) (max (b * c) (b * d)) := by - apply max_le - · exact le_trans hAy (le_max_left _ _) - · exact le_trans hBy (le_max_right _ _) - have hxy := (mul_between_of_bounds (a := a) (b := b) (x := x) (y := y) hx hx').2 - exact le_trans hxy hmax - -/-- Lower interval endpoint for a dot product with bounds on both vectors. -/ -def dotIntervalLower2 {n : Nat} (lo1 hi1 lo2 hi2 : Fin n → Rat) : Rat := - Linear.sumFin n (fun j => mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j)) - -/-- Upper interval endpoint for a dot product with bounds on both vectors. -/ -def dotIntervalUpper2 {n : Nat} (lo1 hi1 lo2 hi2 : Fin n → Rat) : Rat := - Linear.sumFin n (fun j => mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j)) - -/-- Lower/upper interval endpoints for a dot product with bounds on both vectors. -/ -def dotIntervalLowerUpper2CommonDen {n : Nat} (lo1 hi1 lo2 hi2 : Fin n → Rat) : Rat × Rat := - Linear.foldlFin n - (fun acc j => - (acc.1 + mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j), - acc.2 + mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j))) - (0, 0) - -/-! Sign-splitting bounds. -/ - -/-- Clamp a single coordinate interval to be nonnegative or nonpositive. -/ -def clampAt {n : Nat} (i : Fin n) (nonneg : Bool) (lo hi : Fin n → Rat) : - (Fin n → Rat) × (Fin n → Rat) := - if nonneg then - (fun j => if j = i then max 0 (lo j) else lo j, hi) - else - (lo, fun j => if j = i then min 0 (hi j) else hi j) - -/-- Lower/upper interval endpoints with sign-splitting on selected coordinates. -/ -def dotIntervalLowerUpper2SignSplit {n : Nat} (dims : List (Fin n)) - (lo1 hi1 lo2 hi2 : Fin n → Rat) : Rat × Rat := - match dims with - | [] => - dotIntervalLowerUpper2CommonDen lo1 hi1 lo2 hi2 - | i :: rest => - let boundsPos := - let clamped := clampAt i true lo1 hi1 - dotIntervalLowerUpper2SignSplit rest clamped.1 clamped.2 lo2 hi2 - let boundsNeg := - let clamped := clampAt i false lo1 hi1 - dotIntervalLowerUpper2SignSplit rest clamped.1 clamped.2 lo2 hi2 - (min boundsPos.1 boundsNeg.1, max boundsPos.2 boundsNeg.2) - -/-- Lower/upper interval endpoints with sign-splitting on both sides. -/ -def dotIntervalLowerUpper2SignSplitBoth {n : Nat} (dims1 dims2 : List (Fin n)) - (lo1 hi1 lo2 hi2 : Fin n → Rat) : Rat × Rat := - let bounds1 := dotIntervalLowerUpper2SignSplit dims1 lo1 hi1 lo2 hi2 - let bounds2 := dotIntervalLowerUpper2SignSplit dims2 lo2 hi2 lo1 hi1 - (max bounds1.1 bounds2.1, min bounds1.2 bounds2.2) - -/-- Sum of lower interval products bounds the dot-product sum (Rat). -/ -private theorem sum_mulIntervalLower_le_sum_mul {n : Nat} - (lo1 hi1 lo2 hi2 x y : Fin n → Rat) - (hlo1 : ∀ j, lo1 j ≤ x j) (hhi1 : ∀ j, x j ≤ hi1 j) - (hlo2 : ∀ j, lo2 j ≤ y j) (hhi2 : ∀ j, y j ≤ hi2 j) : - ∑ j, mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j) ≤ - ∑ j, x j * y j := by - refine Finset.sum_le_sum ?_ - intro j _ - exact mulIntervalLower_le_mul (hlo1 j) (hhi1 j) (hlo2 j) (hhi2 j) - -/-- Sum of products is bounded by the upper interval products (Rat). -/ -private theorem sum_mul_le_sum_mulIntervalUpper {n : Nat} - (lo1 hi1 lo2 hi2 x y : Fin n → Rat) - (hlo1 : ∀ j, lo1 j ≤ x j) (hhi1 : ∀ j, x j ≤ hi1 j) - (hlo2 : ∀ j, lo2 j ≤ y j) (hhi2 : ∀ j, y j ≤ hi2 j) : - ∑ j, x j * y j ≤ - ∑ j, mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j) := by - refine Finset.sum_le_sum ?_ - intro j _ - exact mul_le_mulIntervalUpper (hlo1 j) (hhi1 j) (hlo2 j) (hhi2 j) - -theorem dotIntervalLower2_le_dotProduct {n : Nat} (lo1 hi1 lo2 hi2 x y : Fin n → Rat) - (hlo1 : ∀ j, lo1 j ≤ x j) (hhi1 : ∀ j, x j ≤ hi1 j) - (hlo2 : ∀ j, lo2 j ≤ y j) (hhi2 : ∀ j, y j ≤ hi2 j) : - dotIntervalLower2 lo1 hi1 lo2 hi2 ≤ dotProduct x y := by - classical - have hsum := - sum_mulIntervalLower_le_sum_mul - (lo1 := lo1) (hi1 := hi1) (lo2 := lo2) (hi2 := hi2) (x := x) (y := y) - hlo1 hhi1 hlo2 hhi2 - simpa [dotIntervalLower2, Linear.sumFin_eq_sum_univ, dotProduct] using hsum - -theorem dotProduct_le_dotIntervalUpper2 {n : Nat} (lo1 hi1 lo2 hi2 x y : Fin n → Rat) - (hlo1 : ∀ j, lo1 j ≤ x j) (hhi1 : ∀ j, x j ≤ hi1 j) - (hlo2 : ∀ j, lo2 j ≤ y j) (hhi2 : ∀ j, y j ≤ hi2 j) : - dotProduct x y ≤ dotIntervalUpper2 lo1 hi1 lo2 hi2 := by - classical - have hsum := - sum_mul_le_sum_mulIntervalUpper - (lo1 := lo1) (hi1 := hi1) (lo2 := lo2) (hi2 := hi2) (x := x) (y := y) - hlo1 hhi1 hlo2 hhi2 - simpa [dotIntervalUpper2, Linear.sumFin_eq_sum_univ, dotProduct] using hsum -/-- Lower interval endpoint using a shared-denominator accumulator. -/ -def dotIntervalLowerCommonDen {n : Nat} (v lo hi : Fin n → Rat) : Rat := - Linear.sumFinCommonDen n (fun j => if 0 ≤ v j then v j * lo j else v j * hi j) - -/-- Upper interval endpoint using a shared-denominator accumulator. -/ -def dotIntervalUpperCommonDen {n : Nat} (v lo hi : Fin n → Rat) : Rat := - Linear.sumFinCommonDen n (fun j => if 0 ≤ v j then v j * hi j else v j * lo j) - -/-- Lower/upper interval endpoints computed in a single pass. -/ -def dotIntervalLowerUpperCommonDen {n : Nat} (v lo hi : Fin n → Rat) : Rat × Rat := - Linear.foldlFin n - (fun acc j => - (acc.1 + if 0 ≤ v j then v j * lo j else v j * hi j, - acc.2 + if 0 ≤ v j then v j * hi j else v j * lo j)) - (0, 0) - -/-- Lower interval endpoint using unnormalized accumulation. -/ -def dotIntervalLowerUnnorm {n : Nat} (v lo hi : Fin n → Rat) : Rat := - dotIntervalLower v lo hi - -/-- Upper interval endpoint using unnormalized accumulation. -/ -def dotIntervalUpperUnnorm {n : Nat} (v lo hi : Fin n → Rat) : Rat := - dotIntervalUpper v lo hi - -theorem dotIntervalLowerCommonDen_eq {n : Nat} (v lo hi : Fin n → Rat) : - dotIntervalLowerCommonDen v lo hi = dotIntervalLower v lo hi := by - simp only [dotIntervalLowerCommonDen, dotIntervalLower, Linear.sumFinCommonDen_eq_sumFin] - -theorem dotIntervalUpperCommonDen_eq {n : Nat} (v lo hi : Fin n → Rat) : - dotIntervalUpperCommonDen v lo hi = dotIntervalUpper v lo hi := by - simp only [dotIntervalUpperCommonDen, dotIntervalUpper, Linear.sumFinCommonDen_eq_sumFin] - -private lemma foldl_pair {α : Type _} (xs : List α) (f g : α → Rat) (a b : Rat) : - xs.foldl (fun acc x => (acc.1 + f x, acc.2 + g x)) (a, b) = - (xs.foldl (fun acc x => acc + f x) a, xs.foldl (fun acc x => acc + g x) b) := by - induction xs generalizing a b with - | nil => - simp - | cons x xs ih => - simp [List.foldl, ih] - -theorem dotIntervalLowerUpper2CommonDen_fst {n : Nat} (lo1 hi1 lo2 hi2 : Fin n → Rat) : - (dotIntervalLowerUpper2CommonDen lo1 hi1 lo2 hi2).1 = - dotIntervalLower2 lo1 hi1 lo2 hi2 := by - classical - have hpair := - foldl_pair (xs := List.finRange n) - (f := fun j => mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j)) - (g := fun j => mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j)) - (a := 0) (b := 0) - have hfst := congrArg Prod.fst hpair - simpa [dotIntervalLowerUpper2CommonDen, dotIntervalLower2, Linear.foldlFin_eq_foldl, - Linear.sumFin_eq_list_foldl, Fin.foldl_eq_foldl_finRange] using hfst - -theorem dotIntervalLowerUpper2CommonDen_snd {n : Nat} (lo1 hi1 lo2 hi2 : Fin n → Rat) : - (dotIntervalLowerUpper2CommonDen lo1 hi1 lo2 hi2).2 = - dotIntervalUpper2 lo1 hi1 lo2 hi2 := by - classical - have hpair := - foldl_pair (xs := List.finRange n) - (f := fun j => mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j)) - (g := fun j => mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j)) - (a := 0) (b := 0) - have hsnd := congrArg Prod.snd hpair - simpa [dotIntervalLowerUpper2CommonDen, dotIntervalUpper2, Linear.foldlFin_eq_foldl, - Linear.sumFin_eq_list_foldl, Fin.foldl_eq_foldl_finRange] using hsnd - -theorem dotIntervalLowerUpper2CommonDen_eq {n : Nat} (lo1 hi1 lo2 hi2 : Fin n → Rat) : - dotIntervalLowerUpper2CommonDen lo1 hi1 lo2 hi2 = - (dotIntervalLower2 lo1 hi1 lo2 hi2, dotIntervalUpper2 lo1 hi1 lo2 hi2) := by - ext <;> simp only [dotIntervalLowerUpper2CommonDen_fst, dotIntervalLowerUpper2CommonDen_snd] - -theorem dotIntervalLowerUpperCommonDen_fst {n : Nat} (v lo hi : Fin n → Rat) : - (dotIntervalLowerUpperCommonDen v lo hi).1 = dotIntervalLowerCommonDen v lo hi := by - classical - have hpair := - foldl_pair (xs := List.finRange n) - (f := fun j => if 0 ≤ v j then v j * lo j else v j * hi j) - (g := fun j => if 0 ≤ v j then v j * hi j else v j * lo j) - (a := 0) (b := 0) - have hfst := congrArg Prod.fst hpair - simpa [dotIntervalLowerUpperCommonDen, dotIntervalLowerCommonDen, - Linear.foldlFin_eq_foldl, Linear.sumFinCommonDen_eq_sumFin, - Linear.sumFin_eq_list_foldl, Fin.foldl_eq_foldl_finRange] using hfst - -theorem dotIntervalLowerUpperCommonDen_snd {n : Nat} (v lo hi : Fin n → Rat) : - (dotIntervalLowerUpperCommonDen v lo hi).2 = dotIntervalUpperCommonDen v lo hi := by - classical - have hpair := - foldl_pair (xs := List.finRange n) - (f := fun j => if 0 ≤ v j then v j * lo j else v j * hi j) - (g := fun j => if 0 ≤ v j then v j * hi j else v j * lo j) - (a := 0) (b := 0) - have hsnd := congrArg Prod.snd hpair - simpa [dotIntervalLowerUpperCommonDen, dotIntervalUpperCommonDen, - Linear.foldlFin_eq_foldl, Linear.sumFinCommonDen_eq_sumFin, - Linear.sumFin_eq_list_foldl, Fin.foldl_eq_foldl_finRange] using hsnd - -/-- Single-pass lower/upper endpoints agree with the common-denominator bounds. -/ -theorem dotIntervalLowerUpperCommonDen_eq {n : Nat} (v lo hi : Fin n → Rat) : - dotIntervalLowerUpperCommonDen v lo hi = - (dotIntervalLowerCommonDen v lo hi, dotIntervalUpperCommonDen v lo hi) := by - ext <;> simp only [dotIntervalLowerUpperCommonDen_fst, dotIntervalLowerUpperCommonDen_snd] - -private theorem dotIntervalLowerUnnorm_eq {n : Nat} (v lo hi : Fin n → Rat) : - dotIntervalLowerUnnorm v lo hi = dotIntervalLower v lo hi := rfl - -private theorem dotIntervalUpperUnnorm_eq {n : Nat} (v lo hi : Fin n → Rat) : - dotIntervalUpperUnnorm v lo hi = dotIntervalUpper v lo hi := rfl - -/-! Cached endpoints. -/ - -/-- Cached-array lower interval endpoint for a dot product using normalized rational sums. -/ -def dotIntervalLowerCachedRat {n : Nat} (v lo hi : Fin n → Rat) : Rat := - let vArr := Array.ofFn v - let loArr := Array.ofFn lo - let hiArr := Array.ofFn hi - Linear.sumFin n (fun j => - let vj := vArr[j.1]'(by - have hsize : vArr.size = n := by simp [vArr] - simp [hsize, j.isLt]) - let loj := loArr[j.1]'(by - have hsize : loArr.size = n := by simp [loArr] - simp [hsize, j.isLt]) - let hij := hiArr[j.1]'(by - have hsize : hiArr.size = n := by simp [hiArr] - simp [hsize, j.isLt]) - if 0 ≤ vj then vj * loj else vj * hij) - -/-- Cached-array upper interval endpoint for a dot product using normalized rational sums. -/ -def dotIntervalUpperCachedRat {n : Nat} (v lo hi : Fin n → Rat) : Rat := - let vArr := Array.ofFn v - let loArr := Array.ofFn lo - let hiArr := Array.ofFn hi - Linear.sumFin n (fun j => - let vj := vArr[j.1]'(by - have hsize : vArr.size = n := by simp [vArr] - simp [hsize, j.isLt]) - let loj := loArr[j.1]'(by - have hsize : loArr.size = n := by simp [loArr] - simp [hsize, j.isLt]) - let hij := hiArr[j.1]'(by - have hsize : hiArr.size = n := by simp [hiArr] - simp [hsize, j.isLt]) - if 0 ≤ vj then vj * hij else vj * loj) - -theorem dotIntervalLowerCachedRat_eq {n : Nat} (v lo hi : Fin n → Rat) : - dotIntervalLowerCachedRat v lo hi = dotIntervalLower v lo hi := by - classical - simp only [dotIntervalLowerCachedRat, dotIntervalLower, Linear.sumFin_eq_list_foldl, - Array.getElem_ofFn] - -theorem dotIntervalUpperCachedRat_eq {n : Nat} (v lo hi : Fin n → Rat) : - dotIntervalUpperCachedRat v lo hi = dotIntervalUpper v lo hi := by - classical - simp only [dotIntervalUpperCachedRat, dotIntervalUpper, Linear.sumFin_eq_list_foldl, - Array.getElem_ofFn] - -/-! Absolute bounds. -/ - -/-- Absolute bound from interval endpoints for a dot product. -/ -def dotIntervalAbsBound {n : Nat} (v lo hi : Fin n → Rat) : Rat := - max |dotIntervalLower v lo hi| |dotIntervalUpper v lo hi| - -/-- Lower interval endpoint for a matrix-vector product under input intervals. -/ -def mulVecIntervalLower {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (lo hi : Fin n → Rat) : Fin m → Rat := - fun i => dotIntervalLower (fun j => W i j) lo hi - -/-- Upper interval endpoint for a matrix-vector product under input intervals. -/ -def mulVecIntervalUpper {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (lo hi : Fin n → Rat) : Fin m → Rat := - fun i => dotIntervalUpper (fun j => W i j) lo hi - -theorem dotIntervalLower_le_dotProduct {n : Nat} (v lo hi x : Fin n → Rat) - (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : - dotIntervalLower v lo hi ≤ dotProduct v x := by - classical - simp only [dotIntervalLower, Linear.sumFin_eq_sum_univ, dotProduct] - refine Finset.sum_le_sum ?_ - intro j _ - by_cases hv : 0 ≤ v j - · have hmul : v j * lo j ≤ v j * x j := by - exact mul_le_mul_of_nonneg_left (hlo j) hv - simpa [hv] using hmul - · have hv' : v j ≤ 0 := le_of_lt (lt_of_not_ge hv) - have hmul : v j * hi j ≤ v j * x j := by - have hmul' : hi j * v j ≤ x j * v j := by - exact mul_le_mul_of_nonpos_right (hhi j) hv' - simpa only [mul_comm] using hmul' - simpa [hv] using hmul - -theorem dotProduct_le_dotIntervalUpper {n : Nat} (v lo hi x : Fin n → Rat) - (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : - dotProduct v x ≤ dotIntervalUpper v lo hi := by - classical - simp only [dotIntervalUpper, Linear.sumFin_eq_sum_univ, dotProduct] - refine Finset.sum_le_sum ?_ - intro j _ - by_cases hv : 0 ≤ v j - · have hmul : v j * x j ≤ v j * hi j := by - exact mul_le_mul_of_nonneg_left (hhi j) hv - simpa [hv] using hmul - · have hv' : v j ≤ 0 := le_of_lt (lt_of_not_ge hv) - have hmul : v j * x j ≤ v j * lo j := by - have hmul' : x j * v j ≤ lo j * v j := by - exact mul_le_mul_of_nonpos_right (hlo j) hv' - simpa only [mul_comm] using hmul' - simpa [hv] using hmul - -theorem abs_le_max_abs_abs_of_interval {a b x : Rat} (hlo : a ≤ x) (hhi : x ≤ b) : - |x| ≤ max |a| |b| := by - exact abs_le_max_abs_abs hlo hhi - -/-- Global absolute bound from interval endpoints. -/ -def intervalAbsBound {n : Nat} (lo hi : Fin n → Rat) : Rat := - Linear.foldlFin n (fun acc i => max acc (max |lo i| |hi i|)) 0 - -/-- `intervalAbsBound` dominates each endpoint absolute value. -/ -theorem max_abs_le_intervalAbsBound {n : Nat} (lo hi : Fin n → Rat) (i : Fin n) : - max |lo i| |hi i| ≤ intervalAbsBound lo hi := by - simpa [intervalAbsBound] using - (foldlFin_max_ge (f := fun j => max |lo j| |hi j|) i) - -/-- `intervalAbsBound` bounds any element inside the interval. -/ -theorem abs_le_intervalAbsBound {n : Nat} (lo hi x : Fin n → Rat) - (hlo : ∀ i, lo i ≤ x i) (hhi : ∀ i, x i ≤ hi i) (i : Fin n) : - |x i| ≤ intervalAbsBound lo hi := by - have hbound : |x i| ≤ max |lo i| |hi i| := - abs_le_max_abs_abs_of_interval (hlo i) (hhi i) - have hsup : max |lo i| |hi i| ≤ intervalAbsBound lo hi := - max_abs_le_intervalAbsBound lo hi i - exact le_trans hbound hsup - -theorem abs_dotProduct_le_dotIntervalAbsBound {n : Nat} (v lo hi x : Fin n → Rat) - (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : - |dotProduct v x| ≤ dotIntervalAbsBound v lo hi := by - have hlow : dotIntervalLower v lo hi ≤ dotProduct v x := - dotIntervalLower_le_dotProduct v lo hi x hlo hhi - have hhigh : dotProduct v x ≤ dotIntervalUpper v lo hi := - dotProduct_le_dotIntervalUpper v lo hi x hlo hhi - have habs : |dotProduct v x| ≤ - max |dotIntervalLower v lo hi| |dotIntervalUpper v lo hi| := - abs_le_max_abs_abs_of_interval hlow hhigh - simpa [dotIntervalAbsBound] using habs - -/-! Real-valued bounds from rational intervals. -/ - -lemma mul_between_of_bounds_real {a b x y : Real} (hx : a ≤ x) (hx' : x ≤ b) : - min (a * y) (b * y) ≤ x * y ∧ x * y ≤ max (a * y) (b * y) := by - have hab : a ≤ b := le_trans hx hx' - by_cases hy : 0 ≤ y - · have hmin : min (a * y) (b * y) = a * y := by - have hle : a * y ≤ b * y := by - exact mul_le_mul_of_nonneg_right hab hy - exact min_eq_left hle - have hmax : max (a * y) (b * y) = b * y := by - have hle : a * y ≤ b * y := by - exact mul_le_mul_of_nonneg_right hab hy - exact max_eq_right hle - constructor - · have hmul : a * y ≤ x * y := by - exact mul_le_mul_of_nonneg_right hx hy - simpa [hmin] using hmul - · have hmul : x * y ≤ b * y := by - exact mul_le_mul_of_nonneg_right hx' hy - simpa [hmax] using hmul - · have hy' : y ≤ 0 := le_of_not_ge hy - have hmin : min (a * y) (b * y) = b * y := by - have hle : b * y ≤ a * y := by - exact mul_le_mul_of_nonpos_right hab hy' - exact min_eq_right hle - have hmax : max (a * y) (b * y) = a * y := by - have hle : b * y ≤ a * y := by - exact mul_le_mul_of_nonpos_right hab hy' - exact max_eq_left hle - constructor - · have hmul : b * y ≤ x * y := by - exact mul_le_mul_of_nonpos_right hx' hy' - simpa [hmin] using hmul - · have hmul : x * y ≤ a * y := by - exact mul_le_mul_of_nonpos_right hx hy' - simpa [hmax] using hmul - -lemma mulIntervalLower_le_mul_real {a b c d : Rat} {x y : Real} - (hx : (a : Real) ≤ x) (hx' : x ≤ (b : Real)) - (hy : (c : Real) ≤ y) (hy' : y ≤ (d : Real)) : - (mulIntervalLower a b c d : Real) ≤ x * y := by - have hAy : - min ((a : Real) * (c : Real)) ((a : Real) * (d : Real)) ≤ (a : Real) * y := by - have h := mul_between_of_bounds_real (a := (c : Real)) (b := (d : Real)) (x := y) - (y := (a : Real)) hy hy' - simpa only [mul_comm] using h.1 - have hBy : - min ((b : Real) * (c : Real)) ((b : Real) * (d : Real)) ≤ (b : Real) * y := by - have h := mul_between_of_bounds_real (a := (c : Real)) (b := (d : Real)) (x := y) - (y := (b : Real)) hy hy' - simpa only [mul_comm] using h.1 - have hmin : - min (min ((a : Real) * (c : Real)) ((a : Real) * (d : Real))) - (min ((b : Real) * (c : Real)) ((b : Real) * (d : Real))) ≤ - min ((a : Real) * y) ((b : Real) * y) := by - apply le_min - · exact le_trans (min_le_left _ _) hAy - · exact le_trans (min_le_right _ _) hBy - have hxy := (mul_between_of_bounds_real (a := (a : Real)) (b := (b : Real)) (x := x) - (y := y) hx hx').1 - have hcast : - (mulIntervalLower a b c d : Real) = - min (min ((a : Real) * (c : Real)) ((a : Real) * (d : Real))) - (min ((b : Real) * (c : Real)) ((b : Real) * (d : Real))) := by - simp only [mulIntervalLower, Rat.cast_min, Rat.cast_mul] - calc - (mulIntervalLower a b c d : Real) - = min (min ((a : Real) * (c : Real)) ((a : Real) * (d : Real))) - (min ((b : Real) * (c : Real)) ((b : Real) * (d : Real))) := hcast - _ ≤ min ((a : Real) * y) ((b : Real) * y) := hmin - _ ≤ x * y := hxy - -lemma mul_le_mulIntervalUpper_real {a b c d : Rat} {x y : Real} - (hx : (a : Real) ≤ x) (hx' : x ≤ (b : Real)) - (hy : (c : Real) ≤ y) (hy' : y ≤ (d : Real)) : - x * y ≤ (mulIntervalUpper a b c d : Real) := by - have hAy : - (a : Real) * y ≤ max ((a : Real) * (c : Real)) ((a : Real) * (d : Real)) := by - have h := mul_between_of_bounds_real (a := (c : Real)) (b := (d : Real)) (x := y) - (y := (a : Real)) hy hy' - simpa only [mul_comm] using h.2 - have hBy : - (b : Real) * y ≤ max ((b : Real) * (c : Real)) ((b : Real) * (d : Real)) := by - have h := mul_between_of_bounds_real (a := (c : Real)) (b := (d : Real)) (x := y) - (y := (b : Real)) hy hy' - simpa only [mul_comm] using h.2 - have hmax : - max ((a : Real) * y) ((b : Real) * y) ≤ - max (max ((a : Real) * (c : Real)) ((a : Real) * (d : Real))) - (max ((b : Real) * (c : Real)) ((b : Real) * (d : Real))) := by - apply max_le - · exact le_trans hAy (le_max_left _ _) - · exact le_trans hBy (le_max_right _ _) - have hxy := (mul_between_of_bounds_real (a := (a : Real)) (b := (b : Real)) (x := x) - (y := y) hx hx').2 - have hcast : - (mulIntervalUpper a b c d : Real) = - max (max ((a : Real) * (c : Real)) ((a : Real) * (d : Real))) - (max ((b : Real) * (c : Real)) ((b : Real) * (d : Real))) := by - simp only [mulIntervalUpper, Rat.cast_max, Rat.cast_mul] - calc - x * y ≤ max ((a : Real) * y) ((b : Real) * y) := hxy - _ ≤ max (max ((a : Real) * (c : Real)) ((a : Real) * (d : Real))) - (max ((b : Real) * (c : Real)) ((b : Real) * (d : Real))) := hmax - _ = (mulIntervalUpper a b c d : Real) := hcast.symm - -/-- Sum of lower interval products bounds the dot-product sum (Real). -/ -private theorem sum_mulIntervalLower_le_sum_mul_real {n : Nat} - (lo1 hi1 lo2 hi2 : Fin n → Rat) (x y : Fin n → Real) - (hlo1 : ∀ j, (lo1 j : Real) ≤ x j) (hhi1 : ∀ j, x j ≤ (hi1 j : Real)) - (hlo2 : ∀ j, (lo2 j : Real) ≤ y j) (hhi2 : ∀ j, y j ≤ (hi2 j : Real)) : - (∑ j, (mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j) : Real)) ≤ - ∑ j, x j * y j := by - refine Finset.sum_le_sum ?_ - intro j _ - exact mulIntervalLower_le_mul_real (hlo1 j) (hhi1 j) (hlo2 j) (hhi2 j) - -/-- Sum of products is bounded by the upper interval products (Real). -/ -private theorem sum_mul_le_sum_mulIntervalUpper_real {n : Nat} - (lo1 hi1 lo2 hi2 : Fin n → Rat) (x y : Fin n → Real) - (hlo1 : ∀ j, (lo1 j : Real) ≤ x j) (hhi1 : ∀ j, x j ≤ (hi1 j : Real)) - (hlo2 : ∀ j, (lo2 j : Real) ≤ y j) (hhi2 : ∀ j, y j ≤ (hi2 j : Real)) : - ∑ j, x j * y j ≤ - ∑ j, (mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j) : Real) := by - refine Finset.sum_le_sum ?_ - intro j _ - exact mul_le_mulIntervalUpper_real (hlo1 j) (hhi1 j) (hlo2 j) (hhi2 j) - -theorem dotIntervalLower2_le_dotProduct_real {n : Nat} (lo1 hi1 lo2 hi2 : Fin n → Rat) - (x y : Fin n → Real) - (hlo1 : ∀ j, (lo1 j : Real) ≤ x j) (hhi1 : ∀ j, x j ≤ (hi1 j : Real)) - (hlo2 : ∀ j, (lo2 j : Real) ≤ y j) (hhi2 : ∀ j, y j ≤ (hi2 j : Real)) : - (dotIntervalLower2 lo1 hi1 lo2 hi2 : Real) ≤ dotProduct x y := by - classical - have hcast : - (dotIntervalLower2 lo1 hi1 lo2 hi2 : Real) = - ∑ j, (mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j) : Real) := by - simpa [dotIntervalLower2, ratToReal_def] using - (Linear.ratToReal_sumFin - (f := fun j => mulIntervalLower (lo1 j) (hi1 j) (lo2 j) (hi2 j))) - have hsum := - sum_mulIntervalLower_le_sum_mul_real - (lo1 := lo1) (hi1 := hi1) (lo2 := lo2) (hi2 := hi2) (x := x) (y := y) - hlo1 hhi1 hlo2 hhi2 - simpa [hcast, dotProduct] using hsum - -theorem dotProduct_le_dotIntervalUpper2_real {n : Nat} (lo1 hi1 lo2 hi2 : Fin n → Rat) - (x y : Fin n → Real) - (hlo1 : ∀ j, (lo1 j : Real) ≤ x j) (hhi1 : ∀ j, x j ≤ (hi1 j : Real)) - (hlo2 : ∀ j, (lo2 j : Real) ≤ y j) (hhi2 : ∀ j, y j ≤ (hi2 j : Real)) : - dotProduct x y ≤ (dotIntervalUpper2 lo1 hi1 lo2 hi2 : Real) := by - classical - have hcast : - (dotIntervalUpper2 lo1 hi1 lo2 hi2 : Real) = - ∑ j, (mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j) : Real) := by - simpa [dotIntervalUpper2, ratToReal_def] using - (Linear.ratToReal_sumFin - (f := fun j => mulIntervalUpper (lo1 j) (hi1 j) (lo2 j) (hi2 j))) - have hsum := - sum_mul_le_sum_mulIntervalUpper_real - (lo1 := lo1) (hi1 := hi1) (lo2 := lo2) (hi2 := hi2) (x := x) (y := y) - hlo1 hhi1 hlo2 hhi2 - simpa [hcast, dotProduct] using hsum - -theorem dotIntervalLowerUpper2SignSplit_spec_real {n : Nat} (dims : List (Fin n)) - (lo1 hi1 lo2 hi2 : Fin n → Rat) (x y : Fin n → Real) - (hlo1 : ∀ j, (lo1 j : Real) ≤ x j) (hhi1 : ∀ j, x j ≤ (hi1 j : Real)) - (hlo2 : ∀ j, (lo2 j : Real) ≤ y j) (hhi2 : ∀ j, y j ≤ (hi2 j : Real)) : - let bounds := dotIntervalLowerUpper2SignSplit dims lo1 hi1 lo2 hi2 - (bounds.1 : Real) ≤ dotProduct x y ∧ dotProduct x y ≤ (bounds.2 : Real) := by - classical - induction dims generalizing lo1 hi1 with - | nil => - have hlow := - dotIntervalLower2_le_dotProduct_real - (lo1 := lo1) (hi1 := hi1) (lo2 := lo2) (hi2 := hi2) - (x := x) (y := y) hlo1 hhi1 hlo2 hhi2 - have hhigh := - dotProduct_le_dotIntervalUpper2_real - (lo1 := lo1) (hi1 := hi1) (lo2 := lo2) (hi2 := hi2) - (x := x) (y := y) hlo1 hhi1 hlo2 hhi2 - simpa only [dotIntervalLowerUpper2SignSplit, dotIntervalLowerUpper2CommonDen_fst, - dotIntervalLowerUpper2CommonDen_snd] using And.intro hlow hhigh - | cons i rest ih => - by_cases hx : 0 ≤ x i - · let clamped := clampAt i true lo1 hi1 - let boundsPos := dotIntervalLowerUpper2SignSplit rest clamped.1 clamped.2 lo2 hi2 - let boundsNeg := - dotIntervalLowerUpper2SignSplit rest (clampAt i false lo1 hi1).1 - (clampAt i false lo1 hi1).2 lo2 hi2 - have hlo1' : ∀ j, (clamped.1 j : Real) ≤ x j := by - intro j - by_cases hji : j = i - · have hmax : max (0 : Real) (lo1 i : Real) ≤ x i := - (max_le_iff).2 ⟨hx, hlo1 i⟩ - simpa [clamped, clampAt, hji, ratToReal_max] using hmax - · simpa [clamped, clampAt, hji] using hlo1 j - have hhi1' : ∀ j, x j ≤ (clamped.2 j : Real) := by - intro j - simpa [clamped, clampAt] using hhi1 j - have hpos := - ih (lo1 := clamped.1) (hi1 := clamped.2) hlo1' hhi1' - have hlow : (min boundsPos.1 boundsNeg.1 : Real) ≤ dotProduct x y := by - exact le_trans (min_le_left _ _) hpos.1 - have hhigh : dotProduct x y ≤ (max boundsPos.2 boundsNeg.2 : Real) := by - exact le_trans hpos.2 (le_max_left _ _) - simpa [dotIntervalLowerUpper2SignSplit, boundsPos, boundsNeg, clamped] using - And.intro hlow hhigh - · have hxneg : x i ≤ 0 := le_of_lt (lt_of_not_ge hx) - let clamped := clampAt i false lo1 hi1 - let boundsPos := - dotIntervalLowerUpper2SignSplit rest (clampAt i true lo1 hi1).1 - (clampAt i true lo1 hi1).2 lo2 hi2 - let boundsNeg := dotIntervalLowerUpper2SignSplit rest clamped.1 clamped.2 lo2 hi2 - have hlo1' : ∀ j, (clamped.1 j : Real) ≤ x j := by - intro j - simpa [clamped, clampAt] using hlo1 j - have hhi1' : ∀ j, x j ≤ (clamped.2 j : Real) := by - intro j - by_cases hji : j = i - · have hmin : x i ≤ min (0 : Real) (hi1 i : Real) := - (le_min_iff).2 ⟨hxneg, hhi1 i⟩ - simpa [clamped, clampAt, hji, ratToReal_min] using hmin - · simpa [clamped, clampAt, hji] using hhi1 j - have hneg := - ih (lo1 := clamped.1) (hi1 := clamped.2) hlo1' hhi1' - have hlow : (min boundsPos.1 boundsNeg.1 : Real) ≤ dotProduct x y := by - exact le_trans (min_le_right _ _) hneg.1 - have hhigh : dotProduct x y ≤ (max boundsPos.2 boundsNeg.2 : Real) := by - exact le_trans hneg.2 (le_max_right _ _) - simpa [dotIntervalLowerUpper2SignSplit, boundsPos, boundsNeg, clamped] using - And.intro hlow hhigh - -theorem dotIntervalLowerUpper2SignSplitBoth_spec_real {n : Nat} (dims1 dims2 : List (Fin n)) - (lo1 hi1 lo2 hi2 : Fin n → Rat) (x y : Fin n → Real) - (hlo1 : ∀ j, (lo1 j : Real) ≤ x j) (hhi1 : ∀ j, x j ≤ (hi1 j : Real)) - (hlo2 : ∀ j, (lo2 j : Real) ≤ y j) (hhi2 : ∀ j, y j ≤ (hi2 j : Real)) : - let bounds := dotIntervalLowerUpper2SignSplitBoth dims1 dims2 lo1 hi1 lo2 hi2 - (bounds.1 : Real) ≤ dotProduct x y ∧ dotProduct x y ≤ (bounds.2 : Real) := by - classical - let bounds1 := dotIntervalLowerUpper2SignSplit dims1 lo1 hi1 lo2 hi2 - let bounds2 := dotIntervalLowerUpper2SignSplit dims2 lo2 hi2 lo1 hi1 - have h1 := - dotIntervalLowerUpper2SignSplit_spec_real - (dims := dims1) (lo1 := lo1) (hi1 := hi1) (lo2 := lo2) (hi2 := hi2) - (x := x) (y := y) hlo1 hhi1 hlo2 hhi2 - have h2swap := - dotIntervalLowerUpper2SignSplit_spec_real - (dims := dims2) (lo1 := lo2) (hi1 := hi2) (lo2 := lo1) (hi2 := hi1) - (x := y) (y := x) hlo2 hhi2 hlo1 hhi1 - have h2 : (bounds2.1 : Real) ≤ dotProduct x y ∧ dotProduct x y ≤ (bounds2.2 : Real) := by - simpa only [dotProduct_comm] using h2swap - have hlow' : max (bounds1.1 : Real) (bounds2.1 : Real) ≤ dotProduct x y := - (max_le_iff).2 ⟨h1.1, h2.1⟩ - have hhigh' : dotProduct x y ≤ min (bounds1.2 : Real) (bounds2.2 : Real) := - (le_min_iff).2 ⟨h1.2, h2.2⟩ - have hlow : ((max bounds1.1 bounds2.1 : Rat) : Real) ≤ dotProduct x y := by - simpa [ratToReal_max] using hlow' - have hhigh : dotProduct x y ≤ ((min bounds1.2 bounds2.2 : Rat) : Real) := by - simpa [ratToReal_min] using hhigh' - simpa only [dotIntervalLowerUpper2SignSplitBoth, bounds1, bounds2] using And.intro hlow hhigh - -theorem dotIntervalLower_le_dotProduct_real {n : Nat} (v lo hi : Fin n → Rat) - (x : Fin n → Real) - (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : - (dotIntervalLower v lo hi : Real) ≤ dotProduct (fun j => (v j : Real)) x := by - classical - have hcast : - (dotIntervalLower v lo hi : Real) = - ∑ j, if 0 ≤ v j then (v j : Real) * (lo j : Real) else (v j : Real) * (hi j : Real) := by - have hcast' : - ratToReal (dotIntervalLower v lo hi) = - ∑ j, if 0 ≤ v j then ratToReal (v j) * ratToReal (lo j) else - ratToReal (v j) * ratToReal (hi j) := by - simpa [dotIntervalLower, ratToReal_if, ratToReal_mul] using - (Linear.ratToReal_sumFin - (f := fun j => if 0 ≤ v j then v j * lo j else v j * hi j)) - simpa [ratToReal_def] using hcast' - have hsum : - (∑ j, if 0 ≤ v j then (v j : Real) * (lo j : Real) else (v j : Real) * (hi j : Real)) ≤ - ∑ j, (v j : Real) * x j := by - refine Finset.sum_le_sum ?_ - intro j _ - by_cases hv : 0 ≤ v j - · have hv' : (0 : Real) ≤ (v j : Real) := by - simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hv - have hmul : (v j : Real) * (lo j : Real) ≤ (v j : Real) * x j := by - exact mul_le_mul_of_nonneg_left (hlo j) hv' - simpa [hv] using hmul - · have hv' : (v j : Real) ≤ 0 := by - simpa [ratToReal_def] using (ratToReal_nonpos_iff (x := v j)).2 (le_of_not_ge hv) - have hmul : (v j : Real) * (hi j : Real) ≤ (v j : Real) * x j := by - exact mul_le_mul_of_nonpos_left (hhi j) hv' - simpa [hv] using hmul - simpa [hcast, dotProduct] using hsum - -theorem dotIntervalLower_le_dotProduct_real_add {n : Nat} - (v lo hi : Fin n → Rat) (x : Fin n → Real) (b : Real) - (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : - (dotIntervalLower v lo hi : Real) + b ≤ - dotProduct (fun j => (v j : Real)) x + b := by - have hlow := - dotIntervalLower_le_dotProduct_real (v := v) (lo := lo) (hi := hi) - (x := x) hlo hhi - simpa [add_comm] using add_le_add_left hlow b - -theorem dotProduct_le_dotIntervalUpper_real {n : Nat} (v lo hi : Fin n → Rat) - (x : Fin n → Real) - (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : - dotProduct (fun j => (v j : Real)) x ≤ (dotIntervalUpper v lo hi : Real) := by - classical - have hcast : - (dotIntervalUpper v lo hi : Real) = - ∑ j, if 0 ≤ v j then (v j : Real) * (hi j : Real) else (v j : Real) * (lo j : Real) := by - have hcast' : - ratToReal (dotIntervalUpper v lo hi) = - ∑ j, if 0 ≤ v j then ratToReal (v j) * ratToReal (hi j) else - ratToReal (v j) * ratToReal (lo j) := by - simpa [dotIntervalUpper, ratToReal_if, ratToReal_mul] using - (Linear.ratToReal_sumFin - (f := fun j => if 0 ≤ v j then v j * hi j else v j * lo j)) - simpa [ratToReal_def] using hcast' - have hsum : - ∑ j, (v j : Real) * x j ≤ - ∑ j, if 0 ≤ v j then (v j : Real) * (hi j : Real) else (v j : Real) * (lo j : Real) := by - refine Finset.sum_le_sum ?_ - intro j _ - by_cases hv : 0 ≤ v j - · have hv' : (0 : Real) ≤ (v j : Real) := by - simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hv - have hmul : (v j : Real) * x j ≤ (v j : Real) * (hi j : Real) := by - exact mul_le_mul_of_nonneg_left (hhi j) hv' - simpa [hv] using hmul - · have hv' : (v j : Real) ≤ 0 := by - simpa [ratToReal_def] using (ratToReal_nonpos_iff (x := v j)).2 (le_of_not_ge hv) - have hmul : (v j : Real) * x j ≤ (v j : Real) * (lo j : Real) := by - exact mul_le_mul_of_nonpos_left (hlo j) hv' - simpa [hv] using hmul - simpa [hcast, dotProduct] using hsum - -theorem dotProduct_le_dotIntervalUpper_real_add {n : Nat} - (v lo hi : Fin n → Rat) (x : Fin n → Real) (b : Real) - (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : - dotProduct (fun j => (v j : Real)) x + b ≤ - (dotIntervalUpper v lo hi : Real) + b := by - have hhigh := - dotProduct_le_dotIntervalUpper_real (v := v) (lo := lo) (hi := hi) - (x := x) hlo hhi - simpa [add_comm] using add_le_add_left hhigh b - -theorem abs_le_max_abs_abs_of_interval_real {a b x : Real} (hlo : a ≤ x) (hhi : x ≤ b) : - |x| ≤ max |a| |b| := by - exact abs_le_max_abs_abs hlo hhi - -/-- `intervalAbsBound` controls real-valued coordinates inside a rational interval. -/ -theorem abs_le_intervalAbsBound_real {n : Nat} (lo hi : Fin n → Rat) (x : Fin n → Real) - (hlo : ∀ i, (lo i : Real) ≤ x i) (hhi : ∀ i, x i ≤ (hi i : Real)) (i : Fin n) : - |x i| ≤ (intervalAbsBound lo hi : Real) := by - have hbound : |x i| ≤ max |(lo i : Real)| |(hi i : Real)| := - abs_le_max_abs_abs_of_interval_real (hlo i) (hhi i) - have hsup : max |lo i| |hi i| ≤ intervalAbsBound lo hi := - max_abs_le_intervalAbsBound lo hi i - have hsup_real : - max |(lo i : Real)| |(hi i : Real)| ≤ (intervalAbsBound lo hi : Real) := by - refine max_le_iff.mpr ?_ - constructor - · have hleft : |lo i| ≤ intervalAbsBound lo hi := by - exact le_trans (le_max_left _ _) (max_abs_le_intervalAbsBound lo hi i) - simpa [ratToReal_def] using ratToReal_abs_le_of_le hleft - · have hright : |hi i| ≤ intervalAbsBound lo hi := by - exact le_trans (le_max_right _ _) (max_abs_le_intervalAbsBound lo hi i) - simpa [ratToReal_def] using ratToReal_abs_le_of_le hright - exact le_trans hbound hsup_real - -theorem abs_dotProduct_le_dotIntervalAbsBound_real {n : Nat} (v lo hi : Fin n → Rat) - (x : Fin n → Real) - (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : - |dotProduct (fun j => (v j : Real)) x| ≤ (dotIntervalAbsBound v lo hi : Real) := by - have hlow : - (dotIntervalLower v lo hi : Real) ≤ dotProduct (fun j => (v j : Real)) x := - dotIntervalLower_le_dotProduct_real v lo hi x hlo hhi - have hhigh : - dotProduct (fun j => (v j : Real)) x ≤ (dotIntervalUpper v lo hi : Real) := - dotProduct_le_dotIntervalUpper_real v lo hi x hlo hhi - have habs : - |dotProduct (fun j => (v j : Real)) x| ≤ - max |(dotIntervalLower v lo hi : Real)| |(dotIntervalUpper v lo hi : Real)| := - abs_le_max_abs_abs_of_interval_real hlow hhigh - simpa [dotIntervalAbsBound] using habs - -/-! Matrix-vector interval bounds. -/ - -theorem mulVecIntervalLower_le_mulVec {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (lo hi x : Fin n → Rat) (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : - ∀ i, mulVecIntervalLower W lo hi i ≤ Matrix.mulVec W x i := by - intro i - simpa [mulVecIntervalLower, Matrix.mulVec, dotProduct] using - (dotIntervalLower_le_dotProduct (v := fun j => W i j) lo hi x hlo hhi) - -theorem mulVec_le_mulVecIntervalUpper {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (lo hi x : Fin n → Rat) (hlo : ∀ j, lo j ≤ x j) (hhi : ∀ j, x j ≤ hi j) : - ∀ i, Matrix.mulVec W x i ≤ mulVecIntervalUpper W lo hi i := by - intro i - simpa [mulVecIntervalUpper, Matrix.mulVec, dotProduct] using - (dotProduct_le_dotIntervalUpper (v := fun j => W i j) lo hi x hlo hhi) - -theorem mulVecIntervalLower_le_upper {m n : Nat} (W : Matrix (Fin m) (Fin n) Rat) - (lo hi : Fin n → Rat) (hlohi : ∀ j, lo j ≤ hi j) : - ∀ i, mulVecIntervalLower W lo hi i ≤ mulVecIntervalUpper W lo hi i := by - intro i - have hlow : - dotIntervalLower (fun j => W i j) lo hi ≤ dotProduct (fun j => W i j) lo := - dotIntervalLower_le_dotProduct (v := fun j => W i j) lo hi lo - (fun j => le_rfl) hlohi - have hhigh : - dotProduct (fun j => W i j) lo ≤ dotIntervalUpper (fun j => W i j) lo hi := - dotProduct_le_dotIntervalUpper (v := fun j => W i j) lo hi lo - (fun j => le_rfl) hlohi - exact le_trans hlow hhigh - -end Bounds - - -end Nfp diff --git a/Nfp/Bounds/Mlp.lean b/Nfp/Bounds/Mlp.lean index 66eed18..1a443fe 100644 --- a/Nfp/Bounds/Mlp.lean +++ b/Nfp/Bounds/Mlp.lean @@ -6,7 +6,6 @@ public import Mathlib.Algebra.BigOperators.Group.Finset.Basic public import Nfp.Core.Basic public import Nfp.Bounds.Gelu public import Nfp.Bounds.LayerNorm -public import Nfp.Bounds.MatrixNorm /-! Interval bounds for GPT-2 MLP blocks (linear + GELU + linear). diff --git a/Nfp/Bounds/Transformer.lean b/Nfp/Bounds/Transformer.lean deleted file mode 100644 index 23a059a..0000000 --- a/Nfp/Bounds/Transformer.lean +++ /dev/null @@ -1,10 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.Bounds.Transformer.Basic -public import Nfp.Bounds.Transformer.Embedding - -/-! -Transformer-stack interval bounds and supporting lemmas. --/ diff --git a/Nfp/Bounds/Transformer/Basic.lean b/Nfp/Bounds/Transformer/Basic.lean deleted file mode 100644 index c5388e3..0000000 --- a/Nfp/Bounds/Transformer/Basic.lean +++ /dev/null @@ -1,522 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Mathlib.Algebra.BigOperators.Group.Finset.Basic -public import Mathlib.Data.List.Range -public import Mathlib.Data.Real.Basic -public import Nfp.Circuit.Cert.ResidualInterval -public import Nfp.Model.Gpt2 -public import Nfp.Bounds.Attention -public import Nfp.Bounds.LayerNorm -public import Nfp.Bounds.Transformer.Embedding -public import Nfp.Linear.FinFold - -/-! -Interval bounds for transformer stacks and final LayerNorm outputs. --/ - -public section - -namespace Nfp - - -namespace Bounds - -open scoped BigOperators - -/-- Real-valued output of a transformer layer. -/ -noncomputable def transformerLayerReal {seq dModel dHead numHeads hidden : Nat} [NeZero seq] - (eps : Rat) (layer : Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (scores : Fin numHeads → Fin seq → Fin seq → Real) - (x : Fin seq → Fin dModel → Real) (q : Fin seq) (i : Fin dModel) : Real := - x q i + - attentionOutputReal eps layer.ln1Gamma layer.ln1Beta heads layer.attnBias scores x q i + - mlpReal layer.mlpWIn layer.mlpBIn layer.mlpWOut layer.mlpBOut - (layerNormRealOfReal eps layer.ln2Gamma layer.ln2Beta - (fun j => - x q j + attentionOutputReal eps layer.ln1Gamma layer.ln1Beta heads - layer.attnBias scores x q j)) i - -/-- `transformerLayerBounds` soundness for `transformerLayerReal`. -/ -theorem transformerLayerBounds_spec_real {seq dModel dHead numHeads hidden : Nat} [NeZero seq] - (eps : Rat) (layer : Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (scores : Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) - (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) - (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : - let bounds := transformerLayerBounds eps layer.ln1Gamma layer.ln1Beta layer.ln2Gamma - layer.ln2Beta heads layer.attnBias layer.mlpWIn layer.mlpBIn layer.mlpWOut - layer.mlpBOut lo hi - ∀ q i, - (bounds.1 i : Real) ≤ transformerLayerReal eps layer heads scores x q i ∧ - transformerLayerReal eps layer heads scores x q i ≤ (bounds.2 i : Real) := by - classical - simpa [transformerLayerReal] using - (transformerLayerBounds_spec (eps := eps) - (ln1Gamma := layer.ln1Gamma) (ln1Beta := layer.ln1Beta) - (ln2Gamma := layer.ln2Gamma) (ln2Beta := layer.ln2Beta) - (heads := heads) (attnBias := layer.attnBias) - (mlpWIn := layer.mlpWIn) (mlpBIn := layer.mlpBIn) - (mlpWOut := layer.mlpWOut) (mlpBOut := layer.mlpBOut) - (scores := scores) (lo := lo) (hi := hi) (x := x) - hne heps hsqrt hlo hhi) - -/-- Interval bounds for a transformer layer from per-position bounds. -/ -def transformerLayerBoundsPos {seq dModel dHead numHeads hidden : Nat} [NeZero seq] - (eps : Rat) (layer : Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (lo hi : Fin seq → Fin dModel → Rat) : - (Fin seq → Fin dModel → Rat) × (Fin seq → Fin dModel → Rat) := - let positions := (Finset.univ : Finset (Fin seq)) - let hpos : positions.Nonempty := by - simp [positions] - let loCached := cacheBound2 lo - let hiCached := cacheBound2 hi - let base := intervalBoundsOn positions hpos loCached hiCached - let baseLo := cacheBound base.1 - let baseHi := cacheBound base.2 - let attn := attentionOutputBounds eps layer.ln1Gamma layer.ln1Beta heads layer.attnBias - baseLo baseHi - let attnLo := cacheBound attn.1 - let attnHi := cacheBound attn.2 - let yLo : Fin seq → Fin dModel → Rat := fun q i => loCached q i + attnLo i - let yHi : Fin seq → Fin dModel → Rat := fun q i => hiCached q i + attnHi i - let yLoCached := cacheBound2 yLo - let yHiCached := cacheBound2 yHi - let out := cacheBoundPair2 (fun q => - layerNormAbsMlpResidualBounds eps layer.ln2Gamma layer.ln2Beta - layer.mlpWIn layer.mlpBIn layer.mlpWOut layer.mlpBOut - (yLoCached q) (yHiCached q)) - out - -/-- `transformerLayerBoundsPos` soundness for `transformerLayerReal`. -/ -theorem transformerLayerBoundsPos_spec {seq dModel dHead numHeads hidden : Nat} [NeZero seq] - (eps : Rat) (layer : Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (scores : Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin seq → Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) - (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) - (hlo : ∀ q i, (lo q i : Real) ≤ x q i) - (hhi : ∀ q i, x q i ≤ (hi q i : Real)) : - let bounds := transformerLayerBoundsPos eps layer heads lo hi - ∀ q i, - (bounds.1 q i : Real) ≤ transformerLayerReal eps layer heads scores x q i ∧ - transformerLayerReal eps layer heads scores x q i ≤ (bounds.2 q i : Real) := by - classical - intro bounds q i - let positions := (Finset.univ : Finset (Fin seq)) - have hpos : positions.Nonempty := by - simp [positions] - let loCached := cacheBound2 lo - let hiCached := cacheBound2 hi - have hloCached : ∀ q i, (loCached q i : Real) ≤ x q i := by - intro q i - simpa [loCached, cacheBound2_apply] using hlo q i - have hhiCached : ∀ q i, x q i ≤ (hiCached q i : Real) := by - intro q i - simpa [hiCached, cacheBound2_apply] using hhi q i - let base := intervalBoundsOn positions hpos loCached hiCached - have hbase := intervalBoundsOn_spec positions hpos loCached hiCached x - (fun q _ i => hloCached q i) (fun q _ i => hhiCached q i) - have hloBase : ∀ q i, (base.1 i : Real) ≤ x q i := fun q i => - (hbase q (by simp [positions]) i).1 - have hhiBase : ∀ q i, x q i ≤ (base.2 i : Real) := fun q i => - (hbase q (by simp [positions]) i).2 - let baseLo := cacheBound base.1 - let baseHi := cacheBound base.2 - have hloBaseCached : ∀ q i, (baseLo i : Real) ≤ x q i := by - intro q i - simpa [baseLo, cacheBound_apply] using hloBase q i - have hhiBaseCached : ∀ q i, x q i ≤ (baseHi i : Real) := by - intro q i - simpa [baseHi, cacheBound_apply] using hhiBase q i - let attn := attentionOutputBounds eps layer.ln1Gamma layer.ln1Beta heads layer.attnBias - baseLo baseHi - have hattn := attentionOutputBounds_spec eps layer.ln1Gamma layer.ln1Beta heads - layer.attnBias scores baseLo baseHi x hne heps hsqrt hloBaseCached hhiBaseCached q - let attnLo := cacheBound attn.1 - let attnHi := cacheBound attn.2 - let y := fun j => - x q j + attentionOutputReal eps layer.ln1Gamma layer.ln1Beta heads - layer.attnBias scores x q j - have yLo : ∀ j, (loCached q j : Real) + (attn.1 j : Real) ≤ y j := by - intro j - have hlow : - (loCached q j : Real) + (attn.1 j : Real) ≤ - x q j + - attentionOutputReal eps layer.ln1Gamma layer.ln1Beta heads - layer.attnBias scores x q j := by - exact add_le_add (hloCached q j) (hattn j).1 - simpa [y] using hlow - have yHi : ∀ j, y j ≤ (hiCached q j : Real) + (attn.2 j : Real) := by - intro j - have hhigh : - x q j + - attentionOutputReal eps layer.ln1Gamma layer.ln1Beta heads - layer.attnBias scores x q j ≤ - (hiCached q j : Real) + (attn.2 j : Real) := by - exact add_le_add (hhiCached q j) (hattn j).2 - simpa [y] using hhigh - let yLoCached := cacheBound2 (fun q i => loCached q i + attnLo i) - let yHiCached := cacheBound2 (fun q i => hiCached q i + attnHi i) - have yLoCached_bound : ∀ j, (yLoCached q j : Real) ≤ y j := by - intro j - simpa [yLoCached, attnLo, cacheBound_apply, cacheBound2_apply] using (yLo j) - have yHiCached_bound : ∀ j, y j ≤ (yHiCached q j : Real) := by - intro j - simpa [yHiCached, attnHi, cacheBound_apply, cacheBound2_apply] using (yHi j) - have hmlp := - layerNormAbsMlpResidualBounds_spec eps layer.ln2Gamma layer.ln2Beta - layer.mlpWIn layer.mlpBIn layer.mlpWOut layer.mlpBOut - (yLoCached q) (yHiCached q) y hne heps hsqrt yLoCached_bound yHiCached_bound - have hmlp_i := hmlp i - simpa [bounds, transformerLayerBoundsPos, positions, base, loCached, hiCached, baseLo, baseHi, - attn, attnLo, attnHi, y, yLoCached, yHiCached, cacheBound2_apply, cacheBoundPair2_apply_left, - cacheBoundPair2_apply_right, transformerLayerReal, cacheBound_apply] using hmlp_i - -/-- Real-valued transformer stack output (folded left over layers). -/ -noncomputable def transformerStackReal - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (x : Fin seq → Fin dModel → Real) : Fin seq → Fin dModel → Real := - let step := fun x layerIdx => - transformerLayerReal eps (layers layerIdx) (heads layerIdx) (scores layerIdx) x - Linear.foldlFin numLayers step x - -/-- Interval bounds for a transformer stack (folded left over layers). -/ -def transformerStackBounds {dModel dHead numHeads hidden numLayers : Nat} - (eps : Rat) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (lo hi : Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := - let step := fun bounds layerIdx => - transformerLayerBounds eps (layers layerIdx).ln1Gamma (layers layerIdx).ln1Beta - (layers layerIdx).ln2Gamma (layers layerIdx).ln2Beta (heads layerIdx) - (layers layerIdx).attnBias (layers layerIdx).mlpWIn (layers layerIdx).mlpBIn - (layers layerIdx).mlpWOut (layers layerIdx).mlpBOut bounds.1 bounds.2 - Linear.foldlFin numLayers step (lo, hi) - -/-- Interval bounds for a transformer stack from per-position bounds. -/ -def transformerStackBoundsPos {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] - (eps : Rat) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (lo hi : Fin seq → Fin dModel → Rat) : - (Fin seq → Fin dModel → Rat) × (Fin seq → Fin dModel → Rat) := - let step := fun bounds layerIdx => - transformerLayerBoundsPos eps (layers layerIdx) (heads layerIdx) bounds.1 bounds.2 - Linear.foldlFin numLayers step (lo, hi) - -private theorem transformerStackBoundsPos_spec_list - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] - (eps : Rat) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : - ∀ (ls : List (Fin numLayers)) (lo hi : Fin seq → Fin dModel → Rat) - (x : Fin seq → Fin dModel → Real), - (∀ q i, (lo q i : Real) ≤ x q i) → - (∀ q i, x q i ≤ (hi q i : Real)) → - let bounds := (ls.foldl - (fun bounds layerIdx => - transformerLayerBoundsPos eps (layers layerIdx) (heads layerIdx) bounds.1 bounds.2) - (lo, hi)) - let x' := (ls.foldl - (fun x layerIdx => - transformerLayerReal eps (layers layerIdx) (heads layerIdx) (scores layerIdx) x) - x) - ∀ q i, - (bounds.1 q i : Real) ≤ x' q i ∧ - x' q i ≤ (bounds.2 q i : Real) := by - intro ls lo hi x hlo hhi - induction ls generalizing lo hi x hlo hhi with - | nil => - simpa using fun q i => And.intro (hlo q i) (hhi q i) - | cons l ls ih => - have hstep := - transformerLayerBoundsPos_spec eps (layers l) (heads l) (scores l) lo hi x - hne heps hsqrt hlo hhi - let bounds1 := transformerLayerBoundsPos eps (layers l) (heads l) lo hi - let x1 := transformerLayerReal eps (layers l) (heads l) (scores l) x - have hlo1 : ∀ q i, (bounds1.1 q i : Real) ≤ x1 q i := fun q i => (hstep q i).1 - have hhi1 : ∀ q i, x1 q i ≤ (bounds1.2 q i : Real) := fun q i => (hstep q i).2 - have ih' := ih bounds1.1 bounds1.2 x1 hlo1 hhi1 - simpa [bounds1, x1] using ih' - -/-- `transformerStackBoundsPos` soundness for real transformer-stack outputs. -/ -theorem transformerStackBoundsPos_spec {seq dModel dHead numHeads hidden numLayers : Nat} - [NeZero seq] (eps : Rat) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin seq → Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) - (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) - (hlo : ∀ q i, (lo q i : Real) ≤ x q i) - (hhi : ∀ q i, x q i ≤ (hi q i : Real)) : - let bounds := transformerStackBoundsPos eps layers heads lo hi - ∀ q i, - (bounds.1 q i : Real) ≤ transformerStackReal eps layers heads scores x q i ∧ - transformerStackReal eps layers heads scores x q i ≤ (bounds.2 q i : Real) := by - classical - simpa [transformerStackBoundsPos, transformerStackReal, Linear.foldlFin_eq_foldl, - Fin.foldl_eq_foldl_finRange] using - transformerStackBoundsPos_spec_list eps layers heads scores hne heps hsqrt - (List.finRange numLayers) lo hi x hlo hhi - -private theorem transformerStackBounds_spec_list - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] - (eps : Rat) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : - ∀ (ls : List (Fin numLayers)) (lo hi : Fin dModel → Rat) - (x : Fin seq → Fin dModel → Real), - (∀ q i, (lo i : Real) ≤ x q i) → - (∀ q i, x q i ≤ (hi i : Real)) → - let bounds := (ls.foldl - (fun bounds layerIdx => - transformerLayerBounds eps (layers layerIdx).ln1Gamma (layers layerIdx).ln1Beta - (layers layerIdx).ln2Gamma (layers layerIdx).ln2Beta (heads layerIdx) - (layers layerIdx).attnBias (layers layerIdx).mlpWIn (layers layerIdx).mlpBIn - (layers layerIdx).mlpWOut (layers layerIdx).mlpBOut bounds.1 bounds.2) - (lo, hi)) - let x' := (ls.foldl - (fun x layerIdx => - transformerLayerReal eps (layers layerIdx) (heads layerIdx) (scores layerIdx) x) - x) - ∀ q i, - (bounds.1 i : Real) ≤ x' q i ∧ - x' q i ≤ (bounds.2 i : Real) := by - intro ls lo hi x hlo hhi - induction ls generalizing lo hi x hlo hhi with - | nil => - simpa using fun q i => And.intro (hlo q i) (hhi q i) - | cons l ls ih => - have hstep := - transformerLayerBounds_spec_real eps (layers l) (heads l) (scores l) lo hi x - hne heps hsqrt hlo hhi - let bounds1 := - transformerLayerBounds eps (layers l).ln1Gamma (layers l).ln1Beta (layers l).ln2Gamma - (layers l).ln2Beta (heads l) (layers l).attnBias (layers l).mlpWIn (layers l).mlpBIn - (layers l).mlpWOut (layers l).mlpBOut lo hi - let x1 := transformerLayerReal eps (layers l) (heads l) (scores l) x - have hlo1 : ∀ q i, (bounds1.1 i : Real) ≤ x1 q i := fun q i => (hstep q i).1 - have hhi1 : ∀ q i, x1 q i ≤ (bounds1.2 i : Real) := fun q i => (hstep q i).2 - have ih' := ih bounds1.1 bounds1.2 x1 hlo1 hhi1 - simpa [bounds1, x1] using ih' - -/-- `transformerStackBounds` soundness for real transformer-stack outputs. -/ -theorem transformerStackBounds_spec {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] - (eps : Rat) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) - (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) - (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : - let bounds := transformerStackBounds eps layers heads lo hi - ∀ q i, - (bounds.1 i : Real) ≤ transformerStackReal eps layers heads scores x q i ∧ - transformerStackReal eps layers heads scores x q i ≤ (bounds.2 i : Real) := by - classical - simpa [transformerStackBounds, transformerStackReal, Linear.foldlFin_eq_foldl, - Fin.foldl_eq_foldl_finRange] using - transformerStackBounds_spec_list eps layers heads scores hne heps hsqrt - (List.finRange numLayers) lo hi x hlo hhi - -/-- Real-valued transformer stack output after the final LayerNorm. -/ -noncomputable def transformerStackFinalReal {seq dModel dHead numHeads hidden numLayers : Nat} - [NeZero seq] (eps : Rat) (finalLn : Model.Gpt2FinalLayerNorm dModel) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (x : Fin seq → Fin dModel → Real) (q : Fin seq) (i : Fin dModel) : Real := - layerNormRealOfReal eps finalLn.gamma finalLn.beta - (fun j => transformerStackReal eps layers heads scores x q j) i - -/-- Interval bounds for transformer stack outputs after the final LayerNorm. -/ -def transformerStackFinalBounds {dModel dHead numHeads hidden numLayers : Nat} - (eps : Rat) (finalLn : Model.Gpt2FinalLayerNorm dModel) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (lo hi : Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := - let stack := transformerStackBounds eps layers heads lo hi - layerNormIntervalBounds eps finalLn.gamma finalLn.beta stack.1 stack.2 - -/-- `transformerStackFinalBounds` soundness for real outputs. -/ -theorem transformerStackFinalBounds_spec {seq dModel dHead numHeads hidden numLayers : Nat} - [NeZero seq] (eps : Rat) (finalLn : Model.Gpt2FinalLayerNorm dModel) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) - (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) - (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : - let bounds := transformerStackFinalBounds eps finalLn layers heads lo hi - ∀ q i, - (bounds.1 i : Real) ≤ transformerStackFinalReal eps finalLn layers heads scores x q i ∧ - transformerStackFinalReal eps finalLn layers heads scores x q i ≤ (bounds.2 i : Real) := by - classical - intro bounds q i - let stack := transformerStackBounds eps layers heads lo hi - have hstack := - transformerStackBounds_spec eps layers heads scores lo hi x hne heps hsqrt hlo hhi q - have hlo' : ∀ k, (stack.1 k : Real) ≤ transformerStackReal eps layers heads scores x q k := - fun k => (hstack k).1 - have hhi' : ∀ k, transformerStackReal eps layers heads scores x q k ≤ (stack.2 k : Real) := - fun k => (hstack k).2 - have hln := - layerNormIntervalBounds_spec_real eps finalLn.gamma finalLn.beta stack.1 stack.2 - (fun j => transformerStackReal eps layers heads scores x q j) hne heps hsqrt hlo' hhi' - simpa [bounds, transformerStackFinalBounds, stack, transformerStackFinalReal] using hln i - -/-- Interval bounds for transformer stack outputs after the final LayerNorm (per-position). -/ -def transformerStackFinalBoundsPos - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) - (finalLn : Model.Gpt2FinalLayerNorm dModel) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (lo hi : Fin seq → Fin dModel → Rat) : - (Fin seq → Fin dModel → Rat) × (Fin seq → Fin dModel → Rat) := - let stack := transformerStackBoundsPos eps layers heads lo hi - let ln := fun q => - layerNormIntervalBounds eps finalLn.gamma finalLn.beta (stack.1 q) (stack.2 q) - (fun q i => (ln q).1 i, fun q i => (ln q).2 i) - -/-- `transformerStackFinalBoundsPos` soundness for real outputs. -/ -theorem transformerStackFinalBoundsPos_spec - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) - (finalLn : Model.Gpt2FinalLayerNorm dModel) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (lo hi : Fin seq → Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) - (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) - (hlo : ∀ q i, (lo q i : Real) ≤ x q i) - (hhi : ∀ q i, x q i ≤ (hi q i : Real)) : - let bounds := transformerStackFinalBoundsPos eps finalLn layers heads lo hi - ∀ q i, - (bounds.1 q i : Real) ≤ - transformerStackFinalReal eps finalLn layers heads scores x q i ∧ - transformerStackFinalReal eps finalLn layers heads scores x q i ≤ - (bounds.2 q i : Real) := by - classical - intro bounds q i - let stack := transformerStackBoundsPos eps layers heads lo hi - have hstack := - transformerStackBoundsPos_spec eps layers heads scores lo hi x hne heps hsqrt hlo hhi q - have hlo' : ∀ j, (stack.1 q j : Real) ≤ transformerStackReal eps layers heads scores x q j := - fun j => (hstack j).1 - have hhi' : ∀ j, transformerStackReal eps layers heads scores x q j ≤ (stack.2 q j : Real) := - fun j => (hstack j).2 - have hln := - layerNormIntervalBounds_spec_real eps finalLn.gamma finalLn.beta (stack.1 q) (stack.2 q) - (fun j => transformerStackReal eps layers heads scores x q j) hne heps hsqrt hlo' hhi' - simpa [bounds, transformerStackFinalBoundsPos, stack, transformerStackFinalReal] using hln i - -/-- Residual interval bounds for a GPT-2 stack from exact embeddings. -/ -def gpt2ResidualIntervalBounds - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (finalLn : Model.Gpt2FinalLayerNorm dModel) - (embed : Fin seq → Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := - let base := embeddingIntervalBounds embed - transformerStackFinalBounds eps finalLn layers heads base.1 base.2 - -/-- `gpt2ResidualIntervalBounds` soundness for real GPT-2 outputs. -/ -theorem gpt2ResidualIntervalBounds_spec - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] (eps : Rat) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (finalLn : Model.Gpt2FinalLayerNorm dModel) - (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (embed : Fin seq → Fin dModel → Rat) - (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : - let bounds := gpt2ResidualIntervalBounds eps layers heads finalLn embed - ∀ q i, - (bounds.1 i : Real) ≤ - transformerStackFinalReal eps finalLn layers heads scores - (fun q i => (embed q i : Real)) q i ∧ - transformerStackFinalReal eps finalLn layers heads scores - (fun q i => (embed q i : Real)) q i ≤ (bounds.2 i : Real) := by - classical - intro bounds q i - let base := embeddingIntervalBounds embed - have hbase := embeddingIntervalBounds_spec embed - have hlo : ∀ q i, (base.1 i : Real) ≤ (embed q i : Real) := fun q i => (hbase q i).1 - have hhi : ∀ q i, (embed q i : Real) ≤ (base.2 i : Real) := fun q i => (hbase q i).2 - have hstack := - transformerStackFinalBounds_spec eps finalLn layers heads scores base.1 base.2 - (fun q i => (embed q i : Real)) hne heps hsqrt hlo hhi q i - simpa [bounds, gpt2ResidualIntervalBounds, base] using hstack - -/-- Residual interval bounds over an active set from exact embeddings. -/ -def gpt2ResidualIntervalBoundsActive - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] - (active : Finset (Fin seq)) (hactive : active.Nonempty) (eps : Rat) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (finalLn : Model.Gpt2FinalLayerNorm dModel) - (embed : Fin seq → Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := - let baseLo : Fin seq → Fin dModel → Rat := embed - let baseHi : Fin seq → Fin dModel → Rat := embed - let final := transformerStackFinalBoundsPos eps finalLn layers heads baseLo baseHi - intervalBoundsOn active hactive final.1 final.2 - -/-- `gpt2ResidualIntervalBoundsActive` soundness for real GPT-2 outputs. -/ -theorem gpt2ResidualIntervalBoundsActive_spec - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] - (active : Finset (Fin seq)) (hactive : active.Nonempty) (eps : Rat) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (finalLn : Model.Gpt2FinalLayerNorm dModel) - (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (embed : Fin seq → Fin dModel → Rat) - (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : - let bounds := gpt2ResidualIntervalBoundsActive active hactive eps layers heads finalLn embed - ∀ q, q ∈ active → ∀ i, - (bounds.1 i : Real) ≤ - transformerStackFinalReal eps finalLn layers heads scores - (fun q i => (embed q i : Real)) q i ∧ - transformerStackFinalReal eps finalLn layers heads scores - (fun q i => (embed q i : Real)) q i ≤ (bounds.2 i : Real) := by - classical - intro bounds q hq i - let baseLo : Fin seq → Fin dModel → Rat := embed - let baseHi : Fin seq → Fin dModel → Rat := embed - let final := transformerStackFinalBoundsPos eps finalLn layers heads baseLo baseHi - have hfinal := - transformerStackFinalBoundsPos_spec eps finalLn layers heads scores baseLo baseHi - (fun q i => (embed q i : Real)) hne heps hsqrt - (fun q i => by simp [baseLo]) - (fun q i => by simp [baseHi]) - have hlo : ∀ q, q ∈ active → ∀ i, - (final.1 q i : Real) ≤ - transformerStackFinalReal eps finalLn layers heads scores - (fun q i => (embed q i : Real)) q i := by - intro q hq i - simpa [final] using (hfinal q i).1 - have hhi : ∀ q, q ∈ active → ∀ i, - transformerStackFinalReal eps finalLn layers heads scores - (fun q i => (embed q i : Real)) q i ≤ (final.2 q i : Real) := by - intro q hq i - simpa [final] using (hfinal q i).2 - have hbounds := intervalBoundsOn_spec active hactive final.1 final.2 - (fun q i => transformerStackFinalReal eps finalLn layers heads scores - (fun q i => (embed q i : Real)) q i) - hlo hhi - simpa [bounds, gpt2ResidualIntervalBoundsActive, final, baseLo, baseHi] using - hbounds q hq i - -end Bounds - - -end Nfp diff --git a/Nfp/Bounds/Transformer/Embedding.lean b/Nfp/Bounds/Transformer/Embedding.lean deleted file mode 100644 index 9f1943a..0000000 --- a/Nfp/Bounds/Transformer/Embedding.lean +++ /dev/null @@ -1,131 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Mathlib.Algebra.BigOperators.Group.Finset.Basic -public import Nfp.Core.Basic - -/-! -Embedding interval bounds for transformer stacks. - -This module isolates per-position and per-set embedding bounds. --/ - -public section - -namespace Nfp - - -namespace Bounds - -open scoped BigOperators - -private lemma fin_univ_nonempty (seq : Nat) [NeZero seq] : - (Finset.univ : Finset (Fin seq)).Nonempty := by - classical - refine ⟨⟨0, ?_⟩, by simp⟩ - exact Nat.pos_of_ne_zero (NeZero.ne (n := seq)) - -/-- `inf'`/`sup'` bounds for a selected position. -/ -private lemma inf_sup_bounds {seq : Nat} (positions : Finset (Fin seq)) - (hpos : positions.Nonempty) (f : Fin seq → Rat) - {q : Fin seq} (hq : q ∈ positions) : - positions.inf' hpos f ≤ f q ∧ f q ≤ positions.sup' hpos f := by - constructor - · exact Finset.inf'_le (s := positions) (f := f) (b := q) hq - · exact Finset.le_sup' (s := positions) (f := f) (b := q) hq - -/-- Interval bounds across tokens for an embedding map. -/ -def embeddingIntervalBounds {seq dModel : Nat} [NeZero seq] - (x : Fin seq → Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := - let h : (Finset.univ : Finset (Fin seq)).Nonempty := fin_univ_nonempty (seq := seq) - (fun i => (Finset.univ).inf' h (fun q => x q i), - fun i => (Finset.univ).sup' h (fun q => x q i)) - -/-- `embeddingIntervalBounds` bounds embeddings coordinatewise. -/ -theorem embeddingIntervalBounds_spec {seq dModel : Nat} [NeZero seq] - (x : Fin seq → Fin dModel → Rat) : - let bounds := embeddingIntervalBounds x - ∀ q i, - (bounds.1 i : Real) ≤ (x q i : Real) ∧ - (x q i : Real) ≤ (bounds.2 i : Real) := by - classical - intro bounds q i - have hbounds : - bounds.1 i ≤ x q i ∧ x q i ≤ bounds.2 i := by - have h := inf_sup_bounds (positions := (Finset.univ : Finset (Fin seq))) - (hpos := fin_univ_nonempty (seq := seq)) (f := fun k => x k i) - (q := q) (hq := by simp) - simpa [bounds, embeddingIntervalBounds, fin_univ_nonempty] using h - constructor - · simpa [ratToReal_def] using ratToReal_le_of_le hbounds.1 - · simpa [ratToReal_def] using ratToReal_le_of_le hbounds.2 - -/-- Interval bounds across a finite set of positions for an embedding map. -/ -def embeddingIntervalBoundsOn {seq dModel : Nat} [NeZero seq] - (positions : Finset (Fin seq)) (hpos : positions.Nonempty) - (x : Fin seq → Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := - (fun i => positions.inf' hpos (fun q => x q i), - fun i => positions.sup' hpos (fun q => x q i)) - -/-- `embeddingIntervalBoundsOn` bounds embeddings on the chosen positions. -/ -theorem embeddingIntervalBoundsOn_spec {seq dModel : Nat} [NeZero seq] - (positions : Finset (Fin seq)) (hpos : positions.Nonempty) - (x : Fin seq → Fin dModel → Rat) : - let bounds := embeddingIntervalBoundsOn positions hpos x - ∀ q, q ∈ positions → ∀ i, - (bounds.1 i : Real) ≤ (x q i : Real) ∧ - (x q i : Real) ≤ (bounds.2 i : Real) := by - classical - intro bounds q hq i - have hbounds : bounds.1 i ≤ x q i ∧ x q i ≤ bounds.2 i := by - have h := inf_sup_bounds (positions := positions) (hpos := hpos) - (f := fun k => x k i) (q := q) (hq := hq) - simpa [bounds, embeddingIntervalBoundsOn] using h - constructor - · simpa [ratToReal_def] using ratToReal_le_of_le hbounds.1 - · simpa [ratToReal_def] using ratToReal_le_of_le hbounds.2 - -/-- Collapse per-position interval bounds over a finite set of positions. -/ -def intervalBoundsOn {seq dModel : Nat} [NeZero seq] - (positions : Finset (Fin seq)) (hpos : positions.Nonempty) - (lo hi : Fin seq → Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := - (fun i => positions.inf' hpos (fun q => lo q i), - fun i => positions.sup' hpos (fun q => hi q i)) - -/-- `intervalBoundsOn` soundness for bounds on the chosen positions. -/ -theorem intervalBoundsOn_spec {seq dModel : Nat} [NeZero seq] - (positions : Finset (Fin seq)) (hpos : positions.Nonempty) - (lo hi : Fin seq → Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) - (hlo : ∀ q, q ∈ positions → ∀ i, (lo q i : Real) ≤ x q i) - (hhi : ∀ q, q ∈ positions → ∀ i, x q i ≤ (hi q i : Real)) : - let bounds := intervalBoundsOn positions hpos lo hi - ∀ q, q ∈ positions → ∀ i, - (bounds.1 i : Real) ≤ x q i ∧ - x q i ≤ (bounds.2 i : Real) := by - classical - intro bounds q hq i - have hmin : bounds.1 i ≤ lo q i := by - have h := inf_sup_bounds (positions := positions) (hpos := hpos) - (f := fun k => lo k i) (q := q) (hq := hq) - simpa [bounds, intervalBoundsOn] using h.1 - have hmax : hi q i ≤ bounds.2 i := by - have h := inf_sup_bounds (positions := positions) (hpos := hpos) - (f := fun k => hi k i) (q := q) (hq := hq) - simpa [bounds, intervalBoundsOn] using h.2 - have hlo' := hlo q hq i - have hhi' := hhi q hq i - constructor - · have hmin_real : - (bounds.1 i : Real) ≤ (lo q i : Real) := by - simpa [ratToReal_def] using ratToReal_le_of_le hmin - exact le_trans hmin_real hlo' - · have hmax_real : - (hi q i : Real) ≤ (bounds.2 i : Real) := by - simpa [ratToReal_def] using ratToReal_le_of_le hmax - exact le_trans hhi' hmax_real - -end Bounds - - -end Nfp diff --git a/Nfp/Circuit/Cert.lean b/Nfp/Circuit/Cert.lean index b6f1bc4..e078289 100644 --- a/Nfp/Circuit/Cert.lean +++ b/Nfp/Circuit/Cert.lean @@ -3,11 +3,8 @@ module public import Nfp.Circuit.Cert.Basic -public import Nfp.Circuit.Cert.DownstreamLinear public import Nfp.Circuit.Cert.InductionHead public import Nfp.Circuit.Cert.LogitDiff -public import Nfp.Circuit.Cert.ResidualBound -public import Nfp.Circuit.Cert.ResidualInterval public import Nfp.Circuit.Cert.SoftmaxMargin public import Nfp.Circuit.Cert.ValueRange diff --git a/Nfp/Circuit/Cert/DownstreamLinear.lean b/Nfp/Circuit/Cert/DownstreamLinear.lean deleted file mode 100644 index 9d21059..0000000 --- a/Nfp/Circuit/Cert/DownstreamLinear.lean +++ /dev/null @@ -1,66 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.Core.Basic -public import Nfp.Circuit.Cert.Basic - -/-! -Downstream linear certificates for end-to-end induction bounds. - -These certificates record a nonnegative error bound computed externally. -The checker only verifies arithmetic consistency (`error = gain * inputBound`) -and nonnegativity of the reported quantities. --/ - -public section - -namespace Nfp - -namespace Circuit - -/-- Certificate payload for downstream linear error bounds. -/ -structure DownstreamLinearCert where - /-- Upper bound on the downstream logit-diff error. -/ - error : Rat - /-- Operator gain bound used to justify the error. -/ - gain : Rat - /-- Input magnitude bound used to justify the error. -/ - inputBound : Rat - -/-- Arithmetic properties enforced by `checkDownstreamLinearCert`. -/ -structure DownstreamLinearBounds (c : DownstreamLinearCert) : Prop where - /-- Error bound is nonnegative. -/ - error_nonneg : 0 ≤ c.error - /-- Gain bound is nonnegative. -/ - gain_nonneg : 0 ≤ c.gain - /-- Input bound is nonnegative. -/ - input_nonneg : 0 ≤ c.inputBound - /-- Error bound matches the reported gain/input product. -/ - error_eq : c.error = c.gain * c.inputBound - -/-- Boolean checker for downstream linear certificates. -/ -def checkDownstreamLinearCert (c : DownstreamLinearCert) : Bool := - decide (0 ≤ c.error) && - decide (0 ≤ c.gain) && - decide (0 ≤ c.inputBound) && - decide (c.error = c.gain * c.inputBound) - -/-- `checkDownstreamLinearCert` is sound for `DownstreamLinearBounds`. -/ -theorem checkDownstreamLinearCert_sound (c : DownstreamLinearCert) : - checkDownstreamLinearCert c = true → DownstreamLinearBounds c := by - intro h - have h' : - ((0 ≤ c.error ∧ 0 ≤ c.gain) ∧ 0 ≤ c.inputBound) ∧ - c.error = c.gain * c.inputBound := by - simpa [checkDownstreamLinearCert, Bool.and_eq_true, decide_eq_true_iff] using h - rcases h' with ⟨⟨⟨herror, hgain⟩, hinput⟩, heq⟩ - refine - { error_nonneg := herror - gain_nonneg := hgain - input_nonneg := hinput - error_eq := heq } - -end Circuit - -end Nfp diff --git a/Nfp/Circuit/Cert/ResidualBound.lean b/Nfp/Circuit/Cert/ResidualBound.lean deleted file mode 100644 index 167803b..0000000 --- a/Nfp/Circuit/Cert/ResidualBound.lean +++ /dev/null @@ -1,51 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.Core.Basic -public import Nfp.Circuit.Cert.Basic - -/-! -Residual-stream bound certificates. - -These certificates record per-coordinate absolute bounds for residual vectors. --/ - -public section - -namespace Nfp - -namespace Circuit - -/-- Certificate payload for per-coordinate residual absolute bounds. -/ -structure ResidualBoundCert (n : Nat) where - /-- Absolute bound per coordinate. -/ - bound : Fin n → Rat - -/-- Properties enforced by `checkResidualBoundCert`. -/ -structure ResidualBoundBounds {n : Nat} (c : ResidualBoundCert n) : Prop where - /-- Residual bounds are nonnegative. -/ - bound_nonneg : ∀ i, 0 ≤ c.bound i - -/-- Boolean checker for residual-bound certificates. -/ -def checkResidualBoundCert {n : Nat} (c : ResidualBoundCert n) : Bool := - finsetAll (Finset.univ : Finset (Fin n)) (fun i => decide (0 ≤ c.bound i)) - -/-- `checkResidualBoundCert` is sound for `ResidualBoundBounds`. -/ -theorem checkResidualBoundCert_sound {n : Nat} (c : ResidualBoundCert n) : - checkResidualBoundCert c = true → ResidualBoundBounds c := by - intro hcheck - have hall : - finsetAll (Finset.univ : Finset (Fin n)) (fun i => - decide (0 ≤ c.bound i)) = true := by - simpa [checkResidualBoundCert] using hcheck - have hall' := - (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin n)))).1 hall - refine { bound_nonneg := ?_ } - intro i - have hi := hall' i (by simp) - simpa [decide_eq_true_iff] using hi - -end Circuit - -end Nfp diff --git a/Nfp/Circuit/Cert/ResidualInterval.lean b/Nfp/Circuit/Cert/ResidualInterval.lean deleted file mode 100644 index 7295ecc..0000000 --- a/Nfp/Circuit/Cert/ResidualInterval.lean +++ /dev/null @@ -1,53 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.Core.Basic -public import Nfp.Circuit.Cert.Basic - -/-! -Residual-stream interval certificates. - -These certificates record per-coordinate lower/upper bounds for residual vectors. --/ - -public section - -namespace Nfp - -namespace Circuit - -/-- Certificate payload for per-coordinate residual intervals. -/ -structure ResidualIntervalCert (n : Nat) where - /-- Lower bound per coordinate. -/ - lo : Fin n → Rat - /-- Upper bound per coordinate. -/ - hi : Fin n → Rat - -/-- Properties enforced by `checkResidualIntervalCert`. -/ -structure ResidualIntervalBounds {n : Nat} (c : ResidualIntervalCert n) : Prop where - /-- Lower bounds are at most upper bounds. -/ - lo_le_hi : ∀ i, c.lo i ≤ c.hi i - -/-- Boolean checker for residual-interval certificates. -/ -def checkResidualIntervalCert {n : Nat} (c : ResidualIntervalCert n) : Bool := - finsetAll (Finset.univ : Finset (Fin n)) (fun i => decide (c.lo i ≤ c.hi i)) - -/-- `checkResidualIntervalCert` is sound for `ResidualIntervalBounds`. -/ -theorem checkResidualIntervalCert_sound {n : Nat} (c : ResidualIntervalCert n) : - checkResidualIntervalCert c = true → ResidualIntervalBounds c := by - intro hcheck - have hall : - finsetAll (Finset.univ : Finset (Fin n)) (fun i => - decide (c.lo i ≤ c.hi i)) = true := by - simpa [checkResidualIntervalCert] using hcheck - have hall' := - (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin n)))).1 hall - refine { lo_le_hi := ?_ } - intro i - have hi := hall' i (by simp) - simpa [decide_eq_true_iff] using hi - -end Circuit - -end Nfp diff --git a/Nfp/IO/Loaders.lean b/Nfp/IO/Loaders.lean deleted file mode 100644 index 8d1259e..0000000 --- a/Nfp/IO/Loaders.lean +++ /dev/null @@ -1,54 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.IO.Parse -public import Nfp.Circuit.Cert.DownstreamLinear -public import Nfp.Circuit.Cert.ResidualBound -public import Nfp.Circuit.Cert.ResidualInterval - -/-! -IO loaders for certificates. --/ - -public section - -namespace Nfp - -namespace IO - -open Nfp.Circuit - -/-- Load a softmax-margin certificate from disk. -/ -def loadSoftmaxMarginCert (path : System.FilePath) : - IO (Except String (Sigma SoftmaxMarginCert)) := do - let data ← IO.FS.readFile path - return Parse.parseSoftmaxMarginCert data - -/-- Load a value-range certificate from disk. -/ -def loadValueRangeCert (path : System.FilePath) : - IO (Except String (Sigma ValueRangeCert)) := do - let data ← IO.FS.readFile path - return Parse.parseValueRangeCert data - -/-- Load a downstream linear certificate from disk. -/ -def loadDownstreamLinearCert (path : System.FilePath) : - IO (Except String DownstreamLinearCert) := do - let data ← IO.FS.readFile path - return Parse.parseDownstreamLinearCert data - -/-- Load a residual-bound certificate from disk. -/ -def loadResidualBoundCert (path : System.FilePath) : - IO (Except String (Sigma (fun n => ResidualBoundCert n))) := do - let data ← IO.FS.readFile path - return Parse.parseResidualBoundCert data - -/-- Load a residual-interval certificate from disk. -/ -def loadResidualIntervalCert (path : System.FilePath) : - IO (Except String (Sigma (fun n => ResidualIntervalCert n))) := do - let data ← IO.FS.readFile path - return Parse.parseResidualIntervalCert data - -end IO - -end Nfp diff --git a/Nfp/IO/Parse.lean b/Nfp/IO/Parse.lean index 3bc69fa..7d4b73c 100644 --- a/Nfp/IO/Parse.lean +++ b/Nfp/IO/Parse.lean @@ -3,8 +3,6 @@ module public import Nfp.IO.Parse.Basic -public import Nfp.IO.Parse.Downstream -public import Nfp.IO.Parse.Residual public import Nfp.IO.Parse.SoftmaxMargin public import Nfp.IO.Parse.ValueRange diff --git a/Nfp/IO/Parse/Downstream.lean b/Nfp/IO/Parse/Downstream.lean deleted file mode 100644 index f950aec..0000000 --- a/Nfp/IO/Parse/Downstream.lean +++ /dev/null @@ -1,80 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.Circuit.Cert.DownstreamLinear -public import Nfp.IO.Parse.Basic - -/-! -Parse parsing helpers for downstream linear certificates. --/ - -public section - -namespace Nfp - -namespace IO - -namespace Parse - -open Nfp.Circuit - -private structure DownstreamLinearParseState where - error : Option Rat - gain : Option Rat - inputBound : Option Rat - -private def initDownstreamLinearState : DownstreamLinearParseState := - { error := none, gain := none, inputBound := none } - -private def parseDownstreamLinearLine (st : DownstreamLinearParseState) - (tokens : List String) : Except String DownstreamLinearParseState := do - match tokens with - | ["error", val] => - if st.error.isSome then - throw "duplicate error entry" - else - return { st with error := some (← parseRat val) } - | ["gain", val] => - if st.gain.isSome then - throw "duplicate gain entry" - else - return { st with gain := some (← parseRat val) } - | ["input-bound", val] => - if st.inputBound.isSome then - throw "duplicate input-bound entry" - else - return { st with inputBound := some (← parseRat val) } - | _ => - throw s!"unrecognized line: '{String.intercalate " " tokens}'" - -private def finalizeDownstreamLinearState (st : DownstreamLinearParseState) : - Except String Circuit.DownstreamLinearCert := do - let error ← - match st.error with - | some v => pure v - | none => throw "missing error entry" - let gain ← - match st.gain with - | some v => pure v - | none => throw "missing gain entry" - let inputBound ← - match st.inputBound with - | some v => pure v - | none => throw "missing input-bound entry" - return { error := error, gain := gain, inputBound := inputBound } - -/-- Parse a downstream linear certificate from a text payload. -/ -def parseDownstreamLinearCert (input : String) : - Except String Circuit.DownstreamLinearCert := do - let lines := input.splitOn "\n" - let tokens := lines.filterMap cleanTokens - let st0 := initDownstreamLinearState - let st ← tokens.foldlM (fun st t => parseDownstreamLinearLine st t) st0 - finalizeDownstreamLinearState st - -end Parse - -end IO - -end Nfp diff --git a/Nfp/IO/Parse/Residual.lean b/Nfp/IO/Parse/Residual.lean deleted file mode 100644 index 3769a79..0000000 --- a/Nfp/IO/Parse/Residual.lean +++ /dev/null @@ -1,140 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.Circuit.Cert.ResidualBound -public import Nfp.Circuit.Cert.ResidualInterval -public import Nfp.IO.Parse.Basic - -/-! -Parse parsing helpers for residual-bound and residual-interval certificates. --/ - -public section - -namespace Nfp - -namespace IO - -namespace Parse - -open Nfp.Circuit - -private structure ResidualBoundParseState (n : Nat) where - bounds : Fin n → Option Rat - -private def initResidualBoundState (n : Nat) : ResidualBoundParseState n := - { bounds := fun _ => none } - -private def setVectorEntry {n : Nat} (bounds : Fin n → Option Rat) - (i : Nat) (v : Rat) : Except String (Fin n → Option Rat) := do - if hi : i < n then - let iFin : Fin n := ⟨i, hi⟩ - match bounds iFin with - | some _ => - throw s!"duplicate bound entry at index {i}" - | none => - let bounds' : Fin n → Option Rat := fun i' => - if i' = iFin then - some v - else - bounds i' - return bounds' - else - throw s!"index out of range: {i}" - -private def parseResidualBoundLine {n : Nat} (st : ResidualBoundParseState n) - (tokens : List String) : Except String (ResidualBoundParseState n) := do - match tokens with - | ["bound", i, val] => - let bounds ← setVectorEntry st.bounds (← parseNat i) (← parseRat val) - return { st with bounds := bounds } - | ["dim", _] => - throw "duplicate dim entry" - | _ => - throw s!"unrecognized line: '{String.intercalate " " tokens}'" - -private def finalizeResidualBoundState {n : Nat} (st : ResidualBoundParseState n) : - Except String (Circuit.ResidualBoundCert n) := do - if !finsetAll (Finset.univ : Finset (Fin n)) (fun i => (st.bounds i).isSome) then - throw "missing bound entries" - let bound : Fin n → Rat := fun i => - (st.bounds i).getD 0 - return { bound := bound } - -/-- Parse a residual-bound payload from text. -/ -def parseResidualBoundCert (input : String) : - Except String (Sigma (fun n => Circuit.ResidualBoundCert n)) := do - let lines := input.splitOn "\n" - let tokens := lines.filterMap cleanTokens - match tokens with - | [] => throw "empty residual-bound payload" - | ["dim", nStr] :: rest => - let n ← parseNat nStr - match n with - | 0 => throw "dim must be positive" - | Nat.succ n' => - let dim := Nat.succ n' - let st0 := initResidualBoundState dim - let st ← rest.foldlM (fun st t => parseResidualBoundLine st t) st0 - let cert ← finalizeResidualBoundState st - return ⟨dim, cert⟩ - | _ => throw "expected header 'dim '" - -private structure ResidualIntervalParseState (n : Nat) where - lo : Fin n → Option Rat - hi : Fin n → Option Rat - -private def initResidualIntervalState (n : Nat) : ResidualIntervalParseState n := - { lo := fun _ => none, hi := fun _ => none } - -private def parseResidualIntervalLine {n : Nat} (st : ResidualIntervalParseState n) - (tokens : List String) : Except String (ResidualIntervalParseState n) := do - match tokens with - | ["lo", i, val] => - let lo ← setVectorEntry st.lo (← parseNat i) (← parseRat val) - return { st with lo := lo } - | ["hi", i, val] => - let hi ← setVectorEntry st.hi (← parseNat i) (← parseRat val) - return { st with hi := hi } - | ["dim", _] => - throw "duplicate dim entry" - | _ => - throw s!"unrecognized line: '{String.intercalate " " tokens}'" - -private def finalizeResidualIntervalState {n : Nat} (st : ResidualIntervalParseState n) : - Except String (Circuit.ResidualIntervalCert n) := do - if !finsetAll (Finset.univ : Finset (Fin n)) (fun i => (st.lo i).isSome) then - throw "missing lo entries" - if !finsetAll (Finset.univ : Finset (Fin n)) (fun i => (st.hi i).isSome) then - throw "missing hi entries" - let lo : Fin n → Rat := fun i => - (st.lo i).getD 0 - let hi : Fin n → Rat := fun i => - (st.hi i).getD 0 - return { lo := lo, hi := hi } - -/-- Parse a residual-interval payload from text. -/ -def parseResidualIntervalCert (input : String) : - Except String (Sigma (fun n => Circuit.ResidualIntervalCert n)) := do - let lines := input.splitOn "\n" - let tokens := lines.filterMap cleanTokens - match tokens with - | [] => throw "empty residual-interval payload" - | ["dim", nStr] :: rest => - let n ← parseNat nStr - match n with - | 0 => throw "dim must be positive" - | Nat.succ n' => - let dim := Nat.succ n' - let st0 := initResidualIntervalState dim - let st ← rest.foldlM (fun st t => parseResidualIntervalLine st t) st0 - let cert ← finalizeResidualIntervalState st - return ⟨dim, cert⟩ - | _ => throw "expected header 'dim '" - -end Parse - -end IO - -end Nfp diff --git a/Nfp/Model/Gpt2.lean b/Nfp/Model/Gpt2.lean index 61d23a7..1f14211 100644 --- a/Nfp/Model/Gpt2.lean +++ b/Nfp/Model/Gpt2.lean @@ -6,11 +6,11 @@ public import Nfp.Core.Basic public import Nfp.Circuit.Cert.ValueRange /-! -Exact GPT-2 slices for induction certification and downstream bounds. +Exact GPT-2 slices for induction certification. This module holds token embeddings, head projection weights, and per-layer -MLP/LayerNorm parameters used to define `InductionHeadInputs` and downstream -bound computations. +MLP/LayerNorm parameters used to define `InductionHeadInputs` and bound +computations. -/ public section diff --git a/Nfp/Model/InductionCircuit.lean b/Nfp/Model/InductionCircuit.lean index 0ddd03b..18a0e56 100644 --- a/Nfp/Model/InductionCircuit.lean +++ b/Nfp/Model/InductionCircuit.lean @@ -10,7 +10,7 @@ Circuit-level induction prompt specifications. These wrappers name the *shifted* prev/active maps that correspond to the canonical induction-head circuit (previous-token head feeding induction head). They are definitional aliases over `InductionPrevSpec*Shift`, but make the -intended mechanistic interpretation explicit for downstream lemmas. +intended mechanistic interpretation explicit for later lemmas. -/ public section diff --git a/Nfp/Sound/Induction.lean b/Nfp/Sound/Induction.lean index a1360a4..c0352fe 100644 --- a/Nfp/Sound/Induction.lean +++ b/Nfp/Sound/Induction.lean @@ -3,7 +3,6 @@ module public import Nfp.Sound.Induction.Core -public import Nfp.Sound.Induction.EndToEnd public import Nfp.Sound.Induction.HeadOutput public import Nfp.Sound.Induction.LogitDiff public import Nfp.Sound.Induction.OneHot @@ -11,6 +10,6 @@ public import Nfp.Sound.Induction.OneHot /-! Soundness lemmas for induction certificates. -This module re-exports the core definitions, head-output interval predicates, -and logit-diff helpers that operate on explicit certificates. +This module re-exports the core definitions, head-output helpers, and logit-diff +lemmas that operate on explicit certificates. -/ diff --git a/Nfp/Sound/Induction/CoreDefs.lean b/Nfp/Sound/Induction/CoreDefs.lean index e620385..d48984a 100644 --- a/Nfp/Sound/Induction/CoreDefs.lean +++ b/Nfp/Sound/Induction/CoreDefs.lean @@ -9,7 +9,6 @@ public import Nfp.Circuit.Layers.Softmax public import Nfp.Core.Basic public import Nfp.Model.InductionHead public import Nfp.Bounds.LayerNorm -public import Nfp.Bounds.MatrixNorm public import Nfp.Linear.FinFold /-! diff --git a/Nfp/Sound/Induction/EndToEnd.lean b/Nfp/Sound/Induction/EndToEnd.lean deleted file mode 100644 index 400911c..0000000 --- a/Nfp/Sound/Induction/EndToEnd.lean +++ /dev/null @@ -1,83 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -module - -public import Nfp.Bounds.Transformer -public import Nfp.Sound.Induction.HeadOutput -public import Nfp.Sound.Induction.LogitDiff - -/-! -End-to-end induction bounds that combine head certificates with transformer-stack intervals. --/ - -public section - -namespace Nfp - -namespace Sound - -/-- Compose head logit-diff bounds with GPT-2 stack output intervals. -/ -theorem logitDiffLowerBound_end_to_end_gpt2 - {seq dModel dHead numHeads hidden numLayers : Nat} [NeZero seq] - (inputs : Model.InductionHeadInputs seq dModel dHead) - (lb : Rat) - (hlb : ∀ q, q ∈ inputs.active → (lb : Real) ≤ headLogitDiff inputs q) - (headCert : Circuit.ResidualIntervalCert dModel) - (hhead : HeadOutputIntervalSound inputs inputs.active headCert) - (eps : Rat) - (layers : Fin numLayers → Model.Gpt2LayerSlice dModel hidden) - (heads : Fin numLayers → Fin numHeads → Model.Gpt2HeadWeights dModel dHead) - (finalLn : Model.Gpt2FinalLayerNorm dModel) - (scores : Fin numLayers → Fin numHeads → Fin seq → Fin seq → Real) - (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < Bounds.sqrtLower eps) - (hactive : inputs.active.Nonempty) : - let bounds := - Bounds.gpt2ResidualIntervalBoundsActive inputs.active hactive eps layers heads finalLn - inputs.embed - let output : Fin seq → Fin dModel → Real := - fun q i => Bounds.transformerStackFinalReal eps finalLn layers heads scores - (fun q i => (inputs.embed q i : Real)) q i - ∀ q, q ∈ inputs.active → - (lb : Real) - - (Bounds.dotIntervalAbsBound inputs.direction - (fun i => bounds.1 i - headCert.hi i) - (fun i => bounds.2 i - headCert.lo i) : Real) ≤ - dotProduct (fun i => (inputs.direction i : Real)) (fun i => output q i) := by - classical - intro bounds output q hq - have hbounds : - ∀ q, q ∈ inputs.active → ∀ i, - (bounds.1 i : Real) ≤ output q i ∧ output q i ≤ (bounds.2 i : Real) := by - simpa [bounds, output] using - (Bounds.gpt2ResidualIntervalBoundsActive_spec - (active := inputs.active) - (hactive := hactive) - (eps := eps) - (layers := layers) - (heads := heads) - (finalLn := finalLn) - (scores := scores) - (embed := inputs.embed) - (hne := hne) - (heps := heps) - (hsqrt := hsqrt)) - have hhead_out := hhead.output_mem - have h := - logitDiffLowerBound_with_output_intervals - (inputs := inputs) - (lb := lb) - (hlb := hlb) - (output := output) - (outLo := bounds.1) - (outHi := bounds.2) - (hout := hbounds) - (headLo := headCert.lo) - (headHi := headCert.hi) - (hhead := hhead_out) - q - hq - simpa [bounds] using h - -end Sound - -end Nfp diff --git a/Nfp/Sound/Induction/HeadOutput.lean b/Nfp/Sound/Induction/HeadOutput.lean index 2723975..9035efc 100644 --- a/Nfp/Sound/Induction/HeadOutput.lean +++ b/Nfp/Sound/Induction/HeadOutput.lean @@ -2,11 +2,10 @@ module -public import Nfp.Circuit.Cert.ResidualInterval public import Nfp.Sound.Induction.CoreDefs /-! -Head-output interval certificates for induction heads. +Head-output definitions for induction heads. -/ public section @@ -53,19 +52,6 @@ theorem headOutput_def (inputs : Model.InductionHeadInputs seq dModel dHead) headOutputWithScores (scoresRealOfInputs inputs) inputs q i := by simp [headOutput] -/-- Soundness predicate for head-output interval bounds. -/ -structure HeadOutputIntervalSound [NeZero seq] - (inputs : Model.InductionHeadInputs seq dModel dHead) - (active : Finset (Fin seq)) - (c : Circuit.ResidualIntervalCert dModel) : Prop where - /-- Interval bounds are ordered coordinatewise. -/ - bounds : Circuit.ResidualIntervalBounds c - /-- Active-query outputs lie inside the interval bounds. -/ - output_mem : - ∀ q, q ∈ active → ∀ i, - (c.lo i : Real) ≤ headOutput inputs q i ∧ - headOutput inputs q i ≤ (c.hi i : Real) - end end Sound diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean index 875df15..0f6c987 100644 --- a/Nfp/Sound/Induction/LogitDiff.lean +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -7,7 +7,6 @@ public import Mathlib.Data.List.MinMax public import Mathlib.Data.Vector.Basic public import Nfp.Circuit.Cert.LogitDiff public import Nfp.Bounds.Cache -public import Nfp.Bounds.MatrixNorm.Interval public import Nfp.Sound.Induction.HeadOutput /-! @@ -947,7 +946,7 @@ theorem logitDiffLowerBoundFromCertBest_le end WithNeZero -/-! End-to-end lower bounds from head certificates plus residual intervals. -/ +/-! Head-output identities. -/ /-- The head logit-diff equals the direction dot product of the head output. -/ theorem headLogitDiff_eq_direction_dot_headOutput @@ -1005,125 +1004,6 @@ theorem headLogitDiff_eq_direction_dot_headOutput (fun i => headOutput inputs q i) := by simpa [dir] using hsum.symm -/-- Combine a head logit-diff bound with residual interval bounds. -/ -theorem logitDiffLowerBound_with_residual - (inputs : Model.InductionHeadInputs seq dModel dHead) - (lb : Rat) - (hlb : ∀ q, q ∈ inputs.active → (lb : Real) ≤ headLogitDiff inputs q) - (residual : Fin seq → Fin dModel → Real) - (lo hi : Fin dModel → Rat) - (hres : ∀ q, q ∈ inputs.active → ∀ i, - (lo i : Real) ≤ residual q i ∧ residual q i ≤ (hi i : Real)) : - ∀ q, q ∈ inputs.active → - (lb : Real) - (Bounds.dotIntervalAbsBound inputs.direction lo hi : Real) ≤ - dotProduct (fun i => (inputs.direction i : Real)) - (fun i => headOutput inputs q i + residual q i) := by - intro q hq - have hhead := hlb q hq - have hres' : - |dotProduct (fun i => (inputs.direction i : Real)) (residual q)| ≤ - (Bounds.dotIntervalAbsBound inputs.direction lo hi : Real) := by - have hlo : ∀ i, (lo i : Real) ≤ residual q i := fun i => (hres q hq i).1 - have hhi : ∀ i, residual q i ≤ (hi i : Real) := fun i => (hres q hq i).2 - simpa using - (Bounds.abs_dotProduct_le_dotIntervalAbsBound_real - (v := inputs.direction) (lo := lo) (hi := hi) (x := residual q) hlo hhi) - have hres_lower : - -(Bounds.dotIntervalAbsBound inputs.direction lo hi : Real) ≤ - dotProduct (fun i => (inputs.direction i : Real)) (residual q) := by - exact (abs_le.mp hres').1 - have hsum : - (lb : Real) - (Bounds.dotIntervalAbsBound inputs.direction lo hi : Real) ≤ - headLogitDiff inputs q + - dotProduct (fun i => (inputs.direction i : Real)) (residual q) := by - have hsum' : (lb : Real) + -(Bounds.dotIntervalAbsBound inputs.direction lo hi : Real) ≤ - headLogitDiff inputs q + - dotProduct (fun i => (inputs.direction i : Real)) (residual q) := - add_le_add hhead hres_lower - simpa [sub_eq_add_neg] using hsum' - calc - (lb : Real) - (Bounds.dotIntervalAbsBound inputs.direction lo hi : Real) ≤ - headLogitDiff inputs q + - dotProduct (fun i => (inputs.direction i : Real)) (residual q) := hsum - _ = - dotProduct (fun i => (inputs.direction i : Real)) - (fun i => headOutput inputs q i + residual q i) := by - have hdot : - dotProduct (fun i => (inputs.direction i : Real)) - (fun i => headOutput inputs q i + residual q i) = - dotProduct (fun i => (inputs.direction i : Real)) - (fun i => headOutput inputs q i) + - dotProduct (fun i => (inputs.direction i : Real)) (residual q) := by - simpa using - (Linear.dotProduct_add_right - (x := fun i => (inputs.direction i : Real)) - (y := fun i => headOutput inputs q i) - (z := residual q)) - simp [headLogitDiff_eq_direction_dot_headOutput, hdot] - -/-- Combine a head logit-diff bound with intervals on head output and a downstream output. -/ -theorem logitDiffLowerBound_with_output_intervals - (inputs : Model.InductionHeadInputs seq dModel dHead) - (lb : Rat) - (hlb : ∀ q, q ∈ inputs.active → (lb : Real) ≤ headLogitDiff inputs q) - (output : Fin seq → Fin dModel → Real) - (outLo outHi : Fin dModel → Rat) - (hout : ∀ q, q ∈ inputs.active → ∀ i, - (outLo i : Real) ≤ output q i ∧ output q i ≤ (outHi i : Real)) - (headLo headHi : Fin dModel → Rat) - (hhead : ∀ q, q ∈ inputs.active → ∀ i, - (headLo i : Real) ≤ headOutput inputs q i ∧ - headOutput inputs q i ≤ (headHi i : Real)) : - ∀ q, q ∈ inputs.active → - (lb : Real) - - (Bounds.dotIntervalAbsBound inputs.direction - (fun i => outLo i - headHi i) (fun i => outHi i - headLo i) : Real) ≤ - dotProduct (fun i => (inputs.direction i : Real)) (fun i => output q i) := by - intro q hq - let residual : Fin seq → Fin dModel → Real := - fun q i => output q i - headOutput inputs q i - let lo : Fin dModel → Rat := fun i => outLo i - headHi i - let hi : Fin dModel → Rat := fun i => outHi i - headLo i - have hres : ∀ q, q ∈ inputs.active → ∀ i, - (lo i : Real) ≤ residual q i ∧ residual q i ≤ (hi i : Real) := by - intro q hq i - have hout_q := hout q hq i - have hhead_q := hhead q hq i - have hlow : - (outLo i : Real) - (headHi i : Real) ≤ - output q i - headOutput inputs q i := by - exact sub_le_sub hout_q.1 hhead_q.2 - have hhigh : - output q i - headOutput inputs q i ≤ - (outHi i : Real) - (headLo i : Real) := by - exact sub_le_sub hout_q.2 hhead_q.1 - constructor - · simpa [lo, residual, ratToReal_sub] using hlow - · simpa [hi, residual, ratToReal_sub] using hhigh - have hbound := - logitDiffLowerBound_with_residual - (inputs := inputs) - (lb := lb) - (hlb := hlb) - (residual := residual) - (lo := lo) - (hi := hi) - hres - q - hq - have hdot : - dotProduct (fun i => (inputs.direction i : Real)) - (fun i => headOutput inputs q i + residual q i) = - dotProduct (fun i => (inputs.direction i : Real)) - (fun i => output q i) := by - refine Finset.sum_congr rfl ?_ - intro i _ - have hsum : - headOutput inputs q i + residual q i = output q i := by - simp [residual, sub_eq_add_neg, add_left_comm] - simp [hsum] - simpa [lo, hi, hdot] using hbound - end LogitDiffLowerBound end Sound diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index b02e687..d1088ec 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -14,14 +14,11 @@ It is intentionally brief and focused on the soundness boundary. unembedding columns represent token logits. - The active set is user-supplied (or defaulted by the parser); bounds only hold for `q ∈ active`. -- Residual and downstream bounds are provided as explicit certificates; there is no verified - end-to-end model derivation of these bounds inside Lean. - Performance: checking large certificates can be expensive for long sequences. ## Remaining work - Prove or verify that `prev`, `active`, and `direction` are derived from token-level semantics. - Add a verified extraction pipeline from model weights to explicit certificates. -- Tighten residual and downstream interval bounds to avoid vacuity. -- Extend the bridge from certificates to full circuit/model semantics and (eventually) to - end-to-end transformer claims. +- Extend the bridge from head-level certificates to full circuit/model semantics and + (eventually) to end-to-end transformer claims. diff --git a/docs/induction_cert_audit.md b/docs/induction_cert_audit.md index 43724bc..2e53363 100644 --- a/docs/induction_cert_audit.md +++ b/docs/induction_cert_audit.md @@ -13,11 +13,8 @@ induction heads, and spell out the scope and limitations of that claim. - `logitDiffLowerBoundAt` plus `logitDiffLowerBoundAt_le` give a certified lower bound on the logit-diff contribution derived from the certificate’s values (`Nfp/Circuit/Cert/LogitDiff.lean`). -- `headLogitDiff_eq_direction_dot_headOutput`, `logitDiffLowerBound_with_residual`, - and `logitDiffLowerBound_with_output_intervals` compose head-level logit-diff - bounds with output intervals (`Nfp/Sound/Induction/LogitDiff.lean`). -- `logitDiffLowerBound_end_to_end_gpt2` instantiates the composition for GPT-2 - stack outputs (`Nfp/Sound/Induction/EndToEnd.lean`). +- `headLogitDiff_eq_direction_dot_headOutput` connects the logit-diff definition + to head-output semantics (`Nfp/Sound/Induction/LogitDiff.lean`). ## Mechanistic mapping (Transformer Circuits) @@ -48,8 +45,6 @@ Key assumptions and limitations: scripts; Lean does not (yet) verify their derivation from token-level semantics. - The active set can be strict; bounds only hold for `q ∈ active`, not all positions. - The direction metadata assumes the unembedding columns encode the model’s logit map. -- End-to-end claims rely on external residual/downstream interval certificates; the - current checker only verifies those certificates once provided. ## Conclusion @@ -62,4 +57,3 @@ along a specified direction, conditional on an explicit certificate. - Add a verified extraction pipeline from model weights to explicit certificates. - Prove that `prev`, `active`, and `direction` correspond to token-level semantics. -- Tighten residual/downstream interval bounds to strengthen end-to-end claims. diff --git a/scripts/build_downstream_linear_cert.py b/scripts/build_downstream_linear_cert.py deleted file mode 100644 index d1640cb..0000000 --- a/scripts/build_downstream_linear_cert.py +++ /dev/null @@ -1,66 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: AGPL-3.0-or-later - -""" -Build a downstream linear certificate from externally computed bounds. - -This script is untrusted: it only formats rational inputs into the downstream -linear certificate format expected by the Lean checker. - -Usage: - python scripts/build_downstream_linear_cert.py \ - --output reports/gpt2_downstream.cert \ - --gain 3/2 \ - --input-bound 5/4 - -Optional: - --error 15/8 # override gain * input-bound -""" - -import argparse -from fractions import Fraction -from pathlib import Path - - -def parse_rat(raw: str) -> Fraction: - if "/" in raw: - num, den = raw.split("/", 1) - return Fraction(int(num.strip()), int(den.strip())) - return Fraction(int(raw.strip()), 1) - - -def rat_to_str(q: Fraction) -> str: - if q.denominator == 1: - return str(q.numerator) - return f"{q.numerator}/{q.denominator}" - - -def main() -> None: - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("--output", required=True, help="Path to write certificate") - parser.add_argument("--gain", required=True, help="Nonnegative gain bound (Rat)") - parser.add_argument("--input-bound", required=True, - help="Nonnegative input bound (Rat)") - parser.add_argument("--error", - help="Optional error override (Rat). Defaults to gain * input-bound.") - args = parser.parse_args() - - gain = parse_rat(args.gain) - input_bound = parse_rat(args.input_bound) - if gain < 0 or input_bound < 0: - raise SystemExit("gain and input-bound must be nonnegative") - error = parse_rat(args.error) if args.error else gain * input_bound - if error < 0: - raise SystemExit("error must be nonnegative") - - output_path = Path(args.output) - output_path.parent.mkdir(parents=True, exist_ok=True) - with output_path.open("w", encoding="ascii") as f: - f.write(f"error {rat_to_str(error)}\n") - f.write(f"gain {rat_to_str(gain)}\n") - f.write(f"input-bound {rat_to_str(input_bound)}\n") - print(f"Wrote downstream certificate to {output_path}") - - -if __name__ == "__main__": - main() diff --git a/scripts/build_residual_bound_cert.py b/scripts/build_residual_bound_cert.py deleted file mode 100644 index 7f54c51..0000000 --- a/scripts/build_residual_bound_cert.py +++ /dev/null @@ -1,151 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: AGPL-3.0-or-later - -""" -Build a residual-bound certificate from a GPT-2 forward pass. - -This script is untrusted. It computes per-coordinate absolute bounds by -taking maxima over a fixed input sequence (optionally restricted to active -positions from a softmax-margin certificate). The resulting bounds are -rounded up to rationals for Lean-side checking. - -Usage: - uv run scripts/build_residual_bound_cert.py \ - --output reports/gpt2_residual.bound \ - --seq 32 --pattern-length 16 \ - --scores reports/gpt2_induction.cert - -Optional: - --tokens tokens.txt # whitespace-separated token ids - --random-pattern --seed 0 - --decimals 6 --safety 1e-6 -""" - -import argparse -import math -from fractions import Fraction -from pathlib import Path - -import numpy as np - -try: - import torch - from transformers import GPT2Model -except ImportError: - raise SystemExit( - "Missing dependencies. Install with: uv add transformers torch" - ) - - -def rat_to_str(q: Fraction) -> str: - if q.denominator == 1: - return str(q.numerator) - return f"{q.numerator}/{q.denominator}" - - -def build_tokens(seq: int, pattern_len: int, random_pattern: bool, seed: int) -> np.ndarray: - if random_pattern: - rng = np.random.default_rng(seed) - pattern = rng.integers(1000, 30000, size=pattern_len, endpoint=False) - else: - pattern = np.arange(pattern_len) - repeats = (seq // pattern_len) + 1 - return np.tile(pattern, repeats)[:seq] - - -def parse_tokens(path: Path) -> np.ndarray: - raw = path.read_text(encoding="ascii") - tokens = [int(tok) for tok in raw.split() if tok.strip()] - if not tokens: - raise SystemExit(f"no tokens found in {path}") - return np.array(tokens, dtype=np.int64) - - -def parse_active_positions(path: Path) -> tuple[int | None, list[int]]: - seq = None - active: list[int] = [] - for line in path.read_text(encoding="ascii").splitlines(): - line = line.strip() - if not line or line.startswith("#"): - continue - parts = line.split() - if parts[0] == "seq" and len(parts) >= 2: - seq = int(parts[1]) - elif parts[0] == "active" and len(parts) >= 2: - active.append(int(parts[1])) - return seq, active - - -def ceil_rat(x: float, decimals: int, safety: float) -> Fraction: - scale = 10 ** decimals - scaled = abs(x) * (1.0 + safety) * scale - return Fraction(int(math.ceil(scaled)), scale) - - -def main() -> None: - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("--output", required=True, help="Path to write certificate") - parser.add_argument("--seq", type=int, default=32, help="Sequence length") - parser.add_argument("--pattern-length", type=int, default=16, help="Pattern length") - parser.add_argument("--random-pattern", action="store_true", - help="Use random token pattern") - parser.add_argument("--seed", type=int, default=0, help="RNG seed for random pattern") - parser.add_argument("--tokens", help="Optional path to whitespace-separated tokens") - parser.add_argument("--scores", help="Optional softmax-margin certificate for active queries") - parser.add_argument("--model", default="gpt2", help="HuggingFace model name") - parser.add_argument("--device", default="cpu", help="Torch device") - parser.add_argument("--decimals", type=int, default=6, - help="Decimal rounding for rationals (ceil)") - parser.add_argument("--safety", type=float, default=1e-6, - help="Relative safety slack added before rounding") - args = parser.parse_args() - - if args.seq <= 0: - raise SystemExit("seq must be positive") - if args.decimals < 0: - raise SystemExit("decimals must be nonnegative") - if args.safety < 0: - raise SystemExit("safety must be nonnegative") - - if args.tokens: - tokens = parse_tokens(Path(args.tokens)) - seq = len(tokens) - else: - seq = args.seq - tokens = build_tokens(seq, args.pattern_length, args.random_pattern, args.seed) - - positions = list(range(seq)) - if args.scores: - cert_seq, active = parse_active_positions(Path(args.scores)) - if cert_seq is not None and cert_seq != seq: - raise SystemExit(f"seq mismatch: scores={cert_seq} tokens={seq}") - if active: - positions = active - - model = GPT2Model.from_pretrained(args.model) - model.to(args.device) - model.eval() - input_ids = torch.tensor(tokens, dtype=torch.long, device=args.device).unsqueeze(0) - with torch.no_grad(): - outputs = model(input_ids) - hidden = outputs.last_hidden_state.squeeze(0).cpu().numpy() - - if hidden.shape[0] != seq: - raise SystemExit(f"hidden state seq mismatch: {hidden.shape[0]} vs {seq}") - - chosen = hidden[positions] - max_abs = np.max(np.abs(chosen), axis=0) - bounds = [ceil_rat(float(val), args.decimals, args.safety) for val in max_abs] - - output_path = Path(args.output) - output_path.parent.mkdir(parents=True, exist_ok=True) - with output_path.open("w", encoding="ascii") as f: - f.write(f"dim {len(bounds)}\n") - for i, bound in enumerate(bounds): - f.write(f"bound {i} {rat_to_str(bound)}\n") - - print(f"Wrote residual-bound certificate to {output_path}") - - -if __name__ == "__main__": - main() diff --git a/scripts/build_residual_interval_cert.py b/scripts/build_residual_interval_cert.py deleted file mode 100644 index 51896d2..0000000 --- a/scripts/build_residual_interval_cert.py +++ /dev/null @@ -1,209 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: AGPL-3.0-or-later - -""" -Build a residual-interval certificate from a GPT-2 forward pass. - -This script is untrusted. It computes per-coordinate min/max bounds by -taking extrema over a fixed input sequence (optionally restricted to active -positions from a softmax-margin certificate). The resulting intervals are -expanded slightly and rounded outwards to rationals for Lean-side checking. - -Usage: - uv run scripts/build_residual_interval_cert.py \ - --output reports/gpt2_residual.interval \ - --seq 32 --pattern-length 16 \ - --scores reports/gpt2_induction.cert - -Optional: - --tokens tokens.txt # whitespace-separated token ids - --nfpt model.nfpt # read tokens from binary model - --random-pattern --seed 0 - --decimals 6 --safety 1e-6 -""" - -import argparse -import math -from fractions import Fraction -from pathlib import Path - -import numpy as np - -try: - import torch - from transformers import GPT2Model -except ImportError: - raise SystemExit( - "Missing dependencies. Install with: uv add transformers torch" - ) - - -def rat_to_str(q: Fraction) -> str: - if q.denominator == 1: - return str(q.numerator) - return f"{q.numerator}/{q.denominator}" - - -def build_tokens(seq: int, pattern_len: int, random_pattern: bool, seed: int) -> np.ndarray: - if random_pattern: - rng = np.random.default_rng(seed) - pattern = rng.integers(1000, 30000, size=pattern_len, endpoint=False) - else: - pattern = np.arange(pattern_len) - repeats = (seq // pattern_len) + 1 - return np.tile(pattern, repeats)[:seq] - - -def parse_tokens(path: Path) -> np.ndarray: - raw = path.read_text(encoding="ascii") - tokens = [int(tok) for tok in raw.split() if tok.strip()] - if not tokens: - raise SystemExit(f"no tokens found in {path}") - return np.array(tokens, dtype=np.int64) - - -def parse_tokens_from_nfpt(path: Path) -> np.ndarray: - header: dict[str, str] = {} - with path.open("rb") as f: - while True: - line = f.readline() - if not line: - raise SystemExit("unexpected EOF while reading header") - text = line.decode("ascii").strip() - if text == "BINARY_START": - break - if "=" in text: - key, value = text.split("=", 1) - header[key.strip()] = value.strip() - seq_len_raw = header.get("seq_len") - if seq_len_raw is None: - raise SystemExit("header missing seq_len") - seq_len = int(seq_len_raw) - token_bytes = f.read(seq_len * 4) - if len(token_bytes) != seq_len * 4: - raise SystemExit("unexpected EOF while reading tokens") - tokens = np.frombuffer(token_bytes, dtype=" tuple[int | None, list[int]]: - seq = None - active: list[int] = [] - for line in path.read_text(encoding="ascii").splitlines(): - line = line.strip() - if not line or line.startswith("#"): - continue - parts = line.split() - if parts[0] == "seq" and len(parts) >= 2: - seq = int(parts[1]) - elif parts[0] == "active" and len(parts) >= 2: - active.append(int(parts[1])) - return seq, active - - -def expand_lo(val: float, safety: float) -> float: - slack = safety * max(1.0, abs(val)) - return val - slack - - -def expand_hi(val: float, safety: float) -> float: - slack = safety * max(1.0, abs(val)) - return val + slack - - -def floor_rat(val: float, decimals: int) -> Fraction: - scale = 10 ** decimals - return Fraction(int(math.floor(val * scale)), scale) - - -def ceil_rat(val: float, decimals: int) -> Fraction: - scale = 10 ** decimals - return Fraction(int(math.ceil(val * scale)), scale) - - -def main() -> None: - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("--output", required=True, help="Path to write certificate") - parser.add_argument("--seq", type=int, default=32, help="Sequence length") - parser.add_argument("--pattern-length", type=int, default=16, help="Pattern length") - parser.add_argument("--random-pattern", action="store_true", - help="Use random token pattern") - parser.add_argument("--seed", type=int, default=0, help="RNG seed for random pattern") - parser.add_argument("--tokens", help="Optional path to whitespace-separated tokens") - parser.add_argument("--nfpt", help="Optional .nfpt file to read tokens from") - parser.add_argument("--scores", help="Optional softmax-margin certificate for active queries") - parser.add_argument("--model", default="gpt2", help="HuggingFace model name") - parser.add_argument("--device", default="cpu", help="Torch device") - parser.add_argument("--decimals", type=int, default=6, - help="Decimal rounding for rationals (outward)") - parser.add_argument("--safety", type=float, default=1e-6, - help="Relative safety slack added before rounding") - args = parser.parse_args() - - if args.seq <= 0: - raise SystemExit("seq must be positive") - if args.decimals < 0: - raise SystemExit("decimals must be nonnegative") - if args.safety < 0: - raise SystemExit("safety must be nonnegative") - - if args.nfpt: - tokens = parse_tokens_from_nfpt(Path(args.nfpt)) - seq = len(tokens) - elif args.tokens: - tokens = parse_tokens(Path(args.tokens)) - seq = len(tokens) - else: - seq = args.seq - tokens = build_tokens(seq, args.pattern_length, args.random_pattern, args.seed) - - positions = list(range(seq)) - if args.scores: - cert_seq, active = parse_active_positions(Path(args.scores)) - if cert_seq is not None and cert_seq != seq: - raise SystemExit(f"seq mismatch: scores={cert_seq} tokens={seq}") - if active: - positions = active - - model = GPT2Model.from_pretrained(args.model) - model.to(args.device) - model.eval() - input_ids = torch.tensor(tokens, dtype=torch.long, device=args.device).unsqueeze(0) - with torch.no_grad(): - outputs = model(input_ids) - hidden = outputs.last_hidden_state.squeeze(0).cpu().numpy() - - if hidden.shape[0] != seq: - raise SystemExit(f"hidden state seq mismatch: {hidden.shape[0]} vs {seq}") - - chosen = hidden[positions] - mins = np.min(chosen, axis=0) - maxs = np.max(chosen, axis=0) - - lo_bounds = [] - hi_bounds = [] - for lo_val, hi_val in zip(mins.tolist(), maxs.tolist(), strict=True): - lo_adj = expand_lo(float(lo_val), args.safety) - hi_adj = expand_hi(float(hi_val), args.safety) - lo_rat = floor_rat(lo_adj, args.decimals) - hi_rat = ceil_rat(hi_adj, args.decimals) - if lo_rat > hi_rat: - lo_rat, hi_rat = hi_rat, lo_rat - lo_bounds.append(lo_rat) - hi_bounds.append(hi_rat) - - output_path = Path(args.output) - output_path.parent.mkdir(parents=True, exist_ok=True) - with output_path.open("w", encoding="ascii") as f: - f.write(f"dim {len(lo_bounds)}\n") - for i, (lo, hi) in enumerate(zip(lo_bounds, hi_bounds, strict=True)): - f.write(f"lo {i} {rat_to_str(lo)}\n") - f.write(f"hi {i} {rat_to_str(hi)}\n") - - print(f"Wrote residual-interval certificate to {output_path}") - - -if __name__ == "__main__": - main() From b23ec0c54e7a43d239da26455ce92a42ffe8c4b7 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 09:31:16 +0100 Subject: [PATCH 224/244] Trim model-level wording from parity docs --- CLAIMS.md | 2 +- SOUNDNESS_LIMITATIONS.md | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/CLAIMS.md b/CLAIMS.md index 44f453d..bb6cdad 100644 --- a/CLAIMS.md +++ b/CLAIMS.md @@ -34,5 +34,5 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri ## Not yet proven - A verified extraction pipeline from model weights to explicit certificates. -- End-to-end claims about GPT-2 logits or Jacobians derived from certificates. +- Model-level claims about logits or Jacobians derived from certificates. - A full bridge from explicit head certificates to complete model semantics. diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index d1088ec..0377f38 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -9,7 +9,7 @@ It is intentionally brief and focused on the soundness boundary. run model evaluation. - Induction certificates are **head-level** (softmax-margin + value-interval + logit-diff lower bound) and conditional on the supplied `prev`, `active`, and `direction` inputs. They do **not** - yet imply end-to-end model behavior. + yet imply full model behavior. - Direction metadata (`direction-target`, `direction-negative`) is untrusted and assumes that the unembedding columns represent token logits. - The active set is user-supplied (or defaulted by the parser); bounds only hold for @@ -20,5 +20,4 @@ It is intentionally brief and focused on the soundness boundary. - Prove or verify that `prev`, `active`, and `direction` are derived from token-level semantics. - Add a verified extraction pipeline from model weights to explicit certificates. -- Extend the bridge from head-level certificates to full circuit/model semantics and - (eventually) to end-to-end transformer claims. +- Extend the bridge from head-level certificates to full circuit/model semantics. From 872a584c7d0d88208d034f8f8043950174282bdc Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 09:35:05 +0100 Subject: [PATCH 225/244] Remove .nfpt format dependencies --- CLAIMS.md | 6 +- README.md | 13 +- .../build_gpt2_induction_cert_from_binary.py | 431 ------------------ scripts/convert_text_fixture_to_binary.py | 273 ----------- scripts/ensure_gelu_kind.py | 86 ---- scripts/export_gpt2.py | 239 ---------- scripts/generate_induction_data.py | 148 ------ scripts/generate_rigorous_induction.py | 232 ---------- tests/fixtures/tiny_sound_binary.nfpt | Bin 1387 -> 0 bytes tests/fixtures/tiny_sound_input.nfpt | 17 - tests/fixtures/tiny_sound_model.nfpt | 62 --- 11 files changed, 3 insertions(+), 1504 deletions(-) delete mode 100644 scripts/build_gpt2_induction_cert_from_binary.py delete mode 100644 scripts/convert_text_fixture_to_binary.py delete mode 100644 scripts/ensure_gelu_kind.py delete mode 100644 scripts/export_gpt2.py delete mode 100644 scripts/generate_induction_data.py delete mode 100644 scripts/generate_rigorous_induction.py delete mode 100644 tests/fixtures/tiny_sound_binary.nfpt delete mode 100644 tests/fixtures/tiny_sound_input.nfpt delete mode 100644 tests/fixtures/tiny_sound_model.nfpt diff --git a/CLAIMS.md b/CLAIMS.md index bb6cdad..118bb05 100644 --- a/CLAIMS.md +++ b/CLAIMS.md @@ -25,10 +25,8 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri ## Untrusted / heuristic -- Python helpers that generate explicit induction-head certificates from GPT-2 weights or - `.nfpt` files: `scripts/build_gpt2_induction_cert.py`, - `scripts/build_gpt2_induction_cert_from_binary.py`. -- Exporters and dataset generators for `.nfpt` model files. +- Python helpers that generate explicit induction-head certificates from GPT-2 weights: + `scripts/build_gpt2_induction_cert.py`. - Any choice of prompts, directions, or candidate heads used by certificate generators. ## Not yet proven diff --git a/README.md b/README.md index 1763292..597fed8 100644 --- a/README.md +++ b/README.md @@ -26,8 +26,7 @@ lake exe nfp induction --help ``` Current subcommands are limited to **induction certificate checking**. The CLI does **not** run a -full model forward pass and does **not** ingest `.nfpt` weights directly; weight ingestion is done -by untrusted helper scripts (see below). +full model forward pass; certificate generation is done by untrusted helper scripts (see below). ## Module map @@ -59,16 +58,6 @@ Optional direction metadata: --direction-target --direction-negative ``` -If you already have an `NFP_BINARY_V1` model file: - -```bash -python scripts/build_gpt2_induction_cert_from_binary.py \ - --model models/gpt2_rigorous.nfpt \ - --layer 5 --head 1 \ - --direction-target 1 --direction-negative 2 \ - --output reports/gpt2_induction.cert -``` - ### Verify a head certificate (trusted checker) ```bash diff --git a/scripts/build_gpt2_induction_cert_from_binary.py b/scripts/build_gpt2_induction_cert_from_binary.py deleted file mode 100644 index c7862d3..0000000 --- a/scripts/build_gpt2_induction_cert_from_binary.py +++ /dev/null @@ -1,431 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: AGPL-3.0-or-later - -""" -Build an induction-head certificate from an NFP_BINARY_V1 model. - -This is untrusted and uses floating-point arithmetic to produce a rational -induction-head certificate compatible with `nfp induction certify`. -""" - -from __future__ import annotations - -import argparse -import math -import struct -from fractions import Fraction -from pathlib import Path -from typing import Dict, Tuple - -import numpy as np - - -def rat_from_float(x: float, decimals: int) -> Fraction: - scale = 10 ** decimals - return Fraction(int(round(x * scale)), scale) - - -def rat_to_str(q: Fraction) -> str: - if q.denominator == 1: - return str(q.numerator) - return f"{q.numerator}/{q.denominator}" - - -def parse_header(f) -> Dict[str, str]: - header: Dict[str, str] = {} - magic = f.readline().decode("ascii").strip() - if magic != "NFP_BINARY_V1": - raise SystemExit(f"Unsupported magic header: {magic}") - while True: - line = f.readline() - if line == b"": - raise SystemExit("Unexpected EOF while reading header.") - text = line.decode("ascii").strip() - if text == "BINARY_START": - break - if "=" in text: - key, value = text.split("=", 1) - header[key.strip()] = value.strip() - return header - - -def read_i32(f, count: int) -> np.ndarray: - raw = f.read(count * 4) - if len(raw) != count * 4: - raise SystemExit("Unexpected EOF while reading int32 payload.") - return np.frombuffer(raw, dtype=" np.ndarray: - raw = f.read(count * 8) - if len(raw) != count * 8: - raise SystemExit("Unexpected EOF while reading float64 payload.") - return np.frombuffer(raw, dtype=" None: - offset = count * 8 - f.seek(offset, 1) - - -def build_prev(tokens: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: - prev = np.zeros_like(tokens) - active = np.zeros_like(tokens, dtype=bool) - last_seen: Dict[int, int] = {} - for idx, tok in enumerate(tokens.tolist()): - if idx == 0: - prev[idx] = 0 - active[idx] = False - else: - if tok in last_seen: - prev[idx] = last_seen[tok] - active[idx] = True - else: - prev[idx] = 0 - active[idx] = False - last_seen[tok] = idx - return prev, active - - -def read_head_weights( - f, - num_layers: int, - num_heads: int, - model_dim: int, - head_dim: int, - hidden_dim: int, - layer: int, - head: int, -) -> Tuple[ - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, -]: - target = (layer, head) - wq = wk = wv = wo = None - bq = bk = bv = None - attn_bias = ln1_gamma = ln1_beta = None - for layer_idx in range(num_layers): - for head_idx in range(num_heads): - wq_block = read_f64(f, model_dim * head_dim).reshape(model_dim, head_dim) - bq_block = read_f64(f, head_dim) - wk_block = read_f64(f, model_dim * head_dim).reshape(model_dim, head_dim) - bk_block = read_f64(f, head_dim) - wv_block = read_f64(f, model_dim * head_dim).reshape(model_dim, head_dim) - bv_block = read_f64(f, head_dim) - wo_block = read_f64(f, head_dim * model_dim).reshape(head_dim, model_dim) - if (layer_idx, head_idx) == target: - wq = wq_block - wk = wk_block - wv = wv_block - wo = wo_block - bq = bq_block - bk = bk_block - bv = bv_block - attn_bias_block = read_f64(f, model_dim) - skip_f64(f, model_dim * hidden_dim) - skip_f64(f, hidden_dim) - skip_f64(f, hidden_dim * model_dim) - skip_f64(f, model_dim) - ln1_gamma_block = read_f64(f, model_dim) - ln1_beta_block = read_f64(f, model_dim) - skip_f64(f, model_dim) - skip_f64(f, model_dim) - if layer_idx == layer: - attn_bias = attn_bias_block - ln1_gamma = ln1_gamma_block - ln1_beta = ln1_beta_block - if ( - wq is None - or wk is None - or wv is None - or wo is None - or bq is None - or bk is None - or bv is None - or attn_bias is None - or ln1_gamma is None - or ln1_beta is None - ): - raise SystemExit("Failed to locate head weights.") - return wq, bq, wk, bk, wv, bv, wo, attn_bias, ln1_gamma, ln1_beta - - -def read_unembed_columns( - f, - start: int, - model_dim: int, - vocab_size: int, - target: int, - negative: int, -) -> Tuple[np.ndarray, np.ndarray]: - row_bytes = vocab_size * 8 - col_t = np.zeros(model_dim, dtype=np.float64) - col_n = np.zeros(model_dim, dtype=np.float64) - for row in range(model_dim): - base = start + row * row_bytes - f.seek(base + target * 8) - col_t[row] = struct.unpack(" np.ndarray: - mean = x.mean(axis=1, keepdims=True) - var = ((x - mean) ** 2).mean(axis=1, keepdims=True) - x_hat = (x - mean) / np.sqrt(var + eps) - return x_hat * gamma + beta - - -def softmax(scores: np.ndarray) -> np.ndarray: - shift = scores - scores.max(axis=1, keepdims=True) - exp = np.exp(shift) - return exp / exp.sum(axis=1, keepdims=True) - - -def write_induction_cert(path: Path, seq: int, prev: np.ndarray, - scores_rat, weights_rat, eps: Fraction, - margin: Fraction, active_positions, - eps_at, weight_bound_at, vals_rat, - direction_target: int | None, - direction_negative: int | None) -> None: - lo = min(vals_rat) - hi = max(vals_rat) - with path.open("w", encoding="ascii") as f: - f.write(f"seq {seq}\n") - if direction_target is not None and direction_negative is not None: - f.write(f"direction-target {direction_target}\n") - f.write(f"direction-negative {direction_negative}\n") - f.write(f"eps {rat_to_str(eps)}\n") - f.write(f"margin {rat_to_str(margin)}\n") - for q in active_positions: - f.write(f"active {q}\n") - for q, k in enumerate(prev.tolist()): - f.write(f"prev {q} {k}\n") - for q in range(seq): - for k in range(seq): - f.write(f"score {q} {k} {rat_to_str(scores_rat[q][k])}\n") - for q in range(seq): - for k in range(seq): - f.write(f"weight {q} {k} {rat_to_str(weights_rat[q][k])}\n") - for q in range(seq): - f.write(f"eps-at {q} {rat_to_str(eps_at[q])}\n") - for q in range(seq): - for k in range(seq): - f.write(f"weight-bound {q} {k} {rat_to_str(weight_bound_at[q][k])}\n") - f.write(f"lo {rat_to_str(lo)}\n") - f.write(f"hi {rat_to_str(hi)}\n") - for k, val in enumerate(vals_rat): - val_str = rat_to_str(val) - f.write(f"val {k} {val_str}\n") - f.write(f"val-lo {k} {val_str}\n") - f.write(f"val-hi {k} {val_str}\n") - - -def write_value_range(path: Path, seq: int, values, decimals: int, - direction_target=None, direction_negative=None) -> None: - vals_rat = [rat_from_float(float(values[k]), decimals) for k in range(seq)] - lo = min(vals_rat) - hi = max(vals_rat) - with path.open("w", encoding="ascii") as f: - f.write(f"seq {seq}\n") - if direction_target is not None and direction_negative is not None: - f.write(f"direction-target {direction_target}\n") - f.write(f"direction-negative {direction_negative}\n") - f.write(f"lo {rat_to_str(lo)}\n") - f.write(f"hi {rat_to_str(hi)}\n") - for k, val in enumerate(vals_rat): - f.write(f"val {k} {rat_to_str(val)}\n") - - -def main() -> None: - ap = argparse.ArgumentParser(description=__doc__) - ap.add_argument("--model", type=Path, required=True, help="Path to NFP_BINARY_V1 model") - ap.add_argument("--layer", type=int, required=True, help="Layer index") - ap.add_argument("--head", type=int, required=True, help="Head index") - ap.add_argument("--output", type=Path, required=True, - help="Path for induction-head certificate") - ap.add_argument("--values-out", type=Path, - help="Optional path for a value-range certificate") - ap.add_argument("--direction-target", type=int, required=True, - help="Target token id for logit-diff direction") - ap.add_argument("--direction-negative", type=int, required=True, - help="Negative token id for logit-diff direction") - ap.add_argument("--decimals", type=int, default=6, - help="Decimal rounding for rationals") - ap.add_argument("--active-eps-max", default="1/2", - help="Maximum eps to include an active position") - args = ap.parse_args() - - if not args.model.exists(): - raise SystemExit(f"Missing model file: {args.model}") - - with args.model.open("rb") as f: - header = parse_header(f) - num_layers = int(header["num_layers"]) - num_heads = int(header["num_heads"]) - model_dim = int(header["model_dim"]) - head_dim = int(header["head_dim"]) - vocab_size = int(header["vocab_size"]) - seq_len = int(header["seq_len"]) - hidden_dim = int(header["hidden_dim"]) - ln_eps = float(header.get("layer_norm_eps", header.get("eps", "0"))) - - if args.layer < 0 or args.layer >= num_layers: - raise SystemExit("layer index out of range") - if args.head < 0 or args.head >= num_heads: - raise SystemExit("head index out of range") - if not (0 <= args.direction_target < vocab_size): - raise SystemExit("direction-target out of vocab range") - if not (0 <= args.direction_negative < vocab_size): - raise SystemExit("direction-negative out of vocab range") - - tokens = read_i32(f, seq_len) - embeddings = read_f64(f, seq_len * model_dim).reshape(seq_len, model_dim) - - wq, bq, wk, bk, wv, bv, wo_raw, _attn_bias, ln1_gamma, ln1_beta = read_head_weights( - f, - num_layers, - num_heads, - model_dim, - head_dim, - hidden_dim, - args.layer, - args.head, - ) - - skip_f64(f, model_dim) - skip_f64(f, model_dim) - - unembed_start = f.tell() - col_target, col_negative = read_unembed_columns( - f, - unembed_start, - model_dim, - vocab_size, - args.direction_target, - args.direction_negative, - ) - - prev, active_mask = build_prev(tokens) - candidate_positions = [int(i) for i, flag in enumerate(active_mask) if flag] - active_eps_max = Fraction(args.active_eps_max) - - scale_denom = int(math.isqrt(head_dim)) - if scale_denom * scale_denom != head_dim: - scale = 1.0 / math.sqrt(head_dim) - else: - scale = 1.0 / scale_denom - - ln = layer_norm(embeddings, ln1_gamma, ln1_beta, ln_eps) - q = ln @ wq + bq - k = ln @ wk + bk - v = ln @ wv + bv - - scores = scale * (q @ k.T) - mask_value = -10000.0 - mask = np.triu(np.ones((seq_len, seq_len), dtype=bool), k=1) - scores = scores.copy() - scores[mask] = mask_value - weights = softmax(scores) - - scores_rat = [[rat_from_float(float(scores[q, k]), args.decimals) - for k in range(seq_len)] for q in range(seq_len)] - weights_rat = [[rat_from_float(float(weights[q, k]), args.decimals) - for k in range(seq_len)] for q in range(seq_len)] - - for q in range(seq_len): - total = sum(weights_rat[q], Fraction(0)) - if total == 0: - raise SystemExit(f"zero weight sum at q={q}") - weights_rat[q] = [w / total for w in weights_rat[q]] - - eps_by_q: dict[int, Fraction] = {} - margin_by_q: dict[int, Fraction] = {} - for q in range(seq_len): - prev_q = prev[q] - prev_w = weights_rat[q][prev_q] - if seq_len == 1: - max_other = Fraction(0) - else: - max_other = max(weights_rat[q][k] for k in range(seq_len) if k != prev_q) - deficit = Fraction(1) - prev_w - eps_by_q[q] = max(max_other, deficit) - - diffs = [scores_rat[q][prev_q] - scores_rat[q][k] - for k in range(seq_len) if k != prev_q] - margin_by_q[q] = min(diffs) if diffs else Fraction(0) - - active_positions = [q for q in candidate_positions if eps_by_q[q] <= active_eps_max] - if not active_positions and candidate_positions: - print("Warning: no active positions satisfy active-eps-max; certificate may be vacuous.") - - if not active_positions and seq_len > 1: - if candidate_positions: - print("Warning: no active positions satisfy active-eps-max; using all nonzero queries.") - active_positions = list(range(1, seq_len)) - if active_positions: - eps = max(eps_by_q[q] for q in active_positions) - margin = min((margin_by_q[q] for q in active_positions), default=Fraction(0)) - else: - eps = Fraction(0) - margin = Fraction(0) - - eps_at = [] - for q in range(seq_len): - prev_q = prev[q] - if seq_len == 1: - max_other = Fraction(0) - else: - max_other = max(weights_rat[q][k] for k in range(seq_len) if k != prev_q) - deficit = Fraction(1) - weights_rat[q][prev_q] - eps_at.append(max(max_other, deficit)) - weight_bound_at = weights_rat - - wo = wo_raw.T - direction = col_target - col_negative - dir_head = wo.T @ direction - dir_vals = v @ dir_head - vals_rat = [rat_from_float(float(dir_vals[k]), args.decimals) for k in range(seq_len)] - - output_path = args.output - output_path.parent.mkdir(parents=True, exist_ok=True) - write_induction_cert( - output_path, - seq_len, - prev, - scores_rat, - weights_rat, - eps, - margin, - active_positions, - eps_at, - weight_bound_at, - vals_rat, - args.direction_target, - args.direction_negative, - ) - - if args.values_out: - values_path = args.values_out - values_path.parent.mkdir(parents=True, exist_ok=True) - write_value_range(values_path, seq_len, dir_vals, args.decimals, - direction_target=args.direction_target, - direction_negative=args.direction_negative) - print(f"Wrote value-range certificate to {values_path}") - - print(f"Wrote induction-head certificate to {output_path}") - if candidate_positions: - print(f"Active positions: {len(active_positions)}/{len(candidate_positions)}") - - -if __name__ == "__main__": - main() diff --git a/scripts/convert_text_fixture_to_binary.py b/scripts/convert_text_fixture_to_binary.py deleted file mode 100644 index 82d8751..0000000 --- a/scripts/convert_text_fixture_to_binary.py +++ /dev/null @@ -1,273 +0,0 @@ -#!/usr/bin/env python3 -from __future__ import annotations - -import argparse -import struct -from pathlib import Path -from typing import Callable, Dict, List, Tuple - - -def parse_header(lines: List[str], path: Path) -> Tuple[Dict[str, str], int]: - if not lines: - raise ValueError(f"{path}: empty file") - magic = lines[0].strip() - if not magic.startswith("NFP_TEXT"): - raise ValueError(f"{path}: unexpected header '{magic}'") - header: Dict[str, str] = {} - i = 1 - while i < len(lines): - line = lines[i].strip() - i += 1 - if line == "": - break - if "=" in line: - key, value = line.split("=", 1) - header[key.strip()] = value.strip() - return header, i - - -def next_nonempty(lines: List[str], i: int, path: Path) -> Tuple[int, str]: - while i < len(lines) and lines[i].strip() == "": - i += 1 - if i >= len(lines): - raise ValueError(f"{path}: unexpected EOF") - return i, lines[i].strip() - - -def expect_line(lines: List[str], i: int, expected: str, path: Path) -> int: - i, line = next_nonempty(lines, i, path) - if line != expected: - raise ValueError(f"{path}: expected '{expected}', got '{line}'") - return i + 1 - - -def expect_prefix(lines: List[str], i: int, prefix: str, path: Path) -> int: - i, line = next_nonempty(lines, i, path) - if not line.startswith(prefix): - raise ValueError(f"{path}: expected '{prefix} ...', got '{line}'") - return i + 1 - - -def read_numbers( - lines: List[str], - i: int, - count: int, - path: Path, - cast: Callable[[str], float], -) -> Tuple[List[float], int]: - out: List[float] = [] - while len(out) < count: - if i >= len(lines): - raise ValueError(f"{path}: unexpected EOF while reading numbers") - line = lines[i].strip() - i += 1 - if line == "": - continue - for tok in line.split(): - try: - out.append(cast(tok)) - except ValueError as exc: - raise ValueError(f"{path}: invalid number '{tok}'") from exc - if len(out) == count: - break - return out, i - - -def read_input_embeddings(path: Path) -> Tuple[Dict[str, str], List[int], List[float]]: - lines = path.read_text().splitlines() - header, i = parse_header(lines, path) - seq_len = int(header["seq_len"]) - model_dim = int(header["model_dim"]) - i = expect_line(lines, i, "TOKENS", path) - tokens_f, i = read_numbers(lines, i, seq_len, path, int) - tokens = [int(t) for t in tokens_f] - i = expect_line(lines, i, "EMBEDDINGS", path) - embeddings_f, _ = read_numbers(lines, i, seq_len * model_dim, path, float) - return header, tokens, embeddings_f - - -def read_model_weights(path: Path) -> Tuple[Dict[str, str], Dict[str, object]]: - lines = path.read_text().splitlines() - header, i = parse_header(lines, path) - num_layers = int(header["num_layers"]) - num_heads = int(header["num_heads"]) - model_dim = int(header["model_dim"]) - head_dim = int(header["head_dim"]) - hidden_dim = int(header["hidden_dim"]) - - layers: List[Dict[str, object]] = [] - for _ in range(num_layers): - i = expect_prefix(lines, i, "LAYER", path) - heads: List[Dict[str, List[float]]] = [] - for _ in range(num_heads): - i = expect_prefix(lines, i, "HEAD", path) - i = expect_line(lines, i, "W_Q", path) - w_q, i = read_numbers(lines, i, model_dim * head_dim, path, float) - i = expect_line(lines, i, "b_Q", path) - b_q, i = read_numbers(lines, i, head_dim, path, float) - i = expect_line(lines, i, "W_K", path) - w_k, i = read_numbers(lines, i, model_dim * head_dim, path, float) - i = expect_line(lines, i, "b_K", path) - b_k, i = read_numbers(lines, i, head_dim, path, float) - i = expect_line(lines, i, "W_V", path) - w_v, i = read_numbers(lines, i, model_dim * head_dim, path, float) - i = expect_line(lines, i, "b_V", path) - b_v, i = read_numbers(lines, i, head_dim, path, float) - i = expect_line(lines, i, "W_O", path) - w_o, i = read_numbers(lines, i, head_dim * model_dim, path, float) - heads.append( - { - "W_Q": w_q, - "b_Q": b_q, - "W_K": w_k, - "b_K": b_k, - "W_V": w_v, - "b_V": b_v, - "W_O": w_o, - } - ) - i = expect_line(lines, i, "ATTN_BIAS", path) - attn_bias, i = read_numbers(lines, i, model_dim, path, float) - i = expect_line(lines, i, "MLP", path) - i = expect_line(lines, i, "W_in", path) - w_in, i = read_numbers(lines, i, model_dim * hidden_dim, path, float) - i = expect_line(lines, i, "b_in", path) - b_in, i = read_numbers(lines, i, hidden_dim, path, float) - i = expect_line(lines, i, "W_out", path) - w_out, i = read_numbers(lines, i, hidden_dim * model_dim, path, float) - i = expect_line(lines, i, "b_out", path) - b_out, i = read_numbers(lines, i, model_dim, path, float) - i = expect_line(lines, i, "LN1_GAMMA", path) - ln1_gamma, i = read_numbers(lines, i, model_dim, path, float) - i = expect_line(lines, i, "LN1_BETA", path) - ln1_beta, i = read_numbers(lines, i, model_dim, path, float) - i = expect_line(lines, i, "LN2_GAMMA", path) - ln2_gamma, i = read_numbers(lines, i, model_dim, path, float) - i = expect_line(lines, i, "LN2_BETA", path) - ln2_beta, i = read_numbers(lines, i, model_dim, path, float) - layers.append( - { - "heads": heads, - "attn_bias": attn_bias, - "w_in": w_in, - "b_in": b_in, - "w_out": w_out, - "b_out": b_out, - "ln1_gamma": ln1_gamma, - "ln1_beta": ln1_beta, - "ln2_gamma": ln2_gamma, - "ln2_beta": ln2_beta, - } - ) - - return header, {"layers": layers} - - -def write_i32(f, data: List[int]) -> None: - if not data: - return - f.write(struct.pack("<" + "i" * len(data), *data)) - - -def write_f64(f, data: List[float]) -> None: - if not data: - return - f.write(struct.pack("<" + "d" * len(data), *data)) - - -def main() -> None: - parser = argparse.ArgumentParser( - description="Convert a tiny NFP_TEXT fixture into NFP_BINARY_V1." - ) - parser.add_argument( - "--model", - default="tests/fixtures/tiny_sound_model.nfpt", - help="Text model fixture path (NFP_TEXT_V2).", - ) - parser.add_argument( - "--input", - default="tests/fixtures/tiny_sound_input.nfpt", - help="Text input fixture path (NFP_TEXT_V2).", - ) - parser.add_argument( - "--output", - default="tests/fixtures/tiny_sound_binary.nfpt", - help="Output binary fixture path.", - ) - args = parser.parse_args() - - model_path = Path(args.model) - input_path = Path(args.input) - output_path = Path(args.output) - - input_header, tokens, embeddings = read_input_embeddings(input_path) - model_header, model_payload = read_model_weights(model_path) - - if int(input_header["model_dim"]) != int(model_header["model_dim"]): - raise ValueError("input/model model_dim mismatch") - if int(input_header["seq_len"]) != int(model_header["seq_len"]): - raise ValueError("input/model seq_len mismatch") - input_eps = input_header.get("layer_norm_eps") or input_header.get("eps") - model_eps = model_header.get("layer_norm_eps") or model_header.get("eps") - model_gelu = model_header.get("gelu_kind") or model_header.get("gelu_deriv") - input_gelu = input_header.get("gelu_kind") or input_header.get("gelu_deriv") - if model_eps is None: - raise ValueError("model header missing layer_norm_eps") - if input_eps is not None and input_eps != model_eps: - raise ValueError("input/model layer_norm_eps mismatch") - if model_gelu is None: - raise ValueError("model header missing gelu_kind") - if input_gelu is not None and input_gelu != model_gelu: - raise ValueError("input/model gelu_kind mismatch") - - output_path.parent.mkdir(parents=True, exist_ok=True) - with output_path.open("wb") as f: - f.write(b"NFP_BINARY_V1\n") - for key in [ - "num_layers", - "num_heads", - "model_dim", - "head_dim", - "hidden_dim", - "vocab_size", - "seq_len", - ]: - f.write(f"{key}={model_header[key]}\n".encode("ascii")) - f.write(f"layer_norm_eps={model_eps}\n".encode("ascii")) - f.write(f"gelu_kind={model_gelu}\n".encode("ascii")) - f.write(b"BINARY_START\n") - - write_i32(f, tokens) - write_f64(f, embeddings) - - layers = model_payload["layers"] - for layer in layers: - for head in layer["heads"]: - write_f64(f, head["W_Q"]) - write_f64(f, head["b_Q"]) - write_f64(f, head["W_K"]) - write_f64(f, head["b_K"]) - write_f64(f, head["W_V"]) - write_f64(f, head["b_V"]) - write_f64(f, head["W_O"]) - write_f64(f, layer["attn_bias"]) - write_f64(f, layer["w_in"]) - write_f64(f, layer["b_in"]) - write_f64(f, layer["w_out"]) - write_f64(f, layer["b_out"]) - write_f64(f, layer["ln1_gamma"]) - write_f64(f, layer["ln1_beta"]) - write_f64(f, layer["ln2_gamma"]) - write_f64(f, layer["ln2_beta"]) - - model_dim = int(model_header["model_dim"]) - vocab_size = int(model_header["vocab_size"]) - write_f64(f, [1.0] * model_dim) - write_f64(f, [0.0] * model_dim) - write_f64(f, [0.0] * (model_dim * vocab_size)) - - print(f"Wrote {output_path}") - - -if __name__ == "__main__": - main() diff --git a/scripts/ensure_gelu_kind.py b/scripts/ensure_gelu_kind.py deleted file mode 100644 index 0a908a4..0000000 --- a/scripts/ensure_gelu_kind.py +++ /dev/null @@ -1,86 +0,0 @@ -#!/usr/bin/env python3 -"""Ensure a `.nfpt` header has `gelu_kind`, optionally writing a patched copy.""" - -from __future__ import annotations - -import argparse -import shutil -import sys -from pathlib import Path - - -def has_gelu_kind(path: Path) -> bool: - with path.open("rb") as f: - while True: - line = f.readline() - if not line: - raise ValueError("unexpected EOF while reading header") - stripped = line.strip() - if stripped.startswith(b"gelu_kind=") or stripped.startswith(b"gelu_deriv="): - return True - if stripped == b"BINARY_START": - return False - - -def patch_header(path: Path, output: Path, default_kind: str) -> None: - with path.open("rb") as f: - header_lines: list[bytes] = [] - gelu_present = False - while True: - line = f.readline() - if not line: - raise ValueError("unexpected EOF while reading header") - stripped = line.strip() - if stripped.startswith(b"gelu_kind=") or stripped.startswith(b"gelu_deriv="): - gelu_present = True - header_lines.append(line) - if stripped == b"BINARY_START": - break - payload = f.read() - - output.parent.mkdir(parents=True, exist_ok=True) - if gelu_present: - shutil.copyfile(path, output) - print(f"gelu_kind already present; copied to {output}") - return - - patched: list[bytes] = [] - inserted = False - for line in header_lines: - if line.strip() == b"BINARY_START" and not inserted: - patched.append(f"gelu_kind={default_kind}\n".encode("ascii")) - inserted = True - patched.append(line) - - if not inserted: - raise ValueError("missing BINARY_START while patching header") - - with output.open("wb") as f: - f.writelines(patched) - f.write(payload) - print(f"added gelu_kind={default_kind}; wrote {output}") - - -def main() -> int: - parser = argparse.ArgumentParser() - parser.add_argument("input", type=Path) - parser.add_argument("--output", type=Path) - parser.add_argument("--default", default="tanh") - parser.add_argument("--check", action="store_true") - args = parser.parse_args() - - if args.check: - return 0 if has_gelu_kind(args.input) else 1 - - if args.output is None: - raise SystemExit("missing --output (use --check to test only)") - - if args.output.resolve() == args.input.resolve(): - raise SystemExit("refusing in-place patch; pass a distinct --output path") - - patch_header(args.input, args.output, args.default) - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/scripts/export_gpt2.py b/scripts/export_gpt2.py deleted file mode 100644 index 04e8f7c..0000000 --- a/scripts/export_gpt2.py +++ /dev/null @@ -1,239 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: AGPL-3.0-or-later - -""" -Export GPT-2 Small weights to NFP_BINARY_V1 format (.nfpt). - -This script loads GPT-2 Small from HuggingFace and exports all weights needed -for circuit analysis in the NFP library. - -Usage: - uv run scripts/export_gpt2.py [output_path] - -Default output: models/gpt2.nfpt -""" - -import sys -from pathlib import Path -import numpy as np - -try: - from transformers import GPT2Model -except ImportError: - print("Error: transformers library not installed.") - print("Install with: uv add transformers torch") - sys.exit(1) - - -def export_gpt2_weights(output_path: str = "models/gpt2.nfpt", seq_len: int = 256): - """ - Export GPT-2 Small weights to .nfpt format. - - GPT-2 Small architecture: - - 12 layers - - 12 attention heads per layer - - 768 model dimension - - 64 head dimension (768 / 12) - - 3072 MLP hidden dimension (4 * 768) - - 50257 vocabulary size - - Args: - output_path: Path to write the .nfpt file - seq_len: Sequence length for analysis (default 256, smaller than GPT-2's 1024 for speed) - """ - print("Loading GPT-2 Small from HuggingFace...") - model = GPT2Model.from_pretrained("gpt2") - config = model.config - - # Model dimensions - num_layers = config.n_layer # 12 - num_heads = config.n_head # 12 - model_dim = config.n_embd # 768 - head_dim = model_dim // num_heads # 64 - hidden_dim = config.n_inner or 4 * model_dim # 3072 - vocab_size = config.vocab_size # 50257 - layer_norm_eps = float(config.layer_norm_epsilon) - - print("Model configuration:") - print(f" Layers: {num_layers}") - print(f" Heads: {num_heads}") - print(f" Model dim: {model_dim}") - print(f" Head dim: {head_dim}") - print(f" Hidden dim: {hidden_dim}") - print(f" Vocab size: {vocab_size}") - print(f" Sequence length (for analysis): {seq_len}") - - # Create output directory - output_file = Path(output_path) - output_file.parent.mkdir(parents=True, exist_ok=True) - - print(f"\nExporting to {output_path}...") - - with open(output_path, "wb") as f: - # Header - write_header( - f, - num_layers=num_layers, - num_heads=num_heads, - model_dim=model_dim, - head_dim=head_dim, - hidden_dim=hidden_dim, - vocab_size=vocab_size, - seq_len=seq_len, - layer_norm_eps=layer_norm_eps, - gelu_kind="tanh", - ) - - # Sample input embeddings for analysis - # For circuit analysis, we need input embeddings for a specific sequence - # We'll create a dummy sequence of token indices and get their embeddings - # Shape: (seq_len, model_dim) - wte = model.wte.weight.detach().numpy() - wpe = model.wpe.weight.detach().numpy() # Position embeddings - - # Create sample token indices (repeating pattern to make induction detectable) - # Pattern: 0, 1, 2, ..., 15, 0, 1, 2, ..., 15, ... - sample_tokens = np.array([i % 16 for i in range(seq_len)]) - # Include both token embeddings and position embeddings - # This is what GPT-2 computes: x_0 = token_emb + pos_emb - sample_embeddings = wte[sample_tokens] + wpe[:seq_len] # (seq_len, model_dim) - - # Ground-truth token sequence for self-supervised induction targeting. - # This is used by the Lean pipeline to choose the correct induction target - # algorithmically from the sequence history. - write_i32(f, sample_tokens) - write_f64(f, sample_embeddings) - - # Process each layer - for layer_idx in range(num_layers): - print(f" Processing layer {layer_idx}...") - block = model.h[layer_idx] - - - # Attention weights - # GPT-2 stores Q, K, V concatenated in c_attn.weight - # Shape: (model_dim, 3 * model_dim) for weight - # We need to split into Q, K, V and then split each by heads - - c_attn_weight = block.attn.c_attn.weight.detach().numpy() # (768, 2304) - c_attn_bias = get_bias(block.attn.c_attn.bias, 3 * model_dim) # (2304,) - - # c_attn.weight layout: input_dim -> [Q_all_heads, K_all_heads, V_all_heads] - # Each section is (model_dim, model_dim) - W_Q_all = c_attn_weight[:, 0:model_dim] # (768, 768) - W_K_all = c_attn_weight[:, model_dim : 2 * model_dim] # (768, 768) - W_V_all = c_attn_weight[:, 2 * model_dim : 3 * model_dim] # (768, 768) - b_Q_all = c_attn_bias[0:model_dim] # (768,) - b_K_all = c_attn_bias[model_dim : 2 * model_dim] # (768,) - b_V_all = c_attn_bias[2 * model_dim : 3 * model_dim] # (768,) - - # Output projection c_proj - c_proj_weight = block.attn.c_proj.weight.detach().numpy() # (768, 768) - c_proj_bias = get_bias(block.attn.c_proj.bias, model_dim) # (768,) - - # Split into heads - # W_Q_all columns are organized as [head0, head1, ..., head11] - # Each head gets head_dim columns - for head_idx in range(num_heads): - start = head_idx * head_dim - end = (head_idx + 1) * head_dim - - # Extract per-head Q, K, V projections - # W_Q: (model_dim, head_dim) - projects input to queries for this head - W_Q = W_Q_all[:, start:end] # (768, 64) - W_K = W_K_all[:, start:end] # (768, 64) - W_V = W_V_all[:, start:end] # (768, 64) - - # Output projection for this head - # c_proj.weight: (768, 768), organized as [head0_rows, head1_rows, ...] - # W_O: (head_dim, model_dim) - projects head output back to model dim - W_O = c_proj_weight[start:end, :] # (64, 768) - - write_f64(f, W_Q) - write_f64(f, b_Q_all[start:end]) - write_f64(f, W_K) - write_f64(f, b_K_all[start:end]) - write_f64(f, W_V) - write_f64(f, b_V_all[start:end]) - write_f64(f, W_O) - - # Attention output bias (c_proj.bias), applied once after combining heads. - write_f64(f, c_proj_bias) - - # MLP weights - # GPT-2 MLP: x -> c_fc (expand) -> GELU -> c_proj (contract) -> out - # c_fc: (model_dim, hidden_dim) - # c_proj: (hidden_dim, model_dim) - - mlp_c_fc_weight = block.mlp.c_fc.weight.detach().numpy() # (768, 3072) - mlp_c_fc_bias = get_bias(block.mlp.c_fc.bias, hidden_dim) # (3072,) - mlp_c_proj_weight = block.mlp.c_proj.weight.detach().numpy() # (3072, 768) - mlp_c_proj_bias = get_bias(block.mlp.c_proj.bias, model_dim) # (768,) - - # Note: GPT-2 uses Conv1D which stores weights transposed compared to Linear - # Conv1D weight shape is (in_features, out_features) - # Our format expects W_in: (model_dim, hidden_dim) and W_out: (hidden_dim, model_dim) - - write_f64(f, mlp_c_fc_weight) # (768, 3072) - write_f64(f, mlp_c_fc_bias) # (3072,) - write_f64(f, mlp_c_proj_weight) # (3072, 768) - write_f64(f, mlp_c_proj_bias) # (768,) - - # LayerNorm parameters (Pre-LN) - # ln_1 is applied before attention; ln_2 is applied before MLP. - ln1_gamma = block.ln_1.weight.detach().numpy() # (model_dim,) - ln1_beta = get_bias(block.ln_1.bias, model_dim) # (model_dim,) - ln2_gamma = block.ln_2.weight.detach().numpy() # (model_dim,) - ln2_beta = get_bias(block.ln_2.bias, model_dim) # (model_dim,) - - write_f64(f, ln1_gamma) - write_f64(f, ln1_beta) - write_f64(f, ln2_gamma) - write_f64(f, ln2_beta) - - # Final LayerNorm (ln_f) before unembedding - ln_f_gamma = model.ln_f.weight.detach().numpy() # (model_dim,) - ln_f_beta = get_bias(model.ln_f.bias, model_dim) # (model_dim,) - write_f64(f, ln_f_gamma) - write_f64(f, ln_f_beta) - - # Unembedding matrix - # In GPT-2, the unembedding is tied to the embedding (same weights transposed) - # We want (model_dim, vocab_size) for projecting hidden states to logits - # wte is (vocab_size, model_dim), so we transpose it - unembed = wte.T # (768, 50257) - write_f64(f, unembed) - - # Report file size - file_size = output_file.stat().st_size - print("\nExport complete!") - print(f" File size: {file_size / 1024 / 1024:.1f} MB") - print(f" Output: {output_path}") - - -def write_header(f, **fields: object) -> None: - f.write(b"NFP_BINARY_V1\n") - for key, value in fields.items(): - f.write(f"{key}={value}\n".encode("ascii")) - f.write(b"BINARY_START\n") - - -def write_i32(f, data: np.ndarray) -> None: - arr = np.ascontiguousarray(data, dtype=" None: - arr = np.ascontiguousarray(data, dtype=" np.ndarray: - if param is None: - return np.zeros(size, dtype=np.float64) - return param.detach().numpy() - - -if __name__ == "__main__": - output = sys.argv[1] if len(sys.argv) > 1 else "models/gpt2.nfpt" - export_gpt2_weights(output) diff --git a/scripts/generate_induction_data.py b/scripts/generate_induction_data.py deleted file mode 100644 index 2b6fe63..0000000 --- a/scripts/generate_induction_data.py +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: AGPL-3.0-or-later - -""" -Generate a "Strong Induction" dataset for NFP verification. -Creates a sequence of random common words repeated multiple times. -""" - -import sys -from pathlib import Path -import numpy as np -import torch - -try: - from transformers import GPT2Model -except ImportError: - print("Error: transformers library not installed.") - print("Install with: uv add transformers torch") - sys.exit(1) - - -def export_induction_weights(output_path: str = "models/gpt2_induction.nfpt"): - print("Loading GPT-2 Small...") - model = GPT2Model.from_pretrained("gpt2") - config = model.config - layer_norm_eps = float(config.layer_norm_epsilon) - - # Fixed parameters for GPT-2 Small - seq_len = 256 - model_dim = 768 - num_heads = 12 - head_dim = 64 - - # --- CRITICAL CHANGE: Better Induction Data --- - # Instead of tokens 0-15, we use random "word" tokens (1000-30000). - # A pattern of length 30 repeated ~8 times creates strong induction pressure. - np.random.seed(42) - pattern_len = 30 - # Generate 30 random token IDs - pattern = np.random.randint(1000, 30000, size=pattern_len) - # Repeat the pattern to fill seq_len - repeats = (seq_len // pattern_len) + 1 - sample_tokens = np.tile(pattern, repeats)[:seq_len] - - print(f"Generated induction sequence (len={seq_len}):") - print(f"Pattern (first 10): {sample_tokens[:10]}") - # ----------------------------------------------- - - # Compute embeddings - wte = model.wte.weight.detach().numpy() - wpe = model.wpe.weight.detach().numpy() - sample_embeddings = wte[sample_tokens] + wpe[:seq_len] - - output_file = Path(output_path) - output_file.parent.mkdir(parents=True, exist_ok=True) - - print(f"Exporting to {output_path}...") - with open(output_path, "wb") as f: - write_header( - f, - num_layers=config.n_layer, - num_heads=config.n_head, - model_dim=model_dim, - head_dim=head_dim, - hidden_dim=config.n_inner or 4 * model_dim, - vocab_size=config.vocab_size, - seq_len=seq_len, - layer_norm_eps=layer_norm_eps, - gelu_kind="tanh", - ) - - write_i32(f, sample_tokens) - write_f64(f, sample_embeddings) - - # Export weights - for layer_idx in range(config.n_layer): - block = model.h[layer_idx] - - # Attention - c_attn = block.attn.c_attn.weight.detach().numpy() - c_attn_bias = get_bias(block.attn.c_attn.bias, 3 * model_dim) - c_proj = block.attn.c_proj.weight.detach().numpy() - c_proj_bias = get_bias(block.attn.c_proj.bias, model_dim) - - W_Q_all = c_attn[:, 0:model_dim] - W_K_all = c_attn[:, model_dim : 2 * model_dim] - W_V_all = c_attn[:, 2 * model_dim : 3 * model_dim] - b_Q_all = c_attn_bias[0:model_dim] - b_K_all = c_attn_bias[model_dim : 2 * model_dim] - b_V_all = c_attn_bias[2 * model_dim : 3 * model_dim] - - for h in range(num_heads): - start, end = h * head_dim, (h + 1) * head_dim - write_f64(f, W_Q_all[:, start:end]) - write_f64(f, b_Q_all[start:end]) - write_f64(f, W_K_all[:, start:end]) - write_f64(f, b_K_all[start:end]) - write_f64(f, W_V_all[:, start:end]) - write_f64(f, b_V_all[start:end]) - write_f64(f, c_proj[start:end, :]) - - write_f64(f, c_proj_bias) - - # MLP - write_f64(f, block.mlp.c_fc.weight.detach().numpy()) - write_f64(f, get_bias(block.mlp.c_fc.bias, config.n_inner or 4 * model_dim)) - write_f64(f, block.mlp.c_proj.weight.detach().numpy()) - write_f64(f, get_bias(block.mlp.c_proj.bias, model_dim)) - - # LayerNorm - write_f64(f, block.ln_1.weight.detach().numpy()) - write_f64(f, get_bias(block.ln_1.bias, model_dim)) - write_f64(f, block.ln_2.weight.detach().numpy()) - write_f64(f, get_bias(block.ln_2.bias, model_dim)) - - # Unembedding - write_f64(f, model.ln_f.weight.detach().numpy()) - write_f64(f, get_bias(model.ln_f.bias, model_dim)) - write_f64(f, wte.T) - - print("Done.") - - -def write_header(f, **fields: object) -> None: - f.write(b"NFP_BINARY_V1\n") - for key, value in fields.items(): - f.write(f"{key}={value}\n".encode("ascii")) - f.write(b"BINARY_START\n") - - -def write_i32(f, data: np.ndarray) -> None: - arr = np.ascontiguousarray(data, dtype=" None: - arr = np.ascontiguousarray(data, dtype=" np.ndarray: - if param is None: - return np.zeros(size, dtype=np.float64) - return param.detach().numpy() - - -if __name__ == "__main__": - export_induction_weights() diff --git a/scripts/generate_rigorous_induction.py b/scripts/generate_rigorous_induction.py deleted file mode 100644 index 40e339d..0000000 --- a/scripts/generate_rigorous_induction.py +++ /dev/null @@ -1,232 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: AGPL-3.0-or-later - -""" -Generate a "Rigorous Induction" dataset for NFP verification. -Constructs a sequence of random single-token words repeated in a fixed pattern. - -Dataset intent (heuristic, model-agnostic): -- The next token is deterministically defined by a previous occurrence in the sequence. -- Randomized content reduces semantic cues but does not guarantee model probabilities. - -This aims to isolate induction-style copying from semantic completion. -""" - -from __future__ import annotations - -import argparse -import sys -from pathlib import Path - -import numpy as np -import torch - -try: - from transformers import GPT2Model, GPT2Tokenizer -except ImportError: - print("Error: transformers library not installed.") - print("Install with: uv add transformers torch") - sys.exit(1) - - -def select_vocab_candidates( - tokenizer, - vocab_min: int, - vocab_max: int, - min_word_length: int, - require_leading_space: bool, -) -> list[int]: - candidates = [] - for tid in range(vocab_min, vocab_max): - word = tokenizer.decode([tid]) - if len(word.strip()) <= min_word_length: - continue - if require_leading_space and not word.startswith(" "): - continue - candidates.append(tid) - return candidates - - -def export_rigorous_induction( - output_path: str = "models/gpt2_rigorous.nfpt", - seq_len: int = 256, - pattern_len: int = 20, - seed: int = 1337, - vocab_min: int = 1000, - vocab_max: int = 5000, - min_word_length: int = 4, - require_leading_space: bool = True, - model_name: str = "gpt2", -) -> None: - print(f"Loading {model_name}...") - model = GPT2Model.from_pretrained(model_name) - tokenizer = GPT2Tokenizer.from_pretrained(model_name) - config = model.config - layer_norm_eps = float(config.layer_norm_epsilon) - - if seq_len <= 0: - raise ValueError("seq_len must be positive") - if pattern_len <= 0 or pattern_len > seq_len: - raise ValueError("pattern_len must be between 1 and seq_len") - if vocab_min < 0 or vocab_max <= vocab_min: - raise ValueError("invalid vocab range") - - vocab_candidates = select_vocab_candidates( - tokenizer, - vocab_min=vocab_min, - vocab_max=vocab_max, - min_word_length=min_word_length, - require_leading_space=require_leading_space, - ) - if len(vocab_candidates) < pattern_len: - raise ValueError( - f"Need at least {pattern_len} vocab candidates; only found {len(vocab_candidates)}" - ) - - np.random.seed(seed) - unique_pattern = np.random.choice(vocab_candidates, size=pattern_len, replace=False) - - repeats = (seq_len // pattern_len) + 1 - full_sequence = np.tile(unique_pattern, repeats)[:seq_len] - - last_token = full_sequence[-1] - prev_idx = seq_len - 1 - pattern_len - target_token = full_sequence[prev_idx + 1] - - print("\nSequence Structure:") - print(f" Pattern Length: {pattern_len}") - print(f" Total Length: {seq_len}") - print(f" Seed: {seed}") - print(f" Vocab Range: [{vocab_min}, {vocab_max})") - print( - f" Token Filter: min_len>{min_word_length}, " - f"leading_space={require_leading_space}" - ) - print(f" Last Token: '{tokenizer.decode([last_token])}' (ID: {last_token})") - print(f" Previous Occur: Index {prev_idx}") - print( - f" True Target: '{tokenizer.decode([target_token])}' (ID: {target_token})" - ) - - wte = model.wte.weight.detach().numpy() - wpe = model.wpe.weight.detach().numpy() - sample_embeddings = wte[full_sequence] + wpe[:seq_len] - - output_file = Path(output_path) - output_file.parent.mkdir(parents=True, exist_ok=True) - - print(f"\nExporting to {output_path}...") - with output_file.open("wb") as f: - write_header( - f, - num_layers=config.n_layer, - num_heads=config.n_head, - model_dim=768, - head_dim=64, - hidden_dim=config.n_inner or 4 * 768, - vocab_size=config.vocab_size, - seq_len=seq_len, - layer_norm_eps=layer_norm_eps, - gelu_kind="tanh", - ) - - write_i32(f, full_sequence) - write_f64(f, sample_embeddings) - - for layer_idx in range(config.n_layer): - block = model.h[layer_idx] - - c_attn = block.attn.c_attn.weight.detach().numpy() - c_attn_bias = get_bias(block.attn.c_attn.bias, 3 * 768) - c_proj = block.attn.c_proj.weight.detach().numpy() - c_proj_bias = get_bias(block.attn.c_proj.bias, 768) - - w_q_all = c_attn[:, 0:768] - w_k_all = c_attn[:, 768 : 2 * 768] - w_v_all = c_attn[:, 2 * 768 : 3 * 768] - b_q_all = c_attn_bias[0:768] - b_k_all = c_attn_bias[768 : 2 * 768] - b_v_all = c_attn_bias[2 * 768 : 3 * 768] - - for h in range(12): - start, end = h * 64, (h + 1) * 64 - write_f64(f, w_q_all[:, start:end]) - write_f64(f, b_q_all[start:end]) - write_f64(f, w_k_all[:, start:end]) - write_f64(f, b_k_all[start:end]) - write_f64(f, w_v_all[:, start:end]) - write_f64(f, b_v_all[start:end]) - write_f64(f, c_proj[start:end, :]) - - write_f64(f, c_proj_bias) - - write_f64(f, block.mlp.c_fc.weight.detach().numpy()) - write_f64(f, get_bias(block.mlp.c_fc.bias, config.n_inner or 4 * 768)) - write_f64(f, block.mlp.c_proj.weight.detach().numpy()) - write_f64(f, get_bias(block.mlp.c_proj.bias, 768)) - - write_f64(f, block.ln_1.weight.detach().numpy()) - write_f64(f, get_bias(block.ln_1.bias, 768)) - write_f64(f, block.ln_2.weight.detach().numpy()) - write_f64(f, get_bias(block.ln_2.bias, 768)) - - write_f64(f, model.ln_f.weight.detach().numpy()) - write_f64(f, get_bias(model.ln_f.bias, 768)) - write_f64(f, wte.T) - - print("Done.") - - -def write_header(f, **fields: object) -> None: - f.write(b"NFP_BINARY_V1\n") - for key, value in fields.items(): - f.write(f"{key}={value}\n".encode("ascii")) - f.write(b"BINARY_START\n") - - -def write_i32(f, data: np.ndarray) -> None: - arr = np.ascontiguousarray(data, dtype=" None: - arr = np.ascontiguousarray(data, dtype=" np.ndarray: - if param is None: - return np.zeros(size, dtype=np.float64) - return param.detach().numpy() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("--output", default="models/gpt2_rigorous.nfpt") - parser.add_argument("--seq-len", type=int, default=256) - parser.add_argument("--pattern-len", type=int, default=20) - parser.add_argument("--seed", type=int, default=1337) - parser.add_argument("--vocab-min", type=int, default=1000) - parser.add_argument("--vocab-max", type=int, default=5000) - parser.add_argument("--min-word-length", type=int, default=4) - parser.add_argument("--require-leading-space", action="store_true", default=True) - parser.add_argument( - "--allow-no-leading-space", - action="store_true", - help="Permit tokens without a leading space", - ) - parser.add_argument("--model", default="gpt2") - args = parser.parse_args() - - require_leading_space = args.require_leading_space and not args.allow_no_leading_space - export_rigorous_induction( - output_path=args.output, - seq_len=args.seq_len, - pattern_len=args.pattern_len, - seed=args.seed, - vocab_min=args.vocab_min, - vocab_max=args.vocab_max, - min_word_length=args.min_word_length, - require_leading_space=require_leading_space, - model_name=args.model, - ) diff --git a/tests/fixtures/tiny_sound_binary.nfpt b/tests/fixtures/tiny_sound_binary.nfpt deleted file mode 100644 index 30238ce31a5eb7d87b836e71af927438b177b389..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1387 zcmeZs3y62}^m7b~j1M#9$}7!{&q=IIEh@H!a57R8Q$Xz8{FKz3_>|0CTN5sj7?^3q zm64f}l9~qtG~g;uEsW1e&9gP)0vi~gmtT|{pIQJkELGQ(D?K%* zG(I~sFU7VbF)xD)>Vn`9$Dj}{Mg|53CLo4@2lf#9!hQ&S0LtG0<-dUPA3*sRp!@^- zXUznGo%Rs=8bBW$g_?WZ9^&2;P6X;l{uK4PUr^2!AD1-AbtV zPBit1cz{MQ+#IyA2i&iK>gbSw-;*een@=GgQ{Bsb0>QEL-lXjPbmDM z;SG)7gZ8@4GOr(*NkM6?{fS1Q23jkPp?s_T5dVIFx(gQ#3l&^6E`2a@T>4<*Fm Date: Sat, 17 Jan 2026 09:40:41 +0100 Subject: [PATCH 226/244] Clean .gitignore of nfpt entries --- .gitignore | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.gitignore b/.gitignore index ed0dfed..8d7a9be 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,4 @@ /.lake -models -*.nfpt -!tests/fixtures/*.nfpt .DS_Store logs/ sound_cache/ @@ -9,5 +6,4 @@ lean-reference-manual/ inductionhead_papers/ nfp.zip reports/ -tests/fixtures/tiny_sound_binary.nfpt nfp_agent_book/ From d34ad22864b624ed7c021b57a18266c8159d9b6b Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 09:54:23 +0100 Subject: [PATCH 227/244] Streamline induction CLI and tighten cert generation --- CLAIMS.md | 5 +-- Nfp/Cli.lean | 57 ++-------------------------- README.md | 3 +- scripts/build_gpt2_induction_cert.py | 29 ++++++++------ 4 files changed, 26 insertions(+), 68 deletions(-) diff --git a/CLAIMS.md b/CLAIMS.md index 118bb05..570c0d0 100644 --- a/CLAIMS.md +++ b/CLAIMS.md @@ -19,9 +19,8 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri ## Soundly checked by the trusted CLI -- `nfp induction certify`, `nfp induction certify_nonvacuous`, and - `nfp induction head_cert_check` verify explicit induction-head certificates from a single - cert file, optionally enforcing minimum `active`, `margin`, `eps`, and logit-diff gates. +- `nfp induction certify` verifies explicit induction-head certificates from a single cert + file, optionally enforcing minimum `active`, `margin`, `eps`, and logit-diff gates. ## Untrusted / heuristic diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index 67c1b4e..b5df7e5 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -29,7 +29,7 @@ def versionCmd : Cmd := `[Cli| "Print the NFP version." ] -private def runInductionCertifyUnified (requireNonvacuous : Bool) (p : Parsed) : IO UInt32 := do +private def runInductionCertifySimple (p : Parsed) : IO UInt32 := do let certPath? := (p.flag? "cert").map (·.as! String) let minActive? := (p.flag? "min-active").map (·.as! Nat) let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) @@ -41,18 +41,8 @@ private def runInductionCertifyUnified (requireNonvacuous : Bool) (p : Parsed) : match certPath? with | none => fail "provide --cert" | some certPath => - if requireNonvacuous then - IO.runInductionHeadCertCheck certPath minActive? minLogitDiffStr? - minMarginStr? maxEpsStr? - else - IO.runInductionHeadCertCheck certPath minActive? minLogitDiffStr? - minMarginStr? maxEpsStr? - -private def runInductionCertifySimple (p : Parsed) : IO UInt32 := - runInductionCertifyUnified false p - -private def runInductionCertifyNonvacuousSimple (p : Parsed) : IO UInt32 := - runInductionCertifyUnified true p + IO.runInductionHeadCertCheck certPath minActive? minLogitDiffStr? + minMarginStr? maxEpsStr? /-- `nfp induction certify` subcommand (streamlined). -/ def inductionCertifySimpleCmd : Cmd := `[Cli| @@ -68,51 +58,12 @@ def inductionCertifySimpleCmd : Cmd := `[Cli| "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." ] -/-- `nfp induction certify_nonvacuous` subcommand (streamlined). -/ -def inductionCertifyNonvacuousSimpleCmd : Cmd := `[Cli| - certify_nonvacuous VIA runInductionCertifyNonvacuousSimple; - "Require a strictly positive logit-diff bound from a cert." - FLAGS: - cert : String; "Path to the induction head certificate file." - "min-active" : Nat; "Optional minimum number of active queries required \ - (default: max 1 (seq/8))." - "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ - (rational literal; default: 0)." - "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." - "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." -] - -/-- `nfp induction head_cert_check` subcommand. -/ -def runInductionHeadCertCheck (p : Parsed) : IO UInt32 := do - let certPath := p.flag! "cert" |>.as! String - let minActive? := (p.flag? "min-active").map (·.as! Nat) - let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) - let minMarginStr? := (p.flag? "min-margin").map (·.as! String) - let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) - IO.runInductionHeadCertCheck certPath minActive? minLogitDiffStr? minMarginStr? maxEpsStr? - -/-- `nfp induction head_cert_check` subcommand. -/ -def inductionHeadCertCheckCmd : Cmd := `[Cli| - head_cert_check VIA runInductionHeadCertCheck; - "Check an explicit induction-head certificate." - FLAGS: - cert : String; "Path to the induction-head certificate file." - "min-active" : Nat; "Optional minimum number of active queries required \ - (default: max 1 (seq/8))." - "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ - (rational literal; defaults to 0 when direction is set)." - "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." - "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." -] - /-- Induction-head subcommands. -/ def inductionCmd : Cmd := `[Cli| induction NOOP; "Induction-head utilities (streamlined)." SUBCOMMANDS: - inductionCertifySimpleCmd; - inductionCertifyNonvacuousSimpleCmd; - inductionHeadCertCheckCmd + inductionCertifySimpleCmd ] /-- The root CLI command. -/ diff --git a/README.md b/README.md index 597fed8..40c39ad 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,8 @@ by **untrusted** Python scripts and verified by the Lean CLI; no model forward p ```bash python scripts/build_gpt2_induction_cert.py \ --output reports/gpt2_induction.cert \ - --layer 5 --head 1 --seq 32 --pattern-length 16 \ + --layer 0 --head 5 --seq 32 --pattern-length 16 \ + --random-pattern --seed 0 \ --active-eps-max 1/2 ``` diff --git a/scripts/build_gpt2_induction_cert.py b/scripts/build_gpt2_induction_cert.py index fa206df..c81163a 100644 --- a/scripts/build_gpt2_induction_cert.py +++ b/scripts/build_gpt2_induction_cert.py @@ -10,12 +10,16 @@ Usage: python scripts/build_gpt2_induction_cert.py --output reports/gpt2_induction.cert \ - --layer 5 --head 1 --seq 32 --pattern-length 16 \ + --layer 0 --head 5 --seq 32 --pattern-length 16 \ + --random-pattern --seed 0 \ --values-out reports/gpt2_induction.values --value-dim 0 \ - --active-eps-max 0.2 + --active-eps-max 0.2 --min-margin 0 Optionally, provide a logit-diff direction: --direction-target --direction-negative + +Note: active positions are filtered by --active-eps-max and --min-margin. If +none qualify, the script exits with an error. """ import argparse @@ -192,6 +196,8 @@ def main() -> None: help="Value dimension index for the value-range certificate") parser.add_argument("--active-eps-max", default="1/2", help="Maximum eps to include an active position (default: 1/2).") + parser.add_argument("--min-margin", default="0", + help="Minimum score gap required for an active position (default: 0).") parser.add_argument("--direction-target", type=int, help="Target token id for logit-diff direction (optional)") parser.add_argument("--direction-negative", type=int, @@ -241,16 +247,17 @@ def main() -> None: for k in range(args.seq) if k != prev_q] margin_by_q[q] = min(diffs) if diffs else Fraction(0) - active_positions = candidate_positions eps_threshold = Fraction(args.active_eps_max) - active_positions = [q for q in candidate_positions if eps_by_q[q] <= eps_threshold] - if not active_positions and candidate_positions: - print("Warning: no active positions satisfy active-eps-max; certificate may be vacuous.") - - if not active_positions and args.seq > 1: - if candidate_positions: - print("Warning: no active positions satisfy active-eps-max; using all nonzero queries.") - active_positions = list(range(1, args.seq)) + min_margin = Fraction(args.min_margin) + active_positions = [ + q for q in candidate_positions + if eps_by_q[q] <= eps_threshold and margin_by_q[q] >= min_margin + ] + if not active_positions: + raise SystemExit( + "No active positions satisfy active-eps-max/min-margin. " + "Try a different head/layer, random-pattern/seed, or relax the thresholds." + ) if active_positions: eps = max(eps_by_q[q] for q in active_positions) margin = min((margin_by_q[q] for q in active_positions), default=Fraction(0)) From ae60aec25e0edaf9771a5610f55f8d1573706109 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 10:03:07 +0100 Subject: [PATCH 228/244] Adopt 1-based indexing in certificates --- Nfp/Circuit/Layers/Induction/Basic.lean | 42 +++++++------ Nfp/IO/InductionHead/Cert.lean | 80 ++++++++++++------------- Nfp/IO/Parse/SoftmaxMargin/Shared.lean | 68 ++++++++++----------- Nfp/IO/Parse/ValueRange/Shared.lean | 16 +++-- README.md | 4 +- docs/induction_cert_audit.md | 2 +- scripts/build_gpt2_induction_cert.py | 30 +++++----- 7 files changed, 129 insertions(+), 113 deletions(-) diff --git a/Nfp/Circuit/Layers/Induction/Basic.lean b/Nfp/Circuit/Layers/Induction/Basic.lean index cc398db..db2d94f 100644 --- a/Nfp/Circuit/Layers/Induction/Basic.lean +++ b/Nfp/Circuit/Layers/Induction/Basic.lean @@ -55,7 +55,8 @@ section Spec variable {Val : Type v} variable {n : Nat} -/-- Induction-head spec: for nonzero queries, outputs copy `prev` values. -/ +/-- Induction-head spec: for non-initial queries (1-based indices ≥ 2), + outputs copy `prev` values. -/ def InductionSpec (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) (out vals : Fin (Nat.succ n) → Val) : Prop := ∀ q, q ≠ 0 → out q = vals (prev q) @@ -147,17 +148,19 @@ section Bounds variable {Val : Type v} [Semiring Val] [PartialOrder Val] variable {seq : Nat} [NeZero seq] -/-- Numeric bounds certifying one-hot weights on nonzero queries. -/ +/-- Numeric bounds certifying one-hot weights on non-initial queries + (1-based indices ≥ 2). -/ structure OneHotBoundsOn (prev : Fin seq → Fin seq) (weights : Fin seq → Fin seq → Val) : Prop where - /-- All weights are nonnegative on nonzero queries. -/ + /-- All weights are nonnegative on non-initial queries (1-based indices ≥ 2). -/ nonneg : ∀ q, q ≠ 0 → ∀ k, 0 ≤ weights q k - /-- Weights sum to one on nonzero queries. -/ + /-- Weights sum to one on non-initial queries (1-based indices ≥ 2). -/ sum_one : ∀ q, q ≠ 0 → (∑ k, weights q k) = 1 - /-- Non-prev weights are nonpositive on nonzero queries. -/ + /-- Non-prev weights are nonpositive on non-initial queries (1-based indices ≥ 2). -/ other_le_zero : ∀ q, q ≠ 0 → ∀ k, k ≠ prev q → weights q k ≤ 0 -/-- Certified bounds imply one-hot weights on nonzero queries. -/ +/-- Certified bounds imply one-hot weights on non-initial queries + (1-based indices ≥ 2). -/ theorem oneHot_of_boundsOn (prev : Fin seq → Fin seq) (weights : Fin seq → Fin seq → Val) [DecidableEq (Fin seq)] (h : OneHotBoundsOn prev weights) : @@ -200,16 +203,18 @@ section ApproxBounds variable {Val : Type v} [Semiring Val] [PartialOrder Val] variable {seq : Nat} [NeZero seq] -/-- Approximate one-hot bounds for attention weights on nonzero queries. -/ +/-- Approximate one-hot bounds for attention weights on non-initial queries + (1-based indices ≥ 2). -/ structure OneHotApproxBoundsOn (ε : Val) (prev : Fin seq → Fin seq) (weights : Fin seq → Fin seq → Val) : Prop where - /-- All weights are nonnegative on nonzero queries. -/ + /-- All weights are nonnegative on non-initial queries (1-based indices ≥ 2). -/ nonneg : ∀ q, q ≠ 0 → ∀ k, 0 ≤ weights q k - /-- Weights sum to one on nonzero queries. -/ + /-- Weights sum to one on non-initial queries (1-based indices ≥ 2). -/ sum_one : ∀ q, q ≠ 0 → (∑ k, weights q k) = 1 - /-- The `prev` weight is within `ε` of one on nonzero queries. -/ + /-- The `prev` weight is within `ε` of one on non-initial queries + (1-based indices ≥ 2). -/ prev_large : ∀ q, q ≠ 0 → 1 ≤ weights q (prev q) + ε - /-- Non-prev weights are at most `ε` on nonzero queries. -/ + /-- Non-prev weights are at most `ε` on non-initial queries (1-based indices ≥ 2). -/ other_le : ∀ q, q ≠ 0 → ∀ k, k ≠ prev q → weights q k ≤ ε /-- Approximate one-hot bounds for attention weights on active queries. -/ @@ -514,15 +519,17 @@ variable {seq : Nat} [NeZero seq] /-- Softmax margin certificates for approximate one-hot weights. -/ structure SoftmaxMarginBounds (ε margin : Val) (prev : Fin seq → Fin seq) (scores weights : Fin seq → Fin seq → Val) : Prop where - /-- Score gap between `prev` and other keys on nonzero queries. -/ + /-- Score gap between `prev` and other keys on non-initial queries + (1-based indices ≥ 2). -/ score_margin : ∀ q, q ≠ 0 → ∀ k, k ≠ prev q → scores q k + margin ≤ scores q (prev q) - /-- All weights are nonnegative on nonzero queries. -/ + /-- All weights are nonnegative on non-initial queries (1-based indices ≥ 2). -/ nonneg : ∀ q, q ≠ 0 → ∀ k, 0 ≤ weights q k - /-- Weights sum to one on nonzero queries. -/ + /-- Weights sum to one on non-initial queries (1-based indices ≥ 2). -/ sum_one : ∀ q, q ≠ 0 → (∑ k, weights q k) = 1 - /-- The `prev` weight is within `ε` of one on nonzero queries. -/ + /-- The `prev` weight is within `ε` of one on non-initial queries + (1-based indices ≥ 2). -/ prev_large : ∀ q, q ≠ 0 → 1 ≤ weights q (prev q) + ε - /-- Non-prev weights are at most `ε` on nonzero queries. -/ + /-- Non-prev weights are at most `ε` on non-initial queries (1-based indices ≥ 2). -/ other_le : ∀ q, q ≠ 0 → ∀ k, k ≠ prev q → weights q k ≤ ε /-- Softmax margin certificates for approximate one-hot weights on active queries. -/ @@ -796,7 +803,8 @@ variable {Val : Type v} [NonAssocSemiring Val] variable (scale : Val) variable (softmax : (Fin (Nat.succ n) → Val) → Fin (Nat.succ n) → Val) -/-- One-hot weights on nonzero queries imply the induction spec for typed evaluation. -/ +/-- One-hot weights on non-initial queries (1-based indices ≥ 2) imply the induction spec + for typed evaluation. -/ theorem attentionTyped_eval_inductionSpec_of_oneHot (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) (input : AttentionInput Batch (Nat.succ n) heads dim → Val) diff --git a/Nfp/IO/InductionHead/Cert.lean b/Nfp/IO/InductionHead/Cert.lean index 55555f0..e6b12f2 100644 --- a/Nfp/IO/InductionHead/Cert.lean +++ b/Nfp/IO/InductionHead/Cert.lean @@ -10,6 +10,9 @@ public import Nfp.IO.Util /-! Untrusted parsing and checking for explicit induction-head certificates. + +All sequence indices in the certificate payload are 1-based (literature convention) and +are converted to `Fin` indices internally. -/ public section @@ -78,59 +81,54 @@ def initState (seq : Nat) : ParseState seq := directionTarget := none directionNegative := none } +private def toIndex1 {seq : Nat} (label : String) (idx : Nat) : Except String (Fin seq) := do + if idx = 0 then + throw s!"{label} index must be >= 1" + let idx' := idx - 1 + if h : idx' < seq then + return ⟨idx', h⟩ + else + throw s!"{label} index out of range: {idx}" + private def setActive {seq : Nat} (st : ParseState seq) (q : Nat) : Except String (ParseState seq) := do - if hq : q < seq then - let qFin : Fin seq := ⟨q, hq⟩ - if qFin ∈ st.active then - throw s!"duplicate active entry for q={q}" - else - return { st with active := insert qFin st.active, activeSeen := true } + let qFin ← toIndex1 (seq := seq) "q" q + if qFin ∈ st.active then + throw s!"duplicate active entry for q={q}" else - throw s!"active index out of range: q={q}" + return { st with active := insert qFin st.active, activeSeen := true } private def setPrev {seq : Nat} (st : ParseState seq) (q k : Nat) : Except String (ParseState seq) := do - if q < seq then - if hk : k < seq then - let kFin : Fin seq := ⟨k, hk⟩ - match st.prev[q]! with - | some _ => - throw s!"duplicate prev entry for q={q}" - | none => - let prev' := st.prev.set! q (some kFin) - return { st with prev := prev' } - else - throw s!"prev index out of range: k={k}" - else - throw s!"prev index out of range: q={q}" + let qFin ← toIndex1 (seq := seq) "q" q + let kFin ← toIndex1 (seq := seq) "k" k + match st.prev[qFin.1]! with + | some _ => + throw s!"duplicate prev entry for q={q}" + | none => + let prev' := st.prev.set! qFin.1 (some kFin) + return { st with prev := prev' } private def setVecEntry {seq : Nat} (arr : Array (Option Rat)) (idx : Nat) (v : Rat) : Except String (Array (Option Rat)) := do - if idx < seq then - match arr[idx]! with - | some _ => - throw s!"duplicate entry for k={idx}" - | none => - return arr.set! idx (some v) - else - throw s!"index out of range: k={idx}" + let kFin ← toIndex1 (seq := seq) "k" idx + match arr[kFin.1]! with + | some _ => + throw s!"duplicate entry for k={idx}" + | none => + return arr.set! kFin.1 (some v) private def setMatrixEntry {seq : Nat} (mat : Array (Array (Option Rat))) (q k : Nat) (v : Rat) : Except String (Array (Array (Option Rat))) := do - if q < seq then - if k < seq then - let row := mat[q]! - match row[k]! with - | some _ => - throw s!"duplicate matrix entry at ({q}, {k})" - | none => - let row' := row.set! k (some v) - return mat.set! q row' - else - throw s!"index out of range: k={k}" - else - throw s!"index out of range: q={q}" + let qFin ← toIndex1 (seq := seq) "q" q + let kFin ← toIndex1 (seq := seq) "k" k + let row := mat[qFin.1]! + match row[kFin.1]! with + | some _ => + throw s!"duplicate matrix entry at ({q}, {k})" + | none => + let row' := row.set! kFin.1 (some v) + return mat.set! qFin.1 row' /-- Parse a tokenized line into the parse state. -/ def parseLine {seq : Nat} (st : ParseState seq) (tokens : List String) : diff --git a/Nfp/IO/Parse/SoftmaxMargin/Shared.lean b/Nfp/IO/Parse/SoftmaxMargin/Shared.lean index 0ac29c9..2ddfb57 100644 --- a/Nfp/IO/Parse/SoftmaxMargin/Shared.lean +++ b/Nfp/IO/Parse/SoftmaxMargin/Shared.lean @@ -7,6 +7,9 @@ public import Nfp.IO.Parse.Basic /-! Shared parsing helpers for softmax-margin payloads. + +All sequence indices in the payload are 1-based (literature convention) and are converted to +`Fin` indices internally. -/ public section @@ -47,50 +50,47 @@ def initState (seq : Nat) : ParseState seq := scores := Array.replicate seq row weights := Array.replicate seq row } +private def toIndex1 {seq : Nat} (label : String) (idx : Nat) : Except String (Fin seq) := do + if idx = 0 then + throw s!"{label} index must be >= 1" + let idx' := idx - 1 + if h : idx' < seq then + return ⟨idx', h⟩ + else + throw s!"{label} index out of range: {idx}" + /-- Set a predecessor entry from `(q, k)` tokens. -/ def setPrev {seq : Nat} (st : ParseState seq) (q k : Nat) : Except String (ParseState seq) := do - if q < seq then - if hk : k < seq then - let kFin : Fin seq := ⟨k, hk⟩ - match st.prev[q]! with - | some _ => - throw s!"duplicate prev entry for q={q}" - | none => - let prev' := st.prev.set! q (some kFin) - return { st with prev := prev' } - else - throw s!"prev index out of range: k={k}" - else - throw s!"prev index out of range: q={q}" + let qFin ← toIndex1 (seq := seq) "q" q + let kFin ← toIndex1 (seq := seq) "k" k + match st.prev[qFin.1]! with + | some _ => + throw s!"duplicate prev entry for q={q}" + | none => + let prev' := st.prev.set! qFin.1 (some kFin) + return { st with prev := prev' } /-- Mark an active query index. -/ def setActive {seq : Nat} (st : ParseState seq) (q : Nat) : Except String (ParseState seq) := do - if hq : q < seq then - let qFin : Fin seq := ⟨q, hq⟩ - if qFin ∈ st.active then - throw s!"duplicate active entry for q={q}" - else - return { st with active := insert qFin st.active, activeSeen := true } + let qFin ← toIndex1 (seq := seq) "q" q + if qFin ∈ st.active then + throw s!"duplicate active entry for q={q}" else - throw s!"active index out of range: q={q}" + return { st with active := insert qFin st.active, activeSeen := true } /-- Insert a matrix entry for scores/weights. -/ def setMatrixEntry {seq : Nat} (mat : Array (Array (Option Rat))) (q k : Nat) (v : Rat) : Except String (Array (Array (Option Rat))) := do - if q < seq then - if k < seq then - let row := mat[q]! - match row[k]! with - | some _ => - throw s!"duplicate matrix entry at ({q}, {k})" - | none => - let row' := row.set! k (some v) - let mat' := mat.set! q row' - return mat' - else - throw s!"index out of range: k={k}" - else - throw s!"index out of range: q={q}" + let qFin ← toIndex1 (seq := seq) "q" q + let kFin ← toIndex1 (seq := seq) "k" k + let row := mat[qFin.1]! + match row[kFin.1]! with + | some _ => + throw s!"duplicate matrix entry at ({q}, {k})" + | none => + let row' := row.set! kFin.1 (some v) + let mat' := mat.set! qFin.1 row' + return mat' /-- Parse a tokenized line into the softmax-margin parse state. -/ def parseLine {seq : Nat} (st : ParseState seq) diff --git a/Nfp/IO/Parse/ValueRange/Shared.lean b/Nfp/IO/Parse/ValueRange/Shared.lean index e51800c..df180e8 100644 --- a/Nfp/IO/Parse/ValueRange/Shared.lean +++ b/Nfp/IO/Parse/ValueRange/Shared.lean @@ -7,6 +7,9 @@ public import Nfp.IO.Parse.Basic /-! Shared parsing helpers for value-range payloads. + +All sequence indices in the payload are 1-based (literature convention) and are converted to +`Fin` indices internally. -/ public section @@ -47,17 +50,20 @@ def initState (seq : Nat) : ParseState seq := /-- Set a value entry from `(k, v)` tokens. -/ def setVal {seq : Nat} (st : ParseState seq) (k : Nat) (v : Rat) : Except String (ParseState seq) := do - if hk : k < seq then - let kFin : Fin seq := ⟨k, hk⟩ + if k = 0 then + throw "value index must be >= 1" + let k' := k - 1 + if hk : k' < seq then + let kFin : Fin seq := ⟨k', hk⟩ match st.vals kFin with | some _ => throw s!"duplicate value entry for k={k}" | none => - let vals' : Fin seq → Option Rat := fun k' => - if k' = kFin then + let vals' : Fin seq → Option Rat := fun k'' => + if k'' = kFin then some v else - st.vals k' + st.vals k'' return { st with vals := vals' } else throw s!"value index out of range: k={k}" diff --git a/README.md b/README.md index 40c39ad..8941d04 100644 --- a/README.md +++ b/README.md @@ -94,8 +94,10 @@ val-lo val-hi ``` +All sequence indices (`q`, `k`) are **1-based** (literature convention). Direction token IDs +(`direction-target`, `direction-negative`) are raw model IDs (tokenizer convention). `direction-*` lines are optional metadata; if present, both must appear. If no `active` lines -appear, the checker defaults to all nonzero queries. +appear, the checker defaults to all non-initial queries (indices 2.. in 1-based indexing). ## Soundness boundary diff --git a/docs/induction_cert_audit.md b/docs/induction_cert_audit.md index 2e53363..f9cd7fe 100644 --- a/docs/induction_cert_audit.md +++ b/docs/induction_cert_audit.md @@ -6,7 +6,7 @@ induction heads, and spell out the scope and limitations of that claim. ## Formal proof chain (Lean) - Explicit induction-head certificates are parsed from text in - `Nfp/IO/InductionHead/Cert.lean`. + `Nfp/IO/InductionHead/Cert.lean` (sequence indices are 1-based in the payload). - `checkInductionHeadCert` and `checkInductionHeadCert_sound` show that a passing certificate satisfies `InductionHeadCertBounds` (`Nfp/Circuit/Cert/InductionHead.lean`). diff --git a/scripts/build_gpt2_induction_cert.py b/scripts/build_gpt2_induction_cert.py index c81163a..1b5a84a 100644 --- a/scripts/build_gpt2_induction_cert.py +++ b/scripts/build_gpt2_induction_cert.py @@ -7,6 +7,7 @@ This script is untrusted and uses floating-point arithmetic to produce a rational induction-head certificate compatible with `nfp induction certify`. Active induction positions are recorded as `active ` lines in the output. +All sequence indices in the certificate are 1-based (literature convention). Usage: python scripts/build_gpt2_induction_cert.py --output reports/gpt2_induction.cert \ @@ -20,6 +21,7 @@ Note: active positions are filtered by --active-eps-max and --min-margin. If none qualify, the script exits with an error. +Direction token IDs use the model's raw tokenizer indexing. """ import argparse @@ -115,15 +117,15 @@ def write_scores(path: Path, seq: int, prev: np.ndarray, scores, weights, eps=No f.write(f"margin {rat_to_str(margin)}\n") if active is not None: for q in active: - f.write(f"active {q}\n") + f.write(f"active {q + 1}\n") for q, k in enumerate(prev.tolist()): - f.write(f"prev {q} {k}\n") + f.write(f"prev {q + 1} {k + 1}\n") for q in range(seq): for k in range(seq): - f.write(f"score {q} {k} {rat_to_str(scores[q][k])}\n") + f.write(f"score {q + 1} {k + 1} {rat_to_str(scores[q][k])}\n") for q in range(seq): for k in range(seq): - f.write(f"weight {q} {k} {rat_to_str(weights[q][k])}\n") + f.write(f"weight {q + 1} {k + 1} {rat_to_str(weights[q][k])}\n") def write_induction_cert(path: Path, seq: int, prev: np.ndarray, scores, weights, eps, margin, active, eps_at, weight_bound_at, @@ -139,27 +141,27 @@ def write_induction_cert(path: Path, seq: int, prev: np.ndarray, scores, weights f.write(f"margin {rat_to_str(margin)}\n") if active is not None: for q in active: - f.write(f"active {q}\n") + f.write(f"active {q + 1}\n") for q, k in enumerate(prev.tolist()): - f.write(f"prev {q} {k}\n") + f.write(f"prev {q + 1} {k + 1}\n") for q in range(seq): for k in range(seq): - f.write(f"score {q} {k} {rat_to_str(scores[q][k])}\n") + f.write(f"score {q + 1} {k + 1} {rat_to_str(scores[q][k])}\n") for q in range(seq): for k in range(seq): - f.write(f"weight {q} {k} {rat_to_str(weights[q][k])}\n") + f.write(f"weight {q + 1} {k + 1} {rat_to_str(weights[q][k])}\n") for q in range(seq): - f.write(f"eps-at {q} {rat_to_str(eps_at[q])}\n") + f.write(f"eps-at {q + 1} {rat_to_str(eps_at[q])}\n") for q in range(seq): for k in range(seq): - f.write(f"weight-bound {q} {k} {rat_to_str(weight_bound_at[q][k])}\n") + f.write(f"weight-bound {q + 1} {k + 1} {rat_to_str(weight_bound_at[q][k])}\n") f.write(f"lo {rat_to_str(lo)}\n") f.write(f"hi {rat_to_str(hi)}\n") for k, val in enumerate(vals): val_str = rat_to_str(val) - f.write(f"val {k} {val_str}\n") - f.write(f"val-lo {k} {val_str}\n") - f.write(f"val-hi {k} {val_str}\n") + f.write(f"val {k + 1} {val_str}\n") + f.write(f"val-lo {k + 1} {val_str}\n") + f.write(f"val-hi {k + 1} {val_str}\n") def write_value_range(path: Path, seq: int, values, decimals: int, @@ -175,7 +177,7 @@ def write_value_range(path: Path, seq: int, values, decimals: int, f.write(f"lo {rat_to_str(lo)}\n") f.write(f"hi {rat_to_str(hi)}\n") for k, val in enumerate(vals_rat): - f.write(f"val {k} {rat_to_str(val)}\n") + f.write(f"val {k + 1} {rat_to_str(val)}\n") def main() -> None: From 64491668c6c5d8086faeee0bbcba946919814599 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 10:07:41 +0100 Subject: [PATCH 229/244] Use 1-based layer/head indexing in cert generator --- README.md | 4 +++- scripts/build_gpt2_induction_cert.py | 27 ++++++++++++++++++++------- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 8941d04..a942d99 100644 --- a/README.md +++ b/README.md @@ -48,11 +48,13 @@ by **untrusted** Python scripts and verified by the Lean CLI; no model forward p ```bash python scripts/build_gpt2_induction_cert.py \ --output reports/gpt2_induction.cert \ - --layer 0 --head 5 --seq 32 --pattern-length 16 \ + --layer 1 --head 6 --seq 32 --pattern-length 16 \ --random-pattern --seed 0 \ --active-eps-max 1/2 ``` +Layer/head indices in the generator are 1-based to match the literature. + Optional direction metadata: ``` diff --git a/scripts/build_gpt2_induction_cert.py b/scripts/build_gpt2_induction_cert.py index 1b5a84a..861abe1 100644 --- a/scripts/build_gpt2_induction_cert.py +++ b/scripts/build_gpt2_induction_cert.py @@ -11,7 +11,7 @@ Usage: python scripts/build_gpt2_induction_cert.py --output reports/gpt2_induction.cert \ - --layer 0 --head 5 --seq 32 --pattern-length 16 \ + --layer 1 --head 6 --seq 32 --pattern-length 16 \ --random-pattern --seed 0 \ --values-out reports/gpt2_induction.values --value-dim 0 \ --active-eps-max 0.2 --min-margin 0 @@ -21,6 +21,7 @@ Note: active positions are filtered by --active-eps-max and --min-margin. If none qualify, the script exits with an error. +Layer/head indices are 1-based in the CLI and converted to 0-based internally. Direction token IDs use the model's raw tokenizer indexing. """ @@ -184,8 +185,10 @@ def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--output", required=True, help="Path to write certificate") parser.add_argument("--scores-out", help="Optional path for raw scores/weights dump") - parser.add_argument("--layer", type=int, default=0, help="Transformer layer index") - parser.add_argument("--head", type=int, default=0, help="Attention head index") + parser.add_argument("--layer", type=int, default=1, + help="Transformer layer index (1-based)") + parser.add_argument("--head", type=int, default=1, + help="Attention head index (1-based)") parser.add_argument("--seq", type=int, default=32, help="Sequence length") parser.add_argument("--pattern-length", type=int, default=16, help="Pattern length") parser.add_argument("--random-pattern", action="store_true", help="Use random token pattern") @@ -213,12 +216,22 @@ def main() -> None: prev, active_mask = build_prev(tokens) candidate_positions = [int(i) for i, flag in enumerate(active_mask) if flag] + if args.layer <= 0: + raise SystemExit("layer must be >= 1") + if args.head <= 0: + raise SystemExit("head must be >= 1") + model = GPT2Model.from_pretrained(args.model) + layer = args.layer - 1 + head = args.head - 1 + if layer >= model.config.n_layer: + raise SystemExit(f"layer must be in [1, {model.config.n_layer}]") + if head >= model.config.n_head: + raise SystemExit(f"head must be in [1, {model.config.n_head}]") model.to(args.device) input_ids = torch.tensor(tokens, dtype=torch.long, device=args.device).unsqueeze(0) - scores, weights, values = compute_scores_weights(model, input_ids, args.layer, args.head, - args.device) + scores, weights, values = compute_scores_weights(model, input_ids, layer, head, args.device) scores_rat = [[rat_from_float(float(scores[q, k]), args.decimals) for k in range(args.seq)] for q in range(args.seq)] @@ -294,8 +307,8 @@ def main() -> None: direction_negative = args.direction_negative direction = wte[direction_target] - wte[direction_negative] head_dim = model.config.n_embd // model.config.n_head - start, end = args.head * head_dim, (args.head + 1) * head_dim - w_o = model.h[args.layer].attn.c_proj.weight.detach().cpu().numpy()[:, start:end] + start, end = head * head_dim, (head + 1) * head_dim + w_o = model.h[layer].attn.c_proj.weight.detach().cpu().numpy()[:, start:end] dir_head = w_o.T @ direction vals = values @ dir_head else: From 3c1d011ec466bf76166925a9802216856b9e6fb8 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 10:12:41 +0100 Subject: [PATCH 230/244] Document literature-aligned induction diagnostics --- docs/induction_cert_audit.md | 8 ++++++++ scripts/build_gpt2_induction_cert.py | 3 +++ scripts/diagnose_induction_heads.py | 2 ++ 3 files changed, 13 insertions(+) diff --git a/docs/induction_cert_audit.md b/docs/induction_cert_audit.md index f9cd7fe..918dc09 100644 --- a/docs/induction_cert_audit.md +++ b/docs/induction_cert_audit.md @@ -33,6 +33,14 @@ This is direct mechanistic evidence in the Transformer Circuits sense: it ties parameters (Q/K/V/O + LayerNorm) to certified bounds on attention and value contributions, but only for the specific inputs and direction supplied. +## Literature alignment + +We follow the standard induction-head diagnostic setup from the literature: +repeated-token sequences (a pattern repeated twice) and attention stripes that +look back by one period (`q -> q - period`). The diagnostic script +`scripts/diagnose_induction_heads.py` mirrors this setup, and the certificate +generator uses repeated patterns for its inputs. + ## Preconditions and scope limits These proofs are sufficient for a **conditional** certification claim: diff --git a/scripts/build_gpt2_induction_cert.py b/scripts/build_gpt2_induction_cert.py index 861abe1..37653f7 100644 --- a/scripts/build_gpt2_induction_cert.py +++ b/scripts/build_gpt2_induction_cert.py @@ -9,6 +9,9 @@ Active induction positions are recorded as `active ` lines in the output. All sequence indices in the certificate are 1-based (literature convention). +The repeated-pattern inputs match the standard induction-head diagnostic setup +from the literature (pattern repeated twice). + Usage: python scripts/build_gpt2_induction_cert.py --output reports/gpt2_induction.cert \ --layer 1 --head 6 --seq 32 --pattern-length 16 \ diff --git a/scripts/diagnose_induction_heads.py b/scripts/diagnose_induction_heads.py index b35976f..7d17ebe 100644 --- a/scripts/diagnose_induction_heads.py +++ b/scripts/diagnose_induction_heads.py @@ -9,6 +9,8 @@ - run GPT-2 with output_attentions, - rank heads by induction stripe attention (q -> q - period), - rank heads by previous-token attention (q -> q - 1). + +Layer/head indices in the report are 1-based to match the literature. """ from __future__ import annotations From 79e94c01412549225fa0481a879bc99ebe0d03fa Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 10:18:00 +0100 Subject: [PATCH 231/244] Document non-vacuous induction certification --- README.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/README.md b/README.md index a942d99..f436dc8 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,17 @@ python scripts/build_gpt2_induction_cert.py \ Layer/head indices in the generator are 1-based to match the literature. +To certify a **non-vacuous** logit-diff lower bound, supply a direction: + +```bash +python scripts/build_gpt2_induction_cert.py \ + --output reports/gpt2_induction.cert \ + --layer 1 --head 6 --seq 32 --pattern-length 16 \ + --random-pattern --seed 0 \ + --active-eps-max 1/2 \ + --direction-target 1268 --direction-negative 1796 +``` + Optional direction metadata: ``` @@ -73,6 +84,12 @@ Optional gates: --min-active --min-margin --max-eps --min-logit-diff ``` +Example non-vacuous check: + +```bash +lake exe nfp induction certify --cert reports/gpt2_induction.cert --min-logit-diff 1/10 +``` + ## File formats ### Induction-head certificate From 15a36143835ecc6fdb6d3c5e18055b699e763fc3 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 10:23:07 +0100 Subject: [PATCH 232/244] Add direction search for non-vacuous certs --- README.md | 12 +++ scripts/build_gpt2_induction_cert.py | 114 ++++++++++++++++++++++++++- 2 files changed, 122 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index f436dc8..c05492f 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,18 @@ python scripts/build_gpt2_induction_cert.py \ --direction-target 1268 --direction-negative 1796 ``` +Or let the untrusted script search for a direction in a vocab slice: + +```bash +python scripts/build_gpt2_induction_cert.py \ + --output reports/gpt2_induction.cert \ + --layer 1 --head 6 --seq 32 --pattern-length 16 \ + --random-pattern --seed 0 \ + --active-eps-max 1/2 \ + --search-direction --direction-vocab-min 1000 --direction-vocab-max 2000 \ + --direction-min-lb 1/10 +``` + Optional direction metadata: ``` diff --git a/scripts/build_gpt2_induction_cert.py b/scripts/build_gpt2_induction_cert.py index 37653f7..f68a1b1 100644 --- a/scripts/build_gpt2_induction_cert.py +++ b/scripts/build_gpt2_induction_cert.py @@ -22,6 +22,9 @@ Optionally, provide a logit-diff direction: --direction-target --direction-negative +Or ask the script to search for a direction (untrusted): + --search-direction --direction-vocab-min 1000 --direction-vocab-max 2000 + Note: active positions are filtered by --active-eps-max and --min-margin. If none qualify, the script exits with an error. Layer/head indices are 1-based in the CLI and converted to 0-based internally. @@ -183,6 +186,63 @@ def write_value_range(path: Path, seq: int, values, decimals: int, for k, val in enumerate(vals_rat): f.write(f"val {k + 1} {rat_to_str(val)}\n") +def search_direction( + model, + values: np.ndarray, + layer: int, + head: int, + prev: np.ndarray, + active_positions: list[int], + eps_at: list[float], + vocab_min: int, + vocab_max: int, + max_candidates: int | None, + seed: int, +) -> tuple[int, int, float]: + if vocab_min < 0 or vocab_max <= vocab_min: + raise SystemExit("direction vocab range must satisfy 0 <= min < max") + cand_ids = list(range(vocab_min, vocab_max)) + if max_candidates is not None and max_candidates > 0 and len(cand_ids) > max_candidates: + rng = np.random.default_rng(seed) + cand_ids = rng.choice(cand_ids, size=max_candidates, replace=False).tolist() + + wte = model.wte.weight.detach().cpu().numpy() + if max(cand_ids) >= wte.shape[0]: + raise SystemExit("direction vocab range exceeds model vocab size") + head_dim = model.config.n_embd // model.config.n_head + start, end = head * head_dim, (head + 1) * head_dim + w_o = model.h[layer].attn.c_proj.weight.detach().cpu().numpy()[:, start:end] + + proj = wte[cand_ids] @ w_o # (m, head_dim) + vals_mat = values @ proj.T # (seq, m) + active_prev = [(q, int(prev[q])) for q in active_positions] + best_lb = None + best_pair = None + + for i in range(vals_mat.shape[1]): + vals_i = vals_mat[:, i] + for j in range(vals_mat.shape[1]): + if i == j: + continue + vals = vals_i - vals_mat[:, j] + lo = float(vals.min()) + hi = float(vals.max()) + gap = hi - lo + lb = None + for q, pq in active_prev: + cand = float(vals[pq]) - eps_at[q] * gap + if lb is None or cand < lb: + lb = cand + if lb is None: + continue + if best_lb is None or lb > best_lb: + best_lb = lb + best_pair = (cand_ids[i], cand_ids[j]) + + if best_pair is None or best_lb is None: + raise SystemExit("No direction candidates found") + return best_pair[0], best_pair[1], best_lb + def main() -> None: parser = argparse.ArgumentParser(description=__doc__) @@ -210,6 +270,16 @@ def main() -> None: help="Target token id for logit-diff direction (optional)") parser.add_argument("--direction-negative", type=int, help="Negative token id for logit-diff direction (optional)") + parser.add_argument("--search-direction", action="store_true", + help="Search for a direction within the vocab range (untrusted).") + parser.add_argument("--direction-vocab-min", type=int, default=1000, + help="Minimum vocab id for direction search (inclusive).") + parser.add_argument("--direction-vocab-max", type=int, default=2000, + help="Maximum vocab id for direction search (exclusive).") + parser.add_argument("--direction-max-candidates", type=int, default=0, + help="Limit number of direction candidates (0 = all in range).") + parser.add_argument("--direction-min-lb", default="0", + help="Minimum logit-diff lower bound to accept (default: 0).") args = parser.parse_args() if args.seq <= 0: @@ -296,18 +366,54 @@ def main() -> None: eps_at.append(max(max_other, deficit)) weight_bound_at = weights_rat + if args.search_direction and (args.direction_target is not None or args.direction_negative is not None): + raise SystemExit("search-direction is mutually exclusive with explicit direction tokens") + direction_target = None direction_negative = None + if args.search_direction: + eps_at_float = [float(eps_at[q]) for q in range(args.seq)] + max_candidates = args.direction_max_candidates + if max_candidates == 0: + max_candidates = None + direction_target, direction_negative, best_lb = search_direction( + model=model, + values=values, + layer=layer, + head=head, + prev=prev, + active_positions=active_positions, + eps_at=eps_at_float, + vocab_min=args.direction_vocab_min, + vocab_max=args.direction_vocab_max, + max_candidates=max_candidates, + seed=args.seed, + ) + try: + min_lb = float(Fraction(args.direction_min_lb)) + except (ValueError, ZeroDivisionError) as exc: + raise SystemExit("direction-min-lb must be a rational literal") from exc + if best_lb < min_lb: + raise SystemExit( + f"Best direction lower bound {best_lb:.6f} below minimum {min_lb:.6f}." + ) + print( + f"Selected direction: target={direction_target} negative={direction_negative} " + f"(estimated LB {best_lb:.6f})" + ) + if (args.direction_target is None) != (args.direction_negative is None): raise SystemExit("direction-target and direction-negative must be provided together") if args.direction_target is not None: + direction_target = args.direction_target + direction_negative = args.direction_negative + + if direction_target is not None: wte = model.wte.weight.detach().cpu().numpy() - if args.direction_target < 0 or args.direction_target >= wte.shape[0]: + if direction_target < 0 or direction_target >= wte.shape[0]: raise SystemExit("direction-target out of vocab range") - if args.direction_negative < 0 or args.direction_negative >= wte.shape[0]: + if direction_negative < 0 or direction_negative >= wte.shape[0]: raise SystemExit("direction-negative out of vocab range") - direction_target = args.direction_target - direction_negative = args.direction_negative direction = wte[direction_target] - wte[direction_negative] head_dim = model.config.n_embd // model.config.n_head start, end = head * head_dim, (head + 1) * head_dim From d69f3f0de416a8ee5d98008eacacd2c410668db7 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 10:31:08 +0100 Subject: [PATCH 233/244] Remove obsolete script and clarify untrusted search --- CLAIMS.md | 4 +- README.md | 3 ++ scripts/variance_lower_bound_check.py | 53 --------------------------- 3 files changed, 5 insertions(+), 55 deletions(-) delete mode 100644 scripts/variance_lower_bound_check.py diff --git a/CLAIMS.md b/CLAIMS.md index 570c0d0..21d062a 100644 --- a/CLAIMS.md +++ b/CLAIMS.md @@ -24,8 +24,8 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri ## Untrusted / heuristic -- Python helpers that generate explicit induction-head certificates from GPT-2 weights: - `scripts/build_gpt2_induction_cert.py`. +- Python helpers that generate explicit induction-head certificates from GPT-2 weights, + including optional direction search: `scripts/build_gpt2_induction_cert.py`. - Any choice of prompts, directions, or candidate heads used by certificate generators. ## Not yet proven diff --git a/README.md b/README.md index c05492f..95d4871 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,9 @@ python scripts/build_gpt2_induction_cert.py \ --direction-min-lb 1/10 ``` +Direction search is **untrusted witness generation**; the Lean CLI only verifies the resulting +explicit certificate. + Optional direction metadata: ``` diff --git a/scripts/variance_lower_bound_check.py b/scripts/variance_lower_bound_check.py deleted file mode 100644 index 5acd5c6..0000000 --- a/scripts/variance_lower_bound_check.py +++ /dev/null @@ -1,53 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: AGPL-3.0-or-later - -from fractions import Fraction - - -def population_variance(xs: list[Fraction]) -> Fraction: - n = len(xs) - if n == 0: - return Fraction(0) - mean = sum(xs, Fraction(0)) / n - return sum((x - mean) ** 2 for x in xs) / n - - -def variance_bound_range(delta: Fraction, n: int) -> Fraction: - if n <= 0 or delta <= 0: - return Fraction(0) - return (delta * delta) / (2 * n) - - -def main() -> None: - xs = [Fraction(0), Fraction(1, 2), Fraction(1)] - n = len(xs) - delta = max(xs) - min(xs) - var = population_variance(xs) - old_bound = Fraction(n - 1, n * n) * (delta**2) - new_bound = variance_bound_range(delta, n) - print(f"x = {xs}") - print(f"n = {n}, delta = {delta}") - print(f"variance = {var} = {float(var):.6f}") - print(f"old_bound = {old_bound} = {float(old_bound):.6f}") - print(f"new_bound = {new_bound} = {float(new_bound):.6f}") - print(f"old_bound <= variance? {old_bound <= var}") - print(f"new_bound <= variance? {new_bound <= var}") - - # Quick brute-force sanity check on a small grid. - grid = [Fraction(0), Fraction(1, 2), Fraction(1)] - for a in grid: - for b in grid: - for c in grid: - xs = [a, b, c] - delta = max(xs) - min(xs) - var = population_variance(xs) - bound = variance_bound_range(delta, len(xs)) - if bound > var: - raise SystemExit( - f"violation: xs={xs} delta={delta} var={var} bound={bound}" - ) - print("grid_check=OK") - - -if __name__ == "__main__": - main() From 3c2b20a46eae7088f903c71af27adbc3040de782 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 10:34:48 +0100 Subject: [PATCH 234/244] Clarify untrusted direction search in limitations --- SOUNDNESS_LIMITATIONS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index 0377f38..04d8b00 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -12,6 +12,8 @@ It is intentionally brief and focused on the soundness boundary. yet imply full model behavior. - Direction metadata (`direction-target`, `direction-negative`) is untrusted and assumes that the unembedding columns represent token logits. +- Any direction search performed by Python helpers is untrusted witness generation; only the + resulting explicit certificate is checked by the Lean CLI. - The active set is user-supplied (or defaulted by the parser); bounds only hold for `q ∈ active`. - Performance: checking large certificates can be expensive for long sequences. From 0643b4067484123540ad28c4b057d321d8a4c229 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 10:41:27 +0100 Subject: [PATCH 235/244] Add direction search report output --- README.md | 3 +- scripts/build_gpt2_induction_cert.py | 50 ++++++++++++++++++++++++++-- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 95d4871..92794ac 100644 --- a/README.md +++ b/README.md @@ -75,7 +75,8 @@ python scripts/build_gpt2_induction_cert.py \ --random-pattern --seed 0 \ --active-eps-max 1/2 \ --search-direction --direction-vocab-min 1000 --direction-vocab-max 2000 \ - --direction-min-lb 1/10 + --direction-min-lb 1/10 \ + --direction-report-out reports/direction_report.txt --direction-topk 10 ``` Direction search is **untrusted witness generation**; the Lean CLI only verifies the resulting diff --git a/scripts/build_gpt2_induction_cert.py b/scripts/build_gpt2_induction_cert.py index f68a1b1..b7c8f54 100644 --- a/scripts/build_gpt2_induction_cert.py +++ b/scripts/build_gpt2_induction_cert.py @@ -24,6 +24,7 @@ Or ask the script to search for a direction (untrusted): --search-direction --direction-vocab-min 1000 --direction-vocab-max 2000 + --direction-report-out reports/direction_report.txt --direction-topk 10 Note: active positions are filtered by --active-eps-max and --min-margin. If none qualify, the script exits with an error. @@ -197,8 +198,11 @@ def search_direction( vocab_min: int, vocab_max: int, max_candidates: int | None, + topk: int, seed: int, -) -> tuple[int, int, float]: +) -> tuple[int, int, float, list[tuple[float, int, int]]]: + import heapq + if vocab_min < 0 or vocab_max <= vocab_min: raise SystemExit("direction vocab range must satisfy 0 <= min < max") cand_ids = list(range(vocab_min, vocab_max)) @@ -218,6 +222,7 @@ def search_direction( active_prev = [(q, int(prev[q])) for q in active_positions] best_lb = None best_pair = None + topk_entries: list[tuple[float, int, int]] = [] for i in range(vals_mat.shape[1]): vals_i = vals_mat[:, i] @@ -238,10 +243,34 @@ def search_direction( if best_lb is None or lb > best_lb: best_lb = lb best_pair = (cand_ids[i], cand_ids[j]) + if topk > 0: + # Preserve the best few candidates for reporting. + if len(topk_entries) < topk: + heapq.heappush(topk_entries, (lb, cand_ids[i], cand_ids[j])) + else: + if lb > topk_entries[0][0]: + heapq.heapreplace(topk_entries, (lb, cand_ids[i], cand_ids[j])) if best_pair is None or best_lb is None: raise SystemExit("No direction candidates found") - return best_pair[0], best_pair[1], best_lb + topk_sorted = sorted(topk_entries, key=lambda x: x[0], reverse=True) + return best_pair[0], best_pair[1], best_lb, topk_sorted + + +def write_direction_report( + path: Path, + entries: list[tuple[float, int, int]], + vocab_min: int, + vocab_max: int, + seed: int, +) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="ascii") as f: + f.write("direction_report\n") + f.write(f"vocab_min={vocab_min} vocab_max={vocab_max} seed={seed}\n") + f.write("rank\tlb\ttarget\tnegative\n") + for rank, (lb, target, negative) in enumerate(entries, start=1): + f.write(f"{rank}\t{lb:.6f}\t{target}\t{negative}\n") def main() -> None: @@ -280,6 +309,10 @@ def main() -> None: help="Limit number of direction candidates (0 = all in range).") parser.add_argument("--direction-min-lb", default="0", help="Minimum logit-diff lower bound to accept (default: 0).") + parser.add_argument("--direction-report-out", type=Path, + help="Optional path to write a ranked direction report.") + parser.add_argument("--direction-topk", type=int, default=10, + help="How many top directions to report (default: 10).") args = parser.parse_args() if args.seq <= 0: @@ -368,6 +401,8 @@ def main() -> None: if args.search_direction and (args.direction_target is not None or args.direction_negative is not None): raise SystemExit("search-direction is mutually exclusive with explicit direction tokens") + if args.direction_report_out is not None and not args.search_direction: + raise SystemExit("direction-report-out requires --search-direction") direction_target = None direction_negative = None @@ -376,7 +411,7 @@ def main() -> None: max_candidates = args.direction_max_candidates if max_candidates == 0: max_candidates = None - direction_target, direction_negative, best_lb = search_direction( + direction_target, direction_negative, best_lb, topk_entries = search_direction( model=model, values=values, layer=layer, @@ -387,8 +422,17 @@ def main() -> None: vocab_min=args.direction_vocab_min, vocab_max=args.direction_vocab_max, max_candidates=max_candidates, + topk=args.direction_topk, seed=args.seed, ) + if args.direction_report_out is not None: + write_direction_report( + args.direction_report_out, + topk_entries, + args.direction_vocab_min, + args.direction_vocab_max, + args.seed, + ) try: min_lb = float(Fraction(args.direction_min_lb)) except (ValueError, ZeroDivisionError) as exc: From 27f74d5cbaf4fe1d25f387c3e3761e75f72229e9 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 10:49:28 +0100 Subject: [PATCH 236/244] Update README for literature alignment and reports --- README.md | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 92794ac..6f4257f 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,8 @@ High-level layout: The current prototype checks **explicit induction-head certificates**. Certificates are produced by **untrusted** Python scripts and verified by the Lean CLI; no model forward pass runs in Lean. +The input setup follows the standard literature diagnostic: repeated token patterns (pattern +repeated twice) and attention stripes that look back by one period. ### Build a head certificate (untrusted) @@ -80,7 +82,8 @@ python scripts/build_gpt2_induction_cert.py \ ``` Direction search is **untrusted witness generation**; the Lean CLI only verifies the resulting -explicit certificate. +explicit certificate. The direction report lists the top-ranked candidates by estimated lower +bound so you can pick a stable non-vacuous direction. Optional direction metadata: @@ -134,6 +137,17 @@ All sequence indices (`q`, `k`) are **1-based** (literature convention). Directi `direction-*` lines are optional metadata; if present, both must appear. If no `active` lines appear, the checker defaults to all non-initial queries (indices 2.. in 1-based indexing). +### Direction report (untrusted) + +``` +direction_report +vocab_min= vocab_max= seed= +rank\tlb\ttarget\tnegative +``` + +This file is an **untrusted helper artifact**; it only ranks candidate directions and does not +change what the Lean checker accepts. + ## Soundness boundary - Untrusted scripts may use floating-point numerics to generate candidate certificates. @@ -144,7 +158,7 @@ For known gaps, see `SOUNDNESS_LIMITATIONS.md`. ## Requirements - **Lean 4** (pinned in `lean-toolchain`) and **Lake**. -- Optional: **Python** for helper scripts (`scripts/`). +- Optional: **Python** for helper scripts (`scripts/`), plus `torch`, `transformers`, and `numpy`. ## Contributing From 260f406bd34216de6f5163ce57216ee7169798b1 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 11:06:46 +0100 Subject: [PATCH 237/244] Verify cert prev/active against optional tokens --- CLAIMS.md | 3 +- Nfp/Cli.lean | 4 +- Nfp/IO/InductionHead/Cert.lean | 57 +++++++++++++++++- Nfp/IO/InductionHead/Tokens.lean | 87 ++++++++++++++++++++++++++++ README.md | 19 +++++- SOUNDNESS_LIMITATIONS.md | 3 +- docs/induction_cert_audit.md | 4 ++ scripts/build_gpt2_induction_cert.py | 17 ++++++ 8 files changed, 188 insertions(+), 6 deletions(-) create mode 100644 Nfp/IO/InductionHead/Tokens.lean diff --git a/CLAIMS.md b/CLAIMS.md index 21d062a..cbffcfb 100644 --- a/CLAIMS.md +++ b/CLAIMS.md @@ -20,7 +20,8 @@ what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewri ## Soundly checked by the trusted CLI - `nfp induction certify` verifies explicit induction-head certificates from a single cert - file, optionally enforcing minimum `active`, `margin`, `eps`, and logit-diff gates. + file, optionally enforcing minimum `active`, `margin`, `eps`, and logit-diff gates, and + optionally checking `prev`/`active` against a supplied token list. ## Untrusted / heuristic diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean index b5df7e5..3016af6 100644 --- a/Nfp/Cli.lean +++ b/Nfp/Cli.lean @@ -35,6 +35,7 @@ private def runInductionCertifySimple (p : Parsed) : IO UInt32 := do let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) let minMarginStr? := (p.flag? "min-margin").map (·.as! String) let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) + let tokensPath? := (p.flag? "tokens").map (·.as! String) let fail (msg : String) : IO UInt32 := do IO.eprintln s!"error: {msg}" return 2 @@ -42,7 +43,7 @@ private def runInductionCertifySimple (p : Parsed) : IO UInt32 := do | none => fail "provide --cert" | some certPath => IO.runInductionHeadCertCheck certPath minActive? minLogitDiffStr? - minMarginStr? maxEpsStr? + minMarginStr? maxEpsStr? tokensPath? /-- `nfp induction certify` subcommand (streamlined). -/ def inductionCertifySimpleCmd : Cmd := `[Cli| @@ -50,6 +51,7 @@ def inductionCertifySimpleCmd : Cmd := `[Cli| "Check induction head certificates from an explicit cert." FLAGS: cert : String; "Path to the induction head certificate file." + tokens : String; "Optional path to a token list to verify prev/active." "min-active" : Nat; "Optional minimum number of active queries required \ (default: max 1 (seq/8))." "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ diff --git a/Nfp/IO/InductionHead/Cert.lean b/Nfp/IO/InductionHead/Cert.lean index e6b12f2..283cf3e 100644 --- a/Nfp/IO/InductionHead/Cert.lean +++ b/Nfp/IO/InductionHead/Cert.lean @@ -5,8 +5,10 @@ module public import Mathlib.Data.Finset.Insert public import Nfp.Circuit.Cert.InductionHead public import Nfp.Circuit.Cert.LogitDiff +public import Nfp.IO.InductionHead.Tokens public import Nfp.IO.Parse.Basic public import Nfp.IO.Util +public import Nfp.Model.InductionPrompt /-! Untrusted parsing and checking for explicit induction-head certificates. @@ -321,13 +323,38 @@ def loadInductionHeadCert (path : System.FilePath) : let data ← IO.FS.readFile path return parseInductionHeadCert data +/-- Parse a token list payload. -/ +def parseInductionHeadTokens (input : String) : + Except String (Sigma fun seq => Fin seq → Nat) := do + let lines := input.splitOn "\n" + let tokens := lines.filterMap cleanTokens + let seq ← InductionHeadTokens.parseSeq tokens + match seq with + | 0 => throw "seq must be positive" + | Nat.succ n => + let seq := Nat.succ n + let st0 : InductionHeadTokens.ParseState seq := InductionHeadTokens.initState seq + let st ← tokens.foldlM (fun st t => + match t with + | ["seq", _] => pure st + | _ => InductionHeadTokens.parseLine st t) st0 + let tokensFun ← InductionHeadTokens.finalizeState st + return ⟨seq, tokensFun⟩ + +/-- Load a token list from disk. -/ +def loadInductionHeadTokens (path : System.FilePath) : + IO (Except String (Sigma fun seq => Fin seq → Nat)) := do + let data ← IO.FS.readFile path + return parseInductionHeadTokens data + private def ratToString (x : Rat) : String := toString x /-- Check an explicit induction-head certificate from disk. -/ def runInductionHeadCertCheck (certPath : System.FilePath) (minActive? : Option Nat) (minLogitDiffStr? : Option String) - (minMarginStr? : Option String) (maxEpsStr? : Option String) : IO UInt32 := do + (minMarginStr? : Option String) (maxEpsStr? : Option String) + (tokensPath? : Option String) : IO UInt32 := do let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? let minMargin?E := parseRatOpt "min-margin" minMarginStr? let maxEps?E := parseRatOpt "max-eps" maxEpsStr? @@ -378,6 +405,34 @@ def runInductionHeadCertCheck (certPath : System.FilePath) s!"error: eps {ratToString cert.eps} \ above maximum {ratToString maxEps}" return 2 + if let some tokensPath := tokensPath? then + let tokensParsed ← loadInductionHeadTokens tokensPath + match tokensParsed with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok ⟨seqTokens, tokens⟩ => + if hseq : seqTokens = seq then + let tokens' : Fin seq → Nat := by + simpa [hseq] using tokens + let activeTokens := Model.activeOfTokens (seq := seq) tokens' + if !decide (cert.active ⊆ activeTokens) then + IO.eprintln "error: active set not contained in token repeats" + return 2 + let prevTokens := Model.prevOfTokens (seq := seq) tokens' + let prevOk := + (List.finRange seq).all (fun q => + if decide (q ∈ cert.active) then + decide (prevTokens q = cert.prev q) + else + true) + if !prevOk then + IO.eprintln "error: prev map does not match tokens on active queries" + return 2 + else + IO.eprintln + s!"error: tokens seq {seqTokens} does not match cert seq {seq}" + return 2 let effectiveMinLogitDiff := match minLogitDiff?, cert.values.direction with | some v, _ => some v diff --git a/Nfp/IO/InductionHead/Tokens.lean b/Nfp/IO/InductionHead/Tokens.lean new file mode 100644 index 0000000..827e94d --- /dev/null +++ b/Nfp/IO/InductionHead/Tokens.lean @@ -0,0 +1,87 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Data.Fintype.Basic +public import Nfp.IO.Parse.Basic + +/-! +Untrusted parsing helpers for optional induction-head token lists. +-/ + +public section + +namespace Nfp + +namespace IO + +open Nfp.IO.Parse + +namespace InductionHeadTokens + +/-- State for parsing token lists. -/ +structure ParseState (seq : Nat) where + /-- Optional per-position tokens. -/ + tokens : Array (Option Nat) + +/-- Initialize a token parse state. -/ +def initState (seq : Nat) : ParseState seq := + { tokens := Array.replicate seq none } + +private def toIndex1 {seq : Nat} (label : String) (idx : Nat) : Except String (Fin seq) := do + if idx = 0 then + throw s!"{label} index must be >= 1" + let idx' := idx - 1 + if h : idx' < seq then + return ⟨idx', h⟩ + else + throw s!"{label} index out of range: {idx}" + +private def setToken {seq : Nat} (st : ParseState seq) (q tok : Nat) : + Except String (ParseState seq) := do + let qFin ← toIndex1 (seq := seq) "q" q + match st.tokens[qFin.1]! with + | some _ => + throw s!"duplicate token entry for q={q}" + | none => + let tokens' := st.tokens.set! qFin.1 (some tok) + return { st with tokens := tokens' } + +/-- Parse a tokenized line into the token parse state. -/ +def parseLine {seq : Nat} (st : ParseState seq) (tokens : List String) : + Except String (ParseState seq) := do + match tokens with + | ["token", q, tok] => + setToken st (← parseNat q) (← parseNat tok) + | _ => + throw s!"unrecognized line: '{String.intercalate " " tokens}'" + +/-- Extract the `seq` header from tokenized lines. -/ +def parseSeq (tokens : List (List String)) : Except String Nat := do + let mut seq? : Option Nat := none + for t in tokens do + match t with + | ["seq", n] => + if seq?.isSome then + throw "duplicate seq entry" + else + seq? := some (← parseNat n) + | _ => pure () + match seq? with + | some v => pure v + | none => throw "missing seq entry" + +/-- Finalize a token parse state into a total token map. -/ +def finalizeState {seq : Nat} (st : ParseState seq) : + Except String (Fin seq → Nat) := do + if !st.tokens.all Option.isSome then + throw "missing token entries" + let tokensFun : Fin seq → Nat := fun q => + (st.tokens[q.1]!).getD 0 + pure tokensFun + +end InductionHeadTokens + +end IO + +end Nfp diff --git a/README.md b/README.md index 6f4257f..5055a08 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,8 @@ python scripts/build_gpt2_induction_cert.py \ --active-eps-max 1/2 \ --search-direction --direction-vocab-min 1000 --direction-vocab-max 2000 \ --direction-min-lb 1/10 \ - --direction-report-out reports/direction_report.txt --direction-topk 10 + --direction-report-out reports/direction_report.txt --direction-topk 10 \ + --tokens-out reports/gpt2_induction.tokens ``` Direction search is **untrusted witness generation**; the Lean CLI only verifies the resulting @@ -100,9 +101,12 @@ lake exe nfp induction certify --cert reports/gpt2_induction.cert Optional gates: ``` ---min-active --min-margin --max-eps --min-logit-diff +--min-active --min-margin --max-eps --min-logit-diff --tokens ``` +If `--tokens` is provided, the CLI verifies that the certificate's `prev` and `active` +match the token-sequence semantics for repeated tokens (previous occurrence). + Example non-vacuous check: ```bash @@ -148,6 +152,17 @@ rank\tlb\ttarget\tnegative This file is an **untrusted helper artifact**; it only ranks candidate directions and does not change what the Lean checker accepts. +### Token list (untrusted) + +``` +seq +token +``` + +This file is an **untrusted helper artifact** used to check that `prev` and `active` match the +token sequence (previous-occurrence semantics) when `--tokens` is supplied to the CLI. Indices +are 1-based. + ## Soundness boundary - Untrusted scripts may use floating-point numerics to generate candidate certificates. diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index 04d8b00..04bb374 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -15,7 +15,8 @@ It is intentionally brief and focused on the soundness boundary. - Any direction search performed by Python helpers is untrusted witness generation; only the resulting explicit certificate is checked by the Lean CLI. - The active set is user-supplied (or defaulted by the parser); bounds only hold for - `q ∈ active`. + `q ∈ active`. You can optionally verify `prev`/`active` against a token list via + `nfp induction certify --tokens ...`. - Performance: checking large certificates can be expensive for long sequences. ## Remaining work diff --git a/docs/induction_cert_audit.md b/docs/induction_cert_audit.md index 918dc09..37c8d70 100644 --- a/docs/induction_cert_audit.md +++ b/docs/induction_cert_audit.md @@ -54,6 +54,10 @@ Key assumptions and limitations: - The active set can be strict; bounds only hold for `q ∈ active`, not all positions. - The direction metadata assumes the unembedding columns encode the model’s logit map. +Optional safeguard: +- If a token list is supplied to the CLI (`--tokens`), the checker verifies that `prev` + and `active` match the previous-occurrence semantics for that sequence. + ## Conclusion Yes—**within the formal scope** of the current definitions, the proofs are diff --git a/scripts/build_gpt2_induction_cert.py b/scripts/build_gpt2_induction_cert.py index b7c8f54..d4e926d 100644 --- a/scripts/build_gpt2_induction_cert.py +++ b/scripts/build_gpt2_induction_cert.py @@ -26,6 +26,9 @@ --search-direction --direction-vocab-min 1000 --direction-vocab-max 2000 --direction-report-out reports/direction_report.txt --direction-topk 10 +Optional token dump (for Lean-side prev/active verification): + --tokens-out reports/gpt2_induction.tokens + Note: active positions are filtered by --active-eps-max and --min-margin. If none qualify, the script exits with an error. Layer/head indices are 1-based in the CLI and converted to 0-based internally. @@ -187,6 +190,14 @@ def write_value_range(path: Path, seq: int, values, decimals: int, for k, val in enumerate(vals_rat): f.write(f"val {k + 1} {rat_to_str(val)}\n") + +def write_tokens(path: Path, tokens: np.ndarray) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="ascii") as f: + f.write(f"seq {len(tokens)}\n") + for idx, tok in enumerate(tokens.tolist(), start=1): + f.write(f"token {idx} {tok}\n") + def search_direction( model, values: np.ndarray, @@ -291,6 +302,7 @@ def main() -> None: parser.add_argument("--values-out", help="Optional path for a value-range certificate") parser.add_argument("--value-dim", type=int, default=0, help="Value dimension index for the value-range certificate") + parser.add_argument("--tokens-out", help="Optional path to write the token list") parser.add_argument("--active-eps-max", default="1/2", help="Maximum eps to include an active position (default: 1/2).") parser.add_argument("--min-margin", default="0", @@ -501,12 +513,17 @@ def main() -> None: write_value_range(values_path, args.seq, vals, args.decimals, direction_target=direction_target, direction_negative=direction_negative) + if args.tokens_out: + tokens_path = Path(args.tokens_out) + write_tokens(tokens_path, tokens) print(f"Wrote certificate to {output_path}") if args.scores_out: print(f"Wrote scores dump to {scores_path}") if args.values_out: print(f"Wrote value-range certificate to {values_path}") + if args.tokens_out: + print(f"Wrote token list to {tokens_path}") if __name__ == "__main__": From 382e98f0953b952c78ad006362e253d73412e114 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 20:04:09 +0100 Subject: [PATCH 238/244] Add user-facing demo guide --- README.md | 2 ++ docs/demo.md | 61 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) create mode 100644 docs/demo.md diff --git a/README.md b/README.md index 5055a08..85d4e32 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,8 @@ by **untrusted** Python scripts and verified by the Lean CLI; no model forward p The input setup follows the standard literature diagnostic: repeated token patterns (pattern repeated twice) and attention stripes that look back by one period. +For a step-by-step walkthrough, see `docs/demo.md`. + ### Build a head certificate (untrusted) ```bash diff --git a/docs/demo.md b/docs/demo.md new file mode 100644 index 0000000..bc1400c --- /dev/null +++ b/docs/demo.md @@ -0,0 +1,61 @@ +# NFP User-Facing Demo (Induction Head Certification) + +This demo shows a full, reproducible path from **untrusted certificate generation** +to **trusted Lean verification** for an induction head. + +## 0. Prerequisites + +- Lean 4 / Lake (pinned in `lean-toolchain`) +- Python with: `numpy`, `torch`, `transformers` + +## 1. Build + +```bash +lake build -q --wfail +lake build nfp -q --wfail +``` + +## 2. Generate an explicit certificate (untrusted) + +This uses a repeated pattern (period repeated twice) and searches for a +non-vacuous logit-diff direction. + +```bash +python scripts/build_gpt2_induction_cert.py \ + --output reports/gpt2_induction.cert \ + --layer 1 --head 6 --seq 32 --pattern-length 16 \ + --random-pattern --seed 0 \ + --active-eps-max 1/2 \ + --search-direction --direction-vocab-min 1000 --direction-vocab-max 2000 \ + --direction-min-lb 1/10 \ + --direction-report-out reports/direction_report.txt --direction-topk 10 \ + --tokens-out reports/gpt2_induction.tokens +``` + +Expected output includes: +- a certificate (`reports/gpt2_induction.cert`) +- a ranked direction report (`reports/direction_report.txt`) +- a token list (`reports/gpt2_induction.tokens`) + +## 3. Verify with the Lean checker (trusted) + +This enforces a **non-vacuous** logit-diff lower bound and checks `prev/active` +against the token list. + +```bash +lake exe nfp induction certify \ + --cert reports/gpt2_induction.cert \ + --min-logit-diff 1/10 \ + --tokens reports/gpt2_induction.tokens +``` + +Expected output (example): + +``` +ok: induction head certificate checked (seq=32, active=15, margin=..., eps=..., logitDiffLB=...) +``` + +## Notes + +- Everything in `scripts/` is **untrusted witness generation**. +- The Lean CLI **only verifies** explicit certificates and token semantics. From 3627f0f8606805bc89e83ecf1a9b65cb7f02a258 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 20:07:39 +0100 Subject: [PATCH 239/244] Add certificate usefulness note --- README.md | 2 ++ docs/cert_usefulness.md | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+) create mode 100644 docs/cert_usefulness.md diff --git a/README.md b/README.md index 85d4e32..2d70ea4 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,8 @@ The input setup follows the standard literature diagnostic: repeated token patte repeated twice) and attention stripes that look back by one period. For a step-by-step walkthrough, see `docs/demo.md`. +For a careful statement of what certificates do and do not claim, see +`docs/cert_usefulness.md`. ### Build a head certificate (untrusted) diff --git a/docs/cert_usefulness.md b/docs/cert_usefulness.md new file mode 100644 index 0000000..700d63f --- /dev/null +++ b/docs/cert_usefulness.md @@ -0,0 +1,39 @@ +# What Induction-Head Certificates Do (and Do Not) Claim + +This document summarizes the **useful, limited guarantees** provided by an induction‑head +certificate, without overselling. + +## What the certificate **does** guarantee + +If the Lean checker accepts a certificate, then: +- **Softmax‑margin bounds** hold on the specified active queries (the `prev` score dominates other + keys by the declared margin). +- **One‑hot‑style weight bounds** hold on those queries (non‑`prev` weights are bounded by `ε`). +- **Value‑interval bounds** hold for the supplied value ranges. +- **Logit‑diff lower bound** holds for the supplied direction (if direction metadata is present + and the checker is run with `--min-logit-diff`). + +These are **formal, exact** statements about the explicit certificate data. + +## Why this is useful + +- **Quantitative guarantees:** the bounds are numeric and can be gated (e.g., require a strictly + positive logit‑diff lower bound). +- **Reproducibility:** certificates are explicit artifacts that can be re‑checked later. +- **Comparability:** bounds provide a principled way to compare heads or settings. +- **Soundness boundary clarity:** generation is untrusted, verification is trusted. + +## What the certificate **does not** guarantee + +- **No full‑model claim:** this is a head‑level certificate; it does not imply end‑to‑end model + behavior. +- **Input‑specific:** guarantees apply only to the specified inputs / token patterns. +- **Untrusted semantics:** unless you pass `--tokens`, the `prev` and `active` sets are not + verified against a token sequence. +- **Direction is untrusted:** `direction-target` / `direction-negative` are supplied metadata. + +## Optional token verification + +If you provide a token list to the CLI (`--tokens`), the checker verifies that `prev` and `active` +match **previous‑occurrence semantics** for that token sequence. This strengthens the link to the +induction‑head diagnostic while keeping the trusted checker lightweight. From 656838dc2778c1e3da2a1bab8a1047d2b0f89ddd Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 20:12:50 +0100 Subject: [PATCH 240/244] Add literature references --- README.md | 5 +++++ docs/induction_cert_audit.md | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/README.md b/README.md index 2d70ea4..f368ed7 100644 --- a/README.md +++ b/README.md @@ -179,6 +179,11 @@ For known gaps, see `SOUNDNESS_LIMITATIONS.md`. - **Lean 4** (pinned in `lean-toolchain`) and **Lake**. - Optional: **Python** for helper scripts (`scripts/`), plus `torch`, `transformers`, and `numpy`. +## References + +- Elhage et al., “A Mathematical Framework for Transformer Circuits.” citeturn0search0 +- Olsson et al., “In-context Learning and Induction Heads.” citeturn0academia12 + ## Contributing Please follow the project rules in `AGENTS.md` (no `sorry`, no linter disables, total soundness in diff --git a/docs/induction_cert_audit.md b/docs/induction_cert_audit.md index 37c8d70..581d79f 100644 --- a/docs/induction_cert_audit.md +++ b/docs/induction_cert_audit.md @@ -69,3 +69,8 @@ along a specified direction, conditional on an explicit certificate. - Add a verified extraction pipeline from model weights to explicit certificates. - Prove that `prev`, `active`, and `direction` correspond to token-level semantics. + +## References + +- Elhage et al., “A Mathematical Framework for Transformer Circuits.” citeturn0search0 +- Olsson et al., “In-context Learning and Induction Heads.” citeturn0academia12 From 02384d7b8acb095384a5636c14174abeb9f9d451 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 20:16:22 +0100 Subject: [PATCH 241/244] Fix literature references --- README.md | 2 +- docs/induction_cert_audit.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f368ed7..f590e10 100644 --- a/README.md +++ b/README.md @@ -181,7 +181,7 @@ For known gaps, see `SOUNDNESS_LIMITATIONS.md`. ## References -- Elhage et al., “A Mathematical Framework for Transformer Circuits.” citeturn0search0 +- Elhage et al., “A Mathematical Framework for Transformer Circuits.” citeturn0search1 - Olsson et al., “In-context Learning and Induction Heads.” citeturn0academia12 ## Contributing diff --git a/docs/induction_cert_audit.md b/docs/induction_cert_audit.md index 581d79f..18b62ff 100644 --- a/docs/induction_cert_audit.md +++ b/docs/induction_cert_audit.md @@ -72,5 +72,5 @@ along a specified direction, conditional on an explicit certificate. ## References -- Elhage et al., “A Mathematical Framework for Transformer Circuits.” citeturn0search0 +- Elhage et al., “A Mathematical Framework for Transformer Circuits.” citeturn0search1 - Olsson et al., “In-context Learning and Induction Heads.” citeturn0academia12 From ebf4e0faafc8b7ec1210e97c9ea4448cbb5e4b6d Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 20:20:08 +0100 Subject: [PATCH 242/244] Update reference links --- README.md | 6 ++++-- docs/induction_cert_audit.md | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index f590e10..79d0b01 100644 --- a/README.md +++ b/README.md @@ -181,8 +181,10 @@ For known gaps, see `SOUNDNESS_LIMITATIONS.md`. ## References -- Elhage et al., “A Mathematical Framework for Transformer Circuits.” citeturn0search1 -- Olsson et al., “In-context Learning and Induction Heads.” citeturn0academia12 +- Elhage et al., “A Mathematical Framework for Transformer Circuits.” + Link: `https://transformer-circuits.pub/2021/framework/index.html` citeturn0search1 +- Olsson et al., “In-context Learning and Induction Heads.” + Link: `https://arxiv.org/abs/2209.11895` citeturn0academia12 ## Contributing diff --git a/docs/induction_cert_audit.md b/docs/induction_cert_audit.md index 18b62ff..cefe075 100644 --- a/docs/induction_cert_audit.md +++ b/docs/induction_cert_audit.md @@ -72,5 +72,7 @@ along a specified direction, conditional on an explicit certificate. ## References -- Elhage et al., “A Mathematical Framework for Transformer Circuits.” citeturn0search1 -- Olsson et al., “In-context Learning and Induction Heads.” citeturn0academia12 +- Elhage et al., “A Mathematical Framework for Transformer Circuits.” + Link: `https://transformer-circuits.pub/2021/framework/index.html` citeturn0search1 +- Olsson et al., “In-context Learning and Induction Heads.” + Link: `https://arxiv.org/abs/2209.11895` citeturn0academia12 From 397fc91f7124568cf1ac3e02ff7350a9f6a44097 Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sat, 17 Jan 2026 20:23:08 +0100 Subject: [PATCH 243/244] Remove citation artifacts --- README.md | 4 ++-- docs/induction_cert_audit.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 79d0b01..8da1edd 100644 --- a/README.md +++ b/README.md @@ -182,9 +182,9 @@ For known gaps, see `SOUNDNESS_LIMITATIONS.md`. ## References - Elhage et al., “A Mathematical Framework for Transformer Circuits.” - Link: `https://transformer-circuits.pub/2021/framework/index.html` citeturn0search1 + Link: `https://transformer-circuits.pub/2021/framework/index.html` - Olsson et al., “In-context Learning and Induction Heads.” - Link: `https://arxiv.org/abs/2209.11895` citeturn0academia12 + Link: `https://arxiv.org/abs/2209.11895` ## Contributing diff --git a/docs/induction_cert_audit.md b/docs/induction_cert_audit.md index cefe075..eafb351 100644 --- a/docs/induction_cert_audit.md +++ b/docs/induction_cert_audit.md @@ -73,6 +73,6 @@ along a specified direction, conditional on an explicit certificate. ## References - Elhage et al., “A Mathematical Framework for Transformer Circuits.” - Link: `https://transformer-circuits.pub/2021/framework/index.html` citeturn0search1 + Link: `https://transformer-circuits.pub/2021/framework/index.html` - Olsson et al., “In-context Learning and Induction Heads.” - Link: `https://arxiv.org/abs/2209.11895` citeturn0academia12 + Link: `https://arxiv.org/abs/2209.11895` From 0c3a4735d0e6231e75a6e67c24dd31634a99fa1d Mon Sep 17 00:00:00 2001 From: TheDarkchip Date: Sun, 18 Jan 2026 23:24:34 +0100 Subject: [PATCH 244/244] Trim CI and fix parse cert docstrings --- .github/workflows/ci.yml | 30 ---------------------------- Nfp/IO/Parse/SoftmaxMargin/Cert.lean | 2 +- Nfp/IO/Parse/ValueRange/Cert.lean | 2 +- 3 files changed, 2 insertions(+), 32 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 364b4e3..06d8d82 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,33 +33,3 @@ jobs: echo "Quiet build failed; retrying verbose..." lake build nfp -v --no-ansi --wfail } - - - name: SOUND cache regression test (tiny fixture) - run: | - set -euo pipefail - lake exe nfp sound_cache_check --scalePow10 9 --maxTokens 0 tests/fixtures/tiny_sound_model.nfpt - - - name: Python setup - uses: actions/setup-python@v5 - with: - python-version: "3.11" - - - name: Cache HF model files - uses: actions/cache@v4 - with: - path: | - ~/.cache/huggingface - ~/.cache/torch - key: ${{ runner.os }}-hf-tiny-gpt2-v1 - - - name: Install PyTorch + Transformers (CPU) - run: | - set -euo pipefail - python -m pip install --upgrade pip - python -m pip install --index-url https://download.pytorch.org/whl/cpu torch - python -m pip install transformers numpy - - - name: Sanity check (Lean vs PyTorch forward step) - run: | - set -euo pipefail - python scripts/ci_sanity_forward_step.py --model sshleifer/tiny-gpt2 --seqLen 8 --layer 0 --pos 0 --take 16 diff --git a/Nfp/IO/Parse/SoftmaxMargin/Cert.lean b/Nfp/IO/Parse/SoftmaxMargin/Cert.lean index 94892f0..014c810 100644 --- a/Nfp/IO/Parse/SoftmaxMargin/Cert.lean +++ b/Nfp/IO/Parse/SoftmaxMargin/Cert.lean @@ -6,7 +6,7 @@ public import Nfp.Circuit.Cert.SoftmaxMargin public import Nfp.IO.Parse.SoftmaxMargin.Shared /-! -Parse parsing helpers for softmax-margin certificates. +Parsing helpers for softmax-margin certificates. -/ public section diff --git a/Nfp/IO/Parse/ValueRange/Cert.lean b/Nfp/IO/Parse/ValueRange/Cert.lean index e13d6e1..3e7ee8d 100644 --- a/Nfp/IO/Parse/ValueRange/Cert.lean +++ b/Nfp/IO/Parse/ValueRange/Cert.lean @@ -6,7 +6,7 @@ public import Nfp.Circuit.Cert.ValueRange public import Nfp.IO.Parse.ValueRange.Shared /-! -Parse parsing helpers for value-range certificates. +Parsing helpers for value-range certificates. -/ public section